htlcswitch: implement InitStfu link operation

This commit is contained in:
Keagan McClelland 2024-03-12 12:32:43 -07:00
parent bca1516429
commit 7255b7357c
No known key found for this signature in database
GPG key ID: FA7E65C951F12439
3 changed files with 244 additions and 15 deletions

View file

@ -396,6 +396,11 @@ type channelLink struct {
// respect to the quiescence protocol.
quiescer Quiescer
// quiescenceReqs is a queue of requests to quiesce this link. The
// members of the queue are send-only channels we should call back with
// the result.
quiescenceReqs chan StfuReq
// ContextGuard is a helper that encapsulates a wait group and quit
// channel and allows contexts that either block or cancel on those
// depending on the use case.
@ -481,6 +486,10 @@ func NewChannelLink(cfg ChannelLinkConfig,
},
}
quiescenceReqs := make(
chan fn.Req[fn.Unit, fn.Result[lntypes.ChannelParty]], 1,
)
return &channelLink{
cfg: cfg,
channel: channel,
@ -491,6 +500,7 @@ func NewChannelLink(cfg ChannelLinkConfig,
outgoingCommitHooks: newHookMap(),
incomingCommitHooks: newHookMap(),
quiescer: NewQuiescer(quiescerCfg),
quiescenceReqs: quiescenceReqs,
ContextGuard: fn.NewContextGuard(),
}
}
@ -745,12 +755,17 @@ func (l *channelLink) OnCommitOnce(direction LinkDirection, hook func()) {
// may be removed or reworked in the future as RPC initiated quiescence is a
// holdover until we have downstream protocols that use it.
func (l *channelLink) InitStfu() <-chan fn.Result[lntypes.ChannelParty] {
// TODO(proofofkeags): Implement
c := make(chan fn.Result[lntypes.ChannelParty], 1)
req, out := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]](
fn.Unit{},
)
c <- fn.Errf[lntypes.ChannelParty]("InitStfu not yet implemented")
select {
case l.quiescenceReqs <- req:
case <-l.Quit:
req.Resolve(fn.Err[lntypes.ChannelParty](ErrLinkShuttingDown))
}
return c
return out
}
// isReestablished returns true if the link has successfully completed the
@ -1498,6 +1513,22 @@ func (l *channelLink) htlcManager() {
)
}
case qReq := <-l.quiescenceReqs:
l.quiescer.InitStfu(qReq)
pendingOnLocal := l.channel.NumPendingUpdates(
lntypes.Local, lntypes.Local,
)
pendingOnRemote := l.channel.NumPendingUpdates(
lntypes.Local, lntypes.Remote,
)
if err := l.quiescer.SendOwedStfu(
pendingOnLocal + pendingOnRemote,
); err != nil {
l.stfuFailf("%s", err.Error())
qReq.Resolve(fn.Err[lntypes.ChannelParty](err))
}
case <-l.Quit:
return
}

View file

