mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-26 08:55:59 +01:00
lnwallet+sweep: introduce TxPublisher
to handle fee bump
This commit adds `TxPublisher` which implements `Bumper` interface. This is part one of the implementation that focuses on implementing the `Broadcast` method which guarantees a tx can be published with RBF-compliant. It does so by leveraging the `testmempoolaccept` API, keep increasing the fee rate until an RBF-compliant tx is made and broadcasts it. This tx will then be monitored by the `TxPublisher` and in the following commit, the monitoring process will be added.
This commit is contained in:
parent
ecd471ac75
commit
11f7e455d1
6 changed files with 1548 additions and 5 deletions
|
@ -1,6 +1,7 @@
|
|||
package chainntnfs
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/fn"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
@ -50,3 +51,73 @@ func (m *MockMempoolWatcher) LookupInputMempoolSpend(
|
|||
|
||||
return args.Get(0).(fn.Option[wire.MsgTx])
|
||||
}
|
||||
|
||||
// MockNotifier is a mock implementation of the ChainNotifier interface.
|
||||
type MockChainNotifier struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Compile-time check to ensure MockChainNotifier implements ChainNotifier.
|
||||
var _ ChainNotifier = (*MockChainNotifier)(nil)
|
||||
|
||||
// RegisterConfirmationsNtfn registers an intent to be notified once txid
|
||||
// reaches numConfs confirmations.
|
||||
func (m *MockChainNotifier) RegisterConfirmationsNtfn(txid *chainhash.Hash,
|
||||
pkScript []byte, numConfs, heightHint uint32,
|
||||
opts ...NotifierOption) (*ConfirmationEvent, error) {
|
||||
|
||||
args := m.Called(txid, pkScript, numConfs, heightHint)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).(*ConfirmationEvent), args.Error(1)
|
||||
}
|
||||
|
||||
// RegisterSpendNtfn registers an intent to be notified once the target
|
||||
// outpoint is successfully spent within a transaction.
|
||||
func (m *MockChainNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint,
|
||||
pkScript []byte, heightHint uint32) (*SpendEvent, error) {
|
||||
|
||||
args := m.Called(outpoint, pkScript, heightHint)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).(*SpendEvent), args.Error(1)
|
||||
}
|
||||
|
||||
// RegisterBlockEpochNtfn registers an intent to be notified of each new block
|
||||
// connected to the tip of the main chain.
|
||||
func (m *MockChainNotifier) RegisterBlockEpochNtfn(epoch *BlockEpoch) (
|
||||
*BlockEpochEvent, error) {
|
||||
|
||||
args := m.Called(epoch)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).(*BlockEpochEvent), args.Error(1)
|
||||
}
|
||||
|
||||
// Start the ChainNotifier. Once started, the implementation should be ready,
|
||||
// and able to receive notification registrations from clients.
|
||||
func (m *MockChainNotifier) Start() error {
|
||||
args := m.Called()
|
||||
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
// Started returns true if this instance has been started, and false otherwise.
|
||||
func (m *MockChainNotifier) Started() bool {
|
||||
args := m.Called()
|
||||
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
// Stops the concrete ChainNotifier.
|
||||
func (m *MockChainNotifier) Stop() error {
|
||||
args := m.Called()
|
||||
|
||||
return args.Error(0)
|
||||
}
|
||||
|
|
103
input/mocks.go
103
input/mocks.go
|
@ -1,8 +1,14 @@
|
|||
package input
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec/v2"
|
||||
"github.com/btcsuite/btcd/btcec/v2/schnorr"
|
||||
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
|
||||
"github.com/btcsuite/btcd/txscript"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
|
@ -168,3 +174,100 @@ func (m *MockWitnessType) AddWeightEstimation(e *TxWeightEstimator) error {
|
|||
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
// MockInputSigner is a mock implementation of the Signer interface.
|
||||
type MockInputSigner struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Compile-time constraint to ensure MockInputSigner implements Signer.
|
||||
var _ Signer = (*MockInputSigner)(nil)
|
||||
|
||||
// SignOutputRaw generates a signature for the passed transaction according to
|
||||
// the data within the passed SignDescriptor.
|
||||
func (m *MockInputSigner) SignOutputRaw(tx *wire.MsgTx,
|
||||
signDesc *SignDescriptor) (Signature, error) {
|
||||
|
||||
args := m.Called(tx, signDesc)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).(Signature), args.Error(1)
|
||||
}
|
||||
|
||||
// ComputeInputScript generates a complete InputIndex for the passed
|
||||
// transaction with the signature as defined within the passed SignDescriptor.
|
||||
func (m *MockInputSigner) ComputeInputScript(tx *wire.MsgTx,
|
||||
signDesc *SignDescriptor) (*Script, error) {
|
||||
|
||||
args := m.Called(tx, signDesc)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).(*Script), args.Error(1)
|
||||
}
|
||||
|
||||
// MuSig2CreateSession creates a new MuSig2 signing session using the local key
|
||||
// identified by the key locator.
|
||||
func (m *MockInputSigner) MuSig2CreateSession(version MuSig2Version,
|
||||
locator keychain.KeyLocator, pubkey []*btcec.PublicKey,
|
||||
tweak *MuSig2Tweaks, pubNonces [][musig2.PubNonceSize]byte,
|
||||
nonces *musig2.Nonces) (*MuSig2SessionInfo, error) {
|
||||
|
||||
args := m.Called(version, locator, pubkey, tweak, pubNonces, nonces)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).(*MuSig2SessionInfo), args.Error(1)
|
||||
}
|
||||
|
||||
// MuSig2RegisterNonces registers one or more public nonces of other signing
|
||||
// participants for a session identified by its ID.
|
||||
func (m *MockInputSigner) MuSig2RegisterNonces(versio MuSig2SessionID,
|
||||
pubNonces [][musig2.PubNonceSize]byte) (bool, error) {
|
||||
|
||||
args := m.Called(versio, pubNonces)
|
||||
if args.Get(0) == nil {
|
||||
return false, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Bool(0), args.Error(1)
|
||||
}
|
||||
|
||||
// MuSig2Sign creates a partial signature using the local signing key that was
|
||||
// specified when the session was created.
|
||||
func (m *MockInputSigner) MuSig2Sign(sessionID MuSig2SessionID,
|
||||
msg [sha256.Size]byte, withSortedKeys bool) (
|
||||
*musig2.PartialSignature, error) {
|
||||
|
||||
args := m.Called(sessionID, msg, withSortedKeys)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).(*musig2.PartialSignature), args.Error(1)
|
||||
}
|
||||
|
||||
// MuSig2CombineSig combines the given partial signature(s) with the local one,
|
||||
// if it already exists.
|
||||
func (m *MockInputSigner) MuSig2CombineSig(sessionID MuSig2SessionID,
|
||||
partialSig []*musig2.PartialSignature) (
|
||||
*schnorr.Signature, bool, error) {
|
||||
|
||||
args := m.Called(sessionID, partialSig)
|
||||
if args.Get(0) == nil {
|
||||
return nil, false, args.Error(2)
|
||||
}
|
||||
|
||||
return args.Get(0).(*schnorr.Signature), args.Bool(1), args.Error(2)
|
||||
}
|
||||
|
||||
// MuSig2Cleanup removes a session from memory to free up resources.
|
||||
func (m *MockInputSigner) MuSig2Cleanup(sessionID MuSig2SessionID) error {
|
||||
args := m.Called(sessionID)
|
||||
|
||||
return args.Error(0)
|
||||
}
|
||||
|
|
|
@ -3,16 +3,29 @@ package sweep
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/btcsuite/btcd/btcutil"
|
||||
"github.com/btcsuite/btcd/rpcclient"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/btcsuite/btcwallet/chain"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/labels"
|
||||
"github.com/lightningnetwork/lnd/lnutils"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidBumpResult is returned when the bump result is invalid.
|
||||
ErrInvalidBumpResult = errors.New("invalid bump result")
|
||||
|
||||
// ErrNotEnoughBudget is returned when the fee bumper decides the
|
||||
// current budget cannot cover the fee.
|
||||
ErrNotEnoughBudget = errors.New("not enough budget")
|
||||
)
|
||||
|
||||
// Bumper defines an interface that can be used by other subsystems for fee
|
||||
|
@ -165,6 +178,9 @@ type BumpResult struct {
|
|||
|
||||
// Err is the error that occurred during the broadcast.
|
||||
Err error
|
||||
|
||||
// requestID is the ID of the request that created this record.
|
||||
requestID uint64
|
||||
}
|
||||
|
||||
// Validate validates the BumpResult so it's safe to use.
|
||||
|
@ -197,3 +213,460 @@ func (b *BumpResult) Validate() error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TxPublisherConfig is the config used to create a new TxPublisher.
|
||||
type TxPublisherConfig struct {
|
||||
// Signer is used to create the tx signature.
|
||||
Signer input.Signer
|
||||
|
||||
// Wallet is used primarily to publish the tx.
|
||||
Wallet Wallet
|
||||
|
||||
// Estimator is used to estimate the fee rate for the new tx based on
|
||||
// its deadline conf target.
|
||||
Estimator chainfee.Estimator
|
||||
|
||||
// Notifier is used to monitor the confirmation status of the tx.
|
||||
Notifier chainntnfs.ChainNotifier
|
||||
}
|
||||
|
||||
// TxPublisher is an implementation of the Bumper interface. It utilizes the
|
||||
// `testmempoolaccept` RPC to bump the fee of txns it created based on
|
||||
// different fee function selected or configed by the caller. Its purpose is to
|
||||
// take a list of inputs specified, and create a tx that spends them to a
|
||||
// specified output. It will then monitor the confirmation status of the tx,
|
||||
// and if it's not confirmed within a certain time frame, it will attempt to
|
||||
// bump the fee of the tx by creating a new tx that spends the same inputs to
|
||||
// the same output, but with a higher fee rate. It will continue to do this
|
||||
// until the tx is confirmed or the fee rate reaches the maximum fee rate
|
||||
// specified by the caller.
|
||||
type TxPublisher struct {
|
||||
wg sync.WaitGroup
|
||||
|
||||
// cfg specifies the configuration of the TxPublisher.
|
||||
cfg *TxPublisherConfig
|
||||
|
||||
// currentHeight is the current block height.
|
||||
currentHeight int32
|
||||
|
||||
// records is a map keyed by the requestCounter and the value is the tx
|
||||
// being monitored.
|
||||
records lnutils.SyncMap[uint64, *monitorRecord]
|
||||
|
||||
// requestCounter is a monotonically increasing counter used to keep
|
||||
// track of how many requests have been made.
|
||||
requestCounter atomic.Uint64
|
||||
|
||||
// subscriberChans is a map keyed by the requestCounter, each item is
|
||||
// the chan that the publisher sends the fee bump result to.
|
||||
subscriberChans lnutils.SyncMap[uint64, chan *BumpResult]
|
||||
|
||||
// quit is used to signal the publisher to stop.
|
||||
quit chan struct{}
|
||||
}
|
||||
|
||||
// Compile-time constraint to ensure TxPublisher implements Bumper.
|
||||
var _ Bumper = (*TxPublisher)(nil)
|
||||
|
||||
// NewTxPublisher creates a new TxPublisher.
|
||||
func NewTxPublisher(cfg TxPublisherConfig) *TxPublisher {
|
||||
return &TxPublisher{
|
||||
cfg: &cfg,
|
||||
records: lnutils.SyncMap[uint64, *monitorRecord]{},
|
||||
subscriberChans: lnutils.SyncMap[uint64, chan *BumpResult]{},
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast is used to publish the tx created from the given inputs. It will,
|
||||
// 1. init a fee function based on the given strategy.
|
||||
// 2. create an RBF-compliant tx and monitor it for confirmation.
|
||||
// 3. notify the initial broadcast result back to the caller.
|
||||
// The initial broadcast is guaranteed to be RBF-compliant unless the budget
|
||||
// specified cannot cover the fee.
|
||||
//
|
||||
// NOTE: part of the Bumper interface.
|
||||
func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) {
|
||||
log.Tracef("Received broadcast request: %s", newLogClosure(
|
||||
func() string {
|
||||
return spew.Sdump(req)
|
||||
})())
|
||||
|
||||
// Attempt an initial broadcast which is guaranteed to comply with the
|
||||
// RBF rules.
|
||||
result, err := t.initialBroadcast(req)
|
||||
if err != nil {
|
||||
log.Errorf("Initial broadcast failed: %v", err)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create a chan to send the result to the caller.
|
||||
subscriber := make(chan *BumpResult, 1)
|
||||
t.subscriberChans.Store(result.requestID, subscriber)
|
||||
|
||||
// Send the initial broadcast result to the caller.
|
||||
t.handleResult(result)
|
||||
|
||||
return subscriber, nil
|
||||
}
|
||||
|
||||
// initialBroadcast initializes a fee function, creates an RBF-compliant tx and
|
||||
// broadcasts it.
|
||||
func (t *TxPublisher) initialBroadcast(req *BumpRequest) (*BumpResult, error) {
|
||||
// Create a fee bumping algorithm to be used for future RBF.
|
||||
feeAlgo, err := t.initializeFeeFunction(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init fee function: %w", err)
|
||||
}
|
||||
|
||||
// Create the initial tx to be broadcasted. This tx is guaranteed to
|
||||
// comply with the RBF restrictions.
|
||||
requestID, err := t.createRBFCompliantTx(req, feeAlgo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create RBF-compliant tx: %w", err)
|
||||
}
|
||||
|
||||
// Broadcast the tx and return the monitored record.
|
||||
result, err := t.broadcast(requestID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("broadcast sweep tx: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// initializeFeeFunction initializes a fee function to be used for this request
|
||||
// for future fee bumping.
|
||||
func (t *TxPublisher) initializeFeeFunction(
|
||||
req *BumpRequest) (FeeFunction, error) {
|
||||
|
||||
// Get the max allowed feerate.
|
||||
maxFeeRateAllowed, err := req.MaxFeeRateAllowed()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the initial conf target.
|
||||
confTarget := calcCurrentConfTarget(t.currentHeight, req.DeadlineHeight)
|
||||
|
||||
// Initialize the fee function and return it.
|
||||
//
|
||||
// TODO(yy): return based on differet req.Strategy?
|
||||
return NewLinearFeeFunction(
|
||||
maxFeeRateAllowed, confTarget, t.cfg.Estimator,
|
||||
)
|
||||
}
|
||||
|
||||
// createRBFCompliantTx creates a tx that is compliant with RBF rules. It does
|
||||
// so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee
|
||||
// and redo the process until the tx is valid, or return an error when non-RBF
|
||||
// related errors occur or the budget has been used up.
|
||||
func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest,
|
||||
f FeeFunction) (uint64, error) {
|
||||
|
||||
for {
|
||||
// Create a new tx with the given fee rate and check its
|
||||
// mempool acceptance.
|
||||
tx, fee, err := t.createAndCheckTx(req, f)
|
||||
|
||||
switch {
|
||||
case err == nil:
|
||||
// The tx is valid, return the request ID.
|
||||
requestID := t.storeRecord(tx, req, f, fee)
|
||||
|
||||
log.Infof("Created tx %v for %v inputs: feerate=%v, "+
|
||||
"fee=%v, inputs=%v", tx.TxHash(),
|
||||
len(req.Inputs), f.FeeRate(), fee,
|
||||
inputTypeSummary(req.Inputs))
|
||||
|
||||
return requestID, nil
|
||||
|
||||
// If the error indicates the fees paid is not enough, we will
|
||||
// ask the fee function to increase the fee rate and retry.
|
||||
case errors.Is(err, lnwallet.ErrMempoolFee):
|
||||
// We should at least start with a feerate above the
|
||||
// mempool min feerate, so if we get this error, it
|
||||
// means something is wrong earlier in the pipeline.
|
||||
log.Errorf("Current fee=%v, feerate=%v, %v", fee,
|
||||
f.FeeRate(), err)
|
||||
|
||||
fallthrough
|
||||
|
||||
// We are not paying enough fees so we increase it.
|
||||
case errors.Is(err, rpcclient.ErrInsufficientFee):
|
||||
increased := false
|
||||
|
||||
// Keep calling the fee function until the fee rate is
|
||||
// increased or maxed out.
|
||||
for !increased {
|
||||
log.Debugf("Increasing fee for next round, "+
|
||||
"current fee=%v, feerate=%v", fee,
|
||||
f.FeeRate())
|
||||
|
||||
// If the fee function tells us that we have
|
||||
// used up the budget, we will return an error
|
||||
// indicating this tx cannot be made. The
|
||||
// sweeper should handle this error and try to
|
||||
// cluster these inputs differetly.
|
||||
increased, err = f.Increment()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(yy): suppose there's only one bad input, we can do a
|
||||
// binary search to find out which input is causing this error
|
||||
// by recreating a tx using half of the inputs and check its
|
||||
// mempool acceptance.
|
||||
default:
|
||||
log.Debugf("Failed to create RBF-compliant tx: %v", err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// storeRecord stores the given record in the records map.
|
||||
func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest,
|
||||
f FeeFunction, fee btcutil.Amount) uint64 {
|
||||
|
||||
// Increase the request counter.
|
||||
//
|
||||
// NOTE: this is the only place where we increase the
|
||||
// counter.
|
||||
requestID := t.requestCounter.Add(1)
|
||||
|
||||
// Register the record.
|
||||
t.records.Store(requestID, &monitorRecord{
|
||||
tx: tx,
|
||||
req: req,
|
||||
feeFunction: f,
|
||||
fee: fee,
|
||||
})
|
||||
|
||||
return requestID
|
||||
}
|
||||
|
||||
// createAndCheckTx creates a tx based on the given inputs, change output
|
||||
// script, and the fee rate. In addition, it validates the tx's mempool
|
||||
// acceptance before returning a tx that can be published directly, along with
|
||||
// its fee.
|
||||
func (t *TxPublisher) createAndCheckTx(req *BumpRequest, f FeeFunction) (
|
||||
*wire.MsgTx, btcutil.Amount, error) {
|
||||
|
||||
// Create the sweep tx with max fee rate of 0 as the fee function
|
||||
// guarantees the fee rate used here won't exceed the max fee rate.
|
||||
//
|
||||
// TODO(yy): refactor this function to not require a max fee rate.
|
||||
tx, fee, err := createSweepTx(
|
||||
req.Inputs, nil, req.DeliveryAddress, uint32(t.currentHeight),
|
||||
f.FeeRate(), 0, t.cfg.Signer,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("create sweep tx: %w", err)
|
||||
}
|
||||
|
||||
// Sanity check the budget still covers the fee.
|
||||
if fee > req.Budget {
|
||||
return nil, 0, fmt.Errorf("%w: budget=%v, fee=%v",
|
||||
ErrNotEnoughBudget, req.Budget, fee)
|
||||
}
|
||||
|
||||
// Validate the tx's mempool acceptance.
|
||||
err = t.cfg.Wallet.CheckMempoolAcceptance(tx)
|
||||
|
||||
// Exit early if the tx is valid.
|
||||
if err == nil {
|
||||
return tx, fee, nil
|
||||
}
|
||||
|
||||
// Print an error log if the chain backend doesn't support the mempool
|
||||
// acceptance test RPC.
|
||||
if errors.Is(err, rpcclient.ErrBackendVersion) {
|
||||
log.Errorf("TestMempoolAccept not supported by backend, " +
|
||||
"consider upgrading it to a newer version")
|
||||
return tx, fee, nil
|
||||
}
|
||||
|
||||
// We are running on a backend that doesn't implement the RPC
|
||||
// testmempoolaccept, eg, neutrino, so we'll skip the check.
|
||||
if errors.Is(err, chain.ErrUnimplemented) {
|
||||
log.Debug("Skipped testmempoolaccept due to not implemented")
|
||||
return tx, fee, nil
|
||||
}
|
||||
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// broadcast takes a monitored tx and publishes it to the network. Prior to the
|
||||
// broadcast, it will subscribe the tx's confirmation notification and attach
|
||||
// the event channel to the record. Any broadcast-related errors will not be
|
||||
// returned here, instead, they will be put inside the `BumpResult` and
|
||||
// returned to the caller.
|
||||
func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) {
|
||||
// Get the record being monitored.
|
||||
record, ok := t.records.Load(requestID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tx record %v not found", requestID)
|
||||
}
|
||||
|
||||
txid := record.tx.TxHash()
|
||||
|
||||
// Subscribe to its confirmation notification.
|
||||
confEvent, err := t.cfg.Notifier.RegisterConfirmationsNtfn(
|
||||
&txid, nil, 1, uint32(t.currentHeight),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register confirmation ntfn: %w", err)
|
||||
}
|
||||
|
||||
// Attach the confirmation event channel to the record.
|
||||
record.confEvent = confEvent
|
||||
|
||||
tx := record.tx
|
||||
log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v",
|
||||
txid, len(tx.TxIn), t.currentHeight)
|
||||
|
||||
// Set the event, and change it to TxFailed if the wallet fails to
|
||||
// publish it.
|
||||
event := TxPublished
|
||||
|
||||
// Publish the sweeping tx with customized label. If the publish fails,
|
||||
// this error will be saved in the `BumpResult` and it will be removed
|
||||
// from being monitored.
|
||||
err = t.cfg.Wallet.PublishTransaction(
|
||||
tx, labels.MakeLabel(labels.LabelTypeSweepTransaction, nil),
|
||||
)
|
||||
if err != nil {
|
||||
// NOTE: we decide to attach this error to the result instead
|
||||
// of returning it here because by the time the tx reaches
|
||||
// here, it should have passed the mempool acceptance check. If
|
||||
// it still fails to be broadcast, it's likely a non-RBF
|
||||
// related error happened. So we send this error back to the
|
||||
// caller so that it can handle it properly.
|
||||
//
|
||||
// TODO(yy): find out which input is causing the failure.
|
||||
log.Errorf("Failed to publish tx %v: %v", txid, err)
|
||||
event = TxFailed
|
||||
}
|
||||
|
||||
result := &BumpResult{
|
||||
Event: event,
|
||||
Tx: record.tx,
|
||||
Fee: record.fee,
|
||||
FeeRate: record.feeFunction.FeeRate(),
|
||||
Err: err,
|
||||
requestID: requestID,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// notifyResult sends the result to the resultChan specified by the requestID.
|
||||
// This channel is expected to be read by the caller.
|
||||
func (t *TxPublisher) notifyResult(result *BumpResult) {
|
||||
id := result.requestID
|
||||
subscriber, ok := t.subscriberChans.Load(id)
|
||||
if !ok {
|
||||
log.Errorf("Result chan for id=%v not found", id)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Sending result for requestID=%v, tx=%v", id,
|
||||
result.Tx.TxHash())
|
||||
|
||||
select {
|
||||
// Send the result to the subscriber.
|
||||
//
|
||||
// TODO(yy): Add timeout in case it's blocking?
|
||||
case subscriber <- result:
|
||||
case <-t.quit:
|
||||
log.Debug("Fee bumper stopped")
|
||||
}
|
||||
}
|
||||
|
||||
// removeResult removes the tracking of the result if the result contains a
|
||||
// non-nil error, or the tx is confirmed, the record will be removed from the
|
||||
// maps.
|
||||
func (t *TxPublisher) removeResult(result *BumpResult) {
|
||||
id := result.requestID
|
||||
|
||||
// Remove the record from the maps if there's an error. This means this
|
||||
// tx has failed its broadcast and cannot be retried. There are two
|
||||
// cases,
|
||||
// - when the budget cannot cover the fee.
|
||||
// - when a non-RBF related error occurs.
|
||||
switch result.Event {
|
||||
case TxFailed:
|
||||
log.Errorf("Removing monitor record=%v, tx=%v, due to err: %v",
|
||||
id, result.Tx.TxHash(), result.Err)
|
||||
|
||||
case TxConfirmed:
|
||||
// Remove the record is the tx is confirmed.
|
||||
log.Debugf("Removing confirmed monitor record=%v, tx=%v", id,
|
||||
result.Tx.TxHash())
|
||||
|
||||
// Do nothing if it's neither failed or confirmed.
|
||||
default:
|
||||
log.Tracef("Skipping record removal for id=%v, event=%v", id,
|
||||
result.Event)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
t.records.Delete(id)
|
||||
t.subscriberChans.Delete(id)
|
||||
}
|
||||
|
||||
// handleResult handles the result of a tx broadcast. It will notify the
|
||||
// subscriber and remove the record if the tx is confirmed or failed to be
|
||||
// broadcast.
|
||||
func (t *TxPublisher) handleResult(result *BumpResult) {
|
||||
// Notify the subscriber.
|
||||
t.notifyResult(result)
|
||||
|
||||
// Remove the record if it's failed or confirmed.
|
||||
t.removeResult(result)
|
||||
}
|
||||
|
||||
// monitorRecord is used to keep track of the tx being monitored by the
|
||||
// publisher internally.
|
||||
type monitorRecord struct {
|
||||
// tx is the tx being monitored.
|
||||
tx *wire.MsgTx
|
||||
|
||||
// req is the original request.
|
||||
req *BumpRequest
|
||||
|
||||
// confEvent is the subscription to the confirmation event of the tx.
|
||||
confEvent *chainntnfs.ConfirmationEvent
|
||||
|
||||
// feeFunction is the fee bumping algorithm used by the publisher.
|
||||
feeFunction FeeFunction
|
||||
|
||||
// fee is the fee paid by the tx.
|
||||
fee btcutil.Amount
|
||||
}
|
||||
|
||||
// calcCurrentConfTarget calculates the current confirmation target based on
|
||||
// the deadline height. The conf target is capped at 0 if the deadline has
|
||||
// already been past.
|
||||
func calcCurrentConfTarget(currentHeight, deadline int32) uint32 {
|
||||
var confTarget uint32
|
||||
|
||||
// Calculate how many blocks left until the deadline.
|
||||
deadlineDelta := deadline - currentHeight
|
||||
|
||||
// If we are already past the deadline, we will set the conf target to
|
||||
// be 1.
|
||||
if deadlineDelta <= 0 {
|
||||
log.Warnf("Deadline is %d blocks behind current height %v",
|
||||
-deadlineDelta, currentHeight)
|
||||
|
||||
confTarget = 1
|
||||
} else {
|
||||
confTarget = uint32(deadlineDelta)
|
||||
}
|
||||
|
||||
return confTarget
|
||||
}
|
||||
|
|
|
@ -1,12 +1,18 @@
|
|||
package sweep
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcutil"
|
||||
"github.com/btcsuite/btcd/rpcclient"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -171,3 +177,865 @@ func TestBumpRequestMaxFeeRateAllowed(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCalcCurrentConfTarget checks that the current confirmation target is
|
||||
// calculated correctly.
|
||||
func TestCalcCurrentConfTarget(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// When the current block height is 100 and deadline height is 200, the
|
||||
// conf target should be 100.
|
||||
conf := calcCurrentConfTarget(int32(100), int32(200))
|
||||
require.EqualValues(t, 100, conf)
|
||||
|
||||
// When the current block height is 200 and deadline height is 100, the
|
||||
// conf target should be 1 since the deadline has passed.
|
||||
conf = calcCurrentConfTarget(int32(200), int32(100))
|
||||
require.EqualValues(t, 1, conf)
|
||||
}
|
||||
|
||||
// TestInitializeFeeFunction tests the initialization of the fee function.
|
||||
func TestInitializeFeeFunction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a test input.
|
||||
inp := createTestInput(100, input.WitnessKeyHash)
|
||||
|
||||
// Create a mock fee estimator.
|
||||
estimator := &chainfee.MockEstimator{}
|
||||
defer estimator.AssertExpectations(t)
|
||||
|
||||
// Create a publisher using the mocks.
|
||||
tp := NewTxPublisher(TxPublisherConfig{
|
||||
Estimator: estimator,
|
||||
})
|
||||
|
||||
// Create a test feerate.
|
||||
feerate := chainfee.SatPerKWeight(1000)
|
||||
|
||||
// Create a testing bump request.
|
||||
req := &BumpRequest{
|
||||
DeliveryAddress: changePkScript,
|
||||
Inputs: []input.Input{&inp},
|
||||
Budget: btcutil.Amount(1000),
|
||||
MaxFeeRate: feerate,
|
||||
}
|
||||
|
||||
// Mock the fee estimator to return an error.
|
||||
//
|
||||
// We are not testing `NewLinearFeeFunction` here, so the actual params
|
||||
// used are irrelevant.
|
||||
dummyErr := fmt.Errorf("dummy error")
|
||||
estimator.On("EstimateFeePerKW", mock.Anything).Return(
|
||||
chainfee.SatPerKWeight(0), dummyErr).Once()
|
||||
|
||||
// Call the method under test and assert the error is returned.
|
||||
f, err := tp.initializeFeeFunction(req)
|
||||
require.ErrorIs(t, err, dummyErr)
|
||||
require.Nil(t, f)
|
||||
|
||||
// Mock the fee estimator to return the testing fee rate.
|
||||
//
|
||||
// We are not testing `NewLinearFeeFunction` here, so the actual params
|
||||
// used are irrelevant.
|
||||
estimator.On("EstimateFeePerKW", mock.Anything).Return(
|
||||
feerate, nil).Once()
|
||||
estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once()
|
||||
|
||||
// Call the method under test.
|
||||
f, err = tp.initializeFeeFunction(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, feerate, f.FeeRate())
|
||||
}
|
||||
|
||||
// TestStoreRecord correctly increases the request counter and saves the
|
||||
// record.
|
||||
func TestStoreRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a test input.
|
||||
inp := createTestInput(1000, input.WitnessKeyHash)
|
||||
|
||||
// Create a bump request.
|
||||
req := &BumpRequest{
|
||||
DeliveryAddress: changePkScript,
|
||||
Inputs: []input.Input{&inp},
|
||||
Budget: btcutil.Amount(1000),
|
||||
}
|
||||
|
||||
// Create a naive fee function.
|
||||
feeFunc := &LinearFeeFunction{}
|
||||
|
||||
// Create a test fee and tx.
|
||||
fee := btcutil.Amount(1000)
|
||||
tx := &wire.MsgTx{}
|
||||
|
||||
// Create a publisher using the mocks.
|
||||
tp := NewTxPublisher(TxPublisherConfig{})
|
||||
|
||||
// Get the current counter and check it's increased later.
|
||||
initialCounter := tp.requestCounter.Load()
|
||||
|
||||
// Call the method under test.
|
||||
requestID := tp.storeRecord(tx, req, feeFunc, fee)
|
||||
|
||||
// Check the request ID is as expected.
|
||||
require.Equal(t, initialCounter+1, requestID)
|
||||
|
||||
// Read the saved record and compare.
|
||||
record, ok := tp.records.Load(requestID)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tx, record.tx)
|
||||
require.Equal(t, feeFunc, record.feeFunction)
|
||||
require.Equal(t, fee, record.fee)
|
||||
require.Equal(t, req, record.req)
|
||||
}
|
||||
|
||||
// mockers wraps a list of mocked interfaces used inside tx publisher.
|
||||
type mockers struct {
|
||||
signer *input.MockInputSigner
|
||||
wallet *MockWallet
|
||||
estimator *chainfee.MockEstimator
|
||||
notifier *chainntnfs.MockChainNotifier
|
||||
|
||||
feeFunc *MockFeeFunction
|
||||
}
|
||||
|
||||
// createTestPublisher creates a new tx publisher using the provided mockers.
|
||||
func createTestPublisher(t *testing.T) (*TxPublisher, *mockers) {
|
||||
// Create a mock fee estimator.
|
||||
estimator := &chainfee.MockEstimator{}
|
||||
|
||||
// Create a mock fee function.
|
||||
feeFunc := &MockFeeFunction{}
|
||||
|
||||
// Create a mock signer.
|
||||
signer := &input.MockInputSigner{}
|
||||
|
||||
// Create a mock wallet.
|
||||
wallet := &MockWallet{}
|
||||
|
||||
// Create a mock chain notifier.
|
||||
notifier := &chainntnfs.MockChainNotifier{}
|
||||
|
||||
t.Cleanup(func() {
|
||||
estimator.AssertExpectations(t)
|
||||
feeFunc.AssertExpectations(t)
|
||||
signer.AssertExpectations(t)
|
||||
wallet.AssertExpectations(t)
|
||||
notifier.AssertExpectations(t)
|
||||
})
|
||||
|
||||
m := &mockers{
|
||||
signer: signer,
|
||||
wallet: wallet,
|
||||
estimator: estimator,
|
||||
notifier: notifier,
|
||||
feeFunc: feeFunc,
|
||||
}
|
||||
|
||||
// Create a publisher using the mocks.
|
||||
tp := NewTxPublisher(TxPublisherConfig{
|
||||
Estimator: m.estimator,
|
||||
Signer: m.signer,
|
||||
Wallet: m.wallet,
|
||||
Notifier: m.notifier,
|
||||
})
|
||||
|
||||
return tp, m
|
||||
}
|
||||
|
||||
// TestCreateAndCheckTx checks `createAndCheckTx` behaves as expected.
|
||||
func TestCreateAndCheckTx(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a test request.
|
||||
inp := createTestInput(1000, input.WitnessKeyHash)
|
||||
|
||||
// Create a publisher using the mocks.
|
||||
tp, m := createTestPublisher(t)
|
||||
|
||||
// Create a test feerate and return it from the mock fee function.
|
||||
feerate := chainfee.SatPerKWeight(1000)
|
||||
m.feeFunc.On("FeeRate").Return(feerate)
|
||||
|
||||
// Mock the wallet to fail on testmempoolaccept on the first call, and
|
||||
// succeed on the second.
|
||||
m.wallet.On("CheckMempoolAcceptance",
|
||||
mock.Anything).Return(errDummy).Once()
|
||||
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
|
||||
|
||||
// Mock the signer to always return a valid script.
|
||||
//
|
||||
// NOTE: we are not testing the utility of creating valid txes here, so
|
||||
// this is fine to be mocked. This behaves essentially as skipping the
|
||||
// Signer check and alaways assume the tx has a valid sig.
|
||||
script := &input.Script{}
|
||||
m.signer.On("ComputeInputScript", mock.Anything,
|
||||
mock.Anything).Return(script, nil)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
req *BumpRequest
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
// When the budget cannot cover the fee, an error
|
||||
// should be returned.
|
||||
name: "not enough budget",
|
||||
req: &BumpRequest{
|
||||
DeliveryAddress: changePkScript,
|
||||
Inputs: []input.Input{&inp},
|
||||
},
|
||||
expectedErr: ErrNotEnoughBudget,
|
||||
},
|
||||
{
|
||||
// When the mempool rejects the transaction, an error
|
||||
// should be returned.
|
||||
name: "testmempoolaccept fail",
|
||||
req: &BumpRequest{
|
||||
DeliveryAddress: changePkScript,
|
||||
Inputs: []input.Input{&inp},
|
||||
Budget: btcutil.Amount(1000),
|
||||
},
|
||||
expectedErr: errDummy,
|
||||
},
|
||||
{
|
||||
// When the mempool accepts the transaction, no error
|
||||
// should be returned.
|
||||
name: "testmempoolaccept pass",
|
||||
req: &BumpRequest{
|
||||
DeliveryAddress: changePkScript,
|
||||
Inputs: []input.Input{&inp},
|
||||
Budget: btcutil.Amount(1000),
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Call the method under test.
|
||||
_, _, err := tp.createAndCheckTx(tc.req, m.feeFunc)
|
||||
|
||||
// Check the result is as expected.
|
||||
require.ErrorIs(t, err, tc.expectedErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// createTestBumpRequest creates a new bump request.
|
||||
func createTestBumpRequest() *BumpRequest {
|
||||
// Create a test input.
|
||||
inp := createTestInput(1000, input.WitnessKeyHash)
|
||||
|
||||
return &BumpRequest{
|
||||
DeliveryAddress: changePkScript,
|
||||
Inputs: []input.Input{&inp},
|
||||
Budget: btcutil.Amount(1000),
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateRBFCompliantTx checks that `createRBFCompliantTx` behaves as
|
||||
// expected.
|
||||
func TestCreateRBFCompliantTx(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a publisher using the mocks.
|
||||
tp, m := createTestPublisher(t)
|
||||
|
||||
// Create a test bump request.
|
||||
req := createTestBumpRequest()
|
||||
|
||||
// Create a test feerate and return it from the mock fee function.
|
||||
feerate := chainfee.SatPerKWeight(1000)
|
||||
m.feeFunc.On("FeeRate").Return(feerate)
|
||||
|
||||
// Mock the signer to always return a valid script.
|
||||
//
|
||||
// NOTE: we are not testing the utility of creating valid txes here, so
|
||||
// this is fine to be mocked. This behaves essentially as skipping the
|
||||
// Signer check and alaways assume the tx has a valid sig.
|
||||
script := &input.Script{}
|
||||
m.signer.On("ComputeInputScript", mock.Anything,
|
||||
mock.Anything).Return(script, nil)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupMock func()
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
// When testmempoolaccept accepts the tx, no error
|
||||
// should be returned.
|
||||
name: "success case",
|
||||
setupMock: func() {
|
||||
// Mock the testmempoolaccept to pass.
|
||||
m.wallet.On("CheckMempoolAcceptance",
|
||||
mock.Anything).Return(nil).Once()
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
// When testmempoolaccept fails due to a non-fee
|
||||
// related error, an error should be returned.
|
||||
name: "non-fee related testmempoolaccept fail",
|
||||
setupMock: func() {
|
||||
// Mock the testmempoolaccept to fail.
|
||||
m.wallet.On("CheckMempoolAcceptance",
|
||||
mock.Anything).Return(errDummy).Once()
|
||||
},
|
||||
expectedErr: errDummy,
|
||||
},
|
||||
{
|
||||
// When increase feerate gives an error, the error
|
||||
// should be returned.
|
||||
name: "fail on increase fee",
|
||||
setupMock: func() {
|
||||
// Mock the testmempoolaccept to fail on fee.
|
||||
m.wallet.On("CheckMempoolAcceptance",
|
||||
mock.Anything).Return(
|
||||
lnwallet.ErrMempoolFee).Once()
|
||||
|
||||
// Mock the fee function to return an error.
|
||||
m.feeFunc.On("Increment").Return(
|
||||
false, errDummy).Once()
|
||||
},
|
||||
expectedErr: errDummy,
|
||||
},
|
||||
{
|
||||
// Test that after one round of increasing the feerate
|
||||
// the tx passes testmempoolaccept.
|
||||
name: "increase fee and success on min mempool fee",
|
||||
setupMock: func() {
|
||||
// Mock the testmempoolaccept to fail on fee
|
||||
// for the first call.
|
||||
m.wallet.On("CheckMempoolAcceptance",
|
||||
mock.Anything).Return(
|
||||
lnwallet.ErrMempoolFee).Once()
|
||||
|
||||
// Mock the fee function to increase feerate.
|
||||
m.feeFunc.On("Increment").Return(
|
||||
true, nil).Once()
|
||||
|
||||
// Mock the testmempoolaccept to pass on the
|
||||
// second call.
|
||||
m.wallet.On("CheckMempoolAcceptance",
|
||||
mock.Anything).Return(nil).Once()
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
// Test that after one round of increasing the feerate
|
||||
// the tx passes testmempoolaccept.
|
||||
name: "increase fee and success on insufficient fee",
|
||||
setupMock: func() {
|
||||
// Mock the testmempoolaccept to fail on fee
|
||||
// for the first call.
|
||||
m.wallet.On("CheckMempoolAcceptance",
|
||||
mock.Anything).Return(
|
||||
rpcclient.ErrInsufficientFee).Once()
|
||||
|
||||
// Mock the fee function to increase feerate.
|
||||
m.feeFunc.On("Increment").Return(
|
||||
true, nil).Once()
|
||||
|
||||
// Mock the testmempoolaccept to pass on the
|
||||
// second call.
|
||||
m.wallet.On("CheckMempoolAcceptance",
|
||||
mock.Anything).Return(nil).Once()
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
// Test that the fee function increases the fee rate
|
||||
// after one round.
|
||||
name: "increase fee on second round",
|
||||
setupMock: func() {
|
||||
// Mock the testmempoolaccept to fail on fee
|
||||
// for the first call.
|
||||
m.wallet.On("CheckMempoolAcceptance",
|
||||
mock.Anything).Return(
|
||||
rpcclient.ErrInsufficientFee).Once()
|
||||
|
||||
// Mock the fee function to NOT increase
|
||||
// feerate on the first round.
|
||||
m.feeFunc.On("Increment").Return(
|
||||
false, nil).Once()
|
||||
|
||||
// Mock the fee function to increase feerate.
|
||||
m.feeFunc.On("Increment").Return(
|
||||
true, nil).Once()
|
||||
|
||||
// Mock the testmempoolaccept to pass on the
|
||||
// second call.
|
||||
m.wallet.On("CheckMempoolAcceptance",
|
||||
mock.Anything).Return(nil).Once()
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.setupMock()
|
||||
|
||||
// Call the method under test.
|
||||
id, err := tp.createRBFCompliantTx(req, m.feeFunc)
|
||||
|
||||
// Check the result is as expected.
|
||||
require.ErrorIs(t, err, tc.expectedErr)
|
||||
|
||||
// If there's an error, expect the requestID to be
|
||||
// empty.
|
||||
if tc.expectedErr != nil {
|
||||
require.Zero(t, id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTxPublisherBroadcast checks the internal `broadcast` method behaves as
|
||||
// expected.
|
||||
func TestTxPublisherBroadcast(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a publisher using the mocks.
|
||||
tp, m := createTestPublisher(t)
|
||||
|
||||
// Create a test bump request.
|
||||
req := createTestBumpRequest()
|
||||
|
||||
// Create a test tx.
|
||||
tx := &wire.MsgTx{LockTime: 1}
|
||||
txid := tx.TxHash()
|
||||
|
||||
// Create a test feerate and return it from the mock fee function.
|
||||
feerate := chainfee.SatPerKWeight(1000)
|
||||
m.feeFunc.On("FeeRate").Return(feerate)
|
||||
|
||||
// Create a test conf event.
|
||||
confEvent := &chainntnfs.ConfirmationEvent{}
|
||||
|
||||
// Create a testing record and put it in the map.
|
||||
fee := btcutil.Amount(1000)
|
||||
requestID := tp.storeRecord(tx, req, m.feeFunc, fee)
|
||||
|
||||
// Quickly check when the requestID cannot be found, an error is
|
||||
// returned.
|
||||
result, err := tp.broadcast(uint64(1000))
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
|
||||
// Define params to be used in RegisterConfirmationsNtfn. Not important
|
||||
// for this test.
|
||||
var pkScript []byte
|
||||
confs := uint32(1)
|
||||
height := uint32(tp.currentHeight)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupMock func()
|
||||
expectedErr error
|
||||
expectedResult *BumpResult
|
||||
}{
|
||||
{
|
||||
// When the notifier cannot register this spend, an
|
||||
// error should be returned
|
||||
name: "fail to register nftn",
|
||||
setupMock: func() {
|
||||
// Mock the RegisterConfirmationsNtfn to fail.
|
||||
m.notifier.On("RegisterConfirmationsNtfn",
|
||||
&txid, pkScript, confs, height).Return(
|
||||
nil, errDummy).Once()
|
||||
},
|
||||
expectedErr: errDummy,
|
||||
expectedResult: nil,
|
||||
},
|
||||
{
|
||||
// When the wallet cannot publish this tx, the error
|
||||
// should be put inside the result.
|
||||
name: "fail to publish",
|
||||
setupMock: func() {
|
||||
// Mock the RegisterConfirmationsNtfn to pass.
|
||||
m.notifier.On("RegisterConfirmationsNtfn",
|
||||
&txid, pkScript, confs, height).Return(
|
||||
confEvent, nil).Once()
|
||||
|
||||
// Mock the wallet to fail to publish.
|
||||
m.wallet.On("PublishTransaction",
|
||||
tx, mock.Anything).Return(
|
||||
errDummy).Once()
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectedResult: &BumpResult{
|
||||
Event: TxFailed,
|
||||
Tx: tx,
|
||||
Fee: fee,
|
||||
FeeRate: feerate,
|
||||
Err: errDummy,
|
||||
requestID: requestID,
|
||||
},
|
||||
},
|
||||
{
|
||||
// When nothing goes wrong, the result is returned.
|
||||
name: "publish success",
|
||||
setupMock: func() {
|
||||
// Mock the RegisterConfirmationsNtfn to pass.
|
||||
m.notifier.On("RegisterConfirmationsNtfn",
|
||||
&txid, pkScript, confs, height).Return(
|
||||
confEvent, nil).Once()
|
||||
|
||||
// Mock the wallet to publish successfully.
|
||||
m.wallet.On("PublishTransaction",
|
||||
tx, mock.Anything).Return(nil).Once()
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectedResult: &BumpResult{
|
||||
Event: TxPublished,
|
||||
Tx: tx,
|
||||
Fee: fee,
|
||||
FeeRate: feerate,
|
||||
Err: nil,
|
||||
requestID: requestID,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.setupMock()
|
||||
|
||||
// Call the method under test.
|
||||
result, err := tp.broadcast(requestID)
|
||||
|
||||
// Check the result is as expected.
|
||||
require.ErrorIs(t, err, tc.expectedErr)
|
||||
require.Equal(t, tc.expectedResult, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRemoveResult checks the records and subscriptions are removed when a tx
|
||||
// is confirmed or failed.
|
||||
func TestRemoveResult(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a publisher using the mocks.
|
||||
tp, m := createTestPublisher(t)
|
||||
|
||||
// Create a test bump request.
|
||||
req := createTestBumpRequest()
|
||||
|
||||
// Create a test tx.
|
||||
tx := &wire.MsgTx{LockTime: 1}
|
||||
|
||||
// Create a testing record and put it in the map.
|
||||
fee := btcutil.Amount(1000)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupRecord func() uint64
|
||||
result *BumpResult
|
||||
removed bool
|
||||
}{
|
||||
{
|
||||
// When the tx is confirmed, the records will be
|
||||
// removed.
|
||||
name: "remove on TxConfirmed",
|
||||
setupRecord: func() uint64 {
|
||||
id := tp.storeRecord(tx, req, m.feeFunc, fee)
|
||||
tp.subscriberChans.Store(id, nil)
|
||||
|
||||
return id
|
||||
},
|
||||
result: &BumpResult{
|
||||
Event: TxConfirmed,
|
||||
Tx: tx,
|
||||
},
|
||||
removed: true,
|
||||
},
|
||||
{
|
||||
// When the tx is failed, the records will be removed.
|
||||
name: "remove on TxFailed",
|
||||
setupRecord: func() uint64 {
|
||||
id := tp.storeRecord(tx, req, m.feeFunc, fee)
|
||||
tp.subscriberChans.Store(id, nil)
|
||||
|
||||
return id
|
||||
},
|
||||
result: &BumpResult{
|
||||
Event: TxFailed,
|
||||
Err: errDummy,
|
||||
Tx: tx,
|
||||
},
|
||||
removed: true,
|
||||
},
|
||||
{
|
||||
// Noop when the tx is neither confirmed or failed.
|
||||
name: "noop when tx is not confirmed or failed",
|
||||
setupRecord: func() uint64 {
|
||||
id := tp.storeRecord(tx, req, m.feeFunc, fee)
|
||||
tp.subscriberChans.Store(id, nil)
|
||||
|
||||
return id
|
||||
},
|
||||
result: &BumpResult{
|
||||
Event: TxPublished,
|
||||
Tx: tx,
|
||||
},
|
||||
removed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
requestID := tc.setupRecord()
|
||||
|
||||
// Attach the requestID from the setup.
|
||||
tc.result.requestID = requestID
|
||||
|
||||
// Remove the result.
|
||||
tp.removeResult(tc.result)
|
||||
|
||||
// Check if the record is removed.
|
||||
_, found := tp.records.Load(requestID)
|
||||
require.Equal(t, !tc.removed, found)
|
||||
|
||||
_, found = tp.subscriberChans.Load(requestID)
|
||||
require.Equal(t, !tc.removed, found)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotifyResult checks the subscribers are notified when a result is sent.
|
||||
func TestNotifyResult(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a publisher using the mocks.
|
||||
tp, m := createTestPublisher(t)
|
||||
|
||||
// Create a test bump request.
|
||||
req := createTestBumpRequest()
|
||||
|
||||
// Create a test tx.
|
||||
tx := &wire.MsgTx{LockTime: 1}
|
||||
|
||||
// Create a testing record and put it in the map.
|
||||
fee := btcutil.Amount(1000)
|
||||
requestID := tp.storeRecord(tx, req, m.feeFunc, fee)
|
||||
|
||||
// Create a subscription to the event.
|
||||
subscriber := make(chan *BumpResult, 1)
|
||||
tp.subscriberChans.Store(requestID, subscriber)
|
||||
|
||||
// Create a test result.
|
||||
result := &BumpResult{
|
||||
requestID: requestID,
|
||||
Tx: tx,
|
||||
}
|
||||
|
||||
// Notify the result and expect the subscriber to receive it.
|
||||
//
|
||||
// NOTE: must be done inside a goroutine in case it blocks.
|
||||
go tp.notifyResult(result)
|
||||
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for subscriber to receive result")
|
||||
|
||||
case received := <-subscriber:
|
||||
require.Equal(t, result, received)
|
||||
}
|
||||
|
||||
// Notify two results. This time it should block because the channel is
|
||||
// full. We then shutdown TxPublisher to test the quit behavior.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// Call notifyResult twice, which blocks at the second call.
|
||||
tp.notifyResult(result)
|
||||
tp.notifyResult(result)
|
||||
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Shutdown the publisher and expect notifyResult to exit.
|
||||
close(tp.quit)
|
||||
|
||||
// We expect to done chan.
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for notifyResult to exit")
|
||||
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
// TestBroadcastSuccess checks the public `Broadcast` method can successfully
|
||||
// broadcast a tx based on the request.
|
||||
func TestBroadcastSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a publisher using the mocks.
|
||||
tp, m := createTestPublisher(t)
|
||||
|
||||
// Create a test feerate.
|
||||
feerate := chainfee.SatPerKWeight(1000)
|
||||
|
||||
// Mock the fee estimator to return the testing fee rate.
|
||||
//
|
||||
// We are not testing `NewLinearFeeFunction` here, so the actual params
|
||||
// used are irrelevant.
|
||||
m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
|
||||
feerate, nil).Once()
|
||||
m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once()
|
||||
|
||||
// Mock the signer to always return a valid script.
|
||||
//
|
||||
// NOTE: we are not testing the utility of creating valid txes here, so
|
||||
// this is fine to be mocked. This behaves essentially as skipping the
|
||||
// Signer check and alaways assume the tx has a valid sig.
|
||||
script := &input.Script{}
|
||||
m.signer.On("ComputeInputScript", mock.Anything,
|
||||
mock.Anything).Return(script, nil)
|
||||
|
||||
// Mock the testmempoolaccept to pass.
|
||||
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
|
||||
|
||||
// Create a test conf event.
|
||||
confEvent := &chainntnfs.ConfirmationEvent{}
|
||||
|
||||
// Mock the RegisterConfirmationsNtfn to pass.
|
||||
m.notifier.On("RegisterConfirmationsNtfn",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(confEvent, nil).Once()
|
||||
|
||||
// Mock the wallet to publish successfully.
|
||||
m.wallet.On("PublishTransaction",
|
||||
mock.Anything, mock.Anything).Return(nil).Once()
|
||||
|
||||
// Create a test request.
|
||||
inp := createTestInput(1000, input.WitnessKeyHash)
|
||||
|
||||
// Create a testing bump request.
|
||||
req := &BumpRequest{
|
||||
DeliveryAddress: changePkScript,
|
||||
Inputs: []input.Input{&inp},
|
||||
Budget: btcutil.Amount(1000),
|
||||
MaxFeeRate: feerate,
|
||||
}
|
||||
|
||||
// Send the req and expect no error.
|
||||
resultChan, err := tp.Broadcast(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check the result is sent back.
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for subscriber to receive result")
|
||||
|
||||
case result := <-resultChan:
|
||||
// We expect the first result to be TxPublished.
|
||||
require.Equal(t, TxPublished, result.Event)
|
||||
}
|
||||
|
||||
// Validate the record was stored.
|
||||
require.Equal(t, 1, tp.records.Len())
|
||||
require.Equal(t, 1, tp.subscriberChans.Len())
|
||||
}
|
||||
|
||||
// TestBroadcastFail checks the public `Broadcast` returns the error or a
|
||||
// failed result when the broadcast fails.
|
||||
func TestBroadcastFail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a publisher using the mocks.
|
||||
tp, m := createTestPublisher(t)
|
||||
|
||||
// Create a test feerate.
|
||||
feerate := chainfee.SatPerKWeight(1000)
|
||||
|
||||
// Create a test request.
|
||||
inp := createTestInput(1000, input.WitnessKeyHash)
|
||||
|
||||
// Create a testing bump request.
|
||||
req := &BumpRequest{
|
||||
DeliveryAddress: changePkScript,
|
||||
Inputs: []input.Input{&inp},
|
||||
Budget: btcutil.Amount(1000),
|
||||
MaxFeeRate: feerate,
|
||||
}
|
||||
|
||||
// Mock the fee estimator to return the testing fee rate.
|
||||
//
|
||||
// We are not testing `NewLinearFeeFunction` here, so the actual params
|
||||
// used are irrelevant.
|
||||
m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
|
||||
feerate, nil).Twice()
|
||||
m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice()
|
||||
|
||||
// Mock the signer to always return a valid script.
|
||||
//
|
||||
// NOTE: we are not testing the utility of creating valid txes here, so
|
||||
// this is fine to be mocked. This behaves essentially as skipping the
|
||||
// Signer check and alaways assume the tx has a valid sig.
|
||||
script := &input.Script{}
|
||||
m.signer.On("ComputeInputScript", mock.Anything,
|
||||
mock.Anything).Return(script, nil)
|
||||
|
||||
// Mock the testmempoolaccept to return an error.
|
||||
m.wallet.On("CheckMempoolAcceptance",
|
||||
mock.Anything).Return(errDummy).Once()
|
||||
|
||||
// Send the req and expect an error returned.
|
||||
resultChan, err := tp.Broadcast(req)
|
||||
require.ErrorIs(t, err, errDummy)
|
||||
require.Nil(t, resultChan)
|
||||
|
||||
// Validate the record was NOT stored.
|
||||
require.Equal(t, 0, tp.records.Len())
|
||||
require.Equal(t, 0, tp.subscriberChans.Len())
|
||||
|
||||
// Mock the testmempoolaccept again, this time it passes.
|
||||
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
|
||||
|
||||
// Create a test conf event.
|
||||
confEvent := &chainntnfs.ConfirmationEvent{}
|
||||
|
||||
// Mock the RegisterConfirmationsNtfn to pass.
|
||||
m.notifier.On("RegisterConfirmationsNtfn",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(confEvent, nil).Once()
|
||||
|
||||
// Mock the wallet to fail on publish.
|
||||
m.wallet.On("PublishTransaction",
|
||||
mock.Anything, mock.Anything).Return(errDummy).Once()
|
||||
|
||||
// Send the req and expect no error returned.
|
||||
resultChan, err = tp.Broadcast(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check the result is sent back.
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for subscriber to receive result")
|
||||
|
||||
case result := <-resultChan:
|
||||
// We expect the result to be TxFailed and the error is set in
|
||||
// the result.
|
||||
require.Equal(t, TxFailed, result.Event)
|
||||
require.ErrorIs(t, result.Err, errDummy)
|
||||
}
|
||||
|
||||
// Validate the record was removed.
|
||||
require.Equal(t, 0, tp.records.Len())
|
||||
require.Equal(t, 0, tp.subscriberChans.Len())
|
||||
}
|
||||
|
|
|
@ -493,3 +493,32 @@ func (m *MockBumper) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) {
|
|||
|
||||
return args.Get(0).(chan *BumpResult), args.Error(1)
|
||||
}
|
||||
|
||||
// MockFeeFunction is a mock implementation of the FeeFunction interface.
|
||||
type MockFeeFunction struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Compile-time constraint to ensure MockFeeFunction implements FeeFunction.
|
||||
var _ FeeFunction = (*MockFeeFunction)(nil)
|
||||
|
||||
// FeeRate returns the current fee rate calculated by the fee function.
|
||||
func (m *MockFeeFunction) FeeRate() chainfee.SatPerKWeight {
|
||||
args := m.Called()
|
||||
|
||||
return args.Get(0).(chainfee.SatPerKWeight)
|
||||
}
|
||||
|
||||
// Increment adds one delta to the current fee rate.
|
||||
func (m *MockFeeFunction) Increment() (bool, error) {
|
||||
args := m.Called()
|
||||
|
||||
return args.Bool(0), args.Error(1)
|
||||
}
|
||||
|
||||
// IncreaseFeeRate increases the fee rate by one step.
|
||||
func (m *MockFeeFunction) IncreaseFeeRate(confTarget uint32) (bool, error) {
|
||||
args := m.Called(confTarget)
|
||||
|
||||
return args.Bool(0), args.Error(1)
|
||||
}
|
||||
|
|
|
@ -205,15 +205,14 @@ func createSweepTx(inputs []input.Input, outputs []*wire.TxOut,
|
|||
}
|
||||
}
|
||||
|
||||
log.Infof("Creating sweep transaction %v for %v inputs (%s) "+
|
||||
"using %v sat/kw, tx_weight=%v, tx_fee=%v, parents_count=%v, "+
|
||||
"parents_fee=%v, parents_weight=%v",
|
||||
log.Debugf("Creating sweep transaction %v for %v inputs (%s) "+
|
||||
"using %v, tx_weight=%v, tx_fee=%v, parents_count=%v, "+
|
||||
"parents_fee=%v, parents_weight=%v, current_height=%v",
|
||||
sweepTx.TxHash(), len(inputs),
|
||||
inputTypeSummary(inputs), feeRate,
|
||||
estimator.weight(), txFee,
|
||||
len(estimator.parents), estimator.parentsFee,
|
||||
estimator.parentsWeight,
|
||||
)
|
||||
estimator.parentsWeight, currentBlockHeight)
|
||||
|
||||
return sweepTx, txFee, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue