Merge pull request #7069 from ellemouton/deleteSessions

watchtower: start using the DeleteSession message
This commit is contained in:
Oliver Gugger 2023-03-20 18:45:36 +01:00 committed by GitHub
commit c4c1f1ac92
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 3222 additions and 157 deletions

View file

@ -9,6 +9,9 @@
* [Allow caller to filter sessions at the time of reading them from
disk](https://github.com/lightningnetwork/lnd/pull/7059)
* [Clean up sessions once all channels for which they have updates for are
closed. Also start sending the `DeleteSession` message to the
tower.](https://github.com/lightningnetwork/lnd/pull/7069)
## Misc

View file

@ -515,4 +515,8 @@ var allTestCases = []*lntest.TestCase{
Name: "lookup htlc resolution",
TestFunc: testLookupHtlcResolution,
},
{
Name: "watchtower session management",
TestFunc: testWatchtowerSessionManagement,
},
}

View file

@ -0,0 +1,172 @@
package itest
import (
"fmt"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/funding"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnrpc/routerrpc"
"github.com/lightningnetwork/lnd/lnrpc/wtclientrpc"
"github.com/lightningnetwork/lnd/lntest"
"github.com/lightningnetwork/lnd/lntest/node"
"github.com/lightningnetwork/lnd/lntest/wait"
"github.com/stretchr/testify/require"
)
// testWatchtowerSessionManagement tests that session deletion is done
// correctly.
func testWatchtowerSessionManagement(ht *lntest.HarnessTest) {
const (
chanAmt = funding.MaxBtcFundingAmount
paymentAmt = 10_000
numInvoices = 5
maxUpdates = numInvoices * 2
externalIP = "1.2.3.4"
sessionCloseRange = 1
)
// Set up Wallis the watchtower who will be used by Dave to watch over
// his channel commitment transactions.
wallis := ht.NewNode("Wallis", []string{
"--watchtower.active",
"--watchtower.externalip=" + externalIP,
})
wallisInfo := wallis.RPC.GetInfoWatchtower()
// Assert that Wallis has one listener and it is 0.0.0.0:9911 or
// [::]:9911. Since no listener is explicitly specified, one of these
// should be the default depending on whether the host supports IPv6 or
// not.
require.Len(ht, wallisInfo.Listeners, 1)
listener := wallisInfo.Listeners[0]
require.True(ht, listener == "0.0.0.0:9911" || listener == "[::]:9911")
// Assert the Wallis's URIs properly display the chosen external IP.
require.Len(ht, wallisInfo.Uris, 1)
require.Contains(ht, wallisInfo.Uris[0], externalIP)
// Dave will be the tower client.
daveArgs := []string{
"--wtclient.active",
fmt.Sprintf("--wtclient.max-updates=%d", maxUpdates),
fmt.Sprintf(
"--wtclient.session-close-range=%d", sessionCloseRange,
),
}
dave := ht.NewNode("Dave", daveArgs)
addTowerReq := &wtclientrpc.AddTowerRequest{
Pubkey: wallisInfo.Pubkey,
Address: listener,
}
dave.RPC.AddTower(addTowerReq)
// Assert that there exists a session between Dave and Wallis.
err := wait.NoError(func() error {
info := dave.RPC.GetTowerInfo(&wtclientrpc.GetTowerInfoRequest{
Pubkey: wallisInfo.Pubkey,
IncludeSessions: true,
})
var numSessions uint32
for _, sessionType := range info.SessionInfo {
numSessions += sessionType.NumSessions
}
if numSessions > 0 {
return nil
}
return fmt.Errorf("expected a non-zero number of sessions")
}, defaultTimeout)
require.NoError(ht, err)
// Before we make a channel, we'll load up Dave with some coins sent
// directly from the miner.
ht.FundCoins(btcutil.SatoshiPerBitcoin, dave)
// Connect Dave and Alice.
ht.ConnectNodes(dave, ht.Alice)
// Open a channel between Dave and Alice.
params := lntest.OpenChannelParams{
Amt: chanAmt,
}
chanPoint := ht.OpenChannel(dave, ht.Alice, params)
// Since there are 2 updates made for every payment and the maximum
// number of updates per session has been set to 10, make 5 payments
// between the pair so that the session is exhausted.
alicePayReqs, _, _ := ht.CreatePayReqs(
ht.Alice, paymentAmt, numInvoices,
)
send := func(node *node.HarnessNode, payReq string) {
stream := node.RPC.SendPayment(&routerrpc.SendPaymentRequest{
PaymentRequest: payReq,
TimeoutSeconds: 60,
FeeLimitMsat: noFeeLimitMsat,
})
ht.AssertPaymentStatusFromStream(
stream, lnrpc.Payment_SUCCEEDED,
)
}
for i := 0; i < numInvoices; i++ {
send(dave, alicePayReqs[i])
}
// assertNumBackups is a closure that asserts that Dave has a certain
// number of backups backed up to the tower. If mineOnFail is true,
// then a block will be mined each time the assertion fails.
assertNumBackups := func(expected int, mineOnFail bool) {
err = wait.NoError(func() error {
info := dave.RPC.GetTowerInfo(
&wtclientrpc.GetTowerInfoRequest{
Pubkey: wallisInfo.Pubkey,
IncludeSessions: true,
},
)
var numBackups uint32
for _, sessionType := range info.SessionInfo {
for _, session := range sessionType.Sessions {
numBackups += session.NumBackups
}
}
if numBackups == uint32(expected) {
return nil
}
if mineOnFail {
ht.Miner.MineBlocksSlow(1)
}
return fmt.Errorf("expected %d backups, got %d",
expected, numBackups)
}, defaultTimeout)
require.NoError(ht, err)
}
// Assert that one of the sessions now has 10 backups.
assertNumBackups(10, false)
// Now close the channel and wait for the close transaction to appear
// in the mempool so that it is included in a block when we mine.
ht.CloseChannelAssertPending(dave, chanPoint, false)
// Mine enough blocks to surpass the session-close-range. This should
// trigger the session to be deleted.
ht.MineBlocksAndAssertNumTxes(sessionCloseRange+6, 1)
// Wait for the session to be deleted. We know it has been deleted once
// the number of backups is back to zero. We check for number of backups
// instead of number of sessions because it is expected that the client
// would immediately negotiate another session after deleting the
// exhausted one. This time we set the "mineOnFail" parameter to true to
// ensure that the session deleting logic is run.
assertNumBackups(0, true)
}

View file

@ -17,6 +17,15 @@ type WtClient struct {
// SweepFeeRate specifies the fee rate in sat/byte to be used when
// constructing justice transactions sent to the tower.
SweepFeeRate uint64 `long:"sweep-fee-rate" description:"Specifies the fee rate in sat/byte to be used when constructing justice transactions sent to the watchtower."`
// SessionCloseRange is the range over which to choose a random number
// of blocks to wait after the last channel of a session is closed
// before sending the DeleteSession message to the tower server.
SessionCloseRange uint32 `long:"session-close-range" description:"The range over which to choose a random number of blocks to wait after the last channel of a session is closed before sending the DeleteSession message to the tower server. Set to 1 for no delay."`
// MaxUpdates is the maximum number of updates to be backed up in a
// single tower sessions.
MaxUpdates uint16 `long:"max-updates" description:"The maximum number of updates to be backed up in a single session."`
}
// Validate ensures the user has provided a valid configuration.

View file

@ -309,6 +309,7 @@ func (c *WatchtowerClient) ListTowers(ctx context.Context,
}
t.SessionInfo = append(t.SessionInfo, rpcTower.SessionInfo...)
t.Sessions = append(t.Sessions, rpcTower.Sessions...)
}
towers := make([]*Tower, 0, len(rpcTowers))
@ -365,6 +366,9 @@ func (c *WatchtowerClient) GetTowerInfo(ctx context.Context,
rpcTower.SessionInfo = append(
rpcTower.SessionInfo, rpcLegacyTower.SessionInfo...,
)
rpcTower.Sessions = append(
rpcTower.Sessions, rpcLegacyTower.Sessions...,
)
return rpcTower, nil
}

View file

