[devnet tool] Separate logging (#7553)

Co-authored-by: Alex Sharp <alexsharp@Alexs-MacBook-Pro-2.local>
This commit is contained in:
ledgerwatch 2023-05-20 14:48:16 +01:00 committed by GitHub
parent 2a872b4d54
commit 067f695fff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 194 additions and 147 deletions

View File

@ -65,7 +65,7 @@ func action(ctx *cli.Context) error {
logger := logging.SetupLoggerCtx("devnet", ctx, false /* rootLogger */)
// Make root logger fail
//log.Root().SetHandler(PanicHandler{})
log.Root().SetHandler(PanicHandler{})
// clear all the dev files
if err := devnetutils.ClearDevDB(dataDir, logger); err != nil {
@ -84,7 +84,7 @@ func action(ctx *cli.Context) error {
time.Sleep(time.Second * 10)
// start up the subscription services for the different sub methods
services.InitSubscriptions([]models.SubMethod{models.ETHNewHeads})
services.InitSubscriptions([]models.SubMethod{models.ETHNewHeads}, logger)
// execute all rpc methods amongst the two nodes
commands.ExecuteAllMethods()

View File

@ -9,11 +9,12 @@ import (
"github.com/ledgerwatch/erigon/cmd/devnet/devnetutils"
"github.com/ledgerwatch/erigon/cmd/devnet/models"
"github.com/ledgerwatch/erigon/rpc"
"github.com/ledgerwatch/log/v3"
)
func InitSubscriptions(methods []models.SubMethod) {
func InitSubscriptions(methods []models.SubMethod, logger log.Logger) {
fmt.Printf("CONNECTING TO WEBSOCKETS AND SUBSCRIBING TO METHODS...\n")
if err := subscribeAll(methods); err != nil {
if err := subscribeAll(methods, logger); err != nil {
fmt.Printf("failed to subscribe to all methods: %v\n", err)
return
}
@ -65,8 +66,8 @@ func subscribe(client *rpc.Client, method models.SubMethod, args ...interface{})
return methodSub, nil
}
func subscribeToMethod(method models.SubMethod) (*models.MethodSubscription, error) {
client, err := rpc.DialWebsocket(context.Background(), fmt.Sprintf("ws://%s", models.Localhost), "")
func subscribeToMethod(method models.SubMethod, logger log.Logger) (*models.MethodSubscription, error) {
client, err := rpc.DialWebsocket(context.Background(), fmt.Sprintf("ws://%s", models.Localhost), "", logger)
if err != nil {
return nil, fmt.Errorf("failed to dial websocket: %v", err)
}
@ -132,11 +133,11 @@ func UnsubscribeAll() {
}
// subscribeAll subscribes to the range of methods provided
func subscribeAll(methods []models.SubMethod) error {
func subscribeAll(methods []models.SubMethod, logger log.Logger) error {
m := make(map[models.SubMethod]*models.MethodSubscription)
models.MethodSubscriptionMap = &m
for _, method := range methods {
sub, err := subscribeToMethod(method)
sub, err := subscribeToMethod(method, logger)
if err != nil {
return err
}

View File

@ -510,7 +510,7 @@ func NewBackend(stack *node.Node, config *ethconfig.Config, logger log.Logger) (
// Initialize ethbackend
ethBackendRPC := privateapi.NewEthBackendServer(ctx, backend, backend.chainDB, backend.notifications.Events,
backend.blockReader, chainConfig, assembleBlockPOS, backend.sentriesClient.Hd, config.Miner.EnabledPOS, logger)
miningRPC = privateapi.NewMiningServer(ctx, backend, ethashApi)
miningRPC = privateapi.NewMiningServer(ctx, backend, ethashApi, logger)
var creds credentials.TransportCredentials
if stack.Config().PrivateApiAddr != "" {

View File

@ -45,6 +45,7 @@ import (
"text/tabwriter"
"github.com/ledgerwatch/erigon/cmd/utils"
"github.com/ledgerwatch/erigon/turbo/logging"
"github.com/urfave/cli/v2"
"github.com/ledgerwatch/erigon/crypto"
@ -69,7 +70,8 @@ func main() {
},
}
app.Before = func(ctx *cli.Context) error {
client = simulations.NewClient(ctx.String("api"))
logger := logging.SetupLoggerCtx("p2psim", ctx, false /* rootLogger */)
client = simulations.NewClient(ctx.String("api"), logger)
return nil
}
app.Commands = []*cli.Command{

View File

@ -499,7 +499,7 @@ func startRegularRpcServer(ctx context.Context, cfg httpcfg.HttpCfg, rpcAPI []rp
httpEndpoint := fmt.Sprintf("%s:%d", cfg.HttpListenAddress, cfg.HttpPort)
logger.Trace("TraceRequests = %t\n", cfg.TraceRequests)
srv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, cfg.RpcStreamingDisable)
srv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, cfg.RpcStreamingDisable, logger)
allowListForRPC, err := parseAllowListForRPC(cfg.RpcAllowListFilePath)
if err != nil {
@ -619,7 +619,7 @@ type engineInfo struct {
func startAuthenticatedRpcServer(cfg httpcfg.HttpCfg, rpcAPI []rpc.API, logger log.Logger) (*engineInfo, error) {
logger.Trace("TraceRequests = %t\n", cfg.TraceRequests)
srv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, cfg.RpcStreamingDisable)
srv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, cfg.RpcStreamingDisable, logger)
engineListener, engineSrv, engineHttpEndpoint, err := createEngineListener(cfg, rpcAPI, logger)
if err != nil {
@ -709,7 +709,7 @@ func createHandler(cfg httpcfg.HttpCfg, apiList []rpc.API, httpHandler http.Hand
func createEngineListener(cfg httpcfg.HttpCfg, engineApi []rpc.API, logger log.Logger) (*http.Server, *rpc.Server, string, error) {
engineHttpEndpoint := fmt.Sprintf("%s:%d", cfg.AuthRpcHTTPListenAddress, cfg.AuthRpcPort)
engineSrv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, true)
engineSrv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, true, logger)
if err := node.RegisterApisFromWhitelist(engineApi, nil, engineSrv, true, logger); err != nil {
return nil, nil, "", fmt.Errorf("could not start register RPC engine api: %w", err)

View File

@ -298,7 +298,7 @@ func CreateTestGrpcConn(t *testing.T, m *stages.MockSentry) (context.Context, *g
remote.RegisterETHBACKENDServer(server, privateapi.NewEthBackendServer(ctx, nil, m.DB, m.Notifications.Events,
snapshotsync.NewBlockReaderWithSnapshots(m.BlockSnapshots, m.TransactionsV3), nil, nil, nil, false, log.New()))
txpool.RegisterTxpoolServer(server, m.TxPoolGrpcServer)
txpool.RegisterMiningServer(server, privateapi.NewMiningServer(ctx, &IsMiningMock{}, ethashApi))
txpool.RegisterMiningServer(server, privateapi.NewMiningServer(ctx, &IsMiningMock{}, ethashApi, m.Log))
listener := bufconn.Listen(1024 * 1024)
dialer := func() func(context.Context, string) (net.Conn, error) {

View File

@ -172,7 +172,7 @@ func doTxpool(ctx context.Context, logger log.Logger) error {
ethashApi = casted.APIs(nil)[1].Service.(*ethash.API)
}
*/
miningGrpcServer := privateapi.NewMiningServer(ctx, &rpcdaemontest.IsMiningMock{}, nil)
miningGrpcServer := privateapi.NewMiningServer(ctx, &rpcdaemontest.IsMiningMock{}, nil, logger)
grpcServer, err := txpool.StartGrpc(txpoolGrpcServer, miningGrpcServer, txpoolApiAddr, nil, logger)
if err != nil {

View File

@ -546,7 +546,7 @@ func New(stack *node.Node, config *ethconfig.Config, logger log.Logger) (*Ethere
// Initialize ethbackend
ethBackendRPC := privateapi.NewEthBackendServer(ctx, backend, backend.chainDB, backend.notifications.Events,
blockReader, chainConfig, assembleBlockPOS, backend.sentriesClient.Hd, config.Miner.EnabledPOS, logger)
miningRPC = privateapi.NewMiningServer(ctx, backend, ethashApi)
miningRPC = privateapi.NewMiningServer(ctx, backend, ethashApi, logger)
var creds credentials.TransportCredentials
if stack.Config().PrivateApiAddr != "" {

View File

@ -47,7 +47,7 @@ func DefaultStages(ctx context.Context, snapshots SnapshotsCfg, headers HeadersC
ID: stages.CumulativeIndex,
Description: "Write Cumulative Index",
Forward: func(firstCycle bool, badBlockUnwind bool, s *StageState, u Unwinder, tx kv.RwTx, logger log.Logger) error {
return SpawnStageCumulativeIndex(cumulativeIndex, s, tx, ctx)
return SpawnStageCumulativeIndex(cumulativeIndex, s, tx, ctx, logger)
},
Unwind: func(firstCycle bool, u *UnwindState, s *StageState, tx kv.RwTx, logger log.Logger) error {
return UnwindCumulativeIndexStage(u, cumulativeIndex, tx, ctx)

View File

@ -29,7 +29,7 @@ func StageCumulativeIndexCfg(db kv.RwDB) CumulativeIndexCfg {
}
}
func SpawnStageCumulativeIndex(cfg CumulativeIndexCfg, s *StageState, tx kv.RwTx, ctx context.Context) error {
func SpawnStageCumulativeIndex(cfg CumulativeIndexCfg, s *StageState, tx kv.RwTx, ctx context.Context, logger log.Logger) error {
useExternalTx := tx != nil
if !useExternalTx {
@ -105,11 +105,11 @@ func SpawnStageCumulativeIndex(cfg CumulativeIndexCfg, s *StageState, tx kv.RwTx
case <-ctx.Done():
return ctx.Err()
case <-logEvery.C:
log.Info(fmt.Sprintf("[%s] Wrote Cumulative Index", s.LogPrefix()),
logger.Info(fmt.Sprintf("[%s] Wrote Cumulative Index", s.LogPrefix()),
"gasUsed", cumulativeGasUsed.String(), "now", currentBlockNumber, "blk/sec", float64(currentBlockNumber-prevProgress)/float64(logInterval/time.Second))
prevProgress = currentBlockNumber
default:
log.Trace("RequestQueueTime (header) ticked")
logger.Trace("RequestQueueTime (header) ticked")
}
// Cleanup timer
}

View File

@ -30,14 +30,15 @@ type MiningServer struct {
minedBlockStreams MinedBlockStreams
ethash *ethash.API
isMining IsMining
logger log.Logger
}
type IsMining interface {
IsMining() bool
}
func NewMiningServer(ctx context.Context, isMining IsMining, ethashApi *ethash.API) *MiningServer {
return &MiningServer{ctx: ctx, isMining: isMining, ethash: ethashApi}
func NewMiningServer(ctx context.Context, isMining IsMining, ethashApi *ethash.API, logger log.Logger) *MiningServer {
return &MiningServer{ctx: ctx, isMining: isMining, ethash: ethashApi, logger: logger}
}
func (s *MiningServer) Version(context.Context, *emptypb.Empty) (*types2.VersionReply, error) {
@ -100,7 +101,7 @@ func (s *MiningServer) BroadcastPendingLogs(l types.Logs) error {
return err
}
reply := &proto_txpool.OnPendingBlockReply{RplBlock: b}
s.pendingBlockStreams.Broadcast(reply)
s.pendingBlockStreams.Broadcast(reply, s.logger)
return nil
}
@ -121,7 +122,7 @@ func (s *MiningServer) BroadcastPendingBlock(block *types.Block) error {
return err
}
reply := &proto_txpool.OnPendingBlockReply{RplBlock: buf.Bytes()}
s.pendingBlockStreams.Broadcast(reply)
s.pendingBlockStreams.Broadcast(reply, s.logger)
return nil
}
@ -133,21 +134,22 @@ func (s *MiningServer) OnMinedBlock(req *proto_txpool.OnMinedBlockRequest, reply
}
func (s *MiningServer) BroadcastMinedBlock(block *types.Block) error {
log.Debug("BroadcastMinedBlock", "block hash", block.Hash(), "block number", block.Number(), "root", block.Root(), "gas", block.GasUsed())
s.logger.Debug("BroadcastMinedBlock", "block hash", block.Hash(), "block number", block.Number(), "root", block.Root(), "gas", block.GasUsed())
var buf bytes.Buffer
if err := block.EncodeRLP(&buf); err != nil {
return err
}
reply := &proto_txpool.OnMinedBlockReply{RplBlock: buf.Bytes()}
s.minedBlockStreams.Broadcast(reply)
s.minedBlockStreams.Broadcast(reply, s.logger)
return nil
}
// MinedBlockStreams - it's safe to use this class as non-pointer
type MinedBlockStreams struct {
chans map[uint]proto_txpool.Mining_OnMinedBlockServer
id uint
mu sync.Mutex
chans map[uint]proto_txpool.Mining_OnMinedBlockServer
id uint
mu sync.Mutex
logger log.Logger
}
func (s *MinedBlockStreams) Add(stream proto_txpool.Mining_OnMinedBlockServer) (remove func()) {
@ -162,13 +164,13 @@ func (s *MinedBlockStreams) Add(stream proto_txpool.Mining_OnMinedBlockServer) (
return func() { s.remove(id) }
}
func (s *MinedBlockStreams) Broadcast(reply *proto_txpool.OnMinedBlockReply) {
func (s *MinedBlockStreams) Broadcast(reply *proto_txpool.OnMinedBlockReply, logger log.Logger) {
s.mu.Lock()
defer s.mu.Unlock()
for id, stream := range s.chans {
err := stream.Send(reply)
if err != nil {
log.Trace("failed send to mined block stream", "err", err)
logger.Trace("failed send to mined block stream", "err", err)
select {
case <-stream.Context().Done():
delete(s.chans, id)
@ -207,13 +209,13 @@ func (s *PendingBlockStreams) Add(stream proto_txpool.Mining_OnPendingBlockServe
return func() { s.remove(id) }
}
func (s *PendingBlockStreams) Broadcast(reply *proto_txpool.OnPendingBlockReply) {
func (s *PendingBlockStreams) Broadcast(reply *proto_txpool.OnPendingBlockReply, logger log.Logger) {
s.mu.Lock()
defer s.mu.Unlock()
for id, stream := range s.chans {
err := stream.Send(reply)
if err != nil {
log.Trace("failed send to mined block stream", "err", err)
logger.Trace("failed send to mined block stream", "err", err)
select {
case <-stream.Context().Done():
delete(s.chans, id)
@ -252,13 +254,13 @@ func (s *PendingLogsStreams) Add(stream proto_txpool.Mining_OnPendingLogsServer)
return func() { s.remove(id) }
}
func (s *PendingLogsStreams) Broadcast(reply *proto_txpool.OnPendingLogsReply) {
func (s *PendingLogsStreams) Broadcast(reply *proto_txpool.OnPendingLogsReply, logger log.Logger) {
s.mu.Lock()
defer s.mu.Unlock()
for id, stream := range s.chans {
err := stream.Send(reply)
if err != nil {
log.Trace("failed send to mined block stream", "err", err)
logger.Trace("failed send to mined block stream", "err", err)
select {
case <-stream.Context().Done():
delete(s.chans, id)

View File

@ -56,7 +56,7 @@ type rpcHandler struct {
}
type httpServer struct {
log log.Logger
logger log.Logger
timeouts rpccfg.HTTPTimeouts
mux http.ServeMux // registered handlers go here
@ -82,7 +82,7 @@ type httpServer struct {
}
func newHTTPServer(logger log.Logger, timeouts rpccfg.HTTPTimeouts) *httpServer {
h := &httpServer{log: logger, timeouts: timeouts, handlerNames: make(map[string]string)}
h := &httpServer{logger: logger, timeouts: timeouts, handlerNames: make(map[string]string)}
h.httpHandler.Store((*rpcHandler)(nil))
h.wsHandler.Store((*rpcHandler)(nil))
@ -150,14 +150,14 @@ func (h *httpServer) start() error {
if h.wsConfig.prefix != "" {
url += h.wsConfig.prefix
}
h.log.Info("WebSocket enabled", "url", url)
h.logger.Info("WebSocket enabled", "url", url)
}
// if server is websocket only, return after logging
if !h.rpcAllowed() {
return nil
}
// Log http endpoint.
h.log.Info("HTTP server started",
h.logger.Info("HTTP server started",
"endpoint", listener.Addr(),
"prefix", h.httpConfig.prefix,
"cors", strings.Join(h.httpConfig.CorsAllowedOrigins, ","),
@ -176,7 +176,7 @@ func (h *httpServer) start() error {
for _, path := range paths {
name := h.handlerNames[path]
if !logged[name] {
h.log.Info(name+" enabled", "url", "http://"+listener.Addr().String()+path)
h.logger.Info(name+" enabled", "url", "http://"+listener.Addr().String()+path)
logged[name] = true
}
}
@ -248,7 +248,7 @@ func (h *httpServer) doStop() {
}
h.server.Shutdown(context.Background()) //nolint:errcheck
h.listener.Close()
h.log.Info("HTTP server stopped", "endpoint", h.listener.Addr())
h.logger.Info("HTTP server stopped", "endpoint", h.listener.Addr())
// Clear out everything to allow re-configuring it later.
h.host, h.port, h.endpoint = "", 0, ""
@ -265,9 +265,9 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig, allowList rpc.
}
// Create RPC server and handler.
srv := rpc.NewServer(50, false /* traceRequests */, true)
srv := rpc.NewServer(50, false /* traceRequests */, true, h.logger)
srv.SetAllowList(allowList)
if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false, h.log); err != nil {
if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false, h.logger); err != nil {
return err
}
h.httpConfig = config
@ -298,14 +298,14 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig, allowList rpc.All
}
// Create RPC server and handler.
srv := rpc.NewServer(50, false /* traceRequests */, true)
srv := rpc.NewServer(50, false /* traceRequests */, true, h.logger)
srv.SetAllowList(allowList)
if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false, h.log); err != nil {
if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false, h.logger); err != nil {
return err
}
h.wsConfig = config
h.wsHandler.Store(&rpcHandler{
Handler: srv.WebsocketHandler(config.Origins, nil, false, h.log),
Handler: srv.WebsocketHandler(config.Origins, nil, false, h.logger),
server: srv,
})
return nil

View File

@ -36,11 +36,12 @@ import (
"github.com/gorilla/websocket"
"github.com/julienschmidt/httprouter"
"github.com/ledgerwatch/log/v3"
)
// DefaultClient is the default simulation API client which expects the API
// to be running at http://localhost:8888
var DefaultClient = NewClient("http://localhost:8888")
var DefaultClient = NewClient("http://localhost:8888", log.New())
// Client is a client for the simulation HTTP API which supports creating
// and managing simulation networks
@ -48,13 +49,15 @@ type Client struct {
URL string
client *http.Client
logger log.Logger
}
// NewClient returns a new simulation API client
func NewClient(url string) *Client {
func NewClient(url string, logger log.Logger) *Client {
return &Client{
URL: url,
client: http.DefaultClient,
logger: logger,
}
}
@ -208,7 +211,7 @@ func (c *Client) DisconnectNode(nodeID, peerID string) error {
// RPCClient returns an RPC client connected to a node
func (c *Client) RPCClient(ctx context.Context, nodeID string) (*rpc.Client, error) {
baseURL := strings.Replace(c.URL, "http", "ws", 1)
return rpc.DialWebsocket(ctx, fmt.Sprintf("%s/nodes/%s/rpc", baseURL, nodeID), "")
return rpc.DialWebsocket(ctx, fmt.Sprintf("%s/nodes/%s/rpc", baseURL, nodeID), "", c.logger)
}
// Get performs a HTTP GET request decoding the resulting JSON response

View File

@ -99,6 +99,7 @@ type Client struct {
reqInit chan *requestOp // register response IDs, takes write lock
reqSent chan error // signals write completion, releases write lock
reqTimeout chan *requestOp // removes response IDs when call timeout expires
logger log.Logger
}
type reconnectFunc func(ctx context.Context) (ServerCodec, error)
@ -112,7 +113,7 @@ type clientConn struct {
func (c *Client) newClientConn(conn ServerCodec) *clientConn {
ctx := context.WithValue(context.Background(), clientContextKey{}, c)
handler := newHandler(ctx, conn, c.idgen, c.services, c.methodAllowList, 50, false /* traceRequests */)
handler := newHandler(ctx, conn, c.idgen, c.services, c.methodAllowList, 50, false /* traceRequests */, c.logger)
return &clientConn{conn, handler}
}
@ -159,26 +160,26 @@ func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, erro
// For websocket connections, the origin is set to the local host name.
//
// The client reconnects automatically if the connection is lost.
func Dial(rawurl string) (*Client, error) {
return DialContext(context.Background(), rawurl)
func Dial(rawurl string, logger log.Logger) (*Client, error) {
return DialContext(context.Background(), rawurl, logger)
}
// DialContext creates a new RPC client, just like Dial.
//
// The context is used to cancel or time out the initial connection establishment. It does
// not affect subsequent interactions with the client.
func DialContext(ctx context.Context, rawurl string) (*Client, error) {
func DialContext(ctx context.Context, rawurl string, logger log.Logger) (*Client, error) {
u, err := url.Parse(rawurl)
if err != nil {
return nil, err
}
switch u.Scheme {
case "http", "https":
return DialHTTP(rawurl)
return DialHTTP(rawurl, logger)
case "ws", "wss":
return DialWebsocket(ctx, rawurl, "")
return DialWebsocket(ctx, rawurl, "", logger)
case "stdio":
return DialStdIO(ctx)
return DialStdIO(ctx, logger)
default:
return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme)
}
@ -186,22 +187,23 @@ func DialContext(ctx context.Context, rawurl string) (*Client, error) {
// Client retrieves the client from the context, if any. This can be used to perform
// 'reverse calls' in a handler method.
func ClientFromContext(ctx context.Context) (*Client, bool) {
func ClientFromContext(ctx context.Context, logger log.Logger) (*Client, bool) {
client, ok := ctx.Value(clientContextKey{}).(*Client)
client.logger = logger
return client, ok
}
func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) {
func newClient(initctx context.Context, connect reconnectFunc, logger log.Logger) (*Client, error) {
conn, err := connect(initctx)
if err != nil {
return nil, err
}
c := initClient(conn, randomIDGenerator(), new(serviceRegistry))
c := initClient(conn, randomIDGenerator(), new(serviceRegistry), logger)
c.reconnectFunc = connect
return c, nil
}
func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client {
func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry, logger log.Logger) *Client {
_, isHTTP := conn.(*httpConn)
c := &Client{
idgen: idgen,
@ -217,6 +219,7 @@ func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *C
reqInit: make(chan *requestOp),
reqSent: make(chan error, 1),
reqTimeout: make(chan *requestOp),
logger: logger,
}
if !isHTTP {
go c.dispatch(conn)
@ -515,7 +518,7 @@ func (c *Client) reconnect(ctx context.Context) error {
}
newconn, err := c.reconnectFunc(ctx)
if err != nil {
log.Trace("RPC client reconnect failed", "err", err)
c.logger.Trace("RPC client reconnect failed", "err", err)
return err
}
select {
@ -564,13 +567,13 @@ func (c *Client) dispatch(codec ServerCodec) {
}
case err := <-c.readErr:
conn.handler.log.Trace("RPC connection read error", "err", err)
conn.handler.logger.Trace("RPC connection read error", "err", err)
conn.close(err, lastOp)
reading = false
// Reconnect:
case newcodec := <-c.reconnected:
log.Trace("RPC client reconnected", "reading", reading, "conn", newcodec.remoteAddr())
c.logger.Trace("RPC client reconnected", "reading", reading, "conn", newcodec.remoteAddr())
if reading {
// Wait for the previous read loop to exit. This is a rare case which
// happens if this loop isn't notified in time after the connection breaks.

View File

@ -23,6 +23,7 @@ import (
"github.com/ledgerwatch/erigon/common/hexutil"
"github.com/ledgerwatch/erigon/rpc"
"github.com/ledgerwatch/log/v3"
)
// In this example, our client wishes to track the latest 'block number'
@ -40,7 +41,8 @@ type Block struct {
func ExampleClientSubscription() {
// Connect the client.
client, _ := rpc.Dial("ws://127.0.0.1:8545")
logger := log.New()
client, _ := rpc.Dial("ws://127.0.0.1:8545", logger)
subch := make(chan Block)
// Ensure that subch receives the latest block.

View File

@ -35,9 +35,10 @@ import (
)
func TestClientRequest(t *testing.T) {
server := newTestServer()
logger := log.New()
server := newTestServer(logger)
defer server.Stop()
client := DialInProc(server)
client := DialInProc(server, logger)
defer client.Close()
var resp echoResult
@ -50,9 +51,10 @@ func TestClientRequest(t *testing.T) {
}
func TestClientResponseType(t *testing.T) {
server := newTestServer()
logger := log.New()
server := newTestServer(logger)
defer server.Stop()
client := DialInProc(server)
client := DialInProc(server, logger)
defer client.Close()
if err := client.Call(nil, "test_echo", "hello", 10, &echoArgs{"world"}); err != nil {
@ -68,9 +70,10 @@ func TestClientResponseType(t *testing.T) {
// This test checks that server-returned errors with code and data come out of Client.Call.
func TestClientErrorData(t *testing.T) {
server := newTestServer()
logger := log.New()
server := newTestServer(logger)
defer server.Stop()
client := DialInProc(server)
client := DialInProc(server, logger)
defer client.Close()
var resp interface{}
@ -94,9 +97,10 @@ func TestClientErrorData(t *testing.T) {
}
func TestClientBatchRequest(t *testing.T) {
server := newTestServer()
logger := log.New()
server := newTestServer(logger)
defer server.Stop()
client := DialInProc(server)
client := DialInProc(server, logger)
defer client.Close()
batch := []BatchElem{
@ -143,9 +147,10 @@ func TestClientBatchRequest(t *testing.T) {
}
func TestClientNotify(t *testing.T) {
server := newTestServer()
logger := log.New()
server := newTestServer(logger)
defer server.Stop()
client := DialInProc(server)
client := DialInProc(server, logger)
defer client.Close()
if err := client.Notify(context.Background(), "test_echo", "hello", 10, &echoArgs{"world"}); err != nil {
@ -154,18 +159,18 @@ func TestClientNotify(t *testing.T) {
}
// func TestClientCancelInproc(t *testing.T) { testClientCancel("inproc", t) }
func TestClientCancelWebsocket(t *testing.T) { testClientCancel("ws", t) }
func TestClientCancelHTTP(t *testing.T) { testClientCancel("http", t) }
func TestClientCancelWebsocket(t *testing.T) { testClientCancel("ws", t, log.New()) }
func TestClientCancelHTTP(t *testing.T) { testClientCancel("http", t, log.New()) }
// This test checks that requests made through CallContext can be canceled by canceling
// the context.
func testClientCancel(transport string, t *testing.T) {
func testClientCancel(transport string, t *testing.T, logger log.Logger) {
// These tests take a lot of time, run them all at once.
// You probably want to run with -parallel 1 or comment out
// the call to t.Parallel if you enable the logging.
t.Parallel()
server := newTestServer()
server := newTestServer(logger)
defer server.Stop()
// What we want to achieve is that the context gets canceled
@ -243,9 +248,10 @@ func testClientCancel(transport string, t *testing.T) {
}
func TestClientSubscribeInvalidArg(t *testing.T) {
server := newTestServer()
logger := log.New()
server := newTestServer(logger)
defer server.Stop()
client := DialInProc(server)
client := DialInProc(server, logger)
defer client.Close()
check := func(shouldPanic bool, arg interface{}) {
@ -271,9 +277,10 @@ func TestClientSubscribeInvalidArg(t *testing.T) {
}
func TestClientSubscribe(t *testing.T) {
server := newTestServer()
logger := log.New()
server := newTestServer(logger)
defer server.Stop()
client := DialInProc(server)
client := DialInProc(server, logger)
defer client.Close()
nc := make(chan int)
@ -303,7 +310,8 @@ func TestClientSubscribe(t *testing.T) {
// In this test, the connection drops while Subscribe is waiting for a response.
func TestClientSubscribeClose(t *testing.T) {
server := newTestServer()
logger := log.New()
server := newTestServer(logger)
service := &notificationTestService{
gotHangSubscriptionReq: make(chan struct{}),
unblockHangSubscription: make(chan struct{}),
@ -313,7 +321,7 @@ func TestClientSubscribeClose(t *testing.T) {
}
defer server.Stop()
client := DialInProc(server)
client := DialInProc(server, logger)
defer client.Close()
var (
@ -347,11 +355,12 @@ func TestClientSubscribeClose(t *testing.T) {
// This test reproduces https://github.com/ledgerwatch/erigon/issues/17837 where the
// client hangs during shutdown when Unsubscribe races with Client.Close.
func TestClientCloseUnsubscribeRace(t *testing.T) {
server := newTestServer()
logger := log.New()
server := newTestServer(logger)
defer server.Stop()
for i := 0; i < 20; i++ {
client := DialInProc(server)
client := DialInProc(server, logger)
nc := make(chan int)
sub, err := client.Subscribe(context.Background(), "nftest", nc, "someSubscription", 3, 1)
if err != nil {
@ -370,11 +379,12 @@ func TestClientCloseUnsubscribeRace(t *testing.T) {
// This test checks that Client doesn't lock up when a single subscriber
// doesn't read subscription events.
func TestClientNotificationStorm(t *testing.T) {
server := newTestServer()
logger := log.New()
server := newTestServer(logger)
defer server.Stop()
doTest := func(count int, wantError bool) {
client := DialInProc(server)
client := DialInProc(server, logger)
defer client.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
@ -422,8 +432,9 @@ func TestClientNotificationStorm(t *testing.T) {
}
func TestClientSetHeader(t *testing.T) {
logger := log.New()
var gotHeader bool
srv := newTestServer()
srv := newTestServer(logger)
httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("test") == "ok" {
gotHeader = true
@ -433,7 +444,7 @@ func TestClientSetHeader(t *testing.T) {
defer httpsrv.Close()
defer srv.Stop()
client, err := Dial(httpsrv.URL)
client, err := Dial(httpsrv.URL, logger)
if err != nil {
t.Fatal(err)
}
@ -458,7 +469,8 @@ func TestClientSetHeader(t *testing.T) {
}
func TestClientHTTP(t *testing.T) {
server := newTestServer()
logger := log.New()
server := newTestServer(logger)
defer server.Stop()
client, hs := httpTestClient(server, "http", nil)
@ -504,7 +516,7 @@ func TestClientHTTP(t *testing.T) {
func TestClientReconnect(t *testing.T) {
logger := log.New()
startServer := func(addr string) (*Server, net.Listener) {
srv := newTestServer()
srv := newTestServer(logger)
l, err := net.Listen("tcp", addr)
if err != nil {
t.Fatal("can't listen:", err)
@ -518,7 +530,7 @@ func TestClientReconnect(t *testing.T) {
// Start a server and corresponding client.
s1, l1 := startServer("127.0.0.1:0")
client, err := DialContext(ctx, "ws://"+l1.Addr().String())
client, err := DialContext(ctx, "ws://"+l1.Addr().String(), logger)
if err != nil {
t.Fatal("can't dial", err)
}
@ -588,7 +600,7 @@ func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client,
}
// Connect the client.
hs.Start()
client, err := Dial(transport + "://" + hs.Listener.Addr().String())
client, err := Dial(transport+"://"+hs.Listener.Addr().String(), logger)
if err != nil {
panic(err)
}

View File

@ -61,7 +61,7 @@ type handler struct {
rootCtx context.Context // canceled by close()
cancelRoot func() // cancel function for rootCtx
conn jsonWriter // where responses will be sent
log log.Logger
logger log.Logger
allowSubscribe bool
allowList AllowList // a list of explicitly allowed methods, if empty -- everything is allowed
@ -110,7 +110,7 @@ func HandleError(err error, stream *jsoniter.Stream) error {
return nil
}
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, allowList AllowList, maxBatchConcurrency uint, traceRequests bool) *handler {
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, allowList AllowList, maxBatchConcurrency uint, traceRequests bool, logger log.Logger) *handler {
rootCtx, cancelRoot := context.WithCancel(connCtx)
forbiddenList := newForbiddenList()
h := &handler{
@ -123,7 +123,7 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *
cancelRoot: cancelRoot,
allowSubscribe: true,
serverSubs: make(map[ID]*Subscription),
log: log.Root(),
logger: logger,
allowList: allowList,
forbiddenList: forbiddenList,
@ -132,7 +132,7 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *
}
if conn.remoteAddr() != "" {
h.log = h.log.New("conn", conn.remoteAddr())
h.logger = h.logger.New("conn", conn.remoteAddr())
}
h.unsubscribeCb = newCallback(reflect.Value{}, reflect.ValueOf(h.unsubscribe), "unsubscribe")
return h
@ -330,7 +330,7 @@ func (h *handler) handleImmediate(msg *jsonrpcMessage) bool {
return false
case msg.isResponse():
h.handleResponse(msg)
h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "t", time.Since(start))
h.logger.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "t", time.Since(start))
return true
default:
return false
@ -341,7 +341,7 @@ func (h *handler) handleImmediate(msg *jsonrpcMessage) bool {
func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) {
var result subscriptionResult
if err := json.Unmarshal(msg.Params, &result); err != nil {
h.log.Trace("Dropping invalid subscription message")
h.logger.Trace("Dropping invalid subscription message")
return
}
if h.clientSubs[result.ID] != nil {
@ -353,7 +353,7 @@ func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) {
func (h *handler) handleResponse(msg *jsonrpcMessage) {
op := h.respWait[string(msg.ID)]
if op == nil {
h.log.Trace("Unsolicited RPC response", "reqid", idForLog{msg.ID})
h.logger.Trace("Unsolicited RPC response", "reqid", idForLog{msg.ID})
return
}
delete(h.respWait, string(msg.ID))
@ -383,26 +383,26 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage, stream *json
case msg.isNotification():
h.handleCall(ctx, msg, stream)
if h.traceRequests {
h.log.Info("Served", "t", time.Since(start), "method", msg.Method, "params", string(msg.Params))
h.logger.Info("Served", "t", time.Since(start), "method", msg.Method, "params", string(msg.Params))
} else {
h.log.Trace("Served", "t", time.Since(start), "method", msg.Method, "params", string(msg.Params))
h.logger.Trace("Served", "t", time.Since(start), "method", msg.Method, "params", string(msg.Params))
}
return nil
case msg.isCall():
resp := h.handleCall(ctx, msg, stream)
if resp != nil && resp.Error != nil {
if resp.Error.Data != nil {
h.log.Warn("Served", "method", msg.Method, "reqid", idForLog{msg.ID}, "t", time.Since(start),
h.logger.Warn("Served", "method", msg.Method, "reqid", idForLog{msg.ID}, "t", time.Since(start),
"err", resp.Error.Message, "errdata", resp.Error.Data)
} else {
h.log.Warn("Served", "method", msg.Method, "reqid", idForLog{msg.ID}, "t", time.Since(start),
h.logger.Warn("Served", "method", msg.Method, "reqid", idForLog{msg.ID}, "t", time.Since(start),
"err", resp.Error.Message)
}
}
if h.traceRequests {
h.log.Info("Served", "t", time.Since(start), "method", msg.Method, "reqid", idForLog{msg.ID}, "params", string(msg.Params))
h.logger.Info("Served", "t", time.Since(start), "method", msg.Method, "reqid", idForLog{msg.ID}, "params", string(msg.Params))
} else {
h.log.Trace("Served", "t", time.Since(start), "method", msg.Method, "reqid", idForLog{msg.ID}, "params", string(msg.Params))
h.logger.Trace("Served", "t", time.Since(start), "method", msg.Method, "reqid", idForLog{msg.ID}, "params", string(msg.Params))
}
return resp
case msg.hasValidID():

View File

@ -33,6 +33,7 @@ import (
"github.com/golang-jwt/jwt/v4"
jsoniter "github.com/json-iterator/go"
"github.com/ledgerwatch/log/v3"
)
const (
@ -77,7 +78,7 @@ func (hc *httpConn) closed() <-chan interface{} {
// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP
// using the provided HTTP Client.
func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
func DialHTTPWithClient(endpoint string, client *http.Client, logger log.Logger) (*Client, error) {
// Sanity check URL so we don't end up with a client that will fail every request.
_, err := url.Parse(endpoint)
if err != nil {
@ -96,12 +97,12 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
closeCh: make(chan interface{}),
}
return hc, nil
})
}, logger)
}
// DialHTTP creates a new RPC client that connects to an RPC server over HTTP.
func DialHTTP(endpoint string) (*Client, error) {
return DialHTTPWithClient(endpoint, new(http.Client))
func DialHTTP(endpoint string, logger log.Logger) (*Client, error) {
return DialHTTPWithClient(endpoint, new(http.Client), logger)
}
func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) error {

View File

@ -21,6 +21,8 @@ import (
"net/http/httptest"
"strings"
"testing"
"github.com/ledgerwatch/log/v3"
)
func confirmStatusCode(t *testing.T, got, want int) {
@ -102,9 +104,10 @@ func TestHTTPResponseWithEmptyGet(t *testing.T) {
// This checks that maxRequestContentLength is not applied to the response of a request.
func TestHTTPRespBodyUnlimited(t *testing.T) {
logger := log.New()
const respLength = maxRequestContentLength * 3
s := NewServer(50, false /* traceRequests */, true)
s := NewServer(50, false /* traceRequests */, true, logger)
defer s.Stop()
if err := s.RegisterName("test", largeRespService{respLength}); err != nil {
t.Fatal(err)
@ -112,7 +115,7 @@ func TestHTTPRespBodyUnlimited(t *testing.T) {
ts := httptest.NewServer(s)
defer ts.Close()
c, err := DialHTTP(ts.URL)
c, err := DialHTTP(ts.URL, logger)
if err != nil {
t.Fatal(err)
}

View File

@ -19,15 +19,17 @@ package rpc
import (
"context"
"net"
"github.com/ledgerwatch/log/v3"
)
// DialInProc attaches an in-process connection to the given RPC server.
func DialInProc(handler *Server) *Client {
func DialInProc(handler *Server, logger log.Logger) *Client {
initctx := context.Background()
c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) {
p1, p2 := net.Pipe()
go handler.ServeCodec(NewCodec(p1), 0)
return NewCodec(p2), nil
})
}, logger)
return c
}

View File

@ -54,11 +54,13 @@ type Server struct {
disableStreaming bool
traceRequests bool // Whether to print requests at INFO level
batchLimit int // Maximum number of requests in a batch
logger log.Logger
}
// NewServer creates a new server instance with no registered handlers.
func NewServer(batchConcurrency uint, traceRequests, disableStreaming bool) *Server {
server := &Server{idgen: randomIDGenerator(), codecs: mapset.NewSet(), run: 1, batchConcurrency: batchConcurrency, disableStreaming: disableStreaming, traceRequests: traceRequests}
func NewServer(batchConcurrency uint, traceRequests, disableStreaming bool, logger log.Logger) *Server {
server := &Server{idgen: randomIDGenerator(), codecs: mapset.NewSet(), run: 1, batchConcurrency: batchConcurrency,
disableStreaming: disableStreaming, traceRequests: traceRequests, logger: logger}
// Register the default service providing meta information about the RPC service such
// as the services and methods it offers.
rpcService := &RPCService{server: server}
@ -101,7 +103,7 @@ func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) {
s.codecs.Add(codec)
defer s.codecs.Remove(codec)
c := initClient(codec, s.idgen, &s.services)
c := initClient(codec, s.idgen, &s.services, s.logger)
<-codec.closed()
c.Close()
}
@ -115,7 +117,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec, stre
return
}
h := newHandler(ctx, codec, s.idgen, &s.services, s.methodAllowList, s.batchConcurrency, s.traceRequests)
h := newHandler(ctx, codec, s.idgen, &s.services, s.methodAllowList, s.batchConcurrency, s.traceRequests, s.logger)
h.allowSubscribe = false
defer h.close(io.EOF, nil)
@ -142,7 +144,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec, stre
// subscriptions.
func (s *Server) Stop() {
if atomic.CompareAndSwapInt32(&s.run, 1, 0) {
log.Info("RPC server shutting down")
s.logger.Info("RPC server shutting down")
s.codecs.Each(func(c interface{}) bool {
c.(ServerCodec).close()
return true

View File

@ -28,10 +28,13 @@ import (
"strings"
"testing"
"time"
"github.com/ledgerwatch/log/v3"
)
func TestServerRegisterName(t *testing.T) {
server := NewServer(50, false /* traceRequests */, true)
logger := log.New()
server := NewServer(50, false /* traceRequests */, true, logger)
service := new(testService)
if err := server.RegisterName("test", service); err != nil {
@ -54,6 +57,7 @@ func TestServerRegisterName(t *testing.T) {
}
func TestServer(t *testing.T) {
logger := log.New()
files, err := os.ReadDir("testdata")
if err != nil {
t.Fatal("where'd my testdata go?")
@ -65,13 +69,13 @@ func TestServer(t *testing.T) {
path := filepath.Join("testdata", f.Name())
name := strings.TrimSuffix(f.Name(), filepath.Ext(f.Name()))
t.Run(name, func(t *testing.T) {
runTestScript(t, path)
runTestScript(t, path, logger)
})
}
}
func runTestScript(t *testing.T, file string) {
server := newTestServer()
func runTestScript(t *testing.T, file string, logger log.Logger) {
server := newTestServer(logger)
content, err := os.ReadFile(file)
if err != nil {
t.Fatal(err)
@ -136,7 +140,8 @@ func runTestScript(t *testing.T, file string) {
// This test checks that responses are delivered for very short-lived connections that
// only carry a single request.
func TestServerShortLivedConn(t *testing.T) {
server := newTestServer()
logger := log.New()
server := newTestServer(logger)
defer server.Stop()
listener, err := net.Listen("tcp", "127.0.0.1:0")

View File

@ -23,21 +23,23 @@ import (
"net"
"os"
"time"
"github.com/ledgerwatch/log/v3"
)
// DialStdIO creates a client on stdin/stdout.
func DialStdIO(ctx context.Context) (*Client, error) {
return DialIO(ctx, os.Stdin, os.Stdout)
func DialStdIO(ctx context.Context, logger log.Logger) (*Client, error) {
return DialIO(ctx, os.Stdin, os.Stdout, logger)
}
// DialIO creates a client which uses the given IO channels
func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) {
func DialIO(ctx context.Context, in io.Reader, out io.Writer, logger log.Logger) (*Client, error) {
return newClient(ctx, func(_ context.Context) (ServerCodec, error) {
return NewCodec(stdioConn{
in: in,
out: out,
}), nil
})
}, logger)
}
type stdioConn struct {

View File

@ -23,6 +23,8 @@ import (
"strings"
"testing"
"time"
"github.com/ledgerwatch/log/v3"
)
func TestNewID(t *testing.T) {
@ -47,13 +49,14 @@ func TestNewID(t *testing.T) {
}
func TestSubscriptions(t *testing.T) {
logger := log.New()
var (
namespaces = []string{"eth", "bzz"}
service = &notificationTestService{}
subCount = len(namespaces)
notificationCount = 3
server = NewServer(50, false /* traceRequests */, true)
server = NewServer(50, false /* traceRequests */, true, logger)
clientConn, serverConn = net.Pipe()
out = json.NewEncoder(clientConn)
in = json.NewDecoder(clientConn)
@ -125,11 +128,12 @@ func TestSubscriptions(t *testing.T) {
// This test checks that unsubscribing works.
func TestServerUnsubscribe(t *testing.T) {
logger := log.New()
p1, p2 := net.Pipe()
defer p2.Close()
// Start the server.
server := newTestServer()
server := newTestServer(logger)
service := &notificationTestService{unsubscribed: make(chan string, 1)}
server.RegisterName("nftest2", service)
go server.ServeCodec(NewCodec(p1), 0)

View File

@ -23,10 +23,12 @@ import (
"strings"
"sync"
"time"
"github.com/ledgerwatch/log/v3"
)
func newTestServer() *Server {
server := NewServer(50, false /* traceRequests */, true)
func newTestServer(logger log.Logger) *Server {
server := NewServer(50, false /* traceRequests */, true, logger)
server.idgen = sequentialIDGenerator()
if err := server.RegisterName("test", new(testService)); err != nil {
panic(err)
@ -111,7 +113,7 @@ func (s *testService) ReturnError() error {
}
func (s *testService) CallMeBack(ctx context.Context, method string, args []interface{}) (interface{}, error) {
c, ok := ClientFromContext(ctx)
c, ok := ClientFromContext(ctx, log.New())
if !ok {
return nil, errors.New("no client")
}
@ -121,7 +123,7 @@ func (s *testService) CallMeBack(ctx context.Context, method string, args []inte
}
func (s *testService) CallMeBackLater(ctx context.Context, method string, args []interface{}) error {
c, ok := ClientFromContext(ctx)
c, ok := ClientFromContext(ctx, log.New())
if !ok {
return errors.New("no client")
}

View File

@ -186,7 +186,7 @@ func parseOriginURL(origin string) (string, string, string, error) {
// DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server
// that is listening on the given endpoint using the provided dialer.
func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer, logger log.Logger) (*Client, error) {
endpoint, header, err := wsClientHeaders(endpoint, origin)
if err != nil {
return nil, err
@ -202,7 +202,7 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
return nil, hErr
}
return newWebsocketCodec(conn), nil
})
}, logger)
}
// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
@ -210,13 +210,13 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
//
// The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client.
func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
func DialWebsocket(ctx context.Context, endpoint, origin string, logger log.Logger) (*Client, error) {
dialer := websocket.Dialer{
ReadBufferSize: wsReadBuffer,
WriteBufferSize: wsWriteBuffer,
WriteBufferPool: wsBufferPool,
}
return DialWebsocketWithDialer(ctx, endpoint, origin, dialer)
return DialWebsocketWithDialer(ctx, endpoint, origin, dialer, logger)
}
func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {

View File

@ -54,14 +54,14 @@ func TestWebsocketOriginCheck(t *testing.T) {
logger := log.New()
var (
srv = newTestServer()
srv = newTestServer(logger)
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"}, nil, false, logger))
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
)
defer srv.Stop()
defer httpsrv.Close()
client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com")
client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com", logger)
if err == nil {
client.Close()
t.Fatal("no error for wrong origin")
@ -72,7 +72,7 @@ func TestWebsocketOriginCheck(t *testing.T) {
}
// Connections without origin header should work.
client, err = DialWebsocket(context.Background(), wsURL, "")
client, err = DialWebsocket(context.Background(), wsURL, "", logger)
if err != nil {
t.Fatal("error for empty origin")
}
@ -85,14 +85,14 @@ func TestWebsocketLargeCall(t *testing.T) {
logger := log.New()
var (
srv = newTestServer()
srv = newTestServer(logger)
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}, nil, false, logger))
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
)
defer srv.Stop()
defer httpsrv.Close()
client, clientErr := DialWebsocket(context.Background(), wsURL, "")
client, clientErr := DialWebsocket(context.Background(), wsURL, "", logger)
if clientErr != nil {
t.Fatalf("can't dial: %v", clientErr)
}
@ -122,6 +122,7 @@ func TestClientWebsocketPing(t *testing.T) {
}
t.Parallel()
logger := log.New()
var (
sendPing = make(chan struct{})
@ -131,7 +132,7 @@ func TestClientWebsocketPing(t *testing.T) {
defer cancel()
defer server.Shutdown(ctx)
client, err := DialContext(ctx, "ws://"+server.Addr)
client, err := DialContext(ctx, "ws://"+server.Addr, logger)
if err != nil {
t.Fatalf("client dial error: %v", err)
}
@ -167,7 +168,7 @@ func TestClientWebsocketPing(t *testing.T) {
func TestClientWebsocketLargeMessage(t *testing.T) {
logger := log.New()
var (
srv = NewServer(50, false /* traceRequests */, true)
srv = NewServer(50, false /* traceRequests */, true, logger)
httpsrv = httptest.NewServer(srv.WebsocketHandler(nil, nil, false, logger))
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
)
@ -179,7 +180,7 @@ func TestClientWebsocketLargeMessage(t *testing.T) {
t.Fatal(err)
}
c, err := DialWebsocket(context.Background(), wsURL, "")
c, err := DialWebsocket(context.Background(), wsURL, "", logger)
if err != nil {
t.Fatal(err)
}