diff --git a/tlv/stream.go b/tlv/stream.go index d716d7d07..3d353ad22 100644 --- a/tlv/stream.go +++ b/tlv/stream.go @@ -111,11 +111,11 @@ func (s *Stream) Encode(w io.Writer) error { return nil } -// Decode deserializes TLV Stream from the passed io.Reader. The Stream will -// inspect each record that is parsed and check to see if it has a corresponding -// Record to facilitate deserialization of that field. If the record is unknown, -// the Stream will discard the record's bytes and proceed to the subsequent -// record. +// Decode deserializes TLV Stream from the passed io.Reader for non-P2P +// settings. The Stream will inspect each record that is parsed and check to +// see if it has a corresponding Record to facilitate deserialization of that +// field. If the record is unknown, the Stream will discard the record's bytes +// and proceed to the subsequent record. // // Each record has the following format: // @@ -137,7 +137,14 @@ func (s *Stream) Encode(w io.Writer) error { // the last record was read cleanly and we should stop parsing. All other io.EOF // or io.ErrUnexpectedEOF errors are returned. func (s *Stream) Decode(r io.Reader) error { - _, err := s.decode(r, nil) + _, err := s.decode(r, nil, false) + return err +} + +// DecodeP2P is identical to Decode except that the maximum record size is +// capped at 65535. +func (s *Stream) DecodeP2P(r io.Reader) error { + _, err := s.decode(r, nil, true) return err } @@ -145,13 +152,23 @@ func (s *Stream) Decode(r io.Reader) error { // TypeMap containing the types of all records that were decoded or ignored from // the stream. func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeMap, error) { - return s.decode(r, make(TypeMap)) + return s.decode(r, make(TypeMap), false) +} + +// DecodeWithParsedTypesP2P is identical to DecodeWithParsedTypes except that +// the record size is capped at 65535. This should only be called from a p2p +// setting where untrusted input is being deserialized. +func (s *Stream) DecodeWithParsedTypesP2P(r io.Reader) (TypeMap, error) { + return s.decode(r, make(TypeMap), true) } // decode is a helper function that performs the basis of stream decoding. If // the caller needs the set of parsed types, it must provide an initialized -// parsedTypes, otherwise the returned TypeMap will be nil. -func (s *Stream) decode(r io.Reader, parsedTypes TypeMap) (TypeMap, error) { +// parsedTypes, otherwise the returned TypeMap will be nil. If the p2p bool is +// true, then the record size is capped at 65535. Otherwise, it is not capped. +func (s *Stream) decode(r io.Reader, parsedTypes TypeMap, p2p bool) (TypeMap, + error) { + var ( typ Type min Type @@ -204,8 +221,8 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeMap) (TypeMap, error) { // Place a soft limit on the size of a sane record, which // prevents malicious encoders from causing us to allocate an // unbounded amount of memory when decoding variable-sized - // fields. - if length > MaxRecordSize { + // fields. This is only checked when the p2p bool is true. + if p2p && length > MaxRecordSize { return nil, ErrRecordTooLarge } diff --git a/tlv/stream_test.go b/tlv/stream_test.go index 60a3db73b..0a74bbf60 100644 --- a/tlv/stream_test.go +++ b/tlv/stream_test.go @@ -248,3 +248,64 @@ func makeBigSizeFormatTlvStream(t *testing.T, vUint32 *uint32, return ts } + +// TestDecodeP2P tests that the p2p variants of the stream decode functions +// work with small records and fail with large records. +func TestDecodeP2P(t *testing.T) { + t.Parallel() + + const ( + smallType tlv.Type = 8 + largeType tlv.Type = 10 + ) + + var ( + smallBytes = []byte{ + 0x08, // tlv type = 8 + 0x10, // length = 16 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + } + + largeBytes = []byte{ + 0x0a, // tlv type = 10 + 0xfe, 0x00, 0x01, 0x00, 0x00, // length = 65536 + } + ) + + // Verify the expected behavior for the large type. + s, err := tlv.NewStream(tlv.MakePrimitiveRecord(largeType, &[]byte{})) + require.NoError(t, err) + + // Decoding with either of the p2p stream decoders should fail with the + // record too large error. + buf := bytes.NewBuffer(largeBytes) + require.Equal(t, s.DecodeP2P(buf), tlv.ErrRecordTooLarge) + + buf2 := bytes.NewBuffer(largeBytes) + _, err = s.DecodeWithParsedTypesP2P(buf2) + require.Equal(t, err, tlv.ErrRecordTooLarge) + + // Extend largeBytes with a payload of 65536 bytes so that the non-p2p + // decoders can successfully decode it. + largeSlice := make([]byte, 65542) + copy(largeSlice[:6], largeBytes) + buf3 := bytes.NewBuffer(largeSlice) + require.NoError(t, s.Decode(buf3)) + + buf4 := bytes.NewBuffer(largeSlice) + _, err = s.DecodeWithParsedTypes(buf4) + require.NoError(t, err) + + // Now create a new stream and assert that the p2p-variants can decode + // small types. + s2, err := tlv.NewStream(tlv.MakePrimitiveRecord(smallType, &[]byte{})) + require.NoError(t, err) + + buf5 := bytes.NewBuffer(smallBytes) + require.NoError(t, s2.DecodeP2P(buf5)) + + buf6 := bytes.NewBuffer(smallBytes) + _, err = s2.DecodeWithParsedTypesP2P(buf6) + require.NoError(t, err) +} diff --git a/tlv/tlv_test.go b/tlv/tlv_test.go index 3dd520b32..35e23cab7 100644 --- a/tlv/tlv_test.go +++ b/tlv/tlv_test.go @@ -370,12 +370,6 @@ var tlvDecodingFailureTests = []struct { bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00}, expErr: tlv.ErrStreamNotCanonical, }, - { - name: "absurd record length", - bytes: []byte{0xfd, 0x01, 0x91, 0xfe, 0xff, 0xff, 0xff, 0xff}, - expErr: tlv.ErrRecordTooLarge, - skipN2: true, - }, } // TestTLVDecodingSuccess asserts that the TLV parser fails to decode invalid