lnd/tlv/fuzz_test.go

299 lines
6.9 KiB
Go
Raw Normal View History

package tlv
import (
"bytes"
"testing"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/stretchr/testify/require"
)
// harness decodes the passed data, re-encodes it, and verifies that the
// re-encoded data matches the original data.
func harness(t *testing.T, data []byte, encode Encoder, decode Decoder,
val interface{}, decodeLen uint64) {
if uint64(len(data)) > decodeLen {
return
}
r := bytes.NewReader(data)
var buf [8]byte
if err := decode(r, val, &buf, decodeLen); err != nil {
return
}
var b bytes.Buffer
require.NoError(t, encode(&b, val, &buf))
// Use bytes.Equal instead of require.Equal so that nil and empty slices
// are considered equal.
require.True(
t, bytes.Equal(data, b.Bytes()), "%v != %v", data, b.Bytes(),
)
}
func FuzzUint8(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val uint8
harness(t, data, EUint8, DUint8, &val, 1)
})
}
func FuzzUint16(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val uint16
harness(t, data, EUint16, DUint16, &val, 2)
})
}
func FuzzUint32(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val uint32
harness(t, data, EUint32, DUint32, &val, 4)
})
}
func FuzzUint64(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val uint64
harness(t, data, EUint64, DUint64, &val, 8)
})
}
func FuzzBytes32(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val [32]byte
harness(t, data, EBytes32, DBytes32, &val, 32)
})
}
func FuzzBytes33(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val [33]byte
harness(t, data, EBytes33, DBytes33, &val, 33)
})
}
func FuzzBytes64(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val [64]byte
harness(t, data, EBytes64, DBytes64, &val, 64)
})
}
func FuzzPubKey(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val *btcec.PublicKey
harness(t, data, EPubKey, DPubKey, &val, 33)
})
}
func FuzzVarBytes(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val []byte
harness(t, data, EVarBytes, DVarBytes, &val, uint64(len(data)))
})
}
func FuzzBool(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val bool
harness(t, data, EBool, DBool, &val, 1)
})
}
// bigSizeHarness works the same as harness, except that it compares decoded
// values instead of encoded values. We do this because DBigSize may leave some
// bytes unparsed from data, causing the encoded data to be shorter than the
// original.
func bigSizeHarness(t *testing.T, data []byte, val1, val2 interface{}) {
if len(data) > 9 {
return
}
r := bytes.NewReader(data)
var buf [8]byte
if err := DBigSize(r, val1, &buf, 9); err != nil {
return
}
var b bytes.Buffer
require.NoError(t, EBigSize(&b, val1, &buf))
r2 := bytes.NewReader(b.Bytes())
require.NoError(t, DBigSize(r2, val2, &buf, 9))
require.Equal(t, val1, val2)
}
func FuzzBigSize32(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val1, val2 uint32
bigSizeHarness(t, data, &val1, &val2)
})
}
func FuzzBigSize64(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val1, val2 uint64
bigSizeHarness(t, data, &val1, &val2)
})
}
func FuzzTUint16(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val uint16
for decodeLen := 0; decodeLen <= 2; decodeLen++ {
harness(
t, data, ETUint16, DTUint16, &val,
uint64(decodeLen),
)
}
})
}
func FuzzTUint32(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val uint32
for decodeLen := 0; decodeLen <= 4; decodeLen++ {
harness(
t, data, ETUint32, DTUint32, &val,
uint64(decodeLen),
)
}
})
}
func FuzzTUint64(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val uint64
for decodeLen := 0; decodeLen <= 8; decodeLen++ {
harness(
t, data, ETUint64, DTUint64, &val,
uint64(decodeLen),
)
}
})
}
// encodeParsedTypes re-encodes TLVs decoded from a stream, using the
// parsedTypes and decodedRecords produced during decoding. This function
// requires that each record in decodedRecords has a type number equivalent to
// its index in the slice.
func encodeParsedTypes(t *testing.T, parsedTypes TypeMap,
decodedRecords []Record) []byte {
var encodeRecords []Record
for typ, val := range parsedTypes {
// If typ is present in decodedRecords, use the decoded value.
if typ < Type(len(decodedRecords)) {
encodeRecords = append(
encodeRecords, decodedRecords[typ],
)
continue
}
// Otherwise, typ is not present in decodedRecords, and we must
// create a new one.
val := val
encodeRecords = append(
encodeRecords, MakePrimitiveRecord(typ, &val),
)
}
SortRecords(encodeRecords)
encodeStream := MustNewStream(encodeRecords...)
var b bytes.Buffer
require.NoError(t, encodeStream.Encode(&b))
return b.Bytes()
}
// FuzzStream does two stream decode-encode cycles on the fuzzer data and checks
// that the encoded values match.
func FuzzStream(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var (
u8 uint8
u16 uint16
u32 uint32
u64 uint64
b32 [32]byte
b33 [33]byte
b64 [64]byte
pk *btcec.PublicKey
b []byte
bs32 uint32
bs64 uint64
tu16 uint16
tu32 uint32
tu64 uint64
boolean bool
)
sizeTU16 := func() uint64 {
return SizeTUint16(tu16)
}
sizeTU32 := func() uint64 {
return SizeTUint32(tu32)
}
sizeTU64 := func() uint64 {
return SizeTUint64(tu64)
}
// We deliberately set each record's type number to its index in
// the slice, as this simplifies the re-encoding logic in
// encodeParsedTypes().
decodeRecords := []Record{
MakePrimitiveRecord(0, &u8),
MakePrimitiveRecord(1, &u16),
MakePrimitiveRecord(2, &u32),
MakePrimitiveRecord(3, &u64),
MakePrimitiveRecord(4, &b32),
MakePrimitiveRecord(5, &b33),
MakePrimitiveRecord(6, &b64),
MakePrimitiveRecord(7, &pk),
MakePrimitiveRecord(8, &b),
MakeBigSizeRecord(9, &bs32),
MakeBigSizeRecord(10, &bs64),
MakeDynamicRecord(
11, &tu16, sizeTU16, ETUint16, DTUint16,
),
MakeDynamicRecord(
12, &tu32, sizeTU32, ETUint32, DTUint32,
),
MakeDynamicRecord(
13, &tu64, sizeTU64, ETUint64, DTUint64,
),
MakePrimitiveRecord(14, &boolean),
}
decodeStream := MustNewStream(decodeRecords...)
r := bytes.NewReader(data)
// Use the P2P decoding method to avoid OOMs from large lengths
// in the fuzzer TLV data.
parsedTypes, err := decodeStream.DecodeWithParsedTypesP2P(r)
if err != nil {
return
}
encoded := encodeParsedTypes(t, parsedTypes, decodeRecords)
r2 := bytes.NewReader(encoded)
decodeStream2 := MustNewStream(decodeRecords...)
// The P2P decoding method is not required here since we're now
// decoding TLV data that we created (not the fuzzer).
parsedTypes2, err := decodeStream2.DecodeWithParsedTypes(r2)
require.NoError(t, err)
encoded2 := encodeParsedTypes(t, parsedTypes2, decodeRecords)
require.Equal(t, encoded, encoded2)
})
}