From 571966440c28f640c6bf2ab1385844a770b03f6c Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 21 Oct 2022 11:21:18 +0200 Subject: [PATCH] watchtower: add MarkChannelClosed db method This commit adds a `MarkChannelClosed` method to the tower client DB. This function can be called when a channel is closed and it will check the channel's associated sessions to see if any of them are "closable". Any closable sessions are added to a new `cClosableSessionsBkt` bucket so that they can be evaluated in future. Note that only the logic for this function is added in this commit and it is not yet called. --- watchtower/wtclient/interface.go | 9 ++ watchtower/wtdb/client_db.go | 261 ++++++++++++++++++++++++++++++ watchtower/wtdb/client_db_test.go | 163 ++++++++++++++++++- watchtower/wtmock/client_db.go | 134 +++++++++++++-- 4 files changed, 553 insertions(+), 14 deletions(-) diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 32ebe93ad..e5fc5d22b 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -86,6 +86,15 @@ type DB interface { // their channel summaries. FetchChanSummaries() (wtdb.ChannelSummaries, error) + // MarkChannelClosed will mark a registered channel as closed by setting + // its closed-height as the given block height. It returns a list of + // session IDs for sessions that are now considered closable due to the + // close of this channel. The details for this channel will be deleted + // from the DB if there are no more sessions in the DB that contain + // updates for this channel. + MarkChannelClosed(chanID lnwire.ChannelID, blockHeight uint32) ( + []wtdb.SessionID, error) + // RegisterChannel registers a channel for use within the client // database. For now, all that is stored in the channel summary is the // sweep pkscript that we'd like any tower sweeps to pay into. In the diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index eaf188470..d88fd631e 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -24,6 +24,7 @@ var ( // channel-id => cChannelSummary -> encoded ClientChanSummary. // => cChanDBID -> db-assigned-id // => cChanSessions => db-session-id -> 1 + // => cChanClosedHeight -> block-height cChanDetailsBkt = []byte("client-channel-detail-bucket") // cChanSessions is a sub-bucket of cChanDetailsBkt which stores: @@ -34,6 +35,12 @@ var ( // db-assigned-id of a channel. cChanDBID = []byte("client-channel-db-id") + // cChanClosedHeight is a key used in the cChanDetailsBkt to store the + // block height at which the channel's closing transaction was mined in. + // If this there is no associated value for this key, then the channel + // has not yet been marked as closed. + cChanClosedHeight = []byte("client-channel-closed-height") + // cChannelSummary is a key used in cChanDetailsBkt to store the encoded // body of ClientChanSummary. cChannelSummary = []byte("client-channel-summary") @@ -83,6 +90,10 @@ var ( "client-tower-to-session-index-bucket", ) + // cClosableSessionsBkt is a top-level bucket storing: + // db-session-id -> last-channel-close-height + cClosableSessionsBkt = []byte("client-closable-sessions-bucket") + // ErrTowerNotFound signals that the target tower was not found in the // database. ErrTowerNotFound = errors.New("tower not found") @@ -156,6 +167,14 @@ var ( // ErrSessionFailedFilterFn indicates that a particular session did // not pass the filter func provided by the caller. ErrSessionFailedFilterFn = errors.New("session failed filter func") + + // errSessionHasOpenChannels is an error used to indicate that a + // session has updates for channels that are still open. + errSessionHasOpenChannels = errors.New("session has open channels") + + // errSessionHasUnackedUpdates is an error used to indicate that a + // session has un-acked updates. + errSessionHasUnackedUpdates = errors.New("session has un-acked updates") ) // NewBoltBackendCreator returns a function that creates a new bbolt backend for @@ -256,6 +275,7 @@ func initClientDBBuckets(tx kvdb.RwTx) error { cTowerToSessionIndexBkt, cChanIDIndexBkt, cSessionIDIndexBkt, + cClosableSessionsBkt, } for _, bucket := range buckets { @@ -1365,6 +1385,209 @@ func (c *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, return nil } +// MarkChannelClosed will mark a registered channel as closed by setting its +// closed-height as the given block height. It returns a list of session IDs for +// sessions that are now considered closable due to the close of this channel. +// The details for this channel will be deleted from the DB if there are no more +// sessions in the DB that contain updates for this channel. +func (c *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID, + blockHeight uint32) ([]SessionID, error) { + + var closableSessions []SessionID + err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { + sessionsBkt := tx.ReadBucket(cSessionBkt) + if sessionsBkt == nil { + return ErrUninitializedDB + } + + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return ErrUninitializedDB + } + + closableSessBkt := tx.ReadWriteBucket(cClosableSessionsBkt) + if closableSessBkt == nil { + return ErrUninitializedDB + } + + chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt) + if chanIDIndexBkt == nil { + return ErrUninitializedDB + } + + sessIDIndexBkt := tx.ReadBucket(cSessionIDIndexBkt) + if sessIDIndexBkt == nil { + return ErrUninitializedDB + } + + chanDetails := chanDetailsBkt.NestedReadWriteBucket(chanID[:]) + if chanDetails == nil { + return ErrChannelNotRegistered + } + + // If there are no sessions for this channel, the channel + // details can be deleted. + chanSessIDsBkt := chanDetails.NestedReadBucket(cChanSessions) + if chanSessIDsBkt == nil { + return chanDetailsBkt.DeleteNestedBucket(chanID[:]) + } + + // Otherwise, mark the channel as closed. + var height [4]byte + byteOrder.PutUint32(height[:], blockHeight) + + err := chanDetails.Put(cChanClosedHeight, height[:]) + if err != nil { + return err + } + + // Now iterate through all the sessions of the channel to check + // if any of them are closeable. + return chanSessIDsBkt.ForEach(func(sessDBID, _ []byte) error { + sessDBIDInt, err := readBigSize(sessDBID) + if err != nil { + return err + } + + // Use the session-ID index to get the real session ID. + sID, err := getRealSessionID( + sessIDIndexBkt, sessDBIDInt, + ) + if err != nil { + return err + } + + isClosable, err := isSessionClosable( + sessionsBkt, chanDetailsBkt, chanIDIndexBkt, + sID, + ) + if err != nil { + return err + } + + if !isClosable { + return nil + } + + // Add session to "closableSessions" list and add the + // block height that this last channel was closed in. + // This will be used in future to determine when we + // should delete the session. + var height [4]byte + byteOrder.PutUint32(height[:], blockHeight) + err = closableSessBkt.Put(sessDBID, height[:]) + if err != nil { + return err + } + + closableSessions = append(closableSessions, *sID) + + return nil + }) + }, func() { + closableSessions = nil + }) + if err != nil { + return nil, err + } + + return closableSessions, nil +} + +// isSessionClosable returns true if a session is considered closable. A session +// is considered closable only if all the following points are true: +// 1) It has no un-acked updates. +// 2) It is exhausted (ie it can't accept any more updates) +// 3) All the channels that it has acked updates for are closed. +func isSessionClosable(sessionsBkt, chanDetailsBkt, chanIDIndexBkt kvdb.RBucket, + id *SessionID) (bool, error) { + + sessBkt := sessionsBkt.NestedReadBucket(id[:]) + if sessBkt == nil { + return false, ErrSessionNotFound + } + + commitsBkt := sessBkt.NestedReadBucket(cSessionCommits) + if commitsBkt == nil { + // If the session has no cSessionCommits bucket then we can be + // sure that no updates have ever been committed to the session + // and so it is not yet exhausted. + return false, nil + } + + // If the session has any un-acked updates, then it is not yet closable. + err := commitsBkt.ForEach(func(_, _ []byte) error { + return errSessionHasUnackedUpdates + }) + if errors.Is(err, errSessionHasUnackedUpdates) { + return false, nil + } else if err != nil { + return false, err + } + + session, err := getClientSessionBody(sessionsBkt, id[:]) + if err != nil { + return false, err + } + + // We have already checked that the session has no more committed + // updates. So now we can check if the session is exhausted. + if session.SeqNum < session.Policy.MaxUpdates { + // If the session is not yet exhausted, it is not yet closable. + return false, nil + } + + // If the session has no acked-updates, then something is wrong since + // the above check ensures that this session has been exhausted meaning + // that it should have MaxUpdates acked updates. + ackedRangeBkt := sessBkt.NestedReadBucket(cSessionAckRangeIndex) + if ackedRangeBkt == nil { + return false, fmt.Errorf("no acked-updates found for "+ + "exhausted session %s", id) + } + + // Iterate over each of the channels that the session has acked-updates + // for. If any of those channels are not closed, then the session is + // not yet closable. + err = ackedRangeBkt.ForEach(func(dbChanID, _ []byte) error { + dbChanIDInt, err := readBigSize(dbChanID) + if err != nil { + return err + } + + chanID, err := getRealChannelID(chanIDIndexBkt, dbChanIDInt) + if err != nil { + return err + } + + // Get the channel details bucket for the channel. + chanDetails := chanDetailsBkt.NestedReadBucket(chanID[:]) + if chanDetails == nil { + return fmt.Errorf("no channel details found for "+ + "channel %s referenced by session %s", chanID, + id) + } + + // If a closed height has been set, then the channel is closed. + closedHeight := chanDetails.Get(cChanClosedHeight) + if len(closedHeight) > 0 { + return nil + } + + // Otherwise, the channel is not yet closed meaning that the + // session is not yet closable. We break the ForEach by + // returning an error to indicate this. + return errSessionHasOpenChannels + }) + if errors.Is(err, errSessionHasOpenChannels) { + return false, nil + } else if err != nil { + return false, err + } + + return true, nil +} + // CommitUpdate persists the CommittedUpdate provided in the slot for (session, // seqNum). This allows the client to retransmit this update on startup. func (c *ClientDB) CommitUpdate(id *SessionID, @@ -2016,6 +2239,44 @@ func getDBSessionID(sessionsBkt kvdb.RBucket, sessionID SessionID) (uint64, return id, idBytes, nil } +func getRealSessionID(sessIDIndexBkt kvdb.RBucket, dbID uint64) (*SessionID, + error) { + + dbIDBytes, err := writeBigSize(dbID) + if err != nil { + return nil, err + } + + sessIDBytes := sessIDIndexBkt.Get(dbIDBytes) + if len(sessIDBytes) != SessionIDSize { + return nil, fmt.Errorf("session ID not found") + } + + var sessID SessionID + copy(sessID[:], sessIDBytes) + + return &sessID, nil +} + +func getRealChannelID(chanIDIndexBkt kvdb.RBucket, + dbID uint64) (*lnwire.ChannelID, error) { + + dbIDBytes, err := writeBigSize(dbID) + if err != nil { + return nil, err + } + + chanIDBytes := chanIDIndexBkt.Get(dbIDBytes) + if len(chanIDBytes) != 32 { //nolint:gomnd + return nil, fmt.Errorf("channel ID not found") + } + + var chanIDS lnwire.ChannelID + copy(chanIDS[:], chanIDBytes) + + return &chanIDS, nil +} + // writeBigSize will encode the given uint64 as a BigSize byte slice. func writeBigSize(i uint64) ([]byte, error) { var b bytes.Buffer diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index cd77ec77e..4f5f80749 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -3,6 +3,7 @@ package wtdb_test import ( crand "crypto/rand" "io" + "math/rand" "net" "testing" @@ -17,6 +18,8 @@ import ( "github.com/stretchr/testify/require" ) +const blobType = blob.TypeAltruistCommit + // pseudoAddr is a fake network address to be used for testing purposes. var pseudoAddr = &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911} @@ -193,6 +196,17 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16, require.ErrorIs(h.t, err, expErr) } +func (h *clientDBHarness) markChannelClosed(id lnwire.ChannelID, + blockHeight uint32, expErr error) []wtdb.SessionID { + + h.t.Helper() + + closableSessions, err := h.db.MarkChannelClosed(id, blockHeight) + require.ErrorIs(h.t, err, expErr) + + return closableSessions +} + // newTower is a helper function that creates a new tower with a randomly // generated public key and inserts it into the client DB. func (h *clientDBHarness) newTower() *wtdb.Tower { @@ -605,6 +619,105 @@ func testCommitUpdate(h *clientDBHarness) { }, nil) } +// testMarkChannelClosed asserts the behaviour of MarkChannelClosed. +func testMarkChannelClosed(h *clientDBHarness) { + tower := h.newTower() + + // Create channel 1. + chanID1 := randChannelID(h.t) + + // Since we have not yet registered the channel, we expect an error + // when attempting to mark it as closed. + h.markChannelClosed(chanID1, 1, wtdb.ErrChannelNotRegistered) + + // Now register the channel. + h.registerChan(chanID1, nil, nil) + + // Since there are still no sessions that would have updates for the + // channel, marking it as closed now should succeed. + h.markChannelClosed(chanID1, 1, nil) + + // Register channel 2. + chanID2 := randChannelID(h.t) + h.registerChan(chanID2, nil, nil) + + // Create session1 with MaxUpdates set to 5. + session1 := h.randSession(h.t, tower.ID, 5) + h.insertSession(session1, nil) + + // Add an update for channel 2 in session 1 and ack it too. + update := randCommittedUpdateForChannel(h.t, chanID2, 1) + lastApplied := h.commitUpdate(&session1.ID, update, nil) + require.Zero(h.t, lastApplied) + h.ackUpdate(&session1.ID, 1, 1, nil) + + // Marking channel 2 now should not result in any closable sessions + // since session 1 is not yet exhausted. + sl := h.markChannelClosed(chanID2, 1, nil) + require.Empty(h.t, sl) + + // Create channel 3 and 4. + chanID3 := randChannelID(h.t) + h.registerChan(chanID3, nil, nil) + + chanID4 := randChannelID(h.t) + h.registerChan(chanID4, nil, nil) + + // Add an update for channel 4 and ack it. + update = randCommittedUpdateForChannel(h.t, chanID4, 2) + lastApplied = h.commitUpdate(&session1.ID, update, nil) + require.EqualValues(h.t, 1, lastApplied) + h.ackUpdate(&session1.ID, 2, 2, nil) + + // Add an update for channel 3 in session 1. But dont ack it yet. + update = randCommittedUpdateForChannel(h.t, chanID2, 3) + lastApplied = h.commitUpdate(&session1.ID, update, nil) + require.EqualValues(h.t, 2, lastApplied) + + // Mark channel 4 as closed & assert that session 1 is not seen as + // closable since it still has committed updates. + sl = h.markChannelClosed(chanID4, 1, nil) + require.Empty(h.t, sl) + + // Now ack the update we added above. + h.ackUpdate(&session1.ID, 3, 3, nil) + + // Mark channel 3 as closed & assert that session 1 is still not seen as + // closable since it is not yet exhausted. + sl = h.markChannelClosed(chanID3, 1, nil) + require.Empty(h.t, sl) + + // Create channel 5 and 6. + chanID5 := randChannelID(h.t) + h.registerChan(chanID5, nil, nil) + + chanID6 := randChannelID(h.t) + h.registerChan(chanID6, nil, nil) + + // Add an update for channel 5 and ack it. + update = randCommittedUpdateForChannel(h.t, chanID5, 4) + lastApplied = h.commitUpdate(&session1.ID, update, nil) + require.EqualValues(h.t, 3, lastApplied) + h.ackUpdate(&session1.ID, 4, 4, nil) + + // Add an update for channel 6 and ack it. + update = randCommittedUpdateForChannel(h.t, chanID6, 5) + lastApplied = h.commitUpdate(&session1.ID, update, nil) + require.EqualValues(h.t, 4, lastApplied) + h.ackUpdate(&session1.ID, 5, 5, nil) + + // The session is no exhausted. + // If we now close channel 5, session 1 should still not be closable + // since it has an update for channel 6 which is still open. + sl = h.markChannelClosed(chanID5, 1, nil) + require.Empty(h.t, sl) + + // Finally, if we close channel 6, session 1 _should_ be in the closable + // list. + sl = h.markChannelClosed(chanID6, 1, nil) + require.ElementsMatch(h.t, sl, []wtdb.SessionID{session1.ID}) +} + // testAckUpdate asserts the behavior of AckUpdate. func testAckUpdate(h *clientDBHarness) { const blobType = blob.TypeAltruistCommit @@ -821,6 +934,10 @@ func TestClientDB(t *testing.T) { name: "ack update", run: testAckUpdate, }, + { + name: "mark channel closed", + run: testMarkChannelClosed, + }, } for _, database := range dbs { @@ -841,12 +958,32 @@ func TestClientDB(t *testing.T) { // randCommittedUpdate generates a random committed update. func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate { + t.Helper() + + chanID := randChannelID(t) + + return randCommittedUpdateForChannel(t, chanID, seqNum) +} + +func randChannelID(t *testing.T) lnwire.ChannelID { + t.Helper() + var chanID lnwire.ChannelID _, err := io.ReadFull(crand.Reader, chanID[:]) require.NoError(t, err) + return chanID +} + +// randCommittedUpdateForChannel generates a random committed update for the +// given channel ID. +func randCommittedUpdateForChannel(t *testing.T, chanID lnwire.ChannelID, + seqNum uint16) *wtdb.CommittedUpdate { + + t.Helper() + var hint blob.BreachHint - _, err = io.ReadFull(crand.Reader, hint[:]) + _, err := io.ReadFull(crand.Reader, hint[:]) require.NoError(t, err) encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type())) @@ -865,3 +1002,27 @@ func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate { }, } } + +func (h *clientDBHarness) randSession(t *testing.T, + towerID wtdb.TowerID, maxUpdates uint16) *wtdb.ClientSession { + + t.Helper() + + var id wtdb.SessionID + rand.Read(id[:]) + + return &wtdb.ClientSession{ + ClientSessionBody: wtdb.ClientSessionBody{ + TowerID: towerID, + Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blobType, + }, + MaxUpdates: maxUpdates, + }, + RewardPkScript: []byte{0x01, 0x02, 0x03}, + KeyIndex: h.nextKeyIndex(towerID, blobType), + }, + ID: id, + } +} diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 9d38c2da2..2820d74cd 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -25,19 +25,26 @@ type rangeIndexArrayMap map[wtdb.SessionID]map[lnwire.ChannelID]*wtdb.RangeIndex type rangeIndexKVStore map[wtdb.SessionID]map[lnwire.ChannelID]*mockKVStore +type channel struct { + summary *wtdb.ClientChanSummary + closedHeight uint32 + sessions map[wtdb.SessionID]bool +} + // ClientDB is a mock, in-memory database or testing the watchtower client // behavior. type ClientDB struct { nextTowerID uint64 // to be used atomically mu sync.Mutex - summaries map[lnwire.ChannelID]wtdb.ClientChanSummary + channels map[lnwire.ChannelID]*channel activeSessions map[wtdb.SessionID]wtdb.ClientSession ackedUpdates rangeIndexArrayMap persistedAckedUpdates rangeIndexKVStore committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate towerIndex map[towerPK]wtdb.TowerID towers map[wtdb.TowerID]*wtdb.Tower + closableSessions map[wtdb.SessionID]uint32 nextIndex uint32 indexes map[keyIndexKey]uint32 @@ -47,9 +54,7 @@ type ClientDB struct { // NewClientDB initializes a new mock ClientDB. func NewClientDB() *ClientDB { return &ClientDB{ - summaries: make( - map[lnwire.ChannelID]wtdb.ClientChanSummary, - ), + channels: make(map[lnwire.ChannelID]*channel), activeSessions: make( map[wtdb.SessionID]wtdb.ClientSession, ), @@ -58,10 +63,11 @@ func NewClientDB() *ClientDB { committedUpdates: make( map[wtdb.SessionID][]wtdb.CommittedUpdate, ), - towerIndex: make(map[towerPK]wtdb.TowerID), - towers: make(map[wtdb.TowerID]*wtdb.Tower), - indexes: make(map[keyIndexKey]uint32), - legacyIndexes: make(map[wtdb.TowerID]uint32), + towerIndex: make(map[towerPK]wtdb.TowerID), + towers: make(map[wtdb.TowerID]*wtdb.Tower), + indexes: make(map[keyIndexKey]uint32), + legacyIndexes: make(map[wtdb.TowerID]uint32), + closableSessions: make(map[wtdb.SessionID]uint32), } } @@ -503,6 +509,13 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, continue } + // Add sessionID to channel. + channel, ok := m.channels[update.BackupID.ChanID] + if !ok { + return wtdb.ErrChannelNotRegistered + } + channel.sessions[*id] = true + // Remove the committed update from disk and mark the update as // acked. The tower last applied value is also recorded to send // along with the next update. @@ -545,15 +558,107 @@ func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) { defer m.mu.Unlock() summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary) - for chanID, summary := range m.summaries { + for chanID, channel := range m.channels { summaries[chanID] = wtdb.ClientChanSummary{ - SweepPkScript: cloneBytes(summary.SweepPkScript), + SweepPkScript: cloneBytes( + channel.summary.SweepPkScript, + ), } } return summaries, nil } +// MarkChannelClosed will mark a registered channel as closed by setting +// its closed-height as the given block height. It returns a list of +// session IDs for sessions that are now considered closable due to the +// close of this channel. +func (m *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID, + blockHeight uint32) ([]wtdb.SessionID, error) { + + m.mu.Lock() + defer m.mu.Unlock() + + channel, ok := m.channels[chanID] + if !ok { + return nil, wtdb.ErrChannelNotRegistered + } + + // If there are no sessions for this channel, the channel details can be + // deleted. + if len(channel.sessions) == 0 { + delete(m.channels, chanID) + return nil, nil + } + + // Mark the channel as closed. + channel.closedHeight = blockHeight + + // Now iterate through all the sessions of the channel to check if any + // of them are closeable. + var closableSessions []wtdb.SessionID + for sessID := range channel.sessions { + isClosable, err := m.isSessionClosable(sessID) + if err != nil { + return nil, err + } + + if !isClosable { + continue + } + + closableSessions = append(closableSessions, sessID) + + // Add session to "closableSessions" list and add the block + // height that this last channel was closed in. This will be + // used in future to determine when we should delete the + // session. + m.closableSessions[sessID] = blockHeight + } + + return closableSessions, nil +} + +// isSessionClosable returns true if a session is considered closable. A session +// is considered closable only if: +// 1) It has no un-acked updates +// 2) It is exhausted (ie it cant accept any more updates) +// 3) All the channels that it has acked-updates for are closed. +func (m *ClientDB) isSessionClosable(id wtdb.SessionID) (bool, error) { + // The session is not closable if it has un-acked updates. + if len(m.committedUpdates[id]) > 0 { + return false, nil + } + + sess, ok := m.activeSessions[id] + if !ok { + return false, wtdb.ErrClientSessionNotFound + } + + // The session is not closable if it is not yet exhausted. + if sess.SeqNum != sess.Policy.MaxUpdates { + return false, nil + } + + // Iterate over each of the channels that the session has acked-updates + // for. If any of those channels are not closed, then the session is + // not yet closable. + for chanID := range m.ackedUpdates[id] { + channel, ok := m.channels[chanID] + if !ok { + continue + } + + // Channel is not yet closed, and so we can not yet delete the + // session. + if channel.closedHeight == 0 { + return false, nil + } + } + + return true, nil +} + // GetClientSession loads the ClientSession with the given ID from the DB. func (m *ClientDB) GetClientSession(id wtdb.SessionID, opts ...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error) { @@ -595,12 +700,15 @@ func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID, m.mu.Lock() defer m.mu.Unlock() - if _, ok := m.summaries[chanID]; ok { + if _, ok := m.channels[chanID]; ok { return wtdb.ErrChannelAlreadyRegistered } - m.summaries[chanID] = wtdb.ClientChanSummary{ - SweepPkScript: cloneBytes(sweepPkScript), + m.channels[chanID] = &channel{ + summary: &wtdb.ClientChanSummary{ + SweepPkScript: cloneBytes(sweepPkScript), + }, + sessions: make(map[wtdb.SessionID]bool), } return nil