lnd/sweep/mock_test.go
yyforyongyu 4c13ea1747
sweep: pass default deadline height when clustering inputs
This commit changes the method `ClusterInputs` to also take a default
deadline height. Previously, when calculating the default deadline
height for a non-time sensitive input, we would first cluster it with
other non-time sensitive inputs, then give it a deadline before we are
about to `sweep`. This is now moved to the step where we decide to
cluster inputs, allowing time-sensitive and non-sensitive inputs to be
grouped together, if they happen to share the same deadline heights.
2024-04-19 21:33:32 +08:00

548 lines
14 KiB
Go

package sweep
import (
"sync"
"testing"
"time"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/stretchr/testify/mock"
)
// mockBackend simulates a chain backend for realistic behaviour in unit tests
// around double spends.
type mockBackend struct {
t *testing.T
lock sync.Mutex
notifier *MockNotifier
confirmedSpendInputs map[wire.OutPoint]struct{}
unconfirmedTxes map[chainhash.Hash]*wire.MsgTx
unconfirmedSpendInputs map[wire.OutPoint]struct{}
publishChan chan wire.MsgTx
walletUtxos []*lnwallet.Utxo
utxoCnt int
}
func newMockBackend(t *testing.T, notifier *MockNotifier) *mockBackend {
return &mockBackend{
t: t,
notifier: notifier,
unconfirmedTxes: make(map[chainhash.Hash]*wire.MsgTx),
confirmedSpendInputs: make(map[wire.OutPoint]struct{}),
unconfirmedSpendInputs: make(map[wire.OutPoint]struct{}),
publishChan: make(chan wire.MsgTx, 2),
}
}
func (b *mockBackend) CheckMempoolAcceptance(tx *wire.MsgTx) error {
return nil
}
func (b *mockBackend) publishTransaction(tx *wire.MsgTx) error {
b.lock.Lock()
defer b.lock.Unlock()
txHash := tx.TxHash()
if _, ok := b.unconfirmedTxes[txHash]; ok {
// Tx already exists
testLog.Tracef("mockBackend duplicate tx %v", tx.TxHash())
return lnwallet.ErrDoubleSpend
}
for _, in := range tx.TxIn {
if _, ok := b.unconfirmedSpendInputs[in.PreviousOutPoint]; ok {
// Double spend
testLog.Tracef("mockBackend double spend tx %v",
tx.TxHash())
return lnwallet.ErrDoubleSpend
}
if _, ok := b.confirmedSpendInputs[in.PreviousOutPoint]; ok {
// Already included in block
testLog.Tracef("mockBackend already in block tx %v",
tx.TxHash())
return lnwallet.ErrDoubleSpend
}
}
b.unconfirmedTxes[txHash] = tx
for _, in := range tx.TxIn {
b.unconfirmedSpendInputs[in.PreviousOutPoint] = struct{}{}
}
testLog.Tracef("mockBackend publish tx %v", tx.TxHash())
return nil
}
func (b *mockBackend) PublishTransaction(tx *wire.MsgTx, _ string) error {
log.Tracef("Publishing tx %v", tx.TxHash())
err := b.publishTransaction(tx)
select {
case b.publishChan <- *tx:
case <-time.After(defaultTestTimeout):
b.t.Fatalf("unexpected tx published")
}
return err
}
func (b *mockBackend) ListUnspentWitnessFromDefaultAccount(minConfs,
maxConfs int32) ([]*lnwallet.Utxo, error) {
b.lock.Lock()
defer b.lock.Unlock()
// Each time we list output, we increment the utxo counter, to
// ensure we don't return the same outpoint every time.
b.utxoCnt++
for i := range b.walletUtxos {
b.walletUtxos[i].OutPoint.Hash[0] = byte(b.utxoCnt)
}
return b.walletUtxos, nil
}
func (b *mockBackend) WithCoinSelectLock(f func() error) error {
return f()
}
func (b *mockBackend) deleteUnconfirmed(txHash chainhash.Hash) {
b.lock.Lock()
defer b.lock.Unlock()
tx, ok := b.unconfirmedTxes[txHash]
if !ok {
// Tx already exists
testLog.Errorf("mockBackend delete tx not existing %v", txHash)
return
}
testLog.Tracef("mockBackend delete tx %v", tx.TxHash())
delete(b.unconfirmedTxes, txHash)
for _, in := range tx.TxIn {
delete(b.unconfirmedSpendInputs, in.PreviousOutPoint)
}
}
func (b *mockBackend) mine() {
b.lock.Lock()
defer b.lock.Unlock()
notifications := make(map[wire.OutPoint]*wire.MsgTx)
for _, tx := range b.unconfirmedTxes {
testLog.Tracef("mockBackend mining tx %v", tx.TxHash())
for _, in := range tx.TxIn {
b.confirmedSpendInputs[in.PreviousOutPoint] = struct{}{}
notifications[in.PreviousOutPoint] = tx
}
}
b.unconfirmedSpendInputs = make(map[wire.OutPoint]struct{})
b.unconfirmedTxes = make(map[chainhash.Hash]*wire.MsgTx)
for outpoint, tx := range notifications {
testLog.Tracef("mockBackend delivering spend ntfn for %v",
outpoint)
b.notifier.SpendOutpoint(outpoint, *tx)
}
}
func (b *mockBackend) isDone() bool {
return len(b.unconfirmedTxes) == 0
}
func (b *mockBackend) RemoveDescendants(*wire.MsgTx) error {
return nil
}
func (b *mockBackend) FetchTx(chainhash.Hash) (*wire.MsgTx, error) {
return nil, nil
}
func (b *mockBackend) CancelRebroadcast(tx chainhash.Hash) {
}
// GetTransactionDetails returns a detailed description of a tx given its
// transaction hash.
func (b *mockBackend) GetTransactionDetails(txHash *chainhash.Hash) (
*lnwallet.TransactionDetail, error) {
return nil, nil
}
// mockFeeEstimator implements a mock fee estimator. It closely resembles
// lnwallet.StaticFeeEstimator with the addition that fees can be changed for
// testing purposes in a thread safe manner.
//
// TODO(yy): replace it with chainfee.MockEstimator once it's merged.
type mockFeeEstimator struct {
feePerKW chainfee.SatPerKWeight
relayFee chainfee.SatPerKWeight
blocksToFee map[uint32]chainfee.SatPerKWeight
// A closure that when set is used instead of the
// mockFeeEstimator.EstimateFeePerKW method.
estimateFeePerKW func(numBlocks uint32) (chainfee.SatPerKWeight, error)
lock sync.Mutex
}
func newMockFeeEstimator(feePerKW,
relayFee chainfee.SatPerKWeight) *mockFeeEstimator {
return &mockFeeEstimator{
feePerKW: feePerKW,
relayFee: relayFee,
blocksToFee: make(map[uint32]chainfee.SatPerKWeight),
}
}
func (e *mockFeeEstimator) updateFees(feePerKW,
relayFee chainfee.SatPerKWeight) {
e.lock.Lock()
defer e.lock.Unlock()
e.feePerKW = feePerKW
e.relayFee = relayFee
}
func (e *mockFeeEstimator) EstimateFeePerKW(numBlocks uint32) (
chainfee.SatPerKWeight, error) {
e.lock.Lock()
defer e.lock.Unlock()
if e.estimateFeePerKW != nil {
return e.estimateFeePerKW(numBlocks)
}
if fee, ok := e.blocksToFee[numBlocks]; ok {
return fee, nil
}
return e.feePerKW, nil
}
func (e *mockFeeEstimator) RelayFeePerKW() chainfee.SatPerKWeight {
e.lock.Lock()
defer e.lock.Unlock()
return e.relayFee
}
func (e *mockFeeEstimator) Start() error {
return nil
}
func (e *mockFeeEstimator) Stop() error {
return nil
}
var _ chainfee.Estimator = (*mockFeeEstimator)(nil)
// MockSweeperStore is a mock implementation of sweeper store. This type is
// exported, because it is currently used in nursery tests too.
type MockSweeperStore struct {
mock.Mock
}
// NewMockSweeperStore returns a new instance.
func NewMockSweeperStore() *MockSweeperStore {
return &MockSweeperStore{}
}
// IsOurTx determines whether a tx is published by us, based on its hash.
func (s *MockSweeperStore) IsOurTx(hash chainhash.Hash) (bool, error) {
args := s.Called(hash)
return args.Bool(0), args.Error(1)
}
// StoreTx stores a tx we are about to publish.
func (s *MockSweeperStore) StoreTx(tr *TxRecord) error {
args := s.Called(tr)
return args.Error(0)
}
// ListSweeps lists all the sweeps we have successfully published.
func (s *MockSweeperStore) ListSweeps() ([]chainhash.Hash, error) {
args := s.Called()
return args.Get(0).([]chainhash.Hash), args.Error(1)
}
// GetTx queries the database to find the tx that matches the given txid.
// Returns ErrTxNotFound if it cannot be found.
func (s *MockSweeperStore) GetTx(hash chainhash.Hash) (*TxRecord, error) {
args := s.Called(hash)
tr := args.Get(0)
if tr != nil {
return args.Get(0).(*TxRecord), args.Error(1)
}
return nil, args.Error(1)
}
// DeleteTx removes the given tx from db.
func (s *MockSweeperStore) DeleteTx(txid chainhash.Hash) error {
args := s.Called(txid)
return args.Error(0)
}
// Compile-time constraint to ensure MockSweeperStore implements SweeperStore.
var _ SweeperStore = (*MockSweeperStore)(nil)
type MockFeePreference struct {
mock.Mock
}
// Compile-time constraint to ensure MockFeePreference implements FeePreference.
var _ FeePreference = (*MockFeePreference)(nil)
func (m *MockFeePreference) String() string {
return "mock fee preference"
}
func (m *MockFeePreference) Estimate(estimator chainfee.Estimator,
maxFeeRate chainfee.SatPerKWeight) (chainfee.SatPerKWeight, error) {
args := m.Called(estimator, maxFeeRate)
if args.Get(0) == nil {
return 0, args.Error(1)
}
return args.Get(0).(chainfee.SatPerKWeight), args.Error(1)
}
type mockUtxoAggregator struct {
mock.Mock
}
// Compile-time constraint to ensure mockUtxoAggregator implements
// UtxoAggregator.
var _ UtxoAggregator = (*mockUtxoAggregator)(nil)
// ClusterInputs takes a list of inputs and groups them into clusters.
func (m *mockUtxoAggregator) ClusterInputs(inputs InputsMap,
defaultDeadline int32) []InputSet {
args := m.Called(inputs, defaultDeadline)
return args.Get(0).([]InputSet)
}
// MockWallet is a mock implementation of the Wallet interface.
type MockWallet struct {
mock.Mock
}
// Compile-time constraint to ensure MockWallet implements Wallet.
var _ Wallet = (*MockWallet)(nil)
// CheckMempoolAcceptance checks if the transaction can be accepted to the
// mempool.
func (m *MockWallet) CheckMempoolAcceptance(tx *wire.MsgTx) error {
args := m.Called(tx)
return args.Error(0)
}
// PublishTransaction performs cursory validation (dust checks, etc) and
// broadcasts the passed transaction to the Bitcoin network.
func (m *MockWallet) PublishTransaction(tx *wire.MsgTx, label string) error {
args := m.Called(tx, label)
return args.Error(0)
}
// ListUnspentWitnessFromDefaultAccount returns all unspent outputs which are
// version 0 witness programs from the default wallet account. The 'minConfs'
// and 'maxConfs' parameters indicate the minimum and maximum number of
// confirmations an output needs in order to be returned by this method.
func (m *MockWallet) ListUnspentWitnessFromDefaultAccount(
minConfs, maxConfs int32) ([]*lnwallet.Utxo, error) {
args := m.Called(minConfs, maxConfs)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]*lnwallet.Utxo), args.Error(1)
}
// WithCoinSelectLock will execute the passed function closure in a
// synchronized manner preventing any coin selection operations from proceeding
// while the closure is executing. This can be seen as the ability to execute a
// function closure under an exclusive coin selection lock.
func (m *MockWallet) WithCoinSelectLock(f func() error) error {
m.Called(f)
return f()
}
// RemoveDescendants removes any wallet transactions that spends
// outputs created by the specified transaction.
func (m *MockWallet) RemoveDescendants(tx *wire.MsgTx) error {
args := m.Called(tx)
return args.Error(0)
}
// FetchTx returns the transaction that corresponds to the transaction
// hash passed in. If the transaction can't be found then a nil
// transaction pointer is returned.
func (m *MockWallet) FetchTx(txid chainhash.Hash) (*wire.MsgTx, error) {
args := m.Called(txid)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*wire.MsgTx), args.Error(1)
}
// CancelRebroadcast is used to inform the rebroadcaster sub-system
// that it no longer needs to try to rebroadcast a transaction. This is
// used to ensure that invalid transactions (inputs spent) aren't
// retried in the background.
func (m *MockWallet) CancelRebroadcast(tx chainhash.Hash) {
m.Called(tx)
}
// GetTransactionDetails returns a detailed description of a tx given its
// transaction hash.
func (m *MockWallet) GetTransactionDetails(txHash *chainhash.Hash) (
*lnwallet.TransactionDetail, error) {
args := m.Called(txHash)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*lnwallet.TransactionDetail), args.Error(1)
}
// MockInputSet is a mock implementation of the InputSet interface.
type MockInputSet struct {
mock.Mock
}
// Compile-time constraint to ensure MockInputSet implements InputSet.
var _ InputSet = (*MockInputSet)(nil)
// Inputs returns the set of inputs that should be used to create a tx.
func (m *MockInputSet) Inputs() []input.Input {
args := m.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).([]input.Input)
}
// FeeRate returns the fee rate that should be used for the tx.
func (m *MockInputSet) FeeRate() chainfee.SatPerKWeight {
args := m.Called()
return args.Get(0).(chainfee.SatPerKWeight)
}
// AddWalletInputs adds wallet inputs to the set until a non-dust
// change output can be made. Return an error if there are not enough
// wallet inputs.
func (m *MockInputSet) AddWalletInputs(wallet Wallet) error {
args := m.Called(wallet)
return args.Error(0)
}
// NeedWalletInput returns true if the input set needs more wallet
// inputs.
func (m *MockInputSet) NeedWalletInput() bool {
args := m.Called()
return args.Bool(0)
}
// DeadlineHeight returns the deadline height for the set.
func (m *MockInputSet) DeadlineHeight() int32 {
args := m.Called()
return args.Get(0).(int32)
}
// Budget givens the total amount that can be used as fees by this input set.
func (m *MockInputSet) Budget() btcutil.Amount {
args := m.Called()
return args.Get(0).(btcutil.Amount)
}
// MockBumper is a mock implementation of the interface Bumper.
type MockBumper struct {
mock.Mock
}
// Compile-time constraint to ensure MockBumper implements Bumper.
var _ Bumper = (*MockBumper)(nil)
// Broadcast broadcasts the transaction to the network.
func (m *MockBumper) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) {
args := m.Called(req)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(chan *BumpResult), args.Error(1)
}
// MockFeeFunction is a mock implementation of the FeeFunction interface.
type MockFeeFunction struct {
mock.Mock
}
// Compile-time constraint to ensure MockFeeFunction implements FeeFunction.
var _ FeeFunction = (*MockFeeFunction)(nil)
// FeeRate returns the current fee rate calculated by the fee function.
func (m *MockFeeFunction) FeeRate() chainfee.SatPerKWeight {
args := m.Called()
return args.Get(0).(chainfee.SatPerKWeight)
}
// Increment adds one delta to the current fee rate.
func (m *MockFeeFunction) Increment() (bool, error) {
args := m.Called()
return args.Bool(0), args.Error(1)
}
// IncreaseFeeRate increases the fee rate by one step.
func (m *MockFeeFunction) IncreaseFeeRate(confTarget uint32) (bool, error) {
args := m.Called(confTarget)
return args.Bool(0), args.Error(1)
}