@ -24,6 +24,20 @@ func (h *HarnessRPC) GetInfoWatchtower() *watchtowerrpc.GetInfoResponse {
return info
}
// GetTowerInfo makes an RPC call to the watchtower client of the given node and
// asserts.
func (h *HarnessRPC) GetTowerInfo(
req *wtclientrpc.GetTowerInfoRequest) *wtclientrpc.Tower {
ctxt, cancel := context.WithTimeout(h.runCtx, DefaultTimeout)
defer cancel()
info, err := h.WatchtowerClient.GetTowerInfo(ctxt, req)
h.NoError(err, "GetTowerInfo from WatchtowerClient")
return info
}
// AddTower makes a RPC call to the WatchtowerClient of the given node and
// asserts.
func (h *HarnessRPC) AddTower(

View file

@ -997,6 +997,15 @@ litecoin.node=ltcd
; supported at this time, if none are provided the tower will not be enabled.
; wtclient.private-tower-uris=
; The range over which to choose a random number of blocks to wait after the
; last channel of a session is closed before sending the DeleteSession message
; to the tower server. The default is currently 288. Note that setting this to
; a lower value will result in faster session cleanup _but_ that this comes
; along with reduced privacy from the tower server.
; wtclient.session-close-range=10
; The maximum number of updates to include in a tower session.
; wtclient.max-updates=1024
[healthcheck]

View file

@ -1497,6 +1497,15 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
policy.SweepFeeRate = sweepRateSatPerVByte.FeePerKWeight()
}
if cfg.WtClient.MaxUpdates != 0 {
policy.MaxUpdates = cfg.WtClient.MaxUpdates
}
sessionCloseRange := uint32(wtclient.DefaultSessionCloseRange)
if cfg.WtClient.SessionCloseRange != 0 {
sessionCloseRange = cfg.WtClient.SessionCloseRange
}
if err := policy.Validate(); err != nil {
return nil, err
}
@ -1512,7 +1521,18 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
)
}
fetchClosedChannel := s.chanStateDB.FetchClosedChannelForID
s.towerClient, err = wtclient.New(&wtclient.Config{
FetchClosedChannel: fetchClosedChannel,
SessionCloseRange: sessionCloseRange,
ChainNotifier: s.cc.ChainNotifier,
SubscribeChannelEvents: func() (subscribe.Subscription,
error) {
return s.channelNotifier.
SubscribeChannelEvents()
},
Signer: cc.Wallet.Cfg.Signer,
NewAddress: newSweepPkScriptGen(cc.Wallet),
SecretKeyRing: s.cc.KeyRing,
@ -1536,6 +1556,15 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
blob.Type(blob.FlagAnchorChannel)
s.anchorTowerClient, err = wtclient.New(&wtclient.Config{
FetchClosedChannel: fetchClosedChannel,
SessionCloseRange: sessionCloseRange,
ChainNotifier: s.cc.ChainNotifier,
SubscribeChannelEvents: func() (subscribe.Subscription,
error) {
return s.channelNotifier.
SubscribeChannelEvents()
},
Signer: cc.Wallet.Cfg.Signer,
NewAddress: newSweepPkScriptGen(cc.Wallet),
SecretKeyRing: s.cc.KeyRing,

View file

@ -69,6 +69,12 @@ type AddressIterator interface {
// Reset clears the iterators state, and makes the address at the front
// of the list the next item to be returned.
Reset()
// Copy constructs a new AddressIterator that has the same addresses
// as this iterator.
//
// NOTE that the address locks are not expected to be copied.
Copy() AddressIterator
}
// A compile-time check to ensure that addressIterator implements the
@ -324,6 +330,33 @@ func (a *addressIterator) GetAll() []net.Addr {
a.mu.Lock()
defer a.mu.Unlock()
return a.getAllUnsafe()
}
// Copy constructs a new AddressIterator that has the same addresses
// as this iterator.
//
// NOTE that the address locks will not be copied.
func (a *addressIterator) Copy() AddressIterator {
a.mu.Lock()
defer a.mu.Unlock()
addrs := a.getAllUnsafe()
// Since newAddressIterator will only ever return an error if it is
// initialised with zero addresses, we can ignore the error here since
// we are initialising it with the set of addresses of this
// addressIterator which is by definition a non-empty list.
iter, _ := newAddressIterator(addrs...)
return iter
}
// getAllUnsafe returns a copy of all the addresses in the iterator.
//
// NOTE: this method is not thread safe and so must only be called once the
// addressIterator mutex is already being held.
func (a *addressIterator) getAllUnsafe() []net.Addr {
var addrs []net.Addr
cursor := a.addrList.Front()

View file

@ -97,6 +97,11 @@ func TestAddrIterator(t *testing.T) {
addrList := iter.GetAll()
require.ElementsMatch(t, addrList, []net.Addr{addr1, addr2, addr3})
// Also check that an iterator constructed via the Copy method, also
// contains all the expected addresses.
newIterAddrs := iter.Copy().GetAll()
require.ElementsMatch(t, newIterAddrs, []net.Addr{addr1, addr2, addr3})
// Let's now remove addr3.
err = iter.Remove(addr3)
require.NoError(t, err)

View file

@ -29,6 +29,10 @@ type TowerCandidateIterator interface {
// candidates available as long as they remain in the set.
Reset() error
// GetTower gets the tower with the given ID from the iterator. If no
// such tower is found then ErrTowerNotInIterator is returned.
GetTower(id wtdb.TowerID) (*Tower, error)
// Next returns the next candidate tower. The iterator is not required
// to return results in any particular order. If no more candidates are
// available, ErrTowerCandidatesExhausted is returned.
@ -76,6 +80,20 @@ func (t *towerListIterator) Reset() error {
return nil
}
// GetTower gets the tower with the given ID from the iterator. If no such tower
// is found then ErrTowerNotInIterator is returned.
func (t *towerListIterator) GetTower(id wtdb.TowerID) (*Tower, error) {
t.mu.Lock()
defer t.mu.Unlock()
tower, ok := t.candidates[id]
if !ok {
return nil, ErrTowerNotInIterator
}
return tower, nil
}
// Next returns the next candidate tower. This iterator will always return
// candidates in the order given when the iterator was instantiated. If no more
// candidates are available, ErrTowerCandidatesExhausted is returned.

View file

@ -52,14 +52,10 @@ func randTower(t *testing.T) *Tower {
func copyTower(t *testing.T, tower *Tower) *Tower {
t.Helper()
addrs := tower.Addresses.GetAll()
addrIterator, err := newAddressIterator(addrs...)
require.NoError(t, err)
return &Tower{
ID: tower.ID,
IdentityKey: tower.IdentityKey,
Addresses: addrIterator,
Addresses: tower.Addresses.Copy(),
}
}
@ -83,9 +79,15 @@ func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *Tower) {
tower, err := i.Next()
require.NoError(t, err)
require.True(t, tower.IdentityKey.IsEqual(c.IdentityKey))
require.Equal(t, tower.ID, c.ID)
require.Equal(t, tower.Addresses.GetAll(), c.Addresses.GetAll())
assertTowersEqual(t, c, tower)
}
func assertTowersEqual(t *testing.T, expected, actual *Tower) {
t.Helper()
require.True(t, expected.IdentityKey.IsEqual(actual.IdentityKey))
require.Equal(t, expected.ID, actual.ID)
require.Equal(t, expected.Addresses.GetAll(), actual.Addresses.GetAll())
}
// TestTowerCandidateIterator asserts the internal state of a
@ -155,4 +157,16 @@ func TestTowerCandidateIterator(t *testing.T) {
towerIterator.AddCandidate(secondTower)
assertActiveCandidate(t, towerIterator, secondTower, true)
assertNextCandidate(t, towerIterator, secondTower)
// Assert that the GetTower correctly returns the tower too.
tower, err := towerIterator.GetTower(secondTower.ID)
require.NoError(t, err)
assertTowersEqual(t, secondTower, tower)
// Now remove the tower and assert that GetTower returns expected error.
err = towerIterator.RemoveCandidate(secondTower.ID, nil)
require.NoError(t, err)
_, err = towerIterator.GetTower(secondTower.ID)
require.ErrorIs(t, err, ErrTowerNotInIterator)
}

View file

@ -2,7 +2,10 @@ package wtclient
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"math/big"
"net"
"sync"
"time"
@ -11,11 +14,14 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btclog"
"github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channelnotifier"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/subscribe"
"github.com/lightningnetwork/lnd/tor"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
@ -40,6 +46,11 @@ const (
// client should abandon any pending updates or session negotiations
// before terminating.
DefaultForceQuitDelay = 10 * time.Second
// DefaultSessionCloseRange is the range over which we will generate a
// random number of blocks to delay closing a session after its last
// channel has been closed.
DefaultSessionCloseRange = 288
)
// genSessionFilter constructs a filter that can be used to select sessions only
@ -146,6 +157,19 @@ type Config struct {
// transaction.
Signer input.Signer
// SubscribeChannelEvents can be used to subscribe to channel event
// notifications.
SubscribeChannelEvents func() (subscribe.Subscription, error)
// FetchClosedChannel can be used to fetch the info about a closed
// channel. If the channel is not found or not yet closed then
// channeldb.ErrClosedChannelNotFound will be returned.
FetchClosedChannel func(cid lnwire.ChannelID) (
*channeldb.ChannelCloseSummary, error)
// ChainNotifier can be used to subscribe to block notifications.
ChainNotifier chainntnfs.ChainNotifier
// NewAddress generates a new on-chain sweep pkscript.
NewAddress func() ([]byte, error)
@ -201,6 +225,11 @@ type Config struct {
// watchtowers. If the exponential backoff produces a timeout greater
// than this value, the backoff will be clamped to MaxBackoff.
MaxBackoff time.Duration
// SessionCloseRange is the range over which we will generate a random
// number of blocks to delay closing a session after its last channel
// has been closed.
SessionCloseRange uint32
}
// newTowerMsg is an internal message we'll use within the TowerClient to signal
@ -258,6 +287,8 @@ type TowerClient struct {
sessionQueue *sessionQueue
prevTask *backupTask
closableSessionQueue *sessionCloseMinHeap
backupMu sync.Mutex
summaries wtdb.ChannelSummaries
chanCommitHeights map[lnwire.ChannelID]uint64
@ -269,6 +300,7 @@ type TowerClient struct {
staleTowers chan *staleTowerMsg
wg sync.WaitGroup
quit chan struct{}
forceQuit chan struct{}
}
@ -314,11 +346,13 @@ func New(config *Config) (*TowerClient, error) {
chanCommitHeights: make(map[lnwire.ChannelID]uint64),
activeSessions: make(sessionQueueSet),
summaries: chanSummaries,
closableSessionQueue: newSessionCloseMinHeap(),
statTicker: time.NewTicker(DefaultStatInterval),
stats: new(ClientStats),
newTowers: make(chan *newTowerMsg),
staleTowers: make(chan *staleTowerMsg),
forceQuit: make(chan struct{}),
quit: make(chan struct{}),
}
// perUpdate is a callback function that will be used to inspect the
@ -364,7 +398,7 @@ func New(config *Config) (*TowerClient, error) {
return
}
log.Infof("Using private watchtower %s, offering policy %s",
c.log.Infof("Using private watchtower %s, offering policy %s",
tower, cfg.Policy)
// Add the tower to the set of candidate towers.
@ -435,27 +469,19 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
}
for _, s := range sessions {
towerKeyDesc, err := keyRing.DeriveKey(
keychain.KeyLocator{
Family: keychain.KeyFamilyTowerSession,
Index: s.KeyIndex,
},
if !sessionFilter(s) {
continue
}
cs, err := NewClientSessionFromDBSession(
s, tower, keyRing,
)
if err != nil {
return nil, err
}
sessionKeyECDH := keychain.NewPubKeyECDH(
towerKeyDesc, keyRing,
)
// Add the session to the set of candidate sessions.
candidateSessions[s.ID] = &ClientSession{
ID: s.ID,
ClientSessionBody: s.ClientSessionBody,
Tower: tower,
SessionKeyECDH: sessionKeyECDH,
}
candidateSessions[s.ID] = cs
perActiveTower(tower)
}
@ -548,10 +574,70 @@ func (c *TowerClient) Start() error {
}
}
chanSub, err := c.cfg.SubscribeChannelEvents()
if err != nil {
returnErr = err
return
}
// Iterate over the list of registered channels and check if
// any of them can be marked as closed.
for id := range c.summaries {
isClosed, closedHeight, err := c.isChannelClosed(id)
if err != nil {
returnErr = err
return
}
if !isClosed {
continue
}
_, err = c.cfg.DB.MarkChannelClosed(id, closedHeight)
if err != nil {
c.log.Errorf("could not mark channel(%s) as "+
"closed: %v", id, err)
continue
}
// Since the channel has been marked as closed, we can
// also remove it from the channel summaries map.
delete(c.summaries, id)
}
// Load all closable sessions.
closableSessions, err := c.cfg.DB.ListClosableSessions()
if err != nil {
returnErr = err
return
}
err = c.trackClosableSessions(closableSessions)
if err != nil {
returnErr = err
return
}
c.wg.Add(1)
go c.handleChannelCloses(chanSub)
// Subscribe to new block events.
blockEvents, err := c.cfg.ChainNotifier.RegisterBlockEpochNtfn(
nil,
)
if err != nil {
returnErr = err
return
}
c.wg.Add(1)
go c.handleClosableSessions(blockEvents)
// Now start the session negotiator, which will allow us to
// request new session as soon as the backupDispatcher starts
// up.
err := c.negotiator.Start()
err = c.negotiator.Start()
if err != nil {
returnErr = err
return
@ -599,6 +685,7 @@ func (c *TowerClient) Stop() error {
// dispatcher to exit. The backup queue will signal it's
// completion to the dispatcher, which releases the wait group
// after all tasks have been assigned to session queues.
close(c.quit)
c.wg.Wait()
// 4. Since all valid tasks have been assigned to session
@ -780,6 +867,335 @@ func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) {
return c.getOrInitActiveQueue(candidateSession, updates), nil
}
// handleChannelCloses listens for channel close events and marks channels as
// closed in the DB.
//
// NOTE: This method MUST be run as a goroutine.
func (c *TowerClient) handleChannelCloses(chanSub subscribe.Subscription) {
defer c.wg.Done()
c.log.Debugf("Starting channel close handler")
defer c.log.Debugf("Stopping channel close handler")
for {
select {
case update, ok := <-chanSub.Updates():
if !ok {
c.log.Debugf("Channel notifier has exited")
return
}
// We only care about channel-close events.
event, ok := update.(channelnotifier.ClosedChannelEvent)
if !ok {
continue
}
chanID := lnwire.NewChanIDFromOutPoint(
&event.CloseSummary.ChanPoint,
)
c.log.Debugf("Received ClosedChannelEvent for "+
"channel: %s", chanID)
err := c.handleClosedChannel(
chanID, event.CloseSummary.CloseHeight,
)
if err != nil {
c.log.Errorf("Could not handle channel close "+
"event for channel(%s): %v", chanID,
err)
}
case <-c.forceQuit:
return
case <-c.quit:
return
}
}
}
// handleClosedChannel handles the closure of a single channel. It will mark the
// channel as closed in the DB, then it will handle all the sessions that are
// now closable due to the channel closure.
func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID,
closeHeight uint32) error {
c.backupMu.Lock()
defer c.backupMu.Unlock()
// We only care about channels registered with the tower client.
if _, ok := c.summaries[chanID]; !ok {
return nil
}
c.log.Debugf("Marking channel(%s) as closed", chanID)
sessions, err := c.cfg.DB.MarkChannelClosed(chanID, closeHeight)
if err != nil {
return fmt.Errorf("could not mark channel(%s) as closed: %w",
chanID, err)
}
closableSessions := make(map[wtdb.SessionID]uint32, len(sessions))
for _, sess := range sessions {
closableSessions[sess] = closeHeight
}
c.log.Debugf("Tracking %d new closable sessions as a result of "+
"closing channel %s", len(closableSessions), chanID)
err = c.trackClosableSessions(closableSessions)
if err != nil {
return fmt.Errorf("could not track closable sessions: %w", err)
}
delete(c.summaries, chanID)
delete(c.chanCommitHeights, chanID)
return nil
}
// handleClosableSessions listens for new block notifications. For each block,
// it checks the closableSessionQueue to see if there is a closable session with
// a delete-height smaller than or equal to the new block, if there is then the
// tower is informed that it can delete the session, and then we also delete it
// from our DB.
func (c *TowerClient) handleClosableSessions(
blocksChan *chainntnfs.BlockEpochEvent) {
defer c.wg.Done()
c.log.Debug("Starting closable sessions handler")
defer c.log.Debug("Stopping closable sessions handler")
for {
select {
case newBlock := <-blocksChan.Epochs:
if newBlock == nil {
return
}
height := uint32(newBlock.Height)
for {
select {
case <-c.quit:
return
default:
}
// If there are no closable sessions that we
// need to handle, then we are done and can
// reevaluate when the next block comes.
item := c.closableSessionQueue.Top()
if item == nil {
break
}
// If there is closable session but the delete
// height we have set for it is after the
// current block height, then our work is done.
if item.deleteHeight > height {
break
}
// Otherwise, we pop this item from the heap
// and handle it.
c.closableSessionQueue.Pop()
// Fetch the session from the DB so that we can
// extract the Tower info.
sess, err := c.cfg.DB.GetClientSession(
item.sessionID,
)
if err != nil {
c.log.Errorf("error calling "+
"GetClientSession for "+
"session %s: %v",
item.sessionID, err)
continue
}
err = c.deleteSessionFromTower(sess)
if err != nil {
c.log.Errorf("error deleting "+
"session %s from tower: %v",
sess.ID, err)
continue
}
err = c.cfg.DB.DeleteSession(item.sessionID)
if err != nil {
c.log.Errorf("could not delete "+
"session(%s) from DB: %w",
sess.ID, err)
continue
}
}
case <-c.forceQuit:
return
case <-c.quit:
return
}
}
}
// trackClosableSessions takes in a map of session IDs to the earliest block
// height at which the session should be deleted. For each of the sessions,
// a random delay is added to the block height and the session is added to the
// closableSessionQueue.
func (c *TowerClient) trackClosableSessions(
sessions map[wtdb.SessionID]uint32) error {
// For each closable session, add a random delay to its close
// height and add it to the closableSessionQueue.
for sID, blockHeight := range sessions {
delay, err := newRandomDelay(c.cfg.SessionCloseRange)
if err != nil {
return err
}
deleteHeight := blockHeight + delay
c.closableSessionQueue.Push(&sessionCloseItem{
sessionID: sID,
deleteHeight: deleteHeight,
})
}
return nil
}
// deleteSessionFromTower dials the tower that we created the session with and
// attempts to send the tower the DeleteSession message.
func (c *TowerClient) deleteSessionFromTower(sess *wtdb.ClientSession) error {
// First, we check if we have already loaded this tower in our
// candidate towers iterator.
tower, err := c.candidateTowers.GetTower(sess.TowerID)
if errors.Is(err, ErrTowerNotInIterator) {
// If not, then we attempt to load it from the DB.
dbTower, err := c.cfg.DB.LoadTowerByID(sess.TowerID)
if err != nil {
return err
}
tower, err = NewTowerFromDBTower(dbTower)
if err != nil {
return err
}
} else if err != nil {
return err
}
session, err := NewClientSessionFromDBSession(
sess, tower, c.cfg.SecretKeyRing,
)
if err != nil {
return err
}
localInit := wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
c.cfg.ChainHash,
)
var (
conn wtserver.Peer
// addrIterator is a copy of the tower's address iterator.
// We use this copy so that iterating through the addresses does
// not affect any other threads using this iterator.
addrIterator = tower.Addresses.Copy()
towerAddr = addrIterator.Peek()
)
// Attempt to dial the tower with its available addresses.
for {
conn, err = c.dial(
session.SessionKeyECDH, &lnwire.NetAddress{
IdentityKey: tower.IdentityKey,
Address: towerAddr,
},
)
if err != nil {
// If there are more addrs available, immediately try
// those.
nextAddr, iteratorErr := addrIterator.Next()
if iteratorErr == nil {
towerAddr = nextAddr
continue
}
// Otherwise, if we have exhausted the address list,
// exit.
addrIterator.Reset()
return fmt.Errorf("failed to dial tower(%x) at any "+
"available addresses",
tower.IdentityKey.SerializeCompressed())
}
break
}
defer conn.Close()
// Send Init to tower.
err = c.sendMessage(conn, localInit)
if err != nil {
return err
}
// Receive Init from tower.
remoteMsg, err := c.readMessage(conn)
if err != nil {
return err
}
remoteInit, ok := remoteMsg.(*wtwire.Init)
if !ok {
return fmt.Errorf("watchtower %s responded with %T to Init",
towerAddr, remoteMsg)
}
// Validate Init.
err = localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames)
if err != nil {
return err
}
// Send DeleteSession to tower.
err = c.sendMessage(conn, &wtwire.DeleteSession{})
if err != nil {
return err
}
// Receive DeleteSessionReply from tower.
remoteMsg, err = c.readMessage(conn)
if err != nil {
return err
}
deleteSessionReply, ok := remoteMsg.(*wtwire.DeleteSessionReply)
if !ok {
return fmt.Errorf("watchtower %s responded with %T to "+
"DeleteSession", towerAddr, remoteMsg)
}
switch deleteSessionReply.Code {
case wtwire.CodeOK, wtwire.DeleteSessionCodeNotFound:
return nil
default:
return fmt.Errorf("received error code %v in "+
"DeleteSessionReply when attempting to delete "+
"session from tower", deleteSessionReply.Code)
}
}
// backupDispatcher processes events coming from the taskPipeline and is
// responsible for detecting when the client needs to renegotiate a session to
// fulfill continuing demand. The event loop exits after all tasks have been
@ -1153,6 +1569,22 @@ func (c *TowerClient) initActiveQueue(s *ClientSession,
return sq
}
// isChanClosed can be used to check if the channel with the given ID has been
// closed. If it has been, the block height in which its closing transaction was
// mined will also be returned.
func (c *TowerClient) isChannelClosed(id lnwire.ChannelID) (bool, uint32,
error) {
chanSum, err := c.cfg.FetchClosedChannel(id)
if errors.Is(err, channeldb.ErrClosedChannelNotFound) {
return false, 0, nil
} else if err != nil {
return false, 0, err
}
return true, chanSum.CloseHeight, nil
}
// AddTower adds a new watchtower reachable at the given address and considers
// it for new sessions. If the watchtower already exists, then any new addresses
// included will be considered when dialing it for session negotiations and
@ -1409,3 +1841,15 @@ func (c *TowerClient) logMessage(
preposition, peer.RemotePub().SerializeCompressed(),
peer.RemoteAddr())
}
func newRandomDelay(max uint32) (uint32, error) {
var maxDelay big.Int
maxDelay.SetUint64(uint64(max))
randDelay, err := rand.Int(rand.Reader, &maxDelay)
if err != nil {
return 0, err
}
return uint32(randDelay.Uint64()), nil
}

View file

@ -1,6 +1,7 @@
package wtclient_test
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
@ -15,12 +16,15 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channelnotifier"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntest/wait"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/subscribe"
"github.com/lightningnetwork/lnd/tor"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtclient"
@ -393,8 +397,15 @@ type testHarness struct {
server *wtserver.Server
net *mockNet
blockEvents *mockBlockSub
height int32
channelEvents *mockSubscription
sendUpdatesOn bool
mu sync.Mutex
channels map[lnwire.ChannelID]*mockChannel
closedChannels map[lnwire.ChannelID]uint32
quit chan struct{}
}
@ -441,8 +452,47 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
mockNet := newMockNet()
clientDB := wtmock.NewClientDB()
clientCfg := &wtclient.Config{
h := &testHarness{
t: t,
cfg: cfg,
signer: signer,
capacity: cfg.localBalance + cfg.remoteBalance,
clientDB: clientDB,
serverAddr: towerAddr,
serverDB: serverDB,
serverCfg: serverCfg,
net: mockNet,
blockEvents: newMockBlockSub(t),
channelEvents: newMockSubscription(t),
channels: make(map[lnwire.ChannelID]*mockChannel),
closedChannels: make(map[lnwire.ChannelID]uint32),
quit: make(chan struct{}),
}
t.Cleanup(func() {
close(h.quit)
})
fetchChannel := func(id lnwire.ChannelID) (
*channeldb.ChannelCloseSummary, error) {
h.mu.Lock()
defer h.mu.Unlock()
height, ok := h.closedChannels[id]
if !ok {
return nil, channeldb.ErrClosedChannelNotFound
}
return &channeldb.ChannelCloseSummary{CloseHeight: height}, nil
}
h.clientCfg = &wtclient.Config{
Signer: signer,
SubscribeChannelEvents: func() (subscribe.Subscription, error) {
return h.channelEvents, nil
},
FetchClosedChannel: fetchChannel,
ChainNotifier: h.blockEvents,
Dial: mockNet.Dial,
DB: clientDB,
AuthDial: mockNet.AuthDial,
@ -456,26 +506,9 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
MinBackoff: time.Millisecond,
MaxBackoff: time.Second,
ForceQuitDelay: 10 * time.Second,
SessionCloseRange: 1,
}
h := &testHarness{
t: t,
cfg: cfg,
signer: signer,
capacity: cfg.localBalance + cfg.remoteBalance,
clientDB: clientDB,
clientCfg: clientCfg,
serverAddr: towerAddr,
serverDB: serverDB,
serverCfg: serverCfg,
net: mockNet,
channels: make(map[lnwire.ChannelID]*mockChannel),
quit: make(chan struct{}),
}
t.Cleanup(func() {
close(h.quit)
})
if !cfg.noServerStart {
h.startServer()
t.Cleanup(h.stopServer)
@ -492,6 +525,16 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
return h
}
// mine mimics the mining of new blocks by sending new block notifications.
func (h *testHarness) mine(numBlocks int) {
h.t.Helper()
for i := 0; i < numBlocks; i++ {
h.height++
h.blockEvents.sendNewBlock(h.height)
}
}
// startServer creates a new server using the harness's current serverCfg and
// starts it after pointing the mockNet's callback to the new server.
func (h *testHarness) startServer() {
@ -576,6 +619,41 @@ func (h *testHarness) channel(id uint64) *mockChannel {
return c
}
// closeChannel marks a channel as closed.
//
// NOTE: The method fails if a channel for id does not exist.
func (h *testHarness) closeChannel(id uint64, height uint32) {
h.t.Helper()
h.mu.Lock()
defer h.mu.Unlock()
chanID := chanIDFromInt(id)
_, ok := h.channels[chanID]
require.Truef(h.t, ok, "unable to fetch channel %d", id)
h.closedChannels[chanID] = height
delete(h.channels, chanID)
chanPointHash, err := chainhash.NewHash(chanID[:])
require.NoError(h.t, err)
if !h.sendUpdatesOn {
return
}
h.channelEvents.sendUpdate(channelnotifier.ClosedChannelEvent{
CloseSummary: &channeldb.ChannelCloseSummary{
ChanPoint: wire.OutPoint{
Hash: *chanPointHash,
Index: 0,
},
CloseHeight: height,
},
})
}
// registerChannel registers the channel identified by id with the client.
func (h *testHarness) registerChannel(id uint64) {
h.t.Helper()
@ -624,7 +702,7 @@ func (h *testHarness) backupState(id, i uint64, expErr error) {
err := h.client.BackupState(
&chanID, retribution, channeldb.SingleFunderBit,
)
require.ErrorIs(h.t, expErr, err)
require.ErrorIs(h.t, err, expErr)
}
// sendPayments instructs the channel identified by id to send amt to the remote
@ -770,11 +848,132 @@ func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) {
require.NoError(h.t, err)
}
// relevantSessions returns a list of session IDs that have acked updates for
// the given channel ID.
func (h *testHarness) relevantSessions(chanID uint64) []wtdb.SessionID {
h.t.Helper()
var (
sessionIDs []wtdb.SessionID
cID = chanIDFromInt(chanID)
)
collectSessions := wtdb.WithPerNumAckedUpdates(
func(session *wtdb.ClientSession, id lnwire.ChannelID,
_ uint16) {
if !bytes.Equal(id[:], cID[:]) {
return
}
sessionIDs = append(sessionIDs, session.ID)
},
)
_, err := h.clientDB.ListClientSessions(nil, nil, collectSessions)
require.NoError(h.t, err)
return sessionIDs
}
// isSessionClosable returns true if the given session has been marked as
// closable in the DB.
func (h *testHarness) isSessionClosable(id wtdb.SessionID) bool {
h.t.Helper()
cs, err := h.clientDB.ListClosableSessions()
require.NoError(h.t, err)
_, ok := cs[id]
return ok
}
// mockSubscription is a mock subscription client that blocks on sends into the
// updates channel.
type mockSubscription struct {
t *testing.T
updates chan interface{}
// Embed the subscription interface in this mock so that we satisfy it.
subscribe.Subscription
}
// newMockSubscription creates a mock subscription.
func newMockSubscription(t *testing.T) *mockSubscription {
t.Helper()
return &mockSubscription{
t: t,
updates: make(chan interface{}),
}
}
// sendUpdate sends an update into our updates channel, mocking the dispatch of
// an update from a subscription server. This call will fail the test if the
// update is not consumed within our timeout.
func (m *mockSubscription) sendUpdate(update interface{}) {
select {
case m.updates <- update:
case <-time.After(waitTime):
m.t.Fatalf("update: %v timeout", update)
}
}
// Updates returns the updates channel for the mock.
func (m *mockSubscription) Updates() <-chan interface{} {
return m.updates
}
// mockBlockSub mocks out the ChainNotifier.
type mockBlockSub struct {
t *testing.T
events chan *chainntnfs.BlockEpoch
chainntnfs.ChainNotifier
}
// newMockBlockSub creates a new mockBlockSub.
func newMockBlockSub(t *testing.T) *mockBlockSub {
t.Helper()
return &mockBlockSub{
t: t,
events: make(chan *chainntnfs.BlockEpoch),
}
}
// RegisterBlockEpochNtfn returns a channel that can be used to listen for new
// blocks.
func (m *mockBlockSub) RegisterBlockEpochNtfn(_ *chainntnfs.BlockEpoch) (
*chainntnfs.BlockEpochEvent, error) {
return &chainntnfs.BlockEpochEvent{
Epochs: m.events,
}, nil
}
// sendNewBlock will send a new block on the notification channel.
func (m *mockBlockSub) sendNewBlock(height int32) {
select {
case m.events <- &chainntnfs.BlockEpoch{Height: height}:
case <-time.After(waitTime):
m.t.Fatalf("timed out sending block: %d", height)
}
}
const (
localBalance = lnwire.MilliSatoshi(100000000)
remoteBalance = lnwire.MilliSatoshi(200000000)
)
var defaultTxPolicy = wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
}
type clientTest struct {
name string
cfg harnessCfg
@ -791,10 +990,7 @@ var clientTests = []clientTest{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
TxPolicy: defaultTxPolicy,
MaxUpdates: 20000,
},
noRegisterChan0: true,
@ -825,10 +1021,7 @@ var clientTests = []clientTest{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
TxPolicy: defaultTxPolicy,
MaxUpdates: 20000,
},
},
@ -860,10 +1053,7 @@ var clientTests = []clientTest{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
TxPolicy: defaultTxPolicy,
MaxUpdates: 5,
},
},
@ -927,10 +1117,7 @@ var clientTests = []clientTest{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
TxPolicy: defaultTxPolicy,
MaxUpdates: 20000,
},
},
@ -1006,10 +1193,7 @@ var clientTests = []clientTest{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
TxPolicy: defaultTxPolicy,
MaxUpdates: 5,
},
},
@ -1062,10 +1246,7 @@ var clientTests = []clientTest{
localBalance: 100000001, // ensure (% amt != 0)
remoteBalance: 200000001, // ensure (% amt != 0)
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
TxPolicy: defaultTxPolicy,
MaxUpdates: 1000,
},
},
@ -1106,10 +1287,7 @@ var clientTests = []clientTest{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
TxPolicy: defaultTxPolicy,
MaxUpdates: 5,
},
},
@ -1156,10 +1334,7 @@ var clientTests = []clientTest{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
TxPolicy: defaultTxPolicy,
MaxUpdates: 5,
},
noAckCreateSession: true,
@ -1212,10 +1387,7 @@ var clientTests = []clientTest{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
TxPolicy: defaultTxPolicy,
MaxUpdates: 5,
},
noAckCreateSession: true,
@ -1274,10 +1446,7 @@ var clientTests = []clientTest{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
TxPolicy: defaultTxPolicy,
MaxUpdates: 10,
},
},
@ -1333,10 +1502,7 @@ var clientTests = []clientTest{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
TxPolicy: defaultTxPolicy,
MaxUpdates: 5,
},
},
@ -1381,10 +1547,7 @@ var clientTests = []clientTest{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
TxPolicy: defaultTxPolicy,
MaxUpdates: 5,
},
},
@ -1489,10 +1652,7 @@ var clientTests = []clientTest{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
TxPolicy: defaultTxPolicy,
MaxUpdates: 5,
},
},
@ -1557,10 +1717,7 @@ var clientTests = []clientTest{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
TxPolicy: defaultTxPolicy,
MaxUpdates: 5,
},
noServerStart: true,
@ -1654,6 +1811,209 @@ var clientTests = []clientTest{
}, waitTime)
require.NoError(h.t, err)
},
}, {
name: "assert that sessions are correctly marked as closable",
cfg: harnessCfg{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: defaultTxPolicy,
MaxUpdates: 5,
},
},
fn: func(h *testHarness) {
const numUpdates = 5
// In this test we assert that a channel is correctly
// marked as closed and that sessions are also correctly
// marked as closable.
// We start with the sendUpdatesOn parameter set to
// false so that we can test that channels are correctly
// evaluated at startup.
h.sendUpdatesOn = false
// Advance channel 0 to create all states and back them
// all up. This will saturate the session with updates
// for channel 0 which means that the session should be
// considered closable when channel 0 is closed.
hints := h.advanceChannelN(0, numUpdates)
h.backupStates(0, 0, numUpdates, nil)
h.waitServerUpdates(hints, waitTime)
// We expect only 1 session to have updates for this
// channel.
sessionIDs := h.relevantSessions(0)
require.Len(h.t, sessionIDs, 1)
// Since channel 0 is still open, the session should not
// yet be closable.
require.False(h.t, h.isSessionClosable(sessionIDs[0]))
// Close the channel.
h.closeChannel(0, 1)
// Since updates are currently not being sent, we expect
// the session to still not be marked as closable.
require.False(h.t, h.isSessionClosable(sessionIDs[0]))
// Restart the client.
h.client.ForceQuit()
h.startClient()
// The session should now have been marked as closable.
err := wait.Predicate(func() bool {
return h.isSessionClosable(sessionIDs[0])
}, waitTime)
require.NoError(h.t, err)
// Now we set sendUpdatesOn to true and do the same with
// a new channel. A restart should now not be necessary
// anymore.
h.sendUpdatesOn = true
h.makeChannel(
1, h.cfg.localBalance, h.cfg.remoteBalance,
)
h.registerChannel(1)
hints = h.advanceChannelN(1, numUpdates)
h.backupStates(1, 0, numUpdates, nil)
h.waitServerUpdates(hints, waitTime)
// Determine the ID of the session of interest.
sessionIDs = h.relevantSessions(1)
// We expect only 1 session to have updates for this
// channel.
require.Len(h.t, sessionIDs, 1)
// Assert that the session is not yet closable since
// the channel is still open.
require.False(h.t, h.isSessionClosable(sessionIDs[0]))
// Now close the channel.
h.closeChannel(1, 1)
// Since the updates have been turned on, the session
// should now show up as closable.
err = wait.Predicate(func() bool {
return h.isSessionClosable(sessionIDs[0])
}, waitTime)
require.NoError(h.t, err)
// Now we test that a session must be exhausted with all
// channels closed before it is seen as closable.
h.makeChannel(
2, h.cfg.localBalance, h.cfg.remoteBalance,
)
h.registerChannel(2)
// Fill up only half of the session updates.
hints = h.advanceChannelN(2, numUpdates)
h.backupStates(2, 0, numUpdates/2, nil)
h.waitServerUpdates(hints[:numUpdates/2], waitTime)
// Determine the ID of the session of interest.
sessionIDs = h.relevantSessions(2)
// We expect only 1 session to have updates for this
// channel.
require.Len(h.t, sessionIDs, 1)
// Now close the channel.
h.closeChannel(2, 1)
// The session should _not_ be closable due to it not
// being exhausted yet.
require.False(h.t, h.isSessionClosable(sessionIDs[0]))
// Create a new channel.
h.makeChannel(
3, h.cfg.localBalance, h.cfg.remoteBalance,
)
h.registerChannel(3)
hints = h.advanceChannelN(3, numUpdates)
h.backupStates(3, 0, numUpdates, nil)
h.waitServerUpdates(hints, waitTime)
// Close it.
h.closeChannel(3, 1)
// Now the session should be closable.
err = wait.Predicate(func() bool {
return h.isSessionClosable(sessionIDs[0])
}, waitTime)
require.NoError(h.t, err)
// Now we will mine a few blocks. This will cause the
// necessary session-close-range to be exceeded meaning
// that the client should send the DeleteSession message
// to the server. We will assert that both the client
// and server have deleted the appropriate sessions and
// channel info.
// Before we mine blocks, assert that the client
// currently has 3 closable sessions.
closableSess, err := h.clientDB.ListClosableSessions()
require.NoError(h.t, err)
require.Len(h.t, closableSess, 3)
// Assert that the server is also aware of all of these
// sessions.
for sid := range closableSess {
_, err := h.serverDB.GetSessionInfo(&sid)
require.NoError(h.t, err)
}
// Also make a note of the total number of sessions the
// client has.
sessions, err := h.clientDB.ListClientSessions(nil, nil)
require.NoError(h.t, err)
require.Len(h.t, sessions, 4)
h.mine(3)
// The client should no longer have any closable
// sessions and the total list of client sessions should
// no longer include the three that it previously had
// marked as closable. The server should also no longer
// have these sessions in its DB.
err = wait.Predicate(func() bool {
sess, err := h.clientDB.ListClientSessions(
nil, nil,
)
require.NoError(h.t, err)
cs, err := h.clientDB.ListClosableSessions()
require.NoError(h.t, err)
if len(sess) != 1 || len(cs) != 0 {
return false
}
for sid := range closableSess {
_, ok := sess[sid]
if ok {
return false
}
_, err := h.serverDB.GetSessionInfo(
&sid,
)
if !errors.Is(
err, wtdb.ErrSessionNotFound,
) {
return false
}
}
return true
}, waitTime)
require.NoError(h.t, err)
},
},
}

