multimutex: remove HashMutex, make Mutex type a type param

In this commit, we eliminate some code duplication by removing the old
`HashMutex` struct as it just duplicates all the code with a different
type (uint64 and hash). We then make the main Mutex struct take a type
param, so the key can be parametrized when the struct is instantiated.
This commit is contained in:
Olaoluwa Osuntokun 2023-06-01 17:38:37 -07:00
parent f9d4600ff8
commit a7d6826f60
No known key found for this signature in database
GPG Key ID: 3BBD59E99B280306
7 changed files with 45 additions and 142 deletions

View File

@ -14,7 +14,7 @@ import (
// BlockCache is an lru cache for blocks. // BlockCache is an lru cache for blocks.
type BlockCache struct { type BlockCache struct {
Cache *lru.Cache[wire.InvVect, *neutrino.CacheableBlock] Cache *lru.Cache[wire.InvVect, *neutrino.CacheableBlock]
HashMutex *multimutex.HashMutex HashMutex *multimutex.Mutex[lntypes.Hash]
} }
// NewBlockCache creates a new BlockCache with the given maximum capacity. // NewBlockCache creates a new BlockCache with the given maximum capacity.
@ -23,7 +23,7 @@ func NewBlockCache(capacity uint64) *BlockCache {
Cache: lru.NewCache[wire.InvVect, *neutrino.CacheableBlock]( Cache: lru.NewCache[wire.InvVect, *neutrino.CacheableBlock](
capacity, capacity,
), ),
HashMutex: multimutex.NewHashMutex(), HashMutex: multimutex.NewMutex[lntypes.Hash](),
} }
} }

View File

@ -450,7 +450,7 @@ type AuthenticatedGossiper struct {
// goroutine per channel ID. This is done to ensure that when // goroutine per channel ID. This is done to ensure that when
// the gossiper is handling an announcement, the db state stays // the gossiper is handling an announcement, the db state stays
// consistent between when the DB is first read until it's written. // consistent between when the DB is first read until it's written.
channelMtx *multimutex.Mutex channelMtx *multimutex.Mutex[uint64]
recentRejects *lru.Cache[rejectCacheKey, *cachedReject] recentRejects *lru.Cache[rejectCacheKey, *cachedReject]
@ -496,7 +496,7 @@ func New(cfg Config, selfKeyDesc *keychain.KeyDescriptor) *AuthenticatedGossiper
prematureChannelUpdates: lru.NewCache[uint64, *cachedNetworkMsg]( //nolint: lll prematureChannelUpdates: lru.NewCache[uint64, *cachedNetworkMsg]( //nolint: lll
maxPrematureUpdates, maxPrematureUpdates,
), ),
channelMtx: multimutex.NewMutex(), channelMtx: multimutex.NewMutex[uint64](),
recentRejects: lru.NewCache[rejectCacheKey, *cachedReject]( recentRejects: lru.NewCache[rejectCacheKey, *cachedReject](
maxRejectedUpdates, maxRejectedUpdates,
), ),

View File

@ -93,14 +93,14 @@ type networkResultStore struct {
// paymentIDMtx is a multimutex used to make sure the database and // paymentIDMtx is a multimutex used to make sure the database and
// result subscribers map is consistent for each payment ID in case of // result subscribers map is consistent for each payment ID in case of
// concurrent callers. // concurrent callers.
paymentIDMtx *multimutex.Mutex paymentIDMtx *multimutex.Mutex[uint64]
} }
func newNetworkResultStore(db kvdb.Backend) *networkResultStore { func newNetworkResultStore(db kvdb.Backend) *networkResultStore {
return &networkResultStore{ return &networkResultStore{
backend: db, backend: db,
results: make(map[uint64][]chan *networkResult), results: make(map[uint64][]chan *networkResult),
paymentIDMtx: multimutex.NewMutex(), paymentIDMtx: multimutex.NewMutex[uint64](),
} }
} }

View File

