1
0
Fork 0
mirror of https://github.com/ACINQ/eclair.git synced 2025-03-15 12:20:13 +01:00

Merge commit 'b63c4aa5a' into android

This commit is contained in:
pm47 2020-10-09 12:25:17 +02:00
commit 648c841c0d
No known key found for this signature in database
GPG key ID: E434ED292E85643A
26 changed files with 1356 additions and 1037 deletions

2
.gitignore vendored
View file

@ -25,3 +25,5 @@ target/
project/target project/target
DeleteMe*.* DeleteMe*.*
*~ *~
.DS_Store

114
docs/PostgreSQL.md Normal file
View file

@ -0,0 +1,114 @@
## PostgreSQL Configuration
By default Eclair stores its data on the machine's local file system (typically in `~/.eclair` directory) using SQLite.
It also supports PostgreSQL version 10.6 and higher as a database backend.
To enable PostgreSQL support set the `driver` parameter to `postgres`:
```
eclair.db.driver = postgres
```
### Connection settings
To configure the connection settings use the `database`, `host`, `port` `username` and `password` parameters:
```
eclair.db.postgres.database = "mydb"
eclair.db.postgres.host = "127.0.0.1" # Default: "localhost"
eclair.db.postgres.port = 12345 # Default: 5432
eclair.db.postgres.username = "myuser"
eclair.db.postgres.password = "mypassword"
```
Eclair uses Hikari connection pool (https://github.com/brettwooldridge/HikariCP) which has a lot of configuration
parameters. Some of them can be set in Eclair config file. The most important is `pool.max-size`, it defines the maximum
allowed number of simultaneous connections to the database.
A good rule of thumb is to set `pool.max-size` to the CPU core count times 2.
See https://github.com/brettwooldridge/HikariCP/wiki/About-Pool-Sizing for better estimation.
```
eclair.db.postgres.pool {
max-size = 8 # Default: 10
connection-timeout = 10 seconds # Default: 30 seconds
idle-timeout = 1 minute # Default: 10 minutes
max-life-time = 15 minutes # Default: 30 minutes
}
```
### Locking settings
Running multiple Eclair processes connected to the same database can lead to data corruption and loss of funds.
That's why Eclair supports database locking mechanisms to prevent multiple Eclair instances from accessing one database together.
Use `postgres.lock-type` parameter to set the locking schemes.
Lock type | Description
---|---
`lease` | At the beginning, Eclair acquires a lease for the database that expires after some time. Then it constantly extends the lease. On each lease extension and each database transaction, Eclair checks if the lease belongs to the Eclair instance. If it doesn't, Eclair assumes that the database was updated by another Eclair process and terminates. Note that this is just a safeguard feature for Eclair rather than a bulletproof database-wide lock, because third-party applications still have the ability to access the database without honoring this locking scheme.
`none` | No locking at all. Useful for tests. DO NOT USE ON MAINNET!
```
eclair.db.postgres.lock-type = "none" // Default: "lease"
```
#### Database Lease Settings
There are two main configuration parameters for the lease locking scheme: `lease.interval` and `lease.renew-interval`.
`lease.interval` defines lease validity time. During the lease time no other node can acquire the lock, except the lease holder.
After that time the lease is assumed expired, any node can acquire the lease. So that only one node can update the database
at a time. Eclair extends the lease every `lease.renew-interval` until terminated.
```
eclair.db.postgres.lease {
interval = 30 seconds // Default: 5 minutes
renew-interval = 10 seconds // Default: 1 minute
}
```
### Backups and replication
The PostgreSQL driver doesn't support Eclair's built-in online backups. Instead, you should use the tools provided
by PostgreSQL.
#### Backup/Restore
For nodes with infrequent channel updates its easier to use `pg_dump` to perform the task.
It's important to stop the node to prevent any channel updates while a backup/restore operation is in progress. It makes
sense to backup the database after each channel update, to prevent restoring an outdated channel's state and consequently
losing the funds associated with that channel.
For more information about backup refer to the official PostgreSQL documentation: https://www.postgresql.org/docs/current/backup.html
#### Replication
For busier nodes it isn't practical to use `pg_dump`. Fortunately, PostgreSQL provides built-in database replication which makes the backup/restore process more seamless.
To set up database replication you need to create a main database, that accepts all changes from the node, and a replica database.
Once replication is configured, the main database will automatically send all the changes to the replica.
In case of failure of the main database, the node can be simply reconfigured to use the replica instead of the main database.
PostgreSQL supports [different types of replication](https://www.postgresql.org/docs/current/different-replication-solutions.html).
The most suitable type for an Eclair node is [synchronous streaming replication](https://www.postgresql.org/docs/current/warm-standby.html#SYNCHRONOUS-REPLICATION),
because it provides a very important feature, that helps keep the replicated channel's state up to date:
> When requesting synchronous replication, each commit of a write transaction will wait until confirmation is received that the commit has been written to the write-ahead log on disk of both the primary and standby server.
Follow the official PostgreSQL high availability documentation for the instructions to set up synchronous streaming replication: https://www.postgresql.org/docs/current/high-availability.html
### Safeguard to prevent accidental loss of funds due to database misconfiguration
Using Eclair with an outdated version of the database or a database created with another seed might lead to loss of funds.
Every time Eclair starts, it checks if the Postgres database connection settings were changed since the last start.
If in fact the settings were changed, Eclair stops immediately to prevent potentially dangerous
but accidental configuration changes to come into effect.
Eclair stores the latest database settings in the `${data-dir}/last_jdbcurl` file, and compares its contents with the database settings from the config file.
The node operator can force Eclair to accept new database
connection settings by removing the `last_jdbcurl` file.

View file

@ -182,6 +182,28 @@ eclair {
port = 9051 port = 9051
private-key-file = "tor.dat" private-key-file = "tor.dat"
} }
db {
driver = "sqlite" // sqlite, postgres
postgres {
database = "eclair"
host = "localhost"
port = 5432
username = ""
password = ""
pool {
max-size = 10 // recommended value = number_of_cpu_cores * 2
connection-timeout = 30 seconds
idle-timeout = 10 minutes
max-life-time = 30 minutes
}
lease {
interval = 5 minutes // lease-interval must be greater than lease-renew-interval
renew-interval = 1 minute
}
lock-type = "lease" // lease or none
}
}
} }
// do not edit or move this section // do not edit or move this section

View file

@ -18,6 +18,7 @@ package fr.acinq.eclair
import java.io.File import java.io.File
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.util.UUID
import java.util.concurrent.atomic.{AtomicLong, AtomicReference} import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
import akka.actor.{Actor, ActorLogging, ActorRef, ActorSystem, PoisonPill, Props, ReceiveTimeout, SupervisorStrategy} import akka.actor.{Actor, ActorLogging, ActorRef, ActorSystem, PoisonPill, Props, ReceiveTimeout, SupervisorStrategy}
@ -55,6 +56,7 @@ class CheckElectrumSetup(datadir: File,
val config = system.settings.config.getConfig("eclair") val config = system.settings.config.getConfig("eclair")
val chain = config.getString("chain") val chain = config.getString("chain")
val keyManager = new LocalKeyManager(randomBytes(32), NodeParams.hashFromChain(chain)) val keyManager = new LocalKeyManager(randomBytes(32), NodeParams.hashFromChain(chain))
val instanceId = UUID.randomUUID()
val database = db match { val database = db match {
case Some(d) => d case Some(d) => d
case None => Databases.sqliteJDBC(new File(datadir, chain)) case None => Databases.sqliteJDBC(new File(datadir, chain))
@ -83,7 +85,7 @@ class CheckElectrumSetup(datadir: File,
override def getFeeratePerKw(target: Int): Long = feeratesPerKw.get().feePerBlock(target) override def getFeeratePerKw(target: Int): Long = feeratesPerKw.get().feePerBlock(target)
} }
val nodeParams = NodeParams.makeNodeParams(config, keyManager, None, database, blockCount, feeEstimator) val nodeParams = NodeParams.makeNodeParams(config, instanceId, keyManager, None, database, blockCount, feeEstimator)
logger.info(s"nodeid=${nodeParams.nodeId} alias=${nodeParams.alias}") logger.info(s"nodeid=${nodeParams.nodeId} alias=${nodeParams.alias}")
logger.info(s"using chain=$chain chainHash=${nodeParams.chainHash}") logger.info(s"using chain=$chain chainHash=${nodeParams.chainHash}")

View file

@ -45,7 +45,7 @@ import scala.concurrent.duration._
import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.{ExecutionContext, Future}
import scala.reflect.ClassTag import scala.reflect.ClassTag
case class GetInfoResponse(version: String, nodeId: PublicKey, alias: String, color: String, features: Features, chainHash: ByteVector32, network: String, blockHeight: Int, publicAddresses: Seq[NodeAddress]) case class GetInfoResponse(version: String, nodeId: PublicKey, alias: String, color: String, features: Features, chainHash: ByteVector32, network: String, blockHeight: Int, publicAddresses: Seq[NodeAddress], instanceId: String)
case class AuditResponse(sent: Seq[PaymentSent], received: Seq[PaymentReceived], relayed: Seq[PaymentRelayed]) case class AuditResponse(sent: Seq[PaymentSent], received: Seq[PaymentReceived], relayed: Seq[PaymentRelayed])
@ -367,7 +367,8 @@ class EclairImpl(appKit: Kit) extends Eclair {
chainHash = appKit.nodeParams.chainHash, chainHash = appKit.nodeParams.chainHash,
network = NodeParams.chainFromHash(appKit.nodeParams.chainHash), network = NodeParams.chainFromHash(appKit.nodeParams.chainHash),
blockHeight = appKit.nodeParams.currentBlockHeight.toInt, blockHeight = appKit.nodeParams.currentBlockHeight.toInt,
publicAddresses = appKit.nodeParams.publicAddresses) publicAddresses = appKit.nodeParams.publicAddresses,
instanceId = appKit.nodeParams.instanceId.toString)
) )
override def usableBalances()(implicit timeout: Timeout): Future[Iterable[UsableBalance]] = override def usableBalances()(implicit timeout: Timeout): Future[Iterable[UsableBalance]] =

View file