View file

@ -1,6 +1,8 @@
package wtclient
import "errors"
import (
"errors"
)
var (
// ErrClientExiting signals that the watchtower client is shutting down.
@ -11,6 +13,10 @@ var (
ErrTowerCandidatesExhausted = errors.New("exhausted all tower " +
"candidates")
// ErrTowerNotInIterator is returned when a requested tower was not
// found in the iterator.
ErrTowerNotInIterator = errors.New("tower not in iterator")
// ErrPermanentTowerFailure signals that the tower has reported that it
// has permanently failed or the client believes this has happened based
// on the tower's behavior.

View file

@ -64,6 +64,11 @@ type DB interface {
...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error)
// GetClientSession loads the ClientSession with the given ID from the
// DB.
GetClientSession(wtdb.SessionID,
...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error)
// FetchSessionCommittedUpdates retrieves the current set of un-acked
// updates of the given session.
FetchSessionCommittedUpdates(id *wtdb.SessionID) (
@ -78,9 +83,29 @@ type DB interface {
NumAckedUpdates(id *wtdb.SessionID) (uint64, error)
// FetchChanSummaries loads a mapping from all registered channels to
// their channel summaries.
// their channel summaries. Only the channels that have not yet been
// marked as closed will be loaded.
FetchChanSummaries() (wtdb.ChannelSummaries, error)
// MarkChannelClosed will mark a registered channel as closed by setting
// its closed-height as the given block height. It returns a list of
// session IDs for sessions that are now considered closable due to the
// close of this channel. The details for this channel will be deleted
// from the DB if there are no more sessions in the DB that contain
// updates for this channel.
MarkChannelClosed(chanID lnwire.ChannelID, blockHeight uint32) (
[]wtdb.SessionID, error)
// ListClosableSessions fetches and returns the IDs for all sessions
// marked as closable.
ListClosableSessions() (map[wtdb.SessionID]uint32, error)
// DeleteSession can be called when a session should be deleted from the
// DB. All references to the session will also be deleted from the DB.
// A session will only be deleted if it was previously marked as
// closable.
DeleteSession(id wtdb.SessionID) error
// RegisterChannel registers a channel for use within the client
// database. For now, all that is stored in the channel summary is the
// sweep pkscript that we'd like any tower sweeps to pay into. In the
@ -174,3 +199,30 @@ type ClientSession struct {
// key used to connect to the watchtower.
SessionKeyECDH keychain.SingleKeyECDH
}
// NewClientSessionFromDBSession converts a wtdb.ClientSession to a
// ClientSession.
func NewClientSessionFromDBSession(s *wtdb.ClientSession, tower *Tower,
keyRing ECDHKeyRing) (*ClientSession, error) {
towerKeyDesc, err := keyRing.DeriveKey(
keychain.KeyLocator{
Family: keychain.KeyFamilyTowerSession,
Index: s.KeyIndex,
},
)
if err != nil {
return nil, err
}
sessionKeyECDH := keychain.NewPubKeyECDH(
towerKeyDesc, keyRing,
)
return &ClientSession{
ID: s.ID,
ClientSessionBody: s.ClientSessionBody,
Tower: tower,
SessionKeyECDH: sessionKeyECDH,
}, nil
}

View file

@ -0,0 +1,95 @@
package wtclient
import (
"sync"
"github.com/lightningnetwork/lnd/queue"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
)
// sessionCloseMinHeap is a thread-safe min-heap implementation that stores
// sessionCloseItem items and prioritises the item with the lowest block height.
type sessionCloseMinHeap struct {
queue queue.PriorityQueue
mu sync.Mutex
}
// newSessionCloseMinHeap constructs a new sessionCloseMineHeap.
func newSessionCloseMinHeap() *sessionCloseMinHeap {
return &sessionCloseMinHeap{}
}
// Len returns the length of the queue.
func (h *sessionCloseMinHeap) Len() int {
h.mu.Lock()
defer h.mu.Unlock()
return h.queue.Len()
}
// Empty returns true if the queue is empty.
func (h *sessionCloseMinHeap) Empty() bool {
h.mu.Lock()
defer h.mu.Unlock()
return h.queue.Empty()
}
// Push adds an item to the priority queue.
func (h *sessionCloseMinHeap) Push(item *sessionCloseItem) {
h.mu.Lock()
defer h.mu.Unlock()
h.queue.Push(item)
}
// Pop removes the top most item from the queue.
func (h *sessionCloseMinHeap) Pop() *sessionCloseItem {
h.mu.Lock()
defer h.mu.Unlock()
if h.queue.Empty() {
return nil
}
item := h.queue.Pop()
return item.(*sessionCloseItem) //nolint:forcetypeassert
}
// Top returns the top most item from the queue without removing it.
func (h *sessionCloseMinHeap) Top() *sessionCloseItem {
h.mu.Lock()
defer h.mu.Unlock()
if h.queue.Empty() {
return nil
}
item := h.queue.Top()
return item.(*sessionCloseItem) //nolint:forcetypeassert
}
// sessionCloseItem represents a session that is ready to be deleted.
type sessionCloseItem struct {
// sessionID is the ID of the session in question.
sessionID wtdb.SessionID
// deleteHeight is the block height after which we can delete the
// session.
deleteHeight uint32
}
// Less returns true if the current item's delete height is less than the
// other sessionCloseItem's delete height. This results in lower block heights
// being popped first from the heap.
//
// NOTE: this is part of the queue.PriorityQueueItem interface.
func (s *sessionCloseItem) Less(other queue.PriorityQueueItem) bool {
o := other.(*sessionCloseItem).deleteHeight //nolint:forcetypeassert
return s.deleteHeight < o
}
var _ queue.PriorityQueueItem = (*sessionCloseItem)(nil)

View file

@ -0,0 +1,52 @@
package wtclient
import (
"testing"
"github.com/stretchr/testify/require"
)
// TestSessionCloseMinHeap asserts that the sessionCloseMinHeap behaves as
// expected.
func TestSessionCloseMinHeap(t *testing.T) {
t.Parallel()
heap := newSessionCloseMinHeap()
require.Nil(t, heap.Pop())
require.Nil(t, heap.Top())
require.True(t, heap.Empty())
require.Zero(t, heap.Len())
// Add an item with height 10.
item1 := &sessionCloseItem{
sessionID: [33]byte{1, 2, 3},
deleteHeight: 10,
}
heap.Push(item1)
require.Equal(t, item1, heap.Top())
require.False(t, heap.Empty())
require.EqualValues(t, 1, heap.Len())
// Add a bunch more items with heights 1, 2, 6, 11, 6, 30, 9.
heap.Push(&sessionCloseItem{deleteHeight: 1})
heap.Push(&sessionCloseItem{deleteHeight: 2})
heap.Push(&sessionCloseItem{deleteHeight: 6})
heap.Push(&sessionCloseItem{deleteHeight: 11})
heap.Push(&sessionCloseItem{deleteHeight: 6})
heap.Push(&sessionCloseItem{deleteHeight: 30})
heap.Push(&sessionCloseItem{deleteHeight: 9})
// Now pop from the queue and assert that the items are returned in
// ascending order.
require.EqualValues(t, 1, heap.Pop().deleteHeight)
require.EqualValues(t, 2, heap.Pop().deleteHeight)
require.EqualValues(t, 6, heap.Pop().deleteHeight)
require.EqualValues(t, 6, heap.Pop().deleteHeight)
require.EqualValues(t, 9, heap.Pop().deleteHeight)
require.EqualValues(t, 10, heap.Pop().deleteHeight)
require.EqualValues(t, 11, heap.Pop().deleteHeight)
require.EqualValues(t, 30, heap.Pop().deleteHeight)
require.Nil(t, heap.Pop())
require.Zero(t, heap.Len())
}

View file

@ -23,22 +23,39 @@ var (
// cChanDetailsBkt is a top-level bucket storing:
// channel-id => cChannelSummary -> encoded ClientChanSummary.
// => cChanDBID -> db-assigned-id
// => cChanSessions => db-session-id -> 1
// => cChanClosedHeight -> block-height
cChanDetailsBkt = []byte("client-channel-detail-bucket")
// cChanSessions is a sub-bucket of cChanDetailsBkt which stores:
// db-session-id -> 1
cChanSessions = []byte("client-channel-sessions")
// cChanDBID is a key used in the cChanDetailsBkt to store the
// db-assigned-id of a channel.
cChanDBID = []byte("client-channel-db-id")
// cChanClosedHeight is a key used in the cChanDetailsBkt to store the
// block height at which the channel's closing transaction was mined in.
// If this there is no associated value for this key, then the channel
// has not yet been marked as closed.
cChanClosedHeight = []byte("client-channel-closed-height")
// cChannelSummary is a key used in cChanDetailsBkt to store the encoded
// body of ClientChanSummary.
cChannelSummary = []byte("client-channel-summary")
// cSessionBkt is a top-level bucket storing:
// session-id => cSessionBody -> encoded ClientSessionBody
// => cSessionDBID -> db-assigned-id
// => cSessionCommits => seqnum -> encoded CommittedUpdate
// => cSessionAckRangeIndex => db-chan-id => start -> end
cSessionBkt = []byte("client-session-bucket")
// cSessionDBID is a key used in the cSessionBkt to store the
// db-assigned-id of a session.
cSessionDBID = []byte("client-session-db-id")
// cSessionBody is a sub-bucket of cSessionBkt storing only the body of
// the ClientSession.
cSessionBody = []byte("client-session-body")
@ -55,6 +72,10 @@ var (
// db-assigned-id -> channel-ID
cChanIDIndexBkt = []byte("client-channel-id-index")
// cSessionIDIndexBkt is a top-level bucket storing:
// db-assigned-id -> session-id
cSessionIDIndexBkt = []byte("client-session-id-index")
// cTowerBkt is a top-level bucket storing:
// tower-id -> encoded Tower.
cTowerBkt = []byte("client-tower-bucket")
@ -69,6 +90,10 @@ var (
"client-tower-to-session-index-bucket",
)
// cClosableSessionsBkt is a top-level bucket storing:
// db-session-id -> last-channel-close-height
cClosableSessionsBkt = []byte("client-closable-sessions-bucket")
// ErrTowerNotFound signals that the target tower was not found in the
// database.
ErrTowerNotFound = errors.New("tower not found")
@ -142,6 +167,23 @@ var (
// ErrSessionFailedFilterFn indicates that a particular session did
// not pass the filter func provided by the caller.
ErrSessionFailedFilterFn = errors.New("session failed filter func")
// ErrSessionNotClosable is returned when a session is not found in the
// closable list.
ErrSessionNotClosable = errors.New("session is not closable")
// errSessionHasOpenChannels is an error used to indicate that a
// session has updates for channels that are still open.
errSessionHasOpenChannels = errors.New("session has open channels")
// errSessionHasUnackedUpdates is an error used to indicate that a
// session has un-acked updates.
errSessionHasUnackedUpdates = errors.New("session has un-acked updates")
// errChannelHasMoreSessions is an error used to indicate that a channel
// has updates in other non-closed sessions.
errChannelHasMoreSessions = errors.New("channel has updates in " +
"other sessions")
)
// NewBoltBackendCreator returns a function that creates a new bbolt backend for
@ -241,6 +283,8 @@ func initClientDBBuckets(tx kvdb.RwTx) error {
cTowerIndexBkt,
cTowerToSessionIndexBkt,
cChanIDIndexBkt,
cSessionIDIndexBkt,
cClosableSessionsBkt,
}
for _, bucket := range buckets {
@ -723,24 +767,58 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error {
}
}
// Add the new entry to the towerID-to-SessionID index.
indexBkt := towerToSessionIndex.NestedReadWriteBucket(
towerID.Bytes(),
)
if indexBkt == nil {
return ErrTowerNotFound
// Get the session-ID index bucket.
dbIDIndex := tx.ReadWriteBucket(cSessionIDIndexBkt)
if dbIDIndex == nil {
return ErrUninitializedDB
}
err = indexBkt.Put(session.ID[:], []byte{1})
// Get a new, unique, ID for this session from the session-ID
// index bucket.
nextSeq, err := dbIDIndex.NextSequence()
if err != nil {
return err
}
// Add the new entry to the dbID-to-SessionID index.
newIndex, err := writeBigSize(nextSeq)
if err != nil {
return err
}
err = dbIDIndex.Put(newIndex, session.ID[:])
if err != nil {
return err
}
// Also add the db-assigned-id to the session bucket under the
// cSessionDBID key.
sessionBkt, err := sessions.CreateBucket(session.ID[:])
if err != nil {
return err
}
err = sessionBkt.Put(cSessionDBID, newIndex)
if err != nil {
return err
}
// TODO(elle): migrate the towerID-to-SessionID to use the
// new db-assigned sessionID's rather.
// Add the new entry to the towerID-to-SessionID index.
towerSessions := towerToSessionIndex.NestedReadWriteBucket(
towerID.Bytes(),
)
if towerSessions == nil {
return ErrTowerNotFound
}
err = towerSessions.Put(session.ID[:], []byte{1})
if err != nil {
return err
}
// Finally, write the client session's body in the sessions
// bucket.
return putClientSessionBody(sessionBkt, session)
@ -960,6 +1038,37 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID,
return byteOrder.Uint32(keyIndexBytes), nil
}
// GetClientSession loads the ClientSession with the given ID from the DB.
func (c *ClientDB) GetClientSession(id SessionID,
opts ...ClientSessionListOption) (*ClientSession, error) {
var sess *ClientSession
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
sessionsBkt := tx.ReadBucket(cSessionBkt)
if sessionsBkt == nil {
return ErrUninitializedDB
}
chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt)
if chanIDIndexBkt == nil {
return ErrUninitializedDB
}
session, err := c.getClientSession(
sessionsBkt, chanIDIndexBkt, id[:], nil, opts...,
)
if err != nil {
return err
}
sess = session
return nil
}, func() {})
return sess, err
}
// ListClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
@ -974,20 +1083,14 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
return ErrUninitializedDB
}
towers := tx.ReadBucket(cTowerBkt)
if towers == nil {
return ErrUninitializedDB
}
chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt)
if chanIDIndexBkt == nil {
return ErrUninitializedDB
}
var err error
// If no tower ID is specified, then fetch all the sessions
// known to the db.
var err error
if id == nil {
clientSessions, err = c.listClientAllSessions(
sessions, chanIDIndexBkt, filterFn, opts...,
@ -1181,7 +1284,8 @@ func (c *ClientDB) NumAckedUpdates(id *SessionID) (uint64, error) {
}
// FetchChanSummaries loads a mapping from all registered channels to their
// channel summaries.
// channel summaries. Only the channels that have not yet been marked as closed
// will be loaded.
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
var summaries map[lnwire.ChannelID]ClientChanSummary
@ -1197,6 +1301,13 @@ func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
return ErrCorruptChanDetails
}
// If this channel has already been marked as closed,
// then its summary does not need to be loaded.
closedHeight := chanDetails.Get(cChanClosedHeight)
if len(closedHeight) > 0 {
return nil
}
var chanID lnwire.ChannelID
copy(chanID[:], k)
@ -1292,6 +1403,420 @@ func (c *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID,
return nil
}
// ListClosableSessions fetches and returns the IDs for all sessions marked as
// closable.
func (c *ClientDB) ListClosableSessions() (map[SessionID]uint32, error) {
sessions := make(map[SessionID]uint32)
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
csBkt := tx.ReadBucket(cClosableSessionsBkt)
if csBkt == nil {
return ErrUninitializedDB
}
sessIDIndexBkt := tx.ReadBucket(cSessionIDIndexBkt)
if sessIDIndexBkt == nil {
return ErrUninitializedDB
}
return csBkt.ForEach(func(dbIDBytes, heightBytes []byte) error {
dbID, err := readBigSize(dbIDBytes)
if err != nil {
return err
}
sessID, err := getRealSessionID(sessIDIndexBkt, dbID)
if err != nil {
return err
}
sessions[*sessID] = byteOrder.Uint32(heightBytes)
return nil
})
}, func() {
sessions = make(map[SessionID]uint32)
})
if err != nil {
return nil, err
}
return sessions, nil
}
// DeleteSession can be called when a session should be deleted from the DB.
// All references to the session will also be deleted from the DB. Note that a
// session will only be deleted if was previously marked as closable.
func (c *ClientDB) DeleteSession(id SessionID) error {
return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
sessionsBkt := tx.ReadWriteBucket(cSessionBkt)
if sessionsBkt == nil {
return ErrUninitializedDB
}
closableBkt := tx.ReadWriteBucket(cClosableSessionsBkt)
if closableBkt == nil {
return ErrUninitializedDB
}
chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt)
if chanDetailsBkt == nil {
return ErrUninitializedDB
}
sessIDIndexBkt := tx.ReadWriteBucket(cSessionIDIndexBkt)
if sessIDIndexBkt == nil {
return ErrUninitializedDB
}
chanIDIndexBkt := tx.ReadWriteBucket(cChanIDIndexBkt)
if chanIDIndexBkt == nil {
return ErrUninitializedDB
}
towerToSessBkt := tx.ReadWriteBucket(cTowerToSessionIndexBkt)
if towerToSessBkt == nil {
return ErrUninitializedDB
}
// Get the sub-bucket for this session ID. If it does not exist
// then the session has already been deleted and so our work is
// done.
sessionBkt := sessionsBkt.NestedReadBucket(id[:])
if sessionBkt == nil {
return nil
}
_, dbIDBytes, err := getDBSessionID(sessionsBkt, id)
if err != nil {
return err
}
// First we check if the session has actually been marked as
// closable.
if closableBkt.Get(dbIDBytes) == nil {
return ErrSessionNotClosable
}
sess, err := getClientSessionBody(sessionsBkt, id[:])
if err != nil {
return err
}
// Delete from the tower-to-sessionID index.
towerIndexBkt := towerToSessBkt.NestedReadWriteBucket(
sess.TowerID.Bytes(),
)
if towerIndexBkt == nil {
return fmt.Errorf("no entry in the tower-to-session "+
"index found for tower ID %v", sess.TowerID)
}
err = towerIndexBkt.Delete(id[:])
if err != nil {
return err
}
// Delete entry from session ID index.
err = sessIDIndexBkt.Delete(dbIDBytes)
if err != nil {
return err
}
// Delete the entry from the closable sessions index.
err = closableBkt.Delete(dbIDBytes)
if err != nil {
return err
}
// Get the acked updates range index for the session. This is
// used to get the list of channels that the session has updates
// for.
ackRanges := sessionBkt.NestedReadBucket(cSessionAckRangeIndex)
if ackRanges == nil {
// A session would only be considered closable if it
// was exhausted. Meaning that it should not be the
// case that it has no acked-updates.
return fmt.Errorf("cannot delete session %s since it "+
"is not yet exhausted", id)
}
// For each of the channels, delete the session ID entry.
err = ackRanges.ForEach(func(chanDBID, _ []byte) error {
chanDBIDInt, err := readBigSize(chanDBID)
if err != nil {
return err
}
chanID, err := getRealChannelID(
chanIDIndexBkt, chanDBIDInt,
)
if err != nil {
return err
}
chanDetails := chanDetailsBkt.NestedReadWriteBucket(
chanID[:],
)
if chanDetails == nil {
return ErrChannelNotRegistered
}
chanSessions := chanDetails.NestedReadWriteBucket(
cChanSessions,
)
if chanSessions == nil {
return fmt.Errorf("no session list found for "+
"channel %s", chanID)
}
// Check that this session was actually listed in the
// session list for this channel.
if len(chanSessions.Get(dbIDBytes)) == 0 {
return fmt.Errorf("session %s not found in "+
"the session list for channel %s", id,
chanID)
}
// If it was, then delete it.
err = chanSessions.Delete(dbIDBytes)
if err != nil {
return err
}
// If this was the last session for this channel, we can
// now delete the channel details for this channel
// completely.
err = chanSessions.ForEach(func(_, _ []byte) error {
return errChannelHasMoreSessions
})
if errors.Is(err, errChannelHasMoreSessions) {
return nil
} else if err != nil {
return err
}
// Delete the channel's entry from the channel-id-index.
dbID := chanDetails.Get(cChanDBID)
err = chanIDIndexBkt.Delete(dbID)
if err != nil {
return err
}
// Delete the channel details.
return chanDetailsBkt.DeleteNestedBucket(chanID[:])
})
if err != nil {
return err
}
// Delete the actual session.
return sessionsBkt.DeleteNestedBucket(id[:])
}, func() {})
}
// MarkChannelClosed will mark a registered channel as closed by setting its
// closed-height as the given block height. It returns a list of session IDs for
// sessions that are now considered closable due to the close of this channel.
// The details for this channel will be deleted from the DB if there are no more
// sessions in the DB that contain updates for this channel.
func (c *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID,
blockHeight uint32) ([]SessionID, error) {
var closableSessions []SessionID
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
sessionsBkt := tx.ReadBucket(cSessionBkt)
if sessionsBkt == nil {
return ErrUninitializedDB
}
chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt)
if chanDetailsBkt == nil {
return ErrUninitializedDB
}
closableSessBkt := tx.ReadWriteBucket(cClosableSessionsBkt)
if closableSessBkt == nil {
return ErrUninitializedDB
}
chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt)
if chanIDIndexBkt == nil {
return ErrUninitializedDB
}
sessIDIndexBkt := tx.ReadBucket(cSessionIDIndexBkt)
if sessIDIndexBkt == nil {
return ErrUninitializedDB
}
chanDetails := chanDetailsBkt.NestedReadWriteBucket(chanID[:])
if chanDetails == nil {
return ErrChannelNotRegistered
}
// If there are no sessions for this channel, the channel
// details can be deleted.
chanSessIDsBkt := chanDetails.NestedReadBucket(cChanSessions)
if chanSessIDsBkt == nil {
return chanDetailsBkt.DeleteNestedBucket(chanID[:])
}
// Otherwise, mark the channel as closed.
var height [4]byte
byteOrder.PutUint32(height[:], blockHeight)
err := chanDetails.Put(cChanClosedHeight, height[:])
if err != nil {
return err
}
// Now iterate through all the sessions of the channel to check
// if any of them are closeable.
return chanSessIDsBkt.ForEach(func(sessDBID, _ []byte) error {
sessDBIDInt, err := readBigSize(sessDBID)
if err != nil {
return err
}
// Use the session-ID index to get the real session ID.
sID, err := getRealSessionID(
sessIDIndexBkt, sessDBIDInt,
)
if err != nil {
return err
}
isClosable, err := isSessionClosable(
sessionsBkt, chanDetailsBkt, chanIDIndexBkt,
sID,
)
if err != nil {
return err
}
if !isClosable {
return nil
}
// Add session to "closableSessions" list and add the
// block height that this last channel was closed in.
// This will be used in future to determine when we
// should delete the session.
var height [4]byte
byteOrder.PutUint32(height[:], blockHeight)
err = closableSessBkt.Put(sessDBID, height[:])
if err != nil {
return err
}
closableSessions = append(closableSessions, *sID)
return nil
})
}, func() {
closableSessions = nil
})
if err != nil {
return nil, err
}
return closableSessions, nil
}
// isSessionClosable returns true if a session is considered closable. A session
// is considered closable only if all the following points are true:
// 1) It has no un-acked updates.
// 2) It is exhausted (ie it can't accept any more updates)
// 3) All the channels that it has acked updates for are closed.
func isSessionClosable(sessionsBkt, chanDetailsBkt, chanIDIndexBkt kvdb.RBucket,
id *SessionID) (bool, error) {
sessBkt := sessionsBkt.NestedReadBucket(id[:])
if sessBkt == nil {
return false, ErrSessionNotFound
}
commitsBkt := sessBkt.NestedReadBucket(cSessionCommits)
if commitsBkt == nil {
// If the session has no cSessionCommits bucket then we can be
// sure that no updates have ever been committed to the session
// and so it is not yet exhausted.
return false, nil
}
// If the session has any un-acked updates, then it is not yet closable.
err := commitsBkt.ForEach(func(_, _ []byte) error {
return errSessionHasUnackedUpdates
})
if errors.Is(err, errSessionHasUnackedUpdates) {
return false, nil
} else if err != nil {
return false, err
}
session, err := getClientSessionBody(sessionsBkt, id[:])
if err != nil {
return false, err
}
// We have already checked that the session has no more committed
// updates. So now we can check if the session is exhausted.
if session.SeqNum < session.Policy.MaxUpdates {
// If the session is not yet exhausted, it is not yet closable.
return false, nil
}
// If the session has no acked-updates, then something is wrong since
// the above check ensures that this session has been exhausted meaning
// that it should have MaxUpdates acked updates.
ackedRangeBkt := sessBkt.NestedReadBucket(cSessionAckRangeIndex)
if ackedRangeBkt == nil {
return false, fmt.Errorf("no acked-updates found for "+
"exhausted session %s", id)
}
// Iterate over each of the channels that the session has acked-updates
// for. If any of those channels are not closed, then the session is
// not yet closable.
err = ackedRangeBkt.ForEach(func(dbChanID, _ []byte) error {
dbChanIDInt, err := readBigSize(dbChanID)
if err != nil {
return err
}
chanID, err := getRealChannelID(chanIDIndexBkt, dbChanIDInt)
if err != nil {
return err
}
// Get the channel details bucket for the channel.
chanDetails := chanDetailsBkt.NestedReadBucket(chanID[:])
if chanDetails == nil {
return fmt.Errorf("no channel details found for "+
"channel %s referenced by session %s", chanID,
id)
}
// If a closed height has been set, then the channel is closed.
closedHeight := chanDetails.Get(cChanClosedHeight)
if len(closedHeight) > 0 {
return nil
}
// Otherwise, the channel is not yet closed meaning that the
// session is not yet closable. We break the ForEach by
// returning an error to indicate this.
return errSessionHasOpenChannels
})
if errors.Is(err, errSessionHasOpenChannels) {
return false, nil
} else if err != nil {
return false, err
}
return true, nil
}
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
// seqNum). This allows the client to retransmit this update on startup.
func (c *ClientDB) CommitUpdate(id *SessionID,
@ -1410,7 +1935,7 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
return ErrUninitializedDB
}
chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt)
chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt)
if chanDetailsBkt == nil {
return ErrUninitializedDB
}
@ -1494,6 +2019,23 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
return err
}
dbSessionID, _, err := getDBSessionID(sessions, *id)
if err != nil {
return err
}
chanDetails := chanDetailsBkt.NestedReadWriteBucket(
committedUpdate.BackupID.ChanID[:],
)
if chanDetails == nil {
return ErrChannelNotRegistered
}
err = putChannelToSessionMapping(chanDetails, dbSessionID)
if err != nil {
return err
}
// Get the range index for the given session-channel pair.
index, err := c.getRangeIndex(tx, *id, chanID)
if err != nil {
@ -1504,6 +2046,26 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
}, func() {})
}
// putChannelToSessionMapping adds the given session ID to a channel's
// cChanSessions bucket.
func putChannelToSessionMapping(chanDetails kvdb.RwBucket,
dbSessID uint64) error {
chanSessIDsBkt, err := chanDetails.CreateBucketIfNotExists(
cChanSessions,
)
if err != nil {
return err
}
b, err := writeBigSize(dbSessID)
if err != nil {
return err
}
return chanSessIDsBkt.Put(b, []byte{1})
}
// getClientSessionBody loads the body of a ClientSession from the sessions
// bucket corresponding to the serialized session id. This does not deserialize
// the CommittedUpdates, AckUpdates or the Tower associated with the session.
@ -1882,6 +2444,68 @@ func getDBChanID(chanDetailsBkt kvdb.RBucket, chanID lnwire.ChannelID) (uint64,
return id, idBytes, nil
}
// getDBSessionID returns the db-assigned session ID for the given real session
// ID. It returns both the uint64 and byte representation.
func getDBSessionID(sessionsBkt kvdb.RBucket, sessionID SessionID) (uint64,
[]byte, error) {
sessionBkt := sessionsBkt.NestedReadBucket(sessionID[:])
if sessionBkt == nil {
return 0, nil, ErrClientSessionNotFound
}
idBytes := sessionBkt.Get(cSessionDBID)
if len(idBytes) == 0 {
return 0, nil, fmt.Errorf("no db-assigned ID found for "+
"session ID %s", sessionID)
}
id, err := readBigSize(idBytes)
if err != nil {
return 0, nil, err
}
return id, idBytes, nil
}
func getRealSessionID(sessIDIndexBkt kvdb.RBucket, dbID uint64) (*SessionID,
error) {
dbIDBytes, err := writeBigSize(dbID)
if err != nil {
return nil, err
}
sessIDBytes := sessIDIndexBkt.Get(dbIDBytes)
if len(sessIDBytes) != SessionIDSize {
return nil, fmt.Errorf("session ID not found")
}
var sessID SessionID
copy(sessID[:], sessIDBytes)
return &sessID, nil
}
func getRealChannelID(chanIDIndexBkt kvdb.RBucket,
dbID uint64) (*lnwire.ChannelID, error) {
dbIDBytes, err := writeBigSize(dbID)
if err != nil {
return nil, err
}
chanIDBytes := chanIDIndexBkt.Get(dbIDBytes)
if len(chanIDBytes) != 32 { //nolint:gomnd
return nil, fmt.Errorf("channel ID not found")
}
var chanIDS lnwire.ChannelID
copy(chanIDS[:], chanIDBytes)
return &chanIDS, nil
}
// writeBigSize will encode the given uint64 as a BigSize byte slice.
func writeBigSize(i uint64) ([]byte, error) {
var b bytes.Buffer

View file

@ -3,6 +3,7 @@ package wtdb_test
import (
crand "crypto/rand"
"io"
"math/rand"
"net"
"testing"
@ -17,6 +18,8 @@ import (
"github.com/stretchr/testify/require"
)
const blobType = blob.TypeAltruistCommit
// pseudoAddr is a fake network address to be used for testing purposes.
var pseudoAddr = &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911}
@ -193,6 +196,35 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16,
require.ErrorIs(h.t, err, expErr)
}
func (h *clientDBHarness) markChannelClosed(id lnwire.ChannelID,
blockHeight uint32, expErr error) []wtdb.SessionID {
h.t.Helper()
closableSessions, err := h.db.MarkChannelClosed(id, blockHeight)
require.ErrorIs(h.t, err, expErr)
return closableSessions
}
func (h *clientDBHarness) listClosableSessions(
expErr error) map[wtdb.SessionID]uint32 {
h.t.Helper()
closableSessions, err := h.db.ListClosableSessions()
require.ErrorIs(h.t, err, expErr)
return closableSessions
}
func (h *clientDBHarness) deleteSession(id wtdb.SessionID, expErr error) {
h.t.Helper()
err := h.db.DeleteSession(id)
require.ErrorIs(h.t, err, expErr)
}
// newTower is a helper function that creates a new tower with a randomly
// generated public key and inserts it into the client DB.
func (h *clientDBHarness) newTower() *wtdb.Tower {
@ -605,6 +637,118 @@ func testCommitUpdate(h *clientDBHarness) {
}, nil)
}
// testMarkChannelClosed asserts the behaviour of MarkChannelClosed.
func testMarkChannelClosed(h *clientDBHarness) {
tower := h.newTower()
// Create channel 1.
chanID1 := randChannelID(h.t)
// Since we have not yet registered the channel, we expect an error
// when attempting to mark it as closed.
h.markChannelClosed(chanID1, 1, wtdb.ErrChannelNotRegistered)
// Now register the channel.
h.registerChan(chanID1, nil, nil)
// Since there are still no sessions that would have updates for the
// channel, marking it as closed now should succeed.
h.markChannelClosed(chanID1, 1, nil)
// Register channel 2.
chanID2 := randChannelID(h.t)
h.registerChan(chanID2, nil, nil)
// Create session1 with MaxUpdates set to 5.
session1 := h.randSession(h.t, tower.ID, 5)
h.insertSession(session1, nil)
// Add an update for channel 2 in session 1 and ack it too.
update := randCommittedUpdateForChannel(h.t, chanID2, 1)
lastApplied := h.commitUpdate(&session1.ID, update, nil)
require.Zero(h.t, lastApplied)
h.ackUpdate(&session1.ID, 1, 1, nil)
// Marking channel 2 now should not result in any closable sessions
// since session 1 is not yet exhausted.
sl := h.markChannelClosed(chanID2, 1, nil)
require.Empty(h.t, sl)
// Create channel 3 and 4.
chanID3 := randChannelID(h.t)
h.registerChan(chanID3, nil, nil)
chanID4 := randChannelID(h.t)
h.registerChan(chanID4, nil, nil)
// Add an update for channel 4 and ack it.
update = randCommittedUpdateForChannel(h.t, chanID4, 2)
lastApplied = h.commitUpdate(&session1.ID, update, nil)
require.EqualValues(h.t, 1, lastApplied)
h.ackUpdate(&session1.ID, 2, 2, nil)
// Add an update for channel 3 in session 1. But dont ack it yet.
update = randCommittedUpdateForChannel(h.t, chanID2, 3)
lastApplied = h.commitUpdate(&session1.ID, update, nil)
require.EqualValues(h.t, 2, lastApplied)
// Mark channel 4 as closed & assert that session 1 is not seen as
// closable since it still has committed updates.
sl = h.markChannelClosed(chanID4, 1, nil)
require.Empty(h.t, sl)
// Now ack the update we added above.
h.ackUpdate(&session1.ID, 3, 3, nil)
// Mark channel 3 as closed & assert that session 1 is still not seen as
// closable since it is not yet exhausted.
sl = h.markChannelClosed(chanID3, 1, nil)
require.Empty(h.t, sl)
// Create channel 5 and 6.
chanID5 := randChannelID(h.t)
h.registerChan(chanID5, nil, nil)
chanID6 := randChannelID(h.t)
h.registerChan(chanID6, nil, nil)
// Add an update for channel 5 and ack it.
update = randCommittedUpdateForChannel(h.t, chanID5, 4)
lastApplied = h.commitUpdate(&session1.ID, update, nil)
require.EqualValues(h.t, 3, lastApplied)
h.ackUpdate(&session1.ID, 4, 4, nil)
// Add an update for channel 6 and ack it.
update = randCommittedUpdateForChannel(h.t, chanID6, 5)
lastApplied = h.commitUpdate(&session1.ID, update, nil)
require.EqualValues(h.t, 4, lastApplied)
h.ackUpdate(&session1.ID, 5, 5, nil)
// The session is no exhausted.
// If we now close channel 5, session 1 should still not be closable
// since it has an update for channel 6 which is still open.
sl = h.markChannelClosed(chanID5, 1, nil)
require.Empty(h.t, sl)
require.Empty(h.t, h.listClosableSessions(nil))
// Also check that attempting to delete the session will fail since it
// is not yet considered closable.
h.deleteSession(session1.ID, wtdb.ErrSessionNotClosable)
// Finally, if we close channel 6, session 1 _should_ be in the closable
// list.
sl = h.markChannelClosed(chanID6, 100, nil)
require.ElementsMatch(h.t, sl, []wtdb.SessionID{session1.ID})
slMap := h.listClosableSessions(nil)
require.InDeltaMapValues(h.t, slMap, map[wtdb.SessionID]uint32{
session1.ID: 100,
}, 0)
// Assert that we now can delete the session.
h.deleteSession(session1.ID, nil)
require.Empty(h.t, h.listClosableSessions(nil))
}
// testAckUpdate asserts the behavior of AckUpdate.
func testAckUpdate(h *clientDBHarness) {
const blobType = blob.TypeAltruistCommit
@ -821,6 +965,10 @@ func TestClientDB(t *testing.T) {
name: "ack update",
run: testAckUpdate,
},
{
name: "mark channel closed",
run: testMarkChannelClosed,
},
}
for _, database := range dbs {
@ -841,12 +989,32 @@ func TestClientDB(t *testing.T) {
// randCommittedUpdate generates a random committed update.
func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate {
t.Helper()
chanID := randChannelID(t)
return randCommittedUpdateForChannel(t, chanID, seqNum)
}
func randChannelID(t *testing.T) lnwire.ChannelID {
t.Helper()
var chanID lnwire.ChannelID
_, err := io.ReadFull(crand.Reader, chanID[:])
require.NoError(t, err)
return chanID
}
// randCommittedUpdateForChannel generates a random committed update for the
// given channel ID.
func randCommittedUpdateForChannel(t *testing.T, chanID lnwire.ChannelID,
seqNum uint16) *wtdb.CommittedUpdate {
t.Helper()
var hint blob.BreachHint
_, err = io.ReadFull(crand.Reader, hint[:])
_, err := io.ReadFull(crand.Reader, hint[:])
require.NoError(t, err)
encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type()))
@ -865,3 +1033,27 @@ func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate {
},
}
}
func (h *clientDBHarness) randSession(t *testing.T,
towerID wtdb.TowerID, maxUpdates uint16) *wtdb.ClientSession {
t.Helper()
var id wtdb.SessionID
rand.Read(id[:])
return &wtdb.ClientSession{
ClientSessionBody: wtdb.ClientSessionBody{
TowerID: towerID,
Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blobType,
},
MaxUpdates: maxUpdates,
},
RewardPkScript: []byte{0x01, 0x02, 0x03},
KeyIndex: h.nextKeyIndex(towerID, blobType),
},
ID: id,
}
}

View file

@ -8,6 +8,8 @@ import (
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration3"
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration4"
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration5"
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration6"
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration7"
)
// log is a logger that is initialized with no output filters. This
@ -36,6 +38,8 @@ func UseLogger(logger btclog.Logger) {
migration3.UseLogger(logger)
migration4.UseLogger(logger)
migration5.UseLogger(logger)
migration6.UseLogger(logger)
migration7.UseLogger(logger)
}
// logClosure is used to provide a closure over expensive logging operations so

View file

@ -0,0 +1,114 @@
package migration6
import (
"bytes"
"encoding/binary"
"errors"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/tlv"
)
var (
// cSessionBkt is a top-level bucket storing:
// session-id => cSessionBody -> encoded ClientSessionBody
// => cSessionDBID -> db-assigned-id
// => cSessionCommits => seqnum -> encoded CommittedUpdate
// => cSessionAcks => seqnum -> encoded BackupID
cSessionBkt = []byte("client-session-bucket")
// cSessionDBID is a key used in the cSessionBkt to store the
// db-assigned-id of a session.
cSessionDBID = []byte("client-session-db-id")
// cSessionIDIndexBkt is a top-level bucket storing:
// db-assigned-id -> session-id
cSessionIDIndexBkt = []byte("client-session-id-index")
// cSessionBody is a sub-bucket of cSessionBkt storing only the body of
// the ClientSession.
cSessionBody = []byte("client-session-body")
// ErrUninitializedDB signals that top-level buckets for the database
// have not been initialized.
ErrUninitializedDB = errors.New("db not initialized")
// ErrCorruptClientSession signals that the client session's on-disk
// structure deviates from what is expected.
ErrCorruptClientSession = errors.New("client session corrupted")
byteOrder = binary.BigEndian
)
// MigrateSessionIDIndex adds a new session ID index to the tower client db.
// This index is a mapping from db-assigned ID (a uint64 encoded using BigSize)
// to real session ID (33 bytes). This mapping will allow us to persist session
// pointers with fewer bytes in the future.
func MigrateSessionIDIndex(tx kvdb.RwTx) error {
log.Infof("Migrating the tower client db to add a new session ID " +
"index which stores a mapping from db-assigned ID to real " +
"session ID")
// Create a new top-level bucket for the index.
indexBkt, err := tx.CreateTopLevelBucket(cSessionIDIndexBkt)
if err != nil {
return err
}
// Get the existing top-level sessions bucket.
sessionsBkt := tx.ReadWriteBucket(cSessionBkt)
if sessionsBkt == nil {
return ErrUninitializedDB
}
// Iterate over the sessions bucket where each key is a session-ID.
return sessionsBkt.ForEach(func(sessionID, _ []byte) error {
// Ask the DB for a new, unique, id for the index bucket.
nextSeq, err := indexBkt.NextSequence()
if err != nil {
return err
}
newIndex, err := writeBigSize(nextSeq)
if err != nil {
return err
}
// Add the new db-assigned-ID to real-session-ID pair to the
// new index bucket.
err = indexBkt.Put(newIndex, sessionID)
if err != nil {
return err
}
// Get the sub-bucket for this specific session ID.
sessionBkt := sessionsBkt.NestedReadWriteBucket(sessionID)
if sessionBkt == nil {
return ErrCorruptClientSession
}
// Here we ensure that the session bucket includes a session
// body. The only reason we do this is so that we can simulate
// a migration fail in a test to ensure that a migration fail
// results in an untouched db.
sessionBodyBytes := sessionBkt.Get(cSessionBody)
if sessionBodyBytes == nil {
return ErrCorruptClientSession
}
// Add the db-assigned ID of the session to the session under
// the cSessionDBID key.
return sessionBkt.Put(cSessionDBID, newIndex)
})
}
// writeBigSize will encode the given uint64 as a BigSize byte slice.
func writeBigSize(i uint64) ([]byte, error) {
var b bytes.Buffer
err := tlv.WriteVarInt(&b, i, &[8]byte{})
if err != nil {
return nil, err
}
return b.Bytes(), nil
}

