Add CRUDAction.{updateAllAction, updateAction} (#3872)

* Add CRUDAction.{updateAllAction, updateAction}

* Move updateDLCOracleSigs into DLCActionBuilder and make it an action

* Update DLCTransactionProcessing.calculateAndSetOutcome() to use actions
This commit is contained in:
Chris Stewart 2021-12-04 05:29:51 -06:00 committed by GitHub
parent 9c9a0a618f
commit 2d5732375f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 127 additions and 72 deletions

View File

@ -75,30 +75,13 @@ abstract class CRUD[T, PrimaryKeyType](implicit
/** Update the corresponding record in the database */
def update(t: T): Future[T] = {
updateAll(Vector(t)).map { ts =>
ts.headOption match {
case Some(updated) => updated
case None => throw UpdateFailedException("Update failed for: " + t)
}
}
val action = updateAction(t).transactionally
safeDatabase.run(action)
}
def updateAll(ts: Vector[T]): Future[Vector[T]] = {
if (ts.isEmpty) {
Future.successful(ts)
} else {
val actions = ts.map(t => find(t).update(t))
for {
numUpdated <- safeDatabase.runVec(
DBIO.sequence(actions).transactionally)
tsUpdated <- {
if (numUpdated.sum == ts.length) Future.successful(ts)
else
Future.failed(new RuntimeException(
s"Unexpected number of updates completed ${numUpdated.sum} of ${ts.length}"))
}
} yield tsUpdated
}
val actions = updateAllAction(ts).transactionally
safeDatabase.runVec(actions)
}
/** delete the corresponding record in the database
@ -165,15 +148,6 @@ abstract class CRUD[T, PrimaryKeyType](implicit
protected def findByPrimaryKeys(
ids: Vector[PrimaryKeyType]): Query[Table[T], T, Seq]
/** return the row that corresponds with this record
*
* @param t - the row to find
* @return query - the sql query to find this record
*/
protected def find(t: T): Query[Table[_], T, Seq] = findAll(Vector(t))
protected def findAll(ts: Vector[T]): Query[Table[_], T, Seq]
/** Finds all elements in the table */
def findAll(): Future[Vector[T]] =
safeDatabase.run(table.result).map(_.toVector)

View File

@ -18,4 +18,43 @@ abstract class CRUDAction[T, PrimaryKeyType](implicit
createAllAction(Vector(t))
.map(_.head)
}
def updateAction(t: T): DBIOAction[T, NoStream, Effect.Write] = {
updateAllAction(Vector(t)).map(_.head)
}
protected def find(t: T): Query[Table[_], T, Seq] = findAll(Vector(t))
protected def findAll(ts: Vector[T]): Query[Table[_], T, Seq]
/** Updates all of the given ts.
* Returns all ts that actually existed in the database and got updated
* This method discards things that did not exist in the database,
* thus could not be updated
*/
def updateAllAction(
ts: Vector[T]): DBIOAction[Vector[T], NoStream, Effect.Write] = {
val updateActions: Vector[DBIOAction[Option[T], NoStream, Effect.Write]] = {
ts.map { t =>
find(t).update(t).flatMap { rowsUpdated =>
if (rowsUpdated == 0) {
DBIO.successful(None)
} else if (rowsUpdated == 1) {
DBIO.successful(Some(t))
} else {
DBIO.failed(new RuntimeException(
s"Updated more rows that we intended to update, updated=$rowsUpdated"))
}
}
}
}
val sequencedA: DBIOAction[Vector[Option[T]], NoStream, Effect.Write] = {
DBIO.sequence(updateActions)
}
//discard all rows that did not exist,
//thus cannot be updated
sequencedA.map(_.flatten)
}
}

View File

@ -72,7 +72,8 @@ abstract class DLCWallet
dlcOfferDAO = dlcOfferDAO,
dlcAcceptDAO = dlcAcceptDAO,
dlcSigsDAO = dlcSigsDAO,
dlcRefundSigDAO = dlcRefundSigDAO
dlcRefundSigDAO = dlcRefundSigDAO,
oracleNonceDAO = oracleNonceDAO
)
}
@ -198,25 +199,10 @@ abstract class DLCWallet
require(outcomeAndSigByNonce.forall(t => t._1 == t._2._2.rx),
"nonces out of order")
val updateOracleSigsA =
actionBuilder.updateDLCOracleSigsAction(outcomeAndSigByNonce)
for {
nonceDbs <- oracleNonceDAO.findByNonces(
outcomeAndSigByNonce.keys.toVector)
_ = assert(nonceDbs.size == outcomeAndSigByNonce.keys.size,
"Didn't receive all nonce dbs")
updated = nonceDbs.map { db =>
val (outcome, sig) = outcomeAndSigByNonce(db.nonce)
db.copy(outcomeOpt = Some(outcome), signatureOpt = Some(sig))
}
updates <- oracleNonceDAO.updateAll(updated)
announcementIds = updates.map(_.announcementId).distinct
announcementDbs <- dlcAnnouncementDAO.findByAnnouncementIds(
announcementIds)
updatedDbs = announcementDbs.map(_.copy(used = true))
_ <- dlcAnnouncementDAO.updateAll(updatedDbs)
updates <- safeDatabase.runVec(updateOracleSigsA)
} yield updates
}

