diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index a5afb2def..04faa1231 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -132,6 +132,22 @@ type staleTowerMsg struct { errChan chan error } +// deactivateTowerMsg is an internal message we'll use within the TowerClient +// to signal that a tower should be marked as inactive. +type deactivateTowerMsg struct { + // id is the unique database identifier for the tower. + id wtdb.TowerID + + // pubKey is the identifying public key of the watchtower. + pubKey *btcec.PublicKey + + // errChan is the channel through which we'll send a response back to + // the caller when handling their request. + // + // NOTE: This channel must be buffered. + errChan chan error +} + // clientCfg holds the configuration values required by a client. type clientCfg struct { *Config @@ -165,8 +181,9 @@ type client struct { statTicker *time.Ticker stats *clientStats - newTowers chan *newTowerMsg - staleTowers chan *staleTowerMsg + newTowers chan *newTowerMsg + staleTowers chan *staleTowerMsg + deactivateTowers chan *deactivateTowerMsg wg sync.WaitGroup quit chan struct{} @@ -192,15 +209,16 @@ func newClient(cfg *clientCfg) (*client, error) { } c := &client{ - cfg: cfg, - log: plog, - pipeline: queue, - activeSessions: newSessionQueueSet(), - statTicker: time.NewTicker(DefaultStatInterval), - stats: new(clientStats), - newTowers: make(chan *newTowerMsg), - staleTowers: make(chan *staleTowerMsg), - quit: make(chan struct{}), + cfg: cfg, + log: plog, + pipeline: queue, + activeSessions: newSessionQueueSet(), + statTicker: time.NewTicker(DefaultStatInterval), + stats: new(clientStats), + newTowers: make(chan *newTowerMsg), + staleTowers: make(chan *staleTowerMsg), + deactivateTowers: make(chan *deactivateTowerMsg), + quit: make(chan struct{}), } candidateTowers := newTowerListIterator() @@ -514,8 +532,8 @@ func (c *client) nextSessionQueue() (*sessionQueue, error) { // stopAndRemoveSession stops the session with the given ID and removes it from // the in-memory active sessions set. -func (c *client) stopAndRemoveSession(id wtdb.SessionID) error { - return c.activeSessions.StopAndRemove(id) +func (c *client) stopAndRemoveSession(id wtdb.SessionID, final bool) error { + return c.activeSessions.StopAndRemove(id, final) } // deleteSessionFromTower dials the tower that we created the session with and @@ -694,6 +712,12 @@ func (c *client) backupDispatcher() { case msg := <-c.staleTowers: msg.errChan <- c.handleStaleTower(msg) + // A tower has been requested to be de-activated. We'll + // only allow this if the tower is not currently being + // used for session negotiation. + case msg := <-c.deactivateTowers: + msg.errChan <- c.handleDeactivateTower(msg) + case <-c.quit: return } @@ -779,6 +803,10 @@ func (c *client) backupDispatcher() { case msg := <-c.staleTowers: msg.errChan <- c.handleStaleTower(msg) + // A tower has been requested to be de-activated. + case msg := <-c.deactivateTowers: + msg.errChan <- c.handleDeactivateTower(msg) + case <-c.quit: return } @@ -1046,6 +1074,77 @@ func (c *client) initActiveQueue(s *ClientSession, return sq } +// deactivateTower sends a tower deactivation request to the backupDispatcher +// where it will be handled synchronously. The request should result in all the +// sessions that we have with the given tower being shutdown and removed from +// our in-memory set of active sessions. +func (c *client) deactivateTower(id wtdb.TowerID, + pubKey *btcec.PublicKey) error { + + errChan := make(chan error, 1) + + select { + case c.deactivateTowers <- &deactivateTowerMsg{ + id: id, + pubKey: pubKey, + errChan: errChan, + }: + case <-c.pipeline.quit: + return ErrClientExiting + } + + select { + case err := <-errChan: + return err + case <-c.pipeline.quit: + return ErrClientExiting + } +} + +// handleDeactivateTower handles a request to deactivate a tower. We will remove +// it from the in-memory candidate set, and we will also stop any active +// sessions we have with this tower. +func (c *client) handleDeactivateTower(msg *deactivateTowerMsg) error { + // Remove the tower from our in-memory candidate set so that it is not + // used for any new session negotiations. + err := c.candidateTowers.RemoveCandidate(msg.id, nil) + if err != nil { + return err + } + + pubKey := msg.pubKey.SerializeCompressed() + sessions, err := c.cfg.DB.ListClientSessions(&msg.id) + if err != nil { + return fmt.Errorf("unable to retrieve sessions for tower %x: "+ + "%v", pubKey, err) + } + + // Iterate over all the sessions we have for this tower and remove them + // from our candidate set and also from our set of started, active + // sessions. + for sessionID := range sessions { + delete(c.candidateSessions, sessionID) + + err = c.activeSessions.StopAndRemove(sessionID, false) + if err != nil { + return fmt.Errorf("could not stop session %s: %w", + sessionID, err) + } + } + + // If our active session queue corresponds to the stale tower, we'll + // proceed to negotiate a new one. + if c.sessionQueue != nil { + towerKey := c.sessionQueue.tower.IdentityKey + + if bytes.Equal(pubKey, towerKey.SerializeCompressed()) { + c.sessionQueue = nil + } + } + + return nil +} + // addTower adds a new watchtower reachable at the given address and considers // it for new sessions. If the watchtower already exists, then any new addresses // included will be considered when dialing it for session negotiations and @@ -1152,7 +1251,7 @@ func (c *client) handleStaleTower(msg *staleTowerMsg) error { // Shutdown the session so that any pending updates are // replayed back onto the main task pipeline. - err = c.activeSessions.StopAndRemove(sessionID) + err = c.activeSessions.StopAndRemove(sessionID, true) if err != nil { c.log.Errorf("could not stop session %s: %w", sessionID, err) diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 81674f0c0..0e5352f88 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -2688,6 +2688,88 @@ var clientTests = []clientTest{ require.NoError(h.t, err) }, }, + { + name: "de-activate a tower", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + TxPolicy: defaultTxPolicy, + MaxUpdates: 5, + }, + }, + fn: func(h *testHarness) { + const ( + numUpdates = 10 + chanIDInt = 0 + ) + + // Advance the channel with a few updates. + hints := h.advanceChannelN(chanIDInt, numUpdates) + + // Backup a few these updates and wait for them to + // arrive at the server. + h.backupStates(chanIDInt, 0, numUpdates/2, nil) + h.server.waitForUpdates(hints[:numUpdates/2], waitTime) + + // Lookup the tower and assert that it currently is + // seen as an active session candidate. + resp, err := h.clientMgr.LookupTower( + h.server.addr.IdentityKey, + ) + require.NoError(h.t, err) + tower, ok := resp[blob.TypeAltruistTaprootCommit] + require.True(h.t, ok) + require.True(h.t, tower.ActiveSessionCandidate) + + // Deactivate the tower. + err = h.clientMgr.DeactivateTower( + h.server.addr.IdentityKey, + ) + require.NoError(h.t, err) + + // Assert that it is no longer seen as an active + // session candidate. + resp, err = h.clientMgr.LookupTower( + h.server.addr.IdentityKey, + ) + require.NoError(h.t, err) + tower, ok = resp[blob.TypeAltruistTaprootCommit] + require.True(h.t, ok) + require.False(h.t, tower.ActiveSessionCandidate) + + // Add a new tower. + server2 := newServerHarness( + h.t, h.net, towerAddr2Str, nil, + ) + server2.start() + h.addTower(server2.addr) + + // Backup a few more states and assert that they appear + // on the second tower server. + h.backupStates( + chanIDInt, numUpdates/2, numUpdates-1, nil, + ) + server2.waitForUpdates( + hints[numUpdates/2:numUpdates-1], waitTime, + ) + + // Reactivate the first tower. + err = h.clientMgr.AddTower(h.server.addr) + require.NoError(h.t, err) + + // Deactivate the second tower. + err = h.clientMgr.DeactivateTower( + server2.addr.IdentityKey, + ) + require.NoError(h.t, err) + + // Backup the last backup and assert that it appears + // on the first tower. + h.backupStates(chanIDInt, numUpdates-1, numUpdates, nil) + h.server.waitForUpdates(hints[numUpdates-1:], waitTime) + }, + }, } // TestClient executes the client test suite, asserting the ability to backup diff --git a/watchtower/wtclient/manager.go b/watchtower/wtclient/manager.go index 7ffc60c48..17a351c15 100644 --- a/watchtower/wtclient/manager.go +++ b/watchtower/wtclient/manager.go @@ -38,6 +38,11 @@ type ClientManager interface { // instead. RemoveTower(*btcec.PublicKey, net.Addr) error + // DeactivateTower sets the given tower's status to inactive so that it + // is not considered for session negotiation. Its sessions will also not + // be used while the tower is inactive. + DeactivateTower(pubKey *btcec.PublicKey) error + // Stats returns the in-memory statistics of the client since startup. Stats() ClientStats @@ -431,6 +436,56 @@ func (m *Manager) RemoveTower(key *btcec.PublicKey, addr net.Addr) error { return nil } +// DeactivateTower sets the given tower's status to inactive so that it is not +// considered for session negotiation. Its sessions will also not be used while +// the tower is inactive. +func (m *Manager) DeactivateTower(key *btcec.PublicKey) error { + // We'll load the tower in order to retrieve its ID within the database. + tower, err := m.cfg.DB.LoadTower(key) + if err != nil { + return err + } + + m.clientsMu.Lock() + defer m.clientsMu.Unlock() + + for _, client := range m.clients { + err := client.deactivateTower(tower.ID, tower.IdentityKey) + if err != nil { + return err + } + } + + // Finally, mark the tower as inactive in the DB. + err = m.cfg.DB.DeactivateTower(key) + if err != nil { + log.Errorf("Could not deactivate the tower. Re-activating. %v", + err) + + // If the persisted state update fails, re-add the address to + // our client's in-memory state. + tower, newTowerErr := NewTowerFromDBTower(tower) + if newTowerErr != nil { + log.Errorf("Could not create new in-memory tower: %v", + newTowerErr) + + return err + } + + for _, client := range m.clients { + addTowerErr := client.addTower(tower) + if addTowerErr != nil { + log.Errorf("Could not re-add tower: %v", + addTowerErr) + } + } + + return err + } + + return nil +} + // Stats returns the in-memory statistics of the clients managed by the Manager // since startup. func (m *Manager) Stats() ClientStats { @@ -850,7 +905,7 @@ func (m *Manager) handleClosableSessions( // Stop the session and remove it from the // in-memory set. err = client.stopAndRemoveSession( - item.sessionID, + item.sessionID, true, ) if err != nil { log.Errorf("could not remove "+ diff --git a/watchtower/wtclient/session_queue.go b/watchtower/wtclient/session_queue.go index 3c1126a89..786410515 100644 --- a/watchtower/wtclient/session_queue.go +++ b/watchtower/wtclient/session_queue.go @@ -765,7 +765,7 @@ func (s *sessionQueueSet) AddAndStart(sessionQueue *sessionQueue) { // StopAndRemove stops the given session queue and removes it from the // sessionQueueSet. -func (s *sessionQueueSet) StopAndRemove(id wtdb.SessionID) error { +func (s *sessionQueueSet) StopAndRemove(id wtdb.SessionID, final bool) error { s.mu.Lock() defer s.mu.Unlock() @@ -776,7 +776,7 @@ func (s *sessionQueueSet) StopAndRemove(id wtdb.SessionID) error { delete(s.queues, id) - return queue.Stop(true) + return queue.Stop(final) } // Get fetches and returns the sessionQueue with the given ID.