diff --git a/kvdb/go.mod b/kvdb/go.mod index a07c76e05..daca93119 100644 --- a/kvdb/go.mod +++ b/kvdb/go.mod @@ -6,6 +6,8 @@ require ( github.com/davecgh/go-spew v1.1.1 github.com/fergusstrange/embedded-postgres v1.10.0 github.com/google/btree v1.0.1 + github.com/jackc/pgconn v1.14.0 + github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa github.com/jackc/pgx/v4 v4.18.1 github.com/lightningnetwork/lnd/healthcheck v1.0.0 github.com/stretchr/testify v1.8.2 @@ -37,7 +39,6 @@ require ( github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect - github.com/jackc/pgconn v1.14.0 // indirect github.com/jackc/pgio v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgproto3/v2 v2.3.2 // indirect diff --git a/kvdb/go.sum b/kvdb/go.sum index 0bb9edee8..9986ec980 100644 --- a/kvdb/go.sum +++ b/kvdb/go.sum @@ -143,6 +143,8 @@ github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8 github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= github.com/jackc/pgconn v1.14.0 h1:vrbA9Ud87g6JdFWkHTJXppVce58qPIdP7N8y0Ml/A7Q= github.com/jackc/pgconn v1.14.0/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8ti++E= +github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa h1:s+4MhCQ6YrzisK6hFJUX53drDT4UsSW3DEhKn0ifuHw= +github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= diff --git a/kvdb/sqlbase/db.go b/kvdb/sqlbase/db.go index 7defed2d7..f6f2c7590 100644 --- a/kvdb/sqlbase/db.go +++ b/kvdb/sqlbase/db.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "math" + "math/rand" "sync" "time" @@ -37,12 +38,6 @@ const ( DefaultMaxRetryDelay = time.Second * 5 ) -var ( - // ErrRetriesExceeded is returned when a transaction is retried more - // than the max allowed valued without a success. - ErrRetriesExceeded = errors.New("db tx retries exceeded") -) - // Config holds a set of configuration options of a sql database connection. type Config struct { // DriverName is the string that defines the registered sql driver that @@ -238,7 +233,7 @@ func randRetryDelay(initialRetryDelay, maxRetryDelay time.Duration, attempt int) time.Duration { halfDelay := initialRetryDelay / 2 - randDelay := prand.Int63n(int64(initialRetryDelay)) //nolint:gosec + randDelay := rand.Int63n(int64(initialRetryDelay)) //nolint:gosec // 50% plus 0%-100% gives us the range of 50%-150%. initialDelay := halfDelay + time.Duration(randDelay) @@ -286,11 +281,11 @@ func (db *db) executeTransaction(f func(tx walletdb.ReadWriteTx) error, select { // Before we try again, we'll wait with a random backoff based // on the retry delay. - case time.After(retryDelay): + case <-time.After(retryDelay): return true // If the daemon is shutting down, then we'll exit early. - case <-db.Context.Done(): + case <-db.ctx.Done(): return false } } diff --git a/kvdb/sqlbase/sqlerrors.go b/kvdb/sqlbase/sqlerrors.go new file mode 100644 index 000000000..ccbb9c610 --- /dev/null +++ b/kvdb/sqlbase/sqlerrors.go @@ -0,0 +1,115 @@ +//go:build kvdb_postgres || (kvdb_sqlite && !(windows && (arm || 386)) && !(linux && (ppc64 || mips || mipsle || mips64))) + +package sqlbase + +import ( + "errors" + "fmt" + + "github.com/jackc/pgconn" + "github.com/jackc/pgerrcode" + "modernc.org/sqlite" + sqlite3 "modernc.org/sqlite/lib" +) + +var ( + // ErrRetriesExceeded is returned when a transaction is retried more + // than the max allowed valued without a success. + ErrRetriesExceeded = errors.New("db tx retries exceeded") +) + +// MapSQLError attempts to interpret a given error as a database agnostic SQL +// error. +func MapSQLError(err error) error { + // Attempt to interpret the error as a sqlite error. + var sqliteErr *sqlite.Error + if errors.As(err, &sqliteErr) { + return parseSqliteError(sqliteErr) + } + + // Attempt to interpret the error as a postgres error. + var pqErr *pgconn.PgError + if errors.As(err, &pqErr) { + return parsePostgresError(pqErr) + } + + // Return original error if it could not be classified as a database + // specific error. + return err +} + +// parsePostgresError attempts to parse a sqlite error as a database agnostic +// SQL error. +func parseSqliteError(sqliteErr *sqlite.Error) error { + switch sqliteErr.Code() { + // Handle unique constraint violation error. + case sqlite3.SQLITE_CONSTRAINT_UNIQUE: + return &ErrSQLUniqueConstraintViolation{ + DBError: sqliteErr, + } + + // Database is currently busy, so we'll need to try again. + case sqlite3.SQLITE_BUSY: + return &ErrSerializationError{ + DBError: sqliteErr, + } + + default: + return fmt.Errorf("unknown sqlite error: %w", sqliteErr) + } +} + +// parsePostgresError attempts to parse a postgres error as a database agnostic +// SQL error. +func parsePostgresError(pqErr *pgconn.PgError) error { + switch pqErr.Code { + // Handle unique constraint violation error. + case pgerrcode.UniqueViolation: + return &ErrSQLUniqueConstraintViolation{ + DBError: pqErr, + } + + // Unable to serialize the transaction, so we'll need to try again. + case pgerrcode.SerializationFailure: + return &ErrSerializationError{ + DBError: pqErr, + } + + default: + return fmt.Errorf("unknown postgres error: %w", pqErr) + } +} + +// ErrSQLUniqueConstraintViolation is an error type which represents a database +// agnostic SQL unique constraint violation. +type ErrSQLUniqueConstraintViolation struct { + DBError error +} + +func (e ErrSQLUniqueConstraintViolation) Error() string { + return fmt.Sprintf("sql unique constraint violation: %v", e.DBError) +} + +// ErrSerializationError is an error type which represents a database agnostic +// error that a transaction couldn't be serialized with other concurrent db +// transactions. +type ErrSerializationError struct { + DBError error +} + +// Unwrap returns the wrapped error. +func (e ErrSerializationError) Unwrap() error { + return e.DBError +} + +// Error returns the error message. +func (e ErrSerializationError) Error() string { + return e.DBError.Error() +} + +// IsSerializationError returns true if the given error is a serialization +// error. +func IsSerializationError(err error) bool { + var serializationError *ErrSerializationError + return errors.As(err, &serializationError) +}