Implement basic TLV functionality (#1847)

* Defined BigSizeUInt as in BOLT 1

* Introduced UInt16 and TLV types

* Responded to review

* Responded to review
This commit is contained in:
Nadav Kohen 2020-08-20 13:42:16 -06:00 committed by GitHub
parent 2f8dcd1e57
commit e8b195f477
8 changed files with 679 additions and 1 deletions

View file

@ -0,0 +1,212 @@
package org.bitcoins.core.number
import org.bitcoins.testkit.core.gen.NumberGenerator
import org.bitcoins.testkit.util.BitcoinSUnitTest
import org.scalacheck.Gen
import scodec.bits.ByteVector
import scala.util.Try
class UInt16Test extends BitcoinSUnitTest {
implicit override val generatorDrivenConfig: PropertyCheckConfiguration =
generatorDrivenConfigNewCode
behavior of "UInt16"
it must "create the number zero as an unsigned 16 bit integer" in {
val zero = UInt16(ByteVector(0x0.toByte))
assert(zero.toInt == 0)
}
it must "create the max number for an unsigned byte" in {
val maxByteValue = UInt16(ByteVector(0xff.toByte))
assert(maxByteValue.toInt == 255)
}
it must "create the number 256" in {
val uInt16 = UInt16(ByteVector(0x01.toByte, 0x00.toByte))
assert(uInt16.toInt == 256)
}
it must "create the number 65535" in {
val uInt16 = UInt16(ByteVector(0xff.toByte, 0xff.toByte))
assert(uInt16.toInt == 65535)
}
it must "have the correct maximum number representation for UInt16" in {
assert(UInt16.max.toInt == 65535)
assert(UInt16.max.hex == "ffff")
}
it must "fail to create the number 65536" in {
assertThrows[IllegalArgumentException] {
UInt16(ByteVector(0x01.toByte, 0x0.toByte, 0x0.toByte))
}
assertThrows[IllegalArgumentException] {
UInt16(65536)
}
}
it must "throw an exception if we try and create a UInt16 with a negative number" in {
assertThrows[IllegalArgumentException] {
UInt16(-1)
}
}
it must "throw an exception if we try and create a UInt16 with more than 2 bytes" in {
assertThrows[IllegalArgumentException] {
UInt16(ByteVector(0.toByte, 0.toByte, 0.toByte))
}
}
it must "have the correct representation for 0" in {
assert(UInt16.zero.toInt == 0)
}
it must "have the correct representation for 1" in {
assert(UInt16.one.toInt == 1)
}
it must "have the correct minimum number for a UInt16" in {
assert(UInt16.min.toInt == 0)
}
it must "have serialization symmetry" in {
forAll(NumberGenerator.uInt16) { uInt16: UInt16 =>
assert(UInt16(uInt16.hex) == uInt16)
assert(UInt16(uInt16.hex).hex == uInt16.hex)
}
}
it must "add zero correctly" in {
forAll(NumberGenerator.uInt16) { num: UInt16 =>
assert(num + UInt16.zero == num)
}
}
it must "Negative numbers in UInt16 throw an exception" in {
forAll(NumberGenerator.negativeInts) { num =>
val uint16 = Try(UInt16(num))
assert(uint16.isFailure)
}
}
it must "add two uint16s and get the mathematical sum of the two numbers" in {
forAll(NumberGenerator.uInt16, NumberGenerator.uInt16) {
(num1: UInt16, num2: UInt16) =>
val result = num1.toInt + num2.toInt
if (result <= UInt16.max.toInt) {
assert(num1 + num2 == UInt16(result.toInt))
} else {
assert(Try(num1 + num2).isFailure)
}
}
}
it must "subtract zero correctly" in {
forAll(NumberGenerator.uInt16) { uInt16: UInt16 =>
assert(uInt16 - UInt16.zero == uInt16)
}
}
it must "subtract from zero correctly" in {
forAll(NumberGenerator.uInt16) { num =>
if (num == UInt16.zero) {
assert(UInt16.zero - num == UInt16.zero)
} else {
assert(Try(UInt16.zero - num).isFailure)
}
}
}
it must "subtract a uint16 from another uint16 and get the correct result" in {
forAll(NumberGenerator.uInt16, NumberGenerator.uInt16) {
(num1: UInt16, num2: UInt16) =>
val result = num1.toInt - num2.toInt
if (result >= 0) {
assert(num1 - num2 == UInt16(result))
} else {
assert(Try(num1 - num2).isFailure)
}
}
}
it must "multiplying by zero correctly" in {
forAll(NumberGenerator.uInt16) { uInt16: UInt16 =>
assert(uInt16 * UInt16.zero == UInt16.zero)
}
}
it must "multiply by one correctly" in {
forAll(NumberGenerator.uInt16) { uInt16: UInt16 =>
assert(uInt16 * UInt16.one == uInt16)
}
}
it must "multiply two UInt16s correctly" in {
forAll(NumberGenerator.uInt16, NumberGenerator.uInt16) {
(num1: UInt16, num2: UInt16) =>
val bigInt1 = num1.toBigInt
val bigInt2 = num2.toBigInt
if (bigInt1 * bigInt2 <= UInt16.max.toInt) {
assert(
num1 * num2 ==
UInt16(num1.toInt * num2.toInt))
} else {
assert(Try(num1 * num2).isFailure)
}
}
}
it must "compare UInt16s correctly" in {
forAll(NumberGenerator.uInt16, NumberGenerator.uInt16) {
(num1: UInt16, num2: UInt16) =>
if (num1.toInt < num2.toInt) assert(num1 < num2)
else assert(num1 >= num2)
if (num1.toInt <= num2.toInt) assert(num1 <= num2)
else assert(num1 > num2)
if (num1.toInt == num2.toInt) assert(num1 == num2)
else assert(num1 != num2)
}
}
it must "| correctly" in {
forAll(NumberGenerator.uInt16, NumberGenerator.uInt16) {
(num1: UInt16, num2: UInt16) =>
assert(UInt16(num1.toInt | num2.toInt) == (num1 | num2))
}
}
it must "& correctly" in {
forAll(NumberGenerator.uInt16, NumberGenerator.uInt16) {
(num1: UInt16, num2: UInt16) =>
assert(UInt16(num1.toInt & num2.toInt) == (num1 & num2))
}
}
it must "<< correctly" in {
forAll(NumberGenerator.uInt16, Gen.choose(0, 16)) {
case (u16, shift) =>
val r = Try(u16 << shift)
val expected = (u16.toInt << shift) & 0xffff
if (r.isSuccess && expected <= UInt16.max.toInt) {
assert(r.get == UInt16(expected))
} else {
assert(r.isFailure)
}
}
}
it must ">> correctly" in {
forAll(NumberGenerator.uInt16, Gen.choose(0, 100)) {
case (u16, shift) =>
val r = u16 >> shift
val expected = if (shift >= 32) 0 else u16.toInt >> shift
assert(r == UInt16(expected))
}
}
}

View file

@ -0,0 +1,77 @@
package org.bitcoins.core.protocol
import org.bitcoins.core.number.UInt64
import org.bitcoins.testkit.core.gen.NumberGenerator
import org.bitcoins.testkit.util.BitcoinSUnitTest
import scodec.bits.ByteVector
import scala.util.{Failure, Success, Try}
class BigSizeUIntTest extends BitcoinSUnitTest {
implicit override val generatorDrivenConfig: PropertyCheckConfiguration =
generatorDrivenConfigNewCode
behavior of "BigSizeUInt"
it must "have serialization symmetry" in {
forAll(NumberGenerator.bigSizeUInt) { num =>
assert(BigSizeUInt(num.bytes) == num)
}
}
it must "fail to parse an empty ByteVector" in {
assertThrows[IllegalArgumentException] {
BigSizeUInt(ByteVector.empty)
}
}
it must "pass encoding tests" in {
val bufferedSource =
io.Source.fromURL(getClass.getResource("/bigsize_encoding.json"))
try {
val builder = new StringBuilder
bufferedSource.getLines().foreach(builder.append)
val tests = ujson.read(builder.result()).arr.toVector
tests.foreach { test =>
val obj = test.obj
val name = obj("name").str
val num = BigInt(obj("value").str)
val bytes = ByteVector.fromValidHex(obj("bytes").str)
assert(BigSizeUInt(num).bytes == bytes, name)
}
} finally {
bufferedSource.close()
}
}
it must "pass decoding tests" in {
val bufferedSource =
io.Source.fromURL(getClass.getResource("/bigsize_decoding.json"))
try {
val builder = new StringBuilder
bufferedSource.getLines().foreach(builder.append)
val tests = ujson.read(builder.result()).arr.toVector
tests.foreach { test =>
val obj = test.obj
val name = obj("name").str
val numStr = obj("value").str
val bytes = ByteVector.fromValidHex(obj("bytes").str)
if (numStr.nonEmpty) {
assert(BigSizeUInt(bytes).num == UInt64(BigInt(numStr)), name)
} else {
Try {
assertThrows[IllegalArgumentException] {
BigSizeUInt(bytes)
}
} match {
case Failure(err) => fail(obj("exp_error").str, err)
case Success(success) => success
}
}
}
} finally {
bufferedSource.close()
}
}
}

View file

@ -0,0 +1,44 @@
package org.bitcoins.core.protocol.tlv
import org.bitcoins.testkit.core.gen.TLVGen
import org.bitcoins.testkit.util.BitcoinSUnitTest
class TLVTest extends BitcoinSUnitTest {
implicit override val generatorDrivenConfig: PropertyCheckConfiguration =
generatorDrivenConfigNewCode
"TLV" must "have serizliation symmetry" in {
forAll(TLVGen.tlv) { tlv =>
assert(TLV(tlv.bytes) == tlv)
}
}
"UnknownTLV" must "have serialization symmetry" in {
forAll(TLVGen.unknownTLV) { unknown =>
assert(UnknownTLV(unknown.bytes) == unknown)
assert(TLV(unknown.bytes) == unknown)
}
}
"ErrorTLV" must "have serialization symmetry" in {
forAll(TLVGen.errorTLV) { error =>
assert(ErrorTLV(error.bytes) == error)
assert(TLV(error.bytes) == error)
}
}
"PingTLV" must "have serialization symmetry" in {
forAll(TLVGen.pingTLV) { ping =>
assert(PingTLV(ping.bytes) == ping)
assert(TLV(ping.bytes) == ping)
}
}
"PongTLV" must "have serialization symmetry" in {
forAll(TLVGen.pongTLV) { pong =>
assert(PongTLV(pong.bytes) == pong)
assert(TLV(pong.bytes) == pong)
}
}
}

View file

@ -128,6 +128,16 @@ sealed abstract class UInt8 extends UnsignedNumber[UInt8] {
}
}
/**
* Represents a uint16_t in C
*/
sealed abstract class UInt16 extends UnsignedNumber[UInt16] {
override def apply: A => UInt16 = UInt16(_)
override def hex: String = BytesUtil.encodeHex(toInt.toShort)
override def andMask = 0xffffL
}
/**
* Represents a uint32_t in C
*/
@ -310,6 +320,57 @@ object UInt8
}
}
object UInt16
extends Factory[UInt16]
with NumberObject[UInt16]
with Bounded[UInt16] {
private case class UInt16Impl(underlying: BigInt) extends UInt16 {
require(isInBound(underlying),
s"Cannot create ${super.getClass.getSimpleName} from $underlying")
}
/** Cache from 0 to 128 UInt16 */
private val cached: Vector[UInt16] = {
0.until(128).map(i => UInt16(BigInt(i))).toVector
}
lazy val zero = cached(0)
lazy val one = cached(1)
private lazy val minUnderlying: A = 0
private lazy val maxUnderlying: A = BigInt(65535L)
lazy val min = zero
lazy val max = UInt16(maxUnderlying)
override def isInBound(num: A): Boolean =
num <= maxUnderlying && num >= minUnderlying
override def fromBytes(bytes: ByteVector): UInt16 = {
require(
bytes.size <= 2,
"UInt16 byte array was too large, got: " + BytesUtil.encodeHex(bytes))
UInt16(bytes.toLong(signed = false, ordering = ByteOrdering.BigEndian))
}
def apply(long: Long): UInt16 = {
checkCached(long)
}
def apply(bigInt: BigInt): UInt16 = {
UInt16Impl(bigInt)
}
/** Checks if we have the number cached, if not allocates a new object to represent the number */
private def checkCached(long: Long): UInt16 = {
if (long < 128 && long >= 0) cached(long.toInt)
else {
UInt16(BigInt(long))
}
}
}
object UInt32
extends Factory[UInt32]
with NumberObject[UInt32]

View file

@ -0,0 +1,78 @@
package org.bitcoins.core.protocol
import org.bitcoins.core.number.UInt64
import org.bitcoins.crypto.{Factory, NetworkElement}
import scodec.bits.ByteVector
case class BigSizeUInt(num: UInt64) extends NetworkElement {
override val byteSize: Long = BigSizeUInt.calcSizeForNum(num)
override def bytes: ByteVector = {
byteSize match {
case 1 => num.bytes.takeRight(1)
case 3 => 0xfd.toByte +: num.bytes.takeRight(2)
case 5 => 0xfe.toByte +: num.bytes.takeRight(4)
case _ => 0xff.toByte +: num.bytes
}
}
def toLong: Long = num.toLong
def toInt: Int = {
val l = toLong
require(Int.MinValue <= l && l <= Int.MaxValue,
"Cannot convert BigSizeUInt toInt, got: " + this)
l.toInt
}
override def toString: String = s"BigSizeUInt(${num.toLong})"
}
object BigSizeUInt extends Factory[BigSizeUInt] {
def calcSizeForNum(num: UInt64): Int = {
if (num.toBigInt < 0xfd) { // can be represented with one byte
1
} else if (num.toBigInt < 0x10000) { // can be represented with two bytes
3
} else if (num.toBigInt < 0x100000000L) { // can be represented with 4 bytes
5
} else {
9
}
}
def apply(num: Long): BigSizeUInt = {
BigSizeUInt(UInt64(num))
}
def apply(num: BigInt): BigSizeUInt = {
BigSizeUInt(UInt64(num))
}
override def fromBytes(bytes: ByteVector): BigSizeUInt = {
require(bytes.nonEmpty, "Cannot parse a BigSizeUInt from empty byte vector")
val prefixNum = UInt64(bytes.take(1)).toInt
val (bigSizeUInt, expectedSize) = if (prefixNum < 253) { // 8 bit number
(BigSizeUInt(prefixNum), 1)
} else if (prefixNum == 253) { // 16 bit number
(BigSizeUInt(UInt64(bytes.slice(1, 3))), 3)
} else if (prefixNum == 254) { // 32 bit number
(BigSizeUInt(UInt64(bytes.slice(1, 5))), 5)
} else { // 64 bit number
(BigSizeUInt(UInt64(bytes.slice(1, 9))), 9)
}
require(
bigSizeUInt.byteSize == expectedSize,
s"Length prefix $prefixNum did not match bytes ${bigSizeUInt.bytes.tail}")
bigSizeUInt
}
def calcFor(bytes: ByteVector): BigSizeUInt = {
BigSizeUInt(bytes.length)
}
}

