sweep: add method markInputFailed

This commit is contained in:
yyforyongyu 2024-02-28 03:10:49 +08:00
parent a8f5a09dea
commit 34b6a3d718
No known key found for this signature in database
GPG key ID: 9BCD95C4FF296868
2 changed files with 43 additions and 1 deletions

View file

@ -1421,7 +1421,7 @@ func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) {
)
if err != nil {
err := fmt.Errorf("wait for spend: %w", err)
s.signalResult(pi, Result{Err: err})
s.markInputFailed(pi, err)
return
}
@ -1624,6 +1624,21 @@ func (s *UtxoSweeper) markInputsSwept(tx *wire.MsgTx, isOurTx bool) error {
return nil
}
// markInputFailed marks the given input as failed and won't be retried. It
// will also notify all the subscribers of this input.
func (s *UtxoSweeper) markInputFailed(pi *pendingInput, err error) {
log.Errorf("Failed to sweep input: %v, error: %v", pi, err)
pi.state = StateFailed
// Remove all other inputs in this exclusive group.
if pi.params.ExclusiveGroup != nil {
s.removeExclusiveGroup(*pi.params.ExclusiveGroup)
}
s.signalResult(pi, Result{Err: err})
}
// updateSweeperInputs updates the sweeper's internal state and returns a map
// of inputs to be swept. It will remove the inputs that are in final states,
// and returns a map of inputs that have either StateInit or

View file

@ -2369,3 +2369,30 @@ func TestAttachAvailableRBFInfo(t *testing.T) {
mockMempool.AssertExpectations(t)
mockStore.AssertExpectations(t)
}
// TestMarkInputFailed checks that the input is marked as failed as expected.
func TestMarkInputFailed(t *testing.T) {
t.Parallel()
// Create a mock input.
mockInput := &input.MockInput{}
defer mockInput.AssertExpectations(t)
// Mock the `OutPoint` to return a dummy outpoint.
mockInput.On("OutPoint").Return(&wire.OutPoint{Hash: chainhash.Hash{1}})
// Create a test sweeper.
s := New(&UtxoSweeperConfig{})
// Create a testing pending input.
pi := &pendingInput{
state: StateInit,
Input: mockInput,
}
// Call the method under test.
s.markInputFailed(pi, errors.New("dummy error"))
// Assert the state is updated.
require.Equal(t, StateFailed, pi.state)
}