p2p: fix RLPx disconnect message decoding (#8056)

The disconnect message could either be a plain integer, or a list with
one integer element. We were encoding it as a plain integer, but
decoding as a list. Change this to be able to decode any format.
This commit is contained in:
battlmonstr 2023-08-24 13:49:19 +02:00 committed by GitHub
parent 66d93f2489
commit bb2c2adbb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 20 deletions

View File

@ -1,12 +1,12 @@
package observer
import (
"bytes"
"context"
"crypto/ecdsa"
"fmt"
"math/big"
"net"
"strings"
"time"
libcommon "github.com/ledgerwatch/erigon-lib/common"
@ -215,15 +215,11 @@ func readMessage(conn *rlpx.Conn, expectedMessageID uint64, decodeError Handshak
return readMessage(conn, expectedMessageID, decodeError, message)
}
if messageID == RLPxMessageIDDisconnect {
var reason [1]p2p.DiscReason
err = rlp.DecodeBytes(data, &reason)
if (err != nil) && strings.Contains(err.Error(), "rlp: expected input list") {
err = rlp.DecodeBytes(data, &reason[0])
}
reason, err := p2p.DisconnectMessagePayloadDecode(bytes.NewBuffer(data))
if err != nil {
return NewHandshakeError(HandshakeErrorIDDisconnectDecode, err, 0)
}
return NewHandshakeError(HandshakeErrorIDDisconnect, reason[0], uint64(reason[0]))
return NewHandshakeError(HandshakeErrorIDDisconnect, reason, uint64(reason))
}
if messageID != expectedMessageID {
return NewHandshakeError(HandshakeErrorIDUnexpectedMessage, nil, messageID)

View File

@ -322,11 +322,13 @@ func (p *Peer) handle(msg Msg) error {
msg.Discard()
go SendItems(p.rw, pongMsg)
case msg.Code == discMsg:
// This is the last message. We don't need to discard or
// check errors because, the connection will be closed after it.
var m struct{ R DiscReason }
rlp.Decode(msg.Payload, &m)
return m.R
// This is the last message.
// We don't need to discard because the connection will be closed after it.
reason, err := DisconnectMessagePayloadDecode(msg.Payload)
if err != nil {
p.log.Debug("Peer.handle: failed to rlp.Decode msg.Payload", "err", err)
}
return reason
case msg.Code < baseProtocolLength:
// ignore other base protocol messages
msg.Discard()

View File

@ -22,6 +22,7 @@ import (
"fmt"
"io"
"net"
"strings"
"sync"
"time"
@ -123,7 +124,7 @@ func (t *rlpxTransport) close(err error) {
if err := t.conn.SetWriteDeadline(deadline); err == nil {
// Connection supports write deadline.
t.wbuf.Reset()
rlp.Encode(&t.wbuf, []DiscReason{r}) //nolint:errcheck
_ = DisconnectMessagePayloadEncode(&t.wbuf, r)
t.conn.Write(discMsg, t.wbuf.Bytes()) //nolint:errcheck
}
}
@ -169,13 +170,8 @@ func readProtocolHandshake(rw MsgReader) (*protoHandshake, error) {
if msg.Code == discMsg {
// Disconnect before protocol handshake is valid according to the
// spec and we send it ourself if the post-handshake checks fail.
// We can't return the reason directly, though, because it is echoed
// back otherwise. Wrap it in a string instead.
var reason [1]DiscReason
if err = rlp.Decode(msg.Payload, &reason); err != nil {
return nil, err
}
return nil, reason[0]
reason, _ := DisconnectMessagePayloadDecode(msg.Payload)
return nil, reason
}
if msg.Code != handshakeMsg {
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
@ -189,3 +185,34 @@ func readProtocolHandshake(rw MsgReader) (*protoHandshake, error) {
}
return &hs, nil
}
func DisconnectMessagePayloadDecode(reader io.Reader) (DiscReason, error) {
var buffer bytes.Buffer
_, err := buffer.ReadFrom(reader)
if err != nil {
return DiscRequested, err
}
data := buffer.Bytes()
if len(data) == 0 {
return DiscRequested, nil
}
var reasonList struct{ Reason DiscReason }
err = rlp.DecodeBytes(data, &reasonList)
// en empty list
if (err != nil) && strings.Contains(err.Error(), "rlp: too few elements") {
return DiscRequested, nil
}
// not a list, try to decode as a plain integer
if (err != nil) && strings.Contains(err.Error(), "rlp: expected input list") {
err = rlp.DecodeBytes(data, &reasonList.Reason)
}
return reasonList.Reason, err
}
func DisconnectMessagePayloadEncode(writer io.Writer, reason DiscReason) error {
return rlp.Encode(writer, []DiscReason{reason})
}

View File

@ -17,6 +17,7 @@
package p2p
import (
"bytes"
"errors"
"reflect"
"sync"
@ -146,3 +147,54 @@ func TestProtocolHandshakeErrors(t *testing.T) {
}
}
}
func TestDisconnectMessagePayloadDecode(t *testing.T) {
var buffer bytes.Buffer
err := DisconnectMessagePayloadEncode(&buffer, DiscTooManyPeers)
if err != nil {
t.Error(err)
}
reason, err := DisconnectMessagePayloadDecode(&buffer)
if err != nil {
t.Error(err)
}
if reason != DiscTooManyPeers {
t.Fail()
}
// plain integer
reason, err = DisconnectMessagePayloadDecode(bytes.NewBuffer([]byte{uint8(DiscTooManyPeers)}))
if err != nil {
t.Error(err)
}
if reason != DiscTooManyPeers {
t.Fail()
}
// single-element RLP list
reason, err = DisconnectMessagePayloadDecode(bytes.NewBuffer([]byte{0xC1, uint8(DiscTooManyPeers)}))
if err != nil {
t.Error(err)
}
if reason != DiscTooManyPeers {
t.Fail()
}
// empty RLP list
reason, err = DisconnectMessagePayloadDecode(bytes.NewBuffer([]byte{0xC0}))
if err != nil {
t.Error(err)
}
if reason != DiscRequested {
t.Fail()
}
// empty payload
reason, err = DisconnectMessagePayloadDecode(bytes.NewBuffer([]byte{}))
if err != nil {
t.Error(err)
}
if reason != DiscRequested {
t.Fail()
}
}