mirror of
https://gitlab.com/pulsechaincom/erigon-pulse.git
synced 2024-12-22 03:30:37 +00:00
p2p: fix RLPx disconnect message decoding (#8056)
The disconnect message could either be a plain integer, or a list with one integer element. We were encoding it as a plain integer, but decoding as a list. Change this to be able to decode any format.
This commit is contained in:
parent
66d93f2489
commit
bb2c2adbb6
@ -1,12 +1,12 @@
|
||||
package observer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
libcommon "github.com/ledgerwatch/erigon-lib/common"
|
||||
@ -215,15 +215,11 @@ func readMessage(conn *rlpx.Conn, expectedMessageID uint64, decodeError Handshak
|
||||
return readMessage(conn, expectedMessageID, decodeError, message)
|
||||
}
|
||||
if messageID == RLPxMessageIDDisconnect {
|
||||
var reason [1]p2p.DiscReason
|
||||
err = rlp.DecodeBytes(data, &reason)
|
||||
if (err != nil) && strings.Contains(err.Error(), "rlp: expected input list") {
|
||||
err = rlp.DecodeBytes(data, &reason[0])
|
||||
}
|
||||
reason, err := p2p.DisconnectMessagePayloadDecode(bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return NewHandshakeError(HandshakeErrorIDDisconnectDecode, err, 0)
|
||||
}
|
||||
return NewHandshakeError(HandshakeErrorIDDisconnect, reason[0], uint64(reason[0]))
|
||||
return NewHandshakeError(HandshakeErrorIDDisconnect, reason, uint64(reason))
|
||||
}
|
||||
if messageID != expectedMessageID {
|
||||
return NewHandshakeError(HandshakeErrorIDUnexpectedMessage, nil, messageID)
|
||||
|
12
p2p/peer.go
12
p2p/peer.go
@ -322,11 +322,13 @@ func (p *Peer) handle(msg Msg) error {
|
||||
msg.Discard()
|
||||
go SendItems(p.rw, pongMsg)
|
||||
case msg.Code == discMsg:
|
||||
// This is the last message. We don't need to discard or
|
||||
// check errors because, the connection will be closed after it.
|
||||
var m struct{ R DiscReason }
|
||||
rlp.Decode(msg.Payload, &m)
|
||||
return m.R
|
||||
// This is the last message.
|
||||
// We don't need to discard because the connection will be closed after it.
|
||||
reason, err := DisconnectMessagePayloadDecode(msg.Payload)
|
||||
if err != nil {
|
||||
p.log.Debug("Peer.handle: failed to rlp.Decode msg.Payload", "err", err)
|
||||
}
|
||||
return reason
|
||||
case msg.Code < baseProtocolLength:
|
||||
// ignore other base protocol messages
|
||||
msg.Discard()
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -123,7 +124,7 @@ func (t *rlpxTransport) close(err error) {
|
||||
if err := t.conn.SetWriteDeadline(deadline); err == nil {
|
||||
// Connection supports write deadline.
|
||||
t.wbuf.Reset()
|
||||
rlp.Encode(&t.wbuf, []DiscReason{r}) //nolint:errcheck
|
||||
_ = DisconnectMessagePayloadEncode(&t.wbuf, r)
|
||||
t.conn.Write(discMsg, t.wbuf.Bytes()) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
@ -169,13 +170,8 @@ func readProtocolHandshake(rw MsgReader) (*protoHandshake, error) {
|
||||
if msg.Code == discMsg {
|
||||
// Disconnect before protocol handshake is valid according to the
|
||||
// spec and we send it ourself if the post-handshake checks fail.
|
||||
// We can't return the reason directly, though, because it is echoed
|
||||
// back otherwise. Wrap it in a string instead.
|
||||
var reason [1]DiscReason
|
||||
if err = rlp.Decode(msg.Payload, &reason); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, reason[0]
|
||||
reason, _ := DisconnectMessagePayloadDecode(msg.Payload)
|
||||
return nil, reason
|
||||
}
|
||||
if msg.Code != handshakeMsg {
|
||||
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
|
||||
@ -189,3 +185,34 @@ func readProtocolHandshake(rw MsgReader) (*protoHandshake, error) {
|
||||
}
|
||||
return &hs, nil
|
||||
}
|
||||
|
||||
func DisconnectMessagePayloadDecode(reader io.Reader) (DiscReason, error) {
|
||||
var buffer bytes.Buffer
|
||||
_, err := buffer.ReadFrom(reader)
|
||||
if err != nil {
|
||||
return DiscRequested, err
|
||||
}
|
||||
data := buffer.Bytes()
|
||||
if len(data) == 0 {
|
||||
return DiscRequested, nil
|
||||
}
|
||||
|
||||
var reasonList struct{ Reason DiscReason }
|
||||
err = rlp.DecodeBytes(data, &reasonList)
|
||||
|
||||
// en empty list
|
||||
if (err != nil) && strings.Contains(err.Error(), "rlp: too few elements") {
|
||||
return DiscRequested, nil
|
||||
}
|
||||
|
||||
// not a list, try to decode as a plain integer
|
||||
if (err != nil) && strings.Contains(err.Error(), "rlp: expected input list") {
|
||||
err = rlp.DecodeBytes(data, &reasonList.Reason)
|
||||
}
|
||||
|
||||
return reasonList.Reason, err
|
||||
}
|
||||
|
||||
func DisconnectMessagePayloadEncode(writer io.Writer, reason DiscReason) error {
|
||||
return rlp.Encode(writer, []DiscReason{reason})
|
||||
}
|
||||
|
@ -17,6 +17,7 @@
|
||||
package p2p
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
@ -146,3 +147,54 @@ func TestProtocolHandshakeErrors(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisconnectMessagePayloadDecode(t *testing.T) {
|
||||
var buffer bytes.Buffer
|
||||
err := DisconnectMessagePayloadEncode(&buffer, DiscTooManyPeers)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
reason, err := DisconnectMessagePayloadDecode(&buffer)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if reason != DiscTooManyPeers {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// plain integer
|
||||
reason, err = DisconnectMessagePayloadDecode(bytes.NewBuffer([]byte{uint8(DiscTooManyPeers)}))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if reason != DiscTooManyPeers {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// single-element RLP list
|
||||
reason, err = DisconnectMessagePayloadDecode(bytes.NewBuffer([]byte{0xC1, uint8(DiscTooManyPeers)}))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if reason != DiscTooManyPeers {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// empty RLP list
|
||||
reason, err = DisconnectMessagePayloadDecode(bytes.NewBuffer([]byte{0xC0}))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if reason != DiscRequested {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// empty payload
|
||||
reason, err = DisconnectMessagePayloadDecode(bytes.NewBuffer([]byte{}))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if reason != DiscRequested {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user