watchtower: always populate Tower in ClientSession

In this commit, we make sure to always populate the Tower member of a
ClientSession. This is done for consistency.
This commit is contained in:
Elle Mouton 2022-10-04 15:18:40 +02:00
parent e150bb83d1
commit c60ecaccbf
No known key found for this signature in database
GPG Key ID: D7D916376026F177
3 changed files with 28 additions and 18 deletions

View File

@ -354,8 +354,8 @@ func New(config *Config) (*TowerClient, error) {
// optional filter can be provided to filter out any undesired client sessions.
//
// NOTE: This method should only be used when deserialization of a
// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the
// existing ListClientSessions method should be used.
// ClientSession's SessionPrivKey field is desired, otherwise, the existing
// ListClientSessions method should be used.
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
passesFilter func(*wtdb.ClientSession) bool) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
@ -371,12 +371,6 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
// requests. This prevents us from having to store the private keys on
// disk.
for _, s := range sessions {
tower, err := db.LoadTowerByID(s.TowerID)
if err != nil {
return nil, err
}
s.Tower = tower
towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{
Family: keychain.KeyFamilyTowerSession,
Index: s.KeyIndex,

View File

@ -288,7 +288,7 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
}
towerID := TowerIDFromBytes(towerIDBytes)
towerSessions, err := listClientSessions(
sessions, &towerID,
sessions, towers, &towerID,
)
if err != nil {
return err
@ -389,7 +389,9 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
return ErrUninitializedDB
}
towerID := TowerIDFromBytes(towerIDBytes)
towerSessions, err := listClientSessions(sessions, &towerID)
towerSessions, err := listClientSessions(
sessions, towers, &towerID,
)
if err != nil {
return err
}
@ -685,8 +687,14 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
if sessions == nil {
return ErrUninitializedDB
}
towers := tx.ReadBucket(cTowerBkt)
if towers == nil {
return ErrUninitializedDB
}
var err error
clientSessions, err = listClientSessions(sessions, id)
clientSessions, err = listClientSessions(sessions, towers, id)
return err
}, func() {
clientSessions = nil
@ -701,7 +709,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
// listClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func listClientSessions(sessions kvdb.RBucket,
func listClientSessions(sessions, towers kvdb.RBucket,
id *TowerID) (map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession)
@ -710,7 +718,7 @@ func listClientSessions(sessions kvdb.RBucket,
// the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height
// for each channel.
session, err := getClientSession(sessions, k)
session, err := getClientSession(sessions, towers, k)
if err != nil {
return err
}
@ -1022,8 +1030,8 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
// getClientSessionBody loads the body of a ClientSession from the sessions
// bucket corresponding to the serialized session id. This does not deserialize
// the CommittedUpdates or AckUpdates associated with the session. If the caller
// requires this info, use getClientSession.
// the CommittedUpdates, AckUpdates or the Tower associated with the session.
// If the caller requires this info, use getClientSession.
func getClientSessionBody(sessions kvdb.RBucket,
idBytes []byte) (*ClientSession, error) {
@ -1050,9 +1058,9 @@ func getClientSessionBody(sessions kvdb.RBucket,
}
// getClientSession loads the full ClientSession associated with the serialized
// session id. This method populates the CommittedUpdates and AckUpdates in
// addition to the ClientSession's body.
func getClientSession(sessions kvdb.RBucket,
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
// in addition to the ClientSession's body.
func getClientSession(sessions, towers kvdb.RBucket,
idBytes []byte) (*ClientSession, error) {
session, err := getClientSessionBody(sessions, idBytes)
@ -1060,6 +1068,12 @@ func getClientSession(sessions kvdb.RBucket,
return nil, err
}
// Fetch the tower associated with this session.
tower, err := getTower(towers, session.TowerID.Bytes())
if err != nil {
return nil, err
}
// Fetch the committed updates for this session.
commitedUpdates, err := getClientSessionCommits(sessions, idBytes)
if err != nil {
@ -1072,6 +1086,7 @@ func getClientSession(sessions kvdb.RBucket,
return nil, err
}
session.Tower = tower
session.CommittedUpdates = commitedUpdates
session.AckedUpdates = ackedUpdates

View File

@ -220,6 +220,7 @@ func (m *ClientDB) listClientSessions(
if tower != nil && *tower != session.TowerID {
continue
}
session.Tower = m.towers[session.TowerID]
sessions[session.ID] = &session
}