watchtower+server: let manager Start & Stop the clients

In this commit, the `Stop` and `Start` methods are removed from the
`Client` interface and instead added to the new `Manager`. Callers now
only need to call the Manager to start or stop the clients instead of
needing to call stop/start on each individual client.
This commit is contained in:
Elle Mouton 2023-08-11 11:39:24 +02:00
parent ab0375e0c1
commit 2abc422aac
No known key found for this signature in database
GPG Key ID: D7D916376026F177
4 changed files with 183 additions and 181 deletions

View File

@ -1913,19 +1913,12 @@ func (s *server) Start() error {
}
cleanup = cleanup.add(s.htlcNotifier.Stop)
if s.towerClient != nil {
if err := s.towerClient.Start(); err != nil {
if s.towerClientMgr != nil {
if err := s.towerClientMgr.Start(); err != nil {
startErr = err
return
}
cleanup = cleanup.add(s.towerClient.Stop)
}
if s.anchorTowerClient != nil {
if err := s.anchorTowerClient.Start(); err != nil {
startErr = err
return
}
cleanup = cleanup.add(s.anchorTowerClient.Stop)
cleanup = cleanup.add(s.towerClientMgr.Stop)
}
if err := s.sweeper.Start(); err != nil {
@ -2298,16 +2291,10 @@ func (s *server) Stop() error {
// client which will reliably flush all queued states to the
// tower. If this is halted for any reason, the force quit timer
// will kick in and abort to allow this method to return.
if s.towerClient != nil {
if err := s.towerClient.Stop(); err != nil {
if s.towerClientMgr != nil {
if err := s.towerClientMgr.Stop(); err != nil {
srvrLog.Warnf("Unable to shut down tower "+
"client: %v", err)
}
}
if s.anchorTowerClient != nil {
if err := s.anchorTowerClient.Stop(); err != nil {
srvrLog.Warnf("Unable to shut down anchor "+
"tower client: %v", err)
"client manager: %v", err)
}
}

View File

@ -134,15 +134,6 @@ type Client interface {
// successful unless the justice transaction would create dust outputs
// when trying to abide by the negotiated policy.
BackupState(chanID *lnwire.ChannelID, stateNum uint64) error
// Start initializes the watchtower client, allowing it process requests
// to backup revoked channel states.
Start() error
// Stop attempts a graceful shutdown of the watchtower client. In doing
// so, it will attempt to flush the pipeline and deliver any queued
// states to the tower before exiting.
Stop() error
}
// BreachRetributionBuilder is a function that can be used to construct a
@ -199,9 +190,6 @@ type towerClientCfg struct {
// non-blocking, reliable subsystem for backing up revoked states to a specified
// private tower.
type TowerClient struct {
started sync.Once
stopped sync.Once
cfg *towerClientCfg
log btclog.Logger
@ -420,170 +408,158 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
return sessions, nil
}
// Start initializes the watchtower client by loading or negotiating an active
// start initializes the watchtower client by loading or negotiating an active
// session and then begins processing backup tasks from the request pipeline.
func (c *TowerClient) Start() error {
var returnErr error
c.started.Do(func() {
c.log.Infof("Watchtower client starting")
func (c *TowerClient) start() error {
c.log.Infof("Watchtower client starting")
// First, restart a session queue for any sessions that have
// committed but unacked state updates. This ensures that these
// sessions will be able to flush the committed updates after a
// restart.
fetchCommittedUpdates := c.cfg.DB.FetchSessionCommittedUpdates
for _, session := range c.candidateSessions {
committedUpdates, err := fetchCommittedUpdates(
&session.ID,
)
if err != nil {
returnErr = err
return
}
if len(committedUpdates) > 0 {
c.log.Infof("Starting session=%s to process "+
"%d committed backups", session.ID,
len(committedUpdates))
c.initActiveQueue(session, committedUpdates)
}
}
chanSub, err := c.cfg.SubscribeChannelEvents()
if err != nil {
returnErr = err
return
}
// Iterate over the list of registered channels and check if
// any of them can be marked as closed.
for id := range c.chanInfos {
isClosed, closedHeight, err := c.isChannelClosed(id)
if err != nil {
returnErr = err
return
}
if !isClosed {
continue
}
_, err = c.cfg.DB.MarkChannelClosed(id, closedHeight)
if err != nil {
c.log.Errorf("could not mark channel(%s) as "+
"closed: %v", id, err)
continue
}
// Since the channel has been marked as closed, we can
// also remove it from the channel summaries map.
delete(c.chanInfos, id)
}
// Load all closable sessions.
closableSessions, err := c.cfg.DB.ListClosableSessions()
if err != nil {
returnErr = err
return
}
err = c.trackClosableSessions(closableSessions)
if err != nil {
returnErr = err
return
}
c.wg.Add(1)
go c.handleChannelCloses(chanSub)
// Subscribe to new block events.
blockEvents, err := c.cfg.ChainNotifier.RegisterBlockEpochNtfn(
nil,
// First, restart a session queue for any sessions that have
// committed but unacked state updates. This ensures that these
// sessions will be able to flush the committed updates after a
// restart.
fetchCommittedUpdates := c.cfg.DB.FetchSessionCommittedUpdates
for _, session := range c.candidateSessions {
committedUpdates, err := fetchCommittedUpdates(
&session.ID,
)
if err != nil {
returnErr = err
return
return err
}
c.wg.Add(1)
go c.handleClosableSessions(blockEvents)
if len(committedUpdates) > 0 {
c.log.Infof("Starting session=%s to process "+
"%d committed backups", session.ID,
len(committedUpdates))
// Now start the session negotiator, which will allow us to
// request new session as soon as the backupDispatcher starts
// up.
err = c.negotiator.Start()
c.initActiveQueue(session, committedUpdates)
}
}
chanSub, err := c.cfg.SubscribeChannelEvents()
if err != nil {
return err
}
// Iterate over the list of registered channels and check if
// any of them can be marked as closed.
for id := range c.chanInfos {
isClosed, closedHeight, err := c.isChannelClosed(id)
if err != nil {
returnErr = err
return
return err
}
// Start the task pipeline to which new backup tasks will be
// submitted from active links.
err = c.pipeline.Start()
if !isClosed {
continue
}
_, err = c.cfg.DB.MarkChannelClosed(id, closedHeight)
if err != nil {
returnErr = err
return
c.log.Errorf("could not mark channel(%s) as "+
"closed: %v", id, err)
continue
}
c.wg.Add(1)
go c.backupDispatcher()
// Since the channel has been marked as closed, we can
// also remove it from the channel summaries map.
delete(c.chanInfos, id)
}
c.log.Infof("Watchtower client started successfully")
})
return returnErr
// Load all closable sessions.
closableSessions, err := c.cfg.DB.ListClosableSessions()
if err != nil {
return err
}
err = c.trackClosableSessions(closableSessions)
if err != nil {
return err
}
c.wg.Add(1)
go c.handleChannelCloses(chanSub)
// Subscribe to new block events.
blockEvents, err := c.cfg.ChainNotifier.RegisterBlockEpochNtfn(
nil,
)
if err != nil {
return err
}
c.wg.Add(1)
go c.handleClosableSessions(blockEvents)
// Now start the session negotiator, which will allow us to
// request new session as soon as the backupDispatcher starts
// up.
err = c.negotiator.Start()
if err != nil {
return err
}
// Start the task pipeline to which new backup tasks will be
// submitted from active links.
err = c.pipeline.Start()
if err != nil {
return err
}
c.wg.Add(1)
go c.backupDispatcher()
c.log.Infof("Watchtower client started successfully")
return nil
}
// Stop idempotently initiates a graceful shutdown of the watchtower client.
func (c *TowerClient) Stop() error {
// stop idempotently initiates a graceful shutdown of the watchtower client.
func (c *TowerClient) stop() error {
var returnErr error
c.stopped.Do(func() {
c.log.Debugf("Stopping watchtower client")
c.log.Debugf("Stopping watchtower client")
// 1. Stop the session negotiator.
err := c.negotiator.Stop()
// 1. Stop the session negotiator.
err := c.negotiator.Stop()
if err != nil {
returnErr = err
}
// 2. Stop the backup dispatcher and any other goroutines.
close(c.quit)
c.wg.Wait()
// 3. If there was a left over 'prevTask' from the backup
// dispatcher, replay that onto the pipeline.
if c.prevTask != nil {
err = c.pipeline.QueueBackupID(c.prevTask)
if err != nil {
returnErr = err
}
}
// 2. Stop the backup dispatcher and any other goroutines.
close(c.quit)
c.wg.Wait()
// 3. If there was a left over 'prevTask' from the backup
// dispatcher, replay that onto the pipeline.
if c.prevTask != nil {
err = c.pipeline.QueueBackupID(c.prevTask)
// 4. Shutdown all active session queues in parallel. These will
// exit once all unhandled updates have been replayed to the
// task pipeline.
c.activeSessions.ApplyAndWait(func(s *sessionQueue) func() {
return func() {
err := s.Stop(false)
if err != nil {
c.log.Errorf("could not stop session "+
"queue: %s: %v", s.ID(), err)
returnErr = err
}
}
// 4. Shutdown all active session queues in parallel. These will
// exit once all unhandled updates have been replayed to the
// task pipeline.
c.activeSessions.ApplyAndWait(func(s *sessionQueue) func() {
return func() {
err := s.Stop(false)
if err != nil {
c.log.Errorf("could not stop session "+
"queue: %s: %v", s.ID(), err)
returnErr = err
}
}
})
// 5. Shutdown the backup queue, which will prevent any further
// updates from being accepted.
if err = c.pipeline.Stop(); err != nil {
returnErr = err
}
c.log.Debugf("Client successfully stopped, stats: %s", c.stats)
})
// 5. Shutdown the backup queue, which will prevent any further
// updates from being accepted.
if err = c.pipeline.Stop(); err != nil {
returnErr = err
}
c.log.Debugf("Client successfully stopped, stats: %s", c.stats)
return returnErr
}

View File

@ -399,6 +399,7 @@ type testHarness struct {
cfg harnessCfg
signer *wtmock.MockSigner
capacity lnwire.MilliSatoshi
clientMgr *wtclient.Manager
clientDB *wtdb.ClientDB
clientCfg *wtclient.Config
clientPolicy wtpolicy.Policy
@ -526,7 +527,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
h.startClient()
t.Cleanup(func() {
require.NoError(t, h.client.Stop())
require.NoError(h.t, h.clientMgr.Stop())
require.NoError(t, h.clientDB.Close())
})
@ -560,12 +561,12 @@ func (h *testHarness) startClient() {
Address: towerTCPAddr,
}
m, err := wtclient.NewManager(h.clientCfg)
h.clientMgr, err = wtclient.NewManager(h.clientCfg)
require.NoError(h.t, err)
h.client, err = m.NewClient(h.clientPolicy)
h.client, err = h.clientMgr.NewClient(h.clientPolicy)
require.NoError(h.t, err)
require.NoError(h.t, h.client.Start())
require.NoError(h.t, h.clientMgr.Start())
require.NoError(h.t, h.client.AddTower(towerAddr))
}
@ -1127,7 +1128,7 @@ var clientTests = []clientTest{
)
// Stop the client, subsequent backups should fail.
h.client.Stop()
require.NoError(h.t, h.clientMgr.Stop())
// Advance the channel and try to back up the states. We
// expect ErrClientExiting to be returned from
@ -1242,7 +1243,7 @@ var clientTests = []clientTest{
// Stop the client to abort the state updates it has
// queued.
require.NoError(h.t, h.client.Stop())
require.NoError(h.t, h.clientMgr.Stop())
// Restart the server and allow it to ack the updates
// after the client retransmits the unacked update.
@ -1437,7 +1438,7 @@ var clientTests = []clientTest{
h.server.waitForUpdates(nil, waitTime)
// Stop the client since it has queued backups.
require.NoError(h.t, h.client.Stop())
require.NoError(h.t, h.clientMgr.Stop())
// Restart the server and allow it to ack session
// creation.
@ -1487,7 +1488,7 @@ var clientTests = []clientTest{
h.server.waitForUpdates(nil, waitTime)
// Stop the client since it has queued backups.
require.NoError(h.t, h.client.Stop())
require.NoError(h.t, h.clientMgr.Stop())
// Restart the server and allow it to ack session
// creation.
@ -1541,7 +1542,7 @@ var clientTests = []clientTest{
h.server.waitForUpdates(hints[:numUpdates/2], waitTime)
// Stop the client, which should have no more backups.
require.NoError(h.t, h.client.Stop())
require.NoError(h.t, h.clientMgr.Stop())
// Record the policy that the first half was stored
// under. We'll expect the second half to also be
@ -1602,7 +1603,7 @@ var clientTests = []clientTest{
// Restart the client, so we can ensure the deduping is
// maintained across restarts.
require.NoError(h.t, h.client.Stop())
require.NoError(h.t, h.clientMgr.Stop())
h.startClient()
// Try to back up the full range of retributions. Only
@ -1882,7 +1883,7 @@ var clientTests = []clientTest{
require.False(h.t, h.isSessionClosable(sessionIDs[0]))
// Restart the client.
require.NoError(h.t, h.client.Stop())
require.NoError(h.t, h.clientMgr.Stop())
h.startClient()
// The session should now have been marked as closable.
@ -2069,7 +2070,7 @@ var clientTests = []clientTest{
h.server.waitForUpdates(hints[:numUpdates/2], waitTime)
// Now stop the client and reset its database.
require.NoError(h.t, h.client.Stop())
require.NoError(h.t, h.clientMgr.Stop())
db := newClientDB(h.t)
h.clientDB = db
@ -2122,7 +2123,7 @@ var clientTests = []clientTest{
h.backupStates(chanID, 0, numUpdates/2, nil)
// Restart the Client. And also now start the server.
require.NoError(h.t, h.client.Stop())
require.NoError(h.t, h.clientMgr.Stop())
h.server.start()
h.startClient()
@ -2395,7 +2396,7 @@ var clientTests = []clientTest{
// Now restart the client. This ensures that the
// updates are no longer in the pending queue.
require.NoError(h.t, h.client.Stop())
require.NoError(h.t, h.clientMgr.Stop())
h.startClient()
// Now remove the tower.

View File

@ -100,6 +100,9 @@ type Config struct {
// required for each different commitment transaction type. The Manager acts as
// a tower client multiplexer.
type Manager struct {
started sync.Once
stopped sync.Once
cfg *Config
clients map[blob.Type]*TowerClient
@ -154,3 +157,38 @@ func (m *Manager) NewClient(policy wtpolicy.Policy) (*TowerClient, error) {
return client, nil
}
// Start starts all the clients that have been registered with the Manager.
func (m *Manager) Start() error {
var returnErr error
m.started.Do(func() {
m.clientsMu.Lock()
defer m.clientsMu.Unlock()
for _, client := range m.clients {
if err := client.start(); err != nil {
returnErr = err
return
}
}
})
return returnErr
}
// Stop stops all the clients that the Manger is managing.
func (m *Manager) Stop() error {
var returnErr error
m.stopped.Do(func() {
m.clientsMu.Lock()
defer m.clientsMu.Unlock()
for _, client := range m.clients {
if err := client.stop(); err != nil {
returnErr = err
}
}
})
return returnErr
}