mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 01:43:16 +01:00
watchtower: add ClientSessionFilterFn to session requests
In this commit, a new ClientSessionFilterFn parameter is added to the DB's ListClientSession method which can be used to allow the caller to specify a filter function for filtering sessions read from the DB. Currently all filtering of sessions are done after the sessions have been read from the DB, so adding this option should provide some efficiency.
This commit is contained in:
parent
8a2999c789
commit
40ac82e439
@ -418,7 +418,7 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sessions, err := db.ListClientSessions(&tower.ID, opts...)
|
||||
sessions, err := db.ListClientSessions(&tower.ID, nil, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -470,7 +470,7 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
||||
opts ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*ClientSession, error) {
|
||||
|
||||
dbSessions, err := db.ListClientSessions(forTower, opts...)
|
||||
dbSessions, err := db.ListClientSessions(forTower, nil, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1289,7 +1289,7 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
|
||||
// Otherwise, the tower should no longer be used for future session
|
||||
// negotiations and backups.
|
||||
pubKey := msg.pubKey.SerializeCompressed()
|
||||
sessions, err := c.cfg.DB.ListClientSessions(&dbTower.ID)
|
||||
sessions, err := c.cfg.DB.ListClientSessions(&dbTower.ID, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to retrieve sessions for tower %x: "+
|
||||
"%v", pubKey, err)
|
||||
@ -1320,7 +1320,7 @@ func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) (
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientSessions, err := c.cfg.DB.ListClientSessions(nil, opts...)
|
||||
clientSessions, err := c.cfg.DB.ListClientSessions(nil, nil, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1361,7 +1361,9 @@ func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID, opts...)
|
||||
towerSessions, err := c.cfg.DB.ListClientSessions(
|
||||
&tower.ID, nil, opts...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -60,7 +60,8 @@ type DB interface {
|
||||
// ListClientSessions returns the set of all client sessions known to
|
||||
// the db. An optional tower ID can be used to filter out any client
|
||||
// sessions in the response that do not correspond to this tower.
|
||||
ListClientSessions(*wtdb.TowerID, ...wtdb.ClientSessionListOption) (
|
||||
ListClientSessions(*wtdb.TowerID, wtdb.ClientSessionFilterFn,
|
||||
...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error)
|
||||
|
||||
// FetchSessionCommittedUpdates retrieves the current set of un-acked
|
||||
|
@ -138,6 +138,10 @@ var (
|
||||
// range-index found for the given session ID to channel ID pair.
|
||||
ErrNoRangeIndexFound = errors.New("no range index found for the " +
|
||||
"given session-channel pair")
|
||||
|
||||
// ErrSessionFailedFilterFn indicates that a particular session did
|
||||
// not pass the filter func provided by the caller.
|
||||
ErrSessionFailedFilterFn = errors.New("session failed filter func")
|
||||
)
|
||||
|
||||
// NewBoltBackendCreator returns a function that creates a new bbolt backend for
|
||||
@ -469,7 +473,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
|
||||
towerSessions, err := c.listTowerSessions(
|
||||
towerID, sessions, chanIDIndexBkt,
|
||||
towersToSessionsIndex,
|
||||
towersToSessionsIndex, nil,
|
||||
WithPerCommittedUpdate(perCommittedUpdate),
|
||||
)
|
||||
if err != nil {
|
||||
@ -960,7 +964,8 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID,
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||
filterFn ClientSessionFilterFn, opts ...ClientSessionListOption) (
|
||||
map[SessionID]*ClientSession, error) {
|
||||
|
||||
var clientSessions map[SessionID]*ClientSession
|
||||
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
|
||||
@ -985,7 +990,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||
// known to the db.
|
||||
if id == nil {
|
||||
clientSessions, err = c.listClientAllSessions(
|
||||
sessions, chanIDIndexBkt, opts...,
|
||||
sessions, chanIDIndexBkt, filterFn, opts...,
|
||||
)
|
||||
return err
|
||||
}
|
||||
@ -998,7 +1003,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||
|
||||
clientSessions, err = c.listTowerSessions(
|
||||
*id, sessions, chanIDIndexBkt, towerToSessionIndex,
|
||||
opts...,
|
||||
filterFn, opts...,
|
||||
)
|
||||
return err
|
||||
}, func() {
|
||||
@ -1013,7 +1018,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||
|
||||
// listClientAllSessions returns the set of all client sessions known to the db.
|
||||
func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
|
||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||
filterFn ClientSessionFilterFn, opts ...ClientSessionListOption) (
|
||||
map[SessionID]*ClientSession, error) {
|
||||
|
||||
clientSessions := make(map[SessionID]*ClientSession)
|
||||
err := sessions.ForEach(func(k, _ []byte) error {
|
||||
@ -1022,9 +1028,11 @@ func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
|
||||
// committed updates and compute the highest known commit height
|
||||
// for each channel.
|
||||
session, err := c.getClientSession(
|
||||
sessions, chanIDIndexBkt, k, opts...,
|
||||
sessions, chanIDIndexBkt, k, filterFn, opts...,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSessionFailedFilterFn) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -1042,8 +1050,8 @@ func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
|
||||
// listTowerSessions returns the set of all client sessions known to the db
|
||||
// that are associated with the given tower id.
|
||||
func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt,
|
||||
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
|
||||
map[SessionID]*ClientSession, error) {
|
||||
towerToSessionIndex kvdb.RBucket, filterFn ClientSessionFilterFn,
|
||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||
|
||||
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
|
||||
if towerIndexBkt == nil {
|
||||
@ -1057,9 +1065,11 @@ func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt,
|
||||
// committed updates and compute the highest known commit height
|
||||
// for each channel.
|
||||
session, err := c.getClientSession(
|
||||
sessionsBkt, chanIDIndexBkt, k, opts...,
|
||||
sessionsBkt, chanIDIndexBkt, k, filterFn, opts...,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSessionFailedFilterFn) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -1523,6 +1533,11 @@ func getClientSessionBody(sessions kvdb.RBucket,
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// ClientSessionFilterFn describes the signature of a callback function that can
|
||||
// be used to filter the sessions that are returned in any of the DB methods
|
||||
// that read sessions from the DB.
|
||||
type ClientSessionFilterFn func(*ClientSession) bool
|
||||
|
||||
// PerMaxHeightCB describes the signature of a callback function that can be
|
||||
// called for each channel that a session has updates for to communicate the
|
||||
// maximum commitment height that the session has backed up for the channel.
|
||||
@ -1533,6 +1548,10 @@ type PerMaxHeightCB func(*ClientSession, lnwire.ChannelID, uint64)
|
||||
// number of updates that the session has for the channel.
|
||||
type PerNumAckedUpdatesCB func(*ClientSession, lnwire.ChannelID, uint16)
|
||||
|
||||
// PerAckedUpdateCB describes the signature of a callback function that can be
|
||||
// called for each of a session's acked updates.
|
||||
type PerAckedUpdateCB func(*ClientSession, uint16, BackupID)
|
||||
|
||||
// PerCommittedUpdateCB describes the signature of a callback function that can
|
||||
// be called for each of a session's committed updates (updates that the client
|
||||
// has not yet received an ACK for).
|
||||
@ -1597,8 +1616,8 @@ func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
|
||||
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
|
||||
// in addition to the ClientSession's body.
|
||||
func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket,
|
||||
idBytes []byte, opts ...ClientSessionListOption) (*ClientSession,
|
||||
error) {
|
||||
idBytes []byte, filterFn ClientSessionFilterFn,
|
||||
opts ...ClientSessionListOption) (*ClientSession, error) {
|
||||
|
||||
cfg := NewClientSessionCfg()
|
||||
for _, o := range opts {
|
||||
@ -1610,6 +1629,10 @@ func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if filterFn != nil && !filterFn(session) {
|
||||
return nil, ErrSessionFailedFilterFn
|
||||
}
|
||||
|
||||
// Can't fail because client session body has already been read.
|
||||
sessionBkt := sessionsBkt.NestedReadBucket(idBytes)
|
||||
|
||||
|
@ -49,11 +49,12 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession,
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) listSessions(id *wtdb.TowerID,
|
||||
filterFn wtdb.ClientSessionFilterFn,
|
||||
opts ...wtdb.ClientSessionListOption) map[wtdb.SessionID]*wtdb.ClientSession {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
sessions, err := h.db.ListClientSessions(id, opts...)
|
||||
sessions, err := h.db.ListClientSessions(id, filterFn, opts...)
|
||||
require.NoError(h.t, err, "unable to list client sessions")
|
||||
|
||||
return sessions
|
||||
@ -80,7 +81,7 @@ func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress,
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
require.NotZero(h.t, tower.ID, "tower id should never be 0")
|
||||
|
||||
for _, session := range h.listSessions(&tower.ID) {
|
||||
for _, session := range h.listSessions(&tower.ID, nil) {
|
||||
require.Equal(h.t, wtdb.CSessionActive, session.Status)
|
||||
}
|
||||
|
||||
@ -123,7 +124,7 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr,
|
||||
return
|
||||
}
|
||||
|
||||
for _, session := range h.listSessions(&tower.ID) {
|
||||
for _, session := range h.listSessions(&tower.ID, nil) {
|
||||
require.Equal(h.t, wtdb.CSessionInactive,
|
||||
session.Status, "expected status for session "+
|
||||
"%v to be %v, got %v", session.ID,
|
||||
@ -268,7 +269,7 @@ func testCreateClientSession(h *clientDBHarness) {
|
||||
|
||||
// First, assert that this session is not already present in the
|
||||
// database.
|
||||
_, ok := h.listSessions(nil)[session.ID]
|
||||
_, ok := h.listSessions(nil, nil)[session.ID]
|
||||
require.Falsef(h.t, ok, "session for id %x should not exist yet",
|
||||
session.ID)
|
||||
|
||||
@ -296,7 +297,7 @@ func testCreateClientSession(h *clientDBHarness) {
|
||||
h.insertSession(session, nil)
|
||||
|
||||
// Verify that the session now exists in the database.
|
||||
_, ok = h.listSessions(nil)[session.ID]
|
||||
_, ok = h.listSessions(nil, nil)[session.ID]
|
||||
require.Truef(h.t, ok, "session for id %x should exist now", session.ID)
|
||||
|
||||
// Attempt to insert the session again, which should fail due to the
|
||||
@ -344,7 +345,7 @@ func testFilterClientSessions(h *clientDBHarness) {
|
||||
// We should see the expected sessions for each tower when filtering
|
||||
// them.
|
||||
for towerID, expectedSessions := range towerSessions {
|
||||
sessions := h.listSessions(&towerID)
|
||||
sessions := h.listSessions(&towerID, nil)
|
||||
require.Len(h.t, sessions, len(expectedSessions))
|
||||
|
||||
for _, expectedSession := range expectedSessions {
|
||||
|
@ -83,7 +83,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
|
||||
tower = m.towers[towerID]
|
||||
tower.AddAddress(lnAddr.Address)
|
||||
|
||||
towerSessions, err := m.listClientSessions(&towerID)
|
||||
towerSessions, err := m.listClientSessions(&towerID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -135,7 +135,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
towerSessions, err := m.listClientSessions(&tower.ID)
|
||||
towerSessions, err := m.listClientSessions(&tower.ID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -220,18 +220,20 @@ func (m *ClientDB) MarkBackupIneligible(_ lnwire.ChannelID, _ uint64) error {
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID,
|
||||
filterFn wtdb.ClientSessionFilterFn,
|
||||
opts ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.listClientSessions(tower, opts...)
|
||||
return m.listClientSessions(tower, filterFn, opts...)
|
||||
}
|
||||
|
||||
// listClientSessions returns the set of all client sessions known to the db. An
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
|
||||
filterFn wtdb.ClientSessionFilterFn,
|
||||
opts ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
@ -246,6 +248,11 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
|
||||
if tower != nil && *tower != session.TowerID {
|
||||
continue
|
||||
}
|
||||
|
||||
if filterFn != nil && !filterFn(&session) {
|
||||
continue
|
||||
}
|
||||
|
||||
sessions[session.ID] = &session
|
||||
|
||||
if cfg.PerMaxHeight != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user