diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index 1228d6631..015aa2318 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -458,6 +458,253 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { } } +// TestPaymentControlMultiShard checks the ability of payment control to +// have multiple in-flight HTLCs for a single payment. +func TestPaymentControlMultiShard(t *testing.T) { + t.Parallel() + + // We will register three HTLC attempts, and always fail the second + // one. We'll generate all combinations of settling/failing the first + // and third HTLC, and assert that the payment status end up as we + // expect. + type testCase struct { + settleFirst bool + settleLast bool + } + + var tests []testCase + for _, f := range []bool{true, false} { + for _, l := range []bool{true, false} { + tests = append(tests, testCase{f, l}) + } + } + + runSubTest := func(t *testing.T, test testCase) { + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + info, attempt, preimg, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Init the payment, moving it to the StatusInFlight state. + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, pControl, info.PaymentHash, info, nil, nil, + ) + + // Create three unique attempts we'll use for the test, and + // register them with the payment control. + var attempts []*HTLCAttemptInfo + for i := uint64(0); i < 3; i++ { + a := *attempt + a.AttemptID = i + attempts = append(attempts, &a) + + err = pControl.RegisterAttempt(info.PaymentHash, &a) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + assertPaymentStatus( + t, pControl, info.PaymentHash, StatusInFlight, + ) + + htlc := &htlcStatus{ + HTLCAttemptInfo: &a, + } + assertPaymentInfo( + t, pControl, info.PaymentHash, info, nil, htlc, + ) + } + + // Fail the second attempt. + a := attempts[1] + htlcFail := HTLCFailUnreadable + err = pControl.FailAttempt( + info.PaymentHash, a.AttemptID, + &HTLCFailInfo{ + Reason: htlcFail, + }, + ) + if err != nil { + t.Fatal(err) + } + + htlc := &htlcStatus{ + HTLCAttemptInfo: a, + failure: &htlcFail, + } + assertPaymentInfo( + t, pControl, info.PaymentHash, info, nil, htlc, + ) + + // Payment should still be in-flight. + assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) + + // Depending on the test case, settle or fail the first attempt. + a = attempts[0] + htlc = &htlcStatus{ + HTLCAttemptInfo: a, + } + + var firstFailReason *FailureReason + if test.settleFirst { + _, err := pControl.SettleAttempt( + info.PaymentHash, a.AttemptID, + &HTLCSettleInfo{ + Preimage: preimg, + }, + ) + if err != nil { + t.Fatalf("error shouldn't have been "+ + "received, got: %v", err) + } + + // Assert that the HTLC has had the preimage recorded. + htlc.settle = &preimg + assertPaymentInfo( + t, pControl, info.PaymentHash, info, nil, htlc, + ) + } else { + err := pControl.FailAttempt( + info.PaymentHash, a.AttemptID, + &HTLCFailInfo{ + Reason: htlcFail, + }, + ) + if err != nil { + t.Fatalf("error shouldn't have been "+ + "received, got: %v", err) + } + + // Assert the failure was recorded. + htlc.failure = &htlcFail + assertPaymentInfo( + t, pControl, info.PaymentHash, info, nil, htlc, + ) + + // We also record a payment level fail, to move it into + // a terminal state. + failReason := FailureReasonNoRoute + _, err = pControl.Fail(info.PaymentHash, failReason) + if err != nil { + t.Fatalf("unable to fail payment hash: %v", err) + } + + // Record the reason we failed the payment, such that + // we can assert this later in the test. + firstFailReason = &failReason + } + + // The payment should still be considered in-flight, since there + // is still an active HTLC. + assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) + + // Try to register yet another attempt. This should fail now + // that the payment has reached a terminal condition. + b := *attempt + b.AttemptID = 3 + err = pControl.RegisterAttempt(info.PaymentHash, &b) + if err != ErrPaymentTerminal { + t.Fatalf("expected ErrPaymentTerminal, got: %v", err) + } + + assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) + + // Settle or fail the remaining attempt based on the testcase. + a = attempts[2] + htlc = &htlcStatus{ + HTLCAttemptInfo: a, + } + if test.settleLast { + // Settle the last outstanding attempt. + _, err = pControl.SettleAttempt( + info.PaymentHash, a.AttemptID, + &HTLCSettleInfo{ + Preimage: preimg, + }, + ) + if err != nil { + t.Fatalf("error shouldn't have been "+ + "received, got: %v", err) + } + + htlc.settle = &preimg + assertPaymentInfo( + t, pControl, info.PaymentHash, info, + firstFailReason, htlc, + ) + } else { + // Fail the attempt. + err := pControl.FailAttempt( + info.PaymentHash, a.AttemptID, + &HTLCFailInfo{ + Reason: htlcFail, + }, + ) + if err != nil { + t.Fatalf("error shouldn't have been "+ + "received, got: %v", err) + } + + // Assert the failure was recorded. + htlc.failure = &htlcFail + assertPaymentInfo( + t, pControl, info.PaymentHash, info, + firstFailReason, htlc, + ) + + // Check that we can override any perevious terminal + // failure. This is to allow multiple concurrent shard + // write a terminal failure to the database without + // syncing. + failReason := FailureReasonPaymentDetails + _, err = pControl.Fail(info.PaymentHash, failReason) + if err != nil { + t.Fatalf("unable to fail payment hash: %v", err) + } + } + + // If any of the two attempts settled, the payment should end + // up in the Succeeded state. If both failed the payment should + // also be Failed at this poinnt. + finalStatus := StatusFailed + expRegErr := ErrPaymentAlreadyFailed + if test.settleFirst || test.settleLast { + finalStatus = StatusSucceeded + expRegErr = ErrPaymentAlreadySucceeded + } + + assertPaymentStatus(t, pControl, info.PaymentHash, finalStatus) + + // Finally assert we cannot register more attempts. + err = pControl.RegisterAttempt(info.PaymentHash, &b) + if err != expRegErr { + t.Fatalf("expected error %v, got: %v", expRegErr, err) + } + } + + for _, test := range tests { + test := test + subTest := fmt.Sprintf("first=%v, second=%v", + test.settleFirst, test.settleLast) + + t.Run(subTest, func(t *testing.T) { + runSubTest(t, test) + }) + } +} + // assertPaymentStatus retrieves the status of the payment referred to by hash // and compares it with the expected state. func assertPaymentStatus(t *testing.T, p *PaymentControl,