diff --git a/wallet/db.c b/wallet/db.c index 33a62d27e..63e84e687 100644 --- a/wallet/db.c +++ b/wallet/db.c @@ -690,54 +690,6 @@ void db_exec_prepared_(const char *caller, struct db *db, sqlite3_stmt *stmt) db_stmt_done(stmt); } -/* This one doesn't check if we're in a transaction. */ -static void db_do_exec(const char *caller, struct db *db, const char *cmd) -{ - char *errmsg; - int err; - - err = sqlite3_exec(db->sql, cmd, NULL, NULL, &errmsg); - if (err != SQLITE_OK) { - db_fatal("%s:%s:%s:%s", caller, sqlite3_errstr(err), cmd, errmsg); - /* Only reached in testing */ - sqlite3_free(errmsg); - } -#if HAVE_SQLITE3_EXPANDED_SQL - tal_arr_expand(&db->changes, tal_strdup(db->changes, cmd)); -#endif -} - -static void PRINTF_FMT(3, 4) - db_exec(const char *caller, struct db *db, const char *fmt, ...) -{ - va_list ap; - char *cmd; - - assert(db->in_transaction); - - va_start(ap, fmt); - cmd = tal_vfmt(db, fmt, ap); - va_end(ap); - - db_do_exec(caller, db, cmd); - tal_free(cmd); -} - -/* This one can fail: returns NULL if so */ -static sqlite3_stmt *db_query(const char *location, - struct db *db, const char *query) -{ - sqlite3_stmt *stmt; - - assert(db->in_transaction); - - /* Sets stmt to NULL if not SQLITE_OK */ - sqlite3_prepare_v2(db->sql, query, -1, &stmt, NULL); - if (stmt) - dev_statement_start(stmt, location); - return stmt; -} - sqlite3_stmt *db_select_(const char *location, struct db *db, const char *query) { sqlite3_stmt *stmt; @@ -892,18 +844,17 @@ static struct db *db_open(const tal_t *ctx, char *filename) */ static int db_get_version(struct db *db) { - int res; - sqlite3_stmt *stmt = db_query(__func__, - db, "SELECT version FROM version LIMIT 1"); + int res = -1; + struct db_stmt *stmt = db_prepare_v2(db, SQL("SELECT version FROM version LIMIT 1")); + if (!db_query_prepared(stmt)) { + tal_free(stmt); + return res; + } - if (!stmt) - return -1; + if (db_step(stmt)) + res = db_column_u64(stmt, 0); - if (!db_select_step(db, stmt)) - return -1; - - res = sqlite3_column_int64(stmt, 0); - db_stmt_done(stmt); + tal_free(stmt); return res; } @@ -914,6 +865,7 @@ static void db_migrate(struct lightningd *ld, struct db *db, struct log *log) { /* Attempt to read the version from the database */ int current, orig, available; + struct db_stmt *stmt; db_begin_transaction(db); @@ -942,13 +894,20 @@ static void db_migrate(struct lightningd *ld, struct db *db, struct log *log) } /* Finally update the version number in the version table */ - db_exec(__func__, db, "UPDATE version SET version=%d;", available); + stmt = db_prepare_v2(db, SQL("UPDATE version SET version=?;")); + db_bind_u64(stmt, 0, available); + db_exec_prepared_v2(stmt); + tal_free(stmt); /* Annotate that we did upgrade, if any. */ - if (current != orig) - db_exec(__func__, db, - "INSERT INTO db_upgrades VALUES (%i, '%s');", - orig, version()); + if (current != orig) { + stmt = db_prepare_v2( + db, SQL("INSERT INTO db_upgrades VALUES (?, ?);")); + db_bind_u64(stmt, 0, orig); + db_bind_text(stmt, 1, version()); + db_exec_prepared_v2(stmt); + tal_free(stmt); + } db_commit_transaction(db); } @@ -963,39 +922,42 @@ struct db *db_setup(const tal_t *ctx, struct lightningd *ld, struct log *log) s64 db_get_intvar(struct db *db, char *varname, s64 defval) { - s64 res; - sqlite3_stmt *stmt; - const char *query; + s64 res = defval; + struct db_stmt *stmt = db_prepare_v2( + db, SQL("SELECT val FROM vars WHERE name= ? LIMIT 1")); + db_bind_text(stmt, 0, varname); + if (!db_query_prepared(stmt)) + goto done; - query = tal_fmt(db, "SELECT val FROM vars WHERE name='%s' LIMIT 1", varname); - stmt = db_query(__func__, db, query); - tal_free(query); - - if (!stmt) - return defval; - - if (db_select_step(db, stmt)) { - const unsigned char *stringvar = sqlite3_column_text(stmt, 0); - res = atol((const char *)stringvar); - db_stmt_done(stmt); - } else - res = defval; + if (db_step(stmt)) + res = atol((const char*)db_column_text(stmt, 0)); +done: + tal_free(stmt); return res; } void db_set_intvar(struct db *db, char *varname, s64 val) { - /* Attempt to update */ - db_exec(__func__, db, - "UPDATE vars SET val='%" PRId64 "' WHERE name='%s';", val, - varname); - if (sqlite3_changes(db->sql) == 0) - db_exec( - __func__, db, - "INSERT INTO vars (name, val) VALUES ('%s', '%" PRId64 - "');", - varname, val); + char *v = tal_fmt(NULL, "%"PRIi64, val); + size_t changes; + struct db_stmt *stmt = db_prepare_v2(db, SQL("UPDATE vars SET val=? WHERE name=?;")); + db_bind_text(stmt, 0, v); + db_bind_text(stmt, 1, varname); + if (!db_exec_prepared_v2(stmt)) + db_fatal("Error executing update: %s", stmt->error); + changes = db_count_changes(stmt); + tal_free(stmt); + + if (changes == 0) { + stmt = db_prepare_v2(db, SQL("INSERT INTO vars (name, val) VALUES (?, ?);")); + db_bind_text(stmt, 0, varname); + db_bind_text(stmt, 1, v); + if (!db_exec_prepared_v2(stmt)) + db_fatal("Error executing insert: %s", stmt->error); + tal_free(stmt); + } + tal_free(v); } void *sqlite3_column_arr_(const tal_t *ctx, sqlite3_stmt *stmt, int col, diff --git a/wallet/test/run-db.c b/wallet/test/run-db.c index ba36aa152..bd8642250 100644 --- a/wallet/test/run-db.c +++ b/wallet/test/run-db.c @@ -78,7 +78,9 @@ static bool test_empty_db_migrate(struct lightningd *ld) static bool test_primitives(void) { + struct db_stmt *stmt; struct db *db = create_test_db(); + db_err = NULL; db_begin_transaction(db); CHECK(db->in_transaction); db_commit_transaction(db); @@ -87,11 +89,15 @@ static bool test_primitives(void) db_commit_transaction(db); db_begin_transaction(db); - db_exec(__func__, db, "SELECT name FROM sqlite_master WHERE type='table';"); + stmt = db_prepare_v2(db, SQL("SELECT name FROM sqlite_master WHERE type='table';")); + CHECK_MSG(db_exec_prepared_v2(stmt), "db_exec_prepared must succeed"); CHECK_MSG(!db_err, "Simple correct SQL command"); + tal_free(stmt); - db_exec(__func__, db, "not a valid SQL statement"); + stmt = db_prepare_v2(db, SQL("not a valid SQL statement")); + CHECK_MSG(!db_exec_prepared_v2(stmt), "db_exec_prepared must fail"); CHECK_MSG(db_err, "Failing SQL command"); + tal_free(stmt); db_err = tal_free(db_err); db_commit_transaction(db); CHECK(!db->in_transaction); diff --git a/wallet/test/run-wallet.c b/wallet/test/run-wallet.c index 33c581a4a..93e4ecb8b 100644 --- a/wallet/test/run-wallet.c +++ b/wallet/test/run-wallet.c @@ -1126,6 +1126,7 @@ static bool test_channel_config_crud(struct lightningd *ld, const tal_t *ctx) static bool test_htlc_crud(struct lightningd *ld, const tal_t *ctx) { + struct db_stmt *stmt; struct htlc_in in, *hin; struct htlc_out out, *hout; struct preimage payment_key; @@ -1136,9 +1137,11 @@ static bool test_htlc_crud(struct lightningd *ld, const tal_t *ctx) struct htlc_out_map *htlcs_out = tal(ctx, struct htlc_out_map); /* Make sure we have our references correct */ - CHECK(transaction_wrap(w->db, - db_exec(__func__, w->db, "INSERT INTO channels (id) VALUES (1);"))); db_begin_transaction(w->db); + char *query = SQL("INSERT INTO channels (id) VALUES (1);"); + stmt = db_prepare_v2(w->db, query); + db_exec_prepared_v2(stmt); + tal_free(stmt); db_commit_transaction(w->db); chan->dbid = 1;