diff --git a/docs/release-notes/release-notes-0.16.0.md b/docs/release-notes/release-notes-0.16.0.md index f003dd4b8..5bdcce5b3 100644 --- a/docs/release-notes/release-notes-0.16.0.md +++ b/docs/release-notes/release-notes-0.16.0.md @@ -102,6 +102,14 @@ crash](https://github.com/lightningnetwork/lnd/pull/7019). * [The `tlv` package now allows decoding records larger than 65535 bytes. The caller is expected to know that doing so with untrusted input is unsafe.](https://github.com/lightningnetwork/lnd/pull/6779) + +## Watchtowers + +* [Create a towerID-to-sessionID index in the wtclient DB to improve the + speed of listing sessions for a particular tower ID]( + https://github.com/lightningnetwork/lnd/pull/6972). This PR also ensures a + closer coupling of Towers and Sessions and ensures that a session cannot be + added if the tower it is referring to does not exist. * [Create a helper function to wait for peer to come online](https://github.com/lightningnetwork/lnd/pull/6931). diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 720bdcca7..3d23f0b82 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -287,26 +287,33 @@ func New(config *Config) (*TowerClient, error) { } plog := build.NewPrefixLog(prefix, log) - // Next, load all candidate sessions and towers from the database into - // the client. We will use any of these session if their policies match + // Next, load all candidate towers and sessions from the database into + // the client. We will use any of these sessions if their policies match // the current policy of the client, otherwise they will be ignored and // new sessions will be requested. isAnchorClient := cfg.Policy.IsAnchorChannel() activeSessionFilter := genActiveSessionFilter(isAnchorClient) - candidateSessions, err := getClientSessions( - cfg.DB, cfg.SecretKeyRing, nil, activeSessionFilter, + candidateTowers := newTowerListIterator() + perActiveTower := func(tower *wtdb.Tower) { + // If the tower has already been marked as active, then there is + // no need to add it to the iterator again. + if candidateTowers.IsActive(tower.ID) { + return + } + + log.Infof("Using private watchtower %s, offering policy %s", + tower, cfg.Policy) + + // Add the tower to the set of candidate towers. + candidateTowers.AddCandidate(tower) + } + candidateSessions, err := getTowerAndSessionCandidates( + cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower, ) if err != nil { return nil, err } - var candidateTowers []*wtdb.Tower - for _, s := range candidateSessions { - plog.Infof("Using private watchtower %s, offering policy %s", - s.Tower, cfg.Policy) - candidateTowers = append(candidateTowers, s.Tower) - } - // Load the sweep pkscripts that have been generated for all previously // registered channels. chanSummaries, err := cfg.DB.FetchChanSummaries() @@ -318,7 +325,7 @@ func New(config *Config) (*TowerClient, error) { cfg: cfg, log: plog, pipeline: newTaskPipeline(plog), - candidateTowers: newTowerListIterator(candidateTowers...), + candidateTowers: candidateTowers, candidateSessions: candidateSessions, activeSessions: make(sessionQueueSet), summaries: chanSummaries, @@ -349,13 +356,62 @@ func New(config *Config) (*TowerClient, error) { return c, nil } +// getTowerAndSessionCandidates loads all the towers from the DB and then +// fetches the sessions for each of tower. Sessions are only collected if they +// pass the sessionFilter check. If a tower has a session that does pass the +// sessionFilter check then the perActiveTower call-back will be called on that +// tower. +func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, + sessionFilter func(*wtdb.ClientSession) bool, + perActiveTower func(tower *wtdb.Tower)) ( + map[wtdb.SessionID]*wtdb.ClientSession, error) { + + towers, err := db.ListTowers() + if err != nil { + return nil, err + } + + candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession) + for _, tower := range towers { + sessions, err := db.ListClientSessions(&tower.ID) + if err != nil { + return nil, err + } + + for _, s := range sessions { + towerKeyDesc, err := keyRing.DeriveKey( + keychain.KeyLocator{ + Family: keychain.KeyFamilyTowerSession, + Index: s.KeyIndex, + }, + ) + if err != nil { + return nil, err + } + s.SessionKeyECDH = keychain.NewPubKeyECDH( + towerKeyDesc, keyRing, + ) + + if !sessionFilter(s) { + continue + } + + // Add the session to the set of candidate sessions. + candidateSessions[s.ID] = s + perActiveTower(tower) + } + } + + return candidateSessions, nil +} + // getClientSessions retrieves the client sessions for a particular tower if // specified, otherwise all client sessions for all towers are retrieved. An // optional filter can be provided to filter out any undesired client sessions. // // NOTE: This method should only be used when deserialization of a -// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the -// existing ListClientSessions method should be used. +// ClientSession's SessionPrivKey field is desired, otherwise, the existing +// ListClientSessions method should be used. func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, passesFilter func(*wtdb.ClientSession) bool) ( map[wtdb.SessionID]*wtdb.ClientSession, error) { @@ -371,12 +427,6 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, // requests. This prevents us from having to store the private keys on // disk. for _, s := range sessions { - tower, err := db.LoadTowerByID(s.TowerID) - if err != nil { - return nil, err - } - s.Tower = tower - towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{ Family: keychain.KeyFamilyTowerSession, Index: s.KeyIndex, diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 69f367293..dbf2faf71 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -62,7 +62,8 @@ type DB interface { // still be able to accept state updates. An optional tower ID can be // used to filter out any client sessions in the response that do not // correspond to this tower. - ListClientSessions(*wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) + ListClientSessions(*wtdb.TowerID) ( + map[wtdb.SessionID]*wtdb.ClientSession, error) // FetchChanSummaries loads a mapping from all registered channels to // their channel summaries. @@ -96,8 +97,8 @@ type DB interface { AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error } -// AuthDialer connects to a remote node using an authenticated transport, such as -// brontide. The dialer argument is used to specify a resolver, which allows +// AuthDialer connects to a remote node using an authenticated transport, such +// as brontide. The dialer argument is used to specify a resolver, which allows // this method to be used over Tor or clear net connections. type AuthDialer func(localKey keychain.SingleKeyECDH, netAddr *lnwire.NetAddress, diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 91df574d2..3cb5a8c70 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -48,6 +48,12 @@ var ( // tower-pubkey -> tower-id. cTowerIndexBkt = []byte("client-tower-index-bucket") + // cTowerToSessionIndexBkt is a top-level bucket storing: + // tower-id -> session-id -> 1 + cTowerToSessionIndexBkt = []byte( + "client-tower-to-session-index-bucket", + ) + // ErrTowerNotFound signals that the target tower was not found in the // database. ErrTowerNotFound = errors.New("tower not found") @@ -113,7 +119,8 @@ var ( // NewBoltBackendCreator returns a function that creates a new bbolt backend for // the watchtower database. func NewBoltBackendCreator(active bool, dbPath, - dbFileName string) func(boltCfg *kvdb.BoltConfig) (kvdb.Backend, error) { + dbFileName string) func(boltCfg *kvdb.BoltConfig) (kvdb.Backend, + error) { // If the watchtower client isn't active, we return a function that // always returns a nil DB to make sure we don't create empty database @@ -195,6 +202,7 @@ func initClientDBBuckets(tx kvdb.RwTx) error { cSessionBkt, cTowerBkt, cTowerIndexBkt, + cTowerToSessionIndexBkt, } for _, bucket := range buckets { @@ -259,6 +267,13 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) { return ErrUninitializedDB } + towerToSessionIndex := tx.ReadWriteBucket( + cTowerToSessionIndexBkt, + ) + if towerToSessionIndex == nil { + return ErrUninitializedDB + } + // Check if the tower index already knows of this pubkey. towerIDBytes := towerIndex.Get(towerPubKey[:]) if len(towerIDBytes) == 8 { @@ -278,27 +293,32 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) { // If there are any client sessions that correspond to // this tower, we'll mark them as active to ensure we // load them upon restarts. - // - // TODO(wilmer): with an index of tower -> sessions we - // can avoid the linear lookup. + towerSessIndex := towerToSessionIndex.NestedReadBucket( + tower.ID.Bytes(), + ) + if towerSessIndex == nil { + return ErrTowerNotFound + } + sessions := tx.ReadWriteBucket(cSessionBkt) if sessions == nil { return ErrUninitializedDB } - towerID := TowerIDFromBytes(towerIDBytes) - towerSessions, err := listClientSessions( - sessions, &towerID, - ) - if err != nil { - return err - } - for _, session := range towerSessions { - err := markSessionStatus( - sessions, session, CSessionActive, + + err = towerSessIndex.ForEach(func(k, _ []byte) error { + session, err := getClientSessionBody( + sessions, k, ) if err != nil { return err } + + return markSessionStatus( + sessions, session, CSessionActive, + ) + }) + if err != nil { + return err } } else { // No such tower exists, create a new tower id for our @@ -320,6 +340,13 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) { if err != nil { return err } + + // Create a new bucket for this tower in the + // tower-to-sessions index. + _, err = towerToSessionIndex.CreateBucket(towerIDBytes) + if err != nil { + return err + } } // Store the new or updated tower under its tower id. @@ -348,11 +375,19 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { if towers == nil { return ErrUninitializedDB } + towerIndex := tx.ReadWriteBucket(cTowerIndexBkt) if towerIndex == nil { return ErrUninitializedDB } + towersToSessionsIndex := tx.ReadWriteBucket( + cTowerToSessionIndexBkt, + ) + if towersToSessionsIndex == nil { + return ErrUninitializedDB + } + // Don't return an error if the watchtower doesn't exist to act // as a NOP. pubKeyBytes := pubKey.SerializeCompressed() @@ -380,15 +415,14 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { // Otherwise, we should attempt to mark the tower's sessions as // inactive. - // - // TODO(wilmer): with an index of tower -> sessions we can avoid - // the linear lookup. sessions := tx.ReadWriteBucket(cSessionBkt) if sessions == nil { return ErrUninitializedDB } towerID := TowerIDFromBytes(towerIDBytes) - towerSessions, err := listClientSessions(sessions, &towerID) + towerSessions, err := listTowerSessions( + towerID, sessions, towers, towersToSessionsIndex, + ) if err != nil { return err } @@ -399,7 +433,14 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { if err := towerIndex.Delete(pubKeyBytes); err != nil { return err } - return towers.Delete(towerIDBytes) + + if err := towers.Delete(towerIDBytes); err != nil { + return err + } + + return towersToSessionsIndex.DeleteNestedBucket( + towerIDBytes, + ) } // We'll mark its sessions as inactive as long as they don't @@ -573,14 +614,34 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error { return ErrUninitializedDB } + towers := tx.ReadBucket(cTowerBkt) + if towers == nil { + return ErrUninitializedDB + } + + towerToSessionIndex := tx.ReadWriteBucket( + cTowerToSessionIndexBkt, + ) + if towerToSessionIndex == nil { + return ErrUninitializedDB + } + // Check that client session with this session id doesn't // already exist. - existingSessionBytes := sessions.NestedReadWriteBucket(session.ID[:]) + existingSessionBytes := sessions.NestedReadWriteBucket( + session.ID[:], + ) if existingSessionBytes != nil { return ErrClientSessionAlreadyExists } + // Ensure that a tower with the given ID actually exists in the + // DB. towerID := session.TowerID + if _, err := getTower(towers, towerID.Bytes()); err != nil { + return err + } + blobType := session.Policy.BlobType // Check that this tower has a reserved key index. @@ -609,6 +670,19 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error { } } + // Add the new entry to the towerID-to-SessionID index. + indexBkt := towerToSessionIndex.NestedReadWriteBucket( + towerID.Bytes(), + ) + if indexBkt == nil { + return ErrTowerNotFound + } + + err = indexBkt.Put(session.ID[:], []byte{1}) + if err != nil { + return err + } + // Finally, write the client session's body in the sessions // bucket. return putClientSessionBody(sessions, session) @@ -662,15 +736,41 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID, // ListClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. -func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession, error) { +func (c *ClientDB) ListClientSessions(id *TowerID) ( + map[SessionID]*ClientSession, error) { + var clientSessions map[SessionID]*ClientSession err := kvdb.View(c.db, func(tx kvdb.RTx) error { sessions := tx.ReadBucket(cSessionBkt) if sessions == nil { return ErrUninitializedDB } + + towers := tx.ReadBucket(cTowerBkt) + if towers == nil { + return ErrUninitializedDB + } + var err error - clientSessions, err = listClientSessions(sessions, id) + + // If no tower ID is specified, then fetch all the sessions + // known to the db. + if id == nil { + clientSessions, err = listClientAllSessions( + sessions, towers, + ) + return err + } + + // Otherwise, fetch the sessions for the given tower. + towerToSessionIndex := tx.ReadBucket(cTowerToSessionIndexBkt) + if towerToSessionIndex == nil { + return ErrUninitializedDB + } + + clientSessions, err = listTowerSessions( + *id, sessions, towers, towerToSessionIndex, + ) return err }, func() { clientSessions = nil @@ -682,11 +782,9 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession return clientSessions, nil } -// listClientSessions returns the set of all client sessions known to the db. An -// optional tower ID can be used to filter out any client sessions in the -// response that do not correspond to this tower. -func listClientSessions(sessions kvdb.RBucket, - id *TowerID) (map[SessionID]*ClientSession, error) { +// listClientAllSessions returns the set of all client sessions known to the db. +func listClientAllSessions(sessions, + towers kvdb.RBucket) (map[SessionID]*ClientSession, error) { clientSessions := make(map[SessionID]*ClientSession) err := sessions.ForEach(func(k, _ []byte) error { @@ -694,19 +792,45 @@ func listClientSessions(sessions kvdb.RBucket, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := getClientSession(sessions, k) + session, err := getClientSession(sessions, towers, k) if err != nil { return err } - // Filter out any sessions that don't correspond to the given - // tower if one was set. - if id != nil && session.TowerID != *id { - return nil + clientSessions[session.ID] = session + + return nil + }) + if err != nil { + return nil, err + } + + return clientSessions, nil +} + +// listTowerSessions returns the set of all client sessions known to the db +// that are associated with the given tower id. +func listTowerSessions(id TowerID, sessionsBkt, towersBkt, + towerToSessionIndex kvdb.RBucket) (map[SessionID]*ClientSession, + error) { + + towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes()) + if towerIndexBkt == nil { + return nil, ErrTowerNotFound + } + + clientSessions := make(map[SessionID]*ClientSession) + err := towerIndexBkt.ForEach(func(k, _ []byte) error { + // We'll load the full client session since the client will need + // the CommittedUpdates and AckedUpdates on startup to resume + // committed updates and compute the highest known commit height + // for each channel. + session, err := getClientSession(sessionsBkt, towersBkt, k) + if err != nil { + return err } clientSessions[session.ID] = session - return nil }) if err != nil { @@ -951,7 +1075,9 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, // If the commits sub-bucket doesn't exist, there can't possibly // be a corresponding committed update to remove. - sessionCommits := sessionBkt.NestedReadWriteBucket(cSessionCommits) + sessionCommits := sessionBkt.NestedReadWriteBucket( + cSessionCommits, + ) if sessionCommits == nil { return ErrCommittedUpdateNotFound } @@ -1004,8 +1130,8 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, // getClientSessionBody loads the body of a ClientSession from the sessions // bucket corresponding to the serialized session id. This does not deserialize -// the CommittedUpdates or AckUpdates associated with the session. If the caller -// requires this info, use getClientSession. +// the CommittedUpdates, AckUpdates or the Tower associated with the session. +// If the caller requires this info, use getClientSession. func getClientSessionBody(sessions kvdb.RBucket, idBytes []byte) (*ClientSession, error) { @@ -1032,9 +1158,9 @@ func getClientSessionBody(sessions kvdb.RBucket, } // getClientSession loads the full ClientSession associated with the serialized -// session id. This method populates the CommittedUpdates and AckUpdates in -// addition to the ClientSession's body. -func getClientSession(sessions kvdb.RBucket, +// session id. This method populates the CommittedUpdates, AckUpdates and Tower +// in addition to the ClientSession's body. +func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte) (*ClientSession, error) { session, err := getClientSessionBody(sessions, idBytes) @@ -1042,6 +1168,12 @@ func getClientSession(sessions kvdb.RBucket, return nil, err } + // Fetch the tower associated with this session. + tower, err := getTower(towers, session.TowerID.Bytes()) + if err != nil { + return nil, err + } + // Fetch the committed updates for this session. commitedUpdates, err := getClientSessionCommits(sessions, idBytes) if err != nil { @@ -1054,6 +1186,7 @@ func getClientSession(sessions kvdb.RBucket, return nil, err } + session.Tower = tower session.CommittedUpdates = commitedUpdates session.AckedUpdates = ackedUpdates diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index f694b746f..d4f1699c9 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -1,11 +1,9 @@ package wtdb_test import ( - "bytes" crand "crypto/rand" "io" "net" - "reflect" "testing" "github.com/btcsuite/btcd/btcec/v2" @@ -16,8 +14,12 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" + "github.com/stretchr/testify/require" ) +// pseudoAddr is a fake network address to be used for testing purposes. +var pseudoAddr = &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911} + // clientDBInit is a closure used to initialize a wtclient.DB instance. type clientDBInit func(t *testing.T) wtclient.DB @@ -37,23 +39,22 @@ func newClientDBHarness(t *testing.T, init clientDBInit) *clientDBHarness { return h } -func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, expErr error) { +func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, + expErr error) { + h.t.Helper() err := h.db.CreateClientSession(session) - if err != expErr { - h.t.Fatalf("expected create client session error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) } -func (h *clientDBHarness) listSessions(id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession { +func (h *clientDBHarness) listSessions( + id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession { + h.t.Helper() sessions, err := h.db.ListClientSessions(id) - if err != nil { - h.t.Fatalf("unable to list client sessions: %v", err) - } + require.NoError(h.t, err, "unable to list client sessions") return sessions } @@ -64,13 +65,8 @@ func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID, h.t.Helper() index, err := h.db.NextSessionKeyIndex(id, blobType) - if err != nil { - h.t.Fatalf("unable to create next session key index: %v", err) - } - - if index == 0 { - h.t.Fatalf("next key index should never be 0") - } + require.NoError(h.t, err, "unable to create next session key index") + require.NotZero(h.t, index, "next key index should never be 0") return index } @@ -81,20 +77,11 @@ func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress, h.t.Helper() tower, err := h.db.CreateTower(lnAddr) - if err != expErr { - h.t.Fatalf("expected create tower error: %v, got: %v", expErr, err) - } - - if tower.ID == 0 { - h.t.Fatalf("tower id should never be 0") - } + require.ErrorIs(h.t, err, expErr) + require.NotZero(h.t, tower.ID, "tower id should never be 0") for _, session := range h.listSessions(&tower.ID) { - if session.Status != wtdb.CSessionActive { - h.t.Fatalf("expected status for session %v to be %v, "+ - "got %v", session.ID, wtdb.CSessionActive, - session.Status) - } + require.Equal(h.t, wtdb.CSessionActive, session.Status) } return tower @@ -105,68 +92,64 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr, h.t.Helper() - if err := h.db.RemoveTower(pubKey, addr); err != expErr { - h.t.Fatalf("expected remove tower error: %v, got %v", expErr, err) - } + err := h.db.RemoveTower(pubKey, addr) + require.ErrorIs(h.t, err, expErr) + if expErr != nil { return } + pubKeyStr := pubKey.SerializeCompressed() + if addr != nil { tower, err := h.db.LoadTower(pubKey) - if err != nil { - h.t.Fatalf("expected tower %x to still exist", - pubKey.SerializeCompressed()) - } + require.NoErrorf(h.t, err, "expected tower %x to still exist", + pubKeyStr) removedAddr := addr.String() for _, towerAddr := range tower.Addresses { - if towerAddr.String() == removedAddr { - h.t.Fatalf("address %v not removed for tower %x", - removedAddr, pubKey.SerializeCompressed()) - } + require.NotEqualf(h.t, removedAddr, towerAddr, + "address %v not removed for tower %x", + removedAddr, pubKeyStr) } } else { tower, err := h.db.LoadTower(pubKey) - if hasSessions && err != nil { - h.t.Fatalf("expected tower %x with sessions to still "+ - "exist", pubKey.SerializeCompressed()) - } - if !hasSessions && err == nil { - h.t.Fatalf("expected tower %x with no sessions to not "+ - "exist", pubKey.SerializeCompressed()) - } - if !hasSessions { + if hasSessions { + require.NoError(h.t, err, "expected tower %x with "+ + "sessions to still exist", pubKeyStr) + } else { + require.Errorf(h.t, err, "expected tower %x with no "+ + "sessions to not exist", pubKeyStr) return } + for _, session := range h.listSessions(&tower.ID) { - if session.Status != wtdb.CSessionInactive { - h.t.Fatalf("expected status for session %v to "+ - "be %v, got %v", session.ID, - wtdb.CSessionInactive, session.Status) - } + require.Equal(h.t, wtdb.CSessionInactive, + session.Status, "expected status for session "+ + "%v to be %v, got %v", session.ID, + wtdb.CSessionInactive, session.Status) } } } -func (h *clientDBHarness) loadTower(pubKey *btcec.PublicKey, expErr error) *wtdb.Tower { +func (h *clientDBHarness) loadTower(pubKey *btcec.PublicKey, + expErr error) *wtdb.Tower { + h.t.Helper() tower, err := h.db.LoadTower(pubKey) - if err != expErr { - h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err) - } + require.ErrorIs(h.t, err, expErr) return tower } -func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID, expErr error) *wtdb.Tower { +func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID, + expErr error) *wtdb.Tower { + h.t.Helper() tower, err := h.db.LoadTowerByID(id) - if err != expErr { - h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err) - } + require.ErrorIs(h.t, err, expErr) return tower } @@ -175,9 +158,7 @@ func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientC h.t.Helper() summaries, err := h.db.FetchChanSummaries() - if err != nil { - h.t.Fatalf("unable to fetch chan summaries: %v", err) - } + require.NoError(h.t, err) return summaries } @@ -188,10 +169,7 @@ func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID, h.t.Helper() err := h.db.RegisterChannel(chanID, sweepPkScript) - if err != expErr { - h.t.Fatalf("expected register channel error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) } func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID, @@ -200,10 +178,7 @@ func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID, h.t.Helper() lastApplied, err := h.db.CommitUpdate(id, update) - if err != expErr { - h.t.Fatalf("expected commit update error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) return lastApplied } @@ -214,10 +189,22 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16, h.t.Helper() err := h.db.AckUpdate(id, seqNum, lastApplied) - if err != expErr { - h.t.Fatalf("expected commit update error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) +} + +// newTower is a helper function that creates a new tower with a randomly +// generated public key and inserts it into the client DB. +func (h *clientDBHarness) newTower() *wtdb.Tower { + h.t.Helper() + + pk, err := randPubKey() + require.NoError(h.t, err) + + // Insert a random tower into the database. + return h.createTower(&lnwire.NetAddress{ + IdentityKey: pk, + Address: pseudoAddr, + }, nil) } // testCreateClientSession asserts various conditions regarding the creation of @@ -228,10 +215,12 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16, func testCreateClientSession(h *clientDBHarness) { const blobType = blob.TypeAltruistAnchorCommit + tower := h.newTower() + // Create a test client session to insert. session := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ - TowerID: wtdb.TowerID(3), + TowerID: tower.ID, Policy: wtpolicy.Policy{ TxPolicy: wtpolicy.TxPolicy{ BlobType: blobType, @@ -245,9 +234,9 @@ func testCreateClientSession(h *clientDBHarness) { // First, assert that this session is not already present in the // database. - if _, ok := h.listSessions(nil)[session.ID]; ok { - h.t.Fatalf("session for id %x should not exist yet", session.ID) - } + _, ok := h.listSessions(nil)[session.ID] + require.Falsef(h.t, ok, "session for id %x should not exist yet", + session.ID) // Attempting to insert the client session without reserving a session // key index should fail. @@ -264,10 +253,8 @@ func testCreateClientSession(h *clientDBHarness) { // successfully created, it should return the same index to maintain // idempotency across restarts. keyIndex2 := h.nextKeyIndex(session.TowerID, blobType) - if keyIndex != keyIndex2 { - h.t.Fatalf("next key index should be idempotent: want: %v, "+ - "got %v", keyIndex, keyIndex2) - } + require.Equalf(h.t, keyIndex, keyIndex2, "next key index should "+ + "be idempotent: want: %v, got %v", keyIndex, keyIndex2) // Now, set the client session's key index so that it is proper and // insert it. This should succeed. @@ -275,9 +262,8 @@ func testCreateClientSession(h *clientDBHarness) { h.insertSession(session, nil) // Verify that the session now exists in the database. - if _, ok := h.listSessions(nil)[session.ID]; !ok { - h.t.Fatalf("session for id %x should exist now", session.ID) - } + _, ok = h.listSessions(nil)[session.ID] + require.Truef(h.t, ok, "session for id %x should exist now", session.ID) // Attempt to insert the session again, which should fail due to the // session already existing. @@ -286,9 +272,8 @@ func testCreateClientSession(h *clientDBHarness) { // Finally, assert that reserving another key index succeeds with a // different key index, now that the first one has been finalized. keyIndex3 := h.nextKeyIndex(session.TowerID, blobType) - if keyIndex == keyIndex3 { - h.t.Fatalf("key index still reserved after creating session") - } + require.NotEqualf(h.t, keyIndex, keyIndex3, "key index still "+ + "reserved after creating session") } // testFilterClientSessions asserts that we can correctly filter client sessions @@ -300,15 +285,12 @@ func testFilterClientSessions(h *clientDBHarness) { const blobType = blob.TypeAltruistCommit towerSessions := make(map[wtdb.TowerID][]wtdb.SessionID) for i := 0; i < numSessions; i++ { - towerID := wtdb.TowerID(1) - if i == numSessions-1 { - towerID = wtdb.TowerID(2) - } - keyIndex := h.nextKeyIndex(towerID, blobType) + tower := h.newTower() + keyIndex := h.nextKeyIndex(tower.ID, blobType) sessionID := wtdb.SessionID([33]byte{byte(i)}) h.insertSession(&wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ - TowerID: towerID, + TowerID: tower.ID, Policy: wtpolicy.Policy{ TxPolicy: wtpolicy.TxPolicy{ BlobType: blobType, @@ -320,22 +302,21 @@ func testFilterClientSessions(h *clientDBHarness) { }, ID: sessionID, }, nil) - towerSessions[towerID] = append(towerSessions[towerID], sessionID) + towerSessions[tower.ID] = append( + towerSessions[tower.ID], sessionID, + ) } // We should see the expected sessions for each tower when filtering // them. for towerID, expectedSessions := range towerSessions { sessions := h.listSessions(&towerID) - if len(sessions) != len(expectedSessions) { - h.t.Fatalf("expected %v sessions for tower %v, got %v", - len(expectedSessions), towerID, len(sessions)) - } + require.Len(h.t, sessions, len(expectedSessions)) + for _, expectedSession := range expectedSessions { - if _, ok := sessions[expectedSession]; !ok { - h.t.Fatalf("expected session %v for tower %v", - expectedSession, towerID) - } + _, ok := sessions[expectedSession] + require.Truef(h.t, ok, "expected session %v for "+ + "tower %v", expectedSession, towerID) } } } @@ -347,49 +328,31 @@ func testCreateTower(h *clientDBHarness) { // Test that loading a tower with an arbitrary tower id fails. h.loadTowerByID(20, wtdb.ErrTowerNotFound) - pk, err := randPubKey() - if err != nil { - h.t.Fatalf("unable to generate pubkey: %v", err) - } - - addr1 := &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911} - lnAddr := &lnwire.NetAddress{ - IdentityKey: pk, - Address: addr1, - } - - // Insert a random tower into the database. - tower := h.createTower(lnAddr, nil) + tower := h.newTower() + require.Len(h.t, tower.LNAddrs(), 1) + towerAddr := tower.LNAddrs()[0] // Load the tower from the database and assert that it matches the tower // we created. tower2 := h.loadTowerByID(tower.ID, nil) - if !reflect.DeepEqual(tower, tower2) { - h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", - tower, tower2) - } - tower2 = h.loadTower(pk, err) - if !reflect.DeepEqual(tower, tower2) { - h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", - tower, tower2) - } + require.Equal(h.t, tower, tower2) + + tower2 = h.loadTower(tower.IdentityKey, nil) + require.Equal(h.t, tower, tower2) // Insert the address again into the database. Since the address is the // same, this should result in an unmodified tower record. - towerDupAddr := h.createTower(lnAddr, nil) - if len(towerDupAddr.Addresses) != 1 { - h.t.Fatalf("duplicate address should be deduped") - } - if !reflect.DeepEqual(tower, towerDupAddr) { - h.t.Fatalf("mismatch towers, want: %v, got: %v", - tower, towerDupAddr) - } + towerDupAddr := h.createTower(towerAddr, nil) + require.Lenf(h.t, towerDupAddr.Addresses, 1, "duplicate address "+ + "should be deduped") + + require.Equal(h.t, tower, towerDupAddr) // Generate a new address for this tower. addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911} lnAddr2 := &lnwire.NetAddress{ - IdentityKey: pk, + IdentityKey: tower.IdentityKey, Address: addr2, } @@ -400,26 +363,18 @@ func testCreateTower(h *clientDBHarness) { // Load the tower from the database, and assert that it matches the // tower returned from creation. towerNewAddr2 := h.loadTowerByID(tower.ID, nil) - if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) { - h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", - towerNewAddr, towerNewAddr2) - } - towerNewAddr2 = h.loadTower(pk, nil) - if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) { - h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", - towerNewAddr, towerNewAddr2) - } + require.Equal(h.t, towerNewAddr, towerNewAddr2) + + towerNewAddr2 = h.loadTower(tower.IdentityKey, nil) + require.Equal(h.t, towerNewAddr, towerNewAddr2) // Assert that there are now two addresses on the tower object. - if len(towerNewAddr.Addresses) != 2 { - h.t.Fatalf("new address should be added") - } + require.Lenf(h.t, towerNewAddr.Addresses, 2, "new address should be "+ + "added") // Finally, assert that the new address was prepended since it is deemed // fresher. - if !reflect.DeepEqual(tower.Addresses, towerNewAddr.Addresses[1:]) { - h.t.Fatalf("new address should be prepended") - } + require.Equal(h.t, tower.Addresses, towerNewAddr.Addresses[1:]) } // testRemoveTower asserts the behavior of removing Tower objects as a whole and @@ -427,9 +382,7 @@ func testCreateTower(h *clientDBHarness) { func testRemoveTower(h *clientDBHarness) { // Generate a random public key we'll use for our tower. pk, err := randPubKey() - if err != nil { - h.t.Fatalf("unable to generate pubkey: %v", err) - } + require.NoError(h.t, err) // Removing a tower that does not exist within the database should // result in a NOP. @@ -507,28 +460,23 @@ func testRemoveTower(h *clientDBHarness) { func testChanSummaries(h *clientDBHarness) { // First, assert that this channel is not already registered. var chanID lnwire.ChannelID - if _, ok := h.fetchChanSummaries()[chanID]; ok { - h.t.Fatalf("pkscript for channel %x should not exist yet", - chanID) - } + _, ok := h.fetchChanSummaries()[chanID] + require.Falsef(h.t, ok, "pkscript for channel %x should not exist yet", + chanID) // Generate a random sweep pkscript and register it for this channel. expPkScript := make([]byte, 22) - if _, err := io.ReadFull(crand.Reader, expPkScript); err != nil { - h.t.Fatalf("unable to generate pkscript: %v", err) - } + _, err := io.ReadFull(crand.Reader, expPkScript) + require.NoError(h.t, err) + h.registerChan(chanID, expPkScript, nil) // Assert that the channel exists and that its sweep pkscript matches // the one we registered. summary, ok := h.fetchChanSummaries()[chanID] - if !ok { - h.t.Fatalf("pkscript for channel %x should not exist yet", - chanID) - } else if bytes.Compare(expPkScript, summary.SweepPkScript) != 0 { - h.t.Fatalf("pkscript mismatch, want: %x, got: %x", - expPkScript, summary.SweepPkScript) - } + require.Truef(h.t, ok, "pkscript for channel %x should not exist yet", + chanID) + require.Equal(h.t, expPkScript, summary.SweepPkScript) // Finally, assert that re-registering the same channel produces a // failure. @@ -538,9 +486,11 @@ func testChanSummaries(h *clientDBHarness) { // testCommitUpdate tests the behavior of CommitUpdate, ensuring that they can func testCommitUpdate(h *clientDBHarness) { const blobType = blob.TypeAltruistCommit + + tower := h.newTower() session := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ - TowerID: wtdb.TowerID(3), + TowerID: tower.ID, Policy: wtpolicy.Policy{ TxPolicy: wtpolicy.TxPolicy{ BlobType: blobType, @@ -565,10 +515,7 @@ func testCommitUpdate(h *clientDBHarness) { // succeed. The lastApplied value should be 0 since we have not received // an ack from the tower. lastApplied := h.commitUpdate(&session.ID, update1, nil) - if lastApplied != 0 { - h.t.Fatalf("last applied mismatch, want: 0, got: %v", - lastApplied) - } + require.Zero(h.t, lastApplied) // Assert that the committed update appears in the client session's // CommittedUpdates map when loaded from disk and that there are no @@ -584,10 +531,7 @@ func testCommitUpdate(h *clientDBHarness) { // the on-disk update's hint). The lastApplied value should remain // unchanged. lastApplied2 := h.commitUpdate(&session.ID, update1, nil) - if lastApplied2 != lastApplied { - h.t.Fatalf("last applied should not have changed, got %v", - lastApplied2) - } + require.Equal(h.t, lastApplied, lastApplied2) // Assert that the loaded ClientSession is the same as before. dbSession = h.listSessions(nil)[session.ID] @@ -605,10 +549,7 @@ func testCommitUpdate(h *clientDBHarness) { // which should succeed. update2.SeqNum = 2 lastApplied3 := h.commitUpdate(&session.ID, update2, nil) - if lastApplied3 != lastApplied { - h.t.Fatalf("last applied should not have changed, got %v", - lastApplied3) - } + require.Equal(h.t, lastApplied, lastApplied3) // Check that both updates now appear as committed on the ClientSession // loaded from disk. @@ -638,10 +579,12 @@ func testCommitUpdate(h *clientDBHarness) { func testAckUpdate(h *clientDBHarness) { const blobType = blob.TypeAltruistCommit + tower := h.newTower() + // Create a new session that the updates in this will be tied to. session := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ - TowerID: wtdb.TowerID(3), + TowerID: tower.ID, Policy: wtpolicy.Policy{ TxPolicy: wtpolicy.TxPolicy{ BlobType: blobType, @@ -668,10 +611,7 @@ func testAckUpdate(h *clientDBHarness) { // Commit to a random update at seqnum 1. update1 := randCommittedUpdate(h.t, 1) lastApplied := h.commitUpdate(&session.ID, update1, nil) - if lastApplied != 0 { - h.t.Fatalf("last applied mismatch, want: 0, got: %v", - lastApplied) - } + require.Zero(h.t, lastApplied) // Acking seqnum 1 should succeed. h.ackUpdate(&session.ID, 1, 1, nil) @@ -699,10 +639,7 @@ func testAckUpdate(h *clientDBHarness) { // ack. update2 := randCommittedUpdate(h.t, 2) lastApplied = h.commitUpdate(&session.ID, update2, nil) - if lastApplied != 1 { - h.t.Fatalf("last applied mismatch, want: 1, got: %v", - lastApplied) - } + require.EqualValues(h.t, 1, lastApplied) // Ack seqnum 2. h.ackUpdate(&session.ID, 2, 2, nil) @@ -740,10 +677,7 @@ func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession, expUpdates = make([]wtdb.CommittedUpdate, 0) } - if !reflect.DeepEqual(session.CommittedUpdates, expUpdates) { - t.Fatalf("committed updates mismatch, want: %v, got: %v", - expUpdates, session.CommittedUpdates) - } + require.Equal(t, expUpdates, session.CommittedUpdates) } // checkAckedUpdates asserts that the AckedUpdates on a session match the @@ -758,10 +692,7 @@ func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession, expUpdates = make(map[uint16]wtdb.BackupID) } - if !reflect.DeepEqual(session.AckedUpdates, expUpdates) { - t.Fatalf("acked updates mismatch, want: %v, got: %v", - expUpdates, session.AckedUpdates) - } + require.Equal(t, expUpdates, session.AckedUpdates) } // TestClientDB asserts the behavior of a fresh client db, a reopened client db, @@ -779,14 +710,10 @@ func TestClientDB(t *testing.T) { bdb, err := wtdb.NewBoltBackendCreator( true, t.TempDir(), "wtclient.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err := wtdb.OpenClientDB(bdb) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) t.Cleanup(func() { db.Close() @@ -803,27 +730,19 @@ func TestClientDB(t *testing.T) { bdb, err := wtdb.NewBoltBackendCreator( true, path, "wtclient.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err := wtdb.OpenClientDB(bdb) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db.Close() bdb, err = wtdb.NewBoltBackendCreator( true, path, "wtclient.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err = wtdb.OpenClientDB(bdb) - if err != nil { - t.Fatalf("unable to reopen db: %v", err) - } + require.NoError(t, err) t.Cleanup(func() { db.Close() @@ -893,19 +812,16 @@ func TestClientDB(t *testing.T) { // randCommittedUpdate generates a random committed update. func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate { var chanID lnwire.ChannelID - if _, err := io.ReadFull(crand.Reader, chanID[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - } + _, err := io.ReadFull(crand.Reader, chanID[:]) + require.NoError(t, err) var hint blob.BreachHint - if _, err := io.ReadFull(crand.Reader, hint[:]); err != nil { - t.Fatalf("unable to generate breach hint: %v", err) - } + _, err = io.ReadFull(crand.Reader, hint[:]) + require.NoError(t, err) encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type())) - if _, err := io.ReadFull(crand.Reader, encBlob); err != nil { - t.Fatalf("unable to generate encrypted blob: %v", err) - } + _, err = io.ReadFull(crand.Reader, encBlob) + require.NoError(t, err) return &wtdb.CommittedUpdate{ SeqNum: seqNum, diff --git a/watchtower/wtdb/codec_test.go b/watchtower/wtdb/codec_test.go index 7842b13bc..c2628b86a 100644 --- a/watchtower/wtdb/codec_test.go +++ b/watchtower/wtdb/codec_test.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/stretchr/testify/require" ) func randPubKey() (*btcec.PublicKey, error) { @@ -134,10 +135,7 @@ func TestCodec(tt *testing.T) { // Ensure encoding the object succeeds. var b bytes.Buffer err := obj.Encode(&b) - if err != nil { - t.Fatalf("unable to encode: %v", err) - return false - } + require.NoError(t, err) var obj2 dbObject switch obj.(type) { @@ -162,17 +160,10 @@ func TestCodec(tt *testing.T) { // Ensure decoding the object succeeds. err = obj2.Decode(bytes.NewReader(b.Bytes())) - if err != nil { - t.Fatalf("unable to decode: %v", err) - return false - } + require.NoError(t, err) // Assert the original and decoded object match. - if !reflect.DeepEqual(obj, obj2) { - t.Fatalf("encode/decode mismatch, want: %v, "+ - "got: %v", obj, obj2) - return false - } + require.Equal(t, obj, obj2) return true } @@ -180,16 +171,10 @@ func TestCodec(tt *testing.T) { customTypeGen := map[string]func([]reflect.Value, *rand.Rand){ "Tower": func(v []reflect.Value, r *rand.Rand) { pk, err := randPubKey() - if err != nil { - t.Fatalf("unable to generate pubkey: %v", err) - return - } + require.NoError(t, err) addrs, err := randAddrs(r) - if err != nil { - t.Fatalf("unable to generate addrs: %v", err) - return - } + require.NoError(t, err) obj := wtdb.Tower{ IdentityKey: pk, @@ -260,10 +245,7 @@ func TestCodec(tt *testing.T) { } err := quick.Check(test.scenario, config) - if err != nil { - t.Fatalf("fuzz checks for msg=%s failed: %v", - test.name, err) - } + require.NoError(h, err) }) } } diff --git a/watchtower/wtdb/log.go b/watchtower/wtdb/log.go index 0e14ea996..6ddb6c35f 100644 --- a/watchtower/wtdb/log.go +++ b/watchtower/wtdb/log.go @@ -3,6 +3,7 @@ package wtdb import ( "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration1" ) // log is a logger that is initialized with no output filters. This @@ -26,6 +27,7 @@ func DisableLog() { // using btclog. func UseLogger(logger btclog.Logger) { log = logger + migration1.UseLogger(logger) } // logClosure is used to provide a closure over expensive logging operations so diff --git a/watchtower/wtdb/migration1/client_db.go b/watchtower/wtdb/migration1/client_db.go new file mode 100644 index 000000000..d09ef6ef7 --- /dev/null +++ b/watchtower/wtdb/migration1/client_db.go @@ -0,0 +1,145 @@ +package migration1 + +import ( + "bytes" + "errors" + "fmt" + + "github.com/lightningnetwork/lnd/kvdb" +) + +var ( + // cSessionBkt is a top-level bucket storing: + // session-id => cSessionBody -> encoded ClientSessionBody + // => cSessionCommits => seqnum -> encoded CommittedUpdate + // => cSessionAcks => seqnum -> encoded BackupID + cSessionBkt = []byte("client-session-bucket") + + // cSessionBody is a sub-bucket of cSessionBkt storing only the body of + // the ClientSession. + cSessionBody = []byte("client-session-body") + + // cTowerIDToSessionIDIndexBkt is a top-level bucket storing: + // tower-id -> session-id -> 1 + cTowerIDToSessionIDIndexBkt = []byte( + "client-tower-to-session-index-bucket", + ) + + // ErrUninitializedDB signals that top-level buckets for the database + // have not been initialized. + ErrUninitializedDB = errors.New("db not initialized") + + // ErrClientSessionNotFound signals that the requested client session + // was not found in the database. + ErrClientSessionNotFound = errors.New("client session not found") + + // ErrCorruptClientSession signals that the client session's on-disk + // structure deviates from what is expected. + ErrCorruptClientSession = errors.New("client session corrupted") +) + +// MigrateTowerToSessionIndex constructs a new towerID-to-sessionID for the +// watchtower client DB. +func MigrateTowerToSessionIndex(tx kvdb.RwTx) error { + log.Infof("Migrating the tower client db to add a " + + "towerID-to-sessionID index") + + // First, we collect all the entries we want to add to the index. + entries, err := getIndexEntries(tx) + if err != nil { + return err + } + + // Then we create a new top-level bucket for the index. + indexBkt, err := tx.CreateTopLevelBucket(cTowerIDToSessionIDIndexBkt) + if err != nil { + return err + } + + // Finally, we add all the collected entries to the index. + for towerID, sessions := range entries { + // Create a sub-bucket using the tower ID. + towerBkt, err := indexBkt.CreateBucketIfNotExists( + towerID.Bytes(), + ) + if err != nil { + return err + } + + for sessionID := range sessions { + err := addIndex(towerBkt, sessionID) + if err != nil { + return err + } + } + } + + return nil +} + +// addIndex adds a new towerID-sessionID pair to the given bucket. The +// session ID is used as a key within the bucket and a value of []byte{1} is +// used for each session ID key. +func addIndex(towerBkt kvdb.RwBucket, sessionID SessionID) error { + session := towerBkt.Get(sessionID[:]) + if session != nil { + return fmt.Errorf("session %x duplicated", sessionID) + } + + return towerBkt.Put(sessionID[:], []byte{1}) +} + +// getIndexEntries collects all the towerID-sessionID entries that need to be +// added to the new index. +func getIndexEntries(tx kvdb.RwTx) (map[TowerID]map[SessionID]bool, error) { + sessions := tx.ReadBucket(cSessionBkt) + if sessions == nil { + return nil, ErrUninitializedDB + } + + index := make(map[TowerID]map[SessionID]bool) + err := sessions.ForEach(func(k, _ []byte) error { + session, err := getClientSession(sessions, k) + if err != nil { + return err + } + + if index[session.TowerID] == nil { + index[session.TowerID] = make(map[SessionID]bool) + } + + index[session.TowerID][session.ID] = true + return nil + }) + if err != nil { + return nil, err + } + + return index, nil +} + +// getClientSession fetches the session with the given ID from the db. +func getClientSession(sessions kvdb.RBucket, idBytes []byte) (*ClientSession, + error) { + + sessionBkt := sessions.NestedReadBucket(idBytes) + if sessionBkt == nil { + return nil, ErrClientSessionNotFound + } + + // Should never have a sessionBkt without also having its body. + sessionBody := sessionBkt.Get(cSessionBody) + if sessionBody == nil { + return nil, ErrCorruptClientSession + } + + var session ClientSession + copy(session.ID[:], idBytes) + + err := session.Decode(bytes.NewReader(sessionBody)) + if err != nil { + return nil, err + } + + return &session, nil +} diff --git a/watchtower/wtdb/migration1/client_db_test.go b/watchtower/wtdb/migration1/client_db_test.go new file mode 100644 index 000000000..acae177ad --- /dev/null +++ b/watchtower/wtdb/migration1/client_db_test.go @@ -0,0 +1,155 @@ +package migration1 + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/channeldb/migtest" + "github.com/lightningnetwork/lnd/kvdb" +) + +var ( + s1 = &ClientSessionBody{ + TowerID: TowerID(1), + } + s2 = &ClientSessionBody{ + TowerID: TowerID(3), + } + s3 = &ClientSessionBody{ + TowerID: TowerID(6), + } + + // pre is the expected data in the DB before the migration. + pre = map[string]interface{}{ + sessionIDString("1"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s1), + }, + sessionIDString("2"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s3), + }, + sessionIDString("3"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s1), + }, + sessionIDString("4"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s1), + }, + sessionIDString("5"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s2), + }, + } + + // preFailNoSessionBody should fail the migration due to there being a + // session without an associated session body. + preFailNoSessionBody = map[string]interface{}{ + sessionIDString("1"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s1), + }, + sessionIDString("2"): map[string]interface{}{}, + } + + // post is the expected data after migration. + post = map[string]interface{}{ + towerIDString(1): map[string]interface{}{ + sessionIDString("1"): string([]byte{1}), + sessionIDString("3"): string([]byte{1}), + sessionIDString("4"): string([]byte{1}), + }, + towerIDString(3): map[string]interface{}{ + sessionIDString("5"): string([]byte{1}), + }, + towerIDString(6): map[string]interface{}{ + sessionIDString("2"): string([]byte{1}), + }, + } +) + +// TestMigrateTowerToSessionIndex tests that the TestMigrateTowerToSessionIndex +// function correctly adds a new towerID-to-sessionID index to the tower client +// db. +func TestMigrateTowerToSessionIndex(t *testing.T) { + tests := []struct { + name string + shouldFail bool + pre map[string]interface{} + post map[string]interface{} + }{ + { + name: "migration ok", + shouldFail: false, + pre: pre, + post: post, + }, + { + name: "fail due to corrupt db", + shouldFail: true, + pre: preFailNoSessionBody, + post: nil, + }, + { + name: "no sessions", + shouldFail: false, + pre: nil, + post: nil, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + // Before the migration we have a sessions bucket. + before := func(tx kvdb.RwTx) error { + return migtest.RestoreDB( + tx, cSessionBkt, test.pre, + ) + } + + // After the migration, we should have an untouched + // sessions bucket and a new index bucket. + after := func(tx kvdb.RwTx) error { + if err := migtest.VerifyDB( + tx, cSessionBkt, test.pre, + ); err != nil { + return err + } + + // If we expect our migration to fail, we don't + // expect an index bucket. + if test.shouldFail { + return nil + } + + return migtest.VerifyDB( + tx, cTowerIDToSessionIDIndexBkt, + test.post, + ) + } + + migtest.ApplyMigration( + t, before, after, MigrateTowerToSessionIndex, + test.shouldFail, + ) + }) + } +} + +func sessionIDString(id string) string { + var sessID SessionID + copy(sessID[:], id) + return string(sessID[:]) +} + +func clientSessionString(s *ClientSessionBody) string { + var b bytes.Buffer + err := s.Encode(&b) + if err != nil { + panic(err) + } + + return b.String() +} + +func towerIDString(id int) string { + towerID := TowerID(id) + return string(towerID.Bytes()) +} diff --git a/watchtower/wtdb/migration1/codec.go b/watchtower/wtdb/migration1/codec.go new file mode 100644 index 000000000..8c5a2299c --- /dev/null +++ b/watchtower/wtdb/migration1/codec.go @@ -0,0 +1,241 @@ +package migration1 + +import ( + "encoding/binary" + "io" + + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/lightningnetwork/lnd/watchtower/blob" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" +) + +// SessionIDSize is 33-bytes; it is a serialized, compressed public key. +const SessionIDSize = 33 + +// UnknownElementType is an alias for channeldb.UnknownElementType. +type UnknownElementType = channeldb.UnknownElementType + +// SessionID is created from the remote public key of a client, and serves as a +// unique identifier and authentication for sending state updates. +type SessionID [SessionIDSize]byte + +// TowerID is a unique 64-bit identifier allocated to each unique watchtower. +// This allows the client to conserve on-disk space by not needing to always +// reference towers by their pubkey. +type TowerID uint64 + +// Bytes encodes a TowerID into an 8-byte slice in big-endian byte order. +func (id TowerID) Bytes() []byte { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], uint64(id)) + return buf[:] +} + +// ClientSession encapsulates a SessionInfo returned from a successful +// session negotiation, and also records the tower and ephemeral secret used for +// communicating with the tower. +type ClientSession struct { + // ID is the client's public key used when authenticating with the + // tower. + ID SessionID + ClientSessionBody +} + +// CSessionStatus is a bit-field representing the possible statuses of +// ClientSessions. +type CSessionStatus uint8 + +type ClientSessionBody struct { + // SeqNum is the next unallocated sequence number that can be sent to + // the tower. + SeqNum uint16 + + // TowerLastApplied the last last-applied the tower has echoed back. + TowerLastApplied uint16 + + // TowerID is the unique, db-assigned identifier that references the + // Tower with which the session is negotiated. + TowerID TowerID + + // KeyIndex is the index of key locator used to derive the client's + // session key so that it can authenticate with the tower to update its + // session. In order to rederive the private key, the key locator should + // use the keychain.KeyFamilyTowerSession key family. + KeyIndex uint32 + + // Policy holds the negotiated session parameters. + Policy wtpolicy.Policy + + // Status indicates the current state of the ClientSession. + Status CSessionStatus + + // RewardPkScript is the pkscript that the tower's reward will be + // deposited to if a sweep transaction confirms and the sessions + // specifies a reward output. + RewardPkScript []byte +} + +// Encode writes a ClientSessionBody to the passed io.Writer. +func (s *ClientSessionBody) Encode(w io.Writer) error { + return WriteElements(w, + s.SeqNum, + s.TowerLastApplied, + uint64(s.TowerID), + s.KeyIndex, + uint8(s.Status), + s.Policy, + s.RewardPkScript, + ) +} + +// Decode reads a ClientSessionBody from the passed io.Reader. +func (s *ClientSessionBody) Decode(r io.Reader) error { + var ( + towerID uint64 + status uint8 + ) + err := ReadElements(r, + &s.SeqNum, + &s.TowerLastApplied, + &towerID, + &s.KeyIndex, + &status, + &s.Policy, + &s.RewardPkScript, + ) + if err != nil { + return err + } + + s.TowerID = TowerID(towerID) + s.Status = CSessionStatus(status) + + return nil +} + +// WriteElements serializes a variadic list of elements into the given +// io.Writer. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + if err := WriteElement(w, element); err != nil { + return err + } + } + + return nil +} + +// WriteElement serializes a single element into the provided io.Writer. +func WriteElement(w io.Writer, element interface{}) error { + err := channeldb.WriteElement(w, element) + switch { + // Known to channeldb codec. + case err == nil: + return nil + + // Fail if error is not UnknownElementType. + default: + if _, ok := err.(UnknownElementType); !ok { + return err + } + } + + // Process any wtdb-specific extensions to the codec. + switch e := element.(type) { + case SessionID: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case blob.BreachHint: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case wtpolicy.Policy: + return channeldb.WriteElements(w, + uint16(e.BlobType), + e.MaxUpdates, + e.RewardBase, + e.RewardRate, + uint64(e.SweepFeeRate), + ) + + // Type is still unknown to wtdb extensions, fail. + default: + return channeldb.NewUnknownElementType( + "WriteElement", element, + ) + } + + return nil +} + +// ReadElements deserializes the provided io.Reader into a variadic list of +// target elements. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + if err := ReadElement(r, element); err != nil { + return err + } + } + + return nil +} + +// ReadElement deserializes a single element from the provided io.Reader. +func ReadElement(r io.Reader, element interface{}) error { + err := channeldb.ReadElement(r, element) + switch { + // Known to channeldb codec. + case err == nil: + return nil + + // Fail if error is not UnknownElementType. + default: + if _, ok := err.(UnknownElementType); !ok { + return err + } + } + + // Process any wtdb-specific extensions to the codec. + switch e := element.(type) { + case *SessionID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *blob.BreachHint: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *wtpolicy.Policy: + var ( + blobType uint16 + sweepFeeRate uint64 + ) + err := channeldb.ReadElements(r, + &blobType, + &e.MaxUpdates, + &e.RewardBase, + &e.RewardRate, + &sweepFeeRate, + ) + if err != nil { + return err + } + + e.BlobType = blob.Type(blobType) + e.SweepFeeRate = chainfee.SatPerKWeight(sweepFeeRate) + + // Type is still unknown to wtdb extensions, fail. + default: + return channeldb.NewUnknownElementType( + "ReadElement", element, + ) + } + + return nil +} diff --git a/watchtower/wtdb/migration1/log.go b/watchtower/wtdb/migration1/log.go new file mode 100644 index 000000000..1dc105280 --- /dev/null +++ b/watchtower/wtdb/migration1/log.go @@ -0,0 +1,14 @@ +package migration1 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/watchtower/wtdb/tower_db_test.go b/watchtower/wtdb/tower_db_test.go index 177dbd233..9459f34d3 100644 --- a/watchtower/wtdb/tower_db_test.go +++ b/watchtower/wtdb/tower_db_test.go @@ -3,7 +3,6 @@ package wtdb_test import ( "bytes" "encoding/binary" - "reflect" "testing" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -14,6 +13,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" + "github.com/stretchr/testify/require" ) var ( @@ -48,10 +48,7 @@ func (h *towerDBHarness) insertSession(s *wtdb.SessionInfo, expErr error) { h.t.Helper() err := h.db.InsertSessionInfo(s) - if err != expErr { - h.t.Fatalf("expected insert session error: %v, got : %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) } // getSession retrieves the session identified by id, asserting that the call @@ -62,10 +59,7 @@ func (h *towerDBHarness) getSession(id *wtdb.SessionID, h.t.Helper() session, err := h.db.GetSessionInfo(id) - if err != expErr { - h.t.Fatalf("expected get session error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) return session } @@ -79,10 +73,7 @@ func (h *towerDBHarness) insertUpdate(s *wtdb.SessionStateUpdate, h.t.Helper() lastApplied, err := h.db.InsertStateUpdate(s) - if err != expErr { - h.t.Fatalf("expected insert update error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) return lastApplied } @@ -93,10 +84,7 @@ func (h *towerDBHarness) deleteSession(id wtdb.SessionID, expErr error) { h.t.Helper() err := h.db.DeleteSession(id) - if err != expErr { - h.t.Fatalf("expected deletion error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) } // queryMatches queries that database for the passed breach hint, returning all @@ -105,9 +93,7 @@ func (h *towerDBHarness) queryMatches(hint blob.BreachHint) []wtdb.Match { h.t.Helper() matches, err := h.db.QueryMatches([]blob.BreachHint{hint}) - if err != nil { - h.t.Fatalf("unable to query matches: %v", err) - } + require.NoError(h.t, err) return matches } @@ -119,14 +105,10 @@ func (h *towerDBHarness) hasUpdate(hint blob.BreachHint) wtdb.Match { h.t.Helper() matches := h.queryMatches(hint) - if len(matches) != 1 { - h.t.Fatalf("expected 1 match, found: %d", len(matches)) - } + require.Len(h.t, matches, 1) match := matches[0] - if match.Hint != hint { - h.t.Fatalf("expected hint: %x, got: %x", hint, match.Hint) - } + require.Equal(h.t, hint, match.Hint) return match } @@ -158,11 +140,7 @@ func testInsertSession(h *towerDBHarness) { h.insertSession(session, nil) session2 := h.getSession(&id, nil) - - if !reflect.DeepEqual(session, session2) { - h.t.Fatalf("expected session: %v, got %v", - session, session2) - } + require.Equal(h.t, session, session2) h.insertSession(session, nil) @@ -211,28 +189,21 @@ func testMultipleMatches(h *towerDBHarness) { // Query the db for matches on the chosen hint. matches := h.queryMatches(hint) - if len(matches) != numUpdates { - h.t.Fatalf("num updates mismatch, want: %d, got: %d", - numUpdates, len(matches)) - } + require.Len(h.t, matches, numUpdates) // Assert that the hints are what we asked for, and compute the set of // sessions returned. sessions := make(map[wtdb.SessionID]struct{}) for _, match := range matches { - if match.Hint != hint { - h.t.Fatalf("hint mismatch, want: %v, got: %v", - hint, match.Hint) - } + require.Equal(h.t, hint, match.Hint) sessions[match.ID] = struct{}{} } // Assert that the sessions returned match the session ids of the // sessions we initially created. for i := 0; i < numUpdates; i++ { - if _, ok := sessions[*id(i)]; !ok { - h.t.Fatalf("match for session %v not found", *id(i)) - } + _, ok := sessions[*id(i)] + require.Truef(h.t, ok, "match for session %v not found", *id(i)) } } @@ -242,33 +213,22 @@ func testMultipleMatches(h *towerDBHarness) { func testLookoutTip(h *towerDBHarness) { // Retrieve lookout tip on fresh db. epoch, err := h.db.GetLookoutTip() - if err != nil { - h.t.Fatalf("unable to fetch lookout tip: %v", err) - } + require.NoError(h.t, err) // Assert that the epoch is nil. - if epoch != nil { - h.t.Fatalf("lookout tip should not be set, found: %v", epoch) - } + require.Nil(h.t, epoch) // Create a closure that inserts an epoch, retrieves it, and asserts // that the returned epoch matches what was inserted. setAndCheck := func(i int) { expEpoch := epochFromInt(1) err = h.db.SetLookoutTip(expEpoch) - if err != nil { - h.t.Fatalf("unable to set lookout tip: %v", err) - } + require.NoError(h.t, err) epoch, err = h.db.GetLookoutTip() - if err != nil { - h.t.Fatalf("unable to fetch lookout tip: %v", err) - } + require.NoError(h.t, err) - if !reflect.DeepEqual(epoch, expEpoch) { - h.t.Fatalf("lookout tip mismatch, want: %v, got: %v", - expEpoch, epoch) - } + require.Equal(h.t, expEpoch, epoch) } // Set and assert the lookout tip. @@ -348,15 +308,10 @@ func testDeleteSession(h *towerDBHarness) { // Assert that only one update is still present. matches := h.queryMatches(hint) - if len(matches) != 1 { - h.t.Fatalf("expected one update, found: %d", len(matches)) - } + require.Len(h.t, matches, 1) // Assert that the update belongs to the first session. - if matches[0].ID != *id0 { - h.t.Fatalf("expected match for %v, instead is for: %v", - *id0, matches[0].ID) - } + require.Equal(h.t, *id0, matches[0].ID) // Finally, remove the first session added. h.deleteSession(*id0, nil) @@ -366,9 +321,7 @@ func testDeleteSession(h *towerDBHarness) { // No matches should exist for this hint. matches = h.queryMatches(hint) - if len(matches) != 0 { - h.t.Fatalf("expected zero updates, found: %d", len(matches)) - } + require.Zero(h.t, len(matches)) } type stateUpdateTest struct { @@ -403,10 +356,9 @@ func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) { *expSession = *test.session } - if len(test.updates) != len(test.updateErrs) { - h.t.Fatalf("malformed test case, num updates " + - "should match num errors") - } + require.Lenf(h.t, test.updates, len(test.updateErrs), + "malformed test case, num updates should match num "+ + "errors") // Send any updates provided in the test. for i, update := range test.updates { @@ -430,10 +382,7 @@ func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) { expSession.ClientLastApplied = update.LastApplied match := h.hasUpdate(update.Hint) - if !reflect.DeepEqual(match.SessionInfo, expSession) { - h.t.Fatalf("expected session: %v, got: %v", - expSession, match.SessionInfo) - } + require.Equal(h.t, expSession, match.SessionInfo) } } } @@ -640,14 +589,10 @@ func TestTowerDB(t *testing.T) { bdb, err := wtdb.NewBoltBackendCreator( true, path, "watchtower.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err := wtdb.OpenTowerDB(bdb) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) t.Cleanup(func() { db.Close() @@ -664,14 +609,10 @@ func TestTowerDB(t *testing.T) { bdb, err := wtdb.NewBoltBackendCreator( true, path, "watchtower.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err := wtdb.OpenTowerDB(bdb) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db.Close() // Open the db again, ensuring we test a @@ -680,14 +621,10 @@ func TestTowerDB(t *testing.T) { bdb, err = wtdb.NewBoltBackendCreator( true, path, "watchtower.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err = wtdb.OpenTowerDB(bdb) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) t.Cleanup(func() { db.Close() diff --git a/watchtower/wtdb/version.go b/watchtower/wtdb/version.go index 229b8a9dd..4785b0ae2 100644 --- a/watchtower/wtdb/version.go +++ b/watchtower/wtdb/version.go @@ -3,6 +3,7 @@ package wtdb import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration1" ) // migration is a function which takes a prior outdated version of the database @@ -24,7 +25,11 @@ var towerDBVersions = []version{} // clientDBVersions stores all versions and migrations of the client database. // This list will be used when opening the database to determine if any // migrations must be applied. -var clientDBVersions = []version{} +var clientDBVersions = []version{ + { + migration: migration1.MigrateTowerToSessionIndex, + }, +} // getLatestDBVersion returns the last known database version. func getLatestDBVersion(versions []version) uint32 { diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 28dafd04c..2a3825e87 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -220,6 +220,7 @@ func (m *ClientDB) listClientSessions( if tower != nil && *tower != session.TowerID { continue } + session.Tower = m.towers[session.TowerID] sessions[session.ID] = &session }