erigon-pulse/erigon-lib/rlp2/unmarshaler.go
a 2aab8f496c
rlp2 (#8643)
rlp2 is a package that aims to replace the existing erigon-lib/rlp
package and the erigon/common/rlp

it is called rlp2 for now because it requires breaking changes to
erigon-lib/rlp and i do not have the time right now to test all current
uses of such functions

however, the encoder/decoder characteristics of rlp2 might be desirable
for caplin, and also for execution layer parsing blob txns, so im
putting it in a folder called rlp2 (note that it exports package rlp for
easier switching later)

importantly, rlp2 is designed for single-pass decoding with the ability
to skip elements one does not care about. it also is zero alloc.
2023-11-04 09:22:11 +07:00

192 lines
4.3 KiB
Go

package rlp
import (
"fmt"
"reflect"
)
type Unmarshaler interface {
UnmarshalRLP(data []byte) error
}
func Unmarshal(data []byte, val any) error {
buf := newBuf(data, 0)
return unmarshal(buf, val)
}
func unmarshal(buf *buf, val any) error {
rv := reflect.ValueOf(val)
if rv.Kind() != reflect.Pointer || rv.IsNil() {
return fmt.Errorf("%w: v must be ptr", ErrDecode)
}
v := rv.Elem()
err := reflectAny(buf, v, rv)
if err != nil {
return fmt.Errorf("%w: %w", ErrDecode, err)
}
return nil
}
func reflectAny(w *buf, v reflect.Value, rv reflect.Value) error {
if um, ok := rv.Interface().(Unmarshaler); ok {
return um.UnmarshalRLP(w.Bytes())
}
// figure out what we are reading
prefix, err := w.ReadByte()
if err != nil {
return err
}
token := identifyToken(prefix)
// switch
switch token {
case TokenDecimal:
// in this case, the value is just the byte itself
switch v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v.SetInt(int64(prefix))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
v.SetUint(uint64(prefix))
case reflect.Invalid:
// do nothing
default:
return fmt.Errorf("%w: decimal must be unmarshal into integer type", ErrDecode)
}
case TokenShortBlob:
sz := int(token.Diff(prefix))
str, err := nextFull(w, sz)
if err != nil {
return err
}
return putBlob(str, v, rv)
case TokenLongBlob:
lenSz := int(token.Diff(prefix))
sz, err := nextBeInt(w, lenSz)
if err != nil {
return err
}
str, err := nextFull(w, sz)
if err != nil {
return err
}
return putBlob(str, v, rv)
case TokenShortList:
sz := int(token.Diff(prefix))
buf, err := nextFull(w, sz)
if err != nil {
return err
}
return reflectList(newBuf(buf, 0), v, rv)
case TokenLongList:
lenSz := int(token.Diff(prefix))
sz, err := nextBeInt(w, lenSz)
if err != nil {
return err
}
buf, err := nextFull(w, sz)
if err != nil {
return err
}
return reflectList(newBuf(buf, 0), v, rv)
case TokenUnknown:
return fmt.Errorf("%w: unknown token", ErrDecode)
}
return nil
}
func putBlob(w []byte, v reflect.Value, rv reflect.Value) error {
switch v.Kind() {
case reflect.String:
v.SetString(string(w))
case reflect.Slice:
if v.Type().Elem().Kind() != reflect.Uint8 {
return fmt.Errorf("%w: need to use uint8 as underlying if want slice output from longstring", ErrDecode)
}
v.SetBytes(w)
case reflect.Array:
if v.Type().Elem().Kind() != reflect.Uint8 {
return fmt.Errorf("%w: need to use uint8 as underlying if want array output from longstring", ErrDecode)
}
reflect.Copy(v, reflect.ValueOf(w))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
val, err := BeInt(w, 0, len(w))
if err != nil {
return err
}
v.SetInt(int64(val))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
val, err := BeInt(w, 0, len(w))
if err != nil {
return err
}
v.SetUint(uint64(val))
case reflect.Invalid:
// do nothing
return nil
}
return nil
}
func reflectList(w *buf, v reflect.Value, rv reflect.Value) error {
switch v.Kind() {
case reflect.Invalid:
// do nothing
return nil
case reflect.Map:
rv1 := reflect.New(v.Type().Key())
v1 := rv1.Elem()
err := reflectAny(w, v1, rv1)
if err != nil {
return err
}
rv2 := reflect.New(v.Type().Elem())
v2 := rv2.Elem()
err = reflectAny(w, v2, rv2)
if err != nil {
return err
}
v.SetMapIndex(rv1, rv2)
case reflect.Struct:
for idx := 0; idx < v.NumField(); idx++ {
// Decode into element.
rv1 := v.Field(idx).Addr()
rt1 := v.Type().Field(idx)
v1 := rv1.Elem()
shouldSet := rt1.IsExported()
if shouldSet {
err := reflectAny(w, v1, rv1)
if err != nil {
return err
}
}
}
case reflect.Array, reflect.Slice:
idx := 0
for {
if idx >= v.Cap() {
v.Grow(1)
}
if idx >= v.Len() {
v.SetLen(idx + 1)
}
if idx < v.Len() {
// Decode into element.
rv1 := v.Index(idx)
v1 := rv1.Elem()
err := reflectAny(w, v1, rv1)
if err != nil {
return err
}
} else {
// Ran out of fixed array: skip.
rv1 := reflect.Value{}
err := reflectAny(w, rv1, rv1)
if err != nil {
return err
}
}
idx++
}
}
return nil
}