diff --git a/blockcache/blockcache.go b/blockcache/blockcache.go index 6c1f5f74c..09d5d3395 100644 --- a/blockcache/blockcache.go +++ b/blockcache/blockcache.go @@ -14,7 +14,7 @@ import ( // BlockCache is an lru cache for blocks. type BlockCache struct { Cache *lru.Cache[wire.InvVect, *neutrino.CacheableBlock] - HashMutex *multimutex.HashMutex + HashMutex *multimutex.Mutex[lntypes.Hash] } // NewBlockCache creates a new BlockCache with the given maximum capacity. @@ -23,7 +23,7 @@ func NewBlockCache(capacity uint64) *BlockCache { Cache: lru.NewCache[wire.InvVect, *neutrino.CacheableBlock]( capacity, ), - HashMutex: multimutex.NewHashMutex(), + HashMutex: multimutex.NewMutex[lntypes.Hash](), } } diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 72e5876ac..526f7d79b 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -450,7 +450,7 @@ type AuthenticatedGossiper struct { // goroutine per channel ID. This is done to ensure that when // the gossiper is handling an announcement, the db state stays // consistent between when the DB is first read until it's written. - channelMtx *multimutex.Mutex + channelMtx *multimutex.Mutex[uint64] recentRejects *lru.Cache[rejectCacheKey, *cachedReject] @@ -496,7 +496,7 @@ func New(cfg Config, selfKeyDesc *keychain.KeyDescriptor) *AuthenticatedGossiper prematureChannelUpdates: lru.NewCache[uint64, *cachedNetworkMsg]( //nolint: lll maxPrematureUpdates, ), - channelMtx: multimutex.NewMutex(), + channelMtx: multimutex.NewMutex[uint64](), recentRejects: lru.NewCache[rejectCacheKey, *cachedReject]( maxRejectedUpdates, ), diff --git a/htlcswitch/payment_result.go b/htlcswitch/payment_result.go index 8d6cb5b3a..cd982b8bb 100644 --- a/htlcswitch/payment_result.go +++ b/htlcswitch/payment_result.go @@ -93,14 +93,14 @@ type networkResultStore struct { // paymentIDMtx is a multimutex used to make sure the database and // result subscribers map is consistent for each payment ID in case of // concurrent callers. - paymentIDMtx *multimutex.Mutex + paymentIDMtx *multimutex.Mutex[uint64] } func newNetworkResultStore(db kvdb.Backend) *networkResultStore { return &networkResultStore{ backend: db, results: make(map[uint64][]chan *networkResult), - paymentIDMtx: multimutex.NewMutex(), + paymentIDMtx: multimutex.NewMutex[uint64](), } } diff --git a/multimutex/hash_mutex.go b/multimutex/hash_mutex.go deleted file mode 100644 index 4a65394d1..000000000 --- a/multimutex/hash_mutex.go +++ /dev/null @@ -1,90 +0,0 @@ -package multimutex - -import ( - "fmt" - "sync" - - "github.com/lightningnetwork/lnd/lntypes" -) - -// HashMutex is a struct that keeps track of a set of mutexes with a given hash. -// It can be used for making sure only one goroutine gets given the mutex per -// hash. -type HashMutex struct { - // mutexes is a map of hashes to a cntMutex. The cntMutex for - // a given hash will hold the mutex to be used by all - // callers requesting access for the hash, in addition to - // the count of callers. - mutexes map[lntypes.Hash]*cntMutex - - // mapMtx is used to give synchronize concurrent access - // to the mutexes map. - mapMtx sync.Mutex -} - -// NewHashMutex creates a new Mutex. -func NewHashMutex() *HashMutex { - return &HashMutex{ - mutexes: make(map[lntypes.Hash]*cntMutex), - } -} - -// Lock locks the mutex by the given hash. If the mutex is already -// locked by this hash, Lock blocks until the mutex is available. -func (c *HashMutex) Lock(hash lntypes.Hash) { - c.mapMtx.Lock() - mtx, ok := c.mutexes[hash] - if ok { - // If the mutex already existed in the map, we - // increment its counter, to indicate that there - // now is one more goroutine waiting for it. - mtx.cnt++ - } else { - // If it was not in the map, it means no other - // goroutine has locked the mutex for this hash, - // and we can create a new mutex with count 1 - // and add it to the map. - mtx = &cntMutex{ - cnt: 1, - } - c.mutexes[hash] = mtx - } - c.mapMtx.Unlock() - - // Acquire the mutex for this hash. - mtx.Lock() -} - -// Unlock unlocks the mutex by the given hash. It is a run-time -// error if the mutex is not locked by the hash on entry to Unlock. -func (c *HashMutex) Unlock(hash lntypes.Hash) { - // Since we are done with all the work for this - // update, we update the map to reflect that. - c.mapMtx.Lock() - - mtx, ok := c.mutexes[hash] - if !ok { - // The mutex not existing in the map means - // an unlock for an hash not currently locked - // was attempted. - panic(fmt.Sprintf("double unlock for hash %v", - hash)) - } - - // Decrement the counter. If the count goes to - // zero, it means this caller was the last one - // to wait for the mutex, and we can delete it - // from the map. We can do this safely since we - // are under the mapMtx, meaning that all other - // goroutines waiting for the mutex already - // have incremented it, or will create a new - // mutex when they get the mapMtx. - mtx.cnt-- - if mtx.cnt == 0 { - delete(c.mutexes, hash) - } - c.mapMtx.Unlock() - - // Unlock the mutex for this hash. - mtx.Unlock() -} diff --git a/multimutex/multimutex.go b/multimutex/multimutex.go index e37c88d51..4180f3e53 100644 --- a/multimutex/multimutex.go +++ b/multimutex/multimutex.go @@ -5,51 +5,48 @@ import ( "sync" ) -// cntMutex is a struct that wraps a counter and a mutex, and is used -// to keep track of the number of goroutines waiting for access to the +// cntMutex is a struct that wraps a counter and a mutex, and is used to keep +// track of the number of goroutines waiting for access to the // mutex, such that we can forget about it when the counter is zero. type cntMutex struct { cnt int sync.Mutex } -// Mutex is a struct that keeps track of a set of mutexes with -// a given ID. It can be used for making sure only one goroutine -// gets given the mutex per ID. -type Mutex struct { - // mutexes is a map of IDs to a cntMutex. The cntMutex for - // a given ID will hold the mutex to be used by all - // callers requesting access for the ID, in addition to - // the count of callers. - mutexes map[uint64]*cntMutex +// Mutex is a struct that keeps track of a set of mutexes with a given ID. It +// can be used for making sure only one goroutine gets given the mutex per ID. +type Mutex[T comparable] struct { + // mutexes is a map of IDs to a cntMutex. The cntMutex for a given ID + // will hold the mutex to be used by all callers requesting access for + // the ID, in addition to the count of callers. + mutexes map[T]*cntMutex - // mapMtx is used to give synchronize concurrent access - // to the mutexes map. + // mapMtx is used to give synchronize concurrent access to the mutexes + // map. mapMtx sync.Mutex } // NewMutex creates a new Mutex. -func NewMutex() *Mutex { - return &Mutex{ - mutexes: make(map[uint64]*cntMutex), +func NewMutex[T comparable]() *Mutex[T] { + return &Mutex[T]{ + mutexes: make(map[T]*cntMutex), } } -// Lock locks the mutex by the given ID. If the mutex is already -// locked by this ID, Lock blocks until the mutex is available. -func (c *Mutex) Lock(id uint64) { +// Lock locks the mutex by the given ID. If the mutex is already locked by this +// ID, Lock blocks until the mutex is available. +func (c *Mutex[T]) Lock(id T) { c.mapMtx.Lock() mtx, ok := c.mutexes[id] if ok { - // If the mutex already existed in the map, we - // increment its counter, to indicate that there - // now is one more goroutine waiting for it. + // If the mutex already existed in the map, we increment its + // counter, to indicate that there now is one more goroutine + // waiting for it. mtx.cnt++ } else { - // If it was not in the map, it means no other - // goroutine has locked the mutex for this ID, - // and we can create a new mutex with count 1 - // and add it to the map. + // If it was not in the map, it means no other goroutine has + // locked the mutex for this ID, and we can create a new mutex + // with count 1 and add it to the map. mtx = &cntMutex{ cnt: 1, } @@ -61,30 +58,26 @@ func (c *Mutex) Lock(id uint64) { mtx.Lock() } -// Unlock unlocks the mutex by the given ID. It is a run-time -// error if the mutex is not locked by the ID on entry to Unlock. -func (c *Mutex) Unlock(id uint64) { - // Since we are done with all the work for this - // update, we update the map to reflect that. +// Unlock unlocks the mutex by the given ID. It is a run-time error if the +// mutex is not locked by the ID on entry to Unlock. +func (c *Mutex[T]) Unlock(id T) { + // Since we are done with all the work for this update, we update the + // map to reflect that. c.mapMtx.Lock() mtx, ok := c.mutexes[id] if !ok { - // The mutex not existing in the map means - // an unlock for an ID not currently locked - // was attempted. + // The mutex not existing in the map means an unlock for an ID + // not currently locked was attempted. panic(fmt.Sprintf("double unlock for id %v", id)) } - // Decrement the counter. If the count goes to - // zero, it means this caller was the last one - // to wait for the mutex, and we can delete it - // from the map. We can do this safely since we - // are under the mapMtx, meaning that all other - // goroutines waiting for the mutex already - // have incremented it, or will create a new - // mutex when they get the mapMtx. + // Decrement the counter. If the count goes to zero, it means this + // caller was the last one to wait for the mutex, and we can delete it + // from the map. We can do this safely since we are under the mapMtx, + // meaning that all other goroutines waiting for the mutex already have + // incremented it, or will create a new mutex when they get the mapMtx. mtx.cnt-- if mtx.cnt == 0 { delete(c.mutexes, id) diff --git a/routing/control_tower.go b/routing/control_tower.go index a0c5b1df7..d2cbc6bbf 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -132,7 +132,7 @@ type controlTower struct { // paymentsMtx provides synchronization on the payment level to ensure // that no race conditions occur in between updating the database and // sending a notification. - paymentsMtx *multimutex.HashMutex + paymentsMtx *multimutex.Mutex[lntypes.Hash] } // NewControlTower creates a new instance of the controlTower. @@ -143,7 +143,7 @@ func NewControlTower(db *channeldb.PaymentControl) ControlTower { map[uint64]*controlTowerSubscriberImpl, ), subscribers: make(map[lntypes.Hash][]*controlTowerSubscriberImpl), - paymentsMtx: multimutex.NewHashMutex(), + paymentsMtx: multimutex.NewMutex[lntypes.Hash](), } } diff --git a/routing/router.go b/routing/router.go index 84bfbd71e..123d13c94 100644 --- a/routing/router.go +++ b/routing/router.go @@ -442,7 +442,7 @@ type ChannelRouter struct { // channelEdgeMtx is a mutex we use to make sure we process only one // ChannelEdgePolicy at a time for a given channelID, to ensure // consistency between the various database accesses. - channelEdgeMtx *multimutex.Mutex + channelEdgeMtx *multimutex.Mutex[uint64] // statTicker is a resumable ticker that logs the router's progress as // it discovers channels or receives updates. @@ -480,7 +480,7 @@ func New(cfg Config) (*ChannelRouter, error) { networkUpdates: make(chan *routingMsg), topologyClients: &lnutils.SyncMap[uint64, *topologyClient]{}, ntfnClientUpdates: make(chan *topologyClientUpdate), - channelEdgeMtx: multimutex.NewMutex(), + channelEdgeMtx: multimutex.NewMutex[uint64](), selfNode: selfNode, statTicker: ticker.New(defaultStatInterval), stats: new(routerStats),