sweep+lnd: introduce UtxoAggregator to handle clustering inputs

This commit refactors the grouping logic into a new interface
`UtxoAggregator`, which makes it easier to write tests and opens
possibility for future customized clustering strategies.

The old clustering logic is kept as and moved into `SimpleAggregator`.
This commit is contained in:
yyforyongyu 2023-10-24 12:32:17 +08:00
parent 3bcac318eb
commit 1870caf39c
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
6 changed files with 807 additions and 728 deletions

View File

@ -1063,6 +1063,10 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
return nil, err
}
aggregator := sweep.NewSimpleUtxoAggregator(
cc.FeeEstimator, cfg.Sweeper.MaxFeeRate.FeePerKWeight(),
)
s.sweeper = sweep.New(&sweep.UtxoSweeperConfig{
FeeEstimator: cc.FeeEstimator,
GenSweepScript: newSweepPkScriptGen(cc.Wallet),
@ -1075,7 +1079,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
MaxSweepAttempts: sweep.DefaultMaxSweepAttempts,
NextAttemptDeltaFunc: sweep.DefaultNextAttemptDeltaFunc,
MaxFeeRate: cfg.Sweeper.MaxFeeRate,
FeeRateBucketSize: sweep.DefaultFeeRateBucketSize,
Aggregator: aggregator,
})
s.utxoNursery = contractcourt.NewUtxoNursery(&contractcourt.NurseryConfig{

351
sweep/aggregator.go Normal file
View File

@ -0,0 +1,351 @@
package sweep
import (
"sort"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
const (
// DefaultFeeRateBucketSize is the default size of fee rate buckets
// we'll use when clustering inputs into buckets with similar fee rates
// within the SimpleAggregator.
//
// Given a minimum relay fee rate of 1 sat/vbyte, a multiplier of 10
// would result in the following fee rate buckets up to the maximum fee
// rate:
//
// #1: min = 1 sat/vbyte, max = 10 sat/vbyte
// #2: min = 11 sat/vbyte, max = 20 sat/vbyte...
DefaultFeeRateBucketSize = 10
)
// UtxoAggregator defines an interface that takes a list of inputs and
// aggregate them into groups. Each group is used as the inputs to create a
// sweeping transaction.
type UtxoAggregator interface {
// ClusterInputs takes a list of inputs and groups them into clusters.
ClusterInputs(pendingInputs) []inputCluster
}
// SimpleAggregator aggregates inputs known by the Sweeper based on each
// input's locktime and feerate.
type SimpleAggregator struct {
// FeeEstimator is used when crafting sweep transactions to estimate
// the necessary fee relative to the expected size of the sweep
// transaction.
FeeEstimator chainfee.Estimator
// MaxFeeRate is the maximum fee rate allowed within the
// SimpleAggregator.
MaxFeeRate chainfee.SatPerKWeight
// FeeRateBucketSize is the default size of fee rate buckets we'll use
// when clustering inputs into buckets with similar fee rates within
// the SimpleAggregator.
//
// Given a minimum relay fee rate of 1 sat/vbyte, a fee rate bucket
// size of 10 would result in the following fee rate buckets up to the
// maximum fee rate:
//
// #1: min = 1 sat/vbyte, max (exclusive) = 11 sat/vbyte
// #2: min = 11 sat/vbyte, max (exclusive) = 21 sat/vbyte...
FeeRateBucketSize int
}
// Compile-time constraint to ensure SimpleAggregator implements UtxoAggregator.
var _ UtxoAggregator = (*SimpleAggregator)(nil)
// NewSimpleUtxoAggregator creates a new instance of a SimpleAggregator.
func NewSimpleUtxoAggregator(estimator chainfee.Estimator,
max chainfee.SatPerKWeight) *SimpleAggregator {
return &SimpleAggregator{
FeeEstimator: estimator,
MaxFeeRate: max,
FeeRateBucketSize: DefaultFeeRateBucketSize,
}
}
// ClusterInputs creates a list of input clusters from the set of pending
// inputs known by the UtxoSweeper. It clusters inputs by
// 1) Required tx locktime
// 2) Similar fee rates.
//
// TODO(yy): remove this nolint once done refactoring.
//
//nolint:revive
func (s *SimpleAggregator) ClusterInputs(inputs pendingInputs) []inputCluster {
// We start by getting the inputs clusters by locktime. Since the
// inputs commit to the locktime, they can only be clustered together
// if the locktime is equal.
lockTimeClusters, nonLockTimeInputs := s.clusterByLockTime(inputs)
// Cluster the remaining inputs by sweep fee rate.
feeClusters := s.clusterBySweepFeeRate(nonLockTimeInputs)
// Since the inputs that we clustered by fee rate don't commit to a
// specific locktime, we can try to merge a locktime cluster with a fee
// cluster.
return zipClusters(lockTimeClusters, feeClusters)
}
// clusterByLockTime takes the given set of pending inputs and clusters those
// with equal locktime together. Each cluster contains a sweep fee rate, which
// is determined by calculating the average fee rate of all inputs within that
// cluster. In addition to the created clusters, inputs that did not specify a
// required locktime are returned.
func (s *SimpleAggregator) clusterByLockTime(
inputs pendingInputs) ([]inputCluster, pendingInputs) {
locktimes := make(map[uint32]pendingInputs)
rem := make(pendingInputs)
// Go through all inputs and check if they require a certain locktime.
for op, input := range inputs {
lt, ok := input.RequiredLockTime()
if !ok {
rem[op] = input
continue
}
// Check if we already have inputs with this locktime.
cluster, ok := locktimes[lt]
if !ok {
cluster = make(pendingInputs)
}
// Get the fee rate based on the fee preference. If an error is
// returned, we'll skip sweeping this input for this round of
// cluster creation and retry it when we create the clusters
// from the pending inputs again.
feeRate, err := input.params.Fee.Estimate(
s.FeeEstimator, s.MaxFeeRate,
)
if err != nil {
log.Warnf("Skipping input %v: %v", op, err)
continue
}
log.Debugf("Adding input %v to cluster with locktime=%v, "+
"feeRate=%v", op, lt, feeRate)
// Attach the fee rate to the input.
input.lastFeeRate = feeRate
// Update the cluster about the updated input.
cluster[op] = input
locktimes[lt] = cluster
}
// We'll then determine the sweep fee rate for each set of inputs by
// calculating the average fee rate of the inputs within each set.
inputClusters := make([]inputCluster, 0, len(locktimes))
for lt, cluster := range locktimes {
lt := lt
var sweepFeeRate chainfee.SatPerKWeight
for _, input := range cluster {
sweepFeeRate += input.lastFeeRate
}
sweepFeeRate /= chainfee.SatPerKWeight(len(cluster))
inputClusters = append(inputClusters, inputCluster{
lockTime: &lt,
sweepFeeRate: sweepFeeRate,
inputs: cluster,
})
}
return inputClusters, rem
}
// clusterBySweepFeeRate takes the set of pending inputs within the UtxoSweeper
// and clusters those together with similar fee rates. Each cluster contains a
// sweep fee rate, which is determined by calculating the average fee rate of
// all inputs within that cluster.
func (s *SimpleAggregator) clusterBySweepFeeRate(
inputs pendingInputs) []inputCluster {
bucketInputs := make(map[int]*bucketList)
inputFeeRates := make(map[wire.OutPoint]chainfee.SatPerKWeight)
// First, we'll group together all inputs with similar fee rates. This
// is done by determining the fee rate bucket they should belong in.
for op, input := range inputs {
feeRate, err := input.params.Fee.Estimate(
s.FeeEstimator, s.MaxFeeRate,
)
if err != nil {
log.Warnf("Skipping input %v: %v", op, err)
continue
}
// Only try to sweep inputs with an unconfirmed parent if the
// current sweep fee rate exceeds the parent tx fee rate. This
// assumes that such inputs are offered to the sweeper solely
// for the purpose of anchoring down the parent tx using cpfp.
parentTx := input.UnconfParent()
if parentTx != nil {
parentFeeRate :=
chainfee.SatPerKWeight(parentTx.Fee*1000) /
chainfee.SatPerKWeight(parentTx.Weight)
if parentFeeRate >= feeRate {
log.Debugf("Skipping cpfp input %v: "+
"fee_rate=%v, parent_fee_rate=%v", op,
feeRate, parentFeeRate)
continue
}
}
feeGroup := s.bucketForFeeRate(feeRate)
// Create a bucket list for this fee rate if there isn't one
// yet.
buckets, ok := bucketInputs[feeGroup]
if !ok {
buckets = &bucketList{}
bucketInputs[feeGroup] = buckets
}
// Request the bucket list to add this input. The bucket list
// will take into account exclusive group constraints.
buckets.add(input)
input.lastFeeRate = feeRate
inputFeeRates[op] = feeRate
}
// We'll then determine the sweep fee rate for each set of inputs by
// calculating the average fee rate of the inputs within each set.
inputClusters := make([]inputCluster, 0, len(bucketInputs))
for _, buckets := range bucketInputs {
for _, inputs := range buckets.buckets {
var sweepFeeRate chainfee.SatPerKWeight
for op := range inputs {
sweepFeeRate += inputFeeRates[op]
}
sweepFeeRate /= chainfee.SatPerKWeight(len(inputs))
inputClusters = append(inputClusters, inputCluster{
sweepFeeRate: sweepFeeRate,
inputs: inputs,
})
}
}
return inputClusters
}
// bucketForFeeReate determines the proper bucket for a fee rate. This is done
// in order to batch inputs with similar fee rates together.
func (s *SimpleAggregator) bucketForFeeRate(
feeRate chainfee.SatPerKWeight) int {
relayFeeRate := s.FeeEstimator.RelayFeePerKW()
// Create an isolated bucket for sweeps at the minimum fee rate. This
// is to prevent very small outputs (anchors) from becoming
// uneconomical if their fee rate would be averaged with higher fee
// rate inputs in a regular bucket.
if feeRate == relayFeeRate {
return 0
}
return 1 + int(feeRate-relayFeeRate)/s.FeeRateBucketSize
}
// mergeClusters attempts to merge cluster a and b if they are compatible. The
// new cluster will have the locktime set if a or b had a locktime set, and a
// sweep fee rate that is the maximum of a and b's. If the two clusters are not
// compatible, they will be returned unchanged.
func mergeClusters(a, b inputCluster) []inputCluster {
newCluster := inputCluster{}
switch {
// Incompatible locktimes, return the sets without merging them.
case a.lockTime != nil && b.lockTime != nil &&
*a.lockTime != *b.lockTime:
return []inputCluster{a, b}
case a.lockTime != nil:
newCluster.lockTime = a.lockTime
case b.lockTime != nil:
newCluster.lockTime = b.lockTime
}
if a.sweepFeeRate > b.sweepFeeRate {
newCluster.sweepFeeRate = a.sweepFeeRate
} else {
newCluster.sweepFeeRate = b.sweepFeeRate
}
newCluster.inputs = make(pendingInputs)
for op, in := range a.inputs {
newCluster.inputs[op] = in
}
for op, in := range b.inputs {
newCluster.inputs[op] = in
}
return []inputCluster{newCluster}
}
// zipClusters merges pairwise clusters from as and bs such that cluster a from
// as is merged with a cluster from bs that has at least the fee rate of a.
// This to ensure we don't delay confirmation by decreasing the fee rate (the
// lock time inputs are typically second level HTLC transactions, that are time
// sensitive).
func zipClusters(as, bs []inputCluster) []inputCluster {
// Sort the clusters by decreasing fee rates.
sort.Slice(as, func(i, j int) bool {
return as[i].sweepFeeRate >
as[j].sweepFeeRate
})
sort.Slice(bs, func(i, j int) bool {
return bs[i].sweepFeeRate >
bs[j].sweepFeeRate
})
var (
finalClusters []inputCluster
j int
)
// Go through each cluster in as, and merge with the next one from bs
// if it has at least the fee rate needed.
for i := range as {
a := as[i]
switch {
// If the fee rate for the next one from bs is at least a's, we
// merge.
case j < len(bs) && bs[j].sweepFeeRate >= a.sweepFeeRate:
merged := mergeClusters(a, bs[j])
finalClusters = append(finalClusters, merged...)
// Increment j for the next round.
j++
// We did not merge, meaning all the remaining clusters from bs
// have lower fee rate. Instead we add a directly to the final
// clusters.
default:
finalClusters = append(finalClusters, a)
}
}
// Add any remaining clusters from bs.
for ; j < len(bs); j++ {
b := bs[j]
finalClusters = append(finalClusters, b)
}
return finalClusters
}

423
sweep/aggregator_test.go Normal file
View File

@ -0,0 +1,423 @@
package sweep
import (
"errors"
"reflect"
"sort"
"testing"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/stretchr/testify/require"
)
//nolint:lll
var (
testInputsA = pendingInputs{
wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{},
}
testInputsB = pendingInputs{
wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{},
}
testInputsC = pendingInputs{
wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{},
}
)
// TestMergeClusters check that we properly can merge clusters together,
// according to their required locktime.
func TestMergeClusters(t *testing.T) {
t.Parallel()
lockTime1 := uint32(100)
lockTime2 := uint32(200)
testCases := []struct {
name string
a inputCluster
b inputCluster
res []inputCluster
}{
{
name: "max fee rate",
a: inputCluster{
sweepFeeRate: 5000,
inputs: testInputsA,
},
b: inputCluster{
sweepFeeRate: 7000,
inputs: testInputsB,
},
res: []inputCluster{
{
sweepFeeRate: 7000,
inputs: testInputsC,
},
},
},
{
name: "same locktime",
a: inputCluster{
lockTime: &lockTime1,
sweepFeeRate: 5000,
inputs: testInputsA,
},
b: inputCluster{
lockTime: &lockTime1,
sweepFeeRate: 7000,
inputs: testInputsB,
},
res: []inputCluster{
{
lockTime: &lockTime1,
sweepFeeRate: 7000,
inputs: testInputsC,
},
},
},
{
name: "diff locktime",
a: inputCluster{
lockTime: &lockTime1,
sweepFeeRate: 5000,
inputs: testInputsA,
},
b: inputCluster{
lockTime: &lockTime2,
sweepFeeRate: 7000,
inputs: testInputsB,
},
res: []inputCluster{
{
lockTime: &lockTime1,
sweepFeeRate: 5000,
inputs: testInputsA,
},
{
lockTime: &lockTime2,
sweepFeeRate: 7000,
inputs: testInputsB,
},
},
},
}
for _, test := range testCases {
merged := mergeClusters(test.a, test.b)
if !reflect.DeepEqual(merged, test.res) {
t.Fatalf("[%s] unexpected result: %v",
test.name, spew.Sdump(merged))
}
}
}
// TestZipClusters tests that we can merge lists of inputs clusters correctly.
func TestZipClusters(t *testing.T) {
t.Parallel()
createCluster := func(inp pendingInputs,
f chainfee.SatPerKWeight) inputCluster {
return inputCluster{
sweepFeeRate: f,
inputs: inp,
}
}
testCases := []struct {
name string
as []inputCluster
bs []inputCluster
res []inputCluster
}{
{
name: "merge A into B",
as: []inputCluster{
createCluster(testInputsA, 5000),
},
bs: []inputCluster{
createCluster(testInputsB, 7000),
},
res: []inputCluster{
createCluster(testInputsC, 7000),
},
},
{
name: "A can't merge with B",
as: []inputCluster{
createCluster(testInputsA, 7000),
},
bs: []inputCluster{
createCluster(testInputsB, 5000),
},
res: []inputCluster{
createCluster(testInputsA, 7000),
createCluster(testInputsB, 5000),
},
},
{
name: "empty bs",
as: []inputCluster{
createCluster(testInputsA, 7000),
},
bs: []inputCluster{},
res: []inputCluster{
createCluster(testInputsA, 7000),
},
},
{
name: "empty as",
as: []inputCluster{},
bs: []inputCluster{
createCluster(testInputsB, 5000),
},
res: []inputCluster{
createCluster(testInputsB, 5000),
},
},
{
name: "zip 3xA into 3xB",
as: []inputCluster{
createCluster(testInputsA, 5000),
createCluster(testInputsA, 5000),
createCluster(testInputsA, 5000),
},
bs: []inputCluster{
createCluster(testInputsB, 7000),
createCluster(testInputsB, 7000),
createCluster(testInputsB, 7000),
},
res: []inputCluster{
createCluster(testInputsC, 7000),
createCluster(testInputsC, 7000),
createCluster(testInputsC, 7000),
},
},
{
name: "zip A into 3xB",
as: []inputCluster{
createCluster(testInputsA, 2500),
},
bs: []inputCluster{
createCluster(testInputsB, 3000),
createCluster(testInputsB, 2000),
createCluster(testInputsB, 1000),
},
res: []inputCluster{
createCluster(testInputsC, 3000),
createCluster(testInputsB, 2000),
createCluster(testInputsB, 1000),
},
},
}
for _, test := range testCases {
zipped := zipClusters(test.as, test.bs)
if !reflect.DeepEqual(zipped, test.res) {
t.Fatalf("[%s] unexpected result: %v",
test.name, spew.Sdump(zipped))
}
}
}
// TestClusterByLockTime tests the method clusterByLockTime works as expected.
func TestClusterByLockTime(t *testing.T) {
t.Parallel()
// Create a mock FeePreference.
mockFeePref := &MockFeePreference{}
// Create a test param with a dummy fee preference. This is needed so
// `feeRateForPreference` won't throw an error.
param := Params{Fee: mockFeePref}
// We begin the test by creating three clusters of inputs, the first
// cluster has a locktime of 1, the second has a locktime of 2, and the
// final has no locktime.
lockTime1 := uint32(1)
lockTime2 := uint32(2)
// Create cluster one, which has a locktime of 1.
input1LockTime1 := &input.MockInput{}
input2LockTime1 := &input.MockInput{}
input1LockTime1.On("RequiredLockTime").Return(lockTime1, true)
input2LockTime1.On("RequiredLockTime").Return(lockTime1, true)
// Create cluster two, which has a locktime of 2.
input3LockTime2 := &input.MockInput{}
input4LockTime2 := &input.MockInput{}
input3LockTime2.On("RequiredLockTime").Return(lockTime2, true)
input4LockTime2.On("RequiredLockTime").Return(lockTime2, true)
// Create cluster three, which has no locktime.
input5NoLockTime := &input.MockInput{}
input6NoLockTime := &input.MockInput{}
input5NoLockTime.On("RequiredLockTime").Return(uint32(0), false)
input6NoLockTime.On("RequiredLockTime").Return(uint32(0), false)
// With the inner Input being mocked, we can now create the pending
// inputs.
input1 := &pendingInput{Input: input1LockTime1, params: param}
input2 := &pendingInput{Input: input2LockTime1, params: param}
input3 := &pendingInput{Input: input3LockTime2, params: param}
input4 := &pendingInput{Input: input4LockTime2, params: param}
input5 := &pendingInput{Input: input5NoLockTime, params: param}
input6 := &pendingInput{Input: input6NoLockTime, params: param}
// Create the pending inputs map, which will be passed to the method
// under test.
//
// NOTE: we don't care the actual outpoint values as long as they are
// unique.
inputs := pendingInputs{
wire.OutPoint{Index: 1}: input1,
wire.OutPoint{Index: 2}: input2,
wire.OutPoint{Index: 3}: input3,
wire.OutPoint{Index: 4}: input4,
wire.OutPoint{Index: 5}: input5,
wire.OutPoint{Index: 6}: input6,
}
// Create expected clusters so we can shorten the line length in the
// test cases below.
cluster1 := pendingInputs{
wire.OutPoint{Index: 1}: input1,
wire.OutPoint{Index: 2}: input2,
}
cluster2 := pendingInputs{
wire.OutPoint{Index: 3}: input3,
wire.OutPoint{Index: 4}: input4,
}
// cluster3 should be the remaining inputs since they don't have
// locktime.
cluster3 := pendingInputs{
wire.OutPoint{Index: 5}: input5,
wire.OutPoint{Index: 6}: input6,
}
const (
// Set the min fee rate to be 1000 sat/kw.
minFeeRate = chainfee.SatPerKWeight(1000)
// Set the max fee rate to be 10,000 sat/kw.
maxFeeRate = chainfee.SatPerKWeight(10_000)
)
// Create a test aggregator.
s := NewSimpleUtxoAggregator(nil, maxFeeRate)
testCases := []struct {
name string
// setupMocker takes a testing fee rate and makes a mocker over
// `Estimate` that always return the testing fee rate.
setupMocker func()
testFeeRate chainfee.SatPerKWeight
expectedClusters []inputCluster
expectedRemainingInputs pendingInputs
}{
{
// Test a successful case where the locktime clusters
// are created and the no-locktime cluster is returned
// as the remaining inputs.
name: "successfully create clusters",
setupMocker: func() {
// Expect the four inputs with locktime to call
// this method.
mockFeePref.On("Estimate", nil, maxFeeRate).
Return(minFeeRate+1, nil).Times(4)
},
// Use a fee rate above the min value so we don't hit
// an error when performing fee estimation.
//
// TODO(yy): we should customize the returned fee rate
// for each input to further test the averaging logic.
// Or we can split the method into two, one for
// grouping the clusters and the other for averaging
// the fee rates so it's easier to be tested.
testFeeRate: minFeeRate + 1,
expectedClusters: []inputCluster{
{
lockTime: &lockTime1,
sweepFeeRate: minFeeRate + 1,
inputs: cluster1,
},
{
lockTime: &lockTime2,
sweepFeeRate: minFeeRate + 1,
inputs: cluster2,
},
},
expectedRemainingInputs: cluster3,
},
{
// Test that when the input is skipped when the fee
// estimation returns an error.
name: "error from fee estimation",
setupMocker: func() {
mockFeePref.On("Estimate", nil, maxFeeRate).
Return(chainfee.SatPerKWeight(0),
errors.New("dummy")).Times(4)
},
// Use a fee rate below the min value so we hit an
// error when performing fee estimation.
testFeeRate: minFeeRate - 1,
expectedClusters: []inputCluster{},
// Remaining inputs should stay untouched.
expectedRemainingInputs: cluster3,
},
}
//nolint:paralleltest
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
// Apply the test fee rate so `feeRateForPreference` is
// mocked to return the specified value.
tc.setupMocker()
// Assert the mocked methods are called as expeceted.
defer mockFeePref.AssertExpectations(t)
// Call the method under test.
clusters, remainingInputs := s.clusterByLockTime(inputs)
// Sort by locktime as the order is not guaranteed.
sort.Slice(clusters, func(i, j int) bool {
return *clusters[i].lockTime <
*clusters[j].lockTime
})
// Validate the values are returned as expected.
require.Equal(t, tc.expectedClusters, clusters)
require.Equal(t, tc.expectedRemainingInputs,
remainingInputs,
)
// Assert the mocked methods are called as expeceted.
input1LockTime1.AssertExpectations(t)
input2LockTime1.AssertExpectations(t)
input3LockTime2.AssertExpectations(t)
input4LockTime2.AssertExpectations(t)
input5NoLockTime.AssertExpectations(t)
input6NoLockTime.AssertExpectations(t)
})
}
}

