mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-18 21:35:24 +01:00
watchtower: always populate Tower in ClientSession
In this commit, we make sure to always populate the Tower member of a ClientSession. This is done for consistency.
This commit is contained in:
parent
e150bb83d1
commit
c60ecaccbf
@ -354,8 +354,8 @@ func New(config *Config) (*TowerClient, error) {
|
||||
// optional filter can be provided to filter out any undesired client sessions.
|
||||
//
|
||||
// NOTE: This method should only be used when deserialization of a
|
||||
// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the
|
||||
// existing ListClientSessions method should be used.
|
||||
// ClientSession's SessionPrivKey field is desired, otherwise, the existing
|
||||
// ListClientSessions method should be used.
|
||||
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
||||
passesFilter func(*wtdb.ClientSession) bool) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
@ -371,12 +371,6 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
||||
// requests. This prevents us from having to store the private keys on
|
||||
// disk.
|
||||
for _, s := range sessions {
|
||||
tower, err := db.LoadTowerByID(s.TowerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.Tower = tower
|
||||
|
||||
towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{
|
||||
Family: keychain.KeyFamilyTowerSession,
|
||||
Index: s.KeyIndex,
|
||||
|
@ -288,7 +288,7 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
|
||||
}
|
||||
towerID := TowerIDFromBytes(towerIDBytes)
|
||||
towerSessions, err := listClientSessions(
|
||||
sessions, &towerID,
|
||||
sessions, towers, &towerID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -389,7 +389,9 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
towerID := TowerIDFromBytes(towerIDBytes)
|
||||
towerSessions, err := listClientSessions(sessions, &towerID)
|
||||
towerSessions, err := listClientSessions(
|
||||
sessions, towers, &towerID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -685,8 +687,14 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
towers := tx.ReadBucket(cTowerBkt)
|
||||
if towers == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
var err error
|
||||
clientSessions, err = listClientSessions(sessions, id)
|
||||
clientSessions, err = listClientSessions(sessions, towers, id)
|
||||
return err
|
||||
}, func() {
|
||||
clientSessions = nil
|
||||
@ -701,7 +709,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
|
||||
// listClientSessions returns the set of all client sessions known to the db. An
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func listClientSessions(sessions kvdb.RBucket,
|
||||
func listClientSessions(sessions, towers kvdb.RBucket,
|
||||
id *TowerID) (map[SessionID]*ClientSession, error) {
|
||||
|
||||
clientSessions := make(map[SessionID]*ClientSession)
|
||||
@ -710,7 +718,7 @@ func listClientSessions(sessions kvdb.RBucket,
|
||||
// the CommittedUpdates and AckedUpdates on startup to resume
|
||||
// committed updates and compute the highest known commit height
|
||||
// for each channel.
|
||||
session, err := getClientSession(sessions, k)
|
||||
session, err := getClientSession(sessions, towers, k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1022,8 +1030,8 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
|
||||
|
||||
// getClientSessionBody loads the body of a ClientSession from the sessions
|
||||
// bucket corresponding to the serialized session id. This does not deserialize
|
||||
// the CommittedUpdates or AckUpdates associated with the session. If the caller
|
||||
// requires this info, use getClientSession.
|
||||
// the CommittedUpdates, AckUpdates or the Tower associated with the session.
|
||||
// If the caller requires this info, use getClientSession.
|
||||
func getClientSessionBody(sessions kvdb.RBucket,
|
||||
idBytes []byte) (*ClientSession, error) {
|
||||
|
||||
@ -1050,9 +1058,9 @@ func getClientSessionBody(sessions kvdb.RBucket,
|
||||
}
|
||||
|
||||
// getClientSession loads the full ClientSession associated with the serialized
|
||||
// session id. This method populates the CommittedUpdates and AckUpdates in
|
||||
// addition to the ClientSession's body.
|
||||
func getClientSession(sessions kvdb.RBucket,
|
||||
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
|
||||
// in addition to the ClientSession's body.
|
||||
func getClientSession(sessions, towers kvdb.RBucket,
|
||||
idBytes []byte) (*ClientSession, error) {
|
||||
|
||||
session, err := getClientSessionBody(sessions, idBytes)
|
||||
@ -1060,6 +1068,12 @@ func getClientSession(sessions kvdb.RBucket,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fetch the tower associated with this session.
|
||||
tower, err := getTower(towers, session.TowerID.Bytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fetch the committed updates for this session.
|
||||
commitedUpdates, err := getClientSessionCommits(sessions, idBytes)
|
||||
if err != nil {
|
||||
@ -1072,6 +1086,7 @@ func getClientSession(sessions kvdb.RBucket,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session.Tower = tower
|
||||
session.CommittedUpdates = commitedUpdates
|
||||
session.AckedUpdates = ackedUpdates
|
||||
|
||||
|
@ -220,6 +220,7 @@ func (m *ClientDB) listClientSessions(
|
||||
if tower != nil && *tower != session.TowerID {
|
||||
continue
|
||||
}
|
||||
session.Tower = m.towers[session.TowerID]
|
||||
sessions[session.ID] = &session
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user