diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 2e64e6122..5a18095b8 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -2145,6 +2145,12 @@ type ClientSessionListCfg struct { // function, then it will not be passed to any of the other call backs // and won't be included in the return list. PreEvaluateFilterFn ClientSessionFilterFn + + // PostEvaluateFilterFn will be run _after_ all the other call-back + // functions in ClientSessionListCfg. If a session fails this filter + // function then all it means is that it won't be included in the list + // of sessions to return. + PostEvaluateFilterFn ClientSessionFilterFn } // NewClientSessionCfg constructs a new ClientSessionListCfg. @@ -2189,6 +2195,19 @@ func WithPreEvalFilterFn(fn ClientSessionFilterFn) ClientSessionListOption { } } +// WithPostEvalFilterFn constructs a functional option that will set a call-back +// function that will be used to determine if a session should be included in +// the returned list. This differs from WithPreEvalFilterFn since that call-back +// is used to determine if the session should be evaluated at all (and thus +// run against the other ClientSessionListCfg call-backs) whereas the session +// will only reach the PostEvalFilterFn call-back once it has already been +// evaluated by all the other call-backs. +func WithPostEvalFilterFn(fn ClientSessionFilterFn) ClientSessionListOption { + return func(cfg *ClientSessionListCfg) { + cfg.PostEvaluateFilterFn = fn + } +} + // getClientSession loads the full ClientSession associated with the serialized // session id. This method populates the CommittedUpdates, AckUpdates and Tower // in addition to the ClientSession's body. @@ -2232,6 +2251,12 @@ func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket, return nil, err } + if cfg.PostEvaluateFilterFn != nil && + !cfg.PostEvaluateFilterFn(session) { + + return nil, ErrSessionFailedFilterFn + } + return session, nil } diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 22b21a563..8d43a0e45 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -260,8 +260,6 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, continue } - sessions[session.ID] = &session - if cfg.PerMaxHeight != nil { for chanID, index := range m.ackedUpdates[session.ID] { cfg.PerMaxHeight( @@ -285,6 +283,14 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, cfg.PerCommittedUpdate(&session, &update) } } + + if cfg.PostEvaluateFilterFn != nil && + !cfg.PostEvaluateFilterFn(&session) { + + continue + } + + sessions[session.ID] = &session } return sessions, nil