mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-02-23 14:40:30 +01:00
watchtower: add MarkChannelClosed db method
This commit adds a `MarkChannelClosed` method to the tower client DB. This function can be called when a channel is closed and it will check the channel's associated sessions to see if any of them are "closable". Any closable sessions are added to a new `cClosableSessionsBkt` bucket so that they can be evaluated in future. Note that only the logic for this function is added in this commit and it is not yet called.
This commit is contained in:
parent
a3050ed213
commit
571966440c
4 changed files with 553 additions and 14 deletions
|
@ -86,6 +86,15 @@ type DB interface {
|
|||
// their channel summaries.
|
||||
FetchChanSummaries() (wtdb.ChannelSummaries, error)
|
||||
|
||||
// MarkChannelClosed will mark a registered channel as closed by setting
|
||||
// its closed-height as the given block height. It returns a list of
|
||||
// session IDs for sessions that are now considered closable due to the
|
||||
// close of this channel. The details for this channel will be deleted
|
||||
// from the DB if there are no more sessions in the DB that contain
|
||||
// updates for this channel.
|
||||
MarkChannelClosed(chanID lnwire.ChannelID, blockHeight uint32) (
|
||||
[]wtdb.SessionID, error)
|
||||
|
||||
// RegisterChannel registers a channel for use within the client
|
||||
// database. For now, all that is stored in the channel summary is the
|
||||
// sweep pkscript that we'd like any tower sweeps to pay into. In the
|
||||
|
|
|
@ -24,6 +24,7 @@ var (
|
|||
// channel-id => cChannelSummary -> encoded ClientChanSummary.
|
||||
// => cChanDBID -> db-assigned-id
|
||||
// => cChanSessions => db-session-id -> 1
|
||||
// => cChanClosedHeight -> block-height
|
||||
cChanDetailsBkt = []byte("client-channel-detail-bucket")
|
||||
|
||||
// cChanSessions is a sub-bucket of cChanDetailsBkt which stores:
|
||||
|
@ -34,6 +35,12 @@ var (
|
|||
// db-assigned-id of a channel.
|
||||
cChanDBID = []byte("client-channel-db-id")
|
||||
|
||||
// cChanClosedHeight is a key used in the cChanDetailsBkt to store the
|
||||
// block height at which the channel's closing transaction was mined in.
|
||||
// If this there is no associated value for this key, then the channel
|
||||
// has not yet been marked as closed.
|
||||
cChanClosedHeight = []byte("client-channel-closed-height")
|
||||
|
||||
// cChannelSummary is a key used in cChanDetailsBkt to store the encoded
|
||||
// body of ClientChanSummary.
|
||||
cChannelSummary = []byte("client-channel-summary")
|
||||
|
@ -83,6 +90,10 @@ var (
|
|||
"client-tower-to-session-index-bucket",
|
||||
)
|
||||
|
||||
// cClosableSessionsBkt is a top-level bucket storing:
|
||||
// db-session-id -> last-channel-close-height
|
||||
cClosableSessionsBkt = []byte("client-closable-sessions-bucket")
|
||||
|
||||
// ErrTowerNotFound signals that the target tower was not found in the
|
||||
// database.
|
||||
ErrTowerNotFound = errors.New("tower not found")
|
||||
|
@ -156,6 +167,14 @@ var (
|
|||
// ErrSessionFailedFilterFn indicates that a particular session did
|
||||
// not pass the filter func provided by the caller.
|
||||
ErrSessionFailedFilterFn = errors.New("session failed filter func")
|
||||
|
||||
// errSessionHasOpenChannels is an error used to indicate that a
|
||||
// session has updates for channels that are still open.
|
||||
errSessionHasOpenChannels = errors.New("session has open channels")
|
||||
|
||||
// errSessionHasUnackedUpdates is an error used to indicate that a
|
||||
// session has un-acked updates.
|
||||
errSessionHasUnackedUpdates = errors.New("session has un-acked updates")
|
||||
)
|
||||
|
||||
// NewBoltBackendCreator returns a function that creates a new bbolt backend for
|
||||
|
@ -256,6 +275,7 @@ func initClientDBBuckets(tx kvdb.RwTx) error {
|
|||
cTowerToSessionIndexBkt,
|
||||
cChanIDIndexBkt,
|
||||
cSessionIDIndexBkt,
|
||||
cClosableSessionsBkt,
|
||||
}
|
||||
|
||||
for _, bucket := range buckets {
|
||||
|
@ -1365,6 +1385,209 @@ func (c *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID,
|
|||
return nil
|
||||
}
|
||||
|
||||
// MarkChannelClosed will mark a registered channel as closed by setting its
|
||||
// closed-height as the given block height. It returns a list of session IDs for
|
||||
// sessions that are now considered closable due to the close of this channel.
|
||||
// The details for this channel will be deleted from the DB if there are no more
|
||||
// sessions in the DB that contain updates for this channel.
|
||||
func (c *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID,
|
||||
blockHeight uint32) ([]SessionID, error) {
|
||||
|
||||
var closableSessions []SessionID
|
||||
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
|
||||
sessionsBkt := tx.ReadBucket(cSessionBkt)
|
||||
if sessionsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt)
|
||||
if chanDetailsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
closableSessBkt := tx.ReadWriteBucket(cClosableSessionsBkt)
|
||||
if closableSessBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt)
|
||||
if chanIDIndexBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
sessIDIndexBkt := tx.ReadBucket(cSessionIDIndexBkt)
|
||||
if sessIDIndexBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
chanDetails := chanDetailsBkt.NestedReadWriteBucket(chanID[:])
|
||||
if chanDetails == nil {
|
||||
return ErrChannelNotRegistered
|
||||
}
|
||||
|
||||
// If there are no sessions for this channel, the channel
|
||||
// details can be deleted.
|
||||
chanSessIDsBkt := chanDetails.NestedReadBucket(cChanSessions)
|
||||
if chanSessIDsBkt == nil {
|
||||
return chanDetailsBkt.DeleteNestedBucket(chanID[:])
|
||||
}
|
||||
|
||||
// Otherwise, mark the channel as closed.
|
||||
var height [4]byte
|
||||
byteOrder.PutUint32(height[:], blockHeight)
|
||||
|
||||
err := chanDetails.Put(cChanClosedHeight, height[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now iterate through all the sessions of the channel to check
|
||||
// if any of them are closeable.
|
||||
return chanSessIDsBkt.ForEach(func(sessDBID, _ []byte) error {
|
||||
sessDBIDInt, err := readBigSize(sessDBID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Use the session-ID index to get the real session ID.
|
||||
sID, err := getRealSessionID(
|
||||
sessIDIndexBkt, sessDBIDInt,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
isClosable, err := isSessionClosable(
|
||||
sessionsBkt, chanDetailsBkt, chanIDIndexBkt,
|
||||
sID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !isClosable {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add session to "closableSessions" list and add the
|
||||
// block height that this last channel was closed in.
|
||||
// This will be used in future to determine when we
|
||||
// should delete the session.
|
||||
var height [4]byte
|
||||
byteOrder.PutUint32(height[:], blockHeight)
|
||||
err = closableSessBkt.Put(sessDBID, height[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
closableSessions = append(closableSessions, *sID)
|
||||
|
||||
return nil
|
||||
})
|
||||
}, func() {
|
||||
closableSessions = nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return closableSessions, nil
|
||||
}
|
||||
|
||||
// isSessionClosable returns true if a session is considered closable. A session
|
||||
// is considered closable only if all the following points are true:
|
||||
// 1) It has no un-acked updates.
|
||||
// 2) It is exhausted (ie it can't accept any more updates)
|
||||
// 3) All the channels that it has acked updates for are closed.
|
||||
func isSessionClosable(sessionsBkt, chanDetailsBkt, chanIDIndexBkt kvdb.RBucket,
|
||||
id *SessionID) (bool, error) {
|
||||
|
||||
sessBkt := sessionsBkt.NestedReadBucket(id[:])
|
||||
if sessBkt == nil {
|
||||
return false, ErrSessionNotFound
|
||||
}
|
||||
|
||||
commitsBkt := sessBkt.NestedReadBucket(cSessionCommits)
|
||||
if commitsBkt == nil {
|
||||
// If the session has no cSessionCommits bucket then we can be
|
||||
// sure that no updates have ever been committed to the session
|
||||
// and so it is not yet exhausted.
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// If the session has any un-acked updates, then it is not yet closable.
|
||||
err := commitsBkt.ForEach(func(_, _ []byte) error {
|
||||
return errSessionHasUnackedUpdates
|
||||
})
|
||||
if errors.Is(err, errSessionHasUnackedUpdates) {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
session, err := getClientSessionBody(sessionsBkt, id[:])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// We have already checked that the session has no more committed
|
||||
// updates. So now we can check if the session is exhausted.
|
||||
if session.SeqNum < session.Policy.MaxUpdates {
|
||||
// If the session is not yet exhausted, it is not yet closable.
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// If the session has no acked-updates, then something is wrong since
|
||||
// the above check ensures that this session has been exhausted meaning
|
||||
// that it should have MaxUpdates acked updates.
|
||||
ackedRangeBkt := sessBkt.NestedReadBucket(cSessionAckRangeIndex)
|
||||
if ackedRangeBkt == nil {
|
||||
return false, fmt.Errorf("no acked-updates found for "+
|
||||
"exhausted session %s", id)
|
||||
}
|
||||
|
||||
// Iterate over each of the channels that the session has acked-updates
|
||||
// for. If any of those channels are not closed, then the session is
|
||||
// not yet closable.
|
||||
err = ackedRangeBkt.ForEach(func(dbChanID, _ []byte) error {
|
||||
dbChanIDInt, err := readBigSize(dbChanID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chanID, err := getRealChannelID(chanIDIndexBkt, dbChanIDInt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the channel details bucket for the channel.
|
||||
chanDetails := chanDetailsBkt.NestedReadBucket(chanID[:])
|
||||
if chanDetails == nil {
|
||||
return fmt.Errorf("no channel details found for "+
|
||||
"channel %s referenced by session %s", chanID,
|
||||
id)
|
||||
}
|
||||
|
||||
// If a closed height has been set, then the channel is closed.
|
||||
closedHeight := chanDetails.Get(cChanClosedHeight)
|
||||
if len(closedHeight) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise, the channel is not yet closed meaning that the
|
||||
// session is not yet closable. We break the ForEach by
|
||||
// returning an error to indicate this.
|
||||
return errSessionHasOpenChannels
|
||||
})
|
||||
if errors.Is(err, errSessionHasOpenChannels) {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
|
||||
// seqNum). This allows the client to retransmit this update on startup.
|
||||
func (c *ClientDB) CommitUpdate(id *SessionID,
|
||||
|
@ -2016,6 +2239,44 @@ func getDBSessionID(sessionsBkt kvdb.RBucket, sessionID SessionID) (uint64,
|
|||
return id, idBytes, nil
|
||||
}
|
||||
|
||||
func getRealSessionID(sessIDIndexBkt kvdb.RBucket, dbID uint64) (*SessionID,
|
||||
error) {
|
||||
|
||||
dbIDBytes, err := writeBigSize(dbID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sessIDBytes := sessIDIndexBkt.Get(dbIDBytes)
|
||||
if len(sessIDBytes) != SessionIDSize {
|
||||
return nil, fmt.Errorf("session ID not found")
|
||||
}
|
||||
|
||||
var sessID SessionID
|
||||
copy(sessID[:], sessIDBytes)
|
||||
|
||||
return &sessID, nil
|
||||
}
|
||||
|
||||
func getRealChannelID(chanIDIndexBkt kvdb.RBucket,
|
||||
dbID uint64) (*lnwire.ChannelID, error) {
|
||||
|
||||
dbIDBytes, err := writeBigSize(dbID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chanIDBytes := chanIDIndexBkt.Get(dbIDBytes)
|
||||
if len(chanIDBytes) != 32 { //nolint:gomnd
|
||||
return nil, fmt.Errorf("channel ID not found")
|
||||
}
|
||||
|
||||
var chanIDS lnwire.ChannelID
|
||||
copy(chanIDS[:], chanIDBytes)
|
||||
|
||||
return &chanIDS, nil
|
||||
}
|
||||
|
||||
// writeBigSize will encode the given uint64 as a BigSize byte slice.
|
||||
func writeBigSize(i uint64) ([]byte, error) {
|
||||
var b bytes.Buffer
|
||||
|
|
|
@ -3,6 +3,7 @@ package wtdb_test
|
|||
import (
|
||||
crand "crypto/rand"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
|
@ -17,6 +18,8 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const blobType = blob.TypeAltruistCommit
|
||||
|
||||
// pseudoAddr is a fake network address to be used for testing purposes.
|
||||
var pseudoAddr = &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911}
|
||||
|
||||
|
@ -193,6 +196,17 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16,
|
|||
require.ErrorIs(h.t, err, expErr)
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) markChannelClosed(id lnwire.ChannelID,
|
||||
blockHeight uint32, expErr error) []wtdb.SessionID {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
closableSessions, err := h.db.MarkChannelClosed(id, blockHeight)
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
|
||||
return closableSessions
|
||||
}
|
||||
|
||||
// newTower is a helper function that creates a new tower with a randomly
|
||||
// generated public key and inserts it into the client DB.
|
||||
func (h *clientDBHarness) newTower() *wtdb.Tower {
|
||||
|
@ -605,6 +619,105 @@ func testCommitUpdate(h *clientDBHarness) {
|
|||
}, nil)
|
||||
}
|
||||
|
||||
// testMarkChannelClosed asserts the behaviour of MarkChannelClosed.
|
||||
func testMarkChannelClosed(h *clientDBHarness) {
|
||||
tower := h.newTower()
|
||||
|
||||
// Create channel 1.
|
||||
chanID1 := randChannelID(h.t)
|
||||
|
||||
// Since we have not yet registered the channel, we expect an error
|
||||
// when attempting to mark it as closed.
|
||||
h.markChannelClosed(chanID1, 1, wtdb.ErrChannelNotRegistered)
|
||||
|
||||
// Now register the channel.
|
||||
h.registerChan(chanID1, nil, nil)
|
||||
|
||||
// Since there are still no sessions that would have updates for the
|
||||
// channel, marking it as closed now should succeed.
|
||||
h.markChannelClosed(chanID1, 1, nil)
|
||||
|
||||
// Register channel 2.
|
||||
chanID2 := randChannelID(h.t)
|
||||
h.registerChan(chanID2, nil, nil)
|
||||
|
||||
// Create session1 with MaxUpdates set to 5.
|
||||
session1 := h.randSession(h.t, tower.ID, 5)
|
||||
h.insertSession(session1, nil)
|
||||
|
||||
// Add an update for channel 2 in session 1 and ack it too.
|
||||
update := randCommittedUpdateForChannel(h.t, chanID2, 1)
|
||||
lastApplied := h.commitUpdate(&session1.ID, update, nil)
|
||||
require.Zero(h.t, lastApplied)
|
||||
h.ackUpdate(&session1.ID, 1, 1, nil)
|
||||
|
||||
// Marking channel 2 now should not result in any closable sessions
|
||||
// since session 1 is not yet exhausted.
|
||||
sl := h.markChannelClosed(chanID2, 1, nil)
|
||||
require.Empty(h.t, sl)
|
||||
|
||||
// Create channel 3 and 4.
|
||||
chanID3 := randChannelID(h.t)
|
||||
h.registerChan(chanID3, nil, nil)
|
||||
|
||||
chanID4 := randChannelID(h.t)
|
||||
h.registerChan(chanID4, nil, nil)
|
||||
|
||||
// Add an update for channel 4 and ack it.
|
||||
update = randCommittedUpdateForChannel(h.t, chanID4, 2)
|
||||
lastApplied = h.commitUpdate(&session1.ID, update, nil)
|
||||
require.EqualValues(h.t, 1, lastApplied)
|
||||
h.ackUpdate(&session1.ID, 2, 2, nil)
|
||||
|
||||
// Add an update for channel 3 in session 1. But dont ack it yet.
|
||||
update = randCommittedUpdateForChannel(h.t, chanID2, 3)
|
||||
lastApplied = h.commitUpdate(&session1.ID, update, nil)
|
||||
require.EqualValues(h.t, 2, lastApplied)
|
||||
|
||||
// Mark channel 4 as closed & assert that session 1 is not seen as
|
||||
// closable since it still has committed updates.
|
||||
sl = h.markChannelClosed(chanID4, 1, nil)
|
||||
require.Empty(h.t, sl)
|
||||
|
||||
// Now ack the update we added above.
|
||||
h.ackUpdate(&session1.ID, 3, 3, nil)
|
||||
|
||||
// Mark channel 3 as closed & assert that session 1 is still not seen as
|
||||
// closable since it is not yet exhausted.
|
||||
sl = h.markChannelClosed(chanID3, 1, nil)
|
||||
require.Empty(h.t, sl)
|
||||
|
||||
// Create channel 5 and 6.
|
||||
chanID5 := randChannelID(h.t)
|
||||
h.registerChan(chanID5, nil, nil)
|
||||
|
||||
chanID6 := randChannelID(h.t)
|
||||
h.registerChan(chanID6, nil, nil)
|
||||
|
||||
// Add an update for channel 5 and ack it.
|
||||
update = randCommittedUpdateForChannel(h.t, chanID5, 4)
|
||||
lastApplied = h.commitUpdate(&session1.ID, update, nil)
|
||||
require.EqualValues(h.t, 3, lastApplied)
|
||||
h.ackUpdate(&session1.ID, 4, 4, nil)
|
||||
|
||||
// Add an update for channel 6 and ack it.
|
||||
update = randCommittedUpdateForChannel(h.t, chanID6, 5)
|
||||
lastApplied = h.commitUpdate(&session1.ID, update, nil)
|
||||
require.EqualValues(h.t, 4, lastApplied)
|
||||
h.ackUpdate(&session1.ID, 5, 5, nil)
|
||||
|
||||
// The session is no exhausted.
|
||||
// If we now close channel 5, session 1 should still not be closable
|
||||
// since it has an update for channel 6 which is still open.
|
||||
sl = h.markChannelClosed(chanID5, 1, nil)
|
||||
require.Empty(h.t, sl)
|
||||
|
||||
// Finally, if we close channel 6, session 1 _should_ be in the closable
|
||||
// list.
|
||||
sl = h.markChannelClosed(chanID6, 1, nil)
|
||||
require.ElementsMatch(h.t, sl, []wtdb.SessionID{session1.ID})
|
||||
}
|
||||
|
||||
// testAckUpdate asserts the behavior of AckUpdate.
|
||||
func testAckUpdate(h *clientDBHarness) {
|
||||
const blobType = blob.TypeAltruistCommit
|
||||
|
@ -821,6 +934,10 @@ func TestClientDB(t *testing.T) {
|
|||
name: "ack update",
|
||||
run: testAckUpdate,
|
||||
},
|
||||
{
|
||||
name: "mark channel closed",
|
||||
run: testMarkChannelClosed,
|
||||
},
|
||||
}
|
||||
|
||||
for _, database := range dbs {
|
||||
|
@ -841,12 +958,32 @@ func TestClientDB(t *testing.T) {
|
|||
|
||||
// randCommittedUpdate generates a random committed update.
|
||||
func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate {
|
||||
t.Helper()
|
||||
|
||||
chanID := randChannelID(t)
|
||||
|
||||
return randCommittedUpdateForChannel(t, chanID, seqNum)
|
||||
}
|
||||
|
||||
func randChannelID(t *testing.T) lnwire.ChannelID {
|
||||
t.Helper()
|
||||
|
||||
var chanID lnwire.ChannelID
|
||||
_, err := io.ReadFull(crand.Reader, chanID[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
return chanID
|
||||
}
|
||||
|
||||
// randCommittedUpdateForChannel generates a random committed update for the
|
||||
// given channel ID.
|
||||
func randCommittedUpdateForChannel(t *testing.T, chanID lnwire.ChannelID,
|
||||
seqNum uint16) *wtdb.CommittedUpdate {
|
||||
|
||||
t.Helper()
|
||||
|
||||
var hint blob.BreachHint
|
||||
_, err = io.ReadFull(crand.Reader, hint[:])
|
||||
_, err := io.ReadFull(crand.Reader, hint[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type()))
|
||||
|
@ -865,3 +1002,27 @@ func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate {
|
|||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) randSession(t *testing.T,
|
||||
towerID wtdb.TowerID, maxUpdates uint16) *wtdb.ClientSession {
|
||||
|
||||
t.Helper()
|
||||
|
||||
var id wtdb.SessionID
|
||||
rand.Read(id[:])
|
||||
|
||||
return &wtdb.ClientSession{
|
||||
ClientSessionBody: wtdb.ClientSessionBody{
|
||||
TowerID: towerID,
|
||||
Policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blobType,
|
||||
},
|
||||
MaxUpdates: maxUpdates,
|
||||
},
|
||||
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
||||
KeyIndex: h.nextKeyIndex(towerID, blobType),
|
||||
},
|
||||
ID: id,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,19 +25,26 @@ type rangeIndexArrayMap map[wtdb.SessionID]map[lnwire.ChannelID]*wtdb.RangeIndex
|
|||
|
||||
type rangeIndexKVStore map[wtdb.SessionID]map[lnwire.ChannelID]*mockKVStore
|
||||
|
||||
type channel struct {
|
||||
summary *wtdb.ClientChanSummary
|
||||
closedHeight uint32
|
||||
sessions map[wtdb.SessionID]bool
|
||||
}
|
||||
|
||||
// ClientDB is a mock, in-memory database or testing the watchtower client
|
||||
// behavior.
|
||||
type ClientDB struct {
|
||||
nextTowerID uint64 // to be used atomically
|
||||
|
||||
mu sync.Mutex
|
||||
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
|
||||
channels map[lnwire.ChannelID]*channel
|
||||
activeSessions map[wtdb.SessionID]wtdb.ClientSession
|
||||
ackedUpdates rangeIndexArrayMap
|
||||
persistedAckedUpdates rangeIndexKVStore
|
||||
committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate
|
||||
towerIndex map[towerPK]wtdb.TowerID
|
||||
towers map[wtdb.TowerID]*wtdb.Tower
|
||||
closableSessions map[wtdb.SessionID]uint32
|
||||
|
||||
nextIndex uint32
|
||||
indexes map[keyIndexKey]uint32
|
||||
|
@ -47,9 +54,7 @@ type ClientDB struct {
|
|||
// NewClientDB initializes a new mock ClientDB.
|
||||
func NewClientDB() *ClientDB {
|
||||
return &ClientDB{
|
||||
summaries: make(
|
||||
map[lnwire.ChannelID]wtdb.ClientChanSummary,
|
||||
),
|
||||
channels: make(map[lnwire.ChannelID]*channel),
|
||||
activeSessions: make(
|
||||
map[wtdb.SessionID]wtdb.ClientSession,
|
||||
),
|
||||
|
@ -58,10 +63,11 @@ func NewClientDB() *ClientDB {
|
|||
committedUpdates: make(
|
||||
map[wtdb.SessionID][]wtdb.CommittedUpdate,
|
||||
),
|
||||
towerIndex: make(map[towerPK]wtdb.TowerID),
|
||||
towers: make(map[wtdb.TowerID]*wtdb.Tower),
|
||||
indexes: make(map[keyIndexKey]uint32),
|
||||
legacyIndexes: make(map[wtdb.TowerID]uint32),
|
||||
towerIndex: make(map[towerPK]wtdb.TowerID),
|
||||
towers: make(map[wtdb.TowerID]*wtdb.Tower),
|
||||
indexes: make(map[keyIndexKey]uint32),
|
||||
legacyIndexes: make(map[wtdb.TowerID]uint32),
|
||||
closableSessions: make(map[wtdb.SessionID]uint32),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -503,6 +509,13 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum,
|
|||
continue
|
||||
}
|
||||
|
||||
// Add sessionID to channel.
|
||||
channel, ok := m.channels[update.BackupID.ChanID]
|
||||
if !ok {
|
||||
return wtdb.ErrChannelNotRegistered
|
||||
}
|
||||
channel.sessions[*id] = true
|
||||
|
||||
// Remove the committed update from disk and mark the update as
|
||||
// acked. The tower last applied value is also recorded to send
|
||||
// along with the next update.
|
||||
|
@ -545,15 +558,107 @@ func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) {
|
|||
defer m.mu.Unlock()
|
||||
|
||||
summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary)
|
||||
for chanID, summary := range m.summaries {
|
||||
for chanID, channel := range m.channels {
|
||||
summaries[chanID] = wtdb.ClientChanSummary{
|
||||
SweepPkScript: cloneBytes(summary.SweepPkScript),
|
||||
SweepPkScript: cloneBytes(
|
||||
channel.summary.SweepPkScript,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
return summaries, nil
|
||||
}
|
||||
|
||||
// MarkChannelClosed will mark a registered channel as closed by setting
|
||||
// its closed-height as the given block height. It returns a list of
|
||||
// session IDs for sessions that are now considered closable due to the
|
||||
// close of this channel.
|
||||
func (m *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID,
|
||||
blockHeight uint32) ([]wtdb.SessionID, error) {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
channel, ok := m.channels[chanID]
|
||||
if !ok {
|
||||
return nil, wtdb.ErrChannelNotRegistered
|
||||
}
|
||||
|
||||
// If there are no sessions for this channel, the channel details can be
|
||||
// deleted.
|
||||
if len(channel.sessions) == 0 {
|
||||
delete(m.channels, chanID)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Mark the channel as closed.
|
||||
channel.closedHeight = blockHeight
|
||||
|
||||
// Now iterate through all the sessions of the channel to check if any
|
||||
// of them are closeable.
|
||||
var closableSessions []wtdb.SessionID
|
||||
for sessID := range channel.sessions {
|
||||
isClosable, err := m.isSessionClosable(sessID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !isClosable {
|
||||
continue
|
||||
}
|
||||
|
||||
closableSessions = append(closableSessions, sessID)
|
||||
|
||||
// Add session to "closableSessions" list and add the block
|
||||
// height that this last channel was closed in. This will be
|
||||
// used in future to determine when we should delete the
|
||||
// session.
|
||||
m.closableSessions[sessID] = blockHeight
|
||||
}
|
||||
|
||||
return closableSessions, nil
|
||||
}
|
||||
|
||||
// isSessionClosable returns true if a session is considered closable. A session
|
||||
// is considered closable only if:
|
||||
// 1) It has no un-acked updates
|
||||
// 2) It is exhausted (ie it cant accept any more updates)
|
||||
// 3) All the channels that it has acked-updates for are closed.
|
||||
func (m *ClientDB) isSessionClosable(id wtdb.SessionID) (bool, error) {
|
||||
// The session is not closable if it has un-acked updates.
|
||||
if len(m.committedUpdates[id]) > 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
sess, ok := m.activeSessions[id]
|
||||
if !ok {
|
||||
return false, wtdb.ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
// The session is not closable if it is not yet exhausted.
|
||||
if sess.SeqNum != sess.Policy.MaxUpdates {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Iterate over each of the channels that the session has acked-updates
|
||||
// for. If any of those channels are not closed, then the session is
|
||||
// not yet closable.
|
||||
for chanID := range m.ackedUpdates[id] {
|
||||
channel, ok := m.channels[chanID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Channel is not yet closed, and so we can not yet delete the
|
||||
// session.
|
||||
if channel.closedHeight == 0 {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetClientSession loads the ClientSession with the given ID from the DB.
|
||||
func (m *ClientDB) GetClientSession(id wtdb.SessionID,
|
||||
opts ...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error) {
|
||||
|
@ -595,12 +700,15 @@ func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
|
|||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, ok := m.summaries[chanID]; ok {
|
||||
if _, ok := m.channels[chanID]; ok {
|
||||
return wtdb.ErrChannelAlreadyRegistered
|
||||
}
|
||||
|
||||
m.summaries[chanID] = wtdb.ClientChanSummary{
|
||||
SweepPkScript: cloneBytes(sweepPkScript),
|
||||
m.channels[chanID] = &channel{
|
||||
summary: &wtdb.ClientChanSummary{
|
||||
SweepPkScript: cloneBytes(sweepPkScript),
|
||||
},
|
||||
sessions: make(map[wtdb.SessionID]bool),
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
Loading…
Add table
Reference in a new issue