diff --git a/rpc/handler.go b/rpc/handler.go index 2cae39e03..d79102817 100644 --- a/rpc/handler.go +++ b/rpc/handler.go @@ -510,7 +510,7 @@ func (h *handler) runMethod(ctx context.Context, msg *jsonrpcMessage, callb *cal stream.WriteObjectField("result") _, err := callb.call(ctx, msg.Method, args, stream) if err != nil { - stream.WriteNil() + writeNilIfNotPresent(stream) stream.WriteMore() HandleError(err, stream) } @@ -519,6 +519,29 @@ func (h *handler) runMethod(ctx context.Context, msg *jsonrpcMessage, callb *cal return nil } +var nullAsBytes = []byte{110, 117, 108, 108} + +// there are many avenues that could lead to an error being handled in runMethod, so we need to check +// if nil has already been written to the stream before writing it again here +func writeNilIfNotPresent(stream *jsoniter.Stream) { + b := stream.Buffer() + hasNil := true + if len(b) >= 4 { + b = b[len(b)-4:] + for i, v := range nullAsBytes { + if v != b[i] { + hasNil = false + break + } + } + } else { + hasNil = false + } + if !hasNil { + stream.WriteNil() + } +} + // unsubscribe is the callback function for all *_unsubscribe calls. func (h *handler) unsubscribe(ctx context.Context, id ID) (bool, error) { h.subLock.Lock() diff --git a/rpc/handler_test.go b/rpc/handler_test.go new file mode 100644 index 000000000..3b0e4a044 --- /dev/null +++ b/rpc/handler_test.go @@ -0,0 +1,84 @@ +package rpc + +import ( + "bytes" + "context" + "fmt" + "reflect" + "testing" + + jsoniter "github.com/json-iterator/go" + "github.com/stretchr/testify/assert" +) + +func TestHandlerDoesNotDoubleWriteNull(t *testing.T) { + + tests := map[string]struct { + params []byte + expected string + }{ + "error_with_stream_write": { + params: []byte("[1]"), + expected: `{"jsonrpc":"2.0","id":1,"result":null,"error":{"code":-32000,"message":"id 1"}}`, + }, + "error_without_stream_write": { + params: []byte("[2]"), + expected: `{"jsonrpc":"2.0","id":1,"result":null,"error":{"code":-32000,"message":"id 2"}}`, + }, + "no_error": { + params: []byte("[3]"), + expected: `{"jsonrpc":"2.0","id":1,"result":{}}`, + }, + } + + for name, testParams := range tests { + t.Run(name, func(t *testing.T) { + msg := jsonrpcMessage{ + Version: "2.0", + ID: []byte{49}, + Method: "test_test", + Params: testParams.params, + Error: nil, + Result: nil, + } + + dummyFunc := func(id int, stream *jsoniter.Stream) error { + if id == 1 { + stream.WriteNil() + return fmt.Errorf("id 1") + } + if id == 2 { + return fmt.Errorf("id 2") + } + stream.WriteEmptyObject() + return nil + } + + var arg1 int + cb := &callback{ + fn: reflect.ValueOf(dummyFunc), + rcvr: reflect.Value{}, + argTypes: []reflect.Type{reflect.TypeOf(arg1)}, + hasCtx: false, + errPos: 0, + isSubscribe: false, + streamable: true, + } + + args, err := parsePositionalArguments((msg).Params, cb.argTypes) + if err != nil { + t.Fatal(err) + } + + var buf bytes.Buffer + stream := jsoniter.NewStream(jsoniter.ConfigDefault, &buf, 4096) + + h := handler{} + h.runMethod(context.Background(), &msg, cb, args, stream) + + output := buf.String() + assert.Equal(t, testParams.expected, output, "expected output should match") + }) + } + +}