watchtower: add ClientSessionFilterFn to session requests

In this commit, a new ClientSessionFilterFn parameter is added to the
DB's ListClientSession method which can be used to allow the caller to
specify a filter function for filtering sessions read from the DB.
Currently all filtering of sessions are done after the sessions have
been read from the DB, so adding this option should provide some
efficiency.
This commit is contained in:
Elle Mouton 2022-10-18 12:08:46 +02:00
parent 8a2999c789
commit 40ac82e439
No known key found for this signature in database
GPG Key ID: D7D916376026F177
5 changed files with 62 additions and 28 deletions

View File

@ -418,7 +418,7 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
return nil, err
}
sessions, err := db.ListClientSessions(&tower.ID, opts...)
sessions, err := db.ListClientSessions(&tower.ID, nil, opts...)
if err != nil {
return nil, err
}
@ -470,7 +470,7 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*ClientSession, error) {
dbSessions, err := db.ListClientSessions(forTower, opts...)
dbSessions, err := db.ListClientSessions(forTower, nil, opts...)
if err != nil {
return nil, err
}
@ -1289,7 +1289,7 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
// Otherwise, the tower should no longer be used for future session
// negotiations and backups.
pubKey := msg.pubKey.SerializeCompressed()
sessions, err := c.cfg.DB.ListClientSessions(&dbTower.ID)
sessions, err := c.cfg.DB.ListClientSessions(&dbTower.ID, nil)
if err != nil {
return fmt.Errorf("unable to retrieve sessions for tower %x: "+
"%v", pubKey, err)
@ -1320,7 +1320,7 @@ func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) (
if err != nil {
return nil, err
}
clientSessions, err := c.cfg.DB.ListClientSessions(nil, opts...)
clientSessions, err := c.cfg.DB.ListClientSessions(nil, nil, opts...)
if err != nil {
return nil, err
}
@ -1361,7 +1361,9 @@ func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey,
return nil, err
}
towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID, opts...)
towerSessions, err := c.cfg.DB.ListClientSessions(
&tower.ID, nil, opts...,
)
if err != nil {
return nil, err
}

View File

