OutcomePayoutPoint now has the correct types and deffers rounding due to extra_precision in serialized points (#2441)

This commit is contained in:
Nadav Kohen 2020-12-29 13:55:39 -06:00 committed by GitHub
parent f7c6fc141d
commit ff60c6e03e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 104 additions and 60 deletions

View file

@ -13,7 +13,7 @@ class DLCPayoutCurveTest extends BitcoinSUnitTest {
behavior of "DLCPayoutCurve"
private val numGen = Gen.choose[Double](0, 1000).map(BigDecimal(_))
private val numGen = Gen.choose[Long](0, 10000)
private val intGen = Gen.choose[Int](0, 1000)
def nPoints(n: Int): Gen[Vector[OutcomePayoutPoint]] = {
@ -32,7 +32,7 @@ class DLCPayoutCurveTest extends BitcoinSUnitTest {
}
}
it should "agree on lines and degree 1 polynomials" ignore {
it should "agree on lines and degree 1 polynomials" in {
forAll(nPoints(2), Gen.listOfN(1000, numGen)) {
case (Vector(point1: OutcomePayoutEndpoint,
point2: OutcomePayoutEndpoint),
@ -48,7 +48,7 @@ class DLCPayoutCurveTest extends BitcoinSUnitTest {
}
}
it should "agree on lines and y = mx + b" ignore {
it should "agree on lines and y = mx + b" in {
val twoNums = for {
num1 <- intGen
num2 <- intGen.suchThat(_ != num1)
@ -78,7 +78,7 @@ class DLCPayoutCurveTest extends BitcoinSUnitTest {
}
}
it should "agree on quadratics and degree 2 polynomials" ignore {
it should "agree on quadratics and degree 2 polynomials" in {
forAll(nPoints(3), Gen.listOfN(1000, numGen)) {
case (Vector(point1: OutcomePayoutEndpoint,
point2: OutcomePayoutMidpoint,
@ -96,7 +96,7 @@ class DLCPayoutCurveTest extends BitcoinSUnitTest {
}
}
it should "agree on quadratics and y = ax^2 + bx + c" ignore {
it should "agree on quadratics and y = ax^2 + bx + c" in {
val threeNums = for {
num1 <- intGen
num2 <- intGen.suchThat(_ != num1)
@ -125,7 +125,7 @@ class DLCPayoutCurveTest extends BitcoinSUnitTest {
}
}
it should "agree on degenerate quadratics and lines" ignore {
it should "agree on degenerate quadratics and lines" in {
val threeNums = for {
num1 <- intGen
num2 <- intGen.suchThat(_ != num1)
@ -155,7 +155,7 @@ class DLCPayoutCurveTest extends BitcoinSUnitTest {
}
}
it should "agree on cubics and degree 3 polynomials" ignore {
it should "agree on cubics and degree 3 polynomials" in {
forAll(nPoints(4), Gen.listOfN(1000, numGen)) {
case (Vector(point1: OutcomePayoutEndpoint,
point2: OutcomePayoutMidpoint,
@ -174,7 +174,7 @@ class DLCPayoutCurveTest extends BitcoinSUnitTest {
}
}
it should "agree on cubics and y = ax^3 + bx^2 + cx + d" ignore {
it should "agree on cubics and y = ax^3 + bx^2 + cx + d" in {
val fourNums = for {
num1 <- intGen
num2 <- intGen.suchThat(_ != num1)
@ -211,7 +211,7 @@ class DLCPayoutCurveTest extends BitcoinSUnitTest {
}
}
it should "agree on degenerate cubics and lines" ignore {
it should "agree on degenerate cubics and lines" in {
val fourNums = for {
num1 <- intGen
num2 <- intGen.suchThat(_ != num1)
@ -243,7 +243,7 @@ class DLCPayoutCurveTest extends BitcoinSUnitTest {
}
}
it should "agree on degenerate cubics and quadratics" ignore {
it should "agree on degenerate cubics and quadratics" in {
val fourNums = for {
num1 <- intGen
num2 <- intGen.suchThat(_ != num1)
@ -275,24 +275,24 @@ class DLCPayoutCurveTest extends BitcoinSUnitTest {
}
}
it should "parse points into component functions correctly and compute outputs" ignore {
it should "parse points into component functions correctly and compute outputs" in {
val point0 = OutcomePayoutEndpoint(0, Satoshis.zero)
val point1 = OutcomePayoutEndpoint(1, Satoshis.one)
val point1 = OutcomePayoutEndpoint(10, Satoshis(100))
val line = DLCPayoutCurve(Vector(point0, point1))
val lineFunc = line.functionComponents
assert(lineFunc == Vector(OutcomePayoutLine(point0, point1)))
val point2 = OutcomePayoutMidpoint(2, Satoshis.zero)
val point3 = OutcomePayoutEndpoint(3, Satoshis(3))
val point2 = OutcomePayoutMidpoint(20, Satoshis.zero)
val point3 = OutcomePayoutEndpoint(30, Satoshis(300))
val quad = DLCPayoutCurve(Vector(point1, point2, point3))
val quadFunc = quad.functionComponents
assert(quadFunc == Vector(OutcomePayoutQuadratic(point1, point2, point3)))
val point4 = OutcomePayoutMidpoint(4, Satoshis(6))
val point5 = OutcomePayoutMidpoint(5, Satoshis(5))
val point6 = OutcomePayoutEndpoint(6, Satoshis(7))
val point4 = OutcomePayoutMidpoint(40, Satoshis(600))
val point5 = OutcomePayoutMidpoint(50, Satoshis(500))
val point6 = OutcomePayoutEndpoint(60, Satoshis(700))
val cubicPoints = Vector(point3, point4, point5, point6)
val cubic = DLCPayoutCurve(cubicPoints)
@ -305,12 +305,12 @@ class DLCPayoutCurveTest extends BitcoinSUnitTest {
val allFuncs = func.functionComponents
assert(allFuncs == lineFunc ++ quadFunc ++ cubicFunc)
forAll(Gen.choose[Double](0, 6)) { outcome =>
forAll(Gen.choose[Long](0, 60)) { outcome =>
val value = func(outcome)
if (0 <= outcome && outcome < 1) {
if (0 <= outcome && outcome < 10) {
assert(value == line(outcome))
} else if (1 <= outcome && outcome < 3) {
} else if (10 <= outcome && outcome < 30) {
assert(value == quad(outcome))
} else {
assert(value == cubic(outcome))

View file

@ -95,8 +95,8 @@ object CETCalculator {
def processConstantComponents(): Unit = {
currentFunc match {
case OutcomePayoutConstant(_, rightEndpoint) =>
val componentEnd = rightEndpoint.outcome.toLongExact - 1
val funcValue = rightEndpoint.payout
val componentEnd = rightEndpoint.outcome - 1
val funcValue = rightEndpoint.roundedPayout
if (funcValue <= Satoshis.zero) {
currentRange match {

View file

@ -281,7 +281,7 @@ object DLCMessage {
s"Input total collateral ($totalCollateral) did not match ${this.totalCollateral}")
val flippedFunc = DLCPayoutCurve(outcomeValueFunc.points.map { point =>
point.copy(payout = (totalCollateral - point.payout).satoshis)
point.copy(payout = totalCollateral.toLong - point.payout)
})
MultiNonceContractInfo(
@ -294,7 +294,10 @@ object DLCMessage {
override lazy val toTLV: ContractInfoV1TLV = {
val tlvPoints = outcomeValueFunc.points.map { point =>
TLVPoint(point.outcome.toLongExact, point.payout, point.isEndpoint)
TLVPoint(point.outcome,
point.roundedPayout,
point.extraPrecision,
point.isEndpoint)
}
ContractInfoV1TLV(base, numDigits, totalCollateral, tlvPoints)
@ -307,7 +310,9 @@ object DLCMessage {
override def fromTLV(tlv: ContractInfoV1TLV): MultiNonceContractInfo = {
val points = tlv.points.map { point =>
OutcomePayoutPoint(point.outcome, point.value, point.isEndpoint)
val payoutWithPrecision =
point.value.toLong + (BigDecimal(point.extraPrecision) / (1 << 16))
OutcomePayoutPoint(point.outcome, payoutWithPrecision, point.isEndpoint)
}
MultiNonceContractInfo(DLCPayoutCurve(points),

View file

@ -32,7 +32,7 @@ case class DLCPayoutCurve(points: Vector[OutcomePayoutPoint]) {
/** Returns the function component on which the given oracle outcome is
* defined, along with its index
*/
def componentFor(outcome: BigDecimal): Indexed[DLCPayoutCurveComponent] = {
def componentFor(outcome: Long): Indexed[DLCPayoutCurveComponent] = {
val endpointIndex = NumberUtil.search(outcomes, outcome)
val Indexed(endpoint, _) = endpoints(endpointIndex)
@ -45,19 +45,19 @@ case class DLCPayoutCurve(points: Vector[OutcomePayoutPoint]) {
}
}
def getPayout(outcome: BigDecimal): Satoshis = {
def getPayout(outcome: Long): Satoshis = {
val Indexed(func, _) = componentFor(outcome)
func(outcome)
}
def getPayout(outcome: BigDecimal, rounding: RoundingIntervals): Satoshis = {
def getPayout(outcome: Long, rounding: RoundingIntervals): Satoshis = {
val Indexed(func, _) = componentFor(outcome)
func(outcome, rounding)
}
def apply(outcome: BigDecimal): Satoshis = getPayout(outcome)
def apply(outcome: Long): Satoshis = getPayout(outcome)
def apply(outcome: BigDecimal, rounding: RoundingIntervals): Satoshis =
def apply(outcome: Long, rounding: RoundingIntervals): Satoshis =
getPayout(outcome, rounding)
}
@ -68,13 +68,22 @@ case class DLCPayoutCurve(points: Vector[OutcomePayoutPoint]) {
* isEndpoint: True if this point defines a boundary between pieces in the curve
*/
sealed trait OutcomePayoutPoint {
def outcome: BigDecimal
def payout: Satoshis
def outcome: Long
def payout: BigDecimal
def isEndpoint: Boolean
def roundedPayout: Satoshis = {
Satoshis(payout.setScale(0, RoundingMode.FLOOR).toLongExact)
}
def extraPrecision: Int = {
val shifted = (payout - roundedPayout.toLong) * (1 << 16)
shifted.setScale(0, RoundingMode.FLOOR).toIntExact
}
def copy(
outcome: BigDecimal = this.outcome,
payout: Satoshis = this.payout): OutcomePayoutPoint = {
outcome: Long = this.outcome,
payout: BigDecimal = this.payout): OutcomePayoutPoint = {
this match {
case OutcomePayoutEndpoint(_, _) => OutcomePayoutEndpoint(outcome, payout)
case OutcomePayoutMidpoint(_, _) => OutcomePayoutMidpoint(outcome, payout)
@ -85,8 +94,8 @@ sealed trait OutcomePayoutPoint {
object OutcomePayoutPoint {
def apply(
outcome: BigDecimal,
payout: Satoshis,
outcome: Long,
payout: BigDecimal,
isEndpoint: Boolean): OutcomePayoutPoint = {
if (isEndpoint) {
OutcomePayoutEndpoint(outcome, payout)
@ -94,22 +103,43 @@ object OutcomePayoutPoint {
OutcomePayoutMidpoint(outcome, payout)
}
}
def apply(
outcome: Long,
payout: Satoshis,
isEndpoint: Boolean): OutcomePayoutPoint = {
OutcomePayoutPoint(outcome, payout.toLong, isEndpoint)
}
}
case class OutcomePayoutEndpoint(outcome: BigDecimal, payout: Satoshis)
case class OutcomePayoutEndpoint(outcome: Long, payout: BigDecimal)
extends OutcomePayoutPoint {
override val isEndpoint: Boolean = true
def toMidpoint: OutcomePayoutMidpoint = OutcomePayoutMidpoint(outcome, payout)
}
case class OutcomePayoutMidpoint(outcome: BigDecimal, payout: Satoshis)
object OutcomePayoutEndpoint {
def apply(outcome: Long, payout: Satoshis): OutcomePayoutEndpoint = {
OutcomePayoutEndpoint(outcome, payout.toLong)
}
}
case class OutcomePayoutMidpoint(outcome: Long, payout: BigDecimal)
extends OutcomePayoutPoint {
override val isEndpoint: Boolean = false
def toEndpoint: OutcomePayoutEndpoint = OutcomePayoutEndpoint(outcome, payout)
}
object OutcomePayoutMidpoint {
def apply(outcome: Long, payout: Satoshis): OutcomePayoutMidpoint = {
OutcomePayoutMidpoint(outcome, payout.toLong)
}
}
/** A single piece of a larger piecewise function defined between left and right endpoints */
sealed trait DLCPayoutCurveComponent {
def leftEndpoint: OutcomePayoutEndpoint
@ -131,15 +161,18 @@ sealed trait DLCPayoutCurveComponent {
s"Points must be ascending: $this")
}
def apply(outcome: BigDecimal): Satoshis
def apply(outcome: Long): Satoshis
def apply(outcome: BigDecimal, rounding: RoundingIntervals): Satoshis = {
def apply(outcome: Long, rounding: RoundingIntervals): Satoshis = {
rounding.round(outcome, apply(outcome))
}
/** Returns the largest Long less than or equal to bd (floor function) */
protected def bigDecimalSats(bd: BigDecimal): Satoshis = {
Satoshis(bd.setScale(0, RoundingMode.FLOOR).toLongExact)
Satoshis(
bd.setScale(6, RoundingMode.HALF_UP)
.setScale(0, RoundingMode.FLOOR)
.toLongExact)
}
}
@ -181,7 +214,8 @@ case class OutcomePayoutConstant(
override lazy val midpoints: Vector[OutcomePayoutMidpoint] = Vector.empty
override def apply(outcome: BigDecimal): Satoshis = leftEndpoint.payout
override def apply(outcome: Long): Satoshis =
bigDecimalSats(leftEndpoint.payout)
}
/** A Line between left and right endpoints defining a piece of a larger payout curve */
@ -192,12 +226,12 @@ case class OutcomePayoutLine(
override lazy val midpoints: Vector[OutcomePayoutMidpoint] = Vector.empty
lazy val slope: BigDecimal = {
(rightEndpoint.payout.toLong - leftEndpoint.payout.toLong) / (rightEndpoint.outcome - leftEndpoint.outcome)
(rightEndpoint.payout - leftEndpoint.payout) / (rightEndpoint.outcome - leftEndpoint.outcome)
}
override def apply(outcome: BigDecimal): Satoshis = {
override def apply(outcome: Long): Satoshis = {
val value =
(outcome - leftEndpoint.outcome) * slope + leftEndpoint.payout.toLong
(outcome - leftEndpoint.outcome) * slope + leftEndpoint.payout
bigDecimalSats(value)
}
@ -220,14 +254,13 @@ case class OutcomePayoutQuadratic(
private lazy val (x10, x20, x21) = (-x01, -x02, -x12)
private lazy val (y0, y1, y2) = (leftEndpoint.payout.toLong,
midpoint.payout.toLong,
rightEndpoint.payout.toLong)
private lazy val (y0, y1, y2) =
(leftEndpoint.payout, midpoint.payout, rightEndpoint.payout)
private lazy val (c0, c1, c2) =
(y0 / (x01 * x02), y1 / (x10 * x12), y2 / (x20 * x21))
override def apply(outcome: BigDecimal): Satoshis = {
override def apply(outcome: Long): Satoshis = {
val x0 = outcome - leftEndpoint.outcome
val x1 = outcome - midpoint.outcome
val x2 = outcome - rightEndpoint.outcome
@ -260,10 +293,10 @@ case class OutcomePayoutCubic(
private lazy val (x10, x20, x30, x21, x31, x32) =
(-x01, -x02, -x03, -x12, -x13, -x23)
private lazy val (y0, y1, y2, y3) = (leftEndpoint.payout.toLong,
leftMidpoint.payout.toLong,
rightMidpoint.payout.toLong,
rightEndpoint.payout.toLong)
private lazy val (y0, y1, y2, y3) = (leftEndpoint.payout,
leftMidpoint.payout,
rightMidpoint.payout,
rightEndpoint.payout)
private lazy val (c0, c1, c2, c3) =
(y0 / (x01 * x02 * x03),
@ -271,7 +304,7 @@ case class OutcomePayoutCubic(
y2 / (x20 * x21 * x23),
y3 / (x30 * x31 * x32))
override def apply(outcome: BigDecimal): Satoshis = {
override def apply(outcome: Long): Satoshis = {
val x0 = outcome - leftEndpoint.outcome
val x1 = outcome - leftMidpoint.outcome
val x2 = outcome - rightMidpoint.outcome
@ -317,13 +350,13 @@ case class OutcomePayoutPolynomial(points: Vector[OutcomePayoutPoint])
}
}
yi.toLong / denom
yi / denom
}
}
override def apply(outcome: BigDecimal): Satoshis = {
override def apply(outcome: Long): Satoshis = {
points.find(_.outcome == outcome) match {
case Some(point) => point.payout
case Some(point) => bigDecimalSats(point.payout)
case None =>
val allProd = points.foldLeft(BigDecimal(1)) {
case (prodSoFar, point) =>

View file

@ -850,7 +850,11 @@ object ContractInfoV0TLV extends TLVFactory[ContractInfoV0TLV] {
}
}
case class TLVPoint(outcome: Long, value: Satoshis, isEndpoint: Boolean)
case class TLVPoint(
outcome: Long,
value: Satoshis,
extraPrecision: Int,
isEndpoint: Boolean)
extends NetworkElement {
lazy val leadingByte: Byte = if (isEndpoint) {
@ -862,7 +866,8 @@ case class TLVPoint(outcome: Long, value: Satoshis, isEndpoint: Boolean)
override def bytes: ByteVector = {
ByteVector(leadingByte) ++
BigSizeUInt(outcome).bytes ++
UInt64(value.toLong).bytes
UInt64(value.toLong).bytes ++
UInt16(extraPrecision).bytes
}
}
@ -879,7 +884,8 @@ object TLVPoint extends Factory[TLVPoint] {
val outcome = BigSizeUInt(bytes.tail)
val value = UInt64(bytes.drop(1 + outcome.byteSize).take(8))
TLVPoint(outcome.toLong, Satoshis(value.toLong), isEndpoint)
val extraPrecision = UInt16(bytes.drop(9 + outcome.byteSize).take(2)).toInt
TLVPoint(outcome.toLong, Satoshis(value.toLong), extraPrecision, isEndpoint)
}
}