mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 18:10:34 +01:00
c760700545
These race conditions originate from the mock database storing and returning pointers, rather than returning a copy. Observed on Travis: WARNING: DATA RACE Read at 0x00c0003222b8 by goroutine 149: github.com/lightningnetwork/lnd/watchtower/wtclient.(*sessionQueue).drainBackups() /home/runner/work/lnd/lnd/watchtower/wtclient/session_queue.go:288 +0xed github.com/lightningnetwork/lnd/watchtower/wtclient.(*sessionQueue).sessionManager() /home/runner/work/lnd/lnd/watchtower/wtclient/session_queue.go:281 +0x450 Previous write at 0x00c0003222b8 by goroutine 93: github.com/lightningnetwork/lnd/watchtower/wtclient.getClientSessions() /home/runner/work/lnd/lnd/watchtower/wtclient/client.go:365 +0x24f github.com/lightningnetwork/lnd/watchtower/wtclient.(*TowerClient).handleNewTower() /home/runner/work/lnd/lnd/watchtower/wtclient/client.go:1063 +0x23e github.com/lightningnetwork/lnd/watchtower/wtclient.(*TowerClient).backupDispatcher() /home/runner/work/lnd/lnd/watchtower/wtclient/client.go:784 +0x10b9
431 lines
12 KiB
Go
431 lines
12 KiB
Go
package wtmock
|
|
|
|
import (
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/btcsuite/btcd/btcec"
|
|
"github.com/lightningnetwork/lnd/lnwire"
|
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
|
)
|
|
|
|
type towerPK [33]byte
|
|
|
|
// ClientDB is a mock, in-memory database or testing the watchtower client
|
|
// behavior.
|
|
type ClientDB struct {
|
|
nextTowerID uint64 // to be used atomically
|
|
|
|
mu sync.Mutex
|
|
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
|
|
activeSessions map[wtdb.SessionID]wtdb.ClientSession
|
|
towerIndex map[towerPK]wtdb.TowerID
|
|
towers map[wtdb.TowerID]*wtdb.Tower
|
|
|
|
nextIndex uint32
|
|
indexes map[wtdb.TowerID]uint32
|
|
}
|
|
|
|
// NewClientDB initializes a new mock ClientDB.
|
|
func NewClientDB() *ClientDB {
|
|
return &ClientDB{
|
|
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
|
|
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
|
|
towerIndex: make(map[towerPK]wtdb.TowerID),
|
|
towers: make(map[wtdb.TowerID]*wtdb.Tower),
|
|
indexes: make(map[wtdb.TowerID]uint32),
|
|
}
|
|
}
|
|
|
|
// CreateTower initialize an address record used to communicate with a
|
|
// watchtower. Each Tower is assigned a unique ID, that is used to amortize
|
|
// storage costs of the public key when used by multiple sessions. If the tower
|
|
// already exists, the address is appended to the list of all addresses used to
|
|
// that tower previously and its corresponding sessions are marked as active.
|
|
func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
var towerPubKey towerPK
|
|
copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed())
|
|
|
|
var tower *wtdb.Tower
|
|
towerID, ok := m.towerIndex[towerPubKey]
|
|
if ok {
|
|
tower = m.towers[towerID]
|
|
tower.AddAddress(lnAddr.Address)
|
|
|
|
towerSessions, err := m.listClientSessions(&towerID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for id, session := range towerSessions {
|
|
session.Status = wtdb.CSessionActive
|
|
m.activeSessions[id] = *session
|
|
}
|
|
} else {
|
|
towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1))
|
|
tower = &wtdb.Tower{
|
|
ID: wtdb.TowerID(towerID),
|
|
IdentityKey: lnAddr.IdentityKey,
|
|
Addresses: []net.Addr{lnAddr.Address},
|
|
}
|
|
}
|
|
|
|
m.towerIndex[towerPubKey] = towerID
|
|
m.towers[towerID] = tower
|
|
|
|
return copyTower(tower), nil
|
|
}
|
|
|
|
// RemoveTower modifies a tower's record within the database. If an address is
|
|
// provided, then _only_ the address record should be removed from the tower's
|
|
// persisted state. Otherwise, we'll attempt to mark the tower as inactive by
|
|
// marking all of its sessions inactive. If any of its sessions has unacked
|
|
// updates, then ErrTowerUnackedUpdates is returned. If the tower doesn't have
|
|
// any sessions at all, it'll be completely removed from the database.
|
|
//
|
|
// NOTE: An error is not returned if the tower doesn't exist.
|
|
func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
tower, err := m.loadTower(pubKey)
|
|
if err == wtdb.ErrTowerNotFound {
|
|
return nil
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if addr != nil {
|
|
tower.RemoveAddress(addr)
|
|
m.towers[tower.ID] = tower
|
|
return nil
|
|
}
|
|
|
|
towerSessions, err := m.listClientSessions(&tower.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(towerSessions) == 0 {
|
|
var towerPK towerPK
|
|
copy(towerPK[:], pubKey.SerializeCompressed())
|
|
delete(m.towerIndex, towerPK)
|
|
delete(m.towers, tower.ID)
|
|
return nil
|
|
}
|
|
|
|
for id, session := range towerSessions {
|
|
if len(session.CommittedUpdates) > 0 {
|
|
return wtdb.ErrTowerUnackedUpdates
|
|
}
|
|
session.Status = wtdb.CSessionInactive
|
|
m.activeSessions[id] = *session
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LoadTower retrieves a tower by its public key.
|
|
func (m *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*wtdb.Tower, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.loadTower(pubKey)
|
|
}
|
|
|
|
// loadTower retrieves a tower by its public key.
|
|
//
|
|
// NOTE: This method requires the database's lock to be acquired.
|
|
func (m *ClientDB) loadTower(pubKey *btcec.PublicKey) (*wtdb.Tower, error) {
|
|
var towerPK towerPK
|
|
copy(towerPK[:], pubKey.SerializeCompressed())
|
|
|
|
towerID, ok := m.towerIndex[towerPK]
|
|
if !ok {
|
|
return nil, wtdb.ErrTowerNotFound
|
|
}
|
|
tower, ok := m.towers[towerID]
|
|
if !ok {
|
|
return nil, wtdb.ErrTowerNotFound
|
|
}
|
|
|
|
return copyTower(tower), nil
|
|
}
|
|
|
|
// LoadTowerByID retrieves a tower by its tower ID.
|
|
func (m *ClientDB) LoadTowerByID(towerID wtdb.TowerID) (*wtdb.Tower, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if tower, ok := m.towers[towerID]; ok {
|
|
return copyTower(tower), nil
|
|
}
|
|
|
|
return nil, wtdb.ErrTowerNotFound
|
|
}
|
|
|
|
// ListTowers retrieves the list of towers available within the database.
|
|
func (m *ClientDB) ListTowers() ([]*wtdb.Tower, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
towers := make([]*wtdb.Tower, 0, len(m.towers))
|
|
for _, tower := range m.towers {
|
|
towers = append(towers, copyTower(tower))
|
|
}
|
|
|
|
return towers, nil
|
|
}
|
|
|
|
// MarkBackupIneligible records that particular commit height is ineligible for
|
|
// backup. This allows the client to track which updates it should not attempt
|
|
// to retry after startup.
|
|
func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight uint64) error {
|
|
return nil
|
|
}
|
|
|
|
// ListClientSessions returns the set of all client sessions known to the db. An
|
|
// optional tower ID can be used to filter out any client sessions in the
|
|
// response that do not correspond to this tower.
|
|
func (m *ClientDB) ListClientSessions(
|
|
tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.listClientSessions(tower)
|
|
}
|
|
|
|
// listClientSessions returns the set of all client sessions known to the db. An
|
|
// optional tower ID can be used to filter out any client sessions in the
|
|
// response that do not correspond to this tower.
|
|
func (m *ClientDB) listClientSessions(
|
|
tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
|
|
|
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
|
for _, session := range m.activeSessions {
|
|
session := session
|
|
if tower != nil && *tower != session.TowerID {
|
|
continue
|
|
}
|
|
sessions[session.ID] = &session
|
|
}
|
|
|
|
return sessions, nil
|
|
}
|
|
|
|
// CreateClientSession records a newly negotiated client session in the set of
|
|
// active sessions. The session can be identified by its SessionID.
|
|
func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
// Ensure that we aren't overwriting an existing session.
|
|
if _, ok := m.activeSessions[session.ID]; ok {
|
|
return wtdb.ErrClientSessionAlreadyExists
|
|
}
|
|
|
|
// Ensure that a session key index has been reserved for this tower.
|
|
keyIndex, ok := m.indexes[session.TowerID]
|
|
if !ok {
|
|
return wtdb.ErrNoReservedKeyIndex
|
|
}
|
|
|
|
// Ensure that the session's index matches the reserved index.
|
|
if keyIndex != session.KeyIndex {
|
|
return wtdb.ErrIncorrectKeyIndex
|
|
}
|
|
|
|
// Remove the key index reservation for this tower. Once committed, this
|
|
// permits us to create another session with this tower.
|
|
delete(m.indexes, session.TowerID)
|
|
|
|
m.activeSessions[session.ID] = wtdb.ClientSession{
|
|
ID: session.ID,
|
|
ClientSessionBody: wtdb.ClientSessionBody{
|
|
SeqNum: session.SeqNum,
|
|
TowerLastApplied: session.TowerLastApplied,
|
|
TowerID: session.TowerID,
|
|
KeyIndex: session.KeyIndex,
|
|
Policy: session.Policy,
|
|
RewardPkScript: cloneBytes(session.RewardPkScript),
|
|
},
|
|
CommittedUpdates: make([]wtdb.CommittedUpdate, 0),
|
|
AckedUpdates: make(map[uint16]wtdb.BackupID),
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// NextSessionKeyIndex reserves a new session key derivation index for a
|
|
// particular tower id. The index is reserved for that tower until
|
|
// CreateClientSession is invoked for that tower and index, at which point a new
|
|
// index for that tower can be reserved. Multiple calls to this method before
|
|
// CreateClientSession is invoked should return the same index.
|
|
func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID) (uint32, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if index, ok := m.indexes[towerID]; ok {
|
|
return index, nil
|
|
}
|
|
|
|
m.nextIndex++
|
|
index := m.nextIndex
|
|
m.indexes[towerID] = index
|
|
|
|
return index, nil
|
|
}
|
|
|
|
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
|
|
// seqNum). This allows the client to retransmit this update on startup.
|
|
func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
|
|
update *wtdb.CommittedUpdate) (uint16, error) {
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
// Fail if session doesn't exist.
|
|
session, ok := m.activeSessions[*id]
|
|
if !ok {
|
|
return 0, wtdb.ErrClientSessionNotFound
|
|
}
|
|
|
|
// Check if an update has already been committed for this state.
|
|
for _, dbUpdate := range session.CommittedUpdates {
|
|
if dbUpdate.SeqNum == update.SeqNum {
|
|
// If the breach hint matches, we'll just return the
|
|
// last applied value so the client can retransmit.
|
|
if dbUpdate.Hint == update.Hint {
|
|
return session.TowerLastApplied, nil
|
|
}
|
|
|
|
// Otherwise, fail since the breach hint doesn't match.
|
|
return 0, wtdb.ErrUpdateAlreadyCommitted
|
|
}
|
|
}
|
|
|
|
// Sequence number must increment.
|
|
if update.SeqNum != session.SeqNum+1 {
|
|
return 0, wtdb.ErrCommitUnorderedUpdate
|
|
}
|
|
|
|
// Save the update and increment the sequence number.
|
|
session.CommittedUpdates = append(session.CommittedUpdates, *update)
|
|
session.SeqNum++
|
|
m.activeSessions[*id] = session
|
|
|
|
return session.TowerLastApplied, nil
|
|
}
|
|
|
|
// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This
|
|
// removes the update from the set of committed updates, and validates the
|
|
// lastApplied value returned from the tower.
|
|
func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
// Fail if session doesn't exist.
|
|
session, ok := m.activeSessions[*id]
|
|
if !ok {
|
|
return wtdb.ErrClientSessionNotFound
|
|
}
|
|
|
|
// Ensure the returned last applied value does not exceed the highest
|
|
// allocated sequence number.
|
|
if lastApplied > session.SeqNum {
|
|
return wtdb.ErrUnallocatedLastApplied
|
|
}
|
|
|
|
// Ensure the last applied value isn't lower than a previous one sent by
|
|
// the tower.
|
|
if lastApplied < session.TowerLastApplied {
|
|
return wtdb.ErrLastAppliedReversion
|
|
}
|
|
|
|
// Retrieve the committed update, failing if none is found. We should
|
|
// only receive acks for state updates that we send.
|
|
updates := session.CommittedUpdates
|
|
for i, update := range updates {
|
|
if update.SeqNum != seqNum {
|
|
continue
|
|
}
|
|
|
|
// Remove the committed update from disk and mark the update as
|
|
// acked. The tower last applied value is also recorded to send
|
|
// along with the next update.
|
|
copy(updates[:i], updates[i+1:])
|
|
updates[len(updates)-1] = wtdb.CommittedUpdate{}
|
|
session.CommittedUpdates = updates[:len(updates)-1]
|
|
|
|
session.AckedUpdates[seqNum] = update.BackupID
|
|
session.TowerLastApplied = lastApplied
|
|
|
|
m.activeSessions[*id] = session
|
|
return nil
|
|
}
|
|
|
|
return wtdb.ErrCommittedUpdateNotFound
|
|
}
|
|
|
|
// FetchChanSummaries loads a mapping from all registered channels to their
|
|
// channel summaries.
|
|
func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary)
|
|
for chanID, summary := range m.summaries {
|
|
summaries[chanID] = wtdb.ClientChanSummary{
|
|
SweepPkScript: cloneBytes(summary.SweepPkScript),
|
|
}
|
|
}
|
|
|
|
return summaries, nil
|
|
}
|
|
|
|
// RegisterChannel registers a channel for use within the client database. For
|
|
// now, all that is stored in the channel summary is the sweep pkscript that
|
|
// we'd like any tower sweeps to pay into. In the future, this will be extended
|
|
// to contain more info to allow the client efficiently request historical
|
|
// states to be backed up under the client's active policy.
|
|
func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
|
|
sweepPkScript []byte) error {
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if _, ok := m.summaries[chanID]; ok {
|
|
return wtdb.ErrChannelAlreadyRegistered
|
|
}
|
|
|
|
m.summaries[chanID] = wtdb.ClientChanSummary{
|
|
SweepPkScript: cloneBytes(sweepPkScript),
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func cloneBytes(b []byte) []byte {
|
|
if b == nil {
|
|
return nil
|
|
}
|
|
|
|
bb := make([]byte, len(b))
|
|
copy(bb, b)
|
|
|
|
return bb
|
|
}
|
|
|
|
func copyTower(tower *wtdb.Tower) *wtdb.Tower {
|
|
t := &wtdb.Tower{
|
|
ID: tower.ID,
|
|
IdentityKey: tower.IdentityKey,
|
|
Addresses: make([]net.Addr, len(tower.Addresses)),
|
|
}
|
|
copy(t.Addresses, tower.Addresses)
|
|
|
|
return t
|
|
}
|