diff --git a/Makefile b/Makefile index e5a2ab835..7ad5744bb 100644 --- a/Makefile +++ b/Makefile @@ -95,7 +95,7 @@ test: semantics/z3/build/libz3.a all lint: lintci -lintci: +lintci: semantics/z3/build/libz3.a all @echo "--> Running linter for code diff versus commit $(LATEST_COMMIT)" @./build/bin/golangci-lint run \ --new-from-rev=$(LATEST_COMMIT) \ @@ -116,7 +116,7 @@ lintci: lintci-deps: rm -f ./build/bin/golangci-lint - curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b ./build/bin v1.21.0 + curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b ./build/bin v1.23.8 clean: env GO111MODULE=on go clean -cache @@ -231,4 +231,4 @@ bindings: go generate ./tests/contracts/ simulator-genesis: - go run ./cmd/tester genesis > ./cmd/tester/simulator_genesis.json \ No newline at end of file + go run ./cmd/tester genesis > ./cmd/tester/simulator_genesis.json diff --git a/cmd/hack/hack.go b/cmd/hack/hack.go index 584dc77fc..e9d8a273f 100644 --- a/cmd/hack/hack.go +++ b/cmd/hack/hack.go @@ -617,7 +617,7 @@ func trieChart() { } func execToBlock(chaindata string, block uint64, fromScratch bool) { - state.MaxTrieCacheGen = 32 + state.MaxTrieCacheSize = 32 blockDb, err := ethdb.NewBoltDatabase(chaindata) check(err) bcb, err := core.NewBlockChain(blockDb, nil, params.MainnetChainConfig, ethash.NewFaker(), vm.Config{}, nil) @@ -1938,50 +1938,50 @@ func validateTxLookups2(db *ethdb.BoltDatabase, startBlock uint64, interruptCh c func indexSize(chaindata string) { db, err := ethdb.NewBoltDatabase(chaindata) check(err) - fStorage,err:=os.Create("index_sizes_storage.csv") + fStorage, err := os.Create("index_sizes_storage.csv") check(err) - fAcc,err:=os.Create("index_sizes_acc.csv") + fAcc, err := os.Create("index_sizes_acc.csv") check(err) - csvAcc:=csv.NewWriter(fAcc) + csvAcc := csv.NewWriter(fAcc) err = csvAcc.Write([]string{"key", "ln"}) check(err) - csvStorage:=csv.NewWriter(fStorage) + csvStorage := csv.NewWriter(fStorage) err = csvStorage.Write([]string{"key", "ln"}) - i:=0 - j:=0 - maxLenAcc:=0 - maxLenSt:=0 + i := 0 + j := 0 + maxLenAcc := 0 + maxLenSt := 0 db.Walk(dbutils.AccountsHistoryBucket, []byte{}, 0, func(k, v []byte) (b bool, e error) { - if i>10000 { + if i > 10000 { fmt.Println(j) - i=0 + i = 0 } i++ j++ - if len(v)> maxLenAcc { - maxLenAcc=len(v) + if len(v) > maxLenAcc { + maxLenAcc = len(v) } err = csvAcc.Write([]string{common.Bytes2Hex(k), strconv.Itoa(len(v))}) - if err!=nil { + if err != nil { panic(err) } return true, nil }) - i=0 - j=0 + i = 0 + j = 0 db.Walk(dbutils.StorageHistoryBucket, []byte{}, 0, func(k, v []byte) (b bool, e error) { - if i>10000 { + if i > 10000 { fmt.Println(j) - i=0 + i = 0 } i++ j++ - if len(v)> maxLenSt { - maxLenSt=len(v) + if len(v) > maxLenSt { + maxLenSt = len(v) } err = csvStorage.Write([]string{common.Bytes2Hex(k), strconv.Itoa(len(v))}) - if err!=nil { + if err != nil { panic(err) } diff --git a/cmd/state/commands/stateless.go b/cmd/state/commands/stateless.go index 6024aa967..5b9d73a1c 100644 --- a/cmd/state/commands/stateless.go +++ b/cmd/state/commands/stateless.go @@ -28,7 +28,7 @@ func init() { withBlock(statelessCmd) statelessCmd.Flags().StringVar(&statefile, "statefile", "state", "path to the file where the state will be periodically written during the analysis") - statelessCmd.Flags().Uint32Var(&triesize, "triesize", 1024*1024, "maximum number of nodes in the state trie") + statelessCmd.Flags().Uint32Var(&triesize, "triesize", 4*1024*1024, "maximum size of a trie in bytes") statelessCmd.Flags().BoolVar(&preroot, "preroot", false, "Attempt to compute hash of the trie without modifying it") statelessCmd.Flags().Uint64Var(&snapshotInterval, "snapshotInterval", 0, "how often to take snapshots (0 - never, 1 - every block, 1000 - every 1000th block, etc)") statelessCmd.Flags().Uint64Var(&snapshotFrom, "snapshotFrom", 0, "from which block to start snapshots") diff --git a/cmd/state/stateless/stateless.go b/cmd/state/stateless/stateless.go index 819d72223..b352b6e6d 100644 --- a/cmd/state/stateless/stateless.go +++ b/cmd/state/stateless/stateless.go @@ -150,7 +150,7 @@ func Stateless( witnessDatabasePath string, writeHistory bool, ) { - state.MaxTrieCacheGen = triesize + state.MaxTrieCacheSize = uint64(triesize) startTime := time.Now() sigs := make(chan os.Signal, 1) interruptCh := make(chan bool, 1) @@ -324,7 +324,7 @@ func Stateless( return } if len(resolveWitnesses) > 0 { - witnessDBWriter.MustUpsert(blockNum, state.MaxTrieCacheGen, resolveWitnesses) + witnessDBWriter.MustUpsert(blockNum, uint32(state.MaxTrieCacheSize), resolveWitnesses) } } execTime2 := time.Since(execStart) @@ -454,7 +454,7 @@ func Stateless( fmt.Printf("Failed to commit batch: %v\n", err) return } - tds.PruneTries(false) + tds.EvictTries(false) } if willSnapshot { diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index ff5d3d970..019e93f06 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -1589,7 +1589,7 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { // TODO(fjl): move trie cache generations into config if gen := ctx.GlobalInt(TrieCacheGenFlag.Name); gen > 0 { - state.MaxTrieCacheGen = uint32(gen) + state.MaxTrieCacheSize = uint64(gen) } } diff --git a/core/blockchain.go b/core/blockchain.go index 446f42da6..39dc0403e 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -1775,7 +1775,7 @@ func (bc *BlockChain) insertChain(ctx context.Context, chain types.Blocks, verif } bc.committedBlock.Store(bc.currentBlock.Load()) if bc.trieDbState != nil { - bc.trieDbState.PruneTries(false) + bc.trieDbState.EvictTries(false) } log.Info("Database", "size", bc.db.DiskSize(), "written", written) } diff --git a/core/state/database.go b/core/state/database.go index a889e6f68..238f701af 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -14,6 +14,7 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . +//nolint:scopelint package state import ( @@ -36,8 +37,8 @@ import ( "github.com/ledgerwatch/turbo-geth/trie" ) -// Trie cache generation limit after which to evict trie nodes from memory. -var MaxTrieCacheGen = uint32(1024 * 1024) +// MaxTrieCacheSize is the trie cache size limit after which to evict trie nodes from memory. +var MaxTrieCacheSize = uint64(1024 * 1024) const ( //FirstContractIncarnation - first incarnation for contract accounts. After 1 it increases by 1. @@ -180,7 +181,7 @@ type TrieDbState struct { resolveReads bool savePreimages bool resolveSetBuilder *trie.ResolveSetBuilder - tp *trie.TriePruning + tp *trie.Eviction newStream trie.Stream hashBuilder *trie.HashBuilder resolver *trie.Resolver @@ -189,7 +190,7 @@ type TrieDbState struct { func NewTrieDbState(root common.Hash, db ethdb.Database, blockNr uint64) *TrieDbState { t := trie.New(root) - tp := trie.NewTriePruning(blockNr) + tp := trie.NewEviction() tds := &TrieDbState{ t: t, @@ -202,10 +203,11 @@ func NewTrieDbState(root common.Hash, db ethdb.Database, blockNr uint64) *TrieDb hashBuilder: trie.NewHashBuilder(false), incarnationMap: make(map[common.Hash]uint64), } - t.SetTouchFunc(tp.Touch) - tp.SetUnloadNodeFunc(tds.putIntermediateHash) - tp.SetCreateNodeFunc(tds.delIntermediateHash) + tp.SetBlockNumber(blockNr) + + t.AddObserver(tp) + t.AddObserver(NewIntermediateHashes(tds.db, tds.db)) return tds } @@ -232,7 +234,8 @@ func (tds *TrieDbState) Copy() *TrieDbState { tds.tMu.Unlock() n := tds.getBlockNr() - tp := trie.NewTriePruning(n) + tp := trie.NewEviction() + tp.SetBlockNumber(n) cpy := TrieDbState{ t: &tcopy, @@ -244,36 +247,12 @@ func (tds *TrieDbState) Copy() *TrieDbState { incarnationMap: make(map[common.Hash]uint64), } - cpy.tp.SetUnloadNodeFunc(cpy.putIntermediateHash) - cpy.tp.SetCreateNodeFunc(cpy.delIntermediateHash) + cpy.t.AddObserver(tp) + cpy.t.AddObserver(NewIntermediateHashes(cpy.db, cpy.db)) return &cpy } -func (tds *TrieDbState) putIntermediateHash(key []byte, nodeHash []byte) { - if err := tds.db.Put(dbutils.IntermediateTrieHashBucket, common.CopyBytes(key), common.CopyBytes(nodeHash)); err != nil { - log.Warn("could not put intermediate trie hash", "err", err) - } -} - -func (tds *TrieDbState) delIntermediateHash(prefixAsNibbles []byte) { - if len(prefixAsNibbles) == 0 { - return - } - - if len(prefixAsNibbles)%2 == 1 { // only put to bucket prefixes with even number of nibbles - return - } - - key := make([]byte, len(prefixAsNibbles)/2) - trie.CompressNibbles(prefixAsNibbles, &key) - - if err := tds.db.Delete(dbutils.IntermediateTrieHashBucket, key); err != nil { - log.Warn("could not delete intermediate trie hash", "err", err) - return - } -} - func (tds *TrieDbState) Database() ethdb.Database { return tds.db } @@ -645,7 +624,7 @@ func (tds *TrieDbState) ResolveStateTrieStateless(database trie.WitnessStorage) return nil } - pos, err := resolver.ResolveStateless(database, tds.blockNr, MaxTrieCacheGen, startPos) + pos, err := resolver.ResolveStateless(database, tds.blockNr, uint32(MaxTrieCacheSize), startPos) if err != nil { return err } @@ -838,7 +817,7 @@ func (tds *TrieDbState) clearUpdates() { func (tds *TrieDbState) SetBlockNr(blockNr uint64) { tds.setBlockNr(blockNr) - tds.tp.SetBlockNr(blockNr) + tds.tp.SetBlockNumber(blockNr) } func (tds *TrieDbState) GetBlockNr() uint64 { @@ -1203,28 +1182,69 @@ type TrieStateWriter struct { tds *TrieDbState } -func (tds *TrieDbState) PruneTries(print bool) { +func (tds *TrieDbState) EvictTries(print bool) { tds.tMu.Lock() defer tds.tMu.Unlock() + strict := print tds.incarnationMap = make(map[common.Hash]uint64) if print { - prunableNodes := tds.t.CountPrunableNodes() - fmt.Printf("[Before] Actual prunable nodes: %d, accounted: %d\n", prunableNodes, tds.tp.NodeCount()) + trieSize := tds.t.TrieSize() + fmt.Println("") // newline for better formatting + fmt.Printf("[Before] Actual nodes size: %d, accounted size: %d\n", trieSize, tds.tp.TotalSize()) } - tds.tp.PruneTo(tds.t, int(MaxTrieCacheGen)) + if strict { + actualAccounts := uint64(tds.t.NumberOfAccounts()) + fmt.Println("number of leaves: ", actualAccounts) + accountedAccounts := tds.tp.NumberOf() + if actualAccounts != accountedAccounts { + panic(fmt.Errorf("account number mismatch: trie=%v eviction=%v", actualAccounts, accountedAccounts)) + } + fmt.Printf("checking number --> ok\n") + + actualSize := uint64(tds.t.TrieSize()) + accountedSize := tds.tp.TotalSize() + + if actualSize != accountedSize { + panic(fmt.Errorf("account size mismatch: trie=%v eviction=%v", actualSize, accountedSize)) + } + fmt.Printf("checking size --> ok\n") + } + + tds.tp.EvictToFitSize(tds.t, MaxTrieCacheSize) + + if strict { + actualAccounts := uint64(tds.t.NumberOfAccounts()) + fmt.Println("number of leaves: ", actualAccounts) + accountedAccounts := tds.tp.NumberOf() + if actualAccounts != accountedAccounts { + panic(fmt.Errorf("after eviction account number mismatch: trie=%v eviction=%v", actualAccounts, accountedAccounts)) + } + fmt.Printf("checking number --> ok\n") + + actualSize := uint64(tds.t.TrieSize()) + accountedSize := tds.tp.TotalSize() + + if actualSize != accountedSize { + panic(fmt.Errorf("after eviction account size mismatch: trie=%v eviction=%v", actualSize, accountedSize)) + } + fmt.Printf("checking size --> ok\n") + } if print { - prunableNodes := tds.t.CountPrunableNodes() - fmt.Printf("[After] Actual prunable nodes: %d, accounted: %d\n", prunableNodes, tds.tp.NodeCount()) + trieSize := tds.t.TrieSize() + fmt.Printf("[After] Actual nodes size: %d, accounted size: %d\n", trieSize, tds.tp.TotalSize()) + + actualAccounts := uint64(tds.t.NumberOfAccounts()) + fmt.Println("number of leaves: ", actualAccounts) } var m runtime.MemStats runtime.ReadMemStats(&m) - log.Info("Memory", "nodes", tds.tp.NodeCount(), "hashes", tds.t.HashMapSize(), + log.Info("Memory", "nodes size", tds.tp.TotalSize(), "hashes", tds.t.HashMapSize(), "alloc", int(m.Alloc/1024), "sys", int(m.Sys/1024), "numGC", int(m.NumGC)) if print { - fmt.Printf("Pruning done. Nodes: %d, alloc: %d, sys: %d, numGC: %d\n", tds.tp.NodeCount(), int(m.Alloc/1024), int(m.Sys/1024), int(m.NumGC)) + fmt.Printf("Eviction done. Nodes size: %d, alloc: %d, sys: %d, numGC: %d\n", tds.tp.TotalSize(), int(m.Alloc/1024), int(m.Sys/1024), int(m.NumGC)) } } diff --git a/core/state/intermediate_hashes.go b/core/state/intermediate_hashes.go new file mode 100644 index 000000000..ca60be45d --- /dev/null +++ b/core/state/intermediate_hashes.go @@ -0,0 +1,51 @@ +package state + +import ( + "github.com/ledgerwatch/turbo-geth/common" + "github.com/ledgerwatch/turbo-geth/common/dbutils" + "github.com/ledgerwatch/turbo-geth/common/pool" + "github.com/ledgerwatch/turbo-geth/ethdb" + "github.com/ledgerwatch/turbo-geth/log" + "github.com/ledgerwatch/turbo-geth/trie" +) + +const keyBufferSize = 64 + +type IntermediateHashes struct { + trie.NoopObserver // make sure that we don't need to subscribe to unnecessary methods + putter ethdb.Putter + deleter ethdb.Deleter +} + +func NewIntermediateHashes(putter ethdb.Putter, deleter ethdb.Deleter) *IntermediateHashes { + return &IntermediateHashes{putter: putter, deleter: deleter} +} + +func (ih *IntermediateHashes) WillUnloadBranchNode(prefixAsNibbles []byte, nodeHash common.Hash) { + // only put to bucket prefixes with even number of nibbles + if len(prefixAsNibbles) == 0 || len(prefixAsNibbles)%2 == 1 { + return + } + + key := pool.GetBuffer(keyBufferSize) + trie.CompressNibbles(prefixAsNibbles, &key.B) + + if err := ih.putter.Put(dbutils.IntermediateTrieHashBucket, common.CopyBytes(key.B), common.CopyBytes(nodeHash[:])); err != nil { + log.Warn("could not put intermediate trie hash", "err", err) + } +} + +func (ih *IntermediateHashes) BranchNodeLoaded(prefixAsNibbles []byte) { + // only put to bucket prefixes with even number of nibbles + if len(prefixAsNibbles) == 0 || len(prefixAsNibbles)%2 == 1 { + return + } + + key := pool.GetBuffer(keyBufferSize) + trie.CompressNibbles(prefixAsNibbles, &key.B) + + if err := ih.deleter.Delete(dbutils.IntermediateTrieHashBucket, key.B); err != nil { + log.Warn("could not delete intermediate trie hash", "err", err) + return + } +} diff --git a/trie/node.go b/trie/node.go index 314d26441..c4c9bc1d3 100644 --- a/trie/node.go +++ b/trie/node.go @@ -228,3 +228,77 @@ func (n hashNode) String() string { return n.fstring("") } func (n valueNode) String() string { return n.fstring("") } func (n codeNode) String() string { return n.fstring("") } func (an accountNode) String() string { return an.fstring("") } + +func CodeKeyFromAddrHash(addrHash []byte) []byte { + return append(addrHash, 0xC0, 0xDE) +} + +func CodeHexFromHex(hex []byte) []byte { + return append(hex, 0x0C, 0x00, 0x0D, 0x0E) +} + +func IsPointingToCode(key []byte) bool { + // checking for 0xC0DE + l := len(key) + if l < 2 { + return false + } + + return key[l-2] == 0xC0 && key[l-1] == 0xDE +} + +func AddrHashFromCodeKey(codeKey []byte) []byte { + // cut off 0xC0DE + return codeKey[:len(codeKey)-2] +} + +func calcSubtreeSize(node node) int { + switch n := node.(type) { + case nil: + return 0 + case valueNode: + return 0 + case *shortNode: + return calcSubtreeSize(n.Val) + case *duoNode: + return 1 + calcSubtreeSize(n.child1) + calcSubtreeSize(n.child2) + case *fullNode: + size := 1 + for _, child := range n.Children { + size += calcSubtreeSize(child) + } + return size + case *accountNode: + return len(n.code) + calcSubtreeSize(n.storage) + case hashNode: + return 0 + } + return 0 +} + +func calcSubtreeNodes(node node) int { + switch n := node.(type) { + case nil: + return 0 + case valueNode: + return 0 + case *shortNode: + return calcSubtreeNodes(n.Val) + case *duoNode: + return 1 + calcSubtreeNodes(n.child1) + calcSubtreeNodes(n.child2) + case *fullNode: + size := 1 + for _, child := range n.Children { + size += calcSubtreeNodes(child) + } + return size + case *accountNode: + if n.code != nil { + return 1 + calcSubtreeNodes(n.storage) + } + return calcSubtreeNodes(n.storage) + case hashNode: + return 0 + } + return 0 +} diff --git a/trie/trie.go b/trie/trie.go index e020cc132..c02bb57b1 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -48,8 +48,6 @@ var ( type Trie struct { root node - touchFunc func(hex []byte, del bool) - newHasherFunc func() *hasher Version uint8 @@ -57,6 +55,8 @@ type Trie struct { binary bool hashMap map[common.Hash]node + + observers *ObserverMux } // New creates a trie with an existing root node from db. @@ -67,9 +67,9 @@ type Trie struct { // not exist in the database. Accessing the trie loads nodes from db on demand. func New(root common.Hash) *Trie { trie := &Trie{ - touchFunc: func([]byte, bool) {}, newHasherFunc: func() *hasher { return newHasher( /*valueNodesRlpEncoded = */ false) }, hashMap: make(map[common.Hash]node), + observers: NewTrieObserverMux(), } if (root != common.Hash{}) && root != EmptyRoot { trie.root = hashNode(root[:]) @@ -87,9 +87,9 @@ func NewBinary(root common.Hash) *Trie { // it is usually used for testing purposes. func NewTestRLPTrie(root common.Hash) *Trie { trie := &Trie{ - touchFunc: func([]byte, bool) {}, newHasherFunc: func() *hasher { return newHasher( /*valueNodesRlpEncoded = */ true) }, hashMap: make(map[common.Hash]node), + observers: NewTrieObserverMux(), } if (root != common.Hash{}) && root != EmptyRoot { trie.root = hashNode(root[:]) @@ -97,8 +97,8 @@ func NewTestRLPTrie(root common.Hash) *Trie { return trie } -func (t *Trie) SetTouchFunc(touchFunc func(hex []byte, del bool)) { - t.touchFunc = touchFunc +func (t *Trie) AddObserver(observer Observer) { + t.observers.AddChild(observer) } // Get returns the value for key stored in the trie. @@ -149,6 +149,8 @@ func (t *Trie) GetAccountCode(key []byte) (value []byte, gotValue bool) { return nil, gotValue } + t.observers.CodeNodeTouched(hex) + if accNode.code == nil { return nil, false } @@ -174,7 +176,7 @@ func (t *Trie) getAccount(origNode node, key []byte, pos int) (value *accountNod return nil, true } case *duoNode: - t.touchFunc(key[:pos], false) + t.observers.BranchNodeTouched(key[:pos]) i1, i2 := n.childrenIdx() switch key[pos] { case i1: @@ -185,7 +187,7 @@ func (t *Trie) getAccount(origNode node, key []byte, pos int) (value *accountNod return nil, true } case *fullNode: - t.touchFunc(key[:pos], false) + t.observers.BranchNodeTouched(key[:pos]) child := n.Children[key[pos]] return t.getAccount(child, key, pos+1) case hashNode: @@ -215,7 +217,7 @@ func (t *Trie) get(origNode node, key []byte, pos int) (value []byte, gotValue b } return case *duoNode: - t.touchFunc(key[:pos], false) + t.observers.BranchNodeTouched(key[:pos]) i1, i2 := n.childrenIdx() switch key[pos] { case i1: @@ -227,7 +229,7 @@ func (t *Trie) get(origNode node, key []byte, pos int) (value []byte, gotValue b } return case *fullNode: - t.touchFunc(key[:pos], false) + t.observers.BranchNodeTouched(key[:pos]) child := n.Children[key[pos]] if child == nil { return nil, true @@ -254,9 +256,10 @@ func (t *Trie) Update(key, value []byte) { hex = keyHexToBin(hex) } + newnode := valueNode(value) + if t.root == nil { - newnode := &shortNode{Key: hex, Val: valueNode(value)} - t.root = newnode + t.root = &shortNode{Key: hex, Val: newnode} } else { _, t.root = t.insert(t.root, hex, 0, valueNode(value)) } @@ -271,20 +274,18 @@ func (t *Trie) UpdateAccount(key []byte, acc *accounts.Account) { if t.binary { hex = keyHexToBin(hex) } - if t.root == nil { - var newnode node - if value.Root == EmptyRoot || value.Root == (common.Hash{}) { - newnode = &shortNode{Key: hex, Val: &accountNode{*value, nil, true, nil}} - } else { - newnode = &shortNode{Key: hex, Val: &accountNode{*value, hashNode(value.Root[:]), true, nil}} - } - t.root = newnode + + var newnode *accountNode + if value.Root == EmptyRoot || value.Root == (common.Hash{}) { + newnode = &accountNode{*value, nil, true, nil} } else { - if value.Root == EmptyRoot || value.Root == (common.Hash{}) { - _, t.root = t.insert(t.root, hex, 0, &accountNode{*value, nil, true, nil}) - } else { - _, t.root = t.insert(t.root, hex, 0, &accountNode{*value, hashNode(value.Root[:]), true, nil}) - } + newnode = &accountNode{*value, hashNode(value.Root[:]), true, nil} + } + + if t.root == nil { + t.root = &shortNode{Key: hex, Val: newnode} + } else { + _, t.root = t.insert(t.root, hex, 0, newnode) } } @@ -311,6 +312,7 @@ func (t *Trie) UpdateAccountCode(key []byte, code codeNode) error { accNode.code = code + // t.insert will call the observer methods itself _, t.root = t.insert(t.root, hex, 0, accNode) return nil } @@ -467,11 +469,19 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node) (updated b if updated { if !bytes.Equal(origAccN.CodeHash[:], vAccN.CodeHash[:]) { origAccN.code = nil + } else if vAccN.code != nil { + origAccN.code = vAccN.code } origAccN.Account.Copy(&vAccN.Account) origAccN.rootCorrect = false } newNode = origAccN + + if len(origAccN.code) > 0 { + t.observers.CodeNodeSizeChanged(key[:pos-1], uint(len(origAccN.code))) + } else { + t.observers.CodeNodeDeleted(key[:pos-1]) + } return } @@ -532,22 +542,21 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node) (updated b // Replace this shortNode with the branch if it occurs at index 0. if matchlen == 0 { - t.touchFunc(key[:pos], false) newNode = branch // current node leaves the generation, but new node branch joins it } else { // Otherwise, replace it with a short node leading up to the branch. - t.touchFunc(key[:pos+matchlen], false) n.Key = common.CopyBytes(key[pos : pos+matchlen]) n.Val = branch n.ref.len = 0 newNode = n } + t.observers.BranchNodeCreated(key[:pos+matchlen]) updated = true } return case *duoNode: - t.touchFunc(key[:pos], false) + t.observers.BranchNodeTouched(key[:pos]) i1, i2 := n.childrenIdx() switch key[pos] { case i1: @@ -585,7 +594,7 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node) (updated b return case *fullNode: - t.touchFunc(key[:pos], false) + t.observers.BranchNodeTouched(key[:pos]) child := n.Children[key[pos]] if child == nil { t.evictNodeFromHashMap(n) @@ -635,9 +644,8 @@ func (t *Trie) getNode(hex []byte, doTouch bool) (node, node, bool) { } case *duoNode: if doTouch { - t.touchFunc(hex[:pos], false) + t.observers.BranchNodeTouched(hex[:pos]) } - i1, i2 := n.childrenIdx() switch hex[pos] { case i1: @@ -653,7 +661,7 @@ func (t *Trie) getNode(hex []byte, doTouch bool) (node, node, bool) { } case *fullNode: if doTouch { - t.touchFunc(hex[:pos], false) + t.observers.BranchNodeTouched(hex[:pos]) } child := n.Children[hex[pos]] if child == nil { @@ -733,7 +741,12 @@ func (t *Trie) touchAll(n node, hex []byte, del bool) { t.touchAll(n.Val, hexVal, del) } case *duoNode: - t.touchFunc(hex, del) + if del { + t.observers.BranchNodeDeleted(hex) + } else { + t.observers.BranchNodeCreated(hex) + t.observers.BranchNodeLoaded(hex) + } i1, i2 := n.childrenIdx() hex1 := make([]byte, len(hex)+1) copy(hex1, hex) @@ -744,13 +757,23 @@ func (t *Trie) touchAll(n node, hex []byte, del bool) { t.touchAll(n.child1, hex1, del) t.touchAll(n.child2, hex2, del) case *fullNode: - t.touchFunc(hex, del) + if del { + t.observers.BranchNodeDeleted(hex) + } else { + t.observers.BranchNodeCreated(hex) + t.observers.BranchNodeLoaded(hex) + } for i, child := range n.Children { if child != nil { t.touchAll(child, concat(hex, byte(i)), del) } } case *accountNode: + if del { + t.observers.CodeNodeDeleted(hex) + } else { + t.observers.CodeNodeTouched(hex) + } if n.storage != nil { t.touchAll(n.storage, hex, del) } @@ -853,15 +876,15 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, preserveAccountNo case i1: updated, nn = t.delete(n.child1, key, keyStart+1, preserveAccountNode) if !updated { - t.touchFunc(key[:keyStart], false) newNode = n + t.observers.BranchNodeTouched(key[:keyStart]) } else { t.evictNodeFromHashMap(n) if nn == nil { - t.touchFunc(key[:keyStart], true) newNode = t.convertToShortNode(n.child2, uint(i2)) + t.observers.BranchNodeDeleted(key[:keyStart]) } else { - t.touchFunc(key[:keyStart], false) + t.observers.BranchNodeTouched(key[:keyStart]) n.child1 = nn n.ref.len = 0 newNode = n @@ -870,22 +893,21 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, preserveAccountNo case i2: updated, nn = t.delete(n.child2, key, keyStart+1, preserveAccountNode) if !updated { - t.touchFunc(key[:keyStart], false) newNode = n + t.observers.BranchNodeTouched(key[:keyStart]) } else { t.evictNodeFromHashMap(n) if nn == nil { - t.touchFunc(key[:keyStart], true) newNode = t.convertToShortNode(n.child1, uint(i1)) + t.observers.BranchNodeDeleted(key[:keyStart]) } else { - t.touchFunc(key[:keyStart], false) + t.observers.BranchNodeTouched(key[:keyStart]) n.child2 = nn n.ref.len = 0 newNode = n } } default: - t.touchFunc(key[:keyStart], false) updated = false newNode = n } @@ -895,8 +917,8 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, preserveAccountNo child := n.Children[key[keyStart]] updated, nn = t.delete(child, key, keyStart+1, preserveAccountNode) if !updated { - t.touchFunc(key[:keyStart], false) newNode = n + t.observers.BranchNodeTouched(key[:keyStart]) } else { t.evictNodeFromHashMap(n) n.Children[key[keyStart]] = nn @@ -926,10 +948,10 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, preserveAccountNo } } if count == 1 { - t.touchFunc(key[:keyStart], true) newNode = t.convertToShortNode(n.Children[pos1], uint(pos1)) + t.observers.BranchNodeDeleted(key[:keyStart]) } else if count == 2 { - t.touchFunc(key[:keyStart], false) + t.observers.BranchNodeTouched(key[:keyStart]) duo := &duoNode{} if pos1 == int(key[keyStart]) { duo.child1 = nn @@ -944,7 +966,7 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, preserveAccountNo duo.mask = (1 << uint(pos1)) | (uint32(1) << uint(pos2)) newNode = duo } else if count > 2 { - t.touchFunc(key[:keyStart], false) + t.observers.BranchNodeTouched(key[:keyStart]) // n still contains at least three values and cannot be reduced. n.ref.len = 0 newNode = n @@ -960,20 +982,23 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, preserveAccountNo case *accountNode: if keyStart >= len(key) || key[keyStart] == 16 { // Key terminates here + h := key[:keyStart] + if h[len(h)-1] == 16 { + h = h[:len(h)-1] + } if n.storage != nil { - h := key[:keyStart] - if h[len(h)-1] == 16 { - h = h[:len(h)-1] - } // Mark all the storage nodes as deleted t.touchAll(n.storage, h, true) } if preserveAccountNode { n.storage = nil + n.code = nil n.Root = EmptyRoot n.rootCorrect = true + t.observers.CodeNodeDeleted(h) return true, n } + return true, nil } updated, nn = t.delete(n.storage, key, keyStart, preserveAccountNode) @@ -1067,11 +1092,22 @@ func (t *Trie) DeepHash(keyPrefix []byte) (bool, common.Hash) { return true, accNode.Root } -func (t *Trie) unload(hex []byte) { +func (t *Trie) EvictNode(hex []byte) { + isCode := IsPointingToCode(hex) + if isCode { + hex = AddrHashFromCodeKey(hex) + } + nd, parent, ok := t.getNode(hex, false) if !ok { return } + if accNode, ok := parent.(*accountNode); isCode && ok { + // add special treatment to code nodes + accNode.code = nil + return + } + switch nd.(type) { case valueNode, hashNode: return @@ -1084,9 +1120,12 @@ func (t *Trie) unload(hex []byte) { var hn common.Hash if nd == nil { fmt.Printf("nd == nil, hex %x, parent node: %T\n", hex, parent) + return } copy(hn[:], nd.reference()) hnode := hashNode(hn[:]) + t.observers.WillUnloadBranchNode(hex, hn) + switch p := parent.(type) { case nil: t.root = hnode @@ -1108,57 +1147,12 @@ func (t *Trie) unload(hex []byte) { } } -func (t *Trie) CountPrunableNodes() int { - return t.countPrunableNodes(t.root, []byte{}, false) +func (t *Trie) TrieSize() int { + return calcSubtreeSize(t.root) } -func (t *Trie) countPrunableNodes(nd node, hex []byte, print bool) int { - switch n := nd.(type) { - case nil: - return 0 - case valueNode: - return 0 - case *accountNode: - return t.countPrunableNodes(n.storage, hex, print) - case hashNode: - return 0 - case *shortNode: - var hexVal []byte - if _, ok := n.Val.(valueNode); !ok { // Don't need to compute prefix for a leaf - h := n.Key - if h[len(h)-1] == 16 { - h = h[:len(h)-1] - } - hexVal = concat(hex, h...) - } - //@todo accountNode? - return t.countPrunableNodes(n.Val, hexVal, print) - case *duoNode: - i1, i2 := n.childrenIdx() - hex1 := make([]byte, len(hex)+1) - copy(hex1, hex) - hex1[len(hex)] = byte(i1) - hex2 := make([]byte, len(hex)+1) - copy(hex2, hex) - hex2[len(hex)] = byte(i2) - if print { - fmt.Printf("%T node: %x\n", n, hex) - } - return 1 + t.countPrunableNodes(n.child1, hex1, print) + t.countPrunableNodes(n.child2, hex2, print) - case *fullNode: - if print { - fmt.Printf("%T node: %x\n", n, hex) - } - count := 0 - for i, child := range n.Children { - if child != nil { - count += t.countPrunableNodes(child, concat(hex, byte(i)), print) - } - } - return 1 + count - default: - panic("") - } +func (t *Trie) NumberOfAccounts() int { + return calcSubtreeNodes(t.root) } func (t *Trie) hashRoot() (node, error) { diff --git a/trie/trie_eviction.go b/trie/trie_eviction.go new file mode 100644 index 000000000..6be7c3d3b --- /dev/null +++ b/trie/trie_eviction.go @@ -0,0 +1,343 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty off +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Pruning of the Merkle Patricia trees + +package trie + +import ( + "fmt" + "sort" + "strings" +) + +type AccountEvicter interface { + EvictNode([]byte) +} + +type generations struct { + blockNumToGeneration map[uint64]*generation + keyToBlockNum map[string]uint64 + oldestBlockNum uint64 + totalSize int64 +} + +func newGenerations() *generations { + return &generations{ + make(map[uint64]*generation), + make(map[string]uint64), + 0, + 0, + } +} + +func (gs *generations) add(blockNum uint64, key []byte, size uint) { + if _, ok := gs.keyToBlockNum[string(key)]; ok { + gs.updateSize(blockNum, key, size) + return + } + + generation, ok := gs.blockNumToGeneration[blockNum] + if !ok { + generation = newGeneration() + gs.blockNumToGeneration[blockNum] = generation + } + generation.add(key, size) + gs.keyToBlockNum[string(key)] = blockNum + if gs.oldestBlockNum > blockNum { + gs.oldestBlockNum = blockNum + } + gs.totalSize += int64(size) +} + +func (gs *generations) touch(blockNum uint64, key []byte) { + if len(gs.blockNumToGeneration) == 0 { + return + } + + oldBlockNum, ok := gs.keyToBlockNum[string(key)] + if !ok { + return + } + + oldGeneration, ok := gs.blockNumToGeneration[oldBlockNum] + if !ok { + return + } + + currentGeneration, ok := gs.blockNumToGeneration[blockNum] + if !ok { + currentGeneration = newGeneration() + gs.blockNumToGeneration[blockNum] = currentGeneration + } + + currentGeneration.grabFrom(key, oldGeneration) + + gs.keyToBlockNum[string(key)] = blockNum + + if gs.oldestBlockNum > blockNum { + gs.oldestBlockNum = blockNum + } + + if oldGeneration.empty() { + delete(gs.blockNumToGeneration, oldBlockNum) + } +} + +func (gs *generations) remove(key []byte) { + oldBlockNum, ok := gs.keyToBlockNum[string(key)] + if !ok { + return + } + generation, ok := gs.blockNumToGeneration[oldBlockNum] + if !ok { + return + } + sizeDiff := generation.remove(key) + gs.totalSize += sizeDiff + delete(gs.keyToBlockNum, string(key)) +} + +func (gs *generations) updateSize(blockNum uint64, key []byte, newSize uint) { + oldBlockNum, ok := gs.keyToBlockNum[string(key)] + if !ok { + gs.add(blockNum, key, newSize) + return + } + generation, ok := gs.blockNumToGeneration[oldBlockNum] + if !ok { + gs.add(blockNum, key, newSize) + return + } + + sizeDiff := generation.updateAccountSize(key, newSize) + gs.totalSize += sizeDiff + gs.touch(blockNum, key) +} + +// popKeysToEvict returns the keys to evict from the trie, +// also removing them from generations +func (gs *generations) popKeysToEvict(threshold uint64) []string { + keys := make([]string, 0) + for uint64(gs.totalSize) > threshold && len(gs.blockNumToGeneration) > 0 { + generation, ok := gs.blockNumToGeneration[gs.oldestBlockNum] + if !ok { + gs.oldestBlockNum++ + continue + } + + gs.totalSize -= generation.totalSize + if gs.totalSize < 0 { + gs.totalSize = 0 + } + keysToEvict := generation.keys() + keys = append(keys, keysToEvict...) + for _, k := range keysToEvict { + delete(gs.keyToBlockNum, k) + } + delete(gs.blockNumToGeneration, gs.oldestBlockNum) + gs.oldestBlockNum++ + } + return keys +} + +type generation struct { + sizesByKey map[string]uint + totalSize int64 +} + +func newGeneration() *generation { + return &generation{ + make(map[string]uint), + 0, + } +} + +func (g *generation) empty() bool { + return len(g.sizesByKey) == 0 +} + +func (g *generation) grabFrom(key []byte, other *generation) { + if g == other { + return + } + + keyStr := string(key) + size, ok := other.sizesByKey[keyStr] + if !ok { + return + } + + g.sizesByKey[keyStr] = size + g.totalSize += int64(size) + other.totalSize -= int64(size) + + if other.totalSize < 0 { + other.totalSize = 0 + } + + delete(other.sizesByKey, keyStr) +} + +func (g *generation) add(key []byte, size uint) { + g.sizesByKey[string(key)] = size + g.totalSize += int64(size) +} + +func (g *generation) updateAccountSize(key []byte, size uint) int64 { + oldSize := g.sizesByKey[string(key)] + g.sizesByKey[string(key)] = size + diff := int64(size) - int64(oldSize) + g.totalSize += diff + return diff +} + +func (g *generation) remove(key []byte) int64 { + oldSize := g.sizesByKey[string(key)] + delete(g.sizesByKey, string(key)) + g.totalSize -= int64(oldSize) + if g.totalSize < 0 { + g.totalSize = 0 + } + + return -1 * int64(oldSize) +} + +func (g *generation) keys() []string { + keys := make([]string, len(g.sizesByKey)) + i := 0 + for k := range g.sizesByKey { + keys[i] = k + i++ + } + return keys +} + +type Eviction struct { + NoopObserver // make sure that we don't need to implement unnecessary observer methods + + blockNumber uint64 + + generations *generations +} + +func NewEviction() *Eviction { + return &Eviction{ + generations: newGenerations(), + } +} + +func (tp *Eviction) SetBlockNumber(blockNumber uint64) { + tp.blockNumber = blockNumber +} + +func (tp *Eviction) BlockNumber() uint64 { + return tp.blockNumber +} + +func (tp *Eviction) BranchNodeCreated(hex []byte) { + key := hex + tp.generations.add(tp.blockNumber, key, 1) +} + +func (tp *Eviction) BranchNodeDeleted(hex []byte) { + key := hex + tp.generations.remove(key) +} + +func (tp *Eviction) BranchNodeTouched(hex []byte) { + key := hex + tp.generations.touch(tp.blockNumber, key) +} + +func (tp *Eviction) CodeNodeCreated(hex []byte, size uint) { + key := hex + tp.generations.add(tp.blockNumber, CodeKeyFromAddrHash(key), size) +} + +func (tp *Eviction) CodeNodeDeleted(hex []byte) { + key := hex + tp.generations.remove(CodeKeyFromAddrHash(key)) +} + +func (tp *Eviction) CodeNodeTouched(hex []byte) { + key := hex + tp.generations.touch(tp.blockNumber, CodeKeyFromAddrHash(key)) +} + +func (tp *Eviction) CodeNodeSizeChanged(hex []byte, newSize uint) { + key := hex + tp.generations.updateSize(tp.blockNumber, CodeKeyFromAddrHash(key), newSize) +} + +func evictList(evicter AccountEvicter, hexes []string) bool { + var empty = false + sort.Strings(hexes) + + // from long to short -- a naive way to first clean up nodes and then accounts + // FIXME: optimize to avoid the same paths + for i := len(hexes) - 1; i >= 0; i-- { + evicter.EvictNode([]byte(hexes[i])) + } + return empty +} + +// EvictToFitSize evicts mininum number of generations necessary so that the total +// size of accounts left is fits into the provided threshold +func (tp *Eviction) EvictToFitSize( + evicter AccountEvicter, + threshold uint64, +) bool { + + if uint64(tp.generations.totalSize) <= threshold { + return false + } + + keys := tp.generations.popKeysToEvict(threshold) + + return evictList(evicter, keys) +} + +func (tp *Eviction) TotalSize() uint64 { + return uint64(tp.generations.totalSize) +} + +func (tp *Eviction) NumberOf() uint64 { + total := uint64(0) + for _, gen := range tp.generations.blockNumToGeneration { + if gen == nil { + continue + } + total += uint64(len(gen.sizesByKey)) + } + return total +} + +func (tp *Eviction) DebugDump() string { + var sb strings.Builder + + for block, gen := range tp.generations.blockNumToGeneration { + if gen.empty() { + continue + } + sb.WriteString(fmt.Sprintf("Block: %v\n", block)) + for key, size := range gen.sizesByKey { + sb.WriteString(fmt.Sprintf(" %x->%v\n", key, size)) + } + } + + return sb.String() +} diff --git a/trie/trie_eviction_test.go b/trie/trie_eviction_test.go new file mode 100644 index 000000000..0eb4e5754 --- /dev/null +++ b/trie/trie_eviction_test.go @@ -0,0 +1,394 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty off +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Pruning of the Merkle Patricia trees + +package trie + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type mockAccountEvicter struct { + keys [][]byte +} + +func newMockAccountEvicter() *mockAccountEvicter { + return &mockAccountEvicter{make([][]byte, 0)} +} + +func (m *mockAccountEvicter) EvictNode(key []byte) { + m.keys = append(m.keys, key) +} + +func TestEvictionBasicOperations(t *testing.T) { + eviction := NewEviction() + eviction.SetBlockNumber(1) + + key := []byte{0x01, 0x01, 0x01, 0x01} + hex := keybytesToHex(key) + eviction.CodeNodeCreated(hex, 1024) + + assert.Equal(t, 1024, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 1024, int(eviction.generations.blockNumToGeneration[1].totalSize), "should register size of gen") + + // grow + eviction.CodeNodeSizeChanged(hex, 2048) + + assert.Equal(t, 2048, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 2048, int(eviction.generations.blockNumToGeneration[1].totalSize), "should register size of gen") + + // shrink + eviction.CodeNodeSizeChanged(hex, 100) + + assert.Equal(t, 100, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 100, int(eviction.generations.blockNumToGeneration[1].totalSize), "should register size of gen") + + // shrink + eviction.CodeNodeDeleted(hex) + + assert.Equal(t, 0, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 0, int(eviction.generations.blockNumToGeneration[1].totalSize), "should register size of gen") +} + +func TestEvictionPartialSingleGen(t *testing.T) { + eviction := NewEviction() + eviction.SetBlockNumber(1) + + // create 100kb or accounts + for i := 0; i < 100; i++ { + key := []byte{0x01, 0x01, 0x01, byte(i)} + eviction.BranchNodeCreated(keybytesToHex(key)) + } + + assert.Equal(t, 100, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 0, int(eviction.generations.oldestBlockNum), "should register block num") + assert.Equal(t, 100, int(eviction.generations.blockNumToGeneration[1].totalSize), "should register size of gen") + + mock := newMockAccountEvicter() + + eviction.EvictToFitSize(mock, 99) + + assert.Equal(t, 0, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 0, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 2, int(eviction.generations.oldestBlockNum), "should register block num") + assert.Equal(t, 100, len(mock.keys), "should evict all 100 accounts") +} + +func TestEvictionFullSingleGen(t *testing.T) { + eviction := NewEviction() + eviction.SetBlockNumber(1) + + // create 100kb or accounts + for i := 0; i < 100; i++ { + key := []byte{0x01, 0x01, 0x01, byte(i)} + eviction.BranchNodeCreated(keybytesToHex(key)) + } + + assert.Equal(t, 100, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 0, int(eviction.generations.oldestBlockNum), "should register block num") + assert.Equal(t, 100, int(eviction.generations.blockNumToGeneration[1].totalSize), "should register size of gen") + + mock := newMockAccountEvicter() + + eviction.EvictToFitSize(mock, 0) + + assert.Equal(t, 0, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 0, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 2, int(eviction.generations.oldestBlockNum), "should register block num") + assert.Equal(t, 100, len(mock.keys), "should evict all 100 accounts") +} + +func TestEvictionNoNeedSingleGen(t *testing.T) { + eviction := NewEviction() + eviction.SetBlockNumber(1) + + // create 100kb or accounts + for i := 0; i < 100; i++ { + key := []byte{0x01, 0x01, 0x01, byte(i)} + eviction.BranchNodeCreated(keybytesToHex(key)) + } + + assert.Equal(t, 100, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 0, int(eviction.generations.oldestBlockNum), "should register block num") + assert.Equal(t, 100, int(eviction.generations.blockNumToGeneration[1].totalSize), "should register size of gen") + + mock := newMockAccountEvicter() + + eviction.EvictToFitSize(mock, 100) + + assert.Equal(t, 100, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 0, int(eviction.generations.oldestBlockNum), "should register block num") + assert.Equal(t, 100, int(eviction.generations.blockNumToGeneration[1].totalSize), "should register size of gen") + + assert.Equal(t, 0, len(mock.keys), "should evict all 100 accounts") +} + +func TestEvictionNoNeedMultipleGen(t *testing.T) { + eviction := NewEviction() + eviction.SetBlockNumber(1) + + // create 10kb or accounts + for i := 0; i < 10; i++ { + key := []byte{0x01, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 1024) + } + + eviction.SetBlockNumber(2) + + // create 10kb or accounts + for i := 0; i < 10; i++ { + key := []byte{0x02, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 1024) + } + + eviction.SetBlockNumber(4) + + // create 10kb or accounts + for i := 0; i < 10; i++ { + key := []byte{0x03, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 1024) + } + + eviction.SetBlockNumber(5) + + // create 10kb or accounts + for i := 0; i < 10; i++ { + key := []byte{0x04, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 1024) + } + + eviction.SetBlockNumber(7) + + // create 10kb or accounts + for i := 0; i < 10; i++ { + key := []byte{0x05, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 1024) + } + + // 50 kb total + + assert.Equal(t, 50*1024, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 5, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 0, int(eviction.generations.oldestBlockNum), "should register block num") + for _, i := range []uint64{1, 2, 4, 5, 7} { + assert.Equal(t, 10*1024, int(eviction.generations.blockNumToGeneration[i].totalSize), "should register size of gen") + } + + mock := newMockAccountEvicter() + + eviction.EvictToFitSize(mock, 50*1024) + + assert.Equal(t, 50*1024, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 5, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 0, int(eviction.generations.oldestBlockNum), "should register block num") + + assert.Equal(t, 0, len(mock.keys), "should not evict anything") +} + +func TestEvictionPartialMultipleGen(t *testing.T) { + eviction := NewEviction() + eviction.SetBlockNumber(1) + + // create 10kb or accounts + for i := 0; i < 10; i++ { + key := []byte{0x01, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 1024) + } + + eviction.SetBlockNumber(2) + + // create 10kb or accounts + for i := 0; i < 10; i++ { + key := []byte{0x02, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 1024) + } + + eviction.SetBlockNumber(4) + + // create 10kb or accounts + for i := 0; i < 10; i++ { + key := []byte{0x03, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 1024) + } + + eviction.SetBlockNumber(5) + + // create 10kb or accounts + for i := 0; i < 10; i++ { + key := []byte{0x04, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 1024) + } + + eviction.SetBlockNumber(7) + + // create 10kb or accounts + for i := 0; i < 10; i++ { + key := []byte{0x05, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 1024) + } + + // 50 kb total + + assert.Equal(t, 50*1024, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 5, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 0, int(eviction.generations.oldestBlockNum), "should register block num") + for _, i := range []uint64{1, 2, 4, 5, 7} { + assert.Equal(t, 10*1024, int(eviction.generations.blockNumToGeneration[i].totalSize), "should register size of gen") + } + + mock := newMockAccountEvicter() + + eviction.EvictToFitSize(mock, 20*1024) + + assert.Equal(t, 20*1024, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 2, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 5, int(eviction.generations.oldestBlockNum), "should register block num") + for _, i := range []uint64{5, 7} { + assert.Equal(t, 10*1024, int(eviction.generations.blockNumToGeneration[i].totalSize), "should register size of gen") + } + assert.Equal(t, 30, len(mock.keys), "should evict only 3 generations") +} + +func TestEvictionFullMultipleGen(t *testing.T) { + eviction := NewEviction() + eviction.SetBlockNumber(1) + + // create 10kb or accounts + for i := 0; i < 10; i++ { + key := []byte{0x01, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 1024) + } + + eviction.SetBlockNumber(2) + + // create 10kb or accounts + for i := 0; i < 10; i++ { + key := []byte{0x02, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 1024) + } + + eviction.SetBlockNumber(4) + + // create 10kb or accounts + for i := 0; i < 10; i++ { + key := []byte{0x03, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 1024) + } + + eviction.SetBlockNumber(5) + + // create 10kb or accounts + for i := 0; i < 10; i++ { + key := []byte{0x04, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 1024) + } + + eviction.SetBlockNumber(7) + + // create 10kb or accounts + for i := 0; i < 10; i++ { + key := []byte{0x05, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 1024) + } + + // 50 kb total + + assert.Equal(t, 50*1024, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 5, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 0, int(eviction.generations.oldestBlockNum), "should register block num") + for _, i := range []uint64{1, 2, 4, 5, 7} { + assert.Equal(t, 10*1024, int(eviction.generations.blockNumToGeneration[i].totalSize), "should register size of gen") + } + + mock := newMockAccountEvicter() + + eviction.EvictToFitSize(mock, 0) + + assert.Equal(t, 0, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 0, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 8, int(eviction.generations.oldestBlockNum), "should register block num") + + assert.Equal(t, 50, len(mock.keys), "should evict only 3 generations") + +} + +func TestEvictionMoveBetweenGen(t *testing.T) { + eviction := NewEviction() + + eviction.SetBlockNumber(2) + + for i := 0; i < 2; i++ { + key := []byte{0x01, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 10*1024) + } + + eviction.SetBlockNumber(4) + + // create 10kb or accounts + for i := 0; i < 1; i++ { + key := []byte{0x04, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 20*1024) + } + + eviction.SetBlockNumber(5) + + // create 10kb or accounts + for i := 0; i < 10; i++ { + key := []byte{0x05, 0x01, 0x01, byte(i)} + eviction.CodeNodeCreated(keybytesToHex(key), 1024) + } + + assert.Equal(t, 50*1024, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 3, len(eviction.generations.blockNumToGeneration), "should register generation") + + eviction.CodeNodeTouched(keybytesToHex([]byte{0x01, 0x01, 0x01, 0x00})) + + assert.Equal(t, 50*1024, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 3, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration[2].keys()), "should move one acc") + assert.Equal(t, 11, len(eviction.generations.blockNumToGeneration[5].keys()), "should move one acc to gen 5") + assert.Equal(t, CodeKeyFromAddrHash(keybytesToHex([]byte{0x01, 0x01, 0x01, 0x01})), []byte(eviction.generations.blockNumToGeneration[2].keys()[0]), "should move one acc") + + // move the acc again to the new block! + eviction.SetBlockNumber(10) + eviction.CodeNodeTouched(keybytesToHex([]byte{0x01, 0x01, 0x01, 0x00})) + + assert.Equal(t, 50*1024, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 4, len(eviction.generations.blockNumToGeneration), "should register generation") + assert.Equal(t, 10, len(eviction.generations.blockNumToGeneration[5].keys()), "should move one acc from gen 5, again 10") + assert.Equal(t, CodeKeyFromAddrHash(keybytesToHex([]byte{0x01, 0x01, 0x01, 0x00})), []byte(eviction.generations.blockNumToGeneration[10].keys()[0]), "should move one acc") + + // move the last acc from the gen + eviction.CodeNodeTouched(keybytesToHex([]byte{0x01, 0x01, 0x01, 0x01})) + + assert.Equal(t, 50*1024, int(eviction.TotalSize()), "should register all accounts") + assert.Equal(t, 3, len(eviction.generations.blockNumToGeneration), "should register generation") + _, found := eviction.generations.blockNumToGeneration[2] + assert.False(t, found, "2nd generation is empty and should be removed") + + assert.Equal(t, 2, len(eviction.generations.blockNumToGeneration[10].keys()), "should move one acc") +} diff --git a/trie/trie_observers.go b/trie/trie_observers.go new file mode 100644 index 000000000..4aa23cdc3 --- /dev/null +++ b/trie/trie_observers.go @@ -0,0 +1,104 @@ +package trie + +import "github.com/ledgerwatch/turbo-geth/common" + +type Observer interface { + BranchNodeCreated(hex []byte) + BranchNodeDeleted(hex []byte) + BranchNodeTouched(hex []byte) + + CodeNodeCreated(hex []byte, size uint) + CodeNodeDeleted(hex []byte) + CodeNodeTouched(hex []byte) + CodeNodeSizeChanged(hex []byte, newSize uint) + + WillUnloadBranchNode(key []byte, nodeHash common.Hash) + BranchNodeLoaded(prefixAsNibbles []byte) +} + +var _ Observer = (*NoopObserver)(nil) // make sure that NoopTrieObserver is compliant + +// NoopTrieObserver might be used to emulate optional methods in observers +type NoopObserver struct{} + +func (*NoopObserver) BranchNodeCreated(_ []byte) {} +func (*NoopObserver) BranchNodeDeleted(_ []byte) {} +func (*NoopObserver) BranchNodeTouched(_ []byte) {} +func (*NoopObserver) CodeNodeCreated(_ []byte, _ uint) {} +func (*NoopObserver) CodeNodeDeleted(_ []byte) {} +func (*NoopObserver) CodeNodeTouched(_ []byte) {} +func (*NoopObserver) CodeNodeSizeChanged(_ []byte, _ uint) {} +func (*NoopObserver) WillUnloadBranchNode(_ []byte, _ common.Hash) {} +func (*NoopObserver) BranchNodeLoaded(_ []byte) {} + +// TrieObserverMux multiplies the callback methods and sends them to +// all it's children. +type ObserverMux struct { + children []Observer +} + +func NewTrieObserverMux() *ObserverMux { + return &ObserverMux{make([]Observer, 0)} +} + +func (mux *ObserverMux) AddChild(child Observer) { + if child == nil { + return + } + + mux.children = append(mux.children, child) +} + +func (mux *ObserverMux) BranchNodeCreated(hex []byte) { + for _, child := range mux.children { + child.BranchNodeCreated(hex) + } +} + +func (mux *ObserverMux) BranchNodeDeleted(hex []byte) { + for _, child := range mux.children { + child.BranchNodeDeleted(hex) + } +} + +func (mux *ObserverMux) BranchNodeTouched(hex []byte) { + for _, child := range mux.children { + child.BranchNodeTouched(hex) + } +} + +func (mux *ObserverMux) CodeNodeCreated(hex []byte, size uint) { + for _, child := range mux.children { + child.CodeNodeCreated(hex, size) + } +} + +func (mux *ObserverMux) CodeNodeDeleted(hex []byte) { + for _, child := range mux.children { + child.CodeNodeDeleted(hex) + } +} + +func (mux *ObserverMux) CodeNodeTouched(hex []byte) { + for _, child := range mux.children { + child.CodeNodeTouched(hex) + } +} + +func (mux *ObserverMux) CodeNodeSizeChanged(hex []byte, newSize uint) { + for _, child := range mux.children { + child.CodeNodeSizeChanged(hex, newSize) + } +} + +func (mux *ObserverMux) WillUnloadBranchNode(key []byte, nodeHash common.Hash) { + for _, child := range mux.children { + child.WillUnloadBranchNode(key, nodeHash) + } +} + +func (mux *ObserverMux) BranchNodeLoaded(prefixAsNibbles []byte) { + for _, child := range mux.children { + child.BranchNodeLoaded(prefixAsNibbles) + } +} diff --git a/trie/trie_observers_test.go b/trie/trie_observers_test.go new file mode 100644 index 000000000..3f4eeb6d5 --- /dev/null +++ b/trie/trie_observers_test.go @@ -0,0 +1,454 @@ +package trie + +import ( + "fmt" + "math/rand" + "testing" + + "github.com/ledgerwatch/turbo-geth/common" + "github.com/ledgerwatch/turbo-geth/common/dbutils" + "github.com/ledgerwatch/turbo-geth/core/types/accounts" + "github.com/ledgerwatch/turbo-geth/crypto" + "github.com/stretchr/testify/assert" +) + +func genAccount() *accounts.Account { + acc := &accounts.Account{} + acc.Nonce = 123 + return acc +} + +func genNKeys(n int) [][]byte { + result := make([][]byte, n) + for i := range result { + result[i] = crypto.Keccak256([]byte{0x0, 0x0, 0x0, byte(i % 256), byte(i / 256)}) + } + return result +} + +func genByteArrayOfLen(n int) []byte { + result := make([]byte, n) + for i := range result { + result[i] = byte(rand.Intn(255)) + } + return result +} + +type partialObserver struct { + NoopObserver + callbackCalled bool +} + +func (o *partialObserver) BranchNodeDeleted(_ []byte) { + o.callbackCalled = true +} + +type mockObserver struct { + createdNodes map[string]uint + deletedNodes map[string]struct{} + touchedNodes map[string]int + + unloadedNodes map[string]int + unloadedNodeHashes map[string][]byte + reloadedNodes map[string]int +} + +func newMockObserver() *mockObserver { + return &mockObserver{ + createdNodes: make(map[string]uint), + deletedNodes: make(map[string]struct{}), + touchedNodes: make(map[string]int), + unloadedNodes: make(map[string]int), + unloadedNodeHashes: make(map[string][]byte), + reloadedNodes: make(map[string]int), + } +} + +func (m *mockObserver) BranchNodeCreated(hex []byte) { + m.createdNodes[common.Bytes2Hex(hex)] = 1 +} + +func (m *mockObserver) CodeNodeCreated(hex []byte, size uint) { + m.createdNodes[common.Bytes2Hex(hex)] = size +} + +func (m *mockObserver) BranchNodeDeleted(hex []byte) { + m.deletedNodes[common.Bytes2Hex(hex)] = struct{}{} +} + +func (m *mockObserver) CodeNodeDeleted(hex []byte) { + m.deletedNodes[common.Bytes2Hex(hex)] = struct{}{} +} + +func (m *mockObserver) BranchNodeTouched(hex []byte) { + value := m.touchedNodes[common.Bytes2Hex(hex)] + value++ + m.touchedNodes[common.Bytes2Hex(hex)] = value +} + +func (m *mockObserver) CodeNodeTouched(hex []byte) { + value := m.touchedNodes[common.Bytes2Hex(hex)] + value++ + m.touchedNodes[common.Bytes2Hex(hex)] = value +} + +func (m *mockObserver) CodeNodeSizeChanged(hex []byte, newSize uint) { + m.createdNodes[common.Bytes2Hex(hex)] = newSize +} + +func (m *mockObserver) WillUnloadBranchNode(hex []byte, hash common.Hash) { + dictKey := common.Bytes2Hex(hex) + value := m.unloadedNodes[dictKey] + value++ + m.unloadedNodes[dictKey] = value + m.unloadedNodeHashes[dictKey] = common.CopyBytes(hash[:]) +} + +func (m *mockObserver) BranchNodeLoaded(hex []byte) { + dictKey := common.Bytes2Hex(hex) + value := m.reloadedNodes[dictKey] + value++ + m.reloadedNodes[dictKey] = value +} + +func TestObserversBranchNodesCreateDelete(t *testing.T) { + trie := newEmpty() + + observer := newMockObserver() + + trie.AddObserver(observer) + + keys := genNKeys(100) + + for _, key := range keys { + acc := genAccount() + code := genByteArrayOfLen(100) + codeHash := crypto.Keccak256(code) + acc.CodeHash = common.BytesToHash(codeHash) + trie.UpdateAccount(key, acc) + trie.UpdateAccountCode(key, codeNode(code)) //nolint:errcheck + } + + expectedNodes := calcSubtreeNodes(trie.root) + + assert.True(t, expectedNodes >= 100, "should register all code nodes") + assert.Equal(t, expectedNodes, len(observer.createdNodes), "should register all") + + for _, key := range keys { + trie.Delete(key) + } + + assert.Equal(t, expectedNodes, len(observer.deletedNodes), "should register all") +} + +func TestObserverCodeSizeChanged(t *testing.T) { + rand.Seed(9999) + + trie := newEmpty() + + observer := newMockObserver() + + trie.AddObserver(observer) + + keys := genNKeys(10) + + for _, key := range keys { + acc := genAccount() + code := genByteArrayOfLen(100) + codeHash := crypto.Keccak256(code) + acc.CodeHash = common.BytesToHash(codeHash) + trie.UpdateAccount(key, acc) + trie.UpdateAccountCode(key, codeNode(code)) //nolint:errcheck + + hex := keybytesToHex(key) + hex = hex[:len(hex)-1] + newSize, ok := observer.createdNodes[common.Bytes2Hex(hex)] + assert.True(t, ok, "account should be registed as created") + assert.Equal(t, 100, int(newSize), "account size should increase when the account code grows") + + code2 := genByteArrayOfLen(50) + codeHash2 := crypto.Keccak256(code2) + acc.CodeHash = common.BytesToHash(codeHash2) + trie.UpdateAccount(key, acc) + trie.UpdateAccountCode(key, codeNode(code2)) //nolint:errcheck + + newSize2, ok := observer.createdNodes[common.Bytes2Hex(hex)] + assert.True(t, ok, "account should be registed as created") + assert.Equal(t, -50, int(newSize2)-int(newSize), "account size should decrease when the account code shrinks") + } +} + +func TestObserverUnloadStorageNodes(t *testing.T) { + rand.Seed(9999) + + trie := newEmpty() + + observer := newMockObserver() + + trie.AddObserver(observer) + + key := genNKeys(1)[0] + + storageKeys := genNKeys(10) + // group all storage keys into a single fullNode + for i := range storageKeys { + storageKeys[i][0] = byte(0) + storageKeys[i][1] = byte(i) + } + + fullKeys := make([][]byte, len(storageKeys)) + for i, storageKey := range storageKeys { + fullKey := dbutils.GenerateCompositeTrieKey(common.BytesToHash(key), common.BytesToHash(storageKey)) + fullKeys[i] = fullKey + } + + acc := genAccount() + trie.UpdateAccount(key, acc) + + for i, fullKey := range fullKeys { + trie.Update(fullKey, []byte(fmt.Sprintf("test-value-%d", i))) + } + + rootHash := trie.Hash() + + // adding nodes doesn't add anything + assert.Equal(t, 0, len(observer.reloadedNodes), "adding nodes doesn't add anything") + + // unloading nodes adds to the list + + hex := keybytesToHex(key) + hex = hex[:len(hex)-1] + hex = append(hex, 0x0, 0x0, 0x0) + trie.EvictNode(hex) + + newRootHash := trie.Hash() + assert.Equal(t, rootHash, newRootHash, "root hash shouldn't change") + + assert.Equal(t, 1, len(observer.unloadedNodes), "should unload one full node") + + hex = keybytesToHex(key) + + storageKey := fmt.Sprintf("%s000000", common.Bytes2Hex(hex[:len(hex)-1])) + assert.Equal(t, 1, observer.unloadedNodes[storageKey], "should unload structure nodes") + + accNode, ok := trie.getAccount(trie.root, hex, 0) + assert.True(t, ok, "account should be found") + + sn, ok := accNode.storage.(*shortNode) + assert.True(t, ok, "storage should be the shortnode contaning hash") + + _, ok = sn.Val.(hashNode) + assert.True(t, ok, "storage should be the shortnode contaning hash") +} + +func TestObserverLoadNodes(t *testing.T) { + rand.Seed(9999) + + subtrie := newEmpty() + + observer := newMockObserver() + + // this test needs a specific trie structure + // ( full ) + // (full) (duo) (duo) + // (short)(short)(short) (short)(short) (short)(short) + // (acc1) (acc2) (acc3) (acc4) (acc5) (acc6) (acc7) + // + // to ensure this structure we override prefixes of + // random account keys with the follwing paths + + prefixes := [][]byte{ + {0x00, 0x00}, //acc1 + {0x00, 0x02}, //acc2 + {0x00, 0x05}, //acc3 + {0x02, 0x02}, //acc4 + {0x02, 0x05}, //acc5 + {0x0A, 0x00}, //acc6 + {0x0A, 0x03}, //acc7 + } + + keys := genNKeys(7) + for i := range keys { + copy(keys[i][:2], prefixes[i]) + } + + storageKeys := genNKeys(10) + // group all storage keys into a single fullNode + for i := range storageKeys { + storageKeys[i][0] = byte(0) + storageKeys[i][1] = byte(i) + } + + for _, key := range keys { + acc := genAccount() + subtrie.UpdateAccount(key, acc) + + for i, storageKey := range storageKeys { + fullKey := dbutils.GenerateCompositeTrieKey(common.BytesToHash(key), common.BytesToHash(storageKey)) + subtrie.Update(fullKey, []byte(fmt.Sprintf("test-value-%d", i))) + } + } + + trie := newEmpty() + trie.AddObserver(observer) + + trie.hook([]byte{}, subtrie.root) + + // fullNode + assert.Equal(t, 1, observer.reloadedNodes["000000"], "should reload structure nodes") + // duoNode + assert.Equal(t, 1, observer.reloadedNodes["000200"], "should reload structure nodes") + // duoNode + assert.Equal(t, 1, observer.reloadedNodes["000a00"], "should reload structure nodes") + // root + assert.Equal(t, 1, observer.reloadedNodes["00"], "should reload structure nodes") + + // check storages (should have a single fullNode per account) + for _, key := range keys { + hex := keybytesToHex(key) + storageKey := fmt.Sprintf("%s000000", common.Bytes2Hex(hex[:len(hex)-1])) + assert.Equal(t, 1, observer.reloadedNodes[storageKey], "should reload structure nodes") + } +} + +func TestObserverTouches(t *testing.T) { + rand.Seed(9999) + + trie := newEmpty() + + observer := newMockObserver() + + trie.AddObserver(observer) + + keys := genNKeys(3) + for i := range keys { + keys[i][0] = 0x00 // 3 belong to the same branch + } + + var acc *accounts.Account + for _, key := range keys { + // creation touches the account + acc = genAccount() + trie.UpdateAccount(key, acc) + } + + key := keys[0] + branchNodeHex := "0000" + + assert.Equal(t, uint(1), observer.createdNodes[branchNodeHex], "node is created") + assert.Equal(t, 1, observer.touchedNodes[branchNodeHex]) + + // updating touches the account + code := genByteArrayOfLen(100) + codeHash := crypto.Keccak256(code) + acc.CodeHash = common.BytesToHash(codeHash) + trie.UpdateAccount(key, acc) + assert.Equal(t, 2, observer.touchedNodes[branchNodeHex]) + + // updating code touches the account + trie.UpdateAccountCode(key, codeNode(code)) //nolint:errcheck + // 2 touches -- retrieve + updae + assert.Equal(t, 4, observer.touchedNodes[branchNodeHex]) + + // changing storage touches the account + storageKey := genNKeys(1)[0] + fullKey := dbutils.GenerateCompositeTrieKey(common.BytesToHash(key), common.BytesToHash(storageKey)) + + trie.Update(fullKey, []byte("value-1")) + assert.Equal(t, 5, observer.touchedNodes[branchNodeHex]) + + trie.Update(fullKey, []byte("value-2")) + assert.Equal(t, 6, observer.touchedNodes[branchNodeHex]) + + // getting storage touches the account + _, ok := trie.Get(fullKey) + assert.True(t, ok, "should be able to receive storage") + assert.Equal(t, 7, observer.touchedNodes[branchNodeHex]) + + // deleting storage touches the account + trie.Delete(fullKey) + assert.Equal(t, 8, observer.touchedNodes[branchNodeHex]) + + // getting code touches the account + _, ok = trie.GetAccountCode(key) + assert.True(t, ok, "should be able to receive code") + assert.Equal(t, 9, observer.touchedNodes[branchNodeHex]) + + // getting account touches the account + _, ok = trie.GetAccount(key) + assert.True(t, ok, "should be able to receive account") + assert.Equal(t, 10, observer.touchedNodes[branchNodeHex]) +} + +func TestObserverMux(t *testing.T) { + trie := newEmpty() + + observer1 := newMockObserver() + observer2 := newMockObserver() + mux := NewTrieObserverMux() + mux.AddChild(observer1) + mux.AddChild(observer2) + + trie.AddObserver(mux) + + keys := genNKeys(100) + for _, key := range keys { + acc := genAccount() + trie.UpdateAccount(key, acc) + + code := genByteArrayOfLen(100) + codeHash := crypto.Keccak256(code) + acc.CodeHash = common.BytesToHash(codeHash) + trie.UpdateAccount(key, acc) + trie.UpdateAccountCode(key, codeNode(code)) //nolint:errcheck + + _, ok := trie.GetAccount(key) + assert.True(t, ok, "acount should be found") + + } + + trie.Hash() + + for i, key := range keys { + if i < 80 { + trie.Delete(key) + } else { + hex := keybytesToHex(key) + hex = hex[:len(hex)-1] + trie.EvictNode(CodeKeyFromAddrHash(hex)) + } + } + + assert.Equal(t, observer1.createdNodes, observer2.createdNodes, "should propagate created events") + assert.Equal(t, observer1.deletedNodes, observer2.deletedNodes, "should propagate deleted events") + assert.Equal(t, observer1.touchedNodes, observer2.touchedNodes, "should propagate touched events") + + assert.Equal(t, observer1.unloadedNodes, observer2.unloadedNodes, "should propagage unloads") + assert.Equal(t, observer1.unloadedNodeHashes, observer2.unloadedNodeHashes, "should propagage unloads") + assert.Equal(t, observer1.reloadedNodes, observer2.reloadedNodes, "should propagage reloads") +} + +func TestObserverPartial(t *testing.T) { + trie := newEmpty() + + observer := &partialObserver{callbackCalled: false} // only implements `BranchNodeDeleted` + trie.AddObserver(observer) + + keys := genNKeys(2) + for _, key := range keys { + acc := genAccount() + trie.UpdateAccount(key, acc) + + code := genByteArrayOfLen(100) + codeHash := crypto.Keccak256(code) + acc.CodeHash = common.BytesToHash(codeHash) + trie.UpdateAccount(key, acc) + trie.UpdateAccountCode(key, codeNode(code)) //nolint:errcheck + + } + for _, key := range keys { + trie.Delete(key) + } + + assert.True(t, observer.callbackCalled, "should be called") +} diff --git a/trie/trie_pruning.go b/trie/trie_pruning.go deleted file mode 100644 index f413debef..000000000 --- a/trie/trie_pruning.go +++ /dev/null @@ -1,267 +0,0 @@ -// Copyright 2019 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty off -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -// Pruning of the Merkle Patricia trees - -package trie - -import ( - "fmt" - "sort" - "strings" - - "github.com/ledgerwatch/turbo-geth/common" - "github.com/ledgerwatch/turbo-geth/common/pool" -) - -type TriePruning struct { - accountTimestamps map[string]uint64 - - // Maps timestamp (uint64) to set of prefixes of nodes (string) - accounts map[uint64]map[string]struct{} - - // For each timestamp, keeps number of branch nodes belonging to it - generationCounts map[uint64]int - - // Keeps total number of branch nodes - nodeCount int - - // The oldest timestamp of all branch nodes - oldestGeneration uint64 - - // Current timestamp - blockNr uint64 - - createNodeFunc func(prefixAsNibbles []byte) - unloadNodeFunc func(prefix []byte, nodeHash []byte) // called when fullNode or dualNode unloaded -} - -func NewTriePruning(oldestGeneration uint64) *TriePruning { - return &TriePruning{ - oldestGeneration: oldestGeneration, - blockNr: oldestGeneration, - accountTimestamps: make(map[string]uint64), - accounts: make(map[uint64]map[string]struct{}), - generationCounts: make(map[uint64]int), - createNodeFunc: func([]byte) {}, - } -} - -func (tp *TriePruning) SetBlockNr(blockNr uint64) { - tp.blockNr = blockNr -} - -func (tp *TriePruning) BlockNr() uint64 { - return tp.blockNr -} - -func (tp *TriePruning) SetCreateNodeFunc(f func(prefixAsNibbles []byte)) { - tp.createNodeFunc = f -} - -func (tp *TriePruning) SetUnloadNodeFunc(f func(prefix []byte, nodeHash []byte)) { - tp.unloadNodeFunc = f -} - -// Updates a node to the current timestamp -// contract is effectively address of the smart contract -// hex is the prefix of the key -// parent is the node that needs to be modified to unload the touched node -// exists is true when the node existed before, and false if it is a new one -// prevTimestamp is the timestamp the node current has -func (tp *TriePruning) touch(hexS string, exists bool, prevTimestamp uint64, del bool, newTimestamp uint64) { - //fmt.Printf("TouchFrom %x, exists: %t, prevTimestamp %d, del %t, newTimestamp %d\n", hexS, exists, prevTimestamp, del, newTimestamp) - if exists && !del && prevTimestamp == newTimestamp { - return - } - if !del { - var newMap map[string]struct{} - if m, ok := tp.accounts[newTimestamp]; ok { - newMap = m - } else { - newMap = make(map[string]struct{}) - tp.accounts[newTimestamp] = newMap - } - - newMap[hexS] = struct{}{} - } - if exists { - if m, ok := tp.accounts[prevTimestamp]; ok { - delete(m, hexS) - if len(m) == 0 { - delete(tp.accounts, prevTimestamp) - } - } - } - // Update generation count - if !del { - tp.generationCounts[newTimestamp]++ - tp.nodeCount++ - } - if exists { - tp.generationCounts[prevTimestamp]-- - if tp.generationCounts[prevTimestamp] == 0 { - delete(tp.generationCounts, prevTimestamp) - } - tp.nodeCount-- - } -} - -func (tp *TriePruning) Timestamp(hex []byte) uint64 { - ts := tp.accountTimestamps[string(hex)] - return ts -} - -// Updates a node to the current timestamp -// contract is effectively address of the smart contract -// hex is the prefix of the key -// parent is the node that needs to be modified to unload the touched node -func (tp *TriePruning) Touch(hex []byte, del bool) { - var exists = false - var prevTimestamp uint64 - hexS := string(common.CopyBytes(hex)) - - if m, ok := tp.accountTimestamps[hexS]; ok { - prevTimestamp = m - exists = true - if del { - delete(tp.accountTimestamps, hexS) - } - } - if !del { - tp.accountTimestamps[hexS] = tp.blockNr - } - if !exists { - tp.createNodeFunc([]byte(hexS)) - } - - tp.touch(hexS, exists, prevTimestamp, del, tp.blockNr) -} - -func pruneMap(t *Trie, m map[string]struct{}) bool { - hexes := make([]string, len(m)) - i := 0 - for hexS := range m { - hexes[i] = hexS - i++ - } - var empty = false - sort.Strings(hexes) - - for i, hex := range hexes { - if i == 0 || len(hex) == 0 || !strings.HasPrefix(hex, hexes[i-1]) { // If the parent nodes pruned, there is no need to prune descendants - t.unload([]byte(hex)) - if len(hex) == 0 { - empty = true - } - } - } - return empty -} - -// Prunes all nodes that are older than given timestamp -func (tp *TriePruning) PruneToTimestamp( - accountsTrie *Trie, - targetTimestamp uint64, -) { - // Remove (unload) nodes from storage tries and account trie - aggregateAccounts := make(map[string]struct{}) - for gen := tp.oldestGeneration; gen < targetTimestamp; gen++ { - tp.nodeCount -= tp.generationCounts[gen] - if m, ok := tp.accounts[gen]; ok { - for hexS := range m { - aggregateAccounts[hexS] = struct{}{} - } - } - delete(tp.accounts, gen) - } - - // intermediate hashes - key := pool.GetBuffer(64) - defer pool.PutBuffer(key) - for prefix := range aggregateAccounts { - if len(prefix) == 0 || len(prefix)%2 == 1 { - continue - } - - nd, parent, ok := accountsTrie.getNode([]byte(prefix), false) - if !ok { - continue - } - switch nd.(type) { - case *duoNode, *fullNode: - // will work only with these types of nodes - default: - continue - } - switch parent.(type) { // without this condition - doesn't work. Need investigate why. - case *duoNode, *fullNode: - // will work only with these types of nodes - CompressNibbles([]byte(prefix), &key.B) - tp.unloadNodeFunc(key.B, nd.reference()) - default: - continue - } - } - - pruneMap(accountsTrie, aggregateAccounts) - - // Remove fom the timestamp structure - for hexS := range aggregateAccounts { - delete(tp.accountTimestamps, hexS) - } - tp.oldestGeneration = targetTimestamp -} - -// Prunes mininum number of generations necessary so that the total -// number of prunable nodes is at most `targetNodeCount` -func (tp *TriePruning) PruneTo( - accountsTrie *Trie, - targetNodeCount int, -) bool { - if tp.nodeCount <= targetNodeCount { - return false - } - excess := tp.nodeCount - targetNodeCount - prunable := 0 - pruneGeneration := tp.oldestGeneration - for prunable < excess && pruneGeneration < tp.blockNr { - prunable += tp.generationCounts[pruneGeneration] - pruneGeneration++ - } - //fmt.Printf("Will prune to generation %d, nodes to prune: %d, excess %d\n", pruneGeneration, prunable, excess) - tp.PruneToTimestamp(accountsTrie, pruneGeneration) - return true -} - -func (tp *TriePruning) NodeCount() int { - return tp.nodeCount -} - -func (tp *TriePruning) GenCounts() map[uint64]int { - return tp.generationCounts -} - -// DebugDump is used in the tests to ensure that there are no prunable entries (in such case, this function returns empty string) -func (tp *TriePruning) DebugDump() string { - var sb strings.Builder - for timestamp, m := range tp.accounts { - for account := range m { - sb.WriteString(fmt.Sprintf("%d %x\n", timestamp, account)) - } - } - return sb.String() -} diff --git a/trie/trie_pruning_test.go b/trie/trie_pruning_test.go deleted file mode 100644 index a9690892e..000000000 --- a/trie/trie_pruning_test.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2019 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty off -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -// Pruning of the Merkle Patricia trees - -package trie - -import ( - "encoding/binary" - "fmt" - "testing" - - "github.com/ledgerwatch/turbo-geth/common" -) - -func TestOnePerTimestamp(t *testing.T) { - tp := NewTriePruning(0) - tr := New(common.Hash{}) - tr.SetTouchFunc(tp.Touch) - var key [4]byte - value := []byte("V") - var timestamp uint64 = 0 - for n := uint32(0); n < uint32(100); n++ { - tp.SetBlockNr(timestamp) - binary.BigEndian.PutUint32(key[:], n) - tr.Update(key[:], value) // Each key is added within a new generation - timestamp++ - } - for n := uint32(50); n < uint32(60); n++ { - tp.SetBlockNr(timestamp) - binary.BigEndian.PutUint32(key[:], n) - tr.Delete(key[:]) // Each key is added within a new generation - timestamp++ - } - for n := uint32(30); n < uint32(59); n++ { - tp.SetBlockNr(timestamp) - binary.BigEndian.PutUint32(key[:], n) - tr.Get(key[:]) // Each key is added within a new generation - timestamp++ - } - prunableNodes := tr.CountPrunableNodes() - fmt.Printf("Actual prunable nodes: %d, accounted: %d\n", prunableNodes, tp.NodeCount()) - if b := tp.PruneTo(tr, 4); !b { - t.Fatal("Not pruned") - } - prunableNodes = tr.CountPrunableNodes() - fmt.Printf("Actual prunable nodes: %d, accounted: %d\n", prunableNodes, tp.NodeCount()) -} diff --git a/trie/trie_test.go b/trie/trie_test.go index 5066b925a..780ecd9e7 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -887,6 +887,6 @@ func TestCodeNodeUpdateAccountNoChangeCodeHash(t *testing.T) { trie.UpdateAccount(crypto.Keccak256(address[:]), &acc) value, gotValue := trie.GetAccountCode(crypto.Keccak256(address[:])) - assert.Equal(t, value, codeValue1, "the value should NOT reset after account's non codehash had changed") + assert.Equal(t, codeValue1, value, "the value should NOT reset after account's non codehash had changed") assert.True(t, gotValue, "should indicate that the code is still in the cache") }