package ssz import ( "bytes" "encoding/binary" "errors" "fmt" "reflect" "github.com/prysmaticlabs/prysm/shared/hashutil" ) const hashLengthBytes = 32 const sszChunkSize = 128 // Hashable defines the interface for supporting tree-hash function. type Hashable interface { TreeHashSSZ() ([32]byte, error) } // TreeHash calculates tree-hash result for input value. func TreeHash(val interface{}) ([32]byte, error) { if val == nil { return [32]byte{}, newHashError("nil is not supported", nil) } rval := reflect.ValueOf(val) sszUtils, err := cachedSSZUtils(rval.Type()) if err != nil { return [32]byte{}, newHashError(fmt.Sprint(err), rval.Type()) } output, err := sszUtils.hasher(rval) if err != nil { return [32]byte{}, newHashError(fmt.Sprint(err), rval.Type()) } // Right-pad with 0 to make 32 bytes long, if necessary var paddedOutput [32]byte copy(paddedOutput[:], output) return paddedOutput, nil } type hashError struct { msg string typ reflect.Type } func (err *hashError) Error() string { return fmt.Sprintf("hash error: %s for input type %v", err.msg, err.typ) } func newHashError(msg string, typ reflect.Type) *hashError { return &hashError{msg, typ} } func makeHasher(typ reflect.Type) (hasher, error) { kind := typ.Kind() switch { case kind == reflect.Bool || kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 || kind == reflect.Uint64: return getEncoding, nil case kind == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 || kind == reflect.Array && typ.Elem().Kind() == reflect.Uint8: return hashedEncoding, nil case kind == reflect.Slice || kind == reflect.Array: return makeSliceHasher(typ) case kind == reflect.Struct: return makeStructHasher(typ) case kind == reflect.Ptr: return makePtrHasher(typ) default: return nil, fmt.Errorf("type %v is not hashable", typ) } } func getEncoding(val reflect.Value) ([]byte, error) { utils, err := cachedSSZUtilsNoAcquireLock(val.Type()) if err != nil { return nil, err } buf := &encbuf{} if err = utils.encoder(val, buf); err != nil { return nil, err } writer := new(bytes.Buffer) if err = buf.toWriter(writer); err != nil { return nil, err } return writer.Bytes(), nil } func hashedEncoding(val reflect.Value) ([]byte, error) { encoding, err := getEncoding(val) if err != nil { return nil, err } output := hashutil.Hash(encoding) return output[:], nil } func makeSliceHasher(typ reflect.Type) (hasher, error) { elemSSZUtils, err := cachedSSZUtilsNoAcquireLock(typ.Elem()) if err != nil { return nil, fmt.Errorf("failed to get ssz utils: %v", err) } hasher := func(val reflect.Value) ([]byte, error) { var elemHashList [][]byte for i := 0; i < val.Len(); i++ { elemHash, err := elemSSZUtils.hasher(val.Index(i)) if err != nil { return nil, fmt.Errorf("failed to hash element of slice/array: %v", err) } elemHashList = append(elemHashList, elemHash) } output, err := merkleHash(elemHashList) if err != nil { return nil, fmt.Errorf("failed to calculate merkle hash of element hash list: %v", err) } return output, nil } return hasher, nil } func makeStructHasher(typ reflect.Type) (hasher, error) { fields, err := structFields(typ) if err != nil { return nil, err } hasher := func(val reflect.Value) ([]byte, error) { concatElemHash := make([]byte, 0) for _, f := range fields { elemHash, err := f.sszUtils.hasher(val.Field(f.index)) if err != nil { return nil, fmt.Errorf("failed to hash field of struct: %v", err) } concatElemHash = append(concatElemHash, elemHash...) } result := hashutil.Hash(concatElemHash) return result[:], nil } return hasher, nil } // Notice: Currently we don't support nil pointer: // - Input for encoding must not contain nil pointer // - Output for decoding will never contain nil pointer // (Not to be confused with empty slice. Empty slice is supported) func makePtrHasher(typ reflect.Type) (hasher, error) { elemSSZUtils, err := cachedSSZUtilsNoAcquireLock(typ.Elem()) if err != nil { return nil, err } hasher := func(val reflect.Value) ([]byte, error) { if val.IsNil() { return nil, errors.New("nil is not supported") } return elemSSZUtils.hasher(val.Elem()) } return hasher, nil } // merkelHash implements a merkle-tree style hash algorithm. // // Please refer to the official spec for details: // https://github.com/ethereum/eth2.0-specs/blob/master/specs/simple-serialize.md#tree-hash // // The overall idea is: // 1. Create a bunch of bytes chunk (each has a size of sszChunkSize) from the input hash list. // 2. Treat each bytes chunk as the leaf of a binary tree. // 3. For every pair of leaves, we set their parent's value using the hash value of the concatenation of the two leaves. // The original two leaves are then removed. // 4. Keep doing step 3 until there's only one node left in the tree (the root). // 5. Return the hash of the concatenation of the root and the data length encoding. // // Time complexity is O(n) given input list of size n. func merkleHash(list [][]byte) ([]byte, error) { // Assume len(list) < 2^64 dataLenEnc := make([]byte, hashLengthBytes) binary.BigEndian.PutUint64(dataLenEnc[hashLengthBytes-8:], uint64(len(list))) var chunkz [][]byte emptyChunk := make([]byte, sszChunkSize) if len(list) == 0 { chunkz = make([][]byte, 1) chunkz[0] = emptyChunk } else if len(list[0]) < sszChunkSize { if sszChunkSize%len(list[0]) != 0 { return nil, fmt.Errorf("element hash size needs to be factor of %d", sszChunkSize) } itemsPerChunk := sszChunkSize / len(list[0]) chunkz = make([][]byte, 0) for i := 0; i < len(list); i += itemsPerChunk { chunk := make([]byte, 0) j := i + itemsPerChunk if j > len(list) { j = len(list) } // Every chunk should have sszChunkSize bytes except that the last one could have less bytes for _, elemHash := range list[i:j] { chunk = append(chunk, elemHash...) } chunkz = append(chunkz, chunk) } } else { chunkz = list } for len(chunkz) > 1 { if len(chunkz)%2 == 1 { chunkz = append(chunkz, emptyChunk) } hashedChunkz := make([][]byte, 0) for i := 0; i < len(chunkz); i += 2 { hashedChunk := hashutil.Hash(append(chunkz[i], chunkz[i+1]...)) hashedChunkz = append(hashedChunkz, hashedChunk[:]) } chunkz = hashedChunkz } result := hashutil.Hash(append(chunkz[0], dataLenEnc...)) return result[:], nil }