mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-15 03:51:23 +01:00
watchtower: extend client databse with CRUD operations for towers
These operations are currently unused, but will be integrated into the TowerClient at a later point as future preparation for the WatchtowerClient RPC subserver, which will allow users to add, remove, and list the watchtowers currntly in use.
This commit is contained in:
parent
56d66c80a1
commit
1d73a6564f
7 changed files with 501 additions and 19 deletions
|
@ -217,7 +217,7 @@ func New(config *Config) (*TowerClient, error) {
|
||||||
// requests. This prevents us from having to store the private keys on
|
// requests. This prevents us from having to store the private keys on
|
||||||
// disk.
|
// disk.
|
||||||
for _, s := range sessions {
|
for _, s := range sessions {
|
||||||
tower, err := cfg.DB.LoadTower(s.TowerID)
|
tower, err := cfg.DB.LoadTowerByID(s.TowerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,11 +17,31 @@ type DB interface {
|
||||||
// CreateTower initialize an address record used to communicate with a
|
// CreateTower initialize an address record used to communicate with a
|
||||||
// watchtower. Each Tower is assigned a unique ID, that is used to
|
// watchtower. Each Tower is assigned a unique ID, that is used to
|
||||||
// amortize storage costs of the public key when used by multiple
|
// amortize storage costs of the public key when used by multiple
|
||||||
// sessions.
|
// sessions. If the tower already exists, the address is appended to the
|
||||||
|
// list of all addresses used to that tower previously and its
|
||||||
|
// corresponding sessions are marked as active.
|
||||||
CreateTower(*lnwire.NetAddress) (*wtdb.Tower, error)
|
CreateTower(*lnwire.NetAddress) (*wtdb.Tower, error)
|
||||||
|
|
||||||
// LoadTower retrieves a tower by its tower ID.
|
// RemoveTower modifies a tower's record within the database. If an
|
||||||
LoadTower(wtdb.TowerID) (*wtdb.Tower, error)
|
// address is provided, then _only_ the address record should be removed
|
||||||
|
// from the tower's persisted state. Otherwise, we'll attempt to mark
|
||||||
|
// the tower as inactive by marking all of its sessions inactive. If any
|
||||||
|
// of its sessions has unacked updates, then ErrTowerUnackedUpdates is
|
||||||
|
// returned. If the tower doesn't have any sessions at all, it'll be
|
||||||
|
// completely removed from the database.
|
||||||
|
//
|
||||||
|
// NOTE: An error is not returned if the tower doesn't exist.
|
||||||
|
RemoveTower(*btcec.PublicKey, net.Addr) error
|
||||||
|
|
||||||
|
// LoadTower retrieves a tower by its public key.
|
||||||
|
LoadTower(*btcec.PublicKey) (*wtdb.Tower, error)
|
||||||
|
|
||||||
|
// LoadTowerByID retrieves a tower by its tower ID.
|
||||||
|
LoadTowerByID(wtdb.TowerID) (*wtdb.Tower, error)
|
||||||
|
|
||||||
|
// ListTowers retrieves the list of towers available within the
|
||||||
|
// database.
|
||||||
|
ListTowers() ([]*wtdb.Tower, error)
|
||||||
|
|
||||||
// NextSessionKeyIndex reserves a new session key derivation index for a
|
// NextSessionKeyIndex reserves a new session key derivation index for a
|
||||||
// particular tower id. The index is reserved for that tower until
|
// particular tower id. The index is reserved for that tower until
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/coreos/bbolt"
|
"github.com/coreos/bbolt"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
)
|
)
|
||||||
|
@ -55,6 +56,11 @@ var (
|
||||||
// database.
|
// database.
|
||||||
ErrTowerNotFound = errors.New("tower not found")
|
ErrTowerNotFound = errors.New("tower not found")
|
||||||
|
|
||||||
|
// ErrTowerUnackedUpdates is an error returned when we attempt to mark a
|
||||||
|
// tower's sessions as inactive, but one of its sessions has unacked
|
||||||
|
// updates.
|
||||||
|
ErrTowerUnackedUpdates = errors.New("tower has unacked updates")
|
||||||
|
|
||||||
// ErrCorruptClientSession signals that the client session's on-disk
|
// ErrCorruptClientSession signals that the client session's on-disk
|
||||||
// structure deviates from what is expected.
|
// structure deviates from what is expected.
|
||||||
ErrCorruptClientSession = errors.New("client session corrupted")
|
ErrCorruptClientSession = errors.New("client session corrupted")
|
||||||
|
@ -199,9 +205,11 @@ func (c *ClientDB) Close() error {
|
||||||
return c.db.Close()
|
return c.db.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateTower initializes a database entry with the given lightning address. If
|
// CreateTower initialize an address record used to communicate with a
|
||||||
// the tower exists, the address is append to the list of all addresses used to
|
// watchtower. Each Tower is assigned a unique ID, that is used to amortize
|
||||||
// that tower previously.
|
// storage costs of the public key when used by multiple sessions. If the tower
|
||||||
|
// already exists, the address is appended to the list of all addresses used to
|
||||||
|
// that tower previously and its corresponding sessions are marked as active.
|
||||||
func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
|
func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
|
||||||
var towerPubKey [33]byte
|
var towerPubKey [33]byte
|
||||||
copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed())
|
copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed())
|
||||||
|
@ -233,6 +241,32 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
|
||||||
// address is a duplicate, this will result in no
|
// address is a duplicate, this will result in no
|
||||||
// change.
|
// change.
|
||||||
tower.AddAddress(lnAddr.Address)
|
tower.AddAddress(lnAddr.Address)
|
||||||
|
|
||||||
|
// If there are any client sessions that correspond to
|
||||||
|
// this tower, we'll mark them as active to ensure we
|
||||||
|
// load them upon restarts.
|
||||||
|
//
|
||||||
|
// TODO(wilmer): with an index of tower -> sessions we
|
||||||
|
// can avoid the linear lookup.
|
||||||
|
sessions := tx.Bucket(cSessionBkt)
|
||||||
|
if sessions == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
towerID := TowerIDFromBytes(towerIDBytes)
|
||||||
|
towerSessions, err := listClientSessions(
|
||||||
|
sessions, &towerID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, session := range towerSessions {
|
||||||
|
err := markSessionStatus(
|
||||||
|
sessions, session, CSessionActive,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// No such tower exists, create a new tower id for our
|
// No such tower exists, create a new tower id for our
|
||||||
// new tower. The error is unhandled since NextSequence
|
// new tower. The error is unhandled since NextSequence
|
||||||
|
@ -265,8 +299,89 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
|
||||||
return tower, nil
|
return tower, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadTower retrieves a tower by its tower ID.
|
// RemoveTower modifies a tower's record within the database. If an address is
|
||||||
func (c *ClientDB) LoadTower(towerID TowerID) (*Tower, error) {
|
// provided, then _only_ the address record should be removed from the tower's
|
||||||
|
// persisted state. Otherwise, we'll attempt to mark the tower as inactive by
|
||||||
|
// marking all of its sessions inactive. If any of its sessions has unacked
|
||||||
|
// updates, then ErrTowerUnackedUpdates is returned. If the tower doesn't have
|
||||||
|
// any sessions at all, it'll be completely removed from the database.
|
||||||
|
//
|
||||||
|
// NOTE: An error is not returned if the tower doesn't exist.
|
||||||
|
func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||||
|
return c.db.Update(func(tx *bbolt.Tx) error {
|
||||||
|
towers := tx.Bucket(cTowerBkt)
|
||||||
|
if towers == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
towerIndex := tx.Bucket(cTowerIndexBkt)
|
||||||
|
if towerIndex == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't return an error if the watchtower doesn't exist to act
|
||||||
|
// as a NOP.
|
||||||
|
pubKeyBytes := pubKey.SerializeCompressed()
|
||||||
|
towerIDBytes := towerIndex.Get(pubKeyBytes)
|
||||||
|
if towerIDBytes == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If an address is provided, then we should _only_ remove the
|
||||||
|
// address record from the database.
|
||||||
|
if addr != nil {
|
||||||
|
tower, err := getTower(towers, towerIDBytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tower.RemoveAddress(addr)
|
||||||
|
return putTower(towers, tower)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, we should attempt to mark the tower's sessions as
|
||||||
|
// inactive.
|
||||||
|
//
|
||||||
|
// TODO(wilmer): with an index of tower -> sessions we can avoid
|
||||||
|
// the linear lookup.
|
||||||
|
sessions := tx.Bucket(cSessionBkt)
|
||||||
|
if sessions == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
towerID := TowerIDFromBytes(towerIDBytes)
|
||||||
|
towerSessions, err := listClientSessions(sessions, &towerID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it doesn't have any, we can completely remove it from the
|
||||||
|
// database.
|
||||||
|
if len(towerSessions) == 0 {
|
||||||
|
if err := towerIndex.Delete(pubKeyBytes); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return towers.Delete(towerIDBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We'll mark its sessions as inactive as long as they don't
|
||||||
|
// have any pending updates to ensure we don't load them upon
|
||||||
|
// restarts.
|
||||||
|
for _, session := range towerSessions {
|
||||||
|
if len(session.CommittedUpdates) > 0 {
|
||||||
|
return ErrTowerUnackedUpdates
|
||||||
|
}
|
||||||
|
err := markSessionStatus(
|
||||||
|
sessions, session, CSessionInactive,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadTowerByID retrieves a tower by its tower ID.
|
||||||
|
func (c *ClientDB) LoadTowerByID(towerID TowerID) (*Tower, error) {
|
||||||
var tower *Tower
|
var tower *Tower
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
err := c.db.View(func(tx *bbolt.Tx) error {
|
||||||
towers := tx.Bucket(cTowerBkt)
|
towers := tx.Bucket(cTowerBkt)
|
||||||
|
@ -285,6 +400,60 @@ func (c *ClientDB) LoadTower(towerID TowerID) (*Tower, error) {
|
||||||
return tower, nil
|
return tower, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadTower retrieves a tower by its public key.
|
||||||
|
func (c *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*Tower, error) {
|
||||||
|
var tower *Tower
|
||||||
|
err := c.db.View(func(tx *bbolt.Tx) error {
|
||||||
|
towers := tx.Bucket(cTowerBkt)
|
||||||
|
if towers == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
towerIndex := tx.Bucket(cTowerIndexBkt)
|
||||||
|
if towerIndex == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
towerIDBytes := towerIndex.Get(pubKey.SerializeCompressed())
|
||||||
|
if towerIDBytes == nil {
|
||||||
|
return ErrTowerNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
tower, err = getTower(towers, towerIDBytes)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tower, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListTowers retrieves the list of towers available within the database.
|
||||||
|
func (c *ClientDB) ListTowers() ([]*Tower, error) {
|
||||||
|
var towers []*Tower
|
||||||
|
err := c.db.View(func(tx *bbolt.Tx) error {
|
||||||
|
towerBucket := tx.Bucket(cTowerBkt)
|
||||||
|
if towerBucket == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
return towerBucket.ForEach(func(towerIDBytes, _ []byte) error {
|
||||||
|
tower, err := getTower(towerBucket, towerIDBytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
towers = append(towers, tower)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return towers, nil
|
||||||
|
}
|
||||||
|
|
||||||
// NextSessionKeyIndex reserves a new session key derivation index for a
|
// NextSessionKeyIndex reserves a new session key derivation index for a
|
||||||
// particular tower id. The index is reserved for that tower until
|
// particular tower id. The index is reserved for that tower until
|
||||||
// CreateClientSession is invoked for that tower and index, at which point a new
|
// CreateClientSession is invoked for that tower and index, at which point a new
|
||||||
|
@ -871,6 +1040,15 @@ func putClientSessionBody(sessions *bbolt.Bucket,
|
||||||
return sessionBkt.Put(cSessionBody, b.Bytes())
|
return sessionBkt.Put(cSessionBody, b.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// markSessionStatus updates the persisted state of the session to the new
|
||||||
|
// status.
|
||||||
|
func markSessionStatus(sessions *bbolt.Bucket, session *ClientSession,
|
||||||
|
status CSessionStatus) error {
|
||||||
|
|
||||||
|
session.Status = status
|
||||||
|
return putClientSessionBody(sessions, session)
|
||||||
|
}
|
||||||
|
|
||||||
// getChanSummary loads a ClientChanSummary for the passed chanID.
|
// getChanSummary loads a ClientChanSummary for the passed chanID.
|
||||||
func getChanSummary(chanSummaries *bbolt.Bucket,
|
func getChanSummary(chanSummaries *bbolt.Bucket,
|
||||||
chanID lnwire.ChannelID) (*ClientChanSummary, error) {
|
chanID lnwire.ChannelID) (*ClientChanSummary, error) {
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtclient"
|
"github.com/lightningnetwork/lnd/watchtower/wtclient"
|
||||||
|
@ -89,13 +90,81 @@ func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress,
|
||||||
h.t.Fatalf("tower id should never be 0")
|
h.t.Fatalf("tower id should never be 0")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, session := range h.listSessions(&tower.ID) {
|
||||||
|
if session.Status != wtdb.CSessionActive {
|
||||||
|
h.t.Fatalf("expected status for session %v to be %v, "+
|
||||||
|
"got %v", session.ID, wtdb.CSessionActive,
|
||||||
|
session.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return tower
|
return tower
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *clientDBHarness) loadTower(id wtdb.TowerID, expErr error) *wtdb.Tower {
|
func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr,
|
||||||
|
hasSessions bool, expErr error) {
|
||||||
|
|
||||||
h.t.Helper()
|
h.t.Helper()
|
||||||
|
|
||||||
tower, err := h.db.LoadTower(id)
|
if err := h.db.RemoveTower(pubKey, addr); err != expErr {
|
||||||
|
h.t.Fatalf("expected remove tower error: %v, got %v", expErr, err)
|
||||||
|
}
|
||||||
|
if expErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr != nil {
|
||||||
|
tower, err := h.db.LoadTower(pubKey)
|
||||||
|
if err != nil {
|
||||||
|
h.t.Fatalf("expected tower %x to still exist",
|
||||||
|
pubKey.SerializeCompressed())
|
||||||
|
}
|
||||||
|
|
||||||
|
removedAddr := addr.String()
|
||||||
|
for _, towerAddr := range tower.Addresses {
|
||||||
|
if towerAddr.String() == removedAddr {
|
||||||
|
h.t.Fatalf("address %v not removed for tower %x",
|
||||||
|
removedAddr, pubKey.SerializeCompressed())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tower, err := h.db.LoadTower(pubKey)
|
||||||
|
if hasSessions && err != nil {
|
||||||
|
h.t.Fatalf("expected tower %x with sessions to still "+
|
||||||
|
"exist", pubKey.SerializeCompressed())
|
||||||
|
}
|
||||||
|
if !hasSessions && err == nil {
|
||||||
|
h.t.Fatalf("expected tower %x with no sessions to not "+
|
||||||
|
"exist", pubKey.SerializeCompressed())
|
||||||
|
}
|
||||||
|
if !hasSessions {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, session := range h.listSessions(&tower.ID) {
|
||||||
|
if session.Status != wtdb.CSessionInactive {
|
||||||
|
h.t.Fatalf("expected status for session %v to "+
|
||||||
|
"be %v, got %v", session.ID,
|
||||||
|
wtdb.CSessionInactive, session.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *clientDBHarness) loadTower(pubKey *btcec.PublicKey, expErr error) *wtdb.Tower {
|
||||||
|
h.t.Helper()
|
||||||
|
|
||||||
|
tower, err := h.db.LoadTower(pubKey)
|
||||||
|
if err != expErr {
|
||||||
|
h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tower
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID, expErr error) *wtdb.Tower {
|
||||||
|
h.t.Helper()
|
||||||
|
|
||||||
|
tower, err := h.db.LoadTowerByID(id)
|
||||||
if err != expErr {
|
if err != expErr {
|
||||||
h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err)
|
h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err)
|
||||||
}
|
}
|
||||||
|
@ -268,7 +337,7 @@ func testFilterClientSessions(h *clientDBHarness) {
|
||||||
// known addresses for the tower.
|
// known addresses for the tower.
|
||||||
func testCreateTower(h *clientDBHarness) {
|
func testCreateTower(h *clientDBHarness) {
|
||||||
// Test that loading a tower with an arbitrary tower id fails.
|
// Test that loading a tower with an arbitrary tower id fails.
|
||||||
h.loadTower(20, wtdb.ErrTowerNotFound)
|
h.loadTowerByID(20, wtdb.ErrTowerNotFound)
|
||||||
|
|
||||||
pk, err := randPubKey()
|
pk, err := randPubKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -286,7 +355,12 @@ func testCreateTower(h *clientDBHarness) {
|
||||||
|
|
||||||
// Load the tower from the database and assert that it matches the tower
|
// Load the tower from the database and assert that it matches the tower
|
||||||
// we created.
|
// we created.
|
||||||
tower2 := h.loadTower(tower.ID, nil)
|
tower2 := h.loadTowerByID(tower.ID, nil)
|
||||||
|
if !reflect.DeepEqual(tower, tower2) {
|
||||||
|
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
||||||
|
tower, tower2)
|
||||||
|
}
|
||||||
|
tower2 = h.loadTower(pk, err)
|
||||||
if !reflect.DeepEqual(tower, tower2) {
|
if !reflect.DeepEqual(tower, tower2) {
|
||||||
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
||||||
tower, tower2)
|
tower, tower2)
|
||||||
|
@ -317,7 +391,12 @@ func testCreateTower(h *clientDBHarness) {
|
||||||
|
|
||||||
// Load the tower from the database, and assert that it matches the
|
// Load the tower from the database, and assert that it matches the
|
||||||
// tower returned from creation.
|
// tower returned from creation.
|
||||||
towerNewAddr2 := h.loadTower(tower.ID, nil)
|
towerNewAddr2 := h.loadTowerByID(tower.ID, nil)
|
||||||
|
if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) {
|
||||||
|
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
||||||
|
towerNewAddr, towerNewAddr2)
|
||||||
|
}
|
||||||
|
towerNewAddr2 = h.loadTower(pk, nil)
|
||||||
if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) {
|
if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) {
|
||||||
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
||||||
towerNewAddr, towerNewAddr2)
|
towerNewAddr, towerNewAddr2)
|
||||||
|
@ -335,6 +414,82 @@ func testCreateTower(h *clientDBHarness) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testRemoveTower asserts the behavior of removing Tower objects as a whole and
|
||||||
|
// removing addresses from Tower objects within the database.
|
||||||
|
func testRemoveTower(h *clientDBHarness) {
|
||||||
|
// Generate a random public key we'll use for our tower.
|
||||||
|
pk, err := randPubKey()
|
||||||
|
if err != nil {
|
||||||
|
h.t.Fatalf("unable to generate pubkey: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Removing a tower that does not exist within the database should
|
||||||
|
// result in a NOP.
|
||||||
|
h.removeTower(pk, nil, false, nil)
|
||||||
|
|
||||||
|
// We'll create a tower with two addresses.
|
||||||
|
addr1 := &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911}
|
||||||
|
addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911}
|
||||||
|
h.createTower(&lnwire.NetAddress{
|
||||||
|
IdentityKey: pk,
|
||||||
|
Address: addr1,
|
||||||
|
}, nil)
|
||||||
|
h.createTower(&lnwire.NetAddress{
|
||||||
|
IdentityKey: pk,
|
||||||
|
Address: addr2,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
// We'll then remove the second address. We should now only see the
|
||||||
|
// first.
|
||||||
|
h.removeTower(pk, addr2, false, nil)
|
||||||
|
|
||||||
|
// We'll then remove the first address. We should now see that the tower
|
||||||
|
// has no addresses left.
|
||||||
|
h.removeTower(pk, addr1, false, nil)
|
||||||
|
|
||||||
|
// Removing the tower as a whole from the database should succeed since
|
||||||
|
// there aren't any active sessions for it.
|
||||||
|
h.removeTower(pk, nil, false, nil)
|
||||||
|
|
||||||
|
// We'll then recreate the tower, but this time we'll create a session
|
||||||
|
// for it.
|
||||||
|
tower := h.createTower(&lnwire.NetAddress{
|
||||||
|
IdentityKey: pk,
|
||||||
|
Address: addr1,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
session := &wtdb.ClientSession{
|
||||||
|
ClientSessionBody: wtdb.ClientSessionBody{
|
||||||
|
TowerID: tower.ID,
|
||||||
|
Policy: wtpolicy.Policy{
|
||||||
|
MaxUpdates: 100,
|
||||||
|
},
|
||||||
|
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
||||||
|
KeyIndex: h.nextKeyIndex(tower.ID, nil),
|
||||||
|
},
|
||||||
|
ID: wtdb.SessionID([33]byte{0x01}),
|
||||||
|
}
|
||||||
|
h.insertSession(session, nil)
|
||||||
|
update := randCommittedUpdate(h.t, 1)
|
||||||
|
h.commitUpdate(&session.ID, update, nil)
|
||||||
|
|
||||||
|
// We should not be able to fully remove it from the database since
|
||||||
|
// there's a session and it has unacked updates.
|
||||||
|
h.removeTower(pk, nil, true, wtdb.ErrTowerUnackedUpdates)
|
||||||
|
|
||||||
|
// Removing the tower after all sessions no longer have unacked updates
|
||||||
|
// should result in the sessions becoming inactive.
|
||||||
|
h.ackUpdate(&session.ID, 1, 1, nil)
|
||||||
|
h.removeTower(pk, nil, true, nil)
|
||||||
|
|
||||||
|
// Creating the tower again should mark all of the sessions active once
|
||||||
|
// again.
|
||||||
|
h.createTower(&lnwire.NetAddress{
|
||||||
|
IdentityKey: pk,
|
||||||
|
Address: addr1,
|
||||||
|
}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
// testChanSummaries tests the process of a registering a channel and its
|
// testChanSummaries tests the process of a registering a channel and its
|
||||||
// associated sweep pkscript.
|
// associated sweep pkscript.
|
||||||
func testChanSummaries(h *clientDBHarness) {
|
func testChanSummaries(h *clientDBHarness) {
|
||||||
|
@ -673,6 +828,10 @@ func TestClientDB(t *testing.T) {
|
||||||
name: "create tower",
|
name: "create tower",
|
||||||
run: testCreateTower,
|
run: testCreateTower,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "remove tower",
|
||||||
|
run: testRemoveTower,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "chan summaries",
|
name: "chan summaries",
|
||||||
run: testChanSummaries,
|
run: testChanSummaries,
|
||||||
|
|
|
@ -18,6 +18,10 @@ const (
|
||||||
// CSessionActive indicates that the ClientSession is active and can be
|
// CSessionActive indicates that the ClientSession is active and can be
|
||||||
// used for backups.
|
// used for backups.
|
||||||
CSessionActive CSessionStatus = 0
|
CSessionActive CSessionStatus = 0
|
||||||
|
|
||||||
|
// CSessionInactive indicates that the ClientSession is inactive and
|
||||||
|
// cannot be used for backups.
|
||||||
|
CSessionInactive CSessionStatus = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClientSession encapsulates a SessionInfo returned from a successful
|
// ClientSession encapsulates a SessionInfo returned from a successful
|
||||||
|
|
|
@ -62,6 +62,19 @@ func (t *Tower) AddAddress(addr net.Addr) {
|
||||||
t.Addresses = append([]net.Addr{addr}, t.Addresses...)
|
t.Addresses = append([]net.Addr{addr}, t.Addresses...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RemoveAddress removes the given address from the tower's in-memory list of
|
||||||
|
// addresses. If the address doesn't exist, then this will act as a NOP.
|
||||||
|
func (t *Tower) RemoveAddress(addr net.Addr) {
|
||||||
|
addrStr := addr.String()
|
||||||
|
for i, address := range t.Addresses {
|
||||||
|
if address.String() != addrStr {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t.Addresses = append(t.Addresses[:i], t.Addresses[i+1:]...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// LNAddrs generates a list of lnwire.NetAddress from a Tower instance's
|
// LNAddrs generates a list of lnwire.NetAddress from a Tower instance's
|
||||||
// addresses. This can be used to have a client try multiple addresses for the
|
// addresses. This can be used to have a client try multiple addresses for the
|
||||||
// same Tower.
|
// same Tower.
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
)
|
)
|
||||||
|
@ -37,9 +38,11 @@ func NewClientDB() *ClientDB {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateTower initializes a database entry with the given lightning address. If
|
// CreateTower initialize an address record used to communicate with a
|
||||||
// the tower exists, the address is append to the list of all addresses used to
|
// watchtower. Each Tower is assigned a unique ID, that is used to amortize
|
||||||
// that tower previously.
|
// storage costs of the public key when used by multiple sessions. If the tower
|
||||||
|
// already exists, the address is appended to the list of all addresses used to
|
||||||
|
// that tower previously and its corresponding sessions are marked as active.
|
||||||
func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
|
func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
@ -52,6 +55,15 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
|
||||||
if ok {
|
if ok {
|
||||||
tower = m.towers[towerID]
|
tower = m.towers[towerID]
|
||||||
tower.AddAddress(lnAddr.Address)
|
tower.AddAddress(lnAddr.Address)
|
||||||
|
|
||||||
|
towerSessions, err := m.listClientSessions(&towerID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for id, session := range towerSessions {
|
||||||
|
session.Status = wtdb.CSessionActive
|
||||||
|
m.activeSessions[id] = session
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1))
|
towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1))
|
||||||
tower = &wtdb.Tower{
|
tower = &wtdb.Tower{
|
||||||
|
@ -67,8 +79,83 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
|
||||||
return copyTower(tower), nil
|
return copyTower(tower), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadTower retrieves a tower by its tower ID.
|
// RemoveTower modifies a tower's record within the database. If an address is
|
||||||
func (m *ClientDB) LoadTower(towerID wtdb.TowerID) (*wtdb.Tower, error) {
|
// provided, then _only_ the address record should be removed from the tower's
|
||||||
|
// persisted state. Otherwise, we'll attempt to mark the tower as inactive by
|
||||||
|
// marking all of its sessions inactive. If any of its sessions has unacked
|
||||||
|
// updates, then ErrTowerUnackedUpdates is returned. If the tower doesn't have
|
||||||
|
// any sessions at all, it'll be completely removed from the database.
|
||||||
|
//
|
||||||
|
// NOTE: An error is not returned if the tower doesn't exist.
|
||||||
|
func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
tower, err := m.loadTower(pubKey)
|
||||||
|
if err == wtdb.ErrTowerNotFound {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr != nil {
|
||||||
|
tower.RemoveAddress(addr)
|
||||||
|
m.towers[tower.ID] = tower
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
towerSessions, err := m.listClientSessions(&tower.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(towerSessions) == 0 {
|
||||||
|
var towerPK towerPK
|
||||||
|
copy(towerPK[:], pubKey.SerializeCompressed())
|
||||||
|
delete(m.towerIndex, towerPK)
|
||||||
|
delete(m.towers, tower.ID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, session := range towerSessions {
|
||||||
|
if len(session.CommittedUpdates) > 0 {
|
||||||
|
return wtdb.ErrTowerUnackedUpdates
|
||||||
|
}
|
||||||
|
session.Status = wtdb.CSessionInactive
|
||||||
|
m.activeSessions[id] = session
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadTower retrieves a tower by its public key.
|
||||||
|
func (m *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*wtdb.Tower, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.loadTower(pubKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadTower retrieves a tower by its public key.
|
||||||
|
//
|
||||||
|
// NOTE: This method requires the database's lock to be acquired.
|
||||||
|
func (m *ClientDB) loadTower(pubKey *btcec.PublicKey) (*wtdb.Tower, error) {
|
||||||
|
var towerPK towerPK
|
||||||
|
copy(towerPK[:], pubKey.SerializeCompressed())
|
||||||
|
|
||||||
|
towerID, ok := m.towerIndex[towerPK]
|
||||||
|
if !ok {
|
||||||
|
return nil, wtdb.ErrTowerNotFound
|
||||||
|
}
|
||||||
|
tower, ok := m.towers[towerID]
|
||||||
|
if !ok {
|
||||||
|
return nil, wtdb.ErrTowerNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return copyTower(tower), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadTowerByID retrieves a tower by its tower ID.
|
||||||
|
func (m *ClientDB) LoadTowerByID(towerID wtdb.TowerID) (*wtdb.Tower, error) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
@ -79,6 +166,19 @@ func (m *ClientDB) LoadTower(towerID wtdb.TowerID) (*wtdb.Tower, error) {
|
||||||
return nil, wtdb.ErrTowerNotFound
|
return nil, wtdb.ErrTowerNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListTowers retrieves the list of towers available within the database.
|
||||||
|
func (m *ClientDB) ListTowers() ([]*wtdb.Tower, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
towers := make([]*wtdb.Tower, 0, len(m.towers))
|
||||||
|
for _, tower := range m.towers {
|
||||||
|
towers = append(towers, copyTower(tower))
|
||||||
|
}
|
||||||
|
|
||||||
|
return towers, nil
|
||||||
|
}
|
||||||
|
|
||||||
// MarkBackupIneligible records that particular commit height is ineligible for
|
// MarkBackupIneligible records that particular commit height is ineligible for
|
||||||
// backup. This allows the client to track which updates it should not attempt
|
// backup. This allows the client to track which updates it should not attempt
|
||||||
// to retry after startup.
|
// to retry after startup.
|
||||||
|
@ -94,6 +194,14 @@ func (m *ClientDB) ListClientSessions(
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
return m.listClientSessions(tower)
|
||||||
|
}
|
||||||
|
|
||||||
|
// listClientSessions returns the set of all client sessions known to the db. An
|
||||||
|
// optional tower ID can be used to filter out any client sessions in the
|
||||||
|
// response that do not correspond to this tower.
|
||||||
|
func (m *ClientDB) listClientSessions(
|
||||||
|
tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||||
|
|
||||||
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
||||||
for _, session := range m.activeSessions {
|
for _, session := range m.activeSessions {
|
||||||
|
|
Loading…
Add table
Reference in a new issue