mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 01:43:16 +01:00
tlv/stream: parse entire stream to find all required types
This commit is contained in:
parent
d08e8ddd61
commit
e85aaa45f6
@ -162,6 +162,7 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
||||
var (
|
||||
typ Type
|
||||
min Type
|
||||
firstFail *Type
|
||||
recordIdx int
|
||||
overflow bool
|
||||
)
|
||||
@ -176,7 +177,10 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
||||
// We'll silence an EOF when zero bytes remain, meaning the
|
||||
// stream was cleanly encoded.
|
||||
case err == io.EOF:
|
||||
return parsedTypes, nil
|
||||
if firstFail == nil {
|
||||
return parsedTypes, nil
|
||||
}
|
||||
return parsedTypes, ErrUnknownRequiredType(*firstFail)
|
||||
|
||||
// Other unexpected errors.
|
||||
case err != nil:
|
||||
@ -243,7 +247,27 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
||||
// This record type is unknown to the stream, fail if the type
|
||||
// is even meaning that we are required to understand it.
|
||||
case typ%2 == 0:
|
||||
return nil, ErrUnknownRequiredType(typ)
|
||||
// We'll fail immediately in the case that we aren't
|
||||
// tracking the set of parsed types.
|
||||
if parsedTypes == nil {
|
||||
return nil, ErrUnknownRequiredType(typ)
|
||||
}
|
||||
|
||||
// Otherwise, we'll track the first such failure and
|
||||
// allow parsing to continue. If no other types of
|
||||
// errors are encountered, the first failure will be
|
||||
// returned as an ErrUnknownRequiredType so that the
|
||||
// full set of included types can be returned.
|
||||
if firstFail == nil {
|
||||
failTyp := typ
|
||||
firstFail = &failTyp
|
||||
}
|
||||
|
||||
// With the failure type recorded, we'll simply discard
|
||||
// the remainder of the record as if it were optional.
|
||||
// The first failure will be returned after reaching the
|
||||
// stopping condition.
|
||||
fallthrough
|
||||
|
||||
// Otherwise, the record type is unknown and is odd, discard the
|
||||
// number of bytes specified by length.
|
||||
|
@ -2,50 +2,106 @@ package tlv_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
type parsedTypeTest struct {
|
||||
name string
|
||||
encode []tlv.Type
|
||||
decode []tlv.Type
|
||||
expErr error
|
||||
}
|
||||
|
||||
// TestParsedTypes asserts that a Stream will properly return the set of types
|
||||
// that it encounters when the type is known-and-decoded or unknown-and-ignored.
|
||||
func TestParsedTypes(t *testing.T) {
|
||||
const (
|
||||
knownType = 1
|
||||
unknownType = 3
|
||||
firstReqType = 0
|
||||
knownType = 1
|
||||
unknownType = 3
|
||||
secondReqType = 4
|
||||
)
|
||||
|
||||
// Construct a stream that will encode two types, one that will be known
|
||||
// to the decoder and another that will be unknown.
|
||||
encStream := tlv.MustNewStream(
|
||||
tlv.MakePrimitiveRecord(knownType, new(uint64)),
|
||||
tlv.MakePrimitiveRecord(unknownType, new(uint64)),
|
||||
)
|
||||
tests := []parsedTypeTest{
|
||||
{
|
||||
name: "known optional and unknown optional",
|
||||
encode: []tlv.Type{knownType, unknownType},
|
||||
decode: []tlv.Type{knownType},
|
||||
},
|
||||
{
|
||||
name: "unknown required and known optional",
|
||||
encode: []tlv.Type{firstReqType, knownType},
|
||||
decode: []tlv.Type{knownType},
|
||||
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
||||
},
|
||||
{
|
||||
name: "unknown required and unknown optional",
|
||||
encode: []tlv.Type{unknownType, secondReqType},
|
||||
expErr: tlv.ErrUnknownRequiredType(secondReqType),
|
||||
},
|
||||
{
|
||||
name: "unknown required and known required",
|
||||
encode: []tlv.Type{firstReqType, secondReqType},
|
||||
decode: []tlv.Type{secondReqType},
|
||||
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
||||
},
|
||||
{
|
||||
name: "two unknown required",
|
||||
encode: []tlv.Type{firstReqType, secondReqType},
|
||||
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
testParsedTypes(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testParsedTypes(t *testing.T, test parsedTypeTest) {
|
||||
encRecords := make([]tlv.Record, 0, len(test.encode))
|
||||
for _, typ := range test.encode {
|
||||
encRecords = append(
|
||||
encRecords, tlv.MakePrimitiveRecord(typ, new(uint64)),
|
||||
)
|
||||
}
|
||||
|
||||
decRecords := make([]tlv.Record, 0, len(test.decode))
|
||||
for _, typ := range test.decode {
|
||||
decRecords = append(
|
||||
decRecords, tlv.MakePrimitiveRecord(typ, new(uint64)),
|
||||
)
|
||||
}
|
||||
|
||||
// Construct a stream that will encode the test's set of types.
|
||||
encStream := tlv.MustNewStream(encRecords...)
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := encStream.Encode(&b); err != nil {
|
||||
t.Fatalf("unable to encode stream: %v", err)
|
||||
}
|
||||
|
||||
// Create a stream that will parse only the known type.
|
||||
decStream := tlv.MustNewStream(
|
||||
tlv.MakePrimitiveRecord(knownType, new(uint64)),
|
||||
)
|
||||
// Create a stream that will parse a subset of the test's types.
|
||||
decStream := tlv.MustNewStream(decRecords...)
|
||||
|
||||
parsedTypes, err := decStream.DecodeWithParsedTypes(
|
||||
bytes.NewReader(b.Bytes()),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to decode stream: %v", err)
|
||||
if !reflect.DeepEqual(err, test.expErr) {
|
||||
t.Fatalf("error mismatch, want: %v got: %v", err, test.expErr)
|
||||
}
|
||||
|
||||
// Assert that both the known and unknown types are included in the set
|
||||
// of parsed types.
|
||||
if _, ok := parsedTypes[knownType]; !ok {
|
||||
t.Fatalf("known type %d should be in parsed types", knownType)
|
||||
}
|
||||
if _, ok := parsedTypes[unknownType]; !ok {
|
||||
t.Fatalf("unknown type %d should be in parsed types",
|
||||
unknownType)
|
||||
// Assert that all encoded types are included in the set of parsed
|
||||
// types.
|
||||
for _, typ := range test.encode {
|
||||
if _, ok := parsedTypes[typ]; !ok {
|
||||
t.Fatalf("encoded type %d should be in parsed types",
|
||||
typ)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user