From 067f695fffa3eb295afdfaa6f840c47e1866125d Mon Sep 17 00:00:00 2001 From: ledgerwatch Date: Sat, 20 May 2023 14:48:16 +0100 Subject: [PATCH] [devnet tool] Separate logging (#7553) Co-authored-by: Alex Sharp --- cmd/devnet/main.go | 4 +- cmd/devnet/services/event.go | 13 +++-- cmd/erigon-el/backend/backend.go | 2 +- cmd/p2psim/main.go | 4 +- cmd/rpcdaemon/cli/config.go | 6 +- cmd/rpcdaemon/rpcdaemontest/test_util.go | 2 +- cmd/txpool/main.go | 2 +- eth/backend.go | 2 +- eth/stagedsync/default_stages.go | 2 +- eth/stagedsync/stage_cumulative_index.go | 6 +- ethdb/privateapi/mining.go | 32 ++++++----- node/rpcstack.go | 22 ++++---- p2p/simulations/http.go | 9 ++- rpc/client.go | 31 +++++----- rpc/client_example_test.go | 4 +- rpc/client_test.go | 72 ++++++++++++++---------- rpc/handler.go | 26 ++++----- rpc/http.go | 9 +-- rpc/http_test.go | 7 ++- rpc/inproc.go | 6 +- rpc/server.go | 12 ++-- rpc/server_test.go | 15 +++-- rpc/stdio.go | 10 ++-- rpc/subscription_test.go | 8 ++- rpc/testservice_test.go | 10 ++-- rpc/websocket.go | 8 +-- rpc/websocket_test.go | 17 +++--- 27 files changed, 194 insertions(+), 147 deletions(-) diff --git a/cmd/devnet/main.go b/cmd/devnet/main.go index 07e4350ac..1ac5d5f50 100644 --- a/cmd/devnet/main.go +++ b/cmd/devnet/main.go @@ -65,7 +65,7 @@ func action(ctx *cli.Context) error { logger := logging.SetupLoggerCtx("devnet", ctx, false /* rootLogger */) // Make root logger fail - //log.Root().SetHandler(PanicHandler{}) + log.Root().SetHandler(PanicHandler{}) // clear all the dev files if err := devnetutils.ClearDevDB(dataDir, logger); err != nil { @@ -84,7 +84,7 @@ func action(ctx *cli.Context) error { time.Sleep(time.Second * 10) // start up the subscription services for the different sub methods - services.InitSubscriptions([]models.SubMethod{models.ETHNewHeads}) + services.InitSubscriptions([]models.SubMethod{models.ETHNewHeads}, logger) // execute all rpc methods amongst the two nodes commands.ExecuteAllMethods() diff --git a/cmd/devnet/services/event.go b/cmd/devnet/services/event.go index a8f2c6b22..51bc62347 100644 --- a/cmd/devnet/services/event.go +++ b/cmd/devnet/services/event.go @@ -9,11 +9,12 @@ import ( "github.com/ledgerwatch/erigon/cmd/devnet/devnetutils" "github.com/ledgerwatch/erigon/cmd/devnet/models" "github.com/ledgerwatch/erigon/rpc" + "github.com/ledgerwatch/log/v3" ) -func InitSubscriptions(methods []models.SubMethod) { +func InitSubscriptions(methods []models.SubMethod, logger log.Logger) { fmt.Printf("CONNECTING TO WEBSOCKETS AND SUBSCRIBING TO METHODS...\n") - if err := subscribeAll(methods); err != nil { + if err := subscribeAll(methods, logger); err != nil { fmt.Printf("failed to subscribe to all methods: %v\n", err) return } @@ -65,8 +66,8 @@ func subscribe(client *rpc.Client, method models.SubMethod, args ...interface{}) return methodSub, nil } -func subscribeToMethod(method models.SubMethod) (*models.MethodSubscription, error) { - client, err := rpc.DialWebsocket(context.Background(), fmt.Sprintf("ws://%s", models.Localhost), "") +func subscribeToMethod(method models.SubMethod, logger log.Logger) (*models.MethodSubscription, error) { + client, err := rpc.DialWebsocket(context.Background(), fmt.Sprintf("ws://%s", models.Localhost), "", logger) if err != nil { return nil, fmt.Errorf("failed to dial websocket: %v", err) } @@ -132,11 +133,11 @@ func UnsubscribeAll() { } // subscribeAll subscribes to the range of methods provided -func subscribeAll(methods []models.SubMethod) error { +func subscribeAll(methods []models.SubMethod, logger log.Logger) error { m := make(map[models.SubMethod]*models.MethodSubscription) models.MethodSubscriptionMap = &m for _, method := range methods { - sub, err := subscribeToMethod(method) + sub, err := subscribeToMethod(method, logger) if err != nil { return err } diff --git a/cmd/erigon-el/backend/backend.go b/cmd/erigon-el/backend/backend.go index 1ecaba253..53dbcb853 100644 --- a/cmd/erigon-el/backend/backend.go +++ b/cmd/erigon-el/backend/backend.go @@ -510,7 +510,7 @@ func NewBackend(stack *node.Node, config *ethconfig.Config, logger log.Logger) ( // Initialize ethbackend ethBackendRPC := privateapi.NewEthBackendServer(ctx, backend, backend.chainDB, backend.notifications.Events, backend.blockReader, chainConfig, assembleBlockPOS, backend.sentriesClient.Hd, config.Miner.EnabledPOS, logger) - miningRPC = privateapi.NewMiningServer(ctx, backend, ethashApi) + miningRPC = privateapi.NewMiningServer(ctx, backend, ethashApi, logger) var creds credentials.TransportCredentials if stack.Config().PrivateApiAddr != "" { diff --git a/cmd/p2psim/main.go b/cmd/p2psim/main.go index b2e83339e..6eb48e923 100644 --- a/cmd/p2psim/main.go +++ b/cmd/p2psim/main.go @@ -45,6 +45,7 @@ import ( "text/tabwriter" "github.com/ledgerwatch/erigon/cmd/utils" + "github.com/ledgerwatch/erigon/turbo/logging" "github.com/urfave/cli/v2" "github.com/ledgerwatch/erigon/crypto" @@ -69,7 +70,8 @@ func main() { }, } app.Before = func(ctx *cli.Context) error { - client = simulations.NewClient(ctx.String("api")) + logger := logging.SetupLoggerCtx("p2psim", ctx, false /* rootLogger */) + client = simulations.NewClient(ctx.String("api"), logger) return nil } app.Commands = []*cli.Command{ diff --git a/cmd/rpcdaemon/cli/config.go b/cmd/rpcdaemon/cli/config.go index 5463057c4..57ce0a20f 100644 --- a/cmd/rpcdaemon/cli/config.go +++ b/cmd/rpcdaemon/cli/config.go @@ -499,7 +499,7 @@ func startRegularRpcServer(ctx context.Context, cfg httpcfg.HttpCfg, rpcAPI []rp httpEndpoint := fmt.Sprintf("%s:%d", cfg.HttpListenAddress, cfg.HttpPort) logger.Trace("TraceRequests = %t\n", cfg.TraceRequests) - srv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, cfg.RpcStreamingDisable) + srv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, cfg.RpcStreamingDisable, logger) allowListForRPC, err := parseAllowListForRPC(cfg.RpcAllowListFilePath) if err != nil { @@ -619,7 +619,7 @@ type engineInfo struct { func startAuthenticatedRpcServer(cfg httpcfg.HttpCfg, rpcAPI []rpc.API, logger log.Logger) (*engineInfo, error) { logger.Trace("TraceRequests = %t\n", cfg.TraceRequests) - srv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, cfg.RpcStreamingDisable) + srv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, cfg.RpcStreamingDisable, logger) engineListener, engineSrv, engineHttpEndpoint, err := createEngineListener(cfg, rpcAPI, logger) if err != nil { @@ -709,7 +709,7 @@ func createHandler(cfg httpcfg.HttpCfg, apiList []rpc.API, httpHandler http.Hand func createEngineListener(cfg httpcfg.HttpCfg, engineApi []rpc.API, logger log.Logger) (*http.Server, *rpc.Server, string, error) { engineHttpEndpoint := fmt.Sprintf("%s:%d", cfg.AuthRpcHTTPListenAddress, cfg.AuthRpcPort) - engineSrv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, true) + engineSrv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, true, logger) if err := node.RegisterApisFromWhitelist(engineApi, nil, engineSrv, true, logger); err != nil { return nil, nil, "", fmt.Errorf("could not start register RPC engine api: %w", err) diff --git a/cmd/rpcdaemon/rpcdaemontest/test_util.go b/cmd/rpcdaemon/rpcdaemontest/test_util.go index 182af9fb5..9999bcce8 100644 --- a/cmd/rpcdaemon/rpcdaemontest/test_util.go +++ b/cmd/rpcdaemon/rpcdaemontest/test_util.go @@ -298,7 +298,7 @@ func CreateTestGrpcConn(t *testing.T, m *stages.MockSentry) (context.Context, *g remote.RegisterETHBACKENDServer(server, privateapi.NewEthBackendServer(ctx, nil, m.DB, m.Notifications.Events, snapshotsync.NewBlockReaderWithSnapshots(m.BlockSnapshots, m.TransactionsV3), nil, nil, nil, false, log.New())) txpool.RegisterTxpoolServer(server, m.TxPoolGrpcServer) - txpool.RegisterMiningServer(server, privateapi.NewMiningServer(ctx, &IsMiningMock{}, ethashApi)) + txpool.RegisterMiningServer(server, privateapi.NewMiningServer(ctx, &IsMiningMock{}, ethashApi, m.Log)) listener := bufconn.Listen(1024 * 1024) dialer := func() func(context.Context, string) (net.Conn, error) { diff --git a/cmd/txpool/main.go b/cmd/txpool/main.go index 1d5eda607..7b806810c 100644 --- a/cmd/txpool/main.go +++ b/cmd/txpool/main.go @@ -172,7 +172,7 @@ func doTxpool(ctx context.Context, logger log.Logger) error { ethashApi = casted.APIs(nil)[1].Service.(*ethash.API) } */ - miningGrpcServer := privateapi.NewMiningServer(ctx, &rpcdaemontest.IsMiningMock{}, nil) + miningGrpcServer := privateapi.NewMiningServer(ctx, &rpcdaemontest.IsMiningMock{}, nil, logger) grpcServer, err := txpool.StartGrpc(txpoolGrpcServer, miningGrpcServer, txpoolApiAddr, nil, logger) if err != nil { diff --git a/eth/backend.go b/eth/backend.go index d9149a998..325edf08a 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -546,7 +546,7 @@ func New(stack *node.Node, config *ethconfig.Config, logger log.Logger) (*Ethere // Initialize ethbackend ethBackendRPC := privateapi.NewEthBackendServer(ctx, backend, backend.chainDB, backend.notifications.Events, blockReader, chainConfig, assembleBlockPOS, backend.sentriesClient.Hd, config.Miner.EnabledPOS, logger) - miningRPC = privateapi.NewMiningServer(ctx, backend, ethashApi) + miningRPC = privateapi.NewMiningServer(ctx, backend, ethashApi, logger) var creds credentials.TransportCredentials if stack.Config().PrivateApiAddr != "" { diff --git a/eth/stagedsync/default_stages.go b/eth/stagedsync/default_stages.go index 064b0418c..7ee8152e1 100644 --- a/eth/stagedsync/default_stages.go +++ b/eth/stagedsync/default_stages.go @@ -47,7 +47,7 @@ func DefaultStages(ctx context.Context, snapshots SnapshotsCfg, headers HeadersC ID: stages.CumulativeIndex, Description: "Write Cumulative Index", Forward: func(firstCycle bool, badBlockUnwind bool, s *StageState, u Unwinder, tx kv.RwTx, logger log.Logger) error { - return SpawnStageCumulativeIndex(cumulativeIndex, s, tx, ctx) + return SpawnStageCumulativeIndex(cumulativeIndex, s, tx, ctx, logger) }, Unwind: func(firstCycle bool, u *UnwindState, s *StageState, tx kv.RwTx, logger log.Logger) error { return UnwindCumulativeIndexStage(u, cumulativeIndex, tx, ctx) diff --git a/eth/stagedsync/stage_cumulative_index.go b/eth/stagedsync/stage_cumulative_index.go index 23fbc6aa7..82a0e05e0 100644 --- a/eth/stagedsync/stage_cumulative_index.go +++ b/eth/stagedsync/stage_cumulative_index.go @@ -29,7 +29,7 @@ func StageCumulativeIndexCfg(db kv.RwDB) CumulativeIndexCfg { } } -func SpawnStageCumulativeIndex(cfg CumulativeIndexCfg, s *StageState, tx kv.RwTx, ctx context.Context) error { +func SpawnStageCumulativeIndex(cfg CumulativeIndexCfg, s *StageState, tx kv.RwTx, ctx context.Context, logger log.Logger) error { useExternalTx := tx != nil if !useExternalTx { @@ -105,11 +105,11 @@ func SpawnStageCumulativeIndex(cfg CumulativeIndexCfg, s *StageState, tx kv.RwTx case <-ctx.Done(): return ctx.Err() case <-logEvery.C: - log.Info(fmt.Sprintf("[%s] Wrote Cumulative Index", s.LogPrefix()), + logger.Info(fmt.Sprintf("[%s] Wrote Cumulative Index", s.LogPrefix()), "gasUsed", cumulativeGasUsed.String(), "now", currentBlockNumber, "blk/sec", float64(currentBlockNumber-prevProgress)/float64(logInterval/time.Second)) prevProgress = currentBlockNumber default: - log.Trace("RequestQueueTime (header) ticked") + logger.Trace("RequestQueueTime (header) ticked") } // Cleanup timer } diff --git a/ethdb/privateapi/mining.go b/ethdb/privateapi/mining.go index 1a24ed3db..414fb9891 100644 --- a/ethdb/privateapi/mining.go +++ b/ethdb/privateapi/mining.go @@ -30,14 +30,15 @@ type MiningServer struct { minedBlockStreams MinedBlockStreams ethash *ethash.API isMining IsMining + logger log.Logger } type IsMining interface { IsMining() bool } -func NewMiningServer(ctx context.Context, isMining IsMining, ethashApi *ethash.API) *MiningServer { - return &MiningServer{ctx: ctx, isMining: isMining, ethash: ethashApi} +func NewMiningServer(ctx context.Context, isMining IsMining, ethashApi *ethash.API, logger log.Logger) *MiningServer { + return &MiningServer{ctx: ctx, isMining: isMining, ethash: ethashApi, logger: logger} } func (s *MiningServer) Version(context.Context, *emptypb.Empty) (*types2.VersionReply, error) { @@ -100,7 +101,7 @@ func (s *MiningServer) BroadcastPendingLogs(l types.Logs) error { return err } reply := &proto_txpool.OnPendingBlockReply{RplBlock: b} - s.pendingBlockStreams.Broadcast(reply) + s.pendingBlockStreams.Broadcast(reply, s.logger) return nil } @@ -121,7 +122,7 @@ func (s *MiningServer) BroadcastPendingBlock(block *types.Block) error { return err } reply := &proto_txpool.OnPendingBlockReply{RplBlock: buf.Bytes()} - s.pendingBlockStreams.Broadcast(reply) + s.pendingBlockStreams.Broadcast(reply, s.logger) return nil } @@ -133,21 +134,22 @@ func (s *MiningServer) OnMinedBlock(req *proto_txpool.OnMinedBlockRequest, reply } func (s *MiningServer) BroadcastMinedBlock(block *types.Block) error { - log.Debug("BroadcastMinedBlock", "block hash", block.Hash(), "block number", block.Number(), "root", block.Root(), "gas", block.GasUsed()) + s.logger.Debug("BroadcastMinedBlock", "block hash", block.Hash(), "block number", block.Number(), "root", block.Root(), "gas", block.GasUsed()) var buf bytes.Buffer if err := block.EncodeRLP(&buf); err != nil { return err } reply := &proto_txpool.OnMinedBlockReply{RplBlock: buf.Bytes()} - s.minedBlockStreams.Broadcast(reply) + s.minedBlockStreams.Broadcast(reply, s.logger) return nil } // MinedBlockStreams - it's safe to use this class as non-pointer type MinedBlockStreams struct { - chans map[uint]proto_txpool.Mining_OnMinedBlockServer - id uint - mu sync.Mutex + chans map[uint]proto_txpool.Mining_OnMinedBlockServer + id uint + mu sync.Mutex + logger log.Logger } func (s *MinedBlockStreams) Add(stream proto_txpool.Mining_OnMinedBlockServer) (remove func()) { @@ -162,13 +164,13 @@ func (s *MinedBlockStreams) Add(stream proto_txpool.Mining_OnMinedBlockServer) ( return func() { s.remove(id) } } -func (s *MinedBlockStreams) Broadcast(reply *proto_txpool.OnMinedBlockReply) { +func (s *MinedBlockStreams) Broadcast(reply *proto_txpool.OnMinedBlockReply, logger log.Logger) { s.mu.Lock() defer s.mu.Unlock() for id, stream := range s.chans { err := stream.Send(reply) if err != nil { - log.Trace("failed send to mined block stream", "err", err) + logger.Trace("failed send to mined block stream", "err", err) select { case <-stream.Context().Done(): delete(s.chans, id) @@ -207,13 +209,13 @@ func (s *PendingBlockStreams) Add(stream proto_txpool.Mining_OnPendingBlockServe return func() { s.remove(id) } } -func (s *PendingBlockStreams) Broadcast(reply *proto_txpool.OnPendingBlockReply) { +func (s *PendingBlockStreams) Broadcast(reply *proto_txpool.OnPendingBlockReply, logger log.Logger) { s.mu.Lock() defer s.mu.Unlock() for id, stream := range s.chans { err := stream.Send(reply) if err != nil { - log.Trace("failed send to mined block stream", "err", err) + logger.Trace("failed send to mined block stream", "err", err) select { case <-stream.Context().Done(): delete(s.chans, id) @@ -252,13 +254,13 @@ func (s *PendingLogsStreams) Add(stream proto_txpool.Mining_OnPendingLogsServer) return func() { s.remove(id) } } -func (s *PendingLogsStreams) Broadcast(reply *proto_txpool.OnPendingLogsReply) { +func (s *PendingLogsStreams) Broadcast(reply *proto_txpool.OnPendingLogsReply, logger log.Logger) { s.mu.Lock() defer s.mu.Unlock() for id, stream := range s.chans { err := stream.Send(reply) if err != nil { - log.Trace("failed send to mined block stream", "err", err) + logger.Trace("failed send to mined block stream", "err", err) select { case <-stream.Context().Done(): delete(s.chans, id) diff --git a/node/rpcstack.go b/node/rpcstack.go index c5f3f99ac..3dbd4a306 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -56,7 +56,7 @@ type rpcHandler struct { } type httpServer struct { - log log.Logger + logger log.Logger timeouts rpccfg.HTTPTimeouts mux http.ServeMux // registered handlers go here @@ -82,7 +82,7 @@ type httpServer struct { } func newHTTPServer(logger log.Logger, timeouts rpccfg.HTTPTimeouts) *httpServer { - h := &httpServer{log: logger, timeouts: timeouts, handlerNames: make(map[string]string)} + h := &httpServer{logger: logger, timeouts: timeouts, handlerNames: make(map[string]string)} h.httpHandler.Store((*rpcHandler)(nil)) h.wsHandler.Store((*rpcHandler)(nil)) @@ -150,14 +150,14 @@ func (h *httpServer) start() error { if h.wsConfig.prefix != "" { url += h.wsConfig.prefix } - h.log.Info("WebSocket enabled", "url", url) + h.logger.Info("WebSocket enabled", "url", url) } // if server is websocket only, return after logging if !h.rpcAllowed() { return nil } // Log http endpoint. - h.log.Info("HTTP server started", + h.logger.Info("HTTP server started", "endpoint", listener.Addr(), "prefix", h.httpConfig.prefix, "cors", strings.Join(h.httpConfig.CorsAllowedOrigins, ","), @@ -176,7 +176,7 @@ func (h *httpServer) start() error { for _, path := range paths { name := h.handlerNames[path] if !logged[name] { - h.log.Info(name+" enabled", "url", "http://"+listener.Addr().String()+path) + h.logger.Info(name+" enabled", "url", "http://"+listener.Addr().String()+path) logged[name] = true } } @@ -248,7 +248,7 @@ func (h *httpServer) doStop() { } h.server.Shutdown(context.Background()) //nolint:errcheck h.listener.Close() - h.log.Info("HTTP server stopped", "endpoint", h.listener.Addr()) + h.logger.Info("HTTP server stopped", "endpoint", h.listener.Addr()) // Clear out everything to allow re-configuring it later. h.host, h.port, h.endpoint = "", 0, "" @@ -265,9 +265,9 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig, allowList rpc. } // Create RPC server and handler. - srv := rpc.NewServer(50, false /* traceRequests */, true) + srv := rpc.NewServer(50, false /* traceRequests */, true, h.logger) srv.SetAllowList(allowList) - if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false, h.log); err != nil { + if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false, h.logger); err != nil { return err } h.httpConfig = config @@ -298,14 +298,14 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig, allowList rpc.All } // Create RPC server and handler. - srv := rpc.NewServer(50, false /* traceRequests */, true) + srv := rpc.NewServer(50, false /* traceRequests */, true, h.logger) srv.SetAllowList(allowList) - if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false, h.log); err != nil { + if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false, h.logger); err != nil { return err } h.wsConfig = config h.wsHandler.Store(&rpcHandler{ - Handler: srv.WebsocketHandler(config.Origins, nil, false, h.log), + Handler: srv.WebsocketHandler(config.Origins, nil, false, h.logger), server: srv, }) return nil diff --git a/p2p/simulations/http.go b/p2p/simulations/http.go index 25321e0f9..c72101a4c 100644 --- a/p2p/simulations/http.go +++ b/p2p/simulations/http.go @@ -36,11 +36,12 @@ import ( "github.com/gorilla/websocket" "github.com/julienschmidt/httprouter" + "github.com/ledgerwatch/log/v3" ) // DefaultClient is the default simulation API client which expects the API // to be running at http://localhost:8888 -var DefaultClient = NewClient("http://localhost:8888") +var DefaultClient = NewClient("http://localhost:8888", log.New()) // Client is a client for the simulation HTTP API which supports creating // and managing simulation networks @@ -48,13 +49,15 @@ type Client struct { URL string client *http.Client + logger log.Logger } // NewClient returns a new simulation API client -func NewClient(url string) *Client { +func NewClient(url string, logger log.Logger) *Client { return &Client{ URL: url, client: http.DefaultClient, + logger: logger, } } @@ -208,7 +211,7 @@ func (c *Client) DisconnectNode(nodeID, peerID string) error { // RPCClient returns an RPC client connected to a node func (c *Client) RPCClient(ctx context.Context, nodeID string) (*rpc.Client, error) { baseURL := strings.Replace(c.URL, "http", "ws", 1) - return rpc.DialWebsocket(ctx, fmt.Sprintf("%s/nodes/%s/rpc", baseURL, nodeID), "") + return rpc.DialWebsocket(ctx, fmt.Sprintf("%s/nodes/%s/rpc", baseURL, nodeID), "", c.logger) } // Get performs a HTTP GET request decoding the resulting JSON response diff --git a/rpc/client.go b/rpc/client.go index 56262d4dc..e60c99db4 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -99,6 +99,7 @@ type Client struct { reqInit chan *requestOp // register response IDs, takes write lock reqSent chan error // signals write completion, releases write lock reqTimeout chan *requestOp // removes response IDs when call timeout expires + logger log.Logger } type reconnectFunc func(ctx context.Context) (ServerCodec, error) @@ -112,7 +113,7 @@ type clientConn struct { func (c *Client) newClientConn(conn ServerCodec) *clientConn { ctx := context.WithValue(context.Background(), clientContextKey{}, c) - handler := newHandler(ctx, conn, c.idgen, c.services, c.methodAllowList, 50, false /* traceRequests */) + handler := newHandler(ctx, conn, c.idgen, c.services, c.methodAllowList, 50, false /* traceRequests */, c.logger) return &clientConn{conn, handler} } @@ -159,26 +160,26 @@ func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, erro // For websocket connections, the origin is set to the local host name. // // The client reconnects automatically if the connection is lost. -func Dial(rawurl string) (*Client, error) { - return DialContext(context.Background(), rawurl) +func Dial(rawurl string, logger log.Logger) (*Client, error) { + return DialContext(context.Background(), rawurl, logger) } // DialContext creates a new RPC client, just like Dial. // // The context is used to cancel or time out the initial connection establishment. It does // not affect subsequent interactions with the client. -func DialContext(ctx context.Context, rawurl string) (*Client, error) { +func DialContext(ctx context.Context, rawurl string, logger log.Logger) (*Client, error) { u, err := url.Parse(rawurl) if err != nil { return nil, err } switch u.Scheme { case "http", "https": - return DialHTTP(rawurl) + return DialHTTP(rawurl, logger) case "ws", "wss": - return DialWebsocket(ctx, rawurl, "") + return DialWebsocket(ctx, rawurl, "", logger) case "stdio": - return DialStdIO(ctx) + return DialStdIO(ctx, logger) default: return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme) } @@ -186,22 +187,23 @@ func DialContext(ctx context.Context, rawurl string) (*Client, error) { // Client retrieves the client from the context, if any. This can be used to perform // 'reverse calls' in a handler method. -func ClientFromContext(ctx context.Context) (*Client, bool) { +func ClientFromContext(ctx context.Context, logger log.Logger) (*Client, bool) { client, ok := ctx.Value(clientContextKey{}).(*Client) + client.logger = logger return client, ok } -func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) { +func newClient(initctx context.Context, connect reconnectFunc, logger log.Logger) (*Client, error) { conn, err := connect(initctx) if err != nil { return nil, err } - c := initClient(conn, randomIDGenerator(), new(serviceRegistry)) + c := initClient(conn, randomIDGenerator(), new(serviceRegistry), logger) c.reconnectFunc = connect return c, nil } -func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client { +func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry, logger log.Logger) *Client { _, isHTTP := conn.(*httpConn) c := &Client{ idgen: idgen, @@ -217,6 +219,7 @@ func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *C reqInit: make(chan *requestOp), reqSent: make(chan error, 1), reqTimeout: make(chan *requestOp), + logger: logger, } if !isHTTP { go c.dispatch(conn) @@ -515,7 +518,7 @@ func (c *Client) reconnect(ctx context.Context) error { } newconn, err := c.reconnectFunc(ctx) if err != nil { - log.Trace("RPC client reconnect failed", "err", err) + c.logger.Trace("RPC client reconnect failed", "err", err) return err } select { @@ -564,13 +567,13 @@ func (c *Client) dispatch(codec ServerCodec) { } case err := <-c.readErr: - conn.handler.log.Trace("RPC connection read error", "err", err) + conn.handler.logger.Trace("RPC connection read error", "err", err) conn.close(err, lastOp) reading = false // Reconnect: case newcodec := <-c.reconnected: - log.Trace("RPC client reconnected", "reading", reading, "conn", newcodec.remoteAddr()) + c.logger.Trace("RPC client reconnected", "reading", reading, "conn", newcodec.remoteAddr()) if reading { // Wait for the previous read loop to exit. This is a rare case which // happens if this loop isn't notified in time after the connection breaks. diff --git a/rpc/client_example_test.go b/rpc/client_example_test.go index 28126b66d..203f9786a 100644 --- a/rpc/client_example_test.go +++ b/rpc/client_example_test.go @@ -23,6 +23,7 @@ import ( "github.com/ledgerwatch/erigon/common/hexutil" "github.com/ledgerwatch/erigon/rpc" + "github.com/ledgerwatch/log/v3" ) // In this example, our client wishes to track the latest 'block number' @@ -40,7 +41,8 @@ type Block struct { func ExampleClientSubscription() { // Connect the client. - client, _ := rpc.Dial("ws://127.0.0.1:8545") + logger := log.New() + client, _ := rpc.Dial("ws://127.0.0.1:8545", logger) subch := make(chan Block) // Ensure that subch receives the latest block. diff --git a/rpc/client_test.go b/rpc/client_test.go index 46e95fcb4..b504a9c32 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -35,9 +35,10 @@ import ( ) func TestClientRequest(t *testing.T) { - server := newTestServer() + logger := log.New() + server := newTestServer(logger) defer server.Stop() - client := DialInProc(server) + client := DialInProc(server, logger) defer client.Close() var resp echoResult @@ -50,9 +51,10 @@ func TestClientRequest(t *testing.T) { } func TestClientResponseType(t *testing.T) { - server := newTestServer() + logger := log.New() + server := newTestServer(logger) defer server.Stop() - client := DialInProc(server) + client := DialInProc(server, logger) defer client.Close() if err := client.Call(nil, "test_echo", "hello", 10, &echoArgs{"world"}); err != nil { @@ -68,9 +70,10 @@ func TestClientResponseType(t *testing.T) { // This test checks that server-returned errors with code and data come out of Client.Call. func TestClientErrorData(t *testing.T) { - server := newTestServer() + logger := log.New() + server := newTestServer(logger) defer server.Stop() - client := DialInProc(server) + client := DialInProc(server, logger) defer client.Close() var resp interface{} @@ -94,9 +97,10 @@ func TestClientErrorData(t *testing.T) { } func TestClientBatchRequest(t *testing.T) { - server := newTestServer() + logger := log.New() + server := newTestServer(logger) defer server.Stop() - client := DialInProc(server) + client := DialInProc(server, logger) defer client.Close() batch := []BatchElem{ @@ -143,9 +147,10 @@ func TestClientBatchRequest(t *testing.T) { } func TestClientNotify(t *testing.T) { - server := newTestServer() + logger := log.New() + server := newTestServer(logger) defer server.Stop() - client := DialInProc(server) + client := DialInProc(server, logger) defer client.Close() if err := client.Notify(context.Background(), "test_echo", "hello", 10, &echoArgs{"world"}); err != nil { @@ -154,18 +159,18 @@ func TestClientNotify(t *testing.T) { } // func TestClientCancelInproc(t *testing.T) { testClientCancel("inproc", t) } -func TestClientCancelWebsocket(t *testing.T) { testClientCancel("ws", t) } -func TestClientCancelHTTP(t *testing.T) { testClientCancel("http", t) } +func TestClientCancelWebsocket(t *testing.T) { testClientCancel("ws", t, log.New()) } +func TestClientCancelHTTP(t *testing.T) { testClientCancel("http", t, log.New()) } // This test checks that requests made through CallContext can be canceled by canceling // the context. -func testClientCancel(transport string, t *testing.T) { +func testClientCancel(transport string, t *testing.T, logger log.Logger) { // These tests take a lot of time, run them all at once. // You probably want to run with -parallel 1 or comment out // the call to t.Parallel if you enable the logging. t.Parallel() - server := newTestServer() + server := newTestServer(logger) defer server.Stop() // What we want to achieve is that the context gets canceled @@ -243,9 +248,10 @@ func testClientCancel(transport string, t *testing.T) { } func TestClientSubscribeInvalidArg(t *testing.T) { - server := newTestServer() + logger := log.New() + server := newTestServer(logger) defer server.Stop() - client := DialInProc(server) + client := DialInProc(server, logger) defer client.Close() check := func(shouldPanic bool, arg interface{}) { @@ -271,9 +277,10 @@ func TestClientSubscribeInvalidArg(t *testing.T) { } func TestClientSubscribe(t *testing.T) { - server := newTestServer() + logger := log.New() + server := newTestServer(logger) defer server.Stop() - client := DialInProc(server) + client := DialInProc(server, logger) defer client.Close() nc := make(chan int) @@ -303,7 +310,8 @@ func TestClientSubscribe(t *testing.T) { // In this test, the connection drops while Subscribe is waiting for a response. func TestClientSubscribeClose(t *testing.T) { - server := newTestServer() + logger := log.New() + server := newTestServer(logger) service := ¬ificationTestService{ gotHangSubscriptionReq: make(chan struct{}), unblockHangSubscription: make(chan struct{}), @@ -313,7 +321,7 @@ func TestClientSubscribeClose(t *testing.T) { } defer server.Stop() - client := DialInProc(server) + client := DialInProc(server, logger) defer client.Close() var ( @@ -347,11 +355,12 @@ func TestClientSubscribeClose(t *testing.T) { // This test reproduces https://github.com/ledgerwatch/erigon/issues/17837 where the // client hangs during shutdown when Unsubscribe races with Client.Close. func TestClientCloseUnsubscribeRace(t *testing.T) { - server := newTestServer() + logger := log.New() + server := newTestServer(logger) defer server.Stop() for i := 0; i < 20; i++ { - client := DialInProc(server) + client := DialInProc(server, logger) nc := make(chan int) sub, err := client.Subscribe(context.Background(), "nftest", nc, "someSubscription", 3, 1) if err != nil { @@ -370,11 +379,12 @@ func TestClientCloseUnsubscribeRace(t *testing.T) { // This test checks that Client doesn't lock up when a single subscriber // doesn't read subscription events. func TestClientNotificationStorm(t *testing.T) { - server := newTestServer() + logger := log.New() + server := newTestServer(logger) defer server.Stop() doTest := func(count int, wantError bool) { - client := DialInProc(server) + client := DialInProc(server, logger) defer client.Close() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -422,8 +432,9 @@ func TestClientNotificationStorm(t *testing.T) { } func TestClientSetHeader(t *testing.T) { + logger := log.New() var gotHeader bool - srv := newTestServer() + srv := newTestServer(logger) httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("test") == "ok" { gotHeader = true @@ -433,7 +444,7 @@ func TestClientSetHeader(t *testing.T) { defer httpsrv.Close() defer srv.Stop() - client, err := Dial(httpsrv.URL) + client, err := Dial(httpsrv.URL, logger) if err != nil { t.Fatal(err) } @@ -458,7 +469,8 @@ func TestClientSetHeader(t *testing.T) { } func TestClientHTTP(t *testing.T) { - server := newTestServer() + logger := log.New() + server := newTestServer(logger) defer server.Stop() client, hs := httpTestClient(server, "http", nil) @@ -504,7 +516,7 @@ func TestClientHTTP(t *testing.T) { func TestClientReconnect(t *testing.T) { logger := log.New() startServer := func(addr string) (*Server, net.Listener) { - srv := newTestServer() + srv := newTestServer(logger) l, err := net.Listen("tcp", addr) if err != nil { t.Fatal("can't listen:", err) @@ -518,7 +530,7 @@ func TestClientReconnect(t *testing.T) { // Start a server and corresponding client. s1, l1 := startServer("127.0.0.1:0") - client, err := DialContext(ctx, "ws://"+l1.Addr().String()) + client, err := DialContext(ctx, "ws://"+l1.Addr().String(), logger) if err != nil { t.Fatal("can't dial", err) } @@ -588,7 +600,7 @@ func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client, } // Connect the client. hs.Start() - client, err := Dial(transport + "://" + hs.Listener.Addr().String()) + client, err := Dial(transport+"://"+hs.Listener.Addr().String(), logger) if err != nil { panic(err) } diff --git a/rpc/handler.go b/rpc/handler.go index abe764833..b1da7fa23 100644 --- a/rpc/handler.go +++ b/rpc/handler.go @@ -61,7 +61,7 @@ type handler struct { rootCtx context.Context // canceled by close() cancelRoot func() // cancel function for rootCtx conn jsonWriter // where responses will be sent - log log.Logger + logger log.Logger allowSubscribe bool allowList AllowList // a list of explicitly allowed methods, if empty -- everything is allowed @@ -110,7 +110,7 @@ func HandleError(err error, stream *jsoniter.Stream) error { return nil } -func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, allowList AllowList, maxBatchConcurrency uint, traceRequests bool) *handler { +func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, allowList AllowList, maxBatchConcurrency uint, traceRequests bool, logger log.Logger) *handler { rootCtx, cancelRoot := context.WithCancel(connCtx) forbiddenList := newForbiddenList() h := &handler{ @@ -123,7 +123,7 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg * cancelRoot: cancelRoot, allowSubscribe: true, serverSubs: make(map[ID]*Subscription), - log: log.Root(), + logger: logger, allowList: allowList, forbiddenList: forbiddenList, @@ -132,7 +132,7 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg * } if conn.remoteAddr() != "" { - h.log = h.log.New("conn", conn.remoteAddr()) + h.logger = h.logger.New("conn", conn.remoteAddr()) } h.unsubscribeCb = newCallback(reflect.Value{}, reflect.ValueOf(h.unsubscribe), "unsubscribe") return h @@ -330,7 +330,7 @@ func (h *handler) handleImmediate(msg *jsonrpcMessage) bool { return false case msg.isResponse(): h.handleResponse(msg) - h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "t", time.Since(start)) + h.logger.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "t", time.Since(start)) return true default: return false @@ -341,7 +341,7 @@ func (h *handler) handleImmediate(msg *jsonrpcMessage) bool { func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) { var result subscriptionResult if err := json.Unmarshal(msg.Params, &result); err != nil { - h.log.Trace("Dropping invalid subscription message") + h.logger.Trace("Dropping invalid subscription message") return } if h.clientSubs[result.ID] != nil { @@ -353,7 +353,7 @@ func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) { func (h *handler) handleResponse(msg *jsonrpcMessage) { op := h.respWait[string(msg.ID)] if op == nil { - h.log.Trace("Unsolicited RPC response", "reqid", idForLog{msg.ID}) + h.logger.Trace("Unsolicited RPC response", "reqid", idForLog{msg.ID}) return } delete(h.respWait, string(msg.ID)) @@ -383,26 +383,26 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage, stream *json case msg.isNotification(): h.handleCall(ctx, msg, stream) if h.traceRequests { - h.log.Info("Served", "t", time.Since(start), "method", msg.Method, "params", string(msg.Params)) + h.logger.Info("Served", "t", time.Since(start), "method", msg.Method, "params", string(msg.Params)) } else { - h.log.Trace("Served", "t", time.Since(start), "method", msg.Method, "params", string(msg.Params)) + h.logger.Trace("Served", "t", time.Since(start), "method", msg.Method, "params", string(msg.Params)) } return nil case msg.isCall(): resp := h.handleCall(ctx, msg, stream) if resp != nil && resp.Error != nil { if resp.Error.Data != nil { - h.log.Warn("Served", "method", msg.Method, "reqid", idForLog{msg.ID}, "t", time.Since(start), + h.logger.Warn("Served", "method", msg.Method, "reqid", idForLog{msg.ID}, "t", time.Since(start), "err", resp.Error.Message, "errdata", resp.Error.Data) } else { - h.log.Warn("Served", "method", msg.Method, "reqid", idForLog{msg.ID}, "t", time.Since(start), + h.logger.Warn("Served", "method", msg.Method, "reqid", idForLog{msg.ID}, "t", time.Since(start), "err", resp.Error.Message) } } if h.traceRequests { - h.log.Info("Served", "t", time.Since(start), "method", msg.Method, "reqid", idForLog{msg.ID}, "params", string(msg.Params)) + h.logger.Info("Served", "t", time.Since(start), "method", msg.Method, "reqid", idForLog{msg.ID}, "params", string(msg.Params)) } else { - h.log.Trace("Served", "t", time.Since(start), "method", msg.Method, "reqid", idForLog{msg.ID}, "params", string(msg.Params)) + h.logger.Trace("Served", "t", time.Since(start), "method", msg.Method, "reqid", idForLog{msg.ID}, "params", string(msg.Params)) } return resp case msg.hasValidID(): diff --git a/rpc/http.go b/rpc/http.go index a6163994d..df33c5d2e 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -33,6 +33,7 @@ import ( "github.com/golang-jwt/jwt/v4" jsoniter "github.com/json-iterator/go" + "github.com/ledgerwatch/log/v3" ) const ( @@ -77,7 +78,7 @@ func (hc *httpConn) closed() <-chan interface{} { // DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP // using the provided HTTP Client. -func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { +func DialHTTPWithClient(endpoint string, client *http.Client, logger log.Logger) (*Client, error) { // Sanity check URL so we don't end up with a client that will fail every request. _, err := url.Parse(endpoint) if err != nil { @@ -96,12 +97,12 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { closeCh: make(chan interface{}), } return hc, nil - }) + }, logger) } // DialHTTP creates a new RPC client that connects to an RPC server over HTTP. -func DialHTTP(endpoint string) (*Client, error) { - return DialHTTPWithClient(endpoint, new(http.Client)) +func DialHTTP(endpoint string, logger log.Logger) (*Client, error) { + return DialHTTPWithClient(endpoint, new(http.Client), logger) } func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) error { diff --git a/rpc/http_test.go b/rpc/http_test.go index 728f90cbf..be1b43962 100644 --- a/rpc/http_test.go +++ b/rpc/http_test.go @@ -21,6 +21,8 @@ import ( "net/http/httptest" "strings" "testing" + + "github.com/ledgerwatch/log/v3" ) func confirmStatusCode(t *testing.T, got, want int) { @@ -102,9 +104,10 @@ func TestHTTPResponseWithEmptyGet(t *testing.T) { // This checks that maxRequestContentLength is not applied to the response of a request. func TestHTTPRespBodyUnlimited(t *testing.T) { + logger := log.New() const respLength = maxRequestContentLength * 3 - s := NewServer(50, false /* traceRequests */, true) + s := NewServer(50, false /* traceRequests */, true, logger) defer s.Stop() if err := s.RegisterName("test", largeRespService{respLength}); err != nil { t.Fatal(err) @@ -112,7 +115,7 @@ func TestHTTPRespBodyUnlimited(t *testing.T) { ts := httptest.NewServer(s) defer ts.Close() - c, err := DialHTTP(ts.URL) + c, err := DialHTTP(ts.URL, logger) if err != nil { t.Fatal(err) } diff --git a/rpc/inproc.go b/rpc/inproc.go index fbe9a40ce..93d727265 100644 --- a/rpc/inproc.go +++ b/rpc/inproc.go @@ -19,15 +19,17 @@ package rpc import ( "context" "net" + + "github.com/ledgerwatch/log/v3" ) // DialInProc attaches an in-process connection to the given RPC server. -func DialInProc(handler *Server) *Client { +func DialInProc(handler *Server, logger log.Logger) *Client { initctx := context.Background() c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) { p1, p2 := net.Pipe() go handler.ServeCodec(NewCodec(p1), 0) return NewCodec(p2), nil - }) + }, logger) return c } diff --git a/rpc/server.go b/rpc/server.go index 761080f64..b0805702c 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -54,11 +54,13 @@ type Server struct { disableStreaming bool traceRequests bool // Whether to print requests at INFO level batchLimit int // Maximum number of requests in a batch + logger log.Logger } // NewServer creates a new server instance with no registered handlers. -func NewServer(batchConcurrency uint, traceRequests, disableStreaming bool) *Server { - server := &Server{idgen: randomIDGenerator(), codecs: mapset.NewSet(), run: 1, batchConcurrency: batchConcurrency, disableStreaming: disableStreaming, traceRequests: traceRequests} +func NewServer(batchConcurrency uint, traceRequests, disableStreaming bool, logger log.Logger) *Server { + server := &Server{idgen: randomIDGenerator(), codecs: mapset.NewSet(), run: 1, batchConcurrency: batchConcurrency, + disableStreaming: disableStreaming, traceRequests: traceRequests, logger: logger} // Register the default service providing meta information about the RPC service such // as the services and methods it offers. rpcService := &RPCService{server: server} @@ -101,7 +103,7 @@ func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) { s.codecs.Add(codec) defer s.codecs.Remove(codec) - c := initClient(codec, s.idgen, &s.services) + c := initClient(codec, s.idgen, &s.services, s.logger) <-codec.closed() c.Close() } @@ -115,7 +117,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec, stre return } - h := newHandler(ctx, codec, s.idgen, &s.services, s.methodAllowList, s.batchConcurrency, s.traceRequests) + h := newHandler(ctx, codec, s.idgen, &s.services, s.methodAllowList, s.batchConcurrency, s.traceRequests, s.logger) h.allowSubscribe = false defer h.close(io.EOF, nil) @@ -142,7 +144,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec, stre // subscriptions. func (s *Server) Stop() { if atomic.CompareAndSwapInt32(&s.run, 1, 0) { - log.Info("RPC server shutting down") + s.logger.Info("RPC server shutting down") s.codecs.Each(func(c interface{}) bool { c.(ServerCodec).close() return true diff --git a/rpc/server_test.go b/rpc/server_test.go index 3c1e2f95e..b9e0b12e8 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -28,10 +28,13 @@ import ( "strings" "testing" "time" + + "github.com/ledgerwatch/log/v3" ) func TestServerRegisterName(t *testing.T) { - server := NewServer(50, false /* traceRequests */, true) + logger := log.New() + server := NewServer(50, false /* traceRequests */, true, logger) service := new(testService) if err := server.RegisterName("test", service); err != nil { @@ -54,6 +57,7 @@ func TestServerRegisterName(t *testing.T) { } func TestServer(t *testing.T) { + logger := log.New() files, err := os.ReadDir("testdata") if err != nil { t.Fatal("where'd my testdata go?") @@ -65,13 +69,13 @@ func TestServer(t *testing.T) { path := filepath.Join("testdata", f.Name()) name := strings.TrimSuffix(f.Name(), filepath.Ext(f.Name())) t.Run(name, func(t *testing.T) { - runTestScript(t, path) + runTestScript(t, path, logger) }) } } -func runTestScript(t *testing.T, file string) { - server := newTestServer() +func runTestScript(t *testing.T, file string, logger log.Logger) { + server := newTestServer(logger) content, err := os.ReadFile(file) if err != nil { t.Fatal(err) @@ -136,7 +140,8 @@ func runTestScript(t *testing.T, file string) { // This test checks that responses are delivered for very short-lived connections that // only carry a single request. func TestServerShortLivedConn(t *testing.T) { - server := newTestServer() + logger := log.New() + server := newTestServer(logger) defer server.Stop() listener, err := net.Listen("tcp", "127.0.0.1:0") diff --git a/rpc/stdio.go b/rpc/stdio.go index be2bab1c9..2129b117a 100644 --- a/rpc/stdio.go +++ b/rpc/stdio.go @@ -23,21 +23,23 @@ import ( "net" "os" "time" + + "github.com/ledgerwatch/log/v3" ) // DialStdIO creates a client on stdin/stdout. -func DialStdIO(ctx context.Context) (*Client, error) { - return DialIO(ctx, os.Stdin, os.Stdout) +func DialStdIO(ctx context.Context, logger log.Logger) (*Client, error) { + return DialIO(ctx, os.Stdin, os.Stdout, logger) } // DialIO creates a client which uses the given IO channels -func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) { +func DialIO(ctx context.Context, in io.Reader, out io.Writer, logger log.Logger) (*Client, error) { return newClient(ctx, func(_ context.Context) (ServerCodec, error) { return NewCodec(stdioConn{ in: in, out: out, }), nil - }) + }, logger) } type stdioConn struct { diff --git a/rpc/subscription_test.go b/rpc/subscription_test.go index 874ab977c..64099f953 100644 --- a/rpc/subscription_test.go +++ b/rpc/subscription_test.go @@ -23,6 +23,8 @@ import ( "strings" "testing" "time" + + "github.com/ledgerwatch/log/v3" ) func TestNewID(t *testing.T) { @@ -47,13 +49,14 @@ func TestNewID(t *testing.T) { } func TestSubscriptions(t *testing.T) { + logger := log.New() var ( namespaces = []string{"eth", "bzz"} service = ¬ificationTestService{} subCount = len(namespaces) notificationCount = 3 - server = NewServer(50, false /* traceRequests */, true) + server = NewServer(50, false /* traceRequests */, true, logger) clientConn, serverConn = net.Pipe() out = json.NewEncoder(clientConn) in = json.NewDecoder(clientConn) @@ -125,11 +128,12 @@ func TestSubscriptions(t *testing.T) { // This test checks that unsubscribing works. func TestServerUnsubscribe(t *testing.T) { + logger := log.New() p1, p2 := net.Pipe() defer p2.Close() // Start the server. - server := newTestServer() + server := newTestServer(logger) service := ¬ificationTestService{unsubscribed: make(chan string, 1)} server.RegisterName("nftest2", service) go server.ServeCodec(NewCodec(p1), 0) diff --git a/rpc/testservice_test.go b/rpc/testservice_test.go index 04d2e792e..8e6b4ffd7 100644 --- a/rpc/testservice_test.go +++ b/rpc/testservice_test.go @@ -23,10 +23,12 @@ import ( "strings" "sync" "time" + + "github.com/ledgerwatch/log/v3" ) -func newTestServer() *Server { - server := NewServer(50, false /* traceRequests */, true) +func newTestServer(logger log.Logger) *Server { + server := NewServer(50, false /* traceRequests */, true, logger) server.idgen = sequentialIDGenerator() if err := server.RegisterName("test", new(testService)); err != nil { panic(err) @@ -111,7 +113,7 @@ func (s *testService) ReturnError() error { } func (s *testService) CallMeBack(ctx context.Context, method string, args []interface{}) (interface{}, error) { - c, ok := ClientFromContext(ctx) + c, ok := ClientFromContext(ctx, log.New()) if !ok { return nil, errors.New("no client") } @@ -121,7 +123,7 @@ func (s *testService) CallMeBack(ctx context.Context, method string, args []inte } func (s *testService) CallMeBackLater(ctx context.Context, method string, args []interface{}) error { - c, ok := ClientFromContext(ctx) + c, ok := ClientFromContext(ctx, log.New()) if !ok { return errors.New("no client") } diff --git a/rpc/websocket.go b/rpc/websocket.go index b86a7cd71..00706be29 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -186,7 +186,7 @@ func parseOriginURL(origin string) (string, string, string, error) { // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server // that is listening on the given endpoint using the provided dialer. -func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { +func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer, logger log.Logger) (*Client, error) { endpoint, header, err := wsClientHeaders(endpoint, origin) if err != nil { return nil, err @@ -202,7 +202,7 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale return nil, hErr } return newWebsocketCodec(conn), nil - }) + }, logger) } // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server @@ -210,13 +210,13 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale // // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. -func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { +func DialWebsocket(ctx context.Context, endpoint, origin string, logger log.Logger) (*Client, error) { dialer := websocket.Dialer{ ReadBufferSize: wsReadBuffer, WriteBufferSize: wsWriteBuffer, WriteBufferPool: wsBufferPool, } - return DialWebsocketWithDialer(ctx, endpoint, origin, dialer) + return DialWebsocketWithDialer(ctx, endpoint, origin, dialer, logger) } func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 178ba9560..3a3fce2e8 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -54,14 +54,14 @@ func TestWebsocketOriginCheck(t *testing.T) { logger := log.New() var ( - srv = newTestServer() + srv = newTestServer(logger) httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"}, nil, false, logger)) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") ) defer srv.Stop() defer httpsrv.Close() - client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com") + client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com", logger) if err == nil { client.Close() t.Fatal("no error for wrong origin") @@ -72,7 +72,7 @@ func TestWebsocketOriginCheck(t *testing.T) { } // Connections without origin header should work. - client, err = DialWebsocket(context.Background(), wsURL, "") + client, err = DialWebsocket(context.Background(), wsURL, "", logger) if err != nil { t.Fatal("error for empty origin") } @@ -85,14 +85,14 @@ func TestWebsocketLargeCall(t *testing.T) { logger := log.New() var ( - srv = newTestServer() + srv = newTestServer(logger) httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}, nil, false, logger)) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") ) defer srv.Stop() defer httpsrv.Close() - client, clientErr := DialWebsocket(context.Background(), wsURL, "") + client, clientErr := DialWebsocket(context.Background(), wsURL, "", logger) if clientErr != nil { t.Fatalf("can't dial: %v", clientErr) } @@ -122,6 +122,7 @@ func TestClientWebsocketPing(t *testing.T) { } t.Parallel() + logger := log.New() var ( sendPing = make(chan struct{}) @@ -131,7 +132,7 @@ func TestClientWebsocketPing(t *testing.T) { defer cancel() defer server.Shutdown(ctx) - client, err := DialContext(ctx, "ws://"+server.Addr) + client, err := DialContext(ctx, "ws://"+server.Addr, logger) if err != nil { t.Fatalf("client dial error: %v", err) } @@ -167,7 +168,7 @@ func TestClientWebsocketPing(t *testing.T) { func TestClientWebsocketLargeMessage(t *testing.T) { logger := log.New() var ( - srv = NewServer(50, false /* traceRequests */, true) + srv = NewServer(50, false /* traceRequests */, true, logger) httpsrv = httptest.NewServer(srv.WebsocketHandler(nil, nil, false, logger)) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") ) @@ -179,7 +180,7 @@ func TestClientWebsocketLargeMessage(t *testing.T) { t.Fatal(err) } - c, err := DialWebsocket(context.Background(), wsURL, "") + c, err := DialWebsocket(context.Background(), wsURL, "", logger) if err != nil { t.Fatal(err) }