1
0
mirror of https://github.com/ACINQ/eclair.git synced 2024-11-19 18:10:42 +01:00

Use proper data type for timestamps in Postgres (#1778)

Did some refactoring in tests and introduced a new `migrationCheck`
helper method.

Note that the change of data type in sqlite for the `commitment_number`
field (from `BLOB` to `INTEGER`) is not a migration. If the table has
been created before, it will stay like it was. It doesn't matter due to
how sqlite stores data, and we make sure in tests that there is no
regression.
This commit is contained in:
Pierre-Marie Padiou 2021-04-22 10:16:40 +02:00 committed by GitHub
parent 4a1dfd2a27
commit e14c40d7c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 533 additions and 394 deletions

View File

@ -22,7 +22,7 @@ import org.sqlite.SQLiteConnection
import scodec.Codec
import scodec.bits.{BitVector, ByteVector}
import java.sql.{Connection, ResultSet, Statement}
import java.sql.{Connection, ResultSet, Statement, Timestamp}
import java.util.UUID
import javax.sql.DataSource
import scala.collection.immutable.Queue
@ -123,18 +123,16 @@ trait JdbcUtils {
def getByteVector32FromHexNullable(columnLabel: String): Option[ByteVector32] = {
val s = rs.getString(columnLabel)
if (rs.wasNull()) None else {
Some(ByteVector32(ByteVector.fromValidHex(s)))
}
if (rs.wasNull()) None else Some(ByteVector32(ByteVector.fromValidHex(s)))
}
def getBitVectorOpt(columnLabel: String): Option[BitVector] = Option(rs.getBytes(columnLabel)).map(BitVector(_))
def getByteVector(columnLabel: String): ByteVector = ByteVector(rs.getBytes(columnLabel))
def getByteVectorNullable(columnLabel: String): ByteVector = {
def getByteVectorNullable(columnLabel: String): Option[ByteVector] = {
val result = rs.getBytes(columnLabel)
if (rs.wasNull()) ByteVector.empty else ByteVector(result)
if (rs.wasNull()) None else Some(ByteVector(result))
}
def getByteVector32(columnLabel: String): ByteVector32 = ByteVector32(ByteVector(rs.getBytes(columnLabel)))
@ -164,6 +162,11 @@ trait JdbcUtils {
if (rs.wasNull()) None else Some(MilliSatoshi(result))
}
def getTimestampNullable(label: String): Option[Timestamp] = {
val result = rs.getTimestamp(label)
if (rs.wasNull()) None else Some(result)
}
}
object ExtendedResultSet {

View File

@ -29,7 +29,8 @@ import fr.acinq.eclair.transactions.Transactions.PlaceHolderPubKey
import fr.acinq.eclair.{MilliSatoshi, MilliSatoshiLong}
import grizzled.slf4j.Logging
import java.sql.Statement
import java.sql.{Statement, Timestamp}
import java.time.Instant
import java.util.UUID
import javax.sql.DataSource
import scala.collection.immutable.Queue
@ -40,7 +41,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
import ExtendedResultSet._
val DB_NAME = "audit"
val CURRENT_VERSION = 5
val CURRENT_VERSION = 6
case class RelayedPart(channelId: ByteVector32, amount: MilliSatoshi, direction: String, relayType: String, timestamp: Long)
@ -52,15 +53,25 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.executeUpdate("CREATE INDEX relayed_trampoline_payment_hash_idx ON relayed_trampoline(payment_hash)")
}
def migration56(statement: Statement): Unit = {
statement.executeUpdate("ALTER TABLE sent ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE received ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE relayed ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE relayed_trampoline ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE network_fees ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE channel_events ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE channel_errors ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
}
getVersion(statement, DB_NAME) match {
case None =>
statement.executeUpdate("CREATE TABLE sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE relayed_trampoline (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, next_node_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
statement.executeUpdate("CREATE TABLE received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
statement.executeUpdate("CREATE TABLE relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
statement.executeUpdate("CREATE TABLE relayed_trampoline (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, next_node_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
statement.executeUpdate("CREATE TABLE network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
statement.executeUpdate("CREATE TABLE channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
statement.executeUpdate("CREATE TABLE channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
statement.executeUpdate("CREATE INDEX sent_timestamp_idx ON sent(timestamp)")
statement.executeUpdate("CREATE INDEX received_timestamp_idx ON received(timestamp)")
@ -74,6 +85,10 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
case Some(v@4) =>
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
migration45(statement)
migration56(statement)
case Some(v@5) =>
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
migration56(statement)
case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do
case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
}
@ -90,7 +105,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.setBoolean(4, e.isFunder)
statement.setBoolean(5, e.isPrivate)
statement.setString(6, e.event.label)
statement.setLong(7, System.currentTimeMillis)
statement.setTimestamp(7, Timestamp.from(Instant.now()))
statement.executeUpdate()
}
}
@ -109,7 +124,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.setString(7, e.paymentPreimage.toHex)
statement.setString(8, e.recipientNodeId.value.toHex)
statement.setString(9, p.toChannelId.toHex)
statement.setLong(10, p.timestamp)
statement.setTimestamp(10, Timestamp.from(Instant.ofEpochMilli(p.timestamp)))
statement.addBatch()
})
statement.executeBatch()
@ -124,7 +139,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.setLong(1, p.amount.toLong)
statement.setString(2, e.paymentHash.toHex)
statement.setString(3, p.fromChannelId.toHex)
statement.setLong(4, p.timestamp)
statement.setTimestamp(4, Timestamp.from(Instant.ofEpochMilli(p.timestamp)))
statement.addBatch()
})
statement.executeBatch()
@ -143,7 +158,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.setString(1, e.paymentHash.toHex)
statement.setLong(2, nextTrampolineAmount.toLong)
statement.setString(3, nextTrampolineNodeId.value.toHex)
statement.setLong(4, e.timestamp)
statement.setTimestamp(4, Timestamp.from(Instant.ofEpochMilli(e.timestamp)))
statement.executeUpdate()
}
// trampoline relayed payments do MPP aggregation and may have M inputs and N outputs
@ -156,7 +171,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.setString(3, p.channelId.toHex)
statement.setString(4, p.direction)
statement.setString(5, p.relayType)
statement.setLong(6, e.timestamp)
statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(e.timestamp)))
statement.executeUpdate()
}
}
@ -171,7 +186,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.setString(3, e.tx.txid.toHex)
statement.setLong(4, e.fee.toLong)
statement.setString(5, e.txType)
statement.setLong(6, System.currentTimeMillis)
statement.setTimestamp(6, Timestamp.from(Instant.now()))
statement.executeUpdate()
}
}
@ -189,7 +204,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.setString(3, errorName)
statement.setString(4, errorMessage)
statement.setBoolean(5, e.isFatal)
statement.setLong(6, System.currentTimeMillis)
statement.setTimestamp(6, Timestamp.from(Instant.now()))
statement.executeUpdate()
}
}
@ -197,9 +212,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
override def listSent(from: Long, to: Long): Seq[PaymentSent] =
inTransaction { pg =>
using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp >= ? AND timestamp < ?")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp BETWEEN ? AND ?")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
var sentByParentId = Map.empty[UUID, PaymentSent]
while (rs.next()) {
@ -210,7 +225,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
MilliSatoshi(rs.getLong("fees_msat")),
rs.getByteVector32FromHex("to_channel_id"),
None, // we don't store the route in the audit DB
rs.getLong("timestamp"))
rs.getTimestamp("timestamp").getTime)
val sent = sentByParentId.get(parentId) match {
case Some(s) => s.copy(parts = s.parts :+ part)
case None => PaymentSent(
@ -229,9 +244,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
override def listReceived(from: Long, to: Long): Seq[PaymentReceived] =
inTransaction { pg =>
using(pg.prepareStatement("SELECT * FROM received WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
using(pg.prepareStatement("SELECT * FROM received WHERE timestamp BETWEEN ? AND ?")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
var receivedByHash = Map.empty[ByteVector32, PaymentReceived]
while (rs.next()) {
@ -239,7 +254,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
val part = PaymentReceived.PartialPayment(
MilliSatoshi(rs.getLong("amount_msat")),
rs.getByteVector32FromHex("from_channel_id"),
rs.getLong("timestamp"))
rs.getTimestamp("timestamp").getTime)
val received = receivedByHash.get(paymentHash) match {
case Some(r) => r.copy(parts = r.parts :+ part)
case None => PaymentReceived(paymentHash, Seq(part))
@ -253,9 +268,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] =
inTransaction { pg =>
var trampolineByHash = Map.empty[ByteVector32, (MilliSatoshi, PublicKey)]
using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp >= ? AND timestamp < ?")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
while (rs.next()) {
val paymentHash = rs.getByteVector32FromHex("payment_hash")
@ -264,9 +279,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
trampolineByHash += (paymentHash -> (amount, nodeId))
}
}
using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ?")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
var relayedByHash = Map.empty[ByteVector32, Seq[RelayedPart]]
while (rs.next()) {
@ -276,7 +291,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
MilliSatoshi(rs.getLong("amount_msat")),
rs.getString("direction"),
rs.getString("relay_type"),
rs.getLong("timestamp"))
rs.getTimestamp("timestamp").getTime)
relayedByHash = relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part))
}
relayedByHash.flatMap {
@ -300,9 +315,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
override def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] =
inTransaction { pg =>
using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp BETWEEN ? and ? ORDER BY timestamp")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
var q: Queue[NetworkFee] = Queue()
while (rs.next()) {
@ -312,7 +327,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
txId = rs.getByteVector32FromHex("tx_id"),
fee = Satoshi(rs.getLong("fee_sat")),
txType = rs.getString("tx_type"),
timestamp = rs.getLong("timestamp"))
timestamp = rs.getTimestamp("timestamp").getTime)
}
q
}

View File

@ -27,7 +27,8 @@ import fr.acinq.eclair.db.pg.PgUtils.PgLock
import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec
import grizzled.slf4j.Logging
import java.sql.Statement
import java.sql.{Statement, Timestamp}
import java.time.Instant
import javax.sql.DataSource
import scala.collection.immutable.Queue
@ -38,7 +39,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit
import lock._
val DB_NAME = "channels"
val CURRENT_VERSION = 3
val CURRENT_VERSION = 4
inTransaction { pg =>
using(pg.createStatement()) { statement =>
@ -51,14 +52,28 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit
statement.executeUpdate("ALTER TABLE local_channels ADD COLUMN closed_timestamp BIGINT")
}
def migration34(statement: Statement): Unit = {
statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN created_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + created_timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_payment_sent_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_payment_sent_timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_payment_received_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_payment_received_timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_connected_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_connected_timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN closed_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + closed_timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE htlc_infos ALTER COLUMN commitment_number SET DATA TYPE BIGINT USING commitment_number::BIGINT")
}
getVersion(statement, DB_NAME) match {
case None =>
statement.executeUpdate("CREATE TABLE local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE, created_timestamp BIGINT, last_payment_sent_timestamp BIGINT, last_payment_received_timestamp BIGINT, last_connected_timestamp BIGINT, closed_timestamp BIGINT)")
statement.executeUpdate("CREATE TABLE htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
statement.executeUpdate("CREATE TABLE local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE, created_timestamp TIMESTAMP WITH TIME ZONE, last_payment_sent_timestamp TIMESTAMP WITH TIME ZONE, last_payment_received_timestamp TIMESTAMP WITH TIME ZONE, last_connected_timestamp TIMESTAMP WITH TIME ZONE, closed_timestamp TIMESTAMP WITH TIME ZONE)")
statement.executeUpdate("CREATE TABLE htlc_infos (channel_id TEXT NOT NULL, commitment_number BIGINT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
statement.executeUpdate("CREATE INDEX htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
case Some(v@2) =>
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
migration23(statement)
migration34(statement)
case Some(v@3) =>
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
migration34(statement)
case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do
case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
}
@ -89,7 +104,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit
private def updateChannelMetaTimestampColumn(channelId: ByteVector32, columnName: String): Unit = {
inTransaction { pg =>
using(pg.prepareStatement(s"UPDATE local_channels SET $columnName=? WHERE channel_id=?")) { statement =>
statement.setLong(1, System.currentTimeMillis)
statement.setTimestamp(1, Timestamp.from(Instant.now()))
statement.setString(2, channelId.toHex)
statement.executeUpdate()
}
@ -152,7 +167,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit
withLock { pg =>
using(pg.prepareStatement("SELECT payment_hash, cltv_expiry FROM htlc_infos WHERE channel_id=? AND commitment_number=?")) { statement =>
statement.setString(1, channelId.toHex)
statement.setString(2, commitmentNumber.toString)
statement.setLong(2, commitmentNumber)
val rs = statement.executeQuery
var q: Queue[(ByteVector32, CltvExpiry)] = Queue()
while (rs.next()) {

View File

@ -23,7 +23,6 @@ import fr.acinq.eclair.db.ChannelsDb
import fr.acinq.eclair.db.DbEventHandler.ChannelEvent
import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics
import fr.acinq.eclair.db.Monitoring.Tags.DbBackends
import fr.acinq.eclair.payment.{ChannelPaymentRelayed, PaymentEvent, PaymentReceived, PaymentRelayed, PaymentSent}
import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec
import grizzled.slf4j.Logging
@ -64,7 +63,7 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging {
getVersion(statement, DB_NAME) match {
case None =>
statement.executeUpdate("CREATE TABLE local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0, created_timestamp INTEGER, last_payment_sent_timestamp INTEGER, last_payment_received_timestamp INTEGER, last_connected_timestamp INTEGER, closed_timestamp INTEGER)")
statement.executeUpdate("CREATE TABLE htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
statement.executeUpdate("CREATE TABLE htlc_infos (channel_id BLOB NOT NULL, commitment_number INTEGER NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
statement.executeUpdate("CREATE INDEX htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
case Some(v@1) =>
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")

View File

@ -4,8 +4,8 @@ import akka.actor.ActorSystem
import com.opentable.db.postgres.embedded.EmbeddedPostgres
import com.zaxxer.hikari.HikariConfig
import fr.acinq.eclair.db._
import fr.acinq.eclair.db.pg.PgUtils.PgLock
import fr.acinq.eclair.db.pg.PgUtils.PgLock.LockFailureHandler
import fr.acinq.eclair.db.pg.PgUtils.{PgLock, getVersion, using}
import org.postgresql.jdbc.PgConnection
import org.sqlite.SQLiteConnection
@ -64,6 +64,7 @@ object TestDatabases {
// @formatter:off
override val connection: PgConnection = pg.getPostgresDatabase.getConnection.asInstanceOf[PgConnection]
// NB: we use a lazy val here: databases won't be initialized until we reference that variable
override lazy val db: Databases = Databases.PostgresDatabases(hikariConfig, UUID.randomUUID(), lock, jdbcUrlFile_opt = Some(jdbcUrlFile), readOnlyUser_opt = None)
override def close(): Unit = pg.close()
// @formatter:on
@ -77,4 +78,23 @@ object TestDatabases {
// @formatter:on
}
def migrationCheck(dbs: TestDatabases,
initializeTables: Connection => Unit,
dbName: String,
targetVersion: Int,
postCheck: Connection => Unit
): Unit = {
val connection = dbs.connection
// initialize the database to a previous version and populate data
initializeTables(connection)
// this will trigger the initialization of tables and the migration
val _ = dbs.db
// check that db version was updated
using(connection.createStatement()) { statement =>
assert(getVersion(statement, dbName).contains(targetVersion))
}
// post-migration checks
postCheck(connection)
}
}

View File

@ -18,7 +18,7 @@ package fr.acinq.eclair.db
import fr.acinq.bitcoin.Crypto.PrivateKey
import fr.acinq.bitcoin.{ByteVector32, SatoshiLong, Transaction}
import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases}
import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases, migrationCheck}
import fr.acinq.eclair._
import fr.acinq.eclair.channel.Helpers.Closing.MutualClose
import fr.acinq.eclair.channel.{ChannelErrorOccurred, LocalError, NetworkFeePaid, RemoteError}
@ -26,7 +26,7 @@ import fr.acinq.eclair.db.AuditDb.Stats
import fr.acinq.eclair.db.DbEventHandler.ChannelEvent
import fr.acinq.eclair.db.jdbc.JdbcUtils.using
import fr.acinq.eclair.db.pg.PgAuditDb
import fr.acinq.eclair.db.pg.PgUtils.{inTransaction, setVersion}
import fr.acinq.eclair.db.pg.PgUtils.{getVersion, setVersion}
import fr.acinq.eclair.db.sqlite.SqliteAuditDb
import fr.acinq.eclair.payment._
import fr.acinq.eclair.transactions.Transactions.PlaceHolderPubKey
@ -34,8 +34,9 @@ import fr.acinq.eclair.wire.protocol.Error
import org.scalatest.Tag
import org.scalatest.funsuite.AnyFunSuite
import java.sql.Timestamp
import java.time.Instant
import java.util.UUID
import javax.sql.DataSource
import scala.concurrent.duration._
import scala.util.Random
@ -182,13 +183,20 @@ class AuditDbSpec extends AnyFunSuite {
}
}
test("handle migration version 1 -> 5") {
forAllDbs {
case _: TestPgDatabases => // no migration
case dbs: TestSqliteDatabases =>
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
val connection = dbs.connection
test("migrate sqlite audit database v1 -> v5") {
val dbs = TestSqliteDatabases()
val ps = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 42000 msat, PrivateKey(ByteVector32.One).publicKey, PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None) :: Nil)
val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, None)
val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, None)
val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 84003 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil)
val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(Error(randomBytes32, "remote oops")), isFatal = true)
migrationCheck(
dbs = dbs,
initializeTables = connection => {
// simulate existing previous version db
using(connection.createStatement()) { statement =>
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
@ -208,17 +216,6 @@ class AuditDbSpec extends AnyFunSuite {
setVersion(statement, "audit", 1)
}
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(1))
}
val ps = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 42000 msat, PrivateKey(ByteVector32.One).publicKey, PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None) :: Nil)
val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, None)
val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, None)
val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 84003 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil)
val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(Error(randomBytes32, "remote oops")), isFatal = true)
// add a row (no ID on sent)
using(connection.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setLong(1, ps.recipientAmount.toLong)
@ -229,15 +226,12 @@ class AuditDbSpec extends AnyFunSuite {
statement.setLong(6, ps.timestamp)
statement.executeUpdate()
}
val migratedDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(5))
}
},
dbName = "audit",
targetVersion = 5,
postCheck = connection => {
// existing rows in the 'sent' table will use id=00000000-0000-0000-0000-000000000000 as default
assert(migratedDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID)))))
assert(dbs.audit.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID)))))
val postMigrationDb = new SqliteAuditDb(connection)
@ -252,16 +246,19 @@ class AuditDbSpec extends AnyFunSuite {
// the old record will have the UNKNOWN_UUID but the new ones will have their actual id
val expected = Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID))), ps1)
assert(postMigrationDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === expected)
}
}
)
}
test("handle migration version 2 -> 5") {
forAllDbs {
case _: TestPgDatabases => // no migration
case dbs: TestSqliteDatabases =>
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
val connection = dbs.connection
test("migrate sqlite audit database v2 -> v5") {
val dbs = TestSqliteDatabases()
val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(Error(randomBytes32, "remote oops")), isFatal = true)
migrationCheck(
dbs = dbs,
initializeTables = connection => {
// simulate existing previous version db
using(connection.createStatement()) { statement =>
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
@ -280,39 +277,39 @@ class AuditDbSpec extends AnyFunSuite {
setVersion(statement, "audit", 2)
}
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(2))
}
val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(Error(randomBytes32, "remote oops")), isFatal = true)
val migratedDb = new SqliteAuditDb(connection)
},
dbName = "audit",
targetVersion = 5,
postCheck = connection => {
val migratedDb = dbs.audit
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(5))
}
migratedDb.add(e1)
val postMigrationDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(5))
}
postMigrationDb.add(e2)
}
}
)
}
test("handle migration version 3 -> 5") {
forAllDbs {
case _: TestPgDatabases => // no migration
case dbs: TestSqliteDatabases =>
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
val connection = dbs.connection
test("migrate sqlite audit database v3 -> v5") {
val dbs = TestSqliteDatabases()
val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 100)
val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 110)
val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil)
val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105)
val relayed2 = ChannelPaymentRelayed(650 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 115)
migrationCheck(
dbs = dbs,
initializeTables = connection => {
// simulate existing previous version db
using(connection.createStatement()) { statement =>
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
@ -334,14 +331,6 @@ class AuditDbSpec extends AnyFunSuite {
setVersion(statement, "audit", 3)
}
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(3))
}
val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 100)
val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 110)
val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil)
for (pp <- Seq(pp1, pp2)) {
using(connection.prepareStatement("INSERT INTO sent (amount_msat, fees_msat, payment_hash, payment_preimage, to_channel_id, timestamp, id) VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement =>
statement.setLong(1, pp.amount.toLong)
@ -355,9 +344,6 @@ class AuditDbSpec extends AnyFunSuite {
}
}
val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105)
val relayed2 = ChannelPaymentRelayed(650 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 115)
for (relayed <- Seq(relayed1, relayed2)) {
using(connection.prepareStatement("INSERT INTO relayed (amount_in_msat, amount_out_msat, payment_hash, from_channel_id, to_channel_id, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setLong(1, relayed.amountIn.toLong)
@ -369,12 +355,14 @@ class AuditDbSpec extends AnyFunSuite {
statement.executeUpdate()
}
}
val migratedDb = new SqliteAuditDb(connection)
},
dbName = "audit",
targetVersion = 5,
postCheck = connection => {
val migratedDb = dbs.audit
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(5))
}
assert(migratedDb.listSent(50, 150).toSet === Set(
ps1.copy(id = pp1.id, recipientAmount = pp1.amount, parts = pp1 :: Nil),
ps1.copy(id = pp2.id, recipientAmount = pp2.amount, parts = pp2 :: Nil)
@ -382,214 +370,194 @@ class AuditDbSpec extends AnyFunSuite {
assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2))
val postMigrationDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(5))
}
val ps2 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, randomKey.publicKey, Seq(
PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 160),
PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 165)
))
val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150)
postMigrationDb.add(ps2)
assert(postMigrationDb.listSent(155, 200) === Seq(ps2))
postMigrationDb.add(relayed3)
assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3))
}
}
)
}
test("handle migration version 4 -> 5") {
test("migrate audit database v4 -> v5/v6") {
val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105)
val relayed2 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(300 msat, randomBytes32), PaymentRelayed.Part(350 msat, randomBytes32)), Seq(PaymentRelayed.Part(600 msat, randomBytes32)), PlaceHolderPubKey, 0 msat, 110)
forAllDbs {
case dbs: TestPgDatabases =>
import fr.acinq.eclair.db.pg.PgUtils.getVersion
implicit val datasource: DataSource = dbs.datasource
migrationCheck(
dbs = dbs,
initializeTables = connection => {
// simulate existing previous version db
using(connection.createStatement()) { statement =>
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp BIGINT NOT NULL)")
// simulate existing previous version db
inTransaction { pg =>
using(pg.createStatement()) { statement =>
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)")
setVersion(statement, "audit", 4)
}
setVersion(statement, "audit", 4)
}
}
inTransaction { pg =>
using(pg.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(4))
}
}
val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105)
val relayed2 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(300 msat, randomBytes32), PaymentRelayed.Part(350 msat, randomBytes32)), Seq(PaymentRelayed.Part(600 msat, randomBytes32)), PlaceHolderPubKey, 0 msat, 110)
inTransaction { pg =>
using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, relayed1.paymentHash.toHex)
statement.setLong(2, relayed1.amountIn.toLong)
statement.setString(3, relayed1.fromChannelId.toHex)
statement.setString(4, "IN")
statement.setString(5, "channel")
statement.setLong(6, relayed1.timestamp)
statement.executeUpdate()
}
using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, relayed1.paymentHash.toHex)
statement.setLong(2, relayed1.amountOut.toLong)
statement.setString(3, relayed1.toChannelId.toHex)
statement.setString(4, "OUT")
statement.setString(5, "channel")
statement.setLong(6, relayed1.timestamp)
statement.executeUpdate()
}
for (incoming <- relayed2.incoming) {
using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, relayed2.paymentHash.toHex)
statement.setLong(2, incoming.amount.toLong)
statement.setString(3, incoming.channelId.toHex)
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, relayed1.paymentHash.toHex)
statement.setLong(2, relayed1.amountIn.toLong)
statement.setString(3, relayed1.fromChannelId.toHex)
statement.setString(4, "IN")
statement.setString(5, "trampoline")
statement.setLong(6, relayed2.timestamp)
statement.setString(5, "channel")
statement.setLong(6, relayed1.timestamp)
statement.executeUpdate()
}
}
for (outgoing <- relayed2.outgoing) {
using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, relayed2.paymentHash.toHex)
statement.setLong(2, outgoing.amount.toLong)
statement.setString(3, outgoing.channelId.toHex)
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, relayed1.paymentHash.toHex)
statement.setLong(2, relayed1.amountOut.toLong)
statement.setString(3, relayed1.toChannelId.toHex)
statement.setString(4, "OUT")
statement.setString(5, "trampoline")
statement.setLong(6, relayed2.timestamp)
statement.setString(5, "channel")
statement.setLong(6, relayed1.timestamp)
statement.executeUpdate()
}
for (incoming <- relayed2.incoming) {
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, relayed2.paymentHash.toHex)
statement.setLong(2, incoming.amount.toLong)
statement.setString(3, incoming.channelId.toHex)
statement.setString(4, "IN")
statement.setString(5, "trampoline")
statement.setLong(6, relayed2.timestamp)
statement.executeUpdate()
}
}
for (outgoing <- relayed2.outgoing) {
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, relayed2.paymentHash.toHex)
statement.setLong(2, outgoing.amount.toLong)
statement.setString(3, outgoing.channelId.toHex)
statement.setString(4, "OUT")
statement.setString(5, "trampoline")
statement.setLong(6, relayed2.timestamp)
statement.executeUpdate()
}
}
},
dbName = "audit",
targetVersion = 6,
postCheck = connection => {
val migratedDb = dbs.audit
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(6))
}
assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2))
val postMigrationDb = new PgAuditDb()(dbs.datasource)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(6))
}
val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150)
postMigrationDb.add(relayed3)
assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3))
}
}
val migratedDb = new PgAuditDb()(datasource)
inTransaction { pg =>
using(pg.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(5))
}
}
assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2))
val postMigrationDb = new PgAuditDb()(datasource)
inTransaction { pg =>
using(pg.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(5))
}
}
val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150)
postMigrationDb.add(relayed3)
assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3))
)
case dbs: TestSqliteDatabases =>
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
val connection = dbs.connection
migrationCheck(
dbs = dbs,
initializeTables = connection => {
// simulate existing previous version db
using(connection.createStatement()) { statement =>
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, recipient_amount_msat INTEGER NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, recipient_node_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, channel_id BLOB NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
// simulate existing previous version db
using(connection.createStatement()) { statement =>
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, recipient_amount_msat INTEGER NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, recipient_node_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, channel_id BLOB NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)")
setVersion(statement, "audit", 4)
}
setVersion(statement, "audit", 4)
}
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, relayed1.paymentHash.toArray)
statement.setLong(2, relayed1.amountIn.toLong)
statement.setBytes(3, relayed1.fromChannelId.toArray)
statement.setString(4, "IN")
statement.setString(5, "channel")
statement.setLong(6, relayed1.timestamp)
statement.executeUpdate()
}
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, relayed1.paymentHash.toArray)
statement.setLong(2, relayed1.amountOut.toLong)
statement.setBytes(3, relayed1.toChannelId.toArray)
statement.setString(4, "OUT")
statement.setString(5, "channel")
statement.setLong(6, relayed1.timestamp)
statement.executeUpdate()
}
for (incoming <- relayed2.incoming) {
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, relayed2.paymentHash.toArray)
statement.setLong(2, incoming.amount.toLong)
statement.setBytes(3, incoming.channelId.toArray)
statement.setString(4, "IN")
statement.setString(5, "trampoline")
statement.setLong(6, relayed2.timestamp)
statement.executeUpdate()
}
}
for (outgoing <- relayed2.outgoing) {
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, relayed2.paymentHash.toArray)
statement.setLong(2, outgoing.amount.toLong)
statement.setBytes(3, outgoing.channelId.toArray)
statement.setString(4, "OUT")
statement.setString(5, "trampoline")
statement.setLong(6, relayed2.timestamp)
statement.executeUpdate()
}
}
},
dbName = "audit",
targetVersion = 5,
postCheck = connection => {
val migratedDb = dbs.audit
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(5))
}
assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2))
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(4))
}
val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105)
val relayed2 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(300 msat, randomBytes32), PaymentRelayed.Part(350 msat, randomBytes32)), Seq(PaymentRelayed.Part(600 msat, randomBytes32)), PlaceHolderPubKey, 0 msat, 110)
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, relayed1.paymentHash.toArray)
statement.setLong(2, relayed1.amountIn.toLong)
statement.setBytes(3, relayed1.fromChannelId.toArray)
statement.setString(4, "IN")
statement.setString(5, "channel")
statement.setLong(6, relayed1.timestamp)
statement.executeUpdate()
}
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, relayed1.paymentHash.toArray)
statement.setLong(2, relayed1.amountOut.toLong)
statement.setBytes(3, relayed1.toChannelId.toArray)
statement.setString(4, "OUT")
statement.setString(5, "channel")
statement.setLong(6, relayed1.timestamp)
statement.executeUpdate()
}
for (incoming <- relayed2.incoming) {
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, relayed2.paymentHash.toArray)
statement.setLong(2, incoming.amount.toLong)
statement.setBytes(3, incoming.channelId.toArray)
statement.setString(4, "IN")
statement.setString(5, "trampoline")
statement.setLong(6, relayed2.timestamp)
statement.executeUpdate()
val postMigrationDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(5))
}
val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150)
postMigrationDb.add(relayed3)
assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3))
}
}
for (outgoing <- relayed2.outgoing) {
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, relayed2.paymentHash.toArray)
statement.setLong(2, outgoing.amount.toLong)
statement.setBytes(3, outgoing.channelId.toArray)
statement.setString(4, "OUT")
statement.setString(5, "trampoline")
statement.setLong(6, relayed2.timestamp)
statement.executeUpdate()
}
}
val migratedDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(5))
}
assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2))
val postMigrationDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit").contains(5))
}
val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150)
postMigrationDb.add(relayed3)
assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3))
)
}
}
@ -605,7 +573,7 @@ class AuditDbSpec extends AnyFunSuite {
if (isPg) statement.setString(3, randomBytes32.toHex) else statement.setBytes(3, randomBytes32.toArray)
statement.setString(4, "IN")
statement.setString(5, "unknown") // invalid relay type
statement.setLong(6, 10)
if (isPg) statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(10))) else statement.setLong(6, 10)
statement.executeUpdate()
}
@ -615,7 +583,7 @@ class AuditDbSpec extends AnyFunSuite {
if (isPg) statement.setString(3, randomBytes32.toHex) else statement.setBytes(3, randomBytes32.toArray)
statement.setString(4, "UP") // invalid direction
statement.setString(5, "channel")
statement.setLong(6, 20)
if (isPg) statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(20))) else statement.setLong(6, 20)
statement.executeUpdate()
}
@ -628,7 +596,7 @@ class AuditDbSpec extends AnyFunSuite {
if (isPg) statement.setString(3, channelId.toHex) else statement.setBytes(3, channelId.toArray)
statement.setString(4, "IN") // missing a corresponding OUT
statement.setString(5, "channel")
statement.setLong(6, 30)
if (isPg) statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(30))) else statement.setLong(6, 30)
statement.executeUpdate()
}

View File

@ -18,23 +18,25 @@ package fr.acinq.eclair.db
import com.softwaremill.quicklens._
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases}
import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases, migrationCheck}
import fr.acinq.eclair.db.ChannelsDbSpec.{getPgTimestamp, getTimestamp, testCases}
import fr.acinq.eclair.db.DbEventHandler.ChannelEvent
import fr.acinq.eclair.db.jdbc.JdbcUtils.using
import fr.acinq.eclair.db.pg.PgChannelsDb
import fr.acinq.eclair.db.pg.PgUtils.{getVersion, setVersion}
import fr.acinq.eclair.db.pg.{PgChannelsDb, PgUtils}
import fr.acinq.eclair.db.sqlite.SqliteChannelsDb
import fr.acinq.eclair.db.sqlite.SqliteUtils.ExtendedResultSet._
import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec
import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec
import fr.acinq.eclair.{CltvExpiry, ShortChannelId, randomBytes32}
import fr.acinq.eclair.{CltvExpiry, ShortChannelId, TestDatabases, randomBytes32}
import org.scalatest.funsuite.AnyFunSuite
import scodec.bits.ByteVector
import java.sql.SQLException
import java.sql.{Connection, SQLException}
import java.util.concurrent.Executors
import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future}
import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future}
import scala.util.Random
class ChannelsDbSpec extends AnyFunSuite {
@ -107,56 +109,42 @@ class ChannelsDbSpec extends AnyFunSuite {
test("channel metadata") {
forAllDbs { dbs =>
val db = dbs.channels
val connection = dbs.connection
val channel1 = ChannelCodecsSpec.normal
val channel2 = channel1.modify(_.commitments.channelId).setTo(randomBytes32)
def getTimestamp(channelId: ByteVector32, columnName: String): Option[Long] = {
using(connection.prepareStatement(s"SELECT $columnName FROM local_channels WHERE channel_id=?")) { statement =>
// data type differs depending on underlying database system
dbs match {
case _: TestPgDatabases => statement.setString(1, channelId.toHex)
case _: TestSqliteDatabases => statement.setBytes(1, channelId.toArray)
}
val rs = statement.executeQuery()
rs.next()
rs.getLongNullable(columnName)
}
}
// first we add channels
db.addOrUpdateChannel(channel1)
db.addOrUpdateChannel(channel2)
// make sure initially all metadata are empty
assert(getTimestamp(channel1.channelId, "created_timestamp").isEmpty)
assert(getTimestamp(channel1.channelId, "last_payment_sent_timestamp").isEmpty)
assert(getTimestamp(channel1.channelId, "last_payment_received_timestamp").isEmpty)
assert(getTimestamp(channel1.channelId, "last_connected_timestamp").isEmpty)
assert(getTimestamp(channel1.channelId, "closed_timestamp").isEmpty)
assert(getTimestamp(dbs, channel1.channelId, "created_timestamp").isEmpty)
assert(getTimestamp(dbs, channel1.channelId, "last_payment_sent_timestamp").isEmpty)
assert(getTimestamp(dbs, channel1.channelId, "last_payment_received_timestamp").isEmpty)
assert(getTimestamp(dbs, channel1.channelId, "last_connected_timestamp").isEmpty)
assert(getTimestamp(dbs, channel1.channelId, "closed_timestamp").isEmpty)
db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.Created)
assert(getTimestamp(channel1.channelId, "created_timestamp").nonEmpty)
assert(getTimestamp(dbs, channel1.channelId, "created_timestamp").nonEmpty)
db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.PaymentSent)
assert(getTimestamp(channel1.channelId, "last_payment_sent_timestamp").nonEmpty)
assert(getTimestamp(dbs, channel1.channelId, "last_payment_sent_timestamp").nonEmpty)
db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.PaymentReceived)
assert(getTimestamp(channel1.channelId, "last_payment_received_timestamp").nonEmpty)
assert(getTimestamp(dbs, channel1.channelId, "last_payment_received_timestamp").nonEmpty)
db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.Connected)
assert(getTimestamp(channel1.channelId, "last_connected_timestamp").nonEmpty)
assert(getTimestamp(dbs, channel1.channelId, "last_connected_timestamp").nonEmpty)
db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.Closed(null))
assert(getTimestamp(channel1.channelId, "closed_timestamp").nonEmpty)
assert(getTimestamp(dbs, channel1.channelId, "closed_timestamp").nonEmpty)
// make sure all metadata are still empty for channel 2
assert(getTimestamp(channel2.channelId, "created_timestamp").isEmpty)
assert(getTimestamp(channel2.channelId, "last_payment_sent_timestamp").isEmpty)
assert(getTimestamp(channel2.channelId, "last_payment_received_timestamp").isEmpty)
assert(getTimestamp(channel2.channelId, "last_connected_timestamp").isEmpty)
assert(getTimestamp(channel2.channelId, "closed_timestamp").isEmpty)
assert(getTimestamp(dbs, channel2.channelId, "created_timestamp").isEmpty)
assert(getTimestamp(dbs, channel2.channelId, "last_payment_sent_timestamp").isEmpty)
assert(getTimestamp(dbs, channel2.channelId, "last_payment_received_timestamp").isEmpty)
assert(getTimestamp(dbs, channel2.channelId, "last_connected_timestamp").isEmpty)
assert(getTimestamp(dbs, channel2.channelId, "closed_timestamp").isEmpty)
}
}
@ -175,13 +163,22 @@ class ChannelsDbSpec extends AnyFunSuite {
setVersion(statement, "channels", 1)
}
// insert 1 row
val channel = ChannelCodecsSpec.normal
val data = stateDataCodec.encode(channel).require.toByteArray
using(sqlite.prepareStatement("INSERT INTO local_channels VALUES (?, ?)")) { statement =>
statement.setBytes(1, channel.channelId.toArray)
statement.setBytes(2, data)
statement.executeUpdate()
// insert data
for (testCase <- testCases) {
using(sqlite.prepareStatement("INSERT INTO local_channels VALUES (?, ?)")) { statement =>
statement.setBytes(1, testCase.channelId.toArray)
statement.setBytes(2, testCase.data.toArray)
statement.executeUpdate()
}
for (commitmentNumber <- testCase.commitmentNumbers) {
using(sqlite.prepareStatement("INSERT INTO htlc_infos (channel_id, commitment_number, payment_hash, cltv_expiry) VALUES (?, ?, ?, ?)")) { statement =>
statement.setBytes(1, testCase.channelId.toArray)
statement.setLong(2, commitmentNumber)
statement.setBytes(3, randomBytes32.toArray)
statement.setLong(4, 500000 + Random.nextInt(500000))
statement.executeUpdate()
}
}
}
// check that db migration works
@ -189,71 +186,193 @@ class ChannelsDbSpec extends AnyFunSuite {
using(sqlite.createStatement()) { statement =>
assert(getVersion(statement, "channels").contains(3))
}
assert(db.listLocalChannels() === List(channel))
db.updateChannelMeta(channel.channelId, ChannelEvent.EventType.Created) // this call must not fail
assert(db.listLocalChannels().size === testCases.size)
for (testCase <- testCases) {
db.updateChannelMeta(testCase.channelId, ChannelEvent.EventType.Created) // this call must not fail
for (commitmentNumber <- testCase.commitmentNumbers) {
assert(db.listHtlcInfos(testCase.channelId, commitmentNumber).size === testCase.commitmentNumbers.count(_ == commitmentNumber))
}
}
}
}
test("migrate channel database v2 -> v3") {
test("migrate channel database v2 -> v3/v4") {
def postCheck(channelsDb: ChannelsDb): Unit = {
assert(channelsDb.listLocalChannels().size === testCases.filterNot(_.isClosed).size)
for (testCase <- testCases.filterNot(_.isClosed)) {
channelsDb.updateChannelMeta(testCase.channelId, ChannelEvent.EventType.Created) // this call must not fail
for (commitmentNumber <- testCase.commitmentNumbers) {
assert(channelsDb.listHtlcInfos(testCase.channelId, commitmentNumber).size === testCase.commitmentNumbers.count(_ == commitmentNumber))
}
}
}
forAllDbs {
case dbs: TestPgDatabases =>
val pg = dbs.connection
// create a v2 channels database
using(pg.createStatement()) { statement =>
statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
setVersion(statement, "channels", 2)
}
// insert 1 row
val channel = ChannelCodecsSpec.normal
val data = stateDataCodec.encode(channel).require.toByteArray
using(pg.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed) VALUES (?, ?, ?)")) { statement =>
statement.setString(1, channel.channelId.toHex)
statement.setBytes(2, data)
statement.setBoolean(3, false)
statement.executeUpdate()
}
// check that db migration works
val db = dbs.channels
using(pg.createStatement()) { statement =>
assert(getVersion(statement, "channels").contains(3))
}
assert(db.listLocalChannels() === List(channel))
db.updateChannelMeta(channel.channelId, ChannelEvent.EventType.Created) // this call must not fail
migrationCheck(
dbs = dbs,
initializeTables = connection => {
// initialize a v2 database
using(connection.createStatement()) { statement =>
statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
setVersion(statement, "channels", 2)
}
// insert data
testCases.foreach { testCase =>
using(connection.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed) VALUES (?, ?, ?)")) { statement =>
statement.setString(1, testCase.channelId.toHex)
statement.setBytes(2, testCase.data.toArray)
statement.setBoolean(3, testCase.isClosed)
statement.executeUpdate()
for (commitmentNumber <- testCase.commitmentNumbers) {
using(connection.prepareStatement("INSERT INTO htlc_infos (channel_id, commitment_number, payment_hash, cltv_expiry) VALUES (?, ?, ?, ?)")) { statement =>
statement.setString(1, testCase.channelId.toHex)
statement.setLong(2, commitmentNumber)
statement.setString(3, randomBytes32.toHex)
statement.setLong(4, 500000 + Random.nextInt(500000))
statement.executeUpdate()
}
}
}
}
},
dbName = "channels",
targetVersion = 4,
postCheck = _ => postCheck(dbs.channels)
)
case dbs: TestSqliteDatabases =>
val sqlite = dbs.connection
migrationCheck(
dbs = dbs,
initializeTables = connection => {
// create a v2 channels database
using(connection.createStatement()) { statement =>
statement.execute("PRAGMA foreign_keys = ON")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
setVersion(statement, "channels", 2)
}
// insert data
testCases.foreach { testCase =>
using(connection.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed) VALUES (?, ?, ?)")) { statement =>
statement.setBytes(1, testCase.channelId.toArray)
statement.setBytes(2, testCase.data.toArray)
statement.setBoolean(3, testCase.isClosed)
statement.executeUpdate()
for (commitmentNumber <- testCase.commitmentNumbers) {
using(connection.prepareStatement("INSERT INTO htlc_infos (channel_id, commitment_number, payment_hash, cltv_expiry) VALUES (?, ?, ?, ?)")) { statement =>
statement.setBytes(1, testCase.channelId.toArray)
statement.setLong(2, commitmentNumber)
statement.setBytes(3, randomBytes32.toArray)
statement.setLong(4, 500000 + Random.nextInt(500000))
statement.executeUpdate()
}
}
}
}
},
dbName = "channels",
targetVersion = 3,
postCheck = _ => postCheck(dbs.channels)
)
}
}
// create a v2 channels database
using(sqlite.createStatement()) { statement =>
statement.execute("PRAGMA foreign_keys = ON")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
setVersion(statement, "channels", 2)
}
test("migrate pg channel database v3->v4") {
val dbs = TestPgDatabases()
// insert 1 row
val channel = ChannelCodecsSpec.normal
val data = stateDataCodec.encode(channel).require.toByteArray
using(sqlite.prepareStatement("INSERT INTO local_channels VALUES (?, ?, ?)")) { statement =>
statement.setBytes(1, channel.channelId.toArray)
statement.setBytes(2, data)
statement.setBoolean(3, false)
statement.executeUpdate()
migrationCheck(
dbs = dbs,
initializeTables = connection => {
using(connection.createStatement()) { statement =>
// initialize a v3 database
statement.executeUpdate("CREATE TABLE local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE, created_timestamp BIGINT, last_payment_sent_timestamp BIGINT, last_payment_received_timestamp BIGINT, last_connected_timestamp BIGINT, closed_timestamp BIGINT)")
statement.executeUpdate("CREATE TABLE htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
statement.executeUpdate("CREATE INDEX htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
PgUtils.setVersion(statement, "channels", 3)
}
// insert data
testCases.foreach { testCase =>
using(connection.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed, created_timestamp, last_payment_sent_timestamp, last_payment_received_timestamp, last_connected_timestamp, closed_timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, testCase.channelId.toHex)
statement.setBytes(2, testCase.data.toArray)
statement.setBoolean(3, testCase.isClosed)
statement.setObject(4, testCase.createdTimestamp.orNull)
statement.setObject(5, testCase.lastPaymentSentTimestamp.orNull)
statement.setObject(6, testCase.lastPaymentReceivedTimestamp.orNull)
statement.setObject(7, testCase.lastConnectedTimestamp.orNull)
statement.setObject(8, testCase.closedTimestamp.orNull)
statement.executeUpdate()
}
}
},
dbName = "channels",
targetVersion = 4,
postCheck = connection => {
assert(dbs.channels.listLocalChannels().size === testCases.filterNot(_.isClosed).size)
testCases.foreach { testCase =>
assert(getPgTimestamp(connection, testCase.channelId, "created_timestamp") === testCase.createdTimestamp)
assert(getPgTimestamp(connection, testCase.channelId, "last_payment_sent_timestamp") === testCase.lastPaymentSentTimestamp)
assert(getPgTimestamp(connection, testCase.channelId, "last_payment_received_timestamp") === testCase.lastPaymentReceivedTimestamp)
assert(getPgTimestamp(connection, testCase.channelId, "last_connected_timestamp") === testCase.lastConnectedTimestamp)
assert(getPgTimestamp(connection, testCase.channelId, "closed_timestamp") === testCase.closedTimestamp)
}
}
)
// check that db migration works
val db = dbs.channels
using(sqlite.createStatement()) { statement =>
assert(getVersion(statement, "channels").contains(3))
}
assert(db.listLocalChannels() === List(channel))
db.updateChannelMeta(channel.channelId, ChannelEvent.EventType.Created) // this call must not fail
}
}
object ChannelsDbSpec {
case class TestCase(channelId: ByteVector32,
data: ByteVector,
isClosed: Boolean,
createdTimestamp: Option[Long],
lastPaymentSentTimestamp: Option[Long],
lastPaymentReceivedTimestamp: Option[Long],
lastConnectedTimestamp: Option[Long],
closedTimestamp: Option[Long],
commitmentNumbers: Seq[Int]
)
private val data = stateDataCodec.encode(ChannelCodecsSpec.normal).require.bytes
val testCases: Seq[TestCase] = for (_ <- 0 until 10) yield TestCase(
channelId = randomBytes32,
data = data,
isClosed = Random.nextBoolean(),
createdTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None,
lastPaymentSentTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None,
lastPaymentReceivedTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None,
lastConnectedTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None,
closedTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None,
commitmentNumbers = for (_ <- 0 until Random.nextInt(10)) yield Random.nextInt(5) // there will be repetitions, on purpose
)
def getTimestamp(dbs: TestDatabases, channelId: ByteVector32, columnName: String): Option[Long] = {
dbs match {
case _: TestPgDatabases => getPgTimestamp(dbs.connection, channelId, columnName)
case _: TestSqliteDatabases => getSqliteTimestamp(dbs.connection, channelId, columnName)
}
}
def getSqliteTimestamp(connection: Connection, channelId: ByteVector32, columnName: String): Option[Long] = {
using(connection.prepareStatement(s"SELECT $columnName FROM local_channels WHERE channel_id=?")) { statement =>
statement.setBytes(1, channelId.toArray)
val rs = statement.executeQuery()
rs.next()
rs.getLongNullable(columnName)
}
}
def getPgTimestamp(connection: Connection, channelId: ByteVector32, columnName: String): Option[Long] = {
using(connection.prepareStatement(s"SELECT $columnName FROM local_channels WHERE channel_id=?")) { statement =>
statement.setString(1, channelId.toHex)
val rs = statement.executeQuery()
rs.next()
rs.getTimestampNullable(columnName).map(_.getTime)
}
}
}