Merge pull request #9513 from ellemouton/graph5

graph+routing: refactor to remove `graphsession`
This commit is contained in:
Elle 2025-02-18 11:54:24 -03:00 committed by GitHub
commit f9d29f90cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 224 additions and 331 deletions

View file

@ -248,13 +248,16 @@ The underlying functionality between those two options remain the same.
* Graph abstraction work: * Graph abstraction work:
- [Abstract autopilot access](https://github.com/lightningnetwork/lnd/pull/9480) - [Abstract autopilot access](https://github.com/lightningnetwork/lnd/pull/9480)
- [Abstract invoicerpc server access](https://github.com/lightningnetwork/lnd/pull/9516) - [Abstract invoicerpc server access](https://github.com/lightningnetwork/lnd/pull/9516)
- [Refactor to hide DB transactions](https://github.com/lightningnetwork/lnd/pull/9513)
* [Golang was updated to
`v1.22.11`](https://github.com/lightningnetwork/lnd/pull/9462).
* Move funding transaction validation to the gossiper * Move funding transaction validation to the gossiper
[1](https://github.com/lightningnetwork/lnd/pull/9476) [1](https://github.com/lightningnetwork/lnd/pull/9476)
[2](https://github.com/lightningnetwork/lnd/pull/9477) [2](https://github.com/lightningnetwork/lnd/pull/9477)
[3](https://github.com/lightningnetwork/lnd/pull/9478). [3](https://github.com/lightningnetwork/lnd/pull/9478).
## Breaking Changes ## Breaking Changes
## Performance Improvements ## Performance Improvements

View file

@ -403,16 +403,6 @@ func initChannelGraph(db kvdb.Backend) error {
return nil return nil
} }
// NewPathFindTx returns a new read transaction that can be used for a single
// path finding session. Will return nil if the graph cache is enabled.
func (c *ChannelGraph) NewPathFindTx() (kvdb.RTx, error) {
if c.graphCache != nil {
return nil, nil
}
return c.db.BeginReadTx()
}
// AddrsForNode returns all known addresses for the target node public key that // AddrsForNode returns all known addresses for the target node public key that
// the graph DB is aware of. The returned boolean indicates if the given node is // the graph DB is aware of. The returned boolean indicates if the given node is
// unknown to the graph DB or not. // unknown to the graph DB or not.
@ -500,13 +490,14 @@ func (c *ChannelGraph) ForEachChannel(cb func(*models.ChannelEdgeInfo,
}, func() {}) }, func() {})
} }
// ForEachNodeDirectedChannel iterates through all channels of a given node, // forEachNodeDirectedChannel iterates through all channels of a given node,
// executing the passed callback on the directed edge representing the channel // executing the passed callback on the directed edge representing the channel
// and its incoming policy. If the callback returns an error, then the iteration // and its incoming policy. If the callback returns an error, then the iteration
// is halted with the error propagated back up to the caller. // is halted with the error propagated back up to the caller. An optional read
// transaction may be provided. If none is provided, a new one will be created.
// //
// Unknown policies are passed into the callback as nil values. // Unknown policies are passed into the callback as nil values.
func (c *ChannelGraph) ForEachNodeDirectedChannel(tx kvdb.RTx, func (c *ChannelGraph) forEachNodeDirectedChannel(tx kvdb.RTx,
node route.Vertex, cb func(channel *DirectedChannel) error) error { node route.Vertex, cb func(channel *DirectedChannel) error) error {
if c.graphCache != nil { if c.graphCache != nil {
@ -517,7 +508,7 @@ func (c *ChannelGraph) ForEachNodeDirectedChannel(tx kvdb.RTx,
toNodeCallback := func() route.Vertex { toNodeCallback := func() route.Vertex {
return node return node
} }
toNodeFeatures, err := c.FetchNodeFeatures(node) toNodeFeatures, err := c.fetchNodeFeatures(tx, node)
if err != nil { if err != nil {
return err return err
} }
@ -561,9 +552,10 @@ func (c *ChannelGraph) ForEachNodeDirectedChannel(tx kvdb.RTx,
return nodeTraversal(tx, node[:], c.db, dbCallback) return nodeTraversal(tx, node[:], c.db, dbCallback)
} }
// FetchNodeFeatures returns the features of a given node. If no features are // fetchNodeFeatures returns the features of a given node. If no features are
// known for the node, an empty feature vector is returned. // known for the node, an empty feature vector is returned. An optional read
func (c *ChannelGraph) FetchNodeFeatures( // transaction may be provided. If none is provided, a new one will be created.
func (c *ChannelGraph) fetchNodeFeatures(tx kvdb.RTx,
node route.Vertex) (*lnwire.FeatureVector, error) { node route.Vertex) (*lnwire.FeatureVector, error) {
if c.graphCache != nil { if c.graphCache != nil {
@ -571,7 +563,7 @@ func (c *ChannelGraph) FetchNodeFeatures(
} }
// Fallback that uses the database. // Fallback that uses the database.
targetNode, err := c.FetchLightningNode(node) targetNode, err := c.FetchLightningNodeTx(tx, node)
switch err { switch err {
// If the node exists and has features, return them directly. // If the node exists and has features, return them directly.
case nil: case nil:
@ -588,6 +580,34 @@ func (c *ChannelGraph) FetchNodeFeatures(
} }
} }
// ForEachNodeDirectedChannel iterates through all channels of a given node,
// executing the passed callback on the directed edge representing the channel
// and its incoming policy. If the callback returns an error, then the iteration
// is halted with the error propagated back up to the caller. If the graphCache
// is available, then it will be used to retrieve the node's channels instead
// of the database.
//
// Unknown policies are passed into the callback as nil values.
//
// NOTE: this is part of the graphdb.NodeTraverser interface.
func (c *ChannelGraph) ForEachNodeDirectedChannel(nodePub route.Vertex,
cb func(channel *DirectedChannel) error) error {
return c.forEachNodeDirectedChannel(nil, nodePub, cb)
}
// FetchNodeFeatures returns the features of the given node. If no features are
// known for the node, an empty feature vector is returned.
// If the graphCache is available, then it will be used to retrieve the node's
// features instead of the database.
//
// NOTE: this is part of the graphdb.NodeTraverser interface.
func (c *ChannelGraph) FetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) {
return c.fetchNodeFeatures(nil, nodePub)
}
// ForEachNodeCached is similar to forEachNode, but it utilizes the channel // ForEachNodeCached is similar to forEachNode, but it utilizes the channel
// graph cache instead. Note that this doesn't return all the information the // graph cache instead. Note that this doesn't return all the information the
// regular forEachNode method does. // regular forEachNode method does.
@ -617,8 +637,8 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex,
toNodeCallback := func() route.Vertex { toNodeCallback := func() route.Vertex {
return node.PubKeyBytes return node.PubKeyBytes
} }
toNodeFeatures, err := c.FetchNodeFeatures( toNodeFeatures, err := c.fetchNodeFeatures(
node.PubKeyBytes, tx, node.PubKeyBytes,
) )
if err != nil { if err != nil {
return err return err
@ -3873,6 +3893,64 @@ func (c *ChannelGraph) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
return isClosed, nil return isClosed, nil
} }
// GraphSession will provide the call-back with access to a NodeTraverser
// instance which can be used to perform queries against the channel graph. If
// the graph cache is not enabled, then the call-back will be provided with
// access to the graph via a consistent read-only transaction.
func (c *ChannelGraph) GraphSession(cb func(graph NodeTraverser) error) error {
var (
tx kvdb.RTx
err error
commit = func() {}
)
if c.graphCache == nil {
tx, err = c.db.BeginReadTx()
if err != nil {
return err
}
commit = func() {
if err := tx.Rollback(); err != nil {
log.Errorf("Unable to rollback tx: %v", err)
}
}
}
defer commit()
return cb(&nodeTraverserSession{
db: c,
tx: tx,
})
}
// nodeTraverserSession implements the NodeTraverser interface but with a
// backing read only transaction for a consistent view of the graph in the case
// where the graph Cache has not been enabled.
type nodeTraverserSession struct {
tx kvdb.RTx
db *ChannelGraph
}
// ForEachNodeDirectedChannel calls the callback for every channel of the given
// node.
//
// NOTE: Part of the NodeTraverser interface.
func (c *nodeTraverserSession) ForEachNodeDirectedChannel(nodePub route.Vertex,
cb func(channel *DirectedChannel) error) error {
return c.db.forEachNodeDirectedChannel(c.tx, nodePub, cb)
}
// FetchNodeFeatures returns the features of the given node. If the node is
// unknown, assume no additional features are supported.
//
// NOTE: Part of the NodeTraverser interface.
func (c *nodeTraverserSession) FetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) {
return c.db.fetchNodeFeatures(c.tx, nodePub)
}
func putLightningNode(nodeBucket kvdb.RwBucket, aliasBucket kvdb.RwBucket, // nolint:dupl func putLightningNode(nodeBucket kvdb.RwBucket, aliasBucket kvdb.RwBucket, // nolint:dupl
updateIndex kvdb.RwBucket, node *models.LightningNode) error { updateIndex kvdb.RwBucket, node *models.LightningNode) error {

View file

@ -3915,7 +3915,7 @@ func BenchmarkForEachChannel(b *testing.B) {
} }
} }
// TestGraphCacheForEachNodeChannel tests that the ForEachNodeDirectedChannel // TestGraphCacheForEachNodeChannel tests that the forEachNodeDirectedChannel
// method works as expected, and is able to handle nil self edges. // method works as expected, and is able to handle nil self edges.
func TestGraphCacheForEachNodeChannel(t *testing.T) { func TestGraphCacheForEachNodeChannel(t *testing.T) {
graph, err := MakeTestGraph(t) graph, err := MakeTestGraph(t)
@ -3952,7 +3952,7 @@ func TestGraphCacheForEachNodeChannel(t *testing.T) {
getSingleChannel := func() *DirectedChannel { getSingleChannel := func() *DirectedChannel {
var ch *DirectedChannel var ch *DirectedChannel
err = graph.ForEachNodeDirectedChannel(nil, node1.PubKeyBytes, err = graph.forEachNodeDirectedChannel(nil, node1.PubKeyBytes,
func(c *DirectedChannel) error { func(c *DirectedChannel) error {
require.Nil(t, ch) require.Nil(t, ch)
ch = c ch = c

View file

@ -2,6 +2,7 @@ package graphdb
import ( import (
"github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
) )
@ -23,3 +24,16 @@ type NodeRTx interface {
// the same transaction. // the same transaction.
FetchNode(node route.Vertex) (NodeRTx, error) FetchNode(node route.Vertex) (NodeRTx, error)
} }
// NodeTraverser is an abstract read only interface that provides information
// about nodes and their edges. The interface is about providing fast read-only
// access to the graph and so if a cache is available, it should be used.
type NodeTraverser interface {
// ForEachNodeDirectedChannel calls the callback for every channel of
// the given node.
ForEachNodeDirectedChannel(nodePub route.Vertex,
cb func(channel *DirectedChannel) error) error
// FetchNodeFeatures returns the features of the given node.
FetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error)
}

View file

@ -1,141 +0,0 @@
package graphsession
import (
"fmt"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing"
"github.com/lightningnetwork/lnd/routing/route"
)
// Factory implements the routing.GraphSessionFactory and can be used to start
// a session with a ReadOnlyGraph.
type Factory struct {
graph ReadOnlyGraph
}
// NewGraphSessionFactory constructs a new Factory which can then be used to
// start a new session.
func NewGraphSessionFactory(graph ReadOnlyGraph) routing.GraphSessionFactory {
return &Factory{
graph: graph,
}
}
// NewGraphSession will produce a new Graph to use for a path-finding session.
// It returns the Graph along with a call-back that must be called once Graph
// access is complete. This call-back will close any read-only transaction that
// was created at Graph construction time.
//
// NOTE: This is part of the routing.GraphSessionFactory interface.
func (g *Factory) NewGraphSession() (routing.Graph, func() error, error) {
tx, err := g.graph.NewPathFindTx()
if err != nil {
return nil, nil, err
}
session := &session{
graph: g.graph,
tx: tx,
}
return session, session.close, nil
}
// A compile-time check to ensure that Factory implements the
// routing.GraphSessionFactory interface.
var _ routing.GraphSessionFactory = (*Factory)(nil)
// session is an implementation of the routing.Graph interface where the same
// read-only transaction is held across calls to the graph and can be used to
// access the backing channel graph.
type session struct {
graph graph
tx kvdb.RTx
}
// NewRoutingGraph constructs a session that which does not first start a
// read-only transaction and so each call on the routing.Graph will create a
// new transaction.
func NewRoutingGraph(graph ReadOnlyGraph) routing.Graph {
return &session{
graph: graph,
}
}
// close closes the read-only transaction being used to access the backing
// graph. If no transaction was started then this is a no-op.
func (g *session) close() error {
if g.tx == nil {
return nil
}
err := g.tx.Rollback()
if err != nil {
return fmt.Errorf("error closing db tx: %w", err)
}
return nil
}
// ForEachNodeChannel calls the callback for every channel of the given node.
//
// NOTE: Part of the routing.Graph interface.
func (g *session) ForEachNodeChannel(nodePub route.Vertex,
cb func(channel *graphdb.DirectedChannel) error) error {
return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb)
}
// FetchNodeFeatures returns the features of the given node. If the node is
// unknown, assume no additional features are supported.
//
// NOTE: Part of the routing.Graph interface.
func (g *session) FetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) {
return g.graph.FetchNodeFeatures(nodePub)
}
// A compile-time check to ensure that *session implements the
// routing.Graph interface.
var _ routing.Graph = (*session)(nil)
// ReadOnlyGraph is a graph extended with a call to create a new read-only
// transaction that can then be used to make further queries to the graph.
type ReadOnlyGraph interface {
// NewPathFindTx returns a new read transaction that can be used for a
// single path finding session. Will return nil if the graph cache is
// enabled.
NewPathFindTx() (kvdb.RTx, error)
graph
}
// graph describes the API necessary for a graph source to have access to on a
// database implementation, like channeldb.ChannelGraph, in order to be used by
// the Router for pathfinding.
type graph interface {
// ForEachNodeDirectedChannel iterates through all channels of a given
// node, executing the passed callback on the directed edge representing
// the channel and its incoming policy. If the callback returns an
// error, then the iteration is halted with the error propagated back
// up to the caller.
//
// Unknown policies are passed into the callback as nil values.
//
// NOTE: if a nil tx is provided, then it is expected that the
// implementation create a read only tx.
ForEachNodeDirectedChannel(tx kvdb.RTx, node route.Vertex,
cb func(channel *graphdb.DirectedChannel) error) error
// FetchNodeFeatures returns the features of a given node. If no
// features are known for the node, an empty feature vector is returned.
FetchNodeFeatures(node route.Vertex) (*lnwire.FeatureVector, error)
}
// A compile-time check to ensure that *channeldb.ChannelGraph implements the
// graph interface.
var _ graph = (*graphdb.ChannelGraph)(nil)

View file

@ -63,7 +63,7 @@ func newBandwidthManager(graph Graph, sourceNode route.Vertex,
// First, we'll collect the set of outbound edges from the target // First, we'll collect the set of outbound edges from the target
// source node and add them to our bandwidth manager's map of channels. // source node and add them to our bandwidth manager's map of channels.
err := graph.ForEachNodeChannel(sourceNode, err := graph.ForEachNodeDirectedChannel(sourceNode,
func(channel *graphdb.DirectedChannel) error { func(channel *graphdb.DirectedChannel) error {
shortID := lnwire.NewShortChanIDFromInt( shortID := lnwire.NewShortChanIDFromInt(
channel.ChannelID, channel.ChannelID,

View file

@ -12,25 +12,24 @@ import (
// Graph is an abstract interface that provides information about nodes and // Graph is an abstract interface that provides information about nodes and
// edges to pathfinding. // edges to pathfinding.
type Graph interface { type Graph interface {
// ForEachNodeChannel calls the callback for every channel of the given // ForEachNodeDirectedChannel calls the callback for every channel of
// node. // the given node.
ForEachNodeChannel(nodePub route.Vertex, ForEachNodeDirectedChannel(nodePub route.Vertex,
cb func(channel *graphdb.DirectedChannel) error) error cb func(channel *graphdb.DirectedChannel) error) error
// FetchNodeFeatures returns the features of the given node. // FetchNodeFeatures returns the features of the given node.
FetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) FetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error)
} }
// GraphSessionFactory can be used to produce a new Graph instance which can // GraphSessionFactory can be used to gain access to a graphdb.NodeTraverser
// then be used for a path-finding session. Depending on the implementation, // instance which can then be used for a path-finding session. Depending on the
// the Graph session will represent a DB connection where a read-lock is being // implementation, the session will represent a DB connection where a read-lock
// held across calls to the backing Graph. // is being held across calls to the backing graph.
type GraphSessionFactory interface { type GraphSessionFactory interface {
// NewGraphSession will produce a new Graph to use for a path-finding // GraphSession will provide the call-back with access to a
// session. It returns the Graph along with a call-back that must be // graphdb.NodeTraverser instance which can be used to perform queries
// called once Graph access is complete. This call-back will close any // against the channel graph.
// read-only transaction that was created at Graph construction time. GraphSession(cb func(graph graphdb.NodeTraverser) error) error
NewGraphSession() (Graph, func() error, error)
} }
// FetchAmountPairCapacity determines the maximal public capacity between two // FetchAmountPairCapacity determines the maximal public capacity between two

View file

@ -8,7 +8,6 @@ import (
"time" "time"
"github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
@ -211,7 +210,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,
session, err := newPaymentSession( session, err := newPaymentSession(
&payment, c.graph.source.pubkey, getBandwidthHints, &payment, c.graph.source.pubkey, getBandwidthHints,
newMockGraphSessionFactory(c.graph), mc, c.pathFindingCfg, c.graph, mc, c.pathFindingCfg,
) )
if err != nil { if err != nil {
c.t.Fatal(err) c.t.Fatal(err)
@ -317,88 +316,3 @@ func getNodeIndex(route *route.Route, failureSource route.Vertex) *int {
} }
return nil return nil
} }
type mockGraphSessionFactory struct {
Graph
}
func newMockGraphSessionFactory(graph Graph) GraphSessionFactory {
return &mockGraphSessionFactory{Graph: graph}
}
func (m *mockGraphSessionFactory) NewGraphSession() (Graph, func() error,
error) {
return m, func() error {
return nil
}, nil
}
var _ GraphSessionFactory = (*mockGraphSessionFactory)(nil)
var _ Graph = (*mockGraphSessionFactory)(nil)
type mockGraphSessionFactoryChanDB struct {
graph *graphdb.ChannelGraph
}
func newMockGraphSessionFactoryFromChanDB(
graph *graphdb.ChannelGraph) *mockGraphSessionFactoryChanDB {
return &mockGraphSessionFactoryChanDB{
graph: graph,
}
}
func (g *mockGraphSessionFactoryChanDB) NewGraphSession() (Graph, func() error,
error) {
tx, err := g.graph.NewPathFindTx()
if err != nil {
return nil, nil, err
}
session := &mockGraphSessionChanDB{
graph: g.graph,
tx: tx,
}
return session, session.close, nil
}
var _ GraphSessionFactory = (*mockGraphSessionFactoryChanDB)(nil)
type mockGraphSessionChanDB struct {
graph *graphdb.ChannelGraph
tx kvdb.RTx
}
func newMockGraphSessionChanDB(graph *graphdb.ChannelGraph) Graph {
return &mockGraphSessionChanDB{
graph: graph,
}
}
func (g *mockGraphSessionChanDB) close() error {
if g.tx == nil {
return nil
}
err := g.tx.Rollback()
if err != nil {
return fmt.Errorf("error closing db tx: %w", err)
}
return nil
}
func (g *mockGraphSessionChanDB) ForEachNodeChannel(nodePub route.Vertex,
cb func(channel *graphdb.DirectedChannel) error) error {
return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb)
}
func (g *mockGraphSessionChanDB) FetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) {
return g.graph.FetchNodeFeatures(nodePub)
}

View file

@ -404,5 +404,5 @@ func TestPaymentAddrOnlyNoSplit(t *testing.T) {
// The payment should have failed since we need to split in order to // The payment should have failed since we need to split in order to
// route a payment to the destination, but they don't actually support // route a payment to the destination, but they don't actually support
// MPP. // MPP.
require.Equal(t, err.Error(), errNoPathFound.Error()) require.ErrorIs(t, err, errNoPathFound)
} }

View file

@ -165,7 +165,7 @@ func (m *mockGraph) addChannel(id uint64, node1id, node2id byte,
// forEachNodeChannel calls the callback for every channel of the given node. // forEachNodeChannel calls the callback for every channel of the given node.
// //
// NOTE: Part of the Graph interface. // NOTE: Part of the Graph interface.
func (m *mockGraph) ForEachNodeChannel(nodePub route.Vertex, func (m *mockGraph) ForEachNodeDirectedChannel(nodePub route.Vertex,
cb func(channel *graphdb.DirectedChannel) error) error { cb func(channel *graphdb.DirectedChannel) error) error {
// Look up the mock node. // Look up the mock node.
@ -227,6 +227,17 @@ func (m *mockGraph) FetchNodeFeatures(nodePub route.Vertex) (
return lnwire.EmptyFeatureVector(), nil return lnwire.EmptyFeatureVector(), nil
} }
// GraphSession will provide the call-back with access to a
// graphdb.NodeTraverser instance which can be used to perform queries against
// the channel graph.
//
// NOTE: Part of the GraphSessionFactory interface.
func (m *mockGraph) GraphSession(
cb func(graph graphdb.NodeTraverser) error) error {
return cb(m)
}
// htlcResult describes the resolution of an htlc. If failure is nil, the htlc // htlcResult describes the resolution of an htlc. If failure is nil, the htlc
// was settled. // was settled.
type htlcResult struct { type htlcResult struct {

View file

@ -557,7 +557,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
} }
// Iterate over all channels of the to node. // Iterate over all channels of the to node.
err := g.ForEachNodeChannel(node, cb) err := g.ForEachNodeDirectedChannel(node, cb)
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
} }
@ -1325,7 +1325,7 @@ func processNodeForBlindedPath(g Graph, node route.Vertex,
// Now, iterate over the node's channels in search for paths to this // Now, iterate over the node's channels in search for paths to this
// node that can be used for blinded paths // node that can be used for blinded paths
err = g.ForEachNodeChannel(node, err = g.ForEachNodeDirectedChannel(node,
func(channel *graphdb.DirectedChannel) error { func(channel *graphdb.DirectedChannel) error {
// Keep track of how many incoming channels this node // Keep track of how many incoming channels this node
// has. We only use a node as an introduction node if it // has. We only use a node as an introduction node if it

View file

@ -3221,30 +3221,25 @@ func dbFindPath(graph *graphdb.ChannelGraph,
return nil, err return nil, err
} }
graphSessFactory := newMockGraphSessionFactoryFromChanDB(graph) var route []*unifiedEdge
err = graph.GraphSession(func(graph graphdb.NodeTraverser) error {
route, _, err = findPath(
&graphParams{
additionalEdges: additionalEdges,
bandwidthHints: bandwidthHints,
graph: graph,
},
r, cfg, sourceNode.PubKeyBytes, source, target, amt,
timePref, finalHtlcExpiry,
)
graphSess, closeGraphSess, err := graphSessFactory.NewGraphSession() return err
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { return route, nil
if err := closeGraphSess(); err != nil {
log.Errorf("Error closing graph session: %v", err)
}
}()
route, _, err := findPath(
&graphParams{
additionalEdges: additionalEdges,
bandwidthHints: bandwidthHints,
graph: graphSess,
},
r, cfg, sourceNode.PubKeyBytes, source, target, amt, timePref,
finalHtlcExpiry,
)
return route, err
} }
// dbFindBlindedPaths calls findBlindedPaths after getting a db transaction from // dbFindBlindedPaths calls findBlindedPaths after getting a db transaction from
@ -3258,8 +3253,7 @@ func dbFindBlindedPaths(graph *graphdb.ChannelGraph,
} }
return findBlindedPaths( return findBlindedPaths(
newMockGraphSessionChanDB(graph), sourceNode.PubKeyBytes, graph, sourceNode.PubKeyBytes, restrictions,
restrictions,
) )
} }

View file

@ -6,6 +6,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btclog/v2" "github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -235,6 +236,17 @@ func newPaymentSession(p *LightningPayment, selfNode route.Vertex,
}, nil }, nil
} }
// pathFindingError is a wrapper error type that is used to distinguish path
// finding errors from other errors in path finding loop.
type pathFindingError struct {
error
}
// Unwrap returns the underlying error.
func (e *pathFindingError) Unwrap() error {
return e.error
}
// RequestRoute returns a route which is likely to be capable for successfully // RequestRoute returns a route which is likely to be capable for successfully
// routing the specified HTLC payment to the target node. Initially the first // routing the specified HTLC payment to the target node. Initially the first
// set of paths returned from this method may encounter routing failure along // set of paths returned from this method may encounter routing failure along
@ -295,13 +307,8 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
maxAmt = *p.payment.MaxShardAmt maxAmt = *p.payment.MaxShardAmt
} }
for { var path []*unifiedEdge
// Get a routing graph session. findPath := func(graph graphdb.NodeTraverser) error {
graph, closeGraph, err := p.graphSessFactory.NewGraphSession()
if err != nil {
return nil, err
}
// We'll also obtain a set of bandwidthHints from the lower // We'll also obtain a set of bandwidthHints from the lower
// layer for each of our outbound channels. This will allow the // layer for each of our outbound channels. This will allow the
// path finding to skip any links that aren't active or just // path finding to skip any links that aren't active or just
@ -310,19 +317,13 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
// attempt, because concurrent payments may change balances. // attempt, because concurrent payments may change balances.
bandwidthHints, err := p.getBandwidthHints(graph) bandwidthHints, err := p.getBandwidthHints(graph)
if err != nil { if err != nil {
// Close routing graph session. return err
if graphErr := closeGraph(); graphErr != nil {
log.Errorf("could not close graph session: %v",
graphErr)
}
return nil, err
} }
p.log.Debugf("pathfinding for amt=%v", maxAmt) p.log.Debugf("pathfinding for amt=%v", maxAmt)
// Find a route for the current amount. // Find a route for the current amount.
path, _, err := p.pathFinder( path, _, err = p.pathFinder(
&graphParams{ &graphParams{
additionalEdges: p.additionalEdges, additionalEdges: p.additionalEdges,
bandwidthHints: bandwidthHints, bandwidthHints: bandwidthHints,
@ -332,12 +333,31 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
p.selfNode, p.selfNode, p.payment.Target, p.selfNode, p.selfNode, p.payment.Target,
maxAmt, p.payment.TimePref, finalHtlcExpiry, maxAmt, p.payment.TimePref, finalHtlcExpiry,
) )
if err != nil {
// Close routing graph session. // Wrap the error to distinguish path finding errors
if err := closeGraph(); err != nil { // from other errors in this closure.
log.Errorf("could not close graph session: %v", err) return &pathFindingError{err}
} }
return nil
}
for {
err := p.graphSessFactory.GraphSession(findPath)
// If there is an error, and it is not a path finding error, we
// return it immediately.
if err != nil && !lnutils.ErrorAs[*pathFindingError](err) {
return nil, err
} else if err != nil {
// If the error is a path finding error, we'll unwrap it
// to check the underlying error.
//
//nolint:errorlint
pErr, _ := err.(*pathFindingError)
err = pErr.Unwrap()
}
// Otherwise, we'll switch on the path finding error.
switch { switch {
case err == errNoPathFound: case err == errNoPathFound:
// Don't split if this is a legacy payment without mpp // Don't split if this is a legacy payment without mpp

View file

@ -4,6 +4,7 @@ import (
"testing" "testing"
"time" "time"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -118,7 +119,7 @@ func TestUpdateAdditionalEdge(t *testing.T) {
func(Graph) (bandwidthHints, error) { func(Graph) (bandwidthHints, error) {
return &mockBandwidthHints{}, nil return &mockBandwidthHints{}, nil
}, },
newMockGraphSessionFactory(&sessionGraph{}), &sessionGraph{},
&MissionControl{}, &MissionControl{},
PathFindingConfig{}, PathFindingConfig{},
) )
@ -196,7 +197,7 @@ func TestRequestRoute(t *testing.T) {
func(Graph) (bandwidthHints, error) { func(Graph) (bandwidthHints, error) {
return &mockBandwidthHints{}, nil return &mockBandwidthHints{}, nil
}, },
newMockGraphSessionFactory(&sessionGraph{}), &sessionGraph{},
&MissionControl{}, &MissionControl{},
PathFindingConfig{}, PathFindingConfig{},
) )
@ -257,3 +258,9 @@ type sessionGraph struct {
func (g *sessionGraph) sourceNode() route.Vertex { func (g *sessionGraph) sourceNode() route.Vertex {
return route.Vertex{} return route.Vertex{}
} }
func (g *sessionGraph) GraphSession(
cb func(graph graphdb.NodeTraverser) error) error {
return cb(g)
}

View file

@ -135,20 +135,18 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T,
sourceNode, err := graphInstance.graph.SourceNode() sourceNode, err := graphInstance.graph.SourceNode()
require.NoError(t, err) require.NoError(t, err)
sessionSource := &SessionSource{ sessionSource := &SessionSource{
GraphSessionFactory: newMockGraphSessionFactoryFromChanDB( GraphSessionFactory: graphInstance.graph,
graphInstance.graph, SourceNode: sourceNode,
), GetLink: graphInstance.getLink,
SourceNode: sourceNode, PathFindingConfig: pathFindingConfig,
GetLink: graphInstance.getLink, MissionControl: mc,
PathFindingConfig: pathFindingConfig,
MissionControl: mc,
} }
graphBuilder := newMockGraphBuilder(graphInstance.graph) graphBuilder := newMockGraphBuilder(graphInstance.graph)
router, err := New(Config{ router, err := New(Config{
SelfNode: sourceNode.PubKeyBytes, SelfNode: sourceNode.PubKeyBytes,
RoutingGraph: newMockGraphSessionChanDB(graphInstance.graph), RoutingGraph: graphInstance.graph,
Chain: chain, Chain: chain,
Payer: &mockPaymentAttemptDispatcherOld{}, Payer: &mockPaymentAttemptDispatcherOld{},
Control: makeMockControlTower(), Control: makeMockControlTower(),

View file

@ -125,7 +125,7 @@ func (u *nodeEdgeUnifier) addGraphPolicies(g Graph) error {
} }
// Iterate over all channels of the to node. // Iterate over all channels of the to node.
return g.ForEachNodeChannel(u.toNode, cb) return g.ForEachNodeDirectedChannel(u.toNode, cb)
} }
// unifiedEdge is the individual channel data that is kept inside an edgeUnifier // unifiedEdge is the individual channel data that is kept inside an edgeUnifier

View file

@ -51,7 +51,6 @@ import (
"github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/graph"
graphdb "github.com/lightningnetwork/lnd/graph/db" graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/graph/graphsession"
"github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
@ -711,8 +710,8 @@ func (r *rpcServer) addDeps(s *server, macService *macaroons.Service,
amount lnwire.MilliSatoshi) (btcutil.Amount, error) { amount lnwire.MilliSatoshi) (btcutil.Amount, error) {
return routing.FetchAmountPairCapacity( return routing.FetchAmountPairCapacity(
graphsession.NewRoutingGraph(graph), graph, selfNode.PubKeyBytes, nodeFrom, nodeTo,
selfNode.PubKeyBytes, nodeFrom, nodeTo, amount, amount,
) )
}, },
FetchChannelEndpoints: func(chanID uint64) (route.Vertex, FetchChannelEndpoints: func(chanID uint64) (route.Vertex,

View file

@ -45,7 +45,6 @@ import (
"github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/graph"
graphdb "github.com/lightningnetwork/lnd/graph/db" graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/graph/graphsession"
"github.com/lightningnetwork/lnd/healthcheck" "github.com/lightningnetwork/lnd/healthcheck"
"github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/htlcswitch/hop"
@ -1038,13 +1037,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
return nil, fmt.Errorf("error getting source node: %w", err) return nil, fmt.Errorf("error getting source node: %w", err)
} }
paymentSessionSource := &routing.SessionSource{ paymentSessionSource := &routing.SessionSource{
GraphSessionFactory: graphsession.NewGraphSessionFactory( GraphSessionFactory: dbs.GraphDB,
dbs.GraphDB, SourceNode: sourceNode,
), MissionControl: s.defaultMC,
SourceNode: sourceNode, GetLink: s.htlcSwitch.GetLinkByShortID,
MissionControl: s.defaultMC, PathFindingConfig: pathFindingConfig,
GetLink: s.htlcSwitch.GetLinkByShortID,
PathFindingConfig: pathFindingConfig,
} }
paymentControl := channeldb.NewPaymentControl(dbs.ChanStateDB) paymentControl := channeldb.NewPaymentControl(dbs.ChanStateDB)
@ -1073,7 +1070,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
s.chanRouter, err = routing.New(routing.Config{ s.chanRouter, err = routing.New(routing.Config{
SelfNode: selfNode.PubKeyBytes, SelfNode: selfNode.PubKeyBytes,
RoutingGraph: graphsession.NewRoutingGraph(dbs.GraphDB), RoutingGraph: dbs.GraphDB,
Chain: cc.ChainIO, Chain: cc.ChainIO,
Payer: s.htlcSwitch, Payer: s.htlcSwitch,
Control: s.controlTower, Control: s.controlTower,