lnd/routing/payment_session_test.go

260 lines
6.2 KiB
Go

package routing
import (
"testing"
"time"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/zpay32"
"github.com/stretchr/testify/require"
)
func TestValidateCLTVLimit(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
cltvLimit uint32
finalCltvDelta uint16
includePadding bool
expectError bool
}{
{
name: "bad limit with padding",
cltvLimit: uint32(103),
finalCltvDelta: uint16(100),
includePadding: true,
expectError: true,
},
{
name: "good limit with padding",
cltvLimit: uint32(104),
finalCltvDelta: uint16(100),
includePadding: true,
expectError: false,
},
{
name: "bad limit no padding",
cltvLimit: uint32(100),
finalCltvDelta: uint16(100),
includePadding: false,
expectError: true,
},
{
name: "good limit no padding",
cltvLimit: uint32(101),
finalCltvDelta: uint16(100),
includePadding: false,
expectError: false,
},
}
for _, testCase := range testCases {
testCase := testCase
success := t.Run(testCase.name, func(t *testing.T) {
err := ValidateCLTVLimit(
testCase.cltvLimit, testCase.finalCltvDelta,
testCase.includePadding,
)
if testCase.expectError {
require.NotEmpty(t, err)
} else {
require.NoError(t, err)
}
})
if !success {
break
}
}
}
// TestUpdateAdditionalEdge checks that we can update the additional edges as
// expected.
func TestUpdateAdditionalEdge(t *testing.T) {
var (
testChannelID = uint64(12345)
oldFeeBaseMSat = uint32(1000)
newFeeBaseMSat = uint32(1100)
oldExpiryDelta = uint16(100)
newExpiryDelta = uint16(120)
payHash lntypes.Hash
)
// Create a minimal test node using the private key priv1.
pub := priv1.PubKey().SerializeCompressed()
testNode := &models.LightningNode{}
copy(testNode.PubKeyBytes[:], pub)
nodeID, err := testNode.PubKey()
require.NoError(t, err, "failed to get node id")
// Create a payment with a route hint.
payment := &LightningPayment{
Target: testNode.PubKeyBytes,
Amount: 1000,
RouteHints: [][]zpay32.HopHint{{
zpay32.HopHint{
// The nodeID is actually the target itself. It
// doesn't matter as we are not doing routing
// in this test.
NodeID: nodeID,
ChannelID: testChannelID,
FeeBaseMSat: oldFeeBaseMSat,
CLTVExpiryDelta: oldExpiryDelta,
},
}},
paymentHash: &payHash,
}
// Create the paymentsession.
session, err := newPaymentSession(
payment, route.Vertex{},
func(Graph) (bandwidthHints, error) {
return &mockBandwidthHints{}, nil
},
newMockGraphSessionFactory(&sessionGraph{}),
&MissionControl{},
PathFindingConfig{},
)
require.NoError(t, err, "failed to create payment session")
// We should have 1 additional edge.
require.Equal(t, 1, len(session.additionalEdges))
// The edge should use nodeID as key, and its value should have 1 edge
// policy.
vertex := route.NewVertex(nodeID)
policies, ok := session.additionalEdges[vertex]
require.True(t, ok, "cannot find policy")
require.Equal(t, 1, len(policies), "should have 1 edge policy")
// Check that the policy has been created as expected.
policy := policies[0].EdgePolicy()
require.Equal(t, testChannelID, policy.ChannelID, "channel ID mismatch")
require.Equal(t,
oldExpiryDelta, policy.TimeLockDelta, "timelock delta mismatch",
)
require.Equal(t,
lnwire.MilliSatoshi(oldFeeBaseMSat),
policy.FeeBaseMSat, "fee base msat mismatch",
)
// Create the channel update message and sign.
msg := &lnwire.ChannelUpdate1{
ShortChannelID: lnwire.NewShortChanIDFromInt(testChannelID),
Timestamp: uint32(time.Now().Unix()),
BaseFee: newFeeBaseMSat,
TimeLockDelta: newExpiryDelta,
}
signErrChanUpdate(t, priv1, msg)
// Apply the update.
require.True(t,
session.UpdateAdditionalEdge(msg, nodeID, policy),
"failed to update additional edge",
)
// Check that the policy has been updated as expected.
require.Equal(t, testChannelID, policy.ChannelID, "channel ID mismatch")
require.Equal(t,
newExpiryDelta, policy.TimeLockDelta, "timelock delta mismatch",
)
require.Equal(t,
lnwire.MilliSatoshi(newFeeBaseMSat),
policy.FeeBaseMSat, "fee base msat mismatch",
)
}
func TestRequestRoute(t *testing.T) {
const (
height = 10
)
cltvLimit := uint32(30)
finalCltvDelta := uint16(8)
payment := &LightningPayment{
CltvLimit: cltvLimit,
FinalCLTVDelta: finalCltvDelta,
Amount: 1000,
FeeLimit: 1000,
}
var paymentHash [32]byte
if err := payment.SetPaymentHash(paymentHash); err != nil {
t.Fatal(err)
}
session, err := newPaymentSession(
payment, route.Vertex{},
func(Graph) (bandwidthHints, error) {
return &mockBandwidthHints{}, nil
},
newMockGraphSessionFactory(&sessionGraph{}),
&MissionControl{},
PathFindingConfig{},
)
if err != nil {
t.Fatal(err)
}
// Override pathfinder with a mock.
session.pathFinder = func(_ *graphParams, r *RestrictParams,
_ *PathFindingConfig, _, _, _ route.Vertex,
_ lnwire.MilliSatoshi, _ float64, _ int32) ([]*unifiedEdge,
float64, error) {
// We expect find path to receive a cltv limit excluding the
// final cltv delta (including the block padding).
if r.CltvLimit != 22-uint32(BlockPadding) {
t.Fatal("wrong cltv limit")
}
path := []*unifiedEdge{
{
policy: &models.CachedEdgePolicy{
ToNodePubKey: func() route.Vertex {
return route.Vertex{}
},
ToNodeFeatures: lnwire.NewFeatureVector(
nil, nil,
),
},
},
}
return path, 1.0, nil
}
route, err := session.RequestRoute(
payment.Amount, payment.FeeLimit, 0, height,
lnwire.CustomRecords{
lnwire.MinCustomRecordsTlvType + 123: []byte{1, 2, 3},
},
)
if err != nil {
t.Fatal(err)
}
// We expect an absolute route lock value of height + finalCltvDelta
// + BlockPadding.
if route.TotalTimeLock != 18+uint32(BlockPadding) {
t.Fatalf("unexpected total time lock of %v",
route.TotalTimeLock)
}
}
type sessionGraph struct {
Graph
}
func (g *sessionGraph) sourceNode() route.Vertex {
return route.Vertex{}
}