sentry: refactor eth handshake checks (#4163)

This commit is contained in:
battlmonstr 2022-05-16 15:28:03 +02:00 committed by GitHub
parent 2fd2826b85
commit d845eb2a69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 169 additions and 50 deletions

View File

@ -0,0 +1,83 @@
package sentry
import (
"fmt"
"github.com/ledgerwatch/erigon-lib/gointerfaces"
proto_sentry "github.com/ledgerwatch/erigon-lib/gointerfaces/sentry"
"github.com/ledgerwatch/erigon/core/forkid"
"github.com/ledgerwatch/erigon/eth/protocols/eth"
"github.com/ledgerwatch/erigon/p2p"
)
func readAndValidatePeerStatusMessage(
rw p2p.MsgReadWriter,
status *proto_sentry.StatusData,
version uint,
minVersion uint,
) (*eth.StatusPacket, error) {
msg, err := rw.ReadMsg()
if err != nil {
return nil, err
}
reply, err := tryDecodeStatusMessage(&msg)
msg.Discard()
if err != nil {
return nil, err
}
err = checkPeerStatusCompatibility(reply, status, version, minVersion)
return reply, err
}
func tryDecodeStatusMessage(msg *p2p.Msg) (*eth.StatusPacket, error) {
if msg.Code != eth.StatusMsg {
return nil, fmt.Errorf("first msg has code %x (!= %x)", msg.Code, eth.StatusMsg)
}
if msg.Size > eth.ProtocolMaxMsgSize {
return nil, fmt.Errorf("message is too large %d, limit %d", msg.Size, eth.ProtocolMaxMsgSize)
}
var reply eth.StatusPacket
if err := msg.Decode(&reply); err != nil {
return nil, fmt.Errorf("decode message %v: %w", msg, err)
}
return &reply, nil
}
func checkPeerStatusCompatibility(
reply *eth.StatusPacket,
status *proto_sentry.StatusData,
version uint,
minVersion uint,
) error {
networkID := status.NetworkId
if reply.NetworkID != networkID {
return fmt.Errorf("network id does not match: theirs %d, ours %d", reply.NetworkID, networkID)
}
if uint(reply.ProtocolVersion) > version {
return fmt.Errorf("version is more than what this senty supports: theirs %d, max %d", reply.ProtocolVersion, version)
}
if uint(reply.ProtocolVersion) < minVersion {
return fmt.Errorf("version is less than allowed minimum: theirs %d, min %d", reply.ProtocolVersion, minVersion)
}
genesisHash := gointerfaces.ConvertH256ToHash(status.ForkData.Genesis)
if reply.Genesis != genesisHash {
return fmt.Errorf("genesis hash does not match: theirs %x, ours %x", reply.Genesis, genesisHash)
}
forks := make([]uint64, len(status.ForkData.Forks))
// copy because forkid.NewFilterFromForks will write into this slice
copy(forks, status.ForkData.Forks)
forkFilter := forkid.NewFilterFromForks(forks, genesisHash, status.MaxBlock)
if err := forkFilter(reply.ForkID); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,78 @@
package sentry
import (
"math/big"
"testing"
"github.com/holiman/uint256"
"github.com/ledgerwatch/erigon-lib/gointerfaces"
proto_sentry "github.com/ledgerwatch/erigon-lib/gointerfaces/sentry"
"github.com/ledgerwatch/erigon/common"
"github.com/ledgerwatch/erigon/core/forkid"
"github.com/ledgerwatch/erigon/eth/protocols/eth"
"github.com/ledgerwatch/erigon/params"
"github.com/stretchr/testify/assert"
)
func TestCheckPeerStatusCompatibility(t *testing.T) {
var version uint = eth.ETH66
networkID := params.MainnetChainConfig.ChainID.Uint64()
goodReply := eth.StatusPacket{
ProtocolVersion: uint32(version),
NetworkID: networkID,
TD: big.NewInt(0),
Head: common.Hash{},
Genesis: params.MainnetGenesisHash,
ForkID: forkid.NewID(params.MainnetChainConfig, params.MainnetGenesisHash, 0),
}
status := proto_sentry.StatusData{
NetworkId: networkID,
TotalDifficulty: gointerfaces.ConvertUint256IntToH256(new(uint256.Int)),
BestHash: nil,
ForkData: &proto_sentry.Forks{
Genesis: gointerfaces.ConvertHashToH256(params.MainnetGenesisHash),
Forks: forkid.GatherForks(params.MainnetChainConfig),
},
MaxBlock: 0,
}
t.Run("ok", func(t *testing.T) {
err := checkPeerStatusCompatibility(&goodReply, &status, version, version)
assert.Nil(t, err)
})
t.Run("network mismatch", func(t *testing.T) {
reply := goodReply
reply.NetworkID = 0
err := checkPeerStatusCompatibility(&reply, &status, version, version)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "network")
})
t.Run("version mismatch min", func(t *testing.T) {
reply := goodReply
reply.ProtocolVersion = eth.ETH66 - 1
err := checkPeerStatusCompatibility(&reply, &status, version, version)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "version is less")
})
t.Run("version mismatch max", func(t *testing.T) {
reply := goodReply
reply.ProtocolVersion = eth.ETH66 + 1
err := checkPeerStatusCompatibility(&reply, &status, version, version)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "version is more")
})
t.Run("genesis mismatch", func(t *testing.T) {
reply := goodReply
reply.Genesis = common.Hash{}
err := checkPeerStatusCompatibility(&reply, &status, version, version)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "genesis")
})
t.Run("fork mismatch", func(t *testing.T) {
reply := goodReply
reply.ForkID = forkid.ID{}
err := checkPeerStatusCompatibility(&reply, &status, version, version)
assert.NotNil(t, err)
assert.ErrorIs(t, err, forkid.ErrLocalIncompatibleOrStale)
})
}

