From 206f773a9b3e854dd735cd422ff5c5753c8f9247 Mon Sep 17 00:00:00 2001 From: Ononiwu Maureen Date: Thu, 5 Oct 2023 06:12:54 +0100 Subject: [PATCH] tlv: Added bool to primitive Signed-off-by: Ononiwu Maureen --- tlv/fuzz_test.go | 37 +++++++++++++++++++------------- tlv/primitive.go | 43 +++++++++++++++++++++++++++++++++++++ tlv/primitive_test.go | 49 +++++++++++++++++++++++++++---------------- tlv/record.go | 5 +++++ 4 files changed, 102 insertions(+), 32 deletions(-) diff --git a/tlv/fuzz_test.go b/tlv/fuzz_test.go index ecaab4577..0a38aa66e 100644 --- a/tlv/fuzz_test.go +++ b/tlv/fuzz_test.go @@ -97,6 +97,13 @@ func FuzzVarBytes(f *testing.F) { }) } +func FuzzBool(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + var val bool + harness(t, data, EBool, DBool, &val, 1) + }) +} + // bigSizeHarness works the same as harness, except that it compares decoded // values instead of encoded values. We do this because DBigSize may leave some // bytes unparsed from data, causing the encoded data to be shorter than the @@ -210,20 +217,21 @@ func encodeParsedTypes(t *testing.T, parsedTypes TypeMap, func FuzzStream(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { var ( - u8 uint8 - u16 uint16 - u32 uint32 - u64 uint64 - b32 [32]byte - b33 [33]byte - b64 [64]byte - pk *btcec.PublicKey - b []byte - bs32 uint32 - bs64 uint64 - tu16 uint16 - tu32 uint32 - tu64 uint64 + u8 uint8 + u16 uint16 + u32 uint32 + u64 uint64 + b32 [32]byte + b33 [33]byte + b64 [64]byte + pk *btcec.PublicKey + b []byte + bs32 uint32 + bs64 uint64 + tu16 uint16 + tu32 uint32 + tu64 uint64 + boolean bool ) sizeTU16 := func() uint64 { @@ -260,6 +268,7 @@ func FuzzStream(f *testing.F) { MakeDynamicRecord( 13, &tu64, sizeTU64, ETUint64, DTUint64, ), + MakePrimitiveRecord(14, &boolean), } decodeStream := MustNewStream(decodeRecords...) diff --git a/tlv/primitive.go b/tlv/primitive.go index fbf9a2ca0..51b1d72bc 100644 --- a/tlv/primitive.go +++ b/tlv/primitive.go @@ -2,6 +2,7 @@ package tlv import ( "encoding/binary" + "errors" "fmt" "io" @@ -143,6 +144,33 @@ func EUint64T(w io.Writer, val uint64, buf *[8]byte) error { return err } +// EBool encodes a boolean. An error is returned if val is not a boolean. +func EBool(w io.Writer, val interface{}, buf *[8]byte) error { + if i, ok := val.(*bool); ok { + if *i { + buf[0] = 1 + } else { + buf[0] = 0 + } + _, err := w.Write(buf[:1]) + return err + } + return NewTypeForEncodingErr(val, "bool") +} + +// EBoolT encodes a bool val to the provided io.Writer. This method is exposed +// so that encodings for custom bool-like types can be created without +// incurring an extra heap allocation. +func EBoolT(w io.Writer, val bool, buf *[8]byte) error { + if val { + buf[0] = 1 + } else { + buf[0] = 0 + } + _, err := w.Write(buf[:1]) + return err +} + // DUint8 is a Decoder for uint8 values. An error is returned if val is not a // *uint8. func DUint8(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { @@ -195,6 +223,21 @@ func DUint64(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { return NewTypeForDecodingErr(val, "uint64", l, 8) } +// DBool decodes a boolean. An error is returned if val is not a boolean. +func DBool(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if i, ok := val.(*bool); ok && l == 1 { + if _, err := io.ReadFull(r, buf[:1]); err != nil { + return err + } + if buf[0] != 0 && buf[0] != 1 { + return errors.New("corrupted data") + } + *i = buf[0] != 0 + return nil + } + return NewTypeForDecodingErr(val, "bool", l, 1) +} + // EBytes32 is an Encoder for 32-byte arrays. An error is returned if val is not // a *[32]byte. func EBytes32(w io.Writer, val interface{}, _ *[8]byte) error { diff --git a/tlv/primitive_test.go b/tlv/primitive_test.go index 7034d4258..ba84320e6 100644 --- a/tlv/primitive_test.go +++ b/tlv/primitive_test.go @@ -17,15 +17,16 @@ var testPK, _ = btcec.ParsePubKey([]byte{0x02, }) type primitive struct { - u8 byte - u16 uint16 - u32 uint32 - u64 uint64 - b32 [32]byte - b33 [33]byte - b64 [64]byte - pk *btcec.PublicKey - bytes []byte + u8 byte + u16 uint16 + u32 uint32 + u64 uint64 + b32 [32]byte + b33 [33]byte + b64 [64]byte + pk *btcec.PublicKey + bytes []byte + boolean bool } // TestWrongEncodingType asserts that all primitives encoders will fail with a @@ -41,6 +42,7 @@ func TestWrongEncodingType(t *testing.T) { tlv.EBytes64, tlv.EPubKey, tlv.EVarBytes, + tlv.EBool, } // We'll use an int32 since it is not a primitive type, which should @@ -73,6 +75,7 @@ func TestWrongDecodingType(t *testing.T) { tlv.DBytes64, tlv.DPubKey, tlv.DVarBytes, + tlv.DBool, } // We'll use an int32 since it is not a primitive type, which should @@ -108,15 +111,16 @@ type fieldDecoder struct { // to decode the output and arrive at the same fields. func TestPrimitiveEncodings(t *testing.T) { prim := primitive{ - u8: 0x01, - u16: 0x0201, - u32: 0x02000001, - u64: 0x0200000000000001, - b32: [32]byte{0x02, 0x01}, - b33: [33]byte{0x03, 0x01}, - b64: [64]byte{0x02, 0x01}, - pk: testPK, - bytes: []byte{0xaa, 0xbb}, + u8: 0x01, + u16: 0x0201, + u32: 0x02000001, + u64: 0x0200000000000001, + b32: [32]byte{0x02, 0x01}, + b33: [33]byte{0x03, 0x01}, + b64: [64]byte{0x02, 0x01}, + pk: testPK, + bytes: []byte{0xaa, 0xbb}, + boolean: true, } encoders := []fieldEncoder{ @@ -156,6 +160,10 @@ func TestPrimitiveEncodings(t *testing.T) { val: &prim.bytes, encoder: tlv.EVarBytes, }, + { + val: &prim.boolean, + encoder: tlv.EBool, + }, } // First we'll encode the primitive fields into a buffer. @@ -222,6 +230,11 @@ func TestPrimitiveEncodings(t *testing.T) { decoder: tlv.DVarBytes, size: 2, }, + { + val: &prim2.boolean, + decoder: tlv.DBool, + size: 1, + }, } for _, field := range decoders { diff --git a/tlv/record.go b/tlv/record.go index 474908ebf..7d3575f20 100644 --- a/tlv/record.go +++ b/tlv/record.go @@ -104,6 +104,11 @@ func MakePrimitiveRecord(typ Type, val interface{}) Record { decoder Decoder ) switch e := val.(type) { + case *bool: + staticSize = 1 + encoder = EBool + decoder = DBool + case *uint8: staticSize = 1 encoder = EUint8