wtclient: refactor existing candidate session filtering into method

This commit is contained in:
Wilmer Paulino 2020-05-11 15:23:43 -07:00
parent 8b09ac07d3
commit 01ab551b22
No known key found for this signature in database
GPG Key ID: 6DF57B9F9514972F

View File

@ -38,6 +38,14 @@ const (
DefaultForceQuitDelay = 10 * time.Second DefaultForceQuitDelay = 10 * time.Second
) )
var (
// activeSessionFilter is a filter that ignored any sessions which are
// not active.
activeSessionFilter = func(s *wtdb.ClientSession) bool {
return s.Status == wtdb.CSessionActive
}
)
// RegisteredTower encompasses information about a registered watchtower with // RegisteredTower encompasses information about a registered watchtower with
// the client. // the client.
type RegisteredTower struct { type RegisteredTower struct {
@ -268,49 +276,18 @@ func New(config *Config) (*TowerClient, error) {
// the client. We will use any of these session if their policies match // the client. We will use any of these session if their policies match
// the current policy of the client, otherwise they will be ignored and // the current policy of the client, otherwise they will be ignored and
// new sessions will be requested. // new sessions will be requested.
sessions, err := cfg.DB.ListClientSessions(nil) candidateSessions, err := getClientSessions(
cfg.DB, cfg.SecretKeyRing, nil, activeSessionFilter,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
sessionTowers := make(map[wtdb.TowerID]*wtdb.Tower)
for _, s := range sessions {
// Candidate sessions must be in an active state.
if s.Status != wtdb.CSessionActive {
continue
}
// Reload the tower from disk using the tower ID contained in
// each candidate session. We will also rederive any session
// keys needed to be able to communicate with the towers and
// authenticate session requests. This prevents us from having
// to store the private keys on disk.
tower, ok := sessionTowers[s.TowerID]
if !ok {
var err error
tower, err = cfg.DB.LoadTowerByID(s.TowerID)
if err != nil {
return nil, err
}
}
s.Tower = tower
sessionKey, err := DeriveSessionKey(cfg.SecretKeyRing, s.KeyIndex)
if err != nil {
return nil, err
}
s.SessionPrivKey = sessionKey
candidateSessions[s.ID] = s
sessionTowers[tower.ID] = tower
}
var candidateTowers []*wtdb.Tower var candidateTowers []*wtdb.Tower
for _, tower := range sessionTowers { for _, s := range candidateSessions {
log.Infof("Using private watchtower %s, offering policy %s", log.Infof("Using private watchtower %s, offering policy %s",
tower, cfg.Policy) s.Tower, cfg.Policy)
candidateTowers = append(candidateTowers, tower) candidateTowers = append(candidateTowers, s.Tower)
} }
// Load the sweep pkscripts that have been generated for all previously // Load the sweep pkscripts that have been generated for all previously
@ -353,6 +330,50 @@ func New(config *Config) (*TowerClient, error) {
return c, nil return c, nil
} }
// getClientSessions retrieves the client sessions for a particular tower if
// specified, otherwise all client sessions for all towers are retrieved. An
// optional filter can be provided to filter out any undesired client sessions.
//
// NOTE: This method should only be used when deserialization of a
// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the
// existing ListClientSessions method should be used.
func getClientSessions(db DB, keyRing SecretKeyRing, forTower *wtdb.TowerID,
passesFilter func(*wtdb.ClientSession) bool) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
sessions, err := db.ListClientSessions(forTower)
if err != nil {
return nil, err
}
// Reload the tower from disk using the tower ID contained in each
// candidate session. We will also rederive any session keys needed to
// be able to communicate with the towers and authenticate session
// requests. This prevents us from having to store the private keys on
// disk.
for _, s := range sessions {
tower, err := db.LoadTowerByID(s.TowerID)
if err != nil {
return nil, err
}
s.Tower = tower
sessionKey, err := DeriveSessionKey(keyRing, s.KeyIndex)
if err != nil {
return nil, err
}
s.SessionPrivKey = sessionKey
// If an optional filter was provided, use it to filter out any
// undesired sessions.
if passesFilter != nil && !passesFilter(s) {
delete(sessions, s.ID)
}
}
return sessions, nil
}
// buildHighestCommitHeights inspects the full set of candidate client sessions // buildHighestCommitHeights inspects the full set of candidate client sessions
// loaded from disk, and determines the highest known commit height for each // loaded from disk, and determines the highest known commit height for each
// channel. This allows the client to reject backups that it has already // channel. This allows the client to reject backups that it has already