1
0
mirror of https://github.com/ACINQ/eclair.git synced 2024-11-19 01:43:22 +01:00

Postgresql support (#1249)

Add beta support for PostgreSQL database backend.
This commit is contained in:
rorp 2020-07-01 05:52:36 -07:00 committed by GitHub
parent 68dfc6cb7c
commit b63c4aa5a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 2974 additions and 1018 deletions

2
.gitignore vendored
View File

@ -25,3 +25,5 @@ target/
project/target
DeleteMe*.*
*~
.DS_Store

114
docs/PostgreSQL.md Normal file
View File

@ -0,0 +1,114 @@
## PostgreSQL Configuration
By default Eclair stores its data on the machine's local file system (typically in `~/.eclair` directory) using SQLite.
It also supports PostgreSQL version 10.6 and higher as a database backend.
To enable PostgreSQL support set the `driver` parameter to `postgres`:
```
eclair.db.driver = postgres
```
### Connection settings
To configure the connection settings use the `database`, `host`, `port` `username` and `password` parameters:
```
eclair.db.postgres.database = "mydb"
eclair.db.postgres.host = "127.0.0.1" # Default: "localhost"
eclair.db.postgres.port = 12345 # Default: 5432
eclair.db.postgres.username = "myuser"
eclair.db.postgres.password = "mypassword"
```
Eclair uses Hikari connection pool (https://github.com/brettwooldridge/HikariCP) which has a lot of configuration
parameters. Some of them can be set in Eclair config file. The most important is `pool.max-size`, it defines the maximum
allowed number of simultaneous connections to the database.
A good rule of thumb is to set `pool.max-size` to the CPU core count times 2.
See https://github.com/brettwooldridge/HikariCP/wiki/About-Pool-Sizing for better estimation.
```
eclair.db.postgres.pool {
max-size = 8 # Default: 10
connection-timeout = 10 seconds # Default: 30 seconds
idle-timeout = 1 minute # Default: 10 minutes
max-life-time = 15 minutes # Default: 30 minutes
}
```
### Locking settings
Running multiple Eclair processes connected to the same database can lead to data corruption and loss of funds.
That's why Eclair supports database locking mechanisms to prevent multiple Eclair instances from accessing one database together.
Use `postgres.lock-type` parameter to set the locking schemes.
Lock type | Description
---|---
`lease` | At the beginning, Eclair acquires a lease for the database that expires after some time. Then it constantly extends the lease. On each lease extension and each database transaction, Eclair checks if the lease belongs to the Eclair instance. If it doesn't, Eclair assumes that the database was updated by another Eclair process and terminates. Note that this is just a safeguard feature for Eclair rather than a bulletproof database-wide lock, because third-party applications still have the ability to access the database without honoring this locking scheme.
`none` | No locking at all. Useful for tests. DO NOT USE ON MAINNET!
```
eclair.db.postgres.lock-type = "none" // Default: "lease"
```
#### Database Lease Settings
There are two main configuration parameters for the lease locking scheme: `lease.interval` and `lease.renew-interval`.
`lease.interval` defines lease validity time. During the lease time no other node can acquire the lock, except the lease holder.
After that time the lease is assumed expired, any node can acquire the lease. So that only one node can update the database
at a time. Eclair extends the lease every `lease.renew-interval` until terminated.
```
eclair.db.postgres.lease {
interval = 30 seconds // Default: 5 minutes
renew-interval = 10 seconds // Default: 1 minute
}
```
### Backups and replication
The PostgreSQL driver doesn't support Eclair's built-in online backups. Instead, you should use the tools provided
by PostgreSQL.
#### Backup/Restore
For nodes with infrequent channel updates its easier to use `pg_dump` to perform the task.
It's important to stop the node to prevent any channel updates while a backup/restore operation is in progress. It makes
sense to backup the database after each channel update, to prevent restoring an outdated channel's state and consequently
losing the funds associated with that channel.
For more information about backup refer to the official PostgreSQL documentation: https://www.postgresql.org/docs/current/backup.html
#### Replication
For busier nodes it isn't practical to use `pg_dump`. Fortunately, PostgreSQL provides built-in database replication which makes the backup/restore process more seamless.
To set up database replication you need to create a main database, that accepts all changes from the node, and a replica database.
Once replication is configured, the main database will automatically send all the changes to the replica.
In case of failure of the main database, the node can be simply reconfigured to use the replica instead of the main database.
PostgreSQL supports [different types of replication](https://www.postgresql.org/docs/current/different-replication-solutions.html).
The most suitable type for an Eclair node is [synchronous streaming replication](https://www.postgresql.org/docs/current/warm-standby.html#SYNCHRONOUS-REPLICATION),
because it provides a very important feature, that helps keep the replicated channel's state up to date:
> When requesting synchronous replication, each commit of a write transaction will wait until confirmation is received that the commit has been written to the write-ahead log on disk of both the primary and standby server.
Follow the official PostgreSQL high availability documentation for the instructions to set up synchronous streaming replication: https://www.postgresql.org/docs/current/high-availability.html
### Safeguard to prevent accidental loss of funds due to database misconfiguration
Using Eclair with an outdated version of the database or a database created with another seed might lead to loss of funds.
Every time Eclair starts, it checks if the Postgres database connection settings were changed since the last start.
If in fact the settings were changed, Eclair stops immediately to prevent potentially dangerous
but accidental configuration changes to come into effect.
Eclair stores the latest database settings in the `${data-dir}/last_jdbcurl` file, and compares its contents with the database settings from the config file.
The node operator can force Eclair to accept new database
connection settings by removing the `last_jdbcurl` file.

View File

@ -198,6 +198,16 @@
<artifactId>sqlite-jdbc</artifactId>
<version>3.27.2.1</version>
</dependency>
<dependency>
<groupId>org.postgresql</groupId>
<version>42.2.12</version>
<artifactId>postgresql</artifactId>
</dependency>
<dependency>
<groupId>com.zaxxer</groupId>
<artifactId>HikariCP</artifactId>
<version>3.4.2</version>
</dependency>
<dependency>
<!-- This is to get rid of '[WARNING] warning: Class javax.annotation.Nonnull not found - continuing with a stub.' compile errors -->
<groupId>com.google.code.findbugs</groupId>
@ -264,5 +274,11 @@
<version>1.5.9</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.opentable.components</groupId>
<artifactId>otj-pg-embedded</artifactId>
<version>0.13.3</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -183,6 +183,28 @@ eclair {
port = 9051
private-key-file = "tor.dat"
}
db {
driver = "sqlite" // sqlite, postgres
postgres {
database = "eclair"
host = "localhost"
port = 5432
username = ""
password = ""
pool {
max-size = 10 // recommended value = number_of_cpu_cores * 2
connection-timeout = 30 seconds
idle-timeout = 10 minutes
max-life-time = 30 minutes
}
lease {
interval = 5 minutes // lease-interval must be greater than lease-renew-interval
renew-interval = 1 minute
}
lock-type = "lease" // lease or none
}
}
}
// do not edit or move this section

View File

@ -45,7 +45,7 @@ import scala.concurrent.duration._
import scala.concurrent.{ExecutionContext, Future}
import scala.reflect.ClassTag
case class GetInfoResponse(version: String, nodeId: PublicKey, alias: String, color: String, features: Features, chainHash: ByteVector32, network: String, blockHeight: Int, publicAddresses: Seq[NodeAddress])
case class GetInfoResponse(version: String, nodeId: PublicKey, alias: String, color: String, features: Features, chainHash: ByteVector32, network: String, blockHeight: Int, publicAddresses: Seq[NodeAddress], instanceId: String)
case class AuditResponse(sent: Seq[PaymentSent], received: Seq[PaymentReceived], relayed: Seq[PaymentRelayed])
@ -372,7 +372,8 @@ class EclairImpl(appKit: Kit) extends Eclair {
chainHash = appKit.nodeParams.chainHash,
network = NodeParams.chainFromHash(appKit.nodeParams.chainHash),
blockHeight = appKit.nodeParams.currentBlockHeight.toInt,
publicAddresses = appKit.nodeParams.publicAddresses)
publicAddresses = appKit.nodeParams.publicAddresses,
instanceId = appKit.nodeParams.instanceId.toString)
)
override def usableBalances()(implicit timeout: Timeout): Future[Iterable[UsableBalance]] =

View File

@ -19,6 +19,7 @@ package fr.acinq.eclair
import java.io.File
import java.net.InetSocketAddress
import java.nio.file.Files
import java.util.UUID
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicLong
@ -42,6 +43,7 @@ import scala.jdk.CollectionConverters._
* Created by PM on 26/02/2017.
*/
case class NodeParams(keyManager: KeyManager,
instanceId: UUID, // a unique instance ID regenerated after each restart
private val blockCount: AtomicLong,
alias: String,
color: Color,
@ -132,7 +134,7 @@ object NodeParams {
def chainFromHash(chainHash: ByteVector32): String = chain2Hash.map(_.swap).getOrElse(chainHash, throw new RuntimeException(s"invalid chainHash '$chainHash'"))
def makeNodeParams(config: Config, keyManager: KeyManager, torAddress_opt: Option[NodeAddress], database: Databases, blockCount: AtomicLong, feeEstimator: FeeEstimator): NodeParams = {
def makeNodeParams(config: Config, instanceId: UUID, keyManager: KeyManager, torAddress_opt: Option[NodeAddress], database: Databases, blockCount: AtomicLong, feeEstimator: FeeEstimator): NodeParams = {
// check configuration for keys that have been renamed
val deprecatedKeyPaths = Map(
// v0.3.2
@ -235,6 +237,7 @@ object NodeParams {
NodeParams(
keyManager = keyManager,
instanceId = instanceId,
blockCount = blockCount,
alias = nodeAlias,
color = Color(color(0), color(1), color(2)),

View File

@ -19,6 +19,7 @@ package fr.acinq.eclair
import java.io.File
import java.net.InetSocketAddress
import java.sql.DriverManager
import java.util.UUID
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
@ -27,7 +28,6 @@ import akka.actor.{ActorRef, ActorSystem, Props, SupervisorStrategy}
import akka.pattern.after
import akka.util.Timeout
import com.softwaremill.sttp.okhttp.OkHttpFutureBackend
import com.typesafe.config.{Config, ConfigFactory}
import fr.acinq.bitcoin.{Block, ByteVector32}
import fr.acinq.eclair.NodeParams.{BITCOIND, ELECTRUM}
import fr.acinq.eclair.blockchain.bitcoind.rpc.{BasicBitcoinJsonRPCClient, BatchingBitcoinJsonRPCClient, ExtendedBitcoinClient}
@ -41,7 +41,8 @@ import fr.acinq.eclair.blockchain.fee.{ConstantFeeProvider, _}
import fr.acinq.eclair.blockchain.{EclairWallet, _}
import fr.acinq.eclair.channel.Register
import fr.acinq.eclair.crypto.LocalKeyManager
import fr.acinq.eclair.db.{BackupHandler, Databases}
import fr.acinq.eclair.db.Databases.FileBackup
import fr.acinq.eclair.db.{Databases, FileBackupHandler}
import fr.acinq.eclair.io.{ClientSpawner, Server, Switchboard}
import fr.acinq.eclair.payment.Auditor
import fr.acinq.eclair.payment.receive.PaymentHandler
@ -90,11 +91,11 @@ class Setup(datadir: File,
val chain = config.getString("chain")
val chaindir = new File(datadir, chain)
val keyManager = new LocalKeyManager(seed, NodeParams.hashFromChain(chain))
val instanceId = UUID.randomUUID()
val database = db match {
case Some(d) => d
case None => Databases.sqliteJDBC(chaindir)
}
logger.info(s"instanceid=$instanceId")
val databases = Databases.init(config.getConfig("db"), instanceId, datadir, chaindir, db)
/**
* This counter holds the current blockchain height.
@ -123,7 +124,7 @@ class Setup(datadir: File,
// @formatter:on
}
val nodeParams = NodeParams.makeNodeParams(config, keyManager, initTor(), database, blockCount, feeEstimator)
val nodeParams = NodeParams.makeNodeParams(config, instanceId, keyManager, initTor(), databases, blockCount, feeEstimator)
val serverBindingAddress = new InetSocketAddress(
config.getString("server.binding-ip"),
@ -281,12 +282,16 @@ class Setup(datadir: File,
// do not change the name of this actor. it is used in the configuration to specify a custom bounded mailbox
backupHandler = if (config.getBoolean("enable-db-backup")) {
system.actorOf(SimpleSupervisor.props(
BackupHandler.props(
nodeParams.db,
new File(chaindir, "eclair.sqlite.bak"),
if (config.hasPath("backup-notify-script")) Some(config.getString("backup-notify-script")) else None
), "backuphandler", SupervisorStrategy.Resume))
nodeParams.db match {
case fileBackup: FileBackup => system.actorOf(SimpleSupervisor.props(
FileBackupHandler.props(
fileBackup,
new File(chaindir, "eclair.sqlite.bak"),
if (config.hasPath("backup-notify-script")) Some(config.getString("backup-notify-script")) else None),
"backuphandler", SupervisorStrategy.Resume))
case _ =>
system.deadLetters
}
} else {
logger.warn("database backup is disabled")
system.deadLetters
@ -363,6 +368,7 @@ class Setup(datadir: File,
None
}
}
}
// @formatter:off

View File

@ -17,11 +17,20 @@
package fr.acinq.eclair.db
import java.io.File
import java.nio.file._
import java.sql.{Connection, DriverManager}
import java.util.UUID
import akka.actor.ActorSystem
import com.typesafe.config.Config
import fr.acinq.eclair.db.pg.PgUtils.LockType.LockType
import fr.acinq.eclair.db.pg.PgUtils._
import fr.acinq.eclair.db.pg._
import fr.acinq.eclair.db.sqlite._
import grizzled.slf4j.Logging
import org.sqlite.SQLiteException
import javax.sql.DataSource
import scala.concurrent.duration._
trait Databases {
@ -36,12 +45,53 @@ trait Databases {
val payments: PaymentsDb
val pendingRelay: PendingRelayDb
def backup(file: File): Unit
}
object Databases extends Logging {
trait FileBackup { this: Databases =>
def backup(backupFile: File): Unit
}
trait ExclusiveLock { this: Databases =>
def obtainExclusiveLock(): Unit
}
def init(dbConfig: Config, instanceId: UUID, datadir: File, chaindir: File, db: Option[Databases] = None)(implicit system: ActorSystem): Databases = {
db match {
case Some(d) => d
case None =>
dbConfig.getString("driver") match {
case "sqlite" => Databases.sqliteJDBC(chaindir)
case "postgres" =>
val pg = Databases.setupPgDatabases(dbConfig, instanceId, datadir, { ex =>
logger.error("fatal error: Cannot obtain lock on the database.\n", ex)
sys.exit(-2)
})
if (LockType(dbConfig.getString("postgres.lock-type")) == LockType.LEASE) {
val dbLockLeaseRenewInterval = dbConfig.getDuration("postgres.lease.renew-interval").toSeconds.seconds
val dbLockLeaseInterval = dbConfig.getDuration("postgres.lease.interval").toSeconds.seconds
if (dbLockLeaseInterval <= dbLockLeaseRenewInterval)
throw new RuntimeException("Invalid configuration: `db.postgres.lease.interval` must be greater than `db.postgres.lease.renew-interval`")
import system.dispatcher
system.scheduler.scheduleWithFixedDelay(dbLockLeaseRenewInterval, dbLockLeaseRenewInterval)(new Runnable {
override def run(): Unit = {
try {
pg.obtainExclusiveLock()
} catch {
case e: Throwable =>
logger.error("fatal error: Cannot obtain the database lease.\n", e)
sys.exit(-1)
}
}
})
}
pg
case driver => throw new RuntimeException(s"Unknown database driver `$driver`")
}
}
}
/**
* Given a parent folder it creates or loads all the databases from a JDBC connection
*
@ -59,7 +109,7 @@ object Databases extends Logging {
sqliteAudit = DriverManager.getConnection(s"jdbc:sqlite:${new File(dbdir, "audit.sqlite")}")
SqliteUtils.obtainExclusiveLock(sqliteEclair) // there should only be one process writing to this file
logger.info("successful lock on eclair.sqlite")
databaseByConnections(sqliteAudit, sqliteNetwork, sqliteEclair)
sqliteDatabaseByConnections(sqliteAudit, sqliteNetwork, sqliteEclair)
} catch {
case t: Throwable => {
logger.error("could not create connection to sqlite databases: ", t)
@ -69,23 +119,117 @@ object Databases extends Logging {
throw t
}
}
}
def databaseByConnections(auditJdbc: Connection, networkJdbc: Connection, eclairJdbc: Connection) = new Databases {
def postgresJDBC(database: String, host: String, port: Int,
username: Option[String], password: Option[String],
poolProperties: Map[String, Long],
instanceId: UUID,
databaseLeaseInterval: FiniteDuration,
lockExceptionHandler: LockExceptionHandler = { _ => () },
lockType: LockType = LockType.NONE, datadir: File): Databases with ExclusiveLock = {
val url = s"jdbc:postgresql://${host}:${port}/${database}"
checkIfDatabaseUrlIsUnchanged(url, datadir)
implicit val lock: DatabaseLock = lockType match {
case LockType.NONE => NoLock
case LockType.LEASE => LeaseLock(instanceId, databaseLeaseInterval, lockExceptionHandler)
case _ => throw new RuntimeException(s"Unknown postgres lock type: `$lockType`")
}
import com.zaxxer.hikari.{HikariConfig, HikariDataSource}
val config = new HikariConfig()
config.setJdbcUrl(url)
username.foreach(config.setUsername)
password.foreach(config.setPassword)
poolProperties.get("max-size").foreach(x => config.setMaximumPoolSize(x.toInt))
poolProperties.get("connection-timeout").foreach(config.setConnectionTimeout)
poolProperties.get("idle-timeout").foreach(config.setIdleTimeout)
poolProperties.get("max-life-time").foreach(config.setMaxLifetime)
implicit val ds: DataSource = new HikariDataSource(config)
val databases = new Databases with ExclusiveLock {
override val network = new PgNetworkDb
override val audit = new PgAuditDb
override val channels = new PgChannelsDb
override val peers = new PgPeersDb
override val payments = new PgPaymentsDb
override val pendingRelay = new PgPendingRelayDb
override def obtainExclusiveLock(): Unit = lock.obtainExclusiveLock
}
databases.obtainExclusiveLock()
databases
}
def sqliteDatabaseByConnections(auditJdbc: Connection, networkJdbc: Connection, eclairJdbc: Connection): Databases = new Databases with FileBackup {
override val network = new SqliteNetworkDb(networkJdbc)
override val audit = new SqliteAuditDb(auditJdbc)
override val channels = new SqliteChannelsDb(eclairJdbc)
override val peers = new SqlitePeersDb(eclairJdbc)
override val payments = new SqlitePaymentsDb(eclairJdbc)
override val pendingRelay = new SqlitePendingRelayDb(eclairJdbc)
override def backup(backupFile: File): Unit = {
override def backup(file: File): Unit = {
SqliteUtils.using(eclairJdbc.createStatement()) {
statement => {
statement.executeUpdate(s"backup to ${file.getAbsolutePath}")
statement.executeUpdate(s"backup to ${backupFile.getAbsolutePath}")
}
}
}
}
def setupPgDatabases(dbConfig: Config, instanceId: UUID, datadir: File, lockExceptionHandler: LockExceptionHandler): Databases with ExclusiveLock = {
val database = dbConfig.getString("postgres.database")
val host = dbConfig.getString("postgres.host")
val port = dbConfig.getInt("postgres.port")
val username = if (dbConfig.getIsNull("postgres.username") || dbConfig.getString("postgres.username").isEmpty)
None
else
Some(dbConfig.getString("postgres.username"))
val password = if (dbConfig.getIsNull("postgres.password") || dbConfig.getString("postgres.password").isEmpty)
None
else
Some(dbConfig.getString("postgres.password"))
val properties = {
val poolConfig = dbConfig.getConfig("postgres.pool")
Map.empty
.updated("max-size", poolConfig.getInt("max-size").toLong)
.updated("connection-timeout", poolConfig.getDuration("connection-timeout").toMillis)
.updated("idle-timeout", poolConfig.getDuration("idle-timeout").toMillis)
.updated("max-life-time", poolConfig.getDuration("max-life-time").toMillis)
}
val lockType = LockType(dbConfig.getString("postgres.lock-type"))
val leaseInterval = dbConfig.getDuration("postgres.lease.interval").toSeconds.seconds
Databases.postgresJDBC(
database = database, host = host, port = port,
username = username, password = password,
poolProperties = properties,
instanceId = instanceId,
databaseLeaseInterval = leaseInterval,
lockExceptionHandler = lockExceptionHandler, lockType = lockType, datadir = datadir
)
}
private def checkIfDatabaseUrlIsUnchanged(url: String, datadir: File ): Unit = {
val urlFile = new File(datadir, "last_jdbcurl")
def readString(path: Path): String = Files.readAllLines(path).get(0)
def writeString(path: Path, string: String): Unit = Files.write(path, java.util.Arrays.asList(string))
if (urlFile.exists()) {
val oldUrl = readString(urlFile.toPath)
if (oldUrl != url)
throw new RuntimeException(s"The database URL has changed since the last start. It was `$oldUrl`, now it's `$url`")
} else {
writeString(urlFile.toPath, url)
}
}
}

View File

@ -22,6 +22,7 @@ import java.nio.file.{Files, StandardCopyOption}
import akka.actor.{Actor, ActorLogging, Props}
import akka.dispatch.{BoundedMessageQueueSemantics, RequiresMessageQueue}
import fr.acinq.eclair.channel.ChannelPersisted
import fr.acinq.eclair.db.Databases.FileBackup
import scala.sys.process.Process
import scala.util.{Failure, Success, Try}
@ -46,7 +47,7 @@ import scala.util.{Failure, Success, Try}
*
* Constructor is private so users will have to use BackupHandler.props() which always specific a custom mailbox
*/
class BackupHandler private(databases: Databases, backupFile: File, backupScript_opt: Option[String]) extends Actor with RequiresMessageQueue[BoundedMessageQueueSemantics] with ActorLogging {
class FileBackupHandler private(databases: FileBackup, backupFile: File, backupScript_opt: Option[String]) extends Actor with RequiresMessageQueue[BoundedMessageQueueSemantics] with ActorLogging {
// we listen to ChannelPersisted events, which will trigger a backup
context.system.eventStream.subscribe(self, classOf[ChannelPersisted])
@ -56,6 +57,7 @@ class BackupHandler private(databases: Databases, backupFile: File, backupScript
val start = System.currentTimeMillis()
val tmpFile = new File(backupFile.getAbsolutePath.concat(".tmp"))
databases.backup(tmpFile)
// this will throw an exception if it fails, which is possible if the backup file is not on the same filesystem
// as the temporary file
Files.move(tmpFile.toPath, backupFile.toPath, StandardCopyOption.REPLACE_EXISTING, StandardCopyOption.ATOMIC_MOVE)
@ -83,8 +85,8 @@ sealed trait BackupEvent
// this notification is sent when we have completed our backup process (our backup file is ready to be used)
case object BackupCompleted extends BackupEvent
object BackupHandler {
object FileBackupHandler {
// using this method is the only way to create a BackupHandler actor
// we make sure that it uses a custom bounded mailbox, and a custom pinned dispatcher (i.e our actor will have its own thread pool with 1 single thread)
def props(databases: Databases, backupFile: File, backupScript_opt: Option[String]) = Props(new BackupHandler(databases, backupFile, backupScript_opt)).withMailbox("eclair.backup-mailbox").withDispatcher("eclair.backup-dispatcher")
def props(databases: FileBackup, backupFile: File, backupScript_opt: Option[String]) = Props(new FileBackupHandler(databases, backupFile, backupScript_opt)).withMailbox("eclair.backup-mailbox").withDispatcher("eclair.backup-dispatcher")
}

View File

@ -0,0 +1,139 @@
/*
* Copyright 2019 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package fr.acinq.eclair.db.jdbc
import java.sql.{Connection, ResultSet, Statement}
import java.util.UUID
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.MilliSatoshi
import javax.sql.DataSource
import scodec.Codec
import scodec.bits.{BitVector, ByteVector}
import scala.collection.immutable.Queue
trait JdbcUtils {
def withConnection[T](f: Connection => T)(implicit dataSource: DataSource): T = {
val connection = dataSource.getConnection()
try {
f(connection)
} finally {
connection.close()
}
}
/**
* This helper makes sure statements are correctly closed.
*
* @param inTransaction if set to true, all updates in the block will be run in a transaction.
*/
def using[T <: Statement, U](statement: T, inTransaction: Boolean = false)(block: T => U): U = {
val autoCommit = statement.getConnection.getAutoCommit
try {
if (inTransaction) statement.getConnection.setAutoCommit(false)
val res = block(statement)
if (inTransaction) statement.getConnection.commit()
res
} catch {
case t: Exception =>
if (inTransaction) statement.getConnection.rollback()
throw t
} finally {
if (inTransaction) statement.getConnection.setAutoCommit(autoCommit)
if (statement != null) statement.close()
}
}
/**
* This helper assumes that there is a "data" column available, decodable with the provided codec
*
* TODO: we should use an scala.Iterator instead
*/
def codecSequence[T](rs: ResultSet, codec: Codec[T]): Seq[T] = {
var q: Queue[T] = Queue()
while (rs.next()) {
q = q :+ codec.decode(BitVector(rs.getBytes("data"))).require.value
}
q
}
case class ExtendedResultSet(rs: ResultSet) {
def getByteVectorFromHex(columnLabel: String): ByteVector = {
val s = rs.getString(columnLabel).stripPrefix("\\x")
ByteVector.fromValidHex(s)
}
def getByteVector32FromHex(columnLabel: String): ByteVector32 = {
val s = rs.getString(columnLabel)
ByteVector32(ByteVector.fromValidHex(s))
}
def getByteVector32FromHexNullable(columnLabel: String): Option[ByteVector32] = {
val s = rs.getString(columnLabel)
if (rs.wasNull()) None else {
Some(ByteVector32(ByteVector.fromValidHex(s)))
}
}
def getBitVectorOpt(columnLabel: String): Option[BitVector] = Option(rs.getBytes(columnLabel)).map(BitVector(_))
def getByteVector(columnLabel: String): ByteVector = ByteVector(rs.getBytes(columnLabel))
def getByteVectorNullable(columnLabel: String): ByteVector = {
val result = rs.getBytes(columnLabel)
if (rs.wasNull()) ByteVector.empty else ByteVector(result)
}
def getByteVector32(columnLabel: String): ByteVector32 = ByteVector32(ByteVector(rs.getBytes(columnLabel)))
def getByteVector32Nullable(columnLabel: String): Option[ByteVector32] = {
val bytes = rs.getBytes(columnLabel)
if (rs.wasNull()) None else Some(ByteVector32(ByteVector(bytes)))
}
def getStringNullable(columnLabel: String): Option[String] = {
val result = rs.getString(columnLabel)
if (rs.wasNull()) None else Some(result)
}
def getLongNullable(columnLabel: String): Option[Long] = {
val result = rs.getLong(columnLabel)
if (rs.wasNull()) None else Some(result)
}
def getUUIDNullable(label: String): Option[UUID] = {
val result = rs.getString(label)
if (rs.wasNull()) None else Some(UUID.fromString(result))
}
def getMilliSatoshiNullable(label: String): Option[MilliSatoshi] = {
val result = rs.getLong(label)
if (rs.wasNull()) None else Some(MilliSatoshi(result))
}
}
object ExtendedResultSet {
implicit def conv(rs: ResultSet): ExtendedResultSet = ExtendedResultSet(rs)
}
}
object JdbcUtils extends JdbcUtils

View File

@ -0,0 +1,321 @@
/*
* Copyright 2019 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package fr.acinq.eclair.db.pg
import java.util.UUID
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.bitcoin.{ByteVector32, Satoshi}
import fr.acinq.eclair.channel.{ChannelErrorOccurred, LocalError, NetworkFeePaid, RemoteError}
import fr.acinq.eclair.db._
import fr.acinq.eclair.payment._
import fr.acinq.eclair.{LongToBtcAmount, MilliSatoshi}
import grizzled.slf4j.Logging
import javax.sql.DataSource
import scala.collection.immutable.Queue
class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
import PgUtils._
import ExtendedResultSet._
val DB_NAME = "audit"
val CURRENT_VERSION = 4
case class RelayedPart(channelId: ByteVector32, amount: MilliSatoshi, direction: String, relayType: String, timestamp: Long)
inTransaction { pg =>
using(pg.createStatement()) { statement =>
getVersion(statement, DB_NAME, CURRENT_VERSION) match {
case CURRENT_VERSION =>
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)")
case unknownVersion =>
throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
}
}
}
override def add(e: ChannelLifecycleEvent): Unit =
inTransaction { pg =>
using(pg.prepareStatement("INSERT INTO channel_events VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, e.channelId.toHex)
statement.setString(2, e.remoteNodeId.value.toHex)
statement.setLong(3, e.capacity.toLong)
statement.setBoolean(4, e.isFunder)
statement.setBoolean(5, e.isPrivate)
statement.setString(6, e.event)
statement.setLong(7, System.currentTimeMillis)
statement.executeUpdate()
}
}
override def add(e: PaymentSent): Unit =
inTransaction { pg =>
using(pg.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
e.parts.foreach(p => {
statement.setLong(1, p.amount.toLong)
statement.setLong(2, p.feesPaid.toLong)
statement.setLong(3, e.recipientAmount.toLong)
statement.setString(4, p.id.toString)
statement.setString(5, e.id.toString)
statement.setString(6, e.paymentHash.toHex)
statement.setString(7, e.paymentPreimage.toHex)
statement.setString(8, e.recipientNodeId.value.toHex)
statement.setString(9, p.toChannelId.toHex)
statement.setLong(10, p.timestamp)
statement.addBatch()
})
statement.executeBatch()
}
}
override def add(e: PaymentReceived): Unit =
inTransaction { pg =>
using(pg.prepareStatement("INSERT INTO received VALUES (?, ?, ?, ?)")) { statement =>
e.parts.foreach(p => {
statement.setLong(1, p.amount.toLong)
statement.setString(2, e.paymentHash.toHex)
statement.setString(3, p.fromChannelId.toHex)
statement.setLong(4, p.timestamp)
statement.addBatch()
})
statement.executeBatch()
}
}
override def add(e: PaymentRelayed): Unit =
inTransaction { pg =>
val payments = e match {
case ChannelPaymentRelayed(amountIn, amountOut, _, fromChannelId, toChannelId, ts) =>
// non-trampoline relayed payments have one input and one output
Seq(RelayedPart(fromChannelId, amountIn, "IN", "channel", ts), RelayedPart(toChannelId, amountOut, "OUT", "channel", ts))
case TrampolinePaymentRelayed(_, incoming, outgoing, ts) =>
// trampoline relayed payments do MPP aggregation and may have M inputs and N outputs
incoming.map(i => RelayedPart(i.channelId, i.amount, "IN", "trampoline", ts)) ++ outgoing.map(o => RelayedPart(o.channelId, o.amount, "OUT", "trampoline", ts))
}
for (p <- payments) {
using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, e.paymentHash.toHex)
statement.setLong(2, p.amount.toLong)
statement.setString(3, p.channelId.toHex)
statement.setString(4, p.direction)
statement.setString(5, p.relayType)
statement.setLong(6, e.timestamp)
statement.executeUpdate()
}
}
}
override def add(e: NetworkFeePaid): Unit =
inTransaction { pg =>
using(pg.prepareStatement("INSERT INTO network_fees VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, e.channelId.toHex)
statement.setString(2, e.remoteNodeId.value.toHex)
statement.setString(3, e.tx.txid.toHex)
statement.setLong(4, e.fee.toLong)
statement.setString(5, e.txType)
statement.setLong(6, System.currentTimeMillis)
statement.executeUpdate()
}
}
override def add(e: ChannelErrorOccurred): Unit =
inTransaction { pg =>
using(pg.prepareStatement("INSERT INTO channel_errors VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
val (errorName, errorMessage) = e.error match {
case LocalError(t) => (t.getClass.getSimpleName, t.getMessage)
case RemoteError(error) => ("remote", error.toAscii)
}
statement.setString(1, e.channelId.toHex)
statement.setString(2, e.remoteNodeId.value.toHex)
statement.setString(3, errorName)
statement.setString(4, errorMessage)
statement.setBoolean(5, e.isFatal)
statement.setLong(6, System.currentTimeMillis)
statement.executeUpdate()
}
}
override def listSent(from: Long, to: Long): Seq[PaymentSent] =
inTransaction { pg =>
using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp >= ? AND timestamp < ?")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
val rs = statement.executeQuery()
var sentByParentId = Map.empty[UUID, PaymentSent]
while (rs.next()) {
val parentId = UUID.fromString(rs.getString("parent_payment_id"))
val part = PaymentSent.PartialPayment(
UUID.fromString(rs.getString("payment_id")),
MilliSatoshi(rs.getLong("amount_msat")),
MilliSatoshi(rs.getLong("fees_msat")),
rs.getByteVector32FromHex("to_channel_id"),
None, // we don't store the route in the audit DB
rs.getLong("timestamp"))
val sent = sentByParentId.get(parentId) match {
case Some(s) => s.copy(parts = s.parts :+ part)
case None => PaymentSent(
parentId,
rs.getByteVector32FromHex("payment_hash"),
rs.getByteVector32FromHex("payment_preimage"),
MilliSatoshi(rs.getLong("recipient_amount_msat")),
PublicKey(rs.getByteVectorFromHex("recipient_node_id")),
Seq(part))
}
sentByParentId = sentByParentId + (parentId -> sent)
}
sentByParentId.values.toSeq.sortBy(_.timestamp)
}
}
override def listReceived(from: Long, to: Long): Seq[PaymentReceived] =
inTransaction { pg =>
using(pg.prepareStatement("SELECT * FROM received WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
val rs = statement.executeQuery()
var receivedByHash = Map.empty[ByteVector32, PaymentReceived]
while (rs.next()) {
val paymentHash = rs.getByteVector32FromHex("payment_hash")
val part = PaymentReceived.PartialPayment(
MilliSatoshi(rs.getLong("amount_msat")),
rs.getByteVector32FromHex("from_channel_id"),
rs.getLong("timestamp"))
val received = receivedByHash.get(paymentHash) match {
case Some(r) => r.copy(parts = r.parts :+ part)
case None => PaymentReceived(paymentHash, Seq(part))
}
receivedByHash = receivedByHash + (paymentHash -> received)
}
receivedByHash.values.toSeq.sortBy(_.timestamp)
}
}
override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] =
inTransaction { pg =>
using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
val rs = statement.executeQuery()
var relayedByHash = Map.empty[ByteVector32, Seq[RelayedPart]]
while (rs.next()) {
val paymentHash = rs.getByteVector32FromHex("payment_hash")
val part = RelayedPart(
rs.getByteVector32FromHex("channel_id"),
MilliSatoshi(rs.getLong("amount_msat")),
rs.getString("direction"),
rs.getString("relay_type"),
rs.getLong("timestamp"))
relayedByHash = relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part))
}
relayedByHash.flatMap {
case (paymentHash, parts) =>
// We may have been routing multiple payments for the same payment_hash (MPP) in both cases (trampoline and channel).
// NB: we may link the wrong in-out parts, but the overall sum will be correct: we sort by amounts to minimize the risk of mismatch.
val incoming = parts.filter(_.direction == "IN").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount)
val outgoing = parts.filter(_.direction == "OUT").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount)
parts.headOption match {
case Some(RelayedPart(_, _, _, "channel", timestamp)) => incoming.zip(outgoing).map {
case (in, out) => ChannelPaymentRelayed(in.amount, out.amount, paymentHash, in.channelId, out.channelId, timestamp)
}
case Some(RelayedPart(_, _, _, "trampoline", timestamp)) => TrampolinePaymentRelayed(paymentHash, incoming, outgoing, timestamp) :: Nil
case _ => Nil
}
}.toSeq.sortBy(_.timestamp)
}
}
override def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] =
inTransaction { pg =>
using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
val rs = statement.executeQuery()
var q: Queue[NetworkFee] = Queue()
while (rs.next()) {
q = q :+ NetworkFee(
remoteNodeId = PublicKey(rs.getByteVectorFromHex("node_id")),
channelId = rs.getByteVector32FromHex("channel_id"),
txId = rs.getByteVector32FromHex("tx_id"),
fee = Satoshi(rs.getLong("fee_sat")),
txType = rs.getString("tx_type"),
timestamp = rs.getLong("timestamp"))
}
q
}
}
override def stats(from: Long, to: Long): Seq[Stats] = {
val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { case (feeByChannelId, f) =>
feeByChannelId + (f.channelId -> (feeByChannelId.getOrElse(f.channelId, 0 sat) + f.fee))
}
case class Relayed(amount: MilliSatoshi, fee: MilliSatoshi, direction: String)
val relayed = listRelayed(from, to).foldLeft(Map.empty[ByteVector32, Seq[Relayed]]) { case (previous, e) =>
// NB: we must avoid counting the fee twice: we associate it to the outgoing channels rather than the incoming ones.
val current = e match {
case c: ChannelPaymentRelayed => Map(
c.fromChannelId -> (Relayed(c.amountIn, 0 msat, "IN") +: previous.getOrElse(c.fromChannelId, Nil)),
c.toChannelId -> (Relayed(c.amountOut, c.amountIn - c.amountOut, "OUT") +: previous.getOrElse(c.toChannelId, Nil)),
)
case t: TrampolinePaymentRelayed =>
// We ensure a trampoline payment is counted only once per channel and per direction (if multiple HTLCs were
// sent from/to the same channel, we group them).
val in = t.incoming.groupBy(_.channelId).map { case (channelId, parts) => (channelId, Relayed(parts.map(_.amount).sum, 0 msat, "IN")) }.toSeq
val out = t.outgoing.groupBy(_.channelId).map { case (channelId, parts) =>
val fee = (t.amountIn - t.amountOut) * parts.length / t.outgoing.length // we split the fee among outgoing channels
(channelId, Relayed(parts.map(_.amount).sum, fee, "OUT"))
}.toSeq
(in ++ out).groupBy(_._1).map { case (channelId, payments) => (channelId, payments.map(_._2) ++ previous.getOrElse(channelId, Nil)) }
}
previous ++ current
}
// Channels opened by our peers won't have any entry in the network_fees table, but we still want to compute stats for them.
val allChannels = networkFees.keySet ++ relayed.keySet
allChannels.toSeq.flatMap(channelId => {
val networkFee = networkFees.getOrElse(channelId, 0 sat)
val (in, out) = relayed.getOrElse(channelId, Nil).partition(_.direction == "IN")
((in, "IN") :: (out, "OUT") :: Nil).map { case (r, direction) =>
val paymentCount = r.length
if (paymentCount == 0) {
Stats(channelId, direction, 0 sat, 0, 0 sat, networkFee)
} else {
val avgPaymentAmount = r.map(_.amount).sum / paymentCount
val relayFee = r.map(_.fee).sum
Stats(channelId, direction, avgPaymentAmount.truncateToSatoshi, paymentCount, relayFee.truncateToSatoshi, networkFee)
}
}
})
}
override def close(): Unit = ()
}

View File

@ -0,0 +1,124 @@
/*
* Copyright 2019 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package fr.acinq.eclair.db.pg
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.CltvExpiry
import fr.acinq.eclair.channel.HasCommitments
import fr.acinq.eclair.db.ChannelsDb
import fr.acinq.eclair.db.pg.PgUtils.DatabaseLock
import fr.acinq.eclair.wire.ChannelCodecs.stateDataCodec
import grizzled.slf4j.Logging
import javax.sql.DataSource
import scala.collection.immutable.Queue
class PgChannelsDb(implicit ds: DataSource, lock: DatabaseLock) extends ChannelsDb with Logging {
import PgUtils.ExtendedResultSet._
import PgUtils._
import lock._
val DB_NAME = "channels"
val CURRENT_VERSION = 2
inTransaction { pg =>
using(pg.createStatement()) { statement =>
getVersion(statement, DB_NAME, CURRENT_VERSION) match {
case CURRENT_VERSION =>
statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
case unknownVersion => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
}
}
}
override def addOrUpdateChannel(state: HasCommitments): Unit = {
withLock { pg =>
val data = stateDataCodec.encode(state).require.toByteArray
using(pg.prepareStatement("UPDATE local_channels SET data=? WHERE channel_id=?")) { update =>
update.setBytes(1, data)
update.setString(2, state.channelId.toHex)
if (update.executeUpdate() == 0) {
using(pg.prepareStatement("INSERT INTO local_channels VALUES (?, ?, FALSE)")) { statement =>
statement.setString(1, state.channelId.toHex)
statement.setBytes(2, data)
statement.executeUpdate()
}
}
}
}
}
override def removeChannel(channelId: ByteVector32): Unit = {
withLock { pg =>
using(pg.prepareStatement("DELETE FROM pending_relay WHERE channel_id=?")) { statement =>
statement.setString(1, channelId.toHex)
statement.executeUpdate()
}
using(pg.prepareStatement("DELETE FROM htlc_infos WHERE channel_id=?")) { statement =>
statement.setString(1, channelId.toHex)
statement.executeUpdate()
}
using(pg.prepareStatement("UPDATE local_channels SET is_closed=TRUE WHERE channel_id=?")) { statement =>
statement.setString(1, channelId.toHex)
statement.executeUpdate()
}
}
}
override def listLocalChannels(): Seq[HasCommitments] = {
withLock { pg =>
using(pg.createStatement) { statement =>
val rs = statement.executeQuery("SELECT data FROM local_channels WHERE is_closed=FALSE")
codecSequence(rs, stateDataCodec)
}
}
}
override def addHtlcInfo(channelId: ByteVector32, commitmentNumber: Long, paymentHash: ByteVector32, cltvExpiry: CltvExpiry): Unit = {
withLock { pg =>
using(pg.prepareStatement("INSERT INTO htlc_infos VALUES (?, ?, ?, ?)")) { statement =>
statement.setString(1, channelId.toHex)
statement.setLong(2, commitmentNumber)
statement.setString(3, paymentHash.toHex)
statement.setLong(4, cltvExpiry.toLong)
statement.executeUpdate()
}
}
}
override def listHtlcInfos(channelId: ByteVector32, commitmentNumber: Long): Seq[(ByteVector32, CltvExpiry)] = {
withLock { pg =>
using(pg.prepareStatement("SELECT payment_hash, cltv_expiry FROM htlc_infos WHERE channel_id=? AND commitment_number=?")) { statement =>
statement.setString(1, channelId.toHex)
statement.setString(2, commitmentNumber.toString)
val rs = statement.executeQuery
var q: Queue[(ByteVector32, CltvExpiry)] = Queue()
while (rs.next()) {
q = q :+ (ByteVector32(rs.getByteVector32FromHex("payment_hash")), CltvExpiry(rs.getLong("cltv_expiry")))
}
q
}
}
}
override def close(): Unit = ()
}

View File

@ -0,0 +1,183 @@
/*
* Copyright 2019 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package fr.acinq.eclair.db.pg
import fr.acinq.bitcoin.{ByteVector32, Crypto, Satoshi}
import fr.acinq.eclair.ShortChannelId
import fr.acinq.eclair.db.NetworkDb
import fr.acinq.eclair.router.Router.PublicChannel
import fr.acinq.eclair.wire.LightningMessageCodecs.{channelAnnouncementCodec, channelUpdateCodec, nodeAnnouncementCodec}
import fr.acinq.eclair.wire.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement}
import grizzled.slf4j.Logging
import javax.sql.DataSource
import scala.collection.immutable.SortedMap
class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging {
import PgUtils.ExtendedResultSet._
import PgUtils._
val DB_NAME = "network"
val CURRENT_VERSION = 2
inTransaction { pg =>
using(pg.createStatement()) { statement =>
getVersion(statement, DB_NAME, CURRENT_VERSION) match {
case CURRENT_VERSION => () // nothing to do
case unknown => throw new IllegalArgumentException(s"unknown version $unknown for network db")
}
statement.executeUpdate("CREATE TABLE IF NOT EXISTS nodes (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channels (short_channel_id BIGINT NOT NULL PRIMARY KEY, txid TEXT NOT NULL, channel_announcement BYTEA NOT NULL, capacity_sat BIGINT NOT NULL, channel_update_1 BYTEA NULL, channel_update_2 BYTEA NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS pruned (short_channel_id BIGINT NOT NULL PRIMARY KEY)")
}
}
override def addNode(n: NodeAnnouncement): Unit = {
inTransaction { pg =>
using(pg.prepareStatement("INSERT INTO nodes VALUES (?, ?) ON CONFLICT DO NOTHING")) { statement =>
statement.setString(1, n.nodeId.value.toHex)
statement.setBytes(2, nodeAnnouncementCodec.encode(n).require.toByteArray)
statement.executeUpdate()
}
}
}
override def updateNode(n: NodeAnnouncement): Unit = {
inTransaction { pg =>
using(pg.prepareStatement("UPDATE nodes SET data=? WHERE node_id=?")) { statement =>
statement.setBytes(1, nodeAnnouncementCodec.encode(n).require.toByteArray)
statement.setString(2, n.nodeId.value.toHex)
statement.executeUpdate()
}
}
}
override def getNode(nodeId: Crypto.PublicKey): Option[NodeAnnouncement] = {
inTransaction { pg =>
using(pg.prepareStatement("SELECT data FROM nodes WHERE node_id=?")) { statement =>
statement.setString(1, nodeId.value.toHex)
val rs = statement.executeQuery()
codecSequence(rs, nodeAnnouncementCodec).headOption
}
}
}
override def removeNode(nodeId: Crypto.PublicKey): Unit = {
inTransaction { pg =>
using(pg.prepareStatement("DELETE FROM nodes WHERE node_id=?")) { statement =>
statement.setString(1, nodeId.value.toHex)
statement.executeUpdate()
}
}
}
override def listNodes(): Seq[NodeAnnouncement] = {
inTransaction { pg =>
using(pg.createStatement()) { statement =>
val rs = statement.executeQuery("SELECT data FROM nodes")
codecSequence(rs, nodeAnnouncementCodec)
}
}
}
override def addChannel(c: ChannelAnnouncement, txid: ByteVector32, capacity: Satoshi): Unit = {
inTransaction { pg =>
using(pg.prepareStatement("INSERT INTO channels VALUES (?, ?, ?, ?) ON CONFLICT DO NOTHING")) { statement =>
statement.setLong(1, c.shortChannelId.toLong)
statement.setString(2, txid.toHex)
statement.setBytes(3, channelAnnouncementCodec.encode(c).require.toByteArray)
statement.setLong(4, capacity.toLong)
statement.executeUpdate()
}
}
}
override def updateChannel(u: ChannelUpdate): Unit = {
val column = if (u.isNode1) "channel_update_1" else "channel_update_2"
inTransaction { pg =>
using(pg.prepareStatement(s"UPDATE channels SET $column=? WHERE short_channel_id=?")) { statement =>
statement.setBytes(1, channelUpdateCodec.encode(u).require.toByteArray)
statement.setLong(2, u.shortChannelId.toLong)
statement.executeUpdate()
}
}
}
override def listChannels(): SortedMap[ShortChannelId, PublicChannel] = {
inTransaction { pg =>
using(pg.createStatement()) { statement =>
val rs = statement.executeQuery("SELECT channel_announcement, txid, capacity_sat, channel_update_1, channel_update_2 FROM channels")
var m = SortedMap.empty[ShortChannelId, PublicChannel]
while (rs.next()) {
val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value
val txId = ByteVector32.fromValidHex(rs.getString("txid"))
val capacity = rs.getLong("capacity_sat")
val channel_update_1_opt = rs.getBitVectorOpt("channel_update_1").map(channelUpdateCodec.decode(_).require.value)
val channel_update_2_opt = rs.getBitVectorOpt("channel_update_2").map(channelUpdateCodec.decode(_).require.value)
m = m + (ann.shortChannelId -> PublicChannel(ann, txId, Satoshi(capacity), channel_update_1_opt, channel_update_2_opt, None))
}
m
}
}
}
override def removeChannels(shortChannelIds: Iterable[ShortChannelId]): Unit = {
inTransaction { pg =>
using(pg.createStatement) { statement =>
shortChannelIds
.grouped(1000) // remove channels by batch of 1000
.foreach { _ =>
val ids = shortChannelIds.map(_.toLong).mkString(",")
statement.executeUpdate(s"DELETE FROM channels WHERE short_channel_id IN ($ids)")
}
}
}
}
override def addToPruned(shortChannelIds: Iterable[ShortChannelId]): Unit = {
inTransaction { pg =>
using(pg.prepareStatement("INSERT INTO pruned VALUES (?) ON CONFLICT DO NOTHING")) { statement =>
shortChannelIds.foreach(shortChannelId => {
statement.setLong(1, shortChannelId.toLong)
statement.addBatch()
})
statement.executeBatch()
}
}
}
override def removeFromPruned(shortChannelId: ShortChannelId): Unit = {
inTransaction { pg =>
using(pg.createStatement) { statement =>
statement.executeUpdate(s"DELETE FROM pruned WHERE short_channel_id=${shortChannelId.toLong}")
}
}
}
override def isPruned(shortChannelId: ShortChannelId): Boolean = {
inTransaction { pg =>
using(pg.prepareStatement("SELECT short_channel_id from pruned WHERE short_channel_id=?")) { statement =>
statement.setLong(1, shortChannelId.toLong)
val rs = statement.executeQuery()
rs.next()
}
}
}
override def close(): Unit = ()
}

View File

@ -0,0 +1,406 @@
/*
* Copyright 2019 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package fr.acinq.eclair.db.pg
import java.sql.ResultSet
import java.util.UUID
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.MilliSatoshi
import fr.acinq.eclair.db.pg.PgUtils.DatabaseLock
import fr.acinq.eclair.db.{HopSummary, _}
import fr.acinq.eclair.payment.{PaymentFailed, PaymentRequest, PaymentSent}
import fr.acinq.eclair.wire.CommonCodecs
import grizzled.slf4j.Logging
import javax.sql.DataSource
import scodec.Attempt
import scodec.bits.BitVector
import scodec.codecs._
import scala.collection.immutable.Queue
import scala.concurrent.duration._
class PgPaymentsDb(implicit ds: DataSource, lock: DatabaseLock) extends PaymentsDb with Logging {
import PgUtils.ExtendedResultSet._
import PgUtils._
import lock._
val DB_NAME = "payments"
val CURRENT_VERSION = 4
private val hopSummaryCodec = (("node_id" | CommonCodecs.publicKey) :: ("next_node_id" | CommonCodecs.publicKey) :: ("short_channel_id" | optional(bool, CommonCodecs.shortchannelid))).as[HopSummary]
private val paymentRouteCodec = discriminated[List[HopSummary]].by(byte)
.typecase(0x01, listOfN(uint8, hopSummaryCodec))
private val failureSummaryCodec = (("type" | enumerated(uint8, FailureType)) :: ("message" | ascii32) :: paymentRouteCodec).as[FailureSummary]
private val paymentFailuresCodec = discriminated[List[FailureSummary]].by(byte)
.typecase(0x01, listOfN(uint8, failureSummaryCodec))
inTransaction { pg =>
using(pg.createStatement()) { statement =>
getVersion(statement, DB_NAME, CURRENT_VERSION) match {
case CURRENT_VERSION =>
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash TEXT NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage TEXT NOT NULL, payment_request TEXT NOT NULL, received_msat BIGINT, created_at BIGINT NOT NULL, expire_at BIGINT NOT NULL, received_at BIGINT)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash TEXT NOT NULL, payment_preimage TEXT, payment_type TEXT NOT NULL, amount_msat BIGINT NOT NULL, fees_msat BIGINT, recipient_amount_msat BIGINT NOT NULL, recipient_node_id TEXT NOT NULL, payment_request TEXT, payment_route BYTEA, failures BYTEA, created_at BIGINT NOT NULL, completed_at BIGINT)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_parent_id_idx ON sent_payments(parent_id)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_payment_hash_idx ON sent_payments(payment_hash)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_created_idx ON sent_payments(created_at)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_created_idx ON received_payments(created_at)")
case unknownVersion => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
}
}
}
override def addOutgoingPayment(sent: OutgoingPayment): Unit = {
require(sent.status == OutgoingPaymentStatus.Pending, s"outgoing payment isn't pending (${sent.status.getClass.getSimpleName})")
withLock { pg =>
using(pg.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, payment_type, amount_msat, recipient_amount_msat, recipient_node_id, created_at, payment_request) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, sent.id.toString)
statement.setString(2, sent.parentId.toString)
statement.setString(3, sent.externalId.orNull)
statement.setString(4, sent.paymentHash.toHex)
statement.setString(5, sent.paymentType)
statement.setLong(6, sent.amount.toLong)
statement.setLong(7, sent.recipientAmount.toLong)
statement.setString(8, sent.recipientNodeId.value.toHex)
statement.setLong(9, sent.createdAt)
statement.setString(10, sent.paymentRequest.map(PaymentRequest.write).orNull)
statement.executeUpdate()
}
}
}
override def updateOutgoingPayment(paymentResult: PaymentSent): Unit =
withLock { pg =>
using(pg.prepareStatement("UPDATE sent_payments SET (completed_at, payment_preimage, fees_msat, payment_route) = (?, ?, ?, ?) WHERE id = ? AND completed_at IS NULL")) { statement =>
paymentResult.parts.foreach(p => {
statement.setLong(1, p.timestamp)
statement.setString(2, paymentResult.paymentPreimage.toHex)
statement.setLong(3, p.feesPaid.toLong)
statement.setBytes(4, paymentRouteCodec.encode(p.route.getOrElse(Nil).map(h => HopSummary(h)).toList).require.toByteArray)
statement.setString(5, p.id.toString)
statement.addBatch()
})
if (statement.executeBatch().contains(0)) throw new IllegalArgumentException(s"Tried to mark an outgoing payment as succeeded but already in final status (id=${paymentResult.id})")
}
}
override def updateOutgoingPayment(paymentResult: PaymentFailed): Unit =
withLock { pg =>
using(pg.prepareStatement("UPDATE sent_payments SET (completed_at, failures) = (?, ?) WHERE id = ? AND completed_at IS NULL")) { statement =>
statement.setLong(1, paymentResult.timestamp)
statement.setBytes(2, paymentFailuresCodec.encode(paymentResult.failures.map(f => FailureSummary(f)).toList).require.toByteArray)
statement.setString(3, paymentResult.id.toString)
if (statement.executeUpdate() == 0) throw new IllegalArgumentException(s"Tried to mark an outgoing payment as failed but already in final status (id=${paymentResult.id})")
}
}
private def parseOutgoingPayment(rs: ResultSet): OutgoingPayment = {
val status = buildOutgoingPaymentStatus(
rs.getByteVector32FromHexNullable("payment_preimage"),
rs.getMilliSatoshiNullable("fees_msat"),
rs.getBitVectorOpt("payment_route"),
rs.getLongNullable("completed_at"),
rs.getBitVectorOpt("failures"))
OutgoingPayment(
UUID.fromString(rs.getString("id")),
UUID.fromString(rs.getString("parent_id")),
rs.getStringNullable("external_id"),
rs.getByteVector32FromHex("payment_hash"),
rs.getString("payment_type"),
MilliSatoshi(rs.getLong("amount_msat")),
MilliSatoshi(rs.getLong("recipient_amount_msat")),
PublicKey(rs.getByteVectorFromHex("recipient_node_id")),
rs.getLong("created_at"),
rs.getStringNullable("payment_request").map(PaymentRequest.read),
status
)
}
private def buildOutgoingPaymentStatus(preimage_opt: Option[ByteVector32], fees_opt: Option[MilliSatoshi], paymentRoute_opt: Option[BitVector], completedAt_opt: Option[Long], failures: Option[BitVector]): OutgoingPaymentStatus = {
preimage_opt match {
// If we have a pre-image, the payment succeeded.
case Some(preimage) => OutgoingPaymentStatus.Succeeded(
preimage, fees_opt.getOrElse(MilliSatoshi(0)), paymentRoute_opt.map(b => paymentRouteCodec.decode(b) match {
case Attempt.Successful(route) => route.value
case Attempt.Failure(_) => Nil
}).getOrElse(Nil),
completedAt_opt.getOrElse(0)
)
case None => completedAt_opt match {
// Otherwise if the payment was marked completed, it's a failure.
case Some(completedAt) => OutgoingPaymentStatus.Failed(
failures.map(b => paymentFailuresCodec.decode(b) match {
case Attempt.Successful(f) => f.value
case Attempt.Failure(_) => Nil
}).getOrElse(Nil),
completedAt
)
// Else it's still pending.
case _ => OutgoingPaymentStatus.Pending
}
}
}
override def getOutgoingPayment(id: UUID): Option[OutgoingPayment] =
withLock { pg =>
using(pg.prepareStatement("SELECT * FROM sent_payments WHERE id = ?")) { statement =>
statement.setString(1, id.toString)
val rs = statement.executeQuery()
if (rs.next()) {
Some(parseOutgoingPayment(rs))
} else {
None
}
}
}
override def listOutgoingPayments(parentId: UUID): Seq[OutgoingPayment] =
withLock { pg =>
using(pg.prepareStatement("SELECT * FROM sent_payments WHERE parent_id = ? ORDER BY created_at")) { statement =>
statement.setString(1, parentId.toString)
val rs = statement.executeQuery()
var q: Queue[OutgoingPayment] = Queue()
while (rs.next()) {
q = q :+ parseOutgoingPayment(rs)
}
q
}
}
override def listOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] =
withLock { pg =>
using(pg.prepareStatement("SELECT * FROM sent_payments WHERE payment_hash = ? ORDER BY created_at")) { statement =>
statement.setString(1, paymentHash.toHex)
val rs = statement.executeQuery()
var q: Queue[OutgoingPayment] = Queue()
while (rs.next()) {
q = q :+ parseOutgoingPayment(rs)
}
q
}
}
override def listOutgoingPayments(from: Long, to: Long): Seq[OutgoingPayment] =
withLock { pg =>
using(pg.prepareStatement("SELECT * FROM sent_payments WHERE created_at >= ? AND created_at < ? ORDER BY created_at")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
val rs = statement.executeQuery()
var q: Queue[OutgoingPayment] = Queue()
while (rs.next()) {
q = q :+ parseOutgoingPayment(rs)
}
q
}
}
override def addIncomingPayment(pr: PaymentRequest, preimage: ByteVector32, paymentType: String): Unit =
withLock { pg =>
using(pg.prepareStatement("INSERT INTO received_payments (payment_hash, payment_preimage, payment_type, payment_request, created_at, expire_at) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, pr.paymentHash.toHex)
statement.setString(2, preimage.toHex)
statement.setString(3, paymentType)
statement.setString(4, PaymentRequest.write(pr))
statement.setLong(5, pr.timestamp.seconds.toMillis) // BOLT11 timestamp is in seconds
statement.setLong(6, (pr.timestamp + pr.expiry.getOrElse(PaymentRequest.DEFAULT_EXPIRY_SECONDS.toLong)).seconds.toMillis)
statement.executeUpdate()
}
}
override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long): Unit =
withLock { pg =>
using(pg.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update =>
update.setLong(1, amount.toLong)
update.setLong(2, receivedAt)
update.setString(3, paymentHash.toHex)
val updated = update.executeUpdate()
if (updated == 0) {
throw new IllegalArgumentException("Inserted a received payment without having an invoice")
}
}
}
private def parseIncomingPayment(rs: ResultSet): IncomingPayment = {
val paymentRequest = rs.getString("payment_request")
IncomingPayment(
PaymentRequest.read(paymentRequest),
rs.getByteVector32FromHex("payment_preimage"),
rs.getString("payment_type"),
rs.getLong("created_at"),
buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("received_msat"), Some(paymentRequest), rs.getLongNullable("received_at")))
}
private def buildIncomingPaymentStatus(amount_opt: Option[MilliSatoshi], serializedPaymentRequest_opt: Option[String], receivedAt_opt: Option[Long]): IncomingPaymentStatus = {
amount_opt match {
case Some(amount) => IncomingPaymentStatus.Received(amount, receivedAt_opt.getOrElse(0))
case None if serializedPaymentRequest_opt.exists(PaymentRequest.fastHasExpired) => IncomingPaymentStatus.Expired
case None => IncomingPaymentStatus.Pending
}
}
override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] =
withLock { pg =>
using(pg.prepareStatement("SELECT * FROM received_payments WHERE payment_hash = ?")) { statement =>
statement.setString(1, paymentHash.toHex)
val rs = statement.executeQuery()
if (rs.next()) {
Some(parseIncomingPayment(rs))
} else {
None
}
}
}
override def listIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] =
withLock { pg =>
using(pg.prepareStatement("SELECT * FROM received_payments WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
val rs = statement.executeQuery()
var q: Queue[IncomingPayment] = Queue()
while (rs.next()) {
q = q :+ parseIncomingPayment(rs)
}
q
}
}
override def listReceivedIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] =
withLock { pg =>
using(pg.prepareStatement("SELECT * FROM received_payments WHERE received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
val rs = statement.executeQuery()
var q: Queue[IncomingPayment] = Queue()
while (rs.next()) {
q = q :+ parseIncomingPayment(rs)
}
q
}
}
override def listPendingIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] =
withLock { pg =>
using(pg.prepareStatement("SELECT * FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
statement.setLong(3, System.currentTimeMillis)
val rs = statement.executeQuery()
var q: Queue[IncomingPayment] = Queue()
while (rs.next()) {
q = q :+ parseIncomingPayment(rs)
}
q
}
}
override def listExpiredIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] =
withLock { pg =>
using(pg.prepareStatement("SELECT * FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
statement.setLong(3, System.currentTimeMillis)
val rs = statement.executeQuery()
var q: Queue[IncomingPayment] = Queue()
while (rs.next()) {
q = q :+ parseIncomingPayment(rs)
}
q
}
}
override def listPaymentsOverview(limit: Int): Seq[PlainPayment] = {
// This query is an UNION of the ``sent_payments`` and ``received_payments`` table
// - missing fields set to NULL when needed.
// - only retrieve incoming payments that did receive funds.
// - outgoing payments are grouped by parent_id.
// - order by completion date (or creation date if nothing else).
withLock { pg =>
using(pg.prepareStatement(
"""
|SELECT * FROM (
| SELECT 'received' as type,
| NULL as parent_id,
| NULL as external_id,
| payment_hash,
| payment_preimage,
| payment_type,
| received_msat as final_amount,
| payment_request,
| created_at,
| received_at as completed_at,
| expire_at,
| NULL as order_trick
| FROM received_payments
| WHERE received_msat > 0
|UNION ALL
| SELECT 'sent' as type,
| parent_id,
| external_id,
| payment_hash,
| payment_preimage,
| payment_type,
| sum(amount_msat + fees_msat) as final_amount,
| payment_request,
| created_at,
| completed_at,
| NULL as expire_at,
| MAX(coalesce(completed_at, created_at)) as order_trick
| FROM sent_payments
| GROUP BY parent_id,external_id,payment_hash,payment_preimage,payment_type,payment_request,created_at,completed_at
|) q
|ORDER BY coalesce(q.completed_at, q.created_at) DESC
|LIMIT ?
""".stripMargin
)) { statement =>
statement.setInt(1, limit)
val rs = statement.executeQuery()
var q: Queue[PlainPayment] = Queue()
while (rs.next()) {
val parentId = rs.getUUIDNullable("parent_id")
val externalId_opt = rs.getStringNullable("external_id")
val paymentHash = rs.getByteVector32FromHex("payment_hash")
val paymentType = rs.getString("payment_type")
val paymentRequest_opt = rs.getStringNullable("payment_request")
val amount_opt = rs.getMilliSatoshiNullable("final_amount")
val createdAt = rs.getLong("created_at")
val completedAt_opt = rs.getLongNullable("completed_at")
val expireAt_opt = rs.getLongNullable("expire_at")
val p = if (rs.getString("type") == "received") {
val status: IncomingPaymentStatus = buildIncomingPaymentStatus(amount_opt, paymentRequest_opt, completedAt_opt)
PlainIncomingPayment(paymentHash, paymentType, amount_opt, paymentRequest_opt, status, createdAt, completedAt_opt, expireAt_opt)
} else {
val preimage_opt = rs.getByteVector32Nullable("payment_preimage")
// note that the resulting status will not contain any details (routes, failures...)
val status: OutgoingPaymentStatus = buildOutgoingPaymentStatus(preimage_opt, None, None, completedAt_opt, None)
PlainOutgoingPayment(parentId, externalId_opt, paymentHash, paymentType, amount_opt, paymentRequest_opt, status, createdAt, completedAt_opt)
}
q = q :+ p
}
q
}
}
}
override def close(): Unit = ()
}

View File

@ -0,0 +1,95 @@
/*
* Copyright 2019 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package fr.acinq.eclair.db.pg
import fr.acinq.bitcoin.Crypto
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.db.PeersDb
import fr.acinq.eclair.db.pg.PgUtils.DatabaseLock
import fr.acinq.eclair.wire._
import javax.sql.DataSource
import scodec.bits.BitVector
class PgPeersDb(implicit ds: DataSource, lock: DatabaseLock) extends PeersDb {
import PgUtils.ExtendedResultSet._
import PgUtils._
import lock._
val DB_NAME = "peers"
val CURRENT_VERSION = 1
inTransaction { pg =>
using(pg.createStatement()) { statement =>
require(getVersion(statement, DB_NAME, CURRENT_VERSION) == CURRENT_VERSION, s"incompatible version of $DB_NAME DB found") // there is only one version currently deployed
statement.executeUpdate("CREATE TABLE IF NOT EXISTS peers (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)")
}
}
override def addOrUpdatePeer(nodeId: Crypto.PublicKey, nodeaddress: NodeAddress): Unit = {
withLock { pg =>
val data = CommonCodecs.nodeaddress.encode(nodeaddress).require.toByteArray
using(pg.prepareStatement("UPDATE peers SET data=? WHERE node_id=?")) { update =>
update.setBytes(1, data)
update.setString(2, nodeId.value.toHex)
if (update.executeUpdate() == 0) {
using(pg.prepareStatement("INSERT INTO peers VALUES (?, ?)")) { statement =>
statement.setString(1, nodeId.value.toHex)
statement.setBytes(2, data)
statement.executeUpdate()
}
}
}
}
}
override def removePeer(nodeId: Crypto.PublicKey): Unit = {
withLock { pg =>
using(pg.prepareStatement("DELETE FROM peers WHERE node_id=?")) { statement =>
statement.setString(1, nodeId.value.toHex)
statement.executeUpdate()
}
}
}
override def getPeer(nodeId: PublicKey): Option[NodeAddress] = {
withLock { pg =>
using(pg.prepareStatement("SELECT data FROM peers WHERE node_id=?")) { statement =>
statement.setString(1, nodeId.value.toHex)
val rs = statement.executeQuery()
codecSequence(rs, CommonCodecs.nodeaddress).headOption
}
}
}
override def listPeers(): Map[PublicKey, NodeAddress] = {
withLock { pg =>
using(pg.createStatement()) { statement =>
val rs = statement.executeQuery("SELECT node_id, data FROM peers")
var m: Map[PublicKey, NodeAddress] = Map()
while (rs.next()) {
val nodeid = PublicKey(rs.getByteVectorFromHex("node_id"))
val nodeaddress = CommonCodecs.nodeaddress.decode(BitVector(rs.getBytes("data"))).require.value
m += (nodeid -> nodeaddress)
}
m
}
}
}
override def close(): Unit = ()
}

View File

@ -0,0 +1,91 @@
/*
* Copyright 2019 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package fr.acinq.eclair.db.pg
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.channel.{Command, HasHtlcId}
import fr.acinq.eclair.db.PendingRelayDb
import fr.acinq.eclair.db.pg.PgUtils._
import fr.acinq.eclair.wire.CommandCodecs.cmdCodec
import javax.sql.DataSource
import scala.collection.immutable.Queue
class PgPendingRelayDb(implicit ds: DataSource, lock: DatabaseLock) extends PendingRelayDb {
import PgUtils.ExtendedResultSet._
import PgUtils._
import lock._
val DB_NAME = "pending_relay"
val CURRENT_VERSION = 1
inTransaction { pg =>
using(pg.createStatement()) { statement =>
require(getVersion(statement, DB_NAME, CURRENT_VERSION) == CURRENT_VERSION, s"incompatible version of $DB_NAME DB found") // there is only one version currently deployed
// note: should we use a foreign key to local_channels table here?
statement.executeUpdate("CREATE TABLE IF NOT EXISTS pending_relay (channel_id TEXT NOT NULL, htlc_id BIGINT NOT NULL, data BYTEA NOT NULL, PRIMARY KEY(channel_id, htlc_id))")
}
}
override def addPendingRelay(channelId: ByteVector32, cmd: Command with HasHtlcId): Unit = {
withLock { pg =>
using(pg.prepareStatement("INSERT INTO pending_relay VALUES (?, ?, ?) ON CONFLICT DO NOTHING")) { statement =>
statement.setString(1, channelId.toHex)
statement.setLong(2, cmd.id)
statement.setBytes(3, cmdCodec.encode(cmd).require.toByteArray)
statement.executeUpdate()
}
}
}
override def removePendingRelay(channelId: ByteVector32, htlcId: Long): Unit = {
withLock { pg =>
using(pg.prepareStatement("DELETE FROM pending_relay WHERE channel_id=? AND htlc_id=?")) { statement =>
statement.setString(1, channelId.toHex)
statement.setLong(2, htlcId)
statement.executeUpdate()
}
}
}
override def listPendingRelay(channelId: ByteVector32): Seq[Command with HasHtlcId] = {
withLock { pg =>
using(pg.prepareStatement("SELECT htlc_id, data FROM pending_relay WHERE channel_id=?")) { statement =>
statement.setString(1, channelId.toHex)
val rs = statement.executeQuery()
codecSequence(rs, cmdCodec)
}
}
}
override def listPendingRelay(): Set[(ByteVector32, Long)] = {
withLock { pg =>
using(pg.prepareStatement("SELECT channel_id, htlc_id FROM pending_relay")) { statement =>
val rs = statement.executeQuery()
var q: Queue[(ByteVector32, Long)] = Queue()
while (rs.next()) {
q = q :+ (rs.getByteVector32FromHex("channel_id"), rs.getLong("htlc_id"))
}
q.toSet
}
}
}
override def close(): Unit = ()
}

View File

@ -0,0 +1,247 @@
/*
* Copyright 2019 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package fr.acinq.eclair.db.pg
import java.sql.{Connection, Statement, Timestamp}
import java.util.UUID
import fr.acinq.eclair.db.jdbc.JdbcUtils
import grizzled.slf4j.Logging
import javax.sql.DataSource
import org.postgresql.util.{PGInterval, PSQLException}
import scala.concurrent.duration._
import scala.util.{Failure, Success, Try}
object PgUtils extends JdbcUtils with Logging {
val LeaseTable = "lease"
val LockTimeout = 5 seconds
val TransactionIsolationLevel = Connection.TRANSACTION_SERIALIZABLE
object LockType extends Enumeration {
type LockType = Value
val NONE, LEASE = Value
def apply(s: String): LockType = s match {
case "none" => NONE
case "lease" => LEASE
case _ => throw new RuntimeException(s"Unknown postgres lock type: `$s`")
}
}
case class LockLease(expiresAt: Timestamp, instanceId: UUID, expired: Boolean)
// @formatter:off
class TooManyLockAttempts(msg: String) extends RuntimeException(msg)
class UninitializedLockTable(msg: String) extends RuntimeException(msg)
class LockException(msg: String, cause: Option[Throwable] = None) extends RuntimeException(msg, cause.orNull)
class LeaseException(msg: String) extends RuntimeException(msg)
// @formatter:on
type LockExceptionHandler = LockException => Unit
sealed trait DatabaseLock {
def obtainExclusiveLock(implicit ds: DataSource): Unit
def withLock[T](f: Connection => T)(implicit ds: DataSource): T
}
case object NoLock extends DatabaseLock {
override def obtainExclusiveLock(implicit ds: DataSource): Unit = ()
override def withLock[T](f: Connection => T)(implicit ds: DataSource): T =
inTransaction(f)
}
/**
* This class represents a lease based locking mechanism [[https://en.wikipedia.org/wiki/Lease_(computer_science]].
* It allows only one process to access the database at a time.
*
* `obtainExclusiveLock` method updates the record in `lease` table with the instance id and the expiration date
* calculated as the current time plus the lease duration. If the current lease is not expired or it belongs to
* another instance `obtainExclusiveLock` throws an exception.
*
* withLock method executes its `f` function and reads the record from lease table to checks if this instance still
* holds the lease and it's not expired. If so, the database transaction gets committed, otherwise en exception is thrown.
*
* `lockExceptionHandler` provides a lock exception handler to customize the behavior when locking errors occur.
*/
case class LeaseLock(instanceId: UUID, leaseDuration: FiniteDuration, lockExceptionHandler: LockExceptionHandler) extends DatabaseLock {
override def obtainExclusiveLock(implicit ds: DataSource): Unit =
obtainDatabaseLease(instanceId, leaseDuration)
override def withLock[T](f: Connection => T)(implicit ds: DataSource): T = {
inTransaction { connection =>
val res = f(connection)
checkDatabaseLease(connection, instanceId, lockExceptionHandler)
res
}
}
}
def inTransaction[T](connection: Connection)(f: Connection => T): T = {
val autoCommit = connection.getAutoCommit
connection.setAutoCommit(false)
val isolationLevel = connection.getTransactionIsolation
connection.setTransactionIsolation(TransactionIsolationLevel)
try {
val res = f(connection)
connection.commit()
res
} catch {
case ex: Throwable =>
connection.rollback()
throw ex
} finally {
connection.setAutoCommit(autoCommit)
connection.setTransactionIsolation(isolationLevel)
}
}
def inTransaction[T](f: Connection => T)(implicit dataSource: DataSource): T = {
withConnection { connection =>
inTransaction(connection)(f)
}
}
/**
* Several logical databases (channels, network, peers) may be stored in the same physical postgres database.
* We keep track of their respective version using a dedicated table. The version entry will be created if
* there is none but will never be updated here (use setVersion to do that).
*/
def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int = {
statement.executeUpdate("CREATE TABLE IF NOT EXISTS versions (db_name TEXT NOT NULL PRIMARY KEY, version INTEGER NOT NULL)")
// if there was no version for the current db, then insert the current version
statement.executeUpdate(s"INSERT INTO versions VALUES ('$db_name', $currentVersion) ON CONFLICT DO NOTHING")
// if there was a previous version installed, this will return a different value from current version
val res = statement.executeQuery(s"SELECT version FROM versions WHERE db_name='$db_name'")
res.next()
res.getInt("version")
}
/**
* Updates the version for a particular logical database, it will overwrite the previous version.
*/
def setVersion(statement: Statement, db_name: String, newVersion: Int): Unit = {
statement.executeUpdate("CREATE TABLE IF NOT EXISTS versions (db_name TEXT NOT NULL PRIMARY KEY, version INTEGER NOT NULL)")
// overwrite the existing version
statement.executeUpdate(s"UPDATE versions SET version=$newVersion WHERE db_name='$db_name'")
}
private def obtainDatabaseLease(instanceId: UUID, leaseDuration: FiniteDuration, attempt: Int = 1)(implicit ds: DataSource): Unit = synchronized {
logger.debug(s"trying to acquire database lease (attempt #$attempt) instance ID=${instanceId}")
if (attempt > 3) throw new TooManyLockAttempts("Too many attempts to acquire database lease")
try {
inTransaction { implicit connection =>
acquireExclusiveTableLock()
getCurrentLease match {
case Some(lease) =>
if (lease.instanceId == instanceId || lease.expired)
updateLease(instanceId, leaseDuration)
else
throw new LeaseException(s"The database is locked by instance ID=${lease.instanceId}")
case None =>
updateLease(instanceId, leaseDuration, insertNew = true)
}
}
logger.debug("database lease was successfully acquired")
} catch {
case e: PSQLException if (e.getServerErrorMessage != null && e.getServerErrorMessage.getSQLState == "42P01") =>
withConnection {
connection =>
logger.warn(s"table $LeaseTable does not exist, trying to recreate it")
initializeLeaseTable(connection)
obtainDatabaseLease(instanceId, leaseDuration, attempt + 1)
}
}
}
private def initializeLeaseTable(implicit connection: Connection): Unit = {
using(connection.createStatement()) {
statement =>
// allow only one row in the ownership lease table
statement.executeUpdate(s"CREATE TABLE IF NOT EXISTS $LeaseTable (id INTEGER PRIMARY KEY default(1), expires_at TIMESTAMP NOT NULL, instance VARCHAR NOT NULL, CONSTRAINT one_row CHECK (id = 1))")
}
}
private def acquireExclusiveTableLock()(implicit connection: Connection): Unit = {
using(connection.createStatement()) {
statement =>
statement.executeUpdate(s"SET lock_timeout TO '${LockTimeout.toSeconds}s'")
statement.executeUpdate(s"LOCK TABLE $LeaseTable IN ACCESS EXCLUSIVE MODE")
}
}
private def checkDatabaseLease(connection: Connection, instanceId: UUID, lockExceptionHandler: LockExceptionHandler): Unit = {
Try {
getCurrentLease(connection) match {
case Some(lease) =>
if (!(lease.instanceId == instanceId) || lease.expired) {
logger.info(s"database lease: $lease")
throw new LockException("This Eclair instance is not a database owner")
}
case None =>
throw new LockException("No database lease info")
}
} match {
case Success(_) => ()
case Failure(ex) =>
val lex = ex match {
case e: LockException => e
case t: Throwable => new LockException("Cannot check database lease", Some(t))
}
lockExceptionHandler(lex)
throw lex
}
}
private def getCurrentLease(implicit connection: Connection): Option[LockLease] = {
using(connection.createStatement()) {
statement =>
val rs = statement.executeQuery(s"SELECT expires_at, instance, now() > expires_at AS expired FROM $LeaseTable WHERE id = 1")
if (rs.next())
Some(LockLease(
expiresAt = rs.getTimestamp("expires_at"),
instanceId = UUID.fromString(rs.getString("instance")),
expired = rs.getBoolean("expired")))
else
None
}
}
private def updateLease(instanceId: UUID, leaseDuration: FiniteDuration, insertNew: Boolean = false)(implicit connection: Connection): Unit = {
val sql = if (insertNew)
s"INSERT INTO $LeaseTable (expires_at, instance) VALUES (now() + ?, ?)"
else
s"UPDATE $LeaseTable SET expires_at = now() + ?, instance = ? WHERE id = 1"
using(connection.prepareStatement(sql)) {
statement =>
statement.setObject(1, new PGInterval(s"${
leaseDuration.toSeconds
} seconds"))
statement.setString(2, instanceId.toString)
statement.executeUpdate()
}
}
}

View File

@ -101,7 +101,7 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging {
}
}
def addHtlcInfo(channelId: ByteVector32, commitmentNumber: Long, paymentHash: ByteVector32, cltvExpiry: CltvExpiry): Unit = {
override def addHtlcInfo(channelId: ByteVector32, commitmentNumber: Long, paymentHash: ByteVector32, cltvExpiry: CltvExpiry): Unit = {
using(sqlite.prepareStatement("INSERT INTO htlc_infos VALUES (?, ?, ?, ?)")) { statement =>
statement.setBytes(1, channelId.toArray)
statement.setLong(2, commitmentNumber)
@ -111,7 +111,7 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging {
}
}
def listHtlcInfos(channelId: ByteVector32, commitmentNumber: Long): Seq[(ByteVector32, CltvExpiry)] = {
override def listHtlcInfos(channelId: ByteVector32, commitmentNumber: Long): Seq[(ByteVector32, CltvExpiry)] = {
using(sqlite.prepareStatement("SELECT payment_hash, cltv_expiry FROM htlc_infos WHERE channel_id=? AND commitment_number=?")) { statement =>
statement.setBytes(1, channelId.toArray)
statement.setLong(2, commitmentNumber)

View File

@ -16,38 +16,11 @@
package fr.acinq.eclair.db.sqlite
import java.sql.{Connection, ResultSet, Statement}
import java.util.UUID
import java.sql.{Connection, Statement}
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.MilliSatoshi
import scodec.Codec
import scodec.bits.{BitVector, ByteVector}
import fr.acinq.eclair.db.jdbc.JdbcUtils
import scala.collection.immutable.Queue
object SqliteUtils {
/**
* This helper makes sure statements are correctly closed.
*
* @param inTransaction if set to true, all updates in the block will be run in a transaction.
*/
def using[T <: Statement, U](statement: T, inTransaction: Boolean = false)(block: T => U): U = {
try {
if (inTransaction) statement.getConnection.setAutoCommit(false)
val res = block(statement)
if (inTransaction) statement.getConnection.commit()
res
} catch {
case t: Exception =>
if (inTransaction) statement.getConnection.rollback()
throw t
} finally {
if (inTransaction) statement.getConnection.setAutoCommit(true)
if (statement != null) statement.close()
}
}
object SqliteUtils extends JdbcUtils {
/**
* Several logical databases (channels, network, peers) may be stored in the same physical sqlite database.
@ -72,19 +45,6 @@ object SqliteUtils {
statement.executeUpdate(s"UPDATE versions SET version=$newVersion WHERE db_name='$db_name'")
}
/**
* This helper assumes that there is a "data" column available, decodable with the provided codec
*
* TODO: we should use an scala.Iterator instead
*/
def codecSequence[T](rs: ResultSet, codec: Codec[T]): Seq[T] = {
var q: Queue[T] = Queue()
while (rs.next()) {
q = q :+ codec.decode(BitVector(rs.getBytes("data"))).require.value
}
q
}
/**
* Obtain an exclusive lock on a sqlite database. This is useful when we want to make sure that only one process
* accesses the database file (see https://www.sqlite.org/pragma.html).
@ -99,48 +59,4 @@ object SqliteUtils {
statement.executeUpdate("INSERT INTO dummy_table_for_locking VALUES (42)")
}
case class ExtendedResultSet(rs: ResultSet) {
def getBitVectorOpt(columnLabel: String): Option[BitVector] = Option(rs.getBytes(columnLabel)).map(BitVector(_))
def getByteVector(columnLabel: String): ByteVector = ByteVector(rs.getBytes(columnLabel))
def getByteVectorNullable(columnLabel: String): ByteVector = {
val result = rs.getBytes(columnLabel)
if (rs.wasNull()) ByteVector.empty else ByteVector(result)
}
def getByteVector32(columnLabel: String): ByteVector32 = ByteVector32(ByteVector(rs.getBytes(columnLabel)))
def getByteVector32Nullable(columnLabel: String): Option[ByteVector32] = {
val bytes = rs.getBytes(columnLabel)
if (rs.wasNull()) None else Some(ByteVector32(ByteVector(bytes)))
}
def getStringNullable(columnLabel: String): Option[String] = {
val result = rs.getString(columnLabel)
if (rs.wasNull()) None else Some(result)
}
def getLongNullable(columnLabel: String): Option[Long] = {
val result = rs.getLong(columnLabel)
if (rs.wasNull()) None else Some(result)
}
def getUUIDNullable(label: String): Option[UUID] = {
val result = rs.getString(label)
if (rs.wasNull()) None else Some(UUID.fromString(result))
}
def getMilliSatoshiNullable(label: String): Option[MilliSatoshi] = {
val result = rs.getLong(label)
if (rs.wasNull()) None else Some(MilliSatoshi(result))
}
}
object ExtendedResultSet {
implicit def conv(rs: ResultSet): ExtendedResultSet = ExtendedResultSet(rs)
}
}

View File

@ -47,7 +47,7 @@
<logger name="fr.acinq.eclair.channel" level="WARN"/>
<logger name="fr.acinq.eclair.Diagnostics" level="OFF"/>
<logger name="fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher" level="OFF"/>
<logger name="fr.acinq.eclair.db.BackupHandler" level="OFF"/>
<logger name="fr.acinq.eclair.db.FileBackupHandler" level="OFF"/>
<root level="INFO">
<!--appender-ref ref="FILE"/>

View File

@ -16,6 +16,7 @@
package fr.acinq.eclair
import java.util.UUID
import java.util.concurrent.atomic.AtomicLong
import com.typesafe.config.{Config, ConfigFactory}
@ -39,7 +40,7 @@ class StartupSpec extends AnyFunSuite {
val keyManager = new LocalKeyManager(seed = randomBytes32, chainHash = Block.TestnetGenesisBlock.hash)
val feeEstimator = new TestConstants.TestFeeEstimator
val db = TestConstants.inMemoryDb()
NodeParams.makeNodeParams(conf, keyManager, None, db, blockCount, feeEstimator)
NodeParams.makeNodeParams(conf, UUID.fromString("01234567-0123-4567-89ab-0123456789ab"), keyManager, None, db, blockCount, feeEstimator)
}
test("check configuration") {

View File

@ -16,9 +16,11 @@
package fr.acinq.eclair
import java.sql.{Connection, DriverManager}
import java.sql.{Connection, DriverManager, Statement}
import java.util.UUID
import java.util.concurrent.atomic.AtomicLong
import com.opentable.db.postgres.embedded.EmbeddedPostgres
import fr.acinq.bitcoin.Crypto.PrivateKey
import fr.acinq.bitcoin.{Block, ByteVector32, Script}
import fr.acinq.eclair.FeatureSupport.Optional
@ -27,6 +29,9 @@ import fr.acinq.eclair.NodeParams.BITCOIND
import fr.acinq.eclair.blockchain.fee._
import fr.acinq.eclair.crypto.LocalKeyManager
import fr.acinq.eclair.db._
import fr.acinq.eclair.db.pg.PgUtils.NoLock
import fr.acinq.eclair.db.pg._
import fr.acinq.eclair.db.sqlite._
import fr.acinq.eclair.io.Peer
import fr.acinq.eclair.router.Router.RouterConf
import fr.acinq.eclair.wire.{Color, EncodingType, NodeAddress}
@ -57,9 +62,62 @@ object TestConstants {
}
}
def sqliteInMemory() = DriverManager.getConnection("jdbc:sqlite::memory:")
sealed trait TestDatabases {
val connection: Connection
def network(): NetworkDb
def audit(): AuditDb
def channels(): ChannelsDb
def peers(): PeersDb
def payments(): PaymentsDb
def pendingRelay(): PendingRelayDb
def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int
def close(): Unit
}
def inMemoryDb(connection: Connection = sqliteInMemory()): Databases = Databases.databaseByConnections(connection, connection, connection)
case class TestSqliteDatabases(connection: Connection = sqliteInMemory()) extends TestDatabases {
override def network(): NetworkDb = new SqliteNetworkDb(connection)
override def audit(): AuditDb = new SqliteAuditDb(connection)
override def channels(): ChannelsDb = new SqliteChannelsDb(connection)
override def peers(): PeersDb = new SqlitePeersDb(connection)
override def payments(): PaymentsDb = new SqlitePaymentsDb(connection)
override def pendingRelay(): PendingRelayDb = new SqlitePendingRelayDb(connection)
override def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int = SqliteUtils.getVersion(statement, db_name, currentVersion)
override def close(): Unit = ()
}
case class TestPgDatabases() extends TestDatabases {
private val pg = EmbeddedPostgres.start()
override val connection: Connection = pg.getPostgresDatabase.getConnection
import com.zaxxer.hikari.{HikariConfig, HikariDataSource}
val config = new HikariConfig
config.setDataSource(pg.getPostgresDatabase)
implicit val ds = new HikariDataSource(config)
implicit val lock = NoLock
override def network(): NetworkDb = new PgNetworkDb
override def audit(): AuditDb = new PgAuditDb
override def channels(): ChannelsDb = new PgChannelsDb
override def peers(): PeersDb = new PgPeersDb
override def payments(): PaymentsDb = new PgPaymentsDb
override def pendingRelay(): PendingRelayDb = new PgPendingRelayDb
override def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int = PgUtils.getVersion(statement, db_name, currentVersion)
override def close(): Unit = pg.close()
}
def sqliteInMemory(): Connection = DriverManager.getConnection("jdbc:sqlite::memory:")
def forAllDbs(f: TestDatabases => Unit): Unit = {
def using(dbs: TestDatabases)(g: TestDatabases => Unit): Unit = try g(dbs) finally dbs.close()
using(TestSqliteDatabases())(f)
using(TestPgDatabases())(f)
}
def inMemoryDb(connection: Connection = sqliteInMemory()): Databases = Databases.sqliteDatabaseByConnections(connection, connection, connection)
object Alice {
val seed = ByteVector32(ByteVector.fill(32)(1))
@ -139,7 +197,8 @@ object TestConstants {
),
socksProxy_opt = None,
maxPaymentAttempts = 5,
enableTrampolinePayment = true
enableTrampolinePayment = true,
instanceId = UUID.fromString("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
)
def channelParams = Peer.makeChannelParams(
@ -225,7 +284,8 @@ object TestConstants {
),
socksProxy_opt = None,
maxPaymentAttempts = 5,
enableTrampolinePayment = true
enableTrampolinePayment = true,
instanceId = UUID.fromString("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")
)
def channelParams = Peer.makeChannelParams(

View File

@ -20,15 +20,15 @@ import java.io.File
import java.sql.DriverManager
import java.util.UUID
import akka.actor.ActorSystem
import akka.testkit.{TestKit, TestProbe}
import akka.testkit.TestProbe
import fr.acinq.eclair.channel.ChannelPersisted
import fr.acinq.eclair.db.Databases.FileBackup
import fr.acinq.eclair.db.sqlite.SqliteChannelsDb
import fr.acinq.eclair.wire.ChannelCodecsSpec
import fr.acinq.eclair.{TestConstants, TestKitBaseClass, TestUtils, randomBytes32}
import org.scalatest.funsuite.AnyFunSuiteLike
class BackupHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike {
class FileBackupHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike {
test("process backups") {
val db = TestConstants.inMemoryDb()
@ -40,7 +40,7 @@ class BackupHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike {
db.channels.addOrUpdateChannel(channel)
assert(db.channels.listLocalChannels() == Seq(channel))
val handler = system.actorOf(BackupHandler.props(db, dest, None))
val handler = system.actorOf(FileBackupHandler.props(db.asInstanceOf[FileBackup], dest, None))
val probe = TestProbe()
system.eventStream.subscribe(probe.ref, classOf[BackupEvent])

View File

@ -20,10 +20,11 @@ import java.util.UUID
import fr.acinq.bitcoin.Crypto.PrivateKey
import fr.acinq.bitcoin.{ByteVector32, Transaction}
import fr.acinq.eclair.TestConstants.{TestPgDatabases, TestSqliteDatabases, forAllDbs}
import fr.acinq.eclair._
import fr.acinq.eclair.channel.{ChannelErrorOccurred, LocalError, NetworkFeePaid, RemoteError}
import fr.acinq.eclair.db.jdbc.JdbcUtils.using
import fr.acinq.eclair.db.sqlite.SqliteAuditDb
import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, using}
import fr.acinq.eclair.payment._
import org.scalatest.Tag
import org.scalatest.funsuite.AnyFunSuite
@ -36,372 +37,394 @@ class SqliteAuditDbSpec extends AnyFunSuite {
val ZERO_UUID = UUID.fromString("00000000-0000-0000-0000-000000000000")
test("init sqlite 2 times in a row") {
val sqlite = TestConstants.sqliteInMemory()
val db1 = new SqliteAuditDb(sqlite)
val db2 = new SqliteAuditDb(sqlite)
forAllDbs { dbs =>
val db1 = dbs.audit()
val db2 = dbs.audit()
}
}
test("add/list events") {
val sqlite = TestConstants.sqliteInMemory()
val db = new SqliteAuditDb(sqlite)
forAllDbs { dbs =>
val db = dbs.audit()
val e1 = PaymentSent(ZERO_UUID, randomBytes32, randomBytes32, 40000 msat, randomKey.publicKey, PaymentSent.PartialPayment(ZERO_UUID, 42000 msat, 1000 msat, randomBytes32, None) :: Nil)
val pp2a = PaymentReceived.PartialPayment(42000 msat, randomBytes32)
val pp2b = PaymentReceived.PartialPayment(42100 msat, randomBytes32)
val e2 = PaymentReceived(randomBytes32, pp2a :: pp2b :: Nil)
val e3 = ChannelPaymentRelayed(42000 msat, 1000 msat, randomBytes32, randomBytes32, randomBytes32)
val e4 = NetworkFeePaid(null, randomKey.publicKey, randomBytes32, Transaction(0, Seq.empty, Seq.empty, 0), 42 sat, "mutual")
val pp5a = PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None, timestamp = 0)
val pp5b = PaymentSent.PartialPayment(UUID.randomUUID(), 42100 msat, 900 msat, randomBytes32, None, timestamp = 1)
val e5 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 84100 msat, randomKey.publicKey, pp5a :: pp5b :: Nil)
val pp6 = PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None, timestamp = (System.currentTimeMillis.milliseconds + 10.minutes).toMillis)
val e6 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 42000 msat, randomKey.publicKey, pp6 :: Nil)
val e7 = ChannelLifecycleEvent(randomBytes32, randomKey.publicKey, 456123000 sat, isFunder = true, isPrivate = false, "mutual")
val e8 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
val e9 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true)
val e10 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(20000 msat, randomBytes32), PaymentRelayed.Part(22000 msat, randomBytes32)), Seq(PaymentRelayed.Part(10000 msat, randomBytes32), PaymentRelayed.Part(12000 msat, randomBytes32), PaymentRelayed.Part(15000 msat, randomBytes32)))
val multiPartPaymentHash = randomBytes32
val now = System.currentTimeMillis
val e11 = ChannelPaymentRelayed(13000 msat, 11000 msat, multiPartPaymentHash, randomBytes32, randomBytes32, now)
val e12 = ChannelPaymentRelayed(15000 msat, 12500 msat, multiPartPaymentHash, randomBytes32, randomBytes32, now)
val e1 = PaymentSent(ZERO_UUID, randomBytes32, randomBytes32, 40000 msat, randomKey.publicKey, PaymentSent.PartialPayment(ZERO_UUID, 42000 msat, 1000 msat, randomBytes32, None) :: Nil)
val pp2a = PaymentReceived.PartialPayment(42000 msat, randomBytes32)
val pp2b = PaymentReceived.PartialPayment(42100 msat, randomBytes32)
val e2 = PaymentReceived(randomBytes32, pp2a :: pp2b :: Nil)
val e3 = ChannelPaymentRelayed(42000 msat, 1000 msat, randomBytes32, randomBytes32, randomBytes32)
val e4 = NetworkFeePaid(null, randomKey.publicKey, randomBytes32, Transaction(0, Seq.empty, Seq.empty, 0), 42 sat, "mutual")
val pp5a = PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None, timestamp = 0)
val pp5b = PaymentSent.PartialPayment(UUID.randomUUID(), 42100 msat, 900 msat, randomBytes32, None, timestamp = 1)
val e5 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 84100 msat, randomKey.publicKey, pp5a :: pp5b :: Nil)
val pp6 = PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None, timestamp = (System.currentTimeMillis.milliseconds + 10.minutes).toMillis)
val e6 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 42000 msat, randomKey.publicKey, pp6 :: Nil)
val e7 = ChannelLifecycleEvent(randomBytes32, randomKey.publicKey, 456123000 sat, isFunder = true, isPrivate = false, "mutual")
val e8 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
val e9 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true)
val e10 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(20000 msat, randomBytes32), PaymentRelayed.Part(22000 msat, randomBytes32)), Seq(PaymentRelayed.Part(10000 msat, randomBytes32), PaymentRelayed.Part(12000 msat, randomBytes32), PaymentRelayed.Part(15000 msat, randomBytes32)))
val multiPartPaymentHash = randomBytes32
val now = System.currentTimeMillis
val e11 = ChannelPaymentRelayed(13000 msat, 11000 msat, multiPartPaymentHash, randomBytes32, randomBytes32, now)
val e12 = ChannelPaymentRelayed(15000 msat, 12500 msat, multiPartPaymentHash, randomBytes32, randomBytes32, now)
db.add(e1)
db.add(e2)
db.add(e3)
db.add(e4)
db.add(e5)
db.add(e6)
db.add(e7)
db.add(e8)
db.add(e9)
db.add(e10)
db.add(e11)
db.add(e12)
db.add(e1)
db.add(e2)
db.add(e3)
db.add(e4)
db.add(e5)
db.add(e6)
db.add(e7)
db.add(e8)
db.add(e9)
db.add(e10)
db.add(e11)
db.add(e12)
assert(db.listSent(from = 0L, to = (System.currentTimeMillis.milliseconds + 15.minute).toMillis).toSet === Set(e1, e5, e6))
assert(db.listSent(from = 100000L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).toList === List(e1))
assert(db.listReceived(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).toList === List(e2))
assert(db.listRelayed(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).toList === List(e3, e10, e11, e12))
assert(db.listNetworkFees(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).size === 1)
assert(db.listNetworkFees(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).head.txType === "mutual")
assert(db.listSent(from = 0L, to = (System.currentTimeMillis.milliseconds + 15.minute).toMillis).toSet === Set(e1, e5, e6))
assert(db.listSent(from = 100000L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).toList === List(e1))
assert(db.listReceived(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).toList === List(e2))
assert(db.listRelayed(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).toList === List(e3, e10, e11, e12))
assert(db.listNetworkFees(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).size === 1)
assert(db.listNetworkFees(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).head.txType === "mutual")
}
}
test("stats") {
val sqlite = TestConstants.sqliteInMemory()
val db = new SqliteAuditDb(sqlite)
forAllDbs { dbs =>
val db = dbs.audit()
val n2 = randomKey.publicKey
val n3 = randomKey.publicKey
val n4 = randomKey.publicKey
val n2 = randomKey.publicKey
val n3 = randomKey.publicKey
val n4 = randomKey.publicKey
val c1 = randomBytes32
val c2 = randomBytes32
val c3 = randomBytes32
val c4 = randomBytes32
val c5 = randomBytes32
val c6 = randomBytes32
val c1 = randomBytes32
val c2 = randomBytes32
val c3 = randomBytes32
val c4 = randomBytes32
val c5 = randomBytes32
val c6 = randomBytes32
db.add(ChannelPaymentRelayed(46000 msat, 44000 msat, randomBytes32, c6, c1))
db.add(ChannelPaymentRelayed(41000 msat, 40000 msat, randomBytes32, c6, c1))
db.add(ChannelPaymentRelayed(43000 msat, 42000 msat, randomBytes32, c5, c1))
db.add(ChannelPaymentRelayed(42000 msat, 40000 msat, randomBytes32, c5, c2))
db.add(ChannelPaymentRelayed(45000 msat, 40000 msat, randomBytes32, c5, c6))
db.add(TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(25000 msat, c6)), Seq(PaymentRelayed.Part(20000 msat, c4))))
db.add(TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(46000 msat, c6)), Seq(PaymentRelayed.Part(16000 msat, c2), PaymentRelayed.Part(10000 msat, c4), PaymentRelayed.Part(14000 msat, c4))))
db.add(ChannelPaymentRelayed(46000 msat, 44000 msat, randomBytes32, c6, c1))
db.add(ChannelPaymentRelayed(41000 msat, 40000 msat, randomBytes32, c6, c1))
db.add(ChannelPaymentRelayed(43000 msat, 42000 msat, randomBytes32, c5, c1))
db.add(ChannelPaymentRelayed(42000 msat, 40000 msat, randomBytes32, c5, c2))
db.add(ChannelPaymentRelayed(45000 msat, 40000 msat, randomBytes32, c5, c6))
db.add(TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(25000 msat, c6)), Seq(PaymentRelayed.Part(20000 msat, c4))))
db.add(TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(46000 msat, c6)), Seq(PaymentRelayed.Part(16000 msat, c2), PaymentRelayed.Part(10000 msat, c4), PaymentRelayed.Part(14000 msat, c4))))
db.add(NetworkFeePaid(null, n2, c2, Transaction(0, Seq.empty, Seq.empty, 0), 200 sat, "funding"))
db.add(NetworkFeePaid(null, n2, c2, Transaction(0, Seq.empty, Seq.empty, 0), 300 sat, "mutual"))
db.add(NetworkFeePaid(null, n3, c3, Transaction(0, Seq.empty, Seq.empty, 0), 400 sat, "funding"))
db.add(NetworkFeePaid(null, n4, c4, Transaction(0, Seq.empty, Seq.empty, 0), 500 sat, "funding"))
db.add(NetworkFeePaid(null, n2, c2, Transaction(0, Seq.empty, Seq.empty, 0), 200 sat, "funding"))
db.add(NetworkFeePaid(null, n2, c2, Transaction(0, Seq.empty, Seq.empty, 0), 300 sat, "mutual"))
db.add(NetworkFeePaid(null, n3, c3, Transaction(0, Seq.empty, Seq.empty, 0), 400 sat, "funding"))
db.add(NetworkFeePaid(null, n4, c4, Transaction(0, Seq.empty, Seq.empty, 0), 500 sat, "funding"))
// NB: we only count a relay fee for the outgoing channel, no the incoming one.
assert(db.stats(0, System.currentTimeMillis + 1).toSet === Set(
Stats(channelId = c1, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 0 sat),
Stats(channelId = c1, direction = "OUT", avgPaymentAmount = 42 sat, paymentCount = 3, relayFee = 4 sat, networkFee = 0 sat),
Stats(channelId = c2, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 500 sat),
Stats(channelId = c2, direction = "OUT", avgPaymentAmount = 28 sat, paymentCount = 2, relayFee = 4 sat, networkFee = 500 sat),
Stats(channelId = c3, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 400 sat),
Stats(channelId = c3, direction = "OUT", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 400 sat),
Stats(channelId = c4, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 500 sat),
Stats(channelId = c4, direction = "OUT", avgPaymentAmount = 22 sat, paymentCount = 2, relayFee = 9 sat, networkFee = 500 sat),
Stats(channelId = c5, direction = "IN", avgPaymentAmount = 43 sat, paymentCount = 3, relayFee = 0 sat, networkFee = 0 sat),
Stats(channelId = c5, direction = "OUT", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 0 sat),
Stats(channelId = c6, direction = "IN", avgPaymentAmount = 39 sat, paymentCount = 4, relayFee = 0 sat, networkFee = 0 sat),
Stats(channelId = c6, direction = "OUT", avgPaymentAmount = 40 sat, paymentCount = 1, relayFee = 5 sat, networkFee = 0 sat),
))
// NB: we only count a relay fee for the outgoing channel, no the incoming one.
assert(db.stats(0, System.currentTimeMillis + 1).toSet === Set(
Stats(channelId = c1, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 0 sat),
Stats(channelId = c1, direction = "OUT", avgPaymentAmount = 42 sat, paymentCount = 3, relayFee = 4 sat, networkFee = 0 sat),
Stats(channelId = c2, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 500 sat),
Stats(channelId = c2, direction = "OUT", avgPaymentAmount = 28 sat, paymentCount = 2, relayFee = 4 sat, networkFee = 500 sat),
Stats(channelId = c3, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 400 sat),
Stats(channelId = c3, direction = "OUT", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 400 sat),
Stats(channelId = c4, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 500 sat),
Stats(channelId = c4, direction = "OUT", avgPaymentAmount = 22 sat, paymentCount = 2, relayFee = 9 sat, networkFee = 500 sat),
Stats(channelId = c5, direction = "IN", avgPaymentAmount = 43 sat, paymentCount = 3, relayFee = 0 sat, networkFee = 0 sat),
Stats(channelId = c5, direction = "OUT", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 0 sat),
Stats(channelId = c6, direction = "IN", avgPaymentAmount = 39 sat, paymentCount = 4, relayFee = 0 sat, networkFee = 0 sat),
Stats(channelId = c6, direction = "OUT", avgPaymentAmount = 40 sat, paymentCount = 1, relayFee = 5 sat, networkFee = 0 sat),
))
}
}
ignore("relay stats performance", Tag("perf")) {
val sqlite = TestConstants.sqliteInMemory()
val db = new SqliteAuditDb(sqlite)
val nodeCount = 100
val channelCount = 1000
val eventCount = 100000
val nodeIds = (1 to nodeCount).map(_ => randomKey.publicKey)
val channelIds = (1 to channelCount).map(_ => randomBytes32)
// Fund channels.
channelIds.foreach(channelId => {
val nodeId = nodeIds(Random.nextInt(nodeCount))
db.add(NetworkFeePaid(null, nodeId, channelId, Transaction(0, Seq.empty, Seq.empty, 0), 100 sat, "funding"))
})
// Add relay events.
(1 to eventCount).foreach(_ => {
// 25% trampoline relays.
if (Random.nextInt(4) == 0) {
val outgoingCount = 1 + Random.nextInt(4)
val incoming = Seq(PaymentRelayed.Part(10000 msat, randomBytes32))
val outgoing = (1 to outgoingCount).map(_ => PaymentRelayed.Part(Random.nextInt(2000).msat, channelIds(Random.nextInt(channelCount))))
db.add(TrampolinePaymentRelayed(randomBytes32, incoming, outgoing))
} else {
val toChannelId = channelIds(Random.nextInt(channelCount))
db.add(ChannelPaymentRelayed(10000 msat, Random.nextInt(10000).msat, randomBytes32, randomBytes32, toChannelId))
}
})
// Test starts here.
val start = System.currentTimeMillis
assert(db.stats(0, start + 1).nonEmpty)
val end = System.currentTimeMillis
fail(s"took ${end - start}ms")
forAllDbs { dbs =>
val db = dbs.audit()
val nodeCount = 100
val channelCount = 1000
val eventCount = 100000
val nodeIds = (1 to nodeCount).map(_ => randomKey.publicKey)
val channelIds = (1 to channelCount).map(_ => randomBytes32)
// Fund channels.
channelIds.foreach(channelId => {
val nodeId = nodeIds(Random.nextInt(nodeCount))
db.add(NetworkFeePaid(null, nodeId, channelId, Transaction(0, Seq.empty, Seq.empty, 0), 100 sat, "funding"))
})
// Add relay events.
(1 to eventCount).foreach(_ => {
// 25% trampoline relays.
if (Random.nextInt(4) == 0) {
val outgoingCount = 1 + Random.nextInt(4)
val incoming = Seq(PaymentRelayed.Part(10000 msat, randomBytes32))
val outgoing = (1 to outgoingCount).map(_ => PaymentRelayed.Part(Random.nextInt(2000).msat, channelIds(Random.nextInt(channelCount))))
db.add(TrampolinePaymentRelayed(randomBytes32, incoming, outgoing))
} else {
val toChannelId = channelIds(Random.nextInt(channelCount))
db.add(ChannelPaymentRelayed(10000 msat, Random.nextInt(10000).msat, randomBytes32, randomBytes32, toChannelId))
}
})
// Test starts here.
val start = System.currentTimeMillis
assert(db.stats(0, start + 1).nonEmpty)
val end = System.currentTimeMillis
fail(s"took ${end - start}ms")
}
}
test("handle migration version 1 -> 4") {
val connection = TestConstants.sqliteInMemory()
forAllDbs {
case _: TestPgDatabases => // no migration
case dbs: TestSqliteDatabases =>
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
val connection = dbs.connection
// simulate existing previous version db
using(connection.createStatement()) { statement =>
getVersion(statement, "audit", 1)
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (amount_in_msat INTEGER NOT NULL, amount_out_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event STRING NOT NULL, timestamp INTEGER NOT NULL)")
// simulate existing previous version db
using(connection.createStatement()) { statement =>
getVersion(statement, "audit", 1)
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (amount_in_msat INTEGER NOT NULL, amount_out_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event STRING NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS balance_updated_idx ON balance_updated(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS balance_updated_idx ON balance_updated(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
}
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 1) // we expect version 1
}
val ps = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 42000 msat, PrivateKey(ByteVector32.One).publicKey, PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None) :: Nil)
val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, None)
val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, None)
val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 84003 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil)
val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true)
// add a row (no ID on sent)
using(connection.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setLong(1, ps.recipientAmount.toLong)
statement.setLong(2, ps.feesPaid.toLong)
statement.setBytes(3, ps.paymentHash.toArray)
statement.setBytes(4, ps.paymentPreimage.toArray)
statement.setBytes(5, ps.parts.head.toChannelId.toArray)
statement.setLong(6, ps.timestamp)
statement.executeUpdate()
}
val migratedDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 4) // version changed from 1 -> 4
}
// existing rows in the 'sent' table will use id=00000000-0000-0000-0000-000000000000 as default
assert(migratedDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID)))))
val postMigrationDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 4) // version 4
}
postMigrationDb.add(ps1)
postMigrationDb.add(e1)
postMigrationDb.add(e2)
// the old record will have the UNKNOWN_UUID but the new ones will have their actual id
val expected = Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID))), ps1)
assert(postMigrationDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === expected)
}
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 1) // we expect version 1
}
val ps = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 42000 msat, PrivateKey(ByteVector32.One).publicKey, PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None) :: Nil)
val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, None)
val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, None)
val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 84003 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil)
val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true)
// add a row (no ID on sent)
using(connection.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setLong(1, ps.recipientAmount.toLong)
statement.setLong(2, ps.feesPaid.toLong)
statement.setBytes(3, ps.paymentHash.toArray)
statement.setBytes(4, ps.paymentPreimage.toArray)
statement.setBytes(5, ps.parts.head.toChannelId.toArray)
statement.setLong(6, ps.timestamp)
statement.executeUpdate()
}
val migratedDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 4) // version changed from 1 -> 4
}
// existing rows in the 'sent' table will use id=00000000-0000-0000-0000-000000000000 as default
assert(migratedDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID)))))
val postMigrationDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 4) // version 4
}
postMigrationDb.add(ps1)
postMigrationDb.add(e1)
postMigrationDb.add(e2)
// the old record will have the UNKNOWN_UUID but the new ones will have their actual id
val expected = Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID))), ps1)
assert(postMigrationDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === expected)
}
test("handle migration version 2 -> 4") {
val connection = TestConstants.sqliteInMemory()
forAllDbs {
case _: TestPgDatabases => // no migration
case dbs: TestSqliteDatabases =>
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
val connection = dbs.connection
// simulate existing previous version db
using(connection.createStatement()) { statement =>
getVersion(statement, "audit", 2)
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL, id BLOB NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (amount_in_msat INTEGER NOT NULL, amount_out_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event STRING NOT NULL, timestamp INTEGER NOT NULL)")
// simulate existing previous version db
using(connection.createStatement()) { statement =>
getVersion(statement, "audit", 2)
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL, id BLOB NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (amount_in_msat INTEGER NOT NULL, amount_out_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event STRING NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS balance_updated_idx ON balance_updated(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS balance_updated_idx ON balance_updated(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
}
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 2) // version 2 is deployed now
}
val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true)
val migratedDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 4) // version changed from 2 -> 4
}
migratedDb.add(e1)
val postMigrationDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 4) // version 4
}
postMigrationDb.add(e2)
}
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 2) // version 2 is deployed now
}
val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true)
val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true)
val migratedDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 4) // version changed from 2 -> 4
}
migratedDb.add(e1)
val postMigrationDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 4) // version 4
}
postMigrationDb.add(e2)
}
test("handle migration version 3 -> 4") {
val connection = TestConstants.sqliteInMemory()
forAllDbs {
case _: TestPgDatabases => // no migration
case dbs: TestSqliteDatabases =>
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
val connection = dbs.connection
// simulate existing previous version db
using(connection.createStatement()) { statement =>
getVersion(statement, "audit", 3)
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL, id BLOB NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (amount_in_msat INTEGER NOT NULL, amount_out_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
// simulate existing previous version db
using(connection.createStatement()) { statement =>
getVersion(statement, "audit", 3)
statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL, id BLOB NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (amount_in_msat INTEGER NOT NULL, amount_out_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS balance_updated_idx ON balance_updated(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS balance_updated_idx ON balance_updated(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)")
}
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 3) // version 3 is deployed now
}
val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 100)
val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 110)
val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil)
for (pp <- Seq(pp1, pp2)) {
using(connection.prepareStatement("INSERT INTO sent (amount_msat, fees_msat, payment_hash, payment_preimage, to_channel_id, timestamp, id) VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement =>
statement.setLong(1, pp.amount.toLong)
statement.setLong(2, pp.feesPaid.toLong)
statement.setBytes(3, ps1.paymentHash.toArray)
statement.setBytes(4, ps1.paymentPreimage.toArray)
statement.setBytes(5, pp.toChannelId.toArray)
statement.setLong(6, pp.timestamp)
statement.setBytes(7, pp.id.toString.getBytes)
statement.executeUpdate()
}
}
val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105)
val relayed2 = ChannelPaymentRelayed(650 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 115)
for (relayed <- Seq(relayed1, relayed2)) {
using(connection.prepareStatement("INSERT INTO relayed (amount_in_msat, amount_out_msat, payment_hash, from_channel_id, to_channel_id, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setLong(1, relayed.amountIn.toLong)
statement.setLong(2, relayed.amountOut.toLong)
statement.setBytes(3, relayed.paymentHash.toArray)
statement.setBytes(4, relayed.fromChannelId.toArray)
statement.setBytes(5, relayed.toChannelId.toArray)
statement.setLong(6, relayed.timestamp)
statement.executeUpdate()
}
}
val migratedDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 4) // version changed from 3 -> 4
}
assert(migratedDb.listSent(50, 150).toSet === Set(
ps1.copy(id = pp1.id, recipientAmount = pp1.amount, parts = pp1 :: Nil),
ps1.copy(id = pp2.id, recipientAmount = pp2.amount, parts = pp2 :: Nil)
))
assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2))
val postMigrationDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 4) // version 4
}
val ps2 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, randomKey.publicKey, Seq(
PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 160),
PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 165)
))
val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), 150)
postMigrationDb.add(ps2)
assert(postMigrationDb.listSent(155, 200) === Seq(ps2))
postMigrationDb.add(relayed3)
assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3))
}
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 3) // version 3 is deployed now
}
val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 100)
val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 110)
val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil)
for (pp <- Seq(pp1, pp2)) {
using(connection.prepareStatement("INSERT INTO sent (amount_msat, fees_msat, payment_hash, payment_preimage, to_channel_id, timestamp, id) VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement =>
statement.setLong(1, pp.amount.toLong)
statement.setLong(2, pp.feesPaid.toLong)
statement.setBytes(3, ps1.paymentHash.toArray)
statement.setBytes(4, ps1.paymentPreimage.toArray)
statement.setBytes(5, pp.toChannelId.toArray)
statement.setLong(6, pp.timestamp)
statement.setBytes(7, pp.id.toString.getBytes)
statement.executeUpdate()
}
}
val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105)
val relayed2 = ChannelPaymentRelayed(650 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 115)
for (relayed <- Seq(relayed1, relayed2)) {
using(connection.prepareStatement("INSERT INTO relayed (amount_in_msat, amount_out_msat, payment_hash, from_channel_id, to_channel_id, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setLong(1, relayed.amountIn.toLong)
statement.setLong(2, relayed.amountOut.toLong)
statement.setBytes(3, relayed.paymentHash.toArray)
statement.setBytes(4, relayed.fromChannelId.toArray)
statement.setBytes(5, relayed.toChannelId.toArray)
statement.setLong(6, relayed.timestamp)
statement.executeUpdate()
}
}
val migratedDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 4) // version changed from 3 -> 4
}
assert(migratedDb.listSent(50, 150).toSet === Set(
ps1.copy(id = pp1.id, recipientAmount = pp1.amount, parts = pp1 :: Nil),
ps1.copy(id = pp2.id, recipientAmount = pp2.amount, parts = pp2 :: Nil)
))
assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2))
val postMigrationDb = new SqliteAuditDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "audit", 4) == 4) // version 4
}
val ps2 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, randomKey.publicKey, Seq(
PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 160),
PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 165)
))
val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), 150)
postMigrationDb.add(ps2)
assert(postMigrationDb.listSent(155, 200) === Seq(ps2))
postMigrationDb.add(relayed3)
assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3))
}
test("ignore invalid values in the DB") {
val sqlite = TestConstants.sqliteInMemory()
val db = new SqliteAuditDb(sqlite)
forAllDbs { dbs =>
val db = dbs.audit()
val sqlite = dbs.connection
val isPg = dbs.isInstanceOf[TestPgDatabases]
using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, randomBytes32.toArray)
statement.setLong(2, 42)
statement.setBytes(3, randomBytes32.toArray)
statement.setString(4, "IN")
statement.setString(5, "unknown") // invalid relay type
statement.setLong(6, 10)
statement.executeUpdate()
using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
if (isPg) statement.setString(1, randomBytes32.toHex) else statement.setBytes(1, randomBytes32.toArray)
statement.setLong(2, 42)
if (isPg) statement.setString(3, randomBytes32.toHex) else statement.setBytes(3, randomBytes32.toArray)
statement.setString(4, "IN")
statement.setString(5, "unknown") // invalid relay type
statement.setLong(6, 10)
statement.executeUpdate()
}
using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
if (isPg) statement.setString(1, randomBytes32.toHex) else statement.setBytes(1, randomBytes32.toArray)
statement.setLong(2, 51)
if (isPg) statement.setString(3, randomBytes32.toHex) else statement.setBytes(3, randomBytes32.toArray)
statement.setString(4, "UP") // invalid direction
statement.setString(5, "channel")
statement.setLong(6, 20)
statement.executeUpdate()
}
val paymentHash = randomBytes32
val channelId = randomBytes32
using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
if (isPg) statement.setString(1, paymentHash.toHex) else statement.setBytes(1, paymentHash.toArray)
statement.setLong(2, 65)
if (isPg) statement.setString(3, channelId.toHex) else statement.setBytes(3, channelId.toArray)
statement.setString(4, "IN") // missing a corresponding OUT
statement.setString(5, "channel")
statement.setLong(6, 30)
statement.executeUpdate()
}
assert(db.listRelayed(0, 40) === Nil)
}
using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, randomBytes32.toArray)
statement.setLong(2, 51)
statement.setBytes(3, randomBytes32.toArray)
statement.setString(4, "UP") // invalid direction
statement.setString(5, "channel")
statement.setLong(6, 20)
statement.executeUpdate()
}
val paymentHash = randomBytes32
val channelId = randomBytes32
using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, paymentHash.toArray)
statement.setLong(2, 65)
statement.setBytes(3, channelId.toArray)
statement.setString(4, "IN") // missing a corresponding OUT
statement.setString(5, "channel")
statement.setLong(6, 30)
statement.executeUpdate()
}
assert(db.listRelayed(0, 40) === Nil)
}
}

View File

@ -16,81 +16,89 @@
package fr.acinq.eclair.db
import java.sql.SQLException
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.CltvExpiry
import fr.acinq.eclair.TestConstants.{TestPgDatabases, TestSqliteDatabases, forAllDbs}
import fr.acinq.eclair.db.sqlite.SqliteChannelsDb
import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, using}
import fr.acinq.eclair.db.sqlite.{SqliteChannelsDb, SqlitePendingRelayDb}
import fr.acinq.eclair.wire.ChannelCodecs.stateDataCodec
import fr.acinq.eclair.wire.ChannelCodecsSpec
import fr.acinq.eclair.{CltvExpiry, TestConstants}
import org.scalatest.funsuite.AnyFunSuite
import org.sqlite.SQLiteException
import scodec.bits.ByteVector
class SqliteChannelsDbSpec extends AnyFunSuite {
test("init sqlite 2 times in a row") {
val sqlite = TestConstants.sqliteInMemory()
val db1 = new SqliteChannelsDb(sqlite)
val db2 = new SqliteChannelsDb(sqlite)
forAllDbs { dbs =>
val db1 = dbs.channels()
val db2 = dbs.channels()
}
}
test("add/remove/list channels") {
val sqlite = TestConstants.sqliteInMemory()
val db = new SqliteChannelsDb(sqlite)
new SqlitePendingRelayDb(sqlite) // needed by db.removeChannel
forAllDbs { dbs =>
val db = dbs.channels()
dbs.pendingRelay() // needed by db.removeChannel
val channel = ChannelCodecsSpec.normal
val channel = ChannelCodecsSpec.normal
val commitNumber = 42
val paymentHash1 = ByteVector32.Zeroes
val cltvExpiry1 = CltvExpiry(123)
val paymentHash2 = ByteVector32(ByteVector.fill(32)(1))
val cltvExpiry2 = CltvExpiry(656)
val commitNumber = 42
val paymentHash1 = ByteVector32.Zeroes
val cltvExpiry1 = CltvExpiry(123)
val paymentHash2 = ByteVector32(ByteVector.fill(32)(1))
val cltvExpiry2 = CltvExpiry(656)
intercept[SQLiteException](db.addHtlcInfo(channel.channelId, commitNumber, paymentHash1, cltvExpiry1)) // no related channel
intercept[SQLException](db.addHtlcInfo(channel.channelId, commitNumber, paymentHash1, cltvExpiry1)) // no related channel
assert(db.listLocalChannels().toSet === Set.empty)
db.addOrUpdateChannel(channel)
db.addOrUpdateChannel(channel)
assert(db.listLocalChannels() === List(channel))
assert(db.listLocalChannels().toSet === Set.empty)
db.addOrUpdateChannel(channel)
db.addOrUpdateChannel(channel)
assert(db.listLocalChannels() === List(channel))
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == Nil)
db.addHtlcInfo(channel.channelId, commitNumber, paymentHash1, cltvExpiry1)
db.addHtlcInfo(channel.channelId, commitNumber, paymentHash2, cltvExpiry2)
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == List((paymentHash1, cltvExpiry1), (paymentHash2, cltvExpiry2)))
assert(db.listHtlcInfos(channel.channelId, 43).toList == Nil)
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == Nil)
db.addHtlcInfo(channel.channelId, commitNumber, paymentHash1, cltvExpiry1)
db.addHtlcInfo(channel.channelId, commitNumber, paymentHash2, cltvExpiry2)
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList.toSet == Set((paymentHash1, cltvExpiry1), (paymentHash2, cltvExpiry2)))
assert(db.listHtlcInfos(channel.channelId, 43).toList == Nil)
db.removeChannel(channel.channelId)
assert(db.listLocalChannels() === Nil)
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == Nil)
db.removeChannel(channel.channelId)
assert(db.listLocalChannels() === Nil)
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == Nil)
}
}
test("migrate channel database v1 -> v2") {
val sqlite = TestConstants.sqliteInMemory()
forAllDbs {
case _: TestPgDatabases => // no migration
case dbs: TestSqliteDatabases =>
val sqlite = dbs.connection
// create a v1 channels database
using(sqlite.createStatement()) { statement =>
getVersion(statement, "channels", 1)
statement.execute("PRAGMA foreign_keys = ON")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
}
// create a v1 channels database
using(sqlite.createStatement()) { statement =>
getVersion(statement, "channels", 1)
statement.execute("PRAGMA foreign_keys = ON")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)")
}
// insert 1 row
val channel = ChannelCodecsSpec.normal
val data = stateDataCodec.encode(channel).require.toByteArray
using(sqlite.prepareStatement("INSERT INTO local_channels VALUES (?, ?)")) { statement =>
statement.setBytes(1, channel.channelId.toArray)
statement.setBytes(2, data)
statement.executeUpdate()
}
// insert 1 row
val channel = ChannelCodecsSpec.normal
val data = stateDataCodec.encode(channel).require.toByteArray
using(sqlite.prepareStatement("INSERT INTO local_channels VALUES (?, ?)")) { statement =>
statement.setBytes(1, channel.channelId.toArray)
statement.setBytes(2, data)
statement.executeUpdate()
}
// check that db migration works
val db = new SqliteChannelsDb(sqlite)
using(sqlite.createStatement()) { statement =>
assert(getVersion(statement, "channels", 1) == 2) // version changed from 1 -> 2
// check that db migration works
val db = new SqliteChannelsDb(sqlite)
using(sqlite.createStatement()) { statement =>
assert(getVersion(statement, "channels", 1) == 2) // version changed from 1 -> 2
}
assert(db.listLocalChannels() === List(channel))
}
assert(db.listLocalChannels() === List(channel))
}
}

View File

@ -16,107 +16,112 @@
package fr.acinq.eclair.db
import java.sql.Connection
import fr.acinq.bitcoin.Crypto.PrivateKey
import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64, Crypto, Satoshi}
import fr.acinq.eclair.FeatureSupport.Optional
import fr.acinq.eclair.Features.VariableLengthOnion
import fr.acinq.eclair.db.sqlite.SqliteNetworkDb
import fr.acinq.eclair.TestConstants.{TestDatabases, TestPgDatabases, TestSqliteDatabases}
import fr.acinq.eclair.db.sqlite.SqliteUtils._
import fr.acinq.eclair.router.Announcements
import fr.acinq.eclair.router.Router.PublicChannel
import fr.acinq.eclair.wire.{Color, NodeAddress, Tor2}
import fr.acinq.eclair.{ActivatedFeature, CltvExpiryDelta, Features, LongToBtcAmount, ShortChannelId, TestConstants, randomBytes32, randomKey}
import org.scalatest.funsuite.AnyFunSuite
import scodec.bits.HexStringSyntax
import scala.collection.{SortedMap, mutable}
class SqliteNetworkDbSpec extends AnyFunSuite {
import TestConstants.forAllDbs
val shortChannelIds = (42 to (5000 + 42)).map(i => ShortChannelId(i))
test("init sqlite 2 times in a row") {
val sqlite = TestConstants.sqliteInMemory()
val db1 = new SqliteNetworkDb(sqlite)
val db2 = new SqliteNetworkDb(sqlite)
forAllDbs { dbs =>
val db1 = dbs.network()
val db2 = dbs.network()
}
}
test("migration test 1->2") {
val sqlite = TestConstants.sqliteInMemory()
forAllDbs {
case _: TestPgDatabases => // no migration
case dbs: TestSqliteDatabases =>
using(sqlite.createStatement()) { statement =>
getVersion(statement, "network", 1) // this will set version to 1
statement.execute("PRAGMA foreign_keys = ON")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS nodes (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channels (short_channel_id INTEGER NOT NULL PRIMARY KEY, txid STRING NOT NULL, data BLOB NOT NULL, capacity_sat INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_updates (short_channel_id INTEGER NOT NULL, node_flag INTEGER NOT NULL, data BLOB NOT NULL, PRIMARY KEY(short_channel_id, node_flag), FOREIGN KEY(short_channel_id) REFERENCES channels(short_channel_id))")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_updates_idx ON channel_updates(short_channel_id)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS pruned (short_channel_id INTEGER NOT NULL PRIMARY KEY)")
}
using(dbs.connection.createStatement()) { statement =>
dbs.getVersion(statement, "network", 1) // this will set version to 1
statement.execute("PRAGMA foreign_keys = ON")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS nodes (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channels (short_channel_id INTEGER NOT NULL PRIMARY KEY, txid STRING NOT NULL, data BLOB NOT NULL, capacity_sat INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_updates (short_channel_id INTEGER NOT NULL, node_flag INTEGER NOT NULL, data BLOB NOT NULL, PRIMARY KEY(short_channel_id, node_flag), FOREIGN KEY(short_channel_id) REFERENCES channels(short_channel_id))")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_updates_idx ON channel_updates(short_channel_id)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS pruned (short_channel_id INTEGER NOT NULL PRIMARY KEY)")
}
using(sqlite.createStatement()) { statement =>
assert(getVersion(statement, "network", 2) == 1)
}
using(dbs.connection.createStatement()) { statement =>
assert(dbs.getVersion(statement, "network", 2) == 1)
}
// first round: this will trigger a migration
simpleTest(sqlite)
// first round: this will trigger a migration
simpleTest(dbs)
using(sqlite.createStatement()) { statement =>
assert(getVersion(statement, "network", 2) == 2)
}
using(dbs.connection.createStatement()) { statement =>
assert(dbs.getVersion(statement, "network", 2) == 2)
}
using(sqlite.createStatement()) { statement =>
statement.executeUpdate("DELETE FROM nodes")
statement.executeUpdate("DELETE FROM channels")
}
using(dbs.connection.createStatement()) { statement =>
statement.executeUpdate("DELETE FROM nodes")
statement.executeUpdate("DELETE FROM channels")
}
// second round: no migration
simpleTest(sqlite)
// second round: no migration
simpleTest(dbs)
using(sqlite.createStatement()) { statement =>
assert(getVersion(statement, "network", 2) == 2)
using(dbs.connection.createStatement()) { statement =>
assert(dbs.getVersion(statement, "network", 2) == 2)
}
}
}
test("add/remove/list nodes") {
val sqlite = TestConstants.sqliteInMemory()
val db = new SqliteNetworkDb(sqlite)
forAllDbs { dbs =>
val db = dbs.network()
val node_1 = Announcements.makeNodeAnnouncement(randomKey, "node-alice", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features.empty)
val node_2 = Announcements.makeNodeAnnouncement(randomKey, "node-bob", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features(Set(ActivatedFeature(VariableLengthOnion, Optional))))
val node_3 = Announcements.makeNodeAnnouncement(randomKey, "node-charlie", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features(Set(ActivatedFeature(VariableLengthOnion, Optional))))
val node_4 = Announcements.makeNodeAnnouncement(randomKey, "node-charlie", Color(100.toByte, 200.toByte, 300.toByte), Tor2("aaaqeayeaudaocaj", 42000) :: Nil, Features.empty)
val node_1 = Announcements.makeNodeAnnouncement(randomKey, "node-alice", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features.empty)
val node_2 = Announcements.makeNodeAnnouncement(randomKey, "node-bob", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features(Set(ActivatedFeature(VariableLengthOnion, Optional))))
val node_3 = Announcements.makeNodeAnnouncement(randomKey, "node-charlie", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features(Set(ActivatedFeature(VariableLengthOnion, Optional))))
val node_4 = Announcements.makeNodeAnnouncement(randomKey, "node-charlie", Color(100.toByte, 200.toByte, 300.toByte), Tor2("aaaqeayeaudaocaj", 42000) :: Nil, Features.empty)
assert(db.listNodes().toSet === Set.empty)
db.addNode(node_1)
db.addNode(node_1) // duplicate is ignored
assert(db.getNode(node_1.nodeId) === Some(node_1))
assert(db.listNodes().size === 1)
db.addNode(node_2)
db.addNode(node_3)
db.addNode(node_4)
assert(db.listNodes().toSet === Set(node_1, node_2, node_3, node_4))
db.removeNode(node_2.nodeId)
assert(db.listNodes().toSet === Set(node_1, node_3, node_4))
db.updateNode(node_1)
assert(db.listNodes().toSet === Set.empty)
db.addNode(node_1)
db.addNode(node_1) // duplicate is ignored
assert(db.getNode(node_1.nodeId) === Some(node_1))
assert(db.listNodes().size === 1)
db.addNode(node_2)
db.addNode(node_3)
db.addNode(node_4)
assert(db.listNodes().toSet === Set(node_1, node_2, node_3, node_4))
db.removeNode(node_2.nodeId)
assert(db.listNodes().toSet === Set(node_1, node_3, node_4))
db.updateNode(node_1)
assert(node_4.addresses == List(Tor2("aaaqeayeaudaocaj", 42000)))
assert(node_4.addresses == List(Tor2("aaaqeayeaudaocaj", 42000)))
}
}
test("correctly handle txids that start with 0") {
val sqlite = TestConstants.sqliteInMemory()
val db = new SqliteNetworkDb(sqlite)
val sig = ByteVector64.Zeroes
val c = Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, ShortChannelId(42), randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, sig, sig, sig, sig)
val txid = ByteVector32.fromValidHex("0001" * 16)
db.addChannel(c, txid, Satoshi(42))
assert(db.listChannels() === SortedMap(c.shortChannelId -> PublicChannel(c, txid, Satoshi(42), None, None, None)))
forAllDbs { dbs =>
val db = dbs.network()
val sig = ByteVector64.Zeroes
val c = Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, ShortChannelId(42), randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, sig, sig, sig, sig)
val txid = ByteVector32.fromValidHex("0001" * 16)
db.addChannel(c, txid, Satoshi(42))
assert(db.listChannels() === SortedMap(c.shortChannelId -> PublicChannel(c, txid, Satoshi(42), None, None, None)))
}
}
def simpleTest(sqlite: Connection) = {
val db = new SqliteNetworkDb(sqlite)
def simpleTest(dbs: TestDatabases) = {
val db = dbs.network()
def sig = Crypto.sign(randomBytes32, randomKey)
@ -172,75 +177,85 @@ class SqliteNetworkDbSpec extends AnyFunSuite {
}
test("add/remove/list channels and channel_updates") {
val sqlite = TestConstants.sqliteInMemory()
simpleTest(sqlite)
forAllDbs { dbs =>
simpleTest(dbs)
}
}
test("creating a table that already exists but with different column types is ignored") {
val sqlite = TestConstants.sqliteInMemory()
using(sqlite.createStatement(), inTransaction = true) { statement =>
statement.execute("CREATE TABLE IF NOT EXISTS test (txid STRING NOT NULL)")
}
// column type is STRING
assert(sqlite.getMetaData.getColumns(null, null, "test", null).getString("TYPE_NAME") == "STRING")
forAllDbs { dbs =>
// insert and read back random values
val txids = for (_ <- 0 until 1000) yield randomBytes32
txids.foreach { txid =>
using(sqlite.prepareStatement("INSERT OR IGNORE INTO test VALUES (?)")) { statement =>
statement.setString(1, txid.toHex)
statement.executeUpdate()
using(dbs.connection.createStatement(), inTransaction = true) { statement =>
statement.execute("CREATE TABLE IF NOT EXISTS test (txid VARCHAR NOT NULL)")
}
}
// column type is VARCHAR
val rs = dbs.connection.getMetaData.getColumns(null, null, "test", null)
assert(rs.next())
assert(rs.getString("TYPE_NAME").toLowerCase == "varchar")
val check = using(sqlite.createStatement()) { statement =>
val rs = statement.executeQuery("SELECT txid FROM test")
val q = new mutable.Queue[ByteVector32]()
while (rs.next()) {
val txId = ByteVector32.fromValidHex(rs.getString("txid"))
q.enqueue(txId)
// insert and read back random values
val txids = for (_ <- 0 until 1000) yield randomBytes32
txids.foreach { txid =>
using(dbs.connection.prepareStatement("INSERT INTO test VALUES (?)")) { statement =>
statement.setString(1, txid.toHex)
statement.executeUpdate()
}
}
q
val check = using(dbs.connection.createStatement()) { statement =>
val rs = statement.executeQuery("SELECT txid FROM test")
val q = new mutable.Queue[ByteVector32]()
while (rs.next()) {
val txId = ByteVector32.fromValidHex(rs.getString("txid"))
q.enqueue(txId)
}
q
}
assert(txids.toSet == check.toSet)
using(dbs.connection.createStatement(), inTransaction = true) { statement =>
statement.execute("CREATE TABLE IF NOT EXISTS test (txid TEXT NOT NULL)")
}
// column type has not changed
val rs1 = dbs.connection.getMetaData.getColumns(null, null, "test", null)
assert(rs1.next())
assert(rs1.getString("TYPE_NAME").toLowerCase == "varchar")
}
assert(txids.toSet == check.toSet)
using(sqlite.createStatement(), inTransaction = true) { statement =>
statement.execute("CREATE TABLE IF NOT EXISTS test (txid TEXT NOT NULL)")
}
// column type has not changed
assert(sqlite.getMetaData.getColumns(null, null, "test", null).getString("TYPE_NAME") == "STRING")
}
test("remove many channels") {
val sqlite = TestConstants.sqliteInMemory()
val db = new SqliteNetworkDb(sqlite)
val sig = Crypto.sign(randomBytes32, randomKey)
val priv = randomKey
val pub = priv.publicKey
val capacity = 10000 sat
forAllDbs { dbs =>
val db = dbs.network()
val sig = Crypto.sign(randomBytes32, randomKey)
val priv = randomKey
val pub = priv.publicKey
val capacity = 10000 sat
val channels = shortChannelIds.map(id => Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, id, pub, pub, pub, pub, sig, sig, sig, sig))
val template = Announcements.makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv, pub, ShortChannelId(42), CltvExpiryDelta(5), 7000000 msat, 50000 msat, 100, 500000000L msat, true)
val updates = shortChannelIds.map(id => template.copy(shortChannelId = id))
val txid = randomBytes32
channels.foreach(ca => db.addChannel(ca, txid, capacity))
updates.foreach(u => db.updateChannel(u))
assert(db.listChannels().keySet === channels.map(_.shortChannelId).toSet)
val channels = shortChannelIds.map(id => Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, id, pub, pub, pub, pub, sig, sig, sig, sig))
val template = Announcements.makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv, pub, ShortChannelId(42), CltvExpiryDelta(5), 7000000 msat, 50000 msat, 100, 500000000L msat, true)
val updates = shortChannelIds.map(id => template.copy(shortChannelId = id))
val txid = randomBytes32
channels.foreach(ca => db.addChannel(ca, txid, capacity))
updates.foreach(u => db.updateChannel(u))
assert(db.listChannels().keySet === channels.map(_.shortChannelId).toSet)
val toDelete = channels.map(_.shortChannelId).drop(500).take(2500)
db.removeChannels(toDelete)
assert(db.listChannels().keySet === (channels.map(_.shortChannelId).toSet -- toDelete))
val toDelete = channels.map(_.shortChannelId).drop(500).take(2500)
db.removeChannels(toDelete)
assert(db.listChannels().keySet === (channels.map(_.shortChannelId).toSet -- toDelete))
}
}
test("prune many channels") {
val sqlite = TestConstants.sqliteInMemory()
val db = new SqliteNetworkDb(sqlite)
forAllDbs { dbs =>
val db = dbs.network()
db.addToPruned(shortChannelIds)
shortChannelIds.foreach { id => assert(db.isPruned((id))) }
db.removeFromPruned(ShortChannelId(5))
assert(!db.isPruned(ShortChannelId(5)))
db.addToPruned(shortChannelIds)
shortChannelIds.foreach { id => assert(db.isPruned((id))) }
db.removeFromPruned(ShortChannelId(5))
assert(!db.isPruned(ShortChannelId(5)))
}
}
}

View File

@ -20,9 +20,9 @@ import java.util.UUID
import fr.acinq.bitcoin.Crypto.PrivateKey
import fr.acinq.bitcoin.{Block, ByteVector32, Crypto}
import fr.acinq.eclair.TestConstants.{TestPgDatabases, TestSqliteDatabases, forAllDbs}
import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.db.sqlite.SqlitePaymentsDb
import fr.acinq.eclair.db.sqlite.SqliteUtils._
import fr.acinq.eclair.payment._
import fr.acinq.eclair.router.Router.{ChannelHop, NodeHop}
import fr.acinq.eclair.wire.{ChannelUpdate, UnknownNextPeer}
@ -36,383 +36,396 @@ class SqlitePaymentsDbSpec extends AnyFunSuite {
import SqlitePaymentsDbSpec._
test("init sqlite 2 times in a row") {
val sqlite = TestConstants.sqliteInMemory()
val db1 = new SqlitePaymentsDb(sqlite)
val db2 = new SqlitePaymentsDb(sqlite)
forAllDbs { dbs =>
val db1 = dbs.payments()
val db2 = dbs.payments()
}
}
test("handle version migration 1->4") {
val connection = TestConstants.sqliteInMemory()
import fr.acinq.eclair.db.sqlite.SqliteUtils._
forAllDbs {
case _: TestPgDatabases => // no migration
case dbs: TestSqliteDatabases =>
val connection = dbs.connection
using(connection.createStatement()) { statement =>
getVersion(statement, "payments", 1)
statement.executeUpdate("CREATE TABLE IF NOT EXISTS payments (payment_hash BLOB NOT NULL PRIMARY KEY, amount_msat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
using(connection.createStatement()) { statement =>
getVersion(statement, "payments", 1)
statement.executeUpdate("CREATE TABLE IF NOT EXISTS payments (payment_hash BLOB NOT NULL PRIMARY KEY, amount_msat INTEGER NOT NULL, timestamp INTEGER NOT NULL)")
}
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 1) == 1) // version 1 is deployed now
}
// Changes between version 1 and 2:
// - the monolithic payments table has been replaced by two tables, received_payments and sent_payments
// - old records from the payments table are ignored (not migrated to the new tables)
using(connection.prepareStatement("INSERT INTO payments VALUES (?, ?, ?)")) { statement =>
statement.setBytes(1, paymentHash1.toArray)
statement.setLong(2, (123 msat).toLong)
statement.setLong(3, 1000) // received_at
statement.executeUpdate()
}
val preMigrationDb = new SqlitePaymentsDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 1) == 4) // version changed from 1 -> 4
}
// the existing received payment can NOT be queried anymore
assert(preMigrationDb.getIncomingPayment(paymentHash1).isEmpty)
// add a few rows
val ps1 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), None, paymentHash1, PaymentType.Standard, 12345 msat, 12345 msat, alice, 1000, None, OutgoingPaymentStatus.Pending)
val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(500 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 1)
val pr1 = IncomingPayment(i1, preimage1, PaymentType.Standard, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(550 msat, 1100))
preMigrationDb.addOutgoingPayment(ps1)
preMigrationDb.addIncomingPayment(i1, preimage1)
preMigrationDb.receiveIncomingPayment(i1.paymentHash, 550 msat, 1100)
assert(preMigrationDb.listIncomingPayments(1, 1500) === Seq(pr1))
assert(preMigrationDb.listOutgoingPayments(1, 1500) === Seq(ps1))
val postMigrationDb = new SqlitePaymentsDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 4) == 4) // version still to 4
}
assert(postMigrationDb.listIncomingPayments(1, 1500) === Seq(pr1))
assert(postMigrationDb.listOutgoingPayments(1, 1500) === Seq(ps1))
}
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 1) == 1) // version 1 is deployed now
}
// Changes between version 1 and 2:
// - the monolithic payments table has been replaced by two tables, received_payments and sent_payments
// - old records from the payments table are ignored (not migrated to the new tables)
using(connection.prepareStatement("INSERT INTO payments VALUES (?, ?, ?)")) { statement =>
statement.setBytes(1, paymentHash1.toArray)
statement.setLong(2, (123 msat).toLong)
statement.setLong(3, 1000) // received_at
statement.executeUpdate()
}
val preMigrationDb = new SqlitePaymentsDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 1) == 4) // version changed from 1 -> 4
}
// the existing received payment can NOT be queried anymore
assert(preMigrationDb.getIncomingPayment(paymentHash1).isEmpty)
// add a few rows
val ps1 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), None, paymentHash1, PaymentType.Standard, 12345 msat, 12345 msat, alice, 1000, None, OutgoingPaymentStatus.Pending)
val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(500 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 1)
val pr1 = IncomingPayment(i1, preimage1, PaymentType.Standard, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(550 msat, 1100))
preMigrationDb.addOutgoingPayment(ps1)
preMigrationDb.addIncomingPayment(i1, preimage1)
preMigrationDb.receiveIncomingPayment(i1.paymentHash, 550 msat, 1100)
assert(preMigrationDb.listIncomingPayments(1, 1500) === Seq(pr1))
assert(preMigrationDb.listOutgoingPayments(1, 1500) === Seq(ps1))
val postMigrationDb = new SqlitePaymentsDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 4) == 4) // version still to 4
}
assert(postMigrationDb.listIncomingPayments(1, 1500) === Seq(pr1))
assert(postMigrationDb.listOutgoingPayments(1, 1500) === Seq(ps1))
}
test("handle version migration 2->4") {
val connection = TestConstants.sqliteInMemory()
import fr.acinq.eclair.db.sqlite.SqliteUtils._
forAllDbs {
case _: TestPgDatabases => // no migration
case dbs: TestSqliteDatabases =>
val connection = dbs.connection
using(connection.createStatement()) { statement =>
getVersion(statement, "payments", 2)
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER, received_at INTEGER)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, payment_hash BLOB NOT NULL, preimage BLOB, amount_msat INTEGER NOT NULL, created_at INTEGER NOT NULL, completed_at INTEGER, status VARCHAR NOT NULL)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS payment_hash_idx ON sent_payments(payment_hash)")
using(connection.createStatement()) { statement =>
getVersion(statement, "payments", 2)
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER, received_at INTEGER)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, payment_hash BLOB NOT NULL, preimage BLOB, amount_msat INTEGER NOT NULL, created_at INTEGER NOT NULL, completed_at INTEGER, status VARCHAR NOT NULL)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS payment_hash_idx ON sent_payments(payment_hash)")
}
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 2) == 2) // version 2 is deployed now
}
// Insert a bunch of old version 2 rows.
val id1 = UUID.randomUUID()
val id2 = UUID.randomUUID()
val id3 = UUID.randomUUID()
val ps1 = OutgoingPayment(id1, id1, None, randomBytes32, PaymentType.Standard, 561 msat, 561 msat, PrivateKey(ByteVector32.One).publicKey, 1000, None, OutgoingPaymentStatus.Pending)
val ps2 = OutgoingPayment(id2, id2, None, randomBytes32, PaymentType.Standard, 1105 msat, 1105 msat, PrivateKey(ByteVector32.One).publicKey, 1010, None, OutgoingPaymentStatus.Failed(Nil, 1050))
val ps3 = OutgoingPayment(id3, id3, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, PrivateKey(ByteVector32.One).publicKey, 1040, None, OutgoingPaymentStatus.Succeeded(preimage1, 0 msat, Nil, 1060))
val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 1)
val pr1 = IncomingPayment(i1, preimage1, PaymentType.Standard, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(12345678 msat, 1090))
val i2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash2, carolPriv, "Another invoice", expirySeconds = Some(30), timestamp = 1)
val pr2 = IncomingPayment(i2, preimage2, PaymentType.Standard, i2.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired)
// Changes between version 2 and 3 to sent_payments:
// - removed the status column
// - added optional payment failures
// - added optional payment success details (fees paid and route)
// - added optional payment request
// - added target node ID
// - added externalID and parentID
using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, amount_msat, created_at, status) VALUES (?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, ps1.id.toString)
statement.setBytes(2, ps1.paymentHash.toArray)
statement.setLong(3, ps1.amount.toLong)
statement.setLong(4, ps1.createdAt)
statement.setString(5, "PENDING")
statement.executeUpdate()
}
using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, amount_msat, created_at, completed_at, status) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, ps2.id.toString)
statement.setBytes(2, ps2.paymentHash.toArray)
statement.setLong(3, ps2.amount.toLong)
statement.setLong(4, ps2.createdAt)
statement.setLong(5, ps2.status.asInstanceOf[OutgoingPaymentStatus.Failed].completedAt)
statement.setString(6, "FAILED")
statement.executeUpdate()
}
using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, preimage, amount_msat, created_at, completed_at, status) VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, ps3.id.toString)
statement.setBytes(2, ps3.paymentHash.toArray)
statement.setBytes(3, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].paymentPreimage.toArray)
statement.setLong(4, ps3.amount.toLong)
statement.setLong(5, ps3.createdAt)
statement.setLong(6, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].completedAt)
statement.setString(7, "SUCCEEDED")
statement.executeUpdate()
}
// Changes between version 2 and 3 to received_payments:
// - renamed the preimage column
// - made expire_at not null
using(connection.prepareStatement("INSERT INTO received_payments (payment_hash, preimage, payment_request, received_msat, created_at, received_at) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, i1.paymentHash.toArray)
statement.setBytes(2, pr1.paymentPreimage.toArray)
statement.setString(3, PaymentRequest.write(i1))
statement.setLong(4, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].amount.toLong)
statement.setLong(5, pr1.createdAt)
statement.setLong(6, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].receivedAt)
statement.executeUpdate()
}
using(connection.prepareStatement("INSERT INTO received_payments (payment_hash, preimage, payment_request, created_at, expire_at) VALUES (?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, i2.paymentHash.toArray)
statement.setBytes(2, pr2.paymentPreimage.toArray)
statement.setString(3, PaymentRequest.write(i2))
statement.setLong(4, pr2.createdAt)
statement.setLong(5, (i2.timestamp + i2.expiry.get).seconds.toMillis)
statement.executeUpdate()
}
val preMigrationDb = new SqlitePaymentsDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 2) == 4) // version changed from 2 -> 4
}
assert(preMigrationDb.getIncomingPayment(i1.paymentHash) === Some(pr1))
assert(preMigrationDb.getIncomingPayment(i2.paymentHash) === Some(pr2))
assert(preMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2, ps3))
val postMigrationDb = new SqlitePaymentsDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 4) == 4) // version still to 4
}
val i3 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), paymentHash3, alicePriv, "invoice #3", expirySeconds = Some(30))
val pr3 = IncomingPayment(i3, preimage3, PaymentType.Standard, i3.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending)
postMigrationDb.addIncomingPayment(i3, pr3.paymentPreimage)
val ps4 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("1"), randomBytes32, PaymentType.Standard, 123 msat, 123 msat, alice, 1100, Some(i3), OutgoingPaymentStatus.Pending)
val ps5 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("2"), randomBytes32, PaymentType.Standard, 456 msat, 456 msat, bob, 1150, Some(i2), OutgoingPaymentStatus.Succeeded(preimage1, 42 msat, Nil, 1180))
val ps6 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("3"), randomBytes32, PaymentType.Standard, 789 msat, 789 msat, bob, 1250, None, OutgoingPaymentStatus.Failed(Nil, 1300))
postMigrationDb.addOutgoingPayment(ps4)
postMigrationDb.addOutgoingPayment(ps5.copy(status = OutgoingPaymentStatus.Pending))
postMigrationDb.updateOutgoingPayment(PaymentSent(ps5.parentId, ps5.paymentHash, preimage1, ps5.amount, ps5.recipientNodeId, Seq(PaymentSent.PartialPayment(ps5.id, ps5.amount, 42 msat, randomBytes32, None, 1180))))
postMigrationDb.addOutgoingPayment(ps6.copy(status = OutgoingPaymentStatus.Pending))
postMigrationDb.updateOutgoingPayment(PaymentFailed(ps6.id, ps6.paymentHash, Nil, 1300))
assert(postMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2, ps3, ps4, ps5, ps6))
assert(postMigrationDb.listIncomingPayments(1, System.currentTimeMillis) === Seq(pr1, pr2, pr3))
assert(postMigrationDb.listExpiredIncomingPayments(1, 2000) === Seq(pr2))
}
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 2) == 2) // version 2 is deployed now
}
// Insert a bunch of old version 2 rows.
val id1 = UUID.randomUUID()
val id2 = UUID.randomUUID()
val id3 = UUID.randomUUID()
val ps1 = OutgoingPayment(id1, id1, None, randomBytes32, PaymentType.Standard, 561 msat, 561 msat, PrivateKey(ByteVector32.One).publicKey, 1000, None, OutgoingPaymentStatus.Pending)
val ps2 = OutgoingPayment(id2, id2, None, randomBytes32, PaymentType.Standard, 1105 msat, 1105 msat, PrivateKey(ByteVector32.One).publicKey, 1010, None, OutgoingPaymentStatus.Failed(Nil, 1050))
val ps3 = OutgoingPayment(id3, id3, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, PrivateKey(ByteVector32.One).publicKey, 1040, None, OutgoingPaymentStatus.Succeeded(preimage1, 0 msat, Nil, 1060))
val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 1)
val pr1 = IncomingPayment(i1, preimage1, PaymentType.Standard, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(12345678 msat, 1090))
val i2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash2, carolPriv, "Another invoice", expirySeconds = Some(30), timestamp = 1)
val pr2 = IncomingPayment(i2, preimage2, PaymentType.Standard, i2.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired)
// Changes between version 2 and 3 to sent_payments:
// - removed the status column
// - added optional payment failures
// - added optional payment success details (fees paid and route)
// - added optional payment request
// - added target node ID
// - added externalID and parentID
using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, amount_msat, created_at, status) VALUES (?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, ps1.id.toString)
statement.setBytes(2, ps1.paymentHash.toArray)
statement.setLong(3, ps1.amount.toLong)
statement.setLong(4, ps1.createdAt)
statement.setString(5, "PENDING")
statement.executeUpdate()
}
using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, amount_msat, created_at, completed_at, status) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, ps2.id.toString)
statement.setBytes(2, ps2.paymentHash.toArray)
statement.setLong(3, ps2.amount.toLong)
statement.setLong(4, ps2.createdAt)
statement.setLong(5, ps2.status.asInstanceOf[OutgoingPaymentStatus.Failed].completedAt)
statement.setString(6, "FAILED")
statement.executeUpdate()
}
using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, preimage, amount_msat, created_at, completed_at, status) VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, ps3.id.toString)
statement.setBytes(2, ps3.paymentHash.toArray)
statement.setBytes(3, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].paymentPreimage.toArray)
statement.setLong(4, ps3.amount.toLong)
statement.setLong(5, ps3.createdAt)
statement.setLong(6, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].completedAt)
statement.setString(7, "SUCCEEDED")
statement.executeUpdate()
}
// Changes between version 2 and 3 to received_payments:
// - renamed the preimage column
// - made expire_at not null
using(connection.prepareStatement("INSERT INTO received_payments (payment_hash, preimage, payment_request, received_msat, created_at, received_at) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, i1.paymentHash.toArray)
statement.setBytes(2, pr1.paymentPreimage.toArray)
statement.setString(3, PaymentRequest.write(i1))
statement.setLong(4, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].amount.toLong)
statement.setLong(5, pr1.createdAt)
statement.setLong(6, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].receivedAt)
statement.executeUpdate()
}
using(connection.prepareStatement("INSERT INTO received_payments (payment_hash, preimage, payment_request, created_at, expire_at) VALUES (?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, i2.paymentHash.toArray)
statement.setBytes(2, pr2.paymentPreimage.toArray)
statement.setString(3, PaymentRequest.write(i2))
statement.setLong(4, pr2.createdAt)
statement.setLong(5, (i2.timestamp + i2.expiry.get).seconds.toMillis)
statement.executeUpdate()
}
val preMigrationDb = new SqlitePaymentsDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 2) == 4) // version changed from 2 -> 4
}
assert(preMigrationDb.getIncomingPayment(i1.paymentHash) === Some(pr1))
assert(preMigrationDb.getIncomingPayment(i2.paymentHash) === Some(pr2))
assert(preMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2, ps3))
val postMigrationDb = new SqlitePaymentsDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 4) == 4) // version still to 4
}
val i3 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), paymentHash3, alicePriv, "invoice #3", expirySeconds = Some(30))
val pr3 = IncomingPayment(i3, preimage3, PaymentType.Standard, i3.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending)
postMigrationDb.addIncomingPayment(i3, pr3.paymentPreimage)
val ps4 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("1"), randomBytes32, PaymentType.Standard, 123 msat, 123 msat, alice, 1100, Some(i3), OutgoingPaymentStatus.Pending)
val ps5 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("2"), randomBytes32, PaymentType.Standard, 456 msat, 456 msat, bob, 1150, Some(i2), OutgoingPaymentStatus.Succeeded(preimage1, 42 msat, Nil, 1180))
val ps6 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("3"), randomBytes32, PaymentType.Standard, 789 msat, 789 msat, bob, 1250, None, OutgoingPaymentStatus.Failed(Nil, 1300))
postMigrationDb.addOutgoingPayment(ps4)
postMigrationDb.addOutgoingPayment(ps5.copy(status = OutgoingPaymentStatus.Pending))
postMigrationDb.updateOutgoingPayment(PaymentSent(ps5.parentId, ps5.paymentHash, preimage1, ps5.amount, ps5.recipientNodeId, Seq(PaymentSent.PartialPayment(ps5.id, ps5.amount, 42 msat, randomBytes32, None, 1180))))
postMigrationDb.addOutgoingPayment(ps6.copy(status = OutgoingPaymentStatus.Pending))
postMigrationDb.updateOutgoingPayment(PaymentFailed(ps6.id, ps6.paymentHash, Nil, 1300))
assert(postMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2, ps3, ps4, ps5, ps6))
assert(postMigrationDb.listIncomingPayments(1, System.currentTimeMillis) === Seq(pr1, pr2, pr3))
assert(postMigrationDb.listExpiredIncomingPayments(1, 2000) === Seq(pr2))
}
test("handle version migration 3->4") {
val connection = TestConstants.sqliteInMemory()
forAllDbs {
case _: TestPgDatabases => // no migration
case dbs: TestSqliteDatabases =>
import fr.acinq.eclair.db.sqlite.SqliteUtils._
val connection = dbs.connection
using(connection.createStatement()) { statement =>
getVersion(statement, "payments", 3)
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, payment_preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER NOT NULL, received_at INTEGER)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)")
using(connection.createStatement()) { statement =>
getVersion(statement, "payments", 3)
statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, payment_preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER NOT NULL, received_at INTEGER)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_parent_id_idx ON sent_payments(parent_id)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_payment_hash_idx ON sent_payments(payment_hash)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_created_idx ON sent_payments(created_at)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_created_idx ON received_payments(created_at)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_parent_id_idx ON sent_payments(parent_id)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_payment_hash_idx ON sent_payments(payment_hash)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_created_idx ON sent_payments(created_at)")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_created_idx ON received_payments(created_at)")
}
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 3) == 3) // version 3 is deployed now
}
// Insert a bunch of old version 3 rows.
val (id1, id2, id3) = (UUID.randomUUID(), UUID.randomUUID(), UUID.randomUUID())
val parentId = UUID.randomUUID()
val invoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(2834 msat), paymentHash1, bobPriv, "invoice #1", expirySeconds = Some(30))
val ps1 = OutgoingPayment(id1, id1, Some("42"), randomBytes32, PaymentType.Standard, 561 msat, 561 msat, alice, 1000, None, OutgoingPaymentStatus.Failed(Seq(FailureSummary(FailureType.REMOTE, "no candy for you", List(HopSummary(hop_ab), HopSummary(hop_bc)))), 1020))
val ps2 = OutgoingPayment(id2, parentId, Some("42"), paymentHash1, PaymentType.Standard, 1105 msat, 1105 msat, bob, 1010, Some(invoice1), OutgoingPaymentStatus.Pending)
val ps3 = OutgoingPayment(id3, parentId, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, bob, 1040, None, OutgoingPaymentStatus.Succeeded(preimage1, 10 msat, Seq(HopSummary(hop_ab), HopSummary(hop_bc)), 1060))
using(connection.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, completed_at, failures) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, ps1.id.toString)
statement.setString(2, ps1.parentId.toString)
statement.setString(3, ps1.externalId.get)
statement.setBytes(4, ps1.paymentHash.toArray)
statement.setLong(5, ps1.amount.toLong)
statement.setBytes(6, ps1.recipientNodeId.value.toArray)
statement.setLong(7, ps1.createdAt)
statement.setLong(8, ps1.status.asInstanceOf[OutgoingPaymentStatus.Failed].completedAt)
statement.setBytes(9, SqlitePaymentsDb.paymentFailuresCodec.encode(ps1.status.asInstanceOf[OutgoingPaymentStatus.Failed].failures.toList).require.toByteArray)
statement.executeUpdate()
}
using(connection.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, payment_request) VALUES (?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, ps2.id.toString)
statement.setString(2, ps2.parentId.toString)
statement.setString(3, ps2.externalId.get)
statement.setBytes(4, ps2.paymentHash.toArray)
statement.setLong(5, ps2.amount.toLong)
statement.setBytes(6, ps2.recipientNodeId.value.toArray)
statement.setLong(7, ps2.createdAt)
statement.setString(8, PaymentRequest.write(invoice1))
statement.executeUpdate()
}
using(connection.prepareStatement("INSERT INTO sent_payments (id, parent_id, payment_hash, amount_msat, target_node_id, created_at, completed_at, payment_preimage, fees_msat, payment_route) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, ps3.id.toString)
statement.setString(2, ps3.parentId.toString)
statement.setBytes(3, ps3.paymentHash.toArray)
statement.setLong(4, ps3.amount.toLong)
statement.setBytes(5, ps3.recipientNodeId.value.toArray)
statement.setLong(6, ps3.createdAt)
statement.setLong(7, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].completedAt)
statement.setBytes(8, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].paymentPreimage.toArray)
statement.setLong(9, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].feesPaid.toLong)
statement.setBytes(10, SqlitePaymentsDb.paymentRouteCodec.encode(ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].route.toList).require.toByteArray)
statement.executeUpdate()
}
// Changes between version 3 and 4 to sent_payments:
// - added final amount column
// - added payment type column, with a default to "Standard"
// - renamed target_node_id -> recipient_node_id
// - re-ordered columns
val preMigrationDb = new SqlitePaymentsDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 3) == 4) // version changed from 3 -> 4
}
assert(preMigrationDb.getOutgoingPayment(id1) === Some(ps1))
assert(preMigrationDb.listOutgoingPayments(parentId) === Seq(ps2, ps3))
val postMigrationDb = new SqlitePaymentsDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 4) == 4) // version still to 4
}
}
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 3) == 3) // version 3 is deployed now
}
// Insert a bunch of old version 3 rows.
val (id1, id2, id3) = (UUID.randomUUID(), UUID.randomUUID(), UUID.randomUUID())
val parentId = UUID.randomUUID()
val invoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(2834 msat), paymentHash1, bobPriv, "invoice #1", expirySeconds = Some(30))
val ps1 = OutgoingPayment(id1, id1, Some("42"), randomBytes32, PaymentType.Standard, 561 msat, 561 msat, alice, 1000, None, OutgoingPaymentStatus.Failed(Seq(FailureSummary(FailureType.REMOTE, "no candy for you", List(HopSummary(hop_ab), HopSummary(hop_bc)))), 1020))
val ps2 = OutgoingPayment(id2, parentId, Some("42"), paymentHash1, PaymentType.Standard, 1105 msat, 1105 msat, bob, 1010, Some(invoice1), OutgoingPaymentStatus.Pending)
val ps3 = OutgoingPayment(id3, parentId, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, bob, 1040, None, OutgoingPaymentStatus.Succeeded(preimage1, 10 msat, Seq(HopSummary(hop_ab), HopSummary(hop_bc)), 1060))
using(connection.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, completed_at, failures) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, ps1.id.toString)
statement.setString(2, ps1.parentId.toString)
statement.setString(3, ps1.externalId.get)
statement.setBytes(4, ps1.paymentHash.toArray)
statement.setLong(5, ps1.amount.toLong)
statement.setBytes(6, ps1.recipientNodeId.value.toArray)
statement.setLong(7, ps1.createdAt)
statement.setLong(8, ps1.status.asInstanceOf[OutgoingPaymentStatus.Failed].completedAt)
statement.setBytes(9, SqlitePaymentsDb.paymentFailuresCodec.encode(ps1.status.asInstanceOf[OutgoingPaymentStatus.Failed].failures.toList).require.toByteArray)
statement.executeUpdate()
}
using(connection.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, payment_request) VALUES (?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, ps2.id.toString)
statement.setString(2, ps2.parentId.toString)
statement.setString(3, ps2.externalId.get)
statement.setBytes(4, ps2.paymentHash.toArray)
statement.setLong(5, ps2.amount.toLong)
statement.setBytes(6, ps2.recipientNodeId.value.toArray)
statement.setLong(7, ps2.createdAt)
statement.setString(8, PaymentRequest.write(invoice1))
statement.executeUpdate()
}
using(connection.prepareStatement("INSERT INTO sent_payments (id, parent_id, payment_hash, amount_msat, target_node_id, created_at, completed_at, payment_preimage, fees_msat, payment_route) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement =>
statement.setString(1, ps3.id.toString)
statement.setString(2, ps3.parentId.toString)
statement.setBytes(3, ps3.paymentHash.toArray)
statement.setLong(4, ps3.amount.toLong)
statement.setBytes(5, ps3.recipientNodeId.value.toArray)
statement.setLong(6, ps3.createdAt)
statement.setLong(7, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].completedAt)
statement.setBytes(8, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].paymentPreimage.toArray)
statement.setLong(9, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].feesPaid.toLong)
statement.setBytes(10, SqlitePaymentsDb.paymentRouteCodec.encode(ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].route.toList).require.toByteArray)
statement.executeUpdate()
}
// Changes between version 3 and 4 to sent_payments:
// - added final amount column
// - added payment type column, with a default to "Standard"
// - renamed target_node_id -> recipient_node_id
// - re-ordered columns
val preMigrationDb = new SqlitePaymentsDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 3) == 4) // version changed from 3 -> 4
}
assert(preMigrationDb.getOutgoingPayment(id1) === Some(ps1))
assert(preMigrationDb.listOutgoingPayments(parentId) === Seq(ps2, ps3))
val postMigrationDb = new SqlitePaymentsDb(connection)
using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 4) == 4) // version still to 4
}
val ps4 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), None, randomBytes32, PaymentType.SwapOut, 50 msat, 100 msat, carol, 1100, Some(invoice1), OutgoingPaymentStatus.Pending)
postMigrationDb.addOutgoingPayment(ps4)
postMigrationDb.updateOutgoingPayment(PaymentSent(parentId, paymentHash1, preimage1, ps2.recipientAmount, ps2.recipientNodeId, Seq(PaymentSent.PartialPayment(id2, ps2.amount, 15 msat, randomBytes32, Some(Seq(hop_ab)), 1105))))
assert(postMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 15 msat, Seq(HopSummary(hop_ab)), 1105)), ps3, ps4))
}
test("add/retrieve/update incoming payments") {
val sqlite = TestConstants.sqliteInMemory()
val db = new SqlitePaymentsDb(sqlite)
forAllDbs { dbs =>
val db = dbs.payments()
// can't receive a payment without an invoice associated with it
assertThrows[IllegalArgumentException](db.receiveIncomingPayment(randomBytes32, 12345678 msat))
// can't receive a payment without an invoice associated with it
assertThrows[IllegalArgumentException](db.receiveIncomingPayment(randomBytes32, 12345678 msat))
val expiredInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #1", timestamp = 1)
val expiredInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #2", timestamp = 2, expirySeconds = Some(30))
val expiredPayment1 = IncomingPayment(expiredInvoice1, randomBytes32, PaymentType.Standard, expiredInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired)
val expiredPayment2 = IncomingPayment(expiredInvoice2, randomBytes32, PaymentType.Standard, expiredInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired)
val expiredInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #1", timestamp = 1)
val expiredInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #2", timestamp = 2, expirySeconds = Some(30))
val expiredPayment1 = IncomingPayment(expiredInvoice1, randomBytes32, PaymentType.Standard, expiredInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired)
val expiredPayment2 = IncomingPayment(expiredInvoice2, randomBytes32, PaymentType.Standard, expiredInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired)
val pendingInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #3")
val pendingInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #4", expirySeconds = Some(30))
val pendingPayment1 = IncomingPayment(pendingInvoice1, randomBytes32, PaymentType.Standard, pendingInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending)
val pendingPayment2 = IncomingPayment(pendingInvoice2, randomBytes32, PaymentType.SwapIn, pendingInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending)
val pendingInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #3")
val pendingInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #4", expirySeconds = Some(30))
val pendingPayment1 = IncomingPayment(pendingInvoice1, randomBytes32, PaymentType.Standard, pendingInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending)
val pendingPayment2 = IncomingPayment(pendingInvoice2, randomBytes32, PaymentType.SwapIn, pendingInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending)
val paidInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #5")
val paidInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #6", expirySeconds = Some(60))
val receivedAt1 = System.currentTimeMillis + 1
val receivedAt2 = System.currentTimeMillis + 2
val payment1 = IncomingPayment(paidInvoice1, randomBytes32, PaymentType.Standard, paidInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(561 msat, receivedAt2))
val payment2 = IncomingPayment(paidInvoice2, randomBytes32, PaymentType.Standard, paidInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(1111 msat, receivedAt2))
val paidInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #5")
val paidInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #6", expirySeconds = Some(60))
val receivedAt1 = System.currentTimeMillis + 1
val receivedAt2 = System.currentTimeMillis + 2
val payment1 = IncomingPayment(paidInvoice1, randomBytes32, PaymentType.Standard, paidInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(561 msat, receivedAt2))
val payment2 = IncomingPayment(paidInvoice2, randomBytes32, PaymentType.Standard, paidInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(1111 msat, receivedAt2))
db.addIncomingPayment(pendingInvoice1, pendingPayment1.paymentPreimage)
db.addIncomingPayment(pendingInvoice2, pendingPayment2.paymentPreimage, PaymentType.SwapIn)
db.addIncomingPayment(expiredInvoice1, expiredPayment1.paymentPreimage)
db.addIncomingPayment(expiredInvoice2, expiredPayment2.paymentPreimage)
db.addIncomingPayment(paidInvoice1, payment1.paymentPreimage)
db.addIncomingPayment(paidInvoice2, payment2.paymentPreimage)
db.addIncomingPayment(pendingInvoice1, pendingPayment1.paymentPreimage)
db.addIncomingPayment(pendingInvoice2, pendingPayment2.paymentPreimage, PaymentType.SwapIn)
db.addIncomingPayment(expiredInvoice1, expiredPayment1.paymentPreimage)
db.addIncomingPayment(expiredInvoice2, expiredPayment2.paymentPreimage)
db.addIncomingPayment(paidInvoice1, payment1.paymentPreimage)
db.addIncomingPayment(paidInvoice2, payment2.paymentPreimage)
assert(db.getIncomingPayment(pendingInvoice1.paymentHash) === Some(pendingPayment1))
assert(db.getIncomingPayment(expiredInvoice2.paymentHash) === Some(expiredPayment2))
assert(db.getIncomingPayment(paidInvoice1.paymentHash) === Some(payment1.copy(status = IncomingPaymentStatus.Pending)))
assert(db.getIncomingPayment(pendingInvoice1.paymentHash) === Some(pendingPayment1))
assert(db.getIncomingPayment(expiredInvoice2.paymentHash) === Some(expiredPayment2))
assert(db.getIncomingPayment(paidInvoice1.paymentHash) === Some(payment1.copy(status = IncomingPaymentStatus.Pending)))
val now = System.currentTimeMillis
assert(db.listIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2, pendingPayment1, pendingPayment2, payment1.copy(status = IncomingPaymentStatus.Pending), payment2.copy(status = IncomingPaymentStatus.Pending)))
assert(db.listExpiredIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2))
assert(db.listReceivedIncomingPayments(0, now) === Nil)
assert(db.listPendingIncomingPayments(0, now) === Seq(pendingPayment1, pendingPayment2, payment1.copy(status = IncomingPaymentStatus.Pending), payment2.copy(status = IncomingPaymentStatus.Pending)))
val now = System.currentTimeMillis
assert(db.listIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2, pendingPayment1, pendingPayment2, payment1.copy(status = IncomingPaymentStatus.Pending), payment2.copy(status = IncomingPaymentStatus.Pending)))
assert(db.listExpiredIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2))
assert(db.listReceivedIncomingPayments(0, now) === Nil)
assert(db.listPendingIncomingPayments(0, now) === Seq(pendingPayment1, pendingPayment2, payment1.copy(status = IncomingPaymentStatus.Pending), payment2.copy(status = IncomingPaymentStatus.Pending)))
db.receiveIncomingPayment(paidInvoice1.paymentHash, 461 msat, receivedAt1)
db.receiveIncomingPayment(paidInvoice1.paymentHash, 100 msat, receivedAt2) // adding another payment to this invoice should sum
db.receiveIncomingPayment(paidInvoice2.paymentHash, 1111 msat, receivedAt2)
db.receiveIncomingPayment(paidInvoice1.paymentHash, 461 msat, receivedAt1)
db.receiveIncomingPayment(paidInvoice1.paymentHash, 100 msat, receivedAt2) // adding another payment to this invoice should sum
db.receiveIncomingPayment(paidInvoice2.paymentHash, 1111 msat, receivedAt2)
assert(db.getIncomingPayment(paidInvoice1.paymentHash) === Some(payment1))
assert(db.listIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2, pendingPayment1, pendingPayment2, payment1, payment2))
assert(db.listIncomingPayments(now - 60.seconds.toMillis, now) === Seq(pendingPayment1, pendingPayment2, payment1, payment2))
assert(db.listPendingIncomingPayments(0, now) === Seq(pendingPayment1, pendingPayment2))
assert(db.listReceivedIncomingPayments(0, now) === Seq(payment1, payment2))
assert(db.getIncomingPayment(paidInvoice1.paymentHash) === Some(payment1))
assert(db.listIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2, pendingPayment1, pendingPayment2, payment1, payment2))
assert(db.listIncomingPayments(now - 60.seconds.toMillis, now) === Seq(pendingPayment1, pendingPayment2, payment1, payment2))
assert(db.listPendingIncomingPayments(0, now) === Seq(pendingPayment1, pendingPayment2))
assert(db.listReceivedIncomingPayments(0, now) === Seq(payment1, payment2))
}
}
test("add/retrieve/update outgoing payments") {
val db = new SqlitePaymentsDb(TestConstants.sqliteInMemory())
forAllDbs { dbs =>
val db = dbs.payments()
val parentId = UUID.randomUUID()
val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(123 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 0)
val s1 = OutgoingPayment(UUID.randomUUID(), parentId, None, paymentHash1, PaymentType.Standard, 123 msat, 600 msat, dave, 100, Some(i1), OutgoingPaymentStatus.Pending)
val s2 = OutgoingPayment(UUID.randomUUID(), parentId, Some("1"), paymentHash1, PaymentType.SwapOut, 456 msat, 600 msat, dave, 200, None, OutgoingPaymentStatus.Pending)
val parentId = UUID.randomUUID()
val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(123 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 0)
val s1 = OutgoingPayment(UUID.randomUUID(), parentId, None, paymentHash1, PaymentType.Standard, 123 msat, 600 msat, dave, 100, Some(i1), OutgoingPaymentStatus.Pending)
val s2 = OutgoingPayment(UUID.randomUUID(), parentId, Some("1"), paymentHash1, PaymentType.SwapOut, 456 msat, 600 msat, dave, 200, None, OutgoingPaymentStatus.Pending)
assert(db.listOutgoingPayments(0, System.currentTimeMillis).isEmpty)
db.addOutgoingPayment(s1)
db.addOutgoingPayment(s2)
assert(db.listOutgoingPayments(0, System.currentTimeMillis).isEmpty)
db.addOutgoingPayment(s1)
db.addOutgoingPayment(s2)
// can't add an outgoing payment in non-pending state
assertThrows[IllegalArgumentException](db.addOutgoingPayment(s1.copy(status = OutgoingPaymentStatus.Succeeded(randomBytes32, 0 msat, Nil, 110))))
// can't add an outgoing payment in non-pending state
assertThrows[IllegalArgumentException](db.addOutgoingPayment(s1.copy(status = OutgoingPaymentStatus.Succeeded(randomBytes32, 0 msat, Nil, 110))))
assert(db.listOutgoingPayments(1, 300).toList == Seq(s1, s2))
assert(db.listOutgoingPayments(1, 150).toList == Seq(s1))
assert(db.listOutgoingPayments(150, 250).toList == Seq(s2))
assert(db.getOutgoingPayment(s1.id) === Some(s1))
assert(db.getOutgoingPayment(UUID.randomUUID()) === None)
assert(db.listOutgoingPayments(s2.paymentHash) === Seq(s1, s2))
assert(db.listOutgoingPayments(s1.id) === Nil)
assert(db.listOutgoingPayments(parentId) === Seq(s1, s2))
assert(db.listOutgoingPayments(ByteVector32.Zeroes) === Nil)
assert(db.listOutgoingPayments(1, 300).toList == Seq(s1, s2))
assert(db.listOutgoingPayments(1, 150).toList == Seq(s1))
assert(db.listOutgoingPayments(150, 250).toList == Seq(s2))
assert(db.getOutgoingPayment(s1.id) === Some(s1))
assert(db.getOutgoingPayment(UUID.randomUUID()) === None)
assert(db.listOutgoingPayments(s2.paymentHash) === Seq(s1, s2))
assert(db.listOutgoingPayments(s1.id) === Nil)
assert(db.listOutgoingPayments(parentId) === Seq(s1, s2))
assert(db.listOutgoingPayments(ByteVector32.Zeroes) === Nil)
val s3 = s2.copy(id = UUID.randomUUID(), amount = 789 msat, createdAt = 300)
val s4 = s2.copy(id = UUID.randomUUID(), paymentType = PaymentType.Standard, createdAt = 300)
db.addOutgoingPayment(s3)
db.addOutgoingPayment(s4)
val s3 = s2.copy(id = UUID.randomUUID(), amount = 789 msat, createdAt = 300)
val s4 = s2.copy(id = UUID.randomUUID(), paymentType = PaymentType.Standard, createdAt = 301)
db.addOutgoingPayment(s3)
db.addOutgoingPayment(s4)
db.updateOutgoingPayment(PaymentFailed(s3.id, s3.paymentHash, Nil, 310))
val ss3 = s3.copy(status = OutgoingPaymentStatus.Failed(Nil, 310))
assert(db.getOutgoingPayment(s3.id) === Some(ss3))
db.updateOutgoingPayment(PaymentFailed(s4.id, s4.paymentHash, Seq(LocalFailure(Seq(hop_ab), new RuntimeException("woops")), RemoteFailure(Seq(hop_ab, hop_bc), Sphinx.DecryptedFailurePacket(carol, UnknownNextPeer))), 320))
val ss4 = s4.copy(status = OutgoingPaymentStatus.Failed(Seq(FailureSummary(FailureType.LOCAL, "woops", List(HopSummary(alice, bob, Some(ShortChannelId(42))))), FailureSummary(FailureType.REMOTE, "processing node does not know the next peer in the route", List(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, None)))), 320))
assert(db.getOutgoingPayment(s4.id) === Some(ss4))
db.updateOutgoingPayment(PaymentFailed(s3.id, s3.paymentHash, Nil, 310))
val ss3 = s3.copy(status = OutgoingPaymentStatus.Failed(Nil, 310))
assert(db.getOutgoingPayment(s3.id) === Some(ss3))
db.updateOutgoingPayment(PaymentFailed(s4.id, s4.paymentHash, Seq(LocalFailure(Seq(hop_ab), new RuntimeException("woops")), RemoteFailure(Seq(hop_ab, hop_bc), Sphinx.DecryptedFailurePacket(carol, UnknownNextPeer))), 320))
val ss4 = s4.copy(status = OutgoingPaymentStatus.Failed(Seq(FailureSummary(FailureType.LOCAL, "woops", List(HopSummary(alice, bob, Some(ShortChannelId(42))))), FailureSummary(FailureType.REMOTE, "processing node does not know the next peer in the route", List(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, None)))), 320))
assert(db.getOutgoingPayment(s4.id) === Some(ss4))
// can't update again once it's in a final state
assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentSent(parentId, s3.paymentHash, preimage1, s3.recipientAmount, s3.recipientNodeId, Seq(PaymentSent.PartialPayment(s3.id, s3.amount, 42 msat, randomBytes32, None)))))
// can't update again once it's in a final state
assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentSent(parentId, s3.paymentHash, preimage1, s3.recipientAmount, s3.recipientNodeId, Seq(PaymentSent.PartialPayment(s3.id, s3.amount, 42 msat, randomBytes32, None)))))
val paymentSent = PaymentSent(parentId, paymentHash1, preimage1, 600 msat, carol, Seq(
PaymentSent.PartialPayment(s1.id, s1.amount, 15 msat, randomBytes32, None, 400),
PaymentSent.PartialPayment(s2.id, s2.amount, 20 msat, randomBytes32, Some(Seq(hop_ab, hop_bc)), 410)
))
val ss1 = s1.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 15 msat, Nil, 400))
val ss2 = s2.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 20 msat, Seq(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, None)), 410))
db.updateOutgoingPayment(paymentSent)
assert(db.getOutgoingPayment(s1.id) === Some(ss1))
assert(db.getOutgoingPayment(s2.id) === Some(ss2))
assert(db.listOutgoingPayments(parentId) === Seq(ss1, ss2, ss3, ss4))
val paymentSent = PaymentSent(parentId, paymentHash1, preimage1, 600 msat, carol, Seq(
PaymentSent.PartialPayment(s1.id, s1.amount, 15 msat, randomBytes32, None, 400),
PaymentSent.PartialPayment(s2.id, s2.amount, 20 msat, randomBytes32, Some(Seq(hop_ab, hop_bc)), 410)
))
val ss1 = s1.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 15 msat, Nil, 400))
val ss2 = s2.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 20 msat, Seq(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, None)), 410))
db.updateOutgoingPayment(paymentSent)
assert(db.getOutgoingPayment(s1.id) === Some(ss1))
assert(db.getOutgoingPayment(s2.id) === Some(ss2))
assert(db.listOutgoingPayments(parentId) === Seq(ss1, ss2, ss3, ss4))
// can't update again once it's in a final state
assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentFailed(s1.id, s1.paymentHash, Nil)))
// can't update again once it's in a final state
assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentFailed(s1.id, s1.paymentHash, Nil)))
}
}
test("high level payments overview") {

View File

@ -16,50 +16,49 @@
package fr.acinq.eclair.db
import java.sql.DriverManager
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.db.sqlite.SqlitePeersDb
import fr.acinq.eclair.randomKey
import fr.acinq.eclair.wire.{NodeAddress, Tor2, Tor3}
import fr.acinq.eclair.{TestConstants, randomKey}
import org.scalatest.funsuite.AnyFunSuite
class SqlitePeersDbSpec extends AnyFunSuite {
def inmem = DriverManager.getConnection("jdbc:sqlite::memory:")
import TestConstants.forAllDbs
test("init sqlite 2 times in a row") {
val sqlite = inmem
val db1 = new SqlitePeersDb(sqlite)
val db2 = new SqlitePeersDb(sqlite)
forAllDbs { dbs =>
val db1 = dbs.peers()
val db2 = dbs.peers()
}
}
test("add/remove/get/list peers") {
val sqlite = inmem
val db = new SqlitePeersDb(sqlite)
test("add/remove/list peers") {
forAllDbs { dbs =>
val db = dbs.peers()
case class TestCase(nodeId: PublicKey, nodeAddress: NodeAddress)
case class TestCase(nodeId: PublicKey, nodeAddress: NodeAddress)
val peer_1 = TestCase(randomKey.publicKey, NodeAddress.fromParts("127.0.0.1", 42000).get)
val peer_1_bis = TestCase(peer_1.nodeId, NodeAddress.fromParts("127.0.0.1", 1112).get)
val peer_2 = TestCase(randomKey.publicKey, Tor2("z4zif3fy7fe7bpg3", 4231))
val peer_3 = TestCase(randomKey.publicKey, Tor3("mrl2d3ilhctt2vw4qzvmz3etzjvpnc6dczliq5chrxetthgbuczuggyd", 4231))
val peer_1 = TestCase(randomKey.publicKey, NodeAddress.fromParts("127.0.0.1", 42000).get)
val peer_1_bis = TestCase(peer_1.nodeId, NodeAddress.fromParts("127.0.0.1", 1112).get)
val peer_2 = TestCase(randomKey.publicKey, Tor2("z4zif3fy7fe7bpg3", 4231))
val peer_3 = TestCase(randomKey.publicKey, Tor3("mrl2d3ilhctt2vw4qzvmz3etzjvpnc6dczliq5chrxetthgbuczuggyd", 4231))
assert(db.listPeers().toSet === Set.empty)
db.addOrUpdatePeer(peer_1.nodeId, peer_1.nodeAddress)
assert(db.getPeer(peer_1.nodeId) === Some(peer_1.nodeAddress))
assert(db.getPeer(peer_2.nodeId) === None)
db.addOrUpdatePeer(peer_1.nodeId, peer_1.nodeAddress) // duplicate is ignored
assert(db.listPeers().size === 1)
db.addOrUpdatePeer(peer_2.nodeId, peer_2.nodeAddress)
db.addOrUpdatePeer(peer_3.nodeId, peer_3.nodeAddress)
assert(db.listPeers().map(p => TestCase(p._1, p._2)).toSet === Set(peer_1, peer_2, peer_3))
db.removePeer(peer_2.nodeId)
assert(db.listPeers().map(p => TestCase(p._1, p._2)).toSet === Set(peer_1, peer_3))
db.addOrUpdatePeer(peer_1_bis.nodeId, peer_1_bis.nodeAddress)
assert(db.getPeer(peer_1.nodeId) === Some(peer_1_bis.nodeAddress))
assert(db.listPeers().map(p => TestCase(p._1, p._2)).toSet === Set(peer_1_bis, peer_3))
assert(db.listPeers().toSet === Set.empty)
db.addOrUpdatePeer(peer_1.nodeId, peer_1.nodeAddress)
assert(db.getPeer(peer_1.nodeId) === Some(peer_1.nodeAddress))
assert(db.getPeer(peer_2.nodeId) === None)
db.addOrUpdatePeer(peer_1.nodeId, peer_1.nodeAddress) // duplicate is ignored
assert(db.listPeers().size === 1)
db.addOrUpdatePeer(peer_2.nodeId, peer_2.nodeAddress)
db.addOrUpdatePeer(peer_3.nodeId, peer_3.nodeAddress)
assert(db.listPeers().map(p => TestCase(p._1, p._2)).toSet === Set(peer_1, peer_2, peer_3))
db.removePeer(peer_2.nodeId)
assert(db.listPeers().map(p => TestCase(p._1, p._2)).toSet === Set(peer_1, peer_3))
db.addOrUpdatePeer(peer_1_bis.nodeId, peer_1_bis.nodeAddress)
assert(db.getPeer(peer_1.nodeId) === Some(peer_1_bis.nodeAddress))
assert(db.listPeers().map(p => TestCase(p._1, p._2)).toSet === Set(peer_1_bis, peer_3))
}
}
}

View File

@ -17,46 +17,49 @@
package fr.acinq.eclair.db
import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FAIL_MALFORMED_HTLC, CMD_FULFILL_HTLC}
import fr.acinq.eclair.db.sqlite.SqlitePendingRelayDb
import fr.acinq.eclair.{TestConstants, randomBytes32}
import fr.acinq.eclair.wire.FailureMessageCodecs
import fr.acinq.eclair.{TestConstants, randomBytes32}
import org.scalatest.funsuite.AnyFunSuite
class SqlitePendingRelayDbSpec extends AnyFunSuite {
import TestConstants.forAllDbs
test("init sqlite 2 times in a row") {
val sqlite = TestConstants.sqliteInMemory()
val db1 = new SqlitePendingRelayDb(sqlite)
val db2 = new SqlitePendingRelayDb(sqlite)
forAllDbs { dbs =>
val db1 = dbs.pendingRelay()
val db2 = dbs.pendingRelay()
}
}
test("add/remove/list messages") {
val sqlite = TestConstants.sqliteInMemory()
val db = new SqlitePendingRelayDb(sqlite)
forAllDbs { dbs =>
val db = dbs.pendingRelay()
val channelId1 = randomBytes32
val channelId2 = randomBytes32
val msg0 = CMD_FULFILL_HTLC(0, randomBytes32)
val msg1 = CMD_FULFILL_HTLC(1, randomBytes32)
val msg2 = CMD_FAIL_HTLC(2, Left(randomBytes32))
val msg3 = CMD_FAIL_HTLC(3, Left(randomBytes32))
val msg4 = CMD_FAIL_MALFORMED_HTLC(4, randomBytes32, FailureMessageCodecs.BADONION)
val channelId1 = randomBytes32
val channelId2 = randomBytes32
val msg0 = CMD_FULFILL_HTLC(0, randomBytes32)
val msg1 = CMD_FULFILL_HTLC(1, randomBytes32)
val msg2 = CMD_FAIL_HTLC(2, Left(randomBytes32))
val msg3 = CMD_FAIL_HTLC(3, Left(randomBytes32))
val msg4 = CMD_FAIL_MALFORMED_HTLC(4, randomBytes32, FailureMessageCodecs.BADONION)
assert(db.listPendingRelay(channelId1).toSet === Set.empty)
db.addPendingRelay(channelId1, msg0)
db.addPendingRelay(channelId1, msg0) // duplicate
db.addPendingRelay(channelId1, msg1)
db.addPendingRelay(channelId1, msg2)
db.addPendingRelay(channelId1, msg3)
db.addPendingRelay(channelId1, msg4)
db.addPendingRelay(channelId2, msg0) // same messages but for different channel
db.addPendingRelay(channelId2, msg1)
assert(db.listPendingRelay(channelId1).toSet === Set(msg0, msg1, msg2, msg3, msg4))
assert(db.listPendingRelay(channelId2).toSet === Set(msg0, msg1))
assert(db.listPendingRelay === Set((channelId1, msg0.id), (channelId1, msg1.id), (channelId1, msg2.id), (channelId1, msg3.id), (channelId1, msg4.id), (channelId2, msg0.id), (channelId2, msg1.id)))
db.removePendingRelay(channelId1, msg1.id)
assert(db.listPendingRelay === Set((channelId1, msg0.id), (channelId1, msg2.id), (channelId1, msg3.id), (channelId1, msg4.id), (channelId2, msg0.id), (channelId2, msg1.id)))
assert(db.listPendingRelay(channelId1).toSet === Set.empty)
db.addPendingRelay(channelId1, msg0)
db.addPendingRelay(channelId1, msg0) // duplicate
db.addPendingRelay(channelId1, msg1)
db.addPendingRelay(channelId1, msg2)
db.addPendingRelay(channelId1, msg3)
db.addPendingRelay(channelId1, msg4)
db.addPendingRelay(channelId2, msg0) // same messages but for different channel
db.addPendingRelay(channelId2, msg1)
assert(db.listPendingRelay(channelId1).toSet === Set(msg0, msg1, msg2, msg3, msg4))
assert(db.listPendingRelay(channelId2).toSet === Set(msg0, msg1))
assert(db.listPendingRelay === Set((channelId1, msg0.id), (channelId1, msg1.id), (channelId1, msg2.id), (channelId1, msg3.id), (channelId1, msg4.id), (channelId2, msg0.id), (channelId2, msg1.id)))
db.removePendingRelay(channelId1, msg1.id)
assert(db.listPendingRelay === Set((channelId1, msg0.id), (channelId1, msg2.id), (channelId1, msg3.id), (channelId1, msg4.id), (channelId2, msg0.id), (channelId2, msg1.id)))
}
}
}

View File

@ -16,10 +16,11 @@
package fr.acinq.eclair.db
import java.sql.SQLException
import fr.acinq.eclair.TestConstants
import fr.acinq.eclair.db.sqlite.SqliteUtils.using
import org.scalatest.funsuite.AnyFunSuite
import org.sqlite.SQLiteException
class SqliteUtilsSpec extends AnyFunSuite {
@ -41,7 +42,7 @@ class SqliteUtilsSpec extends AnyFunSuite {
assert(!results.next())
}
assertThrows[SQLiteException](using(conn.createStatement(), inTransaction = true) { statement =>
assertThrows[SQLException](using(conn.createStatement(), inTransaction = true) { statement =>
statement.executeUpdate("INSERT INTO utils_test VALUES (3, 3)")
statement.executeUpdate("INSERT INTO utils_test VALUES (1, 3)") // should throw (primary key violation)
})

View File

@ -1 +1 @@
{"version":"1.0.0-SNAPSHOT-e3f1ec0","nodeId":"03af0ed6052cf28d670665549bc86f4b721c9fdb309d40c58f5811f63966e005d0","alias":"alice","color":"#000102","features":{"activated":[{"name":"option_data_loss_protect","support":"mandatory"},{"name":"gossip_queries_ex","support":"optional"}],"unknown":[]},"chainHash":"06226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f","network":"regtest","blockHeight":9999,"publicAddresses":["localhost:9731"]}
{"version":"1.0.0-SNAPSHOT-e3f1ec0","nodeId":"03af0ed6052cf28d670665549bc86f4b721c9fdb309d40c58f5811f63966e005d0","alias":"alice","color":"#000102","features":{"activated":[{"name":"option_data_loss_protect","support":"mandatory"},{"name":"gossip_queries_ex","support":"optional"}],"unknown":[]},"chainHash":"06226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f","network":"regtest","blockHeight":9999,"publicAddresses":["localhost:9731"],"instanceId":"01234567-0123-4567-89ab-0123456789ab"}

View File

@ -180,7 +180,8 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM
chainHash = ByteVector32(hex"06226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f"),
network = "regtest",
blockHeight = 9999,
publicAddresses = NodeAddress.fromParts("localhost", 9731).get :: Nil
publicAddresses = NodeAddress.fromParts("localhost", 9731).get :: Nil,
instanceId = "01234567-0123-4567-89ab-0123456789ab"
))
Post("/getinfo") ~>