Merge pull request #6972 from ellemouton/wtclientTowerDb

watchtower: add towerID-to-sessionID index
This commit is contained in:
Oliver Gugger 2022-10-13 13:09:12 +02:00 committed by GitHub
commit 707546e2f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 999 additions and 409 deletions

View File

@ -102,6 +102,14 @@ crash](https://github.com/lightningnetwork/lnd/pull/7019).
* [The `tlv` package now allows decoding records larger than 65535 bytes. The
caller is expected to know that doing so with untrusted input is
unsafe.](https://github.com/lightningnetwork/lnd/pull/6779)
## Watchtowers
* [Create a towerID-to-sessionID index in the wtclient DB to improve the
speed of listing sessions for a particular tower ID](
https://github.com/lightningnetwork/lnd/pull/6972). This PR also ensures a
closer coupling of Towers and Sessions and ensures that a session cannot be
added if the tower it is referring to does not exist.
* [Create a helper function to wait for peer to come
online](https://github.com/lightningnetwork/lnd/pull/6931).

View File

@ -287,26 +287,33 @@ func New(config *Config) (*TowerClient, error) {
}
plog := build.NewPrefixLog(prefix, log)
// Next, load all candidate sessions and towers from the database into
// the client. We will use any of these session if their policies match
// Next, load all candidate towers and sessions from the database into
// the client. We will use any of these sessions if their policies match
// the current policy of the client, otherwise they will be ignored and
// new sessions will be requested.
isAnchorClient := cfg.Policy.IsAnchorChannel()
activeSessionFilter := genActiveSessionFilter(isAnchorClient)
candidateSessions, err := getClientSessions(
cfg.DB, cfg.SecretKeyRing, nil, activeSessionFilter,
candidateTowers := newTowerListIterator()
perActiveTower := func(tower *wtdb.Tower) {
// If the tower has already been marked as active, then there is
// no need to add it to the iterator again.
if candidateTowers.IsActive(tower.ID) {
return
}
log.Infof("Using private watchtower %s, offering policy %s",
tower, cfg.Policy)
// Add the tower to the set of candidate towers.
candidateTowers.AddCandidate(tower)
}
candidateSessions, err := getTowerAndSessionCandidates(
cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower,
)
if err != nil {
return nil, err
}
var candidateTowers []*wtdb.Tower
for _, s := range candidateSessions {
plog.Infof("Using private watchtower %s, offering policy %s",
s.Tower, cfg.Policy)
candidateTowers = append(candidateTowers, s.Tower)
}
// Load the sweep pkscripts that have been generated for all previously
// registered channels.
chanSummaries, err := cfg.DB.FetchChanSummaries()
@ -318,7 +325,7 @@ func New(config *Config) (*TowerClient, error) {
cfg: cfg,
log: plog,
pipeline: newTaskPipeline(plog),
candidateTowers: newTowerListIterator(candidateTowers...),
candidateTowers: candidateTowers,
candidateSessions: candidateSessions,
activeSessions: make(sessionQueueSet),
summaries: chanSummaries,
@ -349,13 +356,62 @@ func New(config *Config) (*TowerClient, error) {
return c, nil
}
// getTowerAndSessionCandidates loads all the towers from the DB and then
// fetches the sessions for each of tower. Sessions are only collected if they
// pass the sessionFilter check. If a tower has a session that does pass the
// sessionFilter check then the perActiveTower call-back will be called on that
// tower.
func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
sessionFilter func(*wtdb.ClientSession) bool,
perActiveTower func(tower *wtdb.Tower)) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
towers, err := db.ListTowers()
if err != nil {
return nil, err
}
candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
for _, tower := range towers {
sessions, err := db.ListClientSessions(&tower.ID)
if err != nil {
return nil, err
}
for _, s := range sessions {
towerKeyDesc, err := keyRing.DeriveKey(
keychain.KeyLocator{
Family: keychain.KeyFamilyTowerSession,
Index: s.KeyIndex,
},
)
if err != nil {
return nil, err
}
s.SessionKeyECDH = keychain.NewPubKeyECDH(
towerKeyDesc, keyRing,
)
if !sessionFilter(s) {
continue
}
// Add the session to the set of candidate sessions.
candidateSessions[s.ID] = s
perActiveTower(tower)
}
}
return candidateSessions, nil
}
// getClientSessions retrieves the client sessions for a particular tower if
// specified, otherwise all client sessions for all towers are retrieved. An
// optional filter can be provided to filter out any undesired client sessions.
//
// NOTE: This method should only be used when deserialization of a
// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the
// existing ListClientSessions method should be used.
// ClientSession's SessionPrivKey field is desired, otherwise, the existing
// ListClientSessions method should be used.
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
passesFilter func(*wtdb.ClientSession) bool) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
@ -371,12 +427,6 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
// requests. This prevents us from having to store the private keys on
// disk.
for _, s := range sessions {
tower, err := db.LoadTowerByID(s.TowerID)
if err != nil {
return nil, err
}
s.Tower = tower
towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{
Family: keychain.KeyFamilyTowerSession,
Index: s.KeyIndex,

View File

@ -62,7 +62,8 @@ type DB interface {
// still be able to accept state updates. An optional tower ID can be
// used to filter out any client sessions in the response that do not
// correspond to this tower.
ListClientSessions(*wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error)
ListClientSessions(*wtdb.TowerID) (
map[wtdb.SessionID]*wtdb.ClientSession, error)
// FetchChanSummaries loads a mapping from all registered channels to
// their channel summaries.
@ -96,8 +97,8 @@ type DB interface {
AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error
}
// AuthDialer connects to a remote node using an authenticated transport, such as
// brontide. The dialer argument is used to specify a resolver, which allows
// AuthDialer connects to a remote node using an authenticated transport, such
// as brontide. The dialer argument is used to specify a resolver, which allows
// this method to be used over Tor or clear net connections.
type AuthDialer func(localKey keychain.SingleKeyECDH,
netAddr *lnwire.NetAddress,

View File

@ -48,6 +48,12 @@ var (
// tower-pubkey -> tower-id.
cTowerIndexBkt = []byte("client-tower-index-bucket")
// cTowerToSessionIndexBkt is a top-level bucket storing:
// tower-id -> session-id -> 1
cTowerToSessionIndexBkt = []byte(
"client-tower-to-session-index-bucket",
)
// ErrTowerNotFound signals that the target tower was not found in the
// database.
ErrTowerNotFound = errors.New("tower not found")
@ -113,7 +119,8 @@ var (
// NewBoltBackendCreator returns a function that creates a new bbolt backend for
// the watchtower database.
func NewBoltBackendCreator(active bool, dbPath,
dbFileName string) func(boltCfg *kvdb.BoltConfig) (kvdb.Backend, error) {
dbFileName string) func(boltCfg *kvdb.BoltConfig) (kvdb.Backend,
error) {
// If the watchtower client isn't active, we return a function that
// always returns a nil DB to make sure we don't create empty database
@ -195,6 +202,7 @@ func initClientDBBuckets(tx kvdb.RwTx) error {
cSessionBkt,
cTowerBkt,
cTowerIndexBkt,
cTowerToSessionIndexBkt,
}
for _, bucket := range buckets {
@ -259,6 +267,13 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
return ErrUninitializedDB
}
towerToSessionIndex := tx.ReadWriteBucket(
cTowerToSessionIndexBkt,
)
if towerToSessionIndex == nil {
return ErrUninitializedDB
}
// Check if the tower index already knows of this pubkey.
towerIDBytes := towerIndex.Get(towerPubKey[:])
if len(towerIDBytes) == 8 {
@ -278,27 +293,32 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
// If there are any client sessions that correspond to
// this tower, we'll mark them as active to ensure we
// load them upon restarts.
//
// TODO(wilmer): with an index of tower -> sessions we
// can avoid the linear lookup.
towerSessIndex := towerToSessionIndex.NestedReadBucket(
tower.ID.Bytes(),
)
if towerSessIndex == nil {
return ErrTowerNotFound
}
sessions := tx.ReadWriteBucket(cSessionBkt)
if sessions == nil {
return ErrUninitializedDB
}
towerID := TowerIDFromBytes(towerIDBytes)
towerSessions, err := listClientSessions(
sessions, &towerID,
)
if err != nil {
return err
}
for _, session := range towerSessions {
err := markSessionStatus(
sessions, session, CSessionActive,
err = towerSessIndex.ForEach(func(k, _ []byte) error {
session, err := getClientSessionBody(
sessions, k,
)
if err != nil {
return err
}
return markSessionStatus(
sessions, session, CSessionActive,
)
})
if err != nil {
return err
}
} else {
// No such tower exists, create a new tower id for our
@ -320,6 +340,13 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
if err != nil {
return err
}
// Create a new bucket for this tower in the
// tower-to-sessions index.
_, err = towerToSessionIndex.CreateBucket(towerIDBytes)
if err != nil {
return err
}
}
// Store the new or updated tower under its tower id.
@ -348,11 +375,19 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
if towers == nil {
return ErrUninitializedDB
}
towerIndex := tx.ReadWriteBucket(cTowerIndexBkt)
if towerIndex == nil {
return ErrUninitializedDB
}
towersToSessionsIndex := tx.ReadWriteBucket(
cTowerToSessionIndexBkt,
)
if towersToSessionsIndex == nil {
return ErrUninitializedDB
}
// Don't return an error if the watchtower doesn't exist to act
// as a NOP.
pubKeyBytes := pubKey.SerializeCompressed()
@ -380,15 +415,14 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
// Otherwise, we should attempt to mark the tower's sessions as
// inactive.
//
// TODO(wilmer): with an index of tower -> sessions we can avoid
// the linear lookup.
sessions := tx.ReadWriteBucket(cSessionBkt)
if sessions == nil {
return ErrUninitializedDB
}
towerID := TowerIDFromBytes(towerIDBytes)
towerSessions, err := listClientSessions(sessions, &towerID)
towerSessions, err := listTowerSessions(
towerID, sessions, towers, towersToSessionsIndex,
)
if err != nil {
return err
}
@ -399,7 +433,14 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
if err := towerIndex.Delete(pubKeyBytes); err != nil {
return err
}
return towers.Delete(towerIDBytes)
if err := towers.Delete(towerIDBytes); err != nil {
return err
}
return towersToSessionsIndex.DeleteNestedBucket(
towerIDBytes,
)
}
// We'll mark its sessions as inactive as long as they don't
@ -573,14 +614,34 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error {
return ErrUninitializedDB
}
towers := tx.ReadBucket(cTowerBkt)
if towers == nil {
return ErrUninitializedDB
}
towerToSessionIndex := tx.ReadWriteBucket(
cTowerToSessionIndexBkt,
)
if towerToSessionIndex == nil {
return ErrUninitializedDB
}
// Check that client session with this session id doesn't
// already exist.
existingSessionBytes := sessions.NestedReadWriteBucket(session.ID[:])
existingSessionBytes := sessions.NestedReadWriteBucket(
session.ID[:],
)
if existingSessionBytes != nil {
return ErrClientSessionAlreadyExists
}
// Ensure that a tower with the given ID actually exists in the
// DB.
towerID := session.TowerID
if _, err := getTower(towers, towerID.Bytes()); err != nil {
return err
}
blobType := session.Policy.BlobType
// Check that this tower has a reserved key index.
@ -609,6 +670,19 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error {
}
}
// Add the new entry to the towerID-to-SessionID index.
indexBkt := towerToSessionIndex.NestedReadWriteBucket(
towerID.Bytes(),
)
if indexBkt == nil {
return ErrTowerNotFound
}
err = indexBkt.Put(session.ID[:], []byte{1})
if err != nil {
return err
}
// Finally, write the client session's body in the sessions
// bucket.
return putClientSessionBody(sessions, session)
@ -662,15 +736,41 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID,
// ListClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession, error) {
func (c *ClientDB) ListClientSessions(id *TowerID) (
map[SessionID]*ClientSession, error) {
var clientSessions map[SessionID]*ClientSession
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
sessions := tx.ReadBucket(cSessionBkt)
if sessions == nil {
return ErrUninitializedDB
}
towers := tx.ReadBucket(cTowerBkt)
if towers == nil {
return ErrUninitializedDB
}
var err error
clientSessions, err = listClientSessions(sessions, id)
// If no tower ID is specified, then fetch all the sessions
// known to the db.
if id == nil {
clientSessions, err = listClientAllSessions(
sessions, towers,
)
return err
}
// Otherwise, fetch the sessions for the given tower.
towerToSessionIndex := tx.ReadBucket(cTowerToSessionIndexBkt)
if towerToSessionIndex == nil {
return ErrUninitializedDB
}
clientSessions, err = listTowerSessions(
*id, sessions, towers, towerToSessionIndex,
)
return err
}, func() {
clientSessions = nil
@ -682,11 +782,9 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession
return clientSessions, nil
}
// listClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func listClientSessions(sessions kvdb.RBucket,
id *TowerID) (map[SessionID]*ClientSession, error) {
// listClientAllSessions returns the set of all client sessions known to the db.
func listClientAllSessions(sessions,
towers kvdb.RBucket) (map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession)
err := sessions.ForEach(func(k, _ []byte) error {
@ -694,19 +792,45 @@ func listClientSessions(sessions kvdb.RBucket,
// the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height
// for each channel.
session, err := getClientSession(sessions, k)
session, err := getClientSession(sessions, towers, k)
if err != nil {
return err
}
// Filter out any sessions that don't correspond to the given
// tower if one was set.
if id != nil && session.TowerID != *id {
return nil
clientSessions[session.ID] = session
return nil
})
if err != nil {
return nil, err
}
return clientSessions, nil
}
// listTowerSessions returns the set of all client sessions known to the db
// that are associated with the given tower id.
func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
towerToSessionIndex kvdb.RBucket) (map[SessionID]*ClientSession,
error) {
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
if towerIndexBkt == nil {
return nil, ErrTowerNotFound
}
clientSessions := make(map[SessionID]*ClientSession)
err := towerIndexBkt.ForEach(func(k, _ []byte) error {
// We'll load the full client session since the client will need
// the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height
// for each channel.
session, err := getClientSession(sessionsBkt, towersBkt, k)
if err != nil {
return err
}
clientSessions[session.ID] = session
return nil
})
if err != nil {
@ -951,7 +1075,9 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
// If the commits sub-bucket doesn't exist, there can't possibly
// be a corresponding committed update to remove.
sessionCommits := sessionBkt.NestedReadWriteBucket(cSessionCommits)
sessionCommits := sessionBkt.NestedReadWriteBucket(
cSessionCommits,
)
if sessionCommits == nil {
return ErrCommittedUpdateNotFound
}
@ -1004,8 +1130,8 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
// getClientSessionBody loads the body of a ClientSession from the sessions
// bucket corresponding to the serialized session id. This does not deserialize
// the CommittedUpdates or AckUpdates associated with the session. If the caller
// requires this info, use getClientSession.
// the CommittedUpdates, AckUpdates or the Tower associated with the session.
// If the caller requires this info, use getClientSession.
func getClientSessionBody(sessions kvdb.RBucket,
idBytes []byte) (*ClientSession, error) {
@ -1032,9 +1158,9 @@ func getClientSessionBody(sessions kvdb.RBucket,
}
// getClientSession loads the full ClientSession associated with the serialized
// session id. This method populates the CommittedUpdates and AckUpdates in
// addition to the ClientSession's body.
func getClientSession(sessions kvdb.RBucket,
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
// in addition to the ClientSession's body.
func getClientSession(sessions, towers kvdb.RBucket,
idBytes []byte) (*ClientSession, error) {
session, err := getClientSessionBody(sessions, idBytes)
@ -1042,6 +1168,12 @@ func getClientSession(sessions kvdb.RBucket,
return nil, err
}
// Fetch the tower associated with this session.
tower, err := getTower(towers, session.TowerID.Bytes())
if err != nil {
return nil, err
}
// Fetch the committed updates for this session.
commitedUpdates, err := getClientSessionCommits(sessions, idBytes)
if err != nil {
@ -1054,6 +1186,7 @@ func getClientSession(sessions kvdb.RBucket,
return nil, err
}
session.Tower = tower
session.CommittedUpdates = commitedUpdates
session.AckedUpdates = ackedUpdates

View File

@ -1,11 +1,9 @@
package wtdb_test
import (
"bytes"
crand "crypto/rand"
"io"
"net"
"reflect"
"testing"
"github.com/btcsuite/btcd/btcec/v2"
@ -16,8 +14,12 @@ import (
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtmock"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
"github.com/stretchr/testify/require"
)
// pseudoAddr is a fake network address to be used for testing purposes.
var pseudoAddr = &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911}
// clientDBInit is a closure used to initialize a wtclient.DB instance.
type clientDBInit func(t *testing.T) wtclient.DB
@ -37,23 +39,22 @@ func newClientDBHarness(t *testing.T, init clientDBInit) *clientDBHarness {
return h
}
func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, expErr error) {
func (h *clientDBHarness) insertSession(session *wtdb.ClientSession,
expErr error) {
h.t.Helper()
err := h.db.CreateClientSession(session)
if err != expErr {
h.t.Fatalf("expected create client session error: %v, got: %v",
expErr, err)
}
require.ErrorIs(h.t, err, expErr)
}
func (h *clientDBHarness) listSessions(id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession {
func (h *clientDBHarness) listSessions(
id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession {
h.t.Helper()
sessions, err := h.db.ListClientSessions(id)
if err != nil {
h.t.Fatalf("unable to list client sessions: %v", err)
}
require.NoError(h.t, err, "unable to list client sessions")
return sessions
}
@ -64,13 +65,8 @@ func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID,
h.t.Helper()
index, err := h.db.NextSessionKeyIndex(id, blobType)
if err != nil {
h.t.Fatalf("unable to create next session key index: %v", err)
}
if index == 0 {
h.t.Fatalf("next key index should never be 0")
}
require.NoError(h.t, err, "unable to create next session key index")
require.NotZero(h.t, index, "next key index should never be 0")
return index
}
@ -81,20 +77,11 @@ func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress,
h.t.Helper()
tower, err := h.db.CreateTower(lnAddr)
if err != expErr {
h.t.Fatalf("expected create tower error: %v, got: %v", expErr, err)
}
if tower.ID == 0 {
h.t.Fatalf("tower id should never be 0")
}
require.ErrorIs(h.t, err, expErr)
require.NotZero(h.t, tower.ID, "tower id should never be 0")
for _, session := range h.listSessions(&tower.ID) {
if session.Status != wtdb.CSessionActive {
h.t.Fatalf("expected status for session %v to be %v, "+
"got %v", session.ID, wtdb.CSessionActive,
session.Status)
}
require.Equal(h.t, wtdb.CSessionActive, session.Status)
}
return tower
@ -105,68 +92,64 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr,
h.t.Helper()
if err := h.db.RemoveTower(pubKey, addr); err != expErr {
h.t.Fatalf("expected remove tower error: %v, got %v", expErr, err)
}
err := h.db.RemoveTower(pubKey, addr)
require.ErrorIs(h.t, err, expErr)
if expErr != nil {
return
}
pubKeyStr := pubKey.SerializeCompressed()
if addr != nil {
tower, err := h.db.LoadTower(pubKey)
if err != nil {
h.t.Fatalf("expected tower %x to still exist",
pubKey.SerializeCompressed())
}
require.NoErrorf(h.t, err, "expected tower %x to still exist",
pubKeyStr)
removedAddr := addr.String()
for _, towerAddr := range tower.Addresses {
if towerAddr.String() == removedAddr {
h.t.Fatalf("address %v not removed for tower %x",
removedAddr, pubKey.SerializeCompressed())
}
require.NotEqualf(h.t, removedAddr, towerAddr,
"address %v not removed for tower %x",
removedAddr, pubKeyStr)
}
} else {
tower, err := h.db.LoadTower(pubKey)
if hasSessions && err != nil {
h.t.Fatalf("expected tower %x with sessions to still "+
"exist", pubKey.SerializeCompressed())
}
if !hasSessions && err == nil {
h.t.Fatalf("expected tower %x with no sessions to not "+
"exist", pubKey.SerializeCompressed())
}
if !hasSessions {
if hasSessions {
require.NoError(h.t, err, "expected tower %x with "+
"sessions to still exist", pubKeyStr)
} else {
require.Errorf(h.t, err, "expected tower %x with no "+
"sessions to not exist", pubKeyStr)
return
}
for _, session := range h.listSessions(&tower.ID) {
if session.Status != wtdb.CSessionInactive {
h.t.Fatalf("expected status for session %v to "+
"be %v, got %v", session.ID,
wtdb.CSessionInactive, session.Status)
}
require.Equal(h.t, wtdb.CSessionInactive,
session.Status, "expected status for session "+
"%v to be %v, got %v", session.ID,
wtdb.CSessionInactive, session.Status)
}
}
}
func (h *clientDBHarness) loadTower(pubKey *btcec.PublicKey, expErr error) *wtdb.Tower {
func (h *clientDBHarness) loadTower(pubKey *btcec.PublicKey,
expErr error) *wtdb.Tower {
h.t.Helper()
tower, err := h.db.LoadTower(pubKey)
if err != expErr {
h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err)
}
require.ErrorIs(h.t, err, expErr)
return tower
}
func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID, expErr error) *wtdb.Tower {
func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID,
expErr error) *wtdb.Tower {
h.t.Helper()
tower, err := h.db.LoadTowerByID(id)
if err != expErr {
h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err)
}
require.ErrorIs(h.t, err, expErr)
return tower
}
@ -175,9 +158,7 @@ func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientC
h.t.Helper()
summaries, err := h.db.FetchChanSummaries()
if err != nil {
h.t.Fatalf("unable to fetch chan summaries: %v", err)
}
require.NoError(h.t, err)
return summaries
}
@ -188,10 +169,7 @@ func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID,
h.t.Helper()
err := h.db.RegisterChannel(chanID, sweepPkScript)
if err != expErr {
h.t.Fatalf("expected register channel error: %v, got: %v",
expErr, err)
}
require.ErrorIs(h.t, err, expErr)
}
func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID,
@ -200,10 +178,7 @@ func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID,
h.t.Helper()
lastApplied, err := h.db.CommitUpdate(id, update)
if err != expErr {
h.t.Fatalf("expected commit update error: %v, got: %v",
expErr, err)
}
require.ErrorIs(h.t, err, expErr)
return lastApplied
}
@ -214,10 +189,22 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16,
h.t.Helper()
err := h.db.AckUpdate(id, seqNum, lastApplied)
if err != expErr {
h.t.Fatalf("expected commit update error: %v, got: %v",
expErr, err)
}
require.ErrorIs(h.t, err, expErr)
}
// newTower is a helper function that creates a new tower with a randomly
// generated public key and inserts it into the client DB.
func (h *clientDBHarness) newTower() *wtdb.Tower {
h.t.Helper()
pk, err := randPubKey()
require.NoError(h.t, err)
// Insert a random tower into the database.
return h.createTower(&lnwire.NetAddress{
IdentityKey: pk,
Address: pseudoAddr,
}, nil)
}
// testCreateClientSession asserts various conditions regarding the creation of
@ -228,10 +215,12 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16,
func testCreateClientSession(h *clientDBHarness) {
const blobType = blob.TypeAltruistAnchorCommit
tower := h.newTower()
// Create a test client session to insert.
session := &wtdb.ClientSession{
ClientSessionBody: wtdb.ClientSessionBody{
TowerID: wtdb.TowerID(3),
TowerID: tower.ID,
Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blobType,
@ -245,9 +234,9 @@ func testCreateClientSession(h *clientDBHarness) {
// First, assert that this session is not already present in the
// database.
if _, ok := h.listSessions(nil)[session.ID]; ok {
h.t.Fatalf("session for id %x should not exist yet", session.ID)
}
_, ok := h.listSessions(nil)[session.ID]
require.Falsef(h.t, ok, "session for id %x should not exist yet",
session.ID)
// Attempting to insert the client session without reserving a session
// key index should fail.
@ -264,10 +253,8 @@ func testCreateClientSession(h *clientDBHarness) {
// successfully created, it should return the same index to maintain
// idempotency across restarts.
keyIndex2 := h.nextKeyIndex(session.TowerID, blobType)
if keyIndex != keyIndex2 {
h.t.Fatalf("next key index should be idempotent: want: %v, "+
"got %v", keyIndex, keyIndex2)
}
require.Equalf(h.t, keyIndex, keyIndex2, "next key index should "+
"be idempotent: want: %v, got %v", keyIndex, keyIndex2)
// Now, set the client session's key index so that it is proper and
// insert it. This should succeed.
@ -275,9 +262,8 @@ func testCreateClientSession(h *clientDBHarness) {
h.insertSession(session, nil)
// Verify that the session now exists in the database.
if _, ok := h.listSessions(nil)[session.ID]; !ok {
h.t.Fatalf("session for id %x should exist now", session.ID)
}
_, ok = h.listSessions(nil)[session.ID]
require.Truef(h.t, ok, "session for id %x should exist now", session.ID)
// Attempt to insert the session again, which should fail due to the
// session already existing.
@ -286,9 +272,8 @@ func testCreateClientSession(h *clientDBHarness) {
// Finally, assert that reserving another key index succeeds with a
// different key index, now that the first one has been finalized.
keyIndex3 := h.nextKeyIndex(session.TowerID, blobType)
if keyIndex == keyIndex3 {
h.t.Fatalf("key index still reserved after creating session")
}
require.NotEqualf(h.t, keyIndex, keyIndex3, "key index still "+
"reserved after creating session")
}
// testFilterClientSessions asserts that we can correctly filter client sessions
@ -300,15 +285,12 @@ func testFilterClientSessions(h *clientDBHarness) {
const blobType = blob.TypeAltruistCommit
towerSessions := make(map[wtdb.TowerID][]wtdb.SessionID)
for i := 0; i < numSessions; i++ {
towerID := wtdb.TowerID(1)
if i == numSessions-1 {
towerID = wtdb.TowerID(2)
}
keyIndex := h.nextKeyIndex(towerID, blobType)
tower := h.newTower()
keyIndex := h.nextKeyIndex(tower.ID, blobType)
sessionID := wtdb.SessionID([33]byte{byte(i)})
h.insertSession(&wtdb.ClientSession{
ClientSessionBody: wtdb.ClientSessionBody{
TowerID: towerID,
TowerID: tower.ID,
Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blobType,
@ -320,22 +302,21 @@ func testFilterClientSessions(h *clientDBHarness) {
},
ID: sessionID,
}, nil)
towerSessions[towerID] = append(towerSessions[towerID], sessionID)
towerSessions[tower.ID] = append(
towerSessions[tower.ID], sessionID,
)
}
// We should see the expected sessions for each tower when filtering
// them.
for towerID, expectedSessions := range towerSessions {
sessions := h.listSessions(&towerID)
if len(sessions) != len(expectedSessions) {
h.t.Fatalf("expected %v sessions for tower %v, got %v",
len(expectedSessions), towerID, len(sessions))
}
require.Len(h.t, sessions, len(expectedSessions))
for _, expectedSession := range expectedSessions {
if _, ok := sessions[expectedSession]; !ok {
h.t.Fatalf("expected session %v for tower %v",
expectedSession, towerID)
}
_, ok := sessions[expectedSession]
require.Truef(h.t, ok, "expected session %v for "+
"tower %v", expectedSession, towerID)
}
}
}
@ -347,49 +328,31 @@ func testCreateTower(h *clientDBHarness) {
// Test that loading a tower with an arbitrary tower id fails.
h.loadTowerByID(20, wtdb.ErrTowerNotFound)
pk, err := randPubKey()
if err != nil {
h.t.Fatalf("unable to generate pubkey: %v", err)
}
addr1 := &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911}
lnAddr := &lnwire.NetAddress{
IdentityKey: pk,
Address: addr1,
}
// Insert a random tower into the database.
tower := h.createTower(lnAddr, nil)
tower := h.newTower()
require.Len(h.t, tower.LNAddrs(), 1)
towerAddr := tower.LNAddrs()[0]
// Load the tower from the database and assert that it matches the tower
// we created.
tower2 := h.loadTowerByID(tower.ID, nil)
if !reflect.DeepEqual(tower, tower2) {
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
tower, tower2)
}
tower2 = h.loadTower(pk, err)
if !reflect.DeepEqual(tower, tower2) {
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
tower, tower2)
}
require.Equal(h.t, tower, tower2)
tower2 = h.loadTower(tower.IdentityKey, nil)
require.Equal(h.t, tower, tower2)
// Insert the address again into the database. Since the address is the
// same, this should result in an unmodified tower record.
towerDupAddr := h.createTower(lnAddr, nil)
if len(towerDupAddr.Addresses) != 1 {
h.t.Fatalf("duplicate address should be deduped")
}
if !reflect.DeepEqual(tower, towerDupAddr) {
h.t.Fatalf("mismatch towers, want: %v, got: %v",
tower, towerDupAddr)
}
towerDupAddr := h.createTower(towerAddr, nil)
require.Lenf(h.t, towerDupAddr.Addresses, 1, "duplicate address "+
"should be deduped")
require.Equal(h.t, tower, towerDupAddr)
// Generate a new address for this tower.
addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911}
lnAddr2 := &lnwire.NetAddress{
IdentityKey: pk,
IdentityKey: tower.IdentityKey,
Address: addr2,
}
@ -400,26 +363,18 @@ func testCreateTower(h *clientDBHarness) {
// Load the tower from the database, and assert that it matches the
// tower returned from creation.
towerNewAddr2 := h.loadTowerByID(tower.ID, nil)
if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) {
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
towerNewAddr, towerNewAddr2)
}
towerNewAddr2 = h.loadTower(pk, nil)
if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) {
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
towerNewAddr, towerNewAddr2)
}
require.Equal(h.t, towerNewAddr, towerNewAddr2)
towerNewAddr2 = h.loadTower(tower.IdentityKey, nil)
require.Equal(h.t, towerNewAddr, towerNewAddr2)
// Assert that there are now two addresses on the tower object.
if len(towerNewAddr.Addresses) != 2 {
h.t.Fatalf("new address should be added")
}
require.Lenf(h.t, towerNewAddr.Addresses, 2, "new address should be "+
"added")
// Finally, assert that the new address was prepended since it is deemed
// fresher.
if !reflect.DeepEqual(tower.Addresses, towerNewAddr.Addresses[1:]) {
h.t.Fatalf("new address should be prepended")
}
require.Equal(h.t, tower.Addresses, towerNewAddr.Addresses[1:])
}
// testRemoveTower asserts the behavior of removing Tower objects as a whole and
@ -427,9 +382,7 @@ func testCreateTower(h *clientDBHarness) {
func testRemoveTower(h *clientDBHarness) {
// Generate a random public key we'll use for our tower.
pk, err := randPubKey()
if err != nil {
h.t.Fatalf("unable to generate pubkey: %v", err)
}
require.NoError(h.t, err)
// Removing a tower that does not exist within the database should
// result in a NOP.
@ -507,28 +460,23 @@ func testRemoveTower(h *clientDBHarness) {
func testChanSummaries(h *clientDBHarness) {
// First, assert that this channel is not already registered.
var chanID lnwire.ChannelID
if _, ok := h.fetchChanSummaries()[chanID]; ok {
h.t.Fatalf("pkscript for channel %x should not exist yet",
chanID)
}
_, ok := h.fetchChanSummaries()[chanID]
require.Falsef(h.t, ok, "pkscript for channel %x should not exist yet",
chanID)
// Generate a random sweep pkscript and register it for this channel.
expPkScript := make([]byte, 22)
if _, err := io.ReadFull(crand.Reader, expPkScript); err != nil {
h.t.Fatalf("unable to generate pkscript: %v", err)
}
_, err := io.ReadFull(crand.Reader, expPkScript)
require.NoError(h.t, err)
h.registerChan(chanID, expPkScript, nil)
// Assert that the channel exists and that its sweep pkscript matches
// the one we registered.
summary, ok := h.fetchChanSummaries()[chanID]
if !ok {
h.t.Fatalf("pkscript for channel %x should not exist yet",
chanID)
} else if bytes.Compare(expPkScript, summary.SweepPkScript) != 0 {
h.t.Fatalf("pkscript mismatch, want: %x, got: %x",
expPkScript, summary.SweepPkScript)
}
require.Truef(h.t, ok, "pkscript for channel %x should not exist yet",
chanID)
require.Equal(h.t, expPkScript, summary.SweepPkScript)
// Finally, assert that re-registering the same channel produces a
// failure.
@ -538,9 +486,11 @@ func testChanSummaries(h *clientDBHarness) {
// testCommitUpdate tests the behavior of CommitUpdate, ensuring that they can
func testCommitUpdate(h *clientDBHarness) {
const blobType = blob.TypeAltruistCommit
tower := h.newTower()
session := &wtdb.ClientSession{
ClientSessionBody: wtdb.ClientSessionBody{
TowerID: wtdb.TowerID(3),
TowerID: tower.ID,
Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blobType,
@ -565,10 +515,7 @@ func testCommitUpdate(h *clientDBHarness) {
// succeed. The lastApplied value should be 0 since we have not received
// an ack from the tower.
lastApplied := h.commitUpdate(&session.ID, update1, nil)
if lastApplied != 0 {
h.t.Fatalf("last applied mismatch, want: 0, got: %v",
lastApplied)
}
require.Zero(h.t, lastApplied)
// Assert that the committed update appears in the client session's
// CommittedUpdates map when loaded from disk and that there are no
@ -584,10 +531,7 @@ func testCommitUpdate(h *clientDBHarness) {
// the on-disk update's hint). The lastApplied value should remain
// unchanged.
lastApplied2 := h.commitUpdate(&session.ID, update1, nil)
if lastApplied2 != lastApplied {
h.t.Fatalf("last applied should not have changed, got %v",
lastApplied2)
}
require.Equal(h.t, lastApplied, lastApplied2)
// Assert that the loaded ClientSession is the same as before.
dbSession = h.listSessions(nil)[session.ID]
@ -605,10 +549,7 @@ func testCommitUpdate(h *clientDBHarness) {
// which should succeed.
update2.SeqNum = 2
lastApplied3 := h.commitUpdate(&session.ID, update2, nil)
if lastApplied3 != lastApplied {
h.t.Fatalf("last applied should not have changed, got %v",
lastApplied3)
}
require.Equal(h.t, lastApplied, lastApplied3)
// Check that both updates now appear as committed on the ClientSession
// loaded from disk.
@ -638,10 +579,12 @@ func testCommitUpdate(h *clientDBHarness) {
func testAckUpdate(h *clientDBHarness) {
const blobType = blob.TypeAltruistCommit
tower := h.newTower()
// Create a new session that the updates in this will be tied to.
session := &wtdb.ClientSession{
ClientSessionBody: wtdb.ClientSessionBody{
TowerID: wtdb.TowerID(3),
TowerID: tower.ID,
Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blobType,
@ -668,10 +611,7 @@ func testAckUpdate(h *clientDBHarness) {
// Commit to a random update at seqnum 1.
update1 := randCommittedUpdate(h.t, 1)
lastApplied := h.commitUpdate(&session.ID, update1, nil)
if lastApplied != 0 {
h.t.Fatalf("last applied mismatch, want: 0, got: %v",
lastApplied)
}
require.Zero(h.t, lastApplied)
// Acking seqnum 1 should succeed.
h.ackUpdate(&session.ID, 1, 1, nil)
@ -699,10 +639,7 @@ func testAckUpdate(h *clientDBHarness) {
// ack.
update2 := randCommittedUpdate(h.t, 2)
lastApplied = h.commitUpdate(&session.ID, update2, nil)
if lastApplied != 1 {
h.t.Fatalf("last applied mismatch, want: 1, got: %v",
lastApplied)
}
require.EqualValues(h.t, 1, lastApplied)
// Ack seqnum 2.
h.ackUpdate(&session.ID, 2, 2, nil)
@ -740,10 +677,7 @@ func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession,
expUpdates = make([]wtdb.CommittedUpdate, 0)
}
if !reflect.DeepEqual(session.CommittedUpdates, expUpdates) {
t.Fatalf("committed updates mismatch, want: %v, got: %v",
expUpdates, session.CommittedUpdates)
}
require.Equal(t, expUpdates, session.CommittedUpdates)
}
// checkAckedUpdates asserts that the AckedUpdates on a session match the
@ -758,10 +692,7 @@ func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession,
expUpdates = make(map[uint16]wtdb.BackupID)
}
if !reflect.DeepEqual(session.AckedUpdates, expUpdates) {
t.Fatalf("acked updates mismatch, want: %v, got: %v",
expUpdates, session.AckedUpdates)
}
require.Equal(t, expUpdates, session.AckedUpdates)
}
// TestClientDB asserts the behavior of a fresh client db, a reopened client db,
@ -779,14 +710,10 @@ func TestClientDB(t *testing.T) {
bdb, err := wtdb.NewBoltBackendCreator(
true, t.TempDir(), "wtclient.db",
)(dbCfg)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
db, err := wtdb.OpenClientDB(bdb)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
t.Cleanup(func() {
db.Close()
@ -803,27 +730,19 @@ func TestClientDB(t *testing.T) {
bdb, err := wtdb.NewBoltBackendCreator(
true, path, "wtclient.db",
)(dbCfg)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
db, err := wtdb.OpenClientDB(bdb)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
db.Close()
bdb, err = wtdb.NewBoltBackendCreator(
true, path, "wtclient.db",
)(dbCfg)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
db, err = wtdb.OpenClientDB(bdb)
if err != nil {
t.Fatalf("unable to reopen db: %v", err)
}
require.NoError(t, err)
t.Cleanup(func() {
db.Close()
@ -893,19 +812,16 @@ func TestClientDB(t *testing.T) {
// randCommittedUpdate generates a random committed update.
func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate {
var chanID lnwire.ChannelID
if _, err := io.ReadFull(crand.Reader, chanID[:]); err != nil {
t.Fatalf("unable to generate chan id: %v", err)
}
_, err := io.ReadFull(crand.Reader, chanID[:])
require.NoError(t, err)
var hint blob.BreachHint
if _, err := io.ReadFull(crand.Reader, hint[:]); err != nil {
t.Fatalf("unable to generate breach hint: %v", err)
}
_, err = io.ReadFull(crand.Reader, hint[:])
require.NoError(t, err)
encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type()))
if _, err := io.ReadFull(crand.Reader, encBlob); err != nil {
t.Fatalf("unable to generate encrypted blob: %v", err)
}
_, err = io.ReadFull(crand.Reader, encBlob)
require.NoError(t, err)
return &wtdb.CommittedUpdate{
SeqNum: seqNum,

View File

@ -13,6 +13,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/tor"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/stretchr/testify/require"
)
func randPubKey() (*btcec.PublicKey, error) {
@ -134,10 +135,7 @@ func TestCodec(tt *testing.T) {
// Ensure encoding the object succeeds.
var b bytes.Buffer
err := obj.Encode(&b)
if err != nil {
t.Fatalf("unable to encode: %v", err)
return false
}
require.NoError(t, err)
var obj2 dbObject
switch obj.(type) {
@ -162,17 +160,10 @@ func TestCodec(tt *testing.T) {
// Ensure decoding the object succeeds.
err = obj2.Decode(bytes.NewReader(b.Bytes()))
if err != nil {
t.Fatalf("unable to decode: %v", err)
return false
}
require.NoError(t, err)
// Assert the original and decoded object match.
if !reflect.DeepEqual(obj, obj2) {
t.Fatalf("encode/decode mismatch, want: %v, "+
"got: %v", obj, obj2)
return false
}
require.Equal(t, obj, obj2)
return true
}
@ -180,16 +171,10 @@ func TestCodec(tt *testing.T) {
customTypeGen := map[string]func([]reflect.Value, *rand.Rand){
"Tower": func(v []reflect.Value, r *rand.Rand) {
pk, err := randPubKey()
if err != nil {
t.Fatalf("unable to generate pubkey: %v", err)
return
}
require.NoError(t, err)
addrs, err := randAddrs(r)
if err != nil {
t.Fatalf("unable to generate addrs: %v", err)
return
}
require.NoError(t, err)
obj := wtdb.Tower{
IdentityKey: pk,
@ -260,10 +245,7 @@ func TestCodec(tt *testing.T) {
}
err := quick.Check(test.scenario, config)
if err != nil {
t.Fatalf("fuzz checks for msg=%s failed: %v",
test.name, err)
}
require.NoError(h, err)
})
}
}

View File

@ -3,6 +3,7 @@ package wtdb
import (
"github.com/btcsuite/btclog"
"github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration1"
)
// log is a logger that is initialized with no output filters. This
@ -26,6 +27,7 @@ func DisableLog() {
// using btclog.
func UseLogger(logger btclog.Logger) {
log = logger
migration1.UseLogger(logger)
}
// logClosure is used to provide a closure over expensive logging operations so

View File

@ -0,0 +1,145 @@
package migration1
import (
"bytes"
"errors"
"fmt"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
// cSessionBkt is a top-level bucket storing:
// session-id => cSessionBody -> encoded ClientSessionBody
// => cSessionCommits => seqnum -> encoded CommittedUpdate
// => cSessionAcks => seqnum -> encoded BackupID
cSessionBkt = []byte("client-session-bucket")
// cSessionBody is a sub-bucket of cSessionBkt storing only the body of
// the ClientSession.
cSessionBody = []byte("client-session-body")
// cTowerIDToSessionIDIndexBkt is a top-level bucket storing:
// tower-id -> session-id -> 1
cTowerIDToSessionIDIndexBkt = []byte(
"client-tower-to-session-index-bucket",
)
// ErrUninitializedDB signals that top-level buckets for the database
// have not been initialized.
ErrUninitializedDB = errors.New("db not initialized")
// ErrClientSessionNotFound signals that the requested client session
// was not found in the database.
ErrClientSessionNotFound = errors.New("client session not found")
// ErrCorruptClientSession signals that the client session's on-disk
// structure deviates from what is expected.
ErrCorruptClientSession = errors.New("client session corrupted")
)
// MigrateTowerToSessionIndex constructs a new towerID-to-sessionID for the
// watchtower client DB.
func MigrateTowerToSessionIndex(tx kvdb.RwTx) error {
log.Infof("Migrating the tower client db to add a " +
"towerID-to-sessionID index")
// First, we collect all the entries we want to add to the index.
entries, err := getIndexEntries(tx)
if err != nil {
return err
}
// Then we create a new top-level bucket for the index.
indexBkt, err := tx.CreateTopLevelBucket(cTowerIDToSessionIDIndexBkt)
if err != nil {
return err
}
// Finally, we add all the collected entries to the index.
for towerID, sessions := range entries {
// Create a sub-bucket using the tower ID.
towerBkt, err := indexBkt.CreateBucketIfNotExists(
towerID.Bytes(),
)
if err != nil {
return err
}
for sessionID := range sessions {
err := addIndex(towerBkt, sessionID)
if err != nil {
return err
}
}
}
return nil
}
// addIndex adds a new towerID-sessionID pair to the given bucket. The
// session ID is used as a key within the bucket and a value of []byte{1} is
// used for each session ID key.
func addIndex(towerBkt kvdb.RwBucket, sessionID SessionID) error {
session := towerBkt.Get(sessionID[:])
if session != nil {
return fmt.Errorf("session %x duplicated", sessionID)
}
return towerBkt.Put(sessionID[:], []byte{1})
}
// getIndexEntries collects all the towerID-sessionID entries that need to be
// added to the new index.
func getIndexEntries(tx kvdb.RwTx) (map[TowerID]map[SessionID]bool, error) {
sessions := tx.ReadBucket(cSessionBkt)
if sessions == nil {
return nil, ErrUninitializedDB
}
index := make(map[TowerID]map[SessionID]bool)
err := sessions.ForEach(func(k, _ []byte) error {
session, err := getClientSession(sessions, k)
if err != nil {
return err
}
if index[session.TowerID] == nil {
index[session.TowerID] = make(map[SessionID]bool)
}
index[session.TowerID][session.ID] = true
return nil
})
if err != nil {
return nil, err
}
return index, nil
}
// getClientSession fetches the session with the given ID from the db.
func getClientSession(sessions kvdb.RBucket, idBytes []byte) (*ClientSession,
error) {
sessionBkt := sessions.NestedReadBucket(idBytes)
if sessionBkt == nil {
return nil, ErrClientSessionNotFound
}
// Should never have a sessionBkt without also having its body.
sessionBody := sessionBkt.Get(cSessionBody)
if sessionBody == nil {
return nil, ErrCorruptClientSession
}
var session ClientSession
copy(session.ID[:], idBytes)
err := session.Decode(bytes.NewReader(sessionBody))
if err != nil {
return nil, err
}
return &session, nil
}

View File

@ -0,0 +1,155 @@
package migration1
import (
"bytes"
"testing"
"github.com/lightningnetwork/lnd/channeldb/migtest"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
s1 = &ClientSessionBody{
TowerID: TowerID(1),
}
s2 = &ClientSessionBody{
TowerID: TowerID(3),
}
s3 = &ClientSessionBody{
TowerID: TowerID(6),
}
// pre is the expected data in the DB before the migration.
pre = map[string]interface{}{
sessionIDString("1"): map[string]interface{}{
string(cSessionBody): clientSessionString(s1),
},
sessionIDString("2"): map[string]interface{}{
string(cSessionBody): clientSessionString(s3),
},
sessionIDString("3"): map[string]interface{}{
string(cSessionBody): clientSessionString(s1),
},
sessionIDString("4"): map[string]interface{}{
string(cSessionBody): clientSessionString(s1),
},
sessionIDString("5"): map[string]interface{}{
string(cSessionBody): clientSessionString(s2),
},
}
// preFailNoSessionBody should fail the migration due to there being a
// session without an associated session body.
preFailNoSessionBody = map[string]interface{}{
sessionIDString("1"): map[string]interface{}{
string(cSessionBody): clientSessionString(s1),
},
sessionIDString("2"): map[string]interface{}{},
}
// post is the expected data after migration.
post = map[string]interface{}{
towerIDString(1): map[string]interface{}{
sessionIDString("1"): string([]byte{1}),
sessionIDString("3"): string([]byte{1}),
sessionIDString("4"): string([]byte{1}),
},
towerIDString(3): map[string]interface{}{
sessionIDString("5"): string([]byte{1}),
},
towerIDString(6): map[string]interface{}{
sessionIDString("2"): string([]byte{1}),
},
}
)
// TestMigrateTowerToSessionIndex tests that the TestMigrateTowerToSessionIndex
// function correctly adds a new towerID-to-sessionID index to the tower client
// db.
func TestMigrateTowerToSessionIndex(t *testing.T) {
tests := []struct {
name string
shouldFail bool
pre map[string]interface{}
post map[string]interface{}
}{
{
name: "migration ok",
shouldFail: false,
pre: pre,
post: post,
},
{
name: "fail due to corrupt db",
shouldFail: true,
pre: preFailNoSessionBody,
post: nil,
},
{
name: "no sessions",
shouldFail: false,
pre: nil,
post: nil,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
// Before the migration we have a sessions bucket.
before := func(tx kvdb.RwTx) error {
return migtest.RestoreDB(
tx, cSessionBkt, test.pre,
)
}
// After the migration, we should have an untouched
// sessions bucket and a new index bucket.
after := func(tx kvdb.RwTx) error {
if err := migtest.VerifyDB(
tx, cSessionBkt, test.pre,
); err != nil {
return err
}
// If we expect our migration to fail, we don't
// expect an index bucket.
if test.shouldFail {
return nil
}
return migtest.VerifyDB(
tx, cTowerIDToSessionIDIndexBkt,
test.post,
)
}
migtest.ApplyMigration(
t, before, after, MigrateTowerToSessionIndex,
test.shouldFail,
)
})
}
}
func sessionIDString(id string) string {
var sessID SessionID
copy(sessID[:], id)
return string(sessID[:])
}
func clientSessionString(s *ClientSessionBody) string {
var b bytes.Buffer
err := s.Encode(&b)
if err != nil {
panic(err)
}
return b.String()
}
func towerIDString(id int) string {
towerID := TowerID(id)
return string(towerID.Bytes())
}

View File

@ -0,0 +1,241 @@
package migration1
import (
"encoding/binary"
"io"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
)
// SessionIDSize is 33-bytes; it is a serialized, compressed public key.
const SessionIDSize = 33
// UnknownElementType is an alias for channeldb.UnknownElementType.
type UnknownElementType = channeldb.UnknownElementType
// SessionID is created from the remote public key of a client, and serves as a
// unique identifier and authentication for sending state updates.
type SessionID [SessionIDSize]byte
// TowerID is a unique 64-bit identifier allocated to each unique watchtower.
// This allows the client to conserve on-disk space by not needing to always
// reference towers by their pubkey.
type TowerID uint64
// Bytes encodes a TowerID into an 8-byte slice in big-endian byte order.
func (id TowerID) Bytes() []byte {
var buf [8]byte
binary.BigEndian.PutUint64(buf[:], uint64(id))
return buf[:]
}
// ClientSession encapsulates a SessionInfo returned from a successful
// session negotiation, and also records the tower and ephemeral secret used for
// communicating with the tower.
type ClientSession struct {
// ID is the client's public key used when authenticating with the
// tower.
ID SessionID
ClientSessionBody
}
// CSessionStatus is a bit-field representing the possible statuses of
// ClientSessions.
type CSessionStatus uint8
type ClientSessionBody struct {
// SeqNum is the next unallocated sequence number that can be sent to
// the tower.
SeqNum uint16
// TowerLastApplied the last last-applied the tower has echoed back.
TowerLastApplied uint16
// TowerID is the unique, db-assigned identifier that references the
// Tower with which the session is negotiated.
TowerID TowerID
// KeyIndex is the index of key locator used to derive the client's
// session key so that it can authenticate with the tower to update its
// session. In order to rederive the private key, the key locator should
// use the keychain.KeyFamilyTowerSession key family.
KeyIndex uint32
// Policy holds the negotiated session parameters.
Policy wtpolicy.Policy
// Status indicates the current state of the ClientSession.
Status CSessionStatus
// RewardPkScript is the pkscript that the tower's reward will be
// deposited to if a sweep transaction confirms and the sessions
// specifies a reward output.
RewardPkScript []byte
}
// Encode writes a ClientSessionBody to the passed io.Writer.
func (s *ClientSessionBody) Encode(w io.Writer) error {
return WriteElements(w,
s.SeqNum,
s.TowerLastApplied,
uint64(s.TowerID),
s.KeyIndex,
uint8(s.Status),
s.Policy,
s.RewardPkScript,
)
}
// Decode reads a ClientSessionBody from the passed io.Reader.
func (s *ClientSessionBody) Decode(r io.Reader) error {
var (
towerID uint64
status uint8
)
err := ReadElements(r,
&s.SeqNum,
&s.TowerLastApplied,
&towerID,
&s.KeyIndex,
&status,
&s.Policy,
&s.RewardPkScript,
)
if err != nil {
return err
}
s.TowerID = TowerID(towerID)
s.Status = CSessionStatus(status)
return nil
}
// WriteElements serializes a variadic list of elements into the given
// io.Writer.
func WriteElements(w io.Writer, elements ...interface{}) error {
for _, element := range elements {
if err := WriteElement(w, element); err != nil {
return err
}
}
return nil
}
// WriteElement serializes a single element into the provided io.Writer.
func WriteElement(w io.Writer, element interface{}) error {
err := channeldb.WriteElement(w, element)
switch {
// Known to channeldb codec.
case err == nil:
return nil
// Fail if error is not UnknownElementType.
default:
if _, ok := err.(UnknownElementType); !ok {
return err
}
}
// Process any wtdb-specific extensions to the codec.
switch e := element.(type) {
case SessionID:
if _, err := w.Write(e[:]); err != nil {
return err
}
case blob.BreachHint:
if _, err := w.Write(e[:]); err != nil {
return err
}
case wtpolicy.Policy:
return channeldb.WriteElements(w,
uint16(e.BlobType),
e.MaxUpdates,
e.RewardBase,
e.RewardRate,
uint64(e.SweepFeeRate),
)
// Type is still unknown to wtdb extensions, fail.
default:
return channeldb.NewUnknownElementType(
"WriteElement", element,
)
}
return nil
}
// ReadElements deserializes the provided io.Reader into a variadic list of
// target elements.
func ReadElements(r io.Reader, elements ...interface{}) error {
for _, element := range elements {
if err := ReadElement(r, element); err != nil {
return err
}
}
return nil
}
// ReadElement deserializes a single element from the provided io.Reader.
func ReadElement(r io.Reader, element interface{}) error {
err := channeldb.ReadElement(r, element)
switch {
// Known to channeldb codec.
case err == nil:
return nil
// Fail if error is not UnknownElementType.
default:
if _, ok := err.(UnknownElementType); !ok {
return err
}
}
// Process any wtdb-specific extensions to the codec.
switch e := element.(type) {
case *SessionID:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *blob.BreachHint:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *wtpolicy.Policy:
var (
blobType uint16
sweepFeeRate uint64
)
err := channeldb.ReadElements(r,
&blobType,
&e.MaxUpdates,
&e.RewardBase,
&e.RewardRate,
&sweepFeeRate,
)
if err != nil {
return err
}
e.BlobType = blob.Type(blobType)
e.SweepFeeRate = chainfee.SatPerKWeight(sweepFeeRate)
// Type is still unknown to wtdb extensions, fail.
default:
return channeldb.NewUnknownElementType(
"ReadElement", element,
)
}
return nil
}

View File

@ -0,0 +1,14 @@
package migration1
import (
"github.com/btcsuite/btclog"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}

View File

@ -3,7 +3,6 @@ package wtdb_test
import (
"bytes"
"encoding/binary"
"reflect"
"testing"
"github.com/btcsuite/btcd/chaincfg/chainhash"
@ -14,6 +13,7 @@ import (
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtmock"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
"github.com/stretchr/testify/require"
)
var (
@ -48,10 +48,7 @@ func (h *towerDBHarness) insertSession(s *wtdb.SessionInfo, expErr error) {
h.t.Helper()
err := h.db.InsertSessionInfo(s)
if err != expErr {
h.t.Fatalf("expected insert session error: %v, got : %v",
expErr, err)
}
require.ErrorIs(h.t, err, expErr)
}
// getSession retrieves the session identified by id, asserting that the call
@ -62,10 +59,7 @@ func (h *towerDBHarness) getSession(id *wtdb.SessionID,
h.t.Helper()
session, err := h.db.GetSessionInfo(id)
if err != expErr {
h.t.Fatalf("expected get session error: %v, got: %v",
expErr, err)
}
require.ErrorIs(h.t, err, expErr)
return session
}
@ -79,10 +73,7 @@ func (h *towerDBHarness) insertUpdate(s *wtdb.SessionStateUpdate,
h.t.Helper()
lastApplied, err := h.db.InsertStateUpdate(s)
if err != expErr {
h.t.Fatalf("expected insert update error: %v, got: %v",
expErr, err)
}
require.ErrorIs(h.t, err, expErr)
return lastApplied
}
@ -93,10 +84,7 @@ func (h *towerDBHarness) deleteSession(id wtdb.SessionID, expErr error) {
h.t.Helper()
err := h.db.DeleteSession(id)
if err != expErr {
h.t.Fatalf("expected deletion error: %v, got: %v",
expErr, err)
}
require.ErrorIs(h.t, err, expErr)
}
// queryMatches queries that database for the passed breach hint, returning all
@ -105,9 +93,7 @@ func (h *towerDBHarness) queryMatches(hint blob.BreachHint) []wtdb.Match {
h.t.Helper()
matches, err := h.db.QueryMatches([]blob.BreachHint{hint})
if err != nil {
h.t.Fatalf("unable to query matches: %v", err)
}
require.NoError(h.t, err)
return matches
}
@ -119,14 +105,10 @@ func (h *towerDBHarness) hasUpdate(hint blob.BreachHint) wtdb.Match {
h.t.Helper()
matches := h.queryMatches(hint)
if len(matches) != 1 {
h.t.Fatalf("expected 1 match, found: %d", len(matches))
}
require.Len(h.t, matches, 1)
match := matches[0]
if match.Hint != hint {
h.t.Fatalf("expected hint: %x, got: %x", hint, match.Hint)
}
require.Equal(h.t, hint, match.Hint)
return match
}
@ -158,11 +140,7 @@ func testInsertSession(h *towerDBHarness) {
h.insertSession(session, nil)
session2 := h.getSession(&id, nil)
if !reflect.DeepEqual(session, session2) {
h.t.Fatalf("expected session: %v, got %v",
session, session2)
}
require.Equal(h.t, session, session2)
h.insertSession(session, nil)
@ -211,28 +189,21 @@ func testMultipleMatches(h *towerDBHarness) {
// Query the db for matches on the chosen hint.
matches := h.queryMatches(hint)
if len(matches) != numUpdates {
h.t.Fatalf("num updates mismatch, want: %d, got: %d",
numUpdates, len(matches))
}
require.Len(h.t, matches, numUpdates)
// Assert that the hints are what we asked for, and compute the set of
// sessions returned.
sessions := make(map[wtdb.SessionID]struct{})
for _, match := range matches {
if match.Hint != hint {
h.t.Fatalf("hint mismatch, want: %v, got: %v",
hint, match.Hint)
}
require.Equal(h.t, hint, match.Hint)
sessions[match.ID] = struct{}{}
}
// Assert that the sessions returned match the session ids of the
// sessions we initially created.
for i := 0; i < numUpdates; i++ {
if _, ok := sessions[*id(i)]; !ok {
h.t.Fatalf("match for session %v not found", *id(i))
}
_, ok := sessions[*id(i)]
require.Truef(h.t, ok, "match for session %v not found", *id(i))
}
}
@ -242,33 +213,22 @@ func testMultipleMatches(h *towerDBHarness) {
func testLookoutTip(h *towerDBHarness) {
// Retrieve lookout tip on fresh db.
epoch, err := h.db.GetLookoutTip()
if err != nil {
h.t.Fatalf("unable to fetch lookout tip: %v", err)
}
require.NoError(h.t, err)
// Assert that the epoch is nil.
if epoch != nil {
h.t.Fatalf("lookout tip should not be set, found: %v", epoch)
}
require.Nil(h.t, epoch)
// Create a closure that inserts an epoch, retrieves it, and asserts
// that the returned epoch matches what was inserted.
setAndCheck := func(i int) {
expEpoch := epochFromInt(1)
err = h.db.SetLookoutTip(expEpoch)
if err != nil {
h.t.Fatalf("unable to set lookout tip: %v", err)
}
require.NoError(h.t, err)
epoch, err = h.db.GetLookoutTip()
if err != nil {
h.t.Fatalf("unable to fetch lookout tip: %v", err)
}
require.NoError(h.t, err)
if !reflect.DeepEqual(epoch, expEpoch) {
h.t.Fatalf("lookout tip mismatch, want: %v, got: %v",
expEpoch, epoch)
}
require.Equal(h.t, expEpoch, epoch)
}
// Set and assert the lookout tip.
@ -348,15 +308,10 @@ func testDeleteSession(h *towerDBHarness) {
// Assert that only one update is still present.
matches := h.queryMatches(hint)
if len(matches) != 1 {
h.t.Fatalf("expected one update, found: %d", len(matches))
}
require.Len(h.t, matches, 1)
// Assert that the update belongs to the first session.
if matches[0].ID != *id0 {
h.t.Fatalf("expected match for %v, instead is for: %v",
*id0, matches[0].ID)
}
require.Equal(h.t, *id0, matches[0].ID)
// Finally, remove the first session added.
h.deleteSession(*id0, nil)
@ -366,9 +321,7 @@ func testDeleteSession(h *towerDBHarness) {
// No matches should exist for this hint.
matches = h.queryMatches(hint)
if len(matches) != 0 {
h.t.Fatalf("expected zero updates, found: %d", len(matches))
}
require.Zero(h.t, len(matches))
}
type stateUpdateTest struct {
@ -403,10 +356,9 @@ func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) {
*expSession = *test.session
}
if len(test.updates) != len(test.updateErrs) {
h.t.Fatalf("malformed test case, num updates " +
"should match num errors")
}
require.Lenf(h.t, test.updates, len(test.updateErrs),
"malformed test case, num updates should match num "+
"errors")
// Send any updates provided in the test.
for i, update := range test.updates {
@ -430,10 +382,7 @@ func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) {
expSession.ClientLastApplied = update.LastApplied
match := h.hasUpdate(update.Hint)
if !reflect.DeepEqual(match.SessionInfo, expSession) {
h.t.Fatalf("expected session: %v, got: %v",
expSession, match.SessionInfo)
}
require.Equal(h.t, expSession, match.SessionInfo)
}
}
}
@ -640,14 +589,10 @@ func TestTowerDB(t *testing.T) {
bdb, err := wtdb.NewBoltBackendCreator(
true, path, "watchtower.db",
)(dbCfg)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
db, err := wtdb.OpenTowerDB(bdb)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
t.Cleanup(func() {
db.Close()
@ -664,14 +609,10 @@ func TestTowerDB(t *testing.T) {
bdb, err := wtdb.NewBoltBackendCreator(
true, path, "watchtower.db",
)(dbCfg)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
db, err := wtdb.OpenTowerDB(bdb)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
db.Close()
// Open the db again, ensuring we test a
@ -680,14 +621,10 @@ func TestTowerDB(t *testing.T) {
bdb, err = wtdb.NewBoltBackendCreator(
true, path, "watchtower.db",
)(dbCfg)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
db, err = wtdb.OpenTowerDB(bdb)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
t.Cleanup(func() {
db.Close()

View File

@ -3,6 +3,7 @@ package wtdb
import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration1"
)
// migration is a function which takes a prior outdated version of the database
@ -24,7 +25,11 @@ var towerDBVersions = []version{}
// clientDBVersions stores all versions and migrations of the client database.
// This list will be used when opening the database to determine if any
// migrations must be applied.
var clientDBVersions = []version{}
var clientDBVersions = []version{
{
migration: migration1.MigrateTowerToSessionIndex,
},
}
// getLatestDBVersion returns the last known database version.
func getLatestDBVersion(versions []version) uint32 {

View File

@ -220,6 +220,7 @@ func (m *ClientDB) listClientSessions(
if tower != nil && *tower != session.TowerID {
continue
}
session.Tower = m.towers[session.TowerID]
sessions[session.ID] = &session
}