routing+lnd: prepare closed channel SCIDs in server

The method `FetchClosedChannels` sometimes prematurely mark a pending
force closing channel as finalized, therefore we need to furthur check
`FetchPendingChannels` to make sure the channel is indeed finalized.
This commit is contained in:
yyforyongyu 2023-11-17 21:07:53 +08:00 committed by yyforyongyu
parent e8f292edf4
commit 188aa9a4d4
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
3 changed files with 83 additions and 53 deletions

View File

@ -286,12 +286,10 @@ type Config struct {
// graph that we received from a payment failure. // graph that we received from a payment failure.
ApplyChannelUpdate func(msg *lnwire.ChannelUpdate) bool ApplyChannelUpdate func(msg *lnwire.ChannelUpdate) bool
// FetchClosedChannels is used by the router to fetch closed channels. // ClosedSCIDs is used by the router to fetch closed channels.
// //
// TODO(yy): remove this method once the root cause of stuck payments // TODO(yy): remove it once the root cause of stuck payments is found.
// is found. ClosedSCIDs map[lnwire.ShortChannelID]struct{}
FetchClosedChannels func(pendingOnly bool) (
[]*channeldb.ChannelCloseSummary, error)
} }
// EdgeLocator is a struct used to identify a specific edge. // EdgeLocator is a struct used to identify a specific edge.
@ -1391,19 +1389,6 @@ func (r *ChannelRouter) BuildRoute(amt fn.Option[lnwire.MilliSatoshi],
// resumePayments fetches inflight payments and resumes their payment // resumePayments fetches inflight payments and resumes their payment
// lifecycles. // lifecycles.
func (r *ChannelRouter) resumePayments() error { func (r *ChannelRouter) resumePayments() error {
// Get a list of closed channels.
channels, err := r.cfg.FetchClosedChannels(false)
if err != nil {
return err
}
closedSCIDs := make(map[lnwire.ShortChannelID]struct{}, len(channels))
for _, c := range channels {
if !c.IsPending {
closedSCIDs[c.ShortChanID] = struct{}{}
}
}
// Get all payments that are inflight. // Get all payments that are inflight.
payments, err := r.cfg.Control.FetchInFlightPayments() payments, err := r.cfg.Control.FetchInFlightPayments()
if err != nil { if err != nil {
@ -1422,9 +1407,7 @@ func (r *ChannelRouter) resumePayments() error {
// Try to fail the attempt if the route contains a dead // Try to fail the attempt if the route contains a dead
// channel. // channel.
r.failStaleAttempt( r.failStaleAttempt(a, p.Info.PaymentIdentifier)
a, p.Info.PaymentIdentifier, closedSCIDs,
)
} }
} }
@ -1510,7 +1493,7 @@ func (r *ChannelRouter) resumePayments() error {
// - https://github.com/lightningnetwork/lnd/issues/8146 // - https://github.com/lightningnetwork/lnd/issues/8146
// - https://github.com/lightningnetwork/lnd/pull/8174 // - https://github.com/lightningnetwork/lnd/pull/8174
func (r *ChannelRouter) failStaleAttempt(a channeldb.HTLCAttempt, func (r *ChannelRouter) failStaleAttempt(a channeldb.HTLCAttempt,
payHash lntypes.Hash, closedSCIDs map[lnwire.ShortChannelID]struct{}) { payHash lntypes.Hash) {
// We can only fail inflight HTLCs so we skip the settled/failed ones. // We can only fail inflight HTLCs so we skip the settled/failed ones.
if a.Failure != nil || a.Settle != nil { if a.Failure != nil || a.Settle != nil {
@ -1571,12 +1554,12 @@ func (r *ChannelRouter) failStaleAttempt(a channeldb.HTLCAttempt,
} }
// The channel link is not active, we now check whether this // The channel link is not active, we now check whether this
// channel is already closed. If so, we fail it as there's no // channel is already closed. If so, we fail the HTLC attempt
// need to wait for the network result because it won't be // as there's no need to wait for its network result because
// re-sent. If the channel is still pending, we'll keep waiting // there's no link. If the channel is still pending, we'll keep
// for the result as we may get a contract resolution for this // waiting for the result as we may get a contract resolution
// HTLC. // for this HTLC.
if _, ok := closedSCIDs[scid]; ok { if _, ok := r.cfg.ClosedSCIDs[scid]; ok {
shouldFail = true shouldFail = true
} }
} }

View File

@ -93,9 +93,7 @@ func (c *testCtx) getChannelIDFromAlias(t *testing.T, a, b string) uint64 {
return channelID return channelID
} }
func mockFetchClosedChannels(_ bool) ([]*channeldb.ChannelCloseSummary, error) { var mockClosedSCIDs map[lnwire.ShortChannelID]struct{}
return nil, nil
}
func createTestCtxFromGraphInstance(t *testing.T, startingHeight uint32, func createTestCtxFromGraphInstance(t *testing.T, startingHeight uint32,
graphInstance *testGraphInstance) *testCtx { graphInstance *testGraphInstance) *testCtx {
@ -162,10 +160,10 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T,
next := atomic.AddUint64(&uniquePaymentID, 1) next := atomic.AddUint64(&uniquePaymentID, 1)
return next, nil return next, nil
}, },
PathFindingConfig: pathFindingConfig, PathFindingConfig: pathFindingConfig,
Clock: clock.NewTestClock(time.Unix(1, 0)), Clock: clock.NewTestClock(time.Unix(1, 0)),
ApplyChannelUpdate: graphBuilder.ApplyChannelUpdate, ApplyChannelUpdate: graphBuilder.ApplyChannelUpdate,
FetchClosedChannels: mockFetchClosedChannels, ClosedSCIDs: mockClosedSCIDs,
}) })
require.NoError(t, router.Start(), "unable to start router") require.NoError(t, router.Start(), "unable to start router")
@ -2175,7 +2173,7 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) {
NextPaymentID: func() (uint64, error) { NextPaymentID: func() (uint64, error) {
return 0, nil return 0, nil
}, },
FetchClosedChannels: mockFetchClosedChannels, ClosedSCIDs: mockClosedSCIDs,
}} }}
// Register mockers with the expected method calls. // Register mockers with the expected method calls.
@ -2259,7 +2257,7 @@ func TestSendToRouteSkipTempErrNonMPP(t *testing.T) {
NextPaymentID: func() (uint64, error) { NextPaymentID: func() (uint64, error) {
return 0, nil return 0, nil
}, },
FetchClosedChannels: mockFetchClosedChannels, ClosedSCIDs: mockClosedSCIDs,
}} }}
// Expect an error to be returned. // Expect an error to be returned.
@ -2314,7 +2312,7 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) {
NextPaymentID: func() (uint64, error) { NextPaymentID: func() (uint64, error) {
return 0, nil return 0, nil
}, },
FetchClosedChannels: mockFetchClosedChannels, ClosedSCIDs: mockClosedSCIDs,
}} }}
// Create the error to be returned. // Create the error to be returned.
@ -2397,7 +2395,7 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) {
NextPaymentID: func() (uint64, error) { NextPaymentID: func() (uint64, error) {
return 0, nil return 0, nil
}, },
FetchClosedChannels: mockFetchClosedChannels, ClosedSCIDs: mockClosedSCIDs,
}} }}
// Create the error to be returned. // Create the error to be returned.
@ -2484,7 +2482,7 @@ func TestSendToRouteTempFailure(t *testing.T) {
NextPaymentID: func() (uint64, error) { NextPaymentID: func() (uint64, error) {
return 0, nil return 0, nil
}, },
FetchClosedChannels: mockFetchClosedChannels, ClosedSCIDs: mockClosedSCIDs,
}} }}
// Create the error to be returned. // Create the error to be returned.

