diff --git a/beacon-chain/state/stategen/cold.go b/beacon-chain/state/stategen/cold.go index a8d5de949..f14f934a4 100644 --- a/beacon-chain/state/stategen/cold.go +++ b/beacon-chain/state/stategen/cold.go @@ -97,6 +97,52 @@ func (s *State) loadColdIntermediateStateByRoot(ctx context.Context, slot uint64 return s.ReplayBlocks(ctx, lowArchivedPointState, replayBlks, slot) } +// This loads a cold state by slot where the slot lies between the archived point. +// This is a slower implementation given there's no root and slot is the only argument. It requires fetching +// all the blocks between the archival points. +func (s *State) loadColdIntermediateStateBySlot(ctx context.Context, slot uint64) (*state.BeaconState, error) { + ctx, span := trace.StartSpan(ctx, "stateGen.loadColdIntermediateStateBySlot") + defer span.End() + + // Load the archive point for lower and high side of the intermediate state. + lowArchivedPointIdx := slot / s.slotsPerArchivedPoint + highArchivedPointIdx := lowArchivedPointIdx + 1 + + lowArchivedPointState, err := s.archivedPointByIndex(ctx, lowArchivedPointIdx) + if err != nil { + return nil, errors.Wrap(err, "could not get lower bound archived state using index") + } + if lowArchivedPointState == nil { + return nil, errUnknownArchivedState + } + + // If the slot of the high archived point lies outside of the split slot, use the split slot and root + // for the upper archived point. + var highArchivedPointRoot [32]byte + highArchivedPointSlot := highArchivedPointIdx * s.slotsPerArchivedPoint + if highArchivedPointSlot >= s.splitInfo.slot { + highArchivedPointRoot = s.splitInfo.root + highArchivedPointSlot = s.splitInfo.slot + } else { + if _, err := s.archivedPointByIndex(ctx, highArchivedPointSlot); err != nil { + return nil, errors.Wrap(err, "could not get upper bound archived state using index") + } + highArchivedPointRoot = s.beaconDB.ArchivedPointRoot(ctx, highArchivedPointIdx) + slot, err := s.blockRootSlot(ctx, highArchivedPointRoot) + if err != nil { + return nil, errors.Wrap(err, "could not get high archived point slot") + } + highArchivedPointSlot = slot + } + + replayBlks, err := s.LoadBlocks(ctx, lowArchivedPointState.Slot()+1, highArchivedPointSlot, highArchivedPointRoot) + if err != nil { + return nil, errors.Wrap(err, "could not load block for cold state using slot") + } + + return s.ReplayBlocks(ctx, lowArchivedPointState, replayBlks, slot) +} + // Given the archive index, this returns the archived cold state in the DB. // If the archived state does not exist in the state, it'll compute it and save it. func (s *State) archivedPointByIndex(ctx context.Context, archiveIndex uint64) (*state.BeaconState, error) { diff --git a/beacon-chain/state/stategen/cold_test.go b/beacon-chain/state/stategen/cold_test.go index def96a04d..2bc8c3188 100644 --- a/beacon-chain/state/stategen/cold_test.go +++ b/beacon-chain/state/stategen/cold_test.go @@ -2,6 +2,7 @@ package stategen import ( "context" + "strings" "testing" "github.com/gogo/protobuf/proto" @@ -9,6 +10,7 @@ import ( "github.com/prysmaticlabs/go-ssz" testDB "github.com/prysmaticlabs/prysm/beacon-chain/db/testing" pb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1" + "github.com/prysmaticlabs/prysm/shared/params" "github.com/prysmaticlabs/prysm/shared/testutil" ) @@ -132,6 +134,79 @@ func TestLoadColdStateByRoot_IntermediatePlayback(t *testing.T) { } } +func TestLoadColdStateBySlotIntermediatePlayback_BeforeCutoff(t *testing.T) { + ctx := context.Background() + db := testDB.SetupDB(t) + defer testDB.TeardownDB(t, db) + + service := New(db) + service.slotsPerArchivedPoint = params.BeaconConfig().SlotsPerEpoch * 2 + + beaconState, _ := testutil.DeterministicGenesisState(t, 32) + if err := service.beaconDB.SaveArchivedPointState(ctx, beaconState, 0); err != nil { + t.Fatal(err) + } + if err := service.beaconDB.SaveArchivedPointRoot(ctx, [32]byte{}, 0); err != nil { + t.Fatal(err) + } + if err := service.beaconDB.SaveArchivedPointState(ctx, beaconState, 1); err != nil { + t.Fatal(err) + } + if err := service.beaconDB.SaveArchivedPointRoot(ctx, [32]byte{}, 1); err != nil { + t.Fatal(err) + } + + slot := uint64(20) + loadedState, err := service.loadColdIntermediateStateBySlot(ctx, slot) + if err != nil { + t.Fatal(err) + } + if loadedState.Slot() != slot { + t.Error("Did not correctly save state") + } +} + +func TestLoadColdStateBySlotIntermediatePlayback_AfterCutoff(t *testing.T) { + ctx := context.Background() + db := testDB.SetupDB(t) + defer testDB.TeardownDB(t, db) + + service := New(db) + service.slotsPerArchivedPoint = params.BeaconConfig().SlotsPerEpoch + + beaconState, _ := testutil.DeterministicGenesisState(t, 32) + if err := service.beaconDB.SaveArchivedPointState(ctx, beaconState, 0); err != nil { + t.Fatal(err) + } + if err := service.beaconDB.SaveArchivedPointRoot(ctx, [32]byte{}, 0); err != nil { + t.Fatal(err) + } + + slot := uint64(10) + loadedState, err := service.loadColdIntermediateStateBySlot(ctx, slot) + if err != nil { + t.Fatal(err) + } + if loadedState.Slot() != slot { + t.Error("Did not correctly save state") + } +} + + +func TestLoadColdStateByRoot_UnknownArchivedState(t *testing.T) { + ctx := context.Background() + db := testDB.SetupDB(t) + defer testDB.TeardownDB(t, db) + + service := New(db) + service.slotsPerArchivedPoint = 1 + if _, err := service.loadColdIntermediateStateBySlot(ctx, 0); + !strings.Contains(err.Error(), errUnknownArchivedState.Error()) { + t.Log(err) + t.Error("Did not get wanted error") + } +} + func TestArchivedPointByIndex_HasPoint(t *testing.T) { ctx := context.Background() db := testDB.SetupDB(t)