From d491e4c093c92161d0619f7f81ad3541836bc768 Mon Sep 17 00:00:00 2001 From: "alex.sharov" Date: Wed, 11 Aug 2021 11:21:36 +0700 Subject: [PATCH] add KV --- txpool/fetch.go | 19 +++++++---- txpool/fetch_test.go | 6 ++-- txpool/mocks_test.go | 30 ++++++++++++----- txpool/pool.go | 70 +++++++++++++++++++++++++++++++--------- txpool/pool_fuzz_test.go | 6 ++-- 5 files changed, 97 insertions(+), 34 deletions(-) diff --git a/txpool/fetch.go b/txpool/fetch.go index 540deff9e..1de80358c 100644 --- a/txpool/fetch.go +++ b/txpool/fetch.go @@ -28,6 +28,7 @@ import ( "github.com/ledgerwatch/erigon-lib/gointerfaces" "github.com/ledgerwatch/erigon-lib/gointerfaces/remote" "github.com/ledgerwatch/erigon-lib/gointerfaces/sentry" + "github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/log/v3" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -43,7 +44,8 @@ type Fetch struct { sentryClients []sentry.SentryClient // sentry clients that will be used for accessing the network statusData *sentry.StatusData // Status data used for "handshaking" with sentries pool Pool // Transaction pool implementation - wg *sync.WaitGroup // used for synchronisation in the tests (nil when not in tests) + db kv.RoDB + wg *sync.WaitGroup // used for synchronisation in the tests (nil when not in tests) stateChangesClient remote.KVClient } @@ -62,7 +64,7 @@ var DefaultTimings = Timings{ // NewFetch creates a new fetch object that will work with given sentry clients. Since the // SentryClient here is an interface, it is suitable for mocking in tests (mock will need // to implement all the functions of the SentryClient interface). -func NewFetch(ctx context.Context, sentryClients []sentry.SentryClient, genesisHash [32]byte, networkId uint64, forks []uint64, pool Pool, stateChangesClient remote.KVClient) *Fetch { +func NewFetch(ctx context.Context, sentryClients []sentry.SentryClient, genesisHash [32]byte, networkId uint64, forks []uint64, pool Pool, stateChangesClient remote.KVClient, db kv.RoDB) *Fetch { statusData := &sentry.StatusData{ NetworkId: networkId, TotalDifficulty: gointerfaces.ConvertUint256IntToH256(uint256.NewInt(0)), @@ -78,6 +80,7 @@ func NewFetch(ctx context.Context, sentryClients []sentry.SentryClient, genesisH sentryClients: sentryClients, statusData: statusData, pool: pool, + db: db, stateChangesClient: stateChangesClient, } } @@ -176,7 +179,7 @@ func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) { if req == nil { return } - if err = f.handleInboundMessage(req, sentryClient); err != nil { + if err = f.handleInboundMessage(streamCtx, req, sentryClient); err != nil { log.Warn("Handling incoming message: %s", "err", err) } if f.wg != nil { @@ -186,7 +189,7 @@ func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) { } } -func (f *Fetch) handleInboundMessage(req *sentry.InboundMessage, sentryClient sentry.SentryClient) error { +func (f *Fetch) handleInboundMessage(ctx context.Context, req *sentry.InboundMessage, sentryClient sentry.SentryClient) error { switch req.Id { case sentry.MessageId_NEW_POOLED_TRANSACTION_HASHES_66, sentry.MessageId_NEW_POOLED_TRANSACTION_HASHES_65: hashCount, pos, err := ParseHashesCount(req.Data, 0) @@ -279,7 +282,9 @@ func (f *Fetch) handleInboundMessage(req *sentry.InboundMessage, sentryClient se return err } } - if err := f.pool.Add(txs); err != nil { + if err := f.db.View(ctx, func(tx kv.Tx) error { + return f.pool.Add(tx, txs) + }); err != nil { return err } } @@ -438,7 +443,9 @@ func (f *Fetch) handleStateChanges(ctx context.Context, client remote.KVClient) diff[string(addr[:])] = senderInfo{nonce: nonce, balance: balance} } - if err := f.pool.OnNewBlock(diff, unwindTxs, minedTxs, req.ProtocolBaseFee, req.BlockBaseFee, req.BlockHeight); err != nil { + if err := f.db.View(ctx, func(tx kv.Tx) error { + return f.pool.OnNewBlock(tx, diff, unwindTxs, minedTxs, req.ProtocolBaseFee, req.BlockBaseFee, req.BlockHeight) + }); err != nil { log.Warn("onNewBlock", "err", err) } if f.wg != nil { diff --git a/txpool/fetch_test.go b/txpool/fetch_test.go index 651c09f56..ff6cdceda 100644 --- a/txpool/fetch_test.go +++ b/txpool/fetch_test.go @@ -27,6 +27,7 @@ import ( "github.com/ledgerwatch/erigon-lib/gointerfaces/remote" "github.com/ledgerwatch/erigon-lib/gointerfaces/sentry" "github.com/ledgerwatch/erigon-lib/gointerfaces/types" + "github.com/ledgerwatch/erigon-lib/kv/memdb" "github.com/ledgerwatch/log/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -45,7 +46,7 @@ func TestFetch(t *testing.T) { sentryClient := direct.NewSentryClientDirect(direct.ETH66, m) pool := &PoolMock{} - fetch := NewFetch(ctx, []sentry.SentryClient{sentryClient}, genesisHash, networkId, forks, pool, &remote.KVClientMock{}) + fetch := NewFetch(ctx, []sentry.SentryClient{sentryClient}, genesisHash, networkId, forks, pool, &remote.KVClientMock{}, nil) var wg sync.WaitGroup fetch.SetWaitGroup(&wg) m.StreamWg.Add(2) @@ -138,6 +139,7 @@ func TestSendTxPropagate(t *testing.T) { func TestOnNewBlock(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + db := memdb.NewTestDB(t) var genesisHash [32]byte var networkId uint64 = 1 @@ -157,7 +159,7 @@ func TestOnNewBlock(t *testing.T) { }, } pool := &PoolMock{} - fetch := NewFetch(ctx, nil, genesisHash, networkId, nil, pool, stateChanges) + fetch := NewFetch(ctx, nil, genesisHash, networkId, nil, pool, stateChanges, db) fetch.handleStateChanges(ctx, stateChanges) assert.Equal(t, 1, len(pool.OnNewBlockCalls())) assert.Equal(t, 3, len(pool.OnNewBlockCalls()[0].MinedTxs.txs)) diff --git a/txpool/mocks_test.go b/txpool/mocks_test.go index 1b4ae4343..afbfe311c 100644 --- a/txpool/mocks_test.go +++ b/txpool/mocks_test.go @@ -5,6 +5,8 @@ package txpool import ( "sync" + + "github.com/ledgerwatch/erigon-lib/kv" ) // Ensure, that PoolMock does implement Pool. @@ -17,7 +19,7 @@ var _ Pool = &PoolMock{} // // // make and configure a mocked Pool // mockedPool := &PoolMock{ -// AddFunc: func(newTxs TxSlots) error { +// AddFunc: func(db kv.Tx, newTxs TxSlots) error { // panic("mock out the Add method") // }, // AddNewGoodPeerFunc: func(peerID PeerID) { @@ -29,7 +31,7 @@ var _ Pool = &PoolMock{} // IdHashKnownFunc: func(hash []byte) bool { // panic("mock out the IdHashKnown method") // }, -// OnNewBlockFunc: func(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error { +// OnNewBlockFunc: func(db kv.Tx, stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error { // panic("mock out the OnNewBlock method") // }, // } @@ -40,7 +42,7 @@ var _ Pool = &PoolMock{} // } type PoolMock struct { // AddFunc mocks the Add method. - AddFunc func(newTxs TxSlots) error + AddFunc func(db kv.Tx, newTxs TxSlots) error // AddNewGoodPeerFunc mocks the AddNewGoodPeer method. AddNewGoodPeerFunc func(peerID PeerID) @@ -52,12 +54,14 @@ type PoolMock struct { IdHashKnownFunc func(hash []byte) bool // OnNewBlockFunc mocks the OnNewBlock method. - OnNewBlockFunc func(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error + OnNewBlockFunc func(db kv.Tx, stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error // calls tracks calls to the methods. calls struct { // Add holds details about calls to the Add method. Add []struct { + // Db is the db argument value. + Db kv.Tx // NewTxs is the newTxs argument value. NewTxs TxSlots } @@ -78,6 +82,8 @@ type PoolMock struct { } // OnNewBlock holds details about calls to the OnNewBlock method. OnNewBlock []struct { + // Db is the db argument value. + Db kv.Tx // StateChanges is the stateChanges argument value. StateChanges map[string]senderInfo // UnwindTxs is the unwindTxs argument value. @@ -100,10 +106,12 @@ type PoolMock struct { } // Add calls AddFunc. -func (mock *PoolMock) Add(newTxs TxSlots) error { +func (mock *PoolMock) Add(db kv.Tx, newTxs TxSlots) error { callInfo := struct { + Db kv.Tx NewTxs TxSlots }{ + Db: db, NewTxs: newTxs, } mock.lockAdd.Lock() @@ -115,16 +123,18 @@ func (mock *PoolMock) Add(newTxs TxSlots) error { ) return errOut } - return mock.AddFunc(newTxs) + return mock.AddFunc(db, newTxs) } // AddCalls gets all the calls that were made to Add. // Check the length with: // len(mockedPool.AddCalls()) func (mock *PoolMock) AddCalls() []struct { + Db kv.Tx NewTxs TxSlots } { var calls []struct { + Db kv.Tx NewTxs TxSlots } mock.lockAdd.RLock() @@ -233,8 +243,9 @@ func (mock *PoolMock) IdHashKnownCalls() []struct { } // OnNewBlock calls OnNewBlockFunc. -func (mock *PoolMock) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error { +func (mock *PoolMock) OnNewBlock(db kv.Tx, stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error { callInfo := struct { + Db kv.Tx StateChanges map[string]senderInfo UnwindTxs TxSlots MinedTxs TxSlots @@ -242,6 +253,7 @@ func (mock *PoolMock) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs T BlockBaseFee uint64 BlockHeight uint64 }{ + Db: db, StateChanges: stateChanges, UnwindTxs: unwindTxs, MinedTxs: minedTxs, @@ -258,13 +270,14 @@ func (mock *PoolMock) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs T ) return errOut } - return mock.OnNewBlockFunc(stateChanges, unwindTxs, minedTxs, protocolBaseFee, blockBaseFee, blockHeight) + return mock.OnNewBlockFunc(db, stateChanges, unwindTxs, minedTxs, protocolBaseFee, blockBaseFee, blockHeight) } // OnNewBlockCalls gets all the calls that were made to OnNewBlock. // Check the length with: // len(mockedPool.OnNewBlockCalls()) func (mock *PoolMock) OnNewBlockCalls() []struct { + Db kv.Tx StateChanges map[string]senderInfo UnwindTxs TxSlots MinedTxs TxSlots @@ -273,6 +286,7 @@ func (mock *PoolMock) OnNewBlockCalls() []struct { BlockHeight uint64 } { var calls []struct { + Db kv.Tx StateChanges map[string]senderInfo UnwindTxs TxSlots MinedTxs TxSlots diff --git a/txpool/pool.go b/txpool/pool.go index 0c3fb8321..4ad2b3636 100644 --- a/txpool/pool.go +++ b/txpool/pool.go @@ -27,6 +27,7 @@ import ( "github.com/google/btree" lru "github.com/hashicorp/golang-lru" "github.com/holiman/uint256" + "github.com/ledgerwatch/erigon-lib/kv" "go.uber.org/atomic" ) @@ -37,8 +38,8 @@ type Pool interface { // IdHashKnown check whether transaction with given Id hash is known to the pool IdHashKnown(hash []byte) bool GetRlp(hash []byte) []byte - Add(newTxs TxSlots) error - OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error + Add(db kv.Tx, newTxs TxSlots) error + OnNewBlock(db kv.Tx, stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error AddNewGoodPeer(peerID PeerID) } @@ -87,6 +88,8 @@ const PendingSubPoolLimit = 1024 const BaseFeeSubPoolLimit = 1024 const QueuedSubPoolLimit = 1024 +const MaxSendersInfoCache = 1024 + type nonce2Tx struct{ *btree.BTree } type senderInfo struct { @@ -115,6 +118,7 @@ type TxPool struct { protocolBaseFee atomic.Uint64 blockBaseFee atomic.Uint64 + senderID uint64 senderIDs map[string]uint64 senderInfo map[uint64]*senderInfo byHash map[string]*metaTx // tx_hash => tx @@ -202,7 +206,7 @@ func (p *TxPool) IdHashIsLocal(hash []byte) bool { } func (p *TxPool) OnNewPeer(peerID PeerID) { p.recentlyConnectedPeers.AddPeer(peerID) } -func (p *TxPool) Add(newTxs TxSlots) error { +func (p *TxPool) Add(tx kv.Tx, newTxs TxSlots) error { p.lock.Lock() defer p.lock.Unlock() if err := newTxs.Valid(); err != nil { @@ -214,7 +218,9 @@ func (p *TxPool) Add(newTxs TxSlots) error { return fmt.Errorf("non-zero base fee") } - setTxSenderID(p.senderIDs, p.senderInfo, newTxs) + if err := setTxSenderID(tx, &p.senderID, p.senderIDs, p.senderInfo, newTxs); err != nil { + return err + } if err := onNewTxs(p.senderInfo, newTxs, protocolBaseFee, blockBaseFee, p.pending, p.baseFee, p.queued, p.byHash, p.localsHistory); err != nil { return err } @@ -286,7 +292,7 @@ func onNewTxs(senderInfo map[uint64]*senderInfo, newTxs TxSlots, protocolBaseFee return nil } -func (p *TxPool) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error { +func (p *TxPool) OnNewBlock(tx kv.Tx, stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error { p.lock.Lock() defer p.lock.Unlock() if err := unwindTxs.Valid(); err != nil { @@ -300,13 +306,18 @@ func (p *TxPool) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, mined p.protocolBaseFee.Store(protocolBaseFee) p.blockBaseFee.Store(blockBaseFee) - setTxSenderID(p.senderIDs, p.senderInfo, unwindTxs) - setTxSenderID(p.senderIDs, p.senderInfo, minedTxs) + if err := setTxSenderID(tx, &p.senderID, p.senderIDs, p.senderInfo, unwindTxs); err != nil { + return err + } + if err := setTxSenderID(tx, &p.senderID, p.senderIDs, p.senderInfo, minedTxs); err != nil { + return err + } for addr, id := range p.senderIDs { // merge state changes if v, ok := stateChanges[addr]; ok { p.senderInfo[id] = &v } } + if err := onNewBlock(p.senderInfo, unwindTxs, minedTxs.txs, protocolBaseFee, blockBaseFee, p.pending, p.baseFee, p.queued, p.byHash, p.localsHistory); err != nil { return err } @@ -326,24 +337,53 @@ func (p *TxPool) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, mined } } + /* + // evict sendersInfo without txs + if len(p.senderIDs) > MaxSendersInfoCache { + for i := range p.senderInfo { + if p.senderInfo[i].txNonce2Tx.Len() > 0 { + continue + } + for addr, id := range p.senderIDs { + if id == i { + delete(p.senderIDs, addr) + } + } + delete(p.senderInfo, i) + } + } + */ + return nil } -func setTxSenderID(senderIDs map[string]uint64, senderInfo map[uint64]*senderInfo, txs TxSlots) { + +func setTxSenderID(tx kv.Tx, senderIDSequence *uint64, senderIDs map[string]uint64, sendersInfo map[uint64]*senderInfo, txs TxSlots) error { for i := range txs.txs { addr := string(txs.senders.At(i)) + // assign ID to each new sender id, ok := senderIDs[addr] if !ok { - for i := range senderInfo { //TODO: create field for it? - if id < i { - id = i - } - } - id++ - senderIDs[addr] = id + *senderIDSequence++ + senderIDs[addr] = *senderIDSequence } txs.txs[i].senderID = id + + // load data from db if need + _, ok = sendersInfo[txs.txs[i].senderID] + if !ok { + encoded, err := tx.GetOne(kv.PlainState, txs.senders.At(i)) + if err != nil { + return err + } + nonce, balance, err := DecodeSender(encoded) + if err != nil { + return err + } + sendersInfo[txs.txs[i].senderID] = &senderInfo{nonce: nonce, balance: balance} + } } + return nil } func onNewBlock(senderInfo map[uint64]*senderInfo, unwindTxs TxSlots, minedTxs []*TxSlot, protocolBaseFee, blockBaseFee uint64, pending, baseFee, queued *SubPool, byHash map[string]*metaTx, localsHistory *lru.Cache) error { diff --git a/txpool/pool_fuzz_test.go b/txpool/pool_fuzz_test.go index 257f57592..d905de78f 100644 --- a/txpool/pool_fuzz_test.go +++ b/txpool/pool_fuzz_test.go @@ -426,19 +426,19 @@ func FuzzOnNewBlocks11(f *testing.F) { // go to first fork //fmt.Printf("ll1: %d,%d,%d\n", pool.pending.Len(), pool.baseFee.Len(), pool.queued.Len()) unwindTxs, minedTxs1, p2pReceived, minedTxs2 := splitDataset(txs) - err = pool.OnNewBlock(map[string]senderInfo{}, unwindTxs, minedTxs1, protocolBaseFee, blockBaseFee, 1) + err = pool.OnNewBlock(nil, map[string]senderInfo{}, unwindTxs, minedTxs1, protocolBaseFee, blockBaseFee, 1) assert.NoError(err) check(unwindTxs, minedTxs1, "fork1") checkNotify(unwindTxs, minedTxs1, "fork1") // unwind everything and switch to new fork (need unwind mined now) - err = pool.OnNewBlock(map[string]senderInfo{}, minedTxs1, minedTxs2, protocolBaseFee, blockBaseFee, 2) + err = pool.OnNewBlock(nil, map[string]senderInfo{}, minedTxs1, minedTxs2, protocolBaseFee, blockBaseFee, 2) assert.NoError(err) check(minedTxs1, minedTxs2, "fork2") checkNotify(minedTxs1, minedTxs2, "fork2") // add some remote txs from p2p - err = pool.Add(p2pReceived) + err = pool.Add(nil, p2pReceived) assert.NoError(err) check(p2pReceived, TxSlots{}, "p2pmsg1") checkNotify(p2pReceived, TxSlots{}, "p2pmsg1")