@ -1,90 +0,0 @@
package multimutex
import (
"fmt"
"sync"
"github.com/lightningnetwork/lnd/lntypes"
)
// HashMutex is a struct that keeps track of a set of mutexes with a given hash.
// It can be used for making sure only one goroutine gets given the mutex per
// hash.
type HashMutex struct {
// mutexes is a map of hashes to a cntMutex. The cntMutex for
// a given hash will hold the mutex to be used by all
// callers requesting access for the hash, in addition to
// the count of callers.
mutexes map[lntypes.Hash]*cntMutex
// mapMtx is used to give synchronize concurrent access
// to the mutexes map.
mapMtx sync.Mutex
}
// NewHashMutex creates a new Mutex.
func NewHashMutex() *HashMutex {
return &HashMutex{
mutexes: make(map[lntypes.Hash]*cntMutex),
}
}
// Lock locks the mutex by the given hash. If the mutex is already
// locked by this hash, Lock blocks until the mutex is available.
func (c *HashMutex) Lock(hash lntypes.Hash) {
c.mapMtx.Lock()
mtx, ok := c.mutexes[hash]
if ok {
// If the mutex already existed in the map, we
// increment its counter, to indicate that there
// now is one more goroutine waiting for it.
mtx.cnt++
} else {
// If it was not in the map, it means no other
// goroutine has locked the mutex for this hash,
// and we can create a new mutex with count 1
// and add it to the map.
mtx = &cntMutex{
cnt: 1,
}
c.mutexes[hash] = mtx
}
c.mapMtx.Unlock()
// Acquire the mutex for this hash.
mtx.Lock()
}
// Unlock unlocks the mutex by the given hash. It is a run-time
// error if the mutex is not locked by the hash on entry to Unlock.
func (c *HashMutex) Unlock(hash lntypes.Hash) {
// Since we are done with all the work for this
// update, we update the map to reflect that.
c.mapMtx.Lock()
mtx, ok := c.mutexes[hash]
if !ok {
// The mutex not existing in the map means
// an unlock for an hash not currently locked
// was attempted.
panic(fmt.Sprintf("double unlock for hash %v",
hash))
}
// Decrement the counter. If the count goes to
// zero, it means this caller was the last one
// to wait for the mutex, and we can delete it
// from the map. We can do this safely since we
// are under the mapMtx, meaning that all other
// goroutines waiting for the mutex already
// have incremented it, or will create a new
// mutex when they get the mapMtx.
mtx.cnt--
if mtx.cnt == 0 {
delete(c.mutexes, hash)
}
c.mapMtx.Unlock()
// Unlock the mutex for this hash.
mtx.Unlock()
}

View File

