diff --git a/config_builder.go b/config_builder.go index 7d6235f39..1c3a842ef 100644 --- a/config_builder.go +++ b/config_builder.go @@ -50,6 +50,7 @@ import ( "github.com/lightningnetwork/lnd/rpcperms" "github.com/lightningnetwork/lnd/signal" "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sweep" "github.com/lightningnetwork/lnd/walletunlocker" "github.com/lightningnetwork/lnd/watchtower" "github.com/lightningnetwork/lnd/watchtower/wtclient" @@ -188,6 +189,10 @@ type AuxComponents struct { // modify the way a coop-close transaction is constructed. AuxChanCloser fn.Option[chancloser.AuxChanCloser] + // AuxSweeper is an optional interface that can be used to modify the + // way sweep transaction are generated. + AuxSweeper fn.Option[sweep.AuxSweeper] + // AuxContractResolver is an optional interface that can be used to // modify the way contracts are resolved. AuxContractResolver fn.Option[lnwallet.AuxContractResolver] diff --git a/server.go b/server.go index 32266ee62..17a926169 100644 --- a/server.go +++ b/server.go @@ -1116,13 +1116,15 @@ func newServer(cfg *Config, listenAddrs []net.Addr, aggregator := sweep.NewBudgetAggregator( cc.FeeEstimator, sweep.DefaultMaxInputsPerTx, + s.implCfg.AuxSweeper, ) s.txPublisher = sweep.NewTxPublisher(sweep.TxPublisherConfig{ - Signer: cc.Wallet.Cfg.Signer, - Wallet: cc.Wallet, - Estimator: cc.FeeEstimator, - Notifier: cc.ChainNotifier, + Signer: cc.Wallet.Cfg.Signer, + Wallet: cc.Wallet, + Estimator: cc.FeeEstimator, + Notifier: cc.ChainNotifier, + AuxSweeper: s.implCfg.AuxSweeper, }) s.sweeper = sweep.New(&sweep.UtxoSweeperConfig{ diff --git a/sweep/aggregator.go b/sweep/aggregator.go index 6028adca3..cd809e81a 100644 --- a/sweep/aggregator.go +++ b/sweep/aggregator.go @@ -5,6 +5,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" @@ -31,6 +32,10 @@ type BudgetAggregator struct { // maxInputs specifies the maximum number of inputs allowed in a single // sweep tx. maxInputs uint32 + + // auxSweeper is an optional interface that can be used to modify the + // way sweep transaction are generated. + auxSweeper fn.Option[AuxSweeper] } // Compile-time constraint to ensure BudgetAggregator implements UtxoAggregator. @@ -38,11 +43,12 @@ var _ UtxoAggregator = (*BudgetAggregator)(nil) // NewBudgetAggregator creates a new instance of a BudgetAggregator. func NewBudgetAggregator(estimator chainfee.Estimator, - maxInputs uint32) *BudgetAggregator { + maxInputs uint32, auxSweeper fn.Option[AuxSweeper]) *BudgetAggregator { return &BudgetAggregator{ - estimator: estimator, - maxInputs: maxInputs, + estimator: estimator, + maxInputs: maxInputs, + auxSweeper: auxSweeper, } } @@ -159,7 +165,7 @@ func (b *BudgetAggregator) createInputSets(inputs []SweeperInput, // Create an InputSet using the max allowed number of inputs. set, err := NewBudgetInputSet( - currentInputs, deadlineHeight, + currentInputs, deadlineHeight, b.auxSweeper, ) if err != nil { log.Errorf("unable to create input set: %v", err) @@ -173,7 +179,7 @@ func (b *BudgetAggregator) createInputSets(inputs []SweeperInput, // Create an InputSet from the remaining inputs. if len(remainingInputs) > 0 { set, err := NewBudgetInputSet( - remainingInputs, deadlineHeight, + remainingInputs, deadlineHeight, b.auxSweeper, ) if err != nil { log.Errorf("unable to create input set: %v", err) diff --git a/sweep/aggregator_test.go b/sweep/aggregator_test.go index a32c849a0..6df0d73fa 100644 --- a/sweep/aggregator_test.go +++ b/sweep/aggregator_test.go @@ -150,7 +150,7 @@ func TestBudgetAggregatorFilterInputs(t *testing.T) { // Init the budget aggregator with the mocked estimator and zero max // num of inputs. - b := NewBudgetAggregator(estimator, 0) + b := NewBudgetAggregator(estimator, 0, fn.None[AuxSweeper]()) // Call the method under test. result := b.filterInputs(inputs) @@ -214,7 +214,7 @@ func TestBudgetAggregatorSortInputs(t *testing.T) { } // Init the budget aggregator with zero max num of inputs. - b := NewBudgetAggregator(nil, 0) + b := NewBudgetAggregator(nil, 0, fn.None[AuxSweeper]()) // Call the method under test. result := b.sortInputs(inputs) @@ -279,7 +279,7 @@ func TestBudgetAggregatorCreateInputSets(t *testing.T) { } // Create a budget aggregator with max number of inputs set to 2. - b := NewBudgetAggregator(nil, 2) + b := NewBudgetAggregator(nil, 2, fn.None[AuxSweeper]()) // Create test cases. testCases := []struct { @@ -540,7 +540,9 @@ func TestBudgetInputSetClusterInputs(t *testing.T) { } // Create a budget aggregator with a max number of inputs set to 100. - b := NewBudgetAggregator(estimator, DefaultMaxInputsPerTx) + b := NewBudgetAggregator( + estimator, DefaultMaxInputsPerTx, fn.None[AuxSweeper](), + ) // Call the method under test. result := b.ClusterInputs(inputs) diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index 5afa95bac..b1e34c0bf 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -131,6 +131,8 @@ type BumpRequest struct { func (r *BumpRequest) MaxFeeRateAllowed() (chainfee.SatPerKWeight, error) { // Get the size of the sweep tx, which will be used to calculate the // budget fee rate. + // + // TODO(roasbeef): also wants the extra change output? size, err := calcSweepTxWeight( r.Inputs, r.DeliveryAddress.DeliveryAddress, ) @@ -175,7 +177,7 @@ func calcSweepTxWeight(inputs []input.Input, // TODO(yy): we should refactor the weight estimator to not require a // fee rate and max fee rate and make it a pure tx weight calculator. _, estimator, err := getWeightEstimate( - inputs, nil, feeRate, 0, outputPkScript, + inputs, nil, feeRate, 0, [][]byte{outputPkScript}, ) if err != nil { return 0, err @@ -1171,9 +1173,9 @@ func (t *TxPublisher) createSweepTx(inputs []input.Input, feeRate chainfee.SatPerKWeight) (*sweepTxCtx, error) { // Validate and calculate the fee and change amount. - txFee, changeAmtOpt, locktimeOpt, err := prepareSweepTx( - inputs, changePkScript.DeliveryAddress, feeRate, - t.currentHeight.Load(), + txFee, changeOutputsOpt, locktimeOpt, err := prepareSweepTx( + inputs, changePkScript, feeRate, t.currentHeight.Load(), + t.cfg.AuxSweeper, ) if err != nil { return nil, err @@ -1219,12 +1221,12 @@ func (t *TxPublisher) createSweepTx(inputs []input.Input, }) } - // If there's a change amount, add it to the transaction. - changeAmtOpt.WhenSome(func(changeAmt btcutil.Amount) { - sweepTx.AddTxOut(&wire.TxOut{ - PkScript: changePkScript.DeliveryAddress, - Value: int64(changeAmt), - }) + // If we have change outputs to add, then add it the sweep transaction + // here. + changeOutputsOpt.WhenSome(func(changeOuts []SweepOutput) { + for i := range changeOuts { + sweepTx.AddTxOut(&changeOuts[i].TxOut) + } }) // We'll default to using the current block height as locktime, if none @@ -1268,31 +1270,80 @@ func (t *TxPublisher) createSweepTx(inputs []input.Input, log.Debugf("Created sweep tx %v for inputs:\n%v", sweepTx.TxHash(), inputTypeSummary(inputs)) + // Try to locate the extra change output, though there might be None. + extraTxOut := fn.MapOption( + func(sweepOuts []SweepOutput) fn.Option[SweepOutput] { + for _, sweepOut := range sweepOuts { + if !sweepOut.IsExtra { + continue + } + + // If we sweep outputs of a custom channel, the + // custom leaves in those outputs will be merged + // into a single output, even if we sweep + // multiple outputs (e.g. to_remote and breached + // to_local of a breached channel) at the same + // time. So there will only ever be one extra + // output. + log.Debugf("Sweep produced extra_sweep_out=%v", + lnutils.SpewLogClosure(sweepOut)) + + return fn.Some(sweepOut) + } + + return fn.None[SweepOutput]() + }, + )(changeOutputsOpt) + return &sweepTxCtx{ - tx: sweepTx, - fee: txFee, + tx: sweepTx, + fee: txFee, + extraTxOut: fn.FlattenOption(extraTxOut), }, nil } -// prepareSweepTx returns the tx fee, an optional change amount and an optional -// locktime after a series of validations: +// prepareSweepTx returns the tx fee, a set of optional change outputs and an +// optional locktime after a series of validations: // 1. check the locktime has been reached. // 2. check the locktimes are the same. // 3. check the inputs cover the outputs. // // NOTE: if the change amount is below dust, it will be added to the tx fee. -func prepareSweepTx(inputs []input.Input, changePkScript []byte, - feeRate chainfee.SatPerKWeight, currentHeight int32) ( - btcutil.Amount, fn.Option[btcutil.Amount], fn.Option[int32], error) { +func prepareSweepTx(inputs []input.Input, changePkScript lnwallet.AddrWithKey, + feeRate chainfee.SatPerKWeight, currentHeight int32, + auxSweeper fn.Option[AuxSweeper]) ( + btcutil.Amount, fn.Option[[]SweepOutput], fn.Option[int32], error) { - noChange := fn.None[btcutil.Amount]() + noChange := fn.None[[]SweepOutput]() noLocktime := fn.None[int32]() + // Given the set of inputs we have, if we have an aux sweeper, then + // we'll attempt to see if we have any other change outputs we'll need + // to add to the sweep transaction. + changePkScripts := [][]byte{changePkScript.DeliveryAddress} + + var extraChangeOut fn.Option[SweepOutput] + err := fn.MapOptionZ( + auxSweeper, func(aux AuxSweeper) error { + extraOut := aux.DeriveSweepAddr(inputs, changePkScript) + if err := extraOut.Err(); err != nil { + return err + } + + extraChangeOut = extraOut.LeftToOption() + + return nil + }, + ) + if err != nil { + return 0, noChange, noLocktime, err + } + // Creating a weight estimator with nil outputs and zero max fee rate. // We don't allow adding customized outputs in the sweeping tx, and the // fee rate is already being managed before we get here. inputs, estimator, err := getWeightEstimate( - inputs, nil, feeRate, 0, changePkScript, + inputs, nil, feeRate, 0, changePkScripts, ) if err != nil { return 0, noChange, noLocktime, err @@ -1310,6 +1361,12 @@ func prepareSweepTx(inputs []input.Input, changePkScript []byte, requiredOutput btcutil.Amount ) + // If we have an extra change output, then we'll add it as a required + // output amt. + extraChangeOut.WhenSome(func(o SweepOutput) { + requiredOutput += btcutil.Amount(o.Value) + }) + // Go through each input and check if the required lock times have // reached and are the same. for _, o := range inputs { @@ -1356,14 +1413,21 @@ func prepareSweepTx(inputs []input.Input, changePkScript []byte, // The value remaining after the required output and fees is the // change output. changeAmt := totalInput - requiredOutput - txFee - changeAmtOpt := fn.Some(changeAmt) + changeOuts := make([]SweepOutput, 0, 2) + + extraChangeOut.WhenSome(func(o SweepOutput) { + changeOuts = append(changeOuts, o) + }) // We'll calculate the dust limit for the given changePkScript since it // is variable. - changeFloor := lnwallet.DustLimitForSize(len(changePkScript)) + changeFloor := lnwallet.DustLimitForSize( + len(changePkScript.DeliveryAddress), + ) - // If the change amount is dust, we'll move it into the fees. switch { + // If the change amount is dust, we'll move it into the fees, and + // ignore it. case changeAmt < changeFloor: log.Infof("Change amt %v below dustlimit %v, not adding "+ "change output", changeAmt, changeFloor) @@ -1379,12 +1443,16 @@ func prepareSweepTx(inputs []input.Input, changePkScript []byte, // The dust amount is added to the fee. txFee += changeAmt - // Set the change amount to none. - changeAmtOpt = fn.None[btcutil.Amount]() - // Otherwise, we'll actually recognize it as a change output. default: - // TODO(roasbeef): Implement (later commit in this PR). + changeOuts = append(changeOuts, SweepOutput{ + TxOut: wire.TxOut{ + Value: int64(changeAmt), + PkScript: changePkScript.DeliveryAddress, + }, + IsExtra: false, + InternalKey: changePkScript.InternalKey, + }) } // Optionally set the locktime. @@ -1393,6 +1461,11 @@ func prepareSweepTx(inputs []input.Input, changePkScript []byte, locktimeOpt = noLocktime } + var changeOutsOpt fn.Option[[]SweepOutput] + if len(changeOuts) > 0 { + changeOutsOpt = fn.Some(changeOuts) + } + log.Debugf("Creating sweep tx for %v inputs (%s) using %v, "+ "tx_weight=%v, tx_fee=%v, locktime=%v, parents_count=%v, "+ "parents_fee=%v, parents_weight=%v, current_height=%v", @@ -1400,5 +1473,5 @@ func prepareSweepTx(inputs []input.Input, changePkScript []byte, estimator.weight(), txFee, locktimeOpt, len(estimator.parents), estimator.parentsFee, estimator.parentsWeight, currentHeight) - return txFee, changeAmtOpt, locktimeOpt, nil + return txFee, changeOutsOpt, locktimeOpt, nil } diff --git a/sweep/interface.go b/sweep/interface.go index 41120613b..acece3143 100644 --- a/sweep/interface.go +++ b/sweep/interface.go @@ -84,6 +84,12 @@ type AuxSweeper interface { DeriveSweepAddr(inputs []input.Input, change lnwallet.AddrWithKey) fn.Result[SweepOutput] + // ExtraBudgetForInputs is used to determine the extra budget that + // should be allocated to sweep the given set of inputs. This can be + // used to add extra funds to the sweep transaction, for example to + // cover fees for additional outputs of custom channels. + ExtraBudgetForInputs(inputs []input.Input) fn.Result[btcutil.Amount] + // NotifyBroadcast is used to notify external callers of the broadcast // of a sweep transaction, generated by the passed BumpRequest. NotifyBroadcast(req *BumpRequest, tx *wire.MsgTx, diff --git a/sweep/mock_test.go b/sweep/mock_test.go index 8552fbadc..c623ca3c0 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -6,6 +6,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/stretchr/testify/mock" @@ -315,15 +316,37 @@ func (m *MockFeeFunction) IncreaseFeeRate(confTarget uint32) (bool, error) { return args.Bool(0), args.Error(1) } -type MockAuxSweeper struct{} +type MockAuxSweeper struct { + mock.Mock +} // DeriveSweepAddr takes a set of inputs, and the change address we'd // use to sweep them, and maybe results an extra sweep output that we // should add to the sweeping transaction. -func (*MockAuxSweeper) DeriveSweepAddr(_ []input.Input, +func (m *MockAuxSweeper) DeriveSweepAddr(_ []input.Input, _ lnwallet.AddrWithKey) fn.Result[SweepOutput] { - return fn.Ok(SweepOutput{}) + return fn.Ok(SweepOutput{ + TxOut: wire.TxOut{ + Value: 123, + PkScript: changePkScript.DeliveryAddress, + }, + IsExtra: false, + InternalKey: fn.None[keychain.KeyDescriptor](), + }) +} + +// ExtraBudgetForInputs is used to determine the extra budget that +// should be allocated to sweep the given set of inputs. This can be +// used to add extra funds to the sweep transaction, for example to +// cover fees for additional outputs of custom channels. +func (m *MockAuxSweeper) ExtraBudgetForInputs( + _ []input.Input) fn.Result[btcutil.Amount] { + + args := m.Called() + amt := args.Get(0) + + return amt.(fn.Result[btcutil.Amount]) } // NotifyBroadcast is used to notify external callers of the broadcast diff --git a/sweep/tx_input_set.go b/sweep/tx_input_set.go index 31f20b7db..3a95fff2f 100644 --- a/sweep/tx_input_set.go +++ b/sweep/tx_input_set.go @@ -111,17 +111,26 @@ type BudgetInputSet struct { // deadlineHeight is the height which the inputs in this set must be // confirmed by. deadlineHeight int32 + + // extraBudget is a value that should be allocated to sweep the given + // set of inputs. This can be used to add extra funds to the sweep + // transaction, for example to cover fees for additional outputs of + // custom channels. + extraBudget btcutil.Amount } // Compile-time constraint to ensure budgetInputSet implements InputSet. var _ InputSet = (*BudgetInputSet)(nil) +// errEmptyInputs is returned when the input slice is empty. +var errEmptyInputs = fmt.Errorf("inputs slice is empty") + // validateInputs is used when creating new BudgetInputSet to ensure there are // no duplicate inputs and they all share the same deadline heights, if set. func validateInputs(inputs []SweeperInput, deadlineHeight int32) error { // Sanity check the input slice to ensure it's non-empty. if len(inputs) == 0 { - return fmt.Errorf("inputs slice is empty") + return errEmptyInputs } // inputDeadline tracks the input's deadline height. It will be updated @@ -167,8 +176,8 @@ func validateInputs(inputs []SweeperInput, deadlineHeight int32) error { } // NewBudgetInputSet creates a new BudgetInputSet. -func NewBudgetInputSet(inputs []SweeperInput, - deadlineHeight int32) (*BudgetInputSet, error) { +func NewBudgetInputSet(inputs []SweeperInput, deadlineHeight int32, + auxSweeper fn.Option[AuxSweeper]) (*BudgetInputSet, error) { // Validate the supplied inputs. if err := validateInputs(inputs, deadlineHeight); err != nil { @@ -186,9 +195,32 @@ func NewBudgetInputSet(inputs []SweeperInput, log.Tracef("Created %v", bi.String()) + // Attach an optional budget. This will be a no-op if the auxSweeper + // is not set. + if err := bi.attachExtraBudget(auxSweeper); err != nil { + return nil, err + } + return bi, nil } +// attachExtraBudget attaches an extra budget to the input set, if the passed +// aux sweeper is set. +func (b *BudgetInputSet) attachExtraBudget(s fn.Option[AuxSweeper]) error { + extraBudget, err := fn.MapOptionZ( + s, func(aux AuxSweeper) fn.Result[btcutil.Amount] { + return aux.ExtraBudgetForInputs(b.Inputs()) + }, + ).Unpack() + if err != nil { + return err + } + + b.extraBudget = extraBudget + + return nil +} + // String returns a human-readable description of the input set. func (b *BudgetInputSet) String() string { inputsDesc := "" @@ -212,8 +244,10 @@ func (b *BudgetInputSet) addInput(input SweeperInput) { func (b *BudgetInputSet) NeedWalletInput() bool { var ( // budgetNeeded is the amount that needs to be covered from - // other inputs. - budgetNeeded btcutil.Amount + // other inputs. We start at the value of the extra budget, + // which might be needed for custom channels that add extra + // outputs. + budgetNeeded = b.extraBudget // budgetBorrowable is the amount that can be borrowed from // other inputs. diff --git a/sweep/tx_input_set_test.go b/sweep/tx_input_set_test.go index b6a87b378..b774eedff 100644 --- a/sweep/tx_input_set_test.go +++ b/sweep/tx_input_set_test.go @@ -28,7 +28,9 @@ func TestNewBudgetInputSet(t *testing.T) { rt := require.New(t) // Pass an empty slice and expect an error. - set, err := NewBudgetInputSet([]SweeperInput{}, testHeight) + set, err := NewBudgetInputSet( + []SweeperInput{}, testHeight, fn.None[AuxSweeper](), + ) rt.ErrorContains(err, "inputs slice is empty") rt.Nil(set) @@ -66,23 +68,35 @@ func TestNewBudgetInputSet(t *testing.T) { } // Pass a slice of inputs with different deadline heights. - set, err = NewBudgetInputSet([]SweeperInput{input1, input2}, testHeight) + set, err = NewBudgetInputSet( + []SweeperInput{input1, input2}, testHeight, + fn.None[AuxSweeper](), + ) rt.ErrorContains(err, "input deadline height not matched") rt.Nil(set) // Pass a slice of inputs that only one input has the deadline height, // but it has a different value than the specified testHeight. - set, err = NewBudgetInputSet([]SweeperInput{input0, input2}, testHeight) + set, err = NewBudgetInputSet( + []SweeperInput{input0, input2}, testHeight, + fn.None[AuxSweeper](), + ) rt.ErrorContains(err, "input deadline height not matched") rt.Nil(set) // Pass a slice of inputs that are duplicates. - set, err = NewBudgetInputSet([]SweeperInput{input3, input3}, testHeight) + set, err = NewBudgetInputSet( + []SweeperInput{input3, input3}, testHeight, + fn.None[AuxSweeper](), + ) rt.ErrorContains(err, "duplicate inputs") rt.Nil(set) // Pass a slice of inputs that only one input has the deadline height, - set, err = NewBudgetInputSet([]SweeperInput{input0, input3}, testHeight) + set, err = NewBudgetInputSet( + []SweeperInput{input0, input3}, testHeight, + fn.None[AuxSweeper](), + ) rt.NoError(err) rt.NotNil(set) } @@ -102,7 +116,9 @@ func TestBudgetInputSetAddInput(t *testing.T) { } // Initialize an input set, which adds the above input. - set, err := NewBudgetInputSet([]SweeperInput{*pi}, testHeight) + set, err := NewBudgetInputSet( + []SweeperInput{*pi}, testHeight, fn.None[AuxSweeper](), + ) require.NoError(t, err) // Add the input to the set again. @@ -125,48 +141,55 @@ func TestNeedWalletInput(t *testing.T) { // Create a mock input that doesn't have required outputs. mockInput := &input.MockInput{} mockInput.On("RequiredTxOut").Return(nil) + mockInput.On("OutPoint").Return(wire.OutPoint{Hash: chainhash.Hash{1}}) defer mockInput.AssertExpectations(t) // Create a mock input that has required outputs. mockInputRequireOutput := &input.MockInput{} mockInputRequireOutput.On("RequiredTxOut").Return(&wire.TxOut{}) + mockInputRequireOutput.On("OutPoint").Return( + wire.OutPoint{Hash: chainhash.Hash{2}}, + ) defer mockInputRequireOutput.AssertExpectations(t) // We now create two pending inputs each has a budget of 100 satoshis. const budget = 100 // Create the pending input that doesn't have a required output. - piBudget := &SweeperInput{ + piBudget := SweeperInput{ Input: mockInput, params: Params{Budget: budget}, } // Create the pending input that has a required output. - piRequireOutput := &SweeperInput{ + piRequireOutput := SweeperInput{ Input: mockInputRequireOutput, params: Params{Budget: budget}, } testCases := []struct { name string - setupInputs func() []*SweeperInput + setupInputs func() []SweeperInput + extraBudget btcutil.Amount need bool + err error }{ { // When there are no pending inputs, we won't need a - // wallet input. Technically this should be an invalid + // wallet input. Technically this is be an invalid // state. name: "no inputs", - setupInputs: func() []*SweeperInput { + setupInputs: func() []SweeperInput { return nil }, need: false, + err: errEmptyInputs, }, { // When there's no required output, we don't need a // wallet input. name: "no required outputs", - setupInputs: func() []*SweeperInput { + setupInputs: func() []SweeperInput { // Create a sign descriptor to be used in the // pending input when calculating budgets can // be borrowed. @@ -177,15 +200,36 @@ func TestNeedWalletInput(t *testing.T) { } mockInput.On("SignDesc").Return(sd).Once() - return []*SweeperInput{piBudget} + return []SweeperInput{piBudget} }, need: false, }, + { + // When there's no required normal outputs, but an extra + // budget from custom channels, we will need a wallet + // input. + name: "no required normal outputs but extra budget", + setupInputs: func() []SweeperInput { + // Create a sign descriptor to be used in the + // pending input when calculating budgets can + // be borrowed. + sd := &input.SignDescriptor{ + Output: &wire.TxOut{ + Value: budget, + }, + } + mockInput.On("SignDesc").Return(sd).Once() + + return []SweeperInput{piBudget} + }, + extraBudget: 1000, + need: true, + }, { // When the output value cannot cover the budget, we // need a wallet input. name: "output value cannot cover budget", - setupInputs: func() []*SweeperInput { + setupInputs: func() []SweeperInput { // Create a sign descriptor to be used in the // pending input when calculating budgets can // be borrowed. @@ -194,8 +238,8 @@ func TestNeedWalletInput(t *testing.T) { Value: budget - 1, }, } - mockInput.On("SignDesc").Return(sd).Once() + mockInput.On("SignDesc").Return(sd).Once() // These two methods are only invoked when the // unit test is running with a logger. mockInput.On("OutPoint").Return( @@ -205,7 +249,7 @@ func TestNeedWalletInput(t *testing.T) { input.CommitmentAnchor, ).Maybe() - return []*SweeperInput{piBudget} + return []SweeperInput{piBudget} }, need: true, }, @@ -213,8 +257,8 @@ func TestNeedWalletInput(t *testing.T) { // When there's only inputs that require outputs, we // need wallet inputs. name: "only required outputs", - setupInputs: func() []*SweeperInput { - return []*SweeperInput{piRequireOutput} + setupInputs: func() []SweeperInput { + return []SweeperInput{piRequireOutput} }, need: true, }, @@ -223,7 +267,7 @@ func TestNeedWalletInput(t *testing.T) { // budget cannot cover the required, we need a wallet // input. name: "not enough budget to be borrowed", - setupInputs: func() []*SweeperInput { + setupInputs: func() []SweeperInput { // Create a sign descriptor to be used in the // pending input when calculating budgets can // be borrowed. @@ -237,7 +281,7 @@ func TestNeedWalletInput(t *testing.T) { } mockInput.On("SignDesc").Return(sd).Once() - return []*SweeperInput{ + return []SweeperInput{ piBudget, piRequireOutput, } }, @@ -248,7 +292,7 @@ func TestNeedWalletInput(t *testing.T) { // borrowed covers the required, we don't need wallet // inputs. name: "enough budget to be borrowed", - setupInputs: func() []*SweeperInput { + setupInputs: func() []SweeperInput { // Create a sign descriptor to be used in the // pending input when calculating budgets can // be borrowed. @@ -263,7 +307,7 @@ func TestNeedWalletInput(t *testing.T) { mockInput.On("SignDesc").Return(sd).Once() piBudget.Input = mockInput - return []*SweeperInput{ + return []SweeperInput{ piBudget, piRequireOutput, } }, @@ -276,12 +320,27 @@ func TestNeedWalletInput(t *testing.T) { // Setup testing inputs. inputs := tc.setupInputs() + // If an extra budget is set, then we'll update the mock + // to expect the extra budget. + mockAuxSweeper := &MockAuxSweeper{} + mockAuxSweeper.On("ExtraBudgetForInputs").Return( + fn.Ok(tc.extraBudget), + ) + // Initialize an input set, which adds the testing // inputs. - set := &BudgetInputSet{inputs: inputs} + set, err := NewBudgetInputSet( + inputs, 0, fn.Some[AuxSweeper](mockAuxSweeper), + ) + if err != nil { + require.ErrorIs(t, err, tc.err) + return + } result := set.NeedWalletInput() + require.Equal(t, tc.need, result) + mockAuxSweeper.AssertExpectations(t) }) } } @@ -434,7 +493,9 @@ func TestAddWalletInputSuccess(t *testing.T) { min, max).Return([]*lnwallet.Utxo{utxo, utxo}, nil).Once() // Initialize an input set with the pending input. - set, err := NewBudgetInputSet([]SweeperInput{*pi}, deadline) + set, err := NewBudgetInputSet( + []SweeperInput{*pi}, deadline, fn.None[AuxSweeper](), + ) require.NoError(t, err) // Add wallet inputs to the input set, which should give us an error as diff --git a/sweep/txgenerator.go b/sweep/txgenerator.go index 30e11023e..43fc802ba 100644 --- a/sweep/txgenerator.go +++ b/sweep/txgenerator.go @@ -38,7 +38,7 @@ func createSweepTx(inputs []input.Input, outputs []*wire.TxOut, signer input.Signer) (*wire.MsgTx, btcutil.Amount, error) { inputs, estimator, err := getWeightEstimate( - inputs, outputs, feeRate, maxFeeRate, changePkScript, + inputs, outputs, feeRate, maxFeeRate, [][]byte{changePkScript}, ) if err != nil { return nil, 0, err @@ -221,7 +221,7 @@ func createSweepTx(inputs []input.Input, outputs []*wire.TxOut, // Additionally, it returns counts for the number of csv and cltv inputs. func getWeightEstimate(inputs []input.Input, outputs []*wire.TxOut, feeRate, maxFeeRate chainfee.SatPerKWeight, - outputPkScript []byte) ([]input.Input, *weightEstimator, error) { + outputPkScripts [][]byte) ([]input.Input, *weightEstimator, error) { // We initialize a weight estimator so we can accurately asses the // amount of fees we need to pay for this sweep transaction. @@ -237,31 +237,33 @@ func getWeightEstimate(inputs []input.Input, outputs []*wire.TxOut, // If there is any leftover change after paying to the given outputs // and required outputs, it will go to a single segwit p2wkh or p2tr - // address. This will be our change address, so ensure it contributes to - // our weight estimate. Note that if we have other outputs, we might end - // up creating a sweep tx without a change output. It is okay to add the - // change output to the weight estimate regardless, since the estimated - // fee will just be subtracted from this already dust output, and - // trimmed. - switch { - case txscript.IsPayToTaproot(outputPkScript): - weightEstimate.addP2TROutput() + // address. This will be our change address, so ensure it contributes + // to our weight estimate. Note that if we have other outputs, we might + // end up creating a sweep tx without a change output. It is okay to + // add the change output to the weight estimate regardless, since the + // estimated fee will just be subtracted from this already dust output, + // and trimmed. + for _, outputPkScript := range outputPkScripts { + switch { + case txscript.IsPayToTaproot(outputPkScript): + weightEstimate.addP2TROutput() - case txscript.IsPayToWitnessScriptHash(outputPkScript): - weightEstimate.addP2WSHOutput() + case txscript.IsPayToWitnessScriptHash(outputPkScript): + weightEstimate.addP2WSHOutput() - case txscript.IsPayToWitnessPubKeyHash(outputPkScript): - weightEstimate.addP2WKHOutput() + case txscript.IsPayToWitnessPubKeyHash(outputPkScript): + weightEstimate.addP2WKHOutput() - case txscript.IsPayToPubKeyHash(outputPkScript): - weightEstimate.estimator.AddP2PKHOutput() + case txscript.IsPayToPubKeyHash(outputPkScript): + weightEstimate.estimator.AddP2PKHOutput() - case txscript.IsPayToScriptHash(outputPkScript): - weightEstimate.estimator.AddP2SHOutput() + case txscript.IsPayToScriptHash(outputPkScript): + weightEstimate.estimator.AddP2SHOutput() - default: - // Unknown script type. - return nil, nil, errors.New("unknown script type") + default: + // Unknown script type. + return nil, nil, errors.New("unknown script type") + } } // For each output, use its witness type to determine the estimate diff --git a/sweep/txgenerator_test.go b/sweep/txgenerator_test.go index 48dcacd49..71477bd6e 100644 --- a/sweep/txgenerator_test.go +++ b/sweep/txgenerator_test.go @@ -51,7 +51,7 @@ func TestWeightEstimate(t *testing.T) { } _, estimator, err := getWeightEstimate( - inputs, nil, 0, 0, changePkScript, + inputs, nil, 0, 0, [][]byte{changePkScript}, ) require.NoError(t, err) @@ -153,7 +153,7 @@ func testUnknownScriptInner(t *testing.T, pkscript []byte, expectFail bool) { )) } - _, _, err := getWeightEstimate(inputs, nil, 0, 0, pkscript) + _, _, err := getWeightEstimate(inputs, nil, 0, 0, [][]byte{pkscript}) if expectFail { require.Error(t, err) } else {