From 199e83d3f2d22b230f2ee21b92237e173d56c5dc Mon Sep 17 00:00:00 2001 From: Eugene Siegel Date: Wed, 14 Aug 2024 13:56:31 -0400 Subject: [PATCH 1/7] channeldb: add PutClosedScid and IsClosedScid This commit adds the ability to store closed channels by scid in the database. This will allow the gossiper to ignore channel announcements for closed channels without having to do any expensive validation. --- channeldb/error.go | 4 +++ channeldb/graph.go | 56 +++++++++++++++++++++++++++++++++++++++++ channeldb/graph_test.go | 25 ++++++++++++++++++ 3 files changed, 85 insertions(+) diff --git a/channeldb/error.go b/channeldb/error.go index 859af9746..629cd93c6 100644 --- a/channeldb/error.go +++ b/channeldb/error.go @@ -43,6 +43,10 @@ var ( // created. ErrMetaNotFound = fmt.Errorf("unable to locate meta information") + // ErrClosedScidsNotFound is returned when the closed scid bucket + // hasn't been created. + ErrClosedScidsNotFound = fmt.Errorf("closed scid bucket doesn't exist") + // ErrGraphNotFound is returned when at least one of the components of // graph doesn't exist. ErrGraphNotFound = fmt.Errorf("graph bucket not initialized") diff --git a/channeldb/graph.go b/channeldb/graph.go index 464398b40..86cbe9aa9 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -153,6 +153,14 @@ var ( // case we'll remove all entries from the prune log with a block height // that no longer exists. pruneLogBucket = []byte("prune-log") + + // closedScidBucket is a top-level bucket that stores scids for + // channels that we know to be closed. This is used so that we don't + // need to perform expensive validation checks if we receive a channel + // announcement for the channel again. + // + // maps: scid -> []byte{} + closedScidBucket = []byte("closed-scid") ) const ( @@ -318,6 +326,7 @@ var graphTopLevelBuckets = [][]byte{ nodeBucket, edgeBucket, graphMetaBucket, + closedScidBucket, } // Wipe completely deletes all saved state within all used buckets within the @@ -3884,6 +3893,53 @@ func (c *ChannelGraph) NumZombies() (uint64, error) { return numZombies, nil } +// PutClosedScid stores a SCID for a closed channel in the database. This is so +// that we can ignore channel announcements that we know to be closed without +// having to validate them and fetch a block. +func (c *ChannelGraph) PutClosedScid(scid lnwire.ShortChannelID) error { + return kvdb.Update(c.db, func(tx kvdb.RwTx) error { + closedScids, err := tx.CreateTopLevelBucket(closedScidBucket) + if err != nil { + return err + } + + var k [8]byte + byteOrder.PutUint64(k[:], scid.ToUint64()) + + return closedScids.Put(k[:], []byte{}) + }, func() {}) +} + +// IsClosedScid checks whether a channel identified by the passed in scid is +// closed. This helps avoid having to perform expensive validation checks. +// TODO: Add an LRU cache to cut down on disc reads. +func (c *ChannelGraph) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) { + var isClosed bool + err := kvdb.View(c.db, func(tx kvdb.RTx) error { + closedScids := tx.ReadBucket(closedScidBucket) + if closedScids == nil { + return ErrClosedScidsNotFound + } + + var k [8]byte + byteOrder.PutUint64(k[:], scid.ToUint64()) + + if closedScids.Get(k[:]) != nil { + isClosed = true + return nil + } + + return nil + }, func() { + isClosed = false + }) + if err != nil { + return false, err + } + + return isClosed, nil +} + func putLightningNode(nodeBucket kvdb.RwBucket, aliasBucket kvdb.RwBucket, // nolint:dupl updateIndex kvdb.RwBucket, node *LightningNode) error { diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index b05f3daaa..89197a0a8 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -4037,3 +4037,28 @@ func TestGraphLoading(t *testing.T) { graphReloaded.graphCache.nodeFeatures, ) } + +// TestClosedScid tests that we can correctly insert a SCID into the index of +// closed short channel ids. +func TestClosedScid(t *testing.T) { + t.Parallel() + + graph, err := MakeTestGraph(t) + require.Nil(t, err) + + scid := lnwire.ShortChannelID{} + + // The scid should not exist in the closedScidBucket. + exists, err := graph.IsClosedScid(scid) + require.Nil(t, err) + require.False(t, exists) + + // After we call PutClosedScid, the call to IsClosedScid should return + // true. + err = graph.PutClosedScid(scid) + require.Nil(t, err) + + exists, err = graph.IsClosedScid(scid) + require.Nil(t, err) + require.True(t, exists) +} From 0173e4c44da4bea471441a2c26d0db59e8f9e1e5 Mon Sep 17 00:00:00 2001 From: Eugene Siegel Date: Wed, 14 Aug 2024 13:59:28 -0400 Subject: [PATCH 2/7] discovery: add banman for channel announcements This commit introduces a ban manager that marks peers as banned if they send too many invalid channel announcements to us. Expired entries are purged after a certain period of time (currently 48 hours). --- discovery/ban.go | 252 ++++++++++++++++++++++++++++++++++++++++++ discovery/ban_test.go | 60 ++++++++++ 2 files changed, 312 insertions(+) create mode 100644 discovery/ban.go create mode 100644 discovery/ban_test.go diff --git a/discovery/ban.go b/discovery/ban.go new file mode 100644 index 000000000..cd70d7c38 --- /dev/null +++ b/discovery/ban.go @@ -0,0 +1,252 @@ +package discovery + +import ( + "errors" + "sync" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightninglabs/neutrino/cache" + "github.com/lightninglabs/neutrino/cache/lru" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" +) + +const ( + // maxBannedPeers limits the maximum number of banned pubkeys that + // we'll store. + // TODO(eugene): tune. + maxBannedPeers = 10_000 + + // banThreshold is the point at which non-channel peers will be banned. + // TODO(eugene): tune. + banThreshold = 100 + + // banTime is the amount of time that the non-channel peer will be + // banned for. Channel announcements from channel peers will be dropped + // if it's not one of our channels. + // TODO(eugene): tune. + banTime = time.Hour * 48 + + // resetDelta is the time after a peer's last ban update that we'll + // reset its ban score. + // TODO(eugene): tune. + resetDelta = time.Hour * 48 + + // purgeInterval is how often we'll remove entries from the + // peerBanIndex and allow peers to be un-banned. This interval is also + // used to reset ban scores of peers that aren't banned. + purgeInterval = time.Minute * 10 +) + +var ErrPeerBanned = errors.New("peer has bypassed ban threshold - banning") + +// ClosedChannelTracker handles closed channels being gossiped to us. +type ClosedChannelTracker interface { + // GraphCloser is used to mark channels as closed and to check whether + // certain channels are closed. + GraphCloser + + // IsChannelPeer checks whether we have a channel with a peer. + IsChannelPeer(*btcec.PublicKey) (bool, error) +} + +// GraphCloser handles tracking closed channels by their scid. +type GraphCloser interface { + // PutClosedScid marks a channel as closed so that we won't validate + // channel announcements for it again. + PutClosedScid(lnwire.ShortChannelID) error + + // IsClosedScid checks if a short channel id is closed. + IsClosedScid(lnwire.ShortChannelID) (bool, error) +} + +// NodeInfoInquirier handles queries relating to specific nodes and channels +// they may have with us. +type NodeInfoInquirer interface { + // FetchOpenChannels returns the set of channels that we have with the + // peer identified by the passed-in public key. + FetchOpenChannels(*btcec.PublicKey) ([]*channeldb.OpenChannel, error) +} + +// ScidCloserMan helps the gossiper handle closed channels that are in the +// ChannelGraph. +type ScidCloserMan struct { + graph GraphCloser + channelDB NodeInfoInquirer +} + +// NewScidCloserMan creates a new ScidCloserMan. +func NewScidCloserMan(graph GraphCloser, + channelDB NodeInfoInquirer) *ScidCloserMan { + + return &ScidCloserMan{ + graph: graph, + channelDB: channelDB, + } +} + +// PutClosedScid marks scid as closed so the gossiper can ignore this channel +// in the future. +func (s *ScidCloserMan) PutClosedScid(scid lnwire.ShortChannelID) error { + return s.graph.PutClosedScid(scid) +} + +// IsClosedScid checks whether scid is closed so that the gossiper can ignore +// it. +func (s *ScidCloserMan) IsClosedScid(scid lnwire.ShortChannelID) (bool, + error) { + + return s.graph.IsClosedScid(scid) +} + +// IsChannelPeer checks whether we have a channel with the peer. +func (s *ScidCloserMan) IsChannelPeer(peerKey *btcec.PublicKey) (bool, error) { + chans, err := s.channelDB.FetchOpenChannels(peerKey) + if err != nil { + return false, err + } + + return len(chans) > 0, nil +} + +// A compile-time constraint to ensure ScidCloserMan implements +// ClosedChannelTracker. +var _ ClosedChannelTracker = (*ScidCloserMan)(nil) + +// cachedBanInfo is used to track a peer's ban score and if it is banned. +type cachedBanInfo struct { + score uint64 + lastUpdate time.Time +} + +// Size returns the "size" of an entry. +func (c *cachedBanInfo) Size() (uint64, error) { + return 1, nil +} + +// isBanned returns true if the ban score is greater than the ban threshold. +func (c *cachedBanInfo) isBanned() bool { + return c.score >= banThreshold +} + +// banman is responsible for banning peers that are misbehaving. The banman is +// in-memory and will be reset upon restart of LND. If a node's pubkey is in +// the peerBanIndex, it has a ban score. Ban scores start at 1 and are +// incremented by 1 for each instance of misbehavior. It uses an LRU cache to +// cut down on memory usage in case there are many banned peers and to protect +// against DoS. +type banman struct { + // peerBanIndex tracks our peers' ban scores and if they are banned and + // for how long. The ban score is incremented when our peer gives us + // gossip messages that are invalid. + peerBanIndex *lru.Cache[[33]byte, *cachedBanInfo] + + wg sync.WaitGroup + quit chan struct{} +} + +// newBanman creates a new banman with the default maxBannedPeers. +func newBanman() *banman { + return &banman{ + peerBanIndex: lru.NewCache[[33]byte, *cachedBanInfo]( + maxBannedPeers, + ), + quit: make(chan struct{}), + } +} + +// start kicks off the banman by calling purgeExpiredBans. +func (b *banman) start() { + b.wg.Add(1) + go b.purgeExpiredBans() +} + +// stop halts the banman. +func (b *banman) stop() { + close(b.quit) + b.wg.Wait() +} + +// purgeOldEntries removes ban entries if their ban has expired. +func (b *banman) purgeExpiredBans() { + defer b.wg.Done() + + purgeTicker := time.NewTicker(purgeInterval) + defer purgeTicker.Stop() + + for { + select { + case <-purgeTicker.C: + b.purgeBanEntries() + + case <-b.quit: + return + } + } +} + +// purgeBanEntries does two things: +// - removes peers from our ban list whose ban timer is up +// - removes peers whose ban scores have expired. +func (b *banman) purgeBanEntries() { + keysToRemove := make([][33]byte, 0) + + sweepEntries := func(pubkey [33]byte, banInfo *cachedBanInfo) bool { + if banInfo.isBanned() { + // If the peer is banned, check if the ban timer has + // expired. + if banInfo.lastUpdate.Add(banTime).Before(time.Now()) { + keysToRemove = append(keysToRemove, pubkey) + } + + return true + } + + if banInfo.lastUpdate.Add(resetDelta).Before(time.Now()) { + // Remove non-banned peers whose ban scores have + // expired. + keysToRemove = append(keysToRemove, pubkey) + } + + return true + } + + b.peerBanIndex.Range(sweepEntries) + + for _, key := range keysToRemove { + b.peerBanIndex.Delete(key) + } +} + +// isBanned checks whether the peer identified by the pubkey is banned. +func (b *banman) isBanned(pubkey [33]byte) bool { + banInfo, err := b.peerBanIndex.Get(pubkey) + switch { + case errors.Is(err, cache.ErrElementNotFound): + return false + + default: + return banInfo.isBanned() + } +} + +// incrementBanScore increments a peer's ban score. +func (b *banman) incrementBanScore(pubkey [33]byte) { + banInfo, err := b.peerBanIndex.Get(pubkey) + switch { + case errors.Is(err, cache.ErrElementNotFound): + cachedInfo := &cachedBanInfo{ + score: 1, + lastUpdate: time.Now(), + } + _, _ = b.peerBanIndex.Put(pubkey, cachedInfo) + default: + cachedInfo := &cachedBanInfo{ + score: banInfo.score + 1, + lastUpdate: time.Now(), + } + + _, _ = b.peerBanIndex.Put(pubkey, cachedInfo) + } +} diff --git a/discovery/ban_test.go b/discovery/ban_test.go new file mode 100644 index 000000000..e4149028b --- /dev/null +++ b/discovery/ban_test.go @@ -0,0 +1,60 @@ +package discovery + +import ( + "testing" + "time" + + "github.com/lightninglabs/neutrino/cache" + "github.com/stretchr/testify/require" +) + +// TestPurgeBanEntries tests that we properly purge ban entries on a timer. +func TestPurgeBanEntries(t *testing.T) { + t.Parallel() + + b := newBanman() + + // Ban a peer by repeatedly incrementing its ban score. + peer1 := [33]byte{0x00} + + for i := 0; i < banThreshold; i++ { + b.incrementBanScore(peer1) + } + + // Assert that the peer is now banned. + require.True(t, b.isBanned(peer1)) + + // A call to purgeBanEntries should not remove the peer from the index. + b.purgeBanEntries() + require.True(t, b.isBanned(peer1)) + + // Now set the peer's last update time to two banTimes in the past so + // that we can assert that purgeBanEntries does remove it from the + // index. + banInfo, err := b.peerBanIndex.Get(peer1) + require.NoError(t, err) + + banInfo.lastUpdate = time.Now().Add(-2 * banTime) + + b.purgeBanEntries() + _, err = b.peerBanIndex.Get(peer1) + require.ErrorIs(t, err, cache.ErrElementNotFound) + + // Increment the peer's ban score again but don't get it banned. + b.incrementBanScore(peer1) + require.False(t, b.isBanned(peer1)) + + // Assert that purgeBanEntries does nothing. + b.purgeBanEntries() + banInfo, err = b.peerBanIndex.Get(peer1) + require.Nil(t, err) + + // Set its lastUpdate time to 2 resetDelta's in the past so that + // purgeBanEntries removes it. + banInfo.lastUpdate = time.Now().Add(-2 * resetDelta) + + b.purgeBanEntries() + + _, err = b.peerBanIndex.Get(peer1) + require.ErrorIs(t, err, cache.ErrElementNotFound) +} From 99b86ba462d6bec42659444d09136a5b65ed908f Mon Sep 17 00:00:00 2001 From: Eugene Siegel Date: Wed, 14 Aug 2024 14:03:39 -0400 Subject: [PATCH 3/7] multi: extend lnpeer.Peer interface with Disconnect function This will be used in the gossiper to disconnect from peers if their ban score passes the ban threshold. --- discovery/gossiper_test.go | 80 +++++++++++++++++++++---------- discovery/mock_test.go | 12 +++-- discovery/reliable_sender_test.go | 5 +- funding/manager_test.go | 2 + htlcswitch/link_test.go | 2 + htlcswitch/mock.go | 2 + lnpeer/mock_peer.go | 2 + lnpeer/peer.go | 3 ++ 8 files changed, 78 insertions(+), 30 deletions(-) diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 7cfc7bce8..5bc84d390 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -765,7 +765,7 @@ func createTestCtx(t *testing.T, startHeight uint32) (*testCtx, error) { peerChan chan<- lnpeer.Peer) { pk, _ := btcec.ParsePubKey(target[:]) - peerChan <- &mockPeer{pk, nil, nil} + peerChan <- &mockPeer{pk, nil, nil, atomic.Bool{}} }, NotifyWhenOffline: func(_ [33]byte) <-chan struct{} { c := make(chan struct{}) @@ -843,7 +843,7 @@ func TestProcessAnnouncement(t *testing.T) { } } - nodePeer := &mockPeer{remoteKeyPriv1.PubKey(), nil, nil} + nodePeer := &mockPeer{remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{}} // First, we'll craft a valid remote channel announcement and send it to // the gossiper so that it can be processed. @@ -953,7 +953,7 @@ func TestPrematureAnnouncement(t *testing.T) { _, err = createNodeAnnouncement(remoteKeyPriv1, timestamp) require.NoError(t, err, "can't create node announcement") - nodePeer := &mockPeer{remoteKeyPriv1.PubKey(), nil, nil} + nodePeer := &mockPeer{remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{}} // Pretending that we receive the valid channel announcement from // remote side, but block height of this announcement is greater than @@ -990,7 +990,9 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) { pk, _ := btcec.ParsePubKey(target[:]) select { - case peerChan <- &mockPeer{pk, sentMsgs, ctx.gossiper.quit}: + case peerChan <- &mockPeer{ + pk, sentMsgs, ctx.gossiper.quit, atomic.Bool{}, + }: case <-ctx.gossiper.quit: } } @@ -1000,7 +1002,9 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) { remoteKey, err := btcec.ParsePubKey(batch.nodeAnn2.NodeID[:]) require.NoError(t, err, "unable to parse pubkey") - remotePeer := &mockPeer{remoteKey, sentMsgs, ctx.gossiper.quit} + remotePeer := &mockPeer{ + remoteKey, sentMsgs, ctx.gossiper.quit, atomic.Bool{}, + } // Recreate lightning network topology. Initialize router with channel // between two nodes. @@ -1162,7 +1166,9 @@ func TestOrphanSignatureAnnouncement(t *testing.T) { pk, _ := btcec.ParsePubKey(target[:]) select { - case peerChan <- &mockPeer{pk, sentMsgs, ctx.gossiper.quit}: + case peerChan <- &mockPeer{ + pk, sentMsgs, ctx.gossiper.quit, atomic.Bool{}, + }: case <-ctx.gossiper.quit: } } @@ -1172,7 +1178,9 @@ func TestOrphanSignatureAnnouncement(t *testing.T) { remoteKey, err := btcec.ParsePubKey(batch.nodeAnn2.NodeID[:]) require.NoError(t, err, "unable to parse pubkey") - remotePeer := &mockPeer{remoteKey, sentMsgs, ctx.gossiper.quit} + remotePeer := &mockPeer{ + remoteKey, sentMsgs, ctx.gossiper.quit, atomic.Bool{}, + } // Pretending that we receive local channel announcement from funding // manager, thereby kick off the announcement exchange process, in @@ -1344,7 +1352,9 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { // Set up a channel to intercept the messages sent to the remote peer. sentToPeer := make(chan lnwire.Message, 1) - remotePeer := &mockPeer{remoteKey, sentToPeer, ctx.gossiper.quit} + remotePeer := &mockPeer{ + remoteKey, sentToPeer, ctx.gossiper.quit, atomic.Bool{}, + } // Since the reliable send to the remote peer of the local channel proof // requires a notification when the peer comes online, we'll capture the @@ -1578,7 +1588,9 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { // Set up a channel we can use to inspect messages sent by the // gossiper to the remote peer. sentToPeer := make(chan lnwire.Message, 1) - remotePeer := &mockPeer{remoteKey, sentToPeer, ctx.gossiper.quit} + remotePeer := &mockPeer{ + remoteKey, sentToPeer, ctx.gossiper.quit, atomic.Bool{}, + } // Override NotifyWhenOnline to return the remote peer which we expect // meesages to be sent to. @@ -1772,7 +1784,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { ca, err := createRemoteChannelAnnouncement(0) require.NoError(t, err, "can't create remote channel announcement") - nodePeer := &mockPeer{bitcoinKeyPub2, nil, nil} + nodePeer := &mockPeer{bitcoinKeyPub2, nil, nil, atomic.Bool{}} announcements.AddMsgs(networkMsg{ msg: ca, peer: nodePeer, @@ -2058,7 +2070,7 @@ func TestForwardPrivateNodeAnnouncement(t *testing.T) { // process it. remoteChanAnn, err := createRemoteChannelAnnouncement(startingHeight - 1) require.NoError(t, err, "unable to create remote channel announcement") - peer := &mockPeer{pubKey, nil, nil} + peer := &mockPeer{pubKey, nil, nil, atomic.Bool{}} select { case err := <-ctx.gossiper.ProcessRemoteAnnouncement(remoteChanAnn, peer): @@ -2373,7 +2385,9 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { // Set up a channel that we can use to inspect the messages sent // directly from the gossiper. sentMsgs := make(chan lnwire.Message, 10) - remotePeer := &mockPeer{remoteKey, sentMsgs, ctx.gossiper.quit} + remotePeer := &mockPeer{ + remoteKey, sentMsgs, ctx.gossiper.quit, atomic.Bool{}, + } // Override NotifyWhenOnline to return the remote peer which we expect // messages to be sent to. @@ -2561,7 +2575,9 @@ func TestExtraDataChannelAnnouncementValidation(t *testing.T) { ctx, err := createTestCtx(t, 0) require.NoError(t, err, "can't create context") - remotePeer := &mockPeer{remoteKeyPriv1.PubKey(), nil, nil} + remotePeer := &mockPeer{ + remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{}, + } // We'll now create an announcement that contains an extra set of bytes // that we don't know of ourselves, but should still include in the @@ -2592,7 +2608,9 @@ func TestExtraDataChannelUpdateValidation(t *testing.T) { ctx, err := createTestCtx(t, 0) require.NoError(t, err, "can't create context") - remotePeer := &mockPeer{remoteKeyPriv1.PubKey(), nil, nil} + remotePeer := &mockPeer{ + remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{}, + } // In this scenario, we'll create two announcements, one regular // channel announcement, and another channel update announcement, that @@ -2643,7 +2661,9 @@ func TestExtraDataNodeAnnouncementValidation(t *testing.T) { ctx, err := createTestCtx(t, 0) require.NoError(t, err, "can't create context") - remotePeer := &mockPeer{remoteKeyPriv1.PubKey(), nil, nil} + remotePeer := &mockPeer{ + remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{}, + } timestamp := testTimestamp // We'll create a node announcement that includes a set of opaque data @@ -2716,7 +2736,7 @@ func TestRetransmit(t *testing.T) { remoteKey, err := btcec.ParsePubKey(batch.nodeAnn2.NodeID[:]) require.NoError(t, err, "unable to parse pubkey") - remotePeer := &mockPeer{remoteKey, nil, nil} + remotePeer := &mockPeer{remoteKey, nil, nil, atomic.Bool{}} // Process a local channel announcement, channel update and node // announcement. No messages should be broadcasted yet, since no proof @@ -2822,7 +2842,7 @@ func TestNodeAnnouncementNoChannels(t *testing.T) { remoteKey, err := btcec.ParsePubKey(batch.nodeAnn2.NodeID[:]) require.NoError(t, err, "unable to parse pubkey") - remotePeer := &mockPeer{remoteKey, nil, nil} + remotePeer := &mockPeer{remoteKey, nil, nil, atomic.Bool{}} // Process the remote node announcement. select { @@ -2906,7 +2926,7 @@ func TestOptionalFieldsChannelUpdateValidation(t *testing.T) { chanUpdateHeight := uint32(0) timestamp := uint32(123456) - nodePeer := &mockPeer{remoteKeyPriv1.PubKey(), nil, nil} + nodePeer := &mockPeer{remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{}} // In this scenario, we'll test whether the message flags field in a // channel update is properly handled. @@ -3013,7 +3033,9 @@ func TestSendChannelUpdateReliably(t *testing.T) { // Set up a channel we can use to inspect messages sent by the // gossiper to the remote peer. sentToPeer := make(chan lnwire.Message, 1) - remotePeer := &mockPeer{remoteKey, sentToPeer, ctx.gossiper.quit} + remotePeer := &mockPeer{ + remoteKey, sentToPeer, ctx.gossiper.quit, atomic.Bool{}, + } // Since we first wait to be notified of the peer before attempting to // send the message, we'll overwrite NotifyWhenOnline and @@ -3367,7 +3389,9 @@ func TestPropagateChanPolicyUpdate(t *testing.T) { remoteKey := remoteKeyPriv1.PubKey() sentMsgs := make(chan lnwire.Message, 10) - remotePeer := &mockPeer{remoteKey, sentMsgs, ctx.gossiper.quit} + remotePeer := &mockPeer{ + remoteKey, sentMsgs, ctx.gossiper.quit, atomic.Bool{}, + } // The forced code path for sending the private ChannelUpdate to the // remote peer will be hit, forcing it to request a notification that @@ -3715,7 +3739,9 @@ func TestBroadcastAnnsAfterGraphSynced(t *testing.T) { t.Helper() - nodePeer := &mockPeer{remoteKeyPriv1.PubKey(), nil, nil} + nodePeer := &mockPeer{ + remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{}, + } var errChan chan error if isRemote { errChan = ctx.gossiper.ProcessRemoteAnnouncement( @@ -3791,7 +3817,9 @@ func TestRateLimitChannelUpdates(t *testing.T) { batch, err := createRemoteAnnouncements(blockHeight) require.NoError(t, err) - nodePeer1 := &mockPeer{remoteKeyPriv1.PubKey(), nil, nil} + nodePeer1 := &mockPeer{ + remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{}, + } select { case err := <-ctx.gossiper.ProcessRemoteAnnouncement( batch.chanAnn, nodePeer1, @@ -3810,7 +3838,9 @@ func TestRateLimitChannelUpdates(t *testing.T) { t.Fatal("remote announcement not processed") } - nodePeer2 := &mockPeer{remoteKeyPriv2.PubKey(), nil, nil} + nodePeer2 := &mockPeer{ + remoteKeyPriv2.PubKey(), nil, nil, atomic.Bool{}, + } select { case err := <-ctx.gossiper.ProcessRemoteAnnouncement( batch.chanUpdAnn2, nodePeer2, @@ -3929,7 +3959,7 @@ func TestIgnoreOwnAnnouncement(t *testing.T) { remoteKey, err := btcec.ParsePubKey(batch.nodeAnn2.NodeID[:]) require.NoError(t, err, "unable to parse pubkey") - remotePeer := &mockPeer{remoteKey, nil, nil} + remotePeer := &mockPeer{remoteKey, nil, nil, atomic.Bool{}} // Try to let the remote peer tell us about the channel we are part of. select { @@ -4075,7 +4105,7 @@ func TestRejectCacheChannelAnn(t *testing.T) { remoteKey, err := btcec.ParsePubKey(batch.nodeAnn2.NodeID[:]) require.NoError(t, err, "unable to parse pubkey") - remotePeer := &mockPeer{remoteKey, nil, nil} + remotePeer := &mockPeer{remoteKey, nil, nil, atomic.Bool{}} // Before sending over the announcement, we'll modify it such that we // know it will always fail. diff --git a/discovery/mock_test.go b/discovery/mock_test.go index 4f5c5d4e4..f34a62514 100644 --- a/discovery/mock_test.go +++ b/discovery/mock_test.go @@ -4,6 +4,7 @@ import ( "errors" "net" "sync" + "sync/atomic" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" @@ -14,9 +15,10 @@ import ( // mockPeer implements the lnpeer.Peer interface and is used to test the // gossiper's interaction with peers. type mockPeer struct { - pk *btcec.PublicKey - sentMsgs chan lnwire.Message - quit chan struct{} + pk *btcec.PublicKey + sentMsgs chan lnwire.Message + quit chan struct{} + disconnected atomic.Bool } var _ lnpeer.Peer = (*mockPeer)(nil) @@ -74,6 +76,10 @@ func (p *mockPeer) RemovePendingChannel(_ lnwire.ChannelID) error { return nil } +func (p *mockPeer) Disconnect(err error) { + p.disconnected.Store(true) +} + // mockMessageStore is an in-memory implementation of the MessageStore interface // used for the gossiper's unit tests. type mockMessageStore struct { diff --git a/discovery/reliable_sender_test.go b/discovery/reliable_sender_test.go index d1e69b11f..19fdaa1ca 100644 --- a/discovery/reliable_sender_test.go +++ b/discovery/reliable_sender_test.go @@ -2,6 +2,7 @@ package discovery import ( "fmt" + "sync/atomic" "testing" "time" @@ -74,7 +75,7 @@ func TestReliableSenderFlow(t *testing.T) { // Create a mock peer to send the messages to. pubKey := randPubKey(t) msgsSent := make(chan lnwire.Message) - peer := &mockPeer{pubKey, msgsSent, reliableSender.quit} + peer := &mockPeer{pubKey, msgsSent, reliableSender.quit, atomic.Bool{}} // Override NotifyWhenOnline and NotifyWhenOffline to provide the // notification channels so that we can control when notifications get @@ -193,7 +194,7 @@ func TestReliableSenderStaleMessages(t *testing.T) { // Create a mock peer to send the messages to. pubKey := randPubKey(t) msgsSent := make(chan lnwire.Message) - peer := &mockPeer{pubKey, msgsSent, reliableSender.quit} + peer := &mockPeer{pubKey, msgsSent, reliableSender.quit, atomic.Bool{}} // Override NotifyWhenOnline to provide the notification channel so that // we can control when notifications get dispatched. diff --git a/funding/manager_test.go b/funding/manager_test.go index 9db175ec3..c4c8b4f36 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -283,6 +283,8 @@ type testNode struct { var _ lnpeer.Peer = (*testNode)(nil) +func (n *testNode) Disconnect(err error) {} + func (n *testNode) IdentityKey() *btcec.PublicKey { return n.addr.IdentityKey } diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 8343d5a1c..60010d6ff 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -2072,6 +2072,8 @@ func (m *mockPeer) QuitSignal() <-chan struct{} { return m.quit } +func (m *mockPeer) Disconnect(err error) {} + var _ lnpeer.Peer = (*mockPeer)(nil) func (m *mockPeer) SendMessage(sync bool, msgs ...lnwire.Message) error { diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 6a9628f94..c328bc533 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -684,6 +684,8 @@ func (s *mockServer) RemoteFeatures() *lnwire.FeatureVector { return nil } +func (s *mockServer) Disconnect(err error) {} + func (s *mockServer) Stop() error { if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) { return nil diff --git a/lnpeer/mock_peer.go b/lnpeer/mock_peer.go index 908c9a0c9..7fed3e4a0 100644 --- a/lnpeer/mock_peer.go +++ b/lnpeer/mock_peer.go @@ -79,3 +79,5 @@ func (m *MockPeer) RemoteFeatures() *lnwire.FeatureVector { args := m.Called() return args.Get(0).(*lnwire.FeatureVector) } + +func (m *MockPeer) Disconnect(err error) {} diff --git a/lnpeer/peer.go b/lnpeer/peer.go index f7d2e971b..cb6bc9867 100644 --- a/lnpeer/peer.go +++ b/lnpeer/peer.go @@ -74,4 +74,7 @@ type Peer interface { // by the remote peer. This allows sub-systems that use this interface // to gate their behavior off the set of negotiated feature bits. RemoteFeatures() *lnwire.FeatureVector + + // Disconnect halts communication with the peer. + Disconnect(error) } From 8e0d7774b2af51665fc2b305c5df4037671d7185 Mon Sep 17 00:00:00 2001 From: Eugene Siegel Date: Wed, 14 Aug 2024 14:07:10 -0400 Subject: [PATCH 4/7] discovery: clean up scid variable usage --- discovery/gossiper.go | 50 ++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 80fd576a2..61c123d55 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -2399,8 +2399,10 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, ann *lnwire.ChannelAnnouncement, ops []batch.SchedulerOption) ([]networkMsg, bool) { + scid := ann.ShortChannelID + log.Debugf("Processing ChannelAnnouncement: peer=%v, short_chan_id=%v", - nMsg.peer, ann.ShortChannelID.ToUint64()) + nMsg.peer, scid.ToUint64()) // We'll ignore any channel announcements that target any chain other // than the set of chains we know of. @@ -2411,7 +2413,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, log.Errorf(err.Error()) key := newRejectCacheKey( - ann.ShortChannelID.ToUint64(), + scid.ToUint64(), sourceToPub(nMsg.source), ) _, _ = d.recentRejects.Put(key, &cachedReject{}) @@ -2423,13 +2425,12 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // If this is a remote ChannelAnnouncement with an alias SCID, we'll // reject the announcement. Since the router accepts alias SCIDs, // not erroring out would be a DoS vector. - if nMsg.isRemote && d.cfg.IsAlias(ann.ShortChannelID) { - err := fmt.Errorf("ignoring remote alias channel=%v", - ann.ShortChannelID) + if nMsg.isRemote && d.cfg.IsAlias(scid) { + err := fmt.Errorf("ignoring remote alias channel=%v", scid) log.Errorf(err.Error()) key := newRejectCacheKey( - ann.ShortChannelID.ToUint64(), + scid.ToUint64(), sourceToPub(nMsg.source), ) _, _ = d.recentRejects.Put(key, &cachedReject{}) @@ -2441,11 +2442,10 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // If the advertised inclusionary block is beyond our knowledge of the // chain tip, then we'll ignore it for now. d.Lock() - if nMsg.isRemote && d.isPremature(ann.ShortChannelID, 0, nMsg) { + if nMsg.isRemote && d.isPremature(scid, 0, nMsg) { log.Warnf("Announcement for chan_id=(%v), is premature: "+ "advertises height %v, only height %v is known", - ann.ShortChannelID.ToUint64(), - ann.ShortChannelID.BlockHeight, d.bestHeight) + scid.ToUint64(), scid.BlockHeight, d.bestHeight) d.Unlock() nMsg.err <- nil return nil, false @@ -2454,7 +2454,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // At this point, we'll now ask the router if this is a zombie/known // edge. If so we can skip all the processing below. - if d.cfg.Graph.IsKnownEdge(ann.ShortChannelID) { + if d.cfg.Graph.IsKnownEdge(scid) { nMsg.err <- nil return nil, true } @@ -2468,7 +2468,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, "%v", err) key := newRejectCacheKey( - ann.ShortChannelID.ToUint64(), + scid.ToUint64(), sourceToPub(nMsg.source), ) _, _ = d.recentRejects.Put(key, &cachedReject{}) @@ -2499,7 +2499,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, } edge := &models.ChannelEdgeInfo{ - ChannelID: ann.ShortChannelID.ToUint64(), + ChannelID: scid.ToUint64(), ChainHash: ann.ChainHash, NodeKey1Bytes: ann.NodeID1, NodeKey2Bytes: ann.NodeID2, @@ -2522,8 +2522,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, } } - log.Debugf("Adding edge for short_chan_id: %v", - ann.ShortChannelID.ToUint64()) + log.Debugf("Adding edge for short_chan_id: %v", scid.ToUint64()) // We will add the edge to the channel router. If the nodes present in // this channel are not present in the database, a partial node will be @@ -2533,13 +2532,13 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // channel ID. We do this to ensure no other goroutine has read the // database and is now making decisions based on this DB state, before // it writes to the DB. - d.channelMtx.Lock(ann.ShortChannelID.ToUint64()) + d.channelMtx.Lock(scid.ToUint64()) err := d.cfg.Graph.AddEdge(edge, ops...) if err != nil { log.Debugf("Graph rejected edge for short_chan_id(%v): %v", - ann.ShortChannelID.ToUint64(), err) + scid.ToUint64(), err) - defer d.channelMtx.Unlock(ann.ShortChannelID.ToUint64()) + defer d.channelMtx.Unlock(scid.ToUint64()) // If the edge was rejected due to already being known, then it // may be the case that this new message has a fresh channel @@ -2550,7 +2549,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, anns, rErr := d.processRejectedEdge(ann, proof) if rErr != nil { key := newRejectCacheKey( - ann.ShortChannelID.ToUint64(), + scid.ToUint64(), sourceToPub(nMsg.source), ) cr := &cachedReject{} @@ -2575,7 +2574,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, } else { // Otherwise, this is just a regular rejected edge. key := newRejectCacheKey( - ann.ShortChannelID.ToUint64(), + scid.ToUint64(), sourceToPub(nMsg.source), ) _, _ = d.recentRejects.Put(key, &cachedReject{}) @@ -2586,17 +2585,15 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, } // If err is nil, release the lock immediately. - d.channelMtx.Unlock(ann.ShortChannelID.ToUint64()) + d.channelMtx.Unlock(scid.ToUint64()) - log.Debugf("Finish adding edge for short_chan_id: %v", - ann.ShortChannelID.ToUint64()) + log.Debugf("Finish adding edge for short_chan_id: %v", scid.ToUint64()) // If we earlier received any ChannelUpdates for this channel, we can // now process them, as the channel is added to the graph. - shortChanID := ann.ShortChannelID.ToUint64() var channelUpdates []*processedNetworkMsg - earlyChanUpdates, err := d.prematureChannelUpdates.Get(shortChanID) + earlyChanUpdates, err := d.prematureChannelUpdates.Get(scid.ToUint64()) if err == nil { // There was actually an entry in the map, so we'll accumulate // it. We don't worry about deletion, since it'll eventually @@ -2629,8 +2626,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // shuts down. case *lnwire.ChannelUpdate: log.Debugf("Reprocessing ChannelUpdate for "+ - "shortChanID=%v", - msg.ShortChannelID.ToUint64()) + "shortChanID=%v", scid.ToUint64()) select { case d.networkMsgs <- updMsg: @@ -2664,7 +2660,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, nMsg.err <- nil log.Debugf("Processed ChannelAnnouncement: peer=%v, short_chan_id=%v", - nMsg.peer, ann.ShortChannelID.ToUint64()) + nMsg.peer, scid.ToUint64()) return announcements, true } From 9380292a5a41697640c2284186c82dad6f7b004f Mon Sep 17 00:00:00 2001 From: Eugene Siegel Date: Wed, 14 Aug 2024 14:08:01 -0400 Subject: [PATCH 5/7] graph: export NewErrf and ErrorCode for upcoming gossiper unit tests --- graph/builder.go | 24 ++++++++++++------------ graph/builder_test.go | 2 +- graph/errors.go | 28 ++++++++++++++-------------- graph/validation_barrier.go | 4 ++-- 4 files changed, 29 insertions(+), 29 deletions(-) diff --git a/graph/builder.go b/graph/builder.go index 6523b492b..82a36eb36 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -681,7 +681,7 @@ func (b *Builder) handleNetworkUpdate(vb *ValidationBarrier, update.err <- err case IsError(err, ErrParentValidationFailed): - update.err <- newErrf(ErrIgnored, err.Error()) //nolint + update.err <- NewErrf(ErrIgnored, err.Error()) //nolint default: log.Warnf("unexpected error during validation "+ @@ -1053,7 +1053,7 @@ func (b *Builder) assertNodeAnnFreshness(node route.Vertex, "existence of node: %v", err) } if !exists { - return newErrf(ErrIgnored, "Ignoring node announcement"+ + return NewErrf(ErrIgnored, "Ignoring node announcement"+ " for node not found in channel graph (%x)", node[:]) } @@ -1063,7 +1063,7 @@ func (b *Builder) assertNodeAnnFreshness(node route.Vertex, // if not then we won't accept the new data as it would override newer // data. if !lastUpdate.Before(msgTimestamp) { - return newErrf(ErrOutdated, "Ignoring outdated "+ + return NewErrf(ErrOutdated, "Ignoring outdated "+ "announcement for %x", node[:]) } @@ -1193,11 +1193,11 @@ func (b *Builder) processUpdate(msg interface{}, "existence: %v", err) } if isZombie { - return newErrf(ErrIgnored, "ignoring msg for zombie "+ + return NewErrf(ErrIgnored, "ignoring msg for zombie "+ "chan_id=%v", msg.ChannelID) } if exists { - return newErrf(ErrIgnored, "ignoring msg for known "+ + return NewErrf(ErrIgnored, "ignoring msg for known "+ "chan_id=%v", msg.ChannelID) } @@ -1259,7 +1259,7 @@ func (b *Builder) processUpdate(msg interface{}, default: } - return newErrf(ErrNoFundingTransaction, "unable to "+ + return NewErrf(ErrNoFundingTransaction, "unable to "+ "locate funding tx: %v", err) } @@ -1294,7 +1294,7 @@ func (b *Builder) processUpdate(msg interface{}, return err } - return newErrf(ErrInvalidFundingOutput, "output "+ + return NewErrf(ErrInvalidFundingOutput, "output "+ "failed validation: %w", err) } @@ -1313,7 +1313,7 @@ func (b *Builder) processUpdate(msg interface{}, } } - return newErrf(ErrChannelSpent, "unable to fetch utxo "+ + return NewErrf(ErrChannelSpent, "unable to fetch utxo "+ "for chan_id=%v, chan_point=%v: %v", msg.ChannelID, fundingPoint, err) } @@ -1378,7 +1378,7 @@ func (b *Builder) processUpdate(msg interface{}, b.cfg.ChannelPruneExpiry if isZombie && isStaleUpdate { - return newErrf(ErrIgnored, "ignoring stale update "+ + return NewErrf(ErrIgnored, "ignoring stale update "+ "(flags=%v|%v) for zombie chan_id=%v", msg.MessageFlags, msg.ChannelFlags, msg.ChannelID) @@ -1387,7 +1387,7 @@ func (b *Builder) processUpdate(msg interface{}, // If the channel doesn't exist in our database, we cannot // apply the updated policy. if !exists { - return newErrf(ErrIgnored, "ignoring update "+ + return NewErrf(ErrIgnored, "ignoring update "+ "(flags=%v|%v) for unknown chan_id=%v", msg.MessageFlags, msg.ChannelFlags, msg.ChannelID) @@ -1405,7 +1405,7 @@ func (b *Builder) processUpdate(msg interface{}, // Ignore outdated message. if !edge1Timestamp.Before(msg.LastUpdate) { - return newErrf(ErrOutdated, "Ignoring "+ + return NewErrf(ErrOutdated, "Ignoring "+ "outdated update (flags=%v|%v) for "+ "known chan_id=%v", msg.MessageFlags, msg.ChannelFlags, msg.ChannelID) @@ -1417,7 +1417,7 @@ func (b *Builder) processUpdate(msg interface{}, // Ignore outdated message. if !edge2Timestamp.Before(msg.LastUpdate) { - return newErrf(ErrOutdated, "Ignoring "+ + return NewErrf(ErrOutdated, "Ignoring "+ "outdated update (flags=%v|%v) for "+ "known chan_id=%v", msg.MessageFlags, msg.ChannelFlags, msg.ChannelID) diff --git a/graph/builder_test.go b/graph/builder_test.go index 600bd8634..f6c5dcf9c 100644 --- a/graph/builder_test.go +++ b/graph/builder_test.go @@ -1275,7 +1275,7 @@ func newChannelEdgeInfo(t *testing.T, ctx *testCtx, fundingHeight uint32, } func assertChanChainRejection(t *testing.T, ctx *testCtx, - edge *models.ChannelEdgeInfo, failCode errorCode) { + edge *models.ChannelEdgeInfo, failCode ErrorCode) { t.Helper() diff --git a/graph/errors.go b/graph/errors.go index c0d6b8904..0a1d6fd24 100644 --- a/graph/errors.go +++ b/graph/errors.go @@ -2,14 +2,14 @@ package graph import "github.com/go-errors/errors" -// errorCode is used to represent the various errors that can occur within this +// ErrorCode is used to represent the various errors that can occur within this // package. -type errorCode uint8 +type ErrorCode uint8 const ( // ErrOutdated is returned when the routing update already have // been applied, or a newer update is already known. - ErrOutdated errorCode = iota + ErrOutdated ErrorCode = iota // ErrIgnored is returned when the update have been ignored because // this update can't bring us something new, or because a node @@ -39,27 +39,27 @@ const ( ErrParentValidationFailed ) -// graphError is a structure that represent the error inside the graph package, +// Error is a structure that represent the error inside the graph package, // this structure carries additional information about error code in order to // be able distinguish errors outside of the current package. -type graphError struct { +type Error struct { err *errors.Error - code errorCode + code ErrorCode } // Error represents errors as the string // NOTE: Part of the error interface. -func (e *graphError) Error() string { +func (e *Error) Error() string { return e.err.Error() } -// A compile time check to ensure graphError implements the error interface. -var _ error = (*graphError)(nil) +// A compile time check to ensure Error implements the error interface. +var _ error = (*Error)(nil) -// newErrf creates a graphError by the given error formatted description and +// NewErrf creates a Error by the given error formatted description and // its corresponding error code. -func newErrf(code errorCode, format string, a ...interface{}) *graphError { - return &graphError{ +func NewErrf(code ErrorCode, format string, a ...interface{}) *Error { + return &Error{ code: code, err: errors.Errorf(format, a...), } @@ -67,8 +67,8 @@ func newErrf(code errorCode, format string, a ...interface{}) *graphError { // IsError is a helper function which is needed to have ability to check that // returned error has specific error code. -func IsError(e interface{}, codes ...errorCode) bool { - err, ok := e.(*graphError) +func IsError(e interface{}, codes ...ErrorCode) bool { + err, ok := e.(*Error) if !ok { return false } diff --git a/graph/validation_barrier.go b/graph/validation_barrier.go index 2f3c8c02c..731852d75 100644 --- a/graph/validation_barrier.go +++ b/graph/validation_barrier.go @@ -238,12 +238,12 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { // is closed, or the set of jobs exits. select { case <-v.quit: - return newErrf(ErrVBarrierShuttingDown, + return NewErrf(ErrVBarrierShuttingDown, "validation barrier shutting down") case <-signals.deny: log.Debugf("Signal deny for %s", jobDesc) - return newErrf(ErrParentValidationFailed, + return NewErrf(ErrParentValidationFailed, "parent validation failed") case <-signals.allow: From 013452cff0788289aae3aa296242c698c9beff9d Mon Sep 17 00:00:00 2001 From: Eugene Siegel Date: Fri, 16 Aug 2024 14:16:20 -0400 Subject: [PATCH 6/7] discovery: implement ChannelAnnouncement banning This commit hooks up the banman to the gossiper: - peers that are banned and don't have a channel with us will get disconnected until they are unbanned. - peers that are banned and have a channel with us won't get disconnected, but we will ignore their channel announcements until they are no longer banned. Note that this only disables gossip of announcements to us and still allows us to open channels to them. --- discovery/gossiper.go | 168 ++++++++++++++++++++++++++++- discovery/gossiper_test.go | 212 +++++++++++++++++++++++++++++++------ discovery/mock_test.go | 37 +++++++ discovery/sync_manager.go | 18 +++- discovery/syncer.go | 42 ++++---- discovery/syncer_test.go | 6 +- server.go | 57 +++++++++- 7 files changed, 480 insertions(+), 60 deletions(-) diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 61c123d55..84fae767f 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -256,6 +256,11 @@ type Config struct { // here? AnnSigner lnwallet.MessageSigner + // ScidCloser is an instance of ClosedChannelTracker that helps the + // gossiper cut down on spam channel announcements for already closed + // channels. + ScidCloser ClosedChannelTracker + // NumActiveSyncers is the number of peers for which we should have // active syncers with. After reaching NumActiveSyncers, any future // gossip syncers will be passive. @@ -434,6 +439,9 @@ type AuthenticatedGossiper struct { // ChannelAnnouncement for the channel is received. prematureChannelUpdates *lru.Cache[uint64, *cachedNetworkMsg] + // banman tracks our peer's ban status. + banman *banman + // networkMsgs is a channel that carries new network broadcasted // message from outside the gossiper service to be processed by the // networkHandler. @@ -512,6 +520,7 @@ func New(cfg Config, selfKeyDesc *keychain.KeyDescriptor) *AuthenticatedGossiper maxRejectedUpdates, ), chanUpdateRateLimiter: make(map[uint64][2]*rate.Limiter), + banman: newBanman(), } gossiper.syncMgr = newSyncManager(&SyncManagerCfg{ @@ -606,6 +615,8 @@ func (d *AuthenticatedGossiper) start() error { d.syncMgr.Start() + d.banman.start() + // Start receiving blocks in its dedicated goroutine. d.wg.Add(2) go d.syncBlockHeight() @@ -762,6 +773,8 @@ func (d *AuthenticatedGossiper) stop() { d.syncMgr.Stop() + d.banman.stop() + close(d.quit) d.wg.Wait() @@ -2459,6 +2472,51 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, return nil, true } + // Check if the channel is already closed in which case we can ignore + // it. + closed, err := d.cfg.ScidCloser.IsClosedScid(scid) + if err != nil { + log.Errorf("failed to check if scid %v is closed: %v", scid, + err) + nMsg.err <- err + + return nil, false + } + + if closed { + err = fmt.Errorf("ignoring closed channel %v", scid) + log.Error(err) + + // If this is an announcement from us, we'll just ignore it. + if !nMsg.isRemote { + nMsg.err <- err + return nil, false + } + + // Increment the peer's ban score if they are sending closed + // channel announcements. + d.banman.incrementBanScore(nMsg.peer.PubKey()) + + // If the peer is banned and not a channel peer, we'll + // disconnect them. + shouldDc, dcErr := d.ShouldDisconnect(nMsg.peer.IdentityKey()) + if dcErr != nil { + log.Errorf("failed to check if we should disconnect "+ + "peer: %v", dcErr) + nMsg.err <- dcErr + + return nil, false + } + + if shouldDc { + nMsg.peer.Disconnect(ErrPeerBanned) + } + + nMsg.err <- err + + return nil, false + } + // If this is a remote channel announcement, then we'll validate all // the signatures within the proof as it should be well formed. var proof *models.ChannelAuthProof @@ -2533,7 +2591,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // database and is now making decisions based on this DB state, before // it writes to the DB. d.channelMtx.Lock(scid.ToUint64()) - err := d.cfg.Graph.AddEdge(edge, ops...) + err = d.cfg.Graph.AddEdge(edge, ops...) if err != nil { log.Debugf("Graph rejected edge for short_chan_id(%v): %v", scid.ToUint64(), err) @@ -2543,7 +2601,8 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // If the edge was rejected due to already being known, then it // may be the case that this new message has a fresh channel // proof, so we'll check. - if graph.IsError(err, graph.ErrIgnored) { + switch { + case graph.IsError(err, graph.ErrIgnored): // Attempt to process the rejected message to see if we // get any new announcements. anns, rErr := d.processRejectedEdge(ann, proof) @@ -2571,7 +2630,55 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, nMsg.err <- nil return anns, true - } else { + + case graph.IsError( + err, graph.ErrNoFundingTransaction, + graph.ErrInvalidFundingOutput, + ): + key := newRejectCacheKey( + scid.ToUint64(), + sourceToPub(nMsg.source), + ) + _, _ = d.recentRejects.Put(key, &cachedReject{}) + + // Increment the peer's ban score. We check isRemote + // so we don't actually ban the peer in case of a local + // bug. + if nMsg.isRemote { + d.banman.incrementBanScore(nMsg.peer.PubKey()) + } + + case graph.IsError(err, graph.ErrChannelSpent): + key := newRejectCacheKey( + scid.ToUint64(), + sourceToPub(nMsg.source), + ) + _, _ = d.recentRejects.Put(key, &cachedReject{}) + + // Since this channel has already been closed, we'll + // add it to the graph's closed channel index such that + // we won't attempt to do expensive validation checks + // on it again. + // TODO: Populate the ScidCloser by using closed + // channel notifications. + dbErr := d.cfg.ScidCloser.PutClosedScid(scid) + if dbErr != nil { + log.Errorf("failed to mark scid(%v) as "+ + "closed: %v", scid, dbErr) + + nMsg.err <- dbErr + + return nil, false + } + + // Increment the peer's ban score. We check isRemote + // so we don't accidentally ban ourselves in case of a + // bug. + if nMsg.isRemote { + d.banman.incrementBanScore(nMsg.peer.PubKey()) + } + + default: // Otherwise, this is just a regular rejected edge. key := newRejectCacheKey( scid.ToUint64(), @@ -2580,7 +2687,29 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, _, _ = d.recentRejects.Put(key, &cachedReject{}) } + if !nMsg.isRemote { + log.Errorf("failed to add edge for local channel: %v", + err) + nMsg.err <- err + + return nil, false + } + + shouldDc, dcErr := d.ShouldDisconnect(nMsg.peer.IdentityKey()) + if dcErr != nil { + log.Errorf("failed to check if we should disconnect "+ + "peer: %v", dcErr) + nMsg.err <- dcErr + + return nil, false + } + + if shouldDc { + nMsg.peer.Disconnect(ErrPeerBanned) + } + nMsg.err <- err + return nil, false } @@ -3385,3 +3514,36 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, nMsg.err <- nil return announcements, true } + +// isBanned returns true if the peer identified by pubkey is banned for sending +// invalid channel announcements. +func (d *AuthenticatedGossiper) isBanned(pubkey [33]byte) bool { + return d.banman.isBanned(pubkey) +} + +// ShouldDisconnect returns true if we should disconnect the peer identified by +// pubkey. +func (d *AuthenticatedGossiper) ShouldDisconnect(pubkey *btcec.PublicKey) ( + bool, error) { + + pubkeySer := pubkey.SerializeCompressed() + + var pubkeyBytes [33]byte + copy(pubkeyBytes[:], pubkeySer) + + // If the public key is banned, check whether or not this is a channel + // peer. + if d.isBanned(pubkeyBytes) { + isChanPeer, err := d.cfg.ScidCloser.IsChannelPeer(pubkey) + if err != nil { + return false, err + } + + // We should only disconnect non-channel peers. + if !isChanPeer { + return true, nil + } + } + + return false, nil +} diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 5bc84d390..c7cb149cf 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -25,6 +25,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -90,12 +91,13 @@ func makeTestDB(t *testing.T) (*channeldb.DB, error) { type mockGraphSource struct { bestHeight uint32 - mu sync.Mutex - nodes []channeldb.LightningNode - infos map[uint64]models.ChannelEdgeInfo - edges map[uint64][]models.ChannelEdgePolicy - zombies map[uint64][][33]byte - chansToReject map[uint64]struct{} + mu sync.Mutex + nodes []channeldb.LightningNode + infos map[uint64]models.ChannelEdgeInfo + edges map[uint64][]models.ChannelEdgePolicy + zombies map[uint64][][33]byte + chansToReject map[uint64]struct{} + addEdgeErrCode fn.Option[graph.ErrorCode] } func newMockRouter(height uint32) *mockGraphSource { @@ -126,6 +128,12 @@ func (r *mockGraphSource) AddEdge(info *models.ChannelEdgeInfo, r.mu.Lock() defer r.mu.Unlock() + if r.addEdgeErrCode.IsSome() { + return graph.NewErrf( + r.addEdgeErrCode.UnsafeFromSome(), "received error", + ) + } + if _, ok := r.infos[info.ChannelID]; ok { return errors.New("info already exist") } @@ -138,6 +146,14 @@ func (r *mockGraphSource) AddEdge(info *models.ChannelEdgeInfo, return nil } +func (r *mockGraphSource) resetAddEdgeErrCode() { + r.addEdgeErrCode = fn.None[graph.ErrorCode]() +} + +func (r *mockGraphSource) setAddEdgeErrCode(code graph.ErrorCode) { + r.addEdgeErrCode = fn.Some[graph.ErrorCode](code) +} + func (r *mockGraphSource) queueValidationFail(chanID uint64) { r.mu.Lock() defer r.mu.Unlock() @@ -707,7 +723,9 @@ type testCtx struct { broadcastedMessage chan msgWithSenders } -func createTestCtx(t *testing.T, startHeight uint32) (*testCtx, error) { +func createTestCtx(t *testing.T, startHeight uint32, isChanPeer bool) ( + *testCtx, error) { + // Next we'll initialize an instance of the channel router with mock // versions of the chain and channel notifier. As we don't need to test // any p2p functionality, the peer send and switch send, @@ -803,6 +821,7 @@ func createTestCtx(t *testing.T, startHeight uint32) (*testCtx, error) { FindBaseByAlias: findBaseByAlias, GetAlias: getAlias, FindChannel: mockFindChannel, + ScidCloser: newMockScidCloser(isChanPeer), }, selfKeyDesc) if err := gossiper.Start(); err != nil { @@ -831,7 +850,7 @@ func TestProcessAnnouncement(t *testing.T) { t.Parallel() timestamp := testTimestamp - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") assertSenderExistence := func(sender *btcec.PublicKey, msg msgWithSenders) { @@ -947,7 +966,7 @@ func TestPrematureAnnouncement(t *testing.T) { timestamp := testTimestamp - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") _, err = createNodeAnnouncement(remoteKeyPriv1, timestamp) @@ -978,7 +997,7 @@ func TestPrematureAnnouncement(t *testing.T) { func TestSignatureAnnouncementLocalFirst(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") // Set up a channel that we can use to inspect the messages sent @@ -1154,7 +1173,7 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) { func TestOrphanSignatureAnnouncement(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") // Set up a channel that we can use to inspect the messages sent @@ -1341,7 +1360,7 @@ func TestOrphanSignatureAnnouncement(t *testing.T) { func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") batch, err := createLocalAnnouncements(0) @@ -1576,7 +1595,7 @@ out: func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") batch, err := createLocalAnnouncements(0) @@ -2016,7 +2035,7 @@ func TestForwardPrivateNodeAnnouncement(t *testing.T) { timestamp = 123456 ) - ctx, err := createTestCtx(t, startingHeight) + ctx, err := createTestCtx(t, startingHeight, false) require.NoError(t, err, "can't create context") // We'll start off by processing a channel announcement without a proof @@ -2115,7 +2134,7 @@ func TestRejectZombieEdge(t *testing.T) { // We'll start by creating our test context with a batch of // announcements. - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "unable to create test context") batch, err := createRemoteAnnouncements(0) @@ -2216,7 +2235,7 @@ func TestProcessZombieEdgeNowLive(t *testing.T) { // We'll start by creating our test context with a batch of // announcements. - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "unable to create test context") batch, err := createRemoteAnnouncements(0) @@ -2373,7 +2392,7 @@ func TestProcessZombieEdgeNowLive(t *testing.T) { func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") batch, err := createLocalAnnouncements(0) @@ -2572,7 +2591,7 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { func TestExtraDataChannelAnnouncementValidation(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") remotePeer := &mockPeer{ @@ -2605,7 +2624,7 @@ func TestExtraDataChannelUpdateValidation(t *testing.T) { t.Parallel() timestamp := testTimestamp - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") remotePeer := &mockPeer{ @@ -2658,7 +2677,7 @@ func TestExtraDataChannelUpdateValidation(t *testing.T) { func TestExtraDataNodeAnnouncementValidation(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") remotePeer := &mockPeer{ @@ -2728,7 +2747,7 @@ func assertProcessAnnouncement(t *testing.T, result chan error) { func TestRetransmit(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") batch, err := createLocalAnnouncements(0) @@ -2834,7 +2853,7 @@ func TestRetransmit(t *testing.T) { func TestNodeAnnouncementNoChannels(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") batch, err := createRemoteAnnouncements(0) @@ -2919,7 +2938,7 @@ func TestNodeAnnouncementNoChannels(t *testing.T) { func TestOptionalFieldsChannelUpdateValidation(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") processRemoteAnnouncement := ctx.gossiper.ProcessRemoteAnnouncement @@ -3018,7 +3037,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { // We'll start by creating our test context and a batch of // announcements. - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "unable to create test context") batch, err := createLocalAnnouncements(0) @@ -3372,7 +3391,7 @@ func TestPropagateChanPolicyUpdate(t *testing.T) { // First, we'll make out test context and add 3 random channels to the // graph. startingHeight := uint32(10) - ctx, err := createTestCtx(t, startingHeight) + ctx, err := createTestCtx(t, startingHeight, false) require.NoError(t, err, "unable to create test context") const numChannels = 3 @@ -3553,7 +3572,7 @@ func TestProcessChannelAnnouncementOptionalMsgFields(t *testing.T) { // We'll start by creating our test context and a set of test channel // announcements. - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "unable to create test context") chanAnn1 := createAnnouncementWithoutProof( @@ -3614,7 +3633,7 @@ func assertMessage(t *testing.T, expected, got lnwire.Message) { func TestSplitAnnouncementsCorrectSubBatches(t *testing.T) { // Create our test harness. const blockHeight = 100 - ctx, err := createTestCtx(t, blockHeight) + ctx, err := createTestCtx(t, blockHeight, false) require.NoError(t, err, "can't create context") const subBatchSize = 10 @@ -3726,7 +3745,7 @@ func (m *SyncManager) markGraphSyncing() { func TestBroadcastAnnsAfterGraphSynced(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, 10) + ctx, err := createTestCtx(t, 10, false) require.NoError(t, err, "can't create context") // We'll mark the graph as not synced. This should prevent us from @@ -3801,7 +3820,7 @@ func TestRateLimitChannelUpdates(t *testing.T) { // Create our test harness. const blockHeight = 100 - ctx, err := createTestCtx(t, blockHeight) + ctx, err := createTestCtx(t, blockHeight, false) require.NoError(t, err, "can't create context") ctx.gossiper.cfg.RebroadcastInterval = time.Hour ctx.gossiper.cfg.MaxChannelUpdateBurst = 5 @@ -3951,7 +3970,7 @@ func TestRateLimitChannelUpdates(t *testing.T) { func TestIgnoreOwnAnnouncement(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") batch, err := createLocalAnnouncements(0) @@ -4095,7 +4114,7 @@ func TestIgnoreOwnAnnouncement(t *testing.T) { func TestRejectCacheChannelAnn(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") // First, we create a channel announcement to send over to our test @@ -4169,3 +4188,134 @@ func TestFutureMsgCacheEviction(t *testing.T) { require.NoError(t, err) require.EqualValues(t, 2, item.height, "should be the second item") } + +// TestChanAnnBanningNonChanPeer asserts that non-channel peers who send bogus +// channel announcements are banned properly. +func TestChanAnnBanningNonChanPeer(t *testing.T) { + t.Parallel() + + ctx, err := createTestCtx(t, 1000, false) + require.NoError(t, err, "can't create context") + + nodePeer1 := &mockPeer{ + remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{}, + } + nodePeer2 := &mockPeer{ + remoteKeyPriv2.PubKey(), nil, nil, atomic.Bool{}, + } + + ctx.router.setAddEdgeErrCode(graph.ErrInvalidFundingOutput) + + // Loop 100 times to get nodePeer banned. + for i := 0; i < 100; i++ { + // Craft a valid channel announcement for a channel we don't + // have. We will ensure that it fails validation by modifying + // the router. + ca, err := createRemoteChannelAnnouncement(uint32(i)) + require.NoError(t, err, "can't create channel announcement") + + select { + case err = <-ctx.gossiper.ProcessRemoteAnnouncement( + ca, nodePeer1, + ): + require.True( + t, graph.IsError( + err, graph.ErrInvalidFundingOutput, + ), + ) + + case <-time.After(2 * time.Second): + t.Fatalf("remote announcement not processed") + } + } + + // The peer should be banned now. + require.True(t, ctx.gossiper.isBanned(nodePeer1.PubKey())) + + // Assert that nodePeer has been disconnected. + require.True(t, nodePeer1.disconnected.Load()) + + ca, err := createRemoteChannelAnnouncement(101) + require.NoError(t, err, "can't create channel announcement") + + // Set the error to ErrChannelSpent so that we can test that the + // gossiper ignores closed channels. + ctx.router.setAddEdgeErrCode(graph.ErrChannelSpent) + + select { + case err = <-ctx.gossiper.ProcessRemoteAnnouncement(ca, nodePeer2): + require.True(t, graph.IsError(err, graph.ErrChannelSpent)) + + case <-time.After(2 * time.Second): + t.Fatalf("remote announcement not processed") + } + + // Check that the announcement's scid is marked as closed. + isClosed, err := ctx.gossiper.cfg.ScidCloser.IsClosedScid( + ca.ShortChannelID, + ) + require.Nil(t, err) + require.True(t, isClosed) + + // Remove the scid from the reject cache. + key := newRejectCacheKey( + ca.ShortChannelID.ToUint64(), + sourceToPub(nodePeer2.IdentityKey()), + ) + + ctx.gossiper.recentRejects.Delete(key) + + // Reset the AddEdge error and pass the same announcement again. An + // error should be returned even though AddEdge won't fail. + ctx.router.resetAddEdgeErrCode() + + select { + case err = <-ctx.gossiper.ProcessRemoteAnnouncement(ca, nodePeer2): + require.NotNil(t, err) + + case <-time.After(2 * time.Second): + t.Fatalf("remote announcement not processed") + } +} + +// TestChanAnnBanningChanPeer asserts that channel peers that are banned don't +// get disconnected. +func TestChanAnnBanningChanPeer(t *testing.T) { + t.Parallel() + + ctx, err := createTestCtx(t, 1000, true) + require.NoError(t, err, "can't create context") + + nodePeer := &mockPeer{remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{}} + + ctx.router.setAddEdgeErrCode(graph.ErrInvalidFundingOutput) + + // Loop 100 times to get nodePeer banned. + for i := 0; i < 100; i++ { + // Craft a valid channel announcement for a channel we don't + // have. We will ensure that it fails validation by modifying + // the router. + ca, err := createRemoteChannelAnnouncement(uint32(i)) + require.NoError(t, err, "can't create channel announcement") + + select { + case err = <-ctx.gossiper.ProcessRemoteAnnouncement( + ca, nodePeer, + ): + require.True( + t, graph.IsError( + err, graph.ErrInvalidFundingOutput, + ), + ) + + case <-time.After(2 * time.Second): + t.Fatalf("remote announcement not processed") + } + } + + // The peer should be banned now. + require.True(t, ctx.gossiper.isBanned(nodePeer.PubKey())) + + // Assert that the peer wasn't disconnected. + require.False(t, nodePeer.disconnected.Load()) +} diff --git a/discovery/mock_test.go b/discovery/mock_test.go index f34a62514..6bd93c29b 100644 --- a/discovery/mock_test.go +++ b/discovery/mock_test.go @@ -161,3 +161,40 @@ func (s *mockMessageStore) MessagesForPeer(pubKey [33]byte) ([]lnwire.Message, e return msgs, nil } + +type mockScidCloser struct { + m map[lnwire.ShortChannelID]struct{} + channelPeer bool + + sync.Mutex +} + +func newMockScidCloser(channelPeer bool) *mockScidCloser { + return &mockScidCloser{ + m: make(map[lnwire.ShortChannelID]struct{}), + channelPeer: channelPeer, + } +} + +func (m *mockScidCloser) PutClosedScid(scid lnwire.ShortChannelID) error { + m.Lock() + m.m[scid] = struct{}{} + m.Unlock() + + return nil +} + +func (m *mockScidCloser) IsClosedScid(scid lnwire.ShortChannelID) (bool, + error) { + + m.Lock() + defer m.Unlock() + + _, ok := m.m[scid] + + return ok, nil +} + +func (m *mockScidCloser) IsChannelPeer(pubkey *btcec.PublicKey) (bool, error) { + return m.channelPeer, nil +} diff --git a/discovery/sync_manager.go b/discovery/sync_manager.go index d3a017256..70d28784b 100644 --- a/discovery/sync_manager.go +++ b/discovery/sync_manager.go @@ -22,6 +22,9 @@ const ( // force a historical sync to ensure we have as much of the public // network as possible. DefaultHistoricalSyncInterval = time.Hour + + // filterSemaSize is the capacity of gossipFilterSema. + filterSemaSize = 5 ) var ( @@ -161,12 +164,22 @@ type SyncManager struct { // duration of the connection. pinnedActiveSyncers map[route.Vertex]*GossipSyncer + // gossipFilterSema contains semaphores for the gossip timestamp + // queries. + gossipFilterSema chan struct{} + wg sync.WaitGroup quit chan struct{} } // newSyncManager constructs a new SyncManager backed by the given config. func newSyncManager(cfg *SyncManagerCfg) *SyncManager { + + filterSema := make(chan struct{}, filterSemaSize) + for i := 0; i < filterSemaSize; i++ { + filterSema <- struct{}{} + } + return &SyncManager{ cfg: *cfg, newSyncers: make(chan *newSyncer), @@ -178,7 +191,8 @@ func newSyncManager(cfg *SyncManagerCfg) *SyncManager { pinnedActiveSyncers: make( map[route.Vertex]*GossipSyncer, len(cfg.PinnedSyncers), ), - quit: make(chan struct{}), + gossipFilterSema: filterSema, + quit: make(chan struct{}), } } @@ -507,7 +521,7 @@ func (m *SyncManager) createGossipSyncer(peer lnpeer.Peer) *GossipSyncer { maxQueryChanRangeReplies: maxQueryChanRangeReplies, noTimestampQueryOption: m.cfg.NoTimestampQueries, isStillZombieChannel: m.cfg.IsStillZombieChannel, - }) + }, m.gossipFilterSema) // Gossip syncers are initialized by default in a PassiveSync type // and chansSynced state so that they can reply to any peer queries or diff --git a/discovery/syncer.go b/discovery/syncer.go index 512c9f631..b6adb447a 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -181,9 +181,6 @@ const ( // requestBatchSize is the maximum number of channels we will query the // remote peer for in a QueryShortChanIDs message. requestBatchSize = 500 - - // filterSemaSize is the capacity of gossipFilterSema. - filterSemaSize = 5 ) var ( @@ -400,9 +397,11 @@ type GossipSyncer struct { // GossipSyncer reaches its terminal chansSynced state. syncedSignal chan struct{} - sync.Mutex + // syncerSema is used to more finely control the syncer's ability to + // respond to gossip timestamp range messages. + syncerSema chan struct{} - gossipFilterSema chan struct{} + sync.Mutex quit chan struct{} wg sync.WaitGroup @@ -410,7 +409,7 @@ type GossipSyncer struct { // newGossipSyncer returns a new instance of the GossipSyncer populated using // the passed config. -func newGossipSyncer(cfg gossipSyncerCfg) *GossipSyncer { +func newGossipSyncer(cfg gossipSyncerCfg, sema chan struct{}) *GossipSyncer { // If no parameter was specified for max undelayed query replies, set it // to the default of 5 queries. if cfg.maxUndelayedQueryReplies <= 0 { @@ -432,11 +431,6 @@ func newGossipSyncer(cfg gossipSyncerCfg) *GossipSyncer { interval, cfg.maxUndelayedQueryReplies, ) - filterSema := make(chan struct{}, filterSemaSize) - for i := 0; i < filterSemaSize; i++ { - filterSema <- struct{}{} - } - return &GossipSyncer{ cfg: cfg, rateLimiter: rateLimiter, @@ -444,7 +438,7 @@ func newGossipSyncer(cfg gossipSyncerCfg) *GossipSyncer { historicalSyncReqs: make(chan *historicalSyncReq), gossipMsgs: make(chan lnwire.Message, 100), queryMsgs: make(chan lnwire.Message, 100), - gossipFilterSema: filterSema, + syncerSema: sema, quit: make(chan struct{}), } } @@ -1332,12 +1326,25 @@ func (g *GossipSyncer) ApplyGossipFilter(filter *lnwire.GossipTimestampRange) er return nil } + select { + case <-g.syncerSema: + case <-g.quit: + return ErrGossipSyncerExiting + } + + // We don't put this in a defer because if the goroutine is launched, + // it needs to be called when the goroutine is stopped. + returnSema := func() { + g.syncerSema <- struct{}{} + } + // Now that the remote peer has applied their filter, we'll query the // database for all the messages that are beyond this filter. newUpdatestoSend, err := g.cfg.channelSeries.UpdatesInHorizon( g.cfg.chainHash, startTime, endTime, ) if err != nil { + returnSema() return err } @@ -1347,22 +1354,15 @@ func (g *GossipSyncer) ApplyGossipFilter(filter *lnwire.GossipTimestampRange) er // If we don't have any to send, then we can return early. if len(newUpdatestoSend) == 0 { + returnSema() return nil } - select { - case <-g.gossipFilterSema: - case <-g.quit: - return ErrGossipSyncerExiting - } - // We'll conclude by launching a goroutine to send out any updates. g.wg.Add(1) go func() { defer g.wg.Done() - defer func() { - g.gossipFilterSema <- struct{}{} - }() + defer returnSema() for _, msg := range newUpdatestoSend { err := g.cfg.sendToPeerSync(msg) diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 15e2442e1..bb6aec590 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -211,7 +211,11 @@ func newTestSyncer(hID lnwire.ShortChannelID, markGraphSynced: func() {}, maxQueryChanRangeReplies: maxQueryChanRangeReplies, } - syncer := newGossipSyncer(cfg) + + syncerSema := make(chan struct{}, 1) + syncerSema <- struct{}{} + + syncer := newGossipSyncer(cfg, syncerSema) return msgChan, syncer, cfg.channelSeries.(*mockChannelGraphTimeSeries) } diff --git a/server.go b/server.go index 26fdce6ea..75fd02b97 100644 --- a/server.go +++ b/server.go @@ -1026,6 +1026,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return nil, err } + scidCloserMan := discovery.NewScidCloserMan(s.graphDB, s.chanStateDB) + s.authGossiper = discovery.New(discovery.Config{ Graph: s.graphBuilder, Notifier: s.cc.ChainNotifier, @@ -1063,6 +1065,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, GetAlias: s.aliasMgr.GetPeerAlias, FindChannel: s.findChannel, IsStillZombieChannel: s.graphBuilder.IsZombieChannel, + ScidCloser: scidCloserMan, }, nodeKeyDesc) //nolint:lll @@ -3639,11 +3642,34 @@ func (s *server) InboundPeerConnected(conn net.Conn) { } nodePub := conn.(*brontide.Conn).RemotePub() - pubStr := string(nodePub.SerializeCompressed()) + pubSer := nodePub.SerializeCompressed() + pubStr := string(pubSer) + + var pubBytes [33]byte + copy(pubBytes[:], pubSer) s.mu.Lock() defer s.mu.Unlock() + // If the remote node's public key is banned, drop the connection. + shouldDc, dcErr := s.authGossiper.ShouldDisconnect(nodePub) + if dcErr != nil { + srvrLog.Errorf("Unable to check if we should disconnect "+ + "peer: %v", dcErr) + conn.Close() + + return + } + + if shouldDc { + srvrLog.Debugf("Dropping connection for %v since they are "+ + "banned.", pubSer) + + conn.Close() + + return + } + // If we already have an outbound connection to this peer, then ignore // this new connection. if p, ok := s.outboundPeers[pubStr]; ok { @@ -3726,11 +3752,38 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) } nodePub := conn.(*brontide.Conn).RemotePub() - pubStr := string(nodePub.SerializeCompressed()) + pubSer := nodePub.SerializeCompressed() + pubStr := string(pubSer) + + var pubBytes [33]byte + copy(pubBytes[:], pubSer) s.mu.Lock() defer s.mu.Unlock() + // If the remote node's public key is banned, drop the connection. + shouldDc, dcErr := s.authGossiper.ShouldDisconnect(nodePub) + if dcErr != nil { + srvrLog.Errorf("Unable to check if we should disconnect "+ + "peer: %v", dcErr) + conn.Close() + + return + } + + if shouldDc { + srvrLog.Debugf("Dropping connection for %v since they are "+ + "banned.", pubSer) + + if connReq != nil { + s.connMgr.Remove(connReq.ID()) + } + + conn.Close() + + return + } + // If we already have an inbound connection to this peer, then ignore // this new connection. if p, ok := s.inboundPeers[pubStr]; ok { From 95acc780136ff1e26dce9566de8afcb76c95466a Mon Sep 17 00:00:00 2001 From: Eugene Siegel Date: Fri, 16 Aug 2024 14:34:32 -0400 Subject: [PATCH 7/7] release-notes: update for 0.18.3 --- docs/release-notes/release-notes-0.18.3.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/release-notes/release-notes-0.18.3.md b/docs/release-notes/release-notes-0.18.3.md index 1e8279137..713b7d5eb 100644 --- a/docs/release-notes/release-notes-0.18.3.md +++ b/docs/release-notes/release-notes-0.18.3.md @@ -78,6 +78,11 @@ blinded path expiry. # New Features ## Functional Enhancements + +* LND will now [temporarily ban peers](https://github.com/lightningnetwork/lnd/pull/9009) +that send too many invalid `ChannelAnnouncement`. This is only done for LND nodes +that validate `ChannelAnnouncement` messages. + ## RPC Additions * The [SendPaymentRequest](https://github.com/lightningnetwork/lnd/pull/8734)