mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-18 21:35:24 +01:00
routing/router_test: add TestSendToRouteMultiShardSend
This commit is contained in:
parent
864e64e725
commit
95c5a123c8
@ -15,6 +15,8 @@ import (
|
|||||||
type mockPaymentAttemptDispatcher struct {
|
type mockPaymentAttemptDispatcher struct {
|
||||||
onPayment func(firstHop lnwire.ShortChannelID) ([32]byte, error)
|
onPayment func(firstHop lnwire.ShortChannelID) ([32]byte, error)
|
||||||
results map[uint64]*htlcswitch.PaymentResult
|
results map[uint64]*htlcswitch.PaymentResult
|
||||||
|
|
||||||
|
sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
|
var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
|
||||||
@ -27,10 +29,6 @@ func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.results == nil {
|
|
||||||
m.results = make(map[uint64]*htlcswitch.PaymentResult)
|
|
||||||
}
|
|
||||||
|
|
||||||
var result *htlcswitch.PaymentResult
|
var result *htlcswitch.PaymentResult
|
||||||
preimage, err := m.onPayment(firstHop)
|
preimage, err := m.onPayment(firstHop)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -45,7 +43,13 @@ func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID,
|
|||||||
result = &htlcswitch.PaymentResult{Preimage: preimage}
|
result = &htlcswitch.PaymentResult{Preimage: preimage}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.Lock()
|
||||||
|
if m.results == nil {
|
||||||
|
m.results = make(map[uint64]*htlcswitch.PaymentResult)
|
||||||
|
}
|
||||||
|
|
||||||
m.results[pid] = result
|
m.results[pid] = result
|
||||||
|
m.Unlock()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -55,7 +59,11 @@ func (m *mockPaymentAttemptDispatcher) GetPaymentResult(paymentID uint64,
|
|||||||
<-chan *htlcswitch.PaymentResult, error) {
|
<-chan *htlcswitch.PaymentResult, error) {
|
||||||
|
|
||||||
c := make(chan *htlcswitch.PaymentResult, 1)
|
c := make(chan *htlcswitch.PaymentResult, 1)
|
||||||
|
|
||||||
|
m.Lock()
|
||||||
res, ok := m.results[paymentID]
|
res, ok := m.results[paymentID]
|
||||||
|
m.Unlock()
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, htlcswitch.ErrPaymentIDNotFound
|
return nil, htlcswitch.ErrPaymentIDNotFound
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,7 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/htlcswitch"
|
"github.com/lightningnetwork/lnd/htlcswitch"
|
||||||
"github.com/lightningnetwork/lnd/lntypes"
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/record"
|
||||||
"github.com/lightningnetwork/lnd/routing/route"
|
"github.com/lightningnetwork/lnd/routing/route"
|
||||||
"github.com/lightningnetwork/lnd/zpay32"
|
"github.com/lightningnetwork/lnd/zpay32"
|
||||||
)
|
)
|
||||||
@ -2725,6 +2726,138 @@ func TestSendToRouteStructuredError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestSendToRouteMultiShardSend checks that a 3-shard payment can be executed
|
||||||
|
// using SendToRoute.
|
||||||
|
func TestSendToRouteMultiShardSend(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cleanup, err := createTestCtxSingleNode(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
const numShards = 3
|
||||||
|
const payAmt = lnwire.MilliSatoshi(numShards * 10000)
|
||||||
|
node, err := createTestNode()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a simple 1-hop route that we will use for all three shards.
|
||||||
|
hops := []*route.Hop{
|
||||||
|
{
|
||||||
|
ChannelID: 1,
|
||||||
|
PubKeyBytes: node.PubKeyBytes,
|
||||||
|
AmtToForward: payAmt / numShards,
|
||||||
|
MPP: record.NewMPP(payAmt, [32]byte{}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceNode, err := ctx.graph.SourceNode()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rt, err := route.NewRouteFromHops(
|
||||||
|
payAmt, 100, sourceNode.PubKeyBytes, hops,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to create route: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The first shard we send we'll fail immediately, to check that we are
|
||||||
|
// still allowed to retry with other shards after a failed one.
|
||||||
|
ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcher).setPaymentResult(
|
||||||
|
func(firstHop lnwire.ShortChannelID) ([32]byte, error) {
|
||||||
|
return [32]byte{}, htlcswitch.NewForwardingError(
|
||||||
|
&lnwire.FailFeeInsufficient{
|
||||||
|
Update: lnwire.ChannelUpdate{},
|
||||||
|
}, 1,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
// The payment parameter is mostly redundant in SendToRoute. Can be left
|
||||||
|
// empty for this test.
|
||||||
|
var payment lntypes.Hash
|
||||||
|
|
||||||
|
// Send the shard using the created route, and expect an error to be
|
||||||
|
// returned.
|
||||||
|
_, err = ctx.router.SendToRoute(payment, rt)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected forwarding error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we'll modify the SendToSwitch method again to wait until all
|
||||||
|
// three shards are initiated before returning a result. We do this by
|
||||||
|
// signalling when the method has been called, and then stop to wait
|
||||||
|
// for the test to deliver the final result on the channel below.
|
||||||
|
waitForResultSignal := make(chan struct{}, numShards)
|
||||||
|
results := make(chan lntypes.Preimage, numShards)
|
||||||
|
|
||||||
|
ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcher).setPaymentResult(
|
||||||
|
func(firstHop lnwire.ShortChannelID) ([32]byte, error) {
|
||||||
|
|
||||||
|
// Signal that the shard has been initiated and is
|
||||||
|
// waiting for a result.
|
||||||
|
waitForResultSignal <- struct{}{}
|
||||||
|
|
||||||
|
// Wait for a result before returning it.
|
||||||
|
res, ok := <-results
|
||||||
|
if !ok {
|
||||||
|
return [32]byte{}, fmt.Errorf("failure")
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Launch three shards by calling SendToRoute in three goroutines,
|
||||||
|
// returning their final error on the channel.
|
||||||
|
errChan := make(chan error)
|
||||||
|
successes := make(chan lntypes.Preimage)
|
||||||
|
|
||||||
|
for i := 0; i < numShards; i++ {
|
||||||
|
go func() {
|
||||||
|
preimg, err := ctx.router.SendToRoute(payment, rt)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
successes <- preimg
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all shards to signal they have been initiated.
|
||||||
|
for i := 0; i < numShards; i++ {
|
||||||
|
select {
|
||||||
|
case <-waitForResultSignal:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatalf("not waiting for results")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deliver a dummy preimage to all the shard handlers.
|
||||||
|
preimage := lntypes.Preimage{}
|
||||||
|
preimage[4] = 42
|
||||||
|
for i := 0; i < numShards; i++ {
|
||||||
|
results <- preimage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally expect all shards to return with the above preimage.
|
||||||
|
for i := 0; i < numShards; i++ {
|
||||||
|
select {
|
||||||
|
case p := <-successes:
|
||||||
|
if p != preimage {
|
||||||
|
t.Fatalf("preimage mismatch")
|
||||||
|
}
|
||||||
|
case err := <-errChan:
|
||||||
|
t.Fatalf("unexpected error from SendToRoute: %v", err)
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatalf("result not received")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestSendToRouteMaxHops asserts that SendToRoute fails when using a route that
|
// TestSendToRouteMaxHops asserts that SendToRoute fails when using a route that
|
||||||
// exceeds the maximum number of hops.
|
// exceeds the maximum number of hops.
|
||||||
func TestSendToRouteMaxHops(t *testing.T) {
|
func TestSendToRouteMaxHops(t *testing.T) {
|
||||||
|
Loading…
Reference in New Issue
Block a user