diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 5c16cbf0f..b87dddef1 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -932,7 +932,7 @@ func (p *shardHandler) handleFailureMessage(rt *route.Route, } // Apply channel update to the channel edge policy in our db. - if !p.router.applyChannelUpdate(update, errSource) { + if !p.router.applyChannelUpdate(update) { log.Debugf("Invalid channel update received: node=%v", errVertex) } diff --git a/routing/router.go b/routing/router.go index 348914db7..cf6357c68 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2393,16 +2393,32 @@ func (r *ChannelRouter) extractChannelUpdate( // applyChannelUpdate validates a channel update and if valid, applies it to the // database. It returns a bool indicating whether the updates were successful. -func (r *ChannelRouter) applyChannelUpdate(msg *lnwire.ChannelUpdate, - pubKey *btcec.PublicKey) bool { - +func (r *ChannelRouter) applyChannelUpdate(msg *lnwire.ChannelUpdate) bool { ch, _, _, err := r.GetChannelByID(msg.ShortChannelID) if err != nil { log.Errorf("Unable to retrieve channel by id: %v", err) return false } - if err := ValidateChannelUpdateAnn(pubKey, ch.Capacity, msg); err != nil { + var pubKey *btcec.PublicKey + + switch msg.ChannelFlags & lnwire.ChanUpdateDirection { + case 0: + pubKey, _ = ch.NodeKey1() + + case 1: + pubKey, _ = ch.NodeKey2() + } + + // Exit early if the pubkey cannot be decided. + if pubKey == nil { + log.Errorf("Unable to decide pubkey with ChannelFlags=%v", + msg.ChannelFlags) + return false + } + + err = ValidateChannelUpdateAnn(pubKey, ch.Capacity, msg) + if err != nil { log.Errorf("Unable to validate channel update: %v", err) return false } diff --git a/routing/router_test.go b/routing/router_test.go index 029db0c68..8c1edd253 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -385,7 +385,9 @@ func TestChannelUpdateValidation(t *testing.T) { }, 2), } - testGraph, err := createTestGraphFromChannels(t, true, testChannels, "a") + testGraph, err := createTestGraphFromChannels( + t, true, testChannels, "a", + ) require.NoError(t, err, "unable to create graph") const startingBlockHeight = 101 @@ -394,13 +396,13 @@ func TestChannelUpdateValidation(t *testing.T) { ) // Assert that the initially configured fee is retrieved correctly. - _, policy, _, err := ctx.router.GetChannelByID( - lnwire.NewShortChanIDFromInt(1)) + _, e1, e2, err := ctx.router.GetChannelByID( + lnwire.NewShortChanIDFromInt(1), + ) require.NoError(t, err, "cannot retrieve channel") - require.Equal(t, - feeRate, policy.FeeProportionalMillionths, "invalid fee", - ) + require.Equal(t, feeRate, e1.FeeProportionalMillionths, "invalid fee") + require.Equal(t, feeRate, e2.FeeProportionalMillionths, "invalid fee") // Setup a route from source a to destination c. The route will be used // in a call to SendToRoute. SendToRoute also applies channel updates, @@ -430,10 +432,13 @@ func TestChannelUpdateValidation(t *testing.T) { // returned to the sender. var invalidSignature [64]byte errChanUpdate := lnwire.ChannelUpdate{ - Signature: invalidSignature, - FeeRate: 500, - ShortChannelID: lnwire.NewShortChanIDFromInt(1), - Timestamp: uint32(testTime.Add(time.Minute).Unix()), + Signature: invalidSignature, + FeeRate: 500, + ShortChannelID: lnwire.NewShortChanIDFromInt(1), + Timestamp: uint32(testTime.Add(time.Minute).Unix()), + MessageFlags: e2.MessageFlags, + ChannelFlags: e2.ChannelFlags, + HtlcMaximumMsat: e2.MaxHTLC, } // We'll modify the SendToSwitch method so that it simulates a failed @@ -459,34 +464,34 @@ func TestChannelUpdateValidation(t *testing.T) { _, err = ctx.router.SendToRoute(payment, rt) require.Error(t, err, "expected route to fail with channel update") - _, policy, _, err = ctx.router.GetChannelByID( - lnwire.NewShortChanIDFromInt(1)) + _, e1, e2, err = ctx.router.GetChannelByID( + lnwire.NewShortChanIDFromInt(1), + ) require.NoError(t, err, "cannot retrieve channel") - require.Equal(t, - feeRate, policy.FeeProportionalMillionths, - "fee updated without valid signature", - ) + require.Equal(t, feeRate, e1.FeeProportionalMillionths, + "fee updated without valid signature") + require.Equal(t, feeRate, e2.FeeProportionalMillionths, + "fee updated without valid signature") // Next, add a signature to the channel update. signErrChanUpdate(t, testGraph.privKeyMap["b"], &errChanUpdate) // Retry the payment using the same route as before. _, err = ctx.router.SendToRoute(payment, rt) - if err == nil { - t.Fatalf("expected route to fail with channel update") - } + require.Error(t, err, "expected route to fail with channel update") // This time a valid signature was supplied and the policy change should // have been applied to the graph. - _, policy, _, err = ctx.router.GetChannelByID( - lnwire.NewShortChanIDFromInt(1)) + _, e1, e2, err = ctx.router.GetChannelByID( + lnwire.NewShortChanIDFromInt(1), + ) require.NoError(t, err, "cannot retrieve channel") - require.Equal(t, - lnwire.MilliSatoshi(500), policy.FeeProportionalMillionths, - "fee not updated even though signature is valid", - ) + require.Equal(t, feeRate, e1.FeeProportionalMillionths, + "fee should not be updated") + require.EqualValues(t, 500, int(e2.FeeProportionalMillionths), + "fee not updated even though signature is valid") } // TestSendPaymentErrorRepeatedFeeInsufficient tests that if we receive