1
0
Fork 0
mirror of https://github.com/lightningnetwork/lnd.git synced 2025-03-26 08:55:59 +01:00

lnwallet+sweep: introduce TxPublisher to handle fee bump

This commit adds `TxPublisher` which implements `Bumper` interface. This
is part one of the implementation that focuses on implementing the
`Broadcast` method which guarantees a tx can be published with
RBF-compliant. It does so by leveraging the `testmempoolaccept` API,
keep increasing the fee rate until an RBF-compliant tx is made and
broadcasts it.

This tx will then be monitored by the `TxPublisher` and in the following
commit, the monitoring process will be added.
This commit is contained in:
yyforyongyu 2024-02-29 13:18:59 +08:00
parent ecd471ac75
commit 11f7e455d1
No known key found for this signature in database
GPG key ID: 9BCD95C4FF296868
6 changed files with 1548 additions and 5 deletions

View file

@ -1,6 +1,7 @@
package chainntnfs
import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn"
"github.com/stretchr/testify/mock"
@ -50,3 +51,73 @@ func (m *MockMempoolWatcher) LookupInputMempoolSpend(
return args.Get(0).(fn.Option[wire.MsgTx])
}
// MockNotifier is a mock implementation of the ChainNotifier interface.
type MockChainNotifier struct {
mock.Mock
}
// Compile-time check to ensure MockChainNotifier implements ChainNotifier.
var _ ChainNotifier = (*MockChainNotifier)(nil)
// RegisterConfirmationsNtfn registers an intent to be notified once txid
// reaches numConfs confirmations.
func (m *MockChainNotifier) RegisterConfirmationsNtfn(txid *chainhash.Hash,
pkScript []byte, numConfs, heightHint uint32,
opts ...NotifierOption) (*ConfirmationEvent, error) {
args := m.Called(txid, pkScript, numConfs, heightHint)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*ConfirmationEvent), args.Error(1)
}
// RegisterSpendNtfn registers an intent to be notified once the target
// outpoint is successfully spent within a transaction.
func (m *MockChainNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint,
pkScript []byte, heightHint uint32) (*SpendEvent, error) {
args := m.Called(outpoint, pkScript, heightHint)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*SpendEvent), args.Error(1)
}
// RegisterBlockEpochNtfn registers an intent to be notified of each new block
// connected to the tip of the main chain.
func (m *MockChainNotifier) RegisterBlockEpochNtfn(epoch *BlockEpoch) (
*BlockEpochEvent, error) {
args := m.Called(epoch)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*BlockEpochEvent), args.Error(1)
}
// Start the ChainNotifier. Once started, the implementation should be ready,
// and able to receive notification registrations from clients.
func (m *MockChainNotifier) Start() error {
args := m.Called()
return args.Error(0)
}
// Started returns true if this instance has been started, and false otherwise.
func (m *MockChainNotifier) Started() bool {
args := m.Called()
return args.Bool(0)
}
// Stops the concrete ChainNotifier.
func (m *MockChainNotifier) Stop() error {
args := m.Called()
return args.Error(0)
}

View file

@ -1,8 +1,14 @@
package input
import (
"crypto/sha256"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/keychain"
"github.com/stretchr/testify/mock"
)
@ -168,3 +174,100 @@ func (m *MockWitnessType) AddWeightEstimation(e *TxWeightEstimator) error {
return args.Error(0)
}
// MockInputSigner is a mock implementation of the Signer interface.
type MockInputSigner struct {
mock.Mock
}
// Compile-time constraint to ensure MockInputSigner implements Signer.
var _ Signer = (*MockInputSigner)(nil)
// SignOutputRaw generates a signature for the passed transaction according to
// the data within the passed SignDescriptor.
func (m *MockInputSigner) SignOutputRaw(tx *wire.MsgTx,
signDesc *SignDescriptor) (Signature, error) {
args := m.Called(tx, signDesc)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(Signature), args.Error(1)
}
// ComputeInputScript generates a complete InputIndex for the passed
// transaction with the signature as defined within the passed SignDescriptor.
func (m *MockInputSigner) ComputeInputScript(tx *wire.MsgTx,
signDesc *SignDescriptor) (*Script, error) {
args := m.Called(tx, signDesc)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*Script), args.Error(1)
}
// MuSig2CreateSession creates a new MuSig2 signing session using the local key
// identified by the key locator.
func (m *MockInputSigner) MuSig2CreateSession(version MuSig2Version,
locator keychain.KeyLocator, pubkey []*btcec.PublicKey,
tweak *MuSig2Tweaks, pubNonces [][musig2.PubNonceSize]byte,
nonces *musig2.Nonces) (*MuSig2SessionInfo, error) {
args := m.Called(version, locator, pubkey, tweak, pubNonces, nonces)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*MuSig2SessionInfo), args.Error(1)
}
// MuSig2RegisterNonces registers one or more public nonces of other signing
// participants for a session identified by its ID.
func (m *MockInputSigner) MuSig2RegisterNonces(versio MuSig2SessionID,
pubNonces [][musig2.PubNonceSize]byte) (bool, error) {
args := m.Called(versio, pubNonces)
if args.Get(0) == nil {
return false, args.Error(1)
}
return args.Bool(0), args.Error(1)
}
// MuSig2Sign creates a partial signature using the local signing key that was
// specified when the session was created.
func (m *MockInputSigner) MuSig2Sign(sessionID MuSig2SessionID,
msg [sha256.Size]byte, withSortedKeys bool) (
*musig2.PartialSignature, error) {
args := m.Called(sessionID, msg, withSortedKeys)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*musig2.PartialSignature), args.Error(1)
}
// MuSig2CombineSig combines the given partial signature(s) with the local one,
// if it already exists.
func (m *MockInputSigner) MuSig2CombineSig(sessionID MuSig2SessionID,
partialSig []*musig2.PartialSignature) (
*schnorr.Signature, bool, error) {
args := m.Called(sessionID, partialSig)
if args.Get(0) == nil {
return nil, false, args.Error(2)
}
return args.Get(0).(*schnorr.Signature), args.Bool(1), args.Error(2)
}
// MuSig2Cleanup removes a session from memory to free up resources.
func (m *MockInputSigner) MuSig2Cleanup(sessionID MuSig2SessionID) error {
args := m.Called(sessionID)
return args.Error(0)
}

