diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 61c123d55..84fae767f 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -256,6 +256,11 @@ type Config struct { // here? AnnSigner lnwallet.MessageSigner + // ScidCloser is an instance of ClosedChannelTracker that helps the + // gossiper cut down on spam channel announcements for already closed + // channels. + ScidCloser ClosedChannelTracker + // NumActiveSyncers is the number of peers for which we should have // active syncers with. After reaching NumActiveSyncers, any future // gossip syncers will be passive. @@ -434,6 +439,9 @@ type AuthenticatedGossiper struct { // ChannelAnnouncement for the channel is received. prematureChannelUpdates *lru.Cache[uint64, *cachedNetworkMsg] + // banman tracks our peer's ban status. + banman *banman + // networkMsgs is a channel that carries new network broadcasted // message from outside the gossiper service to be processed by the // networkHandler. @@ -512,6 +520,7 @@ func New(cfg Config, selfKeyDesc *keychain.KeyDescriptor) *AuthenticatedGossiper maxRejectedUpdates, ), chanUpdateRateLimiter: make(map[uint64][2]*rate.Limiter), + banman: newBanman(), } gossiper.syncMgr = newSyncManager(&SyncManagerCfg{ @@ -606,6 +615,8 @@ func (d *AuthenticatedGossiper) start() error { d.syncMgr.Start() + d.banman.start() + // Start receiving blocks in its dedicated goroutine. d.wg.Add(2) go d.syncBlockHeight() @@ -762,6 +773,8 @@ func (d *AuthenticatedGossiper) stop() { d.syncMgr.Stop() + d.banman.stop() + close(d.quit) d.wg.Wait() @@ -2459,6 +2472,51 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, return nil, true } + // Check if the channel is already closed in which case we can ignore + // it. + closed, err := d.cfg.ScidCloser.IsClosedScid(scid) + if err != nil { + log.Errorf("failed to check if scid %v is closed: %v", scid, + err) + nMsg.err <- err + + return nil, false + } + + if closed { + err = fmt.Errorf("ignoring closed channel %v", scid) + log.Error(err) + + // If this is an announcement from us, we'll just ignore it. + if !nMsg.isRemote { + nMsg.err <- err + return nil, false + } + + // Increment the peer's ban score if they are sending closed + // channel announcements. + d.banman.incrementBanScore(nMsg.peer.PubKey()) + + // If the peer is banned and not a channel peer, we'll + // disconnect them. + shouldDc, dcErr := d.ShouldDisconnect(nMsg.peer.IdentityKey()) + if dcErr != nil { + log.Errorf("failed to check if we should disconnect "+ + "peer: %v", dcErr) + nMsg.err <- dcErr + + return nil, false + } + + if shouldDc { + nMsg.peer.Disconnect(ErrPeerBanned) + } + + nMsg.err <- err + + return nil, false + } + // If this is a remote channel announcement, then we'll validate all // the signatures within the proof as it should be well formed. var proof *models.ChannelAuthProof @@ -2533,7 +2591,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // database and is now making decisions based on this DB state, before // it writes to the DB. d.channelMtx.Lock(scid.ToUint64()) - err := d.cfg.Graph.AddEdge(edge, ops...) + err = d.cfg.Graph.AddEdge(edge, ops...) if err != nil { log.Debugf("Graph rejected edge for short_chan_id(%v): %v", scid.ToUint64(), err) @@ -2543,7 +2601,8 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // If the edge was rejected due to already being known, then it // may be the case that this new message has a fresh channel // proof, so we'll check. - if graph.IsError(err, graph.ErrIgnored) { + switch { + case graph.IsError(err, graph.ErrIgnored): // Attempt to process the rejected message to see if we // get any new announcements. anns, rErr := d.processRejectedEdge(ann, proof) @@ -2571,7 +2630,55 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, nMsg.err <- nil return anns, true - } else { + + case graph.IsError( + err, graph.ErrNoFundingTransaction, + graph.ErrInvalidFundingOutput, + ): + key := newRejectCacheKey( + scid.ToUint64(), + sourceToPub(nMsg.source), + ) + _, _ = d.recentRejects.Put(key, &cachedReject{}) + + // Increment the peer's ban score. We check isRemote + // so we don't actually ban the peer in case of a local + // bug. + if nMsg.isRemote { + d.banman.incrementBanScore(nMsg.peer.PubKey()) + } + + case graph.IsError(err, graph.ErrChannelSpent): + key := newRejectCacheKey( + scid.ToUint64(), + sourceToPub(nMsg.source), + ) + _, _ = d.recentRejects.Put(key, &cachedReject{}) + + // Since this channel has already been closed, we'll + // add it to the graph's closed channel index such that + // we won't attempt to do expensive validation checks + // on it again. + // TODO: Populate the ScidCloser by using closed + // channel notifications. + dbErr := d.cfg.ScidCloser.PutClosedScid(scid) + if dbErr != nil { + log.Errorf("failed to mark scid(%v) as "+ + "closed: %v", scid, dbErr) + + nMsg.err <- dbErr + + return nil, false + } + + // Increment the peer's ban score. We check isRemote + // so we don't accidentally ban ourselves in case of a + // bug. + if nMsg.isRemote { + d.banman.incrementBanScore(nMsg.peer.PubKey()) + } + + default: // Otherwise, this is just a regular rejected edge. key := newRejectCacheKey( scid.ToUint64(), @@ -2580,7 +2687,29 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, _, _ = d.recentRejects.Put(key, &cachedReject{}) } + if !nMsg.isRemote { + log.Errorf("failed to add edge for local channel: %v", + err) + nMsg.err <- err + + return nil, false + } + + shouldDc, dcErr := d.ShouldDisconnect(nMsg.peer.IdentityKey()) + if dcErr != nil { + log.Errorf("failed to check if we should disconnect "+ + "peer: %v", dcErr) + nMsg.err <- dcErr + + return nil, false + } + + if shouldDc { + nMsg.peer.Disconnect(ErrPeerBanned) + } + nMsg.err <- err + return nil, false } @@ -3385,3 +3514,36 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, nMsg.err <- nil return announcements, true } + +// isBanned returns true if the peer identified by pubkey is banned for sending +// invalid channel announcements. +func (d *AuthenticatedGossiper) isBanned(pubkey [33]byte) bool { + return d.banman.isBanned(pubkey) +} + +// ShouldDisconnect returns true if we should disconnect the peer identified by +// pubkey. +func (d *AuthenticatedGossiper) ShouldDisconnect(pubkey *btcec.PublicKey) ( + bool, error) { + + pubkeySer := pubkey.SerializeCompressed() + + var pubkeyBytes [33]byte + copy(pubkeyBytes[:], pubkeySer) + + // If the public key is banned, check whether or not this is a channel + // peer. + if d.isBanned(pubkeyBytes) { + isChanPeer, err := d.cfg.ScidCloser.IsChannelPeer(pubkey) + if err != nil { + return false, err + } + + // We should only disconnect non-channel peers. + if !isChanPeer { + return true, nil + } + } + + return false, nil +} diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 5bc84d390..c7cb149cf 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -25,6 +25,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -90,12 +91,13 @@ func makeTestDB(t *testing.T) (*channeldb.DB, error) { type mockGraphSource struct { bestHeight uint32 - mu sync.Mutex - nodes []channeldb.LightningNode - infos map[uint64]models.ChannelEdgeInfo - edges map[uint64][]models.ChannelEdgePolicy - zombies map[uint64][][33]byte - chansToReject map[uint64]struct{} + mu sync.Mutex + nodes []channeldb.LightningNode + infos map[uint64]models.ChannelEdgeInfo + edges map[uint64][]models.ChannelEdgePolicy + zombies map[uint64][][33]byte + chansToReject map[uint64]struct{} + addEdgeErrCode fn.Option[graph.ErrorCode] } func newMockRouter(height uint32) *mockGraphSource { @@ -126,6 +128,12 @@ func (r *mockGraphSource) AddEdge(info *models.ChannelEdgeInfo, r.mu.Lock() defer r.mu.Unlock() + if r.addEdgeErrCode.IsSome() { + return graph.NewErrf( + r.addEdgeErrCode.UnsafeFromSome(), "received error", + ) + } + if _, ok := r.infos[info.ChannelID]; ok { return errors.New("info already exist") } @@ -138,6 +146,14 @@ func (r *mockGraphSource) AddEdge(info *models.ChannelEdgeInfo, return nil } +func (r *mockGraphSource) resetAddEdgeErrCode() { + r.addEdgeErrCode = fn.None[graph.ErrorCode]() +} + +func (r *mockGraphSource) setAddEdgeErrCode(code graph.ErrorCode) { + r.addEdgeErrCode = fn.Some[graph.ErrorCode](code) +} + func (r *mockGraphSource) queueValidationFail(chanID uint64) { r.mu.Lock() defer r.mu.Unlock() @@ -707,7 +723,9 @@ type testCtx struct { broadcastedMessage chan msgWithSenders } -func createTestCtx(t *testing.T, startHeight uint32) (*testCtx, error) { +func createTestCtx(t *testing.T, startHeight uint32, isChanPeer bool) ( + *testCtx, error) { + // Next we'll initialize an instance of the channel router with mock // versions of the chain and channel notifier. As we don't need to test // any p2p functionality, the peer send and switch send, @@ -803,6 +821,7 @@ func createTestCtx(t *testing.T, startHeight uint32) (*testCtx, error) { FindBaseByAlias: findBaseByAlias, GetAlias: getAlias, FindChannel: mockFindChannel, + ScidCloser: newMockScidCloser(isChanPeer), }, selfKeyDesc) if err := gossiper.Start(); err != nil { @@ -831,7 +850,7 @@ func TestProcessAnnouncement(t *testing.T) { t.Parallel() timestamp := testTimestamp - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") assertSenderExistence := func(sender *btcec.PublicKey, msg msgWithSenders) { @@ -947,7 +966,7 @@ func TestPrematureAnnouncement(t *testing.T) { timestamp := testTimestamp - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") _, err = createNodeAnnouncement(remoteKeyPriv1, timestamp) @@ -978,7 +997,7 @@ func TestPrematureAnnouncement(t *testing.T) { func TestSignatureAnnouncementLocalFirst(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") // Set up a channel that we can use to inspect the messages sent @@ -1154,7 +1173,7 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) { func TestOrphanSignatureAnnouncement(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") // Set up a channel that we can use to inspect the messages sent @@ -1341,7 +1360,7 @@ func TestOrphanSignatureAnnouncement(t *testing.T) { func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") batch, err := createLocalAnnouncements(0) @@ -1576,7 +1595,7 @@ out: func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") batch, err := createLocalAnnouncements(0) @@ -2016,7 +2035,7 @@ func TestForwardPrivateNodeAnnouncement(t *testing.T) { timestamp = 123456 ) - ctx, err := createTestCtx(t, startingHeight) + ctx, err := createTestCtx(t, startingHeight, false) require.NoError(t, err, "can't create context") // We'll start off by processing a channel announcement without a proof @@ -2115,7 +2134,7 @@ func TestRejectZombieEdge(t *testing.T) { // We'll start by creating our test context with a batch of // announcements. - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "unable to create test context") batch, err := createRemoteAnnouncements(0) @@ -2216,7 +2235,7 @@ func TestProcessZombieEdgeNowLive(t *testing.T) { // We'll start by creating our test context with a batch of // announcements. - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "unable to create test context") batch, err := createRemoteAnnouncements(0) @@ -2373,7 +2392,7 @@ func TestProcessZombieEdgeNowLive(t *testing.T) { func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") batch, err := createLocalAnnouncements(0) @@ -2572,7 +2591,7 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { func TestExtraDataChannelAnnouncementValidation(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") remotePeer := &mockPeer{ @@ -2605,7 +2624,7 @@ func TestExtraDataChannelUpdateValidation(t *testing.T) { t.Parallel() timestamp := testTimestamp - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") remotePeer := &mockPeer{ @@ -2658,7 +2677,7 @@ func TestExtraDataChannelUpdateValidation(t *testing.T) { func TestExtraDataNodeAnnouncementValidation(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") remotePeer := &mockPeer{ @@ -2728,7 +2747,7 @@ func assertProcessAnnouncement(t *testing.T, result chan error) { func TestRetransmit(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") batch, err := createLocalAnnouncements(0) @@ -2834,7 +2853,7 @@ func TestRetransmit(t *testing.T) { func TestNodeAnnouncementNoChannels(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") batch, err := createRemoteAnnouncements(0) @@ -2919,7 +2938,7 @@ func TestNodeAnnouncementNoChannels(t *testing.T) { func TestOptionalFieldsChannelUpdateValidation(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") processRemoteAnnouncement := ctx.gossiper.ProcessRemoteAnnouncement @@ -3018,7 +3037,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { // We'll start by creating our test context and a batch of // announcements. - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "unable to create test context") batch, err := createLocalAnnouncements(0) @@ -3372,7 +3391,7 @@ func TestPropagateChanPolicyUpdate(t *testing.T) { // First, we'll make out test context and add 3 random channels to the // graph. startingHeight := uint32(10) - ctx, err := createTestCtx(t, startingHeight) + ctx, err := createTestCtx(t, startingHeight, false) require.NoError(t, err, "unable to create test context") const numChannels = 3 @@ -3553,7 +3572,7 @@ func TestProcessChannelAnnouncementOptionalMsgFields(t *testing.T) { // We'll start by creating our test context and a set of test channel // announcements. - ctx, err := createTestCtx(t, 0) + ctx, err := createTestCtx(t, 0, false) require.NoError(t, err, "unable to create test context") chanAnn1 := createAnnouncementWithoutProof( @@ -3614,7 +3633,7 @@ func assertMessage(t *testing.T, expected, got lnwire.Message) { func TestSplitAnnouncementsCorrectSubBatches(t *testing.T) { // Create our test harness. const blockHeight = 100 - ctx, err := createTestCtx(t, blockHeight) + ctx, err := createTestCtx(t, blockHeight, false) require.NoError(t, err, "can't create context") const subBatchSize = 10 @@ -3726,7 +3745,7 @@ func (m *SyncManager) markGraphSyncing() { func TestBroadcastAnnsAfterGraphSynced(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, 10) + ctx, err := createTestCtx(t, 10, false) require.NoError(t, err, "can't create context") // We'll mark the graph as not synced. This should prevent us from @@ -3801,7 +3820,7 @@ func TestRateLimitChannelUpdates(t *testing.T) { // Create our test harness. const blockHeight = 100 - ctx, err := createTestCtx(t, blockHeight) + ctx, err := createTestCtx(t, blockHeight, false) require.NoError(t, err, "can't create context") ctx.gossiper.cfg.RebroadcastInterval = time.Hour ctx.gossiper.cfg.MaxChannelUpdateBurst = 5 @@ -3951,7 +3970,7 @@ func TestRateLimitChannelUpdates(t *testing.T) { func TestIgnoreOwnAnnouncement(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") batch, err := createLocalAnnouncements(0) @@ -4095,7 +4114,7 @@ func TestIgnoreOwnAnnouncement(t *testing.T) { func TestRejectCacheChannelAnn(t *testing.T) { t.Parallel() - ctx, err := createTestCtx(t, proofMatureDelta) + ctx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") // First, we create a channel announcement to send over to our test @@ -4169,3 +4188,134 @@ func TestFutureMsgCacheEviction(t *testing.T) { require.NoError(t, err) require.EqualValues(t, 2, item.height, "should be the second item") } + +// TestChanAnnBanningNonChanPeer asserts that non-channel peers who send bogus +// channel announcements are banned properly. +func TestChanAnnBanningNonChanPeer(t *testing.T) { + t.Parallel() + + ctx, err := createTestCtx(t, 1000, false) + require.NoError(t, err, "can't create context") + + nodePeer1 := &mockPeer{ + remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{}, + } + nodePeer2 := &mockPeer{ + remoteKeyPriv2.PubKey(), nil, nil, atomic.Bool{}, + } + + ctx.router.setAddEdgeErrCode(graph.ErrInvalidFundingOutput) + + // Loop 100 times to get nodePeer banned. + for i := 0; i < 100; i++ { + // Craft a valid channel announcement for a channel we don't + // have. We will ensure that it fails validation by modifying + // the router. + ca, err := createRemoteChannelAnnouncement(uint32(i)) + require.NoError(t, err, "can't create channel announcement") + + select { + case err = <-ctx.gossiper.ProcessRemoteAnnouncement( + ca, nodePeer1, + ): + require.True( + t, graph.IsError( + err, graph.ErrInvalidFundingOutput, + ), + ) + + case <-time.After(2 * time.Second): + t.Fatalf("remote announcement not processed") + } + } + + // The peer should be banned now. + require.True(t, ctx.gossiper.isBanned(nodePeer1.PubKey())) + + // Assert that nodePeer has been disconnected. + require.True(t, nodePeer1.disconnected.Load()) + + ca, err := createRemoteChannelAnnouncement(101) + require.NoError(t, err, "can't create channel announcement") + + // Set the error to ErrChannelSpent so that we can test that the + // gossiper ignores closed channels. + ctx.router.setAddEdgeErrCode(graph.ErrChannelSpent) + + select { + case err = <-ctx.gossiper.ProcessRemoteAnnouncement(ca, nodePeer2): + require.True(t, graph.IsError(err, graph.ErrChannelSpent)) + + case <-time.After(2 * time.Second): + t.Fatalf("remote announcement not processed") + } + + // Check that the announcement's scid is marked as closed. + isClosed, err := ctx.gossiper.cfg.ScidCloser.IsClosedScid( + ca.ShortChannelID, + ) + require.Nil(t, err) + require.True(t, isClosed) + + // Remove the scid from the reject cache. + key := newRejectCacheKey( + ca.ShortChannelID.ToUint64(), + sourceToPub(nodePeer2.IdentityKey()), + ) + + ctx.gossiper.recentRejects.Delete(key) + + // Reset the AddEdge error and pass the same announcement again. An + // error should be returned even though AddEdge won't fail. + ctx.router.resetAddEdgeErrCode() + + select { + case err = <-ctx.gossiper.ProcessRemoteAnnouncement(ca, nodePeer2): + require.NotNil(t, err) + + case <-time.After(2 * time.Second): + t.Fatalf("remote announcement not processed") + } +} + +// TestChanAnnBanningChanPeer asserts that channel peers that are banned don't +// get disconnected. +func TestChanAnnBanningChanPeer(t *testing.T) { + t.Parallel() + + ctx, err := createTestCtx(t, 1000, true) + require.NoError(t, err, "can't create context") + + nodePeer := &mockPeer{remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{}} + + ctx.router.setAddEdgeErrCode(graph.ErrInvalidFundingOutput) + + // Loop 100 times to get nodePeer banned. + for i := 0; i < 100; i++ { + // Craft a valid channel announcement for a channel we don't + // have. We will ensure that it fails validation by modifying + // the router. + ca, err := createRemoteChannelAnnouncement(uint32(i)) + require.NoError(t, err, "can't create channel announcement") + + select { + case err = <-ctx.gossiper.ProcessRemoteAnnouncement( + ca, nodePeer, + ): + require.True( + t, graph.IsError( + err, graph.ErrInvalidFundingOutput, + ), + ) + + case <-time.After(2 * time.Second): + t.Fatalf("remote announcement not processed") + } + } + + // The peer should be banned now. + require.True(t, ctx.gossiper.isBanned(nodePeer.PubKey())) + + // Assert that the peer wasn't disconnected. + require.False(t, nodePeer.disconnected.Load()) +} diff --git a/discovery/mock_test.go b/discovery/mock_test.go index f34a62514..6bd93c29b 100644 --- a/discovery/mock_test.go +++ b/discovery/mock_test.go @@ -161,3 +161,40 @@ func (s *mockMessageStore) MessagesForPeer(pubKey [33]byte) ([]lnwire.Message, e return msgs, nil } + +type mockScidCloser struct { + m map[lnwire.ShortChannelID]struct{} + channelPeer bool + + sync.Mutex +} + +func newMockScidCloser(channelPeer bool) *mockScidCloser { + return &mockScidCloser{ + m: make(map[lnwire.ShortChannelID]struct{}), + channelPeer: channelPeer, + } +} + +func (m *mockScidCloser) PutClosedScid(scid lnwire.ShortChannelID) error { + m.Lock() + m.m[scid] = struct{}{} + m.Unlock() + + return nil +} + +func (m *mockScidCloser) IsClosedScid(scid lnwire.ShortChannelID) (bool, + error) { + + m.Lock() + defer m.Unlock() + + _, ok := m.m[scid] + + return ok, nil +} + +func (m *mockScidCloser) IsChannelPeer(pubkey *btcec.PublicKey) (bool, error) { + return m.channelPeer, nil +} diff --git a/discovery/sync_manager.go b/discovery/sync_manager.go index d3a017256..70d28784b 100644 --- a/discovery/sync_manager.go +++ b/discovery/sync_manager.go @@ -22,6 +22,9 @@ const ( // force a historical sync to ensure we have as much of the public // network as possible. DefaultHistoricalSyncInterval = time.Hour + + // filterSemaSize is the capacity of gossipFilterSema. + filterSemaSize = 5 ) var ( @@ -161,12 +164,22 @@ type SyncManager struct { // duration of the connection. pinnedActiveSyncers map[route.Vertex]*GossipSyncer + // gossipFilterSema contains semaphores for the gossip timestamp + // queries. + gossipFilterSema chan struct{} + wg sync.WaitGroup quit chan struct{} } // newSyncManager constructs a new SyncManager backed by the given config. func newSyncManager(cfg *SyncManagerCfg) *SyncManager { + + filterSema := make(chan struct{}, filterSemaSize) + for i := 0; i < filterSemaSize; i++ { + filterSema <- struct{}{} + } + return &SyncManager{ cfg: *cfg, newSyncers: make(chan *newSyncer), @@ -178,7 +191,8 @@ func newSyncManager(cfg *SyncManagerCfg) *SyncManager { pinnedActiveSyncers: make( map[route.Vertex]*GossipSyncer, len(cfg.PinnedSyncers), ), - quit: make(chan struct{}), + gossipFilterSema: filterSema, + quit: make(chan struct{}), } } @@ -507,7 +521,7 @@ func (m *SyncManager) createGossipSyncer(peer lnpeer.Peer) *GossipSyncer { maxQueryChanRangeReplies: maxQueryChanRangeReplies, noTimestampQueryOption: m.cfg.NoTimestampQueries, isStillZombieChannel: m.cfg.IsStillZombieChannel, - }) + }, m.gossipFilterSema) // Gossip syncers are initialized by default in a PassiveSync type // and chansSynced state so that they can reply to any peer queries or diff --git a/discovery/syncer.go b/discovery/syncer.go index 512c9f631..b6adb447a 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -181,9 +181,6 @@ const ( // requestBatchSize is the maximum number of channels we will query the // remote peer for in a QueryShortChanIDs message. requestBatchSize = 500 - - // filterSemaSize is the capacity of gossipFilterSema. - filterSemaSize = 5 ) var ( @@ -400,9 +397,11 @@ type GossipSyncer struct { // GossipSyncer reaches its terminal chansSynced state. syncedSignal chan struct{} - sync.Mutex + // syncerSema is used to more finely control the syncer's ability to + // respond to gossip timestamp range messages. + syncerSema chan struct{} - gossipFilterSema chan struct{} + sync.Mutex quit chan struct{} wg sync.WaitGroup @@ -410,7 +409,7 @@ type GossipSyncer struct { // newGossipSyncer returns a new instance of the GossipSyncer populated using // the passed config. -func newGossipSyncer(cfg gossipSyncerCfg) *GossipSyncer { +func newGossipSyncer(cfg gossipSyncerCfg, sema chan struct{}) *GossipSyncer { // If no parameter was specified for max undelayed query replies, set it // to the default of 5 queries. if cfg.maxUndelayedQueryReplies <= 0 { @@ -432,11 +431,6 @@ func newGossipSyncer(cfg gossipSyncerCfg) *GossipSyncer { interval, cfg.maxUndelayedQueryReplies, ) - filterSema := make(chan struct{}, filterSemaSize) - for i := 0; i < filterSemaSize; i++ { - filterSema <- struct{}{} - } - return &GossipSyncer{ cfg: cfg, rateLimiter: rateLimiter, @@ -444,7 +438,7 @@ func newGossipSyncer(cfg gossipSyncerCfg) *GossipSyncer { historicalSyncReqs: make(chan *historicalSyncReq), gossipMsgs: make(chan lnwire.Message, 100), queryMsgs: make(chan lnwire.Message, 100), - gossipFilterSema: filterSema, + syncerSema: sema, quit: make(chan struct{}), } } @@ -1332,12 +1326,25 @@ func (g *GossipSyncer) ApplyGossipFilter(filter *lnwire.GossipTimestampRange) er return nil } + select { + case <-g.syncerSema: + case <-g.quit: + return ErrGossipSyncerExiting + } + + // We don't put this in a defer because if the goroutine is launched, + // it needs to be called when the goroutine is stopped. + returnSema := func() { + g.syncerSema <- struct{}{} + } + // Now that the remote peer has applied their filter, we'll query the // database for all the messages that are beyond this filter. newUpdatestoSend, err := g.cfg.channelSeries.UpdatesInHorizon( g.cfg.chainHash, startTime, endTime, ) if err != nil { + returnSema() return err } @@ -1347,22 +1354,15 @@ func (g *GossipSyncer) ApplyGossipFilter(filter *lnwire.GossipTimestampRange) er // If we don't have any to send, then we can return early. if len(newUpdatestoSend) == 0 { + returnSema() return nil } - select { - case <-g.gossipFilterSema: - case <-g.quit: - return ErrGossipSyncerExiting - } - // We'll conclude by launching a goroutine to send out any updates. g.wg.Add(1) go func() { defer g.wg.Done() - defer func() { - g.gossipFilterSema <- struct{}{} - }() + defer returnSema() for _, msg := range newUpdatestoSend { err := g.cfg.sendToPeerSync(msg) diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 15e2442e1..bb6aec590 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -211,7 +211,11 @@ func newTestSyncer(hID lnwire.ShortChannelID, markGraphSynced: func() {}, maxQueryChanRangeReplies: maxQueryChanRangeReplies, } - syncer := newGossipSyncer(cfg) + + syncerSema := make(chan struct{}, 1) + syncerSema <- struct{}{} + + syncer := newGossipSyncer(cfg, syncerSema) return msgChan, syncer, cfg.channelSeries.(*mockChannelGraphTimeSeries) } diff --git a/server.go b/server.go index 26fdce6ea..75fd02b97 100644 --- a/server.go +++ b/server.go @@ -1026,6 +1026,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return nil, err } + scidCloserMan := discovery.NewScidCloserMan(s.graphDB, s.chanStateDB) + s.authGossiper = discovery.New(discovery.Config{ Graph: s.graphBuilder, Notifier: s.cc.ChainNotifier, @@ -1063,6 +1065,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, GetAlias: s.aliasMgr.GetPeerAlias, FindChannel: s.findChannel, IsStillZombieChannel: s.graphBuilder.IsZombieChannel, + ScidCloser: scidCloserMan, }, nodeKeyDesc) //nolint:lll @@ -3639,11 +3642,34 @@ func (s *server) InboundPeerConnected(conn net.Conn) { } nodePub := conn.(*brontide.Conn).RemotePub() - pubStr := string(nodePub.SerializeCompressed()) + pubSer := nodePub.SerializeCompressed() + pubStr := string(pubSer) + + var pubBytes [33]byte + copy(pubBytes[:], pubSer) s.mu.Lock() defer s.mu.Unlock() + // If the remote node's public key is banned, drop the connection. + shouldDc, dcErr := s.authGossiper.ShouldDisconnect(nodePub) + if dcErr != nil { + srvrLog.Errorf("Unable to check if we should disconnect "+ + "peer: %v", dcErr) + conn.Close() + + return + } + + if shouldDc { + srvrLog.Debugf("Dropping connection for %v since they are "+ + "banned.", pubSer) + + conn.Close() + + return + } + // If we already have an outbound connection to this peer, then ignore // this new connection. if p, ok := s.outboundPeers[pubStr]; ok { @@ -3726,11 +3752,38 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) } nodePub := conn.(*brontide.Conn).RemotePub() - pubStr := string(nodePub.SerializeCompressed()) + pubSer := nodePub.SerializeCompressed() + pubStr := string(pubSer) + + var pubBytes [33]byte + copy(pubBytes[:], pubSer) s.mu.Lock() defer s.mu.Unlock() + // If the remote node's public key is banned, drop the connection. + shouldDc, dcErr := s.authGossiper.ShouldDisconnect(nodePub) + if dcErr != nil { + srvrLog.Errorf("Unable to check if we should disconnect "+ + "peer: %v", dcErr) + conn.Close() + + return + } + + if shouldDc { + srvrLog.Debugf("Dropping connection for %v since they are "+ + "banned.", pubSer) + + if connReq != nil { + s.connMgr.Remove(connReq.ID()) + } + + conn.Close() + + return + } + // If we already have an inbound connection to this peer, then ignore // this new connection. if p, ok := s.inboundPeers[pubStr]; ok {