mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 18:10:34 +01:00
2204cbfd30
Our OpenChannelRPC was accepting invalid values for the closing address field. If we were able to decode the address we would use it in the script even if the address is for another bitcoin net.
380 lines
9.4 KiB
Go
380 lines
9.4 KiB
Go
package chancloser
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"testing"
|
|
|
|
"github.com/btcsuite/btcd/btcutil"
|
|
"github.com/btcsuite/btcd/chaincfg"
|
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
|
"github.com/btcsuite/btcd/txscript"
|
|
"github.com/btcsuite/btcd/wire"
|
|
"github.com/lightningnetwork/lnd/channeldb"
|
|
"github.com/lightningnetwork/lnd/input"
|
|
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
|
|
"github.com/lightningnetwork/lnd/lnwire"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// TestMaybeMatchScript tests that the maybeMatchScript errors appropriately
|
|
// when an upfront shutdown script is set and the script provided does not
|
|
// match, and does not error in any other case.
|
|
func TestMaybeMatchScript(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
pubHash := bytes.Repeat([]byte{0x0}, 20)
|
|
scriptHash := bytes.Repeat([]byte{0x0}, 32)
|
|
|
|
p2wkh, err := txscript.NewScriptBuilder().AddOp(txscript.OP_0).
|
|
AddData(pubHash).Script()
|
|
require.NoError(t, err)
|
|
|
|
p2wsh, err := txscript.NewScriptBuilder().AddOp(txscript.OP_0).
|
|
AddData(scriptHash).Script()
|
|
require.NoError(t, err)
|
|
|
|
p2tr, err := txscript.NewScriptBuilder().AddOp(txscript.OP_1).
|
|
AddData(scriptHash).Script()
|
|
require.NoError(t, err)
|
|
|
|
p2OtherV1, err := txscript.NewScriptBuilder().AddOp(txscript.OP_1).
|
|
AddData(pubHash).Script()
|
|
require.NoError(t, err)
|
|
|
|
invalidFork, err := txscript.NewScriptBuilder().AddOp(txscript.OP_NOP).
|
|
AddData(scriptHash).Script()
|
|
require.NoError(t, err)
|
|
|
|
type testCase struct {
|
|
name string
|
|
shutdownScript lnwire.DeliveryAddress
|
|
upfrontScript lnwire.DeliveryAddress
|
|
expectedErr error
|
|
}
|
|
tests := []testCase{
|
|
{
|
|
name: "no upfront shutdown set, script ok",
|
|
shutdownScript: p2wkh,
|
|
upfrontScript: []byte{},
|
|
expectedErr: nil,
|
|
},
|
|
{
|
|
name: "upfront shutdown set, script ok",
|
|
shutdownScript: p2wkh,
|
|
upfrontScript: p2wkh,
|
|
expectedErr: nil,
|
|
},
|
|
{
|
|
name: "upfront shutdown set, script not ok",
|
|
shutdownScript: p2wkh,
|
|
upfrontScript: p2wsh,
|
|
expectedErr: ErrUpfrontShutdownScriptMismatch,
|
|
},
|
|
{
|
|
name: "nil shutdown and empty upfront",
|
|
shutdownScript: nil,
|
|
upfrontScript: []byte{},
|
|
expectedErr: nil,
|
|
},
|
|
{
|
|
name: "p2tr is ok",
|
|
shutdownScript: p2tr,
|
|
},
|
|
{
|
|
name: "segwit v1 is ok",
|
|
shutdownScript: p2OtherV1,
|
|
},
|
|
{
|
|
name: "invalid script not allowed",
|
|
shutdownScript: invalidFork,
|
|
expectedErr: ErrInvalidShutdownScript,
|
|
},
|
|
}
|
|
|
|
// All future segwit softforks should also be ok.
|
|
futureForks := []byte{
|
|
txscript.OP_1, txscript.OP_2, txscript.OP_3, txscript.OP_4,
|
|
txscript.OP_5, txscript.OP_6, txscript.OP_7, txscript.OP_8,
|
|
txscript.OP_9, txscript.OP_10, txscript.OP_11, txscript.OP_12,
|
|
txscript.OP_13, txscript.OP_14, txscript.OP_15, txscript.OP_16,
|
|
}
|
|
for _, witnessVersion := range futureForks {
|
|
p2FutureFork, err := txscript.NewScriptBuilder().AddOp(witnessVersion).
|
|
AddData(scriptHash).Script()
|
|
require.NoError(t, err)
|
|
|
|
opString, err := txscript.DisasmString([]byte{witnessVersion})
|
|
require.NoError(t, err)
|
|
|
|
tests = append(tests, testCase{
|
|
name: fmt.Sprintf("witness_version=%v", opString),
|
|
shutdownScript: p2FutureFork,
|
|
})
|
|
}
|
|
|
|
for _, test := range tests {
|
|
test := test
|
|
|
|
t.Run(test.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
err := validateShutdownScript(
|
|
func() error { return nil }, test.upfrontScript,
|
|
test.shutdownScript, &chaincfg.SimNetParams,
|
|
)
|
|
|
|
if err != test.expectedErr {
|
|
t.Fatalf("Error: %v, expected error: %v", err, test.expectedErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type mockChannel struct {
|
|
chanPoint wire.OutPoint
|
|
initiator bool
|
|
scid lnwire.ShortChannelID
|
|
}
|
|
|
|
func (m *mockChannel) ChannelPoint() *wire.OutPoint {
|
|
return &m.chanPoint
|
|
}
|
|
|
|
func (m *mockChannel) MarkCoopBroadcasted(*wire.MsgTx, bool) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockChannel) IsInitiator() bool {
|
|
return m.initiator
|
|
}
|
|
|
|
func (m *mockChannel) ShortChanID() lnwire.ShortChannelID {
|
|
return m.scid
|
|
}
|
|
|
|
func (m *mockChannel) AbsoluteThawHeight() (uint32, error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (m *mockChannel) RemoteUpfrontShutdownScript() lnwire.DeliveryAddress {
|
|
return lnwire.DeliveryAddress{}
|
|
}
|
|
|
|
func (m *mockChannel) CreateCloseProposal(fee btcutil.Amount,
|
|
localScript, remoteScript []byte,
|
|
) (input.Signature, *chainhash.Hash, btcutil.Amount, error) {
|
|
|
|
return nil, nil, 0, nil
|
|
}
|
|
|
|
func (m *mockChannel) CompleteCooperativeClose(localSig,
|
|
remoteSig input.Signature, localScript, remoteScript []byte,
|
|
proposedFee btcutil.Amount) (*wire.MsgTx, btcutil.Amount, error) {
|
|
|
|
return nil, 0, nil
|
|
}
|
|
|
|
func (m *mockChannel) LocalBalanceDust() bool {
|
|
return false
|
|
}
|
|
|
|
func (m *mockChannel) RemoteBalanceDust() bool {
|
|
return false
|
|
}
|
|
|
|
type mockCoopFeeEstimator struct {
|
|
targetFee btcutil.Amount
|
|
}
|
|
|
|
func (m *mockCoopFeeEstimator) EstimateFee(chanType channeldb.ChannelType,
|
|
localTxOut, remoteTxOut *wire.TxOut,
|
|
idealFeeRate chainfee.SatPerKWeight) btcutil.Amount {
|
|
|
|
return m.targetFee
|
|
}
|
|
|
|
// TestMaxFeeClamp tests that if a max fee is specified, then it's used instead
|
|
// of the default max fee multiplier.
|
|
func TestMaxFeeClamp(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
const (
|
|
absoluteFeeOneSatByte = 126
|
|
absoluteFeeTenSatByte = 1265
|
|
)
|
|
|
|
tests := []struct {
|
|
name string
|
|
|
|
idealFee chainfee.SatPerKWeight
|
|
inputMaxFee chainfee.SatPerKWeight
|
|
|
|
maxFee btcutil.Amount
|
|
}{
|
|
{
|
|
// No max fee specified, we should see 3x the ideal fee.
|
|
name: "no max fee",
|
|
|
|
idealFee: chainfee.SatPerKWeight(253),
|
|
maxFee: absoluteFeeOneSatByte * defaultMaxFeeMultiplier,
|
|
},
|
|
{
|
|
// Max fee specified, this should be used in place.
|
|
name: "max fee clamp",
|
|
|
|
idealFee: chainfee.SatPerKWeight(253),
|
|
inputMaxFee: chainfee.SatPerKWeight(2530),
|
|
|
|
// We should get the resulting absolute fee based on a
|
|
// factor of 10 sat/byte (our new max fee).
|
|
maxFee: absoluteFeeTenSatByte,
|
|
},
|
|
}
|
|
for _, test := range tests {
|
|
test := test
|
|
|
|
t.Run(test.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
channel := mockChannel{}
|
|
|
|
chanCloser := NewChanCloser(
|
|
ChanCloseCfg{
|
|
Channel: &channel,
|
|
MaxFee: test.inputMaxFee,
|
|
FeeEstimator: &SimpleCoopFeeEstimator{},
|
|
}, nil, test.idealFee, 0, nil, false,
|
|
)
|
|
|
|
// We'll call initFeeBaseline early here since we need
|
|
// the populate these internal variables.
|
|
chanCloser.initFeeBaseline()
|
|
|
|
require.Equal(t, test.maxFee, chanCloser.maxFee)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestMaxFeeBailOut tests that once the negotiated fee rate rises above our
|
|
// maximum fee, we'll return an error and refuse to process a co-op close
|
|
// message.
|
|
func TestMaxFeeBailOut(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
const (
|
|
absoluteFee = btcutil.Amount(1000)
|
|
idealFee = chainfee.SatPerKWeight(253)
|
|
)
|
|
|
|
for _, isInitiator := range []bool{true, false} {
|
|
isInitiator := isInitiator
|
|
|
|
t.Run(fmt.Sprintf("initiator=%v", isInitiator), func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// First, we'll make our mock channel, and use that to
|
|
// instantiate our channel closer.
|
|
closeCfg := ChanCloseCfg{
|
|
Channel: &mockChannel{
|
|
initiator: isInitiator,
|
|
},
|
|
FeeEstimator: &mockCoopFeeEstimator{
|
|
targetFee: absoluteFee,
|
|
},
|
|
MaxFee: idealFee * 2,
|
|
}
|
|
chanCloser := NewChanCloser(
|
|
closeCfg, nil, idealFee, 0, nil, false,
|
|
)
|
|
|
|
// We'll now force the channel state into the
|
|
// closeFeeNegotiation state so we can skip straight to
|
|
// the juicy part. We'll also set our last fee sent so
|
|
// we'll attempt to actually "negotiate" here.
|
|
chanCloser.state = closeFeeNegotiation
|
|
chanCloser.lastFeeProposal = absoluteFee
|
|
|
|
// Next, we'll make a ClosingSigned message that
|
|
// proposes a fee that's above the specified max fee.
|
|
//
|
|
// NOTE: We use the absoluteFee here since our mock
|
|
// always returns this fee for the CalcFee method which
|
|
// is used to translate a fee rate
|
|
// into an absolute fee amount in sats.
|
|
closeMsg := &lnwire.ClosingSigned{
|
|
FeeSatoshis: absoluteFee * 2,
|
|
}
|
|
|
|
_, _, err := chanCloser.ProcessCloseMsg(closeMsg)
|
|
|
|
switch isInitiator {
|
|
// If we're the initiator, then we expect an error at
|
|
// this point.
|
|
case true:
|
|
require.ErrorIs(t, err, ErrProposalExeceedsMaxFee)
|
|
|
|
// Otherwise, we expect things to fail for some other
|
|
// reason (invalid sig, etc).
|
|
case false:
|
|
require.NotErrorIs(t, err, ErrProposalExeceedsMaxFee)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestParseUpfrontShutdownAddress tests the we are able to parse the upfront
|
|
// shutdown address properly.
|
|
func TestParseUpfrontShutdownAddress(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
testnetAddress = "tb1qdfkmwwgdaa5dnezrlhtftvmj5qn2kwgp7n0z6r"
|
|
regtestAddress = "bcrt1q09crvvuj95x5nk64wsxf5n6ky0kr8358vpx4d8"
|
|
)
|
|
|
|
tests := []struct {
|
|
name string
|
|
address string
|
|
params chaincfg.Params
|
|
expectedErr string
|
|
}{
|
|
{
|
|
name: "invalid closing address",
|
|
address: "non-valid-address",
|
|
params: chaincfg.RegressionNetParams,
|
|
expectedErr: "invalid address",
|
|
},
|
|
{
|
|
name: "closing address from another net",
|
|
address: testnetAddress,
|
|
params: chaincfg.RegressionNetParams,
|
|
expectedErr: "not a regtest address",
|
|
},
|
|
{
|
|
name: "valid p2wkh closing address",
|
|
address: regtestAddress,
|
|
params: chaincfg.RegressionNetParams,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, err := ParseUpfrontShutdownAddress(
|
|
tc.address, &tc.params,
|
|
)
|
|
|
|
if tc.expectedErr != "" {
|
|
require.ErrorContains(t, err, tc.expectedErr)
|
|
return
|
|
}
|
|
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
}
|