View file

@ -3,16 +3,29 @@ package sweep
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/rpcclient"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/chain"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/labels"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
var (
// ErrInvalidBumpResult is returned when the bump result is invalid.
ErrInvalidBumpResult = errors.New("invalid bump result")
// ErrNotEnoughBudget is returned when the fee bumper decides the
// current budget cannot cover the fee.
ErrNotEnoughBudget = errors.New("not enough budget")
)
// Bumper defines an interface that can be used by other subsystems for fee
@ -165,6 +178,9 @@ type BumpResult struct {
// Err is the error that occurred during the broadcast.
Err error
// requestID is the ID of the request that created this record.
requestID uint64
}
// Validate validates the BumpResult so it's safe to use.
@ -197,3 +213,460 @@ func (b *BumpResult) Validate() error {
return nil
}
// TxPublisherConfig is the config used to create a new TxPublisher.
type TxPublisherConfig struct {
// Signer is used to create the tx signature.
Signer input.Signer
// Wallet is used primarily to publish the tx.
Wallet Wallet
// Estimator is used to estimate the fee rate for the new tx based on
// its deadline conf target.
Estimator chainfee.Estimator
// Notifier is used to monitor the confirmation status of the tx.
Notifier chainntnfs.ChainNotifier
}
// TxPublisher is an implementation of the Bumper interface. It utilizes the
// `testmempoolaccept` RPC to bump the fee of txns it created based on
// different fee function selected or configed by the caller. Its purpose is to
// take a list of inputs specified, and create a tx that spends them to a
// specified output. It will then monitor the confirmation status of the tx,
// and if it's not confirmed within a certain time frame, it will attempt to
// bump the fee of the tx by creating a new tx that spends the same inputs to
// the same output, but with a higher fee rate. It will continue to do this
// until the tx is confirmed or the fee rate reaches the maximum fee rate
// specified by the caller.
type TxPublisher struct {
wg sync.WaitGroup
// cfg specifies the configuration of the TxPublisher.
cfg *TxPublisherConfig
// currentHeight is the current block height.
currentHeight int32
// records is a map keyed by the requestCounter and the value is the tx
// being monitored.
records lnutils.SyncMap[uint64, *monitorRecord]
// requestCounter is a monotonically increasing counter used to keep
// track of how many requests have been made.
requestCounter atomic.Uint64
// subscriberChans is a map keyed by the requestCounter, each item is
// the chan that the publisher sends the fee bump result to.
subscriberChans lnutils.SyncMap[uint64, chan *BumpResult]
// quit is used to signal the publisher to stop.
quit chan struct{}
}
// Compile-time constraint to ensure TxPublisher implements Bumper.
var _ Bumper = (*TxPublisher)(nil)
// NewTxPublisher creates a new TxPublisher.
func NewTxPublisher(cfg TxPublisherConfig) *TxPublisher {
return &TxPublisher{
cfg: &cfg,
records: lnutils.SyncMap[uint64, *monitorRecord]{},
subscriberChans: lnutils.SyncMap[uint64, chan *BumpResult]{},
quit: make(chan struct{}),
}
}
// Broadcast is used to publish the tx created from the given inputs. It will,
// 1. init a fee function based on the given strategy.
// 2. create an RBF-compliant tx and monitor it for confirmation.
// 3. notify the initial broadcast result back to the caller.
// The initial broadcast is guaranteed to be RBF-compliant unless the budget
// specified cannot cover the fee.
//
// NOTE: part of the Bumper interface.
func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) {
log.Tracef("Received broadcast request: %s", newLogClosure(
func() string {
return spew.Sdump(req)
})())
// Attempt an initial broadcast which is guaranteed to comply with the
// RBF rules.
result, err := t.initialBroadcast(req)
if err != nil {
log.Errorf("Initial broadcast failed: %v", err)
return nil, err
}
// Create a chan to send the result to the caller.
subscriber := make(chan *BumpResult, 1)
t.subscriberChans.Store(result.requestID, subscriber)
// Send the initial broadcast result to the caller.
t.handleResult(result)
return subscriber, nil
}
// initialBroadcast initializes a fee function, creates an RBF-compliant tx and
// broadcasts it.
func (t *TxPublisher) initialBroadcast(req *BumpRequest) (*BumpResult, error) {
// Create a fee bumping algorithm to be used for future RBF.
feeAlgo, err := t.initializeFeeFunction(req)
if err != nil {
return nil, fmt.Errorf("init fee function: %w", err)
}
// Create the initial tx to be broadcasted. This tx is guaranteed to
// comply with the RBF restrictions.
requestID, err := t.createRBFCompliantTx(req, feeAlgo)
if err != nil {
return nil, fmt.Errorf("create RBF-compliant tx: %w", err)
}
// Broadcast the tx and return the monitored record.
result, err := t.broadcast(requestID)
if err != nil {
return nil, fmt.Errorf("broadcast sweep tx: %w", err)
}
return result, nil
}
// initializeFeeFunction initializes a fee function to be used for this request
// for future fee bumping.
func (t *TxPublisher) initializeFeeFunction(
req *BumpRequest) (FeeFunction, error) {
// Get the max allowed feerate.
maxFeeRateAllowed, err := req.MaxFeeRateAllowed()
if err != nil {
return nil, err
}
// Get the initial conf target.
confTarget := calcCurrentConfTarget(t.currentHeight, req.DeadlineHeight)
// Initialize the fee function and return it.
//
// TODO(yy): return based on differet req.Strategy?
return NewLinearFeeFunction(
maxFeeRateAllowed, confTarget, t.cfg.Estimator,
)
}
// createRBFCompliantTx creates a tx that is compliant with RBF rules. It does
// so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee
// and redo the process until the tx is valid, or return an error when non-RBF
// related errors occur or the budget has been used up.
func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest,
f FeeFunction) (uint64, error) {
for {
// Create a new tx with the given fee rate and check its
// mempool acceptance.
tx, fee, err := t.createAndCheckTx(req, f)
switch {
case err == nil:
// The tx is valid, return the request ID.
requestID := t.storeRecord(tx, req, f, fee)
log.Infof("Created tx %v for %v inputs: feerate=%v, "+
"fee=%v, inputs=%v", tx.TxHash(),
len(req.Inputs), f.FeeRate(), fee,
inputTypeSummary(req.Inputs))
return requestID, nil
// If the error indicates the fees paid is not enough, we will
// ask the fee function to increase the fee rate and retry.
case errors.Is(err, lnwallet.ErrMempoolFee):
// We should at least start with a feerate above the
// mempool min feerate, so if we get this error, it
// means something is wrong earlier in the pipeline.
log.Errorf("Current fee=%v, feerate=%v, %v", fee,
f.FeeRate(), err)
fallthrough
// We are not paying enough fees so we increase it.
case errors.Is(err, rpcclient.ErrInsufficientFee):
increased := false
// Keep calling the fee function until the fee rate is
// increased or maxed out.
for !increased {
log.Debugf("Increasing fee for next round, "+
"current fee=%v, feerate=%v", fee,
f.FeeRate())
// If the fee function tells us that we have
// used up the budget, we will return an error
// indicating this tx cannot be made. The
// sweeper should handle this error and try to
// cluster these inputs differetly.
increased, err = f.Increment()
if err != nil {
return 0, err
}
}
// TODO(yy): suppose there's only one bad input, we can do a
// binary search to find out which input is causing this error
// by recreating a tx using half of the inputs and check its
// mempool acceptance.
default:
log.Debugf("Failed to create RBF-compliant tx: %v", err)
return 0, err
}
}
}
// storeRecord stores the given record in the records map.
func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest,
f FeeFunction, fee btcutil.Amount) uint64 {
// Increase the request counter.
//
// NOTE: this is the only place where we increase the
// counter.
requestID := t.requestCounter.Add(1)
// Register the record.
t.records.Store(requestID, &monitorRecord{
tx: tx,
req: req,
feeFunction: f,
fee: fee,
})
return requestID
}
// createAndCheckTx creates a tx based on the given inputs, change output
// script, and the fee rate. In addition, it validates the tx's mempool
// acceptance before returning a tx that can be published directly, along with
// its fee.
func (t *TxPublisher) createAndCheckTx(req *BumpRequest, f FeeFunction) (
*wire.MsgTx, btcutil.Amount, error) {
// Create the sweep tx with max fee rate of 0 as the fee function
// guarantees the fee rate used here won't exceed the max fee rate.
//
// TODO(yy): refactor this function to not require a max fee rate.
tx, fee, err := createSweepTx(
req.Inputs, nil, req.DeliveryAddress, uint32(t.currentHeight),
f.FeeRate(), 0, t.cfg.Signer,
)
if err != nil {
return nil, 0, fmt.Errorf("create sweep tx: %w", err)
}
// Sanity check the budget still covers the fee.
if fee > req.Budget {
return nil, 0, fmt.Errorf("%w: budget=%v, fee=%v",
ErrNotEnoughBudget, req.Budget, fee)
}
// Validate the tx's mempool acceptance.
err = t.cfg.Wallet.CheckMempoolAcceptance(tx)
// Exit early if the tx is valid.
if err == nil {
return tx, fee, nil
}
// Print an error log if the chain backend doesn't support the mempool
// acceptance test RPC.
if errors.Is(err, rpcclient.ErrBackendVersion) {
log.Errorf("TestMempoolAccept not supported by backend, " +
"consider upgrading it to a newer version")
return tx, fee, nil
}
// We are running on a backend that doesn't implement the RPC
// testmempoolaccept, eg, neutrino, so we'll skip the check.
if errors.Is(err, chain.ErrUnimplemented) {
log.Debug("Skipped testmempoolaccept due to not implemented")
return tx, fee, nil
}
return nil, 0, err
}
// broadcast takes a monitored tx and publishes it to the network. Prior to the
// broadcast, it will subscribe the tx's confirmation notification and attach
// the event channel to the record. Any broadcast-related errors will not be
// returned here, instead, they will be put inside the `BumpResult` and
// returned to the caller.
func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) {
// Get the record being monitored.
record, ok := t.records.Load(requestID)
if !ok {
return nil, fmt.Errorf("tx record %v not found", requestID)
}
txid := record.tx.TxHash()
// Subscribe to its confirmation notification.
confEvent, err := t.cfg.Notifier.RegisterConfirmationsNtfn(
&txid, nil, 1, uint32(t.currentHeight),
)
if err != nil {
return nil, fmt.Errorf("register confirmation ntfn: %w", err)
}
// Attach the confirmation event channel to the record.
record.confEvent = confEvent
tx := record.tx
log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v",
txid, len(tx.TxIn), t.currentHeight)
// Set the event, and change it to TxFailed if the wallet fails to
// publish it.
event := TxPublished
// Publish the sweeping tx with customized label. If the publish fails,
// this error will be saved in the `BumpResult` and it will be removed
// from being monitored.
err = t.cfg.Wallet.PublishTransaction(
tx, labels.MakeLabel(labels.LabelTypeSweepTransaction, nil),
)
if err != nil {
// NOTE: we decide to attach this error to the result instead
// of returning it here because by the time the tx reaches
// here, it should have passed the mempool acceptance check. If
// it still fails to be broadcast, it's likely a non-RBF
// related error happened. So we send this error back to the
// caller so that it can handle it properly.
//
// TODO(yy): find out which input is causing the failure.
log.Errorf("Failed to publish tx %v: %v", txid, err)
event = TxFailed
}
result := &BumpResult{
Event: event,
Tx: record.tx,
Fee: record.fee,
FeeRate: record.feeFunction.FeeRate(),
Err: err,
requestID: requestID,
}
return result, nil
}
// notifyResult sends the result to the resultChan specified by the requestID.
// This channel is expected to be read by the caller.
func (t *TxPublisher) notifyResult(result *BumpResult) {
id := result.requestID
subscriber, ok := t.subscriberChans.Load(id)
if !ok {
log.Errorf("Result chan for id=%v not found", id)
return
}
log.Debugf("Sending result for requestID=%v, tx=%v", id,
result.Tx.TxHash())
select {
// Send the result to the subscriber.
//
// TODO(yy): Add timeout in case it's blocking?
case subscriber <- result:
case <-t.quit:
log.Debug("Fee bumper stopped")
}
}
// removeResult removes the tracking of the result if the result contains a
// non-nil error, or the tx is confirmed, the record will be removed from the
// maps.
func (t *TxPublisher) removeResult(result *BumpResult) {
id := result.requestID
// Remove the record from the maps if there's an error. This means this
// tx has failed its broadcast and cannot be retried. There are two
// cases,
// - when the budget cannot cover the fee.
// - when a non-RBF related error occurs.
switch result.Event {
case TxFailed:
log.Errorf("Removing monitor record=%v, tx=%v, due to err: %v",
id, result.Tx.TxHash(), result.Err)
case TxConfirmed:
// Remove the record is the tx is confirmed.
log.Debugf("Removing confirmed monitor record=%v, tx=%v", id,
result.Tx.TxHash())
// Do nothing if it's neither failed or confirmed.
default:
log.Tracef("Skipping record removal for id=%v, event=%v", id,
result.Event)
return
}
t.records.Delete(id)
t.subscriberChans.Delete(id)
}
// handleResult handles the result of a tx broadcast. It will notify the
// subscriber and remove the record if the tx is confirmed or failed to be
// broadcast.
func (t *TxPublisher) handleResult(result *BumpResult) {
// Notify the subscriber.
t.notifyResult(result)
// Remove the record if it's failed or confirmed.
t.removeResult(result)
}
// monitorRecord is used to keep track of the tx being monitored by the
// publisher internally.
type monitorRecord struct {
// tx is the tx being monitored.
tx *wire.MsgTx
// req is the original request.
req *BumpRequest
// confEvent is the subscription to the confirmation event of the tx.
confEvent *chainntnfs.ConfirmationEvent
// feeFunction is the fee bumping algorithm used by the publisher.
feeFunction FeeFunction
// fee is the fee paid by the tx.
fee btcutil.Amount
}
// calcCurrentConfTarget calculates the current confirmation target based on
// the deadline height. The conf target is capped at 0 if the deadline has
// already been past.
func calcCurrentConfTarget(currentHeight, deadline int32) uint32 {
var confTarget uint32
// Calculate how many blocks left until the deadline.
deadlineDelta := deadline - currentHeight
// If we are already past the deadline, we will set the conf target to
// be 1.
if deadlineDelta <= 0 {
log.Warnf("Deadline is %d blocks behind current height %v",
-deadlineDelta, currentHeight)
confTarget = 1
} else {
confTarget = uint32(deadlineDelta)
}
return confTarget
}

