htlcswitch: implement InitStfu link operation

This commit is contained in:
Keagan McClelland 2024-03-12 12:32:43 -07:00
parent bca1516429
commit 7255b7357c
No known key found for this signature in database
GPG key ID: FA7E65C951F12439
3 changed files with 244 additions and 15 deletions

View file

@ -396,6 +396,11 @@ type channelLink struct {
// respect to the quiescence protocol. // respect to the quiescence protocol.
quiescer Quiescer quiescer Quiescer
// quiescenceReqs is a queue of requests to quiesce this link. The
// members of the queue are send-only channels we should call back with
// the result.
quiescenceReqs chan StfuReq
// ContextGuard is a helper that encapsulates a wait group and quit // ContextGuard is a helper that encapsulates a wait group and quit
// channel and allows contexts that either block or cancel on those // channel and allows contexts that either block or cancel on those
// depending on the use case. // depending on the use case.
@ -481,6 +486,10 @@ func NewChannelLink(cfg ChannelLinkConfig,
}, },
} }
quiescenceReqs := make(
chan fn.Req[fn.Unit, fn.Result[lntypes.ChannelParty]], 1,
)
return &channelLink{ return &channelLink{
cfg: cfg, cfg: cfg,
channel: channel, channel: channel,
@ -491,6 +500,7 @@ func NewChannelLink(cfg ChannelLinkConfig,
outgoingCommitHooks: newHookMap(), outgoingCommitHooks: newHookMap(),
incomingCommitHooks: newHookMap(), incomingCommitHooks: newHookMap(),
quiescer: NewQuiescer(quiescerCfg), quiescer: NewQuiescer(quiescerCfg),
quiescenceReqs: quiescenceReqs,
ContextGuard: fn.NewContextGuard(), ContextGuard: fn.NewContextGuard(),
} }
} }
@ -745,12 +755,17 @@ func (l *channelLink) OnCommitOnce(direction LinkDirection, hook func()) {
// may be removed or reworked in the future as RPC initiated quiescence is a // may be removed or reworked in the future as RPC initiated quiescence is a
// holdover until we have downstream protocols that use it. // holdover until we have downstream protocols that use it.
func (l *channelLink) InitStfu() <-chan fn.Result[lntypes.ChannelParty] { func (l *channelLink) InitStfu() <-chan fn.Result[lntypes.ChannelParty] {
// TODO(proofofkeags): Implement req, out := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]](
c := make(chan fn.Result[lntypes.ChannelParty], 1) fn.Unit{},
)
c <- fn.Errf[lntypes.ChannelParty]("InitStfu not yet implemented") select {
case l.quiescenceReqs <- req:
case <-l.Quit:
req.Resolve(fn.Err[lntypes.ChannelParty](ErrLinkShuttingDown))
}
return c return out
} }
// isReestablished returns true if the link has successfully completed the // isReestablished returns true if the link has successfully completed the
@ -1498,6 +1513,22 @@ func (l *channelLink) htlcManager() {
) )
} }
case qReq := <-l.quiescenceReqs:
l.quiescer.InitStfu(qReq)
pendingOnLocal := l.channel.NumPendingUpdates(
lntypes.Local, lntypes.Local,
)
pendingOnRemote := l.channel.NumPendingUpdates(
lntypes.Local, lntypes.Remote,
)
if err := l.quiescer.SendOwedStfu(
pendingOnLocal + pendingOnRemote,
); err != nil {
l.stfuFailf("%s", err.Error())
qReq.Resolve(fn.Err[lntypes.ChannelParty](err))
}
case <-l.Quit: case <-l.Quit:
return return
} }

View file

