lnd/lnwallet/chancloser/chancloser_test.go
positiveblue 2204cbfd30
rpc: validate closing channel address in open channel requests
Our OpenChannelRPC was accepting invalid values for the closing address
field. If we were able to decode the address we would use it in the
script even if the address is for another bitcoin net.
2023-01-04 04:42:39 -08:00

380 lines
9.4 KiB
Go

package chancloser
import (
"bytes"
"fmt"
"testing"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/stretchr/testify/require"
)
// TestMaybeMatchScript tests that the maybeMatchScript errors appropriately
// when an upfront shutdown script is set and the script provided does not
// match, and does not error in any other case.
func TestMaybeMatchScript(t *testing.T) {
t.Parallel()
pubHash := bytes.Repeat([]byte{0x0}, 20)
scriptHash := bytes.Repeat([]byte{0x0}, 32)
p2wkh, err := txscript.NewScriptBuilder().AddOp(txscript.OP_0).
AddData(pubHash).Script()
require.NoError(t, err)
p2wsh, err := txscript.NewScriptBuilder().AddOp(txscript.OP_0).
AddData(scriptHash).Script()
require.NoError(t, err)
p2tr, err := txscript.NewScriptBuilder().AddOp(txscript.OP_1).
AddData(scriptHash).Script()
require.NoError(t, err)
p2OtherV1, err := txscript.NewScriptBuilder().AddOp(txscript.OP_1).
AddData(pubHash).Script()
require.NoError(t, err)
invalidFork, err := txscript.NewScriptBuilder().AddOp(txscript.OP_NOP).
AddData(scriptHash).Script()
require.NoError(t, err)
type testCase struct {
name string
shutdownScript lnwire.DeliveryAddress
upfrontScript lnwire.DeliveryAddress
expectedErr error
}
tests := []testCase{
{
name: "no upfront shutdown set, script ok",
shutdownScript: p2wkh,
upfrontScript: []byte{},
expectedErr: nil,
},
{
name: "upfront shutdown set, script ok",
shutdownScript: p2wkh,
upfrontScript: p2wkh,
expectedErr: nil,
},
{
name: "upfront shutdown set, script not ok",
shutdownScript: p2wkh,
upfrontScript: p2wsh,
expectedErr: ErrUpfrontShutdownScriptMismatch,
},
{
name: "nil shutdown and empty upfront",
shutdownScript: nil,
upfrontScript: []byte{},
expectedErr: nil,
},
{
name: "p2tr is ok",
shutdownScript: p2tr,
},
{
name: "segwit v1 is ok",
shutdownScript: p2OtherV1,
},
{
name: "invalid script not allowed",
shutdownScript: invalidFork,
expectedErr: ErrInvalidShutdownScript,
},
}
// All future segwit softforks should also be ok.
futureForks := []byte{
txscript.OP_1, txscript.OP_2, txscript.OP_3, txscript.OP_4,
txscript.OP_5, txscript.OP_6, txscript.OP_7, txscript.OP_8,
txscript.OP_9, txscript.OP_10, txscript.OP_11, txscript.OP_12,
txscript.OP_13, txscript.OP_14, txscript.OP_15, txscript.OP_16,
}
for _, witnessVersion := range futureForks {
p2FutureFork, err := txscript.NewScriptBuilder().AddOp(witnessVersion).
AddData(scriptHash).Script()
require.NoError(t, err)
opString, err := txscript.DisasmString([]byte{witnessVersion})
require.NoError(t, err)
tests = append(tests, testCase{
name: fmt.Sprintf("witness_version=%v", opString),
shutdownScript: p2FutureFork,
})
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
err := validateShutdownScript(
func() error { return nil }, test.upfrontScript,
test.shutdownScript, &chaincfg.SimNetParams,
)
if err != test.expectedErr {
t.Fatalf("Error: %v, expected error: %v", err, test.expectedErr)
}
})
}
}
type mockChannel struct {
chanPoint wire.OutPoint
initiator bool
scid lnwire.ShortChannelID
}
func (m *mockChannel) ChannelPoint() *wire.OutPoint {
return &m.chanPoint
}
func (m *mockChannel) MarkCoopBroadcasted(*wire.MsgTx, bool) error {
return nil
}
func (m *mockChannel) IsInitiator() bool {
return m.initiator
}
func (m *mockChannel) ShortChanID() lnwire.ShortChannelID {
return m.scid
}
func (m *mockChannel) AbsoluteThawHeight() (uint32, error) {
return 0, nil
}
func (m *mockChannel) RemoteUpfrontShutdownScript() lnwire.DeliveryAddress {
return lnwire.DeliveryAddress{}
}
func (m *mockChannel) CreateCloseProposal(fee btcutil.Amount,
localScript, remoteScript []byte,
) (input.Signature, *chainhash.Hash, btcutil.Amount, error) {
return nil, nil, 0, nil
}
func (m *mockChannel) CompleteCooperativeClose(localSig,
remoteSig input.Signature, localScript, remoteScript []byte,
proposedFee btcutil.Amount) (*wire.MsgTx, btcutil.Amount, error) {
return nil, 0, nil
}
func (m *mockChannel) LocalBalanceDust() bool {
return false
}
func (m *mockChannel) RemoteBalanceDust() bool {
return false
}
type mockCoopFeeEstimator struct {
targetFee btcutil.Amount
}
func (m *mockCoopFeeEstimator) EstimateFee(chanType channeldb.ChannelType,
localTxOut, remoteTxOut *wire.TxOut,
idealFeeRate chainfee.SatPerKWeight) btcutil.Amount {
return m.targetFee
}
// TestMaxFeeClamp tests that if a max fee is specified, then it's used instead
// of the default max fee multiplier.
func TestMaxFeeClamp(t *testing.T) {
t.Parallel()
const (
absoluteFeeOneSatByte = 126
absoluteFeeTenSatByte = 1265
)
tests := []struct {
name string
idealFee chainfee.SatPerKWeight
inputMaxFee chainfee.SatPerKWeight
maxFee btcutil.Amount
}{
{
// No max fee specified, we should see 3x the ideal fee.
name: "no max fee",
idealFee: chainfee.SatPerKWeight(253),
maxFee: absoluteFeeOneSatByte * defaultMaxFeeMultiplier,
},
{
// Max fee specified, this should be used in place.
name: "max fee clamp",
idealFee: chainfee.SatPerKWeight(253),
inputMaxFee: chainfee.SatPerKWeight(2530),
// We should get the resulting absolute fee based on a
// factor of 10 sat/byte (our new max fee).
maxFee: absoluteFeeTenSatByte,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
channel := mockChannel{}
chanCloser := NewChanCloser(
ChanCloseCfg{
Channel: &channel,
MaxFee: test.inputMaxFee,
FeeEstimator: &SimpleCoopFeeEstimator{},
}, nil, test.idealFee, 0, nil, false,
)
// We'll call initFeeBaseline early here since we need
// the populate these internal variables.
chanCloser.initFeeBaseline()
require.Equal(t, test.maxFee, chanCloser.maxFee)
})
}
}
// TestMaxFeeBailOut tests that once the negotiated fee rate rises above our
// maximum fee, we'll return an error and refuse to process a co-op close
// message.
func TestMaxFeeBailOut(t *testing.T) {
t.Parallel()
const (
absoluteFee = btcutil.Amount(1000)
idealFee = chainfee.SatPerKWeight(253)
)
for _, isInitiator := range []bool{true, false} {
isInitiator := isInitiator
t.Run(fmt.Sprintf("initiator=%v", isInitiator), func(t *testing.T) {
t.Parallel()
// First, we'll make our mock channel, and use that to
// instantiate our channel closer.
closeCfg := ChanCloseCfg{
Channel: &mockChannel{
initiator: isInitiator,
},
FeeEstimator: &mockCoopFeeEstimator{
targetFee: absoluteFee,
},
MaxFee: idealFee * 2,
}
chanCloser := NewChanCloser(
closeCfg, nil, idealFee, 0, nil, false,
)
// We'll now force the channel state into the
// closeFeeNegotiation state so we can skip straight to
// the juicy part. We'll also set our last fee sent so
// we'll attempt to actually "negotiate" here.
chanCloser.state = closeFeeNegotiation
chanCloser.lastFeeProposal = absoluteFee
// Next, we'll make a ClosingSigned message that
// proposes a fee that's above the specified max fee.
//
// NOTE: We use the absoluteFee here since our mock
// always returns this fee for the CalcFee method which
// is used to translate a fee rate
// into an absolute fee amount in sats.
closeMsg := &lnwire.ClosingSigned{
FeeSatoshis: absoluteFee * 2,
}
_, _, err := chanCloser.ProcessCloseMsg(closeMsg)
switch isInitiator {
// If we're the initiator, then we expect an error at
// this point.
case true:
require.ErrorIs(t, err, ErrProposalExeceedsMaxFee)
// Otherwise, we expect things to fail for some other
// reason (invalid sig, etc).
case false:
require.NotErrorIs(t, err, ErrProposalExeceedsMaxFee)
}
})
}
}
// TestParseUpfrontShutdownAddress tests the we are able to parse the upfront
// shutdown address properly.
func TestParseUpfrontShutdownAddress(t *testing.T) {
t.Parallel()
var (
testnetAddress = "tb1qdfkmwwgdaa5dnezrlhtftvmj5qn2kwgp7n0z6r"
regtestAddress = "bcrt1q09crvvuj95x5nk64wsxf5n6ky0kr8358vpx4d8"
)
tests := []struct {
name string
address string
params chaincfg.Params
expectedErr string
}{
{
name: "invalid closing address",
address: "non-valid-address",
params: chaincfg.RegressionNetParams,
expectedErr: "invalid address",
},
{
name: "closing address from another net",
address: testnetAddress,
params: chaincfg.RegressionNetParams,
expectedErr: "not a regtest address",
},
{
name: "valid p2wkh closing address",
address: regtestAddress,
params: chaincfg.RegressionNetParams,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
_, err := ParseUpfrontShutdownAddress(
tc.address, &tc.params,
)
if tc.expectedErr != "" {
require.ErrorContains(t, err, tc.expectedErr)
return
}
require.NoError(t, err)
})
}
}