This commit is contained in:
alex.sharov 2021-08-11 11:21:36 +07:00
parent 4bf6b1b29b
commit d491e4c093
5 changed files with 97 additions and 34 deletions

View File

@ -28,6 +28,7 @@ import (
"github.com/ledgerwatch/erigon-lib/gointerfaces"
"github.com/ledgerwatch/erigon-lib/gointerfaces/remote"
"github.com/ledgerwatch/erigon-lib/gointerfaces/sentry"
"github.com/ledgerwatch/erigon-lib/kv"
"github.com/ledgerwatch/log/v3"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@ -43,7 +44,8 @@ type Fetch struct {
sentryClients []sentry.SentryClient // sentry clients that will be used for accessing the network
statusData *sentry.StatusData // Status data used for "handshaking" with sentries
pool Pool // Transaction pool implementation
wg *sync.WaitGroup // used for synchronisation in the tests (nil when not in tests)
db kv.RoDB
wg *sync.WaitGroup // used for synchronisation in the tests (nil when not in tests)
stateChangesClient remote.KVClient
}
@ -62,7 +64,7 @@ var DefaultTimings = Timings{
// NewFetch creates a new fetch object that will work with given sentry clients. Since the
// SentryClient here is an interface, it is suitable for mocking in tests (mock will need
// to implement all the functions of the SentryClient interface).
func NewFetch(ctx context.Context, sentryClients []sentry.SentryClient, genesisHash [32]byte, networkId uint64, forks []uint64, pool Pool, stateChangesClient remote.KVClient) *Fetch {
func NewFetch(ctx context.Context, sentryClients []sentry.SentryClient, genesisHash [32]byte, networkId uint64, forks []uint64, pool Pool, stateChangesClient remote.KVClient, db kv.RoDB) *Fetch {
statusData := &sentry.StatusData{
NetworkId: networkId,
TotalDifficulty: gointerfaces.ConvertUint256IntToH256(uint256.NewInt(0)),
@ -78,6 +80,7 @@ func NewFetch(ctx context.Context, sentryClients []sentry.SentryClient, genesisH
sentryClients: sentryClients,
statusData: statusData,
pool: pool,
db: db,
stateChangesClient: stateChangesClient,
}
}
@ -176,7 +179,7 @@ func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) {
if req == nil {
return
}
if err = f.handleInboundMessage(req, sentryClient); err != nil {
if err = f.handleInboundMessage(streamCtx, req, sentryClient); err != nil {
log.Warn("Handling incoming message: %s", "err", err)
}
if f.wg != nil {
@ -186,7 +189,7 @@ func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) {
}
}
func (f *Fetch) handleInboundMessage(req *sentry.InboundMessage, sentryClient sentry.SentryClient) error {
func (f *Fetch) handleInboundMessage(ctx context.Context, req *sentry.InboundMessage, sentryClient sentry.SentryClient) error {
switch req.Id {
case sentry.MessageId_NEW_POOLED_TRANSACTION_HASHES_66, sentry.MessageId_NEW_POOLED_TRANSACTION_HASHES_65:
hashCount, pos, err := ParseHashesCount(req.Data, 0)
@ -279,7 +282,9 @@ func (f *Fetch) handleInboundMessage(req *sentry.InboundMessage, sentryClient se
return err
}
}
if err := f.pool.Add(txs); err != nil {
if err := f.db.View(ctx, func(tx kv.Tx) error {
return f.pool.Add(tx, txs)
}); err != nil {
return err
}
}
@ -438,7 +443,9 @@ func (f *Fetch) handleStateChanges(ctx context.Context, client remote.KVClient)
diff[string(addr[:])] = senderInfo{nonce: nonce, balance: balance}
}
if err := f.pool.OnNewBlock(diff, unwindTxs, minedTxs, req.ProtocolBaseFee, req.BlockBaseFee, req.BlockHeight); err != nil {
if err := f.db.View(ctx, func(tx kv.Tx) error {
return f.pool.OnNewBlock(tx, diff, unwindTxs, minedTxs, req.ProtocolBaseFee, req.BlockBaseFee, req.BlockHeight)
}); err != nil {
log.Warn("onNewBlock", "err", err)
}
if f.wg != nil {

View File

@ -27,6 +27,7 @@ import (
"github.com/ledgerwatch/erigon-lib/gointerfaces/remote"
"github.com/ledgerwatch/erigon-lib/gointerfaces/sentry"
"github.com/ledgerwatch/erigon-lib/gointerfaces/types"
"github.com/ledgerwatch/erigon-lib/kv/memdb"
"github.com/ledgerwatch/log/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -45,7 +46,7 @@ func TestFetch(t *testing.T) {
sentryClient := direct.NewSentryClientDirect(direct.ETH66, m)
pool := &PoolMock{}
fetch := NewFetch(ctx, []sentry.SentryClient{sentryClient}, genesisHash, networkId, forks, pool, &remote.KVClientMock{})
fetch := NewFetch(ctx, []sentry.SentryClient{sentryClient}, genesisHash, networkId, forks, pool, &remote.KVClientMock{}, nil)
var wg sync.WaitGroup
fetch.SetWaitGroup(&wg)
m.StreamWg.Add(2)
@ -138,6 +139,7 @@ func TestSendTxPropagate(t *testing.T) {
func TestOnNewBlock(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db := memdb.NewTestDB(t)
var genesisHash [32]byte
var networkId uint64 = 1
@ -157,7 +159,7 @@ func TestOnNewBlock(t *testing.T) {
},
}
pool := &PoolMock{}
fetch := NewFetch(ctx, nil, genesisHash, networkId, nil, pool, stateChanges)
fetch := NewFetch(ctx, nil, genesisHash, networkId, nil, pool, stateChanges, db)
fetch.handleStateChanges(ctx, stateChanges)
assert.Equal(t, 1, len(pool.OnNewBlockCalls()))
assert.Equal(t, 3, len(pool.OnNewBlockCalls()[0].MinedTxs.txs))

View File

@ -5,6 +5,8 @@ package txpool
import (
"sync"
"github.com/ledgerwatch/erigon-lib/kv"
)
// Ensure, that PoolMock does implement Pool.
@ -17,7 +19,7 @@ var _ Pool = &PoolMock{}
//
// // make and configure a mocked Pool
// mockedPool := &PoolMock{
// AddFunc: func(newTxs TxSlots) error {
// AddFunc: func(db kv.Tx, newTxs TxSlots) error {
// panic("mock out the Add method")
// },
// AddNewGoodPeerFunc: func(peerID PeerID) {
@ -29,7 +31,7 @@ var _ Pool = &PoolMock{}
// IdHashKnownFunc: func(hash []byte) bool {
// panic("mock out the IdHashKnown method")
// },
// OnNewBlockFunc: func(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error {
// OnNewBlockFunc: func(db kv.Tx, stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error {
// panic("mock out the OnNewBlock method")
// },
// }
@ -40,7 +42,7 @@ var _ Pool = &PoolMock{}
// }
type PoolMock struct {
// AddFunc mocks the Add method.
AddFunc func(newTxs TxSlots) error
AddFunc func(db kv.Tx, newTxs TxSlots) error
// AddNewGoodPeerFunc mocks the AddNewGoodPeer method.
AddNewGoodPeerFunc func(peerID PeerID)
@ -52,12 +54,14 @@ type PoolMock struct {
IdHashKnownFunc func(hash []byte) bool
// OnNewBlockFunc mocks the OnNewBlock method.
OnNewBlockFunc func(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error
OnNewBlockFunc func(db kv.Tx, stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error
// calls tracks calls to the methods.
calls struct {
// Add holds details about calls to the Add method.
Add []struct {
// Db is the db argument value.
Db kv.Tx
// NewTxs is the newTxs argument value.
NewTxs TxSlots
}
@ -78,6 +82,8 @@ type PoolMock struct {
}
// OnNewBlock holds details about calls to the OnNewBlock method.
OnNewBlock []struct {
// Db is the db argument value.
Db kv.Tx
// StateChanges is the stateChanges argument value.
StateChanges map[string]senderInfo
// UnwindTxs is the unwindTxs argument value.
@ -100,10 +106,12 @@ type PoolMock struct {
}
// Add calls AddFunc.
func (mock *PoolMock) Add(newTxs TxSlots) error {
func (mock *PoolMock) Add(db kv.Tx, newTxs TxSlots) error {
callInfo := struct {
Db kv.Tx
NewTxs TxSlots
}{
Db: db,
NewTxs: newTxs,
}
mock.lockAdd.Lock()
@ -115,16 +123,18 @@ func (mock *PoolMock) Add(newTxs TxSlots) error {
)
return errOut
}
return mock.AddFunc(newTxs)
return mock.AddFunc(db, newTxs)
}
// AddCalls gets all the calls that were made to Add.
// Check the length with:
// len(mockedPool.AddCalls())
func (mock *PoolMock) AddCalls() []struct {
Db kv.Tx
NewTxs TxSlots
} {
var calls []struct {
Db kv.Tx
NewTxs TxSlots
}
mock.lockAdd.RLock()
@ -233,8 +243,9 @@ func (mock *PoolMock) IdHashKnownCalls() []struct {
}
// OnNewBlock calls OnNewBlockFunc.
func (mock *PoolMock) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error {
func (mock *PoolMock) OnNewBlock(db kv.Tx, stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error {
callInfo := struct {
Db kv.Tx
StateChanges map[string]senderInfo
UnwindTxs TxSlots
MinedTxs TxSlots
@ -242,6 +253,7 @@ func (mock *PoolMock) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs T
BlockBaseFee uint64
BlockHeight uint64
}{
Db: db,
StateChanges: stateChanges,
UnwindTxs: unwindTxs,
MinedTxs: minedTxs,
@ -258,13 +270,14 @@ func (mock *PoolMock) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs T
)
return errOut
}
return mock.OnNewBlockFunc(stateChanges, unwindTxs, minedTxs, protocolBaseFee, blockBaseFee, blockHeight)
return mock.OnNewBlockFunc(db, stateChanges, unwindTxs, minedTxs, protocolBaseFee, blockBaseFee, blockHeight)
}
// OnNewBlockCalls gets all the calls that were made to OnNewBlock.
// Check the length with:
// len(mockedPool.OnNewBlockCalls())
func (mock *PoolMock) OnNewBlockCalls() []struct {
Db kv.Tx
StateChanges map[string]senderInfo
UnwindTxs TxSlots
MinedTxs TxSlots
@ -273,6 +286,7 @@ func (mock *PoolMock) OnNewBlockCalls() []struct {
BlockHeight uint64
} {
var calls []struct {
Db kv.Tx
StateChanges map[string]senderInfo
UnwindTxs TxSlots
MinedTxs TxSlots

View File

@ -27,6 +27,7 @@ import (
"github.com/google/btree"
lru "github.com/hashicorp/golang-lru"
"github.com/holiman/uint256"
"github.com/ledgerwatch/erigon-lib/kv"
"go.uber.org/atomic"
)
@ -37,8 +38,8 @@ type Pool interface {
// IdHashKnown check whether transaction with given Id hash is known to the pool
IdHashKnown(hash []byte) bool
GetRlp(hash []byte) []byte
Add(newTxs TxSlots) error
OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error
Add(db kv.Tx, newTxs TxSlots) error
OnNewBlock(db kv.Tx, stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error
AddNewGoodPeer(peerID PeerID)
}
@ -87,6 +88,8 @@ const PendingSubPoolLimit = 1024
const BaseFeeSubPoolLimit = 1024
const QueuedSubPoolLimit = 1024
const MaxSendersInfoCache = 1024
type nonce2Tx struct{ *btree.BTree }
type senderInfo struct {
@ -115,6 +118,7 @@ type TxPool struct {
protocolBaseFee atomic.Uint64
blockBaseFee atomic.Uint64
senderID uint64
senderIDs map[string]uint64
senderInfo map[uint64]*senderInfo
byHash map[string]*metaTx // tx_hash => tx
@ -202,7 +206,7 @@ func (p *TxPool) IdHashIsLocal(hash []byte) bool {
}
func (p *TxPool) OnNewPeer(peerID PeerID) { p.recentlyConnectedPeers.AddPeer(peerID) }
func (p *TxPool) Add(newTxs TxSlots) error {
func (p *TxPool) Add(tx kv.Tx, newTxs TxSlots) error {
p.lock.Lock()
defer p.lock.Unlock()
if err := newTxs.Valid(); err != nil {
@ -214,7 +218,9 @@ func (p *TxPool) Add(newTxs TxSlots) error {
return fmt.Errorf("non-zero base fee")
}
setTxSenderID(p.senderIDs, p.senderInfo, newTxs)
if err := setTxSenderID(tx, &p.senderID, p.senderIDs, p.senderInfo, newTxs); err != nil {
return err
}
if err := onNewTxs(p.senderInfo, newTxs, protocolBaseFee, blockBaseFee, p.pending, p.baseFee, p.queued, p.byHash, p.localsHistory); err != nil {
return err
}
@ -286,7 +292,7 @@ func onNewTxs(senderInfo map[uint64]*senderInfo, newTxs TxSlots, protocolBaseFee
return nil
}
func (p *TxPool) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error {
func (p *TxPool) OnNewBlock(tx kv.Tx, stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error {
p.lock.Lock()
defer p.lock.Unlock()
if err := unwindTxs.Valid(); err != nil {
@ -300,13 +306,18 @@ func (p *TxPool) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, mined
p.protocolBaseFee.Store(protocolBaseFee)
p.blockBaseFee.Store(blockBaseFee)
setTxSenderID(p.senderIDs, p.senderInfo, unwindTxs)
setTxSenderID(p.senderIDs, p.senderInfo, minedTxs)
if err := setTxSenderID(tx, &p.senderID, p.senderIDs, p.senderInfo, unwindTxs); err != nil {
return err
}
if err := setTxSenderID(tx, &p.senderID, p.senderIDs, p.senderInfo, minedTxs); err != nil {
return err
}
for addr, id := range p.senderIDs { // merge state changes
if v, ok := stateChanges[addr]; ok {
p.senderInfo[id] = &v
}
}
if err := onNewBlock(p.senderInfo, unwindTxs, minedTxs.txs, protocolBaseFee, blockBaseFee, p.pending, p.baseFee, p.queued, p.byHash, p.localsHistory); err != nil {
return err
}
@ -326,24 +337,53 @@ func (p *TxPool) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, mined
}
}
/*
// evict sendersInfo without txs
if len(p.senderIDs) > MaxSendersInfoCache {
for i := range p.senderInfo {
if p.senderInfo[i].txNonce2Tx.Len() > 0 {
continue
}
for addr, id := range p.senderIDs {
if id == i {
delete(p.senderIDs, addr)
}
}
delete(p.senderInfo, i)
}
}
*/
return nil
}
func setTxSenderID(senderIDs map[string]uint64, senderInfo map[uint64]*senderInfo, txs TxSlots) {
func setTxSenderID(tx kv.Tx, senderIDSequence *uint64, senderIDs map[string]uint64, sendersInfo map[uint64]*senderInfo, txs TxSlots) error {
for i := range txs.txs {
addr := string(txs.senders.At(i))
// assign ID to each new sender
id, ok := senderIDs[addr]
if !ok {
for i := range senderInfo { //TODO: create field for it?
if id < i {
id = i
}
}
id++
senderIDs[addr] = id
*senderIDSequence++
senderIDs[addr] = *senderIDSequence
}
txs.txs[i].senderID = id
// load data from db if need
_, ok = sendersInfo[txs.txs[i].senderID]
if !ok {
encoded, err := tx.GetOne(kv.PlainState, txs.senders.At(i))
if err != nil {
return err
}
nonce, balance, err := DecodeSender(encoded)
if err != nil {
return err
}
sendersInfo[txs.txs[i].senderID] = &senderInfo{nonce: nonce, balance: balance}
}
}
return nil
}
func onNewBlock(senderInfo map[uint64]*senderInfo, unwindTxs TxSlots, minedTxs []*TxSlot, protocolBaseFee, blockBaseFee uint64, pending, baseFee, queued *SubPool, byHash map[string]*metaTx, localsHistory *lru.Cache) error {

View File

@ -426,19 +426,19 @@ func FuzzOnNewBlocks11(f *testing.F) {
// go to first fork
//fmt.Printf("ll1: %d,%d,%d\n", pool.pending.Len(), pool.baseFee.Len(), pool.queued.Len())
unwindTxs, minedTxs1, p2pReceived, minedTxs2 := splitDataset(txs)
err = pool.OnNewBlock(map[string]senderInfo{}, unwindTxs, minedTxs1, protocolBaseFee, blockBaseFee, 1)
err = pool.OnNewBlock(nil, map[string]senderInfo{}, unwindTxs, minedTxs1, protocolBaseFee, blockBaseFee, 1)
assert.NoError(err)
check(unwindTxs, minedTxs1, "fork1")
checkNotify(unwindTxs, minedTxs1, "fork1")
// unwind everything and switch to new fork (need unwind mined now)
err = pool.OnNewBlock(map[string]senderInfo{}, minedTxs1, minedTxs2, protocolBaseFee, blockBaseFee, 2)
err = pool.OnNewBlock(nil, map[string]senderInfo{}, minedTxs1, minedTxs2, protocolBaseFee, blockBaseFee, 2)
assert.NoError(err)
check(minedTxs1, minedTxs2, "fork2")
checkNotify(minedTxs1, minedTxs2, "fork2")
// add some remote txs from p2p
err = pool.Add(p2pReceived)
err = pool.Add(nil, p2pReceived)
assert.NoError(err)
check(p2pReceived, TxSlots{}, "p2pmsg1")
checkNotify(p2pReceived, TxSlots{}, "p2pmsg1")