database: pull out database code into a new module

We're going to reuse the database controllers for the accounting plugin
This commit is contained in:
niftynei 2022-01-03 12:45:35 -06:00 committed by Rusty Russell
parent 03c950bae8
commit ce12d2b8a9
37 changed files with 1478 additions and 1292 deletions

4
.gitattributes vendored
View file

@ -7,8 +7,8 @@ configure text eol=lf
# The following files are generated and should not be shown in the
# diffs by default on Github.
doc/lightning*.7 linguist-generated=true
wallet/db_*_sqlgen.c linguist-generated=true
wallet/statements_gettextgen.po linguist-generated=true
db_*_sqlgen.c linguist-generated=true
statements_gettextgen.po linguist-generated=true
*_wiregen.? linguist-generated=true
*_printgen.? linguist-generated=true

View file

@ -336,6 +336,7 @@ include external/Makefile
include bitcoin/Makefile
include common/Makefile
include wire/Makefile
include db/Makefile
include hsmd/Makefile
include gossipd/Makefile
include openingd/Makefile

23
db/Makefile Normal file
View file

@ -0,0 +1,23 @@
#! /usr/bin/make
DB_LIB_SRC := \
db/bindings.c \
db/exec.c \
db/utils.c
DB_DRIVERS := \
db/db_postgres.c \
db/db_sqlite3.c
DB_SRC := $(DB_LIB_SRC) $(DB_DRIVERS)
DB_HEADERS := $(DB_LIB_SRC:.c=.h) db/common.h
DB_OBJS := $(DB_LIB_SRC:.c=.o) $(DB_DRIVERS:.c=.o)
$(DB_OBJS): $(DB_HEADERS)
# Make sure these depend on everything.
ALL_C_SOURCES += $(DB_SRC)
ALL_C_HEADERS += $(DB_HEADERS)
# DB_SQL_FILES is the list of database files
DB_SQL_FILES := db/exec.c

554
db/bindings.c Normal file
View file

