input: use multmutex to increase concurrency for musig session manager

By using the multimutex here, we'll no longer rely on a single mutex for
the entire musig session set like we used to. Instead, we can use the
session ID to key into a map of mutexes and use those directly.
This commit is contained in:
Olaoluwa Osuntokun 2023-06-01 17:52:41 -07:00
parent dafc2a3e5a
commit 001c5b0e0b
No known key found for this signature in database
GPG key ID: 3BBD59E99B280306

View file

@ -3,12 +3,12 @@ package input
import (
"crypto/sha256"
"fmt"
"sync"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/multimutex"
)
// MuSig2State is a struct that holds on to the internal signing session state
@ -34,10 +34,10 @@ type PrivKeyFetcher func(*keychain.KeyDescriptor) (*btcec.PrivateKey, error)
// musig sessions. Each session is identified by a unique session ID which is
// used by callers to interact with a given session.
type MusigSessionManager struct {
sync.Mutex
keyFetcher PrivKeyFetcher
sessionMtx *multimutex.Mutex[MuSig2SessionID]
musig2Sessions map[MuSig2SessionID]*MuSig2State
}
@ -45,7 +45,9 @@ type MusigSessionManager struct {
// fetcher.
func NewMusigSessionManager(keyFetcher PrivKeyFetcher) *MusigSessionManager {
return &MusigSessionManager{
keyFetcher: keyFetcher,
keyFetcher: keyFetcher,
musig2Sessions: make(map[MuSig2SessionID]*MuSig2State),
sessionMtx: multimutex.NewMutex[MuSig2SessionID](),
}
}
@ -70,7 +72,7 @@ func (m *MusigSessionManager) MuSig2CreateSession(bipVersion MuSig2Version,
KeyLocator: keyLoc,
})
if err != nil {
return nil, fmt.Errorf("error deriving private key: %v", err)
return nil, fmt.Errorf("error deriving private key: %w", err)
}
// Create a signing context and session with the given private key and
@ -98,7 +100,7 @@ func (m *MusigSessionManager) MuSig2CreateSession(bipVersion MuSig2Version,
// Register the new session.
combinedKey, err := musigContext.CombinedKey()
if err != nil {
return nil, fmt.Errorf("error getting combined key: %v", err)
return nil, fmt.Errorf("error getting combined key: %w", err)
}
session := &MuSig2State{
MuSig2SessionInfo: MuSig2SessionInfo{
@ -120,7 +122,7 @@ func (m *MusigSessionManager) MuSig2CreateSession(bipVersion MuSig2Version,
if tweaks.HasTaprootTweak() {
internalKey, err := musigContext.TaprootInternalKey()
if err != nil {
return nil, fmt.Errorf("error getting internal key: %v",
return nil, fmt.Errorf("error getting internal key: %w",
err)
}
session.TaprootInternalKey = internalKey
@ -129,9 +131,12 @@ func (m *MusigSessionManager) MuSig2CreateSession(bipVersion MuSig2Version,
// Since we generate new nonces for every session, there is no way that
// a session with the same ID already exists. So even if we call the API
// twice with the same signers, we still get a new ID.
m.Lock()
//
// We'll use just all zeroes as the session ID for the mutex, as this
// is a "global" action.
m.sessionMtx.Lock(MuSig2SessionID{})
m.musig2Sessions[session.SessionID] = session
m.Unlock()
m.sessionMtx.Unlock(MuSig2SessionID{})
return &session.MuSig2SessionInfo, nil
}
@ -149,8 +154,8 @@ func (m *MusigSessionManager) MuSig2Sign(sessionID MuSig2SessionID,
// We hold the lock during the whole operation, we don't want any
// interference with calls that might come through in parallel for the
// same session.
m.Lock()
defer m.Unlock()
m.sessionMtx.Lock(sessionID)
defer m.sessionMtx.Unlock(sessionID)
session, ok := m.musig2Sessions[sessionID]
if !ok {
@ -190,8 +195,8 @@ func (m *MusigSessionManager) MuSig2CombineSig(sessionID MuSig2SessionID,
// We hold the lock during the whole operation, we don't want any
// interference with calls that might come through in parallel for the
// same session.
m.Lock()
defer m.Unlock()
m.sessionMtx.Lock(sessionID)
defer m.sessionMtx.Unlock(sessionID)
session, ok := m.musig2Sessions[sessionID]
if !ok {
@ -237,8 +242,8 @@ func (m *MusigSessionManager) MuSig2Cleanup(sessionID MuSig2SessionID) error {
// We hold the lock during the whole operation, we don't want any
// interference with calls that might come through in parallel for the
// same session.
m.Lock()
defer m.Unlock()
m.sessionMtx.Lock(sessionID)
defer m.sessionMtx.Unlock(sessionID)
_, ok := m.musig2Sessions[sessionID]
if !ok {
@ -259,8 +264,8 @@ func (m *MusigSessionManager) MuSig2RegisterNonces(sessionID MuSig2SessionID,
// We hold the lock during the whole operation, we don't want any
// interference with calls that might come through in parallel for the
// same session.
m.Lock()
defer m.Unlock()
m.sessionMtx.Lock(sessionID)
defer m.sessionMtx.Unlock(sessionID)
session, ok := m.musig2Sessions[sessionID]
if !ok {