From ae09ab2a21421951b23d3dcf3950c99913db3f49 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Sat, 25 Nov 2023 04:46:55 +0800 Subject: [PATCH] input: use `lnutils.SyncMap` to store musig2 sessions --- input/musig2_session_manager.go | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/input/musig2_session_manager.go b/input/musig2_session_manager.go index 827652cc2..b2cac4899 100644 --- a/input/musig2_session_manager.go +++ b/input/musig2_session_manager.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/multimutex" ) @@ -38,16 +39,18 @@ type MusigSessionManager struct { sessionMtx *multimutex.Mutex[MuSig2SessionID] - musig2Sessions map[MuSig2SessionID]*MuSig2State + musig2Sessions *lnutils.SyncMap[MuSig2SessionID, *MuSig2State] } // NewMusigSessionManager creates a new musig manager given an abstract key // fetcher. func NewMusigSessionManager(keyFetcher PrivKeyFetcher) *MusigSessionManager { return &MusigSessionManager{ - keyFetcher: keyFetcher, - musig2Sessions: make(map[MuSig2SessionID]*MuSig2State), - sessionMtx: multimutex.NewMutex[MuSig2SessionID](), + keyFetcher: keyFetcher, + musig2Sessions: &lnutils.SyncMap[ + MuSig2SessionID, *MuSig2State, + ]{}, + sessionMtx: multimutex.NewMutex[MuSig2SessionID](), } } @@ -134,9 +137,7 @@ func (m *MusigSessionManager) MuSig2CreateSession(bipVersion MuSig2Version, // // We'll use just all zeroes as the session ID for the mutex, as this // is a "global" action. - m.sessionMtx.Lock(MuSig2SessionID{}) - m.musig2Sessions[session.SessionID] = session - m.sessionMtx.Unlock(MuSig2SessionID{}) + m.musig2Sessions.Store(session.SessionID, session) return &session.MuSig2SessionInfo, nil } @@ -157,7 +158,7 @@ func (m *MusigSessionManager) MuSig2Sign(sessionID MuSig2SessionID, m.sessionMtx.Lock(sessionID) defer m.sessionMtx.Unlock(sessionID) - session, ok := m.musig2Sessions[sessionID] + session, ok := m.musig2Sessions.Load(sessionID) if !ok { return nil, fmt.Errorf("session with ID %x not found", sessionID[:]) @@ -178,7 +179,7 @@ func (m *MusigSessionManager) MuSig2Sign(sessionID MuSig2SessionID, // Clean up our local state if requested. if cleanUp { - delete(m.musig2Sessions, sessionID) + m.musig2Sessions.Delete(sessionID) } return partialSig, nil @@ -198,7 +199,7 @@ func (m *MusigSessionManager) MuSig2CombineSig(sessionID MuSig2SessionID, m.sessionMtx.Lock(sessionID) defer m.sessionMtx.Unlock(sessionID) - session, ok := m.musig2Sessions[sessionID] + session, ok := m.musig2Sessions.Load(sessionID) if !ok { return nil, false, fmt.Errorf("session with ID %x not found", sessionID[:]) @@ -231,7 +232,7 @@ func (m *MusigSessionManager) MuSig2CombineSig(sessionID MuSig2SessionID, // there is nothing more left to do. if session.HaveAllSigs { finalSig = session.session.FinalSig() - delete(m.musig2Sessions, sessionID) + m.musig2Sessions.Delete(sessionID) } return finalSig, session.HaveAllSigs, nil @@ -245,12 +246,12 @@ func (m *MusigSessionManager) MuSig2Cleanup(sessionID MuSig2SessionID) error { m.sessionMtx.Lock(sessionID) defer m.sessionMtx.Unlock(sessionID) - _, ok := m.musig2Sessions[sessionID] + _, ok := m.musig2Sessions.Load(sessionID) if !ok { return fmt.Errorf("session with ID %x not found", sessionID[:]) } - delete(m.musig2Sessions, sessionID) + m.musig2Sessions.Delete(sessionID) return nil } @@ -267,7 +268,7 @@ func (m *MusigSessionManager) MuSig2RegisterNonces(sessionID MuSig2SessionID, m.sessionMtx.Lock(sessionID) defer m.sessionMtx.Unlock(sessionID) - session, ok := m.musig2Sessions[sessionID] + session, ok := m.musig2Sessions.Load(sessionID) if !ok { return false, fmt.Errorf("session with ID %x not found", sessionID[:])