diff --git a/txpool/fetch_test.go b/txpool/fetch_test.go index 2f6f3c87b..9d1d1afc4 100644 --- a/txpool/fetch_test.go +++ b/txpool/fetch_test.go @@ -18,6 +18,7 @@ package txpool import ( "context" + "encoding/hex" "sync" "testing" @@ -42,7 +43,12 @@ func TestFetch(t *testing.T) { mock.StreamWg.Wait() // Send one transaction id wg.Add(1) - errs := mock.Send(&sentry.InboundMessage{Id: sentry.MessageId_NEW_POOLED_TRANSACTION_HASHES_66, Data: nil, PeerId: PeerId}) + data, _ := hex.DecodeString("e1a0595e27a835cd79729ff1eeacec3120eeb6ed1464a04ec727aaca734ead961328") + errs := mock.Send(&sentry.InboundMessage{ + Id: sentry.MessageId_NEW_POOLED_TRANSACTION_HASHES_66, + Data: data, + PeerId: PeerId, + }) for i, err := range errs { if err != nil { t.Errorf("sending new pool tx hashes 66 (%d): %v", i, err) diff --git a/txpool/packets.go b/txpool/packets.go index 0049d8bed..a3c5b4016 100644 --- a/txpool/packets.go +++ b/txpool/packets.go @@ -16,7 +16,11 @@ package txpool -import "fmt" +import ( + "encoding/binary" + "fmt" + "math/bits" +) const ParseHashErrorPrefix = "parse hash payload" @@ -68,13 +72,44 @@ func ParseHashesCount(payload []byte, pos int) (int, int, error) { if dataLen%33 != 0 { return 0, 0, fmt.Errorf("%s: hashes len must be multiple of 33", ParseHashErrorPrefix) } - return dataLen / 33, dataPos + dataLen, nil + return dataLen / 33, dataPos, nil } // EncodeHashes produces RLP encoding of given number of hashes, as RLP list // It appends encoding to the given given slice (encodeBuf), reusing the space // there is there is enough capacity. -// The first returned value is rthe slice where encodinfg +// The first returned value is the slice where encodinfg func EncodeHashes(hashes []byte, count int, encodeBuf []byte) ([]byte, error) { - return nil, nil + var prefixLen int + dataLen := count * 33 + var beLen int + if dataLen < 56 { + prefixLen = 1 + } else { + beLen = (bits.Len64(uint64(dataLen)) + 7) / 8 + prefixLen = 1 + beLen + } + var encoding []byte + if total := len(encodeBuf) + dataLen + prefixLen; cap(encodeBuf) >= total { + encoding = encodeBuf[:dataLen+prefixLen] // Reuse the space in pkbuf, is it has enough capacity + } else { + encoding = make([]byte, total) + copy(encoding, encodeBuf) + } + if dataLen < 56 { + encoding[0] = 192 + byte(dataLen) + } else { + encoding[0] = 247 + byte(beLen) + binary.BigEndian.PutUint64(encoding[1:], uint64(beLen)) + copy(encoding[1:], encoding[9-beLen:9]) + } + hashP := 0 + encP := prefixLen + for i := 0; i < count; i++ { + encoding[encP] = 128 + 32 + copy(encoding[encP+1:encP+33], hashes[hashP:hashP+32]) + encP += 33 + hashP += 32 + } + return encoding, nil } diff --git a/txpool/packets_test.go b/txpool/packets_test.go index 9715d03cc..b9ab94a18 100644 --- a/txpool/packets_test.go +++ b/txpool/packets_test.go @@ -63,3 +63,39 @@ func TestParseHash(t *testing.T) { }) } } + +var hashEncodeTests = []struct { + payloadStr string + hashesStr string + hashCount int + expectedErr bool +}{ + {payloadStr: "e1a0595e27a835cd79729ff1eeacec3120eeb6ed1464a04ec727aaca734ead961328", hashesStr: "595e27a835cd79729ff1eeacec3120eeb6ed1464a04ec727aaca734ead961328", hashCount: 1, expectedErr: false}, +} + +func TestEncodeHash(t *testing.T) { + for i, tt := range hashEncodeTests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + var payload []byte + var hashes []byte + var err error + var encodeBuf []byte + if payload, err = hex.DecodeString(tt.payloadStr); err != nil { + t.Fatal(err) + } + if hashes, err = hex.DecodeString(tt.hashesStr); err != nil { + t.Fatal(err) + } + if encodeBuf, err = EncodeHashes(hashes, tt.hashCount, encodeBuf); err != nil { + if !tt.expectedErr { + t.Fatal(err) + } + } else if tt.expectedErr { + t.Fatalf("expected error when encoding") + } + if !bytes.Equal(payload, encodeBuf) { + t.Errorf("encoding expected %x, got %x", payload, encodeBuf) + } + }) + } +}