@ -60,7 +60,8 @@ type DB interface {
// ListClientSessions returns the set of all client sessions known to
// the db. An optional tower ID can be used to filter out any client
// sessions in the response that do not correspond to this tower.
ListClientSessions(*wtdb.TowerID, ...wtdb.ClientSessionListOption) (
ListClientSessions(*wtdb.TowerID, wtdb.ClientSessionFilterFn,
...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error)
// FetchSessionCommittedUpdates retrieves the current set of un-acked

View File

@ -138,6 +138,10 @@ var (
// range-index found for the given session ID to channel ID pair.
ErrNoRangeIndexFound = errors.New("no range index found for the " +
"given session-channel pair")
// ErrSessionFailedFilterFn indicates that a particular session did
// not pass the filter func provided by the caller.
ErrSessionFailedFilterFn = errors.New("session failed filter func")
)
// NewBoltBackendCreator returns a function that creates a new bbolt backend for
@ -469,7 +473,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
towerSessions, err := c.listTowerSessions(
towerID, sessions, chanIDIndexBkt,
towersToSessionsIndex,
towersToSessionsIndex, nil,
WithPerCommittedUpdate(perCommittedUpdate),
)
if err != nil {
@ -960,7 +964,8 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID,
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func (c *ClientDB) ListClientSessions(id *TowerID,
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
filterFn ClientSessionFilterFn, opts ...ClientSessionListOption) (
map[SessionID]*ClientSession, error) {
var clientSessions map[SessionID]*ClientSession
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
@ -985,7 +990,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
// known to the db.
if id == nil {
clientSessions, err = c.listClientAllSessions(
sessions, chanIDIndexBkt, opts...,
sessions, chanIDIndexBkt, filterFn, opts...,
)
return err
}
@ -998,7 +1003,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
clientSessions, err = c.listTowerSessions(
*id, sessions, chanIDIndexBkt, towerToSessionIndex,
opts...,
filterFn, opts...,
)
return err
}, func() {
@ -1013,7 +1018,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
// listClientAllSessions returns the set of all client sessions known to the db.
func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
filterFn ClientSessionFilterFn, opts ...ClientSessionListOption) (
map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession)
err := sessions.ForEach(func(k, _ []byte) error {
@ -1022,9 +1028,11 @@ func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
// committed updates and compute the highest known commit height
// for each channel.
session, err := c.getClientSession(
sessions, chanIDIndexBkt, k, opts...,
sessions, chanIDIndexBkt, k, filterFn, opts...,
)
if err != nil {
if errors.Is(err, ErrSessionFailedFilterFn) {
return nil
} else if err != nil {
return err
}
@ -1042,8 +1050,8 @@ func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
// listTowerSessions returns the set of all client sessions known to the db
// that are associated with the given tower id.
func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt,
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
map[SessionID]*ClientSession, error) {
towerToSessionIndex kvdb.RBucket, filterFn ClientSessionFilterFn,
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
if towerIndexBkt == nil {
@ -1057,9 +1065,11 @@ func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt,
// committed updates and compute the highest known commit height
// for each channel.
session, err := c.getClientSession(
sessionsBkt, chanIDIndexBkt, k, opts...,
sessionsBkt, chanIDIndexBkt, k, filterFn, opts...,
)
if err != nil {
if errors.Is(err, ErrSessionFailedFilterFn) {
return nil
} else if err != nil {
return err
}
@ -1523,6 +1533,11 @@ func getClientSessionBody(sessions kvdb.RBucket,
return &session, nil
}
// ClientSessionFilterFn describes the signature of a callback function that can
// be used to filter the sessions that are returned in any of the DB methods
// that read sessions from the DB.
type ClientSessionFilterFn func(*ClientSession) bool
// PerMaxHeightCB describes the signature of a callback function that can be
// called for each channel that a session has updates for to communicate the
// maximum commitment height that the session has backed up for the channel.
@ -1533,6 +1548,10 @@ type PerMaxHeightCB func(*ClientSession, lnwire.ChannelID, uint64)
// number of updates that the session has for the channel.
type PerNumAckedUpdatesCB func(*ClientSession, lnwire.ChannelID, uint16)
// PerAckedUpdateCB describes the signature of a callback function that can be
// called for each of a session's acked updates.
type PerAckedUpdateCB func(*ClientSession, uint16, BackupID)
// PerCommittedUpdateCB describes the signature of a callback function that can
// be called for each of a session's committed updates (updates that the client
// has not yet received an ACK for).
@ -1597,8 +1616,8 @@ func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
// in addition to the ClientSession's body.
func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket,
idBytes []byte, opts ...ClientSessionListOption) (*ClientSession,
error) {
idBytes []byte, filterFn ClientSessionFilterFn,
opts ...ClientSessionListOption) (*ClientSession, error) {
cfg := NewClientSessionCfg()
for _, o := range opts {
@ -1610,6 +1629,10 @@ func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket,
return nil, err
}
if filterFn != nil && !filterFn(session) {
return nil, ErrSessionFailedFilterFn
}
// Can't fail because client session body has already been read.
sessionBkt := sessionsBkt.NestedReadBucket(idBytes)

View File

@ -49,11 +49,12 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession,
}
func (h *clientDBHarness) listSessions(id *wtdb.TowerID,
filterFn wtdb.ClientSessionFilterFn,
opts ...wtdb.ClientSessionListOption) map[wtdb.SessionID]*wtdb.ClientSession {
h.t.Helper()
sessions, err := h.db.ListClientSessions(id, opts...)
sessions, err := h.db.ListClientSessions(id, filterFn, opts...)
require.NoError(h.t, err, "unable to list client sessions")
return sessions
@ -80,7 +81,7 @@ func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress,
require.ErrorIs(h.t, err, expErr)
require.NotZero(h.t, tower.ID, "tower id should never be 0")
for _, session := range h.listSessions(&tower.ID) {
for _, session := range h.listSessions(&tower.ID, nil) {
require.Equal(h.t, wtdb.CSessionActive, session.Status)
}
@ -123,7 +124,7 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr,
return
}
for _, session := range h.listSessions(&tower.ID) {
for _, session := range h.listSessions(&tower.ID, nil) {
require.Equal(h.t, wtdb.CSessionInactive,
session.Status, "expected status for session "+
"%v to be %v, got %v", session.ID,
@ -268,7 +269,7 @@ func testCreateClientSession(h *clientDBHarness) {
// First, assert that this session is not already present in the
// database.
_, ok := h.listSessions(nil)[session.ID]
_, ok := h.listSessions(nil, nil)[session.ID]
require.Falsef(h.t, ok, "session for id %x should not exist yet",
session.ID)
@ -296,7 +297,7 @@ func testCreateClientSession(h *clientDBHarness) {
h.insertSession(session, nil)
// Verify that the session now exists in the database.
_, ok = h.listSessions(nil)[session.ID]
_, ok = h.listSessions(nil, nil)[session.ID]
require.Truef(h.t, ok, "session for id %x should exist now", session.ID)
// Attempt to insert the session again, which should fail due to the
@ -344,7 +345,7 @@ func testFilterClientSessions(h *clientDBHarness) {
// We should see the expected sessions for each tower when filtering
// them.
for towerID, expectedSessions := range towerSessions {
sessions := h.listSessions(&towerID)
sessions := h.listSessions(&towerID, nil)
require.Len(h.t, sessions, len(expectedSessions))
for _, expectedSession := range expectedSessions {

View File

@ -83,7 +83,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
tower = m.towers[towerID]
tower.AddAddress(lnAddr.Address)
towerSessions, err := m.listClientSessions(&towerID)
towerSessions, err := m.listClientSessions(&towerID, nil)
if err != nil {
return nil, err
}
@ -135,7 +135,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
return nil
}
towerSessions, err := m.listClientSessions(&tower.ID)
towerSessions, err := m.listClientSessions(&tower.ID, nil)
if err != nil {
return err
}
@ -220,18 +220,20 @@ func (m *ClientDB) MarkBackupIneligible(_ lnwire.ChannelID, _ uint64) error {
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID,
filterFn wtdb.ClientSessionFilterFn,
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.listClientSessions(tower, opts...)
return m.listClientSessions(tower, filterFn, opts...)
}
// listClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
filterFn wtdb.ClientSessionFilterFn,
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
@ -246,6 +248,11 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
if tower != nil && *tower != session.TowerID {
continue
}
if filterFn != nil && !filterFn(&session) {
continue
}
sessions[session.ID] = &session
if cfg.PerMaxHeight != nil {