Add CRUDAction.{updateAllAction, updateAction} (#3872)

* Add CRUDAction.{updateAllAction, updateAction}

* Move updateDLCOracleSigs into DLCActionBuilder and make it an action

* Update DLCTransactionProcessing.calculateAndSetOutcome() to use actions
This commit is contained in:
Chris Stewart 2021-12-04 05:29:51 -06:00 committed by GitHub
parent 9c9a0a618f
commit 2d5732375f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 127 additions and 72 deletions

View File

@ -75,30 +75,13 @@ abstract class CRUD[T, PrimaryKeyType](implicit
/** Update the corresponding record in the database */ /** Update the corresponding record in the database */
def update(t: T): Future[T] = { def update(t: T): Future[T] = {
updateAll(Vector(t)).map { ts => val action = updateAction(t).transactionally
ts.headOption match { safeDatabase.run(action)
case Some(updated) => updated
case None => throw UpdateFailedException("Update failed for: " + t)
}
}
} }
def updateAll(ts: Vector[T]): Future[Vector[T]] = { def updateAll(ts: Vector[T]): Future[Vector[T]] = {
if (ts.isEmpty) { val actions = updateAllAction(ts).transactionally
Future.successful(ts) safeDatabase.runVec(actions)
} else {
val actions = ts.map(t => find(t).update(t))
for {
numUpdated <- safeDatabase.runVec(
DBIO.sequence(actions).transactionally)
tsUpdated <- {
if (numUpdated.sum == ts.length) Future.successful(ts)
else
Future.failed(new RuntimeException(
s"Unexpected number of updates completed ${numUpdated.sum} of ${ts.length}"))
}
} yield tsUpdated
}
} }
/** delete the corresponding record in the database /** delete the corresponding record in the database
@ -165,15 +148,6 @@ abstract class CRUD[T, PrimaryKeyType](implicit
protected def findByPrimaryKeys( protected def findByPrimaryKeys(
ids: Vector[PrimaryKeyType]): Query[Table[T], T, Seq] ids: Vector[PrimaryKeyType]): Query[Table[T], T, Seq]
/** return the row that corresponds with this record
*
* @param t - the row to find
* @return query - the sql query to find this record
*/
protected def find(t: T): Query[Table[_], T, Seq] = findAll(Vector(t))
protected def findAll(ts: Vector[T]): Query[Table[_], T, Seq]
/** Finds all elements in the table */ /** Finds all elements in the table */
def findAll(): Future[Vector[T]] = def findAll(): Future[Vector[T]] =
safeDatabase.run(table.result).map(_.toVector) safeDatabase.run(table.result).map(_.toVector)

View File

@ -18,4 +18,43 @@ abstract class CRUDAction[T, PrimaryKeyType](implicit
createAllAction(Vector(t)) createAllAction(Vector(t))
.map(_.head) .map(_.head)
} }
def updateAction(t: T): DBIOAction[T, NoStream, Effect.Write] = {
updateAllAction(Vector(t)).map(_.head)
}
protected def find(t: T): Query[Table[_], T, Seq] = findAll(Vector(t))
protected def findAll(ts: Vector[T]): Query[Table[_], T, Seq]
/** Updates all of the given ts.
* Returns all ts that actually existed in the database and got updated
* This method discards things that did not exist in the database,
* thus could not be updated
*/
def updateAllAction(
ts: Vector[T]): DBIOAction[Vector[T], NoStream, Effect.Write] = {
val updateActions: Vector[DBIOAction[Option[T], NoStream, Effect.Write]] = {
ts.map { t =>
find(t).update(t).flatMap { rowsUpdated =>
if (rowsUpdated == 0) {
DBIO.successful(None)
} else if (rowsUpdated == 1) {
DBIO.successful(Some(t))
} else {
DBIO.failed(new RuntimeException(
s"Updated more rows that we intended to update, updated=$rowsUpdated"))
}
}
}
}
val sequencedA: DBIOAction[Vector[Option[T]], NoStream, Effect.Write] = {
DBIO.sequence(updateActions)
}
//discard all rows that did not exist,
//thus cannot be updated
sequencedA.map(_.flatten)
}
} }

View File