View File

@ -27,3 +27,18 @@ func (m *MockFeePreference) Estimate(estimator chainfee.Estimator,
return args.Get(0).(chainfee.SatPerKWeight), args.Error(1)
}
type mockUtxoAggregator struct {
mock.Mock
}
// Compile-time constraint to ensure mockUtxoAggregator implements
// UtxoAggregator.
var _ UtxoAggregator = (*mockUtxoAggregator)(nil)
// ClusterInputs takes a list of inputs and groups them into clusters.
func (m *mockUtxoAggregator) ClusterInputs(pendingInputs) []inputCluster {
args := m.Called(pendingInputs{})
return args.Get(0).([]inputCluster)
}

View File

@ -20,20 +20,6 @@ import (
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
const (
// DefaultFeeRateBucketSize is the default size of fee rate buckets
// we'll use when clustering inputs into buckets with similar fee rates
// within the UtxoSweeper.
//
// Given a minimum relay fee rate of 1 sat/vbyte, a multiplier of 10
// would result in the following fee rate buckets up to the maximum fee
// rate:
//
// #1: min = 1 sat/vbyte, max = 10 sat/vbyte
// #2: min = 11 sat/vbyte, max = 20 sat/vbyte...
DefaultFeeRateBucketSize = 10
)
var (
// ErrRemoteSpend is returned in case an output that we try to sweep is
// confirmed in a tx of the remote party.
@ -287,17 +273,9 @@ type UtxoSweeperConfig struct {
// UtxoSweeper.
MaxFeeRate chainfee.SatPerVByte
// FeeRateBucketSize is the default size of fee rate buckets we'll use
// when clustering inputs into buckets with similar fee rates within the
// UtxoSweeper.
//
// Given a minimum relay fee rate of 1 sat/vbyte, a fee rate bucket size
// of 10 would result in the following fee rate buckets up to the
// maximum fee rate:
//
// #1: min = 1 sat/vbyte, max (exclusive) = 11 sat/vbyte
// #2: min = 11 sat/vbyte, max (exclusive) = 21 sat/vbyte...
FeeRateBucketSize int
// Aggregator is used to group inputs into clusters based on its
// implemention-specific strategy.
Aggregator UtxoAggregator
}
// Result is the struct that is pushed through the result channel. Callers can
@ -717,280 +695,6 @@ func (s *UtxoSweeper) sweepCluster(cluster inputCluster) error {
})
}
// bucketForFeeReate determines the proper bucket for a fee rate. This is done
// in order to batch inputs with similar fee rates together.
func (s *UtxoSweeper) bucketForFeeRate(
feeRate chainfee.SatPerKWeight) int {
// Create an isolated bucket for sweeps at the minimum fee rate. This is
// to prevent very small outputs (anchors) from becoming uneconomical if
// their fee rate would be averaged with higher fee rate inputs in a
// regular bucket.
if feeRate == s.relayFeeRate {
return 0
}
return 1 + int(feeRate-s.relayFeeRate)/s.cfg.FeeRateBucketSize
}
// createInputClusters creates a list of input clusters from the set of pending
// inputs known by the UtxoSweeper. It clusters inputs by
// 1) Required tx locktime
// 2) Similar fee rates.
func (s *UtxoSweeper) createInputClusters() []inputCluster {
inputs := s.pendingInputs
// We start by getting the inputs clusters by locktime. Since the
// inputs commit to the locktime, they can only be clustered together
// if the locktime is equal.
lockTimeClusters, nonLockTimeInputs := s.clusterByLockTime(inputs)
// Cluster the remaining inputs by sweep fee rate.
feeClusters := s.clusterBySweepFeeRate(nonLockTimeInputs)
// Since the inputs that we clustered by fee rate don't commit to a
// specific locktime, we can try to merge a locktime cluster with a fee
// cluster.
return zipClusters(lockTimeClusters, feeClusters)
}
// clusterByLockTime takes the given set of pending inputs and clusters those
// with equal locktime together. Each cluster contains a sweep fee rate, which
// is determined by calculating the average fee rate of all inputs within that
// cluster. In addition to the created clusters, inputs that did not specify a
// required lock time are returned.
func (s *UtxoSweeper) clusterByLockTime(inputs pendingInputs) ([]inputCluster,
pendingInputs) {
locktimes := make(map[uint32]pendingInputs)
rem := make(pendingInputs)
// Go through all inputs and check if they require a certain locktime.
for op, input := range inputs {
lt, ok := input.RequiredLockTime()
if !ok {
rem[op] = input
continue
}
// Check if we already have inputs with this locktime.
cluster, ok := locktimes[lt]
if !ok {
cluster = make(pendingInputs)
}
// Get the fee rate based on the fee preference. If an error is
// returned, we'll skip sweeping this input for this round of
// cluster creation and retry it when we create the clusters
// from the pending inputs again.
feeRate, err := input.params.Fee.Estimate(
s.cfg.FeeEstimator, s.cfg.MaxFeeRate.FeePerKWeight(),
)
if err != nil {
log.Warnf("Skipping input %v: %v", op, err)
continue
}
log.Debugf("Adding input %v to cluster with locktime=%v, "+
"feeRate=%v", op, lt, feeRate)
// Attach the fee rate to the input.
input.lastFeeRate = feeRate
// Update the cluster about the updated input.
cluster[op] = input
locktimes[lt] = cluster
}
// We'll then determine the sweep fee rate for each set of inputs by
// calculating the average fee rate of the inputs within each set.
inputClusters := make([]inputCluster, 0, len(locktimes))
for lt, cluster := range locktimes {
lt := lt
var sweepFeeRate chainfee.SatPerKWeight
for _, input := range cluster {
sweepFeeRate += input.lastFeeRate
}
sweepFeeRate /= chainfee.SatPerKWeight(len(cluster))
inputClusters = append(inputClusters, inputCluster{
lockTime: &lt,
sweepFeeRate: sweepFeeRate,
inputs: cluster,
})
}
return inputClusters, rem
}
// clusterBySweepFeeRate takes the set of pending inputs within the UtxoSweeper
// and clusters those together with similar fee rates. Each cluster contains a
// sweep fee rate, which is determined by calculating the average fee rate of
// all inputs within that cluster.
func (s *UtxoSweeper) clusterBySweepFeeRate(inputs pendingInputs) []inputCluster {
bucketInputs := make(map[int]*bucketList)
inputFeeRates := make(map[wire.OutPoint]chainfee.SatPerKWeight)
// First, we'll group together all inputs with similar fee rates. This
// is done by determining the fee rate bucket they should belong in.
for op, input := range inputs {
feeRate, err := input.params.Fee.Estimate(
s.cfg.FeeEstimator, s.cfg.MaxFeeRate.FeePerKWeight(),
)
if err != nil {
log.Warnf("Skipping input %v: %v", op, err)
continue
}
// Only try to sweep inputs with an unconfirmed parent if the
// current sweep fee rate exceeds the parent tx fee rate. This
// assumes that such inputs are offered to the sweeper solely
// for the purpose of anchoring down the parent tx using cpfp.
parentTx := input.UnconfParent()
if parentTx != nil {
parentFeeRate :=
chainfee.SatPerKWeight(parentTx.Fee*1000) /
chainfee.SatPerKWeight(parentTx.Weight)
if parentFeeRate >= feeRate {
log.Debugf("Skipping cpfp input %v: fee_rate=%v, "+
"parent_fee_rate=%v", op, feeRate,
parentFeeRate)
continue
}
}
feeGroup := s.bucketForFeeRate(feeRate)
// Create a bucket list for this fee rate if there isn't one
// yet.
buckets, ok := bucketInputs[feeGroup]
if !ok {
buckets = &bucketList{}
bucketInputs[feeGroup] = buckets
}
// Request the bucket list to add this input. The bucket list
// will take into account exclusive group constraints.
buckets.add(input)
input.lastFeeRate = feeRate
inputFeeRates[op] = feeRate
}
// We'll then determine the sweep fee rate for each set of inputs by
// calculating the average fee rate of the inputs within each set.
inputClusters := make([]inputCluster, 0, len(bucketInputs))
for _, buckets := range bucketInputs {
for _, inputs := range buckets.buckets {
var sweepFeeRate chainfee.SatPerKWeight
for op := range inputs {
sweepFeeRate += inputFeeRates[op]
}
sweepFeeRate /= chainfee.SatPerKWeight(len(inputs))
inputClusters = append(inputClusters, inputCluster{
sweepFeeRate: sweepFeeRate,
inputs: inputs,
})
}
}
return inputClusters
}
// zipClusters merges pairwise clusters from as and bs such that cluster a from
// as is merged with a cluster from bs that has at least the fee rate of a.
// This to ensure we don't delay confirmation by decreasing the fee rate (the
// lock time inputs are typically second level HTLC transactions, that are time
// sensitive).
func zipClusters(as, bs []inputCluster) []inputCluster {
// Sort the clusters by decreasing fee rates.
sort.Slice(as, func(i, j int) bool {
return as[i].sweepFeeRate >
as[j].sweepFeeRate
})
sort.Slice(bs, func(i, j int) bool {
return bs[i].sweepFeeRate >
bs[j].sweepFeeRate
})
var (
finalClusters []inputCluster
j int
)
// Go through each cluster in as, and merge with the next one from bs
// if it has at least the fee rate needed.
for i := range as {
a := as[i]
switch {
// If the fee rate for the next one from bs is at least a's, we
// merge.
case j < len(bs) && bs[j].sweepFeeRate >= a.sweepFeeRate:
merged := mergeClusters(a, bs[j])
finalClusters = append(finalClusters, merged...)
// Increment j for the next round.
j++
// We did not merge, meaning all the remaining clusters from bs
// have lower fee rate. Instead we add a directly to the final
// clusters.
default:
finalClusters = append(finalClusters, a)
}
}
// Add any remaining clusters from bs.
for ; j < len(bs); j++ {
b := bs[j]
finalClusters = append(finalClusters, b)
}
return finalClusters
}
// mergeClusters attempts to merge cluster a and b if they are compatible. The
// new cluster will have the locktime set if a or b had a locktime set, and a
// sweep fee rate that is the maximum of a and b's. If the two clusters are not
// compatible, they will be returned unchanged.
func mergeClusters(a, b inputCluster) []inputCluster {
newCluster := inputCluster{}
switch {
// Incompatible locktimes, return the sets without merging them.
case a.lockTime != nil && b.lockTime != nil && *a.lockTime != *b.lockTime:
return []inputCluster{a, b}
case a.lockTime != nil:
newCluster.lockTime = a.lockTime
case b.lockTime != nil:
newCluster.lockTime = b.lockTime
}
if a.sweepFeeRate > b.sweepFeeRate {
newCluster.sweepFeeRate = a.sweepFeeRate
} else {
newCluster.sweepFeeRate = b.sweepFeeRate
}
newCluster.inputs = make(pendingInputs)
for op, in := range a.inputs {
newCluster.inputs[op] = in
}
for op, in := range b.inputs {
newCluster.inputs[op] = in
}
return []inputCluster{newCluster}
}
// signalAndRemove notifies the listeners of the final result of the input
// sweep. It cancels any pending spend notification and removes the input from
// the list of pending inputs. When this function returns, the sweeper has
@ -1465,7 +1169,6 @@ func (s *UtxoSweeper) ListSweeps() ([]chainhash.Hash, error) {
// handleNewInput processes a new input by registering spend notification and
// scheduling sweeping for it.
func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) {
outpoint := *input.input.OutPoint()
pendInput, pending := s.pendingInputs[outpoint]
if pending {
@ -1630,7 +1333,7 @@ func (s *UtxoSweeper) handleSweep() {
// Before attempting to sweep them, we'll sort them in descending fee
// rate order. We do this to ensure any inputs which have had their fee
// rate bumped are broadcast first in order enforce the RBF policy.
inputClusters := s.createInputClusters()
inputClusters := s.cfg.Aggregator.ClusterInputs(s.pendingInputs)
sort.Slice(inputClusters, func(i, j int) bool {
return inputClusters[i].sweepFeeRate >
inputClusters[j].sweepFeeRate

View File

@ -1,11 +1,8 @@
package sweep
import (
"errors"
"os"
"reflect"
"runtime/pprof"
"sort"
"testing"
"time"
@ -14,7 +11,6 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input"
@ -121,6 +117,10 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext {
estimator := newMockFeeEstimator(10000, chainfee.FeePerKwFloor)
aggregator := NewSimpleUtxoAggregator(
estimator, DefaultMaxFeeRate.FeePerKWeight(),
)
ctx := &sweeperTestContext{
notifier: notifier,
publishChan: backend.publishChan,
@ -150,7 +150,7 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext {
return 1 << uint(attempts-1)
},
MaxFeeRate: DefaultMaxFeeRate,
FeeRateBucketSize: DefaultFeeRateBucketSize,
Aggregator: aggregator,
})
ctx.sweeper.Start()
@ -384,9 +384,7 @@ func TestDust(t *testing.T) {
dustInput := createTestInput(5260, input.CommitmentTimeLock)
_, err := ctx.sweeper.SweepInput(&dustInput, defaultFeePref)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
// No sweep transaction is expected now. The sweeper should recognize
// that the sweep output will not be relayed and not generate the tx. It
@ -398,18 +396,13 @@ func TestDust(t *testing.T) {
largeInput := createTestInput(100000, input.CommitmentTimeLock)
_, err = ctx.sweeper.SweepInput(&largeInput, defaultFeePref)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
// The second input brings the sweep output above the dust limit. We
// expect a sweep tx now.
sweepTx := ctx.receiveTx()
if len(sweepTx.TxIn) != 2 {
t.Fatalf("Expected tx to sweep 2 inputs, but contains %v "+
"inputs instead", len(sweepTx.TxIn))
}
require.Len(t, sweepTx.TxIn, 2, "unexpected num of tx inputs")
ctx.backend.mine()
@ -1249,224 +1242,6 @@ func TestCpfp(t *testing.T) {
ctx.finish(1)
}
var (
testInputsA = pendingInputs{
wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{},
}
testInputsB = pendingInputs{
wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{},
}
testInputsC = pendingInputs{
wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{},
}
)
// TestMergeClusters check that we properly can merge clusters together,
// according to their required locktime.
func TestMergeClusters(t *testing.T) {
t.Parallel()
lockTime1 := uint32(100)
lockTime2 := uint32(200)
testCases := []struct {
name string
a inputCluster
b inputCluster
res []inputCluster
}{
{
name: "max fee rate",
a: inputCluster{
sweepFeeRate: 5000,
inputs: testInputsA,
},
b: inputCluster{
sweepFeeRate: 7000,
inputs: testInputsB,
},
res: []inputCluster{
{
sweepFeeRate: 7000,
inputs: testInputsC,
},
},
},
{
name: "same locktime",
a: inputCluster{
lockTime: &lockTime1,
sweepFeeRate: 5000,
inputs: testInputsA,
},
b: inputCluster{
lockTime: &lockTime1,
sweepFeeRate: 7000,
inputs: testInputsB,
},
res: []inputCluster{
{
lockTime: &lockTime1,
sweepFeeRate: 7000,
inputs: testInputsC,
},
},
},
{
name: "diff locktime",
a: inputCluster{
lockTime: &lockTime1,
sweepFeeRate: 5000,
inputs: testInputsA,
},
b: inputCluster{
lockTime: &lockTime2,
sweepFeeRate: 7000,
inputs: testInputsB,
},
res: []inputCluster{
{
lockTime: &lockTime1,
sweepFeeRate: 5000,
inputs: testInputsA,
},
{
lockTime: &lockTime2,
sweepFeeRate: 7000,
inputs: testInputsB,
},
},
},
}
for _, test := range testCases {
merged := mergeClusters(test.a, test.b)
if !reflect.DeepEqual(merged, test.res) {
t.Fatalf("[%s] unexpected result: %v",
test.name, spew.Sdump(merged))
}
}
}
// TestZipClusters tests that we can merge lists of inputs clusters correctly.
func TestZipClusters(t *testing.T) {
t.Parallel()
createCluster := func(inp pendingInputs, f chainfee.SatPerKWeight) inputCluster {
return inputCluster{
sweepFeeRate: f,
inputs: inp,
}
}
testCases := []struct {
name string
as []inputCluster
bs []inputCluster
res []inputCluster
}{
{
name: "merge A into B",
as: []inputCluster{
createCluster(testInputsA, 5000),
},
bs: []inputCluster{
createCluster(testInputsB, 7000),
},
res: []inputCluster{
createCluster(testInputsC, 7000),
},
},
{
name: "A can't merge with B",
as: []inputCluster{
createCluster(testInputsA, 7000),
},
bs: []inputCluster{
createCluster(testInputsB, 5000),
},
res: []inputCluster{
createCluster(testInputsA, 7000),
createCluster(testInputsB, 5000),
},
},
{
name: "empty bs",
as: []inputCluster{
createCluster(testInputsA, 7000),
},
bs: []inputCluster{},
res: []inputCluster{
createCluster(testInputsA, 7000),
},
},
{
name: "empty as",
as: []inputCluster{},
bs: []inputCluster{
createCluster(testInputsB, 5000),
},
res: []inputCluster{
createCluster(testInputsB, 5000),
},
},
{
name: "zip 3xA into 3xB",
as: []inputCluster{
createCluster(testInputsA, 5000),
createCluster(testInputsA, 5000),
createCluster(testInputsA, 5000),
},
bs: []inputCluster{
createCluster(testInputsB, 7000),
createCluster(testInputsB, 7000),
createCluster(testInputsB, 7000),
},
res: []inputCluster{
createCluster(testInputsC, 7000),
createCluster(testInputsC, 7000),
createCluster(testInputsC, 7000),
},
},
{
name: "zip A into 3xB",
as: []inputCluster{
createCluster(testInputsA, 2500),
},
bs: []inputCluster{
createCluster(testInputsB, 3000),
createCluster(testInputsB, 2000),
createCluster(testInputsB, 1000),
},
res: []inputCluster{
createCluster(testInputsC, 3000),
createCluster(testInputsB, 2000),
createCluster(testInputsB, 1000),
},
},
}
for _, test := range testCases {
zipped := zipClusters(test.as, test.bs)
if !reflect.DeepEqual(zipped, test.res) {
t.Fatalf("[%s] unexpected result: %v",
test.name, spew.Sdump(zipped))
}
}
}
type testInput struct {
*input.BaseInput
@ -2142,198 +1917,6 @@ func TestSweeperShutdownHandling(t *testing.T) {
require.Error(t, err)
}
// TestClusterByLockTime tests the method clusterByLockTime works as expected.
func TestClusterByLockTime(t *testing.T) {
t.Parallel()
// Create a mock FeePreference.
mockFeePref := &MockFeePreference{}
// Create a test param with a dummy fee preference. This is needed so
// `feeRateForPreference` won't throw an error.
param := Params{Fee: mockFeePref}
// We begin the test by creating three clusters of inputs, the first
// cluster has a locktime of 1, the second has a locktime of 2, and the
// final has no locktime.
lockTime1 := uint32(1)
lockTime2 := uint32(2)
// Create cluster one, which has a locktime of 1.
input1LockTime1 := &input.MockInput{}
input2LockTime1 := &input.MockInput{}
input1LockTime1.On("RequiredLockTime").Return(lockTime1, true)
input2LockTime1.On("RequiredLockTime").Return(lockTime1, true)
// Create cluster two, which has a locktime of 2.
input3LockTime2 := &input.MockInput{}
input4LockTime2 := &input.MockInput{}
input3LockTime2.On("RequiredLockTime").Return(lockTime2, true)
input4LockTime2.On("RequiredLockTime").Return(lockTime2, true)
// Create cluster three, which has no locktime.
input5NoLockTime := &input.MockInput{}
input6NoLockTime := &input.MockInput{}
input5NoLockTime.On("RequiredLockTime").Return(uint32(0), false)
input6NoLockTime.On("RequiredLockTime").Return(uint32(0), false)
// With the inner Input being mocked, we can now create the pending
// inputs.
input1 := &pendingInput{Input: input1LockTime1, params: param}
input2 := &pendingInput{Input: input2LockTime1, params: param}
input3 := &pendingInput{Input: input3LockTime2, params: param}
input4 := &pendingInput{Input: input4LockTime2, params: param}
input5 := &pendingInput{Input: input5NoLockTime, params: param}
input6 := &pendingInput{Input: input6NoLockTime, params: param}
// Create the pending inputs map, which will be passed to the method
// under test.
//
// NOTE: we don't care the actual outpoint values as long as they are
// unique.
inputs := pendingInputs{
wire.OutPoint{Index: 1}: input1,
wire.OutPoint{Index: 2}: input2,
wire.OutPoint{Index: 3}: input3,
wire.OutPoint{Index: 4}: input4,
wire.OutPoint{Index: 5}: input5,
wire.OutPoint{Index: 6}: input6,
}
// Create expected clusters so we can shorten the line length in the
// test cases below.
cluster1 := pendingInputs{
wire.OutPoint{Index: 1}: input1,
wire.OutPoint{Index: 2}: input2,
}
cluster2 := pendingInputs{
wire.OutPoint{Index: 3}: input3,
wire.OutPoint{Index: 4}: input4,
}
// cluster3 should be the remaining inputs since they don't have
// locktime.
cluster3 := pendingInputs{
wire.OutPoint{Index: 5}: input5,
wire.OutPoint{Index: 6}: input6,
}
// Set the min fee rate to be 1000 sat/kw.
const minFeeRate = chainfee.SatPerKWeight(1000)
// Create a test sweeper.
s := New(&UtxoSweeperConfig{
MaxFeeRate: minFeeRate.FeePerVByte() * 10,
})
// Set the relay fee to be the minFeeRate. Any fee rate below the
// minFeeRate will cause an error to be returned.
s.relayFeeRate = minFeeRate
testCases := []struct {
name string
// setupMocker takes a testing fee rate and makes a mocker over
// `Estimate` that always return the testing fee rate.
setupMocker func()
testFeeRate chainfee.SatPerKWeight
expectedClusters []inputCluster
expectedRemainingInputs pendingInputs
}{
{
// Test a successful case where the locktime clusters
// are created and the no-locktime cluster is returned
// as the remaining inputs.
name: "successfully create clusters",
setupMocker: func() {
mockFeePref.On("Estimate",
s.cfg.FeeEstimator,
s.cfg.MaxFeeRate.FeePerKWeight(),
// Expect the four inputs with locktime to call
// this method.
).Return(minFeeRate+1, nil).Times(4)
},
// Use a fee rate above the min value so we don't hit
// an error when performing fee estimation.
//
// TODO(yy): we should customize the returned fee rate
// for each input to further test the averaging logic.
// Or we can split the method into two, one for
// grouping the clusters and the other for averaging
// the fee rates so it's easier to be tested.
testFeeRate: minFeeRate + 1,
expectedClusters: []inputCluster{
{
lockTime: &lockTime1,
sweepFeeRate: minFeeRate + 1,
inputs: cluster1,
},
{
lockTime: &lockTime2,
sweepFeeRate: minFeeRate + 1,
inputs: cluster2,
},
},
expectedRemainingInputs: cluster3,
},
{
// Test that when the input is skipped when the fee
// estimation returns an error.
name: "error from fee estimation",
setupMocker: func() {
mockFeePref.On("Estimate",
s.cfg.FeeEstimator,
s.cfg.MaxFeeRate.FeePerKWeight(),
).Return(chainfee.SatPerKWeight(0),
errors.New("dummy")).Times(4)
},
// Use a fee rate below the min value so we hit an
// error when performing fee estimation.
testFeeRate: minFeeRate - 1,
expectedClusters: []inputCluster{},
// Remaining inputs should stay untouched.
expectedRemainingInputs: cluster3,
},
}
//nolint:paralleltest
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
// Apply the test fee rate so `feeRateForPreference` is
// mocked to return the specified value.
tc.setupMocker()
// Assert the mocked methods are called as expeceted.
defer mockFeePref.AssertExpectations(t)
// Call the method under test.
clusters, remainingInputs := s.clusterByLockTime(inputs)
// Sort by locktime as the order is not guaranteed.
sort.Slice(clusters, func(i, j int) bool {
return *clusters[i].lockTime <
*clusters[j].lockTime
})
// Validate the values are returned as expected.
require.Equal(t, tc.expectedClusters, clusters)
require.Equal(t, tc.expectedRemainingInputs,
remainingInputs,
)
// Assert the mocked methods are called as expected.
input1LockTime1.AssertExpectations(t)
input2LockTime1.AssertExpectations(t)
input3LockTime2.AssertExpectations(t)
input4LockTime2.AssertExpectations(t)
input5NoLockTime.AssertExpectations(t)
input6NoLockTime.AssertExpectations(t)
})
}
}
// TestGetInputLists checks that the expected input sets are returned based on
// whether there are retried inputs or not.
func TestGetInputLists(t *testing.T) {