View file

@ -0,0 +1,147 @@
package migration6
import (
"bytes"
"testing"
"github.com/lightningnetwork/lnd/channeldb/migtest"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/tlv"
)
var (
// pre is the expected data in the sessions bucket before the migration.
pre = map[string]interface{}{
sessionIDToString(100): map[string]interface{}{
string(cSessionBody): string([]byte{1, 2, 3}),
},
sessionIDToString(222): map[string]interface{}{
string(cSessionBody): string([]byte{4, 5, 6}),
},
}
// preFailCorruptDB should fail the migration due to no session body
// being found for a given session ID.
preFailCorruptDB = map[string]interface{}{
sessionIDToString(100): "",
}
// post is the expected session index after migration.
postIndex = map[string]interface{}{
indexToString(1): sessionIDToString(100),
indexToString(2): sessionIDToString(222),
}
// postSessions is the expected data in the sessions bucket after the
// migration.
postSessions = map[string]interface{}{
sessionIDToString(100): map[string]interface{}{
string(cSessionBody): string([]byte{1, 2, 3}),
string(cSessionDBID): indexToString(1),
},
sessionIDToString(222): map[string]interface{}{
string(cSessionBody): string([]byte{4, 5, 6}),
string(cSessionDBID): indexToString(2),
},
}
)
// TestMigrateSessionIDIndex tests that the MigrateSessionIDIndex function
// correctly adds a new session-id index to the DB and also correctly updates
// the existing session bucket.
func TestMigrateSessionIDIndex(t *testing.T) {
t.Parallel()
tests := []struct {
name string
shouldFail bool
pre map[string]interface{}
postSessions map[string]interface{}
postIndex map[string]interface{}
}{
{
name: "migration ok",
shouldFail: false,
pre: pre,
postSessions: postSessions,
postIndex: postIndex,
},
{
name: "fail due to corrupt db",
shouldFail: true,
pre: preFailCorruptDB,
},
{
name: "no channel details",
shouldFail: false,
pre: nil,
postSessions: nil,
postIndex: nil,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
// Before the migration we have a details bucket.
before := func(tx kvdb.RwTx) error {
return migtest.RestoreDB(
tx, cSessionBkt, test.pre,
)
}
// After the migration, we should have an untouched
// summary bucket and a new index bucket.
after := func(tx kvdb.RwTx) error {
// If the migration fails, the details bucket
// should be untouched.
if test.shouldFail {
if err := migtest.VerifyDB(
tx, cSessionBkt, test.pre,
); err != nil {
return err
}
return nil
}
// Else, we expect an updated summary bucket
// and a new index bucket.
err := migtest.VerifyDB(
tx, cSessionBkt, test.postSessions,
)
if err != nil {
return err
}
return migtest.VerifyDB(
tx, cSessionIDIndexBkt, test.postIndex,
)
}
migtest.ApplyMigration(
t, before, after, MigrateSessionIDIndex,
test.shouldFail,
)
})
}
}
func indexToString(id uint64) string {
var newIndex bytes.Buffer
err := tlv.WriteVarInt(&newIndex, id, &[8]byte{})
if err != nil {
panic(err)
}
return newIndex.String()
}
func sessionIDToString(id uint64) string {
var chanID SessionID
byteOrder.PutUint64(chanID[:], id)
return chanID.String()
}

