sweep: account for all script types in craftSweepTx

With this change, transactions created via craftSweepTx will be
standard. Previously, p2wsh/p2pkh scripts passed in via SendCoins would
be weighted as p2wpkh scripts. With a feerate of 1 sat/vbyte,
transactions returned would be non-standard. Luckily, the critical
sweeper subsystem only used p2wpkh scripts so this only affected
callers from the rpcserver.

Also added is an integration test that fails if SendCoins manages
to generate a non-standard transaction. All script types are now
accounted for in getWeightEstimate, which now errors if an unknown
script type is passed in.
This commit is contained in:
eugene 2022-07-15 16:25:31 -04:00
parent 6060e05d7c
commit 6ec2826f6c
No known key found for this signature in database
GPG key ID: 118759E83439A9B1
6 changed files with 311 additions and 15 deletions

View file

@ -0,0 +1,147 @@
package itest
import (
"context"
"testing"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lntest"
"github.com/stretchr/testify/require"
)
func testNonstdSweep(net *lntest.NetworkHarness, t *harnessTest) {
p2shAddr, err := btcutil.NewAddressScriptHash(
make([]byte, 1), harnessNetParams,
)
require.NoError(t.t, err)
p2pkhAddr, err := btcutil.NewAddressPubKeyHash(
make([]byte, 20), harnessNetParams,
)
require.NoError(t.t, err)
p2wshAddr, err := btcutil.NewAddressWitnessScriptHash(
make([]byte, 32), harnessNetParams,
)
require.NoError(t.t, err)
p2wkhAddr, err := btcutil.NewAddressWitnessPubKeyHash(
make([]byte, 20), harnessNetParams,
)
require.NoError(t.t, err)
p2trAddr, err := btcutil.NewAddressTaproot(
make([]byte, 32), harnessNetParams,
)
require.NoError(t.t, err)
tests := []struct {
name string
address string
}{
{
name: "p2sh SendCoins standardness",
address: p2shAddr.EncodeAddress(),
},
{
name: "p2pkh SendCoins standardness",
address: p2pkhAddr.EncodeAddress(),
},
{
name: "p2wsh SendCoins standardness",
address: p2wshAddr.EncodeAddress(),
},
{
name: "p2wkh SendCoins standardness",
address: p2wkhAddr.EncodeAddress(),
},
{
name: "p2tr SendCoins standardness",
address: p2trAddr.EncodeAddress(),
},
}
for _, test := range tests {
test := test
success := t.t.Run(test.name, func(t *testing.T) {
h := newHarnessTest(t, net)
testNonStdSweepInner(net, h, test.address)
})
if !success {
break
}
}
}
func testNonStdSweepInner(net *lntest.NetworkHarness, t *harnessTest,
address string) {
ctxb := context.Background()
carol := net.NewNode(t.t, "carol", nil)
// Give Carol a UTXO so SendCoins will behave as expected.
net.SendCoins(t.t, btcutil.SatoshiPerBitcoin, carol)
// Set the fee estimate to 1sat/vbyte.
net.SetFeeEstimate(250)
// Make Carol call SendCoins with the SendAll flag and the created
// address.
sendReq := &lnrpc.SendCoinsRequest{
Addr: address,
SatPerVbyte: 1,
SendAll: true,
}
ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout)
defer cancel()
// If a non-standard transaction was created, then this SendCoins call
// will fail.
_, err := carol.SendCoins(ctxt, sendReq)
require.NoError(t.t, err)
// Fetch the txid so we can grab the raw transaction.
txid, err := waitForTxInMempool(net.Miner.Client, minerMempoolTimeout)
require.NoError(t.t, err)
tx, err := net.Miner.Client.GetRawTransaction(txid)
require.NoError(t.t, err)
msgTx := tx.MsgTx()
// Fetch the fee of the transaction.
var (
inputVal int
outputVal int
fee int
)
for _, inp := range msgTx.TxIn {
// Fetch the previous outpoint's value.
prevOut := inp.PreviousOutPoint
ptx, err := net.Miner.Client.GetRawTransaction(&prevOut.Hash)
require.NoError(t.t, err)
pout := ptx.MsgTx().TxOut[prevOut.Index]
inputVal += int(pout.Value)
}
for _, outp := range msgTx.TxOut {
outputVal += int(outp.Value)
}
fee = inputVal - outputVal
// Fetch the vsize of the transaction so we can determine if the
// transaction pays >= 1 sat/vbyte.
rawTx, err := net.Miner.Client.GetRawTransactionVerbose(txid)
require.NoError(t.t, err)
// Require fee >= vbytes.
require.True(t.t, fee >= int(rawTx.Vsize))
}

