diff --git a/wallet/wallet.c b/wallet/wallet.c index 651fa816a..33deec9db 100644 --- a/wallet/wallet.c +++ b/wallet/wallet.c @@ -1096,12 +1096,99 @@ bool wallet_htlcs_reconnect(struct wallet *wallet, if (!hout->in) { log_broken( wallet->log, - "Unable to find corresponding htlc_in %" PRIu64 - " for htlc_out %" PRIu64, + "Unable to find corresponding htlc_in %"PRIu64" for htlc_out %"PRIu64, hout->origin_htlc_id, hout->dbid); - return false; } } return true; } + +bool wallet_invoice_save(struct wallet *wallet, struct invoice *inv) +{ + /* Need to use the lower level API of sqlite3 to bind + * label. Otherwise we'd need to implement sanitization of + * that string for sql injections... */ + sqlite3_stmt *stmt; + if (!inv->id) { + stmt = db_prepare(wallet->db, + "INSERT INTO invoices (payment_hash, payment_key, state, msatoshi, label) VALUES (?, ?, ?, ?, ?);"); + if (!stmt) { + log_broken(wallet->log, "Could not prepare statement: %s", wallet->db->err); + return false; + } + + sqlite3_bind_blob(stmt, 1, &inv->rhash, sizeof(inv->rhash), SQLITE_TRANSIENT); + sqlite3_bind_blob(stmt, 2, &inv->r, sizeof(inv->r), SQLITE_TRANSIENT); + sqlite3_bind_int(stmt, 3, inv->state); + sqlite3_bind_int64(stmt, 4, inv->msatoshi); + sqlite3_bind_text(stmt, 5, inv->label, strlen(inv->label), SQLITE_TRANSIENT); + + if (!db_exec_prepared(wallet->db, stmt)) { + log_broken(wallet->log, "Could not exec prepared statement: %s", wallet->db->err); + return false; + } + + inv->id = sqlite3_last_insert_rowid(wallet->db->sql); + return true; + } else { + stmt = db_prepare(wallet->db, "UPDATE invoices SET state=? WHERE id=?;"); + + if (!stmt) { + log_broken(wallet->log, "Could not prepare statement: %s", wallet->db->err); + return false; + } + + sqlite3_bind_int(stmt, 1, inv->state); + sqlite3_bind_int64(stmt, 2, inv->id); + + if (!db_exec_prepared(wallet->db, stmt)) { + log_broken(wallet->log, "Could not exec prepared statement: %s", wallet->db->err); + return false; + } else { + return true; + } + } +} + +static bool wallet_stmt2invoice(sqlite3_stmt *stmt, struct invoice *inv) +{ + inv->id = sqlite3_column_int64(stmt, 0); + inv->state = sqlite3_column_int(stmt, 1); + + assert(sqlite3_column_bytes(stmt, 2) == sizeof(struct preimage)); + memcpy(&inv->r, sqlite3_column_blob(stmt, 2), sqlite3_column_bytes(stmt, 2)); + + assert(sqlite3_column_bytes(stmt, 3) == sizeof(struct sha256)); + memcpy(&inv->rhash, sqlite3_column_blob(stmt, 3), sqlite3_column_bytes(stmt, 3)); + + inv->label = tal_strndup(inv, sqlite3_column_blob(stmt, 4), sqlite3_column_bytes(stmt, 4)); + inv->msatoshi = sqlite3_column_int64(stmt, 5); + return true; +} + +bool wallet_invoices_load(struct wallet *wallet, struct invoices *invs) +{ + struct invoice *i; + int count = 0; + sqlite3_stmt *stmt = db_query(__func__, wallet->db, + "SELECT id, state, payment_key, payment_hash, " + "label, msatoshi FROM invoices;"); + if (!stmt) { + log_broken(wallet->log, "Could not load invoices: %s", wallet->db->err); + return false; + } + + while (sqlite3_step(stmt) == SQLITE_ROW) { + i = tal(invs, struct invoice); + if (!wallet_stmt2invoice(stmt, i)) { + log_broken(wallet->log, "Error deserializing invoice"); + return false; + } + invoice_add(invs, i); + count++; + } + + log_debug(wallet->log, "Loaded %d invoices from DB", count); + return true; +} diff --git a/wallet/wallet.h b/wallet/wallet.h index 0ed8837a8..dbd0768e6 100644 --- a/wallet/wallet.h +++ b/wallet/wallet.h @@ -10,6 +10,7 @@ #include #include #include +#include #include struct lightningd; @@ -300,4 +301,26 @@ bool wallet_htlcs_reconnect(struct wallet *wallet, struct htlc_in_map *htlcs_in, struct htlc_out_map *htlcs_out); +/** + * wallet_invoice_save -- Save/update an invoice to the wallet + * + * Save or update the invoice in the wallet. If `inv->id` is 0 this + * invoice will be considered a new invoice and result in an intert + * into the database, otherwise it'll be updated. + * + * @wallet: Wallet to store in + * @inv: Invoice to save + */ +bool wallet_invoice_save(struct wallet *wallet, struct invoice *inv); + +/** + * wallet_invoices_load -- Load all invoices into memory + * + * Load all invoices into the given `invoices` struct. + * + * @wallet: Wallet to load invoices from + * @invs: invoices container to load into + */ +bool wallet_invoices_load(struct wallet *wallet, struct invoices *invs); + #endif /* WALLET_WALLET_H */ diff --git a/wallet/wallet_tests.c b/wallet/wallet_tests.c index 5d490a78c..a6c05086d 100644 --- a/wallet/wallet_tests.c +++ b/wallet/wallet_tests.c @@ -8,6 +8,9 @@ #include #include +void invoice_add(struct invoices *invs, + struct invoice *inv){} + static struct wallet *create_test_wallet(const tal_t *ctx) { char filename[] = "/tmp/ldb-XXXXXX";