From ac0c24aa7bd60c644739182f1feabc3885c70da7 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Tue, 12 Nov 2024 15:06:02 -0700 Subject: [PATCH] htlcswitch: don't pass pending update counts into quiescer This change simplifies some of the quiescer responsibilities in favor of making the link check whether or not it has a clean state to be able to send or receive an stfu. This change was made on the basis that the only use the quiescer makes of this information is to assess that it is or is not zero. Further the difficulty of checking this condition in the link is barely more burdensome than selecting the proper information to pass to the quiescer anyway. --- htlcswitch/link.go | 69 ++++++++-------- htlcswitch/quiescer.go | 66 +++++++-------- htlcswitch/quiescer_test.go | 159 +++++++----------------------------- 3 files changed, 93 insertions(+), 201 deletions(-) diff --git a/htlcswitch/link.go b/htlcswitch/link.go index c62b6f163..3eb398c1a 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -1529,17 +1529,15 @@ func (l *channelLink) htlcManager() { case qReq := <-l.quiescenceReqs: l.quiescer.InitStfu(qReq) - pendingOnLocal := l.channel.NumPendingUpdates( - lntypes.Local, lntypes.Local, - ) - pendingOnRemote := l.channel.NumPendingUpdates( - lntypes.Local, lntypes.Remote, - ) - if err := l.quiescer.SendOwedStfu( - pendingOnLocal + pendingOnRemote, - ); err != nil { - l.stfuFailf("%s", err.Error()) - qReq.Resolve(fn.Err[lntypes.ChannelParty](err)) + if l.noDanglingUpdates(lntypes.Local) { + err := l.quiescer.SendOwedStfu() + if err != nil { + l.stfuFailf( + "SendOwedStfu: %s", err.Error(), + ) + res := fn.Err[lntypes.ChannelParty](err) + qReq.Resolve(res) + } } case <-l.Quit: @@ -2436,15 +2434,11 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { // If we need to send out an Stfu, this would be the time to do // so. - pendingOnLocal := l.channel.NumPendingUpdates( - lntypes.Local, lntypes.Local, - ) - pendingOnRemote := l.channel.NumPendingUpdates( - lntypes.Local, lntypes.Remote, - ) - err = l.quiescer.SendOwedStfu(pendingOnLocal + pendingOnRemote) - if err != nil { - l.stfuFailf("sendOwedStfu: %v", err.Error()) + if l.noDanglingUpdates(lntypes.Local) { + err = l.quiescer.SendOwedStfu() + if err != nil { + l.stfuFailf("sendOwedStfu: %v", err.Error()) + } } // Now that we have finished processing the incoming CommitSig @@ -2635,26 +2629,20 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { // handleStfu implements the top-level logic for handling the Stfu message from // our peer. func (l *channelLink) handleStfu(stfu *lnwire.Stfu) error { - pendingOnLocal := l.channel.NumPendingUpdates( - lntypes.Remote, lntypes.Local, - ) - pendingOnRemote := l.channel.NumPendingUpdates( - lntypes.Remote, lntypes.Remote, - ) - err := l.quiescer.RecvStfu(*stfu, pendingOnLocal+pendingOnRemote) + if !l.noDanglingUpdates(lntypes.Remote) { + return ErrPendingRemoteUpdates + } + err := l.quiescer.RecvStfu(*stfu) if err != nil { return err } // If we can immediately send an Stfu response back, we will. - pendingOnLocal = l.channel.NumPendingUpdates( - lntypes.Local, lntypes.Local, - ) - pendingOnRemote = l.channel.NumPendingUpdates( - lntypes.Local, lntypes.Remote, - ) + if l.noDanglingUpdates(lntypes.Local) { + return l.quiescer.SendOwedStfu() + } - return l.quiescer.SendOwedStfu(pendingOnLocal + pendingOnRemote) + return nil } // stfuFailf fails the link in the case where the requirements of the quiescence @@ -2669,6 +2657,19 @@ func (l *channelLink) stfuFailf(format string, args ...interface{}) { }, format, args...) } +// noDanglingUpdates returns true when there are 0 updates that were originally +// issued by whose on either the Local or Remote commitment transaction. +func (l *channelLink) noDanglingUpdates(whose lntypes.ChannelParty) bool { + pendingOnLocal := l.channel.NumPendingUpdates( + whose, lntypes.Local, + ) + pendingOnRemote := l.channel.NumPendingUpdates( + whose, lntypes.Remote, + ) + + return pendingOnLocal == 0 && pendingOnRemote == 0 +} + // ackDownStreamPackets is responsible for removing htlcs from a link's mailbox // for packets delivered from server, and cleaning up any circuits closed by // signing a previous commitment txn. This method ensures that the circuits are diff --git a/htlcswitch/quiescer.go b/htlcswitch/quiescer.go index 4b3518ee2..7e6269b12 100644 --- a/htlcswitch/quiescer.go +++ b/htlcswitch/quiescer.go @@ -76,7 +76,7 @@ type Quiescer interface { InitStfu(req StfuReq) // RecvStfu is called when we receive an Stfu message from the remote. - RecvStfu(stfu lnwire.Stfu, numRemotePendingUpdates uint64) error + RecvStfu(stfu lnwire.Stfu) error // CanRecvUpdates returns true if we haven't yet received an Stfu which // would mark the end of the remote's ability to send updates. @@ -88,7 +88,7 @@ type Quiescer interface { // SendOwedStfu sends Stfu if it owes one. It returns an error if the // state machine is in an invalid state. - SendOwedStfu(numPendingLocalUpdates uint64) error + SendOwedStfu() error // OnResume accepts a no return closure that will run when the quiescer // is resumed. @@ -175,19 +175,15 @@ func NewQuiescer(cfg QuiescerCfg) Quiescer { } // RecvStfu is called when we receive an Stfu message from the remote. -func (q *QuiescerLive) RecvStfu(msg lnwire.Stfu, - numPendingRemoteUpdates uint64) error { - +func (q *QuiescerLive) RecvStfu(msg lnwire.Stfu) error { q.Lock() defer q.Unlock() - return q.recvStfu(msg, numPendingRemoteUpdates) + return q.recvStfu(msg) } // recvStfu is called when we receive an Stfu message from the remote. -func (q *QuiescerLive) recvStfu(msg lnwire.Stfu, - numPendingRemoteUpdates uint64) error { - +func (q *QuiescerLive) recvStfu(msg lnwire.Stfu) error { // At the time of this writing, this check that we have already received // an Stfu is not strictly necessary, according to the specification. // However, it is fishy if we do and it is unclear how we should handle @@ -203,7 +199,7 @@ func (q *QuiescerLive) recvStfu(msg lnwire.Stfu, q.cfg.chanID) } - if !q.canRecvStfu(numPendingRemoteUpdates) { + if !q.canRecvStfu() { return fmt.Errorf("%w for channel %v", ErrPendingRemoteUpdates, q.cfg.chanID) } @@ -228,26 +224,22 @@ func (q *QuiescerLive) recvStfu(msg lnwire.Stfu, // MakeStfu is called when we are ready to send an Stfu message. It returns the // Stfu message to be sent. -func (q *QuiescerLive) MakeStfu( - numPendingLocalUpdates uint64) fn.Result[lnwire.Stfu] { - +func (q *QuiescerLive) MakeStfu() fn.Result[lnwire.Stfu] { q.RLock() defer q.RUnlock() - return q.makeStfu(numPendingLocalUpdates) + return q.makeStfu() } // makeStfu is called when we are ready to send an Stfu message. It returns the // Stfu message to be sent. -func (q *QuiescerLive) makeStfu( - numPendingLocalUpdates uint64) fn.Result[lnwire.Stfu] { - +func (q *QuiescerLive) makeStfu() fn.Result[lnwire.Stfu] { if q.sent { return fn.Errf[lnwire.Stfu]("%w for channel %v", ErrStfuAlreadySent, q.cfg.chanID) } - if !q.canSendStfu(numPendingLocalUpdates) { + if !q.canSendStfu() { return fn.Errf[lnwire.Stfu]("%w for channel %v", ErrPendingLocalUpdates, q.cfg.chanID) } @@ -380,44 +372,44 @@ func (q *QuiescerLive) CanSendStfu(numPendingLocalUpdates uint64) bool { q.RLock() defer q.RUnlock() - return q.canSendStfu(numPendingLocalUpdates) + return q.canSendStfu() } // canSendStfu returns true if we can send an Stfu. -func (q *QuiescerLive) canSendStfu(numPendingLocalUpdates uint64) bool { - return numPendingLocalUpdates == 0 && !q.sent +func (q *QuiescerLive) canSendStfu() bool { + return !q.sent } // CanRecvStfu returns true if we can receive an Stfu. -func (q *QuiescerLive) CanRecvStfu(numPendingRemoteUpdates uint64) bool { +func (q *QuiescerLive) CanRecvStfu() bool { q.RLock() defer q.RUnlock() - return q.canRecvStfu(numPendingRemoteUpdates) + return q.canRecvStfu() } // canRecvStfu returns true if we can receive an Stfu. -func (q *QuiescerLive) canRecvStfu(numPendingRemoteUpdates uint64) bool { - return numPendingRemoteUpdates == 0 && !q.received +func (q *QuiescerLive) canRecvStfu() bool { + return !q.received } // SendOwedStfu sends Stfu if it owes one. It returns an error if the state // machine is in an invalid state. -func (q *QuiescerLive) SendOwedStfu(numPendingLocalUpdates uint64) error { +func (q *QuiescerLive) SendOwedStfu() error { q.Lock() defer q.Unlock() - return q.sendOwedStfu(numPendingLocalUpdates) + return q.sendOwedStfu() } // sendOwedStfu sends Stfu if it owes one. It returns an error if the state // machine is in an invalid state. -func (q *QuiescerLive) sendOwedStfu(numPendingLocalUpdates uint64) error { - if !q.oweStfu() || !q.canSendStfu(numPendingLocalUpdates) { +func (q *QuiescerLive) sendOwedStfu() error { + if !q.oweStfu() || !q.canSendStfu() { return nil } - err := q.makeStfu(numPendingLocalUpdates).Sink(q.cfg.sendMsg) + err := q.makeStfu().Sink(q.cfg.sendMsg) if err == nil { q.sent = true @@ -561,13 +553,13 @@ var _ Quiescer = (*quiescerNoop)(nil) func (q *quiescerNoop) InitStfu(req StfuReq) { req.Resolve(fn.Errf[lntypes.ChannelParty]("quiescence not supported")) } -func (q *quiescerNoop) RecvStfu(_ lnwire.Stfu, _ uint64) error { return nil } -func (q *quiescerNoop) CanRecvUpdates() bool { return true } -func (q *quiescerNoop) CanSendUpdates() bool { return true } -func (q *quiescerNoop) SendOwedStfu(_ uint64) error { return nil } -func (q *quiescerNoop) IsQuiescent() bool { return false } -func (q *quiescerNoop) OnResume(hook func()) { hook() } -func (q *quiescerNoop) Resume() {} +func (q *quiescerNoop) RecvStfu(_ lnwire.Stfu) error { return nil } +func (q *quiescerNoop) CanRecvUpdates() bool { return true } +func (q *quiescerNoop) CanSendUpdates() bool { return true } +func (q *quiescerNoop) SendOwedStfu() error { return nil } +func (q *quiescerNoop) IsQuiescent() bool { return false } +func (q *quiescerNoop) OnResume(hook func()) { hook() } +func (q *quiescerNoop) Resume() {} func (q *quiescerNoop) QuiescenceInitiator() fn.Result[lntypes.ChannelParty] { return fn.Err[lntypes.ChannelParty](ErrNoQuiescenceInitiator) } diff --git a/htlcswitch/quiescer_test.go b/htlcswitch/quiescer_test.go index 08e201ddd..da08909d5 100644 --- a/htlcswitch/quiescer_test.go +++ b/htlcswitch/quiescer_test.go @@ -14,9 +14,8 @@ import ( var cid = lnwire.ChannelID(bytes.Repeat([]byte{0x00}, 32)) type quiescerTestHarness struct { - pendingUpdates lntypes.Dual[uint64] - quiescer *QuiescerLive - conn <-chan lnwire.Stfu + quiescer *QuiescerLive + conn <-chan lnwire.Stfu } func initQuiescerTestHarness( @@ -24,8 +23,7 @@ func initQuiescerTestHarness( conn := make(chan lnwire.Stfu, 1) harness := &quiescerTestHarness{ - pendingUpdates: lntypes.Dual[uint64]{}, - conn: conn, + conn: conn, } quiescer, _ := NewQuiescer(QuiescerCfg{ @@ -54,30 +52,12 @@ func TestQuiescerDoubleRecvInvalid(t *testing.T) { Initiator: true, } - err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote) + err := harness.quiescer.RecvStfu(msg) require.NoError(t, err) - err = harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote) + err = harness.quiescer.RecvStfu(msg) require.Error(t, err, ErrStfuAlreadyRcvd) } -// TestQuiescerPendingUpdatesRecvInvalid ensures that we get an error if we -// receive the Stfu message while the Remote party has panding updates on the -// channel. -func TestQuiescerPendingUpdatesRecvInvalid(t *testing.T) { - t.Parallel() - - harness := initQuiescerTestHarness(lntypes.Local) - - msg := lnwire.Stfu{ - ChanID: cid, - Initiator: true, - } - - harness.pendingUpdates.SetForParty(lntypes.Remote, 1) - err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote) - require.ErrorIs(t, err, ErrPendingRemoteUpdates) -} - // TestQuiescenceRemoteInit ensures that we can successfully traverse the state // graph of quiescence beginning with the Remote party initiating quiescence. func TestQuiescenceRemoteInit(t *testing.T) { @@ -90,22 +70,10 @@ func TestQuiescenceRemoteInit(t *testing.T) { Initiator: true, } - harness.pendingUpdates.SetForParty(lntypes.Local, 1) - - err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote) + err := harness.quiescer.RecvStfu(msg) require.NoError(t, err) - err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local) - require.NoError(t, err) - - select { - case <-harness.conn: - t.Fatalf("stfu sent when not expected") - default: - } - - harness.pendingUpdates.SetForParty(lntypes.Local, 0) - err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local) + err = harness.quiescer.SendOwedStfu() require.NoError(t, err) select { @@ -125,25 +93,13 @@ func TestQuiescenceLocalInit(t *testing.T) { ChanID: cid, Initiator: true, } - harness.pendingUpdates.SetForParty(lntypes.Local, 1) stfuReq, stfuRes := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]]( fn.Unit{}, ) harness.quiescer.InitStfu(stfuReq) - harness.pendingUpdates.SetForParty(lntypes.Local, 1) - err := harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local) - require.NoError(t, err) - - select { - case <-harness.conn: - t.Fatalf("stfu sent when not expected") - default: - } - - harness.pendingUpdates.SetForParty(lntypes.Local, 0) - err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local) + err := harness.quiescer.SendOwedStfu() require.NoError(t, err) select { @@ -153,7 +109,7 @@ func TestQuiescenceLocalInit(t *testing.T) { t.Fatalf("stfu not sent when expected") } - err = harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote) + err = harness.quiescer.RecvStfu(msg) require.NoError(t, err) select { @@ -178,17 +134,11 @@ func TestQuiescenceInitiator(t *testing.T) { ChanID: cid, Initiator: true, } - require.NoError( - t, harness.quiescer.RecvStfu( - msg, harness.pendingUpdates.Remote, - ), - ) + require.NoError(t, harness.quiescer.RecvStfu(msg)) require.True(t, harness.quiescer.QuiescenceInitiator().IsErr()) // Send - require.NoError( - t, harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local), - ) + require.NoError(t, harness.quiescer.SendOwedStfu()) require.Equal( t, harness.quiescer.QuiescenceInitiator(), fn.Ok(lntypes.Remote), @@ -214,7 +164,7 @@ func TestQuiescenceInitiator(t *testing.T) { } require.NoError( - t, harness.quiescer.sendOwedStfu(harness.pendingUpdates.Local), + t, harness.quiescer.sendOwedStfu(), ) require.True(t, harness.quiescer.quiescenceInitiator().IsErr()) @@ -222,11 +172,7 @@ func TestQuiescenceInitiator(t *testing.T) { ChanID: cid, Initiator: false, } - require.NoError( - t, harness.quiescer.recvStfu( - msg, harness.pendingUpdates.Remote, - ), - ) + require.NoError(t, harness.quiescer.recvStfu(msg)) require.True(t, harness.quiescer.quiescenceInitiator().IsOk()) select { @@ -249,11 +195,7 @@ func TestQuiescenceCantReceiveUpdatesAfterStfu(t *testing.T) { ChanID: cid, Initiator: true, } - require.NoError( - t, harness.quiescer.RecvStfu( - msg, harness.pendingUpdates.Remote, - ), - ) + require.NoError(t, harness.quiescer.RecvStfu(msg)) require.False(t, harness.quiescer.CanRecvUpdates()) } @@ -270,10 +212,10 @@ func TestQuiescenceCantSendUpdatesAfterStfu(t *testing.T) { Initiator: true, } - err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote) + err := harness.quiescer.RecvStfu(msg) require.NoError(t, err) - err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local) + err = harness.quiescer.SendOwedStfu() require.NoError(t, err) require.False(t, harness.quiescer.CanSendUpdates()) @@ -293,11 +235,7 @@ func TestQuiescenceStfuNotNeededAfterRecv(t *testing.T) { } require.False(t, harness.quiescer.NeedStfu()) - require.NoError( - t, harness.quiescer.RecvStfu( - msg, harness.pendingUpdates.Remote, - ), - ) + require.NoError(t, harness.quiescer.RecvStfu(msg)) require.False(t, harness.quiescer.NeedStfu()) } @@ -309,38 +247,15 @@ func TestQuiescenceInappropriateMakeStfuReturnsErr(t *testing.T) { harness := initQuiescerTestHarness(lntypes.Local) - harness.pendingUpdates.SetForParty(lntypes.Local, 1) - - require.True( - t, harness.quiescer.MakeStfu( - harness.pendingUpdates.Local, - ).IsErr(), - ) - - harness.pendingUpdates.SetForParty(lntypes.Local, 0) msg := lnwire.Stfu{ ChanID: cid, Initiator: true, } - require.NoError( - t, harness.quiescer.RecvStfu( - msg, harness.pendingUpdates.Remote, - ), - ) - require.True( - t, harness.quiescer.MakeStfu( - harness.pendingUpdates.Local, - ).IsOk(), - ) + require.NoError(t, harness.quiescer.RecvStfu(msg)) + require.True(t, harness.quiescer.MakeStfu().IsOk()) - require.NoError( - t, harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local), - ) - require.True( - t, harness.quiescer.MakeStfu( - harness.pendingUpdates.Local, - ).IsErr(), - ) + require.NoError(t, harness.quiescer.SendOwedStfu()) + require.True(t, harness.quiescer.MakeStfu().IsErr()) } // TestQuiescerTieBreaker ensures that if both parties attempt to claim the @@ -364,16 +279,8 @@ func TestQuiescerTieBreaker(t *testing.T) { ) harness.quiescer.InitStfu(req) - require.NoError( - t, harness.quiescer.RecvStfu( - msg, harness.pendingUpdates.Remote, - ), - ) - require.NoError( - t, harness.quiescer.SendOwedStfu( - harness.pendingUpdates.Local, - ), - ) + require.NoError(t, harness.quiescer.RecvStfu(msg)) + require.NoError(t, harness.quiescer.SendOwedStfu()) select { case party := <-res: @@ -396,16 +303,8 @@ func TestQuiescerResume(t *testing.T) { Initiator: true, } - require.NoError( - t, harness.quiescer.RecvStfu( - msg, harness.pendingUpdates.Remote, - ), - ) - require.NoError( - t, harness.quiescer.SendOwedStfu( - harness.pendingUpdates.Local, - ), - ) + require.NoError(t, harness.quiescer.RecvStfu(msg)) + require.NoError(t, harness.quiescer.SendOwedStfu()) require.True(t, harness.quiescer.IsQuiescent()) var resumeHooksCalled = false @@ -434,9 +333,9 @@ func TestQuiescerTimeoutTriggers(t *testing.T) { harness.quiescer.cfg.timeoutDuration = time.Second harness.quiescer.cfg.onTimeout = func() { close(timeoutGate) } - err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote) + err := harness.quiescer.RecvStfu(msg) require.NoError(t, err) - err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local) + err = harness.quiescer.SendOwedStfu() require.NoError(t, err) select { @@ -461,9 +360,9 @@ func TestQuiescerTimeoutAborts(t *testing.T) { harness.quiescer.cfg.timeoutDuration = time.Second harness.quiescer.cfg.onTimeout = func() { close(timeoutGate) } - err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote) + err := harness.quiescer.RecvStfu(msg) require.NoError(t, err) - err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local) + err = harness.quiescer.SendOwedStfu() require.NoError(t, err) harness.quiescer.Resume()