Re-wrote CETCalculator.splitIntoRanges (#2621)

Added some docs

Responded to review
This commit is contained in:
Nadav Kohen 2021-02-05 13:24:24 -06:00 committed by GitHub
parent b481b9a087
commit e417ac94a4
3 changed files with 211 additions and 164 deletions

View File

@ -33,15 +33,13 @@ class CETCalculatorTest extends BitcoinSUnitTest {
))
val expected = Vector(
ZeroPayoutRange(0, 20),
ConstantPayoutRange(0, 20), // 0
VariablePayoutRange(21, 39),
ConstantPayoutRange(40, 50),
VariablePayoutRange(51, 69),
ZeroPayoutRange(70, 70),
VariablePayoutRange(71, 79),
VariablePayoutRange(51, 79),
ConstantPayoutRange(80, 90),
VariablePayoutRange(91, 98),
MaxPayoutRange(99, 108),
ConstantPayoutRange(99, 108), // totalCollateral
VariablePayoutRange(109, 110)
)

View File

@ -47,15 +47,6 @@ object CETCalculator {
def indexTo: Long
}
/** This range contains payouts all <= 0
* (Note that interpolated functions are allowed
* to be negative, but we set all negative values to 0).
*/
case class ZeroPayoutRange(indexFrom: Long, indexTo: Long) extends CETRange
/** This range contains payouts all == totalCollateral */
case class MaxPayoutRange(indexFrom: Long, indexTo: Long) extends CETRange
/** This range contains payouts that all vary at every step and cannot be compressed */
case class VariablePayoutRange(indexFrom: Long, indexTo: Long)
extends CETRange
@ -66,21 +57,10 @@ object CETCalculator {
case class ConstantPayoutRange(indexFrom: Long, indexTo: Long)
extends CETRange
object CETRange {
/** Creates a new CETRange with a single element range */
def apply(
index: Long,
value: Satoshis,
totalCollateral: Satoshis): CETRange = {
if (value <= Satoshis.zero) {
ZeroPayoutRange(index, index)
} else if (value >= totalCollateral) {
MaxPayoutRange(index, index)
} else {
VariablePayoutRange(index, index)
}
}
/** A CETRange with a single element range */
case class SingletonPayoutRange(index: Long) extends CETRange {
override def indexFrom: Long = index
override def indexTo: Long = index
}
/** Goes between from and to (inclusive) and evaluates function to split the
@ -92,136 +72,183 @@ object CETCalculator {
totalCollateral: Satoshis,
function: DLCPayoutCurve,
rounding: RoundingIntervals): Vector[CETRange] = {
var componentStart = from
val Indexed(firstCurrentFunc, firstComponentIndex) =
function.componentFor(from)
var (currentFunc, componentIndex) = (firstCurrentFunc, firstComponentIndex)
var prevFunc = currentFunc
require(from <= to, s"from ($from) cannot be greater than to ($to).")
val rangeBuilder = Vector.newBuilder[CETRange]
var currentRange: CETRange =
CETRange(from, currentFunc(from, rounding), totalCollateral)
val Indexed(firstFunc, firstFuncIndex) = function.componentFor(from)
val firstValue = firstFunc(from, rounding, totalCollateral)
val firstRange = SingletonPayoutRange(from)
var num = from
def newRange(value: Satoshis): Unit = {
rangeBuilder += currentRange
currentRange = CETRange(num, value, totalCollateral)
}
def updateComponent(): Unit = {
componentStart = num
prevFunc = currentFunc
componentIndex = componentIndex + 1
currentFunc = function.functionComponents(componentIndex)
}
@tailrec
def processConstantComponents(): Unit = {
currentFunc match {
case OutcomePayoutConstant(_, rightEndpoint) =>
val componentEnd = rightEndpoint.outcome - 1
val funcValue = rightEndpoint.roundedPayout
if (funcValue <= Satoshis.zero) {
currentRange match {
case ZeroPayoutRange(indexFrom, _) =>
currentRange = ZeroPayoutRange(indexFrom, componentEnd)
case _: MaxPayoutRange | _: VariablePayoutRange |
_: ConstantPayoutRange =>
rangeBuilder += currentRange
currentRange = ZeroPayoutRange(componentStart, componentEnd)
}
} else if (funcValue >= totalCollateral) {
currentRange match {
case MaxPayoutRange(indexFrom, _) =>
currentRange = MaxPayoutRange(indexFrom, componentEnd)
case _: ZeroPayoutRange | _: VariablePayoutRange |
_: ConstantPayoutRange =>
rangeBuilder += currentRange
currentRange = MaxPayoutRange(componentStart, componentEnd)
}
} else if (num != from && funcValue == prevFunc(num - 1, rounding)) {
currentRange match {
case VariablePayoutRange(indexFrom, indexTo) =>
rangeBuilder += VariablePayoutRange(indexFrom, indexTo - 1)
currentRange = ConstantPayoutRange(indexTo, componentEnd)
case ConstantPayoutRange(indexFrom, _) =>
currentRange = ConstantPayoutRange(indexFrom, componentEnd)
case _: ZeroPayoutRange | _: MaxPayoutRange =>
throw new RuntimeException("Something has gone horribly wrong.")
}
} else {
rangeBuilder += currentRange
currentRange = ConstantPayoutRange(componentStart, componentEnd)
}
num = componentEnd + 1
if (num != to) {
updateComponent()
processConstantComponents()
}
case _: DLCPayoutCurvePiece => ()
}
}
processConstantComponents()
while (num <= to) {
if (num == currentFunc.rightEndpoint.outcome && num != to) {
updateComponent()
processConstantComponents()
}
val value = currentFunc(num, rounding)
if (value <= Satoshis.zero) {
currentRange match {
case ZeroPayoutRange(indexFrom, _) =>
currentRange = ZeroPayoutRange(indexFrom, num)
case _: MaxPayoutRange | _: VariablePayoutRange |
_: ConstantPayoutRange =>
newRange(value)
}
} else if (value >= totalCollateral) {
currentRange match {
case MaxPayoutRange(indexFrom, _) =>
currentRange = MaxPayoutRange(indexFrom, num)
case _: ZeroPayoutRange | _: VariablePayoutRange |
_: ConstantPayoutRange =>
newRange(value)
}
} else if (
num != from &&
(num - 1 >= componentStart && value == currentFunc(num - 1,
rounding)) ||
(num - 1 < componentStart && value == prevFunc(num - 1, rounding))
) {
currentRange match {
case VariablePayoutRange(indexFrom, indexTo) =>
rangeBuilder += VariablePayoutRange(indexFrom, indexTo - 1)
currentRange = ConstantPayoutRange(num - 1, num)
case ConstantPayoutRange(indexFrom, _) =>
currentRange = ConstantPayoutRange(indexFrom, num)
case _: ZeroPayoutRange | _: MaxPayoutRange =>
throw new RuntimeException("Something has gone horribly wrong.")
}
val (currentFunc, currentFuncIndex) =
if (from == firstFunc.rightEndpoint.outcome && from != to) {
(function.functionComponents(firstFuncIndex + 1), firstFuncIndex + 1)
} else {
currentRange match {
case VariablePayoutRange(indexFrom, _) =>
currentRange = VariablePayoutRange(indexFrom, num)
case _: ZeroPayoutRange | _: MaxPayoutRange |
_: ConstantPayoutRange =>
newRange(value)
(firstFunc, firstFuncIndex)
}
splitIntoRangesLoop(
currentOutcome = from + 1,
to = to,
totalCollateral = totalCollateral,
function = function,
rounding = rounding,
cetRangesSoFar = Vector.empty,
currentCETRange = firstRange,
prevValue = firstValue,
currentFunc = currentFunc,
currentFuncIndex = currentFuncIndex
)
}
/** Helper case class to return information from processConstantsForSplitIntoRangesLoop */
private case class SplitIntoRangesLoopResult(
nextCETRangesSoFar: Vector[CETRange],
nextCETRange: CETRange,
nextFunc: DLCPayoutCurvePiece,
nextFuncIndex: Int)
/** Processes the constant range [currentOutcome, constantTo].
*
* Sadly the redundant call to splitIntoRangesLoop must be repeated
* in all places that call this function otherwise scala cannot tell
* that it is tail recursive.
*/
private def processConstantsForSplitIntoRangesLoop(
constantTo: Long,
value: Satoshis,
prevValue: Satoshis,
currentCETRange: CETRange,
cetRangesSoFar: Vector[CETRange],
currentOutcome: Long,
function: DLCPayoutCurve,
currentFunc: DLCPayoutCurvePiece,
to: Long,
currentFuncIndex: Int): SplitIntoRangesLoopResult = {
val (nextCETRangesSoFar, nextCETRange) =
if (value == prevValue) {
currentCETRange match {
case VariablePayoutRange(indexFrom, indexTo) =>
val newCETRangesSoFar =
cetRangesSoFar.:+(VariablePayoutRange(indexFrom, indexTo - 1))
val newCurrentCETRange =
ConstantPayoutRange(indexTo, constantTo)
(newCETRangesSoFar, newCurrentCETRange)
case _: ConstantPayoutRange | _: SingletonPayoutRange =>
val newCurrentCETRange =
ConstantPayoutRange(currentCETRange.indexFrom, constantTo)
(cetRangesSoFar, newCurrentCETRange)
}
} else if (constantTo > currentOutcome) {
val newCETRangesSoFar = cetRangesSoFar.:+(currentCETRange)
val newCurrentCETRange =
ConstantPayoutRange(currentOutcome, constantTo)
(newCETRangesSoFar, newCurrentCETRange)
} else {
currentCETRange match {
case _: VariablePayoutRange | _: SingletonPayoutRange =>
val newCurrentCETRange =
VariablePayoutRange(currentCETRange.indexFrom, currentOutcome)
(cetRangesSoFar, newCurrentCETRange)
case _: ConstantPayoutRange =>
val newCETRangesSoFar = cetRangesSoFar.:+(currentCETRange)
val newCurrentCETRange = SingletonPayoutRange(currentOutcome)
(newCETRangesSoFar, newCurrentCETRange)
}
}
num += 1
val (nextFunc, nextFuncIndex) =
if (
constantTo + 1 == currentFunc.rightEndpoint.outcome && constantTo + 1 != to
) {
(function.functionComponents(currentFuncIndex + 1),
currentFuncIndex + 1)
} else {
(currentFunc, currentFuncIndex)
}
SplitIntoRangesLoopResult(nextCETRangesSoFar,
nextCETRange,
nextFunc,
nextFuncIndex)
}
@tailrec
private def splitIntoRangesLoop(
currentOutcome: Long,
to: Long,
totalCollateral: Satoshis,
function: DLCPayoutCurve,
rounding: RoundingIntervals,
cetRangesSoFar: Vector[CETRange],
currentCETRange: CETRange,
prevValue: Satoshis,
currentFunc: DLCPayoutCurvePiece,
currentFuncIndex: Int): Vector[CETRange] = {
if (currentOutcome > to) {
cetRangesSoFar.:+(currentCETRange)
} else {
val value = currentFunc(currentOutcome, rounding, totalCollateral)
def processConstants(constantTo: Long): SplitIntoRangesLoopResult = {
processConstantsForSplitIntoRangesLoop(constantTo,
value,
prevValue,
currentCETRange,
cetRangesSoFar,
currentOutcome,
function,
currentFunc,
to,
currentFuncIndex)
}
currentFunc match {
case constant: OutcomePayoutConstant =>
val rightEndpoint = constant.rightEndpoint.outcome
val componentEnd = if (rightEndpoint == to) {
rightEndpoint
} else {
math.min(rightEndpoint - 1, to)
}
val SplitIntoRangesLoopResult(nextCETRangesSoFar,
nextCETRange,
nextFunc,
nextFuncIndex) =
processConstants(constantTo = componentEnd)
splitIntoRangesLoop(
currentOutcome = componentEnd + 1,
to = to,
totalCollateral = totalCollateral,
function = function,
rounding = rounding,
cetRangesSoFar = nextCETRangesSoFar,
currentCETRange = nextCETRange,
prevValue = value,
currentFunc = nextFunc,
currentFuncIndex = nextFuncIndex
)
case _: DLCPayoutCurvePiece =>
val SplitIntoRangesLoopResult(nextCETRangesSoFar,
nextCETRange,
nextFunc,
nextFuncIndex) =
processConstants(constantTo = currentOutcome)
splitIntoRangesLoop(
currentOutcome = currentOutcome + 1,
to = to,
totalCollateral = totalCollateral,
function = function,
rounding = rounding,
cetRangesSoFar = nextCETRangesSoFar,
currentCETRange = nextCETRange,
prevValue = value,
currentFunc = nextFunc,
currentFuncIndex = nextFuncIndex
)
}
}
rangeBuilder += currentRange
rangeBuilder.result()
}
/** Searches for an outcome which contains a prefix of digits */
@ -402,26 +429,24 @@ object CETCalculator {
ranges.flatMap { range =>
range match {
case ZeroPayoutRange(indexFrom, indexTo) =>
groupByIgnoringDigits(indexFrom, indexTo, base, numDigits).map {
decomp =>
CETOutcome(decomp, payout = Satoshis.zero)
}
case MaxPayoutRange(indexFrom, indexTo) =>
groupByIgnoringDigits(indexFrom, indexTo, base, numDigits).map {
decomp =>
CETOutcome(decomp, payout = totalCollateral)
}
case ConstantPayoutRange(indexFrom, indexTo) =>
groupByIgnoringDigits(indexFrom, indexTo, base, numDigits).map {
decomp =>
CETOutcome(decomp, payout = function(indexFrom, rounding))
CETOutcome(decomp,
payout =
function(indexFrom, rounding, totalCollateral))
}
case VariablePayoutRange(indexFrom, indexTo) =>
indexFrom.to(indexTo).map { num =>
val decomp = NumberUtil.decompose(num, base, numDigits)
CETOutcome(decomp, payout = function(num))
CETOutcome(decomp,
payout = function(num, rounding, totalCollateral))
}
case SingletonPayoutRange(index) =>
val decomp = NumberUtil.decompose(index, base, numDigits)
Vector(
CETOutcome(decomp,
payout = function(index, rounding, totalCollateral)))
}
}
}

View File

@ -65,10 +65,24 @@ case class DLCPayoutCurve(points: Vector[OutcomePayoutPoint]) {
func(outcome, rounding)
}
def getPayout(
outcome: Long,
rounding: RoundingIntervals,
totalCollateral: Satoshis): Satoshis = {
val Indexed(func, _) = componentFor(outcome)
func(outcome, rounding, totalCollateral)
}
def apply(outcome: Long): Satoshis = getPayout(outcome)
def apply(outcome: Long, rounding: RoundingIntervals): Satoshis =
getPayout(outcome, rounding)
def apply(
outcome: Long,
rounding: RoundingIntervals,
totalCollateral: Satoshis): Satoshis =
getPayout(outcome, rounding, totalCollateral)
}
object DLCPayoutCurve {
@ -188,6 +202,16 @@ sealed trait DLCPayoutCurvePiece {
rounding.round(outcome, apply(outcome))
}
def apply(
outcome: Long,
rounding: RoundingIntervals,
totalCollateral: Satoshis): Satoshis = {
val rounded = rounding.round(outcome, apply(outcome)).toLong
val modified = math.min(math.max(rounded, 0), totalCollateral.toLong)
Satoshis(modified)
}
/** Returns the largest Long less than or equal to bd (floor function) */
protected def bigDecimalSats(bd: BigDecimal): Satoshis = {
Satoshis(