diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala index 883a148ac..6f1a5f325 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala @@ -30,10 +30,10 @@ import scodec.Attempt import scodec.bits.BitVector import scodec.codecs._ -import java.sql.{ResultSet, Statement} +import java.sql.{ResultSet, Statement, Timestamp} +import java.time.Instant import java.util.UUID import javax.sql.DataSource -import scala.concurrent.duration._ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb with Logging { @@ -42,7 +42,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit import lock._ val DB_NAME = "payments" - val CURRENT_VERSION = 5 + val CURRENT_VERSION = 6 private val hopSummaryCodec = (("node_id" | CommonCodecs.publicKey) :: ("next_node_id" | CommonCodecs.publicKey) :: ("short_channel_id" | optional(bool, CommonCodecs.shortchannelid))).as[HopSummary] private val paymentRouteCodec = discriminated[List[HopSummary]].by(byte) @@ -62,12 +62,21 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit statement.executeUpdate("ALTER TABLE sent SET SCHEMA payments") } + def migration56(statement: Statement): Unit = { + statement.executeUpdate("ALTER TABLE payments.received ALTER COLUMN created_at SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + created_at * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE payments.received ALTER COLUMN expire_at SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + expire_at * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE payments.received ALTER COLUMN received_at SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + received_at * interval '1 millisecond'") + + statement.executeUpdate("ALTER TABLE payments.sent ALTER COLUMN created_at SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + created_at * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE payments.sent ALTER COLUMN completed_at SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + completed_at * interval '1 millisecond'") + } + getVersion(statement, DB_NAME) match { case None => statement.executeUpdate("CREATE SCHEMA payments") - statement.executeUpdate("CREATE TABLE payments.received (payment_hash TEXT NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage TEXT NOT NULL, payment_request TEXT NOT NULL, received_msat BIGINT, created_at BIGINT NOT NULL, expire_at BIGINT NOT NULL, received_at BIGINT)") - statement.executeUpdate("CREATE TABLE payments.sent (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash TEXT NOT NULL, payment_preimage TEXT, payment_type TEXT NOT NULL, amount_msat BIGINT NOT NULL, fees_msat BIGINT, recipient_amount_msat BIGINT NOT NULL, recipient_node_id TEXT NOT NULL, payment_request TEXT, payment_route BYTEA, failures BYTEA, created_at BIGINT NOT NULL, completed_at BIGINT)") + statement.executeUpdate("CREATE TABLE payments.received (payment_hash TEXT NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage TEXT NOT NULL, payment_request TEXT NOT NULL, received_msat BIGINT, created_at TIMESTAMP WITH TIME ZONE NOT NULL, expire_at TIMESTAMP WITH TIME ZONE NOT NULL, received_at TIMESTAMP WITH TIME ZONE)") + statement.executeUpdate("CREATE TABLE payments.sent (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash TEXT NOT NULL, payment_preimage TEXT, payment_type TEXT NOT NULL, amount_msat BIGINT NOT NULL, fees_msat BIGINT, recipient_amount_msat BIGINT NOT NULL, recipient_node_id TEXT NOT NULL, payment_request TEXT, payment_route BYTEA, failures BYTEA, created_at TIMESTAMP WITH TIME ZONE NOT NULL, completed_at TIMESTAMP WITH TIME ZONE)") statement.executeUpdate("CREATE INDEX sent_parent_id_idx ON payments.sent(parent_id)") statement.executeUpdate("CREATE INDEX sent_payment_hash_idx ON payments.sent(payment_hash)") @@ -76,6 +85,10 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit case Some(v@4) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") migration45(statement) + migration56(statement) + case Some(v@5) => + logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") + migration56(statement) case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -95,7 +108,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit statement.setLong(6, sent.amount.toLong) statement.setLong(7, sent.recipientAmount.toLong) statement.setString(8, sent.recipientNodeId.value.toHex) - statement.setLong(9, sent.createdAt) + statement.setTimestamp(9, Timestamp.from(Instant.ofEpochMilli(sent.createdAt))) statement.setString(10, sent.paymentRequest.map(PaymentRequest.write).orNull) statement.executeUpdate() } @@ -106,7 +119,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit withLock { pg => using(pg.prepareStatement("UPDATE payments.sent SET (completed_at, payment_preimage, fees_msat, payment_route) = (?, ?, ?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => paymentResult.parts.foreach(p => { - statement.setLong(1, p.timestamp) + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(p.timestamp))) statement.setString(2, paymentResult.paymentPreimage.toHex) statement.setLong(3, p.feesPaid.toLong) statement.setBytes(4, paymentRouteCodec.encode(p.route.getOrElse(Nil).map(h => HopSummary(h)).toList).require.toByteArray) @@ -121,7 +134,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def updateOutgoingPayment(paymentResult: PaymentFailed): Unit = withMetrics("payments/update-outgoing-failed", DbBackends.Postgres) { withLock { pg => using(pg.prepareStatement("UPDATE payments.sent SET (completed_at, failures) = (?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => - statement.setLong(1, paymentResult.timestamp) + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(paymentResult.timestamp))) statement.setBytes(2, paymentFailuresCodec.encode(paymentResult.failures.map(f => FailureSummary(f)).toList).require.toByteArray) statement.setString(3, paymentResult.id.toString) if (statement.executeUpdate() == 0) throw new IllegalArgumentException(s"Tried to mark an outgoing payment as failed but already in final status (id=${paymentResult.id})") @@ -134,7 +147,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit rs.getByteVector32FromHexNullable("payment_preimage"), rs.getMilliSatoshiNullable("fees_msat"), rs.getBitVectorOpt("payment_route"), - rs.getLongNullable("completed_at"), + rs.getTimestampNullable("completed_at").map(_.getTime), rs.getBitVectorOpt("failures")) OutgoingPayment( @@ -146,7 +159,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit MilliSatoshi(rs.getLong("amount_msat")), MilliSatoshi(rs.getLong("recipient_amount_msat")), PublicKey(rs.getByteVectorFromHex("recipient_node_id")), - rs.getLong("created_at"), + rs.getTimestamp("created_at").getTime, rs.getStringNullable("payment_request").map(PaymentRequest.read), status ) @@ -207,8 +220,8 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listOutgoingPayments(from: Long, to: Long): Seq[OutgoingPayment] = withMetrics("payments/list-outgoing-by-timestamp", DbBackends.Postgres) { withLock { pg => using(pg.prepareStatement("SELECT * FROM payments.sent WHERE created_at >= ? AND created_at < ? ORDER BY created_at")) { statement => - statement.setLong(1, from) - statement.setLong(2, to) + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) + statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) statement.executeQuery().map { rs => parseOutgoingPayment(rs) }.toSeq @@ -223,8 +236,8 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit statement.setString(2, preimage.toHex) statement.setString(3, paymentType) statement.setString(4, PaymentRequest.write(pr)) - statement.setLong(5, pr.timestamp.seconds.toMillis) // BOLT11 timestamp is in seconds - statement.setLong(6, (pr.timestamp + pr.expiry.getOrElse(PaymentRequest.DEFAULT_EXPIRY_SECONDS.toLong)).seconds.toMillis) + statement.setTimestamp(5, Timestamp.from(Instant.ofEpochSecond(pr.timestamp))) // BOLT11 timestamp is in seconds + statement.setTimestamp(6, Timestamp.from(Instant.ofEpochSecond(pr.timestamp + pr.expiry.getOrElse(PaymentRequest.DEFAULT_EXPIRY_SECONDS.toLong)))) statement.executeUpdate() } } @@ -234,7 +247,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit withLock { pg => using(pg.prepareStatement("UPDATE payments.received SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update => update.setLong(1, amount.toLong) - update.setLong(2, receivedAt) + update.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(receivedAt))) update.setString(3, paymentHash.toHex) val updated = update.executeUpdate() if (updated == 0) { @@ -250,8 +263,8 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit PaymentRequest.read(paymentRequest), rs.getByteVector32FromHex("payment_preimage"), rs.getString("payment_type"), - rs.getLong("created_at"), - buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("received_msat"), Some(paymentRequest), rs.getLongNullable("received_at"))) + rs.getTimestamp("created_at").getTime, + buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("received_msat"), Some(paymentRequest), rs.getTimestampNullable("received_at").map(_.getTime))) } private def buildIncomingPaymentStatus(amount_opt: Option[MilliSatoshi], serializedPaymentRequest_opt: Option[String], receivedAt_opt: Option[Long]): IncomingPaymentStatus = { @@ -274,8 +287,8 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = withMetrics("payments/list-incoming", DbBackends.Postgres) { withLock { pg => using(pg.prepareStatement("SELECT * FROM payments.received WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement => - statement.setLong(1, from) - statement.setLong(2, to) + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) + statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) statement.executeQuery().map(parseIncomingPayment).toSeq } } @@ -284,8 +297,8 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listReceivedIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = withMetrics("payments/list-incoming-received", DbBackends.Postgres) { withLock { pg => using(pg.prepareStatement("SELECT * FROM payments.received WHERE received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at")) { statement => - statement.setLong(1, from) - statement.setLong(2, to) + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) + statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) statement.executeQuery().map(parseIncomingPayment).toSeq } } @@ -294,9 +307,9 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listPendingIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = withMetrics("payments/list-incoming-pending", DbBackends.Postgres) { withLock { pg => using(pg.prepareStatement("SELECT * FROM payments.received WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at")) { statement => - statement.setLong(1, from) - statement.setLong(2, to) - statement.setLong(3, System.currentTimeMillis) + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) + statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) + statement.setTimestamp(3, Timestamp.from(Instant.now())) statement.executeQuery().map(parseIncomingPayment).toSeq } } @@ -305,9 +318,9 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listExpiredIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = withMetrics("payments/list-incoming-expired", DbBackends.Postgres) { withLock { pg => using(pg.prepareStatement("SELECT * FROM payments.received WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at")) { statement => - statement.setLong(1, from) - statement.setLong(2, to) - statement.setLong(3, System.currentTimeMillis) + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) + statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) + statement.setTimestamp(3, Timestamp.from(Instant.now())) statement.executeQuery().map(parseIncomingPayment).toSeq } } @@ -366,9 +379,9 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit val paymentType = rs.getString("payment_type") val paymentRequest_opt = rs.getStringNullable("payment_request") val amount_opt = rs.getMilliSatoshiNullable("final_amount") - val createdAt = rs.getLong("created_at") - val completedAt_opt = rs.getLongNullable("completed_at") - val expireAt_opt = rs.getLongNullable("expire_at") + val createdAt = rs.getTimestamp("created_at").getTime + val completedAt_opt = rs.getTimestampNullable("completed_at").map(_.getTime) + val expireAt_opt = rs.getTimestampNullable("expire_at").map(_.getTime) if (rs.getString("type") == "received") { val status: IncomingPaymentStatus = buildIncomingPaymentStatus(amount_opt, paymentRequest_opt, completedAt_opt) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala index 312df91d6..e55067370 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala @@ -18,8 +18,9 @@ package fr.acinq.eclair.db import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.{Block, ByteVector32, Crypto} -import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases, forAllDbs} +import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases, forAllDbs, migrationCheck} import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.db.jdbc.JdbcUtils.{setVersion, using} import fr.acinq.eclair.db.pg.PgPaymentsDb import fr.acinq.eclair.db.sqlite.SqlitePaymentsDb import fr.acinq.eclair.payment._ @@ -28,6 +29,7 @@ import fr.acinq.eclair.wire.protocol.{ChannelUpdate, UnknownNextPeer} import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, TestDatabases, randomBytes32, randomBytes64, randomKey} import org.scalatest.funsuite.AnyFunSuite +import java.time.Instant import java.util.UUID import scala.concurrent.duration._ @@ -46,22 +48,17 @@ class PaymentsDbSpec extends AnyFunSuite { } } - test("handle version migration 1->4") { - import fr.acinq.eclair.db.sqlite.SqliteUtils._ - forAllDbs { - case _: TestPgDatabases => // no migration - case dbs: TestSqliteDatabases => - val connection = dbs.connection + test("migrate sqlite payments db v1 -> v4") { + val dbs = TestSqliteDatabases() + migrationCheck( + dbs = dbs, + initializeTables = connection => { + // simulate existing previous version db using(connection.createStatement()) { statement => statement.executeUpdate("CREATE TABLE IF NOT EXISTS payments (payment_hash BLOB NOT NULL PRIMARY KEY, amount_msat INTEGER NOT NULL, timestamp INTEGER NOT NULL)") setVersion(statement, "payments", 1) } - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "payments").contains(1)) - } - // Changes between version 1 and 2: // - the monolithic payments table has been replaced by two tables, received_payments and sent_payments // - old records from the payments table are ignored (not migrated to the new tables) @@ -71,68 +68,55 @@ class PaymentsDbSpec extends AnyFunSuite { statement.setLong(3, 1000) // received_at statement.executeUpdate() } - - val preMigrationDb = new SqlitePaymentsDb(connection) - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "payments").contains(4)) - } - + }, + dbName = "payments", + targetVersion = 4, + postCheck = _ => { + val db = dbs.db.payments // the existing received payment can NOT be queried anymore - assert(preMigrationDb.getIncomingPayment(paymentHash1).isEmpty) + assert(db.getIncomingPayment(paymentHash1).isEmpty) // add a few rows val ps1 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), None, paymentHash1, PaymentType.Standard, 12345 msat, 12345 msat, alice, 1000, None, OutgoingPaymentStatus.Pending) val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(500 msat), paymentHash1, davePriv, "Some invoice", CltvExpiryDelta(18), expirySeconds = None, timestamp = 1) val pr1 = IncomingPayment(i1, preimage1, PaymentType.Standard, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(550 msat, 1100)) - preMigrationDb.addOutgoingPayment(ps1) - preMigrationDb.addIncomingPayment(i1, preimage1) - preMigrationDb.receiveIncomingPayment(i1.paymentHash, 550 msat, 1100) + db.addOutgoingPayment(ps1) + db.addIncomingPayment(i1, preimage1) + db.receiveIncomingPayment(i1.paymentHash, 550 msat, 1100) - assert(preMigrationDb.listIncomingPayments(1, 1500) === Seq(pr1)) - assert(preMigrationDb.listOutgoingPayments(1, 1500) === Seq(ps1)) + assert(db.listIncomingPayments(1, 1500) === Seq(pr1)) + assert(db.listOutgoingPayments(1, 1500) === Seq(ps1)) - val postMigrationDb = new SqlitePaymentsDb(connection) - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "payments").contains(4)) - } - - assert(postMigrationDb.listIncomingPayments(1, 1500) === Seq(pr1)) - assert(postMigrationDb.listOutgoingPayments(1, 1500) === Seq(ps1)) - } + } + ) } - test("handle version migration 2->4") { - import fr.acinq.eclair.db.sqlite.SqliteUtils._ - forAllDbs { - case _: TestPgDatabases => // no migration - case dbs: TestSqliteDatabases => - val connection = dbs.connection + test("migrate sqlite payments db v2 -> v4") { + val dbs = TestSqliteDatabases() + // Test data + val id1 = UUID.randomUUID() + val id2 = UUID.randomUUID() + val id3 = UUID.randomUUID() + val ps1 = OutgoingPayment(id1, id1, None, randomBytes32(), PaymentType.Standard, 561 msat, 561 msat, PrivateKey(ByteVector32.One).publicKey, 1000, None, OutgoingPaymentStatus.Pending) + val ps2 = OutgoingPayment(id2, id2, None, randomBytes32(), PaymentType.Standard, 1105 msat, 1105 msat, PrivateKey(ByteVector32.One).publicKey, 1010, None, OutgoingPaymentStatus.Failed(Nil, 1050)) + val ps3 = OutgoingPayment(id3, id3, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, PrivateKey(ByteVector32.One).publicKey, 1040, None, OutgoingPaymentStatus.Succeeded(preimage1, 0 msat, Nil, 1060)) + val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash1, davePriv, "Some invoice", CltvExpiryDelta(18), expirySeconds = None, timestamp = 1) + val pr1 = IncomingPayment(i1, preimage1, PaymentType.Standard, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(12345678 msat, 1090)) + val i2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash2, carolPriv, "Another invoice", CltvExpiryDelta(18), expirySeconds = Some(30), timestamp = 1) + val pr2 = IncomingPayment(i2, preimage2, PaymentType.Standard, i2.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired) + + migrationCheck( + dbs = dbs, + initializeTables = connection => { using(connection.createStatement()) { statement => statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER, received_at INTEGER)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, payment_hash BLOB NOT NULL, preimage BLOB, amount_msat INTEGER NOT NULL, created_at INTEGER NOT NULL, completed_at INTEGER, status VARCHAR NOT NULL)") statement.executeUpdate("CREATE INDEX IF NOT EXISTS payment_hash_idx ON sent_payments(payment_hash)") setVersion(statement, "payments", 2) } - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "payments").contains(2)) - } - // Insert a bunch of old version 2 rows. - val id1 = UUID.randomUUID() - val id2 = UUID.randomUUID() - val id3 = UUID.randomUUID() - val ps1 = OutgoingPayment(id1, id1, None, randomBytes32(), PaymentType.Standard, 561 msat, 561 msat, PrivateKey(ByteVector32.One).publicKey, 1000, None, OutgoingPaymentStatus.Pending) - val ps2 = OutgoingPayment(id2, id2, None, randomBytes32(), PaymentType.Standard, 1105 msat, 1105 msat, PrivateKey(ByteVector32.One).publicKey, 1010, None, OutgoingPaymentStatus.Failed(Nil, 1050)) - val ps3 = OutgoingPayment(id3, id3, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, PrivateKey(ByteVector32.One).publicKey, 1040, None, OutgoingPaymentStatus.Succeeded(preimage1, 0 msat, Nil, 1060)) - val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash1, davePriv, "Some invoice", CltvExpiryDelta(18), expirySeconds = None, timestamp = 1) - val pr1 = IncomingPayment(i1, preimage1, PaymentType.Standard, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(12345678 msat, 1090)) - val i2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash2, carolPriv, "Another invoice", CltvExpiryDelta(18), expirySeconds = Some(30), timestamp = 1) - val pr2 = IncomingPayment(i2, preimage2, PaymentType.Standard, i2.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired) // Changes between version 2 and 3 to sent_payments: // - removed the status column @@ -194,49 +178,49 @@ class PaymentsDbSpec extends AnyFunSuite { statement.setLong(5, (i2.timestamp + i2.expiry.get).seconds.toMillis) statement.executeUpdate() } + }, + dbName = "payments", + targetVersion = 4, + postCheck = _ => { + val db = dbs.db.payments - val preMigrationDb = new SqlitePaymentsDb(connection) - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "payments").contains(4)) - } - - assert(preMigrationDb.getIncomingPayment(i1.paymentHash) === Some(pr1)) - assert(preMigrationDb.getIncomingPayment(i2.paymentHash) === Some(pr2)) - assert(preMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2, ps3)) - - val postMigrationDb = new SqlitePaymentsDb(connection) - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "payments").contains(4)) - } + assert(db.getIncomingPayment(i1.paymentHash) === Some(pr1)) + assert(db.getIncomingPayment(i2.paymentHash) === Some(pr2)) + assert(db.listOutgoingPayments(1, 2000) === Seq(ps1, ps2, ps3)) val i3 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), paymentHash3, alicePriv, "invoice #3", CltvExpiryDelta(18), expirySeconds = Some(30)) val pr3 = IncomingPayment(i3, preimage3, PaymentType.Standard, i3.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending) - postMigrationDb.addIncomingPayment(i3, pr3.paymentPreimage) + db.addIncomingPayment(i3, pr3.paymentPreimage) val ps4 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("1"), randomBytes32(), PaymentType.Standard, 123 msat, 123 msat, alice, 1100, Some(i3), OutgoingPaymentStatus.Pending) val ps5 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("2"), randomBytes32(), PaymentType.Standard, 456 msat, 456 msat, bob, 1150, Some(i2), OutgoingPaymentStatus.Succeeded(preimage1, 42 msat, Nil, 1180)) val ps6 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("3"), randomBytes32(), PaymentType.Standard, 789 msat, 789 msat, bob, 1250, None, OutgoingPaymentStatus.Failed(Nil, 1300)) - postMigrationDb.addOutgoingPayment(ps4) - postMigrationDb.addOutgoingPayment(ps5.copy(status = OutgoingPaymentStatus.Pending)) - postMigrationDb.updateOutgoingPayment(PaymentSent(ps5.parentId, ps5.paymentHash, preimage1, ps5.amount, ps5.recipientNodeId, Seq(PaymentSent.PartialPayment(ps5.id, ps5.amount, 42 msat, randomBytes32(), None, 1180)))) - postMigrationDb.addOutgoingPayment(ps6.copy(status = OutgoingPaymentStatus.Pending)) - postMigrationDb.updateOutgoingPayment(PaymentFailed(ps6.id, ps6.paymentHash, Nil, 1300)) + db.addOutgoingPayment(ps4) + db.addOutgoingPayment(ps5.copy(status = OutgoingPaymentStatus.Pending)) + db.updateOutgoingPayment(PaymentSent(ps5.parentId, ps5.paymentHash, preimage1, ps5.amount, ps5.recipientNodeId, Seq(PaymentSent.PartialPayment(ps5.id, ps5.amount, 42 msat, randomBytes32(), None, 1180)))) + db.addOutgoingPayment(ps6.copy(status = OutgoingPaymentStatus.Pending)) + db.updateOutgoingPayment(PaymentFailed(ps6.id, ps6.paymentHash, Nil, 1300)) - assert(postMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2, ps3, ps4, ps5, ps6)) - assert(postMigrationDb.listIncomingPayments(1, System.currentTimeMillis) === Seq(pr1, pr2, pr3)) - assert(postMigrationDb.listExpiredIncomingPayments(1, 2000) === Seq(pr2)) - } + assert(db.listOutgoingPayments(1, 2000) === Seq(ps1, ps2, ps3, ps4, ps5, ps6)) + assert(db.listIncomingPayments(1, System.currentTimeMillis) === Seq(pr1, pr2, pr3)) + assert(db.listExpiredIncomingPayments(1, 2000) === Seq(pr2)) + }) } - test("handle version migration 3->4") { - forAllDbs { - case _: TestPgDatabases => // no migration - case dbs: TestSqliteDatabases => - import fr.acinq.eclair.db.sqlite.SqliteUtils._ - val connection = dbs.connection + test("migrate sqlite payments db v3 -> v4") { + val dbs = TestSqliteDatabases() + // Test data + val (id1, id2, id3) = (UUID.randomUUID(), UUID.randomUUID(), UUID.randomUUID()) + val parentId = UUID.randomUUID() + val invoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(2834 msat), paymentHash1, bobPriv, "invoice #1", CltvExpiryDelta(18), expirySeconds = Some(30)) + val ps1 = OutgoingPayment(id1, id1, Some("42"), randomBytes32(), PaymentType.Standard, 561 msat, 561 msat, alice, 1000, None, OutgoingPaymentStatus.Failed(Seq(FailureSummary(FailureType.REMOTE, "no candy for you", List(HopSummary(hop_ab), HopSummary(hop_bc)))), 1020)) + val ps2 = OutgoingPayment(id2, parentId, Some("42"), paymentHash1, PaymentType.Standard, 1105 msat, 1105 msat, bob, 1010, Some(invoice1), OutgoingPaymentStatus.Pending) + val ps3 = OutgoingPayment(id3, parentId, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, bob, 1040, None, OutgoingPaymentStatus.Succeeded(preimage1, 10 msat, Seq(HopSummary(hop_ab), HopSummary(hop_bc)), 1060)) + + migrationCheck( + dbs = dbs, + initializeTables = connection => { using(connection.createStatement()) { statement => statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, payment_preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER NOT NULL, received_at INTEGER)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)") @@ -249,17 +233,7 @@ class PaymentsDbSpec extends AnyFunSuite { setVersion(statement, "payments", 3) } - using(connection.createStatement()) { statement => - assert(getVersion(statement, "payments").contains(3)) - } - // Insert a bunch of old version 3 rows. - val (id1, id2, id3) = (UUID.randomUUID(), UUID.randomUUID(), UUID.randomUUID()) - val parentId = UUID.randomUUID() - val invoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(2834 msat), paymentHash1, bobPriv, "invoice #1", CltvExpiryDelta(18), expirySeconds = Some(30)) - val ps1 = OutgoingPayment(id1, id1, Some("42"), randomBytes32(), PaymentType.Standard, 561 msat, 561 msat, alice, 1000, None, OutgoingPaymentStatus.Failed(Seq(FailureSummary(FailureType.REMOTE, "no candy for you", List(HopSummary(hop_ab), HopSummary(hop_bc)))), 1020)) - val ps2 = OutgoingPayment(id2, parentId, Some("42"), paymentHash1, PaymentType.Standard, 1105 msat, 1105 msat, bob, 1010, Some(invoice1), OutgoingPaymentStatus.Pending) - val ps3 = OutgoingPayment(id3, parentId, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, bob, 1040, None, OutgoingPaymentStatus.Succeeded(preimage1, 10 msat, Seq(HopSummary(hop_ab), HopSummary(hop_bc)), 1060)) using(connection.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, completed_at, failures) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement => statement.setString(1, ps1.id.toString) @@ -305,22 +279,119 @@ class PaymentsDbSpec extends AnyFunSuite { // - added payment type column, with a default to "Standard" // - renamed target_node_id -> recipient_node_id // - re-ordered columns + }, + dbName = "payments", + targetVersion = 4, + postCheck = _ => { + val db = dbs.db.payments + assert(db.getOutgoingPayment(id1) === Some(ps1)) + assert(db.listOutgoingPayments(parentId) === Seq(ps2, ps3)) + } + ) + } - val preMigrationDb = new SqlitePaymentsDb(connection) + test("migrate postgres payments db v4 -> v6") { + val dbs = TestPgDatabases() + // Test data + val id1 = UUID.randomUUID() + val id2 = UUID.randomUUID() + val id3 = UUID.randomUUID() + val ps1 = OutgoingPayment(id1, id1, None, randomBytes32(), PaymentType.Standard, 561 msat, 561 msat, PrivateKey(ByteVector32.One).publicKey, Instant.parse("2021-01-01T10:15:30.00Z").toEpochMilli, None, OutgoingPaymentStatus.Pending) + val ps2 = OutgoingPayment(id2, id2, None, randomBytes32(), PaymentType.Standard, 1105 msat, 1105 msat, PrivateKey(ByteVector32.One).publicKey, Instant.parse("2020-05-14T13:47:21.00Z").toEpochMilli, None, OutgoingPaymentStatus.Failed(Nil, Instant.parse("2021-05-15T04:12:40.00Z").toEpochMilli)) + val ps3 = OutgoingPayment(id3, id3, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, PrivateKey(ByteVector32.One).publicKey, Instant.parse("2021-01-28T09:12:05.00Z").toEpochMilli, None, OutgoingPaymentStatus.Succeeded(preimage1, 0 msat, Nil, Instant.now().toEpochMilli)) + val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash1, davePriv, "Some invoice", CltvExpiryDelta(18), expirySeconds = None, timestamp = Instant.now().getEpochSecond) + val pr1 = IncomingPayment(i1, preimage1, PaymentType.Standard, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(12345678 msat, Instant.now().toEpochMilli)) + val i2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash2, carolPriv, "Another invoice", CltvExpiryDelta(18), expirySeconds = Some(24 * 3600), timestamp = Instant.parse("2020-12-30T10:00:55.00Z").getEpochSecond) + val pr2 = IncomingPayment(i2, preimage2, PaymentType.Standard, i2.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired) + + migrationCheck( + dbs = dbs, + initializeTables = connection => { using(connection.createStatement()) { statement => - assert(getVersion(statement, "payments").contains(4)) + statement.executeUpdate("CREATE TABLE received_payments (payment_hash TEXT NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage TEXT NOT NULL, payment_request TEXT NOT NULL, received_msat BIGINT, created_at BIGINT NOT NULL, expire_at BIGINT NOT NULL, received_at BIGINT)") + statement.executeUpdate("CREATE TABLE sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash TEXT NOT NULL, payment_preimage TEXT, payment_type TEXT NOT NULL, amount_msat BIGINT NOT NULL, fees_msat BIGINT, recipient_amount_msat BIGINT NOT NULL, recipient_node_id TEXT NOT NULL, payment_request TEXT, payment_route BYTEA, failures BYTEA, created_at BIGINT NOT NULL, completed_at BIGINT)") + + statement.executeUpdate("CREATE INDEX sent_parent_id_idx ON sent_payments(parent_id)") + statement.executeUpdate("CREATE INDEX sent_payment_hash_idx ON sent_payments(payment_hash)") + statement.executeUpdate("CREATE INDEX sent_created_idx ON sent_payments(created_at)") + statement.executeUpdate("CREATE INDEX received_created_idx ON received_payments(created_at)") + + setVersion(statement, "payments", 4) + } + // insert test data + Seq(ps1, ps2, ps3).foreach { sent => + using(connection.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, payment_type, amount_msat, recipient_amount_msat, recipient_node_id, created_at, payment_request, completed_at, payment_preimage) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, sent.id.toString) + statement.setString(2, sent.parentId.toString) + statement.setString(3, sent.externalId.orNull) + statement.setString(4, sent.paymentHash.toHex) + statement.setString(5, sent.paymentType) + statement.setLong(6, sent.amount.toLong) + statement.setLong(7, sent.recipientAmount.toLong) + statement.setString(8, sent.recipientNodeId.value.toHex) + statement.setLong(9, sent.createdAt) + statement.setString(10, sent.paymentRequest.map(PaymentRequest.write).orNull) + sent.status match { + case s: OutgoingPaymentStatus.Succeeded => + statement.setLong(11, s.completedAt) + statement.setString(12, s.paymentPreimage.toHex) + case s: OutgoingPaymentStatus.Failed => + statement.setLong(11, s.completedAt) + statement.setObject(12, null) + case _ => + statement.setObject(11, null) + statement.setObject(12, null) + } + statement.executeUpdate() + } } - assert(preMigrationDb.getOutgoingPayment(id1) === Some(ps1)) - assert(preMigrationDb.listOutgoingPayments(parentId) === Seq(ps2, ps3)) - - val postMigrationDb = new SqlitePaymentsDb(connection) - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "payments").contains(4)) + Seq((i1, preimage1), (i2, preimage2)).foreach { case (pr, preimage) => + using(connection.prepareStatement("INSERT INTO received_payments (payment_hash, payment_preimage, payment_type, payment_request, created_at, expire_at) VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, pr.paymentHash.toHex) + statement.setString(2, preimage.toHex) + statement.setString(3, PaymentType.Standard) + statement.setString(4, PaymentRequest.write(pr)) + statement.setLong(5, pr.timestamp.seconds.toMillis) // BOLT11 timestamp is in seconds + statement.setLong(6, (pr.timestamp + pr.expiry.getOrElse(PaymentRequest.DEFAULT_EXPIRY_SECONDS.toLong)).seconds.toMillis) + statement.executeUpdate() + } } - } + + using(connection.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update => + update.setLong(1, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].amount.toLong) + update.setLong(2, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].receivedAt) + update.setString(3, pr1.paymentRequest.paymentHash.toHex) + val updated = update.executeUpdate() + if (updated == 0) { + throw new IllegalArgumentException("Inserted a received payment without having an invoice") + } + } + + import fr.acinq.eclair.db.jdbc.JdbcUtils.ExtendedResultSet._ + assert(connection.createStatement().executeQuery("SELECT * FROM received_payments").map(rs => rs.getString("payment_hash")).toSeq.size > 0) + + }, + dbName = "payments", + targetVersion = 6, + postCheck = _ => { + val db = dbs.db.payments + + assert(db.getIncomingPayment(i1.paymentHash) === Some(pr1)) + assert(db.getIncomingPayment(i2.paymentHash) === Some(pr2)) + assert(db.listIncomingPayments(Instant.parse("2020-01-01T00:00:00.00Z").toEpochMilli, Instant.parse("2100-12-31T23:59:59.00Z").toEpochMilli) === Seq(pr2, pr1)) + assert(db.listIncomingPayments(Instant.parse("2020-01-01T00:00:00.00Z").toEpochMilli, Instant.parse("2020-12-31T23:59:59.00Z").toEpochMilli) === Seq(pr2)) + assert(db.listIncomingPayments(Instant.parse("2010-01-01T00:00:00.00Z").toEpochMilli, Instant.parse("2011-12-31T23:59:59.00Z").toEpochMilli) === Seq.empty) + assert(db.listExpiredIncomingPayments(Instant.parse("2020-01-01T00:00:00.00Z").toEpochMilli, Instant.parse("2100-12-31T23:59:59.00Z").toEpochMilli) === Seq(pr2)) + assert(db.listExpiredIncomingPayments(Instant.parse("2020-01-01T00:00:00.00Z").toEpochMilli, Instant.parse("2020-12-31T23:59:59.00Z").toEpochMilli) === Seq(pr2)) + assert(db.listExpiredIncomingPayments(Instant.parse("2010-01-01T00:00:00.00Z").toEpochMilli, Instant.parse("2011-12-31T23:59:59.00Z").toEpochMilli) === Seq.empty) + + assert(db.listOutgoingPayments(Instant.parse("2020-01-01T00:00:00.00Z").toEpochMilli, Instant.parse("2021-12-31T23:59:59.00Z").toEpochMilli) === Seq(ps2, ps1, ps3)) + assert(db.listOutgoingPayments(Instant.parse("2010-01-01T00:00:00.00Z").toEpochMilli, Instant.parse("2021-01-15T23:59:59.00Z").toEpochMilli) === Seq(ps2, ps1)) + assert(db.listOutgoingPayments(Instant.parse("2010-01-01T00:00:00.00Z").toEpochMilli, Instant.parse("2011-12-31T23:59:59.00Z").toEpochMilli) === Seq.empty) + } + ) } test("add/retrieve/update incoming payments") {