mirror of
https://gitlab.com/pulsechaincom/prysm-pulse.git
synced 2024-12-22 03:30:35 +00:00
Tighten Up Snappy Framing (#7479)
* fix framing * fix up conditions * fix * clean up * change back * simpler * no need to cast * Use math.MaxInt64 * gaz, gofmt Co-authored-by: Preston Van Loon <preston@prysmaticlabs.com> Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
This commit is contained in:
parent
98a20766c9
commit
f629c72107
@ -3,6 +3,7 @@ package encoder
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
fastssz "github.com/ferranbt/fastssz"
|
||||
@ -115,33 +116,23 @@ func (e SszNetworkEncoder) DecodeWithMaxLength(r io.Reader, to interface{}) erro
|
||||
params.BeaconNetworkConfig().MaxChunkSize,
|
||||
)
|
||||
}
|
||||
r = newBufferedReader(r)
|
||||
defer bufReaderPool.Put(r)
|
||||
|
||||
maxLen, err := e.MaxLength(int(msgLen))
|
||||
msgMax, err := e.MaxLength(msgLen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
limitedRdr := io.LimitReader(r, int64(msgMax))
|
||||
r = newBufferedReader(limitedRdr)
|
||||
defer bufReaderPool.Put(r)
|
||||
|
||||
b := make([]byte, maxLen)
|
||||
numOfBytes := 0
|
||||
// Read all bytes from stream to handle multiple
|
||||
// framed chunks. Required if reading objects which
|
||||
// are larger than 65 kb.
|
||||
for numOfBytes < int(msgLen) {
|
||||
readBytes, err := r.Read(b[numOfBytes:])
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
numOfBytes += readBytes
|
||||
buf := make([]byte, msgLen)
|
||||
// Returns an error if less than msgLen bytes
|
||||
// are read. This ensures we read exactly the
|
||||
// required amount.
|
||||
_, err = io.ReadFull(r, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if numOfBytes != int(msgLen) {
|
||||
return errors.Errorf("decompressed data has an unexpected length, wanted %d but got %d", msgLen, numOfBytes)
|
||||
}
|
||||
return e.doDecode(b[:numOfBytes], to)
|
||||
return e.doDecode(buf, to)
|
||||
}
|
||||
|
||||
// ProtocolSuffix returns the appropriate suffix for protocol IDs.
|
||||
@ -151,8 +142,12 @@ func (e SszNetworkEncoder) ProtocolSuffix() string {
|
||||
|
||||
// MaxLength specifies the maximum possible length of an encoded
|
||||
// chunk of data.
|
||||
func (e SszNetworkEncoder) MaxLength(length int) (int, error) {
|
||||
maxLen := snappy.MaxEncodedLen(length)
|
||||
func (e SszNetworkEncoder) MaxLength(length uint64) (int, error) {
|
||||
// Defensive check to prevent potential issues when casting to int64.
|
||||
if length > math.MaxInt64 {
|
||||
return 0, errors.Errorf("invalid length provided: %d", length)
|
||||
}
|
||||
maxLen := snappy.MaxEncodedLen(int(length))
|
||||
if maxLen < 0 {
|
||||
return 0, errors.Errorf("max encoded length is negative: %d", maxLen)
|
||||
}
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
@ -119,7 +121,6 @@ func TestSszNetworkEncoder_DecodeWithMultipleFrames(t *testing.T) {
|
||||
err = e.DecodeWithMaxLength(buf, decoded)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSszNetworkEncoder_NegativeMaxLength(t *testing.T) {
|
||||
e := &encoder.SszNetworkEncoder{}
|
||||
length, err := e.MaxLength(0xfffffffffff)
|
||||
@ -127,3 +128,91 @@ func TestSszNetworkEncoder_NegativeMaxLength(t *testing.T) {
|
||||
assert.Equal(t, 0, length, "Received non zero length on bad message length")
|
||||
assert.ErrorContains(t, "max encoded length is negative", err)
|
||||
}
|
||||
|
||||
func TestSszNetworkEncoder_MaxInt64(t *testing.T) {
|
||||
e := &encoder.SszNetworkEncoder{}
|
||||
length, err := e.MaxLength(math.MaxInt64 + 1)
|
||||
|
||||
assert.Equal(t, 0, length, "Received non zero length on bad message length")
|
||||
assert.ErrorContains(t, "invalid length provided", err)
|
||||
}
|
||||
|
||||
func TestSszNetworkEncoder_DecodeWithBadSnappyStream(t *testing.T) {
|
||||
st := newBadSnappyStream()
|
||||
e := &encoder.SszNetworkEncoder{}
|
||||
decoded := new(pb.Fork)
|
||||
err := e.DecodeWithMaxLength(st, decoded)
|
||||
assert.ErrorContains(t, io.EOF.Error(), err)
|
||||
}
|
||||
|
||||
type badSnappyStream struct {
|
||||
varint []byte
|
||||
header []byte
|
||||
repeat []byte
|
||||
i int
|
||||
// count how many times it was read
|
||||
counter int
|
||||
// count bytes read so far
|
||||
total int
|
||||
}
|
||||
|
||||
func newBadSnappyStream() *badSnappyStream {
|
||||
const (
|
||||
magicBody = "sNaPpY"
|
||||
magicChunk = "\xff\x06\x00\x00" + magicBody
|
||||
)
|
||||
|
||||
header := make([]byte, len(magicChunk))
|
||||
// magicChunk == chunkTypeStreamIdentifier byte ++ 3 byte little endian len(magic body) ++ 6 byte magic body
|
||||
|
||||
// header is a special chunk type, with small fixed length, to add some magic to claim it's really snappy.
|
||||
copy(header, magicChunk) // snappy library constants help us construct the common header chunk easily.
|
||||
|
||||
payload := make([]byte, 4)
|
||||
|
||||
// byte 0 is chunk type
|
||||
// Exploit any fancy ignored chunk type
|
||||
// Section 4.4 Padding (chunk type 0xfe).
|
||||
// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
|
||||
payload[0] = 0xfe
|
||||
|
||||
// byte 1,2,3 are chunk length (little endian)
|
||||
payload[1] = 0
|
||||
payload[2] = 0
|
||||
payload[3] = 0
|
||||
|
||||
return &badSnappyStream{
|
||||
varint: proto.EncodeVarint(1000),
|
||||
header: header,
|
||||
repeat: payload,
|
||||
i: 0,
|
||||
counter: 0,
|
||||
total: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *badSnappyStream) Read(p []byte) (n int, err error) {
|
||||
// Stream out varint bytes first to make test happy.
|
||||
if len(b.varint) > 0 {
|
||||
copy(p, b.varint[:1])
|
||||
b.varint = b.varint[1:]
|
||||
return 1, nil
|
||||
}
|
||||
defer func() {
|
||||
b.counter += 1
|
||||
b.total += n
|
||||
}()
|
||||
if len(b.repeat) == 0 {
|
||||
panic("no bytes to repeat")
|
||||
}
|
||||
if len(b.header) > 0 {
|
||||
n = copy(p, b.header)
|
||||
b.header = b.header[n:]
|
||||
return
|
||||
}
|
||||
for n < len(p) {
|
||||
n += copy(p[n:], b.repeat[b.i:])
|
||||
b.i = (b.i + n) % len(b.repeat)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user