lnd/lnrpc/walletrpc/walletkit_server_test.go
2024-05-25 13:37:13 +08:00

634 lines
16 KiB
Go

//go:build walletrpc
// +build walletrpc
package walletrpc
import (
"bytes"
"fmt"
"strings"
"testing"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/btcutil/psbt"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/wallet"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntest/mock"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwallet/chanfunding"
"github.com/stretchr/testify/require"
)
// TestWitnessTypeMapping tests that the two witness type enums in the `input`
// package and the `walletrpc` package remain equal.
func TestWitnessTypeMapping(t *testing.T) {
t.Parallel()
// Tests that both enum types have the same length except the
// UNKNOWN_WITNESS type which is only present in the walletrpc
// witness type enum.
require.Equal(
t, len(allWitnessTypes), len(WitnessType_name)-1,
"number of witness types should match proto definition",
)
// Tests that the string representations of both enum types are
// equivalent.
for witnessType, witnessTypeProto := range allWitnessTypes {
// Redeclare to avoid loop variables being captured
// by func literal.
witnessType := witnessType
witnessTypeProto := witnessTypeProto
t.Run(witnessType.String(), func(tt *testing.T) {
tt.Parallel()
witnessTypeName := witnessType.String()
witnessTypeName = strings.ToUpper(witnessTypeName)
witnessTypeProtoName := witnessTypeProto.String()
witnessTypeProtoName = strings.ReplaceAll(
witnessTypeProtoName, "_", "",
)
require.Equal(
t, witnessTypeName, witnessTypeProtoName,
"mapped witness types should be named the same",
)
})
}
}
type mockCoinSelectionLocker struct {
fail bool
}
func (m *mockCoinSelectionLocker) WithCoinSelectLock(f func() error) error {
if err := f(); err != nil {
return err
}
if m.fail {
return fmt.Errorf("kek")
}
return nil
}
// TestFundPsbtCoinSelect tests that the coin selection for a PSBT template
// works as expected.
func TestFundPsbtCoinSelect(t *testing.T) {
t.Parallel()
const fundAmt = 50_000
var (
p2wkhDustLimit = lnwallet.DustLimitForSize(input.P2WPKHSize)
p2trDustLimit = lnwallet.DustLimitForSize(input.P2TRSize)
p2wkhScript, _ = input.WitnessPubKeyHash([]byte{})
p2trScript, _ = txscript.PayToTaprootScript(
&input.TaprootNUMSKey,
)
)
makePacket := func(outs ...*wire.TxOut) *psbt.Packet {
p := &psbt.Packet{
UnsignedTx: &wire.MsgTx{},
}
for _, out := range outs {
p.UnsignedTx.TxOut = append(p.UnsignedTx.TxOut, out)
p.Outputs = append(p.Outputs, psbt.POutput{})
}
return p
}
updatePacket := func(p *psbt.Packet,
f func(*psbt.Packet) *psbt.Packet) *psbt.Packet {
return f(p)
}
calcFee := func(p2trIn, p2wkhIn, p2trOut, p2wkhOut int,
dust btcutil.Amount) btcutil.Amount {
estimator := input.TxWeightEstimator{}
for i := 0; i < p2trIn; i++ {
estimator.AddTaprootKeySpendInput(
txscript.SigHashDefault,
)
}
for i := 0; i < p2wkhIn; i++ {
estimator.AddP2WKHInput()
}
for i := 0; i < p2trOut; i++ {
estimator.AddP2TROutput()
}
for i := 0; i < p2wkhOut; i++ {
estimator.AddP2WKHOutput()
}
weight := estimator.Weight()
fee := chainfee.FeePerKwFloor.FeeForWeight(weight)
return fee + dust
}
testCases := []struct {
name string
utxos []*lnwallet.Utxo
packet *psbt.Packet
changeIndex int32
changeType chanfunding.ChangeAddressType
feeRate chainfee.SatPerKWeight
// expectedUtxoIndexes is the list of utxo indexes that are
// expected to be used for funding the psbt.
expectedUtxoIndexes []int
// expectChangeOutputIndex is the expected output index that is
// returned from the tested method.
expectChangeOutputIndex int32
// expectedChangeOutputAmount is the expected final total amount
// of the output marked as the change output. This will only be
// checked if the expected amount is non-zero.
expectedChangeOutputAmount btcutil.Amount
// expectedFee is the total amount of fees paid by the funded
// packet in bytes.
expectedFee btcutil.Amount
// expectedErr is the expected concrete error. If not nil, then
// the error must match exactly.
expectedErr error
// expectedErrType is the expected error type. If not nil, then
// the error must be of this type.
expectedErrType error
}{{
name: "no utxos",
utxos: []*lnwallet.Utxo{},
packet: makePacket(&wire.TxOut{
Value: fundAmt,
PkScript: p2trScript,
}),
changeIndex: -1,
feeRate: chainfee.FeePerKwFloor,
expectedErrType: &chanfunding.ErrInsufficientFunds{},
}, {
name: "1 p2wpkh utxo, add p2wkh change",
utxos: []*lnwallet.Utxo{
{
Value: 100_000,
PkScript: p2wkhScript,
},
},
packet: makePacket(&wire.TxOut{
Value: fundAmt,
PkScript: p2trScript,
}),
changeIndex: -1,
feeRate: chainfee.FeePerKwFloor,
expectedUtxoIndexes: []int{0},
expectChangeOutputIndex: 1,
expectedFee: calcFee(0, 1, 1, 1, 0),
}, {
name: "1 p2wpkh utxo, add p2tr change",
utxos: []*lnwallet.Utxo{
{
Value: 100_000,
PkScript: p2wkhScript,
},
},
packet: makePacket(&wire.TxOut{
Value: fundAmt,
PkScript: p2trScript,
}),
changeIndex: -1,
feeRate: chainfee.FeePerKwFloor,
changeType: chanfunding.P2TRChangeAddress,
expectedUtxoIndexes: []int{0},
expectChangeOutputIndex: 1,
expectedFee: calcFee(0, 1, 2, 0, 0),
}, {
name: "1 p2wpkh utxo, no change, exact amount",
utxos: []*lnwallet.Utxo{
{
Value: fundAmt + 123,
PkScript: p2wkhScript,
},
},
packet: makePacket(&wire.TxOut{
Value: fundAmt,
PkScript: p2trScript,
}),
changeIndex: -1,
feeRate: chainfee.FeePerKwFloor,
expectedUtxoIndexes: []int{0},
expectChangeOutputIndex: -1,
expectedFee: calcFee(0, 1, 1, 0, 0),
}, {
name: "1 p2wpkh utxo, no change, p2wpkh change dust to fee",
utxos: []*lnwallet.Utxo{
{
Value: fundAmt + calcFee(
0, 1, 1, 0, p2wkhDustLimit-1,
),
PkScript: p2wkhScript,
},
},
packet: makePacket(&wire.TxOut{
Value: fundAmt,
PkScript: p2trScript,
}),
changeIndex: -1,
feeRate: chainfee.FeePerKwFloor,
changeType: chanfunding.P2WKHChangeAddress,
expectedUtxoIndexes: []int{0},
expectChangeOutputIndex: -1,
expectedFee: calcFee(0, 1, 1, 0, p2wkhDustLimit-1),
}, {
name: "1 p2wpkh utxo, no change, p2tr change dust to fee",
utxos: []*lnwallet.Utxo{
{
Value: fundAmt + calcFee(
0, 1, 1, 0, p2trDustLimit-1,
),
PkScript: p2wkhScript,
},
},
packet: makePacket(&wire.TxOut{
Value: fundAmt,
PkScript: p2trScript,
}),
changeIndex: -1,
feeRate: chainfee.FeePerKwFloor,
changeType: chanfunding.P2TRChangeAddress,
expectedUtxoIndexes: []int{0},
expectChangeOutputIndex: -1,
expectedFee: calcFee(0, 1, 1, 0, p2trDustLimit-1),
}, {
name: "1 p2wpkh utxo, existing p2tr change",
utxos: []*lnwallet.Utxo{
{
Value: fundAmt + 50_000,
PkScript: p2wkhScript,
},
},
packet: makePacket(&wire.TxOut{
Value: fundAmt,
PkScript: p2trScript,
}),
changeIndex: 0,
feeRate: chainfee.FeePerKwFloor,
changeType: chanfunding.ExistingChangeAddress,
expectedUtxoIndexes: []int{0},
expectChangeOutputIndex: 0,
expectedFee: calcFee(0, 1, 1, 0, 0),
}, {
name: "1 p2wpkh utxo, existing p2wkh change",
utxos: []*lnwallet.Utxo{
{
Value: fundAmt + 50_000,
PkScript: p2wkhScript,
},
},
packet: makePacket(&wire.TxOut{
Value: fundAmt,
PkScript: p2wkhScript,
}),
changeIndex: 0,
feeRate: chainfee.FeePerKwFloor,
changeType: chanfunding.ExistingChangeAddress,
expectedUtxoIndexes: []int{0},
expectChangeOutputIndex: 0,
expectedFee: calcFee(0, 1, 0, 1, 0),
}, {
name: "1 p2wpkh utxo, existing p2wkh change, dust change",
utxos: []*lnwallet.Utxo{
{
Value: fundAmt + calcFee(0, 1, 0, 1, 0) + 50,
PkScript: p2wkhScript,
},
},
packet: makePacket(&wire.TxOut{
Value: fundAmt,
PkScript: p2wkhScript,
}),
changeIndex: 0,
feeRate: chainfee.FeePerKwFloor,
changeType: chanfunding.ExistingChangeAddress,
expectedUtxoIndexes: []int{0},
expectChangeOutputIndex: 0,
expectedFee: calcFee(0, 1, 0, 1, 0),
}, {
name: "1 p2wpkh + 1 p2tr utxo, existing p2tr input, existing " +
"p2tr change",
utxos: []*lnwallet.Utxo{
{
Value: fundAmt / 2,
PkScript: p2wkhScript,
}, {
Value: fundAmt / 2,
PkScript: p2trScript,
},
},
packet: updatePacket(makePacket(&wire.TxOut{
Value: fundAmt,
PkScript: p2trScript,
}), func(p *psbt.Packet) *psbt.Packet {
p.UnsignedTx.TxIn = append(
p.UnsignedTx.TxIn, &wire.TxIn{
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{1, 2, 3},
},
},
)
p2TrDerivations := []*psbt.TaprootBip32Derivation{
{
XOnlyPubKey: schnorr.SerializePubKey(
&input.TaprootNUMSKey,
),
Bip32Path: []uint32{1, 2, 3},
},
}
p.Inputs = append(p.Inputs, psbt.PInput{
WitnessUtxo: &wire.TxOut{
Value: 1000,
PkScript: p2trScript,
},
SighashType: txscript.SigHashSingle,
TaprootBip32Derivation: p2TrDerivations,
})
return p
}),
changeIndex: 0,
feeRate: chainfee.FeePerKwFloor,
changeType: chanfunding.ExistingChangeAddress,
expectedUtxoIndexes: []int{0, 1},
expectChangeOutputIndex: 0,
expectedFee: calcFee(2, 1, 1, 0, 0),
}, {
name: "1 p2wpkh + 1 p2tr utxo, existing p2tr input, add p2tr " +
"change",
utxos: []*lnwallet.Utxo{
{
Value: fundAmt / 2,
PkScript: p2wkhScript,
}, {
Value: fundAmt / 2,
PkScript: p2trScript,
},
},
packet: updatePacket(makePacket(&wire.TxOut{
Value: fundAmt,
PkScript: p2trScript,
}), func(p *psbt.Packet) *psbt.Packet {
p.UnsignedTx.TxIn = append(
p.UnsignedTx.TxIn, &wire.TxIn{
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{1, 2, 3},
},
},
)
p2TrDerivations := []*psbt.TaprootBip32Derivation{
{
XOnlyPubKey: schnorr.SerializePubKey(
&input.TaprootNUMSKey,
),
Bip32Path: []uint32{1, 2, 3},
},
}
p.Inputs = append(p.Inputs, psbt.PInput{
WitnessUtxo: &wire.TxOut{
Value: 1000,
PkScript: p2trScript,
},
SighashType: txscript.SigHashSingle,
TaprootBip32Derivation: p2TrDerivations,
})
return p
}),
changeIndex: -1,
feeRate: chainfee.FeePerKwFloor,
changeType: chanfunding.P2TRChangeAddress,
expectedUtxoIndexes: []int{0, 1},
expectChangeOutputIndex: 1,
expectedFee: calcFee(2, 1, 2, 0, 0),
}, {
name: "large existing p2tr input, fee estimation p2wpkh " +
"change",
utxos: []*lnwallet.Utxo{},
packet: updatePacket(makePacket(&wire.TxOut{
Value: fundAmt,
PkScript: p2trScript,
}), func(p *psbt.Packet) *psbt.Packet {
p.UnsignedTx.TxIn = append(
p.UnsignedTx.TxIn, &wire.TxIn{
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{1, 2, 3},
},
},
)
p2TrDerivations := []*psbt.TaprootBip32Derivation{
{
XOnlyPubKey: schnorr.SerializePubKey(
&input.TaprootNUMSKey,
),
Bip32Path: []uint32{1, 2, 3},
},
}
p.Inputs = append(p.Inputs, psbt.PInput{
WitnessUtxo: &wire.TxOut{
Value: fundAmt * 3,
PkScript: p2trScript,
},
TaprootBip32Derivation: p2TrDerivations,
})
return p
}),
changeIndex: -1,
feeRate: chainfee.FeePerKwFloor,
changeType: chanfunding.P2WKHChangeAddress,
expectedUtxoIndexes: []int{},
expectChangeOutputIndex: 1,
expectedChangeOutputAmount: fundAmt*3 - fundAmt -
calcFee(1, 0, 1, 1, 0),
expectedFee: calcFee(1, 0, 1, 1, 0),
}, {
name: "large existing p2tr input, fee estimation no change",
utxos: []*lnwallet.Utxo{},
packet: updatePacket(makePacket(&wire.TxOut{
Value: fundAmt,
PkScript: p2trScript,
}), func(p *psbt.Packet) *psbt.Packet {
p.UnsignedTx.TxIn = append(
p.UnsignedTx.TxIn, &wire.TxIn{
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{1, 2, 3},
},
},
)
p2TrDerivations := []*psbt.TaprootBip32Derivation{
{
XOnlyPubKey: schnorr.SerializePubKey(
&input.TaprootNUMSKey,
),
Bip32Path: []uint32{1, 2, 3},
},
}
p.Inputs = append(p.Inputs, psbt.PInput{
WitnessUtxo: &wire.TxOut{
Value: fundAmt +
int64(calcFee(1, 0, 1, 0, 0)),
PkScript: p2trScript,
},
TaprootBip32Derivation: p2TrDerivations,
})
return p
}),
changeIndex: -1,
feeRate: chainfee.FeePerKwFloor,
changeType: chanfunding.P2TRChangeAddress,
expectedUtxoIndexes: []int{},
expectChangeOutputIndex: -1,
expectedFee: calcFee(1, 0, 1, 0, 0),
}, {
name: "large existing p2tr input, fee estimation existing " +
"change output",
utxos: []*lnwallet.Utxo{},
packet: updatePacket(makePacket(&wire.TxOut{
Value: fundAmt,
PkScript: p2trScript,
}), func(p *psbt.Packet) *psbt.Packet {
p.UnsignedTx.TxIn = append(
p.UnsignedTx.TxIn, &wire.TxIn{
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{1, 2, 3},
},
},
)
p2TrDerivations := []*psbt.TaprootBip32Derivation{
{
XOnlyPubKey: schnorr.SerializePubKey(
&input.TaprootNUMSKey,
),
Bip32Path: []uint32{1, 2, 3},
},
}
p.Inputs = append(p.Inputs, psbt.PInput{
WitnessUtxo: &wire.TxOut{
Value: fundAmt * 2,
PkScript: p2trScript,
},
TaprootBip32Derivation: p2TrDerivations,
})
return p
}),
changeIndex: 0,
feeRate: chainfee.FeePerKwFloor,
changeType: chanfunding.ExistingChangeAddress,
expectedUtxoIndexes: []int{},
expectChangeOutputIndex: 0,
expectedChangeOutputAmount: fundAmt*2 - calcFee(1, 0, 1, 0, 0),
expectedFee: calcFee(1, 0, 1, 0, 0),
}}
for _, tc := range testCases {
tc := tc
privKey, err := btcec.NewPrivateKey()
require.NoError(t, err)
walletMock := &mock.WalletController{
RootKey: privKey,
Utxos: tc.utxos,
}
rpcServer, _, err := New(&Config{
Wallet: walletMock,
CoinSelectionLocker: &mockCoinSelectionLocker{},
CoinSelectionStrategy: wallet.CoinSelectionLargest,
})
require.NoError(t, err)
t.Run(tc.name, func(tt *testing.T) {
// To avoid our packet being mutated, we'll make a deep
// copy of it, so we can still use the original in the
// test case to compare the results to.
var buf bytes.Buffer
err := tc.packet.Serialize(&buf)
require.NoError(tt, err)
copiedPacket, err := psbt.NewFromRawBytes(&buf, false)
require.NoError(tt, err)
resp, err := rpcServer.fundPsbtCoinSelect(
"", tc.changeIndex, copiedPacket, 0,
tc.changeType, tc.feeRate,
rpcServer.cfg.CoinSelectionStrategy,
)
switch {
case tc.expectedErr != nil:
require.Error(tt, err)
require.ErrorIs(tt, err, tc.expectedErr)
return
case tc.expectedErrType != nil:
require.Error(tt, err)
require.ErrorAs(tt, err, &tc.expectedErr)
return
}
require.NoError(tt, err)
require.NotNil(tt, resp)
resultPacket, err := psbt.NewFromRawBytes(
bytes.NewReader(resp.FundedPsbt), false,
)
require.NoError(tt, err)
resultTx := resultPacket.UnsignedTx
expectedNumInputs := len(tc.expectedUtxoIndexes) +
len(tc.packet.Inputs)
require.Len(tt, resultPacket.Inputs, expectedNumInputs)
require.Len(tt, resultTx.TxIn, expectedNumInputs)
require.Equal(
tt, tc.expectChangeOutputIndex,
resp.ChangeOutputIndex,
)
fee, err := resultPacket.GetTxFee()
require.NoError(tt, err)
require.EqualValues(tt, tc.expectedFee, fee)
if tc.expectedChangeOutputAmount != 0 {
changeIdx := resp.ChangeOutputIndex
require.GreaterOrEqual(tt, changeIdx, int32(-1))
require.Less(
tt, changeIdx,
int32(len(resultTx.TxOut)),
)
changeOut := resultTx.TxOut[changeIdx]
require.EqualValues(
tt, tc.expectedChangeOutputAmount,
changeOut.Value,
)
}
})
}
}