View file

@ -1,12 +1,18 @@
package sweep
import (
"fmt"
"testing"
"time"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/rpcclient"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
@ -171,3 +177,865 @@ func TestBumpRequestMaxFeeRateAllowed(t *testing.T) {
})
}
}
// TestCalcCurrentConfTarget checks that the current confirmation target is
// calculated correctly.
func TestCalcCurrentConfTarget(t *testing.T) {
t.Parallel()
// When the current block height is 100 and deadline height is 200, the
// conf target should be 100.
conf := calcCurrentConfTarget(int32(100), int32(200))
require.EqualValues(t, 100, conf)
// When the current block height is 200 and deadline height is 100, the
// conf target should be 1 since the deadline has passed.
conf = calcCurrentConfTarget(int32(200), int32(100))
require.EqualValues(t, 1, conf)
}
// TestInitializeFeeFunction tests the initialization of the fee function.
func TestInitializeFeeFunction(t *testing.T) {
t.Parallel()
// Create a test input.
inp := createTestInput(100, input.WitnessKeyHash)
// Create a mock fee estimator.
estimator := &chainfee.MockEstimator{}
defer estimator.AssertExpectations(t)
// Create a publisher using the mocks.
tp := NewTxPublisher(TxPublisherConfig{
Estimator: estimator,
})
// Create a test feerate.
feerate := chainfee.SatPerKWeight(1000)
// Create a testing bump request.
req := &BumpRequest{
DeliveryAddress: changePkScript,
Inputs: []input.Input{&inp},
Budget: btcutil.Amount(1000),
MaxFeeRate: feerate,
}
// Mock the fee estimator to return an error.
//
// We are not testing `NewLinearFeeFunction` here, so the actual params
// used are irrelevant.
dummyErr := fmt.Errorf("dummy error")
estimator.On("EstimateFeePerKW", mock.Anything).Return(
chainfee.SatPerKWeight(0), dummyErr).Once()
// Call the method under test and assert the error is returned.
f, err := tp.initializeFeeFunction(req)
require.ErrorIs(t, err, dummyErr)
require.Nil(t, f)
// Mock the fee estimator to return the testing fee rate.
//
// We are not testing `NewLinearFeeFunction` here, so the actual params
// used are irrelevant.
estimator.On("EstimateFeePerKW", mock.Anything).Return(
feerate, nil).Once()
estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once()
// Call the method under test.
f, err = tp.initializeFeeFunction(req)
require.NoError(t, err)
require.Equal(t, feerate, f.FeeRate())
}
// TestStoreRecord correctly increases the request counter and saves the
// record.
func TestStoreRecord(t *testing.T) {
t.Parallel()
// Create a test input.
inp := createTestInput(1000, input.WitnessKeyHash)
// Create a bump request.
req := &BumpRequest{
DeliveryAddress: changePkScript,
Inputs: []input.Input{&inp},
Budget: btcutil.Amount(1000),
}
// Create a naive fee function.
feeFunc := &LinearFeeFunction{}
// Create a test fee and tx.
fee := btcutil.Amount(1000)
tx := &wire.MsgTx{}
// Create a publisher using the mocks.
tp := NewTxPublisher(TxPublisherConfig{})
// Get the current counter and check it's increased later.
initialCounter := tp.requestCounter.Load()
// Call the method under test.
requestID := tp.storeRecord(tx, req, feeFunc, fee)
// Check the request ID is as expected.
require.Equal(t, initialCounter+1, requestID)
// Read the saved record and compare.
record, ok := tp.records.Load(requestID)
require.True(t, ok)
require.Equal(t, tx, record.tx)
require.Equal(t, feeFunc, record.feeFunction)
require.Equal(t, fee, record.fee)
require.Equal(t, req, record.req)
}
// mockers wraps a list of mocked interfaces used inside tx publisher.
type mockers struct {
signer *input.MockInputSigner
wallet *MockWallet
estimator *chainfee.MockEstimator
notifier *chainntnfs.MockChainNotifier
feeFunc *MockFeeFunction
}
// createTestPublisher creates a new tx publisher using the provided mockers.
func createTestPublisher(t *testing.T) (*TxPublisher, *mockers) {
// Create a mock fee estimator.
estimator := &chainfee.MockEstimator{}
// Create a mock fee function.
feeFunc := &MockFeeFunction{}
// Create a mock signer.
signer := &input.MockInputSigner{}
// Create a mock wallet.
wallet := &MockWallet{}
// Create a mock chain notifier.
notifier := &chainntnfs.MockChainNotifier{}
t.Cleanup(func() {
estimator.AssertExpectations(t)
feeFunc.AssertExpectations(t)
signer.AssertExpectations(t)
wallet.AssertExpectations(t)
notifier.AssertExpectations(t)
})
m := &mockers{
signer: signer,
wallet: wallet,
estimator: estimator,
notifier: notifier,
feeFunc: feeFunc,
}
// Create a publisher using the mocks.
tp := NewTxPublisher(TxPublisherConfig{
Estimator: m.estimator,
Signer: m.signer,
Wallet: m.wallet,
Notifier: m.notifier,
})
return tp, m
}
// TestCreateAndCheckTx checks `createAndCheckTx` behaves as expected.
func TestCreateAndCheckTx(t *testing.T) {
t.Parallel()
// Create a test request.
inp := createTestInput(1000, input.WitnessKeyHash)
// Create a publisher using the mocks.
tp, m := createTestPublisher(t)
// Create a test feerate and return it from the mock fee function.
feerate := chainfee.SatPerKWeight(1000)
m.feeFunc.On("FeeRate").Return(feerate)
// Mock the wallet to fail on testmempoolaccept on the first call, and
// succeed on the second.
m.wallet.On("CheckMempoolAcceptance",
mock.Anything).Return(errDummy).Once()
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
// Mock the signer to always return a valid script.
//
// NOTE: we are not testing the utility of creating valid txes here, so
// this is fine to be mocked. This behaves essentially as skipping the
// Signer check and alaways assume the tx has a valid sig.
script := &input.Script{}
m.signer.On("ComputeInputScript", mock.Anything,
mock.Anything).Return(script, nil)
testCases := []struct {
name string
req *BumpRequest
expectedErr error
}{
{
// When the budget cannot cover the fee, an error
// should be returned.
name: "not enough budget",
req: &BumpRequest{
DeliveryAddress: changePkScript,
Inputs: []input.Input{&inp},
},
expectedErr: ErrNotEnoughBudget,
},
{
// When the mempool rejects the transaction, an error
// should be returned.
name: "testmempoolaccept fail",
req: &BumpRequest{
DeliveryAddress: changePkScript,
Inputs: []input.Input{&inp},
Budget: btcutil.Amount(1000),
},
expectedErr: errDummy,
},
{
// When the mempool accepts the transaction, no error
// should be returned.
name: "testmempoolaccept pass",
req: &BumpRequest{
DeliveryAddress: changePkScript,
Inputs: []input.Input{&inp},
Budget: btcutil.Amount(1000),
},
expectedErr: nil,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
// Call the method under test.
_, _, err := tp.createAndCheckTx(tc.req, m.feeFunc)
// Check the result is as expected.
require.ErrorIs(t, err, tc.expectedErr)
})
}
}
// createTestBumpRequest creates a new bump request.
func createTestBumpRequest() *BumpRequest {
// Create a test input.
inp := createTestInput(1000, input.WitnessKeyHash)
return &BumpRequest{
DeliveryAddress: changePkScript,
Inputs: []input.Input{&inp},
Budget: btcutil.Amount(1000),
}
}
// TestCreateRBFCompliantTx checks that `createRBFCompliantTx` behaves as
// expected.
func TestCreateRBFCompliantTx(t *testing.T) {
t.Parallel()
// Create a publisher using the mocks.
tp, m := createTestPublisher(t)
// Create a test bump request.
req := createTestBumpRequest()
// Create a test feerate and return it from the mock fee function.
feerate := chainfee.SatPerKWeight(1000)
m.feeFunc.On("FeeRate").Return(feerate)
// Mock the signer to always return a valid script.
//
// NOTE: we are not testing the utility of creating valid txes here, so
// this is fine to be mocked. This behaves essentially as skipping the
// Signer check and alaways assume the tx has a valid sig.
script := &input.Script{}
m.signer.On("ComputeInputScript", mock.Anything,
mock.Anything).Return(script, nil)
testCases := []struct {
name string
setupMock func()
expectedErr error
}{
{
// When testmempoolaccept accepts the tx, no error
// should be returned.
name: "success case",
setupMock: func() {
// Mock the testmempoolaccept to pass.
m.wallet.On("CheckMempoolAcceptance",
mock.Anything).Return(nil).Once()
},
expectedErr: nil,
},
{
// When testmempoolaccept fails due to a non-fee
// related error, an error should be returned.
name: "non-fee related testmempoolaccept fail",
setupMock: func() {
// Mock the testmempoolaccept to fail.
m.wallet.On("CheckMempoolAcceptance",
mock.Anything).Return(errDummy).Once()
},
expectedErr: errDummy,
},
{
// When increase feerate gives an error, the error
// should be returned.
name: "fail on increase fee",
setupMock: func() {
// Mock the testmempoolaccept to fail on fee.
m.wallet.On("CheckMempoolAcceptance",
mock.Anything).Return(
lnwallet.ErrMempoolFee).Once()
// Mock the fee function to return an error.
m.feeFunc.On("Increment").Return(
false, errDummy).Once()
},
expectedErr: errDummy,
},
{
// Test that after one round of increasing the feerate
// the tx passes testmempoolaccept.
name: "increase fee and success on min mempool fee",
setupMock: func() {
// Mock the testmempoolaccept to fail on fee
// for the first call.
m.wallet.On("CheckMempoolAcceptance",
mock.Anything).Return(
lnwallet.ErrMempoolFee).Once()
// Mock the fee function to increase feerate.
m.feeFunc.On("Increment").Return(
true, nil).Once()
// Mock the testmempoolaccept to pass on the
// second call.
m.wallet.On("CheckMempoolAcceptance",
mock.Anything).Return(nil).Once()
},
expectedErr: nil,
},
{
// Test that after one round of increasing the feerate
// the tx passes testmempoolaccept.
name: "increase fee and success on insufficient fee",
setupMock: func() {
// Mock the testmempoolaccept to fail on fee
// for the first call.
m.wallet.On("CheckMempoolAcceptance",
mock.Anything).Return(
rpcclient.ErrInsufficientFee).Once()
// Mock the fee function to increase feerate.
m.feeFunc.On("Increment").Return(
true, nil).Once()
// Mock the testmempoolaccept to pass on the
// second call.
m.wallet.On("CheckMempoolAcceptance",
mock.Anything).Return(nil).Once()
},
expectedErr: nil,
},
{
// Test that the fee function increases the fee rate
// after one round.
name: "increase fee on second round",
setupMock: func() {
// Mock the testmempoolaccept to fail on fee
// for the first call.
m.wallet.On("CheckMempoolAcceptance",
mock.Anything).Return(
rpcclient.ErrInsufficientFee).Once()
// Mock the fee function to NOT increase
// feerate on the first round.
m.feeFunc.On("Increment").Return(
false, nil).Once()
// Mock the fee function to increase feerate.
m.feeFunc.On("Increment").Return(
true, nil).Once()
// Mock the testmempoolaccept to pass on the
// second call.
m.wallet.On("CheckMempoolAcceptance",
mock.Anything).Return(nil).Once()
},
expectedErr: nil,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
tc.setupMock()
// Call the method under test.
id, err := tp.createRBFCompliantTx(req, m.feeFunc)
// Check the result is as expected.
require.ErrorIs(t, err, tc.expectedErr)
// If there's an error, expect the requestID to be
// empty.
if tc.expectedErr != nil {
require.Zero(t, id)
}
})
}
}
// TestTxPublisherBroadcast checks the internal `broadcast` method behaves as
// expected.
func TestTxPublisherBroadcast(t *testing.T) {
t.Parallel()
// Create a publisher using the mocks.
tp, m := createTestPublisher(t)
// Create a test bump request.
req := createTestBumpRequest()
// Create a test tx.
tx := &wire.MsgTx{LockTime: 1}
txid := tx.TxHash()
// Create a test feerate and return it from the mock fee function.
feerate := chainfee.SatPerKWeight(1000)
m.feeFunc.On("FeeRate").Return(feerate)
// Create a test conf event.
confEvent := &chainntnfs.ConfirmationEvent{}
// Create a testing record and put it in the map.
fee := btcutil.Amount(1000)
requestID := tp.storeRecord(tx, req, m.feeFunc, fee)
// Quickly check when the requestID cannot be found, an error is
// returned.
result, err := tp.broadcast(uint64(1000))
require.Error(t, err)
require.Nil(t, result)
// Define params to be used in RegisterConfirmationsNtfn. Not important
// for this test.
var pkScript []byte
confs := uint32(1)
height := uint32(tp.currentHeight)
testCases := []struct {
name string
setupMock func()
expectedErr error
expectedResult *BumpResult
}{
{
// When the notifier cannot register this spend, an
// error should be returned
name: "fail to register nftn",
setupMock: func() {
// Mock the RegisterConfirmationsNtfn to fail.
m.notifier.On("RegisterConfirmationsNtfn",
&txid, pkScript, confs, height).Return(
nil, errDummy).Once()
},
expectedErr: errDummy,
expectedResult: nil,
},
{
// When the wallet cannot publish this tx, the error
// should be put inside the result.
name: "fail to publish",
setupMock: func() {
// Mock the RegisterConfirmationsNtfn to pass.
m.notifier.On("RegisterConfirmationsNtfn",
&txid, pkScript, confs, height).Return(
confEvent, nil).Once()
// Mock the wallet to fail to publish.
m.wallet.On("PublishTransaction",
tx, mock.Anything).Return(
errDummy).Once()
},
expectedErr: nil,
expectedResult: &BumpResult{
Event: TxFailed,
Tx: tx,
Fee: fee,
FeeRate: feerate,
Err: errDummy,
requestID: requestID,
},
},
{
// When nothing goes wrong, the result is returned.
name: "publish success",
setupMock: func() {
// Mock the RegisterConfirmationsNtfn to pass.
m.notifier.On("RegisterConfirmationsNtfn",
&txid, pkScript, confs, height).Return(
confEvent, nil).Once()
// Mock the wallet to publish successfully.
m.wallet.On("PublishTransaction",
tx, mock.Anything).Return(nil).Once()
},
expectedErr: nil,
expectedResult: &BumpResult{
Event: TxPublished,
Tx: tx,
Fee: fee,
FeeRate: feerate,
Err: nil,
requestID: requestID,
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
tc.setupMock()
// Call the method under test.
result, err := tp.broadcast(requestID)
// Check the result is as expected.
require.ErrorIs(t, err, tc.expectedErr)
require.Equal(t, tc.expectedResult, result)
})
}
}
// TestRemoveResult checks the records and subscriptions are removed when a tx
// is confirmed or failed.
func TestRemoveResult(t *testing.T) {
t.Parallel()
// Create a publisher using the mocks.
tp, m := createTestPublisher(t)
// Create a test bump request.
req := createTestBumpRequest()
// Create a test tx.
tx := &wire.MsgTx{LockTime: 1}
// Create a testing record and put it in the map.
fee := btcutil.Amount(1000)
testCases := []struct {
name string
setupRecord func() uint64
result *BumpResult
removed bool
}{
{
// When the tx is confirmed, the records will be
// removed.
name: "remove on TxConfirmed",
setupRecord: func() uint64 {
id := tp.storeRecord(tx, req, m.feeFunc, fee)
tp.subscriberChans.Store(id, nil)
return id
},
result: &BumpResult{
Event: TxConfirmed,
Tx: tx,
},
removed: true,
},
{
// When the tx is failed, the records will be removed.
name: "remove on TxFailed",
setupRecord: func() uint64 {
id := tp.storeRecord(tx, req, m.feeFunc, fee)
tp.subscriberChans.Store(id, nil)
return id
},
result: &BumpResult{
Event: TxFailed,
Err: errDummy,
Tx: tx,
},
removed: true,
},
{
// Noop when the tx is neither confirmed or failed.
name: "noop when tx is not confirmed or failed",
setupRecord: func() uint64 {
id := tp.storeRecord(tx, req, m.feeFunc, fee)
tp.subscriberChans.Store(id, nil)
return id
},
result: &BumpResult{
Event: TxPublished,
Tx: tx,
},
removed: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
requestID := tc.setupRecord()
// Attach the requestID from the setup.
tc.result.requestID = requestID
// Remove the result.
tp.removeResult(tc.result)
// Check if the record is removed.
_, found := tp.records.Load(requestID)
require.Equal(t, !tc.removed, found)
_, found = tp.subscriberChans.Load(requestID)
require.Equal(t, !tc.removed, found)
})
}
}
// TestNotifyResult checks the subscribers are notified when a result is sent.
func TestNotifyResult(t *testing.T) {
t.Parallel()
// Create a publisher using the mocks.
tp, m := createTestPublisher(t)
// Create a test bump request.
req := createTestBumpRequest()
// Create a test tx.
tx := &wire.MsgTx{LockTime: 1}
// Create a testing record and put it in the map.
fee := btcutil.Amount(1000)
requestID := tp.storeRecord(tx, req, m.feeFunc, fee)
// Create a subscription to the event.
subscriber := make(chan *BumpResult, 1)
tp.subscriberChans.Store(requestID, subscriber)
// Create a test result.
result := &BumpResult{
requestID: requestID,
Tx: tx,
}
// Notify the result and expect the subscriber to receive it.
//
// NOTE: must be done inside a goroutine in case it blocks.
go tp.notifyResult(result)
select {
case <-time.After(time.Second):
t.Fatal("timeout waiting for subscriber to receive result")
case received := <-subscriber:
require.Equal(t, result, received)
}
// Notify two results. This time it should block because the channel is
// full. We then shutdown TxPublisher to test the quit behavior.
done := make(chan struct{})
go func() {
// Call notifyResult twice, which blocks at the second call.
tp.notifyResult(result)
tp.notifyResult(result)
close(done)
}()
// Shutdown the publisher and expect notifyResult to exit.
close(tp.quit)
// We expect to done chan.
select {
case <-time.After(time.Second):
t.Fatal("timeout waiting for notifyResult to exit")
case <-done:
}
}
// TestBroadcastSuccess checks the public `Broadcast` method can successfully
// broadcast a tx based on the request.
func TestBroadcastSuccess(t *testing.T) {
t.Parallel()
// Create a publisher using the mocks.
tp, m := createTestPublisher(t)
// Create a test feerate.
feerate := chainfee.SatPerKWeight(1000)
// Mock the fee estimator to return the testing fee rate.
//
// We are not testing `NewLinearFeeFunction` here, so the actual params
// used are irrelevant.
m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
feerate, nil).Once()
m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once()
// Mock the signer to always return a valid script.
//
// NOTE: we are not testing the utility of creating valid txes here, so
// this is fine to be mocked. This behaves essentially as skipping the
// Signer check and alaways assume the tx has a valid sig.
script := &input.Script{}
m.signer.On("ComputeInputScript", mock.Anything,
mock.Anything).Return(script, nil)
// Mock the testmempoolaccept to pass.
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
// Create a test conf event.
confEvent := &chainntnfs.ConfirmationEvent{}
// Mock the RegisterConfirmationsNtfn to pass.
m.notifier.On("RegisterConfirmationsNtfn",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(confEvent, nil).Once()
// Mock the wallet to publish successfully.
m.wallet.On("PublishTransaction",
mock.Anything, mock.Anything).Return(nil).Once()
// Create a test request.
inp := createTestInput(1000, input.WitnessKeyHash)
// Create a testing bump request.
req := &BumpRequest{
DeliveryAddress: changePkScript,
Inputs: []input.Input{&inp},
Budget: btcutil.Amount(1000),
MaxFeeRate: feerate,
}
// Send the req and expect no error.
resultChan, err := tp.Broadcast(req)
require.NoError(t, err)
// Check the result is sent back.
select {
case <-time.After(time.Second):
t.Fatal("timeout waiting for subscriber to receive result")
case result := <-resultChan:
// We expect the first result to be TxPublished.
require.Equal(t, TxPublished, result.Event)
}
// Validate the record was stored.
require.Equal(t, 1, tp.records.Len())
require.Equal(t, 1, tp.subscriberChans.Len())
}
// TestBroadcastFail checks the public `Broadcast` returns the error or a
// failed result when the broadcast fails.
func TestBroadcastFail(t *testing.T) {
t.Parallel()
// Create a publisher using the mocks.
tp, m := createTestPublisher(t)
// Create a test feerate.
feerate := chainfee.SatPerKWeight(1000)
// Create a test request.
inp := createTestInput(1000, input.WitnessKeyHash)
// Create a testing bump request.
req := &BumpRequest{
DeliveryAddress: changePkScript,
Inputs: []input.Input{&inp},
Budget: btcutil.Amount(1000),
MaxFeeRate: feerate,
}
// Mock the fee estimator to return the testing fee rate.
//
// We are not testing `NewLinearFeeFunction` here, so the actual params
// used are irrelevant.
m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
feerate, nil).Twice()
m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice()
// Mock the signer to always return a valid script.
//
// NOTE: we are not testing the utility of creating valid txes here, so
// this is fine to be mocked. This behaves essentially as skipping the
// Signer check and alaways assume the tx has a valid sig.
script := &input.Script{}
m.signer.On("ComputeInputScript", mock.Anything,
mock.Anything).Return(script, nil)
// Mock the testmempoolaccept to return an error.
m.wallet.On("CheckMempoolAcceptance",
mock.Anything).Return(errDummy).Once()
// Send the req and expect an error returned.
resultChan, err := tp.Broadcast(req)
require.ErrorIs(t, err, errDummy)
require.Nil(t, resultChan)
// Validate the record was NOT stored.
require.Equal(t, 0, tp.records.Len())
require.Equal(t, 0, tp.subscriberChans.Len())
// Mock the testmempoolaccept again, this time it passes.
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
// Create a test conf event.
confEvent := &chainntnfs.ConfirmationEvent{}
// Mock the RegisterConfirmationsNtfn to pass.
m.notifier.On("RegisterConfirmationsNtfn",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(confEvent, nil).Once()
// Mock the wallet to fail on publish.
m.wallet.On("PublishTransaction",
mock.Anything, mock.Anything).Return(errDummy).Once()
// Send the req and expect no error returned.
resultChan, err = tp.Broadcast(req)
require.NoError(t, err)
// Check the result is sent back.
select {
case <-time.After(time.Second):
t.Fatal("timeout waiting for subscriber to receive result")
case result := <-resultChan:
// We expect the result to be TxFailed and the error is set in
// the result.
require.Equal(t, TxFailed, result.Event)
require.ErrorIs(t, result.Err, errDummy)
}
// Validate the record was removed.
require.Equal(t, 0, tp.records.Len())
require.Equal(t, 0, tp.subscriberChans.Len())
}

