fix nil pointer in fetch.go (#406)

This commit is contained in:
Alex Sharov 2022-03-31 15:13:11 +07:00 committed by GitHub
parent 5315a677b5
commit af4391d7f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 44 additions and 41 deletions

View File

@ -294,16 +294,6 @@ func (f *Fetch) handleInboundMessage(ctx context.Context, req *sentry.InboundMes
case sentry.MessageId_POOLED_TRANSACTIONS_66, sentry.MessageId_TRANSACTIONS_66: case sentry.MessageId_POOLED_TRANSACTIONS_66, sentry.MessageId_TRANSACTIONS_66:
txs := TxSlots{} txs := TxSlots{}
if err := f.threadSafeParsePooledTxn(func(parseContext *TxParseContext) error { if err := f.threadSafeParsePooledTxn(func(parseContext *TxParseContext) error {
parseContext.ValidateHash(func(hash []byte) error {
known, err := f.pool.IdHashKnown(tx, hash)
if err != nil {
return err
}
if known {
return ErrRejected
}
return nil
})
return nil return nil
}); err != nil { }); err != nil {
return err return err
@ -312,7 +302,16 @@ func (f *Fetch) handleInboundMessage(ctx context.Context, req *sentry.InboundMes
switch req.Id { switch req.Id {
case sentry.MessageId_TRANSACTIONS_66: case sentry.MessageId_TRANSACTIONS_66:
if err := f.threadSafeParsePooledTxn(func(parseContext *TxParseContext) error { if err := f.threadSafeParsePooledTxn(func(parseContext *TxParseContext) error {
if _, err := ParseTransactions(req.Data, 0, parseContext, &txs); err != nil { if _, err := ParseTransactions(req.Data, 0, parseContext, &txs, func(hash []byte) error {
known, err := f.pool.IdHashKnown(tx, hash)
if err != nil {
return err
}
if known {
return ErrRejected
}
return nil
}); err != nil {
return err return err
} }
return nil return nil
@ -321,7 +320,16 @@ func (f *Fetch) handleInboundMessage(ctx context.Context, req *sentry.InboundMes
} }
case sentry.MessageId_POOLED_TRANSACTIONS_66: case sentry.MessageId_POOLED_TRANSACTIONS_66:
if err := f.threadSafeParsePooledTxn(func(parseContext *TxParseContext) error { if err := f.threadSafeParsePooledTxn(func(parseContext *TxParseContext) error {
if _, _, err := ParsePooledTransactions66(req.Data, 0, parseContext, &txs); err != nil { if _, _, err := ParsePooledTransactions66(req.Data, 0, parseContext, &txs, func(hash []byte) error {
known, err := f.pool.IdHashKnown(tx, hash)
if err != nil {
return err
}
if known {
return ErrRejected
}
return nil
}); err != nil {
return err return err
} }
return nil return nil
@ -435,7 +443,7 @@ func (f *Fetch) handleStateChanges(ctx context.Context, client StateChangesClien
for i := range change.Txs { for i := range change.Txs {
minedTxs.txs[i] = &TxSlot{} minedTxs.txs[i] = &TxSlot{}
if err = f.threadSafeParseStateChangeTxn(func(parseContext *TxParseContext) error { if err = f.threadSafeParseStateChangeTxn(func(parseContext *TxParseContext) error {
_, err := parseContext.ParseTransaction(change.Txs[i], 0, minedTxs.txs[i], minedTxs.senders.At(i), true /* hasEnvelope */) _, err := parseContext.ParseTransaction(change.Txs[i], 0, minedTxs.txs[i], minedTxs.senders.At(i), true /* hasEnvelope */, nil)
return err return err
}); err != nil { }); err != nil {
log.Warn("stream.Recv", "err", err) log.Warn("stream.Recv", "err", err)
@ -448,7 +456,7 @@ func (f *Fetch) handleStateChanges(ctx context.Context, client StateChangesClien
for i := range change.Txs { for i := range change.Txs {
unwindTxs.txs[i] = &TxSlot{} unwindTxs.txs[i] = &TxSlot{}
if err = f.threadSafeParseStateChangeTxn(func(parseContext *TxParseContext) error { if err = f.threadSafeParseStateChangeTxn(func(parseContext *TxParseContext) error {
_, err = parseContext.ParseTransaction(change.Txs[i], 0, unwindTxs.txs[i], unwindTxs.senders.At(i), true /* hasEnvelope */) _, err = parseContext.ParseTransaction(change.Txs[i], 0, unwindTxs.txs[i], unwindTxs.senders.At(i), true /* hasEnvelope */, nil)
return err return err
}); err != nil { }); err != nil {
log.Warn("stream.Recv", "err", err) log.Warn("stream.Recv", "err", err)

View File

@ -177,12 +177,6 @@ func (s *GrpcServer) Add(ctx context.Context, in *txpool_proto.AddRequest) (*txp
var slots TxSlots var slots TxSlots
parseCtx := NewTxParseContext(s.chainID) parseCtx := NewTxParseContext(s.chainID)
parseCtx.ValidateHash(func(hash []byte) error {
if known, _ := s.txPool.IdHashKnown(tx, hash); known {
return ErrAlreadyKnown
}
return nil
})
parseCtx.ValidateRLP(s.txPool.ValidateSerializedTxn) parseCtx.ValidateRLP(s.txPool.ValidateSerializedTxn)
reply := &txpool_proto.AddReply{Imported: make([]txpool_proto.ImportResult, len(in.RlpTxs)), Errors: make([]string, len(in.RlpTxs))} reply := &txpool_proto.AddReply{Imported: make([]txpool_proto.ImportResult, len(in.RlpTxs)), Errors: make([]string, len(in.RlpTxs))}
@ -192,7 +186,12 @@ func (s *GrpcServer) Add(ctx context.Context, in *txpool_proto.AddRequest) (*txp
slots.Resize(uint(j + 1)) slots.Resize(uint(j + 1))
slots.txs[j] = &TxSlot{} slots.txs[j] = &TxSlot{}
slots.isLocal[j] = true slots.isLocal[j] = true
if _, err := parseCtx.ParseTransaction(in.RlpTxs[i], 0, slots.txs[j], slots.senders.At(j), false /* hasEnvelope */); err != nil { if _, err := parseCtx.ParseTransaction(in.RlpTxs[i], 0, slots.txs[j], slots.senders.At(j), false /* hasEnvelope */, func(hash []byte) error {
if known, _ := s.txPool.IdHashKnown(tx, hash); known {
return ErrAlreadyKnown
}
return nil
}); err != nil {
if errors.Is(err, ErrAlreadyKnown) { // Noop, but need to handle to not count these if errors.Is(err, ErrAlreadyKnown) { // Noop, but need to handle to not count these
reply.Errors[i] = AlreadyKnown.String() reply.Errors[i] = AlreadyKnown.String()
reply.Imported[i] = txpool_proto.ImportResult_ALREADY_EXISTS reply.Imported[i] = txpool_proto.ImportResult_ALREADY_EXISTS

View File

@ -167,7 +167,7 @@ func EncodeTransactions(txsRlp [][]byte, encodeBuf []byte) []byte {
return encodeBuf return encodeBuf
} }
func ParseTransactions(payload []byte, pos int, ctx *TxParseContext, txSlots *TxSlots) (newPos int, err error) { func ParseTransactions(payload []byte, pos int, ctx *TxParseContext, txSlots *TxSlots, validateHash func([]byte) error) (newPos int, err error) {
pos, _, err = rlp.List(payload, pos) pos, _, err = rlp.List(payload, pos)
if err != nil { if err != nil {
return 0, err return 0, err
@ -176,7 +176,7 @@ func ParseTransactions(payload []byte, pos int, ctx *TxParseContext, txSlots *Tx
for i := 0; pos < len(payload); i++ { for i := 0; pos < len(payload); i++ {
txSlots.Resize(uint(i + 1)) txSlots.Resize(uint(i + 1))
txSlots.txs[i] = &TxSlot{} txSlots.txs[i] = &TxSlot{}
pos, err = ctx.ParseTransaction(payload, pos, txSlots.txs[i], txSlots.senders.At(i), true /* hasEnvelope */) pos, err = ctx.ParseTransaction(payload, pos, txSlots.txs[i], txSlots.senders.At(i), true /* hasEnvelope */, validateHash)
if err != nil { if err != nil {
if errors.Is(err, ErrRejected) { if errors.Is(err, ErrRejected) {
txSlots.Resize(uint(i)) txSlots.Resize(uint(i))
@ -189,7 +189,7 @@ func ParseTransactions(payload []byte, pos int, ctx *TxParseContext, txSlots *Tx
return pos, nil return pos, nil
} }
func ParsePooledTransactions66(payload []byte, pos int, ctx *TxParseContext, txSlots *TxSlots) (requestID uint64, newPos int, err error) { func ParsePooledTransactions66(payload []byte, pos int, ctx *TxParseContext, txSlots *TxSlots, validateHash func([]byte) error) (requestID uint64, newPos int, err error) {
p, _, err := rlp.List(payload, pos) p, _, err := rlp.List(payload, pos)
if err != nil { if err != nil {
return requestID, 0, err return requestID, 0, err
@ -206,7 +206,7 @@ func ParsePooledTransactions66(payload []byte, pos int, ctx *TxParseContext, txS
for i := 0; p < len(payload); i++ { for i := 0; p < len(payload); i++ {
txSlots.Resize(uint(i + 1)) txSlots.Resize(uint(i + 1))
txSlots.txs[i] = &TxSlot{} txSlots.txs[i] = &TxSlot{}
p, err = ctx.ParseTransaction(payload, p, txSlots.txs[i], txSlots.senders.At(i), true /* hasEnvelope */) p, err = ctx.ParseTransaction(payload, p, txSlots.txs[i], txSlots.senders.At(i), true /* hasEnvelope */, validateHash)
if err != nil { if err != nil {
if errors.Is(err, ErrRejected) { if errors.Is(err, ErrRejected) {
txSlots.Resize(uint(i)) txSlots.Resize(uint(i))

View File

@ -153,7 +153,7 @@ func TestPooledTransactionsPacket66(t *testing.T) {
ctx := NewTxParseContext(*uint256.NewInt(tt.chainID)) ctx := NewTxParseContext(*uint256.NewInt(tt.chainID))
slots := &TxSlots{} slots := &TxSlots{}
requestID, _, err := ParsePooledTransactions66(encodeBuf, 0, ctx, slots) requestID, _, err := ParsePooledTransactions66(encodeBuf, 0, ctx, slots, nil)
require.NoError(err) require.NoError(err)
require.Equal(tt.requestID, requestID) require.Equal(tt.requestID, requestID)
require.Equal(len(tt.txs), len(slots.txs)) require.Equal(len(tt.txs), len(slots.txs))
@ -170,9 +170,8 @@ func TestPooledTransactionsPacket66(t *testing.T) {
require.Equal(tt.encoded, fmt.Sprintf("%x", encodeBuf)) require.Equal(tt.encoded, fmt.Sprintf("%x", encodeBuf))
ctx := NewTxParseContext(*u256.N1) ctx := NewTxParseContext(*u256.N1)
ctx.validateHash = func(bytes []byte) error { return ErrRejected }
slots := &TxSlots{} slots := &TxSlots{}
requestID, _, err := ParsePooledTransactions66(encodeBuf, 0, ctx, slots) requestID, _, err := ParsePooledTransactions66(encodeBuf, 0, ctx, slots, func(bytes []byte) error { return ErrRejected })
require.NoError(err) require.NoError(err)
require.Equal(tt.requestID, requestID) require.Equal(tt.requestID, requestID)
require.Equal(0, len(slots.txs)) require.Equal(0, len(slots.txs))
@ -213,7 +212,7 @@ func TestTransactionsPacket(t *testing.T) {
ctx := NewTxParseContext(*uint256.NewInt(tt.chainID)) ctx := NewTxParseContext(*uint256.NewInt(tt.chainID))
slots := &TxSlots{} slots := &TxSlots{}
_, err := ParseTransactions(encodeBuf, 0, ctx, slots) _, err := ParseTransactions(encodeBuf, 0, ctx, slots, nil)
require.NoError(err) require.NoError(err)
require.Equal(len(tt.txs), len(slots.txs)) require.Equal(len(tt.txs), len(slots.txs))
for i, txn := range tt.txs { for i, txn := range tt.txs {
@ -229,9 +228,8 @@ func TestTransactionsPacket(t *testing.T) {
require.Equal(tt.encoded, fmt.Sprintf("%x", encodeBuf)) require.Equal(tt.encoded, fmt.Sprintf("%x", encodeBuf))
ctx := NewTxParseContext(*u256.N1) ctx := NewTxParseContext(*u256.N1)
ctx.validateHash = func(bytes []byte) error { return ErrRejected }
slots := &TxSlots{} slots := &TxSlots{}
_, err := ParseTransactions(encodeBuf, 0, ctx, slots) _, err := ParseTransactions(encodeBuf, 0, ctx, slots, func(bytes []byte) error { return ErrRejected })
require.NoError(err) require.NoError(err)
require.Equal(0, len(slots.txs)) require.Equal(0, len(slots.txs))
require.Equal(0, slots.senders.Len()) require.Equal(0, slots.senders.Len())

View File

@ -1550,7 +1550,7 @@ func (p *TxPool) fromDB(ctx context.Context, tx kv.Tx, coreTx kv.Tx) error {
addr, txRlp := v[:20], v[20:] addr, txRlp := v[:20], v[20:]
txn := &TxSlot{} txn := &TxSlot{}
_, err := parseCtx.ParseTransaction(txRlp, 0, txn, nil, false /* hasEnvelope */) _, err := parseCtx.ParseTransaction(txRlp, 0, txn, nil, false /* hasEnvelope */, nil)
if err != nil { if err != nil {
return fmt.Errorf("err: %w, rlp: %x", err, txRlp) return fmt.Errorf("err: %w, rlp: %x", err, txRlp)
} }

View File

@ -205,7 +205,7 @@ func poolsFromFuzzBytes(rawTxNonce, rawValues, rawTips, rawFeeCap, rawSender []b
feeCap: feeCap[i%len(feeCap)], feeCap: feeCap[i%len(feeCap)],
} }
txRlp := fakeRlpTx(txs.txs[i], senders.At(i%senders.Len())) txRlp := fakeRlpTx(txs.txs[i], senders.At(i%senders.Len()))
_, err := parseCtx.ParseTransaction(txRlp, 0, txs.txs[i], nil, false) _, err := parseCtx.ParseTransaction(txRlp, 0, txs.txs[i], nil, false, nil)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@ -52,7 +52,6 @@ type TxParseContext struct {
sig [65]byte sig [65]byte
withSender bool withSender bool
isProtected bool isProtected bool
validateHash func([]byte) error
validateRlp func([]byte) error validateRlp func([]byte) error
cfg TxParsseConfig cfg TxParsseConfig
@ -112,13 +111,12 @@ var ErrRejected = errors.New("rejected")
var ErrAlreadyKnown = errors.New("already known") var ErrAlreadyKnown = errors.New("already known")
var ErrRlpTooBig = errors.New("txn rlp too big") var ErrRlpTooBig = errors.New("txn rlp too big")
func (ctx *TxParseContext) ValidateHash(f func(hash []byte) error) { ctx.validateHash = f } func (ctx *TxParseContext) ValidateRLP(f func(txnRlp []byte) error) { ctx.validateRlp = f }
func (ctx *TxParseContext) ValidateRLP(f func(txnRlp []byte) error) { ctx.validateHash = f }
func (ctx *TxParseContext) WithSender(v bool) { ctx.withSender = v } func (ctx *TxParseContext) WithSender(v bool) { ctx.withSender = v }
// ParseTransaction extracts all the information from the transactions's payload (RLP) necessary to build TxSlot // ParseTransaction extracts all the information from the transactions's payload (RLP) necessary to build TxSlot
// it also performs syntactic validation of the transactions // it also performs syntactic validation of the transactions
func (ctx *TxParseContext) ParseTransaction(payload []byte, pos int, slot *TxSlot, sender []byte, hasEnvelope bool) (p int, err error) { func (ctx *TxParseContext) ParseTransaction(payload []byte, pos int, slot *TxSlot, sender []byte, hasEnvelope bool, validateHash func([]byte) error) (p int, err error) {
if len(payload) == 0 { if len(payload) == 0 {
return 0, fmt.Errorf("%w: empty rlp", ErrParseTxn) return 0, fmt.Errorf("%w: empty rlp", ErrParseTxn)
} }
@ -380,8 +378,8 @@ func (ctx *TxParseContext) ParseTransaction(payload []byte, pos int, slot *TxSlo
if !ctx.withSender { if !ctx.withSender {
return p, nil return p, nil
} }
if ctx.validateHash != nil { if validateHash != nil {
if err := ctx.validateHash(slot.IDHash[:32]); err != nil { if err := validateHash(slot.IDHash[:32]); err != nil {
return p, err return p, err
} }
} }

View File

@ -29,6 +29,6 @@ func FuzzParseTx(f *testing.F) {
ctx := NewTxParseContext(*u256.N1) ctx := NewTxParseContext(*u256.N1)
txn := &TxSlot{} txn := &TxSlot{}
sender := make([]byte, 20) sender := make([]byte, 20)
_, _ = ctx.ParseTransaction(in, pos, txn, sender, false) _, _ = ctx.ParseTransaction(in, pos, txn, sender, false, nil)
}) })
} }

View File

@ -111,7 +111,7 @@ func TestParseTransactionRLP(t *testing.T) {
for i, tt := range testSet.tests { for i, tt := range testSet.tests {
t.Run(strconv.Itoa(i), func(t *testing.T) { t.Run(strconv.Itoa(i), func(t *testing.T) {
payload := decodeHex(tt.payloadStr) payload := decodeHex(tt.payloadStr)
parseEnd, err := ctx.ParseTransaction(payload, 0, tx, txSender[:], false /* hasEnvelope */) parseEnd, err := ctx.ParseTransaction(payload, 0, tx, txSender[:], false /* hasEnvelope */, nil)
require.NoError(err) require.NoError(err)
require.Equal(len(payload), parseEnd) require.Equal(len(payload), parseEnd)
if tt.signHashStr != "" { if tt.signHashStr != "" {