mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-02-24 06:47:44 +01:00
Merge pull request #7069 from ellemouton/deleteSessions
watchtower: start using the DeleteSession message
This commit is contained in:
commit
c4c1f1ac92
31 changed files with 3222 additions and 157 deletions
|
@ -9,6 +9,9 @@
|
|||
|
||||
* [Allow caller to filter sessions at the time of reading them from
|
||||
disk](https://github.com/lightningnetwork/lnd/pull/7059)
|
||||
* [Clean up sessions once all channels for which they have updates for are
|
||||
closed. Also start sending the `DeleteSession` message to the
|
||||
tower.](https://github.com/lightningnetwork/lnd/pull/7069)
|
||||
|
||||
## Misc
|
||||
|
||||
|
|
|
@ -515,4 +515,8 @@ var allTestCases = []*lntest.TestCase{
|
|||
Name: "lookup htlc resolution",
|
||||
TestFunc: testLookupHtlcResolution,
|
||||
},
|
||||
{
|
||||
Name: "watchtower session management",
|
||||
TestFunc: testWatchtowerSessionManagement,
|
||||
},
|
||||
}
|
||||
|
|
172
itest/lnd_watchtower_test.go
Normal file
172
itest/lnd_watchtower_test.go
Normal file
|
@ -0,0 +1,172 @@
|
|||
package itest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/btcsuite/btcd/btcutil"
|
||||
"github.com/lightningnetwork/lnd/funding"
|
||||
"github.com/lightningnetwork/lnd/lnrpc"
|
||||
"github.com/lightningnetwork/lnd/lnrpc/routerrpc"
|
||||
"github.com/lightningnetwork/lnd/lnrpc/wtclientrpc"
|
||||
"github.com/lightningnetwork/lnd/lntest"
|
||||
"github.com/lightningnetwork/lnd/lntest/node"
|
||||
"github.com/lightningnetwork/lnd/lntest/wait"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testWatchtowerSessionManagement tests that session deletion is done
|
||||
// correctly.
|
||||
func testWatchtowerSessionManagement(ht *lntest.HarnessTest) {
|
||||
const (
|
||||
chanAmt = funding.MaxBtcFundingAmount
|
||||
paymentAmt = 10_000
|
||||
numInvoices = 5
|
||||
maxUpdates = numInvoices * 2
|
||||
externalIP = "1.2.3.4"
|
||||
sessionCloseRange = 1
|
||||
)
|
||||
|
||||
// Set up Wallis the watchtower who will be used by Dave to watch over
|
||||
// his channel commitment transactions.
|
||||
wallis := ht.NewNode("Wallis", []string{
|
||||
"--watchtower.active",
|
||||
"--watchtower.externalip=" + externalIP,
|
||||
})
|
||||
|
||||
wallisInfo := wallis.RPC.GetInfoWatchtower()
|
||||
|
||||
// Assert that Wallis has one listener and it is 0.0.0.0:9911 or
|
||||
// [::]:9911. Since no listener is explicitly specified, one of these
|
||||
// should be the default depending on whether the host supports IPv6 or
|
||||
// not.
|
||||
require.Len(ht, wallisInfo.Listeners, 1)
|
||||
listener := wallisInfo.Listeners[0]
|
||||
require.True(ht, listener == "0.0.0.0:9911" || listener == "[::]:9911")
|
||||
|
||||
// Assert the Wallis's URIs properly display the chosen external IP.
|
||||
require.Len(ht, wallisInfo.Uris, 1)
|
||||
require.Contains(ht, wallisInfo.Uris[0], externalIP)
|
||||
|
||||
// Dave will be the tower client.
|
||||
daveArgs := []string{
|
||||
"--wtclient.active",
|
||||
fmt.Sprintf("--wtclient.max-updates=%d", maxUpdates),
|
||||
fmt.Sprintf(
|
||||
"--wtclient.session-close-range=%d", sessionCloseRange,
|
||||
),
|
||||
}
|
||||
dave := ht.NewNode("Dave", daveArgs)
|
||||
|
||||
addTowerReq := &wtclientrpc.AddTowerRequest{
|
||||
Pubkey: wallisInfo.Pubkey,
|
||||
Address: listener,
|
||||
}
|
||||
dave.RPC.AddTower(addTowerReq)
|
||||
|
||||
// Assert that there exists a session between Dave and Wallis.
|
||||
err := wait.NoError(func() error {
|
||||
info := dave.RPC.GetTowerInfo(&wtclientrpc.GetTowerInfoRequest{
|
||||
Pubkey: wallisInfo.Pubkey,
|
||||
IncludeSessions: true,
|
||||
})
|
||||
|
||||
var numSessions uint32
|
||||
for _, sessionType := range info.SessionInfo {
|
||||
numSessions += sessionType.NumSessions
|
||||
}
|
||||
if numSessions > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("expected a non-zero number of sessions")
|
||||
}, defaultTimeout)
|
||||
require.NoError(ht, err)
|
||||
|
||||
// Before we make a channel, we'll load up Dave with some coins sent
|
||||
// directly from the miner.
|
||||
ht.FundCoins(btcutil.SatoshiPerBitcoin, dave)
|
||||
|
||||
// Connect Dave and Alice.
|
||||
ht.ConnectNodes(dave, ht.Alice)
|
||||
|
||||
// Open a channel between Dave and Alice.
|
||||
params := lntest.OpenChannelParams{
|
||||
Amt: chanAmt,
|
||||
}
|
||||
chanPoint := ht.OpenChannel(dave, ht.Alice, params)
|
||||
|
||||
// Since there are 2 updates made for every payment and the maximum
|
||||
// number of updates per session has been set to 10, make 5 payments
|
||||
// between the pair so that the session is exhausted.
|
||||
alicePayReqs, _, _ := ht.CreatePayReqs(
|
||||
ht.Alice, paymentAmt, numInvoices,
|
||||
)
|
||||
|
||||
send := func(node *node.HarnessNode, payReq string) {
|
||||
stream := node.RPC.SendPayment(&routerrpc.SendPaymentRequest{
|
||||
PaymentRequest: payReq,
|
||||
TimeoutSeconds: 60,
|
||||
FeeLimitMsat: noFeeLimitMsat,
|
||||
})
|
||||
|
||||
ht.AssertPaymentStatusFromStream(
|
||||
stream, lnrpc.Payment_SUCCEEDED,
|
||||
)
|
||||
}
|
||||
|
||||
for i := 0; i < numInvoices; i++ {
|
||||
send(dave, alicePayReqs[i])
|
||||
}
|
||||
|
||||
// assertNumBackups is a closure that asserts that Dave has a certain
|
||||
// number of backups backed up to the tower. If mineOnFail is true,
|
||||
// then a block will be mined each time the assertion fails.
|
||||
assertNumBackups := func(expected int, mineOnFail bool) {
|
||||
err = wait.NoError(func() error {
|
||||
info := dave.RPC.GetTowerInfo(
|
||||
&wtclientrpc.GetTowerInfoRequest{
|
||||
Pubkey: wallisInfo.Pubkey,
|
||||
IncludeSessions: true,
|
||||
},
|
||||
)
|
||||
|
||||
var numBackups uint32
|
||||
for _, sessionType := range info.SessionInfo {
|
||||
for _, session := range sessionType.Sessions {
|
||||
numBackups += session.NumBackups
|
||||
}
|
||||
}
|
||||
|
||||
if numBackups == uint32(expected) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if mineOnFail {
|
||||
ht.Miner.MineBlocksSlow(1)
|
||||
}
|
||||
|
||||
return fmt.Errorf("expected %d backups, got %d",
|
||||
expected, numBackups)
|
||||
}, defaultTimeout)
|
||||
require.NoError(ht, err)
|
||||
}
|
||||
|
||||
// Assert that one of the sessions now has 10 backups.
|
||||
assertNumBackups(10, false)
|
||||
|
||||
// Now close the channel and wait for the close transaction to appear
|
||||
// in the mempool so that it is included in a block when we mine.
|
||||
ht.CloseChannelAssertPending(dave, chanPoint, false)
|
||||
|
||||
// Mine enough blocks to surpass the session-close-range. This should
|
||||
// trigger the session to be deleted.
|
||||
ht.MineBlocksAndAssertNumTxes(sessionCloseRange+6, 1)
|
||||
|
||||
// Wait for the session to be deleted. We know it has been deleted once
|
||||
// the number of backups is back to zero. We check for number of backups
|
||||
// instead of number of sessions because it is expected that the client
|
||||
// would immediately negotiate another session after deleting the
|
||||
// exhausted one. This time we set the "mineOnFail" parameter to true to
|
||||
// ensure that the session deleting logic is run.
|
||||
assertNumBackups(0, true)
|
||||
}
|
|
@ -17,6 +17,15 @@ type WtClient struct {
|
|||
// SweepFeeRate specifies the fee rate in sat/byte to be used when
|
||||
// constructing justice transactions sent to the tower.
|
||||
SweepFeeRate uint64 `long:"sweep-fee-rate" description:"Specifies the fee rate in sat/byte to be used when constructing justice transactions sent to the watchtower."`
|
||||
|
||||
// SessionCloseRange is the range over which to choose a random number
|
||||
// of blocks to wait after the last channel of a session is closed
|
||||
// before sending the DeleteSession message to the tower server.
|
||||
SessionCloseRange uint32 `long:"session-close-range" description:"The range over which to choose a random number of blocks to wait after the last channel of a session is closed before sending the DeleteSession message to the tower server. Set to 1 for no delay."`
|
||||
|
||||
// MaxUpdates is the maximum number of updates to be backed up in a
|
||||
// single tower sessions.
|
||||
MaxUpdates uint16 `long:"max-updates" description:"The maximum number of updates to be backed up in a single session."`
|
||||
}
|
||||
|
||||
// Validate ensures the user has provided a valid configuration.
|
||||
|
|
|
@ -309,6 +309,7 @@ func (c *WatchtowerClient) ListTowers(ctx context.Context,
|
|||
}
|
||||
|
||||
t.SessionInfo = append(t.SessionInfo, rpcTower.SessionInfo...)
|
||||
t.Sessions = append(t.Sessions, rpcTower.Sessions...)
|
||||
}
|
||||
|
||||
towers := make([]*Tower, 0, len(rpcTowers))
|
||||
|
@ -365,6 +366,9 @@ func (c *WatchtowerClient) GetTowerInfo(ctx context.Context,
|
|||
rpcTower.SessionInfo = append(
|
||||
rpcTower.SessionInfo, rpcLegacyTower.SessionInfo...,
|
||||
)
|
||||
rpcTower.Sessions = append(
|
||||
rpcTower.Sessions, rpcLegacyTower.Sessions...,
|
||||
)
|
||||
|
||||
return rpcTower, nil
|
||||
}
|
||||
|
|
|
@ -24,6 +24,20 @@ func (h *HarnessRPC) GetInfoWatchtower() *watchtowerrpc.GetInfoResponse {
|
|||
return info
|
||||
}
|
||||
|
||||
// GetTowerInfo makes an RPC call to the watchtower client of the given node and
|
||||
// asserts.
|
||||
func (h *HarnessRPC) GetTowerInfo(
|
||||
req *wtclientrpc.GetTowerInfoRequest) *wtclientrpc.Tower {
|
||||
|
||||
ctxt, cancel := context.WithTimeout(h.runCtx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
info, err := h.WatchtowerClient.GetTowerInfo(ctxt, req)
|
||||
h.NoError(err, "GetTowerInfo from WatchtowerClient")
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// AddTower makes a RPC call to the WatchtowerClient of the given node and
|
||||
// asserts.
|
||||
func (h *HarnessRPC) AddTower(
|
||||
|
|
|
@ -997,6 +997,15 @@ litecoin.node=ltcd
|
|||
; supported at this time, if none are provided the tower will not be enabled.
|
||||
; wtclient.private-tower-uris=
|
||||
|
||||
; The range over which to choose a random number of blocks to wait after the
|
||||
; last channel of a session is closed before sending the DeleteSession message
|
||||
; to the tower server. The default is currently 288. Note that setting this to
|
||||
; a lower value will result in faster session cleanup _but_ that this comes
|
||||
; along with reduced privacy from the tower server.
|
||||
; wtclient.session-close-range=10
|
||||
|
||||
; The maximum number of updates to include in a tower session.
|
||||
; wtclient.max-updates=1024
|
||||
|
||||
[healthcheck]
|
||||
|
||||
|
|
29
server.go
29
server.go
|
@ -1497,6 +1497,15 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
|
|||
policy.SweepFeeRate = sweepRateSatPerVByte.FeePerKWeight()
|
||||
}
|
||||
|
||||
if cfg.WtClient.MaxUpdates != 0 {
|
||||
policy.MaxUpdates = cfg.WtClient.MaxUpdates
|
||||
}
|
||||
|
||||
sessionCloseRange := uint32(wtclient.DefaultSessionCloseRange)
|
||||
if cfg.WtClient.SessionCloseRange != 0 {
|
||||
sessionCloseRange = cfg.WtClient.SessionCloseRange
|
||||
}
|
||||
|
||||
if err := policy.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1512,7 +1521,18 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
|
|||
)
|
||||
}
|
||||
|
||||
fetchClosedChannel := s.chanStateDB.FetchClosedChannelForID
|
||||
|
||||
s.towerClient, err = wtclient.New(&wtclient.Config{
|
||||
FetchClosedChannel: fetchClosedChannel,
|
||||
SessionCloseRange: sessionCloseRange,
|
||||
ChainNotifier: s.cc.ChainNotifier,
|
||||
SubscribeChannelEvents: func() (subscribe.Subscription,
|
||||
error) {
|
||||
|
||||
return s.channelNotifier.
|
||||
SubscribeChannelEvents()
|
||||
},
|
||||
Signer: cc.Wallet.Cfg.Signer,
|
||||
NewAddress: newSweepPkScriptGen(cc.Wallet),
|
||||
SecretKeyRing: s.cc.KeyRing,
|
||||
|
@ -1536,6 +1556,15 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
|
|||
blob.Type(blob.FlagAnchorChannel)
|
||||
|
||||
s.anchorTowerClient, err = wtclient.New(&wtclient.Config{
|
||||
FetchClosedChannel: fetchClosedChannel,
|
||||
SessionCloseRange: sessionCloseRange,
|
||||
ChainNotifier: s.cc.ChainNotifier,
|
||||
SubscribeChannelEvents: func() (subscribe.Subscription,
|
||||
error) {
|
||||
|
||||
return s.channelNotifier.
|
||||
SubscribeChannelEvents()
|
||||
},
|
||||
Signer: cc.Wallet.Cfg.Signer,
|
||||
NewAddress: newSweepPkScriptGen(cc.Wallet),
|
||||
SecretKeyRing: s.cc.KeyRing,
|
||||
|
|
|
@ -69,6 +69,12 @@ type AddressIterator interface {
|
|||
// Reset clears the iterators state, and makes the address at the front
|
||||
// of the list the next item to be returned.
|
||||
Reset()
|
||||
|
||||
// Copy constructs a new AddressIterator that has the same addresses
|
||||
// as this iterator.
|
||||
//
|
||||
// NOTE that the address locks are not expected to be copied.
|
||||
Copy() AddressIterator
|
||||
}
|
||||
|
||||
// A compile-time check to ensure that addressIterator implements the
|
||||
|
@ -324,6 +330,33 @@ func (a *addressIterator) GetAll() []net.Addr {
|
|||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
return a.getAllUnsafe()
|
||||
}
|
||||
|
||||
// Copy constructs a new AddressIterator that has the same addresses
|
||||
// as this iterator.
|
||||
//
|
||||
// NOTE that the address locks will not be copied.
|
||||
func (a *addressIterator) Copy() AddressIterator {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
addrs := a.getAllUnsafe()
|
||||
|
||||
// Since newAddressIterator will only ever return an error if it is
|
||||
// initialised with zero addresses, we can ignore the error here since
|
||||
// we are initialising it with the set of addresses of this
|
||||
// addressIterator which is by definition a non-empty list.
|
||||
iter, _ := newAddressIterator(addrs...)
|
||||
|
||||
return iter
|
||||
}
|
||||
|
||||
// getAllUnsafe returns a copy of all the addresses in the iterator.
|
||||
//
|
||||
// NOTE: this method is not thread safe and so must only be called once the
|
||||
// addressIterator mutex is already being held.
|
||||
func (a *addressIterator) getAllUnsafe() []net.Addr {
|
||||
var addrs []net.Addr
|
||||
cursor := a.addrList.Front()
|
||||
|
||||
|
|
|
@ -97,6 +97,11 @@ func TestAddrIterator(t *testing.T) {
|
|||
addrList := iter.GetAll()
|
||||
require.ElementsMatch(t, addrList, []net.Addr{addr1, addr2, addr3})
|
||||
|
||||
// Also check that an iterator constructed via the Copy method, also
|
||||
// contains all the expected addresses.
|
||||
newIterAddrs := iter.Copy().GetAll()
|
||||
require.ElementsMatch(t, newIterAddrs, []net.Addr{addr1, addr2, addr3})
|
||||
|
||||
// Let's now remove addr3.
|
||||
err = iter.Remove(addr3)
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -29,6 +29,10 @@ type TowerCandidateIterator interface {
|
|||
// candidates available as long as they remain in the set.
|
||||
Reset() error
|
||||
|
||||
// GetTower gets the tower with the given ID from the iterator. If no
|
||||
// such tower is found then ErrTowerNotInIterator is returned.
|
||||
GetTower(id wtdb.TowerID) (*Tower, error)
|
||||
|
||||
// Next returns the next candidate tower. The iterator is not required
|
||||
// to return results in any particular order. If no more candidates are
|
||||
// available, ErrTowerCandidatesExhausted is returned.
|
||||
|
@ -76,6 +80,20 @@ func (t *towerListIterator) Reset() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// GetTower gets the tower with the given ID from the iterator. If no such tower
|
||||
// is found then ErrTowerNotInIterator is returned.
|
||||
func (t *towerListIterator) GetTower(id wtdb.TowerID) (*Tower, error) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
tower, ok := t.candidates[id]
|
||||
if !ok {
|
||||
return nil, ErrTowerNotInIterator
|
||||
}
|
||||
|
||||
return tower, nil
|
||||
}
|
||||
|
||||
// Next returns the next candidate tower. This iterator will always return
|
||||
// candidates in the order given when the iterator was instantiated. If no more
|
||||
// candidates are available, ErrTowerCandidatesExhausted is returned.
|
||||
|
|
|
@ -52,14 +52,10 @@ func randTower(t *testing.T) *Tower {
|
|||
func copyTower(t *testing.T, tower *Tower) *Tower {
|
||||
t.Helper()
|
||||
|
||||
addrs := tower.Addresses.GetAll()
|
||||
addrIterator, err := newAddressIterator(addrs...)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &Tower{
|
||||
ID: tower.ID,
|
||||
IdentityKey: tower.IdentityKey,
|
||||
Addresses: addrIterator,
|
||||
Addresses: tower.Addresses.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -83,9 +79,15 @@ func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *Tower) {
|
|||
|
||||
tower, err := i.Next()
|
||||
require.NoError(t, err)
|
||||
require.True(t, tower.IdentityKey.IsEqual(c.IdentityKey))
|
||||
require.Equal(t, tower.ID, c.ID)
|
||||
require.Equal(t, tower.Addresses.GetAll(), c.Addresses.GetAll())
|
||||
assertTowersEqual(t, c, tower)
|
||||
}
|
||||
|
||||
func assertTowersEqual(t *testing.T, expected, actual *Tower) {
|
||||
t.Helper()
|
||||
|
||||
require.True(t, expected.IdentityKey.IsEqual(actual.IdentityKey))
|
||||
require.Equal(t, expected.ID, actual.ID)
|
||||
require.Equal(t, expected.Addresses.GetAll(), actual.Addresses.GetAll())
|
||||
}
|
||||
|
||||
// TestTowerCandidateIterator asserts the internal state of a
|
||||
|
@ -155,4 +157,16 @@ func TestTowerCandidateIterator(t *testing.T) {
|
|||
towerIterator.AddCandidate(secondTower)
|
||||
assertActiveCandidate(t, towerIterator, secondTower, true)
|
||||
assertNextCandidate(t, towerIterator, secondTower)
|
||||
|
||||
// Assert that the GetTower correctly returns the tower too.
|
||||
tower, err := towerIterator.GetTower(secondTower.ID)
|
||||
require.NoError(t, err)
|
||||
assertTowersEqual(t, secondTower, tower)
|
||||
|
||||
// Now remove the tower and assert that GetTower returns expected error.
|
||||
err = towerIterator.RemoveCandidate(secondTower.ID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = towerIterator.GetTower(secondTower.ID)
|
||||
require.ErrorIs(t, err, ErrTowerNotInIterator)
|
||||
}
|
||||
|
|
|
@ -2,7 +2,10 @@ package wtclient
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -11,11 +14,14 @@ import (
|
|||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btclog"
|
||||
"github.com/lightningnetwork/lnd/build"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/channelnotifier"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/subscribe"
|
||||
"github.com/lightningnetwork/lnd/tor"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
|
@ -40,6 +46,11 @@ const (
|
|||
// client should abandon any pending updates or session negotiations
|
||||
// before terminating.
|
||||
DefaultForceQuitDelay = 10 * time.Second
|
||||
|
||||
// DefaultSessionCloseRange is the range over which we will generate a
|
||||
// random number of blocks to delay closing a session after its last
|
||||
// channel has been closed.
|
||||
DefaultSessionCloseRange = 288
|
||||
)
|
||||
|
||||
// genSessionFilter constructs a filter that can be used to select sessions only
|
||||
|
@ -146,6 +157,19 @@ type Config struct {
|
|||
// transaction.
|
||||
Signer input.Signer
|
||||
|
||||
// SubscribeChannelEvents can be used to subscribe to channel event
|
||||
// notifications.
|
||||
SubscribeChannelEvents func() (subscribe.Subscription, error)
|
||||
|
||||
// FetchClosedChannel can be used to fetch the info about a closed
|
||||
// channel. If the channel is not found or not yet closed then
|
||||
// channeldb.ErrClosedChannelNotFound will be returned.
|
||||
FetchClosedChannel func(cid lnwire.ChannelID) (
|
||||
*channeldb.ChannelCloseSummary, error)
|
||||
|
||||
// ChainNotifier can be used to subscribe to block notifications.
|
||||
ChainNotifier chainntnfs.ChainNotifier
|
||||
|
||||
// NewAddress generates a new on-chain sweep pkscript.
|
||||
NewAddress func() ([]byte, error)
|
||||
|
||||
|
@ -201,6 +225,11 @@ type Config struct {
|
|||
// watchtowers. If the exponential backoff produces a timeout greater
|
||||
// than this value, the backoff will be clamped to MaxBackoff.
|
||||
MaxBackoff time.Duration
|
||||
|
||||
// SessionCloseRange is the range over which we will generate a random
|
||||
// number of blocks to delay closing a session after its last channel
|
||||
// has been closed.
|
||||
SessionCloseRange uint32
|
||||
}
|
||||
|
||||
// newTowerMsg is an internal message we'll use within the TowerClient to signal
|
||||
|
@ -258,6 +287,8 @@ type TowerClient struct {
|
|||
sessionQueue *sessionQueue
|
||||
prevTask *backupTask
|
||||
|
||||
closableSessionQueue *sessionCloseMinHeap
|
||||
|
||||
backupMu sync.Mutex
|
||||
summaries wtdb.ChannelSummaries
|
||||
chanCommitHeights map[lnwire.ChannelID]uint64
|
||||
|
@ -269,6 +300,7 @@ type TowerClient struct {
|
|||
staleTowers chan *staleTowerMsg
|
||||
|
||||
wg sync.WaitGroup
|
||||
quit chan struct{}
|
||||
forceQuit chan struct{}
|
||||
}
|
||||
|
||||
|
@ -308,17 +340,19 @@ func New(config *Config) (*TowerClient, error) {
|
|||
}
|
||||
|
||||
c := &TowerClient{
|
||||
cfg: cfg,
|
||||
log: plog,
|
||||
pipeline: newTaskPipeline(plog),
|
||||
chanCommitHeights: make(map[lnwire.ChannelID]uint64),
|
||||
activeSessions: make(sessionQueueSet),
|
||||
summaries: chanSummaries,
|
||||
statTicker: time.NewTicker(DefaultStatInterval),
|
||||
stats: new(ClientStats),
|
||||
newTowers: make(chan *newTowerMsg),
|
||||
staleTowers: make(chan *staleTowerMsg),
|
||||
forceQuit: make(chan struct{}),
|
||||
cfg: cfg,
|
||||
log: plog,
|
||||
pipeline: newTaskPipeline(plog),
|
||||
chanCommitHeights: make(map[lnwire.ChannelID]uint64),
|
||||
activeSessions: make(sessionQueueSet),
|
||||
summaries: chanSummaries,
|
||||
closableSessionQueue: newSessionCloseMinHeap(),
|
||||
statTicker: time.NewTicker(DefaultStatInterval),
|
||||
stats: new(ClientStats),
|
||||
newTowers: make(chan *newTowerMsg),
|
||||
staleTowers: make(chan *staleTowerMsg),
|
||||
forceQuit: make(chan struct{}),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
|
||||
// perUpdate is a callback function that will be used to inspect the
|
||||
|
@ -364,7 +398,7 @@ func New(config *Config) (*TowerClient, error) {
|
|||
return
|
||||
}
|
||||
|
||||
log.Infof("Using private watchtower %s, offering policy %s",
|
||||
c.log.Infof("Using private watchtower %s, offering policy %s",
|
||||
tower, cfg.Policy)
|
||||
|
||||
// Add the tower to the set of candidate towers.
|
||||
|
@ -435,27 +469,19 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
|||
}
|
||||
|
||||
for _, s := range sessions {
|
||||
towerKeyDesc, err := keyRing.DeriveKey(
|
||||
keychain.KeyLocator{
|
||||
Family: keychain.KeyFamilyTowerSession,
|
||||
Index: s.KeyIndex,
|
||||
},
|
||||
if !sessionFilter(s) {
|
||||
continue
|
||||
}
|
||||
|
||||
cs, err := NewClientSessionFromDBSession(
|
||||
s, tower, keyRing,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sessionKeyECDH := keychain.NewPubKeyECDH(
|
||||
towerKeyDesc, keyRing,
|
||||
)
|
||||
|
||||
// Add the session to the set of candidate sessions.
|
||||
candidateSessions[s.ID] = &ClientSession{
|
||||
ID: s.ID,
|
||||
ClientSessionBody: s.ClientSessionBody,
|
||||
Tower: tower,
|
||||
SessionKeyECDH: sessionKeyECDH,
|
||||
}
|
||||
candidateSessions[s.ID] = cs
|
||||
|
||||
perActiveTower(tower)
|
||||
}
|
||||
|
@ -548,10 +574,70 @@ func (c *TowerClient) Start() error {
|
|||
}
|
||||
}
|
||||
|
||||
chanSub, err := c.cfg.SubscribeChannelEvents()
|
||||
if err != nil {
|
||||
returnErr = err
|
||||
return
|
||||
}
|
||||
|
||||
// Iterate over the list of registered channels and check if
|
||||
// any of them can be marked as closed.
|
||||
for id := range c.summaries {
|
||||
isClosed, closedHeight, err := c.isChannelClosed(id)
|
||||
if err != nil {
|
||||
returnErr = err
|
||||
return
|
||||
}
|
||||
|
||||
if !isClosed {
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = c.cfg.DB.MarkChannelClosed(id, closedHeight)
|
||||
if err != nil {
|
||||
c.log.Errorf("could not mark channel(%s) as "+
|
||||
"closed: %v", id, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Since the channel has been marked as closed, we can
|
||||
// also remove it from the channel summaries map.
|
||||
delete(c.summaries, id)
|
||||
}
|
||||
|
||||
// Load all closable sessions.
|
||||
closableSessions, err := c.cfg.DB.ListClosableSessions()
|
||||
if err != nil {
|
||||
returnErr = err
|
||||
return
|
||||
}
|
||||
|
||||
err = c.trackClosableSessions(closableSessions)
|
||||
if err != nil {
|
||||
returnErr = err
|
||||
return
|
||||
}
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.handleChannelCloses(chanSub)
|
||||
|
||||
// Subscribe to new block events.
|
||||
blockEvents, err := c.cfg.ChainNotifier.RegisterBlockEpochNtfn(
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
returnErr = err
|
||||
return
|
||||
}
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.handleClosableSessions(blockEvents)
|
||||
|
||||
// Now start the session negotiator, which will allow us to
|
||||
// request new session as soon as the backupDispatcher starts
|
||||
// up.
|
||||
err := c.negotiator.Start()
|
||||
err = c.negotiator.Start()
|
||||
if err != nil {
|
||||
returnErr = err
|
||||
return
|
||||
|
@ -599,6 +685,7 @@ func (c *TowerClient) Stop() error {
|
|||
// dispatcher to exit. The backup queue will signal it's
|
||||
// completion to the dispatcher, which releases the wait group
|
||||
// after all tasks have been assigned to session queues.
|
||||
close(c.quit)
|
||||
c.wg.Wait()
|
||||
|
||||
// 4. Since all valid tasks have been assigned to session
|
||||
|
@ -780,6 +867,335 @@ func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) {
|
|||
return c.getOrInitActiveQueue(candidateSession, updates), nil
|
||||
}
|
||||
|
||||
// handleChannelCloses listens for channel close events and marks channels as
|
||||
// closed in the DB.
|
||||
//
|
||||
// NOTE: This method MUST be run as a goroutine.
|
||||
func (c *TowerClient) handleChannelCloses(chanSub subscribe.Subscription) {
|
||||
defer c.wg.Done()
|
||||
|
||||
c.log.Debugf("Starting channel close handler")
|
||||
defer c.log.Debugf("Stopping channel close handler")
|
||||
|
||||
for {
|
||||
select {
|
||||
case update, ok := <-chanSub.Updates():
|
||||
if !ok {
|
||||
c.log.Debugf("Channel notifier has exited")
|
||||
return
|
||||
}
|
||||
|
||||
// We only care about channel-close events.
|
||||
event, ok := update.(channelnotifier.ClosedChannelEvent)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
chanID := lnwire.NewChanIDFromOutPoint(
|
||||
&event.CloseSummary.ChanPoint,
|
||||
)
|
||||
|
||||
c.log.Debugf("Received ClosedChannelEvent for "+
|
||||
"channel: %s", chanID)
|
||||
|
||||
err := c.handleClosedChannel(
|
||||
chanID, event.CloseSummary.CloseHeight,
|
||||
)
|
||||
if err != nil {
|
||||
c.log.Errorf("Could not handle channel close "+
|
||||
"event for channel(%s): %v", chanID,
|
||||
err)
|
||||
}
|
||||
|
||||
case <-c.forceQuit:
|
||||
return
|
||||
|
||||
case <-c.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleClosedChannel handles the closure of a single channel. It will mark the
|
||||
// channel as closed in the DB, then it will handle all the sessions that are
|
||||
// now closable due to the channel closure.
|
||||
func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID,
|
||||
closeHeight uint32) error {
|
||||
|
||||
c.backupMu.Lock()
|
||||
defer c.backupMu.Unlock()
|
||||
|
||||
// We only care about channels registered with the tower client.
|
||||
if _, ok := c.summaries[chanID]; !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.log.Debugf("Marking channel(%s) as closed", chanID)
|
||||
|
||||
sessions, err := c.cfg.DB.MarkChannelClosed(chanID, closeHeight)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not mark channel(%s) as closed: %w",
|
||||
chanID, err)
|
||||
}
|
||||
|
||||
closableSessions := make(map[wtdb.SessionID]uint32, len(sessions))
|
||||
for _, sess := range sessions {
|
||||
closableSessions[sess] = closeHeight
|
||||
}
|
||||
|
||||
c.log.Debugf("Tracking %d new closable sessions as a result of "+
|
||||
"closing channel %s", len(closableSessions), chanID)
|
||||
|
||||
err = c.trackClosableSessions(closableSessions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not track closable sessions: %w", err)
|
||||
}
|
||||
|
||||
delete(c.summaries, chanID)
|
||||
delete(c.chanCommitHeights, chanID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleClosableSessions listens for new block notifications. For each block,
|
||||
// it checks the closableSessionQueue to see if there is a closable session with
|
||||
// a delete-height smaller than or equal to the new block, if there is then the
|
||||
// tower is informed that it can delete the session, and then we also delete it
|
||||
// from our DB.
|
||||
func (c *TowerClient) handleClosableSessions(
|
||||
blocksChan *chainntnfs.BlockEpochEvent) {
|
||||
|
||||
defer c.wg.Done()
|
||||
|
||||
c.log.Debug("Starting closable sessions handler")
|
||||
defer c.log.Debug("Stopping closable sessions handler")
|
||||
|
||||
for {
|
||||
select {
|
||||
case newBlock := <-blocksChan.Epochs:
|
||||
if newBlock == nil {
|
||||
return
|
||||
}
|
||||
|
||||
height := uint32(newBlock.Height)
|
||||
for {
|
||||
select {
|
||||
case <-c.quit:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// If there are no closable sessions that we
|
||||
// need to handle, then we are done and can
|
||||
// reevaluate when the next block comes.
|
||||
item := c.closableSessionQueue.Top()
|
||||
if item == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// If there is closable session but the delete
|
||||
// height we have set for it is after the
|
||||
// current block height, then our work is done.
|
||||
if item.deleteHeight > height {
|
||||
break
|
||||
}
|
||||
|
||||
// Otherwise, we pop this item from the heap
|
||||
// and handle it.
|
||||
c.closableSessionQueue.Pop()
|
||||
|
||||
// Fetch the session from the DB so that we can
|
||||
// extract the Tower info.
|
||||
sess, err := c.cfg.DB.GetClientSession(
|
||||
item.sessionID,
|
||||
)
|
||||
if err != nil {
|
||||
c.log.Errorf("error calling "+
|
||||
"GetClientSession for "+
|
||||
"session %s: %v",
|
||||
item.sessionID, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
err = c.deleteSessionFromTower(sess)
|
||||
if err != nil {
|
||||
c.log.Errorf("error deleting "+
|
||||
"session %s from tower: %v",
|
||||
sess.ID, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
err = c.cfg.DB.DeleteSession(item.sessionID)
|
||||
if err != nil {
|
||||
c.log.Errorf("could not delete "+
|
||||
"session(%s) from DB: %w",
|
||||
sess.ID, err)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
case <-c.forceQuit:
|
||||
return
|
||||
|
||||
case <-c.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// trackClosableSessions takes in a map of session IDs to the earliest block
|
||||
// height at which the session should be deleted. For each of the sessions,
|
||||
// a random delay is added to the block height and the session is added to the
|
||||
// closableSessionQueue.
|
||||
func (c *TowerClient) trackClosableSessions(
|
||||
sessions map[wtdb.SessionID]uint32) error {
|
||||
|
||||
// For each closable session, add a random delay to its close
|
||||
// height and add it to the closableSessionQueue.
|
||||
for sID, blockHeight := range sessions {
|
||||
delay, err := newRandomDelay(c.cfg.SessionCloseRange)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
deleteHeight := blockHeight + delay
|
||||
|
||||
c.closableSessionQueue.Push(&sessionCloseItem{
|
||||
sessionID: sID,
|
||||
deleteHeight: deleteHeight,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteSessionFromTower dials the tower that we created the session with and
|
||||
// attempts to send the tower the DeleteSession message.
|
||||
func (c *TowerClient) deleteSessionFromTower(sess *wtdb.ClientSession) error {
|
||||
// First, we check if we have already loaded this tower in our
|
||||
// candidate towers iterator.
|
||||
tower, err := c.candidateTowers.GetTower(sess.TowerID)
|
||||
if errors.Is(err, ErrTowerNotInIterator) {
|
||||
// If not, then we attempt to load it from the DB.
|
||||
dbTower, err := c.cfg.DB.LoadTowerByID(sess.TowerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tower, err = NewTowerFromDBTower(dbTower)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session, err := NewClientSessionFromDBSession(
|
||||
sess, tower, c.cfg.SecretKeyRing,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
localInit := wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
|
||||
c.cfg.ChainHash,
|
||||
)
|
||||
|
||||
var (
|
||||
conn wtserver.Peer
|
||||
|
||||
// addrIterator is a copy of the tower's address iterator.
|
||||
// We use this copy so that iterating through the addresses does
|
||||
// not affect any other threads using this iterator.
|
||||
addrIterator = tower.Addresses.Copy()
|
||||
towerAddr = addrIterator.Peek()
|
||||
)
|
||||
// Attempt to dial the tower with its available addresses.
|
||||
for {
|
||||
conn, err = c.dial(
|
||||
session.SessionKeyECDH, &lnwire.NetAddress{
|
||||
IdentityKey: tower.IdentityKey,
|
||||
Address: towerAddr,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
// If there are more addrs available, immediately try
|
||||
// those.
|
||||
nextAddr, iteratorErr := addrIterator.Next()
|
||||
if iteratorErr == nil {
|
||||
towerAddr = nextAddr
|
||||
continue
|
||||
}
|
||||
|
||||
// Otherwise, if we have exhausted the address list,
|
||||
// exit.
|
||||
addrIterator.Reset()
|
||||
|
||||
return fmt.Errorf("failed to dial tower(%x) at any "+
|
||||
"available addresses",
|
||||
tower.IdentityKey.SerializeCompressed())
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Send Init to tower.
|
||||
err = c.sendMessage(conn, localInit)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive Init from tower.
|
||||
remoteMsg, err := c.readMessage(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
remoteInit, ok := remoteMsg.(*wtwire.Init)
|
||||
if !ok {
|
||||
return fmt.Errorf("watchtower %s responded with %T to Init",
|
||||
towerAddr, remoteMsg)
|
||||
}
|
||||
|
||||
// Validate Init.
|
||||
err = localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send DeleteSession to tower.
|
||||
err = c.sendMessage(conn, &wtwire.DeleteSession{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive DeleteSessionReply from tower.
|
||||
remoteMsg, err = c.readMessage(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
deleteSessionReply, ok := remoteMsg.(*wtwire.DeleteSessionReply)
|
||||
if !ok {
|
||||
return fmt.Errorf("watchtower %s responded with %T to "+
|
||||
"DeleteSession", towerAddr, remoteMsg)
|
||||
}
|
||||
|
||||
switch deleteSessionReply.Code {
|
||||
case wtwire.CodeOK, wtwire.DeleteSessionCodeNotFound:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("received error code %v in "+
|
||||
"DeleteSessionReply when attempting to delete "+
|
||||
"session from tower", deleteSessionReply.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// backupDispatcher processes events coming from the taskPipeline and is
|
||||
// responsible for detecting when the client needs to renegotiate a session to
|
||||
// fulfill continuing demand. The event loop exits after all tasks have been
|
||||
|
@ -1153,6 +1569,22 @@ func (c *TowerClient) initActiveQueue(s *ClientSession,
|
|||
return sq
|
||||
}
|
||||
|
||||
// isChanClosed can be used to check if the channel with the given ID has been
|
||||
// closed. If it has been, the block height in which its closing transaction was
|
||||
// mined will also be returned.
|
||||
func (c *TowerClient) isChannelClosed(id lnwire.ChannelID) (bool, uint32,
|
||||
error) {
|
||||
|
||||
chanSum, err := c.cfg.FetchClosedChannel(id)
|
||||
if errors.Is(err, channeldb.ErrClosedChannelNotFound) {
|
||||
return false, 0, nil
|
||||
} else if err != nil {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
return true, chanSum.CloseHeight, nil
|
||||
}
|
||||
|
||||
// AddTower adds a new watchtower reachable at the given address and considers
|
||||
// it for new sessions. If the watchtower already exists, then any new addresses
|
||||
// included will be considered when dialing it for session negotiations and
|
||||
|
@ -1409,3 +1841,15 @@ func (c *TowerClient) logMessage(
|
|||
preposition, peer.RemotePub().SerializeCompressed(),
|
||||
peer.RemoteAddr())
|
||||
}
|
||||
|
||||
func newRandomDelay(max uint32) (uint32, error) {
|
||||
var maxDelay big.Int
|
||||
maxDelay.SetUint64(uint64(max))
|
||||
|
||||
randDelay, err := rand.Int(rand.Reader, &maxDelay)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return uint32(randDelay.Uint64()), nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package wtclient_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -15,12 +16,15 @@ import (
|
|||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/txscript"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/channelnotifier"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lntest/wait"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/subscribe"
|
||||
"github.com/lightningnetwork/lnd/tor"
|
||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtclient"
|
||||
|
@ -393,8 +397,15 @@ type testHarness struct {
|
|||
server *wtserver.Server
|
||||
net *mockNet
|
||||
|
||||
mu sync.Mutex
|
||||
channels map[lnwire.ChannelID]*mockChannel
|
||||
blockEvents *mockBlockSub
|
||||
height int32
|
||||
|
||||
channelEvents *mockSubscription
|
||||
sendUpdatesOn bool
|
||||
|
||||
mu sync.Mutex
|
||||
channels map[lnwire.ChannelID]*mockChannel
|
||||
closedChannels map[lnwire.ChannelID]uint32
|
||||
|
||||
quit chan struct{}
|
||||
}
|
||||
|
@ -441,41 +452,63 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
|||
mockNet := newMockNet()
|
||||
clientDB := wtmock.NewClientDB()
|
||||
|
||||
clientCfg := &wtclient.Config{
|
||||
Signer: signer,
|
||||
Dial: mockNet.Dial,
|
||||
DB: clientDB,
|
||||
AuthDial: mockNet.AuthDial,
|
||||
SecretKeyRing: wtmock.NewSecretKeyRing(),
|
||||
Policy: cfg.policy,
|
||||
NewAddress: func() ([]byte, error) {
|
||||
return addrScript, nil
|
||||
},
|
||||
ReadTimeout: timeout,
|
||||
WriteTimeout: timeout,
|
||||
MinBackoff: time.Millisecond,
|
||||
MaxBackoff: time.Second,
|
||||
ForceQuitDelay: 10 * time.Second,
|
||||
}
|
||||
|
||||
h := &testHarness{
|
||||
t: t,
|
||||
cfg: cfg,
|
||||
signer: signer,
|
||||
capacity: cfg.localBalance + cfg.remoteBalance,
|
||||
clientDB: clientDB,
|
||||
clientCfg: clientCfg,
|
||||
serverAddr: towerAddr,
|
||||
serverDB: serverDB,
|
||||
serverCfg: serverCfg,
|
||||
net: mockNet,
|
||||
channels: make(map[lnwire.ChannelID]*mockChannel),
|
||||
quit: make(chan struct{}),
|
||||
t: t,
|
||||
cfg: cfg,
|
||||
signer: signer,
|
||||
capacity: cfg.localBalance + cfg.remoteBalance,
|
||||
clientDB: clientDB,
|
||||
serverAddr: towerAddr,
|
||||
serverDB: serverDB,
|
||||
serverCfg: serverCfg,
|
||||
net: mockNet,
|
||||
blockEvents: newMockBlockSub(t),
|
||||
channelEvents: newMockSubscription(t),
|
||||
channels: make(map[lnwire.ChannelID]*mockChannel),
|
||||
closedChannels: make(map[lnwire.ChannelID]uint32),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
close(h.quit)
|
||||
})
|
||||
|
||||
fetchChannel := func(id lnwire.ChannelID) (
|
||||
*channeldb.ChannelCloseSummary, error) {
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
height, ok := h.closedChannels[id]
|
||||
if !ok {
|
||||
return nil, channeldb.ErrClosedChannelNotFound
|
||||
}
|
||||
|
||||
return &channeldb.ChannelCloseSummary{CloseHeight: height}, nil
|
||||
}
|
||||
|
||||
h.clientCfg = &wtclient.Config{
|
||||
Signer: signer,
|
||||
SubscribeChannelEvents: func() (subscribe.Subscription, error) {
|
||||
return h.channelEvents, nil
|
||||
},
|
||||
FetchClosedChannel: fetchChannel,
|
||||
ChainNotifier: h.blockEvents,
|
||||
Dial: mockNet.Dial,
|
||||
DB: clientDB,
|
||||
AuthDial: mockNet.AuthDial,
|
||||
SecretKeyRing: wtmock.NewSecretKeyRing(),
|
||||
Policy: cfg.policy,
|
||||
NewAddress: func() ([]byte, error) {
|
||||
return addrScript, nil
|
||||
},
|
||||
ReadTimeout: timeout,
|
||||
WriteTimeout: timeout,
|
||||
MinBackoff: time.Millisecond,
|
||||
MaxBackoff: time.Second,
|
||||
ForceQuitDelay: 10 * time.Second,
|
||||
SessionCloseRange: 1,
|
||||
}
|
||||
|
||||
if !cfg.noServerStart {
|
||||
h.startServer()
|
||||
t.Cleanup(h.stopServer)
|
||||
|
@ -492,6 +525,16 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
|||
return h
|
||||
}
|
||||
|
||||
// mine mimics the mining of new blocks by sending new block notifications.
|
||||
func (h *testHarness) mine(numBlocks int) {
|
||||
h.t.Helper()
|
||||
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
h.height++
|
||||
h.blockEvents.sendNewBlock(h.height)
|
||||
}
|
||||
}
|
||||
|
||||
// startServer creates a new server using the harness's current serverCfg and
|
||||
// starts it after pointing the mockNet's callback to the new server.
|
||||
func (h *testHarness) startServer() {
|
||||
|
@ -576,6 +619,41 @@ func (h *testHarness) channel(id uint64) *mockChannel {
|
|||
return c
|
||||
}
|
||||
|
||||
// closeChannel marks a channel as closed.
|
||||
//
|
||||
// NOTE: The method fails if a channel for id does not exist.
|
||||
func (h *testHarness) closeChannel(id uint64, height uint32) {
|
||||
h.t.Helper()
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
chanID := chanIDFromInt(id)
|
||||
|
||||
_, ok := h.channels[chanID]
|
||||
require.Truef(h.t, ok, "unable to fetch channel %d", id)
|
||||
|
||||
h.closedChannels[chanID] = height
|
||||
delete(h.channels, chanID)
|
||||
|
||||
chanPointHash, err := chainhash.NewHash(chanID[:])
|
||||
require.NoError(h.t, err)
|
||||
|
||||
if !h.sendUpdatesOn {
|
||||
return
|
||||
}
|
||||
|
||||
h.channelEvents.sendUpdate(channelnotifier.ClosedChannelEvent{
|
||||
CloseSummary: &channeldb.ChannelCloseSummary{
|
||||
ChanPoint: wire.OutPoint{
|
||||
Hash: *chanPointHash,
|
||||
Index: 0,
|
||||
},
|
||||
CloseHeight: height,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// registerChannel registers the channel identified by id with the client.
|
||||
func (h *testHarness) registerChannel(id uint64) {
|
||||
h.t.Helper()
|
||||
|
@ -624,7 +702,7 @@ func (h *testHarness) backupState(id, i uint64, expErr error) {
|
|||
err := h.client.BackupState(
|
||||
&chanID, retribution, channeldb.SingleFunderBit,
|
||||
)
|
||||
require.ErrorIs(h.t, expErr, err)
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
}
|
||||
|
||||
// sendPayments instructs the channel identified by id to send amt to the remote
|
||||
|
@ -770,11 +848,132 @@ func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) {
|
|||
require.NoError(h.t, err)
|
||||
}
|
||||
|
||||
// relevantSessions returns a list of session IDs that have acked updates for
|
||||
// the given channel ID.
|
||||
func (h *testHarness) relevantSessions(chanID uint64) []wtdb.SessionID {
|
||||
h.t.Helper()
|
||||
|
||||
var (
|
||||
sessionIDs []wtdb.SessionID
|
||||
cID = chanIDFromInt(chanID)
|
||||
)
|
||||
|
||||
collectSessions := wtdb.WithPerNumAckedUpdates(
|
||||
func(session *wtdb.ClientSession, id lnwire.ChannelID,
|
||||
_ uint16) {
|
||||
|
||||
if !bytes.Equal(id[:], cID[:]) {
|
||||
return
|
||||
}
|
||||
|
||||
sessionIDs = append(sessionIDs, session.ID)
|
||||
},
|
||||
)
|
||||
|
||||
_, err := h.clientDB.ListClientSessions(nil, nil, collectSessions)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
return sessionIDs
|
||||
}
|
||||
|
||||
// isSessionClosable returns true if the given session has been marked as
|
||||
// closable in the DB.
|
||||
func (h *testHarness) isSessionClosable(id wtdb.SessionID) bool {
|
||||
h.t.Helper()
|
||||
|
||||
cs, err := h.clientDB.ListClosableSessions()
|
||||
require.NoError(h.t, err)
|
||||
|
||||
_, ok := cs[id]
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// mockSubscription is a mock subscription client that blocks on sends into the
|
||||
// updates channel.
|
||||
type mockSubscription struct {
|
||||
t *testing.T
|
||||
updates chan interface{}
|
||||
|
||||
// Embed the subscription interface in this mock so that we satisfy it.
|
||||
subscribe.Subscription
|
||||
}
|
||||
|
||||
// newMockSubscription creates a mock subscription.
|
||||
func newMockSubscription(t *testing.T) *mockSubscription {
|
||||
t.Helper()
|
||||
|
||||
return &mockSubscription{
|
||||
t: t,
|
||||
updates: make(chan interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
// sendUpdate sends an update into our updates channel, mocking the dispatch of
|
||||
// an update from a subscription server. This call will fail the test if the
|
||||
// update is not consumed within our timeout.
|
||||
func (m *mockSubscription) sendUpdate(update interface{}) {
|
||||
select {
|
||||
case m.updates <- update:
|
||||
|
||||
case <-time.After(waitTime):
|
||||
m.t.Fatalf("update: %v timeout", update)
|
||||
}
|
||||
}
|
||||
|
||||
// Updates returns the updates channel for the mock.
|
||||
func (m *mockSubscription) Updates() <-chan interface{} {
|
||||
return m.updates
|
||||
}
|
||||
|
||||
// mockBlockSub mocks out the ChainNotifier.
|
||||
type mockBlockSub struct {
|
||||
t *testing.T
|
||||
events chan *chainntnfs.BlockEpoch
|
||||
|
||||
chainntnfs.ChainNotifier
|
||||
}
|
||||
|
||||
// newMockBlockSub creates a new mockBlockSub.
|
||||
func newMockBlockSub(t *testing.T) *mockBlockSub {
|
||||
t.Helper()
|
||||
|
||||
return &mockBlockSub{
|
||||
t: t,
|
||||
events: make(chan *chainntnfs.BlockEpoch),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterBlockEpochNtfn returns a channel that can be used to listen for new
|
||||
// blocks.
|
||||
func (m *mockBlockSub) RegisterBlockEpochNtfn(_ *chainntnfs.BlockEpoch) (
|
||||
*chainntnfs.BlockEpochEvent, error) {
|
||||
|
||||
return &chainntnfs.BlockEpochEvent{
|
||||
Epochs: m.events,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// sendNewBlock will send a new block on the notification channel.
|
||||
func (m *mockBlockSub) sendNewBlock(height int32) {
|
||||
select {
|
||||
case m.events <- &chainntnfs.BlockEpoch{Height: height}:
|
||||
|
||||
case <-time.After(waitTime):
|
||||
m.t.Fatalf("timed out sending block: %d", height)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
localBalance = lnwire.MilliSatoshi(100000000)
|
||||
remoteBalance = lnwire.MilliSatoshi(200000000)
|
||||
)
|
||||
|
||||
var defaultTxPolicy = wtpolicy.TxPolicy{
|
||||
BlobType: blob.TypeAltruistCommit,
|
||||
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||
}
|
||||
|
||||
type clientTest struct {
|
||||
name string
|
||||
cfg harnessCfg
|
||||
|
@ -791,10 +990,7 @@ var clientTests = []clientTest{
|
|||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blob.TypeAltruistCommit,
|
||||
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||
},
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 20000,
|
||||
},
|
||||
noRegisterChan0: true,
|
||||
|
@ -825,10 +1021,7 @@ var clientTests = []clientTest{
|
|||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blob.TypeAltruistCommit,
|
||||
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||
},
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 20000,
|
||||
},
|
||||
},
|
||||
|
@ -860,10 +1053,7 @@ var clientTests = []clientTest{
|
|||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blob.TypeAltruistCommit,
|
||||
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||
},
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 5,
|
||||
},
|
||||
},
|
||||
|
@ -927,10 +1117,7 @@ var clientTests = []clientTest{
|
|||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blob.TypeAltruistCommit,
|
||||
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||
},
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 20000,
|
||||
},
|
||||
},
|
||||
|
@ -1006,10 +1193,7 @@ var clientTests = []clientTest{
|
|||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blob.TypeAltruistCommit,
|
||||
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||
},
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 5,
|
||||
},
|
||||
},
|
||||
|
@ -1062,10 +1246,7 @@ var clientTests = []clientTest{
|
|||
localBalance: 100000001, // ensure (% amt != 0)
|
||||
remoteBalance: 200000001, // ensure (% amt != 0)
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blob.TypeAltruistCommit,
|
||||
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||
},
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 1000,
|
||||
},
|
||||
},
|
||||
|
@ -1106,10 +1287,7 @@ var clientTests = []clientTest{
|
|||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blob.TypeAltruistCommit,
|
||||
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||
},
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 5,
|
||||
},
|
||||
},
|
||||
|
@ -1156,10 +1334,7 @@ var clientTests = []clientTest{
|
|||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blob.TypeAltruistCommit,
|
||||
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||
},
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 5,
|
||||
},
|
||||
noAckCreateSession: true,
|
||||
|
@ -1212,10 +1387,7 @@ var clientTests = []clientTest{
|
|||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blob.TypeAltruistCommit,
|
||||
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||
},
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 5,
|
||||
},
|
||||
noAckCreateSession: true,
|
||||
|
@ -1274,10 +1446,7 @@ var clientTests = []clientTest{
|
|||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blob.TypeAltruistCommit,
|
||||
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||
},
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 10,
|
||||
},
|
||||
},
|
||||
|
@ -1333,10 +1502,7 @@ var clientTests = []clientTest{
|
|||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blob.TypeAltruistCommit,
|
||||
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||
},
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 5,
|
||||
},
|
||||
},
|
||||
|
@ -1381,10 +1547,7 @@ var clientTests = []clientTest{
|
|||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blob.TypeAltruistCommit,
|
||||
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||
},
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 5,
|
||||
},
|
||||
},
|
||||
|
@ -1489,10 +1652,7 @@ var clientTests = []clientTest{
|
|||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blob.TypeAltruistCommit,
|
||||
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||
},
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 5,
|
||||
},
|
||||
},
|
||||
|
@ -1557,10 +1717,7 @@ var clientTests = []clientTest{
|
|||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blob.TypeAltruistCommit,
|
||||
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||
},
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 5,
|
||||
},
|
||||
noServerStart: true,
|
||||
|
@ -1654,6 +1811,209 @@ var clientTests = []clientTest{
|
|||
}, waitTime)
|
||||
require.NoError(h.t, err)
|
||||
},
|
||||
}, {
|
||||
name: "assert that sessions are correctly marked as closable",
|
||||
cfg: harnessCfg{
|
||||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 5,
|
||||
},
|
||||
},
|
||||
fn: func(h *testHarness) {
|
||||
const numUpdates = 5
|
||||
|
||||
// In this test we assert that a channel is correctly
|
||||
// marked as closed and that sessions are also correctly
|
||||
// marked as closable.
|
||||
|
||||
// We start with the sendUpdatesOn parameter set to
|
||||
// false so that we can test that channels are correctly
|
||||
// evaluated at startup.
|
||||
h.sendUpdatesOn = false
|
||||
|
||||
// Advance channel 0 to create all states and back them
|
||||
// all up. This will saturate the session with updates
|
||||
// for channel 0 which means that the session should be
|
||||
// considered closable when channel 0 is closed.
|
||||
hints := h.advanceChannelN(0, numUpdates)
|
||||
h.backupStates(0, 0, numUpdates, nil)
|
||||
h.waitServerUpdates(hints, waitTime)
|
||||
|
||||
// We expect only 1 session to have updates for this
|
||||
// channel.
|
||||
sessionIDs := h.relevantSessions(0)
|
||||
require.Len(h.t, sessionIDs, 1)
|
||||
|
||||
// Since channel 0 is still open, the session should not
|
||||
// yet be closable.
|
||||
require.False(h.t, h.isSessionClosable(sessionIDs[0]))
|
||||
|
||||
// Close the channel.
|
||||
h.closeChannel(0, 1)
|
||||
|
||||
// Since updates are currently not being sent, we expect
|
||||
// the session to still not be marked as closable.
|
||||
require.False(h.t, h.isSessionClosable(sessionIDs[0]))
|
||||
|
||||
// Restart the client.
|
||||
h.client.ForceQuit()
|
||||
h.startClient()
|
||||
|
||||
// The session should now have been marked as closable.
|
||||
err := wait.Predicate(func() bool {
|
||||
return h.isSessionClosable(sessionIDs[0])
|
||||
}, waitTime)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
// Now we set sendUpdatesOn to true and do the same with
|
||||
// a new channel. A restart should now not be necessary
|
||||
// anymore.
|
||||
h.sendUpdatesOn = true
|
||||
|
||||
h.makeChannel(
|
||||
1, h.cfg.localBalance, h.cfg.remoteBalance,
|
||||
)
|
||||
h.registerChannel(1)
|
||||
|
||||
hints = h.advanceChannelN(1, numUpdates)
|
||||
h.backupStates(1, 0, numUpdates, nil)
|
||||
h.waitServerUpdates(hints, waitTime)
|
||||
|
||||
// Determine the ID of the session of interest.
|
||||
sessionIDs = h.relevantSessions(1)
|
||||
|
||||
// We expect only 1 session to have updates for this
|
||||
// channel.
|
||||
require.Len(h.t, sessionIDs, 1)
|
||||
|
||||
// Assert that the session is not yet closable since
|
||||
// the channel is still open.
|
||||
require.False(h.t, h.isSessionClosable(sessionIDs[0]))
|
||||
|
||||
// Now close the channel.
|
||||
h.closeChannel(1, 1)
|
||||
|
||||
// Since the updates have been turned on, the session
|
||||
// should now show up as closable.
|
||||
err = wait.Predicate(func() bool {
|
||||
return h.isSessionClosable(sessionIDs[0])
|
||||
}, waitTime)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
// Now we test that a session must be exhausted with all
|
||||
// channels closed before it is seen as closable.
|
||||
h.makeChannel(
|
||||
2, h.cfg.localBalance, h.cfg.remoteBalance,
|
||||
)
|
||||
h.registerChannel(2)
|
||||
|
||||
// Fill up only half of the session updates.
|
||||
hints = h.advanceChannelN(2, numUpdates)
|
||||
h.backupStates(2, 0, numUpdates/2, nil)
|
||||
h.waitServerUpdates(hints[:numUpdates/2], waitTime)
|
||||
|
||||
// Determine the ID of the session of interest.
|
||||
sessionIDs = h.relevantSessions(2)
|
||||
|
||||
// We expect only 1 session to have updates for this
|
||||
// channel.
|
||||
require.Len(h.t, sessionIDs, 1)
|
||||
|
||||
// Now close the channel.
|
||||
h.closeChannel(2, 1)
|
||||
|
||||
// The session should _not_ be closable due to it not
|
||||
// being exhausted yet.
|
||||
require.False(h.t, h.isSessionClosable(sessionIDs[0]))
|
||||
|
||||
// Create a new channel.
|
||||
h.makeChannel(
|
||||
3, h.cfg.localBalance, h.cfg.remoteBalance,
|
||||
)
|
||||
h.registerChannel(3)
|
||||
|
||||
hints = h.advanceChannelN(3, numUpdates)
|
||||
h.backupStates(3, 0, numUpdates, nil)
|
||||
h.waitServerUpdates(hints, waitTime)
|
||||
|
||||
// Close it.
|
||||
h.closeChannel(3, 1)
|
||||
|
||||
// Now the session should be closable.
|
||||
err = wait.Predicate(func() bool {
|
||||
return h.isSessionClosable(sessionIDs[0])
|
||||
}, waitTime)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
// Now we will mine a few blocks. This will cause the
|
||||
// necessary session-close-range to be exceeded meaning
|
||||
// that the client should send the DeleteSession message
|
||||
// to the server. We will assert that both the client
|
||||
// and server have deleted the appropriate sessions and
|
||||
// channel info.
|
||||
|
||||
// Before we mine blocks, assert that the client
|
||||
// currently has 3 closable sessions.
|
||||
closableSess, err := h.clientDB.ListClosableSessions()
|
||||
require.NoError(h.t, err)
|
||||
require.Len(h.t, closableSess, 3)
|
||||
|
||||
// Assert that the server is also aware of all of these
|
||||
// sessions.
|
||||
for sid := range closableSess {
|
||||
_, err := h.serverDB.GetSessionInfo(&sid)
|
||||
require.NoError(h.t, err)
|
||||
}
|
||||
|
||||
// Also make a note of the total number of sessions the
|
||||
// client has.
|
||||
sessions, err := h.clientDB.ListClientSessions(nil, nil)
|
||||
require.NoError(h.t, err)
|
||||
require.Len(h.t, sessions, 4)
|
||||
|
||||
h.mine(3)
|
||||
|
||||
// The client should no longer have any closable
|
||||
// sessions and the total list of client sessions should
|
||||
// no longer include the three that it previously had
|
||||
// marked as closable. The server should also no longer
|
||||
// have these sessions in its DB.
|
||||
err = wait.Predicate(func() bool {
|
||||
sess, err := h.clientDB.ListClientSessions(
|
||||
nil, nil,
|
||||
)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
cs, err := h.clientDB.ListClosableSessions()
|
||||
require.NoError(h.t, err)
|
||||
|
||||
if len(sess) != 1 || len(cs) != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for sid := range closableSess {
|
||||
_, ok := sess[sid]
|
||||
if ok {
|
||||
return false
|
||||
}
|
||||
|
||||
_, err := h.serverDB.GetSessionInfo(
|
||||
&sid,
|
||||
)
|
||||
if !errors.Is(
|
||||
err, wtdb.ErrSessionNotFound,
|
||||
) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
}, waitTime)
|
||||
require.NoError(h.t, err)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package wtclient
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrClientExiting signals that the watchtower client is shutting down.
|
||||
|
@ -11,6 +13,10 @@ var (
|
|||
ErrTowerCandidatesExhausted = errors.New("exhausted all tower " +
|
||||
"candidates")
|
||||
|
||||
// ErrTowerNotInIterator is returned when a requested tower was not
|
||||
// found in the iterator.
|
||||
ErrTowerNotInIterator = errors.New("tower not in iterator")
|
||||
|
||||
// ErrPermanentTowerFailure signals that the tower has reported that it
|
||||
// has permanently failed or the client believes this has happened based
|
||||
// on the tower's behavior.
|
||||
|
|
|
@ -64,6 +64,11 @@ type DB interface {
|
|||
...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error)
|
||||
|
||||
// GetClientSession loads the ClientSession with the given ID from the
|
||||
// DB.
|
||||
GetClientSession(wtdb.SessionID,
|
||||
...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error)
|
||||
|
||||
// FetchSessionCommittedUpdates retrieves the current set of un-acked
|
||||
// updates of the given session.
|
||||
FetchSessionCommittedUpdates(id *wtdb.SessionID) (
|
||||
|
@ -78,9 +83,29 @@ type DB interface {
|
|||
NumAckedUpdates(id *wtdb.SessionID) (uint64, error)
|
||||
|
||||
// FetchChanSummaries loads a mapping from all registered channels to
|
||||
// their channel summaries.
|
||||
// their channel summaries. Only the channels that have not yet been
|
||||
// marked as closed will be loaded.
|
||||
FetchChanSummaries() (wtdb.ChannelSummaries, error)
|
||||
|
||||
// MarkChannelClosed will mark a registered channel as closed by setting
|
||||
// its closed-height as the given block height. It returns a list of
|
||||
// session IDs for sessions that are now considered closable due to the
|
||||
// close of this channel. The details for this channel will be deleted
|
||||
// from the DB if there are no more sessions in the DB that contain
|
||||
// updates for this channel.
|
||||
MarkChannelClosed(chanID lnwire.ChannelID, blockHeight uint32) (
|
||||
[]wtdb.SessionID, error)
|
||||
|
||||
// ListClosableSessions fetches and returns the IDs for all sessions
|
||||
// marked as closable.
|
||||
ListClosableSessions() (map[wtdb.SessionID]uint32, error)
|
||||
|
||||
// DeleteSession can be called when a session should be deleted from the
|
||||
// DB. All references to the session will also be deleted from the DB.
|
||||
// A session will only be deleted if it was previously marked as
|
||||
// closable.
|
||||
DeleteSession(id wtdb.SessionID) error
|
||||
|
||||
// RegisterChannel registers a channel for use within the client
|
||||
// database. For now, all that is stored in the channel summary is the
|
||||
// sweep pkscript that we'd like any tower sweeps to pay into. In the
|
||||
|
@ -174,3 +199,30 @@ type ClientSession struct {
|
|||
// key used to connect to the watchtower.
|
||||
SessionKeyECDH keychain.SingleKeyECDH
|
||||
}
|
||||
|
||||
// NewClientSessionFromDBSession converts a wtdb.ClientSession to a
|
||||
// ClientSession.
|
||||
func NewClientSessionFromDBSession(s *wtdb.ClientSession, tower *Tower,
|
||||
keyRing ECDHKeyRing) (*ClientSession, error) {
|
||||
|
||||
towerKeyDesc, err := keyRing.DeriveKey(
|
||||
keychain.KeyLocator{
|
||||
Family: keychain.KeyFamilyTowerSession,
|
||||
Index: s.KeyIndex,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sessionKeyECDH := keychain.NewPubKeyECDH(
|
||||
towerKeyDesc, keyRing,
|
||||
)
|
||||
|
||||
return &ClientSession{
|
||||
ID: s.ID,
|
||||
ClientSessionBody: s.ClientSessionBody,
|
||||
Tower: tower,
|
||||
SessionKeyECDH: sessionKeyECDH,
|
||||
}, nil
|
||||
}
|
||||
|
|
95
watchtower/wtclient/sess_close_min_heap.go
Normal file
95
watchtower/wtclient/sess_close_min_heap.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package wtclient
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/lightningnetwork/lnd/queue"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
)
|
||||
|
||||
// sessionCloseMinHeap is a thread-safe min-heap implementation that stores
|
||||
// sessionCloseItem items and prioritises the item with the lowest block height.
|
||||
type sessionCloseMinHeap struct {
|
||||
queue queue.PriorityQueue
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// newSessionCloseMinHeap constructs a new sessionCloseMineHeap.
|
||||
func newSessionCloseMinHeap() *sessionCloseMinHeap {
|
||||
return &sessionCloseMinHeap{}
|
||||
}
|
||||
|
||||
// Len returns the length of the queue.
|
||||
func (h *sessionCloseMinHeap) Len() int {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
return h.queue.Len()
|
||||
}
|
||||
|
||||
// Empty returns true if the queue is empty.
|
||||
func (h *sessionCloseMinHeap) Empty() bool {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
return h.queue.Empty()
|
||||
}
|
||||
|
||||
// Push adds an item to the priority queue.
|
||||
func (h *sessionCloseMinHeap) Push(item *sessionCloseItem) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
h.queue.Push(item)
|
||||
}
|
||||
|
||||
// Pop removes the top most item from the queue.
|
||||
func (h *sessionCloseMinHeap) Pop() *sessionCloseItem {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if h.queue.Empty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
item := h.queue.Pop()
|
||||
|
||||
return item.(*sessionCloseItem) //nolint:forcetypeassert
|
||||
}
|
||||
|
||||
// Top returns the top most item from the queue without removing it.
|
||||
func (h *sessionCloseMinHeap) Top() *sessionCloseItem {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if h.queue.Empty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
item := h.queue.Top()
|
||||
|
||||
return item.(*sessionCloseItem) //nolint:forcetypeassert
|
||||
}
|
||||
|
||||
// sessionCloseItem represents a session that is ready to be deleted.
|
||||
type sessionCloseItem struct {
|
||||
// sessionID is the ID of the session in question.
|
||||
sessionID wtdb.SessionID
|
||||
|
||||
// deleteHeight is the block height after which we can delete the
|
||||
// session.
|
||||
deleteHeight uint32
|
||||
}
|
||||
|
||||
// Less returns true if the current item's delete height is less than the
|
||||
// other sessionCloseItem's delete height. This results in lower block heights
|
||||
// being popped first from the heap.
|
||||
//
|
||||
// NOTE: this is part of the queue.PriorityQueueItem interface.
|
||||
func (s *sessionCloseItem) Less(other queue.PriorityQueueItem) bool {
|
||||
o := other.(*sessionCloseItem).deleteHeight //nolint:forcetypeassert
|
||||
|
||||
return s.deleteHeight < o
|
||||
}
|
||||
|
||||
var _ queue.PriorityQueueItem = (*sessionCloseItem)(nil)
|
52
watchtower/wtclient/sess_close_min_heap_test.go
Normal file
52
watchtower/wtclient/sess_close_min_heap_test.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package wtclient
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestSessionCloseMinHeap asserts that the sessionCloseMinHeap behaves as
|
||||
// expected.
|
||||
func TestSessionCloseMinHeap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
heap := newSessionCloseMinHeap()
|
||||
require.Nil(t, heap.Pop())
|
||||
require.Nil(t, heap.Top())
|
||||
require.True(t, heap.Empty())
|
||||
require.Zero(t, heap.Len())
|
||||
|
||||
// Add an item with height 10.
|
||||
item1 := &sessionCloseItem{
|
||||
sessionID: [33]byte{1, 2, 3},
|
||||
deleteHeight: 10,
|
||||
}
|
||||
|
||||
heap.Push(item1)
|
||||
require.Equal(t, item1, heap.Top())
|
||||
require.False(t, heap.Empty())
|
||||
require.EqualValues(t, 1, heap.Len())
|
||||
|
||||
// Add a bunch more items with heights 1, 2, 6, 11, 6, 30, 9.
|
||||
heap.Push(&sessionCloseItem{deleteHeight: 1})
|
||||
heap.Push(&sessionCloseItem{deleteHeight: 2})
|
||||
heap.Push(&sessionCloseItem{deleteHeight: 6})
|
||||
heap.Push(&sessionCloseItem{deleteHeight: 11})
|
||||
heap.Push(&sessionCloseItem{deleteHeight: 6})
|
||||
heap.Push(&sessionCloseItem{deleteHeight: 30})
|
||||
heap.Push(&sessionCloseItem{deleteHeight: 9})
|
||||
|
||||
// Now pop from the queue and assert that the items are returned in
|
||||
// ascending order.
|
||||
require.EqualValues(t, 1, heap.Pop().deleteHeight)
|
||||
require.EqualValues(t, 2, heap.Pop().deleteHeight)
|
||||
require.EqualValues(t, 6, heap.Pop().deleteHeight)
|
||||
require.EqualValues(t, 6, heap.Pop().deleteHeight)
|
||||
require.EqualValues(t, 9, heap.Pop().deleteHeight)
|
||||
require.EqualValues(t, 10, heap.Pop().deleteHeight)
|
||||
require.EqualValues(t, 11, heap.Pop().deleteHeight)
|
||||
require.EqualValues(t, 30, heap.Pop().deleteHeight)
|
||||
require.Nil(t, heap.Pop())
|
||||
require.Zero(t, heap.Len())
|
||||
}
|
|
@ -23,22 +23,39 @@ var (
|
|||
// cChanDetailsBkt is a top-level bucket storing:
|
||||
// channel-id => cChannelSummary -> encoded ClientChanSummary.
|
||||
// => cChanDBID -> db-assigned-id
|
||||
// => cChanSessions => db-session-id -> 1
|
||||
// => cChanClosedHeight -> block-height
|
||||
cChanDetailsBkt = []byte("client-channel-detail-bucket")
|
||||
|
||||
// cChanSessions is a sub-bucket of cChanDetailsBkt which stores:
|
||||
// db-session-id -> 1
|
||||
cChanSessions = []byte("client-channel-sessions")
|
||||
|
||||
// cChanDBID is a key used in the cChanDetailsBkt to store the
|
||||
// db-assigned-id of a channel.
|
||||
cChanDBID = []byte("client-channel-db-id")
|
||||
|
||||
// cChanClosedHeight is a key used in the cChanDetailsBkt to store the
|
||||
// block height at which the channel's closing transaction was mined in.
|
||||
// If this there is no associated value for this key, then the channel
|
||||
// has not yet been marked as closed.
|
||||
cChanClosedHeight = []byte("client-channel-closed-height")
|
||||
|
||||
// cChannelSummary is a key used in cChanDetailsBkt to store the encoded
|
||||
// body of ClientChanSummary.
|
||||
cChannelSummary = []byte("client-channel-summary")
|
||||
|
||||
// cSessionBkt is a top-level bucket storing:
|
||||
// session-id => cSessionBody -> encoded ClientSessionBody
|
||||
// => cSessionDBID -> db-assigned-id
|
||||
// => cSessionCommits => seqnum -> encoded CommittedUpdate
|
||||
// => cSessionAckRangeIndex => db-chan-id => start -> end
|
||||
cSessionBkt = []byte("client-session-bucket")
|
||||
|
||||
// cSessionDBID is a key used in the cSessionBkt to store the
|
||||
// db-assigned-id of a session.
|
||||
cSessionDBID = []byte("client-session-db-id")
|
||||
|
||||
// cSessionBody is a sub-bucket of cSessionBkt storing only the body of
|
||||
// the ClientSession.
|
||||
cSessionBody = []byte("client-session-body")
|
||||
|
@ -55,6 +72,10 @@ var (
|
|||
// db-assigned-id -> channel-ID
|
||||
cChanIDIndexBkt = []byte("client-channel-id-index")
|
||||
|
||||
// cSessionIDIndexBkt is a top-level bucket storing:
|
||||
// db-assigned-id -> session-id
|
||||
cSessionIDIndexBkt = []byte("client-session-id-index")
|
||||
|
||||
// cTowerBkt is a top-level bucket storing:
|
||||
// tower-id -> encoded Tower.
|
||||
cTowerBkt = []byte("client-tower-bucket")
|
||||
|
@ -69,6 +90,10 @@ var (
|
|||
"client-tower-to-session-index-bucket",
|
||||
)
|
||||
|
||||
// cClosableSessionsBkt is a top-level bucket storing:
|
||||
// db-session-id -> last-channel-close-height
|
||||
cClosableSessionsBkt = []byte("client-closable-sessions-bucket")
|
||||
|
||||
// ErrTowerNotFound signals that the target tower was not found in the
|
||||
// database.
|
||||
ErrTowerNotFound = errors.New("tower not found")
|
||||
|
@ -142,6 +167,23 @@ var (
|
|||
// ErrSessionFailedFilterFn indicates that a particular session did
|
||||
// not pass the filter func provided by the caller.
|
||||
ErrSessionFailedFilterFn = errors.New("session failed filter func")
|
||||
|
||||
// ErrSessionNotClosable is returned when a session is not found in the
|
||||
// closable list.
|
||||
ErrSessionNotClosable = errors.New("session is not closable")
|
||||
|
||||
// errSessionHasOpenChannels is an error used to indicate that a
|
||||
// session has updates for channels that are still open.
|
||||
errSessionHasOpenChannels = errors.New("session has open channels")
|
||||
|
||||
// errSessionHasUnackedUpdates is an error used to indicate that a
|
||||
// session has un-acked updates.
|
||||
errSessionHasUnackedUpdates = errors.New("session has un-acked updates")
|
||||
|
||||
// errChannelHasMoreSessions is an error used to indicate that a channel
|
||||
// has updates in other non-closed sessions.
|
||||
errChannelHasMoreSessions = errors.New("channel has updates in " +
|
||||
"other sessions")
|
||||
)
|
||||
|
||||
// NewBoltBackendCreator returns a function that creates a new bbolt backend for
|
||||
|
@ -241,6 +283,8 @@ func initClientDBBuckets(tx kvdb.RwTx) error {
|
|||
cTowerIndexBkt,
|
||||
cTowerToSessionIndexBkt,
|
||||
cChanIDIndexBkt,
|
||||
cSessionIDIndexBkt,
|
||||
cClosableSessionsBkt,
|
||||
}
|
||||
|
||||
for _, bucket := range buckets {
|
||||
|
@ -723,24 +767,58 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Add the new entry to the towerID-to-SessionID index.
|
||||
indexBkt := towerToSessionIndex.NestedReadWriteBucket(
|
||||
towerID.Bytes(),
|
||||
)
|
||||
if indexBkt == nil {
|
||||
return ErrTowerNotFound
|
||||
// Get the session-ID index bucket.
|
||||
dbIDIndex := tx.ReadWriteBucket(cSessionIDIndexBkt)
|
||||
if dbIDIndex == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
err = indexBkt.Put(session.ID[:], []byte{1})
|
||||
// Get a new, unique, ID for this session from the session-ID
|
||||
// index bucket.
|
||||
nextSeq, err := dbIDIndex.NextSequence()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add the new entry to the dbID-to-SessionID index.
|
||||
newIndex, err := writeBigSize(nextSeq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = dbIDIndex.Put(newIndex, session.ID[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Also add the db-assigned-id to the session bucket under the
|
||||
// cSessionDBID key.
|
||||
sessionBkt, err := sessions.CreateBucket(session.ID[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = sessionBkt.Put(cSessionDBID, newIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(elle): migrate the towerID-to-SessionID to use the
|
||||
// new db-assigned sessionID's rather.
|
||||
|
||||
// Add the new entry to the towerID-to-SessionID index.
|
||||
towerSessions := towerToSessionIndex.NestedReadWriteBucket(
|
||||
towerID.Bytes(),
|
||||
)
|
||||
if towerSessions == nil {
|
||||
return ErrTowerNotFound
|
||||
}
|
||||
|
||||
err = towerSessions.Put(session.ID[:], []byte{1})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, write the client session's body in the sessions
|
||||
// bucket.
|
||||
return putClientSessionBody(sessionBkt, session)
|
||||
|
@ -960,6 +1038,37 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID,
|
|||
return byteOrder.Uint32(keyIndexBytes), nil
|
||||
}
|
||||
|
||||
// GetClientSession loads the ClientSession with the given ID from the DB.
|
||||
func (c *ClientDB) GetClientSession(id SessionID,
|
||||
opts ...ClientSessionListOption) (*ClientSession, error) {
|
||||
|
||||
var sess *ClientSession
|
||||
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
|
||||
sessionsBkt := tx.ReadBucket(cSessionBkt)
|
||||
if sessionsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt)
|
||||
if chanIDIndexBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
session, err := c.getClientSession(
|
||||
sessionsBkt, chanIDIndexBkt, id[:], nil, opts...,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sess = session
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
|
||||
return sess, err
|
||||
}
|
||||
|
||||
// ListClientSessions returns the set of all client sessions known to the db. An
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
|
@ -974,20 +1083,14 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
|
|||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
towers := tx.ReadBucket(cTowerBkt)
|
||||
if towers == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt)
|
||||
if chanIDIndexBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// If no tower ID is specified, then fetch all the sessions
|
||||
// known to the db.
|
||||
var err error
|
||||
if id == nil {
|
||||
clientSessions, err = c.listClientAllSessions(
|
||||
sessions, chanIDIndexBkt, filterFn, opts...,
|
||||
|
@ -1181,7 +1284,8 @@ func (c *ClientDB) NumAckedUpdates(id *SessionID) (uint64, error) {
|
|||
}
|
||||
|
||||
// FetchChanSummaries loads a mapping from all registered channels to their
|
||||
// channel summaries.
|
||||
// channel summaries. Only the channels that have not yet been marked as closed
|
||||
// will be loaded.
|
||||
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
|
||||
var summaries map[lnwire.ChannelID]ClientChanSummary
|
||||
|
||||
|
@ -1197,6 +1301,13 @@ func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
|
|||
return ErrCorruptChanDetails
|
||||
}
|
||||
|
||||
// If this channel has already been marked as closed,
|
||||
// then its summary does not need to be loaded.
|
||||
closedHeight := chanDetails.Get(cChanClosedHeight)
|
||||
if len(closedHeight) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var chanID lnwire.ChannelID
|
||||
copy(chanID[:], k)
|
||||
|
||||
|
@ -1292,6 +1403,420 @@ func (c *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID,
|
|||
return nil
|
||||
}
|
||||
|
||||
// ListClosableSessions fetches and returns the IDs for all sessions marked as
|
||||
// closable.
|
||||
func (c *ClientDB) ListClosableSessions() (map[SessionID]uint32, error) {
|
||||
sessions := make(map[SessionID]uint32)
|
||||
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
|
||||
csBkt := tx.ReadBucket(cClosableSessionsBkt)
|
||||
if csBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
sessIDIndexBkt := tx.ReadBucket(cSessionIDIndexBkt)
|
||||
if sessIDIndexBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
return csBkt.ForEach(func(dbIDBytes, heightBytes []byte) error {
|
||||
dbID, err := readBigSize(dbIDBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sessID, err := getRealSessionID(sessIDIndexBkt, dbID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sessions[*sessID] = byteOrder.Uint32(heightBytes)
|
||||
|
||||
return nil
|
||||
})
|
||||
}, func() {
|
||||
sessions = make(map[SessionID]uint32)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// DeleteSession can be called when a session should be deleted from the DB.
|
||||
// All references to the session will also be deleted from the DB. Note that a
|
||||
// session will only be deleted if was previously marked as closable.
|
||||
func (c *ClientDB) DeleteSession(id SessionID) error {
|
||||
return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
|
||||
sessionsBkt := tx.ReadWriteBucket(cSessionBkt)
|
||||
if sessionsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
closableBkt := tx.ReadWriteBucket(cClosableSessionsBkt)
|
||||
if closableBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt)
|
||||
if chanDetailsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
sessIDIndexBkt := tx.ReadWriteBucket(cSessionIDIndexBkt)
|
||||
if sessIDIndexBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
chanIDIndexBkt := tx.ReadWriteBucket(cChanIDIndexBkt)
|
||||
if chanIDIndexBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
towerToSessBkt := tx.ReadWriteBucket(cTowerToSessionIndexBkt)
|
||||
if towerToSessBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// Get the sub-bucket for this session ID. If it does not exist
|
||||
// then the session has already been deleted and so our work is
|
||||
// done.
|
||||
sessionBkt := sessionsBkt.NestedReadBucket(id[:])
|
||||
if sessionBkt == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, dbIDBytes, err := getDBSessionID(sessionsBkt, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// First we check if the session has actually been marked as
|
||||
// closable.
|
||||
if closableBkt.Get(dbIDBytes) == nil {
|
||||
return ErrSessionNotClosable
|
||||
}
|
||||
|
||||
sess, err := getClientSessionBody(sessionsBkt, id[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete from the tower-to-sessionID index.
|
||||
towerIndexBkt := towerToSessBkt.NestedReadWriteBucket(
|
||||
sess.TowerID.Bytes(),
|
||||
)
|
||||
if towerIndexBkt == nil {
|
||||
return fmt.Errorf("no entry in the tower-to-session "+
|
||||
"index found for tower ID %v", sess.TowerID)
|
||||
}
|
||||
|
||||
err = towerIndexBkt.Delete(id[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete entry from session ID index.
|
||||
err = sessIDIndexBkt.Delete(dbIDBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete the entry from the closable sessions index.
|
||||
err = closableBkt.Delete(dbIDBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the acked updates range index for the session. This is
|
||||
// used to get the list of channels that the session has updates
|
||||
// for.
|
||||
ackRanges := sessionBkt.NestedReadBucket(cSessionAckRangeIndex)
|
||||
if ackRanges == nil {
|
||||
// A session would only be considered closable if it
|
||||
// was exhausted. Meaning that it should not be the
|
||||
// case that it has no acked-updates.
|
||||
return fmt.Errorf("cannot delete session %s since it "+
|
||||
"is not yet exhausted", id)
|
||||
}
|
||||
|
||||
// For each of the channels, delete the session ID entry.
|
||||
err = ackRanges.ForEach(func(chanDBID, _ []byte) error {
|
||||
chanDBIDInt, err := readBigSize(chanDBID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chanID, err := getRealChannelID(
|
||||
chanIDIndexBkt, chanDBIDInt,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chanDetails := chanDetailsBkt.NestedReadWriteBucket(
|
||||
chanID[:],
|
||||
)
|
||||
if chanDetails == nil {
|
||||
return ErrChannelNotRegistered
|
||||
}
|
||||
|
||||
chanSessions := chanDetails.NestedReadWriteBucket(
|
||||
cChanSessions,
|
||||
)
|
||||
if chanSessions == nil {
|
||||
return fmt.Errorf("no session list found for "+
|
||||
"channel %s", chanID)
|
||||
}
|
||||
|
||||
// Check that this session was actually listed in the
|
||||
// session list for this channel.
|
||||
if len(chanSessions.Get(dbIDBytes)) == 0 {
|
||||
return fmt.Errorf("session %s not found in "+
|
||||
"the session list for channel %s", id,
|
||||
chanID)
|
||||
}
|
||||
|
||||
// If it was, then delete it.
|
||||
err = chanSessions.Delete(dbIDBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If this was the last session for this channel, we can
|
||||
// now delete the channel details for this channel
|
||||
// completely.
|
||||
err = chanSessions.ForEach(func(_, _ []byte) error {
|
||||
return errChannelHasMoreSessions
|
||||
})
|
||||
if errors.Is(err, errChannelHasMoreSessions) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete the channel's entry from the channel-id-index.
|
||||
dbID := chanDetails.Get(cChanDBID)
|
||||
err = chanIDIndexBkt.Delete(dbID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete the channel details.
|
||||
return chanDetailsBkt.DeleteNestedBucket(chanID[:])
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete the actual session.
|
||||
return sessionsBkt.DeleteNestedBucket(id[:])
|
||||
}, func() {})
|
||||
}
|
||||
|
||||
// MarkChannelClosed will mark a registered channel as closed by setting its
|
||||
// closed-height as the given block height. It returns a list of session IDs for
|
||||
// sessions that are now considered closable due to the close of this channel.
|
||||
// The details for this channel will be deleted from the DB if there are no more
|
||||
// sessions in the DB that contain updates for this channel.
|
||||
func (c *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID,
|
||||
blockHeight uint32) ([]SessionID, error) {
|
||||
|
||||
var closableSessions []SessionID
|
||||
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
|
||||
sessionsBkt := tx.ReadBucket(cSessionBkt)
|
||||
if sessionsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt)
|
||||
if chanDetailsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
closableSessBkt := tx.ReadWriteBucket(cClosableSessionsBkt)
|
||||
if closableSessBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt)
|
||||
if chanIDIndexBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
sessIDIndexBkt := tx.ReadBucket(cSessionIDIndexBkt)
|
||||
if sessIDIndexBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
chanDetails := chanDetailsBkt.NestedReadWriteBucket(chanID[:])
|
||||
if chanDetails == nil {
|
||||
return ErrChannelNotRegistered
|
||||
}
|
||||
|
||||
// If there are no sessions for this channel, the channel
|
||||
// details can be deleted.
|
||||
chanSessIDsBkt := chanDetails.NestedReadBucket(cChanSessions)
|
||||
if chanSessIDsBkt == nil {
|
||||
return chanDetailsBkt.DeleteNestedBucket(chanID[:])
|
||||
}
|
||||
|
||||
// Otherwise, mark the channel as closed.
|
||||
var height [4]byte
|
||||
byteOrder.PutUint32(height[:], blockHeight)
|
||||
|
||||
err := chanDetails.Put(cChanClosedHeight, height[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now iterate through all the sessions of the channel to check
|
||||
// if any of them are closeable.
|
||||
return chanSessIDsBkt.ForEach(func(sessDBID, _ []byte) error {
|
||||
sessDBIDInt, err := readBigSize(sessDBID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Use the session-ID index to get the real session ID.
|
||||
sID, err := getRealSessionID(
|
||||
sessIDIndexBkt, sessDBIDInt,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
isClosable, err := isSessionClosable(
|
||||
sessionsBkt, chanDetailsBkt, chanIDIndexBkt,
|
||||
sID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !isClosable {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add session to "closableSessions" list and add the
|
||||
// block height that this last channel was closed in.
|
||||
// This will be used in future to determine when we
|
||||
// should delete the session.
|
||||
var height [4]byte
|
||||
byteOrder.PutUint32(height[:], blockHeight)
|
||||
err = closableSessBkt.Put(sessDBID, height[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
closableSessions = append(closableSessions, *sID)
|
||||
|
||||
return nil
|
||||
})
|
||||
}, func() {
|
||||
closableSessions = nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return closableSessions, nil
|
||||
}
|
||||
|
||||
// isSessionClosable returns true if a session is considered closable. A session
|
||||
// is considered closable only if all the following points are true:
|
||||
// 1) It has no un-acked updates.
|
||||
// 2) It is exhausted (ie it can't accept any more updates)
|
||||
// 3) All the channels that it has acked updates for are closed.
|
||||
func isSessionClosable(sessionsBkt, chanDetailsBkt, chanIDIndexBkt kvdb.RBucket,
|
||||
id *SessionID) (bool, error) {
|
||||
|
||||
sessBkt := sessionsBkt.NestedReadBucket(id[:])
|
||||
if sessBkt == nil {
|
||||
return false, ErrSessionNotFound
|
||||
}
|
||||
|
||||
commitsBkt := sessBkt.NestedReadBucket(cSessionCommits)
|
||||
if commitsBkt == nil {
|
||||
// If the session has no cSessionCommits bucket then we can be
|
||||
// sure that no updates have ever been committed to the session
|
||||
// and so it is not yet exhausted.
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// If the session has any un-acked updates, then it is not yet closable.
|
||||
err := commitsBkt.ForEach(func(_, _ []byte) error {
|
||||
return errSessionHasUnackedUpdates
|
||||
})
|
||||
if errors.Is(err, errSessionHasUnackedUpdates) {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
session, err := getClientSessionBody(sessionsBkt, id[:])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// We have already checked that the session has no more committed
|
||||
// updates. So now we can check if the session is exhausted.
|
||||
if session.SeqNum < session.Policy.MaxUpdates {
|
||||
// If the session is not yet exhausted, it is not yet closable.
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// If the session has no acked-updates, then something is wrong since
|
||||
// the above check ensures that this session has been exhausted meaning
|
||||
// that it should have MaxUpdates acked updates.
|
||||
ackedRangeBkt := sessBkt.NestedReadBucket(cSessionAckRangeIndex)
|
||||
if ackedRangeBkt == nil {
|
||||
return false, fmt.Errorf("no acked-updates found for "+
|
||||
"exhausted session %s", id)
|
||||
}
|
||||
|
||||
// Iterate over each of the channels that the session has acked-updates
|
||||
// for. If any of those channels are not closed, then the session is
|
||||
// not yet closable.
|
||||
err = ackedRangeBkt.ForEach(func(dbChanID, _ []byte) error {
|
||||
dbChanIDInt, err := readBigSize(dbChanID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chanID, err := getRealChannelID(chanIDIndexBkt, dbChanIDInt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the channel details bucket for the channel.
|
||||
chanDetails := chanDetailsBkt.NestedReadBucket(chanID[:])
|
||||
if chanDetails == nil {
|
||||
return fmt.Errorf("no channel details found for "+
|
||||
"channel %s referenced by session %s", chanID,
|
||||
id)
|
||||
}
|
||||
|
||||
// If a closed height has been set, then the channel is closed.
|
||||
closedHeight := chanDetails.Get(cChanClosedHeight)
|
||||
if len(closedHeight) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise, the channel is not yet closed meaning that the
|
||||
// session is not yet closable. We break the ForEach by
|
||||
// returning an error to indicate this.
|
||||
return errSessionHasOpenChannels
|
||||
})
|
||||
if errors.Is(err, errSessionHasOpenChannels) {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
|
||||
// seqNum). This allows the client to retransmit this update on startup.
|
||||
func (c *ClientDB) CommitUpdate(id *SessionID,
|
||||
|
@ -1410,7 +1935,7 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
|
|||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt)
|
||||
chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt)
|
||||
if chanDetailsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
@ -1494,6 +2019,23 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
|
|||
return err
|
||||
}
|
||||
|
||||
dbSessionID, _, err := getDBSessionID(sessions, *id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chanDetails := chanDetailsBkt.NestedReadWriteBucket(
|
||||
committedUpdate.BackupID.ChanID[:],
|
||||
)
|
||||
if chanDetails == nil {
|
||||
return ErrChannelNotRegistered
|
||||
}
|
||||
|
||||
err = putChannelToSessionMapping(chanDetails, dbSessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the range index for the given session-channel pair.
|
||||
index, err := c.getRangeIndex(tx, *id, chanID)
|
||||
if err != nil {
|
||||
|
@ -1504,6 +2046,26 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
|
|||
}, func() {})
|
||||
}
|
||||
|
||||
// putChannelToSessionMapping adds the given session ID to a channel's
|
||||
// cChanSessions bucket.
|
||||
func putChannelToSessionMapping(chanDetails kvdb.RwBucket,
|
||||
dbSessID uint64) error {
|
||||
|
||||
chanSessIDsBkt, err := chanDetails.CreateBucketIfNotExists(
|
||||
cChanSessions,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b, err := writeBigSize(dbSessID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return chanSessIDsBkt.Put(b, []byte{1})
|
||||
}
|
||||
|
||||
// getClientSessionBody loads the body of a ClientSession from the sessions
|
||||
// bucket corresponding to the serialized session id. This does not deserialize
|
||||
// the CommittedUpdates, AckUpdates or the Tower associated with the session.
|
||||
|
@ -1882,6 +2444,68 @@ func getDBChanID(chanDetailsBkt kvdb.RBucket, chanID lnwire.ChannelID) (uint64,
|
|||
return id, idBytes, nil
|
||||
}
|
||||
|
||||
// getDBSessionID returns the db-assigned session ID for the given real session
|
||||
// ID. It returns both the uint64 and byte representation.
|
||||
func getDBSessionID(sessionsBkt kvdb.RBucket, sessionID SessionID) (uint64,
|
||||
[]byte, error) {
|
||||
|
||||
sessionBkt := sessionsBkt.NestedReadBucket(sessionID[:])
|
||||
if sessionBkt == nil {
|
||||
return 0, nil, ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
idBytes := sessionBkt.Get(cSessionDBID)
|
||||
if len(idBytes) == 0 {
|
||||
return 0, nil, fmt.Errorf("no db-assigned ID found for "+
|
||||
"session ID %s", sessionID)
|
||||
}
|
||||
|
||||
id, err := readBigSize(idBytes)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
return id, idBytes, nil
|
||||
}
|
||||
|
||||
func getRealSessionID(sessIDIndexBkt kvdb.RBucket, dbID uint64) (*SessionID,
|
||||
error) {
|
||||
|
||||
dbIDBytes, err := writeBigSize(dbID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sessIDBytes := sessIDIndexBkt.Get(dbIDBytes)
|
||||
if len(sessIDBytes) != SessionIDSize {
|
||||
return nil, fmt.Errorf("session ID not found")
|
||||
}
|
||||
|
||||
var sessID SessionID
|
||||
copy(sessID[:], sessIDBytes)
|
||||
|
||||
return &sessID, nil
|
||||
}
|
||||
|
||||
func getRealChannelID(chanIDIndexBkt kvdb.RBucket,
|
||||
dbID uint64) (*lnwire.ChannelID, error) {
|
||||
|
||||
dbIDBytes, err := writeBigSize(dbID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chanIDBytes := chanIDIndexBkt.Get(dbIDBytes)
|
||||
if len(chanIDBytes) != 32 { //nolint:gomnd
|
||||
return nil, fmt.Errorf("channel ID not found")
|
||||
}
|
||||
|
||||
var chanIDS lnwire.ChannelID
|
||||
copy(chanIDS[:], chanIDBytes)
|
||||
|
||||
return &chanIDS, nil
|
||||
}
|
||||
|
||||
// writeBigSize will encode the given uint64 as a BigSize byte slice.
|
||||
func writeBigSize(i uint64) ([]byte, error) {
|
||||
var b bytes.Buffer
|
||||
|
|
|
@ -3,6 +3,7 @@ package wtdb_test
|
|||
import (
|
||||
crand "crypto/rand"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
|
@ -17,6 +18,8 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const blobType = blob.TypeAltruistCommit
|
||||
|
||||
// pseudoAddr is a fake network address to be used for testing purposes.
|
||||
var pseudoAddr = &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911}
|
||||
|
||||
|
@ -193,6 +196,35 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16,
|
|||
require.ErrorIs(h.t, err, expErr)
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) markChannelClosed(id lnwire.ChannelID,
|
||||
blockHeight uint32, expErr error) []wtdb.SessionID {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
closableSessions, err := h.db.MarkChannelClosed(id, blockHeight)
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
|
||||
return closableSessions
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) listClosableSessions(
|
||||
expErr error) map[wtdb.SessionID]uint32 {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
closableSessions, err := h.db.ListClosableSessions()
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
|
||||
return closableSessions
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) deleteSession(id wtdb.SessionID, expErr error) {
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.DeleteSession(id)
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
}
|
||||
|
||||
// newTower is a helper function that creates a new tower with a randomly
|
||||
// generated public key and inserts it into the client DB.
|
||||
func (h *clientDBHarness) newTower() *wtdb.Tower {
|
||||
|
@ -605,6 +637,118 @@ func testCommitUpdate(h *clientDBHarness) {
|
|||
}, nil)
|
||||
}
|
||||
|
||||
// testMarkChannelClosed asserts the behaviour of MarkChannelClosed.
|
||||
func testMarkChannelClosed(h *clientDBHarness) {
|
||||
tower := h.newTower()
|
||||
|
||||
// Create channel 1.
|
||||
chanID1 := randChannelID(h.t)
|
||||
|
||||
// Since we have not yet registered the channel, we expect an error
|
||||
// when attempting to mark it as closed.
|
||||
h.markChannelClosed(chanID1, 1, wtdb.ErrChannelNotRegistered)
|
||||
|
||||
// Now register the channel.
|
||||
h.registerChan(chanID1, nil, nil)
|
||||
|
||||
// Since there are still no sessions that would have updates for the
|
||||
// channel, marking it as closed now should succeed.
|
||||
h.markChannelClosed(chanID1, 1, nil)
|
||||
|
||||
// Register channel 2.
|
||||
chanID2 := randChannelID(h.t)
|
||||
h.registerChan(chanID2, nil, nil)
|
||||
|
||||
// Create session1 with MaxUpdates set to 5.
|
||||
session1 := h.randSession(h.t, tower.ID, 5)
|
||||
h.insertSession(session1, nil)
|
||||
|
||||
// Add an update for channel 2 in session 1 and ack it too.
|
||||
update := randCommittedUpdateForChannel(h.t, chanID2, 1)
|
||||
lastApplied := h.commitUpdate(&session1.ID, update, nil)
|
||||
require.Zero(h.t, lastApplied)
|
||||
h.ackUpdate(&session1.ID, 1, 1, nil)
|
||||
|
||||
// Marking channel 2 now should not result in any closable sessions
|
||||
// since session 1 is not yet exhausted.
|
||||
sl := h.markChannelClosed(chanID2, 1, nil)
|
||||
require.Empty(h.t, sl)
|
||||
|
||||
// Create channel 3 and 4.
|
||||
chanID3 := randChannelID(h.t)
|
||||
h.registerChan(chanID3, nil, nil)
|
||||
|
||||
chanID4 := randChannelID(h.t)
|
||||
h.registerChan(chanID4, nil, nil)
|
||||
|
||||
// Add an update for channel 4 and ack it.
|
||||
update = randCommittedUpdateForChannel(h.t, chanID4, 2)
|
||||
lastApplied = h.commitUpdate(&session1.ID, update, nil)
|
||||
require.EqualValues(h.t, 1, lastApplied)
|
||||
h.ackUpdate(&session1.ID, 2, 2, nil)
|
||||
|
||||
// Add an update for channel 3 in session 1. But dont ack it yet.
|
||||
update = randCommittedUpdateForChannel(h.t, chanID2, 3)
|
||||
lastApplied = h.commitUpdate(&session1.ID, update, nil)
|
||||
require.EqualValues(h.t, 2, lastApplied)
|
||||
|
||||
// Mark channel 4 as closed & assert that session 1 is not seen as
|
||||
// closable since it still has committed updates.
|
||||
sl = h.markChannelClosed(chanID4, 1, nil)
|
||||
require.Empty(h.t, sl)
|
||||
|
||||
// Now ack the update we added above.
|
||||
h.ackUpdate(&session1.ID, 3, 3, nil)
|
||||
|
||||
// Mark channel 3 as closed & assert that session 1 is still not seen as
|
||||
// closable since it is not yet exhausted.
|
||||
sl = h.markChannelClosed(chanID3, 1, nil)
|
||||
require.Empty(h.t, sl)
|
||||
|
||||
// Create channel 5 and 6.
|
||||
chanID5 := randChannelID(h.t)
|
||||
h.registerChan(chanID5, nil, nil)
|
||||
|
||||
chanID6 := randChannelID(h.t)
|
||||
h.registerChan(chanID6, nil, nil)
|
||||
|
||||
// Add an update for channel 5 and ack it.
|
||||
update = randCommittedUpdateForChannel(h.t, chanID5, 4)
|
||||
lastApplied = h.commitUpdate(&session1.ID, update, nil)
|
||||
require.EqualValues(h.t, 3, lastApplied)
|
||||
h.ackUpdate(&session1.ID, 4, 4, nil)
|
||||
|
||||
// Add an update for channel 6 and ack it.
|
||||
update = randCommittedUpdateForChannel(h.t, chanID6, 5)
|
||||
lastApplied = h.commitUpdate(&session1.ID, update, nil)
|
||||
require.EqualValues(h.t, 4, lastApplied)
|
||||
h.ackUpdate(&session1.ID, 5, 5, nil)
|
||||
|
||||
// The session is no exhausted.
|
||||
// If we now close channel 5, session 1 should still not be closable
|
||||
// since it has an update for channel 6 which is still open.
|
||||
sl = h.markChannelClosed(chanID5, 1, nil)
|
||||
require.Empty(h.t, sl)
|
||||
require.Empty(h.t, h.listClosableSessions(nil))
|
||||
|
||||
// Also check that attempting to delete the session will fail since it
|
||||
// is not yet considered closable.
|
||||
h.deleteSession(session1.ID, wtdb.ErrSessionNotClosable)
|
||||
|
||||
// Finally, if we close channel 6, session 1 _should_ be in the closable
|
||||
// list.
|
||||
sl = h.markChannelClosed(chanID6, 100, nil)
|
||||
require.ElementsMatch(h.t, sl, []wtdb.SessionID{session1.ID})
|
||||
slMap := h.listClosableSessions(nil)
|
||||
require.InDeltaMapValues(h.t, slMap, map[wtdb.SessionID]uint32{
|
||||
session1.ID: 100,
|
||||
}, 0)
|
||||
|
||||
// Assert that we now can delete the session.
|
||||
h.deleteSession(session1.ID, nil)
|
||||
require.Empty(h.t, h.listClosableSessions(nil))
|
||||
}
|
||||
|
||||
// testAckUpdate asserts the behavior of AckUpdate.
|
||||
func testAckUpdate(h *clientDBHarness) {
|
||||
const blobType = blob.TypeAltruistCommit
|
||||
|
@ -821,6 +965,10 @@ func TestClientDB(t *testing.T) {
|
|||
name: "ack update",
|
||||
run: testAckUpdate,
|
||||
},
|
||||
{
|
||||
name: "mark channel closed",
|
||||
run: testMarkChannelClosed,
|
||||
},
|
||||
}
|
||||
|
||||
for _, database := range dbs {
|
||||
|
@ -841,12 +989,32 @@ func TestClientDB(t *testing.T) {
|
|||
|
||||
// randCommittedUpdate generates a random committed update.
|
||||
func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate {
|
||||
t.Helper()
|
||||
|
||||
chanID := randChannelID(t)
|
||||
|
||||
return randCommittedUpdateForChannel(t, chanID, seqNum)
|
||||
}
|
||||
|
||||
func randChannelID(t *testing.T) lnwire.ChannelID {
|
||||
t.Helper()
|
||||
|
||||
var chanID lnwire.ChannelID
|
||||
_, err := io.ReadFull(crand.Reader, chanID[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
return chanID
|
||||
}
|
||||
|
||||
// randCommittedUpdateForChannel generates a random committed update for the
|
||||
// given channel ID.
|
||||
func randCommittedUpdateForChannel(t *testing.T, chanID lnwire.ChannelID,
|
||||
seqNum uint16) *wtdb.CommittedUpdate {
|
||||
|
||||
t.Helper()
|
||||
|
||||
var hint blob.BreachHint
|
||||
_, err = io.ReadFull(crand.Reader, hint[:])
|
||||
_, err := io.ReadFull(crand.Reader, hint[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type()))
|
||||
|
@ -865,3 +1033,27 @@ func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate {
|
|||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) randSession(t *testing.T,
|
||||
towerID wtdb.TowerID, maxUpdates uint16) *wtdb.ClientSession {
|
||||
|
||||
t.Helper()
|
||||
|
||||
var id wtdb.SessionID
|
||||
rand.Read(id[:])
|
||||
|
||||
return &wtdb.ClientSession{
|
||||
ClientSessionBody: wtdb.ClientSessionBody{
|
||||
TowerID: towerID,
|
||||
Policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blobType,
|
||||
},
|
||||
MaxUpdates: maxUpdates,
|
||||
},
|
||||
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
||||
KeyIndex: h.nextKeyIndex(towerID, blobType),
|
||||
},
|
||||
ID: id,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,8 @@ import (
|
|||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration3"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration4"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration5"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration6"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration7"
|
||||
)
|
||||
|
||||
// log is a logger that is initialized with no output filters. This
|
||||
|
@ -36,6 +38,8 @@ func UseLogger(logger btclog.Logger) {
|
|||
migration3.UseLogger(logger)
|
||||
migration4.UseLogger(logger)
|
||||
migration5.UseLogger(logger)
|
||||
migration6.UseLogger(logger)
|
||||
migration7.UseLogger(logger)
|
||||
}
|
||||
|
||||
// logClosure is used to provide a closure over expensive logging operations so
|
||||
|
|
114
watchtower/wtdb/migration6/client_db.go
Normal file
114
watchtower/wtdb/migration6/client_db.go
Normal file
|
@ -0,0 +1,114 @@
|
|||
package migration6
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
var (
|
||||
// cSessionBkt is a top-level bucket storing:
|
||||
// session-id => cSessionBody -> encoded ClientSessionBody
|
||||
// => cSessionDBID -> db-assigned-id
|
||||
// => cSessionCommits => seqnum -> encoded CommittedUpdate
|
||||
// => cSessionAcks => seqnum -> encoded BackupID
|
||||
cSessionBkt = []byte("client-session-bucket")
|
||||
|
||||
// cSessionDBID is a key used in the cSessionBkt to store the
|
||||
// db-assigned-id of a session.
|
||||
cSessionDBID = []byte("client-session-db-id")
|
||||
|
||||
// cSessionIDIndexBkt is a top-level bucket storing:
|
||||
// db-assigned-id -> session-id
|
||||
cSessionIDIndexBkt = []byte("client-session-id-index")
|
||||
|
||||
// cSessionBody is a sub-bucket of cSessionBkt storing only the body of
|
||||
// the ClientSession.
|
||||
cSessionBody = []byte("client-session-body")
|
||||
|
||||
// ErrUninitializedDB signals that top-level buckets for the database
|
||||
// have not been initialized.
|
||||
ErrUninitializedDB = errors.New("db not initialized")
|
||||
|
||||
// ErrCorruptClientSession signals that the client session's on-disk
|
||||
// structure deviates from what is expected.
|
||||
ErrCorruptClientSession = errors.New("client session corrupted")
|
||||
|
||||
byteOrder = binary.BigEndian
|
||||
)
|
||||
|
||||
// MigrateSessionIDIndex adds a new session ID index to the tower client db.
|
||||
// This index is a mapping from db-assigned ID (a uint64 encoded using BigSize)
|
||||
// to real session ID (33 bytes). This mapping will allow us to persist session
|
||||
// pointers with fewer bytes in the future.
|
||||
func MigrateSessionIDIndex(tx kvdb.RwTx) error {
|
||||
log.Infof("Migrating the tower client db to add a new session ID " +
|
||||
"index which stores a mapping from db-assigned ID to real " +
|
||||
"session ID")
|
||||
|
||||
// Create a new top-level bucket for the index.
|
||||
indexBkt, err := tx.CreateTopLevelBucket(cSessionIDIndexBkt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the existing top-level sessions bucket.
|
||||
sessionsBkt := tx.ReadWriteBucket(cSessionBkt)
|
||||
if sessionsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// Iterate over the sessions bucket where each key is a session-ID.
|
||||
return sessionsBkt.ForEach(func(sessionID, _ []byte) error {
|
||||
// Ask the DB for a new, unique, id for the index bucket.
|
||||
nextSeq, err := indexBkt.NextSequence()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newIndex, err := writeBigSize(nextSeq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add the new db-assigned-ID to real-session-ID pair to the
|
||||
// new index bucket.
|
||||
err = indexBkt.Put(newIndex, sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the sub-bucket for this specific session ID.
|
||||
sessionBkt := sessionsBkt.NestedReadWriteBucket(sessionID)
|
||||
if sessionBkt == nil {
|
||||
return ErrCorruptClientSession
|
||||
}
|
||||
|
||||
// Here we ensure that the session bucket includes a session
|
||||
// body. The only reason we do this is so that we can simulate
|
||||
// a migration fail in a test to ensure that a migration fail
|
||||
// results in an untouched db.
|
||||
sessionBodyBytes := sessionBkt.Get(cSessionBody)
|
||||
if sessionBodyBytes == nil {
|
||||
return ErrCorruptClientSession
|
||||
}
|
||||
|
||||
// Add the db-assigned ID of the session to the session under
|
||||
// the cSessionDBID key.
|
||||
return sessionBkt.Put(cSessionDBID, newIndex)
|
||||
})
|
||||
}
|
||||
|
||||
// writeBigSize will encode the given uint64 as a BigSize byte slice.
|
||||
func writeBigSize(i uint64) ([]byte, error) {
|
||||
var b bytes.Buffer
|
||||
err := tlv.WriteVarInt(&b, i, &[8]byte{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.Bytes(), nil
|
||||
}
|
147
watchtower/wtdb/migration6/client_db_test.go
Normal file
147
watchtower/wtdb/migration6/client_db_test.go
Normal file
|
@ -0,0 +1,147 @@
|
|||
package migration6
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb/migtest"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
var (
|
||||
// pre is the expected data in the sessions bucket before the migration.
|
||||
pre = map[string]interface{}{
|
||||
sessionIDToString(100): map[string]interface{}{
|
||||
string(cSessionBody): string([]byte{1, 2, 3}),
|
||||
},
|
||||
sessionIDToString(222): map[string]interface{}{
|
||||
string(cSessionBody): string([]byte{4, 5, 6}),
|
||||
},
|
||||
}
|
||||
|
||||
// preFailCorruptDB should fail the migration due to no session body
|
||||
// being found for a given session ID.
|
||||
preFailCorruptDB = map[string]interface{}{
|
||||
sessionIDToString(100): "",
|
||||
}
|
||||
|
||||
// post is the expected session index after migration.
|
||||
postIndex = map[string]interface{}{
|
||||
indexToString(1): sessionIDToString(100),
|
||||
indexToString(2): sessionIDToString(222),
|
||||
}
|
||||
|
||||
// postSessions is the expected data in the sessions bucket after the
|
||||
// migration.
|
||||
postSessions = map[string]interface{}{
|
||||
sessionIDToString(100): map[string]interface{}{
|
||||
string(cSessionBody): string([]byte{1, 2, 3}),
|
||||
string(cSessionDBID): indexToString(1),
|
||||
},
|
||||
sessionIDToString(222): map[string]interface{}{
|
||||
string(cSessionBody): string([]byte{4, 5, 6}),
|
||||
string(cSessionDBID): indexToString(2),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// TestMigrateSessionIDIndex tests that the MigrateSessionIDIndex function
|
||||
// correctly adds a new session-id index to the DB and also correctly updates
|
||||
// the existing session bucket.
|
||||
func TestMigrateSessionIDIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
shouldFail bool
|
||||
pre map[string]interface{}
|
||||
postSessions map[string]interface{}
|
||||
postIndex map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "migration ok",
|
||||
shouldFail: false,
|
||||
pre: pre,
|
||||
postSessions: postSessions,
|
||||
postIndex: postIndex,
|
||||
},
|
||||
{
|
||||
name: "fail due to corrupt db",
|
||||
shouldFail: true,
|
||||
pre: preFailCorruptDB,
|
||||
},
|
||||
{
|
||||
name: "no channel details",
|
||||
shouldFail: false,
|
||||
pre: nil,
|
||||
postSessions: nil,
|
||||
postIndex: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Before the migration we have a details bucket.
|
||||
before := func(tx kvdb.RwTx) error {
|
||||
return migtest.RestoreDB(
|
||||
tx, cSessionBkt, test.pre,
|
||||
)
|
||||
}
|
||||
|
||||
// After the migration, we should have an untouched
|
||||
// summary bucket and a new index bucket.
|
||||
after := func(tx kvdb.RwTx) error {
|
||||
// If the migration fails, the details bucket
|
||||
// should be untouched.
|
||||
if test.shouldFail {
|
||||
if err := migtest.VerifyDB(
|
||||
tx, cSessionBkt, test.pre,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Else, we expect an updated summary bucket
|
||||
// and a new index bucket.
|
||||
err := migtest.VerifyDB(
|
||||
tx, cSessionBkt, test.postSessions,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return migtest.VerifyDB(
|
||||
tx, cSessionIDIndexBkt, test.postIndex,
|
||||
)
|
||||
}
|
||||
|
||||
migtest.ApplyMigration(
|
||||
t, before, after, MigrateSessionIDIndex,
|
||||
test.shouldFail,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func indexToString(id uint64) string {
|
||||
var newIndex bytes.Buffer
|
||||
err := tlv.WriteVarInt(&newIndex, id, &[8]byte{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return newIndex.String()
|
||||
}
|
||||
|
||||
func sessionIDToString(id uint64) string {
|
||||
var chanID SessionID
|
||||
byteOrder.PutUint64(chanID[:], id)
|
||||
return chanID.String()
|
||||
}
|
17
watchtower/wtdb/migration6/codec.go
Normal file
17
watchtower/wtdb/migration6/codec.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package migration6
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
// SessionIDSize is 33-bytes; it is a serialized, compressed public key.
|
||||
const SessionIDSize = 33
|
||||
|
||||
// SessionID is created from the remote public key of a client, and serves as a
|
||||
// unique identifier and authentication for sending state updates.
|
||||
type SessionID [SessionIDSize]byte
|
||||
|
||||
// String returns a hex encoding of the session id.
|
||||
func (s SessionID) String() string {
|
||||
return hex.EncodeToString(s[:])
|
||||
}
|
14
watchtower/wtdb/migration6/log.go
Normal file
14
watchtower/wtdb/migration6/log.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package migration6
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btclog"
|
||||
)
|
||||
|
||||
// log is a logger that is initialized as disabled. This means the package will
|
||||
// not perform any logging by default until a logger is set.
|
||||
var log = btclog.Disabled
|
||||
|
||||
// UseLogger uses a specified Logger to output package logging info.
|
||||
func UseLogger(logger btclog.Logger) {
|
||||
log = logger
|
||||
}
|
202
watchtower/wtdb/migration7/client_db.go
Normal file
202
watchtower/wtdb/migration7/client_db.go
Normal file
|
@ -0,0 +1,202 @@
|
|||
package migration7
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
var (
|
||||
// cSessionBkt is a top-level bucket storing:
|
||||
// session-id => cSessionBody -> encoded ClientSessionBody
|
||||
// => cSessionDBID -> db-assigned-id
|
||||
// => cSessionCommits => seqnum -> encoded CommittedUpdate
|
||||
// => cSessionAckRangeIndex => chan-id => acked-index-range
|
||||
cSessionBkt = []byte("client-session-bucket")
|
||||
|
||||
// cChanDetailsBkt is a top-level bucket storing:
|
||||
// channel-id => cChannelSummary -> encoded ClientChanSummary.
|
||||
// => cChanDBID -> db-assigned-id
|
||||
// => cChanSessions => db-session-id -> 1
|
||||
cChanDetailsBkt = []byte("client-channel-detail-bucket")
|
||||
|
||||
// cChannelSummary is a sub-bucket of cChanDetailsBkt which stores the
|
||||
// encoded body of ClientChanSummary.
|
||||
cChannelSummary = []byte("client-channel-summary")
|
||||
|
||||
// cChanSessions is a sub-bucket of cChanDetailsBkt which stores:
|
||||
// session-id -> 1
|
||||
cChanSessions = []byte("client-channel-sessions")
|
||||
|
||||
// cSessionAckRangeIndex is a sub-bucket of cSessionBkt storing:
|
||||
// chan-id => start -> end
|
||||
cSessionAckRangeIndex = []byte("client-session-ack-range-index")
|
||||
|
||||
// cSessionDBID is a key used in the cSessionBkt to store the
|
||||
// db-assigned-d of a session.
|
||||
cSessionDBID = []byte("client-session-db-id")
|
||||
|
||||
// cChanIDIndexBkt is a top-level bucket storing:
|
||||
// db-assigned-id -> channel-ID
|
||||
cChanIDIndexBkt = []byte("client-channel-id-index")
|
||||
|
||||
// ErrUninitializedDB signals that top-level buckets for the database
|
||||
// have not been initialized.
|
||||
ErrUninitializedDB = errors.New("db not initialized")
|
||||
|
||||
// ErrCorruptClientSession signals that the client session's on-disk
|
||||
// structure deviates from what is expected.
|
||||
ErrCorruptClientSession = errors.New("client session corrupted")
|
||||
|
||||
// byteOrder is the default endianness used when serializing integers.
|
||||
byteOrder = binary.BigEndian
|
||||
)
|
||||
|
||||
// MigrateChannelToSessionIndex migrates the tower client DB to add an index
|
||||
// from channel-to-session. This will make it easier in future to check which
|
||||
// sessions have updates for which channels.
|
||||
func MigrateChannelToSessionIndex(tx kvdb.RwTx) error {
|
||||
log.Infof("Migrating the tower client DB to build a new " +
|
||||
"channel-to-session index")
|
||||
|
||||
sessionsBkt := tx.ReadBucket(cSessionBkt)
|
||||
if sessionsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt)
|
||||
if chanDetailsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
chanIDsBkt := tx.ReadBucket(cChanIDIndexBkt)
|
||||
if chanIDsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// First gather all the new channel-to-session pairs that we want to
|
||||
// add.
|
||||
index, err := collectIndex(sessionsBkt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then persist those pairs to the db.
|
||||
return persistIndex(chanDetailsBkt, chanIDsBkt, index)
|
||||
}
|
||||
|
||||
// collectIndex iterates through all the sessions and uses the keys in the
|
||||
// cSessionAckRangeIndex bucket to collect all the channels that the session
|
||||
// has updates for. The function returns a map from channel ID to session ID
|
||||
// (using the db-assigned IDs for both).
|
||||
func collectIndex(sessionsBkt kvdb.RBucket) (map[uint64]map[uint64]bool,
|
||||
error) {
|
||||
|
||||
index := make(map[uint64]map[uint64]bool)
|
||||
err := sessionsBkt.ForEach(func(sessID, _ []byte) error {
|
||||
sessionBkt := sessionsBkt.NestedReadBucket(sessID)
|
||||
if sessionBkt == nil {
|
||||
return ErrCorruptClientSession
|
||||
}
|
||||
|
||||
ackedRanges := sessionBkt.NestedReadBucket(
|
||||
cSessionAckRangeIndex,
|
||||
)
|
||||
if ackedRanges == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sessDBIDBytes := sessionBkt.Get(cSessionDBID)
|
||||
if sessDBIDBytes == nil {
|
||||
return ErrCorruptClientSession
|
||||
}
|
||||
|
||||
sessDBID, err := readUint64(sessDBIDBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ackedRanges.ForEach(func(dbChanIDBytes, _ []byte) error {
|
||||
dbChanID, err := readUint64(dbChanIDBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := index[dbChanID]; !ok {
|
||||
index[dbChanID] = make(map[uint64]bool)
|
||||
}
|
||||
|
||||
index[dbChanID][sessDBID] = true
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return index, nil
|
||||
}
|
||||
|
||||
// persistIndex adds the channel-to-session mapping in each channel's details
|
||||
// bucket.
|
||||
func persistIndex(chanDetailsBkt kvdb.RwBucket, chanIDsBkt kvdb.RBucket,
|
||||
index map[uint64]map[uint64]bool) error {
|
||||
|
||||
for dbChanID, sessIDs := range index {
|
||||
dbChanIDBytes, err := writeUint64(dbChanID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
realChanID := chanIDsBkt.Get(dbChanIDBytes)
|
||||
|
||||
chanBkt := chanDetailsBkt.NestedReadWriteBucket(realChanID)
|
||||
if chanBkt == nil {
|
||||
return fmt.Errorf("channel not found")
|
||||
}
|
||||
|
||||
sessIDsBkt, err := chanBkt.CreateBucket(cChanSessions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for id := range sessIDs {
|
||||
sessID, err := writeUint64(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = sessIDsBkt.Put(sessID, []byte{1})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeUint64(i uint64) ([]byte, error) {
|
||||
var b bytes.Buffer
|
||||
err := tlv.WriteVarInt(&b, i, &[8]byte{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
func readUint64(b []byte) (uint64, error) {
|
||||
r := bytes.NewReader(b)
|
||||
i, err := tlv.ReadVarInt(r, &[8]byte{})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return i, nil
|
||||
}
|
191
watchtower/wtdb/migration7/client_db_test.go
Normal file
191
watchtower/wtdb/migration7/client_db_test.go
Normal file
|
@ -0,0 +1,191 @@
|
|||
package migration7
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb/migtest"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
)
|
||||
|
||||
var (
|
||||
// preDetails is the expected data of the channel details bucket before
|
||||
// the migration.
|
||||
preDetails = map[string]interface{}{
|
||||
channelIDString(100): map[string]interface{}{
|
||||
string(cChannelSummary): string([]byte{1, 2, 3}),
|
||||
},
|
||||
channelIDString(222): map[string]interface{}{
|
||||
string(cChannelSummary): string([]byte{4, 5, 6}),
|
||||
},
|
||||
}
|
||||
|
||||
// preFailCorruptDB should fail the migration due to no channel summary
|
||||
// being found for a given channel ID.
|
||||
preFailCorruptDB = map[string]interface{}{
|
||||
channelIDString(30): map[string]interface{}{},
|
||||
}
|
||||
|
||||
// channelIDIndex is the data in the channelID index that is used to
|
||||
// find the mapping between the db-assigned channel ID and the real
|
||||
// channel ID.
|
||||
channelIDIndex = map[string]interface{}{
|
||||
uint64ToStr(10): channelIDString(100),
|
||||
uint64ToStr(20): channelIDString(222),
|
||||
}
|
||||
|
||||
// sessions is the expected data in the sessions bucket before and
|
||||
// after the migration.
|
||||
sessions = map[string]interface{}{
|
||||
sessionIDString("1"): map[string]interface{}{
|
||||
string(cSessionAckRangeIndex): map[string]interface{}{
|
||||
uint64ToStr(10): map[string]interface{}{
|
||||
uint64ToStr(30): uint64ToStr(32),
|
||||
uint64ToStr(34): uint64ToStr(34),
|
||||
},
|
||||
uint64ToStr(20): map[string]interface{}{
|
||||
uint64ToStr(30): uint64ToStr(30),
|
||||
},
|
||||
},
|
||||
string(cSessionDBID): uint64ToStr(66),
|
||||
},
|
||||
sessionIDString("2"): map[string]interface{}{
|
||||
string(cSessionAckRangeIndex): map[string]interface{}{
|
||||
uint64ToStr(10): map[string]interface{}{
|
||||
uint64ToStr(33): uint64ToStr(33),
|
||||
},
|
||||
},
|
||||
string(cSessionDBID): uint64ToStr(77),
|
||||
},
|
||||
}
|
||||
|
||||
// postDetails is the expected data in the channel details bucket after
|
||||
// the migration.
|
||||
postDetails = map[string]interface{}{
|
||||
channelIDString(100): map[string]interface{}{
|
||||
string(cChannelSummary): string([]byte{1, 2, 3}),
|
||||
string(cChanSessions): map[string]interface{}{
|
||||
uint64ToStr(66): string([]byte{1}),
|
||||
uint64ToStr(77): string([]byte{1}),
|
||||
},
|
||||
},
|
||||
channelIDString(222): map[string]interface{}{
|
||||
string(cChannelSummary): string([]byte{4, 5, 6}),
|
||||
string(cChanSessions): map[string]interface{}{
|
||||
uint64ToStr(66): string([]byte{1}),
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// TestMigrateChannelToSessionIndex tests that the MigrateChannelToSessionIndex
|
||||
// function correctly builds the new channel-to-sessionID index to the tower
|
||||
// client DB.
|
||||
func TestMigrateChannelToSessionIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
shouldFail bool
|
||||
preDetails map[string]interface{}
|
||||
preSessions map[string]interface{}
|
||||
preChanIndex map[string]interface{}
|
||||
postDetails map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "migration ok",
|
||||
shouldFail: false,
|
||||
preDetails: preDetails,
|
||||
preSessions: sessions,
|
||||
preChanIndex: channelIDIndex,
|
||||
postDetails: postDetails,
|
||||
},
|
||||
{
|
||||
name: "fail due to corrupt db",
|
||||
shouldFail: true,
|
||||
preDetails: preFailCorruptDB,
|
||||
preSessions: sessions,
|
||||
},
|
||||
{
|
||||
name: "no sessions",
|
||||
shouldFail: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Before the migration we have a channel details
|
||||
// bucket, a sessions bucket, a session ID index bucket
|
||||
// and a channel ID index bucket.
|
||||
before := func(tx kvdb.RwTx) error {
|
||||
err := migtest.RestoreDB(
|
||||
tx, cChanDetailsBkt, test.preDetails,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = migtest.RestoreDB(
|
||||
tx, cSessionBkt, test.preSessions,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return migtest.RestoreDB(
|
||||
tx, cChanIDIndexBkt, test.preChanIndex,
|
||||
)
|
||||
}
|
||||
|
||||
after := func(tx kvdb.RwTx) error {
|
||||
// If the migration fails, the details bucket
|
||||
// should be untouched.
|
||||
if test.shouldFail {
|
||||
if err := migtest.VerifyDB(
|
||||
tx, cChanDetailsBkt,
|
||||
test.preDetails,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Else, we expect an updated details bucket
|
||||
// and a new index bucket.
|
||||
return migtest.VerifyDB(
|
||||
tx, cChanDetailsBkt, test.postDetails,
|
||||
)
|
||||
}
|
||||
|
||||
migtest.ApplyMigration(
|
||||
t, before, after, MigrateChannelToSessionIndex,
|
||||
test.shouldFail,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func sessionIDString(id string) string {
|
||||
var sessID SessionID
|
||||
copy(sessID[:], id)
|
||||
return sessID.String()
|
||||
}
|
||||
|
||||
func channelIDString(id uint64) string {
|
||||
var chanID ChannelID
|
||||
byteOrder.PutUint64(chanID[:], id)
|
||||
return string(chanID[:])
|
||||
}
|
||||
|
||||
func uint64ToStr(id uint64) string {
|
||||
b, err := writeUint64(id)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return string(b)
|
||||
}
|
29
watchtower/wtdb/migration7/codec.go
Normal file
29
watchtower/wtdb/migration7/codec.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package migration7
|
||||
|
||||
import "encoding/hex"
|
||||
|
||||
// SessionIDSize is 33-bytes; it is a serialized, compressed public key.
|
||||
const SessionIDSize = 33
|
||||
|
||||
// SessionID is created from the remote public key of a client, and serves as a
|
||||
// unique identifier and authentication for sending state updates.
|
||||
type SessionID [SessionIDSize]byte
|
||||
|
||||
// String returns a hex encoding of the session id.
|
||||
func (s SessionID) String() string {
|
||||
return hex.EncodeToString(s[:])
|
||||
}
|
||||
|
||||
// ChannelID is a series of 32-bytes that uniquely identifies all channels
|
||||
// within the network. The ChannelID is computed using the outpoint of the
|
||||
// funding transaction (the txid, and output index). Given a funding output the
|
||||
// ChannelID can be calculated by XOR'ing the big-endian serialization of the
|
||||
// txid and the big-endian serialization of the output index, truncated to
|
||||
// 2 bytes.
|
||||
type ChannelID [32]byte
|
||||
|
||||
// String returns the string representation of the ChannelID. This is just the
|
||||
// hex string encoding of the ChannelID itself.
|
||||
func (c ChannelID) String() string {
|
||||
return hex.EncodeToString(c[:])
|
||||
}
|
14
watchtower/wtdb/migration7/log.go
Normal file
14
watchtower/wtdb/migration7/log.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package migration7
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btclog"
|
||||
)
|
||||
|
||||
// log is a logger that is initialized as disabled. This means the package will
|
||||
// not perform any logging by default until a logger is set.
|
||||
var log = btclog.Disabled
|
||||
|
||||
// UseLogger uses a specified Logger to output package logging info.
|
||||
func UseLogger(logger btclog.Logger) {
|
||||
log = logger
|
||||
}
|
|
@ -10,6 +10,8 @@ import (
|
|||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration3"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration4"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration5"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration6"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration7"
|
||||
)
|
||||
|
||||
// txMigration is a function which takes a prior outdated version of the
|
||||
|
@ -59,6 +61,12 @@ var clientDBVersions = []version{
|
|||
{
|
||||
txMigration: migration5.MigrateCompleteTowerToSessionIndex,
|
||||
},
|
||||
{
|
||||
txMigration: migration6.MigrateSessionIDIndex,
|
||||
},
|
||||
{
|
||||
txMigration: migration7.MigrateChannelToSessionIndex,
|
||||
},
|
||||
}
|
||||
|
||||
// getLatestDBVersion returns the last known database version.
|
||||
|
|
|
@ -25,19 +25,26 @@ type rangeIndexArrayMap map[wtdb.SessionID]map[lnwire.ChannelID]*wtdb.RangeIndex
|
|||
|
||||
type rangeIndexKVStore map[wtdb.SessionID]map[lnwire.ChannelID]*mockKVStore
|
||||
|
||||
type channel struct {
|
||||
summary *wtdb.ClientChanSummary
|
||||
closedHeight uint32
|
||||
sessions map[wtdb.SessionID]bool
|
||||
}
|
||||
|
||||
// ClientDB is a mock, in-memory database or testing the watchtower client
|
||||
// behavior.
|
||||
type ClientDB struct {
|
||||
nextTowerID uint64 // to be used atomically
|
||||
|
||||
mu sync.Mutex
|
||||
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
|
||||
channels map[lnwire.ChannelID]*channel
|
||||
activeSessions map[wtdb.SessionID]wtdb.ClientSession
|
||||
ackedUpdates rangeIndexArrayMap
|
||||
persistedAckedUpdates rangeIndexKVStore
|
||||
committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate
|
||||
towerIndex map[towerPK]wtdb.TowerID
|
||||
towers map[wtdb.TowerID]*wtdb.Tower
|
||||
closableSessions map[wtdb.SessionID]uint32
|
||||
|
||||
nextIndex uint32
|
||||
indexes map[keyIndexKey]uint32
|
||||
|
@ -47,9 +54,7 @@ type ClientDB struct {
|
|||
// NewClientDB initializes a new mock ClientDB.
|
||||
func NewClientDB() *ClientDB {
|
||||
return &ClientDB{
|
||||
summaries: make(
|
||||
map[lnwire.ChannelID]wtdb.ClientChanSummary,
|
||||
),
|
||||
channels: make(map[lnwire.ChannelID]*channel),
|
||||
activeSessions: make(
|
||||
map[wtdb.SessionID]wtdb.ClientSession,
|
||||
),
|
||||
|
@ -58,10 +63,11 @@ func NewClientDB() *ClientDB {
|
|||
committedUpdates: make(
|
||||
map[wtdb.SessionID][]wtdb.CommittedUpdate,
|
||||
),
|
||||
towerIndex: make(map[towerPK]wtdb.TowerID),
|
||||
towers: make(map[wtdb.TowerID]*wtdb.Tower),
|
||||
indexes: make(map[keyIndexKey]uint32),
|
||||
legacyIndexes: make(map[wtdb.TowerID]uint32),
|
||||
towerIndex: make(map[towerPK]wtdb.TowerID),
|
||||
towers: make(map[wtdb.TowerID]*wtdb.Tower),
|
||||
indexes: make(map[keyIndexKey]uint32),
|
||||
legacyIndexes: make(map[wtdb.TowerID]uint32),
|
||||
closableSessions: make(map[wtdb.SessionID]uint32),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -503,6 +509,13 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum,
|
|||
continue
|
||||
}
|
||||
|
||||
// Add sessionID to channel.
|
||||
channel, ok := m.channels[update.BackupID.ChanID]
|
||||
if !ok {
|
||||
return wtdb.ErrChannelNotRegistered
|
||||
}
|
||||
channel.sessions[*id] = true
|
||||
|
||||
// Remove the committed update from disk and mark the update as
|
||||
// acked. The tower last applied value is also recorded to send
|
||||
// along with the next update.
|
||||
|
@ -538,22 +551,192 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum,
|
|||
return wtdb.ErrCommittedUpdateNotFound
|
||||
}
|
||||
|
||||
// ListClosableSessions fetches and returns the IDs for all sessions marked as
|
||||
// closable.
|
||||
func (m *ClientDB) ListClosableSessions() (map[wtdb.SessionID]uint32, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
cs := make(map[wtdb.SessionID]uint32, len(m.closableSessions))
|
||||
for id, height := range m.closableSessions {
|
||||
cs[id] = height
|
||||
}
|
||||
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
// FetchChanSummaries loads a mapping from all registered channels to their
|
||||
// channel summaries.
|
||||
// channel summaries. Only the channels that have not yet been marked as closed
|
||||
// will be loaded.
|
||||
func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary)
|
||||
for chanID, summary := range m.summaries {
|
||||
for chanID, channel := range m.channels {
|
||||
// Don't load the channel if it has been marked as closed.
|
||||
if channel.closedHeight > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
summaries[chanID] = wtdb.ClientChanSummary{
|
||||
SweepPkScript: cloneBytes(summary.SweepPkScript),
|
||||
SweepPkScript: cloneBytes(
|
||||
channel.summary.SweepPkScript,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
return summaries, nil
|
||||
}
|
||||
|
||||
// MarkChannelClosed will mark a registered channel as closed by setting
|
||||
// its closed-height as the given block height. It returns a list of
|
||||
// session IDs for sessions that are now considered closable due to the
|
||||
// close of this channel.
|
||||
func (m *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID,
|
||||
blockHeight uint32) ([]wtdb.SessionID, error) {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
channel, ok := m.channels[chanID]
|
||||
if !ok {
|
||||
return nil, wtdb.ErrChannelNotRegistered
|
||||
}
|
||||
|
||||
// If there are no sessions for this channel, the channel details can be
|
||||
// deleted.
|
||||
if len(channel.sessions) == 0 {
|
||||
delete(m.channels, chanID)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Mark the channel as closed.
|
||||
channel.closedHeight = blockHeight
|
||||
|
||||
// Now iterate through all the sessions of the channel to check if any
|
||||
// of them are closeable.
|
||||
var closableSessions []wtdb.SessionID
|
||||
for sessID := range channel.sessions {
|
||||
isClosable, err := m.isSessionClosable(sessID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !isClosable {
|
||||
continue
|
||||
}
|
||||
|
||||
closableSessions = append(closableSessions, sessID)
|
||||
|
||||
// Add session to "closableSessions" list and add the block
|
||||
// height that this last channel was closed in. This will be
|
||||
// used in future to determine when we should delete the
|
||||
// session.
|
||||
m.closableSessions[sessID] = blockHeight
|
||||
}
|
||||
|
||||
return closableSessions, nil
|
||||
}
|
||||
|
||||
// isSessionClosable returns true if a session is considered closable. A session
|
||||
// is considered closable only if:
|
||||
// 1) It has no un-acked updates
|
||||
// 2) It is exhausted (ie it cant accept any more updates)
|
||||
// 3) All the channels that it has acked-updates for are closed.
|
||||
func (m *ClientDB) isSessionClosable(id wtdb.SessionID) (bool, error) {
|
||||
// The session is not closable if it has un-acked updates.
|
||||
if len(m.committedUpdates[id]) > 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
sess, ok := m.activeSessions[id]
|
||||
if !ok {
|
||||
return false, wtdb.ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
// The session is not closable if it is not yet exhausted.
|
||||
if sess.SeqNum != sess.Policy.MaxUpdates {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Iterate over each of the channels that the session has acked-updates
|
||||
// for. If any of those channels are not closed, then the session is
|
||||
// not yet closable.
|
||||
for chanID := range m.ackedUpdates[id] {
|
||||
channel, ok := m.channels[chanID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Channel is not yet closed, and so we can not yet delete the
|
||||
// session.
|
||||
if channel.closedHeight == 0 {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetClientSession loads the ClientSession with the given ID from the DB.
|
||||
func (m *ClientDB) GetClientSession(id wtdb.SessionID,
|
||||
opts ...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error) {
|
||||
|
||||
cfg := wtdb.NewClientSessionCfg()
|
||||
for _, o := range opts {
|
||||
o(cfg)
|
||||
}
|
||||
|
||||
session, ok := m.activeSessions[id]
|
||||
if !ok {
|
||||
return nil, wtdb.ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
if cfg.PerMaxHeight != nil {
|
||||
for chanID, index := range m.ackedUpdates[session.ID] {
|
||||
cfg.PerMaxHeight(&session, chanID, index.MaxHeight())
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.PerCommittedUpdate != nil {
|
||||
for _, update := range m.committedUpdates[session.ID] {
|
||||
update := update
|
||||
cfg.PerCommittedUpdate(&session, &update)
|
||||
}
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// DeleteSession can be called when a session should be deleted from the DB.
|
||||
// All references to the session will also be deleted from the DB. Note that a
|
||||
// session will only be deleted if it is considered closable.
|
||||
func (m *ClientDB) DeleteSession(id wtdb.SessionID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
_, ok := m.closableSessions[id]
|
||||
if !ok {
|
||||
return wtdb.ErrSessionNotClosable
|
||||
}
|
||||
|
||||
// For each of the channels, delete the session ID entry.
|
||||
for chanID := range m.ackedUpdates[id] {
|
||||
c, ok := m.channels[chanID]
|
||||
if !ok {
|
||||
return wtdb.ErrChannelNotRegistered
|
||||
}
|
||||
|
||||
delete(c.sessions, id)
|
||||
}
|
||||
|
||||
delete(m.closableSessions, id)
|
||||
delete(m.activeSessions, id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterChannel registers a channel for use within the client database. For
|
||||
// now, all that is stored in the channel summary is the sweep pkscript that
|
||||
// we'd like any tower sweeps to pay into. In the future, this will be extended
|
||||
|
@ -565,12 +748,15 @@ func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
|
|||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, ok := m.summaries[chanID]; ok {
|
||||
if _, ok := m.channels[chanID]; ok {
|
||||
return wtdb.ErrChannelAlreadyRegistered
|
||||
}
|
||||
|
||||
m.summaries[chanID] = wtdb.ClientChanSummary{
|
||||
SweepPkScript: cloneBytes(sweepPkScript),
|
||||
m.channels[chanID] = &channel{
|
||||
summary: &wtdb.ClientChanSummary{
|
||||
SweepPkScript: cloneBytes(sweepPkScript),
|
||||
},
|
||||
sessions: make(map[wtdb.SessionID]bool),
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
Loading…
Add table
Reference in a new issue