diff --git a/beacon-chain/rpc/beacon_chain_server.go b/beacon-chain/rpc/beacon_chain_server.go index 4be359515..c1d3c58b0 100644 --- a/beacon-chain/rpc/beacon_chain_server.go +++ b/beacon-chain/rpc/beacon_chain_server.go @@ -238,11 +238,9 @@ func (bs *BeaconChainServer) GetChainHead(ctx context.Context, _ *ptypes.Empty) }, nil } -// ListValidatorBalances retrieves the validator balances for a given set of public key at -// a specific epoch in time. -// -// TODO(#3064): Implement balances for a specific epoch. Current implementation returns latest balances, -// this is blocked by DB refactor. +// ListValidatorBalances retrieves the validator balances for a given set of public keys. +// An optional Epoch parameter is provided to request historical validator balances from +// archived, persistent data. func (bs *BeaconChainServer) ListValidatorBalances( ctx context.Context, req *ethpb.GetValidatorBalancesRequest) (*ethpb.ValidatorBalances, error) { @@ -250,12 +248,33 @@ func (bs *BeaconChainServer) ListValidatorBalances( res := make([]*ethpb.ValidatorBalances_Balance, 0, len(req.PublicKeys)+len(req.Indices)) filtered := map[uint64]bool{} // track filtered validators to prevent duplication in the response. - headState, err := bs.beaconDB.HeadState(ctx) - if err != nil { - return nil, status.Errorf(codes.Internal, "could not retrieve head state: %v", err) + var requestingGenesis bool + var epoch uint64 + switch q := req.QueryFilter.(type) { + case *ethpb.GetValidatorBalancesRequest_Epoch: + epoch = q.Epoch + case *ethpb.GetValidatorBalancesRequest_Genesis: + requestingGenesis = q.Genesis + default: } - balances := headState.Balances + + var balances []uint64 + var err error + headState := bs.headFetcher.HeadState() validators := headState.Validators + if requestingGenesis { + balances, err = bs.beaconDB.ArchivedBalances(ctx, 0 /* genesis epoch */) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "could not retrieve balances for epoch %d", epoch) + } + } else if !requestingGenesis && epoch < helpers.CurrentEpoch(headState) { + balances, err = bs.beaconDB.ArchivedBalances(ctx, epoch) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "could not retrieve balances for epoch %d", epoch) + } + } else { + balances = headState.Balances + } for _, pubKey := range req.PublicKeys { // Skip empty public key @@ -274,7 +293,7 @@ func (bs *BeaconChainServer) ListValidatorBalances( filtered[index] = true if int(index) >= len(balances) { - return nil, status.Errorf(codes.InvalidArgument, "validator index %d >= balance list %d", + return nil, status.Errorf(codes.OutOfRange, "validator index %d >= balance list %d", index, len(balances)) } @@ -287,7 +306,11 @@ func (bs *BeaconChainServer) ListValidatorBalances( for _, index := range req.Indices { if int(index) >= len(balances) { - return nil, status.Errorf(codes.InvalidArgument, "validator index %d >= balance list %d", + if epoch <= helpers.CurrentEpoch(headState) { + return nil, status.Errorf(codes.OutOfRange, "validator index %d does not exist in historical balances", + index) + } + return nil, status.Errorf(codes.OutOfRange, "validator index %d >= balance list %d", index, len(balances)) } diff --git a/beacon-chain/rpc/beacon_chain_server_test.go b/beacon-chain/rpc/beacon_chain_server_test.go index 23237a884..0cd2d5cf6 100644 --- a/beacon-chain/rpc/beacon_chain_server_test.go +++ b/beacon-chain/rpc/beacon_chain_server_test.go @@ -396,8 +396,14 @@ func TestBeaconChainServer_ListValidatorBalances(t *testing.T) { setupValidators(t, db, 100) + headState, err := db.HeadState(context.Background()) + if err != nil { + t.Fatal(err) + } + bs := &BeaconChainServer{ - beaconDB: db, + beaconDB: db, + headFetcher: &mock.ChainService{State: headState}, } tests := []struct { @@ -447,19 +453,105 @@ func TestBeaconChainServer_ListValidatorBalances(t *testing.T) { func TestBeaconChainServer_ListValidatorBalancesOutOfRange(t *testing.T) { db := dbTest.SetupDB(t) defer dbTest.TeardownDB(t, db) - _, balances := setupValidators(t, db, 1) + setupValidators(t, db, 1) + + headState, err := db.HeadState(context.Background()) + if err != nil { + t.Fatal(err) + } bs := &BeaconChainServer{ - beaconDB: db, + beaconDB: db, + headFetcher: &mock.ChainService{State: headState}, } req := ðpb.GetValidatorBalancesRequest{Indices: []uint64{uint64(1)}} - wanted := fmt.Sprintf("validator index %d >= balance list %d", 1, len(balances)) + wanted := "does not exist" if _, err := bs.ListValidatorBalances(context.Background(), req); !strings.Contains(err.Error(), wanted) { t.Errorf("Expected error %v, received %v", wanted, err) } } +func TestBeaconChainServer_ListValidatorBalancesFromArchive(t *testing.T) { + db := dbTest.SetupDB(t) + defer dbTest.TeardownDB(t, db) + ctx := context.Background() + epoch := uint64(0) + validators, balances := setupValidators(t, db, 100) + + if err := db.SaveArchivedBalances(ctx, epoch, balances); err != nil { + t.Fatal(err) + } + + newerBalances := make([]uint64, len(balances)) + for i := 0; i < len(newerBalances); i++ { + newerBalances[i] = balances[i] * 2 + } + bs := &BeaconChainServer{ + beaconDB: db, + headFetcher: &mock.ChainService{ + State: &pbp2p.BeaconState{ + Slot: params.BeaconConfig().SlotsPerEpoch * 3, + Validators: validators, + Balances: newerBalances, + }, + }, + } + + req := ðpb.GetValidatorBalancesRequest{ + QueryFilter: ðpb.GetValidatorBalancesRequest_Epoch{Epoch: 0}, + Indices: []uint64{uint64(1)}, + } + res, err := bs.ListValidatorBalances(context.Background(), req) + if err != nil { + t.Fatal(err) + } + // We should expect a response containing the old balance from epoch 0, + // not the new balance from the current state. + want := []*ethpb.ValidatorBalances_Balance{ + { + PublicKey: validators[1].PublicKey, + Index: 1, + Balance: balances[1], + }, + } + if !reflect.DeepEqual(want, res.Balances) { + t.Errorf("Wanted %v, received %v", want, res.Balances) + } +} + +func TestBeaconChainServer_ListValidatorBalancesFromArchive_NewValidatorNotFound(t *testing.T) { + db := dbTest.SetupDB(t) + defer dbTest.TeardownDB(t, db) + ctx := context.Background() + epoch := uint64(0) + _, balances := setupValidators(t, db, 100) + + if err := db.SaveArchivedBalances(ctx, epoch, balances); err != nil { + t.Fatal(err) + } + + newValidators, newBalances := setupValidators(t, db, 200) + bs := &BeaconChainServer{ + beaconDB: db, + headFetcher: &mock.ChainService{ + State: &pbp2p.BeaconState{ + Slot: params.BeaconConfig().SlotsPerEpoch * 3, + Validators: newValidators, + Balances: newBalances, + }, + }, + } + + req := ðpb.GetValidatorBalancesRequest{ + QueryFilter: ðpb.GetValidatorBalancesRequest_Epoch{Epoch: 0}, + Indices: []uint64{1, 150, 161}, + } + if _, err := bs.ListValidatorBalances(context.Background(), req); !strings.Contains(err.Error(), "does not exist") { + t.Errorf("Wanted out of range error for including newer validators in the arguments, received %v", err) + } +} + func TestBeaconChainServer_GetValidatorsNoPagination(t *testing.T) { db := dbTest.SetupDB(t) defer dbTest.TeardownDB(t, db)