Merge pull request #7981 from ellemouton/handleRogueUpdates

watchtower: handle rogue updates
This commit is contained in:
Olaoluwa Osuntokun 2023-09-18 13:56:52 -07:00 committed by GitHub
commit 9f4a8836db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 538 additions and 1046 deletions

View file

@ -71,10 +71,14 @@ fails](https://github.com/lightningnetwork/lnd/pull/7876).
retried](https://github.com/lightningnetwork/lnd/pull/7927) with an
exponential back off.
* In the watchtower client, we [now explicitly
handle](https://github.com/lightningnetwork/lnd/pull/7981) the scenario where
a channel is closed while we still have an in-memory update for it.
* `lnd` [now properly handles a case where an erroneous force close attempt
would impeded start up](https://github.com/lightningnetwork/lnd/pull/7985).
# New Features
## Functional Enhancements

View file

@ -390,6 +390,10 @@ func constructFunctionalOptions(includeSessions,
return opts, ackCounts, committedUpdateCounts
}
perNumRogueUpdates := func(s *wtdb.ClientSession, numUpdates uint16) {
ackCounts[s.ID] += numUpdates
}
perNumAckedUpdates := func(s *wtdb.ClientSession, id lnwire.ChannelID,
numUpdates uint16) {
@ -405,6 +409,7 @@ func constructFunctionalOptions(includeSessions,
opts = []wtdb.ClientSessionListOption{
wtdb.WithPerNumAckedUpdates(perNumAckedUpdates),
wtdb.WithPerCommittedUpdate(perCommittedUpdate),
wtdb.WithPerRogueUpdateCount(perNumRogueUpdates),
}
if excludeExhaustedSessions {

View file

@ -977,6 +977,19 @@ func (c *TowerClient) handleClosableSessions(
// and handle it.
c.closableSessionQueue.Pop()
// Stop the session and remove it from the
// in-memory set.
err := c.activeSessions.StopAndRemove(
item.sessionID,
)
if err != nil {
c.log.Errorf("could not remove "+
"session(%s) from in-memory "+
"set: %v", item.sessionID, err)
return
}
// Fetch the session from the DB so that we can
// extract the Tower info.
sess, err := c.cfg.DB.GetClientSession(

View file

@ -21,6 +21,7 @@ import (
"github.com/lightningnetwork/lnd/channelnotifier"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntest/wait"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
@ -72,7 +73,7 @@ var (
addrScript, _ = txscript.PayToAddrScript(addr)
waitTime = 5 * time.Second
waitTime = 15 * time.Second
defaultTxPolicy = wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
@ -398,7 +399,7 @@ type testHarness struct {
cfg harnessCfg
signer *wtmock.MockSigner
capacity lnwire.MilliSatoshi
clientDB *wtmock.ClientDB
clientDB *wtdb.ClientDB
clientCfg *wtclient.Config
client wtclient.Client
server *serverHarness
@ -426,10 +427,26 @@ type harnessCfg struct {
noServerStart bool
}
func newClientDB(t *testing.T) *wtdb.ClientDB {
dbCfg := &kvdb.BoltConfig{
DBTimeout: kvdb.DefaultDBTimeout,
}
// Construct the ClientDB.
dir := t.TempDir()
bdb, err := wtdb.NewBoltBackendCreator(true, dir, "wtclient.db")(dbCfg)
require.NoError(t, err)
clientDB, err := wtdb.OpenClientDB(bdb)
require.NoError(t, err)
return clientDB
}
func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
signer := wtmock.NewMockSigner()
mockNet := newMockNet()
clientDB := wtmock.NewClientDB()
clientDB := newClientDB(t)
server := newServerHarness(
t, mockNet, towerAddrStr, func(serverCfg *wtserver.Config) {
@ -509,6 +526,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
h.startClient()
t.Cleanup(func() {
require.NoError(t, h.client.Stop())
require.NoError(t, h.clientDB.Close())
})
h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance)
@ -1342,7 +1360,7 @@ var clientTests = []clientTest{
// Wait for all the updates to be populated in the
// server's database.
h.server.waitForUpdates(hints, 3*time.Second)
h.server.waitForUpdates(hints, waitTime)
},
},
{
@ -2053,7 +2071,7 @@ var clientTests = []clientTest{
// Now stop the client and reset its database.
require.NoError(h.t, h.client.Stop())
db := wtmock.NewClientDB()
db := newClientDB(h.t)
h.clientDB = db
h.clientCfg.DB = db
@ -2398,6 +2416,140 @@ var clientTests = []clientTest{
server2.waitForUpdates(hints[numUpdates/2:], waitTime)
},
},
{
// This test shows that if a channel is closed while an update
// for that channel still exists in an in-memory queue
// somewhere then it is handled correctly by treating it as a
// rogue update.
name: "channel closed while update is un-acked",
cfg: harnessCfg{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: defaultTxPolicy,
MaxUpdates: 5,
},
},
fn: func(h *testHarness) {
const (
numUpdates = 10
chanIDInt = 0
)
h.sendUpdatesOn = true
// Advance the channel with a few updates.
hints := h.advanceChannelN(chanIDInt, numUpdates)
// Backup a few these updates and wait for them to
// arrive at the server. Note that we back up enough
// updates to saturate the session so that the session
// is considered closable when the channel is deleted.
h.backupStates(chanIDInt, 0, numUpdates/2, nil)
h.server.waitForUpdates(hints[:numUpdates/2], waitTime)
// Now, restart the server in a state where it will not
// ack updates. This will allow us to wait for an
// update to be un-acked and persisted.
h.server.restart(func(cfg *wtserver.Config) {
cfg.NoAckUpdates = true
})
// Backup a few more of the update. These should remain
// in the client as un-acked.
h.backupStates(
chanIDInt, numUpdates/2, numUpdates-1, nil,
)
// Wait for the tasks to be bound to sessions.
fetchSessions := h.clientDB.FetchSessionCommittedUpdates
err := wait.Predicate(func() bool {
sessions, err := h.clientDB.ListClientSessions(
nil,
)
require.NoError(h.t, err)
var updates []wtdb.CommittedUpdate
for id := range sessions {
updates, err = fetchSessions(&id)
require.NoError(h.t, err)
if len(updates) != numUpdates-1 {
return true
}
}
return false
}, waitTime)
require.NoError(h.t, err)
// Now we close this channel while the update for it has
// not yet been acked.
h.closeChannel(chanIDInt, 1)
// Closable sessions should now be one.
err = wait.Predicate(func() bool {
cs, err := h.clientDB.ListClosableSessions()
require.NoError(h.t, err)
return len(cs) == 1
}, waitTime)
require.NoError(h.t, err)
// Now, restart the server and allow it to ack updates
// again.
h.server.restart(func(cfg *wtserver.Config) {
cfg.NoAckUpdates = false
})
// Mine a few blocks so that the session close range is
// surpassed.
h.mine(3)
// Wait for there to be no more closable sessions on the
// client side.
err = wait.Predicate(func() bool {
cs, err := h.clientDB.ListClosableSessions()
require.NoError(h.t, err)
return len(cs) == 0
}, waitTime)
require.NoError(h.t, err)
// Wait for channel to be "unregistered".
chanID := chanIDFromInt(chanIDInt)
err = wait.Predicate(func() bool {
err := h.client.BackupState(&chanID, 0)
return errors.Is(
err, wtclient.ErrUnregisteredChannel,
)
}, waitTime)
require.NoError(h.t, err)
// Show that the committed update for the closed channel
// is cleared from the DB.
err = wait.Predicate(func() bool {
sessions, err := h.clientDB.ListClientSessions(
nil,
)
require.NoError(h.t, err)
var updates []wtdb.CommittedUpdate
for id := range sessions {
updates, err = fetchSessions(&id)
require.NoError(h.t, err)
if len(updates) != 0 {
return false
}
}
return true
}, waitTime)
require.NoError(h.t, err)
},
},
}
// TestClient executes the client test suite, asserting the ability to backup

View file

@ -18,51 +18,13 @@ const (
waitTime = time.Second * 2
)
type initQueue func(t *testing.T) wtdb.Queue[*wtdb.BackupID]
// TestDiskOverflowQueue tests that the DiskOverflowQueue behaves as expected.
func TestDiskOverflowQueue(t *testing.T) {
t.Parallel()
dbs := []struct {
name string
init initQueue
}{
{
name: "kvdb",
init: func(t *testing.T) wtdb.Queue[*wtdb.BackupID] {
dbCfg := &kvdb.BoltConfig{
DBTimeout: kvdb.DefaultDBTimeout,
}
bdb, err := wtdb.NewBoltBackendCreator(
true, t.TempDir(), "wtclient.db",
)(dbCfg)
require.NoError(t, err)
db, err := wtdb.OpenClientDB(bdb)
require.NoError(t, err)
t.Cleanup(func() {
db.Close()
})
return db.GetDBQueue([]byte("test-namespace"))
},
},
{
name: "mock",
init: func(t *testing.T) wtdb.Queue[*wtdb.BackupID] {
db := wtmock.NewClientDB()
return db.GetDBQueue([]byte("test-namespace"))
},
},
}
tests := []struct {
name string
run func(*testing.T, initQueue)
run func(*testing.T, wtdb.Queue[*wtdb.BackupID])
}{
{
name: "overflow to disk",
@ -78,29 +40,43 @@ func TestDiskOverflowQueue(t *testing.T) {
},
}
for _, database := range dbs {
db := database
t.Run(db.name, func(t *testing.T) {
t.Parallel()
initDB := func() wtdb.Queue[*wtdb.BackupID] {
dbCfg := &kvdb.BoltConfig{
DBTimeout: kvdb.DefaultDBTimeout,
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
test.run(t, db.init)
})
}
bdb, err := wtdb.NewBoltBackendCreator(
true, t.TempDir(), "wtclient.db",
)(dbCfg)
require.NoError(t, err)
db, err := wtdb.OpenClientDB(bdb)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, db.Close())
})
return db.GetDBQueue([]byte("test-namespace"))
}
for _, test := range tests {
test := test
t.Run(test.name, func(tt *testing.T) {
tt.Parallel()
test.run(tt, initDB())
})
}
}
// testOverflowToDisk is a basic test that ensures that the queue correctly
// overflows items to disk and then correctly reloads them.
func testOverflowToDisk(t *testing.T, initQueue initQueue) {
func testOverflowToDisk(t *testing.T, db wtdb.Queue[*wtdb.BackupID]) {
// Generate some backup IDs that we want to add to the queue.
tasks := genBackupIDs(10)
// Init the DB.
db := initQueue(t)
// New mock logger.
log := newMockLogger(t.Logf)
@ -146,7 +122,9 @@ func testOverflowToDisk(t *testing.T, initQueue initQueue) {
// testRestartWithSmallerBufferSize tests that if the queue is restarted with
// a smaller in-memory buffer size that it was initially started with, then
// tasks are still loaded in the correct order.
func testRestartWithSmallerBufferSize(t *testing.T, newQueue initQueue) {
func testRestartWithSmallerBufferSize(t *testing.T,
db wtdb.Queue[*wtdb.BackupID]) {
const (
firstMaxInMemItems = 5
secondMaxInMemItems = 2
@ -155,9 +133,6 @@ func testRestartWithSmallerBufferSize(t *testing.T, newQueue initQueue) {
// Generate some backup IDs that we want to add to the queue.
tasks := genBackupIDs(10)
// Create a db.
db := newQueue(t)
// New mock logger.
log := newMockLogger(t.Logf)
@ -223,14 +198,11 @@ func testRestartWithSmallerBufferSize(t *testing.T, newQueue initQueue) {
// testStartStopQueue is a stress test that pushes a large number of tasks
// through the queue while also restarting the queue a couple of times
// throughout.
func testStartStopQueue(t *testing.T, newQueue initQueue) {
func testStartStopQueue(t *testing.T, db wtdb.Queue[*wtdb.BackupID]) {
// Generate a lot of backup IDs that we want to add to the
// queue one after the other.
tasks := genBackupIDs(200_000)
// Construct the ClientDB.
db := newQueue(t)
// New mock logger.
log := newMockLogger(t.Logf)

View file

@ -50,6 +50,7 @@ var (
// => cSessionDBID -> db-assigned-id
// => cSessionCommits => seqnum -> encoded CommittedUpdate
// => cSessionAckRangeIndex => db-chan-id => start -> end
// => cSessionRogueUpdateCount -> count
cSessionBkt = []byte("client-session-bucket")
// cSessionDBID is a key used in the cSessionBkt to store the
@ -68,6 +69,12 @@ var (
// chan-id => start -> end
cSessionAckRangeIndex = []byte("client-session-ack-range-index")
// cSessionRogueUpdateCount is a key in the cSessionBkt bucket storing
// the number of rogue updates that were backed up using the session.
// Rogue updates are updates for channels that have been closed already
// at the time of the back-up.
cSessionRogueUpdateCount = []byte("client-session-rogue-update-count")
// cChanIDIndexBkt is a top-level bucket storing:
// db-assigned-id -> channel-ID
cChanIDIndexBkt = []byte("client-channel-id-index")
@ -980,29 +987,8 @@ func getRangesReadBucket(tx kvdb.RTx, sID SessionID, chanID lnwire.ChannelID) (
// getRangesWriteBucket gets the range index bucket where the range index for
// the given session-channel pair is stored. If any sub-buckets along the way do
// not exist, then they are created.
func getRangesWriteBucket(tx kvdb.RwTx, sID SessionID,
chanID lnwire.ChannelID) (kvdb.RwBucket, error) {
sessions := tx.ReadWriteBucket(cSessionBkt)
if sessions == nil {
return nil, ErrUninitializedDB
}
chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt)
if chanDetailsBkt == nil {
return nil, ErrUninitializedDB
}
sessionBkt, err := sessions.CreateBucketIfNotExists(sID[:])
if err != nil {
return nil, err
}
// Get the DB representation of the channel-ID.
_, dbChanIDBytes, err := getDBChanID(chanDetailsBkt, chanID)
if err != nil {
return nil, err
}
func getRangesWriteBucket(sessionBkt kvdb.RwBucket, dbChanIDBytes []byte) (
kvdb.RwBucket, error) {
sessionAckRanges, err := sessionBkt.CreateBucketIfNotExists(
cSessionAckRangeIndex,
@ -1263,10 +1249,23 @@ func (c *ClientDB) NumAckedUpdates(id *SessionID) (uint64, error) {
}
sessionBkt := sessions.NestedReadBucket(id[:])
if sessionsBkt == nil {
if sessionBkt == nil {
return nil
}
// First, account for any rogue updates.
rogueCountBytes := sessionBkt.Get(cSessionRogueUpdateCount)
if len(rogueCountBytes) != 0 {
rogueCount, err := readBigSize(rogueCountBytes)
if err != nil {
return err
}
numAcked += rogueCount
}
// Then, check if the session-ack-ranges contains any entries
// to account for.
sessionAckRanges := sessionBkt.NestedReadBucket(
cSessionAckRangeIndex,
)
@ -1546,14 +1545,37 @@ func (c *ClientDB) DeleteSession(id SessionID) error {
return err
}
// Get the acked updates range index for the session. This is
// used to get the list of channels that the session has updates
// for.
ackRanges := sessionBkt.NestedReadBucket(cSessionAckRangeIndex)
// There is a small chance that the session only contains rogue
// updates. In that case, there will be no ack-ranges index but
// the rogue update count will be equal the MaxUpdates.
rogueCountBytes := sessionBkt.Get(cSessionRogueUpdateCount)
if len(rogueCountBytes) != 0 {
rogueCount, err := readBigSize(rogueCountBytes)
if err != nil {
return err
}
maxUpdates := sess.ClientSessionBody.Policy.MaxUpdates
if rogueCount == uint64(maxUpdates) {
// Do a sanity check to ensure that the acked
// ranges bucket does not exist in this case.
if ackRanges != nil {
return fmt.Errorf("acked updates "+
"exist for session with a "+
"max-updates(%d) rogue count",
rogueCount)
}
return sessionsBkt.DeleteNestedBucket(id[:])
}
}
// A session would only be considered closable if it was
// exhausted. Meaning that it should not be the case that it has
// no acked-updates.
if ackRanges == nil {
// A session would only be considered closable if it
// was exhausted. Meaning that it should not be the
// case that it has no acked-updates.
return fmt.Errorf("cannot delete session %s since it "+
"is not yet exhausted", id)
}
@ -1784,6 +1806,22 @@ func isSessionClosable(sessionsBkt, chanDetailsBkt, chanIDIndexBkt kvdb.RBucket,
return false, nil
}
// Either the acked-update bucket should exist _or_ the rogue update
// count must be equal to the session's MaxUpdates value, otherwise
// something is wrong because the above check ensures that the session
// has been exhausted.
rogueCountBytes := sessBkt.Get(cSessionRogueUpdateCount)
if len(rogueCountBytes) != 0 {
rogueCount, err := readBigSize(rogueCountBytes)
if err != nil {
return false, err
}
if rogueCount == uint64(session.Policy.MaxUpdates) {
return true, nil
}
}
// If the session has no acked-updates, then something is wrong since
// the above check ensures that this session has been exhausted meaning
// that it should have MaxUpdates acked updates.
@ -2026,18 +2064,92 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
return err
}
chanID := committedUpdate.BackupID.ChanID
height := committedUpdate.BackupID.CommitHeight
// Get the ranges write bucket before getting the range index to
// ensure that the session acks sub-bucket is initialized, so
// that we can insert an entry.
rangesBkt, err := getRangesWriteBucket(tx, *id, chanID)
dbSessionID, dbSessIDBytes, err := getDBSessionID(sessions, *id)
if err != nil {
return err
}
dbSessionID, _, err := getDBSessionID(sessions, *id)
chanID := committedUpdate.BackupID.ChanID
height := committedUpdate.BackupID.CommitHeight
// Get the DB representation of the channel-ID. There is a
// chance that the channel corresponding to this update has been
// closed and that the details for this channel no longer exist
// in the tower client DB. In that case, we consider this a
// rogue update and all we do is make sure to keep track of the
// number of rogue updates for this session.
_, dbChanIDBytes, err := getDBChanID(chanDetailsBkt, chanID)
if errors.Is(err, ErrChannelNotRegistered) {
var (
count uint64
err error
)
rogueCountBytes := sessionBkt.Get(
cSessionRogueUpdateCount,
)
if len(rogueCountBytes) != 0 {
count, err = readBigSize(rogueCountBytes)
if err != nil {
return err
}
}
rogueCount := count + 1
countBytes, err := writeBigSize(rogueCount)
if err != nil {
return err
}
err = sessionBkt.Put(
cSessionRogueUpdateCount, countBytes,
)
if err != nil {
return err
}
// In the rare chance that this session only has rogue
// updates, we check here if the count is equal to the
// MaxUpdate of the session. If it is, then we mark the
// session as closable.
if rogueCount != uint64(session.Policy.MaxUpdates) {
return nil
}
// Before we mark the session as closable, we do a
// sanity check to ensure that this session has no
// acked-update index.
sessionAckRanges := sessionBkt.NestedReadBucket(
cSessionAckRangeIndex,
)
if sessionAckRanges != nil {
return fmt.Errorf("session(%s) has an "+
"acked ranges index but has a rogue "+
"count indicating saturation",
session.ID)
}
closableSessBkt := tx.ReadWriteBucket(
cClosableSessionsBkt,
)
if closableSessBkt == nil {
return ErrUninitializedDB
}
var height [4]byte
byteOrder.PutUint32(height[:], 0)
return closableSessBkt.Put(dbSessIDBytes, height[:])
} else if err != nil {
return err
}
// Get the ranges write bucket before getting the range index to
// ensure that the session acks sub-bucket is initialized, so
// that we can insert an entry.
rangesBkt, err := getRangesWriteBucket(
sessionBkt, dbChanIDBytes,
)
if err != nil {
return err
}
@ -2173,6 +2285,11 @@ type PerMaxHeightCB func(*ClientSession, lnwire.ChannelID, uint64)
// number of updates that the session has for the channel.
type PerNumAckedUpdatesCB func(*ClientSession, lnwire.ChannelID, uint16)
// PerRogueUpdateCountCB describes the signature of a callback function that can
// be called for each session with the number of rogue updates that the session
// has.
type PerRogueUpdateCountCB func(*ClientSession, uint16)
// PerAckedUpdateCB describes the signature of a callback function that can be
// called for each of a session's acked updates.
type PerAckedUpdateCB func(*ClientSession, uint16, BackupID)
@ -2195,6 +2312,10 @@ type ClientSessionListCfg struct {
// channel.
PerNumAckedUpdates PerNumAckedUpdatesCB
// PerRogueUpdateCount will, if set, be called with the number of rogue
// updates that the session has backed up.
PerRogueUpdateCount PerRogueUpdateCountCB
// PerMaxHeight will, if set, be called for each of the session's
// channels to communicate the highest commit height of updates stored
// for that channel.
@ -2242,6 +2363,15 @@ func WithPerNumAckedUpdates(cb PerNumAckedUpdatesCB) ClientSessionListOption {
}
}
// WithPerRogueUpdateCount constructs a functional option that will set a
// call-back function to be called with the number of rogue updates that the
// session has backed up.
func WithPerRogueUpdateCount(cb PerRogueUpdateCountCB) ClientSessionListOption {
return func(cfg *ClientSessionListCfg) {
cfg.PerRogueUpdateCount = cb
}
}
// WithPerCommittedUpdate constructs a functional option that will set a
// call-back function to be called for each of a client's un-acked updates.
func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
@ -2310,7 +2440,7 @@ func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket,
// provided.
err = c.filterClientSessionAcks(
sessionBkt, chanIDIndexBkt, session, cfg.PerMaxHeight,
cfg.PerNumAckedUpdates,
cfg.PerNumAckedUpdates, cfg.PerRogueUpdateCount,
)
if err != nil {
return nil, err
@ -2368,7 +2498,24 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession,
// call back if one is provided.
func (c *ClientDB) filterClientSessionAcks(sessionBkt,
chanIDIndexBkt kvdb.RBucket, s *ClientSession, perMaxCb PerMaxHeightCB,
perNumAckedUpdates PerNumAckedUpdatesCB) error {
perNumAckedUpdates PerNumAckedUpdatesCB,
perRogueUpdateCount PerRogueUpdateCountCB) error {
if perRogueUpdateCount != nil {
var (
count uint64
err error
)
rogueCountBytes := sessionBkt.Get(cSessionRogueUpdateCount)
if len(rogueCountBytes) != 0 {
count, err = readBigSize(rogueCountBytes)
if err != nil {
return err
}
}
perRogueUpdateCount(s, uint16(count))
}
if perMaxCb == nil && perNumAckedUpdates == nil {
return nil

View file

@ -13,7 +13,6 @@ import (
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtclient"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtmock"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
"github.com/stretchr/testify/require"
)
@ -676,6 +675,98 @@ func testCommitUpdate(h *clientDBHarness) {
h.assertUpdates(session.ID, []wtdb.CommittedUpdate{}, nil)
}
// testRogueUpdates asserts that rogue updates (updates for channels that are
// backed up after the channel has been closed and the channel details deleted
// from the DB) are handled correctly.
func testRogueUpdates(h *clientDBHarness) {
const maxUpdates = 5
tower := h.newTower()
// Create and insert a new session.
session1 := h.randSession(h.t, tower.ID, maxUpdates)
h.insertSession(session1, nil)
// Create a new channel and register it.
chanID1 := randChannelID(h.t)
h.registerChan(chanID1, nil, nil)
// Num acked updates should be 0.
require.Zero(h.t, h.numAcked(&session1.ID, nil))
// Commit and ACK enough updates for this channel to fill the session.
for i := 1; i <= maxUpdates; i++ {
update := randCommittedUpdateForChanWithHeight(
h.t, chanID1, uint16(i), uint64(i),
)
lastApplied := h.commitUpdate(&session1.ID, update, nil)
h.ackUpdate(&session1.ID, uint16(i), lastApplied, nil)
}
// Num acked updates should now be 5.
require.EqualValues(h.t, maxUpdates, h.numAcked(&session1.ID, nil))
// Commit one more update for the channel but this time do not ACK it.
// This update will be put in a new session since the previous one has
// been exhausted.
session2 := h.randSession(h.t, tower.ID, maxUpdates)
sess2Seq := 1
h.insertSession(session2, nil)
update := randCommittedUpdateForChanWithHeight(
h.t, chanID1, uint16(sess2Seq), uint64(maxUpdates+1),
)
lastApplied := h.commitUpdate(&session2.ID, update, nil)
// Session 2 should not have any acked updates yet.
require.Zero(h.t, h.numAcked(&session2.ID, nil))
// There should currently be no closable sessions.
require.Empty(h.t, h.listClosableSessions(nil))
// Now mark the channel as closed.
h.markChannelClosed(chanID1, 1, nil)
// Assert that session 1 is now seen as closable.
closableSessionsMap := h.listClosableSessions(nil)
require.Len(h.t, closableSessionsMap, 1)
_, ok := closableSessionsMap[session1.ID]
require.True(h.t, ok)
// Delete session 1.
h.deleteSession(session1.ID, nil)
// Now try to ACK the update for the channel. This should succeed and
// the update should be considered a rogue update.
h.ackUpdate(&session2.ID, uint16(sess2Seq), lastApplied, nil)
// Show that the number of acked updates is now 1.
require.EqualValues(h.t, 1, h.numAcked(&session2.ID, nil))
// We also want to test the extreme case where all the updates for a
// particular session are rogue updates. In this case, the session
// should be seen as closable if it is saturated.
// First show that the session is not yet considered closable.
require.Empty(h.t, h.listClosableSessions(nil))
// Then, let's continue adding rogue updates for the closed channel to
// session 2.
for i := maxUpdates + 2; i <= maxUpdates*2; i++ {
sess2Seq++
update := randCommittedUpdateForChanWithHeight(
h.t, chanID1, uint16(sess2Seq), uint64(i),
)
lastApplied := h.commitUpdate(&session2.ID, update, nil)
h.ackUpdate(&session2.ID, uint16(sess2Seq), lastApplied, nil)
}
// At this point, session 2 is saturated with rogue updates. Assert that
// it is now closable.
closableSessionsMap = h.listClosableSessions(nil)
require.Len(h.t, closableSessionsMap, 1)
}
// testMarkChannelClosed asserts the behaviour of MarkChannelClosed.
func testMarkChannelClosed(h *clientDBHarness) {
tower := h.newTower()
@ -763,7 +854,7 @@ func testMarkChannelClosed(h *clientDBHarness) {
require.EqualValues(h.t, 4, lastApplied)
h.ackUpdate(&session1.ID, 5, 5, nil)
// The session is no exhausted.
// The session is now exhausted.
// If we now close channel 5, session 1 should still not be closable
// since it has an update for channel 6 which is still open.
sl = h.markChannelClosed(chanID5, 1, nil)
@ -964,12 +1055,6 @@ func TestClientDB(t *testing.T) {
return db
},
},
{
name: "mock",
init: func(t *testing.T) wtclient.DB {
return wtmock.NewClientDB()
},
},
}
tests := []struct {
@ -1008,6 +1093,10 @@ func TestClientDB(t *testing.T) {
name: "mark channel closed",
run: testMarkChannelClosed,
},
{
name: "rogue updates",
run: testRogueUpdates,
},
}
for _, database := range dbs {
@ -1073,6 +1162,34 @@ func randCommittedUpdateForChannel(t *testing.T, chanID lnwire.ChannelID,
}
}
// randCommittedUpdateForChanWithHeight generates a random committed update for
// the given channel ID using the given commit height.
func randCommittedUpdateForChanWithHeight(t *testing.T, chanID lnwire.ChannelID,
seqNum uint16, height uint64) *wtdb.CommittedUpdate {
t.Helper()
var hint blob.BreachHint
_, err := io.ReadFull(crand.Reader, hint[:])
require.NoError(t, err)
encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type()))
_, err = io.ReadFull(crand.Reader, encBlob)
require.NoError(t, err)
return &wtdb.CommittedUpdate{
SeqNum: seqNum,
CommittedUpdateBody: wtdb.CommittedUpdateBody{
BackupID: wtdb.BackupID{
ChanID: chanID,
CommitHeight: height,
},
Hint: hint,
EncryptedBlob: encBlob,
},
}
}
func (h *clientDBHarness) randSession(t *testing.T,
towerID wtdb.TowerID, maxUpdates uint16) *wtdb.ClientSession {

View file

@ -4,9 +4,7 @@ import (
"testing"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/watchtower/wtclient"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtmock"
"github.com/stretchr/testify/require"
)
@ -15,53 +13,24 @@ import (
func TestDiskQueue(t *testing.T) {
t.Parallel()
dbs := []struct {
name string
init clientDBInit
}{
{
name: "bbolt",
init: func(t *testing.T) wtclient.DB {
dbCfg := &kvdb.BoltConfig{
DBTimeout: kvdb.DefaultDBTimeout,
}
// Construct the ClientDB.
bdb, err := wtdb.NewBoltBackendCreator(
true, t.TempDir(), "wtclient.db",
)(dbCfg)
require.NoError(t, err)
db, err := wtdb.OpenClientDB(bdb)
require.NoError(t, err)
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
})
return db
},
},
{
name: "mock",
init: func(t *testing.T) wtclient.DB {
return wtmock.NewClientDB()
},
},
dbCfg := &kvdb.BoltConfig{
DBTimeout: kvdb.DefaultDBTimeout,
}
for _, database := range dbs {
db := database
t.Run(db.name, func(t *testing.T) {
t.Parallel()
// Construct the ClientDB.
bdb, err := wtdb.NewBoltBackendCreator(
true, t.TempDir(), "wtclient.db",
)(dbCfg)
require.NoError(t, err)
testQueue(t, db.init(t))
})
}
}
db, err := wtdb.OpenClientDB(bdb)
require.NoError(t, err)
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
})
func testQueue(t *testing.T, db wtclient.DB) {
namespace := []byte("test-namespace")
queue := db.GetDBQueue(namespace)

View file

@ -1,887 +0,0 @@
package wtmock
import (
"encoding/binary"
"net"
"sync"
"sync/atomic"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
)
var byteOrder = binary.BigEndian
type towerPK [33]byte
type keyIndexKey struct {
towerID wtdb.TowerID
blobType blob.Type
}
type rangeIndexArrayMap map[wtdb.SessionID]map[lnwire.ChannelID]*wtdb.RangeIndex
type rangeIndexKVStore map[wtdb.SessionID]map[lnwire.ChannelID]*mockKVStore
type channel struct {
summary *wtdb.ClientChanSummary
closedHeight uint32
sessions map[wtdb.SessionID]bool
}
// ClientDB is a mock, in-memory database or testing the watchtower client
// behavior.
type ClientDB struct {
nextTowerID uint64 // to be used atomically
mu sync.Mutex
channels map[lnwire.ChannelID]*channel
activeSessions map[wtdb.SessionID]wtdb.ClientSession
ackedUpdates rangeIndexArrayMap
persistedAckedUpdates rangeIndexKVStore
committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate
towerIndex map[towerPK]wtdb.TowerID
towers map[wtdb.TowerID]*wtdb.Tower
closableSessions map[wtdb.SessionID]uint32
nextIndex uint32
indexes map[keyIndexKey]uint32
legacyIndexes map[wtdb.TowerID]uint32
queues map[string]wtdb.Queue[*wtdb.BackupID]
}
// NewClientDB initializes a new mock ClientDB.
func NewClientDB() *ClientDB {
return &ClientDB{
channels: make(map[lnwire.ChannelID]*channel),
activeSessions: make(
map[wtdb.SessionID]wtdb.ClientSession,
),
ackedUpdates: make(rangeIndexArrayMap),
persistedAckedUpdates: make(rangeIndexKVStore),
committedUpdates: make(
map[wtdb.SessionID][]wtdb.CommittedUpdate,
),
towerIndex: make(map[towerPK]wtdb.TowerID),
towers: make(map[wtdb.TowerID]*wtdb.Tower),
indexes: make(map[keyIndexKey]uint32),
legacyIndexes: make(map[wtdb.TowerID]uint32),
closableSessions: make(map[wtdb.SessionID]uint32),
queues: make(map[string]wtdb.Queue[*wtdb.BackupID]),
}
}
// CreateTower initialize an address record used to communicate with a
// watchtower. Each Tower is assigned a unique ID, that is used to amortize
// storage costs of the public key when used by multiple sessions. If the tower
// already exists, the address is appended to the list of all addresses used to
// that tower previously and its corresponding sessions are marked as active.
func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
m.mu.Lock()
defer m.mu.Unlock()
var towerPubKey towerPK
copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed())
var tower *wtdb.Tower
towerID, ok := m.towerIndex[towerPubKey]
if ok {
tower = m.towers[towerID]
tower.AddAddress(lnAddr.Address)
towerSessions, err := m.listClientSessions(&towerID)
if err != nil {
return nil, err
}
for id, session := range towerSessions {
session.Status = wtdb.CSessionActive
m.activeSessions[id] = *session
}
} else {
towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1))
tower = &wtdb.Tower{
ID: towerID,
IdentityKey: lnAddr.IdentityKey,
Addresses: []net.Addr{lnAddr.Address},
}
}
m.towerIndex[towerPubKey] = towerID
m.towers[towerID] = tower
return copyTower(tower), nil
}
// RemoveTower modifies a tower's record within the database. If an address is
// provided, then _only_ the address record should be removed from the tower's
// persisted state. Otherwise, we'll attempt to mark the tower as inactive by
// marking all of its sessions inactive. If any of its sessions has unacked
// updates, then ErrTowerUnackedUpdates is returned. If the tower doesn't have
// any sessions at all, it'll be completely removed from the database.
//
// NOTE: An error is not returned if the tower doesn't exist.
func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
m.mu.Lock()
defer m.mu.Unlock()
tower, err := m.loadTower(pubKey)
if err == wtdb.ErrTowerNotFound {
return nil
}
if err != nil {
return err
}
if addr != nil {
tower.RemoveAddress(addr)
if len(tower.Addresses) == 0 {
return wtdb.ErrLastTowerAddr
}
m.towers[tower.ID] = tower
return nil
}
towerSessions, err := m.listClientSessions(&tower.ID)
if err != nil {
return err
}
if len(towerSessions) == 0 {
var towerPK towerPK
copy(towerPK[:], pubKey.SerializeCompressed())
delete(m.towerIndex, towerPK)
delete(m.towers, tower.ID)
return nil
}
for id, session := range towerSessions {
if len(m.committedUpdates[session.ID]) > 0 {
return wtdb.ErrTowerUnackedUpdates
}
session.Status = wtdb.CSessionInactive
m.activeSessions[id] = *session
}
return nil
}
// LoadTower retrieves a tower by its public key.
func (m *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*wtdb.Tower, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.loadTower(pubKey)
}
// loadTower retrieves a tower by its public key.
//
// NOTE: This method requires the database's lock to be acquired.
func (m *ClientDB) loadTower(pubKey *btcec.PublicKey) (*wtdb.Tower, error) {
var towerPK towerPK
copy(towerPK[:], pubKey.SerializeCompressed())
towerID, ok := m.towerIndex[towerPK]
if !ok {
return nil, wtdb.ErrTowerNotFound
}
tower, ok := m.towers[towerID]
if !ok {
return nil, wtdb.ErrTowerNotFound
}
return copyTower(tower), nil
}
// LoadTowerByID retrieves a tower by its tower ID.
func (m *ClientDB) LoadTowerByID(towerID wtdb.TowerID) (*wtdb.Tower, error) {
m.mu.Lock()
defer m.mu.Unlock()
if tower, ok := m.towers[towerID]; ok {
return copyTower(tower), nil
}
return nil, wtdb.ErrTowerNotFound
}
// ListTowers retrieves the list of towers available within the database.
func (m *ClientDB) ListTowers() ([]*wtdb.Tower, error) {
m.mu.Lock()
defer m.mu.Unlock()
towers := make([]*wtdb.Tower, 0, len(m.towers))
for _, tower := range m.towers {
towers = append(towers, copyTower(tower))
}
return towers, nil
}
// MarkBackupIneligible records that particular commit height is ineligible for
// backup. This allows the client to track which updates it should not attempt
// to retry after startup.
func (m *ClientDB) MarkBackupIneligible(_ lnwire.ChannelID, _ uint64) error {
return nil
}
// ListClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID,
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.listClientSessions(tower, opts...)
}
// listClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
cfg := wtdb.NewClientSessionCfg()
for _, o := range opts {
o(cfg)
}
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
for _, session := range m.activeSessions {
session := session
if tower != nil && *tower != session.TowerID {
continue
}
if cfg.PreEvaluateFilterFn != nil &&
!cfg.PreEvaluateFilterFn(&session) {
continue
}
if cfg.PerMaxHeight != nil {
for chanID, index := range m.ackedUpdates[session.ID] {
cfg.PerMaxHeight(
&session, chanID, index.MaxHeight(),
)
}
}
if cfg.PerNumAckedUpdates != nil {
for chanID, index := range m.ackedUpdates[session.ID] {
cfg.PerNumAckedUpdates(
&session, chanID,
uint16(index.NumInSet()),
)
}
}
if cfg.PerCommittedUpdate != nil {
for _, update := range m.committedUpdates[session.ID] {
update := update
cfg.PerCommittedUpdate(&session, &update)
}
}
if cfg.PostEvaluateFilterFn != nil &&
!cfg.PostEvaluateFilterFn(&session) {
continue
}
sessions[session.ID] = &session
}
return sessions, nil
}
// FetchSessionCommittedUpdates retrieves the current set of un-acked updates
// of the given session.
func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) (
[]wtdb.CommittedUpdate, error) {
m.mu.Lock()
defer m.mu.Unlock()
updates, ok := m.committedUpdates[*id]
if !ok {
return nil, wtdb.ErrClientSessionNotFound
}
return updates, nil
}
// IsAcked returns true if the given backup has been backed up using the given
// session.
func (m *ClientDB) IsAcked(id *wtdb.SessionID, backupID *wtdb.BackupID) (bool,
error) {
m.mu.Lock()
defer m.mu.Unlock()
index, ok := m.ackedUpdates[*id][backupID.ChanID]
if !ok {
return false, nil
}
return index.IsInIndex(backupID.CommitHeight), nil
}
// NumAckedUpdates returns the number of backups that have been successfully
// backed up using the given session.
func (m *ClientDB) NumAckedUpdates(id *wtdb.SessionID) (uint64, error) {
m.mu.Lock()
defer m.mu.Unlock()
var numAcked uint64
for _, index := range m.ackedUpdates[*id] {
numAcked += index.NumInSet()
}
return numAcked, nil
}
// CreateClientSession records a newly negotiated client session in the set of
// active sessions. The session can be identified by its SessionID.
func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
m.mu.Lock()
defer m.mu.Unlock()
// Ensure that we aren't overwriting an existing session.
if _, ok := m.activeSessions[session.ID]; ok {
return wtdb.ErrClientSessionAlreadyExists
}
key := keyIndexKey{
towerID: session.TowerID,
blobType: session.Policy.BlobType,
}
// Ensure that a session key index has been reserved for this tower.
keyIndex, err := m.getSessionKeyIndex(key)
if err != nil {
return err
}
// Ensure that the session's index matches the reserved index.
if keyIndex != session.KeyIndex {
return wtdb.ErrIncorrectKeyIndex
}
// Remove the key index reservation for this tower. Once committed, this
// permits us to create another session with this tower.
delete(m.indexes, key)
if key.blobType == blob.TypeAltruistCommit {
delete(m.legacyIndexes, key.towerID)
}
m.activeSessions[session.ID] = wtdb.ClientSession{
ID: session.ID,
ClientSessionBody: wtdb.ClientSessionBody{
SeqNum: session.SeqNum,
TowerLastApplied: session.TowerLastApplied,
TowerID: session.TowerID,
KeyIndex: session.KeyIndex,
Policy: session.Policy,
RewardPkScript: cloneBytes(session.RewardPkScript),
},
}
m.ackedUpdates[session.ID] = make(map[lnwire.ChannelID]*wtdb.RangeIndex)
m.persistedAckedUpdates[session.ID] = make(
map[lnwire.ChannelID]*mockKVStore,
)
m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0)
return nil
}
// NextSessionKeyIndex reserves a new session key derivation index for a
// particular tower id. The index is reserved for that tower until
// CreateClientSession is invoked for that tower and index, at which point a new
// index for that tower can be reserved. Multiple calls to this method before
// CreateClientSession is invoked should return the same index unless forceNext
// is set to true.
func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID, blobType blob.Type,
forceNext bool) (uint32, error) {
m.mu.Lock()
defer m.mu.Unlock()
key := keyIndexKey{
towerID: towerID,
blobType: blobType,
}
if !forceNext {
if index, err := m.getSessionKeyIndex(key); err == nil {
return index, nil
}
}
// By default, we use the next available bucket sequence as the key
// index. But if forceNext is true, then it is assumed that some data
// loss occurred and so the sequence is incremented a by a jump of 1000
// so that we can arrive at a brand new key index quicker.
nextIndex := m.nextIndex + 1
if forceNext {
nextIndex = m.nextIndex + 1000
}
m.nextIndex = nextIndex
m.indexes[key] = nextIndex
return nextIndex, nil
}
func (m *ClientDB) getSessionKeyIndex(key keyIndexKey) (uint32, error) {
if index, ok := m.indexes[key]; ok {
return index, nil
}
if key.blobType == blob.TypeAltruistCommit {
if index, ok := m.legacyIndexes[key.towerID]; ok {
return index, nil
}
}
return 0, wtdb.ErrNoReservedKeyIndex
}
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
// seqNum). This allows the client to retransmit this update on startup.
func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
update *wtdb.CommittedUpdate) (uint16, error) {
m.mu.Lock()
defer m.mu.Unlock()
// Fail if session doesn't exist.
session, ok := m.activeSessions[*id]
if !ok {
return 0, wtdb.ErrClientSessionNotFound
}
// Check if an update has already been committed for this state.
for _, dbUpdate := range m.committedUpdates[session.ID] {
if dbUpdate.SeqNum == update.SeqNum {
// If the breach hint matches, we'll just return the
// last applied value so the client can retransmit.
if dbUpdate.Hint == update.Hint {
return session.TowerLastApplied, nil
}
// Otherwise, fail since the breach hint doesn't match.
return 0, wtdb.ErrUpdateAlreadyCommitted
}
}
// Sequence number must increment.
if update.SeqNum != session.SeqNum+1 {
return 0, wtdb.ErrCommitUnorderedUpdate
}
// Save the update and increment the sequence number.
m.committedUpdates[session.ID] = append(
m.committedUpdates[session.ID], *update,
)
session.SeqNum++
m.activeSessions[*id] = session
return session.TowerLastApplied, nil
}
// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This
// removes the update from the set of committed updates, and validates the
// lastApplied value returned from the tower.
func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum,
lastApplied uint16) error {
m.mu.Lock()
defer m.mu.Unlock()
// Fail if session doesn't exist.
session, ok := m.activeSessions[*id]
if !ok {
return wtdb.ErrClientSessionNotFound
}
// Ensure the returned last applied value does not exceed the highest
// allocated sequence number.
if lastApplied > session.SeqNum {
return wtdb.ErrUnallocatedLastApplied
}
// Ensure the last applied value isn't lower than a previous one sent by
// the tower.
if lastApplied < session.TowerLastApplied {
return wtdb.ErrLastAppliedReversion
}
// Retrieve the committed update, failing if none is found. We should
// only receive acks for state updates that we send.
updates := m.committedUpdates[session.ID]
for i, update := range updates {
if update.SeqNum != seqNum {
continue
}
// Add sessionID to channel.
channel, ok := m.channels[update.BackupID.ChanID]
if !ok {
return wtdb.ErrChannelNotRegistered
}
channel.sessions[*id] = true
// Remove the committed update from disk and mark the update as
// acked. The tower last applied value is also recorded to send
// along with the next update.
copy(updates[:i], updates[i+1:])
updates[len(updates)-1] = wtdb.CommittedUpdate{}
m.committedUpdates[session.ID] = updates[:len(updates)-1]
chanID := update.BackupID.ChanID
if _, ok := m.ackedUpdates[*id][update.BackupID.ChanID]; !ok {
index, err := wtdb.NewRangeIndex(nil)
if err != nil {
return err
}
m.ackedUpdates[*id][chanID] = index
m.persistedAckedUpdates[*id][chanID] = newMockKVStore()
}
err := m.ackedUpdates[*id][chanID].Add(
update.BackupID.CommitHeight,
m.persistedAckedUpdates[*id][chanID],
)
if err != nil {
return err
}
session.TowerLastApplied = lastApplied
m.activeSessions[*id] = session
return nil
}
return wtdb.ErrCommittedUpdateNotFound
}
// GetDBQueue returns a BackupID Queue instance under the given name space.
func (m *ClientDB) GetDBQueue(namespace []byte) wtdb.Queue[*wtdb.BackupID] {
m.mu.Lock()
defer m.mu.Unlock()
if q, ok := m.queues[string(namespace)]; ok {
return q
}
q := NewQueueDB[*wtdb.BackupID]()
m.queues[string(namespace)] = q
return q
}
// DeleteCommittedUpdate deletes the committed update with the given sequence
// number from the given session.
func (m *ClientDB) DeleteCommittedUpdate(id *wtdb.SessionID,
seqNum uint16) error {
m.mu.Lock()
defer m.mu.Unlock()
// Fail if session doesn't exist.
session, ok := m.activeSessions[*id]
if !ok {
return wtdb.ErrClientSessionNotFound
}
// Retrieve the committed update, failing if none is found.
updates := m.committedUpdates[session.ID]
for i, update := range updates {
if update.SeqNum != seqNum {
continue
}
// Remove the committed update from "disk".
updates = append(updates[:i], updates[i+1:]...)
m.committedUpdates[session.ID] = updates
return nil
}
return wtdb.ErrCommittedUpdateNotFound
}
// ListClosableSessions fetches and returns the IDs for all sessions marked as
// closable.
func (m *ClientDB) ListClosableSessions() (map[wtdb.SessionID]uint32, error) {
m.mu.Lock()
defer m.mu.Unlock()
cs := make(map[wtdb.SessionID]uint32, len(m.closableSessions))
for id, height := range m.closableSessions {
cs[id] = height
}
return cs, nil
}
// FetchChanSummaries loads a mapping from all registered channels to their
// channel summaries. Only the channels that have not yet been marked as closed
// will be loaded.
func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) {
m.mu.Lock()
defer m.mu.Unlock()
summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary)
for chanID, channel := range m.channels {
// Don't load the channel if it has been marked as closed.
if channel.closedHeight > 0 {
continue
}
summaries[chanID] = wtdb.ClientChanSummary{
SweepPkScript: cloneBytes(
channel.summary.SweepPkScript,
),
}
}
return summaries, nil
}
// MarkChannelClosed will mark a registered channel as closed by setting
// its closed-height as the given block height. It returns a list of
// session IDs for sessions that are now considered closable due to the
// close of this channel.
func (m *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID,
blockHeight uint32) ([]wtdb.SessionID, error) {
m.mu.Lock()
defer m.mu.Unlock()
channel, ok := m.channels[chanID]
if !ok {
return nil, wtdb.ErrChannelNotRegistered
}
// If there are no sessions for this channel, the channel details can be
// deleted.
if len(channel.sessions) == 0 {
delete(m.channels, chanID)
return nil, nil
}
// Mark the channel as closed.
channel.closedHeight = blockHeight
// Now iterate through all the sessions of the channel to check if any
// of them are closeable.
var closableSessions []wtdb.SessionID
for sessID := range channel.sessions {
isClosable, err := m.isSessionClosable(sessID)
if err != nil {
return nil, err
}
if !isClosable {
continue
}
closableSessions = append(closableSessions, sessID)
// Add session to "closableSessions" list and add the block
// height that this last channel was closed in. This will be
// used in future to determine when we should delete the
// session.
m.closableSessions[sessID] = blockHeight
}
return closableSessions, nil
}
// isSessionClosable returns true if a session is considered closable. A session
// is considered closable only if:
// 1) It has no un-acked updates
// 2) It is exhausted (ie it cant accept any more updates)
// 3) All the channels that it has acked-updates for are closed.
func (m *ClientDB) isSessionClosable(id wtdb.SessionID) (bool, error) {
// The session is not closable if it has un-acked updates.
if len(m.committedUpdates[id]) > 0 {
return false, nil
}
sess, ok := m.activeSessions[id]
if !ok {
return false, wtdb.ErrClientSessionNotFound
}
// The session is not closable if it is not yet exhausted.
if sess.SeqNum != sess.Policy.MaxUpdates {
return false, nil
}
// Iterate over each of the channels that the session has acked-updates
// for. If any of those channels are not closed, then the session is
// not yet closable.
for chanID := range m.ackedUpdates[id] {
channel, ok := m.channels[chanID]
if !ok {
continue
}
// Channel is not yet closed, and so we can not yet delete the
// session.
if channel.closedHeight == 0 {
return false, nil
}
}
return true, nil
}
// GetClientSession loads the ClientSession with the given ID from the DB.
func (m *ClientDB) GetClientSession(id wtdb.SessionID,
opts ...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error) {
cfg := wtdb.NewClientSessionCfg()
for _, o := range opts {
o(cfg)
}
session, ok := m.activeSessions[id]
if !ok {
return nil, wtdb.ErrClientSessionNotFound
}
if cfg.PerMaxHeight != nil {
for chanID, index := range m.ackedUpdates[session.ID] {
cfg.PerMaxHeight(&session, chanID, index.MaxHeight())
}
}
if cfg.PerCommittedUpdate != nil {
for _, update := range m.committedUpdates[session.ID] {
update := update
cfg.PerCommittedUpdate(&session, &update)
}
}
return &session, nil
}
// DeleteSession can be called when a session should be deleted from the DB.
// All references to the session will also be deleted from the DB. Note that a
// session will only be deleted if it is considered closable.
func (m *ClientDB) DeleteSession(id wtdb.SessionID) error {
m.mu.Lock()
defer m.mu.Unlock()
_, ok := m.closableSessions[id]
if !ok {
return wtdb.ErrSessionNotClosable
}
// For each of the channels, delete the session ID entry.
for chanID := range m.ackedUpdates[id] {
c, ok := m.channels[chanID]
if !ok {
return wtdb.ErrChannelNotRegistered
}
delete(c.sessions, id)
}
delete(m.closableSessions, id)
delete(m.activeSessions, id)
return nil
}
// RegisterChannel registers a channel for use within the client database. For
// now, all that is stored in the channel summary is the sweep pkscript that
// we'd like any tower sweeps to pay into. In the future, this will be extended
// to contain more info to allow the client efficiently request historical
// states to be backed up under the client's active policy.
func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
sweepPkScript []byte) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.channels[chanID]; ok {
return wtdb.ErrChannelAlreadyRegistered
}
m.channels[chanID] = &channel{
summary: &wtdb.ClientChanSummary{
SweepPkScript: cloneBytes(sweepPkScript),
},
sessions: make(map[wtdb.SessionID]bool),
}
return nil
}
func cloneBytes(b []byte) []byte {
if b == nil {
return nil
}
bb := make([]byte, len(b))
copy(bb, b)
return bb
}
func copyTower(tower *wtdb.Tower) *wtdb.Tower {
t := &wtdb.Tower{
ID: tower.ID,
IdentityKey: tower.IdentityKey,
Addresses: make([]net.Addr, len(tower.Addresses)),
}
copy(t.Addresses, tower.Addresses)
return t
}
type mockKVStore struct {
kv map[uint64]uint64
err error
}
func newMockKVStore() *mockKVStore {
return &mockKVStore{
kv: make(map[uint64]uint64),
}
}
func (m *mockKVStore) Put(key, value []byte) error {
if m.err != nil {
return m.err
}
k := byteOrder.Uint64(key)
v := byteOrder.Uint64(value)
m.kv[k] = v
return nil
}
func (m *mockKVStore) Delete(key []byte) error {
if m.err != nil {
return m.err
}
k := byteOrder.Uint64(key)
delete(m.kv, k)
return nil
}