diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index cdcd6d249..8ad05d67f 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -76,7 +76,9 @@ var ( rebroadcastInterval = time.Hour * 1000000 ) +// TODO(elle): replace mockGraphSource with testify.Mock. type mockGraphSource struct { + t *testing.T bestHeight uint32 mu sync.Mutex @@ -85,15 +87,22 @@ type mockGraphSource struct { edges map[uint64][]models.ChannelEdgePolicy zombies map[uint64][][33]byte chansToReject map[uint64]struct{} + + updateEdgeCount int + pauseGetChannelByID chan chan struct{} } -func newMockRouter(height uint32) *mockGraphSource { +func newMockRouter(t *testing.T, height uint32) *mockGraphSource { return &mockGraphSource{ - bestHeight: height, - infos: make(map[uint64]models.ChannelEdgeInfo), - edges: make(map[uint64][]models.ChannelEdgePolicy), - zombies: make(map[uint64][][33]byte), - chansToReject: make(map[uint64]struct{}), + t: t, + bestHeight: height, + infos: make(map[uint64]models.ChannelEdgeInfo), + edges: make( + map[uint64][]models.ChannelEdgePolicy, + ), + zombies: make(map[uint64][][33]byte), + chansToReject: make(map[uint64]struct{}), + pauseGetChannelByID: make(chan chan struct{}, 1), } } @@ -155,7 +164,10 @@ func (r *mockGraphSource) UpdateEdge(edge *models.ChannelEdgePolicy, _ ...batch.SchedulerOption) error { r.mu.Lock() - defer r.mu.Unlock() + defer func() { + r.updateEdgeCount++ + r.mu.Unlock() + }() if len(r.edges[edge.ChannelID]) == 0 { r.edges[edge.ChannelID] = make([]models.ChannelEdgePolicy, 2) @@ -234,6 +246,18 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) { + select { + // Check if a pause request channel has been loaded. If one has, then we + // wait for it to be closed before continuing. + case pauseChan := <-r.pauseGetChannelByID: + select { + case <-pauseChan: + case <-time.After(time.Second * 30): + r.t.Fatal("timeout waiting for pause channel") + } + default: + } + r.mu.Lock() defer r.mu.Unlock() @@ -874,7 +898,7 @@ func createTestCtx(t *testing.T, startHeight uint32, isChanPeer bool) ( // any p2p functionality, the peer send and switch send, // broadcast functions won't be populated. notifier := newMockNotifier() - router := newMockRouter(startHeight) + router := newMockRouter(t, startHeight) chain := &lnmock.MockChain{} t.Cleanup(func() { chain.AssertExpectations(t) @@ -3977,6 +4001,197 @@ func TestBroadcastAnnsAfterGraphSynced(t *testing.T) { assertBroadcast(chanAnn2, true, true) } +// TestRateLimitDeDup demonstrates a bug that currently exists in the handling +// of channel updates. It shows that if two identical channel updates are +// received in quick succession, then both of them might be counted towards the +// rate limit, even though only one of them should be. +// +// NOTE: this will be fixed in an upcoming commit. +func TestRateLimitDeDup(t *testing.T) { + t.Parallel() + + // Create our test harness. + const blockHeight = 100 + ctx, err := createTestCtx(t, blockHeight, false) + require.NoError(t, err, "can't create context") + ctx.gossiper.cfg.RebroadcastInterval = time.Hour + + var findBaseByAliasCount atomic.Int32 + ctx.gossiper.cfg.FindBaseByAlias = func(alias lnwire.ShortChannelID) ( + lnwire.ShortChannelID, error) { + + findBaseByAliasCount.Add(1) + + return lnwire.ShortChannelID{}, fmt.Errorf("none") + } + + getUpdateEdgeCount := func() int { + ctx.router.mu.Lock() + defer ctx.router.mu.Unlock() + + return ctx.router.updateEdgeCount + } + + // We set the burst to 2 here. The very first update should not count + // towards this _and_ any duplicates should also not count towards it. + ctx.gossiper.cfg.MaxChannelUpdateBurst = 2 + ctx.gossiper.cfg.ChannelUpdateInterval = time.Minute + + // The graph should start empty. + require.Empty(t, ctx.router.infos) + require.Empty(t, ctx.router.edges) + + // We'll create a batch of signed announcements, including updates for + // both sides, for a channel and process them. They should all be + // forwarded as this is our first time learning about the channel. + batch, err := ctx.createRemoteAnnouncements(blockHeight) + require.NoError(t, err) + + nodePeer1 := &mockPeer{ + remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{}, + } + select { + case err := <-ctx.gossiper.ProcessRemoteAnnouncement( + batch.chanAnn, nodePeer1, + ): + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("remote announcement not processed") + } + + select { + case err := <-ctx.gossiper.ProcessRemoteAnnouncement( + batch.chanUpdAnn1, nodePeer1, + ): + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("remote announcement not processed") + } + + nodePeer2 := &mockPeer{ + remoteKeyPriv2.PubKey(), nil, nil, atomic.Bool{}, + } + select { + case err := <-ctx.gossiper.ProcessRemoteAnnouncement( + batch.chanUpdAnn2, nodePeer2, + ): + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("remote announcement not processed") + } + + timeout := time.After(2 * trickleDelay) + for i := 0; i < 3; i++ { + select { + case <-ctx.broadcastedMessage: + case <-timeout: + t.Fatal("expected announcement to be broadcast") + } + } + + shortChanID := batch.chanAnn.ShortChannelID.ToUint64() + require.Contains(t, ctx.router.infos, shortChanID) + require.Contains(t, ctx.router.edges, shortChanID) + + // Before we send anymore updates, we want to let our test harness + // hang during GetChannelByID so that we can ensure that two threads are + // waiting for the chan. + pause := make(chan struct{}) + ctx.router.pauseGetChannelByID <- pause + + // Take note of how many times FindBaseByAlias has been called. + // It should be 2 since we have processed two channel updates. + require.EqualValues(t, 2, findBaseByAliasCount.Load()) + + // The same is expected for the UpdateEdge call. + require.EqualValues(t, 2, getUpdateEdgeCount()) + + update := *batch.chanUpdAnn1 + + // refreshUpdate is a helper that helps us ensure that the update + // is not seen as stale or as a keep-alive. + refreshUpdate := func() { + update.Timestamp++ + update.BaseFee++ + require.NoError(t, signUpdate(remoteKeyPriv1, &update)) + } + + refreshUpdate() + + // Ok, now we will send the same channel update twice in quick + // succession. We wait for both to have hit the FindBaseByAlias check + // before we un-pause the GetChannelByID call. + go func() { + ctx.gossiper.ProcessRemoteAnnouncement(&update, nodePeer1) + }() + go func() { + ctx.gossiper.ProcessRemoteAnnouncement(&update, nodePeer1) + }() + + // We know that both are being processed once the count for + // FindBaseByAlias has increased by 2. + err = wait.NoError(func() error { + count := findBaseByAliasCount.Load() + + if count != 4 { + return fmt.Errorf("expected 4 calls to "+ + "FindBaseByAlias, got %v", count) + } + + return nil + }, time.Second*5) + require.NoError(t, err) + + // Now we can un-pause the thread that grabbed the mutex first. + close(pause) + + // Currently, both updates make it to UpdateEdge. + err = wait.NoError(func() error { + count := getUpdateEdgeCount() + if count != 4 { + return fmt.Errorf("expected 4 calls to UpdateEdge, "+ + "got %v", count) + } + + return nil + }, time.Second*5) + require.NoError(t, err) + + // We'll define a helper to assert whether update was broadcast or not. + assertBroadcast := func(shouldBroadcast bool) { + t.Helper() + + select { + case <-ctx.broadcastedMessage: + require.True(t, shouldBroadcast) + case <-time.After(2 * trickleDelay): + require.False(t, shouldBroadcast) + } + } + + processUpdate := func(msg lnwire.Message, peer lnpeer.Peer) { + select { + case err := <-ctx.gossiper.ProcessRemoteAnnouncement( + msg, peer, + ): + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("remote announcement not processed") + } + } + + // Show that the last update was broadcast. + assertBroadcast(true) + + // We should be allowed to send another update now since only one of the + // above duplicates should count towards the rate limit. + // However, this is currently not the case, and so we will be rate + // limited early. This will be fixed in an upcoming commit. + refreshUpdate() + processUpdate(&update, nodePeer1) + assertBroadcast(false) +} + // TestRateLimitChannelUpdates ensures that we properly rate limit incoming // channel updates. func TestRateLimitChannelUpdates(t *testing.T) {