View file

@ -493,3 +493,32 @@ func (m *MockBumper) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) {
return args.Get(0).(chan *BumpResult), args.Error(1)
}
// MockFeeFunction is a mock implementation of the FeeFunction interface.
type MockFeeFunction struct {
mock.Mock
}
// Compile-time constraint to ensure MockFeeFunction implements FeeFunction.
var _ FeeFunction = (*MockFeeFunction)(nil)
// FeeRate returns the current fee rate calculated by the fee function.
func (m *MockFeeFunction) FeeRate() chainfee.SatPerKWeight {
args := m.Called()
return args.Get(0).(chainfee.SatPerKWeight)
}
// Increment adds one delta to the current fee rate.
func (m *MockFeeFunction) Increment() (bool, error) {
args := m.Called()
return args.Bool(0), args.Error(1)
}
// IncreaseFeeRate increases the fee rate by one step.
func (m *MockFeeFunction) IncreaseFeeRate(confTarget uint32) (bool, error) {
args := m.Called(confTarget)
return args.Bool(0), args.Error(1)
}

View file

@ -205,15 +205,14 @@ func createSweepTx(inputs []input.Input, outputs []*wire.TxOut,
}
}
log.Infof("Creating sweep transaction %v for %v inputs (%s) "+
"using %v sat/kw, tx_weight=%v, tx_fee=%v, parents_count=%v, "+
"parents_fee=%v, parents_weight=%v",
log.Debugf("Creating sweep transaction %v for %v inputs (%s) "+
"using %v, tx_weight=%v, tx_fee=%v, parents_count=%v, "+
"parents_fee=%v, parents_weight=%v, current_height=%v",
sweepTx.TxHash(), len(inputs),
inputTypeSummary(inputs), feeRate,
estimator.weight(), txFee,
len(estimator.parents), estimator.parentsFee,
estimator.parentsWeight,
)
estimator.parentsWeight, currentBlockHeight)
return sweepTx, txFee, nil
}