View File

@ -10,6 +10,7 @@ import org.bitcoins.core.protocol.transaction.{Transaction, WitnessTransaction}
import org.bitcoins.core.util.FutureUtil
import org.bitcoins.core.wallet.utxo.AddressTag
import org.bitcoins.crypto.DoubleSha256DigestBE
import org.bitcoins.db.SafeDatabase
import org.bitcoins.dlc.wallet.DLCWallet
import org.bitcoins.wallet.internal.TransactionProcessing
@ -21,6 +22,9 @@ import scala.concurrent._
private[bitcoins] trait DLCTransactionProcessing extends TransactionProcessing {
self: DLCWallet =>
import dlcDAO.profile.api._
private lazy val safeDatabase: SafeDatabase = dlcDAO.safeDatabase
/** Calculates the new state of the DLCDb based on the closing transaction,
* will delete old CET sigs that are no longer needed after execution
*/
@ -76,8 +80,8 @@ private[bitcoins] trait DLCTransactionProcessing extends TransactionProcessing {
val dlcId = dlcDb.dlcId
for {
(_, contractData, offerDb, fundingInputDbs, _) <- getDLCOfferData(dlcId)
acceptDbOpt <- dlcAcceptDAO.findByDLCId(dlcId)
(_, contractData, offerDb, acceptDb, fundingInputDbs, _) <-
getDLCFundingData(dlcId)
txIds = fundingInputDbs.map(_.outPoint.txIdBE)
remotePrevTxs <- remoteTxDAO.findByTxIdBEs(txIds)
localPrevTxs <- transactionDAO.findByTxIdBEs(txIds)
@ -110,14 +114,12 @@ private[bitcoins] trait DLCTransactionProcessing extends TransactionProcessing {
val offer =
offerDb.toDLCOffer(contractInfo, fundingInputs, dlcDb, contractData)
val accept = acceptDbOpt
.map(
_.toDLCAccept(dlcDb.tempContractId,
fundingInputs,
sigDbs.map(dbSig =>
(dbSig.sigPoint, dbSig.accepterSig)),
acceptRefundSigOpt.head))
.head
val accept =
acceptDb.toDLCAccept(
dlcDb.tempContractId,
fundingInputs,
sigDbs.map(dbSig => (dbSig.sigPoint, dbSig.accepterSig)),
acceptRefundSigOpt.head)
val sign: DLCSign = {
val cetSigs: CETSignatures =
@ -187,9 +189,11 @@ private[bitcoins] trait DLCTransactionProcessing extends TransactionProcessing {
}
}
_ <- oracleNonceDAO.updateAll(updatedNonces)
_ <- dlcAnnouncementDAO.updateAll(updatedAnnouncements)
updateNonceA = oracleNonceDAO.updateAllAction(updatedNonces)
updateAnnouncementA = dlcAnnouncementDAO.updateAllAction(
updatedAnnouncements)
actions = DBIO.seq(updateNonceA, updateAnnouncementA).transactionally
_ <- safeDatabase.run(actions)
} yield dlcDb.copy(aggregateSignatureOpt = Some(sig))
} else {
Future.successful(dlcDb)

View File

@ -77,9 +77,16 @@ case class DLCAnnouncementDAO()(implicit
def findByAnnouncementIds(
ids: Vector[Long]): Future[Vector[DLCAnnouncementDb]] = {
val query = table.filter(_.announcementId.inSet(ids))
val action = findByAnnouncementIdsAction(ids)
safeDatabase.runVec(action)
}
safeDatabase.runVec(query.result)
def findByAnnouncementIdsAction(ids: Vector[Long]): DBIOAction[
Vector[DLCAnnouncementDb],
NoStream,
Effect.Read] = {
val query = table.filter(_.announcementId.inSet(ids))
query.result.map(_.toVector)
}
override def findByDLCIdAction(dlcId: Sha256Digest): DBIOAction[

View File

@ -69,11 +69,18 @@ case class OracleNonceDAO()(implicit
findByNonces(Vector(nonce)).map(_.headOption)
}
def findByNoncesAction(nonces: Vector[SchnorrNonce]): DBIOAction[
Vector[OracleNonceDb],
NoStream,
Effect.Read] = {
val query = table.filter(_.nonce.inSet(nonces))
query.result.map(_.toVector)
}
def findByNonces(
nonces: Vector[SchnorrNonce]): Future[Vector[OracleNonceDb]] = {
val query = table.filter(_.nonce.inSet(nonces))
safeDatabase.runVec(query.result)
val action = findByNoncesAction(nonces)
safeDatabase.runVec(action)
}
def findByAnnouncementId(id: Long): Future[Vector[OracleNonceDb]] = {

View File

@ -1,7 +1,7 @@
package org.bitcoins.dlc.wallet.util
import org.bitcoins.core.api.dlc.wallet.db.DLCDb
import org.bitcoins.crypto.Sha256Digest
import org.bitcoins.crypto.{SchnorrDigitalSignature, SchnorrNonce, Sha256Digest}
import org.bitcoins.dlc.wallet.models.{
DLCAcceptDAO,
DLCAcceptDb,
@ -17,7 +17,9 @@ import org.bitcoins.dlc.wallet.models.{
DLCOfferDAO,
DLCOfferDb,
DLCRefundSigsDAO,
DLCRefundSigsDb
DLCRefundSigsDb,
OracleNonceDAO,
OracleNonceDb
}
import scala.concurrent.ExecutionContext
@ -31,7 +33,8 @@ case class DLCActionBuilder(
dlcOfferDAO: DLCOfferDAO,
dlcAcceptDAO: DLCAcceptDAO,
dlcSigsDAO: DLCCETSignaturesDAO,
dlcRefundSigDAO: DLCRefundSigsDAO) {
dlcRefundSigDAO: DLCRefundSigsDAO,
oracleNonceDAO: OracleNonceDAO) {
//idk if it matters which profile api i import, but i need access to transactionally
import dlcDAO.profile.api._
@ -153,4 +156,39 @@ case class DLCActionBuilder(
combined
}
/** Updates various tables in our database with oracle attestations
* that are published by the oracle
*/
def updateDLCOracleSigsAction(
outcomeAndSigByNonce: Map[
SchnorrNonce,
(String, SchnorrDigitalSignature)])(implicit
ec: ExecutionContext): DBIOAction[
Vector[OracleNonceDb],
NoStream,
Effect.Write with Effect.Read with Effect.Transactional] = {
val updateAction = for {
nonceDbs <- oracleNonceDAO.findByNoncesAction(
outcomeAndSigByNonce.keys.toVector)
_ = assert(nonceDbs.size == outcomeAndSigByNonce.keys.size,
"Didn't receive all nonce dbs")
updated = nonceDbs.map { db =>
val (outcome, sig) = outcomeAndSigByNonce(db.nonce)
db.copy(outcomeOpt = Some(outcome), signatureOpt = Some(sig))
}
updateNonces <- oracleNonceDAO.updateAllAction(updated)
announcementDbs <- {
val announcementIds = updateNonces.map(_.announcementId).distinct
dlcAnnouncementDAO.findByAnnouncementIdsAction(announcementIds)
}
updatedDbs = announcementDbs.map(_.copy(used = true))
_ <- dlcAnnouncementDAO.updateAllAction(updatedDbs)
} yield updateNonces
updateAction.transactionally
}
}

View File

@ -626,7 +626,7 @@ private[bitcoins] trait TransactionProcessing extends WalletLogger {
val totalIncoming = outputsWithIndex.map(_.output.value).sum
val spks = outputsWithIndex.map(_.output.scriptPubKey)
val spksInDbF = addressDAO.findByScriptPubKeys(spks.toVector)
val spksInDbF = addressDAO.findByScriptPubKeys(spks)
val ourOutputsF = for {
spksInDb <- spksInDbF
@ -658,7 +658,7 @@ private[bitcoins] trait TransactionProcessing extends WalletLogger {
.fromScriptPubKey(out.output.scriptPubKey, networkParameters)
tagsToUse.map(tag => AddressTagDb(address, tag))
}
created <- addressTagDAO.createAll(newTagDbs.toVector)
created <- addressTagDAO.createAll(newTagDbs)
} yield created
for {