@ -44,6 +44,8 @@ var (
)
)
type StfuReq = fn.Req[fn.Unit, fn.Result[lntypes.ChannelParty]]
// QuiescerCfg is a config structure used to initialize a quiescer giving it the
// appropriate functionality to interact with the channel state that the
// quiescer must syncrhonize with.
@ -84,6 +86,10 @@ type Quiescer struct {
// received tracks whether or not we have received Stfu from our peer.
received bool
// activeQuiescenceRequest is a possibly None Request that we should
// resolve when we complete quiescence.
activeQuiescenceReq fn.Option[StfuReq]
sync.RWMutex
}
@ -135,6 +141,10 @@ func (q *Quiescer) recvStfu(msg lnwire.Stfu,
// does not necessarily mean they will get it, though.
q.remoteInit = msg.Initiator
// Since we just received an Stfu, we may have a newly quiesced state.
// If so, we will try to resolve any outstanding StfuReqs.
q.tryResolveStfuReq()
return nil
}
@ -186,7 +196,7 @@ func (q *Quiescer) OweStfu() bool {
// Stfu when we have received but not yet sent an Stfu, or we are the initiator
// but have not yet sent an Stfu.
func (q *Quiescer) oweStfu() bool {
return q.received && !q.sent
return (q.received || q.localInit) && !q.sent
}
// NeedStfu returns true if the remote owes us an Stfu. They owe us an Stfu when
@ -333,7 +343,60 @@ func (q *Quiescer) sendOwedStfu(numPendingLocalUpdates uint64) error {
if err == nil {
q.sent = true
// Since we just sent an Stfu, we may have a newly quiesced
// state. If so, we will try to resolve any outstanding
// StfuReqs.
q.tryResolveStfuReq()
}
return err
}
// TryResolveStfuReq attempts to resolve the active quiescence request if the
// state machine has reached a quiescent state.
func (q *Quiescer) TryResolveStfuReq() {
q.Lock()
defer q.Unlock()
q.tryResolveStfuReq()
}
// tryResolveStfuReq attempts to resolve the active quiescence request if the
// state machine has reached a quiescent state.
func (q *Quiescer) tryResolveStfuReq() {
q.activeQuiescenceReq.WhenSome(
func(req StfuReq) {
if q.isQuiescent() {
req.Resolve(q.quiescenceInitiator())
q.activeQuiescenceReq = fn.None[StfuReq]()
}
},
)
}
// InitStfu instructs the quiescer that we intend to begin a quiescence
// negotiation where we are the initiator. We don't yet send stfu yet because
// we need to wait for the link to give us a valid opportunity to do so.
func (q *Quiescer) InitStfu(req StfuReq) {
q.Lock()
defer q.Unlock()
q.initStfu(req)
}
// initStfu instructs the quiescer that we intend to begin a quiescence
// negotiation where we are the initiator. We don't yet send stfu yet because
// we need to wait for the link to give us a valid opportunity to do so.
func (q *Quiescer) initStfu(req StfuReq) {
if q.localInit {
req.Resolve(fn.Errf[lntypes.ChannelParty](
"quiescence already requested",
))
return
}
q.localInit = true
q.activeQuiescenceReq = fn.Some(req)
}

View file

@ -18,7 +18,9 @@ type quiescerTestHarness struct {
conn <-chan lnwire.Stfu
}
func initQuiescerTestHarness() *quiescerTestHarness {
func initQuiescerTestHarness(
channelInitiator lntypes.ChannelParty) *quiescerTestHarness {
conn := make(chan lnwire.Stfu, 1)
harness := &quiescerTestHarness{
pendingUpdates: lntypes.Dual[uint64]{},
@ -26,7 +28,8 @@ func initQuiescerTestHarness() *quiescerTestHarness {
}
harness.quiescer = NewQuiescer(QuiescerCfg{
chanID: cid,
chanID: cid,
channelInitiator: channelInitiator,
sendMsg: func(msg lnwire.Stfu) error {
conn <- msg
return nil
@ -41,7 +44,7 @@ func initQuiescerTestHarness() *quiescerTestHarness {
func TestQuiescerDoubleRecvInvalid(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
harness := initQuiescerTestHarness(lntypes.Local)
msg := lnwire.Stfu{
ChanID: cid,
@ -60,7 +63,7 @@ func TestQuiescerDoubleRecvInvalid(t *testing.T) {
func TestQuiescerPendingUpdatesRecvInvalid(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
harness := initQuiescerTestHarness(lntypes.Local)
msg := lnwire.Stfu{
ChanID: cid,
@ -77,7 +80,7 @@ func TestQuiescerPendingUpdatesRecvInvalid(t *testing.T) {
func TestQuiescenceRemoteInit(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
harness := initQuiescerTestHarness(lntypes.Local)
msg := lnwire.Stfu{
ChanID: cid,
@ -110,12 +113,61 @@ func TestQuiescenceRemoteInit(t *testing.T) {
}
}
func TestQuiescenceLocalInit(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness(lntypes.Local)
msg := lnwire.Stfu{
ChanID: cid,
Initiator: true,
}
harness.pendingUpdates.SetForParty(lntypes.Local, 1)
stfuReq, stfuRes := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]](
fn.Unit{},
)
harness.quiescer.InitStfu(stfuReq)
harness.pendingUpdates.SetForParty(lntypes.Local, 1)
err := harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local)
require.NoError(t, err)
select {
case <-harness.conn:
t.Fatalf("stfu sent when not expected")
default:
}
harness.pendingUpdates.SetForParty(lntypes.Local, 0)
err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local)
require.NoError(t, err)
select {
case msg := <-harness.conn:
require.True(t, msg.Initiator)
default:
t.Fatalf("stfu not sent when expected")
}
err = harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote)
require.NoError(t, err)
select {
case party := <-stfuRes:
require.Equal(t, fn.Ok(lntypes.Local), party)
default:
t.Fatalf("quiescence request not resolved")
}
}
// TestQuiescenceInitiator ensures that the quiescenceInitiator is the Remote
// party when we have a receive first traversal of the quiescer's state graph.
func TestQuiescenceInitiator(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
// Remote Initiated
harness := initQuiescerTestHarness(lntypes.Local)
require.True(t, harness.quiescer.QuiescenceInitiator().IsErr())
// Receive
@ -138,6 +190,48 @@ func TestQuiescenceInitiator(t *testing.T) {
t, harness.quiescer.QuiescenceInitiator(),
fn.Ok(lntypes.Remote),
)
// Local Initiated
harness = initQuiescerTestHarness(lntypes.Local)
require.True(t, harness.quiescer.quiescenceInitiator().IsErr())
req, res := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]](
fn.Unit{},
)
harness.quiescer.initStfu(req)
req2, res2 := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]](
fn.Unit{},
)
harness.quiescer.initStfu(req2)
select {
case initiator := <-res2:
require.True(t, initiator.IsErr())
default:
t.Fatal("quiescence request not resolved")
}
require.NoError(
t, harness.quiescer.sendOwedStfu(harness.pendingUpdates.Local),
)
require.True(t, harness.quiescer.quiescenceInitiator().IsErr())
msg = lnwire.Stfu{
ChanID: cid,
Initiator: false,
}
require.NoError(
t, harness.quiescer.recvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.True(t, harness.quiescer.quiescenceInitiator().IsOk())
select {
case initiator := <-res:
require.Equal(t, fn.Ok(lntypes.Local), initiator)
default:
t.Fatal("quiescence request not resolved")
}
}
// TestQuiescenceCantReceiveUpdatesAfterStfu tests that we can receive channel
@ -145,7 +239,7 @@ func TestQuiescenceInitiator(t *testing.T) {
func TestQuiescenceCantReceiveUpdatesAfterStfu(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
harness := initQuiescerTestHarness(lntypes.Local)
require.True(t, harness.quiescer.CanRecvUpdates())
msg := lnwire.Stfu{
@ -165,7 +259,7 @@ func TestQuiescenceCantReceiveUpdatesAfterStfu(t *testing.T) {
func TestQuiescenceCantSendUpdatesAfterStfu(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
harness := initQuiescerTestHarness(lntypes.Local)
require.True(t, harness.quiescer.CanSendUpdates())
msg := lnwire.Stfu{
@ -188,7 +282,7 @@ func TestQuiescenceCantSendUpdatesAfterStfu(t *testing.T) {
func TestQuiescenceStfuNotNeededAfterRecv(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
harness := initQuiescerTestHarness(lntypes.Local)
msg := lnwire.Stfu{
ChanID: cid,
@ -210,7 +304,7 @@ func TestQuiescenceStfuNotNeededAfterRecv(t *testing.T) {
func TestQuiescenceInappropriateMakeStfuReturnsErr(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
harness := initQuiescerTestHarness(lntypes.Local)
harness.pendingUpdates.SetForParty(lntypes.Local, 1)
@ -245,3 +339,44 @@ func TestQuiescenceInappropriateMakeStfuReturnsErr(t *testing.T) {
).IsErr(),
)
}
// TestQuiescerTieBreaker ensures that if both parties attempt to claim the
// initiator role that the result of the negotiation breaks the tie using the
// channel initiator.
func TestQuiescerTieBreaker(t *testing.T) {
t.Parallel()
for _, initiator := range []lntypes.ChannelParty{
lntypes.Local, lntypes.Remote,
} {
harness := initQuiescerTestHarness(initiator)
msg := lnwire.Stfu{
ChanID: cid,
Initiator: true,
}
req, res := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]](
fn.Unit{},
)
harness.quiescer.InitStfu(req)
require.NoError(
t, harness.quiescer.RecvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.NoError(
t, harness.quiescer.SendOwedStfu(
harness.pendingUpdates.Local,
),
)
select {
case party := <-res:
require.Equal(t, fn.Ok(initiator), party)
default:
t.Fatal("quiescence party unavailable")
}
}
}