/* Copyright 2022 Erigon contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package commitment import ( "bytes" "encoding/binary" "fmt" "io" "math/bits" "github.com/holiman/uint256" "golang.org/x/crypto/sha3" "github.com/ledgerwatch/erigon-lib/common/length" "github.com/ledgerwatch/erigon-lib/rlp" "github.com/ledgerwatch/log/v3" ) const ( maxKeySize = 512 keyHalfSize = maxKeySize / 2 maxChild = 2 ) type bitstring []uint8 // converts slice of nibbles (lowest 4 bits of each byte) to bitstring func hexToBin(hex []byte) bitstring { bin := make([]byte, 4*len(hex)) for i := range bin { if hex[i/4]&(1<<(3-i%4)) != 0 { bin[i] = 1 } } return bin } // encodes bitstring to its compact representation func binToCompact(bin []byte) []byte { compact := make([]byte, 2+(len(bin)+7)/8) binary.BigEndian.PutUint16(compact, uint16(len(bin))) for i := 0; i < len(bin); i++ { if bin[i] != 0 { compact[2+i/8] |= (byte(1) << (i % 8)) } } return compact } // decodes compact bitstring representation into actual bitstring func compactToBin(compact []byte) []byte { bin := make([]byte, binary.BigEndian.Uint16(compact)) for i := 0; i < len(bin); i++ { if compact[2+i/8]&(byte(1)<<(i%8)) == 0 { bin[i] = 0 } else { bin[i] = 1 } } return bin } // BinHashed implements commitment based on patricia merkle tree with radix 16, // with keys pre-hashed by keccak256 type BinHashed struct { root BinCell // Root cell of the tree // Rows of the grid correspond to the level of depth in the patricia tree // Columns of the grid correspond to pointers to the nodes further from the root grid [maxKeySize][maxChild]BinCell // First 64 rows of this grid are for account trie, and next 64 rows are for storage trie // How many rows (starting from row 0) are currently active and have corresponding selected columns // Last active row does not have selected column activeRows int // Length of the key that reflects current positioning of the grid. It maybe larger than number of active rows, // if a account leaf cell represents multiple nibbles in the key currentKeyLen int currentKey [maxKeySize]byte // For each row indicates which column is currently selected depths [maxKeySize]int // For each row, the depth of cells in that row rootChecked bool // Set to false if it is not known whether the root is empty, set to true if it is checked rootTouched bool rootPresent bool branchBefore [maxKeySize]bool // For each row, whether there was a branch node in the database loaded in unfold touchMap [maxKeySize]uint16 // For each row, bitmap of cells that were either present before modification, or modified or deleted afterMap [maxKeySize]uint16 // For each row, bitmap of cells that were present after modification // Function used to load branch node and fill up the cells // For each cell, it sets the cell type, clears the modified flag, fills the hash, // and for the extension, account, and leaf type, the `l` and `k` branchFn func(prefix []byte) ([]byte, error) // Function used to fetch account with given plain key accountFn func(plainKey []byte, cell *BinCell) error // Function used to fetch account with given plain key storageFn func(plainKey []byte, cell *BinCell) error keccak keccakState keccak2 keccakState accountKeyLen int trace bool auxBuf [1 + length.Hash]byte byteArrayWriter ByteArrayWriter } func NewBinPatriciaHashed(accountKeyLen int, branchFn func(prefix []byte) ([]byte, error), accountFn func(plainKey []byte, cell *Cell) error, storageFn func(plainKey []byte, cell *Cell) error, ) *BinHashed { return &BinHashed{ keccak: sha3.NewLegacyKeccak256().(keccakState), keccak2: sha3.NewLegacyKeccak256().(keccakState), accountKeyLen: accountKeyLen, branchFn: branchFn, accountFn: wrapAccountStorageFn(accountFn), storageFn: wrapAccountStorageFn(storageFn), rootPresent: true, } } func (hph *BinHashed) ReviewKeys(plainKeys, hashedKeys [][]byte) (rootHash []byte, branchNodeUpdates map[string]BranchData, err error) { branchNodeUpdates = make(map[string]BranchData) stagedCell := new(BinCell) for i, hashedKey := range hashedKeys { hashedKey = hexToBin(hashedKey) plainKey := plainKeys[i] if hph.trace { fmt.Printf("plainKey=[%x], hashedKey=[%x], currentKey=[%x]\n", plainKey, hashedKey, hph.currentKey[:hph.currentKeyLen]) } // Keep folding until the currentKey is the prefix of the key we modify for hph.needFolding(hashedKey) { if branchData, updateKey, err := hph.fold(); err != nil { return nil, nil, fmt.Errorf("fold: %w", err) } else if branchData != nil { branchNodeUpdates[string(updateKey)] = branchData } } // Now unfold until we step on an empty cell for unfolding := hph.needUnfolding(hashedKey); unfolding > 0; unfolding = hph.needUnfolding(hashedKey) { if err := hph.unfold(hashedKey, unfolding); err != nil { return nil, nil, fmt.Errorf("unfold: %w", err) } } // Update the cell stagedCell.fillEmpty() var deleteCell bool if len(plainKey) == hph.accountKeyLen { if err := hph.accountFn(plainKey, stagedCell); err != nil { return nil, nil, fmt.Errorf("accountFn for key %x failed: %w", plainKey, err) } if stagedCell.isEmpty() { deleteCell = true } else { cell := hph.updateCell(hashedKey) cell.setAccountFields(plainKey, stagedCell.CodeHash[:], &stagedCell.Balance, stagedCell.Nonce) if hph.trace { fmt.Printf("accountFn filled cell plainKey: %x balance: %v nonce: %v codeHash: %x\n", stagedCell.apk, stagedCell.Balance.String(), stagedCell.Nonce, stagedCell.CodeHash) } } } else { if err = hph.storageFn(plainKey, stagedCell); err != nil { return nil, nil, fmt.Errorf("storageFn for key %x failed: %w", plainKey, err) } if hph.trace { fmt.Printf("storageFn filled %x : %x\n", plainKey, stagedCell.Storage) } if stagedCell.StorageLen == 0 { deleteCell = true } else { hph.updateCell(hashedKey).setStorage(plainKey, stagedCell.Storage[:stagedCell.StorageLen]) } } if deleteCell { hph.deleteCell(hashedKey) } } // Folding everything up to the root for hph.activeRows > 0 { if branchData, updateKey, err := hph.fold(); err != nil { return nil, nil, fmt.Errorf("final fold: %w", err) } else if branchData != nil { branchNodeUpdates[string(updateKey)] = branchData } } if branchData, err := hph.foldRoot(); err != nil { return nil, nil, fmt.Errorf("root fold: %w", err) } else if branchData != nil { branchNodeUpdates[string(hexToBin([]byte{}))] = branchData } rhash, err := hph.RootHash() if err != nil { return nil, branchNodeUpdates, fmt.Errorf("root hash evaluation failed: %w", err) } return rhash, branchNodeUpdates, nil } func (hph *BinHashed) Variant() TrieVariant { return VariantBinPatriciaTrie } func (hph *BinHashed) RootHash() ([]byte, error) { hash, err := hph.computeCellHash(&hph.root, 0, hph.auxBuf[:0]) if err != nil { return nil, err } return hash[1:], nil // first byte is 128+hash_len } func (hph *BinHashed) SetTrace(trace bool) { hph.trace = trace } // Reset allows BinHashed instance to be reused for the new commitment calculation func (hph *BinHashed) Reset() { hph.rootChecked = false hph.root.hl = 0 hph.root.downHashedLen = 0 hph.root.apl = 0 hph.root.spl = 0 hph.root.extLen = 0 copy(hph.root.CodeHash[:], EmptyCodeHash) hph.root.StorageLen = 0 hph.root.Balance.Clear() hph.root.Nonce = 0 hph.rootTouched = false hph.rootPresent = true } func wrapAccountStorageFn(fn func([]byte, *Cell) error) func(pk []byte, bc *BinCell) error { return func(pk []byte, bc *BinCell) error { cl := bc.unwrapToHexCell() if err := fn(pk, cl); err != nil { return err } bc.Balance = *cl.Balance.Clone() bc.Nonce = cl.Nonce bc.StorageLen = cl.StorageLen bc.apl = cl.apl bc.spl = cl.spl bc.hl = cl.hl copy(bc.apk[:], cl.apk[:]) copy(bc.spk[:], cl.spk[:]) copy(bc.h[:], cl.h[:]) copy(bc.extension[:], cl.extension[:]) bc.extLen = cl.extLen copy(bc.downHashedKey[:], cl.downHashedKey[:]) bc.downHashedLen = cl.downHashedLen copy(bc.CodeHash[:], cl.CodeHash[:]) copy(bc.Storage[:], cl.Storage[:]) return nil } } func (hph *BinHashed) ResetFns( branchFn func(prefix []byte) ([]byte, error), accountFn func(plainKey []byte, cell *Cell) error, storageFn func(plainKey []byte, cell *Cell) error, ) { hph.branchFn = branchFn hph.accountFn = wrapAccountStorageFn(accountFn) hph.storageFn = wrapAccountStorageFn(storageFn) } func (hph *BinHashed) completeLeafHash(buf, keyPrefix []byte, kp, kl, compactLen int, key []byte, compact0 byte, ni int, val rlp.RlpSerializable, singleton bool) ([]byte, error) { totalLen := kp + kl + val.DoubleRLPLen() var lenPrefix [4]byte pt := rlp.GenerateStructLen(lenPrefix[:], totalLen) embedded := !singleton && totalLen+pt < length.Hash var writer io.Writer if embedded { hph.byteArrayWriter.Setup(buf) writer = &hph.byteArrayWriter } else { hph.keccak.Reset() writer = hph.keccak } if _, err := writer.Write(lenPrefix[:pt]); err != nil { return nil, err } if _, err := writer.Write(keyPrefix[:kp]); err != nil { return nil, err } var b [1]byte b[0] = compact0 if _, err := writer.Write(b[:]); err != nil { return nil, err } for i := 1; i < compactLen; i++ { b[0] = key[ni]*16 + key[ni+1] if _, err := writer.Write(b[:]); err != nil { return nil, err } ni += 2 } var prefixBuf [8]byte if err := val.ToDoubleRLP(writer, prefixBuf[:]); err != nil { return nil, err } if embedded { buf = hph.byteArrayWriter.buf } else { var hashBuf [33]byte hashBuf[0] = 0x80 + length.Hash if _, err := hph.keccak.Read(hashBuf[1:]); err != nil { return nil, err } buf = append(buf, hashBuf[:]...) } return buf, nil } func (hph *BinHashed) leafHashWithKeyVal(buf, key []byte, val rlp.RlpSerializableBytes, singleton bool) ([]byte, error) { // Compute the total length of binary representation var kp, kl int // Write key var compactLen int var ni int var compact0 byte compactLen = (len(key)-1)/2 + 1 if len(key)&1 == 0 { compact0 = 0x30 + key[0] // Odd: (3<<4) + first nibble ni = 1 } else { compact0 = 0x20 } var keyPrefix [1]byte if compactLen > 1 { keyPrefix[0] = 0x80 + byte(compactLen) kp = 1 kl = compactLen } else { kl = 1 } return hph.completeLeafHash(buf, keyPrefix[:], kp, kl, compactLen, key, compact0, ni, val, singleton) } func (hph *BinHashed) accountLeafHashWithKey(buf, key []byte, val rlp.RlpSerializable) ([]byte, error) { // Compute the total length of binary representation var kp, kl int // Write key var compactLen int var ni int var compact0 byte if hasTerm(key) { compactLen = (len(key)-1)/2 + 1 if len(key)&1 == 0 { compact0 = 48 + key[0] // Odd (1<<4) + first nibble ni = 1 } else { compact0 = 32 } } else { compactLen = len(key)/2 + 1 if len(key)&1 == 1 { compact0 = 16 + key[0] // Odd (1<<4) + first nibble ni = 1 } } var keyPrefix [1]byte if compactLen > 1 { keyPrefix[0] = byte(128 + compactLen) kp = 1 kl = compactLen } else { kl = 1 } return hph.completeLeafHash(buf, keyPrefix[:], kp, kl, compactLen, key, compact0, ni, val, true) } func (hph *BinHashed) extensionHash(key []byte, hash []byte) ([length.Hash]byte, error) { var hashBuf [length.Hash]byte // Compute the total length of binary representation var kp, kl int // Write key var compactLen int var ni int var compact0 byte if hasTerm(key) { compactLen = (len(key)-1)/2 + 1 if len(key)&1 == 0 { compact0 = 0x30 + key[0] // Odd: (3<<4) + first nibble ni = 1 } else { compact0 = 0x20 } } else { compactLen = len(key)/2 + 1 if len(key)&1 == 1 { compact0 = 0x10 + key[0] // Odd: (1<<4) + first nibble ni = 1 } } var keyPrefix [1]byte if compactLen > 1 { keyPrefix[0] = 0x80 + byte(compactLen) kp = 1 kl = compactLen } else { kl = 1 } totalLen := kp + kl + 33 var lenPrefix [4]byte pt := rlp.GenerateStructLen(lenPrefix[:], totalLen) hph.keccak.Reset() if _, err := hph.keccak.Write(lenPrefix[:pt]); err != nil { return hashBuf, err } if _, err := hph.keccak.Write(keyPrefix[:kp]); err != nil { return hashBuf, err } var b [1]byte b[0] = compact0 if _, err := hph.keccak.Write(b[:]); err != nil { return hashBuf, err } for i := 1; i < compactLen; i++ { b[0] = key[ni]*16 + key[ni+1] if _, err := hph.keccak.Write(b[:]); err != nil { return hashBuf, err } ni += 2 } b[0] = 0x80 + length.Hash if _, err := hph.keccak.Write(b[:]); err != nil { return hashBuf, err } if _, err := hph.keccak.Write(hash); err != nil { return hashBuf, err } // Replace previous hash with the new one if _, err := hph.keccak.Read(hashBuf[:]); err != nil { return hashBuf, err } return hashBuf, nil } func (hph *BinHashed) computeCellHashLen(cell *BinCell, depth int) int { if cell.spl > 0 && depth >= keyHalfSize { keyLen := maxKeySize - depth + 1 // Length of hex key with terminator character var kp, kl int compactLen := (keyLen-1)/2 + 1 if compactLen > 1 { kp = 1 kl = compactLen } else { kl = 1 } val := rlp.RlpSerializableBytes(cell.Storage[:cell.StorageLen]) totalLen := kp + kl + val.DoubleRLPLen() var lenPrefix [4]byte pt := rlp.GenerateStructLen(lenPrefix[:], totalLen) if totalLen+pt < length.Hash { return totalLen + pt } } return length.Hash + 1 } func (hph *BinHashed) computeCellHash(cell *BinCell, depth int, buf []byte) ([]byte, error) { var err error var storageRootHash [length.Hash]byte storageRootHashIsSet := false if cell.spl > 0 { var hashedKeyOffset int if depth >= keyHalfSize { hashedKeyOffset = depth - keyHalfSize } singleton := depth <= keyHalfSize if err := hashKey(hph.keccak, cell.spk[hph.accountKeyLen:cell.spl], cell.downHashedKey[:], hashedKeyOffset); err != nil { return nil, err } cell.downHashedKey[keyHalfSize-hashedKeyOffset] = 16 // Add terminator if singleton { if hph.trace { fmt.Printf("leafHashWithKeyVal(singleton) for [%x]=>[%x]\n", cell.downHashedKey[:keyHalfSize-hashedKeyOffset+1], cell.Storage[:cell.StorageLen]) } if _, err = hph.leafHashWithKeyVal(storageRootHash[:0], cell.downHashedKey[:keyHalfSize-hashedKeyOffset+1], rlp.RlpSerializableBytes(cell.Storage[:cell.StorageLen]), true); err != nil { return nil, err } storageRootHashIsSet = true } else { if hph.trace { fmt.Printf("leafHashWithKeyVal for [%x]=>[%x]\n", cell.downHashedKey[:keyHalfSize-hashedKeyOffset+1], cell.Storage[:cell.StorageLen]) } return hph.leafHashWithKeyVal(buf, cell.downHashedKey[:keyHalfSize-hashedKeyOffset+1], rlp.RlpSerializableBytes(cell.Storage[:cell.StorageLen]), false) } } if cell.apl > 0 { if err := hashKey(hph.keccak, cell.apk[:cell.apl], cell.downHashedKey[:], depth); err != nil { return nil, err } cell.downHashedKey[keyHalfSize-depth] = 16 // Add terminator if !storageRootHashIsSet { if cell.extLen > 0 { // Extension if cell.hl > 0 { if hph.trace { fmt.Printf("extensionHash for [%x]=>[%x]\n", cell.extension[:cell.extLen], cell.h[:cell.hl]) } if storageRootHash, err = hph.extensionHash(cell.extension[:cell.extLen], cell.h[:cell.hl]); err != nil { return nil, err } } else { return nil, fmt.Errorf("computeCellHash extension without hash") } } else if cell.hl > 0 { storageRootHash = cell.h } else { storageRootHash = *(*[length.Hash]byte)(EmptyRootHash) } } var valBuf [128]byte valLen := cell.accountForHashing(valBuf[:], storageRootHash) if hph.trace { fmt.Printf("accountLeafHashWithKey for [%x]=>[%x]\n", cell.downHashedKey[:keyHalfSize+1-depth], valBuf[:valLen]) } return hph.accountLeafHashWithKey(buf, cell.downHashedKey[:keyHalfSize+1-depth], rlp.RlpEncodedBytes(valBuf[:valLen])) } buf = append(buf, 0x80+32) if cell.extLen > 0 { // Extension if cell.hl > 0 { if hph.trace { fmt.Printf("extensionHash for [%x]=>[%x]\n", cell.extension[:cell.extLen], cell.h[:cell.hl]) } var hash [length.Hash]byte if hash, err = hph.extensionHash(cell.extension[:cell.extLen], cell.h[:cell.hl]); err != nil { return nil, err } buf = append(buf, hash[:]...) } else { return nil, fmt.Errorf("computeCellHash extension without hash") } } else if cell.hl > 0 { buf = append(buf, cell.h[:cell.hl]...) } else { buf = append(buf, EmptyRootHash...) } return buf, nil } func (hph *BinHashed) needUnfolding(hashedKey []byte) int { var cell *BinCell var depth int if hph.activeRows == 0 { if hph.trace { fmt.Printf("needUnfolding root, rootChecked = %t\n", hph.rootChecked) } cell = &hph.root if cell.downHashedLen == 0 && cell.hl == 0 && !hph.rootChecked { // Need to attempt to unfold the root return 1 } } else { col := int(hashedKey[hph.currentKeyLen]) cell = &hph.grid[hph.activeRows-1][col] depth = hph.depths[hph.activeRows-1] if hph.trace { fmt.Printf("needUnfolding cell (%d, %x), currentKey=[%x], depth=%d, cell.h=[%x]\n", hph.activeRows-1, col, hph.currentKey[:hph.currentKeyLen], depth, cell.h[:cell.hl]) } } if len(hashedKey) <= depth { return 0 } if cell.downHashedLen == 0 { if cell.hl == 0 { // cell is empty, no need to unfold further return 0 } else { // unfold branch node return 1 } } cpl := commonPrefixLen(hashedKey[depth:], cell.downHashedKey[:cell.downHashedLen-1]) if hph.trace { fmt.Printf("cpl=%d, cell.downHashedKey=[%x], depth=%d, hashedKey[depth:]=[%x]\n", cpl, cell.downHashedKey[:cell.downHashedLen], depth, hashedKey[depth:]) } unfolding := cpl + 1 if depth < keyHalfSize && depth+unfolding > keyHalfSize { // This is to make sure that unfolding always breaks at the level where storage subtrees start unfolding = keyHalfSize - depth if hph.trace { fmt.Printf("adjusted unfolding=%d\n", unfolding) } } return unfolding } func (hph *BinHashed) unfoldBranchNode(row int, deleted bool, depth int) error { branchData, err := hph.branchFn(binToCompact(hph.currentKey[:hph.currentKeyLen])) if err != nil { return err } if !hph.rootChecked && hph.currentKeyLen == 0 && len(branchData) == 0 { // Special case - empty or deleted root hph.rootChecked = true return nil } if len(branchData) == 0 { log.Warn("got empty branch data during unfold", "row", row, "depth", depth, "deleted", deleted) } hph.branchBefore[row] = true // fmt.Printf("unfoldBranchNode [%x]=>[%x]\n", hph.currentKey[:hph.currentKeyLen], branchData) bitmap := binary.BigEndian.Uint16(branchData[0:]) pos := 2 if deleted { // All cells come as deleted (touched but not present after) hph.afterMap[row] = 0 hph.touchMap[row] = bitmap } else { hph.afterMap[row] = bitmap hph.touchMap[row] = 0 } //fmt.Printf("unfoldBranchNode [unfoldBranchNode%x], afterMap = [%016b], touchMap = [%016b]\n", branchData, hph.afterMap[row], hph.touchMap[row]) // Loop iterating over the set bits of modMask for bitset, j := bitmap, 0; bitset != 0; j++ { bit := bitset & -bitset nibble := bits.TrailingZeros16(bit) cell := &hph.grid[row][nibble] fieldBits := branchData[pos] pos++ var err error if pos, err = cell.fillFromFields(branchData, pos, PartFlags(fieldBits)); err != nil { return fmt.Errorf("prefix [%x], branchData[%x]: %w", hph.currentKey[:hph.currentKeyLen], branchData, err) } if hph.trace { fmt.Printf("cell (%d, %x) depth=%d, hash=[%x], a=[%x], s=[%x], ex=[%x]\n", row, nibble, depth, cell.h[:cell.hl], cell.apk[:cell.apl], cell.spk[:cell.spl], cell.extension[:cell.extLen]) } if cell.apl > 0 { hph.accountFn(cell.apk[:cell.apl], cell) if hph.trace { fmt.Printf("accountFn[%x] return balance=%d, nonce=%d\n", cell.apk[:cell.apl], &cell.Balance, cell.Nonce) } } if cell.spl > 0 { hph.storageFn(cell.spk[:cell.spl], cell) } if err = cell.deriveHashedKeys(depth, hph.keccak, hph.accountKeyLen); err != nil { return err } bitset ^= bit } return nil } func (hph *BinHashed) unfold(hashedKey []byte, unfolding int) error { if hph.trace { fmt.Printf("unfold %d: activeRows: %d\n", unfolding, hph.activeRows) } var upCell *BinCell var touched, present bool var col byte var upDepth, depth int if hph.activeRows == 0 { if hph.rootChecked && hph.root.hl == 0 && hph.root.downHashedLen == 0 { // No unfolding for empty root return nil } upCell = &hph.root touched = hph.rootTouched present = hph.rootPresent if hph.trace { fmt.Printf("root, touched %t, present %t\n", touched, present) } } else { upDepth = hph.depths[hph.activeRows-1] col = hashedKey[upDepth-1] upCell = &hph.grid[hph.activeRows-1][col] touched = hph.touchMap[hph.activeRows-1]&(uint16(1)<