View file

@ -0,0 +1,17 @@
package migration6
import (
"encoding/hex"
)
// SessionIDSize is 33-bytes; it is a serialized, compressed public key.
const SessionIDSize = 33
// SessionID is created from the remote public key of a client, and serves as a
// unique identifier and authentication for sending state updates.
type SessionID [SessionIDSize]byte
// String returns a hex encoding of the session id.
func (s SessionID) String() string {
return hex.EncodeToString(s[:])
}

View file

@ -0,0 +1,14 @@
package migration6
import (
"github.com/btcsuite/btclog"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}

View file

@ -0,0 +1,202 @@
package migration7
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/tlv"
)
var (
// cSessionBkt is a top-level bucket storing:
// session-id => cSessionBody -> encoded ClientSessionBody
// => cSessionDBID -> db-assigned-id
// => cSessionCommits => seqnum -> encoded CommittedUpdate
// => cSessionAckRangeIndex => chan-id => acked-index-range
cSessionBkt = []byte("client-session-bucket")
// cChanDetailsBkt is a top-level bucket storing:
// channel-id => cChannelSummary -> encoded ClientChanSummary.
// => cChanDBID -> db-assigned-id
// => cChanSessions => db-session-id -> 1
cChanDetailsBkt = []byte("client-channel-detail-bucket")
// cChannelSummary is a sub-bucket of cChanDetailsBkt which stores the
// encoded body of ClientChanSummary.
cChannelSummary = []byte("client-channel-summary")
// cChanSessions is a sub-bucket of cChanDetailsBkt which stores:
// session-id -> 1
cChanSessions = []byte("client-channel-sessions")
// cSessionAckRangeIndex is a sub-bucket of cSessionBkt storing:
// chan-id => start -> end
cSessionAckRangeIndex = []byte("client-session-ack-range-index")
// cSessionDBID is a key used in the cSessionBkt to store the
// db-assigned-d of a session.
cSessionDBID = []byte("client-session-db-id")
// cChanIDIndexBkt is a top-level bucket storing:
// db-assigned-id -> channel-ID
cChanIDIndexBkt = []byte("client-channel-id-index")
// ErrUninitializedDB signals that top-level buckets for the database
// have not been initialized.
ErrUninitializedDB = errors.New("db not initialized")
// ErrCorruptClientSession signals that the client session's on-disk
// structure deviates from what is expected.
ErrCorruptClientSession = errors.New("client session corrupted")
// byteOrder is the default endianness used when serializing integers.
byteOrder = binary.BigEndian
)
// MigrateChannelToSessionIndex migrates the tower client DB to add an index
// from channel-to-session. This will make it easier in future to check which
// sessions have updates for which channels.
func MigrateChannelToSessionIndex(tx kvdb.RwTx) error {
log.Infof("Migrating the tower client DB to build a new " +
"channel-to-session index")
sessionsBkt := tx.ReadBucket(cSessionBkt)
if sessionsBkt == nil {
return ErrUninitializedDB
}
chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt)
if chanDetailsBkt == nil {
return ErrUninitializedDB
}
chanIDsBkt := tx.ReadBucket(cChanIDIndexBkt)
if chanIDsBkt == nil {
return ErrUninitializedDB
}
// First gather all the new channel-to-session pairs that we want to
// add.
index, err := collectIndex(sessionsBkt)
if err != nil {
return err
}
// Then persist those pairs to the db.
return persistIndex(chanDetailsBkt, chanIDsBkt, index)
}
// collectIndex iterates through all the sessions and uses the keys in the
// cSessionAckRangeIndex bucket to collect all the channels that the session
// has updates for. The function returns a map from channel ID to session ID
// (using the db-assigned IDs for both).
func collectIndex(sessionsBkt kvdb.RBucket) (map[uint64]map[uint64]bool,
error) {
index := make(map[uint64]map[uint64]bool)
err := sessionsBkt.ForEach(func(sessID, _ []byte) error {
sessionBkt := sessionsBkt.NestedReadBucket(sessID)
if sessionBkt == nil {
return ErrCorruptClientSession
}
ackedRanges := sessionBkt.NestedReadBucket(
cSessionAckRangeIndex,
)
if ackedRanges == nil {
return nil
}
sessDBIDBytes := sessionBkt.Get(cSessionDBID)
if sessDBIDBytes == nil {
return ErrCorruptClientSession
}
sessDBID, err := readUint64(sessDBIDBytes)
if err != nil {
return err
}
return ackedRanges.ForEach(func(dbChanIDBytes, _ []byte) error {
dbChanID, err := readUint64(dbChanIDBytes)
if err != nil {
return err
}
if _, ok := index[dbChanID]; !ok {
index[dbChanID] = make(map[uint64]bool)
}
index[dbChanID][sessDBID] = true
return nil
})
})
if err != nil {
return nil, err
}
return index, nil
}
// persistIndex adds the channel-to-session mapping in each channel's details
// bucket.
func persistIndex(chanDetailsBkt kvdb.RwBucket, chanIDsBkt kvdb.RBucket,
index map[uint64]map[uint64]bool) error {
for dbChanID, sessIDs := range index {
dbChanIDBytes, err := writeUint64(dbChanID)
if err != nil {
return err
}
realChanID := chanIDsBkt.Get(dbChanIDBytes)
chanBkt := chanDetailsBkt.NestedReadWriteBucket(realChanID)
if chanBkt == nil {
return fmt.Errorf("channel not found")
}
sessIDsBkt, err := chanBkt.CreateBucket(cChanSessions)
if err != nil {
return err
}
for id := range sessIDs {
sessID, err := writeUint64(id)
if err != nil {
return err
}
err = sessIDsBkt.Put(sessID, []byte{1})
if err != nil {
return err
}
}
}
return nil
}
func writeUint64(i uint64) ([]byte, error) {
var b bytes.Buffer
err := tlv.WriteVarInt(&b, i, &[8]byte{})
if err != nil {
return nil, err
}
return b.Bytes(), nil
}
func readUint64(b []byte) (uint64, error) {
r := bytes.NewReader(b)
i, err := tlv.ReadVarInt(r, &[8]byte{})
if err != nil {
return 0, err
}
return i, nil
}

