mirror of
https://github.com/ACINQ/eclair.git
synced 2024-11-19 18:10:42 +01:00
Use proper data type for timestamps in Postgres (#1778)
Did some refactoring in tests and introduced a new `migrationCheck` helper method. Note that the change of data type in sqlite for the `commitment_number` field (from `BLOB` to `INTEGER`) is not a migration. If the table has been created before, it will stay like it was. It doesn't matter due to how sqlite stores data, and we make sure in tests that there is no regression.
This commit is contained in:
parent
4a1dfd2a27
commit
e14c40d7c3
@ -22,7 +22,7 @@ import org.sqlite.SQLiteConnection
|
||||
import scodec.Codec
|
||||
import scodec.bits.{BitVector, ByteVector}
|
||||
|
||||
import java.sql.{Connection, ResultSet, Statement}
|
||||
import java.sql.{Connection, ResultSet, Statement, Timestamp}
|
||||
import java.util.UUID
|
||||
import javax.sql.DataSource
|
||||
import scala.collection.immutable.Queue
|
||||
@ -123,18 +123,16 @@ trait JdbcUtils {
|
||||
|
||||
def getByteVector32FromHexNullable(columnLabel: String): Option[ByteVector32] = {
|
||||
val s = rs.getString(columnLabel)
|
||||
if (rs.wasNull()) None else {
|
||||
Some(ByteVector32(ByteVector.fromValidHex(s)))
|
||||
}
|
||||
if (rs.wasNull()) None else Some(ByteVector32(ByteVector.fromValidHex(s)))
|
||||
}
|
||||
|
||||
def getBitVectorOpt(columnLabel: String): Option[BitVector] = Option(rs.getBytes(columnLabel)).map(BitVector(_))
|
||||
|
||||
def getByteVector(columnLabel: String): ByteVector = ByteVector(rs.getBytes(columnLabel))
|
||||
|
||||
def getByteVectorNullable(columnLabel: String): ByteVector = {
|
||||
def getByteVectorNullable(columnLabel: String): Option[ByteVector] = {
|
||||
val result = rs.getBytes(columnLabel)
|
||||
if (rs.wasNull()) ByteVector.empty else ByteVector(result)
|
||||
if (rs.wasNull()) None else Some(ByteVector(result))
|
||||
}
|
||||
|
||||
def getByteVector32(columnLabel: String): ByteVector32 = ByteVector32(ByteVector(rs.getBytes(columnLabel)))
|
||||
@ -164,6 +162,11 @@ trait JdbcUtils {
|
||||
if (rs.wasNull()) None else Some(MilliSatoshi(result))
|
||||
}
|
||||
|
||||
def getTimestampNullable(label: String): Option[Timestamp] = {
|
||||
val result = rs.getTimestamp(label)
|
||||
if (rs.wasNull()) None else Some(result)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object ExtendedResultSet {
|
||||
|
@ -29,7 +29,8 @@ import fr.acinq.eclair.transactions.Transactions.PlaceHolderPubKey
|
||||
import fr.acinq.eclair.{MilliSatoshi, MilliSatoshiLong}
|
||||
import grizzled.slf4j.Logging
|
||||
|
||||
import java.sql.Statement
|
||||
import java.sql.{Statement, Timestamp}
|
||||
import java.time.Instant
|
||||
import java.util.UUID
|
||||
import javax.sql.DataSource
|
||||
import scala.collection.immutable.Queue
|
||||
@ -40,7 +41,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
import ExtendedResultSet._
|
||||
|
||||
val DB_NAME = "audit"
|
||||
val CURRENT_VERSION = 5
|
||||
val CURRENT_VERSION = 6
|
||||
|
||||
case class RelayedPart(channelId: ByteVector32, amount: MilliSatoshi, direction: String, relayType: String, timestamp: Long)
|
||||
|
||||
@ -52,15 +53,25 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
statement.executeUpdate("CREATE INDEX relayed_trampoline_payment_hash_idx ON relayed_trampoline(payment_hash)")
|
||||
}
|
||||
|
||||
def migration56(statement: Statement): Unit = {
|
||||
statement.executeUpdate("ALTER TABLE sent ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
|
||||
statement.executeUpdate("ALTER TABLE received ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
|
||||
statement.executeUpdate("ALTER TABLE relayed ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
|
||||
statement.executeUpdate("ALTER TABLE relayed_trampoline ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
|
||||
statement.executeUpdate("ALTER TABLE network_fees ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
|
||||
statement.executeUpdate("ALTER TABLE channel_events ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
|
||||
statement.executeUpdate("ALTER TABLE channel_errors ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
|
||||
}
|
||||
|
||||
getVersion(statement, DB_NAME) match {
|
||||
case None =>
|
||||
statement.executeUpdate("CREATE TABLE sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE relayed_trampoline (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, next_node_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE relayed_trampoline (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, next_node_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
|
||||
|
||||
statement.executeUpdate("CREATE INDEX sent_timestamp_idx ON sent(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX received_timestamp_idx ON received(timestamp)")
|
||||
@ -74,6 +85,10 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
case Some(v@4) =>
|
||||
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
|
||||
migration45(statement)
|
||||
migration56(statement)
|
||||
case Some(v@5) =>
|
||||
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
|
||||
migration56(statement)
|
||||
case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do
|
||||
case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
|
||||
}
|
||||
@ -90,7 +105,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
statement.setBoolean(4, e.isFunder)
|
||||
statement.setBoolean(5, e.isPrivate)
|
||||
statement.setString(6, e.event.label)
|
||||
statement.setLong(7, System.currentTimeMillis)
|
||||
statement.setTimestamp(7, Timestamp.from(Instant.now()))
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
@ -109,7 +124,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
statement.setString(7, e.paymentPreimage.toHex)
|
||||
statement.setString(8, e.recipientNodeId.value.toHex)
|
||||
statement.setString(9, p.toChannelId.toHex)
|
||||
statement.setLong(10, p.timestamp)
|
||||
statement.setTimestamp(10, Timestamp.from(Instant.ofEpochMilli(p.timestamp)))
|
||||
statement.addBatch()
|
||||
})
|
||||
statement.executeBatch()
|
||||
@ -124,7 +139,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
statement.setLong(1, p.amount.toLong)
|
||||
statement.setString(2, e.paymentHash.toHex)
|
||||
statement.setString(3, p.fromChannelId.toHex)
|
||||
statement.setLong(4, p.timestamp)
|
||||
statement.setTimestamp(4, Timestamp.from(Instant.ofEpochMilli(p.timestamp)))
|
||||
statement.addBatch()
|
||||
})
|
||||
statement.executeBatch()
|
||||
@ -143,7 +158,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
statement.setString(1, e.paymentHash.toHex)
|
||||
statement.setLong(2, nextTrampolineAmount.toLong)
|
||||
statement.setString(3, nextTrampolineNodeId.value.toHex)
|
||||
statement.setLong(4, e.timestamp)
|
||||
statement.setTimestamp(4, Timestamp.from(Instant.ofEpochMilli(e.timestamp)))
|
||||
statement.executeUpdate()
|
||||
}
|
||||
// trampoline relayed payments do MPP aggregation and may have M inputs and N outputs
|
||||
@ -156,7 +171,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
statement.setString(3, p.channelId.toHex)
|
||||
statement.setString(4, p.direction)
|
||||
statement.setString(5, p.relayType)
|
||||
statement.setLong(6, e.timestamp)
|
||||
statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(e.timestamp)))
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
@ -171,7 +186,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
statement.setString(3, e.tx.txid.toHex)
|
||||
statement.setLong(4, e.fee.toLong)
|
||||
statement.setString(5, e.txType)
|
||||
statement.setLong(6, System.currentTimeMillis)
|
||||
statement.setTimestamp(6, Timestamp.from(Instant.now()))
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
@ -189,7 +204,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
statement.setString(3, errorName)
|
||||
statement.setString(4, errorMessage)
|
||||
statement.setBoolean(5, e.isFatal)
|
||||
statement.setLong(6, System.currentTimeMillis)
|
||||
statement.setTimestamp(6, Timestamp.from(Instant.now()))
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
@ -197,9 +212,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
|
||||
override def listSent(from: Long, to: Long): Seq[PaymentSent] =
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp >= ? AND timestamp < ?")) { statement =>
|
||||
statement.setLong(1, from)
|
||||
statement.setLong(2, to)
|
||||
using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp BETWEEN ? AND ?")) { statement =>
|
||||
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
|
||||
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
|
||||
val rs = statement.executeQuery()
|
||||
var sentByParentId = Map.empty[UUID, PaymentSent]
|
||||
while (rs.next()) {
|
||||
@ -210,7 +225,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
MilliSatoshi(rs.getLong("fees_msat")),
|
||||
rs.getByteVector32FromHex("to_channel_id"),
|
||||
None, // we don't store the route in the audit DB
|
||||
rs.getLong("timestamp"))
|
||||
rs.getTimestamp("timestamp").getTime)
|
||||
val sent = sentByParentId.get(parentId) match {
|
||||
case Some(s) => s.copy(parts = s.parts :+ part)
|
||||
case None => PaymentSent(
|
||||
@ -229,9 +244,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
|
||||
override def listReceived(from: Long, to: Long): Seq[PaymentReceived] =
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("SELECT * FROM received WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
|
||||
statement.setLong(1, from)
|
||||
statement.setLong(2, to)
|
||||
using(pg.prepareStatement("SELECT * FROM received WHERE timestamp BETWEEN ? AND ?")) { statement =>
|
||||
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
|
||||
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
|
||||
val rs = statement.executeQuery()
|
||||
var receivedByHash = Map.empty[ByteVector32, PaymentReceived]
|
||||
while (rs.next()) {
|
||||
@ -239,7 +254,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
val part = PaymentReceived.PartialPayment(
|
||||
MilliSatoshi(rs.getLong("amount_msat")),
|
||||
rs.getByteVector32FromHex("from_channel_id"),
|
||||
rs.getLong("timestamp"))
|
||||
rs.getTimestamp("timestamp").getTime)
|
||||
val received = receivedByHash.get(paymentHash) match {
|
||||
case Some(r) => r.copy(parts = r.parts :+ part)
|
||||
case None => PaymentReceived(paymentHash, Seq(part))
|
||||
@ -253,9 +268,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] =
|
||||
inTransaction { pg =>
|
||||
var trampolineByHash = Map.empty[ByteVector32, (MilliSatoshi, PublicKey)]
|
||||
using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp >= ? AND timestamp < ?")) { statement =>
|
||||
statement.setLong(1, from)
|
||||
statement.setLong(2, to)
|
||||
using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement =>
|
||||
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
|
||||
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
|
||||
val rs = statement.executeQuery()
|
||||
while (rs.next()) {
|
||||
val paymentHash = rs.getByteVector32FromHex("payment_hash")
|
||||
@ -264,9 +279,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
trampolineByHash += (paymentHash -> (amount, nodeId))
|
||||
}
|
||||
}
|
||||
using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
|
||||
statement.setLong(1, from)
|
||||
statement.setLong(2, to)
|
||||
using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ?")) { statement =>
|
||||
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
|
||||
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
|
||||
val rs = statement.executeQuery()
|
||||
var relayedByHash = Map.empty[ByteVector32, Seq[RelayedPart]]
|
||||
while (rs.next()) {
|
||||
@ -276,7 +291,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
MilliSatoshi(rs.getLong("amount_msat")),
|
||||
rs.getString("direction"),
|
||||
rs.getString("relay_type"),
|
||||
rs.getLong("timestamp"))
|
||||
rs.getTimestamp("timestamp").getTime)
|
||||
relayedByHash = relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part))
|
||||
}
|
||||
relayedByHash.flatMap {
|
||||
@ -300,9 +315,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
|
||||
override def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] =
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
|
||||
statement.setLong(1, from)
|
||||
statement.setLong(2, to)
|
||||
using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp BETWEEN ? and ? ORDER BY timestamp")) { statement =>
|
||||
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
|
||||
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
|
||||
val rs = statement.executeQuery()
|
||||
var q: Queue[NetworkFee] = Queue()
|
||||
while (rs.next()) {
|
||||
@ -312,7 +327,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
txId = rs.getByteVector32FromHex("tx_id"),
|
||||
fee = Satoshi(rs.getLong("fee_sat")),
|
||||
txType = rs.getString("tx_type"),
|
||||
timestamp = rs.getLong("timestamp"))
|
||||
timestamp = rs.getTimestamp("timestamp").getTime)
|
||||
}
|
||||
q
|
||||
}
|
||||
|
@ -27,7 +27,8 @@ import fr.acinq.eclair.db.pg.PgUtils.PgLock
|
||||
import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec
|
||||
import grizzled.slf4j.Logging
|
||||
|
||||
import java.sql.Statement
|
||||
import java.sql.{Statement, Timestamp}
|
||||
import java.time.Instant
|
||||
import javax.sql.DataSource
|
||||
import scala.collection.immutable.Queue
|
||||
|
||||
@ -38,7 +39,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit
|
||||
import lock._
|
||||
|
||||
val DB_NAME = "channels"
|
||||
val CURRENT_VERSION = 3
|
||||
val CURRENT_VERSION = 4
|
||||
|
||||
inTransaction { pg =>
|
||||
using(pg.createStatement()) { statement =>
|
||||
@ -51,14 +52,28 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit
|
||||
statement.executeUpdate("ALTER TABLE local_channels ADD COLUMN closed_timestamp BIGINT")
|
||||
}
|
||||
|
||||
def migration34(statement: Statement): Unit = {
|
||||
statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN created_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + created_timestamp * interval '1 millisecond'")
|
||||
statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_payment_sent_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_payment_sent_timestamp * interval '1 millisecond'")
|
||||
statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_payment_received_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_payment_received_timestamp * interval '1 millisecond'")
|
||||
statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_connected_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_connected_timestamp * interval '1 millisecond'")
|
||||
statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN closed_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + closed_timestamp * interval '1 millisecond'")
|
||||
|
||||
statement.executeUpdate("ALTER TABLE htlc_infos ALTER COLUMN commitment_number SET DATA TYPE BIGINT USING commitment_number::BIGINT")
|
||||
}
|
||||
|
||||
getVersion(statement, DB_NAME) match {
|
||||
case None =>
|
||||
statement.executeUpdate("CREATE TABLE local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE, created_timestamp BIGINT, last_payment_sent_timestamp BIGINT, last_payment_received_timestamp BIGINT, last_connected_timestamp BIGINT, closed_timestamp BIGINT)")
|
||||
statement.executeUpdate("CREATE TABLE htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
|
||||
statement.executeUpdate("CREATE TABLE local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE, created_timestamp TIMESTAMP WITH TIME ZONE, last_payment_sent_timestamp TIMESTAMP WITH TIME ZONE, last_payment_received_timestamp TIMESTAMP WITH TIME ZONE, last_connected_timestamp TIMESTAMP WITH TIME ZONE, closed_timestamp TIMESTAMP WITH TIME ZONE)")
|
||||
statement.executeUpdate("CREATE TABLE htlc_infos (channel_id TEXT NOT NULL, commitment_number BIGINT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
|
||||
statement.executeUpdate("CREATE INDEX htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
|
||||
case Some(v@2) =>
|
||||
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
|
||||
migration23(statement)
|
||||
migration34(statement)
|
||||
case Some(v@3) =>
|
||||
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
|
||||
migration34(statement)
|
||||
case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do
|
||||
case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
|
||||
}
|
||||
@ -89,7 +104,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit
|
||||
private def updateChannelMetaTimestampColumn(channelId: ByteVector32, columnName: String): Unit = {
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement(s"UPDATE local_channels SET $columnName=? WHERE channel_id=?")) { statement =>
|
||||
statement.setLong(1, System.currentTimeMillis)
|
||||
statement.setTimestamp(1, Timestamp.from(Instant.now()))
|
||||
statement.setString(2, channelId.toHex)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
@ -152,7 +167,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("SELECT payment_hash, cltv_expiry FROM htlc_infos WHERE channel_id=? AND commitment_number=?")) { statement =>
|
||||
statement.setString(1, channelId.toHex)
|
||||
statement.setString(2, commitmentNumber.toString)
|
||||
statement.setLong(2, commitmentNumber)
|
||||
val rs = statement.executeQuery
|
||||
var q: Queue[(ByteVector32, CltvExpiry)] = Queue()
|
||||
while (rs.next()) {
|
||||
|
@ -23,7 +23,6 @@ import fr.acinq.eclair.db.ChannelsDb
|
||||
import fr.acinq.eclair.db.DbEventHandler.ChannelEvent
|
||||
import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics
|
||||
import fr.acinq.eclair.db.Monitoring.Tags.DbBackends
|
||||
import fr.acinq.eclair.payment.{ChannelPaymentRelayed, PaymentEvent, PaymentReceived, PaymentRelayed, PaymentSent}
|
||||
import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec
|
||||
import grizzled.slf4j.Logging
|
||||
|
||||
@ -64,7 +63,7 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging {
|
||||
getVersion(statement, DB_NAME) match {
|
||||
case None =>
|
||||
statement.executeUpdate("CREATE TABLE local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0, created_timestamp INTEGER, last_payment_sent_timestamp INTEGER, last_payment_received_timestamp INTEGER, last_connected_timestamp INTEGER, closed_timestamp INTEGER)")
|
||||
statement.executeUpdate("CREATE TABLE htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
|
||||
statement.executeUpdate("CREATE TABLE htlc_infos (channel_id BLOB NOT NULL, commitment_number INTEGER NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
|
||||
statement.executeUpdate("CREATE INDEX htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
|
||||
case Some(v@1) =>
|
||||
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
|
||||
|
@ -4,8 +4,8 @@ import akka.actor.ActorSystem
|
||||
import com.opentable.db.postgres.embedded.EmbeddedPostgres
|
||||
import com.zaxxer.hikari.HikariConfig
|
||||
import fr.acinq.eclair.db._
|
||||
import fr.acinq.eclair.db.pg.PgUtils.PgLock
|
||||
import fr.acinq.eclair.db.pg.PgUtils.PgLock.LockFailureHandler
|
||||
import fr.acinq.eclair.db.pg.PgUtils.{PgLock, getVersion, using}
|
||||
import org.postgresql.jdbc.PgConnection
|
||||
import org.sqlite.SQLiteConnection
|
||||
|
||||
@ -64,6 +64,7 @@ object TestDatabases {
|
||||
|
||||
// @formatter:off
|
||||
override val connection: PgConnection = pg.getPostgresDatabase.getConnection.asInstanceOf[PgConnection]
|
||||
// NB: we use a lazy val here: databases won't be initialized until we reference that variable
|
||||
override lazy val db: Databases = Databases.PostgresDatabases(hikariConfig, UUID.randomUUID(), lock, jdbcUrlFile_opt = Some(jdbcUrlFile), readOnlyUser_opt = None)
|
||||
override def close(): Unit = pg.close()
|
||||
// @formatter:on
|
||||
@ -77,4 +78,23 @@ object TestDatabases {
|
||||
// @formatter:on
|
||||
}
|
||||
|
||||
def migrationCheck(dbs: TestDatabases,
|
||||
initializeTables: Connection => Unit,
|
||||
dbName: String,
|
||||
targetVersion: Int,
|
||||
postCheck: Connection => Unit
|
||||
): Unit = {
|
||||
val connection = dbs.connection
|
||||
// initialize the database to a previous version and populate data
|
||||
initializeTables(connection)
|
||||
// this will trigger the initialization of tables and the migration
|
||||
val _ = dbs.db
|
||||
// check that db version was updated
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, dbName).contains(targetVersion))
|
||||
}
|
||||
// post-migration checks
|
||||
postCheck(connection)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ package fr.acinq.eclair.db
|
||||
|
||||
import fr.acinq.bitcoin.Crypto.PrivateKey
|
||||
import fr.acinq.bitcoin.{ByteVector32, SatoshiLong, Transaction}
|
||||
import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases}
|
||||
import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases, migrationCheck}
|
||||
import fr.acinq.eclair._
|
||||
import fr.acinq.eclair.channel.Helpers.Closing.MutualClose
|
||||
import fr.acinq.eclair.channel.{ChannelErrorOccurred, LocalError, NetworkFeePaid, RemoteError}
|
||||
@ -26,7 +26,7 @@ import fr.acinq.eclair.db.AuditDb.Stats
|
||||
import fr.acinq.eclair.db.DbEventHandler.ChannelEvent
|
||||
import fr.acinq.eclair.db.jdbc.JdbcUtils.using
|
||||
import fr.acinq.eclair.db.pg.PgAuditDb
|
||||
import fr.acinq.eclair.db.pg.PgUtils.{inTransaction, setVersion}
|
||||
import fr.acinq.eclair.db.pg.PgUtils.{getVersion, setVersion}
|
||||
import fr.acinq.eclair.db.sqlite.SqliteAuditDb
|
||||
import fr.acinq.eclair.payment._
|
||||
import fr.acinq.eclair.transactions.Transactions.PlaceHolderPubKey
|
||||
@ -34,8 +34,9 @@ import fr.acinq.eclair.wire.protocol.Error
|
||||
import org.scalatest.Tag
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import java.sql.Timestamp
|
||||
import java.time.Instant
|
||||
import java.util.UUID
|
||||
import javax.sql.DataSource
|
||||
import scala.concurrent.duration._
|
||||
import scala.util.Random
|
||||
|
||||
@ -182,13 +183,20 @@ class AuditDbSpec extends AnyFunSuite {
|
||||
}
|
||||
}
|
||||
|
||||
test("handle migration version 1 -> 5") {
|
||||
forAllDbs {
|
||||
case _: TestPgDatabases => // no migration
|
||||
case dbs: TestSqliteDatabases =>
|
||||
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
|
||||
val connection = dbs.connection
|
||||
test("migrate sqlite audit database v1 -> v5") {
|
||||
|
||||
val dbs = TestSqliteDatabases()
|
||||
|
||||
val ps = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 42000 msat, PrivateKey(ByteVector32.One).publicKey, PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None) :: Nil)
|
||||
val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, None)
|
||||
val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, None)
|
||||
val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 84003 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil)
|
||||
val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
|
||||
val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(Error(randomBytes32, "remote oops")), isFatal = true)
|
||||
|
||||
migrationCheck(
|
||||
dbs = dbs,
|
||||
initializeTables = connection => {
|
||||
// simulate existing previous version db
|
||||
using(connection.createStatement()) { statement =>
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
@ -208,17 +216,6 @@ class AuditDbSpec extends AnyFunSuite {
|
||||
setVersion(statement, "audit", 1)
|
||||
}
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(1))
|
||||
}
|
||||
|
||||
val ps = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 42000 msat, PrivateKey(ByteVector32.One).publicKey, PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None) :: Nil)
|
||||
val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, None)
|
||||
val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, None)
|
||||
val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 84003 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil)
|
||||
val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
|
||||
val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(Error(randomBytes32, "remote oops")), isFatal = true)
|
||||
|
||||
// add a row (no ID on sent)
|
||||
using(connection.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setLong(1, ps.recipientAmount.toLong)
|
||||
@ -229,15 +226,12 @@ class AuditDbSpec extends AnyFunSuite {
|
||||
statement.setLong(6, ps.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
val migratedDb = new SqliteAuditDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(5))
|
||||
}
|
||||
|
||||
},
|
||||
dbName = "audit",
|
||||
targetVersion = 5,
|
||||
postCheck = connection => {
|
||||
// existing rows in the 'sent' table will use id=00000000-0000-0000-0000-000000000000 as default
|
||||
assert(migratedDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID)))))
|
||||
assert(dbs.audit.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID)))))
|
||||
|
||||
val postMigrationDb = new SqliteAuditDb(connection)
|
||||
|
||||
@ -252,16 +246,19 @@ class AuditDbSpec extends AnyFunSuite {
|
||||
// the old record will have the UNKNOWN_UUID but the new ones will have their actual id
|
||||
val expected = Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID))), ps1)
|
||||
assert(postMigrationDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === expected)
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
test("handle migration version 2 -> 5") {
|
||||
forAllDbs {
|
||||
case _: TestPgDatabases => // no migration
|
||||
case dbs: TestSqliteDatabases =>
|
||||
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
|
||||
val connection = dbs.connection
|
||||
test("migrate sqlite audit database v2 -> v5") {
|
||||
val dbs = TestSqliteDatabases()
|
||||
|
||||
val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
|
||||
val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(Error(randomBytes32, "remote oops")), isFatal = true)
|
||||
|
||||
migrationCheck(
|
||||
dbs = dbs,
|
||||
initializeTables = connection => {
|
||||
// simulate existing previous version db
|
||||
using(connection.createStatement()) { statement =>
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
@ -280,39 +277,39 @@ class AuditDbSpec extends AnyFunSuite {
|
||||
|
||||
setVersion(statement, "audit", 2)
|
||||
}
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(2))
|
||||
}
|
||||
|
||||
val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
|
||||
val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(Error(randomBytes32, "remote oops")), isFatal = true)
|
||||
|
||||
val migratedDb = new SqliteAuditDb(connection)
|
||||
|
||||
},
|
||||
dbName = "audit",
|
||||
targetVersion = 5,
|
||||
postCheck = connection => {
|
||||
val migratedDb = dbs.audit
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(5))
|
||||
}
|
||||
|
||||
migratedDb.add(e1)
|
||||
|
||||
val postMigrationDb = new SqliteAuditDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(5))
|
||||
}
|
||||
|
||||
postMigrationDb.add(e2)
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
test("handle migration version 3 -> 5") {
|
||||
forAllDbs {
|
||||
case _: TestPgDatabases => // no migration
|
||||
case dbs: TestSqliteDatabases =>
|
||||
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
|
||||
val connection = dbs.connection
|
||||
test("migrate sqlite audit database v3 -> v5") {
|
||||
|
||||
val dbs = TestSqliteDatabases()
|
||||
|
||||
val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 100)
|
||||
val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 110)
|
||||
val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil)
|
||||
|
||||
val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105)
|
||||
val relayed2 = ChannelPaymentRelayed(650 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 115)
|
||||
|
||||
migrationCheck(
|
||||
dbs = dbs,
|
||||
initializeTables = connection => {
|
||||
// simulate existing previous version db
|
||||
using(connection.createStatement()) { statement =>
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
@ -334,14 +331,6 @@ class AuditDbSpec extends AnyFunSuite {
|
||||
setVersion(statement, "audit", 3)
|
||||
}
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(3))
|
||||
}
|
||||
|
||||
val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 100)
|
||||
val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 110)
|
||||
val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil)
|
||||
|
||||
for (pp <- Seq(pp1, pp2)) {
|
||||
using(connection.prepareStatement("INSERT INTO sent (amount_msat, fees_msat, payment_hash, payment_preimage, to_channel_id, timestamp, id) VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setLong(1, pp.amount.toLong)
|
||||
@ -355,9 +344,6 @@ class AuditDbSpec extends AnyFunSuite {
|
||||
}
|
||||
}
|
||||
|
||||
val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105)
|
||||
val relayed2 = ChannelPaymentRelayed(650 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 115)
|
||||
|
||||
for (relayed <- Seq(relayed1, relayed2)) {
|
||||
using(connection.prepareStatement("INSERT INTO relayed (amount_in_msat, amount_out_msat, payment_hash, from_channel_id, to_channel_id, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setLong(1, relayed.amountIn.toLong)
|
||||
@ -369,12 +355,14 @@ class AuditDbSpec extends AnyFunSuite {
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
val migratedDb = new SqliteAuditDb(connection)
|
||||
},
|
||||
dbName = "audit",
|
||||
targetVersion = 5,
|
||||
postCheck = connection => {
|
||||
val migratedDb = dbs.audit
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(5))
|
||||
}
|
||||
|
||||
assert(migratedDb.listSent(50, 150).toSet === Set(
|
||||
ps1.copy(id = pp1.id, recipientAmount = pp1.amount, parts = pp1 :: Nil),
|
||||
ps1.copy(id = pp2.id, recipientAmount = pp2.amount, parts = pp2 :: Nil)
|
||||
@ -382,214 +370,194 @@ class AuditDbSpec extends AnyFunSuite {
|
||||
assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2))
|
||||
|
||||
val postMigrationDb = new SqliteAuditDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(5))
|
||||
}
|
||||
|
||||
val ps2 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, randomKey.publicKey, Seq(
|
||||
PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 160),
|
||||
PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 165)
|
||||
))
|
||||
val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150)
|
||||
|
||||
postMigrationDb.add(ps2)
|
||||
assert(postMigrationDb.listSent(155, 200) === Seq(ps2))
|
||||
postMigrationDb.add(relayed3)
|
||||
assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3))
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
test("handle migration version 4 -> 5") {
|
||||
test("migrate audit database v4 -> v5/v6") {
|
||||
|
||||
val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105)
|
||||
val relayed2 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(300 msat, randomBytes32), PaymentRelayed.Part(350 msat, randomBytes32)), Seq(PaymentRelayed.Part(600 msat, randomBytes32)), PlaceHolderPubKey, 0 msat, 110)
|
||||
|
||||
forAllDbs {
|
||||
case dbs: TestPgDatabases =>
|
||||
import fr.acinq.eclair.db.pg.PgUtils.getVersion
|
||||
implicit val datasource: DataSource = dbs.datasource
|
||||
migrationCheck(
|
||||
dbs = dbs,
|
||||
initializeTables = connection => {
|
||||
// simulate existing previous version db
|
||||
using(connection.createStatement()) { statement =>
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
|
||||
// simulate existing previous version db
|
||||
inTransaction { pg =>
|
||||
using(pg.createStatement()) { statement =>
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)")
|
||||
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)")
|
||||
setVersion(statement, "audit", 4)
|
||||
}
|
||||
|
||||
setVersion(statement, "audit", 4)
|
||||
}
|
||||
}
|
||||
|
||||
inTransaction { pg =>
|
||||
using(pg.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(4))
|
||||
}
|
||||
}
|
||||
|
||||
val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105)
|
||||
val relayed2 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(300 msat, randomBytes32), PaymentRelayed.Part(350 msat, randomBytes32)), Seq(PaymentRelayed.Part(600 msat, randomBytes32)), PlaceHolderPubKey, 0 msat, 110)
|
||||
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, relayed1.paymentHash.toHex)
|
||||
statement.setLong(2, relayed1.amountIn.toLong)
|
||||
statement.setString(3, relayed1.fromChannelId.toHex)
|
||||
statement.setString(4, "IN")
|
||||
statement.setString(5, "channel")
|
||||
statement.setLong(6, relayed1.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, relayed1.paymentHash.toHex)
|
||||
statement.setLong(2, relayed1.amountOut.toLong)
|
||||
statement.setString(3, relayed1.toChannelId.toHex)
|
||||
statement.setString(4, "OUT")
|
||||
statement.setString(5, "channel")
|
||||
statement.setLong(6, relayed1.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
for (incoming <- relayed2.incoming) {
|
||||
using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, relayed2.paymentHash.toHex)
|
||||
statement.setLong(2, incoming.amount.toLong)
|
||||
statement.setString(3, incoming.channelId.toHex)
|
||||
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, relayed1.paymentHash.toHex)
|
||||
statement.setLong(2, relayed1.amountIn.toLong)
|
||||
statement.setString(3, relayed1.fromChannelId.toHex)
|
||||
statement.setString(4, "IN")
|
||||
statement.setString(5, "trampoline")
|
||||
statement.setLong(6, relayed2.timestamp)
|
||||
statement.setString(5, "channel")
|
||||
statement.setLong(6, relayed1.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
for (outgoing <- relayed2.outgoing) {
|
||||
using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, relayed2.paymentHash.toHex)
|
||||
statement.setLong(2, outgoing.amount.toLong)
|
||||
statement.setString(3, outgoing.channelId.toHex)
|
||||
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, relayed1.paymentHash.toHex)
|
||||
statement.setLong(2, relayed1.amountOut.toLong)
|
||||
statement.setString(3, relayed1.toChannelId.toHex)
|
||||
statement.setString(4, "OUT")
|
||||
statement.setString(5, "trampoline")
|
||||
statement.setLong(6, relayed2.timestamp)
|
||||
statement.setString(5, "channel")
|
||||
statement.setLong(6, relayed1.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
for (incoming <- relayed2.incoming) {
|
||||
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, relayed2.paymentHash.toHex)
|
||||
statement.setLong(2, incoming.amount.toLong)
|
||||
statement.setString(3, incoming.channelId.toHex)
|
||||
statement.setString(4, "IN")
|
||||
statement.setString(5, "trampoline")
|
||||
statement.setLong(6, relayed2.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
for (outgoing <- relayed2.outgoing) {
|
||||
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, relayed2.paymentHash.toHex)
|
||||
statement.setLong(2, outgoing.amount.toLong)
|
||||
statement.setString(3, outgoing.channelId.toHex)
|
||||
statement.setString(4, "OUT")
|
||||
statement.setString(5, "trampoline")
|
||||
statement.setLong(6, relayed2.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
},
|
||||
dbName = "audit",
|
||||
targetVersion = 6,
|
||||
postCheck = connection => {
|
||||
val migratedDb = dbs.audit
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(6))
|
||||
}
|
||||
assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2))
|
||||
|
||||
val postMigrationDb = new PgAuditDb()(dbs.datasource)
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(6))
|
||||
}
|
||||
val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150)
|
||||
postMigrationDb.add(relayed3)
|
||||
assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3))
|
||||
}
|
||||
}
|
||||
|
||||
val migratedDb = new PgAuditDb()(datasource)
|
||||
inTransaction { pg =>
|
||||
using(pg.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(5))
|
||||
}
|
||||
}
|
||||
|
||||
assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2))
|
||||
|
||||
val postMigrationDb = new PgAuditDb()(datasource)
|
||||
|
||||
inTransaction { pg =>
|
||||
using(pg.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(5))
|
||||
}
|
||||
}
|
||||
|
||||
val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150)
|
||||
|
||||
postMigrationDb.add(relayed3)
|
||||
assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3))
|
||||
)
|
||||
case dbs: TestSqliteDatabases =>
|
||||
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
|
||||
val connection = dbs.connection
|
||||
migrationCheck(
|
||||
dbs = dbs,
|
||||
initializeTables = connection => {
|
||||
// simulate existing previous version db
|
||||
using(connection.createStatement()) { statement =>
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, recipient_amount_msat INTEGER NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, recipient_node_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, channel_id BLOB NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
|
||||
// simulate existing previous version db
|
||||
using(connection.createStatement()) { statement =>
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, recipient_amount_msat INTEGER NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, recipient_node_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, channel_id BLOB NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)")
|
||||
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)")
|
||||
setVersion(statement, "audit", 4)
|
||||
}
|
||||
|
||||
setVersion(statement, "audit", 4)
|
||||
}
|
||||
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, relayed1.paymentHash.toArray)
|
||||
statement.setLong(2, relayed1.amountIn.toLong)
|
||||
statement.setBytes(3, relayed1.fromChannelId.toArray)
|
||||
statement.setString(4, "IN")
|
||||
statement.setString(5, "channel")
|
||||
statement.setLong(6, relayed1.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, relayed1.paymentHash.toArray)
|
||||
statement.setLong(2, relayed1.amountOut.toLong)
|
||||
statement.setBytes(3, relayed1.toChannelId.toArray)
|
||||
statement.setString(4, "OUT")
|
||||
statement.setString(5, "channel")
|
||||
statement.setLong(6, relayed1.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
for (incoming <- relayed2.incoming) {
|
||||
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, relayed2.paymentHash.toArray)
|
||||
statement.setLong(2, incoming.amount.toLong)
|
||||
statement.setBytes(3, incoming.channelId.toArray)
|
||||
statement.setString(4, "IN")
|
||||
statement.setString(5, "trampoline")
|
||||
statement.setLong(6, relayed2.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
for (outgoing <- relayed2.outgoing) {
|
||||
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, relayed2.paymentHash.toArray)
|
||||
statement.setLong(2, outgoing.amount.toLong)
|
||||
statement.setBytes(3, outgoing.channelId.toArray)
|
||||
statement.setString(4, "OUT")
|
||||
statement.setString(5, "trampoline")
|
||||
statement.setLong(6, relayed2.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
},
|
||||
dbName = "audit",
|
||||
targetVersion = 5,
|
||||
postCheck = connection => {
|
||||
val migratedDb = dbs.audit
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(5))
|
||||
}
|
||||
assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2))
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(4))
|
||||
}
|
||||
|
||||
val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105)
|
||||
val relayed2 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(300 msat, randomBytes32), PaymentRelayed.Part(350 msat, randomBytes32)), Seq(PaymentRelayed.Part(600 msat, randomBytes32)), PlaceHolderPubKey, 0 msat, 110)
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, relayed1.paymentHash.toArray)
|
||||
statement.setLong(2, relayed1.amountIn.toLong)
|
||||
statement.setBytes(3, relayed1.fromChannelId.toArray)
|
||||
statement.setString(4, "IN")
|
||||
statement.setString(5, "channel")
|
||||
statement.setLong(6, relayed1.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, relayed1.paymentHash.toArray)
|
||||
statement.setLong(2, relayed1.amountOut.toLong)
|
||||
statement.setBytes(3, relayed1.toChannelId.toArray)
|
||||
statement.setString(4, "OUT")
|
||||
statement.setString(5, "channel")
|
||||
statement.setLong(6, relayed1.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
for (incoming <- relayed2.incoming) {
|
||||
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, relayed2.paymentHash.toArray)
|
||||
statement.setLong(2, incoming.amount.toLong)
|
||||
statement.setBytes(3, incoming.channelId.toArray)
|
||||
statement.setString(4, "IN")
|
||||
statement.setString(5, "trampoline")
|
||||
statement.setLong(6, relayed2.timestamp)
|
||||
statement.executeUpdate()
|
||||
val postMigrationDb = new SqliteAuditDb(connection)
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(5))
|
||||
}
|
||||
val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150)
|
||||
postMigrationDb.add(relayed3)
|
||||
assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3))
|
||||
}
|
||||
}
|
||||
for (outgoing <- relayed2.outgoing) {
|
||||
using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, relayed2.paymentHash.toArray)
|
||||
statement.setLong(2, outgoing.amount.toLong)
|
||||
statement.setBytes(3, outgoing.channelId.toArray)
|
||||
statement.setString(4, "OUT")
|
||||
statement.setString(5, "trampoline")
|
||||
statement.setLong(6, relayed2.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
val migratedDb = new SqliteAuditDb(connection)
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(5))
|
||||
}
|
||||
|
||||
assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2))
|
||||
|
||||
val postMigrationDb = new SqliteAuditDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit").contains(5))
|
||||
}
|
||||
|
||||
val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150)
|
||||
|
||||
postMigrationDb.add(relayed3)
|
||||
assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3))
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -605,7 +573,7 @@ class AuditDbSpec extends AnyFunSuite {
|
||||
if (isPg) statement.setString(3, randomBytes32.toHex) else statement.setBytes(3, randomBytes32.toArray)
|
||||
statement.setString(4, "IN")
|
||||
statement.setString(5, "unknown") // invalid relay type
|
||||
statement.setLong(6, 10)
|
||||
if (isPg) statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(10))) else statement.setLong(6, 10)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
@ -615,7 +583,7 @@ class AuditDbSpec extends AnyFunSuite {
|
||||
if (isPg) statement.setString(3, randomBytes32.toHex) else statement.setBytes(3, randomBytes32.toArray)
|
||||
statement.setString(4, "UP") // invalid direction
|
||||
statement.setString(5, "channel")
|
||||
statement.setLong(6, 20)
|
||||
if (isPg) statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(20))) else statement.setLong(6, 20)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
@ -628,7 +596,7 @@ class AuditDbSpec extends AnyFunSuite {
|
||||
if (isPg) statement.setString(3, channelId.toHex) else statement.setBytes(3, channelId.toArray)
|
||||
statement.setString(4, "IN") // missing a corresponding OUT
|
||||
statement.setString(5, "channel")
|
||||
statement.setLong(6, 30)
|
||||
if (isPg) statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(30))) else statement.setLong(6, 30)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
|
@ -18,23 +18,25 @@ package fr.acinq.eclair.db
|
||||
|
||||
import com.softwaremill.quicklens._
|
||||
import fr.acinq.bitcoin.ByteVector32
|
||||
import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases}
|
||||
import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases, migrationCheck}
|
||||
import fr.acinq.eclair.db.ChannelsDbSpec.{getPgTimestamp, getTimestamp, testCases}
|
||||
import fr.acinq.eclair.db.DbEventHandler.ChannelEvent
|
||||
import fr.acinq.eclair.db.jdbc.JdbcUtils.using
|
||||
import fr.acinq.eclair.db.pg.PgChannelsDb
|
||||
import fr.acinq.eclair.db.pg.PgUtils.{getVersion, setVersion}
|
||||
import fr.acinq.eclair.db.pg.{PgChannelsDb, PgUtils}
|
||||
import fr.acinq.eclair.db.sqlite.SqliteChannelsDb
|
||||
import fr.acinq.eclair.db.sqlite.SqliteUtils.ExtendedResultSet._
|
||||
import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec
|
||||
import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec
|
||||
import fr.acinq.eclair.{CltvExpiry, ShortChannelId, randomBytes32}
|
||||
import fr.acinq.eclair.{CltvExpiry, ShortChannelId, TestDatabases, randomBytes32}
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import scodec.bits.ByteVector
|
||||
|
||||
import java.sql.SQLException
|
||||
import java.sql.{Connection, SQLException}
|
||||
import java.util.concurrent.Executors
|
||||
import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future}
|
||||
import scala.concurrent.duration._
|
||||
import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future}
|
||||
import scala.util.Random
|
||||
|
||||
class ChannelsDbSpec extends AnyFunSuite {
|
||||
|
||||
@ -107,56 +109,42 @@ class ChannelsDbSpec extends AnyFunSuite {
|
||||
test("channel metadata") {
|
||||
forAllDbs { dbs =>
|
||||
val db = dbs.channels
|
||||
val connection = dbs.connection
|
||||
|
||||
val channel1 = ChannelCodecsSpec.normal
|
||||
val channel2 = channel1.modify(_.commitments.channelId).setTo(randomBytes32)
|
||||
|
||||
def getTimestamp(channelId: ByteVector32, columnName: String): Option[Long] = {
|
||||
using(connection.prepareStatement(s"SELECT $columnName FROM local_channels WHERE channel_id=?")) { statement =>
|
||||
// data type differs depending on underlying database system
|
||||
dbs match {
|
||||
case _: TestPgDatabases => statement.setString(1, channelId.toHex)
|
||||
case _: TestSqliteDatabases => statement.setBytes(1, channelId.toArray)
|
||||
}
|
||||
val rs = statement.executeQuery()
|
||||
rs.next()
|
||||
rs.getLongNullable(columnName)
|
||||
}
|
||||
}
|
||||
|
||||
// first we add channels
|
||||
db.addOrUpdateChannel(channel1)
|
||||
db.addOrUpdateChannel(channel2)
|
||||
|
||||
// make sure initially all metadata are empty
|
||||
assert(getTimestamp(channel1.channelId, "created_timestamp").isEmpty)
|
||||
assert(getTimestamp(channel1.channelId, "last_payment_sent_timestamp").isEmpty)
|
||||
assert(getTimestamp(channel1.channelId, "last_payment_received_timestamp").isEmpty)
|
||||
assert(getTimestamp(channel1.channelId, "last_connected_timestamp").isEmpty)
|
||||
assert(getTimestamp(channel1.channelId, "closed_timestamp").isEmpty)
|
||||
assert(getTimestamp(dbs, channel1.channelId, "created_timestamp").isEmpty)
|
||||
assert(getTimestamp(dbs, channel1.channelId, "last_payment_sent_timestamp").isEmpty)
|
||||
assert(getTimestamp(dbs, channel1.channelId, "last_payment_received_timestamp").isEmpty)
|
||||
assert(getTimestamp(dbs, channel1.channelId, "last_connected_timestamp").isEmpty)
|
||||
assert(getTimestamp(dbs, channel1.channelId, "closed_timestamp").isEmpty)
|
||||
|
||||
db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.Created)
|
||||
assert(getTimestamp(channel1.channelId, "created_timestamp").nonEmpty)
|
||||
assert(getTimestamp(dbs, channel1.channelId, "created_timestamp").nonEmpty)
|
||||
|
||||
db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.PaymentSent)
|
||||
assert(getTimestamp(channel1.channelId, "last_payment_sent_timestamp").nonEmpty)
|
||||
assert(getTimestamp(dbs, channel1.channelId, "last_payment_sent_timestamp").nonEmpty)
|
||||
|
||||
db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.PaymentReceived)
|
||||
assert(getTimestamp(channel1.channelId, "last_payment_received_timestamp").nonEmpty)
|
||||
assert(getTimestamp(dbs, channel1.channelId, "last_payment_received_timestamp").nonEmpty)
|
||||
|
||||
db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.Connected)
|
||||
assert(getTimestamp(channel1.channelId, "last_connected_timestamp").nonEmpty)
|
||||
assert(getTimestamp(dbs, channel1.channelId, "last_connected_timestamp").nonEmpty)
|
||||
|
||||
db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.Closed(null))
|
||||
assert(getTimestamp(channel1.channelId, "closed_timestamp").nonEmpty)
|
||||
assert(getTimestamp(dbs, channel1.channelId, "closed_timestamp").nonEmpty)
|
||||
|
||||
// make sure all metadata are still empty for channel 2
|
||||
assert(getTimestamp(channel2.channelId, "created_timestamp").isEmpty)
|
||||
assert(getTimestamp(channel2.channelId, "last_payment_sent_timestamp").isEmpty)
|
||||
assert(getTimestamp(channel2.channelId, "last_payment_received_timestamp").isEmpty)
|
||||
assert(getTimestamp(channel2.channelId, "last_connected_timestamp").isEmpty)
|
||||
assert(getTimestamp(channel2.channelId, "closed_timestamp").isEmpty)
|
||||
assert(getTimestamp(dbs, channel2.channelId, "created_timestamp").isEmpty)
|
||||
assert(getTimestamp(dbs, channel2.channelId, "last_payment_sent_timestamp").isEmpty)
|
||||
assert(getTimestamp(dbs, channel2.channelId, "last_payment_received_timestamp").isEmpty)
|
||||
assert(getTimestamp(dbs, channel2.channelId, "last_connected_timestamp").isEmpty)
|
||||
assert(getTimestamp(dbs, channel2.channelId, "closed_timestamp").isEmpty)
|
||||
}
|
||||
}
|
||||
|
||||
@ -175,13 +163,22 @@ class ChannelsDbSpec extends AnyFunSuite {
|
||||
setVersion(statement, "channels", 1)
|
||||
}
|
||||
|
||||
// insert 1 row
|
||||
val channel = ChannelCodecsSpec.normal
|
||||
val data = stateDataCodec.encode(channel).require.toByteArray
|
||||
using(sqlite.prepareStatement("INSERT INTO local_channels VALUES (?, ?)")) { statement =>
|
||||
statement.setBytes(1, channel.channelId.toArray)
|
||||
statement.setBytes(2, data)
|
||||
statement.executeUpdate()
|
||||
// insert data
|
||||
for (testCase <- testCases) {
|
||||
using(sqlite.prepareStatement("INSERT INTO local_channels VALUES (?, ?)")) { statement =>
|
||||
statement.setBytes(1, testCase.channelId.toArray)
|
||||
statement.setBytes(2, testCase.data.toArray)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
for (commitmentNumber <- testCase.commitmentNumbers) {
|
||||
using(sqlite.prepareStatement("INSERT INTO htlc_infos (channel_id, commitment_number, payment_hash, cltv_expiry) VALUES (?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, testCase.channelId.toArray)
|
||||
statement.setLong(2, commitmentNumber)
|
||||
statement.setBytes(3, randomBytes32.toArray)
|
||||
statement.setLong(4, 500000 + Random.nextInt(500000))
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check that db migration works
|
||||
@ -189,71 +186,193 @@ class ChannelsDbSpec extends AnyFunSuite {
|
||||
using(sqlite.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "channels").contains(3))
|
||||
}
|
||||
assert(db.listLocalChannels() === List(channel))
|
||||
db.updateChannelMeta(channel.channelId, ChannelEvent.EventType.Created) // this call must not fail
|
||||
assert(db.listLocalChannels().size === testCases.size)
|
||||
for (testCase <- testCases) {
|
||||
db.updateChannelMeta(testCase.channelId, ChannelEvent.EventType.Created) // this call must not fail
|
||||
for (commitmentNumber <- testCase.commitmentNumbers) {
|
||||
assert(db.listHtlcInfos(testCase.channelId, commitmentNumber).size === testCase.commitmentNumbers.count(_ == commitmentNumber))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("migrate channel database v2 -> v3") {
|
||||
test("migrate channel database v2 -> v3/v4") {
|
||||
def postCheck(channelsDb: ChannelsDb): Unit = {
|
||||
assert(channelsDb.listLocalChannels().size === testCases.filterNot(_.isClosed).size)
|
||||
for (testCase <- testCases.filterNot(_.isClosed)) {
|
||||
channelsDb.updateChannelMeta(testCase.channelId, ChannelEvent.EventType.Created) // this call must not fail
|
||||
for (commitmentNumber <- testCase.commitmentNumbers) {
|
||||
assert(channelsDb.listHtlcInfos(testCase.channelId, commitmentNumber).size === testCase.commitmentNumbers.count(_ == commitmentNumber))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
forAllDbs {
|
||||
case dbs: TestPgDatabases =>
|
||||
val pg = dbs.connection
|
||||
|
||||
// create a v2 channels database
|
||||
using(pg.createStatement()) { statement =>
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
|
||||
setVersion(statement, "channels", 2)
|
||||
}
|
||||
|
||||
// insert 1 row
|
||||
val channel = ChannelCodecsSpec.normal
|
||||
val data = stateDataCodec.encode(channel).require.toByteArray
|
||||
using(pg.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed) VALUES (?, ?, ?)")) { statement =>
|
||||
statement.setString(1, channel.channelId.toHex)
|
||||
statement.setBytes(2, data)
|
||||
statement.setBoolean(3, false)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
// check that db migration works
|
||||
val db = dbs.channels
|
||||
using(pg.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "channels").contains(3))
|
||||
}
|
||||
assert(db.listLocalChannels() === List(channel))
|
||||
db.updateChannelMeta(channel.channelId, ChannelEvent.EventType.Created) // this call must not fail
|
||||
|
||||
migrationCheck(
|
||||
dbs = dbs,
|
||||
initializeTables = connection => {
|
||||
// initialize a v2 database
|
||||
using(connection.createStatement()) { statement =>
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
|
||||
setVersion(statement, "channels", 2)
|
||||
}
|
||||
// insert data
|
||||
testCases.foreach { testCase =>
|
||||
using(connection.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed) VALUES (?, ?, ?)")) { statement =>
|
||||
statement.setString(1, testCase.channelId.toHex)
|
||||
statement.setBytes(2, testCase.data.toArray)
|
||||
statement.setBoolean(3, testCase.isClosed)
|
||||
statement.executeUpdate()
|
||||
for (commitmentNumber <- testCase.commitmentNumbers) {
|
||||
using(connection.prepareStatement("INSERT INTO htlc_infos (channel_id, commitment_number, payment_hash, cltv_expiry) VALUES (?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, testCase.channelId.toHex)
|
||||
statement.setLong(2, commitmentNumber)
|
||||
statement.setString(3, randomBytes32.toHex)
|
||||
statement.setLong(4, 500000 + Random.nextInt(500000))
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
dbName = "channels",
|
||||
targetVersion = 4,
|
||||
postCheck = _ => postCheck(dbs.channels)
|
||||
)
|
||||
case dbs: TestSqliteDatabases =>
|
||||
val sqlite = dbs.connection
|
||||
migrationCheck(
|
||||
dbs = dbs,
|
||||
initializeTables = connection => {
|
||||
// create a v2 channels database
|
||||
using(connection.createStatement()) { statement =>
|
||||
statement.execute("PRAGMA foreign_keys = ON")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
|
||||
setVersion(statement, "channels", 2)
|
||||
}
|
||||
// insert data
|
||||
testCases.foreach { testCase =>
|
||||
using(connection.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed) VALUES (?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, testCase.channelId.toArray)
|
||||
statement.setBytes(2, testCase.data.toArray)
|
||||
statement.setBoolean(3, testCase.isClosed)
|
||||
statement.executeUpdate()
|
||||
for (commitmentNumber <- testCase.commitmentNumbers) {
|
||||
using(connection.prepareStatement("INSERT INTO htlc_infos (channel_id, commitment_number, payment_hash, cltv_expiry) VALUES (?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, testCase.channelId.toArray)
|
||||
statement.setLong(2, commitmentNumber)
|
||||
statement.setBytes(3, randomBytes32.toArray)
|
||||
statement.setLong(4, 500000 + Random.nextInt(500000))
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
dbName = "channels",
|
||||
targetVersion = 3,
|
||||
postCheck = _ => postCheck(dbs.channels)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// create a v2 channels database
|
||||
using(sqlite.createStatement()) { statement =>
|
||||
statement.execute("PRAGMA foreign_keys = ON")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
|
||||
setVersion(statement, "channels", 2)
|
||||
}
|
||||
test("migrate pg channel database v3->v4") {
|
||||
val dbs = TestPgDatabases()
|
||||
|
||||
// insert 1 row
|
||||
val channel = ChannelCodecsSpec.normal
|
||||
val data = stateDataCodec.encode(channel).require.toByteArray
|
||||
using(sqlite.prepareStatement("INSERT INTO local_channels VALUES (?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, channel.channelId.toArray)
|
||||
statement.setBytes(2, data)
|
||||
statement.setBoolean(3, false)
|
||||
statement.executeUpdate()
|
||||
migrationCheck(
|
||||
dbs = dbs,
|
||||
initializeTables = connection => {
|
||||
using(connection.createStatement()) { statement =>
|
||||
// initialize a v3 database
|
||||
statement.executeUpdate("CREATE TABLE local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE, created_timestamp BIGINT, last_payment_sent_timestamp BIGINT, last_payment_received_timestamp BIGINT, last_connected_timestamp BIGINT, closed_timestamp BIGINT)")
|
||||
statement.executeUpdate("CREATE TABLE htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
|
||||
statement.executeUpdate("CREATE INDEX htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
|
||||
PgUtils.setVersion(statement, "channels", 3)
|
||||
}
|
||||
// insert data
|
||||
testCases.foreach { testCase =>
|
||||
using(connection.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed, created_timestamp, last_payment_sent_timestamp, last_payment_received_timestamp, last_connected_timestamp, closed_timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, testCase.channelId.toHex)
|
||||
statement.setBytes(2, testCase.data.toArray)
|
||||
statement.setBoolean(3, testCase.isClosed)
|
||||
statement.setObject(4, testCase.createdTimestamp.orNull)
|
||||
statement.setObject(5, testCase.lastPaymentSentTimestamp.orNull)
|
||||
statement.setObject(6, testCase.lastPaymentReceivedTimestamp.orNull)
|
||||
statement.setObject(7, testCase.lastConnectedTimestamp.orNull)
|
||||
statement.setObject(8, testCase.closedTimestamp.orNull)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
},
|
||||
dbName = "channels",
|
||||
targetVersion = 4,
|
||||
postCheck = connection => {
|
||||
assert(dbs.channels.listLocalChannels().size === testCases.filterNot(_.isClosed).size)
|
||||
testCases.foreach { testCase =>
|
||||
assert(getPgTimestamp(connection, testCase.channelId, "created_timestamp") === testCase.createdTimestamp)
|
||||
assert(getPgTimestamp(connection, testCase.channelId, "last_payment_sent_timestamp") === testCase.lastPaymentSentTimestamp)
|
||||
assert(getPgTimestamp(connection, testCase.channelId, "last_payment_received_timestamp") === testCase.lastPaymentReceivedTimestamp)
|
||||
assert(getPgTimestamp(connection, testCase.channelId, "last_connected_timestamp") === testCase.lastConnectedTimestamp)
|
||||
assert(getPgTimestamp(connection, testCase.channelId, "closed_timestamp") === testCase.closedTimestamp)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
// check that db migration works
|
||||
val db = dbs.channels
|
||||
using(sqlite.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "channels").contains(3))
|
||||
}
|
||||
assert(db.listLocalChannels() === List(channel))
|
||||
db.updateChannelMeta(channel.channelId, ChannelEvent.EventType.Created) // this call must not fail
|
||||
}
|
||||
}
|
||||
|
||||
object ChannelsDbSpec {
|
||||
|
||||
case class TestCase(channelId: ByteVector32,
|
||||
data: ByteVector,
|
||||
isClosed: Boolean,
|
||||
createdTimestamp: Option[Long],
|
||||
lastPaymentSentTimestamp: Option[Long],
|
||||
lastPaymentReceivedTimestamp: Option[Long],
|
||||
lastConnectedTimestamp: Option[Long],
|
||||
closedTimestamp: Option[Long],
|
||||
commitmentNumbers: Seq[Int]
|
||||
)
|
||||
|
||||
private val data = stateDataCodec.encode(ChannelCodecsSpec.normal).require.bytes
|
||||
val testCases: Seq[TestCase] = for (_ <- 0 until 10) yield TestCase(
|
||||
channelId = randomBytes32,
|
||||
data = data,
|
||||
isClosed = Random.nextBoolean(),
|
||||
createdTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None,
|
||||
lastPaymentSentTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None,
|
||||
lastPaymentReceivedTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None,
|
||||
lastConnectedTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None,
|
||||
closedTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None,
|
||||
commitmentNumbers = for (_ <- 0 until Random.nextInt(10)) yield Random.nextInt(5) // there will be repetitions, on purpose
|
||||
)
|
||||
|
||||
def getTimestamp(dbs: TestDatabases, channelId: ByteVector32, columnName: String): Option[Long] = {
|
||||
dbs match {
|
||||
case _: TestPgDatabases => getPgTimestamp(dbs.connection, channelId, columnName)
|
||||
case _: TestSqliteDatabases => getSqliteTimestamp(dbs.connection, channelId, columnName)
|
||||
}
|
||||
}
|
||||
|
||||
def getSqliteTimestamp(connection: Connection, channelId: ByteVector32, columnName: String): Option[Long] = {
|
||||
using(connection.prepareStatement(s"SELECT $columnName FROM local_channels WHERE channel_id=?")) { statement =>
|
||||
statement.setBytes(1, channelId.toArray)
|
||||
val rs = statement.executeQuery()
|
||||
rs.next()
|
||||
rs.getLongNullable(columnName)
|
||||
}
|
||||
}
|
||||
|
||||
def getPgTimestamp(connection: Connection, channelId: ByteVector32, columnName: String): Option[Long] = {
|
||||
using(connection.prepareStatement(s"SELECT $columnName FROM local_channels WHERE channel_id=?")) { statement =>
|
||||
statement.setString(1, channelId.toHex)
|
||||
val rs = statement.executeQuery()
|
||||
rs.next()
|
||||
rs.getTimestampNullable(columnName).map(_.getTime)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user