@ -44,6 +44,8 @@ var (
) )
) )
type StfuReq = fn.Req[fn.Unit, fn.Result[lntypes.ChannelParty]]
// QuiescerCfg is a config structure used to initialize a quiescer giving it the // QuiescerCfg is a config structure used to initialize a quiescer giving it the
// appropriate functionality to interact with the channel state that the // appropriate functionality to interact with the channel state that the
// quiescer must syncrhonize with. // quiescer must syncrhonize with.
@ -84,6 +86,10 @@ type Quiescer struct {
// received tracks whether or not we have received Stfu from our peer. // received tracks whether or not we have received Stfu from our peer.
received bool received bool
// activeQuiescenceRequest is a possibly None Request that we should
// resolve when we complete quiescence.
activeQuiescenceReq fn.Option[StfuReq]
sync.RWMutex sync.RWMutex
} }
@ -135,6 +141,10 @@ func (q *Quiescer) recvStfu(msg lnwire.Stfu,
// does not necessarily mean they will get it, though. // does not necessarily mean they will get it, though.
q.remoteInit = msg.Initiator q.remoteInit = msg.Initiator
// Since we just received an Stfu, we may have a newly quiesced state.
// If so, we will try to resolve any outstanding StfuReqs.
q.tryResolveStfuReq()
return nil return nil
} }
@ -186,7 +196,7 @@ func (q *Quiescer) OweStfu() bool {
// Stfu when we have received but not yet sent an Stfu, or we are the initiator // Stfu when we have received but not yet sent an Stfu, or we are the initiator
// but have not yet sent an Stfu. // but have not yet sent an Stfu.
func (q *Quiescer) oweStfu() bool { func (q *Quiescer) oweStfu() bool {
return q.received && !q.sent return (q.received || q.localInit) && !q.sent
} }
// NeedStfu returns true if the remote owes us an Stfu. They owe us an Stfu when // NeedStfu returns true if the remote owes us an Stfu. They owe us an Stfu when
@ -333,7 +343,60 @@ func (q *Quiescer) sendOwedStfu(numPendingLocalUpdates uint64) error {
if err == nil { if err == nil {
q.sent = true q.sent = true
// Since we just sent an Stfu, we may have a newly quiesced
// state. If so, we will try to resolve any outstanding
// StfuReqs.
q.tryResolveStfuReq()
} }
return err return err
} }
// TryResolveStfuReq attempts to resolve the active quiescence request if the
// state machine has reached a quiescent state.
func (q *Quiescer) TryResolveStfuReq() {
q.Lock()
defer q.Unlock()
q.tryResolveStfuReq()
}
// tryResolveStfuReq attempts to resolve the active quiescence request if the
// state machine has reached a quiescent state.
func (q *Quiescer) tryResolveStfuReq() {
q.activeQuiescenceReq.WhenSome(
func(req StfuReq) {
if q.isQuiescent() {
req.Resolve(q.quiescenceInitiator())
q.activeQuiescenceReq = fn.None[StfuReq]()
}
},
)
}
// InitStfu instructs the quiescer that we intend to begin a quiescence
// negotiation where we are the initiator. We don't yet send stfu yet because
// we need to wait for the link to give us a valid opportunity to do so.
func (q *Quiescer) InitStfu(req StfuReq) {
q.Lock()
defer q.Unlock()
q.initStfu(req)
}
// initStfu instructs the quiescer that we intend to begin a quiescence
// negotiation where we are the initiator. We don't yet send stfu yet because
// we need to wait for the link to give us a valid opportunity to do so.
func (q *Quiescer) initStfu(req StfuReq) {
if q.localInit {
req.Resolve(fn.Errf[lntypes.ChannelParty](
"quiescence already requested",
))
return
}
q.localInit = true
q.activeQuiescenceReq = fn.Some(req)
}

View file

