mirror of
https://gitlab.com/pulsechaincom/go-pulse.git
synced 2025-01-03 01:07:39 +00:00
rpc: add limit for batch request items and response size (#26681)
This PR adds server-side limits for JSON-RPC batch requests. Before this change, batches were limited only by processing time. The server would pick calls from the batch and answer them until the response timeout occurred, then stop processing the remaining batch items. Here, we are adding two additional limits which can be configured: - the 'item limit': batches can have at most N items - the 'response size limit': batches can contain at most X response bytes These limits are optional in package rpc. In Geth, we set a default limit of 1000 items and 25MB response size. When a batch goes over the limit, an error response is returned to the client. However, doing this correctly isn't always possible. In JSON-RPC, only method calls with a valid `id` can be responded to. Since batches may also contain non-call messages or notifications, the best effort thing we can do to report an error with the batch itself is reporting the limit violation as an error for the first method call in the batch. If a batch is too large, but contains only notifications and responses, the error will be reported with a null `id`. The RPC client was also changed so it can deal with errors resulting from too large batches. An older client connected to the server code in this PR could get stuck until the request timeout occurred when the batch is too large. **Upgrading to a version of the RPC client containing this change is strongly recommended to avoid timeout issues.** For some weird reason, when writing the original client implementation, @fjl worked off of the assumption that responses could be distributed across batches arbitrarily. So for a batch request containing requests `[A B C]`, the server could respond with `[A B C]` but also with `[A B] [C]` or even `[A] [B] [C]` and it wouldn't make a difference to the client. So in the implementation of BatchCallContext, the client waited for all requests in the batch individually. If the server didn't respond to some of the requests in the batch, the client would eventually just time out (if a context was used). With the addition of batch limits into the server, we anticipate that people will hit this kind of error way more often. To handle this properly, the client now waits for a single response batch and expects it to contain all responses to the requests. --------- Co-authored-by: Felix Lange <fjl@twurst.com> Co-authored-by: Martin Holst Swende <martin@swende.se>
This commit is contained in:
parent
5ac4da3653
commit
f3314bb6df
@ -732,6 +732,7 @@ func signer(c *cli.Context) error {
|
|||||||
cors := utils.SplitAndTrim(c.String(utils.HTTPCORSDomainFlag.Name))
|
cors := utils.SplitAndTrim(c.String(utils.HTTPCORSDomainFlag.Name))
|
||||||
|
|
||||||
srv := rpc.NewServer()
|
srv := rpc.NewServer()
|
||||||
|
srv.SetBatchLimits(node.DefaultConfig.BatchRequestLimit, node.DefaultConfig.BatchResponseMaxSize)
|
||||||
err := node.RegisterApis(rpcAPI, []string{"account"}, srv)
|
err := node.RegisterApis(rpcAPI, []string{"account"}, srv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Fatalf("Could not register API: %w", err)
|
utils.Fatalf("Could not register API: %w", err)
|
||||||
|
@ -168,6 +168,8 @@ var (
|
|||||||
utils.RPCGlobalEVMTimeoutFlag,
|
utils.RPCGlobalEVMTimeoutFlag,
|
||||||
utils.RPCGlobalTxFeeCapFlag,
|
utils.RPCGlobalTxFeeCapFlag,
|
||||||
utils.AllowUnprotectedTxs,
|
utils.AllowUnprotectedTxs,
|
||||||
|
utils.BatchRequestLimit,
|
||||||
|
utils.BatchResponseMaxSize,
|
||||||
}
|
}
|
||||||
|
|
||||||
metricsFlags = []cli.Flag{
|
metricsFlags = []cli.Flag{
|
||||||
|
@ -713,6 +713,18 @@ var (
|
|||||||
Usage: "Allow for unprotected (non EIP155 signed) transactions to be submitted via RPC",
|
Usage: "Allow for unprotected (non EIP155 signed) transactions to be submitted via RPC",
|
||||||
Category: flags.APICategory,
|
Category: flags.APICategory,
|
||||||
}
|
}
|
||||||
|
BatchRequestLimit = &cli.IntFlag{
|
||||||
|
Name: "rpc.batch-request-limit",
|
||||||
|
Usage: "Maximum number of requests in a batch",
|
||||||
|
Value: node.DefaultConfig.BatchRequestLimit,
|
||||||
|
Category: flags.APICategory,
|
||||||
|
}
|
||||||
|
BatchResponseMaxSize = &cli.IntFlag{
|
||||||
|
Name: "rpc.batch-response-max-size",
|
||||||
|
Usage: "Maximum number of bytes returned from a batched call",
|
||||||
|
Value: node.DefaultConfig.BatchResponseMaxSize,
|
||||||
|
Category: flags.APICategory,
|
||||||
|
}
|
||||||
EnablePersonal = &cli.BoolFlag{
|
EnablePersonal = &cli.BoolFlag{
|
||||||
Name: "rpc.enabledeprecatedpersonal",
|
Name: "rpc.enabledeprecatedpersonal",
|
||||||
Usage: "Enables the (deprecated) personal namespace",
|
Usage: "Enables the (deprecated) personal namespace",
|
||||||
@ -1130,6 +1142,14 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) {
|
|||||||
if ctx.IsSet(AllowUnprotectedTxs.Name) {
|
if ctx.IsSet(AllowUnprotectedTxs.Name) {
|
||||||
cfg.AllowUnprotectedTxs = ctx.Bool(AllowUnprotectedTxs.Name)
|
cfg.AllowUnprotectedTxs = ctx.Bool(AllowUnprotectedTxs.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ctx.IsSet(BatchRequestLimit.Name) {
|
||||||
|
cfg.BatchRequestLimit = ctx.Int(BatchRequestLimit.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.IsSet(BatchResponseMaxSize.Name) {
|
||||||
|
cfg.BatchResponseMaxSize = ctx.Int(BatchResponseMaxSize.Name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// setGraphQL creates the GraphQL listener interface string from the set
|
// setGraphQL creates the GraphQL listener interface string from the set
|
||||||
|
@ -176,6 +176,10 @@ func (api *adminAPI) StartHTTP(host *string, port *int, cors *string, apis *stri
|
|||||||
CorsAllowedOrigins: api.node.config.HTTPCors,
|
CorsAllowedOrigins: api.node.config.HTTPCors,
|
||||||
Vhosts: api.node.config.HTTPVirtualHosts,
|
Vhosts: api.node.config.HTTPVirtualHosts,
|
||||||
Modules: api.node.config.HTTPModules,
|
Modules: api.node.config.HTTPModules,
|
||||||
|
rpcEndpointConfig: rpcEndpointConfig{
|
||||||
|
batchItemLimit: api.node.config.BatchRequestLimit,
|
||||||
|
batchResponseSizeLimit: api.node.config.BatchResponseMaxSize,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if cors != nil {
|
if cors != nil {
|
||||||
config.CorsAllowedOrigins = nil
|
config.CorsAllowedOrigins = nil
|
||||||
@ -250,6 +254,10 @@ func (api *adminAPI) StartWS(host *string, port *int, allowedOrigins *string, ap
|
|||||||
Modules: api.node.config.WSModules,
|
Modules: api.node.config.WSModules,
|
||||||
Origins: api.node.config.WSOrigins,
|
Origins: api.node.config.WSOrigins,
|
||||||
// ExposeAll: api.node.config.WSExposeAll,
|
// ExposeAll: api.node.config.WSExposeAll,
|
||||||
|
rpcEndpointConfig: rpcEndpointConfig{
|
||||||
|
batchItemLimit: api.node.config.BatchRequestLimit,
|
||||||
|
batchResponseSizeLimit: api.node.config.BatchResponseMaxSize,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if apis != nil {
|
if apis != nil {
|
||||||
config.Modules = nil
|
config.Modules = nil
|
||||||
|
@ -197,6 +197,12 @@ type Config struct {
|
|||||||
// AllowUnprotectedTxs allows non EIP-155 protected transactions to be send over RPC.
|
// AllowUnprotectedTxs allows non EIP-155 protected transactions to be send over RPC.
|
||||||
AllowUnprotectedTxs bool `toml:",omitempty"`
|
AllowUnprotectedTxs bool `toml:",omitempty"`
|
||||||
|
|
||||||
|
// BatchRequestLimit is the maximum number of requests in a batch.
|
||||||
|
BatchRequestLimit int `toml:",omitempty"`
|
||||||
|
|
||||||
|
// BatchResponseMaxSize is the maximum number of bytes returned from a batched rpc call.
|
||||||
|
BatchResponseMaxSize int `toml:",omitempty"`
|
||||||
|
|
||||||
// JWTSecret is the path to the hex-encoded jwt secret.
|
// JWTSecret is the path to the hex-encoded jwt secret.
|
||||||
JWTSecret string `toml:",omitempty"`
|
JWTSecret string `toml:",omitempty"`
|
||||||
|
|
||||||
|
@ -46,17 +46,19 @@ var (
|
|||||||
|
|
||||||
// DefaultConfig contains reasonable default settings.
|
// DefaultConfig contains reasonable default settings.
|
||||||
var DefaultConfig = Config{
|
var DefaultConfig = Config{
|
||||||
DataDir: DefaultDataDir(),
|
DataDir: DefaultDataDir(),
|
||||||
HTTPPort: DefaultHTTPPort,
|
HTTPPort: DefaultHTTPPort,
|
||||||
AuthAddr: DefaultAuthHost,
|
AuthAddr: DefaultAuthHost,
|
||||||
AuthPort: DefaultAuthPort,
|
AuthPort: DefaultAuthPort,
|
||||||
AuthVirtualHosts: DefaultAuthVhosts,
|
AuthVirtualHosts: DefaultAuthVhosts,
|
||||||
HTTPModules: []string{"net", "web3"},
|
HTTPModules: []string{"net", "web3"},
|
||||||
HTTPVirtualHosts: []string{"localhost"},
|
HTTPVirtualHosts: []string{"localhost"},
|
||||||
HTTPTimeouts: rpc.DefaultHTTPTimeouts,
|
HTTPTimeouts: rpc.DefaultHTTPTimeouts,
|
||||||
WSPort: DefaultWSPort,
|
WSPort: DefaultWSPort,
|
||||||
WSModules: []string{"net", "web3"},
|
WSModules: []string{"net", "web3"},
|
||||||
GraphQLVirtualHosts: []string{"localhost"},
|
BatchRequestLimit: 1000,
|
||||||
|
BatchResponseMaxSize: 25 * 1000 * 1000,
|
||||||
|
GraphQLVirtualHosts: []string{"localhost"},
|
||||||
P2P: p2p.Config{
|
P2P: p2p.Config{
|
||||||
ListenAddr: ":30303",
|
ListenAddr: ":30303",
|
||||||
MaxPeers: 50,
|
MaxPeers: 50,
|
||||||
|
31
node/node.go
31
node/node.go
@ -101,10 +101,11 @@ func New(conf *Config) (*Node, error) {
|
|||||||
if strings.HasSuffix(conf.Name, ".ipc") {
|
if strings.HasSuffix(conf.Name, ".ipc") {
|
||||||
return nil, errors.New(`Config.Name cannot end in ".ipc"`)
|
return nil, errors.New(`Config.Name cannot end in ".ipc"`)
|
||||||
}
|
}
|
||||||
|
server := rpc.NewServer()
|
||||||
|
server.SetBatchLimits(conf.BatchRequestLimit, conf.BatchResponseMaxSize)
|
||||||
node := &Node{
|
node := &Node{
|
||||||
config: conf,
|
config: conf,
|
||||||
inprocHandler: rpc.NewServer(),
|
inprocHandler: server,
|
||||||
eventmux: new(event.TypeMux),
|
eventmux: new(event.TypeMux),
|
||||||
log: conf.Logger,
|
log: conf.Logger,
|
||||||
stop: make(chan struct{}),
|
stop: make(chan struct{}),
|
||||||
@ -403,6 +404,11 @@ func (n *Node) startRPC() error {
|
|||||||
openAPIs, allAPIs = n.getAPIs()
|
openAPIs, allAPIs = n.getAPIs()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rpcConfig := rpcEndpointConfig{
|
||||||
|
batchItemLimit: n.config.BatchRequestLimit,
|
||||||
|
batchResponseSizeLimit: n.config.BatchResponseMaxSize,
|
||||||
|
}
|
||||||
|
|
||||||
initHttp := func(server *httpServer, port int) error {
|
initHttp := func(server *httpServer, port int) error {
|
||||||
if err := server.setListenAddr(n.config.HTTPHost, port); err != nil {
|
if err := server.setListenAddr(n.config.HTTPHost, port); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -412,6 +418,7 @@ func (n *Node) startRPC() error {
|
|||||||
Vhosts: n.config.HTTPVirtualHosts,
|
Vhosts: n.config.HTTPVirtualHosts,
|
||||||
Modules: n.config.HTTPModules,
|
Modules: n.config.HTTPModules,
|
||||||
prefix: n.config.HTTPPathPrefix,
|
prefix: n.config.HTTPPathPrefix,
|
||||||
|
rpcEndpointConfig: rpcConfig,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -425,9 +432,10 @@ func (n *Node) startRPC() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := server.enableWS(openAPIs, wsConfig{
|
if err := server.enableWS(openAPIs, wsConfig{
|
||||||
Modules: n.config.WSModules,
|
Modules: n.config.WSModules,
|
||||||
Origins: n.config.WSOrigins,
|
Origins: n.config.WSOrigins,
|
||||||
prefix: n.config.WSPathPrefix,
|
prefix: n.config.WSPathPrefix,
|
||||||
|
rpcEndpointConfig: rpcConfig,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -441,26 +449,29 @@ func (n *Node) startRPC() error {
|
|||||||
if err := server.setListenAddr(n.config.AuthAddr, port); err != nil {
|
if err := server.setListenAddr(n.config.AuthAddr, port); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
sharedConfig := rpcConfig
|
||||||
|
sharedConfig.jwtSecret = secret
|
||||||
if err := server.enableRPC(allAPIs, httpConfig{
|
if err := server.enableRPC(allAPIs, httpConfig{
|
||||||
CorsAllowedOrigins: DefaultAuthCors,
|
CorsAllowedOrigins: DefaultAuthCors,
|
||||||
Vhosts: n.config.AuthVirtualHosts,
|
Vhosts: n.config.AuthVirtualHosts,
|
||||||
Modules: DefaultAuthModules,
|
Modules: DefaultAuthModules,
|
||||||
prefix: DefaultAuthPrefix,
|
prefix: DefaultAuthPrefix,
|
||||||
jwtSecret: secret,
|
rpcEndpointConfig: sharedConfig,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
servers = append(servers, server)
|
servers = append(servers, server)
|
||||||
|
|
||||||
// Enable auth via WS
|
// Enable auth via WS
|
||||||
server = n.wsServerForPort(port, true)
|
server = n.wsServerForPort(port, true)
|
||||||
if err := server.setListenAddr(n.config.AuthAddr, port); err != nil {
|
if err := server.setListenAddr(n.config.AuthAddr, port); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := server.enableWS(allAPIs, wsConfig{
|
if err := server.enableWS(allAPIs, wsConfig{
|
||||||
Modules: DefaultAuthModules,
|
Modules: DefaultAuthModules,
|
||||||
Origins: DefaultAuthOrigins,
|
Origins: DefaultAuthOrigins,
|
||||||
prefix: DefaultAuthPrefix,
|
prefix: DefaultAuthPrefix,
|
||||||
jwtSecret: secret,
|
rpcEndpointConfig: sharedConfig,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -41,15 +41,21 @@ type httpConfig struct {
|
|||||||
CorsAllowedOrigins []string
|
CorsAllowedOrigins []string
|
||||||
Vhosts []string
|
Vhosts []string
|
||||||
prefix string // path prefix on which to mount http handler
|
prefix string // path prefix on which to mount http handler
|
||||||
jwtSecret []byte // optional JWT secret
|
rpcEndpointConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// wsConfig is the JSON-RPC/Websocket configuration
|
// wsConfig is the JSON-RPC/Websocket configuration
|
||||||
type wsConfig struct {
|
type wsConfig struct {
|
||||||
Origins []string
|
Origins []string
|
||||||
Modules []string
|
Modules []string
|
||||||
prefix string // path prefix on which to mount ws handler
|
prefix string // path prefix on which to mount ws handler
|
||||||
jwtSecret []byte // optional JWT secret
|
rpcEndpointConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type rpcEndpointConfig struct {
|
||||||
|
jwtSecret []byte // optional JWT secret
|
||||||
|
batchItemLimit int
|
||||||
|
batchResponseSizeLimit int
|
||||||
}
|
}
|
||||||
|
|
||||||
type rpcHandler struct {
|
type rpcHandler struct {
|
||||||
@ -297,6 +303,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error {
|
|||||||
|
|
||||||
// Create RPC server and handler.
|
// Create RPC server and handler.
|
||||||
srv := rpc.NewServer()
|
srv := rpc.NewServer()
|
||||||
|
srv.SetBatchLimits(config.batchItemLimit, config.batchResponseSizeLimit)
|
||||||
if err := RegisterApis(apis, config.Modules, srv); err != nil {
|
if err := RegisterApis(apis, config.Modules, srv); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -328,6 +335,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error {
|
|||||||
}
|
}
|
||||||
// Create RPC server and handler.
|
// Create RPC server and handler.
|
||||||
srv := rpc.NewServer()
|
srv := rpc.NewServer()
|
||||||
|
srv.SetBatchLimits(config.batchItemLimit, config.batchResponseSizeLimit)
|
||||||
if err := RegisterApis(apis, config.Modules, srv); err != nil {
|
if err := RegisterApis(apis, config.Modules, srv); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -339,8 +339,10 @@ func TestJWT(t *testing.T) {
|
|||||||
ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret)
|
ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret)
|
||||||
return ss
|
return ss
|
||||||
}
|
}
|
||||||
srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")},
|
cfg := rpcEndpointConfig{jwtSecret: []byte("secret")}
|
||||||
true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}, nil)
|
httpcfg := &httpConfig{rpcEndpointConfig: cfg}
|
||||||
|
wscfg := &wsConfig{Origins: []string{"*"}, rpcEndpointConfig: cfg}
|
||||||
|
srv := createAndStartServer(t, httpcfg, true, wscfg, nil)
|
||||||
wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr())
|
wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr())
|
||||||
htUrl := fmt.Sprintf("http://%v", srv.listenAddr())
|
htUrl := fmt.Sprintf("http://%v", srv.listenAddr())
|
||||||
|
|
||||||
|
131
rpc/client.go
131
rpc/client.go
@ -34,14 +34,15 @@ import (
|
|||||||
var (
|
var (
|
||||||
ErrBadResult = errors.New("bad result in JSON-RPC response")
|
ErrBadResult = errors.New("bad result in JSON-RPC response")
|
||||||
ErrClientQuit = errors.New("client is closed")
|
ErrClientQuit = errors.New("client is closed")
|
||||||
ErrNoResult = errors.New("no result in JSON-RPC response")
|
ErrNoResult = errors.New("JSON-RPC response has no result")
|
||||||
|
ErrMissingBatchResponse = errors.New("response batch did not contain a response to this call")
|
||||||
ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow")
|
ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow")
|
||||||
errClientReconnected = errors.New("client reconnected")
|
errClientReconnected = errors.New("client reconnected")
|
||||||
errDead = errors.New("connection lost")
|
errDead = errors.New("connection lost")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Timeouts
|
||||||
const (
|
const (
|
||||||
// Timeouts
|
|
||||||
defaultDialTimeout = 10 * time.Second // used if context has no deadline
|
defaultDialTimeout = 10 * time.Second // used if context has no deadline
|
||||||
subscribeTimeout = 10 * time.Second // overall timeout eth_subscribe, rpc_modules calls
|
subscribeTimeout = 10 * time.Second // overall timeout eth_subscribe, rpc_modules calls
|
||||||
)
|
)
|
||||||
@ -84,6 +85,10 @@ type Client struct {
|
|||||||
// This function, if non-nil, is called when the connection is lost.
|
// This function, if non-nil, is called when the connection is lost.
|
||||||
reconnectFunc reconnectFunc
|
reconnectFunc reconnectFunc
|
||||||
|
|
||||||
|
// config fields
|
||||||
|
batchItemLimit int
|
||||||
|
batchResponseMaxSize int
|
||||||
|
|
||||||
// writeConn is used for writing to the connection on the caller's goroutine. It should
|
// writeConn is used for writing to the connection on the caller's goroutine. It should
|
||||||
// only be accessed outside of dispatch, with the write lock held. The write lock is
|
// only be accessed outside of dispatch, with the write lock held. The write lock is
|
||||||
// taken by sending on reqInit and released by sending on reqSent.
|
// taken by sending on reqInit and released by sending on reqSent.
|
||||||
@ -114,7 +119,7 @@ func (c *Client) newClientConn(conn ServerCodec) *clientConn {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ctx = context.WithValue(ctx, clientContextKey{}, c)
|
ctx = context.WithValue(ctx, clientContextKey{}, c)
|
||||||
ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo())
|
ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo())
|
||||||
handler := newHandler(ctx, conn, c.idgen, c.services)
|
handler := newHandler(ctx, conn, c.idgen, c.services, c.batchItemLimit, c.batchResponseMaxSize)
|
||||||
return &clientConn{conn, handler}
|
return &clientConn{conn, handler}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -128,14 +133,17 @@ type readOp struct {
|
|||||||
batch bool
|
batch bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// requestOp represents a pending request. This is used for both batch and non-batch
|
||||||
|
// requests.
|
||||||
type requestOp struct {
|
type requestOp struct {
|
||||||
ids []json.RawMessage
|
ids []json.RawMessage
|
||||||
err error
|
err error
|
||||||
resp chan *jsonrpcMessage // receives up to len(ids) responses
|
resp chan []*jsonrpcMessage // the response goes here
|
||||||
sub *ClientSubscription // only set for EthSubscribe requests
|
sub *ClientSubscription // set for Subscribe requests.
|
||||||
|
hadResponse bool // true when the request was responded to
|
||||||
}
|
}
|
||||||
|
|
||||||
func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, error) {
|
func (op *requestOp) wait(ctx context.Context, c *Client) ([]*jsonrpcMessage, error) {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// Send the timeout to dispatch so it can remove the request IDs.
|
// Send the timeout to dispatch so it can remove the request IDs.
|
||||||
@ -211,7 +219,7 @@ func DialOptions(ctx context.Context, rawurl string, options ...ClientOption) (*
|
|||||||
return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme)
|
return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newClient(ctx, reconnect)
|
return newClient(ctx, cfg, reconnect)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientFromContext retrieves the client from the context, if any. This can be used to perform
|
// ClientFromContext retrieves the client from the context, if any. This can be used to perform
|
||||||
@ -221,33 +229,42 @@ func ClientFromContext(ctx context.Context) (*Client, bool) {
|
|||||||
return client, ok
|
return client, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) {
|
func newClient(initctx context.Context, cfg *clientConfig, connect reconnectFunc) (*Client, error) {
|
||||||
conn, err := connect(initctx)
|
conn, err := connect(initctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c := initClient(conn, randomIDGenerator(), new(serviceRegistry))
|
c := initClient(conn, new(serviceRegistry), cfg)
|
||||||
c.reconnectFunc = connect
|
c.reconnectFunc = connect
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client {
|
func initClient(conn ServerCodec, services *serviceRegistry, cfg *clientConfig) *Client {
|
||||||
_, isHTTP := conn.(*httpConn)
|
_, isHTTP := conn.(*httpConn)
|
||||||
c := &Client{
|
c := &Client{
|
||||||
isHTTP: isHTTP,
|
isHTTP: isHTTP,
|
||||||
idgen: idgen,
|
services: services,
|
||||||
services: services,
|
idgen: cfg.idgen,
|
||||||
writeConn: conn,
|
batchItemLimit: cfg.batchItemLimit,
|
||||||
close: make(chan struct{}),
|
batchResponseMaxSize: cfg.batchResponseLimit,
|
||||||
closing: make(chan struct{}),
|
writeConn: conn,
|
||||||
didClose: make(chan struct{}),
|
close: make(chan struct{}),
|
||||||
reconnected: make(chan ServerCodec),
|
closing: make(chan struct{}),
|
||||||
readOp: make(chan readOp),
|
didClose: make(chan struct{}),
|
||||||
readErr: make(chan error),
|
reconnected: make(chan ServerCodec),
|
||||||
reqInit: make(chan *requestOp),
|
readOp: make(chan readOp),
|
||||||
reqSent: make(chan error, 1),
|
readErr: make(chan error),
|
||||||
reqTimeout: make(chan *requestOp),
|
reqInit: make(chan *requestOp),
|
||||||
|
reqSent: make(chan error, 1),
|
||||||
|
reqTimeout: make(chan *requestOp),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set defaults.
|
||||||
|
if c.idgen == nil {
|
||||||
|
c.idgen = randomIDGenerator()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch the main loop.
|
||||||
if !isHTTP {
|
if !isHTTP {
|
||||||
go c.dispatch(conn)
|
go c.dispatch(conn)
|
||||||
}
|
}
|
||||||
@ -325,7 +342,10 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
op := &requestOp{ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage, 1)}
|
op := &requestOp{
|
||||||
|
ids: []json.RawMessage{msg.ID},
|
||||||
|
resp: make(chan []*jsonrpcMessage, 1),
|
||||||
|
}
|
||||||
|
|
||||||
if c.isHTTP {
|
if c.isHTTP {
|
||||||
err = c.sendHTTP(ctx, op, msg)
|
err = c.sendHTTP(ctx, op, msg)
|
||||||
@ -337,9 +357,12 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// dispatch has accepted the request and will close the channel when it quits.
|
// dispatch has accepted the request and will close the channel when it quits.
|
||||||
switch resp, err := op.wait(ctx, c); {
|
batchresp, err := op.wait(ctx, c)
|
||||||
case err != nil:
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
resp := batchresp[0]
|
||||||
|
switch {
|
||||||
case resp.Error != nil:
|
case resp.Error != nil:
|
||||||
return resp.Error
|
return resp.Error
|
||||||
case len(resp.Result) == 0:
|
case len(resp.Result) == 0:
|
||||||
@ -380,7 +403,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error {
|
|||||||
)
|
)
|
||||||
op := &requestOp{
|
op := &requestOp{
|
||||||
ids: make([]json.RawMessage, len(b)),
|
ids: make([]json.RawMessage, len(b)),
|
||||||
resp: make(chan *jsonrpcMessage, len(b)),
|
resp: make(chan []*jsonrpcMessage, 1),
|
||||||
}
|
}
|
||||||
for i, elem := range b {
|
for i, elem := range b {
|
||||||
msg, err := c.newMessage(elem.Method, elem.Args...)
|
msg, err := c.newMessage(elem.Method, elem.Args...)
|
||||||
@ -398,28 +421,48 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error {
|
|||||||
} else {
|
} else {
|
||||||
err = c.send(ctx, op, msgs)
|
err = c.send(ctx, op, msgs)
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
batchresp, err := op.wait(ctx, c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Wait for all responses to come back.
|
// Wait for all responses to come back.
|
||||||
for n := 0; n < len(b) && err == nil; n++ {
|
for n := 0; n < len(batchresp) && err == nil; n++ {
|
||||||
var resp *jsonrpcMessage
|
resp := batchresp[n]
|
||||||
resp, err = op.wait(ctx, c)
|
if resp == nil {
|
||||||
if err != nil {
|
// Ignore null responses. These can happen for batches sent via HTTP.
|
||||||
break
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the element corresponding to this response.
|
// Find the element corresponding to this response.
|
||||||
// The element is guaranteed to be present because dispatch
|
index, ok := byID[string(resp.ID)]
|
||||||
// only sends valid IDs to our channel.
|
if !ok {
|
||||||
elem := &b[byID[string(resp.ID)]]
|
continue
|
||||||
if resp.Error != nil {
|
}
|
||||||
|
delete(byID, string(resp.ID))
|
||||||
|
|
||||||
|
// Assign result and error.
|
||||||
|
elem := &b[index]
|
||||||
|
switch {
|
||||||
|
case resp.Error != nil:
|
||||||
elem.Error = resp.Error
|
elem.Error = resp.Error
|
||||||
continue
|
case resp.Result == nil:
|
||||||
}
|
|
||||||
if len(resp.Result) == 0 {
|
|
||||||
elem.Error = ErrNoResult
|
elem.Error = ErrNoResult
|
||||||
continue
|
default:
|
||||||
|
elem.Error = json.Unmarshal(resp.Result, elem.Result)
|
||||||
}
|
}
|
||||||
elem.Error = json.Unmarshal(resp.Result, elem.Result)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check that all expected responses have been received.
|
||||||
|
for _, index := range byID {
|
||||||
|
elem := &b[index]
|
||||||
|
elem.Error = ErrMissingBatchResponse
|
||||||
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -480,7 +523,7 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf
|
|||||||
}
|
}
|
||||||
op := &requestOp{
|
op := &requestOp{
|
||||||
ids: []json.RawMessage{msg.ID},
|
ids: []json.RawMessage{msg.ID},
|
||||||
resp: make(chan *jsonrpcMessage),
|
resp: make(chan []*jsonrpcMessage, 1),
|
||||||
sub: newClientSubscription(c, namespace, chanVal),
|
sub: newClientSubscription(c, namespace, chanVal),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,11 +28,18 @@ type ClientOption interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type clientConfig struct {
|
type clientConfig struct {
|
||||||
|
// HTTP settings
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
httpHeaders http.Header
|
httpHeaders http.Header
|
||||||
httpAuth HTTPAuth
|
httpAuth HTTPAuth
|
||||||
|
|
||||||
|
// WebSocket options
|
||||||
wsDialer *websocket.Dialer
|
wsDialer *websocket.Dialer
|
||||||
|
|
||||||
|
// RPC handler options
|
||||||
|
idgen func() ID
|
||||||
|
batchItemLimit int
|
||||||
|
batchResponseLimit int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *clientConfig) initHeaders() {
|
func (cfg *clientConfig) initHeaders() {
|
||||||
@ -104,3 +111,25 @@ func WithHTTPAuth(a HTTPAuth) ClientOption {
|
|||||||
// Usually, HTTPAuth functions will call h.Set("authorization", "...") to add
|
// Usually, HTTPAuth functions will call h.Set("authorization", "...") to add
|
||||||
// auth information to the request.
|
// auth information to the request.
|
||||||
type HTTPAuth func(h http.Header) error
|
type HTTPAuth func(h http.Header) error
|
||||||
|
|
||||||
|
// WithBatchItemLimit changes the maximum number of items allowed in batch requests.
|
||||||
|
//
|
||||||
|
// Note: this option applies when processing incoming batch requests. It does not affect
|
||||||
|
// batch requests sent by the client.
|
||||||
|
func WithBatchItemLimit(limit int) ClientOption {
|
||||||
|
return optionFunc(func(cfg *clientConfig) {
|
||||||
|
cfg.batchItemLimit = limit
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBatchResponseSizeLimit changes the maximum number of response bytes that can be
|
||||||
|
// generated for batch requests. When this limit is reached, further calls in the batch
|
||||||
|
// will not be processed.
|
||||||
|
//
|
||||||
|
// Note: this option applies when processing incoming batch requests. It does not affect
|
||||||
|
// batch requests sent by the client.
|
||||||
|
func WithBatchResponseSizeLimit(sizeLimit int) ClientOption {
|
||||||
|
return optionFunc(func(cfg *clientConfig) {
|
||||||
|
cfg.batchResponseLimit = sizeLimit
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -169,10 +169,12 @@ func TestClientBatchRequest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This checks that, for HTTP connections, the length of batch responses is validated to
|
||||||
|
// match the request exactly.
|
||||||
func TestClientBatchRequest_len(t *testing.T) {
|
func TestClientBatchRequest_len(t *testing.T) {
|
||||||
b, err := json.Marshal([]jsonrpcMessage{
|
b, err := json.Marshal([]jsonrpcMessage{
|
||||||
{Version: "2.0", ID: json.RawMessage("1"), Method: "foo", Result: json.RawMessage(`"0x1"`)},
|
{Version: "2.0", ID: json.RawMessage("1"), Result: json.RawMessage(`"0x1"`)},
|
||||||
{Version: "2.0", ID: json.RawMessage("2"), Method: "bar", Result: json.RawMessage(`"0x2"`)},
|
{Version: "2.0", ID: json.RawMessage("2"), Result: json.RawMessage(`"0x2"`)},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("failed to encode jsonrpc message:", err)
|
t.Fatal("failed to encode jsonrpc message:", err)
|
||||||
@ -185,37 +187,102 @@ func TestClientBatchRequest_len(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
t.Cleanup(s.Close)
|
t.Cleanup(s.Close)
|
||||||
|
|
||||||
client, err := Dial(s.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("failed to dial test server:", err)
|
|
||||||
}
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
t.Run("too-few", func(t *testing.T) {
|
t.Run("too-few", func(t *testing.T) {
|
||||||
|
client, err := Dial(s.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to dial test server:", err)
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
batch := []BatchElem{
|
batch := []BatchElem{
|
||||||
{Method: "foo"},
|
{Method: "foo", Result: new(string)},
|
||||||
{Method: "bar"},
|
{Method: "bar", Result: new(string)},
|
||||||
{Method: "baz"},
|
{Method: "baz", Result: new(string)},
|
||||||
}
|
}
|
||||||
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
|
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
|
||||||
defer cancelFn()
|
defer cancelFn()
|
||||||
if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) {
|
|
||||||
t.Errorf("expected %q but got: %v", ErrBadResult, err)
|
if err := client.BatchCallContext(ctx, batch); err != nil {
|
||||||
|
t.Fatal("error:", err)
|
||||||
|
}
|
||||||
|
for i, elem := range batch[:2] {
|
||||||
|
if elem.Error != nil {
|
||||||
|
t.Errorf("expected no error for batch element %d, got %q", i, elem.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i, elem := range batch[2:] {
|
||||||
|
if elem.Error != ErrMissingBatchResponse {
|
||||||
|
t.Errorf("wrong error %q for batch element %d", elem.Error, i+2)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("too-many", func(t *testing.T) {
|
t.Run("too-many", func(t *testing.T) {
|
||||||
|
client, err := Dial(s.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to dial test server:", err)
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
batch := []BatchElem{
|
batch := []BatchElem{
|
||||||
{Method: "foo"},
|
{Method: "foo", Result: new(string)},
|
||||||
}
|
}
|
||||||
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
|
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
|
||||||
defer cancelFn()
|
defer cancelFn()
|
||||||
if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) {
|
|
||||||
t.Errorf("expected %q but got: %v", ErrBadResult, err)
|
if err := client.BatchCallContext(ctx, batch); err != nil {
|
||||||
|
t.Fatal("error:", err)
|
||||||
|
}
|
||||||
|
for i, elem := range batch[:1] {
|
||||||
|
if elem.Error != nil {
|
||||||
|
t.Errorf("expected no error for batch element %d, got %q", i, elem.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i, elem := range batch[1:] {
|
||||||
|
if elem.Error != ErrMissingBatchResponse {
|
||||||
|
t.Errorf("wrong error %q for batch element %d", elem.Error, i+2)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This checks that the client can handle the case where the server doesn't
|
||||||
|
// respond to all requests in a batch.
|
||||||
|
func TestClientBatchRequestLimit(t *testing.T) {
|
||||||
|
server := newTestServer()
|
||||||
|
defer server.Stop()
|
||||||
|
server.SetBatchLimits(2, 100000)
|
||||||
|
client := DialInProc(server)
|
||||||
|
|
||||||
|
batch := []BatchElem{
|
||||||
|
{Method: "foo"},
|
||||||
|
{Method: "bar"},
|
||||||
|
{Method: "baz"},
|
||||||
|
}
|
||||||
|
err := client.BatchCall(batch)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("unexpected error:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the first response indicates an error with batch size.
|
||||||
|
var err0 Error
|
||||||
|
if !errors.As(batch[0].Error, &err0) {
|
||||||
|
t.Log("error zero:", batch[0].Error)
|
||||||
|
t.Fatalf("batch elem 0 has wrong error type: %T", batch[0].Error)
|
||||||
|
} else {
|
||||||
|
if err0.ErrorCode() != -32600 || err0.Error() != errMsgBatchTooLarge {
|
||||||
|
t.Fatalf("wrong error on batch elem zero: %v", err0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that remaining response batch elements are reported as absent.
|
||||||
|
for i, elem := range batch[1:] {
|
||||||
|
if elem.Error != ErrMissingBatchResponse {
|
||||||
|
t.Fatalf("batch elem %d has unexpected error: %v", i+1, elem.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestClientNotify(t *testing.T) {
|
func TestClientNotify(t *testing.T) {
|
||||||
server := newTestServer()
|
server := newTestServer()
|
||||||
defer server.Stop()
|
defer server.Stop()
|
||||||
@ -310,7 +377,7 @@ func testClientCancel(transport string, t *testing.T) {
|
|||||||
_, hasDeadline := ctx.Deadline()
|
_, hasDeadline := ctx.Deadline()
|
||||||
t.Errorf("no error for call with %v wait time (deadline: %v)", timeout, hasDeadline)
|
t.Errorf("no error for call with %v wait time (deadline: %v)", timeout, hasDeadline)
|
||||||
// default:
|
// default:
|
||||||
// t.Logf("got expected error with %v wait time: %v", timeout, err)
|
// t.Logf("got expected error with %v wait time: %v", timeout, err)
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
}
|
}
|
||||||
@ -487,7 +554,8 @@ func TestClientSubscriptionUnsubscribeServer(t *testing.T) {
|
|||||||
defer srv.Stop()
|
defer srv.Stop()
|
||||||
|
|
||||||
// Create the client on the other end of the pipe.
|
// Create the client on the other end of the pipe.
|
||||||
client, _ := newClient(context.Background(), func(context.Context) (ServerCodec, error) {
|
cfg := new(clientConfig)
|
||||||
|
client, _ := newClient(context.Background(), cfg, func(context.Context) (ServerCodec, error) {
|
||||||
return NewCodec(p2), nil
|
return NewCodec(p2), nil
|
||||||
})
|
})
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
|
@ -61,12 +61,15 @@ const (
|
|||||||
errcodeDefault = -32000
|
errcodeDefault = -32000
|
||||||
errcodeNotificationsUnsupported = -32001
|
errcodeNotificationsUnsupported = -32001
|
||||||
errcodeTimeout = -32002
|
errcodeTimeout = -32002
|
||||||
|
errcodeResponseTooLarge = -32003
|
||||||
errcodePanic = -32603
|
errcodePanic = -32603
|
||||||
errcodeMarshalError = -32603
|
errcodeMarshalError = -32603
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
errMsgTimeout = "request timed out"
|
errMsgTimeout = "request timed out"
|
||||||
|
errMsgResponseTooLarge = "response too large"
|
||||||
|
errMsgBatchTooLarge = "batch too large"
|
||||||
)
|
)
|
||||||
|
|
||||||
type methodNotFoundError struct{ method string }
|
type methodNotFoundError struct{ method string }
|
||||||
|
289
rpc/handler.go
289
rpc/handler.go
@ -49,17 +49,19 @@ import (
|
|||||||
// h.removeRequestOp(op) // timeout, etc.
|
// h.removeRequestOp(op) // timeout, etc.
|
||||||
// }
|
// }
|
||||||
type handler struct {
|
type handler struct {
|
||||||
reg *serviceRegistry
|
reg *serviceRegistry
|
||||||
unsubscribeCb *callback
|
unsubscribeCb *callback
|
||||||
idgen func() ID // subscription ID generator
|
idgen func() ID // subscription ID generator
|
||||||
respWait map[string]*requestOp // active client requests
|
respWait map[string]*requestOp // active client requests
|
||||||
clientSubs map[string]*ClientSubscription // active client subscriptions
|
clientSubs map[string]*ClientSubscription // active client subscriptions
|
||||||
callWG sync.WaitGroup // pending call goroutines
|
callWG sync.WaitGroup // pending call goroutines
|
||||||
rootCtx context.Context // canceled by close()
|
rootCtx context.Context // canceled by close()
|
||||||
cancelRoot func() // cancel function for rootCtx
|
cancelRoot func() // cancel function for rootCtx
|
||||||
conn jsonWriter // where responses will be sent
|
conn jsonWriter // where responses will be sent
|
||||||
log log.Logger
|
log log.Logger
|
||||||
allowSubscribe bool
|
allowSubscribe bool
|
||||||
|
batchRequestLimit int
|
||||||
|
batchResponseMaxSize int
|
||||||
|
|
||||||
subLock sync.Mutex
|
subLock sync.Mutex
|
||||||
serverSubs map[ID]*Subscription
|
serverSubs map[ID]*Subscription
|
||||||
@ -70,19 +72,21 @@ type callProc struct {
|
|||||||
notifiers []*Notifier
|
notifiers []*Notifier
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry) *handler {
|
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, batchRequestLimit, batchResponseMaxSize int) *handler {
|
||||||
rootCtx, cancelRoot := context.WithCancel(connCtx)
|
rootCtx, cancelRoot := context.WithCancel(connCtx)
|
||||||
h := &handler{
|
h := &handler{
|
||||||
reg: reg,
|
reg: reg,
|
||||||
idgen: idgen,
|
idgen: idgen,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
respWait: make(map[string]*requestOp),
|
respWait: make(map[string]*requestOp),
|
||||||
clientSubs: make(map[string]*ClientSubscription),
|
clientSubs: make(map[string]*ClientSubscription),
|
||||||
rootCtx: rootCtx,
|
rootCtx: rootCtx,
|
||||||
cancelRoot: cancelRoot,
|
cancelRoot: cancelRoot,
|
||||||
allowSubscribe: true,
|
allowSubscribe: true,
|
||||||
serverSubs: make(map[ID]*Subscription),
|
serverSubs: make(map[ID]*Subscription),
|
||||||
log: log.Root(),
|
log: log.Root(),
|
||||||
|
batchRequestLimit: batchRequestLimit,
|
||||||
|
batchResponseMaxSize: batchResponseMaxSize,
|
||||||
}
|
}
|
||||||
if conn.remoteAddr() != "" {
|
if conn.remoteAddr() != "" {
|
||||||
h.log = h.log.New("conn", conn.remoteAddr())
|
h.log = h.log.New("conn", conn.remoteAddr())
|
||||||
@ -134,16 +138,15 @@ func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) {
|
|||||||
b.doWrite(ctx, conn, false)
|
b.doWrite(ctx, conn, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// timeout sends the responses added so far. For the remaining unanswered call
|
// respondWithError sends the responses added so far. For the remaining unanswered call
|
||||||
// messages, it sends a timeout error response.
|
// messages, it responds with the given error.
|
||||||
func (b *batchCallBuffer) timeout(ctx context.Context, conn jsonWriter) {
|
func (b *batchCallBuffer) respondWithError(ctx context.Context, conn jsonWriter, err error) {
|
||||||
b.mutex.Lock()
|
b.mutex.Lock()
|
||||||
defer b.mutex.Unlock()
|
defer b.mutex.Unlock()
|
||||||
|
|
||||||
for _, msg := range b.calls {
|
for _, msg := range b.calls {
|
||||||
if !msg.isNotification() {
|
if !msg.isNotification() {
|
||||||
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
|
b.resp = append(b.resp, msg.errorResponse(err))
|
||||||
b.resp = append(b.resp, resp)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
b.doWrite(ctx, conn, true)
|
b.doWrite(ctx, conn, true)
|
||||||
@ -171,17 +174,24 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Apply limit on total number of requests.
|
||||||
// Handle non-call messages first:
|
if h.batchRequestLimit != 0 && len(msgs) > h.batchRequestLimit {
|
||||||
calls := make([]*jsonrpcMessage, 0, len(msgs))
|
h.startCallProc(func(cp *callProc) {
|
||||||
for _, msg := range msgs {
|
h.respondWithBatchTooLarge(cp, msgs)
|
||||||
if handled := h.handleImmediate(msg); !handled {
|
})
|
||||||
calls = append(calls, msg)
|
return
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle non-call messages first.
|
||||||
|
// Here we need to find the requestOp that sent the request batch.
|
||||||
|
calls := make([]*jsonrpcMessage, 0, len(msgs))
|
||||||
|
h.handleResponses(msgs, func(msg *jsonrpcMessage) {
|
||||||
|
calls = append(calls, msg)
|
||||||
|
})
|
||||||
if len(calls) == 0 {
|
if len(calls) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process calls on a goroutine because they may block indefinitely:
|
// Process calls on a goroutine because they may block indefinitely:
|
||||||
h.startCallProc(func(cp *callProc) {
|
h.startCallProc(func(cp *callProc) {
|
||||||
var (
|
var (
|
||||||
@ -199,10 +209,12 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
|
|||||||
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
|
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
|
||||||
timer = time.AfterFunc(timeout, func() {
|
timer = time.AfterFunc(timeout, func() {
|
||||||
cancel()
|
cancel()
|
||||||
callBuffer.timeout(cp.ctx, h.conn)
|
err := &internalServerError{errcodeTimeout, errMsgTimeout}
|
||||||
|
callBuffer.respondWithError(cp.ctx, h.conn, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
responseBytes := 0
|
||||||
for {
|
for {
|
||||||
// No need to handle rest of calls if timed out.
|
// No need to handle rest of calls if timed out.
|
||||||
if cp.ctx.Err() != nil {
|
if cp.ctx.Err() != nil {
|
||||||
@ -214,61 +226,88 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
|
|||||||
}
|
}
|
||||||
resp := h.handleCallMsg(cp, msg)
|
resp := h.handleCallMsg(cp, msg)
|
||||||
callBuffer.pushResponse(resp)
|
callBuffer.pushResponse(resp)
|
||||||
|
if resp != nil && h.batchResponseMaxSize != 0 {
|
||||||
|
responseBytes += len(resp.Result)
|
||||||
|
if responseBytes > h.batchResponseMaxSize {
|
||||||
|
err := &internalServerError{errcodeResponseTooLarge, errMsgResponseTooLarge}
|
||||||
|
callBuffer.respondWithError(cp.ctx, h.conn, err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if timer != nil {
|
if timer != nil {
|
||||||
timer.Stop()
|
timer.Stop()
|
||||||
}
|
}
|
||||||
callBuffer.write(cp.ctx, h.conn)
|
|
||||||
h.addSubscriptions(cp.notifiers)
|
h.addSubscriptions(cp.notifiers)
|
||||||
|
callBuffer.write(cp.ctx, h.conn)
|
||||||
for _, n := range cp.notifiers {
|
for _, n := range cp.notifiers {
|
||||||
n.activate()
|
n.activate()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleMsg handles a single message.
|
func (h *handler) respondWithBatchTooLarge(cp *callProc, batch []*jsonrpcMessage) {
|
||||||
func (h *handler) handleMsg(msg *jsonrpcMessage) {
|
resp := errorMessage(&invalidRequestError{errMsgBatchTooLarge})
|
||||||
if ok := h.handleImmediate(msg); ok {
|
// Find the first call and add its "id" field to the error.
|
||||||
return
|
// This is the best we can do, given that the protocol doesn't have a way
|
||||||
|
// of reporting an error for the entire batch.
|
||||||
|
for _, msg := range batch {
|
||||||
|
if msg.isCall() {
|
||||||
|
resp.ID = msg.ID
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
h.startCallProc(func(cp *callProc) {
|
h.conn.writeJSON(cp.ctx, []*jsonrpcMessage{resp}, true)
|
||||||
var (
|
}
|
||||||
responded sync.Once
|
|
||||||
timer *time.Timer
|
|
||||||
cancel context.CancelFunc
|
|
||||||
)
|
|
||||||
cp.ctx, cancel = context.WithCancel(cp.ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Cancel the request context after timeout and send an error response. Since the
|
// handleMsg handles a single non-batch message.
|
||||||
// running method might not return immediately on timeout, we must wait for the
|
func (h *handler) handleMsg(msg *jsonrpcMessage) {
|
||||||
// timeout concurrently with processing the request.
|
msgs := []*jsonrpcMessage{msg}
|
||||||
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
|
h.handleResponses(msgs, func(msg *jsonrpcMessage) {
|
||||||
timer = time.AfterFunc(timeout, func() {
|
h.startCallProc(func(cp *callProc) {
|
||||||
cancel()
|
h.handleNonBatchCall(cp, msg)
|
||||||
responded.Do(func() {
|
})
|
||||||
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
|
|
||||||
h.conn.writeJSON(cp.ctx, resp, true)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
answer := h.handleCallMsg(cp, msg)
|
|
||||||
if timer != nil {
|
|
||||||
timer.Stop()
|
|
||||||
}
|
|
||||||
h.addSubscriptions(cp.notifiers)
|
|
||||||
if answer != nil {
|
|
||||||
responded.Do(func() {
|
|
||||||
h.conn.writeJSON(cp.ctx, answer, false)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
for _, n := range cp.notifiers {
|
|
||||||
n.activate()
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *handler) handleNonBatchCall(cp *callProc, msg *jsonrpcMessage) {
|
||||||
|
var (
|
||||||
|
responded sync.Once
|
||||||
|
timer *time.Timer
|
||||||
|
cancel context.CancelFunc
|
||||||
|
)
|
||||||
|
cp.ctx, cancel = context.WithCancel(cp.ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Cancel the request context after timeout and send an error response. Since the
|
||||||
|
// running method might not return immediately on timeout, we must wait for the
|
||||||
|
// timeout concurrently with processing the request.
|
||||||
|
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
|
||||||
|
timer = time.AfterFunc(timeout, func() {
|
||||||
|
cancel()
|
||||||
|
responded.Do(func() {
|
||||||
|
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
|
||||||
|
h.conn.writeJSON(cp.ctx, resp, true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
answer := h.handleCallMsg(cp, msg)
|
||||||
|
if timer != nil {
|
||||||
|
timer.Stop()
|
||||||
|
}
|
||||||
|
h.addSubscriptions(cp.notifiers)
|
||||||
|
if answer != nil {
|
||||||
|
responded.Do(func() {
|
||||||
|
h.conn.writeJSON(cp.ctx, answer, false)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, n := range cp.notifiers {
|
||||||
|
n.activate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// close cancels all requests except for inflightReq and waits for
|
// close cancels all requests except for inflightReq and waits for
|
||||||
// call goroutines to shut down.
|
// call goroutines to shut down.
|
||||||
func (h *handler) close(err error, inflightReq *requestOp) {
|
func (h *handler) close(err error, inflightReq *requestOp) {
|
||||||
@ -349,23 +388,60 @@ func (h *handler) startCallProc(fn func(*callProc)) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleImmediate executes non-call messages. It returns false if the message is a
|
// handleResponse processes method call responses.
|
||||||
// call or requires a reply.
|
func (h *handler) handleResponses(batch []*jsonrpcMessage, handleCall func(*jsonrpcMessage)) {
|
||||||
func (h *handler) handleImmediate(msg *jsonrpcMessage) bool {
|
var resolvedops []*requestOp
|
||||||
start := time.Now()
|
handleResp := func(msg *jsonrpcMessage) {
|
||||||
switch {
|
op := h.respWait[string(msg.ID)]
|
||||||
case msg.isNotification():
|
if op == nil {
|
||||||
if strings.HasSuffix(msg.Method, notificationMethodSuffix) {
|
h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID})
|
||||||
h.handleSubscriptionResult(msg)
|
return
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
return false
|
resolvedops = append(resolvedops, op)
|
||||||
case msg.isResponse():
|
delete(h.respWait, string(msg.ID))
|
||||||
h.handleResponse(msg)
|
|
||||||
h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start))
|
// For subscription responses, start the subscription if the server
|
||||||
return true
|
// indicates success. EthSubscribe gets unblocked in either case through
|
||||||
default:
|
// the op.resp channel.
|
||||||
return false
|
if op.sub != nil {
|
||||||
|
if msg.Error != nil {
|
||||||
|
op.err = msg.Error
|
||||||
|
} else {
|
||||||
|
op.err = json.Unmarshal(msg.Result, &op.sub.subid)
|
||||||
|
if op.err == nil {
|
||||||
|
go op.sub.run()
|
||||||
|
h.clientSubs[op.sub.subid] = op.sub
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !op.hadResponse {
|
||||||
|
op.hadResponse = true
|
||||||
|
op.resp <- batch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range batch {
|
||||||
|
start := time.Now()
|
||||||
|
switch {
|
||||||
|
case msg.isResponse():
|
||||||
|
handleResp(msg)
|
||||||
|
h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start))
|
||||||
|
|
||||||
|
case msg.isNotification():
|
||||||
|
if strings.HasSuffix(msg.Method, notificationMethodSuffix) {
|
||||||
|
h.handleSubscriptionResult(msg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
handleCall(msg)
|
||||||
|
|
||||||
|
default:
|
||||||
|
handleCall(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, op := range resolvedops {
|
||||||
|
h.removeRequestOp(op)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -381,33 +457,6 @@ func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleResponse processes method call responses.
|
|
||||||
func (h *handler) handleResponse(msg *jsonrpcMessage) {
|
|
||||||
op := h.respWait[string(msg.ID)]
|
|
||||||
if op == nil {
|
|
||||||
h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
delete(h.respWait, string(msg.ID))
|
|
||||||
// For normal responses, just forward the reply to Call/BatchCall.
|
|
||||||
if op.sub == nil {
|
|
||||||
op.resp <- msg
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// For subscription responses, start the subscription if the server
|
|
||||||
// indicates success. EthSubscribe gets unblocked in either case through
|
|
||||||
// the op.resp channel.
|
|
||||||
defer close(op.resp)
|
|
||||||
if msg.Error != nil {
|
|
||||||
op.err = msg.Error
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil {
|
|
||||||
go op.sub.run()
|
|
||||||
h.clientSubs[op.sub.subid] = op.sub
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleCallMsg executes a call message and returns the answer.
|
// handleCallMsg executes a call message and returns the answer.
|
||||||
func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
|
func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
@ -416,6 +465,7 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess
|
|||||||
h.handleCall(ctx, msg)
|
h.handleCall(ctx, msg)
|
||||||
h.log.Debug("Served "+msg.Method, "duration", time.Since(start))
|
h.log.Debug("Served "+msg.Method, "duration", time.Since(start))
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
case msg.isCall():
|
case msg.isCall():
|
||||||
resp := h.handleCall(ctx, msg)
|
resp := h.handleCall(ctx, msg)
|
||||||
var ctx []interface{}
|
var ctx []interface{}
|
||||||
@ -430,8 +480,10 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess
|
|||||||
h.log.Debug("Served "+msg.Method, ctx...)
|
h.log.Debug("Served "+msg.Method, ctx...)
|
||||||
}
|
}
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
case msg.hasValidID():
|
case msg.hasValidID():
|
||||||
return msg.errorResponse(&invalidRequestError{"invalid request"})
|
return msg.errorResponse(&invalidRequestError{"invalid request"})
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return errorMessage(&invalidRequestError{"invalid request"})
|
return errorMessage(&invalidRequestError{"invalid request"})
|
||||||
}
|
}
|
||||||
@ -451,12 +503,14 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
|
|||||||
if callb == nil {
|
if callb == nil {
|
||||||
return msg.errorResponse(&methodNotFoundError{method: msg.Method})
|
return msg.errorResponse(&methodNotFoundError{method: msg.Method})
|
||||||
}
|
}
|
||||||
|
|
||||||
args, err := parsePositionalArguments(msg.Params, callb.argTypes)
|
args, err := parsePositionalArguments(msg.Params, callb.argTypes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return msg.errorResponse(&invalidParamsError{err.Error()})
|
return msg.errorResponse(&invalidParamsError{err.Error()})
|
||||||
}
|
}
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
answer := h.runMethod(cp.ctx, msg, callb, args)
|
answer := h.runMethod(cp.ctx, msg, callb, args)
|
||||||
|
|
||||||
// Collect the statistics for RPC calls if metrics is enabled.
|
// Collect the statistics for RPC calls if metrics is enabled.
|
||||||
// We only care about pure rpc call. Filter out subscription.
|
// We only care about pure rpc call. Filter out subscription.
|
||||||
if callb != h.unsubscribeCb {
|
if callb != h.unsubscribeCb {
|
||||||
@ -469,6 +523,7 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
|
|||||||
rpcServingTimer.UpdateSince(start)
|
rpcServingTimer.UpdateSince(start)
|
||||||
updateServeTimeHistogram(msg.Method, answer.Error == nil, time.Since(start))
|
updateServeTimeHistogram(msg.Method, answer.Error == nil, time.Since(start))
|
||||||
}
|
}
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
19
rpc/http.go
19
rpc/http.go
@ -139,7 +139,7 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
|
|||||||
var cfg clientConfig
|
var cfg clientConfig
|
||||||
cfg.httpClient = client
|
cfg.httpClient = client
|
||||||
fn := newClientTransportHTTP(endpoint, &cfg)
|
fn := newClientTransportHTTP(endpoint, &cfg)
|
||||||
return newClient(context.Background(), fn)
|
return newClient(context.Background(), &cfg, fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc {
|
func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc {
|
||||||
@ -176,11 +176,12 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e
|
|||||||
}
|
}
|
||||||
defer respBody.Close()
|
defer respBody.Close()
|
||||||
|
|
||||||
var respmsg jsonrpcMessage
|
var resp jsonrpcMessage
|
||||||
if err := json.NewDecoder(respBody).Decode(&respmsg); err != nil {
|
batch := [1]*jsonrpcMessage{&resp}
|
||||||
|
if err := json.NewDecoder(respBody).Decode(&resp); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
op.resp <- &respmsg
|
op.resp <- batch[:]
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -191,16 +192,12 @@ func (c *Client) sendBatchHTTP(ctx context.Context, op *requestOp, msgs []*jsonr
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer respBody.Close()
|
defer respBody.Close()
|
||||||
var respmsgs []jsonrpcMessage
|
|
||||||
|
var respmsgs []*jsonrpcMessage
|
||||||
if err := json.NewDecoder(respBody).Decode(&respmsgs); err != nil {
|
if err := json.NewDecoder(respBody).Decode(&respmsgs); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(respmsgs) != len(msgs) {
|
op.resp <- respmsgs
|
||||||
return fmt.Errorf("batch has %d requests but response has %d: %w", len(msgs), len(respmsgs), ErrBadResult)
|
|
||||||
}
|
|
||||||
for i := 0; i < len(respmsgs); i++ {
|
|
||||||
op.resp <- &respmsgs[i]
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,7 +24,8 @@ import (
|
|||||||
// DialInProc attaches an in-process connection to the given RPC server.
|
// DialInProc attaches an in-process connection to the given RPC server.
|
||||||
func DialInProc(handler *Server) *Client {
|
func DialInProc(handler *Server) *Client {
|
||||||
initctx := context.Background()
|
initctx := context.Background()
|
||||||
c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) {
|
cfg := new(clientConfig)
|
||||||
|
c, _ := newClient(initctx, cfg, func(context.Context) (ServerCodec, error) {
|
||||||
p1, p2 := net.Pipe()
|
p1, p2 := net.Pipe()
|
||||||
go handler.ServeCodec(NewCodec(p1), 0)
|
go handler.ServeCodec(NewCodec(p1), 0)
|
||||||
return NewCodec(p2), nil
|
return NewCodec(p2), nil
|
||||||
|
@ -46,7 +46,8 @@ func (s *Server) ServeListener(l net.Listener) error {
|
|||||||
// The context is used for the initial connection establishment. It does not
|
// The context is used for the initial connection establishment. It does not
|
||||||
// affect subsequent interactions with the client.
|
// affect subsequent interactions with the client.
|
||||||
func DialIPC(ctx context.Context, endpoint string) (*Client, error) {
|
func DialIPC(ctx context.Context, endpoint string) (*Client, error) {
|
||||||
return newClient(ctx, newClientTransportIPC(endpoint))
|
cfg := new(clientConfig)
|
||||||
|
return newClient(ctx, cfg, newClientTransportIPC(endpoint))
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientTransportIPC(endpoint string) reconnectFunc {
|
func newClientTransportIPC(endpoint string) reconnectFunc {
|
||||||
|
@ -46,9 +46,11 @@ type Server struct {
|
|||||||
services serviceRegistry
|
services serviceRegistry
|
||||||
idgen func() ID
|
idgen func() ID
|
||||||
|
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
codecs map[ServerCodec]struct{}
|
codecs map[ServerCodec]struct{}
|
||||||
run atomic.Bool
|
run atomic.Bool
|
||||||
|
batchItemLimit int
|
||||||
|
batchResponseLimit int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new server instance with no registered handlers.
|
// NewServer creates a new server instance with no registered handlers.
|
||||||
@ -65,6 +67,17 @@ func NewServer() *Server {
|
|||||||
return server
|
return server
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetBatchLimits sets limits applied to batch requests. There are two limits: 'itemLimit'
|
||||||
|
// is the maximum number of items in a batch. 'maxResponseSize' is the maximum number of
|
||||||
|
// response bytes across all requests in a batch.
|
||||||
|
//
|
||||||
|
// This method should be called before processing any requests via ServeCodec, ServeHTTP,
|
||||||
|
// ServeListener etc.
|
||||||
|
func (s *Server) SetBatchLimits(itemLimit, maxResponseSize int) {
|
||||||
|
s.batchItemLimit = itemLimit
|
||||||
|
s.batchResponseLimit = maxResponseSize
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterName creates a service for the given receiver type under the given name. When no
|
// RegisterName creates a service for the given receiver type under the given name. When no
|
||||||
// methods on the given receiver match the criteria to be either a RPC method or a
|
// methods on the given receiver match the criteria to be either a RPC method or a
|
||||||
// subscription an error is returned. Otherwise a new service is created and added to the
|
// subscription an error is returned. Otherwise a new service is created and added to the
|
||||||
@ -86,7 +99,12 @@ func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) {
|
|||||||
}
|
}
|
||||||
defer s.untrackCodec(codec)
|
defer s.untrackCodec(codec)
|
||||||
|
|
||||||
c := initClient(codec, s.idgen, &s.services)
|
cfg := &clientConfig{
|
||||||
|
idgen: s.idgen,
|
||||||
|
batchItemLimit: s.batchItemLimit,
|
||||||
|
batchResponseLimit: s.batchResponseLimit,
|
||||||
|
}
|
||||||
|
c := initClient(codec, &s.services, cfg)
|
||||||
<-codec.closed()
|
<-codec.closed()
|
||||||
c.Close()
|
c.Close()
|
||||||
}
|
}
|
||||||
@ -118,7 +136,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h := newHandler(ctx, codec, s.idgen, &s.services)
|
h := newHandler(ctx, codec, s.idgen, &s.services, s.batchItemLimit, s.batchResponseLimit)
|
||||||
h.allowSubscribe = false
|
h.allowSubscribe = false
|
||||||
defer h.close(io.EOF, nil)
|
defer h.close(io.EOF, nil)
|
||||||
|
|
||||||
|
@ -70,6 +70,7 @@ func TestServer(t *testing.T) {
|
|||||||
|
|
||||||
func runTestScript(t *testing.T, file string) {
|
func runTestScript(t *testing.T, file string) {
|
||||||
server := newTestServer()
|
server := newTestServer()
|
||||||
|
server.SetBatchLimits(4, 100000)
|
||||||
content, err := os.ReadFile(file)
|
content, err := os.ReadFile(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -152,3 +153,41 @@ func TestServerShortLivedConn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerBatchResponseSizeLimit(t *testing.T) {
|
||||||
|
server := newTestServer()
|
||||||
|
defer server.Stop()
|
||||||
|
server.SetBatchLimits(100, 60)
|
||||||
|
var (
|
||||||
|
batch []BatchElem
|
||||||
|
client = DialInProc(server)
|
||||||
|
)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
batch = append(batch, BatchElem{
|
||||||
|
Method: "test_echo",
|
||||||
|
Args: []any{"x", 1},
|
||||||
|
Result: new(echoResult),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if err := client.BatchCall(batch); err != nil {
|
||||||
|
t.Fatal("error sending batch:", err)
|
||||||
|
}
|
||||||
|
for i := range batch {
|
||||||
|
// We expect the first two queries to be ok, but after that the size limit takes effect.
|
||||||
|
if i < 2 {
|
||||||
|
if batch[i].Error != nil {
|
||||||
|
t.Fatalf("batch elem %d has unexpected error: %v", i, batch[i].Error)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// After two, we expect an error.
|
||||||
|
re, ok := batch[i].Error.(Error)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("batch elem %d has wrong error: %v", i, batch[i].Error)
|
||||||
|
}
|
||||||
|
wantedCode := errcodeResponseTooLarge
|
||||||
|
if re.ErrorCode() != wantedCode {
|
||||||
|
t.Errorf("batch elem %d wrong error code, have %d want %d", i, re.ErrorCode(), wantedCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -32,7 +32,8 @@ func DialStdIO(ctx context.Context) (*Client, error) {
|
|||||||
|
|
||||||
// DialIO creates a client which uses the given IO channels
|
// DialIO creates a client which uses the given IO channels
|
||||||
func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) {
|
func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) {
|
||||||
return newClient(ctx, newClientTransportIO(in, out))
|
cfg := new(clientConfig)
|
||||||
|
return newClient(ctx, cfg, newClientTransportIO(in, out))
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientTransportIO(in io.Reader, out io.Writer) reconnectFunc {
|
func newClientTransportIO(in io.Reader, out io.Writer) reconnectFunc {
|
||||||
|
13
rpc/testdata/invalid-batch-toolarge.js
vendored
Normal file
13
rpc/testdata/invalid-batch-toolarge.js
vendored
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
// This file checks the behavior of the batch item limit code.
|
||||||
|
// In tests, the batch item limit is set to 4. So to trigger the error,
|
||||||
|
// all batches in this file have 5 elements.
|
||||||
|
|
||||||
|
// For batches that do not contain any calls, a response message with "id" == null
|
||||||
|
// is returned.
|
||||||
|
|
||||||
|
--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}]
|
||||||
|
<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"batch too large"}}]
|
||||||
|
|
||||||
|
// For batches with at least one call, the call's "id" is used.
|
||||||
|
--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","id":3,"method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}]
|
||||||
|
<-- [{"jsonrpc":"2.0","id":3,"error":{"code":-32600,"message":"batch too large"}}]
|
@ -197,7 +197,7 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return newClient(ctx, connect)
|
return newClient(ctx, cfg, connect)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
|
// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
|
||||||
@ -214,7 +214,7 @@ func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return newClient(ctx, connect)
|
return newClient(ctx, cfg, connect)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) {
|
func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) {
|
||||||
|
Loading…
Reference in New Issue
Block a user