fix nil pointer in fetch.go (#406)

This commit is contained in:
Alex Sharov 2022-03-31 15:13:11 +07:00 committed by GitHub
parent 5315a677b5
commit af4391d7f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 44 additions and 41 deletions

View File

@ -294,16 +294,6 @@ func (f *Fetch) handleInboundMessage(ctx context.Context, req *sentry.InboundMes
case sentry.MessageId_POOLED_TRANSACTIONS_66, sentry.MessageId_TRANSACTIONS_66:
txs := TxSlots{}
if err := f.threadSafeParsePooledTxn(func(parseContext *TxParseContext) error {
parseContext.ValidateHash(func(hash []byte) error {
known, err := f.pool.IdHashKnown(tx, hash)
if err != nil {
return err
}
if known {
return ErrRejected
}
return nil
})
return nil
}); err != nil {
return err
@ -312,7 +302,16 @@ func (f *Fetch) handleInboundMessage(ctx context.Context, req *sentry.InboundMes
switch req.Id {
case sentry.MessageId_TRANSACTIONS_66:
if err := f.threadSafeParsePooledTxn(func(parseContext *TxParseContext) error {
if _, err := ParseTransactions(req.Data, 0, parseContext, &txs); err != nil {
if _, err := ParseTransactions(req.Data, 0, parseContext, &txs, func(hash []byte) error {
known, err := f.pool.IdHashKnown(tx, hash)
if err != nil {
return err
}
if known {
return ErrRejected
}
return nil
}); err != nil {
return err
}
return nil
@ -321,7 +320,16 @@ func (f *Fetch) handleInboundMessage(ctx context.Context, req *sentry.InboundMes
}
case sentry.MessageId_POOLED_TRANSACTIONS_66:
if err := f.threadSafeParsePooledTxn(func(parseContext *TxParseContext) error {
if _, _, err := ParsePooledTransactions66(req.Data, 0, parseContext, &txs); err != nil {
if _, _, err := ParsePooledTransactions66(req.Data, 0, parseContext, &txs, func(hash []byte) error {
known, err := f.pool.IdHashKnown(tx, hash)
if err != nil {
return err
}
if known {
return ErrRejected
}
return nil
}); err != nil {
return err
}
return nil
@ -435,7 +443,7 @@ func (f *Fetch) handleStateChanges(ctx context.Context, client StateChangesClien
for i := range change.Txs {
minedTxs.txs[i] = &TxSlot{}
if err = f.threadSafeParseStateChangeTxn(func(parseContext *TxParseContext) error {
_, err := parseContext.ParseTransaction(change.Txs[i], 0, minedTxs.txs[i], minedTxs.senders.At(i), true /* hasEnvelope */)
_, err := parseContext.ParseTransaction(change.Txs[i], 0, minedTxs.txs[i], minedTxs.senders.At(i), true /* hasEnvelope */, nil)
return err
}); err != nil {
log.Warn("stream.Recv", "err", err)
@ -448,7 +456,7 @@ func (f *Fetch) handleStateChanges(ctx context.Context, client StateChangesClien
for i := range change.Txs {
unwindTxs.txs[i] = &TxSlot{}
if err = f.threadSafeParseStateChangeTxn(func(parseContext *TxParseContext) error {
_, err = parseContext.ParseTransaction(change.Txs[i], 0, unwindTxs.txs[i], unwindTxs.senders.At(i), true /* hasEnvelope */)
_, err = parseContext.ParseTransaction(change.Txs[i], 0, unwindTxs.txs[i], unwindTxs.senders.At(i), true /* hasEnvelope */, nil)
return err
}); err != nil {
log.Warn("stream.Recv", "err", err)

View File

@ -177,12 +177,6 @@ func (s *GrpcServer) Add(ctx context.Context, in *txpool_proto.AddRequest) (*txp
var slots TxSlots
parseCtx := NewTxParseContext(s.chainID)
parseCtx.ValidateHash(func(hash []byte) error {
if known, _ := s.txPool.IdHashKnown(tx, hash); known {
return ErrAlreadyKnown
}
return nil
})
parseCtx.ValidateRLP(s.txPool.ValidateSerializedTxn)
reply := &txpool_proto.AddReply{Imported: make([]txpool_proto.ImportResult, len(in.RlpTxs)), Errors: make([]string, len(in.RlpTxs))}
@ -192,7 +186,12 @@ func (s *GrpcServer) Add(ctx context.Context, in *txpool_proto.AddRequest) (*txp
slots.Resize(uint(j + 1))
slots.txs[j] = &TxSlot{}
slots.isLocal[j] = true
if _, err := parseCtx.ParseTransaction(in.RlpTxs[i], 0, slots.txs[j], slots.senders.At(j), false /* hasEnvelope */); err != nil {
if _, err := parseCtx.ParseTransaction(in.RlpTxs[i], 0, slots.txs[j], slots.senders.At(j), false /* hasEnvelope */, func(hash []byte) error {
if known, _ := s.txPool.IdHashKnown(tx, hash); known {
return ErrAlreadyKnown
}
return nil
}); err != nil {
if errors.Is(err, ErrAlreadyKnown) { // Noop, but need to handle to not count these
reply.Errors[i] = AlreadyKnown.String()
reply.Imported[i] = txpool_proto.ImportResult_ALREADY_EXISTS

View File

@ -167,7 +167,7 @@ func EncodeTransactions(txsRlp [][]byte, encodeBuf []byte) []byte {
return encodeBuf
}
func ParseTransactions(payload []byte, pos int, ctx *TxParseContext, txSlots *TxSlots) (newPos int, err error) {
func ParseTransactions(payload []byte, pos int, ctx *TxParseContext, txSlots *TxSlots, validateHash func([]byte) error) (newPos int, err error) {
pos, _, err = rlp.List(payload, pos)
if err != nil {
return 0, err
@ -176,7 +176,7 @@ func ParseTransactions(payload []byte, pos int, ctx *TxParseContext, txSlots *Tx
for i := 0; pos < len(payload); i++ {
txSlots.Resize(uint(i + 1))
txSlots.txs[i] = &TxSlot{}
pos, err = ctx.ParseTransaction(payload, pos, txSlots.txs[i], txSlots.senders.At(i), true /* hasEnvelope */)
pos, err = ctx.ParseTransaction(payload, pos, txSlots.txs[i], txSlots.senders.At(i), true /* hasEnvelope */, validateHash)
if err != nil {
if errors.Is(err, ErrRejected) {
txSlots.Resize(uint(i))
@ -189,7 +189,7 @@ func ParseTransactions(payload []byte, pos int, ctx *TxParseContext, txSlots *Tx
return pos, nil
}
func ParsePooledTransactions66(payload []byte, pos int, ctx *TxParseContext, txSlots *TxSlots) (requestID uint64, newPos int, err error) {
func ParsePooledTransactions66(payload []byte, pos int, ctx *TxParseContext, txSlots *TxSlots, validateHash func([]byte) error) (requestID uint64, newPos int, err error) {
p, _, err := rlp.List(payload, pos)
if err != nil {
return requestID, 0, err
@ -206,7 +206,7 @@ func ParsePooledTransactions66(payload []byte, pos int, ctx *TxParseContext, txS
for i := 0; p < len(payload); i++ {
txSlots.Resize(uint(i + 1))
txSlots.txs[i] = &TxSlot{}
p, err = ctx.ParseTransaction(payload, p, txSlots.txs[i], txSlots.senders.At(i), true /* hasEnvelope */)
p, err = ctx.ParseTransaction(payload, p, txSlots.txs[i], txSlots.senders.At(i), true /* hasEnvelope */, validateHash)
if err != nil {
if errors.Is(err, ErrRejected) {
txSlots.Resize(uint(i))

View File

@ -153,7 +153,7 @@ func TestPooledTransactionsPacket66(t *testing.T) {
ctx := NewTxParseContext(*uint256.NewInt(tt.chainID))
slots := &TxSlots{}
requestID, _, err := ParsePooledTransactions66(encodeBuf, 0, ctx, slots)
requestID, _, err := ParsePooledTransactions66(encodeBuf, 0, ctx, slots, nil)
require.NoError(err)
require.Equal(tt.requestID, requestID)
require.Equal(len(tt.txs), len(slots.txs))
@ -170,9 +170,8 @@ func TestPooledTransactionsPacket66(t *testing.T) {
require.Equal(tt.encoded, fmt.Sprintf("%x", encodeBuf))
ctx := NewTxParseContext(*u256.N1)
ctx.validateHash = func(bytes []byte) error { return ErrRejected }
slots := &TxSlots{}
requestID, _, err := ParsePooledTransactions66(encodeBuf, 0, ctx, slots)
requestID, _, err := ParsePooledTransactions66(encodeBuf, 0, ctx, slots, func(bytes []byte) error { return ErrRejected })
require.NoError(err)
require.Equal(tt.requestID, requestID)
require.Equal(0, len(slots.txs))
@ -213,7 +212,7 @@ func TestTransactionsPacket(t *testing.T) {
ctx := NewTxParseContext(*uint256.NewInt(tt.chainID))
slots := &TxSlots{}
_, err := ParseTransactions(encodeBuf, 0, ctx, slots)
_, err := ParseTransactions(encodeBuf, 0, ctx, slots, nil)
require.NoError(err)
require.Equal(len(tt.txs), len(slots.txs))
for i, txn := range tt.txs {
@ -229,9 +228,8 @@ func TestTransactionsPacket(t *testing.T) {
require.Equal(tt.encoded, fmt.Sprintf("%x", encodeBuf))
ctx := NewTxParseContext(*u256.N1)
ctx.validateHash = func(bytes []byte) error { return ErrRejected }
slots := &TxSlots{}
_, err := ParseTransactions(encodeBuf, 0, ctx, slots)
_, err := ParseTransactions(encodeBuf, 0, ctx, slots, func(bytes []byte) error { return ErrRejected })
require.NoError(err)
require.Equal(0, len(slots.txs))
require.Equal(0, slots.senders.Len())

View File

@ -1550,7 +1550,7 @@ func (p *TxPool) fromDB(ctx context.Context, tx kv.Tx, coreTx kv.Tx) error {
addr, txRlp := v[:20], v[20:]
txn := &TxSlot{}
_, err := parseCtx.ParseTransaction(txRlp, 0, txn, nil, false /* hasEnvelope */)
_, err := parseCtx.ParseTransaction(txRlp, 0, txn, nil, false /* hasEnvelope */, nil)
if err != nil {
return fmt.Errorf("err: %w, rlp: %x", err, txRlp)
}

View File

@ -205,7 +205,7 @@ func poolsFromFuzzBytes(rawTxNonce, rawValues, rawTips, rawFeeCap, rawSender []b
feeCap: feeCap[i%len(feeCap)],
}
txRlp := fakeRlpTx(txs.txs[i], senders.At(i%senders.Len()))
_, err := parseCtx.ParseTransaction(txRlp, 0, txs.txs[i], nil, false)
_, err := parseCtx.ParseTransaction(txRlp, 0, txs.txs[i], nil, false, nil)
if err != nil {
panic(err)
}

View File

@ -52,7 +52,6 @@ type TxParseContext struct {
sig [65]byte
withSender bool
isProtected bool
validateHash func([]byte) error
validateRlp func([]byte) error
cfg TxParsseConfig
@ -112,13 +111,12 @@ var ErrRejected = errors.New("rejected")
var ErrAlreadyKnown = errors.New("already known")
var ErrRlpTooBig = errors.New("txn rlp too big")
func (ctx *TxParseContext) ValidateHash(f func(hash []byte) error) { ctx.validateHash = f }
func (ctx *TxParseContext) ValidateRLP(f func(txnRlp []byte) error) { ctx.validateHash = f }
func (ctx *TxParseContext) ValidateRLP(f func(txnRlp []byte) error) { ctx.validateRlp = f }
func (ctx *TxParseContext) WithSender(v bool) { ctx.withSender = v }
// ParseTransaction extracts all the information from the transactions's payload (RLP) necessary to build TxSlot
// it also performs syntactic validation of the transactions
func (ctx *TxParseContext) ParseTransaction(payload []byte, pos int, slot *TxSlot, sender []byte, hasEnvelope bool) (p int, err error) {
func (ctx *TxParseContext) ParseTransaction(payload []byte, pos int, slot *TxSlot, sender []byte, hasEnvelope bool, validateHash func([]byte) error) (p int, err error) {
if len(payload) == 0 {
return 0, fmt.Errorf("%w: empty rlp", ErrParseTxn)
}
@ -380,8 +378,8 @@ func (ctx *TxParseContext) ParseTransaction(payload []byte, pos int, slot *TxSlo
if !ctx.withSender {
return p, nil
}
if ctx.validateHash != nil {
if err := ctx.validateHash(slot.IDHash[:32]); err != nil {
if validateHash != nil {
if err := validateHash(slot.IDHash[:32]); err != nil {
return p, err
}
}

View File

@ -29,6 +29,6 @@ func FuzzParseTx(f *testing.F) {
ctx := NewTxParseContext(*u256.N1)
txn := &TxSlot{}
sender := make([]byte, 20)
_, _ = ctx.ParseTransaction(in, pos, txn, sender, false)
_, _ = ctx.ParseTransaction(in, pos, txn, sender, false, nil)
})
}

View File

@ -111,7 +111,7 @@ func TestParseTransactionRLP(t *testing.T) {
for i, tt := range testSet.tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
payload := decodeHex(tt.payloadStr)
parseEnd, err := ctx.ParseTransaction(payload, 0, tx, txSender[:], false /* hasEnvelope */)
parseEnd, err := ctx.ParseTransaction(payload, 0, tx, txSender[:], false /* hasEnvelope */, nil)
require.NoError(err)
require.Equal(len(payload), parseEnd)
if tt.signHashStr != "" {