multi: migrate towers to use RangeIndex for AckedUpdates

In this commit, a migration is done that takes all the AckedUpdates of
all sessions and stores them in the RangeIndex pattern instead and
deletes the session's old AckedUpdates bucket. All the logic in the code
is also updates in order to write and read from this new structure.
This commit is contained in:
Elle Mouton 2022-12-23 11:14:01 +02:00
parent 50ad10666c
commit c3a2368f46
No known key found for this signature in database
GPG Key ID: D7D916376026F177
9 changed files with 524 additions and 120 deletions

View File

@ -341,27 +341,27 @@ func constructFunctionalOptions(includeSessions bool) (
var (
opts []wtdb.ClientSessionListOption
ackCounts = make(map[wtdb.SessionID]uint16)
committedUpdateCounts = make(map[wtdb.SessionID]uint16)
ackCounts = make(map[wtdb.SessionID]uint16)
)
if !includeSessions {
return opts, ackCounts, committedUpdateCounts
}
perAckedUpdate := func(s *wtdb.ClientSession, _ uint16,
_ wtdb.BackupID) {
perNumAckedUpdates := func(s *wtdb.ClientSession, id lnwire.ChannelID,
numUpdates uint16) {
ackCounts[s.ID]++
ackCounts[s.ID] += numUpdates
}
perCommittedUpdate := func(s *wtdb.ClientSession,
_ *wtdb.CommittedUpdate) {
u *wtdb.CommittedUpdate) {
committedUpdateCounts[s.ID]++
}
opts = []wtdb.ClientSessionListOption{
wtdb.WithPerAckedUpdate(perAckedUpdate),
wtdb.WithPerNumAckedUpdates(perNumAckedUpdates),
wtdb.WithPerCommittedUpdate(perCommittedUpdate),
}
@ -438,7 +438,8 @@ func (c *WatchtowerClient) Policy(ctx context.Context,
// marshallTower converts a client registered watchtower into its corresponding
// RPC type.
func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool,
ackCounts, pendingCounts map[wtdb.SessionID]uint16) *Tower {
ackCounts map[wtdb.SessionID]uint16,
pendingCounts map[wtdb.SessionID]uint16) *Tower {
rpcAddrs := make([]string, 0, len(tower.Addresses))
for _, addr := range tower.Addresses {

View File

@ -4,6 +4,8 @@ import (
"github.com/btcsuite/btclog"
"github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/watchtower/lookout"
"github.com/lightningnetwork/lnd/watchtower/wtclient"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtserver"
)
@ -30,4 +32,6 @@ func UseLogger(logger btclog.Logger) {
log = logger
lookout.UseLogger(logger)
wtserver.UseLogger(logger)
wtclient.UseLogger(logger)
wtdb.UseLogger(logger)
}

View File

@ -314,7 +314,9 @@ func New(config *Config) (*TowerClient, error) {
// determine the highest known commit height for each channel. This
// allows the client to reject backups that it has already processed for
// its active policy.
perUpdate := func(policy wtpolicy.Policy, id wtdb.BackupID) {
perUpdate := func(policy wtpolicy.Policy, chanID lnwire.ChannelID,
commitHeight uint64) {
// We only want to consider accepted updates that have been
// accepted under an identical policy to the client's current
// policy.
@ -324,22 +326,22 @@ func New(config *Config) (*TowerClient, error) {
// Take the highest commit height found in the session's acked
// updates.
height, ok := c.chanCommitHeights[id.ChanID]
if !ok || id.CommitHeight > height {
c.chanCommitHeights[id.ChanID] = id.CommitHeight
height, ok := c.chanCommitHeights[chanID]
if !ok || commitHeight > height {
c.chanCommitHeights[chanID] = commitHeight
}
}
perAckedUpdate := func(s *wtdb.ClientSession, _ uint16,
id wtdb.BackupID) {
perMaxHeight := func(s *wtdb.ClientSession, chanID lnwire.ChannelID,
height uint64) {
perUpdate(s.Policy, id)
perUpdate(s.Policy, chanID, height)
}
perCommittedUpdate := func(s *wtdb.ClientSession,
u *wtdb.CommittedUpdate) {
perUpdate(s.Policy, u.BackupID)
perUpdate(s.Policy, u.BackupID.ChanID, u.BackupID.CommitHeight)
}
// Load all candidate sessions and towers from the database into the
@ -366,7 +368,7 @@ func New(config *Config) (*TowerClient, error) {
candidateSessions, err := getTowerAndSessionCandidates(
cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower,
wtdb.WithPerAckedUpdate(perAckedUpdate),
wtdb.WithPerMaxHeight(perMaxHeight),
wtdb.WithPerCommittedUpdate(perCommittedUpdate),
)
if err != nil {

View File

@ -68,6 +68,14 @@ type DB interface {
FetchSessionCommittedUpdates(id *wtdb.SessionID) (
[]wtdb.CommittedUpdate, error)
// IsAcked returns true if the given backup has been backed up using
// the given session.
IsAcked(id *wtdb.SessionID, backupID *wtdb.BackupID) (bool, error)
// NumAckedUpdates returns the number of backups that have been
// successfully backed up using the given session.
NumAckedUpdates(id *wtdb.SessionID) (uint64, error)
// FetchChanSummaries loads a mapping from all registered channels to
// their channel summaries.
FetchChanSummaries() (wtdb.ChannelSummaries, error)

View File

@ -36,7 +36,7 @@ var (
// cSessionBkt is a top-level bucket storing:
// session-id => cSessionBody -> encoded ClientSessionBody
// => cSessionCommits => seqnum -> encoded CommittedUpdate
// => cSessionAcks => seqnum -> encoded BackupID
// => cSessionAckRangeIndex => db-chan-id => start -> end
cSessionBkt = []byte("client-session-bucket")
// cSessionBody is a sub-bucket of cSessionBkt storing only the body of
@ -47,9 +47,9 @@ var (
// seqnum -> encoded CommittedUpdate.
cSessionCommits = []byte("client-session-commits")
// cSessionAcks is a sub-bucket of cSessionBkt storing:
// seqnum -> encoded BackupID.
cSessionAcks = []byte("client-session-acks")
// cSessionAckRangeIndex is a sub-bucket of cSessionBkt storing
// chan-id => start -> end
cSessionAckRangeIndex = []byte("client-session-ack-range-index")
// cChanIDIndexBkt is a top-level bucket storing:
// db-assigned-id -> channel-ID
@ -422,6 +422,11 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
return ErrUninitializedDB
}
chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt)
if chanIDIndexBkt == nil {
return ErrUninitializedDB
}
// Don't return an error if the watchtower doesn't exist to act
// as a NOP.
pubKeyBytes := pubKey.SerializeCompressed()
@ -463,7 +468,8 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
}
towerSessions, err := c.listTowerSessions(
towerID, sessions, towersToSessionsIndex,
towerID, sessions, chanIDIndexBkt,
towersToSessionsIndex,
WithPerCommittedUpdate(perCommittedUpdate),
)
if err != nil {
@ -763,6 +769,149 @@ func readRangeIndex(rangesBkt kvdb.RBucket) (*RangeIndex, error) {
return NewRangeIndex(ranges, WithSerializeUint64Fn(writeBigSize))
}
// getRangeIndex checks the ClientDB's in-memory range index map to see if it
// has an entry for the given session and channel ID. If it does, this is
// returned, otherwise the range index is loaded from the DB. An optional db
// transaction parameter may be provided. If one is provided then it will be
// used to query the DB for the range index, otherwise, a new transaction will
// be created and used.
func (c *ClientDB) getRangeIndex(tx kvdb.RTx, sID SessionID,
chanID lnwire.ChannelID) (*RangeIndex, error) {
c.ackedRangeIndexMu.Lock()
defer c.ackedRangeIndexMu.Unlock()
if _, ok := c.ackedRangeIndex[sID]; !ok {
c.ackedRangeIndex[sID] = make(map[lnwire.ChannelID]*RangeIndex)
}
// If the in-memory range-index map already includes an entry for this
// session ID and channel ID pair, then return it.
if index, ok := c.ackedRangeIndex[sID][chanID]; ok {
return index, nil
}
// readRangeIndexFromBkt is a helper that is used to read in a
// RangeIndex structure from the passed in bucket and store it in the
// ackedRangeIndex map.
readRangeIndexFromBkt := func(rangesBkt kvdb.RBucket) (*RangeIndex,
error) {
// Create a new in-memory RangeIndex by reading in ranges from
// the DB.
rangeIndex, err := readRangeIndex(rangesBkt)
if err != nil {
return nil, err
}
c.ackedRangeIndex[sID][chanID] = rangeIndex
return rangeIndex, nil
}
// If a DB transaction is provided then use it to fetch the ranges
// bucket from the DB.
if tx != nil {
rangesBkt, err := getRangesReadBucket(tx, sID, chanID)
if err != nil {
return nil, err
}
return readRangeIndexFromBkt(rangesBkt)
}
// No DB transaction was provided. So create and use a new one.
var index *RangeIndex
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
rangesBkt, err := getRangesReadBucket(tx, sID, chanID)
if err != nil {
return err
}
index, err = readRangeIndexFromBkt(rangesBkt)
return err
}, func() {})
if err != nil {
return nil, err
}
return index, nil
}
// getRangesReadBucket gets the range index bucket where the range index for the
// given session-channel pair is stored. If any sub-buckets along the way do not
// exist, then an error is returned. If the sub-buckets should be created
// instead, then use getRangesWriteBucket.
func getRangesReadBucket(tx kvdb.RTx, sID SessionID, chanID lnwire.ChannelID) (
kvdb.RBucket, error) {
sessions := tx.ReadBucket(cSessionBkt)
if sessions == nil {
return nil, ErrUninitializedDB
}
chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt)
if chanDetailsBkt == nil {
return nil, ErrUninitializedDB
}
sessionBkt := sessions.NestedReadBucket(sID[:])
if sessionsBkt == nil {
return nil, ErrNoRangeIndexFound
}
// Get the DB representation of the channel-ID.
_, dbChanIDBytes, err := getDBChanID(chanDetailsBkt, chanID)
if err != nil {
return nil, err
}
sessionAckRanges := sessionBkt.NestedReadBucket(cSessionAckRangeIndex)
if sessionAckRanges == nil {
return nil, ErrNoRangeIndexFound
}
return sessionAckRanges.NestedReadBucket(dbChanIDBytes), nil
}
// getRangesWriteBucket gets the range index bucket where the range index for
// the given session-channel pair is stored. If any sub-buckets along the way do
// not exist, then they are created.
func getRangesWriteBucket(tx kvdb.RwTx, sID SessionID,
chanID lnwire.ChannelID) (kvdb.RwBucket, error) {
sessions := tx.ReadWriteBucket(cSessionBkt)
if sessions == nil {
return nil, ErrUninitializedDB
}
chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt)
if chanDetailsBkt == nil {
return nil, ErrUninitializedDB
}
sessionBkt, err := sessions.CreateBucketIfNotExists(sID[:])
if err != nil {
return nil, err
}
// Get the DB representation of the channel-ID.
_, dbChanIDBytes, err := getDBChanID(chanDetailsBkt, chanID)
if err != nil {
return nil, err
}
sessionAckRanges, err := sessionBkt.CreateBucketIfNotExists(
cSessionAckRangeIndex,
)
if err != nil {
return nil, err
}
return sessionAckRanges.CreateBucketIfNotExists(dbChanIDBytes)
}
// createSessionKeyIndexKey returns the identifier used in the
// session-key-index index, created as tower-id||blob-type.
//
@ -825,13 +974,18 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
return ErrUninitializedDB
}
chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt)
if chanIDIndexBkt == nil {
return ErrUninitializedDB
}
var err error
// If no tower ID is specified, then fetch all the sessions
// known to the db.
if id == nil {
clientSessions, err = c.listClientAllSessions(
sessions, opts...,
sessions, chanIDIndexBkt, opts...,
)
return err
}
@ -843,7 +997,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
}
clientSessions, err = c.listTowerSessions(
*id, sessions, towerToSessionIndex, opts...,
*id, sessions, chanIDIndexBkt, towerToSessionIndex,
opts...,
)
return err
}, func() {
@ -857,7 +1012,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
}
// listClientAllSessions returns the set of all client sessions known to the db.
func (c *ClientDB) listClientAllSessions(sessions kvdb.RBucket,
func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession)
@ -866,7 +1021,9 @@ func (c *ClientDB) listClientAllSessions(sessions kvdb.RBucket,
// the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height
// for each channel.
session, err := c.getClientSession(sessions, k, opts...)
session, err := c.getClientSession(
sessions, chanIDIndexBkt, k, opts...,
)
if err != nil {
return err
}
@ -884,7 +1041,7 @@ func (c *ClientDB) listClientAllSessions(sessions kvdb.RBucket,
// listTowerSessions returns the set of all client sessions known to the db
// that are associated with the given tower id.
func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt,
func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt,
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
map[SessionID]*ClientSession, error) {
@ -899,7 +1056,9 @@ func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt,
// the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height
// for each channel.
session, err := c.getClientSession(sessionsBkt, k, opts...)
session, err := c.getClientSession(
sessionsBkt, chanIDIndexBkt, k, opts...,
)
if err != nil {
return err
}
@ -944,6 +1103,73 @@ func (c *ClientDB) FetchSessionCommittedUpdates(id *SessionID) (
return committedUpdates, nil
}
// IsAcked returns true if the given backup has been backed up using the given
// session.
func (c *ClientDB) IsAcked(id *SessionID, backupID *BackupID) (bool, error) {
index, err := c.getRangeIndex(nil, *id, backupID.ChanID)
if errors.Is(err, ErrNoRangeIndexFound) {
return false, nil
} else if err != nil {
return false, err
}
return index.IsInIndex(backupID.CommitHeight), nil
}
// NumAckedUpdates returns the number of backups that have been successfully
// backed up using the given session.
func (c *ClientDB) NumAckedUpdates(id *SessionID) (uint64, error) {
var numAcked uint64
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
sessions := tx.ReadBucket(cSessionBkt)
if sessions == nil {
return ErrUninitializedDB
}
chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt)
if chanIDIndexBkt == nil {
return ErrUninitializedDB
}
sessionBkt := sessions.NestedReadBucket(id[:])
if sessionsBkt == nil {
return nil
}
sessionAckRanges := sessionBkt.NestedReadBucket(
cSessionAckRangeIndex,
)
if sessionAckRanges == nil {
return nil
}
// Iterate over the channel ID's in the sessionAckRanges
// bucket.
return sessionAckRanges.ForEach(func(dbChanID, _ []byte) error {
// Get the range index for the session-channel pair.
chanIDBytes := chanIDIndexBkt.Get(dbChanID)
var chanID lnwire.ChannelID
copy(chanID[:], chanIDBytes)
index, err := c.getRangeIndex(tx, *id, chanID)
if err != nil {
return err
}
numAcked += index.NumInSet()
return nil
})
}, func() {
numAcked = 0
})
if err != nil {
return 0, err
}
return numAcked, nil
}
// FetchChanSummaries loads a mapping from all registered channels to their
// channel summaries.
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
@ -1174,6 +1400,11 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
return ErrUninitializedDB
}
chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt)
if chanDetailsBkt == nil {
return ErrUninitializedDB
}
// We'll only load the ClientSession body for performance, since
// we primarily need to inspect its SeqNum and TowerLastApplied
// fields. The CommittedUpdates and AckedUpdates will be
@ -1242,25 +1473,24 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
return err
}
// Ensure that the session acks sub-bucket is initialized, so we
// can insert an entry.
sessionAcks, err := sessionBkt.CreateBucketIfNotExists(
cSessionAcks,
)
chanID := committedUpdate.BackupID.ChanID
height := committedUpdate.BackupID.CommitHeight
// Get the ranges write bucket before getting the range index to
// ensure that the session acks sub-bucket is initialized, so
// that we can insert an entry.
rangesBkt, err := getRangesWriteBucket(tx, *id, chanID)
if err != nil {
return err
}
// The session acks only need to track the backup id of the
// update, so we can discard the blob and hint.
var b bytes.Buffer
err = committedUpdate.BackupID.Encode(&b)
// Get the range index for the given session-channel pair.
index, err := c.getRangeIndex(tx, *id, chanID)
if err != nil {
return err
}
// Finally, insert the ack into the sessionAcks sub-bucket.
return sessionAcks.Put(seqNumBuf[:], b.Bytes())
return index.Add(height, rangesBkt)
}, func() {})
}
@ -1293,9 +1523,15 @@ func getClientSessionBody(sessions kvdb.RBucket,
return &session, nil
}
// PerAckedUpdateCB describes the signature of a callback function that can be
// called for each of a session's acked updates.
type PerAckedUpdateCB func(*ClientSession, uint16, BackupID)
// PerMaxHeightCB describes the signature of a callback function that can be
// called for each channel that a session has updates for to communicate the
// maximum commitment height that the session has backed up for the channel.
type PerMaxHeightCB func(*ClientSession, lnwire.ChannelID, uint64)
// PerNumAckedUpdatesCB describes the signature of a callback function that can
// be called for each channel that a session has updates for to communicate the
// number of updates that the session has for the channel.
type PerNumAckedUpdatesCB func(*ClientSession, lnwire.ChannelID, uint16)
// PerCommittedUpdateCB describes the signature of a callback function that can
// be called for each of a session's committed updates (updates that the client
@ -1310,9 +1546,15 @@ type ClientSessionListOption func(cfg *ClientSessionListCfg)
// ClientSessionListCfg defines various query parameters that will be used when
// querying the DB for client sessions.
type ClientSessionListCfg struct {
// PerAckedUpdate will, if set, be called for each of the session's
// acked updates.
PerAckedUpdate PerAckedUpdateCB
// PerNumAckedUpdates will, if set, be called for each of the session's
// channels to communicate the number of updates stored for that
// channel.
PerNumAckedUpdates PerNumAckedUpdatesCB
// PerMaxHeight will, if set, be called for each of the session's
// channels to communicate the highest commit height of updates stored
// for that channel.
PerMaxHeight PerMaxHeightCB
// PerCommittedUpdate will, if set, be called for each of the session's
// committed (un-acked) updates.
@ -1324,11 +1566,22 @@ func NewClientSessionCfg() *ClientSessionListCfg {
return &ClientSessionListCfg{}
}
// WithPerAckedUpdate constructs a functional option that will set a call-back
// function to be called for each of a client's acked updates.
func WithPerAckedUpdate(cb PerAckedUpdateCB) ClientSessionListOption {
// WithPerMaxHeight constructs a functional option that will set a call-back
// function to be called for each of a session's channels to communicate the
// maximum commitment height that the session has stored for the channel.
func WithPerMaxHeight(cb PerMaxHeightCB) ClientSessionListOption {
return func(cfg *ClientSessionListCfg) {
cfg.PerAckedUpdate = cb
cfg.PerMaxHeight = cb
}
}
// WithPerNumAckedUpdates constructs a functional option that will set a
// call-back function to be called for each of a session's channels to
// communicate the number of updates that the session has stored for the
// channel.
func WithPerNumAckedUpdates(cb PerNumAckedUpdatesCB) ClientSessionListOption {
return func(cfg *ClientSessionListCfg) {
cfg.PerNumAckedUpdates = cb
}
}
@ -1343,21 +1596,22 @@ func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
// getClientSession loads the full ClientSession associated with the serialized
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
// in addition to the ClientSession's body.
func (c *ClientDB) getClientSession(sessions kvdb.RBucket, idBytes []byte,
opts ...ClientSessionListOption) (*ClientSession, error) {
func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket,
idBytes []byte, opts ...ClientSessionListOption) (*ClientSession,
error) {
cfg := NewClientSessionCfg()
for _, o := range opts {
o(cfg)
}
session, err := getClientSessionBody(sessions, idBytes)
session, err := getClientSessionBody(sessionsBkt, idBytes)
if err != nil {
return nil, err
}
// Can't fail because client session body has already been read.
sessionBkt := sessions.NestedReadBucket(idBytes)
sessionBkt := sessionsBkt.NestedReadBucket(idBytes)
// Pass the session's committed (un-acked) updates through the call-back
// if one is provided.
@ -1370,7 +1624,10 @@ func (c *ClientDB) getClientSession(sessions kvdb.RBucket, idBytes []byte,
// Pass the session's acked updates through the call-back if one is
// provided.
err = filterClientSessionAcks(sessionBkt, session, cfg.PerAckedUpdate)
err = c.filterClientSessionAcks(
sessionBkt, chanIDIndexBkt, session, cfg.PerMaxHeight,
cfg.PerNumAckedUpdates,
)
if err != nil {
return nil, err
}
@ -1419,35 +1676,43 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession,
// filterClientSessionAcks retrieves all acked updates for the session
// identified by the serialized session id and passes them to the provided
// call back if one is provided.
func filterClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession,
cb PerAckedUpdateCB) error {
func (c *ClientDB) filterClientSessionAcks(sessionBkt,
chanIDIndexBkt kvdb.RBucket, s *ClientSession, perMaxCb PerMaxHeightCB,
perNumAckedUpdates PerNumAckedUpdatesCB) error {
if cb == nil {
if perMaxCb == nil && perNumAckedUpdates == nil {
return nil
}
sessionAcks := sessionBkt.NestedReadBucket(cSessionAcks)
if sessionAcks == nil {
sessionAcksRanges := sessionBkt.NestedReadBucket(cSessionAckRangeIndex)
if sessionAcksRanges == nil {
return nil
}
err := sessionAcks.ForEach(func(k, v []byte) error {
seqNum := byteOrder.Uint16(k)
return sessionAcksRanges.ForEach(func(dbChanID, _ []byte) error {
rangeBkt := sessionAcksRanges.NestedReadBucket(dbChanID)
if rangeBkt == nil {
return nil
}
var backupID BackupID
err := backupID.Decode(bytes.NewReader(v))
index, err := readRangeIndex(rangeBkt)
if err != nil {
return err
}
cb(s, seqNum, backupID)
chanIDBytes := chanIDIndexBkt.Get(dbChanID)
var chanID lnwire.ChannelID
copy(chanID[:], chanIDBytes)
if perMaxCb != nil {
perMaxCb(s, chanID, index.MaxHeight())
}
if perNumAckedUpdates != nil {
perNumAckedUpdates(s, chanID, uint16(index.NumInSet()))
}
return nil
})
if err != nil {
return err
}
return nil
}
// filterClientSessionCommits retrieves all committed updates for the session

View File

@ -221,6 +221,26 @@ func (h *clientDBHarness) fetchSessionCommittedUpdates(id *wtdb.SessionID,
return updates
}
func (h *clientDBHarness) isAcked(id *wtdb.SessionID, backupID *wtdb.BackupID,
expErr error) bool {
h.t.Helper()
isAcked, err := h.db.IsAcked(id, backupID)
require.ErrorIs(h.t, err, expErr)
return isAcked
}
func (h *clientDBHarness) numAcked(id *wtdb.SessionID, expErr error) uint64 {
h.t.Helper()
numAcked, err := h.db.NumAckedUpdates(id)
require.ErrorIs(h.t, err, expErr)
return numAcked
}
// testCreateClientSession asserts various conditions regarding the creation of
// a new ClientSession. The test asserts:
// - client sessions can only be created if a session key index is reserved.
@ -453,6 +473,7 @@ func testRemoveTower(h *clientDBHarness) {
}
h.insertSession(session, nil)
update := randCommittedUpdate(h.t, 1)
h.registerChan(update.BackupID.ChanID, nil, nil)
h.commitUpdate(&session.ID, update, nil)
// We should not be able to fully remove it from the database since
@ -583,16 +604,6 @@ func testCommitUpdate(h *clientDBHarness) {
}, nil)
}
func perAckedUpdate(updates map[uint16]wtdb.BackupID) func(
_ *wtdb.ClientSession, seq uint16, id wtdb.BackupID) {
return func(_ *wtdb.ClientSession, seq uint16,
id wtdb.BackupID) {
updates[seq] = id
}
}
// testAckUpdate asserts the behavior of AckUpdate.
func testAckUpdate(h *clientDBHarness) {
const blobType = blob.TypeAltruistCommit
@ -628,6 +639,8 @@ func testAckUpdate(h *clientDBHarness) {
// Commit to a random update at seqnum 1.
update1 := randCommittedUpdate(h.t, 1)
h.registerChan(update1.BackupID.ChanID, nil, nil)
lastApplied := h.commitUpdate(&session.ID, update1, nil)
require.Zero(h.t, lastApplied)
@ -654,6 +667,7 @@ func testAckUpdate(h *clientDBHarness) {
// value is 1, since this was what was provided in the last successful
// ack.
update2 := randCommittedUpdate(h.t, 2)
h.registerChan(update2.BackupID.ChanID, nil, nil)
lastApplied = h.commitUpdate(&session.ID, update2, nil)
require.EqualValues(h.t, 1, lastApplied)
@ -681,13 +695,16 @@ func (h *clientDBHarness) assertUpdates(id wtdb.SessionID,
expectedPending []wtdb.CommittedUpdate,
expectedAcked map[uint16]wtdb.BackupID) {
ackedUpdates := make(map[uint16]wtdb.BackupID)
_ = h.listSessions(
nil, wtdb.WithPerAckedUpdate(perAckedUpdate(ackedUpdates)),
)
committedUpates := h.fetchSessionCommittedUpdates(&id, nil)
checkCommittedUpdates(h.t, committedUpates, expectedPending)
checkAckedUpdates(h.t, ackedUpdates, expectedAcked)
committedUpdates := h.fetchSessionCommittedUpdates(&id, nil)
checkCommittedUpdates(h.t, committedUpdates, expectedPending)
// Check acked updates.
numAcked := h.numAcked(&id, nil)
require.EqualValues(h.t, len(expectedAcked), numAcked)
for _, backupID := range expectedAcked {
isAcked := h.isAcked(&id, &backupID, nil)
require.True(h.t, isAcked)
}
}
// checkCommittedUpdates asserts that the CommittedUpdates on session match the
@ -707,21 +724,6 @@ func checkCommittedUpdates(t *testing.T, actualUpdates,
require.Equal(t, expUpdates, actualUpdates)
}
// checkAckedUpdates asserts that the AckedUpdates on a session match the
// expUpdates provided.
func checkAckedUpdates(t *testing.T, actualUpdates,
expUpdates map[uint16]wtdb.BackupID) {
// We promote nil expUpdates to an initialized map since the database
// should never return a nil map. This promotion is done purely out of
// convenience for the testing framework.
if expUpdates == nil {
expUpdates = make(map[uint16]wtdb.BackupID)
}
require.Equal(t, expUpdates, actualUpdates)
}
// TestClientDB asserts the behavior of a fresh client db, a reopened client db,
// and the mock implementation. This ensures that all databases function
// identically, especially in the negative paths.

View File

@ -6,6 +6,7 @@ import (
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration1"
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration2"
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration3"
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration4"
)
// log is a logger that is initialized with no output filters. This
@ -32,6 +33,7 @@ func UseLogger(logger btclog.Logger) {
migration1.UseLogger(logger)
migration2.UseLogger(logger)
migration3.UseLogger(logger)
migration4.UseLogger(logger)
}
// logClosure is used to provide a closure over expensive logging operations so

View File

@ -8,6 +8,7 @@ import (
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration1"
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration2"
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration3"
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration4"
)
// txMigration is a function which takes a prior outdated version of the
@ -49,6 +50,11 @@ var clientDBVersions = []version{
{
txMigration: migration3.MigrateChannelIDIndex,
},
{
dbMigration: migration4.MigrateAckedUpdates(
migration4.DefaultSessionsPerTx,
),
},
}
// getLatestDBVersion returns the last known database version.

View File

@ -1,6 +1,7 @@
package wtmock
import (
"encoding/binary"
"net"
"sync"
"sync/atomic"
@ -11,6 +12,8 @@ import (
"github.com/lightningnetwork/lnd/watchtower/wtdb"
)
var byteOrder = binary.BigEndian
type towerPK [33]byte
type keyIndexKey struct {
@ -18,18 +21,23 @@ type keyIndexKey struct {
blobType blob.Type
}
type rangeIndexArrayMap map[wtdb.SessionID]map[lnwire.ChannelID]*wtdb.RangeIndex
type rangeIndexKVStore map[wtdb.SessionID]map[lnwire.ChannelID]*mockKVStore
// ClientDB is a mock, in-memory database or testing the watchtower client
// behavior.
type ClientDB struct {
nextTowerID uint64 // to be used atomically
mu sync.Mutex
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
activeSessions map[wtdb.SessionID]wtdb.ClientSession
ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID
committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate
towerIndex map[towerPK]wtdb.TowerID
towers map[wtdb.TowerID]*wtdb.Tower
mu sync.Mutex
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
activeSessions map[wtdb.SessionID]wtdb.ClientSession
ackedUpdates rangeIndexArrayMap
persistedAckedUpdates rangeIndexKVStore
committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate
towerIndex map[towerPK]wtdb.TowerID
towers map[wtdb.TowerID]*wtdb.Tower
nextIndex uint32
indexes map[keyIndexKey]uint32
@ -39,14 +47,21 @@ type ClientDB struct {
// NewClientDB initializes a new mock ClientDB.
func NewClientDB() *ClientDB {
return &ClientDB{
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID),
committedUpdates: make(map[wtdb.SessionID][]wtdb.CommittedUpdate),
towerIndex: make(map[towerPK]wtdb.TowerID),
towers: make(map[wtdb.TowerID]*wtdb.Tower),
indexes: make(map[keyIndexKey]uint32),
legacyIndexes: make(map[wtdb.TowerID]uint32),
summaries: make(
map[lnwire.ChannelID]wtdb.ClientChanSummary,
),
activeSessions: make(
map[wtdb.SessionID]wtdb.ClientSession,
),
ackedUpdates: make(rangeIndexArrayMap),
persistedAckedUpdates: make(rangeIndexKVStore),
committedUpdates: make(
map[wtdb.SessionID][]wtdb.CommittedUpdate,
),
towerIndex: make(map[towerPK]wtdb.TowerID),
towers: make(map[wtdb.TowerID]*wtdb.Tower),
indexes: make(map[keyIndexKey]uint32),
legacyIndexes: make(map[wtdb.TowerID]uint32),
}
}
@ -233,9 +248,20 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
}
sessions[session.ID] = &session
if cfg.PerAckedUpdate != nil {
for seq, id := range m.ackedUpdates[session.ID] {
cfg.PerAckedUpdate(&session, seq, id)
if cfg.PerMaxHeight != nil {
for chanID, index := range m.ackedUpdates[session.ID] {
cfg.PerMaxHeight(
&session, chanID, index.MaxHeight(),
)
}
}
if cfg.PerNumAckedUpdates != nil {
for chanID, index := range m.ackedUpdates[session.ID] {
cfg.PerNumAckedUpdates(
&session, chanID,
uint16(index.NumInSet()),
)
}
}
@ -266,6 +292,37 @@ func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) (
return updates, nil
}
// IsAcked returns true if the given backup has been backed up using the given
// session.
func (m *ClientDB) IsAcked(id *wtdb.SessionID, backupID *wtdb.BackupID) (bool,
error) {
m.mu.Lock()
defer m.mu.Unlock()
index, ok := m.ackedUpdates[*id][backupID.ChanID]
if !ok {
return false, nil
}
return index.IsInIndex(backupID.CommitHeight), nil
}
// NumAckedUpdates returns the number of backups that have been successfully
// backed up using the given session.
func (m *ClientDB) NumAckedUpdates(id *wtdb.SessionID) (uint64, error) {
m.mu.Lock()
defer m.mu.Unlock()
var numAcked uint64
for _, index := range m.ackedUpdates[*id] {
numAcked += index.NumInSet()
}
return numAcked, nil
}
// CreateClientSession records a newly negotiated client session in the set of
// active sessions. The session can be identified by its SessionID.
func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
@ -311,7 +368,10 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
RewardPkScript: cloneBytes(session.RewardPkScript),
},
}
m.ackedUpdates[session.ID] = make(map[uint16]wtdb.BackupID)
m.ackedUpdates[session.ID] = make(map[lnwire.ChannelID]*wtdb.RangeIndex)
m.persistedAckedUpdates[session.ID] = make(
map[lnwire.ChannelID]*mockKVStore,
)
m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0)
return nil
@ -443,7 +503,25 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum,
updates[len(updates)-1] = wtdb.CommittedUpdate{}
m.committedUpdates[session.ID] = updates[:len(updates)-1]
m.ackedUpdates[*id][seqNum] = update.BackupID
chanID := update.BackupID.ChanID
if _, ok := m.ackedUpdates[*id][update.BackupID.ChanID]; !ok {
index, err := wtdb.NewRangeIndex(nil)
if err != nil {
return err
}
m.ackedUpdates[*id][chanID] = index
m.persistedAckedUpdates[*id][chanID] = newMockKVStore()
}
err := m.ackedUpdates[*id][chanID].Add(
update.BackupID.CommitHeight,
m.persistedAckedUpdates[*id][chanID],
)
if err != nil {
return err
}
session.TowerLastApplied = lastApplied
m.activeSessions[*id] = session
@ -512,3 +590,39 @@ func copyTower(tower *wtdb.Tower) *wtdb.Tower {
return t
}
type mockKVStore struct {
kv map[uint64]uint64
err error
}
func newMockKVStore() *mockKVStore {
return &mockKVStore{
kv: make(map[uint64]uint64),
}
}
func (m *mockKVStore) Put(key, value []byte) error {
if m.err != nil {
return m.err
}
k := byteOrder.Uint64(key)
v := byteOrder.Uint64(value)
m.kv[k] = v
return nil
}
func (m *mockKVStore) Delete(key []byte) error {
if m.err != nil {
return m.err
}
k := byteOrder.Uint64(key)
delete(m.kv, k)
return nil
}