diff --git a/txpool/fetch.go b/txpool/fetch.go index 83a9e36d2..2febf9c7a 100644 --- a/txpool/fetch.go +++ b/txpool/fetch.go @@ -19,6 +19,7 @@ package txpool import ( "context" "errors" + "fmt" "io" "log" "sync" @@ -168,6 +169,7 @@ func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) { } func (f *Fetch) handleInboundMessage(req *sentry.InboundMessage, sentryClient sentry.SentryClient) error { + fmt.Printf("got inbound message\n") return nil } diff --git a/txpool/fetch_test.go b/txpool/fetch_test.go index 769068abc..55a77abb8 100644 --- a/txpool/fetch_test.go +++ b/txpool/fetch_test.go @@ -18,6 +18,7 @@ package txpool import ( "context" + "sync" "testing" "github.com/ledgerwatch/erigon-lib/direct" @@ -34,5 +35,16 @@ func TestFetch(t *testing.T) { sentryClient := direct.NewSentryClientDirect(direct.ETH66, mock) fetch := NewFetch(ctx, []sentry.SentryClient{sentryClient}, genesisHash, networkId, forks) + var wg sync.WaitGroup + fetch.SetWaitGroup(&wg) fetch.Start() + // Send one transaction id + wg.Add(1) + errs := mock.Send(&sentry.InboundMessage{Id: sentry.MessageId_NEW_POOLED_TRANSACTION_HASHES_66, Data: nil, PeerId: PeerId}) + for i, err := range errs { + if err != nil { + t.Errorf("sending new pool tx hashes 66 (%d): %v", i, err) + } + } + wg.Wait() } diff --git a/txpool/test_util.go b/txpool/test_util.go index 762e12c8c..c07227424 100644 --- a/txpool/test_util.go +++ b/txpool/test_util.go @@ -18,8 +18,8 @@ package txpool import ( "context" - "sync" + "github.com/ledgerwatch/erigon-lib/gointerfaces" "github.com/ledgerwatch/erigon-lib/gointerfaces/sentry" "google.golang.org/protobuf/types/known/emptypb" ) @@ -27,19 +27,19 @@ import ( type MockSentry struct { sentry.UnimplementedSentryServer streams map[sentry.MessageId][]sentry.Sentry_MessagesServer - StreamWg sync.WaitGroup - peersStream sentry.Sentry_PeersServer + peersStreams []sentry.Sentry_PeersServer sentMessages []*sentry.OutboundMessageData ctx context.Context } func NewMockSentry(ctx context.Context) *MockSentry { - return &MockSentry{} + return &MockSentry{ctx: ctx} } +var PeerId = gointerfaces.ConvertBytesToH512([]byte("12345")) + // Stream returns stream, waiting if necessary func (ms *MockSentry) Send(req *sentry.InboundMessage) (errs []error) { - ms.StreamWg.Wait() for _, stream := range ms.streams[req.Id] { if err := stream.Send(req); err != nil { errs = append(errs, err) @@ -84,7 +84,6 @@ func (ms *MockSentry) Messages(req *sentry.MessagesRequest, stream sentry.Sentry for _, id := range req.Ids { ms.streams[id] = append(ms.streams[id], stream) } - ms.StreamWg.Done() select { case <-ms.ctx.Done(): return nil @@ -97,8 +96,7 @@ func (ms *MockSentry) PeerCount(_ context.Context, req *sentry.PeerCountRequest) } func (ms *MockSentry) Peers(req *sentry.PeersRequest, stream sentry.Sentry_PeersServer) error { - ms.peersStream = stream - ms.StreamWg.Done() + ms.peersStreams = append(ms.peersStreams, stream) select { case <-ms.ctx.Done(): return nil