watchtowers: handle closable sessions

Add a routine to the tower client that informs towers of sessions they
can delete and also deletes any info about the session from the client
DB.
This commit is contained in:
Elle Mouton 2023-03-20 11:07:31 +02:00
parent 8478b56ce6
commit 26e628c0fe
No known key found for this signature in database
GPG key ID: D7D916376026F177
5 changed files with 331 additions and 7 deletions

View file

@ -17,6 +17,11 @@ type WtClient struct {
// SweepFeeRate specifies the fee rate in sat/byte to be used when
// constructing justice transactions sent to the tower.
SweepFeeRate uint64 `long:"sweep-fee-rate" description:"Specifies the fee rate in sat/byte to be used when constructing justice transactions sent to the watchtower."`
// SessionCloseRange is the range over which to choose a random number
// of blocks to wait after the last channel of a session is closed
// before sending the DeleteSession message to the tower server.
SessionCloseRange uint32 `long:"session-close-range" description:"The range over which to choose a random number of blocks to wait after the last channel of a session is closed before sending the DeleteSession message to the tower server. Set to 1 for no delay."`
}
// Validate ensures the user has provided a valid configuration.

View file

@ -997,6 +997,12 @@ litecoin.node=ltcd
; supported at this time, if none are provided the tower will not be enabled.
; wtclient.private-tower-uris=
; The range over which to choose a random number of blocks to wait after the
; last channel of a session is closed before sending the DeleteSession message
; to the tower server. The default is currently 288. Note that setting this to
; a lower value will result in faster session cleanup _but_ that this comes
; along with reduced privacy from the tower server.
; wtclient.session-close-range=10
[healthcheck]

View file