View File

@ -993,19 +993,19 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
} }
s.chanRouter, err = routing.New(routing.Config{ s.chanRouter, err = routing.New(routing.Config{
SelfNode: selfNode.PubKeyBytes, SelfNode: selfNode.PubKeyBytes,
RoutingGraph: graphsession.NewRoutingGraph(chanGraph), RoutingGraph: graphsession.NewRoutingGraph(chanGraph),
Chain: cc.ChainIO, Chain: cc.ChainIO,
Payer: s.htlcSwitch, Payer: s.htlcSwitch,
Control: s.controlTower, Control: s.controlTower,
MissionControl: s.missionControl, MissionControl: s.missionControl,
SessionSource: paymentSessionSource, SessionSource: paymentSessionSource,
GetLink: s.htlcSwitch.GetLinkByShortID, GetLink: s.htlcSwitch.GetLinkByShortID,
NextPaymentID: sequencer.NextID, NextPaymentID: sequencer.NextID,
PathFindingConfig: pathFindingConfig, PathFindingConfig: pathFindingConfig,
Clock: clock.NewDefaultClock(), Clock: clock.NewDefaultClock(),
ApplyChannelUpdate: s.graphBuilder.ApplyChannelUpdate, ApplyChannelUpdate: s.graphBuilder.ApplyChannelUpdate,
FetchClosedChannels: s.chanStateDB.FetchClosedChannels, ClosedSCIDs: s.fetchClosedChannelSCIDs(),
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("can't create router: %w", err) return nil, fmt.Errorf("can't create router: %w", err)
@ -4830,3 +4830,52 @@ func shouldPeerBootstrap(cfg *Config) bool {
// covering the bootstrapping process. // covering the bootstrapping process.
return !cfg.NoNetBootstrap && !isDevNetwork return !cfg.NoNetBootstrap && !isDevNetwork
} }
// fetchClosedChannelSCIDs returns a set of SCIDs that have their force closing
// finished.
func (s *server) fetchClosedChannelSCIDs() map[lnwire.ShortChannelID]struct{} {
// Get a list of closed channels.
channels, err := s.chanStateDB.FetchClosedChannels(false)
if err != nil {
srvrLog.Errorf("Failed to fetch closed channels: %v", err)
return nil
}
// Save the SCIDs in a map.
closedSCIDs := make(map[lnwire.ShortChannelID]struct{}, len(channels))
for _, c := range channels {
// If the channel is not pending, its FC has been finalized.
if !c.IsPending {
closedSCIDs[c.ShortChanID] = struct{}{}
}
}
// Double check whether the reported closed channel has indeed finished
// closing.
//
// NOTE: There are misalignments regarding when a channel's FC is
// marked as finalized. We double check the pending channels to make
// sure the returned SCIDs are indeed terminated.
//
// TODO(yy): fix the misalignments in `FetchClosedChannels`.
pendings, err := s.chanStateDB.FetchPendingChannels()
if err != nil {
srvrLog.Errorf("Failed to fetch pending channels: %v", err)
return nil
}
for _, c := range pendings {
if _, ok := closedSCIDs[c.ShortChannelID]; !ok {
continue
}
// If the channel is still reported as pending, remove it from
// the map.
delete(closedSCIDs, c.ShortChannelID)
srvrLog.Warnf("Channel=%v is prematurely marked as finalized",
c.ShortChannelID)
}
return closedSCIDs
}