lnd/routing/payment_session_test.go
Elle Mouton e004447da6
multi: remove the need for the graphsession package
In this commit, we add a `GraphSession` method to the `ChannelGraph`.
This method provides a caller with access to a `NodeTraverser`. This is
used by pathfinding to create a graph "session" overwhich to perform a
set of queries for a pathfinding attempt. With this refactor, we hide
details such as DB transaction creation and transaction commits from the
caller. So with this, pathfinding does not need to remember to "close
the graph session". With this commit, the `graphsession` package may be
completely removed.
2025-02-18 10:15:41 -03:00

266 lines
6.3 KiB
Go

package routing
import (
"testing"
"time"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/zpay32"
"github.com/stretchr/testify/require"
)
func TestValidateCLTVLimit(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
cltvLimit uint32
finalCltvDelta uint16
includePadding bool
expectError bool
}{
{
name: "bad limit with padding",
cltvLimit: uint32(103),
finalCltvDelta: uint16(100),
includePadding: true,
expectError: true,
},
{
name: "good limit with padding",
cltvLimit: uint32(104),
finalCltvDelta: uint16(100),
includePadding: true,
expectError: false,
},
{
name: "bad limit no padding",
cltvLimit: uint32(100),
finalCltvDelta: uint16(100),
includePadding: false,
expectError: true,
},
{
name: "good limit no padding",
cltvLimit: uint32(101),
finalCltvDelta: uint16(100),
includePadding: false,
expectError: false,
},
}
for _, testCase := range testCases {
testCase := testCase
success := t.Run(testCase.name, func(t *testing.T) {
err := ValidateCLTVLimit(
testCase.cltvLimit, testCase.finalCltvDelta,
testCase.includePadding,
)
if testCase.expectError {
require.NotEmpty(t, err)
} else {
require.NoError(t, err)
}
})
if !success {
break
}
}
}
// TestUpdateAdditionalEdge checks that we can update the additional edges as
// expected.
func TestUpdateAdditionalEdge(t *testing.T) {
var (
testChannelID = uint64(12345)
oldFeeBaseMSat = uint32(1000)
newFeeBaseMSat = uint32(1100)
oldExpiryDelta = uint16(100)
newExpiryDelta = uint16(120)
payHash lntypes.Hash
)
// Create a minimal test node using the private key priv1.
pub := priv1.PubKey().SerializeCompressed()
testNode := &models.LightningNode{}
copy(testNode.PubKeyBytes[:], pub)
nodeID, err := testNode.PubKey()
require.NoError(t, err, "failed to get node id")
// Create a payment with a route hint.
payment := &LightningPayment{
Target: testNode.PubKeyBytes,
Amount: 1000,
RouteHints: [][]zpay32.HopHint{{
zpay32.HopHint{
// The nodeID is actually the target itself. It
// doesn't matter as we are not doing routing
// in this test.
NodeID: nodeID,
ChannelID: testChannelID,
FeeBaseMSat: oldFeeBaseMSat,
CLTVExpiryDelta: oldExpiryDelta,
},
}},
paymentHash: &payHash,
}
// Create the paymentsession.
session, err := newPaymentSession(
payment, route.Vertex{},
func(Graph) (bandwidthHints, error) {
return &mockBandwidthHints{}, nil
},
&sessionGraph{},
&MissionControl{},
PathFindingConfig{},
)
require.NoError(t, err, "failed to create payment session")
// We should have 1 additional edge.
require.Equal(t, 1, len(session.additionalEdges))
// The edge should use nodeID as key, and its value should have 1 edge
// policy.
vertex := route.NewVertex(nodeID)
policies, ok := session.additionalEdges[vertex]
require.True(t, ok, "cannot find policy")
require.Equal(t, 1, len(policies), "should have 1 edge policy")
// Check that the policy has been created as expected.
policy := policies[0].EdgePolicy()
require.Equal(t, testChannelID, policy.ChannelID, "channel ID mismatch")
require.Equal(t,
oldExpiryDelta, policy.TimeLockDelta, "timelock delta mismatch",
)
require.Equal(t,
lnwire.MilliSatoshi(oldFeeBaseMSat),
policy.FeeBaseMSat, "fee base msat mismatch",
)
// Create the channel update message and sign.
msg := &lnwire.ChannelUpdate1{
ShortChannelID: lnwire.NewShortChanIDFromInt(testChannelID),
Timestamp: uint32(time.Now().Unix()),
BaseFee: newFeeBaseMSat,
TimeLockDelta: newExpiryDelta,
}
signErrChanUpdate(t, priv1, msg)
// Apply the update.
require.True(t,
session.UpdateAdditionalEdge(msg, nodeID, policy),
"failed to update additional edge",
)
// Check that the policy has been updated as expected.
require.Equal(t, testChannelID, policy.ChannelID, "channel ID mismatch")
require.Equal(t,
newExpiryDelta, policy.TimeLockDelta, "timelock delta mismatch",
)
require.Equal(t,
lnwire.MilliSatoshi(newFeeBaseMSat),
policy.FeeBaseMSat, "fee base msat mismatch",
)
}
func TestRequestRoute(t *testing.T) {
const (
height = 10
)
cltvLimit := uint32(30)
finalCltvDelta := uint16(8)
payment := &LightningPayment{
CltvLimit: cltvLimit,
FinalCLTVDelta: finalCltvDelta,
Amount: 1000,
FeeLimit: 1000,
}
var paymentHash [32]byte
if err := payment.SetPaymentHash(paymentHash); err != nil {
t.Fatal(err)
}
session, err := newPaymentSession(
payment, route.Vertex{},
func(Graph) (bandwidthHints, error) {
return &mockBandwidthHints{}, nil
},
&sessionGraph{},
&MissionControl{},
PathFindingConfig{},
)
if err != nil {
t.Fatal(err)
}
// Override pathfinder with a mock.
session.pathFinder = func(_ *graphParams, r *RestrictParams,
_ *PathFindingConfig, _, _, _ route.Vertex,
_ lnwire.MilliSatoshi, _ float64, _ int32) ([]*unifiedEdge,
float64, error) {
// We expect find path to receive a cltv limit excluding the
// final cltv delta (including the block padding).
if r.CltvLimit != 22-uint32(BlockPadding) {
t.Fatal("wrong cltv limit")
}
path := []*unifiedEdge{
{
policy: &models.CachedEdgePolicy{
ToNodePubKey: func() route.Vertex {
return route.Vertex{}
},
ToNodeFeatures: lnwire.NewFeatureVector(
nil, nil,
),
},
},
}
return path, 1.0, nil
}
route, err := session.RequestRoute(
payment.Amount, payment.FeeLimit, 0, height,
lnwire.CustomRecords{
lnwire.MinCustomRecordsTlvType + 123: []byte{1, 2, 3},
},
)
if err != nil {
t.Fatal(err)
}
// We expect an absolute route lock value of height + finalCltvDelta
// + BlockPadding.
if route.TotalTimeLock != 18+uint32(BlockPadding) {
t.Fatalf("unexpected total time lock of %v",
route.TotalTimeLock)
}
}
type sessionGraph struct {
Graph
}
func (g *sessionGraph) sourceNode() route.Vertex {
return route.Vertex{}
}
func (g *sessionGraph) GraphSession(
cb func(graph graphdb.NodeTraverser) error) error {
return cb(g)
}