From 5df776e80bfc51e309603477e590f46989fab9e2 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 15 Apr 2021 23:58:02 +0800 Subject: [PATCH] routing: add method UpdateAdditionalEdge and GetAdditionalEdgePolicy This commit adds the method UpdateAdditionalEdge in PaymentSession, which allows the addtional channel edge policy to be updated from a ChannelUpdate message. Another method, GetAdditionalEdgePolicy is added to allow querying additional edge policies. --- routing/payment_session.go | 52 ++++++++++++++++ routing/payment_session_test.go | 106 ++++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+) diff --git a/routing/payment_session.go b/routing/payment_session.go index 9dc280fae..d2c022558 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -3,7 +3,9 @@ package routing import ( "fmt" + "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btclog" + "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwire" @@ -382,3 +384,53 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, return route, err } } + +// UpdateAdditionalEdge updates the channel edge policy for a private edge. It +// validates the message signature and checks it's up to date, then applies the +// updates to the supplied policy. It returns a boolean to indicate whether +// there's an error when applying the updates. +func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, + pubKey *btcec.PublicKey, policy *channeldb.ChannelEdgePolicy) bool { + + // Validate the message signature. + if err := VerifyChannelUpdateSignature(msg, pubKey); err != nil { + log.Errorf( + "Unable to validate channel update signature: %v", err, + ) + return false + } + + // Update channel policy for the additional edge. + policy.TimeLockDelta = msg.TimeLockDelta + policy.FeeBaseMSat = lnwire.MilliSatoshi(msg.BaseFee) + policy.FeeProportionalMillionths = lnwire.MilliSatoshi(msg.FeeRate) + + log.Debugf("New private channel update applied: %v", + newLogClosure(func() string { return spew.Sdump(msg) })) + + return true +} + +// GetAdditionalEdgePolicy uses the public key and channel ID to query the +// ephemeral channel edge policy for additional edges. Returns a nil if nothing +// found. +func (p *paymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, + channelID uint64) *channeldb.ChannelEdgePolicy { + + target := route.NewVertex(pubKey) + + edges, ok := p.additionalEdges[target] + if !ok { + return nil + } + + for _, edge := range edges { + if edge.ChannelID != channelID { + continue + } + + return edge + } + + return nil +} diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index 2fa103d38..edc4515b5 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -2,10 +2,13 @@ package routing import ( "testing" + "time" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/zpay32" "github.com/stretchr/testify/require" ) @@ -70,6 +73,109 @@ func TestValidateCLTVLimit(t *testing.T) { } } +// TestUpdateAdditionalEdge checks that we can update the additional edges as +// expected. +func TestUpdateAdditionalEdge(t *testing.T) { + + var ( + testChannelID = uint64(12345) + oldFeeBaseMSat = uint32(1000) + newFeeBaseMSat = uint32(1100) + oldExpiryDelta = uint16(100) + newExpiryDelta = uint16(120) + + payHash lntypes.Hash + ) + + // Create a minimal test node using the private key priv1. + pub := priv1.PubKey().SerializeCompressed() + testNode := &channeldb.LightningNode{} + copy(testNode.PubKeyBytes[:], pub) + + nodeID, err := testNode.PubKey() + require.NoError(t, err, "failed to get node id") + + // Create a payment with a route hint. + payment := &LightningPayment{ + Target: testNode.PubKeyBytes, + Amount: 1000, + RouteHints: [][]zpay32.HopHint{{ + zpay32.HopHint{ + // The nodeID is actually the target itself. It + // doesn't matter as we are not doing routing + // in this test. + NodeID: nodeID, + ChannelID: testChannelID, + FeeBaseMSat: oldFeeBaseMSat, + CLTVExpiryDelta: oldExpiryDelta, + }, + }}, + paymentHash: &payHash, + } + + // Create the paymentsession. + session, err := newPaymentSession( + payment, + func() (map[uint64]lnwire.MilliSatoshi, + error) { + + return nil, nil + }, + func() (routingGraph, func(), error) { + return &sessionGraph{}, func() {}, nil + }, + &MissionControl{}, + PathFindingConfig{}, + ) + require.NoError(t, err, "failed to create payment session") + + // We should have 1 additional edge. + require.Equal(t, 1, len(session.additionalEdges)) + + // The edge should use nodeID as key, and its value should have 1 edge + // policy. + vertex := route.NewVertex(nodeID) + policies, ok := session.additionalEdges[vertex] + require.True(t, ok, "cannot find policy") + require.Equal(t, 1, len(policies), "should have 1 edge policy") + + // Check that the policy has been created as expected. + policy := policies[0] + require.Equal(t, testChannelID, policy.ChannelID, "channel ID mismatch") + require.Equal(t, + oldExpiryDelta, policy.TimeLockDelta, "timelock delta mismatch", + ) + require.Equal(t, + lnwire.MilliSatoshi(oldFeeBaseMSat), + policy.FeeBaseMSat, "fee base msat mismatch", + ) + + // Create the channel update message and sign. + msg := &lnwire.ChannelUpdate{ + ShortChannelID: lnwire.NewShortChanIDFromInt(testChannelID), + Timestamp: uint32(time.Now().Unix()), + BaseFee: newFeeBaseMSat, + TimeLockDelta: newExpiryDelta, + } + signErrChanUpdate(t, priv1, msg) + + // Apply the update. + require.True(t, + session.UpdateAdditionalEdge(msg, nodeID, policy), + "failed to update additional edge", + ) + + // Check that the policy has been updated as expected. + require.Equal(t, testChannelID, policy.ChannelID, "channel ID mismatch") + require.Equal(t, + newExpiryDelta, policy.TimeLockDelta, "timelock delta mismatch", + ) + require.Equal(t, + lnwire.MilliSatoshi(newFeeBaseMSat), + policy.FeeBaseMSat, "fee base msat mismatch", + ) +} + func TestRequestRoute(t *testing.T) { const ( height = 10