diff --git a/Makefile b/Makefile
index e5a2ab835..7ad5744bb 100644
--- a/Makefile
+++ b/Makefile
@@ -95,7 +95,7 @@ test: semantics/z3/build/libz3.a all
lint: lintci
-lintci:
+lintci: semantics/z3/build/libz3.a all
@echo "--> Running linter for code diff versus commit $(LATEST_COMMIT)"
@./build/bin/golangci-lint run \
--new-from-rev=$(LATEST_COMMIT) \
@@ -116,7 +116,7 @@ lintci:
lintci-deps:
rm -f ./build/bin/golangci-lint
- curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b ./build/bin v1.21.0
+ curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b ./build/bin v1.23.8
clean:
env GO111MODULE=on go clean -cache
@@ -231,4 +231,4 @@ bindings:
go generate ./tests/contracts/
simulator-genesis:
- go run ./cmd/tester genesis > ./cmd/tester/simulator_genesis.json
\ No newline at end of file
+ go run ./cmd/tester genesis > ./cmd/tester/simulator_genesis.json
diff --git a/cmd/hack/hack.go b/cmd/hack/hack.go
index 584dc77fc..e9d8a273f 100644
--- a/cmd/hack/hack.go
+++ b/cmd/hack/hack.go
@@ -617,7 +617,7 @@ func trieChart() {
}
func execToBlock(chaindata string, block uint64, fromScratch bool) {
- state.MaxTrieCacheGen = 32
+ state.MaxTrieCacheSize = 32
blockDb, err := ethdb.NewBoltDatabase(chaindata)
check(err)
bcb, err := core.NewBlockChain(blockDb, nil, params.MainnetChainConfig, ethash.NewFaker(), vm.Config{}, nil)
@@ -1938,50 +1938,50 @@ func validateTxLookups2(db *ethdb.BoltDatabase, startBlock uint64, interruptCh c
func indexSize(chaindata string) {
db, err := ethdb.NewBoltDatabase(chaindata)
check(err)
- fStorage,err:=os.Create("index_sizes_storage.csv")
+ fStorage, err := os.Create("index_sizes_storage.csv")
check(err)
- fAcc,err:=os.Create("index_sizes_acc.csv")
+ fAcc, err := os.Create("index_sizes_acc.csv")
check(err)
- csvAcc:=csv.NewWriter(fAcc)
+ csvAcc := csv.NewWriter(fAcc)
err = csvAcc.Write([]string{"key", "ln"})
check(err)
- csvStorage:=csv.NewWriter(fStorage)
+ csvStorage := csv.NewWriter(fStorage)
err = csvStorage.Write([]string{"key", "ln"})
- i:=0
- j:=0
- maxLenAcc:=0
- maxLenSt:=0
+ i := 0
+ j := 0
+ maxLenAcc := 0
+ maxLenSt := 0
db.Walk(dbutils.AccountsHistoryBucket, []byte{}, 0, func(k, v []byte) (b bool, e error) {
- if i>10000 {
+ if i > 10000 {
fmt.Println(j)
- i=0
+ i = 0
}
i++
j++
- if len(v)> maxLenAcc {
- maxLenAcc=len(v)
+ if len(v) > maxLenAcc {
+ maxLenAcc = len(v)
}
err = csvAcc.Write([]string{common.Bytes2Hex(k), strconv.Itoa(len(v))})
- if err!=nil {
+ if err != nil {
panic(err)
}
return true, nil
})
- i=0
- j=0
+ i = 0
+ j = 0
db.Walk(dbutils.StorageHistoryBucket, []byte{}, 0, func(k, v []byte) (b bool, e error) {
- if i>10000 {
+ if i > 10000 {
fmt.Println(j)
- i=0
+ i = 0
}
i++
j++
- if len(v)> maxLenSt {
- maxLenSt=len(v)
+ if len(v) > maxLenSt {
+ maxLenSt = len(v)
}
err = csvStorage.Write([]string{common.Bytes2Hex(k), strconv.Itoa(len(v))})
- if err!=nil {
+ if err != nil {
panic(err)
}
diff --git a/cmd/state/commands/stateless.go b/cmd/state/commands/stateless.go
index 6024aa967..5b9d73a1c 100644
--- a/cmd/state/commands/stateless.go
+++ b/cmd/state/commands/stateless.go
@@ -28,7 +28,7 @@ func init() {
withBlock(statelessCmd)
statelessCmd.Flags().StringVar(&statefile, "statefile", "state", "path to the file where the state will be periodically written during the analysis")
- statelessCmd.Flags().Uint32Var(&triesize, "triesize", 1024*1024, "maximum number of nodes in the state trie")
+ statelessCmd.Flags().Uint32Var(&triesize, "triesize", 4*1024*1024, "maximum size of a trie in bytes")
statelessCmd.Flags().BoolVar(&preroot, "preroot", false, "Attempt to compute hash of the trie without modifying it")
statelessCmd.Flags().Uint64Var(&snapshotInterval, "snapshotInterval", 0, "how often to take snapshots (0 - never, 1 - every block, 1000 - every 1000th block, etc)")
statelessCmd.Flags().Uint64Var(&snapshotFrom, "snapshotFrom", 0, "from which block to start snapshots")
diff --git a/cmd/state/stateless/stateless.go b/cmd/state/stateless/stateless.go
index 819d72223..b352b6e6d 100644
--- a/cmd/state/stateless/stateless.go
+++ b/cmd/state/stateless/stateless.go
@@ -150,7 +150,7 @@ func Stateless(
witnessDatabasePath string,
writeHistory bool,
) {
- state.MaxTrieCacheGen = triesize
+ state.MaxTrieCacheSize = uint64(triesize)
startTime := time.Now()
sigs := make(chan os.Signal, 1)
interruptCh := make(chan bool, 1)
@@ -324,7 +324,7 @@ func Stateless(
return
}
if len(resolveWitnesses) > 0 {
- witnessDBWriter.MustUpsert(blockNum, state.MaxTrieCacheGen, resolveWitnesses)
+ witnessDBWriter.MustUpsert(blockNum, uint32(state.MaxTrieCacheSize), resolveWitnesses)
}
}
execTime2 := time.Since(execStart)
@@ -454,7 +454,7 @@ func Stateless(
fmt.Printf("Failed to commit batch: %v\n", err)
return
}
- tds.PruneTries(false)
+ tds.EvictTries(false)
}
if willSnapshot {
diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go
index ff5d3d970..019e93f06 100644
--- a/cmd/utils/flags.go
+++ b/cmd/utils/flags.go
@@ -1589,7 +1589,7 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) {
// TODO(fjl): move trie cache generations into config
if gen := ctx.GlobalInt(TrieCacheGenFlag.Name); gen > 0 {
- state.MaxTrieCacheGen = uint32(gen)
+ state.MaxTrieCacheSize = uint64(gen)
}
}
diff --git a/core/blockchain.go b/core/blockchain.go
index 446f42da6..39dc0403e 100644
--- a/core/blockchain.go
+++ b/core/blockchain.go
@@ -1775,7 +1775,7 @@ func (bc *BlockChain) insertChain(ctx context.Context, chain types.Blocks, verif
}
bc.committedBlock.Store(bc.currentBlock.Load())
if bc.trieDbState != nil {
- bc.trieDbState.PruneTries(false)
+ bc.trieDbState.EvictTries(false)
}
log.Info("Database", "size", bc.db.DiskSize(), "written", written)
}
diff --git a/core/state/database.go b/core/state/database.go
index a889e6f68..238f701af 100644
--- a/core/state/database.go
+++ b/core/state/database.go
@@ -14,6 +14,7 @@
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see .
+//nolint:scopelint
package state
import (
@@ -36,8 +37,8 @@ import (
"github.com/ledgerwatch/turbo-geth/trie"
)
-// Trie cache generation limit after which to evict trie nodes from memory.
-var MaxTrieCacheGen = uint32(1024 * 1024)
+// MaxTrieCacheSize is the trie cache size limit after which to evict trie nodes from memory.
+var MaxTrieCacheSize = uint64(1024 * 1024)
const (
//FirstContractIncarnation - first incarnation for contract accounts. After 1 it increases by 1.
@@ -180,7 +181,7 @@ type TrieDbState struct {
resolveReads bool
savePreimages bool
resolveSetBuilder *trie.ResolveSetBuilder
- tp *trie.TriePruning
+ tp *trie.Eviction
newStream trie.Stream
hashBuilder *trie.HashBuilder
resolver *trie.Resolver
@@ -189,7 +190,7 @@ type TrieDbState struct {
func NewTrieDbState(root common.Hash, db ethdb.Database, blockNr uint64) *TrieDbState {
t := trie.New(root)
- tp := trie.NewTriePruning(blockNr)
+ tp := trie.NewEviction()
tds := &TrieDbState{
t: t,
@@ -202,10 +203,11 @@ func NewTrieDbState(root common.Hash, db ethdb.Database, blockNr uint64) *TrieDb
hashBuilder: trie.NewHashBuilder(false),
incarnationMap: make(map[common.Hash]uint64),
}
- t.SetTouchFunc(tp.Touch)
- tp.SetUnloadNodeFunc(tds.putIntermediateHash)
- tp.SetCreateNodeFunc(tds.delIntermediateHash)
+ tp.SetBlockNumber(blockNr)
+
+ t.AddObserver(tp)
+ t.AddObserver(NewIntermediateHashes(tds.db, tds.db))
return tds
}
@@ -232,7 +234,8 @@ func (tds *TrieDbState) Copy() *TrieDbState {
tds.tMu.Unlock()
n := tds.getBlockNr()
- tp := trie.NewTriePruning(n)
+ tp := trie.NewEviction()
+ tp.SetBlockNumber(n)
cpy := TrieDbState{
t: &tcopy,
@@ -244,36 +247,12 @@ func (tds *TrieDbState) Copy() *TrieDbState {
incarnationMap: make(map[common.Hash]uint64),
}
- cpy.tp.SetUnloadNodeFunc(cpy.putIntermediateHash)
- cpy.tp.SetCreateNodeFunc(cpy.delIntermediateHash)
+ cpy.t.AddObserver(tp)
+ cpy.t.AddObserver(NewIntermediateHashes(cpy.db, cpy.db))
return &cpy
}
-func (tds *TrieDbState) putIntermediateHash(key []byte, nodeHash []byte) {
- if err := tds.db.Put(dbutils.IntermediateTrieHashBucket, common.CopyBytes(key), common.CopyBytes(nodeHash)); err != nil {
- log.Warn("could not put intermediate trie hash", "err", err)
- }
-}
-
-func (tds *TrieDbState) delIntermediateHash(prefixAsNibbles []byte) {
- if len(prefixAsNibbles) == 0 {
- return
- }
-
- if len(prefixAsNibbles)%2 == 1 { // only put to bucket prefixes with even number of nibbles
- return
- }
-
- key := make([]byte, len(prefixAsNibbles)/2)
- trie.CompressNibbles(prefixAsNibbles, &key)
-
- if err := tds.db.Delete(dbutils.IntermediateTrieHashBucket, key); err != nil {
- log.Warn("could not delete intermediate trie hash", "err", err)
- return
- }
-}
-
func (tds *TrieDbState) Database() ethdb.Database {
return tds.db
}
@@ -645,7 +624,7 @@ func (tds *TrieDbState) ResolveStateTrieStateless(database trie.WitnessStorage)
return nil
}
- pos, err := resolver.ResolveStateless(database, tds.blockNr, MaxTrieCacheGen, startPos)
+ pos, err := resolver.ResolveStateless(database, tds.blockNr, uint32(MaxTrieCacheSize), startPos)
if err != nil {
return err
}
@@ -838,7 +817,7 @@ func (tds *TrieDbState) clearUpdates() {
func (tds *TrieDbState) SetBlockNr(blockNr uint64) {
tds.setBlockNr(blockNr)
- tds.tp.SetBlockNr(blockNr)
+ tds.tp.SetBlockNumber(blockNr)
}
func (tds *TrieDbState) GetBlockNr() uint64 {
@@ -1203,28 +1182,69 @@ type TrieStateWriter struct {
tds *TrieDbState
}
-func (tds *TrieDbState) PruneTries(print bool) {
+func (tds *TrieDbState) EvictTries(print bool) {
tds.tMu.Lock()
defer tds.tMu.Unlock()
+ strict := print
tds.incarnationMap = make(map[common.Hash]uint64)
if print {
- prunableNodes := tds.t.CountPrunableNodes()
- fmt.Printf("[Before] Actual prunable nodes: %d, accounted: %d\n", prunableNodes, tds.tp.NodeCount())
+ trieSize := tds.t.TrieSize()
+ fmt.Println("") // newline for better formatting
+ fmt.Printf("[Before] Actual nodes size: %d, accounted size: %d\n", trieSize, tds.tp.TotalSize())
}
- tds.tp.PruneTo(tds.t, int(MaxTrieCacheGen))
+ if strict {
+ actualAccounts := uint64(tds.t.NumberOfAccounts())
+ fmt.Println("number of leaves: ", actualAccounts)
+ accountedAccounts := tds.tp.NumberOf()
+ if actualAccounts != accountedAccounts {
+ panic(fmt.Errorf("account number mismatch: trie=%v eviction=%v", actualAccounts, accountedAccounts))
+ }
+ fmt.Printf("checking number --> ok\n")
+
+ actualSize := uint64(tds.t.TrieSize())
+ accountedSize := tds.tp.TotalSize()
+
+ if actualSize != accountedSize {
+ panic(fmt.Errorf("account size mismatch: trie=%v eviction=%v", actualSize, accountedSize))
+ }
+ fmt.Printf("checking size --> ok\n")
+ }
+
+ tds.tp.EvictToFitSize(tds.t, MaxTrieCacheSize)
+
+ if strict {
+ actualAccounts := uint64(tds.t.NumberOfAccounts())
+ fmt.Println("number of leaves: ", actualAccounts)
+ accountedAccounts := tds.tp.NumberOf()
+ if actualAccounts != accountedAccounts {
+ panic(fmt.Errorf("after eviction account number mismatch: trie=%v eviction=%v", actualAccounts, accountedAccounts))
+ }
+ fmt.Printf("checking number --> ok\n")
+
+ actualSize := uint64(tds.t.TrieSize())
+ accountedSize := tds.tp.TotalSize()
+
+ if actualSize != accountedSize {
+ panic(fmt.Errorf("after eviction account size mismatch: trie=%v eviction=%v", actualSize, accountedSize))
+ }
+ fmt.Printf("checking size --> ok\n")
+ }
if print {
- prunableNodes := tds.t.CountPrunableNodes()
- fmt.Printf("[After] Actual prunable nodes: %d, accounted: %d\n", prunableNodes, tds.tp.NodeCount())
+ trieSize := tds.t.TrieSize()
+ fmt.Printf("[After] Actual nodes size: %d, accounted size: %d\n", trieSize, tds.tp.TotalSize())
+
+ actualAccounts := uint64(tds.t.NumberOfAccounts())
+ fmt.Println("number of leaves: ", actualAccounts)
}
var m runtime.MemStats
runtime.ReadMemStats(&m)
- log.Info("Memory", "nodes", tds.tp.NodeCount(), "hashes", tds.t.HashMapSize(),
+ log.Info("Memory", "nodes size", tds.tp.TotalSize(), "hashes", tds.t.HashMapSize(),
"alloc", int(m.Alloc/1024), "sys", int(m.Sys/1024), "numGC", int(m.NumGC))
if print {
- fmt.Printf("Pruning done. Nodes: %d, alloc: %d, sys: %d, numGC: %d\n", tds.tp.NodeCount(), int(m.Alloc/1024), int(m.Sys/1024), int(m.NumGC))
+ fmt.Printf("Eviction done. Nodes size: %d, alloc: %d, sys: %d, numGC: %d\n", tds.tp.TotalSize(), int(m.Alloc/1024), int(m.Sys/1024), int(m.NumGC))
}
}
diff --git a/core/state/intermediate_hashes.go b/core/state/intermediate_hashes.go
new file mode 100644
index 000000000..ca60be45d
--- /dev/null
+++ b/core/state/intermediate_hashes.go
@@ -0,0 +1,51 @@
+package state
+
+import (
+ "github.com/ledgerwatch/turbo-geth/common"
+ "github.com/ledgerwatch/turbo-geth/common/dbutils"
+ "github.com/ledgerwatch/turbo-geth/common/pool"
+ "github.com/ledgerwatch/turbo-geth/ethdb"
+ "github.com/ledgerwatch/turbo-geth/log"
+ "github.com/ledgerwatch/turbo-geth/trie"
+)
+
+const keyBufferSize = 64
+
+type IntermediateHashes struct {
+ trie.NoopObserver // make sure that we don't need to subscribe to unnecessary methods
+ putter ethdb.Putter
+ deleter ethdb.Deleter
+}
+
+func NewIntermediateHashes(putter ethdb.Putter, deleter ethdb.Deleter) *IntermediateHashes {
+ return &IntermediateHashes{putter: putter, deleter: deleter}
+}
+
+func (ih *IntermediateHashes) WillUnloadBranchNode(prefixAsNibbles []byte, nodeHash common.Hash) {
+ // only put to bucket prefixes with even number of nibbles
+ if len(prefixAsNibbles) == 0 || len(prefixAsNibbles)%2 == 1 {
+ return
+ }
+
+ key := pool.GetBuffer(keyBufferSize)
+ trie.CompressNibbles(prefixAsNibbles, &key.B)
+
+ if err := ih.putter.Put(dbutils.IntermediateTrieHashBucket, common.CopyBytes(key.B), common.CopyBytes(nodeHash[:])); err != nil {
+ log.Warn("could not put intermediate trie hash", "err", err)
+ }
+}
+
+func (ih *IntermediateHashes) BranchNodeLoaded(prefixAsNibbles []byte) {
+ // only put to bucket prefixes with even number of nibbles
+ if len(prefixAsNibbles) == 0 || len(prefixAsNibbles)%2 == 1 {
+ return
+ }
+
+ key := pool.GetBuffer(keyBufferSize)
+ trie.CompressNibbles(prefixAsNibbles, &key.B)
+
+ if err := ih.deleter.Delete(dbutils.IntermediateTrieHashBucket, key.B); err != nil {
+ log.Warn("could not delete intermediate trie hash", "err", err)
+ return
+ }
+}
diff --git a/trie/node.go b/trie/node.go
index 314d26441..c4c9bc1d3 100644
--- a/trie/node.go
+++ b/trie/node.go
@@ -228,3 +228,77 @@ func (n hashNode) String() string { return n.fstring("") }
func (n valueNode) String() string { return n.fstring("") }
func (n codeNode) String() string { return n.fstring("") }
func (an accountNode) String() string { return an.fstring("") }
+
+func CodeKeyFromAddrHash(addrHash []byte) []byte {
+ return append(addrHash, 0xC0, 0xDE)
+}
+
+func CodeHexFromHex(hex []byte) []byte {
+ return append(hex, 0x0C, 0x00, 0x0D, 0x0E)
+}
+
+func IsPointingToCode(key []byte) bool {
+ // checking for 0xC0DE
+ l := len(key)
+ if l < 2 {
+ return false
+ }
+
+ return key[l-2] == 0xC0 && key[l-1] == 0xDE
+}
+
+func AddrHashFromCodeKey(codeKey []byte) []byte {
+ // cut off 0xC0DE
+ return codeKey[:len(codeKey)-2]
+}
+
+func calcSubtreeSize(node node) int {
+ switch n := node.(type) {
+ case nil:
+ return 0
+ case valueNode:
+ return 0
+ case *shortNode:
+ return calcSubtreeSize(n.Val)
+ case *duoNode:
+ return 1 + calcSubtreeSize(n.child1) + calcSubtreeSize(n.child2)
+ case *fullNode:
+ size := 1
+ for _, child := range n.Children {
+ size += calcSubtreeSize(child)
+ }
+ return size
+ case *accountNode:
+ return len(n.code) + calcSubtreeSize(n.storage)
+ case hashNode:
+ return 0
+ }
+ return 0
+}
+
+func calcSubtreeNodes(node node) int {
+ switch n := node.(type) {
+ case nil:
+ return 0
+ case valueNode:
+ return 0
+ case *shortNode:
+ return calcSubtreeNodes(n.Val)
+ case *duoNode:
+ return 1 + calcSubtreeNodes(n.child1) + calcSubtreeNodes(n.child2)
+ case *fullNode:
+ size := 1
+ for _, child := range n.Children {
+ size += calcSubtreeNodes(child)
+ }
+ return size
+ case *accountNode:
+ if n.code != nil {
+ return 1 + calcSubtreeNodes(n.storage)
+ }
+ return calcSubtreeNodes(n.storage)
+ case hashNode:
+ return 0
+ }
+ return 0
+}
diff --git a/trie/trie.go b/trie/trie.go
index e020cc132..c02bb57b1 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -48,8 +48,6 @@ var (
type Trie struct {
root node
- touchFunc func(hex []byte, del bool)
-
newHasherFunc func() *hasher
Version uint8
@@ -57,6 +55,8 @@ type Trie struct {
binary bool
hashMap map[common.Hash]node
+
+ observers *ObserverMux
}
// New creates a trie with an existing root node from db.
@@ -67,9 +67,9 @@ type Trie struct {
// not exist in the database. Accessing the trie loads nodes from db on demand.
func New(root common.Hash) *Trie {
trie := &Trie{
- touchFunc: func([]byte, bool) {},
newHasherFunc: func() *hasher { return newHasher( /*valueNodesRlpEncoded = */ false) },
hashMap: make(map[common.Hash]node),
+ observers: NewTrieObserverMux(),
}
if (root != common.Hash{}) && root != EmptyRoot {
trie.root = hashNode(root[:])
@@ -87,9 +87,9 @@ func NewBinary(root common.Hash) *Trie {
// it is usually used for testing purposes.
func NewTestRLPTrie(root common.Hash) *Trie {
trie := &Trie{
- touchFunc: func([]byte, bool) {},
newHasherFunc: func() *hasher { return newHasher( /*valueNodesRlpEncoded = */ true) },
hashMap: make(map[common.Hash]node),
+ observers: NewTrieObserverMux(),
}
if (root != common.Hash{}) && root != EmptyRoot {
trie.root = hashNode(root[:])
@@ -97,8 +97,8 @@ func NewTestRLPTrie(root common.Hash) *Trie {
return trie
}
-func (t *Trie) SetTouchFunc(touchFunc func(hex []byte, del bool)) {
- t.touchFunc = touchFunc
+func (t *Trie) AddObserver(observer Observer) {
+ t.observers.AddChild(observer)
}
// Get returns the value for key stored in the trie.
@@ -149,6 +149,8 @@ func (t *Trie) GetAccountCode(key []byte) (value []byte, gotValue bool) {
return nil, gotValue
}
+ t.observers.CodeNodeTouched(hex)
+
if accNode.code == nil {
return nil, false
}
@@ -174,7 +176,7 @@ func (t *Trie) getAccount(origNode node, key []byte, pos int) (value *accountNod
return nil, true
}
case *duoNode:
- t.touchFunc(key[:pos], false)
+ t.observers.BranchNodeTouched(key[:pos])
i1, i2 := n.childrenIdx()
switch key[pos] {
case i1:
@@ -185,7 +187,7 @@ func (t *Trie) getAccount(origNode node, key []byte, pos int) (value *accountNod
return nil, true
}
case *fullNode:
- t.touchFunc(key[:pos], false)
+ t.observers.BranchNodeTouched(key[:pos])
child := n.Children[key[pos]]
return t.getAccount(child, key, pos+1)
case hashNode:
@@ -215,7 +217,7 @@ func (t *Trie) get(origNode node, key []byte, pos int) (value []byte, gotValue b
}
return
case *duoNode:
- t.touchFunc(key[:pos], false)
+ t.observers.BranchNodeTouched(key[:pos])
i1, i2 := n.childrenIdx()
switch key[pos] {
case i1:
@@ -227,7 +229,7 @@ func (t *Trie) get(origNode node, key []byte, pos int) (value []byte, gotValue b
}
return
case *fullNode:
- t.touchFunc(key[:pos], false)
+ t.observers.BranchNodeTouched(key[:pos])
child := n.Children[key[pos]]
if child == nil {
return nil, true
@@ -254,9 +256,10 @@ func (t *Trie) Update(key, value []byte) {
hex = keyHexToBin(hex)
}
+ newnode := valueNode(value)
+
if t.root == nil {
- newnode := &shortNode{Key: hex, Val: valueNode(value)}
- t.root = newnode
+ t.root = &shortNode{Key: hex, Val: newnode}
} else {
_, t.root = t.insert(t.root, hex, 0, valueNode(value))
}
@@ -271,20 +274,18 @@ func (t *Trie) UpdateAccount(key []byte, acc *accounts.Account) {
if t.binary {
hex = keyHexToBin(hex)
}
- if t.root == nil {
- var newnode node
- if value.Root == EmptyRoot || value.Root == (common.Hash{}) {
- newnode = &shortNode{Key: hex, Val: &accountNode{*value, nil, true, nil}}
- } else {
- newnode = &shortNode{Key: hex, Val: &accountNode{*value, hashNode(value.Root[:]), true, nil}}
- }
- t.root = newnode
+
+ var newnode *accountNode
+ if value.Root == EmptyRoot || value.Root == (common.Hash{}) {
+ newnode = &accountNode{*value, nil, true, nil}
} else {
- if value.Root == EmptyRoot || value.Root == (common.Hash{}) {
- _, t.root = t.insert(t.root, hex, 0, &accountNode{*value, nil, true, nil})
- } else {
- _, t.root = t.insert(t.root, hex, 0, &accountNode{*value, hashNode(value.Root[:]), true, nil})
- }
+ newnode = &accountNode{*value, hashNode(value.Root[:]), true, nil}
+ }
+
+ if t.root == nil {
+ t.root = &shortNode{Key: hex, Val: newnode}
+ } else {
+ _, t.root = t.insert(t.root, hex, 0, newnode)
}
}
@@ -311,6 +312,7 @@ func (t *Trie) UpdateAccountCode(key []byte, code codeNode) error {
accNode.code = code
+ // t.insert will call the observer methods itself
_, t.root = t.insert(t.root, hex, 0, accNode)
return nil
}
@@ -467,11 +469,19 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node) (updated b
if updated {
if !bytes.Equal(origAccN.CodeHash[:], vAccN.CodeHash[:]) {
origAccN.code = nil
+ } else if vAccN.code != nil {
+ origAccN.code = vAccN.code
}
origAccN.Account.Copy(&vAccN.Account)
origAccN.rootCorrect = false
}
newNode = origAccN
+
+ if len(origAccN.code) > 0 {
+ t.observers.CodeNodeSizeChanged(key[:pos-1], uint(len(origAccN.code)))
+ } else {
+ t.observers.CodeNodeDeleted(key[:pos-1])
+ }
return
}
@@ -532,22 +542,21 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node) (updated b
// Replace this shortNode with the branch if it occurs at index 0.
if matchlen == 0 {
- t.touchFunc(key[:pos], false)
newNode = branch // current node leaves the generation, but new node branch joins it
} else {
// Otherwise, replace it with a short node leading up to the branch.
- t.touchFunc(key[:pos+matchlen], false)
n.Key = common.CopyBytes(key[pos : pos+matchlen])
n.Val = branch
n.ref.len = 0
newNode = n
}
+ t.observers.BranchNodeCreated(key[:pos+matchlen])
updated = true
}
return
case *duoNode:
- t.touchFunc(key[:pos], false)
+ t.observers.BranchNodeTouched(key[:pos])
i1, i2 := n.childrenIdx()
switch key[pos] {
case i1:
@@ -585,7 +594,7 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node) (updated b
return
case *fullNode:
- t.touchFunc(key[:pos], false)
+ t.observers.BranchNodeTouched(key[:pos])
child := n.Children[key[pos]]
if child == nil {
t.evictNodeFromHashMap(n)
@@ -635,9 +644,8 @@ func (t *Trie) getNode(hex []byte, doTouch bool) (node, node, bool) {
}
case *duoNode:
if doTouch {
- t.touchFunc(hex[:pos], false)
+ t.observers.BranchNodeTouched(hex[:pos])
}
-
i1, i2 := n.childrenIdx()
switch hex[pos] {
case i1:
@@ -653,7 +661,7 @@ func (t *Trie) getNode(hex []byte, doTouch bool) (node, node, bool) {
}
case *fullNode:
if doTouch {
- t.touchFunc(hex[:pos], false)
+ t.observers.BranchNodeTouched(hex[:pos])
}
child := n.Children[hex[pos]]
if child == nil {
@@ -733,7 +741,12 @@ func (t *Trie) touchAll(n node, hex []byte, del bool) {
t.touchAll(n.Val, hexVal, del)
}
case *duoNode:
- t.touchFunc(hex, del)
+ if del {
+ t.observers.BranchNodeDeleted(hex)
+ } else {
+ t.observers.BranchNodeCreated(hex)
+ t.observers.BranchNodeLoaded(hex)
+ }
i1, i2 := n.childrenIdx()
hex1 := make([]byte, len(hex)+1)
copy(hex1, hex)
@@ -744,13 +757,23 @@ func (t *Trie) touchAll(n node, hex []byte, del bool) {
t.touchAll(n.child1, hex1, del)
t.touchAll(n.child2, hex2, del)
case *fullNode:
- t.touchFunc(hex, del)
+ if del {
+ t.observers.BranchNodeDeleted(hex)
+ } else {
+ t.observers.BranchNodeCreated(hex)
+ t.observers.BranchNodeLoaded(hex)
+ }
for i, child := range n.Children {
if child != nil {
t.touchAll(child, concat(hex, byte(i)), del)
}
}
case *accountNode:
+ if del {
+ t.observers.CodeNodeDeleted(hex)
+ } else {
+ t.observers.CodeNodeTouched(hex)
+ }
if n.storage != nil {
t.touchAll(n.storage, hex, del)
}
@@ -853,15 +876,15 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, preserveAccountNo
case i1:
updated, nn = t.delete(n.child1, key, keyStart+1, preserveAccountNode)
if !updated {
- t.touchFunc(key[:keyStart], false)
newNode = n
+ t.observers.BranchNodeTouched(key[:keyStart])
} else {
t.evictNodeFromHashMap(n)
if nn == nil {
- t.touchFunc(key[:keyStart], true)
newNode = t.convertToShortNode(n.child2, uint(i2))
+ t.observers.BranchNodeDeleted(key[:keyStart])
} else {
- t.touchFunc(key[:keyStart], false)
+ t.observers.BranchNodeTouched(key[:keyStart])
n.child1 = nn
n.ref.len = 0
newNode = n
@@ -870,22 +893,21 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, preserveAccountNo
case i2:
updated, nn = t.delete(n.child2, key, keyStart+1, preserveAccountNode)
if !updated {
- t.touchFunc(key[:keyStart], false)
newNode = n
+ t.observers.BranchNodeTouched(key[:keyStart])
} else {
t.evictNodeFromHashMap(n)
if nn == nil {
- t.touchFunc(key[:keyStart], true)
newNode = t.convertToShortNode(n.child1, uint(i1))
+ t.observers.BranchNodeDeleted(key[:keyStart])
} else {
- t.touchFunc(key[:keyStart], false)
+ t.observers.BranchNodeTouched(key[:keyStart])
n.child2 = nn
n.ref.len = 0
newNode = n
}
}
default:
- t.touchFunc(key[:keyStart], false)
updated = false
newNode = n
}
@@ -895,8 +917,8 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, preserveAccountNo
child := n.Children[key[keyStart]]
updated, nn = t.delete(child, key, keyStart+1, preserveAccountNode)
if !updated {
- t.touchFunc(key[:keyStart], false)
newNode = n
+ t.observers.BranchNodeTouched(key[:keyStart])
} else {
t.evictNodeFromHashMap(n)
n.Children[key[keyStart]] = nn
@@ -926,10 +948,10 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, preserveAccountNo
}
}
if count == 1 {
- t.touchFunc(key[:keyStart], true)
newNode = t.convertToShortNode(n.Children[pos1], uint(pos1))
+ t.observers.BranchNodeDeleted(key[:keyStart])
} else if count == 2 {
- t.touchFunc(key[:keyStart], false)
+ t.observers.BranchNodeTouched(key[:keyStart])
duo := &duoNode{}
if pos1 == int(key[keyStart]) {
duo.child1 = nn
@@ -944,7 +966,7 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, preserveAccountNo
duo.mask = (1 << uint(pos1)) | (uint32(1) << uint(pos2))
newNode = duo
} else if count > 2 {
- t.touchFunc(key[:keyStart], false)
+ t.observers.BranchNodeTouched(key[:keyStart])
// n still contains at least three values and cannot be reduced.
n.ref.len = 0
newNode = n
@@ -960,20 +982,23 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, preserveAccountNo
case *accountNode:
if keyStart >= len(key) || key[keyStart] == 16 {
// Key terminates here
+ h := key[:keyStart]
+ if h[len(h)-1] == 16 {
+ h = h[:len(h)-1]
+ }
if n.storage != nil {
- h := key[:keyStart]
- if h[len(h)-1] == 16 {
- h = h[:len(h)-1]
- }
// Mark all the storage nodes as deleted
t.touchAll(n.storage, h, true)
}
if preserveAccountNode {
n.storage = nil
+ n.code = nil
n.Root = EmptyRoot
n.rootCorrect = true
+ t.observers.CodeNodeDeleted(h)
return true, n
}
+
return true, nil
}
updated, nn = t.delete(n.storage, key, keyStart, preserveAccountNode)
@@ -1067,11 +1092,22 @@ func (t *Trie) DeepHash(keyPrefix []byte) (bool, common.Hash) {
return true, accNode.Root
}
-func (t *Trie) unload(hex []byte) {
+func (t *Trie) EvictNode(hex []byte) {
+ isCode := IsPointingToCode(hex)
+ if isCode {
+ hex = AddrHashFromCodeKey(hex)
+ }
+
nd, parent, ok := t.getNode(hex, false)
if !ok {
return
}
+ if accNode, ok := parent.(*accountNode); isCode && ok {
+ // add special treatment to code nodes
+ accNode.code = nil
+ return
+ }
+
switch nd.(type) {
case valueNode, hashNode:
return
@@ -1084,9 +1120,12 @@ func (t *Trie) unload(hex []byte) {
var hn common.Hash
if nd == nil {
fmt.Printf("nd == nil, hex %x, parent node: %T\n", hex, parent)
+ return
}
copy(hn[:], nd.reference())
hnode := hashNode(hn[:])
+ t.observers.WillUnloadBranchNode(hex, hn)
+
switch p := parent.(type) {
case nil:
t.root = hnode
@@ -1108,57 +1147,12 @@ func (t *Trie) unload(hex []byte) {
}
}
-func (t *Trie) CountPrunableNodes() int {
- return t.countPrunableNodes(t.root, []byte{}, false)
+func (t *Trie) TrieSize() int {
+ return calcSubtreeSize(t.root)
}
-func (t *Trie) countPrunableNodes(nd node, hex []byte, print bool) int {
- switch n := nd.(type) {
- case nil:
- return 0
- case valueNode:
- return 0
- case *accountNode:
- return t.countPrunableNodes(n.storage, hex, print)
- case hashNode:
- return 0
- case *shortNode:
- var hexVal []byte
- if _, ok := n.Val.(valueNode); !ok { // Don't need to compute prefix for a leaf
- h := n.Key
- if h[len(h)-1] == 16 {
- h = h[:len(h)-1]
- }
- hexVal = concat(hex, h...)
- }
- //@todo accountNode?
- return t.countPrunableNodes(n.Val, hexVal, print)
- case *duoNode:
- i1, i2 := n.childrenIdx()
- hex1 := make([]byte, len(hex)+1)
- copy(hex1, hex)
- hex1[len(hex)] = byte(i1)
- hex2 := make([]byte, len(hex)+1)
- copy(hex2, hex)
- hex2[len(hex)] = byte(i2)
- if print {
- fmt.Printf("%T node: %x\n", n, hex)
- }
- return 1 + t.countPrunableNodes(n.child1, hex1, print) + t.countPrunableNodes(n.child2, hex2, print)
- case *fullNode:
- if print {
- fmt.Printf("%T node: %x\n", n, hex)
- }
- count := 0
- for i, child := range n.Children {
- if child != nil {
- count += t.countPrunableNodes(child, concat(hex, byte(i)), print)
- }
- }
- return 1 + count
- default:
- panic("")
- }
+func (t *Trie) NumberOfAccounts() int {
+ return calcSubtreeNodes(t.root)
}
func (t *Trie) hashRoot() (node, error) {
diff --git a/trie/trie_eviction.go b/trie/trie_eviction.go
new file mode 100644
index 000000000..6be7c3d3b
--- /dev/null
+++ b/trie/trie_eviction.go
@@ -0,0 +1,343 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty off
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// Pruning of the Merkle Patricia trees
+
+package trie
+
+import (
+ "fmt"
+ "sort"
+ "strings"
+)
+
+type AccountEvicter interface {
+ EvictNode([]byte)
+}
+
+type generations struct {
+ blockNumToGeneration map[uint64]*generation
+ keyToBlockNum map[string]uint64
+ oldestBlockNum uint64
+ totalSize int64
+}
+
+func newGenerations() *generations {
+ return &generations{
+ make(map[uint64]*generation),
+ make(map[string]uint64),
+ 0,
+ 0,
+ }
+}
+
+func (gs *generations) add(blockNum uint64, key []byte, size uint) {
+ if _, ok := gs.keyToBlockNum[string(key)]; ok {
+ gs.updateSize(blockNum, key, size)
+ return
+ }
+
+ generation, ok := gs.blockNumToGeneration[blockNum]
+ if !ok {
+ generation = newGeneration()
+ gs.blockNumToGeneration[blockNum] = generation
+ }
+ generation.add(key, size)
+ gs.keyToBlockNum[string(key)] = blockNum
+ if gs.oldestBlockNum > blockNum {
+ gs.oldestBlockNum = blockNum
+ }
+ gs.totalSize += int64(size)
+}
+
+func (gs *generations) touch(blockNum uint64, key []byte) {
+ if len(gs.blockNumToGeneration) == 0 {
+ return
+ }
+
+ oldBlockNum, ok := gs.keyToBlockNum[string(key)]
+ if !ok {
+ return
+ }
+
+ oldGeneration, ok := gs.blockNumToGeneration[oldBlockNum]
+ if !ok {
+ return
+ }
+
+ currentGeneration, ok := gs.blockNumToGeneration[blockNum]
+ if !ok {
+ currentGeneration = newGeneration()
+ gs.blockNumToGeneration[blockNum] = currentGeneration
+ }
+
+ currentGeneration.grabFrom(key, oldGeneration)
+
+ gs.keyToBlockNum[string(key)] = blockNum
+
+ if gs.oldestBlockNum > blockNum {
+ gs.oldestBlockNum = blockNum
+ }
+
+ if oldGeneration.empty() {
+ delete(gs.blockNumToGeneration, oldBlockNum)
+ }
+}
+
+func (gs *generations) remove(key []byte) {
+ oldBlockNum, ok := gs.keyToBlockNum[string(key)]
+ if !ok {
+ return
+ }
+ generation, ok := gs.blockNumToGeneration[oldBlockNum]
+ if !ok {
+ return
+ }
+ sizeDiff := generation.remove(key)
+ gs.totalSize += sizeDiff
+ delete(gs.keyToBlockNum, string(key))
+}
+
+func (gs *generations) updateSize(blockNum uint64, key []byte, newSize uint) {
+ oldBlockNum, ok := gs.keyToBlockNum[string(key)]
+ if !ok {
+ gs.add(blockNum, key, newSize)
+ return
+ }
+ generation, ok := gs.blockNumToGeneration[oldBlockNum]
+ if !ok {
+ gs.add(blockNum, key, newSize)
+ return
+ }
+
+ sizeDiff := generation.updateAccountSize(key, newSize)
+ gs.totalSize += sizeDiff
+ gs.touch(blockNum, key)
+}
+
+// popKeysToEvict returns the keys to evict from the trie,
+// also removing them from generations
+func (gs *generations) popKeysToEvict(threshold uint64) []string {
+ keys := make([]string, 0)
+ for uint64(gs.totalSize) > threshold && len(gs.blockNumToGeneration) > 0 {
+ generation, ok := gs.blockNumToGeneration[gs.oldestBlockNum]
+ if !ok {
+ gs.oldestBlockNum++
+ continue
+ }
+
+ gs.totalSize -= generation.totalSize
+ if gs.totalSize < 0 {
+ gs.totalSize = 0
+ }
+ keysToEvict := generation.keys()
+ keys = append(keys, keysToEvict...)
+ for _, k := range keysToEvict {
+ delete(gs.keyToBlockNum, k)
+ }
+ delete(gs.blockNumToGeneration, gs.oldestBlockNum)
+ gs.oldestBlockNum++
+ }
+ return keys
+}
+
+type generation struct {
+ sizesByKey map[string]uint
+ totalSize int64
+}
+
+func newGeneration() *generation {
+ return &generation{
+ make(map[string]uint),
+ 0,
+ }
+}
+
+func (g *generation) empty() bool {
+ return len(g.sizesByKey) == 0
+}
+
+func (g *generation) grabFrom(key []byte, other *generation) {
+ if g == other {
+ return
+ }
+
+ keyStr := string(key)
+ size, ok := other.sizesByKey[keyStr]
+ if !ok {
+ return
+ }
+
+ g.sizesByKey[keyStr] = size
+ g.totalSize += int64(size)
+ other.totalSize -= int64(size)
+
+ if other.totalSize < 0 {
+ other.totalSize = 0
+ }
+
+ delete(other.sizesByKey, keyStr)
+}
+
+func (g *generation) add(key []byte, size uint) {
+ g.sizesByKey[string(key)] = size
+ g.totalSize += int64(size)
+}
+
+func (g *generation) updateAccountSize(key []byte, size uint) int64 {
+ oldSize := g.sizesByKey[string(key)]
+ g.sizesByKey[string(key)] = size
+ diff := int64(size) - int64(oldSize)
+ g.totalSize += diff
+ return diff
+}
+
+func (g *generation) remove(key []byte) int64 {
+ oldSize := g.sizesByKey[string(key)]
+ delete(g.sizesByKey, string(key))
+ g.totalSize -= int64(oldSize)
+ if g.totalSize < 0 {
+ g.totalSize = 0
+ }
+
+ return -1 * int64(oldSize)
+}
+
+func (g *generation) keys() []string {
+ keys := make([]string, len(g.sizesByKey))
+ i := 0
+ for k := range g.sizesByKey {
+ keys[i] = k
+ i++
+ }
+ return keys
+}
+
+type Eviction struct {
+ NoopObserver // make sure that we don't need to implement unnecessary observer methods
+
+ blockNumber uint64
+
+ generations *generations
+}
+
+func NewEviction() *Eviction {
+ return &Eviction{
+ generations: newGenerations(),
+ }
+}
+
+func (tp *Eviction) SetBlockNumber(blockNumber uint64) {
+ tp.blockNumber = blockNumber
+}
+
+func (tp *Eviction) BlockNumber() uint64 {
+ return tp.blockNumber
+}
+
+func (tp *Eviction) BranchNodeCreated(hex []byte) {
+ key := hex
+ tp.generations.add(tp.blockNumber, key, 1)
+}
+
+func (tp *Eviction) BranchNodeDeleted(hex []byte) {
+ key := hex
+ tp.generations.remove(key)
+}
+
+func (tp *Eviction) BranchNodeTouched(hex []byte) {
+ key := hex
+ tp.generations.touch(tp.blockNumber, key)
+}
+
+func (tp *Eviction) CodeNodeCreated(hex []byte, size uint) {
+ key := hex
+ tp.generations.add(tp.blockNumber, CodeKeyFromAddrHash(key), size)
+}
+
+func (tp *Eviction) CodeNodeDeleted(hex []byte) {
+ key := hex
+ tp.generations.remove(CodeKeyFromAddrHash(key))
+}
+
+func (tp *Eviction) CodeNodeTouched(hex []byte) {
+ key := hex
+ tp.generations.touch(tp.blockNumber, CodeKeyFromAddrHash(key))
+}
+
+func (tp *Eviction) CodeNodeSizeChanged(hex []byte, newSize uint) {
+ key := hex
+ tp.generations.updateSize(tp.blockNumber, CodeKeyFromAddrHash(key), newSize)
+}
+
+func evictList(evicter AccountEvicter, hexes []string) bool {
+ var empty = false
+ sort.Strings(hexes)
+
+ // from long to short -- a naive way to first clean up nodes and then accounts
+ // FIXME: optimize to avoid the same paths
+ for i := len(hexes) - 1; i >= 0; i-- {
+ evicter.EvictNode([]byte(hexes[i]))
+ }
+ return empty
+}
+
+// EvictToFitSize evicts mininum number of generations necessary so that the total
+// size of accounts left is fits into the provided threshold
+func (tp *Eviction) EvictToFitSize(
+ evicter AccountEvicter,
+ threshold uint64,
+) bool {
+
+ if uint64(tp.generations.totalSize) <= threshold {
+ return false
+ }
+
+ keys := tp.generations.popKeysToEvict(threshold)
+
+ return evictList(evicter, keys)
+}
+
+func (tp *Eviction) TotalSize() uint64 {
+ return uint64(tp.generations.totalSize)
+}
+
+func (tp *Eviction) NumberOf() uint64 {
+ total := uint64(0)
+ for _, gen := range tp.generations.blockNumToGeneration {
+ if gen == nil {
+ continue
+ }
+ total += uint64(len(gen.sizesByKey))
+ }
+ return total
+}
+
+func (tp *Eviction) DebugDump() string {
+ var sb strings.Builder
+
+ for block, gen := range tp.generations.blockNumToGeneration {
+ if gen.empty() {
+ continue
+ }
+ sb.WriteString(fmt.Sprintf("Block: %v\n", block))
+ for key, size := range gen.sizesByKey {
+ sb.WriteString(fmt.Sprintf(" %x->%v\n", key, size))
+ }
+ }
+
+ return sb.String()
+}
diff --git a/trie/trie_eviction_test.go b/trie/trie_eviction_test.go
new file mode 100644
index 000000000..0eb4e5754
--- /dev/null
+++ b/trie/trie_eviction_test.go
@@ -0,0 +1,394 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty off
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// Pruning of the Merkle Patricia trees
+
+package trie
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+type mockAccountEvicter struct {
+ keys [][]byte
+}
+
+func newMockAccountEvicter() *mockAccountEvicter {
+ return &mockAccountEvicter{make([][]byte, 0)}
+}
+
+func (m *mockAccountEvicter) EvictNode(key []byte) {
+ m.keys = append(m.keys, key)
+}
+
+func TestEvictionBasicOperations(t *testing.T) {
+ eviction := NewEviction()
+ eviction.SetBlockNumber(1)
+
+ key := []byte{0x01, 0x01, 0x01, 0x01}
+ hex := keybytesToHex(key)
+ eviction.CodeNodeCreated(hex, 1024)
+
+ assert.Equal(t, 1024, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 1024, int(eviction.generations.blockNumToGeneration[1].totalSize), "should register size of gen")
+
+ // grow
+ eviction.CodeNodeSizeChanged(hex, 2048)
+
+ assert.Equal(t, 2048, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 2048, int(eviction.generations.blockNumToGeneration[1].totalSize), "should register size of gen")
+
+ // shrink
+ eviction.CodeNodeSizeChanged(hex, 100)
+
+ assert.Equal(t, 100, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 100, int(eviction.generations.blockNumToGeneration[1].totalSize), "should register size of gen")
+
+ // shrink
+ eviction.CodeNodeDeleted(hex)
+
+ assert.Equal(t, 0, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 0, int(eviction.generations.blockNumToGeneration[1].totalSize), "should register size of gen")
+}
+
+func TestEvictionPartialSingleGen(t *testing.T) {
+ eviction := NewEviction()
+ eviction.SetBlockNumber(1)
+
+ // create 100kb or accounts
+ for i := 0; i < 100; i++ {
+ key := []byte{0x01, 0x01, 0x01, byte(i)}
+ eviction.BranchNodeCreated(keybytesToHex(key))
+ }
+
+ assert.Equal(t, 100, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 0, int(eviction.generations.oldestBlockNum), "should register block num")
+ assert.Equal(t, 100, int(eviction.generations.blockNumToGeneration[1].totalSize), "should register size of gen")
+
+ mock := newMockAccountEvicter()
+
+ eviction.EvictToFitSize(mock, 99)
+
+ assert.Equal(t, 0, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 0, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 2, int(eviction.generations.oldestBlockNum), "should register block num")
+ assert.Equal(t, 100, len(mock.keys), "should evict all 100 accounts")
+}
+
+func TestEvictionFullSingleGen(t *testing.T) {
+ eviction := NewEviction()
+ eviction.SetBlockNumber(1)
+
+ // create 100kb or accounts
+ for i := 0; i < 100; i++ {
+ key := []byte{0x01, 0x01, 0x01, byte(i)}
+ eviction.BranchNodeCreated(keybytesToHex(key))
+ }
+
+ assert.Equal(t, 100, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 0, int(eviction.generations.oldestBlockNum), "should register block num")
+ assert.Equal(t, 100, int(eviction.generations.blockNumToGeneration[1].totalSize), "should register size of gen")
+
+ mock := newMockAccountEvicter()
+
+ eviction.EvictToFitSize(mock, 0)
+
+ assert.Equal(t, 0, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 0, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 2, int(eviction.generations.oldestBlockNum), "should register block num")
+ assert.Equal(t, 100, len(mock.keys), "should evict all 100 accounts")
+}
+
+func TestEvictionNoNeedSingleGen(t *testing.T) {
+ eviction := NewEviction()
+ eviction.SetBlockNumber(1)
+
+ // create 100kb or accounts
+ for i := 0; i < 100; i++ {
+ key := []byte{0x01, 0x01, 0x01, byte(i)}
+ eviction.BranchNodeCreated(keybytesToHex(key))
+ }
+
+ assert.Equal(t, 100, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 0, int(eviction.generations.oldestBlockNum), "should register block num")
+ assert.Equal(t, 100, int(eviction.generations.blockNumToGeneration[1].totalSize), "should register size of gen")
+
+ mock := newMockAccountEvicter()
+
+ eviction.EvictToFitSize(mock, 100)
+
+ assert.Equal(t, 100, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 0, int(eviction.generations.oldestBlockNum), "should register block num")
+ assert.Equal(t, 100, int(eviction.generations.blockNumToGeneration[1].totalSize), "should register size of gen")
+
+ assert.Equal(t, 0, len(mock.keys), "should evict all 100 accounts")
+}
+
+func TestEvictionNoNeedMultipleGen(t *testing.T) {
+ eviction := NewEviction()
+ eviction.SetBlockNumber(1)
+
+ // create 10kb or accounts
+ for i := 0; i < 10; i++ {
+ key := []byte{0x01, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 1024)
+ }
+
+ eviction.SetBlockNumber(2)
+
+ // create 10kb or accounts
+ for i := 0; i < 10; i++ {
+ key := []byte{0x02, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 1024)
+ }
+
+ eviction.SetBlockNumber(4)
+
+ // create 10kb or accounts
+ for i := 0; i < 10; i++ {
+ key := []byte{0x03, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 1024)
+ }
+
+ eviction.SetBlockNumber(5)
+
+ // create 10kb or accounts
+ for i := 0; i < 10; i++ {
+ key := []byte{0x04, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 1024)
+ }
+
+ eviction.SetBlockNumber(7)
+
+ // create 10kb or accounts
+ for i := 0; i < 10; i++ {
+ key := []byte{0x05, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 1024)
+ }
+
+ // 50 kb total
+
+ assert.Equal(t, 50*1024, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 5, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 0, int(eviction.generations.oldestBlockNum), "should register block num")
+ for _, i := range []uint64{1, 2, 4, 5, 7} {
+ assert.Equal(t, 10*1024, int(eviction.generations.blockNumToGeneration[i].totalSize), "should register size of gen")
+ }
+
+ mock := newMockAccountEvicter()
+
+ eviction.EvictToFitSize(mock, 50*1024)
+
+ assert.Equal(t, 50*1024, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 5, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 0, int(eviction.generations.oldestBlockNum), "should register block num")
+
+ assert.Equal(t, 0, len(mock.keys), "should not evict anything")
+}
+
+func TestEvictionPartialMultipleGen(t *testing.T) {
+ eviction := NewEviction()
+ eviction.SetBlockNumber(1)
+
+ // create 10kb or accounts
+ for i := 0; i < 10; i++ {
+ key := []byte{0x01, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 1024)
+ }
+
+ eviction.SetBlockNumber(2)
+
+ // create 10kb or accounts
+ for i := 0; i < 10; i++ {
+ key := []byte{0x02, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 1024)
+ }
+
+ eviction.SetBlockNumber(4)
+
+ // create 10kb or accounts
+ for i := 0; i < 10; i++ {
+ key := []byte{0x03, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 1024)
+ }
+
+ eviction.SetBlockNumber(5)
+
+ // create 10kb or accounts
+ for i := 0; i < 10; i++ {
+ key := []byte{0x04, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 1024)
+ }
+
+ eviction.SetBlockNumber(7)
+
+ // create 10kb or accounts
+ for i := 0; i < 10; i++ {
+ key := []byte{0x05, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 1024)
+ }
+
+ // 50 kb total
+
+ assert.Equal(t, 50*1024, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 5, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 0, int(eviction.generations.oldestBlockNum), "should register block num")
+ for _, i := range []uint64{1, 2, 4, 5, 7} {
+ assert.Equal(t, 10*1024, int(eviction.generations.blockNumToGeneration[i].totalSize), "should register size of gen")
+ }
+
+ mock := newMockAccountEvicter()
+
+ eviction.EvictToFitSize(mock, 20*1024)
+
+ assert.Equal(t, 20*1024, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 2, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 5, int(eviction.generations.oldestBlockNum), "should register block num")
+ for _, i := range []uint64{5, 7} {
+ assert.Equal(t, 10*1024, int(eviction.generations.blockNumToGeneration[i].totalSize), "should register size of gen")
+ }
+ assert.Equal(t, 30, len(mock.keys), "should evict only 3 generations")
+}
+
+func TestEvictionFullMultipleGen(t *testing.T) {
+ eviction := NewEviction()
+ eviction.SetBlockNumber(1)
+
+ // create 10kb or accounts
+ for i := 0; i < 10; i++ {
+ key := []byte{0x01, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 1024)
+ }
+
+ eviction.SetBlockNumber(2)
+
+ // create 10kb or accounts
+ for i := 0; i < 10; i++ {
+ key := []byte{0x02, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 1024)
+ }
+
+ eviction.SetBlockNumber(4)
+
+ // create 10kb or accounts
+ for i := 0; i < 10; i++ {
+ key := []byte{0x03, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 1024)
+ }
+
+ eviction.SetBlockNumber(5)
+
+ // create 10kb or accounts
+ for i := 0; i < 10; i++ {
+ key := []byte{0x04, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 1024)
+ }
+
+ eviction.SetBlockNumber(7)
+
+ // create 10kb or accounts
+ for i := 0; i < 10; i++ {
+ key := []byte{0x05, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 1024)
+ }
+
+ // 50 kb total
+
+ assert.Equal(t, 50*1024, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 5, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 0, int(eviction.generations.oldestBlockNum), "should register block num")
+ for _, i := range []uint64{1, 2, 4, 5, 7} {
+ assert.Equal(t, 10*1024, int(eviction.generations.blockNumToGeneration[i].totalSize), "should register size of gen")
+ }
+
+ mock := newMockAccountEvicter()
+
+ eviction.EvictToFitSize(mock, 0)
+
+ assert.Equal(t, 0, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 0, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 8, int(eviction.generations.oldestBlockNum), "should register block num")
+
+ assert.Equal(t, 50, len(mock.keys), "should evict only 3 generations")
+
+}
+
+func TestEvictionMoveBetweenGen(t *testing.T) {
+ eviction := NewEviction()
+
+ eviction.SetBlockNumber(2)
+
+ for i := 0; i < 2; i++ {
+ key := []byte{0x01, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 10*1024)
+ }
+
+ eviction.SetBlockNumber(4)
+
+ // create 10kb or accounts
+ for i := 0; i < 1; i++ {
+ key := []byte{0x04, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 20*1024)
+ }
+
+ eviction.SetBlockNumber(5)
+
+ // create 10kb or accounts
+ for i := 0; i < 10; i++ {
+ key := []byte{0x05, 0x01, 0x01, byte(i)}
+ eviction.CodeNodeCreated(keybytesToHex(key), 1024)
+ }
+
+ assert.Equal(t, 50*1024, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 3, len(eviction.generations.blockNumToGeneration), "should register generation")
+
+ eviction.CodeNodeTouched(keybytesToHex([]byte{0x01, 0x01, 0x01, 0x00}))
+
+ assert.Equal(t, 50*1024, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 3, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 1, len(eviction.generations.blockNumToGeneration[2].keys()), "should move one acc")
+ assert.Equal(t, 11, len(eviction.generations.blockNumToGeneration[5].keys()), "should move one acc to gen 5")
+ assert.Equal(t, CodeKeyFromAddrHash(keybytesToHex([]byte{0x01, 0x01, 0x01, 0x01})), []byte(eviction.generations.blockNumToGeneration[2].keys()[0]), "should move one acc")
+
+ // move the acc again to the new block!
+ eviction.SetBlockNumber(10)
+ eviction.CodeNodeTouched(keybytesToHex([]byte{0x01, 0x01, 0x01, 0x00}))
+
+ assert.Equal(t, 50*1024, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 4, len(eviction.generations.blockNumToGeneration), "should register generation")
+ assert.Equal(t, 10, len(eviction.generations.blockNumToGeneration[5].keys()), "should move one acc from gen 5, again 10")
+ assert.Equal(t, CodeKeyFromAddrHash(keybytesToHex([]byte{0x01, 0x01, 0x01, 0x00})), []byte(eviction.generations.blockNumToGeneration[10].keys()[0]), "should move one acc")
+
+ // move the last acc from the gen
+ eviction.CodeNodeTouched(keybytesToHex([]byte{0x01, 0x01, 0x01, 0x01}))
+
+ assert.Equal(t, 50*1024, int(eviction.TotalSize()), "should register all accounts")
+ assert.Equal(t, 3, len(eviction.generations.blockNumToGeneration), "should register generation")
+ _, found := eviction.generations.blockNumToGeneration[2]
+ assert.False(t, found, "2nd generation is empty and should be removed")
+
+ assert.Equal(t, 2, len(eviction.generations.blockNumToGeneration[10].keys()), "should move one acc")
+}
diff --git a/trie/trie_observers.go b/trie/trie_observers.go
new file mode 100644
index 000000000..4aa23cdc3
--- /dev/null
+++ b/trie/trie_observers.go
@@ -0,0 +1,104 @@
+package trie
+
+import "github.com/ledgerwatch/turbo-geth/common"
+
+type Observer interface {
+ BranchNodeCreated(hex []byte)
+ BranchNodeDeleted(hex []byte)
+ BranchNodeTouched(hex []byte)
+
+ CodeNodeCreated(hex []byte, size uint)
+ CodeNodeDeleted(hex []byte)
+ CodeNodeTouched(hex []byte)
+ CodeNodeSizeChanged(hex []byte, newSize uint)
+
+ WillUnloadBranchNode(key []byte, nodeHash common.Hash)
+ BranchNodeLoaded(prefixAsNibbles []byte)
+}
+
+var _ Observer = (*NoopObserver)(nil) // make sure that NoopTrieObserver is compliant
+
+// NoopTrieObserver might be used to emulate optional methods in observers
+type NoopObserver struct{}
+
+func (*NoopObserver) BranchNodeCreated(_ []byte) {}
+func (*NoopObserver) BranchNodeDeleted(_ []byte) {}
+func (*NoopObserver) BranchNodeTouched(_ []byte) {}
+func (*NoopObserver) CodeNodeCreated(_ []byte, _ uint) {}
+func (*NoopObserver) CodeNodeDeleted(_ []byte) {}
+func (*NoopObserver) CodeNodeTouched(_ []byte) {}
+func (*NoopObserver) CodeNodeSizeChanged(_ []byte, _ uint) {}
+func (*NoopObserver) WillUnloadBranchNode(_ []byte, _ common.Hash) {}
+func (*NoopObserver) BranchNodeLoaded(_ []byte) {}
+
+// TrieObserverMux multiplies the callback methods and sends them to
+// all it's children.
+type ObserverMux struct {
+ children []Observer
+}
+
+func NewTrieObserverMux() *ObserverMux {
+ return &ObserverMux{make([]Observer, 0)}
+}
+
+func (mux *ObserverMux) AddChild(child Observer) {
+ if child == nil {
+ return
+ }
+
+ mux.children = append(mux.children, child)
+}
+
+func (mux *ObserverMux) BranchNodeCreated(hex []byte) {
+ for _, child := range mux.children {
+ child.BranchNodeCreated(hex)
+ }
+}
+
+func (mux *ObserverMux) BranchNodeDeleted(hex []byte) {
+ for _, child := range mux.children {
+ child.BranchNodeDeleted(hex)
+ }
+}
+
+func (mux *ObserverMux) BranchNodeTouched(hex []byte) {
+ for _, child := range mux.children {
+ child.BranchNodeTouched(hex)
+ }
+}
+
+func (mux *ObserverMux) CodeNodeCreated(hex []byte, size uint) {
+ for _, child := range mux.children {
+ child.CodeNodeCreated(hex, size)
+ }
+}
+
+func (mux *ObserverMux) CodeNodeDeleted(hex []byte) {
+ for _, child := range mux.children {
+ child.CodeNodeDeleted(hex)
+ }
+}
+
+func (mux *ObserverMux) CodeNodeTouched(hex []byte) {
+ for _, child := range mux.children {
+ child.CodeNodeTouched(hex)
+ }
+}
+
+func (mux *ObserverMux) CodeNodeSizeChanged(hex []byte, newSize uint) {
+ for _, child := range mux.children {
+ child.CodeNodeSizeChanged(hex, newSize)
+ }
+}
+
+func (mux *ObserverMux) WillUnloadBranchNode(key []byte, nodeHash common.Hash) {
+ for _, child := range mux.children {
+ child.WillUnloadBranchNode(key, nodeHash)
+ }
+}
+
+func (mux *ObserverMux) BranchNodeLoaded(prefixAsNibbles []byte) {
+ for _, child := range mux.children {
+ child.BranchNodeLoaded(prefixAsNibbles)
+ }
+}
diff --git a/trie/trie_observers_test.go b/trie/trie_observers_test.go
new file mode 100644
index 000000000..3f4eeb6d5
--- /dev/null
+++ b/trie/trie_observers_test.go
@@ -0,0 +1,454 @@
+package trie
+
+import (
+ "fmt"
+ "math/rand"
+ "testing"
+
+ "github.com/ledgerwatch/turbo-geth/common"
+ "github.com/ledgerwatch/turbo-geth/common/dbutils"
+ "github.com/ledgerwatch/turbo-geth/core/types/accounts"
+ "github.com/ledgerwatch/turbo-geth/crypto"
+ "github.com/stretchr/testify/assert"
+)
+
+func genAccount() *accounts.Account {
+ acc := &accounts.Account{}
+ acc.Nonce = 123
+ return acc
+}
+
+func genNKeys(n int) [][]byte {
+ result := make([][]byte, n)
+ for i := range result {
+ result[i] = crypto.Keccak256([]byte{0x0, 0x0, 0x0, byte(i % 256), byte(i / 256)})
+ }
+ return result
+}
+
+func genByteArrayOfLen(n int) []byte {
+ result := make([]byte, n)
+ for i := range result {
+ result[i] = byte(rand.Intn(255))
+ }
+ return result
+}
+
+type partialObserver struct {
+ NoopObserver
+ callbackCalled bool
+}
+
+func (o *partialObserver) BranchNodeDeleted(_ []byte) {
+ o.callbackCalled = true
+}
+
+type mockObserver struct {
+ createdNodes map[string]uint
+ deletedNodes map[string]struct{}
+ touchedNodes map[string]int
+
+ unloadedNodes map[string]int
+ unloadedNodeHashes map[string][]byte
+ reloadedNodes map[string]int
+}
+
+func newMockObserver() *mockObserver {
+ return &mockObserver{
+ createdNodes: make(map[string]uint),
+ deletedNodes: make(map[string]struct{}),
+ touchedNodes: make(map[string]int),
+ unloadedNodes: make(map[string]int),
+ unloadedNodeHashes: make(map[string][]byte),
+ reloadedNodes: make(map[string]int),
+ }
+}
+
+func (m *mockObserver) BranchNodeCreated(hex []byte) {
+ m.createdNodes[common.Bytes2Hex(hex)] = 1
+}
+
+func (m *mockObserver) CodeNodeCreated(hex []byte, size uint) {
+ m.createdNodes[common.Bytes2Hex(hex)] = size
+}
+
+func (m *mockObserver) BranchNodeDeleted(hex []byte) {
+ m.deletedNodes[common.Bytes2Hex(hex)] = struct{}{}
+}
+
+func (m *mockObserver) CodeNodeDeleted(hex []byte) {
+ m.deletedNodes[common.Bytes2Hex(hex)] = struct{}{}
+}
+
+func (m *mockObserver) BranchNodeTouched(hex []byte) {
+ value := m.touchedNodes[common.Bytes2Hex(hex)]
+ value++
+ m.touchedNodes[common.Bytes2Hex(hex)] = value
+}
+
+func (m *mockObserver) CodeNodeTouched(hex []byte) {
+ value := m.touchedNodes[common.Bytes2Hex(hex)]
+ value++
+ m.touchedNodes[common.Bytes2Hex(hex)] = value
+}
+
+func (m *mockObserver) CodeNodeSizeChanged(hex []byte, newSize uint) {
+ m.createdNodes[common.Bytes2Hex(hex)] = newSize
+}
+
+func (m *mockObserver) WillUnloadBranchNode(hex []byte, hash common.Hash) {
+ dictKey := common.Bytes2Hex(hex)
+ value := m.unloadedNodes[dictKey]
+ value++
+ m.unloadedNodes[dictKey] = value
+ m.unloadedNodeHashes[dictKey] = common.CopyBytes(hash[:])
+}
+
+func (m *mockObserver) BranchNodeLoaded(hex []byte) {
+ dictKey := common.Bytes2Hex(hex)
+ value := m.reloadedNodes[dictKey]
+ value++
+ m.reloadedNodes[dictKey] = value
+}
+
+func TestObserversBranchNodesCreateDelete(t *testing.T) {
+ trie := newEmpty()
+
+ observer := newMockObserver()
+
+ trie.AddObserver(observer)
+
+ keys := genNKeys(100)
+
+ for _, key := range keys {
+ acc := genAccount()
+ code := genByteArrayOfLen(100)
+ codeHash := crypto.Keccak256(code)
+ acc.CodeHash = common.BytesToHash(codeHash)
+ trie.UpdateAccount(key, acc)
+ trie.UpdateAccountCode(key, codeNode(code)) //nolint:errcheck
+ }
+
+ expectedNodes := calcSubtreeNodes(trie.root)
+
+ assert.True(t, expectedNodes >= 100, "should register all code nodes")
+ assert.Equal(t, expectedNodes, len(observer.createdNodes), "should register all")
+
+ for _, key := range keys {
+ trie.Delete(key)
+ }
+
+ assert.Equal(t, expectedNodes, len(observer.deletedNodes), "should register all")
+}
+
+func TestObserverCodeSizeChanged(t *testing.T) {
+ rand.Seed(9999)
+
+ trie := newEmpty()
+
+ observer := newMockObserver()
+
+ trie.AddObserver(observer)
+
+ keys := genNKeys(10)
+
+ for _, key := range keys {
+ acc := genAccount()
+ code := genByteArrayOfLen(100)
+ codeHash := crypto.Keccak256(code)
+ acc.CodeHash = common.BytesToHash(codeHash)
+ trie.UpdateAccount(key, acc)
+ trie.UpdateAccountCode(key, codeNode(code)) //nolint:errcheck
+
+ hex := keybytesToHex(key)
+ hex = hex[:len(hex)-1]
+ newSize, ok := observer.createdNodes[common.Bytes2Hex(hex)]
+ assert.True(t, ok, "account should be registed as created")
+ assert.Equal(t, 100, int(newSize), "account size should increase when the account code grows")
+
+ code2 := genByteArrayOfLen(50)
+ codeHash2 := crypto.Keccak256(code2)
+ acc.CodeHash = common.BytesToHash(codeHash2)
+ trie.UpdateAccount(key, acc)
+ trie.UpdateAccountCode(key, codeNode(code2)) //nolint:errcheck
+
+ newSize2, ok := observer.createdNodes[common.Bytes2Hex(hex)]
+ assert.True(t, ok, "account should be registed as created")
+ assert.Equal(t, -50, int(newSize2)-int(newSize), "account size should decrease when the account code shrinks")
+ }
+}
+
+func TestObserverUnloadStorageNodes(t *testing.T) {
+ rand.Seed(9999)
+
+ trie := newEmpty()
+
+ observer := newMockObserver()
+
+ trie.AddObserver(observer)
+
+ key := genNKeys(1)[0]
+
+ storageKeys := genNKeys(10)
+ // group all storage keys into a single fullNode
+ for i := range storageKeys {
+ storageKeys[i][0] = byte(0)
+ storageKeys[i][1] = byte(i)
+ }
+
+ fullKeys := make([][]byte, len(storageKeys))
+ for i, storageKey := range storageKeys {
+ fullKey := dbutils.GenerateCompositeTrieKey(common.BytesToHash(key), common.BytesToHash(storageKey))
+ fullKeys[i] = fullKey
+ }
+
+ acc := genAccount()
+ trie.UpdateAccount(key, acc)
+
+ for i, fullKey := range fullKeys {
+ trie.Update(fullKey, []byte(fmt.Sprintf("test-value-%d", i)))
+ }
+
+ rootHash := trie.Hash()
+
+ // adding nodes doesn't add anything
+ assert.Equal(t, 0, len(observer.reloadedNodes), "adding nodes doesn't add anything")
+
+ // unloading nodes adds to the list
+
+ hex := keybytesToHex(key)
+ hex = hex[:len(hex)-1]
+ hex = append(hex, 0x0, 0x0, 0x0)
+ trie.EvictNode(hex)
+
+ newRootHash := trie.Hash()
+ assert.Equal(t, rootHash, newRootHash, "root hash shouldn't change")
+
+ assert.Equal(t, 1, len(observer.unloadedNodes), "should unload one full node")
+
+ hex = keybytesToHex(key)
+
+ storageKey := fmt.Sprintf("%s000000", common.Bytes2Hex(hex[:len(hex)-1]))
+ assert.Equal(t, 1, observer.unloadedNodes[storageKey], "should unload structure nodes")
+
+ accNode, ok := trie.getAccount(trie.root, hex, 0)
+ assert.True(t, ok, "account should be found")
+
+ sn, ok := accNode.storage.(*shortNode)
+ assert.True(t, ok, "storage should be the shortnode contaning hash")
+
+ _, ok = sn.Val.(hashNode)
+ assert.True(t, ok, "storage should be the shortnode contaning hash")
+}
+
+func TestObserverLoadNodes(t *testing.T) {
+ rand.Seed(9999)
+
+ subtrie := newEmpty()
+
+ observer := newMockObserver()
+
+ // this test needs a specific trie structure
+ // ( full )
+ // (full) (duo) (duo)
+ // (short)(short)(short) (short)(short) (short)(short)
+ // (acc1) (acc2) (acc3) (acc4) (acc5) (acc6) (acc7)
+ //
+ // to ensure this structure we override prefixes of
+ // random account keys with the follwing paths
+
+ prefixes := [][]byte{
+ {0x00, 0x00}, //acc1
+ {0x00, 0x02}, //acc2
+ {0x00, 0x05}, //acc3
+ {0x02, 0x02}, //acc4
+ {0x02, 0x05}, //acc5
+ {0x0A, 0x00}, //acc6
+ {0x0A, 0x03}, //acc7
+ }
+
+ keys := genNKeys(7)
+ for i := range keys {
+ copy(keys[i][:2], prefixes[i])
+ }
+
+ storageKeys := genNKeys(10)
+ // group all storage keys into a single fullNode
+ for i := range storageKeys {
+ storageKeys[i][0] = byte(0)
+ storageKeys[i][1] = byte(i)
+ }
+
+ for _, key := range keys {
+ acc := genAccount()
+ subtrie.UpdateAccount(key, acc)
+
+ for i, storageKey := range storageKeys {
+ fullKey := dbutils.GenerateCompositeTrieKey(common.BytesToHash(key), common.BytesToHash(storageKey))
+ subtrie.Update(fullKey, []byte(fmt.Sprintf("test-value-%d", i)))
+ }
+ }
+
+ trie := newEmpty()
+ trie.AddObserver(observer)
+
+ trie.hook([]byte{}, subtrie.root)
+
+ // fullNode
+ assert.Equal(t, 1, observer.reloadedNodes["000000"], "should reload structure nodes")
+ // duoNode
+ assert.Equal(t, 1, observer.reloadedNodes["000200"], "should reload structure nodes")
+ // duoNode
+ assert.Equal(t, 1, observer.reloadedNodes["000a00"], "should reload structure nodes")
+ // root
+ assert.Equal(t, 1, observer.reloadedNodes["00"], "should reload structure nodes")
+
+ // check storages (should have a single fullNode per account)
+ for _, key := range keys {
+ hex := keybytesToHex(key)
+ storageKey := fmt.Sprintf("%s000000", common.Bytes2Hex(hex[:len(hex)-1]))
+ assert.Equal(t, 1, observer.reloadedNodes[storageKey], "should reload structure nodes")
+ }
+}
+
+func TestObserverTouches(t *testing.T) {
+ rand.Seed(9999)
+
+ trie := newEmpty()
+
+ observer := newMockObserver()
+
+ trie.AddObserver(observer)
+
+ keys := genNKeys(3)
+ for i := range keys {
+ keys[i][0] = 0x00 // 3 belong to the same branch
+ }
+
+ var acc *accounts.Account
+ for _, key := range keys {
+ // creation touches the account
+ acc = genAccount()
+ trie.UpdateAccount(key, acc)
+ }
+
+ key := keys[0]
+ branchNodeHex := "0000"
+
+ assert.Equal(t, uint(1), observer.createdNodes[branchNodeHex], "node is created")
+ assert.Equal(t, 1, observer.touchedNodes[branchNodeHex])
+
+ // updating touches the account
+ code := genByteArrayOfLen(100)
+ codeHash := crypto.Keccak256(code)
+ acc.CodeHash = common.BytesToHash(codeHash)
+ trie.UpdateAccount(key, acc)
+ assert.Equal(t, 2, observer.touchedNodes[branchNodeHex])
+
+ // updating code touches the account
+ trie.UpdateAccountCode(key, codeNode(code)) //nolint:errcheck
+ // 2 touches -- retrieve + updae
+ assert.Equal(t, 4, observer.touchedNodes[branchNodeHex])
+
+ // changing storage touches the account
+ storageKey := genNKeys(1)[0]
+ fullKey := dbutils.GenerateCompositeTrieKey(common.BytesToHash(key), common.BytesToHash(storageKey))
+
+ trie.Update(fullKey, []byte("value-1"))
+ assert.Equal(t, 5, observer.touchedNodes[branchNodeHex])
+
+ trie.Update(fullKey, []byte("value-2"))
+ assert.Equal(t, 6, observer.touchedNodes[branchNodeHex])
+
+ // getting storage touches the account
+ _, ok := trie.Get(fullKey)
+ assert.True(t, ok, "should be able to receive storage")
+ assert.Equal(t, 7, observer.touchedNodes[branchNodeHex])
+
+ // deleting storage touches the account
+ trie.Delete(fullKey)
+ assert.Equal(t, 8, observer.touchedNodes[branchNodeHex])
+
+ // getting code touches the account
+ _, ok = trie.GetAccountCode(key)
+ assert.True(t, ok, "should be able to receive code")
+ assert.Equal(t, 9, observer.touchedNodes[branchNodeHex])
+
+ // getting account touches the account
+ _, ok = trie.GetAccount(key)
+ assert.True(t, ok, "should be able to receive account")
+ assert.Equal(t, 10, observer.touchedNodes[branchNodeHex])
+}
+
+func TestObserverMux(t *testing.T) {
+ trie := newEmpty()
+
+ observer1 := newMockObserver()
+ observer2 := newMockObserver()
+ mux := NewTrieObserverMux()
+ mux.AddChild(observer1)
+ mux.AddChild(observer2)
+
+ trie.AddObserver(mux)
+
+ keys := genNKeys(100)
+ for _, key := range keys {
+ acc := genAccount()
+ trie.UpdateAccount(key, acc)
+
+ code := genByteArrayOfLen(100)
+ codeHash := crypto.Keccak256(code)
+ acc.CodeHash = common.BytesToHash(codeHash)
+ trie.UpdateAccount(key, acc)
+ trie.UpdateAccountCode(key, codeNode(code)) //nolint:errcheck
+
+ _, ok := trie.GetAccount(key)
+ assert.True(t, ok, "acount should be found")
+
+ }
+
+ trie.Hash()
+
+ for i, key := range keys {
+ if i < 80 {
+ trie.Delete(key)
+ } else {
+ hex := keybytesToHex(key)
+ hex = hex[:len(hex)-1]
+ trie.EvictNode(CodeKeyFromAddrHash(hex))
+ }
+ }
+
+ assert.Equal(t, observer1.createdNodes, observer2.createdNodes, "should propagate created events")
+ assert.Equal(t, observer1.deletedNodes, observer2.deletedNodes, "should propagate deleted events")
+ assert.Equal(t, observer1.touchedNodes, observer2.touchedNodes, "should propagate touched events")
+
+ assert.Equal(t, observer1.unloadedNodes, observer2.unloadedNodes, "should propagage unloads")
+ assert.Equal(t, observer1.unloadedNodeHashes, observer2.unloadedNodeHashes, "should propagage unloads")
+ assert.Equal(t, observer1.reloadedNodes, observer2.reloadedNodes, "should propagage reloads")
+}
+
+func TestObserverPartial(t *testing.T) {
+ trie := newEmpty()
+
+ observer := &partialObserver{callbackCalled: false} // only implements `BranchNodeDeleted`
+ trie.AddObserver(observer)
+
+ keys := genNKeys(2)
+ for _, key := range keys {
+ acc := genAccount()
+ trie.UpdateAccount(key, acc)
+
+ code := genByteArrayOfLen(100)
+ codeHash := crypto.Keccak256(code)
+ acc.CodeHash = common.BytesToHash(codeHash)
+ trie.UpdateAccount(key, acc)
+ trie.UpdateAccountCode(key, codeNode(code)) //nolint:errcheck
+
+ }
+ for _, key := range keys {
+ trie.Delete(key)
+ }
+
+ assert.True(t, observer.callbackCalled, "should be called")
+}
diff --git a/trie/trie_pruning.go b/trie/trie_pruning.go
deleted file mode 100644
index f413debef..000000000
--- a/trie/trie_pruning.go
+++ /dev/null
@@ -1,267 +0,0 @@
-// Copyright 2019 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty off
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-// Pruning of the Merkle Patricia trees
-
-package trie
-
-import (
- "fmt"
- "sort"
- "strings"
-
- "github.com/ledgerwatch/turbo-geth/common"
- "github.com/ledgerwatch/turbo-geth/common/pool"
-)
-
-type TriePruning struct {
- accountTimestamps map[string]uint64
-
- // Maps timestamp (uint64) to set of prefixes of nodes (string)
- accounts map[uint64]map[string]struct{}
-
- // For each timestamp, keeps number of branch nodes belonging to it
- generationCounts map[uint64]int
-
- // Keeps total number of branch nodes
- nodeCount int
-
- // The oldest timestamp of all branch nodes
- oldestGeneration uint64
-
- // Current timestamp
- blockNr uint64
-
- createNodeFunc func(prefixAsNibbles []byte)
- unloadNodeFunc func(prefix []byte, nodeHash []byte) // called when fullNode or dualNode unloaded
-}
-
-func NewTriePruning(oldestGeneration uint64) *TriePruning {
- return &TriePruning{
- oldestGeneration: oldestGeneration,
- blockNr: oldestGeneration,
- accountTimestamps: make(map[string]uint64),
- accounts: make(map[uint64]map[string]struct{}),
- generationCounts: make(map[uint64]int),
- createNodeFunc: func([]byte) {},
- }
-}
-
-func (tp *TriePruning) SetBlockNr(blockNr uint64) {
- tp.blockNr = blockNr
-}
-
-func (tp *TriePruning) BlockNr() uint64 {
- return tp.blockNr
-}
-
-func (tp *TriePruning) SetCreateNodeFunc(f func(prefixAsNibbles []byte)) {
- tp.createNodeFunc = f
-}
-
-func (tp *TriePruning) SetUnloadNodeFunc(f func(prefix []byte, nodeHash []byte)) {
- tp.unloadNodeFunc = f
-}
-
-// Updates a node to the current timestamp
-// contract is effectively address of the smart contract
-// hex is the prefix of the key
-// parent is the node that needs to be modified to unload the touched node
-// exists is true when the node existed before, and false if it is a new one
-// prevTimestamp is the timestamp the node current has
-func (tp *TriePruning) touch(hexS string, exists bool, prevTimestamp uint64, del bool, newTimestamp uint64) {
- //fmt.Printf("TouchFrom %x, exists: %t, prevTimestamp %d, del %t, newTimestamp %d\n", hexS, exists, prevTimestamp, del, newTimestamp)
- if exists && !del && prevTimestamp == newTimestamp {
- return
- }
- if !del {
- var newMap map[string]struct{}
- if m, ok := tp.accounts[newTimestamp]; ok {
- newMap = m
- } else {
- newMap = make(map[string]struct{})
- tp.accounts[newTimestamp] = newMap
- }
-
- newMap[hexS] = struct{}{}
- }
- if exists {
- if m, ok := tp.accounts[prevTimestamp]; ok {
- delete(m, hexS)
- if len(m) == 0 {
- delete(tp.accounts, prevTimestamp)
- }
- }
- }
- // Update generation count
- if !del {
- tp.generationCounts[newTimestamp]++
- tp.nodeCount++
- }
- if exists {
- tp.generationCounts[prevTimestamp]--
- if tp.generationCounts[prevTimestamp] == 0 {
- delete(tp.generationCounts, prevTimestamp)
- }
- tp.nodeCount--
- }
-}
-
-func (tp *TriePruning) Timestamp(hex []byte) uint64 {
- ts := tp.accountTimestamps[string(hex)]
- return ts
-}
-
-// Updates a node to the current timestamp
-// contract is effectively address of the smart contract
-// hex is the prefix of the key
-// parent is the node that needs to be modified to unload the touched node
-func (tp *TriePruning) Touch(hex []byte, del bool) {
- var exists = false
- var prevTimestamp uint64
- hexS := string(common.CopyBytes(hex))
-
- if m, ok := tp.accountTimestamps[hexS]; ok {
- prevTimestamp = m
- exists = true
- if del {
- delete(tp.accountTimestamps, hexS)
- }
- }
- if !del {
- tp.accountTimestamps[hexS] = tp.blockNr
- }
- if !exists {
- tp.createNodeFunc([]byte(hexS))
- }
-
- tp.touch(hexS, exists, prevTimestamp, del, tp.blockNr)
-}
-
-func pruneMap(t *Trie, m map[string]struct{}) bool {
- hexes := make([]string, len(m))
- i := 0
- for hexS := range m {
- hexes[i] = hexS
- i++
- }
- var empty = false
- sort.Strings(hexes)
-
- for i, hex := range hexes {
- if i == 0 || len(hex) == 0 || !strings.HasPrefix(hex, hexes[i-1]) { // If the parent nodes pruned, there is no need to prune descendants
- t.unload([]byte(hex))
- if len(hex) == 0 {
- empty = true
- }
- }
- }
- return empty
-}
-
-// Prunes all nodes that are older than given timestamp
-func (tp *TriePruning) PruneToTimestamp(
- accountsTrie *Trie,
- targetTimestamp uint64,
-) {
- // Remove (unload) nodes from storage tries and account trie
- aggregateAccounts := make(map[string]struct{})
- for gen := tp.oldestGeneration; gen < targetTimestamp; gen++ {
- tp.nodeCount -= tp.generationCounts[gen]
- if m, ok := tp.accounts[gen]; ok {
- for hexS := range m {
- aggregateAccounts[hexS] = struct{}{}
- }
- }
- delete(tp.accounts, gen)
- }
-
- // intermediate hashes
- key := pool.GetBuffer(64)
- defer pool.PutBuffer(key)
- for prefix := range aggregateAccounts {
- if len(prefix) == 0 || len(prefix)%2 == 1 {
- continue
- }
-
- nd, parent, ok := accountsTrie.getNode([]byte(prefix), false)
- if !ok {
- continue
- }
- switch nd.(type) {
- case *duoNode, *fullNode:
- // will work only with these types of nodes
- default:
- continue
- }
- switch parent.(type) { // without this condition - doesn't work. Need investigate why.
- case *duoNode, *fullNode:
- // will work only with these types of nodes
- CompressNibbles([]byte(prefix), &key.B)
- tp.unloadNodeFunc(key.B, nd.reference())
- default:
- continue
- }
- }
-
- pruneMap(accountsTrie, aggregateAccounts)
-
- // Remove fom the timestamp structure
- for hexS := range aggregateAccounts {
- delete(tp.accountTimestamps, hexS)
- }
- tp.oldestGeneration = targetTimestamp
-}
-
-// Prunes mininum number of generations necessary so that the total
-// number of prunable nodes is at most `targetNodeCount`
-func (tp *TriePruning) PruneTo(
- accountsTrie *Trie,
- targetNodeCount int,
-) bool {
- if tp.nodeCount <= targetNodeCount {
- return false
- }
- excess := tp.nodeCount - targetNodeCount
- prunable := 0
- pruneGeneration := tp.oldestGeneration
- for prunable < excess && pruneGeneration < tp.blockNr {
- prunable += tp.generationCounts[pruneGeneration]
- pruneGeneration++
- }
- //fmt.Printf("Will prune to generation %d, nodes to prune: %d, excess %d\n", pruneGeneration, prunable, excess)
- tp.PruneToTimestamp(accountsTrie, pruneGeneration)
- return true
-}
-
-func (tp *TriePruning) NodeCount() int {
- return tp.nodeCount
-}
-
-func (tp *TriePruning) GenCounts() map[uint64]int {
- return tp.generationCounts
-}
-
-// DebugDump is used in the tests to ensure that there are no prunable entries (in such case, this function returns empty string)
-func (tp *TriePruning) DebugDump() string {
- var sb strings.Builder
- for timestamp, m := range tp.accounts {
- for account := range m {
- sb.WriteString(fmt.Sprintf("%d %x\n", timestamp, account))
- }
- }
- return sb.String()
-}
diff --git a/trie/trie_pruning_test.go b/trie/trie_pruning_test.go
deleted file mode 100644
index a9690892e..000000000
--- a/trie/trie_pruning_test.go
+++ /dev/null
@@ -1,61 +0,0 @@
-// Copyright 2019 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty off
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-// Pruning of the Merkle Patricia trees
-
-package trie
-
-import (
- "encoding/binary"
- "fmt"
- "testing"
-
- "github.com/ledgerwatch/turbo-geth/common"
-)
-
-func TestOnePerTimestamp(t *testing.T) {
- tp := NewTriePruning(0)
- tr := New(common.Hash{})
- tr.SetTouchFunc(tp.Touch)
- var key [4]byte
- value := []byte("V")
- var timestamp uint64 = 0
- for n := uint32(0); n < uint32(100); n++ {
- tp.SetBlockNr(timestamp)
- binary.BigEndian.PutUint32(key[:], n)
- tr.Update(key[:], value) // Each key is added within a new generation
- timestamp++
- }
- for n := uint32(50); n < uint32(60); n++ {
- tp.SetBlockNr(timestamp)
- binary.BigEndian.PutUint32(key[:], n)
- tr.Delete(key[:]) // Each key is added within a new generation
- timestamp++
- }
- for n := uint32(30); n < uint32(59); n++ {
- tp.SetBlockNr(timestamp)
- binary.BigEndian.PutUint32(key[:], n)
- tr.Get(key[:]) // Each key is added within a new generation
- timestamp++
- }
- prunableNodes := tr.CountPrunableNodes()
- fmt.Printf("Actual prunable nodes: %d, accounted: %d\n", prunableNodes, tp.NodeCount())
- if b := tp.PruneTo(tr, 4); !b {
- t.Fatal("Not pruned")
- }
- prunableNodes = tr.CountPrunableNodes()
- fmt.Printf("Actual prunable nodes: %d, accounted: %d\n", prunableNodes, tp.NodeCount())
-}
diff --git a/trie/trie_test.go b/trie/trie_test.go
index 5066b925a..780ecd9e7 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -887,6 +887,6 @@ func TestCodeNodeUpdateAccountNoChangeCodeHash(t *testing.T) {
trie.UpdateAccount(crypto.Keccak256(address[:]), &acc)
value, gotValue := trie.GetAccountCode(crypto.Keccak256(address[:]))
- assert.Equal(t, value, codeValue1, "the value should NOT reset after account's non codehash had changed")
+ assert.Equal(t, codeValue1, value, "the value should NOT reset after account's non codehash had changed")
assert.True(t, gotValue, "should indicate that the code is still in the cache")
}