2021 01 15 dlc refactors (#2518)

* Kill DigitDecomp.isSigned

* Create TLVEndpoint, TLVMidpoint ADT. Also add helper method OutcomePayoutPoint.toTlvPoint
This commit is contained in:
Chris Stewart 2021-01-16 07:55:42 -06:00 committed by GitHub
parent f3e81d027d
commit abc1fdd23f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 127 additions and 59 deletions

View file

@ -142,11 +142,13 @@ case class OracleRoutes(oracle: DLCOracle)(implicit
} }
outcomes.map(num => Num(num.toDouble)) outcomes.map(num => Num(num.toDouble))
case decomp: DigitDecompositionEventDescriptorV0TLV => case decomp: DigitDecompositionEventDescriptorV0TLV =>
val sign = if (decomp.isSigned) { val sign = decomp match {
Vector(Str("+"), Str("-")) case _: UnsignedDigitDecompositionEventDescriptor =>
} else { Vector.empty
Vector.empty case _: SignedDigitDecompositionEventDescriptor =>
Vector(Str("+"), Str("-"))
} }
val digits = 0.until(decomp.numDigits.toInt).map { _ => val digits = 0.until(decomp.numDigits.toInt).map { _ =>
0 0
.until(decomp.base.toInt) .until(decomp.base.toInt)

View file

@ -293,12 +293,7 @@ object DLCMessage {
} }
override lazy val toTLV: ContractInfoV1TLV = { override lazy val toTLV: ContractInfoV1TLV = {
val tlvPoints = outcomeValueFunc.points.map { point => val tlvPoints = outcomeValueFunc.points.map(_.toTlvPoint)
TLVPoint(point.outcome,
point.roundedPayout,
point.extraPrecision,
point.isEndpoint)
}
ContractInfoV1TLV(base, numDigits, totalCollateral, tlvPoints) ContractInfoV1TLV(base, numDigits, totalCollateral, tlvPoints)
} }
@ -312,6 +307,7 @@ object DLCMessage {
val points = tlv.points.map { point => val points = tlv.points.map { point =>
val payoutWithPrecision = val payoutWithPrecision =
point.value.toLong + (BigDecimal(point.extraPrecision) / (1 << 16)) point.value.toLong + (BigDecimal(point.extraPrecision) / (1 << 16))
OutcomePayoutPoint(point.outcome, payoutWithPrecision, point.isEndpoint) OutcomePayoutPoint(point.outcome, payoutWithPrecision, point.isEndpoint)
} }

View file

@ -1,6 +1,7 @@
package org.bitcoins.core.protocol.dlc package org.bitcoins.core.protocol.dlc
import org.bitcoins.core.currency.Satoshis import org.bitcoins.core.currency.Satoshis
import org.bitcoins.core.protocol.tlv.TLVPoint
import org.bitcoins.core.util.{Indexed, NumberUtil} import org.bitcoins.core.util.{Indexed, NumberUtil}
import scala.math.BigDecimal.RoundingMode import scala.math.BigDecimal.RoundingMode
@ -15,15 +16,26 @@ case class DLCPayoutCurve(points: Vector[OutcomePayoutPoint]) {
/** These points (and their indices in this.points) represent the endpoints /** These points (and their indices in this.points) represent the endpoints
* between which interpolation happens. * between which interpolation happens.
* In other words these endpoints define the pieces of the piecewise function. * In other words these endpoints define the pieces of the piecewise function.
*
* It's important to note that the index returned here is relative to the _entire_
* set of points, not the index relative to the set of endpoints.
*/ */
lazy val endpoints: Vector[Indexed[OutcomePayoutPoint]] = lazy val endpoints: Vector[Indexed[OutcomePayoutEndpoint]] = {
Indexed(points).filter(_.element.isEndpoint) val endpoints = points.zipWithIndex.collect {
case (o: OutcomePayoutEndpoint, idx) => (o, idx)
}
Indexed.fromGivenIndex(endpoints)
}
/** This Vector contains the function pieces between the endpoints */ /** This Vector contains the function pieces between the endpoints */
lazy val functionComponents: Vector[DLCPayoutCurveComponent] = { lazy val functionComponents: Vector[DLCPayoutCurveComponent] = {
endpoints.init.zip(endpoints.tail).map { // All pairs of adjacent endpoints val zipped: Vector[
(Indexed[OutcomePayoutEndpoint], Indexed[OutcomePayoutEndpoint])] =
endpoints.init.zip(endpoints.tail)
zipped.map { // All pairs of adjacent endpoints
case (Indexed(_, index), Indexed(_, nextIndex)) => case (Indexed(_, index), Indexed(_, nextIndex)) =>
DLCPayoutCurveComponent(points.slice(index, nextIndex + 1)) val slice = points.slice(index, nextIndex + 1)
DLCPayoutCurveComponent(slice)
} }
} }
@ -70,7 +82,13 @@ case class DLCPayoutCurve(points: Vector[OutcomePayoutPoint]) {
sealed trait OutcomePayoutPoint { sealed trait OutcomePayoutPoint {
def outcome: Long def outcome: Long
def payout: BigDecimal def payout: BigDecimal
def isEndpoint: Boolean
def isEndPoint: Boolean = {
this match {
case _: OutcomePayoutEndpoint => true
case _: OutcomePayoutMidpoint => false
}
}
def roundedPayout: Satoshis = { def roundedPayout: Satoshis = {
Satoshis(payout.setScale(0, RoundingMode.FLOOR).toLongExact) Satoshis(payout.setScale(0, RoundingMode.FLOOR).toLongExact)
@ -89,6 +107,22 @@ sealed trait OutcomePayoutPoint {
case OutcomePayoutMidpoint(_, _) => OutcomePayoutMidpoint(outcome, payout) case OutcomePayoutMidpoint(_, _) => OutcomePayoutMidpoint(outcome, payout)
} }
} }
/** Converts our internal representation to a TLV that can be sent over the wire */
def toTlvPoint: TLVPoint = {
this match {
case _: OutcomePayoutEndpoint =>
TLVPoint(outcome = outcome,
value = roundedPayout,
extraPrecision = extraPrecision,
isEndpoint = true)
case _: OutcomePayoutMidpoint =>
TLVPoint(outcome = outcome,
value = roundedPayout,
extraPrecision = extraPrecision,
isEndpoint = false)
}
}
} }
object OutcomePayoutPoint { object OutcomePayoutPoint {
@ -114,7 +148,6 @@ object OutcomePayoutPoint {
case class OutcomePayoutEndpoint(outcome: Long, payout: BigDecimal) case class OutcomePayoutEndpoint(outcome: Long, payout: BigDecimal)
extends OutcomePayoutPoint { extends OutcomePayoutPoint {
override val isEndpoint: Boolean = true
def toMidpoint: OutcomePayoutMidpoint = OutcomePayoutMidpoint(outcome, payout) def toMidpoint: OutcomePayoutMidpoint = OutcomePayoutMidpoint(outcome, payout)
} }
@ -128,7 +161,6 @@ object OutcomePayoutEndpoint {
case class OutcomePayoutMidpoint(outcome: Long, payout: BigDecimal) case class OutcomePayoutMidpoint(outcome: Long, payout: BigDecimal)
extends OutcomePayoutPoint { extends OutcomePayoutPoint {
override val isEndpoint: Boolean = false
def toEndpoint: OutcomePayoutEndpoint = OutcomePayoutEndpoint(outcome, payout) def toEndpoint: OutcomePayoutEndpoint = OutcomePayoutEndpoint(outcome, payout)
} }
@ -179,9 +211,9 @@ sealed trait DLCPayoutCurveComponent {
object DLCPayoutCurveComponent { object DLCPayoutCurveComponent {
def apply(points: Vector[OutcomePayoutPoint]): DLCPayoutCurveComponent = { def apply(points: Vector[OutcomePayoutPoint]): DLCPayoutCurveComponent = {
require(points.head.isEndpoint && points.last.isEndpoint, require(points.head.isEndPoint && points.last.isEndPoint,
s"First and last points must be endpoints, $points") s"First and last points must be endpoints, $points")
require(points.tail.init.forall(!_.isEndpoint), require(points.tail.init.forall(!_.isEndPoint),
s"Endpoint detected in middle, $points") s"Endpoint detected in middle, $points")
points match { points match {
@ -320,9 +352,10 @@ case class OutcomePayoutCubic(
/** A polynomial interpolating points and defining a piece of a larger payout curve */ /** A polynomial interpolating points and defining a piece of a larger payout curve */
case class OutcomePayoutPolynomial(points: Vector[OutcomePayoutPoint]) case class OutcomePayoutPolynomial(points: Vector[OutcomePayoutPoint])
extends DLCPayoutCurveComponent { extends DLCPayoutCurveComponent {
require(points.head.isEndpoint && points.last.isEndpoint, require(points.head.isInstanceOf[OutcomePayoutEndpoint] && points.last
.isInstanceOf[OutcomePayoutEndpoint],
s"First and last points must be endpoints, $points") s"First and last points must be endpoints, $points")
require(points.tail.init.forall(!_.isEndpoint), require(points.tail.init.forall(!_.isInstanceOf[OutcomePayoutEndpoint]),
s"Endpoint detected in middle, $points") s"Endpoint detected in middle, $points")
override lazy val leftEndpoint: OutcomePayoutEndpoint = override lazy val leftEndpoint: OutcomePayoutEndpoint =

View file

@ -603,9 +603,6 @@ sealed trait DigitDecompositionEventDescriptorV0TLV
require(numDigits > UInt16.zero, require(numDigits > UInt16.zero,
s"Number of digits must be positive, got $numDigits") s"Number of digits must be positive, got $numDigits")
/** Whether the outcome can be negative */
def isSigned: Boolean
/** The number of digits that the oracle will sign */ /** The number of digits that the oracle will sign */
def numDigits: UInt16 def numDigits: UInt16
@ -613,22 +610,32 @@ sealed trait DigitDecompositionEventDescriptorV0TLV
private lazy val maxDigit: NormalizedString = (base.toInt - 1).toString private lazy val maxDigit: NormalizedString = (base.toInt - 1).toString
override lazy val max: Vector[NormalizedString] = if (isSigned) { override lazy val max: Vector[NormalizedString] = {
NormalizedString("+") +: Vector.fill(numDigits.toInt)(maxDigit) this match {
} else { case _: SignedDigitDecompositionEventDescriptor =>
Vector.fill(numDigits.toInt)(maxDigit) NormalizedString("+") +: Vector.fill(numDigits.toInt)(maxDigit)
case _: UnsignedDigitDecompositionEventDescriptor =>
Vector.fill(numDigits.toInt)(maxDigit)
}
} }
override lazy val minNum: BigInt = if (isSigned) { override lazy val minNum: BigInt = {
-maxNum this match {
} else { case _: SignedDigitDecompositionEventDescriptor =>
0 -maxNum
case _: UnsignedDigitDecompositionEventDescriptor =>
0
}
} }
override lazy val min: Vector[NormalizedString] = if (isSigned) { override lazy val min: Vector[NormalizedString] = {
NormalizedString("-") +: Vector.fill(numDigits.toInt)(maxDigit) this match {
} else { case _: SignedDigitDecompositionEventDescriptor =>
Vector.fill(numDigits.toInt)("0") NormalizedString("-") +: Vector.fill(numDigits.toInt)(maxDigit)
case _: UnsignedDigitDecompositionEventDescriptor =>
Vector.fill(numDigits.toInt)("0")
}
} }
override lazy val step: UInt16 = UInt16.one override lazy val step: UInt16 = UInt16.one
@ -637,16 +644,26 @@ sealed trait DigitDecompositionEventDescriptorV0TLV
DigitDecompositionEventDescriptorV0TLV.tpe DigitDecompositionEventDescriptorV0TLV.tpe
override lazy val value: ByteVector = { override lazy val value: ByteVector = {
base.bytes ++ val start = base.bytes
boolBytes(isSigned) ++ val signByte = this match {
strBytes(unit) ++ case _: UnsignedDigitDecompositionEventDescriptor =>
boolBytes(false)
case _: SignedDigitDecompositionEventDescriptor =>
boolBytes(true)
}
val end = strBytes(unit) ++
precision.bytes ++ precision.bytes ++
numDigits.bytes numDigits.bytes
start ++ signByte ++ end
} }
override def noncesNeeded: Int = { override def noncesNeeded: Int = {
if (isSigned) numDigits.toInt + 1 this match {
else numDigits.toInt case _: SignedDigitDecompositionEventDescriptor =>
numDigits.toInt + 1
case _: UnsignedDigitDecompositionEventDescriptor =>
numDigits.toInt
}
} }
} }
@ -656,9 +673,7 @@ case class SignedDigitDecompositionEventDescriptor(
numDigits: UInt16, numDigits: UInt16,
unit: NormalizedString, unit: NormalizedString,
precision: Int32) precision: Int32)
extends DigitDecompositionEventDescriptorV0TLV { extends DigitDecompositionEventDescriptorV0TLV
override val isSigned: Boolean = true
}
/** Represents a large range event that is unsigned */ /** Represents a large range event that is unsigned */
case class UnsignedDigitDecompositionEventDescriptor( case class UnsignedDigitDecompositionEventDescriptor(
@ -666,9 +681,7 @@ case class UnsignedDigitDecompositionEventDescriptor(
numDigits: UInt16, numDigits: UInt16,
unit: NormalizedString, unit: NormalizedString,
precision: Int32) precision: Int32)
extends DigitDecompositionEventDescriptorV0TLV { extends DigitDecompositionEventDescriptorV0TLV
override val isSigned: Boolean = false
}
object DigitDecompositionEventDescriptorV0TLV object DigitDecompositionEventDescriptorV0TLV
extends TLVFactory[DigitDecompositionEventDescriptorV0TLV] { extends TLVFactory[DigitDecompositionEventDescriptorV0TLV] {
@ -895,7 +908,11 @@ object TLVPoint extends Factory[TLVPoint] {
val outcome = BigSizeUInt(bytes.tail) val outcome = BigSizeUInt(bytes.tail)
val value = UInt64(bytes.drop(1 + outcome.byteSize).take(8)) val value = UInt64(bytes.drop(1 + outcome.byteSize).take(8))
val extraPrecision = UInt16(bytes.drop(9 + outcome.byteSize).take(2)).toInt val extraPrecision = UInt16(bytes.drop(9 + outcome.byteSize).take(2)).toInt
TLVPoint(outcome.toLong, Satoshis(value.toLong), extraPrecision, isEndpoint)
TLVPoint(outcome = outcome.toLong,
value = Satoshis(value.toLong),
extraPrecision = extraPrecision,
isEndpoint = isEndpoint)
} }
} }

View file

@ -1,10 +1,20 @@
package org.bitcoins.core.util package org.bitcoins.core.util
case class Indexed[T](element: T, index: Int) case class Indexed[+T](element: T, index: Int)
object Indexed { object Indexed {
def apply[T](vec: Vector[T]): Vector[Indexed[T]] = { def apply[T](vec: Vector[T]): Vector[Indexed[T]] = {
vec.zipWithIndex.map { case (elem, index) => Indexed(elem, index) } vec.zipWithIndex.map { case (elem, index) => Indexed(elem, index) }
} }
/** Takes in a given vector of T's with their corresponding index
* and returns a Vector[Indexed[T]].
*
* This is useful in situations where you want to preserve the initial
* index in a set of elements, but have performed subsequent collection operations (like .filter, .filterNot, .collect etc)
*/
def fromGivenIndex[T](vec: Vector[(T, Int)]): Vector[Indexed[T]] = {
vec.map { case (t, idx) => Indexed(t, idx) }
}
} }

View file

@ -300,12 +300,14 @@ case class DLCOracle(private val extPrivateKey: ExtPrivateKeyHardened)(implicit
oracleEventTLV: OracleEventTLV, oracleEventTLV: OracleEventTLV,
num: Long): Future[OracleEvent] = { num: Long): Future[OracleEvent] = {
val eventDescriptorTLV = oracleEventTLV.eventDescriptor match { val eventDescriptorTLV: DigitDecompositionEventDescriptorV0TLV = {
case _: EnumEventDescriptorV0TLV | _: RangeEventDescriptorV0TLV => oracleEventTLV.eventDescriptor match {
throw new IllegalArgumentException( case _: EnumEventDescriptorV0TLV | _: RangeEventDescriptorV0TLV =>
"Must have a DigitDecomposition event descriptor use signEvent instead") throw new IllegalArgumentException(
case decomp: DigitDecompositionEventDescriptorV0TLV => "Must have a DigitDecomposition event descriptor use signEvent instead")
decomp case decomp: DigitDecompositionEventDescriptorV0TLV =>
decomp
}
} }
// Make this a vec so it is easier to add on // Make this a vec so it is easier to add on
@ -338,9 +340,12 @@ case class DLCOracle(private val extPrivateKey: ExtPrivateKeyHardened)(implicit
eventDescriptorTLV.base.toInt, eventDescriptorTLV.base.toInt,
eventDescriptorTLV.numDigits.toInt) eventDescriptorTLV.numDigits.toInt)
val nonces = val nonces = eventDescriptorTLV match {
if (eventDescriptorTLV.isSigned) oracleEventTLV.nonces.tail case _: UnsignedDigitDecompositionEventDescriptor =>
else oracleEventTLV.nonces oracleEventTLV.nonces
case _: SignedDigitDecompositionEventDescriptor =>
oracleEventTLV.nonces.tail
}
val digitSigFs = nonces.zipWithIndex.map { val digitSigFs = nonces.zipWithIndex.map {
case (nonce, index) => case (nonce, index) =>

View file

@ -60,7 +60,12 @@ trait EventDbUtil {
Vector.empty Vector.empty
} }
val digitNonces = if (decomp.isSigned) nonces.tail else nonces val digitNonces = decomp match {
case _: UnsignedDigitDecompositionEventDescriptor =>
nonces
case _: SignedDigitDecompositionEventDescriptor =>
nonces.tail
}
val digitDbs = digitNonces.flatMap { nonce => val digitDbs = digitNonces.flatMap { nonce =>
0.until(decomp.base.toInt).map { num => 0.until(decomp.base.toInt).map { num =>