From 26e628c0feaf814a1d6288c45ee33613232ebab7 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 20 Mar 2023 11:07:31 +0200 Subject: [PATCH] watchtowers: handle closable sessions Add a routine to the tower client that informs towers of sessions they can delete and also deletes any info about the session from the client DB. --- lncfg/wtclient.go | 5 + sample-lnd.conf | 6 + server.go | 9 ++ watchtower/wtclient/client.go | 186 ++++++++++++++++++++++++++++- watchtower/wtclient/client_test.go | 132 +++++++++++++++++++- 5 files changed, 331 insertions(+), 7 deletions(-) diff --git a/lncfg/wtclient.go b/lncfg/wtclient.go index 8b9f03939..7d4331112 100644 --- a/lncfg/wtclient.go +++ b/lncfg/wtclient.go @@ -17,6 +17,11 @@ type WtClient struct { // SweepFeeRate specifies the fee rate in sat/byte to be used when // constructing justice transactions sent to the tower. SweepFeeRate uint64 `long:"sweep-fee-rate" description:"Specifies the fee rate in sat/byte to be used when constructing justice transactions sent to the watchtower."` + + // SessionCloseRange is the range over which to choose a random number + // of blocks to wait after the last channel of a session is closed + // before sending the DeleteSession message to the tower server. + SessionCloseRange uint32 `long:"session-close-range" description:"The range over which to choose a random number of blocks to wait after the last channel of a session is closed before sending the DeleteSession message to the tower server. Set to 1 for no delay."` } // Validate ensures the user has provided a valid configuration. diff --git a/sample-lnd.conf b/sample-lnd.conf index 3dfc76da8..f0edda984 100644 --- a/sample-lnd.conf +++ b/sample-lnd.conf @@ -997,6 +997,12 @@ litecoin.node=ltcd ; supported at this time, if none are provided the tower will not be enabled. ; wtclient.private-tower-uris= +; The range over which to choose a random number of blocks to wait after the +; last channel of a session is closed before sending the DeleteSession message +; to the tower server. The default is currently 288. Note that setting this to +; a lower value will result in faster session cleanup _but_ that this comes +; along with reduced privacy from the tower server. +; wtclient.session-close-range=10 [healthcheck] diff --git a/server.go b/server.go index 61b31ffa6..eef43550b 100644 --- a/server.go +++ b/server.go @@ -1497,6 +1497,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr, policy.SweepFeeRate = sweepRateSatPerVByte.FeePerKWeight() } + sessionCloseRange := uint32(wtclient.DefaultSessionCloseRange) + if cfg.WtClient.SessionCloseRange != 0 { + sessionCloseRange = cfg.WtClient.SessionCloseRange + } + if err := policy.Validate(); err != nil { return nil, err } @@ -1516,6 +1521,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.towerClient, err = wtclient.New(&wtclient.Config{ FetchClosedChannel: fetchClosedChannel, + SessionCloseRange: sessionCloseRange, + ChainNotifier: s.cc.ChainNotifier, SubscribeChannelEvents: func() (subscribe.Subscription, error) { @@ -1546,6 +1553,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.anchorTowerClient, err = wtclient.New(&wtclient.Config{ FetchClosedChannel: fetchClosedChannel, + SessionCloseRange: sessionCloseRange, + ChainNotifier: s.cc.ChainNotifier, SubscribeChannelEvents: func() (subscribe.Subscription, error) { diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 464e93f16..e92b8b4cf 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -2,8 +2,10 @@ package wtclient import ( "bytes" + "crypto/rand" "errors" "fmt" + "math/big" "net" "sync" "time" @@ -12,6 +14,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/input" @@ -43,6 +46,11 @@ const ( // client should abandon any pending updates or session negotiations // before terminating. DefaultForceQuitDelay = 10 * time.Second + + // DefaultSessionCloseRange is the range over which we will generate a + // random number of blocks to delay closing a session after its last + // channel has been closed. + DefaultSessionCloseRange = 288 ) // genSessionFilter constructs a filter that can be used to select sessions only @@ -159,6 +167,9 @@ type Config struct { FetchClosedChannel func(cid lnwire.ChannelID) ( *channeldb.ChannelCloseSummary, error) + // ChainNotifier can be used to subscribe to block notifications. + ChainNotifier chainntnfs.ChainNotifier + // NewAddress generates a new on-chain sweep pkscript. NewAddress func() ([]byte, error) @@ -214,6 +225,11 @@ type Config struct { // watchtowers. If the exponential backoff produces a timeout greater // than this value, the backoff will be clamped to MaxBackoff. MaxBackoff time.Duration + + // SessionCloseRange is the range over which we will generate a random + // number of blocks to delay closing a session after its last channel + // has been closed. + SessionCloseRange uint32 } // newTowerMsg is an internal message we'll use within the TowerClient to signal @@ -590,9 +606,34 @@ func (c *TowerClient) Start() error { delete(c.summaries, id) } + // Load all closable sessions. + closableSessions, err := c.cfg.DB.ListClosableSessions() + if err != nil { + returnErr = err + return + } + + err = c.trackClosableSessions(closableSessions) + if err != nil { + returnErr = err + return + } + c.wg.Add(1) go c.handleChannelCloses(chanSub) + // Subscribe to new block events. + blockEvents, err := c.cfg.ChainNotifier.RegisterBlockEpochNtfn( + nil, + ) + if err != nil { + returnErr = err + return + } + + c.wg.Add(1) + go c.handleClosableSessions(blockEvents) + // Now start the session negotiator, which will allow us to // request new session as soon as the backupDispatcher starts // up. @@ -876,7 +917,8 @@ func (c *TowerClient) handleChannelCloses(chanSub subscribe.Subscription) { } // handleClosedChannel handles the closure of a single channel. It will mark the -// channel as closed in the DB. +// channel as closed in the DB, then it will handle all the sessions that are +// now closable due to the channel closure. func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, closeHeight uint32) error { @@ -890,18 +932,146 @@ func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, c.log.Debugf("Marking channel(%s) as closed", chanID) - _, err := c.cfg.DB.MarkChannelClosed(chanID, closeHeight) + sessions, err := c.cfg.DB.MarkChannelClosed(chanID, closeHeight) if err != nil { return fmt.Errorf("could not mark channel(%s) as closed: %w", chanID, err) } + closableSessions := make(map[wtdb.SessionID]uint32, len(sessions)) + for _, sess := range sessions { + closableSessions[sess] = closeHeight + } + + c.log.Debugf("Tracking %d new closable sessions as a result of "+ + "closing channel %s", len(closableSessions), chanID) + + err = c.trackClosableSessions(closableSessions) + if err != nil { + return fmt.Errorf("could not track closable sessions: %w", err) + } + delete(c.summaries, chanID) delete(c.chanCommitHeights, chanID) return nil } +// handleClosableSessions listens for new block notifications. For each block, +// it checks the closableSessionQueue to see if there is a closable session with +// a delete-height smaller than or equal to the new block, if there is then the +// tower is informed that it can delete the session, and then we also delete it +// from our DB. +func (c *TowerClient) handleClosableSessions( + blocksChan *chainntnfs.BlockEpochEvent) { + + defer c.wg.Done() + + c.log.Debug("Starting closable sessions handler") + defer c.log.Debug("Stopping closable sessions handler") + + for { + select { + case newBlock := <-blocksChan.Epochs: + if newBlock == nil { + return + } + + height := uint32(newBlock.Height) + for { + select { + case <-c.quit: + return + default: + } + + // If there are no closable sessions that we + // need to handle, then we are done and can + // reevaluate when the next block comes. + item := c.closableSessionQueue.Top() + if item == nil { + break + } + + // If there is closable session but the delete + // height we have set for it is after the + // current block height, then our work is done. + if item.deleteHeight > height { + break + } + + // Otherwise, we pop this item from the heap + // and handle it. + c.closableSessionQueue.Pop() + + // Fetch the session from the DB so that we can + // extract the Tower info. + sess, err := c.cfg.DB.GetClientSession( + item.sessionID, + ) + if err != nil { + c.log.Errorf("error calling "+ + "GetClientSession for "+ + "session %s: %v", + item.sessionID, err) + + continue + } + + err = c.deleteSessionFromTower(sess) + if err != nil { + c.log.Errorf("error deleting "+ + "session %s from tower: %v", + sess.ID, err) + + continue + } + + err = c.cfg.DB.DeleteSession(item.sessionID) + if err != nil { + c.log.Errorf("could not delete "+ + "session(%s) from DB: %w", + sess.ID, err) + + continue + } + } + + case <-c.forceQuit: + return + + case <-c.quit: + return + } + } +} + +// trackClosableSessions takes in a map of session IDs to the earliest block +// height at which the session should be deleted. For each of the sessions, +// a random delay is added to the block height and the session is added to the +// closableSessionQueue. +func (c *TowerClient) trackClosableSessions( + sessions map[wtdb.SessionID]uint32) error { + + // For each closable session, add a random delay to its close + // height and add it to the closableSessionQueue. + for sID, blockHeight := range sessions { + delay, err := newRandomDelay(c.cfg.SessionCloseRange) + if err != nil { + return err + } + + deleteHeight := blockHeight + delay + + c.closableSessionQueue.Push(&sessionCloseItem{ + sessionID: sID, + deleteHeight: deleteHeight, + }) + } + + return nil +} + // deleteSessionFromTower dials the tower that we created the session with and // attempts to send the tower the DeleteSession message. func (c *TowerClient) deleteSessionFromTower(sess *wtdb.ClientSession) error { @@ -1671,3 +1841,15 @@ func (c *TowerClient) logMessage( preposition, peer.RemotePub().SerializeCompressed(), peer.RemoteAddr()) } + +func newRandomDelay(max uint32) (uint32, error) { + var maxDelay big.Int + maxDelay.SetUint64(uint64(max)) + + randDelay, err := rand.Int(rand.Reader, &maxDelay) + if err != nil { + return 0, err + } + + return uint32(randDelay.Uint64()), nil +} diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 29c4e7a53..2657e691b 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -16,6 +16,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/input" @@ -396,6 +397,9 @@ type testHarness struct { server *wtserver.Server net *mockNet + blockEvents *mockBlockSub + height int32 + channelEvents *mockSubscription sendUpdatesOn bool @@ -458,6 +462,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { serverDB: serverDB, serverCfg: serverCfg, net: mockNet, + blockEvents: newMockBlockSub(t), channelEvents: newMockSubscription(t), channels: make(map[lnwire.ChannelID]*mockChannel), closedChannels: make(map[lnwire.ChannelID]uint32), @@ -487,6 +492,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { return h.channelEvents, nil }, FetchClosedChannel: fetchChannel, + ChainNotifier: h.blockEvents, Dial: mockNet.Dial, DB: clientDB, AuthDial: mockNet.AuthDial, @@ -495,11 +501,12 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { NewAddress: func() ([]byte, error) { return addrScript, nil }, - ReadTimeout: timeout, - WriteTimeout: timeout, - MinBackoff: time.Millisecond, - MaxBackoff: time.Second, - ForceQuitDelay: 10 * time.Second, + ReadTimeout: timeout, + WriteTimeout: timeout, + MinBackoff: time.Millisecond, + MaxBackoff: time.Second, + ForceQuitDelay: 10 * time.Second, + SessionCloseRange: 1, } if !cfg.noServerStart { @@ -518,6 +525,16 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { return h } +// mine mimics the mining of new blocks by sending new block notifications. +func (h *testHarness) mine(numBlocks int) { + h.t.Helper() + + for i := 0; i < numBlocks; i++ { + h.height++ + h.blockEvents.sendNewBlock(h.height) + } +} + // startServer creates a new server using the harness's current serverCfg and // starts it after pointing the mockNet's callback to the new server. func (h *testHarness) startServer() { @@ -909,6 +926,44 @@ func (m *mockSubscription) Updates() <-chan interface{} { return m.updates } +// mockBlockSub mocks out the ChainNotifier. +type mockBlockSub struct { + t *testing.T + events chan *chainntnfs.BlockEpoch + + chainntnfs.ChainNotifier +} + +// newMockBlockSub creates a new mockBlockSub. +func newMockBlockSub(t *testing.T) *mockBlockSub { + t.Helper() + + return &mockBlockSub{ + t: t, + events: make(chan *chainntnfs.BlockEpoch), + } +} + +// RegisterBlockEpochNtfn returns a channel that can be used to listen for new +// blocks. +func (m *mockBlockSub) RegisterBlockEpochNtfn(_ *chainntnfs.BlockEpoch) ( + *chainntnfs.BlockEpochEvent, error) { + + return &chainntnfs.BlockEpochEvent{ + Epochs: m.events, + }, nil +} + +// sendNewBlock will send a new block on the notification channel. +func (m *mockBlockSub) sendNewBlock(height int32) { + select { + case m.events <- &chainntnfs.BlockEpoch{Height: height}: + + case <-time.After(waitTime): + m.t.Fatalf("timed out sending block: %d", height) + } +} + const ( localBalance = lnwire.MilliSatoshi(100000000) remoteBalance = lnwire.MilliSatoshi(200000000) @@ -1891,6 +1946,73 @@ var clientTests = []clientTest{ return h.isSessionClosable(sessionIDs[0]) }, waitTime) require.NoError(h.t, err) + + // Now we will mine a few blocks. This will cause the + // necessary session-close-range to be exceeded meaning + // that the client should send the DeleteSession message + // to the server. We will assert that both the client + // and server have deleted the appropriate sessions and + // channel info. + + // Before we mine blocks, assert that the client + // currently has 3 closable sessions. + closableSess, err := h.clientDB.ListClosableSessions() + require.NoError(h.t, err) + require.Len(h.t, closableSess, 3) + + // Assert that the server is also aware of all of these + // sessions. + for sid := range closableSess { + _, err := h.serverDB.GetSessionInfo(&sid) + require.NoError(h.t, err) + } + + // Also make a note of the total number of sessions the + // client has. + sessions, err := h.clientDB.ListClientSessions(nil, nil) + require.NoError(h.t, err) + require.Len(h.t, sessions, 4) + + h.mine(3) + + // The client should no longer have any closable + // sessions and the total list of client sessions should + // no longer include the three that it previously had + // marked as closable. The server should also no longer + // have these sessions in its DB. + err = wait.Predicate(func() bool { + sess, err := h.clientDB.ListClientSessions( + nil, nil, + ) + require.NoError(h.t, err) + + cs, err := h.clientDB.ListClosableSessions() + require.NoError(h.t, err) + + if len(sess) != 1 || len(cs) != 0 { + return false + } + + for sid := range closableSess { + _, ok := sess[sid] + if ok { + return false + } + + _, err := h.serverDB.GetSessionInfo( + &sid, + ) + if !errors.Is( + err, wtdb.ErrSessionNotFound, + ) { + return false + } + } + + return true + + }, waitTime) + require.NoError(h.t, err) }, }, }