sweep: introduce Bumper interface to handle RBF

This commit adds a new interface, `Bumper`, to handle RBF for a given
input set. It's responsible for creating the sweeping tx using the input
set, and monitors its confirmation status to decide whether a RBF should
be attempted or not.

We leave implementation details to future commits, and focus on mounting
this `Bumper` interface to our sweeper in this commit.
This commit is contained in:
yyforyongyu 2024-01-17 17:21:09 +08:00
parent a088501e47
commit 1187b868ad
No known key found for this signature in database
GPG key ID: 9BCD95C4FF296868
5 changed files with 812 additions and 51 deletions

142
sweep/fee_bumper.go Normal file
View file

@ -0,0 +1,142 @@
package sweep
import (
"errors"
"fmt"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
var (
// ErrInvalidBumpResult is returned when the bump result is invalid.
ErrInvalidBumpResult = errors.New("invalid bump result")
)
// Bumper defines an interface that can be used by other subsystems for fee
// bumping.
type Bumper interface {
// Broadcast is used to publish the tx created from the given inputs
// specified in the request. It handles the tx creation, broadcasts it,
// and monitors its confirmation status for potential fee bumping. It
// returns a chan that the caller can use to receive updates about the
// broadcast result and potential RBF attempts.
Broadcast(req *BumpRequest) (<-chan *BumpResult, error)
}
// BumpEvent represents the event of a fee bumping attempt.
type BumpEvent uint8
const (
// TxPublished is sent when the broadcast attempt is finished.
TxPublished BumpEvent = iota
// TxFailed is sent when the broadcast attempt fails.
TxFailed
// TxReplaced is sent when the original tx is replaced by a new one.
TxReplaced
// TxConfirmed is sent when the tx is confirmed.
TxConfirmed
// sentinalEvent is used to check if an event is unknown.
sentinalEvent
)
// String returns a human-readable string for the event.
func (e BumpEvent) String() string {
switch e {
case TxPublished:
return "Published"
case TxFailed:
return "Failed"
case TxReplaced:
return "Replaced"
case TxConfirmed:
return "Confirmed"
default:
return "Unknown"
}
}
// Unknown returns true if the event is unknown.
func (e BumpEvent) Unknown() bool {
return e >= sentinalEvent
}
// BumpRequest is used by the caller to give the Bumper the necessary info to
// create and manage potential fee bumps for a set of inputs.
type BumpRequest struct {
// Budget givens the total amount that can be used as fees by these
// inputs.
Budget btcutil.Amount
// Inputs is the set of inputs to sweep.
Inputs []input.Input
// DeadlineHeight is the block height at which the tx should be
// confirmed.
DeadlineHeight int32
// DeliveryAddress is the script to send the change output to.
DeliveryAddress []byte
// MaxFeeRate is the maximum fee rate that can be used for fee bumping.
MaxFeeRate chainfee.SatPerKWeight
}
// BumpResult is used by the Bumper to send updates about the tx being
// broadcast.
type BumpResult struct {
// Event is the type of event that the result is for.
Event BumpEvent
// Tx is the tx being broadcast.
Tx *wire.MsgTx
// ReplacedTx is the old, replaced tx if a fee bump is attempted.
ReplacedTx *wire.MsgTx
// FeeRate is the fee rate used for the new tx.
FeeRate chainfee.SatPerKWeight
// Fee is the fee paid by the new tx.
Fee btcutil.Amount
// Err is the error that occurred during the broadcast.
Err error
}
// Validate validates the BumpResult so it's safe to use.
func (b *BumpResult) Validate() error {
// Every result must have a tx.
if b.Tx == nil {
return fmt.Errorf("%w: nil tx", ErrInvalidBumpResult)
}
// Every result must have a known event.
if b.Event.Unknown() {
return fmt.Errorf("%w: unknown event", ErrInvalidBumpResult)
}
// If it's a replacing event, it must have a replaced tx.
if b.Event == TxReplaced && b.ReplacedTx == nil {
return fmt.Errorf("%w: nil replacing tx", ErrInvalidBumpResult)
}
// If it's a failed event, it must have an error.
if b.Event == TxFailed && b.Err == nil {
return fmt.Errorf("%w: nil error", ErrInvalidBumpResult)
}
// If it's a confirmed event, it must have a fee rate and fee.
if b.Event == TxConfirmed && (b.FeeRate == 0 || b.Fee == 0) {
return fmt.Errorf("%w: missing fee rate or fee",
ErrInvalidBumpResult)
}
return nil
}

52
sweep/fee_bumper_test.go Normal file
View file

@ -0,0 +1,52 @@
package sweep
import (
"testing"
"github.com/btcsuite/btcd/wire"
"github.com/stretchr/testify/require"
)
// TestBumpResultValidate tests the validate method of the BumpResult struct.
func TestBumpResultValidate(t *testing.T) {
t.Parallel()
// An empty result will give an error.
b := BumpResult{}
require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult)
// Unknown event type will give an error.
b = BumpResult{
Tx: &wire.MsgTx{},
Event: sentinalEvent,
}
require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult)
// A replacing event without a new tx will give an error.
b = BumpResult{
Tx: &wire.MsgTx{},
Event: TxReplaced,
}
require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult)
// A failed event without a failure reason will give an error.
b = BumpResult{
Tx: &wire.MsgTx{},
Event: TxFailed,
}
require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult)
// A confirmed event without fee info will give an error.
b = BumpResult{
Tx: &wire.MsgTx{},
Event: TxConfirmed,
}
require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult)
// Test a valid result.
b = BumpResult{
Tx: &wire.MsgTx{},
Event: TxPublished,
}
require.NoError(t, b.Validate())
}

