diff --git a/gointerfaces/remote/mocks.go b/gointerfaces/remote/mocks.go index 17890ba3d..1cb97e871 100644 --- a/gointerfaces/remote/mocks.go +++ b/gointerfaces/remote/mocks.go @@ -7,6 +7,7 @@ import ( context "context" types "github.com/ledgerwatch/erigon-lib/gointerfaces/types" grpc "google.golang.org/grpc" + "google.golang.org/grpc/metadata" emptypb "google.golang.org/protobuf/types/known/emptypb" sync "sync" ) @@ -203,3 +204,314 @@ func (mock *KVClientMock) VersionCalls() []struct { mock.lockVersion.RUnlock() return calls } + +// Ensure, that KV_StateChangesClientMock does implement KV_StateChangesClient. +// If this is not the case, regenerate this file with moq. +var _ KV_StateChangesClient = &KV_StateChangesClientMock{} + +// KV_StateChangesClientMock is a mock implementation of KV_StateChangesClient. +// +// func TestSomethingThatUsesKV_StateChangesClient(t *testing.T) { +// +// // make and configure a mocked KV_StateChangesClient +// mockedKV_StateChangesClient := &KV_StateChangesClientMock{ +// CloseSendFunc: func() error { +// panic("mock out the CloseSend method") +// }, +// ContextFunc: func() context.Context { +// panic("mock out the Context method") +// }, +// HeaderFunc: func() (metadata.MD, error) { +// panic("mock out the Header method") +// }, +// RecvFunc: func() (*StateChange, error) { +// panic("mock out the Recv method") +// }, +// RecvMsgFunc: func(m interface{}) error { +// panic("mock out the RecvMsg method") +// }, +// SendMsgFunc: func(m interface{}) error { +// panic("mock out the SendMsg method") +// }, +// TrailerFunc: func() metadata.MD { +// panic("mock out the Trailer method") +// }, +// } +// +// // use mockedKV_StateChangesClient in code that requires KV_StateChangesClient +// // and then make assertions. +// +// } +type KV_StateChangesClientMock struct { + // CloseSendFunc mocks the CloseSend method. + CloseSendFunc func() error + + // ContextFunc mocks the Context method. + ContextFunc func() context.Context + + // HeaderFunc mocks the Header method. + HeaderFunc func() (metadata.MD, error) + + // RecvFunc mocks the Recv method. + RecvFunc func() (*StateChange, error) + + // RecvMsgFunc mocks the RecvMsg method. + RecvMsgFunc func(m interface{}) error + + // SendMsgFunc mocks the SendMsg method. + SendMsgFunc func(m interface{}) error + + // TrailerFunc mocks the Trailer method. + TrailerFunc func() metadata.MD + + // calls tracks calls to the methods. + calls struct { + // CloseSend holds details about calls to the CloseSend method. + CloseSend []struct { + } + // Context holds details about calls to the Context method. + Context []struct { + } + // Header holds details about calls to the Header method. + Header []struct { + } + // Recv holds details about calls to the Recv method. + Recv []struct { + } + // RecvMsg holds details about calls to the RecvMsg method. + RecvMsg []struct { + // M is the m argument value. + M interface{} + } + // SendMsg holds details about calls to the SendMsg method. + SendMsg []struct { + // M is the m argument value. + M interface{} + } + // Trailer holds details about calls to the Trailer method. + Trailer []struct { + } + } + lockCloseSend sync.RWMutex + lockContext sync.RWMutex + lockHeader sync.RWMutex + lockRecv sync.RWMutex + lockRecvMsg sync.RWMutex + lockSendMsg sync.RWMutex + lockTrailer sync.RWMutex +} + +// CloseSend calls CloseSendFunc. +func (mock *KV_StateChangesClientMock) CloseSend() error { + callInfo := struct { + }{} + mock.lockCloseSend.Lock() + mock.calls.CloseSend = append(mock.calls.CloseSend, callInfo) + mock.lockCloseSend.Unlock() + if mock.CloseSendFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.CloseSendFunc() +} + +// CloseSendCalls gets all the calls that were made to CloseSend. +// Check the length with: +// len(mockedKV_StateChangesClient.CloseSendCalls()) +func (mock *KV_StateChangesClientMock) CloseSendCalls() []struct { +} { + var calls []struct { + } + mock.lockCloseSend.RLock() + calls = mock.calls.CloseSend + mock.lockCloseSend.RUnlock() + return calls +} + +// Context calls ContextFunc. +func (mock *KV_StateChangesClientMock) Context() context.Context { + callInfo := struct { + }{} + mock.lockContext.Lock() + mock.calls.Context = append(mock.calls.Context, callInfo) + mock.lockContext.Unlock() + if mock.ContextFunc == nil { + var ( + contextOut context.Context + ) + return contextOut + } + return mock.ContextFunc() +} + +// ContextCalls gets all the calls that were made to Context. +// Check the length with: +// len(mockedKV_StateChangesClient.ContextCalls()) +func (mock *KV_StateChangesClientMock) ContextCalls() []struct { +} { + var calls []struct { + } + mock.lockContext.RLock() + calls = mock.calls.Context + mock.lockContext.RUnlock() + return calls +} + +// Header calls HeaderFunc. +func (mock *KV_StateChangesClientMock) Header() (metadata.MD, error) { + callInfo := struct { + }{} + mock.lockHeader.Lock() + mock.calls.Header = append(mock.calls.Header, callInfo) + mock.lockHeader.Unlock() + if mock.HeaderFunc == nil { + var ( + mDOut metadata.MD + errOut error + ) + return mDOut, errOut + } + return mock.HeaderFunc() +} + +// HeaderCalls gets all the calls that were made to Header. +// Check the length with: +// len(mockedKV_StateChangesClient.HeaderCalls()) +func (mock *KV_StateChangesClientMock) HeaderCalls() []struct { +} { + var calls []struct { + } + mock.lockHeader.RLock() + calls = mock.calls.Header + mock.lockHeader.RUnlock() + return calls +} + +// Recv calls RecvFunc. +func (mock *KV_StateChangesClientMock) Recv() (*StateChange, error) { + callInfo := struct { + }{} + mock.lockRecv.Lock() + mock.calls.Recv = append(mock.calls.Recv, callInfo) + mock.lockRecv.Unlock() + if mock.RecvFunc == nil { + var ( + stateChangeOut *StateChange + errOut error + ) + return stateChangeOut, errOut + } + return mock.RecvFunc() +} + +// RecvCalls gets all the calls that were made to Recv. +// Check the length with: +// len(mockedKV_StateChangesClient.RecvCalls()) +func (mock *KV_StateChangesClientMock) RecvCalls() []struct { +} { + var calls []struct { + } + mock.lockRecv.RLock() + calls = mock.calls.Recv + mock.lockRecv.RUnlock() + return calls +} + +// RecvMsg calls RecvMsgFunc. +func (mock *KV_StateChangesClientMock) RecvMsg(m interface{}) error { + callInfo := struct { + M interface{} + }{ + M: m, + } + mock.lockRecvMsg.Lock() + mock.calls.RecvMsg = append(mock.calls.RecvMsg, callInfo) + mock.lockRecvMsg.Unlock() + if mock.RecvMsgFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.RecvMsgFunc(m) +} + +// RecvMsgCalls gets all the calls that were made to RecvMsg. +// Check the length with: +// len(mockedKV_StateChangesClient.RecvMsgCalls()) +func (mock *KV_StateChangesClientMock) RecvMsgCalls() []struct { + M interface{} +} { + var calls []struct { + M interface{} + } + mock.lockRecvMsg.RLock() + calls = mock.calls.RecvMsg + mock.lockRecvMsg.RUnlock() + return calls +} + +// SendMsg calls SendMsgFunc. +func (mock *KV_StateChangesClientMock) SendMsg(m interface{}) error { + callInfo := struct { + M interface{} + }{ + M: m, + } + mock.lockSendMsg.Lock() + mock.calls.SendMsg = append(mock.calls.SendMsg, callInfo) + mock.lockSendMsg.Unlock() + if mock.SendMsgFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.SendMsgFunc(m) +} + +// SendMsgCalls gets all the calls that were made to SendMsg. +// Check the length with: +// len(mockedKV_StateChangesClient.SendMsgCalls()) +func (mock *KV_StateChangesClientMock) SendMsgCalls() []struct { + M interface{} +} { + var calls []struct { + M interface{} + } + mock.lockSendMsg.RLock() + calls = mock.calls.SendMsg + mock.lockSendMsg.RUnlock() + return calls +} + +// Trailer calls TrailerFunc. +func (mock *KV_StateChangesClientMock) Trailer() metadata.MD { + callInfo := struct { + }{} + mock.lockTrailer.Lock() + mock.calls.Trailer = append(mock.calls.Trailer, callInfo) + mock.lockTrailer.Unlock() + if mock.TrailerFunc == nil { + var ( + mDOut metadata.MD + ) + return mDOut + } + return mock.TrailerFunc() +} + +// TrailerCalls gets all the calls that were made to Trailer. +// Check the length with: +// len(mockedKV_StateChangesClient.TrailerCalls()) +func (mock *KV_StateChangesClientMock) TrailerCalls() []struct { +} { + var calls []struct { + } + mock.lockTrailer.RLock() + calls = mock.calls.Trailer + mock.lockTrailer.RUnlock() + return calls +} diff --git a/gointerfaces/test_util.go b/gointerfaces/test_util.go index 44da34ad5..25e9b4575 100644 --- a/gointerfaces/test_util.go +++ b/gointerfaces/test_util.go @@ -1,4 +1,4 @@ package gointerfaces //go:generate moq -stub -out ./sentry/mocks.go ./sentry SentryServer SentryClient -//go:generate moq -stub -out ./remote/mocks.go ./remote KVClient +//go:generate moq -stub -out ./remote/mocks.go ./remote KVClient KV_StateChangesClient diff --git a/txpool/fetch.go b/txpool/fetch.go index 0cb4cfe66..1f70ba930 100644 --- a/txpool/fetch.go +++ b/txpool/fetch.go @@ -28,6 +28,7 @@ import ( "github.com/ledgerwatch/erigon-lib/gointerfaces" "github.com/ledgerwatch/erigon-lib/gointerfaces/remote" "github.com/ledgerwatch/erigon-lib/gointerfaces/sentry" + "github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/log/v3" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -43,9 +44,9 @@ type Fetch struct { sentryClients []sentry.SentryClient // sentry clients that will be used for accessing the network statusData *sentry.StatusData // Status data used for "handshaking" with sentries pool Pool // Transaction pool implementation - wg *sync.WaitGroup // used for synchronisation in the tests (nil when not in tests) + coreDB kv.RoDB + wg *sync.WaitGroup // used for synchronisation in the tests (nil when not in tests) stateChangesClient remote.KVClient - logger log.Logger } type Timings struct { @@ -63,15 +64,7 @@ var DefaultTimings = Timings{ // NewFetch creates a new fetch object that will work with given sentry clients. Since the // SentryClient here is an interface, it is suitable for mocking in tests (mock will need // to implement all the functions of the SentryClient interface). -func NewFetch(ctx context.Context, - sentryClients []sentry.SentryClient, - genesisHash [32]byte, - networkId uint64, - forks []uint64, - pool Pool, - stateChangesClient remote.KVClient, - logger log.Logger, -) *Fetch { +func NewFetch(ctx context.Context, sentryClients []sentry.SentryClient, genesisHash [32]byte, networkId uint64, forks []uint64, pool Pool, stateChangesClient remote.KVClient, db kv.RoDB) *Fetch { statusData := &sentry.StatusData{ NetworkId: networkId, TotalDifficulty: gointerfaces.ConvertUint256IntToH256(uint256.NewInt(0)), @@ -87,7 +80,7 @@ func NewFetch(ctx context.Context, sentryClients: sentryClients, statusData: statusData, pool: pool, - logger: logger, + coreDB: db, stateChangesClient: stateChangesClient, } } @@ -108,11 +101,19 @@ func (f *Fetch) ConnectSentries() { } } func (f *Fetch) ConnectCore() { - go func() { f.stateChangesLoop(f.ctx, f.stateChangesClient) }() + go func() { + for { + select { + case <-f.ctx.Done(): + return + default: + } + f.handleStateChanges(f.ctx, f.stateChangesClient) + } + }() } func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) { - logger := f.logger for { select { case <-f.ctx.Done(): @@ -125,7 +126,7 @@ func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) { return } // Report error and wait more - logger.Warn("sentry not ready yet", "err", err) + log.Warn("sentry not ready yet", "err", err) time.Sleep(time.Second) continue } @@ -154,7 +155,7 @@ func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) { if errors.Is(err, io.EOF) { return } - logger.Warn("messages", "err", err) + log.Warn("messages", "err", err) return } @@ -172,14 +173,14 @@ func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) { if errors.Is(err, io.EOF) { return } - logger.Warn("stream.Recv", "err", err) + log.Warn("stream.Recv", "err", err) return } if req == nil { return } - if err = f.handleInboundMessage(req, sentryClient); err != nil { - logger.Warn("Handling incoming message: %s", "err", err) + if err = f.handleInboundMessage(streamCtx, req, sentryClient); err != nil { + log.Warn("Handling incoming message: %s", "err", err) } if f.wg != nil { f.wg.Done() @@ -188,7 +189,7 @@ func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) { } } -func (f *Fetch) handleInboundMessage(req *sentry.InboundMessage, sentryClient sentry.SentryClient) error { +func (f *Fetch) handleInboundMessage(ctx context.Context, req *sentry.InboundMessage, sentryClient sentry.SentryClient) error { switch req.Id { case sentry.MessageId_NEW_POOLED_TRANSACTION_HASHES_66, sentry.MessageId_NEW_POOLED_TRANSACTION_HASHES_65: hashCount, pos, err := ParseHashesCount(req.Data, 0) @@ -281,7 +282,9 @@ func (f *Fetch) handleInboundMessage(req *sentry.InboundMessage, sentryClient se return err } } - if err := f.pool.Add(txs); err != nil { + if err := f.coreDB.View(ctx, func(tx kv.Tx) error { + return f.pool.Add(tx, txs) + }); err != nil { return err } } @@ -290,7 +293,6 @@ func (f *Fetch) handleInboundMessage(req *sentry.InboundMessage, sentryClient se } func (f *Fetch) receivePeerLoop(sentryClient sentry.SentryClient) { - logger := f.logger for { select { case <-f.ctx.Done(): @@ -303,7 +305,7 @@ func (f *Fetch) receivePeerLoop(sentryClient sentry.SentryClient) { return } // Report error and wait more - logger.Warn("sentry not ready yet", "err", err) + log.Warn("sentry not ready yet", "err", err) time.Sleep(time.Second) continue } @@ -323,7 +325,7 @@ func (f *Fetch) receivePeerLoop(sentryClient sentry.SentryClient) { if errors.Is(err, io.EOF) { return } - logger.Warn("peers", "err", err) + log.Warn("peers", "err", err) return } @@ -341,14 +343,14 @@ func (f *Fetch) receivePeerLoop(sentryClient sentry.SentryClient) { if errors.Is(err, io.EOF) { return } - logger.Warn("stream.Recv", "err", err) + log.Warn("stream.Recv", "err", err) return } if req == nil { return } if err = f.handleNewPeer(req); err != nil { - logger.Warn("Handling new peer", "err", err) + log.Warn("Handling new peer", "err", err) } if f.wg != nil { f.wg.Done() @@ -369,17 +371,7 @@ func (f *Fetch) handleNewPeer(req *sentry.PeersReply) error { return nil } -func (f *Fetch) stateChangesLoop(ctx context.Context, client remote.KVClient) { - for { - select { - case <-ctx.Done(): - return - default: - } - f.stateChangesStream(ctx, client) - } -} -func (f *Fetch) stateChangesStream(ctx context.Context, client remote.KVClient) { +func (f *Fetch) handleStateChanges(ctx context.Context, client remote.KVClient) { streamCtx, cancel := context.WithCancel(ctx) defer cancel() stream, err := client.StateChanges(streamCtx, &remote.StateChangeRequest{WithStorage: false, WithTransactions: true}, grpc.WaitForReady(true)) @@ -418,12 +410,27 @@ func (f *Fetch) stateChangesStream(ctx context.Context, client remote.KVClient) return } + parseCtx := NewTxParseContext() var unwindTxs, minedTxs TxSlots if req.Direction == remote.Direction_FORWARD { minedTxs.Growth(len(req.Txs)) + for i := range req.Txs { + minedTxs.txs[i] = &TxSlot{} + if _, err := parseCtx.ParseTransaction(req.Txs[i], 0, minedTxs.txs[i], minedTxs.senders.At(i)); err != nil { + log.Warn("stream.Recv", "err", err) + continue + } + } } if req.Direction == remote.Direction_UNWIND { unwindTxs.Growth(len(req.Txs)) + for i := range req.Txs { + unwindTxs.txs[i] = &TxSlot{} + if _, err := parseCtx.ParseTransaction(req.Txs[i], 0, unwindTxs.txs[i], unwindTxs.senders.At(i)); err != nil { + log.Warn("stream.Recv", "err", err) + continue + } + } } diff := map[string]senderInfo{} for _, change := range req.Changes { @@ -436,8 +443,13 @@ func (f *Fetch) stateChangesStream(ctx context.Context, client remote.KVClient) diff[string(addr[:])] = senderInfo{nonce: nonce, balance: balance} } - if err := f.pool.OnNewBlock(diff, unwindTxs, minedTxs, req.ProtocolBaseFee, req.BlockBaseFee, req.BlockHeight); err != nil { + if err := f.coreDB.View(ctx, func(tx kv.Tx) error { + return f.pool.OnNewBlock(tx, diff, unwindTxs, minedTxs, req.ProtocolBaseFee, req.BlockBaseFee, req.BlockHeight) + }); err != nil { log.Warn("onNewBlock", "err", err) } + if f.wg != nil { + f.wg.Done() + } } } diff --git a/txpool/fetch_test.go b/txpool/fetch_test.go index 88ef1ae1e..ff6cdceda 100644 --- a/txpool/fetch_test.go +++ b/txpool/fetch_test.go @@ -19,6 +19,7 @@ package txpool import ( "context" "fmt" + "io" "sync" "testing" @@ -26,15 +27,16 @@ import ( "github.com/ledgerwatch/erigon-lib/gointerfaces/remote" "github.com/ledgerwatch/erigon-lib/gointerfaces/sentry" "github.com/ledgerwatch/erigon-lib/gointerfaces/types" + "github.com/ledgerwatch/erigon-lib/kv/memdb" "github.com/ledgerwatch/log/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" ) func TestFetch(t *testing.T) { - logger := log.New() - ctx, cancelFn := context.WithCancel(context.Background()) - defer cancelFn() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() var genesisHash [32]byte var networkId uint64 = 1 @@ -44,7 +46,7 @@ func TestFetch(t *testing.T) { sentryClient := direct.NewSentryClientDirect(direct.ETH66, m) pool := &PoolMock{} - fetch := NewFetch(ctx, []sentry.SentryClient{sentryClient}, genesisHash, networkId, forks, pool, &remote.KVClientMock{}, logger) + fetch := NewFetch(ctx, []sentry.SentryClient{sentryClient}, genesisHash, networkId, forks, pool, &remote.KVClientMock{}, nil) var wg sync.WaitGroup fetch.SetWaitGroup(&wg) m.StreamWg.Add(2) @@ -133,3 +135,32 @@ func TestSendTxPropagate(t *testing.T) { } }) } + +func TestOnNewBlock(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := memdb.NewTestDB(t) + var genesisHash [32]byte + var networkId uint64 = 1 + + i := 0 + stream := &remote.KV_StateChangesClientMock{ + RecvFunc: func() (*remote.StateChange, error) { + if i > 0 { + return nil, io.EOF + } + i++ + return &remote.StateChange{Txs: [][]byte{decodeHex(txParseTests[0].payloadStr), decodeHex(txParseTests[1].payloadStr), decodeHex(txParseTests[2].payloadStr)}}, nil + }, + } + stateChanges := &remote.KVClientMock{ + StateChangesFunc: func(ctx context.Context, in *remote.StateChangeRequest, opts ...grpc.CallOption) (remote.KV_StateChangesClient, error) { + return stream, nil + }, + } + pool := &PoolMock{} + fetch := NewFetch(ctx, nil, genesisHash, networkId, nil, pool, stateChanges, db) + fetch.handleStateChanges(ctx, stateChanges) + assert.Equal(t, 1, len(pool.OnNewBlockCalls())) + assert.Equal(t, 3, len(pool.OnNewBlockCalls()[0].MinedTxs.txs)) +} diff --git a/txpool/mocks_test.go b/txpool/mocks_test.go index 1b4ae4343..afbfe311c 100644 --- a/txpool/mocks_test.go +++ b/txpool/mocks_test.go @@ -5,6 +5,8 @@ package txpool import ( "sync" + + "github.com/ledgerwatch/erigon-lib/kv" ) // Ensure, that PoolMock does implement Pool. @@ -17,7 +19,7 @@ var _ Pool = &PoolMock{} // // // make and configure a mocked Pool // mockedPool := &PoolMock{ -// AddFunc: func(newTxs TxSlots) error { +// AddFunc: func(db kv.Tx, newTxs TxSlots) error { // panic("mock out the Add method") // }, // AddNewGoodPeerFunc: func(peerID PeerID) { @@ -29,7 +31,7 @@ var _ Pool = &PoolMock{} // IdHashKnownFunc: func(hash []byte) bool { // panic("mock out the IdHashKnown method") // }, -// OnNewBlockFunc: func(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error { +// OnNewBlockFunc: func(db kv.Tx, stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error { // panic("mock out the OnNewBlock method") // }, // } @@ -40,7 +42,7 @@ var _ Pool = &PoolMock{} // } type PoolMock struct { // AddFunc mocks the Add method. - AddFunc func(newTxs TxSlots) error + AddFunc func(db kv.Tx, newTxs TxSlots) error // AddNewGoodPeerFunc mocks the AddNewGoodPeer method. AddNewGoodPeerFunc func(peerID PeerID) @@ -52,12 +54,14 @@ type PoolMock struct { IdHashKnownFunc func(hash []byte) bool // OnNewBlockFunc mocks the OnNewBlock method. - OnNewBlockFunc func(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error + OnNewBlockFunc func(db kv.Tx, stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error // calls tracks calls to the methods. calls struct { // Add holds details about calls to the Add method. Add []struct { + // Db is the db argument value. + Db kv.Tx // NewTxs is the newTxs argument value. NewTxs TxSlots } @@ -78,6 +82,8 @@ type PoolMock struct { } // OnNewBlock holds details about calls to the OnNewBlock method. OnNewBlock []struct { + // Db is the db argument value. + Db kv.Tx // StateChanges is the stateChanges argument value. StateChanges map[string]senderInfo // UnwindTxs is the unwindTxs argument value. @@ -100,10 +106,12 @@ type PoolMock struct { } // Add calls AddFunc. -func (mock *PoolMock) Add(newTxs TxSlots) error { +func (mock *PoolMock) Add(db kv.Tx, newTxs TxSlots) error { callInfo := struct { + Db kv.Tx NewTxs TxSlots }{ + Db: db, NewTxs: newTxs, } mock.lockAdd.Lock() @@ -115,16 +123,18 @@ func (mock *PoolMock) Add(newTxs TxSlots) error { ) return errOut } - return mock.AddFunc(newTxs) + return mock.AddFunc(db, newTxs) } // AddCalls gets all the calls that were made to Add. // Check the length with: // len(mockedPool.AddCalls()) func (mock *PoolMock) AddCalls() []struct { + Db kv.Tx NewTxs TxSlots } { var calls []struct { + Db kv.Tx NewTxs TxSlots } mock.lockAdd.RLock() @@ -233,8 +243,9 @@ func (mock *PoolMock) IdHashKnownCalls() []struct { } // OnNewBlock calls OnNewBlockFunc. -func (mock *PoolMock) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error { +func (mock *PoolMock) OnNewBlock(db kv.Tx, stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error { callInfo := struct { + Db kv.Tx StateChanges map[string]senderInfo UnwindTxs TxSlots MinedTxs TxSlots @@ -242,6 +253,7 @@ func (mock *PoolMock) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs T BlockBaseFee uint64 BlockHeight uint64 }{ + Db: db, StateChanges: stateChanges, UnwindTxs: unwindTxs, MinedTxs: minedTxs, @@ -258,13 +270,14 @@ func (mock *PoolMock) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs T ) return errOut } - return mock.OnNewBlockFunc(stateChanges, unwindTxs, minedTxs, protocolBaseFee, blockBaseFee, blockHeight) + return mock.OnNewBlockFunc(db, stateChanges, unwindTxs, minedTxs, protocolBaseFee, blockBaseFee, blockHeight) } // OnNewBlockCalls gets all the calls that were made to OnNewBlock. // Check the length with: // len(mockedPool.OnNewBlockCalls()) func (mock *PoolMock) OnNewBlockCalls() []struct { + Db kv.Tx StateChanges map[string]senderInfo UnwindTxs TxSlots MinedTxs TxSlots @@ -273,6 +286,7 @@ func (mock *PoolMock) OnNewBlockCalls() []struct { BlockHeight uint64 } { var calls []struct { + Db kv.Tx StateChanges map[string]senderInfo UnwindTxs TxSlots MinedTxs TxSlots diff --git a/txpool/pool.go b/txpool/pool.go index 0c3fb8321..ef60ed532 100644 --- a/txpool/pool.go +++ b/txpool/pool.go @@ -27,6 +27,7 @@ import ( "github.com/google/btree" lru "github.com/hashicorp/golang-lru" "github.com/holiman/uint256" + "github.com/ledgerwatch/erigon-lib/kv" "go.uber.org/atomic" ) @@ -37,8 +38,8 @@ type Pool interface { // IdHashKnown check whether transaction with given Id hash is known to the pool IdHashKnown(hash []byte) bool GetRlp(hash []byte) []byte - Add(newTxs TxSlots) error - OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error + Add(db kv.Tx, newTxs TxSlots) error + OnNewBlock(db kv.Tx, stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error AddNewGoodPeer(peerID PeerID) } @@ -87,6 +88,8 @@ const PendingSubPoolLimit = 1024 const BaseFeeSubPoolLimit = 1024 const QueuedSubPoolLimit = 1024 +const MaxSendersInfoCache = 1024 + type nonce2Tx struct{ *btree.BTree } type senderInfo struct { @@ -115,6 +118,7 @@ type TxPool struct { protocolBaseFee atomic.Uint64 blockBaseFee atomic.Uint64 + senderID uint64 senderIDs map[string]uint64 senderInfo map[uint64]*senderInfo byHash map[string]*metaTx // tx_hash => tx @@ -202,7 +206,7 @@ func (p *TxPool) IdHashIsLocal(hash []byte) bool { } func (p *TxPool) OnNewPeer(peerID PeerID) { p.recentlyConnectedPeers.AddPeer(peerID) } -func (p *TxPool) Add(newTxs TxSlots) error { +func (p *TxPool) Add(coreDB kv.Tx, newTxs TxSlots) error { p.lock.Lock() defer p.lock.Unlock() if err := newTxs.Valid(); err != nil { @@ -214,7 +218,9 @@ func (p *TxPool) Add(newTxs TxSlots) error { return fmt.Errorf("non-zero base fee") } - setTxSenderID(p.senderIDs, p.senderInfo, newTxs) + if err := setTxSenderID(coreDB, &p.senderID, p.senderIDs, p.senderInfo, newTxs); err != nil { + return err + } if err := onNewTxs(p.senderInfo, newTxs, protocolBaseFee, blockBaseFee, p.pending, p.baseFee, p.queued, p.byHash, p.localsHistory); err != nil { return err } @@ -286,7 +292,7 @@ func onNewTxs(senderInfo map[uint64]*senderInfo, newTxs TxSlots, protocolBaseFee return nil } -func (p *TxPool) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error { +func (p *TxPool) OnNewBlock(coreDB kv.Tx, stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error { p.lock.Lock() defer p.lock.Unlock() if err := unwindTxs.Valid(); err != nil { @@ -300,13 +306,18 @@ func (p *TxPool) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, mined p.protocolBaseFee.Store(protocolBaseFee) p.blockBaseFee.Store(blockBaseFee) - setTxSenderID(p.senderIDs, p.senderInfo, unwindTxs) - setTxSenderID(p.senderIDs, p.senderInfo, minedTxs) + if err := setTxSenderID(coreDB, &p.senderID, p.senderIDs, p.senderInfo, unwindTxs); err != nil { + return err + } + if err := setTxSenderID(coreDB, &p.senderID, p.senderIDs, p.senderInfo, minedTxs); err != nil { + return err + } for addr, id := range p.senderIDs { // merge state changes if v, ok := stateChanges[addr]; ok { p.senderInfo[id] = &v } } + if err := onNewBlock(p.senderInfo, unwindTxs, minedTxs.txs, protocolBaseFee, blockBaseFee, p.pending, p.baseFee, p.queued, p.byHash, p.localsHistory); err != nil { return err } @@ -326,24 +337,53 @@ func (p *TxPool) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, mined } } + /* + // evict sendersInfo without txs + if len(p.senderIDs) > MaxSendersInfoCache { + for i := range p.senderInfo { + if p.senderInfo[i].txNonce2Tx.Len() > 0 { + continue + } + for addr, id := range p.senderIDs { + if id == i { + delete(p.senderIDs, addr) + } + } + delete(p.senderInfo, i) + } + } + */ + return nil } -func setTxSenderID(senderIDs map[string]uint64, senderInfo map[uint64]*senderInfo, txs TxSlots) { + +func setTxSenderID(coreDB kv.Tx, senderIDSequence *uint64, senderIDs map[string]uint64, sendersInfo map[uint64]*senderInfo, txs TxSlots) error { for i := range txs.txs { addr := string(txs.senders.At(i)) + // assign ID to each new sender id, ok := senderIDs[addr] if !ok { - for i := range senderInfo { //TODO: create field for it? - if id < i { - id = i - } - } - id++ - senderIDs[addr] = id + *senderIDSequence++ + senderIDs[addr] = *senderIDSequence } txs.txs[i].senderID = id + + // load data from db if need + _, ok = sendersInfo[txs.txs[i].senderID] + if !ok { + encoded, err := coreDB.GetOne(kv.PlainState, txs.senders.At(i)) + if err != nil { + return err + } + nonce, balance, err := DecodeSender(encoded) + if err != nil { + return err + } + sendersInfo[txs.txs[i].senderID] = &senderInfo{nonce: nonce, balance: balance} + } } + return nil } func onNewBlock(senderInfo map[uint64]*senderInfo, unwindTxs TxSlots, minedTxs []*TxSlot, protocolBaseFee, blockBaseFee uint64, pending, baseFee, queued *SubPool, byHash map[string]*metaTx, localsHistory *lru.Cache) error { diff --git a/txpool/pool_fuzz_test.go b/txpool/pool_fuzz_test.go index 257f57592..d905de78f 100644 --- a/txpool/pool_fuzz_test.go +++ b/txpool/pool_fuzz_test.go @@ -426,19 +426,19 @@ func FuzzOnNewBlocks11(f *testing.F) { // go to first fork //fmt.Printf("ll1: %d,%d,%d\n", pool.pending.Len(), pool.baseFee.Len(), pool.queued.Len()) unwindTxs, minedTxs1, p2pReceived, minedTxs2 := splitDataset(txs) - err = pool.OnNewBlock(map[string]senderInfo{}, unwindTxs, minedTxs1, protocolBaseFee, blockBaseFee, 1) + err = pool.OnNewBlock(nil, map[string]senderInfo{}, unwindTxs, minedTxs1, protocolBaseFee, blockBaseFee, 1) assert.NoError(err) check(unwindTxs, minedTxs1, "fork1") checkNotify(unwindTxs, minedTxs1, "fork1") // unwind everything and switch to new fork (need unwind mined now) - err = pool.OnNewBlock(map[string]senderInfo{}, minedTxs1, minedTxs2, protocolBaseFee, blockBaseFee, 2) + err = pool.OnNewBlock(nil, map[string]senderInfo{}, minedTxs1, minedTxs2, protocolBaseFee, blockBaseFee, 2) assert.NoError(err) check(minedTxs1, minedTxs2, "fork2") checkNotify(minedTxs1, minedTxs2, "fork2") // add some remote txs from p2p - err = pool.Add(p2pReceived) + err = pool.Add(nil, p2pReceived) assert.NoError(err) check(p2pReceived, TxSlots{}, "p2pmsg1") checkNotify(p2pReceived, TxSlots{}, "p2pmsg1")