diff --git a/rpc/args.go b/rpc/args.go index 686872a59..11d9a2a70 100644 --- a/rpc/args.go +++ b/rpc/args.go @@ -172,13 +172,8 @@ type NewSigArgs struct { } func (args *NewSigArgs) UnmarshalJSON(b []byte) (err error) { - var obj []json.RawMessage - var ext struct { - From string - Data string - } + var obj []interface{} - // Decode byte slice to array of RawMessages if err := json.Unmarshal(b, &obj); err != nil { return NewDecodeParamError(err.Error()) } @@ -188,21 +183,26 @@ func (args *NewSigArgs) UnmarshalJSON(b []byte) (err error) { return NewInsufficientParamsError(len(obj), 1) } - // Decode 0th RawMessage to temporary struct - if err := json.Unmarshal(obj[0], &ext); err != nil { - return NewDecodeParamError(err.Error()) + from, ok := obj[0].(string) + if !ok { + return NewInvalidTypeError("from", "not a string") } + args.From = from - if len(ext.From) == 0 { + if len(args.From) == 0 { return NewValidationError("from", "is required") } - if len(ext.Data) == 0 { + data, ok := obj[1].(string) + if !ok { + return NewInvalidTypeError("data", "not a string") + } + args.Data = data + + if len(args.Data) == 0 { return NewValidationError("data", "is required") } - args.From = ext.From - args.Data = ext.Data return nil } diff --git a/rpc/args_test.go b/rpc/args_test.go index 09ce12467..9ca73660a 100644 --- a/rpc/args_test.go +++ b/rpc/args_test.go @@ -2504,3 +2504,64 @@ func TestSourceArgsEmpty(t *testing.T) { t.Error(str) } } + +func TestSigArgs(t *testing.T) { + input := `["0xa94f5374fce5edbc8e2a8697c15331677e6ebf0b", "0x0"]` + expected := new(NewSigArgs) + expected.From = "0xa94f5374fce5edbc8e2a8697c15331677e6ebf0b" + expected.Data = "0x0" + + args := new(NewSigArgs) + if err := json.Unmarshal([]byte(input), &args); err != nil { + t.Error(err) + } +} + +func TestSigArgsEmptyData(t *testing.T) { + input := `["0xa94f5374fce5edbc8e2a8697c15331677e6ebf0b", ""]` + + args := new(NewSigArgs) + str := ExpectValidationError(json.Unmarshal([]byte(input), args)) + if len(str) > 0 { + t.Error(str) + } +} + +func TestSigArgsDataType(t *testing.T) { + input := `["0xa94f5374fce5edbc8e2a8697c15331677e6ebf0b", 13]` + + args := new(NewSigArgs) + str := ExpectInvalidTypeError(json.Unmarshal([]byte(input), args)) + if len(str) > 0 { + t.Error(str) + } +} + +func TestSigArgsEmptyFrom(t *testing.T) { + input := `["", "0x0"]` + + args := new(NewSigArgs) + str := ExpectValidationError(json.Unmarshal([]byte(input), args)) + if len(str) > 0 { + t.Error(str) + } +} + +func TestSigArgsFromType(t *testing.T) { + input := `[false, "0x0"]` + + args := new(NewSigArgs) + str := ExpectInvalidTypeError(json.Unmarshal([]byte(input), args)) + if len(str) > 0 { + t.Error(str) + } +} + +func TestSigArgsEmpty(t *testing.T) { + input := `[]` + args := new(NewSigArgs) + str := ExpectInsufficientParamsError(json.Unmarshal([]byte(input), args)) + if len(str) > 0 { + t.Error(str) + } +}