channeldb: add HasSettledHTLC and PaymentFailed fields to state

This commit is contained in:
yyforyongyu 2023-03-06 15:48:24 +08:00 committed by Olaoluwa Osuntokun
parent 52c00e8cc4
commit 89ac071e56
2 changed files with 23 additions and 28 deletions

View file

@ -167,6 +167,14 @@ type MPPaymentState struct {
// shards should be launched. This value is true if we have an HTLC // shards should be launched. This value is true if we have an HTLC
// settled or the payment has an error. // settled or the payment has an error.
Terminate bool Terminate bool
// HasSettledHTLC is true if at least one of the payment's HTLCs is
// settled.
HasSettledHTLC bool
// PaymentFailed is true if the payment has been marked as failed with
// a reason.
PaymentFailed bool
} }
// MPPayment is a wrapper around a payment's PaymentCreationInfo and // MPPayment is a wrapper around a payment's PaymentCreationInfo and
@ -274,11 +282,6 @@ func (m *MPPayment) GetAttempt(id uint64) (*HTLCAttempt, error) {
// registrations when it's newly created, or none of its HTLCs is in a terminal // registrations when it's newly created, or none of its HTLCs is in a terminal
// state. // state.
func (m *MPPayment) Registrable() error { func (m *MPPayment) Registrable() error {
// Get the terminal info.
settle, reason := m.TerminalInfo()
settled := settle != nil
failed := reason != nil
// If updating the payment is not allowed, we can't register new HTLCs. // If updating the payment is not allowed, we can't register new HTLCs.
// Otherwise, the status must be either `StatusInitiated` or // Otherwise, the status must be either `StatusInitiated` or
// `StatusInFlight`. // `StatusInFlight`.
@ -294,12 +297,12 @@ func (m *MPPayment) Registrable() error {
// There are still inflight HTLCs and we need to check whether there // There are still inflight HTLCs and we need to check whether there
// are settled HTLCs or the payment is failed. If we already have // are settled HTLCs or the payment is failed. If we already have
// settled HTLCs, we won't allow adding more HTLCs. // settled HTLCs, we won't allow adding more HTLCs.
if settled { if m.State.HasSettledHTLC {
return ErrPaymentPendingSettled return ErrPaymentPendingSettled
} }
// If the payment is already failed, we won't allow adding more HTLCs. // If the payment is already failed, we won't allow adding more HTLCs.
if failed { if m.State.PaymentFailed {
return ErrPaymentPendingFailed return ErrPaymentPendingFailed
} }
@ -341,6 +344,8 @@ func (m *MPPayment) setState() error {
RemainingAmt: totalAmt - sentAmt, RemainingAmt: totalAmt - sentAmt,
FeesPaid: fees, FeesPaid: fees,
Terminate: terminate, Terminate: terminate,
HasSettledHTLC: settle != nil,
PaymentFailed: failure != nil,
} }
m.Status = status m.Status = status

View file

@ -87,29 +87,15 @@ func TestRegistrable(t *testing.T) {
}, },
} }
// Create test objects.
reason := FailureReasonError
htlcSettled := HTLCAttempt{
Settle: &HTLCSettleInfo{},
}
for i, tc := range testCases { for i, tc := range testCases {
i, tc := i, tc i, tc := i, tc
p := &MPPayment{ p := &MPPayment{
Status: tc.status, Status: tc.status,
} State: &MPPaymentState{
HasSettledHTLC: tc.hasSettledHTLC,
// Add the settled htlc to the payment if needed. PaymentFailed: tc.paymentFailed,
htlcs := make([]HTLCAttempt, 0) },
if tc.hasSettledHTLC {
htlcs = append(htlcs, htlcSettled)
}
p.HTLCs = htlcs
// Add the failure reason if needed.
if tc.paymentFailed {
p.FailureReason = &reason
} }
name := fmt.Sprintf("test_%d_%s", i, p.Status.String()) name := fmt.Sprintf("test_%d_%s", i, p.Status.String())
@ -173,7 +159,8 @@ func TestPaymentSetState(t *testing.T) {
NumAttemptsInFlight: 1, NumAttemptsInFlight: 1,
RemainingAmt: 1000 - 90, RemainingAmt: 1000 - 90,
FeesPaid: 10, FeesPaid: 10,
Terminate: false, HasSettledHTLC: false,
PaymentFailed: false,
}, },
}, },
{ {
@ -193,7 +180,8 @@ func TestPaymentSetState(t *testing.T) {
NumAttemptsInFlight: 0, NumAttemptsInFlight: 0,
RemainingAmt: 1000 - 90, RemainingAmt: 1000 - 90,
FeesPaid: 10, FeesPaid: 10,
Terminate: true, HasSettledHTLC: true,
PaymentFailed: false,
}, },
}, },
{ {
@ -211,13 +199,15 @@ func TestPaymentSetState(t *testing.T) {
NumAttemptsInFlight: 0, NumAttemptsInFlight: 0,
RemainingAmt: 1000, RemainingAmt: 1000,
FeesPaid: 0, FeesPaid: 0,
Terminate: true, HasSettledHTLC: false,
PaymentFailed: true,
}, },
}, },
} }
for _, tc := range testCases { for _, tc := range testCases {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()