View file

@ -0,0 +1,148 @@
package org.bitcoins.core.protocol.tlv
import org.bitcoins.core.number.UInt16
import org.bitcoins.core.protocol.BigSizeUInt
import org.bitcoins.core.protocol.tlv.TLV.DecodeTLVResult
import org.bitcoins.crypto.{Factory, NetworkElement}
import scodec.bits.ByteVector
sealed trait TLV extends NetworkElement {
def tpe: BigSizeUInt
def value: ByteVector
def length: BigSizeUInt = {
BigSizeUInt.calcFor(value)
}
override def bytes: ByteVector = {
tpe.bytes ++ length.bytes ++ value
}
}
object TLV extends Factory[TLV] {
case class DecodeTLVResult(
tpe: BigSizeUInt,
length: BigSizeUInt,
value: ByteVector)
def decodeTLV(bytes: ByteVector): DecodeTLVResult = {
val tpe = BigSizeUInt(bytes)
val length = BigSizeUInt(bytes.drop(tpe.byteSize))
val prefixSize = tpe.byteSize + length.byteSize
require(
bytes.length >= prefixSize + length.num.toLong,
s"Length specified was $length but not enough bytes in ${bytes.drop(prefixSize)}")
val value = bytes.drop(prefixSize).take(length.num.toLong)
DecodeTLVResult(tpe, length, value)
}
private val allFactories: Vector[TLVFactory[TLV]] =
Vector(ErrorTLV, PingTLV, PongTLV)
val knownTypes: Vector[BigSizeUInt] = allFactories.map(_.tpe)
def fromBytes(bytes: ByteVector): TLV = {
val DecodeTLVResult(tpe, _, value) = decodeTLV(bytes)
allFactories.find(_.tpe == tpe) match {
case Some(tlvFactory) => tlvFactory.fromTLVValue(value)
case None => UnknownTLV(tpe, value)
}
}
}
sealed trait TLVFactory[+T <: TLV] extends Factory[T] {
def tpe: BigSizeUInt
def fromTLVValue(value: ByteVector): T
override def fromBytes(bytes: ByteVector): T = {
val DecodeTLVResult(tpe, _, value) = TLV.decodeTLV(bytes)
require(tpe == this.tpe, s"Invalid type $tpe when expecting ${this.tpe}")
fromTLVValue(value)
}
}
case class UnknownTLV(tpe: BigSizeUInt, value: ByteVector) extends TLV {
require(!TLV.knownTypes.contains(tpe), s"Type $tpe is known")
}
object UnknownTLV extends Factory[UnknownTLV] {
override def fromBytes(bytes: ByteVector): UnknownTLV = {
val DecodeTLVResult(tpe, _, value) = TLV.decodeTLV(bytes)
UnknownTLV(tpe, value)
}
}
/** @see [[https://github.com/lightningnetwork/lightning-rfc/blob/master/01-messaging.md#the-error-message]] */
case class ErrorTLV(id: ByteVector, data: ByteVector) extends TLV {
require(id.length == 32, s"ID associated with error is incorrect length: $id")
override val tpe: BigSizeUInt = ErrorTLV.tpe
override val value: ByteVector = {
id ++ UInt16(data.length).bytes ++ data
}
}
object ErrorTLV extends TLVFactory[ErrorTLV] {
override val tpe: BigSizeUInt = BigSizeUInt(17)
override def fromTLVValue(value: ByteVector): ErrorTLV = {
val id = value.take(32)
val len = UInt16(value.drop(32).take(2))
val data = value.drop(32 + 2).take(len.toInt)
ErrorTLV(id, data)
}
}
case class PingTLV(numPongBytes: UInt16, ignored: ByteVector) extends TLV {
override val tpe: BigSizeUInt = PingTLV.tpe
override val value: ByteVector = {
numPongBytes.bytes ++ UInt16(ignored.length).bytes ++ ignored
}
}
object PingTLV extends TLVFactory[PingTLV] {
override val tpe: BigSizeUInt = BigSizeUInt(18)
override def fromTLVValue(value: ByteVector): PingTLV = {
val numPongBytes = UInt16(value.take(2))
val numIgnored = UInt16(value.slice(2, 4))
val ignored = value.drop(4).take(numIgnored.toLong)
PingTLV(numPongBytes, ignored)
}
}
case class PongTLV(ignored: ByteVector) extends TLV {
override val tpe: BigSizeUInt = PongTLV.tpe
override val value: ByteVector = {
UInt16(ignored.length).bytes ++ ignored
}
}
object PongTLV extends TLVFactory[PongTLV] {
override val tpe: BigSizeUInt = BigSizeUInt(19)
override def fromTLVValue(value: ByteVector): PongTLV = {
val numIgnored = UInt16(value.take(2))
val ignored = value.drop(2).take(numIgnored.toLong)
PongTLV.forIgnored(ignored)
}
def forIgnored(ignored: ByteVector): PongTLV = {
new PongTLV(ignored)
}
}

