From b741132a81ab1412f005aefdc9467f31bf2e5754 Mon Sep 17 00:00:00 2001 From: Ononiwu Maureen Date: Fri, 15 Mar 2024 15:56:48 +0100 Subject: [PATCH] peer: Add `startPeer` test function Signed-off-by: Ononiwu Maureen --- peer/brontide_test.go | 56 ++++++++++--------------------------------- peer/test_utils.go | 50 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 44 deletions(-) diff --git a/peer/brontide_test.go b/peer/brontide_test.go index 3b0ae3ce2..d2f3fc7bc 100644 --- a/peer/brontide_test.go +++ b/peer/brontide_test.go @@ -13,6 +13,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwallet" @@ -1024,31 +1025,13 @@ func TestPeerCustomMessage(t *testing.T) { mockConn = params.mockConn alicePeer = params.peer receivedCustomChan = params.customChan + remoteKey = alicePeer.PubKey() ) - remoteKey := alicePeer.PubKey() - - // Set up the init sequence. - go func() { - // Read init message. - <-mockConn.writtenMessages - - // Write the init reply message. - initReplyMsg := lnwire.NewInitMessage( - lnwire.NewRawFeatureVector( - lnwire.DataLossProtectRequired, - ), - lnwire.NewRawFeatureVector(), - ) - var b bytes.Buffer - _, err := lnwire.WriteMessage(&b, initReplyMsg, 0) - require.NoError(t, err) - - mockConn.readMessages <- b.Bytes() - }() - - // Start the peer. - require.NoError(t, alicePeer.Start()) + // Start peer. + startPeerDone := startPeer(t, mockConn, alicePeer) + _, err := fn.RecvOrTimeout(startPeerDone, 2*timeout) + require.NoError(t, err) // Send a custom message. customMsg, err := lnwire.NewCustom( @@ -1330,33 +1313,18 @@ func TestStartupWriteMessageRace(t *testing.T) { // Send a message while starting the peer. As the peer starts up, it // should not trigger a data race between the sending of this message // and the sending of the channel reestablish message. - sendPingDone := make(chan struct{}) + var sendPingDone = make(chan struct{}) go func() { require.NoError(t, peer.SendMessage(true, lnwire.NewPing(0))) close(sendPingDone) }() - // Handle init messages. - go func() { - // Read init message. - <-mockConn.writtenMessages - - // Write the init reply message. - initReplyMsg := lnwire.NewInitMessage( - lnwire.NewRawFeatureVector( - lnwire.DataLossProtectRequired, - ), - lnwire.NewRawFeatureVector(), - ) - var b bytes.Buffer - _, err = lnwire.WriteMessage(&b, initReplyMsg, 0) - require.NoError(t, err) - - mockConn.readMessages <- b.Bytes() - }() - // Start the peer. No data race should occur. - require.NoError(t, peer.Start()) + startPeerDone := startPeer(t, mockConn, peer) + + // Ensure startup is complete. + _, err = fn.RecvOrTimeout(startPeerDone, 2*timeout) + require.NoError(t, err) // Ensure messages were sent during startup. <-sendPingDone diff --git a/peer/test_utils.go b/peer/test_utils.go index 9a01e35ad..8667b04cd 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -18,6 +18,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -724,3 +725,52 @@ func createTestPeer(t *testing.T) *peerTestCtx { chanStatusMgr: chanStatusMgr, } } + +// startPeer invokes the `Start` method on the specified peer and handles any +// initial startup messages for testing. +func startPeer(t *testing.T, mockConn *mockMessageConn, + peer *Brontide) <-chan struct{} { + + // Start the peer in a goroutine so that we can handle and test for + // startup messages. Successfully sending and receiving init message, + // indicates a successful startup. + done := make(chan struct{}) + go func() { + require.NoError(t, peer.Start()) + close(done) + }() + + // Receive the init message that should be the first message received on + // startup. + rawMsg, err := fn.RecvOrTimeout[[]byte]( + mockConn.writtenMessages, timeout, + ) + require.NoError(t, err) + + msgReader := bytes.NewReader(rawMsg) + nextMsg, err := lnwire.ReadMessage(msgReader, 0) + require.NoError(t, err) + + _, ok := nextMsg.(*lnwire.Init) + require.True(t, ok) + + // Write the reply for the init message to complete the startup. + initReplyMsg := lnwire.NewInitMessage( + lnwire.NewRawFeatureVector( + lnwire.DataLossProtectRequired, + lnwire.GossipQueriesOptional, + ), + lnwire.NewRawFeatureVector(), + ) + + var b bytes.Buffer + _, err = lnwire.WriteMessage(&b, initReplyMsg, 0) + require.NoError(t, err) + + ok = fn.SendOrQuit[[]byte, struct{}]( + mockConn.readMessages, b.Bytes(), make(chan struct{}), + ) + require.True(t, ok) + + return done +}