@ -18,7 +18,9 @@ type quiescerTestHarness struct {
conn <-chan lnwire.Stfu conn <-chan lnwire.Stfu
} }
func initQuiescerTestHarness() *quiescerTestHarness { func initQuiescerTestHarness(
channelInitiator lntypes.ChannelParty) *quiescerTestHarness {
conn := make(chan lnwire.Stfu, 1) conn := make(chan lnwire.Stfu, 1)
harness := &quiescerTestHarness{ harness := &quiescerTestHarness{
pendingUpdates: lntypes.Dual[uint64]{}, pendingUpdates: lntypes.Dual[uint64]{},
@ -27,6 +29,7 @@ func initQuiescerTestHarness() *quiescerTestHarness {
harness.quiescer = NewQuiescer(QuiescerCfg{ harness.quiescer = NewQuiescer(QuiescerCfg{
chanID: cid, chanID: cid,
channelInitiator: channelInitiator,
sendMsg: func(msg lnwire.Stfu) error { sendMsg: func(msg lnwire.Stfu) error {
conn <- msg conn <- msg
return nil return nil
@ -41,7 +44,7 @@ func initQuiescerTestHarness() *quiescerTestHarness {
func TestQuiescerDoubleRecvInvalid(t *testing.T) { func TestQuiescerDoubleRecvInvalid(t *testing.T) {
t.Parallel() t.Parallel()
harness := initQuiescerTestHarness() harness := initQuiescerTestHarness(lntypes.Local)
msg := lnwire.Stfu{ msg := lnwire.Stfu{
ChanID: cid, ChanID: cid,
@ -60,7 +63,7 @@ func TestQuiescerDoubleRecvInvalid(t *testing.T) {
func TestQuiescerPendingUpdatesRecvInvalid(t *testing.T) { func TestQuiescerPendingUpdatesRecvInvalid(t *testing.T) {
t.Parallel() t.Parallel()
harness := initQuiescerTestHarness() harness := initQuiescerTestHarness(lntypes.Local)
msg := lnwire.Stfu{ msg := lnwire.Stfu{
ChanID: cid, ChanID: cid,
@ -77,7 +80,7 @@ func TestQuiescerPendingUpdatesRecvInvalid(t *testing.T) {
func TestQuiescenceRemoteInit(t *testing.T) { func TestQuiescenceRemoteInit(t *testing.T) {
t.Parallel() t.Parallel()
harness := initQuiescerTestHarness() harness := initQuiescerTestHarness(lntypes.Local)
msg := lnwire.Stfu{ msg := lnwire.Stfu{
ChanID: cid, ChanID: cid,
@ -110,12 +113,61 @@ func TestQuiescenceRemoteInit(t *testing.T) {
} }
} }
func TestQuiescenceLocalInit(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness(lntypes.Local)
msg := lnwire.Stfu{
ChanID: cid,
Initiator: true,
}
harness.pendingUpdates.SetForParty(lntypes.Local, 1)
stfuReq, stfuRes := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]](
fn.Unit{},
)
harness.quiescer.InitStfu(stfuReq)
harness.pendingUpdates.SetForParty(lntypes.Local, 1)
err := harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local)
require.NoError(t, err)
select {
case <-harness.conn:
t.Fatalf("stfu sent when not expected")
default:
}
harness.pendingUpdates.SetForParty(lntypes.Local, 0)
err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local)
require.NoError(t, err)
select {
case msg := <-harness.conn:
require.True(t, msg.Initiator)
default:
t.Fatalf("stfu not sent when expected")
}
err = harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote)
require.NoError(t, err)
select {
case party := <-stfuRes:
require.Equal(t, fn.Ok(lntypes.Local), party)
default:
t.Fatalf("quiescence request not resolved")
}
}
// TestQuiescenceInitiator ensures that the quiescenceInitiator is the Remote // TestQuiescenceInitiator ensures that the quiescenceInitiator is the Remote
// party when we have a receive first traversal of the quiescer's state graph. // party when we have a receive first traversal of the quiescer's state graph.
func TestQuiescenceInitiator(t *testing.T) { func TestQuiescenceInitiator(t *testing.T) {
t.Parallel() t.Parallel()
harness := initQuiescerTestHarness() // Remote Initiated
harness := initQuiescerTestHarness(lntypes.Local)
require.True(t, harness.quiescer.QuiescenceInitiator().IsErr()) require.True(t, harness.quiescer.QuiescenceInitiator().IsErr())
// Receive // Receive
@ -138,6 +190,48 @@ func TestQuiescenceInitiator(t *testing.T) {
t, harness.quiescer.QuiescenceInitiator(), t, harness.quiescer.QuiescenceInitiator(),
fn.Ok(lntypes.Remote), fn.Ok(lntypes.Remote),
) )
// Local Initiated
harness = initQuiescerTestHarness(lntypes.Local)
require.True(t, harness.quiescer.quiescenceInitiator().IsErr())
req, res := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]](
fn.Unit{},
)
harness.quiescer.initStfu(req)
req2, res2 := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]](
fn.Unit{},
)
harness.quiescer.initStfu(req2)
select {
case initiator := <-res2:
require.True(t, initiator.IsErr())
default:
t.Fatal("quiescence request not resolved")
}
require.NoError(
t, harness.quiescer.sendOwedStfu(harness.pendingUpdates.Local),
)
require.True(t, harness.quiescer.quiescenceInitiator().IsErr())
msg = lnwire.Stfu{
ChanID: cid,
Initiator: false,
}
require.NoError(
t, harness.quiescer.recvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.True(t, harness.quiescer.quiescenceInitiator().IsOk())
select {
case initiator := <-res:
require.Equal(t, fn.Ok(lntypes.Local), initiator)
default:
t.Fatal("quiescence request not resolved")
}
} }
// TestQuiescenceCantReceiveUpdatesAfterStfu tests that we can receive channel // TestQuiescenceCantReceiveUpdatesAfterStfu tests that we can receive channel
@ -145,7 +239,7 @@ func TestQuiescenceInitiator(t *testing.T) {
func TestQuiescenceCantReceiveUpdatesAfterStfu(t *testing.T) { func TestQuiescenceCantReceiveUpdatesAfterStfu(t *testing.T) {
t.Parallel() t.Parallel()
harness := initQuiescerTestHarness() harness := initQuiescerTestHarness(lntypes.Local)
require.True(t, harness.quiescer.CanRecvUpdates()) require.True(t, harness.quiescer.CanRecvUpdates())
msg := lnwire.Stfu{ msg := lnwire.Stfu{
@ -165,7 +259,7 @@ func TestQuiescenceCantReceiveUpdatesAfterStfu(t *testing.T) {
func TestQuiescenceCantSendUpdatesAfterStfu(t *testing.T) { func TestQuiescenceCantSendUpdatesAfterStfu(t *testing.T) {
t.Parallel() t.Parallel()
harness := initQuiescerTestHarness() harness := initQuiescerTestHarness(lntypes.Local)
require.True(t, harness.quiescer.CanSendUpdates()) require.True(t, harness.quiescer.CanSendUpdates())
msg := lnwire.Stfu{ msg := lnwire.Stfu{
@ -188,7 +282,7 @@ func TestQuiescenceCantSendUpdatesAfterStfu(t *testing.T) {
func TestQuiescenceStfuNotNeededAfterRecv(t *testing.T) { func TestQuiescenceStfuNotNeededAfterRecv(t *testing.T) {
t.Parallel() t.Parallel()
harness := initQuiescerTestHarness() harness := initQuiescerTestHarness(lntypes.Local)
msg := lnwire.Stfu{ msg := lnwire.Stfu{
ChanID: cid, ChanID: cid,
@ -210,7 +304,7 @@ func TestQuiescenceStfuNotNeededAfterRecv(t *testing.T) {
func TestQuiescenceInappropriateMakeStfuReturnsErr(t *testing.T) { func TestQuiescenceInappropriateMakeStfuReturnsErr(t *testing.T) {
t.Parallel() t.Parallel()
harness := initQuiescerTestHarness() harness := initQuiescerTestHarness(lntypes.Local)
harness.pendingUpdates.SetForParty(lntypes.Local, 1) harness.pendingUpdates.SetForParty(lntypes.Local, 1)
@ -245,3 +339,44 @@ func TestQuiescenceInappropriateMakeStfuReturnsErr(t *testing.T) {
).IsErr(), ).IsErr(),
) )
} }
// TestQuiescerTieBreaker ensures that if both parties attempt to claim the
// initiator role that the result of the negotiation breaks the tie using the
// channel initiator.
func TestQuiescerTieBreaker(t *testing.T) {
t.Parallel()
for _, initiator := range []lntypes.ChannelParty{
lntypes.Local, lntypes.Remote,
} {
harness := initQuiescerTestHarness(initiator)
msg := lnwire.Stfu{
ChanID: cid,
Initiator: true,
}
req, res := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]](
fn.Unit{},
)
harness.quiescer.InitStfu(req)
require.NoError(
t, harness.quiescer.RecvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.NoError(
t, harness.quiescer.SendOwedStfu(
harness.pendingUpdates.Local,
),
)
select {
case party := <-res:
require.Equal(t, fn.Ok(initiator), party)
default:
t.Fatal("quiescence party unavailable")
}
}
}