diff --git a/routing/blinding.go b/routing/blinding.go index 401d7f3ee..270f998d9 100644 --- a/routing/blinding.go +++ b/routing/blinding.go @@ -154,10 +154,36 @@ func (s *BlindedPaymentPathSet) Features() *lnwire.FeatureVector { return s.features } -// GetPath is a temporary getter for the single path that the set holds. -// This will be removed later on in this PR. -func (s *BlindedPaymentPathSet) GetPath() *BlindedPayment { - return s.paths[0] +// IntroNodeOnlyPath can be called if it is expected that the path set only +// contains a single payment path which itself only has one hop. It errors if +// this is not the case. +func (s *BlindedPaymentPathSet) IntroNodeOnlyPath() (*BlindedPayment, error) { + if len(s.paths) != 1 { + return nil, fmt.Errorf("expected only a single path in the "+ + "blinded payment set, got %d", len(s.paths)) + } + + if len(s.paths[0].BlindedPath.BlindedHops) > 1 { + return nil, fmt.Errorf("an intro node only path cannot have " + + "more than one hop") + } + + return s.paths[0], nil +} + +// IsIntroNode returns true if the given vertex is an introduction node for one +// of the paths in the blinded payment path set. +func (s *BlindedPaymentPathSet) IsIntroNode(source route.Vertex) bool { + for _, path := range s.paths { + introVertex := route.NewVertex( + path.BlindedPath.IntroductionPoint, + ) + if source == introVertex { + return true + } + } + + return false } // FinalCLTVDelta is the minimum CLTV delta to use for the final hop on the diff --git a/routing/pathfind.go b/routing/pathfind.go index 5169bdab9..35b44cf6b 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -309,7 +309,11 @@ func newRoute(sourceVertex route.Vertex, // we can assume the relevant payment is the only one in the // payment set. if blindedPayment == nil { - blindedPayment = blindedPathSet.GetPath() + var err error + blindedPayment, err = blindedPathSet.IntroNodeOnlyPath() + if err != nil { + return nil, err + } } var ( diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 802385351..8fc50bb4f 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -3296,10 +3296,21 @@ func TestBlindedRouteConstruction(t *testing.T) { daveEveEdge := blindedEdges[daveBlindedVertex][0] edges := []*unifiedEdge{ - {policy: aliceBobEdge}, - {policy: bobCarolEdge}, - {policy: carolDaveEdge.EdgePolicy()}, - {policy: daveEveEdge.EdgePolicy()}, + { + policy: aliceBobEdge, + }, + { + policy: bobCarolEdge, + blindedPayment: blindedPayment, + }, + { + policy: carolDaveEdge.EdgePolicy(), + blindedPayment: blindedPayment, + }, + { + policy: daveEveEdge.EdgePolicy(), + blindedPayment: blindedPayment, + }, } // Total timelock for the route should include: diff --git a/routing/router.go b/routing/router.go index bca78b5ad..0a0af0d86 100644 --- a/routing/router.go +++ b/routing/router.go @@ -505,12 +505,7 @@ func NewRouteRequest(source route.Vertex, target *route.Vertex, ) if blindedPathSet != nil { - blindedPayment := blindedPathSet.GetPath() - - introVertex := route.NewVertex( - blindedPayment.BlindedPath.IntroductionPoint, - ) - if source == introVertex { + if blindedPathSet.IsIntroNode(source) { return nil, ErrSelfIntro }