Merge pull request #27 from ledgerwatch/pool11

Pool: get senders from DB if need, add senderIDSequence
This commit is contained in:
Alex Sharov 2021-08-11 11:34:21 +07:00 committed by GitHub
commit 67d92025d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 478 additions and 69 deletions

View File

@ -7,6 +7,7 @@ import (
context "context"
types "github.com/ledgerwatch/erigon-lib/gointerfaces/types"
grpc "google.golang.org/grpc"
"google.golang.org/grpc/metadata"
emptypb "google.golang.org/protobuf/types/known/emptypb"
sync "sync"
)
@ -203,3 +204,314 @@ func (mock *KVClientMock) VersionCalls() []struct {
mock.lockVersion.RUnlock()
return calls
}
// Ensure, that KV_StateChangesClientMock does implement KV_StateChangesClient.
// If this is not the case, regenerate this file with moq.
var _ KV_StateChangesClient = &KV_StateChangesClientMock{}
// KV_StateChangesClientMock is a mock implementation of KV_StateChangesClient.
//
// func TestSomethingThatUsesKV_StateChangesClient(t *testing.T) {
//
// // make and configure a mocked KV_StateChangesClient
// mockedKV_StateChangesClient := &KV_StateChangesClientMock{
// CloseSendFunc: func() error {
// panic("mock out the CloseSend method")
// },
// ContextFunc: func() context.Context {
// panic("mock out the Context method")
// },
// HeaderFunc: func() (metadata.MD, error) {
// panic("mock out the Header method")
// },
// RecvFunc: func() (*StateChange, error) {
// panic("mock out the Recv method")
// },
// RecvMsgFunc: func(m interface{}) error {
// panic("mock out the RecvMsg method")
// },
// SendMsgFunc: func(m interface{}) error {
// panic("mock out the SendMsg method")
// },
// TrailerFunc: func() metadata.MD {
// panic("mock out the Trailer method")
// },
// }
//
// // use mockedKV_StateChangesClient in code that requires KV_StateChangesClient
// // and then make assertions.
//
// }
type KV_StateChangesClientMock struct {
// CloseSendFunc mocks the CloseSend method.
CloseSendFunc func() error
// ContextFunc mocks the Context method.
ContextFunc func() context.Context
// HeaderFunc mocks the Header method.
HeaderFunc func() (metadata.MD, error)
// RecvFunc mocks the Recv method.
RecvFunc func() (*StateChange, error)
// RecvMsgFunc mocks the RecvMsg method.
RecvMsgFunc func(m interface{}) error
// SendMsgFunc mocks the SendMsg method.
SendMsgFunc func(m interface{}) error
// TrailerFunc mocks the Trailer method.
TrailerFunc func() metadata.MD
// calls tracks calls to the methods.
calls struct {
// CloseSend holds details about calls to the CloseSend method.
CloseSend []struct {
}
// Context holds details about calls to the Context method.
Context []struct {
}
// Header holds details about calls to the Header method.
Header []struct {
}
// Recv holds details about calls to the Recv method.
Recv []struct {
}
// RecvMsg holds details about calls to the RecvMsg method.
RecvMsg []struct {
// M is the m argument value.
M interface{}
}
// SendMsg holds details about calls to the SendMsg method.
SendMsg []struct {
// M is the m argument value.
M interface{}
}
// Trailer holds details about calls to the Trailer method.
Trailer []struct {
}
}
lockCloseSend sync.RWMutex
lockContext sync.RWMutex
lockHeader sync.RWMutex
lockRecv sync.RWMutex
lockRecvMsg sync.RWMutex
lockSendMsg sync.RWMutex
lockTrailer sync.RWMutex
}
// CloseSend calls CloseSendFunc.
func (mock *KV_StateChangesClientMock) CloseSend() error {
callInfo := struct {
}{}
mock.lockCloseSend.Lock()
mock.calls.CloseSend = append(mock.calls.CloseSend, callInfo)
mock.lockCloseSend.Unlock()
if mock.CloseSendFunc == nil {
var (
errOut error
)
return errOut
}
return mock.CloseSendFunc()
}
// CloseSendCalls gets all the calls that were made to CloseSend.
// Check the length with:
// len(mockedKV_StateChangesClient.CloseSendCalls())
func (mock *KV_StateChangesClientMock) CloseSendCalls() []struct {
} {
var calls []struct {
}
mock.lockCloseSend.RLock()
calls = mock.calls.CloseSend
mock.lockCloseSend.RUnlock()
return calls
}
// Context calls ContextFunc.
func (mock *KV_StateChangesClientMock) Context() context.Context {
callInfo := struct {
}{}
mock.lockContext.Lock()
mock.calls.Context = append(mock.calls.Context, callInfo)
mock.lockContext.Unlock()
if mock.ContextFunc == nil {
var (
contextOut context.Context
)
return contextOut
}
return mock.ContextFunc()
}
// ContextCalls gets all the calls that were made to Context.
// Check the length with:
// len(mockedKV_StateChangesClient.ContextCalls())
func (mock *KV_StateChangesClientMock) ContextCalls() []struct {
} {
var calls []struct {
}
mock.lockContext.RLock()
calls = mock.calls.Context
mock.lockContext.RUnlock()
return calls
}
// Header calls HeaderFunc.
func (mock *KV_StateChangesClientMock) Header() (metadata.MD, error) {
callInfo := struct {
}{}
mock.lockHeader.Lock()
mock.calls.Header = append(mock.calls.Header, callInfo)
mock.lockHeader.Unlock()
if mock.HeaderFunc == nil {
var (
mDOut metadata.MD
errOut error
)
return mDOut, errOut
}
return mock.HeaderFunc()
}
// HeaderCalls gets all the calls that were made to Header.
// Check the length with:
// len(mockedKV_StateChangesClient.HeaderCalls())
func (mock *KV_StateChangesClientMock) HeaderCalls() []struct {
} {
var calls []struct {
}
mock.lockHeader.RLock()
calls = mock.calls.Header
mock.lockHeader.RUnlock()
return calls
}
// Recv calls RecvFunc.
func (mock *KV_StateChangesClientMock) Recv() (*StateChange, error) {
callInfo := struct {
}{}
mock.lockRecv.Lock()
mock.calls.Recv = append(mock.calls.Recv, callInfo)
mock.lockRecv.Unlock()
if mock.RecvFunc == nil {
var (
stateChangeOut *StateChange
errOut error
)
return stateChangeOut, errOut
}
return mock.RecvFunc()
}
// RecvCalls gets all the calls that were made to Recv.
// Check the length with:
// len(mockedKV_StateChangesClient.RecvCalls())
func (mock *KV_StateChangesClientMock) RecvCalls() []struct {
} {
var calls []struct {
}
mock.lockRecv.RLock()
calls = mock.calls.Recv
mock.lockRecv.RUnlock()
return calls
}
// RecvMsg calls RecvMsgFunc.
func (mock *KV_StateChangesClientMock) RecvMsg(m interface{}) error {
callInfo := struct {
M interface{}
}{
M: m,
}
mock.lockRecvMsg.Lock()
mock.calls.RecvMsg = append(mock.calls.RecvMsg, callInfo)
mock.lockRecvMsg.Unlock()
if mock.RecvMsgFunc == nil {
var (
errOut error
)
return errOut
}
return mock.RecvMsgFunc(m)
}
// RecvMsgCalls gets all the calls that were made to RecvMsg.
// Check the length with:
// len(mockedKV_StateChangesClient.RecvMsgCalls())
func (mock *KV_StateChangesClientMock) RecvMsgCalls() []struct {
M interface{}
} {
var calls []struct {
M interface{}
}
mock.lockRecvMsg.RLock()
calls = mock.calls.RecvMsg
mock.lockRecvMsg.RUnlock()
return calls
}
// SendMsg calls SendMsgFunc.
func (mock *KV_StateChangesClientMock) SendMsg(m interface{}) error {
callInfo := struct {
M interface{}
}{
M: m,
}
mock.lockSendMsg.Lock()
mock.calls.SendMsg = append(mock.calls.SendMsg, callInfo)
mock.lockSendMsg.Unlock()
if mock.SendMsgFunc == nil {
var (
errOut error
)
return errOut
}
return mock.SendMsgFunc(m)
}
// SendMsgCalls gets all the calls that were made to SendMsg.
// Check the length with:
// len(mockedKV_StateChangesClient.SendMsgCalls())
func (mock *KV_StateChangesClientMock) SendMsgCalls() []struct {
M interface{}
} {
var calls []struct {
M interface{}
}
mock.lockSendMsg.RLock()
calls = mock.calls.SendMsg
mock.lockSendMsg.RUnlock()
return calls
}
// Trailer calls TrailerFunc.
func (mock *KV_StateChangesClientMock) Trailer() metadata.MD {
callInfo := struct {
}{}
mock.lockTrailer.Lock()
mock.calls.Trailer = append(mock.calls.Trailer, callInfo)
mock.lockTrailer.Unlock()
if mock.TrailerFunc == nil {
var (
mDOut metadata.MD
)
return mDOut
}
return mock.TrailerFunc()
}
// TrailerCalls gets all the calls that were made to Trailer.
// Check the length with:
// len(mockedKV_StateChangesClient.TrailerCalls())
func (mock *KV_StateChangesClientMock) TrailerCalls() []struct {
} {
var calls []struct {
}
mock.lockTrailer.RLock()
calls = mock.calls.Trailer
mock.lockTrailer.RUnlock()
return calls
}