@ -1497,6 +1497,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
policy.SweepFeeRate = sweepRateSatPerVByte.FeePerKWeight()
}
sessionCloseRange := uint32(wtclient.DefaultSessionCloseRange)
if cfg.WtClient.SessionCloseRange != 0 {
sessionCloseRange = cfg.WtClient.SessionCloseRange
}
if err := policy.Validate(); err != nil {
return nil, err
}
@ -1516,6 +1521,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
s.towerClient, err = wtclient.New(&wtclient.Config{
FetchClosedChannel: fetchClosedChannel,
SessionCloseRange: sessionCloseRange,
ChainNotifier: s.cc.ChainNotifier,
SubscribeChannelEvents: func() (subscribe.Subscription,
error) {
@ -1546,6 +1553,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
s.anchorTowerClient, err = wtclient.New(&wtclient.Config{
FetchClosedChannel: fetchClosedChannel,
SessionCloseRange: sessionCloseRange,
ChainNotifier: s.cc.ChainNotifier,
SubscribeChannelEvents: func() (subscribe.Subscription,
error) {

View file

@ -2,8 +2,10 @@ package wtclient
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"math/big"
"net"
"sync"
"time"
@ -12,6 +14,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btclog"
"github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channelnotifier"
"github.com/lightningnetwork/lnd/input"
@ -43,6 +46,11 @@ const (
// client should abandon any pending updates or session negotiations
// before terminating.
DefaultForceQuitDelay = 10 * time.Second
// DefaultSessionCloseRange is the range over which we will generate a
// random number of blocks to delay closing a session after its last
// channel has been closed.
DefaultSessionCloseRange = 288
)
// genSessionFilter constructs a filter that can be used to select sessions only
@ -159,6 +167,9 @@ type Config struct {
FetchClosedChannel func(cid lnwire.ChannelID) (
*channeldb.ChannelCloseSummary, error)
// ChainNotifier can be used to subscribe to block notifications.
ChainNotifier chainntnfs.ChainNotifier
// NewAddress generates a new on-chain sweep pkscript.
NewAddress func() ([]byte, error)
@ -214,6 +225,11 @@ type Config struct {
// watchtowers. If the exponential backoff produces a timeout greater
// than this value, the backoff will be clamped to MaxBackoff.
MaxBackoff time.Duration
// SessionCloseRange is the range over which we will generate a random
// number of blocks to delay closing a session after its last channel
// has been closed.
SessionCloseRange uint32
}
// newTowerMsg is an internal message we'll use within the TowerClient to signal
@ -590,9 +606,34 @@ func (c *TowerClient) Start() error {
delete(c.summaries, id)
}
// Load all closable sessions.
closableSessions, err := c.cfg.DB.ListClosableSessions()
if err != nil {
returnErr = err
return
}
err = c.trackClosableSessions(closableSessions)
if err != nil {
returnErr = err
return
}
c.wg.Add(1)
go c.handleChannelCloses(chanSub)
// Subscribe to new block events.
blockEvents, err := c.cfg.ChainNotifier.RegisterBlockEpochNtfn(
nil,
)
if err != nil {
returnErr = err
return
}
c.wg.Add(1)
go c.handleClosableSessions(blockEvents)
// Now start the session negotiator, which will allow us to
// request new session as soon as the backupDispatcher starts
// up.
@ -876,7 +917,8 @@ func (c *TowerClient) handleChannelCloses(chanSub subscribe.Subscription) {
}
// handleClosedChannel handles the closure of a single channel. It will mark the
// channel as closed in the DB.
// channel as closed in the DB, then it will handle all the sessions that are
// now closable due to the channel closure.
func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID,
closeHeight uint32) error {
@ -890,18 +932,146 @@ func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID,
c.log.Debugf("Marking channel(%s) as closed", chanID)
_, err := c.cfg.DB.MarkChannelClosed(chanID, closeHeight)
sessions, err := c.cfg.DB.MarkChannelClosed(chanID, closeHeight)
if err != nil {
return fmt.Errorf("could not mark channel(%s) as closed: %w",
chanID, err)
}
closableSessions := make(map[wtdb.SessionID]uint32, len(sessions))
for _, sess := range sessions {
closableSessions[sess] = closeHeight
}
c.log.Debugf("Tracking %d new closable sessions as a result of "+
"closing channel %s", len(closableSessions), chanID)
err = c.trackClosableSessions(closableSessions)
if err != nil {
return fmt.Errorf("could not track closable sessions: %w", err)
}
delete(c.summaries, chanID)
delete(c.chanCommitHeights, chanID)
return nil
}
// handleClosableSessions listens for new block notifications. For each block,
// it checks the closableSessionQueue to see if there is a closable session with
// a delete-height smaller than or equal to the new block, if there is then the
// tower is informed that it can delete the session, and then we also delete it
// from our DB.
func (c *TowerClient) handleClosableSessions(
blocksChan *chainntnfs.BlockEpochEvent) {
defer c.wg.Done()
c.log.Debug("Starting closable sessions handler")
defer c.log.Debug("Stopping closable sessions handler")
for {
select {
case newBlock := <-blocksChan.Epochs:
if newBlock == nil {
return
}
height := uint32(newBlock.Height)
for {
select {
case <-c.quit:
return
default:
}
// If there are no closable sessions that we
// need to handle, then we are done and can
// reevaluate when the next block comes.
item := c.closableSessionQueue.Top()
if item == nil {
break
}
// If there is closable session but the delete
// height we have set for it is after the
// current block height, then our work is done.
if item.deleteHeight > height {
break
}
// Otherwise, we pop this item from the heap
// and handle it.
c.closableSessionQueue.Pop()
// Fetch the session from the DB so that we can
// extract the Tower info.
sess, err := c.cfg.DB.GetClientSession(
item.sessionID,
)
if err != nil {
c.log.Errorf("error calling "+
"GetClientSession for "+
"session %s: %v",
item.sessionID, err)
continue
}
err = c.deleteSessionFromTower(sess)
if err != nil {
c.log.Errorf("error deleting "+
"session %s from tower: %v",
sess.ID, err)
continue
}
err = c.cfg.DB.DeleteSession(item.sessionID)
if err != nil {
c.log.Errorf("could not delete "+
"session(%s) from DB: %w",
sess.ID, err)
continue
}
}
case <-c.forceQuit:
return
case <-c.quit:
return
}
}
}
// trackClosableSessions takes in a map of session IDs to the earliest block
// height at which the session should be deleted. For each of the sessions,
// a random delay is added to the block height and the session is added to the
// closableSessionQueue.
func (c *TowerClient) trackClosableSessions(
sessions map[wtdb.SessionID]uint32) error {
// For each closable session, add a random delay to its close
// height and add it to the closableSessionQueue.
for sID, blockHeight := range sessions {
delay, err := newRandomDelay(c.cfg.SessionCloseRange)
if err != nil {
return err
}
deleteHeight := blockHeight + delay
c.closableSessionQueue.Push(&sessionCloseItem{
sessionID: sID,
deleteHeight: deleteHeight,
})
}
return nil
}
// deleteSessionFromTower dials the tower that we created the session with and
// attempts to send the tower the DeleteSession message.
func (c *TowerClient) deleteSessionFromTower(sess *wtdb.ClientSession) error {
@ -1671,3 +1841,15 @@ func (c *TowerClient) logMessage(
preposition, peer.RemotePub().SerializeCompressed(),
peer.RemoteAddr())
}
func newRandomDelay(max uint32) (uint32, error) {
var maxDelay big.Int
maxDelay.SetUint64(uint64(max))
randDelay, err := rand.Int(rand.Reader, &maxDelay)
if err != nil {
return 0, err
}
return uint32(randDelay.Uint64()), nil
}

View file

@ -16,6 +16,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channelnotifier"
"github.com/lightningnetwork/lnd/input"
@ -396,6 +397,9 @@ type testHarness struct {
server *wtserver.Server
net *mockNet
blockEvents *mockBlockSub
height int32
channelEvents *mockSubscription
sendUpdatesOn bool
@ -458,6 +462,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
serverDB: serverDB,
serverCfg: serverCfg,
net: mockNet,
blockEvents: newMockBlockSub(t),
channelEvents: newMockSubscription(t),
channels: make(map[lnwire.ChannelID]*mockChannel),
closedChannels: make(map[lnwire.ChannelID]uint32),
@ -487,6 +492,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
return h.channelEvents, nil
},
FetchClosedChannel: fetchChannel,
ChainNotifier: h.blockEvents,
Dial: mockNet.Dial,
DB: clientDB,
AuthDial: mockNet.AuthDial,
@ -495,11 +501,12 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
NewAddress: func() ([]byte, error) {
return addrScript, nil
},
ReadTimeout: timeout,
WriteTimeout: timeout,
MinBackoff: time.Millisecond,
MaxBackoff: time.Second,
ForceQuitDelay: 10 * time.Second,
ReadTimeout: timeout,
WriteTimeout: timeout,
MinBackoff: time.Millisecond,
MaxBackoff: time.Second,
ForceQuitDelay: 10 * time.Second,
SessionCloseRange: 1,
}
if !cfg.noServerStart {
@ -518,6 +525,16 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
return h
}
// mine mimics the mining of new blocks by sending new block notifications.
func (h *testHarness) mine(numBlocks int) {
h.t.Helper()
for i := 0; i < numBlocks; i++ {
h.height++
h.blockEvents.sendNewBlock(h.height)
}
}
// startServer creates a new server using the harness's current serverCfg and
// starts it after pointing the mockNet's callback to the new server.
func (h *testHarness) startServer() {
@ -909,6 +926,44 @@ func (m *mockSubscription) Updates() <-chan interface{} {
return m.updates
}
// mockBlockSub mocks out the ChainNotifier.
type mockBlockSub struct {
t *testing.T
events chan *chainntnfs.BlockEpoch
chainntnfs.ChainNotifier
}
// newMockBlockSub creates a new mockBlockSub.
func newMockBlockSub(t *testing.T) *mockBlockSub {
t.Helper()
return &mockBlockSub{
t: t,
events: make(chan *chainntnfs.BlockEpoch),
}
}
// RegisterBlockEpochNtfn returns a channel that can be used to listen for new
// blocks.
func (m *mockBlockSub) RegisterBlockEpochNtfn(_ *chainntnfs.BlockEpoch) (
*chainntnfs.BlockEpochEvent, error) {
return &chainntnfs.BlockEpochEvent{
Epochs: m.events,
}, nil
}
// sendNewBlock will send a new block on the notification channel.
func (m *mockBlockSub) sendNewBlock(height int32) {
select {
case m.events <- &chainntnfs.BlockEpoch{Height: height}:
case <-time.After(waitTime):
m.t.Fatalf("timed out sending block: %d", height)
}
}
const (
localBalance = lnwire.MilliSatoshi(100000000)
remoteBalance = lnwire.MilliSatoshi(200000000)
@ -1891,6 +1946,73 @@ var clientTests = []clientTest{
return h.isSessionClosable(sessionIDs[0])
}, waitTime)
require.NoError(h.t, err)
// Now we will mine a few blocks. This will cause the
// necessary session-close-range to be exceeded meaning
// that the client should send the DeleteSession message
// to the server. We will assert that both the client
// and server have deleted the appropriate sessions and
// channel info.
// Before we mine blocks, assert that the client
// currently has 3 closable sessions.
closableSess, err := h.clientDB.ListClosableSessions()
require.NoError(h.t, err)
require.Len(h.t, closableSess, 3)
// Assert that the server is also aware of all of these
// sessions.
for sid := range closableSess {
_, err := h.serverDB.GetSessionInfo(&sid)
require.NoError(h.t, err)
}
// Also make a note of the total number of sessions the
// client has.
sessions, err := h.clientDB.ListClientSessions(nil, nil)
require.NoError(h.t, err)
require.Len(h.t, sessions, 4)
h.mine(3)
// The client should no longer have any closable
// sessions and the total list of client sessions should
// no longer include the three that it previously had
// marked as closable. The server should also no longer
// have these sessions in its DB.
err = wait.Predicate(func() bool {
sess, err := h.clientDB.ListClientSessions(
nil, nil,
)
require.NoError(h.t, err)
cs, err := h.clientDB.ListClosableSessions()
require.NoError(h.t, err)
if len(sess) != 1 || len(cs) != 0 {
return false
}
for sid := range closableSess {
_, ok := sess[sid]
if ok {
return false
}
_, err := h.serverDB.GetSessionInfo(
&sid,
)
if !errors.Is(
err, wtdb.ErrSessionNotFound,
) {
return false
}
}
return true
}, waitTime)
require.NoError(h.t, err)
},
},
}