mirror of
https://gitlab.com/pulsechaincom/prysm-pulse.git
synced 2024-12-31 23:41:22 +00:00
226 lines
6.3 KiB
Go
226 lines
6.3 KiB
Go
package ssz
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
|
|
bytesutil "github.com/prysmaticlabs/prysm/shared/bytes"
|
|
"github.com/prysmaticlabs/prysm/shared/hashutil"
|
|
)
|
|
|
|
const hashLengthBytes = 32
|
|
const sszChunkSize = 128
|
|
|
|
// Hashable defines the interface for supporting tree-hash function.
|
|
type Hashable interface {
|
|
TreeHashSSZ() ([32]byte, error)
|
|
}
|
|
|
|
// TreeHash calculates tree-hash result for input value.
|
|
func TreeHash(val interface{}) ([32]byte, error) {
|
|
if val == nil {
|
|
return [32]byte{}, newHashError("nil is not supported", nil)
|
|
}
|
|
rval := reflect.ValueOf(val)
|
|
sszUtils, err := cachedSSZUtils(rval.Type())
|
|
if err != nil {
|
|
return [32]byte{}, newHashError(fmt.Sprint(err), rval.Type())
|
|
}
|
|
output, err := sszUtils.hasher(rval)
|
|
if err != nil {
|
|
return [32]byte{}, newHashError(fmt.Sprint(err), rval.Type())
|
|
}
|
|
// Right-pad with 0 to make 32 bytes long, if necessary
|
|
paddedOutput := bytesutil.ToBytes32(output)
|
|
return paddedOutput, nil
|
|
}
|
|
|
|
type hashError struct {
|
|
msg string
|
|
typ reflect.Type
|
|
}
|
|
|
|
func (err *hashError) Error() string {
|
|
return fmt.Sprintf("hash error: %s for input type %v", err.msg, err.typ)
|
|
}
|
|
|
|
func newHashError(msg string, typ reflect.Type) *hashError {
|
|
return &hashError{msg, typ}
|
|
}
|
|
|
|
func makeHasher(typ reflect.Type) (hasher, error) {
|
|
kind := typ.Kind()
|
|
switch {
|
|
case kind == reflect.Bool ||
|
|
kind == reflect.Uint8 ||
|
|
kind == reflect.Uint16 ||
|
|
kind == reflect.Uint32 ||
|
|
kind == reflect.Uint64:
|
|
return getEncoding, nil
|
|
case kind == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 ||
|
|
kind == reflect.Array && typ.Elem().Kind() == reflect.Uint8:
|
|
return hashedEncoding, nil
|
|
case kind == reflect.Slice || kind == reflect.Array:
|
|
return makeSliceHasher(typ)
|
|
case kind == reflect.Struct:
|
|
return makeStructHasher(typ)
|
|
case kind == reflect.Ptr:
|
|
return makePtrHasher(typ)
|
|
default:
|
|
return nil, fmt.Errorf("type %v is not hashable", typ)
|
|
}
|
|
}
|
|
|
|
func getEncoding(val reflect.Value) ([]byte, error) {
|
|
utils, err := cachedSSZUtilsNoAcquireLock(val.Type())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
buf := &encbuf{}
|
|
if err = utils.encoder(val, buf); err != nil {
|
|
return nil, err
|
|
}
|
|
writer := new(bytes.Buffer)
|
|
if err = buf.toWriter(writer); err != nil {
|
|
return nil, err
|
|
}
|
|
return writer.Bytes(), nil
|
|
}
|
|
|
|
func hashedEncoding(val reflect.Value) ([]byte, error) {
|
|
encoding, err := getEncoding(val)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
output := hashutil.Hash(encoding)
|
|
return output[:], nil
|
|
}
|
|
|
|
func makeSliceHasher(typ reflect.Type) (hasher, error) {
|
|
elemSSZUtils, err := cachedSSZUtilsNoAcquireLock(typ.Elem())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get ssz utils: %v", err)
|
|
}
|
|
hasher := func(val reflect.Value) ([]byte, error) {
|
|
var elemHashList [][]byte
|
|
for i := 0; i < val.Len(); i++ {
|
|
elemHash, err := elemSSZUtils.hasher(val.Index(i))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to hash element of slice/array: %v", err)
|
|
}
|
|
elemHashList = append(elemHashList, elemHash)
|
|
}
|
|
output, err := merkleHash(elemHashList)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to calculate merkle hash of element hash list: %v", err)
|
|
}
|
|
return output, nil
|
|
}
|
|
return hasher, nil
|
|
}
|
|
|
|
func makeStructHasher(typ reflect.Type) (hasher, error) {
|
|
fields, err := structFields(typ)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
hasher := func(val reflect.Value) ([]byte, error) {
|
|
concatElemHash := make([]byte, 0)
|
|
for _, f := range fields {
|
|
elemHash, err := f.sszUtils.hasher(val.Field(f.index))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to hash field of struct: %v", err)
|
|
}
|
|
concatElemHash = append(concatElemHash, elemHash...)
|
|
}
|
|
result := hashutil.Hash(concatElemHash)
|
|
return result[:], nil
|
|
}
|
|
return hasher, nil
|
|
}
|
|
|
|
// Notice: Currently we don't support nil pointer:
|
|
// - Input for encoding must not contain nil pointer
|
|
// - Output for decoding will never contain nil pointer
|
|
// (Not to be confused with empty slice. Empty slice is supported)
|
|
func makePtrHasher(typ reflect.Type) (hasher, error) {
|
|
elemSSZUtils, err := cachedSSZUtilsNoAcquireLock(typ.Elem())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hasher := func(val reflect.Value) ([]byte, error) {
|
|
if val.IsNil() {
|
|
return nil, errors.New("nil is not supported")
|
|
}
|
|
return elemSSZUtils.hasher(val.Elem())
|
|
}
|
|
return hasher, nil
|
|
}
|
|
|
|
// merkelHash implements a merkle-tree style hash algorithm.
|
|
//
|
|
// Please refer to the official spec for details:
|
|
// https://github.com/ethereum/eth2.0-specs/blob/master/specs/simple-serialize.md#tree-hash
|
|
//
|
|
// The overall idea is:
|
|
// 1. Create a bunch of bytes chunk (each has a size of sszChunkSize) from the input hash list.
|
|
// 2. Treat each bytes chunk as the leaf of a binary tree.
|
|
// 3. For every pair of leaves, we set their parent's value using the hash value of the concatenation of the two leaves.
|
|
// The original two leaves are then removed.
|
|
// 4. Keep doing step 3 until there's only one node left in the tree (the root).
|
|
// 5. Return the hash of the concatenation of the root and the data length encoding.
|
|
//
|
|
// Time complexity is O(n) given input list of size n.
|
|
func merkleHash(list [][]byte) ([]byte, error) {
|
|
// Assume len(list) < 2^64
|
|
dataLenEnc := make([]byte, hashLengthBytes)
|
|
binary.BigEndian.PutUint64(dataLenEnc[hashLengthBytes-8:], uint64(len(list)))
|
|
|
|
var chunkz [][]byte
|
|
emptyChunk := make([]byte, sszChunkSize)
|
|
|
|
if len(list) == 0 {
|
|
chunkz = make([][]byte, 1)
|
|
chunkz[0] = emptyChunk
|
|
} else if len(list[0]) < sszChunkSize {
|
|
if sszChunkSize%len(list[0]) != 0 {
|
|
return nil, fmt.Errorf("element hash size needs to be factor of %d", sszChunkSize)
|
|
}
|
|
itemsPerChunk := sszChunkSize / len(list[0])
|
|
chunkz = make([][]byte, 0)
|
|
for i := 0; i < len(list); i += itemsPerChunk {
|
|
chunk := make([]byte, 0)
|
|
j := i + itemsPerChunk
|
|
if j > len(list) {
|
|
j = len(list)
|
|
}
|
|
// Every chunk should have sszChunkSize bytes except that the last one could have less bytes
|
|
for _, elemHash := range list[i:j] {
|
|
chunk = append(chunk, elemHash...)
|
|
}
|
|
chunkz = append(chunkz, chunk)
|
|
}
|
|
} else {
|
|
chunkz = list
|
|
}
|
|
|
|
for len(chunkz) > 1 {
|
|
if len(chunkz)%2 == 1 {
|
|
chunkz = append(chunkz, emptyChunk)
|
|
}
|
|
hashedChunkz := make([][]byte, 0)
|
|
for i := 0; i < len(chunkz); i += 2 {
|
|
hashedChunk := hashutil.Hash(append(chunkz[i], chunkz[i+1]...))
|
|
hashedChunkz = append(hashedChunkz, hashedChunk[:])
|
|
}
|
|
chunkz = hashedChunkz
|
|
}
|
|
|
|
result := hashutil.Hash(append(chunkz[0], dataLenEnc...))
|
|
return result[:], nil
|
|
}
|