diff --git a/lntest/itest/lnd_nonstd_sweep_test.go b/lntest/itest/lnd_nonstd_sweep_test.go new file mode 100644 index 000000000..36dc7be8c --- /dev/null +++ b/lntest/itest/lnd_nonstd_sweep_test.go @@ -0,0 +1,147 @@ +package itest + +import ( + "context" + "testing" + + "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lntest" + "github.com/stretchr/testify/require" +) + +func testNonstdSweep(net *lntest.NetworkHarness, t *harnessTest) { + p2shAddr, err := btcutil.NewAddressScriptHash( + make([]byte, 1), harnessNetParams, + ) + require.NoError(t.t, err) + + p2pkhAddr, err := btcutil.NewAddressPubKeyHash( + make([]byte, 20), harnessNetParams, + ) + require.NoError(t.t, err) + + p2wshAddr, err := btcutil.NewAddressWitnessScriptHash( + make([]byte, 32), harnessNetParams, + ) + require.NoError(t.t, err) + + p2wkhAddr, err := btcutil.NewAddressWitnessPubKeyHash( + make([]byte, 20), harnessNetParams, + ) + require.NoError(t.t, err) + + p2trAddr, err := btcutil.NewAddressTaproot( + make([]byte, 32), harnessNetParams, + ) + require.NoError(t.t, err) + + tests := []struct { + name string + address string + }{ + { + name: "p2sh SendCoins standardness", + address: p2shAddr.EncodeAddress(), + }, + { + name: "p2pkh SendCoins standardness", + address: p2pkhAddr.EncodeAddress(), + }, + { + name: "p2wsh SendCoins standardness", + address: p2wshAddr.EncodeAddress(), + }, + { + name: "p2wkh SendCoins standardness", + address: p2wkhAddr.EncodeAddress(), + }, + { + name: "p2tr SendCoins standardness", + address: p2trAddr.EncodeAddress(), + }, + } + + for _, test := range tests { + test := test + success := t.t.Run(test.name, func(t *testing.T) { + h := newHarnessTest(t, net) + + testNonStdSweepInner(net, h, test.address) + }) + if !success { + break + } + } +} + +func testNonStdSweepInner(net *lntest.NetworkHarness, t *harnessTest, + address string) { + + ctxb := context.Background() + + carol := net.NewNode(t.t, "carol", nil) + + // Give Carol a UTXO so SendCoins will behave as expected. + net.SendCoins(t.t, btcutil.SatoshiPerBitcoin, carol) + + // Set the fee estimate to 1sat/vbyte. + net.SetFeeEstimate(250) + + // Make Carol call SendCoins with the SendAll flag and the created + // address. + sendReq := &lnrpc.SendCoinsRequest{ + Addr: address, + SatPerVbyte: 1, + SendAll: true, + } + + ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout) + defer cancel() + + // If a non-standard transaction was created, then this SendCoins call + // will fail. + _, err := carol.SendCoins(ctxt, sendReq) + require.NoError(t.t, err) + + // Fetch the txid so we can grab the raw transaction. + txid, err := waitForTxInMempool(net.Miner.Client, minerMempoolTimeout) + require.NoError(t.t, err) + + tx, err := net.Miner.Client.GetRawTransaction(txid) + require.NoError(t.t, err) + + msgTx := tx.MsgTx() + + // Fetch the fee of the transaction. + var ( + inputVal int + outputVal int + fee int + ) + + for _, inp := range msgTx.TxIn { + // Fetch the previous outpoint's value. + prevOut := inp.PreviousOutPoint + + ptx, err := net.Miner.Client.GetRawTransaction(&prevOut.Hash) + require.NoError(t.t, err) + + pout := ptx.MsgTx().TxOut[prevOut.Index] + inputVal += int(pout.Value) + } + + for _, outp := range msgTx.TxOut { + outputVal += int(outp.Value) + } + + fee = inputVal - outputVal + + // Fetch the vsize of the transaction so we can determine if the + // transaction pays >= 1 sat/vbyte. + rawTx, err := net.Miner.Client.GetRawTransactionVerbose(txid) + require.NoError(t.t, err) + + // Require fee >= vbytes. + require.True(t.t, fee >= int(rawTx.Vsize)) +} diff --git a/lntest/itest/lnd_test_list_on_test.go b/lntest/itest/lnd_test_list_on_test.go index 81202d844..71d6df0d4 100644 --- a/lntest/itest/lnd_test_list_on_test.go +++ b/lntest/itest/lnd_test_list_on_test.go @@ -427,4 +427,8 @@ var allTestCases = []*testCase{ name: "scid alias upgrade", test: testOptionScidUpgrade, }, + { + name: "nonstd sweep", + test: testNonstdSweep, + }, } diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 0eebb3799..d13790a00 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -124,7 +124,6 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { timeoutChan: make(chan chan time.Time, 1), } - var outputScriptCount byte ctx.sweeper = New(&UtxoSweeperConfig{ Notifier: notifier, Wallet: backend, @@ -137,8 +136,8 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { Signer: &mock.DummySigner{}, GenSweepScript: func() ([]byte, error) { script := make([]byte, input.P2WPKHSize) - script[0] = outputScriptCount - outputScriptCount++ + script[0] = 0 + script[1] = 20 return script, nil }, FeeEstimator: estimator, @@ -330,7 +329,8 @@ func assertTxSweepsInputs(t *testing.T, sweepTx *wire.MsgTx, // NOTE: This assumes that transactions only have one output, as this is the // only type of transaction the UtxoSweeper can create at the moment. func assertTxFeeRate(t *testing.T, tx *wire.MsgTx, - expectedFeeRate chainfee.SatPerKWeight, inputs ...input.Input) { + expectedFeeRate chainfee.SatPerKWeight, changePk []byte, + inputs ...input.Input) { t.Helper() @@ -355,7 +355,9 @@ func assertTxFeeRate(t *testing.T, tx *wire.MsgTx, outputAmt := tx.TxOut[0].Value fee := btcutil.Amount(inputAmt - outputAmt) - _, estimator := getWeightEstimate(inputs, nil, 0, nil) + _, estimator, err := getWeightEstimate(inputs, nil, 0, changePk) + require.NoError(t, err) + txWeight := estimator.weight() expectedFee := expectedFeeRate.FeeForWeight(int64(txWeight)) @@ -1092,14 +1094,19 @@ func TestDifferentFeePreferences(t *testing.T) { // transactions to be broadcast in order of high to low fee preference. ctx.tick() + // Generate the same type of sweep script that was used for weight + // estimation. + changePk, err := ctx.sweeper.cfg.GenSweepScript() + require.NoError(t, err) + // The first transaction broadcast should be the one spending the higher // fee rate inputs. sweepTx1 := ctx.receiveTx() - assertTxFeeRate(t, &sweepTx1, highFeeRate, input1, input2) + assertTxFeeRate(t, &sweepTx1, highFeeRate, changePk, input1, input2) // The second should be the one spending the lower fee rate inputs. sweepTx2 := ctx.receiveTx() - assertTxFeeRate(t, &sweepTx2, lowFeeRate, input3) + assertTxFeeRate(t, &sweepTx2, lowFeeRate, changePk, input3) // With the transactions broadcast, we'll mine a block to so that the // result is delivered to each respective client. @@ -1218,10 +1225,15 @@ func TestBumpFeeRBF(t *testing.T) { t.Fatal(err) } + // Generate the same type of change script used so we can have accurate + // weight estimation. + changePk, err := ctx.sweeper.cfg.GenSweepScript() + require.NoError(t, err) + // Ensure that a transaction is broadcast with the lower fee preference. ctx.tick() lowFeeTx := ctx.receiveTx() - assertTxFeeRate(t, &lowFeeTx, lowFeeRate, &input) + assertTxFeeRate(t, &lowFeeTx, lowFeeRate, changePk, &input) // We'll then attempt to bump its fee rate. highFeePref := FeePreference{ConfTarget: 6} @@ -1242,7 +1254,7 @@ func TestBumpFeeRBF(t *testing.T) { // A higher fee rate transaction should be immediately broadcast. ctx.tick() highFeeTx := ctx.receiveTx() - assertTxFeeRate(t, &highFeeTx, highFeeRate, &input) + assertTxFeeRate(t, &highFeeTx, highFeeRate, changePk, &input) // We'll finish our test by mining the sweep transaction. ctx.backend.mine() diff --git a/sweep/txgenerator.go b/sweep/txgenerator.go index 95da8a764..3b07f2d64 100644 --- a/sweep/txgenerator.go +++ b/sweep/txgenerator.go @@ -1,6 +1,7 @@ package sweep import ( + "errors" "fmt" "sort" "strings" @@ -139,9 +140,13 @@ func createSweepTx(inputs []input.Input, outputs []*wire.TxOut, feePerKw chainfee.SatPerKWeight, signer input.Signer) (*wire.MsgTx, error) { - inputs, estimator := getWeightEstimate( + inputs, estimator, err := getWeightEstimate( inputs, outputs, feePerKw, changePkScript, ) + if err != nil { + return nil, err + } + txFee := estimator.fee() var ( @@ -315,7 +320,7 @@ func createSweepTx(inputs []input.Input, outputs []*wire.TxOut, // Additionally, it returns counts for the number of csv and cltv inputs. func getWeightEstimate(inputs []input.Input, outputs []*wire.TxOut, feeRate chainfee.SatPerKWeight, outputPkScript []byte) ([]input.Input, - *weightEstimator) { + *weightEstimator, error) { // We initialize a weight estimator so we can accurately asses the // amount of fees we need to pay for this sweep transaction. @@ -337,10 +342,25 @@ func getWeightEstimate(inputs []input.Input, outputs []*wire.TxOut, // change output to the weight estimate regardless, since the estimated // fee will just be subtracted from this already dust output, and // trimmed. - if txscript.IsPayToTaproot(outputPkScript) { + switch { + case txscript.IsPayToTaproot(outputPkScript): weightEstimate.addP2TROutput() - } else { + + case txscript.IsPayToWitnessScriptHash(outputPkScript): + weightEstimate.addP2WSHOutput() + + case txscript.IsPayToWitnessPubKeyHash(outputPkScript): weightEstimate.addP2WKHOutput() + + case txscript.IsPayToPubKeyHash(outputPkScript): + weightEstimate.estimator.AddP2PKHOutput() + + case txscript.IsPayToScriptHash(outputPkScript): + weightEstimate.estimator.AddP2SHOutput() + + default: + // Unknown script type. + return nil, nil, errors.New("unknown script type") } // For each output, use its witness type to determine the estimate @@ -368,7 +388,7 @@ func getWeightEstimate(inputs []input.Input, outputs []*wire.TxOut, sweepInputs = append(sweepInputs, inp) } - return sweepInputs, weightEstimate + return sweepInputs, weightEstimate, nil } // inputSummary returns a string containing a human readable summary about the diff --git a/sweep/txgenerator_test.go b/sweep/txgenerator_test.go index 960da63d3..cb575e5a6 100644 --- a/sweep/txgenerator_test.go +++ b/sweep/txgenerator_test.go @@ -6,6 +6,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/input" + "github.com/stretchr/testify/require" ) var ( @@ -39,7 +40,19 @@ func TestWeightEstimate(t *testing.T) { )) } - _, estimator := getWeightEstimate(inputs, nil, 0, nil) + // Create a sweep script that is always fed into the weight estimator, + // regardless if it's actually included in the tx. It will be a P2WKH + // script. + changePkScript := []byte{ + 0x00, 0x14, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + } + + _, estimator, err := getWeightEstimate(inputs, nil, 0, changePkScript) + require.NoError(t, err) + weight := int64(estimator.weight()) if weight != expectedWeight { t.Fatalf("unexpected weight. expected %d but got %d.", @@ -51,3 +64,97 @@ func TestWeightEstimate(t *testing.T) { expectedSummary, summary) } } + +// TestWeightEstimatorUnknownScript tests that the weight estimator fails when +// given an unknown script and succeeds when given a known script. +func TestWeightEstimatorUnknownScript(t *testing.T) { + tests := []struct { + name string + pkscript []byte + expectFail bool + }{ + { + name: "p2tr output", + pkscript: []byte{ + 0x51, 0x20, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + }, + { + name: "p2wsh output", + pkscript: []byte{ + 0x00, 0x20, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + }, + { + name: "p2wkh output", + pkscript: []byte{ + 0x00, 0x14, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + }, + }, + { + name: "p2pkh output", + pkscript: []byte{ + 0x76, 0xa9, 0x14, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x88, 0xac, + }, + }, + { + name: "p2sh output", + pkscript: []byte{ + 0xa9, 0x14, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x87, + }, + }, + { + name: "unknown script", + pkscript: []byte{0x00}, + expectFail: true, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + testUnknownScriptInner( + t, test.pkscript, test.expectFail, + ) + }) + } +} + +func testUnknownScriptInner(t *testing.T, pkscript []byte, expectFail bool) { + var inputs []input.Input + for i, witnessType := range witnessTypes { + inputs = append(inputs, input.NewBaseInput( + &wire.OutPoint{ + Hash: chainhash.Hash{byte(i)}, + Index: uint32(i) + 10, + }, witnessType, + &input.SignDescriptor{}, 0, + )) + } + + _, _, err := getWeightEstimate(inputs, nil, 0, pkscript) + if expectFail { + require.Error(t, err) + } else { + require.NoError(t, err) + } +} diff --git a/sweep/weight_estimator.go b/sweep/weight_estimator.go index aba61cd92..c09adabdc 100644 --- a/sweep/weight_estimator.go +++ b/sweep/weight_estimator.go @@ -83,6 +83,12 @@ func (w *weightEstimator) addP2TROutput() { w.estimator.AddP2TROutput() } +// addP2WSHOutput updates the weight estimate to account for an additional +// segwit v0 P2WSH output. +func (w *weightEstimator) addP2WSHOutput() { + w.estimator.AddP2WSHOutput() +} + // addOutput updates the weight estimate to account for the known // output given. func (w *weightEstimator) addOutput(txOut *wire.TxOut) {