diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 431a650ce..a9222597c 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -103,7 +103,7 @@ type MissionControl interface { // GetProbability is expected to return the success probability of a // payment from fromNode to toNode. GetProbability(fromNode, toNode route.Vertex, - amt lnwire.MilliSatoshi) float64 + amt lnwire.MilliSatoshi, capacity btcutil.Amount) float64 // ResetHistory resets the history of MissionControl returning it to a // state as if no payment attempts have been made. @@ -258,7 +258,8 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context, restrictions := &routing.RestrictParams{ FeeLimit: feeLimit, ProbabilitySource: func(fromNode, toNode route.Vertex, - amt lnwire.MilliSatoshi) float64 { + amt lnwire.MilliSatoshi, + capacity btcutil.Amount) float64 { if _, ok := ignoredNodes[fromNode]; ok { return 0 @@ -277,7 +278,7 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context, } return r.MissionControl.GetProbability( - fromNode, toNode, amt, + fromNode, toNode, amt, capacity, ) }, DestCustomRecords: record.CustomSet(in.DestCustomRecords), @@ -362,7 +363,7 @@ func (r *RouterBackend) getSuccessProbability(rt *route.Route) float64 { toNode := hop.PubKeyBytes probability := r.MissionControl.GetProbability( - fromNode, toNode, amtToFwd, + fromNode, toNode, amtToFwd, 0, ) successProb *= probability diff --git a/lnrpc/routerrpc/router_backend_test.go b/lnrpc/routerrpc/router_backend_test.go index a5b629fe2..89c7f8fff 100644 --- a/lnrpc/routerrpc/router_backend_test.go +++ b/lnrpc/routerrpc/router_backend_test.go @@ -145,18 +145,18 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool, } if restrictions.ProbabilitySource(route.Vertex{2}, - route.Vertex{1}, 0, + route.Vertex{1}, 0, 0, ) != 0 { t.Fatal("expecting 0% probability for ignored edge") } if restrictions.ProbabilitySource(ignoreNodeVertex, - route.Vertex{6}, 0, + route.Vertex{6}, 0, 0, ) != 0 { t.Fatal("expecting 0% probability for ignored node") } - if restrictions.ProbabilitySource(node1, node2, 0) != 0 { + if restrictions.ProbabilitySource(node1, node2, 0, 0) != 0 { t.Fatal("expecting 0% probability for ignored pair") } @@ -181,7 +181,7 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool, expectedProb = testMissionControlProb } if restrictions.ProbabilitySource(route.Vertex{4}, - route.Vertex{5}, 0, + route.Vertex{5}, 0, 0, ) != expectedProb { t.Fatal("expecting 100% probability") } @@ -239,7 +239,7 @@ type mockMissionControl struct { } func (m *mockMissionControl) GetProbability(fromNode, toNode route.Vertex, - amt lnwire.MilliSatoshi) float64 { + amt lnwire.MilliSatoshi, capacity btcutil.Amount) float64 { return testMissionControlProb } diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go index 1967a3dd6..29847e8da 100644 --- a/lnrpc/routerrpc/router_server.go +++ b/lnrpc/routerrpc/router_server.go @@ -710,7 +710,7 @@ func (s *Server) QueryProbability(ctx context.Context, amt := lnwire.MilliSatoshi(req.AmtMsat) mc := s.cfg.RouterBackend.MissionControl - prob := mc.GetProbability(fromNode, toNode, amt) + prob := mc.GetProbability(fromNode, toNode, amt, 0) history := mc.GetPairHistorySnapshot(fromNode, toNode) return &QueryProbabilityResponse{ diff --git a/routing/missioncontrol.go b/routing/missioncontrol.go index 2bce890d5..86de4bc87 100644 --- a/routing/missioncontrol.go +++ b/routing/missioncontrol.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/btcsuite/btcd/btcutil" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" @@ -333,7 +334,7 @@ func (m *MissionControl) ResetHistory() error { // GetProbability is expected to return the success probability of a payment // from fromNode along edge. func (m *MissionControl) GetProbability(fromNode, toNode route.Vertex, - amt lnwire.MilliSatoshi) float64 { + amt lnwire.MilliSatoshi, capacity btcutil.Amount) float64 { m.Lock() defer m.Unlock() @@ -346,7 +347,9 @@ func (m *MissionControl) GetProbability(fromNode, toNode route.Vertex, return m.estimator.getLocalPairProbability(now, results, toNode) } - return m.estimator.getPairProbability(now, results, toNode, amt) + return m.estimator.getPairProbability( + now, results, toNode, amt, capacity, + ) } // GetHistorySnapshot takes a snapshot from the current mission control state diff --git a/routing/missioncontrol_test.go b/routing/missioncontrol_test.go index dac2cfebc..c60469198 100644 --- a/routing/missioncontrol_test.go +++ b/routing/missioncontrol_test.go @@ -112,7 +112,7 @@ func (ctx *mcTestContext) restartMc() { func (ctx *mcTestContext) expectP(amt lnwire.MilliSatoshi, expected float64) { ctx.t.Helper() - p := ctx.mc.GetProbability(mcTestNode1, mcTestNode2, amt) + p := ctx.mc.GetProbability(mcTestNode1, mcTestNode2, amt, testCapacity) if p != expected { ctx.t.Fatalf("expected probability %v but got %v", expected, p) } @@ -148,9 +148,11 @@ func TestMissionControl(t *testing.T) { testTime := time.Date(2018, time.January, 9, 14, 00, 00, 0, time.UTC) - // For local channels, we expect a higher probability than our a prior + // For local channels, we expect a higher probability than our apriori // test probability. - selfP := ctx.mc.GetProbability(mcTestSelf, mcTestNode1, 100) + selfP := ctx.mc.GetProbability( + mcTestSelf, mcTestNode1, 100, testCapacity, + ) if selfP != prevSuccessProbability { t.Fatalf("expected prev success prob for untried local chans") } diff --git a/routing/mock_test.go b/routing/mock_test.go index afd0126dd..72f037b03 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/htlcswitch" @@ -139,7 +140,7 @@ func (m *mockMissionControlOld) ReportPaymentSuccess(paymentID uint64, } func (m *mockMissionControlOld) GetProbability(fromNode, toNode route.Vertex, - amt lnwire.MilliSatoshi) float64 { + amt lnwire.MilliSatoshi, capacity btcutil.Amount) float64 { return 0 } @@ -650,9 +651,9 @@ func (m *mockMissionControl) ReportPaymentSuccess(paymentID uint64, } func (m *mockMissionControl) GetProbability(fromNode, toNode route.Vertex, - amt lnwire.MilliSatoshi) float64 { + amt lnwire.MilliSatoshi, capacity btcutil.Amount) float64 { - args := m.Called(fromNode, toNode, amt) + args := m.Called(fromNode, toNode, amt, capacity) return args.Get(0).(float64) } diff --git a/routing/pathfind.go b/routing/pathfind.go index fb11fa241..a3a3955f7 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -7,6 +7,7 @@ import ( "math" "time" + "github.com/btcsuite/btcd/btcutil" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/feature" @@ -306,7 +307,7 @@ type RestrictParams struct { // ProbabilitySource is a callback that is expected to return the // success probability of traversing the channel from the node. ProbabilitySource func(route.Vertex, route.Vertex, - lnwire.MilliSatoshi) float64 + lnwire.MilliSatoshi, btcutil.Amount) float64 // FeeLimit is a maximum fee amount allowed to be used on the path from // the source to the target. @@ -639,6 +640,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Request the success probability for this edge. edgeProbability := r.ProbabilitySource( fromVertex, toNodeDist.node, amountToSend, + edge.capacity, ) log.Trace(newLogClosure(func() string { diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index a3202d1e9..20dcf1470 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -105,7 +105,9 @@ var ( // noProbabilitySource is used in testing to return the same probability 1 for // all edges. -func noProbabilitySource(route.Vertex, route.Vertex, lnwire.MilliSatoshi) float64 { +func noProbabilitySource(route.Vertex, route.Vertex, lnwire.MilliSatoshi, + btcutil.Amount) float64 { + return 1 } @@ -2796,8 +2798,9 @@ func testProbabilityRouting(t *testing.T, useCache bool, target := ctx.testGraphInstance.aliasMap["target"] // Configure a probability source with the test parameters. - ctx.restrictParams.ProbabilitySource = func(fromNode, toNode route.Vertex, - amt lnwire.MilliSatoshi) float64 { + ctx.restrictParams.ProbabilitySource = func(fromNode, + toNode route.Vertex, amt lnwire.MilliSatoshi, + capacity btcutil.Amount) float64 { if amt == 0 { t.Fatal("expected non-zero amount") @@ -2878,8 +2881,9 @@ func runEqualCostRouteSelection(t *testing.T, useCache bool) { paymentAmt := lnwire.NewMSatFromSatoshis(100) target := ctx.testGraphInstance.aliasMap["target"] - ctx.restrictParams.ProbabilitySource = func(fromNode, toNode route.Vertex, - amt lnwire.MilliSatoshi) float64 { + ctx.restrictParams.ProbabilitySource = func(fromNode, + toNode route.Vertex, amt lnwire.MilliSatoshi, + capacity btcutil.Amount) float64 { switch { case fromNode == alias["source"] && toNode == alias["a"]: diff --git a/routing/probability_estimator.go b/routing/probability_estimator.go index 851cecf16..533499ee0 100644 --- a/routing/probability_estimator.go +++ b/routing/probability_estimator.go @@ -5,6 +5,7 @@ import ( "math" "time" + "github.com/btcsuite/btcd/btcutil" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -150,8 +151,8 @@ func (p *probabilityEstimator) getWeight(age time.Duration) float64 { // toNode based on historical payment outcomes for the from node. Those outcomes // are passed in via the results parameter. func (p *probabilityEstimator) getPairProbability( - now time.Time, results NodeResults, - toNode route.Vertex, amt lnwire.MilliSatoshi) float64 { + now time.Time, results NodeResults, toNode route.Vertex, + amt lnwire.MilliSatoshi, capacity btcutil.Amount) float64 { nodeProbability := p.getNodeProbability(now, results, amt) diff --git a/routing/probability_estimator_test.go b/routing/probability_estimator_test.go index 3bfb0536a..c226fe97a 100644 --- a/routing/probability_estimator_test.go +++ b/routing/probability_estimator_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/btcsuite/btcd/btcutil" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -23,6 +24,9 @@ const ( aprioriHopProb = 0.6 aprioriWeight = 0.75 aprioriPrevSucProb = 0.95 + + // testCapacity is used to define a capacity for some channels. + testCapacity = btcutil.Amount(100_000) ) type estimatorTestContext struct { @@ -53,7 +57,8 @@ func newEstimatorTestContext(t *testing.T) *estimatorTestContext { // assertPairProbability asserts that the calculated success probability is // correct. func (c *estimatorTestContext) assertPairProbability(now time.Time, - toNode byte, amt lnwire.MilliSatoshi, expectedProb float64) { + toNode byte, amt lnwire.MilliSatoshi, capacity btcutil.Amount, + expectedProb float64) { c.t.Helper() @@ -64,7 +69,9 @@ func (c *estimatorTestContext) assertPairProbability(now time.Time, const tolerance = 0.01 - p := c.estimator.getPairProbability(now, results, route.Vertex{toNode}, amt) + p := c.estimator.getPairProbability( + now, results, route.Vertex{toNode}, amt, capacity, + ) diff := p - expectedProb if diff > tolerance || diff < -tolerance { c.t.Fatalf("expected probability %v for node %v, but got %v", @@ -77,7 +84,7 @@ func (c *estimatorTestContext) assertPairProbability(now time.Time, func TestProbabilityEstimatorNoResults(t *testing.T) { ctx := newEstimatorTestContext(t) - ctx.assertPairProbability(testTime, 0, 0, aprioriHopProb) + ctx.assertPairProbability(testTime, 0, 0, testCapacity, aprioriHopProb) } // TestProbabilityEstimatorOneSuccess tests the probability estimation for nodes @@ -94,13 +101,15 @@ func TestProbabilityEstimatorOneSuccess(t *testing.T) { // Because of the previous success, this channel keep reporting a high // probability. ctx.assertPairProbability( - testTime, node1, 100, aprioriPrevSucProb, + testTime, node1, 100, testCapacity, aprioriPrevSucProb, ) // Untried channels are also influenced by the success. With a // aprioriWeight of 0.75, the a priori probability is assigned weight 3. expectedP := (3*aprioriHopProb + 1*aprioriPrevSucProb) / 4 - ctx.assertPairProbability(testTime, untriedNode, 100, expectedP) + ctx.assertPairProbability( + testTime, untriedNode, 100, testCapacity, expectedP, + ) } // TestProbabilityEstimatorOneFailure tests the probability estimation for nodes @@ -119,11 +128,15 @@ func TestProbabilityEstimatorOneFailure(t *testing.T) { // the failure after one hour is 0.5. This makes the node probability // 0.51: expectedNodeProb := (3*aprioriHopProb + 0.5*0) / 3.5 - ctx.assertPairProbability(testTime, untriedNode, 100, expectedNodeProb) + ctx.assertPairProbability( + testTime, untriedNode, 100, testCapacity, expectedNodeProb, + ) // The pair probability decays back to the node probability. With the // weight at 0.5, we expected a pair probability of 0.5 * 0.51 = 0.25. - ctx.assertPairProbability(testTime, node1, 100, expectedNodeProb/2) + ctx.assertPairProbability( + testTime, node1, 100, testCapacity, expectedNodeProb/2, + ) } // TestProbabilityEstimatorMix tests the probability estimation for nodes for @@ -147,7 +160,9 @@ func TestProbabilityEstimatorMix(t *testing.T) { // We expect the probability for a previously successful channel to // remain high. - ctx.assertPairProbability(testTime, node1, 100, prevSuccessProbability) + ctx.assertPairProbability( + testTime, node1, 100, testCapacity, prevSuccessProbability, + ) // For an untried node, we expected the node probability to be returned. // This is a weighted average of the results above and the a priori @@ -155,9 +170,13 @@ func TestProbabilityEstimatorMix(t *testing.T) { expectedNodeProb := (3*aprioriHopProb + 1*prevSuccessProbability) / (3 + 1 + 0.25 + 0.125) - ctx.assertPairProbability(testTime, untriedNode, 100, expectedNodeProb) + ctx.assertPairProbability( + testTime, untriedNode, 100, testCapacity, expectedNodeProb, + ) // For the previously failed connection with node 1, we expect 0.75 * // the node probability = 0.47. - ctx.assertPairProbability(testTime, node2, 100, expectedNodeProb*0.75) + ctx.assertPairProbability( + testTime, node2, 100, testCapacity, expectedNodeProb*0.75, + ) } diff --git a/routing/router.go b/routing/router.go index e8d2ab96a..949a4d8f1 100644 --- a/routing/router.go +++ b/routing/router.go @@ -227,7 +227,7 @@ type MissionController interface { // GetProbability is expected to return the success probability of a // payment from fromNode along edge. GetProbability(fromNode, toNode route.Vertex, - amt lnwire.MilliSatoshi) float64 + amt lnwire.MilliSatoshi, capacity btcutil.Amount) float64 } // FeeSchema is the set fee configuration for a Lightning Node on the network.