diff --git a/routing/bandwidth.go b/routing/bandwidth.go index 19c608701..086825568 100644 --- a/routing/bandwidth.go +++ b/routing/bandwidth.go @@ -39,7 +39,7 @@ type bandwidthManager struct { // hints for the edges we directly have open ourselves. Obtaining these hints // allows us to reduce the number of extraneous attempts as we can skip channels // that are inactive, or just don't have enough bandwidth to carry the payment. -func newBandwidthManager(graph routingGraph, sourceNode route.Vertex, +func newBandwidthManager(graph Graph, sourceNode route.Vertex, linkQuery getLinkQuery) (*bandwidthManager, error) { manager := &bandwidthManager{ @@ -49,7 +49,7 @@ func newBandwidthManager(graph routingGraph, sourceNode route.Vertex, // First, we'll collect the set of outbound edges from the target // source node and add them to our bandwidth manager's map of channels. - err := graph.forEachNodeChannel(sourceNode, + err := graph.ForEachNodeChannel(sourceNode, func(channel *channeldb.DirectedChannel) error { shortID := lnwire.NewShortChanIDFromInt( channel.ChannelID, diff --git a/routing/graph.go b/routing/graph.go index 1f0abf9c0..1f4b24bb5 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -10,19 +10,19 @@ import ( "github.com/lightningnetwork/lnd/routing/route" ) -// routingGraph is an abstract interface that provides information about nodes -// and edges to pathfinding. -type routingGraph interface { - // forEachNodeChannel calls the callback for every channel of the given +// Graph is an abstract interface that provides information about nodes and +// edges to pathfinding. +type Graph interface { + // ForEachNodeChannel calls the callback for every channel of the given // node. - forEachNodeChannel(nodePub route.Vertex, + ForEachNodeChannel(nodePub route.Vertex, cb func(channel *channeldb.DirectedChannel) error) error - // fetchNodeFeatures returns the features of the given node. - fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) + // FetchNodeFeatures returns the features of the given node. + FetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) } -// CachedGraph is a routingGraph implementation that retrieves from the +// CachedGraph is a Graph implementation that retrieves from the // database. type CachedGraph struct { graph *channeldb.ChannelGraph @@ -30,9 +30,9 @@ type CachedGraph struct { source route.Vertex } -// A compile time assertion to make sure CachedGraph implements the routingGraph +// A compile time assertion to make sure CachedGraph implements the Graph // interface. -var _ routingGraph = (*CachedGraph)(nil) +var _ Graph = (*CachedGraph)(nil) // NewCachedGraph instantiates a new db-connected routing graph. It implicitly // instantiates a new read transaction. @@ -61,20 +61,20 @@ func (g *CachedGraph) Close() error { return g.tx.Rollback() } -// forEachNodeChannel calls the callback for every channel of the given node. +// ForEachNodeChannel calls the callback for every channel of the given node. // -// NOTE: Part of the routingGraph interface. -func (g *CachedGraph) forEachNodeChannel(nodePub route.Vertex, +// NOTE: Part of the Graph interface. +func (g *CachedGraph) ForEachNodeChannel(nodePub route.Vertex, cb func(channel *channeldb.DirectedChannel) error) error { return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb) } -// fetchNodeFeatures returns the features of the given node. If the node is +// FetchNodeFeatures returns the features of the given node. If the node is // unknown, assume no additional features are supported. // -// NOTE: Part of the routingGraph interface. -func (g *CachedGraph) fetchNodeFeatures(nodePub route.Vertex) ( +// NOTE: Part of the Graph interface. +func (g *CachedGraph) FetchNodeFeatures(nodePub route.Vertex) ( *lnwire.FeatureVector, error) { return g.graph.FetchNodeFeatures(nodePub) diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 95a5eaf65..02cd6f047 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -163,7 +163,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, c.t.Fatal(err) } - getBandwidthHints := func(_ routingGraph) (bandwidthHints, error) { + getBandwidthHints := func(_ Graph) (bandwidthHints, error) { // Create bandwidth hints based on local channel balances. bandwidthHints := map[uint64]lnwire.MilliSatoshi{} for _, ch := range c.graph.nodes[c.source.pubkey].channels { @@ -201,7 +201,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, session, err := newPaymentSession( &payment, c.graph.source.pubkey, getBandwidthHints, - func() (routingGraph, func(), error) { + func() (Graph, func(), error) { return c.graph, func() {}, nil }, mc, c.pathFindingCfg, diff --git a/routing/mock_graph_test.go b/routing/mock_graph_test.go index 2ec9a0f98..348eb3746 100644 --- a/routing/mock_graph_test.go +++ b/routing/mock_graph_test.go @@ -164,8 +164,8 @@ func (m *mockGraph) addChannel(id uint64, node1id, node2id byte, // forEachNodeChannel calls the callback for every channel of the given node. // -// NOTE: Part of the routingGraph interface. -func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, +// NOTE: Part of the Graph interface. +func (m *mockGraph) ForEachNodeChannel(nodePub route.Vertex, cb func(channel *channeldb.DirectedChannel) error) error { // Look up the mock node. @@ -213,15 +213,15 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, // sourceNode returns the source node of the graph. // -// NOTE: Part of the routingGraph interface. +// NOTE: Part of the Graph interface. func (m *mockGraph) sourceNode() route.Vertex { return m.source.pubkey } // fetchNodeFeatures returns the features of the given node. // -// NOTE: Part of the routingGraph interface. -func (m *mockGraph) fetchNodeFeatures(nodePub route.Vertex) ( +// NOTE: Part of the Graph interface. +func (m *mockGraph) FetchNodeFeatures(nodePub route.Vertex) ( *lnwire.FeatureVector, error) { return lnwire.EmptyFeatureVector(), nil @@ -230,7 +230,7 @@ func (m *mockGraph) fetchNodeFeatures(nodePub route.Vertex) ( // FetchAmountPairCapacity returns the maximal capacity between nodes in the // graph. // -// NOTE: Part of the routingGraph interface. +// NOTE: Part of the Graph interface. func (m *mockGraph) FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, error) { @@ -244,7 +244,7 @@ func (m *mockGraph) FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex, return nil } - err := m.forEachNodeChannel(nodeFrom, cb) + err := m.ForEachNodeChannel(nodeFrom, cb) if err != nil { return 0, err } @@ -295,5 +295,5 @@ func (m *mockGraph) sendHtlc(route *route.Route) (htlcResult, error) { return source.fwd(nil, next) } -// Compile-time check for the routingGraph interface. -var _ routingGraph = &mockGraph{} +// Compile-time check for the Graph interface. +var _ Graph = &mockGraph{} diff --git a/routing/pathfind.go b/routing/pathfind.go index d7d2893b0..083af04db 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -369,7 +369,7 @@ func edgeWeight(lockedAmt lnwire.MilliSatoshi, fee lnwire.MilliSatoshi, // graphParams wraps the set of graph parameters passed to findPath. type graphParams struct { // graph is the ChannelGraph to be used during path finding. - graph routingGraph + graph Graph // additionalEdges is an optional set of edges that should be // considered during path finding, that is not already found in the @@ -464,7 +464,7 @@ type PathFindingConfig struct { // available balance. func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, bandwidthHints bandwidthHints, - g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { + g Graph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { var max, total lnwire.MilliSatoshi cb := func(channel *channeldb.DirectedChannel) error { @@ -502,7 +502,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, } // Iterate over all channels of the to node. - err := g.forEachNodeChannel(node, cb) + err := g.ForEachNodeChannel(node, cb) if err != nil { return 0, 0, err } @@ -542,7 +542,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, features := r.DestFeatures if features == nil { var err error - features, err = g.graph.fetchNodeFeatures(target) + features, err = g.graph.FetchNodeFeatures(target) if err != nil { return nil, 0, err } @@ -920,7 +920,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, } // Fetch node features fresh from the graph. - fromFeatures, err := g.graph.fetchNodeFeatures(node) + fromFeatures, err := g.graph.FetchNodeFeatures(node) if err != nil { return nil, err } diff --git a/routing/payment_session.go b/routing/payment_session.go index bdd194812..6cfbeddf4 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -167,7 +167,7 @@ type paymentSession struct { additionalEdges map[route.Vertex][]AdditionalEdge - getBandwidthHints func(routingGraph) (bandwidthHints, error) + getBandwidthHints func(Graph) (bandwidthHints, error) payment *LightningPayment @@ -175,7 +175,7 @@ type paymentSession struct { pathFinder pathFinder - getRoutingGraph func() (routingGraph, func(), error) + getRoutingGraph func() (Graph, func(), error) // pathFindingConfig defines global parameters that control the // trade-off in path finding between fees and probability. @@ -195,8 +195,8 @@ type paymentSession struct { // newPaymentSession instantiates a new payment session. func newPaymentSession(p *LightningPayment, selfNode route.Vertex, - getBandwidthHints func(routingGraph) (bandwidthHints, error), - getRoutingGraph func() (routingGraph, func(), error), + getBandwidthHints func(Graph) (bandwidthHints, error), + getRoutingGraph func() (Graph, func(), error), missionControl MissionController, pathFindingConfig PathFindingConfig) ( *paymentSession, error) { diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index ba010391b..51bfc9781 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -46,7 +46,7 @@ type SessionSource struct { // getRoutingGraph returns a routing graph and a clean-up function for // pathfinding. -func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) { +func (m *SessionSource) getRoutingGraph() (Graph, func(), error) { routingTx, err := NewCachedGraph(m.SourceNode, m.Graph) if err != nil { return nil, nil, err @@ -66,7 +66,7 @@ func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) { func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( PaymentSession, error) { - getBandwidthHints := func(graph routingGraph) (bandwidthHints, error) { + getBandwidthHints := func(graph Graph) (bandwidthHints, error) { return newBandwidthManager( graph, m.SourceNode.PubKeyBytes, m.GetLink, ) diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index b7efed5b7..9356a2be0 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -116,10 +116,10 @@ func TestUpdateAdditionalEdge(t *testing.T) { // Create the paymentsession. session, err := newPaymentSession( payment, route.Vertex{}, - func(routingGraph) (bandwidthHints, error) { + func(Graph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, - func() (routingGraph, func(), error) { + func() (Graph, func(), error) { return &sessionGraph{}, func() {}, nil }, &MissionControl{}, @@ -196,10 +196,10 @@ func TestRequestRoute(t *testing.T) { session, err := newPaymentSession( payment, route.Vertex{}, - func(routingGraph) (bandwidthHints, error) { + func(Graph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, - func() (routingGraph, func(), error) { + func() (Graph, func(), error) { return &sessionGraph{}, func() {}, nil }, &MissionControl{}, @@ -253,7 +253,7 @@ func TestRequestRoute(t *testing.T) { } type sessionGraph struct { - routingGraph + Graph } func (g *sessionGraph) sourceNode() route.Vertex { diff --git a/routing/router.go b/routing/router.go index 597705754..9af047f4e 100644 --- a/routing/router.go +++ b/routing/router.go @@ -453,9 +453,9 @@ type ChannelRouter struct { // when doing any path finding. selfNode *channeldb.LightningNode - // cachedGraph is an instance of routingGraph that caches the source + // cachedGraph is an instance of Graph that caches the source // node as well as the channel graph itself in memory. - cachedGraph routingGraph + cachedGraph Graph // newBlocks is a channel in which new blocks connected to the end of // the main chain are sent over, and blocks updated after a call to @@ -3177,7 +3177,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // getRouteUnifiers returns a list of edge unifiers for the given route. func getRouteUnifiers(source route.Vertex, hops []route.Vertex, useMinAmt bool, runningAmt lnwire.MilliSatoshi, - outgoingChans map[uint64]struct{}, graph routingGraph, + outgoingChans map[uint64]struct{}, graph Graph, bandwidthHints *bandwidthManager) ([]*edgeUnifier, lnwire.MilliSatoshi, error) { diff --git a/routing/unified_edges.go b/routing/unified_edges.go index d39eda1ef..a0300eea4 100644 --- a/routing/unified_edges.go +++ b/routing/unified_edges.go @@ -94,7 +94,7 @@ func (u *nodeEdgeUnifier) addPolicy(fromNode route.Vertex, // addGraphPolicies adds all policies that are known for the toNode in the // graph. -func (u *nodeEdgeUnifier) addGraphPolicies(g routingGraph) error { +func (u *nodeEdgeUnifier) addGraphPolicies(g Graph) error { cb := func(channel *channeldb.DirectedChannel) error { // If there is no edge policy for this candidate node, skip. // Note that we are searching backwards so this node would have @@ -120,7 +120,7 @@ func (u *nodeEdgeUnifier) addGraphPolicies(g routingGraph) error { } // Iterate over all channels of the to node. - return g.forEachNodeChannel(u.toNode, cb) + return g.ForEachNodeChannel(u.toNode, cb) } // unifiedEdge is the individual channel data that is kept inside an edgeUnifier