View file

@ -0,0 +1,191 @@
package migration7
import (
"testing"
"github.com/lightningnetwork/lnd/channeldb/migtest"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
// preDetails is the expected data of the channel details bucket before
// the migration.
preDetails = map[string]interface{}{
channelIDString(100): map[string]interface{}{
string(cChannelSummary): string([]byte{1, 2, 3}),
},
channelIDString(222): map[string]interface{}{
string(cChannelSummary): string([]byte{4, 5, 6}),
},
}
// preFailCorruptDB should fail the migration due to no channel summary
// being found for a given channel ID.
preFailCorruptDB = map[string]interface{}{
channelIDString(30): map[string]interface{}{},
}
// channelIDIndex is the data in the channelID index that is used to
// find the mapping between the db-assigned channel ID and the real
// channel ID.
channelIDIndex = map[string]interface{}{
uint64ToStr(10): channelIDString(100),
uint64ToStr(20): channelIDString(222),
}
// sessions is the expected data in the sessions bucket before and
// after the migration.
sessions = map[string]interface{}{
sessionIDString("1"): map[string]interface{}{
string(cSessionAckRangeIndex): map[string]interface{}{
uint64ToStr(10): map[string]interface{}{
uint64ToStr(30): uint64ToStr(32),
uint64ToStr(34): uint64ToStr(34),
},
uint64ToStr(20): map[string]interface{}{
uint64ToStr(30): uint64ToStr(30),
},
},
string(cSessionDBID): uint64ToStr(66),
},
sessionIDString("2"): map[string]interface{}{
string(cSessionAckRangeIndex): map[string]interface{}{
uint64ToStr(10): map[string]interface{}{
uint64ToStr(33): uint64ToStr(33),
},
},
string(cSessionDBID): uint64ToStr(77),
},
}
// postDetails is the expected data in the channel details bucket after
// the migration.
postDetails = map[string]interface{}{
channelIDString(100): map[string]interface{}{
string(cChannelSummary): string([]byte{1, 2, 3}),
string(cChanSessions): map[string]interface{}{
uint64ToStr(66): string([]byte{1}),
uint64ToStr(77): string([]byte{1}),
},
},
channelIDString(222): map[string]interface{}{
string(cChannelSummary): string([]byte{4, 5, 6}),
string(cChanSessions): map[string]interface{}{
uint64ToStr(66): string([]byte{1}),
},
},
}
)
// TestMigrateChannelToSessionIndex tests that the MigrateChannelToSessionIndex
// function correctly builds the new channel-to-sessionID index to the tower
// client DB.
func TestMigrateChannelToSessionIndex(t *testing.T) {
t.Parallel()
tests := []struct {
name string
shouldFail bool
preDetails map[string]interface{}
preSessions map[string]interface{}
preChanIndex map[string]interface{}
postDetails map[string]interface{}
}{
{
name: "migration ok",
shouldFail: false,
preDetails: preDetails,
preSessions: sessions,
preChanIndex: channelIDIndex,
postDetails: postDetails,
},
{
name: "fail due to corrupt db",
shouldFail: true,
preDetails: preFailCorruptDB,
preSessions: sessions,
},
{
name: "no sessions",
shouldFail: false,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
// Before the migration we have a channel details
// bucket, a sessions bucket, a session ID index bucket
// and a channel ID index bucket.
before := func(tx kvdb.RwTx) error {
err := migtest.RestoreDB(
tx, cChanDetailsBkt, test.preDetails,
)
if err != nil {
return err
}
err = migtest.RestoreDB(
tx, cSessionBkt, test.preSessions,
)
if err != nil {
return err
}
return migtest.RestoreDB(
tx, cChanIDIndexBkt, test.preChanIndex,
)
}
after := func(tx kvdb.RwTx) error {
// If the migration fails, the details bucket
// should be untouched.
if test.shouldFail {
if err := migtest.VerifyDB(
tx, cChanDetailsBkt,
test.preDetails,
); err != nil {
return err
}
return nil
}
// Else, we expect an updated details bucket
// and a new index bucket.
return migtest.VerifyDB(
tx, cChanDetailsBkt, test.postDetails,
)
}
migtest.ApplyMigration(
t, before, after, MigrateChannelToSessionIndex,
test.shouldFail,
)
})
}
}
func sessionIDString(id string) string {
var sessID SessionID
copy(sessID[:], id)
return sessID.String()
}
func channelIDString(id uint64) string {
var chanID ChannelID
byteOrder.PutUint64(chanID[:], id)
return string(chanID[:])
}
func uint64ToStr(id uint64) string {
b, err := writeUint64(id)
if err != nil {
panic(err)
}
return string(b)
}

