From 92747e839a54c7f719af9358efe690341a2f5c4e Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 10 Feb 2025 18:34:34 -0800 Subject: [PATCH] multi: thread thru RPC caller context from CloseChannel --- htlcswitch/switch.go | 9 ++++- peer/brontide.go | 88 ++++++++++++++++++++++---------------------- peer/test_utils.go | 2 +- rpcserver.go | 5 ++- server.go | 4 +- 5 files changed, 59 insertions(+), 49 deletions(-) diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index cf53f728a..fea5496c9 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -2,6 +2,7 @@ package htlcswitch import ( "bytes" + "context" "errors" "fmt" "math/rand" @@ -125,6 +126,9 @@ type ChanClose struct { // Err is used by request creator to receive request execution error. Err chan error + + // Ctx is a context linked to the lifetime of the caller. + Ctx context.Context //nolint:containedctx } // Config defines the configuration for the service. ALL elements within the @@ -1413,7 +1417,7 @@ func (s *Switch) teardownCircuit(pkt *htlcPacket) error { // targetFeePerKw parameter should be the ideal fee-per-kw that will be used as // a starting point for close negotiation. The deliveryScript parameter is an // optional parameter which sets a user specified script to close out to. -func (s *Switch) CloseLink(chanPoint *wire.OutPoint, +func (s *Switch) CloseLink(ctx context.Context, chanPoint *wire.OutPoint, closeType contractcourt.ChannelCloseType, targetFeePerKw, maxFee chainfee.SatPerKWeight, deliveryScript lnwire.DeliveryAddress) (chan interface{}, chan error) { @@ -1427,9 +1431,10 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint, ChanPoint: chanPoint, Updates: updateChan, TargetFeePerKw: targetFeePerKw, - MaxFee: maxFee, DeliveryScript: deliveryScript, Err: errChan, + MaxFee: maxFee, + Ctx: ctx, } select { diff --git a/peer/brontide.go b/peer/brontide.go index f0399b4e8..8686fcd3b 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -582,8 +582,11 @@ type Brontide struct { globalMsgRouter bool startReady chan struct{} - quit chan struct{} - wg sync.WaitGroup + + // cg is a helper that encapsulates a wait group and quit channel and + // allows contexts that either block or cancel on those depending on + // the use case. + cg *fn.ContextGuard // log is a peer-specific logging instance. log btclog.Logger @@ -627,10 +630,10 @@ func NewBrontide(cfg Config) *Brontide { chanCloseMsgs: make(chan *closeMsg), resentChanSyncMsg: make(map[lnwire.ChannelID]struct{}), startReady: make(chan struct{}), - quit: make(chan struct{}), log: peerLog.WithPrefix(logPrefix), msgRouter: msgRouter, globalMsgRouter: globalMsgRouter, + cg: fn.NewContextGuard(), } if cfg.Conn != nil && cfg.Conn.RemoteAddr() != nil { @@ -754,9 +757,9 @@ func (p *Brontide) Start() error { // message MUST be sent before any other message. readErr := make(chan error, 1) msgChan := make(chan lnwire.Message, 1) - p.wg.Add(1) + p.cg.WgAdd(1) go func() { - defer p.wg.Done() + defer p.cg.WgDone() msg, err := p.readNextMessage() if err != nil { @@ -845,7 +848,7 @@ func (p *Brontide) Start() error { return fmt.Errorf("could not start ping manager %w", err) } - p.wg.Add(4) + p.cg.WgAdd(4) go p.queueHandler() go p.writeHandler() go p.channelManager() @@ -865,7 +868,7 @@ func (p *Brontide) Start() error { // // TODO(wilmer): Remove this once we're able to query for node // announcements through their timestamps. - p.wg.Add(2) + p.cg.WgAdd(2) go p.maybeSendNodeAnn(activeChans) go p.maybeSendChannelUpdates() @@ -914,7 +917,7 @@ func (p *Brontide) taprootShutdownAllowed() bool { // // NOTE: Part of the lnpeer.Peer interface. func (p *Brontide) QuitSignal() <-chan struct{} { - return p.quit + return p.cg.Done() } // addrWithInternalKey takes a delivery script, then attempts to supplement it @@ -1276,7 +1279,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, select { case p.linkFailures <- failure: - case <-p.quit: + case <-p.cg.Done(): case <-p.cfg.Quit: } } @@ -1353,7 +1356,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, // maybeSendNodeAnn sends our node announcement to the remote peer if at least // one confirmed public channel exists with them. func (p *Brontide) maybeSendNodeAnn(channels []*channeldb.OpenChannel) { - defer p.wg.Done() + defer p.cg.WgDone() hasConfirmedPublicChan := false for _, channel := range channels { @@ -1385,7 +1388,7 @@ func (p *Brontide) maybeSendNodeAnn(channels []*channeldb.OpenChannel) { // maybeSendChannelUpdates sends our channel updates to the remote peer if we // have any active channels with them. func (p *Brontide) maybeSendChannelUpdates() { - defer p.wg.Done() + defer p.cg.WgDone() // If we don't have any active channels, then we can exit early. if p.activeChannels.Len() == 0 { @@ -1461,16 +1464,16 @@ func (p *Brontide) WaitForDisconnect(ready chan struct{}) { // set of goroutines are already active. select { case <-p.startReady: - case <-p.quit: + case <-p.cg.Done(): return } select { case <-ready: - case <-p.quit: + case <-p.cg.Done(): } - p.wg.Wait() + p.cg.WgWait() } // Disconnect terminates the connection with the remote peer. Additionally, a @@ -1492,7 +1495,7 @@ func (p *Brontide) Disconnect(reason error) { select { case <-p.startReady: - case <-p.quit: + case <-p.cg.Done(): return } } @@ -1508,7 +1511,7 @@ func (p *Brontide) Disconnect(reason error) { // Ensure that the TCP connection is properly closed before continuing. p.cfg.Conn.Close() - close(p.quit) + p.cg.Quit() // If our msg router isn't global (local to this instance), then we'll // stop it. Otherwise, we'll leave it running. @@ -1689,7 +1692,7 @@ func (ms *msgStream) msgConsumer() { // Otherwise, we'll check the message queue for any new // items. select { - case <-ms.peer.quit: + case <-ms.peer.cg.Done(): ms.msgCond.L.Unlock() return case <-ms.quit: @@ -1716,7 +1719,7 @@ func (ms *msgStream) msgConsumer() { // grow indefinitely. select { case ms.producerSema <- struct{}{}: - case <-ms.peer.quit: + case <-ms.peer.cg.Done(): return case <-ms.quit: return @@ -1734,7 +1737,7 @@ func (ms *msgStream) AddMsg(msg lnwire.Message) { // we're signalled to quit, or a slot is freed up. select { case <-ms.producerSema: - case <-ms.peer.quit: + case <-ms.peer.cg.Done(): return case <-ms.quit: return @@ -1811,7 +1814,7 @@ func waitUntilLinkActive(p *Brontide, // calling function should catch it. return p.fetchLinkFromKeyAndCid(cid) - case <-p.quit: + case <-p.cg.Done(): return nil } } @@ -1845,7 +1848,7 @@ func newChanMsgStream(p *Brontide, cid lnwire.ChannelID) *msgStream { // as the peer is exiting, we'll check quickly to see // if we need to exit. select { - case <-p.quit: + case <-p.cg.Done(): return default: } @@ -1885,7 +1888,7 @@ func newDiscMsgStream(p *Brontide) *msgStream { // // NOTE: This method MUST be run as a goroutine. func (p *Brontide) readHandler() { - defer p.wg.Done() + defer p.cg.WgDone() // We'll stop the timer after a new messages is received, and also // reset it after we process the next message. @@ -2006,13 +2009,13 @@ out: case *lnwire.Shutdown: select { case p.chanCloseMsgs <- &closeMsg{msg.ChannelID, msg}: - case <-p.quit: + case <-p.cg.Done(): break out } case *lnwire.ClosingSigned: select { case p.chanCloseMsgs <- &closeMsg{msg.ChannelID, msg}: - case <-p.quit: + case <-p.cg.Done(): break out } @@ -2584,7 +2587,7 @@ out: break out } - case <-p.quit: + case <-p.cg.Done(): exitErr = lnpeer.ErrPeerExiting break out } @@ -2592,7 +2595,7 @@ out: // Avoid an exit deadlock by ensuring WaitGroups are decremented before // disconnect. - p.wg.Done() + p.cg.WgDone() p.Disconnect(exitErr) @@ -2604,7 +2607,7 @@ out: // // NOTE: This method MUST be run as a goroutine. func (p *Brontide) queueHandler() { - defer p.wg.Done() + defer p.cg.WgDone() // priorityMsgs holds an in order list of messages deemed high-priority // to be added to the sendQueue. This predominately includes messages @@ -2645,7 +2648,7 @@ func (p *Brontide) queueHandler() { } else { lazyMsgs.PushBack(msg) } - case <-p.quit: + case <-p.cg.Done(): return } } else { @@ -2659,7 +2662,7 @@ func (p *Brontide) queueHandler() { } else { lazyMsgs.PushBack(msg) } - case <-p.quit: + case <-p.cg.Done(): return } } @@ -2693,7 +2696,7 @@ func (p *Brontide) queue(priority bool, msg lnwire.Message, select { case p.outgoingQueue <- outgoingMsg{priority, msg, errChan}: - case <-p.quit: + case <-p.cg.Done(): p.log.Tracef("Peer shutting down, could not enqueue msg: %v.", spew.Sdump(msg)) if errChan != nil { @@ -2761,7 +2764,7 @@ func (p *Brontide) genDeliveryScript() ([]byte, error) { // // NOTE: This method MUST be run as a goroutine. func (p *Brontide) channelManager() { - defer p.wg.Done() + defer p.cg.WgDone() // reenableTimeout will fire once after the configured channel status // interval has elapsed. This will trigger us to sign new channel @@ -2837,7 +2840,7 @@ out: p.channelEventClient.Cancel() } - case <-p.quit: + case <-p.cg.Done(): // As, we've been signalled to exit, we'll reset all // our active channel back to their default state. p.activeChannels.ForEach(func(_ lnwire.ChannelID, @@ -3120,7 +3123,7 @@ func (p *Brontide) retryRequestEnable(activeChans map[wire.OutPoint]struct{}) { p.log.Warnf("Re-enable channel %v failed, received "+ "inactive link event", chanPoint) - case <-p.quit: + case <-p.cg.Done(): p.log.Debugf("Peer shutdown during retry enabling") return } @@ -3291,7 +3294,6 @@ func (p *Brontide) createChanCloser(channel *lnwallet.LightningChannel, return p.cfg.DisconnectPeer(p.IdentityKey()) }, ChainParams: &p.cfg.Wallet.Cfg.NetParams, - Quit: p.quit, }, *deliveryScript, fee, @@ -3856,7 +3858,7 @@ func (p *Brontide) sendMessage(sync, priority bool, msgs ...lnwire.Message) erro select { case err := <-errChan: return err - case <-p.quit: + case <-p.cg.Done(): return lnpeer.ErrPeerExiting case <-p.cfg.Quit: return lnpeer.ErrPeerExiting @@ -3904,7 +3906,7 @@ func (p *Brontide) AddNewChannel(newChan *lnpeer.NewChannel, case p.newActiveChannel <- newChanMsg: case <-cancel: return errors.New("canceled adding new channel") - case <-p.quit: + case <-p.cg.Done(): return lnpeer.ErrPeerExiting } @@ -3913,7 +3915,7 @@ func (p *Brontide) AddNewChannel(newChan *lnpeer.NewChannel, select { case err := <-errChan: return err - case <-p.quit: + case <-p.cg.Done(): return lnpeer.ErrPeerExiting } } @@ -3937,7 +3939,7 @@ func (p *Brontide) AddPendingChannel(cid lnwire.ChannelID, case <-cancel: return errors.New("canceled adding pending channel") - case <-p.quit: + case <-p.cg.Done(): return lnpeer.ErrPeerExiting } @@ -3951,7 +3953,7 @@ func (p *Brontide) AddPendingChannel(cid lnwire.ChannelID, case <-cancel: return errors.New("canceled adding pending channel") - case <-p.quit: + case <-p.cg.Done(): return lnpeer.ErrPeerExiting } } @@ -3968,7 +3970,7 @@ func (p *Brontide) RemovePendingChannel(cid lnwire.ChannelID) error { select { case p.removePendingChannel <- newChanMsg: - case <-p.quit: + case <-p.cg.Done(): return lnpeer.ErrPeerExiting } @@ -3979,7 +3981,7 @@ func (p *Brontide) RemovePendingChannel(cid lnwire.ChannelID) error { case err := <-errChan: return err - case <-p.quit: + case <-p.cg.Done(): return lnpeer.ErrPeerExiting } } @@ -4131,7 +4133,7 @@ func (p *Brontide) HandleLocalCloseChanReqs(req *htlcswitch.ChanClose) { case p.localCloseChanReqs <- req: p.log.Info("Local close channel request is going to be " + "delivered to the peer") - case <-p.quit: + case <-p.cg.Done(): p.log.Info("Unable to deliver local close channel request " + "to peer") } @@ -4461,7 +4463,7 @@ func (p *Brontide) sendLinkUpdateMsg(cid lnwire.ChannelID, msg lnwire.Message) { // Stop the stream when quit. go func() { - <-p.quit + <-p.cg.Done() chanStream.Stop() }() } diff --git a/peer/test_utils.go b/peer/test_utils.go index 34c42e2f7..bcde83bf9 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -337,7 +337,7 @@ func createTestPeerWithChannel(t *testing.T, updateChan func(a, chanID := lnwire.NewChanIDFromOutPoint(channelAlice.ChannelPoint()) alicePeer.activeChannels.Store(chanID, channelAlice) - alicePeer.wg.Add(1) + alicePeer.cg.WgAdd(1) go alicePeer.channelManager() return &peerTestCtx{ diff --git a/rpcserver.go b/rpcserver.go index 7236a5dad..7b5847fac 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -2887,8 +2887,9 @@ func (r *rpcServer) CloseChannel(in *lnrpc.CloseChannelRequest, } updateChan, errChan = r.server.htlcSwitch.CloseLink( - chanPoint, contractcourt.CloseRegular, feeRate, - maxFee, deliveryScript, + updateStream.Context(), chanPoint, + contractcourt.CloseRegular, feeRate, maxFee, + deliveryScript, ) } diff --git a/server.go b/server.go index ecb3eceae..04cafd265 100644 --- a/server.go +++ b/server.go @@ -1236,7 +1236,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, // Instruct the switch to close the channel. Provide no close out // delivery script or target fee per kw because user input is not // available when the remote peer closes the channel. - s.htlcSwitch.CloseLink(chanPoint, closureType, 0, 0, nil) + s.htlcSwitch.CloseLink( + context.Background(), chanPoint, closureType, 0, 0, nil, + ) } // We will use the following channel to reliably hand off contract