Add wallet function to bump fee with RBF (#2392)

* Add wallet function to bump fee

* Bump sequence number

* Respond to review

* Fix test
This commit is contained in:
Ben Carman 2020-12-20 07:43:18 -06:00 committed by GitHub
parent 0264c2daf8
commit fcfc3d076f
8 changed files with 195 additions and 12 deletions

View file

@ -365,6 +365,10 @@ trait WalletApi extends StartStopAsync[WalletApi] {
amounts: Vector[CurrencyUnit], amounts: Vector[CurrencyUnit],
feeRate: FeeUnit)(implicit ec: ExecutionContext): Future[Transaction] feeRate: FeeUnit)(implicit ec: ExecutionContext): Future[Transaction]
def bumpFeeRBF(
txId: DoubleSha256DigestBE,
newFeeRate: FeeUnit): Future[Transaction]
def makeOpReturnCommitment( def makeOpReturnCommitment(
message: String, message: String,
hashMessage: Boolean, hashMessage: Boolean,

View file

@ -1,7 +1,7 @@
package org.bitcoins.core.wallet.builder package org.bitcoins.core.wallet.builder
import org.bitcoins.core.currency.{CurrencyUnit, Satoshis} import org.bitcoins.core.currency.{CurrencyUnit, Satoshis}
import org.bitcoins.core.number.Int64 import org.bitcoins.core.number.{Int64, UInt32}
import org.bitcoins.core.policy.Policy import org.bitcoins.core.policy.Policy
import org.bitcoins.core.protocol.script.{ScriptPubKey, ScriptSignature} import org.bitcoins.core.protocol.script.{ScriptPubKey, ScriptSignature}
import org.bitcoins.core.protocol.transaction._ import org.bitcoins.core.protocol.transaction._
@ -66,8 +66,10 @@ abstract class FinalizerFactory[T <: RawTxFinalizer] {
outputs: Seq[TransactionOutput], outputs: Seq[TransactionOutput],
utxos: Seq[InputSigningInfo[InputInfo]], utxos: Seq[InputSigningInfo[InputInfo]],
feeRate: FeeUnit, feeRate: FeeUnit,
changeSPK: ScriptPubKey): RawTxBuilderWithFinalizer[T] = { changeSPK: ScriptPubKey,
val inputs = InputUtil.calcSequenceForInputs(utxos) defaultSequence: UInt32 = Policy.sequence): RawTxBuilderWithFinalizer[
T] = {
val inputs = InputUtil.calcSequenceForInputs(utxos, defaultSequence)
val lockTime = TxUtil.calcLockTime(utxos).get val lockTime = TxUtil.calcLockTime(utxos).get
val builder = RawTxBuilder().setLockTime(lockTime) ++= outputs ++= inputs val builder = RawTxBuilder().setLockTime(lockTime) ++= outputs ++= inputs
val finalizer = val finalizer =

View file

@ -9,6 +9,8 @@ import org.bitcoins.core.script.control.OP_RETURN
import org.bitcoins.core.wallet.fee._ import org.bitcoins.core.wallet.fee._
import org.bitcoins.core.wallet.utxo.TxoState import org.bitcoins.core.wallet.utxo.TxoState
import org.bitcoins.crypto.CryptoUtil import org.bitcoins.crypto.CryptoUtil
import org.bitcoins.testkit.Implicits.GeneratorOps
import org.bitcoins.testkit.core.gen.FeeUnitGen
import org.bitcoins.testkit.wallet.BitcoinSWalletTest import org.bitcoins.testkit.wallet.BitcoinSWalletTest
import org.bitcoins.testkit.wallet.BitcoinSWalletTest.RandomFeeProvider import org.bitcoins.testkit.wallet.BitcoinSWalletTest.RandomFeeProvider
import org.bitcoins.testkit.wallet.FundWalletUtil.FundedWallet import org.bitcoins.testkit.wallet.FundWalletUtil.FundedWallet
@ -255,6 +257,34 @@ class WalletSendingTest extends BitcoinSWalletTest {
} }
} }
it should "correctly bump the fee rate of a transaction" in { fundedWallet =>
val wallet = fundedWallet.wallet
val feeRate = FeeUnitGen.satsPerByte.sampleSome
for {
tx <- wallet.sendToAddress(testAddress, amountToSend, feeRate)
firstBal <- wallet.getBalance()
newFeeRate = SatoshisPerByte(feeRate.currencyUnit + Satoshis.one)
bumpedTx <- wallet.bumpFeeRBF(tx.txIdBE, newFeeRate)
txDb1Opt <- wallet.outgoingTxDAO.findByTxId(tx.txIdBE)
txDb2Opt <- wallet.outgoingTxDAO.findByTxId(bumpedTx.txIdBE)
secondBal <- wallet.getBalance()
} yield {
assert(txDb1Opt.isDefined)
assert(txDb2Opt.isDefined)
val txDb1 = txDb1Opt.get
val txDb2 = txDb2Opt.get
assert(txDb1.actualFee < txDb2.actualFee)
assert(firstBal - secondBal == txDb2.actualFee - txDb1.actualFee)
}
}
it should "fail to send from outpoints when already spent" in { it should "fail to send from outpoints when already spent" in {
fundedWallet => fundedWallet =>
val wallet = fundedWallet.wallet val wallet = fundedWallet.wallet

View file

@ -58,4 +58,25 @@ class AddressDAOTest extends WalletDAOFixture {
} yield assert(readAddress.contains(createdAddress)) } yield assert(readAddress.contains(createdAddress))
} }
it should "find by script pub key" in { daos =>
val addressDAO = daos.addressDAO
val addr1 = WalletTestUtil.getAddressDb(WalletTestUtil.firstAccountDb)
val addr2 = WalletTestUtil.getAddressDb(WalletTestUtil.firstAccountDb,
addressIndex = 1)
val addr3 = WalletTestUtil.getAddressDb(WalletTestUtil.firstAccountDb,
addressIndex = 2)
val spks = Vector(addr1.scriptPubKey, addr2.scriptPubKey)
for {
created1 <- addressDAO.create(addr1)
created2 <- addressDAO.create(addr2)
created3 <- addressDAO.create(addr3)
found <- addressDAO.findByScriptPubKeys(spks)
} yield {
assert(found.contains(created1))
assert(found.contains(created2))
assert(!found.contains(created3))
}
}
} }

