input: use multmutex to increase concurrency for musig session manager

By using the multimutex here, we'll no longer rely on a single mutex for
the entire musig session set like we used to. Instead, we can use the
session ID to key into a map of mutexes and use those directly.
This commit is contained in:
Olaoluwa Osuntokun 2023-06-01 17:52:41 -07:00
parent dafc2a3e5a
commit 001c5b0e0b
No known key found for this signature in database
GPG key ID: 3BBD59E99B280306

View file

@ -3,12 +3,12 @@ package input
import ( import (
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"sync"
"github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/multimutex"
) )
// MuSig2State is a struct that holds on to the internal signing session state // MuSig2State is a struct that holds on to the internal signing session state
@ -34,10 +34,10 @@ type PrivKeyFetcher func(*keychain.KeyDescriptor) (*btcec.PrivateKey, error)
// musig sessions. Each session is identified by a unique session ID which is // musig sessions. Each session is identified by a unique session ID which is
// used by callers to interact with a given session. // used by callers to interact with a given session.
type MusigSessionManager struct { type MusigSessionManager struct {
sync.Mutex
keyFetcher PrivKeyFetcher keyFetcher PrivKeyFetcher
sessionMtx *multimutex.Mutex[MuSig2SessionID]
musig2Sessions map[MuSig2SessionID]*MuSig2State musig2Sessions map[MuSig2SessionID]*MuSig2State
} }
@ -45,7 +45,9 @@ type MusigSessionManager struct {
// fetcher. // fetcher.
func NewMusigSessionManager(keyFetcher PrivKeyFetcher) *MusigSessionManager { func NewMusigSessionManager(keyFetcher PrivKeyFetcher) *MusigSessionManager {
return &MusigSessionManager{ return &MusigSessionManager{
keyFetcher: keyFetcher, keyFetcher: keyFetcher,
musig2Sessions: make(map[MuSig2SessionID]*MuSig2State),
sessionMtx: multimutex.NewMutex[MuSig2SessionID](),
} }
} }
@ -70,7 +72,7 @@ func (m *MusigSessionManager) MuSig2CreateSession(bipVersion MuSig2Version,
KeyLocator: keyLoc, KeyLocator: keyLoc,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("error deriving private key: %v", err) return nil, fmt.Errorf("error deriving private key: %w", err)
} }
// Create a signing context and session with the given private key and // Create a signing context and session with the given private key and
@ -98,7 +100,7 @@ func (m *MusigSessionManager) MuSig2CreateSession(bipVersion MuSig2Version,
// Register the new session. // Register the new session.
combinedKey, err := musigContext.CombinedKey() combinedKey, err := musigContext.CombinedKey()
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting combined key: %v", err) return nil, fmt.Errorf("error getting combined key: %w", err)
} }
session := &MuSig2State{ session := &MuSig2State{
MuSig2SessionInfo: MuSig2SessionInfo{ MuSig2SessionInfo: MuSig2SessionInfo{
@ -120,7 +122,7 @@ func (m *MusigSessionManager) MuSig2CreateSession(bipVersion MuSig2Version,
if tweaks.HasTaprootTweak() { if tweaks.HasTaprootTweak() {
internalKey, err := musigContext.TaprootInternalKey() internalKey, err := musigContext.TaprootInternalKey()
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting internal key: %v", return nil, fmt.Errorf("error getting internal key: %w",
err) err)
} }
session.TaprootInternalKey = internalKey session.TaprootInternalKey = internalKey
@ -129,9 +131,12 @@ func (m *MusigSessionManager) MuSig2CreateSession(bipVersion MuSig2Version,
// Since we generate new nonces for every session, there is no way that // Since we generate new nonces for every session, there is no way that
// a session with the same ID already exists. So even if we call the API // a session with the same ID already exists. So even if we call the API
// twice with the same signers, we still get a new ID. // twice with the same signers, we still get a new ID.
m.Lock() //
// We'll use just all zeroes as the session ID for the mutex, as this
// is a "global" action.
m.sessionMtx.Lock(MuSig2SessionID{})
m.musig2Sessions[session.SessionID] = session m.musig2Sessions[session.SessionID] = session
m.Unlock() m.sessionMtx.Unlock(MuSig2SessionID{})
return &session.MuSig2SessionInfo, nil return &session.MuSig2SessionInfo, nil
} }
@ -149,8 +154,8 @@ func (m *MusigSessionManager) MuSig2Sign(sessionID MuSig2SessionID,
// We hold the lock during the whole operation, we don't want any // We hold the lock during the whole operation, we don't want any
// interference with calls that might come through in parallel for the // interference with calls that might come through in parallel for the
// same session. // same session.
m.Lock() m.sessionMtx.Lock(sessionID)
defer m.Unlock() defer m.sessionMtx.Unlock(sessionID)
session, ok := m.musig2Sessions[sessionID] session, ok := m.musig2Sessions[sessionID]
if !ok { if !ok {
@ -190,8 +195,8 @@ func (m *MusigSessionManager) MuSig2CombineSig(sessionID MuSig2SessionID,
// We hold the lock during the whole operation, we don't want any // We hold the lock during the whole operation, we don't want any
// interference with calls that might come through in parallel for the // interference with calls that might come through in parallel for the
// same session. // same session.
m.Lock() m.sessionMtx.Lock(sessionID)
defer m.Unlock() defer m.sessionMtx.Unlock(sessionID)
session, ok := m.musig2Sessions[sessionID] session, ok := m.musig2Sessions[sessionID]
if !ok { if !ok {
@ -237,8 +242,8 @@ func (m *MusigSessionManager) MuSig2Cleanup(sessionID MuSig2SessionID) error {
// We hold the lock during the whole operation, we don't want any // We hold the lock during the whole operation, we don't want any
// interference with calls that might come through in parallel for the // interference with calls that might come through in parallel for the
// same session. // same session.
m.Lock() m.sessionMtx.Lock(sessionID)
defer m.Unlock() defer m.sessionMtx.Unlock(sessionID)
_, ok := m.musig2Sessions[sessionID] _, ok := m.musig2Sessions[sessionID]
if !ok { if !ok {
@ -259,8 +264,8 @@ func (m *MusigSessionManager) MuSig2RegisterNonces(sessionID MuSig2SessionID,
// We hold the lock during the whole operation, we don't want any // We hold the lock during the whole operation, we don't want any
// interference with calls that might come through in parallel for the // interference with calls that might come through in parallel for the
// same session. // same session.
m.Lock() m.sessionMtx.Lock(sessionID)
defer m.Unlock() defer m.sessionMtx.Unlock(sessionID)
session, ok := m.musig2Sessions[sessionID] session, ok := m.musig2Sessions[sessionID]
if !ok { if !ok {