@ -0,0 +1,554 @@
#include "config.h"
#include <bitcoin/privkey.h>
#include <bitcoin/psbt.h>
#include <ccan/mem/mem.h>
#include <ccan/take/take.h>
#include <ccan/tal/str/str.h>
#include <ccan/tal/tal.h>
#include <common/channel_id.h>
#include <common/htlc_state.h>
#include <common/node_id.h>
#include <common/onionreply.h>
#include <db/bindings.h>
#include <db/common.h>
#include <db/utils.h>
#define NSEC_IN_SEC 1000000000
/* Local helpers once you have column number */
static bool db_column_is_null(struct db_stmt *stmt, int col)
{
return stmt->db->config->column_is_null_fn(stmt, col);
}
/* Returns true (and warns) if it's nul */
static bool db_column_null_warn(struct db_stmt *stmt, const char *colname,
int col)
{
if (!db_column_is_null(stmt, col))
return false;
/* FIXME: log broken? */
#if DEVELOPER
db_fatal("Accessing a null column %s/%i in query %s",
colname, col, stmt->query->query);
#endif /* DEVELOPER */
return true;
}
void db_bind_int(struct db_stmt *stmt, int pos, int val)
{
assert(pos < tal_count(stmt->bindings));
memcheck(&val, sizeof(val));
stmt->bindings[pos].type = DB_BINDING_INT;
stmt->bindings[pos].v.i = val;
}
int db_col_int(struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_null_warn(stmt, colname, col))
return 0;
return stmt->db->config->column_int_fn(stmt, col);
}
int db_col_is_null(struct db_stmt *stmt, const char *colname)
{
return db_column_is_null(stmt, db_query_colnum(stmt, colname));
}
void db_bind_null(struct db_stmt *stmt, int pos)
{
assert(pos < tal_count(stmt->bindings));
stmt->bindings[pos].type = DB_BINDING_NULL;
}
void db_bind_u64(struct db_stmt *stmt, int pos, u64 val)
{
memcheck(&val, sizeof(val));
assert(pos < tal_count(stmt->bindings));
stmt->bindings[pos].type = DB_BINDING_UINT64;
stmt->bindings[pos].v.u64 = val;
}
void db_bind_blob(struct db_stmt *stmt, int pos, const u8 *val, size_t len)
{
assert(pos < tal_count(stmt->bindings));
stmt->bindings[pos].type = DB_BINDING_BLOB;
stmt->bindings[pos].v.blob = memcheck(val, len);
stmt->bindings[pos].len = len;
}
void db_bind_text(struct db_stmt *stmt, int pos, const char *val)
{
assert(pos < tal_count(stmt->bindings));
stmt->bindings[pos].type = DB_BINDING_TEXT;
stmt->bindings[pos].v.text = val;
stmt->bindings[pos].len = strlen(val);
}
void db_bind_preimage(struct db_stmt *stmt, int pos, const struct preimage *p)
{
db_bind_blob(stmt, pos, p->r, sizeof(struct preimage));
}
void db_bind_sha256(struct db_stmt *stmt, int pos, const struct sha256 *s)
{
db_bind_blob(stmt, pos, s->u.u8, sizeof(struct sha256));
}
void db_bind_sha256d(struct db_stmt *stmt, int pos, const struct sha256_double *s)
{
db_bind_sha256(stmt, pos, &s->sha);
}
void db_bind_secret(struct db_stmt *stmt, int pos, const struct secret *s)
{
assert(sizeof(s->data) == 32);
db_bind_blob(stmt, pos, s->data, sizeof(s->data));
}
void db_bind_secret_arr(struct db_stmt *stmt, int col, const struct secret *s)
{
size_t num = tal_count(s), elsize = sizeof(s->data);
u8 *ser = tal_arr(stmt, u8, num * elsize);
for (size_t i = 0; i < num; ++i)
memcpy(ser + i * elsize, &s[i], elsize);
db_bind_blob(stmt, col, ser, tal_count(ser));
}
void db_bind_txid(struct db_stmt *stmt, int pos, const struct bitcoin_txid *t)
{
db_bind_sha256d(stmt, pos, &t->shad);
}
void db_bind_channel_id(struct db_stmt *stmt, int pos, const struct channel_id *id)
{
db_bind_blob(stmt, pos, id->id, sizeof(id->id));
}
void db_bind_node_id(struct db_stmt *stmt, int pos, const struct node_id *id)
{
db_bind_blob(stmt, pos, id->k, sizeof(id->k));
}
void db_bind_node_id_arr(struct db_stmt *stmt, int col,
const struct node_id *ids)
{
/* Copy into contiguous array: ARM will add padding to struct node_id! */
size_t n = tal_count(ids);
u8 *arr = tal_arr(stmt, u8, n * sizeof(ids[0].k));
for (size_t i = 0; i < n; ++i) {
assert(node_id_valid(&ids[i]));
memcpy(arr + sizeof(ids[i].k) * i,
ids[i].k,
sizeof(ids[i].k));
}
db_bind_blob(stmt, col, arr, tal_count(arr));
}
void db_bind_pubkey(struct db_stmt *stmt, int pos, const struct pubkey *pk)
{
u8 *der = tal_arr(stmt, u8, PUBKEY_CMPR_LEN);
pubkey_to_der(der, pk);
db_bind_blob(stmt, pos, der, PUBKEY_CMPR_LEN);
}
void db_bind_short_channel_id(struct db_stmt *stmt, int col,
const struct short_channel_id *id)
{
char *ser = short_channel_id_to_str(stmt, id);
db_bind_text(stmt, col, ser);
}
void db_bind_short_channel_id_arr(struct db_stmt *stmt, int col,
const struct short_channel_id *id)
{
u8 *ser = tal_arr(stmt, u8, 0);
size_t num = tal_count(id);
for (size_t i = 0; i < num; ++i)
towire_short_channel_id(&ser, &id[i]);
db_bind_talarr(stmt, col, ser);
}
void db_bind_signature(struct db_stmt *stmt, int col,
const secp256k1_ecdsa_signature *sig)
{
u8 *buf = tal_arr(stmt, u8, 64);
int ret = secp256k1_ecdsa_signature_serialize_compact(secp256k1_ctx,
buf, sig);
assert(ret == 1);
db_bind_blob(stmt, col, buf, 64);
}
void db_bind_timeabs(struct db_stmt *stmt, int col, struct timeabs t)
{
u64 timestamp = t.ts.tv_nsec + (((u64) t.ts.tv_sec) * ((u64) NSEC_IN_SEC));
db_bind_u64(stmt, col, timestamp);
}
void db_bind_tx(struct db_stmt *stmt, int col, const struct wally_tx *tx)
{
u8 *ser = linearize_wtx(stmt, tx);
assert(ser);
db_bind_talarr(stmt, col, ser);
}
void db_bind_psbt(struct db_stmt *stmt, int col, const struct wally_psbt *psbt)
{
size_t bytes_written;
const u8 *ser = psbt_get_bytes(stmt, psbt, &bytes_written);
assert(ser);
db_bind_blob(stmt, col, ser, bytes_written);
}
void db_bind_amount_msat(struct db_stmt *stmt, int pos,
const struct amount_msat *msat)
{
db_bind_u64(stmt, pos, msat->millisatoshis); /* Raw: low level function */
}
void db_bind_amount_sat(struct db_stmt *stmt, int pos,
const struct amount_sat *sat)
{
db_bind_u64(stmt, pos, sat->satoshis); /* Raw: low level function */
}
void db_bind_json_escape(struct db_stmt *stmt, int pos,
const struct json_escape *esc)
{
db_bind_text(stmt, pos, esc->s);
}
void db_bind_onionreply(struct db_stmt *stmt, int pos, const struct onionreply *r)
{
db_bind_talarr(stmt, pos, r->contents);
}
void db_bind_talarr(struct db_stmt *stmt, int col, const u8 *arr)
{
if (!arr)
db_bind_null(stmt, col);
else
db_bind_blob(stmt, col, arr, tal_bytelen(arr));
}
static size_t db_column_bytes(struct db_stmt *stmt, int col)
{
if (db_column_is_null(stmt, col))
return 0;
return stmt->db->config->column_bytes_fn(stmt, col);
}
static const void *db_column_blob(struct db_stmt *stmt, int col)
{
if (db_column_is_null(stmt, col))
return NULL;
return stmt->db->config->column_blob_fn(stmt, col);
}
u64 db_col_u64(struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_null_warn(stmt, colname, col))
return 0;
return stmt->db->config->column_u64_fn(stmt, col);
}
int db_col_int_or_default(struct db_stmt *stmt, const char *colname, int def)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_is_null(stmt, col))
return def;
else
return stmt->db->config->column_int_fn(stmt, col);
}
size_t db_col_bytes(struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_null_warn(stmt, colname, col))
return 0;
return stmt->db->config->column_bytes_fn(stmt, col);
}
const void *db_col_blob(struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_null_warn(stmt, colname, col))
return NULL;
return stmt->db->config->column_blob_fn(stmt, col);
}
char *db_col_strdup(const tal_t *ctx,
struct db_stmt *stmt,
const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_null_warn(stmt, colname, col))
return NULL;
return tal_strdup(ctx, (char *)stmt->db->config->column_text_fn(stmt, col));
}
void db_col_preimage(struct db_stmt *stmt, const char *colname,
struct preimage *preimage)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *raw;
size_t size = sizeof(struct preimage);
assert(db_column_bytes(stmt, col) == size);
raw = db_column_blob(stmt, col);
memcpy(preimage, raw, size);
}
void db_col_channel_id(struct db_stmt *stmt, const char *colname, struct channel_id *dest)
{
size_t col = db_query_colnum(stmt, colname);
assert(db_column_bytes(stmt, col) == sizeof(dest->id));
memcpy(dest->id, db_column_blob(stmt, col), sizeof(dest->id));
}
void db_col_node_id(struct db_stmt *stmt, const char *colname, struct node_id *dest)
{
size_t col = db_query_colnum(stmt, colname);
assert(db_column_bytes(stmt, col) == sizeof(dest->k));
memcpy(dest->k, db_column_blob(stmt, col), sizeof(dest->k));
}
struct node_id *db_col_node_id_arr(const tal_t *ctx, struct db_stmt *stmt,
const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
struct node_id *ret;
size_t n = db_column_bytes(stmt, col) / sizeof(ret->k);
const u8 *arr = db_column_blob(stmt, col);
assert(n * sizeof(ret->k) == (size_t)db_column_bytes(stmt, col));
ret = tal_arr(ctx, struct node_id, n);
db_column_null_warn(stmt, colname, col);
for (size_t i = 0; i < n; i++)
memcpy(ret[i].k, arr + i * sizeof(ret[i].k), sizeof(ret[i].k));
return ret;
}
void db_col_pubkey(struct db_stmt *stmt,
const char *colname,
struct pubkey *dest)
{
size_t col = db_query_colnum(stmt, colname);
bool ok;
assert(db_column_bytes(stmt, col) == PUBKEY_CMPR_LEN);
ok = pubkey_from_der(db_column_blob(stmt, col), PUBKEY_CMPR_LEN, dest);
assert(ok);
}
/* Yes, we put this in as a string. Past mistakes; do not use! */
bool db_col_short_channel_id_str(struct db_stmt *stmt, const char *colname,
struct short_channel_id *dest)
{
size_t col = db_query_colnum(stmt, colname);
const char *source = db_column_blob(stmt, col);
size_t sourcelen = db_column_bytes(stmt, col);
db_column_null_warn(stmt, colname, col);
return short_channel_id_from_str(source, sourcelen, dest);
}
struct short_channel_id *
db_col_short_channel_id_arr(const tal_t *ctx, struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *ser;
size_t len;
struct short_channel_id *ret;
db_column_null_warn(stmt, colname, col);
ser = db_column_blob(stmt, col);
len = db_column_bytes(stmt, col);
ret = tal_arr(ctx, struct short_channel_id, 0);
while (len != 0) {
struct short_channel_id scid;
fromwire_short_channel_id(&ser, &len, &scid);
tal_arr_expand(&ret, scid);
}
return ret;
}
bool db_col_signature(struct db_stmt *stmt, const char *colname,
secp256k1_ecdsa_signature *sig)
{
size_t col = db_query_colnum(stmt, colname);
assert(db_column_bytes(stmt, col) == 64);
return secp256k1_ecdsa_signature_parse_compact(
secp256k1_ctx, sig, db_column_blob(stmt, col)) == 1;
}
struct timeabs db_col_timeabs(struct db_stmt *stmt, const char *colname)
{
struct timeabs t;
u64 timestamp = db_col_u64(stmt, colname);
t.ts.tv_sec = timestamp / NSEC_IN_SEC;
t.ts.tv_nsec = timestamp % NSEC_IN_SEC;
return t;
}
struct bitcoin_tx *db_col_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *src = db_column_blob(stmt, col);
size_t len = db_column_bytes(stmt, col);
db_column_null_warn(stmt, colname, col);
return pull_bitcoin_tx(ctx, &src, &len);
}
struct wally_psbt *db_col_psbt(const tal_t *ctx, struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *src = db_column_blob(stmt, col);
size_t len = db_column_bytes(stmt, col);
db_column_null_warn(stmt, colname, col);
return psbt_from_bytes(ctx, src, len);
}
struct bitcoin_tx *db_col_psbt_to_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname)
{
struct wally_psbt *psbt = db_col_psbt(ctx, stmt, colname);
if (!psbt)
return NULL;
return bitcoin_tx_with_psbt(ctx, psbt);
}
void *db_col_arr_(const tal_t *ctx, struct db_stmt *stmt, const char *colname,
size_t bytes, const char *label, const char *caller)
{
size_t col = db_query_colnum(stmt, colname);
size_t sourcelen;
void *p;
if (db_column_is_null(stmt, col))
return NULL;
sourcelen = db_column_bytes(stmt, col);
if (sourcelen % bytes != 0)
db_fatal("%s: %s/%zu column size for %zu not a multiple of %s (%zu)",
caller, colname, col, sourcelen, label, bytes);
p = tal_arr_label(ctx, char, sourcelen, label);
memcpy(p, db_column_blob(stmt, col), sourcelen);
return p;
}
void db_col_amount_msat_or_default(struct db_stmt *stmt,
const char *colname,
struct amount_msat *msat,
struct amount_msat def)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_is_null(stmt, col))
*msat = def;
else
msat->millisatoshis = db_col_u64(stmt, colname); /* Raw: low level function */
}
void db_col_amount_msat(struct db_stmt *stmt, const char *colname,
struct amount_msat *msat)
{
msat->millisatoshis = db_col_u64(stmt, colname); /* Raw: low level function */
}
void db_col_amount_sat(struct db_stmt *stmt, const char *colname, struct amount_sat *sat)
{
sat->satoshis = db_col_u64(stmt, colname); /* Raw: low level function */
}
struct json_escape *db_col_json_escape(const tal_t *ctx,
struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
return json_escape_string_(ctx, db_column_blob(stmt, col),
db_column_bytes(stmt, col));
}
void db_col_sha256(struct db_stmt *stmt, const char *colname, struct sha256 *sha)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *raw;
size_t size = sizeof(struct sha256);
assert(db_column_bytes(stmt, col) == size);
raw = db_column_blob(stmt, col);
memcpy(sha, raw, size);
}
void db_col_sha256d(struct db_stmt *stmt, const char *colname,
struct sha256_double *shad)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *raw;
size_t size = sizeof(struct sha256_double);
assert(db_column_bytes(stmt, col) == size);
raw = db_column_blob(stmt, col);
memcpy(shad, raw, size);
}
void db_col_secret(struct db_stmt *stmt, const char *colname, struct secret *s)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *raw;
assert(db_column_bytes(stmt, col) == sizeof(struct secret));
raw = db_column_blob(stmt, col);
memcpy(s, raw, sizeof(struct secret));
}
struct secret *db_col_secret_arr(const tal_t *ctx,
struct db_stmt *stmt,
const char *colname)
{
return db_col_arr(ctx, stmt, colname, struct secret);
}
void db_col_txid(struct db_stmt *stmt, const char *colname, struct bitcoin_txid *t)
{
db_col_sha256d(stmt, colname, &t->shad);
}
struct onionreply *db_col_onionreply(const tal_t *ctx,
struct db_stmt *stmt, const char *colname)
{
struct onionreply *r = tal(ctx, struct onionreply);
r->contents = db_col_arr(ctx, stmt, colname, u8);
return r;
}
void db_col_ignore(struct db_stmt *stmt, const char *colname)
{
#if DEVELOPER
db_query_colnum(stmt, colname);
#endif
}

