mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-13 11:09:23 +01:00
watchtowers: handle closable sessions
Add a routine to the tower client that informs towers of sessions they can delete and also deletes any info about the session from the client DB.
This commit is contained in:
parent
8478b56ce6
commit
26e628c0fe
5 changed files with 331 additions and 7 deletions
|
@ -17,6 +17,11 @@ type WtClient struct {
|
|||
// SweepFeeRate specifies the fee rate in sat/byte to be used when
|
||||
// constructing justice transactions sent to the tower.
|
||||
SweepFeeRate uint64 `long:"sweep-fee-rate" description:"Specifies the fee rate in sat/byte to be used when constructing justice transactions sent to the watchtower."`
|
||||
|
||||
// SessionCloseRange is the range over which to choose a random number
|
||||
// of blocks to wait after the last channel of a session is closed
|
||||
// before sending the DeleteSession message to the tower server.
|
||||
SessionCloseRange uint32 `long:"session-close-range" description:"The range over which to choose a random number of blocks to wait after the last channel of a session is closed before sending the DeleteSession message to the tower server. Set to 1 for no delay."`
|
||||
}
|
||||
|
||||
// Validate ensures the user has provided a valid configuration.
|
||||
|
|
|
@ -997,6 +997,12 @@ litecoin.node=ltcd
|
|||
; supported at this time, if none are provided the tower will not be enabled.
|
||||
; wtclient.private-tower-uris=
|
||||
|
||||
; The range over which to choose a random number of blocks to wait after the
|
||||
; last channel of a session is closed before sending the DeleteSession message
|
||||
; to the tower server. The default is currently 288. Note that setting this to
|
||||
; a lower value will result in faster session cleanup _but_ that this comes
|
||||
; along with reduced privacy from the tower server.
|
||||
; wtclient.session-close-range=10
|
||||
|
||||
[healthcheck]
|
||||
|
||||
|
|
|
@ -1497,6 +1497,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
|
|||
policy.SweepFeeRate = sweepRateSatPerVByte.FeePerKWeight()
|
||||
}
|
||||
|
||||
sessionCloseRange := uint32(wtclient.DefaultSessionCloseRange)
|
||||
if cfg.WtClient.SessionCloseRange != 0 {
|
||||
sessionCloseRange = cfg.WtClient.SessionCloseRange
|
||||
}
|
||||
|
||||
if err := policy.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1516,6 +1521,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
|
|||
|
||||
s.towerClient, err = wtclient.New(&wtclient.Config{
|
||||
FetchClosedChannel: fetchClosedChannel,
|
||||
SessionCloseRange: sessionCloseRange,
|
||||
ChainNotifier: s.cc.ChainNotifier,
|
||||
SubscribeChannelEvents: func() (subscribe.Subscription,
|
||||
error) {
|
||||
|
||||
|
@ -1546,6 +1553,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
|
|||
|
||||
s.anchorTowerClient, err = wtclient.New(&wtclient.Config{
|
||||
FetchClosedChannel: fetchClosedChannel,
|
||||
SessionCloseRange: sessionCloseRange,
|
||||
ChainNotifier: s.cc.ChainNotifier,
|
||||
SubscribeChannelEvents: func() (subscribe.Subscription,
|
||||
error) {
|
||||
|
||||
|
|
|
@ -2,8 +2,10 @@ package wtclient
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -12,6 +14,7 @@ import (
|
|||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btclog"
|
||||
"github.com/lightningnetwork/lnd/build"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/channelnotifier"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
|
@ -43,6 +46,11 @@ const (
|
|||
// client should abandon any pending updates or session negotiations
|
||||
// before terminating.
|
||||
DefaultForceQuitDelay = 10 * time.Second
|
||||
|
||||
// DefaultSessionCloseRange is the range over which we will generate a
|
||||
// random number of blocks to delay closing a session after its last
|
||||
// channel has been closed.
|
||||
DefaultSessionCloseRange = 288
|
||||
)
|
||||
|
||||
// genSessionFilter constructs a filter that can be used to select sessions only
|
||||
|
@ -159,6 +167,9 @@ type Config struct {
|
|||
FetchClosedChannel func(cid lnwire.ChannelID) (
|
||||
*channeldb.ChannelCloseSummary, error)
|
||||
|
||||
// ChainNotifier can be used to subscribe to block notifications.
|
||||
ChainNotifier chainntnfs.ChainNotifier
|
||||
|
||||
// NewAddress generates a new on-chain sweep pkscript.
|
||||
NewAddress func() ([]byte, error)
|
||||
|
||||
|
@ -214,6 +225,11 @@ type Config struct {
|
|||
// watchtowers. If the exponential backoff produces a timeout greater
|
||||
// than this value, the backoff will be clamped to MaxBackoff.
|
||||
MaxBackoff time.Duration
|
||||
|
||||
// SessionCloseRange is the range over which we will generate a random
|
||||
// number of blocks to delay closing a session after its last channel
|
||||
// has been closed.
|
||||
SessionCloseRange uint32
|
||||
}
|
||||
|
||||
// newTowerMsg is an internal message we'll use within the TowerClient to signal
|
||||
|
@ -590,9 +606,34 @@ func (c *TowerClient) Start() error {
|
|||
delete(c.summaries, id)
|
||||
}
|
||||
|
||||
// Load all closable sessions.
|
||||
closableSessions, err := c.cfg.DB.ListClosableSessions()
|
||||
if err != nil {
|
||||
returnErr = err
|
||||
return
|
||||
}
|
||||
|
||||
err = c.trackClosableSessions(closableSessions)
|
||||
if err != nil {
|
||||
returnErr = err
|
||||
return
|
||||
}
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.handleChannelCloses(chanSub)
|
||||
|
||||
// Subscribe to new block events.
|
||||
blockEvents, err := c.cfg.ChainNotifier.RegisterBlockEpochNtfn(
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
returnErr = err
|
||||
return
|
||||
}
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.handleClosableSessions(blockEvents)
|
||||
|
||||
// Now start the session negotiator, which will allow us to
|
||||
// request new session as soon as the backupDispatcher starts
|
||||
// up.
|
||||
|
@ -876,7 +917,8 @@ func (c *TowerClient) handleChannelCloses(chanSub subscribe.Subscription) {
|
|||
}
|
||||
|
||||
// handleClosedChannel handles the closure of a single channel. It will mark the
|
||||
// channel as closed in the DB.
|
||||
// channel as closed in the DB, then it will handle all the sessions that are
|
||||
// now closable due to the channel closure.
|
||||
func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID,
|
||||
closeHeight uint32) error {
|
||||
|
||||
|
@ -890,18 +932,146 @@ func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID,
|
|||
|
||||
c.log.Debugf("Marking channel(%s) as closed", chanID)
|
||||
|
||||
_, err := c.cfg.DB.MarkChannelClosed(chanID, closeHeight)
|
||||
sessions, err := c.cfg.DB.MarkChannelClosed(chanID, closeHeight)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not mark channel(%s) as closed: %w",
|
||||
chanID, err)
|
||||
}
|
||||
|
||||
closableSessions := make(map[wtdb.SessionID]uint32, len(sessions))
|
||||
for _, sess := range sessions {
|
||||
closableSessions[sess] = closeHeight
|
||||
}
|
||||
|
||||
c.log.Debugf("Tracking %d new closable sessions as a result of "+
|
||||
"closing channel %s", len(closableSessions), chanID)
|
||||
|
||||
err = c.trackClosableSessions(closableSessions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not track closable sessions: %w", err)
|
||||
}
|
||||
|
||||
delete(c.summaries, chanID)
|
||||
delete(c.chanCommitHeights, chanID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleClosableSessions listens for new block notifications. For each block,
|
||||
// it checks the closableSessionQueue to see if there is a closable session with
|
||||
// a delete-height smaller than or equal to the new block, if there is then the
|
||||
// tower is informed that it can delete the session, and then we also delete it
|
||||
// from our DB.
|
||||
func (c *TowerClient) handleClosableSessions(
|
||||
blocksChan *chainntnfs.BlockEpochEvent) {
|
||||
|
||||
defer c.wg.Done()
|
||||
|
||||
c.log.Debug("Starting closable sessions handler")
|
||||
defer c.log.Debug("Stopping closable sessions handler")
|
||||
|
||||
for {
|
||||
select {
|
||||
case newBlock := <-blocksChan.Epochs:
|
||||
if newBlock == nil {
|
||||
return
|
||||
}
|
||||
|
||||
height := uint32(newBlock.Height)
|
||||
for {
|
||||
select {
|
||||
case <-c.quit:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// If there are no closable sessions that we
|
||||
// need to handle, then we are done and can
|
||||
// reevaluate when the next block comes.
|
||||
item := c.closableSessionQueue.Top()
|
||||
if item == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// If there is closable session but the delete
|
||||
// height we have set for it is after the
|
||||
// current block height, then our work is done.
|
||||
if item.deleteHeight > height {
|
||||
break
|
||||
}
|
||||
|
||||
// Otherwise, we pop this item from the heap
|
||||
// and handle it.
|
||||
c.closableSessionQueue.Pop()
|
||||
|
||||
// Fetch the session from the DB so that we can
|
||||
// extract the Tower info.
|
||||
sess, err := c.cfg.DB.GetClientSession(
|
||||
item.sessionID,
|
||||
)
|
||||
if err != nil {
|
||||
c.log.Errorf("error calling "+
|
||||
"GetClientSession for "+
|
||||
"session %s: %v",
|
||||
item.sessionID, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
err = c.deleteSessionFromTower(sess)
|
||||
if err != nil {
|
||||
c.log.Errorf("error deleting "+
|
||||
"session %s from tower: %v",
|
||||
sess.ID, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
err = c.cfg.DB.DeleteSession(item.sessionID)
|
||||
if err != nil {
|
||||
c.log.Errorf("could not delete "+
|
||||
"session(%s) from DB: %w",
|
||||
sess.ID, err)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
case <-c.forceQuit:
|
||||
return
|
||||
|
||||
case <-c.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// trackClosableSessions takes in a map of session IDs to the earliest block
|
||||
// height at which the session should be deleted. For each of the sessions,
|
||||
// a random delay is added to the block height and the session is added to the
|
||||
// closableSessionQueue.
|
||||
func (c *TowerClient) trackClosableSessions(
|
||||
sessions map[wtdb.SessionID]uint32) error {
|
||||
|
||||
// For each closable session, add a random delay to its close
|
||||
// height and add it to the closableSessionQueue.
|
||||
for sID, blockHeight := range sessions {
|
||||
delay, err := newRandomDelay(c.cfg.SessionCloseRange)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
deleteHeight := blockHeight + delay
|
||||
|
||||
c.closableSessionQueue.Push(&sessionCloseItem{
|
||||
sessionID: sID,
|
||||
deleteHeight: deleteHeight,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteSessionFromTower dials the tower that we created the session with and
|
||||
// attempts to send the tower the DeleteSession message.
|
||||
func (c *TowerClient) deleteSessionFromTower(sess *wtdb.ClientSession) error {
|
||||
|
@ -1671,3 +1841,15 @@ func (c *TowerClient) logMessage(
|
|||
preposition, peer.RemotePub().SerializeCompressed(),
|
||||
peer.RemoteAddr())
|
||||
}
|
||||
|
||||
func newRandomDelay(max uint32) (uint32, error) {
|
||||
var maxDelay big.Int
|
||||
maxDelay.SetUint64(uint64(max))
|
||||
|
||||
randDelay, err := rand.Int(rand.Reader, &maxDelay)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return uint32(randDelay.Uint64()), nil
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/txscript"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/channelnotifier"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
|
@ -396,6 +397,9 @@ type testHarness struct {
|
|||
server *wtserver.Server
|
||||
net *mockNet
|
||||
|
||||
blockEvents *mockBlockSub
|
||||
height int32
|
||||
|
||||
channelEvents *mockSubscription
|
||||
sendUpdatesOn bool
|
||||
|
||||
|
@ -458,6 +462,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
|||
serverDB: serverDB,
|
||||
serverCfg: serverCfg,
|
||||
net: mockNet,
|
||||
blockEvents: newMockBlockSub(t),
|
||||
channelEvents: newMockSubscription(t),
|
||||
channels: make(map[lnwire.ChannelID]*mockChannel),
|
||||
closedChannels: make(map[lnwire.ChannelID]uint32),
|
||||
|
@ -487,6 +492,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
|||
return h.channelEvents, nil
|
||||
},
|
||||
FetchClosedChannel: fetchChannel,
|
||||
ChainNotifier: h.blockEvents,
|
||||
Dial: mockNet.Dial,
|
||||
DB: clientDB,
|
||||
AuthDial: mockNet.AuthDial,
|
||||
|
@ -495,11 +501,12 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
|||
NewAddress: func() ([]byte, error) {
|
||||
return addrScript, nil
|
||||
},
|
||||
ReadTimeout: timeout,
|
||||
WriteTimeout: timeout,
|
||||
MinBackoff: time.Millisecond,
|
||||
MaxBackoff: time.Second,
|
||||
ForceQuitDelay: 10 * time.Second,
|
||||
ReadTimeout: timeout,
|
||||
WriteTimeout: timeout,
|
||||
MinBackoff: time.Millisecond,
|
||||
MaxBackoff: time.Second,
|
||||
ForceQuitDelay: 10 * time.Second,
|
||||
SessionCloseRange: 1,
|
||||
}
|
||||
|
||||
if !cfg.noServerStart {
|
||||
|
@ -518,6 +525,16 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
|||
return h
|
||||
}
|
||||
|
||||
// mine mimics the mining of new blocks by sending new block notifications.
|
||||
func (h *testHarness) mine(numBlocks int) {
|
||||
h.t.Helper()
|
||||
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
h.height++
|
||||
h.blockEvents.sendNewBlock(h.height)
|
||||
}
|
||||
}
|
||||
|
||||
// startServer creates a new server using the harness's current serverCfg and
|
||||
// starts it after pointing the mockNet's callback to the new server.
|
||||
func (h *testHarness) startServer() {
|
||||
|
@ -909,6 +926,44 @@ func (m *mockSubscription) Updates() <-chan interface{} {
|
|||
return m.updates
|
||||
}
|
||||
|
||||
// mockBlockSub mocks out the ChainNotifier.
|
||||
type mockBlockSub struct {
|
||||
t *testing.T
|
||||
events chan *chainntnfs.BlockEpoch
|
||||
|
||||
chainntnfs.ChainNotifier
|
||||
}
|
||||
|
||||
// newMockBlockSub creates a new mockBlockSub.
|
||||
func newMockBlockSub(t *testing.T) *mockBlockSub {
|
||||
t.Helper()
|
||||
|
||||
return &mockBlockSub{
|
||||
t: t,
|
||||
events: make(chan *chainntnfs.BlockEpoch),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterBlockEpochNtfn returns a channel that can be used to listen for new
|
||||
// blocks.
|
||||
func (m *mockBlockSub) RegisterBlockEpochNtfn(_ *chainntnfs.BlockEpoch) (
|
||||
*chainntnfs.BlockEpochEvent, error) {
|
||||
|
||||
return &chainntnfs.BlockEpochEvent{
|
||||
Epochs: m.events,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// sendNewBlock will send a new block on the notification channel.
|
||||
func (m *mockBlockSub) sendNewBlock(height int32) {
|
||||
select {
|
||||
case m.events <- &chainntnfs.BlockEpoch{Height: height}:
|
||||
|
||||
case <-time.After(waitTime):
|
||||
m.t.Fatalf("timed out sending block: %d", height)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
localBalance = lnwire.MilliSatoshi(100000000)
|
||||
remoteBalance = lnwire.MilliSatoshi(200000000)
|
||||
|
@ -1891,6 +1946,73 @@ var clientTests = []clientTest{
|
|||
return h.isSessionClosable(sessionIDs[0])
|
||||
}, waitTime)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
// Now we will mine a few blocks. This will cause the
|
||||
// necessary session-close-range to be exceeded meaning
|
||||
// that the client should send the DeleteSession message
|
||||
// to the server. We will assert that both the client
|
||||
// and server have deleted the appropriate sessions and
|
||||
// channel info.
|
||||
|
||||
// Before we mine blocks, assert that the client
|
||||
// currently has 3 closable sessions.
|
||||
closableSess, err := h.clientDB.ListClosableSessions()
|
||||
require.NoError(h.t, err)
|
||||
require.Len(h.t, closableSess, 3)
|
||||
|
||||
// Assert that the server is also aware of all of these
|
||||
// sessions.
|
||||
for sid := range closableSess {
|
||||
_, err := h.serverDB.GetSessionInfo(&sid)
|
||||
require.NoError(h.t, err)
|
||||
}
|
||||
|
||||
// Also make a note of the total number of sessions the
|
||||
// client has.
|
||||
sessions, err := h.clientDB.ListClientSessions(nil, nil)
|
||||
require.NoError(h.t, err)
|
||||
require.Len(h.t, sessions, 4)
|
||||
|
||||
h.mine(3)
|
||||
|
||||
// The client should no longer have any closable
|
||||
// sessions and the total list of client sessions should
|
||||
// no longer include the three that it previously had
|
||||
// marked as closable. The server should also no longer
|
||||
// have these sessions in its DB.
|
||||
err = wait.Predicate(func() bool {
|
||||
sess, err := h.clientDB.ListClientSessions(
|
||||
nil, nil,
|
||||
)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
cs, err := h.clientDB.ListClosableSessions()
|
||||
require.NoError(h.t, err)
|
||||
|
||||
if len(sess) != 1 || len(cs) != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for sid := range closableSess {
|
||||
_, ok := sess[sid]
|
||||
if ok {
|
||||
return false
|
||||
}
|
||||
|
||||
_, err := h.serverDB.GetSessionInfo(
|
||||
&sid,
|
||||
)
|
||||
if !errors.Is(
|
||||
err, wtdb.ErrSessionNotFound,
|
||||
) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
}, waitTime)
|
||||
require.NoError(h.t, err)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue