Fix TLV parsing for non-standard strings (#2312)

* Fix TLV parsing for non-standard strings

* Create function

* Fix oracle migrations

* Forced all TLV strings to be normalized implicitly

* Removed redundant normalization

* Fix oracle

* Bump migration test

* Fix 2.12.12 compile

* Use NetworkElement & StringFactory

Co-authored-by: nkohen <nadavk25@gmail.com>
This commit is contained in:
Ben Carman 2020-12-03 12:30:17 -06:00 committed by GitHub
parent fd08c98be0
commit b3d70f559a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 130 additions and 50 deletions

View file

@ -128,7 +128,7 @@ case class OracleRoutes(oracle: DLCOracle)(implicit
case Some(event: OracleEvent) =>
val outcomesJson = event.eventDescriptorTLV match {
case enum: EnumEventDescriptorV0TLV =>
enum.outcomes.map(Str)
enum.outcomes.map(outcome => Str(outcome.normStr))
case range: RangeEventDescriptorV0TLV =>
val outcomes: Vector[Long] = {
val startL = range.start.toLong

View file

@ -501,9 +501,12 @@ lazy val walletTest = project
.dependsOn(core % testAndCompile, testkit, wallet)
.enablePlugins(FlywayPlugin)
lazy val oracleDbSettings = dbFlywaySettings("oracle")
lazy val dlcOracle = project
.in(file("dlc-oracle"))
.settings(CommonSettings.prodSettings: _*)
.settings(oracleDbSettings: _*)
.settings(
name := "bitcoin-s-dlc-oracle",
libraryDependencies ++= Deps.dlcOracle

View file

@ -90,6 +90,13 @@ object TLV extends TLVParentFactory[TLV] {
case None => UnknownTLV(tpe, value)
}
}
def getStringBytes(str: NormalizedString): ByteVector = {
val strBytes = str.bytes
val size = BigSizeUInt(strBytes.size)
size.bytes ++ strBytes
}
}
sealed trait TLVFactory[+T <: TLV] extends Factory[T] {
@ -141,11 +148,11 @@ sealed trait TLVFactory[+T <: TLV] extends Factory[T] {
}
}
def takeString(): String = {
def takeString(): NormalizedString = {
val size = BigSizeUInt(current)
skip(size.byteSize)
val strBytes = take(size.toInt)
new String(strBytes.toArray, StandardCharsets.UTF_8)
NormalizedString(strBytes)
}
def takeSPK(): ScriptPubKey = {
@ -155,6 +162,51 @@ sealed trait TLVFactory[+T <: TLV] extends Factory[T] {
}
}
case class NormalizedString(private val str: String) extends NetworkElement {
val normStr: String = CryptoUtil.normalize(str)
override def equals(other: Any): Boolean = {
other match {
case otherStr: String =>
normStr == otherStr
case _ => other.equals(str)
}
}
override def toString: String = normStr
override def bytes: ByteVector = CryptoUtil.serializeForHash(normStr)
}
object NormalizedString extends StringFactory[NormalizedString] {
def apply(bytes: ByteVector): NormalizedString = {
NormalizedString(new String(bytes.toArray, StandardCharsets.UTF_8))
}
import scala.language.implicitConversions
implicit def stringToNormalized(str: String): NormalizedString =
NormalizedString(str)
implicit def normalizedToString(normalized: NormalizedString): String =
normalized.normStr
// If other kinds of Iterables are needed, there's a fancy thing to do
// that is done all over the Seq code using params and an implicit CanBuildFrom
implicit def stringVecToNormalized(
strs: Vector[String]): Vector[NormalizedString] =
strs.map(apply)
implicit def normalizedVecToString(
strs: Vector[NormalizedString]): Vector[String] =
strs.map(_.normStr)
override def fromString(string: String): NormalizedString =
NormalizedString(string)
}
case class UnknownTLV(tpe: BigSizeUInt, value: ByteVector) extends TLV {
require(!TLV.knownTypes.contains(tpe), s"Type $tpe is known")
}
@ -253,7 +305,7 @@ object EventDescriptorTLV extends TLVParentFactory[EventDescriptorTLV] {
* @param outcomes The set of possible outcomes
* @see https://github.com/discreetlogcontracts/dlcspecs/blob/master/Oracle.md#simple-enumeration
*/
case class EnumEventDescriptorV0TLV(outcomes: Vector[String])
case class EnumEventDescriptorV0TLV(outcomes: Vector[NormalizedString])
extends EventDescriptorTLV {
override def tpe: BigSizeUInt = EnumEventDescriptorV0TLV.tpe
@ -261,8 +313,8 @@ case class EnumEventDescriptorV0TLV(outcomes: Vector[String])
val starting = UInt16(outcomes.size).bytes
outcomes.foldLeft(starting) { (accum, outcome) =>
val outcomeBytes = CryptoUtil.serializeForHash(outcome)
accum ++ UInt16(outcomeBytes.length).bytes ++ outcomeBytes
val outcomeBytes = TLV.getStringBytes(outcome)
accum ++ outcomeBytes
}
}
@ -278,19 +330,18 @@ object EnumEventDescriptorV0TLV extends TLVFactory[EnumEventDescriptorV0TLV] {
val count = UInt16(iter.takeBits(16))
val builder = Vector.newBuilder[String]
val builder = Vector.newBuilder[NormalizedString]
while (iter.index < value.length) {
val len = UInt16(iter.takeBits(16))
val outcomeBytes = iter.take(len.toInt)
val str = new String(outcomeBytes.toArray, StandardCharsets.UTF_8)
val str = iter.takeString()
builder.+=(str)
}
val result = builder.result()
require(count.toInt == result.size,
"Did not parse the expected number of outcomes")
require(
count.toInt == result.size,
s"Did not parse the expected number of outcomes, ${count.toInt} != ${result.size}")
EnumEventDescriptorV0TLV(result)
}
@ -299,12 +350,12 @@ object EnumEventDescriptorV0TLV extends TLVFactory[EnumEventDescriptorV0TLV] {
sealed trait NumericEventDescriptorTLV extends EventDescriptorTLV {
/** The minimum valid value in the oracle can sign */
def min: Vector[String]
def min: Vector[NormalizedString]
def minNum: BigInt
/** The maximum valid value in the oracle can sign */
def max: Vector[String]
def max: Vector[NormalizedString]
def maxNum: BigInt
@ -320,7 +371,7 @@ sealed trait NumericEventDescriptorTLV extends EventDescriptorTLV {
def base: UInt16
/** The unit of the outcome value */
def unit: String
def unit: NormalizedString
/** The precision of the outcome representing the base exponent
* by which to multiply the number represented by the composition
@ -361,29 +412,26 @@ case class RangeEventDescriptorV0TLV(
start: Int32,
count: UInt32,
step: UInt16,
unit: String,
unit: NormalizedString,
precision: Int32)
extends NumericEventDescriptorTLV {
override val minNum: BigInt = BigInt(start.toInt)
override val min: Vector[String] = Vector(minNum.toString)
override val min: Vector[NormalizedString] = Vector(minNum.toString)
override val maxNum: BigInt =
start.toLong + (step.toLong * (count.toLong - 1))
override val max: Vector[String] = Vector(maxNum.toString)
override val max: Vector[NormalizedString] = Vector(maxNum.toString)
override val base: UInt16 = UInt16(10)
override val tpe: BigSizeUInt = RangeEventDescriptorV0TLV.tpe
override val value: ByteVector = {
val unitSize = BigSizeUInt(unit.length)
val unitBytes = CryptoUtil.serializeForHash(unit)
start.bytes ++ count.bytes ++ step.bytes ++
unitSize.bytes ++ unitBytes ++ precision.bytes
TLV.getStringBytes(unit) ++ precision.bytes
}
override def noncesNeeded: Int = 1
@ -420,10 +468,10 @@ trait DigitDecompositionEventDescriptorV0TLV extends NumericEventDescriptorTLV {
override lazy val maxNum: BigInt = base.toBigInt.pow(numDigits.toInt) - 1
private lazy val maxDigit = (base.toInt - 1).toString
private lazy val maxDigit: NormalizedString = (base.toInt - 1).toString
override lazy val max: Vector[String] = if (isSigned) {
"+" +: Vector.fill(numDigits.toInt)(maxDigit)
override lazy val max: Vector[NormalizedString] = if (isSigned) {
NormalizedString("+") +: Vector.fill(numDigits.toInt)(maxDigit)
} else {
Vector.fill(numDigits.toInt)(maxDigit)
}
@ -434,8 +482,8 @@ trait DigitDecompositionEventDescriptorV0TLV extends NumericEventDescriptorTLV {
0
}
override lazy val min: Vector[String] = if (isSigned) {
"-" +: Vector.fill(numDigits.toInt)(maxDigit)
override lazy val min: Vector[NormalizedString] = if (isSigned) {
NormalizedString("-") +: Vector.fill(numDigits.toInt)(maxDigit)
} else {
Vector.fill(numDigits.toInt)("0")
}
@ -450,10 +498,9 @@ trait DigitDecompositionEventDescriptorV0TLV extends NumericEventDescriptorTLV {
if (isSigned) ByteVector(TRUE_BYTE) else ByteVector(FALSE_BYTE)
val numDigitBytes = numDigits.bytes
val unitSize = BigSizeUInt(unit.length)
val unitBytes = CryptoUtil.serializeForHash(unit)
val unitBytes = TLV.getStringBytes(unit)
base.bytes ++ isSignedByte ++ unitSize.bytes ++ unitBytes ++ precision.bytes ++ numDigitBytes
base.bytes ++ isSignedByte ++ unitBytes ++ precision.bytes ++ numDigitBytes
}
override def noncesNeeded: Int = {
@ -466,7 +513,7 @@ trait DigitDecompositionEventDescriptorV0TLV extends NumericEventDescriptorTLV {
case class SignedDigitDecompositionEventDescriptor(
base: UInt16,
numDigits: UInt16,
unit: String,
unit: NormalizedString,
precision: Int32)
extends DigitDecompositionEventDescriptorV0TLV {
override val isSigned: Boolean = true
@ -476,7 +523,7 @@ case class SignedDigitDecompositionEventDescriptor(
case class UnsignedDigitDecompositionEventDescriptor(
base: UInt16,
numDigits: UInt16,
unit: String,
unit: NormalizedString,
precision: Int32)
extends DigitDecompositionEventDescriptorV0TLV {
override val isSigned: Boolean = false
@ -509,7 +556,7 @@ object DigitDecompositionEventDescriptorV0TLV
base: UInt16,
isSigned: Boolean,
numDigits: Int,
unit: String,
unit: NormalizedString,
precision: Int32): DigitDecompositionEventDescriptorV0TLV = {
if (isSigned) {
SignedDigitDecompositionEventDescriptor(base,
@ -534,7 +581,7 @@ case class OracleEventV0TLV(
nonces: Vector[SchnorrNonce],
eventMaturityEpoch: UInt32,
eventDescriptor: EventDescriptorTLV,
eventURI: String
eventId: NormalizedString
) extends OracleEventTLV {
require(eventDescriptor.noncesNeeded == nonces.size,
@ -543,11 +590,12 @@ case class OracleEventV0TLV(
override def tpe: BigSizeUInt = OracleEventV0TLV.tpe
override val value: ByteVector = {
val uriBytes = CryptoUtil.serializeForHash(eventURI)
val eventIdBytes = TLV.getStringBytes(eventId)
val numNonces = UInt16(nonces.size)
val noncesBytes = nonces.foldLeft(numNonces.bytes)(_ ++ _.bytes)
noncesBytes ++ eventMaturityEpoch.bytes ++ eventDescriptor.bytes ++ uriBytes
noncesBytes ++ eventMaturityEpoch.bytes ++ eventDescriptor.bytes ++ eventIdBytes
}
/** Gets the maturation of the event since epoch */
@ -579,9 +627,9 @@ object OracleEventV0TLV extends TLVFactory[OracleEventV0TLV] {
val eventMaturity = UInt32(iter.takeBits(32))
val eventDescriptor = EventDescriptorTLV(iter.current)
iter.skip(eventDescriptor.byteSize)
val eventURI = new String(iter.current.toArray, StandardCharsets.UTF_8)
val eventId = iter.takeString()
OracleEventV0TLV(nonces, eventMaturity, eventDescriptor, eventURI)
OracleEventV0TLV(nonces, eventMaturity, eventDescriptor, eventId)
}
}

View file

@ -122,14 +122,14 @@ class DbManagementTest extends BitcoinSAsyncTest with EmbeddedPg {
val result = oracleAppConfig.migrate()
oracleAppConfig.driver match {
case SQLite =>
val expected = 2
val expected = 3
assert(result == expected)
val flywayInfo = oracleAppConfig.info()
assert(flywayInfo.applied().length == expected)
assert(flywayInfo.pending().length == 0)
case PostgreSQL =>
val expected = 2
val expected = 3
assert(result == expected)
val flywayInfo = oracleAppConfig.info()

View file

@ -0,0 +1,2 @@
-- Fix dummy event descriptor to be a parsable one
UPDATE events SET event_descriptor_tlv = 'fdd8060800010564756d6d79' WHERE event_descriptor_tlv = 'fdd806090001000564756d6d79';

View file

@ -0,0 +1,2 @@
-- Fix dummy event descriptor to be a parsable one
UPDATE `events` SET `event_descriptor_tlv` = 'fdd8060800010564756d6d79' WHERE `event_descriptor_tlv` = 'fdd806090001000564756d6d79';

View file

@ -1,7 +1,5 @@
package org.bitcoins.dlc.oracle
import java.time.Instant
import org.bitcoins.commons.jsonmodels.dlc.SigningVersion
import org.bitcoins.core.config.BitcoinNetwork
import org.bitcoins.core.crypto.ExtKeyVersion.SegWitMainNetPriv
@ -18,6 +16,7 @@ import org.bitcoins.dlc.oracle.storage._
import org.bitcoins.dlc.oracle.util.EventDbUtil
import org.bitcoins.keymanager.{DecryptedMnemonic, WalletStorage}
import java.time.Instant
import scala.concurrent.{ExecutionContext, Future}
case class DLCOracle(private val extPrivateKey: ExtPrivateKeyHardened)(implicit
@ -273,7 +272,9 @@ case class DLCOracle(private val extPrivateKey: ExtPrivateKeyHardened)(implicit
s"No event saved with nonce ${nonce.hex} $outcome"))
}
eventOutcomeOpt <- eventOutcomeDAO.read((nonce, outcome.outcomeString))
hash = eventDb.signingVersion.calcOutcomeHash(eventDb.eventDescriptorTLV,
outcome.outcomeString)
eventOutcomeOpt <- eventOutcomeDAO.find(nonce, hash)
eventOutcomeDb <- eventOutcomeOpt match {
case Some(value) => Future.successful(value)
case None =>

View file

@ -62,10 +62,11 @@ case class DLCOracleAppConfig(
}
logger.info(s"Applied $numMigrations to the dlc oracle project")
if (migrationsApplied() == 2) {
logger.debug(s"Doing V2 Migration")
val migrations = migrationsApplied()
if (migrations == 2 || migrations == 3) { // For V2/V3 migrations
logger.debug(s"Doing V2/V3 Migration")
val dummyMigrationTLV = EventDescriptorTLV("fdd806090001000564756d6d79")
val dummyMigrationTLV = EventDescriptorTLV("fdd8060800010564756d6d79")
val eventDAO = EventDAO()(ec, appConfig)
for {

View file

@ -49,6 +49,15 @@ case class EventOutcomeDAO()(implicit
safeDatabase.runVec(query.result.transactionally)
}
def find(
nonce: SchnorrNonce,
hash: ByteVector): Future[Option[EventOutcomeDb]] = {
val query =
table.filter(item => item.nonce === nonce && item.hashedMessage === hash)
safeDatabase.run(query.result.transactionally).map(_.headOption)
}
class EventOutcomeTable(tag: Tag)
extends Table[EventOutcomeDb](tag, schemaName, "event_outcomes") {

View file

@ -2,6 +2,9 @@ package org.bitcoins.testkit.core.gen
import org.scalacheck.Gen
import java.nio.charset.StandardCharsets
import scala.util.{Failure, Success, Try}
/**
* Created by chris on 6/20/16.
*/
@ -55,6 +58,17 @@ trait StringGenerators {
randomString <- genString(randomNum)
} yield randomString
def genUTF8String: Gen[String] = {
for {
bytes <- NumberGenerator.bytes
str <- Try(new String(bytes.toArray, StandardCharsets.UTF_8)) match {
case Failure(_) =>
genUTF8String
case Success(value) =>
Gen.const(value)
}
} yield str
}
}
object StringGenerators extends StringGenerators

View file

@ -44,7 +44,7 @@ trait TLVGen {
def enumEventDescriptorV0TLV: Gen[EnumEventDescriptorV0TLV] = {
for {
numOutcomes <- Gen.choose(2, 10)
outcomes <- Gen.listOfN(numOutcomes, StringGenerators.genString)
outcomes <- Gen.listOfN(numOutcomes, StringGenerators.genUTF8String)
} yield EnumEventDescriptorV0TLV(outcomes.toVector)
}
@ -53,7 +53,7 @@ trait TLVGen {
start <- NumberGenerator.int32s
count <- NumberGenerator.uInt32s
step <- NumberGenerator.uInt16
unit <- StringGenerators.genString
unit <- StringGenerators.genUTF8String
precision <- NumberGenerator.int32s
} yield RangeEventDescriptorV0TLV(start, count, step, unit, precision)
}
@ -64,7 +64,7 @@ trait TLVGen {
base <- NumberGenerator.uInt16
isSigned <- NumberGenerator.bool
numDigits <- Gen.choose(2, 20)
unit <- StringGenerators.genString
unit <- StringGenerators.genUTF8String
precision <- NumberGenerator.int32s
} yield DigitDecompositionEventDescriptorV0TLV(base,
isSigned,
@ -81,7 +81,7 @@ trait TLVGen {
def oracleEventV0TLV: Gen[OracleEventV0TLV] = {
for {
maturity <- NumberGenerator.uInt32s
uri <- StringGenerators.genString
uri <- StringGenerators.genUTF8String
desc <- eventDescriptorTLV
nonces <-
Gen