View file

@ -3,6 +3,7 @@ package org.bitcoins.wallet.models
import org.bitcoins.core.api.wallet.db.{ import org.bitcoins.core.api.wallet.db.{
LegacySpendingInfo, LegacySpendingInfo,
NestedSegwitV0SpendingInfo, NestedSegwitV0SpendingInfo,
ScriptPubKeyDb,
SegwitV0SpendingInfo SegwitV0SpendingInfo
} }
import org.bitcoins.core.protocol.script.ScriptSignature import org.bitcoins.core.protocol.script.ScriptSignature
@ -204,4 +205,28 @@ class SpendingInfoDAOTest extends WalletDAOFixture {
case Some(other) => fail(s"did not get a nested segwit UTXO: $other") case Some(other) => fail(s"did not get a nested segwit UTXO: $other")
} }
} }
it should "find incoming outputs dbs being spent, given a TX" in { daos =>
val utxoDAO = daos.utxoDAO
for {
created <- WalletTestUtil.insertNestedSegWitUTXO(daos)
db <- utxoDAO.read(created.id.get)
account <- daos.accountDAO.create(WalletTestUtil.firstAccountDb)
addr <- daos.addressDAO.create(getAddressDb(account))
// Add another utxo
u2 = WalletTestUtil.sampleSegwitUTXO(addr.scriptPubKey)
_ <- insertDummyIncomingTransaction(daos, u2)
_ <- utxoDAO.create(u2)
dbs <- utxoDAO.findDbsForTx(created.txid)
} yield {
assert(dbs.size == 1)
assert(db.isDefined)
assert(dbs == Vector(db.get))
}
}
} }

