routing+server: use cached graph interface

This commit is contained in:
Oliver Gugger 2021-09-21 19:18:24 +02:00
parent 1d1c42f9ba
commit bf27d05aa8
No known key found for this signature in database
GPG key ID: 8E4256593F177720
10 changed files with 56 additions and 80 deletions

View file

@ -9,7 +9,8 @@ import (
// routingGraph is an abstract interface that provides information about nodes // routingGraph is an abstract interface that provides information about nodes
// and edges to pathfinding. // and edges to pathfinding.
type routingGraph interface { type routingGraph interface {
// forEachNodeChannel calls the callback for every channel of the given node. // forEachNodeChannel calls the callback for every channel of the given
// node.
forEachNodeChannel(nodePub route.Vertex, forEachNodeChannel(nodePub route.Vertex,
cb func(channel *channeldb.DirectedChannel) error) error cb func(channel *channeldb.DirectedChannel) error) error
@ -20,22 +21,26 @@ type routingGraph interface {
fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error)
} }
// dbRoutingTx is a routingGraph implementation that retrieves from the // CachedGraph is a routingGraph implementation that retrieves from the
// database. // database.
type dbRoutingTx struct { type CachedGraph struct {
graph *channeldb.ChannelGraph graph *channeldb.ChannelGraph
source route.Vertex source route.Vertex
} }
// newDbRoutingTx instantiates a new db-connected routing graph. It implictly // A compile time assertion to make sure CachedGraph implements the routingGraph
// interface.
var _ routingGraph = (*CachedGraph)(nil)
// NewCachedGraph instantiates a new db-connected routing graph. It implictly
// instantiates a new read transaction. // instantiates a new read transaction.
func newDbRoutingTx(graph *channeldb.ChannelGraph) (*dbRoutingTx, error) { func NewCachedGraph(graph *channeldb.ChannelGraph) (*CachedGraph, error) {
sourceNode, err := graph.SourceNode() sourceNode, err := graph.SourceNode()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &dbRoutingTx{ return &CachedGraph{
graph: graph, graph: graph,
source: sourceNode.PubKeyBytes, source: sourceNode.PubKeyBytes,
}, nil }, nil
@ -44,7 +49,7 @@ func newDbRoutingTx(graph *channeldb.ChannelGraph) (*dbRoutingTx, error) {
// forEachNodeChannel calls the callback for every channel of the given node. // forEachNodeChannel calls the callback for every channel of the given node.
// //
// NOTE: Part of the routingGraph interface. // NOTE: Part of the routingGraph interface.
func (g *dbRoutingTx) forEachNodeChannel(nodePub route.Vertex, func (g *CachedGraph) forEachNodeChannel(nodePub route.Vertex,
cb func(channel *channeldb.DirectedChannel) error) error { cb func(channel *channeldb.DirectedChannel) error) error {
return g.graph.ForEachNodeChannel(nodePub, cb) return g.graph.ForEachNodeChannel(nodePub, cb)
@ -53,7 +58,7 @@ func (g *dbRoutingTx) forEachNodeChannel(nodePub route.Vertex,
// sourceNode returns the source node of the graph. // sourceNode returns the source node of the graph.
// //
// NOTE: Part of the routingGraph interface. // NOTE: Part of the routingGraph interface.
func (g *dbRoutingTx) sourceNode() route.Vertex { func (g *CachedGraph) sourceNode() route.Vertex {
return g.source return g.source
} }
@ -61,7 +66,7 @@ func (g *dbRoutingTx) sourceNode() route.Vertex {
// unknown, assume no additional features are supported. // unknown, assume no additional features are supported.
// //
// NOTE: Part of the routingGraph interface. // NOTE: Part of the routingGraph interface.
func (g *dbRoutingTx) fetchNodeFeatures(nodePub route.Vertex) ( func (g *CachedGraph) fetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) { *lnwire.FeatureVector, error) {
return g.graph.FetchNodeFeatures(nodePub) return g.graph.FetchNodeFeatures(nodePub)

View file

@ -162,11 +162,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,
} }
session, err := newPaymentSession( session, err := newPaymentSession(
&payment, getBandwidthHints, &payment, getBandwidthHints, c.graph, mc, c.pathFindingCfg,
func() (routingGraph, func(), error) {
return c.graph, func() {}, nil
},
mc, c.pathFindingCfg,
) )
if err != nil { if err != nil {
c.t.Fatal(err) c.t.Fatal(err)

View file

@ -3021,7 +3021,7 @@ func dbFindPath(graph *channeldb.ChannelGraph,
source, target route.Vertex, amt lnwire.MilliSatoshi, source, target route.Vertex, amt lnwire.MilliSatoshi,
finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) { finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) {
routingTx, err := newDbRoutingTx(graph) routingGraph, err := NewCachedGraph(graph)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -3030,7 +3030,7 @@ func dbFindPath(graph *channeldb.ChannelGraph,
&graphParams{ &graphParams{
additionalEdges: additionalEdges, additionalEdges: additionalEdges,
bandwidthHints: bandwidthHints, bandwidthHints: bandwidthHints,
graph: routingTx, graph: routingGraph,
}, },
r, cfg, source, target, amt, finalHtlcExpiry, r, cfg, source, target, amt, finalHtlcExpiry,
) )

View file

@ -172,7 +172,7 @@ type paymentSession struct {
pathFinder pathFinder pathFinder pathFinder
getRoutingGraph func() (routingGraph, func(), error) routingGraph routingGraph
// pathFindingConfig defines global parameters that control the // pathFindingConfig defines global parameters that control the
// trade-off in path finding between fees and probabiity. // trade-off in path finding between fees and probabiity.
@ -193,7 +193,7 @@ type paymentSession struct {
// newPaymentSession instantiates a new payment session. // newPaymentSession instantiates a new payment session.
func newPaymentSession(p *LightningPayment, func newPaymentSession(p *LightningPayment,
getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error), getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error),
getRoutingGraph func() (routingGraph, func(), error), routingGraph routingGraph,
missionControl MissionController, pathFindingConfig PathFindingConfig) ( missionControl MissionController, pathFindingConfig PathFindingConfig) (
*paymentSession, error) { *paymentSession, error) {
@ -209,7 +209,7 @@ func newPaymentSession(p *LightningPayment,
getBandwidthHints: getBandwidthHints, getBandwidthHints: getBandwidthHints,
payment: p, payment: p,
pathFinder: findPath, pathFinder: findPath,
getRoutingGraph: getRoutingGraph, routingGraph: routingGraph,
pathFindingConfig: pathFindingConfig, pathFindingConfig: pathFindingConfig,
missionControl: missionControl, missionControl: missionControl,
minShardAmt: DefaultShardMinAmt, minShardAmt: DefaultShardMinAmt,
@ -287,29 +287,20 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
p.log.Debugf("pathfinding for amt=%v", maxAmt) p.log.Debugf("pathfinding for amt=%v", maxAmt)
// Get a routing graph. sourceVertex := p.routingGraph.sourceNode()
routingGraph, cleanup, err := p.getRoutingGraph()
if err != nil {
return nil, err
}
sourceVertex := routingGraph.sourceNode()
// Find a route for the current amount. // Find a route for the current amount.
path, err := p.pathFinder( path, err := p.pathFinder(
&graphParams{ &graphParams{
additionalEdges: p.additionalEdges, additionalEdges: p.additionalEdges,
bandwidthHints: bandwidthHints, bandwidthHints: bandwidthHints,
graph: routingGraph, graph: p.routingGraph,
}, },
restrictions, &p.pathFindingConfig, restrictions, &p.pathFindingConfig,
sourceVertex, p.payment.Target, sourceVertex, p.payment.Target,
maxAmt, finalHtlcExpiry, maxAmt, finalHtlcExpiry,
) )
// Close routing graph.
cleanup()
switch { switch {
case err == errNoPathFound: case err == errNoPathFound:
// Don't split if this is a legacy payment without mpp // Don't split if this is a legacy payment without mpp

View file

@ -17,7 +17,7 @@ var _ PaymentSessionSource = (*SessionSource)(nil)
type SessionSource struct { type SessionSource struct {
// Graph is the channel graph that will be used to gather metrics from // Graph is the channel graph that will be used to gather metrics from
// and also to carry out path finding queries. // and also to carry out path finding queries.
Graph *channeldb.ChannelGraph Graph routingGraph
// QueryBandwidth is a method that allows querying the lower link layer // QueryBandwidth is a method that allows querying the lower link layer
// to determine the up to date available bandwidth at a prospective link // to determine the up to date available bandwidth at a prospective link
@ -40,16 +40,6 @@ type SessionSource struct {
PathFindingConfig PathFindingConfig PathFindingConfig PathFindingConfig
} }
// getRoutingGraph returns a routing graph and a clean-up function for
// pathfinding.
func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) {
routingTx, err := newDbRoutingTx(m.Graph)
if err != nil {
return nil, nil, err
}
return routingTx, func() {}, nil
}
// NewPaymentSession creates a new payment session backed by the latest prune // NewPaymentSession creates a new payment session backed by the latest prune
// view from Mission Control. An optional set of routing hints can be provided // view from Mission Control. An optional set of routing hints can be provided
// in order to populate additional edges to explore when finding a path to the // in order to populate additional edges to explore when finding a path to the
@ -57,21 +47,16 @@ func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) {
func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( func (m *SessionSource) NewPaymentSession(p *LightningPayment) (
PaymentSession, error) { PaymentSession, error) {
sourceNode, err := m.Graph.SourceNode()
if err != nil {
return nil, err
}
getBandwidthHints := func() (map[uint64]lnwire.MilliSatoshi, getBandwidthHints := func() (map[uint64]lnwire.MilliSatoshi,
error) { error) {
return generateBandwidthHints( return generateBandwidthHints(
sourceNode.PubKeyBytes, m.Graph, m.QueryBandwidth, m.Graph.sourceNode(), m.Graph, m.QueryBandwidth,
) )
} }
session, err := newPaymentSession( session, err := newPaymentSession(
p, getBandwidthHints, m.getRoutingGraph, p, getBandwidthHints, m.Graph,
m.MissionControl, m.PathFindingConfig, m.MissionControl, m.PathFindingConfig,
) )
if err != nil { if err != nil {

View file

@ -121,9 +121,7 @@ func TestUpdateAdditionalEdge(t *testing.T) {
return nil, nil return nil, nil
}, },
func() (routingGraph, func(), error) { &sessionGraph{},
return &sessionGraph{}, func() {}, nil
},
&MissionControl{}, &MissionControl{},
PathFindingConfig{}, PathFindingConfig{},
) )
@ -203,9 +201,7 @@ func TestRequestRoute(t *testing.T) {
return nil, nil return nil, nil
}, },
func() (routingGraph, func(), error) { &sessionGraph{},
return &sessionGraph{}, func() {}, nil
},
&MissionControl{}, &MissionControl{},
PathFindingConfig{}, PathFindingConfig{},
) )

View file

@ -406,6 +406,10 @@ type ChannelRouter struct {
// when doing any path finding. // when doing any path finding.
selfNode *channeldb.LightningNode selfNode *channeldb.LightningNode
// cachedGraph is an instance of routingGraph that caches the source node as
// well as the channel graph itself in memory.
cachedGraph routingGraph
// newBlocks is a channel in which new blocks connected to the end of // newBlocks is a channel in which new blocks connected to the end of
// the main chain are sent over, and blocks updated after a call to // the main chain are sent over, and blocks updated after a call to
// UpdateFilter. // UpdateFilter.
@ -460,7 +464,6 @@ var _ ChannelGraphSource = (*ChannelRouter)(nil)
// channel graph is a subset of the UTXO set) set, then the router will proceed // channel graph is a subset of the UTXO set) set, then the router will proceed
// to fully sync to the latest state of the UTXO set. // to fully sync to the latest state of the UTXO set.
func New(cfg Config) (*ChannelRouter, error) { func New(cfg Config) (*ChannelRouter, error) {
selfNode, err := cfg.Graph.SourceNode() selfNode, err := cfg.Graph.SourceNode()
if err != nil { if err != nil {
return nil, err return nil, err
@ -468,6 +471,10 @@ func New(cfg Config) (*ChannelRouter, error) {
r := &ChannelRouter{ r := &ChannelRouter{
cfg: &cfg, cfg: &cfg,
cachedGraph: &CachedGraph{
graph: cfg.Graph,
source: selfNode.PubKeyBytes,
},
networkUpdates: make(chan *routingMsg), networkUpdates: make(chan *routingMsg),
topologyClients: make(map[uint64]*topologyClient), topologyClients: make(map[uint64]*topologyClient),
ntfnClientUpdates: make(chan *topologyClientUpdate), ntfnClientUpdates: make(chan *topologyClientUpdate),
@ -1735,7 +1742,7 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex,
// We'll attempt to obtain a set of bandwidth hints that can help us // We'll attempt to obtain a set of bandwidth hints that can help us
// eliminate certain routes early on in the path finding process. // eliminate certain routes early on in the path finding process.
bandwidthHints, err := generateBandwidthHints( bandwidthHints, err := generateBandwidthHints(
r.selfNode.PubKeyBytes, r.cfg.Graph, r.cfg.QueryBandwidth, r.selfNode.PubKeyBytes, r.cachedGraph, r.cfg.QueryBandwidth,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -1752,16 +1759,11 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex,
// execute our path finding algorithm. // execute our path finding algorithm.
finalHtlcExpiry := currentHeight + int32(finalExpiry) finalHtlcExpiry := currentHeight + int32(finalExpiry)
routingTx, err := newDbRoutingTx(r.cfg.Graph)
if err != nil {
return nil, err
}
path, err := findPath( path, err := findPath(
&graphParams{ &graphParams{
additionalEdges: routeHints, additionalEdges: routeHints,
bandwidthHints: bandwidthHints, bandwidthHints: bandwidthHints,
graph: routingTx, graph: r.cachedGraph,
}, },
restrictions, restrictions,
&r.cfg.PathFindingConfig, &r.cfg.PathFindingConfig,
@ -2657,14 +2659,14 @@ func (r *ChannelRouter) MarkEdgeLive(chanID lnwire.ShortChannelID) error {
// these hints allows us to reduce the number of extraneous attempts as we can // these hints allows us to reduce the number of extraneous attempts as we can
// skip channels that are inactive, or just don't have enough bandwidth to // skip channels that are inactive, or just don't have enough bandwidth to
// carry the payment. // carry the payment.
func generateBandwidthHints(sourceNode route.Vertex, graph *channeldb.ChannelGraph, func generateBandwidthHints(sourceNode route.Vertex, graph routingGraph,
queryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi) ( queryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi) (
map[uint64]lnwire.MilliSatoshi, error) { map[uint64]lnwire.MilliSatoshi, error) {
// First, we'll collect the set of outbound edges from the target // First, we'll collect the set of outbound edges from the target
// source node. // source node.
var localChans []*channeldb.DirectedChannel var localChans []*channeldb.DirectedChannel
err := graph.ForEachNodeChannel( err := graph.forEachNodeChannel(
sourceNode, func(channel *channeldb.DirectedChannel) error { sourceNode, func(channel *channeldb.DirectedChannel) error {
localChans = append(localChans, channel) localChans = append(localChans, channel)
return nil return nil
@ -2722,7 +2724,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
// We'll attempt to obtain a set of bandwidth hints that helps us select // We'll attempt to obtain a set of bandwidth hints that helps us select
// the best outgoing channel to use in case no outgoing channel is set. // the best outgoing channel to use in case no outgoing channel is set.
bandwidthHints, err := generateBandwidthHints( bandwidthHints, err := generateBandwidthHints(
r.selfNode.PubKeyBytes, r.cfg.Graph, r.cfg.QueryBandwidth, r.selfNode.PubKeyBytes, r.cachedGraph, r.cfg.QueryBandwidth,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -2752,12 +2754,6 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
runningAmt = *amt runningAmt = *amt
} }
// Open a transaction to execute the graph queries in.
routingTx, err := newDbRoutingTx(r.cfg.Graph)
if err != nil {
return nil, err
}
// Traverse hops backwards to accumulate fees in the running amounts. // Traverse hops backwards to accumulate fees in the running amounts.
source := r.selfNode.PubKeyBytes source := r.selfNode.PubKeyBytes
for i := len(hops) - 1; i >= 0; i-- { for i := len(hops) - 1; i >= 0; i-- {
@ -2776,7 +2772,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
// known in the graph. // known in the graph.
u := newUnifiedPolicies(source, toNode, outgoingChans) u := newUnifiedPolicies(source, toNode, outgoingChans)
err := u.addGraphPolicies(routingTx) err := u.addGraphPolicies(r.cachedGraph)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -129,8 +129,11 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T,
) )
require.NoError(t, err, "failed to create missioncontrol") require.NoError(t, err, "failed to create missioncontrol")
cachedGraph, err := NewCachedGraph(graphInstance.graph)
require.NoError(t, err)
sessionSource := &SessionSource{ sessionSource := &SessionSource{
Graph: graphInstance.graph, Graph: cachedGraph,
QueryBandwidth: func( QueryBandwidth: func(
c *channeldb.DirectedChannel) lnwire.MilliSatoshi { c *channeldb.DirectedChannel) lnwire.MilliSatoshi {

View file

@ -776,8 +776,12 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
MinProbability: routingConfig.MinRouteProbability, MinProbability: routingConfig.MinRouteProbability,
} }
cachedGraph, err := routing.NewCachedGraph(chanGraph)
if err != nil {
return nil, err
}
paymentSessionSource := &routing.SessionSource{ paymentSessionSource := &routing.SessionSource{
Graph: chanGraph, Graph: cachedGraph,
MissionControl: s.missionControl, MissionControl: s.missionControl,
QueryBandwidth: queryBandwidth, QueryBandwidth: queryBandwidth,
PathFindingConfig: pathFindingConfig, PathFindingConfig: pathFindingConfig,