diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 863208a12..cbe43cdcd 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -299,10 +299,9 @@ def test_db_hook(node_factory, executor): # It should see the db being created, and sometime later actually get # initted. # This precedes startup, so needle already past - assert l1.daemon.is_in_log('plugin-dblog.py deferring 1 commands') + assert l1.daemon.is_in_log(r'plugin-dblog.py deferring \d+ commands') l1.daemon.logsearch_start = 0 l1.daemon.wait_for_log('plugin-dblog.py replaying pre-init data:') - l1.daemon.wait_for_log('plugin-dblog.py PRAGMA foreign_keys = ON;') l1.daemon.wait_for_log('plugin-dblog.py CREATE TABLE version \\(version INTEGER\\)') l1.daemon.wait_for_log("plugin-dblog.py initialized.* 'startup': True") diff --git a/wallet/db.c b/wallet/db.c index 8d4c03105..3a58e09a2 100644 --- a/wallet/db.c +++ b/wallet/db.c @@ -826,10 +826,17 @@ static void setup_open_db(struct db *db) assert(!db->in_transaction); db_prepare_for_changes(db); - db_do_exec(__func__, db, "PRAGMA foreign_keys = ON;"); + if (db->config->setup_fn) + db->config->setup_fn(db); db_report_changes(db, NULL, 0); } +void db_close(struct db *db) +{ + if (db->config->teardown_fn) + db->config->teardown_fn(db); +} + /** * db_open - Open or create a sqlite3 database */ diff --git a/wallet/db.h b/wallet/db.h index a81dccfde..5bf0f1b59 100644 --- a/wallet/db.h +++ b/wallet/db.h @@ -242,6 +242,8 @@ void sqlite3_bind_timeabs(sqlite3_stmt *stmt, int col, struct timeabs t); struct timeabs sqlite3_column_timeabs(sqlite3_stmt *stmt, int col); +void db_close(struct db *db); + void db_bind_null(struct db_stmt *stmt, int pos); void db_bind_int(struct db_stmt *stmt, int pos, int val); void db_bind_u64(struct db_stmt *stmt, int pos, u64 val); diff --git a/wallet/db_common.h b/wallet/db_common.h index 31aef4efe..9fbae2790 100644 --- a/wallet/db_common.h +++ b/wallet/db_common.h @@ -124,6 +124,9 @@ struct db_config { s64 (*column_int_fn)(struct db_stmt *stmt, int col); size_t (*count_changes_fn)(struct db_stmt *stmt); + + bool (*setup_fn)(struct db *db); + void (*teardown_fn)(struct db *db); }; /* Provide a way for DB backends to register themselves */ diff --git a/wallet/db_sqlite3.c b/wallet/db_sqlite3.c index 0da021cd4..484dde743 100644 --- a/wallet/db_sqlite3.c +++ b/wallet/db_sqlite3.c @@ -28,6 +28,16 @@ static const char *db_sqlite3_fmt_error(struct db_stmt *stmt) sqlite3_errmsg(stmt->db->conn)); } +static bool db_sqlite3_setup(struct db *db) +{ + sqlite3_stmt *stmt; + int err; + sqlite3_prepare_v2(db->conn, "PRAGMA foreign_keys = ON;", -1, &stmt, NULL); + err = sqlite3_step(stmt); + sqlite3_finalize(stmt); + return err == SQLITE_DONE; +} + static bool db_sqlite3_query(struct db_stmt *stmt) { sqlite3_stmt *s; @@ -173,6 +183,11 @@ static size_t db_sqlite3_count_changes(struct db_stmt *stmt) return sqlite3_changes(s); } +static void db_sqlite3_close(struct db *db) +{ + sqlite3_close(db->sql); +} + struct db_config db_sqlite3_config = { .name = "sqlite3", .queries = db_sqlite3_queries, @@ -193,6 +208,8 @@ struct db_config db_sqlite3_config = { .column_text_fn = &db_sqlite3_column_text, .count_changes_fn = &db_sqlite3_count_changes, + .setup_fn = &db_sqlite3_setup, + .teardown_fn = &db_sqlite3_close, }; AUTODATA(db_backends, &db_sqlite3_config);