mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-19 05:45:21 +01:00
multi: return parsed types from payload
To separate blinded route parsing from payload parsing, we need to return the parsed types map so that we can properly validate blinded data payloads against what we saw in the onion.
This commit is contained in:
parent
1e6fae37f7
commit
2029a06918
@ -117,7 +117,7 @@ func fuzzPayload(f *testing.F, finalPayload bool) {
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
|
||||
payload1, err := NewPayloadFromReader(r, finalPayload)
|
||||
payload1, _, err := NewPayloadFromReader(r, finalPayload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -146,7 +146,7 @@ func fuzzPayload(f *testing.F, finalPayload bool) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
payload2, err := NewPayloadFromReader(&b, finalPayload)
|
||||
payload2, _, err := NewPayloadFromReader(&b, finalPayload)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, payload1, payload2)
|
||||
|
@ -106,11 +106,13 @@ func (r *sphinxHopIterator) HopPayload() (*Payload, error) {
|
||||
// Otherwise, if this is the TLV payload, then we'll make a new stream
|
||||
// to decode only what we need to make routing decisions.
|
||||
case sphinx.PayloadTLV:
|
||||
return NewPayloadFromReader(
|
||||
payload, _, err := NewPayloadFromReader(
|
||||
bytes.NewReader(r.processedPacket.Payload.Payload),
|
||||
r.processedPacket.Action == sphinx.ExitNode,
|
||||
)
|
||||
|
||||
return payload, err
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown sphinx payload type: %v",
|
||||
r.processedPacket.Payload.Type)
|
||||
|
@ -133,11 +133,14 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload {
|
||||
}
|
||||
}
|
||||
|
||||
// NewPayloadFromReader builds a new Hop from the passed io.Reader. The reader
|
||||
// NewPayloadFromReader builds a new Hop from the passed io.Reader and returns
|
||||
// a map of all the types that were found in the payload. The reader
|
||||
// should correspond to the bytes encapsulated in a TLV onion payload. The
|
||||
// final hop bool signals that this payload was the final packet parsed by
|
||||
// sphinx.
|
||||
func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload, error) {
|
||||
func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload,
|
||||
map[tlv.Type][]byte, error) {
|
||||
|
||||
var (
|
||||
cid uint64
|
||||
amt uint64
|
||||
@ -162,27 +165,27 @@ func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload, error) {
|
||||
record.NewTotalAmtMsatBlinded(&totalAmtMsat),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Since this data is provided by a potentially malicious peer, pass it
|
||||
// into the P2P decoding variant.
|
||||
parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Validate whether the sender properly included or omitted tlv records
|
||||
// in accordance with BOLT 04.
|
||||
err = ValidateParsedPayloadTypes(parsedTypes, finalHop)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Check for violation of the rules for mandatory fields.
|
||||
violatingType := getMinRequiredViolation(parsedTypes)
|
||||
if violatingType != nil {
|
||||
return nil, ErrInvalidPayload{
|
||||
return nil, nil, ErrInvalidPayload{
|
||||
Type: *violatingType,
|
||||
Violation: RequiredViolation,
|
||||
FinalHop: finalHop,
|
||||
@ -229,7 +232,7 @@ func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload, error) {
|
||||
blindingPoint: blindingPoint,
|
||||
customRecords: customRecords,
|
||||
totalAmtMsat: lnwire.MilliSatoshi(totalAmtMsat),
|
||||
}, nil
|
||||
}, nil, nil
|
||||
}
|
||||
|
||||
// ForwardingInfo returns the basic parameters required for HTLC forwarding,
|
||||
|
@ -479,7 +479,7 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
|
||||
testChildIndex = uint32(9)
|
||||
)
|
||||
|
||||
p, err := hop.NewPayloadFromReader(
|
||||
p, _, err := hop.NewPayloadFromReader(
|
||||
bytes.NewReader(test.payload), test.isFinalHop,
|
||||
)
|
||||
if !reflect.DeepEqual(test.expErr, err) {
|
||||
|
Loading…
Reference in New Issue
Block a user