diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 9da504a5d..e5cbcaba9 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -650,18 +650,13 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef, return err } - // If the set ID hint is non-nil, then we'll use that to filter - // out the HTLCs for AMP invoice so we don't need to read them - // all out to satisfy the invoice callback below. If it's nil, - // then we pass in the zero set ID which means no HTLCs will be - // read out. - var invSetID invpkg.SetID - - if setIDHint != nil { - invSetID = *setIDHint - } + // setIDHint can also be nil here, which means all the HTLCs + // for AMP invoices are fetched. If the blank setID is passed + // in, then no HTLCs are fetched for the AMP invoice. If a + // specific setID is passed in, then only the HTLCs for that + // setID are fetched for a particular sub-AMP invoice. invoice, err := fetchInvoice( - invoiceNum, invoices, []*invpkg.SetID{&invSetID}, false, + invoiceNum, invoices, []*invpkg.SetID{setIDHint}, false, ) if err != nil { return err @@ -691,7 +686,7 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef, // If this is an AMP update, then limit the returned AMP state // to only the requested set ID. if setIDHint != nil { - filterInvoiceAMPState(updatedInvoice, &invSetID) + filterInvoiceAMPState(updatedInvoice, setIDHint) } return nil @@ -848,7 +843,10 @@ func (k *kvInvoiceUpdater) Finalize(updateType invpkg.UpdateType) error { return k.storeSettleHodlInvoiceUpdate() case invpkg.CancelInvoiceUpdate: - return k.serializeAndStoreInvoice() + // Persist all changes which where made when cancelling the + // invoice. All HTLCs which were accepted are now canceled, so + // we persist this state. + return k.storeCancelHtlcsUpdate() } return fmt.Errorf("unknown update type: %v", updateType) diff --git a/invoices/interface.go b/invoices/interface.go index c906da1c3..c49493d5b 100644 --- a/invoices/interface.go +++ b/invoices/interface.go @@ -56,6 +56,11 @@ type InvoiceDB interface { // passed payment hash. If an invoice matching the passed payment hash // doesn't exist within the database, then the action will fail with a // "not found" error. + // The setIDHint is used to signal whether AMP HTLCs should be fetched + // for the invoice. If a blank setID is passed no HTLCs will be fetched + // in case of an AMP invoice. Nil means all HTLCs for all sub AMP + // invoices will be fetched and if a specific setID is supplied only + // HTLCs for that setID will be fetched. // // The update is performed inside the same database transaction that // fetches the invoice and is therefore atomic. The fields to update diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index fbaaee6b2..7f448d238 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -704,7 +704,10 @@ func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef InvoiceRef, // Try to mark the specified htlc as canceled in the invoice database. // Intercept the update descriptor to set the local updated variable. If // no invoice update is performed, we can return early. + // setID is only set for AMP HTLCs, so it can be nil and it is expected + // to be nil for non-AMP HTLCs. setID := (*SetID)(invoiceRef.SetID()) + var updated bool invoice, err := i.idb.UpdateInvoice( context.Background(), invoiceRef, setID, @@ -1014,6 +1017,9 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked( HtlcResolution, invoiceExpiry, error) { invoiceRef := ctx.invoiceRef() + + // This setID is only set for AMP HTLCs, so it can be nil and it is + // also expected to be nil for non-AMP HTLCs. setID := (*SetID)(ctx.setID()) // We need to look up the current state of the invoice in order to send @@ -1370,7 +1376,15 @@ func (i *InvoiceRegistry) SettleHodlInvoice(ctx context.Context, hash := preimage.Hash() invoiceRef := InvoiceRefByHash(hash) - invoice, err := i.idb.UpdateInvoice(ctx, invoiceRef, nil, updateInvoice) + + // AMP hold invoices are not supported so we set the setID to nil. + // For non-AMP invoices this parameter is ignored during the fetching + // of the database state. + setID := (*SetID)(nil) + + invoice, err := i.idb.UpdateInvoice( + ctx, invoiceRef, setID, updateInvoice, + ) if err != nil { log.Errorf("SettleHodlInvoice with preimage %v: %v", preimage, err) @@ -1454,10 +1468,14 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(ctx context.Context, }, nil } + // If it's an AMP invoice we need to fetch all AMP HTLCs here so that + // we can cancel all of HTLCs which are in the accepted state across + // different setIDs. + setID := (*SetID)(nil) invoiceRef := InvoiceRefByHash(payHash) - - // We pass a nil setID which means no HTLCs will be read out. - invoice, err := i.idb.UpdateInvoice(ctx, invoiceRef, nil, updateInvoice) + invoice, err := i.idb.UpdateInvoice( + ctx, invoiceRef, setID, updateInvoice, + ) // Implement idempotency by returning success if the invoice was already // canceled. @@ -1483,8 +1501,8 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(ctx context.Context, // that are waiting for resolution. Any htlcs that were already canceled // before, will be notified again. This isn't necessary but doesn't hurt // either. - // - // TODO(ziggie): Also consider AMP HTLCs here. + // For AMP invoices we fetched all AMP HTLCs for all sub AMP invoices + // here so we can clean up all of them. for key, htlc := range invoice.Htlcs { if htlc.State != HtlcStateCanceled { continue @@ -1496,6 +1514,7 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(ctx context.Context, ), ) } + i.notifyClients(payHash, invoice, nil) // Attempt to also delete the invoice if requested through the registry diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index ba110791b..427d225d6 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -117,6 +117,10 @@ func TestInvoiceRegistry(t *testing.T) { name: "FailPartialAMPPayment", test: testFailPartialAMPPayment, }, + { + name: "CancelAMPInvoicePendingHTLCs", + test: testCancelAMPInvoicePendingHTLCs, + }, } makeKeyValueDB := func(t *testing.T) (invpkg.InvoiceDB, @@ -2441,3 +2445,130 @@ func testFailPartialAMPPayment(t *testing.T, "expected HTLC to be canceled") } } + +// testCancelAMPInvoicePendingHTLCs tests the case where an AMP invoice is +// canceled and the remaining HTLCs are also canceled so that no HTLCs are left +// in the accepted state. +func testCancelAMPInvoicePendingHTLCs(t *testing.T, + makeDB func(t *testing.T) (invpkg.InvoiceDB, *clock.TestClock)) { + + t.Parallel() + + ctx := newTestContext(t, nil, makeDB) + ctxb := context.Background() + + const ( + expiry = uint32(testCurrentHeight + 20) + numShards = 4 + ) + + var ( + shardAmt = testInvoiceAmount / lnwire.MilliSatoshi(numShards) + payAddr [32]byte + ) + _, err := rand.Read(payAddr[:]) + require.NoError(t, err) + + // Create an AMP invoice we are going to pay via a multi-part payment. + ampInvoice := newInvoice(t, false, true) + + // An AMP invoice is referenced by the payment address. + ampInvoice.Terms.PaymentAddr = payAddr + + _, err = ctx.registry.AddInvoice( + ctxb, ampInvoice, testInvoicePaymentHash, + ) + require.NoError(t, err) + + htlcPayloadSet1 := &mockPayload{ + mpp: record.NewMPP(testInvoiceAmount, payAddr), + // We are not interested in settling the AMP HTLC so we don't + // use valid shares. + amp: record.NewAMP([32]byte{1}, [32]byte{1}, 1), + } + + // Send first HTLC which pays part of the invoice. + hodlChan1 := make(chan interface{}, 1) + resolution, err := ctx.registry.NotifyExitHopHtlc( + lntypes.Hash{1}, shardAmt, expiry, testCurrentHeight, + getCircuitKey(1), hodlChan1, nil, htlcPayloadSet1, + ) + require.NoError(t, err) + require.Nil(t, resolution, "did not expect direct resolution") + + htlcPayloadSet2 := &mockPayload{ + mpp: record.NewMPP(testInvoiceAmount, payAddr), + // We are not interested in settling the AMP HTLC so we don't + // use valid shares. + amp: record.NewAMP([32]byte{2}, [32]byte{2}, 1), + } + + // Send htlc 2 which should be added to the invoice as expected. + hodlChan2 := make(chan interface{}, 1) + resolution, err = ctx.registry.NotifyExitHopHtlc( + lntypes.Hash{2}, shardAmt, expiry, testCurrentHeight, + getCircuitKey(2), hodlChan2, nil, htlcPayloadSet2, + ) + require.NoError(t, err) + require.Nil(t, resolution, "did not expect direct resolution") + + require.Eventuallyf(t, func() bool { + inv, err := ctx.registry.LookupInvoice( + ctxb, testInvoicePaymentHash, + ) + require.NoError(t, err) + + return len(inv.Htlcs) == 2 + }, testTimeout, time.Millisecond*100, "HTLCs not added to invoice") + + // expire the invoice here. + ctx.clock.SetTime(testTime.Add(65 * time.Minute)) + + // Expect HLTC 1 to be canceled via the MPPTimeout fail resolution. + select { + case resolution := <-hodlChan1: + htlcResolution, _ := resolution.(invpkg.HtlcResolution) + _, ok := htlcResolution.(*invpkg.HtlcFailResolution) + require.True( + t, ok, "expected fail resolution, got: %T", resolution, + ) + + case <-time.After(testTimeout): + t.Fatal("timeout waiting for HTLC resolution") + } + + // Expect HLTC 2 to be canceled via the MPPTimeout fail resolution. + select { + case resolution := <-hodlChan2: + htlcResolution, _ := resolution.(invpkg.HtlcResolution) + _, ok := htlcResolution.(*invpkg.HtlcFailResolution) + require.True( + t, ok, "expected fail resolution, got: %T", resolution, + ) + + case <-time.After(testTimeout): + t.Fatal("timeout waiting for HTLC resolution") + } + + require.Eventuallyf(t, func() bool { + inv, err := ctx.registry.LookupInvoice( + ctxb, testInvoicePaymentHash, + ) + require.NoError(t, err) + + return inv.State == invpkg.ContractCanceled + }, testTimeout, time.Millisecond*100, "invoice not canceled") + + // Fetch the invoice again and compare the number of cancelled HTLCs. + inv, err := ctx.registry.LookupInvoice( + ctxb, testInvoicePaymentHash, + ) + require.NoError(t, err) + + // Make sure all HTLCs are in the cancelled state. + require.Len(t, inv.Htlcs, 2) + for _, htlc := range inv.Htlcs { + require.Equal(t, invpkg.HtlcStateCanceled, htlc.State, + "expected HTLC to be canceled") + } +} diff --git a/invoices/sql_store.go b/invoices/sql_store.go index eb465eabb..a548b7df5 100644 --- a/invoices/sql_store.go +++ b/invoices/sql_store.go @@ -1336,14 +1336,26 @@ func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef, txOpt := SQLInvoiceQueriesTxOptions{readOnly: false} txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error { - if setID != nil { - // Make sure to use the set ID if this is an AMP update. + switch { + // For the default case we fetch all HTLCs. + case setID == nil: + ref.refModifier = DefaultModifier + + // If the setID is the blank but NOT nil, we set the + // refModifier to HtlcSetBlankModifier to fetch no HTLC for the + // AMP invoice. + case *setID == BlankPayAddr: + ref.refModifier = HtlcSetBlankModifier + + // A setID is provided, we use the refModifier to fetch only + // the HTLCs for the given setID and also make sure we add the + // setID to the ref. + default: var setIDBytes [32]byte copy(setIDBytes[:], setID[:]) ref.setID = &setIDBytes - // If we're updating an AMP invoice, we'll also only - // need to fetch the HTLCs for the given set ID. + // We only fetch the HTLCs for the given setID. ref.refModifier = HtlcSetOnlyModifier } diff --git a/invoices/update_invoice.go b/invoices/update_invoice.go index a2de1b8f2..2b1ac32a9 100644 --- a/invoices/update_invoice.go +++ b/invoices/update_invoice.go @@ -31,7 +31,7 @@ func acceptHtlcsAmp(invoice *Invoice, setID SetID, } // cancelHtlcsAmp processes a cancellation of an HTLC that belongs to an AMP -// HTLC set. We'll need to update the meta data in the main invoice, and also +// HTLC set. We'll need to update the meta data in the main invoice, and also // apply the new update to the update MAP, since all the HTLCs for a given HTLC // set need to be written in-line with each other. func cancelHtlcsAmp(invoice *Invoice, circuitKey models.CircuitKey, @@ -552,6 +552,9 @@ func cancelInvoice(invoice *Invoice, hash *lntypes.Hash, invoice.State = ContractCanceled for key, htlc := range invoice.Htlcs { + // We might not have a setID here in case we are cancelling + // an AMP invoice however the setID is only important when + // settling an AMP HTLC. canceled, _, err := getUpdatedHtlcState( htlc, ContractCanceled, setID, ) @@ -567,6 +570,19 @@ func cancelInvoice(invoice *Invoice, hash *lntypes.Hash, if err != nil { return err } + + // If its an AMP HTLC we need to make sure we persist + // this new state otherwise AMP HTLCs are not updated + // on disk because HTLCs for AMP invoices are stored + // separately. + if htlc.AMP != nil { + err := cancelHtlcsAmp( + invoice, key, htlc, updater, + ) + if err != nil { + return err + } + } } }