mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-19 05:45:21 +01:00
Merge pull request #3701 from joostjager/isolate-odd-even
tlv+hop: contain odd/even logic in payload parsing
This commit is contained in:
commit
e2e94c3b6e
@ -29,6 +29,12 @@ const (
|
||||
RequiredViolation
|
||||
)
|
||||
|
||||
const (
|
||||
// customTypeStart is the start of the custom tlv type range as defined
|
||||
// in BOLT 01.
|
||||
customTypeStart = 65536
|
||||
)
|
||||
|
||||
// String returns a human-readable description of the violation as a verb.
|
||||
func (v PayloadViolation) String() string {
|
||||
switch v {
|
||||
@ -124,28 +130,6 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
||||
|
||||
parsedTypes, err := tlvStream.DecodeWithParsedTypes(r)
|
||||
if err != nil {
|
||||
// Promote any required type failures into ErrInvalidPayload.
|
||||
if e, required := err.(tlv.ErrUnknownRequiredType); required {
|
||||
// If the parser returned an unknown required type
|
||||
// failure, we'll first check that the payload is
|
||||
// properly formed according to our known set of
|
||||
// constraints. If an error is discovered, this
|
||||
// overrides the required type failure.
|
||||
nextHop := lnwire.NewShortChanIDFromInt(cid)
|
||||
err = ValidateParsedPayloadTypes(parsedTypes, nextHop)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Otherwise the known constraints were applied
|
||||
// successfully, report the invalid type failure
|
||||
// returned by the parser.
|
||||
return nil, ErrInvalidPayload{
|
||||
Type: tlv.Type(e),
|
||||
Violation: RequiredViolation,
|
||||
FinalHop: nextHop == Exit,
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -157,6 +141,16 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check for violation of the rules for mandatory fields.
|
||||
violatingType := getMinRequiredViolation(parsedTypes)
|
||||
if violatingType != nil {
|
||||
return nil, ErrInvalidPayload{
|
||||
Type: *violatingType,
|
||||
Violation: RequiredViolation,
|
||||
FinalHop: nextHop == Exit,
|
||||
}
|
||||
}
|
||||
|
||||
// If no MPP field was parsed, set the MPP field on the resulting
|
||||
// payload to nil.
|
||||
if _, ok := parsedTypes[record.MPPOnionType]; !ok {
|
||||
@ -239,3 +233,35 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
|
||||
func (h *Payload) MultiPath() *record.MPP {
|
||||
return h.MPP
|
||||
}
|
||||
|
||||
// getMinRequiredViolation checks for unrecognized required (even) fields in the
|
||||
// standard range and returns the lowest required type. Always returning the
|
||||
// lowest required type allows a failure message to be deterministic.
|
||||
func getMinRequiredViolation(set tlv.TypeSet) *tlv.Type {
|
||||
var (
|
||||
requiredViolation bool
|
||||
minRequiredViolationType tlv.Type
|
||||
)
|
||||
for t, known := range set {
|
||||
// If a type is even but not known to us, we cannot process the
|
||||
// payload. We are required to understand a field that we don't
|
||||
// support.
|
||||
//
|
||||
// We always accept custom fields, because a higher level
|
||||
// application may understand them.
|
||||
if known || t%2 != 0 || t >= customTypeStart {
|
||||
continue
|
||||
}
|
||||
|
||||
if !requiredViolation || t < minRequiredViolationType {
|
||||
minRequiredViolationType = t
|
||||
}
|
||||
requiredViolation = true
|
||||
}
|
||||
|
||||
if requiredViolation {
|
||||
return &minRequiredViolationType
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -130,6 +130,12 @@ var decodePayloadTests = []decodePayloadTest{
|
||||
FinalHop: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "required type in custom range",
|
||||
payload: []byte{0x02, 0x00, 0x04, 0x00,
|
||||
0xfe, 0x00, 0x01, 0x00, 0x00, 0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid intermediate hop",
|
||||
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00,
|
||||
|
@ -12,8 +12,9 @@ import (
|
||||
// Type is an 64-bit identifier for a TLV Record.
|
||||
type Type uint64
|
||||
|
||||
// TypeSet is an unordered set of Types.
|
||||
type TypeSet map[Type]struct{}
|
||||
// TypeSet is an unordered set of Types. The map item boolean values indicate
|
||||
// whether the type that we parsed was known.
|
||||
type TypeSet map[Type]bool
|
||||
|
||||
// Encoder is a signature for methods that can encode TLV values. An error
|
||||
// should be returned if the Encoder cannot support the underlying type of val.
|
||||
|
@ -2,7 +2,6 @@ package tlv
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
@ -22,15 +21,6 @@ var ErrStreamNotCanonical = errors.New("tlv stream is not canonical")
|
||||
// long to parse.
|
||||
var ErrRecordTooLarge = errors.New("record is too large")
|
||||
|
||||
// ErrUnknownRequiredType is an error returned when decoding an unknown and even
|
||||
// type from a Stream.
|
||||
type ErrUnknownRequiredType Type
|
||||
|
||||
// Error returns a human-readable description of unknown required type.
|
||||
func (t ErrUnknownRequiredType) Error() string {
|
||||
return fmt.Sprintf("unknown required type: %d", t)
|
||||
}
|
||||
|
||||
// Stream defines a TLV stream that can be used for encoding or decoding a set
|
||||
// of TLV Records.
|
||||
type Stream struct {
|
||||
@ -162,7 +152,6 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
||||
var (
|
||||
typ Type
|
||||
min Type
|
||||
firstFail *Type
|
||||
recordIdx int
|
||||
overflow bool
|
||||
)
|
||||
@ -177,10 +166,7 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
||||
// We'll silence an EOF when zero bytes remain, meaning the
|
||||
// stream was cleanly encoded.
|
||||
case err == io.EOF:
|
||||
if firstFail == nil {
|
||||
return parsedTypes, nil
|
||||
}
|
||||
return parsedTypes, ErrUnknownRequiredType(*firstFail)
|
||||
|
||||
// Other unexpected errors.
|
||||
case err != nil:
|
||||
@ -244,31 +230,6 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// This record type is unknown to the stream, fail if the type
|
||||
// is even meaning that we are required to understand it.
|
||||
case typ%2 == 0:
|
||||
// We'll fail immediately in the case that we aren't
|
||||
// tracking the set of parsed types.
|
||||
if parsedTypes == nil {
|
||||
return nil, ErrUnknownRequiredType(typ)
|
||||
}
|
||||
|
||||
// Otherwise, we'll track the first such failure and
|
||||
// allow parsing to continue. If no other types of
|
||||
// errors are encountered, the first failure will be
|
||||
// returned as an ErrUnknownRequiredType so that the
|
||||
// full set of included types can be returned.
|
||||
if firstFail == nil {
|
||||
failTyp := typ
|
||||
firstFail = &failTyp
|
||||
}
|
||||
|
||||
// With the failure type recorded, we'll simply discard
|
||||
// the remainder of the record as if it were optional.
|
||||
// The first failure will be returned after reaching the
|
||||
// stopping condition.
|
||||
fallthrough
|
||||
|
||||
// Otherwise, the record type is unknown and is odd, discard the
|
||||
// number of bytes specified by length.
|
||||
default:
|
||||
@ -289,7 +250,7 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
||||
// Record the successfully decoded or ignored type if the
|
||||
// caller provided an initialized TypeSet.
|
||||
if parsedTypes != nil {
|
||||
parsedTypes[typ] = struct{}{}
|
||||
parsedTypes[typ] = ok
|
||||
}
|
||||
|
||||
// Update our record index so that we can begin our next search
|
||||
|
@ -12,46 +12,35 @@ type parsedTypeTest struct {
|
||||
name string
|
||||
encode []tlv.Type
|
||||
decode []tlv.Type
|
||||
expErr error
|
||||
expParsedTypes tlv.TypeSet
|
||||
}
|
||||
|
||||
// TestParsedTypes asserts that a Stream will properly return the set of types
|
||||
// that it encounters when the type is known-and-decoded or unknown-and-ignored.
|
||||
func TestParsedTypes(t *testing.T) {
|
||||
const (
|
||||
firstReqType = 0
|
||||
knownType = 1
|
||||
unknownType = 3
|
||||
secondReqType = 4
|
||||
secondKnownType = 4
|
||||
)
|
||||
|
||||
tests := []parsedTypeTest{
|
||||
{
|
||||
name: "known optional and unknown optional",
|
||||
name: "known and unknown",
|
||||
encode: []tlv.Type{knownType, unknownType},
|
||||
decode: []tlv.Type{knownType},
|
||||
expParsedTypes: tlv.TypeSet{
|
||||
unknownType: false,
|
||||
knownType: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unknown required and known optional",
|
||||
encode: []tlv.Type{firstReqType, knownType},
|
||||
decode: []tlv.Type{knownType},
|
||||
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
||||
name: "known and missing known",
|
||||
encode: []tlv.Type{knownType},
|
||||
decode: []tlv.Type{knownType, secondKnownType},
|
||||
expParsedTypes: tlv.TypeSet{
|
||||
knownType: true,
|
||||
},
|
||||
{
|
||||
name: "unknown required and unknown optional",
|
||||
encode: []tlv.Type{unknownType, secondReqType},
|
||||
expErr: tlv.ErrUnknownRequiredType(secondReqType),
|
||||
},
|
||||
{
|
||||
name: "unknown required and known required",
|
||||
encode: []tlv.Type{firstReqType, secondReqType},
|
||||
decode: []tlv.Type{secondReqType},
|
||||
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
||||
},
|
||||
{
|
||||
name: "two unknown required",
|
||||
encode: []tlv.Type{firstReqType, secondReqType},
|
||||
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
||||
},
|
||||
}
|
||||
|
||||
@ -92,16 +81,10 @@ func testParsedTypes(t *testing.T, test parsedTypeTest) {
|
||||
parsedTypes, err := decStream.DecodeWithParsedTypes(
|
||||
bytes.NewReader(b.Bytes()),
|
||||
)
|
||||
if !reflect.DeepEqual(err, test.expErr) {
|
||||
t.Fatalf("error mismatch, want: %v got: %v", err, test.expErr)
|
||||
}
|
||||
|
||||
// Assert that all encoded types are included in the set of parsed
|
||||
// types.
|
||||
for _, typ := range test.encode {
|
||||
if _, ok := parsedTypes[typ]; !ok {
|
||||
t.Fatalf("encoded type %d should be in parsed types",
|
||||
typ)
|
||||
if err != nil {
|
||||
t.Fatalf("error decoding: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(parsedTypes, test.expParsedTypes) {
|
||||
t.Fatalf("error mismatch on parsed types")
|
||||
}
|
||||
}
|
||||
|
@ -203,26 +203,6 @@ var tlvDecodingFailureTests = []struct {
|
||||
},
|
||||
expErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
name: "unknown even type",
|
||||
bytes: []byte{0x12, 0x00},
|
||||
expErr: tlv.ErrUnknownRequiredType(0x12),
|
||||
},
|
||||
{
|
||||
name: "unknown even type",
|
||||
bytes: []byte{0xfd, 0x01, 0x02, 0x00},
|
||||
expErr: tlv.ErrUnknownRequiredType(0x102),
|
||||
},
|
||||
{
|
||||
name: "unknown even type",
|
||||
bytes: []byte{0xfe, 0x01, 0x00, 0x00, 0x02, 0x00},
|
||||
expErr: tlv.ErrUnknownRequiredType(0x01000002),
|
||||
},
|
||||
{
|
||||
name: "unknown even type",
|
||||
bytes: []byte{0xff, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00},
|
||||
expErr: tlv.ErrUnknownRequiredType(0x0100000000000002),
|
||||
},
|
||||
{
|
||||
name: "greater than encoding length for n1's amt",
|
||||
bytes: []byte{0x01, 0x09, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
@ -340,12 +320,6 @@ var tlvDecodingFailureTests = []struct {
|
||||
expErr: tlv.NewTypeForDecodingErr(new(nodeAmts), "nodeAmts", 50, 49),
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "unknown required type or n1",
|
||||
bytes: []byte{0x00, 0x00},
|
||||
expErr: tlv.ErrUnknownRequiredType(0x00),
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "less than encoding length for n1's cltvDelta",
|
||||
bytes: []byte{0xfd, 0x00, 0x0fe, 0x00},
|
||||
@ -364,12 +338,6 @@ var tlvDecodingFailureTests = []struct {
|
||||
expErr: tlv.NewTypeForDecodingErr(new(uint16), "uint16", 3, 2),
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "unknown even field for n1's namespace",
|
||||
bytes: []byte{0x0a, 0x00},
|
||||
expErr: tlv.ErrUnknownRequiredType(0x0a),
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "valid records but invalid ordering",
|
||||
bytes: []byte{0x02, 0x08,
|
||||
|
Loading…
Reference in New Issue
Block a user