View file

@ -462,3 +462,22 @@ func (m *MockInputSet) Budget() btcutil.Amount {
return args.Get(0).(btcutil.Amount)
}
// MockBumper is a mock implementation of the interface Bumper.
type MockBumper struct {
mock.Mock
}
// Compile-time constraint to ensure MockBumper implements Bumper.
var _ Bumper = (*MockBumper)(nil)
// Broadcast broadcasts the transaction to the network.
func (m *MockBumper) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) {
args := m.Called(req)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(chan *BumpResult), args.Error(1)
}

View file

@ -13,7 +13,6 @@ import (
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/labels"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
@ -41,6 +40,12 @@ var (
// an input is included in a publish attempt before giving up and
// returning an error to the caller.
DefaultMaxSweepAttempts = 10
// DefaultDeadlineDelta defines a default deadline delta (1 week) to be
// used when sweeping inputs with no deadline pressure.
//
// TODO(yy): make this configurable.
DefaultDeadlineDelta = int32(1008)
)
// Params contains the parameters that control the sweeping process.
@ -317,6 +322,10 @@ type UtxoSweeper struct {
// currentHeight is the best known height of the main chain. This is
// updated whenever a new block epoch is received.
currentHeight int32
// bumpResultChan is a channel that receives broadcast results from the
// TxPublisher.
bumpResultChan chan *BumpResult
}
// UtxoSweeperConfig contains dependencies of UtxoSweeper.
@ -364,6 +373,10 @@ type UtxoSweeperConfig struct {
// Aggregator is used to group inputs into clusters based on its
// implemention-specific strategy.
Aggregator UtxoAggregator
// Publisher is used to publish the sweep tx crafted here and monitors
// it for potential fee bumps.
Publisher Bumper
}
// Result is the struct that is pushed through the result channel. Callers can
@ -397,6 +410,7 @@ func New(cfg *UtxoSweeperConfig) *UtxoSweeper {
pendingSweepsReqs: make(chan *pendingSweepsReq),
quit: make(chan struct{}),
pendingInputs: make(pendingInputs),
bumpResultChan: make(chan *BumpResult, 100),
}
}
@ -670,11 +684,16 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) {
err: err,
}
// A new block comes in, update the bestHeight.
//
// TODO(yy): this is where we check our published transactions
// and perform RBF if needed. We'd also like to consult our fee
// bumper to get an updated fee rate.
case result := <-s.bumpResultChan:
// Handle the bump event.
err := s.handleBumpEvent(result)
if err != nil {
log.Errorf("Failed to handle bump event: %v",
err)
}
// A new block comes in, update the bestHeight, perform a check
// over all pending inputs and publish sweeping txns if needed.
case epoch, ok := <-blockEpochs:
if !ok {
// We should stop the sweeper before stopping
@ -779,8 +798,8 @@ func (s *UtxoSweeper) signalResult(pi *pendingInput, result Result) {
}
}
// sweep takes a set of preselected inputs, creates a sweep tx and publishes the
// tx. The output address is only marked as used if the publish succeeds.
// sweep takes a set of preselected inputs, creates a sweep tx and publishes
// the tx. The output address is only marked as used if the publish succeeds.
func (s *UtxoSweeper) sweep(set InputSet) error {
// Generate an output script if there isn't an unused script available.
if s.currentOutputScript == nil {
@ -791,20 +810,21 @@ func (s *UtxoSweeper) sweep(set InputSet) error {
s.currentOutputScript = pkScript
}
// Create sweep tx.
tx, fee, err := createSweepTx(
set.Inputs(), nil, s.currentOutputScript,
uint32(s.currentHeight), set.FeeRate(),
s.cfg.MaxFeeRate.FeePerKWeight(), s.cfg.Signer,
)
if err != nil {
return fmt.Errorf("create sweep tx: %w", err)
}
// Create a default deadline height, and replace it with set's
// DeadlineHeight if it's set.
deadlineHeight := s.currentHeight + DefaultDeadlineDelta
deadlineHeight = set.DeadlineHeight().UnwrapOr(deadlineHeight)
tr := &TxRecord{
Txid: tx.TxHash(),
FeeRate: uint64(set.FeeRate()),
Fee: uint64(fee),
// Create a fee bump request and ask the publisher to broadcast it. The
// publisher will then take over and start monitoring the tx for
// potential fee bump.
req := &BumpRequest{
Inputs: set.Inputs(),
Budget: set.Budget(),
DeadlineHeight: deadlineHeight,
DeliveryAddress: s.currentOutputScript,
MaxFeeRate: s.cfg.MaxFeeRate.FeePerKWeight(),
// TODO(yy): pass the strategy here.
}
// Reschedule the inputs that we just tried to sweep. This is done in
@ -812,13 +832,9 @@ func (s *UtxoSweeper) sweep(set InputSet) error {
// publish attempts and rescue them in the next sweep.
s.markInputsPendingPublish(set)
log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v",
tx.TxHash(), len(tx.TxIn), s.currentHeight)
// Publish the sweeping tx with customized label.
err = s.cfg.Wallet.PublishTransaction(
tx, labels.MakeLabel(labels.LabelTypeSweepTransaction, nil),
)
// Broadcast will return a read-only chan that we will listen to for
// this publish result and future RBF attempt.
resp, err := s.cfg.Publisher.Broadcast(req)
if err != nil {
outpoints := make([]wire.OutPoint, len(set.Inputs()))
for i, inp := range set.Inputs() {
@ -831,16 +847,11 @@ func (s *UtxoSweeper) sweep(set InputSet) error {
return err
}
// Inputs have been successfully published so we update their states.
err = s.markInputsPublished(tr, tx.TxIn)
if err != nil {
return err
}
// If there's no error, remove the output script. Otherwise keep it so
// that it can be reused for the next transaction and causes no address
// inflation.
s.currentOutputScript = nil
// Successfully sent the broadcast attempt, we now handle the result by
// subscribing to the result chan and listen for future updates about
// this tx.
s.wg.Add(1)
go s.monitorFeeBumpResult(resp)
return nil
}
@ -1557,3 +1568,167 @@ func (s *UtxoSweeper) sweepPendingInputs(inputs pendingInputs) {
}
}
}
// monitorFeeBumpResult subscribes to the passed result chan to listen for
// future updates about the sweeping tx.
//
// NOTE: must run as a goroutine.
func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) {
defer s.wg.Done()
for {
select {
case r := <-resultChan:
// Validate the result is valid.
if err := r.Validate(); err != nil {
log.Errorf("Received invalid result: %v", err)
continue
}
// Send the result back to the main event loop.
select {
case s.bumpResultChan <- r:
case <-s.quit:
log.Debug("Sweeper shutting down, skip " +
"sending bump result")
return
}
// The sweeping tx has been confirmed, we can exit the
// monitor now.
//
// TODO(yy): can instead remove the spend subscription
// in sweeper and rely solely on this event to mark
// inputs as Swept?
if r.Event == TxConfirmed || r.Event == TxFailed {
log.Debugf("Received %v for sweep tx %v, exit "+
"fee bump monitor", r.Event,
r.Tx.TxHash())
return
}
case <-s.quit:
log.Debugf("Sweeper shutting down, exit fee " +
"bump handler")
return
}
}
}
// handleBumpEventTxFailed handles the case where the tx has been failed to
// publish.
func (s *UtxoSweeper) handleBumpEventTxFailed(r *BumpResult) error {
tx, err := r.Tx, r.Err
log.Errorf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(), err)
outpoints := make([]wire.OutPoint, 0, len(tx.TxIn))
for _, inp := range tx.TxIn {
outpoints = append(outpoints, inp.PreviousOutPoint)
}
// TODO(yy): should we also remove the failed tx from db?
s.markInputsPublishFailed(outpoints)
return err
}
// handleBumpEventTxReplaced handles the case where the sweeping tx has been
// replaced by a new one.
func (s *UtxoSweeper) handleBumpEventTxReplaced(r *BumpResult) error {
oldTx := r.ReplacedTx
newTx := r.Tx
// Prepare a new record to replace the old one.
tr := &TxRecord{
Txid: newTx.TxHash(),
FeeRate: uint64(r.FeeRate),
Fee: uint64(r.Fee),
}
// Get the old record for logging purpose.
oldTxid := oldTx.TxHash()
record, err := s.cfg.Store.GetTx(oldTxid)
if err != nil {
log.Errorf("Fetch tx record for %v: %v", oldTxid, err)
return err
}
log.Infof("RBFed tx=%v(fee=%v, feerate=%v) with new tx=%v(fee=%v, "+
"feerate=%v)", record.Txid, record.Fee, record.FeeRate,
tr.Txid, tr.Fee, tr.FeeRate)
// The old sweeping tx has been replaced by a new one, we will update
// the tx record in the sweeper db.
//
// TODO(yy): we may also need to update the inputs in this tx to a new
// state. Suppose a replacing tx only spends a subset of the inputs
// here, we'd end up with the rest being marked as `StatePublished` and
// won't be aggregated in the next sweep. Atm it's fine as we always
// RBF the same input set.
if err := s.cfg.Store.DeleteTx(oldTxid); err != nil {
log.Errorf("Delete tx record for %v: %v", oldTxid, err)
return err
}
// Mark the inputs as published using the replacing tx.
return s.markInputsPublished(tr, r.Tx.TxIn)
}
// handleBumpEventTxPublished handles the case where the sweeping tx has been
// successfully published.
func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error {
tx := r.Tx
tr := &TxRecord{
Txid: tx.TxHash(),
FeeRate: uint64(r.FeeRate),
Fee: uint64(r.Fee),
}
// Inputs have been successfully published so we update their
// states.
err := s.markInputsPublished(tr, tx.TxIn)
if err != nil {
return err
}
log.Debugf("Published sweep tx %v, num_inputs=%v, height=%v",
tx.TxHash(), len(tx.TxIn), s.currentHeight)
// If there's no error, remove the output script. Otherwise
// keep it so that it can be reused for the next transaction
// and causes no address inflation.
s.currentOutputScript = nil
return nil
}
// handleBumpEvent handles the result sent from the bumper based on its event
// type.
//
// NOTE: TxConfirmed event is not handled, since we already subscribe to the
// input's spending event, we don't need to do anything here.
func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error {
log.Debugf("Received bump event [%v] for tx %v", r.Event, r.Tx.TxHash())
switch r.Event {
// The tx has been published, we update the inputs' state and create a
// record to be stored in the sweeper db.
case TxPublished:
return s.handleBumpEventTxPublished(r)
// The tx has failed, we update the inputs' state.
case TxFailed:
return s.handleBumpEventTxFailed(r)
// The tx has been replaced, we will remove the old tx and replace it
// with the new one.
case TxReplaced:
return s.handleBumpEventTxReplaced(r)
}
return nil
}

