DLC Adaptor Point Computation Memoization (#3110)

* A first attempt

* Made memoization conditional on CryptoContext

* Added scaladocs

* Responded to review and added a unit test
This commit is contained in:
Nadav Kohen 2021-05-21 17:40:28 -05:00 committed by GitHub
parent 9c9e27a8f5
commit 87f353b08f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 175 additions and 26 deletions

View file

@ -0,0 +1,55 @@
package org.bitcoins.core.protocol.dlc
import org.bitcoins.core.protocol.dlc.compute.DLCAdaptorPointComputer.AdditionTrieNode
import org.bitcoins.crypto.{CryptoUtil, ECPublicKey}
import org.bitcoins.testkitcore.util.BitcoinSUnitTest
import scala.annotation.tailrec
class AdditionTrieNodeTest extends BitcoinSUnitTest {
behavior of "AdditionTrieNode"
val base: Int = 2
val nonces: Int = 10
@tailrec
private def computeAllPrefixes(accum: Vector[Vector[Vector[Int]]] =
Vector.empty): Vector[Vector[Int]] = {
if (accum.length == nonces) {
accum.flatten
} else {
val newPrefixes = if (accum.isEmpty) {
0.until(base).toVector.map(Vector(_))
} else {
accum.last.flatMap { prefix =>
0.until(base).toVector.map(digit => prefix.:+(digit))
}
}
computeAllPrefixes(accum.:+(newPrefixes))
}
}
val allPrefixes: Vector[Vector[Int]] = {
computeAllPrefixes()
}
val preComputeTable: Vector[Vector[ECPublicKey]] =
Vector.fill(nonces)(Vector.fill(base)(ECPublicKey.freshPublicKey))
def newTrie(): AdditionTrieNode = {
AdditionTrieNode.makeRoot(preComputeTable)
}
it should "correctly compute all elements" in {
val trie = newTrie()
allPrefixes.foreach { prefix =>
val ptsToAdd = prefix.zipWithIndex.map {
case (outcomeIndex, nonceIndex) =>
preComputeTable(nonceIndex)(outcomeIndex)
}
val expected = CryptoUtil.combinePubKeys(ptsToAdd)
assert(trie.computeSum(prefix) == expected)
}
}
}

View file

@ -10,12 +10,7 @@ import org.bitcoins.core.protocol.tlv.{
SignedNumericOutcome,
UnsignedNumericOutcome
}
import org.bitcoins.crypto.{
CryptoUtil,
ECPublicKey,
FieldElement,
SchnorrPublicKey
}
import org.bitcoins.crypto._
import scodec.bits.ByteVector
/** Responsible for optimized computation of DLC adaptor point batches. */
@ -50,6 +45,79 @@ object DLCAdaptorPointComputer {
nonce.add(pubKey.publicKey.tweakMultiply(FieldElement(hash)))
}
/** This trie is used for computing adaptor points for a single oracle corresponding
* to digit prefixes while memoizing partial sums.
*
* For example the point corresponding to 0110 and 01111010 both begin with
* the 011 sub-sum.
*
* This trie stores all already computed sub-sums and new points are computed
* by extending this Trie.
*
* Note that this method should not be used if you have access to LibSecp256k1CryptoRuntime
* because calling CryptoUtil.combinePubKeys will outperform memoization in that case.
*/
case class AdditionTrieNode(
preComputeTable: Vector[Vector[ECPublicKey]], // Nonce -> Outcome -> Point
depth: Int = 0,
private var children: Vector[AdditionTrieNode] = Vector.empty,
private var pointOpt: Option[SecpPoint] = None) {
/** Populates children field with base empty nodes.
*
* To avoid unnecessary computation (and recursion),
* this should be called lazily only when children are needed.
*/
def initChildren(): Unit = {
children = 0
.until(base)
.toVector
.map(_ => AdditionTrieNode(preComputeTable, depth + 1))
}
/** Uses the preComputeTable to calculate the adaptor point
* for the given digit prefix.
*
* This is done by traversing (and where need be extending) the
* Trie according to the digits until the point corresponding to
* the input digits is reached.
*/
def computeSum(digits: Vector[Int]): ECPublicKey = {
val point = pointOpt.get
if (digits.isEmpty) { // Then we have arrived at our result
point match {
case SecpPointInfinity =>
throw new IllegalArgumentException(
"Sum cannot be point at infinity.")
case point: SecpPointFinite => point.toPublicKey
}
} else {
val digit = digits.head
if (children.isEmpty) initChildren()
val child = children(digit)
// If child is not defined, extend the trie
child.pointOpt match {
case Some(_) => ()
case None =>
val pointToAdd = preComputeTable(depth)(digit).toPoint
child.pointOpt = Some(point.add(pointToAdd))
}
// Then move down and continue computation
child.computeSum(digits.tail)
}
}
}
object AdditionTrieNode {
/** Creates a fresh AdditionTreeNode for a given preComputeTable */
def makeRoot(
preComputeTable: Vector[Vector[ECPublicKey]]): AdditionTrieNode = {
AdditionTrieNode(preComputeTable, pointOpt = Some(SecpPointInfinity))
}
}
/** Efficiently computes all adaptor points, in order, for a given ContractInfo.
* @see https://medium.com/crypto-garage/optimizing-numeric-outcome-dlc-creation-6d6091ac0e47
*/
@ -77,33 +145,59 @@ object DLCAdaptorPointComputer {
}
}
lazy val additionTries = preComputeTable.map { table =>
AdditionTrieNode.makeRoot(table)
}
val oraclesAndOutcomes = contractInfo.allOutcomes.map(_.oraclesAndOutcomes)
oraclesAndOutcomes.map { oracleAndOutcome =>
// For the given oracleAndOutcome, look up the point in the preComputeTable
val subSigPoints = oracleAndOutcome.flatMap { case (info, outcome) =>
val oracleIndex =
contractInfo.oracleInfo.singleOracleInfos.indexOf(info)
val outcomeIndices = outcome match {
case outcome: EnumOutcome =>
Vector(
contractInfo.contractDescriptor
.asInstanceOf[EnumContractDescriptor]
.keys
.indexOf(outcome)
)
case UnsignedNumericOutcome(digits) => digits
case _: SignedNumericOutcome =>
throw new UnsupportedOperationException(
"Signed numeric outcomes not supported!")
}
val subSigPoints = CryptoUtil.cryptoContext match {
case CryptoContext.LibSecp256k1 =>
oracleAndOutcome.flatMap { case (info, outcome) =>
val oracleIndex =
contractInfo.oracleInfo.singleOracleInfos.indexOf(info)
val outcomeIndices = outcome match {
case outcome: EnumOutcome =>
Vector(
contractInfo.contractDescriptor
.asInstanceOf[EnumContractDescriptor]
.keys
.indexOf(outcome)
)
case UnsignedNumericOutcome(digits) => digits
case _: SignedNumericOutcome =>
throw new UnsupportedOperationException(
"Signed numeric outcomes not supported!")
}
outcomeIndices.zipWithIndex.map { case (outcomeIndex, nonceIndex) =>
preComputeTable(oracleIndex)(nonceIndex)(outcomeIndex)
}
outcomeIndices.zipWithIndex.map { case (outcomeIndex, nonceIndex) =>
preComputeTable(oracleIndex)(nonceIndex)(outcomeIndex)
}
}
case CryptoContext.BouncyCastle | CryptoContext.BCrypto =>
oracleAndOutcome.map { case (info, outcome) =>
val oracleIndex =
contractInfo.oracleInfo.singleOracleInfos.indexOf(info)
outcome match {
case outcome: EnumOutcome =>
val outcomeIndex = contractInfo.contractDescriptor
.asInstanceOf[EnumContractDescriptor]
.keys
.indexOf(outcome)
preComputeTable(oracleIndex)(0)(outcomeIndex)
case UnsignedNumericOutcome(digits) =>
additionTries(oracleIndex).computeSum(digits)
case _: SignedNumericOutcome =>
throw new UnsupportedOperationException(
"Signed numeric outcomes not supported!")
}
}
}
// TODO: Memoization of sub-combinations for further optimization!
CryptoUtil.combinePubKeys(subSigPoints)
}
}