diff --git a/wallet/db_common.h b/wallet/db_common.h index c89240284..0ab196c29 100644 --- a/wallet/db_common.h +++ b/wallet/db_common.h @@ -151,6 +151,13 @@ struct db_config { void (*teardown_fn)(struct db *db); bool (*vacuum_fn)(struct db *db); + + bool (*rename_column)(struct db *db, + const char *tablename, + const char *from, const char *to); + bool (*delete_columns)(struct db *db, + const char *tablename, + const char **colnames, size_t num_cols); }; /* Provide a way for DB backends to register themselves */ diff --git a/wallet/db_postgres.c b/wallet/db_postgres.c index a8a239e93..34f6562d8 100644 --- a/wallet/db_postgres.c +++ b/wallet/db_postgres.c @@ -273,6 +273,51 @@ static bool db_postgres_vacuum(struct db *db) return true; } +static bool db_postgres_rename_column(struct db *db, + const char *tablename, + const char *from, const char *to) +{ + PGresult *res; + char *cmd; + + cmd = tal_fmt(db, "ALTER TABLE %s RENAME %s TO %s;", + tablename, from, to); + res = PQexec(db->conn, cmd); + if (PQresultStatus(res) != PGRES_COMMAND_OK) { + db->error = tal_fmt(db, "Rename '%s' failed: %s", + cmd, PQerrorMessage(db->conn)); + PQclear(res); + return false; + } + PQclear(res); + return true; +} + +static bool db_postgres_delete_columns(struct db *db, + const char *tablename, + const char **colnames, size_t num_cols) +{ + PGresult *res; + char *cmd; + + cmd = tal_fmt(db, "ALTER TABLE %s ", tablename); + for (size_t i = 0; i < num_cols; i++) { + if (i != 0) + tal_append_fmt(&cmd, ", "); + tal_append_fmt(&cmd, "DROP %s", colnames[i]); + } + tal_append_fmt(&cmd, ";"); + res = PQexec(db->conn, cmd); + if (PQresultStatus(res) != PGRES_COMMAND_OK) { + db->error = tal_fmt(db, "Delete '%s' failed: %s", + cmd, PQerrorMessage(db->conn)); + PQclear(res); + return false; + } + PQclear(res); + return true; +} + struct db_config db_postgres_config = { .name = "postgres", .query_table = db_postgres_queries, @@ -296,6 +341,8 @@ struct db_config db_postgres_config = { .setup_fn = db_postgres_setup, .teardown_fn = db_postgres_teardown, .vacuum_fn = db_postgres_vacuum, + .rename_column = db_postgres_rename_column, + .delete_columns = db_postgres_delete_columns, }; AUTODATA(db_backends, &db_postgres_config); diff --git a/wallet/db_sqlite3.c b/wallet/db_sqlite3.c index 00d6af7da..11d639a55 100644 --- a/wallet/db_sqlite3.c +++ b/wallet/db_sqlite3.c @@ -1,5 +1,6 @@ #include "db_sqlite3_sqlgen.c" #include +#include #include #if HAVE_SQLITE3 @@ -251,6 +252,238 @@ static bool db_sqlite3_vacuum(struct db *db) return err == SQLITE_DONE; } +static bool colname_to_delete(const char **colnames, + size_t num_colnames, + const char *columnname) +{ + for (size_t i = 0; i < num_colnames; i++) { + if (streq(columnname, colnames[i])) + return true; + } + return false; +} + +static const char *find_column_name(const tal_t *ctx, + const char *sqlpart, + size_t *after) +{ + size_t start = 0; + + while (isspace(sqlpart[start])) + start++; + *after = strspn(sqlpart + start, "abcdefghijklmnopqrstuvwxyz_0123456789") + start; + if (*after == start) + return NULL; + return tal_strndup(ctx, sqlpart + start, *after - start); +} + +/* Move table out the way, return columns */ +static char **prepare_table_manip(const tal_t *ctx, + struct db *db, const char *tablename) +{ + sqlite3_stmt *stmt; + const char *sql; + char *cmd, *bracket; + char **parts; + int err; + + /* Get schema. */ + sqlite3_prepare_v2(db->conn, "SELECT sql FROM sqlite_master WHERE type = ? AND name = ?;", -1, &stmt, NULL); + sqlite3_bind_text(stmt, 1, "table", strlen("table"), SQLITE_TRANSIENT); + sqlite3_bind_text(stmt, 2, tablename, strlen(tablename), SQLITE_TRANSIENT); + err = sqlite3_step(stmt); + if (err != SQLITE_ROW) { + db->error = tal_fmt(db, "getting schema: %s", + sqlite3_errmsg(db->conn)); + sqlite3_finalize(stmt); + return NULL; + } + + sql = tal_strdup(tmpctx, (const char *)sqlite3_column_text(stmt, 0)); + sqlite3_finalize(stmt); + + bracket = strchr(sql, '('); + if (!strstarts(sql, "CREATE TABLE") || !bracket) { + db->error = tal_fmt(db, "strange schema for %s: %s", + tablename, sql); + return NULL; + } + + /* Split after ( by commas: any lower case is assumed to be a field */ + parts = tal_strsplit(ctx, bracket + 1, ",", STR_EMPTY_OK); + + /* Turn off foreign keys first. */ + sqlite3_prepare_v2(db->conn, "PRAGMA foreign_keys = OFF;", -1, &stmt, NULL); + if (sqlite3_step(stmt) != SQLITE_DONE) + goto sqlite_stmt_err; + sqlite3_finalize(stmt); + + cmd = tal_fmt(tmpctx, "ALTER TABLE %s RENAME TO temp_%s;", + tablename, tablename); + sqlite3_prepare_v2(db->conn, cmd, -1, &stmt, NULL); + if (sqlite3_step(stmt) != SQLITE_DONE) + goto sqlite_stmt_err; + sqlite3_finalize(stmt); + + return parts; + +sqlite_stmt_err: + db->error = tal_fmt(db, "%s", sqlite3_errmsg(db->conn)); + sqlite3_finalize(stmt); + return tal_free(parts); +} + +static bool complete_table_manip(struct db *db, + const char *tablename, + const char **coldefs, + const char **oldcolnames) +{ + sqlite3_stmt *stmt; + char *create_cmd, *insert_cmd, *drop_cmd; + + /* Create table */ + create_cmd = tal_fmt(tmpctx, "CREATE TABLE %s (", tablename); + for (size_t i = 0; i < tal_count(coldefs); i++) { + if (i != 0) + tal_append_fmt(&create_cmd, ", "); + tal_append_fmt(&create_cmd, "%s", coldefs[i]); + } + tal_append_fmt(&create_cmd, ";"); + + sqlite3_prepare_v2(db->conn, create_cmd, -1, &stmt, NULL); + if (sqlite3_step(stmt) != SQLITE_DONE) + goto sqlite_stmt_err; + sqlite3_finalize(stmt); + + /* Populate table from old one */ + insert_cmd = tal_fmt(tmpctx, "INSERT INTO %s SELECT ", tablename); + for (size_t i = 0; i < tal_count(oldcolnames); i++) { + if (i != 0) + tal_append_fmt(&insert_cmd, ", "); + tal_append_fmt(&insert_cmd, "%s", oldcolnames[i]); + } + tal_append_fmt(&insert_cmd, " FROM temp_%s;", tablename); + + sqlite3_prepare_v2(db->conn, insert_cmd, -1, &stmt, NULL); + if (sqlite3_step(stmt) != SQLITE_DONE) + goto sqlite_stmt_err; + sqlite3_finalize(stmt); + + /* Cleanup temp table */ + drop_cmd = tal_fmt(tmpctx, "DROP TABLE temp_%s;", tablename); + sqlite3_prepare_v2(db->conn, drop_cmd, -1, &stmt, NULL); + if (sqlite3_step(stmt) != SQLITE_DONE) + goto sqlite_stmt_err; + sqlite3_finalize(stmt); + + /* Allow links between them (esp. cascade deletes!) */ + sqlite3_prepare_v2(db->conn, "PRAGMA foreign_keys = ON;", -1, &stmt, NULL); + if (sqlite3_step(stmt) != SQLITE_DONE) + goto sqlite_stmt_err; + sqlite3_finalize(stmt); + + return true; + +sqlite_stmt_err: + db->error = tal_fmt(db, "%s", sqlite3_errmsg(db->conn)); + sqlite3_finalize(stmt); + return false; +} + +static bool db_sqlite3_rename_column(struct db *db, + const char *tablename, + const char *from, const char *to) +{ + char **parts; + const char **coldefs, **oldcolnames; + bool colname_found = false; + + parts = prepare_table_manip(tmpctx, db, tablename); + if (!parts) + return false; + + coldefs = tal_arr(tmpctx, const char *, 0); + oldcolnames = tal_arr(tmpctx, const char *, 0); + + for (size_t i = 0; parts[i]; i++) { + /* columnname DETAILS */ + size_t after_name; + const char *colname = find_column_name(tmpctx, parts[i], + &after_name); + + /* Things like "PRIMARY KEY xxx" must be copied verbatim */ + if (!colname) { + tal_arr_expand(&coldefs, parts[i]); + continue; + } + if (streq(colname, from)) { + char *newdef; + colname_found = true; + /* Create column with new name */ + newdef = tal_fmt(coldefs, + "%s%s", to, parts[i] + after_name); + tal_arr_expand(&coldefs, newdef); + tal_arr_expand(&oldcolnames, colname); + } else { + /* Not mentioned, keep it as is! */ + tal_arr_expand(&coldefs, parts[i]); + tal_arr_expand(&oldcolnames, colname); + } + } + + if (!colname_found) { + db->error = tal_fmt(db, "No column called %s", from); + return false; + } + return complete_table_manip(db, tablename, coldefs, oldcolnames); +} + +static bool db_sqlite3_delete_columns(struct db *db, + const char *tablename, + const char **colnames, size_t num_cols) +{ + char **parts; + const char **coldefs, **oldcolnames; + size_t colnames_found = 0; + + parts = prepare_table_manip(tmpctx, db, tablename); + if (!parts) + return false; + + coldefs = tal_arr(tmpctx, const char *, 0); + oldcolnames = tal_arr(tmpctx, const char *, 0); + + for (size_t i = 0; parts[i]; i++) { + /* columnname DETAILS */ + size_t after_name; + const char *colname = find_column_name(tmpctx, parts[i], + &after_name); + + /* Things like "PRIMARY KEY xxx" must be copied verbatim */ + if (!colname) { + tal_arr_expand(&coldefs, parts[i]); + continue; + } + + /* Don't mention columns we're supposed to delete */ + if (colname_to_delete(colnames, num_cols, colname)) { + colnames_found++; + continue; + } + + /* Keep it as is! */ + tal_arr_expand(&coldefs, parts[i]); + tal_arr_expand(&oldcolnames, colname); + } + + if (colnames_found != num_cols) { + db->error = tal_fmt(db, "Only %zu/%zu columns found", + colnames_found, num_cols); + return false; + } + return complete_table_manip(db, tablename, coldefs, oldcolnames); +} + struct db_config db_sqlite3_config = { .name = "sqlite3", .query_table = db_sqlite3_queries, @@ -275,6 +508,8 @@ struct db_config db_sqlite3_config = { .teardown_fn = &db_sqlite3_close, .vacuum_fn = db_sqlite3_vacuum, + .rename_column = db_sqlite3_rename_column, + .delete_columns = db_sqlite3_delete_columns, }; AUTODATA(db_backends, &db_sqlite3_config); diff --git a/wallet/test/run-db.c b/wallet/test/run-db.c index 92ae8a87a..c6123e9ab 100644 --- a/wallet/test/run-db.c +++ b/wallet/test/run-db.c @@ -167,6 +167,91 @@ static bool test_vars(struct lightningd *ld) return true; } +static bool test_manip_columns(void) +{ + struct db_stmt *stmt; + struct db *db = create_test_db(); + const char *field1 = "field1"; + + db_begin_transaction(db); + /* tablea refers to tableb */ + stmt = db_prepare_v2(db, SQL("CREATE TABLE tablea (" + " id BIGSERIAL" + ", field1 INTEGER" + ", PRIMARY KEY (id))")); + CHECK_MSG(db_exec_prepared_v2(stmt), "db_exec_prepared must succeed"); + CHECK_MSG(!db_err, "Simple correct SQL command"); + tal_free(stmt); + + stmt = db_prepare_v2(db, SQL("INSERT INTO tablea (id, field1) VALUES (0, 1);")); + CHECK_MSG(db_exec_prepared_v2(stmt), "db_exec_prepared must succeed"); + CHECK_MSG(!db_err, "Simple correct SQL command"); + tal_free(stmt); + + stmt = db_prepare_v2(db, SQL("CREATE TABLE tableb (" + " id REFERENCES tablea(id) ON DELETE CASCADE" + ", field1 INTEGER" + ", field2 INTEGER);")); + CHECK_MSG(db_exec_prepared_v2(stmt), "db_exec_prepared must succeed"); + CHECK_MSG(!db_err, "Simple correct SQL command"); + tal_free(stmt); + + stmt = db_prepare_v2(db, SQL("INSERT INTO tableb (id, field1, field2) VALUES (0, 1, 2);")); + CHECK_MSG(db_exec_prepared_v2(stmt), "db_exec_prepared must succeed"); + CHECK_MSG(!db_err, "Simple correct SQL command"); + tal_free(stmt); + /* Don't let it try to set a version field (we don't have one!) */ + db->dirty = false; + db->changes = tal_arr(db, const char *, 0); + db_commit_transaction(db); + + /* Rename tablea.field1 -> table1.field1a. */ + CHECK(db->config->rename_column(db, "tablea", "field1", "field1a")); + /* Remove tableb.field1. */ + CHECK(db->config->delete_columns(db, "tableb", &field1, 1)); + + db_begin_transaction(db); + stmt = db_prepare_v2(db, SQL("SELECT id, field1a FROM tablea;")); + CHECK_MSG(db_query_prepared(stmt), "db_query_prepared must succeed"); + CHECK_MSG(!db_err, "Simple correct SQL command"); + CHECK(db_step(stmt)); + CHECK(db_col_u64(stmt, "id") == 0); + CHECK(db_col_u64(stmt, "field1a") == 1); + CHECK(!db_step(stmt)); + tal_free(stmt); + + stmt = db_prepare_v2(db, SQL("SELECT id, field2 FROM tableb;")); + CHECK_MSG(db_query_prepared(stmt), "db_query_prepared must succeed"); + CHECK_MSG(!db_err, "Simple correct SQL command"); + CHECK(db_step(stmt)); + CHECK(db_col_u64(stmt, "id") == 0); + CHECK(db_col_u64(stmt, "field2") == 2); + CHECK(!db_step(stmt)); + tal_free(stmt); + db->dirty = false; + db->changes = tal_arr(db, const char *, 0); + db_commit_transaction(db); + + db_begin_transaction(db); + /* This will actually fail */ + stmt = db_prepare_v2(db, SQL("SELECT field1 FROM tablea;")); + CHECK_MSG(!db_query_prepared(stmt), "db_query_prepared must fail"); + db->dirty = false; + db->changes = tal_arr(db, const char *, 0); + db_commit_transaction(db); + + db_begin_transaction(db); + /* This will actually fail */ + stmt = db_prepare_v2(db, SQL("SELECT field1 FROM tableb;")); + CHECK_MSG(!db_query_prepared(stmt), "db_query_prepared must fail"); + db->dirty = false; + db->changes = tal_arr(db, const char *, 0); + db_commit_transaction(db); + + tal_free(db); + return true; +} + int main(int argc, char *argv[]) { bool ok = true; @@ -179,6 +264,7 @@ int main(int argc, char *argv[]) ok &= test_empty_db_migrate(ld); ok &= test_vars(ld); ok &= test_primitives(); + ok &= test_manip_columns(); tal_free(ld); common_shutdown();