mirror of
https://gitlab.com/pulsechaincom/erigon-pulse.git
synced 2025-01-11 05:20:05 +00:00
5110f80bba
There are now two deadlines, frameReadTimeout and payloadReadTimeout. The frame timeout is longer and allows for connections that are idle. The message timeout is still short and ensures that we don't get stuck in the middle of a message.
398 lines
9.3 KiB
Go
398 lines
9.3 KiB
Go
package p2p
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net"
|
|
"sort"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/ethereum/go-ethereum/logger"
|
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
|
"github.com/ethereum/go-ethereum/rlp"
|
|
)
|
|
|
|
const (
|
|
baseProtocolVersion = 2
|
|
baseProtocolLength = uint64(16)
|
|
baseProtocolMaxMsgSize = 10 * 1024 * 1024
|
|
|
|
disconnectGracePeriod = 2 * time.Second
|
|
)
|
|
|
|
const (
|
|
// devp2p message codes
|
|
handshakeMsg = 0x00
|
|
discMsg = 0x01
|
|
pingMsg = 0x02
|
|
pongMsg = 0x03
|
|
getPeersMsg = 0x04
|
|
peersMsg = 0x05
|
|
)
|
|
|
|
// handshake is the RLP structure of the protocol handshake.
|
|
type handshake struct {
|
|
Version uint64
|
|
Name string
|
|
Caps []Cap
|
|
ListenPort uint64
|
|
NodeID discover.NodeID
|
|
}
|
|
|
|
// Peer represents a connected remote node.
|
|
type Peer struct {
|
|
// Peers have all the log methods.
|
|
// Use them to display messages related to the peer.
|
|
*logger.Logger
|
|
|
|
infoMu sync.Mutex
|
|
name string
|
|
caps []Cap
|
|
|
|
ourID, remoteID *discover.NodeID
|
|
ourName string
|
|
|
|
rw *frameRW
|
|
|
|
// These fields maintain the running protocols.
|
|
protocols []Protocol
|
|
runlock sync.RWMutex // protects running
|
|
running map[string]*proto
|
|
|
|
// disables protocol handshake, for testing
|
|
noHandshake bool
|
|
|
|
protoWG sync.WaitGroup
|
|
protoErr chan error
|
|
closed chan struct{}
|
|
disc chan DiscReason
|
|
}
|
|
|
|
// NewPeer returns a peer for testing purposes.
|
|
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
|
|
conn, _ := net.Pipe()
|
|
peer := newPeer(conn, nil, "", nil, &id)
|
|
peer.setHandshakeInfo(name, caps)
|
|
close(peer.closed) // ensures Disconnect doesn't block
|
|
return peer
|
|
}
|
|
|
|
// ID returns the node's public key.
|
|
func (p *Peer) ID() discover.NodeID {
|
|
return *p.remoteID
|
|
}
|
|
|
|
// Name returns the node name that the remote node advertised.
|
|
func (p *Peer) Name() string {
|
|
// this needs a lock because the information is part of the
|
|
// protocol handshake.
|
|
p.infoMu.Lock()
|
|
name := p.name
|
|
p.infoMu.Unlock()
|
|
return name
|
|
}
|
|
|
|
// Caps returns the capabilities (supported subprotocols) of the remote peer.
|
|
func (p *Peer) Caps() []Cap {
|
|
// this needs a lock because the information is part of the
|
|
// protocol handshake.
|
|
p.infoMu.Lock()
|
|
caps := p.caps
|
|
p.infoMu.Unlock()
|
|
return caps
|
|
}
|
|
|
|
// RemoteAddr returns the remote address of the network connection.
|
|
func (p *Peer) RemoteAddr() net.Addr {
|
|
return p.rw.RemoteAddr()
|
|
}
|
|
|
|
// LocalAddr returns the local address of the network connection.
|
|
func (p *Peer) LocalAddr() net.Addr {
|
|
return p.rw.LocalAddr()
|
|
}
|
|
|
|
// Disconnect terminates the peer connection with the given reason.
|
|
// It returns immediately and does not wait until the connection is closed.
|
|
func (p *Peer) Disconnect(reason DiscReason) {
|
|
select {
|
|
case p.disc <- reason:
|
|
case <-p.closed:
|
|
}
|
|
}
|
|
|
|
// String implements fmt.Stringer.
|
|
func (p *Peer) String() string {
|
|
return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
|
|
}
|
|
|
|
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
|
|
logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
|
|
return &Peer{
|
|
Logger: logger.NewLogger(logtag),
|
|
rw: newFrameRW(conn, msgWriteTimeout),
|
|
ourID: ourID,
|
|
ourName: ourName,
|
|
remoteID: remoteID,
|
|
protocols: protocols,
|
|
running: make(map[string]*proto),
|
|
disc: make(chan DiscReason),
|
|
protoErr: make(chan error),
|
|
closed: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
func (p *Peer) setHandshakeInfo(name string, caps []Cap) {
|
|
p.infoMu.Lock()
|
|
p.name = name
|
|
p.caps = caps
|
|
p.infoMu.Unlock()
|
|
}
|
|
|
|
func (p *Peer) run() DiscReason {
|
|
var readErr = make(chan error, 1)
|
|
defer p.closeProtocols()
|
|
defer close(p.closed)
|
|
|
|
go func() { readErr <- p.readLoop() }()
|
|
|
|
if !p.noHandshake {
|
|
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
|
|
p.DebugDetailf("Protocol handshake error: %v\n", err)
|
|
p.rw.Close()
|
|
return DiscProtocolError
|
|
}
|
|
}
|
|
|
|
// Wait for an error or disconnect.
|
|
var reason DiscReason
|
|
select {
|
|
case err := <-readErr:
|
|
// We rely on protocols to abort if there is a write error. It
|
|
// might be more robust to handle them here as well.
|
|
p.DebugDetailf("Read error: %v\n", err)
|
|
p.rw.Close()
|
|
return DiscNetworkError
|
|
|
|
case err := <-p.protoErr:
|
|
reason = discReasonForError(err)
|
|
case reason = <-p.disc:
|
|
}
|
|
p.politeDisconnect(reason)
|
|
|
|
// Wait for readLoop. It will end because conn is now closed.
|
|
<-readErr
|
|
p.Debugf("Disconnected: %v\n", reason)
|
|
return reason
|
|
}
|
|
|
|
func (p *Peer) politeDisconnect(reason DiscReason) {
|
|
done := make(chan struct{})
|
|
go func() {
|
|
EncodeMsg(p.rw, discMsg, uint(reason))
|
|
// Wait for the other side to close the connection.
|
|
// Discard any data that they send until then.
|
|
io.Copy(ioutil.Discard, p.rw)
|
|
close(done)
|
|
}()
|
|
select {
|
|
case <-done:
|
|
case <-time.After(disconnectGracePeriod):
|
|
}
|
|
p.rw.Close()
|
|
}
|
|
|
|
func (p *Peer) readLoop() error {
|
|
if !p.noHandshake {
|
|
if err := readProtocolHandshake(p, p.rw); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
for {
|
|
msg, err := p.rw.ReadMsg()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err = p.handle(msg); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *Peer) handle(msg Msg) error {
|
|
switch {
|
|
case msg.Code == pingMsg:
|
|
msg.Discard()
|
|
go EncodeMsg(p.rw, pongMsg)
|
|
case msg.Code == discMsg:
|
|
var reason DiscReason
|
|
// no need to discard or for error checking, we'll close the
|
|
// connection after this.
|
|
rlp.Decode(msg.Payload, &reason)
|
|
p.Disconnect(DiscRequested)
|
|
return discRequestedError(reason)
|
|
case msg.Code < baseProtocolLength:
|
|
// ignore other base protocol messages
|
|
return msg.Discard()
|
|
default:
|
|
// it's a subprotocol message
|
|
proto, err := p.getProto(msg.Code)
|
|
if err != nil {
|
|
return fmt.Errorf("msg code out of range: %v", msg.Code)
|
|
}
|
|
proto.in <- msg
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
|
|
// read and handle remote handshake
|
|
msg, err := rw.ReadMsg()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if msg.Code != handshakeMsg {
|
|
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
|
|
}
|
|
if msg.Size > baseProtocolMaxMsgSize {
|
|
return newPeerError(errInvalidMsg, "message too big")
|
|
}
|
|
var hs handshake
|
|
if err := msg.Decode(&hs); err != nil {
|
|
return err
|
|
}
|
|
// validate handshake info
|
|
if hs.Version != baseProtocolVersion {
|
|
return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n",
|
|
baseProtocolVersion, hs.Version)
|
|
}
|
|
if hs.NodeID == *p.remoteID {
|
|
return newPeerError(errPubkeyForbidden, "node ID mismatch")
|
|
}
|
|
// TODO: remove Caps with empty name
|
|
p.setHandshakeInfo(hs.Name, hs.Caps)
|
|
p.startSubprotocols(hs.Caps)
|
|
return nil
|
|
}
|
|
|
|
func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error {
|
|
var caps []interface{}
|
|
for _, proto := range ps {
|
|
caps = append(caps, proto.cap())
|
|
}
|
|
return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id)
|
|
}
|
|
|
|
// startProtocols starts matching named subprotocols.
|
|
func (p *Peer) startSubprotocols(caps []Cap) {
|
|
sort.Sort(capsByName(caps))
|
|
p.runlock.Lock()
|
|
defer p.runlock.Unlock()
|
|
offset := baseProtocolLength
|
|
outer:
|
|
for _, cap := range caps {
|
|
for _, proto := range p.protocols {
|
|
if proto.Name == cap.Name &&
|
|
proto.Version == cap.Version &&
|
|
p.running[cap.Name] == nil {
|
|
p.running[cap.Name] = p.startProto(offset, proto)
|
|
offset += proto.Length
|
|
continue outer
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
|
|
p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version)
|
|
rw := &proto{
|
|
name: impl.Name,
|
|
in: make(chan Msg),
|
|
offset: offset,
|
|
maxcode: impl.Length,
|
|
w: p.rw,
|
|
}
|
|
p.protoWG.Add(1)
|
|
go func() {
|
|
err := impl.Run(p, rw)
|
|
if err == nil {
|
|
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
|
|
err = errors.New("protocol returned")
|
|
} else {
|
|
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
|
|
}
|
|
select {
|
|
case p.protoErr <- err:
|
|
case <-p.closed:
|
|
}
|
|
p.protoWG.Done()
|
|
}()
|
|
return rw
|
|
}
|
|
|
|
// getProto finds the protocol responsible for handling
|
|
// the given message code.
|
|
func (p *Peer) getProto(code uint64) (*proto, error) {
|
|
p.runlock.RLock()
|
|
defer p.runlock.RUnlock()
|
|
for _, proto := range p.running {
|
|
if code >= proto.offset && code < proto.offset+proto.maxcode {
|
|
return proto, nil
|
|
}
|
|
}
|
|
return nil, newPeerError(errInvalidMsgCode, "%d", code)
|
|
}
|
|
|
|
func (p *Peer) closeProtocols() {
|
|
p.runlock.RLock()
|
|
for _, p := range p.running {
|
|
close(p.in)
|
|
}
|
|
p.runlock.RUnlock()
|
|
p.protoWG.Wait()
|
|
}
|
|
|
|
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
|
// this exists because of Server.Broadcast.
|
|
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
|
|
p.runlock.RLock()
|
|
proto, ok := p.running[protoName]
|
|
p.runlock.RUnlock()
|
|
if !ok {
|
|
return fmt.Errorf("protocol %s not handled by peer", protoName)
|
|
}
|
|
if msg.Code >= proto.maxcode {
|
|
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
|
|
}
|
|
msg.Code += proto.offset
|
|
return p.rw.WriteMsg(msg)
|
|
}
|
|
|
|
type proto struct {
|
|
name string
|
|
in chan Msg
|
|
maxcode, offset uint64
|
|
w MsgWriter
|
|
}
|
|
|
|
func (rw *proto) WriteMsg(msg Msg) error {
|
|
if msg.Code >= rw.maxcode {
|
|
return newPeerError(errInvalidMsgCode, "not handled")
|
|
}
|
|
msg.Code += rw.offset
|
|
return rw.w.WriteMsg(msg)
|
|
}
|
|
|
|
func (rw *proto) ReadMsg() (Msg, error) {
|
|
msg, ok := <-rw.in
|
|
if !ok {
|
|
return msg, io.EOF
|
|
}
|
|
msg.Code -= rw.offset
|
|
return msg, nil
|
|
}
|