From 5906ca2537ecf4640f9d3a02bf0492c76afaf8c8 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Tue, 9 Apr 2024 17:48:56 -0700 Subject: [PATCH] htlcswitch: add test for deferred processing remote adds when quiescent --- htlcswitch/link_test.go | 89 +++++++++++++++++++++++++++++++++++++++++ htlcswitch/mock.go | 12 +++++- 2 files changed, 99 insertions(+), 2 deletions(-) diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index c72a25538..725972340 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -7540,3 +7540,92 @@ func TestLinkFlushHooksCalled(t *testing.T) { ctx.receiveRevAndAckAliceToBob() assertHookCalled(true) } + +// TestLinkQuiescenceExitHopProcessingDeferred ensures that we do not send back +// htlc resolution messages in the case where the link is quiescent AND we are +// the exit hop. This is needed because we handle exit hop processing in the +// link instead of the switch and we process htlc resolutions when we receive +// a RevokeAndAck. Because of this we need to ensure that we hold off on +// processing the remote adds when we are quiescent. Later, when the channel +// update traffic is allowed to resume, we will need to verify that the actions +// we didn't run during the initial RevokeAndAck are run. +func TestLinkQuiescenceExitHopProcessingDeferred(t *testing.T) { + t.Parallel() + + // Initialize two channel state machines for testing. + alice, bob, err := createMirroredChannel( + t, btcutil.SatoshiPerBitcoin, btcutil.SatoshiPerBitcoin, + ) + require.NoError(t, err) + + // Build a single edge network to test channel quiescence. + network := newTwoHopNetwork( + t, alice.channel, bob.channel, testStartingHeight, + ) + aliceLink := network.aliceChannelLink + bobLink := network.bobChannelLink + + // Generate an invoice for Bob so that Alice can pay him. + htlcID := uint64(0) + htlc, invoice := generateHtlcAndInvoice(t, htlcID) + err = network.bobServer.registry.AddInvoice( + nil, *invoice, htlc.PaymentHash, + ) + require.NoError(t, err) + + // Establish a payment circuit for Alice + circuit := &PaymentCircuit{ + Incoming: CircuitKey{ + HtlcID: htlcID, + }, + PaymentHash: htlc.PaymentHash, + } + circuitMap := network.aliceServer.htlcSwitch.circuits + _, err = circuitMap.CommitCircuits(circuit) + require.NoError(t, err) + + // Add a switch packet to Alice's switch so that she can initialize the + // payment attempt. + err = aliceLink.handleSwitchPacket(&htlcPacket{ + incomingHTLCID: htlcID, + htlc: htlc, + circuit: circuit, + }) + require.NoError(t, err) + + // give alice enough time to fire the update_add + // TODO(proofofkeags): make this not depend on a flakey sleep. + <-time.After(time.Millisecond) + + // bob initiates stfu which he can do immediately since he doesn't have + // local updates + <-bobLink.InitStfu() + + // wait for other possible messages to play out + <-time.After(1 * time.Second) + + ensureNoUpdateAfterStfu := func(t *testing.T, trace []lnwire.Message) { + stfuReceived := false + for _, msg := range trace { + if msg.MsgType() == lnwire.MsgStfu { + stfuReceived = true + continue + } + + if stfuReceived && msg.MsgType().IsChannelUpdate() { + t.Fatalf("channel update after stfu: %v", + msg.MsgType()) + } + } + } + + network.aliceServer.protocolTraceMtx.Lock() + ensureNoUpdateAfterStfu(t, network.aliceServer.protocolTrace) + network.aliceServer.protocolTraceMtx.Unlock() + + network.bobServer.protocolTraceMtx.Lock() + ensureNoUpdateAfterStfu(t, network.bobServer.protocolTrace) + network.bobServer.protocolTraceMtx.Unlock() + + // TODO(proofofkeags): make sure these actions are run on resume. +} diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 37bf4c6ef..0a3364ae2 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -153,8 +153,10 @@ type mockServer struct { t testing.TB - name string - messages chan lnwire.Message + name string + messages chan lnwire.Message + protocolTraceMtx sync.Mutex + protocolTrace []lnwire.Message id [33]byte htlcSwitch *Switch @@ -289,6 +291,10 @@ func (s *mockServer) Start() error { for { select { case msg := <-s.messages: + s.protocolTraceMtx.Lock() + s.protocolTrace = append(s.protocolTrace, msg) + s.protocolTraceMtx.Unlock() + var shouldSkip bool for _, interceptor := range s.interceptorFuncs { @@ -627,6 +633,8 @@ func (s *mockServer) readHandler(message lnwire.Message) error { targetChan = msg.ChanID case *lnwire.UpdateFee: targetChan = msg.ChanID + case *lnwire.Stfu: + targetChan = msg.ChanID default: return fmt.Errorf("unknown message type: %T", msg) }