@ -72,7 +72,8 @@ abstract class DLCWallet
dlcOfferDAO = dlcOfferDAO, dlcOfferDAO = dlcOfferDAO,
dlcAcceptDAO = dlcAcceptDAO, dlcAcceptDAO = dlcAcceptDAO,
dlcSigsDAO = dlcSigsDAO, dlcSigsDAO = dlcSigsDAO,
dlcRefundSigDAO = dlcRefundSigDAO dlcRefundSigDAO = dlcRefundSigDAO,
oracleNonceDAO = oracleNonceDAO
) )
} }
@ -198,25 +199,10 @@ abstract class DLCWallet
require(outcomeAndSigByNonce.forall(t => t._1 == t._2._2.rx), require(outcomeAndSigByNonce.forall(t => t._1 == t._2._2.rx),
"nonces out of order") "nonces out of order")
val updateOracleSigsA =
actionBuilder.updateDLCOracleSigsAction(outcomeAndSigByNonce)
for { for {
nonceDbs <- oracleNonceDAO.findByNonces( updates <- safeDatabase.runVec(updateOracleSigsA)
outcomeAndSigByNonce.keys.toVector)
_ = assert(nonceDbs.size == outcomeAndSigByNonce.keys.size,
"Didn't receive all nonce dbs")
updated = nonceDbs.map { db =>
val (outcome, sig) = outcomeAndSigByNonce(db.nonce)
db.copy(outcomeOpt = Some(outcome), signatureOpt = Some(sig))
}
updates <- oracleNonceDAO.updateAll(updated)
announcementIds = updates.map(_.announcementId).distinct
announcementDbs <- dlcAnnouncementDAO.findByAnnouncementIds(
announcementIds)
updatedDbs = announcementDbs.map(_.copy(used = true))
_ <- dlcAnnouncementDAO.updateAll(updatedDbs)
} yield updates } yield updates
} }

View File

