diff --git a/cmd/rpcdaemon/commands/eth_filters_test.go b/cmd/rpcdaemon/commands/eth_filters_test.go index ca4366bb9..e9a420e79 100644 --- a/cmd/rpcdaemon/commands/eth_filters_test.go +++ b/cmd/rpcdaemon/commands/eth_filters_test.go @@ -1,16 +1,22 @@ package commands import ( + "math/rand" + "sync" "testing" + "time" "github.com/ledgerwatch/erigon-lib/gointerfaces/txpool" "github.com/ledgerwatch/erigon-lib/kv/kvcache" + "github.com/stretchr/testify/assert" + "github.com/ledgerwatch/erigon/cmd/rpcdaemon/rpcdaemontest" + "github.com/ledgerwatch/erigon/common" + "github.com/ledgerwatch/erigon/core/types" "github.com/ledgerwatch/erigon/eth/filters" "github.com/ledgerwatch/erigon/turbo/rpchelper" "github.com/ledgerwatch/erigon/turbo/snapshotsync" "github.com/ledgerwatch/erigon/turbo/stages" - "github.com/stretchr/testify/assert" ) func TestNewFilters(t *testing.T) { @@ -43,3 +49,50 @@ func TestNewFilters(t *testing.T) { assert.Nil(err) assert.Equal(ok, true) } + +func TestLogsSubscribeAndUnsubscribe_WithoutConcurrentMapIssue(t *testing.T) { + ctx, conn := rpcdaemontest.CreateTestGrpcConn(t, stages.Mock(t)) + mining := txpool.NewMiningClient(conn) + ff := rpchelper.New(ctx, nil, nil, mining, func() {}) + + // generate some random topics + topics := make([][]common.Hash, 0) + for i := 0; i < 10; i++ { + bytes := make([]byte, common.HashLength) + rand.Read(bytes) + toAdd := []common.Hash{common.BytesToHash(bytes)} + topics = append(topics, toAdd) + } + + // generate some addresses + addresses := make([]common.Address, 0) + for i := 0; i < 10; i++ { + bytes := make([]byte, common.AddressLength) + rand.Read(bytes) + addresses = append(addresses, common.BytesToAddress(bytes)) + } + + crit := filters.FilterCriteria{ + Topics: topics, + Addresses: addresses, + } + + ids := make([]rpchelper.LogsSubID, 1000, 1000) + + // make a lot of subscriptions + wg := sync.WaitGroup{} + for i := 0; i < 1000; i++ { + wg.Add(1) + go func(idx int) { + out := make(chan *types.Log, 1) + id := ff.SubscribeLogs(out, crit) + defer func() { + time.Sleep(100 * time.Nanosecond) + ff.UnsubscribeLogs(id) + wg.Done() + }() + ids[idx] = id + }(i) + } + wg.Wait() +} diff --git a/turbo/rpchelper/filters.go b/turbo/rpchelper/filters.go index cb98492f6..bde6fe232 100644 --- a/turbo/rpchelper/filters.go +++ b/turbo/rpchelper/filters.go @@ -401,49 +401,65 @@ func (ff *Filters) SubscribeLogs(out chan *types.Log, crit filters.FilterCriteri AllAddresses: ff.logsSubs.aggLogsFilter.allAddrs == 1, AllTopics: ff.logsSubs.aggLogsFilter.allTopics == 1, } - ff.mu.Lock() - defer ff.mu.Unlock() - for addr := range ff.logsSubs.aggLogsFilter.addrs { + + addresses, topics := ff.logsSubs.getAggMaps() + + for addr := range addresses { lfr.Addresses = append(lfr.Addresses, gointerfaces.ConvertAddressToH160(addr)) } - for topic := range ff.logsSubs.aggLogsFilter.topics { + for topic := range topics { lfr.Topics = append(lfr.Topics, gointerfaces.ConvertHashToH256(topic)) } - loaded := ff.logsRequestor.Load() + + loaded := ff.loadLogsRequester() if loaded != nil { if err := loaded.(func(*remote.LogsFilterRequest) error)(lfr); err != nil { log.Warn("Could not update remote logs filter", "err", err) ff.logsSubs.removeLogsFilter(id) } } + return id } +func (ff *Filters) loadLogsRequester() any { + ff.mu.Lock() + defer ff.mu.Unlock() + return ff.logsRequestor.Load() +} + func (ff *Filters) UnsubscribeLogs(id LogsSubID) bool { isDeleted := ff.logsSubs.removeLogsFilter(id) lfr := &remote.LogsFilterRequest{ AllAddresses: ff.logsSubs.aggLogsFilter.allAddrs == 1, AllTopics: ff.logsSubs.aggLogsFilter.allTopics == 1, } - ff.mu.Lock() - defer ff.mu.Unlock() - for addr := range ff.logsSubs.aggLogsFilter.addrs { + + addresses, topics := ff.logsSubs.getAggMaps() + + for addr := range addresses { lfr.Addresses = append(lfr.Addresses, gointerfaces.ConvertAddressToH160(addr)) } - for topic := range ff.logsSubs.aggLogsFilter.topics { + for topic := range topics { lfr.Topics = append(lfr.Topics, gointerfaces.ConvertHashToH256(topic)) } - loaded := ff.logsRequestor.Load() + loaded := ff.loadLogsRequester() if loaded != nil { if err := loaded.(func(*remote.LogsFilterRequest) error)(lfr); err != nil { log.Warn("Could not update remote logs filter", "err", err) return isDeleted || ff.logsSubs.removeLogsFilter(id) } } + + ff.deleteLogStore(id) + + return isDeleted +} + +func (ff *Filters) deleteLogStore(id LogsSubID) { ff.storeMu.Lock() defer ff.storeMu.Unlock() delete(ff.logsStores, id) - return isDeleted } func (ff *Filters) OnNewEvent(event *remote.SubscribeReply) { diff --git a/turbo/rpchelper/logsfilter.go b/turbo/rpchelper/logsfilter.go index b22c67678..933425643 100644 --- a/turbo/rpchelper/logsfilter.go +++ b/turbo/rpchelper/logsfilter.go @@ -13,7 +13,7 @@ import ( type LogsFilterAggregator struct { aggLogsFilter LogsFilter // Aggregation of all current log filters logsFilters map[LogsSubID]*LogsFilter // Filter for each subscriber, keyed by filterID - logsFilterLock sync.Mutex + logsFilterLock sync.RWMutex nextFilterId LogsSubID } @@ -94,6 +94,23 @@ func (a *LogsFilterAggregator) addLogsFilters(f *LogsFilter) { } } +func (a *LogsFilterAggregator) getAggMaps() (map[common.Address]int, map[common.Hash]int) { + a.logsFilterLock.RLock() + defer a.logsFilterLock.RUnlock() + + addresses := make(map[common.Address]int) + for k, v := range a.aggLogsFilter.addrs { + addresses[k] = v + } + + topics := make(map[common.Hash]int) + for k, v := range a.aggLogsFilter.topics { + topics[k] = v + } + + return addresses, topics +} + func (a *LogsFilterAggregator) distributeLog(eventLog *remote.SubscribeLogsReply) error { a.logsFilterLock.Lock() defer a.logsFilterLock.Unlock()