@ -5,51 +5,48 @@ import (
"sync" "sync"
) )
// cntMutex is a struct that wraps a counter and a mutex, and is used // cntMutex is a struct that wraps a counter and a mutex, and is used to keep
// to keep track of the number of goroutines waiting for access to the // track of the number of goroutines waiting for access to the
// mutex, such that we can forget about it when the counter is zero. // mutex, such that we can forget about it when the counter is zero.
type cntMutex struct { type cntMutex struct {
cnt int cnt int
sync.Mutex sync.Mutex
} }
// Mutex is a struct that keeps track of a set of mutexes with // Mutex is a struct that keeps track of a set of mutexes with a given ID. It
// a given ID. It can be used for making sure only one goroutine // can be used for making sure only one goroutine gets given the mutex per ID.
// gets given the mutex per ID. type Mutex[T comparable] struct {
type Mutex struct { // mutexes is a map of IDs to a cntMutex. The cntMutex for a given ID
// mutexes is a map of IDs to a cntMutex. The cntMutex for // will hold the mutex to be used by all callers requesting access for
// a given ID will hold the mutex to be used by all // the ID, in addition to the count of callers.
// callers requesting access for the ID, in addition to mutexes map[T]*cntMutex
// the count of callers.
mutexes map[uint64]*cntMutex
// mapMtx is used to give synchronize concurrent access // mapMtx is used to give synchronize concurrent access to the mutexes
// to the mutexes map. // map.
mapMtx sync.Mutex mapMtx sync.Mutex
} }
// NewMutex creates a new Mutex. // NewMutex creates a new Mutex.
func NewMutex() *Mutex { func NewMutex[T comparable]() *Mutex[T] {
return &Mutex{ return &Mutex[T]{
mutexes: make(map[uint64]*cntMutex), mutexes: make(map[T]*cntMutex),
} }
} }
// Lock locks the mutex by the given ID. If the mutex is already // Lock locks the mutex by the given ID. If the mutex is already locked by this
// locked by this ID, Lock blocks until the mutex is available. // ID, Lock blocks until the mutex is available.
func (c *Mutex) Lock(id uint64) { func (c *Mutex[T]) Lock(id T) {
c.mapMtx.Lock() c.mapMtx.Lock()
mtx, ok := c.mutexes[id] mtx, ok := c.mutexes[id]
if ok { if ok {
// If the mutex already existed in the map, we // If the mutex already existed in the map, we increment its
// increment its counter, to indicate that there // counter, to indicate that there now is one more goroutine
// now is one more goroutine waiting for it. // waiting for it.
mtx.cnt++ mtx.cnt++
} else { } else {
// If it was not in the map, it means no other // If it was not in the map, it means no other goroutine has
// goroutine has locked the mutex for this ID, // locked the mutex for this ID, and we can create a new mutex
// and we can create a new mutex with count 1 // with count 1 and add it to the map.
// and add it to the map.
mtx = &cntMutex{ mtx = &cntMutex{
cnt: 1, cnt: 1,
} }
@ -61,30 +58,26 @@ func (c *Mutex) Lock(id uint64) {
mtx.Lock() mtx.Lock()
} }
// Unlock unlocks the mutex by the given ID. It is a run-time // Unlock unlocks the mutex by the given ID. It is a run-time error if the
// error if the mutex is not locked by the ID on entry to Unlock. // mutex is not locked by the ID on entry to Unlock.
func (c *Mutex) Unlock(id uint64) { func (c *Mutex[T]) Unlock(id T) {
// Since we are done with all the work for this // Since we are done with all the work for this update, we update the
// update, we update the map to reflect that. // map to reflect that.
c.mapMtx.Lock() c.mapMtx.Lock()
mtx, ok := c.mutexes[id] mtx, ok := c.mutexes[id]
if !ok { if !ok {
// The mutex not existing in the map means // The mutex not existing in the map means an unlock for an ID
// an unlock for an ID not currently locked // not currently locked was attempted.
// was attempted.
panic(fmt.Sprintf("double unlock for id %v", panic(fmt.Sprintf("double unlock for id %v",
id)) id))
} }
// Decrement the counter. If the count goes to // Decrement the counter. If the count goes to zero, it means this
// zero, it means this caller was the last one // caller was the last one to wait for the mutex, and we can delete it
// to wait for the mutex, and we can delete it // from the map. We can do this safely since we are under the mapMtx,
// from the map. We can do this safely since we // meaning that all other goroutines waiting for the mutex already have
// are under the mapMtx, meaning that all other // incremented it, or will create a new mutex when they get the mapMtx.
// goroutines waiting for the mutex already
// have incremented it, or will create a new
// mutex when they get the mapMtx.
mtx.cnt-- mtx.cnt--
if mtx.cnt == 0 { if mtx.cnt == 0 {
delete(c.mutexes, id) delete(c.mutexes, id)

View File

@ -132,7 +132,7 @@ type controlTower struct {
// paymentsMtx provides synchronization on the payment level to ensure // paymentsMtx provides synchronization on the payment level to ensure
// that no race conditions occur in between updating the database and // that no race conditions occur in between updating the database and
// sending a notification. // sending a notification.
paymentsMtx *multimutex.HashMutex paymentsMtx *multimutex.Mutex[lntypes.Hash]
} }
// NewControlTower creates a new instance of the controlTower. // NewControlTower creates a new instance of the controlTower.
@ -143,7 +143,7 @@ func NewControlTower(db *channeldb.PaymentControl) ControlTower {
map[uint64]*controlTowerSubscriberImpl, map[uint64]*controlTowerSubscriberImpl,
), ),
subscribers: make(map[lntypes.Hash][]*controlTowerSubscriberImpl), subscribers: make(map[lntypes.Hash][]*controlTowerSubscriberImpl),
paymentsMtx: multimutex.NewHashMutex(), paymentsMtx: multimutex.NewMutex[lntypes.Hash](),
} }
} }

View File

@ -442,7 +442,7 @@ type ChannelRouter struct {
// channelEdgeMtx is a mutex we use to make sure we process only one // channelEdgeMtx is a mutex we use to make sure we process only one
// ChannelEdgePolicy at a time for a given channelID, to ensure // ChannelEdgePolicy at a time for a given channelID, to ensure
// consistency between the various database accesses. // consistency between the various database accesses.
channelEdgeMtx *multimutex.Mutex channelEdgeMtx *multimutex.Mutex[uint64]
// statTicker is a resumable ticker that logs the router's progress as // statTicker is a resumable ticker that logs the router's progress as
// it discovers channels or receives updates. // it discovers channels or receives updates.
@ -480,7 +480,7 @@ func New(cfg Config) (*ChannelRouter, error) {
networkUpdates: make(chan *routingMsg), networkUpdates: make(chan *routingMsg),
topologyClients: &lnutils.SyncMap[uint64, *topologyClient]{}, topologyClients: &lnutils.SyncMap[uint64, *topologyClient]{},
ntfnClientUpdates: make(chan *topologyClientUpdate), ntfnClientUpdates: make(chan *topologyClientUpdate),
channelEdgeMtx: multimutex.NewMutex(), channelEdgeMtx: multimutex.NewMutex[uint64](),
selfNode: selfNode, selfNode: selfNode,
statTicker: ticker.New(defaultStatInterval), statTicker: ticker.New(defaultStatInterval),
stats: new(routerStats), stats: new(routerStats),