optimized state roots cache for erigon-cl (#7020)

This commit is contained in:
Giulio rebuffo 2023-03-04 17:28:20 +01:00 committed by GitHub
parent e48609ee8d
commit 90ed3c1cb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 56 additions and 40 deletions

View File

@ -10,30 +10,43 @@ import (
)
func (b *BeaconState) HashSSZ() ([32]byte, error) {
if err := b.computeDirtyLeaves(); err != nil {
var err error
if err = b.computeDirtyLeaves(); err != nil {
return [32]byte{}, err
}
if b.cachedStateRoot != (libcommon.Hash{}) {
return b.cachedStateRoot, nil
}
// Pad to 32 of length
return merkle_tree.MerkleRootFromLeaves(b.leaves[:])
b.cachedStateRoot, err = merkle_tree.MerkleRootFromLeaves(b.leaves[:])
return b.cachedStateRoot, err
}
func (b *BeaconState) OptimisticallySetStateRoot(root libcommon.Hash) {
b.cachedStateRoot = root
for index := range b.touchedLeaves {
b.touchedLeaves[index] = false
}
}
func (b *BeaconState) computeDirtyLeaves() error {
// Update all dirty leafs
// ----
// Field(0): GenesisTime
if b.isLeafDirty(GenesisTimeLeafIndex) {
b.updateLeaf(GenesisTimeLeafIndex, merkle_tree.Uint64Root(b.genesisTime))
b.cachedStateRoot = libcommon.Hash{}
}
// Field(1): GenesisValidatorsRoot
if b.isLeafDirty(GenesisValidatorsRootLeafIndex) {
b.updateLeaf(GenesisValidatorsRootLeafIndex, b.genesisValidatorsRoot)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(2): Slot
if b.isLeafDirty(SlotLeafIndex) {
b.updateLeaf(SlotLeafIndex, merkle_tree.Uint64Root(b.slot))
b.cachedStateRoot = libcommon.Hash{}
}
// Field(3): Fork
@ -43,6 +56,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(ForkLeafIndex, forkRoot)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(4): LatestBlockHeader
@ -52,6 +66,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(LatestBlockHeaderLeafIndex, headerRoot)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(5): BlockRoots
@ -61,6 +76,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(BlockRootsLeafIndex, root)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(6): StateRoots
@ -70,6 +86,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(StateRootsLeafIndex, root)
b.cachedStateRoot = libcommon.Hash{}
}
@ -80,6 +97,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(HistoricalRootsLeafIndex, root)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(8): Eth1Data
@ -89,6 +107,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(Eth1DataLeafIndex, dataRoot)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(9): Eth1DataVotes
@ -98,11 +117,13 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(Eth1DataVotesLeafIndex, root)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(10): Eth1DepositIndex
if b.isLeafDirty(Eth1DepositIndexLeafIndex) {
b.updateLeaf(Eth1DepositIndexLeafIndex, merkle_tree.Uint64Root(b.eth1DepositIndex))
b.cachedStateRoot = libcommon.Hash{}
}
// Field(11): Validators
@ -112,6 +133,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(ValidatorsLeafIndex, root)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(12): Balances
@ -121,7 +143,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(BalancesLeafIndex, root)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(13): RandaoMixes
@ -131,6 +153,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(RandaoMixesLeafIndex, root)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(14): Slashings
@ -140,7 +163,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(SlashingsLeafIndex, root)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(15): PreviousEpochParticipation
if b.isLeafDirty(PreviousEpochParticipationLeafIndex) {
@ -149,7 +172,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(PreviousEpochParticipationLeafIndex, root)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(16): CurrentEpochParticipation
@ -159,6 +182,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(CurrentEpochParticipationLeafIndex, root)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(17): JustificationBits
@ -166,6 +190,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
var root [32]byte
root[0] = b.justificationBits.Byte()
b.updateLeaf(JustificationBitsLeafIndex, root)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(18): PreviousJustifiedCheckpoint
@ -175,6 +200,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(PreviousJustifiedCheckpointLeafIndex, checkpointRoot)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(19): CurrentJustifiedCheckpoint
@ -184,6 +210,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(CurrentJustifiedCheckpointLeafIndex, checkpointRoot)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(20): FinalizedCheckpoint
@ -193,6 +220,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(FinalizedCheckpointLeafIndex, checkpointRoot)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(21): Inactivity Scores
@ -202,6 +230,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(InactivityScoresLeafIndex, root)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(22): CurrentSyncCommitte
@ -211,6 +240,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(CurrentSyncCommitteeLeafIndex, committeeRoot)
b.cachedStateRoot = libcommon.Hash{}
}
// Field(23): NextSyncCommitte
@ -220,6 +250,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(NextSyncCommitteeLeafIndex, committeeRoot)
b.cachedStateRoot = libcommon.Hash{}
}
if b.version < clparams.BellatrixVersion {
@ -232,6 +263,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(LatestExecutionPayloadHeaderLeafIndex, headerRoot)
b.cachedStateRoot = libcommon.Hash{}
}
if b.version >= clparams.CapellaVersion {
@ -239,11 +271,13 @@ func (b *BeaconState) computeDirtyLeaves() error {
// Field(25): NextWithdrawalIndex
if b.isLeafDirty(NextWithdrawalIndexLeafIndex) {
b.updateLeaf(NextWithdrawalIndexLeafIndex, merkle_tree.Uint64Root(b.nextWithdrawalIndex))
b.cachedStateRoot = libcommon.Hash{}
}
// Field(26): NextWithdrawalValidatorIndex
if b.isLeafDirty(NextWithdrawalValidatorIndexLeafIndex) {
b.updateLeaf(NextWithdrawalValidatorIndexLeafIndex, merkle_tree.Uint64Root(b.nextWithdrawalValidatorIndex))
b.cachedStateRoot = libcommon.Hash{}
}
// Field(27): HistoricalSummaries
@ -253,6 +287,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
return err
}
b.updateLeaf(HistoricalSummariesLeafIndex, root)
b.cachedStateRoot = libcommon.Hash{}
}
}

View File

@ -14,8 +14,8 @@ func BenchmarkStateRootNonCached(b *testing.B) {
}
}
// Prev: 13953
// Curr: 2093
// Prev: 1400
// Curr: 139.4
func BenchmarkStateRootCached(b *testing.B) {
// Re-use same fields
base := state.GetEmptyBeaconState()

View File

@ -67,6 +67,7 @@ type BeaconState struct {
totalActiveBalanceCache *uint64
totalActiveBalanceRootCache uint64
proposerIndex *uint64
cachedStateRoot libcommon.Hash
// Configs
beaconConfig *clparams.BeaconChainConfig
}

View File

@ -5,7 +5,6 @@ import (
"time"
"github.com/Giulio2002/bls"
libcommon "github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon/cl/cltypes"
"github.com/ledgerwatch/erigon/cl/fork"
"github.com/ledgerwatch/erigon/cl/merkle_tree"
@ -43,27 +42,20 @@ func (s *StateTransistor) TransitionState(block *cltypes.SignedBeaconBlock) erro
return fmt.Errorf("expected state root differs from received state root")
}
}
// Write the block root to the cache
s.stateRootsCache.Add(block.Block.Slot, block.Block.StateRoot)
s.state.OptimisticallySetStateRoot(block.Block.StateRoot)
return nil
}
// transitionSlot is called each time there is a new slot to process
func (s *StateTransistor) transitionSlot() error {
slot := s.state.Slot()
var (
previousStateRoot libcommon.Hash
err error
)
if previousStateRootI, ok := s.stateRootsCache.Get(slot); ok {
previousStateRoot = previousStateRootI.(libcommon.Hash)
} else {
previousStateRoot, err = s.state.HashSSZ()
if err != nil {
return err
}
previousStateRoot, err := s.state.HashSSZ()
if err != nil {
return err
}
s.state.SetStateRootAt(int(slot%s.beaconConfig.SlotsPerHistoricalRoot), previousStateRoot)
latestBlockHeader := s.state.LatestBlockHeader()

View File

@ -1,33 +1,21 @@
package transition
import (
lru "github.com/hashicorp/golang-lru"
"github.com/ledgerwatch/erigon/cl/clparams"
"github.com/ledgerwatch/erigon/cmd/erigon-cl/core/state"
)
// StateTransistor takes care of state transition
type StateTransistor struct {
state *state.BeaconState
beaconConfig *clparams.BeaconChainConfig
genesisConfig *clparams.GenesisConfig
noValidate bool // Whether we want to do cryptography checks.
// stateRootsCache caches slot => stateRoot
stateRootsCache *lru.Cache
state *state.BeaconState
beaconConfig *clparams.BeaconChainConfig
noValidate bool // Whether we want to do cryptography checks.
}
const stateRootsCacheSize = 256
func New(state *state.BeaconState, beaconConfig *clparams.BeaconChainConfig, genesisConfig *clparams.GenesisConfig, noValidate bool) *StateTransistor {
stateRootsCache, err := lru.New(stateRootsCacheSize)
if err != nil {
panic(err)
}
return &StateTransistor{
state: state,
beaconConfig: beaconConfig,
genesisConfig: genesisConfig,
noValidate: noValidate,
stateRootsCache: stateRootsCache,
state: state,
beaconConfig: beaconConfig,
noValidate: noValidate,
}
}