Add OrderedTLVPoints as this is an invariant in the codebase in DLCPayoutCurve (#4874)

This commit is contained in:
Chris Stewart 2022-11-02 17:20:55 -05:00 committed by GitHub
parent 17fc49c772
commit 92613709aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 45 additions and 10 deletions

View File

@ -149,7 +149,7 @@ object DLCParsingTestVector extends TestVectorParser[DLCParsingTestVector] {
"length" -> Element(tlv.length),
"numPieces" -> Element(UInt16(pieces.length)),
"endpointsAndPieces" -> MultiElement(
endpoints
endpoints.toVector
.zip(pieces)
.flatMap { case (leftEndpoint, piece) =>
Vector(leftEndpoint, piece)

View File

@ -1,5 +1,6 @@
package org.bitcoins
import org.bitcoins.core.protocol.tlv.TLVPoint
import org.bitcoins.core.protocol.transaction.{
TransactionInput,
TransactionOutput
@ -88,4 +89,12 @@ package object core {
}
}
}
implicit val tlvPointOrdering: Ordering[TLVPoint] = {
new Ordering[TLVPoint] {
override def compare(point1: TLVPoint, point2: TLVPoint): Int = {
point1.outcome.compare(point2.outcome)
}
}
}
}

View File

@ -2,6 +2,7 @@ package org.bitcoins.core.protocol.dlc.models
import org.bitcoins.core.currency.{CurrencyUnit, Satoshis}
import org.bitcoins.core.protocol.tlv._
import org.bitcoins.core.util.sorted.OrderedTLVPoints
import org.bitcoins.core.util.{Indexed, NumberUtil}
import scala.math.BigDecimal.RoundingMode
@ -22,9 +23,10 @@ case class DLCPayoutCurve(
override def toTLV: PayoutFunctionV0TLV = {
val tlvEndpoints = endpoints.map(_.toTLVPoint)
val orderedTlvPoints = OrderedTLVPoints(tlvEndpoints)
val tlvPieces = pieces.map(_.toTLV)
PayoutFunctionV0TLV(tlvEndpoints, tlvPieces, serializationVersion)
PayoutFunctionV0TLV(orderedTlvPoints, tlvPieces, serializationVersion)
}
private lazy val endpointOutcomes = endpoints.map(_.outcome)
@ -84,7 +86,7 @@ object DLCPayoutCurve
override def fromTLV(tlv: PayoutFunctionV0TLV): DLCPayoutCurve = {
val pieces =
tlv.endpoints.init.zip(tlv.endpoints.tail).zip(tlv.pieces).map {
tlv.endpoints.toVector.init.zip(tlv.endpoints.tail).zip(tlv.pieces).map {
case ((leftEndpoint, rightEndpoint), tlvPiece) =>
DLCPayoutCurvePiece.fromTLV(leftEndpoint, tlvPiece, rightEndpoint)
}

View File

@ -24,7 +24,8 @@ import org.bitcoins.core.psbt.InputPSBTRecord.PartialSignature
import org.bitcoins.core.util.sorted.{
OrderedAnnouncements,
OrderedNonces,
OrderedSchnorrSignatures
OrderedSchnorrSignatures,
OrderedTLVPoints
}
import org.bitcoins.core.wallet.fee.SatoshisPerVirtualByte
import org.bitcoins.crypto._
@ -1293,7 +1294,7 @@ case class OldPayoutFunctionV0TLV(points: Vector[OldTLVPoint])
/** @see https://github.com/discreetlogcontracts/dlcspecs/blob/8ee4bbe816c9881c832b1ce320b9f14c72e3506f/NumericOutcome.md#curve-serialization */
case class PayoutFunctionV0TLV(
endpoints: Vector[TLVPoint],
endpoints: OrderedTLVPoints,
pieces: Vector[PayoutCurvePieceTLV],
serializationVersion: DLCSerializationVersion)
extends DLCSetupPieceTLV {
@ -1305,20 +1306,20 @@ case class PayoutFunctionV0TLV(
override val value: ByteVector = {
u16PrefixedList[(TLVPoint, PayoutCurvePieceTLV)](
endpoints.init.zip(pieces),
endpoints.toVector.init.zip(pieces),
{ case (leftEndpoint: TLVPoint, piece: PayoutCurvePieceTLV) =>
leftEndpoint.bytes ++ piece.bytes
}) ++ endpoints.last.bytes
}) ++ endpoints.toVector.last.bytes
}
def piecewisePolynomialEndpoints: Vector[PiecewisePolynomialEndpoint] = {
endpoints.map(e => PiecewisePolynomialEndpoint(e.outcome, e.value))
endpoints.toVector.map(e => PiecewisePolynomialEndpoint(e.outcome, e.value))
}
override val byteSize: Long = {
serializationVersion match {
case DLCSerializationVersion.Alpha =>
val old = OldPayoutFunctionV0TLV(endpoints.map(p =>
val old = OldPayoutFunctionV0TLV(endpoints.toVector.map(p =>
OldTLVPoint(p.outcome, p.value, p.extraPrecision, true)))
old.byteSize
case DLCSerializationVersion.Beta =>
@ -1340,10 +1341,14 @@ object PayoutFunctionV0TLV extends TLVFactory[PayoutFunctionV0TLV] {
(leftEndpoint, piece)
}
val rightEndpoint = iter.take(TLVPoint)
//we assume that points are in ordered when they are serialized
//if they are not, the person that serialized them is not following
//the spec
val endpoints = endpointsAndPieces.map(_._1).:+(rightEndpoint)
val orderedEndpoints = OrderedTLVPoints(endpoints)
val pieces = endpointsAndPieces.map(_._2)
PayoutFunctionV0TLV(endpoints,
PayoutFunctionV0TLV(orderedEndpoints,
pieces,
serializationVersion = DLCSerializationVersion.Beta)
}

View File

@ -0,0 +1,19 @@
package org.bitcoins.core.util.sorted
import org.bitcoins.core.protocol.tlv.TLVPoint
case class OrderedTLVPoints(private val vec: Vector[TLVPoint])
extends SortedVec[TLVPoint, TLVPoint](vec,
org.bitcoins.core.tlvPointOrdering)
object OrderedTLVPoints extends SortedVecFactory[TLVPoint, OrderedTLVPoints] {
override def apply(point: TLVPoint): OrderedTLVPoints = {
OrderedTLVPoints(Vector(point))
}
override def fromUnsorted(vec: Vector[TLVPoint]): OrderedTLVPoints = {
val sorted = vec.sorted(org.bitcoins.core.tlvPointOrdering)
OrderedTLVPoints(sorted)
}
}