diff --git a/core-test/src/test/scala/org/bitcoins/core/protocol/dlc/CETCalculatorTest.scala b/core-test/src/test/scala/org/bitcoins/core/protocol/dlc/CETCalculatorTest.scala index 15800ea0fe..63e209403c 100644 --- a/core-test/src/test/scala/org/bitcoins/core/protocol/dlc/CETCalculatorTest.scala +++ b/core-test/src/test/scala/org/bitcoins/core/protocol/dlc/CETCalculatorTest.scala @@ -33,15 +33,13 @@ class CETCalculatorTest extends BitcoinSUnitTest { )) val expected = Vector( - ZeroPayoutRange(0, 20), + ConstantPayoutRange(0, 20), // 0 VariablePayoutRange(21, 39), ConstantPayoutRange(40, 50), - VariablePayoutRange(51, 69), - ZeroPayoutRange(70, 70), - VariablePayoutRange(71, 79), + VariablePayoutRange(51, 79), ConstantPayoutRange(80, 90), VariablePayoutRange(91, 98), - MaxPayoutRange(99, 108), + ConstantPayoutRange(99, 108), // totalCollateral VariablePayoutRange(109, 110) ) diff --git a/core/src/main/scala/org/bitcoins/core/protocol/dlc/CETCalculator.scala b/core/src/main/scala/org/bitcoins/core/protocol/dlc/CETCalculator.scala index 084a63cac3..d32e4677c5 100644 --- a/core/src/main/scala/org/bitcoins/core/protocol/dlc/CETCalculator.scala +++ b/core/src/main/scala/org/bitcoins/core/protocol/dlc/CETCalculator.scala @@ -47,15 +47,6 @@ object CETCalculator { def indexTo: Long } - /** This range contains payouts all <= 0 - * (Note that interpolated functions are allowed - * to be negative, but we set all negative values to 0). - */ - case class ZeroPayoutRange(indexFrom: Long, indexTo: Long) extends CETRange - - /** This range contains payouts all == totalCollateral */ - case class MaxPayoutRange(indexFrom: Long, indexTo: Long) extends CETRange - /** This range contains payouts that all vary at every step and cannot be compressed */ case class VariablePayoutRange(indexFrom: Long, indexTo: Long) extends CETRange @@ -66,21 +57,10 @@ object CETCalculator { case class ConstantPayoutRange(indexFrom: Long, indexTo: Long) extends CETRange - object CETRange { - - /** Creates a new CETRange with a single element range */ - def apply( - index: Long, - value: Satoshis, - totalCollateral: Satoshis): CETRange = { - if (value <= Satoshis.zero) { - ZeroPayoutRange(index, index) - } else if (value >= totalCollateral) { - MaxPayoutRange(index, index) - } else { - VariablePayoutRange(index, index) - } - } + /** A CETRange with a single element range */ + case class SingletonPayoutRange(index: Long) extends CETRange { + override def indexFrom: Long = index + override def indexTo: Long = index } /** Goes between from and to (inclusive) and evaluates function to split the @@ -92,136 +72,183 @@ object CETCalculator { totalCollateral: Satoshis, function: DLCPayoutCurve, rounding: RoundingIntervals): Vector[CETRange] = { - var componentStart = from - val Indexed(firstCurrentFunc, firstComponentIndex) = - function.componentFor(from) - var (currentFunc, componentIndex) = (firstCurrentFunc, firstComponentIndex) - var prevFunc = currentFunc + require(from <= to, s"from ($from) cannot be greater than to ($to).") - val rangeBuilder = Vector.newBuilder[CETRange] - var currentRange: CETRange = - CETRange(from, currentFunc(from, rounding), totalCollateral) + val Indexed(firstFunc, firstFuncIndex) = function.componentFor(from) + val firstValue = firstFunc(from, rounding, totalCollateral) + val firstRange = SingletonPayoutRange(from) - var num = from - - def newRange(value: Satoshis): Unit = { - rangeBuilder += currentRange - currentRange = CETRange(num, value, totalCollateral) - } - - def updateComponent(): Unit = { - componentStart = num - prevFunc = currentFunc - componentIndex = componentIndex + 1 - currentFunc = function.functionComponents(componentIndex) - } - - @tailrec - def processConstantComponents(): Unit = { - currentFunc match { - case OutcomePayoutConstant(_, rightEndpoint) => - val componentEnd = rightEndpoint.outcome - 1 - val funcValue = rightEndpoint.roundedPayout - - if (funcValue <= Satoshis.zero) { - currentRange match { - case ZeroPayoutRange(indexFrom, _) => - currentRange = ZeroPayoutRange(indexFrom, componentEnd) - case _: MaxPayoutRange | _: VariablePayoutRange | - _: ConstantPayoutRange => - rangeBuilder += currentRange - currentRange = ZeroPayoutRange(componentStart, componentEnd) - } - } else if (funcValue >= totalCollateral) { - currentRange match { - case MaxPayoutRange(indexFrom, _) => - currentRange = MaxPayoutRange(indexFrom, componentEnd) - case _: ZeroPayoutRange | _: VariablePayoutRange | - _: ConstantPayoutRange => - rangeBuilder += currentRange - currentRange = MaxPayoutRange(componentStart, componentEnd) - } - } else if (num != from && funcValue == prevFunc(num - 1, rounding)) { - currentRange match { - case VariablePayoutRange(indexFrom, indexTo) => - rangeBuilder += VariablePayoutRange(indexFrom, indexTo - 1) - currentRange = ConstantPayoutRange(indexTo, componentEnd) - case ConstantPayoutRange(indexFrom, _) => - currentRange = ConstantPayoutRange(indexFrom, componentEnd) - case _: ZeroPayoutRange | _: MaxPayoutRange => - throw new RuntimeException("Something has gone horribly wrong.") - } - } else { - rangeBuilder += currentRange - currentRange = ConstantPayoutRange(componentStart, componentEnd) - } - - num = componentEnd + 1 - if (num != to) { - updateComponent() - processConstantComponents() - } - case _: DLCPayoutCurvePiece => () - } - } - - processConstantComponents() - - while (num <= to) { - if (num == currentFunc.rightEndpoint.outcome && num != to) { - updateComponent() - - processConstantComponents() - } - - val value = currentFunc(num, rounding) - if (value <= Satoshis.zero) { - currentRange match { - case ZeroPayoutRange(indexFrom, _) => - currentRange = ZeroPayoutRange(indexFrom, num) - case _: MaxPayoutRange | _: VariablePayoutRange | - _: ConstantPayoutRange => - newRange(value) - } - } else if (value >= totalCollateral) { - currentRange match { - case MaxPayoutRange(indexFrom, _) => - currentRange = MaxPayoutRange(indexFrom, num) - case _: ZeroPayoutRange | _: VariablePayoutRange | - _: ConstantPayoutRange => - newRange(value) - } - } else if ( - num != from && - (num - 1 >= componentStart && value == currentFunc(num - 1, - rounding)) || - (num - 1 < componentStart && value == prevFunc(num - 1, rounding)) - ) { - currentRange match { - case VariablePayoutRange(indexFrom, indexTo) => - rangeBuilder += VariablePayoutRange(indexFrom, indexTo - 1) - currentRange = ConstantPayoutRange(num - 1, num) - case ConstantPayoutRange(indexFrom, _) => - currentRange = ConstantPayoutRange(indexFrom, num) - case _: ZeroPayoutRange | _: MaxPayoutRange => - throw new RuntimeException("Something has gone horribly wrong.") - } + val (currentFunc, currentFuncIndex) = + if (from == firstFunc.rightEndpoint.outcome && from != to) { + (function.functionComponents(firstFuncIndex + 1), firstFuncIndex + 1) } else { - currentRange match { - case VariablePayoutRange(indexFrom, _) => - currentRange = VariablePayoutRange(indexFrom, num) - case _: ZeroPayoutRange | _: MaxPayoutRange | - _: ConstantPayoutRange => - newRange(value) + (firstFunc, firstFuncIndex) + } + + splitIntoRangesLoop( + currentOutcome = from + 1, + to = to, + totalCollateral = totalCollateral, + function = function, + rounding = rounding, + cetRangesSoFar = Vector.empty, + currentCETRange = firstRange, + prevValue = firstValue, + currentFunc = currentFunc, + currentFuncIndex = currentFuncIndex + ) + } + + /** Helper case class to return information from processConstantsForSplitIntoRangesLoop */ + private case class SplitIntoRangesLoopResult( + nextCETRangesSoFar: Vector[CETRange], + nextCETRange: CETRange, + nextFunc: DLCPayoutCurvePiece, + nextFuncIndex: Int) + + /** Processes the constant range [currentOutcome, constantTo]. + * + * Sadly the redundant call to splitIntoRangesLoop must be repeated + * in all places that call this function otherwise scala cannot tell + * that it is tail recursive. + */ + private def processConstantsForSplitIntoRangesLoop( + constantTo: Long, + value: Satoshis, + prevValue: Satoshis, + currentCETRange: CETRange, + cetRangesSoFar: Vector[CETRange], + currentOutcome: Long, + function: DLCPayoutCurve, + currentFunc: DLCPayoutCurvePiece, + to: Long, + currentFuncIndex: Int): SplitIntoRangesLoopResult = { + val (nextCETRangesSoFar, nextCETRange) = + if (value == prevValue) { + currentCETRange match { + case VariablePayoutRange(indexFrom, indexTo) => + val newCETRangesSoFar = + cetRangesSoFar.:+(VariablePayoutRange(indexFrom, indexTo - 1)) + val newCurrentCETRange = + ConstantPayoutRange(indexTo, constantTo) + (newCETRangesSoFar, newCurrentCETRange) + case _: ConstantPayoutRange | _: SingletonPayoutRange => + val newCurrentCETRange = + ConstantPayoutRange(currentCETRange.indexFrom, constantTo) + (cetRangesSoFar, newCurrentCETRange) + } + } else if (constantTo > currentOutcome) { + val newCETRangesSoFar = cetRangesSoFar.:+(currentCETRange) + val newCurrentCETRange = + ConstantPayoutRange(currentOutcome, constantTo) + (newCETRangesSoFar, newCurrentCETRange) + } else { + currentCETRange match { + case _: VariablePayoutRange | _: SingletonPayoutRange => + val newCurrentCETRange = + VariablePayoutRange(currentCETRange.indexFrom, currentOutcome) + (cetRangesSoFar, newCurrentCETRange) + case _: ConstantPayoutRange => + val newCETRangesSoFar = cetRangesSoFar.:+(currentCETRange) + val newCurrentCETRange = SingletonPayoutRange(currentOutcome) + (newCETRangesSoFar, newCurrentCETRange) } } - num += 1 + val (nextFunc, nextFuncIndex) = + if ( + constantTo + 1 == currentFunc.rightEndpoint.outcome && constantTo + 1 != to + ) { + (function.functionComponents(currentFuncIndex + 1), + currentFuncIndex + 1) + } else { + (currentFunc, currentFuncIndex) + } + + SplitIntoRangesLoopResult(nextCETRangesSoFar, + nextCETRange, + nextFunc, + nextFuncIndex) + } + + @tailrec + private def splitIntoRangesLoop( + currentOutcome: Long, + to: Long, + totalCollateral: Satoshis, + function: DLCPayoutCurve, + rounding: RoundingIntervals, + cetRangesSoFar: Vector[CETRange], + currentCETRange: CETRange, + prevValue: Satoshis, + currentFunc: DLCPayoutCurvePiece, + currentFuncIndex: Int): Vector[CETRange] = { + if (currentOutcome > to) { + cetRangesSoFar.:+(currentCETRange) + } else { + val value = currentFunc(currentOutcome, rounding, totalCollateral) + + def processConstants(constantTo: Long): SplitIntoRangesLoopResult = { + processConstantsForSplitIntoRangesLoop(constantTo, + value, + prevValue, + currentCETRange, + cetRangesSoFar, + currentOutcome, + function, + currentFunc, + to, + currentFuncIndex) + } + + currentFunc match { + case constant: OutcomePayoutConstant => + val rightEndpoint = constant.rightEndpoint.outcome + val componentEnd = if (rightEndpoint == to) { + rightEndpoint + } else { + math.min(rightEndpoint - 1, to) + } + + val SplitIntoRangesLoopResult(nextCETRangesSoFar, + nextCETRange, + nextFunc, + nextFuncIndex) = + processConstants(constantTo = componentEnd) + + splitIntoRangesLoop( + currentOutcome = componentEnd + 1, + to = to, + totalCollateral = totalCollateral, + function = function, + rounding = rounding, + cetRangesSoFar = nextCETRangesSoFar, + currentCETRange = nextCETRange, + prevValue = value, + currentFunc = nextFunc, + currentFuncIndex = nextFuncIndex + ) + case _: DLCPayoutCurvePiece => + val SplitIntoRangesLoopResult(nextCETRangesSoFar, + nextCETRange, + nextFunc, + nextFuncIndex) = + processConstants(constantTo = currentOutcome) + + splitIntoRangesLoop( + currentOutcome = currentOutcome + 1, + to = to, + totalCollateral = totalCollateral, + function = function, + rounding = rounding, + cetRangesSoFar = nextCETRangesSoFar, + currentCETRange = nextCETRange, + prevValue = value, + currentFunc = nextFunc, + currentFuncIndex = nextFuncIndex + ) + } } - - rangeBuilder += currentRange - - rangeBuilder.result() } /** Searches for an outcome which contains a prefix of digits */ @@ -402,26 +429,24 @@ object CETCalculator { ranges.flatMap { range => range match { - case ZeroPayoutRange(indexFrom, indexTo) => - groupByIgnoringDigits(indexFrom, indexTo, base, numDigits).map { - decomp => - CETOutcome(decomp, payout = Satoshis.zero) - } - case MaxPayoutRange(indexFrom, indexTo) => - groupByIgnoringDigits(indexFrom, indexTo, base, numDigits).map { - decomp => - CETOutcome(decomp, payout = totalCollateral) - } case ConstantPayoutRange(indexFrom, indexTo) => groupByIgnoringDigits(indexFrom, indexTo, base, numDigits).map { decomp => - CETOutcome(decomp, payout = function(indexFrom, rounding)) + CETOutcome(decomp, + payout = + function(indexFrom, rounding, totalCollateral)) } case VariablePayoutRange(indexFrom, indexTo) => indexFrom.to(indexTo).map { num => val decomp = NumberUtil.decompose(num, base, numDigits) - CETOutcome(decomp, payout = function(num)) + CETOutcome(decomp, + payout = function(num, rounding, totalCollateral)) } + case SingletonPayoutRange(index) => + val decomp = NumberUtil.decompose(index, base, numDigits) + Vector( + CETOutcome(decomp, + payout = function(index, rounding, totalCollateral))) } } } diff --git a/core/src/main/scala/org/bitcoins/core/protocol/dlc/DLCPayoutCurve.scala b/core/src/main/scala/org/bitcoins/core/protocol/dlc/DLCPayoutCurve.scala index 8f323c43cd..510dfbe0d8 100644 --- a/core/src/main/scala/org/bitcoins/core/protocol/dlc/DLCPayoutCurve.scala +++ b/core/src/main/scala/org/bitcoins/core/protocol/dlc/DLCPayoutCurve.scala @@ -65,10 +65,24 @@ case class DLCPayoutCurve(points: Vector[OutcomePayoutPoint]) { func(outcome, rounding) } + def getPayout( + outcome: Long, + rounding: RoundingIntervals, + totalCollateral: Satoshis): Satoshis = { + val Indexed(func, _) = componentFor(outcome) + func(outcome, rounding, totalCollateral) + } + def apply(outcome: Long): Satoshis = getPayout(outcome) def apply(outcome: Long, rounding: RoundingIntervals): Satoshis = getPayout(outcome, rounding) + + def apply( + outcome: Long, + rounding: RoundingIntervals, + totalCollateral: Satoshis): Satoshis = + getPayout(outcome, rounding, totalCollateral) } object DLCPayoutCurve { @@ -188,6 +202,16 @@ sealed trait DLCPayoutCurvePiece { rounding.round(outcome, apply(outcome)) } + def apply( + outcome: Long, + rounding: RoundingIntervals, + totalCollateral: Satoshis): Satoshis = { + val rounded = rounding.round(outcome, apply(outcome)).toLong + val modified = math.min(math.max(rounded, 0), totalCollateral.toLong) + + Satoshis(modified) + } + /** Returns the largest Long less than or equal to bd (floor function) */ protected def bigDecimalSats(bd: BigDecimal): Satoshis = { Satoshis(