lnd/kvdb/sqlbase/readwrite_bucket.go
2023-01-23 09:14:08 +02:00

471 lines
11 KiB
Go

//go:build kvdb_postgres || (kvdb_sqlite && !(windows && (arm || 386)) && !(linux && (ppc64 || mips || mipsle || mips64)))
package sqlbase
import (
"database/sql"
"errors"
"fmt"
"github.com/btcsuite/btcwallet/walletdb"
)
// readWriteBucket stores the bucket id and the buckets transaction.
type readWriteBucket struct {
// id is used to identify the bucket. If id is null, it refers to the
// root bucket.
id *int64
// tx holds the parent transaction.
tx *readWriteTx
table string
}
// newReadWriteBucket creates a new rw bucket with the passed transaction
// and bucket id.
func newReadWriteBucket(tx *readWriteTx, id *int64) *readWriteBucket {
return &readWriteBucket{
id: id,
tx: tx,
table: tx.db.table,
}
}
// NestedReadBucket retrieves a nested read bucket with the given key.
// Returns nil if the bucket does not exist.
func (b *readWriteBucket) NestedReadBucket(key []byte) walletdb.ReadBucket {
return b.NestedReadWriteBucket(key)
}
func parentSelector(id *int64) string {
if id == nil {
return "parent_id IS NULL"
}
return fmt.Sprintf("parent_id=%v", *id)
}
// ForEach invokes the passed function with every key/value pair in
// the bucket. This includes nested buckets, in which case the value
// is nil, but it does not include the key/value pairs within those
// nested buckets.
func (b *readWriteBucket) ForEach(cb func(k, v []byte) error) error {
cursor := b.ReadWriteCursor()
k, v := cursor.First()
for k != nil {
err := cb(k, v)
if err != nil {
return err
}
k, v = cursor.Next()
}
return nil
}
// Get returns the value for the given key. Returns nil if the key does
// not exist in this bucket.
func (b *readWriteBucket) Get(key []byte) []byte {
// Return nil if the key is empty.
if len(key) == 0 {
return nil
}
var value *[]byte
row, cancel := b.tx.QueryRow(
"SELECT value FROM "+b.table+" WHERE "+parentSelector(b.id)+
" AND key=$1", key,
)
defer cancel()
err := row.Scan(&value)
switch {
case err == sql.ErrNoRows:
return nil
case err != nil:
panic(err)
}
// When an empty byte array is stored as the value, Sqlite will decode
// that into nil whereas postgres will decode that as an empty byte
// array. Since returning nil is taken to mean that no value has ever
// been written, we ensure here that we at least return an empty array
// so that nil checks will fail.
if len(*value) == 0 {
return []byte{}
}
return *value
}
// ReadCursor returns a new read-only cursor for this bucket.
func (b *readWriteBucket) ReadCursor() walletdb.ReadCursor {
return newReadWriteCursor(b)
}
// NestedReadWriteBucket retrieves a nested bucket with the given key.
// Returns nil if the bucket does not exist.
func (b *readWriteBucket) NestedReadWriteBucket(
key []byte) walletdb.ReadWriteBucket {
if len(key) == 0 {
return nil
}
var id int64
row, cancel := b.tx.QueryRow(
"SELECT id FROM "+b.table+" WHERE "+parentSelector(b.id)+
" AND key=$1 AND value IS NULL", key,
)
defer cancel()
err := row.Scan(&id)
switch {
case err == sql.ErrNoRows:
return nil
case err != nil:
panic(err)
}
return newReadWriteBucket(b.tx, &id)
}
// CreateBucket creates and returns a new nested bucket with the given key.
// Returns ErrBucketExists if the bucket already exists, ErrBucketNameRequired
// if the key is empty, or ErrIncompatibleValue if the key value is otherwise
// invalid for the particular database implementation. Other errors are
// possible depending on the implementation.
func (b *readWriteBucket) CreateBucket(key []byte) (
walletdb.ReadWriteBucket, error) {
if len(key) == 0 {
return nil, walletdb.ErrBucketNameRequired
}
// Check to see if the bucket already exists.
var (
value *[]byte
id int64
)
row, cancel := b.tx.QueryRow(
"SELECT id,value FROM "+b.table+" WHERE "+parentSelector(b.id)+
" AND key=$1", key,
)
defer cancel()
err := row.Scan(&id, &value)
switch {
case err == sql.ErrNoRows:
case err == nil && value == nil:
return nil, walletdb.ErrBucketExists
case err == nil && value != nil:
return nil, walletdb.ErrIncompatibleValue
case err != nil:
return nil, err
}
// Bucket does not yet exist, so create it. Postgres will generate a
// bucket id for the new bucket.
row, cancel = b.tx.QueryRow(
"INSERT INTO "+b.table+" (parent_id, key) "+
"VALUES($1, $2) RETURNING id", b.id, key,
)
defer cancel()
err = row.Scan(&id)
if err != nil {
return nil, err
}
return newReadWriteBucket(b.tx, &id), nil
}
// CreateBucketIfNotExists creates and returns a new nested bucket with
// the given key if it does not already exist. Returns
// ErrBucketNameRequired if the key is empty or ErrIncompatibleValue
// if the key value is otherwise invalid for the particular database
// backend. Other errors are possible depending on the implementation.
func (b *readWriteBucket) CreateBucketIfNotExists(key []byte) (
walletdb.ReadWriteBucket, error) {
if len(key) == 0 {
return nil, walletdb.ErrBucketNameRequired
}
// Check to see if the bucket already exists.
var (
value *[]byte
id int64
)
row, cancel := b.tx.QueryRow(
"SELECT id,value FROM "+b.table+" WHERE "+parentSelector(b.id)+
" AND key=$1", key,
)
defer cancel()
err := row.Scan(&id, &value)
switch {
// Bucket does not yet exist, so create it now. Postgres will generate a
// bucket id for the new bucket.
case err == sql.ErrNoRows:
row, cancel := b.tx.QueryRow(
"INSERT INTO "+b.table+" (parent_id, key) "+
"VALUES($1, $2) RETURNING id", b.id, key,
)
defer cancel()
err := row.Scan(&id)
if err != nil {
return nil, err
}
case err == nil && value != nil:
return nil, walletdb.ErrIncompatibleValue
case err != nil:
return nil, err
}
return newReadWriteBucket(b.tx, &id), nil
}
// DeleteNestedBucket deletes the nested bucket and its sub-buckets
// pointed to by the passed key. All values in the bucket and sub-buckets
// will be deleted as well.
func (b *readWriteBucket) DeleteNestedBucket(key []byte) error {
if len(key) == 0 {
return walletdb.ErrIncompatibleValue
}
result, err := b.tx.Exec(
"DELETE FROM "+b.table+" WHERE "+parentSelector(b.id)+
" AND key=$1 AND value IS NULL",
key,
)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return walletdb.ErrBucketNotFound
}
return nil
}
// Put updates the value for the passed key.
// Returns ErrKeyRequired if te passed key is empty.
func (b *readWriteBucket) Put(key, value []byte) error {
if len(key) == 0 {
return walletdb.ErrKeyRequired
}
// Prevent NULL being written for an empty value slice.
if value == nil {
value = []byte{}
}
var (
result sql.Result
err error
)
// We are putting a value in a bucket in this table. Try to insert the
// key first. If the key already exists (ON CONFLICT), update the key.
// Do not update a NULL value, because this indicates that the key
// contains a sub-bucket. This case will be caught via RowsAffected
// below.
if b.id == nil {
// ON CONFLICT requires the WHERE parent_id IS NULL hint to let
// Postgres find the NULL-parent_id unique index (<table>_unp).
result, err = b.tx.Exec(
"INSERT INTO "+b.table+" (key, value) VALUES($1, $2) "+
"ON CONFLICT (key) WHERE parent_id IS NULL "+
"DO UPDATE SET value=$2 "+
"WHERE "+b.table+".value IS NOT NULL",
key, value,
)
} else {
// ON CONFLICT requires the WHERE parent_id NOT IS NULL hint to
// let Postgres find the non-NULL-parent_id unique index
// (<table>_up).
result, err = b.tx.Exec(
"INSERT INTO "+b.table+" (key, value, parent_id) "+
"VALUES($1, $2, $3) "+
"ON CONFLICT (key, parent_id) "+
"WHERE parent_id IS NOT NULL "+
"DO UPDATE SET value=$2 "+
"WHERE "+b.table+".value IS NOT NULL",
key, value, b.id,
)
}
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return walletdb.ErrIncompatibleValue
}
return nil
}
// Delete deletes the key/value pointed to by the passed key.
// Returns ErrKeyRequired if the passed key is empty.
func (b *readWriteBucket) Delete(key []byte) error {
if key == nil {
return nil
}
if len(key) == 0 {
return walletdb.ErrKeyRequired
}
// Check to see if a bucket with this key exists.
var dummy int
row, cancel := b.tx.QueryRow(
"SELECT 1 FROM "+b.table+" WHERE "+parentSelector(b.id)+
" AND key=$1 AND value IS NULL", key,
)
defer cancel()
err := row.Scan(&dummy)
switch {
// No bucket exists, proceed to deletion of the key.
case err == sql.ErrNoRows:
case err != nil:
return err
// Bucket exists.
default:
return walletdb.ErrIncompatibleValue
}
_, err = b.tx.Exec(
"DELETE FROM "+b.table+" WHERE key=$1 AND "+
parentSelector(b.id)+" AND value IS NOT NULL",
key,
)
if err != nil {
return err
}
return nil
}
// ReadWriteCursor returns a new read-write cursor for this bucket.
func (b *readWriteBucket) ReadWriteCursor() walletdb.ReadWriteCursor {
return newReadWriteCursor(b)
}
// Tx returns the buckets transaction.
func (b *readWriteBucket) Tx() walletdb.ReadWriteTx {
return b.tx
}
// NextSequence returns an autoincrementing sequence number for this bucket.
// Note that this is not a thread safe function and as such it must not be used
// for synchronization.
func (b *readWriteBucket) NextSequence() (uint64, error) {
seq := b.Sequence() + 1
return seq, b.SetSequence(seq)
}
// SetSequence updates the sequence number for the bucket.
func (b *readWriteBucket) SetSequence(v uint64) error {
if b.id == nil {
panic("sequence not supported on top level bucket")
}
result, err := b.tx.Exec(
"UPDATE "+b.table+" SET sequence=$2 WHERE id=$1",
b.id, int64(v),
)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return errors.New("cannot set sequence")
}
return nil
}
// Sequence returns the current sequence number for this bucket without
// incrementing it.
func (b *readWriteBucket) Sequence() uint64 {
if b.id == nil {
panic("sequence not supported on top level bucket")
}
var seq int64
row, cancel := b.tx.QueryRow(
"SELECT sequence FROM "+b.table+" WHERE id=$1 "+
"AND sequence IS NOT NULL",
b.id,
)
defer cancel()
err := row.Scan(&seq)
switch {
case err == sql.ErrNoRows:
return 0
case err != nil:
panic(err)
}
return uint64(seq)
}
// Prefetch will attempt to prefetch all values under a path from the passed
// bucket.
func (b *readWriteBucket) Prefetch(paths ...[]string) {}
// ForAll is an optimized version of ForEach with the limitation that no
// additional queries can be executed within the callback.
func (b *readWriteBucket) ForAll(cb func(k, v []byte) error) error {
rows, cancel, err := b.tx.Query(
"SELECT key, value FROM " + b.table + " WHERE " +
parentSelector(b.id) + " ORDER BY key",
)
if err != nil {
return err
}
defer cancel()
for rows.Next() {
var key, value []byte
err := rows.Scan(&key, &value)
if err != nil {
return err
}
err = cb(key, value)
if err != nil {
return err
}
}
return nil
}