diff --git a/rlp/parse.go b/rlp/parse.go index d707d43e2..ffe711489 100644 --- a/rlp/parse.go +++ b/rlp/parse.go @@ -127,6 +127,28 @@ func U64(payload []byte, pos int) (int, uint64, error) { return dataPos + dataLen, r, nil } +// U32 parses uint64 number from given payload at given position +func U32(payload []byte, pos int) (int, uint32, error) { + dataPos, dataLen, isList, err := Prefix(payload, pos) + if err != nil { + return 0, 0, err + } + if isList { + return 0, 0, fmt.Errorf("uint32 must be a string, not isList") + } + if dataLen > 4 { + return 0, 0, fmt.Errorf("uint32 must not be more than 4 bytes long, got %d", dataLen) + } + if dataLen > 0 && payload[dataPos] == 0 { + return 0, 0, fmt.Errorf("integer encoding for RLP must not have leading zeros: %x", payload[dataPos:dataPos+dataLen]) + } + var r uint32 + for _, b := range payload[dataPos : dataPos+dataLen] { + r = (r << 8) | uint32(b) + } + return dataPos + dataLen, r, nil +} + // U256 parses uint256 number from given payload at given position func U256(payload []byte, pos int, x *uint256.Int) (int, error) { dataPos, dataLen, err := String(payload, pos) diff --git a/rlp/parse_test.go b/rlp/parse_test.go new file mode 100644 index 000000000..d094fb34c --- /dev/null +++ b/rlp/parse_test.go @@ -0,0 +1,58 @@ +package rlp + +import ( + "encoding/hex" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func decodeHex(in string) []byte { + payload, err := hex.DecodeString(in) + if err != nil { + panic(err) + } + return payload +} + +var parseU64Tests = []struct { + payload []byte + expectPos int + expectRes uint64 + expectErr error +}{ + {payload: decodeHex("820400"), expectPos: 3, expectRes: 1024}, + {payload: decodeHex("07"), expectPos: 1, expectRes: 7}, +} + +var parseU32Tests = []struct { + payload []byte + expectPos int + expectRes uint32 + expectErr error +}{ + {payload: decodeHex("820400"), expectPos: 3, expectRes: 1024}, + {payload: decodeHex("07"), expectPos: 1, expectRes: 7}, +} + +func TestPrimitives(t *testing.T) { + for i, tt := range parseU64Tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + assert := assert.New(t) + pos, res, err := U64(tt.payload, 0) + assert.NoError(err) + assert.Equal(tt.expectPos, pos) + assert.Equal(tt.expectRes, res) + }) + } + for i, tt := range parseU32Tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + assert := assert.New(t) + pos, res, err := U32(tt.payload, 0) + assert.NoError(err) + assert.Equal(tt.expectPos, pos) + assert.Equal(tt.expectRes, res) + }) + } +}