multi: thread thru RPC caller context from CloseChannel

This commit is contained in:
Olaoluwa Osuntokun 2025-02-10 18:34:34 -08:00
parent 357b94aa2c
commit 92747e839a
5 changed files with 59 additions and 49 deletions

View file

@ -2,6 +2,7 @@ package htlcswitch
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"math/rand" "math/rand"
@ -125,6 +126,9 @@ type ChanClose struct {
// Err is used by request creator to receive request execution error. // Err is used by request creator to receive request execution error.
Err chan error Err chan error
// Ctx is a context linked to the lifetime of the caller.
Ctx context.Context //nolint:containedctx
} }
// Config defines the configuration for the service. ALL elements within the // Config defines the configuration for the service. ALL elements within the
@ -1413,7 +1417,7 @@ func (s *Switch) teardownCircuit(pkt *htlcPacket) error {
// targetFeePerKw parameter should be the ideal fee-per-kw that will be used as // targetFeePerKw parameter should be the ideal fee-per-kw that will be used as
// a starting point for close negotiation. The deliveryScript parameter is an // a starting point for close negotiation. The deliveryScript parameter is an
// optional parameter which sets a user specified script to close out to. // optional parameter which sets a user specified script to close out to.
func (s *Switch) CloseLink(chanPoint *wire.OutPoint, func (s *Switch) CloseLink(ctx context.Context, chanPoint *wire.OutPoint,
closeType contractcourt.ChannelCloseType, closeType contractcourt.ChannelCloseType,
targetFeePerKw, maxFee chainfee.SatPerKWeight, targetFeePerKw, maxFee chainfee.SatPerKWeight,
deliveryScript lnwire.DeliveryAddress) (chan interface{}, chan error) { deliveryScript lnwire.DeliveryAddress) (chan interface{}, chan error) {
@ -1427,9 +1431,10 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint,
ChanPoint: chanPoint, ChanPoint: chanPoint,
Updates: updateChan, Updates: updateChan,
TargetFeePerKw: targetFeePerKw, TargetFeePerKw: targetFeePerKw,
MaxFee: maxFee,
DeliveryScript: deliveryScript, DeliveryScript: deliveryScript,
Err: errChan, Err: errChan,
MaxFee: maxFee,
Ctx: ctx,
} }
select { select {

View file

@ -582,8 +582,11 @@ type Brontide struct {
globalMsgRouter bool globalMsgRouter bool
startReady chan struct{} startReady chan struct{}
quit chan struct{}
wg sync.WaitGroup // cg is a helper that encapsulates a wait group and quit channel and
// allows contexts that either block or cancel on those depending on
// the use case.
cg *fn.ContextGuard
// log is a peer-specific logging instance. // log is a peer-specific logging instance.
log btclog.Logger log btclog.Logger
@ -627,10 +630,10 @@ func NewBrontide(cfg Config) *Brontide {
chanCloseMsgs: make(chan *closeMsg), chanCloseMsgs: make(chan *closeMsg),
resentChanSyncMsg: make(map[lnwire.ChannelID]struct{}), resentChanSyncMsg: make(map[lnwire.ChannelID]struct{}),
startReady: make(chan struct{}), startReady: make(chan struct{}),
quit: make(chan struct{}),
log: peerLog.WithPrefix(logPrefix), log: peerLog.WithPrefix(logPrefix),
msgRouter: msgRouter, msgRouter: msgRouter,
globalMsgRouter: globalMsgRouter, globalMsgRouter: globalMsgRouter,
cg: fn.NewContextGuard(),
} }
if cfg.Conn != nil && cfg.Conn.RemoteAddr() != nil { if cfg.Conn != nil && cfg.Conn.RemoteAddr() != nil {
@ -754,9 +757,9 @@ func (p *Brontide) Start() error {
// message MUST be sent before any other message. // message MUST be sent before any other message.
readErr := make(chan error, 1) readErr := make(chan error, 1)
msgChan := make(chan lnwire.Message, 1) msgChan := make(chan lnwire.Message, 1)
p.wg.Add(1) p.cg.WgAdd(1)
go func() { go func() {
defer p.wg.Done() defer p.cg.WgDone()
msg, err := p.readNextMessage() msg, err := p.readNextMessage()
if err != nil { if err != nil {
@ -845,7 +848,7 @@ func (p *Brontide) Start() error {
return fmt.Errorf("could not start ping manager %w", err) return fmt.Errorf("could not start ping manager %w", err)
} }
p.wg.Add(4) p.cg.WgAdd(4)
go p.queueHandler() go p.queueHandler()
go p.writeHandler() go p.writeHandler()
go p.channelManager() go p.channelManager()
@ -865,7 +868,7 @@ func (p *Brontide) Start() error {
// //
// TODO(wilmer): Remove this once we're able to query for node // TODO(wilmer): Remove this once we're able to query for node
// announcements through their timestamps. // announcements through their timestamps.
p.wg.Add(2) p.cg.WgAdd(2)
go p.maybeSendNodeAnn(activeChans) go p.maybeSendNodeAnn(activeChans)
go p.maybeSendChannelUpdates() go p.maybeSendChannelUpdates()
@ -914,7 +917,7 @@ func (p *Brontide) taprootShutdownAllowed() bool {
// //
// NOTE: Part of the lnpeer.Peer interface. // NOTE: Part of the lnpeer.Peer interface.
func (p *Brontide) QuitSignal() <-chan struct{} { func (p *Brontide) QuitSignal() <-chan struct{} {
return p.quit return p.cg.Done()
} }
// addrWithInternalKey takes a delivery script, then attempts to supplement it // addrWithInternalKey takes a delivery script, then attempts to supplement it
@ -1276,7 +1279,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint,
select { select {
case p.linkFailures <- failure: case p.linkFailures <- failure:
case <-p.quit: case <-p.cg.Done():
case <-p.cfg.Quit: case <-p.cfg.Quit:
} }
} }
@ -1353,7 +1356,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint,
// maybeSendNodeAnn sends our node announcement to the remote peer if at least // maybeSendNodeAnn sends our node announcement to the remote peer if at least
// one confirmed public channel exists with them. // one confirmed public channel exists with them.
func (p *Brontide) maybeSendNodeAnn(channels []*channeldb.OpenChannel) { func (p *Brontide) maybeSendNodeAnn(channels []*channeldb.OpenChannel) {
defer p.wg.Done() defer p.cg.WgDone()
hasConfirmedPublicChan := false hasConfirmedPublicChan := false
for _, channel := range channels { for _, channel := range channels {
@ -1385,7 +1388,7 @@ func (p *Brontide) maybeSendNodeAnn(channels []*channeldb.OpenChannel) {
// maybeSendChannelUpdates sends our channel updates to the remote peer if we // maybeSendChannelUpdates sends our channel updates to the remote peer if we
// have any active channels with them. // have any active channels with them.
func (p *Brontide) maybeSendChannelUpdates() { func (p *Brontide) maybeSendChannelUpdates() {
defer p.wg.Done() defer p.cg.WgDone()
// If we don't have any active channels, then we can exit early. // If we don't have any active channels, then we can exit early.
if p.activeChannels.Len() == 0 { if p.activeChannels.Len() == 0 {
@ -1461,16 +1464,16 @@ func (p *Brontide) WaitForDisconnect(ready chan struct{}) {
// set of goroutines are already active. // set of goroutines are already active.
select { select {
case <-p.startReady: case <-p.startReady:
case <-p.quit: case <-p.cg.Done():
return return
} }
select { select {
case <-ready: case <-ready:
case <-p.quit: case <-p.cg.Done():
} }
p.wg.Wait() p.cg.WgWait()
} }
// Disconnect terminates the connection with the remote peer. Additionally, a // Disconnect terminates the connection with the remote peer. Additionally, a
@ -1492,7 +1495,7 @@ func (p *Brontide) Disconnect(reason error) {
select { select {
case <-p.startReady: case <-p.startReady:
case <-p.quit: case <-p.cg.Done():
return return
} }
} }
@ -1508,7 +1511,7 @@ func (p *Brontide) Disconnect(reason error) {
// Ensure that the TCP connection is properly closed before continuing. // Ensure that the TCP connection is properly closed before continuing.
p.cfg.Conn.Close() p.cfg.Conn.Close()
close(p.quit) p.cg.Quit()
// If our msg router isn't global (local to this instance), then we'll // If our msg router isn't global (local to this instance), then we'll
// stop it. Otherwise, we'll leave it running. // stop it. Otherwise, we'll leave it running.
@ -1689,7 +1692,7 @@ func (ms *msgStream) msgConsumer() {
// Otherwise, we'll check the message queue for any new // Otherwise, we'll check the message queue for any new
// items. // items.
select { select {
case <-ms.peer.quit: case <-ms.peer.cg.Done():
ms.msgCond.L.Unlock() ms.msgCond.L.Unlock()
return return
case <-ms.quit: case <-ms.quit:
@ -1716,7 +1719,7 @@ func (ms *msgStream) msgConsumer() {
// grow indefinitely. // grow indefinitely.
select { select {
case ms.producerSema <- struct{}{}: case ms.producerSema <- struct{}{}:
case <-ms.peer.quit: case <-ms.peer.cg.Done():
return return
case <-ms.quit: case <-ms.quit:
return return
@ -1734,7 +1737,7 @@ func (ms *msgStream) AddMsg(msg lnwire.Message) {
// we're signalled to quit, or a slot is freed up. // we're signalled to quit, or a slot is freed up.
select { select {
case <-ms.producerSema: case <-ms.producerSema:
case <-ms.peer.quit: case <-ms.peer.cg.Done():
return return
case <-ms.quit: case <-ms.quit:
return return
@ -1811,7 +1814,7 @@ func waitUntilLinkActive(p *Brontide,
// calling function should catch it. // calling function should catch it.
return p.fetchLinkFromKeyAndCid(cid) return p.fetchLinkFromKeyAndCid(cid)
case <-p.quit: case <-p.cg.Done():
return nil return nil
} }
} }
@ -1845,7 +1848,7 @@ func newChanMsgStream(p *Brontide, cid lnwire.ChannelID) *msgStream {
// as the peer is exiting, we'll check quickly to see // as the peer is exiting, we'll check quickly to see
// if we need to exit. // if we need to exit.
select { select {
case <-p.quit: case <-p.cg.Done():
return return
default: default:
} }
@ -1885,7 +1888,7 @@ func newDiscMsgStream(p *Brontide) *msgStream {
// //
// NOTE: This method MUST be run as a goroutine. // NOTE: This method MUST be run as a goroutine.
func (p *Brontide) readHandler() { func (p *Brontide) readHandler() {
defer p.wg.Done() defer p.cg.WgDone()
// We'll stop the timer after a new messages is received, and also // We'll stop the timer after a new messages is received, and also
// reset it after we process the next message. // reset it after we process the next message.
@ -2006,13 +2009,13 @@ out:
case *lnwire.Shutdown: case *lnwire.Shutdown:
select { select {
case p.chanCloseMsgs <- &closeMsg{msg.ChannelID, msg}: case p.chanCloseMsgs <- &closeMsg{msg.ChannelID, msg}:
case <-p.quit: case <-p.cg.Done():
break out break out
} }
case *lnwire.ClosingSigned: case *lnwire.ClosingSigned:
select { select {
case p.chanCloseMsgs <- &closeMsg{msg.ChannelID, msg}: case p.chanCloseMsgs <- &closeMsg{msg.ChannelID, msg}:
case <-p.quit: case <-p.cg.Done():
break out break out
} }
@ -2584,7 +2587,7 @@ out:
break out break out
} }
case <-p.quit: case <-p.cg.Done():
exitErr = lnpeer.ErrPeerExiting exitErr = lnpeer.ErrPeerExiting
break out break out
} }
@ -2592,7 +2595,7 @@ out:
// Avoid an exit deadlock by ensuring WaitGroups are decremented before // Avoid an exit deadlock by ensuring WaitGroups are decremented before
// disconnect. // disconnect.
p.wg.Done() p.cg.WgDone()
p.Disconnect(exitErr) p.Disconnect(exitErr)
@ -2604,7 +2607,7 @@ out:
// //
// NOTE: This method MUST be run as a goroutine. // NOTE: This method MUST be run as a goroutine.
func (p *Brontide) queueHandler() { func (p *Brontide) queueHandler() {
defer p.wg.Done() defer p.cg.WgDone()
// priorityMsgs holds an in order list of messages deemed high-priority // priorityMsgs holds an in order list of messages deemed high-priority
// to be added to the sendQueue. This predominately includes messages // to be added to the sendQueue. This predominately includes messages
@ -2645,7 +2648,7 @@ func (p *Brontide) queueHandler() {
} else { } else {
lazyMsgs.PushBack(msg) lazyMsgs.PushBack(msg)
} }
case <-p.quit: case <-p.cg.Done():
return return
} }
} else { } else {
@ -2659,7 +2662,7 @@ func (p *Brontide) queueHandler() {
} else { } else {
lazyMsgs.PushBack(msg) lazyMsgs.PushBack(msg)
} }
case <-p.quit: case <-p.cg.Done():
return return
} }
} }
@ -2693,7 +2696,7 @@ func (p *Brontide) queue(priority bool, msg lnwire.Message,
select { select {
case p.outgoingQueue <- outgoingMsg{priority, msg, errChan}: case p.outgoingQueue <- outgoingMsg{priority, msg, errChan}:
case <-p.quit: case <-p.cg.Done():
p.log.Tracef("Peer shutting down, could not enqueue msg: %v.", p.log.Tracef("Peer shutting down, could not enqueue msg: %v.",
spew.Sdump(msg)) spew.Sdump(msg))
if errChan != nil { if errChan != nil {
@ -2761,7 +2764,7 @@ func (p *Brontide) genDeliveryScript() ([]byte, error) {
// //
// NOTE: This method MUST be run as a goroutine. // NOTE: This method MUST be run as a goroutine.
func (p *Brontide) channelManager() { func (p *Brontide) channelManager() {
defer p.wg.Done() defer p.cg.WgDone()
// reenableTimeout will fire once after the configured channel status // reenableTimeout will fire once after the configured channel status
// interval has elapsed. This will trigger us to sign new channel // interval has elapsed. This will trigger us to sign new channel
@ -2837,7 +2840,7 @@ out:
p.channelEventClient.Cancel() p.channelEventClient.Cancel()
} }
case <-p.quit: case <-p.cg.Done():
// As, we've been signalled to exit, we'll reset all // As, we've been signalled to exit, we'll reset all
// our active channel back to their default state. // our active channel back to their default state.
p.activeChannels.ForEach(func(_ lnwire.ChannelID, p.activeChannels.ForEach(func(_ lnwire.ChannelID,
@ -3120,7 +3123,7 @@ func (p *Brontide) retryRequestEnable(activeChans map[wire.OutPoint]struct{}) {
p.log.Warnf("Re-enable channel %v failed, received "+ p.log.Warnf("Re-enable channel %v failed, received "+
"inactive link event", chanPoint) "inactive link event", chanPoint)
case <-p.quit: case <-p.cg.Done():
p.log.Debugf("Peer shutdown during retry enabling") p.log.Debugf("Peer shutdown during retry enabling")
return return
} }
@ -3291,7 +3294,6 @@ func (p *Brontide) createChanCloser(channel *lnwallet.LightningChannel,
return p.cfg.DisconnectPeer(p.IdentityKey()) return p.cfg.DisconnectPeer(p.IdentityKey())
}, },
ChainParams: &p.cfg.Wallet.Cfg.NetParams, ChainParams: &p.cfg.Wallet.Cfg.NetParams,
Quit: p.quit,
}, },
*deliveryScript, *deliveryScript,
fee, fee,
@ -3856,7 +3858,7 @@ func (p *Brontide) sendMessage(sync, priority bool, msgs ...lnwire.Message) erro
select { select {
case err := <-errChan: case err := <-errChan:
return err return err
case <-p.quit: case <-p.cg.Done():
return lnpeer.ErrPeerExiting return lnpeer.ErrPeerExiting
case <-p.cfg.Quit: case <-p.cfg.Quit:
return lnpeer.ErrPeerExiting return lnpeer.ErrPeerExiting
@ -3904,7 +3906,7 @@ func (p *Brontide) AddNewChannel(newChan *lnpeer.NewChannel,
case p.newActiveChannel <- newChanMsg: case p.newActiveChannel <- newChanMsg:
case <-cancel: case <-cancel:
return errors.New("canceled adding new channel") return errors.New("canceled adding new channel")
case <-p.quit: case <-p.cg.Done():
return lnpeer.ErrPeerExiting return lnpeer.ErrPeerExiting
} }
@ -3913,7 +3915,7 @@ func (p *Brontide) AddNewChannel(newChan *lnpeer.NewChannel,
select { select {
case err := <-errChan: case err := <-errChan:
return err return err
case <-p.quit: case <-p.cg.Done():
return lnpeer.ErrPeerExiting return lnpeer.ErrPeerExiting
} }
} }
@ -3937,7 +3939,7 @@ func (p *Brontide) AddPendingChannel(cid lnwire.ChannelID,
case <-cancel: case <-cancel:
return errors.New("canceled adding pending channel") return errors.New("canceled adding pending channel")
case <-p.quit: case <-p.cg.Done():
return lnpeer.ErrPeerExiting return lnpeer.ErrPeerExiting
} }
@ -3951,7 +3953,7 @@ func (p *Brontide) AddPendingChannel(cid lnwire.ChannelID,
case <-cancel: case <-cancel:
return errors.New("canceled adding pending channel") return errors.New("canceled adding pending channel")
case <-p.quit: case <-p.cg.Done():
return lnpeer.ErrPeerExiting return lnpeer.ErrPeerExiting
} }
} }
@ -3968,7 +3970,7 @@ func (p *Brontide) RemovePendingChannel(cid lnwire.ChannelID) error {
select { select {
case p.removePendingChannel <- newChanMsg: case p.removePendingChannel <- newChanMsg:
case <-p.quit: case <-p.cg.Done():
return lnpeer.ErrPeerExiting return lnpeer.ErrPeerExiting
} }
@ -3979,7 +3981,7 @@ func (p *Brontide) RemovePendingChannel(cid lnwire.ChannelID) error {
case err := <-errChan: case err := <-errChan:
return err return err
case <-p.quit: case <-p.cg.Done():
return lnpeer.ErrPeerExiting return lnpeer.ErrPeerExiting
} }
} }
@ -4131,7 +4133,7 @@ func (p *Brontide) HandleLocalCloseChanReqs(req *htlcswitch.ChanClose) {
case p.localCloseChanReqs <- req: case p.localCloseChanReqs <- req:
p.log.Info("Local close channel request is going to be " + p.log.Info("Local close channel request is going to be " +
"delivered to the peer") "delivered to the peer")
case <-p.quit: case <-p.cg.Done():
p.log.Info("Unable to deliver local close channel request " + p.log.Info("Unable to deliver local close channel request " +
"to peer") "to peer")
} }
@ -4461,7 +4463,7 @@ func (p *Brontide) sendLinkUpdateMsg(cid lnwire.ChannelID, msg lnwire.Message) {
// Stop the stream when quit. // Stop the stream when quit.
go func() { go func() {
<-p.quit <-p.cg.Done()
chanStream.Stop() chanStream.Stop()
}() }()
} }

View file

@ -337,7 +337,7 @@ func createTestPeerWithChannel(t *testing.T, updateChan func(a,
chanID := lnwire.NewChanIDFromOutPoint(channelAlice.ChannelPoint()) chanID := lnwire.NewChanIDFromOutPoint(channelAlice.ChannelPoint())
alicePeer.activeChannels.Store(chanID, channelAlice) alicePeer.activeChannels.Store(chanID, channelAlice)
alicePeer.wg.Add(1) alicePeer.cg.WgAdd(1)
go alicePeer.channelManager() go alicePeer.channelManager()
return &peerTestCtx{ return &peerTestCtx{

View file

@ -2887,8 +2887,9 @@ func (r *rpcServer) CloseChannel(in *lnrpc.CloseChannelRequest,
} }
updateChan, errChan = r.server.htlcSwitch.CloseLink( updateChan, errChan = r.server.htlcSwitch.CloseLink(
chanPoint, contractcourt.CloseRegular, feeRate, updateStream.Context(), chanPoint,
maxFee, deliveryScript, contractcourt.CloseRegular, feeRate, maxFee,
deliveryScript,
) )
} }

View file

@ -1236,7 +1236,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
// Instruct the switch to close the channel. Provide no close out // Instruct the switch to close the channel. Provide no close out
// delivery script or target fee per kw because user input is not // delivery script or target fee per kw because user input is not
// available when the remote peer closes the channel. // available when the remote peer closes the channel.
s.htlcSwitch.CloseLink(chanPoint, closureType, 0, 0, nil) s.htlcSwitch.CloseLink(
context.Background(), chanPoint, closureType, 0, 0, nil,
)
} }
// We will use the following channel to reliably hand off contract // We will use the following channel to reliably hand off contract