kvdb/postgres: make global dbConns thread safe

In this commit, a mutex is added to guard access to the global dbConns
variable in the kvdb/postgres package.
This commit is contained in:
Elle Mouton 2022-12-15 17:38:14 +02:00
parent 72dbc3dbb4
commit c6abf585ee
No known key found for this signature in database
GPG Key ID: D7D916376026F177
2 changed files with 24 additions and 7 deletions

View File

@ -57,11 +57,21 @@ type db struct {
// Enforce db implements the walletdb.DB interface.
var _ walletdb.DB = (*db)(nil)
// Global set of database connections.
var dbConns *dbConnSet
var (
// dbConns is a global set of database connections.
dbConns *dbConnSet
dbConnsMu sync.Mutex
)
// Init initializes the global set of database connections.
func Init(maxConnections int) {
dbConnsMu.Lock()
defer dbConnsMu.Unlock()
if dbConns != nil {
return
}
dbConns = newDbConnSet(maxConnections)
}
@ -70,6 +80,9 @@ func Init(maxConnections int) {
func newPostgresBackend(ctx context.Context, config *Config, prefix string) (
*db, error) {
dbConnsMu.Lock()
defer dbConnsMu.Unlock()
if prefix == "" {
return nil, errors.New("empty postgres prefix")
}
@ -256,6 +269,9 @@ func (db *db) Copy(w io.Writer) error {
// Close cleanly shuts down the database and syncs all data.
// This function is part of the walletdb.Db interface implementation.
func (db *db) Close() error {
dbConnsMu.Lock()
defer dbConnsMu.Unlock()
log.Infof("Closing database %v", db.prefix)
return dbConns.Close(db.cfg.Dsn)

View File

@ -19,7 +19,8 @@ type dbConnSet struct {
dbConn map[string]*dbConn
maxConnections int
sync.Mutex
// mu is used to guard access to the dbConn map.
mu sync.Mutex
}
// newDbConnSet initializes a new set of connections.
@ -33,8 +34,8 @@ func newDbConnSet(maxConnections int) *dbConnSet {
// Open opens a new database connection. If a connection already exists for the
// given dsn, the existing connection is returned.
func (d *dbConnSet) Open(dsn string) (*sql.DB, error) {
d.Lock()
defer d.Unlock()
d.mu.Lock()
defer d.mu.Unlock()
if dbConn, ok := d.dbConn[dsn]; ok {
dbConn.count++
@ -66,8 +67,8 @@ func (d *dbConnSet) Open(dsn string) (*sql.DB, error) {
// Close closes the connection with the given dsn. If there are still other
// users of the same connection, this function does nothing.
func (d *dbConnSet) Close(dsn string) error {
d.Lock()
defer d.Unlock()
d.mu.Lock()
defer d.mu.Unlock()
dbConn, ok := d.dbConn[dsn]
if !ok {