plugins/sql: fix foreign keys.

I noticed that our subtables were not being cleaned, despite being "ON
DELETE CASCADE".  This is because foreign keys were not enabled, but
then I got foreign key errors: rowid cannot be a foreign key anyway!

So create a real "rowid" column.  We want "ON DELETE CASCADE" for
nodes and channels (and other tables in future) where we update
partially.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell 2023-02-01 12:28:32 +10:30 committed by Alex Myers
parent 3dde1ca399
commit f87c7ed439
3 changed files with 41 additions and 16 deletions

View file

@ -80,6 +80,12 @@ Additionally, only the following functions are allowed:
TABLES
------
Note that the first column of every table is a unique integer called
`rowid`: this is used for related tables to refer to specific rows in
their parent. sqlite3 usually has this as an implicit column, but we
make it explicit as the implicit version is not allowed to be used as
a foreign key.
[comment]: # (GENERATE-DOC-START)
The following tables are currently supported:
- `bkpr_accountevents` (see lightning-bkpr-listaccountevents(7))

View file

@ -117,6 +117,7 @@ static struct sqlite3 *db;
static const char *dbfilename;
static int gosstore_fd = -1;
static size_t gosstore_nodes_off = 0, gosstore_channels_off = 0;
static u64 next_rowid = 1;
/* It was tempting to put these in the schema, but they're really
* just for our usage. Though that would allow us to autogen the
@ -225,6 +226,10 @@ static struct sqlite3 *sqlite_setup(struct plugin *plugin)
if (err != SQLITE_OK)
plugin_err(plugin, "Could not set max_page_count: %s", errmsg);
err = sqlite3_exec(db, "PRAGMA foreign_keys = ON;", NULL, NULL, &errmsg);
if (err != SQLITE_OK)
plugin_err(plugin, "Could not set foreign_keys: %s", errmsg);
return db;
}
@ -423,16 +428,16 @@ static struct command_result *process_json_obj(struct command *cmd,
const jsmntok_t *t,
const struct table_desc *td,
size_t row,
const u64 *rowid,
u64 this_rowid,
const u64 *parent_rowid,
size_t *sqloff,
sqlite3_stmt *stmt)
{
int err;
u64 parent_rowid;
/* Subtables have row, arrindex as first two columns. */
if (rowid) {
sqlite3_bind_int64(stmt, (*sqloff)++, *rowid);
if (parent_rowid) {
sqlite3_bind_int64(stmt, (*sqloff)++, *parent_rowid);
sqlite3_bind_int64(stmt, (*sqloff)++, row);
}
@ -448,8 +453,8 @@ static struct command_result *process_json_obj(struct command *cmd,
continue;
coltok = json_get_member(buf, t, col->jsonname);
ret = process_json_obj(cmd, buf, coltok, col->sub, row, NULL,
sqloff, stmt);
ret = process_json_obj(cmd, buf, coltok, col->sub, row, this_rowid,
NULL, sqloff, stmt);
if (ret)
return ret;
continue;
@ -564,8 +569,6 @@ static struct command_result *process_json_obj(struct command *cmd,
sqlite3_errmsg(db));
}
/* Now we have rowid, we can insert into any subtables. */
parent_rowid = sqlite3_last_insert_rowid(db);
for (size_t i = 0; i < tal_count(td->columns); i++) {
const struct column *col = &td->columns[i];
const jsmntok_t *coltok;
@ -578,7 +581,7 @@ static struct command_result *process_json_obj(struct command *cmd,
if (!coltok)
continue;
ret = process_json_list(cmd, buf, coltok, &parent_rowid, col->sub);
ret = process_json_list(cmd, buf, coltok, &this_rowid, col->sub);
if (ret)
return ret;
}
@ -589,7 +592,7 @@ static struct command_result *process_json_obj(struct command *cmd,
static struct command_result *process_json_list(struct command *cmd,
const char *buf,
const jsmntok_t *arr,
const u64 *rowid,
const u64 *parent_rowid,
const struct table_desc *td)
{
size_t i;
@ -608,7 +611,11 @@ static struct command_result *process_json_list(struct command *cmd,
json_for_each_arr(i, t, arr) {
/* sqlite3 columns are 1-based! */
size_t off = 1;
ret = process_json_obj(cmd, buf, t, td, i, rowid, &off, stmt);
u64 this_rowid = next_rowid++;
/* First entry is always the rowid */
sqlite3_bind_int64(stmt, off++, this_rowid);
ret = process_json_obj(cmd, buf, t, td, i, this_rowid, parent_rowid, &off, stmt);
if (ret)
break;
sqlite3_reset(stmt);
@ -1049,6 +1056,7 @@ static void json_add_schema(struct json_stream *js,
/* This needs to be an array, not a dictionary, since dicts
* are often treated as unordered, and order is critical! */
json_array_start(js, "columns");
json_add_column(js, "rowid", "INTEGER");
if (td->parent) {
json_add_column(js, "row", "INTEGER");
json_add_column(js, "arrindex", "INTEGER");
@ -1118,8 +1126,11 @@ static void finish_td(struct plugin *plugin, struct table_desc *td)
if (td->is_subobject)
return;
create_stmt = tal_fmt(tmpctx, "CREATE TABLE %s (", td->name);
td->update_stmt = tal_fmt(td, "INSERT INTO %s VALUES (", td->name);
/* We make an explicit rowid in each table, for subtables to access. This is
* becuase the implicit rowid can't be used as a foreign key! */
create_stmt = tal_fmt(tmpctx, "CREATE TABLE %s (rowid INTEGER PRIMARY KEY, ",
td->name);
td->update_stmt = tal_fmt(td, "INSERT INTO %s VALUES (?, ", td->name);
/* If we're a child array, we reference the parent column */
if (td->parent) {

View file

@ -3294,6 +3294,9 @@ def test_sql(node_factory, bitcoind):
ret = l2.rpc.sql("SELECT * FROM forwards;")
assert ret == {'rows': []}
# Test that we correctly clean up subtables!
assert len(l2.rpc.sql("SELECT * from peerchannels_features")['rows']) == len(l2.rpc.sql("SELECT * from peerchannels_features")['rows'])
# This should create a forward through l2
l1.rpc.pay(l3.rpc.invoice(amount_msat=12300, label='inv1', description='description')['bolt11'])
@ -3798,12 +3801,13 @@ def test_sql(node_factory, bitcoind):
'number': 'REAL',
'short_channel_id': 'TEXT'}
# Check schemas match.
# Check schemas match (each one has rowid at start)
rowidcol = {'name': 'rowid', 'type': 'u64'}
for table, schema in expected_schemas.items():
res = only_one(l2.rpc.listsqlschemas(table)['schemas'])
assert res['tablename'] == table
assert res.get('indices') == schema.get('indices')
sqlcolumns = [{'name': c['name'], 'type': sqltypemap[c['type']]} for c in schema['columns']]
sqlcolumns = [{'name': c['name'], 'type': sqltypemap[c['type']]} for c in [rowidcol] + schema['columns']]
assert res['columns'] == sqlcolumns
# Make sure we didn't miss any
@ -3827,7 +3831,11 @@ def test_sql(node_factory, bitcoind):
for table, schema in expected_schemas.items():
ret = l2.rpc.sql("SELECT * FROM {};".format(table))
assert len(ret['rows'][0]) == len(schema['columns'])
assert len(ret['rows'][0]) == 1 + len(schema['columns'])
# First column is always rowid!
for row in ret['rows']:
assert row[0] > 0
for col in schema['columns']:
val = only_one(l2.rpc.sql("SELECT {} FROM {};".format(col['name'], table))['rows'][0])