View file

@ -427,4 +427,8 @@ var allTestCases = []*testCase{
name: "scid alias upgrade",
test: testOptionScidUpgrade,
},
{
name: "nonstd sweep",
test: testNonstdSweep,
},
}

View file

@ -124,7 +124,6 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext {
timeoutChan: make(chan chan time.Time, 1),
}
var outputScriptCount byte
ctx.sweeper = New(&UtxoSweeperConfig{
Notifier: notifier,
Wallet: backend,
@ -137,8 +136,8 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext {
Signer: &mock.DummySigner{},
GenSweepScript: func() ([]byte, error) {
script := make([]byte, input.P2WPKHSize)
script[0] = outputScriptCount
outputScriptCount++
script[0] = 0
script[1] = 20
return script, nil
},
FeeEstimator: estimator,
@ -330,7 +329,8 @@ func assertTxSweepsInputs(t *testing.T, sweepTx *wire.MsgTx,
// NOTE: This assumes that transactions only have one output, as this is the
// only type of transaction the UtxoSweeper can create at the moment.
func assertTxFeeRate(t *testing.T, tx *wire.MsgTx,
expectedFeeRate chainfee.SatPerKWeight, inputs ...input.Input) {
expectedFeeRate chainfee.SatPerKWeight, changePk []byte,
inputs ...input.Input) {
t.Helper()
@ -355,7 +355,9 @@ func assertTxFeeRate(t *testing.T, tx *wire.MsgTx,
outputAmt := tx.TxOut[0].Value
fee := btcutil.Amount(inputAmt - outputAmt)
_, estimator := getWeightEstimate(inputs, nil, 0, nil)
_, estimator, err := getWeightEstimate(inputs, nil, 0, changePk)
require.NoError(t, err)
txWeight := estimator.weight()
expectedFee := expectedFeeRate.FeeForWeight(int64(txWeight))
@ -1092,14 +1094,19 @@ func TestDifferentFeePreferences(t *testing.T) {
// transactions to be broadcast in order of high to low fee preference.
ctx.tick()
// Generate the same type of sweep script that was used for weight
// estimation.
changePk, err := ctx.sweeper.cfg.GenSweepScript()
require.NoError(t, err)
// The first transaction broadcast should be the one spending the higher
// fee rate inputs.
sweepTx1 := ctx.receiveTx()
assertTxFeeRate(t, &sweepTx1, highFeeRate, input1, input2)
assertTxFeeRate(t, &sweepTx1, highFeeRate, changePk, input1, input2)
// The second should be the one spending the lower fee rate inputs.
sweepTx2 := ctx.receiveTx()
assertTxFeeRate(t, &sweepTx2, lowFeeRate, input3)
assertTxFeeRate(t, &sweepTx2, lowFeeRate, changePk, input3)
// With the transactions broadcast, we'll mine a block to so that the
// result is delivered to each respective client.
@ -1218,10 +1225,15 @@ func TestBumpFeeRBF(t *testing.T) {
t.Fatal(err)
}
// Generate the same type of change script used so we can have accurate
// weight estimation.
changePk, err := ctx.sweeper.cfg.GenSweepScript()
require.NoError(t, err)
// Ensure that a transaction is broadcast with the lower fee preference.
ctx.tick()
lowFeeTx := ctx.receiveTx()
assertTxFeeRate(t, &lowFeeTx, lowFeeRate, &input)
assertTxFeeRate(t, &lowFeeTx, lowFeeRate, changePk, &input)
// We'll then attempt to bump its fee rate.
highFeePref := FeePreference{ConfTarget: 6}
@ -1242,7 +1254,7 @@ func TestBumpFeeRBF(t *testing.T) {
// A higher fee rate transaction should be immediately broadcast.
ctx.tick()
highFeeTx := ctx.receiveTx()
assertTxFeeRate(t, &highFeeTx, highFeeRate, &input)
assertTxFeeRate(t, &highFeeTx, highFeeRate, changePk, &input)
// We'll finish our test by mining the sweep transaction.
ctx.backend.mine()

View file

@ -1,6 +1,7 @@
package sweep
import (
"errors"
"fmt"
"sort"
"strings"
@ -139,9 +140,13 @@ func createSweepTx(inputs []input.Input, outputs []*wire.TxOut,
feePerKw chainfee.SatPerKWeight, signer input.Signer) (*wire.MsgTx,
error) {
inputs, estimator := getWeightEstimate(
inputs, estimator, err := getWeightEstimate(
inputs, outputs, feePerKw, changePkScript,
)
if err != nil {
return nil, err
}
txFee := estimator.fee()
var (
@ -315,7 +320,7 @@ func createSweepTx(inputs []input.Input, outputs []*wire.TxOut,
// Additionally, it returns counts for the number of csv and cltv inputs.
func getWeightEstimate(inputs []input.Input, outputs []*wire.TxOut,
feeRate chainfee.SatPerKWeight, outputPkScript []byte) ([]input.Input,
*weightEstimator) {
*weightEstimator, error) {
// We initialize a weight estimator so we can accurately asses the
// amount of fees we need to pay for this sweep transaction.
@ -337,10 +342,25 @@ func getWeightEstimate(inputs []input.Input, outputs []*wire.TxOut,
// change output to the weight estimate regardless, since the estimated
// fee will just be subtracted from this already dust output, and
// trimmed.
if txscript.IsPayToTaproot(outputPkScript) {
switch {
case txscript.IsPayToTaproot(outputPkScript):
weightEstimate.addP2TROutput()
} else {
case txscript.IsPayToWitnessScriptHash(outputPkScript):
weightEstimate.addP2WSHOutput()
case txscript.IsPayToWitnessPubKeyHash(outputPkScript):
weightEstimate.addP2WKHOutput()
case txscript.IsPayToPubKeyHash(outputPkScript):
weightEstimate.estimator.AddP2PKHOutput()
case txscript.IsPayToScriptHash(outputPkScript):
weightEstimate.estimator.AddP2SHOutput()
default:
// Unknown script type.
return nil, nil, errors.New("unknown script type")
}
// For each output, use its witness type to determine the estimate
@ -368,7 +388,7 @@ func getWeightEstimate(inputs []input.Input, outputs []*wire.TxOut,
sweepInputs = append(sweepInputs, inp)
}
return sweepInputs, weightEstimate
return sweepInputs, weightEstimate, nil
}
// inputSummary returns a string containing a human readable summary about the

View file

@ -6,6 +6,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/input"
"github.com/stretchr/testify/require"
)
var (
@ -39,7 +40,19 @@ func TestWeightEstimate(t *testing.T) {
))
}
_, estimator := getWeightEstimate(inputs, nil, 0, nil)
// Create a sweep script that is always fed into the weight estimator,
// regardless if it's actually included in the tx. It will be a P2WKH
// script.
changePkScript := []byte{
0x00, 0x14,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
}
_, estimator, err := getWeightEstimate(inputs, nil, 0, changePkScript)
require.NoError(t, err)
weight := int64(estimator.weight())
if weight != expectedWeight {
t.Fatalf("unexpected weight. expected %d but got %d.",
@ -51,3 +64,97 @@ func TestWeightEstimate(t *testing.T) {
expectedSummary, summary)
}
}
// TestWeightEstimatorUnknownScript tests that the weight estimator fails when
// given an unknown script and succeeds when given a known script.
func TestWeightEstimatorUnknownScript(t *testing.T) {
tests := []struct {
name string
pkscript []byte
expectFail bool
}{
{
name: "p2tr output",
pkscript: []byte{
0x51, 0x20,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
},
{
name: "p2wsh output",
pkscript: []byte{
0x00, 0x20,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
},
{
name: "p2wkh output",
pkscript: []byte{
0x00, 0x14,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
},
},
{
name: "p2pkh output",
pkscript: []byte{
0x76, 0xa9, 0x14,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x88, 0xac,
},
},
{
name: "p2sh output",
pkscript: []byte{
0xa9, 0x14,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x87,
},
},
{
name: "unknown script",
pkscript: []byte{0x00},
expectFail: true,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testUnknownScriptInner(
t, test.pkscript, test.expectFail,
)
})
}
}
func testUnknownScriptInner(t *testing.T, pkscript []byte, expectFail bool) {
var inputs []input.Input
for i, witnessType := range witnessTypes {
inputs = append(inputs, input.NewBaseInput(
&wire.OutPoint{
Hash: chainhash.Hash{byte(i)},
Index: uint32(i) + 10,
}, witnessType,
&input.SignDescriptor{}, 0,
))
}
_, _, err := getWeightEstimate(inputs, nil, 0, pkscript)
if expectFail {
require.Error(t, err)
} else {
require.NoError(t, err)
}
}

View file

@ -83,6 +83,12 @@ func (w *weightEstimator) addP2TROutput() {
w.estimator.AddP2TROutput()
}
// addP2WSHOutput updates the weight estimate to account for an additional
// segwit v0 P2WSH output.
func (w *weightEstimator) addP2WSHOutput() {
w.estimator.AddP2WSHOutput()
}
// addOutput updates the weight estimate to account for the known
// output given.
func (w *weightEstimator) addOutput(txOut *wire.TxOut) {