invoices+htlcswitch: add tests for relaxed link and invoice checks

This commit is contained in:
Keagan McClelland 2023-06-16 14:47:29 -06:00
parent 1b1eedb434
commit 36bf471a1f
2 changed files with 253 additions and 65 deletions

View file

@ -12,6 +12,7 @@ import (
"runtime" "runtime"
"sync" "sync"
"testing" "testing"
"testing/quick"
"time" "time"
"github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2"
@ -19,7 +20,6 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
sphinx "github.com/lightningnetwork/lightning-onion" sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
@ -125,7 +125,7 @@ func createInterceptorFunc(prefix, receiver string, messages []expectedMessage,
if messageChanID == chanID { if messageChanID == chanID {
if len(expectToReceive) == 0 { if len(expectToReceive) == 0 {
return false, errors.Errorf("%v received "+ return false, fmt.Errorf("%v received "+
"unexpected message out of range: %v", "unexpected message out of range: %v",
receiver, m.MsgType()) receiver, m.MsgType())
} }
@ -134,9 +134,13 @@ func createInterceptorFunc(prefix, receiver string, messages []expectedMessage,
expectToReceive = expectToReceive[1:] expectToReceive = expectToReceive[1:]
if expectedMessage.message.MsgType() != m.MsgType() { if expectedMessage.message.MsgType() != m.MsgType() {
return false, errors.Errorf("%v received wrong message: \n"+ return false, fmt.Errorf(
"real: %v\nexpected: %v", receiver, m.MsgType(), "%v received wrong message: \n"+
expectedMessage.message.MsgType()) "real: %v\nexpected: %v",
receiver,
m.MsgType(),
expectedMessage.message.MsgType(),
)
} }
if debug { if debug {
@ -721,11 +725,10 @@ func TestChannelLinkCancelFullCommitment(t *testing.T) {
} }
} }
// TestExitNodeTimelockPayloadMismatch tests that when an exit node receives an // TestExitNodeHLTCTimelockExceedsPayload tests that when an exit node receives
// incoming HTLC, if the time lock encoded in the payload of the forwarded HTLC // an incoming HTLC, if the timelock of the incoming HTLC is greater than or
// doesn't match the expected payment value, then the HTLC will be rejected // equal to the timelock encoded in the payload, then the HTLC will be accepted.
// with the appropriate error. func TestExitNodeHTLCTimelockExceedsPayload(t *testing.T) {
func TestExitNodeTimelockPayloadMismatch(t *testing.T) {
t.Parallel() t.Parallel()
channels, _, err := createClusterChannels( channels, _, err := createClusterChannels(
@ -733,35 +736,75 @@ func TestExitNodeTimelockPayloadMismatch(t *testing.T) {
) )
require.NoError(t, err, "unable to create channel") require.NoError(t, err, "unable to create channel")
n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, n := newThreeHopNetwork(
channels.bobToCarol, channels.carolToBob, testStartingHeight) t, channels.aliceToBob, channels.bobToAlice,
if err := n.start(); err != nil { channels.bobToCarol, channels.carolToBob, testStartingHeight,
t.Fatal(err) )
} require.NoError(t, n.start())
t.Cleanup(n.stop) t.Cleanup(n.stop)
const amount = btcutil.SatoshiPerBitcoin const amount = btcutil.SatoshiPerBitcoin
htlcAmt, htlcExpiry, hops := generateHops(amount, htlcAmt, htlcExpiry, hops := generateHops(
testStartingHeight, n.firstBobChannelLink) amount, testStartingHeight, n.firstBobChannelLink,
)
// In order to exercise this case, we'll now _manually_ modify the // In order to exercise this case, we'll now _manually_ modify the
// per-hop payload for outgoing time lock to be the incorrect value. // per-hop payload for outgoing time lock to be a compatible value that
// differs from the specified expiry.
// The proper value of the outgoing CLTV should be the policy set by // The proper value of the outgoing CLTV should be the policy set by
// the receiving node, instead we set it to be a random value. // the receiving node, instead we set it to be a value less than the
hops[0].FwdInfo.OutgoingCTLV = 500 // incoming HTLC timelock.
hops[0].FwdInfo.OutgoingCTLV = htlcExpiry - 1
firstHop := n.firstBobChannelLink.ShortChanID() firstHop := n.firstBobChannelLink.ShortChanID()
_, err = makePayment( _, err = makePayment(
n.aliceServer, n.bobServer, firstHop, hops, amount, htlcAmt, n.aliceServer, n.bobServer, firstHop, hops, amount, htlcAmt,
htlcExpiry, htlcExpiry,
).Wait(30 * time.Second) ).Wait(30 * time.Second)
if err == nil { require.NoError(t, err, "payment should have succeeded but didn't")
t.Fatalf("payment should have failed but didn't")
} }
rtErr, ok := err.(ClearTextError) // TestExitNodeTimelockPayloadExceedsHTLC tests that when an exit node receives
if !ok { // an incoming HTLC, if the timelock encoded in the payload of the forwarded
t.Fatalf("expected a ClearTextError, instead got: %T", err) // HTLC exceeds the timelock on the incoming HTLC, then the HTLC will be
} // rejected with the appropriate error.
func TestExitNodeTimelockPayloadExceedsHTLC(t *testing.T) {
t.Parallel()
channels, _, err := createClusterChannels(
t, btcutil.SatoshiPerBitcoin*5, btcutil.SatoshiPerBitcoin*5,
)
require.NoError(t, err, "unable to create channel")
n := newThreeHopNetwork(
t, channels.aliceToBob, channels.bobToAlice,
channels.bobToCarol, channels.carolToBob, testStartingHeight,
)
require.NoError(t, n.start())
t.Cleanup(n.stop)
const amount = btcutil.SatoshiPerBitcoin
htlcAmt, htlcExpiry, hops := generateHops(
amount, testStartingHeight, n.firstBobChannelLink,
)
// In order to exercise this case, we'll now _manually_ modify the
// per-hop payload for outgoing time lock to be the incorrect value.
// The proper value of the outgoing CLTV should be the policy set by
// the receiving node, instead we set it to be a value greater than the
// incoming HTLC timelock.
hops[0].FwdInfo.OutgoingCTLV = htlcExpiry + 1
firstHop := n.firstBobChannelLink.ShortChanID()
_, err = makePayment(
n.aliceServer, n.bobServer, firstHop, hops, amount, htlcAmt,
htlcExpiry,
).Wait(30 * time.Second)
require.NotNil(t, err, "payment should have failed but didn't")
rtErr := &ForwardingError{}
require.ErrorAs(
t, err, &rtErr, "expected a ClearTextError, instead got: %T",
err,
)
switch rtErr.WireMessage().(type) { switch rtErr.WireMessage().(type) {
case *lnwire.FailFinalIncorrectCltvExpiry: case *lnwire.FailFinalIncorrectCltvExpiry:
@ -771,43 +814,95 @@ func TestExitNodeTimelockPayloadMismatch(t *testing.T) {
} }
} }
// TestExitNodeAmountPayloadMismatch tests that when an exit node receives an // TestExitNodeHTLCUnderpaysPayloadAmount tests that when an exit node receives
// incoming HTLC, if the amount encoded in the onion payload of the forwarded // an incoming HTLC, if the amount offered in the HTLC is less than the amount
// HTLC doesn't match the expected payment value, then the HTLC will be // encoded in the onion payload then the HTLC will be rejected with the
// rejected. // appropriate error.
func TestExitNodeAmountPayloadMismatch(t *testing.T) { func TestExitNodeHTLCUnderpaysPayloadAmount(t *testing.T) {
t.Parallel() t.Parallel()
f := func(underpaymentRand uint64) bool {
underpayment := lnwire.MilliSatoshi(
underpaymentRand%(btcutil.SatoshiPerBitcoin-1) + 1,
)
channels, _, err := createClusterChannels( channels, _, err := createClusterChannels(
t, btcutil.SatoshiPerBitcoin*5, btcutil.SatoshiPerBitcoin*5, t, btcutil.SatoshiPerBitcoin*5,
btcutil.SatoshiPerBitcoin*5,
) )
require.NoError(t, err, "unable to create channel") require.NoError(t, err, "unable to create channel")
n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, n := newThreeHopNetwork(
channels.bobToCarol, channels.carolToBob, testStartingHeight) t, channels.aliceToBob, channels.bobToAlice,
if err := n.start(); err != nil { channels.bobToCarol, channels.carolToBob,
t.Fatal(err) testStartingHeight,
} )
require.NoError(t, n.start())
t.Cleanup(n.stop) t.Cleanup(n.stop)
const amount = btcutil.SatoshiPerBitcoin const amount = btcutil.SatoshiPerBitcoin
htlcAmt, htlcExpiry, hops := generateHops(amount, testStartingHeight, htlcAmt, htlcExpiry, hops := generateHops(
n.firstBobChannelLink) amount, testStartingHeight, n.firstBobChannelLink,
)
// In order to exercise this case, we'll now _manually_ modify the // In order to exercise this case, we'll now _manually_ modify
// per-hop payload for amount to be the incorrect value. The proper // the per-hop payload for amount to be the incorrect value.
// value of the amount to forward should be the amount that the // The acceptable values of the amount to forward should be less
// receiving node expects to receive. // than the incoming HTLC value.
hops[0].FwdInfo.AmountToForward = 1 hops[0].FwdInfo.AmountToForward = amount + underpayment
firstHop := n.firstBobChannelLink.ShortChanID() firstHop := n.firstBobChannelLink.ShortChanID()
_, err = makePayment( _, err = makePayment(
n.aliceServer, n.bobServer, firstHop, hops, amount, htlcAmt, n.aliceServer, n.bobServer, firstHop, hops, amount,
htlcExpiry, htlcAmt, htlcExpiry,
).Wait(30 * time.Second) ).Wait(30 * time.Second)
if err == nil {
t.Fatalf("payment should have failed but didn't")
}
assertFailureCode(t, err, lnwire.CodeFinalIncorrectHtlcAmount) assertFailureCode(t, err, lnwire.CodeFinalIncorrectHtlcAmount)
return err != nil
}
err := quick.Check(f, &quick.Config{MaxCount: 20})
require.NoError(t, err, "payment should have failed but didn't")
}
// TestExitNodeHTLCExceedsAmountPayload tests that when an exit node receives an
// incoming HTLC, if the amount encoded in the onion payload of the forwarded
// HTLC is lower than the incoming HTLC value, then the HTLC will be accepted.
func TestExitNodeHTLCExceedsAmountPayload(t *testing.T) {
t.Parallel()
f := func(overpaymentRand uint64) bool {
overpayment := lnwire.MilliSatoshi(
overpaymentRand%(btcutil.SatoshiPerBitcoin-1) + 1,
)
channels, _, err := createClusterChannels(
t, btcutil.SatoshiPerBitcoin*5,
btcutil.SatoshiPerBitcoin*5,
)
require.NoError(t, err, "unable to create channel")
n := newThreeHopNetwork(t, channels.aliceToBob,
channels.bobToAlice, channels.bobToCarol,
channels.carolToBob, testStartingHeight)
require.NoError(t, n.start())
t.Cleanup(n.stop)
const amount = btcutil.SatoshiPerBitcoin
htlcAmt, htlcExpiry, hops := generateHops(amount,
testStartingHeight, n.firstBobChannelLink)
// In order to exercise this case, we'll now _manually_ modify
// the per-hop payload for amount to be the incorrect value.
// The acceptable values of the amount to forward should be
// lower than the incoming HTLC value.
hops[0].FwdInfo.AmountToForward = amount - overpayment
firstHop := n.firstBobChannelLink.ShortChanID()
_, err = makePayment(
n.aliceServer, n.bobServer, firstHop, hops, amount,
htlcAmt, htlcExpiry,
).Wait(30 * time.Second)
return err == nil
}
err := quick.Check(f, &quick.Config{MaxCount: 20})
require.NoError(t, err, "payment should have succeeded but didn't")
} }
// TestLinkForwardTimelockPolicyMismatch tests that if a node is an // TestLinkForwardTimelockPolicyMismatch tests that if a node is an
@ -3512,25 +3607,37 @@ func TestChannelRetransmission(t *testing.T) {
// bandwidth of htlc links hasn't been changed. // bandwidth of htlc links hasn't been changed.
invoice, err = receiver.registry.LookupInvoice(rhash) invoice, err = receiver.registry.LookupInvoice(rhash)
if err != nil { if err != nil {
err = errors.Errorf("unable to get invoice: %v", err) err = fmt.Errorf(
"unable to get invoice: %w", err,
)
continue continue
} }
if invoice.State != invpkg.ContractSettled { if invoice.State != invpkg.ContractSettled {
err = errors.Errorf("alice invoice haven't been settled") err = fmt.Errorf(
"alice invoice haven't been settled",
)
continue continue
} }
aliceExpectedBandwidth := aliceBandwidthBefore - htlcAmt aliceExpectedBandwidth := aliceBandwidthBefore - htlcAmt
if aliceExpectedBandwidth != n.aliceChannelLink.Bandwidth() { if aliceExpectedBandwidth != n.aliceChannelLink.Bandwidth() {
err = errors.Errorf("expected alice to have %v, instead has %v", err = fmt.Errorf(
aliceExpectedBandwidth, n.aliceChannelLink.Bandwidth()) "expected alice to have %v,"+
" instead has %v",
aliceExpectedBandwidth,
n.aliceChannelLink.Bandwidth(),
)
continue continue
} }
bobExpectedBandwidth := bobBandwidthBefore + htlcAmt bobExpectedBandwidth := bobBandwidthBefore + htlcAmt
if bobExpectedBandwidth != n.firstBobChannelLink.Bandwidth() { if bobExpectedBandwidth != n.firstBobChannelLink.Bandwidth() {
err = errors.Errorf("expected bob to have %v, instead has %v", err = fmt.Errorf(
bobExpectedBandwidth, n.firstBobChannelLink.Bandwidth()) "expected bob to have %v,"+
" instead has %v",
bobExpectedBandwidth,
n.firstBobChannelLink.Bandwidth(),
)
continue continue
} }
@ -5517,8 +5624,10 @@ func TestExpectedFee(t *testing.T) {
} }
fee := ExpectedFee(f, test.htlcAmt) fee := ExpectedFee(f, test.htlcAmt)
if fee != test.expected { if fee != test.expected {
t.Errorf("expected fee to be (%v), instead got (%v)", test.expected, t.Errorf(
fee) "expected fee to be (%v), instead got (%v)",
test.expected, fee,
)
} }
} }
} }

