//go:build kvdb_postgres || (kvdb_sqlite && !(windows && (arm || 386)) && !(linux && (ppc64 || mips || mipsle || mips64))) package sqlbase import ( "database/sql" "errors" "fmt" "github.com/btcsuite/btcwallet/walletdb" ) // readWriteBucket stores the bucket id and the buckets transaction. type readWriteBucket struct { // id is used to identify the bucket. If id is null, it refers to the // root bucket. id *int64 // tx holds the parent transaction. tx *readWriteTx table string } // newReadWriteBucket creates a new rw bucket with the passed transaction // and bucket id. func newReadWriteBucket(tx *readWriteTx, id *int64) *readWriteBucket { return &readWriteBucket{ id: id, tx: tx, table: tx.db.table, } } // NestedReadBucket retrieves a nested read bucket with the given key. // Returns nil if the bucket does not exist. func (b *readWriteBucket) NestedReadBucket(key []byte) walletdb.ReadBucket { return b.NestedReadWriteBucket(key) } func parentSelector(id *int64) string { if id == nil { return "parent_id IS NULL" } return fmt.Sprintf("parent_id=%v", *id) } // ForEach invokes the passed function with every key/value pair in // the bucket. This includes nested buckets, in which case the value // is nil, but it does not include the key/value pairs within those // nested buckets. func (b *readWriteBucket) ForEach(cb func(k, v []byte) error) error { cursor := b.ReadWriteCursor() k, v := cursor.First() for k != nil { err := cb(k, v) if err != nil { return err } k, v = cursor.Next() } return nil } // Get returns the value for the given key. Returns nil if the key does // not exist in this bucket. func (b *readWriteBucket) Get(key []byte) []byte { // Return nil if the key is empty. if len(key) == 0 { return nil } var value *[]byte row, cancel := b.tx.QueryRow( "SELECT value FROM "+b.table+" WHERE "+parentSelector(b.id)+ " AND key=$1", key, ) defer cancel() err := row.Scan(&value) switch { case err == sql.ErrNoRows: return nil case err != nil: panic(err) } // When an empty byte array is stored as the value, Sqlite will decode // that into nil whereas postgres will decode that as an empty byte // array. Since returning nil is taken to mean that no value has ever // been written, we ensure here that we at least return an empty array // so that nil checks will fail. if len(*value) == 0 { return []byte{} } return *value } // ReadCursor returns a new read-only cursor for this bucket. func (b *readWriteBucket) ReadCursor() walletdb.ReadCursor { return newReadWriteCursor(b) } // NestedReadWriteBucket retrieves a nested bucket with the given key. // Returns nil if the bucket does not exist. func (b *readWriteBucket) NestedReadWriteBucket( key []byte) walletdb.ReadWriteBucket { if len(key) == 0 { return nil } var id int64 row, cancel := b.tx.QueryRow( "SELECT id FROM "+b.table+" WHERE "+parentSelector(b.id)+ " AND key=$1 AND value IS NULL", key, ) defer cancel() err := row.Scan(&id) switch { case err == sql.ErrNoRows: return nil case err != nil: panic(err) } return newReadWriteBucket(b.tx, &id) } // CreateBucket creates and returns a new nested bucket with the given key. // Returns ErrBucketExists if the bucket already exists, ErrBucketNameRequired // if the key is empty, or ErrIncompatibleValue if the key value is otherwise // invalid for the particular database implementation. Other errors are // possible depending on the implementation. func (b *readWriteBucket) CreateBucket(key []byte) ( walletdb.ReadWriteBucket, error) { if len(key) == 0 { return nil, walletdb.ErrBucketNameRequired } // Check to see if the bucket already exists. var ( value *[]byte id int64 ) row, cancel := b.tx.QueryRow( "SELECT id,value FROM "+b.table+" WHERE "+parentSelector(b.id)+ " AND key=$1", key, ) defer cancel() err := row.Scan(&id, &value) switch { case err == sql.ErrNoRows: case err == nil && value == nil: return nil, walletdb.ErrBucketExists case err == nil && value != nil: return nil, walletdb.ErrIncompatibleValue case err != nil: return nil, err } // Bucket does not yet exist, so create it. Postgres will generate a // bucket id for the new bucket. row, cancel = b.tx.QueryRow( "INSERT INTO "+b.table+" (parent_id, key) "+ "VALUES($1, $2) RETURNING id", b.id, key, ) defer cancel() err = row.Scan(&id) if err != nil { return nil, err } return newReadWriteBucket(b.tx, &id), nil } // CreateBucketIfNotExists creates and returns a new nested bucket with // the given key if it does not already exist. Returns // ErrBucketNameRequired if the key is empty or ErrIncompatibleValue // if the key value is otherwise invalid for the particular database // backend. Other errors are possible depending on the implementation. func (b *readWriteBucket) CreateBucketIfNotExists(key []byte) ( walletdb.ReadWriteBucket, error) { if len(key) == 0 { return nil, walletdb.ErrBucketNameRequired } // Check to see if the bucket already exists. var ( value *[]byte id int64 ) row, cancel := b.tx.QueryRow( "SELECT id,value FROM "+b.table+" WHERE "+parentSelector(b.id)+ " AND key=$1", key, ) defer cancel() err := row.Scan(&id, &value) switch { // Bucket does not yet exist, so create it now. Postgres will generate a // bucket id for the new bucket. case err == sql.ErrNoRows: row, cancel := b.tx.QueryRow( "INSERT INTO "+b.table+" (parent_id, key) "+ "VALUES($1, $2) RETURNING id", b.id, key, ) defer cancel() err := row.Scan(&id) if err != nil { return nil, err } case err == nil && value != nil: return nil, walletdb.ErrIncompatibleValue case err != nil: return nil, err } return newReadWriteBucket(b.tx, &id), nil } // DeleteNestedBucket deletes the nested bucket and its sub-buckets // pointed to by the passed key. All values in the bucket and sub-buckets // will be deleted as well. func (b *readWriteBucket) DeleteNestedBucket(key []byte) error { if len(key) == 0 { return walletdb.ErrIncompatibleValue } result, err := b.tx.Exec( "DELETE FROM "+b.table+" WHERE "+parentSelector(b.id)+ " AND key=$1 AND value IS NULL", key, ) if err != nil { return err } rows, err := result.RowsAffected() if err != nil { return err } if rows == 0 { return walletdb.ErrBucketNotFound } return nil } // Put updates the value for the passed key. // Returns ErrKeyRequired if te passed key is empty. func (b *readWriteBucket) Put(key, value []byte) error { if len(key) == 0 { return walletdb.ErrKeyRequired } // Prevent NULL being written for an empty value slice. if value == nil { value = []byte{} } var ( result sql.Result err error ) // We are putting a value in a bucket in this table. Try to insert the // key first. If the key already exists (ON CONFLICT), update the key. // Do not update a NULL value, because this indicates that the key // contains a sub-bucket. This case will be caught via RowsAffected // below. if b.id == nil { // ON CONFLICT requires the WHERE parent_id IS NULL hint to let // Postgres find the NULL-parent_id unique index (_unp). result, err = b.tx.Exec( "INSERT INTO "+b.table+" (key, value) VALUES($1, $2) "+ "ON CONFLICT (key) WHERE parent_id IS NULL "+ "DO UPDATE SET value=$2 "+ "WHERE "+b.table+".value IS NOT NULL", key, value, ) } else { // ON CONFLICT requires the WHERE parent_id NOT IS NULL hint to // let Postgres find the non-NULL-parent_id unique index // (
_up). result, err = b.tx.Exec( "INSERT INTO "+b.table+" (key, value, parent_id) "+ "VALUES($1, $2, $3) "+ "ON CONFLICT (key, parent_id) "+ "WHERE parent_id IS NOT NULL "+ "DO UPDATE SET value=$2 "+ "WHERE "+b.table+".value IS NOT NULL", key, value, b.id, ) } if err != nil { return err } rows, err := result.RowsAffected() if err != nil { return err } if rows != 1 { return walletdb.ErrIncompatibleValue } return nil } // Delete deletes the key/value pointed to by the passed key. // Returns ErrKeyRequired if the passed key is empty. func (b *readWriteBucket) Delete(key []byte) error { if key == nil { return nil } if len(key) == 0 { return walletdb.ErrKeyRequired } // Check to see if a bucket with this key exists. var dummy int row, cancel := b.tx.QueryRow( "SELECT 1 FROM "+b.table+" WHERE "+parentSelector(b.id)+ " AND key=$1 AND value IS NULL", key, ) defer cancel() err := row.Scan(&dummy) switch { // No bucket exists, proceed to deletion of the key. case err == sql.ErrNoRows: case err != nil: return err // Bucket exists. default: return walletdb.ErrIncompatibleValue } _, err = b.tx.Exec( "DELETE FROM "+b.table+" WHERE key=$1 AND "+ parentSelector(b.id)+" AND value IS NOT NULL", key, ) if err != nil { return err } return nil } // ReadWriteCursor returns a new read-write cursor for this bucket. func (b *readWriteBucket) ReadWriteCursor() walletdb.ReadWriteCursor { return newReadWriteCursor(b) } // Tx returns the buckets transaction. func (b *readWriteBucket) Tx() walletdb.ReadWriteTx { return b.tx } // NextSequence returns an autoincrementing sequence number for this bucket. // Note that this is not a thread safe function and as such it must not be used // for synchronization. func (b *readWriteBucket) NextSequence() (uint64, error) { seq := b.Sequence() + 1 return seq, b.SetSequence(seq) } // SetSequence updates the sequence number for the bucket. func (b *readWriteBucket) SetSequence(v uint64) error { if b.id == nil { panic("sequence not supported on top level bucket") } result, err := b.tx.Exec( "UPDATE "+b.table+" SET sequence=$2 WHERE id=$1", b.id, int64(v), ) if err != nil { return err } rows, err := result.RowsAffected() if err != nil { return err } if rows != 1 { return errors.New("cannot set sequence") } return nil } // Sequence returns the current sequence number for this bucket without // incrementing it. func (b *readWriteBucket) Sequence() uint64 { if b.id == nil { panic("sequence not supported on top level bucket") } var seq int64 row, cancel := b.tx.QueryRow( "SELECT sequence FROM "+b.table+" WHERE id=$1 "+ "AND sequence IS NOT NULL", b.id, ) defer cancel() err := row.Scan(&seq) switch { case err == sql.ErrNoRows: return 0 case err != nil: panic(err) } return uint64(seq) } // Prefetch will attempt to prefetch all values under a path from the passed // bucket. func (b *readWriteBucket) Prefetch(paths ...[]string) {} // ForAll is an optimized version of ForEach with the limitation that no // additional queries can be executed within the callback. func (b *readWriteBucket) ForAll(cb func(k, v []byte) error) error { rows, cancel, err := b.tx.Query( "SELECT key, value FROM " + b.table + " WHERE " + parentSelector(b.id) + " ORDER BY key", ) if err != nil { return err } defer cancel() for rows.Next() { var key, value []byte err := rows.Scan(&key, &value) if err != nil { return err } err = cb(key, value) if err != nil { return err } } return nil }