multi: return parsed types from payload

To separate blinded route parsing from payload parsing, we need to
return the parsed types map so that we can properly validate blinded
data payloads against what we saw in the onion.
This commit is contained in:
Carla Kirk-Cohen 2024-04-02 10:59:57 -04:00
parent 1e6fae37f7
commit 2029a06918
No known key found for this signature in database
GPG Key ID: 4CA7FE54A6213C91
4 changed files with 16 additions and 11 deletions

View File

@ -117,7 +117,7 @@ func fuzzPayload(f *testing.F, finalPayload bool) {
r := bytes.NewReader(data)
payload1, err := NewPayloadFromReader(r, finalPayload)
payload1, _, err := NewPayloadFromReader(r, finalPayload)
if err != nil {
return
}
@ -146,7 +146,7 @@ func fuzzPayload(f *testing.F, finalPayload bool) {
require.NoError(t, err)
}
payload2, err := NewPayloadFromReader(&b, finalPayload)
payload2, _, err := NewPayloadFromReader(&b, finalPayload)
require.NoError(t, err)
require.Equal(t, payload1, payload2)

View File

@ -106,11 +106,13 @@ func (r *sphinxHopIterator) HopPayload() (*Payload, error) {
// Otherwise, if this is the TLV payload, then we'll make a new stream
// to decode only what we need to make routing decisions.
case sphinx.PayloadTLV:
return NewPayloadFromReader(
payload, _, err := NewPayloadFromReader(
bytes.NewReader(r.processedPacket.Payload.Payload),
r.processedPacket.Action == sphinx.ExitNode,
)
return payload, err
default:
return nil, fmt.Errorf("unknown sphinx payload type: %v",
r.processedPacket.Payload.Type)

View File

@ -133,11 +133,14 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload {
}
}
// NewPayloadFromReader builds a new Hop from the passed io.Reader. The reader
// NewPayloadFromReader builds a new Hop from the passed io.Reader and returns
// a map of all the types that were found in the payload. The reader
// should correspond to the bytes encapsulated in a TLV onion payload. The
// final hop bool signals that this payload was the final packet parsed by
// sphinx.
func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload, error) {
func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload,
map[tlv.Type][]byte, error) {
var (
cid uint64
amt uint64
@ -162,27 +165,27 @@ func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload, error) {
record.NewTotalAmtMsatBlinded(&totalAmtMsat),
)
if err != nil {
return nil, err
return nil, nil, err
}
// Since this data is provided by a potentially malicious peer, pass it
// into the P2P decoding variant.
parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
if err != nil {
return nil, err
return nil, nil, err
}
// Validate whether the sender properly included or omitted tlv records
// in accordance with BOLT 04.
err = ValidateParsedPayloadTypes(parsedTypes, finalHop)
if err != nil {
return nil, err
return nil, nil, err
}
// Check for violation of the rules for mandatory fields.
violatingType := getMinRequiredViolation(parsedTypes)
if violatingType != nil {
return nil, ErrInvalidPayload{
return nil, nil, ErrInvalidPayload{
Type: *violatingType,
Violation: RequiredViolation,
FinalHop: finalHop,
@ -229,7 +232,7 @@ func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload, error) {
blindingPoint: blindingPoint,
customRecords: customRecords,
totalAmtMsat: lnwire.MilliSatoshi(totalAmtMsat),
}, nil
}, nil, nil
}
// ForwardingInfo returns the basic parameters required for HTLC forwarding,

View File

@ -479,7 +479,7 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
testChildIndex = uint32(9)
)
p, err := hop.NewPayloadFromReader(
p, _, err := hop.NewPayloadFromReader(
bytes.NewReader(test.payload), test.isFinalHop,
)
if !reflect.DeepEqual(test.expErr, err) {