mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 01:43:16 +01:00
Merge pull request #6972 from ellemouton/wtclientTowerDb
watchtower: add towerID-to-sessionID index
This commit is contained in:
commit
707546e2f0
@ -102,6 +102,14 @@ crash](https://github.com/lightningnetwork/lnd/pull/7019).
|
||||
* [The `tlv` package now allows decoding records larger than 65535 bytes. The
|
||||
caller is expected to know that doing so with untrusted input is
|
||||
unsafe.](https://github.com/lightningnetwork/lnd/pull/6779)
|
||||
|
||||
## Watchtowers
|
||||
|
||||
* [Create a towerID-to-sessionID index in the wtclient DB to improve the
|
||||
speed of listing sessions for a particular tower ID](
|
||||
https://github.com/lightningnetwork/lnd/pull/6972). This PR also ensures a
|
||||
closer coupling of Towers and Sessions and ensures that a session cannot be
|
||||
added if the tower it is referring to does not exist.
|
||||
|
||||
* [Create a helper function to wait for peer to come
|
||||
online](https://github.com/lightningnetwork/lnd/pull/6931).
|
||||
|
@ -287,26 +287,33 @@ func New(config *Config) (*TowerClient, error) {
|
||||
}
|
||||
plog := build.NewPrefixLog(prefix, log)
|
||||
|
||||
// Next, load all candidate sessions and towers from the database into
|
||||
// the client. We will use any of these session if their policies match
|
||||
// Next, load all candidate towers and sessions from the database into
|
||||
// the client. We will use any of these sessions if their policies match
|
||||
// the current policy of the client, otherwise they will be ignored and
|
||||
// new sessions will be requested.
|
||||
isAnchorClient := cfg.Policy.IsAnchorChannel()
|
||||
activeSessionFilter := genActiveSessionFilter(isAnchorClient)
|
||||
candidateSessions, err := getClientSessions(
|
||||
cfg.DB, cfg.SecretKeyRing, nil, activeSessionFilter,
|
||||
candidateTowers := newTowerListIterator()
|
||||
perActiveTower := func(tower *wtdb.Tower) {
|
||||
// If the tower has already been marked as active, then there is
|
||||
// no need to add it to the iterator again.
|
||||
if candidateTowers.IsActive(tower.ID) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Using private watchtower %s, offering policy %s",
|
||||
tower, cfg.Policy)
|
||||
|
||||
// Add the tower to the set of candidate towers.
|
||||
candidateTowers.AddCandidate(tower)
|
||||
}
|
||||
candidateSessions, err := getTowerAndSessionCandidates(
|
||||
cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var candidateTowers []*wtdb.Tower
|
||||
for _, s := range candidateSessions {
|
||||
plog.Infof("Using private watchtower %s, offering policy %s",
|
||||
s.Tower, cfg.Policy)
|
||||
candidateTowers = append(candidateTowers, s.Tower)
|
||||
}
|
||||
|
||||
// Load the sweep pkscripts that have been generated for all previously
|
||||
// registered channels.
|
||||
chanSummaries, err := cfg.DB.FetchChanSummaries()
|
||||
@ -318,7 +325,7 @@ func New(config *Config) (*TowerClient, error) {
|
||||
cfg: cfg,
|
||||
log: plog,
|
||||
pipeline: newTaskPipeline(plog),
|
||||
candidateTowers: newTowerListIterator(candidateTowers...),
|
||||
candidateTowers: candidateTowers,
|
||||
candidateSessions: candidateSessions,
|
||||
activeSessions: make(sessionQueueSet),
|
||||
summaries: chanSummaries,
|
||||
@ -349,13 +356,62 @@ func New(config *Config) (*TowerClient, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// getTowerAndSessionCandidates loads all the towers from the DB and then
|
||||
// fetches the sessions for each of tower. Sessions are only collected if they
|
||||
// pass the sessionFilter check. If a tower has a session that does pass the
|
||||
// sessionFilter check then the perActiveTower call-back will be called on that
|
||||
// tower.
|
||||
func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
||||
sessionFilter func(*wtdb.ClientSession) bool,
|
||||
perActiveTower func(tower *wtdb.Tower)) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
towers, err := db.ListTowers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
||||
for _, tower := range towers {
|
||||
sessions, err := db.ListClientSessions(&tower.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, s := range sessions {
|
||||
towerKeyDesc, err := keyRing.DeriveKey(
|
||||
keychain.KeyLocator{
|
||||
Family: keychain.KeyFamilyTowerSession,
|
||||
Index: s.KeyIndex,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.SessionKeyECDH = keychain.NewPubKeyECDH(
|
||||
towerKeyDesc, keyRing,
|
||||
)
|
||||
|
||||
if !sessionFilter(s) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Add the session to the set of candidate sessions.
|
||||
candidateSessions[s.ID] = s
|
||||
perActiveTower(tower)
|
||||
}
|
||||
}
|
||||
|
||||
return candidateSessions, nil
|
||||
}
|
||||
|
||||
// getClientSessions retrieves the client sessions for a particular tower if
|
||||
// specified, otherwise all client sessions for all towers are retrieved. An
|
||||
// optional filter can be provided to filter out any undesired client sessions.
|
||||
//
|
||||
// NOTE: This method should only be used when deserialization of a
|
||||
// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the
|
||||
// existing ListClientSessions method should be used.
|
||||
// ClientSession's SessionPrivKey field is desired, otherwise, the existing
|
||||
// ListClientSessions method should be used.
|
||||
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
||||
passesFilter func(*wtdb.ClientSession) bool) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
@ -371,12 +427,6 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
||||
// requests. This prevents us from having to store the private keys on
|
||||
// disk.
|
||||
for _, s := range sessions {
|
||||
tower, err := db.LoadTowerByID(s.TowerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.Tower = tower
|
||||
|
||||
towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{
|
||||
Family: keychain.KeyFamilyTowerSession,
|
||||
Index: s.KeyIndex,
|
||||
|
@ -62,7 +62,8 @@ type DB interface {
|
||||
// still be able to accept state updates. An optional tower ID can be
|
||||
// used to filter out any client sessions in the response that do not
|
||||
// correspond to this tower.
|
||||
ListClientSessions(*wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error)
|
||||
ListClientSessions(*wtdb.TowerID) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error)
|
||||
|
||||
// FetchChanSummaries loads a mapping from all registered channels to
|
||||
// their channel summaries.
|
||||
@ -96,8 +97,8 @@ type DB interface {
|
||||
AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error
|
||||
}
|
||||
|
||||
// AuthDialer connects to a remote node using an authenticated transport, such as
|
||||
// brontide. The dialer argument is used to specify a resolver, which allows
|
||||
// AuthDialer connects to a remote node using an authenticated transport, such
|
||||
// as brontide. The dialer argument is used to specify a resolver, which allows
|
||||
// this method to be used over Tor or clear net connections.
|
||||
type AuthDialer func(localKey keychain.SingleKeyECDH,
|
||||
netAddr *lnwire.NetAddress,
|
||||
|
@ -48,6 +48,12 @@ var (
|
||||
// tower-pubkey -> tower-id.
|
||||
cTowerIndexBkt = []byte("client-tower-index-bucket")
|
||||
|
||||
// cTowerToSessionIndexBkt is a top-level bucket storing:
|
||||
// tower-id -> session-id -> 1
|
||||
cTowerToSessionIndexBkt = []byte(
|
||||
"client-tower-to-session-index-bucket",
|
||||
)
|
||||
|
||||
// ErrTowerNotFound signals that the target tower was not found in the
|
||||
// database.
|
||||
ErrTowerNotFound = errors.New("tower not found")
|
||||
@ -113,7 +119,8 @@ var (
|
||||
// NewBoltBackendCreator returns a function that creates a new bbolt backend for
|
||||
// the watchtower database.
|
||||
func NewBoltBackendCreator(active bool, dbPath,
|
||||
dbFileName string) func(boltCfg *kvdb.BoltConfig) (kvdb.Backend, error) {
|
||||
dbFileName string) func(boltCfg *kvdb.BoltConfig) (kvdb.Backend,
|
||||
error) {
|
||||
|
||||
// If the watchtower client isn't active, we return a function that
|
||||
// always returns a nil DB to make sure we don't create empty database
|
||||
@ -195,6 +202,7 @@ func initClientDBBuckets(tx kvdb.RwTx) error {
|
||||
cSessionBkt,
|
||||
cTowerBkt,
|
||||
cTowerIndexBkt,
|
||||
cTowerToSessionIndexBkt,
|
||||
}
|
||||
|
||||
for _, bucket := range buckets {
|
||||
@ -259,6 +267,13 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
towerToSessionIndex := tx.ReadWriteBucket(
|
||||
cTowerToSessionIndexBkt,
|
||||
)
|
||||
if towerToSessionIndex == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// Check if the tower index already knows of this pubkey.
|
||||
towerIDBytes := towerIndex.Get(towerPubKey[:])
|
||||
if len(towerIDBytes) == 8 {
|
||||
@ -278,27 +293,32 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
|
||||
// If there are any client sessions that correspond to
|
||||
// this tower, we'll mark them as active to ensure we
|
||||
// load them upon restarts.
|
||||
//
|
||||
// TODO(wilmer): with an index of tower -> sessions we
|
||||
// can avoid the linear lookup.
|
||||
towerSessIndex := towerToSessionIndex.NestedReadBucket(
|
||||
tower.ID.Bytes(),
|
||||
)
|
||||
if towerSessIndex == nil {
|
||||
return ErrTowerNotFound
|
||||
}
|
||||
|
||||
sessions := tx.ReadWriteBucket(cSessionBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
towerID := TowerIDFromBytes(towerIDBytes)
|
||||
towerSessions, err := listClientSessions(
|
||||
sessions, &towerID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, session := range towerSessions {
|
||||
err := markSessionStatus(
|
||||
sessions, session, CSessionActive,
|
||||
|
||||
err = towerSessIndex.ForEach(func(k, _ []byte) error {
|
||||
session, err := getClientSessionBody(
|
||||
sessions, k,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return markSessionStatus(
|
||||
sessions, session, CSessionActive,
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// No such tower exists, create a new tower id for our
|
||||
@ -320,6 +340,13 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create a new bucket for this tower in the
|
||||
// tower-to-sessions index.
|
||||
_, err = towerToSessionIndex.CreateBucket(towerIDBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Store the new or updated tower under its tower id.
|
||||
@ -348,11 +375,19 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
if towers == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
towerIndex := tx.ReadWriteBucket(cTowerIndexBkt)
|
||||
if towerIndex == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
towersToSessionsIndex := tx.ReadWriteBucket(
|
||||
cTowerToSessionIndexBkt,
|
||||
)
|
||||
if towersToSessionsIndex == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// Don't return an error if the watchtower doesn't exist to act
|
||||
// as a NOP.
|
||||
pubKeyBytes := pubKey.SerializeCompressed()
|
||||
@ -380,15 +415,14 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
|
||||
// Otherwise, we should attempt to mark the tower's sessions as
|
||||
// inactive.
|
||||
//
|
||||
// TODO(wilmer): with an index of tower -> sessions we can avoid
|
||||
// the linear lookup.
|
||||
sessions := tx.ReadWriteBucket(cSessionBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
towerID := TowerIDFromBytes(towerIDBytes)
|
||||
towerSessions, err := listClientSessions(sessions, &towerID)
|
||||
towerSessions, err := listTowerSessions(
|
||||
towerID, sessions, towers, towersToSessionsIndex,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -399,7 +433,14 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
if err := towerIndex.Delete(pubKeyBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
return towers.Delete(towerIDBytes)
|
||||
|
||||
if err := towers.Delete(towerIDBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return towersToSessionsIndex.DeleteNestedBucket(
|
||||
towerIDBytes,
|
||||
)
|
||||
}
|
||||
|
||||
// We'll mark its sessions as inactive as long as they don't
|
||||
@ -573,14 +614,34 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
towers := tx.ReadBucket(cTowerBkt)
|
||||
if towers == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
towerToSessionIndex := tx.ReadWriteBucket(
|
||||
cTowerToSessionIndexBkt,
|
||||
)
|
||||
if towerToSessionIndex == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// Check that client session with this session id doesn't
|
||||
// already exist.
|
||||
existingSessionBytes := sessions.NestedReadWriteBucket(session.ID[:])
|
||||
existingSessionBytes := sessions.NestedReadWriteBucket(
|
||||
session.ID[:],
|
||||
)
|
||||
if existingSessionBytes != nil {
|
||||
return ErrClientSessionAlreadyExists
|
||||
}
|
||||
|
||||
// Ensure that a tower with the given ID actually exists in the
|
||||
// DB.
|
||||
towerID := session.TowerID
|
||||
if _, err := getTower(towers, towerID.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
blobType := session.Policy.BlobType
|
||||
|
||||
// Check that this tower has a reserved key index.
|
||||
@ -609,6 +670,19 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Add the new entry to the towerID-to-SessionID index.
|
||||
indexBkt := towerToSessionIndex.NestedReadWriteBucket(
|
||||
towerID.Bytes(),
|
||||
)
|
||||
if indexBkt == nil {
|
||||
return ErrTowerNotFound
|
||||
}
|
||||
|
||||
err = indexBkt.Put(session.ID[:], []byte{1})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, write the client session's body in the sessions
|
||||
// bucket.
|
||||
return putClientSessionBody(sessions, session)
|
||||
@ -662,15 +736,41 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID,
|
||||
// ListClientSessions returns the set of all client sessions known to the db. An
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession, error) {
|
||||
func (c *ClientDB) ListClientSessions(id *TowerID) (
|
||||
map[SessionID]*ClientSession, error) {
|
||||
|
||||
var clientSessions map[SessionID]*ClientSession
|
||||
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
|
||||
sessions := tx.ReadBucket(cSessionBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
towers := tx.ReadBucket(cTowerBkt)
|
||||
if towers == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
var err error
|
||||
clientSessions, err = listClientSessions(sessions, id)
|
||||
|
||||
// If no tower ID is specified, then fetch all the sessions
|
||||
// known to the db.
|
||||
if id == nil {
|
||||
clientSessions, err = listClientAllSessions(
|
||||
sessions, towers,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// Otherwise, fetch the sessions for the given tower.
|
||||
towerToSessionIndex := tx.ReadBucket(cTowerToSessionIndexBkt)
|
||||
if towerToSessionIndex == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
clientSessions, err = listTowerSessions(
|
||||
*id, sessions, towers, towerToSessionIndex,
|
||||
)
|
||||
return err
|
||||
}, func() {
|
||||
clientSessions = nil
|
||||
@ -682,11 +782,9 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession
|
||||
return clientSessions, nil
|
||||
}
|
||||
|
||||
// listClientSessions returns the set of all client sessions known to the db. An
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func listClientSessions(sessions kvdb.RBucket,
|
||||
id *TowerID) (map[SessionID]*ClientSession, error) {
|
||||
// listClientAllSessions returns the set of all client sessions known to the db.
|
||||
func listClientAllSessions(sessions,
|
||||
towers kvdb.RBucket) (map[SessionID]*ClientSession, error) {
|
||||
|
||||
clientSessions := make(map[SessionID]*ClientSession)
|
||||
err := sessions.ForEach(func(k, _ []byte) error {
|
||||
@ -694,19 +792,45 @@ func listClientSessions(sessions kvdb.RBucket,
|
||||
// the CommittedUpdates and AckedUpdates on startup to resume
|
||||
// committed updates and compute the highest known commit height
|
||||
// for each channel.
|
||||
session, err := getClientSession(sessions, k)
|
||||
session, err := getClientSession(sessions, towers, k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Filter out any sessions that don't correspond to the given
|
||||
// tower if one was set.
|
||||
if id != nil && session.TowerID != *id {
|
||||
return nil
|
||||
clientSessions[session.ID] = session
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return clientSessions, nil
|
||||
}
|
||||
|
||||
// listTowerSessions returns the set of all client sessions known to the db
|
||||
// that are associated with the given tower id.
|
||||
func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
|
||||
towerToSessionIndex kvdb.RBucket) (map[SessionID]*ClientSession,
|
||||
error) {
|
||||
|
||||
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
|
||||
if towerIndexBkt == nil {
|
||||
return nil, ErrTowerNotFound
|
||||
}
|
||||
|
||||
clientSessions := make(map[SessionID]*ClientSession)
|
||||
err := towerIndexBkt.ForEach(func(k, _ []byte) error {
|
||||
// We'll load the full client session since the client will need
|
||||
// the CommittedUpdates and AckedUpdates on startup to resume
|
||||
// committed updates and compute the highest known commit height
|
||||
// for each channel.
|
||||
session, err := getClientSession(sessionsBkt, towersBkt, k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientSessions[session.ID] = session
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@ -951,7 +1075,9 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
|
||||
|
||||
// If the commits sub-bucket doesn't exist, there can't possibly
|
||||
// be a corresponding committed update to remove.
|
||||
sessionCommits := sessionBkt.NestedReadWriteBucket(cSessionCommits)
|
||||
sessionCommits := sessionBkt.NestedReadWriteBucket(
|
||||
cSessionCommits,
|
||||
)
|
||||
if sessionCommits == nil {
|
||||
return ErrCommittedUpdateNotFound
|
||||
}
|
||||
@ -1004,8 +1130,8 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
|
||||
|
||||
// getClientSessionBody loads the body of a ClientSession from the sessions
|
||||
// bucket corresponding to the serialized session id. This does not deserialize
|
||||
// the CommittedUpdates or AckUpdates associated with the session. If the caller
|
||||
// requires this info, use getClientSession.
|
||||
// the CommittedUpdates, AckUpdates or the Tower associated with the session.
|
||||
// If the caller requires this info, use getClientSession.
|
||||
func getClientSessionBody(sessions kvdb.RBucket,
|
||||
idBytes []byte) (*ClientSession, error) {
|
||||
|
||||
@ -1032,9 +1158,9 @@ func getClientSessionBody(sessions kvdb.RBucket,
|
||||
}
|
||||
|
||||
// getClientSession loads the full ClientSession associated with the serialized
|
||||
// session id. This method populates the CommittedUpdates and AckUpdates in
|
||||
// addition to the ClientSession's body.
|
||||
func getClientSession(sessions kvdb.RBucket,
|
||||
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
|
||||
// in addition to the ClientSession's body.
|
||||
func getClientSession(sessions, towers kvdb.RBucket,
|
||||
idBytes []byte) (*ClientSession, error) {
|
||||
|
||||
session, err := getClientSessionBody(sessions, idBytes)
|
||||
@ -1042,6 +1168,12 @@ func getClientSession(sessions kvdb.RBucket,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fetch the tower associated with this session.
|
||||
tower, err := getTower(towers, session.TowerID.Bytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fetch the committed updates for this session.
|
||||
commitedUpdates, err := getClientSessionCommits(sessions, idBytes)
|
||||
if err != nil {
|
||||
@ -1054,6 +1186,7 @@ func getClientSession(sessions kvdb.RBucket,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session.Tower = tower
|
||||
session.CommittedUpdates = commitedUpdates
|
||||
session.AckedUpdates = ackedUpdates
|
||||
|
||||
|
@ -1,11 +1,9 @@
|
||||
package wtdb_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec/v2"
|
||||
@ -16,8 +14,12 @@ import (
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtmock"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// pseudoAddr is a fake network address to be used for testing purposes.
|
||||
var pseudoAddr = &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911}
|
||||
|
||||
// clientDBInit is a closure used to initialize a wtclient.DB instance.
|
||||
type clientDBInit func(t *testing.T) wtclient.DB
|
||||
|
||||
@ -37,23 +39,22 @@ func newClientDBHarness(t *testing.T, init clientDBInit) *clientDBHarness {
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, expErr error) {
|
||||
func (h *clientDBHarness) insertSession(session *wtdb.ClientSession,
|
||||
expErr error) {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.CreateClientSession(session)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected create client session error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) listSessions(id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession {
|
||||
func (h *clientDBHarness) listSessions(
|
||||
id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
sessions, err := h.db.ListClientSessions(id)
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to list client sessions: %v", err)
|
||||
}
|
||||
require.NoError(h.t, err, "unable to list client sessions")
|
||||
|
||||
return sessions
|
||||
}
|
||||
@ -64,13 +65,8 @@ func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID,
|
||||
h.t.Helper()
|
||||
|
||||
index, err := h.db.NextSessionKeyIndex(id, blobType)
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to create next session key index: %v", err)
|
||||
}
|
||||
|
||||
if index == 0 {
|
||||
h.t.Fatalf("next key index should never be 0")
|
||||
}
|
||||
require.NoError(h.t, err, "unable to create next session key index")
|
||||
require.NotZero(h.t, index, "next key index should never be 0")
|
||||
|
||||
return index
|
||||
}
|
||||
@ -81,20 +77,11 @@ func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress,
|
||||
h.t.Helper()
|
||||
|
||||
tower, err := h.db.CreateTower(lnAddr)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected create tower error: %v, got: %v", expErr, err)
|
||||
}
|
||||
|
||||
if tower.ID == 0 {
|
||||
h.t.Fatalf("tower id should never be 0")
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
require.NotZero(h.t, tower.ID, "tower id should never be 0")
|
||||
|
||||
for _, session := range h.listSessions(&tower.ID) {
|
||||
if session.Status != wtdb.CSessionActive {
|
||||
h.t.Fatalf("expected status for session %v to be %v, "+
|
||||
"got %v", session.ID, wtdb.CSessionActive,
|
||||
session.Status)
|
||||
}
|
||||
require.Equal(h.t, wtdb.CSessionActive, session.Status)
|
||||
}
|
||||
|
||||
return tower
|
||||
@ -105,68 +92,64 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr,
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
if err := h.db.RemoveTower(pubKey, addr); err != expErr {
|
||||
h.t.Fatalf("expected remove tower error: %v, got %v", expErr, err)
|
||||
}
|
||||
err := h.db.RemoveTower(pubKey, addr)
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
|
||||
if expErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pubKeyStr := pubKey.SerializeCompressed()
|
||||
|
||||
if addr != nil {
|
||||
tower, err := h.db.LoadTower(pubKey)
|
||||
if err != nil {
|
||||
h.t.Fatalf("expected tower %x to still exist",
|
||||
pubKey.SerializeCompressed())
|
||||
}
|
||||
require.NoErrorf(h.t, err, "expected tower %x to still exist",
|
||||
pubKeyStr)
|
||||
|
||||
removedAddr := addr.String()
|
||||
for _, towerAddr := range tower.Addresses {
|
||||
if towerAddr.String() == removedAddr {
|
||||
h.t.Fatalf("address %v not removed for tower %x",
|
||||
removedAddr, pubKey.SerializeCompressed())
|
||||
}
|
||||
require.NotEqualf(h.t, removedAddr, towerAddr,
|
||||
"address %v not removed for tower %x",
|
||||
removedAddr, pubKeyStr)
|
||||
}
|
||||
} else {
|
||||
tower, err := h.db.LoadTower(pubKey)
|
||||
if hasSessions && err != nil {
|
||||
h.t.Fatalf("expected tower %x with sessions to still "+
|
||||
"exist", pubKey.SerializeCompressed())
|
||||
}
|
||||
if !hasSessions && err == nil {
|
||||
h.t.Fatalf("expected tower %x with no sessions to not "+
|
||||
"exist", pubKey.SerializeCompressed())
|
||||
}
|
||||
if !hasSessions {
|
||||
if hasSessions {
|
||||
require.NoError(h.t, err, "expected tower %x with "+
|
||||
"sessions to still exist", pubKeyStr)
|
||||
} else {
|
||||
require.Errorf(h.t, err, "expected tower %x with no "+
|
||||
"sessions to not exist", pubKeyStr)
|
||||
return
|
||||
}
|
||||
|
||||
for _, session := range h.listSessions(&tower.ID) {
|
||||
if session.Status != wtdb.CSessionInactive {
|
||||
h.t.Fatalf("expected status for session %v to "+
|
||||
"be %v, got %v", session.ID,
|
||||
wtdb.CSessionInactive, session.Status)
|
||||
}
|
||||
require.Equal(h.t, wtdb.CSessionInactive,
|
||||
session.Status, "expected status for session "+
|
||||
"%v to be %v, got %v", session.ID,
|
||||
wtdb.CSessionInactive, session.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) loadTower(pubKey *btcec.PublicKey, expErr error) *wtdb.Tower {
|
||||
func (h *clientDBHarness) loadTower(pubKey *btcec.PublicKey,
|
||||
expErr error) *wtdb.Tower {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
tower, err := h.db.LoadTower(pubKey)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
|
||||
return tower
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID, expErr error) *wtdb.Tower {
|
||||
func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID,
|
||||
expErr error) *wtdb.Tower {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
tower, err := h.db.LoadTowerByID(id)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
|
||||
return tower
|
||||
}
|
||||
@ -175,9 +158,7 @@ func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientC
|
||||
h.t.Helper()
|
||||
|
||||
summaries, err := h.db.FetchChanSummaries()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to fetch chan summaries: %v", err)
|
||||
}
|
||||
require.NoError(h.t, err)
|
||||
|
||||
return summaries
|
||||
}
|
||||
@ -188,10 +169,7 @@ func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID,
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.RegisterChannel(chanID, sweepPkScript)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected register channel error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID,
|
||||
@ -200,10 +178,7 @@ func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID,
|
||||
h.t.Helper()
|
||||
|
||||
lastApplied, err := h.db.CommitUpdate(id, update)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected commit update error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
|
||||
return lastApplied
|
||||
}
|
||||
@ -214,10 +189,22 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16,
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.AckUpdate(id, seqNum, lastApplied)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected commit update error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
}
|
||||
|
||||
// newTower is a helper function that creates a new tower with a randomly
|
||||
// generated public key and inserts it into the client DB.
|
||||
func (h *clientDBHarness) newTower() *wtdb.Tower {
|
||||
h.t.Helper()
|
||||
|
||||
pk, err := randPubKey()
|
||||
require.NoError(h.t, err)
|
||||
|
||||
// Insert a random tower into the database.
|
||||
return h.createTower(&lnwire.NetAddress{
|
||||
IdentityKey: pk,
|
||||
Address: pseudoAddr,
|
||||
}, nil)
|
||||
}
|
||||
|
||||
// testCreateClientSession asserts various conditions regarding the creation of
|
||||
@ -228,10 +215,12 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16,
|
||||
func testCreateClientSession(h *clientDBHarness) {
|
||||
const blobType = blob.TypeAltruistAnchorCommit
|
||||
|
||||
tower := h.newTower()
|
||||
|
||||
// Create a test client session to insert.
|
||||
session := &wtdb.ClientSession{
|
||||
ClientSessionBody: wtdb.ClientSessionBody{
|
||||
TowerID: wtdb.TowerID(3),
|
||||
TowerID: tower.ID,
|
||||
Policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blobType,
|
||||
@ -245,9 +234,9 @@ func testCreateClientSession(h *clientDBHarness) {
|
||||
|
||||
// First, assert that this session is not already present in the
|
||||
// database.
|
||||
if _, ok := h.listSessions(nil)[session.ID]; ok {
|
||||
h.t.Fatalf("session for id %x should not exist yet", session.ID)
|
||||
}
|
||||
_, ok := h.listSessions(nil)[session.ID]
|
||||
require.Falsef(h.t, ok, "session for id %x should not exist yet",
|
||||
session.ID)
|
||||
|
||||
// Attempting to insert the client session without reserving a session
|
||||
// key index should fail.
|
||||
@ -264,10 +253,8 @@ func testCreateClientSession(h *clientDBHarness) {
|
||||
// successfully created, it should return the same index to maintain
|
||||
// idempotency across restarts.
|
||||
keyIndex2 := h.nextKeyIndex(session.TowerID, blobType)
|
||||
if keyIndex != keyIndex2 {
|
||||
h.t.Fatalf("next key index should be idempotent: want: %v, "+
|
||||
"got %v", keyIndex, keyIndex2)
|
||||
}
|
||||
require.Equalf(h.t, keyIndex, keyIndex2, "next key index should "+
|
||||
"be idempotent: want: %v, got %v", keyIndex, keyIndex2)
|
||||
|
||||
// Now, set the client session's key index so that it is proper and
|
||||
// insert it. This should succeed.
|
||||
@ -275,9 +262,8 @@ func testCreateClientSession(h *clientDBHarness) {
|
||||
h.insertSession(session, nil)
|
||||
|
||||
// Verify that the session now exists in the database.
|
||||
if _, ok := h.listSessions(nil)[session.ID]; !ok {
|
||||
h.t.Fatalf("session for id %x should exist now", session.ID)
|
||||
}
|
||||
_, ok = h.listSessions(nil)[session.ID]
|
||||
require.Truef(h.t, ok, "session for id %x should exist now", session.ID)
|
||||
|
||||
// Attempt to insert the session again, which should fail due to the
|
||||
// session already existing.
|
||||
@ -286,9 +272,8 @@ func testCreateClientSession(h *clientDBHarness) {
|
||||
// Finally, assert that reserving another key index succeeds with a
|
||||
// different key index, now that the first one has been finalized.
|
||||
keyIndex3 := h.nextKeyIndex(session.TowerID, blobType)
|
||||
if keyIndex == keyIndex3 {
|
||||
h.t.Fatalf("key index still reserved after creating session")
|
||||
}
|
||||
require.NotEqualf(h.t, keyIndex, keyIndex3, "key index still "+
|
||||
"reserved after creating session")
|
||||
}
|
||||
|
||||
// testFilterClientSessions asserts that we can correctly filter client sessions
|
||||
@ -300,15 +285,12 @@ func testFilterClientSessions(h *clientDBHarness) {
|
||||
const blobType = blob.TypeAltruistCommit
|
||||
towerSessions := make(map[wtdb.TowerID][]wtdb.SessionID)
|
||||
for i := 0; i < numSessions; i++ {
|
||||
towerID := wtdb.TowerID(1)
|
||||
if i == numSessions-1 {
|
||||
towerID = wtdb.TowerID(2)
|
||||
}
|
||||
keyIndex := h.nextKeyIndex(towerID, blobType)
|
||||
tower := h.newTower()
|
||||
keyIndex := h.nextKeyIndex(tower.ID, blobType)
|
||||
sessionID := wtdb.SessionID([33]byte{byte(i)})
|
||||
h.insertSession(&wtdb.ClientSession{
|
||||
ClientSessionBody: wtdb.ClientSessionBody{
|
||||
TowerID: towerID,
|
||||
TowerID: tower.ID,
|
||||
Policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blobType,
|
||||
@ -320,22 +302,21 @@ func testFilterClientSessions(h *clientDBHarness) {
|
||||
},
|
||||
ID: sessionID,
|
||||
}, nil)
|
||||
towerSessions[towerID] = append(towerSessions[towerID], sessionID)
|
||||
towerSessions[tower.ID] = append(
|
||||
towerSessions[tower.ID], sessionID,
|
||||
)
|
||||
}
|
||||
|
||||
// We should see the expected sessions for each tower when filtering
|
||||
// them.
|
||||
for towerID, expectedSessions := range towerSessions {
|
||||
sessions := h.listSessions(&towerID)
|
||||
if len(sessions) != len(expectedSessions) {
|
||||
h.t.Fatalf("expected %v sessions for tower %v, got %v",
|
||||
len(expectedSessions), towerID, len(sessions))
|
||||
}
|
||||
require.Len(h.t, sessions, len(expectedSessions))
|
||||
|
||||
for _, expectedSession := range expectedSessions {
|
||||
if _, ok := sessions[expectedSession]; !ok {
|
||||
h.t.Fatalf("expected session %v for tower %v",
|
||||
expectedSession, towerID)
|
||||
}
|
||||
_, ok := sessions[expectedSession]
|
||||
require.Truef(h.t, ok, "expected session %v for "+
|
||||
"tower %v", expectedSession, towerID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -347,49 +328,31 @@ func testCreateTower(h *clientDBHarness) {
|
||||
// Test that loading a tower with an arbitrary tower id fails.
|
||||
h.loadTowerByID(20, wtdb.ErrTowerNotFound)
|
||||
|
||||
pk, err := randPubKey()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to generate pubkey: %v", err)
|
||||
}
|
||||
|
||||
addr1 := &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911}
|
||||
lnAddr := &lnwire.NetAddress{
|
||||
IdentityKey: pk,
|
||||
Address: addr1,
|
||||
}
|
||||
|
||||
// Insert a random tower into the database.
|
||||
tower := h.createTower(lnAddr, nil)
|
||||
tower := h.newTower()
|
||||
require.Len(h.t, tower.LNAddrs(), 1)
|
||||
towerAddr := tower.LNAddrs()[0]
|
||||
|
||||
// Load the tower from the database and assert that it matches the tower
|
||||
// we created.
|
||||
tower2 := h.loadTowerByID(tower.ID, nil)
|
||||
if !reflect.DeepEqual(tower, tower2) {
|
||||
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
||||
tower, tower2)
|
||||
}
|
||||
tower2 = h.loadTower(pk, err)
|
||||
if !reflect.DeepEqual(tower, tower2) {
|
||||
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
||||
tower, tower2)
|
||||
}
|
||||
require.Equal(h.t, tower, tower2)
|
||||
|
||||
tower2 = h.loadTower(tower.IdentityKey, nil)
|
||||
require.Equal(h.t, tower, tower2)
|
||||
|
||||
// Insert the address again into the database. Since the address is the
|
||||
// same, this should result in an unmodified tower record.
|
||||
towerDupAddr := h.createTower(lnAddr, nil)
|
||||
if len(towerDupAddr.Addresses) != 1 {
|
||||
h.t.Fatalf("duplicate address should be deduped")
|
||||
}
|
||||
if !reflect.DeepEqual(tower, towerDupAddr) {
|
||||
h.t.Fatalf("mismatch towers, want: %v, got: %v",
|
||||
tower, towerDupAddr)
|
||||
}
|
||||
towerDupAddr := h.createTower(towerAddr, nil)
|
||||
require.Lenf(h.t, towerDupAddr.Addresses, 1, "duplicate address "+
|
||||
"should be deduped")
|
||||
|
||||
require.Equal(h.t, tower, towerDupAddr)
|
||||
|
||||
// Generate a new address for this tower.
|
||||
addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911}
|
||||
|
||||
lnAddr2 := &lnwire.NetAddress{
|
||||
IdentityKey: pk,
|
||||
IdentityKey: tower.IdentityKey,
|
||||
Address: addr2,
|
||||
}
|
||||
|
||||
@ -400,26 +363,18 @@ func testCreateTower(h *clientDBHarness) {
|
||||
// Load the tower from the database, and assert that it matches the
|
||||
// tower returned from creation.
|
||||
towerNewAddr2 := h.loadTowerByID(tower.ID, nil)
|
||||
if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) {
|
||||
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
||||
towerNewAddr, towerNewAddr2)
|
||||
}
|
||||
towerNewAddr2 = h.loadTower(pk, nil)
|
||||
if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) {
|
||||
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
||||
towerNewAddr, towerNewAddr2)
|
||||
}
|
||||
require.Equal(h.t, towerNewAddr, towerNewAddr2)
|
||||
|
||||
towerNewAddr2 = h.loadTower(tower.IdentityKey, nil)
|
||||
require.Equal(h.t, towerNewAddr, towerNewAddr2)
|
||||
|
||||
// Assert that there are now two addresses on the tower object.
|
||||
if len(towerNewAddr.Addresses) != 2 {
|
||||
h.t.Fatalf("new address should be added")
|
||||
}
|
||||
require.Lenf(h.t, towerNewAddr.Addresses, 2, "new address should be "+
|
||||
"added")
|
||||
|
||||
// Finally, assert that the new address was prepended since it is deemed
|
||||
// fresher.
|
||||
if !reflect.DeepEqual(tower.Addresses, towerNewAddr.Addresses[1:]) {
|
||||
h.t.Fatalf("new address should be prepended")
|
||||
}
|
||||
require.Equal(h.t, tower.Addresses, towerNewAddr.Addresses[1:])
|
||||
}
|
||||
|
||||
// testRemoveTower asserts the behavior of removing Tower objects as a whole and
|
||||
@ -427,9 +382,7 @@ func testCreateTower(h *clientDBHarness) {
|
||||
func testRemoveTower(h *clientDBHarness) {
|
||||
// Generate a random public key we'll use for our tower.
|
||||
pk, err := randPubKey()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to generate pubkey: %v", err)
|
||||
}
|
||||
require.NoError(h.t, err)
|
||||
|
||||
// Removing a tower that does not exist within the database should
|
||||
// result in a NOP.
|
||||
@ -507,28 +460,23 @@ func testRemoveTower(h *clientDBHarness) {
|
||||
func testChanSummaries(h *clientDBHarness) {
|
||||
// First, assert that this channel is not already registered.
|
||||
var chanID lnwire.ChannelID
|
||||
if _, ok := h.fetchChanSummaries()[chanID]; ok {
|
||||
h.t.Fatalf("pkscript for channel %x should not exist yet",
|
||||
chanID)
|
||||
}
|
||||
_, ok := h.fetchChanSummaries()[chanID]
|
||||
require.Falsef(h.t, ok, "pkscript for channel %x should not exist yet",
|
||||
chanID)
|
||||
|
||||
// Generate a random sweep pkscript and register it for this channel.
|
||||
expPkScript := make([]byte, 22)
|
||||
if _, err := io.ReadFull(crand.Reader, expPkScript); err != nil {
|
||||
h.t.Fatalf("unable to generate pkscript: %v", err)
|
||||
}
|
||||
_, err := io.ReadFull(crand.Reader, expPkScript)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
h.registerChan(chanID, expPkScript, nil)
|
||||
|
||||
// Assert that the channel exists and that its sweep pkscript matches
|
||||
// the one we registered.
|
||||
summary, ok := h.fetchChanSummaries()[chanID]
|
||||
if !ok {
|
||||
h.t.Fatalf("pkscript for channel %x should not exist yet",
|
||||
chanID)
|
||||
} else if bytes.Compare(expPkScript, summary.SweepPkScript) != 0 {
|
||||
h.t.Fatalf("pkscript mismatch, want: %x, got: %x",
|
||||
expPkScript, summary.SweepPkScript)
|
||||
}
|
||||
require.Truef(h.t, ok, "pkscript for channel %x should not exist yet",
|
||||
chanID)
|
||||
require.Equal(h.t, expPkScript, summary.SweepPkScript)
|
||||
|
||||
// Finally, assert that re-registering the same channel produces a
|
||||
// failure.
|
||||
@ -538,9 +486,11 @@ func testChanSummaries(h *clientDBHarness) {
|
||||
// testCommitUpdate tests the behavior of CommitUpdate, ensuring that they can
|
||||
func testCommitUpdate(h *clientDBHarness) {
|
||||
const blobType = blob.TypeAltruistCommit
|
||||
|
||||
tower := h.newTower()
|
||||
session := &wtdb.ClientSession{
|
||||
ClientSessionBody: wtdb.ClientSessionBody{
|
||||
TowerID: wtdb.TowerID(3),
|
||||
TowerID: tower.ID,
|
||||
Policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blobType,
|
||||
@ -565,10 +515,7 @@ func testCommitUpdate(h *clientDBHarness) {
|
||||
// succeed. The lastApplied value should be 0 since we have not received
|
||||
// an ack from the tower.
|
||||
lastApplied := h.commitUpdate(&session.ID, update1, nil)
|
||||
if lastApplied != 0 {
|
||||
h.t.Fatalf("last applied mismatch, want: 0, got: %v",
|
||||
lastApplied)
|
||||
}
|
||||
require.Zero(h.t, lastApplied)
|
||||
|
||||
// Assert that the committed update appears in the client session's
|
||||
// CommittedUpdates map when loaded from disk and that there are no
|
||||
@ -584,10 +531,7 @@ func testCommitUpdate(h *clientDBHarness) {
|
||||
// the on-disk update's hint). The lastApplied value should remain
|
||||
// unchanged.
|
||||
lastApplied2 := h.commitUpdate(&session.ID, update1, nil)
|
||||
if lastApplied2 != lastApplied {
|
||||
h.t.Fatalf("last applied should not have changed, got %v",
|
||||
lastApplied2)
|
||||
}
|
||||
require.Equal(h.t, lastApplied, lastApplied2)
|
||||
|
||||
// Assert that the loaded ClientSession is the same as before.
|
||||
dbSession = h.listSessions(nil)[session.ID]
|
||||
@ -605,10 +549,7 @@ func testCommitUpdate(h *clientDBHarness) {
|
||||
// which should succeed.
|
||||
update2.SeqNum = 2
|
||||
lastApplied3 := h.commitUpdate(&session.ID, update2, nil)
|
||||
if lastApplied3 != lastApplied {
|
||||
h.t.Fatalf("last applied should not have changed, got %v",
|
||||
lastApplied3)
|
||||
}
|
||||
require.Equal(h.t, lastApplied, lastApplied3)
|
||||
|
||||
// Check that both updates now appear as committed on the ClientSession
|
||||
// loaded from disk.
|
||||
@ -638,10 +579,12 @@ func testCommitUpdate(h *clientDBHarness) {
|
||||
func testAckUpdate(h *clientDBHarness) {
|
||||
const blobType = blob.TypeAltruistCommit
|
||||
|
||||
tower := h.newTower()
|
||||
|
||||
// Create a new session that the updates in this will be tied to.
|
||||
session := &wtdb.ClientSession{
|
||||
ClientSessionBody: wtdb.ClientSessionBody{
|
||||
TowerID: wtdb.TowerID(3),
|
||||
TowerID: tower.ID,
|
||||
Policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blobType,
|
||||
@ -668,10 +611,7 @@ func testAckUpdate(h *clientDBHarness) {
|
||||
// Commit to a random update at seqnum 1.
|
||||
update1 := randCommittedUpdate(h.t, 1)
|
||||
lastApplied := h.commitUpdate(&session.ID, update1, nil)
|
||||
if lastApplied != 0 {
|
||||
h.t.Fatalf("last applied mismatch, want: 0, got: %v",
|
||||
lastApplied)
|
||||
}
|
||||
require.Zero(h.t, lastApplied)
|
||||
|
||||
// Acking seqnum 1 should succeed.
|
||||
h.ackUpdate(&session.ID, 1, 1, nil)
|
||||
@ -699,10 +639,7 @@ func testAckUpdate(h *clientDBHarness) {
|
||||
// ack.
|
||||
update2 := randCommittedUpdate(h.t, 2)
|
||||
lastApplied = h.commitUpdate(&session.ID, update2, nil)
|
||||
if lastApplied != 1 {
|
||||
h.t.Fatalf("last applied mismatch, want: 1, got: %v",
|
||||
lastApplied)
|
||||
}
|
||||
require.EqualValues(h.t, 1, lastApplied)
|
||||
|
||||
// Ack seqnum 2.
|
||||
h.ackUpdate(&session.ID, 2, 2, nil)
|
||||
@ -740,10 +677,7 @@ func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession,
|
||||
expUpdates = make([]wtdb.CommittedUpdate, 0)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(session.CommittedUpdates, expUpdates) {
|
||||
t.Fatalf("committed updates mismatch, want: %v, got: %v",
|
||||
expUpdates, session.CommittedUpdates)
|
||||
}
|
||||
require.Equal(t, expUpdates, session.CommittedUpdates)
|
||||
}
|
||||
|
||||
// checkAckedUpdates asserts that the AckedUpdates on a session match the
|
||||
@ -758,10 +692,7 @@ func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession,
|
||||
expUpdates = make(map[uint16]wtdb.BackupID)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(session.AckedUpdates, expUpdates) {
|
||||
t.Fatalf("acked updates mismatch, want: %v, got: %v",
|
||||
expUpdates, session.AckedUpdates)
|
||||
}
|
||||
require.Equal(t, expUpdates, session.AckedUpdates)
|
||||
}
|
||||
|
||||
// TestClientDB asserts the behavior of a fresh client db, a reopened client db,
|
||||
@ -779,14 +710,10 @@ func TestClientDB(t *testing.T) {
|
||||
bdb, err := wtdb.NewBoltBackendCreator(
|
||||
true, t.TempDir(), "wtclient.db",
|
||||
)(dbCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err := wtdb.OpenClientDB(bdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
@ -803,27 +730,19 @@ func TestClientDB(t *testing.T) {
|
||||
bdb, err := wtdb.NewBoltBackendCreator(
|
||||
true, path, "wtclient.db",
|
||||
)(dbCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err := wtdb.OpenClientDB(bdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
db.Close()
|
||||
|
||||
bdb, err = wtdb.NewBoltBackendCreator(
|
||||
true, path, "wtclient.db",
|
||||
)(dbCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err = wtdb.OpenClientDB(bdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to reopen db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
@ -893,19 +812,16 @@ func TestClientDB(t *testing.T) {
|
||||
// randCommittedUpdate generates a random committed update.
|
||||
func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate {
|
||||
var chanID lnwire.ChannelID
|
||||
if _, err := io.ReadFull(crand.Reader, chanID[:]); err != nil {
|
||||
t.Fatalf("unable to generate chan id: %v", err)
|
||||
}
|
||||
_, err := io.ReadFull(crand.Reader, chanID[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var hint blob.BreachHint
|
||||
if _, err := io.ReadFull(crand.Reader, hint[:]); err != nil {
|
||||
t.Fatalf("unable to generate breach hint: %v", err)
|
||||
}
|
||||
_, err = io.ReadFull(crand.Reader, hint[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type()))
|
||||
if _, err := io.ReadFull(crand.Reader, encBlob); err != nil {
|
||||
t.Fatalf("unable to generate encrypted blob: %v", err)
|
||||
}
|
||||
_, err = io.ReadFull(crand.Reader, encBlob)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &wtdb.CommittedUpdate{
|
||||
SeqNum: seqNum,
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"github.com/btcsuite/btcd/btcec/v2"
|
||||
"github.com/lightningnetwork/lnd/tor"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func randPubKey() (*btcec.PublicKey, error) {
|
||||
@ -134,10 +135,7 @@ func TestCodec(tt *testing.T) {
|
||||
// Ensure encoding the object succeeds.
|
||||
var b bytes.Buffer
|
||||
err := obj.Encode(&b)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to encode: %v", err)
|
||||
return false
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
var obj2 dbObject
|
||||
switch obj.(type) {
|
||||
@ -162,17 +160,10 @@ func TestCodec(tt *testing.T) {
|
||||
|
||||
// Ensure decoding the object succeeds.
|
||||
err = obj2.Decode(bytes.NewReader(b.Bytes()))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to decode: %v", err)
|
||||
return false
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert the original and decoded object match.
|
||||
if !reflect.DeepEqual(obj, obj2) {
|
||||
t.Fatalf("encode/decode mismatch, want: %v, "+
|
||||
"got: %v", obj, obj2)
|
||||
return false
|
||||
}
|
||||
require.Equal(t, obj, obj2)
|
||||
|
||||
return true
|
||||
}
|
||||
@ -180,16 +171,10 @@ func TestCodec(tt *testing.T) {
|
||||
customTypeGen := map[string]func([]reflect.Value, *rand.Rand){
|
||||
"Tower": func(v []reflect.Value, r *rand.Rand) {
|
||||
pk, err := randPubKey()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate pubkey: %v", err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
addrs, err := randAddrs(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate addrs: %v", err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
obj := wtdb.Tower{
|
||||
IdentityKey: pk,
|
||||
@ -260,10 +245,7 @@ func TestCodec(tt *testing.T) {
|
||||
}
|
||||
|
||||
err := quick.Check(test.scenario, config)
|
||||
if err != nil {
|
||||
t.Fatalf("fuzz checks for msg=%s failed: %v",
|
||||
test.name, err)
|
||||
}
|
||||
require.NoError(h, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package wtdb
|
||||
import (
|
||||
"github.com/btcsuite/btclog"
|
||||
"github.com/lightningnetwork/lnd/build"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration1"
|
||||
)
|
||||
|
||||
// log is a logger that is initialized with no output filters. This
|
||||
@ -26,6 +27,7 @@ func DisableLog() {
|
||||
// using btclog.
|
||||
func UseLogger(logger btclog.Logger) {
|
||||
log = logger
|
||||
migration1.UseLogger(logger)
|
||||
}
|
||||
|
||||
// logClosure is used to provide a closure over expensive logging operations so
|
||||
|
145
watchtower/wtdb/migration1/client_db.go
Normal file
145
watchtower/wtdb/migration1/client_db.go
Normal file
@ -0,0 +1,145 @@
|
||||
package migration1
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
)
|
||||
|
||||
var (
|
||||
// cSessionBkt is a top-level bucket storing:
|
||||
// session-id => cSessionBody -> encoded ClientSessionBody
|
||||
// => cSessionCommits => seqnum -> encoded CommittedUpdate
|
||||
// => cSessionAcks => seqnum -> encoded BackupID
|
||||
cSessionBkt = []byte("client-session-bucket")
|
||||
|
||||
// cSessionBody is a sub-bucket of cSessionBkt storing only the body of
|
||||
// the ClientSession.
|
||||
cSessionBody = []byte("client-session-body")
|
||||
|
||||
// cTowerIDToSessionIDIndexBkt is a top-level bucket storing:
|
||||
// tower-id -> session-id -> 1
|
||||
cTowerIDToSessionIDIndexBkt = []byte(
|
||||
"client-tower-to-session-index-bucket",
|
||||
)
|
||||
|
||||
// ErrUninitializedDB signals that top-level buckets for the database
|
||||
// have not been initialized.
|
||||
ErrUninitializedDB = errors.New("db not initialized")
|
||||
|
||||
// ErrClientSessionNotFound signals that the requested client session
|
||||
// was not found in the database.
|
||||
ErrClientSessionNotFound = errors.New("client session not found")
|
||||
|
||||
// ErrCorruptClientSession signals that the client session's on-disk
|
||||
// structure deviates from what is expected.
|
||||
ErrCorruptClientSession = errors.New("client session corrupted")
|
||||
)
|
||||
|
||||
// MigrateTowerToSessionIndex constructs a new towerID-to-sessionID for the
|
||||
// watchtower client DB.
|
||||
func MigrateTowerToSessionIndex(tx kvdb.RwTx) error {
|
||||
log.Infof("Migrating the tower client db to add a " +
|
||||
"towerID-to-sessionID index")
|
||||
|
||||
// First, we collect all the entries we want to add to the index.
|
||||
entries, err := getIndexEntries(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then we create a new top-level bucket for the index.
|
||||
indexBkt, err := tx.CreateTopLevelBucket(cTowerIDToSessionIDIndexBkt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, we add all the collected entries to the index.
|
||||
for towerID, sessions := range entries {
|
||||
// Create a sub-bucket using the tower ID.
|
||||
towerBkt, err := indexBkt.CreateBucketIfNotExists(
|
||||
towerID.Bytes(),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for sessionID := range sessions {
|
||||
err := addIndex(towerBkt, sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addIndex adds a new towerID-sessionID pair to the given bucket. The
|
||||
// session ID is used as a key within the bucket and a value of []byte{1} is
|
||||
// used for each session ID key.
|
||||
func addIndex(towerBkt kvdb.RwBucket, sessionID SessionID) error {
|
||||
session := towerBkt.Get(sessionID[:])
|
||||
if session != nil {
|
||||
return fmt.Errorf("session %x duplicated", sessionID)
|
||||
}
|
||||
|
||||
return towerBkt.Put(sessionID[:], []byte{1})
|
||||
}
|
||||
|
||||
// getIndexEntries collects all the towerID-sessionID entries that need to be
|
||||
// added to the new index.
|
||||
func getIndexEntries(tx kvdb.RwTx) (map[TowerID]map[SessionID]bool, error) {
|
||||
sessions := tx.ReadBucket(cSessionBkt)
|
||||
if sessions == nil {
|
||||
return nil, ErrUninitializedDB
|
||||
}
|
||||
|
||||
index := make(map[TowerID]map[SessionID]bool)
|
||||
err := sessions.ForEach(func(k, _ []byte) error {
|
||||
session, err := getClientSession(sessions, k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if index[session.TowerID] == nil {
|
||||
index[session.TowerID] = make(map[SessionID]bool)
|
||||
}
|
||||
|
||||
index[session.TowerID][session.ID] = true
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return index, nil
|
||||
}
|
||||
|
||||
// getClientSession fetches the session with the given ID from the db.
|
||||
func getClientSession(sessions kvdb.RBucket, idBytes []byte) (*ClientSession,
|
||||
error) {
|
||||
|
||||
sessionBkt := sessions.NestedReadBucket(idBytes)
|
||||
if sessionBkt == nil {
|
||||
return nil, ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
// Should never have a sessionBkt without also having its body.
|
||||
sessionBody := sessionBkt.Get(cSessionBody)
|
||||
if sessionBody == nil {
|
||||
return nil, ErrCorruptClientSession
|
||||
}
|
||||
|
||||
var session ClientSession
|
||||
copy(session.ID[:], idBytes)
|
||||
|
||||
err := session.Decode(bytes.NewReader(sessionBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
155
watchtower/wtdb/migration1/client_db_test.go
Normal file
155
watchtower/wtdb/migration1/client_db_test.go
Normal file
@ -0,0 +1,155 @@
|
||||
package migration1
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb/migtest"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
)
|
||||
|
||||
var (
|
||||
s1 = &ClientSessionBody{
|
||||
TowerID: TowerID(1),
|
||||
}
|
||||
s2 = &ClientSessionBody{
|
||||
TowerID: TowerID(3),
|
||||
}
|
||||
s3 = &ClientSessionBody{
|
||||
TowerID: TowerID(6),
|
||||
}
|
||||
|
||||
// pre is the expected data in the DB before the migration.
|
||||
pre = map[string]interface{}{
|
||||
sessionIDString("1"): map[string]interface{}{
|
||||
string(cSessionBody): clientSessionString(s1),
|
||||
},
|
||||
sessionIDString("2"): map[string]interface{}{
|
||||
string(cSessionBody): clientSessionString(s3),
|
||||
},
|
||||
sessionIDString("3"): map[string]interface{}{
|
||||
string(cSessionBody): clientSessionString(s1),
|
||||
},
|
||||
sessionIDString("4"): map[string]interface{}{
|
||||
string(cSessionBody): clientSessionString(s1),
|
||||
},
|
||||
sessionIDString("5"): map[string]interface{}{
|
||||
string(cSessionBody): clientSessionString(s2),
|
||||
},
|
||||
}
|
||||
|
||||
// preFailNoSessionBody should fail the migration due to there being a
|
||||
// session without an associated session body.
|
||||
preFailNoSessionBody = map[string]interface{}{
|
||||
sessionIDString("1"): map[string]interface{}{
|
||||
string(cSessionBody): clientSessionString(s1),
|
||||
},
|
||||
sessionIDString("2"): map[string]interface{}{},
|
||||
}
|
||||
|
||||
// post is the expected data after migration.
|
||||
post = map[string]interface{}{
|
||||
towerIDString(1): map[string]interface{}{
|
||||
sessionIDString("1"): string([]byte{1}),
|
||||
sessionIDString("3"): string([]byte{1}),
|
||||
sessionIDString("4"): string([]byte{1}),
|
||||
},
|
||||
towerIDString(3): map[string]interface{}{
|
||||
sessionIDString("5"): string([]byte{1}),
|
||||
},
|
||||
towerIDString(6): map[string]interface{}{
|
||||
sessionIDString("2"): string([]byte{1}),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// TestMigrateTowerToSessionIndex tests that the TestMigrateTowerToSessionIndex
|
||||
// function correctly adds a new towerID-to-sessionID index to the tower client
|
||||
// db.
|
||||
func TestMigrateTowerToSessionIndex(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
shouldFail bool
|
||||
pre map[string]interface{}
|
||||
post map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "migration ok",
|
||||
shouldFail: false,
|
||||
pre: pre,
|
||||
post: post,
|
||||
},
|
||||
{
|
||||
name: "fail due to corrupt db",
|
||||
shouldFail: true,
|
||||
pre: preFailNoSessionBody,
|
||||
post: nil,
|
||||
},
|
||||
{
|
||||
name: "no sessions",
|
||||
shouldFail: false,
|
||||
pre: nil,
|
||||
post: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Before the migration we have a sessions bucket.
|
||||
before := func(tx kvdb.RwTx) error {
|
||||
return migtest.RestoreDB(
|
||||
tx, cSessionBkt, test.pre,
|
||||
)
|
||||
}
|
||||
|
||||
// After the migration, we should have an untouched
|
||||
// sessions bucket and a new index bucket.
|
||||
after := func(tx kvdb.RwTx) error {
|
||||
if err := migtest.VerifyDB(
|
||||
tx, cSessionBkt, test.pre,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If we expect our migration to fail, we don't
|
||||
// expect an index bucket.
|
||||
if test.shouldFail {
|
||||
return nil
|
||||
}
|
||||
|
||||
return migtest.VerifyDB(
|
||||
tx, cTowerIDToSessionIDIndexBkt,
|
||||
test.post,
|
||||
)
|
||||
}
|
||||
|
||||
migtest.ApplyMigration(
|
||||
t, before, after, MigrateTowerToSessionIndex,
|
||||
test.shouldFail,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func sessionIDString(id string) string {
|
||||
var sessID SessionID
|
||||
copy(sessID[:], id)
|
||||
return string(sessID[:])
|
||||
}
|
||||
|
||||
func clientSessionString(s *ClientSessionBody) string {
|
||||
var b bytes.Buffer
|
||||
err := s.Encode(&b)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func towerIDString(id int) string {
|
||||
towerID := TowerID(id)
|
||||
return string(towerID.Bytes())
|
||||
}
|
241
watchtower/wtdb/migration1/codec.go
Normal file
241
watchtower/wtdb/migration1/codec.go
Normal file
@ -0,0 +1,241 @@
|
||||
package migration1
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
|
||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
)
|
||||
|
||||
// SessionIDSize is 33-bytes; it is a serialized, compressed public key.
|
||||
const SessionIDSize = 33
|
||||
|
||||
// UnknownElementType is an alias for channeldb.UnknownElementType.
|
||||
type UnknownElementType = channeldb.UnknownElementType
|
||||
|
||||
// SessionID is created from the remote public key of a client, and serves as a
|
||||
// unique identifier and authentication for sending state updates.
|
||||
type SessionID [SessionIDSize]byte
|
||||
|
||||
// TowerID is a unique 64-bit identifier allocated to each unique watchtower.
|
||||
// This allows the client to conserve on-disk space by not needing to always
|
||||
// reference towers by their pubkey.
|
||||
type TowerID uint64
|
||||
|
||||
// Bytes encodes a TowerID into an 8-byte slice in big-endian byte order.
|
||||
func (id TowerID) Bytes() []byte {
|
||||
var buf [8]byte
|
||||
binary.BigEndian.PutUint64(buf[:], uint64(id))
|
||||
return buf[:]
|
||||
}
|
||||
|
||||
// ClientSession encapsulates a SessionInfo returned from a successful
|
||||
// session negotiation, and also records the tower and ephemeral secret used for
|
||||
// communicating with the tower.
|
||||
type ClientSession struct {
|
||||
// ID is the client's public key used when authenticating with the
|
||||
// tower.
|
||||
ID SessionID
|
||||
ClientSessionBody
|
||||
}
|
||||
|
||||
// CSessionStatus is a bit-field representing the possible statuses of
|
||||
// ClientSessions.
|
||||
type CSessionStatus uint8
|
||||
|
||||
type ClientSessionBody struct {
|
||||
// SeqNum is the next unallocated sequence number that can be sent to
|
||||
// the tower.
|
||||
SeqNum uint16
|
||||
|
||||
// TowerLastApplied the last last-applied the tower has echoed back.
|
||||
TowerLastApplied uint16
|
||||
|
||||
// TowerID is the unique, db-assigned identifier that references the
|
||||
// Tower with which the session is negotiated.
|
||||
TowerID TowerID
|
||||
|
||||
// KeyIndex is the index of key locator used to derive the client's
|
||||
// session key so that it can authenticate with the tower to update its
|
||||
// session. In order to rederive the private key, the key locator should
|
||||
// use the keychain.KeyFamilyTowerSession key family.
|
||||
KeyIndex uint32
|
||||
|
||||
// Policy holds the negotiated session parameters.
|
||||
Policy wtpolicy.Policy
|
||||
|
||||
// Status indicates the current state of the ClientSession.
|
||||
Status CSessionStatus
|
||||
|
||||
// RewardPkScript is the pkscript that the tower's reward will be
|
||||
// deposited to if a sweep transaction confirms and the sessions
|
||||
// specifies a reward output.
|
||||
RewardPkScript []byte
|
||||
}
|
||||
|
||||
// Encode writes a ClientSessionBody to the passed io.Writer.
|
||||
func (s *ClientSessionBody) Encode(w io.Writer) error {
|
||||
return WriteElements(w,
|
||||
s.SeqNum,
|
||||
s.TowerLastApplied,
|
||||
uint64(s.TowerID),
|
||||
s.KeyIndex,
|
||||
uint8(s.Status),
|
||||
s.Policy,
|
||||
s.RewardPkScript,
|
||||
)
|
||||
}
|
||||
|
||||
// Decode reads a ClientSessionBody from the passed io.Reader.
|
||||
func (s *ClientSessionBody) Decode(r io.Reader) error {
|
||||
var (
|
||||
towerID uint64
|
||||
status uint8
|
||||
)
|
||||
err := ReadElements(r,
|
||||
&s.SeqNum,
|
||||
&s.TowerLastApplied,
|
||||
&towerID,
|
||||
&s.KeyIndex,
|
||||
&status,
|
||||
&s.Policy,
|
||||
&s.RewardPkScript,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.TowerID = TowerID(towerID)
|
||||
s.Status = CSessionStatus(status)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteElements serializes a variadic list of elements into the given
|
||||
// io.Writer.
|
||||
func WriteElements(w io.Writer, elements ...interface{}) error {
|
||||
for _, element := range elements {
|
||||
if err := WriteElement(w, element); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteElement serializes a single element into the provided io.Writer.
|
||||
func WriteElement(w io.Writer, element interface{}) error {
|
||||
err := channeldb.WriteElement(w, element)
|
||||
switch {
|
||||
// Known to channeldb codec.
|
||||
case err == nil:
|
||||
return nil
|
||||
|
||||
// Fail if error is not UnknownElementType.
|
||||
default:
|
||||
if _, ok := err.(UnknownElementType); !ok {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Process any wtdb-specific extensions to the codec.
|
||||
switch e := element.(type) {
|
||||
case SessionID:
|
||||
if _, err := w.Write(e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case blob.BreachHint:
|
||||
if _, err := w.Write(e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case wtpolicy.Policy:
|
||||
return channeldb.WriteElements(w,
|
||||
uint16(e.BlobType),
|
||||
e.MaxUpdates,
|
||||
e.RewardBase,
|
||||
e.RewardRate,
|
||||
uint64(e.SweepFeeRate),
|
||||
)
|
||||
|
||||
// Type is still unknown to wtdb extensions, fail.
|
||||
default:
|
||||
return channeldb.NewUnknownElementType(
|
||||
"WriteElement", element,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadElements deserializes the provided io.Reader into a variadic list of
|
||||
// target elements.
|
||||
func ReadElements(r io.Reader, elements ...interface{}) error {
|
||||
for _, element := range elements {
|
||||
if err := ReadElement(r, element); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadElement deserializes a single element from the provided io.Reader.
|
||||
func ReadElement(r io.Reader, element interface{}) error {
|
||||
err := channeldb.ReadElement(r, element)
|
||||
switch {
|
||||
// Known to channeldb codec.
|
||||
case err == nil:
|
||||
return nil
|
||||
|
||||
// Fail if error is not UnknownElementType.
|
||||
default:
|
||||
if _, ok := err.(UnknownElementType); !ok {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Process any wtdb-specific extensions to the codec.
|
||||
switch e := element.(type) {
|
||||
case *SessionID:
|
||||
if _, err := io.ReadFull(r, e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case *blob.BreachHint:
|
||||
if _, err := io.ReadFull(r, e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case *wtpolicy.Policy:
|
||||
var (
|
||||
blobType uint16
|
||||
sweepFeeRate uint64
|
||||
)
|
||||
err := channeldb.ReadElements(r,
|
||||
&blobType,
|
||||
&e.MaxUpdates,
|
||||
&e.RewardBase,
|
||||
&e.RewardRate,
|
||||
&sweepFeeRate,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.BlobType = blob.Type(blobType)
|
||||
e.SweepFeeRate = chainfee.SatPerKWeight(sweepFeeRate)
|
||||
|
||||
// Type is still unknown to wtdb extensions, fail.
|
||||
default:
|
||||
return channeldb.NewUnknownElementType(
|
||||
"ReadElement", element,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
14
watchtower/wtdb/migration1/log.go
Normal file
14
watchtower/wtdb/migration1/log.go
Normal file
@ -0,0 +1,14 @@
|
||||
package migration1
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btclog"
|
||||
)
|
||||
|
||||
// log is a logger that is initialized as disabled. This means the package will
|
||||
// not perform any logging by default until a logger is set.
|
||||
var log = btclog.Disabled
|
||||
|
||||
// UseLogger uses a specified Logger to output package logging info.
|
||||
func UseLogger(logger btclog.Logger) {
|
||||
log = logger
|
||||
}
|
@ -3,7 +3,6 @@ package wtdb_test
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
@ -14,6 +13,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtmock"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -48,10 +48,7 @@ func (h *towerDBHarness) insertSession(s *wtdb.SessionInfo, expErr error) {
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.InsertSessionInfo(s)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected insert session error: %v, got : %v",
|
||||
expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
}
|
||||
|
||||
// getSession retrieves the session identified by id, asserting that the call
|
||||
@ -62,10 +59,7 @@ func (h *towerDBHarness) getSession(id *wtdb.SessionID,
|
||||
h.t.Helper()
|
||||
|
||||
session, err := h.db.GetSessionInfo(id)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected get session error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
|
||||
return session
|
||||
}
|
||||
@ -79,10 +73,7 @@ func (h *towerDBHarness) insertUpdate(s *wtdb.SessionStateUpdate,
|
||||
h.t.Helper()
|
||||
|
||||
lastApplied, err := h.db.InsertStateUpdate(s)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected insert update error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
|
||||
return lastApplied
|
||||
}
|
||||
@ -93,10 +84,7 @@ func (h *towerDBHarness) deleteSession(id wtdb.SessionID, expErr error) {
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.DeleteSession(id)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected deletion error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
}
|
||||
|
||||
// queryMatches queries that database for the passed breach hint, returning all
|
||||
@ -105,9 +93,7 @@ func (h *towerDBHarness) queryMatches(hint blob.BreachHint) []wtdb.Match {
|
||||
h.t.Helper()
|
||||
|
||||
matches, err := h.db.QueryMatches([]blob.BreachHint{hint})
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to query matches: %v", err)
|
||||
}
|
||||
require.NoError(h.t, err)
|
||||
|
||||
return matches
|
||||
}
|
||||
@ -119,14 +105,10 @@ func (h *towerDBHarness) hasUpdate(hint blob.BreachHint) wtdb.Match {
|
||||
h.t.Helper()
|
||||
|
||||
matches := h.queryMatches(hint)
|
||||
if len(matches) != 1 {
|
||||
h.t.Fatalf("expected 1 match, found: %d", len(matches))
|
||||
}
|
||||
require.Len(h.t, matches, 1)
|
||||
|
||||
match := matches[0]
|
||||
if match.Hint != hint {
|
||||
h.t.Fatalf("expected hint: %x, got: %x", hint, match.Hint)
|
||||
}
|
||||
require.Equal(h.t, hint, match.Hint)
|
||||
|
||||
return match
|
||||
}
|
||||
@ -158,11 +140,7 @@ func testInsertSession(h *towerDBHarness) {
|
||||
h.insertSession(session, nil)
|
||||
|
||||
session2 := h.getSession(&id, nil)
|
||||
|
||||
if !reflect.DeepEqual(session, session2) {
|
||||
h.t.Fatalf("expected session: %v, got %v",
|
||||
session, session2)
|
||||
}
|
||||
require.Equal(h.t, session, session2)
|
||||
|
||||
h.insertSession(session, nil)
|
||||
|
||||
@ -211,28 +189,21 @@ func testMultipleMatches(h *towerDBHarness) {
|
||||
|
||||
// Query the db for matches on the chosen hint.
|
||||
matches := h.queryMatches(hint)
|
||||
if len(matches) != numUpdates {
|
||||
h.t.Fatalf("num updates mismatch, want: %d, got: %d",
|
||||
numUpdates, len(matches))
|
||||
}
|
||||
require.Len(h.t, matches, numUpdates)
|
||||
|
||||
// Assert that the hints are what we asked for, and compute the set of
|
||||
// sessions returned.
|
||||
sessions := make(map[wtdb.SessionID]struct{})
|
||||
for _, match := range matches {
|
||||
if match.Hint != hint {
|
||||
h.t.Fatalf("hint mismatch, want: %v, got: %v",
|
||||
hint, match.Hint)
|
||||
}
|
||||
require.Equal(h.t, hint, match.Hint)
|
||||
sessions[match.ID] = struct{}{}
|
||||
}
|
||||
|
||||
// Assert that the sessions returned match the session ids of the
|
||||
// sessions we initially created.
|
||||
for i := 0; i < numUpdates; i++ {
|
||||
if _, ok := sessions[*id(i)]; !ok {
|
||||
h.t.Fatalf("match for session %v not found", *id(i))
|
||||
}
|
||||
_, ok := sessions[*id(i)]
|
||||
require.Truef(h.t, ok, "match for session %v not found", *id(i))
|
||||
}
|
||||
}
|
||||
|
||||
@ -242,33 +213,22 @@ func testMultipleMatches(h *towerDBHarness) {
|
||||
func testLookoutTip(h *towerDBHarness) {
|
||||
// Retrieve lookout tip on fresh db.
|
||||
epoch, err := h.db.GetLookoutTip()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to fetch lookout tip: %v", err)
|
||||
}
|
||||
require.NoError(h.t, err)
|
||||
|
||||
// Assert that the epoch is nil.
|
||||
if epoch != nil {
|
||||
h.t.Fatalf("lookout tip should not be set, found: %v", epoch)
|
||||
}
|
||||
require.Nil(h.t, epoch)
|
||||
|
||||
// Create a closure that inserts an epoch, retrieves it, and asserts
|
||||
// that the returned epoch matches what was inserted.
|
||||
setAndCheck := func(i int) {
|
||||
expEpoch := epochFromInt(1)
|
||||
err = h.db.SetLookoutTip(expEpoch)
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to set lookout tip: %v", err)
|
||||
}
|
||||
require.NoError(h.t, err)
|
||||
|
||||
epoch, err = h.db.GetLookoutTip()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to fetch lookout tip: %v", err)
|
||||
}
|
||||
require.NoError(h.t, err)
|
||||
|
||||
if !reflect.DeepEqual(epoch, expEpoch) {
|
||||
h.t.Fatalf("lookout tip mismatch, want: %v, got: %v",
|
||||
expEpoch, epoch)
|
||||
}
|
||||
require.Equal(h.t, expEpoch, epoch)
|
||||
}
|
||||
|
||||
// Set and assert the lookout tip.
|
||||
@ -348,15 +308,10 @@ func testDeleteSession(h *towerDBHarness) {
|
||||
|
||||
// Assert that only one update is still present.
|
||||
matches := h.queryMatches(hint)
|
||||
if len(matches) != 1 {
|
||||
h.t.Fatalf("expected one update, found: %d", len(matches))
|
||||
}
|
||||
require.Len(h.t, matches, 1)
|
||||
|
||||
// Assert that the update belongs to the first session.
|
||||
if matches[0].ID != *id0 {
|
||||
h.t.Fatalf("expected match for %v, instead is for: %v",
|
||||
*id0, matches[0].ID)
|
||||
}
|
||||
require.Equal(h.t, *id0, matches[0].ID)
|
||||
|
||||
// Finally, remove the first session added.
|
||||
h.deleteSession(*id0, nil)
|
||||
@ -366,9 +321,7 @@ func testDeleteSession(h *towerDBHarness) {
|
||||
|
||||
// No matches should exist for this hint.
|
||||
matches = h.queryMatches(hint)
|
||||
if len(matches) != 0 {
|
||||
h.t.Fatalf("expected zero updates, found: %d", len(matches))
|
||||
}
|
||||
require.Zero(h.t, len(matches))
|
||||
}
|
||||
|
||||
type stateUpdateTest struct {
|
||||
@ -403,10 +356,9 @@ func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) {
|
||||
*expSession = *test.session
|
||||
}
|
||||
|
||||
if len(test.updates) != len(test.updateErrs) {
|
||||
h.t.Fatalf("malformed test case, num updates " +
|
||||
"should match num errors")
|
||||
}
|
||||
require.Lenf(h.t, test.updates, len(test.updateErrs),
|
||||
"malformed test case, num updates should match num "+
|
||||
"errors")
|
||||
|
||||
// Send any updates provided in the test.
|
||||
for i, update := range test.updates {
|
||||
@ -430,10 +382,7 @@ func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) {
|
||||
expSession.ClientLastApplied = update.LastApplied
|
||||
|
||||
match := h.hasUpdate(update.Hint)
|
||||
if !reflect.DeepEqual(match.SessionInfo, expSession) {
|
||||
h.t.Fatalf("expected session: %v, got: %v",
|
||||
expSession, match.SessionInfo)
|
||||
}
|
||||
require.Equal(h.t, expSession, match.SessionInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -640,14 +589,10 @@ func TestTowerDB(t *testing.T) {
|
||||
bdb, err := wtdb.NewBoltBackendCreator(
|
||||
true, path, "watchtower.db",
|
||||
)(dbCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err := wtdb.OpenTowerDB(bdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
@ -664,14 +609,10 @@ func TestTowerDB(t *testing.T) {
|
||||
bdb, err := wtdb.NewBoltBackendCreator(
|
||||
true, path, "watchtower.db",
|
||||
)(dbCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err := wtdb.OpenTowerDB(bdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
db.Close()
|
||||
|
||||
// Open the db again, ensuring we test a
|
||||
@ -680,14 +621,10 @@ func TestTowerDB(t *testing.T) {
|
||||
bdb, err = wtdb.NewBoltBackendCreator(
|
||||
true, path, "watchtower.db",
|
||||
)(dbCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err = wtdb.OpenTowerDB(bdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
|
@ -3,6 +3,7 @@ package wtdb
|
||||
import (
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration1"
|
||||
)
|
||||
|
||||
// migration is a function which takes a prior outdated version of the database
|
||||
@ -24,7 +25,11 @@ var towerDBVersions = []version{}
|
||||
// clientDBVersions stores all versions and migrations of the client database.
|
||||
// This list will be used when opening the database to determine if any
|
||||
// migrations must be applied.
|
||||
var clientDBVersions = []version{}
|
||||
var clientDBVersions = []version{
|
||||
{
|
||||
migration: migration1.MigrateTowerToSessionIndex,
|
||||
},
|
||||
}
|
||||
|
||||
// getLatestDBVersion returns the last known database version.
|
||||
func getLatestDBVersion(versions []version) uint32 {
|
||||
|
@ -220,6 +220,7 @@ func (m *ClientDB) listClientSessions(
|
||||
if tower != nil && *tower != session.TowerID {
|
||||
continue
|
||||
}
|
||||
session.Tower = m.towers[session.TowerID]
|
||||
sessions[session.ID] = &session
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user