Merge pull request #25 from ledgerwatch/pool10

Pool: subscribe to state changes
This commit is contained in:
Alex Sharov 2021-08-09 11:16:16 +07:00 committed by GitHub
commit 4d8aa5dde6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 438 additions and 20 deletions

View File

@ -0,0 +1,205 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package remote
import (
context "context"
types "github.com/ledgerwatch/erigon-lib/gointerfaces/types"
grpc "google.golang.org/grpc"
emptypb "google.golang.org/protobuf/types/known/emptypb"
sync "sync"
)
// Ensure, that KVClientMock does implement KVClient.
// If this is not the case, regenerate this file with moq.
var _ KVClient = &KVClientMock{}
// KVClientMock is a mock implementation of KVClient.
//
// func TestSomethingThatUsesKVClient(t *testing.T) {
//
// // make and configure a mocked KVClient
// mockedKVClient := &KVClientMock{
// StateChangesFunc: func(ctx context.Context, in *StateChangeRequest, opts ...grpc.CallOption) (KV_StateChangesClient, error) {
// panic("mock out the StateChanges method")
// },
// TxFunc: func(ctx context.Context, opts ...grpc.CallOption) (KV_TxClient, error) {
// panic("mock out the Tx method")
// },
// VersionFunc: func(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*types.VersionReply, error) {
// panic("mock out the Version method")
// },
// }
//
// // use mockedKVClient in code that requires KVClient
// // and then make assertions.
//
// }
type KVClientMock struct {
// StateChangesFunc mocks the StateChanges method.
StateChangesFunc func(ctx context.Context, in *StateChangeRequest, opts ...grpc.CallOption) (KV_StateChangesClient, error)
// TxFunc mocks the Tx method.
TxFunc func(ctx context.Context, opts ...grpc.CallOption) (KV_TxClient, error)
// VersionFunc mocks the Version method.
VersionFunc func(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*types.VersionReply, error)
// calls tracks calls to the methods.
calls struct {
// StateChanges holds details about calls to the StateChanges method.
StateChanges []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// In is the in argument value.
In *StateChangeRequest
// Opts is the opts argument value.
Opts []grpc.CallOption
}
// Tx holds details about calls to the Tx method.
Tx []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Opts is the opts argument value.
Opts []grpc.CallOption
}
// Version holds details about calls to the Version method.
Version []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// In is the in argument value.
In *emptypb.Empty
// Opts is the opts argument value.
Opts []grpc.CallOption
}
}
lockStateChanges sync.RWMutex
lockTx sync.RWMutex
lockVersion sync.RWMutex
}
// StateChanges calls StateChangesFunc.
func (mock *KVClientMock) StateChanges(ctx context.Context, in *StateChangeRequest, opts ...grpc.CallOption) (KV_StateChangesClient, error) {
callInfo := struct {
Ctx context.Context
In *StateChangeRequest
Opts []grpc.CallOption
}{
Ctx: ctx,
In: in,
Opts: opts,
}
mock.lockStateChanges.Lock()
mock.calls.StateChanges = append(mock.calls.StateChanges, callInfo)
mock.lockStateChanges.Unlock()
if mock.StateChangesFunc == nil {
var (
kV_StateChangesClientOut KV_StateChangesClient
errOut error
)
return kV_StateChangesClientOut, errOut
}
return mock.StateChangesFunc(ctx, in, opts...)
}
// StateChangesCalls gets all the calls that were made to StateChanges.
// Check the length with:
// len(mockedKVClient.StateChangesCalls())
func (mock *KVClientMock) StateChangesCalls() []struct {
Ctx context.Context
In *StateChangeRequest
Opts []grpc.CallOption
} {
var calls []struct {
Ctx context.Context
In *StateChangeRequest
Opts []grpc.CallOption
}
mock.lockStateChanges.RLock()
calls = mock.calls.StateChanges
mock.lockStateChanges.RUnlock()
return calls
}
// Tx calls TxFunc.
func (mock *KVClientMock) Tx(ctx context.Context, opts ...grpc.CallOption) (KV_TxClient, error) {
callInfo := struct {
Ctx context.Context
Opts []grpc.CallOption
}{
Ctx: ctx,
Opts: opts,
}
mock.lockTx.Lock()
mock.calls.Tx = append(mock.calls.Tx, callInfo)
mock.lockTx.Unlock()
if mock.TxFunc == nil {
var (
kV_TxClientOut KV_TxClient
errOut error
)
return kV_TxClientOut, errOut
}
return mock.TxFunc(ctx, opts...)
}
// TxCalls gets all the calls that were made to Tx.
// Check the length with:
// len(mockedKVClient.TxCalls())
func (mock *KVClientMock) TxCalls() []struct {
Ctx context.Context
Opts []grpc.CallOption
} {
var calls []struct {
Ctx context.Context
Opts []grpc.CallOption
}
mock.lockTx.RLock()
calls = mock.calls.Tx
mock.lockTx.RUnlock()
return calls
}
// Version calls VersionFunc.
func (mock *KVClientMock) Version(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*types.VersionReply, error) {
callInfo := struct {
Ctx context.Context
In *emptypb.Empty
Opts []grpc.CallOption
}{
Ctx: ctx,
In: in,
Opts: opts,
}
mock.lockVersion.Lock()
mock.calls.Version = append(mock.calls.Version, callInfo)
mock.lockVersion.Unlock()
if mock.VersionFunc == nil {
var (
versionReplyOut *types.VersionReply
errOut error
)
return versionReplyOut, errOut
}
return mock.VersionFunc(ctx, in, opts...)
}
// VersionCalls gets all the calls that were made to Version.
// Check the length with:
// len(mockedKVClient.VersionCalls())
func (mock *KVClientMock) VersionCalls() []struct {
Ctx context.Context
In *emptypb.Empty
Opts []grpc.CallOption
} {
var calls []struct {
Ctx context.Context
In *emptypb.Empty
Opts []grpc.CallOption
}
mock.lockVersion.RLock()
calls = mock.calls.Version
mock.lockVersion.RUnlock()
return calls
}

