lnd/htlcswitch/switch_test.go
2024-03-05 16:43:29 -08:00

5542 lines
149 KiB
Go

package htlcswitch
import (
"context"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
mrand "math/rand"
"reflect"
"testing"
"time"
"github.com/btcsuite/btcd/btcutil"
"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/htlcswitch/hodl"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntest/mock"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/ticker"
"github.com/stretchr/testify/require"
)
var zeroCircuit = models.CircuitKey{}
var emptyScid = lnwire.ShortChannelID{}
func genPreimage() ([32]byte, error) {
var preimage [32]byte
if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil {
return preimage, err
}
return preimage, nil
}
// TestSwitchAddDuplicateLink tests that the switch will reject duplicate links
// for live links. It also tests that we can successfully add a link after
// having removed it.
func TestSwitchAddDuplicateLink(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, aliceScid := genID()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceScid, emptyScid, alicePeer, false, false,
false, false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
// Alice should have a live link, adding again should fail.
if err := s.AddLink(aliceChannelLink); err == nil {
t.Fatalf("adding duplicate link should have failed")
}
// Remove the live link to ensure the indexes are cleared.
s.RemoveLink(chanID1)
// Alice has no links, adding should succeed.
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
}
// TestSwitchHasActiveLink tests the behavior of HasActiveLink, and asserts that
// it only returns true if a link's short channel id has confirmed (meaning the
// channel is no longer pending) and it's EligibleToForward method returns true,
// i.e. it has received ChannelReady from the remote peer.
func TestSwitchHasActiveLink(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, aliceScid := genID()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceScid, emptyScid, alicePeer, false, false,
false, false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
// The link has been added, but it's still pending. HasActiveLink should
// return false since the link has not been added to the linkIndex
// containing live links.
if s.HasActiveLink(chanID1) {
t.Fatalf("link should not be active yet, still pending")
}
// Finally, simulate the link receiving channel_ready by setting its
// eligibility to true.
aliceChannelLink.eligible = true
// The link should now be reported as active, since EligibleToForward
// returns true and the link is in the linkIndex.
if !s.HasActiveLink(chanID1) {
t.Fatalf("link should not be active now")
}
}
// TestSwitchSendPending checks the inability of htlc switch to forward adds
// over pending links.
func TestSwitchSendPending(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
pendingChanID := lnwire.ShortChannelID{}
aliceChannelLink := newMockChannelLink(
s, chanID1, pendingChanID, emptyScid, alicePeer, false, false,
false, false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should is being forwarded from Bob channel
// link to Alice channel link.
preimage, err := genPreimage()
require.NoError(t, err, "unable to generate preimage")
rhash := sha256.Sum256(preimage[:])
packet := &htlcPacket{
incomingChanID: bobChanID,
incomingHTLCID: 0,
outgoingChanID: aliceChanID,
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Send the ADD packet, this should not be forwarded out to the link
// since there are no eligible links.
if err = s.ForwardPackets(nil, packet); err != nil {
t.Fatal(err)
}
select {
case p := <-bobChannelLink.packets:
if p.linkFailure != nil {
err = p.linkFailure
}
case <-time.After(time.Second):
t.Fatal("no timely reply from switch")
}
linkErr, ok := err.(*LinkError)
if !ok {
t.Fatalf("expected link error, got: %T", err)
}
if linkErr.WireMessage().Code() != lnwire.CodeUnknownNextPeer {
t.Fatalf("expected fail unknown next peer, got: %T",
linkErr.WireMessage().Code())
}
// No message should be sent, since the packet was failed.
select {
case <-aliceChannelLink.packets:
t.Fatal("expected not to receive message")
case <-time.After(time.Second):
}
// Since the packet should have been failed, there should be no active
// circuits.
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
// TestSwitchForwardMapping checks that the Switch properly consults its maps
// when forwarding packets.
func TestSwitchForwardMapping(t *testing.T) {
tests := []struct {
name string
// If this is true, then Alice's channel will be private.
alicePrivate bool
// If this is true, then Alice's channel will be a zero-conf
// channel.
zeroConf bool
// If this is true, then Alice's channel will be an
// option-scid-alias feature-bit, non-zero-conf channel.
optionScid bool
// If this is true, then an alias will be used for forwarding.
useAlias bool
// This is Alice's channel alias. This may not be set if this
// is not an option_scid_alias channel (feature bit).
aliceAlias lnwire.ShortChannelID
// This is Alice's confirmed SCID. This may not be set if this
// is a zero-conf channel before confirmation.
aliceReal lnwire.ShortChannelID
// If this is set, we expect Bob forwarding to Alice to fail.
expectErr bool
}{
{
name: "private unconfirmed zero-conf",
alicePrivate: true,
zeroConf: true,
useAlias: true,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_002,
TxIndex: 2,
TxPosition: 2,
},
aliceReal: lnwire.ShortChannelID{},
expectErr: false,
},
{
name: "private confirmed zero-conf",
alicePrivate: true,
zeroConf: true,
useAlias: true,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_003,
TxIndex: 3,
TxPosition: 3,
},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 300000,
TxIndex: 3,
TxPosition: 3,
},
expectErr: false,
},
{
name: "private confirmed zero-conf failure",
alicePrivate: true,
zeroConf: true,
useAlias: false,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_004,
TxIndex: 4,
TxPosition: 4,
},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 300002,
TxIndex: 4,
TxPosition: 4,
},
expectErr: true,
},
{
name: "public unconfirmed zero-conf",
alicePrivate: false,
zeroConf: true,
useAlias: true,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_005,
TxIndex: 5,
TxPosition: 5,
},
aliceReal: lnwire.ShortChannelID{},
expectErr: false,
},
{
name: "public confirmed zero-conf w/ alias",
alicePrivate: false,
zeroConf: true,
useAlias: true,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_006,
TxIndex: 6,
TxPosition: 6,
},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 500000,
TxIndex: 6,
TxPosition: 6,
},
expectErr: false,
},
{
name: "public confirmed zero-conf w/ real",
alicePrivate: false,
zeroConf: true,
useAlias: false,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_007,
TxIndex: 7,
TxPosition: 7,
},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 502000,
TxIndex: 7,
TxPosition: 7,
},
expectErr: false,
},
{
name: "private non-option channel",
alicePrivate: true,
aliceAlias: lnwire.ShortChannelID{},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 505000,
TxIndex: 8,
TxPosition: 8,
},
},
{
name: "private option channel w/ alias",
alicePrivate: true,
optionScid: true,
useAlias: true,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_015,
TxIndex: 9,
TxPosition: 9,
},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 506000,
TxIndex: 10,
TxPosition: 10,
},
expectErr: false,
},
{
name: "private option channel failure",
alicePrivate: true,
optionScid: true,
useAlias: false,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_016,
TxIndex: 16,
TxPosition: 16,
},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 507000,
TxIndex: 17,
TxPosition: 17,
},
expectErr: true,
},
{
name: "public non-option channel",
alicePrivate: false,
useAlias: false,
aliceAlias: lnwire.ShortChannelID{},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 508000,
TxIndex: 17,
TxPosition: 17,
},
expectErr: false,
},
{
name: "public option channel w/ alias",
alicePrivate: false,
optionScid: true,
useAlias: true,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_018,
TxIndex: 18,
TxPosition: 18,
},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 509000,
TxIndex: 19,
TxPosition: 19,
},
expectErr: false,
},
{
name: "public option channel w/ real",
alicePrivate: false,
optionScid: true,
useAlias: false,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_019,
TxIndex: 19,
TxPosition: 19,
},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 510000,
TxIndex: 20,
TxPosition: 20,
},
expectErr: false,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
testSwitchForwardMapping(
t, test.alicePrivate, test.zeroConf,
test.useAlias, test.optionScid,
test.aliceAlias, test.aliceReal,
test.expectErr,
)
})
}
}
func testSwitchForwardMapping(t *testing.T, alicePrivate, aliceZeroConf,
useAlias, optionScid bool, aliceAlias, aliceReal lnwire.ShortChannelID,
expectErr bool) {
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
defer func() { _ = s.Stop() }()
// Create the lnwire.ChannelIDs that we'll use.
chanID1, chanID2, _, _ := genIDs()
var aliceChannelLink *mockChannelLink
if aliceZeroConf {
aliceChannelLink = newMockChannelLink(
s, chanID1, aliceAlias, aliceReal, alicePeer, true,
alicePrivate, true, false,
)
} else {
aliceChannelLink = newMockChannelLink(
s, chanID1, aliceReal, emptyScid, alicePeer, true,
alicePrivate, false, optionScid,
)
if optionScid {
aliceChannelLink.addAlias(aliceAlias)
}
}
err = s.AddLink(aliceChannelLink)
require.NoError(t, err)
// Bob will just have a non-option_scid_alias channel so no mapping is
// necessary.
bobScid := lnwire.ShortChannelID{
BlockHeight: 501000,
TxIndex: 200,
TxPosition: 2,
}
bobChannelLink := newMockChannelLink(
s, chanID2, bobScid, emptyScid, bobPeer, true, false, false,
false,
)
err = s.AddLink(bobChannelLink)
require.NoError(t, err)
// Generate preimage.
preimage, err := genPreimage()
require.NoError(t, err, "unable to generate preimage")
rhash := sha256.Sum256(preimage[:])
// Determine the outgoing SCID to use.
outgoingSCID := aliceReal
if useAlias {
outgoingSCID = aliceAlias
}
packet := &htlcPacket{
incomingChanID: bobScid,
incomingHTLCID: 0,
outgoingChanID: outgoingSCID,
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
err = s.ForwardPackets(nil, packet)
require.NoError(t, err)
// If we expect a forwarding error, then assert that we receive one.
// option_scid_alias forwards may fail if forwarding would be a privacy
// leak.
if expectErr {
select {
case <-bobChannelLink.packets:
case <-time.After(time.Second * 5):
t.Fatal("expected a forwarding error")
}
select {
case <-aliceChannelLink.packets:
t.Fatal("did not expect a packet")
case <-time.After(time.Second * 5):
}
} else {
select {
case <-bobChannelLink.packets:
t.Fatal("did not expect a forwarding error")
case <-time.After(time.Second * 5):
}
select {
case <-aliceChannelLink.packets:
case <-time.After(time.Second * 5):
t.Fatal("expected alice to receive packet")
}
}
}
// TestSwitchSendHTLCMapping tests that SendHTLC will properly route packets to
// zero-conf or option-scid-alias (feature-bit) channels if the confirmed SCID
// is used. It also tests that nothing breaks with the mapping change.
func TestSwitchSendHTLCMapping(t *testing.T) {
tests := []struct {
name string
// If this is true, the channel will be zero-conf.
zeroConf bool
// Denotes whether the channel is option-scid-alias, non
// zero-conf feature bit.
optionFeature bool
// If this is true, then the alias will be used in the packet.
useAlias bool
// This will be the channel alias if there is a mapping.
alias lnwire.ShortChannelID
// This will be the confirmed SCID if the channel is confirmed.
real lnwire.ShortChannelID
}{
{
name: "non-zero-conf real scid w/ option",
zeroConf: false,
optionFeature: true,
useAlias: false,
alias: lnwire.ShortChannelID{
BlockHeight: 10010,
TxIndex: 10,
TxPosition: 10,
},
real: lnwire.ShortChannelID{
BlockHeight: 500000,
TxIndex: 50,
TxPosition: 50,
},
},
{
name: "non-zero-conf real scid no option",
zeroConf: false,
useAlias: false,
alias: lnwire.ShortChannelID{},
real: lnwire.ShortChannelID{
BlockHeight: 400000,
TxIndex: 50,
TxPosition: 50,
},
},
{
name: "zero-conf alias scid w/ conf",
zeroConf: true,
useAlias: true,
alias: lnwire.ShortChannelID{
BlockHeight: 10020,
TxIndex: 20,
TxPosition: 20,
},
real: lnwire.ShortChannelID{
BlockHeight: 450000,
TxIndex: 50,
TxPosition: 50,
},
},
{
name: "zero-conf alias scid no conf",
zeroConf: true,
useAlias: true,
alias: lnwire.ShortChannelID{
BlockHeight: 10015,
TxIndex: 25,
TxPosition: 35,
},
real: lnwire.ShortChannelID{},
},
{
name: "zero-conf real scid",
zeroConf: true,
useAlias: false,
alias: lnwire.ShortChannelID{
BlockHeight: 10035,
TxIndex: 35,
TxPosition: 35,
},
real: lnwire.ShortChannelID{
BlockHeight: 470000,
TxIndex: 35,
TxPosition: 45,
},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
testSwitchSendHtlcMapping(
t, test.zeroConf, test.useAlias, test.alias,
test.real, test.optionFeature,
)
})
}
}
func testSwitchSendHtlcMapping(t *testing.T, zeroConf, useAlias bool, alias,
realScid lnwire.ShortChannelID, optionFeature bool) {
peer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
defer func() { _ = s.Stop() }()
// Create the lnwire.ChannelID that we'll use.
chanID, _ := genID()
var link *mockChannelLink
if zeroConf {
link = newMockChannelLink(
s, chanID, alias, realScid, peer, true, false, true,
false,
)
} else {
link = newMockChannelLink(
s, chanID, realScid, emptyScid, peer, true, false,
false, true,
)
if optionFeature {
link.addAlias(alias)
}
}
err = s.AddLink(link)
require.NoError(t, err)
// Generate preimage.
preimage, err := genPreimage()
require.NoError(t, err)
rhash := sha256.Sum256(preimage[:])
// Determine the outgoing SCID to use.
outgoingSCID := realScid
if useAlias {
outgoingSCID = alias
}
// Send the HTLC and assert that we don't get an error.
htlc := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
}
err = s.SendHTLC(outgoingSCID, 0, htlc)
require.NoError(t, err)
}
// TestSwitchUpdateScid verifies that zero-conf and non-zero-conf
// option-scid-alias (feature bit) channels will have the expected entries in
// the aliasToReal and baseIndex maps.
func TestSwitchUpdateScid(t *testing.T) {
t.Parallel()
peer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
defer func() { _ = s.Stop() }()
// Create the IDs that we'll use.
chanID, chanID2, _, _ := genIDs()
alias := lnwire.ShortChannelID{
BlockHeight: 16_000_000,
TxIndex: 0,
TxPosition: 0,
}
alias2 := alias
alias2.TxPosition = 1
realScid := lnwire.ShortChannelID{
BlockHeight: 500000,
TxIndex: 0,
TxPosition: 0,
}
link := newMockChannelLink(
s, chanID, alias, emptyScid, peer, true, false, true, false,
)
link.addAlias(alias2)
err = s.AddLink(link)
require.NoError(t, err)
// Assert that the zero-conf link does not have entries in the
// aliasToReal map.
s.indexMtx.RLock()
_, ok := s.aliasToReal[alias]
require.False(t, ok)
_, ok = s.aliasToReal[alias2]
require.False(t, ok)
// Assert that both aliases point to the "base" SCID, which is actually
// just the first alias.
baseScid, ok := s.baseIndex[alias]
require.True(t, ok)
require.Equal(t, alias, baseScid)
baseScid, ok = s.baseIndex[alias2]
require.True(t, ok)
require.Equal(t, alias, baseScid)
s.indexMtx.RUnlock()
// We'll set the mock link's confirmed SCID so that UpdateShortChanID
// populates aliasToReal and adds an entry to baseIndex.
link.realScid = realScid
link.confirmedZC = true
err = s.UpdateShortChanID(chanID)
require.NoError(t, err)
// Assert that aliasToReal is populated and there is an entry in
// baseIndex for realScid.
s.indexMtx.RLock()
realMapping, ok := s.aliasToReal[alias]
require.True(t, ok)
require.Equal(t, realScid, realMapping)
realMapping, ok = s.aliasToReal[alias2]
require.True(t, ok)
require.Equal(t, realScid, realMapping)
baseScid, ok = s.baseIndex[realScid]
require.True(t, ok)
require.Equal(t, alias, baseScid)
s.indexMtx.RUnlock()
// Now we'll perform the same checks with a non-zero-conf
// option-scid-alias channel (feature-bit).
optionReal := lnwire.ShortChannelID{
BlockHeight: 600000,
TxIndex: 0,
TxPosition: 0,
}
optionAlias := lnwire.ShortChannelID{
BlockHeight: 12000,
TxIndex: 0,
TxPosition: 0,
}
optionAlias2 := optionAlias
optionAlias2.TxPosition = 1
link2 := newMockChannelLink(
s, chanID2, optionReal, emptyScid, peer, true, false, false,
true,
)
link2.addAlias(optionAlias)
link2.addAlias(optionAlias2)
err = s.AddLink(link2)
require.NoError(t, err)
// Assert that the option-scid-alias link does have entries in the
// aliasToReal and baseIndex maps.
s.indexMtx.RLock()
realMapping, ok = s.aliasToReal[optionAlias]
require.True(t, ok)
require.Equal(t, optionReal, realMapping)
realMapping, ok = s.aliasToReal[optionAlias2]
require.True(t, ok)
require.Equal(t, optionReal, realMapping)
baseScid, ok = s.baseIndex[optionReal]
require.True(t, ok)
require.Equal(t, optionReal, baseScid)
baseScid, ok = s.baseIndex[optionAlias]
require.True(t, ok)
require.Equal(t, optionReal, baseScid)
baseScid, ok = s.baseIndex[optionAlias2]
require.True(t, ok)
require.Equal(t, optionReal, baseScid)
s.indexMtx.RUnlock()
}
// TestSwitchForward checks the ability of htlc switch to forward add/settle
// requests.
func TestSwitchForward(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
s, err := initSwitchWithTempDB(t, testStartingHeight)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage, err := genPreimage()
if err != nil {
t.Fatalf("unable to generate preimage: %v", err)
}
rhash := sha256.Sum256(preimage[:])
packet := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, packet); err != nil {
t.Fatal(err)
}
select {
case <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
if !s.IsForwardedHTLC(bobChannelLink.ShortChanID(), 0) {
t.Fatal("htlc should be identified as forwarded")
}
// Create settle request pretending that bob link handled the add htlc
// request and sent the htlc settle request back. This request should
// be forwarder back to Alice link.
packet = &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimage,
},
}
// Handle the request and checks that payment circuit works properly.
if err := s.ForwardPackets(nil, packet); err != nil {
t.Fatal(err)
}
select {
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.deleteCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
func TestSwitchForwardFailAfterFullAdd(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to open channeldb")
t.Cleanup(func() { cdb.Close() })
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Pull packet from bob's link, but do not perform a full add.
select {
case packet := <-bobChannelLink.packets:
// Complete the payment circuit and assign the outgoing htlc id
// before restarting.
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of circuits")
}
// Now we will restart bob, leaving the forwarding decision for this
// htlc is in the half-added state.
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to reopen channeldb")
t.Cleanup(func() { cdb2.Close() })
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
require.NoError(t, err, "unable reinit switch")
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
// Even though we intend to Stop s2 later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of circuits")
}
// Craft a failure message from the remote peer.
fail := &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{},
}
// Send the fail packet from the remote peer through the switch.
if err := s2.ForwardPackets(nil, fail); err != nil {
t.Fatalf(err.Error())
}
// Pull packet from alice's link, as it should have gone through
// successfully.
select {
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Circuit map should be empty now.
if s2.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Send the fail packet from the remote peer through the switch.
if err := s.ForwardPackets(nil, fail); err != nil {
t.Fatal(err)
}
select {
case <-aliceChannelLink.packets:
t.Fatalf("expected duplicate fail to not arrive at the destination")
case <-time.After(time.Second):
}
}
func TestSwitchForwardSettleAfterFullAdd(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to open channeldb")
t.Cleanup(func() { cdb.Close() })
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Pull packet from bob's link, but do not perform a full add.
select {
case packet := <-bobChannelLink.packets:
// Complete the payment circuit and assign the outgoing htlc id
// before restarting.
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of circuits")
}
// Now we will restart bob, leaving the forwarding decision for this
// htlc is in the half-added state.
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to reopen channeldb")
t.Cleanup(func() { cdb2.Close() })
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
require.NoError(t, err, "unable reinit switch")
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
// Even though we intend to Stop s2 later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of circuits")
}
// Craft a settle message from the remote peer.
settle := &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimage,
},
}
// Send the settle packet from the remote peer through the switch.
if err := s2.ForwardPackets(nil, settle); err != nil {
t.Fatalf(err.Error())
}
// Pull packet from alice's link, as it should have gone through
// successfully.
select {
case packet := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete circuit with in key=%s: %v",
packet.inKey(), err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Circuit map should be empty now.
if s2.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Send the settle packet again, which not arrive at destination.
if err := s2.ForwardPackets(nil, settle); err != nil {
t.Fatal(err)
}
select {
case <-bobChannelLink.packets:
t.Fatalf("expected duplicate fail to not arrive at the destination")
case <-time.After(time.Second):
}
}
func TestSwitchForwardDropAfterFullAdd(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to open channeldb")
t.Cleanup(func() { cdb.Close() })
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of half circuits")
}
// Pull packet from bob's link, but do not perform a full add.
select {
case packet := <-bobChannelLink.packets:
// Complete the payment circuit and assign the outgoing htlc id
// before restarting.
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Now we will restart bob, leaving the forwarding decision for this
// htlc is in the half-added state.
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to reopen channeldb")
t.Cleanup(func() { cdb2.Close() })
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
require.NoError(t, err, "unable reinit switch")
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
// Even though we intend to Stop s2 later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of half circuits")
}
// Resend the failed htlc. The packet will be dropped silently since the
// switch will detect that it has been half added previously.
if err := s2.ForwardPackets(nil, ogPacket); err != nil {
t.Fatal(err)
}
// After detecting an incomplete forward, the fail packet should have
// been returned to the sender.
select {
case <-aliceChannelLink.packets:
t.Fatal("request should not have returned to source")
case <-bobChannelLink.packets:
t.Fatal("request should not have forwarded to destination")
case <-time.After(time.Second):
}
}
func TestSwitchForwardFailAfterHalfAdd(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to open channeldb")
t.Cleanup(func() { cdb.Close() })
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of half circuits")
}
// Pull packet from bob's link, but do not perform a full add.
select {
case <-bobChannelLink.packets:
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Now we will restart bob, leaving the forwarding decision for this
// htlc is in the half-added state.
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to reopen channeldb")
t.Cleanup(func() { cdb2.Close() })
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
require.NoError(t, err, "unable reinit switch")
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
// Even though we intend to Stop s2 later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of half circuits")
}
// Resend the failed htlc, it should be returned to alice since the
// switch will detect that it has been half added previously.
err = s2.ForwardPackets(nil, ogPacket)
if err != nil {
t.Fatal(err)
}
// After detecting an incomplete forward, the fail packet should have
// been returned to the sender.
select {
case pkt := <-aliceChannelLink.packets:
linkErr := pkt.linkFailure
if linkErr.FailureDetail != OutgoingFailureIncompleteForward {
t.Fatalf("expected incomplete forward, got: %v",
linkErr.FailureDetail)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
}
// TestSwitchForwardCircuitPersistence checks the ability of htlc switch to
// maintain the proper entries in the circuit map in the face of restarts.
func TestSwitchForwardCircuitPersistence(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to open channeldb")
t.Cleanup(func() { cdb.Close() })
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Retrieve packet from outgoing link and cache until after restart.
var packet *htlcPacket
select {
case packet = <-bobChannelLink.packets:
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to reopen channeldb")
t.Cleanup(func() { cdb2.Close() })
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
require.NoError(t, err, "unable reinit switch")
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
// Even though we intend to Stop s2 later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of half circuits")
}
// Now that the switch has restarted, complete the payment circuit.
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
// Create settle request pretending that bob link handled the add htlc
// request and sent the htlc settle request back. This request should
// be forwarder back to Alice link.
ogPacket = &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimage,
},
}
// Handle the request and checks that payment circuit works properly.
if err := s2.ForwardPackets(nil, ogPacket); err != nil {
t.Fatal(err)
}
select {
case packet = <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete circuit with in key=%s: %v",
packet.inKey(), err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s2.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits, want 1, got %d",
s2.circuits.NumPending())
}
if s2.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
if err := s2.Stop(); err != nil {
t.Fatal(err)
}
if err := cdb2.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb3, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to reopen channeldb")
t.Cleanup(func() { cdb3.Close() })
s3, err := initSwitchWithDB(testStartingHeight, cdb3)
require.NoError(t, err, "unable reinit switch")
if err := s3.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
defer s3.Stop()
aliceChannelLink = newMockChannelLink(
s3, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink = newMockChannelLink(
s3, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s3.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s3.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s3.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s3.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
}
type multiHopFwdTest struct {
name string
eligible1, eligible2 bool
failure1, failure2 *LinkError
expectedReply lnwire.FailCode
}
// TestCircularForwards tests the allowing/disallowing of circular payments
// through the same channel in the case where the switch is configured to allow
// and disallow same channel circular forwards.
func TestCircularForwards(t *testing.T) {
chanID1, aliceChanID := genID()
preimage := [sha256.Size]byte{1}
hash := sha256.Sum256(preimage[:])
tests := []struct {
name string
allowCircularPayment bool
expectedErr error
}{
{
name: "circular payment allowed",
allowCircularPayment: true,
expectedErr: nil,
},
{
name: "circular payment disallowed",
allowCircularPayment: false,
expectedErr: NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureCircularRoute,
),
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil,
testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v",
err)
}
s, err := initSwitchWithTempDB(t, testStartingHeight)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer func() { _ = s.Stop() }()
// Set the switch to allow or disallow circular routes
// according to the test's requirements.
s.cfg.AllowCircularRoute = test.allowCircularPayment
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer,
true, false, false, false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
// Create a new packet that loops through alice's link
// in a circle.
obfuscator := NewMockObfuscator()
packet := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
outgoingChanID: aliceChannelLink.ShortChanID(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: hash,
Amount: 1,
},
obfuscator: obfuscator,
}
// Attempt to forward the packet and check for the expected
// error.
if err = s.ForwardPackets(nil, packet); err != nil {
t.Fatal(err)
}
select {
case p := <-aliceChannelLink.packets:
if p.linkFailure != nil {
err = p.linkFailure
}
case <-time.After(time.Second):
t.Fatal("no timely reply from switch")
}
if !reflect.DeepEqual(err, test.expectedErr) {
t.Fatalf("expected: %v, got: %v",
test.expectedErr, err)
}
// Ensure that no circuits were opened.
if s.circuits.NumOpen() > 0 {
t.Fatal("do not expect any open circuits")
}
})
}
}
// TestCheckCircularForward tests the error returned by checkCircularForward
// in cases where we allow and disallow same channel circular forwards.
func TestCheckCircularForward(t *testing.T) {
tests := []struct {
name string
// aliasMapping determines whether the test should add an alias
// mapping to Switch alias maps before checkCircularForward.
aliasMapping bool
// allowCircular determines whether we should allow circular
// forwards.
allowCircular bool
// incomingLink is the link that the htlc arrived on.
incomingLink lnwire.ShortChannelID
// outgoingLink is the link that the htlc forward
// is destined to leave on.
outgoingLink lnwire.ShortChannelID
// expectedErr is the error we expect to be returned.
expectedErr *LinkError
}{
{
name: "not circular, allowed in config",
aliasMapping: false,
allowCircular: true,
incomingLink: lnwire.NewShortChanIDFromInt(123),
outgoingLink: lnwire.NewShortChanIDFromInt(321),
expectedErr: nil,
},
{
name: "not circular, not allowed in config",
aliasMapping: false,
allowCircular: false,
incomingLink: lnwire.NewShortChanIDFromInt(123),
outgoingLink: lnwire.NewShortChanIDFromInt(321),
expectedErr: nil,
},
{
name: "circular, allowed in config",
aliasMapping: false,
allowCircular: true,
incomingLink: lnwire.NewShortChanIDFromInt(123),
outgoingLink: lnwire.NewShortChanIDFromInt(123),
expectedErr: nil,
},
{
name: "circular, not allowed in config",
aliasMapping: false,
allowCircular: false,
incomingLink: lnwire.NewShortChanIDFromInt(123),
outgoingLink: lnwire.NewShortChanIDFromInt(123),
expectedErr: NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureCircularRoute,
),
},
{
name: "circular with map, not allowed",
aliasMapping: true,
allowCircular: false,
incomingLink: lnwire.NewShortChanIDFromInt(1 << 60),
outgoingLink: lnwire.NewShortChanIDFromInt(1 << 55),
expectedErr: NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureCircularRoute,
),
},
{
name: "circular with map, not allowed 2",
aliasMapping: true,
allowCircular: false,
incomingLink: lnwire.NewShortChanIDFromInt(1 << 55),
outgoingLink: lnwire.NewShortChanIDFromInt(1 << 60),
expectedErr: NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureCircularRoute,
),
},
{
name: "circular with map, allowed",
aliasMapping: true,
allowCircular: true,
incomingLink: lnwire.NewShortChanIDFromInt(1 << 60),
outgoingLink: lnwire.NewShortChanIDFromInt(1 << 55),
expectedErr: nil,
},
{
name: "circular with map, allowed 2",
aliasMapping: true,
allowCircular: true,
incomingLink: lnwire.NewShortChanIDFromInt(1 << 55),
outgoingLink: lnwire.NewShortChanIDFromInt(1 << 61),
expectedErr: nil,
},
{
name: "not circular, both confirmed SCID",
aliasMapping: false,
allowCircular: false,
incomingLink: lnwire.NewShortChanIDFromInt(1 << 60),
outgoingLink: lnwire.NewShortChanIDFromInt(1 << 61),
expectedErr: nil,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
defer func() { _ = s.Stop() }()
if test.aliasMapping {
// Make the incoming and outgoing point to the
// same base SCID.
inScid := test.incomingLink
outScid := test.outgoingLink
s.indexMtx.Lock()
s.baseIndex[inScid] = outScid
s.baseIndex[outScid] = outScid
s.indexMtx.Unlock()
}
// Check for a circular forward, the hash passed can
// be nil because it is only used for logging.
err = s.checkCircularForward(
test.incomingLink, test.outgoingLink,
test.allowCircular, lntypes.Hash{},
)
if !reflect.DeepEqual(err, test.expectedErr) {
t.Fatalf("expected: %v, got: %v",
test.expectedErr, err)
}
})
}
}
// TestSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes
// along, then we won't attempt to forward it down al ink that isn't yet able
// to forward any HTLC's.
func TestSkipIneligibleLinksMultiHopForward(t *testing.T) {
tests := []multiHopFwdTest{
// None of the channels is eligible.
{
name: "not eligible",
expectedReply: lnwire.CodeUnknownNextPeer,
},
// Channel one has a policy failure and the other channel isn't
// available.
{
name: "policy fail",
eligible1: true,
failure1: NewLinkError(
lnwire.NewFinalIncorrectCltvExpiry(0),
),
expectedReply: lnwire.CodeFinalIncorrectCltvExpiry,
},
// The requested channel is not eligible, but the packet is
// forwarded through the other channel.
{
name: "non-strict success",
eligible2: true,
expectedReply: lnwire.CodeNone,
},
// The requested channel has insufficient bandwidth and the
// other channel's policy isn't satisfied.
{
name: "non-strict policy fail",
eligible1: true,
failure1: NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureInsufficientBalance,
),
eligible2: true,
failure2: NewLinkError(
lnwire.NewFinalIncorrectCltvExpiry(0),
),
expectedReply: lnwire.CodeTemporaryChannelFailure,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testSkipIneligibleLinksMultiHopForward(t, &test)
})
}
}
// testSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes
// along, then we won't attempt to forward it down al ink that isn't yet able
// to forward any HTLC's.
func testSkipIneligibleLinksMultiHopForward(t *testing.T,
testCase *multiHopFwdTest) {
t.Parallel()
var packet *htlcPacket
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, aliceChanID := genID()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
// We'll create a link for Bob, but mark the link as unable to forward
// any new outgoing HTLC's.
chanID2, bobChanID2 := genID()
bobChannelLink1 := newMockChannelLink(
s, chanID2, bobChanID2, emptyScid, bobPeer, testCase.eligible1,
false, false, false,
)
bobChannelLink1.checkHtlcForwardResult = testCase.failure1
chanID3, bobChanID3 := genID()
bobChannelLink2 := newMockChannelLink(
s, chanID3, bobChanID3, emptyScid, bobPeer, testCase.eligible2,
false, false, false,
)
bobChannelLink2.checkHtlcForwardResult = testCase.failure2
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink1); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if err := s.AddLink(bobChannelLink2); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create a new packet that's destined for Bob as an incoming HTLC from
// Alice.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
obfuscator := NewMockObfuscator()
packet = &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink1.ShortChanID(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
obfuscator: obfuscator,
}
// The request to forward should fail as
if err := s.ForwardPackets(nil, packet); err != nil {
t.Fatal(err)
}
// We select from all links and extract the error if exists.
// The packet must be selected but we don't always expect a link error.
var linkError *LinkError
select {
case p := <-aliceChannelLink.packets:
linkError = p.linkFailure
case p := <-bobChannelLink1.packets:
linkError = p.linkFailure
case p := <-bobChannelLink2.packets:
linkError = p.linkFailure
case <-time.After(time.Second):
t.Fatal("no timely reply from switch")
}
failure := obfuscator.(*mockObfuscator).failure
if testCase.expectedReply == lnwire.CodeNone {
if linkError != nil {
t.Fatalf("forwarding should have succeeded")
}
if failure != nil {
t.Fatalf("unexpected failure %T", failure)
}
} else {
if linkError == nil {
t.Fatalf("forwarding should have failed due to " +
"inactive link")
}
if failure.Code() != testCase.expectedReply {
t.Fatalf("unexpected failure %T", failure)
}
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
// TestSkipIneligibleLinksLocalForward ensures that the switch will not attempt
// to forward any HTLC's down a link that isn't yet eligible for forwarding.
func TestSkipIneligibleLinksLocalForward(t *testing.T) {
t.Parallel()
testSkipLinkLocalForward(t, false, nil)
}
// TestSkipPolicyUnsatisfiedLinkLocalForward ensures that the switch will not
// attempt to send locally initiated HTLCs that would violate the channel policy
// down a link.
func TestSkipPolicyUnsatisfiedLinkLocalForward(t *testing.T) {
t.Parallel()
testSkipLinkLocalForward(t, true, lnwire.NewTemporaryChannelFailure(nil))
}
func testSkipLinkLocalForward(t *testing.T, eligible bool,
policyResult lnwire.FailureMessage) {
// We'll create a single link for this test, marking it as being unable
// to forward form the get go.
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, _, aliceChanID, _ := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, eligible, false,
false, false,
)
aliceChannelLink.checkHtlcTransitResult = NewLinkError(
policyResult,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
preimage, err := genPreimage()
require.NoError(t, err, "unable to generate preimage")
rhash := sha256.Sum256(preimage[:])
addMsg := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
}
// We'll attempt to send out a new HTLC that has Alice as the first
// outgoing link. This should fail as Alice isn't yet able to forward
// any active HTLC's.
err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg)
if err == nil {
t.Fatalf("local forward should fail due to inactive link")
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
// TestSwitchCancel checks that if htlc was rejected we remove unused
// circuits.
func TestSwitchCancel(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarder from alice channel link
// to bob channel link.
preimage, err := genPreimage()
require.NoError(t, err, "unable to generate preimage")
rhash := sha256.Sum256(preimage[:])
request := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, request); err != nil {
t.Fatal(err)
}
select {
case packet := <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
// Create settle request pretending that bob channel link handled
// the add htlc request and sent the htlc settle request back. This
// request should be forwarder back to alice channel link.
request = &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{},
}
// Handle the request and checks that payment circuit works properly.
if err := s.ForwardPackets(nil, request); err != nil {
t.Fatal(err)
}
select {
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s.circuits.NumPending() != 0 {
t.Fatal("wrong amount of circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
// TestSwitchAddSamePayment tests that we send the payment with the same
// payment hash.
func TestSwitchAddSamePayment(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarder from alice channel link
// to bob channel link.
preimage, err := genPreimage()
require.NoError(t, err, "unable to generate preimage")
rhash := sha256.Sum256(preimage[:])
request := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, request); err != nil {
t.Fatal(err)
}
select {
case packet := <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
request = &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 1,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, request); err != nil {
t.Fatal(err)
}
select {
case packet := <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumOpen() != 2 {
t.Fatal("wrong amount of circuits")
}
// Create settle request pretending that bob channel link handled
// the add htlc request and sent the htlc settle request back. This
// request should be forwarder back to alice channel link.
request = &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{},
}
// Handle the request and checks that payment circuit works properly.
if err := s.ForwardPackets(nil, request); err != nil {
t.Fatal(err)
}
select {
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
request = &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 1,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{},
}
// Handle the request and checks that payment circuit works properly.
if err := s.ForwardPackets(nil, request); err != nil {
t.Fatal(err)
}
select {
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
// TestSwitchSendPayment tests ability of htlc switch to respond to the
// users when response is came back from channel link.
func TestSwitchSendPayment(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, _, aliceChanID, _ := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add link: %v", err)
}
// Create request which should be forwarder from alice channel link
// to bob channel link.
preimage, err := genPreimage()
require.NoError(t, err, "unable to generate preimage")
rhash := sha256.Sum256(preimage[:])
update := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
}
paymentID := uint64(123)
// First check that the switch will correctly respond that this payment
// ID is unknown.
_, err = s.GetAttemptResult(
paymentID, rhash, newMockDeobfuscator(),
)
if err != ErrPaymentIDNotFound {
t.Fatalf("expected ErrPaymentIDNotFound, got %v", err)
}
// Handle the request and checks that bob channel link received it.
errChan := make(chan error)
go func() {
err := s.SendHTLC(
aliceChannelLink.ShortChanID(), paymentID, update,
)
if err != nil {
errChan <- err
return
}
resultChan, err := s.GetAttemptResult(
paymentID, rhash, newMockDeobfuscator(),
)
if err != nil {
errChan <- err
return
}
result, ok := <-resultChan
if !ok {
errChan <- fmt.Errorf("shutting down")
}
if result.Error != nil {
errChan <- result.Error
return
}
errChan <- nil
}()
select {
case packet := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case err := <-errChan:
if err != nil {
t.Fatalf("unable to send payment: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
// Create fail request pretending that bob channel link handled
// the add htlc request with error and sent the htlc fail request
// back. This request should be forwarded back to alice channel link.
obfuscator := NewMockObfuscator()
failure := lnwire.NewFailIncorrectDetails(update.Amount, 100)
reason, err := obfuscator.EncryptFirstHop(failure)
require.NoError(t, err, "unable obfuscate failure")
if s.IsForwardedHTLC(aliceChannelLink.ShortChanID(), update.ID) {
t.Fatal("htlc should be identified as not forwarded")
}
packet := &htlcPacket{
outgoingChanID: aliceChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{
Reason: reason,
},
}
if err := s.ForwardPackets(nil, packet); err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
select {
case err := <-errChan:
assertFailureCode(
t, err, lnwire.CodeIncorrectOrUnknownPaymentDetails,
)
case <-time.After(time.Second):
t.Fatal("err wasn't received")
}
}
// TestLocalPaymentNoForwardingEvents tests that if we send a series of locally
// initiated payments, then they aren't reflected in the forwarding log.
func TestLocalPaymentNoForwardingEvents(t *testing.T) {
t.Parallel()
// First, we'll create our traditional three hop network. We'll only be
// interacting with and asserting the state of the first end point for
// this test.
channels, _, err := createClusterChannels(
t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5,
)
require.NoError(t, err, "unable to create channel")
n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice,
channels.bobToCarol, channels.carolToBob, testStartingHeight)
if err := n.start(); err != nil {
t.Fatalf("unable to start three hop network: %v", err)
}
// We'll now craft and send a payment from Alice to Bob.
amount := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin)
htlcAmt, totalTimelock, hops := generateHops(
amount, testStartingHeight, n.firstBobChannelLink,
)
// With the payment crafted, we'll send it from Alice to Bob. We'll
// wait for Alice to receive the preimage for the payment before
// proceeding.
receiver := n.bobServer
firstHop := n.firstBobChannelLink.ShortChanID()
_, err = makePayment(
n.aliceServer, receiver, firstHop, hops, amount, htlcAmt,
totalTimelock,
).Wait(30 * time.Second)
require.NoError(t, err, "unable to make the payment")
// At this point, we'll forcibly stop the three hop network. Doing
// this will cause any pending forwarding events to be flushed by the
// various switches in the network.
n.stop()
// With all the switches stopped, we'll fetch Alice's mock forwarding
// event log.
log, ok := n.aliceServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog)
if !ok {
t.Fatalf("mockForwardingLog assertion failed")
}
log.Lock()
defer log.Unlock()
// If we examine the memory of the forwarding log, then it should be
// blank.
if len(log.events) != 0 {
t.Fatalf("log should have no events, instead has: %v",
spew.Sdump(log.events))
}
}
// TestMultiHopPaymentForwardingEvents tests that if we send a series of
// multi-hop payments via Alice->Bob->Carol. Then Bob properly logs forwarding
// events, while Alice and Carol don't.
func TestMultiHopPaymentForwardingEvents(t *testing.T) {
t.Parallel()
// First, we'll create our traditional three hop network.
channels, _, err := createClusterChannels(
t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5,
)
require.NoError(t, err, "unable to create channel")
n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice,
channels.bobToCarol, channels.carolToBob, testStartingHeight)
if err := n.start(); err != nil {
t.Fatalf("unable to start three hop network: %v", err)
}
// We'll make now 10 payments, of 100k satoshis each from Alice to
// Carol via Bob.
const numPayments = 10
finalAmt := lnwire.NewMSatFromSatoshis(100000)
htlcAmt, totalTimelock, hops := generateHops(
finalAmt, testStartingHeight, n.firstBobChannelLink,
n.carolChannelLink,
)
firstHop := n.firstBobChannelLink.ShortChanID()
for i := 0; i < numPayments/2; i++ {
_, err := makePayment(
n.aliceServer, n.carolServer, firstHop, hops, finalAmt,
htlcAmt, totalTimelock,
).Wait(30 * time.Second)
if err != nil {
t.Fatalf("unable to send payment: %v", err)
}
}
bobLog, ok := n.bobServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog)
if !ok {
t.Fatalf("mockForwardingLog assertion failed")
}
// After sending 5 of the payments, trigger the forwarding ticker, to
// make sure the events are properly flushed.
bobTicker, ok := n.bobServer.htlcSwitch.cfg.FwdEventTicker.(*ticker.Force)
if !ok {
t.Fatalf("mockTicker assertion failed")
}
// We'll trigger the ticker, and wait for the events to appear in Bob's
// forwarding log.
timeout := time.After(15 * time.Second)
for {
select {
case bobTicker.Force <- time.Now():
case <-time.After(1 * time.Second):
t.Fatalf("unable to force tick")
}
// If all 5 events is found in Bob's log, we can break out and
// continue the test.
bobLog.Lock()
if len(bobLog.events) == 5 {
bobLog.Unlock()
break
}
bobLog.Unlock()
// Otherwise wait a little bit before checking again.
select {
case <-time.After(50 * time.Millisecond):
case <-timeout:
bobLog.Lock()
defer bobLog.Unlock()
t.Fatalf("expected 5 events in event log, instead "+
"found: %v", spew.Sdump(bobLog.events))
}
}
// Send the remaining payments.
for i := numPayments / 2; i < numPayments; i++ {
_, err := makePayment(
n.aliceServer, n.carolServer, firstHop, hops, finalAmt,
htlcAmt, totalTimelock,
).Wait(30 * time.Second)
if err != nil {
t.Fatalf("unable to send payment: %v", err)
}
}
// With all 10 payments sent. We'll now manually stop each of the
// switches so we can examine their end state.
n.stop()
// Alice and Carol shouldn't have any recorded forwarding events, as
// they were the source and the sink for these payment flows.
aliceLog, ok := n.aliceServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog)
if !ok {
t.Fatalf("mockForwardingLog assertion failed")
}
aliceLog.Lock()
defer aliceLog.Unlock()
if len(aliceLog.events) != 0 {
t.Fatalf("log should have no events, instead has: %v",
spew.Sdump(aliceLog.events))
}
carolLog, ok := n.carolServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog)
if !ok {
t.Fatalf("mockForwardingLog assertion failed")
}
carolLog.Lock()
defer carolLog.Unlock()
if len(carolLog.events) != 0 {
t.Fatalf("log should have no events, instead has: %v",
spew.Sdump(carolLog.events))
}
// Bob on the other hand, should have 10 events.
bobLog.Lock()
defer bobLog.Unlock()
if len(bobLog.events) != 10 {
t.Fatalf("log should have 10 events, instead has: %v",
spew.Sdump(bobLog.events))
}
// Each of the 10 events should have had all fields set properly.
for _, event := range bobLog.events {
// The incoming and outgoing channels should properly be set for
// the event.
if event.IncomingChanID != n.aliceChannelLink.ShortChanID() {
t.Fatalf("chan id mismatch: expected %v, got %v",
event.IncomingChanID,
n.aliceChannelLink.ShortChanID())
}
if event.OutgoingChanID != n.carolChannelLink.ShortChanID() {
t.Fatalf("chan id mismatch: expected %v, got %v",
event.OutgoingChanID,
n.carolChannelLink.ShortChanID())
}
// Additionally, the incoming and outgoing amounts should also
// be properly set.
if event.AmtIn != htlcAmt {
t.Fatalf("incoming amt mismatch: expected %v, got %v",
event.AmtIn, htlcAmt)
}
if event.AmtOut != finalAmt {
t.Fatalf("outgoing amt mismatch: expected %v, got %v",
event.AmtOut, finalAmt)
}
}
}
// TestUpdateFailMalformedHTLCErrorConversion tests that we're able to properly
// convert malformed HTLC errors that originate at the direct link, as well as
// during multi-hop HTLC forwarding.
func TestUpdateFailMalformedHTLCErrorConversion(t *testing.T) {
t.Parallel()
// First, we'll create our traditional three hop network.
channels, _, err := createClusterChannels(
t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5,
)
require.NoError(t, err, "unable to create channel")
n := newThreeHopNetwork(
t, channels.aliceToBob, channels.bobToAlice,
channels.bobToCarol, channels.carolToBob, testStartingHeight,
)
if err := n.start(); err != nil {
t.Fatalf("unable to start three hop network: %v", err)
}
assertPaymentFailure := func(t *testing.T) {
// With the decoder modified, we'll now attempt to send a
// payment from Alice to carol.
finalAmt := lnwire.NewMSatFromSatoshis(100000)
htlcAmt, totalTimelock, hops := generateHops(
finalAmt, testStartingHeight, n.firstBobChannelLink,
n.carolChannelLink,
)
firstHop := n.firstBobChannelLink.ShortChanID()
_, err = makePayment(
n.aliceServer, n.carolServer, firstHop, hops, finalAmt,
htlcAmt, totalTimelock,
).Wait(30 * time.Second)
// The payment should fail as Carol is unable to decode the
// onion blob sent to her.
if err == nil {
t.Fatalf("unable to send payment: %v", err)
}
routingErr := err.(ClearTextError)
failureMsg := routingErr.WireMessage()
if _, ok := failureMsg.(*lnwire.FailInvalidOnionKey); !ok {
t.Fatalf("expected onion failure instead got: %v",
routingErr.WireMessage())
}
}
t.Run("multi-hop error conversion", func(t *testing.T) {
// Now that we have our network up, we'll modify the hop
// iterator for the Bob <-> Carol channel to fail to decode in
// order to simulate either a replay attack or an issue
// decoding the onion.
n.carolOnionDecoder.decodeFail = true
assertPaymentFailure(t)
})
t.Run("direct channel error conversion", func(t *testing.T) {
// Similar to the above test case, we'll now make the Alice <->
// Bob link always fail to decode an onion. This differs from
// the above test case in that there's no encryption on the
// error at all since Alice will directly receive a
// UpdateFailMalformedHTLC message.
n.bobOnionDecoder.decodeFail = true
assertPaymentFailure(t)
})
}
// TestSwitchGetAttemptResult tests that the switch interacts as expected with
// the circuit map and network result store when looking up the result of a
// payment ID. This is important for not to lose results under concurrent
// lookup and receiving results.
func TestSwitchGetAttemptResult(t *testing.T) {
t.Parallel()
const paymentID = 123
var preimg lntypes.Preimage
preimg[0] = 3
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
lookup := make(chan *PaymentCircuit, 1)
s.circuits = &mockCircuitMap{
lookup: lookup,
}
// If the payment circuit is not found in the circuit map, the payment
// result must be found in the store if available. Since we haven't
// added anything to the store yet, ErrPaymentIDNotFound should be
// returned.
lookup <- nil
_, err = s.GetAttemptResult(
paymentID, lntypes.Hash{}, newMockDeobfuscator(),
)
if err != ErrPaymentIDNotFound {
t.Fatalf("expected ErrPaymentIDNotFound, got %v", err)
}
// Next let the lookup find the circuit in the circuit map. It should
// subscribe to payment results, and return the result when available.
lookup <- &PaymentCircuit{}
resultChan, err := s.GetAttemptResult(
paymentID, lntypes.Hash{}, newMockDeobfuscator(),
)
require.NoError(t, err, "unable to get payment result")
// Add the result to the store.
n := &networkResult{
msg: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimg,
},
unencrypted: true,
isResolution: true,
}
err = s.networkResults.storeResult(paymentID, n)
require.NoError(t, err, "unable to store result")
// The result should be available.
select {
case res, ok := <-resultChan:
if !ok {
t.Fatalf("channel was closed")
}
if res.Error != nil {
t.Fatalf("got unexpected error result")
}
if res.Preimage != preimg {
t.Fatalf("expected preimg %v, got %v",
preimg, res.Preimage)
}
case <-time.After(1 * time.Second):
t.Fatalf("result not received")
}
// As a final test, try to get the result again. Now that is no longer
// in the circuit map, it should be immediately available from the
// store.
lookup <- nil
resultChan, err = s.GetAttemptResult(
paymentID, lntypes.Hash{}, newMockDeobfuscator(),
)
require.NoError(t, err, "unable to get payment result")
select {
case res, ok := <-resultChan:
if !ok {
t.Fatalf("channel was closed")
}
if res.Error != nil {
t.Fatalf("got unexpected error result")
}
if res.Preimage != preimg {
t.Fatalf("expected preimg %v, got %v",
preimg, res.Preimage)
}
case <-time.After(1 * time.Second):
t.Fatalf("result not received")
}
}
// TestInvalidFailure tests that the switch returns an unreadable failure error
// if the failure cannot be decrypted.
func TestInvalidFailure(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, _, aliceChanID, _ := genIDs()
// Set up a mock channel link.
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add link: %v", err)
}
// Create a request which should be forwarded to the mock channel link.
preimage, err := genPreimage()
require.NoError(t, err, "unable to generate preimage")
rhash := sha256.Sum256(preimage[:])
update := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
}
paymentID := uint64(123)
// Send the request.
err = s.SendHTLC(
aliceChannelLink.ShortChanID(), paymentID, update,
)
require.NoError(t, err, "unable to send payment")
// Catch the packet and complete the circuit so that the switch is ready
// for a response.
select {
case packet := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Send response packet with an unreadable failure message to the
// switch. The reason failed is not relevant, because we mock the
// decryption.
packet := &htlcPacket{
outgoingChanID: aliceChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{
Reason: []byte{1, 2, 3},
},
}
if err := s.ForwardPackets(nil, packet); err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
// Get payment result from switch. We expect an unreadable failure
// message error.
deobfuscator := SphinxErrorDecrypter{
OnionErrorDecrypter: &mockOnionErrorDecryptor{
err: ErrUnreadableFailureMessage,
},
}
resultChan, err := s.GetAttemptResult(
paymentID, rhash, &deobfuscator,
)
if err != nil {
t.Fatal(err)
}
select {
case result := <-resultChan:
if result.Error != ErrUnreadableFailureMessage {
t.Fatal("expected unreadable failure message")
}
case <-time.After(time.Second):
t.Fatal("err wasn't received")
}
// Modify the decryption to simulate that decryption went alright, but
// the failure cannot be decoded.
deobfuscator = SphinxErrorDecrypter{
OnionErrorDecrypter: &mockOnionErrorDecryptor{
sourceIdx: 2,
message: []byte{200},
},
}
resultChan, err = s.GetAttemptResult(
paymentID, rhash, &deobfuscator,
)
if err != nil {
t.Fatal(err)
}
select {
case result := <-resultChan:
rtErr, ok := result.Error.(ClearTextError)
if !ok {
t.Fatal("expected ClearTextError")
}
source, ok := rtErr.(*ForwardingError)
if !ok {
t.Fatalf("expected forwarding error, got: %T", rtErr)
}
if source.FailureSourceIdx != 2 {
t.Fatal("unexpected error source index")
}
if rtErr.WireMessage() != nil {
t.Fatal("expected empty failure message")
}
case <-time.After(time.Second):
t.Fatal("err wasn't received")
}
}
// htlcNotifierEvents is a function that generates a set of expected htlc
// notifier evetns for each node in a three hop network with the dynamic
// values provided. These functions take dynamic values so that changes to
// external systems (such as our default timelock delta) do not break
// these tests.
type htlcNotifierEvents func(channels *clusterChannels, htlcID uint64,
ts time.Time, htlc *lnwire.UpdateAddHTLC,
hops []*hop.Payload,
preimage *lntypes.Preimage) ([]interface{}, []interface{}, []interface{})
// TestHtlcNotifier tests the notifying of htlc events that are routed over a
// three hop network. It sets up an Alice -> Bob -> Carol network and routes
// payments from Alice -> Carol to test events from the perspective of a
// sending (Alice), forwarding (Bob) and receiving (Carol) node. Test cases
// are present for saduccessful and failed payments.
func TestHtlcNotifier(t *testing.T) {
tests := []struct {
name string
// Options is a set of options to apply to the three hop
// network's servers.
options []serverOption
// expectedEvents is a function which returns an expected set
// of events for the test.
expectedEvents htlcNotifierEvents
// iterations is the number of times we will send a payment,
// this is used to send more than one payment to force non-
// zero htlc indexes to make sure we aren't just checking
// default values.
iterations int
}{
{
name: "successful three hop payment",
options: nil,
expectedEvents: func(channels *clusterChannels,
htlcID uint64, ts time.Time,
htlc *lnwire.UpdateAddHTLC,
hops []*hop.Payload,
preimage *lntypes.Preimage) ([]interface{},
[]interface{}, []interface{}) {
return getThreeHopEvents(
channels, htlcID, ts, htlc, hops, nil, preimage,
)
},
iterations: 2,
},
{
name: "failed at forwarding link",
// Set a functional option which disables bob as a
// forwarding node to force a payment error.
options: []serverOption{
serverOptionRejectHtlc(false, true, false),
},
expectedEvents: func(channels *clusterChannels,
htlcID uint64, ts time.Time,
htlc *lnwire.UpdateAddHTLC,
hops []*hop.Payload,
preimage *lntypes.Preimage) ([]interface{},
[]interface{}, []interface{}) {
return getThreeHopEvents(
channels, htlcID, ts, htlc, hops,
&LinkError{
msg: &lnwire.FailChannelDisabled{},
FailureDetail: OutgoingFailureForwardsDisabled,
},
preimage,
)
},
iterations: 1,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testHtcNotifier(
t, test.options, test.iterations,
test.expectedEvents,
)
})
}
}
// testHtcNotifier runs a htlc notifier test.
func testHtcNotifier(t *testing.T, testOpts []serverOption, iterations int,
getEvents htlcNotifierEvents) {
t.Parallel()
// First, we'll create our traditional three hop
// network.
channels, _, err := createClusterChannels(
t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5,
)
require.NoError(t, err, "unable to create channel")
// Mock time so that all events are reported with a static timestamp.
now := time.Now()
mockTime := func() time.Time {
return now
}
// Create htlc notifiers for each server in the three hop network and
// start them.
aliceNotifier := NewHtlcNotifier(mockTime)
if err := aliceNotifier.Start(); err != nil {
t.Fatalf("could not start alice notifier")
}
t.Cleanup(func() {
if err := aliceNotifier.Stop(); err != nil {
t.Fatalf("failed to stop alice notifier: %v", err)
}
})
bobNotifier := NewHtlcNotifier(mockTime)
if err := bobNotifier.Start(); err != nil {
t.Fatalf("could not start bob notifier")
}
t.Cleanup(func() {
if err := bobNotifier.Stop(); err != nil {
t.Fatalf("failed to stop bob notifier: %v", err)
}
})
carolNotifier := NewHtlcNotifier(mockTime)
if err := carolNotifier.Start(); err != nil {
t.Fatalf("could not start carol notifier")
}
t.Cleanup(func() {
if err := carolNotifier.Stop(); err != nil {
t.Fatalf("failed to stop carol notifier: %v", err)
}
})
// Create a notifier server option which will set our htlc notifiers
// for the three hop network.
notifierOption := serverOptionWithHtlcNotifier(
aliceNotifier, bobNotifier, carolNotifier,
)
// Add the htlcNotifier option to any other options
// set in the test.
options := append(testOpts, notifierOption) // nolint:gocritic
n := newThreeHopNetwork(
t, channels.aliceToBob,
channels.bobToAlice, channels.bobToCarol,
channels.carolToBob, testStartingHeight,
options...,
)
if err := n.start(); err != nil {
t.Fatalf("unable to start three hop "+
"network: %v", err)
}
t.Cleanup(n.stop)
// Before we forward anything, subscribe to htlc events
// from each notifier.
aliceEvents, err := aliceNotifier.SubscribeHtlcEvents()
if err != nil {
t.Fatalf("could not subscribe to alice's"+
" events: %v", err)
}
t.Cleanup(aliceEvents.Cancel)
bobEvents, err := bobNotifier.SubscribeHtlcEvents()
if err != nil {
t.Fatalf("could not subscribe to bob's"+
" events: %v", err)
}
t.Cleanup(bobEvents.Cancel)
carolEvents, err := carolNotifier.SubscribeHtlcEvents()
if err != nil {
t.Fatalf("could not subscribe to carol's"+
" events: %v", err)
}
t.Cleanup(carolEvents.Cancel)
// Send multiple payments, as specified by the test to test incrementing
// of htlc ids.
for i := 0; i < iterations; i++ {
// We'll start off by making a payment from
// Alice -> Bob -> Carol. The preimage, generated
// by Carol's Invoice is expected in the Settle events
htlc, hops, preimage := n.sendThreeHopPayment(t)
alice, bob, carol := getEvents(
channels, uint64(i), now, htlc, hops, preimage,
)
checkHtlcEvents(t, aliceEvents.Updates(), alice)
checkHtlcEvents(t, bobEvents.Updates(), bob)
checkHtlcEvents(t, carolEvents.Updates(), carol)
}
}
// checkHtlcEvents checks that a subscription has the set of htlc events
// we expect it to have.
func checkHtlcEvents(t *testing.T, events <-chan interface{},
expectedEvents []interface{}) {
t.Helper()
for _, expected := range expectedEvents {
select {
case event := <-events:
if !reflect.DeepEqual(event, expected) {
t.Fatalf("expected %v, got: %v", expected,
event)
}
case <-time.After(5 * time.Second):
t.Fatalf("expected event: %v", expected)
}
}
// Check that there are no unexpected events following.
select {
case event := <-events:
t.Fatalf("unexpected event: %v", event)
default:
}
}
// sendThreeHopPayment is a helper function which sends a payment over
// Alice -> Bob -> Carol in a three hop network and returns Alice's first htlc
// and the remainder of the hops.
func (n *threeHopNetwork) sendThreeHopPayment(t *testing.T) (*lnwire.UpdateAddHTLC,
[]*hop.Payload, *lntypes.Preimage) {
amount := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin)
htlcAmt, totalTimelock, hops := generateHops(amount, testStartingHeight,
n.firstBobChannelLink, n.carolChannelLink)
blob, err := generateRoute(hops...)
if err != nil {
t.Fatal(err)
}
invoice, htlc, pid, err := generatePayment(
amount, htlcAmt, totalTimelock, blob,
)
if err != nil {
t.Fatal(err)
}
err = n.carolServer.registry.AddInvoice(
context.Background(), *invoice, htlc.PaymentHash,
)
require.NoError(t, err, "unable to add invoice in carol registry")
if err := n.aliceServer.htlcSwitch.SendHTLC(
n.firstBobChannelLink.ShortChanID(), pid, htlc,
); err != nil {
t.Fatalf("could not send htlc")
}
return htlc, hops, invoice.Terms.PaymentPreimage
}
// getThreeHopEvents gets the set of htlc events that we expect for a payment
// from Alice -> Bob -> Carol. If a non-nil link error is provided, the set
// of events will fail on Bob's outgoing link.
func getThreeHopEvents(channels *clusterChannels, htlcID uint64,
ts time.Time, htlc *lnwire.UpdateAddHTLC, hops []*hop.Payload,
linkError *LinkError,
preimage *lntypes.Preimage) ([]interface{}, []interface{}, []interface{}) {
aliceKey := HtlcKey{
IncomingCircuit: zeroCircuit,
OutgoingCircuit: models.CircuitKey{
ChanID: channels.aliceToBob.ShortChanID(),
HtlcID: htlcID,
},
}
// Alice always needs a forwarding event because she initiates the
// send.
aliceEvents := []interface{}{
&ForwardingEvent{
HtlcKey: aliceKey,
HtlcInfo: HtlcInfo{
OutgoingTimeLock: htlc.Expiry,
OutgoingAmt: htlc.Amount,
},
HtlcEventType: HtlcEventTypeSend,
Timestamp: ts,
},
}
bobKey := HtlcKey{
IncomingCircuit: models.CircuitKey{
ChanID: channels.bobToAlice.ShortChanID(),
HtlcID: htlcID,
},
OutgoingCircuit: models.CircuitKey{
ChanID: channels.bobToCarol.ShortChanID(),
HtlcID: htlcID,
},
}
bobInfo := HtlcInfo{
IncomingTimeLock: htlc.Expiry,
IncomingAmt: htlc.Amount,
OutgoingTimeLock: hops[1].FwdInfo.OutgoingCTLV,
OutgoingAmt: hops[1].FwdInfo.AmountToForward,
}
// If we expect the payment to fail, we add failures for alice and
// bob, and no events for carol because the payment never reaches her.
if linkError != nil {
aliceEvents = append(aliceEvents,
&ForwardingFailEvent{
HtlcKey: aliceKey,
HtlcEventType: HtlcEventTypeSend,
Timestamp: ts,
},
)
bobEvents := []interface{}{
&LinkFailEvent{
HtlcKey: bobKey,
HtlcInfo: bobInfo,
HtlcEventType: HtlcEventTypeForward,
LinkError: linkError,
Incoming: false,
Timestamp: ts,
},
&FinalHtlcEvent{
CircuitKey: bobKey.IncomingCircuit,
Settled: false,
Offchain: true,
Timestamp: ts,
},
}
return aliceEvents, bobEvents, nil
}
// If we want to get events for a successful payment, we add a settle
// for alice, a forward and settle for bob and a receive settle for
// carol.
aliceEvents = append(
aliceEvents,
&SettleEvent{
HtlcKey: aliceKey,
Preimage: *preimage,
HtlcEventType: HtlcEventTypeSend,
Timestamp: ts,
},
)
bobEvents := []interface{}{
&ForwardingEvent{
HtlcKey: bobKey,
HtlcInfo: bobInfo,
HtlcEventType: HtlcEventTypeForward,
Timestamp: ts,
},
&SettleEvent{
HtlcKey: bobKey,
Preimage: *preimage,
HtlcEventType: HtlcEventTypeForward,
Timestamp: ts,
},
&FinalHtlcEvent{
CircuitKey: bobKey.IncomingCircuit,
Settled: true,
Offchain: true,
Timestamp: ts,
},
}
carolEvents := []interface{}{
&SettleEvent{
HtlcKey: HtlcKey{
IncomingCircuit: models.CircuitKey{
ChanID: channels.carolToBob.ShortChanID(),
HtlcID: htlcID,
},
OutgoingCircuit: zeroCircuit,
},
Preimage: *preimage,
HtlcEventType: HtlcEventTypeReceive,
Timestamp: ts,
}, &FinalHtlcEvent{
CircuitKey: models.CircuitKey{
ChanID: channels.carolToBob.ShortChanID(),
HtlcID: htlcID,
},
Settled: true,
Offchain: true,
Timestamp: ts,
},
}
return aliceEvents, bobEvents, carolEvents
}
type mockForwardInterceptor struct {
t *testing.T
interceptedChan chan InterceptedPacket
}
func (m *mockForwardInterceptor) InterceptForwardHtlc(
intercepted InterceptedPacket) error {
m.interceptedChan <- intercepted
return nil
}
func (m *mockForwardInterceptor) getIntercepted() InterceptedPacket {
m.t.Helper()
select {
case p := <-m.interceptedChan:
return p
case <-time.After(time.Second):
require.Fail(m.t, "timeout")
return InterceptedPacket{}
}
}
func assertNumCircuits(t *testing.T, s *Switch, pending, opened int) {
if s.circuits.NumPending() != pending {
t.Fatalf("wrong amount of half circuits, expected %v but "+
"got %v", pending, s.circuits.NumPending())
}
if s.circuits.NumOpen() != opened {
t.Fatalf("wrong amount of circuits, expected %v but got %v",
opened, s.circuits.NumOpen())
}
}
func assertOutgoingLinkReceive(t *testing.T, targetLink *mockChannelLink,
expectReceive bool) *htlcPacket {
// Pull packet from targetLink link.
select {
case packet := <-targetLink.packets:
if !expectReceive {
t.Fatal("forward was intercepted, shouldn't land at bob link")
} else if err := targetLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
return packet
case <-time.After(time.Second):
if expectReceive {
t.Fatal("request was not propagated to destination")
}
}
return nil
}
func assertOutgoingLinkReceiveIntercepted(t *testing.T,
targetLink *mockChannelLink) {
t.Helper()
select {
case <-targetLink.packets:
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
}
type interceptableSwitchTestContext struct {
t *testing.T
preimage [sha256.Size]byte
rhash [32]byte
onionBlob [1366]byte
incomingHtlcID uint64
cltvRejectDelta uint32
cltvInterceptDelta uint32
forwardInterceptor *mockForwardInterceptor
aliceChannelLink *mockChannelLink
bobChannelLink *mockChannelLink
s *Switch
}
func newInterceptableSwitchTestContext(
t *testing.T) *interceptableSwitchTestContext {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to open channeldb")
t.Cleanup(func() { cdb.Close() })
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
preimage := [sha256.Size]byte{1}
ctx := &interceptableSwitchTestContext{
t: t,
preimage: preimage,
rhash: sha256.Sum256(preimage[:]),
onionBlob: [1366]byte{4, 5, 6},
incomingHtlcID: uint64(0),
cltvRejectDelta: 10,
cltvInterceptDelta: 13,
forwardInterceptor: &mockForwardInterceptor{
t: t,
interceptedChan: make(chan InterceptedPacket),
},
aliceChannelLink: aliceChannelLink,
bobChannelLink: bobChannelLink,
s: s,
}
return ctx
}
func (c *interceptableSwitchTestContext) createTestPacket() *htlcPacket {
c.incomingHtlcID++
return &htlcPacket{
incomingChanID: c.aliceChannelLink.ShortChanID(),
incomingHTLCID: c.incomingHtlcID,
incomingTimeout: testStartingHeight + c.cltvInterceptDelta + 1,
outgoingChanID: c.bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: c.rhash,
Amount: 1,
OnionBlob: c.onionBlob,
},
}
}
func (c *interceptableSwitchTestContext) finish() {
if err := c.s.Stop(); err != nil {
c.t.Fatalf(err.Error())
}
}
func (c *interceptableSwitchTestContext) createSettlePacket(
outgoingHTLCID uint64) *htlcPacket {
return &htlcPacket{
outgoingChanID: c.bobChannelLink.ShortChanID(),
outgoingHTLCID: outgoingHTLCID,
amount: 1,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: c.preimage,
},
}
}
func TestSwitchHoldForward(t *testing.T) {
t.Parallel()
c := newInterceptableSwitchTestContext(t)
defer c.finish()
notifier := &mock.ChainNotifier{
EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
}
notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight}
switchForwardInterceptor, err := NewInterceptableSwitch(
&InterceptableSwitchConfig{
Switch: c.s,
CltvRejectDelta: c.cltvRejectDelta,
CltvInterceptDelta: c.cltvInterceptDelta,
Notifier: notifier,
},
)
require.NoError(t, err)
require.NoError(t, switchForwardInterceptor.Start())
switchForwardInterceptor.SetInterceptor(c.forwardInterceptor.InterceptForwardHtlc)
linkQuit := make(chan struct{})
// Test a forward that expires too soon.
packet := c.createTestPacket()
packet.incomingTimeout = testStartingHeight + c.cltvRejectDelta - 1
err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet)
require.NoError(t, err, "can't forward htlc packet")
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertOutgoingLinkReceiveIntercepted(t, c.aliceChannelLink)
assertNumCircuits(t, c.s, 0, 0)
// Test a forward that expires too soon and can't be failed.
packet = c.createTestPacket()
packet.incomingTimeout = testStartingHeight + c.cltvRejectDelta - 1
// Simulate an error during the composition of the failure message.
currentCallback := c.s.cfg.FetchLastChannelUpdate
c.s.cfg.FetchLastChannelUpdate = func(
lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
return nil, errors.New("cannot fetch update")
}
err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet)
require.NoError(t, err, "can't forward htlc packet")
receivedPkt := assertOutgoingLinkReceive(t, c.bobChannelLink, true)
assertNumCircuits(t, c.s, 1, 1)
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false,
c.createSettlePacket(receivedPkt.outgoingHTLCID),
))
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
c.s.cfg.FetchLastChannelUpdate = currentCallback
// Test resume a hold forward.
assertNumCircuits(t, c.s, 0, 0)
err = switchForwardInterceptor.ForwardPackets(
linkQuit, false, c.createTestPacket(),
)
require.NoError(t, err)
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Action: FwdActionResume,
Key: c.forwardInterceptor.getIntercepted().IncomingCircuit,
}))
receivedPkt = assertOutgoingLinkReceive(t, c.bobChannelLink, true)
assertNumCircuits(t, c.s, 1, 1)
// settling the htlc to close the circuit.
err = switchForwardInterceptor.ForwardPackets(
linkQuit, false,
c.createSettlePacket(receivedPkt.outgoingHTLCID),
)
require.NoError(t, err)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
// Test resume a hold forward after disconnection.
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, c.createTestPacket(),
))
// Wait until the packet is offered to the interceptor.
_ = c.forwardInterceptor.getIntercepted()
// No forward expected yet.
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
// Disconnect should resume the forwarding.
switchForwardInterceptor.SetInterceptor(nil)
receivedPkt = assertOutgoingLinkReceive(t, c.bobChannelLink, true)
assertNumCircuits(t, c.s, 1, 1)
// Settle the htlc to close the circuit.
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false,
c.createSettlePacket(receivedPkt.outgoingHTLCID),
))
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
// Test failing a hold forward
switchForwardInterceptor.SetInterceptor(
c.forwardInterceptor.InterceptForwardHtlc,
)
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, c.createTestPacket(),
))
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Action: FwdActionFail,
Key: c.forwardInterceptor.getIntercepted().IncomingCircuit,
FailureCode: lnwire.CodeTemporaryChannelFailure,
}))
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
// Test failing a hold forward with a failure message.
require.NoError(t,
switchForwardInterceptor.ForwardPackets(
linkQuit, false, c.createTestPacket(),
),
)
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
reason := lnwire.OpaqueReason([]byte{1, 2, 3})
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Action: FwdActionFail,
Key: c.forwardInterceptor.getIntercepted().IncomingCircuit,
FailureMessage: reason,
}))
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
packet = assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
require.Equal(t, reason, packet.htlc.(*lnwire.UpdateFailHTLC).Reason)
assertNumCircuits(t, c.s, 0, 0)
// Test failing a hold forward with a malformed htlc failure.
err = switchForwardInterceptor.ForwardPackets(
linkQuit, false, c.createTestPacket(),
)
require.NoError(t, err)
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
code := lnwire.CodeInvalidOnionKey
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Action: FwdActionFail,
Key: c.forwardInterceptor.getIntercepted().IncomingCircuit,
FailureCode: code,
}))
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
packet = assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
failPacket := packet.htlc.(*lnwire.UpdateFailHTLC)
shaOnionBlob := sha256.Sum256(c.onionBlob[:])
expectedFailure := &lnwire.FailInvalidOnionKey{
OnionSHA256: shaOnionBlob,
}
fwdErr, err := newMockDeobfuscator().DecryptError(failPacket.Reason)
require.NoError(t, err)
require.Equal(t, expectedFailure, fwdErr.WireMessage())
assertNumCircuits(t, c.s, 0, 0)
// Test settling a hold forward
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, c.createTestPacket(),
))
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Key: c.forwardInterceptor.getIntercepted().IncomingCircuit,
Action: FwdActionSettle,
Preimage: c.preimage,
}))
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
require.NoError(t, switchForwardInterceptor.Stop())
// Test always-on interception.
notifier = &mock.ChainNotifier{
EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
}
notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight}
switchForwardInterceptor, err = NewInterceptableSwitch(
&InterceptableSwitchConfig{
Switch: c.s,
CltvRejectDelta: c.cltvRejectDelta,
CltvInterceptDelta: c.cltvInterceptDelta,
RequireInterceptor: true,
Notifier: notifier,
},
)
require.NoError(t, err)
require.NoError(t, switchForwardInterceptor.Start())
// Forward a fresh packet. It is expected to be failed immediately,
// because there is no interceptor registered.
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, c.createTestPacket(),
))
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
// Forward a replayed packet. It is expected to be held until the
// interceptor connects. To continue the test, it needs to be ran in a
// goroutine.
errChan := make(chan error)
go func() {
errChan <- switchForwardInterceptor.ForwardPackets(
linkQuit, true, c.createTestPacket(),
)
}()
// Assert that nothing is forward to the switch.
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertNumCircuits(t, c.s, 0, 0)
// Register an interceptor.
switchForwardInterceptor.SetInterceptor(
c.forwardInterceptor.InterceptForwardHtlc,
)
// Expect the ForwardPackets call to unblock.
require.NoError(t, <-errChan)
// Now expect the queued packet to come through.
c.forwardInterceptor.getIntercepted()
// Disconnect and reconnect interceptor.
switchForwardInterceptor.SetInterceptor(nil)
switchForwardInterceptor.SetInterceptor(
c.forwardInterceptor.InterceptForwardHtlc,
)
// A replay of the held packet is expected.
intercepted := c.forwardInterceptor.getIntercepted()
// Settle the packet.
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Key: intercepted.IncomingCircuit,
Action: FwdActionSettle,
Preimage: c.preimage,
}))
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
require.NoError(t, switchForwardInterceptor.Stop())
select {
case <-c.forwardInterceptor.interceptedChan:
require.Fail(t, "unexpected interception")
default:
}
}
func TestInterceptableSwitchWatchDog(t *testing.T) {
t.Parallel()
c := newInterceptableSwitchTestContext(t)
defer c.finish()
// Start interceptable switch.
notifier := &mock.ChainNotifier{
EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
}
notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight}
switchForwardInterceptor, err := NewInterceptableSwitch(
&InterceptableSwitchConfig{
Switch: c.s,
CltvRejectDelta: c.cltvRejectDelta,
CltvInterceptDelta: c.cltvInterceptDelta,
Notifier: notifier,
},
)
require.NoError(t, err)
require.NoError(t, switchForwardInterceptor.Start())
// Set interceptor.
switchForwardInterceptor.SetInterceptor(
c.forwardInterceptor.InterceptForwardHtlc,
)
// Receive a packet.
linkQuit := make(chan struct{})
packet := c.createTestPacket()
err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet)
require.NoError(t, err, "can't forward htlc packet")
// Intercept the packet.
intercepted := c.forwardInterceptor.getIntercepted()
require.Equal(t,
int32(packet.incomingTimeout-c.cltvRejectDelta),
intercepted.AutoFailHeight,
)
// Htlc expires before a resolution from the interceptor.
notifier.EpochChan <- &chainntnfs.BlockEpoch{
Height: int32(packet.incomingTimeout) -
int32(c.cltvRejectDelta),
}
// Expect the htlc to be failed back.
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
// It is too late now to resolve. Expect an error.
require.Error(t, switchForwardInterceptor.Resolve(&FwdResolution{
Action: FwdActionSettle,
Key: intercepted.IncomingCircuit,
Preimage: c.preimage,
}))
}
// TestSwitchDustForwarding tests that the switch properly fails HTLC's which
// have incoming or outgoing links that breach their dust thresholds.
func TestSwitchDustForwarding(t *testing.T) {
t.Parallel()
// We'll create a three-hop network:
// - Alice has a dust limit of 200sats with Bob
// - Bob has a dust limit of 800sats with Alice
// - Bob has a dust limit of 200sats with Carol
// - Carol has a dust limit of 800sats with Bob
channels, _, err := createClusterChannels(
t, btcutil.SatoshiPerBitcoin, btcutil.SatoshiPerBitcoin,
)
require.NoError(t, err)
n := newThreeHopNetwork(
t, channels.aliceToBob, channels.bobToAlice,
channels.bobToCarol, channels.carolToBob, testStartingHeight,
)
err = n.start()
require.NoError(t, err)
// We'll also put Alice and Bob into hodl.ExitSettle mode, such that
// they won't settle incoming exit-hop HTLC's automatically.
n.aliceChannelLink.cfg.HodlMask = hodl.ExitSettle.Mask()
n.firstBobChannelLink.cfg.HodlMask = hodl.ExitSettle.Mask()
// We'll test that once the default threshold is exceeded on the
// Alice -> Bob channel, either side's calls to SendHTLC will fail.
//
// Alice will send 357 HTLC's of 700sats. Bob will also send 357 HTLC's
// of 700sats. If either side attempts to send a dust HTLC, it will
// fail so amounts below 800sats will breach the dust threshold.
amt := lnwire.NewMSatFromSatoshis(700)
aliceBobFirstHop := n.aliceChannelLink.ShortChanID()
sendDustHtlcs(t, n, true, amt, aliceBobFirstHop)
sendDustHtlcs(t, n, false, amt, aliceBobFirstHop)
// Generate the parameters needed for Bob to send another dust HTLC.
_, timelock, hops := generateHops(
amt, testStartingHeight, n.aliceChannelLink,
)
blob, err := generateRoute(hops...)
require.NoError(t, err)
// Assert that if Bob sends a dust HTLC it will fail.
failingPreimage := lntypes.Preimage{0, 0, 3}
failingHash := failingPreimage.Hash()
failingHtlc := &lnwire.UpdateAddHTLC{
PaymentHash: failingHash,
Amount: amt,
Expiry: timelock,
OnionBlob: blob,
}
checkAlmostDust := func(link *channelLink, mbox MailBox,
remote bool) bool {
timeout := time.After(15 * time.Second)
pollInterval := 300 * time.Millisecond
expectedDust := 357 * 2 * amt
for {
<-time.After(pollInterval)
select {
case <-timeout:
return false
default:
}
linkDust := link.getDustSum(remote)
localMailDust, remoteMailDust := mbox.DustPackets()
totalDust := linkDust
if remote {
totalDust += remoteMailDust
} else {
totalDust += localMailDust
}
if totalDust == expectedDust {
break
}
}
return true
}
// Wait until Bob is almost at the dust threshold.
bobMbox := n.bobServer.htlcSwitch.mailOrchestrator.GetOrCreateMailBox(
n.firstBobChannelLink.ChanID(),
n.firstBobChannelLink.ShortChanID(),
)
require.True(t, checkAlmostDust(n.firstBobChannelLink, bobMbox, false))
// Assert that the HTLC is failed due to the dust threshold.
err = n.bobServer.htlcSwitch.SendHTLC(
aliceBobFirstHop, uint64(357), failingHtlc,
)
require.ErrorIs(t, err, errDustThresholdExceeded)
// Generate the parameters needed for bob to send a non-dust HTLC.
nondustAmt := lnwire.NewMSatFromSatoshis(10_000)
_, _, hops = generateHops(
nondustAmt, testStartingHeight, n.aliceChannelLink,
)
blob, err = generateRoute(hops...)
require.NoError(t, err)
// Now attempt to send an HTLC above Bob's dust limit. It should
// succeed.
nondustPreimage := lntypes.Preimage{0, 0, 4}
nondustHash := nondustPreimage.Hash()
nondustHtlc := &lnwire.UpdateAddHTLC{
PaymentHash: nondustHash,
Amount: nondustAmt,
Expiry: timelock,
OnionBlob: blob,
}
// Assert that SendHTLC succeeds and evaluateDustThreshold returns
// false.
err = n.bobServer.htlcSwitch.SendHTLC(
aliceBobFirstHop, uint64(358), nondustHtlc,
)
require.NoError(t, err)
// Introduce Carol into the mix and assert that sending a multi-hop
// dust HTLC to Alice will fail. Bob should fail back the HTLC with a
// temporary channel failure.
carolAmt, carolTimelock, carolHops := generateHops(
amt, testStartingHeight, n.secondBobChannelLink,
n.aliceChannelLink,
)
carolBlob, err := generateRoute(carolHops...)
require.NoError(t, err)
carolPreimage := lntypes.Preimage{0, 0, 5}
carolHash := carolPreimage.Hash()
carolHtlc := &lnwire.UpdateAddHTLC{
PaymentHash: carolHash,
Amount: carolAmt,
Expiry: carolTimelock,
OnionBlob: carolBlob,
}
// Initialize Carol's attempt ID.
carolAttemptID := 0
err = n.carolServer.htlcSwitch.SendHTLC(
n.carolChannelLink.ShortChanID(), uint64(carolAttemptID),
carolHtlc,
)
require.NoError(t, err)
carolAttemptID++
carolResultChan, err := n.carolServer.htlcSwitch.GetAttemptResult(
uint64(carolAttemptID-1), carolHash, newMockDeobfuscator(),
)
require.NoError(t, err)
result, ok := <-carolResultChan
require.True(t, ok)
assertFailureCode(
t, result.Error, lnwire.CodeTemporaryChannelFailure,
)
// Send an HTLC from Alice to Carol and assert that it is failed at the
// call to SendHTLC.
htlcAmt, totalTimelock, aliceHops := generateHops(
amt, testStartingHeight, n.firstBobChannelLink,
n.carolChannelLink,
)
blob, err = generateRoute(aliceHops...)
require.NoError(t, err)
aliceMultihopPreimage := lntypes.Preimage{0, 0, 6}
aliceMultihopHash := aliceMultihopPreimage.Hash()
aliceMultihopHtlc := &lnwire.UpdateAddHTLC{
PaymentHash: aliceMultihopHash,
Amount: htlcAmt,
Expiry: totalTimelock,
OnionBlob: blob,
}
// Wait until Alice's expected dust for the remote commitment is just
// under the dust threshold.
aliceOrch := n.aliceServer.htlcSwitch.mailOrchestrator
aliceMbox := aliceOrch.GetOrCreateMailBox(
n.aliceChannelLink.ChanID(), n.aliceChannelLink.ShortChanID(),
)
require.True(t, checkAlmostDust(n.aliceChannelLink, aliceMbox, true))
err = n.aliceServer.htlcSwitch.SendHTLC(
n.aliceChannelLink.ShortChanID(), uint64(357),
aliceMultihopHtlc,
)
require.ErrorIs(t, err, errDustThresholdExceeded)
}
// sendDustHtlcs is a helper function used to send many dust HTLC's to test the
// Switch's dust-threshold logic. It takes a boolean denoting whether or not
// Alice is the sender.
func sendDustHtlcs(t *testing.T, n *threeHopNetwork, alice bool,
amt lnwire.MilliSatoshi, sid lnwire.ShortChannelID) {
t.Helper()
// The number of dust HTLC's we'll send for both Alice and Bob.
numHTLCs := 357
// Extract the destination into a variable. If alice is the sender, the
// destination is Bob.
destLink := n.aliceChannelLink
if alice {
destLink = n.firstBobChannelLink
}
// Create hops that will be used in the onion payload.
htlcAmt, totalTimelock, hops := generateHops(
amt, testStartingHeight, destLink,
)
// Convert the hops to a blob that will be put in the Add message.
blob, err := generateRoute(hops...)
require.NoError(t, err)
// Create a slice to store the preimages.
preimages := make([]lntypes.Preimage, numHTLCs)
// Initialize the attempt ID used in SendHTLC calls.
attemptID := uint64(0)
// Deterministically generate preimages. Avoid the all-zeroes preimage
// because that will be rejected by the database. We'll use a different
// third byte for Alice and Bob.
endByte := byte(2)
if alice {
endByte = byte(3)
}
for i := 0; i < numHTLCs; i++ {
preimages[i] = lntypes.Preimage{byte(i >> 8), byte(i), endByte}
}
sendingSwitch := n.bobServer.htlcSwitch
if alice {
sendingSwitch = n.aliceServer.htlcSwitch
}
// Call SendHTLC in a loop for numHTLCs.
for i := 0; i < numHTLCs; i++ {
// Construct the htlc packet.
hash := preimages[i].Hash()
htlc := &lnwire.UpdateAddHTLC{
PaymentHash: hash,
Amount: htlcAmt,
Expiry: totalTimelock,
OnionBlob: blob,
}
for {
// It may be the case that the dust threshold is hit
// before all 357*2 HTLC's are sent due to double
// counting. Get around this by continuing to send
// until successful.
err = sendingSwitch.SendHTLC(sid, attemptID, htlc)
if err == nil {
break
}
}
attemptID++
}
}
// TestSwitchMailboxDust tests that the switch takes into account the mailbox
// dust when evaluating the dust threshold. The mockChannelLink does not have
// channel state, so this only tests the switch-mailbox interaction.
func TestSwitchMailboxDust(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
carolPeer, err := newMockServer(
t, "carol", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
defer func() {
_ = s.Stop()
}()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
chanID3, carolChanID := genID()
aliceLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
err = s.AddLink(aliceLink)
require.NoError(t, err)
bobLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
err = s.AddLink(bobLink)
require.NoError(t, err)
carolLink := newMockChannelLink(
s, chanID3, carolChanID, emptyScid, carolPeer, true, false,
false, false,
)
err = s.AddLink(carolLink)
require.NoError(t, err)
// mockChannelLink sets the local and remote dust limits of the mailbox
// to 400 satoshis and the feerate to 0. We'll fill the mailbox up with
// dust packets and assert that calls to SendHTLC will fail.
preimage, err := genPreimage()
require.NoError(t, err)
rhash := sha256.Sum256(preimage[:])
amt := lnwire.NewMSatFromSatoshis(350)
addMsg := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: amt,
ChanID: chanID1,
}
// Initialize the carolHTLCID.
var carolHTLCID uint64
// It will take aliceCount HTLC's of 350sats to fill up Alice's mailbox
// to the point where another would put Alice over the dust threshold.
aliceCount := 1428
mailbox := s.mailOrchestrator.GetOrCreateMailBox(chanID1, aliceChanID)
for i := 0; i < aliceCount; i++ {
alicePkt := &htlcPacket{
incomingChanID: carolChanID,
incomingHTLCID: carolHTLCID,
outgoingChanID: aliceChanID,
obfuscator: NewMockObfuscator(),
incomingAmount: amt,
amount: amt,
htlc: addMsg,
}
err = mailbox.AddPacket(alicePkt)
require.NoError(t, err)
carolHTLCID++
}
// Sending one more HTLC to Alice should result in the dust threshold
// being breached.
err = s.SendHTLC(aliceChanID, 0, addMsg)
require.ErrorIs(t, err, errDustThresholdExceeded)
// We'll now call ForwardPackets from Bob to ensure that the mailbox
// sum is also accounted for in the forwarding case.
packet := &htlcPacket{
incomingChanID: bobChanID,
incomingHTLCID: 0,
outgoingChanID: aliceChanID,
obfuscator: NewMockObfuscator(),
incomingAmount: amt,
amount: amt,
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: amt,
ChanID: chanID1,
},
}
err = s.ForwardPackets(nil, packet)
require.NoError(t, err)
// Bob should receive a failure from the switch.
select {
case p := <-bobLink.packets:
require.NotEmpty(t, p.linkFailure)
assertFailureCode(
t, p.linkFailure, lnwire.CodeTemporaryChannelFailure,
)
case <-time.After(5 * time.Second):
t.Fatal("no timely reply from switch")
}
}
// TestSwitchResolution checks the ability of the switch to persist and handle
// resolution messages.
func TestSwitchResolution(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err)
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
t.Cleanup(func() { var _ = s.Stop() })
err = s.Start()
require.NoError(t, err)
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
err = s.AddLink(aliceChannelLink)
require.NoError(t, err)
err = s.AddLink(bobChannelLink)
require.NoError(t, err)
// Create an add htlcPacket that Alice will send to Bob.
preimage, err := genPreimage()
require.NoError(t, err)
rhash := sha256.Sum256(preimage[:])
packet := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
err = s.ForwardPackets(nil, packet)
require.NoError(t, err)
// Bob will receive the packet and open the circuit.
select {
case <-bobChannelLink.packets:
err = bobChannelLink.completeCircuit(packet)
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Check that only one circuit is open.
require.Equal(t, 1, s.circuits.NumOpen())
// We'll send a settle resolution to Switch that should go to Alice.
settleResMsg := contractcourt.ResolutionMsg{
SourceChan: bobChanID,
HtlcIndex: 0,
PreImage: &preimage,
}
// Before the resolution is sent, remove alice's link so we can assert
// that the resolution is actually stored. Otherwise, it would be
// deleted shortly after being sent.
s.RemoveLink(chanID1)
// Send the resolution message.
err = s.ProcessContractResolution(settleResMsg)
require.NoError(t, err)
// Assert that the resolution store contains the settle reoslution.
resMsgs, err := s.resMsgStore.fetchAllResolutionMsg()
require.NoError(t, err)
require.Equal(t, 1, len(resMsgs))
require.Equal(t, settleResMsg.SourceChan, resMsgs[0].SourceChan)
require.Equal(t, settleResMsg.HtlcIndex, resMsgs[0].HtlcIndex)
require.Nil(t, resMsgs[0].Failure)
require.Equal(t, preimage, *resMsgs[0].PreImage)
// Now we'll restart Alice's link and delete the circuit.
err = s.AddLink(aliceChannelLink)
require.NoError(t, err)
// Alice will receive the packet and open the circuit.
select {
case alicePkt := <-aliceChannelLink.packets:
err = aliceChannelLink.completeCircuit(alicePkt)
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Assert that there are no more circuits.
require.Equal(t, 0, s.circuits.NumOpen())
// We'll restart the Switch and assert that Alice does not receive
// another packet.
switchDB := s.cfg.DB.(*channeldb.DB)
err = s.Stop()
require.NoError(t, err)
s, err = initSwitchWithDB(testStartingHeight, switchDB)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
defer func() {
_ = s.Stop()
}()
err = s.AddLink(aliceChannelLink)
require.NoError(t, err)
err = s.AddLink(bobChannelLink)
require.NoError(t, err)
// Alice should not receive a packet since the Switch should have
// deleted the resolution message since the circuit was closed.
select {
case alicePkt := <-aliceChannelLink.packets:
t.Fatalf("received erroneous packet: %v", alicePkt)
case <-time.After(time.Second * 5):
}
// Check that the resolution message no longer exists in the store.
resMsgs, err = s.resMsgStore.fetchAllResolutionMsg()
require.NoError(t, err)
require.Equal(t, 0, len(resMsgs))
}
// TestSwitchForwardFailAlias tests that if ForwardPackets returns a failure
// before actually forwarding, the ChannelUpdate uses the SCID from the
// incoming channel and does not leak private information like the UTXO.
func TestSwitchForwardFailAlias(t *testing.T) {
tests := []struct {
name string
// Whether or not Alice will be a zero-conf channel or an
// option-scid-alias channel (feature-bit).
zeroConf bool
}{
{
name: "option-scid-alias forwarding failure",
zeroConf: false,
},
{
name: "zero-conf forwarding failure",
zeroConf: true,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testSwitchForwardFailAlias(t, test.zeroConf)
})
}
}
func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err)
t.Cleanup(func() { cdb.Close() })
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
// Make Alice's channel zero-conf or option-scid-alias (feature bit).
aliceAlias := lnwire.ShortChannelID{
BlockHeight: 16_000_000,
TxIndex: 5,
TxPosition: 5,
}
var aliceLink *mockChannelLink
if zeroConf {
aliceLink = newMockChannelLink(
s, chanID1, aliceAlias, aliceChanID, alicePeer, true,
true, true, false,
)
} else {
aliceLink = newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true,
true, false, true,
)
aliceLink.addAlias(aliceAlias)
}
err = s.AddLink(aliceLink)
require.NoError(t, err)
bobLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
err = s.AddLink(bobLink)
require.NoError(t, err)
// Create a packet that will be sent from Alice to Bob via the switch.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Forward the packet and check that Bob's channel link received it.
err = s.ForwardPackets(nil, ogPacket)
require.NoError(t, err)
// Assert that the circuits are in the expected state.
require.Equal(t, 1, s.circuits.NumPending())
require.Equal(t, 0, s.circuits.NumOpen())
// Pull packet from Bob's link, and do nothing with it.
select {
case <-bobLink.packets:
case <-s.quit:
t.Fatal("switch shutting down, failed to forward packet")
}
// Now we will restart the Switch to trigger the LoadedFromDisk logic.
err = s.Stop()
require.NoError(t, err)
err = cdb.Close()
require.NoError(t, err)
cdb2, err := channeldb.Open(tempPath)
require.NoError(t, err)
t.Cleanup(func() { cdb2.Close() })
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
require.NoError(t, err)
err = s2.Start()
require.NoError(t, err)
defer func() {
_ = s2.Stop()
}()
var aliceLink2 *mockChannelLink
if zeroConf {
aliceLink2 = newMockChannelLink(
s2, chanID1, aliceAlias, aliceChanID, alicePeer, true,
true, true, false,
)
} else {
aliceLink2 = newMockChannelLink(
s2, chanID1, aliceChanID, emptyScid, alicePeer, true,
true, false, true,
)
aliceLink2.addAlias(aliceAlias)
}
err = s2.AddLink(aliceLink2)
require.NoError(t, err)
bobLink2 := newMockChannelLink(
s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
err = s2.AddLink(bobLink2)
require.NoError(t, err)
// Reforward the ogPacket and wait for Alice to receive a failure
// packet.
err = s2.ForwardPackets(nil, ogPacket)
require.NoError(t, err)
select {
case failPacket := <-aliceLink2.packets:
// Assert that the failPacket does not leak UTXO information.
// This means checking that aliceChanID was not returned.
msg := failPacket.linkFailure.msg
failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure)
require.True(t, ok)
require.Equal(t, aliceAlias, failMsg.Update.ShortChannelID)
case <-s2.quit:
t.Fatal("switch shutting down, failed to forward packet")
}
}
// TestSwitchAliasFailAdd tests that the mailbox does not leak UTXO information
// when failing back an HTLC due to the 5-second timeout. This is tested in the
// switch rather than the mailbox because the mailbox tests do not have the
// proper context (e.g. the Switch's failAliasUpdate function). The caveat here
// is that if the private UTXO is already known, it is fine to send a failure
// back. This tests option-scid-alias (feature-bit) and zero-conf channels.
func TestSwitchAliasFailAdd(t *testing.T) {
tests := []struct {
name string
// Denotes whether the opened channel will be zero-conf.
zeroConf bool
// Denotes whether the opened channel will be private.
private bool
// Denotes whether an alias was used during forwarding.
useAlias bool
}{
{
name: "public zero-conf using alias",
zeroConf: true,
private: false,
useAlias: true,
},
{
name: "public zero-conf using real",
zeroConf: true,
private: false,
useAlias: true,
},
{
name: "private zero-conf using alias",
zeroConf: true,
private: true,
useAlias: true,
},
{
name: "public option-scid-alias using alias",
zeroConf: false,
private: false,
useAlias: true,
},
{
name: "public option-scid-alias using real",
zeroConf: false,
private: false,
useAlias: false,
},
{
name: "private option-scid-alias using alias",
zeroConf: false,
private: true,
useAlias: true,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testSwitchAliasFailAdd(
t, test.zeroConf, test.private, test.useAlias,
)
})
}
}
func testSwitchAliasFailAdd(t *testing.T, zeroConf, private, useAlias bool) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err)
defer cdb.Close()
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err)
// Change the mailOrchestrator's expiry to a second.
s.mailOrchestrator.cfg.expiry = time.Second
err = s.Start()
require.NoError(t, err)
defer func() {
_ = s.Stop()
}()
// Make Alice's channel zero-conf or option-scid-alias (feature bit).
aliceAlias := lnwire.ShortChannelID{
BlockHeight: 16_000_000,
TxIndex: 5,
TxPosition: 5,
}
aliceAlias2 := aliceAlias
aliceAlias2.TxPosition = 6
var aliceLink *mockChannelLink
if zeroConf {
aliceLink = newMockChannelLink(
s, chanID1, aliceAlias, aliceChanID, alicePeer, true,
private, true, false,
)
aliceLink.addAlias(aliceAlias2)
} else {
aliceLink = newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true,
private, false, true,
)
aliceLink.addAlias(aliceAlias)
aliceLink.addAlias(aliceAlias2)
}
err = s.AddLink(aliceLink)
require.NoError(t, err)
bobLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, true, false,
false,
)
err = s.AddLink(bobLink)
require.NoError(t, err)
// Create a packet that Bob will send to Alice via ForwardPackets.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: bobLink.ShortChanID(),
incomingHTLCID: 0,
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Determine which outgoingChanID to set based on the useAlias boolean.
outgoingChanID := aliceChanID
if useAlias {
// Choose randomly from the 2 possible aliases.
aliases := aliceLink.getAliases()
idx := mrand.Intn(len(aliases))
outgoingChanID = aliases[idx]
}
ogPacket.outgoingChanID = outgoingChanID
// Forward the packet so Alice's mailbox fails it backwards.
err = s.ForwardPackets(nil, ogPacket)
require.NoError(t, err)
// Assert that the circuits are in the expected state.
require.Equal(t, 1, s.circuits.NumPending())
require.Equal(t, 0, s.circuits.NumOpen())
// Wait to receive the packet from Bob's mailbox.
select {
case failPacket := <-bobLink.packets:
// Assert that failPacket returns the expected SCID in the
// ChannelUpdate.
msg := failPacket.linkFailure.msg
failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure)
require.True(t, ok)
require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID)
case <-s.quit:
t.Fatal("switch shutting down, failed to receive fail packet")
}
}
// TestSwitchHandlePacketForwardAlias checks that handlePacketForward (which
// calls CheckHtlcForward) does not leak the UTXO in a failure message for
// alias channels. This test requires us to have a REAL link, which we also
// must modify in order to test it properly (e.g. making it a private channel).
// This doesn't lead to good code, but short of refactoring the link-generation
// code there is not a good alternative.
func TestSwitchHandlePacketForward(t *testing.T) {
tests := []struct {
name string
// Denotes whether or not the channel will be zero-conf.
zeroConf bool
// Denotes whether or not the channel will have negotiated the
// option-scid-alias feature-bit and is not zero-conf.
optionFeature bool
// Denotes whether or not the channel will be private.
private bool
// Denotes whether or not the alias will be used for
// forwarding.
useAlias bool
}{
{
name: "public zero-conf using alias",
zeroConf: true,
private: false,
useAlias: true,
},
{
name: "public zero-conf using real",
zeroConf: true,
private: false,
useAlias: false,
},
{
name: "private zero-conf using alias",
zeroConf: true,
private: true,
useAlias: true,
},
{
name: "public option-scid-alias using alias",
zeroConf: false,
optionFeature: true,
private: false,
useAlias: true,
},
{
name: "public option-scid-alias using real",
zeroConf: false,
optionFeature: true,
private: false,
useAlias: false,
},
{
name: "private option-scid-alias using alias",
zeroConf: false,
optionFeature: true,
private: true,
useAlias: true,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testSwitchHandlePacketForward(
t, test.zeroConf, test.private, test.useAlias,
test.optionFeature,
)
})
}
}
func testSwitchHandlePacketForward(t *testing.T, zeroConf, private,
useAlias, optionFeature bool) {
t.Parallel()
// Create a link for Alice that we'll add to the switch.
harness, err :=
newSingleLinkTestHarness(t, btcutil.SatoshiPerBitcoin, 0)
require.NoError(t, err)
aliceLink := harness.aliceLink
s, err := initSwitchWithTempDB(t, testStartingHeight)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer func() {
_ = s.Stop()
}()
// Change Alice's ShortChanID and OtherShortChanID here.
aliceAlias := lnwire.ShortChannelID{
BlockHeight: 16_000_000,
TxIndex: 5,
TxPosition: 5,
}
aliceAlias2 := aliceAlias
aliceAlias2.TxPosition = 6
aliceChannelLink := aliceLink.(*channelLink)
aliceChannelState := aliceChannelLink.channel.State()
// Set the link's GetAliases function.
aliceChannelLink.cfg.GetAliases = func(
base lnwire.ShortChannelID) []lnwire.ShortChannelID {
return []lnwire.ShortChannelID{aliceAlias, aliceAlias2}
}
if !private {
// Change the channel to public depending on the test.
aliceChannelState.ChannelFlags = lnwire.FFAnnounceChannel
}
// If this is an option-scid-alias feature-bit non-zero-conf channel,
// we'll mark the channel as such.
if optionFeature {
aliceChannelState.ChanType |= channeldb.ScidAliasFeatureBit
}
// This is the ShortChannelID field in the OpenChannel struct.
aliceScid := aliceLink.ShortChanID()
if zeroConf {
// Store the alias in the shortChanID field and mark the real
// scid in the database.
err = aliceChannelState.MarkRealScid(aliceScid)
require.NoError(t, err)
aliceChannelState.ChanType |= channeldb.ZeroConfBit
}
err = s.AddLink(aliceLink)
require.NoError(t, err)
// Add a mockChannelLink for Bob.
bobChanID, bobScid := genID()
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
bobLink := newMockChannelLink(
s, bobChanID, bobScid, emptyScid, bobPeer, true, false, false,
false,
)
err = s.AddLink(bobLink)
require.NoError(t, err)
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: bobLink.ShortChanID(),
incomingHTLCID: 0,
incomingAmount: 1000,
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Determine which outgoingChanID to set based on the useAlias bool.
outgoingChanID := aliceScid
if useAlias {
// Choose from the possible aliases.
aliases := aliceLink.getAliases()
idx := mrand.Intn(len(aliases))
outgoingChanID = aliases[idx]
}
ogPacket.outgoingChanID = outgoingChanID
// Forward the packet to Alice and she should fail it back with an
// AmountBelowMinimum FailureMessage.
err = s.ForwardPackets(nil, ogPacket)
require.NoError(t, err)
select {
case failPacket := <-bobLink.packets:
// Assert that failPacket returns the expected ChannelUpdate.
msg := failPacket.linkFailure.msg
failMsg, ok := msg.(*lnwire.FailAmountBelowMinimum)
require.True(t, ok)
require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID)
case <-s.quit:
t.Fatal("switch shutting down, failed to receive failure")
}
}
// TestSwitchAliasInterceptFail tests that when the InterceptableSwitch fails
// an incoming HTLC, it does not leak the on-chain UTXO for option-scid-alias
// (feature bit) or zero-conf channels.
func TestSwitchAliasInterceptFail(t *testing.T) {
tests := []struct {
name string
// Denotes whether or not the incoming channel is a zero-conf
// channel or an option-scid-alias channel instead (feature
// bit).
zeroConf bool
}{
{
name: "option-scid-alias",
zeroConf: false,
},
{
name: "zero-conf",
zeroConf: true,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testSwitchAliasInterceptFail(t, test.zeroConf)
})
}
}
func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) {
t.Parallel()
chanID, aliceScid := genID()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err)
t.Cleanup(func() { cdb.Close() })
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
defer func() {
_ = s.Stop()
}()
// Make Alice's alias here.
aliceAlias := lnwire.ShortChannelID{
BlockHeight: 16_000_000,
TxIndex: 5,
TxPosition: 5,
}
aliceAlias2 := aliceAlias
aliceAlias2.TxPosition = 6
var aliceLink *mockChannelLink
if zeroConf {
aliceLink = newMockChannelLink(
s, chanID, aliceAlias, aliceScid, alicePeer, true,
true, true, false,
)
aliceLink.addAlias(aliceAlias2)
} else {
aliceLink = newMockChannelLink(
s, chanID, aliceScid, emptyScid, alicePeer, true,
true, false, true,
)
aliceLink.addAlias(aliceAlias)
aliceLink.addAlias(aliceAlias2)
}
err = s.AddLink(aliceLink)
require.NoError(t, err)
// Now we'll create the packet that will be sent from the Alice link.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceLink.ShortChanID(),
incomingTimeout: 1000,
incomingHTLCID: 0,
outgoingChanID: lnwire.ShortChannelID{},
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Now setup the interceptable switch so that we can reject this
// packet.
forwardInterceptor := &mockForwardInterceptor{
t: t,
interceptedChan: make(chan InterceptedPacket),
}
notifier := &mock.ChainNotifier{
EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
}
notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight}
interceptSwitch, err := NewInterceptableSwitch(
&InterceptableSwitchConfig{
Switch: s,
Notifier: notifier,
CltvRejectDelta: 10,
CltvInterceptDelta: 13,
},
)
require.NoError(t, err)
require.NoError(t, interceptSwitch.Start())
interceptSwitch.SetInterceptor(forwardInterceptor.InterceptForwardHtlc)
err = interceptSwitch.ForwardPackets(nil, false, ogPacket)
require.NoError(t, err)
inCircuit := forwardInterceptor.getIntercepted().IncomingCircuit
require.NoError(t, interceptSwitch.resolve(&FwdResolution{
Action: FwdActionFail,
Key: inCircuit,
FailureCode: lnwire.CodeTemporaryChannelFailure,
}))
select {
case failPacket := <-aliceLink.packets:
// Assert that failPacket returns the expected ChannelUpdate.
failHtlc, ok := failPacket.htlc.(*lnwire.UpdateFailHTLC)
require.True(t, ok)
fwdErr, err := newMockDeobfuscator().DecryptError(
failHtlc.Reason,
)
require.NoError(t, err)
failure := fwdErr.WireMessage()
failureMsg, ok := failure.(*lnwire.FailTemporaryChannelFailure)
require.True(t, ok)
failScid := failureMsg.Update.ShortChannelID
isAlias := failScid == aliceAlias || failScid == aliceAlias2
require.True(t, isAlias)
case <-s.quit:
t.Fatalf("switch shutting down, failed to receive failure")
}
require.NoError(t, interceptSwitch.Stop())
}