Make DLCWallet.listDLCs use DBIOActions (#4555)

* Make DLCWallet.listDLCs use DBIOActions

* Add comment

Co-authored-by: Chris Stewart <stewart.chris1234@gmail.com>
This commit is contained in:
benthecarman 2022-07-31 14:09:23 -05:00 committed by GitHub
parent 936a65edd4
commit 603b7e0aea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 244 additions and 144 deletions

View File

@ -36,12 +36,13 @@ import org.bitcoins.dlc.wallet.util.{
DLCAcceptUtil,
DLCActionBuilder,
DLCStatusBuilder,
DLCTxUtil
DLCTxUtil,
IntermediaryDLCStatus
}
import org.bitcoins.wallet.config.WalletAppConfig
import org.bitcoins.wallet.{Wallet, WalletLogger}
import scodec.bits.ByteVector
import slick.dbio.{DBIO, DBIOAction}
import slick.dbio._
import java.net.InetSocketAddress
import scala.concurrent.Future
@ -55,6 +56,8 @@ abstract class DLCWallet
implicit val dlcConfig: DLCAppConfig
import dlcConfig.profile.api._
private[bitcoins] val announcementDAO: OracleAnnouncementDataDAO =
OracleAnnouncementDataDAO()
private[bitcoins] val oracleNonceDAO: OracleNonceDAO = OracleNonceDAO()
@ -101,6 +104,7 @@ abstract class DLCWallet
}
private lazy val safeDatabase: SafeDatabase = dlcDAO.safeDatabase
private lazy val walletDatabase: SafeDatabase = addressDAO.safeDatabase
/** Updates the contract Id in the wallet database for the given offer and accept */
private def updateDLCContractIds(
@ -1683,16 +1687,10 @@ abstract class DLCWallet
s"Created DLC refund transaction ${refundTx.txIdBE.hex} for contract ${contractId.toHex}")
_ <- updateDLCState(contractId, DLCState.Refunded)
updatedDlcDb <- updateClosingTxId(contractId, refundTx.txIdBE)
_ <- updateClosingTxId(contractId, refundTx.txIdBE)
_ <- processTransaction(refundTx, blockHashOpt = None)
closingTxOpt <- getClosingTxOpt(updatedDlcDb)
dlcAcceptOpt <- dlcAcceptDAO.findByDLCId(updatedDlcDb.dlcId)
status <- buildDLCStatus(updatedDlcDb,
contractData,
offerDbOpt.get,
dlcAcceptOpt,
closingTxOpt)
status <- findDLC(dlcDb.dlcId)
_ <- dlcConfig.walletCallbacks.executeOnDLCStateChange(logger, status.get)
} yield refundTx
}
@ -1720,29 +1718,40 @@ abstract class DLCWallet
private def listDLCs(
contactIdOpt: Option[InetSocketAddress]): Future[Vector[DLCStatus]] = {
for {
val dlcAction = for {
dlcs <- contactIdOpt match {
case Some(contactId) =>
dlcDAO.findByContactId(
dlcDAO.findByContactIdAction(
contactId.getHostString + ":" + contactId.getPort)
case None => dlcDAO.findAll()
case None => dlcDAO.findAllAction()
}
ids = dlcs.map(_.dlcId)
dlcFs = ids.map(findDLC)
dlcs <- Future.sequence(dlcFs)
dlcFs = ids.map(findDLCAction)
dlcs <- DBIO.sequence(dlcFs)
} yield {
dlcs.collect { case Some(dlc) =>
dlc
}
}
safeDatabase.run(dlcAction).flatMap { intermediaries =>
val actions = intermediaries.map { intermediary =>
getWalletDLCDbsAction(intermediary).map {
case (closingTxOpt, payoutAddrOpt) =>
intermediary.complete(payoutAddrOpt, closingTxOpt)
}
}
walletDatabase.run(DBIO.sequence(actions))
}
}
private def getClosingTxOpt(dlcDb: DLCDb): Future[Option[TransactionDb]] = {
val result =
dlcDb.closingTxIdOpt.map(txid => transactionDAO.findByTxId(txid))
result match {
case None => Future.successful(None)
case Some(r) => r
private def getClosingTxOptAction(dlcDb: DLCDb): DBIOAction[
Option[TransactionDb],
NoStream,
Effect.Read] = {
dlcDb.closingTxIdOpt match {
case None => DBIOAction.successful(None)
case Some(txid) => transactionDAO.findByTxIdAction(txid)
}
}
@ -1757,7 +1766,7 @@ abstract class DLCWallet
dlcDbOpt <- dlcDAO.findByTempContractId(tempContractId)
dlcStatusOpt <- dlcDbOpt match {
case None => Future.successful(None)
case Some(dlcDb) => findDLCStatus(dlcDb)
case Some(dlcDb) => findDLC(dlcDb.dlcId)
}
} yield dlcStatusOpt
@ -1769,66 +1778,88 @@ abstract class DLCWallet
}
override def findDLC(dlcId: Sha256Digest): Future[Option[DLCStatus]] = {
val intermediaryF = safeDatabase.run(findDLCAction(dlcId))
intermediaryF.flatMap {
case None => Future.successful(None)
case Some(intermediary) =>
val action = getWalletDLCDbsAction(intermediary)
walletDatabase.run(action).map { case (closingTxOpt, payoutAddress) =>
val res = intermediary.complete(payoutAddress, closingTxOpt)
Some(res)
}
}
}
private def getWalletDLCDbsAction(intermediary: IntermediaryDLCStatus) = {
val dlcDb = intermediary.dlcDb
for {
closingTxOpt <- getClosingTxOptAction(dlcDb)
payoutAddress <- getPayoutAddressAction(dlcDb,
intermediary.offerDb,
intermediary.acceptDbOpt)
} yield (closingTxOpt, payoutAddress)
}
private def findDLCAction(dlcId: Sha256Digest): DBIOAction[
Option[IntermediaryDLCStatus],
NoStream,
Effect.Read] = {
val start = System.currentTimeMillis()
val dlcOptF = for {
dlcDbOpt <- dlcDAO.read(dlcId)
val dlcOptA = for {
dlcDbOpt <- dlcDAO.findByPrimaryKeyAction(dlcId)
dlcStatusOpt <- dlcDbOpt match {
case None => Future.successful(None)
case Some(dlcDb) => findDLCStatus(dlcDb)
case None => DBIO.successful(None)
case Some(dlcDb) => findDLCStatusAction(dlcDb)
}
} yield dlcStatusOpt
dlcOptF.foreach(_ =>
dlcOptA.map { res =>
logger.debug(
s"Done finding dlc=$dlcId, it took=${System.currentTimeMillis() - start}ms"))
dlcOptF
s"Done finding dlc=$dlcId, it took=${System.currentTimeMillis() - start}ms")
res
}
}
private def findDLCStatus(dlcDb: DLCDb): Future[Option[DLCStatus]] = {
private def findDLCStatusAction(dlcDb: DLCDb): DBIOAction[
Option[IntermediaryDLCStatus],
NoStream,
Effect.Read] = {
val dlcId = dlcDb.dlcId
val contractDataOptF = contractDataDAO.read(dlcId)
val offerDbOptF = dlcOfferDAO.read(dlcId)
val acceptDbOptF = dlcAcceptDAO.read(dlcId)
val closingTxOptF: Future[Option[TransactionDb]] = getClosingTxOpt(dlcDb)
val contractDataOptA = contractDataDAO.findByPrimaryKeyAction(dlcId)
val offerDbOptA = dlcOfferDAO.findByPrimaryKeyAction(dlcId)
val acceptDbOptA = dlcAcceptDAO.findByPrimaryKeyAction(dlcId)
val dlcOptF: Future[Option[DLCStatus]] = for {
contractDataOpt <- contractDataOptF
offerDbOpt <- offerDbOptF
acceptDbOpt <- acceptDbOptF
closingTxOpt <- closingTxOptF
for {
contractDataOpt <- contractDataOptA
offerDbOpt <- offerDbOptA
acceptDbOpt <- acceptDbOptA
result <- {
(contractDataOpt, offerDbOpt) match {
case (Some(contractData), Some(offerDb)) =>
buildDLCStatus(dlcDb,
contractData,
offerDb,
acceptDbOpt,
closingTxOpt)
case (_, _) => Future.successful(None)
buildDLCStatusAction(dlcDb, contractData, offerDb, acceptDbOpt)
case (_, _) => DBIO.successful(None)
}
}
} yield result
dlcOptF
}
/** Helper method to assemble a [[DLCStatus]] */
private def buildDLCStatus(
private def buildDLCStatusAction(
dlcDb: DLCDb,
contractData: DLCContractDataDb,
offerDb: DLCOfferDb,
acceptDbOpt: Option[DLCAcceptDb],
closingTxOpt: Option[TransactionDb]): Future[Option[DLCStatus]] = {
acceptDbOpt: Option[DLCAcceptDb]): DBIOAction[
Option[IntermediaryDLCStatus],
NoStream,
Effect.Read] = {
val dlcId = dlcDb.dlcId
val aggregatedF: Future[(
Vector[DLCAnnouncementDb],
Vector[OracleAnnouncementDataDb],
Vector[OracleNonceDb])] =
dlcDataManagement.getDLCAnnouncementDbs(dlcId)
val aggregatedA =
dlcDataManagement.getDLCAnnouncementDbsAction(dlcId)
val contractInfoAndAnnouncementsF: Future[
(ContractInfo, Vector[(OracleAnnouncementV0TLV, Long)])] = {
aggregatedF.map { case (announcements, announcementData, nonceDbs) =>
val contractInfoAndAnnouncementsA = {
aggregatedA.map { case (announcements, announcementData, nonceDbs) =>
val contractInfo = dlcDataManagement.getContractInfo(contractData,
announcements,
announcementData,
@ -1841,67 +1872,38 @@ abstract class DLCWallet
}
}
val statusF: Future[DLCStatus] = for {
(contractInfo, announcementsWithId) <- contractInfoAndAnnouncementsF
(announcementIds, _, nonceDbs) <- aggregatedF
payoutAddress <- getPayoutAddress(dlcDb, offerDb, acceptDbOpt)
status <- {
dlcDb.state match {
case _: DLCState.InProgressState =>
val inProgress = DLCStatusBuilder.buildInProgressDLCStatus(
dlcDb = dlcDb,
contractInfo = contractInfo,
contractData = contractData,
offerDb = offerDb,
payoutAddress = payoutAddress)
Future.successful(inProgress)
case _: DLCState.ClosedState =>
(acceptDbOpt, closingTxOpt) match {
case (Some(acceptDb), Some(closingTx)) =>
val status = DLCStatusBuilder.buildClosedDLCStatus(
dlcDb = dlcDb,
contractInfo = contractInfo,
contractData = contractData,
announcementsWithId = announcementsWithId,
announcementIds = announcementIds,
nonceDbs = nonceDbs,
offerDb = offerDb,
acceptDb = acceptDb,
closingTx = closingTx.transaction,
payoutAddress = payoutAddress
)
Future.successful(status)
case (None, None) =>
Future.failed(new RuntimeException(
s"Could not find acceptDb or closingTx for closing state=${dlcDb.state} dlcId=$dlcId"))
case (Some(_), None) =>
Future.failed(new RuntimeException(
s"Could not find closingTx for state=${dlcDb.state} dlcId=$dlcId"))
case (None, Some(_)) =>
Future.failed(new RuntimeException(
s"Cannot find acceptDb for dlcId=$dlcId. This likely means we have data corruption"))
}
}
}
} yield status
val statusA = for {
(contractInfo, announcementsWithId) <- contractInfoAndAnnouncementsA
(announcementIds, _, nonceDbs) <- aggregatedA
} yield IntermediaryDLCStatus(dlcDb,
contractInfo,
contractData,
offerDb,
acceptDbOpt,
nonceDbs,
announcementsWithId,
announcementIds)
statusF.map(Some(_))
statusA.map(Some(_))
}
private def getPayoutAddress(
private def getPayoutAddressAction(
dlcDb: DLCDb,
offerDb: DLCOfferDb,
acceptDbOpt: Option[DLCAcceptDb]): Future[Option[PayoutAddress]] = {
acceptDbOpt: Option[DLCAcceptDb]): DBIOAction[
Option[PayoutAddress],
NoStream,
Effect.Read] = {
val addressOpt = if (dlcDb.isInitiator) {
Some(offerDb.payoutAddress)
} else {
acceptDbOpt.map(_.payoutAddress)
}
addressOpt match {
case None => Future.successful(None)
case None => DBIOAction.successful(None)
case Some(address) =>
for {
isExternal <- addressDAO.findAddress(address).map(_.isEmpty)
isExternal <- addressDAO.findAddressAction(address).map(_.isEmpty)
} yield Some(PayoutAddress(address, isExternal))
}
}

View File

@ -22,7 +22,7 @@ import org.bitcoins.dlc.wallet.util.DLCActionBuilder
import org.bitcoins.keymanager.bip39.BIP39KeyManager
import org.bitcoins.wallet.models.TransactionDAO
import scodec.bits._
import slick.dbio.{DBIOAction, Effect, NoStream}
import slick.dbio._
import scala.concurrent._
import scala.util.Try
@ -55,29 +55,42 @@ case class DLCDataManagement(dlcWalletDAOs: DLCWalletDAOs)(implicit
dataF.map(data => data.map(_.offer))
}
private[wallet] def getDLCAnnouncementDbs(dlcId: Sha256Digest): Future[(
Vector[DLCAnnouncementDb],
Vector[OracleAnnouncementDataDb],
Vector[OracleNonceDb])] = {
val announcementsF = dlcAnnouncementDAO.findByDLCId(dlcId)
val announcementIdsF: Future[Vector[Long]] = for {
announcements <- announcementsF
announcementIds = announcements.map(_.announcementId)
} yield announcementIds
val announcementDataF =
announcementIdsF.flatMap(ids => announcementDAO.findByIds(ids))
val noncesDbF =
announcementIdsF.flatMap(ids => oracleNonceDAO.findByAnnouncementIds(ids))
private[wallet] def getDLCAnnouncementDbsAction(
dlcId: Sha256Digest): DBIOAction[
(
Vector[DLCAnnouncementDb],
Vector[OracleAnnouncementDataDb],
Vector[OracleNonceDb]
),
NoStream,
Effect.Read] = {
val announcementsA = dlcAnnouncementDAO
.findByDLCIdAction(dlcId)
val announcementIdsA = announcementsA
.map(_.map(_.announcementId))
val announcementDataA =
announcementIdsA.flatMap(ids => announcementDAO.findByIdsAction(ids))
val noncesDbA =
announcementIdsA.flatMap(ids =>
oracleNonceDAO.findByAnnouncementIdsAction(ids))
for {
announcements <- announcementsF
announcementData <- announcementDataF
nonceDbs <- noncesDbF
announcements <- announcementsA
announcementData <- announcementDataA
nonceDbs <- noncesDbA
} yield {
(announcements, announcementData, nonceDbs)
}
}
private[wallet] def getDLCAnnouncementDbs(dlcId: Sha256Digest)(implicit
ec: ExecutionContext): Future[(
Vector[DLCAnnouncementDb],
Vector[OracleAnnouncementDataDb],
Vector[OracleNonceDb])] = {
safeDatabase.run(getDLCAnnouncementDbsAction(dlcId))
}
/** Fetches the oracle announcements of the oracles
* that were used for execution in a DLC
*/

View File

@ -124,9 +124,13 @@ case class DLCDAO()(implicit
}
def findByContactId(contactId: String): Future[Vector[DLCDb]] = {
safeDatabase.run(findByContactIdAction(contactId))
}
def findByContactIdAction(
contactId: String): DBIOAction[Vector[DLCDb], NoStream, Effect.Read] = {
val peer: Option[String] = Some(contactId)
val action = table.filter(_.peerOpt === peer).result
safeDatabase.runVec(action)
table.filter(_.peerOpt === peer).result.map(_.toVector)
}
def updateDLCContactMapping(

View File

@ -37,9 +37,14 @@ case class OracleAnnouncementDataDAO()(implicit
}
def findByIds(ids: Vector[Long]): Future[Vector[OracleAnnouncementDataDb]] = {
val query = table.filter(_.id.inSet(ids))
safeDatabase.run(findByIdsAction(ids))
}
safeDatabase.runVec(query.result)
def findByIdsAction(ids: Vector[Long]): DBIOAction[
Vector[OracleAnnouncementDataDb],
NoStream,
Effect.Read] = {
table.filter(_.id.inSet(ids)).result.map(_.toVector)
}
def findById(id: Long): Future[Option[OracleAnnouncementDataDb]] = {

View File

@ -89,9 +89,14 @@ case class OracleNonceDAO()(implicit
def findByAnnouncementIds(
ids: Vector[Long]): Future[Vector[OracleNonceDb]] = {
val query = table.filter(_.announcementId.inSet(ids))
safeDatabase.run(findByAnnouncementIdsAction(ids))
}
safeDatabase.runVec(query.result)
def findByAnnouncementIdsAction(ids: Vector[Long]): DBIOAction[
Vector[OracleNonceDb],
NoStream,
Effect.Read] = {
table.filter(_.announcementId.inSet(ids)).result.map(_.toVector)
}
class OracleNoncesTable(tag: Tag)

View File

@ -1,6 +1,7 @@
package org.bitcoins.dlc.wallet.util
import org.bitcoins.core.api.dlc.wallet.db.DLCDb
import org.bitcoins.core.api.wallet.db.TransactionDb
import org.bitcoins.core.dlc.accounting.DLCAccounting
import org.bitcoins.core.protocol.dlc.models.DLCStatus._
import org.bitcoins.core.protocol.dlc.models._
@ -10,6 +11,63 @@ import org.bitcoins.crypto.SchnorrDigitalSignature
import org.bitcoins.dlc.wallet.accounting.{AccountingUtil, DLCAccountingDbs}
import org.bitcoins.dlc.wallet.models._
/** Creates a case class that represents all DLC data from dlcdb.sqlite
* Unfortunately we have to read some data from walletdb.sqlite to build a
* full [[DLCStatus]]
* @see https://github.com/bitcoin-s/bitcoin-s/pull/4555#issuecomment-1200113188
*/
case class IntermediaryDLCStatus(
dlcDb: DLCDb,
contractInfo: ContractInfo,
contractData: DLCContractDataDb,
offerDb: DLCOfferDb,
acceptDbOpt: Option[DLCAcceptDb],
nonceDbs: Vector[OracleNonceDb],
announcementsWithId: Vector[(OracleAnnouncementV0TLV, Long)],
announcementIds: Vector[DLCAnnouncementDb]
) {
def complete(
payoutAddressOpt: Option[PayoutAddress],
closingTxOpt: Option[TransactionDb]): DLCStatus = {
val dlcId = dlcDb.dlcId
dlcDb.state match {
case _: DLCState.InProgressState =>
DLCStatusBuilder.buildInProgressDLCStatus(dlcDb = dlcDb,
contractInfo = contractInfo,
contractData = contractData,
offerDb = offerDb,
payoutAddress =
payoutAddressOpt)
case _: DLCState.ClosedState =>
(acceptDbOpt, closingTxOpt) match {
case (Some(acceptDb), Some(closingTx)) =>
DLCStatusBuilder.buildClosedDLCStatus(
dlcDb = dlcDb,
contractInfo = contractInfo,
contractData = contractData,
announcementsWithId = announcementsWithId,
announcementIds = announcementIds,
nonceDbs = nonceDbs,
offerDb = offerDb,
acceptDb = acceptDb,
closingTx = closingTx.transaction,
payoutAddress = payoutAddressOpt
)
case (None, None) =>
throw new RuntimeException(
s"Could not find acceptDb or closingTx for closing state=${dlcDb.state} dlcId=$dlcId")
case (Some(_), None) =>
throw new RuntimeException(
s"Could not find closingTx for state=${dlcDb.state} dlcId=$dlcId")
case (None, Some(_)) =>
throw new RuntimeException(
s"Cannot find acceptDb for dlcId=$dlcId. This likely means we have data corruption")
}
}
}
}
object DLCStatusBuilder {
/** Helper method to convert a bunch of indepdendent datastructures into a in progress dlc status */
@ -21,7 +79,7 @@ object DLCStatusBuilder {
payoutAddress: Option[PayoutAddress]): DLCStatus = {
require(
dlcDb.state.isInstanceOf[DLCState.InProgressState],
s"Cannot have divergent states beteween dlcDb and the parameter state, got= dlcDb.state=${dlcDb.state} state=${dlcDb.state}"
s"Cannot have divergent states between dlcDb and the parameter state, got= dlcDb.state=${dlcDb.state} state=${dlcDb.state}"
)
val dlcId = dlcDb.dlcId

View File

@ -165,12 +165,18 @@ case class AddressDAO()(implicit
}
def findAddress(addr: BitcoinAddress): Future[Option[AddressDb]] = {
val query = table
safeDatabase.run(findAddressAction(addr))
}
def findAddressAction(addr: BitcoinAddress): DBIOAction[
Option[AddressDb],
NoStream,
Effect.Read] = {
table
.join(spkTable)
.on(_.scriptPubKeyId === _.id)
.filter(_._1.address === addr)
safeDatabase
.run(query.result)
.result
.map(_.headOption)
.map(res =>
res.map { case (addrRec, spkRec) =>

View File

@ -70,20 +70,27 @@ trait TxDAO[DbEntryType <: TxDB]
safeDatabase.runVec(q.result)
}
def findByTxId(txIdBE: DoubleSha256DigestBE): Future[Option[DbEntryType]] = {
val q = table
def findByTxIdAction(txIdBE: DoubleSha256DigestBE): DBIOAction[
Option[DbEntryType],
NoStream,
Effect.Read] = {
table
.filter(_.txIdBE === txIdBE)
.result
.map {
case h +: Vector() =>
Some(h)
case Vector() =>
None
case txs: Vector[DbEntryType] =>
// yikes, we should not have more the one transaction per id
throw new RuntimeException(
s"More than one transaction per id=${txIdBE.hex}, got=$txs")
}
}
safeDatabase.run(q.result).map {
case h +: Vector() =>
Some(h)
case Vector() =>
None
case txs: Vector[DbEntryType] =>
// yikes, we should not have more the one transaction per id
throw new RuntimeException(
s"More than one transaction per id=${txIdBE.hex}, got=$txs")
}
def findByTxId(txIdBE: DoubleSha256DigestBE): Future[Option[DbEntryType]] = {
safeDatabase.run(findByTxIdAction(txIdBE))
}
def findByTxId(txId: DoubleSha256Digest): Future[Option[DbEntryType]] =