lnd/sweep/mock_test.go
Olaoluwa Osuntokun cb93f8c01a
sweep: add new AuxSweeper interface
In this commit, we add a new AuxSweeper interface. This'll take a set of
inputs, and a change addr for the sweep transaction, then optionally
return a new sweep output to be added to the sweep transaction.

We also add a new NotifyBroadcast method.  This'll be used to notify
that we're _about_ to broadcast a sweeping transaction. The set of
inputs is passed in, which allows the caller to prepare for the ultimate
broadcast of the sweeping transaction.

We also add ExtraTxOut to BumpRequest pass fees to NotifyBroadcast. This
allows the callee to know the total fee of the sweeping transaction.
2024-10-02 18:09:57 -07:00

336 lines
9.3 KiB
Go

package sweep
import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/stretchr/testify/mock"
)
// MockSweeperStore is a mock implementation of sweeper store. This type is
// exported, because it is currently used in nursery tests too.
type MockSweeperStore struct {
mock.Mock
}
// NewMockSweeperStore returns a new instance.
func NewMockSweeperStore() *MockSweeperStore {
return &MockSweeperStore{}
}
// IsOurTx determines whether a tx is published by us, based on its hash.
func (s *MockSweeperStore) IsOurTx(hash chainhash.Hash) (bool, error) {
args := s.Called(hash)
return args.Bool(0), args.Error(1)
}
// StoreTx stores a tx we are about to publish.
func (s *MockSweeperStore) StoreTx(tr *TxRecord) error {
args := s.Called(tr)
return args.Error(0)
}
// ListSweeps lists all the sweeps we have successfully published.
func (s *MockSweeperStore) ListSweeps() ([]chainhash.Hash, error) {
args := s.Called()
return args.Get(0).([]chainhash.Hash), args.Error(1)
}
// GetTx queries the database to find the tx that matches the given txid.
// Returns ErrTxNotFound if it cannot be found.
func (s *MockSweeperStore) GetTx(hash chainhash.Hash) (*TxRecord, error) {
args := s.Called(hash)
tr := args.Get(0)
if tr != nil {
return args.Get(0).(*TxRecord), args.Error(1)
}
return nil, args.Error(1)
}
// DeleteTx removes the given tx from db.
func (s *MockSweeperStore) DeleteTx(txid chainhash.Hash) error {
args := s.Called(txid)
return args.Error(0)
}
// Compile-time constraint to ensure MockSweeperStore implements SweeperStore.
var _ SweeperStore = (*MockSweeperStore)(nil)
type MockFeePreference struct {
mock.Mock
}
// Compile-time constraint to ensure MockFeePreference implements FeePreference.
var _ FeePreference = (*MockFeePreference)(nil)
func (m *MockFeePreference) String() string {
return "mock fee preference"
}
func (m *MockFeePreference) Estimate(estimator chainfee.Estimator,
maxFeeRate chainfee.SatPerKWeight) (chainfee.SatPerKWeight, error) {
args := m.Called(estimator, maxFeeRate)
if args.Get(0) == nil {
return 0, args.Error(1)
}
return args.Get(0).(chainfee.SatPerKWeight), args.Error(1)
}
type mockUtxoAggregator struct {
mock.Mock
}
// Compile-time constraint to ensure mockUtxoAggregator implements
// UtxoAggregator.
var _ UtxoAggregator = (*mockUtxoAggregator)(nil)
// ClusterInputs takes a list of inputs and groups them into clusters.
func (m *mockUtxoAggregator) ClusterInputs(inputs InputsMap) []InputSet {
args := m.Called(inputs)
return args.Get(0).([]InputSet)
}
// MockWallet is a mock implementation of the Wallet interface.
type MockWallet struct {
mock.Mock
}
// Compile-time constraint to ensure MockWallet implements Wallet.
var _ Wallet = (*MockWallet)(nil)
// BackEnd returns a name for the wallet's backing chain service, which could
// be e.g. btcd, bitcoind, neutrino, or another consensus service.
func (m *MockWallet) BackEnd() string {
args := m.Called()
return args.String(0)
}
// CheckMempoolAcceptance checks if the transaction can be accepted to the
// mempool.
func (m *MockWallet) CheckMempoolAcceptance(tx *wire.MsgTx) error {
args := m.Called(tx)
return args.Error(0)
}
// PublishTransaction performs cursory validation (dust checks, etc) and
// broadcasts the passed transaction to the Bitcoin network.
func (m *MockWallet) PublishTransaction(tx *wire.MsgTx, label string) error {
args := m.Called(tx, label)
return args.Error(0)
}
// ListUnspentWitnessFromDefaultAccount returns all unspent outputs which are
// version 0 witness programs from the default wallet account. The 'minConfs'
// and 'maxConfs' parameters indicate the minimum and maximum number of
// confirmations an output needs in order to be returned by this method.
func (m *MockWallet) ListUnspentWitnessFromDefaultAccount(
minConfs, maxConfs int32) ([]*lnwallet.Utxo, error) {
args := m.Called(minConfs, maxConfs)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]*lnwallet.Utxo), args.Error(1)
}
// WithCoinSelectLock will execute the passed function closure in a
// synchronized manner preventing any coin selection operations from proceeding
// while the closure is executing. This can be seen as the ability to execute a
// function closure under an exclusive coin selection lock.
func (m *MockWallet) WithCoinSelectLock(f func() error) error {
m.Called(f)
return f()
}
// RemoveDescendants removes any wallet transactions that spends
// outputs created by the specified transaction.
func (m *MockWallet) RemoveDescendants(tx *wire.MsgTx) error {
args := m.Called(tx)
return args.Error(0)
}
// FetchTx returns the transaction that corresponds to the transaction
// hash passed in. If the transaction can't be found then a nil
// transaction pointer is returned.
func (m *MockWallet) FetchTx(txid chainhash.Hash) (*wire.MsgTx, error) {
args := m.Called(txid)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*wire.MsgTx), args.Error(1)
}
// CancelRebroadcast is used to inform the rebroadcaster sub-system
// that it no longer needs to try to rebroadcast a transaction. This is
// used to ensure that invalid transactions (inputs spent) aren't
// retried in the background.
func (m *MockWallet) CancelRebroadcast(tx chainhash.Hash) {
m.Called(tx)
}
// GetTransactionDetails returns a detailed description of a tx given its
// transaction hash.
func (m *MockWallet) GetTransactionDetails(txHash *chainhash.Hash) (
*lnwallet.TransactionDetail, error) {
args := m.Called(txHash)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*lnwallet.TransactionDetail), args.Error(1)
}
// MockInputSet is a mock implementation of the InputSet interface.
type MockInputSet struct {
mock.Mock
}
// Compile-time constraint to ensure MockInputSet implements InputSet.
var _ InputSet = (*MockInputSet)(nil)
// Inputs returns the set of inputs that should be used to create a tx.
func (m *MockInputSet) Inputs() []input.Input {
args := m.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).([]input.Input)
}
// FeeRate returns the fee rate that should be used for the tx.
func (m *MockInputSet) FeeRate() chainfee.SatPerKWeight {
args := m.Called()
return args.Get(0).(chainfee.SatPerKWeight)
}
// AddWalletInputs adds wallet inputs to the set until a non-dust
// change output can be made. Return an error if there are not enough
// wallet inputs.
func (m *MockInputSet) AddWalletInputs(wallet Wallet) error {
args := m.Called(wallet)
return args.Error(0)
}
// NeedWalletInput returns true if the input set needs more wallet
// inputs.
func (m *MockInputSet) NeedWalletInput() bool {
args := m.Called()
return args.Bool(0)
}
// DeadlineHeight returns the deadline height for the set.
func (m *MockInputSet) DeadlineHeight() int32 {
args := m.Called()
return args.Get(0).(int32)
}
// Budget givens the total amount that can be used as fees by this input set.
func (m *MockInputSet) Budget() btcutil.Amount {
args := m.Called()
return args.Get(0).(btcutil.Amount)
}
// StartingFeeRate returns the max starting fee rate found in the inputs.
func (m *MockInputSet) StartingFeeRate() fn.Option[chainfee.SatPerKWeight] {
args := m.Called()
return args.Get(0).(fn.Option[chainfee.SatPerKWeight])
}
// MockBumper is a mock implementation of the interface Bumper.
type MockBumper struct {
mock.Mock
}
// Compile-time constraint to ensure MockBumper implements Bumper.
var _ Bumper = (*MockBumper)(nil)
// Broadcast broadcasts the transaction to the network.
func (m *MockBumper) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) {
args := m.Called(req)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(chan *BumpResult), args.Error(1)
}
// MockFeeFunction is a mock implementation of the FeeFunction interface.
type MockFeeFunction struct {
mock.Mock
}
// Compile-time constraint to ensure MockFeeFunction implements FeeFunction.
var _ FeeFunction = (*MockFeeFunction)(nil)
// FeeRate returns the current fee rate calculated by the fee function.
func (m *MockFeeFunction) FeeRate() chainfee.SatPerKWeight {
args := m.Called()
return args.Get(0).(chainfee.SatPerKWeight)
}
// Increment adds one delta to the current fee rate.
func (m *MockFeeFunction) Increment() (bool, error) {
args := m.Called()
return args.Bool(0), args.Error(1)
}
// IncreaseFeeRate increases the fee rate by one step.
func (m *MockFeeFunction) IncreaseFeeRate(confTarget uint32) (bool, error) {
args := m.Called(confTarget)
return args.Bool(0), args.Error(1)
}
type MockAuxSweeper struct{}
// DeriveSweepAddr takes a set of inputs, and the change address we'd
// use to sweep them, and maybe results an extra sweep output that we
// should add to the sweeping transaction.
func (*MockAuxSweeper) DeriveSweepAddr(_ []input.Input,
_ lnwallet.AddrWithKey) fn.Result[SweepOutput] {
return fn.Ok(SweepOutput{})
}
// NotifyBroadcast is used to notify external callers of the broadcast
// of a sweep transaction, generated by the passed BumpRequest.
func (*MockAuxSweeper) NotifyBroadcast(_ *BumpRequest, _ *wire.MsgTx,
_ btcutil.Amount) error {
return nil
}