Reworking/refactoring acceptDLCOffer (#4048)

* Part 1 of reworking/refactoring acceptDLCOffer

* scalafmt

* WIP

* Move offer creation into initDLCForAccept

* Refactor method name to getDlcDbOfferDbAccountDb

* Push to github to force re-run of CI
This commit is contained in:
Chris Stewart 2022-02-07 11:41:17 -06:00 committed by GitHub
parent 49d4d7f179
commit 5aeecdb893
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 148 additions and 85 deletions

View File

@ -451,79 +451,82 @@ abstract class DLCWallet
} yield offer
}
private def initDLCForAccept(offer: DLCOffer): Future[(DLCDb, AccountDb)] = {
/** Retrieves the [[DLCDb]] and [[AccountDb]] for the given tempContractId
* else returns None
*/
private def getDlcDbOfferDbAccountDb(tempContractId: Sha256Digest): Future[
Option[(DLCDb, DLCOfferDb, AccountDb)]] = {
val result: Future[Option[(DLCDb, DLCOfferDb, AccountDb)]] = for {
dlcDbOpt <- dlcDAO.findByTempContractId(tempContractId)
dlcOfferDbOpt <- dlcDbOpt match {
case Some(dlcDb) => dlcOfferDAO.findByDLCId(dlcDb.dlcId)
case None => Future.successful(None)
}
accountOpt <- dlcDbOpt match {
case Some(dlcDb) => accountDAO.findByAccount(dlcDb.account)
case None => Future.successful(None)
}
} yield {
for {
dlcDb <- dlcDbOpt
dlcOfferDb <- dlcOfferDbOpt
account <- accountOpt
} yield (dlcDb, dlcOfferDb, account)
}
result
}
private def initDLCForAccept(
offer: DLCOffer): Future[(DLCDb, DLCOfferDb, AccountDb)] = {
logger.info(
s"Initializing DLC from received offer with tempContractId ${offer.tempContractId.hex}")
dlcDAO.findByTempContractId(offer.tempContractId).flatMap {
case Some(dlcDb) =>
accountDAO
.findByAccount(dlcDb.account)
.map(account => (dlcDb, account.get))
val dlcId = calcDLCId(offer.fundingInputs.map(_.outPoint))
val contractInfo = offer.contractInfo
val dlcOfferDb = DLCOfferDbHelper.fromDLCOffer(dlcId, offer)
val announcements =
offer.contractInfo.oracleInfos.head.singleOracleInfos
.map(_.announcement)
val chainType = HDChainType.External
//filter announcements that we already have in the db
val groupedAnnouncementsF: Future[AnnouncementGrouping] = {
groupByExistingAnnouncements(announcements)
}
getDlcDbOfferDbAccountDb(offer.tempContractId).flatMap {
case Some((dlcDb, dlcOffer, account)) =>
Future.successful((dlcDb, dlcOffer, account))
case None =>
val announcements =
offer.contractInfo.oracleInfos.head.singleOracleInfos
.map(_.announcement)
//filter announcements that we already have in the db
val groupedAnnouncementsF: Future[AnnouncementGrouping] = {
groupByExistingAnnouncements(announcements)
}
val contractInfo = offer.contractInfo
val dlcId = calcDLCId(offer.fundingInputs.map(_.outPoint))
val chainType = HDChainType.External
for {
account <- getDefaultAccountForType(AddressType.SegWit)
nextIndex <- getNextAvailableIndex(account, chainType)
dlc =
DLCDb(
dlcId = dlcId,
tempContractId = offer.tempContractId,
contractIdOpt = None,
protocolVersion = 0,
state = DLCState.Accepted,
isInitiator = false,
account = account.hdAccount,
changeIndex = chainType,
keyIndex = nextIndex,
feeRate = offer.feeRate,
fundOutputSerialId = offer.fundOutputSerialId,
lastUpdated = TimeUtil.now,
fundingOutPointOpt = None,
fundingTxIdOpt = None,
closingTxIdOpt = None,
aggregateSignatureOpt = None,
serializationVersion = contractInfo.serializationVersion
)
dlc = buildDlcDb(offer,
dlcId,
account,
chainType,
nextIndex,
contractInfo)
contractDataDb = {
val oracleParamsOpt =
OracleInfo.getOracleParamsOpt(contractInfo.oracleInfos.head)
DLCContractDataDb(
dlcId = dlcId,
oracleThreshold = contractInfo.oracleInfos.head.threshold,
oracleParamsTLVOpt = oracleParamsOpt,
contractDescriptorTLV =
contractInfo.contractDescriptors.head.toTLV,
contractMaturity = offer.timeouts.contractMaturity,
contractTimeout = offer.timeouts.contractTimeout,
totalCollateral = contractInfo.totalCollateral
)
}
contractDataDb = buildContractDataDb(contractInfo, dlcId, offer)
_ <- writeDLCKeysToAddressDb(account, chainType, nextIndex)
groupedAnnouncements <- groupedAnnouncementsF
writtenDLCAction = dlcDAO.createAction(dlc)
dlcDbAction = dlcDAO.createAction(dlc)
dlcOfferAction = dlcOfferDAO.createAction(dlcOfferDb)
contractAction = contractDataDAO.createAction(contractDataDb)
createdDbsAction = announcementDAO.createAllAction(
createdAnnouncementsAction = announcementDAO.createAllAction(
groupedAnnouncements.newAnnouncements)
zipped = writtenDLCAction.zip(createdDbsAction)
actions = zipped.flatMap { dlcDb =>
contractAction.map(_ => dlcDb)
zipped = {
for {
dlcDb <- dlcDbAction
ann <- createdAnnouncementsAction
//we don't need the contract data db, so don't return it
_ <- contractAction
offer <- dlcOfferAction
} yield (dlcDb, ann, offer)
}
(writtenDLC, createdDbs) <- safeDatabase.run(actions)
(writtenDLC, createdDbs, offerDb) <- safeDatabase.run(zipped)
announcementDataDbs =
createdDbs ++ groupedAnnouncements.existingAnnouncements
@ -553,10 +556,55 @@ abstract class DLCWallet
_ <- safeDatabase.run(
DBIOAction.seq(createNonceAction, createAnnouncementAction))
} yield (writtenDLC, account)
} yield (writtenDLC, offerDb, account)
}
}
private def buildContractDataDb(
contractInfo: ContractInfo,
dlcId: Sha256Digest,
offer: DLCOffer): DLCContractDataDb = {
val oracleParamsOpt =
OracleInfo.getOracleParamsOpt(contractInfo.oracleInfos.head)
DLCContractDataDb(
dlcId = dlcId,
oracleThreshold = contractInfo.oracleInfos.head.threshold,
oracleParamsTLVOpt = oracleParamsOpt,
contractDescriptorTLV = contractInfo.contractDescriptors.head.toTLV,
contractMaturity = offer.timeouts.contractMaturity,
contractTimeout = offer.timeouts.contractTimeout,
totalCollateral = contractInfo.totalCollateral
)
}
private def buildDlcDb(
offer: DLCOffer,
dlcId: Sha256Digest,
account: AccountDb,
chainType: HDChainType,
nextIndex: Int,
contractInfo: ContractInfo): DLCDb = {
DLCDb(
dlcId = dlcId,
tempContractId = offer.tempContractId,
contractIdOpt = None,
protocolVersion = 0,
state = DLCState.Accepted,
isInitiator = false,
account = account.hdAccount,
changeIndex = chainType,
keyIndex = nextIndex,
feeRate = offer.feeRate,
fundOutputSerialId = offer.fundOutputSerialId,
lastUpdated = TimeUtil.now,
fundingOutPointOpt = None,
fundingTxIdOpt = None,
closingTxIdOpt = None,
aggregateSignatureOpt = None,
serializationVersion = contractInfo.serializationVersion
)
}
/** Creates a DLCAccept from the default Segwit account from a given offer, if one has already been
* created with the given parameters then that one will be returned instead.
*
@ -571,15 +619,31 @@ abstract class DLCWallet
logger.debug(s"Checking if Accept (${dlcId.hex}) has already been made")
for {
(dlc, account) <- initDLCForAccept(offer)
dlcAcceptOpt <- findDLCAccept(dlcId, offer)
dlcAccept <- {
dlcAcceptOpt match {
case Some(accept) => Future.successful(accept)
case None => createNewDLCAccept(collateral, offer)
}
}
status <- findDLC(dlcId)
_ <- dlcConfig.walletCallbacks.executeOnDLCStateChange(logger, status.get)
} yield dlcAccept
}
/** Checks if an accept message is in the database with the given dlcId */
private def findDLCAccept(
dlcId: Sha256Digest,
offer: DLCOffer): Future[Option[DLCAccept]] = {
val resultNestedF: Future[Option[Future[DLCAccept]]] = for {
dlcAcceptDbs <- dlcAcceptDAO.findByDLCId(dlcId)
dlcAccept <- dlcAcceptDbs.headOption match {
case Some(dlcAcceptDb) =>
dlcAcceptFOpt = {
dlcAcceptDbs.headOption.map { case dlcAcceptDb =>
logger.debug(
s"DLC Accept (${dlcId.hex}) has already been made, returning accept")
for {
fundingInputs <-
dlcInputsDAO.findByDLCId(dlc.dlcId, isInitiator = false)
dlcInputsDAO.findByDLCId(dlcId, isInitiator = false)
prevTxs <-
transactionDAO.findByTxIdBEs(fundingInputs.map(_.outPoint.txIdBE))
outcomeSigsDbs <- dlcSigsDAO.findByDLCId(dlcId)
@ -595,22 +659,28 @@ abstract class DLCWallet
},
refundSigsDb.get.accepterSig)
}
case None =>
createNewDLCAccept(dlc, account, collateral, offer)
}
}
status <- findDLC(dlcId)
_ <- dlcConfig.walletCallbacks.executeOnDLCStateChange(logger, status.get)
} yield dlcAccept
} yield {
dlcAcceptFOpt
}
resultNestedF.flatMap {
case Some(f) => f.map(Some(_))
case None => Future.successful(None)
}
}
private def createNewDLCAccept(
dlc: DLCDb,
account: AccountDb,
collateral: CurrencyUnit,
offer: DLCOffer): Future[DLCAccept] = {
logger.info(
s"Creating DLC Accept for tempContractId ${offer.tempContractId.hex}")
val dlcDbAccountDbF: Future[(DLCDb, DLCOfferDb, AccountDb)] =
initDLCForAccept(offer)
for {
(dlc, _, account) <- dlcDbAccountDbF
(txBuilder, spendingInfos) <- fundRawTransactionInternal(
destinations = Vector(TransactionOutput(collateral, EmptyScriptPubKey)),
feeRate = offer.feeRate,
@ -618,7 +688,6 @@ abstract class DLCWallet
fromTagOpt = None,
markAsReserved = true
)
serialIds = DLCMessage.genSerialIds(
spendingInfos.size,
offer.fundingInputs.map(_.inputSerialId))
@ -703,8 +772,6 @@ abstract class DLCWallet
refundSigsDb =
DLCRefundSigsDb(dlc.dlcId, refundSig, None)
dlcOfferDb = DLCOfferDbHelper.fromDLCOffer(dlc.dlcId, offer)
offerInputs = offer.fundingInputs.zipWithIndex.map {
case (funding, idx) =>
DLCFundingInputDb(
@ -752,7 +819,6 @@ abstract class DLCWallet
_ <- remoteTxDAO.upsertAll(offerPrevTxs)
actions = actionBuilder.buildCreateAcceptAction(
dlcOfferDb = dlcOfferDb,
dlcAcceptDb = dlcAcceptDb,
offerInputs = offerInputs,
acceptInputs = acceptInputs,

View File

@ -56,7 +56,6 @@ case class DLCActionBuilder(dlcWalletDAOs: DLCWalletDAOs) {
* offer table, accept table, cet sigs table, inputs table, and refund table
*/
def buildCreateAcceptAction(
dlcOfferDb: DLCOfferDb,
dlcAcceptDb: DLCAcceptDb,
offerInputs: Vector[DLCFundingInputDb],
acceptInputs: Vector[DLCFundingInputDb],
@ -66,15 +65,10 @@ case class DLCActionBuilder(dlcWalletDAOs: DLCWalletDAOs) {
NoStream,
Effect.Write with Effect.Transactional] = {
val inputAction = dlcInputsDAO.createAllAction(offerInputs ++ acceptInputs)
val offerAction = dlcOfferDAO.createAction(dlcOfferDb)
val acceptAction = dlcAcceptDAO.createAction(dlcAcceptDb)
val sigsAction = dlcSigsDAO.createAllAction(cetSigsDb)
val refundSigAction = dlcRefundSigDAO.createAction(refundSigsDb)
val actions = Vector(inputAction,
offerAction,
acceptAction,
sigsAction,
refundSigAction)
val actions = Vector(inputAction, acceptAction, sigsAction, refundSigAction)
val allActions = DBIO
.sequence(actions)

View File

@ -39,7 +39,10 @@ trait EmbeddedPg extends BeforeAndAfterAll { this: Suite =>
override def afterAll(): Unit = {
super.afterAll()
val _ = pg.foreach(_.close())
val _ = pg.foreach { p =>
p.close()
}
()
}