View file

@ -0,0 +1,29 @@
package migration7
import "encoding/hex"
// SessionIDSize is 33-bytes; it is a serialized, compressed public key.
const SessionIDSize = 33
// SessionID is created from the remote public key of a client, and serves as a
// unique identifier and authentication for sending state updates.
type SessionID [SessionIDSize]byte
// String returns a hex encoding of the session id.
func (s SessionID) String() string {
return hex.EncodeToString(s[:])
}
// ChannelID is a series of 32-bytes that uniquely identifies all channels
// within the network. The ChannelID is computed using the outpoint of the
// funding transaction (the txid, and output index). Given a funding output the
// ChannelID can be calculated by XOR'ing the big-endian serialization of the
// txid and the big-endian serialization of the output index, truncated to
// 2 bytes.
type ChannelID [32]byte
// String returns the string representation of the ChannelID. This is just the
// hex string encoding of the ChannelID itself.
func (c ChannelID) String() string {
return hex.EncodeToString(c[:])
}

View file

@ -0,0 +1,14 @@
package migration7
import (
"github.com/btcsuite/btclog"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}

View file

@ -10,6 +10,8 @@ import (
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration3"
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration4"
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration5"
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration6"
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration7"
)
// txMigration is a function which takes a prior outdated version of the
@ -59,6 +61,12 @@ var clientDBVersions = []version{
{
txMigration: migration5.MigrateCompleteTowerToSessionIndex,
},
{
txMigration: migration6.MigrateSessionIDIndex,
},
{
txMigration: migration7.MigrateChannelToSessionIndex,
},
}
// getLatestDBVersion returns the last known database version.

View file

@ -25,19 +25,26 @@ type rangeIndexArrayMap map[wtdb.SessionID]map[lnwire.ChannelID]*wtdb.RangeIndex
type rangeIndexKVStore map[wtdb.SessionID]map[lnwire.ChannelID]*mockKVStore
type channel struct {
summary *wtdb.ClientChanSummary
closedHeight uint32
sessions map[wtdb.SessionID]bool
}
// ClientDB is a mock, in-memory database or testing the watchtower client
// behavior.
type ClientDB struct {
nextTowerID uint64 // to be used atomically
mu sync.Mutex
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
channels map[lnwire.ChannelID]*channel
activeSessions map[wtdb.SessionID]wtdb.ClientSession
ackedUpdates rangeIndexArrayMap
persistedAckedUpdates rangeIndexKVStore
committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate
towerIndex map[towerPK]wtdb.TowerID
towers map[wtdb.TowerID]*wtdb.Tower
closableSessions map[wtdb.SessionID]uint32
nextIndex uint32
indexes map[keyIndexKey]uint32
@ -47,9 +54,7 @@ type ClientDB struct {
// NewClientDB initializes a new mock ClientDB.
func NewClientDB() *ClientDB {
return &ClientDB{
summaries: make(
map[lnwire.ChannelID]wtdb.ClientChanSummary,
),
channels: make(map[lnwire.ChannelID]*channel),
activeSessions: make(
map[wtdb.SessionID]wtdb.ClientSession,
),
@ -62,6 +67,7 @@ func NewClientDB() *ClientDB {
towers: make(map[wtdb.TowerID]*wtdb.Tower),
indexes: make(map[keyIndexKey]uint32),
legacyIndexes: make(map[wtdb.TowerID]uint32),
closableSessions: make(map[wtdb.SessionID]uint32),
}
}
@ -503,6 +509,13 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum,
continue
}
// Add sessionID to channel.
channel, ok := m.channels[update.BackupID.ChanID]
if !ok {
return wtdb.ErrChannelNotRegistered
}
channel.sessions[*id] = true
// Remove the committed update from disk and mark the update as
// acked. The tower last applied value is also recorded to send
// along with the next update.
@ -538,22 +551,192 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum,
return wtdb.ErrCommittedUpdateNotFound
}
// ListClosableSessions fetches and returns the IDs for all sessions marked as
// closable.
func (m *ClientDB) ListClosableSessions() (map[wtdb.SessionID]uint32, error) {
m.mu.Lock()
defer m.mu.Unlock()
cs := make(map[wtdb.SessionID]uint32, len(m.closableSessions))
for id, height := range m.closableSessions {
cs[id] = height
}
return cs, nil
}
// FetchChanSummaries loads a mapping from all registered channels to their
// channel summaries.
// channel summaries. Only the channels that have not yet been marked as closed
// will be loaded.
func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) {
m.mu.Lock()
defer m.mu.Unlock()
summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary)
for chanID, summary := range m.summaries {
for chanID, channel := range m.channels {
// Don't load the channel if it has been marked as closed.
if channel.closedHeight > 0 {
continue
}
summaries[chanID] = wtdb.ClientChanSummary{
SweepPkScript: cloneBytes(summary.SweepPkScript),
SweepPkScript: cloneBytes(
channel.summary.SweepPkScript,
),
}
}
return summaries, nil
}
// MarkChannelClosed will mark a registered channel as closed by setting
// its closed-height as the given block height. It returns a list of
// session IDs for sessions that are now considered closable due to the
// close of this channel.
func (m *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID,
blockHeight uint32) ([]wtdb.SessionID, error) {
m.mu.Lock()
defer m.mu.Unlock()
channel, ok := m.channels[chanID]
if !ok {
return nil, wtdb.ErrChannelNotRegistered
}
// If there are no sessions for this channel, the channel details can be
// deleted.
if len(channel.sessions) == 0 {
delete(m.channels, chanID)
return nil, nil
}
// Mark the channel as closed.
channel.closedHeight = blockHeight
// Now iterate through all the sessions of the channel to check if any
// of them are closeable.
var closableSessions []wtdb.SessionID
for sessID := range channel.sessions {
isClosable, err := m.isSessionClosable(sessID)
if err != nil {
return nil, err
}
if !isClosable {
continue
}
closableSessions = append(closableSessions, sessID)
// Add session to "closableSessions" list and add the block
// height that this last channel was closed in. This will be
// used in future to determine when we should delete the
// session.
m.closableSessions[sessID] = blockHeight
}
return closableSessions, nil
}
// isSessionClosable returns true if a session is considered closable. A session
// is considered closable only if:
// 1) It has no un-acked updates
// 2) It is exhausted (ie it cant accept any more updates)
// 3) All the channels that it has acked-updates for are closed.
func (m *ClientDB) isSessionClosable(id wtdb.SessionID) (bool, error) {
// The session is not closable if it has un-acked updates.
if len(m.committedUpdates[id]) > 0 {
return false, nil
}
sess, ok := m.activeSessions[id]
if !ok {
return false, wtdb.ErrClientSessionNotFound
}
// The session is not closable if it is not yet exhausted.
if sess.SeqNum != sess.Policy.MaxUpdates {
return false, nil
}
// Iterate over each of the channels that the session has acked-updates
// for. If any of those channels are not closed, then the session is
// not yet closable.
for chanID := range m.ackedUpdates[id] {
channel, ok := m.channels[chanID]
if !ok {
continue
}
// Channel is not yet closed, and so we can not yet delete the
// session.
if channel.closedHeight == 0 {
return false, nil
}
}
return true, nil
}
// GetClientSession loads the ClientSession with the given ID from the DB.
func (m *ClientDB) GetClientSession(id wtdb.SessionID,
opts ...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error) {
cfg := wtdb.NewClientSessionCfg()
for _, o := range opts {
o(cfg)
}
session, ok := m.activeSessions[id]
if !ok {
return nil, wtdb.ErrClientSessionNotFound
}
if cfg.PerMaxHeight != nil {
for chanID, index := range m.ackedUpdates[session.ID] {
cfg.PerMaxHeight(&session, chanID, index.MaxHeight())
}
}
if cfg.PerCommittedUpdate != nil {
for _, update := range m.committedUpdates[session.ID] {
update := update
cfg.PerCommittedUpdate(&session, &update)
}
}
return &session, nil
}
// DeleteSession can be called when a session should be deleted from the DB.
// All references to the session will also be deleted from the DB. Note that a
// session will only be deleted if it is considered closable.
func (m *ClientDB) DeleteSession(id wtdb.SessionID) error {
m.mu.Lock()
defer m.mu.Unlock()
_, ok := m.closableSessions[id]
if !ok {
return wtdb.ErrSessionNotClosable
}
// For each of the channels, delete the session ID entry.
for chanID := range m.ackedUpdates[id] {
c, ok := m.channels[chanID]
if !ok {
return wtdb.ErrChannelNotRegistered
}
delete(c.sessions, id)
}
delete(m.closableSessions, id)
delete(m.activeSessions, id)
return nil
}
// RegisterChannel registers a channel for use within the client database. For
// now, all that is stored in the channel summary is the sweep pkscript that
// we'd like any tower sweeps to pay into. In the future, this will be extended
@ -565,12 +748,15 @@ func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.summaries[chanID]; ok {
if _, ok := m.channels[chanID]; ok {
return wtdb.ErrChannelAlreadyRegistered
}
m.summaries[chanID] = wtdb.ClientChanSummary{
m.channels[chanID] = &channel{
summary: &wtdb.ClientChanSummary{
SweepPkScript: cloneBytes(sweepPkScript),
},
sessions: make(map[wtdb.SessionID]bool),
}
return nil