View file

@ -1,7 +1,7 @@
package org.bitcoins.testkit.core.gen
import org.bitcoins.core.number._
import org.bitcoins.core.protocol.CompactSizeUInt
import org.bitcoins.core.protocol.{BigSizeUInt, CompactSizeUInt}
import org.bitcoins.core.script.constant.ScriptNumber
import org.bitcoins.core.util.NumberUtil
import org.scalacheck.Arbitrary.arbitrary
@ -49,6 +49,8 @@ trait NumberGenerator {
def uInt8s: Gen[Seq[UInt8]] = Gen.listOf(uInt8)
def uInt16: Gen[UInt16] = Gen.choose(0, 65535).map(UInt16(_))
/**
* Generates a number in the range 0 <= x <= 2 ^^32 - 1
* then wraps it in a UInt32
@ -96,6 +98,8 @@ trait NumberGenerator {
def compactSizeUInts: Gen[CompactSizeUInt] = uInt64s.map(CompactSizeUInt(_))
def bigSizeUInt: Gen[BigSizeUInt] = uInt64.map(BigSizeUInt.apply)
/** Generates an arbitrary [[scala.Byte Byte]] in Scala */
def byte: Gen[Byte] = arbitrary[Byte]

View file

@ -0,0 +1,54 @@
package org.bitcoins.testkit.core.gen
import org.bitcoins.core.protocol.BigSizeUInt
import org.bitcoins.core.protocol.tlv._
import org.scalacheck.Gen
trait TLVGen {
def unknownTpe: Gen[BigSizeUInt] = {
NumberGenerator.bigSizeUInt.suchThat(num => !TLV.knownTypes.contains(num))
}
def unknownTLV: Gen[UnknownTLV] = {
for {
tpe <- unknownTpe
value <- NumberGenerator.bytevector
} yield {
UnknownTLV(tpe, value)
}
}
def errorTLV: Gen[ErrorTLV] = {
for {
id <- NumberGenerator.bytevector(32)
data <- NumberGenerator.bytevector
} yield {
ErrorTLV(id, data)
}
}
def pingTLV: Gen[PingTLV] = {
for {
num <- NumberGenerator.uInt16
bytes <- NumberGenerator.bytevector
} yield {
PingTLV(num, bytes)
}
}
def pongTLV: Gen[PongTLV] = {
NumberGenerator.bytevector.map(PongTLV.forIgnored)
}
def tlv: Gen[TLV] = {
Gen.oneOf(
unknownTLV,
errorTLV,
pingTLV,
pongTLV
)
}
}
object TLVGen extends TLVGen