diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index e233537f4..be7be5eeb 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -93,6 +93,10 @@ type Payload struct { // customRecords are user-defined records in the custom type range that // were included in the payload. customRecords record.CustomSet + + // metadata is additional data that is sent along with the payment to + // the payee. + metadata []byte } // NewLegacyPayload builds a Payload from the amount, cltv, and next hop @@ -115,11 +119,12 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload { // should correspond to the bytes encapsulated in a TLV onion payload. func NewPayloadFromReader(r io.Reader) (*Payload, error) { var ( - cid uint64 - amt uint64 - cltv uint32 - mpp = &record.MPP{} - amp = &record.AMP{} + cid uint64 + amt uint64 + cltv uint32 + mpp = &record.MPP{} + amp = &record.AMP{} + metadata []byte ) tlvStream, err := tlv.NewStream( @@ -128,6 +133,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { record.NewNextHopIDRecord(&cid), mpp.Record(), amp.Record(), + record.NewMetadataRecord(&metadata), ) if err != nil { return nil, err @@ -168,6 +174,12 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { amp = nil } + // If no metadata field was parsed, set the metadata field on the + // resulting payload to nil. + if _, ok := parsedTypes[record.MetadataOnionType]; !ok { + metadata = nil + } + // Filter out the custom records. customRecords := NewCustomRecords(parsedTypes) @@ -180,6 +192,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { }, MPP: mpp, AMP: amp, + metadata: metadata, customRecords: customRecords, }, nil } @@ -284,6 +297,12 @@ func (h *Payload) CustomRecords() record.CustomSet { return h.customRecords } +// Metadata returns the additional data that is sent along with the +// payment to the payee. +func (h *Payload) Metadata() []byte { + return h.metadata +} + // getMinRequiredViolation checks for unrecognized required (even) fields in the // standard range and returns the lowest required type. Always returning the // lowest required type allows a failure message to be deterministic. diff --git a/htlcswitch/hop/payload_test.go b/htlcswitch/hop/payload_test.go index c7abc9fad..130b363dd 100644 --- a/htlcswitch/hop/payload_test.go +++ b/htlcswitch/hop/payload_test.go @@ -11,15 +11,16 @@ import ( "github.com/stretchr/testify/require" ) -const testUnknownRequiredType = 0x10 +const testUnknownRequiredType = 0x80 type decodePayloadTest struct { - name string - payload []byte - expErr error - expCustomRecords map[uint64][]byte - shouldHaveMPP bool - shouldHaveAMP bool + name string + payload []byte + expErr error + expCustomRecords map[uint64][]byte + shouldHaveMPP bool + shouldHaveAMP bool + shouldHaveMetadata bool } var decodePayloadTests = []decodePayloadTest{ @@ -258,6 +259,18 @@ var decodePayloadTests = []decodePayloadTest{ }, shouldHaveAMP: true, }, + { + name: "final hop with metadata", + payload: []byte{ + // amount + 0x02, 0x00, + // cltv + 0x04, 0x00, + // metadata + 0x10, 0x03, 0x01, 0x02, 0x03, + }, + shouldHaveMetadata: true, + }, } // TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the @@ -293,6 +306,7 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) { 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, } + testMetadata = []byte{1, 2, 3} testChildIndex = uint32(9) ) @@ -331,6 +345,15 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) { t.Fatalf("unexpected AMP payload") } + if test.shouldHaveMetadata { + if p.Metadata() == nil { + t.Fatalf("payload should have metadata") + } + require.Equal(t, testMetadata, p.Metadata()) + } else if p.Metadata() != nil { + t.Fatalf("unexpected metadata") + } + // Convert expected nil map to empty map, because we always expect an // initiated map from the payload. expCustomRecords := make(record.CustomSet) diff --git a/invoices/interface.go b/invoices/interface.go index 732b39e17..804817dd8 100644 --- a/invoices/interface.go +++ b/invoices/interface.go @@ -18,4 +18,8 @@ type Payload interface { // CustomRecords returns the custom tlv type records that were parsed // from the payload. CustomRecords() record.CustomSet + + // Metadata returns the additional data that is sent along with the + // payment to the payee. + Metadata() []byte } diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index aaee2a775..327974be5 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -906,6 +906,7 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash, customRecords: payload.CustomRecords(), mpp: payload.MultiPath(), amp: payload.AMPRecord(), + metadata: payload.Metadata(), } switch { diff --git a/invoices/test_utils_test.go b/invoices/test_utils_test.go index 48cf9e8c0..3d8842cb2 100644 --- a/invoices/test_utils_test.go +++ b/invoices/test_utils_test.go @@ -30,6 +30,7 @@ type mockPayload struct { mpp *record.MPP amp *record.AMP customRecords record.CustomSet + metadata []byte } func (p *mockPayload) MultiPath() *record.MPP { @@ -50,6 +51,10 @@ func (p *mockPayload) CustomRecords() record.CustomSet { return p.customRecords } +func (p *mockPayload) Metadata() []byte { + return p.metadata +} + const ( testHtlcExpiry = uint32(5) diff --git a/invoices/update.go b/invoices/update.go index bdfc8ca17..e1a2469ac 100644 --- a/invoices/update.go +++ b/invoices/update.go @@ -1,6 +1,7 @@ package invoices import ( + "encoding/hex" "errors" "github.com/lightningnetwork/lnd/amp" @@ -22,6 +23,7 @@ type invoiceUpdateCtx struct { customRecords record.CustomSet mpp *record.MPP amp *record.AMP + metadata []byte } // invoiceRef returns an identifier that can be used to lookup or update the @@ -52,9 +54,16 @@ func (i invoiceUpdateCtx) setID() *[32]byte { // log logs a message specific to this update context. func (i *invoiceUpdateCtx) log(s string) { + // Don't use %x in the log statement below, because it doesn't + // distinguish between nil and empty metadata. + metadata := "" + if i.metadata != nil { + metadata = hex.EncodeToString(i.metadata) + } + log.Debugf("Invoice%v: %v, amt=%v, expiry=%v, circuit=%v, mpp=%v, "+ - "amp=%v", i.invoiceRef(), s, i.amtPaid, i.expiry, i.circuitKey, - i.mpp, i.amp) + "amp=%v, metadata=%v", i.invoiceRef(), s, i.amtPaid, i.expiry, + i.circuitKey, i.mpp, i.amp, metadata) } // failRes is a helper function which creates a failure resolution with