@ -10,6 +10,7 @@ import org.bitcoins.core.protocol.transaction.{Transaction, WitnessTransaction}
import org.bitcoins.core.util.FutureUtil import org.bitcoins.core.util.FutureUtil
import org.bitcoins.core.wallet.utxo.AddressTag import org.bitcoins.core.wallet.utxo.AddressTag
import org.bitcoins.crypto.DoubleSha256DigestBE import org.bitcoins.crypto.DoubleSha256DigestBE
import org.bitcoins.db.SafeDatabase
import org.bitcoins.dlc.wallet.DLCWallet import org.bitcoins.dlc.wallet.DLCWallet
import org.bitcoins.wallet.internal.TransactionProcessing import org.bitcoins.wallet.internal.TransactionProcessing
@ -21,6 +22,9 @@ import scala.concurrent._
private[bitcoins] trait DLCTransactionProcessing extends TransactionProcessing { private[bitcoins] trait DLCTransactionProcessing extends TransactionProcessing {
self: DLCWallet => self: DLCWallet =>
import dlcDAO.profile.api._
private lazy val safeDatabase: SafeDatabase = dlcDAO.safeDatabase
/** Calculates the new state of the DLCDb based on the closing transaction, /** Calculates the new state of the DLCDb based on the closing transaction,
* will delete old CET sigs that are no longer needed after execution * will delete old CET sigs that are no longer needed after execution
*/ */
@ -76,8 +80,8 @@ private[bitcoins] trait DLCTransactionProcessing extends TransactionProcessing {
val dlcId = dlcDb.dlcId val dlcId = dlcDb.dlcId
for { for {
(_, contractData, offerDb, fundingInputDbs, _) <- getDLCOfferData(dlcId) (_, contractData, offerDb, acceptDb, fundingInputDbs, _) <-
acceptDbOpt <- dlcAcceptDAO.findByDLCId(dlcId) getDLCFundingData(dlcId)
txIds = fundingInputDbs.map(_.outPoint.txIdBE) txIds = fundingInputDbs.map(_.outPoint.txIdBE)
remotePrevTxs <- remoteTxDAO.findByTxIdBEs(txIds) remotePrevTxs <- remoteTxDAO.findByTxIdBEs(txIds)
localPrevTxs <- transactionDAO.findByTxIdBEs(txIds) localPrevTxs <- transactionDAO.findByTxIdBEs(txIds)
@ -110,14 +114,12 @@ private[bitcoins] trait DLCTransactionProcessing extends TransactionProcessing {
val offer = val offer =
offerDb.toDLCOffer(contractInfo, fundingInputs, dlcDb, contractData) offerDb.toDLCOffer(contractInfo, fundingInputs, dlcDb, contractData)
val accept = acceptDbOpt val accept =
.map( acceptDb.toDLCAccept(
_.toDLCAccept(dlcDb.tempContractId, dlcDb.tempContractId,
fundingInputs, fundingInputs,
sigDbs.map(dbSig => sigDbs.map(dbSig => (dbSig.sigPoint, dbSig.accepterSig)),
(dbSig.sigPoint, dbSig.accepterSig)), acceptRefundSigOpt.head)
acceptRefundSigOpt.head))
.head
val sign: DLCSign = { val sign: DLCSign = {
val cetSigs: CETSignatures = val cetSigs: CETSignatures =
@ -187,9 +189,11 @@ private[bitcoins] trait DLCTransactionProcessing extends TransactionProcessing {
} }
} }
_ <- oracleNonceDAO.updateAll(updatedNonces) updateNonceA = oracleNonceDAO.updateAllAction(updatedNonces)
updateAnnouncementA = dlcAnnouncementDAO.updateAllAction(
_ <- dlcAnnouncementDAO.updateAll(updatedAnnouncements) updatedAnnouncements)
actions = DBIO.seq(updateNonceA, updateAnnouncementA).transactionally
_ <- safeDatabase.run(actions)
} yield dlcDb.copy(aggregateSignatureOpt = Some(sig)) } yield dlcDb.copy(aggregateSignatureOpt = Some(sig))
} else { } else {
Future.successful(dlcDb) Future.successful(dlcDb)

View File

@ -77,9 +77,16 @@ case class DLCAnnouncementDAO()(implicit
def findByAnnouncementIds( def findByAnnouncementIds(
ids: Vector[Long]): Future[Vector[DLCAnnouncementDb]] = { ids: Vector[Long]): Future[Vector[DLCAnnouncementDb]] = {
val query = table.filter(_.announcementId.inSet(ids)) val action = findByAnnouncementIdsAction(ids)
safeDatabase.runVec(action)
}
safeDatabase.runVec(query.result) def findByAnnouncementIdsAction(ids: Vector[Long]): DBIOAction[
Vector[DLCAnnouncementDb],
NoStream,
Effect.Read] = {
val query = table.filter(_.announcementId.inSet(ids))
query.result.map(_.toVector)
} }
override def findByDLCIdAction(dlcId: Sha256Digest): DBIOAction[ override def findByDLCIdAction(dlcId: Sha256Digest): DBIOAction[

View File

@ -69,11 +69,18 @@ case class OracleNonceDAO()(implicit
findByNonces(Vector(nonce)).map(_.headOption) findByNonces(Vector(nonce)).map(_.headOption)
} }
def findByNoncesAction(nonces: Vector[SchnorrNonce]): DBIOAction[
Vector[OracleNonceDb],
NoStream,
Effect.Read] = {
val query = table.filter(_.nonce.inSet(nonces))
query.result.map(_.toVector)
}
def findByNonces( def findByNonces(
nonces: Vector[SchnorrNonce]): Future[Vector[OracleNonceDb]] = { nonces: Vector[SchnorrNonce]): Future[Vector[OracleNonceDb]] = {
val query = table.filter(_.nonce.inSet(nonces)) val action = findByNoncesAction(nonces)
safeDatabase.runVec(action)
safeDatabase.runVec(query.result)
} }
def findByAnnouncementId(id: Long): Future[Vector[OracleNonceDb]] = { def findByAnnouncementId(id: Long): Future[Vector[OracleNonceDb]] = {

View File

@ -1,7 +1,7 @@
package org.bitcoins.dlc.wallet.util package org.bitcoins.dlc.wallet.util
import org.bitcoins.core.api.dlc.wallet.db.DLCDb import org.bitcoins.core.api.dlc.wallet.db.DLCDb
import org.bitcoins.crypto.Sha256Digest import org.bitcoins.crypto.{SchnorrDigitalSignature, SchnorrNonce, Sha256Digest}
import org.bitcoins.dlc.wallet.models.{ import org.bitcoins.dlc.wallet.models.{
DLCAcceptDAO, DLCAcceptDAO,
DLCAcceptDb, DLCAcceptDb,
@ -17,7 +17,9 @@ import org.bitcoins.dlc.wallet.models.{
DLCOfferDAO, DLCOfferDAO,
DLCOfferDb, DLCOfferDb,
DLCRefundSigsDAO, DLCRefundSigsDAO,
DLCRefundSigsDb DLCRefundSigsDb,
OracleNonceDAO,
OracleNonceDb
} }
import scala.concurrent.ExecutionContext import scala.concurrent.ExecutionContext
@ -31,7 +33,8 @@ case class DLCActionBuilder(
dlcOfferDAO: DLCOfferDAO, dlcOfferDAO: DLCOfferDAO,
dlcAcceptDAO: DLCAcceptDAO, dlcAcceptDAO: DLCAcceptDAO,
dlcSigsDAO: DLCCETSignaturesDAO, dlcSigsDAO: DLCCETSignaturesDAO,
dlcRefundSigDAO: DLCRefundSigsDAO) { dlcRefundSigDAO: DLCRefundSigsDAO,
oracleNonceDAO: OracleNonceDAO) {
//idk if it matters which profile api i import, but i need access to transactionally //idk if it matters which profile api i import, but i need access to transactionally
import dlcDAO.profile.api._ import dlcDAO.profile.api._
@ -153,4 +156,39 @@ case class DLCActionBuilder(
combined combined
} }
/** Updates various tables in our database with oracle attestations
* that are published by the oracle
*/
def updateDLCOracleSigsAction(
outcomeAndSigByNonce: Map[
SchnorrNonce,
(String, SchnorrDigitalSignature)])(implicit
ec: ExecutionContext): DBIOAction[
Vector[OracleNonceDb],
NoStream,
Effect.Write with Effect.Read with Effect.Transactional] = {
val updateAction = for {
nonceDbs <- oracleNonceDAO.findByNoncesAction(
outcomeAndSigByNonce.keys.toVector)
_ = assert(nonceDbs.size == outcomeAndSigByNonce.keys.size,
"Didn't receive all nonce dbs")
updated = nonceDbs.map { db =>
val (outcome, sig) = outcomeAndSigByNonce(db.nonce)
db.copy(outcomeOpt = Some(outcome), signatureOpt = Some(sig))
}
updateNonces <- oracleNonceDAO.updateAllAction(updated)
announcementDbs <- {
val announcementIds = updateNonces.map(_.announcementId).distinct
dlcAnnouncementDAO.findByAnnouncementIdsAction(announcementIds)
}
updatedDbs = announcementDbs.map(_.copy(used = true))
_ <- dlcAnnouncementDAO.updateAllAction(updatedDbs)
} yield updateNonces
updateAction.transactionally
}
} }

View File

@ -626,7 +626,7 @@ private[bitcoins] trait TransactionProcessing extends WalletLogger {
val totalIncoming = outputsWithIndex.map(_.output.value).sum val totalIncoming = outputsWithIndex.map(_.output.value).sum
val spks = outputsWithIndex.map(_.output.scriptPubKey) val spks = outputsWithIndex.map(_.output.scriptPubKey)
val spksInDbF = addressDAO.findByScriptPubKeys(spks.toVector) val spksInDbF = addressDAO.findByScriptPubKeys(spks)
val ourOutputsF = for { val ourOutputsF = for {
spksInDb <- spksInDbF spksInDb <- spksInDbF
@ -658,7 +658,7 @@ private[bitcoins] trait TransactionProcessing extends WalletLogger {
.fromScriptPubKey(out.output.scriptPubKey, networkParameters) .fromScriptPubKey(out.output.scriptPubKey, networkParameters)
tagsToUse.map(tag => AddressTagDb(address, tag)) tagsToUse.map(tag => AddressTagDb(address, tag))
} }
created <- addressTagDAO.createAll(newTagDbs.toVector) created <- addressTagDAO.createAll(newTagDbs)
} yield created } yield created
for { for {