htlcswitch: don't pass pending update counts into quiescer

This change simplifies some of the quiescer responsibilities in
favor of making the link check whether or not it has a clean state
to be able to send or receive an stfu. This change was made on the
basis that the only use the quiescer makes of this information is
to assess that it is or is not zero. Further the difficulty of
checking this condition in the link is barely more burdensome than
selecting the proper information to pass to the quiescer anyway.
This commit is contained in:
Keagan McClelland 2024-11-12 15:06:02 -07:00
parent a4c49a88f1
commit ac0c24aa7b
No known key found for this signature in database
GPG key ID: FA7E65C951F12439
3 changed files with 93 additions and 201 deletions

View file

@ -1529,17 +1529,15 @@ func (l *channelLink) htlcManager() {
case qReq := <-l.quiescenceReqs:
l.quiescer.InitStfu(qReq)
pendingOnLocal := l.channel.NumPendingUpdates(
lntypes.Local, lntypes.Local,
)
pendingOnRemote := l.channel.NumPendingUpdates(
lntypes.Local, lntypes.Remote,
)
if err := l.quiescer.SendOwedStfu(
pendingOnLocal + pendingOnRemote,
); err != nil {
l.stfuFailf("%s", err.Error())
qReq.Resolve(fn.Err[lntypes.ChannelParty](err))
if l.noDanglingUpdates(lntypes.Local) {
err := l.quiescer.SendOwedStfu()
if err != nil {
l.stfuFailf(
"SendOwedStfu: %s", err.Error(),
)
res := fn.Err[lntypes.ChannelParty](err)
qReq.Resolve(res)
}
}
case <-l.Quit:
@ -2436,15 +2434,11 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
// If we need to send out an Stfu, this would be the time to do
// so.
pendingOnLocal := l.channel.NumPendingUpdates(
lntypes.Local, lntypes.Local,
)
pendingOnRemote := l.channel.NumPendingUpdates(
lntypes.Local, lntypes.Remote,
)
err = l.quiescer.SendOwedStfu(pendingOnLocal + pendingOnRemote)
if err != nil {
l.stfuFailf("sendOwedStfu: %v", err.Error())
if l.noDanglingUpdates(lntypes.Local) {
err = l.quiescer.SendOwedStfu()
if err != nil {
l.stfuFailf("sendOwedStfu: %v", err.Error())
}
}
// Now that we have finished processing the incoming CommitSig
@ -2635,26 +2629,20 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
// handleStfu implements the top-level logic for handling the Stfu message from
// our peer.
func (l *channelLink) handleStfu(stfu *lnwire.Stfu) error {
pendingOnLocal := l.channel.NumPendingUpdates(
lntypes.Remote, lntypes.Local,
)
pendingOnRemote := l.channel.NumPendingUpdates(
lntypes.Remote, lntypes.Remote,
)
err := l.quiescer.RecvStfu(*stfu, pendingOnLocal+pendingOnRemote)
if !l.noDanglingUpdates(lntypes.Remote) {
return ErrPendingRemoteUpdates
}
err := l.quiescer.RecvStfu(*stfu)
if err != nil {
return err
}
// If we can immediately send an Stfu response back, we will.
pendingOnLocal = l.channel.NumPendingUpdates(
lntypes.Local, lntypes.Local,
)
pendingOnRemote = l.channel.NumPendingUpdates(
lntypes.Local, lntypes.Remote,
)
if l.noDanglingUpdates(lntypes.Local) {
return l.quiescer.SendOwedStfu()
}
return l.quiescer.SendOwedStfu(pendingOnLocal + pendingOnRemote)
return nil
}
// stfuFailf fails the link in the case where the requirements of the quiescence
@ -2669,6 +2657,19 @@ func (l *channelLink) stfuFailf(format string, args ...interface{}) {
}, format, args...)
}
// noDanglingUpdates returns true when there are 0 updates that were originally
// issued by whose on either the Local or Remote commitment transaction.
func (l *channelLink) noDanglingUpdates(whose lntypes.ChannelParty) bool {
pendingOnLocal := l.channel.NumPendingUpdates(
whose, lntypes.Local,
)
pendingOnRemote := l.channel.NumPendingUpdates(
whose, lntypes.Remote,
)
return pendingOnLocal == 0 && pendingOnRemote == 0
}
// ackDownStreamPackets is responsible for removing htlcs from a link's mailbox
// for packets delivered from server, and cleaning up any circuits closed by
// signing a previous commitment txn. This method ensures that the circuits are

View file

@ -76,7 +76,7 @@ type Quiescer interface {
InitStfu(req StfuReq)
// RecvStfu is called when we receive an Stfu message from the remote.
RecvStfu(stfu lnwire.Stfu, numRemotePendingUpdates uint64) error
RecvStfu(stfu lnwire.Stfu) error
// CanRecvUpdates returns true if we haven't yet received an Stfu which
// would mark the end of the remote's ability to send updates.
@ -88,7 +88,7 @@ type Quiescer interface {
// SendOwedStfu sends Stfu if it owes one. It returns an error if the
// state machine is in an invalid state.
SendOwedStfu(numPendingLocalUpdates uint64) error
SendOwedStfu() error
// OnResume accepts a no return closure that will run when the quiescer
// is resumed.
@ -175,19 +175,15 @@ func NewQuiescer(cfg QuiescerCfg) Quiescer {
}
// RecvStfu is called when we receive an Stfu message from the remote.
func (q *QuiescerLive) RecvStfu(msg lnwire.Stfu,
numPendingRemoteUpdates uint64) error {
func (q *QuiescerLive) RecvStfu(msg lnwire.Stfu) error {
q.Lock()
defer q.Unlock()
return q.recvStfu(msg, numPendingRemoteUpdates)
return q.recvStfu(msg)
}
// recvStfu is called when we receive an Stfu message from the remote.
func (q *QuiescerLive) recvStfu(msg lnwire.Stfu,
numPendingRemoteUpdates uint64) error {
func (q *QuiescerLive) recvStfu(msg lnwire.Stfu) error {
// At the time of this writing, this check that we have already received
// an Stfu is not strictly necessary, according to the specification.
// However, it is fishy if we do and it is unclear how we should handle
@ -203,7 +199,7 @@ func (q *QuiescerLive) recvStfu(msg lnwire.Stfu,
q.cfg.chanID)
}
if !q.canRecvStfu(numPendingRemoteUpdates) {
if !q.canRecvStfu() {
return fmt.Errorf("%w for channel %v", ErrPendingRemoteUpdates,
q.cfg.chanID)
}
@ -228,26 +224,22 @@ func (q *QuiescerLive) recvStfu(msg lnwire.Stfu,
// MakeStfu is called when we are ready to send an Stfu message. It returns the
// Stfu message to be sent.
func (q *QuiescerLive) MakeStfu(
numPendingLocalUpdates uint64) fn.Result[lnwire.Stfu] {
func (q *QuiescerLive) MakeStfu() fn.Result[lnwire.Stfu] {
q.RLock()
defer q.RUnlock()
return q.makeStfu(numPendingLocalUpdates)
return q.makeStfu()
}
// makeStfu is called when we are ready to send an Stfu message. It returns the
// Stfu message to be sent.
func (q *QuiescerLive) makeStfu(
numPendingLocalUpdates uint64) fn.Result[lnwire.Stfu] {
func (q *QuiescerLive) makeStfu() fn.Result[lnwire.Stfu] {
if q.sent {
return fn.Errf[lnwire.Stfu]("%w for channel %v",
ErrStfuAlreadySent, q.cfg.chanID)
}
if !q.canSendStfu(numPendingLocalUpdates) {
if !q.canSendStfu() {
return fn.Errf[lnwire.Stfu]("%w for channel %v",
ErrPendingLocalUpdates, q.cfg.chanID)
}
@ -380,44 +372,44 @@ func (q *QuiescerLive) CanSendStfu(numPendingLocalUpdates uint64) bool {
q.RLock()
defer q.RUnlock()
return q.canSendStfu(numPendingLocalUpdates)
return q.canSendStfu()
}
// canSendStfu returns true if we can send an Stfu.
func (q *QuiescerLive) canSendStfu(numPendingLocalUpdates uint64) bool {
return numPendingLocalUpdates == 0 && !q.sent
func (q *QuiescerLive) canSendStfu() bool {
return !q.sent
}
// CanRecvStfu returns true if we can receive an Stfu.
func (q *QuiescerLive) CanRecvStfu(numPendingRemoteUpdates uint64) bool {
func (q *QuiescerLive) CanRecvStfu() bool {
q.RLock()
defer q.RUnlock()
return q.canRecvStfu(numPendingRemoteUpdates)
return q.canRecvStfu()
}
// canRecvStfu returns true if we can receive an Stfu.
func (q *QuiescerLive) canRecvStfu(numPendingRemoteUpdates uint64) bool {
return numPendingRemoteUpdates == 0 && !q.received
func (q *QuiescerLive) canRecvStfu() bool {
return !q.received
}
// SendOwedStfu sends Stfu if it owes one. It returns an error if the state
// machine is in an invalid state.
func (q *QuiescerLive) SendOwedStfu(numPendingLocalUpdates uint64) error {
func (q *QuiescerLive) SendOwedStfu() error {
q.Lock()
defer q.Unlock()
return q.sendOwedStfu(numPendingLocalUpdates)
return q.sendOwedStfu()
}
// sendOwedStfu sends Stfu if it owes one. It returns an error if the state
// machine is in an invalid state.
func (q *QuiescerLive) sendOwedStfu(numPendingLocalUpdates uint64) error {
if !q.oweStfu() || !q.canSendStfu(numPendingLocalUpdates) {
func (q *QuiescerLive) sendOwedStfu() error {
if !q.oweStfu() || !q.canSendStfu() {
return nil
}
err := q.makeStfu(numPendingLocalUpdates).Sink(q.cfg.sendMsg)
err := q.makeStfu().Sink(q.cfg.sendMsg)
if err == nil {
q.sent = true
@ -561,13 +553,13 @@ var _ Quiescer = (*quiescerNoop)(nil)
func (q *quiescerNoop) InitStfu(req StfuReq) {
req.Resolve(fn.Errf[lntypes.ChannelParty]("quiescence not supported"))
}
func (q *quiescerNoop) RecvStfu(_ lnwire.Stfu, _ uint64) error { return nil }
func (q *quiescerNoop) CanRecvUpdates() bool { return true }
func (q *quiescerNoop) CanSendUpdates() bool { return true }
func (q *quiescerNoop) SendOwedStfu(_ uint64) error { return nil }
func (q *quiescerNoop) IsQuiescent() bool { return false }
func (q *quiescerNoop) OnResume(hook func()) { hook() }
func (q *quiescerNoop) Resume() {}
func (q *quiescerNoop) RecvStfu(_ lnwire.Stfu) error { return nil }
func (q *quiescerNoop) CanRecvUpdates() bool { return true }
func (q *quiescerNoop) CanSendUpdates() bool { return true }
func (q *quiescerNoop) SendOwedStfu() error { return nil }
func (q *quiescerNoop) IsQuiescent() bool { return false }
func (q *quiescerNoop) OnResume(hook func()) { hook() }
func (q *quiescerNoop) Resume() {}
func (q *quiescerNoop) QuiescenceInitiator() fn.Result[lntypes.ChannelParty] {
return fn.Err[lntypes.ChannelParty](ErrNoQuiescenceInitiator)
}

View file

@ -14,9 +14,8 @@ import (
var cid = lnwire.ChannelID(bytes.Repeat([]byte{0x00}, 32))
type quiescerTestHarness struct {
pendingUpdates lntypes.Dual[uint64]
quiescer *QuiescerLive
conn <-chan lnwire.Stfu
quiescer *QuiescerLive
conn <-chan lnwire.Stfu
}
func initQuiescerTestHarness(
@ -24,8 +23,7 @@ func initQuiescerTestHarness(
conn := make(chan lnwire.Stfu, 1)
harness := &quiescerTestHarness{
pendingUpdates: lntypes.Dual[uint64]{},
conn: conn,
conn: conn,
}
quiescer, _ := NewQuiescer(QuiescerCfg{
@ -54,30 +52,12 @@ func TestQuiescerDoubleRecvInvalid(t *testing.T) {
Initiator: true,
}
err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote)
err := harness.quiescer.RecvStfu(msg)
require.NoError(t, err)
err = harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote)
err = harness.quiescer.RecvStfu(msg)
require.Error(t, err, ErrStfuAlreadyRcvd)
}
// TestQuiescerPendingUpdatesRecvInvalid ensures that we get an error if we
// receive the Stfu message while the Remote party has panding updates on the
// channel.
func TestQuiescerPendingUpdatesRecvInvalid(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness(lntypes.Local)
msg := lnwire.Stfu{
ChanID: cid,
Initiator: true,
}
harness.pendingUpdates.SetForParty(lntypes.Remote, 1)
err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote)
require.ErrorIs(t, err, ErrPendingRemoteUpdates)
}
// TestQuiescenceRemoteInit ensures that we can successfully traverse the state
// graph of quiescence beginning with the Remote party initiating quiescence.
func TestQuiescenceRemoteInit(t *testing.T) {
@ -90,22 +70,10 @@ func TestQuiescenceRemoteInit(t *testing.T) {
Initiator: true,
}
harness.pendingUpdates.SetForParty(lntypes.Local, 1)
err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote)
err := harness.quiescer.RecvStfu(msg)
require.NoError(t, err)
err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local)
require.NoError(t, err)
select {
case <-harness.conn:
t.Fatalf("stfu sent when not expected")
default:
}
harness.pendingUpdates.SetForParty(lntypes.Local, 0)
err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local)
err = harness.quiescer.SendOwedStfu()
require.NoError(t, err)
select {
@ -125,25 +93,13 @@ func TestQuiescenceLocalInit(t *testing.T) {
ChanID: cid,
Initiator: true,
}
harness.pendingUpdates.SetForParty(lntypes.Local, 1)
stfuReq, stfuRes := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]](
fn.Unit{},
)
harness.quiescer.InitStfu(stfuReq)
harness.pendingUpdates.SetForParty(lntypes.Local, 1)
err := harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local)
require.NoError(t, err)
select {
case <-harness.conn:
t.Fatalf("stfu sent when not expected")
default:
}
harness.pendingUpdates.SetForParty(lntypes.Local, 0)
err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local)
err := harness.quiescer.SendOwedStfu()
require.NoError(t, err)
select {
@ -153,7 +109,7 @@ func TestQuiescenceLocalInit(t *testing.T) {
t.Fatalf("stfu not sent when expected")
}
err = harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote)
err = harness.quiescer.RecvStfu(msg)
require.NoError(t, err)
select {
@ -178,17 +134,11 @@ func TestQuiescenceInitiator(t *testing.T) {
ChanID: cid,
Initiator: true,
}
require.NoError(
t, harness.quiescer.RecvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.NoError(t, harness.quiescer.RecvStfu(msg))
require.True(t, harness.quiescer.QuiescenceInitiator().IsErr())
// Send
require.NoError(
t, harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local),
)
require.NoError(t, harness.quiescer.SendOwedStfu())
require.Equal(
t, harness.quiescer.QuiescenceInitiator(),
fn.Ok(lntypes.Remote),
@ -214,7 +164,7 @@ func TestQuiescenceInitiator(t *testing.T) {
}
require.NoError(
t, harness.quiescer.sendOwedStfu(harness.pendingUpdates.Local),
t, harness.quiescer.sendOwedStfu(),
)
require.True(t, harness.quiescer.quiescenceInitiator().IsErr())
@ -222,11 +172,7 @@ func TestQuiescenceInitiator(t *testing.T) {
ChanID: cid,
Initiator: false,
}
require.NoError(
t, harness.quiescer.recvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.NoError(t, harness.quiescer.recvStfu(msg))
require.True(t, harness.quiescer.quiescenceInitiator().IsOk())
select {
@ -249,11 +195,7 @@ func TestQuiescenceCantReceiveUpdatesAfterStfu(t *testing.T) {
ChanID: cid,
Initiator: true,
}
require.NoError(
t, harness.quiescer.RecvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.NoError(t, harness.quiescer.RecvStfu(msg))
require.False(t, harness.quiescer.CanRecvUpdates())
}
@ -270,10 +212,10 @@ func TestQuiescenceCantSendUpdatesAfterStfu(t *testing.T) {
Initiator: true,
}
err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote)
err := harness.quiescer.RecvStfu(msg)
require.NoError(t, err)
err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local)
err = harness.quiescer.SendOwedStfu()
require.NoError(t, err)
require.False(t, harness.quiescer.CanSendUpdates())
@ -293,11 +235,7 @@ func TestQuiescenceStfuNotNeededAfterRecv(t *testing.T) {
}
require.False(t, harness.quiescer.NeedStfu())
require.NoError(
t, harness.quiescer.RecvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.NoError(t, harness.quiescer.RecvStfu(msg))
require.False(t, harness.quiescer.NeedStfu())
}
@ -309,38 +247,15 @@ func TestQuiescenceInappropriateMakeStfuReturnsErr(t *testing.T) {
harness := initQuiescerTestHarness(lntypes.Local)
harness.pendingUpdates.SetForParty(lntypes.Local, 1)
require.True(
t, harness.quiescer.MakeStfu(
harness.pendingUpdates.Local,
).IsErr(),
)
harness.pendingUpdates.SetForParty(lntypes.Local, 0)
msg := lnwire.Stfu{
ChanID: cid,
Initiator: true,
}
require.NoError(
t, harness.quiescer.RecvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.True(
t, harness.quiescer.MakeStfu(
harness.pendingUpdates.Local,
).IsOk(),
)
require.NoError(t, harness.quiescer.RecvStfu(msg))
require.True(t, harness.quiescer.MakeStfu().IsOk())
require.NoError(
t, harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local),
)
require.True(
t, harness.quiescer.MakeStfu(
harness.pendingUpdates.Local,
).IsErr(),
)
require.NoError(t, harness.quiescer.SendOwedStfu())
require.True(t, harness.quiescer.MakeStfu().IsErr())
}
// TestQuiescerTieBreaker ensures that if both parties attempt to claim the
@ -364,16 +279,8 @@ func TestQuiescerTieBreaker(t *testing.T) {
)
harness.quiescer.InitStfu(req)
require.NoError(
t, harness.quiescer.RecvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.NoError(
t, harness.quiescer.SendOwedStfu(
harness.pendingUpdates.Local,
),
)
require.NoError(t, harness.quiescer.RecvStfu(msg))
require.NoError(t, harness.quiescer.SendOwedStfu())
select {
case party := <-res:
@ -396,16 +303,8 @@ func TestQuiescerResume(t *testing.T) {
Initiator: true,
}
require.NoError(
t, harness.quiescer.RecvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.NoError(
t, harness.quiescer.SendOwedStfu(
harness.pendingUpdates.Local,
),
)
require.NoError(t, harness.quiescer.RecvStfu(msg))
require.NoError(t, harness.quiescer.SendOwedStfu())
require.True(t, harness.quiescer.IsQuiescent())
var resumeHooksCalled = false
@ -434,9 +333,9 @@ func TestQuiescerTimeoutTriggers(t *testing.T) {
harness.quiescer.cfg.timeoutDuration = time.Second
harness.quiescer.cfg.onTimeout = func() { close(timeoutGate) }
err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote)
err := harness.quiescer.RecvStfu(msg)
require.NoError(t, err)
err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local)
err = harness.quiescer.SendOwedStfu()
require.NoError(t, err)
select {
@ -461,9 +360,9 @@ func TestQuiescerTimeoutAborts(t *testing.T) {
harness.quiescer.cfg.timeoutDuration = time.Second
harness.quiescer.cfg.onTimeout = func() { close(timeoutGate) }
err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote)
err := harness.quiescer.RecvStfu(msg)
require.NoError(t, err)
err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local)
err = harness.quiescer.SendOwedStfu()
require.NoError(t, err)
harness.quiescer.Resume()