diff --git a/htlcswitch/hop/fuzz_test.go b/htlcswitch/hop/fuzz_test.go index cbbe88260..f5fe648f5 100644 --- a/htlcswitch/hop/fuzz_test.go +++ b/htlcswitch/hop/fuzz_test.go @@ -97,19 +97,37 @@ func hopFromPayload(p *Payload) (*route.Hop, uint64) { // FuzzPayloadFinal fuzzes final hop payloads, providing the additional context // that the hop should be final (which is usually obtained by the structure -// of the sphinx packet). -func FuzzPayloadFinal(f *testing.F) { - fuzzPayload(f, true) +// of the sphinx packet) for the case where a blinding point was provided in +// UpdateAddHtlc. +func FuzzPayloadFinalBlinding(f *testing.F) { + fuzzPayload(f, true, true) +} + +// FuzzPayloadFinal fuzzes final hop payloads, providing the additional context +// that the hop should be final (which is usually obtained by the structure +// of the sphinx packet) for the case where no blinding point was provided in +// UpdateAddHtlc. +func FuzzPayloadFinalNoBlinding(f *testing.F) { + fuzzPayload(f, true, false) } // FuzzPayloadIntermediate fuzzes intermediate hop payloads, providing the // additional context that a hop should be intermediate (which is usually -// obtained by the structure of the sphinx packet). -func FuzzPayloadIntermediate(f *testing.F) { - fuzzPayload(f, false) +// obtained by the structure of the sphinx packet) for the case where a +// blinding point was provided in UpdateAddHtlc. +func FuzzPayloadIntermediateBlinding(f *testing.F) { + fuzzPayload(f, false, true) } -func fuzzPayload(f *testing.F, finalPayload bool) { +// FuzzPayloadIntermediate fuzzes intermediate hop payloads, providing the +// additional context that a hop should be intermediate (which is usually +// obtained by the structure of the sphinx packet) for the case where no +// blinding point was provided in UpdateAddHtlc. +func FuzzPayloadIntermediateNoBlinding(f *testing.F) { + fuzzPayload(f, false, false) +} + +func fuzzPayload(f *testing.F, finalPayload, updateAddBlinded bool) { f.Fuzz(func(t *testing.T, data []byte) { if len(data) > sphinx.MaxPayloadSize { return @@ -117,7 +135,9 @@ func fuzzPayload(f *testing.F, finalPayload bool) { r := bytes.NewReader(data) - payload1, _, err := NewPayloadFromReader(r, finalPayload) + payload1, _, err := NewPayloadFromReader( + r, finalPayload, updateAddBlinded, + ) if err != nil { return } @@ -146,7 +166,9 @@ func fuzzPayload(f *testing.F, finalPayload bool) { require.NoError(t, err) } - payload2, _, err := NewPayloadFromReader(&b, finalPayload) + payload2, _, err := NewPayloadFromReader( + &b, finalPayload, updateAddBlinded, + ) require.NoError(t, err) require.Equal(t, payload1, payload2) diff --git a/htlcswitch/hop/iterator.go b/htlcswitch/hop/iterator.go index df6f5aac7..7088127c3 100644 --- a/htlcswitch/hop/iterator.go +++ b/htlcswitch/hop/iterator.go @@ -17,6 +17,11 @@ import ( var ( // ErrDecodeFailed is returned when we can't decode blinded data. ErrDecodeFailed = errors.New("could not decode blinded data") + + // ErrNoBlindingPoint is returned when we have not provided a blinding + // point for a validated payload with encrypted data set. + ErrNoBlindingPoint = errors.New("no blinding point set for validated " + + "blinded hop") ) // Iterator is an interface that abstracts away the routing information @@ -109,7 +114,7 @@ func (r *sphinxHopIterator) HopPayload() (*Payload, error) { isFinal := r.processedPacket.Action == sphinx.ExitNode payload, parsed, err := NewPayloadFromReader( bytes.NewReader(r.processedPacket.Payload.Payload), - isFinal, + isFinal, r.blindingKit.UpdateAddBlinding.IsSome(), ) if err != nil { return nil, err @@ -182,35 +187,16 @@ type BlindingKit struct { IncomingAmount lnwire.MilliSatoshi } -// validateBlindingPoint validates that only one blinding point is present for -// the hop and returns the relevant one. -func (b *BlindingKit) validateBlindingPoint(payloadBlinding *btcec.PublicKey, - isFinalHop bool) (*btcec.PublicKey, error) { +// getBlindingPoint returns either the payload or updateAddHtlc blinding point, +// assuming that validation that these values are appropriately set has already +// been handled elsewhere. +func (b *BlindingKit) getBlindingPoint(payloadBlinding *btcec.PublicKey) ( + *btcec.PublicKey, error) { - // Bolt 04: if encrypted_recipient_data is present: - // - if blinding_point (in update add) is set: - // - MUST error if current_blinding_point is set (in payload) - // - otherwise: - // - MUST return an error if current_blinding_point is not present - // (in payload) payloadBlindingSet := payloadBlinding != nil updateBlindingSet := b.UpdateAddBlinding.IsSome() switch { - case !(payloadBlindingSet || updateBlindingSet): - return nil, ErrInvalidPayload{ - Type: record.BlindingPointOnionType, - Violation: OmittedViolation, - FinalHop: isFinalHop, - } - - case payloadBlindingSet && updateBlindingSet: - return nil, ErrInvalidPayload{ - Type: record.BlindingPointOnionType, - Violation: IncludedViolation, - FinalHop: isFinalHop, - } - case payloadBlindingSet: return payloadBlinding, nil @@ -223,9 +209,10 @@ func (b *BlindingKit) validateBlindingPoint(payloadBlinding *btcec.PublicKey, } return pk.Val, nil - } - return nil, fmt.Errorf("expected blinded point set") + default: + return nil, ErrNoBlindingPoint + } } // DecryptAndValidateFwdInfo performs all operations required to decrypt and @@ -235,11 +222,10 @@ func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload, *ForwardingInfo, error) { // We expect this function to be called when we have encrypted data - // present, and a blinding key is set either in the payload or the + // present, and expect validation to already have ensured that a + // blinding key is set either in the payload or the // update_add_htlc message. - blindingPoint, err := b.validateBlindingPoint( - payload.blindingPoint, isFinalHop, - ) + blindingPoint, err := b.getBlindingPoint(payload.blindingPoint) if err != nil { return nil, err } diff --git a/htlcswitch/hop/iterator_test.go b/htlcswitch/hop/iterator_test.go index 60919333b..a850c6dc1 100644 --- a/htlcswitch/hop/iterator_test.go +++ b/htlcswitch/hop/iterator_test.go @@ -216,24 +216,10 @@ func TestDecryptAndValidateFwdInfo(t *testing.T) { expectedErr error }{ { - name: "no blinding point", - data: validData, - processor: &mockProcessor{}, - expectedErr: ErrInvalidPayload{ - Type: record.BlindingPointOnionType, - Violation: OmittedViolation, - }, - }, - { - name: "both blinding points", - data: validData, - updateAddBlinding: &btcec.PublicKey{}, - payloadBlinding: &btcec.PublicKey{}, - processor: &mockProcessor{}, - expectedErr: ErrInvalidPayload{ - Type: record.BlindingPointOnionType, - Violation: IncludedViolation, - }, + name: "no blinding point", + data: validData, + processor: &mockProcessor{}, + expectedErr: ErrNoBlindingPoint, }, { name: "decryption failed", @@ -265,12 +251,19 @@ func TestDecryptAndValidateFwdInfo(t *testing.T) { }, }, { - name: "valid", + name: "valid using update add", updateAddBlinding: &btcec.PublicKey{}, data: validData, processor: &mockProcessor{}, expectedErr: nil, }, + { + name: "valid using payload", + payloadBlinding: &btcec.PublicKey{}, + data: validData, + processor: &mockProcessor{}, + expectedErr: nil, + }, } for _, testCase := range tests { diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index 70fdb1403..8454bc847 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -138,8 +138,8 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload { // should correspond to the bytes encapsulated in a TLV onion payload. The // final hop bool signals that this payload was the final packet parsed by // sphinx. -func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload, - map[tlv.Type][]byte, error) { +func NewPayloadFromReader(r io.Reader, finalHop, + updateAddBlinding bool) (*Payload, map[tlv.Type][]byte, error) { var ( cid uint64 @@ -177,7 +177,9 @@ func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload, // Validate whether the sender properly included or omitted tlv records // in accordance with BOLT 04. - err = ValidateParsedPayloadTypes(parsedTypes, finalHop) + err = ValidateParsedPayloadTypes( + parsedTypes, finalHop, updateAddBlinding, + ) if err != nil { return nil, nil, err } @@ -259,7 +261,7 @@ func NewCustomRecords(parsedTypes tlv.TypeMap) record.CustomSet { // boolean should be true if the payload was parsed for an exit hop. The // requirements for this method are described in BOLT 04. func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap, - isFinalHop bool) error { + isFinalHop, updateAddBlinding bool) error { _, hasAmt := parsedTypes[record.AmtOnionType] _, hasLockTime := parsedTypes[record.LockTimeOnionType] @@ -267,6 +269,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap, _, hasMPP := parsedTypes[record.MPPOnionType] _, hasAMP := parsedTypes[record.AMPOnionType] _, hasEncryptedData := parsedTypes[record.EncryptedDataOnionType] + _, hasBlinding := parsedTypes[record.BlindingPointOnionType] // All cleartext hops (including final hop) and the final hop in a // blinded path require the forwading amount and expiry TLVs to be set. @@ -277,6 +280,32 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap, needNextHop := !(hasEncryptedData || isFinalHop) switch { + // Both blinding point being set is invalid. + case hasBlinding && updateAddBlinding: + return ErrInvalidPayload{ + Type: record.BlindingPointOnionType, + Violation: IncludedViolation, + FinalHop: isFinalHop, + } + + // If encrypted data is not provided, blinding points should not be + // set. + case !hasEncryptedData && (hasBlinding || updateAddBlinding): + return ErrInvalidPayload{ + Type: record.EncryptedDataOnionType, + Violation: OmittedViolation, + FinalHop: isFinalHop, + } + + // If encrypted data is present, we require that one blinding point + // is set. + case hasEncryptedData && !(hasBlinding || updateAddBlinding): + return ErrInvalidPayload{ + Type: record.EncryptedDataOnionType, + Violation: IncludedViolation, + FinalHop: isFinalHop, + } + // Hops that need forwarding info must include an amount to forward. case needFwdInfo && !hasAmt: return ErrInvalidPayload{ diff --git a/htlcswitch/hop/payload_test.go b/htlcswitch/hop/payload_test.go index 301e57716..666cf64bc 100644 --- a/htlcswitch/hop/payload_test.go +++ b/htlcswitch/hop/payload_test.go @@ -26,6 +26,7 @@ type decodePayloadTest struct { name string payload []byte isFinalHop bool + updateAddBlinded bool expErr error expCustomRecords map[uint64][]byte shouldHaveMPP bool @@ -271,8 +272,9 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "intermediate hop with encrypted data", - isFinalHop: false, + name: "intermediate hop with encrypted data", + isFinalHop: false, + updateAddBlinded: true, payload: []byte{ // encrypted data 0x0a, 0x03, 0x03, 0x02, 0x01, @@ -365,8 +367,9 @@ var decodePayloadTests = []decodePayloadTest{ shouldHaveTotalAmt: true, }, { - name: "final blinded hop with total amount", - isFinalHop: true, + name: "final blinded hop with total amount", + isFinalHop: true, + updateAddBlinded: true, payload: []byte{ // amount 0x02, 0x00, @@ -378,8 +381,9 @@ var decodePayloadTests = []decodePayloadTest{ shouldHaveEncData: true, }, { - name: "final blinded missing amt", - isFinalHop: true, + name: "final blinded missing amt", + isFinalHop: true, + updateAddBlinded: true, payload: []byte{ // cltv 0x04, 0x00, @@ -394,8 +398,9 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "final blinded missing cltv", - isFinalHop: true, + name: "final blinded missing cltv", + isFinalHop: true, + updateAddBlinded: true, payload: []byte{ // amount 0x02, 0x00, @@ -410,8 +415,9 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "intermediate blinded has amount", - isFinalHop: false, + name: "intermediate blinded has amount", + isFinalHop: false, + updateAddBlinded: true, payload: []byte{ // amount 0x02, 0x00, @@ -425,8 +431,9 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "intermediate blinded has expiry", - isFinalHop: false, + name: "intermediate blinded has expiry", + isFinalHop: false, + updateAddBlinded: true, payload: []byte{ // cltv 0x04, 0x00, @@ -439,6 +446,67 @@ var decodePayloadTests = []decodePayloadTest{ FinalHop: false, }, }, + { + name: "update add blinding no data", + isFinalHop: false, + payload: []byte{ + // cltv + 0x04, 0x00, + }, + updateAddBlinded: true, + expErr: hop.ErrInvalidPayload{ + Type: record.EncryptedDataOnionType, + Violation: hop.OmittedViolation, + FinalHop: false, + }, + }, + { + name: "onion blinding point no data", + isFinalHop: false, + payload: append([]byte{ + // blinding point (type / length) + 0x0c, 0x21, + }, + // blinding point (value) + testPubKey.SerializeCompressed()..., + ), + expErr: hop.ErrInvalidPayload{ + Type: record.EncryptedDataOnionType, + Violation: hop.OmittedViolation, + FinalHop: false, + }, + }, + { + name: "encrypted data no blinding", + isFinalHop: false, + payload: []byte{ + // encrypted data + 0x0a, 0x03, 0x03, 0x02, 0x01, + }, + expErr: hop.ErrInvalidPayload{ + Type: record.EncryptedDataOnionType, + Violation: hop.IncludedViolation, + }, + }, + { + name: "both blinding points", + isFinalHop: false, + updateAddBlinded: true, + payload: append([]byte{ + // encrypted data + 0x0a, 0x03, 0x03, 0x02, 0x01, + // blinding point (type / length) + 0x0c, 0x21, + }, + // blinding point (value) + testPubKey.SerializeCompressed()..., + ), + expErr: hop.ErrInvalidPayload{ + Type: record.BlindingPointOnionType, + Violation: hop.IncludedViolation, + FinalHop: false, + }, + }, } // TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the @@ -481,6 +549,7 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) { p, _, err := hop.NewPayloadFromReader( bytes.NewReader(test.payload), test.isFinalHop, + test.updateAddBlinded, ) if !reflect.DeepEqual(test.expErr, err) { t.Fatalf("expected error mismatch, want: %v, got: %v",