diff --git a/docs/PostgreSQL.md b/docs/PostgreSQL.md index e312da36d..315d45b3f 100644 --- a/docs/PostgreSQL.md +++ b/docs/PostgreSQL.md @@ -111,4 +111,34 @@ Eclair stores the latest database settings in the `${data-dir}/last_jdbcurl` fil The node operator can force Eclair to accept new database connection settings by removing the `last_jdbcurl` file. - \ No newline at end of file + +### Migrating from Sqlite to Postgres + +Eclair supports migrating your existing node from Sqlite to Postgres. Note that the opposite (from Postgres to Sqlite) is not supported. + +:warning: Once you have migrated from Sqlite to Postgres there is no going back! + +To migrate from Sqlite to Postgres, follow these steps: +1. Stop Eclair +2. Edit `eclair.conf` + 1. Set `eclair.db.postgres.*` as explained in the section [Connection Settings](#connection-settings). + 2. Set `eclair.db.driver=dual-sqlite-primary`. This will make Eclair use both databases backends. All calls to sqlite will be replicated in postgres. + 3. Set `eclair.db.dual.migrate-on-restart=true`. This will make Eclair migrate the data from Sqlite to Postgres at startup. + 4. Set `eclair.db.dual.compare-on-restart=true`. This will make Eclair compare Sqlite and Postgres at startup. The result of the comparison is displayed in the logs. +3. Delete the file `~/.eclair/last_jdbcurl`. The purpose of this file is to prevent accidental change in the database backend. +4. Start Eclair. You should see in the logs: + 1. `migrating all tables...` + 2. `migration complete` + 3. `comparing all tables...` + 4. `comparison complete identical=true` (NB: if `identical=false`, contact support) +5. Eclair should then finish startup and operate normally. Data has been migrated to Postgres, and Sqlite/Postgres will be maintained in sync going forward. +6. Edit `eclair.conf` and set `eclair.db.dual.migrate-on-restart=false` but do not restart Eclair yet. +7. We recommend that you leave Eclair in dual db mode for a while, to make sure that you don't have issues with your new Postgres database. This a good time to set up [Backups and replication](#backups-and-replication). +8. After some time has passed, restart Eclair. You should see in the logs: + 1. `comparing all tables...` + 2. `comparison complete identical=true` (NB: if `identical=false`, contact support) +9. At this point we have confidence that the Postgres backend works normally, and we are ready to drop Sqlite for good. +10. Edit `eclair.conf` + 1. Set `eclair.db.driver=postgres` + 2. Set `eclair.db.dual.compare-on-restart=false` +11. Restart Eclair. From this moment, you cannot go back to Sqlite! If you try to do so, Eclair will refuse to start. \ No newline at end of file diff --git a/eclair-core/pom.xml b/eclair-core/pom.xml index 80603e99c..a027974b5 100644 --- a/eclair-core/pom.xml +++ b/eclair-core/pom.xml @@ -204,6 +204,12 @@ scodec-core_${scala.version.short} 1.11.8 + + + org.scodec + scodec-bits_${scala.version.short} + 1.1.25 + commons-codec commons-codec diff --git a/eclair-core/src/main/resources/reference.conf b/eclair-core/src/main/resources/reference.conf index 10bab482b..ee5e42285 100644 --- a/eclair-core/src/main/resources/reference.conf +++ b/eclair-core/src/main/resources/reference.conf @@ -311,7 +311,7 @@ eclair { } db { - driver = "sqlite" // sqlite, postgres + driver = "sqlite" // sqlite, postgres, dual-sqlite-primary, dual-postgres-primary postgres { database = "eclair" host = "localhost" @@ -353,6 +353,10 @@ eclair { } } } + dual { + migrate-on-restart = false // migrate sqlite -> postgres on restart (only applies if sqlite is primary) + compare-on-restart = false // compare sqlite and postgres dbs on restart (only applies if sqlite is primary) + } } file-backup { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala index 8f6f768f6..9713cfbf1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala @@ -21,6 +21,7 @@ import akka.actor.{ActorSystem, CoordinatedShutdown} import com.typesafe.config.Config import com.zaxxer.hikari.{HikariConfig, HikariDataSource} import fr.acinq.eclair.TimestampMilli +import fr.acinq.eclair.db.migration.{CompareDb, MigrateDb} import fr.acinq.eclair.db.pg.PgUtils.PgLock.LockFailureHandler import fr.acinq.eclair.db.pg.PgUtils._ import fr.acinq.eclair.db.pg._ @@ -267,13 +268,25 @@ object Databases extends Logging { db match { case Some(d) => d case None => + val jdbcUrlFile = new File(chaindir, "last_jdbcurl") dbConfig.getString("driver") match { - case "sqlite" => Databases.sqlite(chaindir) - case "postgres" => Databases.postgres(dbConfig, instanceId, chaindir) - case "dual" => - val sqlite = Databases.sqlite(chaindir) - val postgres = Databases.postgres(dbConfig, instanceId, chaindir) - DualDatabases(sqlite, postgres) + case "sqlite" => Databases.sqlite(chaindir, jdbcUrlFile_opt = Some(jdbcUrlFile)) + case "postgres" => Databases.postgres(dbConfig, instanceId, chaindir, jdbcUrlFile_opt = Some(jdbcUrlFile)) + case dual@("dual-sqlite-primary" | "dual-postgres-primary") => + logger.info(s"using $dual database mode") + val sqlite = Databases.sqlite(chaindir, jdbcUrlFile_opt = None) + val postgres = Databases.postgres(dbConfig, instanceId, chaindir, jdbcUrlFile_opt = None) + val (primary, secondary) = if (dual == "dual-sqlite-primary") (sqlite, postgres) else (postgres, sqlite) + val dualDb = DualDatabases(primary, secondary) + if (primary == sqlite) { + if (dbConfig.getBoolean("dual.migrate-on-restart")) { + MigrateDb.migrateAll(dualDb) + } + if (dbConfig.getBoolean("dual.compare-on-restart")) { + CompareDb.compareAll(dualDb) + } + } + dualDb case driver => throw new RuntimeException(s"unknown database driver `$driver`") } } @@ -282,18 +295,17 @@ object Databases extends Logging { /** * Given a parent folder it creates or loads all the databases from a JDBC connection */ - def sqlite(dbdir: File): SqliteDatabases = { + def sqlite(dbdir: File, jdbcUrlFile_opt: Option[File]): SqliteDatabases = { dbdir.mkdirs() - val jdbcUrlFile = new File(dbdir, "last_jdbcurl") SqliteDatabases( eclairJdbc = SqliteUtils.openSqliteFile(dbdir, "eclair.sqlite", exclusiveLock = true, journalMode = "wal", syncFlag = "full"), // there should only be one process writing to this file networkJdbc = SqliteUtils.openSqliteFile(dbdir, "network.sqlite", exclusiveLock = false, journalMode = "wal", syncFlag = "normal"), // we don't need strong durability guarantees on the network db auditJdbc = SqliteUtils.openSqliteFile(dbdir, "audit.sqlite", exclusiveLock = false, journalMode = "wal", syncFlag = "full"), - jdbcUrlFile_opt = Some(jdbcUrlFile) + jdbcUrlFile_opt = jdbcUrlFile_opt ) } - def postgres(dbConfig: Config, instanceId: UUID, dbdir: File, lockExceptionHandler: LockFailureHandler = LockFailureHandler.logAndStop)(implicit system: ActorSystem): PostgresDatabases = { + def postgres(dbConfig: Config, instanceId: UUID, dbdir: File, jdbcUrlFile_opt: Option[File], lockExceptionHandler: LockFailureHandler = LockFailureHandler.logAndStop)(implicit system: ActorSystem): PostgresDatabases = { dbdir.mkdirs() val database = dbConfig.getString("postgres.database") val host = dbConfig.getString("postgres.host") @@ -328,8 +340,6 @@ object Databases extends Logging { case unknownLock => throw new RuntimeException(s"unknown postgres lock type: `$unknownLock`") } - val jdbcUrlFile = new File(dbdir, "last_jdbcurl") - val safetyChecks_opt = if (dbConfig.getBoolean("postgres.safety-checks.enabled")) { Some(PostgresDatabases.SafetyChecks( localChannelsMaxAge = FiniteDuration(dbConfig.getDuration("postgres.safety-checks.max-age.local-channels").getSeconds, TimeUnit.SECONDS), @@ -345,7 +355,7 @@ object Databases extends Logging { hikariConfig = hikariConfig, instanceId = instanceId, lock = lock, - jdbcUrlFile_opt = Some(jdbcUrlFile), + jdbcUrlFile_opt = jdbcUrlFile_opt, readOnlyUser_opt = readOnlyUser_opt, resetJsonColumns = resetJsonColumns, safetyChecks_opt = safetyChecks_opt diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala index 9ca1a2ed5..29f9b3e1c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala @@ -6,8 +6,7 @@ import fr.acinq.eclair.channel._ import fr.acinq.eclair.db.Databases.{FileBackup, PostgresDatabases, SqliteDatabases} import fr.acinq.eclair.db.DbEventHandler.ChannelEvent import fr.acinq.eclair.db.DualDatabases.runAsync -import fr.acinq.eclair.db.pg._ -import fr.acinq.eclair.db.sqlite._ +import fr.acinq.eclair.io.Peer import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair.router.Router @@ -23,25 +22,30 @@ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success, Try} /** - * An implementation of [[Databases]] where there are two separate underlying db, one sqlite and one postgres. - * Sqlite is the main database, but we also replicate all calls to postgres. - * Calls to postgres are made asynchronously in a dedicated thread pool, so that it doesn't have any performance impact. + * An implementation of [[Databases]] where there are two separate underlying db, one primary and one secondary. + * All calls to primary are replicated asynchronously to secondary. + * Calls to secondary are made asynchronously in a dedicated thread pool, so that it doesn't have any performance impact. */ -case class DualDatabases(sqlite: SqliteDatabases, postgres: PostgresDatabases) extends Databases with FileBackup { +case class DualDatabases(primary: Databases, secondary: Databases) extends Databases with FileBackup { - override val network: NetworkDb = DualNetworkDb(sqlite.network, postgres.network) + override val network: NetworkDb = DualNetworkDb(primary.network, secondary.network) - override val audit: AuditDb = DualAuditDb(sqlite.audit, postgres.audit) + override val audit: AuditDb = DualAuditDb(primary.audit, secondary.audit) - override val channels: ChannelsDb = DualChannelsDb(sqlite.channels, postgres.channels) + override val channels: ChannelsDb = DualChannelsDb(primary.channels, secondary.channels) - override val peers: PeersDb = DualPeersDb(sqlite.peers, postgres.peers) + override val peers: PeersDb = DualPeersDb(primary.peers, secondary.peers) - override val payments: PaymentsDb = DualPaymentsDb(sqlite.payments, postgres.payments) + override val payments: PaymentsDb = DualPaymentsDb(primary.payments, secondary.payments) - override val pendingCommands: PendingCommandsDb = DualPendingCommandsDb(sqlite.pendingCommands, postgres.pendingCommands) + override val pendingCommands: PendingCommandsDb = DualPendingCommandsDb(primary.pendingCommands, secondary.pendingCommands) - override def backup(backupFile: File): Unit = sqlite.backup(backupFile) + /** if one of the database supports file backup, we use it */ + override def backup(backupFile: File): Unit = (primary, secondary) match { + case (f: FileBackup, _) => f.backup(backupFile) + case (_, f: FileBackup) => f.backup(backupFile) + case _ => () + } } object DualDatabases extends Logging { @@ -55,360 +59,369 @@ object DualDatabases extends Logging { throw t } } + + def getDatabases(dualDatabases: DualDatabases): (SqliteDatabases, PostgresDatabases) = + (dualDatabases.primary, dualDatabases.secondary) match { + case (sqliteDb: SqliteDatabases, postgresDb: PostgresDatabases) => + (sqliteDb, postgresDb) + case (postgresDb: PostgresDatabases, sqliteDb: SqliteDatabases) => + (sqliteDb, postgresDb) + case _ => throw new IllegalArgumentException("there must be one sqlite and one postgres in dual db mode") + } } -case class DualNetworkDb(sqlite: SqliteNetworkDb, postgres: PgNetworkDb) extends NetworkDb { +case class DualNetworkDb(primary: NetworkDb, secondary: NetworkDb) extends NetworkDb { private implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(Executors.newSingleThreadExecutor(new ThreadFactoryBuilder().setNameFormat("db-network").build())) override def addNode(n: NodeAnnouncement): Unit = { - runAsync(postgres.addNode(n)) - sqlite.addNode(n) + runAsync(secondary.addNode(n)) + primary.addNode(n) } override def updateNode(n: NodeAnnouncement): Unit = { - runAsync(postgres.updateNode(n)) - sqlite.updateNode(n) + runAsync(secondary.updateNode(n)) + primary.updateNode(n) } override def getNode(nodeId: Crypto.PublicKey): Option[NodeAnnouncement] = { - runAsync(postgres.getNode(nodeId)) - sqlite.getNode(nodeId) + runAsync(secondary.getNode(nodeId)) + primary.getNode(nodeId) } override def removeNode(nodeId: Crypto.PublicKey): Unit = { - runAsync(postgres.removeNode(nodeId)) - sqlite.removeNode(nodeId) + runAsync(secondary.removeNode(nodeId)) + primary.removeNode(nodeId) } override def listNodes(): Seq[NodeAnnouncement] = { - runAsync(postgres.listNodes()) - sqlite.listNodes() + runAsync(secondary.listNodes()) + primary.listNodes() } override def addChannel(c: ChannelAnnouncement, txid: ByteVector32, capacity: Satoshi): Unit = { - runAsync(postgres.addChannel(c, txid, capacity)) - sqlite.addChannel(c, txid, capacity) + runAsync(secondary.addChannel(c, txid, capacity)) + primary.addChannel(c, txid, capacity) } override def updateChannel(u: ChannelUpdate): Unit = { - runAsync(postgres.updateChannel(u)) - sqlite.updateChannel(u) + runAsync(secondary.updateChannel(u)) + primary.updateChannel(u) } override def removeChannels(shortChannelIds: Iterable[ShortChannelId]): Unit = { - runAsync(postgres.removeChannels(shortChannelIds)) - sqlite.removeChannels(shortChannelIds) + runAsync(secondary.removeChannels(shortChannelIds)) + primary.removeChannels(shortChannelIds) } override def listChannels(): SortedMap[ShortChannelId, Router.PublicChannel] = { - runAsync(postgres.listChannels()) - sqlite.listChannels() + runAsync(secondary.listChannels()) + primary.listChannels() } override def addToPruned(shortChannelIds: Iterable[ShortChannelId]): Unit = { - runAsync(postgres.addToPruned(shortChannelIds)) - sqlite.addToPruned(shortChannelIds) + runAsync(secondary.addToPruned(shortChannelIds)) + primary.addToPruned(shortChannelIds) } override def removeFromPruned(shortChannelId: ShortChannelId): Unit = { - runAsync(postgres.removeFromPruned(shortChannelId)) - sqlite.removeFromPruned(shortChannelId) + runAsync(secondary.removeFromPruned(shortChannelId)) + primary.removeFromPruned(shortChannelId) } override def isPruned(shortChannelId: ShortChannelId): Boolean = { - runAsync(postgres.isPruned(shortChannelId)) - sqlite.isPruned(shortChannelId) + runAsync(secondary.isPruned(shortChannelId)) + primary.isPruned(shortChannelId) } override def close(): Unit = { - runAsync(postgres.close()) - sqlite.close() + runAsync(secondary.close()) + primary.close() } } -case class DualAuditDb(sqlite: SqliteAuditDb, postgres: PgAuditDb) extends AuditDb { +case class DualAuditDb(primary: AuditDb, secondary: AuditDb) extends AuditDb { private implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(Executors.newSingleThreadExecutor(new ThreadFactoryBuilder().setNameFormat("db-audit").build())) override def add(channelLifecycle: DbEventHandler.ChannelEvent): Unit = { - runAsync(postgres.add(channelLifecycle)) - sqlite.add(channelLifecycle) + runAsync(secondary.add(channelLifecycle)) + primary.add(channelLifecycle) } override def add(paymentSent: PaymentSent): Unit = { - runAsync(postgres.add(paymentSent)) - sqlite.add(paymentSent) + runAsync(secondary.add(paymentSent)) + primary.add(paymentSent) } override def add(paymentReceived: PaymentReceived): Unit = { - runAsync(postgres.add(paymentReceived)) - sqlite.add(paymentReceived) + runAsync(secondary.add(paymentReceived)) + primary.add(paymentReceived) } override def add(paymentRelayed: PaymentRelayed): Unit = { - runAsync(postgres.add(paymentRelayed)) - sqlite.add(paymentRelayed) + runAsync(secondary.add(paymentRelayed)) + primary.add(paymentRelayed) } override def add(txPublished: TransactionPublished): Unit = { - runAsync(postgres.add(txPublished)) - sqlite.add(txPublished) + runAsync(secondary.add(txPublished)) + primary.add(txPublished) } override def add(txConfirmed: TransactionConfirmed): Unit = { - runAsync(postgres.add(txConfirmed)) - sqlite.add(txConfirmed) + runAsync(secondary.add(txConfirmed)) + primary.add(txConfirmed) } override def add(channelErrorOccurred: ChannelErrorOccurred): Unit = { - runAsync(postgres.add(channelErrorOccurred)) - sqlite.add(channelErrorOccurred) + runAsync(secondary.add(channelErrorOccurred)) + primary.add(channelErrorOccurred) } override def addChannelUpdate(channelUpdateParametersChanged: ChannelUpdateParametersChanged): Unit = { - runAsync(postgres.addChannelUpdate(channelUpdateParametersChanged)) - sqlite.addChannelUpdate(channelUpdateParametersChanged) + runAsync(secondary.addChannelUpdate(channelUpdateParametersChanged)) + primary.addChannelUpdate(channelUpdateParametersChanged) } override def addPathFindingExperimentMetrics(metrics: PathFindingExperimentMetrics): Unit = { - runAsync(postgres.addPathFindingExperimentMetrics(metrics)) - sqlite.addPathFindingExperimentMetrics(metrics) + runAsync(secondary.addPathFindingExperimentMetrics(metrics)) + primary.addPathFindingExperimentMetrics(metrics) } override def listSent(from: TimestampMilli, to: TimestampMilli): Seq[PaymentSent] = { - runAsync(postgres.listSent(from, to)) - sqlite.listSent(from, to) + runAsync(secondary.listSent(from, to)) + primary.listSent(from, to) } override def listReceived(from: TimestampMilli, to: TimestampMilli): Seq[PaymentReceived] = { - runAsync(postgres.listReceived(from, to)) - sqlite.listReceived(from, to) + runAsync(secondary.listReceived(from, to)) + primary.listReceived(from, to) } override def listRelayed(from: TimestampMilli, to: TimestampMilli): Seq[PaymentRelayed] = { - runAsync(postgres.listRelayed(from, to)) - sqlite.listRelayed(from, to) + runAsync(secondary.listRelayed(from, to)) + primary.listRelayed(from, to) } override def listNetworkFees(from: TimestampMilli, to: TimestampMilli): Seq[AuditDb.NetworkFee] = { - runAsync(postgres.listNetworkFees(from, to)) - sqlite.listNetworkFees(from, to) + runAsync(secondary.listNetworkFees(from, to)) + primary.listNetworkFees(from, to) } override def stats(from: TimestampMilli, to: TimestampMilli): Seq[AuditDb.Stats] = { - runAsync(postgres.stats(from, to)) - sqlite.stats(from, to) + runAsync(secondary.stats(from, to)) + primary.stats(from, to) } override def close(): Unit = { - runAsync(postgres.close()) - sqlite.close() + runAsync(secondary.close()) + primary.close() } } -case class DualChannelsDb(sqlite: SqliteChannelsDb, postgres: PgChannelsDb) extends ChannelsDb { +case class DualChannelsDb(primary: ChannelsDb, secondary: ChannelsDb) extends ChannelsDb { private implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(Executors.newSingleThreadExecutor(new ThreadFactoryBuilder().setNameFormat("db-channels").build())) override def addOrUpdateChannel(state: HasCommitments): Unit = { - runAsync(postgres.addOrUpdateChannel(state)) - sqlite.addOrUpdateChannel(state) + runAsync(secondary.addOrUpdateChannel(state)) + primary.addOrUpdateChannel(state) } override def getChannel(channelId: ByteVector32): Option[HasCommitments] = { - runAsync(postgres.getChannel(channelId)) - sqlite.getChannel(channelId) + runAsync(secondary.getChannel(channelId)) + primary.getChannel(channelId) } override def updateChannelMeta(channelId: ByteVector32, event: ChannelEvent.EventType): Unit = { - runAsync(postgres.updateChannelMeta(channelId, event)) - sqlite.updateChannelMeta(channelId, event) + runAsync(secondary.updateChannelMeta(channelId, event)) + primary.updateChannelMeta(channelId, event) } override def removeChannel(channelId: ByteVector32): Unit = { - runAsync(postgres.removeChannel(channelId)) - sqlite.removeChannel(channelId) + runAsync(secondary.removeChannel(channelId)) + primary.removeChannel(channelId) } override def listLocalChannels(): Seq[HasCommitments] = { - runAsync(postgres.listLocalChannels()) - sqlite.listLocalChannels() + runAsync(secondary.listLocalChannels()) + primary.listLocalChannels() } override def addHtlcInfo(channelId: ByteVector32, commitmentNumber: Long, paymentHash: ByteVector32, cltvExpiry: CltvExpiry): Unit = { - runAsync(postgres.addHtlcInfo(channelId, commitmentNumber, paymentHash, cltvExpiry)) - sqlite.addHtlcInfo(channelId, commitmentNumber, paymentHash, cltvExpiry) + runAsync(secondary.addHtlcInfo(channelId, commitmentNumber, paymentHash, cltvExpiry)) + primary.addHtlcInfo(channelId, commitmentNumber, paymentHash, cltvExpiry) } override def listHtlcInfos(channelId: ByteVector32, commitmentNumber: Long): Seq[(ByteVector32, CltvExpiry)] = { - runAsync(postgres.listHtlcInfos(channelId, commitmentNumber)) - sqlite.listHtlcInfos(channelId, commitmentNumber) + runAsync(secondary.listHtlcInfos(channelId, commitmentNumber)) + primary.listHtlcInfos(channelId, commitmentNumber) } override def close(): Unit = { - runAsync(postgres.close()) - sqlite.close() + runAsync(secondary.close()) + primary.close() } } -case class DualPeersDb(sqlite: SqlitePeersDb, postgres: PgPeersDb) extends PeersDb { +case class DualPeersDb(primary: PeersDb, secondary: PeersDb) extends PeersDb { private implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(Executors.newSingleThreadExecutor(new ThreadFactoryBuilder().setNameFormat("db-peers").build())) override def addOrUpdatePeer(nodeId: Crypto.PublicKey, address: NodeAddress): Unit = { - runAsync(postgres.addOrUpdatePeer(nodeId, address)) - sqlite.addOrUpdatePeer(nodeId, address) + runAsync(secondary.addOrUpdatePeer(nodeId, address)) + primary.addOrUpdatePeer(nodeId, address) } override def removePeer(nodeId: Crypto.PublicKey): Unit = { - runAsync(postgres.removePeer(nodeId)) - sqlite.removePeer(nodeId) + runAsync(secondary.removePeer(nodeId)) + primary.removePeer(nodeId) } override def getPeer(nodeId: Crypto.PublicKey): Option[NodeAddress] = { - runAsync(postgres.getPeer(nodeId)) - sqlite.getPeer(nodeId) + runAsync(secondary.getPeer(nodeId)) + primary.getPeer(nodeId) } override def listPeers(): Map[Crypto.PublicKey, NodeAddress] = { - runAsync(postgres.listPeers()) - sqlite.listPeers() + runAsync(secondary.listPeers()) + primary.listPeers() } override def addOrUpdateRelayFees(nodeId: Crypto.PublicKey, fees: RelayFees): Unit = { - runAsync(postgres.addOrUpdateRelayFees(nodeId, fees)) - sqlite.addOrUpdateRelayFees(nodeId, fees) + runAsync(secondary.addOrUpdateRelayFees(nodeId, fees)) + primary.addOrUpdateRelayFees(nodeId, fees) } override def getRelayFees(nodeId: Crypto.PublicKey): Option[RelayFees] = { - runAsync(postgres.getRelayFees(nodeId)) - sqlite.getRelayFees(nodeId) + runAsync(secondary.getRelayFees(nodeId)) + primary.getRelayFees(nodeId) } override def close(): Unit = { - runAsync(postgres.close()) - sqlite.close() + runAsync(secondary.close()) + primary.close() } } -case class DualPaymentsDb(sqlite: SqlitePaymentsDb, postgres: PgPaymentsDb) extends PaymentsDb { +case class DualPaymentsDb(primary: PaymentsDb, secondary: PaymentsDb) extends PaymentsDb { private implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(Executors.newSingleThreadExecutor(new ThreadFactoryBuilder().setNameFormat("db-payments").build())) override def listPaymentsOverview(limit: Int): Seq[PlainPayment] = { - runAsync(postgres.listPaymentsOverview(limit)) - sqlite.listPaymentsOverview(limit) + runAsync(secondary.listPaymentsOverview(limit)) + primary.listPaymentsOverview(limit) } override def close(): Unit = { - runAsync(postgres.close()) - sqlite.close() + runAsync(secondary.close()) + primary.close() } override def addIncomingPayment(pr: PaymentRequest, preimage: ByteVector32, paymentType: String): Unit = { - runAsync(postgres.addIncomingPayment(pr, preimage, paymentType)) - sqlite.addIncomingPayment(pr, preimage, paymentType) + runAsync(secondary.addIncomingPayment(pr, preimage, paymentType)) + primary.addIncomingPayment(pr, preimage, paymentType) } override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Boolean = { - runAsync(postgres.receiveIncomingPayment(paymentHash, amount, receivedAt)) - sqlite.receiveIncomingPayment(paymentHash, amount, receivedAt) + runAsync(secondary.receiveIncomingPayment(paymentHash, amount, receivedAt)) + primary.receiveIncomingPayment(paymentHash, amount, receivedAt) } override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] = { - runAsync(postgres.getIncomingPayment(paymentHash)) - sqlite.getIncomingPayment(paymentHash) + runAsync(secondary.getIncomingPayment(paymentHash)) + primary.getIncomingPayment(paymentHash) } override def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] = { - runAsync(postgres.removeIncomingPayment(paymentHash)) - sqlite.removeIncomingPayment(paymentHash) + runAsync(secondary.removeIncomingPayment(paymentHash)) + primary.removeIncomingPayment(paymentHash) } override def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = { - runAsync(postgres.listIncomingPayments(from, to)) - sqlite.listIncomingPayments(from, to) + runAsync(secondary.listIncomingPayments(from, to)) + primary.listIncomingPayments(from, to) } override def listPendingIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = { - runAsync(postgres.listPendingIncomingPayments(from, to)) - sqlite.listPendingIncomingPayments(from, to) + runAsync(secondary.listPendingIncomingPayments(from, to)) + primary.listPendingIncomingPayments(from, to) } override def listExpiredIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = { - runAsync(postgres.listExpiredIncomingPayments(from, to)) - sqlite.listExpiredIncomingPayments(from, to) + runAsync(secondary.listExpiredIncomingPayments(from, to)) + primary.listExpiredIncomingPayments(from, to) } override def listReceivedIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = { - runAsync(postgres.listReceivedIncomingPayments(from, to)) - sqlite.listReceivedIncomingPayments(from, to) + runAsync(secondary.listReceivedIncomingPayments(from, to)) + primary.listReceivedIncomingPayments(from, to) } override def addOutgoingPayment(outgoingPayment: OutgoingPayment): Unit = { - runAsync(postgres.addOutgoingPayment(outgoingPayment)) - sqlite.addOutgoingPayment(outgoingPayment) + runAsync(secondary.addOutgoingPayment(outgoingPayment)) + primary.addOutgoingPayment(outgoingPayment) } override def updateOutgoingPayment(paymentResult: PaymentSent): Unit = { - runAsync(postgres.updateOutgoingPayment(paymentResult)) - sqlite.updateOutgoingPayment(paymentResult) + runAsync(secondary.updateOutgoingPayment(paymentResult)) + primary.updateOutgoingPayment(paymentResult) } override def updateOutgoingPayment(paymentResult: PaymentFailed): Unit = { - runAsync(postgres.updateOutgoingPayment(paymentResult)) - sqlite.updateOutgoingPayment(paymentResult) + runAsync(secondary.updateOutgoingPayment(paymentResult)) + primary.updateOutgoingPayment(paymentResult) } override def getOutgoingPayment(id: UUID): Option[OutgoingPayment] = { - runAsync(postgres.getOutgoingPayment(id)) - sqlite.getOutgoingPayment(id) + runAsync(secondary.getOutgoingPayment(id)) + primary.getOutgoingPayment(id) } override def listOutgoingPayments(parentId: UUID): Seq[OutgoingPayment] = { - runAsync(postgres.listOutgoingPayments(parentId)) - sqlite.listOutgoingPayments(parentId) + runAsync(secondary.listOutgoingPayments(parentId)) + primary.listOutgoingPayments(parentId) } override def listOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] = { - runAsync(postgres.listOutgoingPayments(paymentHash)) - sqlite.listOutgoingPayments(paymentHash) + runAsync(secondary.listOutgoingPayments(paymentHash)) + primary.listOutgoingPayments(paymentHash) } override def listOutgoingPayments(from: TimestampMilli, to: TimestampMilli): Seq[OutgoingPayment] = { - runAsync(postgres.listOutgoingPayments(from, to)) - sqlite.listOutgoingPayments(from, to) + runAsync(secondary.listOutgoingPayments(from, to)) + primary.listOutgoingPayments(from, to) } } -case class DualPendingCommandsDb(sqlite: SqlitePendingCommandsDb, postgres: PgPendingCommandsDb) extends PendingCommandsDb { +case class DualPendingCommandsDb(primary: PendingCommandsDb, secondary: PendingCommandsDb) extends PendingCommandsDb { private implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(Executors.newSingleThreadExecutor(new ThreadFactoryBuilder().setNameFormat("db-pending-commands").build())) override def addSettlementCommand(channelId: ByteVector32, cmd: HtlcSettlementCommand): Unit = { - runAsync(postgres.addSettlementCommand(channelId, cmd)) - sqlite.addSettlementCommand(channelId, cmd) + runAsync(secondary.addSettlementCommand(channelId, cmd)) + primary.addSettlementCommand(channelId, cmd) } override def removeSettlementCommand(channelId: ByteVector32, htlcId: Long): Unit = { - runAsync(postgres.removeSettlementCommand(channelId, htlcId)) - sqlite.removeSettlementCommand(channelId, htlcId) + runAsync(secondary.removeSettlementCommand(channelId, htlcId)) + primary.removeSettlementCommand(channelId, htlcId) } override def listSettlementCommands(channelId: ByteVector32): Seq[HtlcSettlementCommand] = { - runAsync(postgres.listSettlementCommands(channelId)) - sqlite.listSettlementCommands(channelId) + runAsync(secondary.listSettlementCommands(channelId)) + primary.listSettlementCommands(channelId) } override def listSettlementCommands(): Seq[(ByteVector32, HtlcSettlementCommand)] = { - runAsync(postgres.listSettlementCommands()) - sqlite.listSettlementCommands() + runAsync(secondary.listSettlementCommands()) + primary.listSettlementCommands() } override def close(): Unit = { - runAsync(postgres.close()) - sqlite.close() + runAsync(secondary.close()) + primary.close() } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala index dc0838168..c6771c856 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala @@ -153,6 +153,11 @@ trait JdbcUtils { ByteVector.fromValidHex(s) } + def getByteVectorFromHexNullable(columnLabel: String): Option[ByteVector] = { + val s = rs.getString(columnLabel) + if (rs.wasNull()) None else Some(ByteVector.fromValidHex(s)) + } + def getByteVector32FromHex(columnLabel: String): ByteVector32 = { val s = rs.getString(columnLabel) ByteVector32(ByteVector.fromValidHex(s)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/CompareAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/CompareAuditDb.scala new file mode 100644 index 000000000..5d532e70a --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/CompareAuditDb.scala @@ -0,0 +1,280 @@ +package fr.acinq.eclair.db.migration + +import fr.acinq.eclair.db.migration.CompareDb._ +import scodec.bits.ByteVector + +import java.sql.{Connection, ResultSet} + +object CompareAuditDb { + + private def compareSentTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "sent" + val table2 = "audit.sent" + + def hash1(rs: ResultSet): ByteVector = { + long(rs, "amount_msat") ++ + long(rs, "fees_msat") ++ + long(rs, "recipient_amount_msat") ++ + string(rs, "payment_id") ++ + string(rs, "parent_payment_id") ++ + bytes(rs, "payment_hash") ++ + bytes(rs, "payment_preimage") ++ + bytes(rs, "recipient_node_id") ++ + bytes(rs, "to_channel_id") ++ + longts(rs, "timestamp") + } + + def hash2(rs: ResultSet): ByteVector = { + long(rs, "amount_msat") ++ + long(rs, "fees_msat") ++ + long(rs, "recipient_amount_msat") ++ + string(rs, "payment_id") ++ + string(rs, "parent_payment_id") ++ + hex(rs, "payment_hash") ++ + hex(rs, "payment_preimage") ++ + hex(rs, "recipient_node_id") ++ + hex(rs, "to_channel_id") ++ + ts(rs, "timestamp") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + private def compareReceivedTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "received" + val table2 = "audit.received" + + def hash1(rs: ResultSet): ByteVector = { + long(rs, "amount_msat") ++ + bytes(rs, "payment_hash") ++ + bytes(rs, "from_channel_id") ++ + longts(rs, "timestamp") + } + + def hash2(rs: ResultSet): ByteVector = { + long(rs, "amount_msat") ++ + hex(rs, "payment_hash") ++ + hex(rs, "from_channel_id") ++ + ts(rs, "timestamp") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + private def compareRelayedTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "relayed" + val table2 = "audit.relayed" + + def hash1(rs: ResultSet): ByteVector = { + bytes(rs, "payment_hash") ++ + long(rs, "amount_msat") ++ + bytes(rs, "channel_id") ++ + string(rs, "direction") ++ + string(rs, "relay_type") ++ + longts(rs, "timestamp") + } + + def hash2(rs: ResultSet): ByteVector = { + hex(rs, "payment_hash") ++ + long(rs, "amount_msat") ++ + hex(rs, "channel_id") ++ + string(rs, "direction") ++ + string(rs, "relay_type") ++ + ts(rs, "timestamp") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + private def compareRelayedTrampolineTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "relayed_trampoline" + val table2 = "audit.relayed_trampoline" + + def hash1(rs: ResultSet): ByteVector = { + bytes(rs, "payment_hash") ++ + long(rs, "amount_msat") ++ + bytes(rs, "next_node_id") ++ + longts(rs, "timestamp") + } + + def hash2(rs: ResultSet): ByteVector = { + hex(rs, "payment_hash") ++ + long(rs, "amount_msat") ++ + hex(rs, "next_node_id") ++ + ts(rs, "timestamp") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + private def compareTransactionsPublishedTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "transactions_published" + val table2 = "audit.transactions_published" + + def hash1(rs: ResultSet): ByteVector = { + bytes(rs, "tx_id") ++ + bytes(rs, "channel_id") ++ + bytes(rs, "node_id") ++ + long(rs, "mining_fee_sat") ++ + string(rs, "tx_type") ++ + longts(rs, "timestamp") + } + + def hash2(rs: ResultSet): ByteVector = { + hex(rs, "tx_id") ++ + hex(rs, "channel_id") ++ + hex(rs, "node_id") ++ + long(rs, "mining_fee_sat") ++ + string(rs, "tx_type") ++ + ts(rs, "timestamp") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + private def compareTransactionsConfirmedTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "transactions_confirmed" + val table2 = "audit.transactions_confirmed" + + def hash1(rs: ResultSet): ByteVector = { + bytes(rs, "tx_id") ++ + bytes(rs, "channel_id") ++ + bytes(rs, "node_id") ++ + longts(rs, "timestamp") + } + + def hash2(rs: ResultSet): ByteVector = { + hex(rs, "tx_id") ++ + hex(rs, "channel_id") ++ + hex(rs, "node_id") ++ + ts(rs, "timestamp") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + private def compareChannelEventsTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "channel_events" + val table2 = "audit.channel_events" + + def hash1(rs: ResultSet): ByteVector = { + bytes(rs, "channel_id") ++ + bytes(rs, "node_id") ++ + long(rs, "capacity_sat") ++ + bool(rs, "is_funder") ++ + bool(rs, "is_private") ++ + string(rs, "event") ++ + longts(rs, "timestamp") + } + + def hash2(rs: ResultSet): ByteVector = { + hex(rs, "channel_id") ++ + hex(rs, "node_id") ++ + long(rs, "capacity_sat") ++ + bool(rs, "is_funder") ++ + bool(rs, "is_private") ++ + string(rs, "event") ++ + ts(rs, "timestamp") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + private def compareChannelErrorsTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "channel_errors WHERE error_name <> 'CannotAffordFees'" + val table2 = "audit.channel_errors" + + def hash1(rs: ResultSet): ByteVector = { + bytes(rs, "channel_id") ++ + bytes(rs, "node_id") ++ + string(rs, "error_name") ++ + string(rs, "error_message") ++ + bool(rs, "is_fatal") ++ + longts(rs, "timestamp") + } + + def hash2(rs: ResultSet): ByteVector = { + hex(rs, "channel_id") ++ + hex(rs, "node_id") ++ + string(rs, "error_name") ++ + string(rs, "error_message") ++ + bool(rs, "is_fatal") ++ + ts(rs, "timestamp") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + private def compareChannelUpdatesTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "channel_updates" + val table2 = "audit.channel_updates" + + def hash1(rs: ResultSet): ByteVector = { + bytes(rs, "channel_id") ++ + bytes(rs, "node_id") ++ + long(rs, "fee_base_msat") ++ + long(rs, "fee_proportional_millionths") ++ + long(rs, "cltv_expiry_delta") ++ + long(rs, "htlc_minimum_msat") ++ + long(rs, "htlc_maximum_msat") ++ + longts(rs, "timestamp") + } + + def hash2(rs: ResultSet): ByteVector = { + hex(rs, "channel_id") ++ + hex(rs, "node_id") ++ + long(rs, "fee_base_msat") ++ + long(rs, "fee_proportional_millionths") ++ + long(rs, "cltv_expiry_delta") ++ + long(rs, "htlc_minimum_msat") ++ + long(rs, "htlc_maximum_msat") ++ + ts(rs, "timestamp") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + private def comparePathFindingMetricsTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "path_finding_metrics" + val table2 = "audit.path_finding_metrics" + + def hash1(rs: ResultSet): ByteVector = { + long(rs, "amount_msat") ++ + long(rs, "fees_msat") ++ + string(rs, "status") ++ + long(rs, "duration_ms") ++ + longts(rs, "timestamp") ++ + bool(rs, "is_mpp") ++ + string(rs, "experiment_name") ++ + bytes(rs, "recipient_node_id") + + } + + def hash2(rs: ResultSet): ByteVector = { + long(rs, "amount_msat") ++ + long(rs, "fees_msat") ++ + string(rs, "status") ++ + long(rs, "duration_ms") ++ + ts(rs, "timestamp") ++ + bool(rs, "is_mpp") ++ + string(rs, "experiment_name") ++ + hex(rs, "recipient_node_id") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + def compareAllTables(conn1: Connection, conn2: Connection): Boolean = { + compareSentTable(conn1, conn2) && + compareReceivedTable(conn1, conn2) && + compareRelayedTable(conn1, conn2) && + compareRelayedTrampolineTable(conn1, conn2) && + compareTransactionsPublishedTable(conn1, conn2) && + compareTransactionsConfirmedTable(conn1, conn2) && + compareChannelEventsTable(conn1, conn2) && + compareChannelErrorsTable(conn1, conn2) && + compareChannelUpdatesTable(conn1, conn2) && + comparePathFindingMetricsTable(conn1, conn2) + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/CompareChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/CompareChannelsDb.scala new file mode 100644 index 000000000..7f166a4d3 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/CompareChannelsDb.scala @@ -0,0 +1,80 @@ +package fr.acinq.eclair.db.migration + +import fr.acinq.eclair.BlockHeight +import fr.acinq.eclair.channel.{DATA_CLOSING, DATA_WAIT_FOR_FUNDING_CONFIRMED} +import fr.acinq.eclair.db.migration.CompareDb._ +import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec +import scodec.bits.ByteVector + +import java.sql.{Connection, ResultSet} + +object CompareChannelsDb { + + private def compareChannelsTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "local_channels" + val table2 = "local.channels" + + def hash1(rs: ResultSet): ByteVector = { + val data = ByteVector(rs.getBytes("data")) + val data_modified = stateDataCodec.decode(data.bits).require.value match { + case c: DATA_WAIT_FOR_FUNDING_CONFIRMED => stateDataCodec.encode(c.copy(waitingSince = BlockHeight(0))).require.toByteVector + case c: DATA_CLOSING => stateDataCodec.encode(c.copy(waitingSince = BlockHeight(0))).require.toByteVector + case _ => data + } + bytes(rs, "channel_id") ++ + data_modified ++ + bool(rs, "is_closed") ++ + longtsnull(rs, "created_timestamp") ++ + longtsnull(rs, "last_payment_sent_timestamp") ++ + longtsnull(rs, "last_payment_received_timestamp") ++ + longtsnull(rs, "last_connected_timestamp") ++ + longtsnull(rs, "closed_timestamp") + } + + def hash2(rs: ResultSet): ByteVector = { + val data = ByteVector(rs.getBytes("data")) + val data_modified = stateDataCodec.decode(data.bits).require.value match { + case c: DATA_WAIT_FOR_FUNDING_CONFIRMED => stateDataCodec.encode(c.copy(waitingSince = BlockHeight(0))).require.toByteVector + case c: DATA_CLOSING => stateDataCodec.encode(c.copy(waitingSince = BlockHeight(0))).require.toByteVector + case _ => data + } + hex(rs, "channel_id") ++ + data_modified ++ + bool(rs, "is_closed") ++ + tsnull(rs, "created_timestamp") ++ + tsnull(rs, "last_payment_sent_timestamp") ++ + tsnull(rs, "last_payment_received_timestamp") ++ + tsnull(rs, "last_connected_timestamp") ++ + tsnull(rs, "closed_timestamp") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + private def compareHtlcInfosTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "htlc_infos" + val table2 = "local.htlc_infos" + + def hash1(rs: ResultSet): ByteVector = { + bytes(rs, "channel_id") ++ + long(rs, "commitment_number") ++ + bytes(rs, "payment_hash") ++ + long(rs, "cltv_expiry") + } + + def hash2(rs: ResultSet): ByteVector = { + hex(rs, "channel_id") ++ + long(rs, "commitment_number") ++ + hex(rs, "payment_hash") ++ + long(rs, "cltv_expiry") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + def compareAllTables(conn1: Connection, conn2: Connection): Boolean = { + compareChannelsTable(conn1, conn2) && + compareHtlcInfosTable(conn1, conn2) + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/CompareDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/CompareDb.scala new file mode 100644 index 000000000..d5efd6047 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/CompareDb.scala @@ -0,0 +1,81 @@ +package fr.acinq.eclair.db.migration + +import fr.acinq.eclair.db.Databases.{PostgresDatabases, SqliteDatabases} +import fr.acinq.eclair.db.DualDatabases +import fr.acinq.eclair.db.pg.PgUtils +import grizzled.slf4j.Logging +import scodec.bits.ByteVector + +import java.sql.{Connection, ResultSet} + +object CompareDb extends Logging { + + def compareTable(conn1: Connection, + conn2: Connection, + table1: String, + table2: String, + hash1: ResultSet => ByteVector, + hash2: ResultSet => ByteVector): Boolean = { + val rs1 = conn1.prepareStatement(s"SELECT * FROM $table1").executeQuery() + val rs2 = conn2.prepareStatement(s"SELECT * FROM $table2").executeQuery() + + var hashes1 = List.empty[ByteVector] + while (rs1.next()) { + hashes1 = hash1(rs1) +: hashes1 + } + + var hashes2 = List.empty[ByteVector] + while (rs2.next()) { + hashes2 = hash2(rs2) +: hashes2 + } + + val res = hashes1.sorted == hashes2.sorted + + if (res) { + logger.info(s"tables $table1/$table2 are identical") + } else { + val diff1 = hashes1 diff hashes2 + val diff2 = hashes2 diff hashes1 + logger.warn(s"tables $table1/$table2 are different diff1=${diff1.take(3).map(_.toHex.take(128))} diff2=${diff2.take(3).map(_.toHex.take(128))}") + } + + res + } + + // @formatter:off + import fr.acinq.eclair.db.jdbc.JdbcUtils.ExtendedResultSet._ + def bytes(rs: ResultSet, columnName: String): ByteVector = rs.getByteVector(columnName) + def bytesnull(rs: ResultSet, columnName: String): ByteVector = rs.getByteVectorNullable(columnName).getOrElse(ByteVector.fromValidHex("deadbeef")) + def hex(rs: ResultSet, columnName: String): ByteVector = rs.getByteVectorFromHex(columnName) + def hexnull(rs: ResultSet, columnName: String): ByteVector = rs.getByteVectorFromHexNullable(columnName).getOrElse(ByteVector.fromValidHex("deadbeef")) + def string(rs: ResultSet, columnName: String): ByteVector = ByteVector(rs.getString(columnName).getBytes) + def stringnull(rs: ResultSet, columnName: String): ByteVector = ByteVector(rs.getStringNullable(columnName).getOrElse("").getBytes) + def bool(rs: ResultSet, columnName: String): ByteVector = ByteVector.fromByte(if (rs.getBoolean(columnName)) 1 else 0) + def long(rs: ResultSet, columnName: String): ByteVector = ByteVector.fromLong(rs.getLong(columnName)) + def longnull(rs: ResultSet, columnName: String): ByteVector = ByteVector.fromLong(rs.getLongNullable(columnName).getOrElse(42)) + def longts(rs: ResultSet, columnName: String): ByteVector = ByteVector.fromLong((rs.getLong(columnName).toDouble / 1_000_000).round) + def longtsnull(rs: ResultSet, columnName: String): ByteVector = ByteVector.fromLong(rs.getLongNullable(columnName).map(l => (l.toDouble/1_000_000).round).getOrElse(42)) + def int(rs: ResultSet, columnName: String): ByteVector = ByteVector.fromInt(rs.getInt(columnName)) + def ts(rs: ResultSet, columnName: String): ByteVector = ByteVector.fromLong((rs.getTimestamp(columnName).getTime.toDouble / 1_000_000).round) + def tsnull(rs: ResultSet, columnName: String): ByteVector = ByteVector.fromLong(rs.getTimestampNullable(columnName).map(t => (t.getTime.toDouble / 1_000_000).round).getOrElse(42)) + def tssec(rs: ResultSet, columnName: String): ByteVector = ByteVector.fromLong((rs.getTimestamp(columnName).toInstant.getEpochSecond.toDouble / 1_000_000).round) + def tssecnull(rs: ResultSet, columnName: String): ByteVector = ByteVector.fromLong(rs.getTimestampNullable(columnName).map(t => (t.toInstant.getEpochSecond.toDouble / 1_000_000).round).getOrElse(42)) + // @formatter:on + + def compareAll(dualDatabases: DualDatabases): Unit = { + logger.info("comparing all tables...") + val (sqliteDb: SqliteDatabases, postgresDb: PostgresDatabases) = DualDatabases.getDatabases(dualDatabases) + PgUtils.inTransaction { postgres => + val result = List( + CompareChannelsDb.compareAllTables(sqliteDb.channels.sqlite, postgres), + ComparePendingCommandsDb.compareAllTables(sqliteDb.pendingCommands.sqlite, postgres), + ComparePeersDb.compareAllTables(sqliteDb.peers.sqlite, postgres), + ComparePaymentsDb.compareAllTables(sqliteDb.payments.sqlite, postgres), + CompareNetworkDb.compareAllTables(sqliteDb.network.sqlite, postgres), + CompareAuditDb.compareAllTables(sqliteDb.audit.sqlite, postgres) + ).forall(_ == true) + logger.info(s"comparison complete identical=$result") + }(postgresDb.dataSource) + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/CompareNetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/CompareNetworkDb.scala new file mode 100644 index 000000000..8b91bb31d --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/CompareNetworkDb.scala @@ -0,0 +1,73 @@ +package fr.acinq.eclair.db.migration + +import fr.acinq.eclair.db.migration.CompareDb._ +import scodec.bits.ByteVector + +import java.sql.{Connection, ResultSet} + +object CompareNetworkDb { + + private def compareNodesTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "nodes" + val table2 = "network.nodes" + + def hash1(rs: ResultSet): ByteVector = { + bytes(rs, "node_id") ++ + bytes(rs, "data") + } + + def hash2(rs: ResultSet): ByteVector = { + hex(rs, "node_id") ++ + bytes(rs, "data") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + private def compareChannelsTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "channels" + val table2 = "network.public_channels" + + def hash1(rs: ResultSet): ByteVector = { + long(rs, "short_channel_id") ++ + string(rs, "txid") ++ + bytes(rs, "channel_announcement") ++ + long(rs, "capacity_sat") ++ + bytesnull(rs, "channel_update_1") ++ + bytesnull(rs, "channel_update_2") + } + + def hash2(rs: ResultSet): ByteVector = { + long(rs, "short_channel_id") ++ + string(rs, "txid") ++ + bytes(rs, "channel_announcement") ++ + long(rs, "capacity_sat") ++ + bytesnull(rs, "channel_update_1") ++ + bytesnull(rs, "channel_update_2") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + private def comparePrunedTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "pruned" + val table2 = "network.pruned_channels" + + def hash1(rs: ResultSet): ByteVector = { + long(rs, "short_channel_id") + } + + def hash2(rs: ResultSet): ByteVector = { + long(rs, "short_channel_id") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + def compareAllTables(conn1: Connection, conn2: Connection): Boolean = { + compareNodesTable(conn1, conn2) && + compareChannelsTable(conn1, conn2) && + comparePrunedTable(conn1, conn2) + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/ComparePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/ComparePaymentsDb.scala new file mode 100644 index 000000000..17265931c --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/ComparePaymentsDb.scala @@ -0,0 +1,87 @@ +package fr.acinq.eclair.db.migration + +import fr.acinq.eclair.db.migration.CompareDb._ +import scodec.bits.ByteVector + +import java.sql.{Connection, ResultSet} + +object ComparePaymentsDb { + + private def compareReceivedPaymentsTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "received_payments" + val table2 = "payments.received" + + def hash1(rs: ResultSet): ByteVector = { + bytes(rs, "payment_hash") ++ + string(rs, "payment_type") ++ + bytes(rs, "payment_preimage") ++ + string(rs, "payment_request") ++ + longnull(rs, "received_msat") ++ + longts(rs, "created_at") ++ + longts(rs, "expire_at") ++ + longtsnull(rs, "received_at") + } + + def hash2(rs: ResultSet): ByteVector = { + hex(rs, "payment_hash") ++ + string(rs, "payment_type") ++ + hex(rs, "payment_preimage") ++ + string(rs, "payment_request") ++ + longnull(rs, "received_msat") ++ + ts(rs, "created_at") ++ + ts(rs, "expire_at") ++ + tsnull(rs, "received_at") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + private def compareSentPaymentsTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "sent_payments" + val table2 = "payments.sent" + + def hash1(rs: ResultSet): ByteVector = { + string(rs, "id") ++ + string(rs, "parent_id") ++ + stringnull(rs, "external_id") ++ + bytes(rs, "payment_hash") ++ + bytesnull(rs, "payment_preimage") ++ + string(rs, "payment_type") ++ + long(rs, "amount_msat") ++ + longnull(rs, "fees_msat") ++ + long(rs, "recipient_amount_msat") ++ + bytes(rs, "recipient_node_id") ++ + stringnull(rs, "payment_request") ++ + bytesnull(rs, "payment_route") ++ + bytesnull(rs, "failures") ++ + longts(rs, "created_at") ++ + longtsnull(rs, "completed_at") + } + + def hash2(rs: ResultSet): ByteVector = { + string(rs, "id") ++ + string(rs, "parent_id") ++ + stringnull(rs, "external_id") ++ + hex(rs, "payment_hash") ++ + hexnull(rs, "payment_preimage") ++ + string(rs, "payment_type") ++ + long(rs, "amount_msat") ++ + longnull(rs, "fees_msat") ++ + long(rs, "recipient_amount_msat") ++ + hex(rs, "recipient_node_id") ++ + stringnull(rs, "payment_request") ++ + bytesnull(rs, "payment_route") ++ + bytesnull(rs, "failures") ++ + ts(rs, "created_at") ++ + tsnull(rs, "completed_at") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + def compareAllTables(conn1: Connection, conn2: Connection): Boolean = { + compareReceivedPaymentsTable(conn1, conn2) && + compareSentPaymentsTable(conn1, conn2) + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/ComparePeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/ComparePeersDb.scala new file mode 100644 index 000000000..c0c034fa9 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/ComparePeersDb.scala @@ -0,0 +1,51 @@ +package fr.acinq.eclair.db.migration + +import fr.acinq.eclair.db.migration.CompareDb._ +import scodec.bits.ByteVector + +import java.sql.{Connection, ResultSet} + +object ComparePeersDb { + + private def comparePeersTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "peers" + val table2 = "local.peers" + + def hash1(rs: ResultSet): ByteVector = { + bytes(rs, "node_id") ++ + bytes(rs, "data") + } + + def hash2(rs: ResultSet): ByteVector = { + hex(rs, "node_id") ++ + bytes(rs, "data") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + private def compareRelayFeesTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "relay_fees" + val table2 = "local.relay_fees" + + def hash1(rs: ResultSet): ByteVector = { + bytes(rs, "node_id") ++ + long(rs, "fee_base_msat") ++ + long(rs, "fee_proportional_millionths") + } + + def hash2(rs: ResultSet): ByteVector = { + hex(rs, "node_id") ++ + long(rs, "fee_base_msat") ++ + long(rs, "fee_proportional_millionths") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + def compareAllTables(conn1: Connection, conn2: Connection): Boolean = { + comparePeersTable(conn1, conn2) && + compareRelayFeesTable(conn1, conn2) + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/ComparePendingCommandsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/ComparePendingCommandsDb.scala new file mode 100644 index 000000000..eb099ceea --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/ComparePendingCommandsDb.scala @@ -0,0 +1,33 @@ +package fr.acinq.eclair.db.migration + +import fr.acinq.eclair.db.migration.CompareDb._ +import scodec.bits.ByteVector + +import java.sql.{Connection, ResultSet} + +object ComparePendingCommandsDb { + + private def comparePendingSettlementCommandsTable(conn1: Connection, conn2: Connection): Boolean = { + val table1 = "pending_settlement_commands" + val table2 = "local.pending_settlement_commands" + + def hash1(rs: ResultSet): ByteVector = { + bytes(rs, "channel_id") ++ + long(rs, "htlc_id") ++ + bytes(rs, "data") + } + + def hash2(rs: ResultSet): ByteVector = { + hex(rs, "channel_id") ++ + long(rs, "htlc_id") ++ + bytes(rs, "data") + } + + compareTable(conn1, conn2, table1, table2, hash1, hash2) + } + + def compareAllTables(conn1: Connection, conn2: Connection): Boolean = { + comparePendingSettlementCommandsTable(conn1, conn2) + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigrateAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigrateAuditDb.scala new file mode 100644 index 000000000..28705cd27 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigrateAuditDb.scala @@ -0,0 +1,188 @@ +package fr.acinq.eclair.db.migration + +import fr.acinq.eclair.db.jdbc.JdbcUtils.ExtendedResultSet._ +import fr.acinq.eclair.db.migration.MigrateDb.{checkVersions, migrateTable} + +import java.sql.{Connection, PreparedStatement, ResultSet, Timestamp} +import java.time.Instant + +object MigrateAuditDb { + + private def migrateSentTable(source: Connection, destination: Connection): Int = { + val sourceTable = "sent" + val insertSql = "INSERT INTO audit.sent (amount_msat, fees_msat, recipient_amount_msat, payment_id, parent_payment_id, payment_hash, payment_preimage, recipient_node_id, to_channel_id, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setLong(1, rs.getLong("amount_msat")) + insertStatement.setLong(2, rs.getLong("fees_msat")) + insertStatement.setLong(3, rs.getLong("recipient_amount_msat")) + insertStatement.setString(4, rs.getString("payment_id")) + insertStatement.setString(5, rs.getString("parent_payment_id")) + insertStatement.setString(6, rs.getByteVector32("payment_hash").toHex) + insertStatement.setString(7, rs.getByteVector32("payment_preimage").toHex) + insertStatement.setString(8, rs.getByteVector("recipient_node_id").toHex) + insertStatement.setString(9, rs.getByteVector32("to_channel_id").toHex) + insertStatement.setTimestamp(10, Timestamp.from(Instant.ofEpochMilli(rs.getLong("timestamp")))) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + private def migrateReceivedTable(source: Connection, destination: Connection): Int = { + val sourceTable = "received" + val insertSql = "INSERT INTO audit.received (amount_msat, payment_hash, from_channel_id, timestamp) VALUES (?, ?, ?, ?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setLong(1, rs.getLong("amount_msat")) + insertStatement.setString(2, rs.getByteVector32("payment_hash").toHex) + insertStatement.setString(3, rs.getByteVector32("from_channel_id").toHex) + insertStatement.setTimestamp(4, Timestamp.from(Instant.ofEpochMilli(rs.getLong("timestamp")))) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + private def migrateRelayedTable(source: Connection, destination: Connection): Int = { + val sourceTable = "relayed" + val insertSql = "INSERT INTO audit.relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setString(1, rs.getByteVector32("payment_hash").toHex) + insertStatement.setLong(2, rs.getLong("amount_msat")) + insertStatement.setString(3, rs.getByteVector32("channel_id").toHex) + insertStatement.setString(4, rs.getString("direction")) + insertStatement.setString(5, rs.getString("relay_type")) + insertStatement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(rs.getLong("timestamp")))) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + private def migrateRelayedTrampolineTable(source: Connection, destination: Connection): Int = { + val sourceTable = "relayed_trampoline" + val insertSql = "INSERT INTO audit.relayed_trampoline (payment_hash, amount_msat, next_node_id, timestamp) VALUES (?, ?, ?, ?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setString(1, rs.getByteVector32("payment_hash").toHex) + insertStatement.setLong(2, rs.getLong("amount_msat")) + insertStatement.setString(3, rs.getByteVector("next_node_id").toHex) + insertStatement.setTimestamp(4, Timestamp.from(Instant.ofEpochMilli(rs.getLong("timestamp")))) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + private def migrateTransactionsPublishedTable(source: Connection, destination: Connection): Int = { + val sourceTable = "transactions_published" + val insertSql = "INSERT INTO audit.transactions_published (tx_id, channel_id, node_id, mining_fee_sat, tx_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setString(1, rs.getByteVector32("tx_id").toHex) + insertStatement.setString(2, rs.getByteVector32("channel_id").toHex) + insertStatement.setString(3, rs.getByteVector("node_id").toHex) + insertStatement.setLong(4, rs.getLong("mining_fee_sat")) + insertStatement.setString(5, rs.getString("tx_type")) + insertStatement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(rs.getLong("timestamp")))) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + private def migrateTransactionsConfirmedTable(source: Connection, destination: Connection): Int = { + val sourceTable = "transactions_confirmed" + val insertSql = "INSERT INTO audit.transactions_confirmed (tx_id, channel_id, node_id, timestamp) VALUES (?, ?, ?, ?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setString(1, rs.getByteVector32("tx_id").toHex) + insertStatement.setString(2, rs.getByteVector32("channel_id").toHex) + insertStatement.setString(3, rs.getByteVector("node_id").toHex) + insertStatement.setTimestamp(4, Timestamp.from(Instant.ofEpochMilli(rs.getLong("timestamp")))) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + private def migrateChannelEventsTable(source: Connection, destination: Connection): Int = { + val sourceTable = "channel_events" + val insertSql = "INSERT INTO audit.channel_events (channel_id, node_id, capacity_sat, is_funder, is_private, event, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setString(1, rs.getByteVector32("channel_id").toHex) + insertStatement.setString(2, rs.getByteVector("node_id").toHex) + insertStatement.setLong(3, rs.getLong("capacity_sat")) + insertStatement.setBoolean(4, rs.getBoolean("is_funder")) + insertStatement.setBoolean(5, rs.getBoolean("is_private")) + insertStatement.setString(6, rs.getString("event")) + insertStatement.setTimestamp(7, Timestamp.from(Instant.ofEpochMilli(rs.getLong("timestamp")))) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + private def migrateChannelErrorsTable(source: Connection, destination: Connection): Int = { + val sourceTable = "channel_errors WHERE error_name <> 'CannotAffordFees'" + val insertSql = "INSERT INTO audit.channel_errors (channel_id, node_id, error_name, error_message, is_fatal, timestamp) VALUES (?, ?, ?, ?, ?, ?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setString(1, rs.getByteVector32("channel_id").toHex) + insertStatement.setString(2, rs.getByteVector("node_id").toHex) + insertStatement.setString(3, rs.getString("error_name")) + insertStatement.setString(4, rs.getString("error_message")) + insertStatement.setBoolean(5, rs.getBoolean("is_fatal")) + insertStatement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(rs.getLong("timestamp")))) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + private def migrateChannelUpdatesTable(source: Connection, destination: Connection): Int = { + val sourceTable = "channel_updates" + val insertSql = "INSERT INTO audit.channel_updates (channel_id, node_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta, htlc_minimum_msat, htlc_maximum_msat, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setString(1, rs.getByteVector32("channel_id").toHex) + insertStatement.setString(2, rs.getByteVector("node_id").toHex) + insertStatement.setLong(3, rs.getLong("fee_base_msat")) + insertStatement.setLong(4, rs.getLong("fee_proportional_millionths")) + insertStatement.setLong(5, rs.getLong("cltv_expiry_delta")) + insertStatement.setLong(6, rs.getLong("htlc_minimum_msat")) + insertStatement.setLong(7, rs.getLong("htlc_maximum_msat")) + insertStatement.setTimestamp(8, Timestamp.from(Instant.ofEpochMilli(rs.getLong("timestamp")))) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + private def migratePathFindingMetricsTable(source: Connection, destination: Connection): Int = { + val sourceTable = "path_finding_metrics" + val insertSql = "INSERT INTO audit.path_finding_metrics (amount_msat, fees_msat, status, duration_ms, timestamp, is_mpp, experiment_name, recipient_node_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setLong(1, rs.getLong("amount_msat")) + insertStatement.setLong(2, rs.getLong("fees_msat")) + insertStatement.setString(3, rs.getString("status")) + insertStatement.setLong(4, rs.getLong("duration_ms")) + insertStatement.setTimestamp(5, Timestamp.from(Instant.ofEpochMilli(rs.getLong("timestamp")))) + insertStatement.setBoolean(6, rs.getBoolean("is_mpp")) + insertStatement.setString(7, rs.getString("experiment_name")) + insertStatement.setString(8, rs.getByteVector("recipient_node_id").toHex) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + def migrateAllTables(source: Connection, destination: Connection): Unit = { + checkVersions(source, destination, "audit", 8, 10) + migrateSentTable(source, destination) + migrateReceivedTable(source, destination) + migrateRelayedTable(source, destination) + migrateRelayedTrampolineTable(source, destination) + migrateTransactionsPublishedTable(source, destination) + migrateTransactionsConfirmedTable(source, destination) + migrateChannelEventsTable(source, destination) + migrateChannelErrorsTable(source, destination) + migrateChannelUpdatesTable(source, destination) + migratePathFindingMetricsTable(source, destination) + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigrateChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigrateChannelsDb.scala new file mode 100644 index 000000000..0758b1efe --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigrateChannelsDb.scala @@ -0,0 +1,56 @@ +package fr.acinq.eclair.db.migration + +import fr.acinq.eclair.db.jdbc.JdbcUtils.ExtendedResultSet._ +import fr.acinq.eclair.db.migration.MigrateDb.{checkVersions, migrateTable} +import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec +import scodec.bits.BitVector + +import java.sql.{Connection, PreparedStatement, ResultSet, Timestamp} +import java.time.Instant + +object MigrateChannelsDb { + + private def migrateChannelsTable(source: Connection, destination: Connection): Int = { + val sourceTable = "local_channels" + val insertSql = "INSERT INTO local.channels (channel_id, data, json, is_closed, created_timestamp, last_payment_sent_timestamp, last_payment_received_timestamp, last_connected_timestamp, closed_timestamp) VALUES (?, ?, ?::JSONB, ?, ?, ?, ?, ?, ?)" + + import fr.acinq.eclair.json.JsonSerializers._ + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setString(1, rs.getByteVector32("channel_id").toHex) + insertStatement.setBytes(2, rs.getBytes("data")) + val state = stateDataCodec.decode(BitVector(rs.getBytes("data"))).require.value + val json = serialization.writePretty(state) + insertStatement.setString(3, json) + insertStatement.setBoolean(4, rs.getBoolean("is_closed")) + insertStatement.setTimestamp(5, rs.getLongNullable("created_timestamp").map(l => Timestamp.from(Instant.ofEpochMilli(l))).orNull) + insertStatement.setTimestamp(6, rs.getLongNullable("last_payment_sent_timestamp").map(l => Timestamp.from(Instant.ofEpochMilli(l))).orNull) + insertStatement.setTimestamp(7, rs.getLongNullable("last_payment_received_timestamp").map(l => Timestamp.from(Instant.ofEpochMilli(l))).orNull) + insertStatement.setTimestamp(8, rs.getLongNullable("last_connected_timestamp").map(l => Timestamp.from(Instant.ofEpochMilli(l))).orNull) + insertStatement.setTimestamp(9, rs.getLongNullable("closed_timestamp").map(l => Timestamp.from(Instant.ofEpochMilli(l))).orNull) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + private def migrateHtlcInfos(source: Connection, destination: Connection): Int = { + val sourceTable = "htlc_infos" + val insertSql = "INSERT INTO local.htlc_infos (channel_id, commitment_number, payment_hash, cltv_expiry) VALUES (?, ?, ?, ?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setString(1, rs.getByteVector32("channel_id").toHex) + insertStatement.setLong(2, rs.getLong("commitment_number")) + insertStatement.setString(3, rs.getByteVector32("payment_hash").toHex) + insertStatement.setLong(4, rs.getLong("cltv_expiry")) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + def migrateAllTables(source: Connection, destination: Connection): Unit = { + checkVersions(source, destination, "channels", 4, 7) + migrateChannelsTable(source, destination) + migrateHtlcInfos(source, destination) + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigrateDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigrateDb.scala new file mode 100644 index 000000000..331e62b1d --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigrateDb.scala @@ -0,0 +1,53 @@ +package fr.acinq.eclair.db.migration + +import fr.acinq.eclair.db.Databases.{PostgresDatabases, SqliteDatabases} +import fr.acinq.eclair.db.DualDatabases +import fr.acinq.eclair.db.jdbc.JdbcUtils +import fr.acinq.eclair.db.pg.PgUtils +import grizzled.slf4j.Logging + +import java.sql.{Connection, PreparedStatement, ResultSet} + +object MigrateDb extends Logging { + + private def getVersion(conn: Connection, + dbName: String): Int = { + val statement = conn.prepareStatement(s"SELECT version FROM versions WHERE db_name='$dbName'") + val res = statement.executeQuery() + res.next() + res.getInt("version") + } + + def checkVersions(source: Connection, + destination: Connection, + dbName: String, + expectedSourceVersion: Int, + expectedDestinationVersion: Int): Unit = { + val actualSourceVersion = getVersion(source, dbName) + val actualDestinationVersion = getVersion(destination, dbName) + require(actualSourceVersion == expectedSourceVersion, s"unexpected version for source db=$dbName expected=$expectedSourceVersion actual=$actualSourceVersion") + require(actualDestinationVersion == expectedDestinationVersion, s"unexpected version for destination db=$dbName expected=$expectedDestinationVersion actual=$actualDestinationVersion") + } + + def migrateTable(source: Connection, + destination: Connection, + sourceTable: String, + insertSql: String, + migrate: (ResultSet, PreparedStatement) => Unit): Int = + JdbcUtils.migrateTable(source, destination, sourceTable, insertSql, migrate)(logger) + + def migrateAll(dualDatabases: DualDatabases): Unit = { + logger.info("migrating all tables...") + val (sqliteDb: SqliteDatabases, postgresDb: PostgresDatabases) = DualDatabases.getDatabases(dualDatabases) + PgUtils.inTransaction { postgres => + MigrateChannelsDb.migrateAllTables(sqliteDb.channels.sqlite, postgres) + MigratePendingCommandsDb.migrateAllTables(sqliteDb.pendingCommands.sqlite, postgres) + MigratePeersDb.migrateAllTables(sqliteDb.peers.sqlite, postgres) + MigratePaymentsDb.migrateAllTables(sqliteDb.payments.sqlite, postgres) + MigrateNetworkDb.migrateAllTables(sqliteDb.network.sqlite, postgres) + MigrateAuditDb.migrateAllTables(sqliteDb.audit.sqlite, postgres) + logger.info("migration complete") + }(postgresDb.dataSource) + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigrateNetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigrateNetworkDb.scala new file mode 100644 index 000000000..807e97f18 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigrateNetworkDb.scala @@ -0,0 +1,74 @@ +package fr.acinq.eclair.db.migration + +import fr.acinq.eclair.db.jdbc.JdbcUtils.ExtendedResultSet._ +import fr.acinq.eclair.db.migration.MigrateDb.{checkVersions, migrateTable} +import fr.acinq.eclair.wire.protocol.LightningMessageCodecs.{channelAnnouncementCodec, channelUpdateCodec, nodeAnnouncementCodec} +import scodec.bits.BitVector + +import java.sql.{Connection, PreparedStatement, ResultSet} + +object MigrateNetworkDb { + + private def migrateNodesTable(source: Connection, destination: Connection): Int = { + val sourceTable = "nodes" + val insertSql = "INSERT INTO network.nodes (node_id, data, json) VALUES (?, ?, ?::JSONB)" + + import fr.acinq.eclair.json.JsonSerializers._ + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setString(1, rs.getByteVector("node_id").toHex) + insertStatement.setBytes(2, rs.getBytes("data")) + val state = nodeAnnouncementCodec.decode(BitVector(rs.getBytes("data"))).require.value + val json = serialization.writePretty(state) + insertStatement.setString(3, json) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + private def migrateChannelsTable(source: Connection, destination: Connection): Int = { + val sourceTable = "channels" + val insertSql = "INSERT INTO network.public_channels (short_channel_id, txid, channel_announcement, capacity_sat, channel_update_1, channel_update_2, channel_announcement_json, channel_update_1_json, channel_update_2_json) VALUES (?, ?, ?, ?, ?, ?, ?::JSONB, ?::JSONB, ?::JSONB)" + + import fr.acinq.eclair.json.JsonSerializers._ + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setLong(1, rs.getLong("short_channel_id")) + insertStatement.setString(2, rs.getString("txid")) + insertStatement.setBytes(3, rs.getBytes("channel_announcement")) + insertStatement.setLong(4, rs.getLong("capacity_sat")) + insertStatement.setBytes(5, rs.getBytes("channel_update_1")) + insertStatement.setBytes(6, rs.getBytes("channel_update_2")) + val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value + val channel_update_1_opt = rs.getBitVectorOpt("channel_update_1").map(channelUpdateCodec.decode(_).require.value) + val channel_update_2_opt = rs.getBitVectorOpt("channel_update_2").map(channelUpdateCodec.decode(_).require.value) + val json = serialization.writePretty(ann) + val u1_json = channel_update_1_opt.map(serialization.writePretty(_)).orNull + val u2_json = channel_update_2_opt.map(serialization.writePretty(_)).orNull + insertStatement.setString(7, json) + insertStatement.setString(8, u1_json) + insertStatement.setString(9, u2_json) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + private def migratePrunedTable(source: Connection, destination: Connection): Int = { + val sourceTable = "pruned" + val insertSql = "INSERT INTO network.pruned_channels (short_channel_id) VALUES (?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setLong(1, rs.getLong("short_channel_id")) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + def migrateAllTables(source: Connection, destination: Connection): Unit = { + checkVersions(source, destination, "network", 2, 4) + migrateNodesTable(source, destination) + migrateChannelsTable(source, destination) + migratePrunedTable(source, destination) + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigratePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigratePaymentsDb.scala new file mode 100644 index 000000000..ac2e885bf --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigratePaymentsDb.scala @@ -0,0 +1,60 @@ +package fr.acinq.eclair.db.migration + +import fr.acinq.eclair.db.jdbc.JdbcUtils.ExtendedResultSet._ +import fr.acinq.eclair.db.migration.MigrateDb.{checkVersions, migrateTable} + +import java.sql.{Connection, PreparedStatement, ResultSet, Timestamp} +import java.time.Instant + +object MigratePaymentsDb { + + private def migrateReceivedPaymentsTable(source: Connection, destination: Connection): Int = { + val sourceTable = "received_payments" + val insertSql = "INSERT INTO payments.received (payment_hash, payment_type, payment_preimage, payment_request, received_msat, created_at, expire_at, received_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setString(1, rs.getByteVector("payment_hash").toHex) + insertStatement.setString(2, rs.getString("payment_type")) + insertStatement.setString(3, rs.getByteVector("payment_preimage").toHex) + insertStatement.setString(4, rs.getString("payment_request")) + insertStatement.setObject(5, rs.getLongNullable("received_msat").orNull) + insertStatement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(rs.getLong("created_at")))) + insertStatement.setTimestamp(7, Timestamp.from(Instant.ofEpochMilli(rs.getLong("expire_at")))) + insertStatement.setObject(8, rs.getLongNullable("received_at").map(l => Timestamp.from(Instant.ofEpochMilli(l))).orNull) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + private def migrateSentPaymentsTable(source: Connection, destination: Connection): Int = { + val sourceTable = "sent_payments" + val insertSql = "INSERT INTO payments.sent (id, parent_id, external_id, payment_hash, payment_preimage, payment_type, amount_msat, fees_msat, recipient_amount_msat, recipient_node_id, payment_request, payment_route, failures, created_at, completed_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setString(1, rs.getString("id")) + insertStatement.setString(2, rs.getString("parent_id")) + insertStatement.setString(3, rs.getStringNullable("external_id").orNull) + insertStatement.setString(4, rs.getByteVector("payment_hash").toHex) + insertStatement.setString(5, rs.getByteVector32Nullable("payment_preimage").map(_.toHex).orNull) + insertStatement.setString(6, rs.getString("payment_type")) + insertStatement.setLong(7, rs.getLong("amount_msat")) + insertStatement.setObject(8, rs.getLongNullable("fees_msat").orNull) + insertStatement.setLong(9, rs.getLong("recipient_amount_msat")) + insertStatement.setString(10, rs.getByteVector("recipient_node_id").toHex) + insertStatement.setString(11, rs.getStringNullable("payment_request").orNull) + insertStatement.setBytes(12, rs.getBytes("payment_route")) + insertStatement.setBytes(13, rs.getBytes("failures")) + insertStatement.setTimestamp(14, Timestamp.from(Instant.ofEpochMilli(rs.getLong("created_at")))) + insertStatement.setObject(15, rs.getLongNullable("completed_at").map(l => Timestamp.from(Instant.ofEpochMilli(l))).orNull) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + def migrateAllTables(source: Connection, destination: Connection): Unit = { + checkVersions(source, destination, "payments", 4, 6) + migrateReceivedPaymentsTable(source, destination) + migrateSentPaymentsTable(source, destination) + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigratePeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigratePeersDb.scala new file mode 100644 index 000000000..77021f3a3 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigratePeersDb.scala @@ -0,0 +1,41 @@ +package fr.acinq.eclair.db.migration + +import fr.acinq.eclair.db.jdbc.JdbcUtils.ExtendedResultSet._ +import fr.acinq.eclair.db.migration.MigrateDb.{checkVersions, migrateTable} + +import java.sql.{Connection, PreparedStatement, ResultSet} + +object MigratePeersDb { + + private def migratePeersTable(source: Connection, destination: Connection): Int = { + val sourceTable = "peers" + val insertSql = "INSERT INTO local.peers (node_id, data) VALUES (?, ?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setString(1, rs.getByteVector("node_id").toHex) + insertStatement.setBytes(2, rs.getBytes("data")) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + private def migrateRelayFeesTable(source: Connection, destination: Connection): Int = { + val sourceTable = "relay_fees" + val insertSql = "INSERT INTO local.relay_fees (node_id, fee_base_msat, fee_proportional_millionths) VALUES (?, ?, ?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setString(1, rs.getByteVector("node_id").toHex) + insertStatement.setLong(2, rs.getLong("fee_base_msat")) + insertStatement.setLong(3, rs.getLong("fee_proportional_millionths")) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + def migrateAllTables(source: Connection, destination: Connection): Unit = { + checkVersions(source, destination, "peers", 2, 3) + migratePeersTable(source, destination) + migrateRelayFeesTable(source, destination) + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigratePendingCommandsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigratePendingCommandsDb.scala new file mode 100644 index 000000000..6a0120876 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/migration/MigratePendingCommandsDb.scala @@ -0,0 +1,28 @@ +package fr.acinq.eclair.db.migration + +import fr.acinq.eclair.db.jdbc.JdbcUtils.ExtendedResultSet._ +import fr.acinq.eclair.db.migration.MigrateDb.{checkVersions, migrateTable} + +import java.sql.{Connection, PreparedStatement, ResultSet} + +object MigratePendingCommandsDb { + + private def migratePendingSettlementCommandsTable(source: Connection, destination: Connection): Int = { + val sourceTable = "pending_settlement_commands" + val insertSql = "INSERT INTO local.pending_settlement_commands (channel_id, htlc_id, data) VALUES (?, ?, ?)" + + def migrate(rs: ResultSet, insertStatement: PreparedStatement): Unit = { + insertStatement.setString(1, rs.getByteVector("channel_id").toHex) + insertStatement.setLong(2, rs.getLong("htlc_id")) + insertStatement.setBytes(3, rs.getBytes("data")) + } + + migrateTable(source, destination, sourceTable, insertSql, migrate) + } + + def migrateAllTables(source: Connection, destination: Connection): Unit = { + checkVersions(source, destination, "pending_relay", 2, 3) + migratePendingSettlementCommandsTable(source, destination) + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala index 74c01a145..ed2f40324 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala @@ -37,7 +37,7 @@ object SqliteAuditDb { val CURRENT_VERSION = 8 } -class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { +class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { import SqliteUtils._ import ExtendedResultSet._ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala index a376a4295..3da17af57 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala @@ -34,7 +34,7 @@ object SqliteNetworkDb { val DB_NAME = "network" } -class SqliteNetworkDb(sqlite: Connection) extends NetworkDb with Logging { +class SqliteNetworkDb(val sqlite: Connection) extends NetworkDb with Logging { import SqliteNetworkDb._ import SqliteUtils.ExtendedResultSet._ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index 2dce58d6e..fbaef7851 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -33,7 +33,7 @@ import java.util.UUID import scala.concurrent.duration._ import scala.util.{Failure, Success, Try} -class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { +class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { import SqlitePaymentsDb._ import SqliteUtils.ExtendedResultSet._ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala index 7b4e0f2f2..04bfcb2be 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala @@ -35,7 +35,7 @@ object SqlitePeersDb { val CURRENT_VERSION = 2 } -class SqlitePeersDb(sqlite: Connection) extends PeersDb with Logging { +class SqlitePeersDb(val sqlite: Connection) extends PeersDb with Logging { import SqlitePeersDb._ import SqliteUtils.ExtendedResultSet._ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingCommandsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingCommandsDb.scala index 82ae1111f..46b6ea80f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingCommandsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingCommandsDb.scala @@ -31,7 +31,7 @@ object SqlitePendingCommandsDb { val CURRENT_VERSION = 2 } -class SqlitePendingCommandsDb(sqlite: Connection) extends PendingCommandsDb with Logging { +class SqlitePendingCommandsDb(val sqlite: Connection) extends PendingCommandsDb with Logging { import SqlitePendingCommandsDb._ import SqliteUtils.ExtendedResultSet._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/DbMigrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/DbMigrationSpec.scala new file mode 100644 index 000000000..5d1a4fef8 --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/DbMigrationSpec.scala @@ -0,0 +1,117 @@ +package fr.acinq.eclair.db + +import akka.actor.ActorSystem +import com.opentable.db.postgres.embedded.EmbeddedPostgres +import com.zaxxer.hikari.HikariConfig +import fr.acinq.eclair.db.Databases.{PostgresDatabases, SqliteDatabases} +import fr.acinq.eclair.db.migration._ +import fr.acinq.eclair.db.pg.PgUtils.PgLock +import fr.acinq.eclair.db.pg._ +import org.scalatest.Ignore +import org.scalatest.funsuite.AnyFunSuite +import org.sqlite.SQLiteConfig + +import java.io.File +import java.sql.{Connection, DriverManager} +import java.util.UUID +import javax.sql.DataSource + +/** + * To run this test, create a `migration` directory in your project's `user.dir` + * and copy your sqlite files to it (eclair.sqlite, network.sqlite, audit.sqlite). + * Then remove the `Ignore` annotation and run the test. + */ +@Ignore +class DbMigrationSpec extends AnyFunSuite { + + import DbMigrationSpec._ + + test("eclair migration test") { + val sqlite = loadSqlite("migration\\eclair.sqlite") + val postgresDatasource = EmbeddedPostgres.start().getPostgresDatabase + + new PgChannelsDb()(postgresDatasource, PgLock.NoLock) + new PgPendingCommandsDb()(postgresDatasource, PgLock.NoLock) + new PgPeersDb()(postgresDatasource, PgLock.NoLock) + new PgPaymentsDb()(postgresDatasource, PgLock.NoLock) + + PgUtils.inTransaction { postgres => + MigrateChannelsDb.migrateAllTables(sqlite, postgres) + MigratePendingCommandsDb.migrateAllTables(sqlite, postgres) + MigratePeersDb.migrateAllTables(sqlite, postgres) + MigratePaymentsDb.migrateAllTables(sqlite, postgres) + assert(CompareChannelsDb.compareAllTables(sqlite, postgres)) + assert(ComparePendingCommandsDb.compareAllTables(sqlite, postgres)) + assert(ComparePeersDb.compareAllTables(sqlite, postgres)) + assert(ComparePaymentsDb.compareAllTables(sqlite, postgres)) + }(postgresDatasource) + + sqlite.close() + } + + test("network migration test") { + val sqlite = loadSqlite("migration\\network.sqlite") + val postgresDatasource = EmbeddedPostgres.start().getPostgresDatabase + + new PgNetworkDb()(postgresDatasource) + + PgUtils.inTransaction { postgres => + MigrateNetworkDb.migrateAllTables(sqlite, postgres) + assert(CompareNetworkDb.compareAllTables(sqlite, postgres)) + }(postgresDatasource) + + sqlite.close() + } + + test("audit migration test") { + val sqlite = loadSqlite("migration\\audit.sqlite") + val postgresDatasource = EmbeddedPostgres.start().getPostgresDatabase + + new PgAuditDb()(postgresDatasource) + + PgUtils.inTransaction { postgres => + MigrateAuditDb.migrateAllTables(sqlite, postgres) + assert(CompareAuditDb.compareAllTables(sqlite, postgres)) + }(postgresDatasource) + + sqlite.close() + } + + test("full migration") { + // we need to open in read/write because of the getVersion call + val sqlite = SqliteDatabases( + auditJdbc = loadSqlite("migration\\audit.sqlite", readOnly = false), + eclairJdbc = loadSqlite("migration\\eclair.sqlite", readOnly = false), + networkJdbc = loadSqlite("migration\\network.sqlite", readOnly = false), + jdbcUrlFile_opt = None + ) + val postgres = { + val pg = EmbeddedPostgres.start() + val datasource: DataSource = pg.getPostgresDatabase + val hikariConfig = new HikariConfig + hikariConfig.setDataSource(datasource) + PostgresDatabases( + hikariConfig = hikariConfig, + instanceId = UUID.randomUUID(), + lock = PgLock.NoLock, + jdbcUrlFile_opt = None, + readOnlyUser_opt = None, + resetJsonColumns = false, + safetyChecks_opt = None + )(ActorSystem()) + } + val dualDb = DualDatabases(sqlite, postgres) + MigrateDb.migrateAll(dualDb) + CompareDb.compareAll(dualDb) + } + +} + +object DbMigrationSpec { + def loadSqlite(path: String, readOnly: Boolean = true): Connection = { + val sqliteConfig = new SQLiteConfig() + sqliteConfig.setReadOnly(readOnly) + val dbFile = new File(path) + DriverManager.getConnection(s"jdbc:sqlite:$dbFile", sqliteConfig.toProperties) + } +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/DualDatabasesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/DualDatabasesSpec.scala new file mode 100644 index 000000000..a9ae922f2 --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/DualDatabasesSpec.scala @@ -0,0 +1,86 @@ +package fr.acinq.eclair.db + +import com.opentable.db.postgres.embedded.EmbeddedPostgres +import com.typesafe.config.{Config, ConfigFactory} +import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec +import fr.acinq.eclair.{TestKitBaseClass, TestUtils} +import org.scalatest.funsuite.AnyFunSuiteLike + +import java.io.File +import java.util.UUID + +class DualDatabasesSpec extends TestKitBaseClass with AnyFunSuiteLike { + + def fixture(driver: String): DualDatabases = { + val pg = EmbeddedPostgres.start() + val config = DualDatabasesSpec.testConfig(pg.getPort, driver) + val datadir = new File(TestUtils.BUILD_DIRECTORY, s"pg_test_${UUID.randomUUID()}") + datadir.mkdirs() + val instanceId = UUID.randomUUID() + Databases.init(config, instanceId, datadir).asInstanceOf[DualDatabases] + } + + test("sqlite primary") { + val db = fixture("dual-sqlite-primary") + + db.channels.addOrUpdateChannel(ChannelCodecsSpec.normal) + assert(db.primary.channels.listLocalChannels().nonEmpty) + awaitCond(db.primary.channels.listLocalChannels() === db.secondary.channels.listLocalChannels()) + } + + test("postgres primary") { + val db = fixture("dual-postgres-primary") + + db.channels.addOrUpdateChannel(ChannelCodecsSpec.normal) + assert(db.primary.channels.listLocalChannels().nonEmpty) + awaitCond(db.primary.channels.listLocalChannels() === db.secondary.channels.listLocalChannels()) + } +} + +object DualDatabasesSpec { + def testConfig(port: Int, driver: String): Config = + ConfigFactory.parseString( + s""" + |driver = $driver + |postgres { + | database = "" + | host = "localhost" + | port = $port + | username = "postgres" + | password = "" + | readonly-user = "" + | reset-json-columns = false + | pool { + | max-size = 10 // recommended value = number_of_cpu_cores * 2 + | connection-timeout = 30 seconds + | idle-timeout = 10 minutes + | max-life-time = 30 minutes + | } + | lock-type = "lease" // lease or none (do not use none in production) + | lease { + | interval = 5 seconds // lease-interval must be greater than lease-renew-interval + | renew-interval = 2 seconds + | lock-timeout = 5 seconds // timeout for the lock statement on the lease table + | auto-release-at-shutdown = false // automatically release the lock when eclair is stopping + | } + | safety-checks { + | // a set of basic checks on data to make sure we use the correct database + | enabled = false + | max-age { + | local-channels = 3 minutes + | network-nodes = 30 minutes + | audit-relayed = 10 minutes + | } + | min-count { + | local-channels = 10 + | network-nodes = 3000 + | network-channels = 20000 + | } + | } + |} + |dual { + | migrate-on-restart = false // migrate sqlite -> postgres on restart (only applies if sqlite is primary) + | compare-on-restart = false // compare sqlite and postgres dbs on restart + |} + |""".stripMargin) +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PgUtilsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PgUtilsSpec.scala index 3a65c0d8f..a11283f4a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/PgUtilsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PgUtilsSpec.scala @@ -32,12 +32,12 @@ class PgUtilsSpec extends TestKitBaseClass with AnyFunSuiteLike with Eventually datadir.mkdirs() val instanceId1 = UUID.randomUUID() // this will lock the database for this instance id - val db1 = Databases.postgres(config, instanceId1, datadir, LockFailureHandler.logAndThrow) + val db1 = Databases.postgres(config, instanceId1, datadir, None, LockFailureHandler.logAndThrow) assert( intercept[LockFailureHandler.LockException] { // this will fail because the database is already locked for a different instance id - Databases.postgres(config, UUID.randomUUID(), datadir, LockFailureHandler.logAndThrow) + Databases.postgres(config, UUID.randomUUID(), datadir, None, LockFailureHandler.logAndThrow) }.lockFailure === LockFailure.AlreadyLocked(instanceId1)) // we can renew the lease at will @@ -48,7 +48,7 @@ class PgUtilsSpec extends TestKitBaseClass with AnyFunSuiteLike with Eventually assert( intercept[LockFailureHandler.LockException] { // this will fail because the database is already locked for a different instance id - Databases.postgres(config, UUID.randomUUID(), datadir, LockFailureHandler.logAndThrow) + Databases.postgres(config, UUID.randomUUID(), datadir, None, LockFailureHandler.logAndThrow) }.lockFailure === LockFailure.AlreadyLocked(instanceId1)) // we close the first connection @@ -59,7 +59,7 @@ class PgUtilsSpec extends TestKitBaseClass with AnyFunSuiteLike with Eventually // now we can put a lock with a different instance id val instanceId2 = UUID.randomUUID() - val db2 = Databases.postgres(config, instanceId2, datadir, LockFailureHandler.logAndThrow) + val db2 = Databases.postgres(config, instanceId2, datadir, None, LockFailureHandler.logAndThrow) // we close the second connection db2.dataSource.close() @@ -68,7 +68,7 @@ class PgUtilsSpec extends TestKitBaseClass with AnyFunSuiteLike with Eventually // but we don't wait for the previous lease to expire, so we can't take over right now assert(intercept[LockFailureHandler.LockException] { // this will fail because even if we have acquired the table lock, the previous lease still hasn't expired - Databases.postgres(config, UUID.randomUUID(), datadir, LockFailureHandler.logAndThrow) + Databases.postgres(config, UUID.randomUUID(), datadir, None, LockFailureHandler.logAndThrow) }.lockFailure === LockFailure.AlreadyLocked(instanceId2)) pg.close() @@ -81,7 +81,7 @@ class PgUtilsSpec extends TestKitBaseClass with AnyFunSuiteLike with Eventually datadir.mkdirs() val instanceId1 = UUID.randomUUID() // this will lock the database for this instance id - val db = Databases.postgres(config, instanceId1, datadir, LockFailureHandler.logAndThrow) + val db = Databases.postgres(config, instanceId1, datadir, None, LockFailureHandler.logAndThrow) implicit val ds: DataSource = db.dataSource // dummy query works @@ -133,8 +133,9 @@ class PgUtilsSpec extends TestKitBaseClass with AnyFunSuiteLike with Eventually val config = PgUtilsSpec.testConfig(pg.getPort) val datadir = new File(TestUtils.BUILD_DIRECTORY, s"pg_test_${UUID.randomUUID()}") datadir.mkdirs() + val jdbcUrlPath = new File(datadir, "last_jdbcurl") // this will lock the database for this instance id - val db = Databases.postgres(config, UUID.randomUUID(), datadir, LockFailureHandler.logAndThrow) + val db = Databases.postgres(config, UUID.randomUUID(), datadir, Some(jdbcUrlPath), LockFailureHandler.logAndThrow) // we close the first connection db.dataSource.close() @@ -143,7 +144,7 @@ class PgUtilsSpec extends TestKitBaseClass with AnyFunSuiteLike with Eventually // here we change the config to simulate an involuntary change in the server we connect to val config1 = ConfigFactory.parseString("postgres.port=1234").withFallback(config) intercept[JdbcUrlChanged] { - Databases.postgres(config1, UUID.randomUUID(), datadir, LockFailureHandler.logAndThrow) + Databases.postgres(config1, UUID.randomUUID(), datadir, Some(jdbcUrlPath), LockFailureHandler.logAndThrow) } pg.close() @@ -156,7 +157,7 @@ class PgUtilsSpec extends TestKitBaseClass with AnyFunSuiteLike with Eventually .withFallback(PgUtilsSpec.testConfig(pg.getPort)) val datadir = new File(TestUtils.BUILD_DIRECTORY, s"pg_test_${UUID.randomUUID()}") datadir.mkdirs() - Databases.postgres(config, UUID.randomUUID(), datadir, LockFailureHandler.logAndThrow) + Databases.postgres(config, UUID.randomUUID(), datadir, None, LockFailureHandler.logAndThrow) } test("safety checks") { @@ -166,7 +167,7 @@ class PgUtilsSpec extends TestKitBaseClass with AnyFunSuiteLike with Eventually datadir.mkdirs() { - val db = Databases.postgres(baseConfig, UUID.randomUUID(), datadir, LockFailureHandler.logAndThrow) + val db = Databases.postgres(baseConfig, UUID.randomUUID(), datadir, None, LockFailureHandler.logAndThrow) db.channels.addOrUpdateChannel(ChannelCodecsSpec.normal) db.channels.updateChannelMeta(ChannelCodecsSpec.normal.channelId, ChannelEvent.EventType.Created) db.network.addNode(Announcements.makeNodeAnnouncement(randomKey(), "node-A", Color(50, 99, -80), Nil, Features.empty, TimestampSecond.now() - 45.days)) @@ -196,7 +197,7 @@ class PgUtilsSpec extends TestKitBaseClass with AnyFunSuiteLike with Eventually | } |}""".stripMargin) val config = safetyConfig.withFallback(baseConfig) - val db = Databases.postgres(config, UUID.randomUUID(), datadir, LockFailureHandler.logAndThrow) + val db = Databases.postgres(config, UUID.randomUUID(), datadir, None, LockFailureHandler.logAndThrow) db.dataSource.close() } @@ -221,7 +222,7 @@ class PgUtilsSpec extends TestKitBaseClass with AnyFunSuiteLike with Eventually |}""".stripMargin) val config = safetyConfig.withFallback(baseConfig) intercept[IllegalArgumentException] { - Databases.postgres(config, UUID.randomUUID(), datadir, LockFailureHandler.logAndThrow) + Databases.postgres(config, UUID.randomUUID(), datadir, None, LockFailureHandler.logAndThrow) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala index feead2ce8..93ed45f2e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala @@ -81,17 +81,17 @@ class SqliteUtilsSpec extends AnyFunSuite { test("jdbc url check") { val datadir = new File(TestUtils.BUILD_DIRECTORY, s"sqlite_test_${UUID.randomUUID()}") - val jdbcUrlPath = new File(datadir, "last_jdbcurl") datadir.mkdirs() + val jdbcUrlPath = new File(datadir, "last_jdbcurl") // first start : write to file - val db1 = Databases.sqlite(datadir) + val db1 = Databases.sqlite(datadir, Some(jdbcUrlPath)) db1.channels.close() assert(Files.readString(jdbcUrlPath.toPath).trim == "sqlite") // 2nd start : no-op - val db2 = Databases.sqlite(datadir) + val db2 = Databases.sqlite(datadir, Some(jdbcUrlPath)) db2.channels.close() // we modify the file @@ -99,7 +99,7 @@ class SqliteUtilsSpec extends AnyFunSuite { // boom intercept[JdbcUrlChanged] { - Databases.sqlite(datadir) + Databases.sqlite(datadir, Some(jdbcUrlPath)) } }