From 001c5b0e0ba4da97675d1547349d4a0bf3b9bfb8 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Thu, 1 Jun 2023 17:52:41 -0700 Subject: [PATCH] input: use multmutex to increase concurrency for musig session manager By using the multimutex here, we'll no longer rely on a single mutex for the entire musig session set like we used to. Instead, we can use the session ID to key into a map of mutexes and use those directly. --- input/musig2_session_manager.go | 39 +++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/input/musig2_session_manager.go b/input/musig2_session_manager.go index 3572614fb..3bf26004a 100644 --- a/input/musig2_session_manager.go +++ b/input/musig2_session_manager.go @@ -3,12 +3,12 @@ package input import ( "crypto/sha256" "fmt" - "sync" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/multimutex" ) // MuSig2State is a struct that holds on to the internal signing session state @@ -34,10 +34,10 @@ type PrivKeyFetcher func(*keychain.KeyDescriptor) (*btcec.PrivateKey, error) // musig sessions. Each session is identified by a unique session ID which is // used by callers to interact with a given session. type MusigSessionManager struct { - sync.Mutex - keyFetcher PrivKeyFetcher + sessionMtx *multimutex.Mutex[MuSig2SessionID] + musig2Sessions map[MuSig2SessionID]*MuSig2State } @@ -45,7 +45,9 @@ type MusigSessionManager struct { // fetcher. func NewMusigSessionManager(keyFetcher PrivKeyFetcher) *MusigSessionManager { return &MusigSessionManager{ - keyFetcher: keyFetcher, + keyFetcher: keyFetcher, + musig2Sessions: make(map[MuSig2SessionID]*MuSig2State), + sessionMtx: multimutex.NewMutex[MuSig2SessionID](), } } @@ -70,7 +72,7 @@ func (m *MusigSessionManager) MuSig2CreateSession(bipVersion MuSig2Version, KeyLocator: keyLoc, }) if err != nil { - return nil, fmt.Errorf("error deriving private key: %v", err) + return nil, fmt.Errorf("error deriving private key: %w", err) } // Create a signing context and session with the given private key and @@ -98,7 +100,7 @@ func (m *MusigSessionManager) MuSig2CreateSession(bipVersion MuSig2Version, // Register the new session. combinedKey, err := musigContext.CombinedKey() if err != nil { - return nil, fmt.Errorf("error getting combined key: %v", err) + return nil, fmt.Errorf("error getting combined key: %w", err) } session := &MuSig2State{ MuSig2SessionInfo: MuSig2SessionInfo{ @@ -120,7 +122,7 @@ func (m *MusigSessionManager) MuSig2CreateSession(bipVersion MuSig2Version, if tweaks.HasTaprootTweak() { internalKey, err := musigContext.TaprootInternalKey() if err != nil { - return nil, fmt.Errorf("error getting internal key: %v", + return nil, fmt.Errorf("error getting internal key: %w", err) } session.TaprootInternalKey = internalKey @@ -129,9 +131,12 @@ func (m *MusigSessionManager) MuSig2CreateSession(bipVersion MuSig2Version, // Since we generate new nonces for every session, there is no way that // a session with the same ID already exists. So even if we call the API // twice with the same signers, we still get a new ID. - m.Lock() + // + // We'll use just all zeroes as the session ID for the mutex, as this + // is a "global" action. + m.sessionMtx.Lock(MuSig2SessionID{}) m.musig2Sessions[session.SessionID] = session - m.Unlock() + m.sessionMtx.Unlock(MuSig2SessionID{}) return &session.MuSig2SessionInfo, nil } @@ -149,8 +154,8 @@ func (m *MusigSessionManager) MuSig2Sign(sessionID MuSig2SessionID, // We hold the lock during the whole operation, we don't want any // interference with calls that might come through in parallel for the // same session. - m.Lock() - defer m.Unlock() + m.sessionMtx.Lock(sessionID) + defer m.sessionMtx.Unlock(sessionID) session, ok := m.musig2Sessions[sessionID] if !ok { @@ -190,8 +195,8 @@ func (m *MusigSessionManager) MuSig2CombineSig(sessionID MuSig2SessionID, // We hold the lock during the whole operation, we don't want any // interference with calls that might come through in parallel for the // same session. - m.Lock() - defer m.Unlock() + m.sessionMtx.Lock(sessionID) + defer m.sessionMtx.Unlock(sessionID) session, ok := m.musig2Sessions[sessionID] if !ok { @@ -237,8 +242,8 @@ func (m *MusigSessionManager) MuSig2Cleanup(sessionID MuSig2SessionID) error { // We hold the lock during the whole operation, we don't want any // interference with calls that might come through in parallel for the // same session. - m.Lock() - defer m.Unlock() + m.sessionMtx.Lock(sessionID) + defer m.sessionMtx.Unlock(sessionID) _, ok := m.musig2Sessions[sessionID] if !ok { @@ -259,8 +264,8 @@ func (m *MusigSessionManager) MuSig2RegisterNonces(sessionID MuSig2SessionID, // We hold the lock during the whole operation, we don't want any // interference with calls that might come through in parallel for the // same session. - m.Lock() - defer m.Unlock() + m.sessionMtx.Lock(sessionID) + defer m.sessionMtx.Unlock(sessionID) session, ok := m.musig2Sessions[sessionID] if !ok {