lnd/htlcswitch/switch_test.go
Joost Jager e9440a24a2
htlcswitch/test: more realistic mock encryption
This mock is used in the switch test TestUpdateFailMalformedHTLCErrorConversion.
But because the mock isn't very realistic, it doesn't detect problems
in the handling of malformed failures in the link.
2022-12-02 09:04:59 +01:00

5536 lines
149 KiB
Go

package htlcswitch
import (
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
mrand "math/rand"
"reflect"
"testing"
"time"
"github.com/btcsuite/btcd/btcutil"
"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/htlcswitch/hodl"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntest/mock"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/ticker"
"github.com/stretchr/testify/require"
)
var zeroCircuit = channeldb.CircuitKey{}
var emptyScid = lnwire.ShortChannelID{}
func genPreimage() ([32]byte, error) {
var preimage [32]byte
if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil {
return preimage, err
}
return preimage, nil
}
// TestSwitchAddDuplicateLink tests that the switch will reject duplicate links
// for live links. It also tests that we can successfully add a link after
// having removed it.
func TestSwitchAddDuplicateLink(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, aliceScid := genID()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceScid, emptyScid, alicePeer, false, false,
false, false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
// Alice should have a live link, adding again should fail.
if err := s.AddLink(aliceChannelLink); err == nil {
t.Fatalf("adding duplicate link should have failed")
}
// Remove the live link to ensure the indexes are cleared.
s.RemoveLink(chanID1)
// Alice has no links, adding should succeed.
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
}
// TestSwitchHasActiveLink tests the behavior of HasActiveLink, and asserts that
// it only returns true if a link's short channel id has confirmed (meaning the
// channel is no longer pending) and it's EligibleToForward method returns true,
// i.e. it has received FundingLocked from the remote peer.
func TestSwitchHasActiveLink(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, aliceScid := genID()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceScid, emptyScid, alicePeer, false, false,
false, false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
// The link has been added, but it's still pending. HasActiveLink should
// return false since the link has not been added to the linkIndex
// containing live links.
if s.HasActiveLink(chanID1) {
t.Fatalf("link should not be active yet, still pending")
}
// Finally, simulate the link receiving funding locked by setting its
// eligibility to true.
aliceChannelLink.eligible = true
// The link should now be reported as active, since EligibleToForward
// returns true and the link is in the linkIndex.
if !s.HasActiveLink(chanID1) {
t.Fatalf("link should not be active now")
}
}
// TestSwitchSendPending checks the inability of htlc switch to forward adds
// over pending links.
func TestSwitchSendPending(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
pendingChanID := lnwire.ShortChannelID{}
aliceChannelLink := newMockChannelLink(
s, chanID1, pendingChanID, emptyScid, alicePeer, false, false,
false, false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should is being forwarded from Bob channel
// link to Alice channel link.
preimage, err := genPreimage()
require.NoError(t, err, "unable to generate preimage")
rhash := sha256.Sum256(preimage[:])
packet := &htlcPacket{
incomingChanID: bobChanID,
incomingHTLCID: 0,
outgoingChanID: aliceChanID,
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Send the ADD packet, this should not be forwarded out to the link
// since there are no eligible links.
if err = s.ForwardPackets(nil, packet); err != nil {
t.Fatal(err)
}
select {
case p := <-bobChannelLink.packets:
if p.linkFailure != nil {
err = p.linkFailure
}
case <-time.After(time.Second):
t.Fatal("no timely reply from switch")
}
linkErr, ok := err.(*LinkError)
if !ok {
t.Fatalf("expected link error, got: %T", err)
}
if linkErr.WireMessage().Code() != lnwire.CodeUnknownNextPeer {
t.Fatalf("expected fail unknown next peer, got: %T",
linkErr.WireMessage().Code())
}
// No message should be sent, since the packet was failed.
select {
case <-aliceChannelLink.packets:
t.Fatal("expected not to receive message")
case <-time.After(time.Second):
}
// Since the packet should have been failed, there should be no active
// circuits.
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
// TestSwitchForwardMapping checks that the Switch properly consults its maps
// when forwarding packets.
func TestSwitchForwardMapping(t *testing.T) {
tests := []struct {
name string
// If this is true, then Alice's channel will be private.
alicePrivate bool
// If this is true, then Alice's channel will be a zero-conf
// channel.
zeroConf bool
// If this is true, then Alice's channel will be an
// option-scid-alias feature-bit, non-zero-conf channel.
optionScid bool
// If this is true, then an alias will be used for forwarding.
useAlias bool
// This is Alice's channel alias. This may not be set if this
// is not an option_scid_alias channel (feature bit).
aliceAlias lnwire.ShortChannelID
// This is Alice's confirmed SCID. This may not be set if this
// is a zero-conf channel before confirmation.
aliceReal lnwire.ShortChannelID
// If this is set, we expect Bob forwarding to Alice to fail.
expectErr bool
}{
{
name: "private unconfirmed zero-conf",
alicePrivate: true,
zeroConf: true,
useAlias: true,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_002,
TxIndex: 2,
TxPosition: 2,
},
aliceReal: lnwire.ShortChannelID{},
expectErr: false,
},
{
name: "private confirmed zero-conf",
alicePrivate: true,
zeroConf: true,
useAlias: true,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_003,
TxIndex: 3,
TxPosition: 3,
},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 300000,
TxIndex: 3,
TxPosition: 3,
},
expectErr: false,
},
{
name: "private confirmed zero-conf failure",
alicePrivate: true,
zeroConf: true,
useAlias: false,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_004,
TxIndex: 4,
TxPosition: 4,
},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 300002,
TxIndex: 4,
TxPosition: 4,
},
expectErr: true,
},
{
name: "public unconfirmed zero-conf",
alicePrivate: false,
zeroConf: true,
useAlias: true,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_005,
TxIndex: 5,
TxPosition: 5,
},
aliceReal: lnwire.ShortChannelID{},
expectErr: false,
},
{
name: "public confirmed zero-conf w/ alias",
alicePrivate: false,
zeroConf: true,
useAlias: true,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_006,
TxIndex: 6,
TxPosition: 6,
},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 500000,
TxIndex: 6,
TxPosition: 6,
},
expectErr: false,
},
{
name: "public confirmed zero-conf w/ real",
alicePrivate: false,
zeroConf: true,
useAlias: false,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_007,
TxIndex: 7,
TxPosition: 7,
},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 502000,
TxIndex: 7,
TxPosition: 7,
},
expectErr: false,
},
{
name: "private non-option channel",
alicePrivate: true,
aliceAlias: lnwire.ShortChannelID{},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 505000,
TxIndex: 8,
TxPosition: 8,
},
},
{
name: "private option channel w/ alias",
alicePrivate: true,
optionScid: true,
useAlias: true,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_015,
TxIndex: 9,
TxPosition: 9,
},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 506000,
TxIndex: 10,
TxPosition: 10,
},
expectErr: false,
},
{
name: "private option channel failure",
alicePrivate: true,
optionScid: true,
useAlias: false,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_016,
TxIndex: 16,
TxPosition: 16,
},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 507000,
TxIndex: 17,
TxPosition: 17,
},
expectErr: true,
},
{
name: "public non-option channel",
alicePrivate: false,
useAlias: false,
aliceAlias: lnwire.ShortChannelID{},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 508000,
TxIndex: 17,
TxPosition: 17,
},
expectErr: false,
},
{
name: "public option channel w/ alias",
alicePrivate: false,
optionScid: true,
useAlias: true,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_018,
TxIndex: 18,
TxPosition: 18,
},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 509000,
TxIndex: 19,
TxPosition: 19,
},
expectErr: false,
},
{
name: "public option channel w/ real",
alicePrivate: false,
optionScid: true,
useAlias: false,
aliceAlias: lnwire.ShortChannelID{
BlockHeight: 16_000_019,
TxIndex: 19,
TxPosition: 19,
},
aliceReal: lnwire.ShortChannelID{
BlockHeight: 510000,
TxIndex: 20,
TxPosition: 20,
},
expectErr: false,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
testSwitchForwardMapping(
t, test.alicePrivate, test.zeroConf,
test.useAlias, test.optionScid,
test.aliceAlias, test.aliceReal,
test.expectErr,
)
})
}
}
func testSwitchForwardMapping(t *testing.T, alicePrivate, aliceZeroConf,
useAlias, optionScid bool, aliceAlias, aliceReal lnwire.ShortChannelID,
expectErr bool) {
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
defer func() { _ = s.Stop() }()
// Create the lnwire.ChannelIDs that we'll use.
chanID1, chanID2, _, _ := genIDs()
var aliceChannelLink *mockChannelLink
if aliceZeroConf {
aliceChannelLink = newMockChannelLink(
s, chanID1, aliceAlias, aliceReal, alicePeer, true,
alicePrivate, true, false,
)
} else {
aliceChannelLink = newMockChannelLink(
s, chanID1, aliceReal, emptyScid, alicePeer, true,
alicePrivate, false, optionScid,
)
if optionScid {
aliceChannelLink.addAlias(aliceAlias)
}
}
err = s.AddLink(aliceChannelLink)
require.NoError(t, err)
// Bob will just have a non-option_scid_alias channel so no mapping is
// necessary.
bobScid := lnwire.ShortChannelID{
BlockHeight: 501000,
TxIndex: 200,
TxPosition: 2,
}
bobChannelLink := newMockChannelLink(
s, chanID2, bobScid, emptyScid, bobPeer, true, false, false,
false,
)
err = s.AddLink(bobChannelLink)
require.NoError(t, err)
// Generate preimage.
preimage, err := genPreimage()
require.NoError(t, err, "unable to generate preimage")
rhash := sha256.Sum256(preimage[:])
// Determine the outgoing SCID to use.
outgoingSCID := aliceReal
if useAlias {
outgoingSCID = aliceAlias
}
packet := &htlcPacket{
incomingChanID: bobScid,
incomingHTLCID: 0,
outgoingChanID: outgoingSCID,
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
err = s.ForwardPackets(nil, packet)
require.NoError(t, err)
// If we expect a forwarding error, then assert that we receive one.
// option_scid_alias forwards may fail if forwarding would be a privacy
// leak.
if expectErr {
select {
case <-bobChannelLink.packets:
case <-time.After(time.Second * 5):
t.Fatal("expected a forwarding error")
}
select {
case <-aliceChannelLink.packets:
t.Fatal("did not expect a packet")
case <-time.After(time.Second * 5):
}
} else {
select {
case <-bobChannelLink.packets:
t.Fatal("did not expect a forwarding error")
case <-time.After(time.Second * 5):
}
select {
case <-aliceChannelLink.packets:
case <-time.After(time.Second * 5):
t.Fatal("expected alice to receive packet")
}
}
}
// TestSwitchSendHTLCMapping tests that SendHTLC will properly route packets to
// zero-conf or option-scid-alias (feature-bit) channels if the confirmed SCID
// is used. It also tests that nothing breaks with the mapping change.
func TestSwitchSendHTLCMapping(t *testing.T) {
tests := []struct {
name string
// If this is true, the channel will be zero-conf.
zeroConf bool
// Denotes whether the channel is option-scid-alias, non
// zero-conf feature bit.
optionFeature bool
// If this is true, then the alias will be used in the packet.
useAlias bool
// This will be the channel alias if there is a mapping.
alias lnwire.ShortChannelID
// This will be the confirmed SCID if the channel is confirmed.
real lnwire.ShortChannelID
}{
{
name: "non-zero-conf real scid w/ option",
zeroConf: false,
optionFeature: true,
useAlias: false,
alias: lnwire.ShortChannelID{
BlockHeight: 10010,
TxIndex: 10,
TxPosition: 10,
},
real: lnwire.ShortChannelID{
BlockHeight: 500000,
TxIndex: 50,
TxPosition: 50,
},
},
{
name: "non-zero-conf real scid no option",
zeroConf: false,
useAlias: false,
alias: lnwire.ShortChannelID{},
real: lnwire.ShortChannelID{
BlockHeight: 400000,
TxIndex: 50,
TxPosition: 50,
},
},
{
name: "zero-conf alias scid w/ conf",
zeroConf: true,
useAlias: true,
alias: lnwire.ShortChannelID{
BlockHeight: 10020,
TxIndex: 20,
TxPosition: 20,
},
real: lnwire.ShortChannelID{
BlockHeight: 450000,
TxIndex: 50,
TxPosition: 50,
},
},
{
name: "zero-conf alias scid no conf",
zeroConf: true,
useAlias: true,
alias: lnwire.ShortChannelID{
BlockHeight: 10015,
TxIndex: 25,
TxPosition: 35,
},
real: lnwire.ShortChannelID{},
},
{
name: "zero-conf real scid",
zeroConf: true,
useAlias: false,
alias: lnwire.ShortChannelID{
BlockHeight: 10035,
TxIndex: 35,
TxPosition: 35,
},
real: lnwire.ShortChannelID{
BlockHeight: 470000,
TxIndex: 35,
TxPosition: 45,
},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
testSwitchSendHtlcMapping(
t, test.zeroConf, test.useAlias, test.alias,
test.real, test.optionFeature,
)
})
}
}
func testSwitchSendHtlcMapping(t *testing.T, zeroConf, useAlias bool, alias,
realScid lnwire.ShortChannelID, optionFeature bool) {
peer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
defer func() { _ = s.Stop() }()
// Create the lnwire.ChannelID that we'll use.
chanID, _ := genID()
var link *mockChannelLink
if zeroConf {
link = newMockChannelLink(
s, chanID, alias, realScid, peer, true, false, true,
false,
)
} else {
link = newMockChannelLink(
s, chanID, realScid, emptyScid, peer, true, false,
false, true,
)
if optionFeature {
link.addAlias(alias)
}
}
err = s.AddLink(link)
require.NoError(t, err)
// Generate preimage.
preimage, err := genPreimage()
require.NoError(t, err)
rhash := sha256.Sum256(preimage[:])
// Determine the outgoing SCID to use.
outgoingSCID := realScid
if useAlias {
outgoingSCID = alias
}
// Send the HTLC and assert that we don't get an error.
htlc := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
}
err = s.SendHTLC(outgoingSCID, 0, htlc)
require.NoError(t, err)
}
// TestSwitchUpdateScid verifies that zero-conf and non-zero-conf
// option-scid-alias (feature bit) channels will have the expected entries in
// the aliasToReal and baseIndex maps.
func TestSwitchUpdateScid(t *testing.T) {
t.Parallel()
peer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
defer func() { _ = s.Stop() }()
// Create the IDs that we'll use.
chanID, chanID2, _, _ := genIDs()
alias := lnwire.ShortChannelID{
BlockHeight: 16_000_000,
TxIndex: 0,
TxPosition: 0,
}
alias2 := alias
alias2.TxPosition = 1
realScid := lnwire.ShortChannelID{
BlockHeight: 500000,
TxIndex: 0,
TxPosition: 0,
}
link := newMockChannelLink(
s, chanID, alias, emptyScid, peer, true, false, true, false,
)
link.addAlias(alias2)
err = s.AddLink(link)
require.NoError(t, err)
// Assert that the zero-conf link does not have entries in the
// aliasToReal map.
s.indexMtx.RLock()
_, ok := s.aliasToReal[alias]
require.False(t, ok)
_, ok = s.aliasToReal[alias2]
require.False(t, ok)
// Assert that both aliases point to the "base" SCID, which is actually
// just the first alias.
baseScid, ok := s.baseIndex[alias]
require.True(t, ok)
require.Equal(t, alias, baseScid)
baseScid, ok = s.baseIndex[alias2]
require.True(t, ok)
require.Equal(t, alias, baseScid)
s.indexMtx.RUnlock()
// We'll set the mock link's confirmed SCID so that UpdateShortChanID
// populates aliasToReal and adds an entry to baseIndex.
link.realScid = realScid
link.confirmedZC = true
err = s.UpdateShortChanID(chanID)
require.NoError(t, err)
// Assert that aliasToReal is populated and there is an entry in
// baseIndex for realScid.
s.indexMtx.RLock()
realMapping, ok := s.aliasToReal[alias]
require.True(t, ok)
require.Equal(t, realScid, realMapping)
realMapping, ok = s.aliasToReal[alias2]
require.True(t, ok)
require.Equal(t, realScid, realMapping)
baseScid, ok = s.baseIndex[realScid]
require.True(t, ok)
require.Equal(t, alias, baseScid)
s.indexMtx.RUnlock()
// Now we'll perform the same checks with a non-zero-conf
// option-scid-alias channel (feature-bit).
optionReal := lnwire.ShortChannelID{
BlockHeight: 600000,
TxIndex: 0,
TxPosition: 0,
}
optionAlias := lnwire.ShortChannelID{
BlockHeight: 12000,
TxIndex: 0,
TxPosition: 0,
}
optionAlias2 := optionAlias
optionAlias2.TxPosition = 1
link2 := newMockChannelLink(
s, chanID2, optionReal, emptyScid, peer, true, false, false,
true,
)
link2.addAlias(optionAlias)
link2.addAlias(optionAlias2)
err = s.AddLink(link2)
require.NoError(t, err)
// Assert that the option-scid-alias link does have entries in the
// aliasToReal and baseIndex maps.
s.indexMtx.RLock()
realMapping, ok = s.aliasToReal[optionAlias]
require.True(t, ok)
require.Equal(t, optionReal, realMapping)
realMapping, ok = s.aliasToReal[optionAlias2]
require.True(t, ok)
require.Equal(t, optionReal, realMapping)
baseScid, ok = s.baseIndex[optionReal]
require.True(t, ok)
require.Equal(t, optionReal, baseScid)
baseScid, ok = s.baseIndex[optionAlias]
require.True(t, ok)
require.Equal(t, optionReal, baseScid)
baseScid, ok = s.baseIndex[optionAlias2]
require.True(t, ok)
require.Equal(t, optionReal, baseScid)
s.indexMtx.RUnlock()
}
// TestSwitchForward checks the ability of htlc switch to forward add/settle
// requests.
func TestSwitchForward(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
s, err := initSwitchWithTempDB(t, testStartingHeight)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage, err := genPreimage()
if err != nil {
t.Fatalf("unable to generate preimage: %v", err)
}
rhash := sha256.Sum256(preimage[:])
packet := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, packet); err != nil {
t.Fatal(err)
}
select {
case <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
if !s.IsForwardedHTLC(bobChannelLink.ShortChanID(), 0) {
t.Fatal("htlc should be identified as forwarded")
}
// Create settle request pretending that bob link handled the add htlc
// request and sent the htlc settle request back. This request should
// be forwarder back to Alice link.
packet = &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimage,
},
}
// Handle the request and checks that payment circuit works properly.
if err := s.ForwardPackets(nil, packet); err != nil {
t.Fatal(err)
}
select {
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.deleteCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
func TestSwitchForwardFailAfterFullAdd(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to open channeldb")
t.Cleanup(func() { cdb.Close() })
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Pull packet from bob's link, but do not perform a full add.
select {
case packet := <-bobChannelLink.packets:
// Complete the payment circuit and assign the outgoing htlc id
// before restarting.
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of circuits")
}
// Now we will restart bob, leaving the forwarding decision for this
// htlc is in the half-added state.
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to reopen channeldb")
t.Cleanup(func() { cdb2.Close() })
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
require.NoError(t, err, "unable reinit switch")
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
// Even though we intend to Stop s2 later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of circuits")
}
// Craft a failure message from the remote peer.
fail := &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{},
}
// Send the fail packet from the remote peer through the switch.
if err := s2.ForwardPackets(nil, fail); err != nil {
t.Fatalf(err.Error())
}
// Pull packet from alice's link, as it should have gone through
// successfully.
select {
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Circuit map should be empty now.
if s2.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Send the fail packet from the remote peer through the switch.
if err := s.ForwardPackets(nil, fail); err != nil {
t.Fatal(err)
}
select {
case <-aliceChannelLink.packets:
t.Fatalf("expected duplicate fail to not arrive at the destination")
case <-time.After(time.Second):
}
}
func TestSwitchForwardSettleAfterFullAdd(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to open channeldb")
t.Cleanup(func() { cdb.Close() })
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Pull packet from bob's link, but do not perform a full add.
select {
case packet := <-bobChannelLink.packets:
// Complete the payment circuit and assign the outgoing htlc id
// before restarting.
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of circuits")
}
// Now we will restart bob, leaving the forwarding decision for this
// htlc is in the half-added state.
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to reopen channeldb")
t.Cleanup(func() { cdb2.Close() })
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
require.NoError(t, err, "unable reinit switch")
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
// Even though we intend to Stop s2 later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of circuits")
}
// Craft a settle message from the remote peer.
settle := &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimage,
},
}
// Send the settle packet from the remote peer through the switch.
if err := s2.ForwardPackets(nil, settle); err != nil {
t.Fatalf(err.Error())
}
// Pull packet from alice's link, as it should have gone through
// successfully.
select {
case packet := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete circuit with in key=%s: %v",
packet.inKey(), err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Circuit map should be empty now.
if s2.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Send the settle packet again, which not arrive at destination.
if err := s2.ForwardPackets(nil, settle); err != nil {
t.Fatal(err)
}
select {
case <-bobChannelLink.packets:
t.Fatalf("expected duplicate fail to not arrive at the destination")
case <-time.After(time.Second):
}
}
func TestSwitchForwardDropAfterFullAdd(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to open channeldb")
t.Cleanup(func() { cdb.Close() })
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of half circuits")
}
// Pull packet from bob's link, but do not perform a full add.
select {
case packet := <-bobChannelLink.packets:
// Complete the payment circuit and assign the outgoing htlc id
// before restarting.
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Now we will restart bob, leaving the forwarding decision for this
// htlc is in the half-added state.
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to reopen channeldb")
t.Cleanup(func() { cdb2.Close() })
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
require.NoError(t, err, "unable reinit switch")
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
// Even though we intend to Stop s2 later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of half circuits")
}
// Resend the failed htlc. The packet will be dropped silently since the
// switch will detect that it has been half added previously.
if err := s2.ForwardPackets(nil, ogPacket); err != nil {
t.Fatal(err)
}
// After detecting an incomplete forward, the fail packet should have
// been returned to the sender.
select {
case <-aliceChannelLink.packets:
t.Fatal("request should not have returned to source")
case <-bobChannelLink.packets:
t.Fatal("request should not have forwarded to destination")
case <-time.After(time.Second):
}
}
func TestSwitchForwardFailAfterHalfAdd(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to open channeldb")
t.Cleanup(func() { cdb.Close() })
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of half circuits")
}
// Pull packet from bob's link, but do not perform a full add.
select {
case <-bobChannelLink.packets:
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Now we will restart bob, leaving the forwarding decision for this
// htlc is in the half-added state.
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to reopen channeldb")
t.Cleanup(func() { cdb2.Close() })
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
require.NoError(t, err, "unable reinit switch")
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
// Even though we intend to Stop s2 later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of half circuits")
}
// Resend the failed htlc, it should be returned to alice since the
// switch will detect that it has been half added previously.
err = s2.ForwardPackets(nil, ogPacket)
if err != nil {
t.Fatal(err)
}
// After detecting an incomplete forward, the fail packet should have
// been returned to the sender.
select {
case pkt := <-aliceChannelLink.packets:
linkErr := pkt.linkFailure
if linkErr.FailureDetail != OutgoingFailureIncompleteForward {
t.Fatalf("expected incomplete forward, got: %v",
linkErr.FailureDetail)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
}
// TestSwitchForwardCircuitPersistence checks the ability of htlc switch to
// maintain the proper entries in the circuit map in the face of restarts.
func TestSwitchForwardCircuitPersistence(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to open channeldb")
t.Cleanup(func() { cdb.Close() })
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Retrieve packet from outgoing link and cache until after restart.
var packet *htlcPacket
select {
case packet = <-bobChannelLink.packets:
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to reopen channeldb")
t.Cleanup(func() { cdb2.Close() })
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
require.NoError(t, err, "unable reinit switch")
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
// Even though we intend to Stop s2 later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of half circuits")
}
// Now that the switch has restarted, complete the payment circuit.
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
// Create settle request pretending that bob link handled the add htlc
// request and sent the htlc settle request back. This request should
// be forwarder back to Alice link.
ogPacket = &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimage,
},
}
// Handle the request and checks that payment circuit works properly.
if err := s2.ForwardPackets(nil, ogPacket); err != nil {
t.Fatal(err)
}
select {
case packet = <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete circuit with in key=%s: %v",
packet.inKey(), err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s2.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits, want 1, got %d",
s2.circuits.NumPending())
}
if s2.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
if err := s2.Stop(); err != nil {
t.Fatal(err)
}
if err := cdb2.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb3, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to reopen channeldb")
t.Cleanup(func() { cdb3.Close() })
s3, err := initSwitchWithDB(testStartingHeight, cdb3)
require.NoError(t, err, "unable reinit switch")
if err := s3.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
defer s3.Stop()
aliceChannelLink = newMockChannelLink(
s3, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink = newMockChannelLink(
s3, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s3.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s3.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s3.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s3.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
}
type multiHopFwdTest struct {
name string
eligible1, eligible2 bool
failure1, failure2 *LinkError
expectedReply lnwire.FailCode
}
// TestCircularForwards tests the allowing/disallowing of circular payments
// through the same channel in the case where the switch is configured to allow
// and disallow same channel circular forwards.
func TestCircularForwards(t *testing.T) {
chanID1, aliceChanID := genID()
preimage := [sha256.Size]byte{1}
hash := sha256.Sum256(preimage[:])
tests := []struct {
name string
allowCircularPayment bool
expectedErr error
}{
{
name: "circular payment allowed",
allowCircularPayment: true,
expectedErr: nil,
},
{
name: "circular payment disallowed",
allowCircularPayment: false,
expectedErr: NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureCircularRoute,
),
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil,
testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v",
err)
}
s, err := initSwitchWithTempDB(t, testStartingHeight)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer func() { _ = s.Stop() }()
// Set the switch to allow or disallow circular routes
// according to the test's requirements.
s.cfg.AllowCircularRoute = test.allowCircularPayment
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer,
true, false, false, false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
// Create a new packet that loops through alice's link
// in a circle.
obfuscator := NewMockObfuscator()
packet := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
outgoingChanID: aliceChannelLink.ShortChanID(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: hash,
Amount: 1,
},
obfuscator: obfuscator,
}
// Attempt to forward the packet and check for the expected
// error.
if err = s.ForwardPackets(nil, packet); err != nil {
t.Fatal(err)
}
select {
case p := <-aliceChannelLink.packets:
if p.linkFailure != nil {
err = p.linkFailure
}
case <-time.After(time.Second):
t.Fatal("no timely reply from switch")
}
if !reflect.DeepEqual(err, test.expectedErr) {
t.Fatalf("expected: %v, got: %v",
test.expectedErr, err)
}
// Ensure that no circuits were opened.
if s.circuits.NumOpen() > 0 {
t.Fatal("do not expect any open circuits")
}
})
}
}
// TestCheckCircularForward tests the error returned by checkCircularForward
// in cases where we allow and disallow same channel circular forwards.
func TestCheckCircularForward(t *testing.T) {
tests := []struct {
name string
// aliasMapping determines whether the test should add an alias
// mapping to Switch alias maps before checkCircularForward.
aliasMapping bool
// allowCircular determines whether we should allow circular
// forwards.
allowCircular bool
// incomingLink is the link that the htlc arrived on.
incomingLink lnwire.ShortChannelID
// outgoingLink is the link that the htlc forward
// is destined to leave on.
outgoingLink lnwire.ShortChannelID
// expectedErr is the error we expect to be returned.
expectedErr *LinkError
}{
{
name: "not circular, allowed in config",
aliasMapping: false,
allowCircular: true,
incomingLink: lnwire.NewShortChanIDFromInt(123),
outgoingLink: lnwire.NewShortChanIDFromInt(321),
expectedErr: nil,
},
{
name: "not circular, not allowed in config",
aliasMapping: false,
allowCircular: false,
incomingLink: lnwire.NewShortChanIDFromInt(123),
outgoingLink: lnwire.NewShortChanIDFromInt(321),
expectedErr: nil,
},
{
name: "circular, allowed in config",
aliasMapping: false,
allowCircular: true,
incomingLink: lnwire.NewShortChanIDFromInt(123),
outgoingLink: lnwire.NewShortChanIDFromInt(123),
expectedErr: nil,
},
{
name: "circular, not allowed in config",
aliasMapping: false,
allowCircular: false,
incomingLink: lnwire.NewShortChanIDFromInt(123),
outgoingLink: lnwire.NewShortChanIDFromInt(123),
expectedErr: NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureCircularRoute,
),
},
{
name: "circular with map, not allowed",
aliasMapping: true,
allowCircular: false,
incomingLink: lnwire.NewShortChanIDFromInt(1 << 60),
outgoingLink: lnwire.NewShortChanIDFromInt(1 << 55),
expectedErr: NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureCircularRoute,
),
},
{
name: "circular with map, not allowed 2",
aliasMapping: true,
allowCircular: false,
incomingLink: lnwire.NewShortChanIDFromInt(1 << 55),
outgoingLink: lnwire.NewShortChanIDFromInt(1 << 60),
expectedErr: NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureCircularRoute,
),
},
{
name: "circular with map, allowed",
aliasMapping: true,
allowCircular: true,
incomingLink: lnwire.NewShortChanIDFromInt(1 << 60),
outgoingLink: lnwire.NewShortChanIDFromInt(1 << 55),
expectedErr: nil,
},
{
name: "circular with map, allowed 2",
aliasMapping: true,
allowCircular: true,
incomingLink: lnwire.NewShortChanIDFromInt(1 << 55),
outgoingLink: lnwire.NewShortChanIDFromInt(1 << 61),
expectedErr: nil,
},
{
name: "not circular, both confirmed SCID",
aliasMapping: false,
allowCircular: false,
incomingLink: lnwire.NewShortChanIDFromInt(1 << 60),
outgoingLink: lnwire.NewShortChanIDFromInt(1 << 61),
expectedErr: nil,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
defer func() { _ = s.Stop() }()
if test.aliasMapping {
// Make the incoming and outgoing point to the
// same base SCID.
inScid := test.incomingLink
outScid := test.outgoingLink
s.indexMtx.Lock()
s.baseIndex[inScid] = outScid
s.baseIndex[outScid] = outScid
s.indexMtx.Unlock()
}
// Check for a circular forward, the hash passed can
// be nil because it is only used for logging.
err = s.checkCircularForward(
test.incomingLink, test.outgoingLink,
test.allowCircular, lntypes.Hash{},
)
if !reflect.DeepEqual(err, test.expectedErr) {
t.Fatalf("expected: %v, got: %v",
test.expectedErr, err)
}
})
}
}
// TestSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes
// along, then we won't attempt to forward it down al ink that isn't yet able
// to forward any HTLC's.
func TestSkipIneligibleLinksMultiHopForward(t *testing.T) {
tests := []multiHopFwdTest{
// None of the channels is eligible.
{
name: "not eligible",
expectedReply: lnwire.CodeUnknownNextPeer,
},
// Channel one has a policy failure and the other channel isn't
// available.
{
name: "policy fail",
eligible1: true,
failure1: NewLinkError(
lnwire.NewFinalIncorrectCltvExpiry(0),
),
expectedReply: lnwire.CodeFinalIncorrectCltvExpiry,
},
// The requested channel is not eligible, but the packet is
// forwarded through the other channel.
{
name: "non-strict success",
eligible2: true,
expectedReply: lnwire.CodeNone,
},
// The requested channel has insufficient bandwidth and the
// other channel's policy isn't satisfied.
{
name: "non-strict policy fail",
eligible1: true,
failure1: NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureInsufficientBalance,
),
eligible2: true,
failure2: NewLinkError(
lnwire.NewFinalIncorrectCltvExpiry(0),
),
expectedReply: lnwire.CodeTemporaryChannelFailure,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testSkipIneligibleLinksMultiHopForward(t, &test)
})
}
}
// testSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes
// along, then we won't attempt to forward it down al ink that isn't yet able
// to forward any HTLC's.
func testSkipIneligibleLinksMultiHopForward(t *testing.T,
testCase *multiHopFwdTest) {
t.Parallel()
var packet *htlcPacket
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, aliceChanID := genID()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
// We'll create a link for Bob, but mark the link as unable to forward
// any new outgoing HTLC's.
chanID2, bobChanID2 := genID()
bobChannelLink1 := newMockChannelLink(
s, chanID2, bobChanID2, emptyScid, bobPeer, testCase.eligible1,
false, false, false,
)
bobChannelLink1.checkHtlcForwardResult = testCase.failure1
chanID3, bobChanID3 := genID()
bobChannelLink2 := newMockChannelLink(
s, chanID3, bobChanID3, emptyScid, bobPeer, testCase.eligible2,
false, false, false,
)
bobChannelLink2.checkHtlcForwardResult = testCase.failure2
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink1); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if err := s.AddLink(bobChannelLink2); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create a new packet that's destined for Bob as an incoming HTLC from
// Alice.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
obfuscator := NewMockObfuscator()
packet = &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink1.ShortChanID(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
obfuscator: obfuscator,
}
// The request to forward should fail as
if err := s.ForwardPackets(nil, packet); err != nil {
t.Fatal(err)
}
// We select from all links and extract the error if exists.
// The packet must be selected but we don't always expect a link error.
var linkError *LinkError
select {
case p := <-aliceChannelLink.packets:
linkError = p.linkFailure
case p := <-bobChannelLink1.packets:
linkError = p.linkFailure
case p := <-bobChannelLink2.packets:
linkError = p.linkFailure
case <-time.After(time.Second):
t.Fatal("no timely reply from switch")
}
failure := obfuscator.(*mockObfuscator).failure
if testCase.expectedReply == lnwire.CodeNone {
if linkError != nil {
t.Fatalf("forwarding should have succeeded")
}
if failure != nil {
t.Fatalf("unexpected failure %T", failure)
}
} else {
if linkError == nil {
t.Fatalf("forwarding should have failed due to " +
"inactive link")
}
if failure.Code() != testCase.expectedReply {
t.Fatalf("unexpected failure %T", failure)
}
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
// TestSkipIneligibleLinksLocalForward ensures that the switch will not attempt
// to forward any HTLC's down a link that isn't yet eligible for forwarding.
func TestSkipIneligibleLinksLocalForward(t *testing.T) {
t.Parallel()
testSkipLinkLocalForward(t, false, nil)
}
// TestSkipPolicyUnsatisfiedLinkLocalForward ensures that the switch will not
// attempt to send locally initiated HTLCs that would violate the channel policy
// down a link.
func TestSkipPolicyUnsatisfiedLinkLocalForward(t *testing.T) {
t.Parallel()
testSkipLinkLocalForward(t, true, lnwire.NewTemporaryChannelFailure(nil))
}
func testSkipLinkLocalForward(t *testing.T, eligible bool,
policyResult lnwire.FailureMessage) {
// We'll create a single link for this test, marking it as being unable
// to forward form the get go.
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, _, aliceChanID, _ := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, eligible, false,
false, false,
)
aliceChannelLink.checkHtlcTransitResult = NewLinkError(
policyResult,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
preimage, err := genPreimage()
require.NoError(t, err, "unable to generate preimage")
rhash := sha256.Sum256(preimage[:])
addMsg := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
}
// We'll attempt to send out a new HTLC that has Alice as the first
// outgoing link. This should fail as Alice isn't yet able to forward
// any active HTLC's.
err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg)
if err == nil {
t.Fatalf("local forward should fail due to inactive link")
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
// TestSwitchCancel checks that if htlc was rejected we remove unused
// circuits.
func TestSwitchCancel(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarder from alice channel link
// to bob channel link.
preimage, err := genPreimage()
require.NoError(t, err, "unable to generate preimage")
rhash := sha256.Sum256(preimage[:])
request := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, request); err != nil {
t.Fatal(err)
}
select {
case packet := <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
// Create settle request pretending that bob channel link handled
// the add htlc request and sent the htlc settle request back. This
// request should be forwarder back to alice channel link.
request = &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{},
}
// Handle the request and checks that payment circuit works properly.
if err := s.ForwardPackets(nil, request); err != nil {
t.Fatal(err)
}
select {
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s.circuits.NumPending() != 0 {
t.Fatal("wrong amount of circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
// TestSwitchAddSamePayment tests that we send the payment with the same
// payment hash.
func TestSwitchAddSamePayment(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarder from alice channel link
// to bob channel link.
preimage, err := genPreimage()
require.NoError(t, err, "unable to generate preimage")
rhash := sha256.Sum256(preimage[:])
request := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, request); err != nil {
t.Fatal(err)
}
select {
case packet := <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
request = &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 1,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Handle the request and checks that bob channel link received it.
if err := s.ForwardPackets(nil, request); err != nil {
t.Fatal(err)
}
select {
case packet := <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumOpen() != 2 {
t.Fatal("wrong amount of circuits")
}
// Create settle request pretending that bob channel link handled
// the add htlc request and sent the htlc settle request back. This
// request should be forwarder back to alice channel link.
request = &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{},
}
// Handle the request and checks that payment circuit works properly.
if err := s.ForwardPackets(nil, request); err != nil {
t.Fatal(err)
}
select {
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
request = &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 1,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{},
}
// Handle the request and checks that payment circuit works properly.
if err := s.ForwardPackets(nil, request); err != nil {
t.Fatal(err)
}
select {
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
// TestSwitchSendPayment tests ability of htlc switch to respond to the
// users when response is came back from channel link.
func TestSwitchSendPayment(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, _, aliceChanID, _ := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add link: %v", err)
}
// Create request which should be forwarder from alice channel link
// to bob channel link.
preimage, err := genPreimage()
require.NoError(t, err, "unable to generate preimage")
rhash := sha256.Sum256(preimage[:])
update := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
}
paymentID := uint64(123)
// First check that the switch will correctly respond that this payment
// ID is unknown.
_, err = s.GetPaymentResult(
paymentID, rhash, newMockDeobfuscator(),
)
if err != ErrPaymentIDNotFound {
t.Fatalf("expected ErrPaymentIDNotFound, got %v", err)
}
// Handle the request and checks that bob channel link received it.
errChan := make(chan error)
go func() {
err := s.SendHTLC(
aliceChannelLink.ShortChanID(), paymentID, update,
)
if err != nil {
errChan <- err
return
}
resultChan, err := s.GetPaymentResult(
paymentID, rhash, newMockDeobfuscator(),
)
if err != nil {
errChan <- err
return
}
result, ok := <-resultChan
if !ok {
errChan <- fmt.Errorf("shutting down")
}
if result.Error != nil {
errChan <- result.Error
return
}
errChan <- nil
}()
select {
case packet := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case err := <-errChan:
if err != nil {
t.Fatalf("unable to send payment: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
// Create fail request pretending that bob channel link handled
// the add htlc request with error and sent the htlc fail request
// back. This request should be forwarded back to alice channel link.
obfuscator := NewMockObfuscator()
failure := lnwire.NewFailIncorrectDetails(update.Amount, 100)
reason, err := obfuscator.EncryptFirstHop(failure)
require.NoError(t, err, "unable obfuscate failure")
if s.IsForwardedHTLC(aliceChannelLink.ShortChanID(), update.ID) {
t.Fatal("htlc should be identified as not forwarded")
}
packet := &htlcPacket{
outgoingChanID: aliceChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{
Reason: reason,
},
}
if err := s.ForwardPackets(nil, packet); err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
select {
case err := <-errChan:
assertFailureCode(
t, err, lnwire.CodeIncorrectOrUnknownPaymentDetails,
)
case <-time.After(time.Second):
t.Fatal("err wasn't received")
}
}
// TestLocalPaymentNoForwardingEvents tests that if we send a series of locally
// initiated payments, then they aren't reflected in the forwarding log.
func TestLocalPaymentNoForwardingEvents(t *testing.T) {
t.Parallel()
// First, we'll create our traditional three hop network. We'll only be
// interacting with and asserting the state of the first end point for
// this test.
channels, _, err := createClusterChannels(
t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5,
)
require.NoError(t, err, "unable to create channel")
n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice,
channels.bobToCarol, channels.carolToBob, testStartingHeight)
if err := n.start(); err != nil {
t.Fatalf("unable to start three hop network: %v", err)
}
// We'll now craft and send a payment from Alice to Bob.
amount := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin)
htlcAmt, totalTimelock, hops := generateHops(
amount, testStartingHeight, n.firstBobChannelLink,
)
// With the payment crafted, we'll send it from Alice to Bob. We'll
// wait for Alice to receive the preimage for the payment before
// proceeding.
receiver := n.bobServer
firstHop := n.firstBobChannelLink.ShortChanID()
_, err = makePayment(
n.aliceServer, receiver, firstHop, hops, amount, htlcAmt,
totalTimelock,
).Wait(30 * time.Second)
require.NoError(t, err, "unable to make the payment")
// At this point, we'll forcibly stop the three hop network. Doing
// this will cause any pending forwarding events to be flushed by the
// various switches in the network.
n.stop()
// With all the switches stopped, we'll fetch Alice's mock forwarding
// event log.
log, ok := n.aliceServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog)
if !ok {
t.Fatalf("mockForwardingLog assertion failed")
}
log.Lock()
defer log.Unlock()
// If we examine the memory of the forwarding log, then it should be
// blank.
if len(log.events) != 0 {
t.Fatalf("log should have no events, instead has: %v",
spew.Sdump(log.events))
}
}
// TestMultiHopPaymentForwardingEvents tests that if we send a series of
// multi-hop payments via Alice->Bob->Carol. Then Bob properly logs forwarding
// events, while Alice and Carol don't.
func TestMultiHopPaymentForwardingEvents(t *testing.T) {
t.Parallel()
// First, we'll create our traditional three hop network.
channels, _, err := createClusterChannels(
t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5,
)
require.NoError(t, err, "unable to create channel")
n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice,
channels.bobToCarol, channels.carolToBob, testStartingHeight)
if err := n.start(); err != nil {
t.Fatalf("unable to start three hop network: %v", err)
}
// We'll make now 10 payments, of 100k satoshis each from Alice to
// Carol via Bob.
const numPayments = 10
finalAmt := lnwire.NewMSatFromSatoshis(100000)
htlcAmt, totalTimelock, hops := generateHops(
finalAmt, testStartingHeight, n.firstBobChannelLink,
n.carolChannelLink,
)
firstHop := n.firstBobChannelLink.ShortChanID()
for i := 0; i < numPayments/2; i++ {
_, err := makePayment(
n.aliceServer, n.carolServer, firstHop, hops, finalAmt,
htlcAmt, totalTimelock,
).Wait(30 * time.Second)
if err != nil {
t.Fatalf("unable to send payment: %v", err)
}
}
bobLog, ok := n.bobServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog)
if !ok {
t.Fatalf("mockForwardingLog assertion failed")
}
// After sending 5 of the payments, trigger the forwarding ticker, to
// make sure the events are properly flushed.
bobTicker, ok := n.bobServer.htlcSwitch.cfg.FwdEventTicker.(*ticker.Force)
if !ok {
t.Fatalf("mockTicker assertion failed")
}
// We'll trigger the ticker, and wait for the events to appear in Bob's
// forwarding log.
timeout := time.After(15 * time.Second)
for {
select {
case bobTicker.Force <- time.Now():
case <-time.After(1 * time.Second):
t.Fatalf("unable to force tick")
}
// If all 5 events is found in Bob's log, we can break out and
// continue the test.
bobLog.Lock()
if len(bobLog.events) == 5 {
bobLog.Unlock()
break
}
bobLog.Unlock()
// Otherwise wait a little bit before checking again.
select {
case <-time.After(50 * time.Millisecond):
case <-timeout:
bobLog.Lock()
defer bobLog.Unlock()
t.Fatalf("expected 5 events in event log, instead "+
"found: %v", spew.Sdump(bobLog.events))
}
}
// Send the remaining payments.
for i := numPayments / 2; i < numPayments; i++ {
_, err := makePayment(
n.aliceServer, n.carolServer, firstHop, hops, finalAmt,
htlcAmt, totalTimelock,
).Wait(30 * time.Second)
if err != nil {
t.Fatalf("unable to send payment: %v", err)
}
}
// With all 10 payments sent. We'll now manually stop each of the
// switches so we can examine their end state.
n.stop()
// Alice and Carol shouldn't have any recorded forwarding events, as
// they were the source and the sink for these payment flows.
aliceLog, ok := n.aliceServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog)
if !ok {
t.Fatalf("mockForwardingLog assertion failed")
}
aliceLog.Lock()
defer aliceLog.Unlock()
if len(aliceLog.events) != 0 {
t.Fatalf("log should have no events, instead has: %v",
spew.Sdump(aliceLog.events))
}
carolLog, ok := n.carolServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog)
if !ok {
t.Fatalf("mockForwardingLog assertion failed")
}
carolLog.Lock()
defer carolLog.Unlock()
if len(carolLog.events) != 0 {
t.Fatalf("log should have no events, instead has: %v",
spew.Sdump(carolLog.events))
}
// Bob on the other hand, should have 10 events.
bobLog.Lock()
defer bobLog.Unlock()
if len(bobLog.events) != 10 {
t.Fatalf("log should have 10 events, instead has: %v",
spew.Sdump(bobLog.events))
}
// Each of the 10 events should have had all fields set properly.
for _, event := range bobLog.events {
// The incoming and outgoing channels should properly be set for
// the event.
if event.IncomingChanID != n.aliceChannelLink.ShortChanID() {
t.Fatalf("chan id mismatch: expected %v, got %v",
event.IncomingChanID,
n.aliceChannelLink.ShortChanID())
}
if event.OutgoingChanID != n.carolChannelLink.ShortChanID() {
t.Fatalf("chan id mismatch: expected %v, got %v",
event.OutgoingChanID,
n.carolChannelLink.ShortChanID())
}
// Additionally, the incoming and outgoing amounts should also
// be properly set.
if event.AmtIn != htlcAmt {
t.Fatalf("incoming amt mismatch: expected %v, got %v",
event.AmtIn, htlcAmt)
}
if event.AmtOut != finalAmt {
t.Fatalf("outgoing amt mismatch: expected %v, got %v",
event.AmtOut, finalAmt)
}
}
}
// TestUpdateFailMalformedHTLCErrorConversion tests that we're able to properly
// convert malformed HTLC errors that originate at the direct link, as well as
// during multi-hop HTLC forwarding.
func TestUpdateFailMalformedHTLCErrorConversion(t *testing.T) {
t.Parallel()
// First, we'll create our traditional three hop network.
channels, _, err := createClusterChannels(
t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5,
)
require.NoError(t, err, "unable to create channel")
n := newThreeHopNetwork(
t, channels.aliceToBob, channels.bobToAlice,
channels.bobToCarol, channels.carolToBob, testStartingHeight,
)
if err := n.start(); err != nil {
t.Fatalf("unable to start three hop network: %v", err)
}
assertPaymentFailure := func(t *testing.T) {
// With the decoder modified, we'll now attempt to send a
// payment from Alice to carol.
finalAmt := lnwire.NewMSatFromSatoshis(100000)
htlcAmt, totalTimelock, hops := generateHops(
finalAmt, testStartingHeight, n.firstBobChannelLink,
n.carolChannelLink,
)
firstHop := n.firstBobChannelLink.ShortChanID()
_, err = makePayment(
n.aliceServer, n.carolServer, firstHop, hops, finalAmt,
htlcAmt, totalTimelock,
).Wait(30 * time.Second)
// The payment should fail as Carol is unable to decode the
// onion blob sent to her.
if err == nil {
t.Fatalf("unable to send payment: %v", err)
}
routingErr := err.(ClearTextError)
failureMsg := routingErr.WireMessage()
if _, ok := failureMsg.(*lnwire.FailInvalidOnionKey); !ok {
t.Fatalf("expected onion failure instead got: %v",
routingErr.WireMessage())
}
}
t.Run("multi-hop error conversion", func(t *testing.T) {
// Now that we have our network up, we'll modify the hop
// iterator for the Bob <-> Carol channel to fail to decode in
// order to simulate either a replay attack or an issue
// decoding the onion.
n.carolOnionDecoder.decodeFail = true
assertPaymentFailure(t)
})
t.Run("direct channel error conversion", func(t *testing.T) {
// Similar to the above test case, we'll now make the Alice <->
// Bob link always fail to decode an onion. This differs from
// the above test case in that there's no encryption on the
// error at all since Alice will directly receive a
// UpdateFailMalformedHTLC message.
n.bobOnionDecoder.decodeFail = true
assertPaymentFailure(t)
})
}
// TestSwitchGetPaymentResult tests that the switch interacts as expected with
// the circuit map and network result store when looking up the result of a
// payment ID. This is important for not to lose results under concurrent
// lookup and receiving results.
func TestSwitchGetPaymentResult(t *testing.T) {
t.Parallel()
const paymentID = 123
var preimg lntypes.Preimage
preimg[0] = 3
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
lookup := make(chan *PaymentCircuit, 1)
s.circuits = &mockCircuitMap{
lookup: lookup,
}
// If the payment circuit is not found in the circuit map, the payment
// result must be found in the store if available. Since we haven't
// added anything to the store yet, ErrPaymentIDNotFound should be
// returned.
lookup <- nil
_, err = s.GetPaymentResult(
paymentID, lntypes.Hash{}, newMockDeobfuscator(),
)
if err != ErrPaymentIDNotFound {
t.Fatalf("expected ErrPaymentIDNotFound, got %v", err)
}
// Next let the lookup find the circuit in the circuit map. It should
// subscribe to payment results, and return the result when available.
lookup <- &PaymentCircuit{}
resultChan, err := s.GetPaymentResult(
paymentID, lntypes.Hash{}, newMockDeobfuscator(),
)
require.NoError(t, err, "unable to get payment result")
// Add the result to the store.
n := &networkResult{
msg: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimg,
},
unencrypted: true,
isResolution: true,
}
err = s.networkResults.storeResult(paymentID, n)
require.NoError(t, err, "unable to store result")
// The result should be available.
select {
case res, ok := <-resultChan:
if !ok {
t.Fatalf("channel was closed")
}
if res.Error != nil {
t.Fatalf("got unexpected error result")
}
if res.Preimage != preimg {
t.Fatalf("expected preimg %v, got %v",
preimg, res.Preimage)
}
case <-time.After(1 * time.Second):
t.Fatalf("result not received")
}
// As a final test, try to get the result again. Now that is no longer
// in the circuit map, it should be immediately available from the
// store.
lookup <- nil
resultChan, err = s.GetPaymentResult(
paymentID, lntypes.Hash{}, newMockDeobfuscator(),
)
require.NoError(t, err, "unable to get payment result")
select {
case res, ok := <-resultChan:
if !ok {
t.Fatalf("channel was closed")
}
if res.Error != nil {
t.Fatalf("got unexpected error result")
}
if res.Preimage != preimg {
t.Fatalf("expected preimg %v, got %v",
preimg, res.Preimage)
}
case <-time.After(1 * time.Second):
t.Fatalf("result not received")
}
}
// TestInvalidFailure tests that the switch returns an unreadable failure error
// if the failure cannot be decrypted.
func TestInvalidFailure(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, _, aliceChanID, _ := genIDs()
// Set up a mock channel link.
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add link: %v", err)
}
// Create a request which should be forwarded to the mock channel link.
preimage, err := genPreimage()
require.NoError(t, err, "unable to generate preimage")
rhash := sha256.Sum256(preimage[:])
update := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
}
paymentID := uint64(123)
// Send the request.
err = s.SendHTLC(
aliceChannelLink.ShortChanID(), paymentID, update,
)
require.NoError(t, err, "unable to send payment")
// Catch the packet and complete the circuit so that the switch is ready
// for a response.
select {
case packet := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Send response packet with an unreadable failure message to the
// switch. The reason failed is not relevant, because we mock the
// decryption.
packet := &htlcPacket{
outgoingChanID: aliceChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{
Reason: []byte{1, 2, 3},
},
}
if err := s.ForwardPackets(nil, packet); err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
// Get payment result from switch. We expect an unreadable failure
// message error.
deobfuscator := SphinxErrorDecrypter{
OnionErrorDecrypter: &mockOnionErrorDecryptor{
err: ErrUnreadableFailureMessage,
},
}
resultChan, err := s.GetPaymentResult(
paymentID, rhash, &deobfuscator,
)
if err != nil {
t.Fatal(err)
}
select {
case result := <-resultChan:
if result.Error != ErrUnreadableFailureMessage {
t.Fatal("expected unreadable failure message")
}
case <-time.After(time.Second):
t.Fatal("err wasn't received")
}
// Modify the decryption to simulate that decryption went alright, but
// the failure cannot be decoded.
deobfuscator = SphinxErrorDecrypter{
OnionErrorDecrypter: &mockOnionErrorDecryptor{
sourceIdx: 2,
message: []byte{200},
},
}
resultChan, err = s.GetPaymentResult(
paymentID, rhash, &deobfuscator,
)
if err != nil {
t.Fatal(err)
}
select {
case result := <-resultChan:
rtErr, ok := result.Error.(ClearTextError)
if !ok {
t.Fatal("expected ClearTextError")
}
source, ok := rtErr.(*ForwardingError)
if !ok {
t.Fatalf("expected forwarding error, got: %T", rtErr)
}
if source.FailureSourceIdx != 2 {
t.Fatal("unexpected error source index")
}
if rtErr.WireMessage() != nil {
t.Fatal("expected empty failure message")
}
case <-time.After(time.Second):
t.Fatal("err wasn't received")
}
}
// htlcNotifierEvents is a function that generates a set of expected htlc
// notifier evetns for each node in a three hop network with the dynamic
// values provided. These functions take dynamic values so that changes to
// external systems (such as our default timelock delta) do not break
// these tests.
type htlcNotifierEvents func(channels *clusterChannels, htlcID uint64,
ts time.Time, htlc *lnwire.UpdateAddHTLC,
hops []*hop.Payload,
preimage *lntypes.Preimage) ([]interface{}, []interface{}, []interface{})
// TestHtlcNotifier tests the notifying of htlc events that are routed over a
// three hop network. It sets up an Alice -> Bob -> Carol network and routes
// payments from Alice -> Carol to test events from the perspective of a
// sending (Alice), forwarding (Bob) and receiving (Carol) node. Test cases
// are present for saduccessful and failed payments.
func TestHtlcNotifier(t *testing.T) {
tests := []struct {
name string
// Options is a set of options to apply to the three hop
// network's servers.
options []serverOption
// expectedEvents is a function which returns an expected set
// of events for the test.
expectedEvents htlcNotifierEvents
// iterations is the number of times we will send a payment,
// this is used to send more than one payment to force non-
// zero htlc indexes to make sure we aren't just checking
// default values.
iterations int
}{
{
name: "successful three hop payment",
options: nil,
expectedEvents: func(channels *clusterChannels,
htlcID uint64, ts time.Time,
htlc *lnwire.UpdateAddHTLC,
hops []*hop.Payload,
preimage *lntypes.Preimage) ([]interface{},
[]interface{}, []interface{}) {
return getThreeHopEvents(
channels, htlcID, ts, htlc, hops, nil, preimage,
)
},
iterations: 2,
},
{
name: "failed at forwarding link",
// Set a functional option which disables bob as a
// forwarding node to force a payment error.
options: []serverOption{
serverOptionRejectHtlc(false, true, false),
},
expectedEvents: func(channels *clusterChannels,
htlcID uint64, ts time.Time,
htlc *lnwire.UpdateAddHTLC,
hops []*hop.Payload,
preimage *lntypes.Preimage) ([]interface{},
[]interface{}, []interface{}) {
return getThreeHopEvents(
channels, htlcID, ts, htlc, hops,
&LinkError{
msg: &lnwire.FailChannelDisabled{},
FailureDetail: OutgoingFailureForwardsDisabled,
},
preimage,
)
},
iterations: 1,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testHtcNotifier(
t, test.options, test.iterations,
test.expectedEvents,
)
})
}
}
// testHtcNotifier runs a htlc notifier test.
func testHtcNotifier(t *testing.T, testOpts []serverOption, iterations int,
getEvents htlcNotifierEvents) {
t.Parallel()
// First, we'll create our traditional three hop
// network.
channels, _, err := createClusterChannels(
t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5,
)
require.NoError(t, err, "unable to create channel")
// Mock time so that all events are reported with a static timestamp.
now := time.Now()
mockTime := func() time.Time {
return now
}
// Create htlc notifiers for each server in the three hop network and
// start them.
aliceNotifier := NewHtlcNotifier(mockTime)
if err := aliceNotifier.Start(); err != nil {
t.Fatalf("could not start alice notifier")
}
t.Cleanup(func() {
if err := aliceNotifier.Stop(); err != nil {
t.Fatalf("failed to stop alice notifier: %v", err)
}
})
bobNotifier := NewHtlcNotifier(mockTime)
if err := bobNotifier.Start(); err != nil {
t.Fatalf("could not start bob notifier")
}
t.Cleanup(func() {
if err := bobNotifier.Stop(); err != nil {
t.Fatalf("failed to stop bob notifier: %v", err)
}
})
carolNotifier := NewHtlcNotifier(mockTime)
if err := carolNotifier.Start(); err != nil {
t.Fatalf("could not start carol notifier")
}
t.Cleanup(func() {
if err := carolNotifier.Stop(); err != nil {
t.Fatalf("failed to stop carol notifier: %v", err)
}
})
// Create a notifier server option which will set our htlc notifiers
// for the three hop network.
notifierOption := serverOptionWithHtlcNotifier(
aliceNotifier, bobNotifier, carolNotifier,
)
// Add the htlcNotifier option to any other options
// set in the test.
options := append(testOpts, notifierOption) // nolint:gocritic
n := newThreeHopNetwork(
t, channels.aliceToBob,
channels.bobToAlice, channels.bobToCarol,
channels.carolToBob, testStartingHeight,
options...,
)
if err := n.start(); err != nil {
t.Fatalf("unable to start three hop "+
"network: %v", err)
}
t.Cleanup(n.stop)
// Before we forward anything, subscribe to htlc events
// from each notifier.
aliceEvents, err := aliceNotifier.SubscribeHtlcEvents()
if err != nil {
t.Fatalf("could not subscribe to alice's"+
" events: %v", err)
}
t.Cleanup(aliceEvents.Cancel)
bobEvents, err := bobNotifier.SubscribeHtlcEvents()
if err != nil {
t.Fatalf("could not subscribe to bob's"+
" events: %v", err)
}
t.Cleanup(bobEvents.Cancel)
carolEvents, err := carolNotifier.SubscribeHtlcEvents()
if err != nil {
t.Fatalf("could not subscribe to carol's"+
" events: %v", err)
}
t.Cleanup(carolEvents.Cancel)
// Send multiple payments, as specified by the test to test incrementing
// of htlc ids.
for i := 0; i < iterations; i++ {
// We'll start off by making a payment from
// Alice -> Bob -> Carol. The preimage, generated
// by Carol's Invoice is expected in the Settle events
htlc, hops, preimage := n.sendThreeHopPayment(t)
alice, bob, carol := getEvents(
channels, uint64(i), now, htlc, hops, preimage,
)
checkHtlcEvents(t, aliceEvents.Updates(), alice)
checkHtlcEvents(t, bobEvents.Updates(), bob)
checkHtlcEvents(t, carolEvents.Updates(), carol)
}
}
// checkHtlcEvents checks that a subscription has the set of htlc events
// we expect it to have.
func checkHtlcEvents(t *testing.T, events <-chan interface{},
expectedEvents []interface{}) {
t.Helper()
for _, expected := range expectedEvents {
select {
case event := <-events:
if !reflect.DeepEqual(event, expected) {
t.Fatalf("expected %v, got: %v", expected,
event)
}
case <-time.After(5 * time.Second):
t.Fatalf("expected event: %v", expected)
}
}
// Check that there are no unexpected events following.
select {
case event := <-events:
t.Fatalf("unexpected event: %v", event)
default:
}
}
// sendThreeHopPayment is a helper function which sends a payment over
// Alice -> Bob -> Carol in a three hop network and returns Alice's first htlc
// and the remainder of the hops.
func (n *threeHopNetwork) sendThreeHopPayment(t *testing.T) (*lnwire.UpdateAddHTLC,
[]*hop.Payload, *lntypes.Preimage) {
amount := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin)
htlcAmt, totalTimelock, hops := generateHops(amount, testStartingHeight,
n.firstBobChannelLink, n.carolChannelLink)
blob, err := generateRoute(hops...)
if err != nil {
t.Fatal(err)
}
invoice, htlc, pid, err := generatePayment(
amount, htlcAmt, totalTimelock, blob,
)
if err != nil {
t.Fatal(err)
}
err = n.carolServer.registry.AddInvoice(*invoice, htlc.PaymentHash)
require.NoError(t, err, "unable to add invoice in carol registry")
if err := n.aliceServer.htlcSwitch.SendHTLC(
n.firstBobChannelLink.ShortChanID(), pid, htlc,
); err != nil {
t.Fatalf("could not send htlc")
}
return htlc, hops, invoice.Terms.PaymentPreimage
}
// getThreeHopEvents gets the set of htlc events that we expect for a payment
// from Alice -> Bob -> Carol. If a non-nil link error is provided, the set
// of events will fail on Bob's outgoing link.
func getThreeHopEvents(channels *clusterChannels, htlcID uint64,
ts time.Time, htlc *lnwire.UpdateAddHTLC, hops []*hop.Payload,
linkError *LinkError,
preimage *lntypes.Preimage) ([]interface{}, []interface{}, []interface{}) {
aliceKey := HtlcKey{
IncomingCircuit: zeroCircuit,
OutgoingCircuit: channeldb.CircuitKey{
ChanID: channels.aliceToBob.ShortChanID(),
HtlcID: htlcID,
},
}
// Alice always needs a forwarding event because she initiates the
// send.
aliceEvents := []interface{}{
&ForwardingEvent{
HtlcKey: aliceKey,
HtlcInfo: HtlcInfo{
OutgoingTimeLock: htlc.Expiry,
OutgoingAmt: htlc.Amount,
},
HtlcEventType: HtlcEventTypeSend,
Timestamp: ts,
},
}
bobKey := HtlcKey{
IncomingCircuit: channeldb.CircuitKey{
ChanID: channels.bobToAlice.ShortChanID(),
HtlcID: htlcID,
},
OutgoingCircuit: channeldb.CircuitKey{
ChanID: channels.bobToCarol.ShortChanID(),
HtlcID: htlcID,
},
}
bobInfo := HtlcInfo{
IncomingTimeLock: htlc.Expiry,
IncomingAmt: htlc.Amount,
OutgoingTimeLock: hops[1].FwdInfo.OutgoingCTLV,
OutgoingAmt: hops[1].FwdInfo.AmountToForward,
}
// If we expect the payment to fail, we add failures for alice and
// bob, and no events for carol because the payment never reaches her.
if linkError != nil {
aliceEvents = append(aliceEvents,
&ForwardingFailEvent{
HtlcKey: aliceKey,
HtlcEventType: HtlcEventTypeSend,
Timestamp: ts,
},
)
bobEvents := []interface{}{
&LinkFailEvent{
HtlcKey: bobKey,
HtlcInfo: bobInfo,
HtlcEventType: HtlcEventTypeForward,
LinkError: linkError,
Incoming: false,
Timestamp: ts,
},
&FinalHtlcEvent{
CircuitKey: bobKey.IncomingCircuit,
Settled: false,
Offchain: true,
Timestamp: ts,
},
}
return aliceEvents, bobEvents, nil
}
// If we want to get events for a successful payment, we add a settle
// for alice, a forward and settle for bob and a receive settle for
// carol.
aliceEvents = append(
aliceEvents,
&SettleEvent{
HtlcKey: aliceKey,
Preimage: *preimage,
HtlcEventType: HtlcEventTypeSend,
Timestamp: ts,
},
)
bobEvents := []interface{}{
&ForwardingEvent{
HtlcKey: bobKey,
HtlcInfo: bobInfo,
HtlcEventType: HtlcEventTypeForward,
Timestamp: ts,
},
&SettleEvent{
HtlcKey: bobKey,
Preimage: *preimage,
HtlcEventType: HtlcEventTypeForward,
Timestamp: ts,
},
&FinalHtlcEvent{
CircuitKey: bobKey.IncomingCircuit,
Settled: true,
Offchain: true,
Timestamp: ts,
},
}
carolEvents := []interface{}{
&SettleEvent{
HtlcKey: HtlcKey{
IncomingCircuit: channeldb.CircuitKey{
ChanID: channels.carolToBob.ShortChanID(),
HtlcID: htlcID,
},
OutgoingCircuit: zeroCircuit,
},
Preimage: *preimage,
HtlcEventType: HtlcEventTypeReceive,
Timestamp: ts,
}, &FinalHtlcEvent{
CircuitKey: channeldb.CircuitKey{
ChanID: channels.carolToBob.ShortChanID(),
HtlcID: htlcID,
},
Settled: true,
Offchain: true,
Timestamp: ts,
},
}
return aliceEvents, bobEvents, carolEvents
}
type mockForwardInterceptor struct {
t *testing.T
interceptedChan chan InterceptedPacket
}
func (m *mockForwardInterceptor) InterceptForwardHtlc(
intercepted InterceptedPacket) error {
m.interceptedChan <- intercepted
return nil
}
func (m *mockForwardInterceptor) getIntercepted() InterceptedPacket {
m.t.Helper()
select {
case p := <-m.interceptedChan:
return p
case <-time.After(time.Second):
require.Fail(m.t, "timeout")
return InterceptedPacket{}
}
}
func assertNumCircuits(t *testing.T, s *Switch, pending, opened int) {
if s.circuits.NumPending() != pending {
t.Fatalf("wrong amount of half circuits, expected %v but "+
"got %v", pending, s.circuits.NumPending())
}
if s.circuits.NumOpen() != opened {
t.Fatalf("wrong amount of circuits, expected %v but got %v",
opened, s.circuits.NumOpen())
}
}
func assertOutgoingLinkReceive(t *testing.T, targetLink *mockChannelLink,
expectReceive bool) *htlcPacket {
// Pull packet from targetLink link.
select {
case packet := <-targetLink.packets:
if !expectReceive {
t.Fatal("forward was intercepted, shouldn't land at bob link")
} else if err := targetLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
return packet
case <-time.After(time.Second):
if expectReceive {
t.Fatal("request was not propagated to destination")
}
}
return nil
}
func assertOutgoingLinkReceiveIntercepted(t *testing.T,
targetLink *mockChannelLink) {
t.Helper()
select {
case <-targetLink.packets:
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
}
type interceptableSwitchTestContext struct {
t *testing.T
preimage [sha256.Size]byte
rhash [32]byte
onionBlob [1366]byte
incomingHtlcID uint64
cltvRejectDelta uint32
cltvInterceptDelta uint32
forwardInterceptor *mockForwardInterceptor
aliceChannelLink *mockChannelLink
bobChannelLink *mockChannelLink
s *Switch
}
func newInterceptableSwitchTestContext(
t *testing.T) *interceptableSwitchTestContext {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err, "unable to create bob server")
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err, "unable to open channeldb")
t.Cleanup(func() { cdb.Close() })
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
preimage := [sha256.Size]byte{1}
ctx := &interceptableSwitchTestContext{
t: t,
preimage: preimage,
rhash: sha256.Sum256(preimage[:]),
onionBlob: [1366]byte{4, 5, 6},
incomingHtlcID: uint64(0),
cltvRejectDelta: 10,
cltvInterceptDelta: 13,
forwardInterceptor: &mockForwardInterceptor{
t: t,
interceptedChan: make(chan InterceptedPacket),
},
aliceChannelLink: aliceChannelLink,
bobChannelLink: bobChannelLink,
s: s,
}
return ctx
}
func (c *interceptableSwitchTestContext) createTestPacket() *htlcPacket {
c.incomingHtlcID++
return &htlcPacket{
incomingChanID: c.aliceChannelLink.ShortChanID(),
incomingHTLCID: c.incomingHtlcID,
incomingTimeout: testStartingHeight + c.cltvInterceptDelta + 1,
outgoingChanID: c.bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: c.rhash,
Amount: 1,
OnionBlob: c.onionBlob,
},
}
}
func (c *interceptableSwitchTestContext) finish() {
if err := c.s.Stop(); err != nil {
c.t.Fatalf(err.Error())
}
}
func (c *interceptableSwitchTestContext) createSettlePacket(
outgoingHTLCID uint64) *htlcPacket {
return &htlcPacket{
outgoingChanID: c.bobChannelLink.ShortChanID(),
outgoingHTLCID: outgoingHTLCID,
amount: 1,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: c.preimage,
},
}
}
func TestSwitchHoldForward(t *testing.T) {
t.Parallel()
c := newInterceptableSwitchTestContext(t)
defer c.finish()
notifier := &mock.ChainNotifier{
EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
}
notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight}
switchForwardInterceptor, err := NewInterceptableSwitch(
&InterceptableSwitchConfig{
Switch: c.s,
CltvRejectDelta: c.cltvRejectDelta,
CltvInterceptDelta: c.cltvInterceptDelta,
Notifier: notifier,
},
)
require.NoError(t, err)
require.NoError(t, switchForwardInterceptor.Start())
switchForwardInterceptor.SetInterceptor(c.forwardInterceptor.InterceptForwardHtlc)
linkQuit := make(chan struct{})
// Test a forward that expires too soon.
packet := c.createTestPacket()
packet.incomingTimeout = testStartingHeight + c.cltvRejectDelta - 1
err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet)
require.NoError(t, err, "can't forward htlc packet")
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertOutgoingLinkReceiveIntercepted(t, c.aliceChannelLink)
assertNumCircuits(t, c.s, 0, 0)
// Test a forward that expires too soon and can't be failed.
packet = c.createTestPacket()
packet.incomingTimeout = testStartingHeight + c.cltvRejectDelta - 1
// Simulate an error during the composition of the failure message.
currentCallback := c.s.cfg.FetchLastChannelUpdate
c.s.cfg.FetchLastChannelUpdate = func(
lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
return nil, errors.New("cannot fetch update")
}
err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet)
require.NoError(t, err, "can't forward htlc packet")
receivedPkt := assertOutgoingLinkReceive(t, c.bobChannelLink, true)
assertNumCircuits(t, c.s, 1, 1)
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false,
c.createSettlePacket(receivedPkt.outgoingHTLCID),
))
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
c.s.cfg.FetchLastChannelUpdate = currentCallback
// Test resume a hold forward.
assertNumCircuits(t, c.s, 0, 0)
err = switchForwardInterceptor.ForwardPackets(
linkQuit, false, c.createTestPacket(),
)
require.NoError(t, err)
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Action: FwdActionResume,
Key: c.forwardInterceptor.getIntercepted().IncomingCircuit,
}))
receivedPkt = assertOutgoingLinkReceive(t, c.bobChannelLink, true)
assertNumCircuits(t, c.s, 1, 1)
// settling the htlc to close the circuit.
err = switchForwardInterceptor.ForwardPackets(
linkQuit, false,
c.createSettlePacket(receivedPkt.outgoingHTLCID),
)
require.NoError(t, err)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
// Test resume a hold forward after disconnection.
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, c.createTestPacket(),
))
// Wait until the packet is offered to the interceptor.
_ = c.forwardInterceptor.getIntercepted()
// No forward expected yet.
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
// Disconnect should resume the forwarding.
switchForwardInterceptor.SetInterceptor(nil)
receivedPkt = assertOutgoingLinkReceive(t, c.bobChannelLink, true)
assertNumCircuits(t, c.s, 1, 1)
// Settle the htlc to close the circuit.
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false,
c.createSettlePacket(receivedPkt.outgoingHTLCID),
))
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
// Test failing a hold forward
switchForwardInterceptor.SetInterceptor(
c.forwardInterceptor.InterceptForwardHtlc,
)
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, c.createTestPacket(),
))
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Action: FwdActionFail,
Key: c.forwardInterceptor.getIntercepted().IncomingCircuit,
FailureCode: lnwire.CodeTemporaryChannelFailure,
}))
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
// Test failing a hold forward with a failure message.
require.NoError(t,
switchForwardInterceptor.ForwardPackets(
linkQuit, false, c.createTestPacket(),
),
)
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
reason := lnwire.OpaqueReason([]byte{1, 2, 3})
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Action: FwdActionFail,
Key: c.forwardInterceptor.getIntercepted().IncomingCircuit,
FailureMessage: reason,
}))
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
packet = assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
require.Equal(t, reason, packet.htlc.(*lnwire.UpdateFailHTLC).Reason)
assertNumCircuits(t, c.s, 0, 0)
// Test failing a hold forward with a malformed htlc failure.
err = switchForwardInterceptor.ForwardPackets(
linkQuit, false, c.createTestPacket(),
)
require.NoError(t, err)
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
code := lnwire.CodeInvalidOnionKey
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Action: FwdActionFail,
Key: c.forwardInterceptor.getIntercepted().IncomingCircuit,
FailureCode: code,
}))
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
packet = assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
failPacket := packet.htlc.(*lnwire.UpdateFailHTLC)
shaOnionBlob := sha256.Sum256(c.onionBlob[:])
expectedFailure := &lnwire.FailInvalidOnionKey{
OnionSHA256: shaOnionBlob,
}
fwdErr, err := newMockDeobfuscator().DecryptError(failPacket.Reason)
require.NoError(t, err)
require.Equal(t, expectedFailure, fwdErr.WireMessage())
assertNumCircuits(t, c.s, 0, 0)
// Test settling a hold forward
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, c.createTestPacket(),
))
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Key: c.forwardInterceptor.getIntercepted().IncomingCircuit,
Action: FwdActionSettle,
Preimage: c.preimage,
}))
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
require.NoError(t, switchForwardInterceptor.Stop())
// Test always-on interception.
notifier = &mock.ChainNotifier{
EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
}
notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight}
switchForwardInterceptor, err = NewInterceptableSwitch(
&InterceptableSwitchConfig{
Switch: c.s,
CltvRejectDelta: c.cltvRejectDelta,
CltvInterceptDelta: c.cltvInterceptDelta,
RequireInterceptor: true,
Notifier: notifier,
},
)
require.NoError(t, err)
require.NoError(t, switchForwardInterceptor.Start())
// Forward a fresh packet. It is expected to be failed immediately,
// because there is no interceptor registered.
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, c.createTestPacket(),
))
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
// Forward a replayed packet. It is expected to be held until the
// interceptor connects. To continue the test, it needs to be ran in a
// goroutine.
errChan := make(chan error)
go func() {
errChan <- switchForwardInterceptor.ForwardPackets(
linkQuit, true, c.createTestPacket(),
)
}()
// Assert that nothing is forward to the switch.
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertNumCircuits(t, c.s, 0, 0)
// Register an interceptor.
switchForwardInterceptor.SetInterceptor(
c.forwardInterceptor.InterceptForwardHtlc,
)
// Expect the ForwardPackets call to unblock.
require.NoError(t, <-errChan)
// Now expect the queued packet to come through.
c.forwardInterceptor.getIntercepted()
// Disconnect and reconnect interceptor.
switchForwardInterceptor.SetInterceptor(nil)
switchForwardInterceptor.SetInterceptor(
c.forwardInterceptor.InterceptForwardHtlc,
)
// A replay of the held packet is expected.
intercepted := c.forwardInterceptor.getIntercepted()
// Settle the packet.
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Key: intercepted.IncomingCircuit,
Action: FwdActionSettle,
Preimage: c.preimage,
}))
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
require.NoError(t, switchForwardInterceptor.Stop())
select {
case <-c.forwardInterceptor.interceptedChan:
require.Fail(t, "unexpected interception")
default:
}
}
func TestInterceptableSwitchWatchDog(t *testing.T) {
t.Parallel()
c := newInterceptableSwitchTestContext(t)
defer c.finish()
// Start interceptable switch.
notifier := &mock.ChainNotifier{
EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
}
notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight}
switchForwardInterceptor, err := NewInterceptableSwitch(
&InterceptableSwitchConfig{
Switch: c.s,
CltvRejectDelta: c.cltvRejectDelta,
CltvInterceptDelta: c.cltvInterceptDelta,
Notifier: notifier,
},
)
require.NoError(t, err)
require.NoError(t, switchForwardInterceptor.Start())
// Set interceptor.
switchForwardInterceptor.SetInterceptor(
c.forwardInterceptor.InterceptForwardHtlc,
)
// Receive a packet.
linkQuit := make(chan struct{})
packet := c.createTestPacket()
err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet)
require.NoError(t, err, "can't forward htlc packet")
// Intercept the packet.
intercepted := c.forwardInterceptor.getIntercepted()
require.Equal(t,
int32(packet.incomingTimeout-c.cltvRejectDelta),
intercepted.AutoFailHeight,
)
// Htlc expires before a resolution from the interceptor.
notifier.EpochChan <- &chainntnfs.BlockEpoch{
Height: int32(packet.incomingTimeout) -
int32(c.cltvRejectDelta),
}
// Expect the htlc to be failed back.
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
// It is too late now to resolve. Expect an error.
require.Error(t, switchForwardInterceptor.Resolve(&FwdResolution{
Action: FwdActionSettle,
Key: intercepted.IncomingCircuit,
Preimage: c.preimage,
}))
}
// TestSwitchDustForwarding tests that the switch properly fails HTLC's which
// have incoming or outgoing links that breach their dust thresholds.
func TestSwitchDustForwarding(t *testing.T) {
t.Parallel()
// We'll create a three-hop network:
// - Alice has a dust limit of 200sats with Bob
// - Bob has a dust limit of 800sats with Alice
// - Bob has a dust limit of 200sats with Carol
// - Carol has a dust limit of 800sats with Bob
channels, _, err := createClusterChannels(
t, btcutil.SatoshiPerBitcoin, btcutil.SatoshiPerBitcoin,
)
require.NoError(t, err)
n := newThreeHopNetwork(
t, channels.aliceToBob, channels.bobToAlice,
channels.bobToCarol, channels.carolToBob, testStartingHeight,
)
err = n.start()
require.NoError(t, err)
// We'll also put Alice and Bob into hodl.ExitSettle mode, such that
// they won't settle incoming exit-hop HTLC's automatically.
n.aliceChannelLink.cfg.HodlMask = hodl.ExitSettle.Mask()
n.firstBobChannelLink.cfg.HodlMask = hodl.ExitSettle.Mask()
// We'll test that once the default threshold is exceeded on the
// Alice -> Bob channel, either side's calls to SendHTLC will fail.
//
// Alice will send 357 HTLC's of 700sats. Bob will also send 357 HTLC's
// of 700sats. If either side attempts to send a dust HTLC, it will
// fail so amounts below 800sats will breach the dust threshold.
amt := lnwire.NewMSatFromSatoshis(700)
aliceBobFirstHop := n.aliceChannelLink.ShortChanID()
sendDustHtlcs(t, n, true, amt, aliceBobFirstHop)
sendDustHtlcs(t, n, false, amt, aliceBobFirstHop)
// Generate the parameters needed for Bob to send another dust HTLC.
_, timelock, hops := generateHops(
amt, testStartingHeight, n.aliceChannelLink,
)
blob, err := generateRoute(hops...)
require.NoError(t, err)
// Assert that if Bob sends a dust HTLC it will fail.
failingPreimage := lntypes.Preimage{0, 0, 3}
failingHash := failingPreimage.Hash()
failingHtlc := &lnwire.UpdateAddHTLC{
PaymentHash: failingHash,
Amount: amt,
Expiry: timelock,
OnionBlob: blob,
}
checkAlmostDust := func(link *channelLink, mbox MailBox,
remote bool) bool {
timeout := time.After(15 * time.Second)
pollInterval := 300 * time.Millisecond
expectedDust := 357 * 2 * amt
for {
<-time.After(pollInterval)
select {
case <-timeout:
return false
default:
}
linkDust := link.getDustSum(remote)
localMailDust, remoteMailDust := mbox.DustPackets()
totalDust := linkDust
if remote {
totalDust += remoteMailDust
} else {
totalDust += localMailDust
}
if totalDust == expectedDust {
break
}
}
return true
}
// Wait until Bob is almost at the dust threshold.
bobMbox := n.bobServer.htlcSwitch.mailOrchestrator.GetOrCreateMailBox(
n.firstBobChannelLink.ChanID(),
n.firstBobChannelLink.ShortChanID(),
)
require.True(t, checkAlmostDust(n.firstBobChannelLink, bobMbox, false))
// Assert that the HTLC is failed due to the dust threshold.
err = n.bobServer.htlcSwitch.SendHTLC(
aliceBobFirstHop, uint64(357), failingHtlc,
)
require.ErrorIs(t, err, errDustThresholdExceeded)
// Generate the parameters needed for bob to send a non-dust HTLC.
nondustAmt := lnwire.NewMSatFromSatoshis(10_000)
_, _, hops = generateHops(
nondustAmt, testStartingHeight, n.aliceChannelLink,
)
blob, err = generateRoute(hops...)
require.NoError(t, err)
// Now attempt to send an HTLC above Bob's dust limit. It should
// succeed.
nondustPreimage := lntypes.Preimage{0, 0, 4}
nondustHash := nondustPreimage.Hash()
nondustHtlc := &lnwire.UpdateAddHTLC{
PaymentHash: nondustHash,
Amount: nondustAmt,
Expiry: timelock,
OnionBlob: blob,
}
// Assert that SendHTLC succeeds and evaluateDustThreshold returns
// false.
err = n.bobServer.htlcSwitch.SendHTLC(
aliceBobFirstHop, uint64(358), nondustHtlc,
)
require.NoError(t, err)
// Introduce Carol into the mix and assert that sending a multi-hop
// dust HTLC to Alice will fail. Bob should fail back the HTLC with a
// temporary channel failure.
carolAmt, carolTimelock, carolHops := generateHops(
amt, testStartingHeight, n.secondBobChannelLink,
n.aliceChannelLink,
)
carolBlob, err := generateRoute(carolHops...)
require.NoError(t, err)
carolPreimage := lntypes.Preimage{0, 0, 5}
carolHash := carolPreimage.Hash()
carolHtlc := &lnwire.UpdateAddHTLC{
PaymentHash: carolHash,
Amount: carolAmt,
Expiry: carolTimelock,
OnionBlob: carolBlob,
}
// Initialize Carol's attempt ID.
carolAttemptID := 0
err = n.carolServer.htlcSwitch.SendHTLC(
n.carolChannelLink.ShortChanID(), uint64(carolAttemptID),
carolHtlc,
)
require.NoError(t, err)
carolAttemptID++
carolResultChan, err := n.carolServer.htlcSwitch.GetPaymentResult(
uint64(carolAttemptID-1), carolHash, newMockDeobfuscator(),
)
require.NoError(t, err)
result, ok := <-carolResultChan
require.True(t, ok)
assertFailureCode(
t, result.Error, lnwire.CodeTemporaryChannelFailure,
)
// Send an HTLC from Alice to Carol and assert that it is failed at the
// call to SendHTLC.
htlcAmt, totalTimelock, aliceHops := generateHops(
amt, testStartingHeight, n.firstBobChannelLink,
n.carolChannelLink,
)
blob, err = generateRoute(aliceHops...)
require.NoError(t, err)
aliceMultihopPreimage := lntypes.Preimage{0, 0, 6}
aliceMultihopHash := aliceMultihopPreimage.Hash()
aliceMultihopHtlc := &lnwire.UpdateAddHTLC{
PaymentHash: aliceMultihopHash,
Amount: htlcAmt,
Expiry: totalTimelock,
OnionBlob: blob,
}
// Wait until Alice's expected dust for the remote commitment is just
// under the dust threshold.
aliceOrch := n.aliceServer.htlcSwitch.mailOrchestrator
aliceMbox := aliceOrch.GetOrCreateMailBox(
n.aliceChannelLink.ChanID(), n.aliceChannelLink.ShortChanID(),
)
require.True(t, checkAlmostDust(n.aliceChannelLink, aliceMbox, true))
err = n.aliceServer.htlcSwitch.SendHTLC(
n.aliceChannelLink.ShortChanID(), uint64(357),
aliceMultihopHtlc,
)
require.ErrorIs(t, err, errDustThresholdExceeded)
}
// sendDustHtlcs is a helper function used to send many dust HTLC's to test the
// Switch's dust-threshold logic. It takes a boolean denoting whether or not
// Alice is the sender.
func sendDustHtlcs(t *testing.T, n *threeHopNetwork, alice bool,
amt lnwire.MilliSatoshi, sid lnwire.ShortChannelID) {
t.Helper()
// The number of dust HTLC's we'll send for both Alice and Bob.
numHTLCs := 357
// Extract the destination into a variable. If alice is the sender, the
// destination is Bob.
destLink := n.aliceChannelLink
if alice {
destLink = n.firstBobChannelLink
}
// Create hops that will be used in the onion payload.
htlcAmt, totalTimelock, hops := generateHops(
amt, testStartingHeight, destLink,
)
// Convert the hops to a blob that will be put in the Add message.
blob, err := generateRoute(hops...)
require.NoError(t, err)
// Create a slice to store the preimages.
preimages := make([]lntypes.Preimage, numHTLCs)
// Initialize the attempt ID used in SendHTLC calls.
attemptID := uint64(0)
// Deterministically generate preimages. Avoid the all-zeroes preimage
// because that will be rejected by the database. We'll use a different
// third byte for Alice and Bob.
endByte := byte(2)
if alice {
endByte = byte(3)
}
for i := 0; i < numHTLCs; i++ {
preimages[i] = lntypes.Preimage{byte(i >> 8), byte(i), endByte}
}
sendingSwitch := n.bobServer.htlcSwitch
if alice {
sendingSwitch = n.aliceServer.htlcSwitch
}
// Call SendHTLC in a loop for numHTLCs.
for i := 0; i < numHTLCs; i++ {
// Construct the htlc packet.
hash := preimages[i].Hash()
htlc := &lnwire.UpdateAddHTLC{
PaymentHash: hash,
Amount: htlcAmt,
Expiry: totalTimelock,
OnionBlob: blob,
}
for {
// It may be the case that the dust threshold is hit
// before all 357*2 HTLC's are sent due to double
// counting. Get around this by continuing to send
// until successful.
err = sendingSwitch.SendHTLC(sid, attemptID, htlc)
if err == nil {
break
}
}
attemptID++
}
}
// TestSwitchMailboxDust tests that the switch takes into account the mailbox
// dust when evaluating the dust threshold. The mockChannelLink does not have
// channel state, so this only tests the switch-mailbox interaction.
func TestSwitchMailboxDust(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
carolPeer, err := newMockServer(
t, "carol", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
defer func() {
_ = s.Stop()
}()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
chanID3, carolChanID := genID()
aliceLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
err = s.AddLink(aliceLink)
require.NoError(t, err)
bobLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
err = s.AddLink(bobLink)
require.NoError(t, err)
carolLink := newMockChannelLink(
s, chanID3, carolChanID, emptyScid, carolPeer, true, false,
false, false,
)
err = s.AddLink(carolLink)
require.NoError(t, err)
// mockChannelLink sets the local and remote dust limits of the mailbox
// to 400 satoshis and the feerate to 0. We'll fill the mailbox up with
// dust packets and assert that calls to SendHTLC will fail.
preimage, err := genPreimage()
require.NoError(t, err)
rhash := sha256.Sum256(preimage[:])
amt := lnwire.NewMSatFromSatoshis(350)
addMsg := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: amt,
ChanID: chanID1,
}
// Initialize the carolHTLCID.
var carolHTLCID uint64
// It will take aliceCount HTLC's of 350sats to fill up Alice's mailbox
// to the point where another would put Alice over the dust threshold.
aliceCount := 1428
mailbox := s.mailOrchestrator.GetOrCreateMailBox(chanID1, aliceChanID)
for i := 0; i < aliceCount; i++ {
alicePkt := &htlcPacket{
incomingChanID: carolChanID,
incomingHTLCID: carolHTLCID,
outgoingChanID: aliceChanID,
obfuscator: NewMockObfuscator(),
incomingAmount: amt,
amount: amt,
htlc: addMsg,
}
err = mailbox.AddPacket(alicePkt)
require.NoError(t, err)
carolHTLCID++
}
// Sending one more HTLC to Alice should result in the dust threshold
// being breached.
err = s.SendHTLC(aliceChanID, 0, addMsg)
require.ErrorIs(t, err, errDustThresholdExceeded)
// We'll now call ForwardPackets from Bob to ensure that the mailbox
// sum is also accounted for in the forwarding case.
packet := &htlcPacket{
incomingChanID: bobChanID,
incomingHTLCID: 0,
outgoingChanID: aliceChanID,
obfuscator: NewMockObfuscator(),
incomingAmount: amt,
amount: amt,
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: amt,
ChanID: chanID1,
},
}
err = s.ForwardPackets(nil, packet)
require.NoError(t, err)
// Bob should receive a failure from the switch.
select {
case p := <-bobLink.packets:
require.NotEmpty(t, p.linkFailure)
assertFailureCode(
t, p.linkFailure, lnwire.CodeTemporaryChannelFailure,
)
case <-time.After(5 * time.Second):
t.Fatal("no timely reply from switch")
}
}
// TestSwitchResolution checks the ability of the switch to persist and handle
// resolution messages.
func TestSwitchResolution(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err)
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
t.Cleanup(func() { var _ = s.Stop() })
err = s.Start()
require.NoError(t, err)
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
err = s.AddLink(aliceChannelLink)
require.NoError(t, err)
err = s.AddLink(bobChannelLink)
require.NoError(t, err)
// Create an add htlcPacket that Alice will send to Bob.
preimage, err := genPreimage()
require.NoError(t, err)
rhash := sha256.Sum256(preimage[:])
packet := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
err = s.ForwardPackets(nil, packet)
require.NoError(t, err)
// Bob will receive the packet and open the circuit.
select {
case <-bobChannelLink.packets:
err = bobChannelLink.completeCircuit(packet)
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Check that only one circuit is open.
require.Equal(t, 1, s.circuits.NumOpen())
// We'll send a settle resolution to Switch that should go to Alice.
settleResMsg := contractcourt.ResolutionMsg{
SourceChan: bobChanID,
HtlcIndex: 0,
PreImage: &preimage,
}
// Before the resolution is sent, remove alice's link so we can assert
// that the resolution is actually stored. Otherwise, it would be
// deleted shortly after being sent.
s.RemoveLink(chanID1)
// Send the resolution message.
err = s.ProcessContractResolution(settleResMsg)
require.NoError(t, err)
// Assert that the resolution store contains the settle reoslution.
resMsgs, err := s.resMsgStore.fetchAllResolutionMsg()
require.NoError(t, err)
require.Equal(t, 1, len(resMsgs))
require.Equal(t, settleResMsg.SourceChan, resMsgs[0].SourceChan)
require.Equal(t, settleResMsg.HtlcIndex, resMsgs[0].HtlcIndex)
require.Nil(t, resMsgs[0].Failure)
require.Equal(t, preimage, *resMsgs[0].PreImage)
// Now we'll restart Alice's link and delete the circuit.
err = s.AddLink(aliceChannelLink)
require.NoError(t, err)
// Alice will receive the packet and open the circuit.
select {
case alicePkt := <-aliceChannelLink.packets:
err = aliceChannelLink.completeCircuit(alicePkt)
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Assert that there are no more circuits.
require.Equal(t, 0, s.circuits.NumOpen())
// We'll restart the Switch and assert that Alice does not receive
// another packet.
switchDB := s.cfg.DB.(*channeldb.DB)
err = s.Stop()
require.NoError(t, err)
s, err = initSwitchWithDB(testStartingHeight, switchDB)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
defer func() {
_ = s.Stop()
}()
err = s.AddLink(aliceChannelLink)
require.NoError(t, err)
err = s.AddLink(bobChannelLink)
require.NoError(t, err)
// Alice should not receive a packet since the Switch should have
// deleted the resolution message since the circuit was closed.
select {
case alicePkt := <-aliceChannelLink.packets:
t.Fatalf("received erroneous packet: %v", alicePkt)
case <-time.After(time.Second * 5):
}
// Check that the resolution message no longer exists in the store.
resMsgs, err = s.resMsgStore.fetchAllResolutionMsg()
require.NoError(t, err)
require.Equal(t, 0, len(resMsgs))
}
// TestSwitchForwardFailAlias tests that if ForwardPackets returns a failure
// before actually forwarding, the ChannelUpdate uses the SCID from the
// incoming channel and does not leak private information like the UTXO.
func TestSwitchForwardFailAlias(t *testing.T) {
tests := []struct {
name string
// Whether or not Alice will be a zero-conf channel or an
// option-scid-alias channel (feature-bit).
zeroConf bool
}{
{
name: "option-scid-alias forwarding failure",
zeroConf: false,
},
{
name: "zero-conf forwarding failure",
zeroConf: true,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testSwitchForwardFailAlias(t, test.zeroConf)
})
}
}
func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err)
t.Cleanup(func() { cdb.Close() })
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
// Make Alice's channel zero-conf or option-scid-alias (feature bit).
aliceAlias := lnwire.ShortChannelID{
BlockHeight: 16_000_000,
TxIndex: 5,
TxPosition: 5,
}
var aliceLink *mockChannelLink
if zeroConf {
aliceLink = newMockChannelLink(
s, chanID1, aliceAlias, aliceChanID, alicePeer, true,
true, true, false,
)
} else {
aliceLink = newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true,
true, false, true,
)
aliceLink.addAlias(aliceAlias)
}
err = s.AddLink(aliceLink)
require.NoError(t, err)
bobLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
err = s.AddLink(bobLink)
require.NoError(t, err)
// Create a packet that will be sent from Alice to Bob via the switch.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Forward the packet and check that Bob's channel link received it.
err = s.ForwardPackets(nil, ogPacket)
require.NoError(t, err)
// Assert that the circuits are in the expected state.
require.Equal(t, 1, s.circuits.NumPending())
require.Equal(t, 0, s.circuits.NumOpen())
// Pull packet from Bob's link, and do nothing with it.
select {
case <-bobLink.packets:
case <-s.quit:
t.Fatal("switch shutting down, failed to forward packet")
}
// Now we will restart the Switch to trigger the LoadedFromDisk logic.
err = s.Stop()
require.NoError(t, err)
err = cdb.Close()
require.NoError(t, err)
cdb2, err := channeldb.Open(tempPath)
require.NoError(t, err)
t.Cleanup(func() { cdb2.Close() })
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
require.NoError(t, err)
err = s2.Start()
require.NoError(t, err)
defer func() {
_ = s2.Stop()
}()
var aliceLink2 *mockChannelLink
if zeroConf {
aliceLink2 = newMockChannelLink(
s2, chanID1, aliceAlias, aliceChanID, alicePeer, true,
true, true, false,
)
} else {
aliceLink2 = newMockChannelLink(
s2, chanID1, aliceChanID, emptyScid, alicePeer, true,
true, false, true,
)
aliceLink2.addAlias(aliceAlias)
}
err = s2.AddLink(aliceLink2)
require.NoError(t, err)
bobLink2 := newMockChannelLink(
s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
false,
)
err = s2.AddLink(bobLink2)
require.NoError(t, err)
// Reforward the ogPacket and wait for Alice to receive a failure
// packet.
err = s2.ForwardPackets(nil, ogPacket)
require.NoError(t, err)
select {
case failPacket := <-aliceLink2.packets:
// Assert that the failPacket does not leak UTXO information.
// This means checking that aliceChanID was not returned.
msg := failPacket.linkFailure.msg
failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure)
require.True(t, ok)
require.Equal(t, aliceAlias, failMsg.Update.ShortChannelID)
case <-s2.quit:
t.Fatal("switch shutting down, failed to forward packet")
}
}
// TestSwitchAliasFailAdd tests that the mailbox does not leak UTXO information
// when failing back an HTLC due to the 5-second timeout. This is tested in the
// switch rather than the mailbox because the mailbox tests do not have the
// proper context (e.g. the Switch's failAliasUpdate function). The caveat here
// is that if the private UTXO is already known, it is fine to send a failure
// back. This tests option-scid-alias (feature-bit) and zero-conf channels.
func TestSwitchAliasFailAdd(t *testing.T) {
tests := []struct {
name string
// Denotes whether the opened channel will be zero-conf.
zeroConf bool
// Denotes whether the opened channel will be private.
private bool
// Denotes whether an alias was used during forwarding.
useAlias bool
}{
{
name: "public zero-conf using alias",
zeroConf: true,
private: false,
useAlias: true,
},
{
name: "public zero-conf using real",
zeroConf: true,
private: false,
useAlias: true,
},
{
name: "private zero-conf using alias",
zeroConf: true,
private: true,
useAlias: true,
},
{
name: "public option-scid-alias using alias",
zeroConf: false,
private: false,
useAlias: true,
},
{
name: "public option-scid-alias using real",
zeroConf: false,
private: false,
useAlias: false,
},
{
name: "private option-scid-alias using alias",
zeroConf: false,
private: true,
useAlias: true,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testSwitchAliasFailAdd(
t, test.zeroConf, test.private, test.useAlias,
)
})
}
}
func testSwitchAliasFailAdd(t *testing.T, zeroConf, private, useAlias bool) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err)
defer cdb.Close()
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err)
// Change the mailOrchestrator's expiry to a second.
s.mailOrchestrator.cfg.expiry = time.Second
err = s.Start()
require.NoError(t, err)
defer func() {
_ = s.Stop()
}()
// Make Alice's channel zero-conf or option-scid-alias (feature bit).
aliceAlias := lnwire.ShortChannelID{
BlockHeight: 16_000_000,
TxIndex: 5,
TxPosition: 5,
}
aliceAlias2 := aliceAlias
aliceAlias2.TxPosition = 6
var aliceLink *mockChannelLink
if zeroConf {
aliceLink = newMockChannelLink(
s, chanID1, aliceAlias, aliceChanID, alicePeer, true,
private, true, false,
)
aliceLink.addAlias(aliceAlias2)
} else {
aliceLink = newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true,
private, false, true,
)
aliceLink.addAlias(aliceAlias)
aliceLink.addAlias(aliceAlias2)
}
err = s.AddLink(aliceLink)
require.NoError(t, err)
bobLink := newMockChannelLink(
s, chanID2, bobChanID, emptyScid, bobPeer, true, true, false,
false,
)
err = s.AddLink(bobLink)
require.NoError(t, err)
// Create a packet that Bob will send to Alice via ForwardPackets.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: bobLink.ShortChanID(),
incomingHTLCID: 0,
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Determine which outgoingChanID to set based on the useAlias boolean.
outgoingChanID := aliceChanID
if useAlias {
// Choose randomly from the 2 possible aliases.
aliases := aliceLink.getAliases()
idx := mrand.Intn(len(aliases))
outgoingChanID = aliases[idx]
}
ogPacket.outgoingChanID = outgoingChanID
// Forward the packet so Alice's mailbox fails it backwards.
err = s.ForwardPackets(nil, ogPacket)
require.NoError(t, err)
// Assert that the circuits are in the expected state.
require.Equal(t, 1, s.circuits.NumPending())
require.Equal(t, 0, s.circuits.NumOpen())
// Wait to receive the packet from Bob's mailbox.
select {
case failPacket := <-bobLink.packets:
// Assert that failPacket returns the expected SCID in the
// ChannelUpdate.
msg := failPacket.linkFailure.msg
failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure)
require.True(t, ok)
require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID)
case <-s.quit:
t.Fatal("switch shutting down, failed to receive fail packet")
}
}
// TestSwitchHandlePacketForwardAlias checks that handlePacketForward (which
// calls CheckHtlcForward) does not leak the UTXO in a failure message for
// alias channels. This test requires us to have a REAL link, which we also
// must modify in order to test it properly (e.g. making it a private channel).
// This doesn't lead to good code, but short of refactoring the link-generation
// code there is not a good alternative.
func TestSwitchHandlePacketForward(t *testing.T) {
tests := []struct {
name string
// Denotes whether or not the channel will be zero-conf.
zeroConf bool
// Denotes whether or not the channel will have negotiated the
// option-scid-alias feature-bit and is not zero-conf.
optionFeature bool
// Denotes whether or not the channel will be private.
private bool
// Denotes whether or not the alias will be used for
// forwarding.
useAlias bool
}{
{
name: "public zero-conf using alias",
zeroConf: true,
private: false,
useAlias: true,
},
{
name: "public zero-conf using real",
zeroConf: true,
private: false,
useAlias: false,
},
{
name: "private zero-conf using alias",
zeroConf: true,
private: true,
useAlias: true,
},
{
name: "public option-scid-alias using alias",
zeroConf: false,
optionFeature: true,
private: false,
useAlias: true,
},
{
name: "public option-scid-alias using real",
zeroConf: false,
optionFeature: true,
private: false,
useAlias: false,
},
{
name: "private option-scid-alias using alias",
zeroConf: false,
optionFeature: true,
private: true,
useAlias: true,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testSwitchHandlePacketForward(
t, test.zeroConf, test.private, test.useAlias,
test.optionFeature,
)
})
}
}
func testSwitchHandlePacketForward(t *testing.T, zeroConf, private,
useAlias, optionFeature bool) {
t.Parallel()
// Create a link for Alice that we'll add to the switch.
aliceLink, _, _, _, _, err :=
newSingleLinkTestHarness(t, btcutil.SatoshiPerBitcoin, 0)
require.NoError(t, err)
s, err := initSwitchWithTempDB(t, testStartingHeight)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer func() {
_ = s.Stop()
}()
// Change Alice's ShortChanID and OtherShortChanID here.
aliceAlias := lnwire.ShortChannelID{
BlockHeight: 16_000_000,
TxIndex: 5,
TxPosition: 5,
}
aliceAlias2 := aliceAlias
aliceAlias2.TxPosition = 6
aliceChannelLink := aliceLink.(*channelLink)
aliceChannelState := aliceChannelLink.channel.State()
// Set the link's GetAliases function.
aliceChannelLink.cfg.GetAliases = func(
base lnwire.ShortChannelID) []lnwire.ShortChannelID {
return []lnwire.ShortChannelID{aliceAlias, aliceAlias2}
}
if !private {
// Change the channel to public depending on the test.
aliceChannelState.ChannelFlags = lnwire.FFAnnounceChannel
}
// If this is an option-scid-alias feature-bit non-zero-conf channel,
// we'll mark the channel as such.
if optionFeature {
aliceChannelState.ChanType |= channeldb.ScidAliasFeatureBit
}
// This is the ShortChannelID field in the OpenChannel struct.
aliceScid := aliceLink.ShortChanID()
if zeroConf {
// Store the alias in the shortChanID field and mark the real
// scid in the database.
aliceChannelLink.shortChanID = aliceAlias
err = aliceChannelState.MarkRealScid(aliceScid)
require.NoError(t, err)
aliceChannelState.ChanType |= channeldb.ZeroConfBit
}
err = s.AddLink(aliceLink)
require.NoError(t, err)
// Add a mockChannelLink for Bob.
bobChanID, bobScid := genID()
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
bobLink := newMockChannelLink(
s, bobChanID, bobScid, emptyScid, bobPeer, true, false, false,
false,
)
err = s.AddLink(bobLink)
require.NoError(t, err)
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: bobLink.ShortChanID(),
incomingHTLCID: 0,
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Determine which outgoingChanID to set based on the useAlias bool.
outgoingChanID := aliceScid
if useAlias {
// Choose from the possible aliases.
aliases := aliceLink.getAliases()
idx := mrand.Intn(len(aliases))
outgoingChanID = aliases[idx]
}
ogPacket.outgoingChanID = outgoingChanID
// Forward the packet to Alice and she should fail it back with an
// AmountBelowMinimum FailureMessage.
err = s.ForwardPackets(nil, ogPacket)
require.NoError(t, err)
select {
case failPacket := <-bobLink.packets:
// Assert that failPacket returns the expected ChannelUpdate.
msg := failPacket.linkFailure.msg
failMsg, ok := msg.(*lnwire.FailAmountBelowMinimum)
require.True(t, ok)
require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID)
case <-s.quit:
t.Fatal("switch shutting down, failed to receive failure")
}
}
// TestSwitchAliasInterceptFail tests that when the InterceptableSwitch fails
// an incoming HTLC, it does not leak the on-chain UTXO for option-scid-alias
// (feature bit) or zero-conf channels.
func TestSwitchAliasInterceptFail(t *testing.T) {
tests := []struct {
name string
// Denotes whether or not the incoming channel is a zero-conf
// channel or an option-scid-alias channel instead (feature
// bit).
zeroConf bool
}{
{
name: "option-scid-alias",
zeroConf: false,
},
{
name: "zero-conf",
zeroConf: true,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testSwitchAliasInterceptFail(t, test.zeroConf)
})
}
}
func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) {
t.Parallel()
chanID, aliceScid := genID()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
require.NoError(t, err)
tempPath := t.TempDir()
cdb, err := channeldb.Open(tempPath)
require.NoError(t, err)
t.Cleanup(func() { cdb.Close() })
s, err := initSwitchWithDB(testStartingHeight, cdb)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
defer func() {
_ = s.Stop()
}()
// Make Alice's alias here.
aliceAlias := lnwire.ShortChannelID{
BlockHeight: 16_000_000,
TxIndex: 5,
TxPosition: 5,
}
aliceAlias2 := aliceAlias
aliceAlias2.TxPosition = 6
var aliceLink *mockChannelLink
if zeroConf {
aliceLink = newMockChannelLink(
s, chanID, aliceAlias, aliceScid, alicePeer, true,
true, true, false,
)
aliceLink.addAlias(aliceAlias2)
} else {
aliceLink = newMockChannelLink(
s, chanID, aliceScid, emptyScid, alicePeer, true,
true, false, true,
)
aliceLink.addAlias(aliceAlias)
aliceLink.addAlias(aliceAlias2)
}
err = s.AddLink(aliceLink)
require.NoError(t, err)
// Now we'll create the packet that will be sent from the Alice link.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceLink.ShortChanID(),
incomingTimeout: 1000,
incomingHTLCID: 0,
outgoingChanID: lnwire.ShortChannelID{},
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Now setup the interceptable switch so that we can reject this
// packet.
forwardInterceptor := &mockForwardInterceptor{
t: t,
interceptedChan: make(chan InterceptedPacket),
}
notifier := &mock.ChainNotifier{
EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
}
notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight}
interceptSwitch, err := NewInterceptableSwitch(
&InterceptableSwitchConfig{
Switch: s,
Notifier: notifier,
CltvRejectDelta: 10,
CltvInterceptDelta: 13,
},
)
require.NoError(t, err)
require.NoError(t, interceptSwitch.Start())
interceptSwitch.SetInterceptor(forwardInterceptor.InterceptForwardHtlc)
err = interceptSwitch.ForwardPackets(nil, false, ogPacket)
require.NoError(t, err)
inCircuit := forwardInterceptor.getIntercepted().IncomingCircuit
require.NoError(t, interceptSwitch.resolve(&FwdResolution{
Action: FwdActionFail,
Key: inCircuit,
FailureCode: lnwire.CodeTemporaryChannelFailure,
}))
select {
case failPacket := <-aliceLink.packets:
// Assert that failPacket returns the expected ChannelUpdate.
failHtlc, ok := failPacket.htlc.(*lnwire.UpdateFailHTLC)
require.True(t, ok)
fwdErr, err := newMockDeobfuscator().DecryptError(
failHtlc.Reason,
)
require.NoError(t, err)
failure := fwdErr.WireMessage()
failureMsg, ok := failure.(*lnwire.FailTemporaryChannelFailure)
require.True(t, ok)
failScid := failureMsg.Update.ShortChannelID
isAlias := failScid == aliceAlias || failScid == aliceAlias2
require.True(t, isAlias)
case <-s.quit:
t.Fatalf("switch shutting down, failed to receive failure")
}
require.NoError(t, interceptSwitch.Stop())
}