diff --git a/chainntnfs/mempool.go b/chainntnfs/mempool.go index 132dce744..d613e829a 100644 --- a/chainntnfs/mempool.go +++ b/chainntnfs/mempool.go @@ -39,6 +39,12 @@ type MempoolSpendEvent struct { // NOTE: This channel must be buffered. Spend <-chan *SpendDetail + // id is the unique identifier of this subscription. + id uint64 + + // outpoint is the subscribed outpoint. + outpoint wire.OutPoint + // event is the channel that will be sent upon once the target outpoint // is spent. event chan *SpendDetail @@ -48,10 +54,12 @@ type MempoolSpendEvent struct { } // newMempoolSpendEvent returns a new instance of MempoolSpendEvent. -func newMempoolSpendEvent() *MempoolSpendEvent { +func newMempoolSpendEvent(id uint64, op wire.OutPoint) *MempoolSpendEvent { sub := &MempoolSpendEvent{ - event: make(chan *SpendDetail, 1), - cancel: make(chan struct{}), + id: id, + outpoint: op, + event: make(chan *SpendDetail, 1), + cancel: make(chan struct{}), } // Mount the receive only channel to the event channel. @@ -60,11 +68,6 @@ func newMempoolSpendEvent() *MempoolSpendEvent { return sub } -// Cancel cancels the subscription. -func (m *MempoolSpendEvent) Cancel() { - close(m.cancel) -} - // NewMempoolNotifier takes a chain connection and returns a new mempool // notifier. func NewMempoolNotifier() *MempoolNotifier { @@ -82,31 +85,61 @@ func NewMempoolNotifier() *MempoolNotifier { func (m *MempoolNotifier) SubscribeInput( outpoint wire.OutPoint) *MempoolSpendEvent { - Log.Debugf("Subscribing mempool event for input %s", outpoint) - // Get the current subscribers for this input or create a new one. clients := &lnutils.SyncMap[uint64, *MempoolSpendEvent]{} clients, _ = m.subscribedInputs.LoadOrStore(outpoint, clients) + // Increment the subscription counter and return the new value. + subscriptionID := m.sCounter.Add(1) + // Create a new subscription. - sub := newMempoolSpendEvent() + sub := newMempoolSpendEvent(subscriptionID, outpoint) // Add the subscriber with a unique id. - subscriptionID := m.sCounter.Add(1) clients.Store(subscriptionID, sub) // Update the subscribed inputs. m.subscribedInputs.Store(outpoint, clients) + Log.Debugf("Subscribed(id=%v) mempool event for input=%s", + subscriptionID, outpoint) + return sub } -// Unsubscribe removes the subscription for the given outpoint. -func (m *MempoolNotifier) Unsubscribe(outpoint wire.OutPoint) { +// UnsubscribeInput removes all the subscriptions for the given outpoint. +func (m *MempoolNotifier) UnsubscribeInput(outpoint wire.OutPoint) { Log.Debugf("Unsubscribing MempoolSpendEvent for input %s", outpoint) m.subscribedInputs.Delete(outpoint) } +// UnsubscribeEvent removes a given subscriber for the given MempoolSpendEvent. +func (m *MempoolNotifier) UnsubscribeEvent(sub *MempoolSpendEvent) { + Log.Debugf("Unsubscribing(id=%v) MempoolSpendEvent for input=%s", + sub.id, sub.outpoint) + + // Load all the subscribers for this input. + clients, loaded := m.subscribedInputs.Load(sub.outpoint) + if !loaded { + Log.Debugf("No subscribers for input %s", sub.outpoint) + return + } + + // Load the subscriber. + subscriber, loaded := clients.Load(sub.id) + if !loaded { + Log.Debugf("No subscribers for input %s with id %v", + sub.outpoint, sub.id) + return + } + + // Close the cancel channel in case it's been used in a goroutine. + close(subscriber.cancel) + + // Remove the subscriber. + clients.Delete(sub.id) +} + // ProcessRelevantSpendTx takes a transaction and checks whether it spends any // of the subscribed inputs. If so, spend notifications are sent to the // relevant subscribers. @@ -174,29 +207,15 @@ func (m *MempoolNotifier) notifySpent(spentInputs inputsWithTx) { defer m.wg.Done() - Log.Debugf("Notifying client %d", id) - // Send the spend details to the subscriber. select { case sub.event <- detail: - Log.Debugf("Notified mempool spent for input %s", op) + Log.Debugf("Notified(id=%v) mempool spent for input %s", + sub.id, op) case <-sub.cancel: - Log.Debugf("Subscription canceled, skipped notifying "+ - "mempool spent for input %s", op) - - // Find all the subscribers for this outpoint. - clients, loaded := m.subscribedInputs.Load(op) - if !loaded { - Log.Errorf("Client %d not found", id) - return - } - - // Delete the specific subscriber. - clients.Delete(id) - - // Update the subscribers map. - m.subscribedInputs.Store(op, clients) + Log.Debugf("Subscription(id=%v) canceled, skipped "+ + "notifying spent for input %s", sub.id, op) case <-m.quit: Log.Debugf("Mempool notifier quit, skipped notifying "+ @@ -212,7 +231,8 @@ func (m *MempoolNotifier) notifySpent(spentInputs inputsWithTx) { defer m.wg.Done() txid := detail.SpendingTx.TxHash() - Log.Debugf("Notifying the spend of %s in tx %s", op, txid) + Log.Debugf("Notifying all clients for the spend of %s in tx %s", + op, txid) // Load the subscriber. subs, loaded := m.subscribedInputs.Load(op)