mirror of
https://github.com/bitcoin-s/bitcoin-s.git
synced 2024-11-19 09:52:09 +01:00
Add CRUDAction.{updateAllAction, updateAction} (#3872)
* Add CRUDAction.{updateAllAction, updateAction} * Move updateDLCOracleSigs into DLCActionBuilder and make it an action * Update DLCTransactionProcessing.calculateAndSetOutcome() to use actions
This commit is contained in:
parent
9c9a0a618f
commit
2d5732375f
@ -75,30 +75,13 @@ abstract class CRUD[T, PrimaryKeyType](implicit
|
||||
|
||||
/** Update the corresponding record in the database */
|
||||
def update(t: T): Future[T] = {
|
||||
updateAll(Vector(t)).map { ts =>
|
||||
ts.headOption match {
|
||||
case Some(updated) => updated
|
||||
case None => throw UpdateFailedException("Update failed for: " + t)
|
||||
}
|
||||
}
|
||||
val action = updateAction(t).transactionally
|
||||
safeDatabase.run(action)
|
||||
}
|
||||
|
||||
def updateAll(ts: Vector[T]): Future[Vector[T]] = {
|
||||
if (ts.isEmpty) {
|
||||
Future.successful(ts)
|
||||
} else {
|
||||
val actions = ts.map(t => find(t).update(t))
|
||||
for {
|
||||
numUpdated <- safeDatabase.runVec(
|
||||
DBIO.sequence(actions).transactionally)
|
||||
tsUpdated <- {
|
||||
if (numUpdated.sum == ts.length) Future.successful(ts)
|
||||
else
|
||||
Future.failed(new RuntimeException(
|
||||
s"Unexpected number of updates completed ${numUpdated.sum} of ${ts.length}"))
|
||||
}
|
||||
} yield tsUpdated
|
||||
}
|
||||
val actions = updateAllAction(ts).transactionally
|
||||
safeDatabase.runVec(actions)
|
||||
}
|
||||
|
||||
/** delete the corresponding record in the database
|
||||
@ -165,15 +148,6 @@ abstract class CRUD[T, PrimaryKeyType](implicit
|
||||
protected def findByPrimaryKeys(
|
||||
ids: Vector[PrimaryKeyType]): Query[Table[T], T, Seq]
|
||||
|
||||
/** return the row that corresponds with this record
|
||||
*
|
||||
* @param t - the row to find
|
||||
* @return query - the sql query to find this record
|
||||
*/
|
||||
protected def find(t: T): Query[Table[_], T, Seq] = findAll(Vector(t))
|
||||
|
||||
protected def findAll(ts: Vector[T]): Query[Table[_], T, Seq]
|
||||
|
||||
/** Finds all elements in the table */
|
||||
def findAll(): Future[Vector[T]] =
|
||||
safeDatabase.run(table.result).map(_.toVector)
|
||||
|
@ -18,4 +18,43 @@ abstract class CRUDAction[T, PrimaryKeyType](implicit
|
||||
createAllAction(Vector(t))
|
||||
.map(_.head)
|
||||
}
|
||||
|
||||
def updateAction(t: T): DBIOAction[T, NoStream, Effect.Write] = {
|
||||
updateAllAction(Vector(t)).map(_.head)
|
||||
}
|
||||
|
||||
protected def find(t: T): Query[Table[_], T, Seq] = findAll(Vector(t))
|
||||
|
||||
protected def findAll(ts: Vector[T]): Query[Table[_], T, Seq]
|
||||
|
||||
/** Updates all of the given ts.
|
||||
* Returns all ts that actually existed in the database and got updated
|
||||
* This method discards things that did not exist in the database,
|
||||
* thus could not be updated
|
||||
*/
|
||||
def updateAllAction(
|
||||
ts: Vector[T]): DBIOAction[Vector[T], NoStream, Effect.Write] = {
|
||||
val updateActions: Vector[DBIOAction[Option[T], NoStream, Effect.Write]] = {
|
||||
ts.map { t =>
|
||||
find(t).update(t).flatMap { rowsUpdated =>
|
||||
if (rowsUpdated == 0) {
|
||||
DBIO.successful(None)
|
||||
} else if (rowsUpdated == 1) {
|
||||
DBIO.successful(Some(t))
|
||||
} else {
|
||||
DBIO.failed(new RuntimeException(
|
||||
s"Updated more rows that we intended to update, updated=$rowsUpdated"))
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
val sequencedA: DBIOAction[Vector[Option[T]], NoStream, Effect.Write] = {
|
||||
DBIO.sequence(updateActions)
|
||||
}
|
||||
|
||||
//discard all rows that did not exist,
|
||||
//thus cannot be updated
|
||||
sequencedA.map(_.flatten)
|
||||
}
|
||||
}
|
||||
|
@ -72,7 +72,8 @@ abstract class DLCWallet
|
||||
dlcOfferDAO = dlcOfferDAO,
|
||||
dlcAcceptDAO = dlcAcceptDAO,
|
||||
dlcSigsDAO = dlcSigsDAO,
|
||||
dlcRefundSigDAO = dlcRefundSigDAO
|
||||
dlcRefundSigDAO = dlcRefundSigDAO,
|
||||
oracleNonceDAO = oracleNonceDAO
|
||||
)
|
||||
}
|
||||
|
||||
@ -198,25 +199,10 @@ abstract class DLCWallet
|
||||
|
||||
require(outcomeAndSigByNonce.forall(t => t._1 == t._2._2.rx),
|
||||
"nonces out of order")
|
||||
|
||||
val updateOracleSigsA =
|
||||
actionBuilder.updateDLCOracleSigsAction(outcomeAndSigByNonce)
|
||||
for {
|
||||
nonceDbs <- oracleNonceDAO.findByNonces(
|
||||
outcomeAndSigByNonce.keys.toVector)
|
||||
_ = assert(nonceDbs.size == outcomeAndSigByNonce.keys.size,
|
||||
"Didn't receive all nonce dbs")
|
||||
|
||||
updated = nonceDbs.map { db =>
|
||||
val (outcome, sig) = outcomeAndSigByNonce(db.nonce)
|
||||
db.copy(outcomeOpt = Some(outcome), signatureOpt = Some(sig))
|
||||
}
|
||||
|
||||
updates <- oracleNonceDAO.updateAll(updated)
|
||||
|
||||
announcementIds = updates.map(_.announcementId).distinct
|
||||
announcementDbs <- dlcAnnouncementDAO.findByAnnouncementIds(
|
||||
announcementIds)
|
||||
updatedDbs = announcementDbs.map(_.copy(used = true))
|
||||
_ <- dlcAnnouncementDAO.updateAll(updatedDbs)
|
||||
updates <- safeDatabase.runVec(updateOracleSigsA)
|
||||
} yield updates
|
||||
}
|
||||
|
||||
|
@ -10,6 +10,7 @@ import org.bitcoins.core.protocol.transaction.{Transaction, WitnessTransaction}
|
||||
import org.bitcoins.core.util.FutureUtil
|
||||
import org.bitcoins.core.wallet.utxo.AddressTag
|
||||
import org.bitcoins.crypto.DoubleSha256DigestBE
|
||||
import org.bitcoins.db.SafeDatabase
|
||||
import org.bitcoins.dlc.wallet.DLCWallet
|
||||
import org.bitcoins.wallet.internal.TransactionProcessing
|
||||
|
||||
@ -21,6 +22,9 @@ import scala.concurrent._
|
||||
private[bitcoins] trait DLCTransactionProcessing extends TransactionProcessing {
|
||||
self: DLCWallet =>
|
||||
|
||||
import dlcDAO.profile.api._
|
||||
private lazy val safeDatabase: SafeDatabase = dlcDAO.safeDatabase
|
||||
|
||||
/** Calculates the new state of the DLCDb based on the closing transaction,
|
||||
* will delete old CET sigs that are no longer needed after execution
|
||||
*/
|
||||
@ -76,8 +80,8 @@ private[bitcoins] trait DLCTransactionProcessing extends TransactionProcessing {
|
||||
val dlcId = dlcDb.dlcId
|
||||
|
||||
for {
|
||||
(_, contractData, offerDb, fundingInputDbs, _) <- getDLCOfferData(dlcId)
|
||||
acceptDbOpt <- dlcAcceptDAO.findByDLCId(dlcId)
|
||||
(_, contractData, offerDb, acceptDb, fundingInputDbs, _) <-
|
||||
getDLCFundingData(dlcId)
|
||||
txIds = fundingInputDbs.map(_.outPoint.txIdBE)
|
||||
remotePrevTxs <- remoteTxDAO.findByTxIdBEs(txIds)
|
||||
localPrevTxs <- transactionDAO.findByTxIdBEs(txIds)
|
||||
@ -110,14 +114,12 @@ private[bitcoins] trait DLCTransactionProcessing extends TransactionProcessing {
|
||||
|
||||
val offer =
|
||||
offerDb.toDLCOffer(contractInfo, fundingInputs, dlcDb, contractData)
|
||||
val accept = acceptDbOpt
|
||||
.map(
|
||||
_.toDLCAccept(dlcDb.tempContractId,
|
||||
fundingInputs,
|
||||
sigDbs.map(dbSig =>
|
||||
(dbSig.sigPoint, dbSig.accepterSig)),
|
||||
acceptRefundSigOpt.head))
|
||||
.head
|
||||
val accept =
|
||||
acceptDb.toDLCAccept(
|
||||
dlcDb.tempContractId,
|
||||
fundingInputs,
|
||||
sigDbs.map(dbSig => (dbSig.sigPoint, dbSig.accepterSig)),
|
||||
acceptRefundSigOpt.head)
|
||||
|
||||
val sign: DLCSign = {
|
||||
val cetSigs: CETSignatures =
|
||||
@ -187,9 +189,11 @@ private[bitcoins] trait DLCTransactionProcessing extends TransactionProcessing {
|
||||
|
||||
}
|
||||
}
|
||||
_ <- oracleNonceDAO.updateAll(updatedNonces)
|
||||
|
||||
_ <- dlcAnnouncementDAO.updateAll(updatedAnnouncements)
|
||||
updateNonceA = oracleNonceDAO.updateAllAction(updatedNonces)
|
||||
updateAnnouncementA = dlcAnnouncementDAO.updateAllAction(
|
||||
updatedAnnouncements)
|
||||
actions = DBIO.seq(updateNonceA, updateAnnouncementA).transactionally
|
||||
_ <- safeDatabase.run(actions)
|
||||
} yield dlcDb.copy(aggregateSignatureOpt = Some(sig))
|
||||
} else {
|
||||
Future.successful(dlcDb)
|
||||
|
@ -77,9 +77,16 @@ case class DLCAnnouncementDAO()(implicit
|
||||
|
||||
def findByAnnouncementIds(
|
||||
ids: Vector[Long]): Future[Vector[DLCAnnouncementDb]] = {
|
||||
val query = table.filter(_.announcementId.inSet(ids))
|
||||
val action = findByAnnouncementIdsAction(ids)
|
||||
safeDatabase.runVec(action)
|
||||
}
|
||||
|
||||
safeDatabase.runVec(query.result)
|
||||
def findByAnnouncementIdsAction(ids: Vector[Long]): DBIOAction[
|
||||
Vector[DLCAnnouncementDb],
|
||||
NoStream,
|
||||
Effect.Read] = {
|
||||
val query = table.filter(_.announcementId.inSet(ids))
|
||||
query.result.map(_.toVector)
|
||||
}
|
||||
|
||||
override def findByDLCIdAction(dlcId: Sha256Digest): DBIOAction[
|
||||
|
@ -69,11 +69,18 @@ case class OracleNonceDAO()(implicit
|
||||
findByNonces(Vector(nonce)).map(_.headOption)
|
||||
}
|
||||
|
||||
def findByNoncesAction(nonces: Vector[SchnorrNonce]): DBIOAction[
|
||||
Vector[OracleNonceDb],
|
||||
NoStream,
|
||||
Effect.Read] = {
|
||||
val query = table.filter(_.nonce.inSet(nonces))
|
||||
query.result.map(_.toVector)
|
||||
}
|
||||
|
||||
def findByNonces(
|
||||
nonces: Vector[SchnorrNonce]): Future[Vector[OracleNonceDb]] = {
|
||||
val query = table.filter(_.nonce.inSet(nonces))
|
||||
|
||||
safeDatabase.runVec(query.result)
|
||||
val action = findByNoncesAction(nonces)
|
||||
safeDatabase.runVec(action)
|
||||
}
|
||||
|
||||
def findByAnnouncementId(id: Long): Future[Vector[OracleNonceDb]] = {
|
||||
|
@ -1,7 +1,7 @@
|
||||
package org.bitcoins.dlc.wallet.util
|
||||
|
||||
import org.bitcoins.core.api.dlc.wallet.db.DLCDb
|
||||
import org.bitcoins.crypto.Sha256Digest
|
||||
import org.bitcoins.crypto.{SchnorrDigitalSignature, SchnorrNonce, Sha256Digest}
|
||||
import org.bitcoins.dlc.wallet.models.{
|
||||
DLCAcceptDAO,
|
||||
DLCAcceptDb,
|
||||
@ -17,7 +17,9 @@ import org.bitcoins.dlc.wallet.models.{
|
||||
DLCOfferDAO,
|
||||
DLCOfferDb,
|
||||
DLCRefundSigsDAO,
|
||||
DLCRefundSigsDb
|
||||
DLCRefundSigsDb,
|
||||
OracleNonceDAO,
|
||||
OracleNonceDb
|
||||
}
|
||||
|
||||
import scala.concurrent.ExecutionContext
|
||||
@ -31,7 +33,8 @@ case class DLCActionBuilder(
|
||||
dlcOfferDAO: DLCOfferDAO,
|
||||
dlcAcceptDAO: DLCAcceptDAO,
|
||||
dlcSigsDAO: DLCCETSignaturesDAO,
|
||||
dlcRefundSigDAO: DLCRefundSigsDAO) {
|
||||
dlcRefundSigDAO: DLCRefundSigsDAO,
|
||||
oracleNonceDAO: OracleNonceDAO) {
|
||||
|
||||
//idk if it matters which profile api i import, but i need access to transactionally
|
||||
import dlcDAO.profile.api._
|
||||
@ -153,4 +156,39 @@ case class DLCActionBuilder(
|
||||
|
||||
combined
|
||||
}
|
||||
|
||||
/** Updates various tables in our database with oracle attestations
|
||||
* that are published by the oracle
|
||||
*/
|
||||
def updateDLCOracleSigsAction(
|
||||
outcomeAndSigByNonce: Map[
|
||||
SchnorrNonce,
|
||||
(String, SchnorrDigitalSignature)])(implicit
|
||||
ec: ExecutionContext): DBIOAction[
|
||||
Vector[OracleNonceDb],
|
||||
NoStream,
|
||||
Effect.Write with Effect.Read with Effect.Transactional] = {
|
||||
val updateAction = for {
|
||||
nonceDbs <- oracleNonceDAO.findByNoncesAction(
|
||||
outcomeAndSigByNonce.keys.toVector)
|
||||
_ = assert(nonceDbs.size == outcomeAndSigByNonce.keys.size,
|
||||
"Didn't receive all nonce dbs")
|
||||
|
||||
updated = nonceDbs.map { db =>
|
||||
val (outcome, sig) = outcomeAndSigByNonce(db.nonce)
|
||||
db.copy(outcomeOpt = Some(outcome), signatureOpt = Some(sig))
|
||||
}
|
||||
|
||||
updateNonces <- oracleNonceDAO.updateAllAction(updated)
|
||||
|
||||
announcementDbs <- {
|
||||
val announcementIds = updateNonces.map(_.announcementId).distinct
|
||||
dlcAnnouncementDAO.findByAnnouncementIdsAction(announcementIds)
|
||||
}
|
||||
updatedDbs = announcementDbs.map(_.copy(used = true))
|
||||
_ <- dlcAnnouncementDAO.updateAllAction(updatedDbs)
|
||||
} yield updateNonces
|
||||
|
||||
updateAction.transactionally
|
||||
}
|
||||
}
|
||||
|
@ -626,7 +626,7 @@ private[bitcoins] trait TransactionProcessing extends WalletLogger {
|
||||
val totalIncoming = outputsWithIndex.map(_.output.value).sum
|
||||
|
||||
val spks = outputsWithIndex.map(_.output.scriptPubKey)
|
||||
val spksInDbF = addressDAO.findByScriptPubKeys(spks.toVector)
|
||||
val spksInDbF = addressDAO.findByScriptPubKeys(spks)
|
||||
|
||||
val ourOutputsF = for {
|
||||
spksInDb <- spksInDbF
|
||||
@ -658,7 +658,7 @@ private[bitcoins] trait TransactionProcessing extends WalletLogger {
|
||||
.fromScriptPubKey(out.output.scriptPubKey, networkParameters)
|
||||
tagsToUse.map(tag => AddressTagDb(address, tag))
|
||||
}
|
||||
created <- addressTagDAO.createAll(newTagDbs.toVector)
|
||||
created <- addressTagDAO.createAll(newTagDbs)
|
||||
} yield created
|
||||
|
||||
for {
|
||||
|
Loading…
Reference in New Issue
Block a user