diff --git a/beacon-chain/sync/error.go b/beacon-chain/sync/error.go index 5681bcf2e..a5833ba81 100644 --- a/beacon-chain/sync/error.go +++ b/beacon-chain/sync/error.go @@ -15,6 +15,7 @@ const stepError = "invalid range or step" var errWrongForkDigestVersion = errors.New("wrong fork digest version") var errInvalidEpoch = errors.New("invalid epoch") var errInvalidFinalizedRoot = errors.New("invalid finalized root") +var errGeneric = errors.New(genericError) var responseCodeSuccess = byte(0x00) var responseCodeInvalidRequest = byte(0x01) diff --git a/beacon-chain/sync/rpc_goodbye.go b/beacon-chain/sync/rpc_goodbye.go index d0c282cb6..5fa3c1c33 100644 --- a/beacon-chain/sync/rpc_goodbye.go +++ b/beacon-chain/sync/rpc_goodbye.go @@ -8,6 +8,7 @@ import ( libp2pcore "github.com/libp2p/go-libp2p-core" "github.com/libp2p/go-libp2p-core/peer" "github.com/prysmaticlabs/prysm/beacon-chain/p2p" + "github.com/sirupsen/logrus" ) const ( @@ -43,6 +44,22 @@ func (r *Service) goodbyeRPCHandler(ctx context.Context, msg interface{}, stream return r.p2p.Disconnect(stream.Conn().RemotePeer()) } +func (r *Service) sendGoodByeAndDisconnect(ctx context.Context, code uint64, id peer.ID) error { + if err := r.sendGoodByeMessage(ctx, code, id); err != nil { + log.WithFields(logrus.Fields{ + "error": err, + "peer": id, + }).Debug("Could not send goodbye message to peer") + } + // Add a short delay to allow the stream to flush before closing the connection. + // There is still a chance that the peer won't receive the message. + time.Sleep(50 * time.Millisecond) + if err := r.p2p.Disconnect(id); err != nil { + return err + } + return nil +} + func (r *Service) sendGoodByeMessage(ctx context.Context, code uint64, id peer.ID) error { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() diff --git a/beacon-chain/sync/rpc_goodbye_test.go b/beacon-chain/sync/rpc_goodbye_test.go index ae22be530..f853f9b37 100644 --- a/beacon-chain/sync/rpc_goodbye_test.go +++ b/beacon-chain/sync/rpc_goodbye_test.go @@ -103,3 +103,50 @@ func TestSendGoodbye_SendsMessage(t *testing.T) { t.Error("Peer is still not disconnected despite sending a goodbye message") } } + +func TestSendGoodbye_DisconnectWithPeer(t *testing.T) { + p1 := p2ptest.NewTestP2P(t) + p2 := p2ptest.NewTestP2P(t) + p1.Connect(p2) + if len(p1.Host.Network().Peers()) != 1 { + t.Error("Expected peers to be connected") + } + + // Set up a head state in the database with data we expect. + d := db.SetupDB(t) + r := &Service{ + db: d, + p2p: p1, + } + failureCode := codeClientShutdown + + // Setup streams + pcl := protocol.ID("/eth2/beacon_chain/req/goodbye/1/ssz") + var wg sync.WaitGroup + wg.Add(1) + p2.Host.SetStreamHandler(pcl, func(stream network.Stream) { + defer wg.Done() + out := new(uint64) + if err := r.p2p.Encoding().DecodeWithLength(stream, out); err != nil { + t.Fatal(err) + } + if *out != failureCode { + t.Fatalf("Wanted goodbye code of %d but got %d", failureCode, *out) + } + + }) + + err := r.sendGoodByeAndDisconnect(context.Background(), failureCode, p2.Host.ID()) + if err != nil { + t.Errorf("Unxpected error: %v", err) + } + conns := p1.Host.Network().ConnsToPeer(p2.Host.ID()) + if len(conns) > 0 { + t.Error("Peer is still not disconnected despite sending a goodbye message") + } + + if testutil.WaitTimeout(&wg, 1*time.Second) { + t.Fatal("Did not receive stream within 1 sec") + } + +} diff --git a/beacon-chain/sync/rpc_status.go b/beacon-chain/sync/rpc_status.go index d550139f4..6801e882a 100644 --- a/beacon-chain/sync/rpc_status.go +++ b/beacon-chain/sync/rpc_status.go @@ -151,11 +151,27 @@ func (r *Service) statusRPCHandler(ctx context.Context, msg interface{}, stream } if err := r.validateStatusMessage(ctx, m, stream); err != nil { - log.WithField("peer", stream.Conn().RemotePeer()).Debug("Invalid fork version from peer") + log.WithFields(logrus.Fields{ + "peer": stream.Conn().RemotePeer(), + "error": err}).Debug("Invalid status message from peer") + respCode := byte(0) - switch err.Error() { - case genericError: + switch err { + case errGeneric: respCode = responseCodeServerError + case errWrongForkDigestVersion: + // Respond with our status and disconnect with the peer. + r.p2p.Peers().SetChainState(stream.Conn().RemotePeer(), m) + if err := r.respondWithStatus(ctx, stream); err != nil { + return err + } + if err := stream.Close(); err != nil { // Close before disconnecting. + log.WithError(err).Error("Failed to close stream") + } + if err := r.sendGoodByeAndDisconnect(ctx, codeWrongNetwork, stream.Conn().RemotePeer()); err != nil { + return err + } + return nil default: respCode = responseCodeInvalidRequest r.p2p.Peers().IncrementBadResponses(stream.Conn().RemotePeer()) @@ -184,6 +200,10 @@ func (r *Service) statusRPCHandler(ctx context.Context, msg interface{}, stream } r.p2p.Peers().SetChainState(stream.Conn().RemotePeer(), m) + return r.respondWithStatus(ctx, stream) +} + +func (r *Service) respondWithStatus(ctx context.Context, stream network.Stream) error { headRoot, err := r.chain.HeadRoot(ctx) if err != nil { return err @@ -205,7 +225,6 @@ func (r *Service) statusRPCHandler(ctx context.Context, msg interface{}, stream log.WithError(err).Error("Failed to write to stream") } _, err = r.p2p.Encoding().EncodeWithLength(stream, resp) - return err } @@ -239,10 +258,10 @@ func (r *Service) validateStatusMessage(ctx context.Context, msg *pb.Status, str } blk, err := r.db.Block(ctx, bytesutil.ToBytes32(msg.FinalizedRoot)) if err != nil { - return errors.New(genericError) + return errGeneric } if blk == nil { - return errors.New(genericError) + return errGeneric } // TODO(#5827) Verify the finalized block with the epoch in the // status message diff --git a/beacon-chain/sync/rpc_status_test.go b/beacon-chain/sync/rpc_status_test.go index 2c6624ace..880f24566 100644 --- a/beacon-chain/sync/rpc_status_test.go +++ b/beacon-chain/sync/rpc_status_test.go @@ -1,6 +1,7 @@ package sync import ( + "bytes" "context" "sync" "testing" @@ -35,9 +36,18 @@ func TestHelloRPCHandler_Disconnects_OnForkVersionMismatch(t *testing.T) { if len(p1.Host.Network().Peers()) != 1 { t.Error("Expected peers to be connected") } + root := [32]byte{'C'} r := &Service{p2p: p1, chain: &mock.ChainService{ + Fork: &pb.Fork{ + PreviousVersion: params.BeaconConfig().GenesisForkVersion, + CurrentVersion: params.BeaconConfig().GenesisForkVersion, + }, + FinalizedCheckPoint: ðpb.Checkpoint{ + Epoch: 0, + Root: root[:], + }, Genesis: time.Now(), ValidatorsRoot: [32]byte{'A'}, }} @@ -47,17 +57,29 @@ func TestHelloRPCHandler_Disconnects_OnForkVersionMismatch(t *testing.T) { wg.Add(1) p2.Host.SetStreamHandler(pcl, func(stream network.Stream) { defer wg.Done() - code, errMsg, err := ReadStatusCode(stream, p1.Encoding()) - if err != nil { + expectSuccess(t, r, stream) + out := &pb.Status{} + if err := r.p2p.Encoding().DecodeWithLength(stream, out); err != nil { t.Fatal(err) } - if code == 0 { - t.Error("Expected a non-zero code") + if !bytes.Equal(out.FinalizedRoot, root[:]) { + t.Errorf("Expected finalized root of %#x but got %#x", root, out.FinalizedRoot) } - if errMsg != errWrongForkDigestVersion.Error() { - t.Logf("Received error string len %d, wanted error string len %d", len(errMsg), len(errWrongForkDigestVersion.Error())) - t.Errorf("Received unexpected message response in the stream: %s. Wanted %s.", errMsg, errWrongForkDigestVersion.Error()) + }) + + pcl2 := protocol.ID("/eth2/beacon_chain/req/goodbye/1/ssz") + var wg2 sync.WaitGroup + wg2.Add(1) + p2.Host.SetStreamHandler(pcl2, func(stream network.Stream) { + defer wg2.Done() + msg := new(uint64) + if err := r.p2p.Encoding().DecodeWithLength(stream, msg); err != nil { + t.Error(err) } + if *msg != codeWrongNetwork { + t.Errorf("Wrong goodbye code: %d", *msg) + } + }) stream1, err := p1.Host.NewStream(context.Background(), p2.Host.ID(), pcl) @@ -66,13 +88,16 @@ func TestHelloRPCHandler_Disconnects_OnForkVersionMismatch(t *testing.T) { } err = r.statusRPCHandler(context.Background(), &pb.Status{ForkDigest: []byte("fake")}, stream1) - if err != errWrongForkDigestVersion { - t.Errorf("Expected error %v, got %v", errWrongForkDigestVersion, err) + if err != nil { + t.Errorf("Expected no error but got %v", err) } if testutil.WaitTimeout(&wg, 1*time.Second) { t.Fatal("Did not receive stream within 1 sec") } + if testutil.WaitTimeout(&wg2, 1*time.Second) { + t.Fatal("Did not receive stream within 1 sec") + } if len(p1.Host.Network().Peers()) != 0 { t.Error("handler did not disconnect peer")