From d845eb2a6935883a92fb7c6f7a6b2821fb97091c Mon Sep 17 00:00:00 2001 From: battlmonstr Date: Mon, 16 May 2022 15:28:03 +0200 Subject: [PATCH] sentry: refactor eth handshake checks (#4163) --- cmd/sentry/sentry/eth_handshake.go | 83 +++++++++++++++++++++++++ cmd/sentry/sentry/eth_handshake_test.go | 78 +++++++++++++++++++++++ cmd/sentry/sentry/sentry_grpc_server.go | 58 +++-------------- 3 files changed, 169 insertions(+), 50 deletions(-) create mode 100644 cmd/sentry/sentry/eth_handshake.go create mode 100644 cmd/sentry/sentry/eth_handshake_test.go diff --git a/cmd/sentry/sentry/eth_handshake.go b/cmd/sentry/sentry/eth_handshake.go new file mode 100644 index 000000000..acda7a8c4 --- /dev/null +++ b/cmd/sentry/sentry/eth_handshake.go @@ -0,0 +1,83 @@ +package sentry + +import ( + "fmt" + + "github.com/ledgerwatch/erigon-lib/gointerfaces" + proto_sentry "github.com/ledgerwatch/erigon-lib/gointerfaces/sentry" + "github.com/ledgerwatch/erigon/core/forkid" + "github.com/ledgerwatch/erigon/eth/protocols/eth" + "github.com/ledgerwatch/erigon/p2p" +) + +func readAndValidatePeerStatusMessage( + rw p2p.MsgReadWriter, + status *proto_sentry.StatusData, + version uint, + minVersion uint, +) (*eth.StatusPacket, error) { + msg, err := rw.ReadMsg() + if err != nil { + return nil, err + } + + reply, err := tryDecodeStatusMessage(&msg) + msg.Discard() + if err != nil { + return nil, err + } + + err = checkPeerStatusCompatibility(reply, status, version, minVersion) + return reply, err +} + +func tryDecodeStatusMessage(msg *p2p.Msg) (*eth.StatusPacket, error) { + if msg.Code != eth.StatusMsg { + return nil, fmt.Errorf("first msg has code %x (!= %x)", msg.Code, eth.StatusMsg) + } + + if msg.Size > eth.ProtocolMaxMsgSize { + return nil, fmt.Errorf("message is too large %d, limit %d", msg.Size, eth.ProtocolMaxMsgSize) + } + + var reply eth.StatusPacket + if err := msg.Decode(&reply); err != nil { + return nil, fmt.Errorf("decode message %v: %w", msg, err) + } + + return &reply, nil +} + +func checkPeerStatusCompatibility( + reply *eth.StatusPacket, + status *proto_sentry.StatusData, + version uint, + minVersion uint, +) error { + networkID := status.NetworkId + if reply.NetworkID != networkID { + return fmt.Errorf("network id does not match: theirs %d, ours %d", reply.NetworkID, networkID) + } + + if uint(reply.ProtocolVersion) > version { + return fmt.Errorf("version is more than what this senty supports: theirs %d, max %d", reply.ProtocolVersion, version) + } + if uint(reply.ProtocolVersion) < minVersion { + return fmt.Errorf("version is less than allowed minimum: theirs %d, min %d", reply.ProtocolVersion, minVersion) + } + + genesisHash := gointerfaces.ConvertH256ToHash(status.ForkData.Genesis) + if reply.Genesis != genesisHash { + return fmt.Errorf("genesis hash does not match: theirs %x, ours %x", reply.Genesis, genesisHash) + } + + forks := make([]uint64, len(status.ForkData.Forks)) + // copy because forkid.NewFilterFromForks will write into this slice + copy(forks, status.ForkData.Forks) + forkFilter := forkid.NewFilterFromForks(forks, genesisHash, status.MaxBlock) + if err := forkFilter(reply.ForkID); err != nil { + return err + } + + return nil +} diff --git a/cmd/sentry/sentry/eth_handshake_test.go b/cmd/sentry/sentry/eth_handshake_test.go new file mode 100644 index 000000000..eeec146f3 --- /dev/null +++ b/cmd/sentry/sentry/eth_handshake_test.go @@ -0,0 +1,78 @@ +package sentry + +import ( + "math/big" + "testing" + + "github.com/holiman/uint256" + "github.com/ledgerwatch/erigon-lib/gointerfaces" + proto_sentry "github.com/ledgerwatch/erigon-lib/gointerfaces/sentry" + "github.com/ledgerwatch/erigon/common" + "github.com/ledgerwatch/erigon/core/forkid" + "github.com/ledgerwatch/erigon/eth/protocols/eth" + "github.com/ledgerwatch/erigon/params" + "github.com/stretchr/testify/assert" +) + +func TestCheckPeerStatusCompatibility(t *testing.T) { + var version uint = eth.ETH66 + networkID := params.MainnetChainConfig.ChainID.Uint64() + goodReply := eth.StatusPacket{ + ProtocolVersion: uint32(version), + NetworkID: networkID, + TD: big.NewInt(0), + Head: common.Hash{}, + Genesis: params.MainnetGenesisHash, + ForkID: forkid.NewID(params.MainnetChainConfig, params.MainnetGenesisHash, 0), + } + status := proto_sentry.StatusData{ + NetworkId: networkID, + TotalDifficulty: gointerfaces.ConvertUint256IntToH256(new(uint256.Int)), + BestHash: nil, + ForkData: &proto_sentry.Forks{ + Genesis: gointerfaces.ConvertHashToH256(params.MainnetGenesisHash), + Forks: forkid.GatherForks(params.MainnetChainConfig), + }, + MaxBlock: 0, + } + + t.Run("ok", func(t *testing.T) { + err := checkPeerStatusCompatibility(&goodReply, &status, version, version) + assert.Nil(t, err) + }) + t.Run("network mismatch", func(t *testing.T) { + reply := goodReply + reply.NetworkID = 0 + err := checkPeerStatusCompatibility(&reply, &status, version, version) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "network") + }) + t.Run("version mismatch min", func(t *testing.T) { + reply := goodReply + reply.ProtocolVersion = eth.ETH66 - 1 + err := checkPeerStatusCompatibility(&reply, &status, version, version) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "version is less") + }) + t.Run("version mismatch max", func(t *testing.T) { + reply := goodReply + reply.ProtocolVersion = eth.ETH66 + 1 + err := checkPeerStatusCompatibility(&reply, &status, version, version) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "version is more") + }) + t.Run("genesis mismatch", func(t *testing.T) { + reply := goodReply + reply.Genesis = common.Hash{} + err := checkPeerStatusCompatibility(&reply, &status, version, version) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "genesis") + }) + t.Run("fork mismatch", func(t *testing.T) { + reply := goodReply + reply.ForkID = forkid.ID{} + err := checkPeerStatusCompatibility(&reply, &status, version, version) + assert.NotNil(t, err) + assert.ErrorIs(t, err, forkid.ErrLocalIncompatibleOrStale) + }) +} diff --git a/cmd/sentry/sentry/sentry_grpc_server.go b/cmd/sentry/sentry/sentry_grpc_server.go index 541d35f71..5fe9e2612 100644 --- a/cmd/sentry/sentry/sentry_grpc_server.go +++ b/cmd/sentry/sentry/sentry_grpc_server.go @@ -211,6 +211,7 @@ func handShake( ourTD := gointerfaces.ConvertH256ToUint256Int(status.TotalDifficulty) // Convert proto status data into the one required by devp2p genesisHash := gointerfaces.ConvertH256ToHash(status.ForkData.Genesis) + go func() { defer debug.LogPanic() s := ð.StatusPacket{ @@ -223,58 +224,15 @@ func handShake( } errc <- p2p.Send(rw, eth.StatusMsg, s) }() - var readStatus = func() error { - forks := make([]uint64, len(status.ForkData.Forks)) // copy because forkid.NewFilterFromForks will write into this slice - copy(forks, status.ForkData.Forks) - forkFilter := forkid.NewFilterFromForks(forks, genesisHash, status.MaxBlock) - networkID := status.NetworkId - // Read handshake message - msg, err1 := rw.ReadMsg() - if err1 != nil { - return err1 - } - if msg.Code != eth.StatusMsg { - msg.Discard() - return fmt.Errorf("first msg has code %x (!= %x)", msg.Code, eth.StatusMsg) - } - if msg.Size > eth.ProtocolMaxMsgSize { - msg.Discard() - return fmt.Errorf("message is too large %d, limit %d", msg.Size, eth.ProtocolMaxMsgSize) - } - // Decode the handshake and make sure everything matches - var reply eth.StatusPacket - if err1 = msg.Decode(&reply); err1 != nil { - msg.Discard() - return fmt.Errorf("decode message %v: %w", msg, err1) - } - msg.Discard() - if reply.NetworkID != networkID { - return fmt.Errorf("network id does not match: theirs %d, ours %d", reply.NetworkID, networkID) - } - if uint(reply.ProtocolVersion) < minVersion { - return fmt.Errorf("version is less than allowed minimum: theirs %d, min %d", reply.ProtocolVersion, minVersion) - } - if uint(reply.ProtocolVersion) > version { - return fmt.Errorf("version is more than what this senty supports: theirs %d, max %d", reply.ProtocolVersion, version) - } - if reply.Genesis != genesisHash { - return fmt.Errorf("genesis hash does not match: theirs %x, ours %x", reply.Genesis, genesisHash) - } - if err1 = forkFilter(reply.ForkID); err1 != nil { - return fmt.Errorf("%w", err1) - } - - if startSync != nil { - if err := startSync(reply.Head); err != nil { - return err - } - } - - return nil - } go func() { - errc <- readStatus() + reply, err := readAndValidatePeerStatusMessage(rw, status, version, minVersion) + + if (err == nil) && (startSync != nil) { + err = startSync(reply.Head) + } + + errc <- err }() timeout := time.NewTimer(handshakeTimeout)