View File

@ -1,3 +1,4 @@
package gointerfaces
//go:generate moq -stub -out ./sentry/mocks.go ./sentry SentryServer SentryClient
//go:generate moq -stub -out ./remote/mocks.go ./remote KVClient

View File

@ -26,6 +26,7 @@ import (
"github.com/holiman/uint256"
"github.com/ledgerwatch/erigon-lib/gointerfaces"
"github.com/ledgerwatch/erigon-lib/gointerfaces/remote"
"github.com/ledgerwatch/erigon-lib/gointerfaces/sentry"
"github.com/ledgerwatch/log/v3"
"google.golang.org/grpc"
@ -38,12 +39,13 @@ import (
// genesis hash and list of forks, but with zero max block and total difficulty
// Sentry should have a logic not to overwrite statusData with messages from tx pool
type Fetch struct {
ctx context.Context // Context used for cancellation and closing of the fetcher
sentryClients []sentry.SentryClient // sentry clients that will be used for accessing the network
statusData *sentry.StatusData // Status data used for "handshaking" with sentries
pool Pool // Transaction pool implementation
wg *sync.WaitGroup // used for synchronisation in the tests (nil when not in tests)
logger log.Logger
ctx context.Context // Context used for cancellation and closing of the fetcher
sentryClients []sentry.SentryClient // sentry clients that will be used for accessing the network
statusData *sentry.StatusData // Status data used for "handshaking" with sentries
pool Pool // Transaction pool implementation
wg *sync.WaitGroup // used for synchronisation in the tests (nil when not in tests)
stateChangesClient remote.KVClient
logger log.Logger
}
type Timings struct {
@ -67,6 +69,7 @@ func NewFetch(ctx context.Context,
networkId uint64,
forks []uint64,
pool Pool,
stateChangesClient remote.KVClient,
logger log.Logger,
) *Fetch {
statusData := &sentry.StatusData{
@ -80,11 +83,12 @@ func NewFetch(ctx context.Context,
},
}
return &Fetch{
ctx: ctx,
sentryClients: sentryClients,
statusData: statusData,
pool: pool,
logger: logger,
ctx: ctx,
sentryClients: sentryClients,
statusData: statusData,
pool: pool,
logger: logger,
stateChangesClient: stateChangesClient,
}
}
@ -92,8 +96,8 @@ func (f *Fetch) SetWaitGroup(wg *sync.WaitGroup) {
f.wg = wg
}
// Start initialises connection to the sentry
func (f *Fetch) Start() {
// ConnectSentries initialises connection to the sentry
func (f *Fetch) ConnectSentries() {
for i := range f.sentryClients {
go func(i int) {
f.receiveMessageLoop(f.sentryClients[i])
@ -103,6 +107,9 @@ func (f *Fetch) Start() {
}(i)
}
}
func (f *Fetch) ConnectCore() {
go func() { f.stateChangesLoop(f.ctx, f.stateChangesClient) }()
}
func (f *Fetch) receiveMessageLoop(sentryClient sentry.SentryClient) {
logger := f.logger
@ -361,3 +368,76 @@ func (f *Fetch) handleNewPeer(req *sentry.PeersReply) error {
return nil
}
func (f *Fetch) stateChangesLoop(ctx context.Context, client remote.KVClient) {
for {
select {
case <-ctx.Done():
return
default:
}
f.stateChangesStream(ctx, client)
}
}
func (f *Fetch) stateChangesStream(ctx context.Context, client remote.KVClient) {
streamCtx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := client.StateChanges(streamCtx, &remote.StateChangeRequest{WithStorage: false, WithTransactions: true}, grpc.WaitForReady(true))
if err != nil {
select {
case <-ctx.Done():
return
default:
}
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
return
}
if errors.Is(err, io.EOF) {
return
}
time.Sleep(time.Second)
log.Warn("state changes", "err", err)
}
for req, err := stream.Recv(); ; req, err = stream.Recv() {
if err != nil {
select {
case <-ctx.Done():
return
default:
}
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
return
}
if errors.Is(err, io.EOF) {
return
}
log.Warn("stream.Recv", "err", err)
return
}
if req == nil {
return
}
var unwindTxs, minedTxs TxSlots
if req.Direction == remote.Direction_FORWARD {
minedTxs.Growth(len(req.Txs))
}
if req.Direction == remote.Direction_UNWIND {
unwindTxs.Growth(len(req.Txs))
}
diff := map[string]senderInfo{}
for _, change := range req.Changes {
nonce, balance, err := DecodeSender(change.Data)
if err != nil {
log.Warn("stateChanges.decodeSender", "err", err)
continue
}
addr := gointerfaces.ConvertH160toAddress(change.Address)
diff[string(addr[:])] = senderInfo{nonce: nonce, balance: balance}
}
if err := f.pool.OnNewBlock(diff, unwindTxs, minedTxs, req.ProtocolBaseFee, req.BlockBaseFee, req.BlockHeight); err != nil {
log.Warn("onNewBlock", "err", err)
}
}
}

View File

@ -23,6 +23,7 @@ import (
"testing"
"github.com/ledgerwatch/erigon-lib/direct"
"github.com/ledgerwatch/erigon-lib/gointerfaces/remote"
"github.com/ledgerwatch/erigon-lib/gointerfaces/sentry"
"github.com/ledgerwatch/erigon-lib/gointerfaces/types"
"github.com/ledgerwatch/log/v3"
@ -43,11 +44,11 @@ func TestFetch(t *testing.T) {
sentryClient := direct.NewSentryClientDirect(direct.ETH66, m)
pool := &PoolMock{}
fetch := NewFetch(ctx, []sentry.SentryClient{sentryClient}, genesisHash, networkId, forks, pool, logger)
fetch := NewFetch(ctx, []sentry.SentryClient{sentryClient}, genesisHash, networkId, forks, pool, &remote.KVClientMock{}, logger)
var wg sync.WaitGroup
fetch.SetWaitGroup(&wg)
m.StreamWg.Add(2)
fetch.Start()
fetch.ConnectSentries()
m.StreamWg.Wait()
// Send one transaction id
wg.Add(1)

View File

@ -29,6 +29,9 @@ var _ Pool = &PoolMock{}
// IdHashKnownFunc: func(hash []byte) bool {
// panic("mock out the IdHashKnown method")
// },
// OnNewBlockFunc: func(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error {
// panic("mock out the OnNewBlock method")
// },
// }
//
// // use mockedPool in code that requires Pool
@ -48,6 +51,9 @@ type PoolMock struct {
// IdHashKnownFunc mocks the IdHashKnown method.
IdHashKnownFunc func(hash []byte) bool
// OnNewBlockFunc mocks the OnNewBlock method.
OnNewBlockFunc func(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error
// calls tracks calls to the methods.
calls struct {
// Add holds details about calls to the Add method.
@ -70,11 +76,27 @@ type PoolMock struct {
// Hash is the hash argument value.
Hash []byte
}
// OnNewBlock holds details about calls to the OnNewBlock method.
OnNewBlock []struct {
// StateChanges is the stateChanges argument value.
StateChanges map[string]senderInfo
// UnwindTxs is the unwindTxs argument value.
UnwindTxs TxSlots
// MinedTxs is the minedTxs argument value.
MinedTxs TxSlots
// ProtocolBaseFee is the protocolBaseFee argument value.
ProtocolBaseFee uint64
// BlockBaseFee is the blockBaseFee argument value.
BlockBaseFee uint64
// BlockHeight is the blockHeight argument value.
BlockHeight uint64
}
}
lockAdd sync.RWMutex
lockAddNewGoodPeer sync.RWMutex
lockGetRlp sync.RWMutex
lockIdHashKnown sync.RWMutex
lockOnNewBlock sync.RWMutex
}
// Add calls AddFunc.
@ -209,3 +231,57 @@ func (mock *PoolMock) IdHashKnownCalls() []struct {
mock.lockIdHashKnown.RUnlock()
return calls
}
// OnNewBlock calls OnNewBlockFunc.
func (mock *PoolMock) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs TxSlots, minedTxs TxSlots, protocolBaseFee uint64, blockBaseFee uint64, blockHeight uint64) error {
callInfo := struct {
StateChanges map[string]senderInfo
UnwindTxs TxSlots
MinedTxs TxSlots
ProtocolBaseFee uint64
BlockBaseFee uint64
BlockHeight uint64
}{
StateChanges: stateChanges,
UnwindTxs: unwindTxs,
MinedTxs: minedTxs,
ProtocolBaseFee: protocolBaseFee,
BlockBaseFee: blockBaseFee,
BlockHeight: blockHeight,
}
mock.lockOnNewBlock.Lock()
mock.calls.OnNewBlock = append(mock.calls.OnNewBlock, callInfo)
mock.lockOnNewBlock.Unlock()
if mock.OnNewBlockFunc == nil {
var (
errOut error
)
return errOut
}
return mock.OnNewBlockFunc(stateChanges, unwindTxs, minedTxs, protocolBaseFee, blockBaseFee, blockHeight)
}
// OnNewBlockCalls gets all the calls that were made to OnNewBlock.
// Check the length with:
// len(mockedPool.OnNewBlockCalls())
func (mock *PoolMock) OnNewBlockCalls() []struct {
StateChanges map[string]senderInfo
UnwindTxs TxSlots
MinedTxs TxSlots
ProtocolBaseFee uint64
BlockBaseFee uint64
BlockHeight uint64
} {
var calls []struct {
StateChanges map[string]senderInfo
UnwindTxs TxSlots
MinedTxs TxSlots
ProtocolBaseFee uint64
BlockBaseFee uint64
BlockHeight uint64
}
mock.lockOnNewBlock.RLock()
calls = mock.calls.OnNewBlock
mock.lockOnNewBlock.RUnlock()
return calls
}

View File

@ -38,6 +38,7 @@ type Pool interface {
IdHashKnown(hash []byte) bool
GetRlp(hash []byte) []byte
Add(newTxs TxSlots) error
OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error
AddNewGoodPeer(peerID PeerID)
}
@ -110,6 +111,7 @@ func (i *nonce2TxItem) Less(than btree.Item) bool {
type TxPool struct {
lock *sync.RWMutex
blockHeight atomic.Uint64
protocolBaseFee atomic.Uint64
blockBaseFee atomic.Uint64
@ -284,7 +286,7 @@ func onNewTxs(senderInfo map[uint64]*senderInfo, newTxs TxSlots, protocolBaseFee
return nil
}
func (p *TxPool) OnNewBlock(unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee uint64) error {
func (p *TxPool) OnNewBlock(stateChanges map[string]senderInfo, unwindTxs, minedTxs TxSlots, protocolBaseFee, blockBaseFee, blockHeight uint64) error {
p.lock.Lock()
defer p.lock.Unlock()
if err := unwindTxs.Valid(); err != nil {
@ -294,11 +296,17 @@ func (p *TxPool) OnNewBlock(unwindTxs, minedTxs TxSlots, protocolBaseFee, blockB
return err
}
p.blockHeight.Store(blockHeight)
p.protocolBaseFee.Store(protocolBaseFee)
p.blockBaseFee.Store(blockBaseFee)
setTxSenderID(p.senderIDs, p.senderInfo, unwindTxs)
setTxSenderID(p.senderIDs, p.senderInfo, minedTxs)
for addr, id := range p.senderIDs { // merge state changes
if v, ok := stateChanges[addr]; ok {
p.senderInfo[id] = &v
}
}
if err := onNewBlock(p.senderInfo, unwindTxs, minedTxs.txs, protocolBaseFee, blockBaseFee, p.pending, p.baseFee, p.queued, p.byHash, p.localsHistory); err != nil {
return err
}
@ -322,7 +330,9 @@ func (p *TxPool) OnNewBlock(unwindTxs, minedTxs TxSlots, protocolBaseFee, blockB
}
func setTxSenderID(senderIDs map[string]uint64, senderInfo map[uint64]*senderInfo, txs TxSlots) {
for i := range txs.txs {
id, ok := senderIDs[string(txs.senders.At(i))]
addr := string(txs.senders.At(i))
id, ok := senderIDs[addr]
if !ok {
for i := range senderInfo { //TODO: create field for it?
if id < i {
@ -330,7 +340,7 @@ func setTxSenderID(senderIDs map[string]uint64, senderInfo map[uint64]*senderInf
}
}
id++
senderIDs[string(txs.senders.At(i))] = id
senderIDs[addr] = id
}
txs.txs[i].senderID = id
}

View File

@ -426,13 +426,13 @@ func FuzzOnNewBlocks11(f *testing.F) {
// go to first fork
//fmt.Printf("ll1: %d,%d,%d\n", pool.pending.Len(), pool.baseFee.Len(), pool.queued.Len())
unwindTxs, minedTxs1, p2pReceived, minedTxs2 := splitDataset(txs)
err = pool.OnNewBlock(unwindTxs, minedTxs1, protocolBaseFee, blockBaseFee)
err = pool.OnNewBlock(map[string]senderInfo{}, unwindTxs, minedTxs1, protocolBaseFee, blockBaseFee, 1)
assert.NoError(err)
check(unwindTxs, minedTxs1, "fork1")
checkNotify(unwindTxs, minedTxs1, "fork1")
// unwind everything and switch to new fork (need unwind mined now)
err = pool.OnNewBlock(minedTxs1, minedTxs2, protocolBaseFee, blockBaseFee)
err = pool.OnNewBlock(map[string]senderInfo{}, minedTxs1, minedTxs2, protocolBaseFee, blockBaseFee, 2)
assert.NoError(err)
check(minedTxs1, minedTxs2, "fork2")
checkNotify(minedTxs1, minedTxs2, "fork2")

View File

@ -426,3 +426,48 @@ func (s *TxSlots) Growth(targetSize int) {
}
var addressesGrowth = make([]byte, 20)
func DecodeSender(enc []byte) (nonce uint64, balance uint256.Int, err error) {
if len(enc) == 0 {
return
}
var fieldSet = enc[0]
var pos = 1
if fieldSet&1 > 0 {
decodeLength := int(enc[pos])
if len(enc) < pos+decodeLength+1 {
return nonce, balance, fmt.Errorf(
"malformed CBOR for Account.Nonce: %s, Length %d",
enc[pos+1:], decodeLength)
}
nonce = bytesToUint64(enc[pos+1 : pos+decodeLength+1])
pos += decodeLength + 1
}
if fieldSet&2 > 0 {
decodeLength := int(enc[pos])
if len(enc) < pos+decodeLength+1 {
return nonce, balance, fmt.Errorf(
"malformed CBOR for Account.Nonce: %s, Length %d",
enc[pos+1:], decodeLength)
}
(&balance).SetBytes(enc[pos+1 : pos+decodeLength+1])
}
return
}
func bytesToUint64(buf []byte) (x uint64) {
for i, b := range buf {
x = x<<8 + uint64(b)
if i == 7 {
return
}
}
return
}