mirror of
https://gitlab.com/pulsechaincom/erigon-pulse.git
synced 2024-12-21 19:20:39 +00:00
[devnet tool] Separate logging (#7553)
Co-authored-by: Alex Sharp <alexsharp@Alexs-MacBook-Pro-2.local>
This commit is contained in:
parent
2a872b4d54
commit
067f695fff
@ -65,7 +65,7 @@ func action(ctx *cli.Context) error {
|
||||
logger := logging.SetupLoggerCtx("devnet", ctx, false /* rootLogger */)
|
||||
|
||||
// Make root logger fail
|
||||
//log.Root().SetHandler(PanicHandler{})
|
||||
log.Root().SetHandler(PanicHandler{})
|
||||
|
||||
// clear all the dev files
|
||||
if err := devnetutils.ClearDevDB(dataDir, logger); err != nil {
|
||||
@ -84,7 +84,7 @@ func action(ctx *cli.Context) error {
|
||||
time.Sleep(time.Second * 10)
|
||||
|
||||
// start up the subscription services for the different sub methods
|
||||
services.InitSubscriptions([]models.SubMethod{models.ETHNewHeads})
|
||||
services.InitSubscriptions([]models.SubMethod{models.ETHNewHeads}, logger)
|
||||
|
||||
// execute all rpc methods amongst the two nodes
|
||||
commands.ExecuteAllMethods()
|
||||
|
@ -9,11 +9,12 @@ import (
|
||||
"github.com/ledgerwatch/erigon/cmd/devnet/devnetutils"
|
||||
"github.com/ledgerwatch/erigon/cmd/devnet/models"
|
||||
"github.com/ledgerwatch/erigon/rpc"
|
||||
"github.com/ledgerwatch/log/v3"
|
||||
)
|
||||
|
||||
func InitSubscriptions(methods []models.SubMethod) {
|
||||
func InitSubscriptions(methods []models.SubMethod, logger log.Logger) {
|
||||
fmt.Printf("CONNECTING TO WEBSOCKETS AND SUBSCRIBING TO METHODS...\n")
|
||||
if err := subscribeAll(methods); err != nil {
|
||||
if err := subscribeAll(methods, logger); err != nil {
|
||||
fmt.Printf("failed to subscribe to all methods: %v\n", err)
|
||||
return
|
||||
}
|
||||
@ -65,8 +66,8 @@ func subscribe(client *rpc.Client, method models.SubMethod, args ...interface{})
|
||||
return methodSub, nil
|
||||
}
|
||||
|
||||
func subscribeToMethod(method models.SubMethod) (*models.MethodSubscription, error) {
|
||||
client, err := rpc.DialWebsocket(context.Background(), fmt.Sprintf("ws://%s", models.Localhost), "")
|
||||
func subscribeToMethod(method models.SubMethod, logger log.Logger) (*models.MethodSubscription, error) {
|
||||
client, err := rpc.DialWebsocket(context.Background(), fmt.Sprintf("ws://%s", models.Localhost), "", logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial websocket: %v", err)
|
||||
}
|
||||
@ -132,11 +133,11 @@ func UnsubscribeAll() {
|
||||
}
|
||||
|
||||
// subscribeAll subscribes to the range of methods provided
|
||||
func subscribeAll(methods []models.SubMethod) error {
|
||||
func subscribeAll(methods []models.SubMethod, logger log.Logger) error {
|
||||
m := make(map[models.SubMethod]*models.MethodSubscription)
|
||||
models.MethodSubscriptionMap = &m
|
||||
for _, method := range methods {
|
||||
sub, err := subscribeToMethod(method)
|
||||
sub, err := subscribeToMethod(method, logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -510,7 +510,7 @@ func NewBackend(stack *node.Node, config *ethconfig.Config, logger log.Logger) (
|
||||
// Initialize ethbackend
|
||||
ethBackendRPC := privateapi.NewEthBackendServer(ctx, backend, backend.chainDB, backend.notifications.Events,
|
||||
backend.blockReader, chainConfig, assembleBlockPOS, backend.sentriesClient.Hd, config.Miner.EnabledPOS, logger)
|
||||
miningRPC = privateapi.NewMiningServer(ctx, backend, ethashApi)
|
||||
miningRPC = privateapi.NewMiningServer(ctx, backend, ethashApi, logger)
|
||||
|
||||
var creds credentials.TransportCredentials
|
||||
if stack.Config().PrivateApiAddr != "" {
|
||||
|
@ -45,6 +45,7 @@ import (
|
||||
"text/tabwriter"
|
||||
|
||||
"github.com/ledgerwatch/erigon/cmd/utils"
|
||||
"github.com/ledgerwatch/erigon/turbo/logging"
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/ledgerwatch/erigon/crypto"
|
||||
@ -69,7 +70,8 @@ func main() {
|
||||
},
|
||||
}
|
||||
app.Before = func(ctx *cli.Context) error {
|
||||
client = simulations.NewClient(ctx.String("api"))
|
||||
logger := logging.SetupLoggerCtx("p2psim", ctx, false /* rootLogger */)
|
||||
client = simulations.NewClient(ctx.String("api"), logger)
|
||||
return nil
|
||||
}
|
||||
app.Commands = []*cli.Command{
|
||||
|
@ -499,7 +499,7 @@ func startRegularRpcServer(ctx context.Context, cfg httpcfg.HttpCfg, rpcAPI []rp
|
||||
httpEndpoint := fmt.Sprintf("%s:%d", cfg.HttpListenAddress, cfg.HttpPort)
|
||||
|
||||
logger.Trace("TraceRequests = %t\n", cfg.TraceRequests)
|
||||
srv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, cfg.RpcStreamingDisable)
|
||||
srv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, cfg.RpcStreamingDisable, logger)
|
||||
|
||||
allowListForRPC, err := parseAllowListForRPC(cfg.RpcAllowListFilePath)
|
||||
if err != nil {
|
||||
@ -619,7 +619,7 @@ type engineInfo struct {
|
||||
|
||||
func startAuthenticatedRpcServer(cfg httpcfg.HttpCfg, rpcAPI []rpc.API, logger log.Logger) (*engineInfo, error) {
|
||||
logger.Trace("TraceRequests = %t\n", cfg.TraceRequests)
|
||||
srv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, cfg.RpcStreamingDisable)
|
||||
srv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, cfg.RpcStreamingDisable, logger)
|
||||
|
||||
engineListener, engineSrv, engineHttpEndpoint, err := createEngineListener(cfg, rpcAPI, logger)
|
||||
if err != nil {
|
||||
@ -709,7 +709,7 @@ func createHandler(cfg httpcfg.HttpCfg, apiList []rpc.API, httpHandler http.Hand
|
||||
func createEngineListener(cfg httpcfg.HttpCfg, engineApi []rpc.API, logger log.Logger) (*http.Server, *rpc.Server, string, error) {
|
||||
engineHttpEndpoint := fmt.Sprintf("%s:%d", cfg.AuthRpcHTTPListenAddress, cfg.AuthRpcPort)
|
||||
|
||||
engineSrv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, true)
|
||||
engineSrv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, true, logger)
|
||||
|
||||
if err := node.RegisterApisFromWhitelist(engineApi, nil, engineSrv, true, logger); err != nil {
|
||||
return nil, nil, "", fmt.Errorf("could not start register RPC engine api: %w", err)
|
||||
|
@ -298,7 +298,7 @@ func CreateTestGrpcConn(t *testing.T, m *stages.MockSentry) (context.Context, *g
|
||||
remote.RegisterETHBACKENDServer(server, privateapi.NewEthBackendServer(ctx, nil, m.DB, m.Notifications.Events,
|
||||
snapshotsync.NewBlockReaderWithSnapshots(m.BlockSnapshots, m.TransactionsV3), nil, nil, nil, false, log.New()))
|
||||
txpool.RegisterTxpoolServer(server, m.TxPoolGrpcServer)
|
||||
txpool.RegisterMiningServer(server, privateapi.NewMiningServer(ctx, &IsMiningMock{}, ethashApi))
|
||||
txpool.RegisterMiningServer(server, privateapi.NewMiningServer(ctx, &IsMiningMock{}, ethashApi, m.Log))
|
||||
listener := bufconn.Listen(1024 * 1024)
|
||||
|
||||
dialer := func() func(context.Context, string) (net.Conn, error) {
|
||||
|
@ -172,7 +172,7 @@ func doTxpool(ctx context.Context, logger log.Logger) error {
|
||||
ethashApi = casted.APIs(nil)[1].Service.(*ethash.API)
|
||||
}
|
||||
*/
|
||||
miningGrpcServer := privateapi.NewMiningServer(ctx, &rpcdaemontest.IsMiningMock{}, nil)
|
||||
miningGrpcServer := privateapi.NewMiningServer(ctx, &rpcdaemontest.IsMiningMock{}, nil, logger)
|
||||
|
||||
grpcServer, err := txpool.StartGrpc(txpoolGrpcServer, miningGrpcServer, txpoolApiAddr, nil, logger)
|
||||
if err != nil {
|
||||
|
@ -546,7 +546,7 @@ func New(stack *node.Node, config *ethconfig.Config, logger log.Logger) (*Ethere
|
||||
// Initialize ethbackend
|
||||
ethBackendRPC := privateapi.NewEthBackendServer(ctx, backend, backend.chainDB, backend.notifications.Events,
|
||||
blockReader, chainConfig, assembleBlockPOS, backend.sentriesClient.Hd, config.Miner.EnabledPOS, logger)
|
||||
miningRPC = privateapi.NewMiningServer(ctx, backend, ethashApi)
|
||||
miningRPC = privateapi.NewMiningServer(ctx, backend, ethashApi, logger)
|
||||
|
||||
var creds credentials.TransportCredentials
|
||||
if stack.Config().PrivateApiAddr != "" {
|
||||
|
@ -47,7 +47,7 @@ func DefaultStages(ctx context.Context, snapshots SnapshotsCfg, headers HeadersC
|
||||
ID: stages.CumulativeIndex,
|
||||
Description: "Write Cumulative Index",
|
||||
Forward: func(firstCycle bool, badBlockUnwind bool, s *StageState, u Unwinder, tx kv.RwTx, logger log.Logger) error {
|
||||
return SpawnStageCumulativeIndex(cumulativeIndex, s, tx, ctx)
|
||||
return SpawnStageCumulativeIndex(cumulativeIndex, s, tx, ctx, logger)
|
||||
},
|
||||
Unwind: func(firstCycle bool, u *UnwindState, s *StageState, tx kv.RwTx, logger log.Logger) error {
|
||||
return UnwindCumulativeIndexStage(u, cumulativeIndex, tx, ctx)
|
||||
|
@ -29,7 +29,7 @@ func StageCumulativeIndexCfg(db kv.RwDB) CumulativeIndexCfg {
|
||||
}
|
||||
}
|
||||
|
||||
func SpawnStageCumulativeIndex(cfg CumulativeIndexCfg, s *StageState, tx kv.RwTx, ctx context.Context) error {
|
||||
func SpawnStageCumulativeIndex(cfg CumulativeIndexCfg, s *StageState, tx kv.RwTx, ctx context.Context, logger log.Logger) error {
|
||||
useExternalTx := tx != nil
|
||||
|
||||
if !useExternalTx {
|
||||
@ -105,11 +105,11 @@ func SpawnStageCumulativeIndex(cfg CumulativeIndexCfg, s *StageState, tx kv.RwTx
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-logEvery.C:
|
||||
log.Info(fmt.Sprintf("[%s] Wrote Cumulative Index", s.LogPrefix()),
|
||||
logger.Info(fmt.Sprintf("[%s] Wrote Cumulative Index", s.LogPrefix()),
|
||||
"gasUsed", cumulativeGasUsed.String(), "now", currentBlockNumber, "blk/sec", float64(currentBlockNumber-prevProgress)/float64(logInterval/time.Second))
|
||||
prevProgress = currentBlockNumber
|
||||
default:
|
||||
log.Trace("RequestQueueTime (header) ticked")
|
||||
logger.Trace("RequestQueueTime (header) ticked")
|
||||
}
|
||||
// Cleanup timer
|
||||
}
|
||||
|
@ -30,14 +30,15 @@ type MiningServer struct {
|
||||
minedBlockStreams MinedBlockStreams
|
||||
ethash *ethash.API
|
||||
isMining IsMining
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
type IsMining interface {
|
||||
IsMining() bool
|
||||
}
|
||||
|
||||
func NewMiningServer(ctx context.Context, isMining IsMining, ethashApi *ethash.API) *MiningServer {
|
||||
return &MiningServer{ctx: ctx, isMining: isMining, ethash: ethashApi}
|
||||
func NewMiningServer(ctx context.Context, isMining IsMining, ethashApi *ethash.API, logger log.Logger) *MiningServer {
|
||||
return &MiningServer{ctx: ctx, isMining: isMining, ethash: ethashApi, logger: logger}
|
||||
}
|
||||
|
||||
func (s *MiningServer) Version(context.Context, *emptypb.Empty) (*types2.VersionReply, error) {
|
||||
@ -100,7 +101,7 @@ func (s *MiningServer) BroadcastPendingLogs(l types.Logs) error {
|
||||
return err
|
||||
}
|
||||
reply := &proto_txpool.OnPendingBlockReply{RplBlock: b}
|
||||
s.pendingBlockStreams.Broadcast(reply)
|
||||
s.pendingBlockStreams.Broadcast(reply, s.logger)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -121,7 +122,7 @@ func (s *MiningServer) BroadcastPendingBlock(block *types.Block) error {
|
||||
return err
|
||||
}
|
||||
reply := &proto_txpool.OnPendingBlockReply{RplBlock: buf.Bytes()}
|
||||
s.pendingBlockStreams.Broadcast(reply)
|
||||
s.pendingBlockStreams.Broadcast(reply, s.logger)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -133,21 +134,22 @@ func (s *MiningServer) OnMinedBlock(req *proto_txpool.OnMinedBlockRequest, reply
|
||||
}
|
||||
|
||||
func (s *MiningServer) BroadcastMinedBlock(block *types.Block) error {
|
||||
log.Debug("BroadcastMinedBlock", "block hash", block.Hash(), "block number", block.Number(), "root", block.Root(), "gas", block.GasUsed())
|
||||
s.logger.Debug("BroadcastMinedBlock", "block hash", block.Hash(), "block number", block.Number(), "root", block.Root(), "gas", block.GasUsed())
|
||||
var buf bytes.Buffer
|
||||
if err := block.EncodeRLP(&buf); err != nil {
|
||||
return err
|
||||
}
|
||||
reply := &proto_txpool.OnMinedBlockReply{RplBlock: buf.Bytes()}
|
||||
s.minedBlockStreams.Broadcast(reply)
|
||||
s.minedBlockStreams.Broadcast(reply, s.logger)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MinedBlockStreams - it's safe to use this class as non-pointer
|
||||
type MinedBlockStreams struct {
|
||||
chans map[uint]proto_txpool.Mining_OnMinedBlockServer
|
||||
id uint
|
||||
mu sync.Mutex
|
||||
chans map[uint]proto_txpool.Mining_OnMinedBlockServer
|
||||
id uint
|
||||
mu sync.Mutex
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
func (s *MinedBlockStreams) Add(stream proto_txpool.Mining_OnMinedBlockServer) (remove func()) {
|
||||
@ -162,13 +164,13 @@ func (s *MinedBlockStreams) Add(stream proto_txpool.Mining_OnMinedBlockServer) (
|
||||
return func() { s.remove(id) }
|
||||
}
|
||||
|
||||
func (s *MinedBlockStreams) Broadcast(reply *proto_txpool.OnMinedBlockReply) {
|
||||
func (s *MinedBlockStreams) Broadcast(reply *proto_txpool.OnMinedBlockReply, logger log.Logger) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for id, stream := range s.chans {
|
||||
err := stream.Send(reply)
|
||||
if err != nil {
|
||||
log.Trace("failed send to mined block stream", "err", err)
|
||||
logger.Trace("failed send to mined block stream", "err", err)
|
||||
select {
|
||||
case <-stream.Context().Done():
|
||||
delete(s.chans, id)
|
||||
@ -207,13 +209,13 @@ func (s *PendingBlockStreams) Add(stream proto_txpool.Mining_OnPendingBlockServe
|
||||
return func() { s.remove(id) }
|
||||
}
|
||||
|
||||
func (s *PendingBlockStreams) Broadcast(reply *proto_txpool.OnPendingBlockReply) {
|
||||
func (s *PendingBlockStreams) Broadcast(reply *proto_txpool.OnPendingBlockReply, logger log.Logger) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for id, stream := range s.chans {
|
||||
err := stream.Send(reply)
|
||||
if err != nil {
|
||||
log.Trace("failed send to mined block stream", "err", err)
|
||||
logger.Trace("failed send to mined block stream", "err", err)
|
||||
select {
|
||||
case <-stream.Context().Done():
|
||||
delete(s.chans, id)
|
||||
@ -252,13 +254,13 @@ func (s *PendingLogsStreams) Add(stream proto_txpool.Mining_OnPendingLogsServer)
|
||||
return func() { s.remove(id) }
|
||||
}
|
||||
|
||||
func (s *PendingLogsStreams) Broadcast(reply *proto_txpool.OnPendingLogsReply) {
|
||||
func (s *PendingLogsStreams) Broadcast(reply *proto_txpool.OnPendingLogsReply, logger log.Logger) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for id, stream := range s.chans {
|
||||
err := stream.Send(reply)
|
||||
if err != nil {
|
||||
log.Trace("failed send to mined block stream", "err", err)
|
||||
logger.Trace("failed send to mined block stream", "err", err)
|
||||
select {
|
||||
case <-stream.Context().Done():
|
||||
delete(s.chans, id)
|
||||
|
@ -56,7 +56,7 @@ type rpcHandler struct {
|
||||
}
|
||||
|
||||
type httpServer struct {
|
||||
log log.Logger
|
||||
logger log.Logger
|
||||
timeouts rpccfg.HTTPTimeouts
|
||||
mux http.ServeMux // registered handlers go here
|
||||
|
||||
@ -82,7 +82,7 @@ type httpServer struct {
|
||||
}
|
||||
|
||||
func newHTTPServer(logger log.Logger, timeouts rpccfg.HTTPTimeouts) *httpServer {
|
||||
h := &httpServer{log: logger, timeouts: timeouts, handlerNames: make(map[string]string)}
|
||||
h := &httpServer{logger: logger, timeouts: timeouts, handlerNames: make(map[string]string)}
|
||||
|
||||
h.httpHandler.Store((*rpcHandler)(nil))
|
||||
h.wsHandler.Store((*rpcHandler)(nil))
|
||||
@ -150,14 +150,14 @@ func (h *httpServer) start() error {
|
||||
if h.wsConfig.prefix != "" {
|
||||
url += h.wsConfig.prefix
|
||||
}
|
||||
h.log.Info("WebSocket enabled", "url", url)
|
||||
h.logger.Info("WebSocket enabled", "url", url)
|
||||
}
|
||||
// if server is websocket only, return after logging
|
||||
if !h.rpcAllowed() {
|
||||
return nil
|
||||
}
|
||||
// Log http endpoint.
|
||||
h.log.Info("HTTP server started",
|
||||
h.logger.Info("HTTP server started",
|
||||
"endpoint", listener.Addr(),
|
||||
"prefix", h.httpConfig.prefix,
|
||||
"cors", strings.Join(h.httpConfig.CorsAllowedOrigins, ","),
|
||||
@ -176,7 +176,7 @@ func (h *httpServer) start() error {
|
||||
for _, path := range paths {
|
||||
name := h.handlerNames[path]
|
||||
if !logged[name] {
|
||||
h.log.Info(name+" enabled", "url", "http://"+listener.Addr().String()+path)
|
||||
h.logger.Info(name+" enabled", "url", "http://"+listener.Addr().String()+path)
|
||||
logged[name] = true
|
||||
}
|
||||
}
|
||||
@ -248,7 +248,7 @@ func (h *httpServer) doStop() {
|
||||
}
|
||||
h.server.Shutdown(context.Background()) //nolint:errcheck
|
||||
h.listener.Close()
|
||||
h.log.Info("HTTP server stopped", "endpoint", h.listener.Addr())
|
||||
h.logger.Info("HTTP server stopped", "endpoint", h.listener.Addr())
|
||||
|
||||
// Clear out everything to allow re-configuring it later.
|
||||
h.host, h.port, h.endpoint = "", 0, ""
|
||||
@ -265,9 +265,9 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig, allowList rpc.
|
||||
}
|
||||
|
||||
// Create RPC server and handler.
|
||||
srv := rpc.NewServer(50, false /* traceRequests */, true)
|
||||
srv := rpc.NewServer(50, false /* traceRequests */, true, h.logger)
|
||||
srv.SetAllowList(allowList)
|
||||
if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false, h.log); err != nil {
|
||||
if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false, h.logger); err != nil {
|
||||
return err
|
||||
}
|
||||
h.httpConfig = config
|
||||
@ -298,14 +298,14 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig, allowList rpc.All
|
||||
}
|
||||
|
||||
// Create RPC server and handler.
|
||||
srv := rpc.NewServer(50, false /* traceRequests */, true)
|
||||
srv := rpc.NewServer(50, false /* traceRequests */, true, h.logger)
|
||||
srv.SetAllowList(allowList)
|
||||
if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false, h.log); err != nil {
|
||||
if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false, h.logger); err != nil {
|
||||
return err
|
||||
}
|
||||
h.wsConfig = config
|
||||
h.wsHandler.Store(&rpcHandler{
|
||||
Handler: srv.WebsocketHandler(config.Origins, nil, false, h.log),
|
||||
Handler: srv.WebsocketHandler(config.Origins, nil, false, h.logger),
|
||||
server: srv,
|
||||
})
|
||||
return nil
|
||||
|
@ -36,11 +36,12 @@ import (
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/ledgerwatch/log/v3"
|
||||
)
|
||||
|
||||
// DefaultClient is the default simulation API client which expects the API
|
||||
// to be running at http://localhost:8888
|
||||
var DefaultClient = NewClient("http://localhost:8888")
|
||||
var DefaultClient = NewClient("http://localhost:8888", log.New())
|
||||
|
||||
// Client is a client for the simulation HTTP API which supports creating
|
||||
// and managing simulation networks
|
||||
@ -48,13 +49,15 @@ type Client struct {
|
||||
URL string
|
||||
|
||||
client *http.Client
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
// NewClient returns a new simulation API client
|
||||
func NewClient(url string) *Client {
|
||||
func NewClient(url string, logger log.Logger) *Client {
|
||||
return &Client{
|
||||
URL: url,
|
||||
client: http.DefaultClient,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
@ -208,7 +211,7 @@ func (c *Client) DisconnectNode(nodeID, peerID string) error {
|
||||
// RPCClient returns an RPC client connected to a node
|
||||
func (c *Client) RPCClient(ctx context.Context, nodeID string) (*rpc.Client, error) {
|
||||
baseURL := strings.Replace(c.URL, "http", "ws", 1)
|
||||
return rpc.DialWebsocket(ctx, fmt.Sprintf("%s/nodes/%s/rpc", baseURL, nodeID), "")
|
||||
return rpc.DialWebsocket(ctx, fmt.Sprintf("%s/nodes/%s/rpc", baseURL, nodeID), "", c.logger)
|
||||
}
|
||||
|
||||
// Get performs a HTTP GET request decoding the resulting JSON response
|
||||
|
@ -99,6 +99,7 @@ type Client struct {
|
||||
reqInit chan *requestOp // register response IDs, takes write lock
|
||||
reqSent chan error // signals write completion, releases write lock
|
||||
reqTimeout chan *requestOp // removes response IDs when call timeout expires
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
type reconnectFunc func(ctx context.Context) (ServerCodec, error)
|
||||
@ -112,7 +113,7 @@ type clientConn struct {
|
||||
|
||||
func (c *Client) newClientConn(conn ServerCodec) *clientConn {
|
||||
ctx := context.WithValue(context.Background(), clientContextKey{}, c)
|
||||
handler := newHandler(ctx, conn, c.idgen, c.services, c.methodAllowList, 50, false /* traceRequests */)
|
||||
handler := newHandler(ctx, conn, c.idgen, c.services, c.methodAllowList, 50, false /* traceRequests */, c.logger)
|
||||
return &clientConn{conn, handler}
|
||||
}
|
||||
|
||||
@ -159,26 +160,26 @@ func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, erro
|
||||
// For websocket connections, the origin is set to the local host name.
|
||||
//
|
||||
// The client reconnects automatically if the connection is lost.
|
||||
func Dial(rawurl string) (*Client, error) {
|
||||
return DialContext(context.Background(), rawurl)
|
||||
func Dial(rawurl string, logger log.Logger) (*Client, error) {
|
||||
return DialContext(context.Background(), rawurl, logger)
|
||||
}
|
||||
|
||||
// DialContext creates a new RPC client, just like Dial.
|
||||
//
|
||||
// The context is used to cancel or time out the initial connection establishment. It does
|
||||
// not affect subsequent interactions with the client.
|
||||
func DialContext(ctx context.Context, rawurl string) (*Client, error) {
|
||||
func DialContext(ctx context.Context, rawurl string, logger log.Logger) (*Client, error) {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch u.Scheme {
|
||||
case "http", "https":
|
||||
return DialHTTP(rawurl)
|
||||
return DialHTTP(rawurl, logger)
|
||||
case "ws", "wss":
|
||||
return DialWebsocket(ctx, rawurl, "")
|
||||
return DialWebsocket(ctx, rawurl, "", logger)
|
||||
case "stdio":
|
||||
return DialStdIO(ctx)
|
||||
return DialStdIO(ctx, logger)
|
||||
default:
|
||||
return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme)
|
||||
}
|
||||
@ -186,22 +187,23 @@ func DialContext(ctx context.Context, rawurl string) (*Client, error) {
|
||||
|
||||
// Client retrieves the client from the context, if any. This can be used to perform
|
||||
// 'reverse calls' in a handler method.
|
||||
func ClientFromContext(ctx context.Context) (*Client, bool) {
|
||||
func ClientFromContext(ctx context.Context, logger log.Logger) (*Client, bool) {
|
||||
client, ok := ctx.Value(clientContextKey{}).(*Client)
|
||||
client.logger = logger
|
||||
return client, ok
|
||||
}
|
||||
|
||||
func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) {
|
||||
func newClient(initctx context.Context, connect reconnectFunc, logger log.Logger) (*Client, error) {
|
||||
conn, err := connect(initctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := initClient(conn, randomIDGenerator(), new(serviceRegistry))
|
||||
c := initClient(conn, randomIDGenerator(), new(serviceRegistry), logger)
|
||||
c.reconnectFunc = connect
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client {
|
||||
func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry, logger log.Logger) *Client {
|
||||
_, isHTTP := conn.(*httpConn)
|
||||
c := &Client{
|
||||
idgen: idgen,
|
||||
@ -217,6 +219,7 @@ func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *C
|
||||
reqInit: make(chan *requestOp),
|
||||
reqSent: make(chan error, 1),
|
||||
reqTimeout: make(chan *requestOp),
|
||||
logger: logger,
|
||||
}
|
||||
if !isHTTP {
|
||||
go c.dispatch(conn)
|
||||
@ -515,7 +518,7 @@ func (c *Client) reconnect(ctx context.Context) error {
|
||||
}
|
||||
newconn, err := c.reconnectFunc(ctx)
|
||||
if err != nil {
|
||||
log.Trace("RPC client reconnect failed", "err", err)
|
||||
c.logger.Trace("RPC client reconnect failed", "err", err)
|
||||
return err
|
||||
}
|
||||
select {
|
||||
@ -564,13 +567,13 @@ func (c *Client) dispatch(codec ServerCodec) {
|
||||
}
|
||||
|
||||
case err := <-c.readErr:
|
||||
conn.handler.log.Trace("RPC connection read error", "err", err)
|
||||
conn.handler.logger.Trace("RPC connection read error", "err", err)
|
||||
conn.close(err, lastOp)
|
||||
reading = false
|
||||
|
||||
// Reconnect:
|
||||
case newcodec := <-c.reconnected:
|
||||
log.Trace("RPC client reconnected", "reading", reading, "conn", newcodec.remoteAddr())
|
||||
c.logger.Trace("RPC client reconnected", "reading", reading, "conn", newcodec.remoteAddr())
|
||||
if reading {
|
||||
// Wait for the previous read loop to exit. This is a rare case which
|
||||
// happens if this loop isn't notified in time after the connection breaks.
|
||||
|
@ -23,6 +23,7 @@ import (
|
||||
|
||||
"github.com/ledgerwatch/erigon/common/hexutil"
|
||||
"github.com/ledgerwatch/erigon/rpc"
|
||||
"github.com/ledgerwatch/log/v3"
|
||||
)
|
||||
|
||||
// In this example, our client wishes to track the latest 'block number'
|
||||
@ -40,7 +41,8 @@ type Block struct {
|
||||
|
||||
func ExampleClientSubscription() {
|
||||
// Connect the client.
|
||||
client, _ := rpc.Dial("ws://127.0.0.1:8545")
|
||||
logger := log.New()
|
||||
client, _ := rpc.Dial("ws://127.0.0.1:8545", logger)
|
||||
subch := make(chan Block)
|
||||
|
||||
// Ensure that subch receives the latest block.
|
||||
|
@ -35,9 +35,10 @@ import (
|
||||
)
|
||||
|
||||
func TestClientRequest(t *testing.T) {
|
||||
server := newTestServer()
|
||||
logger := log.New()
|
||||
server := newTestServer(logger)
|
||||
defer server.Stop()
|
||||
client := DialInProc(server)
|
||||
client := DialInProc(server, logger)
|
||||
defer client.Close()
|
||||
|
||||
var resp echoResult
|
||||
@ -50,9 +51,10 @@ func TestClientRequest(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientResponseType(t *testing.T) {
|
||||
server := newTestServer()
|
||||
logger := log.New()
|
||||
server := newTestServer(logger)
|
||||
defer server.Stop()
|
||||
client := DialInProc(server)
|
||||
client := DialInProc(server, logger)
|
||||
defer client.Close()
|
||||
|
||||
if err := client.Call(nil, "test_echo", "hello", 10, &echoArgs{"world"}); err != nil {
|
||||
@ -68,9 +70,10 @@ func TestClientResponseType(t *testing.T) {
|
||||
|
||||
// This test checks that server-returned errors with code and data come out of Client.Call.
|
||||
func TestClientErrorData(t *testing.T) {
|
||||
server := newTestServer()
|
||||
logger := log.New()
|
||||
server := newTestServer(logger)
|
||||
defer server.Stop()
|
||||
client := DialInProc(server)
|
||||
client := DialInProc(server, logger)
|
||||
defer client.Close()
|
||||
|
||||
var resp interface{}
|
||||
@ -94,9 +97,10 @@ func TestClientErrorData(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientBatchRequest(t *testing.T) {
|
||||
server := newTestServer()
|
||||
logger := log.New()
|
||||
server := newTestServer(logger)
|
||||
defer server.Stop()
|
||||
client := DialInProc(server)
|
||||
client := DialInProc(server, logger)
|
||||
defer client.Close()
|
||||
|
||||
batch := []BatchElem{
|
||||
@ -143,9 +147,10 @@ func TestClientBatchRequest(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientNotify(t *testing.T) {
|
||||
server := newTestServer()
|
||||
logger := log.New()
|
||||
server := newTestServer(logger)
|
||||
defer server.Stop()
|
||||
client := DialInProc(server)
|
||||
client := DialInProc(server, logger)
|
||||
defer client.Close()
|
||||
|
||||
if err := client.Notify(context.Background(), "test_echo", "hello", 10, &echoArgs{"world"}); err != nil {
|
||||
@ -154,18 +159,18 @@ func TestClientNotify(t *testing.T) {
|
||||
}
|
||||
|
||||
// func TestClientCancelInproc(t *testing.T) { testClientCancel("inproc", t) }
|
||||
func TestClientCancelWebsocket(t *testing.T) { testClientCancel("ws", t) }
|
||||
func TestClientCancelHTTP(t *testing.T) { testClientCancel("http", t) }
|
||||
func TestClientCancelWebsocket(t *testing.T) { testClientCancel("ws", t, log.New()) }
|
||||
func TestClientCancelHTTP(t *testing.T) { testClientCancel("http", t, log.New()) }
|
||||
|
||||
// This test checks that requests made through CallContext can be canceled by canceling
|
||||
// the context.
|
||||
func testClientCancel(transport string, t *testing.T) {
|
||||
func testClientCancel(transport string, t *testing.T, logger log.Logger) {
|
||||
// These tests take a lot of time, run them all at once.
|
||||
// You probably want to run with -parallel 1 or comment out
|
||||
// the call to t.Parallel if you enable the logging.
|
||||
t.Parallel()
|
||||
|
||||
server := newTestServer()
|
||||
server := newTestServer(logger)
|
||||
defer server.Stop()
|
||||
|
||||
// What we want to achieve is that the context gets canceled
|
||||
@ -243,9 +248,10 @@ func testClientCancel(transport string, t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientSubscribeInvalidArg(t *testing.T) {
|
||||
server := newTestServer()
|
||||
logger := log.New()
|
||||
server := newTestServer(logger)
|
||||
defer server.Stop()
|
||||
client := DialInProc(server)
|
||||
client := DialInProc(server, logger)
|
||||
defer client.Close()
|
||||
|
||||
check := func(shouldPanic bool, arg interface{}) {
|
||||
@ -271,9 +277,10 @@ func TestClientSubscribeInvalidArg(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientSubscribe(t *testing.T) {
|
||||
server := newTestServer()
|
||||
logger := log.New()
|
||||
server := newTestServer(logger)
|
||||
defer server.Stop()
|
||||
client := DialInProc(server)
|
||||
client := DialInProc(server, logger)
|
||||
defer client.Close()
|
||||
|
||||
nc := make(chan int)
|
||||
@ -303,7 +310,8 @@ func TestClientSubscribe(t *testing.T) {
|
||||
|
||||
// In this test, the connection drops while Subscribe is waiting for a response.
|
||||
func TestClientSubscribeClose(t *testing.T) {
|
||||
server := newTestServer()
|
||||
logger := log.New()
|
||||
server := newTestServer(logger)
|
||||
service := ¬ificationTestService{
|
||||
gotHangSubscriptionReq: make(chan struct{}),
|
||||
unblockHangSubscription: make(chan struct{}),
|
||||
@ -313,7 +321,7 @@ func TestClientSubscribeClose(t *testing.T) {
|
||||
}
|
||||
|
||||
defer server.Stop()
|
||||
client := DialInProc(server)
|
||||
client := DialInProc(server, logger)
|
||||
defer client.Close()
|
||||
|
||||
var (
|
||||
@ -347,11 +355,12 @@ func TestClientSubscribeClose(t *testing.T) {
|
||||
// This test reproduces https://github.com/ledgerwatch/erigon/issues/17837 where the
|
||||
// client hangs during shutdown when Unsubscribe races with Client.Close.
|
||||
func TestClientCloseUnsubscribeRace(t *testing.T) {
|
||||
server := newTestServer()
|
||||
logger := log.New()
|
||||
server := newTestServer(logger)
|
||||
defer server.Stop()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
client := DialInProc(server)
|
||||
client := DialInProc(server, logger)
|
||||
nc := make(chan int)
|
||||
sub, err := client.Subscribe(context.Background(), "nftest", nc, "someSubscription", 3, 1)
|
||||
if err != nil {
|
||||
@ -370,11 +379,12 @@ func TestClientCloseUnsubscribeRace(t *testing.T) {
|
||||
// This test checks that Client doesn't lock up when a single subscriber
|
||||
// doesn't read subscription events.
|
||||
func TestClientNotificationStorm(t *testing.T) {
|
||||
server := newTestServer()
|
||||
logger := log.New()
|
||||
server := newTestServer(logger)
|
||||
defer server.Stop()
|
||||
|
||||
doTest := func(count int, wantError bool) {
|
||||
client := DialInProc(server)
|
||||
client := DialInProc(server, logger)
|
||||
defer client.Close()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
@ -422,8 +432,9 @@ func TestClientNotificationStorm(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientSetHeader(t *testing.T) {
|
||||
logger := log.New()
|
||||
var gotHeader bool
|
||||
srv := newTestServer()
|
||||
srv := newTestServer(logger)
|
||||
httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("test") == "ok" {
|
||||
gotHeader = true
|
||||
@ -433,7 +444,7 @@ func TestClientSetHeader(t *testing.T) {
|
||||
defer httpsrv.Close()
|
||||
defer srv.Stop()
|
||||
|
||||
client, err := Dial(httpsrv.URL)
|
||||
client, err := Dial(httpsrv.URL, logger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -458,7 +469,8 @@ func TestClientSetHeader(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientHTTP(t *testing.T) {
|
||||
server := newTestServer()
|
||||
logger := log.New()
|
||||
server := newTestServer(logger)
|
||||
defer server.Stop()
|
||||
|
||||
client, hs := httpTestClient(server, "http", nil)
|
||||
@ -504,7 +516,7 @@ func TestClientHTTP(t *testing.T) {
|
||||
func TestClientReconnect(t *testing.T) {
|
||||
logger := log.New()
|
||||
startServer := func(addr string) (*Server, net.Listener) {
|
||||
srv := newTestServer()
|
||||
srv := newTestServer(logger)
|
||||
l, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatal("can't listen:", err)
|
||||
@ -518,7 +530,7 @@ func TestClientReconnect(t *testing.T) {
|
||||
|
||||
// Start a server and corresponding client.
|
||||
s1, l1 := startServer("127.0.0.1:0")
|
||||
client, err := DialContext(ctx, "ws://"+l1.Addr().String())
|
||||
client, err := DialContext(ctx, "ws://"+l1.Addr().String(), logger)
|
||||
if err != nil {
|
||||
t.Fatal("can't dial", err)
|
||||
}
|
||||
@ -588,7 +600,7 @@ func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client,
|
||||
}
|
||||
// Connect the client.
|
||||
hs.Start()
|
||||
client, err := Dial(transport + "://" + hs.Listener.Addr().String())
|
||||
client, err := Dial(transport+"://"+hs.Listener.Addr().String(), logger)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -61,7 +61,7 @@ type handler struct {
|
||||
rootCtx context.Context // canceled by close()
|
||||
cancelRoot func() // cancel function for rootCtx
|
||||
conn jsonWriter // where responses will be sent
|
||||
log log.Logger
|
||||
logger log.Logger
|
||||
allowSubscribe bool
|
||||
|
||||
allowList AllowList // a list of explicitly allowed methods, if empty -- everything is allowed
|
||||
@ -110,7 +110,7 @@ func HandleError(err error, stream *jsoniter.Stream) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, allowList AllowList, maxBatchConcurrency uint, traceRequests bool) *handler {
|
||||
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, allowList AllowList, maxBatchConcurrency uint, traceRequests bool, logger log.Logger) *handler {
|
||||
rootCtx, cancelRoot := context.WithCancel(connCtx)
|
||||
forbiddenList := newForbiddenList()
|
||||
h := &handler{
|
||||
@ -123,7 +123,7 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *
|
||||
cancelRoot: cancelRoot,
|
||||
allowSubscribe: true,
|
||||
serverSubs: make(map[ID]*Subscription),
|
||||
log: log.Root(),
|
||||
logger: logger,
|
||||
allowList: allowList,
|
||||
forbiddenList: forbiddenList,
|
||||
|
||||
@ -132,7 +132,7 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *
|
||||
}
|
||||
|
||||
if conn.remoteAddr() != "" {
|
||||
h.log = h.log.New("conn", conn.remoteAddr())
|
||||
h.logger = h.logger.New("conn", conn.remoteAddr())
|
||||
}
|
||||
h.unsubscribeCb = newCallback(reflect.Value{}, reflect.ValueOf(h.unsubscribe), "unsubscribe")
|
||||
return h
|
||||
@ -330,7 +330,7 @@ func (h *handler) handleImmediate(msg *jsonrpcMessage) bool {
|
||||
return false
|
||||
case msg.isResponse():
|
||||
h.handleResponse(msg)
|
||||
h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "t", time.Since(start))
|
||||
h.logger.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "t", time.Since(start))
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@ -341,7 +341,7 @@ func (h *handler) handleImmediate(msg *jsonrpcMessage) bool {
|
||||
func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) {
|
||||
var result subscriptionResult
|
||||
if err := json.Unmarshal(msg.Params, &result); err != nil {
|
||||
h.log.Trace("Dropping invalid subscription message")
|
||||
h.logger.Trace("Dropping invalid subscription message")
|
||||
return
|
||||
}
|
||||
if h.clientSubs[result.ID] != nil {
|
||||
@ -353,7 +353,7 @@ func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) {
|
||||
func (h *handler) handleResponse(msg *jsonrpcMessage) {
|
||||
op := h.respWait[string(msg.ID)]
|
||||
if op == nil {
|
||||
h.log.Trace("Unsolicited RPC response", "reqid", idForLog{msg.ID})
|
||||
h.logger.Trace("Unsolicited RPC response", "reqid", idForLog{msg.ID})
|
||||
return
|
||||
}
|
||||
delete(h.respWait, string(msg.ID))
|
||||
@ -383,26 +383,26 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage, stream *json
|
||||
case msg.isNotification():
|
||||
h.handleCall(ctx, msg, stream)
|
||||
if h.traceRequests {
|
||||
h.log.Info("Served", "t", time.Since(start), "method", msg.Method, "params", string(msg.Params))
|
||||
h.logger.Info("Served", "t", time.Since(start), "method", msg.Method, "params", string(msg.Params))
|
||||
} else {
|
||||
h.log.Trace("Served", "t", time.Since(start), "method", msg.Method, "params", string(msg.Params))
|
||||
h.logger.Trace("Served", "t", time.Since(start), "method", msg.Method, "params", string(msg.Params))
|
||||
}
|
||||
return nil
|
||||
case msg.isCall():
|
||||
resp := h.handleCall(ctx, msg, stream)
|
||||
if resp != nil && resp.Error != nil {
|
||||
if resp.Error.Data != nil {
|
||||
h.log.Warn("Served", "method", msg.Method, "reqid", idForLog{msg.ID}, "t", time.Since(start),
|
||||
h.logger.Warn("Served", "method", msg.Method, "reqid", idForLog{msg.ID}, "t", time.Since(start),
|
||||
"err", resp.Error.Message, "errdata", resp.Error.Data)
|
||||
} else {
|
||||
h.log.Warn("Served", "method", msg.Method, "reqid", idForLog{msg.ID}, "t", time.Since(start),
|
||||
h.logger.Warn("Served", "method", msg.Method, "reqid", idForLog{msg.ID}, "t", time.Since(start),
|
||||
"err", resp.Error.Message)
|
||||
}
|
||||
}
|
||||
if h.traceRequests {
|
||||
h.log.Info("Served", "t", time.Since(start), "method", msg.Method, "reqid", idForLog{msg.ID}, "params", string(msg.Params))
|
||||
h.logger.Info("Served", "t", time.Since(start), "method", msg.Method, "reqid", idForLog{msg.ID}, "params", string(msg.Params))
|
||||
} else {
|
||||
h.log.Trace("Served", "t", time.Since(start), "method", msg.Method, "reqid", idForLog{msg.ID}, "params", string(msg.Params))
|
||||
h.logger.Trace("Served", "t", time.Since(start), "method", msg.Method, "reqid", idForLog{msg.ID}, "params", string(msg.Params))
|
||||
}
|
||||
return resp
|
||||
case msg.hasValidID():
|
||||
|
@ -33,6 +33,7 @@ import (
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
jsoniter "github.com/json-iterator/go"
|
||||
"github.com/ledgerwatch/log/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -77,7 +78,7 @@ func (hc *httpConn) closed() <-chan interface{} {
|
||||
|
||||
// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP
|
||||
// using the provided HTTP Client.
|
||||
func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
|
||||
func DialHTTPWithClient(endpoint string, client *http.Client, logger log.Logger) (*Client, error) {
|
||||
// Sanity check URL so we don't end up with a client that will fail every request.
|
||||
_, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
@ -96,12 +97,12 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
|
||||
closeCh: make(chan interface{}),
|
||||
}
|
||||
return hc, nil
|
||||
})
|
||||
}, logger)
|
||||
}
|
||||
|
||||
// DialHTTP creates a new RPC client that connects to an RPC server over HTTP.
|
||||
func DialHTTP(endpoint string) (*Client, error) {
|
||||
return DialHTTPWithClient(endpoint, new(http.Client))
|
||||
func DialHTTP(endpoint string, logger log.Logger) (*Client, error) {
|
||||
return DialHTTPWithClient(endpoint, new(http.Client), logger)
|
||||
}
|
||||
|
||||
func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) error {
|
||||
|
@ -21,6 +21,8 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ledgerwatch/log/v3"
|
||||
)
|
||||
|
||||
func confirmStatusCode(t *testing.T, got, want int) {
|
||||
@ -102,9 +104,10 @@ func TestHTTPResponseWithEmptyGet(t *testing.T) {
|
||||
|
||||
// This checks that maxRequestContentLength is not applied to the response of a request.
|
||||
func TestHTTPRespBodyUnlimited(t *testing.T) {
|
||||
logger := log.New()
|
||||
const respLength = maxRequestContentLength * 3
|
||||
|
||||
s := NewServer(50, false /* traceRequests */, true)
|
||||
s := NewServer(50, false /* traceRequests */, true, logger)
|
||||
defer s.Stop()
|
||||
if err := s.RegisterName("test", largeRespService{respLength}); err != nil {
|
||||
t.Fatal(err)
|
||||
@ -112,7 +115,7 @@ func TestHTTPRespBodyUnlimited(t *testing.T) {
|
||||
ts := httptest.NewServer(s)
|
||||
defer ts.Close()
|
||||
|
||||
c, err := DialHTTP(ts.URL)
|
||||
c, err := DialHTTP(ts.URL, logger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -19,15 +19,17 @@ package rpc
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/ledgerwatch/log/v3"
|
||||
)
|
||||
|
||||
// DialInProc attaches an in-process connection to the given RPC server.
|
||||
func DialInProc(handler *Server) *Client {
|
||||
func DialInProc(handler *Server, logger log.Logger) *Client {
|
||||
initctx := context.Background()
|
||||
c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) {
|
||||
p1, p2 := net.Pipe()
|
||||
go handler.ServeCodec(NewCodec(p1), 0)
|
||||
return NewCodec(p2), nil
|
||||
})
|
||||
}, logger)
|
||||
return c
|
||||
}
|
||||
|
@ -54,11 +54,13 @@ type Server struct {
|
||||
disableStreaming bool
|
||||
traceRequests bool // Whether to print requests at INFO level
|
||||
batchLimit int // Maximum number of requests in a batch
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
// NewServer creates a new server instance with no registered handlers.
|
||||
func NewServer(batchConcurrency uint, traceRequests, disableStreaming bool) *Server {
|
||||
server := &Server{idgen: randomIDGenerator(), codecs: mapset.NewSet(), run: 1, batchConcurrency: batchConcurrency, disableStreaming: disableStreaming, traceRequests: traceRequests}
|
||||
func NewServer(batchConcurrency uint, traceRequests, disableStreaming bool, logger log.Logger) *Server {
|
||||
server := &Server{idgen: randomIDGenerator(), codecs: mapset.NewSet(), run: 1, batchConcurrency: batchConcurrency,
|
||||
disableStreaming: disableStreaming, traceRequests: traceRequests, logger: logger}
|
||||
// Register the default service providing meta information about the RPC service such
|
||||
// as the services and methods it offers.
|
||||
rpcService := &RPCService{server: server}
|
||||
@ -101,7 +103,7 @@ func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) {
|
||||
s.codecs.Add(codec)
|
||||
defer s.codecs.Remove(codec)
|
||||
|
||||
c := initClient(codec, s.idgen, &s.services)
|
||||
c := initClient(codec, s.idgen, &s.services, s.logger)
|
||||
<-codec.closed()
|
||||
c.Close()
|
||||
}
|
||||
@ -115,7 +117,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec, stre
|
||||
return
|
||||
}
|
||||
|
||||
h := newHandler(ctx, codec, s.idgen, &s.services, s.methodAllowList, s.batchConcurrency, s.traceRequests)
|
||||
h := newHandler(ctx, codec, s.idgen, &s.services, s.methodAllowList, s.batchConcurrency, s.traceRequests, s.logger)
|
||||
h.allowSubscribe = false
|
||||
defer h.close(io.EOF, nil)
|
||||
|
||||
@ -142,7 +144,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec, stre
|
||||
// subscriptions.
|
||||
func (s *Server) Stop() {
|
||||
if atomic.CompareAndSwapInt32(&s.run, 1, 0) {
|
||||
log.Info("RPC server shutting down")
|
||||
s.logger.Info("RPC server shutting down")
|
||||
s.codecs.Each(func(c interface{}) bool {
|
||||
c.(ServerCodec).close()
|
||||
return true
|
||||
|
@ -28,10 +28,13 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ledgerwatch/log/v3"
|
||||
)
|
||||
|
||||
func TestServerRegisterName(t *testing.T) {
|
||||
server := NewServer(50, false /* traceRequests */, true)
|
||||
logger := log.New()
|
||||
server := NewServer(50, false /* traceRequests */, true, logger)
|
||||
service := new(testService)
|
||||
|
||||
if err := server.RegisterName("test", service); err != nil {
|
||||
@ -54,6 +57,7 @@ func TestServerRegisterName(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
logger := log.New()
|
||||
files, err := os.ReadDir("testdata")
|
||||
if err != nil {
|
||||
t.Fatal("where'd my testdata go?")
|
||||
@ -65,13 +69,13 @@ func TestServer(t *testing.T) {
|
||||
path := filepath.Join("testdata", f.Name())
|
||||
name := strings.TrimSuffix(f.Name(), filepath.Ext(f.Name()))
|
||||
t.Run(name, func(t *testing.T) {
|
||||
runTestScript(t, path)
|
||||
runTestScript(t, path, logger)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runTestScript(t *testing.T, file string) {
|
||||
server := newTestServer()
|
||||
func runTestScript(t *testing.T, file string, logger log.Logger) {
|
||||
server := newTestServer(logger)
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -136,7 +140,8 @@ func runTestScript(t *testing.T, file string) {
|
||||
// This test checks that responses are delivered for very short-lived connections that
|
||||
// only carry a single request.
|
||||
func TestServerShortLivedConn(t *testing.T) {
|
||||
server := newTestServer()
|
||||
logger := log.New()
|
||||
server := newTestServer(logger)
|
||||
defer server.Stop()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
|
10
rpc/stdio.go
10
rpc/stdio.go
@ -23,21 +23,23 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/ledgerwatch/log/v3"
|
||||
)
|
||||
|
||||
// DialStdIO creates a client on stdin/stdout.
|
||||
func DialStdIO(ctx context.Context) (*Client, error) {
|
||||
return DialIO(ctx, os.Stdin, os.Stdout)
|
||||
func DialStdIO(ctx context.Context, logger log.Logger) (*Client, error) {
|
||||
return DialIO(ctx, os.Stdin, os.Stdout, logger)
|
||||
}
|
||||
|
||||
// DialIO creates a client which uses the given IO channels
|
||||
func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) {
|
||||
func DialIO(ctx context.Context, in io.Reader, out io.Writer, logger log.Logger) (*Client, error) {
|
||||
return newClient(ctx, func(_ context.Context) (ServerCodec, error) {
|
||||
return NewCodec(stdioConn{
|
||||
in: in,
|
||||
out: out,
|
||||
}), nil
|
||||
})
|
||||
}, logger)
|
||||
}
|
||||
|
||||
type stdioConn struct {
|
||||
|
@ -23,6 +23,8 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ledgerwatch/log/v3"
|
||||
)
|
||||
|
||||
func TestNewID(t *testing.T) {
|
||||
@ -47,13 +49,14 @@ func TestNewID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSubscriptions(t *testing.T) {
|
||||
logger := log.New()
|
||||
var (
|
||||
namespaces = []string{"eth", "bzz"}
|
||||
service = ¬ificationTestService{}
|
||||
subCount = len(namespaces)
|
||||
notificationCount = 3
|
||||
|
||||
server = NewServer(50, false /* traceRequests */, true)
|
||||
server = NewServer(50, false /* traceRequests */, true, logger)
|
||||
clientConn, serverConn = net.Pipe()
|
||||
out = json.NewEncoder(clientConn)
|
||||
in = json.NewDecoder(clientConn)
|
||||
@ -125,11 +128,12 @@ func TestSubscriptions(t *testing.T) {
|
||||
|
||||
// This test checks that unsubscribing works.
|
||||
func TestServerUnsubscribe(t *testing.T) {
|
||||
logger := log.New()
|
||||
p1, p2 := net.Pipe()
|
||||
defer p2.Close()
|
||||
|
||||
// Start the server.
|
||||
server := newTestServer()
|
||||
server := newTestServer(logger)
|
||||
service := ¬ificationTestService{unsubscribed: make(chan string, 1)}
|
||||
server.RegisterName("nftest2", service)
|
||||
go server.ServeCodec(NewCodec(p1), 0)
|
||||
|
@ -23,10 +23,12 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ledgerwatch/log/v3"
|
||||
)
|
||||
|
||||
func newTestServer() *Server {
|
||||
server := NewServer(50, false /* traceRequests */, true)
|
||||
func newTestServer(logger log.Logger) *Server {
|
||||
server := NewServer(50, false /* traceRequests */, true, logger)
|
||||
server.idgen = sequentialIDGenerator()
|
||||
if err := server.RegisterName("test", new(testService)); err != nil {
|
||||
panic(err)
|
||||
@ -111,7 +113,7 @@ func (s *testService) ReturnError() error {
|
||||
}
|
||||
|
||||
func (s *testService) CallMeBack(ctx context.Context, method string, args []interface{}) (interface{}, error) {
|
||||
c, ok := ClientFromContext(ctx)
|
||||
c, ok := ClientFromContext(ctx, log.New())
|
||||
if !ok {
|
||||
return nil, errors.New("no client")
|
||||
}
|
||||
@ -121,7 +123,7 @@ func (s *testService) CallMeBack(ctx context.Context, method string, args []inte
|
||||
}
|
||||
|
||||
func (s *testService) CallMeBackLater(ctx context.Context, method string, args []interface{}) error {
|
||||
c, ok := ClientFromContext(ctx)
|
||||
c, ok := ClientFromContext(ctx, log.New())
|
||||
if !ok {
|
||||
return errors.New("no client")
|
||||
}
|
||||
|
@ -186,7 +186,7 @@ func parseOriginURL(origin string) (string, string, string, error) {
|
||||
|
||||
// DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server
|
||||
// that is listening on the given endpoint using the provided dialer.
|
||||
func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
|
||||
func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer, logger log.Logger) (*Client, error) {
|
||||
endpoint, header, err := wsClientHeaders(endpoint, origin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -202,7 +202,7 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
|
||||
return nil, hErr
|
||||
}
|
||||
return newWebsocketCodec(conn), nil
|
||||
})
|
||||
}, logger)
|
||||
}
|
||||
|
||||
// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
|
||||
@ -210,13 +210,13 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
|
||||
//
|
||||
// The context is used for the initial connection establishment. It does not
|
||||
// affect subsequent interactions with the client.
|
||||
func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
|
||||
func DialWebsocket(ctx context.Context, endpoint, origin string, logger log.Logger) (*Client, error) {
|
||||
dialer := websocket.Dialer{
|
||||
ReadBufferSize: wsReadBuffer,
|
||||
WriteBufferSize: wsWriteBuffer,
|
||||
WriteBufferPool: wsBufferPool,
|
||||
}
|
||||
return DialWebsocketWithDialer(ctx, endpoint, origin, dialer)
|
||||
return DialWebsocketWithDialer(ctx, endpoint, origin, dialer, logger)
|
||||
}
|
||||
|
||||
func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
|
||||
|
@ -54,14 +54,14 @@ func TestWebsocketOriginCheck(t *testing.T) {
|
||||
|
||||
logger := log.New()
|
||||
var (
|
||||
srv = newTestServer()
|
||||
srv = newTestServer(logger)
|
||||
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"}, nil, false, logger))
|
||||
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
|
||||
)
|
||||
defer srv.Stop()
|
||||
defer httpsrv.Close()
|
||||
|
||||
client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com")
|
||||
client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com", logger)
|
||||
if err == nil {
|
||||
client.Close()
|
||||
t.Fatal("no error for wrong origin")
|
||||
@ -72,7 +72,7 @@ func TestWebsocketOriginCheck(t *testing.T) {
|
||||
}
|
||||
|
||||
// Connections without origin header should work.
|
||||
client, err = DialWebsocket(context.Background(), wsURL, "")
|
||||
client, err = DialWebsocket(context.Background(), wsURL, "", logger)
|
||||
if err != nil {
|
||||
t.Fatal("error for empty origin")
|
||||
}
|
||||
@ -85,14 +85,14 @@ func TestWebsocketLargeCall(t *testing.T) {
|
||||
logger := log.New()
|
||||
|
||||
var (
|
||||
srv = newTestServer()
|
||||
srv = newTestServer(logger)
|
||||
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}, nil, false, logger))
|
||||
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
|
||||
)
|
||||
defer srv.Stop()
|
||||
defer httpsrv.Close()
|
||||
|
||||
client, clientErr := DialWebsocket(context.Background(), wsURL, "")
|
||||
client, clientErr := DialWebsocket(context.Background(), wsURL, "", logger)
|
||||
if clientErr != nil {
|
||||
t.Fatalf("can't dial: %v", clientErr)
|
||||
}
|
||||
@ -122,6 +122,7 @@ func TestClientWebsocketPing(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
logger := log.New()
|
||||
|
||||
var (
|
||||
sendPing = make(chan struct{})
|
||||
@ -131,7 +132,7 @@ func TestClientWebsocketPing(t *testing.T) {
|
||||
defer cancel()
|
||||
defer server.Shutdown(ctx)
|
||||
|
||||
client, err := DialContext(ctx, "ws://"+server.Addr)
|
||||
client, err := DialContext(ctx, "ws://"+server.Addr, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("client dial error: %v", err)
|
||||
}
|
||||
@ -167,7 +168,7 @@ func TestClientWebsocketPing(t *testing.T) {
|
||||
func TestClientWebsocketLargeMessage(t *testing.T) {
|
||||
logger := log.New()
|
||||
var (
|
||||
srv = NewServer(50, false /* traceRequests */, true)
|
||||
srv = NewServer(50, false /* traceRequests */, true, logger)
|
||||
httpsrv = httptest.NewServer(srv.WebsocketHandler(nil, nil, false, logger))
|
||||
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
|
||||
)
|
||||
@ -179,7 +180,7 @@ func TestClientWebsocketLargeMessage(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c, err := DialWebsocket(context.Background(), wsURL, "")
|
||||
c, err := DialWebsocket(context.Background(), wsURL, "", logger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user