mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-02-22 14:22:37 +01:00
sweep: account for all script types in craftSweepTx
With this change, transactions created via craftSweepTx will be standard. Previously, p2wsh/p2pkh scripts passed in via SendCoins would be weighted as p2wpkh scripts. With a feerate of 1 sat/vbyte, transactions returned would be non-standard. Luckily, the critical sweeper subsystem only used p2wpkh scripts so this only affected callers from the rpcserver. Also added is an integration test that fails if SendCoins manages to generate a non-standard transaction. All script types are now accounted for in getWeightEstimate, which now errors if an unknown script type is passed in.
This commit is contained in:
parent
6060e05d7c
commit
6ec2826f6c
6 changed files with 311 additions and 15 deletions
147
lntest/itest/lnd_nonstd_sweep_test.go
Normal file
147
lntest/itest/lnd_nonstd_sweep_test.go
Normal file
|
@ -0,0 +1,147 @@
|
|||
package itest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcutil"
|
||||
"github.com/lightningnetwork/lnd/lnrpc"
|
||||
"github.com/lightningnetwork/lnd/lntest"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testNonstdSweep(net *lntest.NetworkHarness, t *harnessTest) {
|
||||
p2shAddr, err := btcutil.NewAddressScriptHash(
|
||||
make([]byte, 1), harnessNetParams,
|
||||
)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
p2pkhAddr, err := btcutil.NewAddressPubKeyHash(
|
||||
make([]byte, 20), harnessNetParams,
|
||||
)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
p2wshAddr, err := btcutil.NewAddressWitnessScriptHash(
|
||||
make([]byte, 32), harnessNetParams,
|
||||
)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
p2wkhAddr, err := btcutil.NewAddressWitnessPubKeyHash(
|
||||
make([]byte, 20), harnessNetParams,
|
||||
)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
p2trAddr, err := btcutil.NewAddressTaproot(
|
||||
make([]byte, 32), harnessNetParams,
|
||||
)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
address string
|
||||
}{
|
||||
{
|
||||
name: "p2sh SendCoins standardness",
|
||||
address: p2shAddr.EncodeAddress(),
|
||||
},
|
||||
{
|
||||
name: "p2pkh SendCoins standardness",
|
||||
address: p2pkhAddr.EncodeAddress(),
|
||||
},
|
||||
{
|
||||
name: "p2wsh SendCoins standardness",
|
||||
address: p2wshAddr.EncodeAddress(),
|
||||
},
|
||||
{
|
||||
name: "p2wkh SendCoins standardness",
|
||||
address: p2wkhAddr.EncodeAddress(),
|
||||
},
|
||||
{
|
||||
name: "p2tr SendCoins standardness",
|
||||
address: p2trAddr.EncodeAddress(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
success := t.t.Run(test.name, func(t *testing.T) {
|
||||
h := newHarnessTest(t, net)
|
||||
|
||||
testNonStdSweepInner(net, h, test.address)
|
||||
})
|
||||
if !success {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testNonStdSweepInner(net *lntest.NetworkHarness, t *harnessTest,
|
||||
address string) {
|
||||
|
||||
ctxb := context.Background()
|
||||
|
||||
carol := net.NewNode(t.t, "carol", nil)
|
||||
|
||||
// Give Carol a UTXO so SendCoins will behave as expected.
|
||||
net.SendCoins(t.t, btcutil.SatoshiPerBitcoin, carol)
|
||||
|
||||
// Set the fee estimate to 1sat/vbyte.
|
||||
net.SetFeeEstimate(250)
|
||||
|
||||
// Make Carol call SendCoins with the SendAll flag and the created
|
||||
// address.
|
||||
sendReq := &lnrpc.SendCoinsRequest{
|
||||
Addr: address,
|
||||
SatPerVbyte: 1,
|
||||
SendAll: true,
|
||||
}
|
||||
|
||||
ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
// If a non-standard transaction was created, then this SendCoins call
|
||||
// will fail.
|
||||
_, err := carol.SendCoins(ctxt, sendReq)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
// Fetch the txid so we can grab the raw transaction.
|
||||
txid, err := waitForTxInMempool(net.Miner.Client, minerMempoolTimeout)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
tx, err := net.Miner.Client.GetRawTransaction(txid)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
msgTx := tx.MsgTx()
|
||||
|
||||
// Fetch the fee of the transaction.
|
||||
var (
|
||||
inputVal int
|
||||
outputVal int
|
||||
fee int
|
||||
)
|
||||
|
||||
for _, inp := range msgTx.TxIn {
|
||||
// Fetch the previous outpoint's value.
|
||||
prevOut := inp.PreviousOutPoint
|
||||
|
||||
ptx, err := net.Miner.Client.GetRawTransaction(&prevOut.Hash)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
pout := ptx.MsgTx().TxOut[prevOut.Index]
|
||||
inputVal += int(pout.Value)
|
||||
}
|
||||
|
||||
for _, outp := range msgTx.TxOut {
|
||||
outputVal += int(outp.Value)
|
||||
}
|
||||
|
||||
fee = inputVal - outputVal
|
||||
|
||||
// Fetch the vsize of the transaction so we can determine if the
|
||||
// transaction pays >= 1 sat/vbyte.
|
||||
rawTx, err := net.Miner.Client.GetRawTransactionVerbose(txid)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
// Require fee >= vbytes.
|
||||
require.True(t.t, fee >= int(rawTx.Vsize))
|
||||
}
|
|
@ -427,4 +427,8 @@ var allTestCases = []*testCase{
|
|||
name: "scid alias upgrade",
|
||||
test: testOptionScidUpgrade,
|
||||
},
|
||||
{
|
||||
name: "nonstd sweep",
|
||||
test: testNonstdSweep,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -124,7 +124,6 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext {
|
|||
timeoutChan: make(chan chan time.Time, 1),
|
||||
}
|
||||
|
||||
var outputScriptCount byte
|
||||
ctx.sweeper = New(&UtxoSweeperConfig{
|
||||
Notifier: notifier,
|
||||
Wallet: backend,
|
||||
|
@ -137,8 +136,8 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext {
|
|||
Signer: &mock.DummySigner{},
|
||||
GenSweepScript: func() ([]byte, error) {
|
||||
script := make([]byte, input.P2WPKHSize)
|
||||
script[0] = outputScriptCount
|
||||
outputScriptCount++
|
||||
script[0] = 0
|
||||
script[1] = 20
|
||||
return script, nil
|
||||
},
|
||||
FeeEstimator: estimator,
|
||||
|
@ -330,7 +329,8 @@ func assertTxSweepsInputs(t *testing.T, sweepTx *wire.MsgTx,
|
|||
// NOTE: This assumes that transactions only have one output, as this is the
|
||||
// only type of transaction the UtxoSweeper can create at the moment.
|
||||
func assertTxFeeRate(t *testing.T, tx *wire.MsgTx,
|
||||
expectedFeeRate chainfee.SatPerKWeight, inputs ...input.Input) {
|
||||
expectedFeeRate chainfee.SatPerKWeight, changePk []byte,
|
||||
inputs ...input.Input) {
|
||||
|
||||
t.Helper()
|
||||
|
||||
|
@ -355,7 +355,9 @@ func assertTxFeeRate(t *testing.T, tx *wire.MsgTx,
|
|||
outputAmt := tx.TxOut[0].Value
|
||||
|
||||
fee := btcutil.Amount(inputAmt - outputAmt)
|
||||
_, estimator := getWeightEstimate(inputs, nil, 0, nil)
|
||||
_, estimator, err := getWeightEstimate(inputs, nil, 0, changePk)
|
||||
require.NoError(t, err)
|
||||
|
||||
txWeight := estimator.weight()
|
||||
|
||||
expectedFee := expectedFeeRate.FeeForWeight(int64(txWeight))
|
||||
|
@ -1092,14 +1094,19 @@ func TestDifferentFeePreferences(t *testing.T) {
|
|||
// transactions to be broadcast in order of high to low fee preference.
|
||||
ctx.tick()
|
||||
|
||||
// Generate the same type of sweep script that was used for weight
|
||||
// estimation.
|
||||
changePk, err := ctx.sweeper.cfg.GenSweepScript()
|
||||
require.NoError(t, err)
|
||||
|
||||
// The first transaction broadcast should be the one spending the higher
|
||||
// fee rate inputs.
|
||||
sweepTx1 := ctx.receiveTx()
|
||||
assertTxFeeRate(t, &sweepTx1, highFeeRate, input1, input2)
|
||||
assertTxFeeRate(t, &sweepTx1, highFeeRate, changePk, input1, input2)
|
||||
|
||||
// The second should be the one spending the lower fee rate inputs.
|
||||
sweepTx2 := ctx.receiveTx()
|
||||
assertTxFeeRate(t, &sweepTx2, lowFeeRate, input3)
|
||||
assertTxFeeRate(t, &sweepTx2, lowFeeRate, changePk, input3)
|
||||
|
||||
// With the transactions broadcast, we'll mine a block to so that the
|
||||
// result is delivered to each respective client.
|
||||
|
@ -1218,10 +1225,15 @@ func TestBumpFeeRBF(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Generate the same type of change script used so we can have accurate
|
||||
// weight estimation.
|
||||
changePk, err := ctx.sweeper.cfg.GenSweepScript()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure that a transaction is broadcast with the lower fee preference.
|
||||
ctx.tick()
|
||||
lowFeeTx := ctx.receiveTx()
|
||||
assertTxFeeRate(t, &lowFeeTx, lowFeeRate, &input)
|
||||
assertTxFeeRate(t, &lowFeeTx, lowFeeRate, changePk, &input)
|
||||
|
||||
// We'll then attempt to bump its fee rate.
|
||||
highFeePref := FeePreference{ConfTarget: 6}
|
||||
|
@ -1242,7 +1254,7 @@ func TestBumpFeeRBF(t *testing.T) {
|
|||
// A higher fee rate transaction should be immediately broadcast.
|
||||
ctx.tick()
|
||||
highFeeTx := ctx.receiveTx()
|
||||
assertTxFeeRate(t, &highFeeTx, highFeeRate, &input)
|
||||
assertTxFeeRate(t, &highFeeTx, highFeeRate, changePk, &input)
|
||||
|
||||
// We'll finish our test by mining the sweep transaction.
|
||||
ctx.backend.mine()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package sweep
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
@ -139,9 +140,13 @@ func createSweepTx(inputs []input.Input, outputs []*wire.TxOut,
|
|||
feePerKw chainfee.SatPerKWeight, signer input.Signer) (*wire.MsgTx,
|
||||
error) {
|
||||
|
||||
inputs, estimator := getWeightEstimate(
|
||||
inputs, estimator, err := getWeightEstimate(
|
||||
inputs, outputs, feePerKw, changePkScript,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
txFee := estimator.fee()
|
||||
|
||||
var (
|
||||
|
@ -315,7 +320,7 @@ func createSweepTx(inputs []input.Input, outputs []*wire.TxOut,
|
|||
// Additionally, it returns counts for the number of csv and cltv inputs.
|
||||
func getWeightEstimate(inputs []input.Input, outputs []*wire.TxOut,
|
||||
feeRate chainfee.SatPerKWeight, outputPkScript []byte) ([]input.Input,
|
||||
*weightEstimator) {
|
||||
*weightEstimator, error) {
|
||||
|
||||
// We initialize a weight estimator so we can accurately asses the
|
||||
// amount of fees we need to pay for this sweep transaction.
|
||||
|
@ -337,10 +342,25 @@ func getWeightEstimate(inputs []input.Input, outputs []*wire.TxOut,
|
|||
// change output to the weight estimate regardless, since the estimated
|
||||
// fee will just be subtracted from this already dust output, and
|
||||
// trimmed.
|
||||
if txscript.IsPayToTaproot(outputPkScript) {
|
||||
switch {
|
||||
case txscript.IsPayToTaproot(outputPkScript):
|
||||
weightEstimate.addP2TROutput()
|
||||
} else {
|
||||
|
||||
case txscript.IsPayToWitnessScriptHash(outputPkScript):
|
||||
weightEstimate.addP2WSHOutput()
|
||||
|
||||
case txscript.IsPayToWitnessPubKeyHash(outputPkScript):
|
||||
weightEstimate.addP2WKHOutput()
|
||||
|
||||
case txscript.IsPayToPubKeyHash(outputPkScript):
|
||||
weightEstimate.estimator.AddP2PKHOutput()
|
||||
|
||||
case txscript.IsPayToScriptHash(outputPkScript):
|
||||
weightEstimate.estimator.AddP2SHOutput()
|
||||
|
||||
default:
|
||||
// Unknown script type.
|
||||
return nil, nil, errors.New("unknown script type")
|
||||
}
|
||||
|
||||
// For each output, use its witness type to determine the estimate
|
||||
|
@ -368,7 +388,7 @@ func getWeightEstimate(inputs []input.Input, outputs []*wire.TxOut,
|
|||
sweepInputs = append(sweepInputs, inp)
|
||||
}
|
||||
|
||||
return sweepInputs, weightEstimate
|
||||
return sweepInputs, weightEstimate, nil
|
||||
}
|
||||
|
||||
// inputSummary returns a string containing a human readable summary about the
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -39,7 +40,19 @@ func TestWeightEstimate(t *testing.T) {
|
|||
))
|
||||
}
|
||||
|
||||
_, estimator := getWeightEstimate(inputs, nil, 0, nil)
|
||||
// Create a sweep script that is always fed into the weight estimator,
|
||||
// regardless if it's actually included in the tx. It will be a P2WKH
|
||||
// script.
|
||||
changePkScript := []byte{
|
||||
0x00, 0x14,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
}
|
||||
|
||||
_, estimator, err := getWeightEstimate(inputs, nil, 0, changePkScript)
|
||||
require.NoError(t, err)
|
||||
|
||||
weight := int64(estimator.weight())
|
||||
if weight != expectedWeight {
|
||||
t.Fatalf("unexpected weight. expected %d but got %d.",
|
||||
|
@ -51,3 +64,97 @@ func TestWeightEstimate(t *testing.T) {
|
|||
expectedSummary, summary)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWeightEstimatorUnknownScript tests that the weight estimator fails when
|
||||
// given an unknown script and succeeds when given a known script.
|
||||
func TestWeightEstimatorUnknownScript(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pkscript []byte
|
||||
expectFail bool
|
||||
}{
|
||||
{
|
||||
name: "p2tr output",
|
||||
pkscript: []byte{
|
||||
0x51, 0x20,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "p2wsh output",
|
||||
pkscript: []byte{
|
||||
0x00, 0x20,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "p2wkh output",
|
||||
pkscript: []byte{
|
||||
0x00, 0x14,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "p2pkh output",
|
||||
pkscript: []byte{
|
||||
0x76, 0xa9, 0x14,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x88, 0xac,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "p2sh output",
|
||||
pkscript: []byte{
|
||||
0xa9, 0x14,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x87,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unknown script",
|
||||
pkscript: []byte{0x00},
|
||||
expectFail: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
testUnknownScriptInner(
|
||||
t, test.pkscript, test.expectFail,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testUnknownScriptInner(t *testing.T, pkscript []byte, expectFail bool) {
|
||||
var inputs []input.Input
|
||||
for i, witnessType := range witnessTypes {
|
||||
inputs = append(inputs, input.NewBaseInput(
|
||||
&wire.OutPoint{
|
||||
Hash: chainhash.Hash{byte(i)},
|
||||
Index: uint32(i) + 10,
|
||||
}, witnessType,
|
||||
&input.SignDescriptor{}, 0,
|
||||
))
|
||||
}
|
||||
|
||||
_, _, err := getWeightEstimate(inputs, nil, 0, pkscript)
|
||||
if expectFail {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -83,6 +83,12 @@ func (w *weightEstimator) addP2TROutput() {
|
|||
w.estimator.AddP2TROutput()
|
||||
}
|
||||
|
||||
// addP2WSHOutput updates the weight estimate to account for an additional
|
||||
// segwit v0 P2WSH output.
|
||||
func (w *weightEstimator) addP2WSHOutput() {
|
||||
w.estimator.AddP2WSHOutput()
|
||||
}
|
||||
|
||||
// addOutput updates the weight estimate to account for the known
|
||||
// output given.
|
||||
func (w *weightEstimator) addOutput(txOut *wire.TxOut) {
|
||||
|
|
Loading…
Add table
Reference in a new issue