View File

@ -211,6 +211,7 @@ func handShake(
ourTD := gointerfaces.ConvertH256ToUint256Int(status.TotalDifficulty)
// Convert proto status data into the one required by devp2p
genesisHash := gointerfaces.ConvertH256ToHash(status.ForkData.Genesis)
go func() {
defer debug.LogPanic()
s := &eth.StatusPacket{
@ -223,58 +224,15 @@ func handShake(
}
errc <- p2p.Send(rw, eth.StatusMsg, s)
}()
var readStatus = func() error {
forks := make([]uint64, len(status.ForkData.Forks)) // copy because forkid.NewFilterFromForks will write into this slice
copy(forks, status.ForkData.Forks)
forkFilter := forkid.NewFilterFromForks(forks, genesisHash, status.MaxBlock)
networkID := status.NetworkId
// Read handshake message
msg, err1 := rw.ReadMsg()
if err1 != nil {
return err1
}
if msg.Code != eth.StatusMsg {
msg.Discard()
return fmt.Errorf("first msg has code %x (!= %x)", msg.Code, eth.StatusMsg)
}
if msg.Size > eth.ProtocolMaxMsgSize {
msg.Discard()
return fmt.Errorf("message is too large %d, limit %d", msg.Size, eth.ProtocolMaxMsgSize)
}
// Decode the handshake and make sure everything matches
var reply eth.StatusPacket
if err1 = msg.Decode(&reply); err1 != nil {
msg.Discard()
return fmt.Errorf("decode message %v: %w", msg, err1)
}
msg.Discard()
if reply.NetworkID != networkID {
return fmt.Errorf("network id does not match: theirs %d, ours %d", reply.NetworkID, networkID)
}
if uint(reply.ProtocolVersion) < minVersion {
return fmt.Errorf("version is less than allowed minimum: theirs %d, min %d", reply.ProtocolVersion, minVersion)
}
if uint(reply.ProtocolVersion) > version {
return fmt.Errorf("version is more than what this senty supports: theirs %d, max %d", reply.ProtocolVersion, version)
}
if reply.Genesis != genesisHash {
return fmt.Errorf("genesis hash does not match: theirs %x, ours %x", reply.Genesis, genesisHash)
}
if err1 = forkFilter(reply.ForkID); err1 != nil {
return fmt.Errorf("%w", err1)
}
if startSync != nil {
if err := startSync(reply.Head); err != nil {
return err
}
}
return nil
}
go func() {
errc <- readStatus()
reply, err := readAndValidatePeerStatusMessage(rw, status, version, minVersion)
if (err == nil) && (startSync != nil) {
err = startSync(reply.Head)
}
errc <- err
}()
timeout := time.NewTimer(handshakeTimeout)