rpc: fix for map concurrency issue in logs subscription (#4903)

moving a couple of mutex locks and introducing another to prevent a deferred call to unsubscribe clashing with a new call to subscribe
This commit is contained in:
hexoscott 2022-08-02 18:37:34 +01:00 committed by GitHub
parent ec67e80a8a
commit 98c639784b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 6 deletions

View File

@ -17,12 +17,13 @@ import (
"github.com/ledgerwatch/erigon-lib/gointerfaces/remote" "github.com/ledgerwatch/erigon-lib/gointerfaces/remote"
"github.com/ledgerwatch/erigon-lib/gointerfaces/txpool" "github.com/ledgerwatch/erigon-lib/gointerfaces/txpool"
txpool2 "github.com/ledgerwatch/erigon-lib/txpool" txpool2 "github.com/ledgerwatch/erigon-lib/txpool"
"github.com/ledgerwatch/log/v3"
"google.golang.org/grpc"
"github.com/ledgerwatch/erigon/common" "github.com/ledgerwatch/erigon/common"
"github.com/ledgerwatch/erigon/core/types" "github.com/ledgerwatch/erigon/core/types"
"github.com/ledgerwatch/erigon/eth/filters" "github.com/ledgerwatch/erigon/eth/filters"
"github.com/ledgerwatch/erigon/rlp" "github.com/ledgerwatch/erigon/rlp"
"github.com/ledgerwatch/log/v3"
"google.golang.org/grpc"
) )
type ( type (
@ -400,14 +401,14 @@ func (ff *Filters) SubscribeLogs(out chan *types.Log, crit filters.FilterCriteri
AllAddresses: ff.logsSubs.aggLogsFilter.allAddrs == 1, AllAddresses: ff.logsSubs.aggLogsFilter.allAddrs == 1,
AllTopics: ff.logsSubs.aggLogsFilter.allTopics == 1, AllTopics: ff.logsSubs.aggLogsFilter.allTopics == 1,
} }
ff.mu.Lock()
defer ff.mu.Unlock()
for addr := range ff.logsSubs.aggLogsFilter.addrs { for addr := range ff.logsSubs.aggLogsFilter.addrs {
lfr.Addresses = append(lfr.Addresses, gointerfaces.ConvertAddressToH160(addr)) lfr.Addresses = append(lfr.Addresses, gointerfaces.ConvertAddressToH160(addr))
} }
for topic := range ff.logsSubs.aggLogsFilter.topics { for topic := range ff.logsSubs.aggLogsFilter.topics {
lfr.Topics = append(lfr.Topics, gointerfaces.ConvertHashToH256(topic)) lfr.Topics = append(lfr.Topics, gointerfaces.ConvertHashToH256(topic))
} }
ff.mu.Lock()
defer ff.mu.Unlock()
loaded := ff.logsRequestor.Load() loaded := ff.logsRequestor.Load()
if loaded != nil { if loaded != nil {
if err := loaded.(func(*remote.LogsFilterRequest) error)(lfr); err != nil { if err := loaded.(func(*remote.LogsFilterRequest) error)(lfr); err != nil {
@ -424,14 +425,14 @@ func (ff *Filters) UnsubscribeLogs(id LogsSubID) bool {
AllAddresses: ff.logsSubs.aggLogsFilter.allAddrs == 1, AllAddresses: ff.logsSubs.aggLogsFilter.allAddrs == 1,
AllTopics: ff.logsSubs.aggLogsFilter.allTopics == 1, AllTopics: ff.logsSubs.aggLogsFilter.allTopics == 1,
} }
ff.mu.Lock()
defer ff.mu.Unlock()
for addr := range ff.logsSubs.aggLogsFilter.addrs { for addr := range ff.logsSubs.aggLogsFilter.addrs {
lfr.Addresses = append(lfr.Addresses, gointerfaces.ConvertAddressToH160(addr)) lfr.Addresses = append(lfr.Addresses, gointerfaces.ConvertAddressToH160(addr))
} }
for topic := range ff.logsSubs.aggLogsFilter.topics { for topic := range ff.logsSubs.aggLogsFilter.topics {
lfr.Topics = append(lfr.Topics, gointerfaces.ConvertHashToH256(topic)) lfr.Topics = append(lfr.Topics, gointerfaces.ConvertHashToH256(topic))
} }
ff.mu.Lock()
defer ff.mu.Unlock()
loaded := ff.logsRequestor.Load() loaded := ff.logsRequestor.Load()
if loaded != nil { if loaded != nil {
if err := loaded.(func(*remote.LogsFilterRequest) error)(lfr); err != nil { if err := loaded.(func(*remote.LogsFilterRequest) error)(lfr); err != nil {

View File

@ -5,6 +5,7 @@ import (
"github.com/ledgerwatch/erigon-lib/gointerfaces" "github.com/ledgerwatch/erigon-lib/gointerfaces"
"github.com/ledgerwatch/erigon-lib/gointerfaces/remote" "github.com/ledgerwatch/erigon-lib/gointerfaces/remote"
"github.com/ledgerwatch/erigon/common" "github.com/ledgerwatch/erigon/common"
types2 "github.com/ledgerwatch/erigon/core/types" types2 "github.com/ledgerwatch/erigon/core/types"
) )
@ -81,6 +82,8 @@ func (a *LogsFilterAggregator) subtractLogFilters(f *LogsFilter) {
} }
func (a *LogsFilterAggregator) addLogsFilters(f *LogsFilter) { func (a *LogsFilterAggregator) addLogsFilters(f *LogsFilter) {
a.logsFilterLock.Lock()
defer a.logsFilterLock.Unlock()
a.aggLogsFilter.allAddrs += f.allAddrs a.aggLogsFilter.allAddrs += f.allAddrs
for addr, count := range f.addrs { for addr, count := range f.addrs {
a.aggLogsFilter.addrs[addr] += count a.aggLogsFilter.addrs[addr] += count