View file

@ -1,7 +1,5 @@
package org.bitcoins.wallet package org.bitcoins.wallet
import java.time.Instant
import org.bitcoins.commons.jsonmodels.wallet.SyncHeightDescriptor import org.bitcoins.commons.jsonmodels.wallet.SyncHeightDescriptor
import org.bitcoins.core.api.chain.ChainQueryApi import org.bitcoins.core.api.chain.ChainQueryApi
import org.bitcoins.core.api.feeprovider.FeeRateApi import org.bitcoins.core.api.feeprovider.FeeRateApi
@ -14,6 +12,7 @@ import org.bitcoins.core.crypto.ExtPublicKey
import org.bitcoins.core.currency._ import org.bitcoins.core.currency._
import org.bitcoins.core.gcs.{GolombFilter, SimpleFilterMatcher} import org.bitcoins.core.gcs.{GolombFilter, SimpleFilterMatcher}
import org.bitcoins.core.hd._ import org.bitcoins.core.hd._
import org.bitcoins.core.number.UInt32
import org.bitcoins.core.protocol.BitcoinAddress import org.bitcoins.core.protocol.BitcoinAddress
import org.bitcoins.core.protocol.blockchain.ChainParams import org.bitcoins.core.protocol.blockchain.ChainParams
import org.bitcoins.core.protocol.script.ScriptPubKey import org.bitcoins.core.protocol.script.ScriptPubKey
@ -33,20 +32,16 @@ import org.bitcoins.core.wallet.utxo.TxoState.{
PendingConfirmationsReceived PendingConfirmationsReceived
} }
import org.bitcoins.core.wallet.utxo._ import org.bitcoins.core.wallet.utxo._
import org.bitcoins.crypto.{ import org.bitcoins.crypto._
AesPassword,
CryptoUtil,
DoubleSha256Digest,
ECPublicKey
}
import org.bitcoins.keymanager.bip39.{BIP39KeyManager, BIP39LockedKeyManager} import org.bitcoins.keymanager.bip39.{BIP39KeyManager, BIP39LockedKeyManager}
import org.bitcoins.wallet.config.WalletAppConfig import org.bitcoins.wallet.config.WalletAppConfig
import org.bitcoins.wallet.internal._ import org.bitcoins.wallet.internal._
import org.bitcoins.wallet.models._ import org.bitcoins.wallet.models._
import scodec.bits.ByteVector import scodec.bits.ByteVector
import java.time.Instant
import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success} import scala.util.{Failure, Random, Success}
abstract class Wallet abstract class Wallet
extends AnyHDWalletApi extends AnyHDWalletApi
@ -499,6 +494,88 @@ abstract class Wallet
} yield tx } yield tx
} }
override def bumpFeeRBF(
txId: DoubleSha256DigestBE,
newFeeRate: FeeUnit): Future[Transaction] = {
for {
txDbOpt <- transactionDAO.findByTxId(txId)
tx <- txDbOpt match {
case Some(db) => Future.successful(db.transaction)
case None =>
Future.failed(
new RuntimeException(s"Unable to find transaction ${txId.hex}"))
}
outPoints = tx.inputs.map(_.previousOutput).toVector
spks = tx.outputs.map(_.scriptPubKey).toVector
utxos <- spendingInfoDAO.findByOutPoints(outPoints)
_ = require(utxos.nonEmpty, "Can only bump fee for our own transaction")
_ = require(utxos.size == tx.inputs.size,
"Can only bump fee for a transaction we own all the inputs")
spendingInfos <- FutureUtil.sequentially(utxos) { utxo =>
transactionDAO
.findByOutPoint(utxo.outPoint)
.map(txDbOpt =>
utxo.toUTXOInfo(keyManager = keyManager, txDbOpt.get.transaction))
}
_ = {
val inputAmount = utxos.foldLeft(CurrencyUnits.zero)(_ + _.output.value)
val oldFeeRate = newFeeRate match {
case _: SatoshisPerByte =>
SatoshisPerByte.calc(inputAmount, tx)
case _: SatoshisPerKiloByte =>
SatoshisPerKiloByte.calc(inputAmount, tx)
case _: SatoshisPerVirtualByte =>
SatoshisPerVirtualByte.calc(inputAmount, tx)
case _: SatoshisPerKW =>
SatoshisPerKW.calc(inputAmount, tx)
}
require(
oldFeeRate.currencyUnit < newFeeRate.currencyUnit,
s"Cannot bump to a lower fee ${oldFeeRate.currencyUnit} < ${newFeeRate.currencyUnit}")
}
myAddrs <- addressDAO.findByScriptPubKeys(spks)
_ = require(myAddrs.nonEmpty, "Must have an output we own")
changeSpks = myAddrs.flatMap { db =>
if (db.path.chain.chainType == HDChainType.Change) {
Some(db.scriptPubKey)
} else None
}
changeSpk =
if (changeSpks.nonEmpty) {
// Pick a random change spk
Random.shuffle(changeSpks).head
} else {
// If none are explicit change, pick a random one we own
Random.shuffle(myAddrs.map(_.scriptPubKey)).head
}
// Mark old outputs as replaced
oldUtxos <- spendingInfoDAO.findDbsForTx(txId)
_ <- spendingInfoDAO.updateAll(
oldUtxos.map(_.copyWithState(TxoState.DoesNotExist)))
sequence = tx.inputs.head.sequence + UInt32.one
outputs = tx.outputs.filterNot(_.scriptPubKey == changeSpk)
txBuilder = StandardNonInteractiveFinalizer.txBuilderFrom(outputs,
spendingInfos,
newFeeRate,
changeSpk,
sequence)
amount = outputs.foldLeft(CurrencyUnits.zero)(_ + _.value)
tx <-
finishSend(txBuilder, spendingInfos, amount, newFeeRate, Vector.empty)
} yield tx
}
override def sendWithAlgo( override def sendWithAlgo(
address: BitcoinAddress, address: BitcoinAddress,
amount: CurrencyUnit, amount: CurrencyUnit,

View file

@ -280,6 +280,21 @@ case class AddressDAO()(implicit
}) })
} }
def findByScriptPubKeys(
spks: Vector[ScriptPubKey]): Future[Vector[AddressDb]] = {
val query = table
.join(spkTable)
.on(_.scriptPubKeyId === _.id)
.filter(_._2.scriptPubKey.inSet(spks))
safeDatabase
.runVec(query.result.transactionally)
.map(res =>
res.map {
case (addrRec, spkRec) => addrRec.toAddressDb(spkRec.scriptPubKey)
})
}
private def findMostRecentForChain(account: HDAccount, chain: HDChainType) = { private def findMostRecentForChain(account: HDAccount, chain: HDChainType) = {
addressesForAccountQuery(account.index) addressesForAccountQuery(account.index)
.filter(_._1.purpose === account.purpose) .filter(_._1.purpose === account.purpose)

View file

@ -277,6 +277,15 @@ case class SpendingInfoDAO()(implicit
} }
/**
* Fetches all the incoming TXOs in our DB that are in
* the transaction with the given TXID
*/
def findDbsForTx(txid: DoubleSha256DigestBE): Future[Vector[UTXORecord]] = {
val query = table.filter(_.txid === txid)
safeDatabase.runVec(query.result)
}
/** /**
* Fetches all the incoming TXOs in our DB that are in * Fetches all the incoming TXOs in our DB that are in
* the transaction with the given TXID * the transaction with the given TXID