diff --git a/p2p/discover/v4_udp.go b/p2p/discover/v4_udp.go index daad2958e..bc665d0a1 100644 --- a/p2p/discover/v4_udp.go +++ b/p2p/discover/v4_udp.go @@ -83,7 +83,9 @@ type UDPv4 struct { closeOnce sync.Once wg sync.WaitGroup - addReplyMatcher chan *replyMatcher + addReplyMatcher chan *replyMatcher + addReplyMatcherMutex sync.Mutex + gotreply chan reply gotkey chan v4wire.Pubkey gotnodes chan nodes @@ -160,7 +162,7 @@ func ListenV4(ctx context.Context, protocol string, c UDPConn, ln *enode.LocalNo localNode: ln, db: ln.Database(), gotreply: make(chan reply, 10), - addReplyMatcher: make(chan *replyMatcher), + addReplyMatcher: make(chan *replyMatcher, 10), gotkey: make(chan v4wire.Pubkey, 10), gotnodes: make(chan nodes, 10), replyTimeout: cfg.ReplyTimeout, @@ -456,6 +458,13 @@ func (t *UDPv4) pending(id enode.ID, ip net.IP, port int, ptype byte, callback r ch := make(chan error, 1) p := &replyMatcher{from: id, ip: ip, port: port, ptype: ptype, callback: callback, errc: ch} + t.addReplyMatcherMutex.Lock() + defer t.addReplyMatcherMutex.Unlock() + if t.addReplyMatcher == nil { + ch <- errClosed + return p + } + select { case t.addReplyMatcher <- p: // loop will handle it @@ -582,6 +591,14 @@ func (t *UDPv4) loop() { el.Value.(*replyMatcher).errc <- errClosed } }() + + t.addReplyMatcherMutex.Lock() + defer t.addReplyMatcherMutex.Unlock() + close(t.addReplyMatcher) + for matcher := range t.addReplyMatcher { + matcher.errc <- errClosed + } + t.addReplyMatcher = nil return case p := <-t.addReplyMatcher: