mirror of
https://github.com/ACINQ/eclair.git
synced 2024-11-19 01:43:22 +01:00
Postgresql support (#1249)
Add beta support for PostgreSQL database backend.
This commit is contained in:
parent
68dfc6cb7c
commit
b63c4aa5a4
2
.gitignore
vendored
2
.gitignore
vendored
@ -25,3 +25,5 @@ target/
|
||||
project/target
|
||||
DeleteMe*.*
|
||||
*~
|
||||
|
||||
.DS_Store
|
||||
|
114
docs/PostgreSQL.md
Normal file
114
docs/PostgreSQL.md
Normal file
@ -0,0 +1,114 @@
|
||||
## PostgreSQL Configuration
|
||||
|
||||
By default Eclair stores its data on the machine's local file system (typically in `~/.eclair` directory) using SQLite.
|
||||
|
||||
It also supports PostgreSQL version 10.6 and higher as a database backend.
|
||||
|
||||
To enable PostgreSQL support set the `driver` parameter to `postgres`:
|
||||
|
||||
```
|
||||
eclair.db.driver = postgres
|
||||
```
|
||||
|
||||
### Connection settings
|
||||
|
||||
To configure the connection settings use the `database`, `host`, `port` `username` and `password` parameters:
|
||||
|
||||
```
|
||||
eclair.db.postgres.database = "mydb"
|
||||
eclair.db.postgres.host = "127.0.0.1" # Default: "localhost"
|
||||
eclair.db.postgres.port = 12345 # Default: 5432
|
||||
eclair.db.postgres.username = "myuser"
|
||||
eclair.db.postgres.password = "mypassword"
|
||||
```
|
||||
|
||||
Eclair uses Hikari connection pool (https://github.com/brettwooldridge/HikariCP) which has a lot of configuration
|
||||
parameters. Some of them can be set in Eclair config file. The most important is `pool.max-size`, it defines the maximum
|
||||
allowed number of simultaneous connections to the database.
|
||||
|
||||
A good rule of thumb is to set `pool.max-size` to the CPU core count times 2.
|
||||
See https://github.com/brettwooldridge/HikariCP/wiki/About-Pool-Sizing for better estimation.
|
||||
|
||||
```
|
||||
eclair.db.postgres.pool {
|
||||
max-size = 8 # Default: 10
|
||||
connection-timeout = 10 seconds # Default: 30 seconds
|
||||
idle-timeout = 1 minute # Default: 10 minutes
|
||||
max-life-time = 15 minutes # Default: 30 minutes
|
||||
}
|
||||
```
|
||||
|
||||
### Locking settings
|
||||
|
||||
Running multiple Eclair processes connected to the same database can lead to data corruption and loss of funds.
|
||||
That's why Eclair supports database locking mechanisms to prevent multiple Eclair instances from accessing one database together.
|
||||
|
||||
Use `postgres.lock-type` parameter to set the locking schemes.
|
||||
|
||||
Lock type | Description
|
||||
---|---
|
||||
`lease` | At the beginning, Eclair acquires a lease for the database that expires after some time. Then it constantly extends the lease. On each lease extension and each database transaction, Eclair checks if the lease belongs to the Eclair instance. If it doesn't, Eclair assumes that the database was updated by another Eclair process and terminates. Note that this is just a safeguard feature for Eclair rather than a bulletproof database-wide lock, because third-party applications still have the ability to access the database without honoring this locking scheme.
|
||||
`none` | No locking at all. Useful for tests. DO NOT USE ON MAINNET!
|
||||
|
||||
```
|
||||
eclair.db.postgres.lock-type = "none" // Default: "lease"
|
||||
```
|
||||
|
||||
#### Database Lease Settings
|
||||
|
||||
There are two main configuration parameters for the lease locking scheme: `lease.interval` and `lease.renew-interval`.
|
||||
`lease.interval` defines lease validity time. During the lease time no other node can acquire the lock, except the lease holder.
|
||||
After that time the lease is assumed expired, any node can acquire the lease. So that only one node can update the database
|
||||
at a time. Eclair extends the lease every `lease.renew-interval` until terminated.
|
||||
|
||||
```
|
||||
eclair.db.postgres.lease {
|
||||
interval = 30 seconds // Default: 5 minutes
|
||||
renew-interval = 10 seconds // Default: 1 minute
|
||||
}
|
||||
```
|
||||
|
||||
### Backups and replication
|
||||
|
||||
The PostgreSQL driver doesn't support Eclair's built-in online backups. Instead, you should use the tools provided
|
||||
by PostgreSQL.
|
||||
|
||||
#### Backup/Restore
|
||||
|
||||
For nodes with infrequent channel updates its easier to use `pg_dump` to perform the task.
|
||||
|
||||
It's important to stop the node to prevent any channel updates while a backup/restore operation is in progress. It makes
|
||||
sense to backup the database after each channel update, to prevent restoring an outdated channel's state and consequently
|
||||
losing the funds associated with that channel.
|
||||
|
||||
For more information about backup refer to the official PostgreSQL documentation: https://www.postgresql.org/docs/current/backup.html
|
||||
|
||||
#### Replication
|
||||
|
||||
For busier nodes it isn't practical to use `pg_dump`. Fortunately, PostgreSQL provides built-in database replication which makes the backup/restore process more seamless.
|
||||
|
||||
To set up database replication you need to create a main database, that accepts all changes from the node, and a replica database.
|
||||
Once replication is configured, the main database will automatically send all the changes to the replica.
|
||||
In case of failure of the main database, the node can be simply reconfigured to use the replica instead of the main database.
|
||||
|
||||
PostgreSQL supports [different types of replication](https://www.postgresql.org/docs/current/different-replication-solutions.html).
|
||||
The most suitable type for an Eclair node is [synchronous streaming replication](https://www.postgresql.org/docs/current/warm-standby.html#SYNCHRONOUS-REPLICATION),
|
||||
because it provides a very important feature, that helps keep the replicated channel's state up to date:
|
||||
|
||||
> When requesting synchronous replication, each commit of a write transaction will wait until confirmation is received that the commit has been written to the write-ahead log on disk of both the primary and standby server.
|
||||
|
||||
Follow the official PostgreSQL high availability documentation for the instructions to set up synchronous streaming replication: https://www.postgresql.org/docs/current/high-availability.html
|
||||
|
||||
### Safeguard to prevent accidental loss of funds due to database misconfiguration
|
||||
|
||||
Using Eclair with an outdated version of the database or a database created with another seed might lead to loss of funds.
|
||||
|
||||
Every time Eclair starts, it checks if the Postgres database connection settings were changed since the last start.
|
||||
If in fact the settings were changed, Eclair stops immediately to prevent potentially dangerous
|
||||
but accidental configuration changes to come into effect.
|
||||
|
||||
Eclair stores the latest database settings in the `${data-dir}/last_jdbcurl` file, and compares its contents with the database settings from the config file.
|
||||
|
||||
The node operator can force Eclair to accept new database
|
||||
connection settings by removing the `last_jdbcurl` file.
|
||||
|
@ -198,6 +198,16 @@
|
||||
<artifactId>sqlite-jdbc</artifactId>
|
||||
<version>3.27.2.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.postgresql</groupId>
|
||||
<version>42.2.12</version>
|
||||
<artifactId>postgresql</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.zaxxer</groupId>
|
||||
<artifactId>HikariCP</artifactId>
|
||||
<version>3.4.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<!-- This is to get rid of '[WARNING] warning: Class javax.annotation.Nonnull not found - continuing with a stub.' compile errors -->
|
||||
<groupId>com.google.code.findbugs</groupId>
|
||||
@ -264,5 +274,11 @@
|
||||
<version>1.5.9</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.opentable.components</groupId>
|
||||
<artifactId>otj-pg-embedded</artifactId>
|
||||
<version>0.13.3</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
@ -183,6 +183,28 @@ eclair {
|
||||
port = 9051
|
||||
private-key-file = "tor.dat"
|
||||
}
|
||||
|
||||
db {
|
||||
driver = "sqlite" // sqlite, postgres
|
||||
postgres {
|
||||
database = "eclair"
|
||||
host = "localhost"
|
||||
port = 5432
|
||||
username = ""
|
||||
password = ""
|
||||
pool {
|
||||
max-size = 10 // recommended value = number_of_cpu_cores * 2
|
||||
connection-timeout = 30 seconds
|
||||
idle-timeout = 10 minutes
|
||||
max-life-time = 30 minutes
|
||||
}
|
||||
lease {
|
||||
interval = 5 minutes // lease-interval must be greater than lease-renew-interval
|
||||
renew-interval = 1 minute
|
||||
}
|
||||
lock-type = "lease" // lease or none
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// do not edit or move this section
|
||||
|
@ -45,7 +45,7 @@ import scala.concurrent.duration._
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
case class GetInfoResponse(version: String, nodeId: PublicKey, alias: String, color: String, features: Features, chainHash: ByteVector32, network: String, blockHeight: Int, publicAddresses: Seq[NodeAddress])
|
||||
case class GetInfoResponse(version: String, nodeId: PublicKey, alias: String, color: String, features: Features, chainHash: ByteVector32, network: String, blockHeight: Int, publicAddresses: Seq[NodeAddress], instanceId: String)
|
||||
|
||||
case class AuditResponse(sent: Seq[PaymentSent], received: Seq[PaymentReceived], relayed: Seq[PaymentRelayed])
|
||||
|
||||
@ -372,7 +372,8 @@ class EclairImpl(appKit: Kit) extends Eclair {
|
||||
chainHash = appKit.nodeParams.chainHash,
|
||||
network = NodeParams.chainFromHash(appKit.nodeParams.chainHash),
|
||||
blockHeight = appKit.nodeParams.currentBlockHeight.toInt,
|
||||
publicAddresses = appKit.nodeParams.publicAddresses)
|
||||
publicAddresses = appKit.nodeParams.publicAddresses,
|
||||
instanceId = appKit.nodeParams.instanceId.toString)
|
||||
)
|
||||
|
||||
override def usableBalances()(implicit timeout: Timeout): Future[Iterable[UsableBalance]] =
|
||||
|
@ -19,6 +19,7 @@ package fr.acinq.eclair
|
||||
import java.io.File
|
||||
import java.net.InetSocketAddress
|
||||
import java.nio.file.Files
|
||||
import java.util.UUID
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.atomic.AtomicLong
|
||||
|
||||
@ -42,6 +43,7 @@ import scala.jdk.CollectionConverters._
|
||||
* Created by PM on 26/02/2017.
|
||||
*/
|
||||
case class NodeParams(keyManager: KeyManager,
|
||||
instanceId: UUID, // a unique instance ID regenerated after each restart
|
||||
private val blockCount: AtomicLong,
|
||||
alias: String,
|
||||
color: Color,
|
||||
@ -132,7 +134,7 @@ object NodeParams {
|
||||
|
||||
def chainFromHash(chainHash: ByteVector32): String = chain2Hash.map(_.swap).getOrElse(chainHash, throw new RuntimeException(s"invalid chainHash '$chainHash'"))
|
||||
|
||||
def makeNodeParams(config: Config, keyManager: KeyManager, torAddress_opt: Option[NodeAddress], database: Databases, blockCount: AtomicLong, feeEstimator: FeeEstimator): NodeParams = {
|
||||
def makeNodeParams(config: Config, instanceId: UUID, keyManager: KeyManager, torAddress_opt: Option[NodeAddress], database: Databases, blockCount: AtomicLong, feeEstimator: FeeEstimator): NodeParams = {
|
||||
// check configuration for keys that have been renamed
|
||||
val deprecatedKeyPaths = Map(
|
||||
// v0.3.2
|
||||
@ -235,6 +237,7 @@ object NodeParams {
|
||||
|
||||
NodeParams(
|
||||
keyManager = keyManager,
|
||||
instanceId = instanceId,
|
||||
blockCount = blockCount,
|
||||
alias = nodeAlias,
|
||||
color = Color(color(0), color(1), color(2)),
|
||||
|
@ -19,6 +19,7 @@ package fr.acinq.eclair
|
||||
import java.io.File
|
||||
import java.net.InetSocketAddress
|
||||
import java.sql.DriverManager
|
||||
import java.util.UUID
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
|
||||
|
||||
@ -27,7 +28,6 @@ import akka.actor.{ActorRef, ActorSystem, Props, SupervisorStrategy}
|
||||
import akka.pattern.after
|
||||
import akka.util.Timeout
|
||||
import com.softwaremill.sttp.okhttp.OkHttpFutureBackend
|
||||
import com.typesafe.config.{Config, ConfigFactory}
|
||||
import fr.acinq.bitcoin.{Block, ByteVector32}
|
||||
import fr.acinq.eclair.NodeParams.{BITCOIND, ELECTRUM}
|
||||
import fr.acinq.eclair.blockchain.bitcoind.rpc.{BasicBitcoinJsonRPCClient, BatchingBitcoinJsonRPCClient, ExtendedBitcoinClient}
|
||||
@ -41,7 +41,8 @@ import fr.acinq.eclair.blockchain.fee.{ConstantFeeProvider, _}
|
||||
import fr.acinq.eclair.blockchain.{EclairWallet, _}
|
||||
import fr.acinq.eclair.channel.Register
|
||||
import fr.acinq.eclair.crypto.LocalKeyManager
|
||||
import fr.acinq.eclair.db.{BackupHandler, Databases}
|
||||
import fr.acinq.eclair.db.Databases.FileBackup
|
||||
import fr.acinq.eclair.db.{Databases, FileBackupHandler}
|
||||
import fr.acinq.eclair.io.{ClientSpawner, Server, Switchboard}
|
||||
import fr.acinq.eclair.payment.Auditor
|
||||
import fr.acinq.eclair.payment.receive.PaymentHandler
|
||||
@ -90,11 +91,11 @@ class Setup(datadir: File,
|
||||
val chain = config.getString("chain")
|
||||
val chaindir = new File(datadir, chain)
|
||||
val keyManager = new LocalKeyManager(seed, NodeParams.hashFromChain(chain))
|
||||
val instanceId = UUID.randomUUID()
|
||||
|
||||
val database = db match {
|
||||
case Some(d) => d
|
||||
case None => Databases.sqliteJDBC(chaindir)
|
||||
}
|
||||
logger.info(s"instanceid=$instanceId")
|
||||
|
||||
val databases = Databases.init(config.getConfig("db"), instanceId, datadir, chaindir, db)
|
||||
|
||||
/**
|
||||
* This counter holds the current blockchain height.
|
||||
@ -123,7 +124,7 @@ class Setup(datadir: File,
|
||||
// @formatter:on
|
||||
}
|
||||
|
||||
val nodeParams = NodeParams.makeNodeParams(config, keyManager, initTor(), database, blockCount, feeEstimator)
|
||||
val nodeParams = NodeParams.makeNodeParams(config, instanceId, keyManager, initTor(), databases, blockCount, feeEstimator)
|
||||
|
||||
val serverBindingAddress = new InetSocketAddress(
|
||||
config.getString("server.binding-ip"),
|
||||
@ -281,12 +282,16 @@ class Setup(datadir: File,
|
||||
// do not change the name of this actor. it is used in the configuration to specify a custom bounded mailbox
|
||||
|
||||
backupHandler = if (config.getBoolean("enable-db-backup")) {
|
||||
system.actorOf(SimpleSupervisor.props(
|
||||
BackupHandler.props(
|
||||
nodeParams.db,
|
||||
new File(chaindir, "eclair.sqlite.bak"),
|
||||
if (config.hasPath("backup-notify-script")) Some(config.getString("backup-notify-script")) else None
|
||||
), "backuphandler", SupervisorStrategy.Resume))
|
||||
nodeParams.db match {
|
||||
case fileBackup: FileBackup => system.actorOf(SimpleSupervisor.props(
|
||||
FileBackupHandler.props(
|
||||
fileBackup,
|
||||
new File(chaindir, "eclair.sqlite.bak"),
|
||||
if (config.hasPath("backup-notify-script")) Some(config.getString("backup-notify-script")) else None),
|
||||
"backuphandler", SupervisorStrategy.Resume))
|
||||
case _ =>
|
||||
system.deadLetters
|
||||
}
|
||||
} else {
|
||||
logger.warn("database backup is disabled")
|
||||
system.deadLetters
|
||||
@ -363,6 +368,7 @@ class Setup(datadir: File,
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// @formatter:off
|
||||
|
@ -17,11 +17,20 @@
|
||||
package fr.acinq.eclair.db
|
||||
|
||||
import java.io.File
|
||||
import java.nio.file._
|
||||
import java.sql.{Connection, DriverManager}
|
||||
import java.util.UUID
|
||||
|
||||
import akka.actor.ActorSystem
|
||||
import com.typesafe.config.Config
|
||||
import fr.acinq.eclair.db.pg.PgUtils.LockType.LockType
|
||||
import fr.acinq.eclair.db.pg.PgUtils._
|
||||
import fr.acinq.eclair.db.pg._
|
||||
import fr.acinq.eclair.db.sqlite._
|
||||
import grizzled.slf4j.Logging
|
||||
import org.sqlite.SQLiteException
|
||||
import javax.sql.DataSource
|
||||
|
||||
import scala.concurrent.duration._
|
||||
|
||||
trait Databases {
|
||||
|
||||
@ -36,12 +45,53 @@ trait Databases {
|
||||
val payments: PaymentsDb
|
||||
|
||||
val pendingRelay: PendingRelayDb
|
||||
|
||||
def backup(file: File): Unit
|
||||
}
|
||||
|
||||
object Databases extends Logging {
|
||||
|
||||
trait FileBackup { this: Databases =>
|
||||
def backup(backupFile: File): Unit
|
||||
}
|
||||
|
||||
trait ExclusiveLock { this: Databases =>
|
||||
def obtainExclusiveLock(): Unit
|
||||
}
|
||||
|
||||
def init(dbConfig: Config, instanceId: UUID, datadir: File, chaindir: File, db: Option[Databases] = None)(implicit system: ActorSystem): Databases = {
|
||||
db match {
|
||||
case Some(d) => d
|
||||
case None =>
|
||||
dbConfig.getString("driver") match {
|
||||
case "sqlite" => Databases.sqliteJDBC(chaindir)
|
||||
case "postgres" =>
|
||||
val pg = Databases.setupPgDatabases(dbConfig, instanceId, datadir, { ex =>
|
||||
logger.error("fatal error: Cannot obtain lock on the database.\n", ex)
|
||||
sys.exit(-2)
|
||||
})
|
||||
if (LockType(dbConfig.getString("postgres.lock-type")) == LockType.LEASE) {
|
||||
val dbLockLeaseRenewInterval = dbConfig.getDuration("postgres.lease.renew-interval").toSeconds.seconds
|
||||
val dbLockLeaseInterval = dbConfig.getDuration("postgres.lease.interval").toSeconds.seconds
|
||||
if (dbLockLeaseInterval <= dbLockLeaseRenewInterval)
|
||||
throw new RuntimeException("Invalid configuration: `db.postgres.lease.interval` must be greater than `db.postgres.lease.renew-interval`")
|
||||
import system.dispatcher
|
||||
system.scheduler.scheduleWithFixedDelay(dbLockLeaseRenewInterval, dbLockLeaseRenewInterval)(new Runnable {
|
||||
override def run(): Unit = {
|
||||
try {
|
||||
pg.obtainExclusiveLock()
|
||||
} catch {
|
||||
case e: Throwable =>
|
||||
logger.error("fatal error: Cannot obtain the database lease.\n", e)
|
||||
sys.exit(-1)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
pg
|
||||
case driver => throw new RuntimeException(s"Unknown database driver `$driver`")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a parent folder it creates or loads all the databases from a JDBC connection
|
||||
*
|
||||
@ -59,7 +109,7 @@ object Databases extends Logging {
|
||||
sqliteAudit = DriverManager.getConnection(s"jdbc:sqlite:${new File(dbdir, "audit.sqlite")}")
|
||||
SqliteUtils.obtainExclusiveLock(sqliteEclair) // there should only be one process writing to this file
|
||||
logger.info("successful lock on eclair.sqlite")
|
||||
databaseByConnections(sqliteAudit, sqliteNetwork, sqliteEclair)
|
||||
sqliteDatabaseByConnections(sqliteAudit, sqliteNetwork, sqliteEclair)
|
||||
} catch {
|
||||
case t: Throwable => {
|
||||
logger.error("could not create connection to sqlite databases: ", t)
|
||||
@ -69,23 +119,117 @@ object Databases extends Logging {
|
||||
throw t
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
def databaseByConnections(auditJdbc: Connection, networkJdbc: Connection, eclairJdbc: Connection) = new Databases {
|
||||
def postgresJDBC(database: String, host: String, port: Int,
|
||||
username: Option[String], password: Option[String],
|
||||
poolProperties: Map[String, Long],
|
||||
instanceId: UUID,
|
||||
databaseLeaseInterval: FiniteDuration,
|
||||
lockExceptionHandler: LockExceptionHandler = { _ => () },
|
||||
lockType: LockType = LockType.NONE, datadir: File): Databases with ExclusiveLock = {
|
||||
val url = s"jdbc:postgresql://${host}:${port}/${database}"
|
||||
|
||||
checkIfDatabaseUrlIsUnchanged(url, datadir)
|
||||
|
||||
implicit val lock: DatabaseLock = lockType match {
|
||||
case LockType.NONE => NoLock
|
||||
case LockType.LEASE => LeaseLock(instanceId, databaseLeaseInterval, lockExceptionHandler)
|
||||
case _ => throw new RuntimeException(s"Unknown postgres lock type: `$lockType`")
|
||||
}
|
||||
|
||||
import com.zaxxer.hikari.{HikariConfig, HikariDataSource}
|
||||
|
||||
val config = new HikariConfig()
|
||||
config.setJdbcUrl(url)
|
||||
username.foreach(config.setUsername)
|
||||
password.foreach(config.setPassword)
|
||||
poolProperties.get("max-size").foreach(x => config.setMaximumPoolSize(x.toInt))
|
||||
poolProperties.get("connection-timeout").foreach(config.setConnectionTimeout)
|
||||
poolProperties.get("idle-timeout").foreach(config.setIdleTimeout)
|
||||
poolProperties.get("max-life-time").foreach(config.setMaxLifetime)
|
||||
|
||||
implicit val ds: DataSource = new HikariDataSource(config)
|
||||
|
||||
val databases = new Databases with ExclusiveLock {
|
||||
override val network = new PgNetworkDb
|
||||
override val audit = new PgAuditDb
|
||||
override val channels = new PgChannelsDb
|
||||
override val peers = new PgPeersDb
|
||||
override val payments = new PgPaymentsDb
|
||||
override val pendingRelay = new PgPendingRelayDb
|
||||
override def obtainExclusiveLock(): Unit = lock.obtainExclusiveLock
|
||||
}
|
||||
databases.obtainExclusiveLock()
|
||||
databases
|
||||
}
|
||||
|
||||
def sqliteDatabaseByConnections(auditJdbc: Connection, networkJdbc: Connection, eclairJdbc: Connection): Databases = new Databases with FileBackup {
|
||||
override val network = new SqliteNetworkDb(networkJdbc)
|
||||
override val audit = new SqliteAuditDb(auditJdbc)
|
||||
override val channels = new SqliteChannelsDb(eclairJdbc)
|
||||
override val peers = new SqlitePeersDb(eclairJdbc)
|
||||
override val payments = new SqlitePaymentsDb(eclairJdbc)
|
||||
override val pendingRelay = new SqlitePendingRelayDb(eclairJdbc)
|
||||
override def backup(backupFile: File): Unit = {
|
||||
|
||||
override def backup(file: File): Unit = {
|
||||
SqliteUtils.using(eclairJdbc.createStatement()) {
|
||||
statement => {
|
||||
statement.executeUpdate(s"backup to ${file.getAbsolutePath}")
|
||||
statement.executeUpdate(s"backup to ${backupFile.getAbsolutePath}")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
def setupPgDatabases(dbConfig: Config, instanceId: UUID, datadir: File, lockExceptionHandler: LockExceptionHandler): Databases with ExclusiveLock = {
|
||||
val database = dbConfig.getString("postgres.database")
|
||||
val host = dbConfig.getString("postgres.host")
|
||||
val port = dbConfig.getInt("postgres.port")
|
||||
val username = if (dbConfig.getIsNull("postgres.username") || dbConfig.getString("postgres.username").isEmpty)
|
||||
None
|
||||
else
|
||||
Some(dbConfig.getString("postgres.username"))
|
||||
val password = if (dbConfig.getIsNull("postgres.password") || dbConfig.getString("postgres.password").isEmpty)
|
||||
None
|
||||
else
|
||||
Some(dbConfig.getString("postgres.password"))
|
||||
val properties = {
|
||||
val poolConfig = dbConfig.getConfig("postgres.pool")
|
||||
Map.empty
|
||||
.updated("max-size", poolConfig.getInt("max-size").toLong)
|
||||
.updated("connection-timeout", poolConfig.getDuration("connection-timeout").toMillis)
|
||||
.updated("idle-timeout", poolConfig.getDuration("idle-timeout").toMillis)
|
||||
.updated("max-life-time", poolConfig.getDuration("max-life-time").toMillis)
|
||||
|
||||
}
|
||||
val lockType = LockType(dbConfig.getString("postgres.lock-type"))
|
||||
val leaseInterval = dbConfig.getDuration("postgres.lease.interval").toSeconds.seconds
|
||||
|
||||
Databases.postgresJDBC(
|
||||
database = database, host = host, port = port,
|
||||
username = username, password = password,
|
||||
poolProperties = properties,
|
||||
instanceId = instanceId,
|
||||
databaseLeaseInterval = leaseInterval,
|
||||
lockExceptionHandler = lockExceptionHandler, lockType = lockType, datadir = datadir
|
||||
)
|
||||
}
|
||||
|
||||
private def checkIfDatabaseUrlIsUnchanged(url: String, datadir: File ): Unit = {
|
||||
val urlFile = new File(datadir, "last_jdbcurl")
|
||||
|
||||
def readString(path: Path): String = Files.readAllLines(path).get(0)
|
||||
|
||||
def writeString(path: Path, string: String): Unit = Files.write(path, java.util.Arrays.asList(string))
|
||||
|
||||
if (urlFile.exists()) {
|
||||
val oldUrl = readString(urlFile.toPath)
|
||||
if (oldUrl != url)
|
||||
throw new RuntimeException(s"The database URL has changed since the last start. It was `$oldUrl`, now it's `$url`")
|
||||
} else {
|
||||
writeString(urlFile.toPath, url)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -22,6 +22,7 @@ import java.nio.file.{Files, StandardCopyOption}
|
||||
import akka.actor.{Actor, ActorLogging, Props}
|
||||
import akka.dispatch.{BoundedMessageQueueSemantics, RequiresMessageQueue}
|
||||
import fr.acinq.eclair.channel.ChannelPersisted
|
||||
import fr.acinq.eclair.db.Databases.FileBackup
|
||||
|
||||
import scala.sys.process.Process
|
||||
import scala.util.{Failure, Success, Try}
|
||||
@ -46,7 +47,7 @@ import scala.util.{Failure, Success, Try}
|
||||
*
|
||||
* Constructor is private so users will have to use BackupHandler.props() which always specific a custom mailbox
|
||||
*/
|
||||
class BackupHandler private(databases: Databases, backupFile: File, backupScript_opt: Option[String]) extends Actor with RequiresMessageQueue[BoundedMessageQueueSemantics] with ActorLogging {
|
||||
class FileBackupHandler private(databases: FileBackup, backupFile: File, backupScript_opt: Option[String]) extends Actor with RequiresMessageQueue[BoundedMessageQueueSemantics] with ActorLogging {
|
||||
|
||||
// we listen to ChannelPersisted events, which will trigger a backup
|
||||
context.system.eventStream.subscribe(self, classOf[ChannelPersisted])
|
||||
@ -56,6 +57,7 @@ class BackupHandler private(databases: Databases, backupFile: File, backupScript
|
||||
val start = System.currentTimeMillis()
|
||||
val tmpFile = new File(backupFile.getAbsolutePath.concat(".tmp"))
|
||||
databases.backup(tmpFile)
|
||||
|
||||
// this will throw an exception if it fails, which is possible if the backup file is not on the same filesystem
|
||||
// as the temporary file
|
||||
Files.move(tmpFile.toPath, backupFile.toPath, StandardCopyOption.REPLACE_EXISTING, StandardCopyOption.ATOMIC_MOVE)
|
||||
@ -83,8 +85,8 @@ sealed trait BackupEvent
|
||||
// this notification is sent when we have completed our backup process (our backup file is ready to be used)
|
||||
case object BackupCompleted extends BackupEvent
|
||||
|
||||
object BackupHandler {
|
||||
object FileBackupHandler {
|
||||
// using this method is the only way to create a BackupHandler actor
|
||||
// we make sure that it uses a custom bounded mailbox, and a custom pinned dispatcher (i.e our actor will have its own thread pool with 1 single thread)
|
||||
def props(databases: Databases, backupFile: File, backupScript_opt: Option[String]) = Props(new BackupHandler(databases, backupFile, backupScript_opt)).withMailbox("eclair.backup-mailbox").withDispatcher("eclair.backup-dispatcher")
|
||||
def props(databases: FileBackup, backupFile: File, backupScript_opt: Option[String]) = Props(new FileBackupHandler(databases, backupFile, backupScript_opt)).withMailbox("eclair.backup-mailbox").withDispatcher("eclair.backup-dispatcher")
|
||||
}
|
@ -0,0 +1,139 @@
|
||||
/*
|
||||
* Copyright 2019 ACINQ SAS
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package fr.acinq.eclair.db.jdbc
|
||||
|
||||
import java.sql.{Connection, ResultSet, Statement}
|
||||
import java.util.UUID
|
||||
|
||||
import fr.acinq.bitcoin.ByteVector32
|
||||
import fr.acinq.eclair.MilliSatoshi
|
||||
import javax.sql.DataSource
|
||||
import scodec.Codec
|
||||
import scodec.bits.{BitVector, ByteVector}
|
||||
|
||||
import scala.collection.immutable.Queue
|
||||
|
||||
trait JdbcUtils {
|
||||
|
||||
def withConnection[T](f: Connection => T)(implicit dataSource: DataSource): T = {
|
||||
val connection = dataSource.getConnection()
|
||||
try {
|
||||
f(connection)
|
||||
} finally {
|
||||
connection.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This helper makes sure statements are correctly closed.
|
||||
*
|
||||
* @param inTransaction if set to true, all updates in the block will be run in a transaction.
|
||||
*/
|
||||
def using[T <: Statement, U](statement: T, inTransaction: Boolean = false)(block: T => U): U = {
|
||||
val autoCommit = statement.getConnection.getAutoCommit
|
||||
try {
|
||||
if (inTransaction) statement.getConnection.setAutoCommit(false)
|
||||
val res = block(statement)
|
||||
if (inTransaction) statement.getConnection.commit()
|
||||
res
|
||||
} catch {
|
||||
case t: Exception =>
|
||||
if (inTransaction) statement.getConnection.rollback()
|
||||
throw t
|
||||
} finally {
|
||||
if (inTransaction) statement.getConnection.setAutoCommit(autoCommit)
|
||||
if (statement != null) statement.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This helper assumes that there is a "data" column available, decodable with the provided codec
|
||||
*
|
||||
* TODO: we should use an scala.Iterator instead
|
||||
*/
|
||||
def codecSequence[T](rs: ResultSet, codec: Codec[T]): Seq[T] = {
|
||||
var q: Queue[T] = Queue()
|
||||
while (rs.next()) {
|
||||
q = q :+ codec.decode(BitVector(rs.getBytes("data"))).require.value
|
||||
}
|
||||
q
|
||||
}
|
||||
|
||||
case class ExtendedResultSet(rs: ResultSet) {
|
||||
|
||||
def getByteVectorFromHex(columnLabel: String): ByteVector = {
|
||||
val s = rs.getString(columnLabel).stripPrefix("\\x")
|
||||
ByteVector.fromValidHex(s)
|
||||
}
|
||||
|
||||
def getByteVector32FromHex(columnLabel: String): ByteVector32 = {
|
||||
val s = rs.getString(columnLabel)
|
||||
ByteVector32(ByteVector.fromValidHex(s))
|
||||
}
|
||||
|
||||
def getByteVector32FromHexNullable(columnLabel: String): Option[ByteVector32] = {
|
||||
val s = rs.getString(columnLabel)
|
||||
if (rs.wasNull()) None else {
|
||||
Some(ByteVector32(ByteVector.fromValidHex(s)))
|
||||
}
|
||||
}
|
||||
|
||||
def getBitVectorOpt(columnLabel: String): Option[BitVector] = Option(rs.getBytes(columnLabel)).map(BitVector(_))
|
||||
|
||||
def getByteVector(columnLabel: String): ByteVector = ByteVector(rs.getBytes(columnLabel))
|
||||
|
||||
def getByteVectorNullable(columnLabel: String): ByteVector = {
|
||||
val result = rs.getBytes(columnLabel)
|
||||
if (rs.wasNull()) ByteVector.empty else ByteVector(result)
|
||||
}
|
||||
|
||||
def getByteVector32(columnLabel: String): ByteVector32 = ByteVector32(ByteVector(rs.getBytes(columnLabel)))
|
||||
|
||||
def getByteVector32Nullable(columnLabel: String): Option[ByteVector32] = {
|
||||
val bytes = rs.getBytes(columnLabel)
|
||||
if (rs.wasNull()) None else Some(ByteVector32(ByteVector(bytes)))
|
||||
}
|
||||
|
||||
def getStringNullable(columnLabel: String): Option[String] = {
|
||||
val result = rs.getString(columnLabel)
|
||||
if (rs.wasNull()) None else Some(result)
|
||||
}
|
||||
|
||||
def getLongNullable(columnLabel: String): Option[Long] = {
|
||||
val result = rs.getLong(columnLabel)
|
||||
if (rs.wasNull()) None else Some(result)
|
||||
}
|
||||
|
||||
def getUUIDNullable(label: String): Option[UUID] = {
|
||||
val result = rs.getString(label)
|
||||
if (rs.wasNull()) None else Some(UUID.fromString(result))
|
||||
}
|
||||
|
||||
def getMilliSatoshiNullable(label: String): Option[MilliSatoshi] = {
|
||||
val result = rs.getLong(label)
|
||||
if (rs.wasNull()) None else Some(MilliSatoshi(result))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object ExtendedResultSet {
|
||||
implicit def conv(rs: ResultSet): ExtendedResultSet = ExtendedResultSet(rs)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object JdbcUtils extends JdbcUtils
|
321
eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala
Normal file
321
eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala
Normal file
@ -0,0 +1,321 @@
|
||||
/*
|
||||
* Copyright 2019 ACINQ SAS
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package fr.acinq.eclair.db.pg
|
||||
|
||||
import java.util.UUID
|
||||
|
||||
import fr.acinq.bitcoin.Crypto.PublicKey
|
||||
import fr.acinq.bitcoin.{ByteVector32, Satoshi}
|
||||
import fr.acinq.eclair.channel.{ChannelErrorOccurred, LocalError, NetworkFeePaid, RemoteError}
|
||||
import fr.acinq.eclair.db._
|
||||
import fr.acinq.eclair.payment._
|
||||
import fr.acinq.eclair.{LongToBtcAmount, MilliSatoshi}
|
||||
import grizzled.slf4j.Logging
|
||||
import javax.sql.DataSource
|
||||
|
||||
import scala.collection.immutable.Queue
|
||||
|
||||
class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
|
||||
|
||||
import PgUtils._
|
||||
import ExtendedResultSet._
|
||||
|
||||
val DB_NAME = "audit"
|
||||
val CURRENT_VERSION = 4
|
||||
|
||||
case class RelayedPart(channelId: ByteVector32, amount: MilliSatoshi, direction: String, relayType: String, timestamp: Long)
|
||||
|
||||
inTransaction { pg =>
|
||||
using(pg.createStatement()) { statement =>
|
||||
|
||||
getVersion(statement, DB_NAME, CURRENT_VERSION) match {
|
||||
case CURRENT_VERSION =>
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp BIGINT NOT NULL)")
|
||||
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)")
|
||||
case unknownVersion =>
|
||||
throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def add(e: ChannelLifecycleEvent): Unit =
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("INSERT INTO channel_events VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, e.channelId.toHex)
|
||||
statement.setString(2, e.remoteNodeId.value.toHex)
|
||||
statement.setLong(3, e.capacity.toLong)
|
||||
statement.setBoolean(4, e.isFunder)
|
||||
statement.setBoolean(5, e.isPrivate)
|
||||
statement.setString(6, e.event)
|
||||
statement.setLong(7, System.currentTimeMillis)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
override def add(e: PaymentSent): Unit =
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
e.parts.foreach(p => {
|
||||
statement.setLong(1, p.amount.toLong)
|
||||
statement.setLong(2, p.feesPaid.toLong)
|
||||
statement.setLong(3, e.recipientAmount.toLong)
|
||||
statement.setString(4, p.id.toString)
|
||||
statement.setString(5, e.id.toString)
|
||||
statement.setString(6, e.paymentHash.toHex)
|
||||
statement.setString(7, e.paymentPreimage.toHex)
|
||||
statement.setString(8, e.recipientNodeId.value.toHex)
|
||||
statement.setString(9, p.toChannelId.toHex)
|
||||
statement.setLong(10, p.timestamp)
|
||||
statement.addBatch()
|
||||
})
|
||||
statement.executeBatch()
|
||||
}
|
||||
}
|
||||
|
||||
override def add(e: PaymentReceived): Unit =
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("INSERT INTO received VALUES (?, ?, ?, ?)")) { statement =>
|
||||
e.parts.foreach(p => {
|
||||
statement.setLong(1, p.amount.toLong)
|
||||
statement.setString(2, e.paymentHash.toHex)
|
||||
statement.setString(3, p.fromChannelId.toHex)
|
||||
statement.setLong(4, p.timestamp)
|
||||
statement.addBatch()
|
||||
})
|
||||
statement.executeBatch()
|
||||
}
|
||||
}
|
||||
|
||||
override def add(e: PaymentRelayed): Unit =
|
||||
inTransaction { pg =>
|
||||
val payments = e match {
|
||||
case ChannelPaymentRelayed(amountIn, amountOut, _, fromChannelId, toChannelId, ts) =>
|
||||
// non-trampoline relayed payments have one input and one output
|
||||
Seq(RelayedPart(fromChannelId, amountIn, "IN", "channel", ts), RelayedPart(toChannelId, amountOut, "OUT", "channel", ts))
|
||||
case TrampolinePaymentRelayed(_, incoming, outgoing, ts) =>
|
||||
// trampoline relayed payments do MPP aggregation and may have M inputs and N outputs
|
||||
incoming.map(i => RelayedPart(i.channelId, i.amount, "IN", "trampoline", ts)) ++ outgoing.map(o => RelayedPart(o.channelId, o.amount, "OUT", "trampoline", ts))
|
||||
}
|
||||
for (p <- payments) {
|
||||
using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, e.paymentHash.toHex)
|
||||
statement.setLong(2, p.amount.toLong)
|
||||
statement.setString(3, p.channelId.toHex)
|
||||
statement.setString(4, p.direction)
|
||||
statement.setString(5, p.relayType)
|
||||
statement.setLong(6, e.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def add(e: NetworkFeePaid): Unit =
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("INSERT INTO network_fees VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, e.channelId.toHex)
|
||||
statement.setString(2, e.remoteNodeId.value.toHex)
|
||||
statement.setString(3, e.tx.txid.toHex)
|
||||
statement.setLong(4, e.fee.toLong)
|
||||
statement.setString(5, e.txType)
|
||||
statement.setLong(6, System.currentTimeMillis)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
override def add(e: ChannelErrorOccurred): Unit =
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("INSERT INTO channel_errors VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
val (errorName, errorMessage) = e.error match {
|
||||
case LocalError(t) => (t.getClass.getSimpleName, t.getMessage)
|
||||
case RemoteError(error) => ("remote", error.toAscii)
|
||||
}
|
||||
statement.setString(1, e.channelId.toHex)
|
||||
statement.setString(2, e.remoteNodeId.value.toHex)
|
||||
statement.setString(3, errorName)
|
||||
statement.setString(4, errorMessage)
|
||||
statement.setBoolean(5, e.isFatal)
|
||||
statement.setLong(6, System.currentTimeMillis)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
override def listSent(from: Long, to: Long): Seq[PaymentSent] =
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp >= ? AND timestamp < ?")) { statement =>
|
||||
statement.setLong(1, from)
|
||||
statement.setLong(2, to)
|
||||
val rs = statement.executeQuery()
|
||||
var sentByParentId = Map.empty[UUID, PaymentSent]
|
||||
while (rs.next()) {
|
||||
val parentId = UUID.fromString(rs.getString("parent_payment_id"))
|
||||
val part = PaymentSent.PartialPayment(
|
||||
UUID.fromString(rs.getString("payment_id")),
|
||||
MilliSatoshi(rs.getLong("amount_msat")),
|
||||
MilliSatoshi(rs.getLong("fees_msat")),
|
||||
rs.getByteVector32FromHex("to_channel_id"),
|
||||
None, // we don't store the route in the audit DB
|
||||
rs.getLong("timestamp"))
|
||||
val sent = sentByParentId.get(parentId) match {
|
||||
case Some(s) => s.copy(parts = s.parts :+ part)
|
||||
case None => PaymentSent(
|
||||
parentId,
|
||||
rs.getByteVector32FromHex("payment_hash"),
|
||||
rs.getByteVector32FromHex("payment_preimage"),
|
||||
MilliSatoshi(rs.getLong("recipient_amount_msat")),
|
||||
PublicKey(rs.getByteVectorFromHex("recipient_node_id")),
|
||||
Seq(part))
|
||||
}
|
||||
sentByParentId = sentByParentId + (parentId -> sent)
|
||||
}
|
||||
sentByParentId.values.toSeq.sortBy(_.timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
override def listReceived(from: Long, to: Long): Seq[PaymentReceived] =
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("SELECT * FROM received WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
|
||||
statement.setLong(1, from)
|
||||
statement.setLong(2, to)
|
||||
val rs = statement.executeQuery()
|
||||
var receivedByHash = Map.empty[ByteVector32, PaymentReceived]
|
||||
while (rs.next()) {
|
||||
val paymentHash = rs.getByteVector32FromHex("payment_hash")
|
||||
val part = PaymentReceived.PartialPayment(
|
||||
MilliSatoshi(rs.getLong("amount_msat")),
|
||||
rs.getByteVector32FromHex("from_channel_id"),
|
||||
rs.getLong("timestamp"))
|
||||
val received = receivedByHash.get(paymentHash) match {
|
||||
case Some(r) => r.copy(parts = r.parts :+ part)
|
||||
case None => PaymentReceived(paymentHash, Seq(part))
|
||||
}
|
||||
receivedByHash = receivedByHash + (paymentHash -> received)
|
||||
}
|
||||
receivedByHash.values.toSeq.sortBy(_.timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] =
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
|
||||
statement.setLong(1, from)
|
||||
statement.setLong(2, to)
|
||||
val rs = statement.executeQuery()
|
||||
var relayedByHash = Map.empty[ByteVector32, Seq[RelayedPart]]
|
||||
while (rs.next()) {
|
||||
val paymentHash = rs.getByteVector32FromHex("payment_hash")
|
||||
val part = RelayedPart(
|
||||
rs.getByteVector32FromHex("channel_id"),
|
||||
MilliSatoshi(rs.getLong("amount_msat")),
|
||||
rs.getString("direction"),
|
||||
rs.getString("relay_type"),
|
||||
rs.getLong("timestamp"))
|
||||
relayedByHash = relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part))
|
||||
}
|
||||
relayedByHash.flatMap {
|
||||
case (paymentHash, parts) =>
|
||||
// We may have been routing multiple payments for the same payment_hash (MPP) in both cases (trampoline and channel).
|
||||
// NB: we may link the wrong in-out parts, but the overall sum will be correct: we sort by amounts to minimize the risk of mismatch.
|
||||
val incoming = parts.filter(_.direction == "IN").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount)
|
||||
val outgoing = parts.filter(_.direction == "OUT").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount)
|
||||
parts.headOption match {
|
||||
case Some(RelayedPart(_, _, _, "channel", timestamp)) => incoming.zip(outgoing).map {
|
||||
case (in, out) => ChannelPaymentRelayed(in.amount, out.amount, paymentHash, in.channelId, out.channelId, timestamp)
|
||||
}
|
||||
case Some(RelayedPart(_, _, _, "trampoline", timestamp)) => TrampolinePaymentRelayed(paymentHash, incoming, outgoing, timestamp) :: Nil
|
||||
case _ => Nil
|
||||
}
|
||||
}.toSeq.sortBy(_.timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
override def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] =
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
|
||||
statement.setLong(1, from)
|
||||
statement.setLong(2, to)
|
||||
val rs = statement.executeQuery()
|
||||
var q: Queue[NetworkFee] = Queue()
|
||||
while (rs.next()) {
|
||||
q = q :+ NetworkFee(
|
||||
remoteNodeId = PublicKey(rs.getByteVectorFromHex("node_id")),
|
||||
channelId = rs.getByteVector32FromHex("channel_id"),
|
||||
txId = rs.getByteVector32FromHex("tx_id"),
|
||||
fee = Satoshi(rs.getLong("fee_sat")),
|
||||
txType = rs.getString("tx_type"),
|
||||
timestamp = rs.getLong("timestamp"))
|
||||
}
|
||||
q
|
||||
}
|
||||
}
|
||||
|
||||
override def stats(from: Long, to: Long): Seq[Stats] = {
|
||||
val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { case (feeByChannelId, f) =>
|
||||
feeByChannelId + (f.channelId -> (feeByChannelId.getOrElse(f.channelId, 0 sat) + f.fee))
|
||||
}
|
||||
case class Relayed(amount: MilliSatoshi, fee: MilliSatoshi, direction: String)
|
||||
val relayed = listRelayed(from, to).foldLeft(Map.empty[ByteVector32, Seq[Relayed]]) { case (previous, e) =>
|
||||
// NB: we must avoid counting the fee twice: we associate it to the outgoing channels rather than the incoming ones.
|
||||
val current = e match {
|
||||
case c: ChannelPaymentRelayed => Map(
|
||||
c.fromChannelId -> (Relayed(c.amountIn, 0 msat, "IN") +: previous.getOrElse(c.fromChannelId, Nil)),
|
||||
c.toChannelId -> (Relayed(c.amountOut, c.amountIn - c.amountOut, "OUT") +: previous.getOrElse(c.toChannelId, Nil)),
|
||||
)
|
||||
case t: TrampolinePaymentRelayed =>
|
||||
// We ensure a trampoline payment is counted only once per channel and per direction (if multiple HTLCs were
|
||||
// sent from/to the same channel, we group them).
|
||||
val in = t.incoming.groupBy(_.channelId).map { case (channelId, parts) => (channelId, Relayed(parts.map(_.amount).sum, 0 msat, "IN")) }.toSeq
|
||||
val out = t.outgoing.groupBy(_.channelId).map { case (channelId, parts) =>
|
||||
val fee = (t.amountIn - t.amountOut) * parts.length / t.outgoing.length // we split the fee among outgoing channels
|
||||
(channelId, Relayed(parts.map(_.amount).sum, fee, "OUT"))
|
||||
}.toSeq
|
||||
(in ++ out).groupBy(_._1).map { case (channelId, payments) => (channelId, payments.map(_._2) ++ previous.getOrElse(channelId, Nil)) }
|
||||
}
|
||||
previous ++ current
|
||||
}
|
||||
// Channels opened by our peers won't have any entry in the network_fees table, but we still want to compute stats for them.
|
||||
val allChannels = networkFees.keySet ++ relayed.keySet
|
||||
allChannels.toSeq.flatMap(channelId => {
|
||||
val networkFee = networkFees.getOrElse(channelId, 0 sat)
|
||||
val (in, out) = relayed.getOrElse(channelId, Nil).partition(_.direction == "IN")
|
||||
((in, "IN") :: (out, "OUT") :: Nil).map { case (r, direction) =>
|
||||
val paymentCount = r.length
|
||||
if (paymentCount == 0) {
|
||||
Stats(channelId, direction, 0 sat, 0, 0 sat, networkFee)
|
||||
} else {
|
||||
val avgPaymentAmount = r.map(_.amount).sum / paymentCount
|
||||
val relayFee = r.map(_.fee).sum
|
||||
Stats(channelId, direction, avgPaymentAmount.truncateToSatoshi, paymentCount, relayFee.truncateToSatoshi, networkFee)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
override def close(): Unit = ()
|
||||
|
||||
}
|
@ -0,0 +1,124 @@
|
||||
/*
|
||||
* Copyright 2019 ACINQ SAS
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package fr.acinq.eclair.db.pg
|
||||
|
||||
import fr.acinq.bitcoin.ByteVector32
|
||||
import fr.acinq.eclair.CltvExpiry
|
||||
import fr.acinq.eclair.channel.HasCommitments
|
||||
import fr.acinq.eclair.db.ChannelsDb
|
||||
import fr.acinq.eclair.db.pg.PgUtils.DatabaseLock
|
||||
import fr.acinq.eclair.wire.ChannelCodecs.stateDataCodec
|
||||
import grizzled.slf4j.Logging
|
||||
import javax.sql.DataSource
|
||||
|
||||
import scala.collection.immutable.Queue
|
||||
|
||||
class PgChannelsDb(implicit ds: DataSource, lock: DatabaseLock) extends ChannelsDb with Logging {
|
||||
|
||||
import PgUtils.ExtendedResultSet._
|
||||
import PgUtils._
|
||||
import lock._
|
||||
|
||||
val DB_NAME = "channels"
|
||||
val CURRENT_VERSION = 2
|
||||
|
||||
inTransaction { pg =>
|
||||
using(pg.createStatement()) { statement =>
|
||||
getVersion(statement, DB_NAME, CURRENT_VERSION) match {
|
||||
case CURRENT_VERSION =>
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
|
||||
case unknownVersion => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def addOrUpdateChannel(state: HasCommitments): Unit = {
|
||||
withLock { pg =>
|
||||
val data = stateDataCodec.encode(state).require.toByteArray
|
||||
using(pg.prepareStatement("UPDATE local_channels SET data=? WHERE channel_id=?")) { update =>
|
||||
update.setBytes(1, data)
|
||||
update.setString(2, state.channelId.toHex)
|
||||
if (update.executeUpdate() == 0) {
|
||||
using(pg.prepareStatement("INSERT INTO local_channels VALUES (?, ?, FALSE)")) { statement =>
|
||||
statement.setString(1, state.channelId.toHex)
|
||||
statement.setBytes(2, data)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def removeChannel(channelId: ByteVector32): Unit = {
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("DELETE FROM pending_relay WHERE channel_id=?")) { statement =>
|
||||
statement.setString(1, channelId.toHex)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
using(pg.prepareStatement("DELETE FROM htlc_infos WHERE channel_id=?")) { statement =>
|
||||
statement.setString(1, channelId.toHex)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
using(pg.prepareStatement("UPDATE local_channels SET is_closed=TRUE WHERE channel_id=?")) { statement =>
|
||||
statement.setString(1, channelId.toHex)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def listLocalChannels(): Seq[HasCommitments] = {
|
||||
withLock { pg =>
|
||||
using(pg.createStatement) { statement =>
|
||||
val rs = statement.executeQuery("SELECT data FROM local_channels WHERE is_closed=FALSE")
|
||||
codecSequence(rs, stateDataCodec)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def addHtlcInfo(channelId: ByteVector32, commitmentNumber: Long, paymentHash: ByteVector32, cltvExpiry: CltvExpiry): Unit = {
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("INSERT INTO htlc_infos VALUES (?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, channelId.toHex)
|
||||
statement.setLong(2, commitmentNumber)
|
||||
statement.setString(3, paymentHash.toHex)
|
||||
statement.setLong(4, cltvExpiry.toLong)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def listHtlcInfos(channelId: ByteVector32, commitmentNumber: Long): Seq[(ByteVector32, CltvExpiry)] = {
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("SELECT payment_hash, cltv_expiry FROM htlc_infos WHERE channel_id=? AND commitment_number=?")) { statement =>
|
||||
statement.setString(1, channelId.toHex)
|
||||
statement.setString(2, commitmentNumber.toString)
|
||||
val rs = statement.executeQuery
|
||||
var q: Queue[(ByteVector32, CltvExpiry)] = Queue()
|
||||
while (rs.next()) {
|
||||
q = q :+ (ByteVector32(rs.getByteVector32FromHex("payment_hash")), CltvExpiry(rs.getLong("cltv_expiry")))
|
||||
}
|
||||
q
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def close(): Unit = ()
|
||||
}
|
@ -0,0 +1,183 @@
|
||||
/*
|
||||
* Copyright 2019 ACINQ SAS
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package fr.acinq.eclair.db.pg
|
||||
|
||||
import fr.acinq.bitcoin.{ByteVector32, Crypto, Satoshi}
|
||||
import fr.acinq.eclair.ShortChannelId
|
||||
import fr.acinq.eclair.db.NetworkDb
|
||||
import fr.acinq.eclair.router.Router.PublicChannel
|
||||
import fr.acinq.eclair.wire.LightningMessageCodecs.{channelAnnouncementCodec, channelUpdateCodec, nodeAnnouncementCodec}
|
||||
import fr.acinq.eclair.wire.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement}
|
||||
import grizzled.slf4j.Logging
|
||||
import javax.sql.DataSource
|
||||
|
||||
import scala.collection.immutable.SortedMap
|
||||
|
||||
class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging {
|
||||
|
||||
import PgUtils.ExtendedResultSet._
|
||||
import PgUtils._
|
||||
|
||||
val DB_NAME = "network"
|
||||
val CURRENT_VERSION = 2
|
||||
|
||||
inTransaction { pg =>
|
||||
using(pg.createStatement()) { statement =>
|
||||
getVersion(statement, DB_NAME, CURRENT_VERSION) match {
|
||||
case CURRENT_VERSION => () // nothing to do
|
||||
case unknown => throw new IllegalArgumentException(s"unknown version $unknown for network db")
|
||||
}
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS nodes (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channels (short_channel_id BIGINT NOT NULL PRIMARY KEY, txid TEXT NOT NULL, channel_announcement BYTEA NOT NULL, capacity_sat BIGINT NOT NULL, channel_update_1 BYTEA NULL, channel_update_2 BYTEA NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS pruned (short_channel_id BIGINT NOT NULL PRIMARY KEY)")
|
||||
}
|
||||
}
|
||||
|
||||
override def addNode(n: NodeAnnouncement): Unit = {
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("INSERT INTO nodes VALUES (?, ?) ON CONFLICT DO NOTHING")) { statement =>
|
||||
statement.setString(1, n.nodeId.value.toHex)
|
||||
statement.setBytes(2, nodeAnnouncementCodec.encode(n).require.toByteArray)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def updateNode(n: NodeAnnouncement): Unit = {
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("UPDATE nodes SET data=? WHERE node_id=?")) { statement =>
|
||||
statement.setBytes(1, nodeAnnouncementCodec.encode(n).require.toByteArray)
|
||||
statement.setString(2, n.nodeId.value.toHex)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def getNode(nodeId: Crypto.PublicKey): Option[NodeAnnouncement] = {
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("SELECT data FROM nodes WHERE node_id=?")) { statement =>
|
||||
statement.setString(1, nodeId.value.toHex)
|
||||
val rs = statement.executeQuery()
|
||||
codecSequence(rs, nodeAnnouncementCodec).headOption
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def removeNode(nodeId: Crypto.PublicKey): Unit = {
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("DELETE FROM nodes WHERE node_id=?")) { statement =>
|
||||
statement.setString(1, nodeId.value.toHex)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def listNodes(): Seq[NodeAnnouncement] = {
|
||||
inTransaction { pg =>
|
||||
using(pg.createStatement()) { statement =>
|
||||
val rs = statement.executeQuery("SELECT data FROM nodes")
|
||||
codecSequence(rs, nodeAnnouncementCodec)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def addChannel(c: ChannelAnnouncement, txid: ByteVector32, capacity: Satoshi): Unit = {
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("INSERT INTO channels VALUES (?, ?, ?, ?) ON CONFLICT DO NOTHING")) { statement =>
|
||||
statement.setLong(1, c.shortChannelId.toLong)
|
||||
statement.setString(2, txid.toHex)
|
||||
statement.setBytes(3, channelAnnouncementCodec.encode(c).require.toByteArray)
|
||||
statement.setLong(4, capacity.toLong)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def updateChannel(u: ChannelUpdate): Unit = {
|
||||
val column = if (u.isNode1) "channel_update_1" else "channel_update_2"
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement(s"UPDATE channels SET $column=? WHERE short_channel_id=?")) { statement =>
|
||||
statement.setBytes(1, channelUpdateCodec.encode(u).require.toByteArray)
|
||||
statement.setLong(2, u.shortChannelId.toLong)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def listChannels(): SortedMap[ShortChannelId, PublicChannel] = {
|
||||
inTransaction { pg =>
|
||||
using(pg.createStatement()) { statement =>
|
||||
val rs = statement.executeQuery("SELECT channel_announcement, txid, capacity_sat, channel_update_1, channel_update_2 FROM channels")
|
||||
var m = SortedMap.empty[ShortChannelId, PublicChannel]
|
||||
while (rs.next()) {
|
||||
val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value
|
||||
val txId = ByteVector32.fromValidHex(rs.getString("txid"))
|
||||
val capacity = rs.getLong("capacity_sat")
|
||||
val channel_update_1_opt = rs.getBitVectorOpt("channel_update_1").map(channelUpdateCodec.decode(_).require.value)
|
||||
val channel_update_2_opt = rs.getBitVectorOpt("channel_update_2").map(channelUpdateCodec.decode(_).require.value)
|
||||
m = m + (ann.shortChannelId -> PublicChannel(ann, txId, Satoshi(capacity), channel_update_1_opt, channel_update_2_opt, None))
|
||||
}
|
||||
m
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def removeChannels(shortChannelIds: Iterable[ShortChannelId]): Unit = {
|
||||
inTransaction { pg =>
|
||||
using(pg.createStatement) { statement =>
|
||||
shortChannelIds
|
||||
.grouped(1000) // remove channels by batch of 1000
|
||||
.foreach { _ =>
|
||||
val ids = shortChannelIds.map(_.toLong).mkString(",")
|
||||
statement.executeUpdate(s"DELETE FROM channels WHERE short_channel_id IN ($ids)")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def addToPruned(shortChannelIds: Iterable[ShortChannelId]): Unit = {
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("INSERT INTO pruned VALUES (?) ON CONFLICT DO NOTHING")) { statement =>
|
||||
shortChannelIds.foreach(shortChannelId => {
|
||||
statement.setLong(1, shortChannelId.toLong)
|
||||
statement.addBatch()
|
||||
})
|
||||
statement.executeBatch()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def removeFromPruned(shortChannelId: ShortChannelId): Unit = {
|
||||
inTransaction { pg =>
|
||||
using(pg.createStatement) { statement =>
|
||||
statement.executeUpdate(s"DELETE FROM pruned WHERE short_channel_id=${shortChannelId.toLong}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def isPruned(shortChannelId: ShortChannelId): Boolean = {
|
||||
inTransaction { pg =>
|
||||
using(pg.prepareStatement("SELECT short_channel_id from pruned WHERE short_channel_id=?")) { statement =>
|
||||
statement.setLong(1, shortChannelId.toLong)
|
||||
val rs = statement.executeQuery()
|
||||
rs.next()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def close(): Unit = ()
|
||||
}
|
@ -0,0 +1,406 @@
|
||||
/*
|
||||
* Copyright 2019 ACINQ SAS
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package fr.acinq.eclair.db.pg
|
||||
|
||||
import java.sql.ResultSet
|
||||
import java.util.UUID
|
||||
|
||||
import fr.acinq.bitcoin.ByteVector32
|
||||
import fr.acinq.bitcoin.Crypto.PublicKey
|
||||
import fr.acinq.eclair.MilliSatoshi
|
||||
import fr.acinq.eclair.db.pg.PgUtils.DatabaseLock
|
||||
import fr.acinq.eclair.db.{HopSummary, _}
|
||||
import fr.acinq.eclair.payment.{PaymentFailed, PaymentRequest, PaymentSent}
|
||||
import fr.acinq.eclair.wire.CommonCodecs
|
||||
import grizzled.slf4j.Logging
|
||||
import javax.sql.DataSource
|
||||
import scodec.Attempt
|
||||
import scodec.bits.BitVector
|
||||
import scodec.codecs._
|
||||
|
||||
import scala.collection.immutable.Queue
|
||||
import scala.concurrent.duration._
|
||||
|
||||
class PgPaymentsDb(implicit ds: DataSource, lock: DatabaseLock) extends PaymentsDb with Logging {
|
||||
|
||||
import PgUtils.ExtendedResultSet._
|
||||
import PgUtils._
|
||||
import lock._
|
||||
|
||||
val DB_NAME = "payments"
|
||||
val CURRENT_VERSION = 4
|
||||
|
||||
private val hopSummaryCodec = (("node_id" | CommonCodecs.publicKey) :: ("next_node_id" | CommonCodecs.publicKey) :: ("short_channel_id" | optional(bool, CommonCodecs.shortchannelid))).as[HopSummary]
|
||||
private val paymentRouteCodec = discriminated[List[HopSummary]].by(byte)
|
||||
.typecase(0x01, listOfN(uint8, hopSummaryCodec))
|
||||
private val failureSummaryCodec = (("type" | enumerated(uint8, FailureType)) :: ("message" | ascii32) :: paymentRouteCodec).as[FailureSummary]
|
||||
private val paymentFailuresCodec = discriminated[List[FailureSummary]].by(byte)
|
||||
.typecase(0x01, listOfN(uint8, failureSummaryCodec))
|
||||
|
||||
inTransaction { pg =>
|
||||
using(pg.createStatement()) { statement =>
|
||||
|
||||
getVersion(statement, DB_NAME, CURRENT_VERSION) match {
|
||||
case CURRENT_VERSION =>
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash TEXT NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage TEXT NOT NULL, payment_request TEXT NOT NULL, received_msat BIGINT, created_at BIGINT NOT NULL, expire_at BIGINT NOT NULL, received_at BIGINT)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash TEXT NOT NULL, payment_preimage TEXT, payment_type TEXT NOT NULL, amount_msat BIGINT NOT NULL, fees_msat BIGINT, recipient_amount_msat BIGINT NOT NULL, recipient_node_id TEXT NOT NULL, payment_request TEXT, payment_route BYTEA, failures BYTEA, created_at BIGINT NOT NULL, completed_at BIGINT)")
|
||||
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_parent_id_idx ON sent_payments(parent_id)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_payment_hash_idx ON sent_payments(payment_hash)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_created_idx ON sent_payments(created_at)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_created_idx ON received_payments(created_at)")
|
||||
case unknownVersion => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def addOutgoingPayment(sent: OutgoingPayment): Unit = {
|
||||
require(sent.status == OutgoingPaymentStatus.Pending, s"outgoing payment isn't pending (${sent.status.getClass.getSimpleName})")
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, payment_type, amount_msat, recipient_amount_msat, recipient_node_id, created_at, payment_request) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, sent.id.toString)
|
||||
statement.setString(2, sent.parentId.toString)
|
||||
statement.setString(3, sent.externalId.orNull)
|
||||
statement.setString(4, sent.paymentHash.toHex)
|
||||
statement.setString(5, sent.paymentType)
|
||||
statement.setLong(6, sent.amount.toLong)
|
||||
statement.setLong(7, sent.recipientAmount.toLong)
|
||||
statement.setString(8, sent.recipientNodeId.value.toHex)
|
||||
statement.setLong(9, sent.createdAt)
|
||||
statement.setString(10, sent.paymentRequest.map(PaymentRequest.write).orNull)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def updateOutgoingPayment(paymentResult: PaymentSent): Unit =
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("UPDATE sent_payments SET (completed_at, payment_preimage, fees_msat, payment_route) = (?, ?, ?, ?) WHERE id = ? AND completed_at IS NULL")) { statement =>
|
||||
paymentResult.parts.foreach(p => {
|
||||
statement.setLong(1, p.timestamp)
|
||||
statement.setString(2, paymentResult.paymentPreimage.toHex)
|
||||
statement.setLong(3, p.feesPaid.toLong)
|
||||
statement.setBytes(4, paymentRouteCodec.encode(p.route.getOrElse(Nil).map(h => HopSummary(h)).toList).require.toByteArray)
|
||||
statement.setString(5, p.id.toString)
|
||||
statement.addBatch()
|
||||
})
|
||||
if (statement.executeBatch().contains(0)) throw new IllegalArgumentException(s"Tried to mark an outgoing payment as succeeded but already in final status (id=${paymentResult.id})")
|
||||
}
|
||||
}
|
||||
|
||||
override def updateOutgoingPayment(paymentResult: PaymentFailed): Unit =
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("UPDATE sent_payments SET (completed_at, failures) = (?, ?) WHERE id = ? AND completed_at IS NULL")) { statement =>
|
||||
statement.setLong(1, paymentResult.timestamp)
|
||||
statement.setBytes(2, paymentFailuresCodec.encode(paymentResult.failures.map(f => FailureSummary(f)).toList).require.toByteArray)
|
||||
statement.setString(3, paymentResult.id.toString)
|
||||
if (statement.executeUpdate() == 0) throw new IllegalArgumentException(s"Tried to mark an outgoing payment as failed but already in final status (id=${paymentResult.id})")
|
||||
}
|
||||
}
|
||||
|
||||
private def parseOutgoingPayment(rs: ResultSet): OutgoingPayment = {
|
||||
val status = buildOutgoingPaymentStatus(
|
||||
rs.getByteVector32FromHexNullable("payment_preimage"),
|
||||
rs.getMilliSatoshiNullable("fees_msat"),
|
||||
rs.getBitVectorOpt("payment_route"),
|
||||
rs.getLongNullable("completed_at"),
|
||||
rs.getBitVectorOpt("failures"))
|
||||
|
||||
OutgoingPayment(
|
||||
UUID.fromString(rs.getString("id")),
|
||||
UUID.fromString(rs.getString("parent_id")),
|
||||
rs.getStringNullable("external_id"),
|
||||
rs.getByteVector32FromHex("payment_hash"),
|
||||
rs.getString("payment_type"),
|
||||
MilliSatoshi(rs.getLong("amount_msat")),
|
||||
MilliSatoshi(rs.getLong("recipient_amount_msat")),
|
||||
PublicKey(rs.getByteVectorFromHex("recipient_node_id")),
|
||||
rs.getLong("created_at"),
|
||||
rs.getStringNullable("payment_request").map(PaymentRequest.read),
|
||||
status
|
||||
)
|
||||
}
|
||||
|
||||
private def buildOutgoingPaymentStatus(preimage_opt: Option[ByteVector32], fees_opt: Option[MilliSatoshi], paymentRoute_opt: Option[BitVector], completedAt_opt: Option[Long], failures: Option[BitVector]): OutgoingPaymentStatus = {
|
||||
preimage_opt match {
|
||||
// If we have a pre-image, the payment succeeded.
|
||||
case Some(preimage) => OutgoingPaymentStatus.Succeeded(
|
||||
preimage, fees_opt.getOrElse(MilliSatoshi(0)), paymentRoute_opt.map(b => paymentRouteCodec.decode(b) match {
|
||||
case Attempt.Successful(route) => route.value
|
||||
case Attempt.Failure(_) => Nil
|
||||
}).getOrElse(Nil),
|
||||
completedAt_opt.getOrElse(0)
|
||||
)
|
||||
case None => completedAt_opt match {
|
||||
// Otherwise if the payment was marked completed, it's a failure.
|
||||
case Some(completedAt) => OutgoingPaymentStatus.Failed(
|
||||
failures.map(b => paymentFailuresCodec.decode(b) match {
|
||||
case Attempt.Successful(f) => f.value
|
||||
case Attempt.Failure(_) => Nil
|
||||
}).getOrElse(Nil),
|
||||
completedAt
|
||||
)
|
||||
// Else it's still pending.
|
||||
case _ => OutgoingPaymentStatus.Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def getOutgoingPayment(id: UUID): Option[OutgoingPayment] =
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("SELECT * FROM sent_payments WHERE id = ?")) { statement =>
|
||||
statement.setString(1, id.toString)
|
||||
val rs = statement.executeQuery()
|
||||
if (rs.next()) {
|
||||
Some(parseOutgoingPayment(rs))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def listOutgoingPayments(parentId: UUID): Seq[OutgoingPayment] =
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("SELECT * FROM sent_payments WHERE parent_id = ? ORDER BY created_at")) { statement =>
|
||||
statement.setString(1, parentId.toString)
|
||||
val rs = statement.executeQuery()
|
||||
var q: Queue[OutgoingPayment] = Queue()
|
||||
while (rs.next()) {
|
||||
q = q :+ parseOutgoingPayment(rs)
|
||||
}
|
||||
q
|
||||
}
|
||||
}
|
||||
|
||||
override def listOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] =
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("SELECT * FROM sent_payments WHERE payment_hash = ? ORDER BY created_at")) { statement =>
|
||||
statement.setString(1, paymentHash.toHex)
|
||||
val rs = statement.executeQuery()
|
||||
var q: Queue[OutgoingPayment] = Queue()
|
||||
while (rs.next()) {
|
||||
q = q :+ parseOutgoingPayment(rs)
|
||||
}
|
||||
q
|
||||
}
|
||||
}
|
||||
|
||||
override def listOutgoingPayments(from: Long, to: Long): Seq[OutgoingPayment] =
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("SELECT * FROM sent_payments WHERE created_at >= ? AND created_at < ? ORDER BY created_at")) { statement =>
|
||||
statement.setLong(1, from)
|
||||
statement.setLong(2, to)
|
||||
val rs = statement.executeQuery()
|
||||
var q: Queue[OutgoingPayment] = Queue()
|
||||
while (rs.next()) {
|
||||
q = q :+ parseOutgoingPayment(rs)
|
||||
}
|
||||
q
|
||||
}
|
||||
}
|
||||
|
||||
override def addIncomingPayment(pr: PaymentRequest, preimage: ByteVector32, paymentType: String): Unit =
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("INSERT INTO received_payments (payment_hash, payment_preimage, payment_type, payment_request, created_at, expire_at) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, pr.paymentHash.toHex)
|
||||
statement.setString(2, preimage.toHex)
|
||||
statement.setString(3, paymentType)
|
||||
statement.setString(4, PaymentRequest.write(pr))
|
||||
statement.setLong(5, pr.timestamp.seconds.toMillis) // BOLT11 timestamp is in seconds
|
||||
statement.setLong(6, (pr.timestamp + pr.expiry.getOrElse(PaymentRequest.DEFAULT_EXPIRY_SECONDS.toLong)).seconds.toMillis)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long): Unit =
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update =>
|
||||
update.setLong(1, amount.toLong)
|
||||
update.setLong(2, receivedAt)
|
||||
update.setString(3, paymentHash.toHex)
|
||||
val updated = update.executeUpdate()
|
||||
if (updated == 0) {
|
||||
throw new IllegalArgumentException("Inserted a received payment without having an invoice")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def parseIncomingPayment(rs: ResultSet): IncomingPayment = {
|
||||
val paymentRequest = rs.getString("payment_request")
|
||||
IncomingPayment(
|
||||
PaymentRequest.read(paymentRequest),
|
||||
rs.getByteVector32FromHex("payment_preimage"),
|
||||
rs.getString("payment_type"),
|
||||
rs.getLong("created_at"),
|
||||
buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("received_msat"), Some(paymentRequest), rs.getLongNullable("received_at")))
|
||||
}
|
||||
|
||||
private def buildIncomingPaymentStatus(amount_opt: Option[MilliSatoshi], serializedPaymentRequest_opt: Option[String], receivedAt_opt: Option[Long]): IncomingPaymentStatus = {
|
||||
amount_opt match {
|
||||
case Some(amount) => IncomingPaymentStatus.Received(amount, receivedAt_opt.getOrElse(0))
|
||||
case None if serializedPaymentRequest_opt.exists(PaymentRequest.fastHasExpired) => IncomingPaymentStatus.Expired
|
||||
case None => IncomingPaymentStatus.Pending
|
||||
}
|
||||
}
|
||||
|
||||
override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] =
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("SELECT * FROM received_payments WHERE payment_hash = ?")) { statement =>
|
||||
statement.setString(1, paymentHash.toHex)
|
||||
val rs = statement.executeQuery()
|
||||
if (rs.next()) {
|
||||
Some(parseIncomingPayment(rs))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def listIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] =
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("SELECT * FROM received_payments WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement =>
|
||||
statement.setLong(1, from)
|
||||
statement.setLong(2, to)
|
||||
val rs = statement.executeQuery()
|
||||
var q: Queue[IncomingPayment] = Queue()
|
||||
while (rs.next()) {
|
||||
q = q :+ parseIncomingPayment(rs)
|
||||
}
|
||||
q
|
||||
}
|
||||
}
|
||||
|
||||
override def listReceivedIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] =
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("SELECT * FROM received_payments WHERE received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at")) { statement =>
|
||||
statement.setLong(1, from)
|
||||
statement.setLong(2, to)
|
||||
val rs = statement.executeQuery()
|
||||
var q: Queue[IncomingPayment] = Queue()
|
||||
while (rs.next()) {
|
||||
q = q :+ parseIncomingPayment(rs)
|
||||
}
|
||||
q
|
||||
}
|
||||
}
|
||||
|
||||
override def listPendingIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] =
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("SELECT * FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at")) { statement =>
|
||||
statement.setLong(1, from)
|
||||
statement.setLong(2, to)
|
||||
statement.setLong(3, System.currentTimeMillis)
|
||||
val rs = statement.executeQuery()
|
||||
var q: Queue[IncomingPayment] = Queue()
|
||||
while (rs.next()) {
|
||||
q = q :+ parseIncomingPayment(rs)
|
||||
}
|
||||
q
|
||||
}
|
||||
}
|
||||
|
||||
override def listExpiredIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] =
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("SELECT * FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at")) { statement =>
|
||||
statement.setLong(1, from)
|
||||
statement.setLong(2, to)
|
||||
statement.setLong(3, System.currentTimeMillis)
|
||||
val rs = statement.executeQuery()
|
||||
var q: Queue[IncomingPayment] = Queue()
|
||||
while (rs.next()) {
|
||||
q = q :+ parseIncomingPayment(rs)
|
||||
}
|
||||
q
|
||||
}
|
||||
}
|
||||
|
||||
override def listPaymentsOverview(limit: Int): Seq[PlainPayment] = {
|
||||
// This query is an UNION of the ``sent_payments`` and ``received_payments`` table
|
||||
// - missing fields set to NULL when needed.
|
||||
// - only retrieve incoming payments that did receive funds.
|
||||
// - outgoing payments are grouped by parent_id.
|
||||
// - order by completion date (or creation date if nothing else).
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement(
|
||||
"""
|
||||
|SELECT * FROM (
|
||||
| SELECT 'received' as type,
|
||||
| NULL as parent_id,
|
||||
| NULL as external_id,
|
||||
| payment_hash,
|
||||
| payment_preimage,
|
||||
| payment_type,
|
||||
| received_msat as final_amount,
|
||||
| payment_request,
|
||||
| created_at,
|
||||
| received_at as completed_at,
|
||||
| expire_at,
|
||||
| NULL as order_trick
|
||||
| FROM received_payments
|
||||
| WHERE received_msat > 0
|
||||
|UNION ALL
|
||||
| SELECT 'sent' as type,
|
||||
| parent_id,
|
||||
| external_id,
|
||||
| payment_hash,
|
||||
| payment_preimage,
|
||||
| payment_type,
|
||||
| sum(amount_msat + fees_msat) as final_amount,
|
||||
| payment_request,
|
||||
| created_at,
|
||||
| completed_at,
|
||||
| NULL as expire_at,
|
||||
| MAX(coalesce(completed_at, created_at)) as order_trick
|
||||
| FROM sent_payments
|
||||
| GROUP BY parent_id,external_id,payment_hash,payment_preimage,payment_type,payment_request,created_at,completed_at
|
||||
|) q
|
||||
|ORDER BY coalesce(q.completed_at, q.created_at) DESC
|
||||
|LIMIT ?
|
||||
""".stripMargin
|
||||
)) { statement =>
|
||||
statement.setInt(1, limit)
|
||||
val rs = statement.executeQuery()
|
||||
var q: Queue[PlainPayment] = Queue()
|
||||
while (rs.next()) {
|
||||
val parentId = rs.getUUIDNullable("parent_id")
|
||||
val externalId_opt = rs.getStringNullable("external_id")
|
||||
val paymentHash = rs.getByteVector32FromHex("payment_hash")
|
||||
val paymentType = rs.getString("payment_type")
|
||||
val paymentRequest_opt = rs.getStringNullable("payment_request")
|
||||
val amount_opt = rs.getMilliSatoshiNullable("final_amount")
|
||||
val createdAt = rs.getLong("created_at")
|
||||
val completedAt_opt = rs.getLongNullable("completed_at")
|
||||
val expireAt_opt = rs.getLongNullable("expire_at")
|
||||
|
||||
val p = if (rs.getString("type") == "received") {
|
||||
val status: IncomingPaymentStatus = buildIncomingPaymentStatus(amount_opt, paymentRequest_opt, completedAt_opt)
|
||||
PlainIncomingPayment(paymentHash, paymentType, amount_opt, paymentRequest_opt, status, createdAt, completedAt_opt, expireAt_opt)
|
||||
} else {
|
||||
val preimage_opt = rs.getByteVector32Nullable("payment_preimage")
|
||||
// note that the resulting status will not contain any details (routes, failures...)
|
||||
val status: OutgoingPaymentStatus = buildOutgoingPaymentStatus(preimage_opt, None, None, completedAt_opt, None)
|
||||
PlainOutgoingPayment(parentId, externalId_opt, paymentHash, paymentType, amount_opt, paymentRequest_opt, status, createdAt, completedAt_opt)
|
||||
}
|
||||
q = q :+ p
|
||||
}
|
||||
q
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def close(): Unit = ()
|
||||
}
|
@ -0,0 +1,95 @@
|
||||
/*
|
||||
* Copyright 2019 ACINQ SAS
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package fr.acinq.eclair.db.pg
|
||||
|
||||
import fr.acinq.bitcoin.Crypto
|
||||
import fr.acinq.bitcoin.Crypto.PublicKey
|
||||
import fr.acinq.eclair.db.PeersDb
|
||||
import fr.acinq.eclair.db.pg.PgUtils.DatabaseLock
|
||||
import fr.acinq.eclair.wire._
|
||||
import javax.sql.DataSource
|
||||
import scodec.bits.BitVector
|
||||
|
||||
class PgPeersDb(implicit ds: DataSource, lock: DatabaseLock) extends PeersDb {
|
||||
|
||||
import PgUtils.ExtendedResultSet._
|
||||
import PgUtils._
|
||||
import lock._
|
||||
|
||||
val DB_NAME = "peers"
|
||||
val CURRENT_VERSION = 1
|
||||
|
||||
inTransaction { pg =>
|
||||
using(pg.createStatement()) { statement =>
|
||||
require(getVersion(statement, DB_NAME, CURRENT_VERSION) == CURRENT_VERSION, s"incompatible version of $DB_NAME DB found") // there is only one version currently deployed
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS peers (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)")
|
||||
}
|
||||
}
|
||||
|
||||
override def addOrUpdatePeer(nodeId: Crypto.PublicKey, nodeaddress: NodeAddress): Unit = {
|
||||
withLock { pg =>
|
||||
val data = CommonCodecs.nodeaddress.encode(nodeaddress).require.toByteArray
|
||||
using(pg.prepareStatement("UPDATE peers SET data=? WHERE node_id=?")) { update =>
|
||||
update.setBytes(1, data)
|
||||
update.setString(2, nodeId.value.toHex)
|
||||
if (update.executeUpdate() == 0) {
|
||||
using(pg.prepareStatement("INSERT INTO peers VALUES (?, ?)")) { statement =>
|
||||
statement.setString(1, nodeId.value.toHex)
|
||||
statement.setBytes(2, data)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def removePeer(nodeId: Crypto.PublicKey): Unit = {
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("DELETE FROM peers WHERE node_id=?")) { statement =>
|
||||
statement.setString(1, nodeId.value.toHex)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def getPeer(nodeId: PublicKey): Option[NodeAddress] = {
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("SELECT data FROM peers WHERE node_id=?")) { statement =>
|
||||
statement.setString(1, nodeId.value.toHex)
|
||||
val rs = statement.executeQuery()
|
||||
codecSequence(rs, CommonCodecs.nodeaddress).headOption
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def listPeers(): Map[PublicKey, NodeAddress] = {
|
||||
withLock { pg =>
|
||||
using(pg.createStatement()) { statement =>
|
||||
val rs = statement.executeQuery("SELECT node_id, data FROM peers")
|
||||
var m: Map[PublicKey, NodeAddress] = Map()
|
||||
while (rs.next()) {
|
||||
val nodeid = PublicKey(rs.getByteVectorFromHex("node_id"))
|
||||
val nodeaddress = CommonCodecs.nodeaddress.decode(BitVector(rs.getBytes("data"))).require.value
|
||||
m += (nodeid -> nodeaddress)
|
||||
}
|
||||
m
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def close(): Unit = ()
|
||||
}
|
@ -0,0 +1,91 @@
|
||||
/*
|
||||
* Copyright 2019 ACINQ SAS
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package fr.acinq.eclair.db.pg
|
||||
|
||||
|
||||
import fr.acinq.bitcoin.ByteVector32
|
||||
import fr.acinq.eclair.channel.{Command, HasHtlcId}
|
||||
import fr.acinq.eclair.db.PendingRelayDb
|
||||
import fr.acinq.eclair.db.pg.PgUtils._
|
||||
import fr.acinq.eclair.wire.CommandCodecs.cmdCodec
|
||||
import javax.sql.DataSource
|
||||
|
||||
import scala.collection.immutable.Queue
|
||||
|
||||
class PgPendingRelayDb(implicit ds: DataSource, lock: DatabaseLock) extends PendingRelayDb {
|
||||
|
||||
import PgUtils.ExtendedResultSet._
|
||||
import PgUtils._
|
||||
import lock._
|
||||
|
||||
val DB_NAME = "pending_relay"
|
||||
val CURRENT_VERSION = 1
|
||||
|
||||
inTransaction { pg =>
|
||||
using(pg.createStatement()) { statement =>
|
||||
require(getVersion(statement, DB_NAME, CURRENT_VERSION) == CURRENT_VERSION, s"incompatible version of $DB_NAME DB found") // there is only one version currently deployed
|
||||
// note: should we use a foreign key to local_channels table here?
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS pending_relay (channel_id TEXT NOT NULL, htlc_id BIGINT NOT NULL, data BYTEA NOT NULL, PRIMARY KEY(channel_id, htlc_id))")
|
||||
}
|
||||
}
|
||||
|
||||
override def addPendingRelay(channelId: ByteVector32, cmd: Command with HasHtlcId): Unit = {
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("INSERT INTO pending_relay VALUES (?, ?, ?) ON CONFLICT DO NOTHING")) { statement =>
|
||||
statement.setString(1, channelId.toHex)
|
||||
statement.setLong(2, cmd.id)
|
||||
statement.setBytes(3, cmdCodec.encode(cmd).require.toByteArray)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def removePendingRelay(channelId: ByteVector32, htlcId: Long): Unit = {
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("DELETE FROM pending_relay WHERE channel_id=? AND htlc_id=?")) { statement =>
|
||||
statement.setString(1, channelId.toHex)
|
||||
statement.setLong(2, htlcId)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def listPendingRelay(channelId: ByteVector32): Seq[Command with HasHtlcId] = {
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("SELECT htlc_id, data FROM pending_relay WHERE channel_id=?")) { statement =>
|
||||
statement.setString(1, channelId.toHex)
|
||||
val rs = statement.executeQuery()
|
||||
codecSequence(rs, cmdCodec)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def listPendingRelay(): Set[(ByteVector32, Long)] = {
|
||||
withLock { pg =>
|
||||
using(pg.prepareStatement("SELECT channel_id, htlc_id FROM pending_relay")) { statement =>
|
||||
val rs = statement.executeQuery()
|
||||
var q: Queue[(ByteVector32, Long)] = Queue()
|
||||
while (rs.next()) {
|
||||
q = q :+ (rs.getByteVector32FromHex("channel_id"), rs.getLong("htlc_id"))
|
||||
}
|
||||
q.toSet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def close(): Unit = ()
|
||||
}
|
247
eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgUtils.scala
Normal file
247
eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgUtils.scala
Normal file
@ -0,0 +1,247 @@
|
||||
/*
|
||||
* Copyright 2019 ACINQ SAS
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package fr.acinq.eclair.db.pg
|
||||
|
||||
import java.sql.{Connection, Statement, Timestamp}
|
||||
import java.util.UUID
|
||||
|
||||
import fr.acinq.eclair.db.jdbc.JdbcUtils
|
||||
import grizzled.slf4j.Logging
|
||||
import javax.sql.DataSource
|
||||
import org.postgresql.util.{PGInterval, PSQLException}
|
||||
|
||||
import scala.concurrent.duration._
|
||||
import scala.util.{Failure, Success, Try}
|
||||
|
||||
object PgUtils extends JdbcUtils with Logging {
|
||||
|
||||
val LeaseTable = "lease"
|
||||
|
||||
val LockTimeout = 5 seconds
|
||||
|
||||
val TransactionIsolationLevel = Connection.TRANSACTION_SERIALIZABLE
|
||||
|
||||
object LockType extends Enumeration {
|
||||
type LockType = Value
|
||||
|
||||
val NONE, LEASE = Value
|
||||
|
||||
def apply(s: String): LockType = s match {
|
||||
case "none" => NONE
|
||||
case "lease" => LEASE
|
||||
case _ => throw new RuntimeException(s"Unknown postgres lock type: `$s`")
|
||||
}
|
||||
}
|
||||
|
||||
case class LockLease(expiresAt: Timestamp, instanceId: UUID, expired: Boolean)
|
||||
|
||||
// @formatter:off
|
||||
class TooManyLockAttempts(msg: String) extends RuntimeException(msg)
|
||||
class UninitializedLockTable(msg: String) extends RuntimeException(msg)
|
||||
class LockException(msg: String, cause: Option[Throwable] = None) extends RuntimeException(msg, cause.orNull)
|
||||
class LeaseException(msg: String) extends RuntimeException(msg)
|
||||
// @formatter:on
|
||||
|
||||
type LockExceptionHandler = LockException => Unit
|
||||
|
||||
sealed trait DatabaseLock {
|
||||
def obtainExclusiveLock(implicit ds: DataSource): Unit
|
||||
|
||||
def withLock[T](f: Connection => T)(implicit ds: DataSource): T
|
||||
}
|
||||
|
||||
case object NoLock extends DatabaseLock {
|
||||
override def obtainExclusiveLock(implicit ds: DataSource): Unit = ()
|
||||
|
||||
override def withLock[T](f: Connection => T)(implicit ds: DataSource): T =
|
||||
inTransaction(f)
|
||||
}
|
||||
|
||||
/**
|
||||
* This class represents a lease based locking mechanism [[https://en.wikipedia.org/wiki/Lease_(computer_science]].
|
||||
* It allows only one process to access the database at a time.
|
||||
*
|
||||
* `obtainExclusiveLock` method updates the record in `lease` table with the instance id and the expiration date
|
||||
* calculated as the current time plus the lease duration. If the current lease is not expired or it belongs to
|
||||
* another instance `obtainExclusiveLock` throws an exception.
|
||||
*
|
||||
* withLock method executes its `f` function and reads the record from lease table to checks if this instance still
|
||||
* holds the lease and it's not expired. If so, the database transaction gets committed, otherwise en exception is thrown.
|
||||
*
|
||||
* `lockExceptionHandler` provides a lock exception handler to customize the behavior when locking errors occur.
|
||||
*/
|
||||
case class LeaseLock(instanceId: UUID, leaseDuration: FiniteDuration, lockExceptionHandler: LockExceptionHandler) extends DatabaseLock {
|
||||
override def obtainExclusiveLock(implicit ds: DataSource): Unit =
|
||||
obtainDatabaseLease(instanceId, leaseDuration)
|
||||
|
||||
override def withLock[T](f: Connection => T)(implicit ds: DataSource): T = {
|
||||
inTransaction { connection =>
|
||||
val res = f(connection)
|
||||
checkDatabaseLease(connection, instanceId, lockExceptionHandler)
|
||||
res
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def inTransaction[T](connection: Connection)(f: Connection => T): T = {
|
||||
val autoCommit = connection.getAutoCommit
|
||||
connection.setAutoCommit(false)
|
||||
val isolationLevel = connection.getTransactionIsolation
|
||||
connection.setTransactionIsolation(TransactionIsolationLevel)
|
||||
try {
|
||||
val res = f(connection)
|
||||
connection.commit()
|
||||
res
|
||||
} catch {
|
||||
case ex: Throwable =>
|
||||
connection.rollback()
|
||||
throw ex
|
||||
} finally {
|
||||
connection.setAutoCommit(autoCommit)
|
||||
connection.setTransactionIsolation(isolationLevel)
|
||||
}
|
||||
}
|
||||
|
||||
def inTransaction[T](f: Connection => T)(implicit dataSource: DataSource): T = {
|
||||
withConnection { connection =>
|
||||
inTransaction(connection)(f)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Several logical databases (channels, network, peers) may be stored in the same physical postgres database.
|
||||
* We keep track of their respective version using a dedicated table. The version entry will be created if
|
||||
* there is none but will never be updated here (use setVersion to do that).
|
||||
*/
|
||||
def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int = {
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS versions (db_name TEXT NOT NULL PRIMARY KEY, version INTEGER NOT NULL)")
|
||||
// if there was no version for the current db, then insert the current version
|
||||
statement.executeUpdate(s"INSERT INTO versions VALUES ('$db_name', $currentVersion) ON CONFLICT DO NOTHING")
|
||||
// if there was a previous version installed, this will return a different value from current version
|
||||
val res = statement.executeQuery(s"SELECT version FROM versions WHERE db_name='$db_name'")
|
||||
res.next()
|
||||
res.getInt("version")
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the version for a particular logical database, it will overwrite the previous version.
|
||||
*/
|
||||
def setVersion(statement: Statement, db_name: String, newVersion: Int): Unit = {
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS versions (db_name TEXT NOT NULL PRIMARY KEY, version INTEGER NOT NULL)")
|
||||
// overwrite the existing version
|
||||
statement.executeUpdate(s"UPDATE versions SET version=$newVersion WHERE db_name='$db_name'")
|
||||
}
|
||||
|
||||
private def obtainDatabaseLease(instanceId: UUID, leaseDuration: FiniteDuration, attempt: Int = 1)(implicit ds: DataSource): Unit = synchronized {
|
||||
logger.debug(s"trying to acquire database lease (attempt #$attempt) instance ID=${instanceId}")
|
||||
|
||||
if (attempt > 3) throw new TooManyLockAttempts("Too many attempts to acquire database lease")
|
||||
|
||||
try {
|
||||
inTransaction { implicit connection =>
|
||||
acquireExclusiveTableLock()
|
||||
getCurrentLease match {
|
||||
case Some(lease) =>
|
||||
if (lease.instanceId == instanceId || lease.expired)
|
||||
updateLease(instanceId, leaseDuration)
|
||||
else
|
||||
throw new LeaseException(s"The database is locked by instance ID=${lease.instanceId}")
|
||||
case None =>
|
||||
updateLease(instanceId, leaseDuration, insertNew = true)
|
||||
}
|
||||
}
|
||||
logger.debug("database lease was successfully acquired")
|
||||
} catch {
|
||||
case e: PSQLException if (e.getServerErrorMessage != null && e.getServerErrorMessage.getSQLState == "42P01") =>
|
||||
withConnection {
|
||||
connection =>
|
||||
logger.warn(s"table $LeaseTable does not exist, trying to recreate it")
|
||||
initializeLeaseTable(connection)
|
||||
obtainDatabaseLease(instanceId, leaseDuration, attempt + 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def initializeLeaseTable(implicit connection: Connection): Unit = {
|
||||
using(connection.createStatement()) {
|
||||
statement =>
|
||||
// allow only one row in the ownership lease table
|
||||
statement.executeUpdate(s"CREATE TABLE IF NOT EXISTS $LeaseTable (id INTEGER PRIMARY KEY default(1), expires_at TIMESTAMP NOT NULL, instance VARCHAR NOT NULL, CONSTRAINT one_row CHECK (id = 1))")
|
||||
}
|
||||
}
|
||||
|
||||
private def acquireExclusiveTableLock()(implicit connection: Connection): Unit = {
|
||||
using(connection.createStatement()) {
|
||||
statement =>
|
||||
statement.executeUpdate(s"SET lock_timeout TO '${LockTimeout.toSeconds}s'")
|
||||
statement.executeUpdate(s"LOCK TABLE $LeaseTable IN ACCESS EXCLUSIVE MODE")
|
||||
}
|
||||
}
|
||||
|
||||
private def checkDatabaseLease(connection: Connection, instanceId: UUID, lockExceptionHandler: LockExceptionHandler): Unit = {
|
||||
Try {
|
||||
getCurrentLease(connection) match {
|
||||
case Some(lease) =>
|
||||
if (!(lease.instanceId == instanceId) || lease.expired) {
|
||||
logger.info(s"database lease: $lease")
|
||||
throw new LockException("This Eclair instance is not a database owner")
|
||||
}
|
||||
case None =>
|
||||
throw new LockException("No database lease info")
|
||||
}
|
||||
} match {
|
||||
case Success(_) => ()
|
||||
case Failure(ex) =>
|
||||
val lex = ex match {
|
||||
case e: LockException => e
|
||||
case t: Throwable => new LockException("Cannot check database lease", Some(t))
|
||||
}
|
||||
lockExceptionHandler(lex)
|
||||
throw lex
|
||||
}
|
||||
}
|
||||
|
||||
private def getCurrentLease(implicit connection: Connection): Option[LockLease] = {
|
||||
using(connection.createStatement()) {
|
||||
statement =>
|
||||
val rs = statement.executeQuery(s"SELECT expires_at, instance, now() > expires_at AS expired FROM $LeaseTable WHERE id = 1")
|
||||
if (rs.next())
|
||||
Some(LockLease(
|
||||
expiresAt = rs.getTimestamp("expires_at"),
|
||||
instanceId = UUID.fromString(rs.getString("instance")),
|
||||
expired = rs.getBoolean("expired")))
|
||||
else
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
private def updateLease(instanceId: UUID, leaseDuration: FiniteDuration, insertNew: Boolean = false)(implicit connection: Connection): Unit = {
|
||||
val sql = if (insertNew)
|
||||
s"INSERT INTO $LeaseTable (expires_at, instance) VALUES (now() + ?, ?)"
|
||||
else
|
||||
s"UPDATE $LeaseTable SET expires_at = now() + ?, instance = ? WHERE id = 1"
|
||||
using(connection.prepareStatement(sql)) {
|
||||
statement =>
|
||||
statement.setObject(1, new PGInterval(s"${
|
||||
leaseDuration.toSeconds
|
||||
} seconds"))
|
||||
statement.setString(2, instanceId.toString)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -101,7 +101,7 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging {
|
||||
}
|
||||
}
|
||||
|
||||
def addHtlcInfo(channelId: ByteVector32, commitmentNumber: Long, paymentHash: ByteVector32, cltvExpiry: CltvExpiry): Unit = {
|
||||
override def addHtlcInfo(channelId: ByteVector32, commitmentNumber: Long, paymentHash: ByteVector32, cltvExpiry: CltvExpiry): Unit = {
|
||||
using(sqlite.prepareStatement("INSERT INTO htlc_infos VALUES (?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, channelId.toArray)
|
||||
statement.setLong(2, commitmentNumber)
|
||||
@ -111,7 +111,7 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging {
|
||||
}
|
||||
}
|
||||
|
||||
def listHtlcInfos(channelId: ByteVector32, commitmentNumber: Long): Seq[(ByteVector32, CltvExpiry)] = {
|
||||
override def listHtlcInfos(channelId: ByteVector32, commitmentNumber: Long): Seq[(ByteVector32, CltvExpiry)] = {
|
||||
using(sqlite.prepareStatement("SELECT payment_hash, cltv_expiry FROM htlc_infos WHERE channel_id=? AND commitment_number=?")) { statement =>
|
||||
statement.setBytes(1, channelId.toArray)
|
||||
statement.setLong(2, commitmentNumber)
|
||||
|
@ -16,38 +16,11 @@
|
||||
|
||||
package fr.acinq.eclair.db.sqlite
|
||||
|
||||
import java.sql.{Connection, ResultSet, Statement}
|
||||
import java.util.UUID
|
||||
import java.sql.{Connection, Statement}
|
||||
|
||||
import fr.acinq.bitcoin.ByteVector32
|
||||
import fr.acinq.eclair.MilliSatoshi
|
||||
import scodec.Codec
|
||||
import scodec.bits.{BitVector, ByteVector}
|
||||
import fr.acinq.eclair.db.jdbc.JdbcUtils
|
||||
|
||||
import scala.collection.immutable.Queue
|
||||
|
||||
object SqliteUtils {
|
||||
|
||||
/**
|
||||
* This helper makes sure statements are correctly closed.
|
||||
*
|
||||
* @param inTransaction if set to true, all updates in the block will be run in a transaction.
|
||||
*/
|
||||
def using[T <: Statement, U](statement: T, inTransaction: Boolean = false)(block: T => U): U = {
|
||||
try {
|
||||
if (inTransaction) statement.getConnection.setAutoCommit(false)
|
||||
val res = block(statement)
|
||||
if (inTransaction) statement.getConnection.commit()
|
||||
res
|
||||
} catch {
|
||||
case t: Exception =>
|
||||
if (inTransaction) statement.getConnection.rollback()
|
||||
throw t
|
||||
} finally {
|
||||
if (inTransaction) statement.getConnection.setAutoCommit(true)
|
||||
if (statement != null) statement.close()
|
||||
}
|
||||
}
|
||||
object SqliteUtils extends JdbcUtils {
|
||||
|
||||
/**
|
||||
* Several logical databases (channels, network, peers) may be stored in the same physical sqlite database.
|
||||
@ -72,19 +45,6 @@ object SqliteUtils {
|
||||
statement.executeUpdate(s"UPDATE versions SET version=$newVersion WHERE db_name='$db_name'")
|
||||
}
|
||||
|
||||
/**
|
||||
* This helper assumes that there is a "data" column available, decodable with the provided codec
|
||||
*
|
||||
* TODO: we should use an scala.Iterator instead
|
||||
*/
|
||||
def codecSequence[T](rs: ResultSet, codec: Codec[T]): Seq[T] = {
|
||||
var q: Queue[T] = Queue()
|
||||
while (rs.next()) {
|
||||
q = q :+ codec.decode(BitVector(rs.getBytes("data"))).require.value
|
||||
}
|
||||
q
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain an exclusive lock on a sqlite database. This is useful when we want to make sure that only one process
|
||||
* accesses the database file (see https://www.sqlite.org/pragma.html).
|
||||
@ -99,48 +59,4 @@ object SqliteUtils {
|
||||
statement.executeUpdate("INSERT INTO dummy_table_for_locking VALUES (42)")
|
||||
}
|
||||
|
||||
case class ExtendedResultSet(rs: ResultSet) {
|
||||
|
||||
def getBitVectorOpt(columnLabel: String): Option[BitVector] = Option(rs.getBytes(columnLabel)).map(BitVector(_))
|
||||
|
||||
def getByteVector(columnLabel: String): ByteVector = ByteVector(rs.getBytes(columnLabel))
|
||||
|
||||
def getByteVectorNullable(columnLabel: String): ByteVector = {
|
||||
val result = rs.getBytes(columnLabel)
|
||||
if (rs.wasNull()) ByteVector.empty else ByteVector(result)
|
||||
}
|
||||
|
||||
def getByteVector32(columnLabel: String): ByteVector32 = ByteVector32(ByteVector(rs.getBytes(columnLabel)))
|
||||
|
||||
def getByteVector32Nullable(columnLabel: String): Option[ByteVector32] = {
|
||||
val bytes = rs.getBytes(columnLabel)
|
||||
if (rs.wasNull()) None else Some(ByteVector32(ByteVector(bytes)))
|
||||
}
|
||||
|
||||
def getStringNullable(columnLabel: String): Option[String] = {
|
||||
val result = rs.getString(columnLabel)
|
||||
if (rs.wasNull()) None else Some(result)
|
||||
}
|
||||
|
||||
def getLongNullable(columnLabel: String): Option[Long] = {
|
||||
val result = rs.getLong(columnLabel)
|
||||
if (rs.wasNull()) None else Some(result)
|
||||
}
|
||||
|
||||
def getUUIDNullable(label: String): Option[UUID] = {
|
||||
val result = rs.getString(label)
|
||||
if (rs.wasNull()) None else Some(UUID.fromString(result))
|
||||
}
|
||||
|
||||
def getMilliSatoshiNullable(label: String): Option[MilliSatoshi] = {
|
||||
val result = rs.getLong(label)
|
||||
if (rs.wasNull()) None else Some(MilliSatoshi(result))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object ExtendedResultSet {
|
||||
implicit def conv(rs: ResultSet): ExtendedResultSet = ExtendedResultSet(rs)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -47,7 +47,7 @@
|
||||
<logger name="fr.acinq.eclair.channel" level="WARN"/>
|
||||
<logger name="fr.acinq.eclair.Diagnostics" level="OFF"/>
|
||||
<logger name="fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher" level="OFF"/>
|
||||
<logger name="fr.acinq.eclair.db.BackupHandler" level="OFF"/>
|
||||
<logger name="fr.acinq.eclair.db.FileBackupHandler" level="OFF"/>
|
||||
|
||||
<root level="INFO">
|
||||
<!--appender-ref ref="FILE"/>
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
package fr.acinq.eclair
|
||||
|
||||
import java.util.UUID
|
||||
import java.util.concurrent.atomic.AtomicLong
|
||||
|
||||
import com.typesafe.config.{Config, ConfigFactory}
|
||||
@ -39,7 +40,7 @@ class StartupSpec extends AnyFunSuite {
|
||||
val keyManager = new LocalKeyManager(seed = randomBytes32, chainHash = Block.TestnetGenesisBlock.hash)
|
||||
val feeEstimator = new TestConstants.TestFeeEstimator
|
||||
val db = TestConstants.inMemoryDb()
|
||||
NodeParams.makeNodeParams(conf, keyManager, None, db, blockCount, feeEstimator)
|
||||
NodeParams.makeNodeParams(conf, UUID.fromString("01234567-0123-4567-89ab-0123456789ab"), keyManager, None, db, blockCount, feeEstimator)
|
||||
}
|
||||
|
||||
test("check configuration") {
|
||||
|
@ -16,9 +16,11 @@
|
||||
|
||||
package fr.acinq.eclair
|
||||
|
||||
import java.sql.{Connection, DriverManager}
|
||||
import java.sql.{Connection, DriverManager, Statement}
|
||||
import java.util.UUID
|
||||
import java.util.concurrent.atomic.AtomicLong
|
||||
|
||||
import com.opentable.db.postgres.embedded.EmbeddedPostgres
|
||||
import fr.acinq.bitcoin.Crypto.PrivateKey
|
||||
import fr.acinq.bitcoin.{Block, ByteVector32, Script}
|
||||
import fr.acinq.eclair.FeatureSupport.Optional
|
||||
@ -27,6 +29,9 @@ import fr.acinq.eclair.NodeParams.BITCOIND
|
||||
import fr.acinq.eclair.blockchain.fee._
|
||||
import fr.acinq.eclair.crypto.LocalKeyManager
|
||||
import fr.acinq.eclair.db._
|
||||
import fr.acinq.eclair.db.pg.PgUtils.NoLock
|
||||
import fr.acinq.eclair.db.pg._
|
||||
import fr.acinq.eclair.db.sqlite._
|
||||
import fr.acinq.eclair.io.Peer
|
||||
import fr.acinq.eclair.router.Router.RouterConf
|
||||
import fr.acinq.eclair.wire.{Color, EncodingType, NodeAddress}
|
||||
@ -57,9 +62,62 @@ object TestConstants {
|
||||
}
|
||||
}
|
||||
|
||||
def sqliteInMemory() = DriverManager.getConnection("jdbc:sqlite::memory:")
|
||||
sealed trait TestDatabases {
|
||||
val connection: Connection
|
||||
def network(): NetworkDb
|
||||
def audit(): AuditDb
|
||||
def channels(): ChannelsDb
|
||||
def peers(): PeersDb
|
||||
def payments(): PaymentsDb
|
||||
def pendingRelay(): PendingRelayDb
|
||||
def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int
|
||||
def close(): Unit
|
||||
}
|
||||
|
||||
def inMemoryDb(connection: Connection = sqliteInMemory()): Databases = Databases.databaseByConnections(connection, connection, connection)
|
||||
case class TestSqliteDatabases(connection: Connection = sqliteInMemory()) extends TestDatabases {
|
||||
override def network(): NetworkDb = new SqliteNetworkDb(connection)
|
||||
override def audit(): AuditDb = new SqliteAuditDb(connection)
|
||||
override def channels(): ChannelsDb = new SqliteChannelsDb(connection)
|
||||
override def peers(): PeersDb = new SqlitePeersDb(connection)
|
||||
override def payments(): PaymentsDb = new SqlitePaymentsDb(connection)
|
||||
override def pendingRelay(): PendingRelayDb = new SqlitePendingRelayDb(connection)
|
||||
override def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int = SqliteUtils.getVersion(statement, db_name, currentVersion)
|
||||
override def close(): Unit = ()
|
||||
}
|
||||
|
||||
case class TestPgDatabases() extends TestDatabases {
|
||||
private val pg = EmbeddedPostgres.start()
|
||||
|
||||
override val connection: Connection = pg.getPostgresDatabase.getConnection
|
||||
|
||||
import com.zaxxer.hikari.{HikariConfig, HikariDataSource}
|
||||
|
||||
val config = new HikariConfig
|
||||
config.setDataSource(pg.getPostgresDatabase)
|
||||
|
||||
implicit val ds = new HikariDataSource(config)
|
||||
|
||||
implicit val lock = NoLock
|
||||
|
||||
override def network(): NetworkDb = new PgNetworkDb
|
||||
override def audit(): AuditDb = new PgAuditDb
|
||||
override def channels(): ChannelsDb = new PgChannelsDb
|
||||
override def peers(): PeersDb = new PgPeersDb
|
||||
override def payments(): PaymentsDb = new PgPaymentsDb
|
||||
override def pendingRelay(): PendingRelayDb = new PgPendingRelayDb
|
||||
override def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int = PgUtils.getVersion(statement, db_name, currentVersion)
|
||||
override def close(): Unit = pg.close()
|
||||
}
|
||||
|
||||
def sqliteInMemory(): Connection = DriverManager.getConnection("jdbc:sqlite::memory:")
|
||||
|
||||
def forAllDbs(f: TestDatabases => Unit): Unit = {
|
||||
def using(dbs: TestDatabases)(g: TestDatabases => Unit): Unit = try g(dbs) finally dbs.close()
|
||||
using(TestSqliteDatabases())(f)
|
||||
using(TestPgDatabases())(f)
|
||||
}
|
||||
|
||||
def inMemoryDb(connection: Connection = sqliteInMemory()): Databases = Databases.sqliteDatabaseByConnections(connection, connection, connection)
|
||||
|
||||
object Alice {
|
||||
val seed = ByteVector32(ByteVector.fill(32)(1))
|
||||
@ -139,7 +197,8 @@ object TestConstants {
|
||||
),
|
||||
socksProxy_opt = None,
|
||||
maxPaymentAttempts = 5,
|
||||
enableTrampolinePayment = true
|
||||
enableTrampolinePayment = true,
|
||||
instanceId = UUID.fromString("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
|
||||
)
|
||||
|
||||
def channelParams = Peer.makeChannelParams(
|
||||
@ -225,7 +284,8 @@ object TestConstants {
|
||||
),
|
||||
socksProxy_opt = None,
|
||||
maxPaymentAttempts = 5,
|
||||
enableTrampolinePayment = true
|
||||
enableTrampolinePayment = true,
|
||||
instanceId = UUID.fromString("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")
|
||||
)
|
||||
|
||||
def channelParams = Peer.makeChannelParams(
|
||||
|
@ -20,15 +20,15 @@ import java.io.File
|
||||
import java.sql.DriverManager
|
||||
import java.util.UUID
|
||||
|
||||
import akka.actor.ActorSystem
|
||||
import akka.testkit.{TestKit, TestProbe}
|
||||
import akka.testkit.TestProbe
|
||||
import fr.acinq.eclair.channel.ChannelPersisted
|
||||
import fr.acinq.eclair.db.Databases.FileBackup
|
||||
import fr.acinq.eclair.db.sqlite.SqliteChannelsDb
|
||||
import fr.acinq.eclair.wire.ChannelCodecsSpec
|
||||
import fr.acinq.eclair.{TestConstants, TestKitBaseClass, TestUtils, randomBytes32}
|
||||
import org.scalatest.funsuite.AnyFunSuiteLike
|
||||
|
||||
class BackupHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike {
|
||||
class FileBackupHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike {
|
||||
|
||||
test("process backups") {
|
||||
val db = TestConstants.inMemoryDb()
|
||||
@ -40,7 +40,7 @@ class BackupHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike {
|
||||
db.channels.addOrUpdateChannel(channel)
|
||||
assert(db.channels.listLocalChannels() == Seq(channel))
|
||||
|
||||
val handler = system.actorOf(BackupHandler.props(db, dest, None))
|
||||
val handler = system.actorOf(FileBackupHandler.props(db.asInstanceOf[FileBackup], dest, None))
|
||||
val probe = TestProbe()
|
||||
system.eventStream.subscribe(probe.ref, classOf[BackupEvent])
|
||||
|
@ -20,10 +20,11 @@ import java.util.UUID
|
||||
|
||||
import fr.acinq.bitcoin.Crypto.PrivateKey
|
||||
import fr.acinq.bitcoin.{ByteVector32, Transaction}
|
||||
import fr.acinq.eclair.TestConstants.{TestPgDatabases, TestSqliteDatabases, forAllDbs}
|
||||
import fr.acinq.eclair._
|
||||
import fr.acinq.eclair.channel.{ChannelErrorOccurred, LocalError, NetworkFeePaid, RemoteError}
|
||||
import fr.acinq.eclair.db.jdbc.JdbcUtils.using
|
||||
import fr.acinq.eclair.db.sqlite.SqliteAuditDb
|
||||
import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, using}
|
||||
import fr.acinq.eclair.payment._
|
||||
import org.scalatest.Tag
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
@ -36,372 +37,394 @@ class SqliteAuditDbSpec extends AnyFunSuite {
|
||||
val ZERO_UUID = UUID.fromString("00000000-0000-0000-0000-000000000000")
|
||||
|
||||
test("init sqlite 2 times in a row") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
val db1 = new SqliteAuditDb(sqlite)
|
||||
val db2 = new SqliteAuditDb(sqlite)
|
||||
forAllDbs { dbs =>
|
||||
val db1 = dbs.audit()
|
||||
val db2 = dbs.audit()
|
||||
}
|
||||
}
|
||||
|
||||
test("add/list events") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
val db = new SqliteAuditDb(sqlite)
|
||||
forAllDbs { dbs =>
|
||||
val db = dbs.audit()
|
||||
|
||||
val e1 = PaymentSent(ZERO_UUID, randomBytes32, randomBytes32, 40000 msat, randomKey.publicKey, PaymentSent.PartialPayment(ZERO_UUID, 42000 msat, 1000 msat, randomBytes32, None) :: Nil)
|
||||
val pp2a = PaymentReceived.PartialPayment(42000 msat, randomBytes32)
|
||||
val pp2b = PaymentReceived.PartialPayment(42100 msat, randomBytes32)
|
||||
val e2 = PaymentReceived(randomBytes32, pp2a :: pp2b :: Nil)
|
||||
val e3 = ChannelPaymentRelayed(42000 msat, 1000 msat, randomBytes32, randomBytes32, randomBytes32)
|
||||
val e4 = NetworkFeePaid(null, randomKey.publicKey, randomBytes32, Transaction(0, Seq.empty, Seq.empty, 0), 42 sat, "mutual")
|
||||
val pp5a = PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None, timestamp = 0)
|
||||
val pp5b = PaymentSent.PartialPayment(UUID.randomUUID(), 42100 msat, 900 msat, randomBytes32, None, timestamp = 1)
|
||||
val e5 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 84100 msat, randomKey.publicKey, pp5a :: pp5b :: Nil)
|
||||
val pp6 = PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None, timestamp = (System.currentTimeMillis.milliseconds + 10.minutes).toMillis)
|
||||
val e6 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 42000 msat, randomKey.publicKey, pp6 :: Nil)
|
||||
val e7 = ChannelLifecycleEvent(randomBytes32, randomKey.publicKey, 456123000 sat, isFunder = true, isPrivate = false, "mutual")
|
||||
val e8 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
|
||||
val e9 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true)
|
||||
val e10 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(20000 msat, randomBytes32), PaymentRelayed.Part(22000 msat, randomBytes32)), Seq(PaymentRelayed.Part(10000 msat, randomBytes32), PaymentRelayed.Part(12000 msat, randomBytes32), PaymentRelayed.Part(15000 msat, randomBytes32)))
|
||||
val multiPartPaymentHash = randomBytes32
|
||||
val now = System.currentTimeMillis
|
||||
val e11 = ChannelPaymentRelayed(13000 msat, 11000 msat, multiPartPaymentHash, randomBytes32, randomBytes32, now)
|
||||
val e12 = ChannelPaymentRelayed(15000 msat, 12500 msat, multiPartPaymentHash, randomBytes32, randomBytes32, now)
|
||||
val e1 = PaymentSent(ZERO_UUID, randomBytes32, randomBytes32, 40000 msat, randomKey.publicKey, PaymentSent.PartialPayment(ZERO_UUID, 42000 msat, 1000 msat, randomBytes32, None) :: Nil)
|
||||
val pp2a = PaymentReceived.PartialPayment(42000 msat, randomBytes32)
|
||||
val pp2b = PaymentReceived.PartialPayment(42100 msat, randomBytes32)
|
||||
val e2 = PaymentReceived(randomBytes32, pp2a :: pp2b :: Nil)
|
||||
val e3 = ChannelPaymentRelayed(42000 msat, 1000 msat, randomBytes32, randomBytes32, randomBytes32)
|
||||
val e4 = NetworkFeePaid(null, randomKey.publicKey, randomBytes32, Transaction(0, Seq.empty, Seq.empty, 0), 42 sat, "mutual")
|
||||
val pp5a = PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None, timestamp = 0)
|
||||
val pp5b = PaymentSent.PartialPayment(UUID.randomUUID(), 42100 msat, 900 msat, randomBytes32, None, timestamp = 1)
|
||||
val e5 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 84100 msat, randomKey.publicKey, pp5a :: pp5b :: Nil)
|
||||
val pp6 = PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None, timestamp = (System.currentTimeMillis.milliseconds + 10.minutes).toMillis)
|
||||
val e6 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 42000 msat, randomKey.publicKey, pp6 :: Nil)
|
||||
val e7 = ChannelLifecycleEvent(randomBytes32, randomKey.publicKey, 456123000 sat, isFunder = true, isPrivate = false, "mutual")
|
||||
val e8 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
|
||||
val e9 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true)
|
||||
val e10 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(20000 msat, randomBytes32), PaymentRelayed.Part(22000 msat, randomBytes32)), Seq(PaymentRelayed.Part(10000 msat, randomBytes32), PaymentRelayed.Part(12000 msat, randomBytes32), PaymentRelayed.Part(15000 msat, randomBytes32)))
|
||||
val multiPartPaymentHash = randomBytes32
|
||||
val now = System.currentTimeMillis
|
||||
val e11 = ChannelPaymentRelayed(13000 msat, 11000 msat, multiPartPaymentHash, randomBytes32, randomBytes32, now)
|
||||
val e12 = ChannelPaymentRelayed(15000 msat, 12500 msat, multiPartPaymentHash, randomBytes32, randomBytes32, now)
|
||||
|
||||
db.add(e1)
|
||||
db.add(e2)
|
||||
db.add(e3)
|
||||
db.add(e4)
|
||||
db.add(e5)
|
||||
db.add(e6)
|
||||
db.add(e7)
|
||||
db.add(e8)
|
||||
db.add(e9)
|
||||
db.add(e10)
|
||||
db.add(e11)
|
||||
db.add(e12)
|
||||
db.add(e1)
|
||||
db.add(e2)
|
||||
db.add(e3)
|
||||
db.add(e4)
|
||||
db.add(e5)
|
||||
db.add(e6)
|
||||
db.add(e7)
|
||||
db.add(e8)
|
||||
db.add(e9)
|
||||
db.add(e10)
|
||||
db.add(e11)
|
||||
db.add(e12)
|
||||
|
||||
assert(db.listSent(from = 0L, to = (System.currentTimeMillis.milliseconds + 15.minute).toMillis).toSet === Set(e1, e5, e6))
|
||||
assert(db.listSent(from = 100000L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).toList === List(e1))
|
||||
assert(db.listReceived(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).toList === List(e2))
|
||||
assert(db.listRelayed(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).toList === List(e3, e10, e11, e12))
|
||||
assert(db.listNetworkFees(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).size === 1)
|
||||
assert(db.listNetworkFees(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).head.txType === "mutual")
|
||||
assert(db.listSent(from = 0L, to = (System.currentTimeMillis.milliseconds + 15.minute).toMillis).toSet === Set(e1, e5, e6))
|
||||
assert(db.listSent(from = 100000L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).toList === List(e1))
|
||||
assert(db.listReceived(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).toList === List(e2))
|
||||
assert(db.listRelayed(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).toList === List(e3, e10, e11, e12))
|
||||
assert(db.listNetworkFees(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).size === 1)
|
||||
assert(db.listNetworkFees(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).head.txType === "mutual")
|
||||
}
|
||||
}
|
||||
|
||||
test("stats") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
val db = new SqliteAuditDb(sqlite)
|
||||
forAllDbs { dbs =>
|
||||
val db = dbs.audit()
|
||||
|
||||
val n2 = randomKey.publicKey
|
||||
val n3 = randomKey.publicKey
|
||||
val n4 = randomKey.publicKey
|
||||
val n2 = randomKey.publicKey
|
||||
val n3 = randomKey.publicKey
|
||||
val n4 = randomKey.publicKey
|
||||
|
||||
val c1 = randomBytes32
|
||||
val c2 = randomBytes32
|
||||
val c3 = randomBytes32
|
||||
val c4 = randomBytes32
|
||||
val c5 = randomBytes32
|
||||
val c6 = randomBytes32
|
||||
val c1 = randomBytes32
|
||||
val c2 = randomBytes32
|
||||
val c3 = randomBytes32
|
||||
val c4 = randomBytes32
|
||||
val c5 = randomBytes32
|
||||
val c6 = randomBytes32
|
||||
|
||||
db.add(ChannelPaymentRelayed(46000 msat, 44000 msat, randomBytes32, c6, c1))
|
||||
db.add(ChannelPaymentRelayed(41000 msat, 40000 msat, randomBytes32, c6, c1))
|
||||
db.add(ChannelPaymentRelayed(43000 msat, 42000 msat, randomBytes32, c5, c1))
|
||||
db.add(ChannelPaymentRelayed(42000 msat, 40000 msat, randomBytes32, c5, c2))
|
||||
db.add(ChannelPaymentRelayed(45000 msat, 40000 msat, randomBytes32, c5, c6))
|
||||
db.add(TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(25000 msat, c6)), Seq(PaymentRelayed.Part(20000 msat, c4))))
|
||||
db.add(TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(46000 msat, c6)), Seq(PaymentRelayed.Part(16000 msat, c2), PaymentRelayed.Part(10000 msat, c4), PaymentRelayed.Part(14000 msat, c4))))
|
||||
db.add(ChannelPaymentRelayed(46000 msat, 44000 msat, randomBytes32, c6, c1))
|
||||
db.add(ChannelPaymentRelayed(41000 msat, 40000 msat, randomBytes32, c6, c1))
|
||||
db.add(ChannelPaymentRelayed(43000 msat, 42000 msat, randomBytes32, c5, c1))
|
||||
db.add(ChannelPaymentRelayed(42000 msat, 40000 msat, randomBytes32, c5, c2))
|
||||
db.add(ChannelPaymentRelayed(45000 msat, 40000 msat, randomBytes32, c5, c6))
|
||||
db.add(TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(25000 msat, c6)), Seq(PaymentRelayed.Part(20000 msat, c4))))
|
||||
db.add(TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(46000 msat, c6)), Seq(PaymentRelayed.Part(16000 msat, c2), PaymentRelayed.Part(10000 msat, c4), PaymentRelayed.Part(14000 msat, c4))))
|
||||
|
||||
db.add(NetworkFeePaid(null, n2, c2, Transaction(0, Seq.empty, Seq.empty, 0), 200 sat, "funding"))
|
||||
db.add(NetworkFeePaid(null, n2, c2, Transaction(0, Seq.empty, Seq.empty, 0), 300 sat, "mutual"))
|
||||
db.add(NetworkFeePaid(null, n3, c3, Transaction(0, Seq.empty, Seq.empty, 0), 400 sat, "funding"))
|
||||
db.add(NetworkFeePaid(null, n4, c4, Transaction(0, Seq.empty, Seq.empty, 0), 500 sat, "funding"))
|
||||
db.add(NetworkFeePaid(null, n2, c2, Transaction(0, Seq.empty, Seq.empty, 0), 200 sat, "funding"))
|
||||
db.add(NetworkFeePaid(null, n2, c2, Transaction(0, Seq.empty, Seq.empty, 0), 300 sat, "mutual"))
|
||||
db.add(NetworkFeePaid(null, n3, c3, Transaction(0, Seq.empty, Seq.empty, 0), 400 sat, "funding"))
|
||||
db.add(NetworkFeePaid(null, n4, c4, Transaction(0, Seq.empty, Seq.empty, 0), 500 sat, "funding"))
|
||||
|
||||
// NB: we only count a relay fee for the outgoing channel, no the incoming one.
|
||||
assert(db.stats(0, System.currentTimeMillis + 1).toSet === Set(
|
||||
Stats(channelId = c1, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 0 sat),
|
||||
Stats(channelId = c1, direction = "OUT", avgPaymentAmount = 42 sat, paymentCount = 3, relayFee = 4 sat, networkFee = 0 sat),
|
||||
Stats(channelId = c2, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 500 sat),
|
||||
Stats(channelId = c2, direction = "OUT", avgPaymentAmount = 28 sat, paymentCount = 2, relayFee = 4 sat, networkFee = 500 sat),
|
||||
Stats(channelId = c3, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 400 sat),
|
||||
Stats(channelId = c3, direction = "OUT", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 400 sat),
|
||||
Stats(channelId = c4, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 500 sat),
|
||||
Stats(channelId = c4, direction = "OUT", avgPaymentAmount = 22 sat, paymentCount = 2, relayFee = 9 sat, networkFee = 500 sat),
|
||||
Stats(channelId = c5, direction = "IN", avgPaymentAmount = 43 sat, paymentCount = 3, relayFee = 0 sat, networkFee = 0 sat),
|
||||
Stats(channelId = c5, direction = "OUT", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 0 sat),
|
||||
Stats(channelId = c6, direction = "IN", avgPaymentAmount = 39 sat, paymentCount = 4, relayFee = 0 sat, networkFee = 0 sat),
|
||||
Stats(channelId = c6, direction = "OUT", avgPaymentAmount = 40 sat, paymentCount = 1, relayFee = 5 sat, networkFee = 0 sat),
|
||||
))
|
||||
// NB: we only count a relay fee for the outgoing channel, no the incoming one.
|
||||
assert(db.stats(0, System.currentTimeMillis + 1).toSet === Set(
|
||||
Stats(channelId = c1, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 0 sat),
|
||||
Stats(channelId = c1, direction = "OUT", avgPaymentAmount = 42 sat, paymentCount = 3, relayFee = 4 sat, networkFee = 0 sat),
|
||||
Stats(channelId = c2, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 500 sat),
|
||||
Stats(channelId = c2, direction = "OUT", avgPaymentAmount = 28 sat, paymentCount = 2, relayFee = 4 sat, networkFee = 500 sat),
|
||||
Stats(channelId = c3, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 400 sat),
|
||||
Stats(channelId = c3, direction = "OUT", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 400 sat),
|
||||
Stats(channelId = c4, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 500 sat),
|
||||
Stats(channelId = c4, direction = "OUT", avgPaymentAmount = 22 sat, paymentCount = 2, relayFee = 9 sat, networkFee = 500 sat),
|
||||
Stats(channelId = c5, direction = "IN", avgPaymentAmount = 43 sat, paymentCount = 3, relayFee = 0 sat, networkFee = 0 sat),
|
||||
Stats(channelId = c5, direction = "OUT", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 0 sat),
|
||||
Stats(channelId = c6, direction = "IN", avgPaymentAmount = 39 sat, paymentCount = 4, relayFee = 0 sat, networkFee = 0 sat),
|
||||
Stats(channelId = c6, direction = "OUT", avgPaymentAmount = 40 sat, paymentCount = 1, relayFee = 5 sat, networkFee = 0 sat),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
ignore("relay stats performance", Tag("perf")) {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
val db = new SqliteAuditDb(sqlite)
|
||||
val nodeCount = 100
|
||||
val channelCount = 1000
|
||||
val eventCount = 100000
|
||||
val nodeIds = (1 to nodeCount).map(_ => randomKey.publicKey)
|
||||
val channelIds = (1 to channelCount).map(_ => randomBytes32)
|
||||
// Fund channels.
|
||||
channelIds.foreach(channelId => {
|
||||
val nodeId = nodeIds(Random.nextInt(nodeCount))
|
||||
db.add(NetworkFeePaid(null, nodeId, channelId, Transaction(0, Seq.empty, Seq.empty, 0), 100 sat, "funding"))
|
||||
})
|
||||
// Add relay events.
|
||||
(1 to eventCount).foreach(_ => {
|
||||
// 25% trampoline relays.
|
||||
if (Random.nextInt(4) == 0) {
|
||||
val outgoingCount = 1 + Random.nextInt(4)
|
||||
val incoming = Seq(PaymentRelayed.Part(10000 msat, randomBytes32))
|
||||
val outgoing = (1 to outgoingCount).map(_ => PaymentRelayed.Part(Random.nextInt(2000).msat, channelIds(Random.nextInt(channelCount))))
|
||||
db.add(TrampolinePaymentRelayed(randomBytes32, incoming, outgoing))
|
||||
} else {
|
||||
val toChannelId = channelIds(Random.nextInt(channelCount))
|
||||
db.add(ChannelPaymentRelayed(10000 msat, Random.nextInt(10000).msat, randomBytes32, randomBytes32, toChannelId))
|
||||
}
|
||||
})
|
||||
// Test starts here.
|
||||
val start = System.currentTimeMillis
|
||||
assert(db.stats(0, start + 1).nonEmpty)
|
||||
val end = System.currentTimeMillis
|
||||
fail(s"took ${end - start}ms")
|
||||
forAllDbs { dbs =>
|
||||
val db = dbs.audit()
|
||||
val nodeCount = 100
|
||||
val channelCount = 1000
|
||||
val eventCount = 100000
|
||||
val nodeIds = (1 to nodeCount).map(_ => randomKey.publicKey)
|
||||
val channelIds = (1 to channelCount).map(_ => randomBytes32)
|
||||
// Fund channels.
|
||||
channelIds.foreach(channelId => {
|
||||
val nodeId = nodeIds(Random.nextInt(nodeCount))
|
||||
db.add(NetworkFeePaid(null, nodeId, channelId, Transaction(0, Seq.empty, Seq.empty, 0), 100 sat, "funding"))
|
||||
})
|
||||
// Add relay events.
|
||||
(1 to eventCount).foreach(_ => {
|
||||
// 25% trampoline relays.
|
||||
if (Random.nextInt(4) == 0) {
|
||||
val outgoingCount = 1 + Random.nextInt(4)
|
||||
val incoming = Seq(PaymentRelayed.Part(10000 msat, randomBytes32))
|
||||
val outgoing = (1 to outgoingCount).map(_ => PaymentRelayed.Part(Random.nextInt(2000).msat, channelIds(Random.nextInt(channelCount))))
|
||||
db.add(TrampolinePaymentRelayed(randomBytes32, incoming, outgoing))
|
||||
} else {
|
||||
val toChannelId = channelIds(Random.nextInt(channelCount))
|
||||
db.add(ChannelPaymentRelayed(10000 msat, Random.nextInt(10000).msat, randomBytes32, randomBytes32, toChannelId))
|
||||
}
|
||||
})
|
||||
// Test starts here.
|
||||
val start = System.currentTimeMillis
|
||||
assert(db.stats(0, start + 1).nonEmpty)
|
||||
val end = System.currentTimeMillis
|
||||
fail(s"took ${end - start}ms")
|
||||
}
|
||||
}
|
||||
|
||||
test("handle migration version 1 -> 4") {
|
||||
val connection = TestConstants.sqliteInMemory()
|
||||
forAllDbs {
|
||||
case _: TestPgDatabases => // no migration
|
||||
case dbs: TestSqliteDatabases =>
|
||||
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
|
||||
val connection = dbs.connection
|
||||
|
||||
// simulate existing previous version db
|
||||
using(connection.createStatement()) { statement =>
|
||||
getVersion(statement, "audit", 1)
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (amount_in_msat INTEGER NOT NULL, amount_out_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event STRING NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
// simulate existing previous version db
|
||||
using(connection.createStatement()) { statement =>
|
||||
getVersion(statement, "audit", 1)
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (amount_in_msat INTEGER NOT NULL, amount_out_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event STRING NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS balance_updated_idx ON balance_updated(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS balance_updated_idx ON balance_updated(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
|
||||
}
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 1) // we expect version 1
|
||||
}
|
||||
|
||||
val ps = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 42000 msat, PrivateKey(ByteVector32.One).publicKey, PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None) :: Nil)
|
||||
val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, None)
|
||||
val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, None)
|
||||
val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 84003 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil)
|
||||
val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
|
||||
val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true)
|
||||
|
||||
// add a row (no ID on sent)
|
||||
using(connection.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setLong(1, ps.recipientAmount.toLong)
|
||||
statement.setLong(2, ps.feesPaid.toLong)
|
||||
statement.setBytes(3, ps.paymentHash.toArray)
|
||||
statement.setBytes(4, ps.paymentPreimage.toArray)
|
||||
statement.setBytes(5, ps.parts.head.toChannelId.toArray)
|
||||
statement.setLong(6, ps.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
val migratedDb = new SqliteAuditDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 4) // version changed from 1 -> 4
|
||||
}
|
||||
|
||||
// existing rows in the 'sent' table will use id=00000000-0000-0000-0000-000000000000 as default
|
||||
assert(migratedDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID)))))
|
||||
|
||||
val postMigrationDb = new SqliteAuditDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 4) // version 4
|
||||
}
|
||||
|
||||
postMigrationDb.add(ps1)
|
||||
postMigrationDb.add(e1)
|
||||
postMigrationDb.add(e2)
|
||||
|
||||
// the old record will have the UNKNOWN_UUID but the new ones will have their actual id
|
||||
val expected = Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID))), ps1)
|
||||
assert(postMigrationDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === expected)
|
||||
}
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 1) // we expect version 1
|
||||
}
|
||||
|
||||
val ps = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 42000 msat, PrivateKey(ByteVector32.One).publicKey, PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None) :: Nil)
|
||||
val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, None)
|
||||
val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, None)
|
||||
val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 84003 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil)
|
||||
val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
|
||||
val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true)
|
||||
|
||||
// add a row (no ID on sent)
|
||||
using(connection.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setLong(1, ps.recipientAmount.toLong)
|
||||
statement.setLong(2, ps.feesPaid.toLong)
|
||||
statement.setBytes(3, ps.paymentHash.toArray)
|
||||
statement.setBytes(4, ps.paymentPreimage.toArray)
|
||||
statement.setBytes(5, ps.parts.head.toChannelId.toArray)
|
||||
statement.setLong(6, ps.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
val migratedDb = new SqliteAuditDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 4) // version changed from 1 -> 4
|
||||
}
|
||||
|
||||
// existing rows in the 'sent' table will use id=00000000-0000-0000-0000-000000000000 as default
|
||||
assert(migratedDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID)))))
|
||||
|
||||
val postMigrationDb = new SqliteAuditDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 4) // version 4
|
||||
}
|
||||
|
||||
postMigrationDb.add(ps1)
|
||||
postMigrationDb.add(e1)
|
||||
postMigrationDb.add(e2)
|
||||
|
||||
// the old record will have the UNKNOWN_UUID but the new ones will have their actual id
|
||||
val expected = Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID))), ps1)
|
||||
assert(postMigrationDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === expected)
|
||||
}
|
||||
|
||||
test("handle migration version 2 -> 4") {
|
||||
val connection = TestConstants.sqliteInMemory()
|
||||
forAllDbs {
|
||||
case _: TestPgDatabases => // no migration
|
||||
case dbs: TestSqliteDatabases =>
|
||||
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
|
||||
val connection = dbs.connection
|
||||
|
||||
// simulate existing previous version db
|
||||
using(connection.createStatement()) { statement =>
|
||||
getVersion(statement, "audit", 2)
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL, id BLOB NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (amount_in_msat INTEGER NOT NULL, amount_out_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event STRING NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
// simulate existing previous version db
|
||||
using(connection.createStatement()) { statement =>
|
||||
getVersion(statement, "audit", 2)
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL, id BLOB NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (amount_in_msat INTEGER NOT NULL, amount_out_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event STRING NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS balance_updated_idx ON balance_updated(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS balance_updated_idx ON balance_updated(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
|
||||
}
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 2) // version 2 is deployed now
|
||||
}
|
||||
|
||||
val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
|
||||
val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true)
|
||||
|
||||
val migratedDb = new SqliteAuditDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 4) // version changed from 2 -> 4
|
||||
}
|
||||
|
||||
migratedDb.add(e1)
|
||||
|
||||
val postMigrationDb = new SqliteAuditDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 4) // version 4
|
||||
}
|
||||
|
||||
postMigrationDb.add(e2)
|
||||
}
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 2) // version 2 is deployed now
|
||||
}
|
||||
|
||||
val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
|
||||
val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true)
|
||||
|
||||
val migratedDb = new SqliteAuditDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 4) // version changed from 2 -> 4
|
||||
}
|
||||
|
||||
migratedDb.add(e1)
|
||||
|
||||
val postMigrationDb = new SqliteAuditDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 4) // version 4
|
||||
}
|
||||
|
||||
postMigrationDb.add(e2)
|
||||
}
|
||||
|
||||
test("handle migration version 3 -> 4") {
|
||||
val connection = TestConstants.sqliteInMemory()
|
||||
forAllDbs {
|
||||
case _: TestPgDatabases => // no migration
|
||||
case dbs: TestSqliteDatabases =>
|
||||
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
|
||||
val connection = dbs.connection
|
||||
|
||||
// simulate existing previous version db
|
||||
using(connection.createStatement()) { statement =>
|
||||
getVersion(statement, "audit", 3)
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL, id BLOB NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (amount_in_msat INTEGER NOT NULL, amount_out_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
// simulate existing previous version db
|
||||
using(connection.createStatement()) { statement =>
|
||||
getVersion(statement, "audit", 3)
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL, id BLOB NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (amount_in_msat INTEGER NOT NULL, amount_out_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS balance_updated_idx ON balance_updated(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS balance_updated_idx ON balance_updated(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)")
|
||||
}
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 3) // version 3 is deployed now
|
||||
}
|
||||
|
||||
val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 100)
|
||||
val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 110)
|
||||
val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil)
|
||||
|
||||
for (pp <- Seq(pp1, pp2)) {
|
||||
using(connection.prepareStatement("INSERT INTO sent (amount_msat, fees_msat, payment_hash, payment_preimage, to_channel_id, timestamp, id) VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setLong(1, pp.amount.toLong)
|
||||
statement.setLong(2, pp.feesPaid.toLong)
|
||||
statement.setBytes(3, ps1.paymentHash.toArray)
|
||||
statement.setBytes(4, ps1.paymentPreimage.toArray)
|
||||
statement.setBytes(5, pp.toChannelId.toArray)
|
||||
statement.setLong(6, pp.timestamp)
|
||||
statement.setBytes(7, pp.id.toString.getBytes)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105)
|
||||
val relayed2 = ChannelPaymentRelayed(650 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 115)
|
||||
|
||||
for (relayed <- Seq(relayed1, relayed2)) {
|
||||
using(connection.prepareStatement("INSERT INTO relayed (amount_in_msat, amount_out_msat, payment_hash, from_channel_id, to_channel_id, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setLong(1, relayed.amountIn.toLong)
|
||||
statement.setLong(2, relayed.amountOut.toLong)
|
||||
statement.setBytes(3, relayed.paymentHash.toArray)
|
||||
statement.setBytes(4, relayed.fromChannelId.toArray)
|
||||
statement.setBytes(5, relayed.toChannelId.toArray)
|
||||
statement.setLong(6, relayed.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
val migratedDb = new SqliteAuditDb(connection)
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 4) // version changed from 3 -> 4
|
||||
}
|
||||
|
||||
assert(migratedDb.listSent(50, 150).toSet === Set(
|
||||
ps1.copy(id = pp1.id, recipientAmount = pp1.amount, parts = pp1 :: Nil),
|
||||
ps1.copy(id = pp2.id, recipientAmount = pp2.amount, parts = pp2 :: Nil)
|
||||
))
|
||||
assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2))
|
||||
|
||||
val postMigrationDb = new SqliteAuditDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 4) // version 4
|
||||
}
|
||||
|
||||
val ps2 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, randomKey.publicKey, Seq(
|
||||
PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 160),
|
||||
PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 165)
|
||||
))
|
||||
val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), 150)
|
||||
|
||||
postMigrationDb.add(ps2)
|
||||
assert(postMigrationDb.listSent(155, 200) === Seq(ps2))
|
||||
postMigrationDb.add(relayed3)
|
||||
assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3))
|
||||
}
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 3) // version 3 is deployed now
|
||||
}
|
||||
|
||||
val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 100)
|
||||
val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 110)
|
||||
val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil)
|
||||
|
||||
for (pp <- Seq(pp1, pp2)) {
|
||||
using(connection.prepareStatement("INSERT INTO sent (amount_msat, fees_msat, payment_hash, payment_preimage, to_channel_id, timestamp, id) VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setLong(1, pp.amount.toLong)
|
||||
statement.setLong(2, pp.feesPaid.toLong)
|
||||
statement.setBytes(3, ps1.paymentHash.toArray)
|
||||
statement.setBytes(4, ps1.paymentPreimage.toArray)
|
||||
statement.setBytes(5, pp.toChannelId.toArray)
|
||||
statement.setLong(6, pp.timestamp)
|
||||
statement.setBytes(7, pp.id.toString.getBytes)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105)
|
||||
val relayed2 = ChannelPaymentRelayed(650 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 115)
|
||||
|
||||
for (relayed <- Seq(relayed1, relayed2)) {
|
||||
using(connection.prepareStatement("INSERT INTO relayed (amount_in_msat, amount_out_msat, payment_hash, from_channel_id, to_channel_id, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setLong(1, relayed.amountIn.toLong)
|
||||
statement.setLong(2, relayed.amountOut.toLong)
|
||||
statement.setBytes(3, relayed.paymentHash.toArray)
|
||||
statement.setBytes(4, relayed.fromChannelId.toArray)
|
||||
statement.setBytes(5, relayed.toChannelId.toArray)
|
||||
statement.setLong(6, relayed.timestamp)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
val migratedDb = new SqliteAuditDb(connection)
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 4) // version changed from 3 -> 4
|
||||
}
|
||||
|
||||
assert(migratedDb.listSent(50, 150).toSet === Set(
|
||||
ps1.copy(id = pp1.id, recipientAmount = pp1.amount, parts = pp1 :: Nil),
|
||||
ps1.copy(id = pp2.id, recipientAmount = pp2.amount, parts = pp2 :: Nil)
|
||||
))
|
||||
assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2))
|
||||
|
||||
val postMigrationDb = new SqliteAuditDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "audit", 4) == 4) // version 4
|
||||
}
|
||||
|
||||
val ps2 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, randomKey.publicKey, Seq(
|
||||
PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 160),
|
||||
PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 165)
|
||||
))
|
||||
val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), 150)
|
||||
|
||||
postMigrationDb.add(ps2)
|
||||
assert(postMigrationDb.listSent(155, 200) === Seq(ps2))
|
||||
postMigrationDb.add(relayed3)
|
||||
assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3))
|
||||
}
|
||||
|
||||
test("ignore invalid values in the DB") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
val db = new SqliteAuditDb(sqlite)
|
||||
forAllDbs { dbs =>
|
||||
val db = dbs.audit()
|
||||
val sqlite = dbs.connection
|
||||
val isPg = dbs.isInstanceOf[TestPgDatabases]
|
||||
|
||||
using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, randomBytes32.toArray)
|
||||
statement.setLong(2, 42)
|
||||
statement.setBytes(3, randomBytes32.toArray)
|
||||
statement.setString(4, "IN")
|
||||
statement.setString(5, "unknown") // invalid relay type
|
||||
statement.setLong(6, 10)
|
||||
statement.executeUpdate()
|
||||
using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
if (isPg) statement.setString(1, randomBytes32.toHex) else statement.setBytes(1, randomBytes32.toArray)
|
||||
statement.setLong(2, 42)
|
||||
if (isPg) statement.setString(3, randomBytes32.toHex) else statement.setBytes(3, randomBytes32.toArray)
|
||||
statement.setString(4, "IN")
|
||||
statement.setString(5, "unknown") // invalid relay type
|
||||
statement.setLong(6, 10)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
if (isPg) statement.setString(1, randomBytes32.toHex) else statement.setBytes(1, randomBytes32.toArray)
|
||||
statement.setLong(2, 51)
|
||||
if (isPg) statement.setString(3, randomBytes32.toHex) else statement.setBytes(3, randomBytes32.toArray)
|
||||
statement.setString(4, "UP") // invalid direction
|
||||
statement.setString(5, "channel")
|
||||
statement.setLong(6, 20)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
val paymentHash = randomBytes32
|
||||
val channelId = randomBytes32
|
||||
|
||||
using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
if (isPg) statement.setString(1, paymentHash.toHex) else statement.setBytes(1, paymentHash.toArray)
|
||||
statement.setLong(2, 65)
|
||||
if (isPg) statement.setString(3, channelId.toHex) else statement.setBytes(3, channelId.toArray)
|
||||
statement.setString(4, "IN") // missing a corresponding OUT
|
||||
statement.setString(5, "channel")
|
||||
statement.setLong(6, 30)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
assert(db.listRelayed(0, 40) === Nil)
|
||||
}
|
||||
|
||||
using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, randomBytes32.toArray)
|
||||
statement.setLong(2, 51)
|
||||
statement.setBytes(3, randomBytes32.toArray)
|
||||
statement.setString(4, "UP") // invalid direction
|
||||
statement.setString(5, "channel")
|
||||
statement.setLong(6, 20)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
val paymentHash = randomBytes32
|
||||
val channelId = randomBytes32
|
||||
|
||||
using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, paymentHash.toArray)
|
||||
statement.setLong(2, 65)
|
||||
statement.setBytes(3, channelId.toArray)
|
||||
statement.setString(4, "IN") // missing a corresponding OUT
|
||||
statement.setString(5, "channel")
|
||||
statement.setLong(6, 30)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
assert(db.listRelayed(0, 40) === Nil)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -16,81 +16,89 @@
|
||||
|
||||
package fr.acinq.eclair.db
|
||||
|
||||
import java.sql.SQLException
|
||||
|
||||
import fr.acinq.bitcoin.ByteVector32
|
||||
import fr.acinq.eclair.CltvExpiry
|
||||
import fr.acinq.eclair.TestConstants.{TestPgDatabases, TestSqliteDatabases, forAllDbs}
|
||||
import fr.acinq.eclair.db.sqlite.SqliteChannelsDb
|
||||
import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, using}
|
||||
import fr.acinq.eclair.db.sqlite.{SqliteChannelsDb, SqlitePendingRelayDb}
|
||||
import fr.acinq.eclair.wire.ChannelCodecs.stateDataCodec
|
||||
import fr.acinq.eclair.wire.ChannelCodecsSpec
|
||||
import fr.acinq.eclair.{CltvExpiry, TestConstants}
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.sqlite.SQLiteException
|
||||
import scodec.bits.ByteVector
|
||||
|
||||
class SqliteChannelsDbSpec extends AnyFunSuite {
|
||||
|
||||
test("init sqlite 2 times in a row") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
val db1 = new SqliteChannelsDb(sqlite)
|
||||
val db2 = new SqliteChannelsDb(sqlite)
|
||||
forAllDbs { dbs =>
|
||||
val db1 = dbs.channels()
|
||||
val db2 = dbs.channels()
|
||||
}
|
||||
}
|
||||
|
||||
test("add/remove/list channels") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
val db = new SqliteChannelsDb(sqlite)
|
||||
new SqlitePendingRelayDb(sqlite) // needed by db.removeChannel
|
||||
forAllDbs { dbs =>
|
||||
val db = dbs.channels()
|
||||
dbs.pendingRelay() // needed by db.removeChannel
|
||||
|
||||
val channel = ChannelCodecsSpec.normal
|
||||
val channel = ChannelCodecsSpec.normal
|
||||
|
||||
val commitNumber = 42
|
||||
val paymentHash1 = ByteVector32.Zeroes
|
||||
val cltvExpiry1 = CltvExpiry(123)
|
||||
val paymentHash2 = ByteVector32(ByteVector.fill(32)(1))
|
||||
val cltvExpiry2 = CltvExpiry(656)
|
||||
val commitNumber = 42
|
||||
val paymentHash1 = ByteVector32.Zeroes
|
||||
val cltvExpiry1 = CltvExpiry(123)
|
||||
val paymentHash2 = ByteVector32(ByteVector.fill(32)(1))
|
||||
val cltvExpiry2 = CltvExpiry(656)
|
||||
|
||||
intercept[SQLiteException](db.addHtlcInfo(channel.channelId, commitNumber, paymentHash1, cltvExpiry1)) // no related channel
|
||||
intercept[SQLException](db.addHtlcInfo(channel.channelId, commitNumber, paymentHash1, cltvExpiry1)) // no related channel
|
||||
|
||||
assert(db.listLocalChannels().toSet === Set.empty)
|
||||
db.addOrUpdateChannel(channel)
|
||||
db.addOrUpdateChannel(channel)
|
||||
assert(db.listLocalChannels() === List(channel))
|
||||
assert(db.listLocalChannels().toSet === Set.empty)
|
||||
db.addOrUpdateChannel(channel)
|
||||
db.addOrUpdateChannel(channel)
|
||||
assert(db.listLocalChannels() === List(channel))
|
||||
|
||||
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == Nil)
|
||||
db.addHtlcInfo(channel.channelId, commitNumber, paymentHash1, cltvExpiry1)
|
||||
db.addHtlcInfo(channel.channelId, commitNumber, paymentHash2, cltvExpiry2)
|
||||
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == List((paymentHash1, cltvExpiry1), (paymentHash2, cltvExpiry2)))
|
||||
assert(db.listHtlcInfos(channel.channelId, 43).toList == Nil)
|
||||
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == Nil)
|
||||
db.addHtlcInfo(channel.channelId, commitNumber, paymentHash1, cltvExpiry1)
|
||||
db.addHtlcInfo(channel.channelId, commitNumber, paymentHash2, cltvExpiry2)
|
||||
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList.toSet == Set((paymentHash1, cltvExpiry1), (paymentHash2, cltvExpiry2)))
|
||||
assert(db.listHtlcInfos(channel.channelId, 43).toList == Nil)
|
||||
|
||||
db.removeChannel(channel.channelId)
|
||||
assert(db.listLocalChannels() === Nil)
|
||||
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == Nil)
|
||||
db.removeChannel(channel.channelId)
|
||||
assert(db.listLocalChannels() === Nil)
|
||||
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == Nil)
|
||||
}
|
||||
}
|
||||
|
||||
test("migrate channel database v1 -> v2") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
forAllDbs {
|
||||
case _: TestPgDatabases => // no migration
|
||||
case dbs: TestSqliteDatabases =>
|
||||
val sqlite = dbs.connection
|
||||
|
||||
// create a v1 channels database
|
||||
using(sqlite.createStatement()) { statement =>
|
||||
getVersion(statement, "channels", 1)
|
||||
statement.execute("PRAGMA foreign_keys = ON")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
|
||||
}
|
||||
// create a v1 channels database
|
||||
using(sqlite.createStatement()) { statement =>
|
||||
getVersion(statement, "channels", 1)
|
||||
statement.execute("PRAGMA foreign_keys = ON")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
|
||||
}
|
||||
|
||||
// insert 1 row
|
||||
val channel = ChannelCodecsSpec.normal
|
||||
val data = stateDataCodec.encode(channel).require.toByteArray
|
||||
using(sqlite.prepareStatement("INSERT INTO local_channels VALUES (?, ?)")) { statement =>
|
||||
statement.setBytes(1, channel.channelId.toArray)
|
||||
statement.setBytes(2, data)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
// insert 1 row
|
||||
val channel = ChannelCodecsSpec.normal
|
||||
val data = stateDataCodec.encode(channel).require.toByteArray
|
||||
using(sqlite.prepareStatement("INSERT INTO local_channels VALUES (?, ?)")) { statement =>
|
||||
statement.setBytes(1, channel.channelId.toArray)
|
||||
statement.setBytes(2, data)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
// check that db migration works
|
||||
val db = new SqliteChannelsDb(sqlite)
|
||||
using(sqlite.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "channels", 1) == 2) // version changed from 1 -> 2
|
||||
// check that db migration works
|
||||
val db = new SqliteChannelsDb(sqlite)
|
||||
using(sqlite.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "channels", 1) == 2) // version changed from 1 -> 2
|
||||
}
|
||||
assert(db.listLocalChannels() === List(channel))
|
||||
}
|
||||
assert(db.listLocalChannels() === List(channel))
|
||||
}
|
||||
}
|
@ -16,107 +16,112 @@
|
||||
|
||||
package fr.acinq.eclair.db
|
||||
|
||||
import java.sql.Connection
|
||||
|
||||
import fr.acinq.bitcoin.Crypto.PrivateKey
|
||||
import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64, Crypto, Satoshi}
|
||||
import fr.acinq.eclair.FeatureSupport.Optional
|
||||
import fr.acinq.eclair.Features.VariableLengthOnion
|
||||
import fr.acinq.eclair.db.sqlite.SqliteNetworkDb
|
||||
import fr.acinq.eclair.TestConstants.{TestDatabases, TestPgDatabases, TestSqliteDatabases}
|
||||
import fr.acinq.eclair.db.sqlite.SqliteUtils._
|
||||
import fr.acinq.eclair.router.Announcements
|
||||
import fr.acinq.eclair.router.Router.PublicChannel
|
||||
import fr.acinq.eclair.wire.{Color, NodeAddress, Tor2}
|
||||
import fr.acinq.eclair.{ActivatedFeature, CltvExpiryDelta, Features, LongToBtcAmount, ShortChannelId, TestConstants, randomBytes32, randomKey}
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import scodec.bits.HexStringSyntax
|
||||
|
||||
import scala.collection.{SortedMap, mutable}
|
||||
|
||||
class SqliteNetworkDbSpec extends AnyFunSuite {
|
||||
|
||||
import TestConstants.forAllDbs
|
||||
|
||||
val shortChannelIds = (42 to (5000 + 42)).map(i => ShortChannelId(i))
|
||||
|
||||
test("init sqlite 2 times in a row") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
val db1 = new SqliteNetworkDb(sqlite)
|
||||
val db2 = new SqliteNetworkDb(sqlite)
|
||||
forAllDbs { dbs =>
|
||||
val db1 = dbs.network()
|
||||
val db2 = dbs.network()
|
||||
}
|
||||
}
|
||||
|
||||
test("migration test 1->2") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
forAllDbs {
|
||||
case _: TestPgDatabases => // no migration
|
||||
case dbs: TestSqliteDatabases =>
|
||||
|
||||
using(sqlite.createStatement()) { statement =>
|
||||
getVersion(statement, "network", 1) // this will set version to 1
|
||||
statement.execute("PRAGMA foreign_keys = ON")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS nodes (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channels (short_channel_id INTEGER NOT NULL PRIMARY KEY, txid STRING NOT NULL, data BLOB NOT NULL, capacity_sat INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_updates (short_channel_id INTEGER NOT NULL, node_flag INTEGER NOT NULL, data BLOB NOT NULL, PRIMARY KEY(short_channel_id, node_flag), FOREIGN KEY(short_channel_id) REFERENCES channels(short_channel_id))")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_updates_idx ON channel_updates(short_channel_id)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS pruned (short_channel_id INTEGER NOT NULL PRIMARY KEY)")
|
||||
}
|
||||
using(dbs.connection.createStatement()) { statement =>
|
||||
dbs.getVersion(statement, "network", 1) // this will set version to 1
|
||||
statement.execute("PRAGMA foreign_keys = ON")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS nodes (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channels (short_channel_id INTEGER NOT NULL PRIMARY KEY, txid STRING NOT NULL, data BLOB NOT NULL, capacity_sat INTEGER NOT NULL)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_updates (short_channel_id INTEGER NOT NULL, node_flag INTEGER NOT NULL, data BLOB NOT NULL, PRIMARY KEY(short_channel_id, node_flag), FOREIGN KEY(short_channel_id) REFERENCES channels(short_channel_id))")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_updates_idx ON channel_updates(short_channel_id)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS pruned (short_channel_id INTEGER NOT NULL PRIMARY KEY)")
|
||||
}
|
||||
|
||||
using(sqlite.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "network", 2) == 1)
|
||||
}
|
||||
using(dbs.connection.createStatement()) { statement =>
|
||||
assert(dbs.getVersion(statement, "network", 2) == 1)
|
||||
}
|
||||
|
||||
// first round: this will trigger a migration
|
||||
simpleTest(sqlite)
|
||||
// first round: this will trigger a migration
|
||||
simpleTest(dbs)
|
||||
|
||||
using(sqlite.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "network", 2) == 2)
|
||||
}
|
||||
using(dbs.connection.createStatement()) { statement =>
|
||||
assert(dbs.getVersion(statement, "network", 2) == 2)
|
||||
}
|
||||
|
||||
using(sqlite.createStatement()) { statement =>
|
||||
statement.executeUpdate("DELETE FROM nodes")
|
||||
statement.executeUpdate("DELETE FROM channels")
|
||||
}
|
||||
using(dbs.connection.createStatement()) { statement =>
|
||||
statement.executeUpdate("DELETE FROM nodes")
|
||||
statement.executeUpdate("DELETE FROM channels")
|
||||
}
|
||||
|
||||
// second round: no migration
|
||||
simpleTest(sqlite)
|
||||
// second round: no migration
|
||||
simpleTest(dbs)
|
||||
|
||||
using(sqlite.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "network", 2) == 2)
|
||||
using(dbs.connection.createStatement()) { statement =>
|
||||
assert(dbs.getVersion(statement, "network", 2) == 2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("add/remove/list nodes") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
val db = new SqliteNetworkDb(sqlite)
|
||||
forAllDbs { dbs =>
|
||||
val db = dbs.network()
|
||||
|
||||
val node_1 = Announcements.makeNodeAnnouncement(randomKey, "node-alice", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features.empty)
|
||||
val node_2 = Announcements.makeNodeAnnouncement(randomKey, "node-bob", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features(Set(ActivatedFeature(VariableLengthOnion, Optional))))
|
||||
val node_3 = Announcements.makeNodeAnnouncement(randomKey, "node-charlie", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features(Set(ActivatedFeature(VariableLengthOnion, Optional))))
|
||||
val node_4 = Announcements.makeNodeAnnouncement(randomKey, "node-charlie", Color(100.toByte, 200.toByte, 300.toByte), Tor2("aaaqeayeaudaocaj", 42000) :: Nil, Features.empty)
|
||||
val node_1 = Announcements.makeNodeAnnouncement(randomKey, "node-alice", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features.empty)
|
||||
val node_2 = Announcements.makeNodeAnnouncement(randomKey, "node-bob", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features(Set(ActivatedFeature(VariableLengthOnion, Optional))))
|
||||
val node_3 = Announcements.makeNodeAnnouncement(randomKey, "node-charlie", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features(Set(ActivatedFeature(VariableLengthOnion, Optional))))
|
||||
val node_4 = Announcements.makeNodeAnnouncement(randomKey, "node-charlie", Color(100.toByte, 200.toByte, 300.toByte), Tor2("aaaqeayeaudaocaj", 42000) :: Nil, Features.empty)
|
||||
|
||||
assert(db.listNodes().toSet === Set.empty)
|
||||
db.addNode(node_1)
|
||||
db.addNode(node_1) // duplicate is ignored
|
||||
assert(db.getNode(node_1.nodeId) === Some(node_1))
|
||||
assert(db.listNodes().size === 1)
|
||||
db.addNode(node_2)
|
||||
db.addNode(node_3)
|
||||
db.addNode(node_4)
|
||||
assert(db.listNodes().toSet === Set(node_1, node_2, node_3, node_4))
|
||||
db.removeNode(node_2.nodeId)
|
||||
assert(db.listNodes().toSet === Set(node_1, node_3, node_4))
|
||||
db.updateNode(node_1)
|
||||
assert(db.listNodes().toSet === Set.empty)
|
||||
db.addNode(node_1)
|
||||
db.addNode(node_1) // duplicate is ignored
|
||||
assert(db.getNode(node_1.nodeId) === Some(node_1))
|
||||
assert(db.listNodes().size === 1)
|
||||
db.addNode(node_2)
|
||||
db.addNode(node_3)
|
||||
db.addNode(node_4)
|
||||
assert(db.listNodes().toSet === Set(node_1, node_2, node_3, node_4))
|
||||
db.removeNode(node_2.nodeId)
|
||||
assert(db.listNodes().toSet === Set(node_1, node_3, node_4))
|
||||
db.updateNode(node_1)
|
||||
|
||||
assert(node_4.addresses == List(Tor2("aaaqeayeaudaocaj", 42000)))
|
||||
assert(node_4.addresses == List(Tor2("aaaqeayeaudaocaj", 42000)))
|
||||
}
|
||||
}
|
||||
|
||||
test("correctly handle txids that start with 0") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
val db = new SqliteNetworkDb(sqlite)
|
||||
val sig = ByteVector64.Zeroes
|
||||
val c = Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, ShortChannelId(42), randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, sig, sig, sig, sig)
|
||||
val txid = ByteVector32.fromValidHex("0001" * 16)
|
||||
db.addChannel(c, txid, Satoshi(42))
|
||||
assert(db.listChannels() === SortedMap(c.shortChannelId -> PublicChannel(c, txid, Satoshi(42), None, None, None)))
|
||||
forAllDbs { dbs =>
|
||||
val db = dbs.network()
|
||||
val sig = ByteVector64.Zeroes
|
||||
val c = Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, ShortChannelId(42), randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, sig, sig, sig, sig)
|
||||
val txid = ByteVector32.fromValidHex("0001" * 16)
|
||||
db.addChannel(c, txid, Satoshi(42))
|
||||
assert(db.listChannels() === SortedMap(c.shortChannelId -> PublicChannel(c, txid, Satoshi(42), None, None, None)))
|
||||
}
|
||||
}
|
||||
|
||||
def simpleTest(sqlite: Connection) = {
|
||||
val db = new SqliteNetworkDb(sqlite)
|
||||
def simpleTest(dbs: TestDatabases) = {
|
||||
val db = dbs.network()
|
||||
|
||||
def sig = Crypto.sign(randomBytes32, randomKey)
|
||||
|
||||
@ -172,75 +177,85 @@ class SqliteNetworkDbSpec extends AnyFunSuite {
|
||||
}
|
||||
|
||||
test("add/remove/list channels and channel_updates") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
simpleTest(sqlite)
|
||||
forAllDbs { dbs =>
|
||||
simpleTest(dbs)
|
||||
}
|
||||
}
|
||||
|
||||
test("creating a table that already exists but with different column types is ignored") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
using(sqlite.createStatement(), inTransaction = true) { statement =>
|
||||
statement.execute("CREATE TABLE IF NOT EXISTS test (txid STRING NOT NULL)")
|
||||
}
|
||||
// column type is STRING
|
||||
assert(sqlite.getMetaData.getColumns(null, null, "test", null).getString("TYPE_NAME") == "STRING")
|
||||
forAllDbs { dbs =>
|
||||
|
||||
// insert and read back random values
|
||||
val txids = for (_ <- 0 until 1000) yield randomBytes32
|
||||
txids.foreach { txid =>
|
||||
using(sqlite.prepareStatement("INSERT OR IGNORE INTO test VALUES (?)")) { statement =>
|
||||
statement.setString(1, txid.toHex)
|
||||
statement.executeUpdate()
|
||||
using(dbs.connection.createStatement(), inTransaction = true) { statement =>
|
||||
statement.execute("CREATE TABLE IF NOT EXISTS test (txid VARCHAR NOT NULL)")
|
||||
}
|
||||
}
|
||||
// column type is VARCHAR
|
||||
val rs = dbs.connection.getMetaData.getColumns(null, null, "test", null)
|
||||
assert(rs.next())
|
||||
assert(rs.getString("TYPE_NAME").toLowerCase == "varchar")
|
||||
|
||||
val check = using(sqlite.createStatement()) { statement =>
|
||||
val rs = statement.executeQuery("SELECT txid FROM test")
|
||||
val q = new mutable.Queue[ByteVector32]()
|
||||
while (rs.next()) {
|
||||
val txId = ByteVector32.fromValidHex(rs.getString("txid"))
|
||||
q.enqueue(txId)
|
||||
|
||||
// insert and read back random values
|
||||
val txids = for (_ <- 0 until 1000) yield randomBytes32
|
||||
txids.foreach { txid =>
|
||||
using(dbs.connection.prepareStatement("INSERT INTO test VALUES (?)")) { statement =>
|
||||
statement.setString(1, txid.toHex)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
}
|
||||
q
|
||||
|
||||
val check = using(dbs.connection.createStatement()) { statement =>
|
||||
val rs = statement.executeQuery("SELECT txid FROM test")
|
||||
val q = new mutable.Queue[ByteVector32]()
|
||||
while (rs.next()) {
|
||||
val txId = ByteVector32.fromValidHex(rs.getString("txid"))
|
||||
q.enqueue(txId)
|
||||
}
|
||||
q
|
||||
}
|
||||
assert(txids.toSet == check.toSet)
|
||||
|
||||
|
||||
using(dbs.connection.createStatement(), inTransaction = true) { statement =>
|
||||
statement.execute("CREATE TABLE IF NOT EXISTS test (txid TEXT NOT NULL)")
|
||||
}
|
||||
|
||||
// column type has not changed
|
||||
val rs1 = dbs.connection.getMetaData.getColumns(null, null, "test", null)
|
||||
assert(rs1.next())
|
||||
assert(rs1.getString("TYPE_NAME").toLowerCase == "varchar")
|
||||
}
|
||||
assert(txids.toSet == check.toSet)
|
||||
|
||||
|
||||
using(sqlite.createStatement(), inTransaction = true) { statement =>
|
||||
statement.execute("CREATE TABLE IF NOT EXISTS test (txid TEXT NOT NULL)")
|
||||
}
|
||||
|
||||
// column type has not changed
|
||||
assert(sqlite.getMetaData.getColumns(null, null, "test", null).getString("TYPE_NAME") == "STRING")
|
||||
}
|
||||
|
||||
test("remove many channels") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
val db = new SqliteNetworkDb(sqlite)
|
||||
val sig = Crypto.sign(randomBytes32, randomKey)
|
||||
val priv = randomKey
|
||||
val pub = priv.publicKey
|
||||
val capacity = 10000 sat
|
||||
forAllDbs { dbs =>
|
||||
val db = dbs.network()
|
||||
val sig = Crypto.sign(randomBytes32, randomKey)
|
||||
val priv = randomKey
|
||||
val pub = priv.publicKey
|
||||
val capacity = 10000 sat
|
||||
|
||||
val channels = shortChannelIds.map(id => Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, id, pub, pub, pub, pub, sig, sig, sig, sig))
|
||||
val template = Announcements.makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv, pub, ShortChannelId(42), CltvExpiryDelta(5), 7000000 msat, 50000 msat, 100, 500000000L msat, true)
|
||||
val updates = shortChannelIds.map(id => template.copy(shortChannelId = id))
|
||||
val txid = randomBytes32
|
||||
channels.foreach(ca => db.addChannel(ca, txid, capacity))
|
||||
updates.foreach(u => db.updateChannel(u))
|
||||
assert(db.listChannels().keySet === channels.map(_.shortChannelId).toSet)
|
||||
val channels = shortChannelIds.map(id => Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, id, pub, pub, pub, pub, sig, sig, sig, sig))
|
||||
val template = Announcements.makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv, pub, ShortChannelId(42), CltvExpiryDelta(5), 7000000 msat, 50000 msat, 100, 500000000L msat, true)
|
||||
val updates = shortChannelIds.map(id => template.copy(shortChannelId = id))
|
||||
val txid = randomBytes32
|
||||
channels.foreach(ca => db.addChannel(ca, txid, capacity))
|
||||
updates.foreach(u => db.updateChannel(u))
|
||||
assert(db.listChannels().keySet === channels.map(_.shortChannelId).toSet)
|
||||
|
||||
val toDelete = channels.map(_.shortChannelId).drop(500).take(2500)
|
||||
db.removeChannels(toDelete)
|
||||
assert(db.listChannels().keySet === (channels.map(_.shortChannelId).toSet -- toDelete))
|
||||
val toDelete = channels.map(_.shortChannelId).drop(500).take(2500)
|
||||
db.removeChannels(toDelete)
|
||||
assert(db.listChannels().keySet === (channels.map(_.shortChannelId).toSet -- toDelete))
|
||||
}
|
||||
}
|
||||
|
||||
test("prune many channels") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
val db = new SqliteNetworkDb(sqlite)
|
||||
forAllDbs { dbs =>
|
||||
val db = dbs.network()
|
||||
|
||||
db.addToPruned(shortChannelIds)
|
||||
shortChannelIds.foreach { id => assert(db.isPruned((id))) }
|
||||
db.removeFromPruned(ShortChannelId(5))
|
||||
assert(!db.isPruned(ShortChannelId(5)))
|
||||
db.addToPruned(shortChannelIds)
|
||||
shortChannelIds.foreach { id => assert(db.isPruned((id))) }
|
||||
db.removeFromPruned(ShortChannelId(5))
|
||||
assert(!db.isPruned(ShortChannelId(5)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -20,9 +20,9 @@ import java.util.UUID
|
||||
|
||||
import fr.acinq.bitcoin.Crypto.PrivateKey
|
||||
import fr.acinq.bitcoin.{Block, ByteVector32, Crypto}
|
||||
import fr.acinq.eclair.TestConstants.{TestPgDatabases, TestSqliteDatabases, forAllDbs}
|
||||
import fr.acinq.eclair.crypto.Sphinx
|
||||
import fr.acinq.eclair.db.sqlite.SqlitePaymentsDb
|
||||
import fr.acinq.eclair.db.sqlite.SqliteUtils._
|
||||
import fr.acinq.eclair.payment._
|
||||
import fr.acinq.eclair.router.Router.{ChannelHop, NodeHop}
|
||||
import fr.acinq.eclair.wire.{ChannelUpdate, UnknownNextPeer}
|
||||
@ -36,383 +36,396 @@ class SqlitePaymentsDbSpec extends AnyFunSuite {
|
||||
import SqlitePaymentsDbSpec._
|
||||
|
||||
test("init sqlite 2 times in a row") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
val db1 = new SqlitePaymentsDb(sqlite)
|
||||
val db2 = new SqlitePaymentsDb(sqlite)
|
||||
forAllDbs { dbs =>
|
||||
val db1 = dbs.payments()
|
||||
val db2 = dbs.payments()
|
||||
}
|
||||
}
|
||||
|
||||
test("handle version migration 1->4") {
|
||||
val connection = TestConstants.sqliteInMemory()
|
||||
import fr.acinq.eclair.db.sqlite.SqliteUtils._
|
||||
forAllDbs {
|
||||
case _: TestPgDatabases => // no migration
|
||||
case dbs: TestSqliteDatabases =>
|
||||
val connection = dbs.connection
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
getVersion(statement, "payments", 1)
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS payments (payment_hash BLOB NOT NULL PRIMARY KEY, amount_msat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
using(connection.createStatement()) { statement =>
|
||||
getVersion(statement, "payments", 1)
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS payments (payment_hash BLOB NOT NULL PRIMARY KEY, amount_msat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
|
||||
}
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 1) == 1) // version 1 is deployed now
|
||||
}
|
||||
|
||||
// Changes between version 1 and 2:
|
||||
// - the monolithic payments table has been replaced by two tables, received_payments and sent_payments
|
||||
// - old records from the payments table are ignored (not migrated to the new tables)
|
||||
using(connection.prepareStatement("INSERT INTO payments VALUES (?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, paymentHash1.toArray)
|
||||
statement.setLong(2, (123 msat).toLong)
|
||||
statement.setLong(3, 1000) // received_at
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
val preMigrationDb = new SqlitePaymentsDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 1) == 4) // version changed from 1 -> 4
|
||||
}
|
||||
|
||||
// the existing received payment can NOT be queried anymore
|
||||
assert(preMigrationDb.getIncomingPayment(paymentHash1).isEmpty)
|
||||
|
||||
// add a few rows
|
||||
val ps1 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), None, paymentHash1, PaymentType.Standard, 12345 msat, 12345 msat, alice, 1000, None, OutgoingPaymentStatus.Pending)
|
||||
val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(500 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 1)
|
||||
val pr1 = IncomingPayment(i1, preimage1, PaymentType.Standard, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(550 msat, 1100))
|
||||
|
||||
preMigrationDb.addOutgoingPayment(ps1)
|
||||
preMigrationDb.addIncomingPayment(i1, preimage1)
|
||||
preMigrationDb.receiveIncomingPayment(i1.paymentHash, 550 msat, 1100)
|
||||
|
||||
assert(preMigrationDb.listIncomingPayments(1, 1500) === Seq(pr1))
|
||||
assert(preMigrationDb.listOutgoingPayments(1, 1500) === Seq(ps1))
|
||||
|
||||
val postMigrationDb = new SqlitePaymentsDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 4) == 4) // version still to 4
|
||||
}
|
||||
|
||||
assert(postMigrationDb.listIncomingPayments(1, 1500) === Seq(pr1))
|
||||
assert(postMigrationDb.listOutgoingPayments(1, 1500) === Seq(ps1))
|
||||
}
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 1) == 1) // version 1 is deployed now
|
||||
}
|
||||
|
||||
// Changes between version 1 and 2:
|
||||
// - the monolithic payments table has been replaced by two tables, received_payments and sent_payments
|
||||
// - old records from the payments table are ignored (not migrated to the new tables)
|
||||
using(connection.prepareStatement("INSERT INTO payments VALUES (?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, paymentHash1.toArray)
|
||||
statement.setLong(2, (123 msat).toLong)
|
||||
statement.setLong(3, 1000) // received_at
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
val preMigrationDb = new SqlitePaymentsDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 1) == 4) // version changed from 1 -> 4
|
||||
}
|
||||
|
||||
// the existing received payment can NOT be queried anymore
|
||||
assert(preMigrationDb.getIncomingPayment(paymentHash1).isEmpty)
|
||||
|
||||
// add a few rows
|
||||
val ps1 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), None, paymentHash1, PaymentType.Standard, 12345 msat, 12345 msat, alice, 1000, None, OutgoingPaymentStatus.Pending)
|
||||
val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(500 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 1)
|
||||
val pr1 = IncomingPayment(i1, preimage1, PaymentType.Standard, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(550 msat, 1100))
|
||||
|
||||
preMigrationDb.addOutgoingPayment(ps1)
|
||||
preMigrationDb.addIncomingPayment(i1, preimage1)
|
||||
preMigrationDb.receiveIncomingPayment(i1.paymentHash, 550 msat, 1100)
|
||||
|
||||
assert(preMigrationDb.listIncomingPayments(1, 1500) === Seq(pr1))
|
||||
assert(preMigrationDb.listOutgoingPayments(1, 1500) === Seq(ps1))
|
||||
|
||||
val postMigrationDb = new SqlitePaymentsDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 4) == 4) // version still to 4
|
||||
}
|
||||
|
||||
assert(postMigrationDb.listIncomingPayments(1, 1500) === Seq(pr1))
|
||||
assert(postMigrationDb.listOutgoingPayments(1, 1500) === Seq(ps1))
|
||||
}
|
||||
|
||||
test("handle version migration 2->4") {
|
||||
val connection = TestConstants.sqliteInMemory()
|
||||
import fr.acinq.eclair.db.sqlite.SqliteUtils._
|
||||
forAllDbs {
|
||||
case _: TestPgDatabases => // no migration
|
||||
case dbs: TestSqliteDatabases =>
|
||||
val connection = dbs.connection
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
getVersion(statement, "payments", 2)
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER, received_at INTEGER)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, payment_hash BLOB NOT NULL, preimage BLOB, amount_msat INTEGER NOT NULL, created_at INTEGER NOT NULL, completed_at INTEGER, status VARCHAR NOT NULL)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS payment_hash_idx ON sent_payments(payment_hash)")
|
||||
using(connection.createStatement()) { statement =>
|
||||
getVersion(statement, "payments", 2)
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER, received_at INTEGER)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, payment_hash BLOB NOT NULL, preimage BLOB, amount_msat INTEGER NOT NULL, created_at INTEGER NOT NULL, completed_at INTEGER, status VARCHAR NOT NULL)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS payment_hash_idx ON sent_payments(payment_hash)")
|
||||
}
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 2) == 2) // version 2 is deployed now
|
||||
}
|
||||
|
||||
// Insert a bunch of old version 2 rows.
|
||||
val id1 = UUID.randomUUID()
|
||||
val id2 = UUID.randomUUID()
|
||||
val id3 = UUID.randomUUID()
|
||||
val ps1 = OutgoingPayment(id1, id1, None, randomBytes32, PaymentType.Standard, 561 msat, 561 msat, PrivateKey(ByteVector32.One).publicKey, 1000, None, OutgoingPaymentStatus.Pending)
|
||||
val ps2 = OutgoingPayment(id2, id2, None, randomBytes32, PaymentType.Standard, 1105 msat, 1105 msat, PrivateKey(ByteVector32.One).publicKey, 1010, None, OutgoingPaymentStatus.Failed(Nil, 1050))
|
||||
val ps3 = OutgoingPayment(id3, id3, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, PrivateKey(ByteVector32.One).publicKey, 1040, None, OutgoingPaymentStatus.Succeeded(preimage1, 0 msat, Nil, 1060))
|
||||
val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 1)
|
||||
val pr1 = IncomingPayment(i1, preimage1, PaymentType.Standard, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(12345678 msat, 1090))
|
||||
val i2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash2, carolPriv, "Another invoice", expirySeconds = Some(30), timestamp = 1)
|
||||
val pr2 = IncomingPayment(i2, preimage2, PaymentType.Standard, i2.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired)
|
||||
|
||||
// Changes between version 2 and 3 to sent_payments:
|
||||
// - removed the status column
|
||||
// - added optional payment failures
|
||||
// - added optional payment success details (fees paid and route)
|
||||
// - added optional payment request
|
||||
// - added target node ID
|
||||
// - added externalID and parentID
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, amount_msat, created_at, status) VALUES (?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, ps1.id.toString)
|
||||
statement.setBytes(2, ps1.paymentHash.toArray)
|
||||
statement.setLong(3, ps1.amount.toLong)
|
||||
statement.setLong(4, ps1.createdAt)
|
||||
statement.setString(5, "PENDING")
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, amount_msat, created_at, completed_at, status) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, ps2.id.toString)
|
||||
statement.setBytes(2, ps2.paymentHash.toArray)
|
||||
statement.setLong(3, ps2.amount.toLong)
|
||||
statement.setLong(4, ps2.createdAt)
|
||||
statement.setLong(5, ps2.status.asInstanceOf[OutgoingPaymentStatus.Failed].completedAt)
|
||||
statement.setString(6, "FAILED")
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, preimage, amount_msat, created_at, completed_at, status) VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, ps3.id.toString)
|
||||
statement.setBytes(2, ps3.paymentHash.toArray)
|
||||
statement.setBytes(3, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].paymentPreimage.toArray)
|
||||
statement.setLong(4, ps3.amount.toLong)
|
||||
statement.setLong(5, ps3.createdAt)
|
||||
statement.setLong(6, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].completedAt)
|
||||
statement.setString(7, "SUCCEEDED")
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
// Changes between version 2 and 3 to received_payments:
|
||||
// - renamed the preimage column
|
||||
// - made expire_at not null
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO received_payments (payment_hash, preimage, payment_request, received_msat, created_at, received_at) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, i1.paymentHash.toArray)
|
||||
statement.setBytes(2, pr1.paymentPreimage.toArray)
|
||||
statement.setString(3, PaymentRequest.write(i1))
|
||||
statement.setLong(4, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].amount.toLong)
|
||||
statement.setLong(5, pr1.createdAt)
|
||||
statement.setLong(6, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].receivedAt)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO received_payments (payment_hash, preimage, payment_request, created_at, expire_at) VALUES (?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, i2.paymentHash.toArray)
|
||||
statement.setBytes(2, pr2.paymentPreimage.toArray)
|
||||
statement.setString(3, PaymentRequest.write(i2))
|
||||
statement.setLong(4, pr2.createdAt)
|
||||
statement.setLong(5, (i2.timestamp + i2.expiry.get).seconds.toMillis)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
val preMigrationDb = new SqlitePaymentsDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 2) == 4) // version changed from 2 -> 4
|
||||
}
|
||||
|
||||
assert(preMigrationDb.getIncomingPayment(i1.paymentHash) === Some(pr1))
|
||||
assert(preMigrationDb.getIncomingPayment(i2.paymentHash) === Some(pr2))
|
||||
assert(preMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2, ps3))
|
||||
|
||||
val postMigrationDb = new SqlitePaymentsDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 4) == 4) // version still to 4
|
||||
}
|
||||
|
||||
val i3 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), paymentHash3, alicePriv, "invoice #3", expirySeconds = Some(30))
|
||||
val pr3 = IncomingPayment(i3, preimage3, PaymentType.Standard, i3.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending)
|
||||
postMigrationDb.addIncomingPayment(i3, pr3.paymentPreimage)
|
||||
|
||||
val ps4 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("1"), randomBytes32, PaymentType.Standard, 123 msat, 123 msat, alice, 1100, Some(i3), OutgoingPaymentStatus.Pending)
|
||||
val ps5 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("2"), randomBytes32, PaymentType.Standard, 456 msat, 456 msat, bob, 1150, Some(i2), OutgoingPaymentStatus.Succeeded(preimage1, 42 msat, Nil, 1180))
|
||||
val ps6 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("3"), randomBytes32, PaymentType.Standard, 789 msat, 789 msat, bob, 1250, None, OutgoingPaymentStatus.Failed(Nil, 1300))
|
||||
postMigrationDb.addOutgoingPayment(ps4)
|
||||
postMigrationDb.addOutgoingPayment(ps5.copy(status = OutgoingPaymentStatus.Pending))
|
||||
postMigrationDb.updateOutgoingPayment(PaymentSent(ps5.parentId, ps5.paymentHash, preimage1, ps5.amount, ps5.recipientNodeId, Seq(PaymentSent.PartialPayment(ps5.id, ps5.amount, 42 msat, randomBytes32, None, 1180))))
|
||||
postMigrationDb.addOutgoingPayment(ps6.copy(status = OutgoingPaymentStatus.Pending))
|
||||
postMigrationDb.updateOutgoingPayment(PaymentFailed(ps6.id, ps6.paymentHash, Nil, 1300))
|
||||
|
||||
assert(postMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2, ps3, ps4, ps5, ps6))
|
||||
assert(postMigrationDb.listIncomingPayments(1, System.currentTimeMillis) === Seq(pr1, pr2, pr3))
|
||||
assert(postMigrationDb.listExpiredIncomingPayments(1, 2000) === Seq(pr2))
|
||||
}
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 2) == 2) // version 2 is deployed now
|
||||
}
|
||||
|
||||
// Insert a bunch of old version 2 rows.
|
||||
val id1 = UUID.randomUUID()
|
||||
val id2 = UUID.randomUUID()
|
||||
val id3 = UUID.randomUUID()
|
||||
val ps1 = OutgoingPayment(id1, id1, None, randomBytes32, PaymentType.Standard, 561 msat, 561 msat, PrivateKey(ByteVector32.One).publicKey, 1000, None, OutgoingPaymentStatus.Pending)
|
||||
val ps2 = OutgoingPayment(id2, id2, None, randomBytes32, PaymentType.Standard, 1105 msat, 1105 msat, PrivateKey(ByteVector32.One).publicKey, 1010, None, OutgoingPaymentStatus.Failed(Nil, 1050))
|
||||
val ps3 = OutgoingPayment(id3, id3, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, PrivateKey(ByteVector32.One).publicKey, 1040, None, OutgoingPaymentStatus.Succeeded(preimage1, 0 msat, Nil, 1060))
|
||||
val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 1)
|
||||
val pr1 = IncomingPayment(i1, preimage1, PaymentType.Standard, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(12345678 msat, 1090))
|
||||
val i2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash2, carolPriv, "Another invoice", expirySeconds = Some(30), timestamp = 1)
|
||||
val pr2 = IncomingPayment(i2, preimage2, PaymentType.Standard, i2.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired)
|
||||
|
||||
// Changes between version 2 and 3 to sent_payments:
|
||||
// - removed the status column
|
||||
// - added optional payment failures
|
||||
// - added optional payment success details (fees paid and route)
|
||||
// - added optional payment request
|
||||
// - added target node ID
|
||||
// - added externalID and parentID
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, amount_msat, created_at, status) VALUES (?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, ps1.id.toString)
|
||||
statement.setBytes(2, ps1.paymentHash.toArray)
|
||||
statement.setLong(3, ps1.amount.toLong)
|
||||
statement.setLong(4, ps1.createdAt)
|
||||
statement.setString(5, "PENDING")
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, amount_msat, created_at, completed_at, status) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, ps2.id.toString)
|
||||
statement.setBytes(2, ps2.paymentHash.toArray)
|
||||
statement.setLong(3, ps2.amount.toLong)
|
||||
statement.setLong(4, ps2.createdAt)
|
||||
statement.setLong(5, ps2.status.asInstanceOf[OutgoingPaymentStatus.Failed].completedAt)
|
||||
statement.setString(6, "FAILED")
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, preimage, amount_msat, created_at, completed_at, status) VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, ps3.id.toString)
|
||||
statement.setBytes(2, ps3.paymentHash.toArray)
|
||||
statement.setBytes(3, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].paymentPreimage.toArray)
|
||||
statement.setLong(4, ps3.amount.toLong)
|
||||
statement.setLong(5, ps3.createdAt)
|
||||
statement.setLong(6, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].completedAt)
|
||||
statement.setString(7, "SUCCEEDED")
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
// Changes between version 2 and 3 to received_payments:
|
||||
// - renamed the preimage column
|
||||
// - made expire_at not null
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO received_payments (payment_hash, preimage, payment_request, received_msat, created_at, received_at) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, i1.paymentHash.toArray)
|
||||
statement.setBytes(2, pr1.paymentPreimage.toArray)
|
||||
statement.setString(3, PaymentRequest.write(i1))
|
||||
statement.setLong(4, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].amount.toLong)
|
||||
statement.setLong(5, pr1.createdAt)
|
||||
statement.setLong(6, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].receivedAt)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO received_payments (payment_hash, preimage, payment_request, created_at, expire_at) VALUES (?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setBytes(1, i2.paymentHash.toArray)
|
||||
statement.setBytes(2, pr2.paymentPreimage.toArray)
|
||||
statement.setString(3, PaymentRequest.write(i2))
|
||||
statement.setLong(4, pr2.createdAt)
|
||||
statement.setLong(5, (i2.timestamp + i2.expiry.get).seconds.toMillis)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
val preMigrationDb = new SqlitePaymentsDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 2) == 4) // version changed from 2 -> 4
|
||||
}
|
||||
|
||||
assert(preMigrationDb.getIncomingPayment(i1.paymentHash) === Some(pr1))
|
||||
assert(preMigrationDb.getIncomingPayment(i2.paymentHash) === Some(pr2))
|
||||
assert(preMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2, ps3))
|
||||
|
||||
val postMigrationDb = new SqlitePaymentsDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 4) == 4) // version still to 4
|
||||
}
|
||||
|
||||
val i3 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), paymentHash3, alicePriv, "invoice #3", expirySeconds = Some(30))
|
||||
val pr3 = IncomingPayment(i3, preimage3, PaymentType.Standard, i3.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending)
|
||||
postMigrationDb.addIncomingPayment(i3, pr3.paymentPreimage)
|
||||
|
||||
val ps4 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("1"), randomBytes32, PaymentType.Standard, 123 msat, 123 msat, alice, 1100, Some(i3), OutgoingPaymentStatus.Pending)
|
||||
val ps5 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("2"), randomBytes32, PaymentType.Standard, 456 msat, 456 msat, bob, 1150, Some(i2), OutgoingPaymentStatus.Succeeded(preimage1, 42 msat, Nil, 1180))
|
||||
val ps6 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("3"), randomBytes32, PaymentType.Standard, 789 msat, 789 msat, bob, 1250, None, OutgoingPaymentStatus.Failed(Nil, 1300))
|
||||
postMigrationDb.addOutgoingPayment(ps4)
|
||||
postMigrationDb.addOutgoingPayment(ps5.copy(status = OutgoingPaymentStatus.Pending))
|
||||
postMigrationDb.updateOutgoingPayment(PaymentSent(ps5.parentId, ps5.paymentHash, preimage1, ps5.amount, ps5.recipientNodeId, Seq(PaymentSent.PartialPayment(ps5.id, ps5.amount, 42 msat, randomBytes32, None, 1180))))
|
||||
postMigrationDb.addOutgoingPayment(ps6.copy(status = OutgoingPaymentStatus.Pending))
|
||||
postMigrationDb.updateOutgoingPayment(PaymentFailed(ps6.id, ps6.paymentHash, Nil, 1300))
|
||||
|
||||
assert(postMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2, ps3, ps4, ps5, ps6))
|
||||
assert(postMigrationDb.listIncomingPayments(1, System.currentTimeMillis) === Seq(pr1, pr2, pr3))
|
||||
assert(postMigrationDb.listExpiredIncomingPayments(1, 2000) === Seq(pr2))
|
||||
}
|
||||
|
||||
test("handle version migration 3->4") {
|
||||
val connection = TestConstants.sqliteInMemory()
|
||||
forAllDbs {
|
||||
case _: TestPgDatabases => // no migration
|
||||
case dbs: TestSqliteDatabases =>
|
||||
import fr.acinq.eclair.db.sqlite.SqliteUtils._
|
||||
val connection = dbs.connection
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
getVersion(statement, "payments", 3)
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, payment_preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER NOT NULL, received_at INTEGER)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)")
|
||||
using(connection.createStatement()) { statement =>
|
||||
getVersion(statement, "payments", 3)
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, payment_preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER NOT NULL, received_at INTEGER)")
|
||||
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)")
|
||||
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_parent_id_idx ON sent_payments(parent_id)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_payment_hash_idx ON sent_payments(payment_hash)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_created_idx ON sent_payments(created_at)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_created_idx ON received_payments(created_at)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_parent_id_idx ON sent_payments(parent_id)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_payment_hash_idx ON sent_payments(payment_hash)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_created_idx ON sent_payments(created_at)")
|
||||
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_created_idx ON received_payments(created_at)")
|
||||
}
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 3) == 3) // version 3 is deployed now
|
||||
}
|
||||
|
||||
// Insert a bunch of old version 3 rows.
|
||||
val (id1, id2, id3) = (UUID.randomUUID(), UUID.randomUUID(), UUID.randomUUID())
|
||||
val parentId = UUID.randomUUID()
|
||||
val invoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(2834 msat), paymentHash1, bobPriv, "invoice #1", expirySeconds = Some(30))
|
||||
val ps1 = OutgoingPayment(id1, id1, Some("42"), randomBytes32, PaymentType.Standard, 561 msat, 561 msat, alice, 1000, None, OutgoingPaymentStatus.Failed(Seq(FailureSummary(FailureType.REMOTE, "no candy for you", List(HopSummary(hop_ab), HopSummary(hop_bc)))), 1020))
|
||||
val ps2 = OutgoingPayment(id2, parentId, Some("42"), paymentHash1, PaymentType.Standard, 1105 msat, 1105 msat, bob, 1010, Some(invoice1), OutgoingPaymentStatus.Pending)
|
||||
val ps3 = OutgoingPayment(id3, parentId, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, bob, 1040, None, OutgoingPaymentStatus.Succeeded(preimage1, 10 msat, Seq(HopSummary(hop_ab), HopSummary(hop_bc)), 1060))
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, completed_at, failures) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, ps1.id.toString)
|
||||
statement.setString(2, ps1.parentId.toString)
|
||||
statement.setString(3, ps1.externalId.get)
|
||||
statement.setBytes(4, ps1.paymentHash.toArray)
|
||||
statement.setLong(5, ps1.amount.toLong)
|
||||
statement.setBytes(6, ps1.recipientNodeId.value.toArray)
|
||||
statement.setLong(7, ps1.createdAt)
|
||||
statement.setLong(8, ps1.status.asInstanceOf[OutgoingPaymentStatus.Failed].completedAt)
|
||||
statement.setBytes(9, SqlitePaymentsDb.paymentFailuresCodec.encode(ps1.status.asInstanceOf[OutgoingPaymentStatus.Failed].failures.toList).require.toByteArray)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, payment_request) VALUES (?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, ps2.id.toString)
|
||||
statement.setString(2, ps2.parentId.toString)
|
||||
statement.setString(3, ps2.externalId.get)
|
||||
statement.setBytes(4, ps2.paymentHash.toArray)
|
||||
statement.setLong(5, ps2.amount.toLong)
|
||||
statement.setBytes(6, ps2.recipientNodeId.value.toArray)
|
||||
statement.setLong(7, ps2.createdAt)
|
||||
statement.setString(8, PaymentRequest.write(invoice1))
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO sent_payments (id, parent_id, payment_hash, amount_msat, target_node_id, created_at, completed_at, payment_preimage, fees_msat, payment_route) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, ps3.id.toString)
|
||||
statement.setString(2, ps3.parentId.toString)
|
||||
statement.setBytes(3, ps3.paymentHash.toArray)
|
||||
statement.setLong(4, ps3.amount.toLong)
|
||||
statement.setBytes(5, ps3.recipientNodeId.value.toArray)
|
||||
statement.setLong(6, ps3.createdAt)
|
||||
statement.setLong(7, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].completedAt)
|
||||
statement.setBytes(8, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].paymentPreimage.toArray)
|
||||
statement.setLong(9, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].feesPaid.toLong)
|
||||
statement.setBytes(10, SqlitePaymentsDb.paymentRouteCodec.encode(ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].route.toList).require.toByteArray)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
// Changes between version 3 and 4 to sent_payments:
|
||||
// - added final amount column
|
||||
// - added payment type column, with a default to "Standard"
|
||||
// - renamed target_node_id -> recipient_node_id
|
||||
// - re-ordered columns
|
||||
|
||||
val preMigrationDb = new SqlitePaymentsDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 3) == 4) // version changed from 3 -> 4
|
||||
}
|
||||
|
||||
assert(preMigrationDb.getOutgoingPayment(id1) === Some(ps1))
|
||||
assert(preMigrationDb.listOutgoingPayments(parentId) === Seq(ps2, ps3))
|
||||
|
||||
val postMigrationDb = new SqlitePaymentsDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 4) == 4) // version still to 4
|
||||
}
|
||||
}
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 3) == 3) // version 3 is deployed now
|
||||
}
|
||||
|
||||
// Insert a bunch of old version 3 rows.
|
||||
val (id1, id2, id3) = (UUID.randomUUID(), UUID.randomUUID(), UUID.randomUUID())
|
||||
val parentId = UUID.randomUUID()
|
||||
val invoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(2834 msat), paymentHash1, bobPriv, "invoice #1", expirySeconds = Some(30))
|
||||
val ps1 = OutgoingPayment(id1, id1, Some("42"), randomBytes32, PaymentType.Standard, 561 msat, 561 msat, alice, 1000, None, OutgoingPaymentStatus.Failed(Seq(FailureSummary(FailureType.REMOTE, "no candy for you", List(HopSummary(hop_ab), HopSummary(hop_bc)))), 1020))
|
||||
val ps2 = OutgoingPayment(id2, parentId, Some("42"), paymentHash1, PaymentType.Standard, 1105 msat, 1105 msat, bob, 1010, Some(invoice1), OutgoingPaymentStatus.Pending)
|
||||
val ps3 = OutgoingPayment(id3, parentId, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, bob, 1040, None, OutgoingPaymentStatus.Succeeded(preimage1, 10 msat, Seq(HopSummary(hop_ab), HopSummary(hop_bc)), 1060))
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, completed_at, failures) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, ps1.id.toString)
|
||||
statement.setString(2, ps1.parentId.toString)
|
||||
statement.setString(3, ps1.externalId.get)
|
||||
statement.setBytes(4, ps1.paymentHash.toArray)
|
||||
statement.setLong(5, ps1.amount.toLong)
|
||||
statement.setBytes(6, ps1.recipientNodeId.value.toArray)
|
||||
statement.setLong(7, ps1.createdAt)
|
||||
statement.setLong(8, ps1.status.asInstanceOf[OutgoingPaymentStatus.Failed].completedAt)
|
||||
statement.setBytes(9, SqlitePaymentsDb.paymentFailuresCodec.encode(ps1.status.asInstanceOf[OutgoingPaymentStatus.Failed].failures.toList).require.toByteArray)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, payment_request) VALUES (?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, ps2.id.toString)
|
||||
statement.setString(2, ps2.parentId.toString)
|
||||
statement.setString(3, ps2.externalId.get)
|
||||
statement.setBytes(4, ps2.paymentHash.toArray)
|
||||
statement.setLong(5, ps2.amount.toLong)
|
||||
statement.setBytes(6, ps2.recipientNodeId.value.toArray)
|
||||
statement.setLong(7, ps2.createdAt)
|
||||
statement.setString(8, PaymentRequest.write(invoice1))
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
using(connection.prepareStatement("INSERT INTO sent_payments (id, parent_id, payment_hash, amount_msat, target_node_id, created_at, completed_at, payment_preimage, fees_msat, payment_route) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
|
||||
statement.setString(1, ps3.id.toString)
|
||||
statement.setString(2, ps3.parentId.toString)
|
||||
statement.setBytes(3, ps3.paymentHash.toArray)
|
||||
statement.setLong(4, ps3.amount.toLong)
|
||||
statement.setBytes(5, ps3.recipientNodeId.value.toArray)
|
||||
statement.setLong(6, ps3.createdAt)
|
||||
statement.setLong(7, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].completedAt)
|
||||
statement.setBytes(8, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].paymentPreimage.toArray)
|
||||
statement.setLong(9, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].feesPaid.toLong)
|
||||
statement.setBytes(10, SqlitePaymentsDb.paymentRouteCodec.encode(ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].route.toList).require.toByteArray)
|
||||
statement.executeUpdate()
|
||||
}
|
||||
|
||||
// Changes between version 3 and 4 to sent_payments:
|
||||
// - added final amount column
|
||||
// - added payment type column, with a default to "Standard"
|
||||
// - renamed target_node_id -> recipient_node_id
|
||||
// - re-ordered columns
|
||||
|
||||
val preMigrationDb = new SqlitePaymentsDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 3) == 4) // version changed from 3 -> 4
|
||||
}
|
||||
|
||||
assert(preMigrationDb.getOutgoingPayment(id1) === Some(ps1))
|
||||
assert(preMigrationDb.listOutgoingPayments(parentId) === Seq(ps2, ps3))
|
||||
|
||||
val postMigrationDb = new SqlitePaymentsDb(connection)
|
||||
|
||||
using(connection.createStatement()) { statement =>
|
||||
assert(getVersion(statement, "payments", 4) == 4) // version still to 4
|
||||
}
|
||||
|
||||
val ps4 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), None, randomBytes32, PaymentType.SwapOut, 50 msat, 100 msat, carol, 1100, Some(invoice1), OutgoingPaymentStatus.Pending)
|
||||
postMigrationDb.addOutgoingPayment(ps4)
|
||||
postMigrationDb.updateOutgoingPayment(PaymentSent(parentId, paymentHash1, preimage1, ps2.recipientAmount, ps2.recipientNodeId, Seq(PaymentSent.PartialPayment(id2, ps2.amount, 15 msat, randomBytes32, Some(Seq(hop_ab)), 1105))))
|
||||
|
||||
assert(postMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 15 msat, Seq(HopSummary(hop_ab)), 1105)), ps3, ps4))
|
||||
}
|
||||
|
||||
test("add/retrieve/update incoming payments") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
val db = new SqlitePaymentsDb(sqlite)
|
||||
forAllDbs { dbs =>
|
||||
val db = dbs.payments()
|
||||
|
||||
// can't receive a payment without an invoice associated with it
|
||||
assertThrows[IllegalArgumentException](db.receiveIncomingPayment(randomBytes32, 12345678 msat))
|
||||
// can't receive a payment without an invoice associated with it
|
||||
assertThrows[IllegalArgumentException](db.receiveIncomingPayment(randomBytes32, 12345678 msat))
|
||||
|
||||
val expiredInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #1", timestamp = 1)
|
||||
val expiredInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #2", timestamp = 2, expirySeconds = Some(30))
|
||||
val expiredPayment1 = IncomingPayment(expiredInvoice1, randomBytes32, PaymentType.Standard, expiredInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired)
|
||||
val expiredPayment2 = IncomingPayment(expiredInvoice2, randomBytes32, PaymentType.Standard, expiredInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired)
|
||||
val expiredInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #1", timestamp = 1)
|
||||
val expiredInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #2", timestamp = 2, expirySeconds = Some(30))
|
||||
val expiredPayment1 = IncomingPayment(expiredInvoice1, randomBytes32, PaymentType.Standard, expiredInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired)
|
||||
val expiredPayment2 = IncomingPayment(expiredInvoice2, randomBytes32, PaymentType.Standard, expiredInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired)
|
||||
|
||||
val pendingInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #3")
|
||||
val pendingInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #4", expirySeconds = Some(30))
|
||||
val pendingPayment1 = IncomingPayment(pendingInvoice1, randomBytes32, PaymentType.Standard, pendingInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending)
|
||||
val pendingPayment2 = IncomingPayment(pendingInvoice2, randomBytes32, PaymentType.SwapIn, pendingInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending)
|
||||
val pendingInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #3")
|
||||
val pendingInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #4", expirySeconds = Some(30))
|
||||
val pendingPayment1 = IncomingPayment(pendingInvoice1, randomBytes32, PaymentType.Standard, pendingInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending)
|
||||
val pendingPayment2 = IncomingPayment(pendingInvoice2, randomBytes32, PaymentType.SwapIn, pendingInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending)
|
||||
|
||||
val paidInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #5")
|
||||
val paidInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #6", expirySeconds = Some(60))
|
||||
val receivedAt1 = System.currentTimeMillis + 1
|
||||
val receivedAt2 = System.currentTimeMillis + 2
|
||||
val payment1 = IncomingPayment(paidInvoice1, randomBytes32, PaymentType.Standard, paidInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(561 msat, receivedAt2))
|
||||
val payment2 = IncomingPayment(paidInvoice2, randomBytes32, PaymentType.Standard, paidInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(1111 msat, receivedAt2))
|
||||
val paidInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #5")
|
||||
val paidInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #6", expirySeconds = Some(60))
|
||||
val receivedAt1 = System.currentTimeMillis + 1
|
||||
val receivedAt2 = System.currentTimeMillis + 2
|
||||
val payment1 = IncomingPayment(paidInvoice1, randomBytes32, PaymentType.Standard, paidInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(561 msat, receivedAt2))
|
||||
val payment2 = IncomingPayment(paidInvoice2, randomBytes32, PaymentType.Standard, paidInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(1111 msat, receivedAt2))
|
||||
|
||||
db.addIncomingPayment(pendingInvoice1, pendingPayment1.paymentPreimage)
|
||||
db.addIncomingPayment(pendingInvoice2, pendingPayment2.paymentPreimage, PaymentType.SwapIn)
|
||||
db.addIncomingPayment(expiredInvoice1, expiredPayment1.paymentPreimage)
|
||||
db.addIncomingPayment(expiredInvoice2, expiredPayment2.paymentPreimage)
|
||||
db.addIncomingPayment(paidInvoice1, payment1.paymentPreimage)
|
||||
db.addIncomingPayment(paidInvoice2, payment2.paymentPreimage)
|
||||
db.addIncomingPayment(pendingInvoice1, pendingPayment1.paymentPreimage)
|
||||
db.addIncomingPayment(pendingInvoice2, pendingPayment2.paymentPreimage, PaymentType.SwapIn)
|
||||
db.addIncomingPayment(expiredInvoice1, expiredPayment1.paymentPreimage)
|
||||
db.addIncomingPayment(expiredInvoice2, expiredPayment2.paymentPreimage)
|
||||
db.addIncomingPayment(paidInvoice1, payment1.paymentPreimage)
|
||||
db.addIncomingPayment(paidInvoice2, payment2.paymentPreimage)
|
||||
|
||||
assert(db.getIncomingPayment(pendingInvoice1.paymentHash) === Some(pendingPayment1))
|
||||
assert(db.getIncomingPayment(expiredInvoice2.paymentHash) === Some(expiredPayment2))
|
||||
assert(db.getIncomingPayment(paidInvoice1.paymentHash) === Some(payment1.copy(status = IncomingPaymentStatus.Pending)))
|
||||
assert(db.getIncomingPayment(pendingInvoice1.paymentHash) === Some(pendingPayment1))
|
||||
assert(db.getIncomingPayment(expiredInvoice2.paymentHash) === Some(expiredPayment2))
|
||||
assert(db.getIncomingPayment(paidInvoice1.paymentHash) === Some(payment1.copy(status = IncomingPaymentStatus.Pending)))
|
||||
|
||||
val now = System.currentTimeMillis
|
||||
assert(db.listIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2, pendingPayment1, pendingPayment2, payment1.copy(status = IncomingPaymentStatus.Pending), payment2.copy(status = IncomingPaymentStatus.Pending)))
|
||||
assert(db.listExpiredIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2))
|
||||
assert(db.listReceivedIncomingPayments(0, now) === Nil)
|
||||
assert(db.listPendingIncomingPayments(0, now) === Seq(pendingPayment1, pendingPayment2, payment1.copy(status = IncomingPaymentStatus.Pending), payment2.copy(status = IncomingPaymentStatus.Pending)))
|
||||
val now = System.currentTimeMillis
|
||||
assert(db.listIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2, pendingPayment1, pendingPayment2, payment1.copy(status = IncomingPaymentStatus.Pending), payment2.copy(status = IncomingPaymentStatus.Pending)))
|
||||
assert(db.listExpiredIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2))
|
||||
assert(db.listReceivedIncomingPayments(0, now) === Nil)
|
||||
assert(db.listPendingIncomingPayments(0, now) === Seq(pendingPayment1, pendingPayment2, payment1.copy(status = IncomingPaymentStatus.Pending), payment2.copy(status = IncomingPaymentStatus.Pending)))
|
||||
|
||||
db.receiveIncomingPayment(paidInvoice1.paymentHash, 461 msat, receivedAt1)
|
||||
db.receiveIncomingPayment(paidInvoice1.paymentHash, 100 msat, receivedAt2) // adding another payment to this invoice should sum
|
||||
db.receiveIncomingPayment(paidInvoice2.paymentHash, 1111 msat, receivedAt2)
|
||||
db.receiveIncomingPayment(paidInvoice1.paymentHash, 461 msat, receivedAt1)
|
||||
db.receiveIncomingPayment(paidInvoice1.paymentHash, 100 msat, receivedAt2) // adding another payment to this invoice should sum
|
||||
db.receiveIncomingPayment(paidInvoice2.paymentHash, 1111 msat, receivedAt2)
|
||||
|
||||
assert(db.getIncomingPayment(paidInvoice1.paymentHash) === Some(payment1))
|
||||
assert(db.listIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2, pendingPayment1, pendingPayment2, payment1, payment2))
|
||||
assert(db.listIncomingPayments(now - 60.seconds.toMillis, now) === Seq(pendingPayment1, pendingPayment2, payment1, payment2))
|
||||
assert(db.listPendingIncomingPayments(0, now) === Seq(pendingPayment1, pendingPayment2))
|
||||
assert(db.listReceivedIncomingPayments(0, now) === Seq(payment1, payment2))
|
||||
assert(db.getIncomingPayment(paidInvoice1.paymentHash) === Some(payment1))
|
||||
assert(db.listIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2, pendingPayment1, pendingPayment2, payment1, payment2))
|
||||
assert(db.listIncomingPayments(now - 60.seconds.toMillis, now) === Seq(pendingPayment1, pendingPayment2, payment1, payment2))
|
||||
assert(db.listPendingIncomingPayments(0, now) === Seq(pendingPayment1, pendingPayment2))
|
||||
assert(db.listReceivedIncomingPayments(0, now) === Seq(payment1, payment2))
|
||||
}
|
||||
}
|
||||
|
||||
test("add/retrieve/update outgoing payments") {
|
||||
val db = new SqlitePaymentsDb(TestConstants.sqliteInMemory())
|
||||
forAllDbs { dbs =>
|
||||
val db = dbs.payments()
|
||||
|
||||
val parentId = UUID.randomUUID()
|
||||
val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(123 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 0)
|
||||
val s1 = OutgoingPayment(UUID.randomUUID(), parentId, None, paymentHash1, PaymentType.Standard, 123 msat, 600 msat, dave, 100, Some(i1), OutgoingPaymentStatus.Pending)
|
||||
val s2 = OutgoingPayment(UUID.randomUUID(), parentId, Some("1"), paymentHash1, PaymentType.SwapOut, 456 msat, 600 msat, dave, 200, None, OutgoingPaymentStatus.Pending)
|
||||
val parentId = UUID.randomUUID()
|
||||
val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(123 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 0)
|
||||
val s1 = OutgoingPayment(UUID.randomUUID(), parentId, None, paymentHash1, PaymentType.Standard, 123 msat, 600 msat, dave, 100, Some(i1), OutgoingPaymentStatus.Pending)
|
||||
val s2 = OutgoingPayment(UUID.randomUUID(), parentId, Some("1"), paymentHash1, PaymentType.SwapOut, 456 msat, 600 msat, dave, 200, None, OutgoingPaymentStatus.Pending)
|
||||
|
||||
assert(db.listOutgoingPayments(0, System.currentTimeMillis).isEmpty)
|
||||
db.addOutgoingPayment(s1)
|
||||
db.addOutgoingPayment(s2)
|
||||
assert(db.listOutgoingPayments(0, System.currentTimeMillis).isEmpty)
|
||||
db.addOutgoingPayment(s1)
|
||||
db.addOutgoingPayment(s2)
|
||||
|
||||
// can't add an outgoing payment in non-pending state
|
||||
assertThrows[IllegalArgumentException](db.addOutgoingPayment(s1.copy(status = OutgoingPaymentStatus.Succeeded(randomBytes32, 0 msat, Nil, 110))))
|
||||
// can't add an outgoing payment in non-pending state
|
||||
assertThrows[IllegalArgumentException](db.addOutgoingPayment(s1.copy(status = OutgoingPaymentStatus.Succeeded(randomBytes32, 0 msat, Nil, 110))))
|
||||
|
||||
assert(db.listOutgoingPayments(1, 300).toList == Seq(s1, s2))
|
||||
assert(db.listOutgoingPayments(1, 150).toList == Seq(s1))
|
||||
assert(db.listOutgoingPayments(150, 250).toList == Seq(s2))
|
||||
assert(db.getOutgoingPayment(s1.id) === Some(s1))
|
||||
assert(db.getOutgoingPayment(UUID.randomUUID()) === None)
|
||||
assert(db.listOutgoingPayments(s2.paymentHash) === Seq(s1, s2))
|
||||
assert(db.listOutgoingPayments(s1.id) === Nil)
|
||||
assert(db.listOutgoingPayments(parentId) === Seq(s1, s2))
|
||||
assert(db.listOutgoingPayments(ByteVector32.Zeroes) === Nil)
|
||||
assert(db.listOutgoingPayments(1, 300).toList == Seq(s1, s2))
|
||||
assert(db.listOutgoingPayments(1, 150).toList == Seq(s1))
|
||||
assert(db.listOutgoingPayments(150, 250).toList == Seq(s2))
|
||||
assert(db.getOutgoingPayment(s1.id) === Some(s1))
|
||||
assert(db.getOutgoingPayment(UUID.randomUUID()) === None)
|
||||
assert(db.listOutgoingPayments(s2.paymentHash) === Seq(s1, s2))
|
||||
assert(db.listOutgoingPayments(s1.id) === Nil)
|
||||
assert(db.listOutgoingPayments(parentId) === Seq(s1, s2))
|
||||
assert(db.listOutgoingPayments(ByteVector32.Zeroes) === Nil)
|
||||
|
||||
val s3 = s2.copy(id = UUID.randomUUID(), amount = 789 msat, createdAt = 300)
|
||||
val s4 = s2.copy(id = UUID.randomUUID(), paymentType = PaymentType.Standard, createdAt = 300)
|
||||
db.addOutgoingPayment(s3)
|
||||
db.addOutgoingPayment(s4)
|
||||
val s3 = s2.copy(id = UUID.randomUUID(), amount = 789 msat, createdAt = 300)
|
||||
val s4 = s2.copy(id = UUID.randomUUID(), paymentType = PaymentType.Standard, createdAt = 301)
|
||||
db.addOutgoingPayment(s3)
|
||||
db.addOutgoingPayment(s4)
|
||||
|
||||
db.updateOutgoingPayment(PaymentFailed(s3.id, s3.paymentHash, Nil, 310))
|
||||
val ss3 = s3.copy(status = OutgoingPaymentStatus.Failed(Nil, 310))
|
||||
assert(db.getOutgoingPayment(s3.id) === Some(ss3))
|
||||
db.updateOutgoingPayment(PaymentFailed(s4.id, s4.paymentHash, Seq(LocalFailure(Seq(hop_ab), new RuntimeException("woops")), RemoteFailure(Seq(hop_ab, hop_bc), Sphinx.DecryptedFailurePacket(carol, UnknownNextPeer))), 320))
|
||||
val ss4 = s4.copy(status = OutgoingPaymentStatus.Failed(Seq(FailureSummary(FailureType.LOCAL, "woops", List(HopSummary(alice, bob, Some(ShortChannelId(42))))), FailureSummary(FailureType.REMOTE, "processing node does not know the next peer in the route", List(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, None)))), 320))
|
||||
assert(db.getOutgoingPayment(s4.id) === Some(ss4))
|
||||
db.updateOutgoingPayment(PaymentFailed(s3.id, s3.paymentHash, Nil, 310))
|
||||
val ss3 = s3.copy(status = OutgoingPaymentStatus.Failed(Nil, 310))
|
||||
assert(db.getOutgoingPayment(s3.id) === Some(ss3))
|
||||
db.updateOutgoingPayment(PaymentFailed(s4.id, s4.paymentHash, Seq(LocalFailure(Seq(hop_ab), new RuntimeException("woops")), RemoteFailure(Seq(hop_ab, hop_bc), Sphinx.DecryptedFailurePacket(carol, UnknownNextPeer))), 320))
|
||||
val ss4 = s4.copy(status = OutgoingPaymentStatus.Failed(Seq(FailureSummary(FailureType.LOCAL, "woops", List(HopSummary(alice, bob, Some(ShortChannelId(42))))), FailureSummary(FailureType.REMOTE, "processing node does not know the next peer in the route", List(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, None)))), 320))
|
||||
assert(db.getOutgoingPayment(s4.id) === Some(ss4))
|
||||
|
||||
// can't update again once it's in a final state
|
||||
assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentSent(parentId, s3.paymentHash, preimage1, s3.recipientAmount, s3.recipientNodeId, Seq(PaymentSent.PartialPayment(s3.id, s3.amount, 42 msat, randomBytes32, None)))))
|
||||
// can't update again once it's in a final state
|
||||
assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentSent(parentId, s3.paymentHash, preimage1, s3.recipientAmount, s3.recipientNodeId, Seq(PaymentSent.PartialPayment(s3.id, s3.amount, 42 msat, randomBytes32, None)))))
|
||||
|
||||
val paymentSent = PaymentSent(parentId, paymentHash1, preimage1, 600 msat, carol, Seq(
|
||||
PaymentSent.PartialPayment(s1.id, s1.amount, 15 msat, randomBytes32, None, 400),
|
||||
PaymentSent.PartialPayment(s2.id, s2.amount, 20 msat, randomBytes32, Some(Seq(hop_ab, hop_bc)), 410)
|
||||
))
|
||||
val ss1 = s1.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 15 msat, Nil, 400))
|
||||
val ss2 = s2.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 20 msat, Seq(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, None)), 410))
|
||||
db.updateOutgoingPayment(paymentSent)
|
||||
assert(db.getOutgoingPayment(s1.id) === Some(ss1))
|
||||
assert(db.getOutgoingPayment(s2.id) === Some(ss2))
|
||||
assert(db.listOutgoingPayments(parentId) === Seq(ss1, ss2, ss3, ss4))
|
||||
val paymentSent = PaymentSent(parentId, paymentHash1, preimage1, 600 msat, carol, Seq(
|
||||
PaymentSent.PartialPayment(s1.id, s1.amount, 15 msat, randomBytes32, None, 400),
|
||||
PaymentSent.PartialPayment(s2.id, s2.amount, 20 msat, randomBytes32, Some(Seq(hop_ab, hop_bc)), 410)
|
||||
))
|
||||
val ss1 = s1.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 15 msat, Nil, 400))
|
||||
val ss2 = s2.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 20 msat, Seq(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, None)), 410))
|
||||
db.updateOutgoingPayment(paymentSent)
|
||||
assert(db.getOutgoingPayment(s1.id) === Some(ss1))
|
||||
assert(db.getOutgoingPayment(s2.id) === Some(ss2))
|
||||
assert(db.listOutgoingPayments(parentId) === Seq(ss1, ss2, ss3, ss4))
|
||||
|
||||
// can't update again once it's in a final state
|
||||
assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentFailed(s1.id, s1.paymentHash, Nil)))
|
||||
// can't update again once it's in a final state
|
||||
assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentFailed(s1.id, s1.paymentHash, Nil)))
|
||||
}
|
||||
}
|
||||
|
||||
test("high level payments overview") {
|
||||
|
@ -16,50 +16,49 @@
|
||||
|
||||
package fr.acinq.eclair.db
|
||||
|
||||
import java.sql.DriverManager
|
||||
|
||||
import fr.acinq.bitcoin.Crypto.PublicKey
|
||||
import fr.acinq.eclair.db.sqlite.SqlitePeersDb
|
||||
import fr.acinq.eclair.randomKey
|
||||
import fr.acinq.eclair.wire.{NodeAddress, Tor2, Tor3}
|
||||
import fr.acinq.eclair.{TestConstants, randomKey}
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
|
||||
class SqlitePeersDbSpec extends AnyFunSuite {
|
||||
|
||||
def inmem = DriverManager.getConnection("jdbc:sqlite::memory:")
|
||||
import TestConstants.forAllDbs
|
||||
|
||||
test("init sqlite 2 times in a row") {
|
||||
val sqlite = inmem
|
||||
val db1 = new SqlitePeersDb(sqlite)
|
||||
val db2 = new SqlitePeersDb(sqlite)
|
||||
forAllDbs { dbs =>
|
||||
val db1 = dbs.peers()
|
||||
val db2 = dbs.peers()
|
||||
}
|
||||
}
|
||||
|
||||
test("add/remove/get/list peers") {
|
||||
val sqlite = inmem
|
||||
val db = new SqlitePeersDb(sqlite)
|
||||
test("add/remove/list peers") {
|
||||
forAllDbs { dbs =>
|
||||
val db = dbs.peers()
|
||||
|
||||
case class TestCase(nodeId: PublicKey, nodeAddress: NodeAddress)
|
||||
case class TestCase(nodeId: PublicKey, nodeAddress: NodeAddress)
|
||||
|
||||
val peer_1 = TestCase(randomKey.publicKey, NodeAddress.fromParts("127.0.0.1", 42000).get)
|
||||
val peer_1_bis = TestCase(peer_1.nodeId, NodeAddress.fromParts("127.0.0.1", 1112).get)
|
||||
val peer_2 = TestCase(randomKey.publicKey, Tor2("z4zif3fy7fe7bpg3", 4231))
|
||||
val peer_3 = TestCase(randomKey.publicKey, Tor3("mrl2d3ilhctt2vw4qzvmz3etzjvpnc6dczliq5chrxetthgbuczuggyd", 4231))
|
||||
val peer_1 = TestCase(randomKey.publicKey, NodeAddress.fromParts("127.0.0.1", 42000).get)
|
||||
val peer_1_bis = TestCase(peer_1.nodeId, NodeAddress.fromParts("127.0.0.1", 1112).get)
|
||||
val peer_2 = TestCase(randomKey.publicKey, Tor2("z4zif3fy7fe7bpg3", 4231))
|
||||
val peer_3 = TestCase(randomKey.publicKey, Tor3("mrl2d3ilhctt2vw4qzvmz3etzjvpnc6dczliq5chrxetthgbuczuggyd", 4231))
|
||||
|
||||
assert(db.listPeers().toSet === Set.empty)
|
||||
db.addOrUpdatePeer(peer_1.nodeId, peer_1.nodeAddress)
|
||||
assert(db.getPeer(peer_1.nodeId) === Some(peer_1.nodeAddress))
|
||||
assert(db.getPeer(peer_2.nodeId) === None)
|
||||
db.addOrUpdatePeer(peer_1.nodeId, peer_1.nodeAddress) // duplicate is ignored
|
||||
assert(db.listPeers().size === 1)
|
||||
db.addOrUpdatePeer(peer_2.nodeId, peer_2.nodeAddress)
|
||||
db.addOrUpdatePeer(peer_3.nodeId, peer_3.nodeAddress)
|
||||
assert(db.listPeers().map(p => TestCase(p._1, p._2)).toSet === Set(peer_1, peer_2, peer_3))
|
||||
db.removePeer(peer_2.nodeId)
|
||||
assert(db.listPeers().map(p => TestCase(p._1, p._2)).toSet === Set(peer_1, peer_3))
|
||||
db.addOrUpdatePeer(peer_1_bis.nodeId, peer_1_bis.nodeAddress)
|
||||
assert(db.getPeer(peer_1.nodeId) === Some(peer_1_bis.nodeAddress))
|
||||
assert(db.listPeers().map(p => TestCase(p._1, p._2)).toSet === Set(peer_1_bis, peer_3))
|
||||
assert(db.listPeers().toSet === Set.empty)
|
||||
db.addOrUpdatePeer(peer_1.nodeId, peer_1.nodeAddress)
|
||||
assert(db.getPeer(peer_1.nodeId) === Some(peer_1.nodeAddress))
|
||||
assert(db.getPeer(peer_2.nodeId) === None)
|
||||
db.addOrUpdatePeer(peer_1.nodeId, peer_1.nodeAddress) // duplicate is ignored
|
||||
assert(db.listPeers().size === 1)
|
||||
db.addOrUpdatePeer(peer_2.nodeId, peer_2.nodeAddress)
|
||||
db.addOrUpdatePeer(peer_3.nodeId, peer_3.nodeAddress)
|
||||
assert(db.listPeers().map(p => TestCase(p._1, p._2)).toSet === Set(peer_1, peer_2, peer_3))
|
||||
db.removePeer(peer_2.nodeId)
|
||||
assert(db.listPeers().map(p => TestCase(p._1, p._2)).toSet === Set(peer_1, peer_3))
|
||||
db.addOrUpdatePeer(peer_1_bis.nodeId, peer_1_bis.nodeAddress)
|
||||
assert(db.getPeer(peer_1.nodeId) === Some(peer_1_bis.nodeAddress))
|
||||
assert(db.listPeers().map(p => TestCase(p._1, p._2)).toSet === Set(peer_1_bis, peer_3))
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -17,46 +17,49 @@
|
||||
package fr.acinq.eclair.db
|
||||
|
||||
import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FAIL_MALFORMED_HTLC, CMD_FULFILL_HTLC}
|
||||
import fr.acinq.eclair.db.sqlite.SqlitePendingRelayDb
|
||||
import fr.acinq.eclair.{TestConstants, randomBytes32}
|
||||
import fr.acinq.eclair.wire.FailureMessageCodecs
|
||||
import fr.acinq.eclair.{TestConstants, randomBytes32}
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
|
||||
class SqlitePendingRelayDbSpec extends AnyFunSuite {
|
||||
|
||||
import TestConstants.forAllDbs
|
||||
|
||||
test("init sqlite 2 times in a row") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
val db1 = new SqlitePendingRelayDb(sqlite)
|
||||
val db2 = new SqlitePendingRelayDb(sqlite)
|
||||
forAllDbs { dbs =>
|
||||
val db1 = dbs.pendingRelay()
|
||||
val db2 = dbs.pendingRelay()
|
||||
}
|
||||
}
|
||||
|
||||
test("add/remove/list messages") {
|
||||
val sqlite = TestConstants.sqliteInMemory()
|
||||
val db = new SqlitePendingRelayDb(sqlite)
|
||||
forAllDbs { dbs =>
|
||||
val db = dbs.pendingRelay()
|
||||
|
||||
val channelId1 = randomBytes32
|
||||
val channelId2 = randomBytes32
|
||||
val msg0 = CMD_FULFILL_HTLC(0, randomBytes32)
|
||||
val msg1 = CMD_FULFILL_HTLC(1, randomBytes32)
|
||||
val msg2 = CMD_FAIL_HTLC(2, Left(randomBytes32))
|
||||
val msg3 = CMD_FAIL_HTLC(3, Left(randomBytes32))
|
||||
val msg4 = CMD_FAIL_MALFORMED_HTLC(4, randomBytes32, FailureMessageCodecs.BADONION)
|
||||
val channelId1 = randomBytes32
|
||||
val channelId2 = randomBytes32
|
||||
val msg0 = CMD_FULFILL_HTLC(0, randomBytes32)
|
||||
val msg1 = CMD_FULFILL_HTLC(1, randomBytes32)
|
||||
val msg2 = CMD_FAIL_HTLC(2, Left(randomBytes32))
|
||||
val msg3 = CMD_FAIL_HTLC(3, Left(randomBytes32))
|
||||
val msg4 = CMD_FAIL_MALFORMED_HTLC(4, randomBytes32, FailureMessageCodecs.BADONION)
|
||||
|
||||
assert(db.listPendingRelay(channelId1).toSet === Set.empty)
|
||||
db.addPendingRelay(channelId1, msg0)
|
||||
db.addPendingRelay(channelId1, msg0) // duplicate
|
||||
db.addPendingRelay(channelId1, msg1)
|
||||
db.addPendingRelay(channelId1, msg2)
|
||||
db.addPendingRelay(channelId1, msg3)
|
||||
db.addPendingRelay(channelId1, msg4)
|
||||
db.addPendingRelay(channelId2, msg0) // same messages but for different channel
|
||||
db.addPendingRelay(channelId2, msg1)
|
||||
assert(db.listPendingRelay(channelId1).toSet === Set(msg0, msg1, msg2, msg3, msg4))
|
||||
assert(db.listPendingRelay(channelId2).toSet === Set(msg0, msg1))
|
||||
assert(db.listPendingRelay === Set((channelId1, msg0.id), (channelId1, msg1.id), (channelId1, msg2.id), (channelId1, msg3.id), (channelId1, msg4.id), (channelId2, msg0.id), (channelId2, msg1.id)))
|
||||
db.removePendingRelay(channelId1, msg1.id)
|
||||
assert(db.listPendingRelay === Set((channelId1, msg0.id), (channelId1, msg2.id), (channelId1, msg3.id), (channelId1, msg4.id), (channelId2, msg0.id), (channelId2, msg1.id)))
|
||||
assert(db.listPendingRelay(channelId1).toSet === Set.empty)
|
||||
db.addPendingRelay(channelId1, msg0)
|
||||
db.addPendingRelay(channelId1, msg0) // duplicate
|
||||
db.addPendingRelay(channelId1, msg1)
|
||||
db.addPendingRelay(channelId1, msg2)
|
||||
db.addPendingRelay(channelId1, msg3)
|
||||
db.addPendingRelay(channelId1, msg4)
|
||||
db.addPendingRelay(channelId2, msg0) // same messages but for different channel
|
||||
db.addPendingRelay(channelId2, msg1)
|
||||
assert(db.listPendingRelay(channelId1).toSet === Set(msg0, msg1, msg2, msg3, msg4))
|
||||
assert(db.listPendingRelay(channelId2).toSet === Set(msg0, msg1))
|
||||
assert(db.listPendingRelay === Set((channelId1, msg0.id), (channelId1, msg1.id), (channelId1, msg2.id), (channelId1, msg3.id), (channelId1, msg4.id), (channelId2, msg0.id), (channelId2, msg1.id)))
|
||||
db.removePendingRelay(channelId1, msg1.id)
|
||||
assert(db.listPendingRelay === Set((channelId1, msg0.id), (channelId1, msg2.id), (channelId1, msg3.id), (channelId1, msg4.id), (channelId2, msg0.id), (channelId2, msg1.id)))
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -16,10 +16,11 @@
|
||||
|
||||
package fr.acinq.eclair.db
|
||||
|
||||
import java.sql.SQLException
|
||||
|
||||
import fr.acinq.eclair.TestConstants
|
||||
import fr.acinq.eclair.db.sqlite.SqliteUtils.using
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.sqlite.SQLiteException
|
||||
|
||||
class SqliteUtilsSpec extends AnyFunSuite {
|
||||
|
||||
@ -41,7 +42,7 @@ class SqliteUtilsSpec extends AnyFunSuite {
|
||||
assert(!results.next())
|
||||
}
|
||||
|
||||
assertThrows[SQLiteException](using(conn.createStatement(), inTransaction = true) { statement =>
|
||||
assertThrows[SQLException](using(conn.createStatement(), inTransaction = true) { statement =>
|
||||
statement.executeUpdate("INSERT INTO utils_test VALUES (3, 3)")
|
||||
statement.executeUpdate("INSERT INTO utils_test VALUES (1, 3)") // should throw (primary key violation)
|
||||
})
|
||||
|
@ -1 +1 @@
|
||||
{"version":"1.0.0-SNAPSHOT-e3f1ec0","nodeId":"03af0ed6052cf28d670665549bc86f4b721c9fdb309d40c58f5811f63966e005d0","alias":"alice","color":"#000102","features":{"activated":[{"name":"option_data_loss_protect","support":"mandatory"},{"name":"gossip_queries_ex","support":"optional"}],"unknown":[]},"chainHash":"06226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f","network":"regtest","blockHeight":9999,"publicAddresses":["localhost:9731"]}
|
||||
{"version":"1.0.0-SNAPSHOT-e3f1ec0","nodeId":"03af0ed6052cf28d670665549bc86f4b721c9fdb309d40c58f5811f63966e005d0","alias":"alice","color":"#000102","features":{"activated":[{"name":"option_data_loss_protect","support":"mandatory"},{"name":"gossip_queries_ex","support":"optional"}],"unknown":[]},"chainHash":"06226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f","network":"regtest","blockHeight":9999,"publicAddresses":["localhost:9731"],"instanceId":"01234567-0123-4567-89ab-0123456789ab"}
|
@ -180,7 +180,8 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM
|
||||
chainHash = ByteVector32(hex"06226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f"),
|
||||
network = "regtest",
|
||||
blockHeight = 9999,
|
||||
publicAddresses = NodeAddress.fromParts("localhost", 9731).get :: Nil
|
||||
publicAddresses = NodeAddress.fromParts("localhost", 9731).get :: Nil,
|
||||
instanceId = "01234567-0123-4567-89ab-0123456789ab"
|
||||
))
|
||||
|
||||
Post("/getinfo") ~>
|
||||
|
Loading…
Reference in New Issue
Block a user