diff --git a/beacon-chain/blockchain/receive_attestation.go b/beacon-chain/blockchain/receive_attestation.go index 5fb743cc3..bbf96e0c5 100644 --- a/beacon-chain/blockchain/receive_attestation.go +++ b/beacon-chain/blockchain/receive_attestation.go @@ -9,6 +9,7 @@ import ( "github.com/prysmaticlabs/prysm/beacon-chain/core/blocks" "github.com/prysmaticlabs/prysm/beacon-chain/core/feed" "github.com/prysmaticlabs/prysm/beacon-chain/core/helpers" + "github.com/prysmaticlabs/prysm/beacon-chain/state" "github.com/prysmaticlabs/prysm/shared/bytesutil" "github.com/prysmaticlabs/prysm/shared/featureconfig" "github.com/prysmaticlabs/prysm/shared/params" @@ -22,6 +23,7 @@ import ( type AttestationReceiver interface { ReceiveAttestationNoPubsub(ctx context.Context, att *ethpb.Attestation) error IsValidAttestation(ctx context.Context, att *ethpb.Attestation) bool + AttestationPreState(ctx context.Context, att *ethpb.Attestation) (*state.BeaconState, error) } // ReceiveAttestationNoPubsub is a function that defines the operations that are preformed on @@ -58,7 +60,7 @@ func (s *Service) ReceiveAttestationNoPubsub(ctx context.Context, att *ethpb.Att // IsValidAttestation returns true if the attestation can be verified against its pre-state. func (s *Service) IsValidAttestation(ctx context.Context, att *ethpb.Attestation) bool { - baseState, err := s.getAttPreState(ctx, att.Data.Target) + baseState, err := s.AttestationPreState(ctx, att) if err != nil { log.WithError(err).Error("Failed to get attestation pre state") return false @@ -72,6 +74,11 @@ func (s *Service) IsValidAttestation(ctx context.Context, att *ethpb.Attestation return true } +// AttestationPreState returns the pre state of attestation. +func (s *Service) AttestationPreState(ctx context.Context, att *ethpb.Attestation) (*state.BeaconState, error) { + return s.getAttPreState(ctx, att.Data.Target) +} + // This processes attestations from the attestation pool to account for validator votes and fork choice. func (s *Service) processAttestation(subscribedToStateEvents chan struct{}) { // Wait for state to be initialized. diff --git a/beacon-chain/blockchain/testing/mock.go b/beacon-chain/blockchain/testing/mock.go index 9d80be356..d8d83e9da 100644 --- a/beacon-chain/blockchain/testing/mock.go +++ b/beacon-chain/blockchain/testing/mock.go @@ -202,6 +202,11 @@ func (ms *ChainService) ReceiveAttestationNoPubsub(context.Context, *ethpb.Attes return nil } +// AttestationPreState mocks AttestationPreState method in chain service. +func (ms *ChainService) AttestationPreState(ctx context.Context, att *ethpb.Attestation) (*stateTrie.BeaconState, error) { + return ms.State, nil +} + // HeadValidatorsIndices mocks the same method in the chain service. func (ms *ChainService) HeadValidatorsIndices(epoch uint64) ([]uint64, error) { if ms.State == nil { diff --git a/beacon-chain/sync/validate_aggregate_proof.go b/beacon-chain/sync/validate_aggregate_proof.go index d3427727a..2ed7b27c5 100644 --- a/beacon-chain/sync/validate_aggregate_proof.go +++ b/beacon-chain/sync/validate_aggregate_proof.go @@ -91,7 +91,7 @@ func (r *Service) validateAggregatedAtt(ctx context.Context, signed *ethpb.Signe return false } - s, err := r.chain.HeadState(ctx) + s, err := r.chain.AttestationPreState(ctx, signed.Message.Aggregate) if err != nil { traceutil.AnnotateError(span, err) return false