diff --git a/gointerfaces/remote/mocks.go b/gointerfaces/remote/mocks.go new file mode 100644 index 000000000..17890ba3d --- /dev/null +++ b/gointerfaces/remote/mocks.go @@ -0,0 +1,205 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package remote + +import ( + context "context" + types "github.com/ledgerwatch/erigon-lib/gointerfaces/types" + grpc "google.golang.org/grpc" + emptypb "google.golang.org/protobuf/types/known/emptypb" + sync "sync" +) + +// Ensure, that KVClientMock does implement KVClient. +// If this is not the case, regenerate this file with moq. +var _ KVClient = &KVClientMock{} + +// KVClientMock is a mock implementation of KVClient. +// +// func TestSomethingThatUsesKVClient(t *testing.T) { +// +// // make and configure a mocked KVClient +// mockedKVClient := &KVClientMock{ +// StateChangesFunc: func(ctx context.Context, in *StateChangeRequest, opts ...grpc.CallOption) (KV_StateChangesClient, error) { +// panic("mock out the StateChanges method") +// }, +// TxFunc: func(ctx context.Context, opts ...grpc.CallOption) (KV_TxClient, error) { +// panic("mock out the Tx method") +// }, +// VersionFunc: func(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*types.VersionReply, error) { +// panic("mock out the Version method") +// }, +// } +// +// // use mockedKVClient in code that requires KVClient +// // and then make assertions. +// +// } +type KVClientMock struct { + // StateChangesFunc mocks the StateChanges method. + StateChangesFunc func(ctx context.Context, in *StateChangeRequest, opts ...grpc.CallOption) (KV_StateChangesClient, error) + + // TxFunc mocks the Tx method. + TxFunc func(ctx context.Context, opts ...grpc.CallOption) (KV_TxClient, error) + + // VersionFunc mocks the Version method. + VersionFunc func(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*types.VersionReply, error) + + // calls tracks calls to the methods. + calls struct { + // StateChanges holds details about calls to the StateChanges method. + StateChanges []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // In is the in argument value. + In *StateChangeRequest + // Opts is the opts argument value. + Opts []grpc.CallOption + } + // Tx holds details about calls to the Tx method. + Tx []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Opts is the opts argument value. + Opts []grpc.CallOption + } + // Version holds details about calls to the Version method. + Version []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // In is the in argument value. + In *emptypb.Empty + // Opts is the opts argument value. + Opts []grpc.CallOption + } + } + lockStateChanges sync.RWMutex + lockTx sync.RWMutex + lockVersion sync.RWMutex +} + +// StateChanges calls StateChangesFunc. +func (mock *KVClientMock) StateChanges(ctx context.Context, in *StateChangeRequest, opts ...grpc.CallOption) (KV_StateChangesClient, error) { + callInfo := struct { + Ctx context.Context + In *StateChangeRequest + Opts []grpc.CallOption + }{ + Ctx: ctx, + In: in, + Opts: opts, + } + mock.lockStateChanges.Lock() + mock.calls.StateChanges = append(mock.calls.StateChanges, callInfo) + mock.lockStateChanges.Unlock() + if mock.StateChangesFunc == nil { + var ( + kV_StateChangesClientOut KV_StateChangesClient + errOut error + ) + return kV_StateChangesClientOut, errOut + } + return mock.StateChangesFunc(ctx, in, opts...) +} + +// StateChangesCalls gets all the calls that were made to StateChanges. +// Check the length with: +// len(mockedKVClient.StateChangesCalls()) +func (mock *KVClientMock) StateChangesCalls() []struct { + Ctx context.Context + In *StateChangeRequest + Opts []grpc.CallOption +} { + var calls []struct { + Ctx context.Context + In *StateChangeRequest + Opts []grpc.CallOption + } + mock.lockStateChanges.RLock() + calls = mock.calls.StateChanges + mock.lockStateChanges.RUnlock() + return calls +} + +// Tx calls TxFunc. +func (mock *KVClientMock) Tx(ctx context.Context, opts ...grpc.CallOption) (KV_TxClient, error) { + callInfo := struct { + Ctx context.Context + Opts []grpc.CallOption + }{ + Ctx: ctx, + Opts: opts, + } + mock.lockTx.Lock() + mock.calls.Tx = append(mock.calls.Tx, callInfo) + mock.lockTx.Unlock() + if mock.TxFunc == nil { + var ( + kV_TxClientOut KV_TxClient + errOut error + ) + return kV_TxClientOut, errOut + } + return mock.TxFunc(ctx, opts...) +} + +// TxCalls gets all the calls that were made to Tx. +// Check the length with: +// len(mockedKVClient.TxCalls()) +func (mock *KVClientMock) TxCalls() []struct { + Ctx context.Context + Opts []grpc.CallOption +} { + var calls []struct { + Ctx context.Context + Opts []grpc.CallOption + } + mock.lockTx.RLock() + calls = mock.calls.Tx + mock.lockTx.RUnlock() + return calls +} + +// Version calls VersionFunc. +func (mock *KVClientMock) Version(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*types.VersionReply, error) { + callInfo := struct { + Ctx context.Context + In *emptypb.Empty + Opts []grpc.CallOption + }{ + Ctx: ctx, + In: in, + Opts: opts, + } + mock.lockVersion.Lock() + mock.calls.Version = append(mock.calls.Version, callInfo) + mock.lockVersion.Unlock() + if mock.VersionFunc == nil { + var ( + versionReplyOut *types.VersionReply + errOut error + ) + return versionReplyOut, errOut + } + return mock.VersionFunc(ctx, in, opts...) +} + +// VersionCalls gets all the calls that were made to Version. +// Check the length with: +// len(mockedKVClient.VersionCalls()) +func (mock *KVClientMock) VersionCalls() []struct { + Ctx context.Context + In *emptypb.Empty + Opts []grpc.CallOption +} { + var calls []struct { + Ctx context.Context + In *emptypb.Empty + Opts []grpc.CallOption + } + mock.lockVersion.RLock() + calls = mock.calls.Version + mock.lockVersion.RUnlock() + return calls +} diff --git a/gointerfaces/test_util.go b/gointerfaces/test_util.go index 500328691..44da34ad5 100644 --- a/gointerfaces/test_util.go +++ b/gointerfaces/test_util.go @@ -1,3 +1,4 @@ package gointerfaces //go:generate moq -stub -out ./sentry/mocks.go ./sentry SentryServer SentryClient +//go:generate moq -stub -out ./remote/mocks.go ./remote KVClient diff --git a/txpool/fetch.go b/txpool/fetch.go index d8edaeede..0cb4cfe66 100644 --- a/txpool/fetch.go +++ b/txpool/fetch.go @@ -26,6 +26,7 @@ import ( "github.com/holiman/uint256" "github.com/ledgerwatch/erigon-lib/gointerfaces" + "github.com/ledgerwatch/erigon-lib/gointerfaces/remote" "github.com/ledgerwatch/erigon-lib/gointerfaces/sentry" "github.com/ledgerwatch/log/v3" "google.golang.org/grpc" @@ -38,12 +39,13 @@ import ( // genesis hash and list of forks, but with zero max block and total difficulty // Sentry should have a logic not to overwrite statusData with messages from tx pool type Fetch struct { - ctx context.Context // Context used for cancellation and closing of the fetcher - sentryClients []sentry.SentryClient // sentry clients that will be used for accessing the network - statusData *sentry.StatusData // Status data used for "handshaking" with sentries - pool Pool // Transaction pool implementation - wg *sync.WaitGroup // used for synchronisation in the tests (nil when not in tests) - logger log.Logger + ctx context.Context // Context used for cancellation and closing of the fetcher + sentryClients []sentry.SentryClient // sentry clients that will be used for accessing the network + statusData *sentry.StatusData // Status data used for "handshaking" with sentries + pool Pool // Transaction pool implementation + wg *sync.WaitGroup // used for synchronisation in the tests (nil when not in tests) + stateChangesClient remote.KVClient + logger log.Logger } type Timings struct { @@ -67,6 +69,7 @@ func NewFetch(ctx context.Context, networkId uint64, forks []uint64, pool Pool, + stateChangesClient remote.KVClient, logger log.Logger, ) *Fetch { statusData := &sentry.StatusData{ @@ -80,11 +83,12 @@ func NewFetch(ctx context.Context, }, } return &Fetch{ - ctx: ctx, - sentryClients: sentryClients, - statusData: statusData, - pool: pool, - logger: logger, + ctx: ctx, + sentryClients: sentryClients, + statusData: statusData, + pool: pool, + logger: logger, + stateChangesClient: stateChangesClient, } } @@ -92,8 +96,8 @@ func (f *Fetch) SetWaitGroup(wg *sync.WaitGroup) { f.wg = wg } -// Start initialises connection to the sentry -func (f *Fetch) Start() { +// ConnectSentries initialises connection to the sentry +func (f *Fetch) ConnectSentries() { for i := range f.sentryClients { go func(i int) { f.receiveMessageLoop(f.sentryClients[i]) @@ -103,6 +107,9 @@ func (f *Fetch) Start() { }(i) } } +func (f *Fetch) ConnectCore() { + go func() { f.stateChangesLoop(f.ctx, f.stateChangesClient) }() +} func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) { logger := f.logger @@ -361,3 +368,76 @@ func (f *Fetch) handleNewPeer(req *sentry.PeersReply) error { return nil } + +func (f *Fetch) stateChangesLoop(ctx context.Context, client remote.KVClient) { + for { + select { + case <-ctx.Done(): + return + default: + } + f.stateChangesStream(ctx, client) + } +} +func (f *Fetch) stateChangesStream(ctx context.Context, client remote.KVClient) { + streamCtx, cancel := context.WithCancel(ctx) + defer cancel() + stream, err := client.StateChanges(streamCtx, &remote.StateChangeRequest{WithStorage: false, WithTransactions: true}, grpc.WaitForReady(true)) + if err != nil { + select { + case <-ctx.Done(): + return + default: + } + if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled { + return + } + if errors.Is(err, io.EOF) { + return + } + time.Sleep(time.Second) + log.Warn("state changes", "err", err) + } + for req, err := stream.Recv(); ; req, err = stream.Recv() { + if err != nil { + select { + case <-ctx.Done(): + return + default: + } + if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled { + return + } + if errors.Is(err, io.EOF) { + return + } + log.Warn("stream.Recv", "err", err) + return + } + if req == nil { + return + } + + var unwindTxs, minedTxs TxSlots + if req.Direction == remote.Direction_FORWARD { + minedTxs.Growth(len(req.Txs)) + } + if req.Direction == remote.Direction_UNWIND { + unwindTxs.Growth(len(req.Txs)) + } + diff := map[string]senderInfo{} + for _, change := range req.Changes { + nonce, balance, err := DecodeSender(change.Data) + if err != nil { + log.Warn("stateChanges.decodeSender", "err", err) + continue + } + addr := gointerfaces.ConvertH160toAddress(change.Address) + diff[string(addr[:])] = senderInfo{nonce: nonce, balance: balance} + } + + if err := f.pool.OnNewBlock(diff, unwindTxs, minedTxs, req.ProtocolBaseFee, req.BlockBaseFee, req.BlockHeight); err != nil { + log.Warn("onNewBlock", "err", err) + } + } +} diff --git a/txpool/fetch_test.go b/txpool/fetch_test.go index 502cb28c9..88ef1ae1e 100644 --- a/txpool/fetch_test.go +++ b/txpool/fetch_test.go @@ -23,6 +23,7 @@ import ( "testing" "github.com/ledgerwatch/erigon-lib/direct" + "github.com/ledgerwatch/erigon-lib/gointerfaces/remote" "github.com/ledgerwatch/erigon-lib/gointerfaces/sentry" "github.com/ledgerwatch/erigon-lib/gointerfaces/types" "github.com/ledgerwatch/log/v3" @@ -43,11 +44,11 @@ func TestFetch(t *testing.T) { sentryClient := direct.NewSentryClientDirect(direct.ETH66, m) pool := &PoolMock{} - fetch := NewFetch(ctx, []sentry.SentryClient{sentryClient}, genesisHash, networkId, forks, pool, logger) + fetch := NewFetch(ctx, []sentry.SentryClient{sentryClient}, genesisHash, networkId, forks, pool, &remote.KVClientMock{}, logger) var wg sync.WaitGroup fetch.SetWaitGroup(&wg) m.StreamWg.Add(2) - fetch.Start() + fetch.ConnectSentries() m.StreamWg.Wait() // Send one transaction id wg.Add(1) diff --git a/txpool/mocks_test.go b/txpool/mocks_test.go index d3178604d..1b4ae4343 100644 --- a/txpool/mocks_test.go +++ b/txpool/mocks_test.go @@ -29,6 +29,9 @@ var _ Pool = &PoolMock{} // IdHashKnownFunc: func(hash []byte) bool { // panic("mock out the IdHashKnown method") // }, +// OnNewBlockFunc: func(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error { +// panic("mock out the OnNewBlock method") +// }, // } // // // use mockedPool in code that requires Pool @@ -48,6 +51,9 @@ type PoolMock struct { // IdHashKnownFunc mocks the IdHashKnown method. IdHashKnownFunc func(hash []byte) bool + // OnNewBlockFunc mocks the OnNewBlock method. + OnNewBlockFunc func(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error + // calls tracks calls to the methods. calls struct { // Add holds details about calls to the Add method. @@ -70,11 +76,27 @@ type PoolMock struct { // Hash is the hash argument value. Hash []byte } + // OnNewBlock holds details about calls to the OnNewBlock method. + OnNewBlock []struct { + // StateChanges is the stateChanges argument value. + StateChanges map[string]senderInfo + // UnwindTxs is the unwindTxs argument value. + UnwindTxs TxSlots + // MinedTxs is the minedTxs argument value. + MinedTxs TxSlots + // ProtocolBaseFee is the protocolBaseFee argument value. + ProtocolBaseFee uint64 + // BlockBaseFee is the blockBaseFee argument value. + BlockBaseFee uint64 + // BlockHeight is the blockHeight argument value. + BlockHeight uint64 + } } lockAdd sync.RWMutex lockAddNewGoodPeer sync.RWMutex lockGetRlp sync.RWMutex lockIdHashKnown sync.RWMutex + lockOnNewBlock sync.RWMutex } // Add calls AddFunc. @@ -209,3 +231,57 @@ func (mock *PoolMock) IdHashKnownCalls() []struct { mock.lockIdHashKnown.RUnlock() return calls } + +// OnNewBlock calls OnNewBlockFunc. +func (mock *PoolMock) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error { + callInfo := struct { + StateChanges map[string]senderInfo + UnwindTxs TxSlots + MinedTxs TxSlots + ProtocolBaseFee uint64 + BlockBaseFee uint64 + BlockHeight uint64 + }{ + StateChanges: stateChanges, + UnwindTxs: unwindTxs, + MinedTxs: minedTxs, + ProtocolBaseFee: protocolBaseFee, + BlockBaseFee: blockBaseFee, + BlockHeight: blockHeight, + } + mock.lockOnNewBlock.Lock() + mock.calls.OnNewBlock = append(mock.calls.OnNewBlock, callInfo) + mock.lockOnNewBlock.Unlock() + if mock.OnNewBlockFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.OnNewBlockFunc(stateChanges, unwindTxs, minedTxs, protocolBaseFee, blockBaseFee, blockHeight) +} + +// OnNewBlockCalls gets all the calls that were made to OnNewBlock. +// Check the length with: +// len(mockedPool.OnNewBlockCalls()) +func (mock *PoolMock) OnNewBlockCalls() []struct { + StateChanges map[string]senderInfo + UnwindTxs TxSlots + MinedTxs TxSlots + ProtocolBaseFee uint64 + BlockBaseFee uint64 + BlockHeight uint64 +} { + var calls []struct { + StateChanges map[string]senderInfo + UnwindTxs TxSlots + MinedTxs TxSlots + ProtocolBaseFee uint64 + BlockBaseFee uint64 + BlockHeight uint64 + } + mock.lockOnNewBlock.RLock() + calls = mock.calls.OnNewBlock + mock.lockOnNewBlock.RUnlock() + return calls +} diff --git a/txpool/pool.go b/txpool/pool.go index e1b4adb6c..0c3fb8321 100644 --- a/txpool/pool.go +++ b/txpool/pool.go @@ -38,6 +38,7 @@ type Pool interface { IdHashKnown(hash []byte) bool GetRlp(hash []byte) []byte Add(newTxs TxSlots) error + OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error AddNewGoodPeer(peerID PeerID) } @@ -110,6 +111,7 @@ func (i *nonce2TxItem) Less(than btree.Item) bool { type TxPool struct { lock *sync.RWMutex + blockHeight atomic.Uint64 protocolBaseFee atomic.Uint64 blockBaseFee atomic.Uint64 @@ -284,7 +286,7 @@ func onNewTxs(senderInfo map[uint64]*senderInfo, newTxs TxSlots, protocolBaseFee return nil } -func (p *TxPool) OnNewBlock(unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee uint64) error { +func (p *TxPool) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error { p.lock.Lock() defer p.lock.Unlock() if err := unwindTxs.Valid(); err != nil { @@ -294,11 +296,17 @@ func (p *TxPool) OnNewBlock(unwindTxs, minedTxs TxSlots, protocolBaseFee, blockB return err } + p.blockHeight.Store(blockHeight) p.protocolBaseFee.Store(protocolBaseFee) p.blockBaseFee.Store(blockBaseFee) setTxSenderID(p.senderIDs, p.senderInfo, unwindTxs) setTxSenderID(p.senderIDs, p.senderInfo, minedTxs) + for addr, id := range p.senderIDs { // merge state changes + if v, ok := stateChanges[addr]; ok { + p.senderInfo[id] = &v + } + } if err := onNewBlock(p.senderInfo, unwindTxs, minedTxs.txs, protocolBaseFee, blockBaseFee, p.pending, p.baseFee, p.queued, p.byHash, p.localsHistory); err != nil { return err } @@ -322,7 +330,9 @@ func (p *TxPool) OnNewBlock(unwindTxs, minedTxs TxSlots, protocolBaseFee, blockB } func setTxSenderID(senderIDs map[string]uint64, senderInfo map[uint64]*senderInfo, txs TxSlots) { for i := range txs.txs { - id, ok := senderIDs[string(txs.senders.At(i))] + addr := string(txs.senders.At(i)) + + id, ok := senderIDs[addr] if !ok { for i := range senderInfo { //TODO: create field for it? if id < i { @@ -330,7 +340,7 @@ func setTxSenderID(senderIDs map[string]uint64, senderInfo map[uint64]*senderInf } } id++ - senderIDs[string(txs.senders.At(i))] = id + senderIDs[addr] = id } txs.txs[i].senderID = id } diff --git a/txpool/pool_fuzz_test.go b/txpool/pool_fuzz_test.go index f0f3bfd53..257f57592 100644 --- a/txpool/pool_fuzz_test.go +++ b/txpool/pool_fuzz_test.go @@ -426,13 +426,13 @@ func FuzzOnNewBlocks11(f *testing.F) { // go to first fork //fmt.Printf("ll1: %d,%d,%d\n", pool.pending.Len(), pool.baseFee.Len(), pool.queued.Len()) unwindTxs, minedTxs1, p2pReceived, minedTxs2 := splitDataset(txs) - err = pool.OnNewBlock(unwindTxs, minedTxs1, protocolBaseFee, blockBaseFee) + err = pool.OnNewBlock(map[string]senderInfo{}, unwindTxs, minedTxs1, protocolBaseFee, blockBaseFee, 1) assert.NoError(err) check(unwindTxs, minedTxs1, "fork1") checkNotify(unwindTxs, minedTxs1, "fork1") // unwind everything and switch to new fork (need unwind mined now) - err = pool.OnNewBlock(minedTxs1, minedTxs2, protocolBaseFee, blockBaseFee) + err = pool.OnNewBlock(map[string]senderInfo{}, minedTxs1, minedTxs2, protocolBaseFee, blockBaseFee, 2) assert.NoError(err) check(minedTxs1, minedTxs2, "fork2") checkNotify(minedTxs1, minedTxs2, "fork2") diff --git a/txpool/types.go b/txpool/types.go index aabceb080..5fce74dc1 100644 --- a/txpool/types.go +++ b/txpool/types.go @@ -426,3 +426,48 @@ func (s *TxSlots) Growth(targetSize int) { } var addressesGrowth = make([]byte, 20) + +func DecodeSender(enc []byte) (nonce uint64, balance uint256.Int, err error) { + if len(enc) == 0 { + return + } + + var fieldSet = enc[0] + var pos = 1 + + if fieldSet&1 > 0 { + decodeLength := int(enc[pos]) + + if len(enc) < pos+decodeLength+1 { + return nonce, balance, fmt.Errorf( + "malformed CBOR for Account.Nonce: %s, Length %d", + enc[pos+1:], decodeLength) + } + + nonce = bytesToUint64(enc[pos+1 : pos+decodeLength+1]) + pos += decodeLength + 1 + } + + if fieldSet&2 > 0 { + decodeLength := int(enc[pos]) + + if len(enc) < pos+decodeLength+1 { + return nonce, balance, fmt.Errorf( + "malformed CBOR for Account.Nonce: %s, Length %d", + enc[pos+1:], decodeLength) + } + + (&balance).SetBytes(enc[pos+1 : pos+decodeLength+1]) + } + return +} + +func bytesToUint64(buf []byte) (x uint64) { + for i, b := range buf { + x = x<<8 + uint64(b) + if i == 7 { + return + } + } + return +}