Added checkpoints and justification bits processing (post-altair) (#6699)

* Added processing of checkpoints
* Unit tests rigorously imported from prysm
* They all pass :)
This commit is contained in:
Giulio rebuffo 2023-01-26 00:31:20 +01:00 committed by GitHub
parent 20a865b79f
commit ff21ef7b21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 451 additions and 23 deletions

View File

@ -214,6 +214,15 @@ var CheckpointSyncEndpoints = map[NetworkType][]string{
// BeaconChainConfig contains constant configs for node to participate in beacon chain.
type BeaconChainConfig struct {
// Constants (non-configurable)
GenesisSlot uint64 `yaml:"GENESIS_SLOT"` // GenesisSlot represents the first canonical slot number of the beacon chain.
GenesisEpoch uint64 `yaml:"GENESIS_EPOCH"` // GenesisEpoch represents the first canonical epoch number of the beacon chain.
FarFutureEpoch uint64 `yaml:"FAR_FUTURE_EPOCH"` // FarFutureEpoch represents a epoch extremely far away in the future used as the default penalization epoch for validators.
FarFutureSlot uint64 `yaml:"FAR_FUTURE_SLOT"` // FarFutureSlot represents a slot extremely far away in the future.
BaseRewardsPerEpoch uint64 `yaml:"BASE_REWARDS_PER_EPOCH"` // BaseRewardsPerEpoch is used to calculate the per epoch rewards.
DepositContractTreeDepth uint64 `yaml:"DEPOSIT_CONTRACT_TREE_DEPTH"` // DepositContractTreeDepth depth of the Merkle trie of deposits in the validator deposit contract on the PoW chain.
JustificationBitsLength uint64 `yaml:"JUSTIFICATION_BITS_LENGTH"` // JustificationBitsLength defines number of epochs to track when implementing k-finality in Casper FFG.
// Misc constants.
PresetBase string `yaml:"PRESET_BASE" spec:"true"` // PresetBase represents the underlying spec preset this config is based on.
ConfigName string `yaml:"CONFIG_NAME" spec:"true"` // ConfigName for allowing an easy human-readable way of knowing what chain is being used.
@ -441,11 +450,11 @@ func configForkNames(b *BeaconChainConfig) map[[VersionLength]byte]string {
var MainnetBeaconConfig BeaconChainConfig = BeaconChainConfig{
// Constants (Non-configurable)
/*FarFutureEpoch: math.MaxUint64,
FarFutureEpoch: math.MaxUint64,
FarFutureSlot: math.MaxUint64,
BaseRewardsPerEpoch: 4,
DepositContractTreeDepth: 32,*/
GenesisDelay: 604800, // 1 week.
DepositContractTreeDepth: 32,
GenesisDelay: 604800, // 1 week.
// Misc constant.
TargetCommitteeSize: 128,

View File

@ -0,0 +1,35 @@
package cltypes
import "github.com/ledgerwatch/erigon/cl/utils"
const JustificationBitsLength = 4
type JustificationBits [JustificationBitsLength]bool // Bit vector of size 4
func (j JustificationBits) Byte() (out byte) {
for i, bit := range j {
if !bit {
continue
}
out += byte(utils.PowerOf2(uint64(i)))
}
return
}
func (j *JustificationBits) FromByte(b byte) {
j[0] = b&1 > 0
j[1] = b&2 > 0
j[2] = b&4 > 0
j[3] = b&8 > 0
}
// CheckRange checks if bits in certain range are all enabled.
func (j JustificationBits) CheckRange(start int, end int) bool {
checkBits := j[start:end]
for _, bit := range checkBits {
if !bit {
return false
}
}
return true
}

View File

@ -0,0 +1,15 @@
package cltypes_test
import (
"testing"
"github.com/ledgerwatch/erigon/cl/cltypes"
"github.com/stretchr/testify/require"
)
func TestParticipationBits(t *testing.T) {
bits := cltypes.JustificationBits{}
bits.FromByte(2)
require.Equal(t, bits, cltypes.JustificationBits{false, true, false, false})
require.Equal(t, bits.Byte(), byte(2))
}

View File

@ -0,0 +1,34 @@
package cltypes
import (
"github.com/ledgerwatch/erigon/cl/utils"
)
type ParticipationFlags byte
func (f ParticipationFlags) Add(index int) ParticipationFlags {
return f | ParticipationFlags(utils.PowerOf2(uint64(index)))
}
func (f ParticipationFlags) HasFlag(index int) bool {
flag := ParticipationFlags(utils.PowerOf2(uint64(index)))
return f&flag == flag
}
type ParticipationFlagsList []ParticipationFlags
func (p ParticipationFlagsList) Bytes() []byte {
b := make([]byte, len(p))
for i := range p {
b[i] = byte(p[i])
}
return b
}
func ParticipationFlagsListFromBytes(buf []byte) ParticipationFlagsList {
flagsList := make([]ParticipationFlags, len(buf))
for i := range flagsList {
flagsList[i] = ParticipationFlags(buf[i])
}
return flagsList
}

View File

@ -0,0 +1,14 @@
package cltypes_test
import (
"testing"
"github.com/ledgerwatch/erigon/cl/cltypes"
"github.com/stretchr/testify/require"
)
func TestParticipationFlags(t *testing.T) {
flagsList := cltypes.ParticipationFlagsListFromBytes([]byte{0, 0, 0, 0})
flagsList[0] = flagsList[0].Add(4) // Turn on fourth bit
require.True(t, flagsList[0].HasFlag(4))
}

View File

@ -338,3 +338,8 @@ func (v *Validator) HashSSZ() ([32]byte, error) {
leaves[7] = merkle_tree.Uint64Root(v.WithdrawableEpoch)
return merkle_tree.ArraysRoot(leaves, 8)
}
// Active returns if validator is active for given epoch
func (v *Validator) Active(epoch uint64) bool {
return v.ActivationEpoch <= epoch && epoch < v.ExitEpoch
}

View File

@ -40,7 +40,7 @@ func RetrieveBeaconState(ctx context.Context, beaconConfig *clparams.BeaconChain
epoch := utils.GetCurrentEpoch(genesisConfig.GenesisTime, beaconConfig.SecondsPerSlot, beaconConfig.SlotsPerEpoch)
beaconState := &state.BeaconState{}
beaconState := state.New(beaconConfig)
fmt.Println(int(beaconConfig.GetCurrentStateVersion(epoch)))
err = beaconState.DecodeSSZWithVersion(marshaled, int(beaconConfig.GetCurrentStateVersion(epoch)))
if err != nil {

View File

@ -0,0 +1,98 @@
package state
import (
"fmt"
libcommon "github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon/cl/cltypes"
)
// GetActiveValidatorsIndices returns the list of validator indices active for the given epoch.
func (b *BeaconState) GetActiveValidatorsIndices(epoch uint64) (indicies []uint64) {
for i, validator := range b.validators {
if !validator.Active(epoch) {
continue
}
indicies = append(indicies, uint64(i))
}
return
}
// Epoch returns current epoch.
func (b *BeaconState) Epoch() uint64 {
return b.slot / b.beaconConfig.SlotsPerEpoch // Return current state epoch
}
// PreviousEpoch returns previous epoch.
func (b *BeaconState) PreviousEpoch() uint64 {
epoch := b.Epoch()
if epoch == 0 {
return epoch
}
return epoch - 1
}
// getUnslashedParticipatingIndices returns set of currently unslashed participating indexes
func (b *BeaconState) GetUnslashedParticipatingIndices(flagIndex int, epoch uint64) (validatorSet []uint64, err error) {
var participation cltypes.ParticipationFlagsList
// Must be either previous or current epoch
switch epoch {
case b.Epoch():
participation = b.currentEpochParticipation
case b.PreviousEpoch():
participation = b.previousEpochParticipation
default:
return nil, fmt.Errorf("getUnslashedParticipatingIndices: only epoch and previous epoch can be used")
}
// Iterate over all validators and include the active ones that have flag_index enabled and are not slashed.
for i, validator := range b.Validators() {
if !validator.Active(epoch) ||
!participation[i].HasFlag(flagIndex) ||
validator.Slashed {
continue
}
validatorSet = append(validatorSet, uint64(i))
}
return
}
// GetTotalBalance return the sum of all balances within the given validator set.
func (b *BeaconState) GetTotalBalance(validatorSet []uint64) (uint64, error) {
var (
total uint64
validatorsSize = uint64(len(b.validators))
)
for _, validatorIndex := range validatorSet {
// Should be in bounds.
if validatorIndex >= validatorsSize {
return 0, fmt.Errorf("GetTotalBalance: out of bounds validator index")
}
total += b.validators[validatorIndex].EffectiveBalance
}
// Always minimum set to EffectiveBalanceIncrement
if total < b.beaconConfig.EffectiveBalanceIncrement {
total = b.beaconConfig.EffectiveBalanceIncrement
}
return total, nil
}
// GetTotalActiveBalance return the sum of all balances within active validators.
func (b *BeaconState) GetTotalActiveBalance() (uint64, error) {
return b.GetTotalBalance(b.GetActiveValidatorsIndices(b.Epoch()))
}
// GetBlockRoot returns blook root at start of a given epoch
func (b *BeaconState) GetBlockRoot(epoch uint64) (libcommon.Hash, error) {
return b.GetBlockRootAtSlot(epoch * b.beaconConfig.SlotsPerEpoch)
}
// GetBlockRootAtSlot returns the block root at a given slot
func (b *BeaconState) GetBlockRootAtSlot(slot uint64) (libcommon.Hash, error) {
if slot >= b.slot {
return libcommon.Hash{}, fmt.Errorf("GetBlockRootAtSlot: slot in the future")
}
if b.slot > slot+b.beaconConfig.SlotsPerHistoricalRoot {
return libcommon.Hash{}, fmt.Errorf("GetBlockRootAtSlot: slot too much far behind")
}
return b.blockRoots[slot%b.beaconConfig.SlotsPerHistoricalRoot], nil
}

View File

@ -0,0 +1,50 @@
package state_test
import (
"testing"
"github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon/cl/cltypes"
"github.com/ledgerwatch/erigon/cmd/erigon-cl/core/state"
"github.com/stretchr/testify/require"
)
func TestActiveValidatorIndices(t *testing.T) {
epoch := uint64(2)
testState := state.GetEmptyBeaconState()
// Not Active validator
testState.AddValidator(&cltypes.Validator{
ActivationEpoch: 3,
ExitEpoch: 9,
EffectiveBalance: 2e9,
})
// Active Validator
testState.AddValidator(&cltypes.Validator{
ActivationEpoch: 1,
ExitEpoch: 9,
EffectiveBalance: 2e9,
})
testState.SetSlot(epoch * 32) // Epoch
testFlags := cltypes.ParticipationFlagsListFromBytes([]byte{1, 1})
testState.SetCurrentEpochParticipation(testFlags)
// Only validator at index 1 (second validator) is active.
require.Equal(t, testState.GetActiveValidatorsIndices(epoch), []uint64{1})
set, err := testState.GetUnslashedParticipatingIndices(0x00, epoch)
require.NoError(t, err)
require.Equal(t, set, []uint64{1})
// Check if balances are retrieved correctly
totalBalance, err := testState.GetTotalActiveBalance()
require.NoError(t, err)
require.Equal(t, totalBalance, uint64(2e9))
}
func TestGetBlockRoot(t *testing.T) {
epoch := uint64(2)
testState := state.GetEmptyBeaconState()
root := common.HexToHash("ff")
testState.SetSlot(100)
testState.SetBlockRootAt(int(epoch*32), root)
retrieved, err := testState.GetBlockRoot(epoch)
require.NoError(t, err)
require.Equal(t, retrieved, root)
}

View File

@ -3,6 +3,7 @@ package state
import (
libcommon "github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon/cl/clparams"
"github.com/ledgerwatch/erigon/cl/cltypes"
"github.com/ledgerwatch/erigon/core/types"
)
@ -77,15 +78,15 @@ func (b *BeaconState) SlashingSegmentAt(pos int) uint64 {
return b.slashings[pos]
}
func (b *BeaconState) PreviousEpochParticipation() []byte {
func (b *BeaconState) PreviousEpochParticipation() cltypes.ParticipationFlagsList {
return b.previousEpochParticipation
}
func (b *BeaconState) CurrentEpochParticipation() []byte {
func (b *BeaconState) CurrentEpochParticipation() cltypes.ParticipationFlagsList {
return b.currentEpochParticipation
}
func (b *BeaconState) JustificationBits() byte {
func (b *BeaconState) JustificationBits() cltypes.JustificationBits {
return b.justificationBits
}
@ -124,3 +125,7 @@ func (b *BeaconState) NextWithdrawalValidatorIndex() uint64 {
func (b *BeaconState) HistoricalSummaries() []*cltypes.HistoricalSummary {
return b.historicalSummaries
}
func (b *BeaconState) Version() clparams.StateVersion {
return b.version
}

View File

@ -140,7 +140,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
}
// Field(15): PreviousEpochParticipation
if b.isLeafDirty(PreviousEpochParticipationLeafIndex) {
participationRoot, err := merkle_tree.BitlistRootWithLimitForState(b.previousEpochParticipation, state_encoding.ValidatorRegistryLimit)
participationRoot, err := merkle_tree.BitlistRootWithLimitForState(b.previousEpochParticipation.Bytes(), state_encoding.ValidatorRegistryLimit)
if err != nil {
return err
}
@ -149,7 +149,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
// Field(16): CurrentEpochParticipation
if b.isLeafDirty(CurrentEpochParticipationLeafIndex) {
participationRoot, err := merkle_tree.BitlistRootWithLimitForState(b.currentEpochParticipation, state_encoding.ValidatorRegistryLimit)
participationRoot, err := merkle_tree.BitlistRootWithLimitForState(b.currentEpochParticipation.Bytes(), state_encoding.ValidatorRegistryLimit)
if err != nil {
return err
}
@ -159,7 +159,7 @@ func (b *BeaconState) computeDirtyLeaves() error {
// Field(17): JustificationBits
if b.isLeafDirty(JustificationBitsLeafIndex) {
var root [32]byte
root[0] = b.justificationBits
root[0] = b.justificationBits.Byte()
b.updateLeaf(JustificationBitsLeafIndex, root)
}

View File

@ -100,17 +100,17 @@ func (b *BeaconState) SetSlashingSegmentAt(index int, segment uint64) {
b.slashings[index] = segment
}
func (b *BeaconState) SetPreviousEpochParticipation(previousEpochParticipation []byte) {
func (b *BeaconState) SetPreviousEpochParticipation(previousEpochParticipation []cltypes.ParticipationFlags) {
b.touchedLeaves[PreviousEpochParticipationLeafIndex] = true
b.previousEpochParticipation = previousEpochParticipation
}
func (b *BeaconState) SetCurrentEpochParticipation(currentEpochParticipation []byte) {
func (b *BeaconState) SetCurrentEpochParticipation(currentEpochParticipation []cltypes.ParticipationFlags) {
b.touchedLeaves[CurrentEpochParticipationLeafIndex] = true
b.currentEpochParticipation = currentEpochParticipation
}
func (b *BeaconState) SetJustificationBits(justificationBits byte) {
func (b *BeaconState) SetJustificationBits(justificationBits cltypes.JustificationBits) {
b.touchedLeaves[JustificationBitsLeafIndex] = true
b.justificationBits = justificationBits
}

View File

@ -121,7 +121,7 @@ func (b *BeaconState) EncodeSSZ(buf []byte) ([]byte, error) {
dst = append(dst, ssz_utils.OffsetSSZ(offset)...)
offset += uint32(len(b.currentEpochParticipation))
dst = append(dst, b.justificationBits)
dst = append(dst, b.justificationBits.Byte())
// Checkpoints
if dst, err = b.previousJustifiedCheckpoint.EncodeSSZ(dst); err != nil {
@ -179,8 +179,8 @@ func (b *BeaconState) EncodeSSZ(buf []byte) ([]byte, error) {
}
// Write participations (offset 4 & 5)
dst = append(dst, b.previousEpochParticipation...)
dst = append(dst, b.currentEpochParticipation...)
dst = append(dst, b.previousEpochParticipation.Bytes()...)
dst = append(dst, b.currentEpochParticipation.Bytes()...)
// write inactivity scores (offset 6)
for _, score := range b.inactivityScores {
@ -277,7 +277,7 @@ func (b *BeaconState) DecodeSSZWithVersion(buf []byte, version int) error {
currentEpochParticipationOffset := ssz_utils.DecodeOffset(buf[pos:])
pos += 4
// just take that one smol byte
b.justificationBits = buf[pos]
b.justificationBits.FromByte(buf[pos])
pos++
// Decode checkpoints
b.previousJustifiedCheckpoint = new(cltypes.Checkpoint)
@ -338,12 +338,14 @@ func (b *BeaconState) DecodeSSZWithVersion(buf []byte, version int) error {
if b.balances, err = ssz_utils.DecodeNumbersList(buf, balancesOffset, previousEpochParticipationOffset, state_encoding.ValidatorRegistryLimit); err != nil {
return err
}
if b.previousEpochParticipation, err = ssz_utils.DecodeString(buf, uint64(previousEpochParticipationOffset), uint64(currentEpochParticipationOffset), state_encoding.ValidatorRegistryLimit); err != nil {
var previousEpochParticipation, currentEpochParticipation []byte
if previousEpochParticipation, err = ssz_utils.DecodeString(buf, uint64(previousEpochParticipationOffset), uint64(currentEpochParticipationOffset), state_encoding.ValidatorRegistryLimit); err != nil {
return err
}
if b.currentEpochParticipation, err = ssz_utils.DecodeString(buf, uint64(currentEpochParticipationOffset), uint64(inactivityScoresOffset), state_encoding.ValidatorRegistryLimit); err != nil {
if currentEpochParticipation, err = ssz_utils.DecodeString(buf, uint64(currentEpochParticipationOffset), uint64(inactivityScoresOffset), state_encoding.ValidatorRegistryLimit); err != nil {
return err
}
b.previousEpochParticipation, b.currentEpochParticipation = cltypes.ParticipationFlagsListFromBytes(previousEpochParticipation), cltypes.ParticipationFlagsListFromBytes(currentEpochParticipation)
endOffset := uint32(len(buf))
if executionPayloadOffset != 0 {
endOffset = executionPayloadOffset

View File

@ -34,9 +34,9 @@ type BeaconState struct {
balances []uint64
randaoMixes [randoMixesLength]libcommon.Hash
slashings [slashingsLength]uint64
previousEpochParticipation []byte
currentEpochParticipation []byte
justificationBits byte
previousEpochParticipation cltypes.ParticipationFlagsList
currentEpochParticipation cltypes.ParticipationFlagsList
justificationBits cltypes.JustificationBits
// Altair
previousJustifiedCheckpoint *cltypes.Checkpoint
currentJustifiedCheckpoint *cltypes.Checkpoint
@ -54,6 +54,16 @@ type BeaconState struct {
version clparams.StateVersion // State version
leaves [32][32]byte // Pre-computed leaves.
touchedLeaves map[StateLeafIndex]bool // Maps each leaf to whether they were touched or not.
// Configs
beaconConfig *clparams.BeaconChainConfig
}
func New(cfg *clparams.BeaconChainConfig) *BeaconState {
state := &BeaconState{
beaconConfig: cfg,
}
state.initBeaconState()
return state
}
func preparateRootsForHashing(roots []libcommon.Hash) [][32]byte {

View File

@ -26,7 +26,8 @@ func GetEmptyBeaconState() *BeaconState {
BaseFee: big.NewInt(0),
Number: big.NewInt(0),
},
version: clparams.BellatrixVersion,
version: clparams.BellatrixVersion,
beaconConfig: &clparams.MainnetBeaconConfig,
}
b.initBeaconState()
return b

View File

@ -0,0 +1,101 @@
package transition
import (
"github.com/ledgerwatch/erigon/cl/clparams"
"github.com/ledgerwatch/erigon/cl/cltypes"
)
// weighJustificationAndFinalization checks justification and finality of epochs and adds records to the state as needed.
func (s *StateTransistor) weighJustificationAndFinalization(totalActiveBalance, previousEpochTargetBalance, currentEpochTargetBalance uint64) error {
currentEpoch := s.state.Epoch()
previousEpoch := s.state.PreviousEpoch()
oldPreviousJustifiedCheckpoint := s.state.PreviousJustifiedCheckpoint()
oldCurrentJustifiedCheckpoint := s.state.CurrentJustifiedCheckpoint()
justificationBits := s.state.JustificationBits()
// Process justification
s.state.SetPreviousJustifiedCheckpoint(oldCurrentJustifiedCheckpoint)
// Discard oldest bit
copy(justificationBits[1:], justificationBits[:3])
// Turn off current justification bit
justificationBits[0] = false
// Update justified checkpoint if super majority is reached on previous epoch
if previousEpochTargetBalance*3 >= totalActiveBalance*2 {
checkPointRoot, err := s.state.GetBlockRoot(previousEpoch)
if err != nil {
return err
}
s.state.SetCurrentJustifiedCheckpoint(&cltypes.Checkpoint{
Epoch: previousEpoch,
Root: checkPointRoot,
})
justificationBits[1] = true
}
if currentEpochTargetBalance*3 >= totalActiveBalance*2 {
checkPointRoot, err := s.state.GetBlockRoot(currentEpoch)
if err != nil {
return err
}
s.state.SetCurrentJustifiedCheckpoint(&cltypes.Checkpoint{
Epoch: currentEpoch,
Root: checkPointRoot,
})
justificationBits[0] = true
}
// Process finalization
// The 2nd/3rd/4th most recent epochs are justified, the 2nd using the 4th as source
// The 2nd/3rd most recent epochs are justified, the 2nd using the 3rd as source
if (justificationBits.CheckRange(1, 4) && oldPreviousJustifiedCheckpoint.Epoch+3 == currentEpoch) ||
(justificationBits.CheckRange(1, 3) && oldPreviousJustifiedCheckpoint.Epoch+2 == currentEpoch) {
s.state.SetFinalizedCheckpoint(oldPreviousJustifiedCheckpoint)
}
// The 1st/2nd/3rd most recent epochs are justified, the 1st using the 3rd as source
// The 1st/2nd most recent epochs are justified, the 1st using the 2nd as source
if (justificationBits.CheckRange(0, 3) && oldCurrentJustifiedCheckpoint.Epoch+2 == currentEpoch) ||
(justificationBits.CheckRange(0, 2) && oldCurrentJustifiedCheckpoint.Epoch+1 == currentEpoch) {
s.state.SetFinalizedCheckpoint(oldCurrentJustifiedCheckpoint)
}
// Write justification bits
s.state.SetJustificationBits(justificationBits)
return nil
}
func (s *StateTransistor) ProcessJustificationBitsAndFinality() error {
if s.state.Version() == clparams.Phase0Version {
return s.processJustificationBitsAndFinalityPreAltair()
}
return s.processJustificationBitsAndFinalityAltair()
}
func (s *StateTransistor) processJustificationBitsAndFinalityPreAltair() error {
panic("NOT IMPLEMENTED. STOOOOOP")
}
func (s *StateTransistor) processJustificationBitsAndFinalityAltair() error {
currentEpoch := s.state.Epoch()
previousEpoch := s.state.PreviousEpoch()
// Skip for first 2 epochs
if currentEpoch <= s.beaconConfig.GenesisEpoch+1 {
return nil
}
previousIndices, err := s.state.GetUnslashedParticipatingIndices(int(s.beaconConfig.TimelyTargetFlagIndex), previousEpoch)
if err != nil {
return err
}
currentIndices, err := s.state.GetUnslashedParticipatingIndices(int(s.beaconConfig.TimelyTargetFlagIndex), currentEpoch)
if err != nil {
return err
}
totalActiveBalance, err := s.state.GetTotalActiveBalance()
if err != nil {
return err
}
previousTargetBalance, err := s.state.GetTotalBalance(previousIndices)
if err != nil {
return err
}
currentTargetBalance, err := s.state.GetTotalBalance(currentIndices)
if err != nil {
return err
}
return s.weighJustificationAndFinalization(totalActiveBalance, previousTargetBalance, currentTargetBalance)
}

View File

@ -0,0 +1,49 @@
package transition_test
import (
"testing"
libcommon "github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon/cl/clparams"
"github.com/ledgerwatch/erigon/cl/cltypes"
"github.com/ledgerwatch/erigon/cmd/erigon-cl/core/state"
"github.com/ledgerwatch/erigon/cmd/erigon-cl/core/transition"
"github.com/stretchr/testify/require"
)
func getJustificationAndFinalizationState() *state.BeaconState {
cfg := clparams.MainnetBeaconConfig
epoch := cfg.FarFutureEpoch
bal := cfg.MaxEffectiveBalance
state := state.GetEmptyBeaconState()
blockRoots := make([][32]byte, cfg.SlotsPerEpoch*2+1)
for i := 0; i < len(blockRoots); i++ {
blockRoots[i][0] = byte(i)
state.SetBlockRootAt(i, blockRoots[i])
}
state.SetSlot(cfg.SlotsPerEpoch*2 + 1)
bits := cltypes.JustificationBits{}
bits.FromByte(0x3)
state.SetJustificationBits(bits)
state.SetValidators([]*cltypes.Validator{
{ExitEpoch: epoch}, {ExitEpoch: epoch}, {ExitEpoch: epoch}, {ExitEpoch: epoch},
})
state.SetBalances([]uint64{bal, bal, bal, bal})
state.SetCurrentEpochParticipation(cltypes.ParticipationFlagsList{0b01, 0b01, 0b01, 0b01})
state.SetPreviousEpochParticipation(cltypes.ParticipationFlagsList{0b01, 0b01, 0b01, 0b01})
return state
}
func TestProcessJustificationAndFinalizationJustifyCurrentEpoch(t *testing.T) {
cfg := clparams.MainnetBeaconConfig
testState := getJustificationAndFinalizationState()
transitioner := transition.New(testState, &cfg, nil)
transitioner.ProcessJustificationBitsAndFinality()
rt := libcommon.Hash{byte(64)}
require.Equal(t, rt, testState.CurrentJustifiedCheckpoint().Root, "Unexpected current justified root")
require.Equal(t, uint64(2), testState.CurrentJustifiedCheckpoint().Epoch, "Unexpected justified epoch")
require.Equal(t, uint64(0), testState.PreviousJustifiedCheckpoint().Epoch, "Unexpected previous justified epoch")
require.Equal(t, libcommon.Hash{}, testState.FinalizedCheckpoint().Root)
require.Equal(t, uint64(0), testState.FinalizedCheckpoint().Epoch, "Unexpected finalized epoch")
}