View File

@ -1,4 +1,4 @@
package gointerfaces
//go:generate moq -stub -out ./sentry/mocks.go ./sentry SentryServer SentryClient
//go:generate moq -stub -out ./remote/mocks.go ./remote KVClient
//go:generate moq -stub -out ./remote/mocks.go ./remote KVClient KV_StateChangesClient

View File

@ -28,6 +28,7 @@ import (
"github.com/ledgerwatch/erigon-lib/gointerfaces"
"github.com/ledgerwatch/erigon-lib/gointerfaces/remote"
"github.com/ledgerwatch/erigon-lib/gointerfaces/sentry"
"github.com/ledgerwatch/erigon-lib/kv"
"github.com/ledgerwatch/log/v3"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@ -43,9 +44,9 @@ type Fetch struct {
sentryClients []sentry.SentryClient // sentry clients that will be used for accessing the network
statusData *sentry.StatusData // Status data used for "handshaking" with sentries
pool Pool // Transaction pool implementation
wg *sync.WaitGroup // used for synchronisation in the tests (nil when not in tests)
coreDB kv.RoDB
wg *sync.WaitGroup // used for synchronisation in the tests (nil when not in tests)
stateChangesClient remote.KVClient
logger log.Logger
}
type Timings struct {
@ -63,15 +64,7 @@ var DefaultTimings = Timings{
// NewFetch creates a new fetch object that will work with given sentry clients. Since the
// SentryClient here is an interface, it is suitable for mocking in tests (mock will need
// to implement all the functions of the SentryClient interface).
func NewFetch(ctx context.Context,
sentryClients []sentry.SentryClient,
genesisHash [32]byte,
networkId uint64,
forks []uint64,
pool Pool,
stateChangesClient remote.KVClient,
logger log.Logger,
) *Fetch {
func NewFetch(ctx context.Context, sentryClients []sentry.SentryClient, genesisHash [32]byte, networkId uint64, forks []uint64, pool Pool, stateChangesClient remote.KVClient, db kv.RoDB) *Fetch {
statusData := &sentry.StatusData{
NetworkId: networkId,
TotalDifficulty: gointerfaces.ConvertUint256IntToH256(uint256.NewInt(0)),
@ -87,7 +80,7 @@ func NewFetch(ctx context.Context,
sentryClients: sentryClients,
statusData: statusData,
pool: pool,
logger: logger,
coreDB: db,
stateChangesClient: stateChangesClient,
}
}
@ -108,11 +101,19 @@ func (f *Fetch) ConnectSentries() {
}
}
func (f *Fetch) ConnectCore() {
go func() { f.stateChangesLoop(f.ctx, f.stateChangesClient) }()
go func() {
for {
select {
case <-f.ctx.Done():
return
default:
}
f.handleStateChanges(f.ctx, f.stateChangesClient)
}
}()
}
func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) {
logger := f.logger
for {
select {
case <-f.ctx.Done():
@ -125,7 +126,7 @@ func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) {
return
}
// Report error and wait more
logger.Warn("sentry not ready yet", "err", err)
log.Warn("sentry not ready yet", "err", err)
time.Sleep(time.Second)
continue
}
@ -154,7 +155,7 @@ func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) {
if errors.Is(err, io.EOF) {
return
}
logger.Warn("messages", "err", err)
log.Warn("messages", "err", err)
return
}
@ -172,14 +173,14 @@ func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) {
if errors.Is(err, io.EOF) {
return
}
logger.Warn("stream.Recv", "err", err)
log.Warn("stream.Recv", "err", err)
return
}
if req == nil {
return
}
if err = f.handleInboundMessage(req, sentryClient); err != nil {
logger.Warn("Handling incoming message: %s", "err", err)
if err = f.handleInboundMessage(streamCtx, req, sentryClient); err != nil {
log.Warn("Handling incoming message: %s", "err", err)
}
if f.wg != nil {
f.wg.Done()
@ -188,7 +189,7 @@ func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) {
}
}
func (f *Fetch) handleInboundMessage(req *sentry.InboundMessage, sentryClient sentry.SentryClient) error {
func (f *Fetch) handleInboundMessage(ctx context.Context, req *sentry.InboundMessage, sentryClient sentry.SentryClient) error {
switch req.Id {
case sentry.MessageId_NEW_POOLED_TRANSACTION_HASHES_66, sentry.MessageId_NEW_POOLED_TRANSACTION_HASHES_65:
hashCount, pos, err := ParseHashesCount(req.Data, 0)
@ -281,7 +282,9 @@ func (f *Fetch) handleInboundMessage(req *sentry.InboundMessage, sentryClient se
return err
}
}
if err := f.pool.Add(txs); err != nil {
if err := f.coreDB.View(ctx, func(tx kv.Tx) error {
return f.pool.Add(tx, txs)
}); err != nil {
return err
}
}
@ -290,7 +293,6 @@ func (f *Fetch) handleInboundMessage(req *sentry.InboundMessage, sentryClient se
}
func (f *Fetch) receivePeerLoop(sentryClient sentry.SentryClient) {
logger := f.logger
for {
select {
case <-f.ctx.Done():
@ -303,7 +305,7 @@ func (f *Fetch) receivePeerLoop(sentryClient sentry.SentryClient) {
return
}
// Report error and wait more
logger.Warn("sentry not ready yet", "err", err)
log.Warn("sentry not ready yet", "err", err)
time.Sleep(time.Second)
continue
}
@ -323,7 +325,7 @@ func (f *Fetch) receivePeerLoop(sentryClient sentry.SentryClient) {
if errors.Is(err, io.EOF) {
return
}
logger.Warn("peers", "err", err)
log.Warn("peers", "err", err)
return
}
@ -341,14 +343,14 @@ func (f *Fetch) receivePeerLoop(sentryClient sentry.SentryClient) {
if errors.Is(err, io.EOF) {
return
}
logger.Warn("stream.Recv", "err", err)
log.Warn("stream.Recv", "err", err)
return
}
if req == nil {
return
}
if err = f.handleNewPeer(req); err != nil {
logger.Warn("Handling new peer", "err", err)
log.Warn("Handling new peer", "err", err)
}
if f.wg != nil {
f.wg.Done()
@ -369,17 +371,7 @@ func (f *Fetch) handleNewPeer(req *sentry.PeersReply) error {
return nil
}
func (f *Fetch) stateChangesLoop(ctx context.Context, client remote.KVClient) {
for {
select {
case <-ctx.Done():
return
default:
}
f.stateChangesStream(ctx, client)
}
}
func (f *Fetch) stateChangesStream(ctx context.Context, client remote.KVClient) {
func (f *Fetch) handleStateChanges(ctx context.Context, client remote.KVClient) {
streamCtx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := client.StateChanges(streamCtx, &remote.StateChangeRequest{WithStorage: false, WithTransactions: true}, grpc.WaitForReady(true))
@ -418,12 +410,27 @@ func (f *Fetch) stateChangesStream(ctx context.Context, client remote.KVClient)
return
}
parseCtx := NewTxParseContext()
var unwindTxs, minedTxs TxSlots
if req.Direction == remote.Direction_FORWARD {
minedTxs.Growth(len(req.Txs))
for i := range req.Txs {
minedTxs.txs[i] = &TxSlot{}
if _, err := parseCtx.ParseTransaction(req.Txs[i], 0, minedTxs.txs[i], minedTxs.senders.At(i)); err != nil {
log.Warn("stream.Recv", "err", err)
continue
}
}
}
if req.Direction == remote.Direction_UNWIND {
unwindTxs.Growth(len(req.Txs))
for i := range req.Txs {
unwindTxs.txs[i] = &TxSlot{}
if _, err := parseCtx.ParseTransaction(req.Txs[i], 0, unwindTxs.txs[i], unwindTxs.senders.At(i)); err != nil {
log.Warn("stream.Recv", "err", err)
continue
}
}
}
diff := map[string]senderInfo{}
for _, change := range req.Changes {
@ -436,8 +443,13 @@ func (f *Fetch) stateChangesStream(ctx context.Context, client remote.KVClient)
diff[string(addr[:])] = senderInfo{nonce: nonce, balance: balance}
}
if err := f.pool.OnNewBlock(diff, unwindTxs, minedTxs, req.ProtocolBaseFee, req.BlockBaseFee, req.BlockHeight); err != nil {
if err := f.coreDB.View(ctx, func(tx kv.Tx) error {
return f.pool.OnNewBlock(tx, diff, unwindTxs, minedTxs, req.ProtocolBaseFee, req.BlockBaseFee, req.BlockHeight)
}); err != nil {
log.Warn("onNewBlock", "err", err)
}
if f.wg != nil {
f.wg.Done()
}
}
}

View File

@ -19,6 +19,7 @@ package txpool
import (
"context"
"fmt"
"io"
"sync"
"testing"
@ -26,15 +27,16 @@ import (
"github.com/ledgerwatch/erigon-lib/gointerfaces/remote"
"github.com/ledgerwatch/erigon-lib/gointerfaces/sentry"
"github.com/ledgerwatch/erigon-lib/gointerfaces/types"
"github.com/ledgerwatch/erigon-lib/kv/memdb"
"github.com/ledgerwatch/log/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
)
func TestFetch(t *testing.T) {
logger := log.New()
ctx, cancelFn := context.WithCancel(context.Background())
defer cancelFn()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var genesisHash [32]byte
var networkId uint64 = 1
@ -44,7 +46,7 @@ func TestFetch(t *testing.T) {
sentryClient := direct.NewSentryClientDirect(direct.ETH66, m)
pool := &PoolMock{}
fetch := NewFetch(ctx, []sentry.SentryClient{sentryClient}, genesisHash, networkId, forks, pool, &remote.KVClientMock{}, logger)
fetch := NewFetch(ctx, []sentry.SentryClient{sentryClient}, genesisHash, networkId, forks, pool, &remote.KVClientMock{}, nil)
var wg sync.WaitGroup
fetch.SetWaitGroup(&wg)
m.StreamWg.Add(2)
@ -133,3 +135,32 @@ func TestSendTxPropagate(t *testing.T) {
}
})
}
func TestOnNewBlock(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db := memdb.NewTestDB(t)
var genesisHash [32]byte
var networkId uint64 = 1
i := 0
stream := &remote.KV_StateChangesClientMock{
RecvFunc: func() (*remote.StateChange, error) {
if i > 0 {
return nil, io.EOF
}
i++
return &remote.StateChange{Txs: [][]byte{decodeHex(txParseTests[0].payloadStr), decodeHex(txParseTests[1].payloadStr), decodeHex(txParseTests[2].payloadStr)}}, nil
},
}
stateChanges := &remote.KVClientMock{
StateChangesFunc: func(ctx context.Context, in *remote.StateChangeRequest, opts ...grpc.CallOption) (remote.KV_StateChangesClient, error) {
return stream, nil
},
}
pool := &PoolMock{}
fetch := NewFetch(ctx, nil, genesisHash, networkId, nil, pool, stateChanges, db)
fetch.handleStateChanges(ctx, stateChanges)
assert.Equal(t, 1, len(pool.OnNewBlockCalls()))
assert.Equal(t, 3, len(pool.OnNewBlockCalls()[0].MinedTxs.txs))
}

View File

@ -5,6 +5,8 @@ package txpool
import (
"sync"
"github.com/ledgerwatch/erigon-lib/kv"
)
// Ensure, that PoolMock does implement Pool.
@ -17,7 +19,7 @@ var _ Pool = &PoolMock{}
//
// // make and configure a mocked Pool
// mockedPool := &PoolMock{
// AddFunc: func(newTxs TxSlots) error {
// AddFunc: func(db kv.Tx, newTxs TxSlots) error {
// panic("mock out the Add method")
// },
// AddNewGoodPeerFunc: func(peerID PeerID) {
@ -29,7 +31,7 @@ var _ Pool = &PoolMock{}
// IdHashKnownFunc: func(hash []byte) bool {
// panic("mock out the IdHashKnown method")
// },
// OnNewBlockFunc: func(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error {
// OnNewBlockFunc: func(db kv.Tx, stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error {
// panic("mock out the OnNewBlock method")
// },
// }
@ -40,7 +42,7 @@ var _ Pool = &PoolMock{}
// }
type PoolMock struct {
// AddFunc mocks the Add method.
AddFunc func(newTxs TxSlots) error
AddFunc func(db kv.Tx, newTxs TxSlots) error
// AddNewGoodPeerFunc mocks the AddNewGoodPeer method.
AddNewGoodPeerFunc func(peerID PeerID)
@ -52,12 +54,14 @@ type PoolMock struct {
IdHashKnownFunc func(hash []byte) bool
// OnNewBlockFunc mocks the OnNewBlock method.
OnNewBlockFunc func(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error
OnNewBlockFunc func(db kv.Tx, stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error
// calls tracks calls to the methods.
calls struct {
// Add holds details about calls to the Add method.
Add []struct {
// Db is the db argument value.
Db kv.Tx
// NewTxs is the newTxs argument value.
NewTxs TxSlots
}
@ -78,6 +82,8 @@ type PoolMock struct {
}
// OnNewBlock holds details about calls to the OnNewBlock method.
OnNewBlock []struct {
// Db is the db argument value.
Db kv.Tx
// StateChanges is the stateChanges argument value.
StateChanges map[string]senderInfo
// UnwindTxs is the unwindTxs argument value.
@ -100,10 +106,12 @@ type PoolMock struct {
}
// Add calls AddFunc.
func (mock *PoolMock) Add(newTxs TxSlots) error {
func (mock *PoolMock) Add(db kv.Tx, newTxs TxSlots) error {
callInfo := struct {
Db kv.Tx
NewTxs TxSlots
}{
Db: db,
NewTxs: newTxs,
}
mock.lockAdd.Lock()
@ -115,16 +123,18 @@ func (mock *PoolMock) Add(newTxs TxSlots) error {
)
return errOut
}
return mock.AddFunc(newTxs)
return mock.AddFunc(db, newTxs)
}
// AddCalls gets all the calls that were made to Add.
// Check the length with:
// len(mockedPool.AddCalls())
func (mock *PoolMock) AddCalls() []struct {
Db kv.Tx
NewTxs TxSlots
} {
var calls []struct {
Db kv.Tx
NewTxs TxSlots
}
mock.lockAdd.RLock()
@ -233,8 +243,9 @@ func (mock *PoolMock) IdHashKnownCalls() []struct {
}
// OnNewBlock calls OnNewBlockFunc.
func (mock *PoolMock) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error {
func (mock *PoolMock) OnNewBlock(db kv.Tx, stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error {
callInfo := struct {
Db kv.Tx
StateChanges map[string]senderInfo
UnwindTxs TxSlots
MinedTxs TxSlots
@ -242,6 +253,7 @@ func (mock *PoolMock) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs T
BlockBaseFee uint64
BlockHeight uint64
}{
Db: db,
StateChanges: stateChanges,
UnwindTxs: unwindTxs,
MinedTxs: minedTxs,
@ -258,13 +270,14 @@ func (mock *PoolMock) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs T
)
return errOut
}
return mock.OnNewBlockFunc(stateChanges, unwindTxs, minedTxs, protocolBaseFee, blockBaseFee, blockHeight)
return mock.OnNewBlockFunc(db, stateChanges, unwindTxs, minedTxs, protocolBaseFee, blockBaseFee, blockHeight)
}
// OnNewBlockCalls gets all the calls that were made to OnNewBlock.
// Check the length with:
// len(mockedPool.OnNewBlockCalls())
func (mock *PoolMock) OnNewBlockCalls() []struct {
Db kv.Tx
StateChanges map[string]senderInfo
UnwindTxs TxSlots
MinedTxs TxSlots
@ -273,6 +286,7 @@ func (mock *PoolMock) OnNewBlockCalls() []struct {
BlockHeight uint64
} {
var calls []struct {
Db kv.Tx
StateChanges map[string]senderInfo
UnwindTxs TxSlots
MinedTxs TxSlots

View File

@ -27,6 +27,7 @@ import (
"github.com/google/btree"
lru "github.com/hashicorp/golang-lru"
"github.com/holiman/uint256"
"github.com/ledgerwatch/erigon-lib/kv"
"go.uber.org/atomic"
)
@ -37,8 +38,8 @@ type Pool interface {
// IdHashKnown check whether transaction with given Id hash is known to the pool
IdHashKnown(hash []byte) bool
GetRlp(hash []byte) []byte
Add(newTxs TxSlots) error
OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error
Add(db kv.Tx, newTxs TxSlots) error
OnNewBlock(db kv.Tx, stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error
AddNewGoodPeer(peerID PeerID)
}
@ -87,6 +88,8 @@ const PendingSubPoolLimit = 1024
const BaseFeeSubPoolLimit = 1024
const QueuedSubPoolLimit = 1024
const MaxSendersInfoCache = 1024
type nonce2Tx struct{ *btree.BTree }
type senderInfo struct {
@ -115,6 +118,7 @@ type TxPool struct {
protocolBaseFee atomic.Uint64
blockBaseFee atomic.Uint64
senderID uint64
senderIDs map[string]uint64
senderInfo map[uint64]*senderInfo
byHash map[string]*metaTx // tx_hash => tx
@ -202,7 +206,7 @@ func (p *TxPool) IdHashIsLocal(hash []byte) bool {
}
func (p *TxPool) OnNewPeer(peerID PeerID) { p.recentlyConnectedPeers.AddPeer(peerID) }
func (p *TxPool) Add(newTxs TxSlots) error {
func (p *TxPool) Add(coreDB kv.Tx, newTxs TxSlots) error {
p.lock.Lock()
defer p.lock.Unlock()
if err := newTxs.Valid(); err != nil {
@ -214,7 +218,9 @@ func (p *TxPool) Add(newTxs TxSlots) error {
return fmt.Errorf("non-zero base fee")
}
setTxSenderID(p.senderIDs, p.senderInfo, newTxs)
if err := setTxSenderID(coreDB, &p.senderID, p.senderIDs, p.senderInfo, newTxs); err != nil {
return err
}
if err := onNewTxs(p.senderInfo, newTxs, protocolBaseFee, blockBaseFee, p.pending, p.baseFee, p.queued, p.byHash, p.localsHistory); err != nil {
return err
}
@ -286,7 +292,7 @@ func onNewTxs(senderInfo map[uint64]*senderInfo, newTxs TxSlots, protocolBaseFee
return nil
}
func (p *TxPool) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error {
func (p *TxPool) OnNewBlock(coreDB kv.Tx, stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error {
p.lock.Lock()
defer p.lock.Unlock()
if err := unwindTxs.Valid(); err != nil {
@ -300,13 +306,18 @@ func (p *TxPool) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, mined
p.protocolBaseFee.Store(protocolBaseFee)
p.blockBaseFee.Store(blockBaseFee)
setTxSenderID(p.senderIDs, p.senderInfo, unwindTxs)
setTxSenderID(p.senderIDs, p.senderInfo, minedTxs)
if err := setTxSenderID(coreDB, &p.senderID, p.senderIDs, p.senderInfo, unwindTxs); err != nil {
return err
}
if err := setTxSenderID(coreDB, &p.senderID, p.senderIDs, p.senderInfo, minedTxs); err != nil {
return err
}
for addr, id := range p.senderIDs { // merge state changes
if v, ok := stateChanges[addr]; ok {
p.senderInfo[id] = &v
}
}
if err := onNewBlock(p.senderInfo, unwindTxs, minedTxs.txs, protocolBaseFee, blockBaseFee, p.pending, p.baseFee, p.queued, p.byHash, p.localsHistory); err != nil {
return err
}
@ -326,24 +337,53 @@ func (p *TxPool) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, mined
}
}
/*
// evict sendersInfo without txs
if len(p.senderIDs) > MaxSendersInfoCache {
for i := range p.senderInfo {
if p.senderInfo[i].txNonce2Tx.Len() > 0 {
continue
}
for addr, id := range p.senderIDs {
if id == i {
delete(p.senderIDs, addr)
}
}
delete(p.senderInfo, i)
}
}
*/
return nil
}
func setTxSenderID(senderIDs map[string]uint64, senderInfo map[uint64]*senderInfo, txs TxSlots) {
func setTxSenderID(coreDB kv.Tx, senderIDSequence *uint64, senderIDs map[string]uint64, sendersInfo map[uint64]*senderInfo, txs TxSlots) error {
for i := range txs.txs {
addr := string(txs.senders.At(i))
// assign ID to each new sender
id, ok := senderIDs[addr]
if !ok {
for i := range senderInfo { //TODO: create field for it?
if id < i {
id = i
}
}
id++
senderIDs[addr] = id
*senderIDSequence++
senderIDs[addr] = *senderIDSequence
}
txs.txs[i].senderID = id
// load data from db if need
_, ok = sendersInfo[txs.txs[i].senderID]
if !ok {
encoded, err := coreDB.GetOne(kv.PlainState, txs.senders.At(i))
if err != nil {
return err
}
nonce, balance, err := DecodeSender(encoded)
if err != nil {
return err
}
sendersInfo[txs.txs[i].senderID] = &senderInfo{nonce: nonce, balance: balance}
}
}
return nil
}
func onNewBlock(senderInfo map[uint64]*senderInfo, unwindTxs TxSlots, minedTxs []*TxSlot, protocolBaseFee, blockBaseFee uint64, pending, baseFee, queued *SubPool, byHash map[string]*metaTx, localsHistory *lru.Cache) error {

View File

@ -426,19 +426,19 @@ func FuzzOnNewBlocks11(f *testing.F) {
// go to first fork
//fmt.Printf("ll1: %d,%d,%d\n", pool.pending.Len(), pool.baseFee.Len(), pool.queued.Len())
unwindTxs, minedTxs1, p2pReceived, minedTxs2 := splitDataset(txs)
err = pool.OnNewBlock(map[string]senderInfo{}, unwindTxs, minedTxs1, protocolBaseFee, blockBaseFee, 1)
err = pool.OnNewBlock(nil, map[string]senderInfo{}, unwindTxs, minedTxs1, protocolBaseFee, blockBaseFee, 1)
assert.NoError(err)
check(unwindTxs, minedTxs1, "fork1")
checkNotify(unwindTxs, minedTxs1, "fork1")
// unwind everything and switch to new fork (need unwind mined now)
err = pool.OnNewBlock(map[string]senderInfo{}, minedTxs1, minedTxs2, protocolBaseFee, blockBaseFee, 2)
err = pool.OnNewBlock(nil, map[string]senderInfo{}, minedTxs1, minedTxs2, protocolBaseFee, blockBaseFee, 2)
assert.NoError(err)
check(minedTxs1, minedTxs2, "fork2")
checkNotify(minedTxs1, minedTxs2, "fork2")
// add some remote txs from p2p
err = pool.Add(p2pReceived)
err = pool.Add(nil, p2pReceived)
assert.NoError(err)
check(p2pReceived, TxSlots{}, "p2pmsg1")
checkNotify(p2pReceived, TxSlots{}, "p2pmsg1")