View file

@ -33,6 +33,8 @@ var (
testMaxInputsPerTx = uint32(3)
defaultFeePref = Params{Fee: FeeEstimateInfo{ConfTarget: 1}}
errDummy = errors.New("dummy error")
)
type sweeperTestContext struct {
@ -137,6 +139,12 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext {
currentHeight: mockChainHeight,
}
// Create a mock fee bumper.
mockBumper := &MockBumper{}
t.Cleanup(func() {
mockBumper.AssertExpectations(t)
})
ctx.sweeper = New(&UtxoSweeperConfig{
Notifier: notifier,
Wallet: backend,
@ -153,6 +161,7 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext {
MaxSweepAttempts: testMaxSweepAttempts,
MaxFeeRate: DefaultMaxFeeRate,
Aggregator: aggregator,
Publisher: mockBumper,
})
ctx.sweeper.Start()
@ -2410,16 +2419,27 @@ func TestSweepPendingInputs(t *testing.T) {
// Create a mock wallet and aggregator.
wallet := &MockWallet{}
defer wallet.AssertExpectations(t)
aggregator := &mockUtxoAggregator{}
defer aggregator.AssertExpectations(t)
publisher := &MockBumper{}
defer publisher.AssertExpectations(t)
// Create a test sweeper.
s := New(&UtxoSweeperConfig{
Wallet: wallet,
Aggregator: aggregator,
Publisher: publisher,
GenSweepScript: func() ([]byte, error) {
return testPubKey.SerializeCompressed(), nil
},
})
// Create an input set that needs wallet inputs.
setNeedWallet := &MockInputSet{}
defer setNeedWallet.AssertExpectations(t)
// Mock this set to ask for wallet input.
setNeedWallet.On("NeedWalletInput").Return(true).Once()
@ -2430,15 +2450,18 @@ func TestSweepPendingInputs(t *testing.T) {
// Create an input set that doesn't need wallet inputs.
normalSet := &MockInputSet{}
defer normalSet.AssertExpectations(t)
normalSet.On("NeedWalletInput").Return(false).Once()
// Mock the methods used in `sweep`. This is not important for this
// unit test.
feeRate := chainfee.SatPerKWeight(1000)
setNeedWallet.On("Inputs").Return(nil).Once()
setNeedWallet.On("FeeRate").Return(feeRate).Once()
normalSet.On("Inputs").Return(nil).Once()
normalSet.On("FeeRate").Return(feeRate).Once()
setNeedWallet.On("Inputs").Return(nil).Times(4)
setNeedWallet.On("DeadlineHeight").Return(fn.None[int32]()).Once()
setNeedWallet.On("Budget").Return(btcutil.Amount(1)).Once()
normalSet.On("Inputs").Return(nil).Times(4)
normalSet.On("DeadlineHeight").Return(fn.None[int32]()).Once()
normalSet.On("Budget").Return(btcutil.Amount(1)).Once()
// Make pending inputs for testing. We don't need real values here as
// the returned clusters are mocked.
@ -2449,19 +2472,369 @@ func TestSweepPendingInputs(t *testing.T) {
setNeedWallet, normalSet,
})
// Set change output script to an invalid value. This should cause the
// Mock `Broadcast` to return an error. This should cause the
// `createSweepTx` inside `sweep` to fail. This is done so we can
// terminate the method early as we are only interested in testing the
// workflow in `sweepPendingInputs`. We don't need to test `sweep` here
// as it should be tested in its own unit test.
s.currentOutputScript = []byte{1}
dummyErr := errors.New("dummy error")
publisher.On("Broadcast", mock.Anything).Return(nil, dummyErr).Twice()
// Call the method under test.
s.sweepPendingInputs(pis)
// Assert mocked methods are called as expected.
wallet.AssertExpectations(t)
aggregator.AssertExpectations(t)
setNeedWallet.AssertExpectations(t)
normalSet.AssertExpectations(t)
}
// TestHandleBumpEventTxFailed checks that the sweeper correctly handles the
// case where the bump event tx fails to be published.
func TestHandleBumpEventTxFailed(t *testing.T) {
t.Parallel()
// Create a test sweeper.
s := New(&UtxoSweeperConfig{})
var (
// Create four testing outpoints.
op1 = wire.OutPoint{Hash: chainhash.Hash{1}}
op2 = wire.OutPoint{Hash: chainhash.Hash{2}}
op3 = wire.OutPoint{Hash: chainhash.Hash{3}}
opNotExist = wire.OutPoint{Hash: chainhash.Hash{4}}
)
// Create three mock inputs.
input1 := &input.MockInput{}
defer input1.AssertExpectations(t)
input2 := &input.MockInput{}
defer input2.AssertExpectations(t)
input3 := &input.MockInput{}
defer input3.AssertExpectations(t)
// Construct the initial state for the sweeper.
s.pendingInputs = pendingInputs{
op1: &pendingInput{Input: input1, state: StatePendingPublish},
op2: &pendingInput{Input: input2, state: StatePendingPublish},
op3: &pendingInput{Input: input3, state: StatePendingPublish},
}
// Create a testing tx that spends the first two inputs.
tx := &wire.MsgTx{
TxIn: []*wire.TxIn{
{PreviousOutPoint: op1},
{PreviousOutPoint: op2},
{PreviousOutPoint: opNotExist},
},
}
// Create a testing bump result.
br := &BumpResult{
Tx: tx,
Event: TxFailed,
Err: errDummy,
}
// Call the method under test.
err := s.handleBumpEvent(br)
require.ErrorIs(t, err, errDummy)
// Assert the states of the first two inputs are updated.
require.Equal(t, StatePublishFailed, s.pendingInputs[op1].state)
require.Equal(t, StatePublishFailed, s.pendingInputs[op2].state)
// Assert the state of the third input is not updated.
require.Equal(t, StatePendingPublish, s.pendingInputs[op3].state)
// Assert the non-existing input is not added to the pending inputs.
require.NotContains(t, s.pendingInputs, opNotExist)
}
// TestHandleBumpEventTxReplaced checks that the sweeper correctly handles the
// case where the bump event tx is replaced.
func TestHandleBumpEventTxReplaced(t *testing.T) {
t.Parallel()
// Create a mock store.
store := &MockSweeperStore{}
defer store.AssertExpectations(t)
// Create a test sweeper.
s := New(&UtxoSweeperConfig{
Store: store,
})
// Create a testing outpoint.
op := wire.OutPoint{Hash: chainhash.Hash{1}}
// Create a mock input.
inp := &input.MockInput{}
defer inp.AssertExpectations(t)
// Construct the initial state for the sweeper.
s.pendingInputs = pendingInputs{
op: &pendingInput{Input: inp, state: StatePendingPublish},
}
// Create a testing tx that spends the input.
tx := &wire.MsgTx{
LockTime: 1,
TxIn: []*wire.TxIn{
{PreviousOutPoint: op},
},
}
// Create a replacement tx.
replacementTx := &wire.MsgTx{
LockTime: 2,
TxIn: []*wire.TxIn{
{PreviousOutPoint: op},
},
}
// Create a testing bump result.
br := &BumpResult{
Tx: replacementTx,
ReplacedTx: tx,
Event: TxReplaced,
}
// Mock the store to return an error.
dummyErr := errors.New("dummy error")
store.On("GetTx", tx.TxHash()).Return(nil, dummyErr).Once()
// Call the method under test and assert the error is returned.
err := s.handleBumpEventTxReplaced(br)
require.ErrorIs(t, err, dummyErr)
// Mock the store to return the old tx record.
store.On("GetTx", tx.TxHash()).Return(&TxRecord{
Txid: tx.TxHash(),
}, nil).Once()
// Mock an error returned when deleting the old tx record.
store.On("DeleteTx", tx.TxHash()).Return(dummyErr).Once()
// Call the method under test and assert the error is returned.
err = s.handleBumpEventTxReplaced(br)
require.ErrorIs(t, err, dummyErr)
// Mock the store to return the old tx record and delete it without
// error.
store.On("GetTx", tx.TxHash()).Return(&TxRecord{
Txid: tx.TxHash(),
}, nil).Once()
store.On("DeleteTx", tx.TxHash()).Return(nil).Once()
// Mock the store to save the new tx record.
store.On("StoreTx", &TxRecord{
Txid: replacementTx.TxHash(),
Published: true,
}).Return(nil).Once()
// Call the method under test.
err = s.handleBumpEventTxReplaced(br)
require.NoError(t, err)
// Assert the state of the input is updated.
require.Equal(t, StatePublished, s.pendingInputs[op].state)
}
// TestHandleBumpEventTxPublished checks that the sweeper correctly handles the
// case where the bump event tx is published.
func TestHandleBumpEventTxPublished(t *testing.T) {
t.Parallel()
// Create a mock store.
store := &MockSweeperStore{}
defer store.AssertExpectations(t)
// Create a test sweeper.
s := New(&UtxoSweeperConfig{
Store: store,
})
// Create a testing outpoint.
op := wire.OutPoint{Hash: chainhash.Hash{1}}
// Create a mock input.
inp := &input.MockInput{}
defer inp.AssertExpectations(t)
// Construct the initial state for the sweeper.
s.pendingInputs = pendingInputs{
op: &pendingInput{Input: inp, state: StatePendingPublish},
}
// Create a testing tx that spends the input.
tx := &wire.MsgTx{
LockTime: 1,
TxIn: []*wire.TxIn{
{PreviousOutPoint: op},
},
}
// Create a testing bump result.
br := &BumpResult{
Tx: tx,
Event: TxPublished,
}
// Mock the store to save the new tx record.
store.On("StoreTx", &TxRecord{
Txid: tx.TxHash(),
Published: true,
}).Return(nil).Once()
// Call the method under test.
err := s.handleBumpEventTxPublished(br)
require.NoError(t, err)
// Assert the state of the input is updated.
require.Equal(t, StatePublished, s.pendingInputs[op].state)
}
// TestMonitorFeeBumpResult checks that the fee bump monitor loop correctly
// exits when the sweeper is stopped, the tx is confirmed or failed.
func TestMonitorFeeBumpResult(t *testing.T) {
// Create a mock store.
store := &MockSweeperStore{}
defer store.AssertExpectations(t)
// Create a test sweeper.
s := New(&UtxoSweeperConfig{
Store: store,
})
// Create a testing outpoint.
op := wire.OutPoint{Hash: chainhash.Hash{1}}
// Create a mock input.
inp := &input.MockInput{}
defer inp.AssertExpectations(t)
// Construct the initial state for the sweeper.
s.pendingInputs = pendingInputs{
op: &pendingInput{Input: inp, state: StatePendingPublish},
}
// Create a testing tx that spends the input.
tx := &wire.MsgTx{
LockTime: 1,
TxIn: []*wire.TxIn{
{PreviousOutPoint: op},
},
}
testCases := []struct {
name string
setupResultChan func() <-chan *BumpResult
shouldExit bool
}{
{
// When a tx confirmed event is received, we expect to
// exit the monitor loop.
name: "tx confirmed",
// We send a result with TxConfirmed event to the
// result channel.
setupResultChan: func() <-chan *BumpResult {
// Create a result chan.
resultChan := make(chan *BumpResult, 1)
resultChan <- &BumpResult{
Tx: tx,
Event: TxConfirmed,
Fee: 10000,
FeeRate: 100,
}
return resultChan
},
shouldExit: true,
},
{
// When a tx failed event is received, we expect to
// exit the monitor loop.
name: "tx failed",
// We send a result with TxConfirmed event to the
// result channel.
setupResultChan: func() <-chan *BumpResult {
// Create a result chan.
resultChan := make(chan *BumpResult, 1)
resultChan <- &BumpResult{
Tx: tx,
Event: TxFailed,
Err: errDummy,
}
return resultChan
},
shouldExit: true,
},
{
// When processing non-confirmed events, the monitor
// should not exit.
name: "no exit on normal event",
// We send a result with TxPublished and mock the
// method `StoreTx` to return nil.
setupResultChan: func() <-chan *BumpResult {
// Create a result chan.
resultChan := make(chan *BumpResult, 1)
resultChan <- &BumpResult{
Tx: tx,
Event: TxPublished,
}
return resultChan
},
shouldExit: false,
}, {
// When the sweeper is shutting down, the monitor loop
// should exit.
name: "exit on sweeper shutdown",
// We don't send anything but quit the sweeper.
setupResultChan: func() <-chan *BumpResult {
close(s.quit)
return nil
},
shouldExit: true,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
// Setup the testing result channel.
resultChan := tc.setupResultChan()
// Create a done chan that's used to signal the monitor
// has exited.
done := make(chan struct{})
s.wg.Add(1)
go func() {
s.monitorFeeBumpResult(resultChan)
close(done)
}()
// The monitor is expected to exit, we check it's done
// in one second or fail.
if tc.shouldExit {
select {
case <-done:
case <-time.After(1 * time.Second):
require.Fail(t, "monitor not exited")
}
return
}
// The monitor should not exit, check it doesn't close
// the `done` channel within one second.
select {
case <-done:
require.Fail(t, "monitor exited")
case <-time.After(1 * time.Second):
}
})
}
}