diff --git a/discovery/gossiper.go b/discovery/gossiper.go index be7be8516..5eac87d91 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -430,8 +430,7 @@ type AuthenticatedGossiper struct { // the chan networkMsgs once the block height has reached. The cached // map format is, // {msgID1: msg1, msgID2: msg2, ...} - futureMsgs *lru.Cache[uint64, *cachedFutureMsg] - futureMsgID atomic.Uint64 + futureMsgs *futureMsgCache // chanPolicyUpdates is a channel that requests to update the // forwarding policy of a set of channels is sent over. @@ -484,13 +483,11 @@ type AuthenticatedGossiper struct { // passed configuration parameters. func New(cfg Config, selfKeyDesc *keychain.KeyDescriptor) *AuthenticatedGossiper { gossiper := &AuthenticatedGossiper{ - selfKey: selfKeyDesc.PubKey, - selfKeyLoc: selfKeyDesc.KeyLocator, - cfg: &cfg, - networkMsgs: make(chan *networkMsg), - futureMsgs: lru.NewCache[uint64, *cachedFutureMsg]( - maxFutureMessages, - ), + selfKey: selfKeyDesc.PubKey, + selfKeyLoc: selfKeyDesc.KeyLocator, + cfg: &cfg, + networkMsgs: make(chan *networkMsg), + futureMsgs: newFutureMsgCache(maxFutureMessages), quit: make(chan struct{}), chanPolicyUpdates: make(chan *chanPolicyUpdateRequest), prematureChannelUpdates: lru.NewCache[uint64, *cachedNetworkMsg]( //nolint: lll @@ -638,7 +635,32 @@ func (d *AuthenticatedGossiper) syncBlockHeight() { } } -// cachedFutureMsg is a future message that's saved to the `futureMsgs` cache. +// futureMsgCache embeds a `lru.Cache` with a message counter that's served as +// the unique ID when saving the message. +type futureMsgCache struct { + *lru.Cache[uint64, *cachedFutureMsg] + + // msgID is a monotonically increased integer. + msgID atomic.Uint64 +} + +// nextMsgID returns a unique message ID. +func (f *futureMsgCache) nextMsgID() uint64 { + return f.msgID.Add(1) +} + +// newFutureMsgCache creates a new future message cache with the underlying lru +// cache being initialized with the specified capacity. +func newFutureMsgCache(capacity uint64) *futureMsgCache { + // Create a new cache. + cache := lru.NewCache[uint64, *cachedFutureMsg](capacity) + + return &futureMsgCache{ + Cache: cache, + } +} + +// cachedFutureMsg is a future message that's saved to the `futureMsgCache`. type cachedFutureMsg struct { // msg is the network message. msg *networkMsg @@ -1933,7 +1955,7 @@ func (d *AuthenticatedGossiper) isPremature(chanID lnwire.ShortChannelID, } // Increment the msg ID and add it to the cache. - nextMsgID := d.futureMsgID.Add(1) + nextMsgID := d.futureMsgs.nextMsgID() _, err := d.futureMsgs.Put(nextMsgID, cachedMsg) if err != nil { log.Errorf("Adding future message got error: %v", err) diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 37823469c..cc2fdef05 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -20,6 +20,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" + "github.com/lightninglabs/neutrino/cache" "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" @@ -4098,3 +4099,38 @@ func TestRejectCacheChannelAnn(t *testing.T) { t.Fatal("did not process remote announcement") } } + +// TestFutureMsgCacheEviction checks that when the cache's capacity is reached, +// saving one more item will evict the oldest item. +func TestFutureMsgCacheEviction(t *testing.T) { + t.Parallel() + + // Create a future message cache with size 1. + c := newFutureMsgCache(1) + + // Send two messages to the cache, which ends in the first message + // being evicted. + // + // Put the first item. + id := c.nextMsgID() + evicted, err := c.Put(id, &cachedFutureMsg{height: uint32(id)}) + require.NoError(t, err) + require.False(t, evicted, "should not be evicted") + + // Put the second item. + id = c.nextMsgID() + evicted, err = c.Put(id, &cachedFutureMsg{height: uint32(id)}) + require.NoError(t, err) + require.True(t, evicted, "should be evicted") + + // The first item should have been evicted. + // + // NOTE: msg ID starts at 1, not 0. + _, err = c.Get(1) + require.ErrorIs(t, err, cache.ErrElementNotFound) + + // The second item should be found. + item, err := c.Get(2) + require.NoError(t, err) + require.EqualValues(t, 2, item.height, "should be the second item") +}