diff --git a/cmd/bootnode/main.go b/cmd/bootnode/main.go
index 845900865..346523ddb 100644
--- a/cmd/bootnode/main.go
+++ b/cmd/bootnode/main.go
@@ -119,16 +119,17 @@ func main() {
}
if *runv5 {
- if _, err := discv5.ListenUDP(nodeKey, conn, realaddr, "", restrictList); err != nil {
+ if _, err := discv5.ListenUDP(nodeKey, conn, "", restrictList); err != nil {
utils.Fatalf("%v", err)
}
} else {
+ db, _ := enode.OpenDB("")
+ ln := enode.NewLocalNode(db, nodeKey)
cfg := discover.Config{
- PrivateKey: nodeKey,
- AnnounceAddr: realaddr,
- NetRestrict: restrictList,
+ PrivateKey: nodeKey,
+ NetRestrict: restrictList,
}
- if _, err := discover.ListenUDP(conn, cfg); err != nil {
+ if _, err := discover.ListenUDP(conn, ln, cfg); err != nil {
utils.Fatalf("%v", err)
}
}
diff --git a/node/node_test.go b/node/node_test.go
index e51900bd1..f833cd688 100644
--- a/node/node_test.go
+++ b/node/node_test.go
@@ -454,9 +454,9 @@ func TestProtocolGather(t *testing.T) {
Count int
Maker InstrumentingWrapper
}{
- "Zero Protocols": {0, InstrumentedServiceMakerA},
- "Single Protocol": {1, InstrumentedServiceMakerB},
- "Many Protocols": {25, InstrumentedServiceMakerC},
+ "zero": {0, InstrumentedServiceMakerA},
+ "one": {1, InstrumentedServiceMakerB},
+ "many": {10, InstrumentedServiceMakerC},
}
for id, config := range services {
protocols := make([]p2p.Protocol, config.Count)
@@ -480,7 +480,7 @@ func TestProtocolGather(t *testing.T) {
defer stack.Stop()
protocols := stack.Server().Protocols
- if len(protocols) != 26 {
+ if len(protocols) != 11 {
t.Fatalf("mismatching number of protocols launched: have %d, want %d", len(protocols), 26)
}
for id, config := range services {
diff --git a/p2p/dial.go b/p2p/dial.go
index 359cdbcbb..d228514fc 100644
--- a/p2p/dial.go
+++ b/p2p/dial.go
@@ -71,6 +71,7 @@ type dialstate struct {
maxDynDials int
ntab discoverTable
netrestrict *netutil.Netlist
+ self enode.ID
lookupRunning bool
dialing map[enode.ID]connFlag
@@ -84,7 +85,6 @@ type dialstate struct {
}
type discoverTable interface {
- Self() *enode.Node
Close()
Resolve(*enode.Node) *enode.Node
LookupRandom() []*enode.Node
@@ -126,10 +126,11 @@ type waitExpireTask struct {
time.Duration
}
-func newDialState(static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
+func newDialState(self enode.ID, static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
s := &dialstate{
maxDynDials: maxdyn,
ntab: ntab,
+ self: self,
netrestrict: netrestrict,
static: make(map[enode.ID]*dialTask),
dialing: make(map[enode.ID]connFlag),
@@ -266,7 +267,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
return errAlreadyDialing
case peers[n.ID()] != nil:
return errAlreadyConnected
- case s.ntab != nil && n.ID() == s.ntab.Self().ID():
+ case n.ID() == s.self:
return errSelf
case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()):
return errNotWhitelisted
diff --git a/p2p/dial_test.go b/p2p/dial_test.go
index 2de2c5999..f41ab7752 100644
--- a/p2p/dial_test.go
+++ b/p2p/dial_test.go
@@ -89,7 +89,7 @@ func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t)
// This test checks that dynamic dials are launched from discovery results.
func TestDialStateDynDial(t *testing.T) {
runDialTest(t, dialtest{
- init: newDialState(nil, nil, fakeTable{}, 5, nil),
+ init: newDialState(enode.ID{}, nil, nil, fakeTable{}, 5, nil),
rounds: []round{
// A discovery query is launched.
{
@@ -236,7 +236,7 @@ func TestDialStateDynDialBootnode(t *testing.T) {
newNode(uintID(8), nil),
}
runDialTest(t, dialtest{
- init: newDialState(nil, bootnodes, table, 5, nil),
+ init: newDialState(enode.ID{}, nil, bootnodes, table, 5, nil),
rounds: []round{
// 2 dynamic dials attempted, bootnodes pending fallback interval
{
@@ -324,7 +324,7 @@ func TestDialStateDynDialFromTable(t *testing.T) {
}
runDialTest(t, dialtest{
- init: newDialState(nil, nil, table, 10, nil),
+ init: newDialState(enode.ID{}, nil, nil, table, 10, nil),
rounds: []round{
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
{
@@ -430,7 +430,7 @@ func TestDialStateNetRestrict(t *testing.T) {
restrict.Add("127.0.2.0/24")
runDialTest(t, dialtest{
- init: newDialState(nil, nil, table, 10, restrict),
+ init: newDialState(enode.ID{}, nil, nil, table, 10, restrict),
rounds: []round{
{
new: []task{
@@ -453,7 +453,7 @@ func TestDialStateStaticDial(t *testing.T) {
}
runDialTest(t, dialtest{
- init: newDialState(wantStatic, nil, fakeTable{}, 0, nil),
+ init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
rounds: []round{
// Static dials are launched for the nodes that
// aren't yet connected.
@@ -557,7 +557,7 @@ func TestDialStaticAfterReset(t *testing.T) {
},
}
dTest := dialtest{
- init: newDialState(wantStatic, nil, fakeTable{}, 0, nil),
+ init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
rounds: rounds,
}
runDialTest(t, dTest)
@@ -578,7 +578,7 @@ func TestDialStateCache(t *testing.T) {
}
runDialTest(t, dialtest{
- init: newDialState(wantStatic, nil, fakeTable{}, 0, nil),
+ init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
rounds: []round{
// Static dials are launched for the nodes that
// aren't yet connected.
@@ -640,7 +640,7 @@ func TestDialStateCache(t *testing.T) {
func TestDialResolve(t *testing.T) {
resolved := newNode(uintID(1), net.IP{127, 0, 55, 234})
table := &resolveMock{answer: resolved}
- state := newDialState(nil, nil, table, 0, nil)
+ state := newDialState(enode.ID{}, nil, nil, table, 0, nil)
// Check that the task is generated with an incomplete ID.
dest := newNode(uintID(1), nil)
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
index 7a3e41de1..afd4c9a27 100644
--- a/p2p/discover/table.go
+++ b/p2p/discover/table.go
@@ -72,21 +72,20 @@ type Table struct {
ips netutil.DistinctNetSet
db *enode.DB // database of known nodes
+ net transport
refreshReq chan chan struct{}
initDone chan struct{}
closeReq chan struct{}
closed chan struct{}
nodeAddedHook func(*node) // for testing
-
- net transport
- self *node // metadata of the local node
}
// transport is implemented by the UDP transport.
// it is an interface so we can test without opening lots of UDP
// sockets and without generating a private key.
type transport interface {
+ self() *enode.Node
ping(enode.ID, *net.UDPAddr) error
findnode(toid enode.ID, addr *net.UDPAddr, target encPubkey) ([]*node, error)
close()
@@ -100,11 +99,10 @@ type bucket struct {
ips netutil.DistinctNetSet
}
-func newTable(t transport, self *enode.Node, db *enode.DB, bootnodes []*enode.Node) (*Table, error) {
+func newTable(t transport, db *enode.DB, bootnodes []*enode.Node) (*Table, error) {
tab := &Table{
net: t,
db: db,
- self: wrapNode(self),
refreshReq: make(chan chan struct{}),
initDone: make(chan struct{}),
closeReq: make(chan struct{}),
@@ -127,6 +125,10 @@ func newTable(t transport, self *enode.Node, db *enode.DB, bootnodes []*enode.No
return tab, nil
}
+func (tab *Table) self() *enode.Node {
+ return tab.net.self()
+}
+
func (tab *Table) seedRand() {
var b [8]byte
crand.Read(b[:])
@@ -136,11 +138,6 @@ func (tab *Table) seedRand() {
tab.mutex.Unlock()
}
-// Self returns the local node.
-func (tab *Table) Self() *enode.Node {
- return unwrapNode(tab.self)
-}
-
// ReadRandomNodes fills the given slice with random nodes from the table. The results
// are guaranteed to be unique for a single invocation, no node will appear twice.
func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) {
@@ -183,6 +180,10 @@ func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) {
// Close terminates the network listener and flushes the node database.
func (tab *Table) Close() {
+ if tab.net != nil {
+ tab.net.close()
+ }
+
select {
case <-tab.closed:
// already closed.
@@ -257,7 +258,7 @@ func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node {
)
// don't query further if we hit ourself.
// unlikely to happen often in practice.
- asked[tab.self.ID()] = true
+ asked[tab.self().ID()] = true
for {
tab.mutex.Lock()
@@ -340,8 +341,8 @@ func (tab *Table) loop() {
revalidate = time.NewTimer(tab.nextRevalidateTime())
refresh = time.NewTicker(refreshInterval)
copyNodes = time.NewTicker(copyNodesInterval)
- revalidateDone = make(chan struct{})
refreshDone = make(chan struct{}) // where doRefresh reports completion
+ revalidateDone chan struct{} // where doRevalidate reports completion
waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs
)
defer refresh.Stop()
@@ -372,9 +373,11 @@ loop:
}
waiting, refreshDone = nil, nil
case <-revalidate.C:
+ revalidateDone = make(chan struct{})
go tab.doRevalidate(revalidateDone)
case <-revalidateDone:
revalidate.Reset(tab.nextRevalidateTime())
+ revalidateDone = nil
case <-copyNodes.C:
go tab.copyLiveNodes()
case <-tab.closeReq:
@@ -382,15 +385,15 @@ loop:
}
}
- if tab.net != nil {
- tab.net.close()
- }
if refreshDone != nil {
<-refreshDone
}
for _, ch := range waiting {
close(ch)
}
+ if revalidateDone != nil {
+ <-revalidateDone
+ }
close(tab.closed)
}
@@ -408,7 +411,7 @@ func (tab *Table) doRefresh(done chan struct{}) {
// Run self lookup to discover new neighbor nodes.
// We can only do this if we have a secp256k1 identity.
var key ecdsa.PublicKey
- if err := tab.self.Load((*enode.Secp256k1)(&key)); err == nil {
+ if err := tab.self().Load((*enode.Secp256k1)(&key)); err == nil {
tab.lookup(encodePubkey(&key), false)
}
@@ -530,7 +533,7 @@ func (tab *Table) len() (n int) {
// bucket returns the bucket for the given node ID hash.
func (tab *Table) bucket(id enode.ID) *bucket {
- d := enode.LogDist(tab.self.ID(), id)
+ d := enode.LogDist(tab.self().ID(), id)
if d <= bucketMinDistance {
return tab.buckets[0]
}
@@ -543,7 +546,7 @@ func (tab *Table) bucket(id enode.ID) *bucket {
//
// The caller must not hold tab.mutex.
func (tab *Table) add(n *node) {
- if n.ID() == tab.self.ID() {
+ if n.ID() == tab.self().ID() {
return
}
@@ -576,7 +579,7 @@ func (tab *Table) stuff(nodes []*node) {
defer tab.mutex.Unlock()
for _, n := range nodes {
- if n.ID() == tab.self.ID() {
+ if n.ID() == tab.self().ID() {
continue // don't add self
}
b := tab.bucket(n.ID())
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
index e8631024b..6b4cd2d18 100644
--- a/p2p/discover/table_test.go
+++ b/p2p/discover/table_test.go
@@ -141,7 +141,7 @@ func TestTable_IPLimit(t *testing.T) {
defer db.Close()
for i := 0; i < tableIPLimit+1; i++ {
- n := nodeAtDistance(tab.self.ID(), i, net.IP{172, 0, 1, byte(i)})
+ n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)})
tab.add(n)
}
if tab.len() > tableIPLimit {
@@ -158,7 +158,7 @@ func TestTable_BucketIPLimit(t *testing.T) {
d := 3
for i := 0; i < bucketIPLimit+1; i++ {
- n := nodeAtDistance(tab.self.ID(), d, net.IP{172, 0, 1, byte(i)})
+ n := nodeAtDistance(tab.self().ID(), d, net.IP{172, 0, 1, byte(i)})
tab.add(n)
}
if tab.len() > bucketIPLimit {
@@ -240,7 +240,7 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
for i := 0; i < len(buf); i++ {
ld := cfg.Rand.Intn(len(tab.buckets))
- tab.stuff([]*node{nodeAtDistance(tab.self.ID(), ld, intIP(ld))})
+ tab.stuff([]*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))})
}
gotN := tab.ReadRandomNodes(buf)
if gotN != tab.len() {
@@ -510,6 +510,10 @@ type preminedTestnet struct {
dists [hashBits + 1][]encPubkey
}
+func (tn *preminedTestnet) self() *enode.Node {
+ return nullNode
+}
+
func (tn *preminedTestnet) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
// current log distance is encoded in port number
// fmt.Println("findnode query at dist", toaddr.Port)
diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go
index 05ae0b6c0..d41519452 100644
--- a/p2p/discover/table_util_test.go
+++ b/p2p/discover/table_util_test.go
@@ -28,12 +28,17 @@ import (
"github.com/ethereum/go-ethereum/p2p/enr"
)
-func newTestTable(t transport) (*Table, *enode.DB) {
+var nullNode *enode.Node
+
+func init() {
var r enr.Record
r.Set(enr.IP{0, 0, 0, 0})
- n := enode.SignNull(&r, enode.ID{})
+ nullNode = enode.SignNull(&r, enode.ID{})
+}
+
+func newTestTable(t transport) (*Table, *enode.DB) {
db, _ := enode.OpenDB("")
- tab, _ := newTable(t, n, db, nil)
+ tab, _ := newTable(t, db, nil)
return tab, db
}
@@ -70,10 +75,10 @@ func intIP(i int) net.IP {
// fillBucket inserts nodes into the given bucket until it is full.
func fillBucket(tab *Table, n *node) (last *node) {
- ld := enode.LogDist(tab.self.ID(), n.ID())
+ ld := enode.LogDist(tab.self().ID(), n.ID())
b := tab.bucket(n.ID())
for len(b.entries) < bucketSize {
- b.entries = append(b.entries, nodeAtDistance(tab.self.ID(), ld, intIP(ld)))
+ b.entries = append(b.entries, nodeAtDistance(tab.self().ID(), ld, intIP(ld)))
}
return b.entries[bucketSize-1]
}
@@ -81,15 +86,25 @@ func fillBucket(tab *Table, n *node) (last *node) {
type pingRecorder struct {
mu sync.Mutex
dead, pinged map[enode.ID]bool
+ n *enode.Node
}
func newPingRecorder() *pingRecorder {
+ var r enr.Record
+ r.Set(enr.IP{0, 0, 0, 0})
+ n := enode.SignNull(&r, enode.ID{})
+
return &pingRecorder{
dead: make(map[enode.ID]bool),
pinged: make(map[enode.ID]bool),
+ n: n,
}
}
+func (t *pingRecorder) self() *enode.Node {
+ return nullNode
+}
+
func (t *pingRecorder) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
return nil, nil
}
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index 45fcce282..37a044902 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -23,12 +23,12 @@ import (
"errors"
"fmt"
"net"
+ "sync"
"time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode"
- "github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -118,9 +118,11 @@ type (
)
func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
- ip := addr.IP.To4()
- if ip == nil {
- ip = addr.IP.To16()
+ ip := net.IP{}
+ if ip4 := addr.IP.To4(); ip4 != nil {
+ ip = ip4
+ } else if ip6 := addr.IP.To16(); ip6 != nil {
+ ip = ip6
}
return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
}
@@ -165,20 +167,19 @@ type conn interface {
LocalAddr() net.Addr
}
-// udp implements the RPC protocol.
+// udp implements the discovery v4 UDP wire protocol.
type udp struct {
conn conn
netrestrict *netutil.Netlist
priv *ecdsa.PrivateKey
- ourEndpoint rpcEndpoint
+ localNode *enode.LocalNode
+ db *enode.DB
+ tab *Table
+ wg sync.WaitGroup
addpending chan *pending
gotreply chan reply
-
- closing chan struct{}
- nat nat.Interface
-
- *Table
+ closing chan struct{}
}
// pending represents a pending reply.
@@ -230,60 +231,57 @@ type Config struct {
PrivateKey *ecdsa.PrivateKey
// These settings are optional:
- AnnounceAddr *net.UDPAddr // local address announced in the DHT
- NodeDBPath string // if set, the node database is stored at this filesystem location
- NetRestrict *netutil.Netlist // network whitelist
- Bootnodes []*enode.Node // list of bootstrap nodes
- Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
+ NetRestrict *netutil.Netlist // network whitelist
+ Bootnodes []*enode.Node // list of bootstrap nodes
+ Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
}
// ListenUDP returns a new table that listens for UDP packets on laddr.
-func ListenUDP(c conn, cfg Config) (*Table, error) {
- tab, _, err := newUDP(c, cfg)
+func ListenUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, error) {
+ tab, _, err := newUDP(c, ln, cfg)
if err != nil {
return nil, err
}
- log.Info("UDP listener up", "self", tab.self)
return tab, nil
}
-func newUDP(c conn, cfg Config) (*Table, *udp, error) {
- realaddr := c.LocalAddr().(*net.UDPAddr)
- if cfg.AnnounceAddr != nil {
- realaddr = cfg.AnnounceAddr
- }
- self := enode.NewV4(&cfg.PrivateKey.PublicKey, realaddr.IP, realaddr.Port, realaddr.Port)
- db, err := enode.OpenDB(cfg.NodeDBPath)
- if err != nil {
- return nil, nil, err
- }
-
+func newUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, *udp, error) {
udp := &udp{
conn: c,
priv: cfg.PrivateKey,
netrestrict: cfg.NetRestrict,
+ localNode: ln,
+ db: ln.Database(),
closing: make(chan struct{}),
gotreply: make(chan reply),
addpending: make(chan *pending),
}
- // TODO: separate TCP port
- udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port))
- tab, err := newTable(udp, self, db, cfg.Bootnodes)
+ tab, err := newTable(udp, ln.Database(), cfg.Bootnodes)
if err != nil {
return nil, nil, err
}
- udp.Table = tab
+ udp.tab = tab
+ udp.wg.Add(2)
go udp.loop()
go udp.readLoop(cfg.Unhandled)
- return udp.Table, udp, nil
+ return udp.tab, udp, nil
+}
+
+func (t *udp) self() *enode.Node {
+ return t.localNode.Node()
}
func (t *udp) close() {
close(t.closing)
t.conn.Close()
- t.db.Close()
- // TODO: wait for the loops to end.
+ t.wg.Wait()
+}
+
+func (t *udp) ourEndpoint() rpcEndpoint {
+ n := t.self()
+ a := &net.UDPAddr{IP: n.IP(), Port: n.UDP()}
+ return makeEndpoint(a, uint16(n.TCP()))
}
// ping sends a ping message to the given node and waits for a reply.
@@ -296,7 +294,7 @@ func (t *udp) ping(toid enode.ID, toaddr *net.UDPAddr) error {
func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-chan error {
req := &ping{
Version: 4,
- From: t.ourEndpoint,
+ From: t.ourEndpoint(),
To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB
Expiration: uint64(time.Now().Add(expiration).Unix()),
}
@@ -313,6 +311,7 @@ func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-ch
}
return ok
})
+ t.localNode.UDPContact(toaddr)
t.write(toaddr, req.name(), packet)
return errc
}
@@ -381,6 +380,8 @@ func (t *udp) handleReply(from enode.ID, ptype byte, req packet) bool {
// loop runs in its own goroutine. it keeps track of
// the refresh timer and the pending reply queue.
func (t *udp) loop() {
+ defer t.wg.Done()
+
var (
plist = list.New()
timeout = time.NewTimer(0)
@@ -542,10 +543,11 @@ func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet,
// readLoop runs in its own goroutine. it handles incoming UDP packets.
func (t *udp) readLoop(unhandled chan<- ReadPacket) {
- defer t.conn.Close()
+ defer t.wg.Done()
if unhandled != nil {
defer close(unhandled)
}
+
// Discovery packets are defined to be no larger than 1280 bytes.
// Packets larger than this size will be cut at the end and treated
// as invalid because their hash won't match.
@@ -629,10 +631,11 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte
n := wrapNode(enode.NewV4(key, from.IP, int(req.From.TCP), from.Port))
t.handleReply(n.ID(), pingPacket, req)
if time.Since(t.db.LastPongReceived(n.ID())) > bondExpiration {
- t.sendPing(n.ID(), from, func() { t.addThroughPing(n) })
+ t.sendPing(n.ID(), from, func() { t.tab.addThroughPing(n) })
} else {
- t.addThroughPing(n)
+ t.tab.addThroughPing(n)
}
+ t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
t.db.UpdateLastPingReceived(n.ID(), time.Now())
return nil
}
@@ -647,6 +650,7 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte
if !t.handleReply(fromID, pongPacket, req) {
return errUnsolicitedReply
}
+ t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
t.db.UpdateLastPongReceived(fromID, time.Now())
return nil
}
@@ -668,9 +672,9 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []
return errUnknownNode
}
target := enode.ID(crypto.Keccak256Hash(req.Target[:]))
- t.mutex.Lock()
- closest := t.closest(target, bucketSize).entries
- t.mutex.Unlock()
+ t.tab.mutex.Lock()
+ closest := t.tab.closest(target, bucketSize).entries
+ t.tab.mutex.Unlock()
p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
var sent bool
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
index da95c4f5c..a4ddaf750 100644
--- a/p2p/discover/udp_test.go
+++ b/p2p/discover/udp_test.go
@@ -71,7 +71,9 @@ func newUDPTest(t *testing.T) *udpTest {
remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
}
- test.table, test.udp, _ = newUDP(test.pipe, Config{PrivateKey: test.localkey})
+ db, _ := enode.OpenDB("")
+ ln := enode.NewLocalNode(db, test.localkey)
+ test.table, test.udp, _ = newUDP(test.pipe, ln, Config{PrivateKey: test.localkey})
// Wait for initial refresh so the table doesn't send unexpected findnode.
<-test.table.initDone
return test
@@ -355,12 +357,13 @@ func TestUDP_successfulPing(t *testing.T) {
// remote is unknown, the table pings back.
hash, _ := test.waitPacketOut(func(p *ping) error {
- if !reflect.DeepEqual(p.From, test.udp.ourEndpoint) {
- t.Errorf("got ping.From %v, want %v", p.From, test.udp.ourEndpoint)
+ if !reflect.DeepEqual(p.From, test.udp.ourEndpoint()) {
+ t.Errorf("got ping.From %#v, want %#v", p.From, test.udp.ourEndpoint())
}
wantTo := rpcEndpoint{
// The mirrored UDP address is the UDP packet sender.
- IP: test.remoteaddr.IP, UDP: uint16(test.remoteaddr.Port),
+ IP: test.remoteaddr.IP,
+ UDP: uint16(test.remoteaddr.Port),
TCP: 0,
}
if !reflect.DeepEqual(p.To, wantTo) {
diff --git a/p2p/discv5/udp.go b/p2p/discv5/udp.go
index 49e1cb811..ff5ed983b 100644
--- a/p2p/discv5/udp.go
+++ b/p2p/discv5/udp.go
@@ -230,7 +230,8 @@ type udp struct {
}
// ListenUDP returns a new table that listens for UDP packets on laddr.
-func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
+func ListenUDP(priv *ecdsa.PrivateKey, conn conn, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
+ realaddr := conn.LocalAddr().(*net.UDPAddr)
transport, err := listenUDP(priv, conn, realaddr)
if err != nil {
return nil, err
diff --git a/p2p/enode/localnode.go b/p2p/enode/localnode.go
new file mode 100644
index 000000000..623f8eae1
--- /dev/null
+++ b/p2p/enode/localnode.go
@@ -0,0 +1,246 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package enode
+
+import (
+ "crypto/ecdsa"
+ "fmt"
+ "net"
+ "reflect"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p/enr"
+ "github.com/ethereum/go-ethereum/p2p/netutil"
+)
+
+const (
+ // IP tracker configuration
+ iptrackMinStatements = 10
+ iptrackWindow = 5 * time.Minute
+ iptrackContactWindow = 10 * time.Minute
+)
+
+// LocalNode produces the signed node record of a local node, i.e. a node run in the
+// current process. Setting ENR entries via the Set method updates the record. A new version
+// of the record is signed on demand when the Node method is called.
+type LocalNode struct {
+ cur atomic.Value // holds a non-nil node pointer while the record is up-to-date.
+ id ID
+ key *ecdsa.PrivateKey
+ db *DB
+
+ // everything below is protected by a lock
+ mu sync.Mutex
+ seq uint64
+ entries map[string]enr.Entry
+ udpTrack *netutil.IPTracker // predicts external UDP endpoint
+ staticIP net.IP
+ fallbackIP net.IP
+ fallbackUDP int
+}
+
+// NewLocalNode creates a local node.
+func NewLocalNode(db *DB, key *ecdsa.PrivateKey) *LocalNode {
+ ln := &LocalNode{
+ id: PubkeyToIDV4(&key.PublicKey),
+ db: db,
+ key: key,
+ udpTrack: netutil.NewIPTracker(iptrackWindow, iptrackContactWindow, iptrackMinStatements),
+ entries: make(map[string]enr.Entry),
+ }
+ ln.seq = db.localSeq(ln.id)
+ ln.invalidate()
+ return ln
+}
+
+// Database returns the node database associated with the local node.
+func (ln *LocalNode) Database() *DB {
+ return ln.db
+}
+
+// Node returns the current version of the local node record.
+func (ln *LocalNode) Node() *Node {
+ n := ln.cur.Load().(*Node)
+ if n != nil {
+ return n
+ }
+ // Record was invalidated, sign a new copy.
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+ ln.sign()
+ return ln.cur.Load().(*Node)
+}
+
+// ID returns the local node ID.
+func (ln *LocalNode) ID() ID {
+ return ln.id
+}
+
+// Set puts the given entry into the local record, overwriting
+// any existing value.
+func (ln *LocalNode) Set(e enr.Entry) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.set(e)
+}
+
+func (ln *LocalNode) set(e enr.Entry) {
+ val, exists := ln.entries[e.ENRKey()]
+ if !exists || !reflect.DeepEqual(val, e) {
+ ln.entries[e.ENRKey()] = e
+ ln.invalidate()
+ }
+}
+
+// Delete removes the given entry from the local record.
+func (ln *LocalNode) Delete(e enr.Entry) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.delete(e)
+}
+
+func (ln *LocalNode) delete(e enr.Entry) {
+ _, exists := ln.entries[e.ENRKey()]
+ if exists {
+ delete(ln.entries, e.ENRKey())
+ ln.invalidate()
+ }
+}
+
+// SetStaticIP sets the local IP to the given one unconditionally.
+// This disables endpoint prediction.
+func (ln *LocalNode) SetStaticIP(ip net.IP) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.staticIP = ip
+ ln.updateEndpoints()
+}
+
+// SetFallbackIP sets the last-resort IP address. This address is used
+// if no endpoint prediction can be made and no static IP is set.
+func (ln *LocalNode) SetFallbackIP(ip net.IP) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.fallbackIP = ip
+ ln.updateEndpoints()
+}
+
+// SetFallbackUDP sets the last-resort UDP port. This port is used
+// if no endpoint prediction can be made.
+func (ln *LocalNode) SetFallbackUDP(port int) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.fallbackUDP = port
+ ln.updateEndpoints()
+}
+
+// UDPEndpointStatement should be called whenever a statement about the local node's
+// UDP endpoint is received. It feeds the local endpoint predictor.
+func (ln *LocalNode) UDPEndpointStatement(fromaddr, endpoint *net.UDPAddr) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.udpTrack.AddStatement(fromaddr.String(), endpoint.String())
+ ln.updateEndpoints()
+}
+
+// UDPContact should be called whenever the local node has announced itself to another node
+// via UDP. It feeds the local endpoint predictor.
+func (ln *LocalNode) UDPContact(toaddr *net.UDPAddr) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.udpTrack.AddContact(toaddr.String())
+ ln.updateEndpoints()
+}
+
+func (ln *LocalNode) updateEndpoints() {
+ // Determine the endpoints.
+ newIP := ln.fallbackIP
+ newUDP := ln.fallbackUDP
+ if ln.staticIP != nil {
+ newIP = ln.staticIP
+ } else if ip, port := predictAddr(ln.udpTrack); ip != nil {
+ newIP = ip
+ newUDP = port
+ }
+
+ // Update the record.
+ if newIP != nil && !newIP.IsUnspecified() {
+ ln.set(enr.IP(newIP))
+ if newUDP != 0 {
+ ln.set(enr.UDP(newUDP))
+ } else {
+ ln.delete(enr.UDP(0))
+ }
+ } else {
+ ln.delete(enr.IP{})
+ }
+}
+
+// predictAddr wraps IPTracker.PredictEndpoint, converting from its string-based
+// endpoint representation to IP and port types.
+func predictAddr(t *netutil.IPTracker) (net.IP, int) {
+ ep := t.PredictEndpoint()
+ if ep == "" {
+ return nil, 0
+ }
+ ipString, portString, _ := net.SplitHostPort(ep)
+ ip := net.ParseIP(ipString)
+ port, _ := strconv.Atoi(portString)
+ return ip, port
+}
+
+func (ln *LocalNode) invalidate() {
+ ln.cur.Store((*Node)(nil))
+}
+
+func (ln *LocalNode) sign() {
+ if n := ln.cur.Load().(*Node); n != nil {
+ return // no changes
+ }
+
+ var r enr.Record
+ for _, e := range ln.entries {
+ r.Set(e)
+ }
+ ln.bumpSeq()
+ r.SetSeq(ln.seq)
+ if err := SignV4(&r, ln.key); err != nil {
+ panic(fmt.Errorf("enode: can't sign record: %v", err))
+ }
+ n, err := New(ValidSchemes, &r)
+ if err != nil {
+ panic(fmt.Errorf("enode: can't verify local record: %v", err))
+ }
+ ln.cur.Store(n)
+ log.Info("New local node record", "seq", ln.seq, "id", n.ID(), "ip", n.IP(), "udp", n.UDP(), "tcp", n.TCP())
+}
+
+func (ln *LocalNode) bumpSeq() {
+ ln.seq++
+ ln.db.storeLocalSeq(ln.id, ln.seq)
+}
diff --git a/p2p/enode/localnode_test.go b/p2p/enode/localnode_test.go
new file mode 100644
index 000000000..f5e3496d6
--- /dev/null
+++ b/p2p/enode/localnode_test.go
@@ -0,0 +1,76 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package enode
+
+import (
+ "testing"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/p2p/enr"
+)
+
+func newLocalNodeForTesting() (*LocalNode, *DB) {
+ db, _ := OpenDB("")
+ key, _ := crypto.GenerateKey()
+ return NewLocalNode(db, key), db
+}
+
+func TestLocalNode(t *testing.T) {
+ ln, db := newLocalNodeForTesting()
+ defer db.Close()
+
+ if ln.Node().ID() != ln.ID() {
+ t.Fatal("inconsistent ID")
+ }
+
+ ln.Set(enr.WithEntry("x", uint(3)))
+ var x uint
+ if err := ln.Node().Load(enr.WithEntry("x", &x)); err != nil {
+ t.Fatal("can't load entry 'x':", err)
+ } else if x != 3 {
+ t.Fatal("wrong value for entry 'x':", x)
+ }
+}
+
+func TestLocalNodeSeqPersist(t *testing.T) {
+ ln, db := newLocalNodeForTesting()
+ defer db.Close()
+
+ if s := ln.Node().Seq(); s != 1 {
+ t.Fatalf("wrong initial seq %d, want 1", s)
+ }
+ ln.Set(enr.WithEntry("x", uint(1)))
+ if s := ln.Node().Seq(); s != 2 {
+ t.Fatalf("wrong seq %d after set, want 2", s)
+ }
+
+ // Create a new instance, it should reload the sequence number.
+ // The number increases just after that because a new record is
+ // created without the "x" entry.
+ ln2 := NewLocalNode(db, ln.key)
+ if s := ln2.Node().Seq(); s != 3 {
+ t.Fatalf("wrong seq %d on new instance, want 3", s)
+ }
+
+ // Create a new instance with a different node key on the same database.
+ // This should reset the sequence number.
+ key, _ := crypto.GenerateKey()
+ ln3 := NewLocalNode(db, key)
+ if s := ln3.Node().Seq(); s != 1 {
+ t.Fatalf("wrong seq %d on instance with changed key, want 1", s)
+ }
+}
diff --git a/p2p/enode/node.go b/p2p/enode/node.go
index 84088fcd2..b454ab255 100644
--- a/p2p/enode/node.go
+++ b/p2p/enode/node.go
@@ -98,6 +98,13 @@ func (n *Node) Pubkey() *ecdsa.PublicKey {
return &key
}
+// Record returns the node's record. The return value is a copy and may
+// be modified by the caller.
+func (n *Node) Record() *enr.Record {
+ cpy := n.r
+ return &cpy
+}
+
// checks whether n is a valid complete node.
func (n *Node) ValidateComplete() error {
if n.Incomplete() {
diff --git a/p2p/enode/nodedb.go b/p2p/enode/nodedb.go
index a929b75d7..7ee0c09a9 100644
--- a/p2p/enode/nodedb.go
+++ b/p2p/enode/nodedb.go
@@ -35,11 +35,24 @@ import (
"github.com/syndtr/goleveldb/leveldb/util"
)
+// Keys in the node database.
+const (
+ dbVersionKey = "version" // Version of the database to flush if changes
+ dbItemPrefix = "n:" // Identifier to prefix node entries with
+
+ dbDiscoverRoot = ":discover"
+ dbDiscoverSeq = dbDiscoverRoot + ":seq"
+ dbDiscoverPing = dbDiscoverRoot + ":lastping"
+ dbDiscoverPong = dbDiscoverRoot + ":lastpong"
+ dbDiscoverFindFails = dbDiscoverRoot + ":findfail"
+ dbLocalRoot = ":local"
+ dbLocalSeq = dbLocalRoot + ":seq"
+)
+
var (
- nodeDBNilID = ID{} // Special node ID to use as a nil element.
- nodeDBNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped.
- nodeDBCleanupCycle = time.Hour // Time period for running the expiration task.
- nodeDBVersion = 6
+ dbNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped.
+ dbCleanupCycle = time.Hour // Time period for running the expiration task.
+ dbVersion = 7
)
// DB is the node database, storing previously seen nodes and any collected metadata about
@@ -50,17 +63,6 @@ type DB struct {
quit chan struct{} // Channel to signal the expiring thread to stop
}
-// Schema layout for the node database
-var (
- nodeDBVersionKey = []byte("version") // Version of the database to flush if changes
- nodeDBItemPrefix = []byte("n:") // Identifier to prefix node entries with
-
- nodeDBDiscoverRoot = ":discover"
- nodeDBDiscoverPing = nodeDBDiscoverRoot + ":lastping"
- nodeDBDiscoverPong = nodeDBDiscoverRoot + ":lastpong"
- nodeDBDiscoverFindFails = nodeDBDiscoverRoot + ":findfail"
-)
-
// OpenDB opens a node database for storing and retrieving infos about known peers in the
// network. If no path is given an in-memory, temporary database is constructed.
func OpenDB(path string) (*DB, error) {
@@ -93,13 +95,13 @@ func newPersistentDB(path string) (*DB, error) {
// The nodes contained in the cache correspond to a certain protocol version.
// Flush all nodes if the version doesn't match.
currentVer := make([]byte, binary.MaxVarintLen64)
- currentVer = currentVer[:binary.PutVarint(currentVer, int64(nodeDBVersion))]
+ currentVer = currentVer[:binary.PutVarint(currentVer, int64(dbVersion))]
- blob, err := db.Get(nodeDBVersionKey, nil)
+ blob, err := db.Get([]byte(dbVersionKey), nil)
switch err {
case leveldb.ErrNotFound:
// Version not found (i.e. empty cache), insert it
- if err := db.Put(nodeDBVersionKey, currentVer, nil); err != nil {
+ if err := db.Put([]byte(dbVersionKey), currentVer, nil); err != nil {
db.Close()
return nil, err
}
@@ -120,28 +122,27 @@ func newPersistentDB(path string) (*DB, error) {
// makeKey generates the leveldb key-blob from a node id and its particular
// field of interest.
func makeKey(id ID, field string) []byte {
- if bytes.Equal(id[:], nodeDBNilID[:]) {
+ if (id == ID{}) {
return []byte(field)
}
- return append(nodeDBItemPrefix, append(id[:], field...)...)
+ return append([]byte(dbItemPrefix), append(id[:], field...)...)
}
// splitKey tries to split a database key into a node id and a field part.
func splitKey(key []byte) (id ID, field string) {
// If the key is not of a node, return it plainly
- if !bytes.HasPrefix(key, nodeDBItemPrefix) {
+ if !bytes.HasPrefix(key, []byte(dbItemPrefix)) {
return ID{}, string(key)
}
// Otherwise split the id and field
- item := key[len(nodeDBItemPrefix):]
+ item := key[len(dbItemPrefix):]
copy(id[:], item[:len(id)])
field = string(item[len(id):])
return id, field
}
-// fetchInt64 retrieves an integer instance associated with a particular
-// database key.
+// fetchInt64 retrieves an integer associated with a particular key.
func (db *DB) fetchInt64(key []byte) int64 {
blob, err := db.lvl.Get(key, nil)
if err != nil {
@@ -154,18 +155,33 @@ func (db *DB) fetchInt64(key []byte) int64 {
return val
}
-// storeInt64 update a specific database entry to the current time instance as a
-// unix timestamp.
+// storeInt64 stores an integer in the given key.
func (db *DB) storeInt64(key []byte, n int64) error {
blob := make([]byte, binary.MaxVarintLen64)
blob = blob[:binary.PutVarint(blob, n)]
+ return db.lvl.Put(key, blob, nil)
+}
+// fetchUint64 retrieves an integer associated with a particular key.
+func (db *DB) fetchUint64(key []byte) uint64 {
+ blob, err := db.lvl.Get(key, nil)
+ if err != nil {
+ return 0
+ }
+ val, _ := binary.Uvarint(blob)
+ return val
+}
+
+// storeUint64 stores an integer in the given key.
+func (db *DB) storeUint64(key []byte, n uint64) error {
+ blob := make([]byte, binary.MaxVarintLen64)
+ blob = blob[:binary.PutUvarint(blob, n)]
return db.lvl.Put(key, blob, nil)
}
// Node retrieves a node with a given id from the database.
func (db *DB) Node(id ID) *Node {
- blob, err := db.lvl.Get(makeKey(id, nodeDBDiscoverRoot), nil)
+ blob, err := db.lvl.Get(makeKey(id, dbDiscoverRoot), nil)
if err != nil {
return nil
}
@@ -184,11 +200,31 @@ func mustDecodeNode(id, data []byte) *Node {
// UpdateNode inserts - potentially overwriting - a node into the peer database.
func (db *DB) UpdateNode(node *Node) error {
+ if node.Seq() < db.NodeSeq(node.ID()) {
+ return nil
+ }
blob, err := rlp.EncodeToBytes(&node.r)
if err != nil {
return err
}
- return db.lvl.Put(makeKey(node.ID(), nodeDBDiscoverRoot), blob, nil)
+ if err := db.lvl.Put(makeKey(node.ID(), dbDiscoverRoot), blob, nil); err != nil {
+ return err
+ }
+ return db.storeUint64(makeKey(node.ID(), dbDiscoverSeq), node.Seq())
+}
+
+// NodeSeq returns the stored record sequence number of the given node.
+func (db *DB) NodeSeq(id ID) uint64 {
+ return db.fetchUint64(makeKey(id, dbDiscoverSeq))
+}
+
+// Resolve returns the stored record of the node if it has a larger sequence
+// number than n.
+func (db *DB) Resolve(n *Node) *Node {
+ if n.Seq() > db.NodeSeq(n.ID()) {
+ return n
+ }
+ return db.Node(n.ID())
}
// DeleteNode deletes all information/keys associated with a node.
@@ -218,7 +254,7 @@ func (db *DB) ensureExpirer() {
// expirer should be started in a go routine, and is responsible for looping ad
// infinitum and dropping stale data from the database.
func (db *DB) expirer() {
- tick := time.NewTicker(nodeDBCleanupCycle)
+ tick := time.NewTicker(dbCleanupCycle)
defer tick.Stop()
for {
select {
@@ -235,7 +271,7 @@ func (db *DB) expirer() {
// expireNodes iterates over the database and deletes all nodes that have not
// been seen (i.e. received a pong from) for some allotted time.
func (db *DB) expireNodes() error {
- threshold := time.Now().Add(-nodeDBNodeExpiration)
+ threshold := time.Now().Add(-dbNodeExpiration)
// Find discovered nodes that are older than the allowance
it := db.lvl.NewIterator(nil, nil)
@@ -244,7 +280,7 @@ func (db *DB) expireNodes() error {
for it.Next() {
// Skip the item if not a discovery node
id, field := splitKey(it.Key())
- if field != nodeDBDiscoverRoot {
+ if field != dbDiscoverRoot {
continue
}
// Skip the node if not expired yet (and not self)
@@ -260,34 +296,44 @@ func (db *DB) expireNodes() error {
// LastPingReceived retrieves the time of the last ping packet received from
// a remote node.
func (db *DB) LastPingReceived(id ID) time.Time {
- return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPing)), 0)
+ return time.Unix(db.fetchInt64(makeKey(id, dbDiscoverPing)), 0)
}
// UpdateLastPingReceived updates the last time we tried contacting a remote node.
func (db *DB) UpdateLastPingReceived(id ID, instance time.Time) error {
- return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix())
+ return db.storeInt64(makeKey(id, dbDiscoverPing), instance.Unix())
}
// LastPongReceived retrieves the time of the last successful pong from remote node.
func (db *DB) LastPongReceived(id ID) time.Time {
// Launch expirer
db.ensureExpirer()
- return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0)
+ return time.Unix(db.fetchInt64(makeKey(id, dbDiscoverPong)), 0)
}
// UpdateLastPongReceived updates the last pong time of a node.
func (db *DB) UpdateLastPongReceived(id ID, instance time.Time) error {
- return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix())
+ return db.storeInt64(makeKey(id, dbDiscoverPong), instance.Unix())
}
// FindFails retrieves the number of findnode failures since bonding.
func (db *DB) FindFails(id ID) int {
- return int(db.fetchInt64(makeKey(id, nodeDBDiscoverFindFails)))
+ return int(db.fetchInt64(makeKey(id, dbDiscoverFindFails)))
}
// UpdateFindFails updates the number of findnode failures since bonding.
func (db *DB) UpdateFindFails(id ID, fails int) error {
- return db.storeInt64(makeKey(id, nodeDBDiscoverFindFails), int64(fails))
+ return db.storeInt64(makeKey(id, dbDiscoverFindFails), int64(fails))
+}
+
+// LocalSeq retrieves the local record sequence counter.
+func (db *DB) localSeq(id ID) uint64 {
+ return db.fetchUint64(makeKey(id, dbLocalSeq))
+}
+
+// storeLocalSeq stores the local record sequence counter.
+func (db *DB) storeLocalSeq(id ID, n uint64) {
+ db.storeUint64(makeKey(id, dbLocalSeq), n)
}
// QuerySeeds retrieves random nodes to be used as potential seed nodes
@@ -309,7 +355,7 @@ seek:
ctr := id[0]
rand.Read(id[:])
id[0] = ctr + id[0]%16
- it.Seek(makeKey(id, nodeDBDiscoverRoot))
+ it.Seek(makeKey(id, dbDiscoverRoot))
n := nextNode(it)
if n == nil {
@@ -334,7 +380,7 @@ seek:
func nextNode(it iterator.Iterator) *Node {
for end := false; !end; end = !it.Next() {
id, field := splitKey(it.Key())
- if field != nodeDBDiscoverRoot {
+ if field != dbDiscoverRoot {
continue
}
return mustDecodeNode(id[:], it.Value())
diff --git a/p2p/enode/nodedb_test.go b/p2p/enode/nodedb_test.go
index b476a3439..96794827c 100644
--- a/p2p/enode/nodedb_test.go
+++ b/p2p/enode/nodedb_test.go
@@ -332,7 +332,7 @@ var nodeDBExpirationNodes = []struct {
30303,
30303,
),
- pong: time.Now().Add(-nodeDBNodeExpiration + time.Minute),
+ pong: time.Now().Add(-dbNodeExpiration + time.Minute),
exp: false,
}, {
node: NewV4(
@@ -341,7 +341,7 @@ var nodeDBExpirationNodes = []struct {
30303,
30303,
),
- pong: time.Now().Add(-nodeDBNodeExpiration - time.Minute),
+ pong: time.Now().Add(-dbNodeExpiration - time.Minute),
exp: true,
},
}
diff --git a/p2p/enr/enr.go b/p2p/enr/enr.go
index 251caf458..444820c15 100644
--- a/p2p/enr/enr.go
+++ b/p2p/enr/enr.go
@@ -156,7 +156,7 @@ func (r *Record) Set(e Entry) {
}
func (r *Record) invalidate() {
- if r.signature == nil {
+ if r.signature != nil {
r.seq++
}
r.signature = nil
diff --git a/p2p/enr/enr_test.go b/p2p/enr/enr_test.go
index 9bf22478d..449c898a8 100644
--- a/p2p/enr/enr_test.go
+++ b/p2p/enr/enr_test.go
@@ -169,6 +169,18 @@ func TestDirty(t *testing.T) {
}
}
+func TestSeq(t *testing.T) {
+ var r Record
+
+ assert.Equal(t, uint64(0), r.Seq())
+ r.Set(UDP(1))
+ assert.Equal(t, uint64(0), r.Seq())
+ signTest([]byte{5}, &r)
+ assert.Equal(t, uint64(0), r.Seq())
+ r.Set(UDP(2))
+ assert.Equal(t, uint64(1), r.Seq())
+}
+
// TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record.
func TestGetSetOverwrite(t *testing.T) {
var r Record
diff --git a/p2p/nat/nat.go b/p2p/nat/nat.go
index a254648c6..8fad921c4 100644
--- a/p2p/nat/nat.go
+++ b/p2p/nat/nat.go
@@ -129,21 +129,15 @@ func Map(m Interface, c chan struct{}, protocol string, extport, intport int, na
// ExtIP assumes that the local machine is reachable on the given
// external IP address, and that any required ports were mapped manually.
// Mapping operations will not return an error but won't actually do anything.
-func ExtIP(ip net.IP) Interface {
- if ip == nil {
- panic("IP must not be nil")
- }
- return extIP(ip)
-}
+type ExtIP net.IP
-type extIP net.IP
-
-func (n extIP) ExternalIP() (net.IP, error) { return net.IP(n), nil }
-func (n extIP) String() string { return fmt.Sprintf("ExtIP(%v)", net.IP(n)) }
+func (n ExtIP) ExternalIP() (net.IP, error) { return net.IP(n), nil }
+func (n ExtIP) String() string { return fmt.Sprintf("ExtIP(%v)", net.IP(n)) }
// These do nothing.
-func (extIP) AddMapping(string, int, int, string, time.Duration) error { return nil }
-func (extIP) DeleteMapping(string, int, int) error { return nil }
+
+func (ExtIP) AddMapping(string, int, int, string, time.Duration) error { return nil }
+func (ExtIP) DeleteMapping(string, int, int) error { return nil }
// Any returns a port mapper that tries to discover any supported
// mechanism on the local network.
diff --git a/p2p/nat/nat_test.go b/p2p/nat/nat_test.go
index 469101e99..814e6d9e1 100644
--- a/p2p/nat/nat_test.go
+++ b/p2p/nat/nat_test.go
@@ -28,7 +28,7 @@ import (
func TestAutoDiscRace(t *testing.T) {
ad := startautodisc("thing", func() Interface {
time.Sleep(500 * time.Millisecond)
- return extIP{33, 44, 55, 66}
+ return ExtIP{33, 44, 55, 66}
})
// Spawn a few concurrent calls to ad.ExternalIP.
diff --git a/p2p/netutil/iptrack.go b/p2p/netutil/iptrack.go
new file mode 100644
index 000000000..b9cbd5e1c
--- /dev/null
+++ b/p2p/netutil/iptrack.go
@@ -0,0 +1,130 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package netutil
+
+import (
+ "time"
+
+ "github.com/ethereum/go-ethereum/common/mclock"
+)
+
+// IPTracker predicts the external endpoint, i.e. IP address and port, of the local host
+// based on statements made by other hosts.
+type IPTracker struct {
+ window time.Duration
+ contactWindow time.Duration
+ minStatements int
+ clock mclock.Clock
+ statements map[string]ipStatement
+ contact map[string]mclock.AbsTime
+ lastStatementGC mclock.AbsTime
+ lastContactGC mclock.AbsTime
+}
+
+type ipStatement struct {
+ endpoint string
+ time mclock.AbsTime
+}
+
+// NewIPTracker creates an IP tracker.
+//
+// The window parameters configure the amount of past network events which are kept. The
+// minStatements parameter enforces a minimum number of statements which must be recorded
+// before any prediction is made. Higher values for these parameters decrease 'flapping' of
+// predictions as network conditions change. Window duration values should typically be in
+// the range of minutes.
+func NewIPTracker(window, contactWindow time.Duration, minStatements int) *IPTracker {
+ return &IPTracker{
+ window: window,
+ contactWindow: contactWindow,
+ statements: make(map[string]ipStatement),
+ minStatements: minStatements,
+ contact: make(map[string]mclock.AbsTime),
+ clock: mclock.System{},
+ }
+}
+
+// PredictFullConeNAT checks whether the local host is behind full cone NAT. It predicts by
+// checking whether any statement has been received from a node we didn't contact before
+// the statement was made.
+func (it *IPTracker) PredictFullConeNAT() bool {
+ now := it.clock.Now()
+ it.gcContact(now)
+ it.gcStatements(now)
+ for host, st := range it.statements {
+ if c, ok := it.contact[host]; !ok || c > st.time {
+ return true
+ }
+ }
+ return false
+}
+
+// PredictEndpoint returns the current prediction of the external endpoint.
+func (it *IPTracker) PredictEndpoint() string {
+ it.gcStatements(it.clock.Now())
+
+ // The current strategy is simple: find the endpoint with most statements.
+ counts := make(map[string]int)
+ maxcount, max := 0, ""
+ for _, s := range it.statements {
+ c := counts[s.endpoint] + 1
+ counts[s.endpoint] = c
+ if c > maxcount && c >= it.minStatements {
+ maxcount, max = c, s.endpoint
+ }
+ }
+ return max
+}
+
+// AddStatement records that a certain host thinks our external endpoint is the one given.
+func (it *IPTracker) AddStatement(host, endpoint string) {
+ now := it.clock.Now()
+ it.statements[host] = ipStatement{endpoint, now}
+ if time.Duration(now-it.lastStatementGC) >= it.window {
+ it.gcStatements(now)
+ }
+}
+
+// AddContact records that a packet containing our endpoint information has been sent to a
+// certain host.
+func (it *IPTracker) AddContact(host string) {
+ now := it.clock.Now()
+ it.contact[host] = now
+ if time.Duration(now-it.lastContactGC) >= it.contactWindow {
+ it.gcContact(now)
+ }
+}
+
+func (it *IPTracker) gcStatements(now mclock.AbsTime) {
+ it.lastStatementGC = now
+ cutoff := now.Add(-it.window)
+ for host, s := range it.statements {
+ if s.time < cutoff {
+ delete(it.statements, host)
+ }
+ }
+}
+
+func (it *IPTracker) gcContact(now mclock.AbsTime) {
+ it.lastContactGC = now
+ cutoff := now.Add(-it.contactWindow)
+ for host, ct := range it.contact {
+ if ct < cutoff {
+ delete(it.contact, host)
+ }
+ }
+}
diff --git a/p2p/netutil/iptrack_test.go b/p2p/netutil/iptrack_test.go
new file mode 100644
index 000000000..a9a2998a6
--- /dev/null
+++ b/p2p/netutil/iptrack_test.go
@@ -0,0 +1,138 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package netutil
+
+import (
+ "fmt"
+ mrand "math/rand"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common/mclock"
+)
+
+const (
+ opStatement = iota
+ opContact
+ opPredict
+ opCheckFullCone
+)
+
+type iptrackTestEvent struct {
+ op int
+ time int // absolute, in milliseconds
+ ip, from string
+}
+
+func TestIPTracker(t *testing.T) {
+ tests := map[string][]iptrackTestEvent{
+ "minStatements": {
+ {opPredict, 0, "", ""},
+ {opStatement, 0, "127.0.0.1", "127.0.0.2"},
+ {opPredict, 1000, "", ""},
+ {opStatement, 1000, "127.0.0.1", "127.0.0.3"},
+ {opPredict, 1000, "", ""},
+ {opStatement, 1000, "127.0.0.1", "127.0.0.4"},
+ {opPredict, 1000, "127.0.0.1", ""},
+ },
+ "window": {
+ {opStatement, 0, "127.0.0.1", "127.0.0.2"},
+ {opStatement, 2000, "127.0.0.1", "127.0.0.3"},
+ {opStatement, 3000, "127.0.0.1", "127.0.0.4"},
+ {opPredict, 10000, "127.0.0.1", ""},
+ {opPredict, 10001, "", ""}, // first statement expired
+ {opStatement, 10100, "127.0.0.1", "127.0.0.2"},
+ {opPredict, 10200, "127.0.0.1", ""},
+ },
+ "fullcone": {
+ {opContact, 0, "", "127.0.0.2"},
+ {opStatement, 10, "127.0.0.1", "127.0.0.2"},
+ {opContact, 2000, "", "127.0.0.3"},
+ {opStatement, 2010, "127.0.0.1", "127.0.0.3"},
+ {opContact, 3000, "", "127.0.0.4"},
+ {opStatement, 3010, "127.0.0.1", "127.0.0.4"},
+ {opCheckFullCone, 3500, "false", ""},
+ },
+ "fullcone_2": {
+ {opContact, 0, "", "127.0.0.2"},
+ {opStatement, 10, "127.0.0.1", "127.0.0.2"},
+ {opContact, 2000, "", "127.0.0.3"},
+ {opStatement, 2010, "127.0.0.1", "127.0.0.3"},
+ {opStatement, 3000, "127.0.0.1", "127.0.0.4"},
+ {opContact, 3010, "", "127.0.0.4"},
+ {opCheckFullCone, 3500, "true", ""},
+ },
+ }
+ for name, test := range tests {
+ t.Run(name, func(t *testing.T) { runIPTrackerTest(t, test) })
+ }
+}
+
+func runIPTrackerTest(t *testing.T, evs []iptrackTestEvent) {
+ var (
+ clock mclock.Simulated
+ it = NewIPTracker(10*time.Second, 10*time.Second, 3)
+ )
+ it.clock = &clock
+ for i, ev := range evs {
+ evtime := time.Duration(ev.time) * time.Millisecond
+ clock.Run(evtime - time.Duration(clock.Now()))
+ switch ev.op {
+ case opStatement:
+ it.AddStatement(ev.from, ev.ip)
+ case opContact:
+ it.AddContact(ev.from)
+ case opPredict:
+ if pred := it.PredictEndpoint(); pred != ev.ip {
+ t.Errorf("op %d: wrong prediction %q, want %q", i, pred, ev.ip)
+ }
+ case opCheckFullCone:
+ pred := fmt.Sprintf("%t", it.PredictFullConeNAT())
+ if pred != ev.ip {
+ t.Errorf("op %d: wrong prediction %s, want %s", i, pred, ev.ip)
+ }
+ }
+ }
+}
+
+// This checks that old statements and contacts are GCed even if Predict* isn't called.
+func TestIPTrackerForceGC(t *testing.T) {
+ var (
+ clock mclock.Simulated
+ window = 10 * time.Second
+ rate = 50 * time.Millisecond
+ max = int(window/rate) + 1
+ it = NewIPTracker(window, window, 3)
+ )
+ it.clock = &clock
+
+ for i := 0; i < 5*max; i++ {
+ e1 := make([]byte, 4)
+ e2 := make([]byte, 4)
+ mrand.Read(e1)
+ mrand.Read(e2)
+ it.AddStatement(string(e1), string(e2))
+ it.AddContact(string(e1))
+ clock.Run(rate)
+ }
+ if len(it.contact) > 2*max {
+ t.Errorf("contacts not GCed, have %d", len(it.contact))
+ }
+ if len(it.statements) > 2*max {
+ t.Errorf("statements not GCed, have %d", len(it.statements))
+ }
+}
diff --git a/p2p/protocol.go b/p2p/protocol.go
index 4b90a2a70..9438ab8e4 100644
--- a/p2p/protocol.go
+++ b/p2p/protocol.go
@@ -20,6 +20,7 @@ import (
"fmt"
"github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/p2p/enr"
)
// Protocol represents a P2P subprotocol implementation.
@@ -52,6 +53,9 @@ type Protocol struct {
// about a certain peer in the network. If an info retrieval function is set,
// but returns nil, it is assumed that the protocol handshake is still running.
PeerInfo func(id enode.ID) interface{}
+
+ // Attributes contains protocol specific information for the node record.
+ Attributes []enr.Entry
}
func (p Protocol) cap() Cap {
@@ -64,10 +68,6 @@ type Cap struct {
Version uint
}
-func (cap Cap) RlpData() interface{} {
- return []interface{}{cap.Name, cap.Version}
-}
-
func (cap Cap) String() string {
return fmt.Sprintf("%s/%d", cap.Name, cap.Version)
}
@@ -79,3 +79,5 @@ func (cs capsByNameAndVersion) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
func (cs capsByNameAndVersion) Less(i, j int) bool {
return cs[i].Name < cs[j].Name || (cs[i].Name == cs[j].Name && cs[i].Version < cs[j].Version)
}
+
+func (capsByNameAndVersion) ENRKey() string { return "cap" }
diff --git a/p2p/server.go b/p2p/server.go
index 40db758e2..6482c0401 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -20,9 +20,11 @@ package p2p
import (
"bytes"
"crypto/ecdsa"
+ "encoding/hex"
"errors"
"fmt"
"net"
+ "sort"
"sync"
"sync/atomic"
"time"
@@ -35,8 +37,10 @@ import (
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil"
+ "github.com/ethereum/go-ethereum/rlp"
)
const (
@@ -160,6 +164,8 @@ type Server struct {
lock sync.Mutex // protects running
running bool
+ nodedb *enode.DB
+ localnode *enode.LocalNode
ntab discoverTable
listener net.Listener
ourHandshake *protoHandshake
@@ -347,43 +353,13 @@ func (srv *Server) SubscribeEvents(ch chan *PeerEvent) event.Subscription {
// Self returns the local node's endpoint information.
func (srv *Server) Self() *enode.Node {
srv.lock.Lock()
- running, listener, ntab := srv.running, srv.listener, srv.ntab
+ ln := srv.localnode
srv.lock.Unlock()
- if !running {
+ if ln == nil {
return enode.NewV4(&srv.PrivateKey.PublicKey, net.ParseIP("0.0.0.0"), 0, 0)
}
- return srv.makeSelf(listener, ntab)
-}
-
-func (srv *Server) makeSelf(listener net.Listener, ntab discoverTable) *enode.Node {
- // If the node is running but discovery is off, manually assemble the node infos.
- if ntab == nil {
- addr := srv.tcpAddr(listener)
- return enode.NewV4(&srv.PrivateKey.PublicKey, addr.IP, addr.Port, 0)
- }
- // Otherwise return the discovery node.
- return ntab.Self()
-}
-
-func (srv *Server) tcpAddr(listener net.Listener) net.TCPAddr {
- addr := net.TCPAddr{IP: net.IP{0, 0, 0, 0}}
- if listener == nil {
- return addr // Inbound connections disabled, use zero address.
- }
- // Otherwise inject the listener address too.
- if a, ok := listener.Addr().(*net.TCPAddr); ok {
- addr = *a
- }
- if srv.NAT != nil {
- if ip, err := srv.NAT.ExternalIP(); err == nil {
- addr.IP = ip
- }
- }
- if addr.IP.IsUnspecified() {
- addr.IP = net.IP{127, 0, 0, 1}
- }
- return addr
+ return ln.Node()
}
// Stop terminates the server and all active peer connections.
@@ -443,7 +419,9 @@ func (srv *Server) Start() (err error) {
if srv.log == nil {
srv.log = log.New()
}
- srv.log.Info("Starting P2P networking")
+ if srv.NoDial && srv.ListenAddr == "" {
+ srv.log.Warn("P2P server will be useless, neither dialing nor listening")
+ }
// static fields
if srv.PrivateKey == nil {
@@ -466,65 +444,120 @@ func (srv *Server) Start() (err error) {
srv.peerOp = make(chan peerOpFunc)
srv.peerOpDone = make(chan struct{})
- var (
- conn *net.UDPConn
- sconn *sharedUDPConn
- realaddr *net.UDPAddr
- unhandled chan discover.ReadPacket
- )
-
- if !srv.NoDiscovery || srv.DiscoveryV5 {
- addr, err := net.ResolveUDPAddr("udp", srv.ListenAddr)
- if err != nil {
+ if err := srv.setupLocalNode(); err != nil {
+ return err
+ }
+ if srv.ListenAddr != "" {
+ if err := srv.setupListening(); err != nil {
return err
}
- conn, err = net.ListenUDP("udp", addr)
- if err != nil {
- return err
- }
- realaddr = conn.LocalAddr().(*net.UDPAddr)
- if srv.NAT != nil {
- if !realaddr.IP.IsLoopback() {
- go nat.Map(srv.NAT, srv.quit, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
- }
- // TODO: react to external IP changes over time.
- if ext, err := srv.NAT.ExternalIP(); err == nil {
- realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
- }
- }
+ }
+ if err := srv.setupDiscovery(); err != nil {
+ return err
}
- if !srv.NoDiscovery && srv.DiscoveryV5 {
- unhandled = make(chan discover.ReadPacket, 100)
- sconn = &sharedUDPConn{conn, unhandled}
+ dynPeers := srv.maxDialedConns()
+ dialer := newDialState(srv.localnode.ID(), srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
+ srv.loopWG.Add(1)
+ go srv.run(dialer)
+ return nil
+}
+
+func (srv *Server) setupLocalNode() error {
+ // Create the devp2p handshake.
+ pubkey := crypto.FromECDSAPub(&srv.PrivateKey.PublicKey)
+ srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: pubkey[1:]}
+ for _, p := range srv.Protocols {
+ srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
+ }
+ sort.Sort(capsByNameAndVersion(srv.ourHandshake.Caps))
+
+ // Create the local node.
+ db, err := enode.OpenDB(srv.Config.NodeDatabase)
+ if err != nil {
+ return err
+ }
+ srv.nodedb = db
+ srv.localnode = enode.NewLocalNode(db, srv.PrivateKey)
+ srv.localnode.SetFallbackIP(net.IP{127, 0, 0, 1})
+ srv.localnode.Set(capsByNameAndVersion(srv.ourHandshake.Caps))
+ // TODO: check conflicts
+ for _, p := range srv.Protocols {
+ for _, e := range p.Attributes {
+ srv.localnode.Set(e)
+ }
+ }
+ switch srv.NAT.(type) {
+ case nil:
+ // No NAT interface, do nothing.
+ case nat.ExtIP:
+ // ExtIP doesn't block, set the IP right away.
+ ip, _ := srv.NAT.ExternalIP()
+ srv.localnode.SetStaticIP(ip)
+ default:
+ // Ask the router about the IP. This takes a while and blocks startup,
+ // do it in the background.
+ srv.loopWG.Add(1)
+ go func() {
+ defer srv.loopWG.Done()
+ if ip, err := srv.NAT.ExternalIP(); err == nil {
+ srv.localnode.SetStaticIP(ip)
+ }
+ }()
+ }
+ return nil
+}
+
+func (srv *Server) setupDiscovery() error {
+ if srv.NoDiscovery && !srv.DiscoveryV5 {
+ return nil
}
- // node table
+ addr, err := net.ResolveUDPAddr("udp", srv.ListenAddr)
+ if err != nil {
+ return err
+ }
+ conn, err := net.ListenUDP("udp", addr)
+ if err != nil {
+ return err
+ }
+ realaddr := conn.LocalAddr().(*net.UDPAddr)
+ srv.log.Debug("UDP listener up", "addr", realaddr)
+ if srv.NAT != nil {
+ if !realaddr.IP.IsLoopback() {
+ go nat.Map(srv.NAT, srv.quit, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
+ }
+ }
+ srv.localnode.SetFallbackUDP(realaddr.Port)
+
+ // Discovery V4
+ var unhandled chan discover.ReadPacket
+ var sconn *sharedUDPConn
if !srv.NoDiscovery {
- cfg := discover.Config{
- PrivateKey: srv.PrivateKey,
- AnnounceAddr: realaddr,
- NodeDBPath: srv.NodeDatabase,
- NetRestrict: srv.NetRestrict,
- Bootnodes: srv.BootstrapNodes,
- Unhandled: unhandled,
+ if srv.DiscoveryV5 {
+ unhandled = make(chan discover.ReadPacket, 100)
+ sconn = &sharedUDPConn{conn, unhandled}
}
- ntab, err := discover.ListenUDP(conn, cfg)
+ cfg := discover.Config{
+ PrivateKey: srv.PrivateKey,
+ NetRestrict: srv.NetRestrict,
+ Bootnodes: srv.BootstrapNodes,
+ Unhandled: unhandled,
+ }
+ ntab, err := discover.ListenUDP(conn, srv.localnode, cfg)
if err != nil {
return err
}
srv.ntab = ntab
}
-
+ // Discovery V5
if srv.DiscoveryV5 {
- var (
- ntab *discv5.Network
- err error
- )
+ var ntab *discv5.Network
+ var err error
if sconn != nil {
- ntab, err = discv5.ListenUDP(srv.PrivateKey, sconn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase)
+ ntab, err = discv5.ListenUDP(srv.PrivateKey, sconn, "", srv.NetRestrict)
} else {
- ntab, err = discv5.ListenUDP(srv.PrivateKey, conn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase)
+ ntab, err = discv5.ListenUDP(srv.PrivateKey, conn, "", srv.NetRestrict)
}
if err != nil {
return err
@@ -534,32 +567,10 @@ func (srv *Server) Start() (err error) {
}
srv.DiscV5 = ntab
}
-
- dynPeers := srv.maxDialedConns()
- dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
-
- // handshake
- pubkey := crypto.FromECDSAPub(&srv.PrivateKey.PublicKey)
- srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: pubkey[1:]}
- for _, p := range srv.Protocols {
- srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
- }
- // listen/dial
- if srv.ListenAddr != "" {
- if err := srv.startListening(); err != nil {
- return err
- }
- }
- if srv.NoDial && srv.ListenAddr == "" {
- srv.log.Warn("P2P server will be useless, neither dialing nor listening")
- }
-
- srv.loopWG.Add(1)
- go srv.run(dialer)
return nil
}
-func (srv *Server) startListening() error {
+func (srv *Server) setupListening() error {
// Launch the TCP listener.
listener, err := net.Listen("tcp", srv.ListenAddr)
if err != nil {
@@ -568,8 +579,11 @@ func (srv *Server) startListening() error {
laddr := listener.Addr().(*net.TCPAddr)
srv.ListenAddr = laddr.String()
srv.listener = listener
+ srv.localnode.Set(enr.TCP(laddr.Port))
+
srv.loopWG.Add(1)
go srv.listenLoop()
+
// Map the TCP listening port if NAT is configured.
if !laddr.IP.IsLoopback() && srv.NAT != nil {
srv.loopWG.Add(1)
@@ -589,7 +603,10 @@ type dialer interface {
}
func (srv *Server) run(dialstate dialer) {
+ srv.log.Info("Started P2P networking", "self", srv.localnode.Node())
defer srv.loopWG.Done()
+ defer srv.nodedb.Close()
+
var (
peers = make(map[enode.ID]*Peer)
inboundCount = 0
@@ -781,7 +798,7 @@ func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int
return DiscTooManyPeers
case peers[c.node.ID()] != nil:
return DiscAlreadyConnected
- case c.node.ID() == srv.Self().ID():
+ case c.node.ID() == srv.localnode.ID():
return DiscSelf
default:
return nil
@@ -802,15 +819,11 @@ func (srv *Server) maxDialedConns() int {
return srv.MaxPeers / r
}
-type tempError interface {
- Temporary() bool
-}
-
// listenLoop runs in its own goroutine and accepts
// inbound connections.
func (srv *Server) listenLoop() {
defer srv.loopWG.Done()
- srv.log.Info("RLPx listener up", "self", srv.Self())
+ srv.log.Debug("TCP listener up", "addr", srv.listener.Addr())
tokens := defaultMaxPendingPeers
if srv.MaxPendingPeers > 0 {
@@ -831,7 +844,7 @@ func (srv *Server) listenLoop() {
)
for {
fd, err = srv.listener.Accept()
- if tempErr, ok := err.(tempError); ok && tempErr.Temporary() {
+ if netutil.IsTemporaryError(err) {
srv.log.Debug("Temporary read error", "err", err)
continue
} else if err != nil {
@@ -864,10 +877,6 @@ func (srv *Server) listenLoop() {
// as a peer. It returns when the connection has been added as a peer
// or the handshakes have failed.
func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) error {
- self := srv.Self()
- if self == nil {
- return errors.New("shutdown")
- }
c := &conn{fd: fd, transport: srv.newTransport(fd), flags: flags, cont: make(chan error)}
err := srv.setupConn(c, flags, dialDest)
if err != nil {
@@ -1003,6 +1012,7 @@ type NodeInfo struct {
ID string `json:"id"` // Unique node identifier (also the encryption key)
Name string `json:"name"` // Name of the node, including client type, version, OS, custom data
Enode string `json:"enode"` // Enode URL for adding this peer from remote peers
+ ENR string `json:"enr"` // Ethereum Node Record
IP string `json:"ip"` // IP address of the node
Ports struct {
Discovery int `json:"discovery"` // UDP listening port for discovery protocol
@@ -1014,9 +1024,8 @@ type NodeInfo struct {
// NodeInfo gathers and returns a collection of metadata known about the host.
func (srv *Server) NodeInfo() *NodeInfo {
- node := srv.Self()
-
// Gather and assemble the generic node infos
+ node := srv.Self()
info := &NodeInfo{
Name: srv.Name,
Enode: node.String(),
@@ -1027,6 +1036,9 @@ func (srv *Server) NodeInfo() *NodeInfo {
}
info.Ports.Discovery = node.UDP()
info.Ports.Listener = node.TCP()
+ if enc, err := rlp.EncodeToBytes(node.Record()); err == nil {
+ info.ENR = "0x" + hex.EncodeToString(enc)
+ }
// Gather all the running protocol infos (only once per protocol type)
for _, proto := range srv.Protocols {
diff --git a/p2p/server_test.go b/p2p/server_test.go
index e0b1fc122..7e11577d6 100644
--- a/p2p/server_test.go
+++ b/p2p/server_test.go
@@ -225,12 +225,15 @@ func TestServerTaskScheduling(t *testing.T) {
// The Server in this test isn't actually running
// because we're only interested in what run does.
+ db, _ := enode.OpenDB("")
srv := &Server{
- Config: Config{MaxPeers: 10},
- quit: make(chan struct{}),
- ntab: fakeTable{},
- running: true,
- log: log.New(),
+ Config: Config{MaxPeers: 10},
+ localnode: enode.NewLocalNode(db, newkey()),
+ nodedb: db,
+ quit: make(chan struct{}),
+ ntab: fakeTable{},
+ running: true,
+ log: log.New(),
}
srv.loopWG.Add(1)
go func() {
@@ -271,11 +274,14 @@ func TestServerManyTasks(t *testing.T) {
}
var (
- srv = &Server{
- quit: make(chan struct{}),
- ntab: fakeTable{},
- running: true,
- log: log.New(),
+ db, _ = enode.OpenDB("")
+ srv = &Server{
+ quit: make(chan struct{}),
+ localnode: enode.NewLocalNode(db, newkey()),
+ nodedb: db,
+ ntab: fakeTable{},
+ running: true,
+ log: log.New(),
}
done = make(chan *testTask)
start, end = 0, 0