multi: handle all blinding point validation in ValidateParsedPayloadTypes

This commit moves all our validation related to the presence of fields
into ValidateParsedPayloadTypes so that we can handle them in a single
place. We draw the distinction between:
- Validation of the payload (and the context within it's being parsed,
  final hop / blinded hop etc)
- Processing and validation of encrypted data, where we perform
  additional cryptographic operations and validate that the fields
  contained in the blob are valid.

This helps draw the line more clearly between the two validation types,
rather than splitting some payload-releated blinded hop processing
into the encrypted data processing part. The downside of this approach
(vs doing the blinded path payload check _after_ payload validation)
is that we have to pass additional context into payload validation
(ie, whether we got a blinding point in our UpdateAddHtlc - as we
already do for isFinalHop).
This commit is contained in:
Carla Kirk-Cohen 2024-04-22 14:06:17 -04:00
parent c2c0158c84
commit 4d051b4170
No known key found for this signature in database
GPG key ID: 4CA7FE54A6213C91
5 changed files with 174 additions and 75 deletions

View file

@ -97,19 +97,37 @@ func hopFromPayload(p *Payload) (*route.Hop, uint64) {
// FuzzPayloadFinal fuzzes final hop payloads, providing the additional context // FuzzPayloadFinal fuzzes final hop payloads, providing the additional context
// that the hop should be final (which is usually obtained by the structure // that the hop should be final (which is usually obtained by the structure
// of the sphinx packet). // of the sphinx packet) for the case where a blinding point was provided in
func FuzzPayloadFinal(f *testing.F) { // UpdateAddHtlc.
fuzzPayload(f, true) func FuzzPayloadFinalBlinding(f *testing.F) {
fuzzPayload(f, true, true)
}
// FuzzPayloadFinal fuzzes final hop payloads, providing the additional context
// that the hop should be final (which is usually obtained by the structure
// of the sphinx packet) for the case where no blinding point was provided in
// UpdateAddHtlc.
func FuzzPayloadFinalNoBlinding(f *testing.F) {
fuzzPayload(f, true, false)
} }
// FuzzPayloadIntermediate fuzzes intermediate hop payloads, providing the // FuzzPayloadIntermediate fuzzes intermediate hop payloads, providing the
// additional context that a hop should be intermediate (which is usually // additional context that a hop should be intermediate (which is usually
// obtained by the structure of the sphinx packet). // obtained by the structure of the sphinx packet) for the case where a
func FuzzPayloadIntermediate(f *testing.F) { // blinding point was provided in UpdateAddHtlc.
fuzzPayload(f, false) func FuzzPayloadIntermediateBlinding(f *testing.F) {
fuzzPayload(f, false, true)
} }
func fuzzPayload(f *testing.F, finalPayload bool) { // FuzzPayloadIntermediate fuzzes intermediate hop payloads, providing the
// additional context that a hop should be intermediate (which is usually
// obtained by the structure of the sphinx packet) for the case where no
// blinding point was provided in UpdateAddHtlc.
func FuzzPayloadIntermediateNoBlinding(f *testing.F) {
fuzzPayload(f, false, false)
}
func fuzzPayload(f *testing.F, finalPayload, updateAddBlinded bool) {
f.Fuzz(func(t *testing.T, data []byte) { f.Fuzz(func(t *testing.T, data []byte) {
if len(data) > sphinx.MaxPayloadSize { if len(data) > sphinx.MaxPayloadSize {
return return
@ -117,7 +135,9 @@ func fuzzPayload(f *testing.F, finalPayload bool) {
r := bytes.NewReader(data) r := bytes.NewReader(data)
payload1, _, err := NewPayloadFromReader(r, finalPayload) payload1, _, err := NewPayloadFromReader(
r, finalPayload, updateAddBlinded,
)
if err != nil { if err != nil {
return return
} }
@ -146,7 +166,9 @@ func fuzzPayload(f *testing.F, finalPayload bool) {
require.NoError(t, err) require.NoError(t, err)
} }
payload2, _, err := NewPayloadFromReader(&b, finalPayload) payload2, _, err := NewPayloadFromReader(
&b, finalPayload, updateAddBlinded,
)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, payload1, payload2) require.Equal(t, payload1, payload2)

View file

@ -17,6 +17,11 @@ import (
var ( var (
// ErrDecodeFailed is returned when we can't decode blinded data. // ErrDecodeFailed is returned when we can't decode blinded data.
ErrDecodeFailed = errors.New("could not decode blinded data") ErrDecodeFailed = errors.New("could not decode blinded data")
// ErrNoBlindingPoint is returned when we have not provided a blinding
// point for a validated payload with encrypted data set.
ErrNoBlindingPoint = errors.New("no blinding point set for validated " +
"blinded hop")
) )
// Iterator is an interface that abstracts away the routing information // Iterator is an interface that abstracts away the routing information
@ -109,7 +114,7 @@ func (r *sphinxHopIterator) HopPayload() (*Payload, error) {
isFinal := r.processedPacket.Action == sphinx.ExitNode isFinal := r.processedPacket.Action == sphinx.ExitNode
payload, parsed, err := NewPayloadFromReader( payload, parsed, err := NewPayloadFromReader(
bytes.NewReader(r.processedPacket.Payload.Payload), bytes.NewReader(r.processedPacket.Payload.Payload),
isFinal, isFinal, r.blindingKit.UpdateAddBlinding.IsSome(),
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -182,35 +187,16 @@ type BlindingKit struct {
IncomingAmount lnwire.MilliSatoshi IncomingAmount lnwire.MilliSatoshi
} }
// validateBlindingPoint validates that only one blinding point is present for // getBlindingPoint returns either the payload or updateAddHtlc blinding point,
// the hop and returns the relevant one. // assuming that validation that these values are appropriately set has already
func (b *BlindingKit) validateBlindingPoint(payloadBlinding *btcec.PublicKey, // been handled elsewhere.
isFinalHop bool) (*btcec.PublicKey, error) { func (b *BlindingKit) getBlindingPoint(payloadBlinding *btcec.PublicKey) (
*btcec.PublicKey, error) {
// Bolt 04: if encrypted_recipient_data is present:
// - if blinding_point (in update add) is set:
// - MUST error if current_blinding_point is set (in payload)
// - otherwise:
// - MUST return an error if current_blinding_point is not present
// (in payload)
payloadBlindingSet := payloadBlinding != nil payloadBlindingSet := payloadBlinding != nil
updateBlindingSet := b.UpdateAddBlinding.IsSome() updateBlindingSet := b.UpdateAddBlinding.IsSome()
switch { switch {
case !(payloadBlindingSet || updateBlindingSet):
return nil, ErrInvalidPayload{
Type: record.BlindingPointOnionType,
Violation: OmittedViolation,
FinalHop: isFinalHop,
}
case payloadBlindingSet && updateBlindingSet:
return nil, ErrInvalidPayload{
Type: record.BlindingPointOnionType,
Violation: IncludedViolation,
FinalHop: isFinalHop,
}
case payloadBlindingSet: case payloadBlindingSet:
return payloadBlinding, nil return payloadBlinding, nil
@ -223,9 +209,10 @@ func (b *BlindingKit) validateBlindingPoint(payloadBlinding *btcec.PublicKey,
} }
return pk.Val, nil return pk.Val, nil
}
return nil, fmt.Errorf("expected blinded point set") default:
return nil, ErrNoBlindingPoint
}
} }
// DecryptAndValidateFwdInfo performs all operations required to decrypt and // DecryptAndValidateFwdInfo performs all operations required to decrypt and
@ -235,11 +222,10 @@ func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload,
*ForwardingInfo, error) { *ForwardingInfo, error) {
// We expect this function to be called when we have encrypted data // We expect this function to be called when we have encrypted data
// present, and a blinding key is set either in the payload or the // present, and expect validation to already have ensured that a
// blinding key is set either in the payload or the
// update_add_htlc message. // update_add_htlc message.
blindingPoint, err := b.validateBlindingPoint( blindingPoint, err := b.getBlindingPoint(payload.blindingPoint)
payload.blindingPoint, isFinalHop,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -216,24 +216,10 @@ func TestDecryptAndValidateFwdInfo(t *testing.T) {
expectedErr error expectedErr error
}{ }{
{ {
name: "no blinding point", name: "no blinding point",
data: validData, data: validData,
processor: &mockProcessor{}, processor: &mockProcessor{},
expectedErr: ErrInvalidPayload{ expectedErr: ErrNoBlindingPoint,
Type: record.BlindingPointOnionType,
Violation: OmittedViolation,
},
},
{
name: "both blinding points",
data: validData,
updateAddBlinding: &btcec.PublicKey{},
payloadBlinding: &btcec.PublicKey{},
processor: &mockProcessor{},
expectedErr: ErrInvalidPayload{
Type: record.BlindingPointOnionType,
Violation: IncludedViolation,
},
}, },
{ {
name: "decryption failed", name: "decryption failed",
@ -265,12 +251,19 @@ func TestDecryptAndValidateFwdInfo(t *testing.T) {
}, },
}, },
{ {
name: "valid", name: "valid using update add",
updateAddBlinding: &btcec.PublicKey{}, updateAddBlinding: &btcec.PublicKey{},
data: validData, data: validData,
processor: &mockProcessor{}, processor: &mockProcessor{},
expectedErr: nil, expectedErr: nil,
}, },
{
name: "valid using payload",
payloadBlinding: &btcec.PublicKey{},
data: validData,
processor: &mockProcessor{},
expectedErr: nil,
},
} }
for _, testCase := range tests { for _, testCase := range tests {

View file

@ -138,8 +138,8 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload {
// should correspond to the bytes encapsulated in a TLV onion payload. The // should correspond to the bytes encapsulated in a TLV onion payload. The
// final hop bool signals that this payload was the final packet parsed by // final hop bool signals that this payload was the final packet parsed by
// sphinx. // sphinx.
func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload, func NewPayloadFromReader(r io.Reader, finalHop,
map[tlv.Type][]byte, error) { updateAddBlinding bool) (*Payload, map[tlv.Type][]byte, error) {
var ( var (
cid uint64 cid uint64
@ -177,7 +177,9 @@ func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload,
// Validate whether the sender properly included or omitted tlv records // Validate whether the sender properly included or omitted tlv records
// in accordance with BOLT 04. // in accordance with BOLT 04.
err = ValidateParsedPayloadTypes(parsedTypes, finalHop) err = ValidateParsedPayloadTypes(
parsedTypes, finalHop, updateAddBlinding,
)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -259,7 +261,7 @@ func NewCustomRecords(parsedTypes tlv.TypeMap) record.CustomSet {
// boolean should be true if the payload was parsed for an exit hop. The // boolean should be true if the payload was parsed for an exit hop. The
// requirements for this method are described in BOLT 04. // requirements for this method are described in BOLT 04.
func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap, func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap,
isFinalHop bool) error { isFinalHop, updateAddBlinding bool) error {
_, hasAmt := parsedTypes[record.AmtOnionType] _, hasAmt := parsedTypes[record.AmtOnionType]
_, hasLockTime := parsedTypes[record.LockTimeOnionType] _, hasLockTime := parsedTypes[record.LockTimeOnionType]
@ -267,6 +269,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap,
_, hasMPP := parsedTypes[record.MPPOnionType] _, hasMPP := parsedTypes[record.MPPOnionType]
_, hasAMP := parsedTypes[record.AMPOnionType] _, hasAMP := parsedTypes[record.AMPOnionType]
_, hasEncryptedData := parsedTypes[record.EncryptedDataOnionType] _, hasEncryptedData := parsedTypes[record.EncryptedDataOnionType]
_, hasBlinding := parsedTypes[record.BlindingPointOnionType]
// All cleartext hops (including final hop) and the final hop in a // All cleartext hops (including final hop) and the final hop in a
// blinded path require the forwading amount and expiry TLVs to be set. // blinded path require the forwading amount and expiry TLVs to be set.
@ -277,6 +280,32 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap,
needNextHop := !(hasEncryptedData || isFinalHop) needNextHop := !(hasEncryptedData || isFinalHop)
switch { switch {
// Both blinding point being set is invalid.
case hasBlinding && updateAddBlinding:
return ErrInvalidPayload{
Type: record.BlindingPointOnionType,
Violation: IncludedViolation,
FinalHop: isFinalHop,
}
// If encrypted data is not provided, blinding points should not be
// set.
case !hasEncryptedData && (hasBlinding || updateAddBlinding):
return ErrInvalidPayload{
Type: record.EncryptedDataOnionType,
Violation: OmittedViolation,
FinalHop: isFinalHop,
}
// If encrypted data is present, we require that one blinding point
// is set.
case hasEncryptedData && !(hasBlinding || updateAddBlinding):
return ErrInvalidPayload{
Type: record.EncryptedDataOnionType,
Violation: IncludedViolation,
FinalHop: isFinalHop,
}
// Hops that need forwarding info must include an amount to forward. // Hops that need forwarding info must include an amount to forward.
case needFwdInfo && !hasAmt: case needFwdInfo && !hasAmt:
return ErrInvalidPayload{ return ErrInvalidPayload{

View file

@ -26,6 +26,7 @@ type decodePayloadTest struct {
name string name string
payload []byte payload []byte
isFinalHop bool isFinalHop bool
updateAddBlinded bool
expErr error expErr error
expCustomRecords map[uint64][]byte expCustomRecords map[uint64][]byte
shouldHaveMPP bool shouldHaveMPP bool
@ -271,8 +272,9 @@ var decodePayloadTests = []decodePayloadTest{
}, },
}, },
{ {
name: "intermediate hop with encrypted data", name: "intermediate hop with encrypted data",
isFinalHop: false, isFinalHop: false,
updateAddBlinded: true,
payload: []byte{ payload: []byte{
// encrypted data // encrypted data
0x0a, 0x03, 0x03, 0x02, 0x01, 0x0a, 0x03, 0x03, 0x02, 0x01,
@ -365,8 +367,9 @@ var decodePayloadTests = []decodePayloadTest{
shouldHaveTotalAmt: true, shouldHaveTotalAmt: true,
}, },
{ {
name: "final blinded hop with total amount", name: "final blinded hop with total amount",
isFinalHop: true, isFinalHop: true,
updateAddBlinded: true,
payload: []byte{ payload: []byte{
// amount // amount
0x02, 0x00, 0x02, 0x00,
@ -378,8 +381,9 @@ var decodePayloadTests = []decodePayloadTest{
shouldHaveEncData: true, shouldHaveEncData: true,
}, },
{ {
name: "final blinded missing amt", name: "final blinded missing amt",
isFinalHop: true, isFinalHop: true,
updateAddBlinded: true,
payload: []byte{ payload: []byte{
// cltv // cltv
0x04, 0x00, 0x04, 0x00,
@ -394,8 +398,9 @@ var decodePayloadTests = []decodePayloadTest{
}, },
}, },
{ {
name: "final blinded missing cltv", name: "final blinded missing cltv",
isFinalHop: true, isFinalHop: true,
updateAddBlinded: true,
payload: []byte{ payload: []byte{
// amount // amount
0x02, 0x00, 0x02, 0x00,
@ -410,8 +415,9 @@ var decodePayloadTests = []decodePayloadTest{
}, },
}, },
{ {
name: "intermediate blinded has amount", name: "intermediate blinded has amount",
isFinalHop: false, isFinalHop: false,
updateAddBlinded: true,
payload: []byte{ payload: []byte{
// amount // amount
0x02, 0x00, 0x02, 0x00,
@ -425,8 +431,9 @@ var decodePayloadTests = []decodePayloadTest{
}, },
}, },
{ {
name: "intermediate blinded has expiry", name: "intermediate blinded has expiry",
isFinalHop: false, isFinalHop: false,
updateAddBlinded: true,
payload: []byte{ payload: []byte{
// cltv // cltv
0x04, 0x00, 0x04, 0x00,
@ -439,6 +446,67 @@ var decodePayloadTests = []decodePayloadTest{
FinalHop: false, FinalHop: false,
}, },
}, },
{
name: "update add blinding no data",
isFinalHop: false,
payload: []byte{
// cltv
0x04, 0x00,
},
updateAddBlinded: true,
expErr: hop.ErrInvalidPayload{
Type: record.EncryptedDataOnionType,
Violation: hop.OmittedViolation,
FinalHop: false,
},
},
{
name: "onion blinding point no data",
isFinalHop: false,
payload: append([]byte{
// blinding point (type / length)
0x0c, 0x21,
},
// blinding point (value)
testPubKey.SerializeCompressed()...,
),
expErr: hop.ErrInvalidPayload{
Type: record.EncryptedDataOnionType,
Violation: hop.OmittedViolation,
FinalHop: false,
},
},
{
name: "encrypted data no blinding",
isFinalHop: false,
payload: []byte{
// encrypted data
0x0a, 0x03, 0x03, 0x02, 0x01,
},
expErr: hop.ErrInvalidPayload{
Type: record.EncryptedDataOnionType,
Violation: hop.IncludedViolation,
},
},
{
name: "both blinding points",
isFinalHop: false,
updateAddBlinded: true,
payload: append([]byte{
// encrypted data
0x0a, 0x03, 0x03, 0x02, 0x01,
// blinding point (type / length)
0x0c, 0x21,
},
// blinding point (value)
testPubKey.SerializeCompressed()...,
),
expErr: hop.ErrInvalidPayload{
Type: record.BlindingPointOnionType,
Violation: hop.IncludedViolation,
FinalHop: false,
},
},
} }
// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the // TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the
@ -481,6 +549,7 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
p, _, err := hop.NewPayloadFromReader( p, _, err := hop.NewPayloadFromReader(
bytes.NewReader(test.payload), test.isFinalHop, bytes.NewReader(test.payload), test.isFinalHop,
test.updateAddBlinded,
) )
if !reflect.DeepEqual(test.expErr, err) { if !reflect.DeepEqual(test.expErr, err) {
t.Fatalf("expected error mismatch, want: %v, got: %v", t.Fatalf("expected error mismatch, want: %v, got: %v",