From 1870caf39c57f114fa61958fad726873a4467df0 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 24 Oct 2023 12:32:17 +0800 Subject: [PATCH] sweep+lnd: introduce `UtxoAggregator` to handle clustering inputs This commit refactors the grouping logic into a new interface `UtxoAggregator`, which makes it easier to write tests and opens possibility for future customized clustering strategies. The old clustering logic is kept as and moved into `SimpleAggregator`. --- server.go | 6 +- sweep/aggregator.go | 351 +++++++++++++++++++++++++++++++ sweep/aggregator_test.go | 423 +++++++++++++++++++++++++++++++++++++ sweep/mocks.go | 15 ++ sweep/sweeper.go | 305 +-------------------------- sweep/sweeper_test.go | 435 +-------------------------------------- 6 files changed, 807 insertions(+), 728 deletions(-) create mode 100644 sweep/aggregator.go create mode 100644 sweep/aggregator_test.go diff --git a/server.go b/server.go index 19d2c389c..1819cc67a 100644 --- a/server.go +++ b/server.go @@ -1063,6 +1063,10 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return nil, err } + aggregator := sweep.NewSimpleUtxoAggregator( + cc.FeeEstimator, cfg.Sweeper.MaxFeeRate.FeePerKWeight(), + ) + s.sweeper = sweep.New(&sweep.UtxoSweeperConfig{ FeeEstimator: cc.FeeEstimator, GenSweepScript: newSweepPkScriptGen(cc.Wallet), @@ -1075,7 +1079,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, MaxSweepAttempts: sweep.DefaultMaxSweepAttempts, NextAttemptDeltaFunc: sweep.DefaultNextAttemptDeltaFunc, MaxFeeRate: cfg.Sweeper.MaxFeeRate, - FeeRateBucketSize: sweep.DefaultFeeRateBucketSize, + Aggregator: aggregator, }) s.utxoNursery = contractcourt.NewUtxoNursery(&contractcourt.NurseryConfig{ diff --git a/sweep/aggregator.go b/sweep/aggregator.go new file mode 100644 index 000000000..6797e3573 --- /dev/null +++ b/sweep/aggregator.go @@ -0,0 +1,351 @@ +package sweep + +import ( + "sort" + + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" +) + +const ( + // DefaultFeeRateBucketSize is the default size of fee rate buckets + // we'll use when clustering inputs into buckets with similar fee rates + // within the SimpleAggregator. + // + // Given a minimum relay fee rate of 1 sat/vbyte, a multiplier of 10 + // would result in the following fee rate buckets up to the maximum fee + // rate: + // + // #1: min = 1 sat/vbyte, max = 10 sat/vbyte + // #2: min = 11 sat/vbyte, max = 20 sat/vbyte... + DefaultFeeRateBucketSize = 10 +) + +// UtxoAggregator defines an interface that takes a list of inputs and +// aggregate them into groups. Each group is used as the inputs to create a +// sweeping transaction. +type UtxoAggregator interface { + // ClusterInputs takes a list of inputs and groups them into clusters. + ClusterInputs(pendingInputs) []inputCluster +} + +// SimpleAggregator aggregates inputs known by the Sweeper based on each +// input's locktime and feerate. +type SimpleAggregator struct { + // FeeEstimator is used when crafting sweep transactions to estimate + // the necessary fee relative to the expected size of the sweep + // transaction. + FeeEstimator chainfee.Estimator + + // MaxFeeRate is the maximum fee rate allowed within the + // SimpleAggregator. + MaxFeeRate chainfee.SatPerKWeight + + // FeeRateBucketSize is the default size of fee rate buckets we'll use + // when clustering inputs into buckets with similar fee rates within + // the SimpleAggregator. + // + // Given a minimum relay fee rate of 1 sat/vbyte, a fee rate bucket + // size of 10 would result in the following fee rate buckets up to the + // maximum fee rate: + // + // #1: min = 1 sat/vbyte, max (exclusive) = 11 sat/vbyte + // #2: min = 11 sat/vbyte, max (exclusive) = 21 sat/vbyte... + FeeRateBucketSize int +} + +// Compile-time constraint to ensure SimpleAggregator implements UtxoAggregator. +var _ UtxoAggregator = (*SimpleAggregator)(nil) + +// NewSimpleUtxoAggregator creates a new instance of a SimpleAggregator. +func NewSimpleUtxoAggregator(estimator chainfee.Estimator, + max chainfee.SatPerKWeight) *SimpleAggregator { + + return &SimpleAggregator{ + FeeEstimator: estimator, + MaxFeeRate: max, + FeeRateBucketSize: DefaultFeeRateBucketSize, + } +} + +// ClusterInputs creates a list of input clusters from the set of pending +// inputs known by the UtxoSweeper. It clusters inputs by +// 1) Required tx locktime +// 2) Similar fee rates. +// +// TODO(yy): remove this nolint once done refactoring. +// +//nolint:revive +func (s *SimpleAggregator) ClusterInputs(inputs pendingInputs) []inputCluster { + // We start by getting the inputs clusters by locktime. Since the + // inputs commit to the locktime, they can only be clustered together + // if the locktime is equal. + lockTimeClusters, nonLockTimeInputs := s.clusterByLockTime(inputs) + + // Cluster the remaining inputs by sweep fee rate. + feeClusters := s.clusterBySweepFeeRate(nonLockTimeInputs) + + // Since the inputs that we clustered by fee rate don't commit to a + // specific locktime, we can try to merge a locktime cluster with a fee + // cluster. + return zipClusters(lockTimeClusters, feeClusters) +} + +// clusterByLockTime takes the given set of pending inputs and clusters those +// with equal locktime together. Each cluster contains a sweep fee rate, which +// is determined by calculating the average fee rate of all inputs within that +// cluster. In addition to the created clusters, inputs that did not specify a +// required locktime are returned. +func (s *SimpleAggregator) clusterByLockTime( + inputs pendingInputs) ([]inputCluster, pendingInputs) { + + locktimes := make(map[uint32]pendingInputs) + rem := make(pendingInputs) + + // Go through all inputs and check if they require a certain locktime. + for op, input := range inputs { + lt, ok := input.RequiredLockTime() + if !ok { + rem[op] = input + continue + } + + // Check if we already have inputs with this locktime. + cluster, ok := locktimes[lt] + if !ok { + cluster = make(pendingInputs) + } + + // Get the fee rate based on the fee preference. If an error is + // returned, we'll skip sweeping this input for this round of + // cluster creation and retry it when we create the clusters + // from the pending inputs again. + feeRate, err := input.params.Fee.Estimate( + s.FeeEstimator, s.MaxFeeRate, + ) + if err != nil { + log.Warnf("Skipping input %v: %v", op, err) + continue + } + + log.Debugf("Adding input %v to cluster with locktime=%v, "+ + "feeRate=%v", op, lt, feeRate) + + // Attach the fee rate to the input. + input.lastFeeRate = feeRate + + // Update the cluster about the updated input. + cluster[op] = input + locktimes[lt] = cluster + } + + // We'll then determine the sweep fee rate for each set of inputs by + // calculating the average fee rate of the inputs within each set. + inputClusters := make([]inputCluster, 0, len(locktimes)) + for lt, cluster := range locktimes { + lt := lt + + var sweepFeeRate chainfee.SatPerKWeight + for _, input := range cluster { + sweepFeeRate += input.lastFeeRate + } + + sweepFeeRate /= chainfee.SatPerKWeight(len(cluster)) + inputClusters = append(inputClusters, inputCluster{ + lockTime: <, + sweepFeeRate: sweepFeeRate, + inputs: cluster, + }) + } + + return inputClusters, rem +} + +// clusterBySweepFeeRate takes the set of pending inputs within the UtxoSweeper +// and clusters those together with similar fee rates. Each cluster contains a +// sweep fee rate, which is determined by calculating the average fee rate of +// all inputs within that cluster. +func (s *SimpleAggregator) clusterBySweepFeeRate( + inputs pendingInputs) []inputCluster { + + bucketInputs := make(map[int]*bucketList) + inputFeeRates := make(map[wire.OutPoint]chainfee.SatPerKWeight) + + // First, we'll group together all inputs with similar fee rates. This + // is done by determining the fee rate bucket they should belong in. + for op, input := range inputs { + feeRate, err := input.params.Fee.Estimate( + s.FeeEstimator, s.MaxFeeRate, + ) + if err != nil { + log.Warnf("Skipping input %v: %v", op, err) + continue + } + + // Only try to sweep inputs with an unconfirmed parent if the + // current sweep fee rate exceeds the parent tx fee rate. This + // assumes that such inputs are offered to the sweeper solely + // for the purpose of anchoring down the parent tx using cpfp. + parentTx := input.UnconfParent() + if parentTx != nil { + parentFeeRate := + chainfee.SatPerKWeight(parentTx.Fee*1000) / + chainfee.SatPerKWeight(parentTx.Weight) + + if parentFeeRate >= feeRate { + log.Debugf("Skipping cpfp input %v: "+ + "fee_rate=%v, parent_fee_rate=%v", op, + feeRate, parentFeeRate) + + continue + } + } + + feeGroup := s.bucketForFeeRate(feeRate) + + // Create a bucket list for this fee rate if there isn't one + // yet. + buckets, ok := bucketInputs[feeGroup] + if !ok { + buckets = &bucketList{} + bucketInputs[feeGroup] = buckets + } + + // Request the bucket list to add this input. The bucket list + // will take into account exclusive group constraints. + buckets.add(input) + + input.lastFeeRate = feeRate + inputFeeRates[op] = feeRate + } + + // We'll then determine the sweep fee rate for each set of inputs by + // calculating the average fee rate of the inputs within each set. + inputClusters := make([]inputCluster, 0, len(bucketInputs)) + for _, buckets := range bucketInputs { + for _, inputs := range buckets.buckets { + var sweepFeeRate chainfee.SatPerKWeight + for op := range inputs { + sweepFeeRate += inputFeeRates[op] + } + sweepFeeRate /= chainfee.SatPerKWeight(len(inputs)) + inputClusters = append(inputClusters, inputCluster{ + sweepFeeRate: sweepFeeRate, + inputs: inputs, + }) + } + } + + return inputClusters +} + +// bucketForFeeReate determines the proper bucket for a fee rate. This is done +// in order to batch inputs with similar fee rates together. +func (s *SimpleAggregator) bucketForFeeRate( + feeRate chainfee.SatPerKWeight) int { + + relayFeeRate := s.FeeEstimator.RelayFeePerKW() + + // Create an isolated bucket for sweeps at the minimum fee rate. This + // is to prevent very small outputs (anchors) from becoming + // uneconomical if their fee rate would be averaged with higher fee + // rate inputs in a regular bucket. + if feeRate == relayFeeRate { + return 0 + } + + return 1 + int(feeRate-relayFeeRate)/s.FeeRateBucketSize +} + +// mergeClusters attempts to merge cluster a and b if they are compatible. The +// new cluster will have the locktime set if a or b had a locktime set, and a +// sweep fee rate that is the maximum of a and b's. If the two clusters are not +// compatible, they will be returned unchanged. +func mergeClusters(a, b inputCluster) []inputCluster { + newCluster := inputCluster{} + + switch { + // Incompatible locktimes, return the sets without merging them. + case a.lockTime != nil && b.lockTime != nil && + *a.lockTime != *b.lockTime: + + return []inputCluster{a, b} + + case a.lockTime != nil: + newCluster.lockTime = a.lockTime + + case b.lockTime != nil: + newCluster.lockTime = b.lockTime + } + + if a.sweepFeeRate > b.sweepFeeRate { + newCluster.sweepFeeRate = a.sweepFeeRate + } else { + newCluster.sweepFeeRate = b.sweepFeeRate + } + + newCluster.inputs = make(pendingInputs) + + for op, in := range a.inputs { + newCluster.inputs[op] = in + } + + for op, in := range b.inputs { + newCluster.inputs[op] = in + } + + return []inputCluster{newCluster} +} + +// zipClusters merges pairwise clusters from as and bs such that cluster a from +// as is merged with a cluster from bs that has at least the fee rate of a. +// This to ensure we don't delay confirmation by decreasing the fee rate (the +// lock time inputs are typically second level HTLC transactions, that are time +// sensitive). +func zipClusters(as, bs []inputCluster) []inputCluster { + // Sort the clusters by decreasing fee rates. + sort.Slice(as, func(i, j int) bool { + return as[i].sweepFeeRate > + as[j].sweepFeeRate + }) + sort.Slice(bs, func(i, j int) bool { + return bs[i].sweepFeeRate > + bs[j].sweepFeeRate + }) + + var ( + finalClusters []inputCluster + j int + ) + + // Go through each cluster in as, and merge with the next one from bs + // if it has at least the fee rate needed. + for i := range as { + a := as[i] + + switch { + // If the fee rate for the next one from bs is at least a's, we + // merge. + case j < len(bs) && bs[j].sweepFeeRate >= a.sweepFeeRate: + merged := mergeClusters(a, bs[j]) + finalClusters = append(finalClusters, merged...) + + // Increment j for the next round. + j++ + + // We did not merge, meaning all the remaining clusters from bs + // have lower fee rate. Instead we add a directly to the final + // clusters. + default: + finalClusters = append(finalClusters, a) + } + } + + // Add any remaining clusters from bs. + for ; j < len(bs); j++ { + b := bs[j] + finalClusters = append(finalClusters, b) + } + + return finalClusters +} diff --git a/sweep/aggregator_test.go b/sweep/aggregator_test.go new file mode 100644 index 000000000..f3bf2cd28 --- /dev/null +++ b/sweep/aggregator_test.go @@ -0,0 +1,423 @@ +package sweep + +import ( + "errors" + "reflect" + "sort" + "testing" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/stretchr/testify/require" +) + +//nolint:lll +var ( + testInputsA = pendingInputs{ + wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{}, + } + + testInputsB = pendingInputs{ + wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{}, + } + + testInputsC = pendingInputs{ + wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{}, + } +) + +// TestMergeClusters check that we properly can merge clusters together, +// according to their required locktime. +func TestMergeClusters(t *testing.T) { + t.Parallel() + + lockTime1 := uint32(100) + lockTime2 := uint32(200) + + testCases := []struct { + name string + a inputCluster + b inputCluster + res []inputCluster + }{ + { + name: "max fee rate", + a: inputCluster{ + sweepFeeRate: 5000, + inputs: testInputsA, + }, + b: inputCluster{ + sweepFeeRate: 7000, + inputs: testInputsB, + }, + res: []inputCluster{ + { + sweepFeeRate: 7000, + inputs: testInputsC, + }, + }, + }, + { + name: "same locktime", + a: inputCluster{ + lockTime: &lockTime1, + sweepFeeRate: 5000, + inputs: testInputsA, + }, + b: inputCluster{ + lockTime: &lockTime1, + sweepFeeRate: 7000, + inputs: testInputsB, + }, + res: []inputCluster{ + { + lockTime: &lockTime1, + sweepFeeRate: 7000, + inputs: testInputsC, + }, + }, + }, + { + name: "diff locktime", + a: inputCluster{ + lockTime: &lockTime1, + sweepFeeRate: 5000, + inputs: testInputsA, + }, + b: inputCluster{ + lockTime: &lockTime2, + sweepFeeRate: 7000, + inputs: testInputsB, + }, + res: []inputCluster{ + { + lockTime: &lockTime1, + sweepFeeRate: 5000, + inputs: testInputsA, + }, + { + lockTime: &lockTime2, + sweepFeeRate: 7000, + inputs: testInputsB, + }, + }, + }, + } + + for _, test := range testCases { + merged := mergeClusters(test.a, test.b) + if !reflect.DeepEqual(merged, test.res) { + t.Fatalf("[%s] unexpected result: %v", + test.name, spew.Sdump(merged)) + } + } +} + +// TestZipClusters tests that we can merge lists of inputs clusters correctly. +func TestZipClusters(t *testing.T) { + t.Parallel() + + createCluster := func(inp pendingInputs, + f chainfee.SatPerKWeight) inputCluster { + + return inputCluster{ + sweepFeeRate: f, + inputs: inp, + } + } + + testCases := []struct { + name string + as []inputCluster + bs []inputCluster + res []inputCluster + }{ + { + name: "merge A into B", + as: []inputCluster{ + createCluster(testInputsA, 5000), + }, + bs: []inputCluster{ + createCluster(testInputsB, 7000), + }, + res: []inputCluster{ + createCluster(testInputsC, 7000), + }, + }, + { + name: "A can't merge with B", + as: []inputCluster{ + createCluster(testInputsA, 7000), + }, + bs: []inputCluster{ + createCluster(testInputsB, 5000), + }, + res: []inputCluster{ + createCluster(testInputsA, 7000), + createCluster(testInputsB, 5000), + }, + }, + { + name: "empty bs", + as: []inputCluster{ + createCluster(testInputsA, 7000), + }, + bs: []inputCluster{}, + res: []inputCluster{ + createCluster(testInputsA, 7000), + }, + }, + { + name: "empty as", + as: []inputCluster{}, + bs: []inputCluster{ + createCluster(testInputsB, 5000), + }, + res: []inputCluster{ + createCluster(testInputsB, 5000), + }, + }, + + { + name: "zip 3xA into 3xB", + as: []inputCluster{ + createCluster(testInputsA, 5000), + createCluster(testInputsA, 5000), + createCluster(testInputsA, 5000), + }, + bs: []inputCluster{ + createCluster(testInputsB, 7000), + createCluster(testInputsB, 7000), + createCluster(testInputsB, 7000), + }, + res: []inputCluster{ + createCluster(testInputsC, 7000), + createCluster(testInputsC, 7000), + createCluster(testInputsC, 7000), + }, + }, + { + name: "zip A into 3xB", + as: []inputCluster{ + createCluster(testInputsA, 2500), + }, + bs: []inputCluster{ + createCluster(testInputsB, 3000), + createCluster(testInputsB, 2000), + createCluster(testInputsB, 1000), + }, + res: []inputCluster{ + createCluster(testInputsC, 3000), + createCluster(testInputsB, 2000), + createCluster(testInputsB, 1000), + }, + }, + } + + for _, test := range testCases { + zipped := zipClusters(test.as, test.bs) + if !reflect.DeepEqual(zipped, test.res) { + t.Fatalf("[%s] unexpected result: %v", + test.name, spew.Sdump(zipped)) + } + } +} + +// TestClusterByLockTime tests the method clusterByLockTime works as expected. +func TestClusterByLockTime(t *testing.T) { + t.Parallel() + + // Create a mock FeePreference. + mockFeePref := &MockFeePreference{} + + // Create a test param with a dummy fee preference. This is needed so + // `feeRateForPreference` won't throw an error. + param := Params{Fee: mockFeePref} + + // We begin the test by creating three clusters of inputs, the first + // cluster has a locktime of 1, the second has a locktime of 2, and the + // final has no locktime. + lockTime1 := uint32(1) + lockTime2 := uint32(2) + + // Create cluster one, which has a locktime of 1. + input1LockTime1 := &input.MockInput{} + input2LockTime1 := &input.MockInput{} + input1LockTime1.On("RequiredLockTime").Return(lockTime1, true) + input2LockTime1.On("RequiredLockTime").Return(lockTime1, true) + + // Create cluster two, which has a locktime of 2. + input3LockTime2 := &input.MockInput{} + input4LockTime2 := &input.MockInput{} + input3LockTime2.On("RequiredLockTime").Return(lockTime2, true) + input4LockTime2.On("RequiredLockTime").Return(lockTime2, true) + + // Create cluster three, which has no locktime. + input5NoLockTime := &input.MockInput{} + input6NoLockTime := &input.MockInput{} + input5NoLockTime.On("RequiredLockTime").Return(uint32(0), false) + input6NoLockTime.On("RequiredLockTime").Return(uint32(0), false) + + // With the inner Input being mocked, we can now create the pending + // inputs. + input1 := &pendingInput{Input: input1LockTime1, params: param} + input2 := &pendingInput{Input: input2LockTime1, params: param} + input3 := &pendingInput{Input: input3LockTime2, params: param} + input4 := &pendingInput{Input: input4LockTime2, params: param} + input5 := &pendingInput{Input: input5NoLockTime, params: param} + input6 := &pendingInput{Input: input6NoLockTime, params: param} + + // Create the pending inputs map, which will be passed to the method + // under test. + // + // NOTE: we don't care the actual outpoint values as long as they are + // unique. + inputs := pendingInputs{ + wire.OutPoint{Index: 1}: input1, + wire.OutPoint{Index: 2}: input2, + wire.OutPoint{Index: 3}: input3, + wire.OutPoint{Index: 4}: input4, + wire.OutPoint{Index: 5}: input5, + wire.OutPoint{Index: 6}: input6, + } + + // Create expected clusters so we can shorten the line length in the + // test cases below. + cluster1 := pendingInputs{ + wire.OutPoint{Index: 1}: input1, + wire.OutPoint{Index: 2}: input2, + } + cluster2 := pendingInputs{ + wire.OutPoint{Index: 3}: input3, + wire.OutPoint{Index: 4}: input4, + } + + // cluster3 should be the remaining inputs since they don't have + // locktime. + cluster3 := pendingInputs{ + wire.OutPoint{Index: 5}: input5, + wire.OutPoint{Index: 6}: input6, + } + + const ( + // Set the min fee rate to be 1000 sat/kw. + minFeeRate = chainfee.SatPerKWeight(1000) + + // Set the max fee rate to be 10,000 sat/kw. + maxFeeRate = chainfee.SatPerKWeight(10_000) + ) + + // Create a test aggregator. + s := NewSimpleUtxoAggregator(nil, maxFeeRate) + + testCases := []struct { + name string + // setupMocker takes a testing fee rate and makes a mocker over + // `Estimate` that always return the testing fee rate. + setupMocker func() + testFeeRate chainfee.SatPerKWeight + expectedClusters []inputCluster + expectedRemainingInputs pendingInputs + }{ + { + // Test a successful case where the locktime clusters + // are created and the no-locktime cluster is returned + // as the remaining inputs. + name: "successfully create clusters", + setupMocker: func() { + // Expect the four inputs with locktime to call + // this method. + mockFeePref.On("Estimate", nil, maxFeeRate). + Return(minFeeRate+1, nil).Times(4) + }, + // Use a fee rate above the min value so we don't hit + // an error when performing fee estimation. + // + // TODO(yy): we should customize the returned fee rate + // for each input to further test the averaging logic. + // Or we can split the method into two, one for + // grouping the clusters and the other for averaging + // the fee rates so it's easier to be tested. + testFeeRate: minFeeRate + 1, + expectedClusters: []inputCluster{ + { + lockTime: &lockTime1, + sweepFeeRate: minFeeRate + 1, + inputs: cluster1, + }, + { + lockTime: &lockTime2, + sweepFeeRate: minFeeRate + 1, + inputs: cluster2, + }, + }, + expectedRemainingInputs: cluster3, + }, + { + // Test that when the input is skipped when the fee + // estimation returns an error. + name: "error from fee estimation", + setupMocker: func() { + mockFeePref.On("Estimate", nil, maxFeeRate). + Return(chainfee.SatPerKWeight(0), + errors.New("dummy")).Times(4) + }, + + // Use a fee rate below the min value so we hit an + // error when performing fee estimation. + testFeeRate: minFeeRate - 1, + expectedClusters: []inputCluster{}, + // Remaining inputs should stay untouched. + expectedRemainingInputs: cluster3, + }, + } + + //nolint:paralleltest + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + // Apply the test fee rate so `feeRateForPreference` is + // mocked to return the specified value. + tc.setupMocker() + + // Assert the mocked methods are called as expeceted. + defer mockFeePref.AssertExpectations(t) + + // Call the method under test. + clusters, remainingInputs := s.clusterByLockTime(inputs) + + // Sort by locktime as the order is not guaranteed. + sort.Slice(clusters, func(i, j int) bool { + return *clusters[i].lockTime < + *clusters[j].lockTime + }) + + // Validate the values are returned as expected. + require.Equal(t, tc.expectedClusters, clusters) + require.Equal(t, tc.expectedRemainingInputs, + remainingInputs, + ) + + // Assert the mocked methods are called as expeceted. + input1LockTime1.AssertExpectations(t) + input2LockTime1.AssertExpectations(t) + input3LockTime2.AssertExpectations(t) + input4LockTime2.AssertExpectations(t) + input5NoLockTime.AssertExpectations(t) + input6NoLockTime.AssertExpectations(t) + }) + } +} diff --git a/sweep/mocks.go b/sweep/mocks.go index 516e35837..3c8882308 100644 --- a/sweep/mocks.go +++ b/sweep/mocks.go @@ -27,3 +27,18 @@ func (m *MockFeePreference) Estimate(estimator chainfee.Estimator, return args.Get(0).(chainfee.SatPerKWeight), args.Error(1) } + +type mockUtxoAggregator struct { + mock.Mock +} + +// Compile-time constraint to ensure mockUtxoAggregator implements +// UtxoAggregator. +var _ UtxoAggregator = (*mockUtxoAggregator)(nil) + +// ClusterInputs takes a list of inputs and groups them into clusters. +func (m *mockUtxoAggregator) ClusterInputs(pendingInputs) []inputCluster { + args := m.Called(pendingInputs{}) + + return args.Get(0).([]inputCluster) +} diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 186b2684c..69f3403e2 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -20,20 +20,6 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" ) -const ( - // DefaultFeeRateBucketSize is the default size of fee rate buckets - // we'll use when clustering inputs into buckets with similar fee rates - // within the UtxoSweeper. - // - // Given a minimum relay fee rate of 1 sat/vbyte, a multiplier of 10 - // would result in the following fee rate buckets up to the maximum fee - // rate: - // - // #1: min = 1 sat/vbyte, max = 10 sat/vbyte - // #2: min = 11 sat/vbyte, max = 20 sat/vbyte... - DefaultFeeRateBucketSize = 10 -) - var ( // ErrRemoteSpend is returned in case an output that we try to sweep is // confirmed in a tx of the remote party. @@ -287,17 +273,9 @@ type UtxoSweeperConfig struct { // UtxoSweeper. MaxFeeRate chainfee.SatPerVByte - // FeeRateBucketSize is the default size of fee rate buckets we'll use - // when clustering inputs into buckets with similar fee rates within the - // UtxoSweeper. - // - // Given a minimum relay fee rate of 1 sat/vbyte, a fee rate bucket size - // of 10 would result in the following fee rate buckets up to the - // maximum fee rate: - // - // #1: min = 1 sat/vbyte, max (exclusive) = 11 sat/vbyte - // #2: min = 11 sat/vbyte, max (exclusive) = 21 sat/vbyte... - FeeRateBucketSize int + // Aggregator is used to group inputs into clusters based on its + // implemention-specific strategy. + Aggregator UtxoAggregator } // Result is the struct that is pushed through the result channel. Callers can @@ -717,280 +695,6 @@ func (s *UtxoSweeper) sweepCluster(cluster inputCluster) error { }) } -// bucketForFeeReate determines the proper bucket for a fee rate. This is done -// in order to batch inputs with similar fee rates together. -func (s *UtxoSweeper) bucketForFeeRate( - feeRate chainfee.SatPerKWeight) int { - - // Create an isolated bucket for sweeps at the minimum fee rate. This is - // to prevent very small outputs (anchors) from becoming uneconomical if - // their fee rate would be averaged with higher fee rate inputs in a - // regular bucket. - if feeRate == s.relayFeeRate { - return 0 - } - - return 1 + int(feeRate-s.relayFeeRate)/s.cfg.FeeRateBucketSize -} - -// createInputClusters creates a list of input clusters from the set of pending -// inputs known by the UtxoSweeper. It clusters inputs by -// 1) Required tx locktime -// 2) Similar fee rates. -func (s *UtxoSweeper) createInputClusters() []inputCluster { - inputs := s.pendingInputs - - // We start by getting the inputs clusters by locktime. Since the - // inputs commit to the locktime, they can only be clustered together - // if the locktime is equal. - lockTimeClusters, nonLockTimeInputs := s.clusterByLockTime(inputs) - - // Cluster the remaining inputs by sweep fee rate. - feeClusters := s.clusterBySweepFeeRate(nonLockTimeInputs) - - // Since the inputs that we clustered by fee rate don't commit to a - // specific locktime, we can try to merge a locktime cluster with a fee - // cluster. - return zipClusters(lockTimeClusters, feeClusters) -} - -// clusterByLockTime takes the given set of pending inputs and clusters those -// with equal locktime together. Each cluster contains a sweep fee rate, which -// is determined by calculating the average fee rate of all inputs within that -// cluster. In addition to the created clusters, inputs that did not specify a -// required lock time are returned. -func (s *UtxoSweeper) clusterByLockTime(inputs pendingInputs) ([]inputCluster, - pendingInputs) { - - locktimes := make(map[uint32]pendingInputs) - rem := make(pendingInputs) - - // Go through all inputs and check if they require a certain locktime. - for op, input := range inputs { - lt, ok := input.RequiredLockTime() - if !ok { - rem[op] = input - continue - } - - // Check if we already have inputs with this locktime. - cluster, ok := locktimes[lt] - if !ok { - cluster = make(pendingInputs) - } - - // Get the fee rate based on the fee preference. If an error is - // returned, we'll skip sweeping this input for this round of - // cluster creation and retry it when we create the clusters - // from the pending inputs again. - feeRate, err := input.params.Fee.Estimate( - s.cfg.FeeEstimator, s.cfg.MaxFeeRate.FeePerKWeight(), - ) - if err != nil { - log.Warnf("Skipping input %v: %v", op, err) - continue - } - - log.Debugf("Adding input %v to cluster with locktime=%v, "+ - "feeRate=%v", op, lt, feeRate) - - // Attach the fee rate to the input. - input.lastFeeRate = feeRate - - // Update the cluster about the updated input. - cluster[op] = input - locktimes[lt] = cluster - } - - // We'll then determine the sweep fee rate for each set of inputs by - // calculating the average fee rate of the inputs within each set. - inputClusters := make([]inputCluster, 0, len(locktimes)) - for lt, cluster := range locktimes { - lt := lt - - var sweepFeeRate chainfee.SatPerKWeight - for _, input := range cluster { - sweepFeeRate += input.lastFeeRate - } - - sweepFeeRate /= chainfee.SatPerKWeight(len(cluster)) - inputClusters = append(inputClusters, inputCluster{ - lockTime: <, - sweepFeeRate: sweepFeeRate, - inputs: cluster, - }) - } - - return inputClusters, rem -} - -// clusterBySweepFeeRate takes the set of pending inputs within the UtxoSweeper -// and clusters those together with similar fee rates. Each cluster contains a -// sweep fee rate, which is determined by calculating the average fee rate of -// all inputs within that cluster. -func (s *UtxoSweeper) clusterBySweepFeeRate(inputs pendingInputs) []inputCluster { - bucketInputs := make(map[int]*bucketList) - inputFeeRates := make(map[wire.OutPoint]chainfee.SatPerKWeight) - - // First, we'll group together all inputs with similar fee rates. This - // is done by determining the fee rate bucket they should belong in. - for op, input := range inputs { - feeRate, err := input.params.Fee.Estimate( - s.cfg.FeeEstimator, s.cfg.MaxFeeRate.FeePerKWeight(), - ) - if err != nil { - log.Warnf("Skipping input %v: %v", op, err) - continue - } - - // Only try to sweep inputs with an unconfirmed parent if the - // current sweep fee rate exceeds the parent tx fee rate. This - // assumes that such inputs are offered to the sweeper solely - // for the purpose of anchoring down the parent tx using cpfp. - parentTx := input.UnconfParent() - if parentTx != nil { - parentFeeRate := - chainfee.SatPerKWeight(parentTx.Fee*1000) / - chainfee.SatPerKWeight(parentTx.Weight) - - if parentFeeRate >= feeRate { - log.Debugf("Skipping cpfp input %v: fee_rate=%v, "+ - "parent_fee_rate=%v", op, feeRate, - parentFeeRate) - - continue - } - } - - feeGroup := s.bucketForFeeRate(feeRate) - - // Create a bucket list for this fee rate if there isn't one - // yet. - buckets, ok := bucketInputs[feeGroup] - if !ok { - buckets = &bucketList{} - bucketInputs[feeGroup] = buckets - } - - // Request the bucket list to add this input. The bucket list - // will take into account exclusive group constraints. - buckets.add(input) - - input.lastFeeRate = feeRate - inputFeeRates[op] = feeRate - } - - // We'll then determine the sweep fee rate for each set of inputs by - // calculating the average fee rate of the inputs within each set. - inputClusters := make([]inputCluster, 0, len(bucketInputs)) - for _, buckets := range bucketInputs { - for _, inputs := range buckets.buckets { - var sweepFeeRate chainfee.SatPerKWeight - for op := range inputs { - sweepFeeRate += inputFeeRates[op] - } - sweepFeeRate /= chainfee.SatPerKWeight(len(inputs)) - inputClusters = append(inputClusters, inputCluster{ - sweepFeeRate: sweepFeeRate, - inputs: inputs, - }) - } - } - - return inputClusters -} - -// zipClusters merges pairwise clusters from as and bs such that cluster a from -// as is merged with a cluster from bs that has at least the fee rate of a. -// This to ensure we don't delay confirmation by decreasing the fee rate (the -// lock time inputs are typically second level HTLC transactions, that are time -// sensitive). -func zipClusters(as, bs []inputCluster) []inputCluster { - // Sort the clusters by decreasing fee rates. - sort.Slice(as, func(i, j int) bool { - return as[i].sweepFeeRate > - as[j].sweepFeeRate - }) - sort.Slice(bs, func(i, j int) bool { - return bs[i].sweepFeeRate > - bs[j].sweepFeeRate - }) - - var ( - finalClusters []inputCluster - j int - ) - - // Go through each cluster in as, and merge with the next one from bs - // if it has at least the fee rate needed. - for i := range as { - a := as[i] - - switch { - // If the fee rate for the next one from bs is at least a's, we - // merge. - case j < len(bs) && bs[j].sweepFeeRate >= a.sweepFeeRate: - merged := mergeClusters(a, bs[j]) - finalClusters = append(finalClusters, merged...) - - // Increment j for the next round. - j++ - - // We did not merge, meaning all the remaining clusters from bs - // have lower fee rate. Instead we add a directly to the final - // clusters. - default: - finalClusters = append(finalClusters, a) - } - } - - // Add any remaining clusters from bs. - for ; j < len(bs); j++ { - b := bs[j] - finalClusters = append(finalClusters, b) - } - - return finalClusters -} - -// mergeClusters attempts to merge cluster a and b if they are compatible. The -// new cluster will have the locktime set if a or b had a locktime set, and a -// sweep fee rate that is the maximum of a and b's. If the two clusters are not -// compatible, they will be returned unchanged. -func mergeClusters(a, b inputCluster) []inputCluster { - newCluster := inputCluster{} - - switch { - // Incompatible locktimes, return the sets without merging them. - case a.lockTime != nil && b.lockTime != nil && *a.lockTime != *b.lockTime: - return []inputCluster{a, b} - - case a.lockTime != nil: - newCluster.lockTime = a.lockTime - - case b.lockTime != nil: - newCluster.lockTime = b.lockTime - } - - if a.sweepFeeRate > b.sweepFeeRate { - newCluster.sweepFeeRate = a.sweepFeeRate - } else { - newCluster.sweepFeeRate = b.sweepFeeRate - } - - newCluster.inputs = make(pendingInputs) - - for op, in := range a.inputs { - newCluster.inputs[op] = in - } - - for op, in := range b.inputs { - newCluster.inputs[op] = in - } - - return []inputCluster{newCluster} -} - // signalAndRemove notifies the listeners of the final result of the input // sweep. It cancels any pending spend notification and removes the input from // the list of pending inputs. When this function returns, the sweeper has @@ -1465,7 +1169,6 @@ func (s *UtxoSweeper) ListSweeps() ([]chainhash.Hash, error) { // handleNewInput processes a new input by registering spend notification and // scheduling sweeping for it. func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) { - outpoint := *input.input.OutPoint() pendInput, pending := s.pendingInputs[outpoint] if pending { @@ -1630,7 +1333,7 @@ func (s *UtxoSweeper) handleSweep() { // Before attempting to sweep them, we'll sort them in descending fee // rate order. We do this to ensure any inputs which have had their fee // rate bumped are broadcast first in order enforce the RBF policy. - inputClusters := s.createInputClusters() + inputClusters := s.cfg.Aggregator.ClusterInputs(s.pendingInputs) sort.Slice(inputClusters, func(i, j int) bool { return inputClusters[i].sweepFeeRate > inputClusters[j].sweepFeeRate diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index e14e776a9..c12b04aae 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -1,11 +1,8 @@ package sweep import ( - "errors" "os" - "reflect" "runtime/pprof" - "sort" "testing" "time" @@ -14,7 +11,6 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" @@ -121,6 +117,10 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { estimator := newMockFeeEstimator(10000, chainfee.FeePerKwFloor) + aggregator := NewSimpleUtxoAggregator( + estimator, DefaultMaxFeeRate.FeePerKWeight(), + ) + ctx := &sweeperTestContext{ notifier: notifier, publishChan: backend.publishChan, @@ -149,8 +149,8 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { // Use delta func without random factor. return 1 << uint(attempts-1) }, - MaxFeeRate: DefaultMaxFeeRate, - FeeRateBucketSize: DefaultFeeRateBucketSize, + MaxFeeRate: DefaultMaxFeeRate, + Aggregator: aggregator, }) ctx.sweeper.Start() @@ -384,9 +384,7 @@ func TestDust(t *testing.T) { dustInput := createTestInput(5260, input.CommitmentTimeLock) _, err := ctx.sweeper.SweepInput(&dustInput, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // No sweep transaction is expected now. The sweeper should recognize // that the sweep output will not be relayed and not generate the tx. It @@ -398,18 +396,13 @@ func TestDust(t *testing.T) { largeInput := createTestInput(100000, input.CommitmentTimeLock) _, err = ctx.sweeper.SweepInput(&largeInput, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // The second input brings the sweep output above the dust limit. We // expect a sweep tx now. sweepTx := ctx.receiveTx() - if len(sweepTx.TxIn) != 2 { - t.Fatalf("Expected tx to sweep 2 inputs, but contains %v "+ - "inputs instead", len(sweepTx.TxIn)) - } + require.Len(t, sweepTx.TxIn, 2, "unexpected num of tx inputs") ctx.backend.mine() @@ -1249,224 +1242,6 @@ func TestCpfp(t *testing.T) { ctx.finish(1) } -var ( - testInputsA = pendingInputs{ - wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{}, - } - - testInputsB = pendingInputs{ - wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{}, - } - - testInputsC = pendingInputs{ - wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{}, - } -) - -// TestMergeClusters check that we properly can merge clusters together, -// according to their required locktime. -func TestMergeClusters(t *testing.T) { - t.Parallel() - - lockTime1 := uint32(100) - lockTime2 := uint32(200) - - testCases := []struct { - name string - a inputCluster - b inputCluster - res []inputCluster - }{ - { - name: "max fee rate", - a: inputCluster{ - sweepFeeRate: 5000, - inputs: testInputsA, - }, - b: inputCluster{ - sweepFeeRate: 7000, - inputs: testInputsB, - }, - res: []inputCluster{ - { - sweepFeeRate: 7000, - inputs: testInputsC, - }, - }, - }, - { - name: "same locktime", - a: inputCluster{ - lockTime: &lockTime1, - sweepFeeRate: 5000, - inputs: testInputsA, - }, - b: inputCluster{ - lockTime: &lockTime1, - sweepFeeRate: 7000, - inputs: testInputsB, - }, - res: []inputCluster{ - { - lockTime: &lockTime1, - sweepFeeRate: 7000, - inputs: testInputsC, - }, - }, - }, - { - name: "diff locktime", - a: inputCluster{ - lockTime: &lockTime1, - sweepFeeRate: 5000, - inputs: testInputsA, - }, - b: inputCluster{ - lockTime: &lockTime2, - sweepFeeRate: 7000, - inputs: testInputsB, - }, - res: []inputCluster{ - { - lockTime: &lockTime1, - sweepFeeRate: 5000, - inputs: testInputsA, - }, - { - lockTime: &lockTime2, - sweepFeeRate: 7000, - inputs: testInputsB, - }, - }, - }, - } - - for _, test := range testCases { - merged := mergeClusters(test.a, test.b) - if !reflect.DeepEqual(merged, test.res) { - t.Fatalf("[%s] unexpected result: %v", - test.name, spew.Sdump(merged)) - } - } -} - -// TestZipClusters tests that we can merge lists of inputs clusters correctly. -func TestZipClusters(t *testing.T) { - t.Parallel() - - createCluster := func(inp pendingInputs, f chainfee.SatPerKWeight) inputCluster { - return inputCluster{ - sweepFeeRate: f, - inputs: inp, - } - } - - testCases := []struct { - name string - as []inputCluster - bs []inputCluster - res []inputCluster - }{ - { - name: "merge A into B", - as: []inputCluster{ - createCluster(testInputsA, 5000), - }, - bs: []inputCluster{ - createCluster(testInputsB, 7000), - }, - res: []inputCluster{ - createCluster(testInputsC, 7000), - }, - }, - { - name: "A can't merge with B", - as: []inputCluster{ - createCluster(testInputsA, 7000), - }, - bs: []inputCluster{ - createCluster(testInputsB, 5000), - }, - res: []inputCluster{ - createCluster(testInputsA, 7000), - createCluster(testInputsB, 5000), - }, - }, - { - name: "empty bs", - as: []inputCluster{ - createCluster(testInputsA, 7000), - }, - bs: []inputCluster{}, - res: []inputCluster{ - createCluster(testInputsA, 7000), - }, - }, - { - name: "empty as", - as: []inputCluster{}, - bs: []inputCluster{ - createCluster(testInputsB, 5000), - }, - res: []inputCluster{ - createCluster(testInputsB, 5000), - }, - }, - - { - name: "zip 3xA into 3xB", - as: []inputCluster{ - createCluster(testInputsA, 5000), - createCluster(testInputsA, 5000), - createCluster(testInputsA, 5000), - }, - bs: []inputCluster{ - createCluster(testInputsB, 7000), - createCluster(testInputsB, 7000), - createCluster(testInputsB, 7000), - }, - res: []inputCluster{ - createCluster(testInputsC, 7000), - createCluster(testInputsC, 7000), - createCluster(testInputsC, 7000), - }, - }, - { - name: "zip A into 3xB", - as: []inputCluster{ - createCluster(testInputsA, 2500), - }, - bs: []inputCluster{ - createCluster(testInputsB, 3000), - createCluster(testInputsB, 2000), - createCluster(testInputsB, 1000), - }, - res: []inputCluster{ - createCluster(testInputsC, 3000), - createCluster(testInputsB, 2000), - createCluster(testInputsB, 1000), - }, - }, - } - - for _, test := range testCases { - zipped := zipClusters(test.as, test.bs) - if !reflect.DeepEqual(zipped, test.res) { - t.Fatalf("[%s] unexpected result: %v", - test.name, spew.Sdump(zipped)) - } - } -} - type testInput struct { *input.BaseInput @@ -2142,198 +1917,6 @@ func TestSweeperShutdownHandling(t *testing.T) { require.Error(t, err) } -// TestClusterByLockTime tests the method clusterByLockTime works as expected. -func TestClusterByLockTime(t *testing.T) { - t.Parallel() - - // Create a mock FeePreference. - mockFeePref := &MockFeePreference{} - - // Create a test param with a dummy fee preference. This is needed so - // `feeRateForPreference` won't throw an error. - param := Params{Fee: mockFeePref} - - // We begin the test by creating three clusters of inputs, the first - // cluster has a locktime of 1, the second has a locktime of 2, and the - // final has no locktime. - lockTime1 := uint32(1) - lockTime2 := uint32(2) - - // Create cluster one, which has a locktime of 1. - input1LockTime1 := &input.MockInput{} - input2LockTime1 := &input.MockInput{} - input1LockTime1.On("RequiredLockTime").Return(lockTime1, true) - input2LockTime1.On("RequiredLockTime").Return(lockTime1, true) - - // Create cluster two, which has a locktime of 2. - input3LockTime2 := &input.MockInput{} - input4LockTime2 := &input.MockInput{} - input3LockTime2.On("RequiredLockTime").Return(lockTime2, true) - input4LockTime2.On("RequiredLockTime").Return(lockTime2, true) - - // Create cluster three, which has no locktime. - input5NoLockTime := &input.MockInput{} - input6NoLockTime := &input.MockInput{} - input5NoLockTime.On("RequiredLockTime").Return(uint32(0), false) - input6NoLockTime.On("RequiredLockTime").Return(uint32(0), false) - - // With the inner Input being mocked, we can now create the pending - // inputs. - input1 := &pendingInput{Input: input1LockTime1, params: param} - input2 := &pendingInput{Input: input2LockTime1, params: param} - input3 := &pendingInput{Input: input3LockTime2, params: param} - input4 := &pendingInput{Input: input4LockTime2, params: param} - input5 := &pendingInput{Input: input5NoLockTime, params: param} - input6 := &pendingInput{Input: input6NoLockTime, params: param} - - // Create the pending inputs map, which will be passed to the method - // under test. - // - // NOTE: we don't care the actual outpoint values as long as they are - // unique. - inputs := pendingInputs{ - wire.OutPoint{Index: 1}: input1, - wire.OutPoint{Index: 2}: input2, - wire.OutPoint{Index: 3}: input3, - wire.OutPoint{Index: 4}: input4, - wire.OutPoint{Index: 5}: input5, - wire.OutPoint{Index: 6}: input6, - } - - // Create expected clusters so we can shorten the line length in the - // test cases below. - cluster1 := pendingInputs{ - wire.OutPoint{Index: 1}: input1, - wire.OutPoint{Index: 2}: input2, - } - cluster2 := pendingInputs{ - wire.OutPoint{Index: 3}: input3, - wire.OutPoint{Index: 4}: input4, - } - - // cluster3 should be the remaining inputs since they don't have - // locktime. - cluster3 := pendingInputs{ - wire.OutPoint{Index: 5}: input5, - wire.OutPoint{Index: 6}: input6, - } - - // Set the min fee rate to be 1000 sat/kw. - const minFeeRate = chainfee.SatPerKWeight(1000) - - // Create a test sweeper. - s := New(&UtxoSweeperConfig{ - MaxFeeRate: minFeeRate.FeePerVByte() * 10, - }) - - // Set the relay fee to be the minFeeRate. Any fee rate below the - // minFeeRate will cause an error to be returned. - s.relayFeeRate = minFeeRate - - testCases := []struct { - name string - // setupMocker takes a testing fee rate and makes a mocker over - // `Estimate` that always return the testing fee rate. - setupMocker func() - testFeeRate chainfee.SatPerKWeight - expectedClusters []inputCluster - expectedRemainingInputs pendingInputs - }{ - { - // Test a successful case where the locktime clusters - // are created and the no-locktime cluster is returned - // as the remaining inputs. - name: "successfully create clusters", - setupMocker: func() { - mockFeePref.On("Estimate", - s.cfg.FeeEstimator, - s.cfg.MaxFeeRate.FeePerKWeight(), - // Expect the four inputs with locktime to call - // this method. - ).Return(minFeeRate+1, nil).Times(4) - }, - // Use a fee rate above the min value so we don't hit - // an error when performing fee estimation. - // - // TODO(yy): we should customize the returned fee rate - // for each input to further test the averaging logic. - // Or we can split the method into two, one for - // grouping the clusters and the other for averaging - // the fee rates so it's easier to be tested. - testFeeRate: minFeeRate + 1, - expectedClusters: []inputCluster{ - { - lockTime: &lockTime1, - sweepFeeRate: minFeeRate + 1, - inputs: cluster1, - }, - { - lockTime: &lockTime2, - sweepFeeRate: minFeeRate + 1, - inputs: cluster2, - }, - }, - expectedRemainingInputs: cluster3, - }, - { - // Test that when the input is skipped when the fee - // estimation returns an error. - name: "error from fee estimation", - setupMocker: func() { - mockFeePref.On("Estimate", - s.cfg.FeeEstimator, - s.cfg.MaxFeeRate.FeePerKWeight(), - ).Return(chainfee.SatPerKWeight(0), - errors.New("dummy")).Times(4) - }, - - // Use a fee rate below the min value so we hit an - // error when performing fee estimation. - testFeeRate: minFeeRate - 1, - expectedClusters: []inputCluster{}, - // Remaining inputs should stay untouched. - expectedRemainingInputs: cluster3, - }, - } - - //nolint:paralleltest - for _, tc := range testCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - // Apply the test fee rate so `feeRateForPreference` is - // mocked to return the specified value. - tc.setupMocker() - - // Assert the mocked methods are called as expeceted. - defer mockFeePref.AssertExpectations(t) - - // Call the method under test. - clusters, remainingInputs := s.clusterByLockTime(inputs) - - // Sort by locktime as the order is not guaranteed. - sort.Slice(clusters, func(i, j int) bool { - return *clusters[i].lockTime < - *clusters[j].lockTime - }) - - // Validate the values are returned as expected. - require.Equal(t, tc.expectedClusters, clusters) - require.Equal(t, tc.expectedRemainingInputs, - remainingInputs, - ) - - // Assert the mocked methods are called as expected. - input1LockTime1.AssertExpectations(t) - input2LockTime1.AssertExpectations(t) - input3LockTime2.AssertExpectations(t) - input4LockTime2.AssertExpectations(t) - input5NoLockTime.AssertExpectations(t) - input6NoLockTime.AssertExpectations(t) - }) - } -} - // TestGetInputLists checks that the expected input sets are returned based on // whether there are retried inputs or not. func TestGetInputLists(t *testing.T) {