@ -18,12 +18,12 @@ package fr.acinq.eclair
import java.io.File import java.io.File
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.sql.DriverManager import java.util.UUID
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.atomic.AtomicLong
import com.typesafe.config.{Config, ConfigFactory, ConfigValueType}
import com.google.common.io.Files import com.google.common.io.Files
import com.typesafe.config.{Config, ConfigFactory, ConfigValueType}
import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.bitcoin.{Block, ByteVector32, Satoshi} import fr.acinq.bitcoin.{Block, ByteVector32, Satoshi}
import fr.acinq.eclair.NodeParams.WatcherType import fr.acinq.eclair.NodeParams.WatcherType
@ -36,14 +36,14 @@ import fr.acinq.eclair.tor.Socks5ProxyParams
import fr.acinq.eclair.wire.{Color, EncodingType, NodeAddress} import fr.acinq.eclair.wire.{Color, EncodingType, NodeAddress}
import scodec.bits.ByteVector import scodec.bits.ByteVector
import scala.concurrent.duration.FiniteDuration
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import scala.concurrent.duration.FiniteDuration
/** /**
* Created by PM on 26/02/2017. * Created by PM on 26/02/2017.
*/ */
case class NodeParams(keyManager: KeyManager, case class NodeParams(keyManager: KeyManager,
instanceId: UUID, // a unique instance ID regenerated after each restart
private val blockCount: AtomicLong, private val blockCount: AtomicLong,
alias: String, alias: String,
color: Color, color: Color,
@ -134,7 +134,7 @@ object NodeParams {
def chainFromHash(chainHash: ByteVector32): String = chain2Hash.map(_.swap).getOrElse(chainHash, throw new RuntimeException(s"invalid chainHash '$chainHash'")) def chainFromHash(chainHash: ByteVector32): String = chain2Hash.map(_.swap).getOrElse(chainHash, throw new RuntimeException(s"invalid chainHash '$chainHash'"))
def makeNodeParams(config: Config, keyManager: KeyManager, torAddress_opt: Option[NodeAddress], database: Databases, blockCount: AtomicLong, feeEstimator: FeeEstimator): NodeParams = { def makeNodeParams(config: Config, instanceId: UUID, keyManager: KeyManager, torAddress_opt: Option[NodeAddress], database: Databases, blockCount: AtomicLong, feeEstimator: FeeEstimator): NodeParams = {
// check configuration for keys that have been renamed // check configuration for keys that have been renamed
val deprecatedKeyPaths = Map( val deprecatedKeyPaths = Map(
// v0.3.2 // v0.3.2
@ -237,6 +237,7 @@ object NodeParams {
NodeParams( NodeParams(
keyManager = keyManager, keyManager = keyManager,
instanceId = instanceId,
blockCount = blockCount, blockCount = blockCount,
alias = nodeAlias, alias = nodeAlias,
color = Color(color(0), color(1), color(2)), color = Color(color(0), color(1), color(2)),

View file

@ -19,6 +19,7 @@ package fr.acinq.eclair
import java.io.File import java.io.File
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.sql.DriverManager import java.sql.DriverManager
import java.util.UUID
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.{AtomicLong, AtomicReference} import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
@ -36,8 +37,9 @@ import fr.acinq.eclair.blockchain.fee.{ConstantFeeProvider, _}
import fr.acinq.eclair.blockchain.{EclairWallet, _} import fr.acinq.eclair.blockchain.{EclairWallet, _}
import fr.acinq.eclair.channel.Register import fr.acinq.eclair.channel.Register
import fr.acinq.eclair.crypto.LocalKeyManager import fr.acinq.eclair.crypto.LocalKeyManager
import fr.acinq.eclair.db.Databases.FileBackup
import fr.acinq.eclair.db.sqlite.SqliteFeeratesDb import fr.acinq.eclair.db.sqlite.SqliteFeeratesDb
import fr.acinq.eclair.db.{BackupHandler, Databases} import fr.acinq.eclair.db.{Databases, FileBackupHandler}
import fr.acinq.eclair.io.{ClientSpawner, Switchboard} import fr.acinq.eclair.io.{ClientSpawner, Switchboard}
import fr.acinq.eclair.payment.Auditor import fr.acinq.eclair.payment.Auditor
import fr.acinq.eclair.payment.receive.PaymentHandler import fr.acinq.eclair.payment.receive.PaymentHandler
@ -79,11 +81,11 @@ class Setup(datadir: File,
val chain = config.getString("chain") val chain = config.getString("chain")
val chaindir = new File(datadir, chain) val chaindir = new File(datadir, chain)
val keyManager = new LocalKeyManager(seed, NodeParams.hashFromChain(chain)) val keyManager = new LocalKeyManager(seed, NodeParams.hashFromChain(chain))
val instanceId = UUID.randomUUID()
val database = db match { logger.info(s"instanceid=$instanceId")
case Some(d) => d
case None => Databases.sqliteJDBC(chaindir) val databases = Databases.init(config.getConfig("db"), instanceId, datadir, chaindir, db)
}
/** /**
* This counter holds the current blockchain height. * This counter holds the current blockchain height.
@ -111,7 +113,7 @@ class Setup(datadir: File,
// @formatter:on // @formatter:on
} }
val nodeParams = NodeParams.makeNodeParams(config, keyManager, None, database, blockCount, feeEstimator) val nodeParams = NodeParams.makeNodeParams(config, instanceId, keyManager, None, databases, blockCount, feeEstimator)
val serverBindingAddress = new InetSocketAddress( val serverBindingAddress = new InetSocketAddress(
config.getString("server.binding-ip"), config.getString("server.binding-ip"),
@ -228,12 +230,16 @@ class Setup(datadir: File,
// do not change the name of this actor. it is used in the configuration to specify a custom bounded mailbox // do not change the name of this actor. it is used in the configuration to specify a custom bounded mailbox
backupHandler = if (config.getBoolean("enable-db-backup")) { backupHandler = if (config.getBoolean("enable-db-backup")) {
system.actorOf(SimpleSupervisor.props( nodeParams.db match {
BackupHandler.props( case fileBackup: FileBackup => system.actorOf(SimpleSupervisor.props(
nodeParams.db, FileBackupHandler.props(
fileBackup,
new File(chaindir, "eclair.sqlite.bak"), new File(chaindir, "eclair.sqlite.bak"),
if (config.hasPath("backup-notify-script")) Some(config.getString("backup-notify-script")) else None if (config.hasPath("backup-notify-script")) Some(config.getString("backup-notify-script")) else None),
), "backuphandler", SupervisorStrategy.Resume)) "backuphandler", SupervisorStrategy.Resume))
case _ =>
system.deadLetters
}
} else { } else {
logger.warn("database backup is disabled") logger.warn("database backup is disabled")
system.deadLetters system.deadLetters

View file

@ -17,6 +17,7 @@
package fr.acinq.eclair package fr.acinq.eclair
import java.io.File import java.io.File
import java.util.UUID
import java.util.concurrent.atomic.{AtomicLong, AtomicReference} import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
import akka.actor.{Actor, ActorLogging, ActorSystem, Props, ReceiveTimeout, SupervisorStrategy} import akka.actor.{Actor, ActorLogging, ActorSystem, Props, ReceiveTimeout, SupervisorStrategy}
@ -53,6 +54,7 @@ class SyncLiteSetup(datadir: File,
val config = system.settings.config.getConfig("eclair") val config = system.settings.config.getConfig("eclair")
val chain = config.getString("chain") val chain = config.getString("chain")
val keyManager = new LocalKeyManager(randomBytes32, NodeParams.hashFromChain(chain)) val keyManager = new LocalKeyManager(randomBytes32, NodeParams.hashFromChain(chain))
val instanceId = UUID.randomUUID()
val database = db match { val database = db match {
case Some(d) => d case Some(d) => d
case None => Databases.sqliteJDBC(new File(datadir, chain)) case None => Databases.sqliteJDBC(new File(datadir, chain))
@ -81,7 +83,7 @@ class SyncLiteSetup(datadir: File,
override def getFeeratePerKw(target: Int): Long = feeratesPerKw.get().feePerBlock(target) override def getFeeratePerKw(target: Int): Long = feeratesPerKw.get().feePerBlock(target)
} }
val nodeParams = NodeParams.makeNodeParams(config, keyManager, None, database, blockCount, feeEstimator) val nodeParams = NodeParams.makeNodeParams(config, instanceId, keyManager, None, database, blockCount, feeEstimator)
logger.info(s"nodeid=${nodeParams.nodeId} alias=${nodeParams.alias}") logger.info(s"nodeid=${nodeParams.nodeId} alias=${nodeParams.alias}")
logger.info(s"using chain=$chain chainHash=${nodeParams.chainHash}") logger.info(s"using chain=$chain chainHash=${nodeParams.chainHash}")

View file

@ -17,8 +17,12 @@
package fr.acinq.eclair.db package fr.acinq.eclair.db
import java.io.File import java.io.File
import java.nio.file._
import java.sql.{Connection, DriverManager} import java.sql.{Connection, DriverManager}
import java.util.UUID
import akka.actor.ActorSystem
import com.typesafe.config.Config
import fr.acinq.eclair.db.sqlite._ import fr.acinq.eclair.db.sqlite._
import grizzled.slf4j.Logging import grizzled.slf4j.Logging
@ -35,12 +39,29 @@ trait Databases {
val payments: PaymentsDb val payments: PaymentsDb
val pendingRelay: PendingRelayDb val pendingRelay: PendingRelayDb
def backup(file: File): Unit
} }
object Databases extends Logging { object Databases extends Logging {
trait FileBackup { this: Databases =>
def backup(backupFile: File): Unit
}
trait ExclusiveLock { this: Databases =>
def obtainExclusiveLock(): Unit
}
def init(dbConfig: Config, instanceId: UUID, datadir: File, chaindir: File, db: Option[Databases] = None)(implicit system: ActorSystem): Databases = {
db match {
case Some(d) => d
case None =>
dbConfig.getString("driver") match {
case "sqlite" => Databases.sqliteJDBC(chaindir)
case driver => throw new RuntimeException(s"Unknown database driver `$driver`")
}
}
}
/** /**
* Given a parent folder it creates or loads all the databases from a JDBC connection * Given a parent folder it creates or loads all the databases from a JDBC connection
* *
@ -58,7 +79,7 @@ object Databases extends Logging {
sqliteAudit = DriverManager.getConnection(s"jdbc:sqlite:${new File(dbdir, "audit.sqlite")}") sqliteAudit = DriverManager.getConnection(s"jdbc:sqlite:${new File(dbdir, "audit.sqlite")}")
SqliteUtils.obtainExclusiveLock(sqliteEclair) // there should only be one process writing to this file SqliteUtils.obtainExclusiveLock(sqliteEclair) // there should only be one process writing to this file
logger.info("successful lock on eclair.sqlite") logger.info("successful lock on eclair.sqlite")
databaseByConnections(sqliteAudit, sqliteNetwork, sqliteEclair) sqliteDatabaseByConnections(sqliteAudit, sqliteNetwork, sqliteEclair)
} catch { } catch {
case t: Throwable => { case t: Throwable => {
logger.error("could not create connection to sqlite databases: ", t) logger.error("could not create connection to sqlite databases: ", t)
@ -68,23 +89,40 @@ object Databases extends Logging {
throw t throw t
} }
} }
} }
def databaseByConnections(auditJdbc: Connection, networkJdbc: Connection, eclairJdbc: Connection) = new Databases { def sqliteDatabaseByConnections(auditJdbc: Connection, networkJdbc: Connection, eclairJdbc: Connection): Databases = new Databases with FileBackup {
override val network = new SqliteNetworkDb(networkJdbc) override val network = new SqliteNetworkDb(networkJdbc)
override val audit = new SqliteAuditDb(auditJdbc) override val audit = new SqliteAuditDb(auditJdbc)
override val channels = new SqliteChannelsDb(eclairJdbc) override val channels = new SqliteChannelsDb(eclairJdbc)
override val peers = new SqlitePeersDb(eclairJdbc) override val peers = new SqlitePeersDb(eclairJdbc)
override val payments = new SqlitePaymentsDb(eclairJdbc) override val payments = new SqlitePaymentsDb(eclairJdbc)
override val pendingRelay = new SqlitePendingRelayDb(eclairJdbc) override val pendingRelay = new SqlitePendingRelayDb(eclairJdbc)
override def backup(backupFile: File): Unit = {
override def backup(file: File): Unit = {
SqliteUtils.using(eclairJdbc.createStatement()) { SqliteUtils.using(eclairJdbc.createStatement()) {
statement => { statement => {
statement.executeUpdate(s"backup to ${file.getAbsolutePath}") statement.executeUpdate(s"backup to ${backupFile.getAbsolutePath}")
} }
} }
} }
} }
private def checkIfDatabaseUrlIsUnchanged(url: String, datadir: File ): Unit = {
val urlFile = new File(datadir, "last_jdbcurl")
def readString(path: Path): String = Files.readAllLines(path).get(0)
def writeString(path: Path, string: String): Unit = Files.write(path, java.util.Arrays.asList(string))
if (urlFile.exists()) {
val oldUrl = readString(urlFile.toPath)
if (oldUrl != url)
throw new RuntimeException(s"The database URL has changed since the last start. It was `$oldUrl`, now it's `$url`")
} else {
writeString(urlFile.toPath, url)
}
}
} }

View file

@ -21,6 +21,7 @@ import java.io.File
import akka.actor.{Actor, ActorLogging, Props} import akka.actor.{Actor, ActorLogging, Props}
import akka.dispatch.{BoundedMessageQueueSemantics, RequiresMessageQueue} import akka.dispatch.{BoundedMessageQueueSemantics, RequiresMessageQueue}
import fr.acinq.eclair.channel.ChannelPersisted import fr.acinq.eclair.channel.ChannelPersisted
import fr.acinq.eclair.db.Databases.FileBackup
import scala.sys.process.Process import scala.sys.process.Process
import scala.util.{Failure, Success, Try} import scala.util.{Failure, Success, Try}
@ -45,7 +46,7 @@ import scala.util.{Failure, Success, Try}
* *
* Constructor is private so users will have to use BackupHandler.props() which always specific a custom mailbox * Constructor is private so users will have to use BackupHandler.props() which always specific a custom mailbox
*/ */
class BackupHandler private(databases: Databases, backupFile: File, backupScript_opt: Option[String]) extends Actor with RequiresMessageQueue[BoundedMessageQueueSemantics] with ActorLogging { class FileBackupHandler private(databases: FileBackup, backupFile: File, backupScript_opt: Option[String]) extends Actor with RequiresMessageQueue[BoundedMessageQueueSemantics] with ActorLogging {
// we listen to ChannelPersisted events, which will trigger a backup // we listen to ChannelPersisted events, which will trigger a backup
context.system.eventStream.subscribe(self, classOf[ChannelPersisted]) context.system.eventStream.subscribe(self, classOf[ChannelPersisted])
@ -55,6 +56,7 @@ class BackupHandler private(databases: Databases, backupFile: File, backupScript
val start = System.currentTimeMillis() val start = System.currentTimeMillis()
val tmpFile = new File(backupFile.getAbsolutePath.concat(".tmp")) val tmpFile = new File(backupFile.getAbsolutePath.concat(".tmp"))
databases.backup(tmpFile) databases.backup(tmpFile)
// this will throw an exception if it fails, which is possible if the backup file is not on the same filesystem // this will throw an exception if it fails, which is possible if the backup file is not on the same filesystem
// as the temporary file // as the temporary file
// README: On Android we simply use renameTo because most Path methods are not available at our API level // README: On Android we simply use renameTo because most Path methods are not available at our API level
@ -83,8 +85,8 @@ sealed trait BackupEvent
// this notification is sent when we have completed our backup process (our backup file is ready to be used) // this notification is sent when we have completed our backup process (our backup file is ready to be used)
case object BackupCompleted extends BackupEvent case object BackupCompleted extends BackupEvent
object BackupHandler { object FileBackupHandler {
// using this method is the only way to create a BackupHandler actor // using this method is the only way to create a BackupHandler actor
// we make sure that it uses a custom bounded mailbox, and a custom pinned dispatcher (i.e our actor will have its own thread pool with 1 single thread) // we make sure that it uses a custom bounded mailbox, and a custom pinned dispatcher (i.e our actor will have its own thread pool with 1 single thread)
def props(databases: Databases, backupFile: File, backupScript_opt: Option[String]) = Props(new BackupHandler(databases, backupFile, backupScript_opt)).withMailbox("eclair.backup-mailbox").withDispatcher("eclair.backup-dispatcher") def props(databases: FileBackup, backupFile: File, backupScript_opt: Option[String]) = Props(new FileBackupHandler(databases, backupFile, backupScript_opt)).withMailbox("eclair.backup-mailbox").withDispatcher("eclair.backup-dispatcher")
} }

View file

@ -0,0 +1,139 @@
/*
* Copyright 2019 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package fr.acinq.eclair.db.jdbc
import java.sql.{Connection, ResultSet, Statement}
import java.util.UUID
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.MilliSatoshi
import javax.sql.DataSource
import scodec.Codec
import scodec.bits.{BitVector, ByteVector}
import scala.collection.immutable.Queue
trait JdbcUtils {
def withConnection[T](f: Connection => T)(implicit dataSource: DataSource): T = {
val connection = dataSource.getConnection()
try {
f(connection)
} finally {
connection.close()
}
}
/**
* This helper makes sure statements are correctly closed.
*
* @param inTransaction if set to true, all updates in the block will be run in a transaction.
*/
def using[T <: Statement, U](statement: T, inTransaction: Boolean = false)(block: T => U): U = {
val autoCommit = statement.getConnection.getAutoCommit
try {
if (inTransaction) statement.getConnection.setAutoCommit(false)
val res = block(statement)
if (inTransaction) statement.getConnection.commit()
res
} catch {
case t: Exception =>
if (inTransaction) statement.getConnection.rollback()
throw t
} finally {
if (inTransaction) statement.getConnection.setAutoCommit(autoCommit)
if (statement != null) statement.close()
}
}
/**
* This helper assumes that there is a "data" column available, decodable with the provided codec
*
* TODO: we should use an scala.Iterator instead
*/
def codecSequence[T](rs: ResultSet, codec: Codec[T]): Seq[T] = {
var q: Queue[T] = Queue()
while (rs.next()) {
q = q :+ codec.decode(BitVector(rs.getBytes("data"))).require.value
}
q
}
case class ExtendedResultSet(rs: ResultSet) {
def getByteVectorFromHex(columnLabel: String): ByteVector = {
val s = rs.getString(columnLabel).stripPrefix("\\x")
ByteVector.fromValidHex(s)
}
def getByteVector32FromHex(columnLabel: String): ByteVector32 = {
val s = rs.getString(columnLabel)
ByteVector32(ByteVector.fromValidHex(s))
}
def getByteVector32FromHexNullable(columnLabel: String): Option[ByteVector32] = {
val s = rs.getString(columnLabel)
if (rs.wasNull()) None else {
Some(ByteVector32(ByteVector.fromValidHex(s)))
}
}
def getBitVectorOpt(columnLabel: String): Option[BitVector] = Option(rs.getBytes(columnLabel)).map(BitVector(_))
def getByteVector(columnLabel: String): ByteVector = ByteVector(rs.getBytes(columnLabel))
def getByteVectorNullable(columnLabel: String): ByteVector = {
val result = rs.getBytes(columnLabel)
if (rs.wasNull()) ByteVector.empty else ByteVector(result)
}
def getByteVector32(columnLabel: String): ByteVector32 = ByteVector32(ByteVector(rs.getBytes(columnLabel)))
def getByteVector32Nullable(columnLabel: String): Option[ByteVector32] = {
val bytes = rs.getBytes(columnLabel)
if (rs.wasNull()) None else Some(ByteVector32(ByteVector(bytes)))
}
def getStringNullable(columnLabel: String): Option[String] = {
val result = rs.getString(columnLabel)
if (rs.wasNull()) None else Some(result)
}
def getLongNullable(columnLabel: String): Option[Long] = {
val result = rs.getLong(columnLabel)
if (rs.wasNull()) None else Some(result)
}
def getUUIDNullable(label: String): Option[UUID] = {
val result = rs.getString(label)
if (rs.wasNull()) None else Some(UUID.fromString(result))
}
def getMilliSatoshiNullable(label: String): Option[MilliSatoshi] = {
val result = rs.getLong(label)
if (rs.wasNull()) None else Some(MilliSatoshi(result))
}
}
object ExtendedResultSet {
implicit def conv(rs: ResultSet): ExtendedResultSet = ExtendedResultSet(rs)
}
}
object JdbcUtils extends JdbcUtils

View file

@ -101,7 +101,7 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging {
} }
} }
def addHtlcInfo(channelId: ByteVector32, commitmentNumber: Long, paymentHash: ByteVector32, cltvExpiry: CltvExpiry): Unit = { override def addHtlcInfo(channelId: ByteVector32, commitmentNumber: Long, paymentHash: ByteVector32, cltvExpiry: CltvExpiry): Unit = {
using(sqlite.prepareStatement("INSERT INTO htlc_infos VALUES (?, ?, ?, ?)")) { statement => using(sqlite.prepareStatement("INSERT INTO htlc_infos VALUES (?, ?, ?, ?)")) { statement =>
statement.setBytes(1, channelId.toArray) statement.setBytes(1, channelId.toArray)
statement.setLong(2, commitmentNumber) statement.setLong(2, commitmentNumber)
@ -111,7 +111,7 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging {
} }
} }
def listHtlcInfos(channelId: ByteVector32, commitmentNumber: Long): Seq[(ByteVector32, CltvExpiry)] = { override def listHtlcInfos(channelId: ByteVector32, commitmentNumber: Long): Seq[(ByteVector32, CltvExpiry)] = {
using(sqlite.prepareStatement("SELECT payment_hash, cltv_expiry FROM htlc_infos WHERE channel_id=? AND commitment_number=?")) { statement => using(sqlite.prepareStatement("SELECT payment_hash, cltv_expiry FROM htlc_infos WHERE channel_id=? AND commitment_number=?")) { statement =>
statement.setBytes(1, channelId.toArray) statement.setBytes(1, channelId.toArray)
statement.setLong(2, commitmentNumber) statement.setLong(2, commitmentNumber)

View file

@ -16,38 +16,11 @@
package fr.acinq.eclair.db.sqlite package fr.acinq.eclair.db.sqlite
import java.sql.{Connection, PreparedStatement, ResultSet, Statement} import java.sql.{Connection, Statement}
import java.util.UUID
import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.db.jdbc.JdbcUtils
import fr.acinq.eclair.MilliSatoshi
import scodec.Codec
import scodec.bits.{BitVector, ByteVector}
import scala.collection.immutable.Queue object SqliteUtils extends JdbcUtils {
object SqliteUtils {
/**
* This helper makes sure statements are correctly closed.
*
* @param inTransaction if set to true, all updates in the block will be run in a transaction.
*/
def using[T <: Statement, U](statement: T, inTransaction: Boolean = false)(block: T => U): U = {
try {
if (inTransaction) statement.getConnection.setAutoCommit(false)
val res = block(statement)
if (inTransaction) statement.getConnection.commit()
res
} catch {
case t: Exception =>
if (inTransaction) statement.getConnection.rollback()
throw t
} finally {
if (inTransaction) statement.getConnection.setAutoCommit(true)
if (statement != null) statement.close()
}
}
/** /**
* Several logical databases (channels, network, peers) may be stored in the same physical sqlite database. * Several logical databases (channels, network, peers) may be stored in the same physical sqlite database.
@ -72,34 +45,6 @@ object SqliteUtils {
statement.executeUpdate(s"UPDATE versions SET version=$newVersion WHERE db_name='$db_name'") statement.executeUpdate(s"UPDATE versions SET version=$newVersion WHERE db_name='$db_name'")
} }
/**
* This helper assumes that there is a "data" column available, decodable with the provided codec
*
* TODO: we should use an scala.Iterator instead
*/
def codecSequence[T](rs: ResultSet, codec: Codec[T]): Seq[T] = {
var q: Queue[T] = Queue()
while (rs.next()) {
q = q :+ codec.decode(BitVector(rs.getBytes("data"))).require.value
}
q
}
/**
* This helper uses the proper way to set a nullable value.
* It is used on Android only
*
* @param statement
* @param parameterIndex
* @param value_opt
*/
def setNullableLong(statement: PreparedStatement, parameterIndex: Int, value_opt: Option[Long]) = {
value_opt match {
case Some(value) => statement.setLong(parameterIndex, value)
case None => statement.setNull(parameterIndex, java.sql.Types.INTEGER)
}
}
/** /**
* Obtain an exclusive lock on a sqlite database. This is useful when we want to make sure that only one process * Obtain an exclusive lock on a sqlite database. This is useful when we want to make sure that only one process
* accesses the database file (see https://www.sqlite.org/pragma.html). * accesses the database file (see https://www.sqlite.org/pragma.html).
@ -114,48 +59,4 @@ object SqliteUtils {
statement.executeUpdate("INSERT INTO dummy_table_for_locking VALUES (42)") statement.executeUpdate("INSERT INTO dummy_table_for_locking VALUES (42)")
} }
case class ExtendedResultSet(rs: ResultSet) {
def getBitVectorOpt(columnLabel: String): Option[BitVector] = Option(rs.getBytes(columnLabel)).map(BitVector(_))
def getByteVector(columnLabel: String): ByteVector = ByteVector(rs.getBytes(columnLabel))
def getByteVectorNullable(columnLabel: String): ByteVector = {
val result = rs.getBytes(columnLabel)
if (rs.wasNull()) ByteVector.empty else ByteVector(result)
}
def getByteVector32(columnLabel: String): ByteVector32 = ByteVector32(ByteVector(rs.getBytes(columnLabel)))
def getByteVector32Nullable(columnLabel: String): Option[ByteVector32] = {
val bytes = rs.getBytes(columnLabel)
if (rs.wasNull()) None else Some(ByteVector32(ByteVector(bytes)))
}
def getStringNullable(columnLabel: String): Option[String] = {
val result = rs.getString(columnLabel)
if (rs.wasNull()) None else Some(result)
}
def getLongNullable(columnLabel: String): Option[Long] = {
val result = rs.getLong(columnLabel)
if (rs.wasNull()) None else Some(result)
}
def getUUIDNullable(label: String): Option[UUID] = {
val result = rs.getString(label)
if (rs.wasNull()) None else Some(UUID.fromString(result))
}
def getMilliSatoshiNullable(label: String): Option[MilliSatoshi] = {
val result = rs.getLong(label)
if (rs.wasNull()) None else Some(MilliSatoshi(result))
}
}
object ExtendedResultSet {
implicit def conv(rs: ResultSet): ExtendedResultSet = ExtendedResultSet(rs)
}
} }

View file

@ -46,7 +46,7 @@
<logger name="fr.acinq.eclair.router" level="WARN"/> <logger name="fr.acinq.eclair.router" level="WARN"/>
<logger name="fr.acinq.eclair.Diagnostics" level="OFF"/> <logger name="fr.acinq.eclair.Diagnostics" level="OFF"/>
<logger name="fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher" level="OFF"/> <logger name="fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher" level="OFF"/>
<logger name="fr.acinq.eclair.db.BackupHandler" level="OFF"/> <logger name="fr.acinq.eclair.db.FileBackupHandler" level="OFF"/>
<root level="INFO"> <root level="INFO">
<!--appender-ref ref="FILE"/> <!--appender-ref ref="FILE"/>

View file

@ -16,6 +16,7 @@
package fr.acinq.eclair package fr.acinq.eclair
import java.util.UUID
import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.atomic.AtomicLong
import com.typesafe.config.{Config, ConfigFactory} import com.typesafe.config.{Config, ConfigFactory}
@ -40,7 +41,7 @@ class StartupSpec extends AnyFunSuite {
val keyManager = new LocalKeyManager(seed = randomBytes32, chainHash = Block.TestnetGenesisBlock.hash) val keyManager = new LocalKeyManager(seed = randomBytes32, chainHash = Block.TestnetGenesisBlock.hash)
val feeEstimator = new TestConstants.TestFeeEstimator val feeEstimator = new TestConstants.TestFeeEstimator
val db = TestConstants.inMemoryDb() val db = TestConstants.inMemoryDb()
NodeParams.makeNodeParams(conf, keyManager, None, db, blockCount, feeEstimator) NodeParams.makeNodeParams(conf, UUID.fromString("01234567-0123-4567-89ab-0123456789ab"), keyManager, None, db, blockCount, feeEstimator)
} }
test("check configuration") { test("check configuration") {

View file

@ -16,7 +16,8 @@
package fr.acinq.eclair package fr.acinq.eclair
import java.sql.{Connection, DriverManager} import java.sql.{Connection, DriverManager, Statement}
import java.util.UUID
import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.atomic.AtomicLong
import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.Crypto.PrivateKey
@ -27,6 +28,7 @@ import fr.acinq.eclair.NodeParams.BITCOIND
import fr.acinq.eclair.blockchain.fee._ import fr.acinq.eclair.blockchain.fee._
import fr.acinq.eclair.crypto.LocalKeyManager import fr.acinq.eclair.crypto.LocalKeyManager
import fr.acinq.eclair.db._ import fr.acinq.eclair.db._
import fr.acinq.eclair.db.sqlite._
import fr.acinq.eclair.io.Peer import fr.acinq.eclair.io.Peer
import fr.acinq.eclair.router.Router.RouterConf import fr.acinq.eclair.router.Router.RouterConf
import fr.acinq.eclair.wire.{Color, EncodingType, NodeAddress} import fr.acinq.eclair.wire.{Color, EncodingType, NodeAddress}
@ -57,9 +59,37 @@ object TestConstants {
} }
} }
def sqliteInMemory() = DriverManager.getConnection("jdbc:sqlite::memory:") sealed trait TestDatabases {
val connection: Connection
def network(): NetworkDb
def audit(): AuditDb
def channels(): ChannelsDb
def peers(): PeersDb
def payments(): PaymentsDb
def pendingRelay(): PendingRelayDb
def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int
def close(): Unit
}
def inMemoryDb(connection: Connection = sqliteInMemory()): Databases = Databases.databaseByConnections(connection, connection, connection) case class TestSqliteDatabases(connection: Connection = sqliteInMemory()) extends TestDatabases {
override def network(): NetworkDb = new SqliteNetworkDb(connection)
override def audit(): AuditDb = new SqliteAuditDb(connection)
override def channels(): ChannelsDb = new SqliteChannelsDb(connection)
override def peers(): PeersDb = new SqlitePeersDb(connection)
override def payments(): PaymentsDb = new SqlitePaymentsDb(connection)
override def pendingRelay(): PendingRelayDb = new SqlitePendingRelayDb(connection)
override def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int = SqliteUtils.getVersion(statement, db_name, currentVersion)
override def close(): Unit = ()
}
def sqliteInMemory(): Connection = DriverManager.getConnection("jdbc:sqlite::memory:")
def forAllDbs(f: TestDatabases => Unit): Unit = {
def using(dbs: TestDatabases)(g: TestDatabases => Unit): Unit = try g(dbs) finally dbs.close()
using(TestSqliteDatabases())(f)
}
def inMemoryDb(connection: Connection = sqliteInMemory()): Databases = Databases.sqliteDatabaseByConnections(connection, connection, connection)
object Alice { object Alice {
val seed = ByteVector32(ByteVector.fill(32)(1)) val seed = ByteVector32(ByteVector.fill(32)(1))
@ -139,7 +169,8 @@ object TestConstants {
), ),
socksProxy_opt = None, socksProxy_opt = None,
maxPaymentAttempts = 5, maxPaymentAttempts = 5,
enableTrampolinePayment = true enableTrampolinePayment = true,
instanceId = UUID.fromString("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
) )
def channelParams = Peer.makeChannelParams( def channelParams = Peer.makeChannelParams(
@ -225,7 +256,8 @@ object TestConstants {
), ),
socksProxy_opt = None, socksProxy_opt = None,
maxPaymentAttempts = 5, maxPaymentAttempts = 5,
enableTrampolinePayment = true enableTrampolinePayment = true,
instanceId = UUID.fromString("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")
) )
def channelParams = Peer.makeChannelParams( def channelParams = Peer.makeChannelParams(

View file

@ -20,15 +20,15 @@ import java.io.File
import java.sql.DriverManager import java.sql.DriverManager
import java.util.UUID import java.util.UUID
import akka.actor.ActorSystem import akka.testkit.TestProbe
import akka.testkit.{TestKit, TestProbe}
import fr.acinq.eclair.channel.ChannelPersisted import fr.acinq.eclair.channel.ChannelPersisted
import fr.acinq.eclair.db.Databases.FileBackup
import fr.acinq.eclair.db.sqlite.SqliteChannelsDb import fr.acinq.eclair.db.sqlite.SqliteChannelsDb
import fr.acinq.eclair.wire.ChannelCodecsSpec import fr.acinq.eclair.wire.ChannelCodecsSpec
import fr.acinq.eclair.{TestConstants, TestKitBaseClass, TestUtils, randomBytes32} import fr.acinq.eclair.{TestConstants, TestKitBaseClass, TestUtils, randomBytes32}
import org.scalatest.funsuite.AnyFunSuiteLike import org.scalatest.funsuite.AnyFunSuiteLike
class BackupHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike { class FileBackupHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike {
test("process backups") { test("process backups") {
val db = TestConstants.inMemoryDb() val db = TestConstants.inMemoryDb()
@ -40,7 +40,7 @@ class BackupHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike {
db.channels.addOrUpdateChannel(channel) db.channels.addOrUpdateChannel(channel)
assert(db.channels.listLocalChannels() == Seq(channel)) assert(db.channels.listLocalChannels() == Seq(channel))
val handler = system.actorOf(BackupHandler.props(db, dest, None)) val handler = system.actorOf(FileBackupHandler.props(db.asInstanceOf[FileBackup], dest, None))
val probe = TestProbe() val probe = TestProbe()
system.eventStream.subscribe(probe.ref, classOf[BackupEvent]) system.eventStream.subscribe(probe.ref, classOf[BackupEvent])

View file

@ -20,10 +20,11 @@ import java.util.UUID
import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.Crypto.PrivateKey
import fr.acinq.bitcoin.{ByteVector32, Transaction} import fr.acinq.bitcoin.{ByteVector32, Transaction}
import fr.acinq.eclair.TestConstants.{TestSqliteDatabases, forAllDbs}
import fr.acinq.eclair._ import fr.acinq.eclair._
import fr.acinq.eclair.channel.{ChannelErrorOccurred, LocalError, NetworkFeePaid, RemoteError} import fr.acinq.eclair.channel.{ChannelErrorOccurred, LocalError, NetworkFeePaid, RemoteError}
import fr.acinq.eclair.db.jdbc.JdbcUtils.using
import fr.acinq.eclair.db.sqlite.SqliteAuditDb import fr.acinq.eclair.db.sqlite.SqliteAuditDb
import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, using}
import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment._
import org.scalatest.Tag import org.scalatest.Tag
import org.scalatest.funsuite.AnyFunSuite import org.scalatest.funsuite.AnyFunSuite
@ -36,14 +37,15 @@ class SqliteAuditDbSpec extends AnyFunSuite {
val ZERO_UUID = UUID.fromString("00000000-0000-0000-0000-000000000000") val ZERO_UUID = UUID.fromString("00000000-0000-0000-0000-000000000000")
test("init sqlite 2 times in a row") { test("init sqlite 2 times in a row") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
val db1 = new SqliteAuditDb(sqlite) val db1 = dbs.audit()
val db2 = new SqliteAuditDb(sqlite) val db2 = dbs.audit()
}
} }
test("add/list events") { test("add/list events") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
val db = new SqliteAuditDb(sqlite) val db = dbs.audit()
val e1 = PaymentSent(ZERO_UUID, randomBytes32, randomBytes32, 40000 msat, randomKey.publicKey, PaymentSent.PartialPayment(ZERO_UUID, 42000 msat, 1000 msat, randomBytes32, None) :: Nil) val e1 = PaymentSent(ZERO_UUID, randomBytes32, randomBytes32, 40000 msat, randomKey.publicKey, PaymentSent.PartialPayment(ZERO_UUID, 42000 msat, 1000 msat, randomBytes32, None) :: Nil)
val pp2a = PaymentReceived.PartialPayment(42000 msat, randomBytes32) val pp2a = PaymentReceived.PartialPayment(42000 msat, randomBytes32)
@ -85,10 +87,11 @@ class SqliteAuditDbSpec extends AnyFunSuite {
assert(db.listNetworkFees(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).size === 1) assert(db.listNetworkFees(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).size === 1)
assert(db.listNetworkFees(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).head.txType === "mutual") assert(db.listNetworkFees(from = 0L, to = (System.currentTimeMillis.milliseconds + 1.minute).toMillis).head.txType === "mutual")
} }
}
test("stats") { test("stats") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
val db = new SqliteAuditDb(sqlite) val db = dbs.audit()
val n2 = randomKey.publicKey val n2 = randomKey.publicKey
val n3 = randomKey.publicKey val n3 = randomKey.publicKey
@ -130,10 +133,11 @@ class SqliteAuditDbSpec extends AnyFunSuite {
Stats(channelId = c6, direction = "OUT", avgPaymentAmount = 40 sat, paymentCount = 1, relayFee = 5 sat, networkFee = 0 sat) Stats(channelId = c6, direction = "OUT", avgPaymentAmount = 40 sat, paymentCount = 1, relayFee = 5 sat, networkFee = 0 sat)
)) ))
} }
}
ignore("relay stats performance", Tag("perf")) { ignore("relay stats performance", Tag("perf")) {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
val db = new SqliteAuditDb(sqlite) val db = dbs.audit()
val nodeCount = 100 val nodeCount = 100
val channelCount = 1000 val channelCount = 1000
val eventCount = 100000 val eventCount = 100000
@ -163,9 +167,13 @@ class SqliteAuditDbSpec extends AnyFunSuite {
val end = System.currentTimeMillis val end = System.currentTimeMillis
fail(s"took ${end - start}ms") fail(s"took ${end - start}ms")
} }
}
test("handle migration version 1 -> 4") { test("handle migration version 1 -> 4") {
val connection = TestConstants.sqliteInMemory() forAllDbs {
case dbs: TestSqliteDatabases =>
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
val connection = dbs.connection
// simulate existing previous version db // simulate existing previous version db
using(connection.createStatement()) { statement => using(connection.createStatement()) { statement =>
@ -230,9 +238,13 @@ class SqliteAuditDbSpec extends AnyFunSuite {
val expected = Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID))), ps1) val expected = Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID))), ps1)
assert(postMigrationDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === expected) assert(postMigrationDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === expected)
} }
}
test("handle migration version 2 -> 4") { test("handle migration version 2 -> 4") {
val connection = TestConstants.sqliteInMemory() forAllDbs {
case dbs: TestSqliteDatabases =>
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
val connection = dbs.connection
// simulate existing previous version db // simulate existing previous version db
using(connection.createStatement()) { statement => using(connection.createStatement()) { statement =>
@ -275,9 +287,13 @@ class SqliteAuditDbSpec extends AnyFunSuite {
postMigrationDb.add(e2) postMigrationDb.add(e2)
} }
}
test("handle migration version 3 -> 4") { test("handle migration version 3 -> 4") {
val connection = TestConstants.sqliteInMemory() forAllDbs {
case dbs: TestSqliteDatabases =>
import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion
val connection = dbs.connection
// simulate existing previous version db // simulate existing previous version db
using(connection.createStatement()) { statement => using(connection.createStatement()) { statement =>
@ -363,15 +379,18 @@ class SqliteAuditDbSpec extends AnyFunSuite {
postMigrationDb.add(relayed3) postMigrationDb.add(relayed3)
assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3)) assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3))
} }
}
test("ignore invalid values in the DB") { test("ignore invalid values in the DB") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
val db = new SqliteAuditDb(sqlite) val db = dbs.audit()
val sqlite = dbs.connection
val isPg = false
using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement => using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, randomBytes32.toArray) if (isPg) statement.setString(1, randomBytes32.toHex) else statement.setBytes(1, randomBytes32.toArray)
statement.setLong(2, 42) statement.setLong(2, 42)
statement.setBytes(3, randomBytes32.toArray) if (isPg) statement.setString(3, randomBytes32.toHex) else statement.setBytes(3, randomBytes32.toArray)
statement.setString(4, "IN") statement.setString(4, "IN")
statement.setString(5, "unknown") // invalid relay type statement.setString(5, "unknown") // invalid relay type
statement.setLong(6, 10) statement.setLong(6, 10)
@ -379,9 +398,9 @@ class SqliteAuditDbSpec extends AnyFunSuite {
} }
using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement => using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, randomBytes32.toArray) if (isPg) statement.setString(1, randomBytes32.toHex) else statement.setBytes(1, randomBytes32.toArray)
statement.setLong(2, 51) statement.setLong(2, 51)
statement.setBytes(3, randomBytes32.toArray) if (isPg) statement.setString(3, randomBytes32.toHex) else statement.setBytes(3, randomBytes32.toArray)
statement.setString(4, "UP") // invalid direction statement.setString(4, "UP") // invalid direction
statement.setString(5, "channel") statement.setString(5, "channel")
statement.setLong(6, 20) statement.setLong(6, 20)
@ -392,9 +411,9 @@ class SqliteAuditDbSpec extends AnyFunSuite {
val channelId = randomBytes32 val channelId = randomBytes32
using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement => using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement =>
statement.setBytes(1, paymentHash.toArray) if (isPg) statement.setString(1, paymentHash.toHex) else statement.setBytes(1, paymentHash.toArray)
statement.setLong(2, 65) statement.setLong(2, 65)
statement.setBytes(3, channelId.toArray) if (isPg) statement.setString(3, channelId.toHex) else statement.setBytes(3, channelId.toArray)
statement.setString(4, "IN") // missing a corresponding OUT statement.setString(4, "IN") // missing a corresponding OUT
statement.setString(5, "channel") statement.setString(5, "channel")
statement.setLong(6, 30) statement.setLong(6, 30)
@ -403,5 +422,6 @@ class SqliteAuditDbSpec extends AnyFunSuite {
assert(db.listRelayed(0, 40) === Nil) assert(db.listRelayed(0, 40) === Nil)
} }
}
} }

View file

@ -16,28 +16,31 @@
package fr.acinq.eclair.db package fr.acinq.eclair.db
import java.sql.SQLException
import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.CltvExpiry
import fr.acinq.eclair.TestConstants.{TestSqliteDatabases, forAllDbs}
import fr.acinq.eclair.db.sqlite.SqliteChannelsDb
import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, using} import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, using}
import fr.acinq.eclair.db.sqlite.{SqliteChannelsDb, SqlitePendingRelayDb}
import fr.acinq.eclair.wire.ChannelCodecs.stateDataCodec import fr.acinq.eclair.wire.ChannelCodecs.stateDataCodec
import fr.acinq.eclair.wire.ChannelCodecsSpec import fr.acinq.eclair.wire.ChannelCodecsSpec
import fr.acinq.eclair.{CltvExpiry, TestConstants}
import org.scalatest.funsuite.AnyFunSuite import org.scalatest.funsuite.AnyFunSuite
import org.sqlite.SQLiteException
import scodec.bits.ByteVector import scodec.bits.ByteVector
class SqliteChannelsDbSpec extends AnyFunSuite { class SqliteChannelsDbSpec extends AnyFunSuite {
test("init sqlite 2 times in a row") { test("init sqlite 2 times in a row") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
val db1 = new SqliteChannelsDb(sqlite) val db1 = dbs.channels()
val db2 = new SqliteChannelsDb(sqlite) val db2 = dbs.channels()
}
} }
test("add/remove/list channels") { test("add/remove/list channels") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
val db = new SqliteChannelsDb(sqlite) val db = dbs.channels()
new SqlitePendingRelayDb(sqlite) // needed by db.removeChannel dbs.pendingRelay() // needed by db.removeChannel
val channel = ChannelCodecsSpec.normal val channel = ChannelCodecsSpec.normal
@ -47,7 +50,7 @@ class SqliteChannelsDbSpec extends AnyFunSuite {
val paymentHash2 = ByteVector32(ByteVector.fill(32)(1)) val paymentHash2 = ByteVector32(ByteVector.fill(32)(1))
val cltvExpiry2 = CltvExpiry(656) val cltvExpiry2 = CltvExpiry(656)
intercept[SQLiteException](db.addHtlcInfo(channel.channelId, commitNumber, paymentHash1, cltvExpiry1)) // no related channel intercept[SQLException](db.addHtlcInfo(channel.channelId, commitNumber, paymentHash1, cltvExpiry1)) // no related channel
assert(db.listLocalChannels().toSet === Set.empty) assert(db.listLocalChannels().toSet === Set.empty)
db.addOrUpdateChannel(channel) db.addOrUpdateChannel(channel)
@ -57,16 +60,19 @@ class SqliteChannelsDbSpec extends AnyFunSuite {
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == Nil) assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == Nil)
db.addHtlcInfo(channel.channelId, commitNumber, paymentHash1, cltvExpiry1) db.addHtlcInfo(channel.channelId, commitNumber, paymentHash1, cltvExpiry1)
db.addHtlcInfo(channel.channelId, commitNumber, paymentHash2, cltvExpiry2) db.addHtlcInfo(channel.channelId, commitNumber, paymentHash2, cltvExpiry2)
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == List((paymentHash1, cltvExpiry1), (paymentHash2, cltvExpiry2))) assert(db.listHtlcInfos(channel.channelId, commitNumber).toList.toSet == Set((paymentHash1, cltvExpiry1), (paymentHash2, cltvExpiry2)))
assert(db.listHtlcInfos(channel.channelId, 43).toList == Nil) assert(db.listHtlcInfos(channel.channelId, 43).toList == Nil)
db.removeChannel(channel.channelId) db.removeChannel(channel.channelId)
assert(db.listLocalChannels() === Nil) assert(db.listLocalChannels() === Nil)
assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == Nil) assert(db.listHtlcInfos(channel.channelId, commitNumber).toList == Nil)
} }
}
test("migrate channel database v1 -> v2") { test("migrate channel database v1 -> v2") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs {
case dbs: TestSqliteDatabases =>
val sqlite = dbs.connection
// create a v1 channels database // create a v1 channels database
using(sqlite.createStatement()) { statement => using(sqlite.createStatement()) { statement =>
@ -94,3 +100,4 @@ class SqliteChannelsDbSpec extends AnyFunSuite {
assert(db.listLocalChannels() === List(channel)) assert(db.listLocalChannels() === List(channel))
} }
} }
}

View file

@ -16,13 +16,11 @@
package fr.acinq.eclair.db package fr.acinq.eclair.db
import java.sql.Connection
import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.Crypto.PrivateKey
import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64, Crypto, Satoshi} import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64, Crypto, Satoshi}
import fr.acinq.eclair.FeatureSupport.Optional import fr.acinq.eclair.FeatureSupport.Optional
import fr.acinq.eclair.Features.VariableLengthOnion import fr.acinq.eclair.Features.VariableLengthOnion
import fr.acinq.eclair.db.sqlite.SqliteNetworkDb import fr.acinq.eclair.TestConstants.{TestDatabases, TestSqliteDatabases}
import fr.acinq.eclair.db.sqlite.SqliteUtils._ import fr.acinq.eclair.db.sqlite.SqliteUtils._
import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.router.Announcements
import fr.acinq.eclair.router.Router.PublicChannel import fr.acinq.eclair.router.Router.PublicChannel
@ -34,19 +32,23 @@ import scala.collection.{SortedMap, mutable}
class SqliteNetworkDbSpec extends AnyFunSuite { class SqliteNetworkDbSpec extends AnyFunSuite {
import TestConstants.forAllDbs
val shortChannelIds = (42 to (5000 + 42)).map(i => ShortChannelId(i)) val shortChannelIds = (42 to (5000 + 42)).map(i => ShortChannelId(i))
test("init sqlite 2 times in a row") { test("init sqlite 2 times in a row") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
val db1 = new SqliteNetworkDb(sqlite) val db1 = dbs.network()
val db2 = new SqliteNetworkDb(sqlite) val db2 = dbs.network()
}
} }
test("migration test 1->2") { test("migration test 1->2") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs {
case dbs: TestSqliteDatabases =>
using(sqlite.createStatement()) { statement => using(dbs.connection.createStatement()) { statement =>
getVersion(statement, "network", 1) // this will set version to 1 dbs.getVersion(statement, "network", 1) // this will set version to 1
statement.execute("PRAGMA foreign_keys = ON") statement.execute("PRAGMA foreign_keys = ON")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS nodes (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS nodes (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channels (short_channel_id INTEGER NOT NULL PRIMARY KEY, txid STRING NOT NULL, data BLOB NOT NULL, capacity_sat INTEGER NOT NULL)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS channels (short_channel_id INTEGER NOT NULL PRIMARY KEY, txid STRING NOT NULL, data BLOB NOT NULL, capacity_sat INTEGER NOT NULL)")
@ -55,33 +57,34 @@ class SqliteNetworkDbSpec extends AnyFunSuite {
statement.executeUpdate("CREATE TABLE IF NOT EXISTS pruned (short_channel_id INTEGER NOT NULL PRIMARY KEY)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS pruned (short_channel_id INTEGER NOT NULL PRIMARY KEY)")
} }
using(sqlite.createStatement()) { statement => using(dbs.connection.createStatement()) { statement =>
assert(getVersion(statement, "network", 2) == 1) assert(dbs.getVersion(statement, "network", 2) == 1)
} }
// first round: this will trigger a migration // first round: this will trigger a migration
simpleTest(sqlite) simpleTest(dbs)
using(sqlite.createStatement()) { statement => using(dbs.connection.createStatement()) { statement =>
assert(getVersion(statement, "network", 2) == 2) assert(dbs.getVersion(statement, "network", 2) == 2)
} }
using(sqlite.createStatement()) { statement => using(dbs.connection.createStatement()) { statement =>
statement.executeUpdate("DELETE FROM nodes") statement.executeUpdate("DELETE FROM nodes")
statement.executeUpdate("DELETE FROM channels") statement.executeUpdate("DELETE FROM channels")
} }
// second round: no migration // second round: no migration
simpleTest(sqlite) simpleTest(dbs)
using(sqlite.createStatement()) { statement => using(dbs.connection.createStatement()) { statement =>
assert(getVersion(statement, "network", 2) == 2) assert(dbs.getVersion(statement, "network", 2) == 2)
}
} }
} }
test("add/remove/list nodes") { test("add/remove/list nodes") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
val db = new SqliteNetworkDb(sqlite) val db = dbs.network()
val node_1 = Announcements.makeNodeAnnouncement(randomKey, "node-alice", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features.empty) val node_1 = Announcements.makeNodeAnnouncement(randomKey, "node-alice", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features.empty)
val node_2 = Announcements.makeNodeAnnouncement(randomKey, "node-bob", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features(Set(ActivatedFeature(VariableLengthOnion, Optional)))) val node_2 = Announcements.makeNodeAnnouncement(randomKey, "node-bob", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features(Set(ActivatedFeature(VariableLengthOnion, Optional))))
@ -103,10 +106,11 @@ class SqliteNetworkDbSpec extends AnyFunSuite {
assert(node_4.addresses == List(Tor2("aaaqeayeaudaocaj", 42000))) assert(node_4.addresses == List(Tor2("aaaqeayeaudaocaj", 42000)))
} }
}
test("correctly handle txids that start with 0") { test("correctly handle txids that start with 0") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
val db = new SqliteNetworkDb(sqlite) val db = dbs.network()
val sig = ByteVector64.Zeroes val sig = ByteVector64.Zeroes
val c = Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, ShortChannelId(42), randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, sig, sig, sig, sig) val c = Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, ShortChannelId(42), randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, sig, sig, sig, sig)
val c_shrunk = shrink(c) val c_shrunk = shrink(c)
@ -114,13 +118,14 @@ class SqliteNetworkDbSpec extends AnyFunSuite {
db.addChannel(c, txid, Satoshi(42)) db.addChannel(c, txid, Satoshi(42))
assert(db.listChannels() === SortedMap(c.shortChannelId -> PublicChannel(c_shrunk, txid, Satoshi(42), None, None, None))) assert(db.listChannels() === SortedMap(c.shortChannelId -> PublicChannel(c_shrunk, txid, Satoshi(42), None, None, None)))
} }
}
def shrink(c: ChannelAnnouncement) = c.copy(bitcoinKey1 = null, bitcoinKey2 = null, bitcoinSignature1 = null, bitcoinSignature2 = null, nodeSignature1 = null, nodeSignature2 = null, chainHash = null, features = null) def shrink(c: ChannelAnnouncement) = c.copy(bitcoinKey1 = null, bitcoinKey2 = null, bitcoinSignature1 = null, bitcoinSignature2 = null, nodeSignature1 = null, nodeSignature2 = null, chainHash = null, features = null)
def shrink(c: ChannelUpdate) = c.copy(signature = null, chainHash = null) def shrink(c: ChannelUpdate) = c.copy(signature = null, chainHash = null)
def simpleTest(sqlite: Connection) = { def simpleTest(dbs: TestDatabases) = {
val db = new SqliteNetworkDb(sqlite) val db = dbs.network()
def sig = Crypto.sign(randomBytes32, randomKey) def sig = Crypto.sign(randomBytes32, randomKey)
@ -185,28 +190,33 @@ class SqliteNetworkDbSpec extends AnyFunSuite {
} }
test("add/remove/list channels and channel_updates") { test("add/remove/list channels and channel_updates") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
simpleTest(sqlite) simpleTest(dbs)
}
} }
test("creating a table that already exists but with different column types is ignored") { test("creating a table that already exists but with different column types is ignored") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
using(sqlite.createStatement(), inTransaction = true) { statement =>
statement.execute("CREATE TABLE IF NOT EXISTS test (txid STRING NOT NULL)") using(dbs.connection.createStatement(), inTransaction = true) { statement =>
statement.execute("CREATE TABLE IF NOT EXISTS test (txid VARCHAR NOT NULL)")
} }
// column type is STRING // column type is VARCHAR
assert(sqlite.getMetaData.getColumns(null, null, "test", null).getString("TYPE_NAME") == "STRING") val rs = dbs.connection.getMetaData.getColumns(null, null, "test", null)
assert(rs.next())
assert(rs.getString("TYPE_NAME").toLowerCase == "varchar")
// insert and read back random values // insert and read back random values
val txids = for (_ <- 0 until 1000) yield randomBytes32 val txids = for (_ <- 0 until 1000) yield randomBytes32
txids.foreach { txid => txids.foreach { txid =>
using(sqlite.prepareStatement("INSERT OR IGNORE INTO test VALUES (?)")) { statement => using(dbs.connection.prepareStatement("INSERT INTO test VALUES (?)")) { statement =>
statement.setString(1, txid.toHex) statement.setString(1, txid.toHex)
statement.executeUpdate() statement.executeUpdate()
} }
} }
val check = using(sqlite.createStatement()) { statement => val check = using(dbs.connection.createStatement()) { statement =>
val rs = statement.executeQuery("SELECT txid FROM test") val rs = statement.executeQuery("SELECT txid FROM test")
val q = new mutable.Queue[ByteVector32]() val q = new mutable.Queue[ByteVector32]()
while (rs.next()) { while (rs.next()) {
@ -218,17 +228,20 @@ class SqliteNetworkDbSpec extends AnyFunSuite {
assert(txids.toSet == check.toSet) assert(txids.toSet == check.toSet)
using(sqlite.createStatement(), inTransaction = true) { statement => using(dbs.connection.createStatement(), inTransaction = true) { statement =>
statement.execute("CREATE TABLE IF NOT EXISTS test (txid TEXT NOT NULL)") statement.execute("CREATE TABLE IF NOT EXISTS test (txid TEXT NOT NULL)")
} }
// column type has not changed // column type has not changed
assert(sqlite.getMetaData.getColumns(null, null, "test", null).getString("TYPE_NAME") == "STRING") val rs1 = dbs.connection.getMetaData.getColumns(null, null, "test", null)
assert(rs1.next())
assert(rs1.getString("TYPE_NAME").toLowerCase == "varchar")
}
} }
test("remove many channels") { test("remove many channels") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
val db = new SqliteNetworkDb(sqlite) val db = dbs.network()
val sig = Crypto.sign(randomBytes32, randomKey) val sig = Crypto.sign(randomBytes32, randomKey)
val priv = randomKey val priv = randomKey
val pub = priv.publicKey val pub = priv.publicKey
@ -246,10 +259,11 @@ class SqliteNetworkDbSpec extends AnyFunSuite {
db.removeChannels(toDelete) db.removeChannels(toDelete)
assert(db.listChannels().keySet === (channels.map(_.shortChannelId).toSet -- toDelete)) assert(db.listChannels().keySet === (channels.map(_.shortChannelId).toSet -- toDelete))
} }
}
test("prune many channels") { test("prune many channels") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
val db = new SqliteNetworkDb(sqlite) val db = dbs.network()
db.addToPruned(shortChannelIds) db.addToPruned(shortChannelIds)
shortChannelIds.foreach { id => assert(db.isPruned((id))) } shortChannelIds.foreach { id => assert(db.isPruned((id))) }
@ -257,3 +271,4 @@ class SqliteNetworkDbSpec extends AnyFunSuite {
assert(!db.isPruned(ShortChannelId(5))) assert(!db.isPruned(ShortChannelId(5)))
} }
} }
}

View file

@ -20,9 +20,9 @@ import java.util.UUID
import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.Crypto.PrivateKey
import fr.acinq.bitcoin.{Block, ByteVector32, Crypto} import fr.acinq.bitcoin.{Block, ByteVector32, Crypto}
import fr.acinq.eclair.TestConstants.{TestSqliteDatabases, forAllDbs}
import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.db.sqlite.SqlitePaymentsDb import fr.acinq.eclair.db.sqlite.SqlitePaymentsDb
import fr.acinq.eclair.db.sqlite.SqliteUtils._
import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment._
import fr.acinq.eclair.router.Router.{ChannelHop, NodeHop} import fr.acinq.eclair.router.Router.{ChannelHop, NodeHop}
import fr.acinq.eclair.wire.{ChannelUpdate, UnknownNextPeer} import fr.acinq.eclair.wire.{ChannelUpdate, UnknownNextPeer}
@ -36,13 +36,17 @@ class SqlitePaymentsDbSpec extends AnyFunSuite {
import SqlitePaymentsDbSpec._ import SqlitePaymentsDbSpec._
test("init sqlite 2 times in a row") { test("init sqlite 2 times in a row") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
val db1 = new SqlitePaymentsDb(sqlite) val db1 = dbs.payments()
val db2 = new SqlitePaymentsDb(sqlite) val db2 = dbs.payments()
}
} }
test("handle version migration 1->4") { test("handle version migration 1->4") {
val connection = TestConstants.sqliteInMemory() import fr.acinq.eclair.db.sqlite.SqliteUtils._
forAllDbs {
case dbs: TestSqliteDatabases =>
val connection = dbs.connection
using(connection.createStatement()) { statement => using(connection.createStatement()) { statement =>
getVersion(statement, "payments", 1) getVersion(statement, "payments", 1)
@ -93,9 +97,13 @@ class SqlitePaymentsDbSpec extends AnyFunSuite {
assert(postMigrationDb.listIncomingPayments(1, 1500) === Seq(pr1)) assert(postMigrationDb.listIncomingPayments(1, 1500) === Seq(pr1))
assert(postMigrationDb.listOutgoingPayments(1, 1500) === Seq(ps1)) assert(postMigrationDb.listOutgoingPayments(1, 1500) === Seq(ps1))
} }
}
test("handle version migration 2->4") { test("handle version migration 2->4") {
val connection = TestConstants.sqliteInMemory() import fr.acinq.eclair.db.sqlite.SqliteUtils._
forAllDbs {
case dbs: TestSqliteDatabases =>
val connection = dbs.connection
using(connection.createStatement()) { statement => using(connection.createStatement()) { statement =>
getVersion(statement, "payments", 2) getVersion(statement, "payments", 2)
@ -214,9 +222,13 @@ class SqlitePaymentsDbSpec extends AnyFunSuite {
assert(postMigrationDb.listIncomingPayments(1, System.currentTimeMillis) === Seq(pr1, pr2, pr3)) assert(postMigrationDb.listIncomingPayments(1, System.currentTimeMillis) === Seq(pr1, pr2, pr3))
assert(postMigrationDb.listExpiredIncomingPayments(1, 2000) === Seq(pr2)) assert(postMigrationDb.listExpiredIncomingPayments(1, 2000) === Seq(pr2))
} }
}
test("handle version migration 3->4") { test("handle version migration 3->4") {
val connection = TestConstants.sqliteInMemory() forAllDbs {
case dbs: TestSqliteDatabases =>
import fr.acinq.eclair.db.sqlite.SqliteUtils._
val connection = dbs.connection
using(connection.createStatement()) { statement => using(connection.createStatement()) { statement =>
getVersion(statement, "payments", 3) getVersion(statement, "payments", 3)
@ -300,17 +312,12 @@ class SqlitePaymentsDbSpec extends AnyFunSuite {
using(connection.createStatement()) { statement => using(connection.createStatement()) { statement =>
assert(getVersion(statement, "payments", 4) == 4) // version still to 4 assert(getVersion(statement, "payments", 4) == 4) // version still to 4
} }
}
val ps4 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), None, randomBytes32, PaymentType.SwapOut, 50 msat, 100 msat, carol, 1100, Some(invoice1), OutgoingPaymentStatus.Pending)
postMigrationDb.addOutgoingPayment(ps4)
postMigrationDb.updateOutgoingPayment(PaymentSent(parentId, paymentHash1, preimage1, ps2.recipientAmount, ps2.recipientNodeId, Seq(PaymentSent.PartialPayment(id2, ps2.amount, 15 msat, randomBytes32, Some(Seq(hop_ab)), 1105))))
assert(postMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 15 msat, Seq(HopSummary(hop_ab)), 1105)), ps3, ps4))
} }
test("add/retrieve/update incoming payments") { test("add/retrieve/update incoming payments") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
val db = new SqlitePaymentsDb(sqlite) val db = dbs.payments()
// can't receive a payment without an invoice associated with it // can't receive a payment without an invoice associated with it
assertThrows[IllegalArgumentException](db.receiveIncomingPayment(randomBytes32, 12345678 msat)) assertThrows[IllegalArgumentException](db.receiveIncomingPayment(randomBytes32, 12345678 msat))
@ -359,9 +366,11 @@ class SqlitePaymentsDbSpec extends AnyFunSuite {
assert(db.listPendingIncomingPayments(0, now) === Seq(pendingPayment1, pendingPayment2)) assert(db.listPendingIncomingPayments(0, now) === Seq(pendingPayment1, pendingPayment2))
assert(db.listReceivedIncomingPayments(0, now) === Seq(payment1, payment2)) assert(db.listReceivedIncomingPayments(0, now) === Seq(payment1, payment2))
} }
}
test("add/retrieve/update outgoing payments") { test("add/retrieve/update outgoing payments") {
val db = new SqlitePaymentsDb(TestConstants.sqliteInMemory()) forAllDbs { dbs =>
val db = dbs.payments()
val parentId = UUID.randomUUID() val parentId = UUID.randomUUID()
val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(123 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 0) val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(123 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 0)
@ -386,7 +395,7 @@ class SqlitePaymentsDbSpec extends AnyFunSuite {
assert(db.listOutgoingPayments(ByteVector32.Zeroes) === Nil) assert(db.listOutgoingPayments(ByteVector32.Zeroes) === Nil)
val s3 = s2.copy(id = UUID.randomUUID(), amount = 789 msat, createdAt = 300) val s3 = s2.copy(id = UUID.randomUUID(), amount = 789 msat, createdAt = 300)
val s4 = s2.copy(id = UUID.randomUUID(), paymentType = PaymentType.Standard, createdAt = 300) val s4 = s2.copy(id = UUID.randomUUID(), paymentType = PaymentType.Standard, createdAt = 301)
db.addOutgoingPayment(s3) db.addOutgoingPayment(s3)
db.addOutgoingPayment(s4) db.addOutgoingPayment(s4)
@ -414,6 +423,7 @@ class SqlitePaymentsDbSpec extends AnyFunSuite {
// can't update again once it's in a final state // can't update again once it's in a final state
assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentFailed(s1.id, s1.paymentHash, Nil))) assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentFailed(s1.id, s1.paymentHash, Nil)))
} }
}
test("high level payments overview") { test("high level payments overview") {
val db = new SqlitePaymentsDb(TestConstants.sqliteInMemory()) val db = new SqlitePaymentsDb(TestConstants.sqliteInMemory())

View file

@ -16,28 +16,26 @@
package fr.acinq.eclair.db package fr.acinq.eclair.db
import java.sql.DriverManager
import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.db.sqlite.SqlitePeersDb
import fr.acinq.eclair.randomKey
import fr.acinq.eclair.wire.{NodeAddress, Tor2, Tor3} import fr.acinq.eclair.wire.{NodeAddress, Tor2, Tor3}
import fr.acinq.eclair.{TestConstants, randomKey}
import org.scalatest.funsuite.AnyFunSuite import org.scalatest.funsuite.AnyFunSuite
class SqlitePeersDbSpec extends AnyFunSuite { class SqlitePeersDbSpec extends AnyFunSuite {
def inmem = DriverManager.getConnection("jdbc:sqlite::memory:") import TestConstants.forAllDbs
test("init sqlite 2 times in a row") { test("init sqlite 2 times in a row") {
val sqlite = inmem forAllDbs { dbs =>
val db1 = new SqlitePeersDb(sqlite) val db1 = dbs.peers()
val db2 = new SqlitePeersDb(sqlite) val db2 = dbs.peers()
}
} }
test("add/remove/get/list peers") { test("add/remove/list peers") {
val sqlite = inmem forAllDbs { dbs =>
val db = new SqlitePeersDb(sqlite) val db = dbs.peers()
case class TestCase(nodeId: PublicKey, nodeAddress: NodeAddress) case class TestCase(nodeId: PublicKey, nodeAddress: NodeAddress)
@ -61,5 +59,6 @@ class SqlitePeersDbSpec extends AnyFunSuite {
assert(db.getPeer(peer_1.nodeId) === Some(peer_1_bis.nodeAddress)) assert(db.getPeer(peer_1.nodeId) === Some(peer_1_bis.nodeAddress))
assert(db.listPeers().map(p => TestCase(p._1, p._2)).toSet === Set(peer_1_bis, peer_3)) assert(db.listPeers().map(p => TestCase(p._1, p._2)).toSet === Set(peer_1_bis, peer_3))
} }
}
} }

View file

@ -17,23 +17,25 @@
package fr.acinq.eclair.db package fr.acinq.eclair.db
import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FAIL_MALFORMED_HTLC, CMD_FULFILL_HTLC} import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FAIL_MALFORMED_HTLC, CMD_FULFILL_HTLC}
import fr.acinq.eclair.db.sqlite.SqlitePendingRelayDb
import fr.acinq.eclair.{TestConstants, randomBytes32}
import fr.acinq.eclair.wire.FailureMessageCodecs import fr.acinq.eclair.wire.FailureMessageCodecs
import fr.acinq.eclair.{TestConstants, randomBytes32}
import org.scalatest.funsuite.AnyFunSuite import org.scalatest.funsuite.AnyFunSuite
class SqlitePendingRelayDbSpec extends AnyFunSuite { class SqlitePendingRelayDbSpec extends AnyFunSuite {
import TestConstants.forAllDbs
test("init sqlite 2 times in a row") { test("init sqlite 2 times in a row") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
val db1 = new SqlitePendingRelayDb(sqlite) val db1 = dbs.pendingRelay()
val db2 = new SqlitePendingRelayDb(sqlite) val db2 = dbs.pendingRelay()
}
} }
test("add/remove/list messages") { test("add/remove/list messages") {
val sqlite = TestConstants.sqliteInMemory() forAllDbs { dbs =>
val db = new SqlitePendingRelayDb(sqlite) val db = dbs.pendingRelay()
val channelId1 = randomBytes32 val channelId1 = randomBytes32
val channelId2 = randomBytes32 val channelId2 = randomBytes32
@ -58,5 +60,6 @@ class SqlitePendingRelayDbSpec extends AnyFunSuite {
db.removePendingRelay(channelId1, msg1.id) db.removePendingRelay(channelId1, msg1.id)
assert(db.listPendingRelay === Set((channelId1, msg0.id), (channelId1, msg2.id), (channelId1, msg3.id), (channelId1, msg4.id), (channelId2, msg0.id), (channelId2, msg1.id))) assert(db.listPendingRelay === Set((channelId1, msg0.id), (channelId1, msg2.id), (channelId1, msg3.id), (channelId1, msg4.id), (channelId2, msg0.id), (channelId2, msg1.id)))
} }
}
} }

View file

@ -16,10 +16,11 @@
package fr.acinq.eclair.db package fr.acinq.eclair.db
import java.sql.SQLException
import fr.acinq.eclair.TestConstants import fr.acinq.eclair.TestConstants
import fr.acinq.eclair.db.sqlite.SqliteUtils.using import fr.acinq.eclair.db.sqlite.SqliteUtils.using
import org.scalatest.funsuite.AnyFunSuite import org.scalatest.funsuite.AnyFunSuite
import org.sqlite.SQLiteException
class SqliteUtilsSpec extends AnyFunSuite { class SqliteUtilsSpec extends AnyFunSuite {
@ -41,7 +42,7 @@ class SqliteUtilsSpec extends AnyFunSuite {
assert(!results.next()) assert(!results.next())
} }
assertThrows[SQLiteException](using(conn.createStatement(), inTransaction = true) { statement => assertThrows[SQLException](using(conn.createStatement(), inTransaction = true) { statement =>
statement.executeUpdate("INSERT INTO utils_test VALUES (3, 3)") statement.executeUpdate("INSERT INTO utils_test VALUES (3, 3)")
statement.executeUpdate("INSERT INTO utils_test VALUES (1, 3)") // should throw (primary key violation) statement.executeUpdate("INSERT INTO utils_test VALUES (1, 3)") // should throw (primary key violation)
}) })

View file

@ -1 +1 @@
{"version":"1.0.0-SNAPSHOT-e3f1ec0","nodeId":"03af0ed6052cf28d670665549bc86f4b721c9fdb309d40c58f5811f63966e005d0","alias":"alice","color":"#000102","features":{"activated":[{"name":"option_data_loss_protect","support":"mandatory"},{"name":"gossip_queries_ex","support":"optional"}],"unknown":[]},"chainHash":"06226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f","network":"regtest","blockHeight":9999,"publicAddresses":["localhost:9731"]} {"version":"1.0.0-SNAPSHOT-e3f1ec0","nodeId":"03af0ed6052cf28d670665549bc86f4b721c9fdb309d40c58f5811f63966e005d0","alias":"alice","color":"#000102","features":{"activated":[{"name":"option_data_loss_protect","support":"mandatory"},{"name":"gossip_queries_ex","support":"optional"}],"unknown":[]},"chainHash":"06226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f","network":"regtest","blockHeight":9999,"publicAddresses":["localhost:9731"],"instanceId":"01234567-0123-4567-89ab-0123456789ab"}

View file

@ -181,7 +181,8 @@ class ApiServiceSpec extends AnyFunSuiteLike with ScalatestRouteTest with RouteT
chainHash = ByteVector32(hex"06226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f"), chainHash = ByteVector32(hex"06226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f"),
network = "regtest", network = "regtest",
blockHeight = 9999, blockHeight = 9999,
publicAddresses = NodeAddress.fromParts("localhost", 9731).get :: Nil publicAddresses = NodeAddress.fromParts("localhost", 9731).get :: Nil,
instanceId = "01234567-0123-4567-89ab-0123456789ab"
)) ))
Post("/getinfo") ~> Post("/getinfo") ~>