View file

@ -4,6 +4,7 @@ import (
"crypto/rand" "crypto/rand"
"math" "math"
"testing" "testing"
"testing/quick"
"time" "time"
"github.com/lightningnetwork/lnd/amp" "github.com/lightningnetwork/lnd/amp"
@ -925,6 +926,84 @@ func TestMppPayment(t *testing.T) {
} }
} }
// TestMppPaymentWithOverpayment tests settling of an invoice with multiple
// partial payments. It covers the case where the mpp overpays what is in the
// invoice.
func TestMppPaymentWithOverpayment(t *testing.T) {
t.Parallel()
defer timeout()()
f := func(overpayment_rand uint64) bool {
ctx := newTestContext(t, nil)
// Add the invoice.
_, err := ctx.registry.AddInvoice(
testInvoice, testInvoicePaymentHash,
)
if err != nil {
t.Fatal(err)
}
mppPayload := &mockPayload{
mpp: record.NewMPP(testInvoiceAmt, [32]byte{}),
}
// We constrain overpayment amount to be [1,1000].
overpayment := lnwire.MilliSatoshi((overpayment_rand % 999) + 1)
// Send htlc 1.
hodlChan1 := make(chan interface{}, 1)
resolution, err := ctx.registry.NotifyExitHopHtlc(
testInvoicePaymentHash, testInvoice.Terms.Value/2,
testHtlcExpiry, testCurrentHeight, getCircuitKey(11),
hodlChan1, mppPayload,
)
if err != nil {
t.Fatal(err)
}
if resolution != nil {
t.Fatal("expected no direct resolution")
}
// Send htlc 2.
hodlChan2 := make(chan interface{}, 1)
resolution, err = ctx.registry.NotifyExitHopHtlc(
testInvoicePaymentHash,
testInvoice.Terms.Value/2+overpayment, testHtlcExpiry,
testCurrentHeight, getCircuitKey(12), hodlChan2,
mppPayload,
)
if err != nil {
t.Fatal(err)
}
settleResolution, ok :=
resolution.(*invpkg.HtlcSettleResolution)
if !ok {
t.Fatalf("expected settle resolution, got: %T",
resolution)
}
if settleResolution.Outcome != invpkg.ResultSettled {
t.Fatalf("expected result settled, got: %v",
settleResolution.Outcome)
}
// Check that settled amount is equal to the sum of values of
// the htlcs 1 and 2.
inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash)
if err != nil {
t.Fatal(err)
}
if inv.State != invpkg.ContractSettled {
t.Fatal("expected invoice to be settled")
}
return inv.AmtPaid == testInvoice.Terms.Value+overpayment
}
if err := quick.Check(f, &quick.Config{MaxCount: 50}); err != nil {
t.Fatalf("amount incorrect: %v", err)
}
}
// Tests that invoices are canceled after expiration. // Tests that invoices are canceled after expiration.
func TestInvoiceExpiryWithRegistry(t *testing.T) { func TestInvoiceExpiryWithRegistry(t *testing.T) {
t.Parallel() t.Parallel()