From f629c7210788467e0904aaa1c5ae4525e5d53b1e Mon Sep 17 00:00:00 2001 From: Nishant Das Date: Sat, 10 Oct 2020 00:50:18 +0800 Subject: [PATCH] Tighten Up Snappy Framing (#7479) * fix framing * fix up conditions * fix * clean up * change back * simpler * no need to cast * Use math.MaxInt64 * gaz, gofmt Co-authored-by: Preston Van Loon Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com> --- beacon-chain/p2p/encoder/ssz.go | 43 ++++++------- beacon-chain/p2p/encoder/ssz_test.go | 91 +++++++++++++++++++++++++++- 2 files changed, 109 insertions(+), 25 deletions(-) diff --git a/beacon-chain/p2p/encoder/ssz.go b/beacon-chain/p2p/encoder/ssz.go index 5a1e6ad28..48019bcc0 100644 --- a/beacon-chain/p2p/encoder/ssz.go +++ b/beacon-chain/p2p/encoder/ssz.go @@ -3,6 +3,7 @@ package encoder import ( "fmt" "io" + "math" "sync" fastssz "github.com/ferranbt/fastssz" @@ -115,33 +116,23 @@ func (e SszNetworkEncoder) DecodeWithMaxLength(r io.Reader, to interface{}) erro params.BeaconNetworkConfig().MaxChunkSize, ) } - r = newBufferedReader(r) - defer bufReaderPool.Put(r) - - maxLen, err := e.MaxLength(int(msgLen)) + msgMax, err := e.MaxLength(msgLen) if err != nil { return err } + limitedRdr := io.LimitReader(r, int64(msgMax)) + r = newBufferedReader(limitedRdr) + defer bufReaderPool.Put(r) - b := make([]byte, maxLen) - numOfBytes := 0 - // Read all bytes from stream to handle multiple - // framed chunks. Required if reading objects which - // are larger than 65 kb. - for numOfBytes < int(msgLen) { - readBytes, err := r.Read(b[numOfBytes:]) - if err == io.EOF { - break - } - if err != nil { - return err - } - numOfBytes += readBytes + buf := make([]byte, msgLen) + // Returns an error if less than msgLen bytes + // are read. This ensures we read exactly the + // required amount. + _, err = io.ReadFull(r, buf) + if err != nil { + return err } - if numOfBytes != int(msgLen) { - return errors.Errorf("decompressed data has an unexpected length, wanted %d but got %d", msgLen, numOfBytes) - } - return e.doDecode(b[:numOfBytes], to) + return e.doDecode(buf, to) } // ProtocolSuffix returns the appropriate suffix for protocol IDs. @@ -151,8 +142,12 @@ func (e SszNetworkEncoder) ProtocolSuffix() string { // MaxLength specifies the maximum possible length of an encoded // chunk of data. -func (e SszNetworkEncoder) MaxLength(length int) (int, error) { - maxLen := snappy.MaxEncodedLen(length) +func (e SszNetworkEncoder) MaxLength(length uint64) (int, error) { + // Defensive check to prevent potential issues when casting to int64. + if length > math.MaxInt64 { + return 0, errors.Errorf("invalid length provided: %d", length) + } + maxLen := snappy.MaxEncodedLen(int(length)) if maxLen < 0 { return 0, errors.Errorf("max encoded length is negative: %d", maxLen) } diff --git a/beacon-chain/p2p/encoder/ssz_test.go b/beacon-chain/p2p/encoder/ssz_test.go index 93550c525..d1eec32bd 100644 --- a/beacon-chain/p2p/encoder/ssz_test.go +++ b/beacon-chain/p2p/encoder/ssz_test.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "fmt" + "io" + "math" "testing" "github.com/gogo/protobuf/proto" @@ -119,7 +121,6 @@ func TestSszNetworkEncoder_DecodeWithMultipleFrames(t *testing.T) { err = e.DecodeWithMaxLength(buf, decoded) assert.NoError(t, err) } - func TestSszNetworkEncoder_NegativeMaxLength(t *testing.T) { e := &encoder.SszNetworkEncoder{} length, err := e.MaxLength(0xfffffffffff) @@ -127,3 +128,91 @@ func TestSszNetworkEncoder_NegativeMaxLength(t *testing.T) { assert.Equal(t, 0, length, "Received non zero length on bad message length") assert.ErrorContains(t, "max encoded length is negative", err) } + +func TestSszNetworkEncoder_MaxInt64(t *testing.T) { + e := &encoder.SszNetworkEncoder{} + length, err := e.MaxLength(math.MaxInt64 + 1) + + assert.Equal(t, 0, length, "Received non zero length on bad message length") + assert.ErrorContains(t, "invalid length provided", err) +} + +func TestSszNetworkEncoder_DecodeWithBadSnappyStream(t *testing.T) { + st := newBadSnappyStream() + e := &encoder.SszNetworkEncoder{} + decoded := new(pb.Fork) + err := e.DecodeWithMaxLength(st, decoded) + assert.ErrorContains(t, io.EOF.Error(), err) +} + +type badSnappyStream struct { + varint []byte + header []byte + repeat []byte + i int + // count how many times it was read + counter int + // count bytes read so far + total int +} + +func newBadSnappyStream() *badSnappyStream { + const ( + magicBody = "sNaPpY" + magicChunk = "\xff\x06\x00\x00" + magicBody + ) + + header := make([]byte, len(magicChunk)) + // magicChunk == chunkTypeStreamIdentifier byte ++ 3 byte little endian len(magic body) ++ 6 byte magic body + + // header is a special chunk type, with small fixed length, to add some magic to claim it's really snappy. + copy(header, magicChunk) // snappy library constants help us construct the common header chunk easily. + + payload := make([]byte, 4) + + // byte 0 is chunk type + // Exploit any fancy ignored chunk type + // Section 4.4 Padding (chunk type 0xfe). + // Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd). + payload[0] = 0xfe + + // byte 1,2,3 are chunk length (little endian) + payload[1] = 0 + payload[2] = 0 + payload[3] = 0 + + return &badSnappyStream{ + varint: proto.EncodeVarint(1000), + header: header, + repeat: payload, + i: 0, + counter: 0, + total: 0, + } +} + +func (b *badSnappyStream) Read(p []byte) (n int, err error) { + // Stream out varint bytes first to make test happy. + if len(b.varint) > 0 { + copy(p, b.varint[:1]) + b.varint = b.varint[1:] + return 1, nil + } + defer func() { + b.counter += 1 + b.total += n + }() + if len(b.repeat) == 0 { + panic("no bytes to repeat") + } + if len(b.header) > 0 { + n = copy(p, b.header) + b.header = b.header[n:] + return + } + for n < len(p) { + n += copy(p[n:], b.repeat[b.i:]) + b.i = (b.i + n) % len(b.repeat) + } + return +}