118
db/bindings.h Normal file
View file

@ -0,0 +1,118 @@
#ifndef LIGHTNING_DB_BINDINGS_H
#define LIGHTNING_DB_BINDINGS_H
#include "config.h"
#include <bitcoin/preimage.h>
#include <bitcoin/pubkey.h>
#include <bitcoin/short_channel_id.h>
#include <bitcoin/tx.h>
#include <ccan/json_escape/json_escape.h>
#include <ccan/time/time.h>
struct channel_id;
struct db_stmt;
struct node_id;
struct onionreply;
struct wally_psbt;
struct wally_tx;
int db_col_is_null(struct db_stmt *stmt, const char *colname);
void db_bind_int(struct db_stmt *stmt, int pos, int val);
int db_col_int(struct db_stmt *stmt, const char *colname);
void db_bind_null(struct db_stmt *stmt, int pos);
void db_bind_int(struct db_stmt *stmt, int pos, int val);
void db_bind_u64(struct db_stmt *stmt, int pos, u64 val);
void db_bind_blob(struct db_stmt *stmt, int pos, const u8 *val, size_t len);
void db_bind_text(struct db_stmt *stmt, int pos, const char *val);
void db_bind_preimage(struct db_stmt *stmt, int pos, const struct preimage *p);
void db_bind_sha256(struct db_stmt *stmt, int pos, const struct sha256 *s);
void db_bind_sha256d(struct db_stmt *stmt, int pos, const struct sha256_double *s);
void db_bind_secret(struct db_stmt *stmt, int pos, const struct secret *s);
void db_bind_secret_arr(struct db_stmt *stmt, int col, const struct secret *s);
void db_bind_txid(struct db_stmt *stmt, int pos, const struct bitcoin_txid *t);
void db_bind_channel_id(struct db_stmt *stmt, int pos, const struct channel_id *id);
void db_bind_node_id(struct db_stmt *stmt, int pos, const struct node_id *ni);
void db_bind_node_id_arr(struct db_stmt *stmt, int col,
const struct node_id *ids);
void db_bind_pubkey(struct db_stmt *stmt, int pos, const struct pubkey *p);
void db_bind_short_channel_id(struct db_stmt *stmt, int col,
const struct short_channel_id *id);
void db_bind_short_channel_id_arr(struct db_stmt *stmt, int col,
const struct short_channel_id *id);
void db_bind_signature(struct db_stmt *stmt, int col,
const secp256k1_ecdsa_signature *sig);
void db_bind_timeabs(struct db_stmt *stmt, int col, struct timeabs t);
void db_bind_tx(struct db_stmt *stmt, int col, const struct wally_tx *tx);
void db_bind_psbt(struct db_stmt *stmt, int col, const struct wally_psbt *psbt);
void db_bind_amount_msat(struct db_stmt *stmt, int pos,
const struct amount_msat *msat);
void db_bind_amount_sat(struct db_stmt *stmt, int pos,
const struct amount_sat *sat);
void db_bind_json_escape(struct db_stmt *stmt, int pos,
const struct json_escape *esc);
void db_bind_onionreply(struct db_stmt *stmt, int col,
const struct onionreply *r);
void db_bind_talarr(struct db_stmt *stmt, int col, const u8 *arr);
/* Modern variants: get columns by name from SELECT */
/* Bridge function to get column number from SELECT
(must exist) */
size_t db_query_colnum(const struct db_stmt *stmt, const char *colname);
u64 db_col_u64(struct db_stmt *stmt, const char *colname);
size_t db_col_bytes(struct db_stmt *stmt, const char *colname);
const void* db_col_blob(struct db_stmt *stmt, const char *colname);
char *db_col_strdup(const tal_t *ctx,
struct db_stmt *stmt,
const char *colname);
void db_col_preimage(struct db_stmt *stmt, const char *colname, struct preimage *preimage);
void db_col_amount_msat(struct db_stmt *stmt, const char *colname, struct amount_msat *msat);
void db_col_amount_sat(struct db_stmt *stmt, const char *colname, struct amount_sat *sat);
struct json_escape *db_col_json_escape(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
void db_col_sha256(struct db_stmt *stmt, const char *colname, struct sha256 *sha);
void db_col_sha256d(struct db_stmt *stmt, const char *colname, struct sha256_double *shad);
void db_col_secret(struct db_stmt *stmt, const char *colname, struct secret *s);
struct secret *db_col_secret_arr(const tal_t *ctx, struct db_stmt *stmt,
const char *colname);
void db_col_txid(struct db_stmt *stmt, const char *colname, struct bitcoin_txid *t);
void db_col_channel_id(struct db_stmt *stmt, const char *colname, struct channel_id *dest);
void db_col_node_id(struct db_stmt *stmt, const char *colname, struct node_id *ni);
struct node_id *db_col_node_id_arr(const tal_t *ctx, struct db_stmt *stmt,
const char *colname);
void db_col_pubkey(struct db_stmt *stmt, const char *colname,
struct pubkey *p);
bool db_col_short_channel_id_str(struct db_stmt *stmt, const char *colname,
struct short_channel_id *dest);
struct short_channel_id *
db_col_short_channel_id_arr(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
bool db_col_signature(struct db_stmt *stmt, const char *colname,
secp256k1_ecdsa_signature *sig);
struct timeabs db_col_timeabs(struct db_stmt *stmt, const char *colname);
struct bitcoin_tx *db_col_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
struct wally_psbt *db_col_psbt(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
struct bitcoin_tx *db_col_psbt_to_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
struct onionreply *db_col_onionreply(const tal_t *ctx,
struct db_stmt *stmt, const char *colname);
#define db_col_arr(ctx, stmt, colname, type) \
((type *)db_col_arr_((ctx), (stmt), (colname), \
sizeof(type), TAL_LABEL(type, "[]"), \
__func__))
void *db_col_arr_(const tal_t *ctx, struct db_stmt *stmt, const char *colname,
size_t bytes, const char *label, const char *caller);
/* Some useful default variants */
int db_col_int_or_default(struct db_stmt *stmt, const char *colname, int def);
void db_col_amount_msat_or_default(struct db_stmt *stmt, const char *colname,
struct amount_msat *msat,
struct amount_msat def);
/* Explicitly ignore a column (so we don't complain you didn't use it!) */
void db_col_ignore(struct db_stmt *stmt, const char *colname);
#endif /* LIGHTNING_DB_BINDINGS_H */

View file

@ -1,10 +1,32 @@
#ifndef LIGHTNING_WALLET_DB_COMMON_H
#define LIGHTNING_WALLET_DB_COMMON_H
#ifndef LIGHTNING_DB_COMMON_H
#define LIGHTNING_DB_COMMON_H
#include "config.h"
#include <ccan/list/list.h>
#include <ccan/short_types/short_types.h>
#include <ccan/strset/strset.h>
#include <common/autodata.h>
#include <common/utils.h>
/**
* Macro to annotate a named SQL query.
*
* This macro is used to annotate SQL queries that might need rewriting for
* different SQL dialects. It is used both as a marker for the query
* extraction logic in devtools/sql-rewrite.py to identify queries, as well as
* a way to swap out the query text with it's name so that the query execution
* engine can then look up the rewritten query using its name.
*
*/
#define NAMED_SQL(name,x) x
/**
* Simple annotation macro that auto-generates names for NAMED_SQL
*
* If this macro is changed it is likely that the extraction logic in
* devtools/sql-rewrite.py needs to change as well, since they need to
* generate identical names to work correctly.
*/
#define SQL(x) NAMED_SQL( __FILE__ ":" stringify(__COUNTER__), x)
struct db {
char *filename;
@ -13,18 +35,18 @@ struct db {
/* DB-specific context */
void *conn;
/* The configuration, including translated queries for the current
* instance. */
/* The configuration for the current database driver */
const struct db_config *config;
/* Translated queries for the current database domain + driver */
const struct db_query_set *queries;
const char **changes;
/* List of statements that have been created but not executed yet. */
struct list_head pending_statements;
char *error;
struct log *log;
/* Were there any modifying statements in the current transaction?
* Used to bump the data_version in the DB.*/
bool dirty;
@ -32,6 +54,8 @@ struct db {
/* The current DB version we expect to update if changes are
* committed. */
u32 data_version;
void (*report_changes_fn)(struct db *);
};
struct db_query {
@ -102,10 +126,14 @@ struct db_stmt {
#endif
};
struct db_config {
struct db_query_set {
const char *name;
const struct db_query *query_table;
size_t query_table_size;
};
struct db_config {
const char *name;
/* Function used to execute a statement that doesn't result in a
* response. */
@ -155,20 +183,14 @@ struct db_config {
const char **colnames, size_t num_cols);
};
/* Provide a way for DB backends to register themselves */
AUTODATA_TYPE(db_backends, struct db_config);
void db_fatal(const char *fmt, ...)
PRINTF_FMT(1, 2);
/**
* Report a statement that changes the wallet
*
* Allows the DB driver to report an expanded statement during
* execution. Changes are queued up and reported to the `db_write` plugin hook
* upon committing.
*/
void db_changes_add(struct db_stmt *db_stmt, const char * expanded);
/* Provide a way for DB backends to register themselves */
AUTODATA_TYPE(db_backends, struct db_config);
/* Provide a way for DB query sets to register themselves */
AUTODATA_TYPE(db_queries, struct db_query_set);
/* devtools/sql-rewrite.py generates this simple htable */
struct sqlname_map {
@ -176,4 +198,4 @@ struct sqlname_map {
int val;
};
#endif /* LIGHTNING_WALLET_DB_COMMON_H */
#endif /* LIGHTNING_DB_COMMON_H */

View file

@ -1,9 +1,8 @@
#include "config.h"
#include <ccan/ccan/tal/str/str.h>
#include <ccan/endian/endian.h>
#include <lightningd/log.h>
#include <wallet/db_common.h>
#include <wallet/db_postgres_sqlgen.c>
#include <db/common.h>
#include <db/utils.h>
#if HAVE_POSTGRES
/* Indented in order not to trigger the inclusion order check */
@ -321,8 +320,6 @@ static bool db_postgres_delete_columns(struct db *db,
struct db_config db_postgres_config = {
.name = "postgres",
.query_table = db_postgres_queries,
.query_table_size = ARRAY_SIZE(db_postgres_queries),
.exec_fn = db_postgres_exec,
.query_fn = db_postgres_query,
.step_fn = db_postgres_step,
@ -348,4 +345,4 @@ struct db_config db_postgres_config = {
AUTODATA(db_backends, &db_postgres_config);
#endif
#endif /* HAVE_POSTGRES */

View file

@ -1,8 +1,8 @@
#include "config.h"
#include "db_sqlite3_sqlgen.c"
#include <ccan/ccan/tal/str/str.h>
#include <common/utils.h>
#include <lightningd/log.h>
#include <db/common.h>
#include <db/utils.h>
#if HAVE_SQLITE3
#include <sqlite3.h>
@ -682,8 +682,6 @@ static bool db_sqlite3_delete_columns(struct db *db,
struct db_config db_sqlite3_config = {
.name = "sqlite3",
.query_table = db_sqlite3_queries,
.query_table_size = ARRAY_SIZE(db_sqlite3_queries),
.exec_fn = &db_sqlite3_exec,
.query_fn = &db_sqlite3_query,
.step_fn = &db_sqlite3_step,
@ -710,4 +708,4 @@ struct db_config db_sqlite3_config = {
AUTODATA(db_backends, &db_sqlite3_config);
#endif
#endif /* HAVE_SQLITE3 */

162
db/exec.c Normal file
View file

@ -0,0 +1,162 @@
#include "config.h"
#include <ccan/tal/tal.h>
#include <db/bindings.h>
#include <db/common.h>
#include <db/exec.h>
#include <db/utils.h>
/**
* db_get_version - Determine the current DB schema version
*
* Will attempt to determine the current schema version of the
* database @db by querying the `version` table. If the table does not
* exist it'll return schema version -1, so that migration 0 is
* applied, which should create the `version` table.
*/
int db_get_version(struct db *db)
{
int res = -1;
struct db_stmt *stmt = db_prepare_v2(db, SQL("SELECT version FROM version LIMIT 1"));
/*
* Tentatively execute a query, but allow failures. Some databases
* like postgres will terminate the DB transaction if there is an
* error during the execution of a query, e.g., trying to access a
* table that doesn't exist yet, so we need to terminate and restart
* the DB transaction.
*/
if (!db_query_prepared(stmt)) {
db_commit_transaction(stmt->db);
db_begin_transaction(stmt->db);
tal_free(stmt);
return res;
}
if (db_step(stmt))
res = db_col_int(stmt, "version");
tal_free(stmt);
return res;
}
u32 db_data_version_get(struct db *db)
{
struct db_stmt *stmt;
u32 version;
stmt = db_prepare_v2(db, SQL("SELECT intval FROM vars WHERE name = 'data_version'"));
db_query_prepared(stmt);
db_step(stmt);
version = db_col_int(stmt, "intval");
tal_free(stmt);
return version;
}
void db_set_intvar(struct db *db, char *varname, s64 val)
{
size_t changes;
struct db_stmt *stmt = db_prepare_v2(db, SQL("UPDATE vars SET intval=? WHERE name=?;"));
db_bind_int(stmt, 0, val);
db_bind_text(stmt, 1, varname);
if (!db_exec_prepared_v2(stmt))
db_fatal("Error executing update: %s", stmt->error);
changes = db_count_changes(stmt);
tal_free(stmt);
if (changes == 0) {
stmt = db_prepare_v2(db, SQL("INSERT INTO vars (name, intval) VALUES (?, ?);"));
db_bind_text(stmt, 0, varname);
db_bind_int(stmt, 1, val);
if (!db_exec_prepared_v2(stmt))
db_fatal("Error executing insert: %s", stmt->error);
tal_free(stmt);
}
}
s64 db_get_intvar(struct db *db, char *varname, s64 defval)
{
s64 res = defval;
struct db_stmt *stmt = db_prepare_v2(
db, SQL("SELECT intval FROM vars WHERE name= ? LIMIT 1"));
db_bind_text(stmt, 0, varname);
if (!db_query_prepared(stmt))
goto done;
if (db_step(stmt))
res = db_col_int(stmt, "intval");
done:
tal_free(stmt);
return res;
}
/* Leak tracking. */
/* By making the update conditional on the current value we expect we
* are implementing an optimistic lock: if the update results in
* changes on the DB we know that the data_version did not change
* under our feet and no other transaction ran in the meantime.
*
* Notice that this update effectively locks the row, so that other
* operations attempting to change this outside the transaction will
* wait for this transaction to complete. The external change will
* ultimately fail the changes test below, it'll just delay its abort
* until our transaction is committed.
*/
static void db_data_version_incr(struct db *db)
{
struct db_stmt *stmt = db_prepare_v2(
db, SQL("UPDATE vars "
"SET intval = intval + 1 "
"WHERE name = 'data_version'"
" AND intval = ?"));
db_bind_int(stmt, 0, db->data_version);
db_exec_prepared_v2(stmt);
if (db_count_changes(stmt) != 1)
db_fatal("Optimistic lock on the database failed. There"
" may be a concurrent access to the database."
" Aborting since concurrent access is unsafe.");
tal_free(stmt);
db->data_version++;
}
void db_begin_transaction_(struct db *db, const char *location)
{
bool ok;
if (db->in_transaction)
db_fatal("Already in transaction from %s", db->in_transaction);
/* No writes yet. */
db->dirty = false;
db_prepare_for_changes(db);
ok = db->config->begin_tx_fn(db);
if (!ok)
db_fatal("Failed to start DB transaction: %s", db->error);
db->in_transaction = location;
}
bool db_in_transaction(struct db *db)
{
return db->in_transaction;
}
void db_commit_transaction(struct db *db)
{
bool ok;
assert(db->in_transaction);
db_assert_no_outstanding_statements(db);
/* Increment before reporting changes to an eventual plugin. */
if (db->dirty)
db_data_version_incr(db);
db_report_changes(db, NULL, 0);
ok = db->config->commit_tx_fn(db);
if (!ok)
db_fatal("Failed to commit DB transaction: %s", db->error);
db->in_transaction = NULL;
db->dirty = false;
}

52
db/exec.h Normal file
View file

@ -0,0 +1,52 @@
#ifndef LIGHTNING_DB_EXEC_H
#define LIGHTNING_DB_EXEC_H
#include "config.h"
#include <ccan/short_types/short_types.h>
#include <ccan/take/take.h>
struct db;
/**
* db_set_intvar - Set an integer variable in the database
*
* Utility function to store generic integer values in the
* database.
*/
void db_set_intvar(struct db *db, char *varname, s64 val);
/**
* db_get_intvar - Retrieve an integer variable from the database
*
* Either returns the value in the database, or @defval if
* the query failed or no such variable exists.
*/
s64 db_get_intvar(struct db *db, char *varname, s64 defval);
/* Get the current data version (entries). */
u32 db_data_version_get(struct db *db);
/* Get the current database version (migrations). */
int db_get_version(struct db *db);
/**
* db_begin_transaction - Begin a transaction
*
* Begin a new DB transaction. fatal() on database error.
*/
#define db_begin_transaction(db) \
db_begin_transaction_((db), __FILE__ ":" stringify(__LINE__))
void db_begin_transaction_(struct db *db, const char *location);
bool db_in_transaction(struct db *db);
/**
* db_commit_transaction - Commit a running transaction
*
* Requires that we are currently in a transaction. fatal() if we
* fail to commit.
*/
void db_commit_transaction(struct db *db);
#endif /* LIGHTNING_DB_EXEC_H */

324
db/utils.c Normal file
View file

@ -0,0 +1,324 @@
#include "config.h"
#include <ccan/tal/str/str.h>
#include <common/utils.h>
#include <db/common.h>
#include <db/utils.h>
/* Matches the hash function used in devtools/sql-rewrite.py */
static u32 hash_djb2(const char *str)
{
u32 hash = 5381;
for (size_t i = 0; str[i]; i++)
hash = ((hash << 5) + hash) ^ str[i];
return hash;
}
size_t db_query_colnum(const struct db_stmt *stmt,
const char *colname)
{
u32 col;
assert(stmt->query->colnames != NULL);
col = hash_djb2(colname) % stmt->query->num_colnames;
/* Will crash on NULL, which is the Right Thing */
while (!streq(stmt->query->colnames[col].sqlname,
colname)) {
col = (col + 1) % stmt->query->num_colnames;
}
#if DEVELOPER
strset_add(stmt->cols_used, colname);
#endif
return stmt->query->colnames[col].val;
}
static void db_stmt_free(struct db_stmt *stmt)
{
if (!stmt->executed)
db_fatal("Freeing an un-executed statement from %s: %s",
stmt->location, stmt->query->query);
#if DEVELOPER
/* If they never got a db_step, we don't track */
if (stmt->cols_used) {
for (size_t i = 0; i < stmt->query->num_colnames; i++) {
if (!stmt->query->colnames[i].sqlname)
continue;
if (!strset_get(stmt->cols_used,
stmt->query->colnames[i].sqlname)) {
db_fatal("Never accessed column %s in query %s",
stmt->query->colnames[i].sqlname,
stmt->query->query);
}
}
strset_clear(stmt->cols_used);
}
#endif
if (stmt->inner_stmt)
stmt->db->config->stmt_free_fn(stmt);
assert(stmt->inner_stmt == NULL);
}
struct db_stmt *db_prepare_v2_(const char *location, struct db *db,
const char *query_id)
{
struct db_stmt *stmt = tal(db, struct db_stmt);
size_t num_slots, pos;
/* Normalize query_id paths, because unit tests are compiled with this
* prefix. */
if (strncmp(query_id, "./", 2) == 0)
query_id += 2;
if (!db->in_transaction)
db_fatal("Attempting to prepare a db_stmt outside of a "
"transaction: %s", location);
/* Look up the query by its ID */
pos = hash_djb2(query_id) % db->queries->query_table_size;
for (;;) {
if (!db->queries->query_table[pos].name)
db_fatal("Could not resolve query %s", query_id);
if (streq(query_id, db->queries->query_table[pos].name)) {
stmt->query = &db->queries->query_table[pos];
break;
}
pos = (pos + 1) % db->queries->query_table_size;
}
num_slots = stmt->query->placeholders;
/* Allocate the slots for placeholders/bindings, zeroed next since
* that sets the type to DB_BINDING_UNINITIALIZED for later checks. */
stmt->bindings = tal_arr(stmt, struct db_binding, num_slots);
for (size_t i=0; i<num_slots; i++)
stmt->bindings[i].type = DB_BINDING_UNINITIALIZED;
stmt->location = location;
stmt->error = NULL;
stmt->db = db;
stmt->executed = false;
stmt->inner_stmt = NULL;
tal_add_destructor(stmt, db_stmt_free);
list_add(&db->pending_statements, &stmt->list);
#if DEVELOPER
stmt->cols_used = NULL;
#endif /* DEVELOPER */
return stmt;
}
#define db_prepare_v2(db,query) \
db_prepare_v2_(__FILE__ ":" stringify(__LINE__), db, query)
bool db_query_prepared(struct db_stmt *stmt)
{
/* Make sure we don't accidentally execute a modifying query using a
* read-only path. */
bool ret;
assert(stmt->query->readonly);
ret = stmt->db->config->query_fn(stmt);
stmt->executed = true;
list_del_from(&stmt->db->pending_statements, &stmt->list);
return ret;
}
bool db_step(struct db_stmt *stmt)
{
bool ret;
assert(stmt->executed);
ret = stmt->db->config->step_fn(stmt);
#if DEVELOPER
/* We only track cols_used if we return a result! */
if (ret && !stmt->cols_used) {
stmt->cols_used = tal(stmt, struct strset);
strset_init(stmt->cols_used);
}
#endif
return ret;
}
bool db_exec_prepared_v2(struct db_stmt *stmt TAKES)
{
bool ret = stmt->db->config->exec_fn(stmt);
/* If this was a write we need to bump the data_version upon commit. */
stmt->db->dirty = stmt->db->dirty || !stmt->query->readonly;
stmt->executed = true;
list_del_from(&stmt->db->pending_statements, &stmt->list);
/* The driver itself doesn't call `fatal` since we want to override it
* for testing. Instead we check here that the error message is set if
* we report an error. */
if (!ret) {
assert(stmt->error);
db_fatal("Error executing statement: %s", stmt->error);
}
if (taken(stmt))
tal_free(stmt);
return ret;
}
size_t db_count_changes(struct db_stmt *stmt)
{
assert(stmt->executed);
return stmt->db->config->count_changes_fn(stmt);
}
const char **db_changes(struct db *db)
{
return db->changes;
}
u64 db_last_insert_id_v2(struct db_stmt *stmt TAKES)
{
u64 id;
assert(stmt->executed);
id = stmt->db->config->last_insert_id_fn(stmt);
if (taken(stmt))
tal_free(stmt);
return id;
}
/* We expect min changes (ie. BEGIN TRANSACTION): report if more.
* Optionally add "final" at the end (ie. COMMIT). */
void db_report_changes(struct db *db, const char *final, size_t min)
{
assert(db->changes);
assert(tal_count(db->changes) >= min);
/* Having changes implies that we have a dirty TX. The opposite is
* currently not true, e.g., the postgres driver doesn't record
* changes yet. */
assert(!tal_count(db->changes) || db->dirty);
if (tal_count(db->changes) > min && db->report_changes_fn)
db->report_changes_fn(db);
db->changes = tal_free(db->changes);
}
void db_changes_add(struct db_stmt *stmt, const char * expanded)
{
struct db *db = stmt->db;
if (stmt->query->readonly) {
return;
}
/* We get a "COMMIT;" after we've sent our changes. */
if (!db->changes) {
assert(streq(expanded, "COMMIT;"));
return;
}
tal_arr_expand(&db->changes, tal_strdup(db->changes, expanded));
}
#if DEVELOPER
void db_assert_no_outstanding_statements(struct db *db)
{
struct db_stmt *stmt;
stmt = list_top(&db->pending_statements, struct db_stmt, list);
if (stmt)
db_fatal("Unfinalized statement %s", stmt->location);
}
#else
void db_assert_no_outstanding_statements(struct db *db)
{
}
#endif
static void destroy_db(struct db *db)
{
db_assert_no_outstanding_statements(db);
if (db->config->teardown_fn)
db->config->teardown_fn(db);
}
static struct db_config *db_config_find(const char *dsn)
{
size_t num_configs;
struct db_config **configs = autodata_get(db_backends, &num_configs);
const char *sep, *driver_name;
sep = strstr(dsn, "://");
if (!sep)
db_fatal("%s doesn't look like a valid data-source name (missing \"://\" separator.", dsn);
driver_name = tal_strndup(tmpctx, dsn, sep - dsn);
for (size_t i=0; i<num_configs; i++) {
if (streq(driver_name, configs[i]->name)) {
tal_free(driver_name);
return configs[i];
}
}
tal_free(driver_name);
return NULL;
}
static struct db_query_set *db_queries_find(const struct db_config *config)
{
size_t num_queries;
struct db_query_set **queries = autodata_get(db_queries, &num_queries);
for (size_t i = 0; i < num_queries; i++) {
if (streq(config->name, queries[i]->name)) {
return queries[i];
}
}
return NULL;
}
void db_prepare_for_changes(struct db *db)
{
assert(!db->changes);
db->changes = tal_arr(db, const char *, 0);
}
struct db *db_open(const tal_t *ctx, char *filename)
{
struct db *db;
db = tal(ctx, struct db);
db->filename = tal_strdup(db, filename);
list_head_init(&db->pending_statements);
if (!strstr(db->filename, "://"))
db_fatal("Could not extract driver name from \"%s\"", db->filename);
db->config = db_config_find(db->filename);
if (!db->config)
db_fatal("Unable to find DB driver for %s", db->filename);
db->queries = db_queries_find(db->config);
if (!db->queries)
db_fatal("Unable to find DB queries for %s", db->config->name);
tal_add_destructor(db, destroy_db);
db->in_transaction = NULL;
db->changes = NULL;
/* This must be outside a transaction, so catch it */
assert(!db->in_transaction);
db_prepare_for_changes(db);
if (db->config->setup_fn && !db->config->setup_fn(db))
db_fatal("Error calling DB setup: %s", db->error);
db_report_changes(db, NULL, 0);
return db;
}

100
db/utils.h Normal file
View file

@ -0,0 +1,100 @@
#ifndef LIGHTNING_DB_UTILS_H
#define LIGHTNING_DB_UTILS_H
#include "config.h"
#include <ccan/take/take.h>
#include <ccan/tal/tal.h>
struct db;
struct db_stmt;
size_t db_query_colnum(const struct db_stmt *stmt,
const char *colname);
/* Return next 'row' result of statement */
bool db_step(struct db_stmt *stmt);
/* TODO(cdecker) Remove the v2 suffix after finishing the migration */
#define db_prepare_v2(db,query) \
db_prepare_v2_(__FILE__ ":" stringify(__LINE__), db, query)
/**
* db_exec_prepared -- Execute a prepared statement
*
* After preparing a statement using `db_prepare`, and after binding all
* non-null variables using the `db_bind_*` functions, it can be executed with
* this function. It is a small, transaction-aware, wrapper around `db_step`,
* that calls fatal() if the execution fails. This may take ownership of
* `stmt` if annotated with `take()`and will free it before returning.
*
* If you'd like to issue a query and access the rows returned by the query
* please use `db_query_prepared` instead, since this function will not expose
* returned results, and the `stmt` can only be used for calls to
* `db_count_changes` and `db_last_insert_id` after executing.
*
* @stmt: The prepared statement to execute
*/
bool db_exec_prepared_v2(struct db_stmt *stmt TAKES);
/**
* db_query_prepared -- Execute a prepared query
*
* After preparing a query using `db_prepare`, and after binding all non-null
* variables using the `db_bind_*` functions, it can be executed with this
* function. This function must be called before calling `db_step` or any of
* the `db_col_*` column access functions.
*
* If you are not executing a read-only statement, please use
* `db_exec_prepared` instead.
*
* @stmt: The prepared statement to execute
*/
bool db_query_prepared(struct db_stmt *stmt);
size_t db_count_changes(struct db_stmt *stmt);
void db_report_changes(struct db *db, const char *final, size_t min);
void db_prepare_for_changes(struct db *db);
u64 db_last_insert_id_v2(struct db_stmt *stmt);
/**
* db_prepare -- Prepare a DB query/command
*
* Create an instance of `struct db_stmt` that encapsulates a SQL query or command.
*
* @query MUST be wrapped in a `SQL()` macro call, since that allows the
* extraction and translation of the query into the target SQL dialect.
*
* It does not execute the query and does not check its validity, but
* allocates the placeholders detected in the query. The placeholders in the
* `stmt` can then be bound using the `db_bind_*` functions, and executed
* using `db_exec_prepared` for write-only statements and `db_query_prepared`
* for read-only statements.
*
* @db: Database to query/exec
* @query: The SQL statement to compile
*/
struct db_stmt *db_prepare_v2_(const char *location, struct db *db,
const char *query_id);
/**
* db_open - Open or create a database
*/
struct db *db_open(const tal_t *ctx, char *filename);
/**
* Report a statement that changes the wallet
*
* Allows the DB driver to report an expanded statement during
* execution. Changes are queued up and reported to the `db_write` plugin hook
* upon committing.
*/
void db_changes_add(struct db_stmt *db_stmt, const char * expanded);
void db_assert_no_outstanding_statements(struct db *db);
/**
* Access pending changes that have been added to the current transaction.
*/
const char **db_changes(struct db *db);
#endif /* LIGHTNING_DB_UTILS_H */

View file

@ -126,7 +126,8 @@ template = Template("""#ifndef LIGHTNINGD_WALLET_GEN_DB_${f.upper()}
#include <config.h>
#include <ccan/array_size/array_size.h>
#include <wallet/db_common.h>
#include <db/common.h>
#include <db/utils.h>
#if HAVE_${f.upper()}
% for colname, table in colhtables.items():

View file

@ -130,9 +130,11 @@ LIGHTNINGD_COMMON_OBJS := \
common/wallet.o \
common/wire_error.o \
common/wireaddr.o \
db/bindings.o \
db/exec.o \
$(LIGHTNINGD_OBJS): $(LIGHTNINGD_HDRS)
$(WALLET_OBJS): $(LIGHTNINGD_HDRS)
$(WALLET_OBJS): $(LIGHTNINGD_HDRS) $(DB_HEADERS)
# Only the plugin component needs to depend on this header.
lightningd/plugin.o: plugins/list_of_builtin_plugins_gen.h
@ -140,6 +142,6 @@ lightningd/plugin.o: plugins/list_of_builtin_plugins_gen.h
lightningd/channel_state_names_gen.h: lightningd/channel_state.h ccan/ccan/cdump/tools/cdump-enumstr
ccan/ccan/cdump/tools/cdump-enumstr lightningd/channel_state.h > $@
lightningd/lightningd: $(LIGHTNINGD_OBJS) $(WALLET_OBJS) $(LIGHTNINGD_COMMON_OBJS) $(BITCOIN_OBJS) $(WIRE_OBJS) $(WIRE_ONION_OBJS) $(LIGHTNINGD_CONTROL_OBJS) $(HSMD_CLIENT_OBJS)
lightningd/lightningd: $(LIGHTNINGD_OBJS) $(WALLET_OBJS) $(LIGHTNINGD_COMMON_OBJS) $(BITCOIN_OBJS) $(WIRE_OBJS) $(WIRE_ONION_OBJS) $(LIGHTNINGD_CONTROL_OBJS) $(HSMD_CLIENT_OBJS) $(DB_OBJS)
include lightningd/test/Makefile

View file

@ -14,6 +14,7 @@
#include <ccan/tal/str/str.h>
#include <common/json_helpers.h>
#include <common/memleak.h>
#include <db/exec.h>
#include <lightningd/bitcoind.h>
#include <lightningd/chaintopology.h>
#include <lightningd/io_loop_with_timers.h>

View file

@ -11,6 +11,7 @@
#include <common/param.h>
#include <common/timeout.h>
#include <common/type_to_string.h>
#include <db/exec.h>
#include <lightningd/bitcoind.h>
#include <lightningd/chaintopology.h>
#include <lightningd/channel.h>

View file

@ -2,6 +2,7 @@
#include <ccan/array_size/array_size.h>
#include <ccan/asort/asort.h>
#include <ccan/cast/cast.h>
#include <ccan/json_escape/json_escape.h>
#include <ccan/str/hex/hex.h>
#include <ccan/tal/str/str.h>
#include <common/bolt11_json.h>
@ -16,6 +17,7 @@
#include <common/random_select.h>
#include <common/timeout.h>
#include <common/type_to_string.h>
#include <db/exec.h>
#include <errno.h>
#include <hsmd/hsmd_wiregen.h>
#include <lightningd/channel.h>

View file

@ -1,6 +1,7 @@
#include "config.h"
#include <ccan/io/io.h>
#include <common/timeout.h>
#include <db/exec.h>
#include <lightningd/io_loop_with_timers.h>
#include <lightningd/lightningd.h>

View file

@ -17,6 +17,7 @@
#include <ccan/asort/asort.h>
#include <ccan/err/err.h>
#include <ccan/io/io.h>
#include <ccan/json_escape/json_escape.h>
#include <ccan/json_out/json_out.h>
#include <ccan/tal/str/str.h>
#include <common/configdir.h>
@ -26,6 +27,7 @@
#include <common/memleak.h>
#include <common/param.h>
#include <common/timeout.h>
#include <db/exec.h>
#include <fcntl.h>
#include <lightningd/jsonrpc.h>
#include <lightningd/plugin_hook.h>

View file

@ -37,6 +37,7 @@
*/
#include <ccan/array_size/array_size.h>
#include <ccan/closefrom/closefrom.h>
#include <ccan/json_escape/json_escape.h>
#include <ccan/opt/opt.h>
#include <ccan/pipecmd/pipecmd.h>
#include <ccan/read_write_all/read_write_all.h>
@ -53,6 +54,7 @@
#include <common/timeout.h>
#include <common/type_to_string.h>
#include <common/version.h>
#include <db/exec.h>
#include <errno.h>
#include <fcntl.h>

View file

@ -1,5 +1,7 @@
#include "config.h"
#include <ccan/cast/cast.h>
#include <ccan/json_escape/json_escape.h>
#include <ccan/take/take.h>
#include <common/bolt12_merkle.h>
#include <common/json_command.h>
#include <common/json_helpers.h>

View file

@ -2,6 +2,7 @@
#include <bitcoin/feerate.h>
#include <common/key_derive.h>
#include <common/type_to_string.h>
#include <db/exec.h>
#include <errno.h>
#include <hsmd/capabilities.h>
#include <inttypes.h>

View file

@ -1,6 +1,7 @@
#include "config.h"
#include <ccan/array_size/array_size.h>
#include <ccan/err/err.h>
#include <ccan/json_escape/json_escape.h>
#include <ccan/mem/mem.h>
#include <ccan/opt/opt.h>
#include <ccan/opt/private.h>

View file

@ -13,6 +13,7 @@
#include <common/param.h>
#include <common/timeout.h>
#include <common/type_to_string.h>
#include <db/exec.h>
#include <gossipd/gossipd_wiregen.h>
#include <lightningd/chaintopology.h>
#include <lightningd/channel.h>

View file

@ -1,6 +1,8 @@
#include "config.h"
#include <ccan/io/io.h>
#include <common/memleak.h>
#include <db/exec.h>
#include <db/utils.h>
#include <lightningd/plugin_hook.h>
/* Struct containing all the information needed to deserialize and

View file

@ -10,6 +10,7 @@
#include <common/peer_status_wiregen.h>
#include <common/status_wiregen.h>
#include <common/version.h>
#include <db/exec.h>
#include <errno.h>
#include <fcntl.h>
#include <lightningd/lightningd.h>

View file

@ -7,37 +7,40 @@ WALLET_LIB_SRC := \
wallet/wallet.c \
wallet/walletrpc.c
WALLET_LIB_SRC_NOHDR := \
wallet/reservation.c
WALLET_LIB_SRC_NOHDR := \
wallet/reservation.c \
wallet/db_queries_postgres.c \
wallet/db_queries_sqlite3.c
WALLET_DB_DRIVERS := \
wallet/db_postgres.c \
wallet/db_sqlite3.c
WALLET_DB_QUERIES := \
wallet/db_sqlite3_sqlgen.c \
wallet/db_postgres_sqlgen.c
WALLET_SRC := $(WALLET_LIB_SRC) $(WALLET_LIB_SRC_NOHDR) $(WALLET_DB_DRIVERS)
WALLET_SRC := $(WALLET_LIB_SRC) $(WALLET_LIB_SRC_NOHDR)
WALLET_HDRS := $(WALLET_LIB_SRC:.c=.h)
WALLET_OBJS := $(WALLET_SRC:.c=.o)
# Make sure these depend on everything.
ALL_C_SOURCES += $(WALLET_SRC) wallet/db_sqlite3_sqlgen.c wallet/db_postgres_sqlgen.c
ALL_C_SOURCES += $(WALLET_SRC) $(WALLET_DB_QUERIES)
ALL_C_HEADERS += $(WALLET_HDRS)
# Each database driver depends on its rewritten statements.
wallet/db_sqlite3.o: wallet/db_sqlite3_sqlgen.c
wallet/db_postgres.o: wallet/db_postgres_sqlgen.c
# Query sets depend on the rewritten queries
wallet/db_queries_sqlite3.o: $(WALLET_DB_QUERIES)
wallet/db_queries_postgres.o: $(WALLET_DB_QUERIES)
# The following files contain SQL-annotated statements that we need to extact
SQL_FILES := \
WALLET_SQL_FILES := \
$(DB_SQL_FILES) \
wallet/db.c \
wallet/invoices.c \
wallet/wallet.c \
wallet/test/run-db.c \
wallet/test/run-wallet.c \
wallet/statements_gettextgen.po: $(SQL_FILES) $(FORCE)
wallet/statements_gettextgen.po: $(WALLET_SQL_FILES) $(FORCE)
@if $(call SHA256STAMP_CHANGED); then \
$(call VERBOSE,"xgettext $@",xgettext -kNAMED_SQL -kSQL --add-location --no-wrap --omit-header -o $@ $(SQL_FILES) && $(call SHA256STAMP,# ,)); \
$(call VERBOSE,"xgettext $@",xgettext -kNAMED_SQL -kSQL --add-location --no-wrap --omit-header -o $@ $(WALLET_SQL_FILES) && $(call SHA256STAMP,# ,)); \
fi
wallet/db_%_sqlgen.c: wallet/statements_gettextgen.po devtools/sql-rewrite.py $(FORCE)
@ -50,7 +53,6 @@ clean: wallet-maintainer-clean
wallet-maintainer-clean:
$(RM) wallet/statements.po
$(RM) wallet/statements_gettextgen.po
$(RM) wallet/db_sqlite3_sqlgen.c
$(RM) wallet/db_postgres_sqlgen.c
$(RM) $(WALLET_DB_QUERIES)
include wallet/test/Makefile

File diff suppressed because it is too large Load diff

View file

@ -2,45 +2,10 @@
#define LIGHTNING_WALLET_DB_H
#include "config.h"
#include <bitcoin/preimage.h>
#include <bitcoin/pubkey.h>
#include <bitcoin/short_channel_id.h>
#include <bitcoin/tx.h>
#include <ccan/json_escape/json_escape.h>
#include <ccan/time/time.h>
struct channel_id;
struct ext_key;
struct lightningd;
struct log;
struct node_id;
struct onionreply;
struct db_stmt;
struct db;
struct wally_psbt;
struct wally_tx;
/**
* Macro to annotate a named SQL query.
*
* This macro is used to annotate SQL queries that might need rewriting for
* different SQL dialects. It is used both as a marker for the query
* extraction logic in devtools/sql-rewrite.py to identify queries, as well as
* a way to swap out the query text with it's name so that the query execution
* engine can then look up the rewritten query using its name.
*
*/
#define NAMED_SQL(name,x) x
/**
* Simple annotation macro that auto-generates names for NAMED_SQL
*
* If this macro is changed it is likely that the extraction logic in
* devtools/sql-rewrite.py needs to change as well, since they need to
* generate identical names to work correctly.
*/
#define SQL(x) NAMED_SQL( __FILE__ ":" stringify(__COUNTER__), x)
/**
* db_setup - Open a the lightningd database and update the schema
@ -57,204 +22,4 @@ struct wally_tx;
struct db *db_setup(const tal_t *ctx, struct lightningd *ld,
const struct ext_key *bip32_base);
/**
* db_begin_transaction - Begin a transaction
*
* Begin a new DB transaction. fatal() on database error.
*/
#define db_begin_transaction(db) \
db_begin_transaction_((db), __FILE__ ":" stringify(__LINE__))
void db_begin_transaction_(struct db *db, const char *location);
bool db_in_transaction(struct db *db);
/**
* db_commit_transaction - Commit a running transaction
*
* Requires that we are currently in a transaction. fatal() if we
* fail to commit.
*/
void db_commit_transaction(struct db *db);
/**
* db_set_intvar - Set an integer variable in the database
*
* Utility function to store generic integer values in the
* database.
*/
void db_set_intvar(struct db *db, char *varname, s64 val);
/**
* db_get_intvar - Retrieve an integer variable from the database
*
* Either returns the value in the database, or @defval if
* the query failed or no such variable exists.
*/
s64 db_get_intvar(struct db *db, char *varname, s64 defval);
void db_bind_null(struct db_stmt *stmt, int pos);
void db_bind_int(struct db_stmt *stmt, int pos, int val);
void db_bind_u64(struct db_stmt *stmt, int pos, u64 val);
void db_bind_blob(struct db_stmt *stmt, int pos, const u8 *val, size_t len);
void db_bind_text(struct db_stmt *stmt, int pos, const char *val);
void db_bind_preimage(struct db_stmt *stmt, int pos, const struct preimage *p);
void db_bind_sha256(struct db_stmt *stmt, int pos, const struct sha256 *s);
void db_bind_sha256d(struct db_stmt *stmt, int pos, const struct sha256_double *s);
void db_bind_secret(struct db_stmt *stmt, int pos, const struct secret *s);
void db_bind_secret_arr(struct db_stmt *stmt, int col, const struct secret *s);
void db_bind_txid(struct db_stmt *stmt, int pos, const struct bitcoin_txid *t);
void db_bind_channel_id(struct db_stmt *stmt, int pos, const struct channel_id *id);
void db_bind_node_id(struct db_stmt *stmt, int pos, const struct node_id *ni);
void db_bind_node_id_arr(struct db_stmt *stmt, int col,
const struct node_id *ids);
void db_bind_pubkey(struct db_stmt *stmt, int pos, const struct pubkey *p);
void db_bind_short_channel_id(struct db_stmt *stmt, int col,
const struct short_channel_id *id);
void db_bind_short_channel_id_arr(struct db_stmt *stmt, int col,
const struct short_channel_id *id);
void db_bind_signature(struct db_stmt *stmt, int col,
const secp256k1_ecdsa_signature *sig);
void db_bind_timeabs(struct db_stmt *stmt, int col, struct timeabs t);
void db_bind_tx(struct db_stmt *stmt, int col, const struct wally_tx *tx);
void db_bind_psbt(struct db_stmt *stmt, int col, const struct wally_psbt *psbt);
void db_bind_amount_msat(struct db_stmt *stmt, int pos,
const struct amount_msat *msat);
void db_bind_amount_sat(struct db_stmt *stmt, int pos,
const struct amount_sat *sat);
void db_bind_json_escape(struct db_stmt *stmt, int pos,
const struct json_escape *esc);
void db_bind_onionreply(struct db_stmt *stmt, int col,
const struct onionreply *r);
void db_bind_talarr(struct db_stmt *stmt, int col, const u8 *arr);
bool db_step(struct db_stmt *stmt);
/* Modern variants: get columns by name from SELECT */
/* Bridge function to get column number from SELECT
(must exist) */
size_t db_query_colnum(const struct db_stmt *stmt, const char *colname);
u64 db_col_u64(struct db_stmt *stmt, const char *colname);
int db_col_int(struct db_stmt *stmt, const char *colname);
size_t db_col_bytes(struct db_stmt *stmt, const char *colname);
int db_col_is_null(struct db_stmt *stmt, const char *colname);
const void* db_col_blob(struct db_stmt *stmt, const char *colname);
char *db_col_strdup(const tal_t *ctx,
struct db_stmt *stmt,
const char *colname);
void db_col_preimage(struct db_stmt *stmt, const char *colname, struct preimage *preimage);
void db_col_amount_msat(struct db_stmt *stmt, const char *colname, struct amount_msat *msat);
void db_col_amount_sat(struct db_stmt *stmt, const char *colname, struct amount_sat *sat);
struct json_escape *db_col_json_escape(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
void db_col_sha256(struct db_stmt *stmt, const char *colname, struct sha256 *sha);
void db_col_sha256d(struct db_stmt *stmt, const char *colname, struct sha256_double *shad);
void db_col_secret(struct db_stmt *stmt, const char *colname, struct secret *s);
struct secret *db_col_secret_arr(const tal_t *ctx, struct db_stmt *stmt,
const char *colname);
void db_col_txid(struct db_stmt *stmt, const char *colname, struct bitcoin_txid *t);
void db_col_channel_id(struct db_stmt *stmt, const char *colname, struct channel_id *dest);
void db_col_node_id(struct db_stmt *stmt, const char *colname, struct node_id *ni);
struct node_id *db_col_node_id_arr(const tal_t *ctx, struct db_stmt *stmt,
const char *colname);
void db_col_pubkey(struct db_stmt *stmt, const char *colname,
struct pubkey *p);
bool db_col_short_channel_id_str(struct db_stmt *stmt, const char *colname,
struct short_channel_id *dest);
struct short_channel_id *
db_col_short_channel_id_arr(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
bool db_col_signature(struct db_stmt *stmt, const char *colname,
secp256k1_ecdsa_signature *sig);
struct timeabs db_col_timeabs(struct db_stmt *stmt, const char *colname);
struct bitcoin_tx *db_col_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
struct wally_psbt *db_col_psbt(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
struct bitcoin_tx *db_col_psbt_to_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
struct onionreply *db_col_onionreply(const tal_t *ctx,
struct db_stmt *stmt, const char *colname);
#define db_col_arr(ctx, stmt, colname, type) \
((type *)db_col_arr_((ctx), (stmt), (colname), \
sizeof(type), TAL_LABEL(type, "[]"), \
__func__))
void *db_col_arr_(const tal_t *ctx, struct db_stmt *stmt, const char *colname,
size_t bytes, const char *label, const char *caller);
/* Some useful default variants */
int db_col_int_or_default(struct db_stmt *stmt, const char *colname, int def);
void db_col_amount_msat_or_default(struct db_stmt *stmt, const char *colname,
struct amount_msat *msat,
struct amount_msat def);
/* Explicitly ignore a column (so we don't complain you didn't use it!) */
void db_col_ignore(struct db_stmt *stmt, const char *colname);
/**
* db_exec_prepared -- Execute a prepared statement
*
* After preparing a statement using `db_prepare`, and after binding all
* non-null variables using the `db_bind_*` functions, it can be executed with
* this function. It is a small, transaction-aware, wrapper around `db_step`,
* that calls fatal() if the execution fails. This may take ownership of
* `stmt` if annotated with `take()`and will free it before returning.
*
* If you'd like to issue a query and access the rows returned by the query
* please use `db_query_prepared` instead, since this function will not expose
* returned results, and the `stmt` can only be used for calls to
* `db_count_changes` and `db_last_insert_id` after executing.
*
* @stmt: The prepared statement to execute
*/
bool db_exec_prepared_v2(struct db_stmt *stmt TAKES);
/**
* db_query_prepared -- Execute a prepared query
*
* After preparing a query using `db_prepare`, and after binding all non-null
* variables using the `db_bind_*` functions, it can be executed with this
* function. This function must be called before calling `db_step` or any of
* the `db_col_*` column access functions.
*
* If you are not executing a read-only statement, please use
* `db_exec_prepared` instead.
*
* @stmt: The prepared statement to execute
*/
bool db_query_prepared(struct db_stmt *stmt);
size_t db_count_changes(struct db_stmt *stmt);
u64 db_last_insert_id_v2(struct db_stmt *stmt);
/**
* db_prepare -- Prepare a DB query/command
*
* Create an instance of `struct db_stmt` that encapsulates a SQL query or command.
*
* @query MUST be wrapped in a `SQL()` macro call, since that allows the
* extraction and translation of the query into the target SQL dialect.
*
* It does not execute the query and does not check its validity, but
* allocates the placeholders detected in the query. The placeholders in the
* `stmt` can then be bound using the `db_bind_*` functions, and executed
* using `db_exec_prepared` for write-only statements and `db_query_prepared`
* for read-only statements.
*
* @db: Database to query/exec
* @query: The SQL statement to compile
*/
struct db_stmt *db_prepare_v2_(const char *location, struct db *db,
const char *query_id);
/* TODO(cdecker) Remove the v2 suffix after finishing the migration */
#define db_prepare_v2(db,query) \
db_prepare_v2_(__FILE__ ":" stringify(__LINE__), db, query)
/**
* Access pending changes that have been added to the current transaction.
*/
const char **db_changes(struct db *db);
/* Get the current data version. */
u32 db_data_version_get(struct db *db);
#endif /* LIGHTNING_WALLET_DB_H */

View file

@ -0,0 +1,13 @@
#include "config.h"
#include "db_postgres_sqlgen.c"
#if HAVE_POSTGRES
struct db_query_set postgres_query_set = {
.name = "postgres",
.query_table = db_postgres_queries,
.query_table_size = ARRAY_SIZE(db_postgres_queries),
};
AUTODATA(db_queries, &postgres_query_set);
#endif /* HAVE_POSTGRES */

View file

@ -0,0 +1,12 @@
#include "config.h"
#include "db_sqlite3_sqlgen.c"
#if HAVE_SQLITE3
struct db_query_set sqlite3_query_set = {
.name = "sqlite3",
.query_table = db_sqlite3_queries,
.query_table_size = ARRAY_SIZE(db_sqlite3_queries),
};
AUTODATA(db_queries, &sqlite3_query_set);
#endif /* HAVE_SQLITE3 */

View file

@ -1,7 +1,10 @@
#include "config.h"
#include <ccan/tal/str/str.h>
#include <common/timeout.h>
#include <wallet/db.h>
#include <db/bindings.h>
#include <db/common.h>
#include <db/exec.h>
#include <db/utils.h>
#include <wallet/invoices.h>
#include <wallet/wallet.h>

View file

@ -27,7 +27,7 @@ WALLET_TEST_COMMON_OBJS := \
common/utils.o \
common/wireaddr.o \
common/version.o \
wallet/db_sqlite3.o \
wallet/db_queries_sqlite3.o \
wire/towire.o \
wire/fromwire.o

View file

@ -6,6 +6,10 @@ static void db_log_(struct log *log UNUSED, enum log_level level UNUSED, const s
}
#define log_ db_log_
#include "db/bindings.c"
#include "db/db_sqlite3.c"
#include "db/exec.c"
#include "db/utils.c"
#include "wallet/db.c"
#include "test_utils.h"
@ -86,6 +90,8 @@ static struct db *create_test_db(void)
tal_free(filename);
db = db_open(NULL, dsn);
db->data_version = 0;
db->report_changes_fn = NULL;
tal_free(dsn);
return db;
}

View file

@ -3,7 +3,7 @@
#include "test_utils.h"
#include <ccan/tal/str/str.h>
#include <wallet/db_common.h>
#include <db/common.h>
static void db_log_(struct log *log UNUSED, enum log_level level UNUSED, const struct node_id *node_id UNUSED, bool call_notifier UNUSED, const char *fmt UNUSED, ...)
{
@ -35,6 +35,10 @@ void db_fatal(const char *fmt, ...)
#include "lightningd/peer_htlcs.c"
#include "lightningd/channel.c"
#include "db/bindings.c"
#include "db/db_sqlite3.c"
#include "db/exec.c"
#include "db/utils.c"
#include "wallet/db.c"
#include <common/setup.h>
@ -908,6 +912,7 @@ static struct wallet *create_test_wallet(struct lightningd *ld, const tal_t *ctx
dsn = tal_fmt(NULL, "sqlite3://%s", filename);
w->db = db_open(w, dsn);
w->db->report_changes_fn = NULL;
tal_free(dsn);
tal_add_destructor2(w, cleanup_test_wallet, filename);

View file

@ -8,13 +8,16 @@
#include <common/fee_states.h>
#include <common/onionreply.h>
#include <common/type_to_string.h>
#include <db/bindings.h>
#include <db/common.h>
#include <db/exec.h>
#include <db/utils.h>
#include <lightningd/chaintopology.h>
#include <lightningd/channel.h>
#include <lightningd/coin_mvts.h>
#include <lightningd/notification.h>
#include <lightningd/peer_control.h>
#include <onchaind/onchaind_wiregen.h>
#include <wallet/db_common.h>
#include <wallet/invoices.h>
#include <wallet/txfilter.h>
#include <wallet/wallet.h>

View file

@ -12,6 +12,7 @@
#include <common/param.h>
#include <common/psbt_open.h>
#include <common/type_to_string.h>
#include <db/exec.h>
#include <errno.h>
#include <hsmd/hsmd_wiregen.h>
#include <lightningd/chaintopology.h>