package tlv import ( "bytes" "testing" "github.com/btcsuite/btcd/btcec/v2" "github.com/stretchr/testify/require" ) // harness decodes the passed data, re-encodes it, and verifies that the // re-encoded data matches the original data. func harness(t *testing.T, data []byte, encode Encoder, decode Decoder, val interface{}, decodeLen uint64) { if uint64(len(data)) > decodeLen { return } r := bytes.NewReader(data) var buf [8]byte if err := decode(r, val, &buf, decodeLen); err != nil { return } var b bytes.Buffer require.NoError(t, encode(&b, val, &buf)) // Use bytes.Equal instead of require.Equal so that nil and empty slices // are considered equal. require.True( t, bytes.Equal(data, b.Bytes()), "%v != %v", data, b.Bytes(), ) } func FuzzUint8(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var val uint8 harness(t, data, EUint8, DUint8, &val, 1) }) } func FuzzUint16(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var val uint16 harness(t, data, EUint16, DUint16, &val, 2) }) } func FuzzUint32(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var val uint32 harness(t, data, EUint32, DUint32, &val, 4) }) } func FuzzUint64(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var val uint64 harness(t, data, EUint64, DUint64, &val, 8) }) } func FuzzBytes32(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var val [32]byte harness(t, data, EBytes32, DBytes32, &val, 32) }) } func FuzzBytes33(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var val [33]byte harness(t, data, EBytes33, DBytes33, &val, 33) }) } func FuzzBytes64(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var val [64]byte harness(t, data, EBytes64, DBytes64, &val, 64) }) } func FuzzPubKey(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var val *btcec.PublicKey harness(t, data, EPubKey, DPubKey, &val, 33) }) } func FuzzVarBytes(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var val []byte harness(t, data, EVarBytes, DVarBytes, &val, uint64(len(data))) }) } func FuzzBool(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var val bool harness(t, data, EBool, DBool, &val, 1) }) } // bigSizeHarness works the same as harness, except that it compares decoded // values instead of encoded values. We do this because DBigSize may leave some // bytes unparsed from data, causing the encoded data to be shorter than the // original. func bigSizeHarness(t *testing.T, data []byte, val1, val2 interface{}) { if len(data) > 9 { return } r := bytes.NewReader(data) var buf [8]byte if err := DBigSize(r, val1, &buf, 9); err != nil { return } var b bytes.Buffer require.NoError(t, EBigSize(&b, val1, &buf)) r2 := bytes.NewReader(b.Bytes()) require.NoError(t, DBigSize(r2, val2, &buf, 9)) require.Equal(t, val1, val2) } func FuzzBigSize32(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var val1, val2 uint32 bigSizeHarness(t, data, &val1, &val2) }) } func FuzzBigSize64(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var val1, val2 uint64 bigSizeHarness(t, data, &val1, &val2) }) } func FuzzTUint16(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var val uint16 for decodeLen := 0; decodeLen <= 2; decodeLen++ { harness( t, data, ETUint16, DTUint16, &val, uint64(decodeLen), ) } }) } func FuzzTUint32(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var val uint32 for decodeLen := 0; decodeLen <= 4; decodeLen++ { harness( t, data, ETUint32, DTUint32, &val, uint64(decodeLen), ) } }) } func FuzzTUint64(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var val uint64 for decodeLen := 0; decodeLen <= 8; decodeLen++ { harness( t, data, ETUint64, DTUint64, &val, uint64(decodeLen), ) } }) } // encodeParsedTypes re-encodes TLVs decoded from a stream, using the // parsedTypes and decodedRecords produced during decoding. This function // requires that each record in decodedRecords has a type number equivalent to // its index in the slice. func encodeParsedTypes(t *testing.T, parsedTypes TypeMap, decodedRecords []Record) []byte { var encodeRecords []Record for typ, val := range parsedTypes { // If typ is present in decodedRecords, use the decoded value. if typ < Type(len(decodedRecords)) { encodeRecords = append( encodeRecords, decodedRecords[typ], ) continue } // Otherwise, typ is not present in decodedRecords, and we must // create a new one. val := val encodeRecords = append( encodeRecords, MakePrimitiveRecord(typ, &val), ) } SortRecords(encodeRecords) encodeStream := MustNewStream(encodeRecords...) var b bytes.Buffer require.NoError(t, encodeStream.Encode(&b)) return b.Bytes() } // FuzzStream does two stream decode-encode cycles on the fuzzer data and checks // that the encoded values match. func FuzzStream(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var ( u8 uint8 u16 uint16 u32 uint32 u64 uint64 b32 [32]byte b33 [33]byte b64 [64]byte pk *btcec.PublicKey b []byte bs32 uint32 bs64 uint64 tu16 uint16 tu32 uint32 tu64 uint64 boolean bool ) sizeTU16 := func() uint64 { return SizeTUint16(tu16) } sizeTU32 := func() uint64 { return SizeTUint32(tu32) } sizeTU64 := func() uint64 { return SizeTUint64(tu64) } // We deliberately set each record's type number to its index in // the slice, as this simplifies the re-encoding logic in // encodeParsedTypes(). decodeRecords := []Record{ MakePrimitiveRecord(0, &u8), MakePrimitiveRecord(1, &u16), MakePrimitiveRecord(2, &u32), MakePrimitiveRecord(3, &u64), MakePrimitiveRecord(4, &b32), MakePrimitiveRecord(5, &b33), MakePrimitiveRecord(6, &b64), MakePrimitiveRecord(7, &pk), MakePrimitiveRecord(8, &b), MakeBigSizeRecord(9, &bs32), MakeBigSizeRecord(10, &bs64), MakeDynamicRecord( 11, &tu16, sizeTU16, ETUint16, DTUint16, ), MakeDynamicRecord( 12, &tu32, sizeTU32, ETUint32, DTUint32, ), MakeDynamicRecord( 13, &tu64, sizeTU64, ETUint64, DTUint64, ), MakePrimitiveRecord(14, &boolean), } decodeStream := MustNewStream(decodeRecords...) r := bytes.NewReader(data) // Use the P2P decoding method to avoid OOMs from large lengths // in the fuzzer TLV data. parsedTypes, err := decodeStream.DecodeWithParsedTypesP2P(r) if err != nil { return } encoded := encodeParsedTypes(t, parsedTypes, decodeRecords) r2 := bytes.NewReader(encoded) decodeStream2 := MustNewStream(decodeRecords...) // The P2P decoding method is not required here since we're now // decoding TLV data that we created (not the fuzzer). parsedTypes2, err := decodeStream2.DecodeWithParsedTypes(r2) require.NoError(t, err) encoded2 := encodeParsedTypes(t, parsedTypes2, decodeRecords) require.Equal(t, encoded, encoded2) }) }