diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index 3deb6405c..b926260f8 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -20,6 +20,7 @@ import ( "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -108,6 +109,14 @@ func TestInvoiceRegistry(t *testing.T) { name: "SpontaneousAmpPayment", test: testSpontaneousAmpPayment, }, + { + name: "FailPartialMPPPaymentExternal", + test: testFailPartialMPPPaymentExternal, + }, + { + name: "FailPartialAMPPayment", + test: testFailPartialAMPPayment, + }, } makeKeyValueDB := func(t *testing.T) (invpkg.InvoiceDB, @@ -204,7 +213,7 @@ func testSettleInvoice(t *testing.T, require.Equal(t, subscription.PayHash(), &testInvoicePaymentHash) // Add the invoice. - testInvoice := newInvoice(t, false) + testInvoice := newInvoice(t, false, false) addIdx, err := ctx.registry.AddInvoice( ctxb, testInvoice, testInvoicePaymentHash, ) @@ -395,7 +404,7 @@ func testCancelInvoiceImpl(t *testing.T, gc bool, require.Equal(t, subscription.PayHash(), &testInvoicePaymentHash) // Add the invoice. - testInvoice := newInvoice(t, false) + testInvoice := newInvoice(t, false, false) _, err = ctx.registry.AddInvoice( ctxb, testInvoice, testInvoicePaymentHash, ) @@ -555,7 +564,7 @@ func testSettleHoldInvoice(t *testing.T, require.Equal(t, subscription.PayHash(), &testInvoicePaymentHash) // Add the invoice. - invoice := newInvoice(t, true) + invoice := newInvoice(t, true, false) _, err = registry.AddInvoice(ctxb, invoice, testInvoicePaymentHash) require.NoError(t, err) @@ -716,7 +725,7 @@ func testCancelHoldInvoice(t *testing.T, ctxb := context.Background() // Add the invoice. - invoice := newInvoice(t, true) + invoice := newInvoice(t, true, false) _, err = registry.AddInvoice(ctxb, invoice, testInvoicePaymentHash) require.NoError(t, err) @@ -1043,7 +1052,7 @@ func testMppPayment(t *testing.T, ctxb := context.Background() // Add the invoice. - testInvoice := newInvoice(t, false) + testInvoice := newInvoice(t, false, false) _, err := ctx.registry.AddInvoice( ctxb, testInvoice, testInvoicePaymentHash, ) @@ -1141,7 +1150,7 @@ func testMppPaymentWithOverpayment(t *testing.T, ctx := newTestContext(t, nil, makeDB) // Add the invoice. - testInvoice := newInvoice(t, false) + testInvoice := newInvoice(t, false, false) _, err := ctx.registry.AddInvoice( ctxb, testInvoice, testInvoicePaymentHash, ) @@ -1432,7 +1441,7 @@ func testHeightExpiryWithRegistryImpl(t *testing.T, numParts int, settle bool, // Add a hold invoice, we set a non-nil payment request so that this // invoice is not considered a keysend by the expiry watcher. - testInvoice := newInvoice(t, false) + testInvoice := newInvoice(t, false, false) testInvoice.HodlInvoice = true testInvoice.PaymentRequest = []byte{1, 2, 3} @@ -1545,7 +1554,7 @@ func testMultipleSetHeightExpiry(t *testing.T, ctx := newTestContext(t, nil, makeDB) // Add a hold invoice. - testInvoice := newInvoice(t, true) + testInvoice := newInvoice(t, true, false) ctxb := context.Background() _, err := ctx.registry.AddInvoice( @@ -2109,3 +2118,326 @@ func testSpontaneousAmpPaymentImpl( } } } + +// testFailPartialMPPPaymentExternal tests that the HTLC set is cancelled back +// as soon as the HTLC interceptor denies one of the HTLCs. +func testFailPartialMPPPaymentExternal(t *testing.T, + makeDB func(t *testing.T) (invpkg.InvoiceDB, *clock.TestClock)) { + + t.Parallel() + + mockHtlcInterceptor := &invpkg.MockHtlcModifier{} + cfg := defaultRegistryConfig() + cfg.HtlcInterceptor = mockHtlcInterceptor + ctx := newTestContext(t, &cfg, makeDB) + + // Add an invoice which we are going to pay via a MPP set. + testInvoice := newInvoice(t, false, false) + + ctxb := context.Background() + _, err := ctx.registry.AddInvoice( + ctxb, testInvoice, testInvoicePaymentHash, + ) + require.NoError(t, err) + + mppPayload := &mockPayload{ + mpp: record.NewMPP(testInvoiceAmount, [32]byte{}), + } + + // Send first HTLC which pays part of the invoice but keeps the invoice + // in an open state because the amount is less than the invoice amount. + hodlChan1 := make(chan interface{}, 1) + resolution, err := ctx.registry.NotifyExitHopHtlc( + testInvoicePaymentHash, testInvoice.Terms.Value/3, + testHtlcExpiry, testCurrentHeight, getCircuitKey(1), + hodlChan1, nil, mppPayload, + ) + require.NoError(t, err) + require.Nil(t, resolution, "did not expect direct resolution") + + // Register the expected response from the interceptor so that the + // whole HTLC set is cancelled. + expectedResponse := invpkg.HtlcModifyResponse{ + CancelSet: true, + } + mockHtlcInterceptor.On("Intercept", mock.Anything, mock.Anything). + Return(nil, expectedResponse) + + // Send htlc 2. We expect the HTLC to be cancelled because the + // interceptor will deny it. + resolution, err = ctx.registry.NotifyExitHopHtlc( + testInvoicePaymentHash, testInvoice.Terms.Value/2, + testHtlcExpiry, testCurrentHeight, getCircuitKey(2), nil, + nil, mppPayload, + ) + require.NoError(t, err) + failResolution, ok := resolution.(*invpkg.HtlcFailResolution) + require.True(t, ok, "expected fail resolution, got: %T", resolution) + + // Make sure the resolution includes the custom error msg. + require.Equal(t, invpkg.ExternalValidationFailed, + failResolution.Outcome, "expected ExternalValidationFailed, "+ + "got: %v", failResolution.Outcome) + + // Expect HLTC 1 also to be cancelled because it is part of the cancel + // set and the interceptor cancelled the whole set after receiving the + // second HTLC. + select { + case resolution := <-hodlChan1: + htlcResolution, _ := resolution.(invpkg.HtlcResolution) + failResolution, ok = htlcResolution.(*invpkg.HtlcFailResolution) + require.True( + t, ok, "expected fail resolution, got: %T", + htlcResolution, + ) + require.Equal( + t, invpkg.ExternalValidationFailed, + failResolution.Outcome, "expected "+ + "ExternalValidationFailed, got: %v", + failResolution.Outcome, + ) + + case <-time.After(testTimeout): + t.Fatal("timeout waiting for HTLC resolution") + } + + // Assert that the invoice is still open. + inv, err := ctx.registry.LookupInvoice(ctxb, testInvoicePaymentHash) + require.NoError(t, err) + require.Equal(t, invpkg.ContractOpen, inv.State, "expected "+ + "OPEN invoice") + + // Now let the invoice expire. + currentTime := ctx.clock.Now() + ctx.clock.SetTime(currentTime.Add(61 * time.Minute)) + + // Make sure the invoices changes to the canceled state. + require.Eventuallyf(t, func() bool { + inv, err := ctx.registry.LookupInvoice( + ctxb, testInvoicePaymentHash, + ) + require.NoError(t, err) + + return inv.State == invpkg.ContractCanceled + }, testTimeout, time.Millisecond*100, "invoice not canceled") + + // Fetch the invoice again and compare the number of cancelled HTLCs. + inv, err = ctx.registry.LookupInvoice( + ctxb, testInvoicePaymentHash, + ) + require.NoError(t, err) + + // Make sure all HTLCs are in the canceled state which in our case is + // only the first one because the second HTLC was never added to the + // invoice registry in the first place. + require.Len(t, inv.Htlcs, 1) + require.Equal( + t, invpkg.HtlcStateCanceled, inv.Htlcs[getCircuitKey(1)].State, + ) +} + +// testFailPartialAMPPayment tests the MPP timeout logic for AMP invoices. It +// makes sure that all HTLCs are cancelled if the full invoice amount is not +// received. Moreover it points out some TODOs to make AMP invoices more robust. +func testFailPartialAMPPayment(t *testing.T, + makeDB func(t *testing.T) (invpkg.InvoiceDB, *clock.TestClock)) { + + t.Parallel() + + ctx := newTestContext(t, nil, makeDB) + ctxb := context.Background() + + const ( + expiry = uint32(testCurrentHeight + 20) + numShards = 4 + ) + + var ( + shardAmt = testInvoiceAmount / lnwire.MilliSatoshi(numShards) + setID [32]byte + payAddr [32]byte + ) + _, err := rand.Read(payAddr[:]) + require.NoError(t, err) + + // Create an AMP invoice we are going to pay via a multi-part payment. + ampInvoice := newInvoice(t, false, true) + + // An AMP invoice is referenced by the payment address. + ampInvoice.Terms.PaymentAddr = payAddr + + _, err = ctx.registry.AddInvoice( + ctxb, ampInvoice, testInvoicePaymentHash, + ) + require.NoError(t, err) + + // Generate a random setID for the HTLCs. + _, err = rand.Read(setID[:]) + require.NoError(t, err) + + htlcPayload1 := &mockPayload{ + mpp: record.NewMPP(testInvoiceAmount, payAddr), + // We are not interested in settling the AMP HTLC so we don't + // use valid shares. + amp: record.NewAMP([32]byte{1}, setID, 1), + } + + // Send first HTLC which pays part of the invoice. + hodlChan1 := make(chan interface{}, 1) + resolution, err := ctx.registry.NotifyExitHopHtlc( + lntypes.Hash{1}, shardAmt, expiry, testCurrentHeight, + getCircuitKey(1), hodlChan1, nil, htlcPayload1, + ) + require.NoError(t, err) + require.Nil(t, resolution, "did not expect direct resolution") + + htlcPayload2 := &mockPayload{ + mpp: record.NewMPP(testInvoiceAmount, payAddr), + // We are not interested in settling the AMP HTLC so we don't + // use valid shares. + amp: record.NewAMP([32]byte{2}, setID, 2), + } + + // Send htlc 2 which should be added to the invoice as expected. + hodlChan2 := make(chan interface{}, 1) + resolution, err = ctx.registry.NotifyExitHopHtlc( + lntypes.Hash{2}, shardAmt, expiry, testCurrentHeight, + getCircuitKey(2), hodlChan2, nil, htlcPayload2, + ) + require.NoError(t, err) + require.Nil(t, resolution, "did not expect direct resolution") + + // Now time-out the HTLCs. The HoldDuration is 30 seconds after the + // HTLC will be cancelled. + currentTime := ctx.clock.Now() + ctx.clock.SetTime(currentTime.Add(35 * time.Second)) + + // Expect HLTC 1 to be canceled via the MPPTimeout fail resolution. + select { + case resolution := <-hodlChan1: + htlcResolution, _ := resolution.(invpkg.HtlcResolution) + failRes, ok := htlcResolution.(*invpkg.HtlcFailResolution) + require.True( + t, ok, "expected fail resolution, got: %T", resolution, + ) + require.Equal( + t, invpkg.ResultMppTimeout, failRes.Outcome, + "expected MPPTimeout, got: %v", failRes.Outcome, + ) + + case <-time.After(testTimeout): + t.Fatal("timeout waiting for HTLC resolution") + } + + // Expect HLTC 2 to be canceled via the MPPTimeout fail resolution. + select { + case resolution := <-hodlChan2: + htlcResolution, _ := resolution.(invpkg.HtlcResolution) + failRes, ok := htlcResolution.(*invpkg.HtlcFailResolution) + require.True( + t, ok, "expected fail resolution, got: %T", resolution, + ) + require.Equal( + t, invpkg.ResultMppTimeout, failRes.Outcome, + "expected MPPTimeout, got: %v", failRes.Outcome, + ) + + case <-time.After(testTimeout): + t.Fatal("timeout waiting for HTLC resolution") + } + + // The AMP invoice should still be open. + inv, err := ctx.registry.LookupInvoice(ctxb, testInvoicePaymentHash) + require.NoError(t, err) + require.Equal(t, invpkg.ContractOpen, inv.State, "expected "+ + "OPEN invoice") + + // Because one HTLC of the set was cancelled we expect the AMPState to + // be set to canceled. + ampState, ok := inv.AMPState[setID] + require.True(t, ok, "expected AMPState to be set") + require.Equal(t, invpkg.HtlcStateCanceled, ampState.State, "expected "+ + "AMPState CANCELED") + + // The following is a bug and should not be allowed because the sub + // AMP invoice is already marked as canceled. However LND will accept + // other HTLCs to the AMP sub-invoice. + // + // TODO(ziggie): Fix this bug. + htlcPayload3 := &mockPayload{ + mpp: record.NewMPP(testInvoiceAmount, payAddr), + // We are not interested in settling the AMP HTLC so we don't + // use valid shares. + amp: record.NewAMP([32]byte{3}, setID, 3), + } + + // Send htlc 3 which should be added to the invoice as expected. + hodlChan3 := make(chan interface{}, 1) + resolution, err = ctx.registry.NotifyExitHopHtlc( + lntypes.Hash{3}, shardAmt, expiry, testCurrentHeight, + getCircuitKey(3), hodlChan3, nil, htlcPayload3, + ) + require.NoError(t, err) + require.Nil(t, resolution, "did not expect direct resolution") + + // TODO(ziggie): This is a race condition between the invoice being + // cancelled and the htlc being added to the invoice. If we do not wait + // here until the HTLC is added to the invoice, the test might fail + // because the HTLC will not be resolved. + require.Eventuallyf(t, func() bool { + inv, err := ctx.registry.LookupInvoice( + ctxb, testInvoicePaymentHash, + ) + require.NoError(t, err) + + return len(inv.Htlcs) == 3 + }, testTimeout, time.Millisecond*100, "HTLC 3 not added to invoice") + + // Now also let the invoice expire the invoice expiry is 1 hour. + currentTime = ctx.clock.Now() + ctx.clock.SetTime(currentTime.Add(1 * time.Minute)) + + // Expect HLTC 3 to be canceled either via the cancelation of the + // invoice or because the MPP timeout kicks in. + select { + case resolution := <-hodlChan3: + htlcResolution, _ := resolution.(invpkg.HtlcResolution) + failRes, ok := htlcResolution.(*invpkg.HtlcFailResolution) + require.True( + t, ok, "expected fail resolution, got: %T", resolution, + ) + require.Equal( + t, invpkg.ResultMppTimeout, failRes.Outcome, + "expected MPPTimeout, got: %v", failRes.Outcome, + ) + + case <-time.After(testTimeout): + t.Fatal("timeout waiting for HTLC resolution") + } + + // expire the invoice here. + currentTime = ctx.clock.Now() + ctx.clock.SetTime(currentTime.Add(61 * time.Minute)) + + require.Eventuallyf(t, func() bool { + inv, err := ctx.registry.LookupInvoice( + ctxb, testInvoicePaymentHash, + ) + require.NoError(t, err) + + return inv.State == invpkg.ContractCanceled + }, testTimeout, time.Millisecond*100, "invoice not canceled") + + // Fetch the invoice again and compare the number of cancelled HTLCs. + inv, err = ctx.registry.LookupInvoice( + ctxb, testInvoicePaymentHash, + ) + require.NoError(t, err) + + // Make sure all HTLCs are in the cancelled state. + require.Len(t, inv.Htlcs, 3) + for _, htlc := range inv.Htlcs { + require.Equal(t, invpkg.HtlcStateCanceled, htlc.State, + "expected HTLC to be canceled") + } +} diff --git a/invoices/mock.go b/invoices/mock.go index 5d929c227..68f5f66dc 100644 --- a/invoices/mock.go +++ b/invoices/mock.go @@ -86,6 +86,7 @@ func (m *MockInvoiceDB) DeleteCanceledInvoices(ctx context.Context) error { // MockHtlcModifier is a mock implementation of the HtlcModifier interface. type MockHtlcModifier struct { + mock.Mock } // Intercept generates a new intercept session for the given invoice. @@ -94,9 +95,23 @@ type MockHtlcModifier struct { // created in the first place, which is only the case if a client is // registered. func (m *MockHtlcModifier) Intercept( - _ HtlcModifyRequest, _ func(HtlcModifyResponse)) error { + req HtlcModifyRequest, callback func(HtlcModifyResponse)) error { - return nil + // If no expectations are set, return nil by default. + if len(m.ExpectedCalls) == 0 { + return nil + } + + args := m.Called(req, callback) + + // If a response was provided to the mock, execute the callback with it. + if response, ok := args.Get(1).(HtlcModifyResponse); ok && + callback != nil { + + callback(response) + } + + return args.Error(0) } // RegisterInterceptor sets the client callback function that will be diff --git a/invoices/test_utils_test.go b/invoices/test_utils_test.go index 509b92d85..6062b3b80 100644 --- a/invoices/test_utils_test.go +++ b/invoices/test_utils_test.go @@ -207,7 +207,7 @@ func getCircuitKey(htlcID uint64) invpkg.CircuitKey { // Note that this invoice *does not* have a payment address set. It will // create a regular invoice with a preimage is hodl is false, and a hodl // invoice with no preimage otherwise. -func newInvoice(t *testing.T, hodl bool) *invpkg.Invoice { +func newInvoice(t *testing.T, hodl bool, ampInvoice bool) *invpkg.Invoice { invoice := &invpkg.Invoice{ Terms: invpkg.ContractTerm{ Value: testInvoiceAmount, @@ -217,6 +217,23 @@ func newInvoice(t *testing.T, hodl bool) *invpkg.Invoice { CreationDate: testInvoiceCreationDate, } + // This makes the invoice an AMP invoice. We do not support AMP hodl + // invoices. + if ampInvoice { + ampFeature := lnwire.NewRawFeatureVector( + lnwire.TLVOnionPayloadOptional, + lnwire.PaymentAddrOptional, + lnwire.AMPRequired, + ) + + ampFeatures := lnwire.NewFeatureVector( + ampFeature, lnwire.Features, + ) + invoice.Terms.Features = ampFeatures + + return invoice + } + // If creating a hodl invoice, we don't include a preimage. if hodl { invoice.HodlInvoice = true