diff --git a/cmd/headers/download/sentry.go b/cmd/headers/download/sentry.go index f38703b85..1b333346f 100644 --- a/cmd/headers/download/sentry.go +++ b/cmd/headers/download/sentry.go @@ -70,6 +70,7 @@ type PeerInfo struct { deadlines []time.Time // Request deadlines height uint64 rw p2p.MsgReadWriter + removed bool } // AddDeadline adds given deadline to the list of deadlines @@ -100,6 +101,18 @@ func (pi *PeerInfo) ClearDeadlines(now time.Time, givePermit bool) int { return len(pi.deadlines) } +func (pi *PeerInfo) Remove() { + pi.lock.Lock() + defer pi.lock.Unlock() + pi.removed = true +} + +func (pi *PeerInfo) Removed() bool { + pi.lock.RLock() + defer pi.lock.RUnlock() + return pi.removed +} + func makeP2PServer( ctx context.Context, nodeName string, @@ -317,7 +330,9 @@ func runPeer( if err = common.Stopped(ctx.Done()); err != nil { return err } - + if peerInfo.Removed() { + return fmt.Errorf("peer removed") + } msg, err := rw.ReadMsg() if err != nil { return fmt.Errorf("reading message: %v", err) @@ -608,6 +623,12 @@ type SentryServerImpl struct { func (ss *SentryServerImpl) PenalizePeer(_ context.Context, req *proto_sentry.PenalizePeerRequest) (*empty.Empty, error) { //log.Warn("Received penalty", "kind", req.GetPenalty().Descriptor().FullName, "from", fmt.Sprintf("%s", req.GetPeerId())) strId := string(gointerfaces.ConvertH512ToBytes(req.PeerId)) + if x, ok := ss.Peers.Load(strId); ok { + peerInfo := x.(*PeerInfo) + if peerInfo != nil { + peerInfo.Remove() + } + } ss.Peers.Delete(strId) return &empty.Empty{}, nil } @@ -672,6 +693,12 @@ func (ss *SentryServerImpl) SendMessageByMinBlock(_ context.Context, inreq *prot return &proto_sentry.SentPeers{}, fmt.Errorf("sendMessageByMinBlock not implemented for message Id: %s", inreq.Data.Id) } if err := peerInfo.rw.WriteMsg(p2p.Msg{Code: msgcode, Size: uint32(len(inreq.Data.Data)), Payload: bytes.NewReader(inreq.Data.Data)}); err != nil { + if x, ok := ss.Peers.Load(peerID); ok { + peerInfo := x.(*PeerInfo) + if peerInfo != nil { + peerInfo.Remove() + } + } ss.Peers.Delete(peerID) return &proto_sentry.SentPeers{}, fmt.Errorf("sendMessageByMinBlock to peer %s: %v", peerID, err) } @@ -707,6 +734,12 @@ func (ss *SentryServerImpl) SendMessageById(_ context.Context, inreq *proto_sent } if err := peerInfo.rw.WriteMsg(p2p.Msg{Code: msgcode, Size: uint32(len(inreq.Data.Data)), Payload: bytes.NewReader(inreq.Data.Data)}); err != nil { + if x, ok := ss.Peers.Load(peerID); ok { + peerInfo := x.(*PeerInfo) + if peerInfo != nil { + peerInfo.Remove() + } + } ss.Peers.Delete(peerID) return &proto_sentry.SentPeers{}, fmt.Errorf("sendMessageById to peer %s: %v", peerID, err) } @@ -745,6 +778,12 @@ func (ss *SentryServerImpl) SendMessageToRandomPeers(ctx context.Context, req *p return true } if err := peerInfo.rw.WriteMsg(p2p.Msg{Code: msgcode, Size: uint32(len(req.Data.Data)), Payload: bytes.NewReader(req.Data.Data)}); err != nil { + if x, ok := ss.Peers.Load(peerID); ok { + peerInfo := x.(*PeerInfo) + if peerInfo != nil { + peerInfo.Remove() + } + } ss.Peers.Delete(peerID) innerErr = err return false @@ -779,6 +818,12 @@ func (ss *SentryServerImpl) SendMessageToAll(ctx context.Context, req *proto_sen return true } if err := peerInfo.rw.WriteMsg(p2p.Msg{Code: msgcode, Size: uint32(len(req.Data)), Payload: bytes.NewReader(req.Data)}); err != nil { + if x, ok := ss.Peers.Load(peerID); ok { + peerInfo := x.(*PeerInfo) + if peerInfo != nil { + peerInfo.Remove() + } + } ss.Peers.Delete(peerID) innerErr = err return false