package routing import ( "fmt" "math" "os" "testing" "time" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/zpay32" "github.com/stretchr/testify/require" ) const ( sourceNodeID = 1 targetNodeID = 2 ) type mockBandwidthHints struct { hints map[uint64]lnwire.MilliSatoshi } func (m *mockBandwidthHints) availableChanBandwidth(channelID uint64, _ lnwire.MilliSatoshi) (lnwire.MilliSatoshi, bool) { if m.hints == nil { return 0, false } balance, ok := m.hints[channelID] return balance, ok } // integratedRoutingContext defines the context in which integrated routing // tests run. type integratedRoutingContext struct { graph *mockGraph t *testing.T source *mockNode target *mockNode amt lnwire.MilliSatoshi maxShardAmt *lnwire.MilliSatoshi finalExpiry int32 mcCfg MissionControlConfig pathFindingCfg PathFindingConfig routeHints [][]zpay32.HopHint } // newIntegratedRoutingContext instantiates a new integrated routing test // context with a source and a target node. func newIntegratedRoutingContext(t *testing.T) *integratedRoutingContext { // Instantiate a mock graph. source := newMockNode(sourceNodeID) target := newMockNode(targetNodeID) graph := newMockGraph(t) graph.addNode(source) graph.addNode(target) graph.source = source // Initiate the test context with a set of default configuration values. // We don't use the lnd defaults here, because otherwise changing the // defaults would break the unit tests. The actual values picked aren't // critical to excite certain behavior, but do need to be aligned with // the test case assertions. aCfg := AprioriConfig{ PenaltyHalfLife: 30 * time.Minute, AprioriHopProbability: 0.6, AprioriWeight: 0.5, CapacityFraction: testCapacityFraction, } estimator, err := NewAprioriEstimator(aCfg) require.NoError(t, err) ctx := integratedRoutingContext{ t: t, graph: graph, amt: 100000, finalExpiry: 40, mcCfg: MissionControlConfig{ Estimator: estimator, }, pathFindingCfg: PathFindingConfig{ AttemptCost: 1000, MinProbability: 0.01, }, source: source, target: target, } return &ctx } // htlcAttempt records the route and outcome of an attempted htlc. type htlcAttempt struct { route *route.Route success bool } func (h htlcAttempt) String() string { return fmt.Sprintf("success=%v, route=%v", h.success, h.route) } // testPayment launches a test payment and asserts that it is completed after // the expected number of attempts. func (c *integratedRoutingContext) testPayment(maxParts uint32, destFeatureBits ...lnwire.FeatureBit) ([]htlcAttempt, error) { // We start out with the base set of MPP feature bits. If the caller // overrides this set of bits, then we'll use their feature bits // entirely. baseFeatureBits := mppFeatures if len(destFeatureBits) != 0 { baseFeatureBits = lnwire.NewRawFeatureVector(destFeatureBits...) } var ( nextPid uint64 attempts []htlcAttempt ) // Create temporary database for mission control. file, err := os.CreateTemp("", "*.db") if err != nil { c.t.Fatal(err) } dbPath := file.Name() c.t.Cleanup(func() { if err := file.Close(); err != nil { c.t.Fatal(err) } if err := os.Remove(dbPath); err != nil { c.t.Fatal(err) } }) db, err := kvdb.Open( kvdb.BoltBackendName, dbPath, true, kvdb.DefaultDBTimeout, ) if err != nil { c.t.Fatal(err) } c.t.Cleanup(func() { if err := db.Close(); err != nil { c.t.Fatal(err) } }) // Instantiate a new mission control with the current configuration // values. mc, err := NewMissionControl(db, c.source.pubkey, &c.mcCfg) if err != nil { c.t.Fatal(err) } getBandwidthHints := func(_ Graph) (bandwidthHints, error) { // Create bandwidth hints based on local channel balances. bandwidthHints := map[uint64]lnwire.MilliSatoshi{} for _, ch := range c.graph.nodes[c.source.pubkey].channels { bandwidthHints[ch.id] = ch.balance } return &mockBandwidthHints{ hints: bandwidthHints, }, nil } var paymentAddr [32]byte payment := LightningPayment{ FinalCLTVDelta: uint16(c.finalExpiry), FeeLimit: lnwire.MaxMilliSatoshi, Target: c.target.pubkey, PaymentAddr: &paymentAddr, DestFeatures: lnwire.NewFeatureVector( baseFeatureBits, lnwire.Features, ), Amount: c.amt, CltvLimit: math.MaxUint32, MaxParts: maxParts, RouteHints: c.routeHints, } var paymentHash [32]byte if err := payment.SetPaymentHash(paymentHash); err != nil { return nil, err } if c.maxShardAmt != nil { payment.MaxShardAmt = c.maxShardAmt } session, err := newPaymentSession( &payment, c.graph.source.pubkey, getBandwidthHints, newMockGraphSessionFactory(c.graph), mc, c.pathFindingCfg, ) if err != nil { c.t.Fatal(err) } // Override default minimum shard amount. session.minShardAmt = lnwire.NewMSatFromSatoshis(5000) // Now the payment control loop starts. It will keep trying routes until // the payment succeeds. var ( amtRemaining = payment.Amount inFlightHtlcs uint32 ) for { // Create bandwidth hints based on local channel balances. bandwidthHints := map[uint64]lnwire.MilliSatoshi{} for _, ch := range c.graph.nodes[c.source.pubkey].channels { bandwidthHints[ch.id] = ch.balance } // Find a route. route, err := session.RequestRoute( amtRemaining, lnwire.MaxMilliSatoshi, inFlightHtlcs, 0, ) if err != nil { return attempts, err } // Send out the htlc on the mock graph. pid := nextPid nextPid++ htlcResult, err := c.graph.sendHtlc(route) if err != nil { c.t.Fatal(err) } success := htlcResult.failure == nil attempts = append(attempts, htlcAttempt{ route: route, success: success, }) // Process the result. In normal Lightning operations, the // sender doesn't get an acknowledgement from the recipient that // the htlc arrived. In integrated routing tests, this // acknowledgement is available. It is a simplification of // reality that still allows certain classes of tests to be // performed. if success { inFlightHtlcs++ err := mc.ReportPaymentSuccess(pid, route) if err != nil { c.t.Fatal(err) } amtRemaining -= route.ReceiverAmt() // If the full amount has been paid, the payment is // successful and the control loop can be terminated. if amtRemaining == 0 { break } // Otherwise try to send the remaining amount. continue } // Failure, update mission control and retry. finalResult, err := mc.ReportPaymentFail( pid, route, getNodeIndex(route, htlcResult.failureSource), htlcResult.failure, ) if err != nil { c.t.Fatal(err) } if finalResult != nil { break } } return attempts, nil } // getNodeIndex returns the zero-based index of the given node in the route. func getNodeIndex(route *route.Route, failureSource route.Vertex) *int { if failureSource == route.SourcePubKey { idx := 0 return &idx } for i, h := range route.Hops { if h.PubKeyBytes == failureSource { idx := i + 1 return &idx } } return nil } type mockGraphSessionFactory struct { Graph } func newMockGraphSessionFactory(graph Graph) GraphSessionFactory { return &mockGraphSessionFactory{Graph: graph} } func (m *mockGraphSessionFactory) NewGraphSession() (Graph, func() error, error) { return m, func() error { return nil }, nil } var _ GraphSessionFactory = (*mockGraphSessionFactory)(nil) var _ Graph = (*mockGraphSessionFactory)(nil) type mockGraphSessionFactoryChanDB struct { graph *channeldb.ChannelGraph } func newMockGraphSessionFactoryFromChanDB( graph *channeldb.ChannelGraph) *mockGraphSessionFactoryChanDB { return &mockGraphSessionFactoryChanDB{ graph: graph, } } func (g *mockGraphSessionFactoryChanDB) NewGraphSession() (Graph, func() error, error) { tx, err := g.graph.NewPathFindTx() if err != nil { return nil, nil, err } session := &mockGraphSessionChanDB{ graph: g.graph, tx: tx, } return session, session.close, nil } var _ GraphSessionFactory = (*mockGraphSessionFactoryChanDB)(nil) type mockGraphSessionChanDB struct { graph *channeldb.ChannelGraph tx kvdb.RTx } func newMockGraphSessionChanDB(graph *channeldb.ChannelGraph) Graph { return &mockGraphSessionChanDB{ graph: graph, } } func (g *mockGraphSessionChanDB) close() error { if g.tx == nil { return nil } err := g.tx.Rollback() if err != nil { return fmt.Errorf("error closing db tx: %w", err) } return nil } func (g *mockGraphSessionChanDB) ForEachNodeChannel(nodePub route.Vertex, cb func(channel *channeldb.DirectedChannel) error) error { return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb) } func (g *mockGraphSessionChanDB) FetchNodeFeatures(nodePub route.Vertex) ( *lnwire.FeatureVector, error) { return g.graph.FetchNodeFeatures(nodePub) }