diff --git a/.gitattributes b/.gitattributes index c40e14769..ddec20639 100644 --- a/.gitattributes +++ b/.gitattributes @@ -7,8 +7,8 @@ configure text eol=lf # The following files are generated and should not be shown in the # diffs by default on Github. doc/lightning*.7 linguist-generated=true -wallet/db_*_sqlgen.c linguist-generated=true -wallet/statements_gettextgen.po linguist-generated=true +db_*_sqlgen.c linguist-generated=true +statements_gettextgen.po linguist-generated=true *_wiregen.? linguist-generated=true *_printgen.? linguist-generated=true diff --git a/Makefile b/Makefile index 62fec37c1..66232988b 100644 --- a/Makefile +++ b/Makefile @@ -336,6 +336,7 @@ include external/Makefile include bitcoin/Makefile include common/Makefile include wire/Makefile +include db/Makefile include hsmd/Makefile include gossipd/Makefile include openingd/Makefile diff --git a/db/Makefile b/db/Makefile new file mode 100644 index 000000000..43380acd6 --- /dev/null +++ b/db/Makefile @@ -0,0 +1,23 @@ +#! /usr/bin/make + +DB_LIB_SRC := \ + db/bindings.c \ + db/exec.c \ + db/utils.c + +DB_DRIVERS := \ + db/db_postgres.c \ + db/db_sqlite3.c + +DB_SRC := $(DB_LIB_SRC) $(DB_DRIVERS) +DB_HEADERS := $(DB_LIB_SRC:.c=.h) db/common.h + +DB_OBJS := $(DB_LIB_SRC:.c=.o) $(DB_DRIVERS:.c=.o) +$(DB_OBJS): $(DB_HEADERS) + +# Make sure these depend on everything. +ALL_C_SOURCES += $(DB_SRC) +ALL_C_HEADERS += $(DB_HEADERS) + +# DB_SQL_FILES is the list of database files +DB_SQL_FILES := db/exec.c diff --git a/db/bindings.c b/db/bindings.c new file mode 100644 index 000000000..df8b3827c --- /dev/null +++ b/db/bindings.c @@ -0,0 +1,554 @@ +#include "config.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define NSEC_IN_SEC 1000000000 + +/* Local helpers once you have column number */ +static bool db_column_is_null(struct db_stmt *stmt, int col) +{ + return stmt->db->config->column_is_null_fn(stmt, col); +} + +/* Returns true (and warns) if it's nul */ +static bool db_column_null_warn(struct db_stmt *stmt, const char *colname, + int col) +{ + if (!db_column_is_null(stmt, col)) + return false; + + /* FIXME: log broken? */ +#if DEVELOPER + db_fatal("Accessing a null column %s/%i in query %s", + colname, col, stmt->query->query); +#endif /* DEVELOPER */ + return true; +} + +void db_bind_int(struct db_stmt *stmt, int pos, int val) +{ + assert(pos < tal_count(stmt->bindings)); + memcheck(&val, sizeof(val)); + stmt->bindings[pos].type = DB_BINDING_INT; + stmt->bindings[pos].v.i = val; +} + +int db_col_int(struct db_stmt *stmt, const char *colname) +{ + size_t col = db_query_colnum(stmt, colname); + + if (db_column_null_warn(stmt, colname, col)) + return 0; + + return stmt->db->config->column_int_fn(stmt, col); +} + +int db_col_is_null(struct db_stmt *stmt, const char *colname) +{ + return db_column_is_null(stmt, db_query_colnum(stmt, colname)); +} + +void db_bind_null(struct db_stmt *stmt, int pos) +{ + assert(pos < tal_count(stmt->bindings)); + stmt->bindings[pos].type = DB_BINDING_NULL; +} + +void db_bind_u64(struct db_stmt *stmt, int pos, u64 val) +{ + memcheck(&val, sizeof(val)); + assert(pos < tal_count(stmt->bindings)); + stmt->bindings[pos].type = DB_BINDING_UINT64; + stmt->bindings[pos].v.u64 = val; +} + +void db_bind_blob(struct db_stmt *stmt, int pos, const u8 *val, size_t len) +{ + assert(pos < tal_count(stmt->bindings)); + stmt->bindings[pos].type = DB_BINDING_BLOB; + stmt->bindings[pos].v.blob = memcheck(val, len); + stmt->bindings[pos].len = len; +} + +void db_bind_text(struct db_stmt *stmt, int pos, const char *val) +{ + assert(pos < tal_count(stmt->bindings)); + stmt->bindings[pos].type = DB_BINDING_TEXT; + stmt->bindings[pos].v.text = val; + stmt->bindings[pos].len = strlen(val); +} + +void db_bind_preimage(struct db_stmt *stmt, int pos, const struct preimage *p) +{ + db_bind_blob(stmt, pos, p->r, sizeof(struct preimage)); +} + +void db_bind_sha256(struct db_stmt *stmt, int pos, const struct sha256 *s) +{ + db_bind_blob(stmt, pos, s->u.u8, sizeof(struct sha256)); +} + +void db_bind_sha256d(struct db_stmt *stmt, int pos, const struct sha256_double *s) +{ + db_bind_sha256(stmt, pos, &s->sha); +} + +void db_bind_secret(struct db_stmt *stmt, int pos, const struct secret *s) +{ + assert(sizeof(s->data) == 32); + db_bind_blob(stmt, pos, s->data, sizeof(s->data)); +} + +void db_bind_secret_arr(struct db_stmt *stmt, int col, const struct secret *s) +{ + size_t num = tal_count(s), elsize = sizeof(s->data); + u8 *ser = tal_arr(stmt, u8, num * elsize); + + for (size_t i = 0; i < num; ++i) + memcpy(ser + i * elsize, &s[i], elsize); + + db_bind_blob(stmt, col, ser, tal_count(ser)); +} + +void db_bind_txid(struct db_stmt *stmt, int pos, const struct bitcoin_txid *t) +{ + db_bind_sha256d(stmt, pos, &t->shad); +} + +void db_bind_channel_id(struct db_stmt *stmt, int pos, const struct channel_id *id) +{ + db_bind_blob(stmt, pos, id->id, sizeof(id->id)); +} + +void db_bind_node_id(struct db_stmt *stmt, int pos, const struct node_id *id) +{ + db_bind_blob(stmt, pos, id->k, sizeof(id->k)); +} + +void db_bind_node_id_arr(struct db_stmt *stmt, int col, + const struct node_id *ids) +{ + /* Copy into contiguous array: ARM will add padding to struct node_id! */ + size_t n = tal_count(ids); + u8 *arr = tal_arr(stmt, u8, n * sizeof(ids[0].k)); + + for (size_t i = 0; i < n; ++i) { + assert(node_id_valid(&ids[i])); + memcpy(arr + sizeof(ids[i].k) * i, + ids[i].k, + sizeof(ids[i].k)); + } + db_bind_blob(stmt, col, arr, tal_count(arr)); +} + +void db_bind_pubkey(struct db_stmt *stmt, int pos, const struct pubkey *pk) +{ + u8 *der = tal_arr(stmt, u8, PUBKEY_CMPR_LEN); + pubkey_to_der(der, pk); + db_bind_blob(stmt, pos, der, PUBKEY_CMPR_LEN); +} + +void db_bind_short_channel_id(struct db_stmt *stmt, int col, + const struct short_channel_id *id) +{ + char *ser = short_channel_id_to_str(stmt, id); + db_bind_text(stmt, col, ser); +} + +void db_bind_short_channel_id_arr(struct db_stmt *stmt, int col, + const struct short_channel_id *id) +{ + u8 *ser = tal_arr(stmt, u8, 0); + size_t num = tal_count(id); + + for (size_t i = 0; i < num; ++i) + towire_short_channel_id(&ser, &id[i]); + + db_bind_talarr(stmt, col, ser); +} + +void db_bind_signature(struct db_stmt *stmt, int col, + const secp256k1_ecdsa_signature *sig) +{ + u8 *buf = tal_arr(stmt, u8, 64); + int ret = secp256k1_ecdsa_signature_serialize_compact(secp256k1_ctx, + buf, sig); + assert(ret == 1); + db_bind_blob(stmt, col, buf, 64); +} + +void db_bind_timeabs(struct db_stmt *stmt, int col, struct timeabs t) +{ + u64 timestamp = t.ts.tv_nsec + (((u64) t.ts.tv_sec) * ((u64) NSEC_IN_SEC)); + db_bind_u64(stmt, col, timestamp); +} + +void db_bind_tx(struct db_stmt *stmt, int col, const struct wally_tx *tx) +{ + u8 *ser = linearize_wtx(stmt, tx); + assert(ser); + db_bind_talarr(stmt, col, ser); +} + +void db_bind_psbt(struct db_stmt *stmt, int col, const struct wally_psbt *psbt) +{ + size_t bytes_written; + const u8 *ser = psbt_get_bytes(stmt, psbt, &bytes_written); + assert(ser); + db_bind_blob(stmt, col, ser, bytes_written); +} + +void db_bind_amount_msat(struct db_stmt *stmt, int pos, + const struct amount_msat *msat) +{ + db_bind_u64(stmt, pos, msat->millisatoshis); /* Raw: low level function */ +} + +void db_bind_amount_sat(struct db_stmt *stmt, int pos, + const struct amount_sat *sat) +{ + db_bind_u64(stmt, pos, sat->satoshis); /* Raw: low level function */ +} + +void db_bind_json_escape(struct db_stmt *stmt, int pos, + const struct json_escape *esc) +{ + db_bind_text(stmt, pos, esc->s); +} + +void db_bind_onionreply(struct db_stmt *stmt, int pos, const struct onionreply *r) +{ + db_bind_talarr(stmt, pos, r->contents); +} + +void db_bind_talarr(struct db_stmt *stmt, int col, const u8 *arr) +{ + if (!arr) + db_bind_null(stmt, col); + else + db_bind_blob(stmt, col, arr, tal_bytelen(arr)); +} + +static size_t db_column_bytes(struct db_stmt *stmt, int col) +{ + if (db_column_is_null(stmt, col)) + return 0; + return stmt->db->config->column_bytes_fn(stmt, col); +} + +static const void *db_column_blob(struct db_stmt *stmt, int col) +{ + if (db_column_is_null(stmt, col)) + return NULL; + return stmt->db->config->column_blob_fn(stmt, col); +} + + +u64 db_col_u64(struct db_stmt *stmt, const char *colname) +{ + size_t col = db_query_colnum(stmt, colname); + + if (db_column_null_warn(stmt, colname, col)) + return 0; + + return stmt->db->config->column_u64_fn(stmt, col); +} + +int db_col_int_or_default(struct db_stmt *stmt, const char *colname, int def) +{ + size_t col = db_query_colnum(stmt, colname); + + if (db_column_is_null(stmt, col)) + return def; + else + return stmt->db->config->column_int_fn(stmt, col); +} + +size_t db_col_bytes(struct db_stmt *stmt, const char *colname) +{ + size_t col = db_query_colnum(stmt, colname); + + if (db_column_null_warn(stmt, colname, col)) + return 0; + + return stmt->db->config->column_bytes_fn(stmt, col); +} + +const void *db_col_blob(struct db_stmt *stmt, const char *colname) +{ + size_t col = db_query_colnum(stmt, colname); + + if (db_column_null_warn(stmt, colname, col)) + return NULL; + + return stmt->db->config->column_blob_fn(stmt, col); +} + +char *db_col_strdup(const tal_t *ctx, + struct db_stmt *stmt, + const char *colname) +{ + size_t col = db_query_colnum(stmt, colname); + + if (db_column_null_warn(stmt, colname, col)) + return NULL; + + return tal_strdup(ctx, (char *)stmt->db->config->column_text_fn(stmt, col)); +} + +void db_col_preimage(struct db_stmt *stmt, const char *colname, + struct preimage *preimage) +{ + size_t col = db_query_colnum(stmt, colname); + const u8 *raw; + size_t size = sizeof(struct preimage); + assert(db_column_bytes(stmt, col) == size); + raw = db_column_blob(stmt, col); + memcpy(preimage, raw, size); +} + +void db_col_channel_id(struct db_stmt *stmt, const char *colname, struct channel_id *dest) +{ + size_t col = db_query_colnum(stmt, colname); + + assert(db_column_bytes(stmt, col) == sizeof(dest->id)); + memcpy(dest->id, db_column_blob(stmt, col), sizeof(dest->id)); +} + +void db_col_node_id(struct db_stmt *stmt, const char *colname, struct node_id *dest) +{ + size_t col = db_query_colnum(stmt, colname); + + assert(db_column_bytes(stmt, col) == sizeof(dest->k)); + memcpy(dest->k, db_column_blob(stmt, col), sizeof(dest->k)); +} + +struct node_id *db_col_node_id_arr(const tal_t *ctx, struct db_stmt *stmt, + const char *colname) +{ + size_t col = db_query_colnum(stmt, colname); + struct node_id *ret; + size_t n = db_column_bytes(stmt, col) / sizeof(ret->k); + const u8 *arr = db_column_blob(stmt, col); + assert(n * sizeof(ret->k) == (size_t)db_column_bytes(stmt, col)); + ret = tal_arr(ctx, struct node_id, n); + + db_column_null_warn(stmt, colname, col); + for (size_t i = 0; i < n; i++) + memcpy(ret[i].k, arr + i * sizeof(ret[i].k), sizeof(ret[i].k)); + + return ret; +} + +void db_col_pubkey(struct db_stmt *stmt, + const char *colname, + struct pubkey *dest) +{ + size_t col = db_query_colnum(stmt, colname); + bool ok; + assert(db_column_bytes(stmt, col) == PUBKEY_CMPR_LEN); + ok = pubkey_from_der(db_column_blob(stmt, col), PUBKEY_CMPR_LEN, dest); + assert(ok); +} + +/* Yes, we put this in as a string. Past mistakes; do not use! */ +bool db_col_short_channel_id_str(struct db_stmt *stmt, const char *colname, + struct short_channel_id *dest) +{ + size_t col = db_query_colnum(stmt, colname); + const char *source = db_column_blob(stmt, col); + size_t sourcelen = db_column_bytes(stmt, col); + db_column_null_warn(stmt, colname, col); + return short_channel_id_from_str(source, sourcelen, dest); +} + +struct short_channel_id * +db_col_short_channel_id_arr(const tal_t *ctx, struct db_stmt *stmt, const char *colname) +{ + size_t col = db_query_colnum(stmt, colname); + const u8 *ser; + size_t len; + struct short_channel_id *ret; + + db_column_null_warn(stmt, colname, col); + ser = db_column_blob(stmt, col); + len = db_column_bytes(stmt, col); + ret = tal_arr(ctx, struct short_channel_id, 0); + + while (len != 0) { + struct short_channel_id scid; + fromwire_short_channel_id(&ser, &len, &scid); + tal_arr_expand(&ret, scid); + } + + return ret; +} + +bool db_col_signature(struct db_stmt *stmt, const char *colname, + secp256k1_ecdsa_signature *sig) +{ + size_t col = db_query_colnum(stmt, colname); + assert(db_column_bytes(stmt, col) == 64); + return secp256k1_ecdsa_signature_parse_compact( + secp256k1_ctx, sig, db_column_blob(stmt, col)) == 1; +} + +struct timeabs db_col_timeabs(struct db_stmt *stmt, const char *colname) +{ + struct timeabs t; + u64 timestamp = db_col_u64(stmt, colname); + t.ts.tv_sec = timestamp / NSEC_IN_SEC; + t.ts.tv_nsec = timestamp % NSEC_IN_SEC; + return t; + +} + +struct bitcoin_tx *db_col_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname) +{ + size_t col = db_query_colnum(stmt, colname); + const u8 *src = db_column_blob(stmt, col); + size_t len = db_column_bytes(stmt, col); + + db_column_null_warn(stmt, colname, col); + return pull_bitcoin_tx(ctx, &src, &len); +} + +struct wally_psbt *db_col_psbt(const tal_t *ctx, struct db_stmt *stmt, const char *colname) +{ + size_t col = db_query_colnum(stmt, colname); + const u8 *src = db_column_blob(stmt, col); + size_t len = db_column_bytes(stmt, col); + + db_column_null_warn(stmt, colname, col); + return psbt_from_bytes(ctx, src, len); +} + +struct bitcoin_tx *db_col_psbt_to_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname) +{ + struct wally_psbt *psbt = db_col_psbt(ctx, stmt, colname); + if (!psbt) + return NULL; + return bitcoin_tx_with_psbt(ctx, psbt); +} + +void *db_col_arr_(const tal_t *ctx, struct db_stmt *stmt, const char *colname, + size_t bytes, const char *label, const char *caller) +{ + size_t col = db_query_colnum(stmt, colname); + size_t sourcelen; + void *p; + + if (db_column_is_null(stmt, col)) + return NULL; + + sourcelen = db_column_bytes(stmt, col); + + if (sourcelen % bytes != 0) + db_fatal("%s: %s/%zu column size for %zu not a multiple of %s (%zu)", + caller, colname, col, sourcelen, label, bytes); + + p = tal_arr_label(ctx, char, sourcelen, label); + memcpy(p, db_column_blob(stmt, col), sourcelen); + return p; +} + +void db_col_amount_msat_or_default(struct db_stmt *stmt, + const char *colname, + struct amount_msat *msat, + struct amount_msat def) +{ + size_t col = db_query_colnum(stmt, colname); + + if (db_column_is_null(stmt, col)) + *msat = def; + else + msat->millisatoshis = db_col_u64(stmt, colname); /* Raw: low level function */ +} + +void db_col_amount_msat(struct db_stmt *stmt, const char *colname, + struct amount_msat *msat) +{ + msat->millisatoshis = db_col_u64(stmt, colname); /* Raw: low level function */ +} + +void db_col_amount_sat(struct db_stmt *stmt, const char *colname, struct amount_sat *sat) +{ + sat->satoshis = db_col_u64(stmt, colname); /* Raw: low level function */ +} + +struct json_escape *db_col_json_escape(const tal_t *ctx, + struct db_stmt *stmt, const char *colname) +{ + size_t col = db_query_colnum(stmt, colname); + + return json_escape_string_(ctx, db_column_blob(stmt, col), + db_column_bytes(stmt, col)); +} + +void db_col_sha256(struct db_stmt *stmt, const char *colname, struct sha256 *sha) +{ + size_t col = db_query_colnum(stmt, colname); + const u8 *raw; + size_t size = sizeof(struct sha256); + assert(db_column_bytes(stmt, col) == size); + raw = db_column_blob(stmt, col); + memcpy(sha, raw, size); +} + +void db_col_sha256d(struct db_stmt *stmt, const char *colname, + struct sha256_double *shad) +{ + size_t col = db_query_colnum(stmt, colname); + const u8 *raw; + size_t size = sizeof(struct sha256_double); + assert(db_column_bytes(stmt, col) == size); + raw = db_column_blob(stmt, col); + memcpy(shad, raw, size); +} + +void db_col_secret(struct db_stmt *stmt, const char *colname, struct secret *s) +{ + size_t col = db_query_colnum(stmt, colname); + const u8 *raw; + assert(db_column_bytes(stmt, col) == sizeof(struct secret)); + raw = db_column_blob(stmt, col); + memcpy(s, raw, sizeof(struct secret)); +} + +struct secret *db_col_secret_arr(const tal_t *ctx, + struct db_stmt *stmt, + const char *colname) +{ + return db_col_arr(ctx, stmt, colname, struct secret); +} + +void db_col_txid(struct db_stmt *stmt, const char *colname, struct bitcoin_txid *t) +{ + db_col_sha256d(stmt, colname, &t->shad); +} + +struct onionreply *db_col_onionreply(const tal_t *ctx, + struct db_stmt *stmt, const char *colname) +{ + struct onionreply *r = tal(ctx, struct onionreply); + r->contents = db_col_arr(ctx, stmt, colname, u8); + return r; +} + +void db_col_ignore(struct db_stmt *stmt, const char *colname) +{ +#if DEVELOPER + db_query_colnum(stmt, colname); +#endif +} diff --git a/db/bindings.h b/db/bindings.h new file mode 100644 index 000000000..298dc6d48 --- /dev/null +++ b/db/bindings.h @@ -0,0 +1,118 @@ +#ifndef LIGHTNING_DB_BINDINGS_H +#define LIGHTNING_DB_BINDINGS_H +#include "config.h" + +#include +#include +#include +#include +#include +#include + +struct channel_id; +struct db_stmt; +struct node_id; +struct onionreply; +struct wally_psbt; +struct wally_tx; + +int db_col_is_null(struct db_stmt *stmt, const char *colname); + +void db_bind_int(struct db_stmt *stmt, int pos, int val); +int db_col_int(struct db_stmt *stmt, const char *colname); + +void db_bind_null(struct db_stmt *stmt, int pos); +void db_bind_int(struct db_stmt *stmt, int pos, int val); +void db_bind_u64(struct db_stmt *stmt, int pos, u64 val); +void db_bind_blob(struct db_stmt *stmt, int pos, const u8 *val, size_t len); +void db_bind_text(struct db_stmt *stmt, int pos, const char *val); +void db_bind_preimage(struct db_stmt *stmt, int pos, const struct preimage *p); +void db_bind_sha256(struct db_stmt *stmt, int pos, const struct sha256 *s); +void db_bind_sha256d(struct db_stmt *stmt, int pos, const struct sha256_double *s); +void db_bind_secret(struct db_stmt *stmt, int pos, const struct secret *s); +void db_bind_secret_arr(struct db_stmt *stmt, int col, const struct secret *s); +void db_bind_txid(struct db_stmt *stmt, int pos, const struct bitcoin_txid *t); +void db_bind_channel_id(struct db_stmt *stmt, int pos, const struct channel_id *id); +void db_bind_node_id(struct db_stmt *stmt, int pos, const struct node_id *ni); +void db_bind_node_id_arr(struct db_stmt *stmt, int col, + const struct node_id *ids); +void db_bind_pubkey(struct db_stmt *stmt, int pos, const struct pubkey *p); +void db_bind_short_channel_id(struct db_stmt *stmt, int col, + const struct short_channel_id *id); +void db_bind_short_channel_id_arr(struct db_stmt *stmt, int col, + const struct short_channel_id *id); +void db_bind_signature(struct db_stmt *stmt, int col, + const secp256k1_ecdsa_signature *sig); +void db_bind_timeabs(struct db_stmt *stmt, int col, struct timeabs t); +void db_bind_tx(struct db_stmt *stmt, int col, const struct wally_tx *tx); +void db_bind_psbt(struct db_stmt *stmt, int col, const struct wally_psbt *psbt); +void db_bind_amount_msat(struct db_stmt *stmt, int pos, + const struct amount_msat *msat); +void db_bind_amount_sat(struct db_stmt *stmt, int pos, + const struct amount_sat *sat); +void db_bind_json_escape(struct db_stmt *stmt, int pos, + const struct json_escape *esc); +void db_bind_onionreply(struct db_stmt *stmt, int col, + const struct onionreply *r); +void db_bind_talarr(struct db_stmt *stmt, int col, const u8 *arr); + +/* Modern variants: get columns by name from SELECT */ +/* Bridge function to get column number from SELECT + (must exist) */ +size_t db_query_colnum(const struct db_stmt *stmt, const char *colname); + +u64 db_col_u64(struct db_stmt *stmt, const char *colname); +size_t db_col_bytes(struct db_stmt *stmt, const char *colname); +const void* db_col_blob(struct db_stmt *stmt, const char *colname); +char *db_col_strdup(const tal_t *ctx, + struct db_stmt *stmt, + const char *colname); +void db_col_preimage(struct db_stmt *stmt, const char *colname, struct preimage *preimage); +void db_col_amount_msat(struct db_stmt *stmt, const char *colname, struct amount_msat *msat); +void db_col_amount_sat(struct db_stmt *stmt, const char *colname, struct amount_sat *sat); +struct json_escape *db_col_json_escape(const tal_t *ctx, struct db_stmt *stmt, const char *colname); +void db_col_sha256(struct db_stmt *stmt, const char *colname, struct sha256 *sha); +void db_col_sha256d(struct db_stmt *stmt, const char *colname, struct sha256_double *shad); +void db_col_secret(struct db_stmt *stmt, const char *colname, struct secret *s); +struct secret *db_col_secret_arr(const tal_t *ctx, struct db_stmt *stmt, + const char *colname); +void db_col_txid(struct db_stmt *stmt, const char *colname, struct bitcoin_txid *t); +void db_col_channel_id(struct db_stmt *stmt, const char *colname, struct channel_id *dest); +void db_col_node_id(struct db_stmt *stmt, const char *colname, struct node_id *ni); +struct node_id *db_col_node_id_arr(const tal_t *ctx, struct db_stmt *stmt, + const char *colname); +void db_col_pubkey(struct db_stmt *stmt, const char *colname, + struct pubkey *p); +bool db_col_short_channel_id_str(struct db_stmt *stmt, const char *colname, + struct short_channel_id *dest); +struct short_channel_id * +db_col_short_channel_id_arr(const tal_t *ctx, struct db_stmt *stmt, const char *colname); +bool db_col_signature(struct db_stmt *stmt, const char *colname, + secp256k1_ecdsa_signature *sig); +struct timeabs db_col_timeabs(struct db_stmt *stmt, const char *colname); +struct bitcoin_tx *db_col_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname); +struct wally_psbt *db_col_psbt(const tal_t *ctx, struct db_stmt *stmt, const char *colname); +struct bitcoin_tx *db_col_psbt_to_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname); + +struct onionreply *db_col_onionreply(const tal_t *ctx, + struct db_stmt *stmt, const char *colname); + +#define db_col_arr(ctx, stmt, colname, type) \ + ((type *)db_col_arr_((ctx), (stmt), (colname), \ + sizeof(type), TAL_LABEL(type, "[]"), \ + __func__)) +void *db_col_arr_(const tal_t *ctx, struct db_stmt *stmt, const char *colname, + size_t bytes, const char *label, const char *caller); + + +/* Some useful default variants */ +int db_col_int_or_default(struct db_stmt *stmt, const char *colname, int def); +void db_col_amount_msat_or_default(struct db_stmt *stmt, const char *colname, + struct amount_msat *msat, + struct amount_msat def); + + +/* Explicitly ignore a column (so we don't complain you didn't use it!) */ +void db_col_ignore(struct db_stmt *stmt, const char *colname); + +#endif /* LIGHTNING_DB_BINDINGS_H */ diff --git a/wallet/db_common.h b/db/common.h similarity index 77% rename from wallet/db_common.h rename to db/common.h index 233d7b93a..4ced5132a 100644 --- a/wallet/db_common.h +++ b/db/common.h @@ -1,10 +1,32 @@ -#ifndef LIGHTNING_WALLET_DB_COMMON_H -#define LIGHTNING_WALLET_DB_COMMON_H +#ifndef LIGHTNING_DB_COMMON_H +#define LIGHTNING_DB_COMMON_H #include "config.h" #include #include #include #include +#include + +/** + * Macro to annotate a named SQL query. + * + * This macro is used to annotate SQL queries that might need rewriting for + * different SQL dialects. It is used both as a marker for the query + * extraction logic in devtools/sql-rewrite.py to identify queries, as well as + * a way to swap out the query text with it's name so that the query execution + * engine can then look up the rewritten query using its name. + * + */ +#define NAMED_SQL(name,x) x + +/** + * Simple annotation macro that auto-generates names for NAMED_SQL + * + * If this macro is changed it is likely that the extraction logic in + * devtools/sql-rewrite.py needs to change as well, since they need to + * generate identical names to work correctly. + */ +#define SQL(x) NAMED_SQL( __FILE__ ":" stringify(__COUNTER__), x) struct db { char *filename; @@ -13,18 +35,18 @@ struct db { /* DB-specific context */ void *conn; - /* The configuration, including translated queries for the current - * instance. */ + /* The configuration for the current database driver */ const struct db_config *config; + /* Translated queries for the current database domain + driver */ + const struct db_query_set *queries; + const char **changes; /* List of statements that have been created but not executed yet. */ struct list_head pending_statements; char *error; - struct log *log; - /* Were there any modifying statements in the current transaction? * Used to bump the data_version in the DB.*/ bool dirty; @@ -32,6 +54,8 @@ struct db { /* The current DB version we expect to update if changes are * committed. */ u32 data_version; + + void (*report_changes_fn)(struct db *); }; struct db_query { @@ -102,10 +126,14 @@ struct db_stmt { #endif }; -struct db_config { +struct db_query_set { const char *name; const struct db_query *query_table; size_t query_table_size; +}; + +struct db_config { + const char *name; /* Function used to execute a statement that doesn't result in a * response. */ @@ -155,20 +183,14 @@ struct db_config { const char **colnames, size_t num_cols); }; -/* Provide a way for DB backends to register themselves */ -AUTODATA_TYPE(db_backends, struct db_config); - void db_fatal(const char *fmt, ...) PRINTF_FMT(1, 2); -/** - * Report a statement that changes the wallet - * - * Allows the DB driver to report an expanded statement during - * execution. Changes are queued up and reported to the `db_write` plugin hook - * upon committing. - */ -void db_changes_add(struct db_stmt *db_stmt, const char * expanded); +/* Provide a way for DB backends to register themselves */ +AUTODATA_TYPE(db_backends, struct db_config); + +/* Provide a way for DB query sets to register themselves */ +AUTODATA_TYPE(db_queries, struct db_query_set); /* devtools/sql-rewrite.py generates this simple htable */ struct sqlname_map { @@ -176,4 +198,4 @@ struct sqlname_map { int val; }; -#endif /* LIGHTNING_WALLET_DB_COMMON_H */ +#endif /* LIGHTNING_DB_COMMON_H */ diff --git a/wallet/db_postgres.c b/db/db_postgres.c similarity index 97% rename from wallet/db_postgres.c rename to db/db_postgres.c index 91db5beff..7f6f143bc 100644 --- a/wallet/db_postgres.c +++ b/db/db_postgres.c @@ -1,9 +1,8 @@ #include "config.h" #include #include -#include -#include -#include +#include +#include #if HAVE_POSTGRES /* Indented in order not to trigger the inclusion order check */ @@ -321,8 +320,6 @@ static bool db_postgres_delete_columns(struct db *db, struct db_config db_postgres_config = { .name = "postgres", - .query_table = db_postgres_queries, - .query_table_size = ARRAY_SIZE(db_postgres_queries), .exec_fn = db_postgres_exec, .query_fn = db_postgres_query, .step_fn = db_postgres_step, @@ -348,4 +345,4 @@ struct db_config db_postgres_config = { AUTODATA(db_backends, &db_postgres_config); -#endif +#endif /* HAVE_POSTGRES */ diff --git a/wallet/db_sqlite3.c b/db/db_sqlite3.c similarity index 99% rename from wallet/db_sqlite3.c rename to db/db_sqlite3.c index ad81bc20b..90f0a4aa2 100644 --- a/wallet/db_sqlite3.c +++ b/db/db_sqlite3.c @@ -1,8 +1,8 @@ #include "config.h" -#include "db_sqlite3_sqlgen.c" #include #include -#include +#include +#include #if HAVE_SQLITE3 #include @@ -682,8 +682,6 @@ static bool db_sqlite3_delete_columns(struct db *db, struct db_config db_sqlite3_config = { .name = "sqlite3", - .query_table = db_sqlite3_queries, - .query_table_size = ARRAY_SIZE(db_sqlite3_queries), .exec_fn = &db_sqlite3_exec, .query_fn = &db_sqlite3_query, .step_fn = &db_sqlite3_step, @@ -710,4 +708,4 @@ struct db_config db_sqlite3_config = { AUTODATA(db_backends, &db_sqlite3_config); -#endif +#endif /* HAVE_SQLITE3 */ diff --git a/db/exec.c b/db/exec.c new file mode 100644 index 000000000..77c9597ab --- /dev/null +++ b/db/exec.c @@ -0,0 +1,162 @@ +#include "config.h" +#include +#include +#include +#include +#include + +/** + * db_get_version - Determine the current DB schema version + * + * Will attempt to determine the current schema version of the + * database @db by querying the `version` table. If the table does not + * exist it'll return schema version -1, so that migration 0 is + * applied, which should create the `version` table. + */ +int db_get_version(struct db *db) +{ + int res = -1; + struct db_stmt *stmt = db_prepare_v2(db, SQL("SELECT version FROM version LIMIT 1")); + + /* + * Tentatively execute a query, but allow failures. Some databases + * like postgres will terminate the DB transaction if there is an + * error during the execution of a query, e.g., trying to access a + * table that doesn't exist yet, so we need to terminate and restart + * the DB transaction. + */ + if (!db_query_prepared(stmt)) { + db_commit_transaction(stmt->db); + db_begin_transaction(stmt->db); + tal_free(stmt); + return res; + } + + if (db_step(stmt)) + res = db_col_int(stmt, "version"); + + tal_free(stmt); + return res; +} + +u32 db_data_version_get(struct db *db) +{ + struct db_stmt *stmt; + u32 version; + stmt = db_prepare_v2(db, SQL("SELECT intval FROM vars WHERE name = 'data_version'")); + db_query_prepared(stmt); + db_step(stmt); + version = db_col_int(stmt, "intval"); + tal_free(stmt); + return version; +} + +void db_set_intvar(struct db *db, char *varname, s64 val) +{ + size_t changes; + struct db_stmt *stmt = db_prepare_v2(db, SQL("UPDATE vars SET intval=? WHERE name=?;")); + db_bind_int(stmt, 0, val); + db_bind_text(stmt, 1, varname); + if (!db_exec_prepared_v2(stmt)) + db_fatal("Error executing update: %s", stmt->error); + changes = db_count_changes(stmt); + tal_free(stmt); + + if (changes == 0) { + stmt = db_prepare_v2(db, SQL("INSERT INTO vars (name, intval) VALUES (?, ?);")); + db_bind_text(stmt, 0, varname); + db_bind_int(stmt, 1, val); + if (!db_exec_prepared_v2(stmt)) + db_fatal("Error executing insert: %s", stmt->error); + tal_free(stmt); + } +} + +s64 db_get_intvar(struct db *db, char *varname, s64 defval) +{ + s64 res = defval; + struct db_stmt *stmt = db_prepare_v2( + db, SQL("SELECT intval FROM vars WHERE name= ? LIMIT 1")); + db_bind_text(stmt, 0, varname); + if (!db_query_prepared(stmt)) + goto done; + + if (db_step(stmt)) + res = db_col_int(stmt, "intval"); + +done: + tal_free(stmt); + return res; +} + +/* Leak tracking. */ + +/* By making the update conditional on the current value we expect we + * are implementing an optimistic lock: if the update results in + * changes on the DB we know that the data_version did not change + * under our feet and no other transaction ran in the meantime. + * + * Notice that this update effectively locks the row, so that other + * operations attempting to change this outside the transaction will + * wait for this transaction to complete. The external change will + * ultimately fail the changes test below, it'll just delay its abort + * until our transaction is committed. + */ +static void db_data_version_incr(struct db *db) +{ + struct db_stmt *stmt = db_prepare_v2( + db, SQL("UPDATE vars " + "SET intval = intval + 1 " + "WHERE name = 'data_version'" + " AND intval = ?")); + db_bind_int(stmt, 0, db->data_version); + db_exec_prepared_v2(stmt); + if (db_count_changes(stmt) != 1) + db_fatal("Optimistic lock on the database failed. There" + " may be a concurrent access to the database." + " Aborting since concurrent access is unsafe."); + tal_free(stmt); + db->data_version++; +} + +void db_begin_transaction_(struct db *db, const char *location) +{ + bool ok; + if (db->in_transaction) + db_fatal("Already in transaction from %s", db->in_transaction); + + /* No writes yet. */ + db->dirty = false; + + db_prepare_for_changes(db); + ok = db->config->begin_tx_fn(db); + if (!ok) + db_fatal("Failed to start DB transaction: %s", db->error); + + db->in_transaction = location; +} + +bool db_in_transaction(struct db *db) +{ + return db->in_transaction; +} + +void db_commit_transaction(struct db *db) +{ + bool ok; + assert(db->in_transaction); + db_assert_no_outstanding_statements(db); + + /* Increment before reporting changes to an eventual plugin. */ + if (db->dirty) + db_data_version_incr(db); + + db_report_changes(db, NULL, 0); + ok = db->config->commit_tx_fn(db); + + if (!ok) + db_fatal("Failed to commit DB transaction: %s", db->error); + + db->in_transaction = NULL; + db->dirty = false; +} diff --git a/db/exec.h b/db/exec.h new file mode 100644 index 000000000..70799532a --- /dev/null +++ b/db/exec.h @@ -0,0 +1,52 @@ +#ifndef LIGHTNING_DB_EXEC_H +#define LIGHTNING_DB_EXEC_H +#include "config.h" + +#include +#include + +struct db; + +/** + * db_set_intvar - Set an integer variable in the database + * + * Utility function to store generic integer values in the + * database. + */ +void db_set_intvar(struct db *db, char *varname, s64 val); + +/** + * db_get_intvar - Retrieve an integer variable from the database + * + * Either returns the value in the database, or @defval if + * the query failed or no such variable exists. + */ +s64 db_get_intvar(struct db *db, char *varname, s64 defval); + +/* Get the current data version (entries). */ +u32 db_data_version_get(struct db *db); + +/* Get the current database version (migrations). */ +int db_get_version(struct db *db); + +/** + * db_begin_transaction - Begin a transaction + * + * Begin a new DB transaction. fatal() on database error. + */ +#define db_begin_transaction(db) \ + db_begin_transaction_((db), __FILE__ ":" stringify(__LINE__)) +void db_begin_transaction_(struct db *db, const char *location); + +bool db_in_transaction(struct db *db); + +/** + * db_commit_transaction - Commit a running transaction + * + * Requires that we are currently in a transaction. fatal() if we + * fail to commit. + */ +void db_commit_transaction(struct db *db); + + +#endif /* LIGHTNING_DB_EXEC_H */ diff --git a/db/utils.c b/db/utils.c new file mode 100644 index 000000000..11d2654f6 --- /dev/null +++ b/db/utils.c @@ -0,0 +1,324 @@ +#include "config.h" +#include +#include +#include +#include + +/* Matches the hash function used in devtools/sql-rewrite.py */ +static u32 hash_djb2(const char *str) +{ + u32 hash = 5381; + for (size_t i = 0; str[i]; i++) + hash = ((hash << 5) + hash) ^ str[i]; + return hash; +} + +size_t db_query_colnum(const struct db_stmt *stmt, + const char *colname) +{ + u32 col; + + assert(stmt->query->colnames != NULL); + + col = hash_djb2(colname) % stmt->query->num_colnames; + /* Will crash on NULL, which is the Right Thing */ + while (!streq(stmt->query->colnames[col].sqlname, + colname)) { + col = (col + 1) % stmt->query->num_colnames; + } + +#if DEVELOPER + strset_add(stmt->cols_used, colname); +#endif + + return stmt->query->colnames[col].val; +} + +static void db_stmt_free(struct db_stmt *stmt) +{ + if (!stmt->executed) + db_fatal("Freeing an un-executed statement from %s: %s", + stmt->location, stmt->query->query); +#if DEVELOPER + /* If they never got a db_step, we don't track */ + if (stmt->cols_used) { + for (size_t i = 0; i < stmt->query->num_colnames; i++) { + if (!stmt->query->colnames[i].sqlname) + continue; + if (!strset_get(stmt->cols_used, + stmt->query->colnames[i].sqlname)) { + db_fatal("Never accessed column %s in query %s", + stmt->query->colnames[i].sqlname, + stmt->query->query); + } + } + strset_clear(stmt->cols_used); + } +#endif + if (stmt->inner_stmt) + stmt->db->config->stmt_free_fn(stmt); + assert(stmt->inner_stmt == NULL); +} + + +struct db_stmt *db_prepare_v2_(const char *location, struct db *db, + const char *query_id) +{ + struct db_stmt *stmt = tal(db, struct db_stmt); + size_t num_slots, pos; + + /* Normalize query_id paths, because unit tests are compiled with this + * prefix. */ + if (strncmp(query_id, "./", 2) == 0) + query_id += 2; + + if (!db->in_transaction) + db_fatal("Attempting to prepare a db_stmt outside of a " + "transaction: %s", location); + + /* Look up the query by its ID */ + pos = hash_djb2(query_id) % db->queries->query_table_size; + for (;;) { + if (!db->queries->query_table[pos].name) + db_fatal("Could not resolve query %s", query_id); + if (streq(query_id, db->queries->query_table[pos].name)) { + stmt->query = &db->queries->query_table[pos]; + break; + } + pos = (pos + 1) % db->queries->query_table_size; + } + + num_slots = stmt->query->placeholders; + /* Allocate the slots for placeholders/bindings, zeroed next since + * that sets the type to DB_BINDING_UNINITIALIZED for later checks. */ + stmt->bindings = tal_arr(stmt, struct db_binding, num_slots); + for (size_t i=0; ibindings[i].type = DB_BINDING_UNINITIALIZED; + + stmt->location = location; + stmt->error = NULL; + stmt->db = db; + stmt->executed = false; + stmt->inner_stmt = NULL; + + tal_add_destructor(stmt, db_stmt_free); + + list_add(&db->pending_statements, &stmt->list); + +#if DEVELOPER + stmt->cols_used = NULL; +#endif /* DEVELOPER */ + + return stmt; +} + +#define db_prepare_v2(db,query) \ + db_prepare_v2_(__FILE__ ":" stringify(__LINE__), db, query) + +bool db_query_prepared(struct db_stmt *stmt) +{ + /* Make sure we don't accidentally execute a modifying query using a + * read-only path. */ + bool ret; + assert(stmt->query->readonly); + ret = stmt->db->config->query_fn(stmt); + stmt->executed = true; + list_del_from(&stmt->db->pending_statements, &stmt->list); + return ret; +} + +bool db_step(struct db_stmt *stmt) +{ + bool ret; + + assert(stmt->executed); + ret = stmt->db->config->step_fn(stmt); + +#if DEVELOPER + /* We only track cols_used if we return a result! */ + if (ret && !stmt->cols_used) { + stmt->cols_used = tal(stmt, struct strset); + strset_init(stmt->cols_used); + } +#endif + return ret; +} + +bool db_exec_prepared_v2(struct db_stmt *stmt TAKES) +{ + bool ret = stmt->db->config->exec_fn(stmt); + + /* If this was a write we need to bump the data_version upon commit. */ + stmt->db->dirty = stmt->db->dirty || !stmt->query->readonly; + + stmt->executed = true; + list_del_from(&stmt->db->pending_statements, &stmt->list); + + /* The driver itself doesn't call `fatal` since we want to override it + * for testing. Instead we check here that the error message is set if + * we report an error. */ + if (!ret) { + assert(stmt->error); + db_fatal("Error executing statement: %s", stmt->error); + } + + if (taken(stmt)) + tal_free(stmt); + + return ret; +} + +size_t db_count_changes(struct db_stmt *stmt) +{ + assert(stmt->executed); + return stmt->db->config->count_changes_fn(stmt); +} + +const char **db_changes(struct db *db) +{ + return db->changes; +} + +u64 db_last_insert_id_v2(struct db_stmt *stmt TAKES) +{ + u64 id; + assert(stmt->executed); + id = stmt->db->config->last_insert_id_fn(stmt); + + if (taken(stmt)) + tal_free(stmt); + + return id; +} + +/* We expect min changes (ie. BEGIN TRANSACTION): report if more. + * Optionally add "final" at the end (ie. COMMIT). */ +void db_report_changes(struct db *db, const char *final, size_t min) +{ + assert(db->changes); + assert(tal_count(db->changes) >= min); + + /* Having changes implies that we have a dirty TX. The opposite is + * currently not true, e.g., the postgres driver doesn't record + * changes yet. */ + assert(!tal_count(db->changes) || db->dirty); + + if (tal_count(db->changes) > min && db->report_changes_fn) + db->report_changes_fn(db); + db->changes = tal_free(db->changes); +} + +void db_changes_add(struct db_stmt *stmt, const char * expanded) +{ + struct db *db = stmt->db; + + if (stmt->query->readonly) { + return; + } + /* We get a "COMMIT;" after we've sent our changes. */ + if (!db->changes) { + assert(streq(expanded, "COMMIT;")); + return; + } + + tal_arr_expand(&db->changes, tal_strdup(db->changes, expanded)); +} + +#if DEVELOPER +void db_assert_no_outstanding_statements(struct db *db) +{ + struct db_stmt *stmt; + + stmt = list_top(&db->pending_statements, struct db_stmt, list); + if (stmt) + db_fatal("Unfinalized statement %s", stmt->location); +} +#else +void db_assert_no_outstanding_statements(struct db *db) +{ +} +#endif + + +static void destroy_db(struct db *db) +{ + db_assert_no_outstanding_statements(db); + + if (db->config->teardown_fn) + db->config->teardown_fn(db); +} + +static struct db_config *db_config_find(const char *dsn) +{ + size_t num_configs; + struct db_config **configs = autodata_get(db_backends, &num_configs); + const char *sep, *driver_name; + sep = strstr(dsn, "://"); + + if (!sep) + db_fatal("%s doesn't look like a valid data-source name (missing \"://\" separator.", dsn); + + driver_name = tal_strndup(tmpctx, dsn, sep - dsn); + + for (size_t i=0; iname)) { + tal_free(driver_name); + return configs[i]; + } + } + + tal_free(driver_name); + return NULL; +} + +static struct db_query_set *db_queries_find(const struct db_config *config) +{ + size_t num_queries; + struct db_query_set **queries = autodata_get(db_queries, &num_queries); + + for (size_t i = 0; i < num_queries; i++) { + if (streq(config->name, queries[i]->name)) { + return queries[i]; + } + } + return NULL; +} + +void db_prepare_for_changes(struct db *db) +{ + assert(!db->changes); + db->changes = tal_arr(db, const char *, 0); +} + +struct db *db_open(const tal_t *ctx, char *filename) +{ + struct db *db; + + db = tal(ctx, struct db); + db->filename = tal_strdup(db, filename); + list_head_init(&db->pending_statements); + if (!strstr(db->filename, "://")) + db_fatal("Could not extract driver name from \"%s\"", db->filename); + + db->config = db_config_find(db->filename); + if (!db->config) + db_fatal("Unable to find DB driver for %s", db->filename); + + db->queries = db_queries_find(db->config); + if (!db->queries) + db_fatal("Unable to find DB queries for %s", db->config->name); + + tal_add_destructor(db, destroy_db); + db->in_transaction = NULL; + db->changes = NULL; + + /* This must be outside a transaction, so catch it */ + assert(!db->in_transaction); + + db_prepare_for_changes(db); + if (db->config->setup_fn && !db->config->setup_fn(db)) + db_fatal("Error calling DB setup: %s", db->error); + db_report_changes(db, NULL, 0); + + return db; +} diff --git a/db/utils.h b/db/utils.h new file mode 100644 index 000000000..2b5a21dd7 --- /dev/null +++ b/db/utils.h @@ -0,0 +1,100 @@ +#ifndef LIGHTNING_DB_UTILS_H +#define LIGHTNING_DB_UTILS_H +#include "config.h" +#include +#include + +struct db; +struct db_stmt; + +size_t db_query_colnum(const struct db_stmt *stmt, + const char *colname); + +/* Return next 'row' result of statement */ +bool db_step(struct db_stmt *stmt); + +/* TODO(cdecker) Remove the v2 suffix after finishing the migration */ +#define db_prepare_v2(db,query) \ + db_prepare_v2_(__FILE__ ":" stringify(__LINE__), db, query) + + +/** + * db_exec_prepared -- Execute a prepared statement + * + * After preparing a statement using `db_prepare`, and after binding all + * non-null variables using the `db_bind_*` functions, it can be executed with + * this function. It is a small, transaction-aware, wrapper around `db_step`, + * that calls fatal() if the execution fails. This may take ownership of + * `stmt` if annotated with `take()`and will free it before returning. + * + * If you'd like to issue a query and access the rows returned by the query + * please use `db_query_prepared` instead, since this function will not expose + * returned results, and the `stmt` can only be used for calls to + * `db_count_changes` and `db_last_insert_id` after executing. + * + * @stmt: The prepared statement to execute + */ +bool db_exec_prepared_v2(struct db_stmt *stmt TAKES); + +/** + * db_query_prepared -- Execute a prepared query + * + * After preparing a query using `db_prepare`, and after binding all non-null + * variables using the `db_bind_*` functions, it can be executed with this + * function. This function must be called before calling `db_step` or any of + * the `db_col_*` column access functions. + * + * If you are not executing a read-only statement, please use + * `db_exec_prepared` instead. + * + * @stmt: The prepared statement to execute + */ +bool db_query_prepared(struct db_stmt *stmt); + +size_t db_count_changes(struct db_stmt *stmt); +void db_report_changes(struct db *db, const char *final, size_t min); +void db_prepare_for_changes(struct db *db); + +u64 db_last_insert_id_v2(struct db_stmt *stmt); + +/** + * db_prepare -- Prepare a DB query/command + * + * Create an instance of `struct db_stmt` that encapsulates a SQL query or command. + * + * @query MUST be wrapped in a `SQL()` macro call, since that allows the + * extraction and translation of the query into the target SQL dialect. + * + * It does not execute the query and does not check its validity, but + * allocates the placeholders detected in the query. The placeholders in the + * `stmt` can then be bound using the `db_bind_*` functions, and executed + * using `db_exec_prepared` for write-only statements and `db_query_prepared` + * for read-only statements. + * + * @db: Database to query/exec + * @query: The SQL statement to compile + */ +struct db_stmt *db_prepare_v2_(const char *location, struct db *db, + const char *query_id); + +/** + * db_open - Open or create a database + */ +struct db *db_open(const tal_t *ctx, char *filename); + +/** + * Report a statement that changes the wallet + * + * Allows the DB driver to report an expanded statement during + * execution. Changes are queued up and reported to the `db_write` plugin hook + * upon committing. + */ +void db_changes_add(struct db_stmt *db_stmt, const char * expanded); +void db_assert_no_outstanding_statements(struct db *db); + +/** + * Access pending changes that have been added to the current transaction. + */ +const char **db_changes(struct db *db); + +#endif /* LIGHTNING_DB_UTILS_H */ diff --git a/devtools/sql-rewrite.py b/devtools/sql-rewrite.py index e6f70689c..1aa9c4469 100755 --- a/devtools/sql-rewrite.py +++ b/devtools/sql-rewrite.py @@ -126,7 +126,8 @@ template = Template("""#ifndef LIGHTNINGD_WALLET_GEN_DB_${f.upper()} #include #include -#include +#include +#include #if HAVE_${f.upper()} % for colname, table in colhtables.items(): diff --git a/lightningd/Makefile b/lightningd/Makefile index 3a9b58bb7..844d2748f 100644 --- a/lightningd/Makefile +++ b/lightningd/Makefile @@ -130,9 +130,11 @@ LIGHTNINGD_COMMON_OBJS := \ common/wallet.o \ common/wire_error.o \ common/wireaddr.o \ + db/bindings.o \ + db/exec.o \ $(LIGHTNINGD_OBJS): $(LIGHTNINGD_HDRS) -$(WALLET_OBJS): $(LIGHTNINGD_HDRS) +$(WALLET_OBJS): $(LIGHTNINGD_HDRS) $(DB_HEADERS) # Only the plugin component needs to depend on this header. lightningd/plugin.o: plugins/list_of_builtin_plugins_gen.h @@ -140,6 +142,6 @@ lightningd/plugin.o: plugins/list_of_builtin_plugins_gen.h lightningd/channel_state_names_gen.h: lightningd/channel_state.h ccan/ccan/cdump/tools/cdump-enumstr ccan/ccan/cdump/tools/cdump-enumstr lightningd/channel_state.h > $@ -lightningd/lightningd: $(LIGHTNINGD_OBJS) $(WALLET_OBJS) $(LIGHTNINGD_COMMON_OBJS) $(BITCOIN_OBJS) $(WIRE_OBJS) $(WIRE_ONION_OBJS) $(LIGHTNINGD_CONTROL_OBJS) $(HSMD_CLIENT_OBJS) +lightningd/lightningd: $(LIGHTNINGD_OBJS) $(WALLET_OBJS) $(LIGHTNINGD_COMMON_OBJS) $(BITCOIN_OBJS) $(WIRE_OBJS) $(WIRE_ONION_OBJS) $(LIGHTNINGD_CONTROL_OBJS) $(HSMD_CLIENT_OBJS) $(DB_OBJS) include lightningd/test/Makefile diff --git a/lightningd/bitcoind.c b/lightningd/bitcoind.c index 91711b62e..251e3f11a 100644 --- a/lightningd/bitcoind.c +++ b/lightningd/bitcoind.c @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include diff --git a/lightningd/chaintopology.c b/lightningd/chaintopology.c index a998065ce..f965ae90c 100644 --- a/lightningd/chaintopology.c +++ b/lightningd/chaintopology.c @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include diff --git a/lightningd/invoice.c b/lightningd/invoice.c index f462a4e29..365826521 100644 --- a/lightningd/invoice.c +++ b/lightningd/invoice.c @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -16,6 +17,7 @@ #include #include #include +#include #include #include #include diff --git a/lightningd/io_loop_with_timers.c b/lightningd/io_loop_with_timers.c index 246ca9979..9d52d6a4c 100644 --- a/lightningd/io_loop_with_timers.c +++ b/lightningd/io_loop_with_timers.c @@ -1,6 +1,7 @@ #include "config.h" #include #include +#include #include #include diff --git a/lightningd/jsonrpc.c b/lightningd/jsonrpc.c index e80d3edbc..367e3ab4e 100644 --- a/lightningd/jsonrpc.c +++ b/lightningd/jsonrpc.c @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -26,6 +27,7 @@ #include #include #include +#include #include #include #include diff --git a/lightningd/lightningd.c b/lightningd/lightningd.c index 113abcdb6..440a34694 100644 --- a/lightningd/lightningd.c +++ b/lightningd/lightningd.c @@ -37,6 +37,7 @@ */ #include #include +#include #include #include #include @@ -53,6 +54,7 @@ #include #include #include +#include #include #include diff --git a/lightningd/offer.c b/lightningd/offer.c index e8118a493..48739c85f 100644 --- a/lightningd/offer.c +++ b/lightningd/offer.c @@ -1,5 +1,7 @@ #include "config.h" #include +#include +#include #include #include #include diff --git a/lightningd/onchain_control.c b/lightningd/onchain_control.c index 17e32f204..fa8c2139c 100644 --- a/lightningd/onchain_control.c +++ b/lightningd/onchain_control.c @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include diff --git a/lightningd/options.c b/lightningd/options.c index d14b917fd..0577b0a1e 100644 --- a/lightningd/options.c +++ b/lightningd/options.c @@ -1,6 +1,7 @@ #include "config.h" #include #include +#include #include #include #include diff --git a/lightningd/peer_htlcs.c b/lightningd/peer_htlcs.c index 20932c567..cf910dfa2 100644 --- a/lightningd/peer_htlcs.c +++ b/lightningd/peer_htlcs.c @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include diff --git a/lightningd/plugin_hook.c b/lightningd/plugin_hook.c index 29a5de2aa..4d9d82b43 100644 --- a/lightningd/plugin_hook.c +++ b/lightningd/plugin_hook.c @@ -1,6 +1,8 @@ #include "config.h" #include #include +#include +#include #include /* Struct containing all the information needed to deserialize and diff --git a/lightningd/subd.c b/lightningd/subd.c index 633976191..44bd85f3a 100644 --- a/lightningd/subd.c +++ b/lightningd/subd.c @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include diff --git a/wallet/Makefile b/wallet/Makefile index 2697febf1..91b10246b 100644 --- a/wallet/Makefile +++ b/wallet/Makefile @@ -7,37 +7,40 @@ WALLET_LIB_SRC := \ wallet/wallet.c \ wallet/walletrpc.c -WALLET_LIB_SRC_NOHDR := \ - wallet/reservation.c +WALLET_LIB_SRC_NOHDR := \ + wallet/reservation.c \ + wallet/db_queries_postgres.c \ + wallet/db_queries_sqlite3.c -WALLET_DB_DRIVERS := \ - wallet/db_postgres.c \ - wallet/db_sqlite3.c +WALLET_DB_QUERIES := \ + wallet/db_sqlite3_sqlgen.c \ + wallet/db_postgres_sqlgen.c -WALLET_SRC := $(WALLET_LIB_SRC) $(WALLET_LIB_SRC_NOHDR) $(WALLET_DB_DRIVERS) +WALLET_SRC := $(WALLET_LIB_SRC) $(WALLET_LIB_SRC_NOHDR) WALLET_HDRS := $(WALLET_LIB_SRC:.c=.h) WALLET_OBJS := $(WALLET_SRC:.c=.o) # Make sure these depend on everything. -ALL_C_SOURCES += $(WALLET_SRC) wallet/db_sqlite3_sqlgen.c wallet/db_postgres_sqlgen.c +ALL_C_SOURCES += $(WALLET_SRC) $(WALLET_DB_QUERIES) ALL_C_HEADERS += $(WALLET_HDRS) -# Each database driver depends on its rewritten statements. -wallet/db_sqlite3.o: wallet/db_sqlite3_sqlgen.c -wallet/db_postgres.o: wallet/db_postgres_sqlgen.c +# Query sets depend on the rewritten queries +wallet/db_queries_sqlite3.o: $(WALLET_DB_QUERIES) +wallet/db_queries_postgres.o: $(WALLET_DB_QUERIES) # The following files contain SQL-annotated statements that we need to extact -SQL_FILES := \ +WALLET_SQL_FILES := \ + $(DB_SQL_FILES) \ wallet/db.c \ wallet/invoices.c \ wallet/wallet.c \ wallet/test/run-db.c \ wallet/test/run-wallet.c \ -wallet/statements_gettextgen.po: $(SQL_FILES) $(FORCE) +wallet/statements_gettextgen.po: $(WALLET_SQL_FILES) $(FORCE) @if $(call SHA256STAMP_CHANGED); then \ - $(call VERBOSE,"xgettext $@",xgettext -kNAMED_SQL -kSQL --add-location --no-wrap --omit-header -o $@ $(SQL_FILES) && $(call SHA256STAMP,# ,)); \ + $(call VERBOSE,"xgettext $@",xgettext -kNAMED_SQL -kSQL --add-location --no-wrap --omit-header -o $@ $(WALLET_SQL_FILES) && $(call SHA256STAMP,# ,)); \ fi wallet/db_%_sqlgen.c: wallet/statements_gettextgen.po devtools/sql-rewrite.py $(FORCE) @@ -50,7 +53,6 @@ clean: wallet-maintainer-clean wallet-maintainer-clean: $(RM) wallet/statements.po $(RM) wallet/statements_gettextgen.po - $(RM) wallet/db_sqlite3_sqlgen.c - $(RM) wallet/db_postgres_sqlgen.c + $(RM) $(WALLET_DB_QUERIES) include wallet/test/Makefile diff --git a/wallet/db.c b/wallet/db.c index 23fa2866c..3225579eb 100644 --- a/wallet/db.c +++ b/wallet/db.c @@ -4,20 +4,19 @@ #include #include #include -#include #include -#include #include +#include +#include +#include +#include #include #include #include #include #include -#include #include -#define NSEC_IN_SEC 1000000000 - /* Small container for things that are needed by migrations. The * fields are guaranteed to be initialized and can be relied upon when * migrating. @@ -872,337 +871,6 @@ static struct migration dbmigrations[] = { {SQL("ALTER TABLE channel_funding_inflights ADD lease_fee BIGINT DEFAULT 0"), NULL}, }; -/* Leak tracking. */ -#if DEVELOPER -static void db_assert_no_outstanding_statements(struct db *db) -{ - struct db_stmt *stmt; - - stmt = list_top(&db->pending_statements, struct db_stmt, list); - if (stmt) - db_fatal("Unfinalized statement %s", stmt->location); -} -#else -static void db_assert_no_outstanding_statements(struct db *db) -{ -} -#endif - -static void db_stmt_free(struct db_stmt *stmt) -{ - if (!stmt->executed) - fatal("Freeing an un-executed statement from %s: %s", - stmt->location, stmt->query->query); -#if DEVELOPER - /* If they never got a db_step, we don't track */ - if (stmt->cols_used) { - for (size_t i = 0; i < stmt->query->num_colnames; i++) { - if (!stmt->query->colnames[i].sqlname) - continue; - if (!strset_get(stmt->cols_used, - stmt->query->colnames[i].sqlname)) { - log_broken(stmt->db->log, - "Never accessed column %s in query %s", - stmt->query->colnames[i].sqlname, - stmt->query->query); - } - } - strset_clear(stmt->cols_used); - } -#endif - if (stmt->inner_stmt) - stmt->db->config->stmt_free_fn(stmt); - assert(stmt->inner_stmt == NULL); -} - -/* Matches the hash function used in devtools/sql-rewrite.py */ -static u32 hash_djb2(const char *str) -{ - u32 hash = 5381; - for (size_t i = 0; str[i]; i++) - hash = ((hash << 5) + hash) ^ str[i]; - return hash; -} - -struct db_stmt *db_prepare_v2_(const char *location, struct db *db, - const char *query_id) -{ - struct db_stmt *stmt = tal(db, struct db_stmt); - size_t num_slots, pos; - - /* Normalize query_id paths, because unit tests are compiled with this - * prefix. */ - if (strncmp(query_id, "./", 2) == 0) - query_id += 2; - - if (!db->in_transaction) - db_fatal("Attempting to prepare a db_stmt outside of a " - "transaction: %s", location); - - /* Look up the query by its ID */ - pos = hash_djb2(query_id) % db->config->query_table_size; - for (;;) { - if (!db->config->query_table[pos].name) - fatal("Could not resolve query %s", query_id); - if (streq(query_id, db->config->query_table[pos].name)) { - stmt->query = &db->config->query_table[pos]; - break; - } - pos = (pos + 1) % db->config->query_table_size; - } - - num_slots = stmt->query->placeholders; - /* Allocate the slots for placeholders/bindings, zeroed next since - * that sets the type to DB_BINDING_UNINITIALIZED for later checks. */ - stmt->bindings = tal_arr(stmt, struct db_binding, num_slots); - for (size_t i=0; ibindings[i].type = DB_BINDING_UNINITIALIZED; - - stmt->location = location; - stmt->error = NULL; - stmt->db = db; - stmt->executed = false; - stmt->inner_stmt = NULL; - - tal_add_destructor(stmt, db_stmt_free); - - list_add(&db->pending_statements, &stmt->list); - -#if DEVELOPER - stmt->cols_used = NULL; -#endif /* DEVELOPER */ - - return stmt; -} - -#define db_prepare_v2(db,query) \ - db_prepare_v2_(__FILE__ ":" stringify(__LINE__), db, query) - -bool db_step(struct db_stmt *stmt) -{ - bool ret; - - assert(stmt->executed); - ret = stmt->db->config->step_fn(stmt); - -#if DEVELOPER - /* We only track cols_used if we return a result! */ - if (ret && !stmt->cols_used) { - stmt->cols_used = tal(stmt, struct strset); - strset_init(stmt->cols_used); - } -#endif - return ret; -} - -size_t db_count_changes(struct db_stmt *stmt) -{ - assert(stmt->executed); - return stmt->db->config->count_changes_fn(stmt); -} - -u64 db_last_insert_id_v2(struct db_stmt *stmt TAKES) -{ - u64 id; - assert(stmt->executed); - id = stmt->db->config->last_insert_id_fn(stmt); - - if (taken(stmt)) - tal_free(stmt); - - return id; -} - -static void destroy_db(struct db *db) -{ - db_assert_no_outstanding_statements(db); - - if (db->config->teardown_fn) - db->config->teardown_fn(db); -} - -/* We expect min changes (ie. BEGIN TRANSACTION): report if more. - * Optionally add "final" at the end (ie. COMMIT). */ -static void db_report_changes(struct db *db, const char *final, size_t min) -{ - assert(db->changes); - assert(tal_count(db->changes) >= min); - - /* Having changes implies that we have a dirty TX. The opposite is - * currently not true, e.g., the postgres driver doesn't record - * changes yet. */ - assert(!tal_count(db->changes) || db->dirty); - - if (tal_count(db->changes) > min) - plugin_hook_db_sync(db); - db->changes = tal_free(db->changes); -} - -static void db_prepare_for_changes(struct db *db) -{ - assert(!db->changes); - db->changes = tal_arr(db, const char *, 0); -} - -bool db_in_transaction(struct db *db) -{ - return db->in_transaction; -} - -void db_begin_transaction_(struct db *db, const char *location) -{ - bool ok; - if (db->in_transaction) - db_fatal("Already in transaction from %s", db->in_transaction); - - /* No writes yet. */ - db->dirty = false; - - db_prepare_for_changes(db); - ok = db->config->begin_tx_fn(db); - if (!ok) - db_fatal("Failed to start DB transaction: %s", db->error); - - db->in_transaction = location; -} - -/* By making the update conditional on the current value we expect we - * are implementing an optimistic lock: if the update results in - * changes on the DB we know that the data_version did not change - * under our feet and no other transaction ran in the meantime. - * - * Notice that this update effectively locks the row, so that other - * operations attempting to change this outside the transaction will - * wait for this transaction to complete. The external change will - * ultimately fail the changes test below, it'll just delay its abort - * until our transaction is committed. - */ -static void db_data_version_incr(struct db *db) -{ - struct db_stmt *stmt = db_prepare_v2( - db, SQL("UPDATE vars " - "SET intval = intval + 1 " - "WHERE name = 'data_version'" - " AND intval = ?")); - db_bind_int(stmt, 0, db->data_version); - db_exec_prepared_v2(stmt); - if (db_count_changes(stmt) != 1) - fatal("Optimistic lock on the database failed. There may be a " - "concurrent access to the database. Aborting since " - "concurrent access is unsafe."); - tal_free(stmt); - db->data_version++; -} - -void db_commit_transaction(struct db *db) -{ - bool ok; - assert(db->in_transaction); - db_assert_no_outstanding_statements(db); - - /* Increment before reporting changes to an eventual plugin. */ - if (db->dirty) - db_data_version_incr(db); - - db_report_changes(db, NULL, 0); - ok = db->config->commit_tx_fn(db); - - if (!ok) - db_fatal("Failed to commit DB transaction: %s", db->error); - - db->in_transaction = NULL; - db->dirty = false; -} - -static struct db_config *db_config_find(const char *dsn) -{ - size_t num_configs; - struct db_config **configs = autodata_get(db_backends, &num_configs); - const char *sep, *driver_name; - sep = strstr(dsn, "://"); - - if (!sep) - db_fatal("%s doesn't look like a valid data-source name (missing \"://\" separator.", dsn); - - driver_name = tal_strndup(tmpctx, dsn, sep - dsn); - - for (size_t i=0; iname)) { - tal_free(driver_name); - return configs[i]; - } - } - - tal_free(driver_name); - return NULL; -} - -/** - * db_open - Open or create a sqlite3 database - */ -static struct db *db_open(const tal_t *ctx, char *filename) -{ - struct db *db; - - db = tal(ctx, struct db); - db->filename = tal_strdup(db, filename); - list_head_init(&db->pending_statements); - if (!strstr(db->filename, "://")) - db_fatal("Could not extract driver name from \"%s\"", db->filename); - - db->config = db_config_find(db->filename); - if (!db->config) - db_fatal("Unable to find DB driver for %s", db->filename); - - tal_add_destructor(db, destroy_db); - db->in_transaction = NULL; - db->changes = NULL; - - /* This must be outside a transaction, so catch it */ - assert(!db->in_transaction); - - db_prepare_for_changes(db); - if (db->config->setup_fn && !db->config->setup_fn(db)) - fatal("Error calling DB setup: %s", db->error); - db_report_changes(db, NULL, 0); - - return db; -} - -/** - * db_get_version - Determine the current DB schema version - * - * Will attempt to determine the current schema version of the - * database @db by querying the `version` table. If the table does not - * exist it'll return schema version -1, so that migration 0 is - * applied, which should create the `version` table. - */ -static int db_get_version(struct db *db) -{ - int res = -1; - struct db_stmt *stmt = db_prepare_v2(db, SQL("SELECT version FROM version LIMIT 1")); - - /* - * Tentatively execute a query, but allow failures. Some databases - * like postgres will terminate the DB transaction if there is an - * error during the execution of a query, e.g., trying to access a - * table that doesn't exist yet, so we need to terminate and restart - * the DB transaction. - */ - if (!db_query_prepared(stmt)) { - db_commit_transaction(stmt->db); - db_begin_transaction(stmt->db); - tal_free(stmt); - return res; - } - - if (db_step(stmt)) - res = db_col_int(stmt, "version"); - - tal_free(stmt); - return res; -} - /** * db_migrate - Apply all remaining migrations from the current version */ @@ -1221,12 +889,12 @@ static bool db_migrate(struct lightningd *ld, struct db *db, available = ARRAY_SIZE(dbmigrations) - 1; if (current == -1) - log_info(db->log, "Creating database"); + log_info(ld->log, "Creating database"); else if (available < current) db_fatal("Refusing to migrate down from version %u to %u", current, available); else if (current != available) - log_info(db->log, "Updating database from version %u to %u", + log_info(ld->log, "Updating database from version %u to %u", current, available); while (current < available) { @@ -1260,24 +928,13 @@ static bool db_migrate(struct lightningd *ld, struct db *db, return current != orig; } -u32 db_data_version_get(struct db *db) -{ - struct db_stmt *stmt; - u32 version; - stmt = db_prepare_v2(db, SQL("SELECT intval FROM vars WHERE name = 'data_version'")); - db_query_prepared(stmt); - db_step(stmt); - version = db_col_int(stmt, "intval"); - tal_free(stmt); - return version; -} - struct db *db_setup(const tal_t *ctx, struct lightningd *ld, const struct ext_key *bip32_base) { struct db *db = db_open(ctx, ld->wallet_dsn); bool migrated; - db->log = new_log(db, ld->log_book, NULL, "database"); + + db->report_changes_fn = plugin_hook_db_sync; db_begin_transaction(db); @@ -1295,44 +952,6 @@ struct db *db_setup(const tal_t *ctx, struct lightningd *ld, return db; } -s64 db_get_intvar(struct db *db, char *varname, s64 defval) -{ - s64 res = defval; - struct db_stmt *stmt = db_prepare_v2( - db, SQL("SELECT intval FROM vars WHERE name= ? LIMIT 1")); - db_bind_text(stmt, 0, varname); - if (!db_query_prepared(stmt)) - goto done; - - if (db_step(stmt)) - res = db_col_int(stmt, "intval"); - -done: - tal_free(stmt); - return res; -} - -void db_set_intvar(struct db *db, char *varname, s64 val) -{ - size_t changes; - struct db_stmt *stmt = db_prepare_v2(db, SQL("UPDATE vars SET intval=? WHERE name=?;")); - db_bind_int(stmt, 0, val); - db_bind_text(stmt, 1, varname); - if (!db_exec_prepared_v2(stmt)) - db_fatal("Error executing update: %s", stmt->error); - changes = db_count_changes(stmt); - tal_free(stmt); - - if (changes == 0) { - stmt = db_prepare_v2(db, SQL("INSERT INTO vars (name, intval) VALUES (?, ?);")); - db_bind_text(stmt, 0, varname); - db_bind_int(stmt, 1, val); - if (!db_exec_prepared_v2(stmt)) - db_fatal("Error executing insert: %s", stmt->error); - tal_free(stmt); - } -} - /* Will apply the current config fee settings to all channels */ static void migrate_pr2342_feerate_per_channel(struct lightningd *ld, struct db *db, const struct migration_context *mc) @@ -1770,617 +1389,3 @@ void migrate_last_tx_to_psbt(struct lightningd *ld, struct db *db, tal_free(stmt); } - -void db_bind_null(struct db_stmt *stmt, int pos) -{ - assert(pos < tal_count(stmt->bindings)); - stmt->bindings[pos].type = DB_BINDING_NULL; -} - -void db_bind_int(struct db_stmt *stmt, int pos, int val) -{ - assert(pos < tal_count(stmt->bindings)); - memcheck(&val, sizeof(val)); - stmt->bindings[pos].type = DB_BINDING_INT; - stmt->bindings[pos].v.i = val; -} - -void db_bind_u64(struct db_stmt *stmt, int pos, u64 val) -{ - memcheck(&val, sizeof(val)); - assert(pos < tal_count(stmt->bindings)); - stmt->bindings[pos].type = DB_BINDING_UINT64; - stmt->bindings[pos].v.u64 = val; -} - -void db_bind_blob(struct db_stmt *stmt, int pos, const u8 *val, size_t len) -{ - assert(pos < tal_count(stmt->bindings)); - stmt->bindings[pos].type = DB_BINDING_BLOB; - stmt->bindings[pos].v.blob = memcheck(val, len); - stmt->bindings[pos].len = len; -} - -void db_bind_text(struct db_stmt *stmt, int pos, const char *val) -{ - assert(pos < tal_count(stmt->bindings)); - stmt->bindings[pos].type = DB_BINDING_TEXT; - stmt->bindings[pos].v.text = val; - stmt->bindings[pos].len = strlen(val); -} - -void db_bind_preimage(struct db_stmt *stmt, int pos, const struct preimage *p) -{ - db_bind_blob(stmt, pos, p->r, sizeof(struct preimage)); -} - -void db_bind_sha256(struct db_stmt *stmt, int pos, const struct sha256 *s) -{ - db_bind_blob(stmt, pos, s->u.u8, sizeof(struct sha256)); -} - -void db_bind_sha256d(struct db_stmt *stmt, int pos, const struct sha256_double *s) -{ - db_bind_sha256(stmt, pos, &s->sha); -} - -void db_bind_secret(struct db_stmt *stmt, int pos, const struct secret *s) -{ - assert(sizeof(s->data) == 32); - db_bind_blob(stmt, pos, s->data, sizeof(s->data)); -} - -void db_bind_secret_arr(struct db_stmt *stmt, int col, const struct secret *s) -{ - size_t num = tal_count(s), elsize = sizeof(s->data); - u8 *ser = tal_arr(stmt, u8, num * elsize); - - for (size_t i = 0; i < num; ++i) - memcpy(ser + i * elsize, &s[i], elsize); - - db_bind_blob(stmt, col, ser, tal_count(ser)); -} - -void db_bind_txid(struct db_stmt *stmt, int pos, const struct bitcoin_txid *t) -{ - db_bind_sha256d(stmt, pos, &t->shad); -} - -void db_bind_channel_id(struct db_stmt *stmt, int pos, const struct channel_id *id) -{ - db_bind_blob(stmt, pos, id->id, sizeof(id->id)); -} - -void db_bind_node_id(struct db_stmt *stmt, int pos, const struct node_id *id) -{ - db_bind_blob(stmt, pos, id->k, sizeof(id->k)); -} - -void db_bind_node_id_arr(struct db_stmt *stmt, int col, - const struct node_id *ids) -{ - /* Copy into contiguous array: ARM will add padding to struct node_id! */ - size_t n = tal_count(ids); - u8 *arr = tal_arr(stmt, u8, n * sizeof(ids[0].k)); - - for (size_t i = 0; i < n; ++i) { - assert(node_id_valid(&ids[i])); - memcpy(arr + sizeof(ids[i].k) * i, - ids[i].k, - sizeof(ids[i].k)); - } - db_bind_blob(stmt, col, arr, tal_count(arr)); -} - -void db_bind_pubkey(struct db_stmt *stmt, int pos, const struct pubkey *pk) -{ - u8 *der = tal_arr(stmt, u8, PUBKEY_CMPR_LEN); - pubkey_to_der(der, pk); - db_bind_blob(stmt, pos, der, PUBKEY_CMPR_LEN); -} - -void db_bind_short_channel_id(struct db_stmt *stmt, int col, - const struct short_channel_id *id) -{ - char *ser = short_channel_id_to_str(stmt, id); - db_bind_text(stmt, col, ser); -} - -void db_bind_short_channel_id_arr(struct db_stmt *stmt, int col, - const struct short_channel_id *id) -{ - u8 *ser = tal_arr(stmt, u8, 0); - size_t num = tal_count(id); - - for (size_t i = 0; i < num; ++i) - towire_short_channel_id(&ser, &id[i]); - - db_bind_talarr(stmt, col, ser); -} - -void db_bind_signature(struct db_stmt *stmt, int col, - const secp256k1_ecdsa_signature *sig) -{ - u8 *buf = tal_arr(stmt, u8, 64); - int ret = secp256k1_ecdsa_signature_serialize_compact(secp256k1_ctx, - buf, sig); - assert(ret == 1); - db_bind_blob(stmt, col, buf, 64); -} - -void db_bind_timeabs(struct db_stmt *stmt, int col, struct timeabs t) -{ - u64 timestamp = t.ts.tv_nsec + (((u64) t.ts.tv_sec) * ((u64) NSEC_IN_SEC)); - db_bind_u64(stmt, col, timestamp); -} - -void db_bind_tx(struct db_stmt *stmt, int col, const struct wally_tx *tx) -{ - u8 *ser = linearize_wtx(stmt, tx); - assert(ser); - db_bind_talarr(stmt, col, ser); -} - -void db_bind_psbt(struct db_stmt *stmt, int col, const struct wally_psbt *psbt) -{ - size_t bytes_written; - const u8 *ser = psbt_get_bytes(stmt, psbt, &bytes_written); - assert(ser); - db_bind_blob(stmt, col, ser, bytes_written); -} - -void db_bind_amount_msat(struct db_stmt *stmt, int pos, - const struct amount_msat *msat) -{ - db_bind_u64(stmt, pos, msat->millisatoshis); /* Raw: low level function */ -} - -void db_bind_amount_sat(struct db_stmt *stmt, int pos, - const struct amount_sat *sat) -{ - db_bind_u64(stmt, pos, sat->satoshis); /* Raw: low level function */ -} - -void db_bind_json_escape(struct db_stmt *stmt, int pos, - const struct json_escape *esc) -{ - db_bind_text(stmt, pos, esc->s); -} - -void db_bind_onionreply(struct db_stmt *stmt, int pos, const struct onionreply *r) -{ - db_bind_talarr(stmt, pos, r->contents); -} - -void db_bind_talarr(struct db_stmt *stmt, int col, const u8 *arr) -{ - if (!arr) - db_bind_null(stmt, col); - else - db_bind_blob(stmt, col, arr, tal_bytelen(arr)); -} - -/* Local helpers once you have column number */ -static bool db_column_is_null(struct db_stmt *stmt, int col) -{ - return stmt->db->config->column_is_null_fn(stmt, col); -} - -/* Returns true (and warns) if it's nul */ -static bool db_column_null_warn(struct db_stmt *stmt, const char *colname, - int col) -{ - if (!db_column_is_null(stmt, col)) - return false; - - log_broken(stmt->db->log, "Accessing a null column %s/%i in query %s", - colname, col, - stmt->query->query); - return true; -} - -static size_t db_column_bytes(struct db_stmt *stmt, int col) -{ - if (db_column_is_null(stmt, col)) - return 0; - return stmt->db->config->column_bytes_fn(stmt, col); -} - -static const void *db_column_blob(struct db_stmt *stmt, int col) -{ - if (db_column_is_null(stmt, col)) - return NULL; - return stmt->db->config->column_blob_fn(stmt, col); -} - - -u64 db_col_u64(struct db_stmt *stmt, const char *colname) -{ - size_t col = db_query_colnum(stmt, colname); - - if (db_column_null_warn(stmt, colname, col)) - return 0; - - return stmt->db->config->column_u64_fn(stmt, col); -} - -int db_col_int_or_default(struct db_stmt *stmt, const char *colname, int def) -{ - size_t col = db_query_colnum(stmt, colname); - - if (db_column_is_null(stmt, col)) - return def; - else - return stmt->db->config->column_int_fn(stmt, col); -} - -int db_col_int(struct db_stmt *stmt, const char *colname) -{ - size_t col = db_query_colnum(stmt, colname); - - if (db_column_null_warn(stmt, colname, col)) - return 0; - - return stmt->db->config->column_int_fn(stmt, col); -} - -size_t db_col_bytes(struct db_stmt *stmt, const char *colname) -{ - size_t col = db_query_colnum(stmt, colname); - - if (db_column_null_warn(stmt, colname, col)) - return 0; - - return stmt->db->config->column_bytes_fn(stmt, col); -} - -int db_col_is_null(struct db_stmt *stmt, const char *colname) -{ - return db_column_is_null(stmt, db_query_colnum(stmt, colname)); -} - -const void *db_col_blob(struct db_stmt *stmt, const char *colname) -{ - size_t col = db_query_colnum(stmt, colname); - - if (db_column_null_warn(stmt, colname, col)) - return NULL; - - return stmt->db->config->column_blob_fn(stmt, col); -} - -char *db_col_strdup(const tal_t *ctx, - struct db_stmt *stmt, - const char *colname) -{ - size_t col = db_query_colnum(stmt, colname); - - if (db_column_null_warn(stmt, colname, col)) - return NULL; - - return tal_strdup(ctx, (char *)stmt->db->config->column_text_fn(stmt, col)); -} - -void db_col_preimage(struct db_stmt *stmt, const char *colname, - struct preimage *preimage) -{ - size_t col = db_query_colnum(stmt, colname); - const u8 *raw; - size_t size = sizeof(struct preimage); - assert(db_column_bytes(stmt, col) == size); - raw = db_column_blob(stmt, col); - memcpy(preimage, raw, size); -} - -void db_col_channel_id(struct db_stmt *stmt, const char *colname, struct channel_id *dest) -{ - size_t col = db_query_colnum(stmt, colname); - - assert(db_column_bytes(stmt, col) == sizeof(dest->id)); - memcpy(dest->id, db_column_blob(stmt, col), sizeof(dest->id)); -} - -void db_col_node_id(struct db_stmt *stmt, const char *colname, struct node_id *dest) -{ - size_t col = db_query_colnum(stmt, colname); - - assert(db_column_bytes(stmt, col) == sizeof(dest->k)); - memcpy(dest->k, db_column_blob(stmt, col), sizeof(dest->k)); -} - -struct node_id *db_col_node_id_arr(const tal_t *ctx, struct db_stmt *stmt, - const char *colname) -{ - size_t col = db_query_colnum(stmt, colname); - struct node_id *ret; - size_t n = db_column_bytes(stmt, col) / sizeof(ret->k); - const u8 *arr = db_column_blob(stmt, col); - assert(n * sizeof(ret->k) == (size_t)db_column_bytes(stmt, col)); - ret = tal_arr(ctx, struct node_id, n); - - db_column_null_warn(stmt, colname, col); - for (size_t i = 0; i < n; i++) - memcpy(ret[i].k, arr + i * sizeof(ret[i].k), sizeof(ret[i].k)); - - return ret; -} - -void db_col_pubkey(struct db_stmt *stmt, - const char *colname, - struct pubkey *dest) -{ - size_t col = db_query_colnum(stmt, colname); - bool ok; - assert(db_column_bytes(stmt, col) == PUBKEY_CMPR_LEN); - ok = pubkey_from_der(db_column_blob(stmt, col), PUBKEY_CMPR_LEN, dest); - assert(ok); -} - -/* Yes, we put this in as a string. Past mistakes; do not use! */ -bool db_col_short_channel_id_str(struct db_stmt *stmt, const char *colname, - struct short_channel_id *dest) -{ - size_t col = db_query_colnum(stmt, colname); - const char *source = db_column_blob(stmt, col); - size_t sourcelen = db_column_bytes(stmt, col); - db_column_null_warn(stmt, colname, col); - return short_channel_id_from_str(source, sourcelen, dest); -} - -struct short_channel_id * -db_col_short_channel_id_arr(const tal_t *ctx, struct db_stmt *stmt, const char *colname) -{ - size_t col = db_query_colnum(stmt, colname); - const u8 *ser; - size_t len; - struct short_channel_id *ret; - - db_column_null_warn(stmt, colname, col); - ser = db_column_blob(stmt, col); - len = db_column_bytes(stmt, col); - ret = tal_arr(ctx, struct short_channel_id, 0); - - while (len != 0) { - struct short_channel_id scid; - fromwire_short_channel_id(&ser, &len, &scid); - tal_arr_expand(&ret, scid); - } - - return ret; -} - -bool db_col_signature(struct db_stmt *stmt, const char *colname, - secp256k1_ecdsa_signature *sig) -{ - size_t col = db_query_colnum(stmt, colname); - assert(db_column_bytes(stmt, col) == 64); - return secp256k1_ecdsa_signature_parse_compact( - secp256k1_ctx, sig, db_column_blob(stmt, col)) == 1; -} - -struct timeabs db_col_timeabs(struct db_stmt *stmt, const char *colname) -{ - struct timeabs t; - u64 timestamp = db_col_u64(stmt, colname); - t.ts.tv_sec = timestamp / NSEC_IN_SEC; - t.ts.tv_nsec = timestamp % NSEC_IN_SEC; - return t; - -} - -struct bitcoin_tx *db_col_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname) -{ - size_t col = db_query_colnum(stmt, colname); - const u8 *src = db_column_blob(stmt, col); - size_t len = db_column_bytes(stmt, col); - - db_column_null_warn(stmt, colname, col); - return pull_bitcoin_tx(ctx, &src, &len); -} - -struct wally_psbt *db_col_psbt(const tal_t *ctx, struct db_stmt *stmt, const char *colname) -{ - size_t col = db_query_colnum(stmt, colname); - const u8 *src = db_column_blob(stmt, col); - size_t len = db_column_bytes(stmt, col); - - db_column_null_warn(stmt, colname, col); - return psbt_from_bytes(ctx, src, len); -} - -struct bitcoin_tx *db_col_psbt_to_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname) -{ - struct wally_psbt *psbt = db_col_psbt(ctx, stmt, colname); - if (!psbt) - return NULL; - return bitcoin_tx_with_psbt(ctx, psbt); -} - -void *db_col_arr_(const tal_t *ctx, struct db_stmt *stmt, const char *colname, - size_t bytes, const char *label, const char *caller) -{ - size_t col = db_query_colnum(stmt, colname); - size_t sourcelen; - void *p; - - if (db_column_is_null(stmt, col)) - return NULL; - - sourcelen = db_column_bytes(stmt, col); - - if (sourcelen % bytes != 0) - db_fatal("%s: %s/%zu column size for %zu not a multiple of %s (%zu)", - caller, colname, col, sourcelen, label, bytes); - - p = tal_arr_label(ctx, char, sourcelen, label); - memcpy(p, db_column_blob(stmt, col), sourcelen); - return p; -} - -void db_col_amount_msat_or_default(struct db_stmt *stmt, - const char *colname, - struct amount_msat *msat, - struct amount_msat def) -{ - size_t col = db_query_colnum(stmt, colname); - - if (db_column_is_null(stmt, col)) - *msat = def; - else - msat->millisatoshis = db_col_u64(stmt, colname); /* Raw: low level function */ -} - -void db_col_amount_msat(struct db_stmt *stmt, const char *colname, - struct amount_msat *msat) -{ - msat->millisatoshis = db_col_u64(stmt, colname); /* Raw: low level function */ -} - -void db_col_amount_sat(struct db_stmt *stmt, const char *colname, struct amount_sat *sat) -{ - sat->satoshis = db_col_u64(stmt, colname); /* Raw: low level function */ -} - -struct json_escape *db_col_json_escape(const tal_t *ctx, - struct db_stmt *stmt, const char *colname) -{ - size_t col = db_query_colnum(stmt, colname); - - return json_escape_string_(ctx, db_column_blob(stmt, col), - db_column_bytes(stmt, col)); -} - -void db_col_sha256(struct db_stmt *stmt, const char *colname, struct sha256 *sha) -{ - size_t col = db_query_colnum(stmt, colname); - const u8 *raw; - size_t size = sizeof(struct sha256); - assert(db_column_bytes(stmt, col) == size); - raw = db_column_blob(stmt, col); - memcpy(sha, raw, size); -} - -void db_col_sha256d(struct db_stmt *stmt, const char *colname, - struct sha256_double *shad) -{ - size_t col = db_query_colnum(stmt, colname); - const u8 *raw; - size_t size = sizeof(struct sha256_double); - assert(db_column_bytes(stmt, col) == size); - raw = db_column_blob(stmt, col); - memcpy(shad, raw, size); -} - -void db_col_secret(struct db_stmt *stmt, const char *colname, struct secret *s) -{ - size_t col = db_query_colnum(stmt, colname); - const u8 *raw; - assert(db_column_bytes(stmt, col) == sizeof(struct secret)); - raw = db_column_blob(stmt, col); - memcpy(s, raw, sizeof(struct secret)); -} - -struct secret *db_col_secret_arr(const tal_t *ctx, - struct db_stmt *stmt, - const char *colname) -{ - return db_col_arr(ctx, stmt, colname, struct secret); -} - -void db_col_txid(struct db_stmt *stmt, const char *colname, struct bitcoin_txid *t) -{ - db_col_sha256d(stmt, colname, &t->shad); -} - -struct onionreply *db_col_onionreply(const tal_t *ctx, - struct db_stmt *stmt, const char *colname) -{ - struct onionreply *r = tal(ctx, struct onionreply); - r->contents = db_col_arr(ctx, stmt, colname, u8); - return r; -} - -bool db_exec_prepared_v2(struct db_stmt *stmt TAKES) -{ - bool ret = stmt->db->config->exec_fn(stmt); - - /* If this was a write we need to bump the data_version upon commit. */ - stmt->db->dirty = stmt->db->dirty || !stmt->query->readonly; - - stmt->executed = true; - list_del_from(&stmt->db->pending_statements, &stmt->list); - - /* The driver itself doesn't call `fatal` since we want to override it - * for testing. Instead we check here that the error message is set if - * we report an error. */ - if (!ret) { - assert(stmt->error); - db_fatal("Error executing statement: %s", stmt->error); - } - - if (taken(stmt)) - tal_free(stmt); - - return ret; -} - -bool db_query_prepared(struct db_stmt *stmt) -{ - /* Make sure we don't accidentally execute a modifying query using a - * read-only path. */ - bool ret; - assert(stmt->query->readonly); - ret = stmt->db->config->query_fn(stmt); - stmt->executed = true; - list_del_from(&stmt->db->pending_statements, &stmt->list); - return ret; -} - -void db_changes_add(struct db_stmt *stmt, const char * expanded) -{ - struct db *db = stmt->db; - - if (stmt->query->readonly) { - return; - } - /* We get a "COMMIT;" after we've sent our changes. */ - if (!db->changes) { - assert(streq(expanded, "COMMIT;")); - return; - } - - tal_arr_expand(&db->changes, tal_strdup(db->changes, expanded)); -} - -const char **db_changes(struct db *db) -{ - return db->changes; -} - -size_t db_query_colnum(const struct db_stmt *stmt, - const char *colname) -{ - u32 col; - - assert(stmt->query->colnames != NULL); - - col = hash_djb2(colname) % stmt->query->num_colnames; - /* Will crash on NULL, which is the Right Thing */ - while (!streq(stmt->query->colnames[col].sqlname, - colname)) { - col = (col + 1) % stmt->query->num_colnames; - } - -#if DEVELOPER - strset_add(stmt->cols_used, colname); -#endif - - return stmt->query->colnames[col].val; -} - -void db_col_ignore(struct db_stmt *stmt, const char *colname) -{ -#if DEVELOPER - db_query_colnum(stmt, colname); -#endif -} diff --git a/wallet/db.h b/wallet/db.h index e51e75fd6..3d76d097a 100644 --- a/wallet/db.h +++ b/wallet/db.h @@ -2,45 +2,10 @@ #define LIGHTNING_WALLET_DB_H #include "config.h" -#include -#include -#include -#include -#include -#include - -struct channel_id; struct ext_key; struct lightningd; -struct log; -struct node_id; -struct onionreply; struct db_stmt; struct db; -struct wally_psbt; -struct wally_tx; - -/** - * Macro to annotate a named SQL query. - * - * This macro is used to annotate SQL queries that might need rewriting for - * different SQL dialects. It is used both as a marker for the query - * extraction logic in devtools/sql-rewrite.py to identify queries, as well as - * a way to swap out the query text with it's name so that the query execution - * engine can then look up the rewritten query using its name. - * - */ -#define NAMED_SQL(name,x) x - -/** - * Simple annotation macro that auto-generates names for NAMED_SQL - * - * If this macro is changed it is likely that the extraction logic in - * devtools/sql-rewrite.py needs to change as well, since they need to - * generate identical names to work correctly. - */ -#define SQL(x) NAMED_SQL( __FILE__ ":" stringify(__COUNTER__), x) - /** * db_setup - Open a the lightningd database and update the schema @@ -57,204 +22,4 @@ struct wally_tx; struct db *db_setup(const tal_t *ctx, struct lightningd *ld, const struct ext_key *bip32_base); -/** - * db_begin_transaction - Begin a transaction - * - * Begin a new DB transaction. fatal() on database error. - */ -#define db_begin_transaction(db) \ - db_begin_transaction_((db), __FILE__ ":" stringify(__LINE__)) -void db_begin_transaction_(struct db *db, const char *location); - -bool db_in_transaction(struct db *db); - -/** - * db_commit_transaction - Commit a running transaction - * - * Requires that we are currently in a transaction. fatal() if we - * fail to commit. - */ -void db_commit_transaction(struct db *db); - -/** - * db_set_intvar - Set an integer variable in the database - * - * Utility function to store generic integer values in the - * database. - */ -void db_set_intvar(struct db *db, char *varname, s64 val); - -/** - * db_get_intvar - Retrieve an integer variable from the database - * - * Either returns the value in the database, or @defval if - * the query failed or no such variable exists. - */ -s64 db_get_intvar(struct db *db, char *varname, s64 defval); - -void db_bind_null(struct db_stmt *stmt, int pos); -void db_bind_int(struct db_stmt *stmt, int pos, int val); -void db_bind_u64(struct db_stmt *stmt, int pos, u64 val); -void db_bind_blob(struct db_stmt *stmt, int pos, const u8 *val, size_t len); -void db_bind_text(struct db_stmt *stmt, int pos, const char *val); -void db_bind_preimage(struct db_stmt *stmt, int pos, const struct preimage *p); -void db_bind_sha256(struct db_stmt *stmt, int pos, const struct sha256 *s); -void db_bind_sha256d(struct db_stmt *stmt, int pos, const struct sha256_double *s); -void db_bind_secret(struct db_stmt *stmt, int pos, const struct secret *s); -void db_bind_secret_arr(struct db_stmt *stmt, int col, const struct secret *s); -void db_bind_txid(struct db_stmt *stmt, int pos, const struct bitcoin_txid *t); -void db_bind_channel_id(struct db_stmt *stmt, int pos, const struct channel_id *id); -void db_bind_node_id(struct db_stmt *stmt, int pos, const struct node_id *ni); -void db_bind_node_id_arr(struct db_stmt *stmt, int col, - const struct node_id *ids); -void db_bind_pubkey(struct db_stmt *stmt, int pos, const struct pubkey *p); -void db_bind_short_channel_id(struct db_stmt *stmt, int col, - const struct short_channel_id *id); -void db_bind_short_channel_id_arr(struct db_stmt *stmt, int col, - const struct short_channel_id *id); -void db_bind_signature(struct db_stmt *stmt, int col, - const secp256k1_ecdsa_signature *sig); -void db_bind_timeabs(struct db_stmt *stmt, int col, struct timeabs t); -void db_bind_tx(struct db_stmt *stmt, int col, const struct wally_tx *tx); -void db_bind_psbt(struct db_stmt *stmt, int col, const struct wally_psbt *psbt); -void db_bind_amount_msat(struct db_stmt *stmt, int pos, - const struct amount_msat *msat); -void db_bind_amount_sat(struct db_stmt *stmt, int pos, - const struct amount_sat *sat); -void db_bind_json_escape(struct db_stmt *stmt, int pos, - const struct json_escape *esc); -void db_bind_onionreply(struct db_stmt *stmt, int col, - const struct onionreply *r); -void db_bind_talarr(struct db_stmt *stmt, int col, const u8 *arr); - -bool db_step(struct db_stmt *stmt); - -/* Modern variants: get columns by name from SELECT */ -/* Bridge function to get column number from SELECT - (must exist) */ -size_t db_query_colnum(const struct db_stmt *stmt, const char *colname); - -u64 db_col_u64(struct db_stmt *stmt, const char *colname); -int db_col_int(struct db_stmt *stmt, const char *colname); -size_t db_col_bytes(struct db_stmt *stmt, const char *colname); -int db_col_is_null(struct db_stmt *stmt, const char *colname); -const void* db_col_blob(struct db_stmt *stmt, const char *colname); -char *db_col_strdup(const tal_t *ctx, - struct db_stmt *stmt, - const char *colname); -void db_col_preimage(struct db_stmt *stmt, const char *colname, struct preimage *preimage); -void db_col_amount_msat(struct db_stmt *stmt, const char *colname, struct amount_msat *msat); -void db_col_amount_sat(struct db_stmt *stmt, const char *colname, struct amount_sat *sat); -struct json_escape *db_col_json_escape(const tal_t *ctx, struct db_stmt *stmt, const char *colname); -void db_col_sha256(struct db_stmt *stmt, const char *colname, struct sha256 *sha); -void db_col_sha256d(struct db_stmt *stmt, const char *colname, struct sha256_double *shad); -void db_col_secret(struct db_stmt *stmt, const char *colname, struct secret *s); -struct secret *db_col_secret_arr(const tal_t *ctx, struct db_stmt *stmt, - const char *colname); -void db_col_txid(struct db_stmt *stmt, const char *colname, struct bitcoin_txid *t); -void db_col_channel_id(struct db_stmt *stmt, const char *colname, struct channel_id *dest); -void db_col_node_id(struct db_stmt *stmt, const char *colname, struct node_id *ni); -struct node_id *db_col_node_id_arr(const tal_t *ctx, struct db_stmt *stmt, - const char *colname); -void db_col_pubkey(struct db_stmt *stmt, const char *colname, - struct pubkey *p); -bool db_col_short_channel_id_str(struct db_stmt *stmt, const char *colname, - struct short_channel_id *dest); -struct short_channel_id * -db_col_short_channel_id_arr(const tal_t *ctx, struct db_stmt *stmt, const char *colname); -bool db_col_signature(struct db_stmt *stmt, const char *colname, - secp256k1_ecdsa_signature *sig); -struct timeabs db_col_timeabs(struct db_stmt *stmt, const char *colname); -struct bitcoin_tx *db_col_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname); -struct wally_psbt *db_col_psbt(const tal_t *ctx, struct db_stmt *stmt, const char *colname); -struct bitcoin_tx *db_col_psbt_to_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname); - -struct onionreply *db_col_onionreply(const tal_t *ctx, - struct db_stmt *stmt, const char *colname); - -#define db_col_arr(ctx, stmt, colname, type) \ - ((type *)db_col_arr_((ctx), (stmt), (colname), \ - sizeof(type), TAL_LABEL(type, "[]"), \ - __func__)) -void *db_col_arr_(const tal_t *ctx, struct db_stmt *stmt, const char *colname, - size_t bytes, const char *label, const char *caller); - - -/* Some useful default variants */ -int db_col_int_or_default(struct db_stmt *stmt, const char *colname, int def); -void db_col_amount_msat_or_default(struct db_stmt *stmt, const char *colname, - struct amount_msat *msat, - struct amount_msat def); - - -/* Explicitly ignore a column (so we don't complain you didn't use it!) */ -void db_col_ignore(struct db_stmt *stmt, const char *colname); - -/** - * db_exec_prepared -- Execute a prepared statement - * - * After preparing a statement using `db_prepare`, and after binding all - * non-null variables using the `db_bind_*` functions, it can be executed with - * this function. It is a small, transaction-aware, wrapper around `db_step`, - * that calls fatal() if the execution fails. This may take ownership of - * `stmt` if annotated with `take()`and will free it before returning. - * - * If you'd like to issue a query and access the rows returned by the query - * please use `db_query_prepared` instead, since this function will not expose - * returned results, and the `stmt` can only be used for calls to - * `db_count_changes` and `db_last_insert_id` after executing. - * - * @stmt: The prepared statement to execute - */ -bool db_exec_prepared_v2(struct db_stmt *stmt TAKES); - -/** - * db_query_prepared -- Execute a prepared query - * - * After preparing a query using `db_prepare`, and after binding all non-null - * variables using the `db_bind_*` functions, it can be executed with this - * function. This function must be called before calling `db_step` or any of - * the `db_col_*` column access functions. - * - * If you are not executing a read-only statement, please use - * `db_exec_prepared` instead. - * - * @stmt: The prepared statement to execute - */ -bool db_query_prepared(struct db_stmt *stmt); -size_t db_count_changes(struct db_stmt *stmt); -u64 db_last_insert_id_v2(struct db_stmt *stmt); - -/** - * db_prepare -- Prepare a DB query/command - * - * Create an instance of `struct db_stmt` that encapsulates a SQL query or command. - * - * @query MUST be wrapped in a `SQL()` macro call, since that allows the - * extraction and translation of the query into the target SQL dialect. - * - * It does not execute the query and does not check its validity, but - * allocates the placeholders detected in the query. The placeholders in the - * `stmt` can then be bound using the `db_bind_*` functions, and executed - * using `db_exec_prepared` for write-only statements and `db_query_prepared` - * for read-only statements. - * - * @db: Database to query/exec - * @query: The SQL statement to compile - */ -struct db_stmt *db_prepare_v2_(const char *location, struct db *db, - const char *query_id); - -/* TODO(cdecker) Remove the v2 suffix after finishing the migration */ -#define db_prepare_v2(db,query) \ - db_prepare_v2_(__FILE__ ":" stringify(__LINE__), db, query) - -/** - * Access pending changes that have been added to the current transaction. - */ -const char **db_changes(struct db *db); - -/* Get the current data version. */ -u32 db_data_version_get(struct db *db); - #endif /* LIGHTNING_WALLET_DB_H */ diff --git a/wallet/db_queries_postgres.c b/wallet/db_queries_postgres.c new file mode 100644 index 000000000..687f9e4d5 --- /dev/null +++ b/wallet/db_queries_postgres.c @@ -0,0 +1,13 @@ +#include "config.h" +#include "db_postgres_sqlgen.c" + +#if HAVE_POSTGRES + +struct db_query_set postgres_query_set = { + .name = "postgres", + .query_table = db_postgres_queries, + .query_table_size = ARRAY_SIZE(db_postgres_queries), +}; + +AUTODATA(db_queries, &postgres_query_set); +#endif /* HAVE_POSTGRES */ diff --git a/wallet/db_queries_sqlite3.c b/wallet/db_queries_sqlite3.c new file mode 100644 index 000000000..9411170fd --- /dev/null +++ b/wallet/db_queries_sqlite3.c @@ -0,0 +1,12 @@ +#include "config.h" +#include "db_sqlite3_sqlgen.c" + +#if HAVE_SQLITE3 +struct db_query_set sqlite3_query_set = { + .name = "sqlite3", + .query_table = db_sqlite3_queries, + .query_table_size = ARRAY_SIZE(db_sqlite3_queries), +}; + +AUTODATA(db_queries, &sqlite3_query_set); +#endif /* HAVE_SQLITE3 */ diff --git a/wallet/invoices.c b/wallet/invoices.c index 1d6e268ec..f8d46d8d6 100644 --- a/wallet/invoices.c +++ b/wallet/invoices.c @@ -1,7 +1,10 @@ #include "config.h" #include #include -#include +#include +#include +#include +#include #include #include diff --git a/wallet/test/Makefile b/wallet/test/Makefile index 46335b063..b8406b16c 100644 --- a/wallet/test/Makefile +++ b/wallet/test/Makefile @@ -27,7 +27,7 @@ WALLET_TEST_COMMON_OBJS := \ common/utils.o \ common/wireaddr.o \ common/version.o \ - wallet/db_sqlite3.o \ + wallet/db_queries_sqlite3.o \ wire/towire.o \ wire/fromwire.o diff --git a/wallet/test/run-db.c b/wallet/test/run-db.c index 637658526..966a0a9c8 100644 --- a/wallet/test/run-db.c +++ b/wallet/test/run-db.c @@ -6,6 +6,10 @@ static void db_log_(struct log *log UNUSED, enum log_level level UNUSED, const s } #define log_ db_log_ +#include "db/bindings.c" +#include "db/db_sqlite3.c" +#include "db/exec.c" +#include "db/utils.c" #include "wallet/db.c" #include "test_utils.h" @@ -86,6 +90,8 @@ static struct db *create_test_db(void) tal_free(filename); db = db_open(NULL, dsn); db->data_version = 0; + db->report_changes_fn = NULL; + tal_free(dsn); return db; } diff --git a/wallet/test/run-wallet.c b/wallet/test/run-wallet.c index fa12dcf62..7da3e5e41 100644 --- a/wallet/test/run-wallet.c +++ b/wallet/test/run-wallet.c @@ -3,7 +3,7 @@ #include "test_utils.h" #include -#include +#include static void db_log_(struct log *log UNUSED, enum log_level level UNUSED, const struct node_id *node_id UNUSED, bool call_notifier UNUSED, const char *fmt UNUSED, ...) { @@ -35,6 +35,10 @@ void db_fatal(const char *fmt, ...) #include "lightningd/peer_htlcs.c" #include "lightningd/channel.c" +#include "db/bindings.c" +#include "db/db_sqlite3.c" +#include "db/exec.c" +#include "db/utils.c" #include "wallet/db.c" #include @@ -908,6 +912,7 @@ static struct wallet *create_test_wallet(struct lightningd *ld, const tal_t *ctx dsn = tal_fmt(NULL, "sqlite3://%s", filename); w->db = db_open(w, dsn); + w->db->report_changes_fn = NULL; tal_free(dsn); tal_add_destructor2(w, cleanup_test_wallet, filename); diff --git a/wallet/wallet.c b/wallet/wallet.c index 3cbe424ec..16444bf98 100644 --- a/wallet/wallet.c +++ b/wallet/wallet.c @@ -8,13 +8,16 @@ #include #include #include +#include +#include +#include +#include #include #include #include #include #include #include -#include #include #include #include diff --git a/wallet/walletrpc.c b/wallet/walletrpc.c index 584dfa815..1d32d89de 100644 --- a/wallet/walletrpc.c +++ b/wallet/walletrpc.c @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include