mirror of
https://github.com/bitcoin-s/bitcoin-s.git
synced 2025-02-22 22:36:34 +01:00
Implement basic TLV functionality (#1847)
* Defined BigSizeUInt as in BOLT 1 * Introduced UInt16 and TLV types * Responded to review * Responded to review
This commit is contained in:
parent
2f8dcd1e57
commit
e8b195f477
8 changed files with 679 additions and 1 deletions
|
@ -0,0 +1,212 @@
|
|||
package org.bitcoins.core.number
|
||||
|
||||
import org.bitcoins.testkit.core.gen.NumberGenerator
|
||||
import org.bitcoins.testkit.util.BitcoinSUnitTest
|
||||
import org.scalacheck.Gen
|
||||
import scodec.bits.ByteVector
|
||||
|
||||
import scala.util.Try
|
||||
|
||||
class UInt16Test extends BitcoinSUnitTest {
|
||||
|
||||
implicit override val generatorDrivenConfig: PropertyCheckConfiguration =
|
||||
generatorDrivenConfigNewCode
|
||||
|
||||
behavior of "UInt16"
|
||||
|
||||
it must "create the number zero as an unsigned 16 bit integer" in {
|
||||
val zero = UInt16(ByteVector(0x0.toByte))
|
||||
assert(zero.toInt == 0)
|
||||
}
|
||||
|
||||
it must "create the max number for an unsigned byte" in {
|
||||
val maxByteValue = UInt16(ByteVector(0xff.toByte))
|
||||
assert(maxByteValue.toInt == 255)
|
||||
}
|
||||
|
||||
it must "create the number 256" in {
|
||||
val uInt16 = UInt16(ByteVector(0x01.toByte, 0x00.toByte))
|
||||
assert(uInt16.toInt == 256)
|
||||
}
|
||||
|
||||
it must "create the number 65535" in {
|
||||
val uInt16 = UInt16(ByteVector(0xff.toByte, 0xff.toByte))
|
||||
assert(uInt16.toInt == 65535)
|
||||
}
|
||||
|
||||
it must "have the correct maximum number representation for UInt16" in {
|
||||
assert(UInt16.max.toInt == 65535)
|
||||
assert(UInt16.max.hex == "ffff")
|
||||
}
|
||||
|
||||
it must "fail to create the number 65536" in {
|
||||
assertThrows[IllegalArgumentException] {
|
||||
UInt16(ByteVector(0x01.toByte, 0x0.toByte, 0x0.toByte))
|
||||
}
|
||||
|
||||
assertThrows[IllegalArgumentException] {
|
||||
UInt16(65536)
|
||||
}
|
||||
}
|
||||
|
||||
it must "throw an exception if we try and create a UInt16 with a negative number" in {
|
||||
assertThrows[IllegalArgumentException] {
|
||||
UInt16(-1)
|
||||
}
|
||||
}
|
||||
|
||||
it must "throw an exception if we try and create a UInt16 with more than 2 bytes" in {
|
||||
assertThrows[IllegalArgumentException] {
|
||||
UInt16(ByteVector(0.toByte, 0.toByte, 0.toByte))
|
||||
}
|
||||
}
|
||||
|
||||
it must "have the correct representation for 0" in {
|
||||
assert(UInt16.zero.toInt == 0)
|
||||
}
|
||||
|
||||
it must "have the correct representation for 1" in {
|
||||
assert(UInt16.one.toInt == 1)
|
||||
}
|
||||
|
||||
it must "have the correct minimum number for a UInt16" in {
|
||||
assert(UInt16.min.toInt == 0)
|
||||
}
|
||||
|
||||
it must "have serialization symmetry" in {
|
||||
forAll(NumberGenerator.uInt16) { uInt16: UInt16 =>
|
||||
assert(UInt16(uInt16.hex) == uInt16)
|
||||
assert(UInt16(uInt16.hex).hex == uInt16.hex)
|
||||
}
|
||||
}
|
||||
|
||||
it must "add zero correctly" in {
|
||||
forAll(NumberGenerator.uInt16) { num: UInt16 =>
|
||||
assert(num + UInt16.zero == num)
|
||||
}
|
||||
}
|
||||
|
||||
it must "Negative numbers in UInt16 throw an exception" in {
|
||||
forAll(NumberGenerator.negativeInts) { num =>
|
||||
val uint16 = Try(UInt16(num))
|
||||
assert(uint16.isFailure)
|
||||
}
|
||||
}
|
||||
|
||||
it must "add two uint16s and get the mathematical sum of the two numbers" in {
|
||||
forAll(NumberGenerator.uInt16, NumberGenerator.uInt16) {
|
||||
(num1: UInt16, num2: UInt16) =>
|
||||
val result = num1.toInt + num2.toInt
|
||||
if (result <= UInt16.max.toInt) {
|
||||
assert(num1 + num2 == UInt16(result.toInt))
|
||||
} else {
|
||||
assert(Try(num1 + num2).isFailure)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
it must "subtract zero correctly" in {
|
||||
forAll(NumberGenerator.uInt16) { uInt16: UInt16 =>
|
||||
assert(uInt16 - UInt16.zero == uInt16)
|
||||
}
|
||||
}
|
||||
|
||||
it must "subtract from zero correctly" in {
|
||||
forAll(NumberGenerator.uInt16) { num =>
|
||||
if (num == UInt16.zero) {
|
||||
assert(UInt16.zero - num == UInt16.zero)
|
||||
} else {
|
||||
assert(Try(UInt16.zero - num).isFailure)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
it must "subtract a uint16 from another uint16 and get the correct result" in {
|
||||
forAll(NumberGenerator.uInt16, NumberGenerator.uInt16) {
|
||||
(num1: UInt16, num2: UInt16) =>
|
||||
val result = num1.toInt - num2.toInt
|
||||
if (result >= 0) {
|
||||
assert(num1 - num2 == UInt16(result))
|
||||
} else {
|
||||
assert(Try(num1 - num2).isFailure)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
it must "multiplying by zero correctly" in {
|
||||
forAll(NumberGenerator.uInt16) { uInt16: UInt16 =>
|
||||
assert(uInt16 * UInt16.zero == UInt16.zero)
|
||||
}
|
||||
}
|
||||
|
||||
it must "multiply by one correctly" in {
|
||||
forAll(NumberGenerator.uInt16) { uInt16: UInt16 =>
|
||||
assert(uInt16 * UInt16.one == uInt16)
|
||||
}
|
||||
}
|
||||
|
||||
it must "multiply two UInt16s correctly" in {
|
||||
forAll(NumberGenerator.uInt16, NumberGenerator.uInt16) {
|
||||
(num1: UInt16, num2: UInt16) =>
|
||||
val bigInt1 = num1.toBigInt
|
||||
val bigInt2 = num2.toBigInt
|
||||
if (bigInt1 * bigInt2 <= UInt16.max.toInt) {
|
||||
assert(
|
||||
num1 * num2 ==
|
||||
UInt16(num1.toInt * num2.toInt))
|
||||
} else {
|
||||
assert(Try(num1 * num2).isFailure)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
it must "compare UInt16s correctly" in {
|
||||
forAll(NumberGenerator.uInt16, NumberGenerator.uInt16) {
|
||||
(num1: UInt16, num2: UInt16) =>
|
||||
if (num1.toInt < num2.toInt) assert(num1 < num2)
|
||||
else assert(num1 >= num2)
|
||||
|
||||
if (num1.toInt <= num2.toInt) assert(num1 <= num2)
|
||||
else assert(num1 > num2)
|
||||
|
||||
if (num1.toInt == num2.toInt) assert(num1 == num2)
|
||||
else assert(num1 != num2)
|
||||
}
|
||||
}
|
||||
|
||||
it must "| correctly" in {
|
||||
forAll(NumberGenerator.uInt16, NumberGenerator.uInt16) {
|
||||
(num1: UInt16, num2: UInt16) =>
|
||||
assert(UInt16(num1.toInt | num2.toInt) == (num1 | num2))
|
||||
}
|
||||
}
|
||||
|
||||
it must "& correctly" in {
|
||||
forAll(NumberGenerator.uInt16, NumberGenerator.uInt16) {
|
||||
(num1: UInt16, num2: UInt16) =>
|
||||
assert(UInt16(num1.toInt & num2.toInt) == (num1 & num2))
|
||||
}
|
||||
}
|
||||
|
||||
it must "<< correctly" in {
|
||||
forAll(NumberGenerator.uInt16, Gen.choose(0, 16)) {
|
||||
case (u16, shift) =>
|
||||
val r = Try(u16 << shift)
|
||||
val expected = (u16.toInt << shift) & 0xffff
|
||||
if (r.isSuccess && expected <= UInt16.max.toInt) {
|
||||
assert(r.get == UInt16(expected))
|
||||
} else {
|
||||
assert(r.isFailure)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
it must ">> correctly" in {
|
||||
forAll(NumberGenerator.uInt16, Gen.choose(0, 100)) {
|
||||
case (u16, shift) =>
|
||||
val r = u16 >> shift
|
||||
val expected = if (shift >= 32) 0 else u16.toInt >> shift
|
||||
assert(r == UInt16(expected))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package org.bitcoins.core.protocol
|
||||
|
||||
import org.bitcoins.core.number.UInt64
|
||||
import org.bitcoins.testkit.core.gen.NumberGenerator
|
||||
import org.bitcoins.testkit.util.BitcoinSUnitTest
|
||||
import scodec.bits.ByteVector
|
||||
|
||||
import scala.util.{Failure, Success, Try}
|
||||
|
||||
class BigSizeUIntTest extends BitcoinSUnitTest {
|
||||
|
||||
implicit override val generatorDrivenConfig: PropertyCheckConfiguration =
|
||||
generatorDrivenConfigNewCode
|
||||
|
||||
behavior of "BigSizeUInt"
|
||||
|
||||
it must "have serialization symmetry" in {
|
||||
forAll(NumberGenerator.bigSizeUInt) { num =>
|
||||
assert(BigSizeUInt(num.bytes) == num)
|
||||
}
|
||||
}
|
||||
|
||||
it must "fail to parse an empty ByteVector" in {
|
||||
assertThrows[IllegalArgumentException] {
|
||||
BigSizeUInt(ByteVector.empty)
|
||||
}
|
||||
}
|
||||
|
||||
it must "pass encoding tests" in {
|
||||
val bufferedSource =
|
||||
io.Source.fromURL(getClass.getResource("/bigsize_encoding.json"))
|
||||
try {
|
||||
val builder = new StringBuilder
|
||||
bufferedSource.getLines().foreach(builder.append)
|
||||
val tests = ujson.read(builder.result()).arr.toVector
|
||||
tests.foreach { test =>
|
||||
val obj = test.obj
|
||||
val name = obj("name").str
|
||||
val num = BigInt(obj("value").str)
|
||||
val bytes = ByteVector.fromValidHex(obj("bytes").str)
|
||||
assert(BigSizeUInt(num).bytes == bytes, name)
|
||||
}
|
||||
} finally {
|
||||
bufferedSource.close()
|
||||
}
|
||||
}
|
||||
|
||||
it must "pass decoding tests" in {
|
||||
val bufferedSource =
|
||||
io.Source.fromURL(getClass.getResource("/bigsize_decoding.json"))
|
||||
try {
|
||||
val builder = new StringBuilder
|
||||
bufferedSource.getLines().foreach(builder.append)
|
||||
val tests = ujson.read(builder.result()).arr.toVector
|
||||
tests.foreach { test =>
|
||||
val obj = test.obj
|
||||
val name = obj("name").str
|
||||
val numStr = obj("value").str
|
||||
val bytes = ByteVector.fromValidHex(obj("bytes").str)
|
||||
if (numStr.nonEmpty) {
|
||||
assert(BigSizeUInt(bytes).num == UInt64(BigInt(numStr)), name)
|
||||
} else {
|
||||
Try {
|
||||
assertThrows[IllegalArgumentException] {
|
||||
BigSizeUInt(bytes)
|
||||
}
|
||||
} match {
|
||||
case Failure(err) => fail(obj("exp_error").str, err)
|
||||
case Success(success) => success
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
bufferedSource.close()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
package org.bitcoins.core.protocol.tlv
|
||||
|
||||
import org.bitcoins.testkit.core.gen.TLVGen
|
||||
import org.bitcoins.testkit.util.BitcoinSUnitTest
|
||||
|
||||
class TLVTest extends BitcoinSUnitTest {
|
||||
|
||||
implicit override val generatorDrivenConfig: PropertyCheckConfiguration =
|
||||
generatorDrivenConfigNewCode
|
||||
|
||||
"TLV" must "have serizliation symmetry" in {
|
||||
forAll(TLVGen.tlv) { tlv =>
|
||||
assert(TLV(tlv.bytes) == tlv)
|
||||
}
|
||||
}
|
||||
|
||||
"UnknownTLV" must "have serialization symmetry" in {
|
||||
forAll(TLVGen.unknownTLV) { unknown =>
|
||||
assert(UnknownTLV(unknown.bytes) == unknown)
|
||||
assert(TLV(unknown.bytes) == unknown)
|
||||
}
|
||||
}
|
||||
|
||||
"ErrorTLV" must "have serialization symmetry" in {
|
||||
forAll(TLVGen.errorTLV) { error =>
|
||||
assert(ErrorTLV(error.bytes) == error)
|
||||
assert(TLV(error.bytes) == error)
|
||||
}
|
||||
}
|
||||
|
||||
"PingTLV" must "have serialization symmetry" in {
|
||||
forAll(TLVGen.pingTLV) { ping =>
|
||||
assert(PingTLV(ping.bytes) == ping)
|
||||
assert(TLV(ping.bytes) == ping)
|
||||
}
|
||||
}
|
||||
|
||||
"PongTLV" must "have serialization symmetry" in {
|
||||
forAll(TLVGen.pongTLV) { pong =>
|
||||
assert(PongTLV(pong.bytes) == pong)
|
||||
assert(TLV(pong.bytes) == pong)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -128,6 +128,16 @@ sealed abstract class UInt8 extends UnsignedNumber[UInt8] {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a uint16_t in C
|
||||
*/
|
||||
sealed abstract class UInt16 extends UnsignedNumber[UInt16] {
|
||||
override def apply: A => UInt16 = UInt16(_)
|
||||
override def hex: String = BytesUtil.encodeHex(toInt.toShort)
|
||||
|
||||
override def andMask = 0xffffL
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a uint32_t in C
|
||||
*/
|
||||
|
@ -310,6 +320,57 @@ object UInt8
|
|||
}
|
||||
}
|
||||
|
||||
object UInt16
|
||||
extends Factory[UInt16]
|
||||
with NumberObject[UInt16]
|
||||
with Bounded[UInt16] {
|
||||
|
||||
private case class UInt16Impl(underlying: BigInt) extends UInt16 {
|
||||
require(isInBound(underlying),
|
||||
s"Cannot create ${super.getClass.getSimpleName} from $underlying")
|
||||
}
|
||||
|
||||
/** Cache from 0 to 128 UInt16 */
|
||||
private val cached: Vector[UInt16] = {
|
||||
0.until(128).map(i => UInt16(BigInt(i))).toVector
|
||||
}
|
||||
|
||||
lazy val zero = cached(0)
|
||||
lazy val one = cached(1)
|
||||
|
||||
private lazy val minUnderlying: A = 0
|
||||
private lazy val maxUnderlying: A = BigInt(65535L)
|
||||
|
||||
lazy val min = zero
|
||||
lazy val max = UInt16(maxUnderlying)
|
||||
|
||||
override def isInBound(num: A): Boolean =
|
||||
num <= maxUnderlying && num >= minUnderlying
|
||||
|
||||
override def fromBytes(bytes: ByteVector): UInt16 = {
|
||||
require(
|
||||
bytes.size <= 2,
|
||||
"UInt16 byte array was too large, got: " + BytesUtil.encodeHex(bytes))
|
||||
UInt16(bytes.toLong(signed = false, ordering = ByteOrdering.BigEndian))
|
||||
}
|
||||
|
||||
def apply(long: Long): UInt16 = {
|
||||
checkCached(long)
|
||||
}
|
||||
|
||||
def apply(bigInt: BigInt): UInt16 = {
|
||||
UInt16Impl(bigInt)
|
||||
}
|
||||
|
||||
/** Checks if we have the number cached, if not allocates a new object to represent the number */
|
||||
private def checkCached(long: Long): UInt16 = {
|
||||
if (long < 128 && long >= 0) cached(long.toInt)
|
||||
else {
|
||||
UInt16(BigInt(long))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object UInt32
|
||||
extends Factory[UInt32]
|
||||
with NumberObject[UInt32]
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
package org.bitcoins.core.protocol
|
||||
|
||||
import org.bitcoins.core.number.UInt64
|
||||
import org.bitcoins.crypto.{Factory, NetworkElement}
|
||||
import scodec.bits.ByteVector
|
||||
|
||||
case class BigSizeUInt(num: UInt64) extends NetworkElement {
|
||||
override val byteSize: Long = BigSizeUInt.calcSizeForNum(num)
|
||||
|
||||
override def bytes: ByteVector = {
|
||||
byteSize match {
|
||||
case 1 => num.bytes.takeRight(1)
|
||||
case 3 => 0xfd.toByte +: num.bytes.takeRight(2)
|
||||
case 5 => 0xfe.toByte +: num.bytes.takeRight(4)
|
||||
case _ => 0xff.toByte +: num.bytes
|
||||
}
|
||||
}
|
||||
|
||||
def toLong: Long = num.toLong
|
||||
|
||||
def toInt: Int = {
|
||||
val l = toLong
|
||||
require(Int.MinValue <= l && l <= Int.MaxValue,
|
||||
"Cannot convert BigSizeUInt toInt, got: " + this)
|
||||
l.toInt
|
||||
}
|
||||
|
||||
override def toString: String = s"BigSizeUInt(${num.toLong})"
|
||||
}
|
||||
|
||||
object BigSizeUInt extends Factory[BigSizeUInt] {
|
||||
|
||||
def calcSizeForNum(num: UInt64): Int = {
|
||||
if (num.toBigInt < 0xfd) { // can be represented with one byte
|
||||
1
|
||||
} else if (num.toBigInt < 0x10000) { // can be represented with two bytes
|
||||
3
|
||||
} else if (num.toBigInt < 0x100000000L) { // can be represented with 4 bytes
|
||||
5
|
||||
} else {
|
||||
9
|
||||
}
|
||||
}
|
||||
|
||||
def apply(num: Long): BigSizeUInt = {
|
||||
BigSizeUInt(UInt64(num))
|
||||
}
|
||||
|
||||
def apply(num: BigInt): BigSizeUInt = {
|
||||
BigSizeUInt(UInt64(num))
|
||||
}
|
||||
|
||||
override def fromBytes(bytes: ByteVector): BigSizeUInt = {
|
||||
require(bytes.nonEmpty, "Cannot parse a BigSizeUInt from empty byte vector")
|
||||
|
||||
val prefixNum = UInt64(bytes.take(1)).toInt
|
||||
|
||||
val (bigSizeUInt, expectedSize) = if (prefixNum < 253) { // 8 bit number
|
||||
(BigSizeUInt(prefixNum), 1)
|
||||
} else if (prefixNum == 253) { // 16 bit number
|
||||
(BigSizeUInt(UInt64(bytes.slice(1, 3))), 3)
|
||||
} else if (prefixNum == 254) { // 32 bit number
|
||||
(BigSizeUInt(UInt64(bytes.slice(1, 5))), 5)
|
||||
} else { // 64 bit number
|
||||
(BigSizeUInt(UInt64(bytes.slice(1, 9))), 9)
|
||||
}
|
||||
|
||||
require(
|
||||
bigSizeUInt.byteSize == expectedSize,
|
||||
s"Length prefix $prefixNum did not match bytes ${bigSizeUInt.bytes.tail}")
|
||||
|
||||
bigSizeUInt
|
||||
}
|
||||
|
||||
def calcFor(bytes: ByteVector): BigSizeUInt = {
|
||||
BigSizeUInt(bytes.length)
|
||||
}
|
||||
}
|
148
core/src/main/scala/org/bitcoins/core/protocol/tlv/TLV.scala
Normal file
148
core/src/main/scala/org/bitcoins/core/protocol/tlv/TLV.scala
Normal file
|
@ -0,0 +1,148 @@
|
|||
package org.bitcoins.core.protocol.tlv
|
||||
|
||||
import org.bitcoins.core.number.UInt16
|
||||
import org.bitcoins.core.protocol.BigSizeUInt
|
||||
import org.bitcoins.core.protocol.tlv.TLV.DecodeTLVResult
|
||||
import org.bitcoins.crypto.{Factory, NetworkElement}
|
||||
import scodec.bits.ByteVector
|
||||
|
||||
sealed trait TLV extends NetworkElement {
|
||||
def tpe: BigSizeUInt
|
||||
def value: ByteVector
|
||||
|
||||
def length: BigSizeUInt = {
|
||||
BigSizeUInt.calcFor(value)
|
||||
}
|
||||
|
||||
override def bytes: ByteVector = {
|
||||
tpe.bytes ++ length.bytes ++ value
|
||||
}
|
||||
}
|
||||
|
||||
object TLV extends Factory[TLV] {
|
||||
|
||||
case class DecodeTLVResult(
|
||||
tpe: BigSizeUInt,
|
||||
length: BigSizeUInt,
|
||||
value: ByteVector)
|
||||
|
||||
def decodeTLV(bytes: ByteVector): DecodeTLVResult = {
|
||||
val tpe = BigSizeUInt(bytes)
|
||||
val length = BigSizeUInt(bytes.drop(tpe.byteSize))
|
||||
val prefixSize = tpe.byteSize + length.byteSize
|
||||
|
||||
require(
|
||||
bytes.length >= prefixSize + length.num.toLong,
|
||||
s"Length specified was $length but not enough bytes in ${bytes.drop(prefixSize)}")
|
||||
|
||||
val value = bytes.drop(prefixSize).take(length.num.toLong)
|
||||
|
||||
DecodeTLVResult(tpe, length, value)
|
||||
}
|
||||
|
||||
private val allFactories: Vector[TLVFactory[TLV]] =
|
||||
Vector(ErrorTLV, PingTLV, PongTLV)
|
||||
|
||||
val knownTypes: Vector[BigSizeUInt] = allFactories.map(_.tpe)
|
||||
|
||||
def fromBytes(bytes: ByteVector): TLV = {
|
||||
val DecodeTLVResult(tpe, _, value) = decodeTLV(bytes)
|
||||
|
||||
allFactories.find(_.tpe == tpe) match {
|
||||
case Some(tlvFactory) => tlvFactory.fromTLVValue(value)
|
||||
case None => UnknownTLV(tpe, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sealed trait TLVFactory[+T <: TLV] extends Factory[T] {
|
||||
def tpe: BigSizeUInt
|
||||
def fromTLVValue(value: ByteVector): T
|
||||
|
||||
override def fromBytes(bytes: ByteVector): T = {
|
||||
val DecodeTLVResult(tpe, _, value) = TLV.decodeTLV(bytes)
|
||||
|
||||
require(tpe == this.tpe, s"Invalid type $tpe when expecting ${this.tpe}")
|
||||
|
||||
fromTLVValue(value)
|
||||
}
|
||||
}
|
||||
|
||||
case class UnknownTLV(tpe: BigSizeUInt, value: ByteVector) extends TLV {
|
||||
require(!TLV.knownTypes.contains(tpe), s"Type $tpe is known")
|
||||
}
|
||||
|
||||
object UnknownTLV extends Factory[UnknownTLV] {
|
||||
|
||||
override def fromBytes(bytes: ByteVector): UnknownTLV = {
|
||||
val DecodeTLVResult(tpe, _, value) = TLV.decodeTLV(bytes)
|
||||
|
||||
UnknownTLV(tpe, value)
|
||||
}
|
||||
}
|
||||
|
||||
/** @see [[https://github.com/lightningnetwork/lightning-rfc/blob/master/01-messaging.md#the-error-message]] */
|
||||
case class ErrorTLV(id: ByteVector, data: ByteVector) extends TLV {
|
||||
require(id.length == 32, s"ID associated with error is incorrect length: $id")
|
||||
|
||||
override val tpe: BigSizeUInt = ErrorTLV.tpe
|
||||
|
||||
override val value: ByteVector = {
|
||||
id ++ UInt16(data.length).bytes ++ data
|
||||
}
|
||||
}
|
||||
|
||||
object ErrorTLV extends TLVFactory[ErrorTLV] {
|
||||
override val tpe: BigSizeUInt = BigSizeUInt(17)
|
||||
|
||||
override def fromTLVValue(value: ByteVector): ErrorTLV = {
|
||||
val id = value.take(32)
|
||||
val len = UInt16(value.drop(32).take(2))
|
||||
val data = value.drop(32 + 2).take(len.toInt)
|
||||
|
||||
ErrorTLV(id, data)
|
||||
}
|
||||
}
|
||||
|
||||
case class PingTLV(numPongBytes: UInt16, ignored: ByteVector) extends TLV {
|
||||
override val tpe: BigSizeUInt = PingTLV.tpe
|
||||
|
||||
override val value: ByteVector = {
|
||||
numPongBytes.bytes ++ UInt16(ignored.length).bytes ++ ignored
|
||||
}
|
||||
}
|
||||
|
||||
object PingTLV extends TLVFactory[PingTLV] {
|
||||
override val tpe: BigSizeUInt = BigSizeUInt(18)
|
||||
|
||||
override def fromTLVValue(value: ByteVector): PingTLV = {
|
||||
val numPongBytes = UInt16(value.take(2))
|
||||
val numIgnored = UInt16(value.slice(2, 4))
|
||||
val ignored = value.drop(4).take(numIgnored.toLong)
|
||||
|
||||
PingTLV(numPongBytes, ignored)
|
||||
}
|
||||
}
|
||||
|
||||
case class PongTLV(ignored: ByteVector) extends TLV {
|
||||
override val tpe: BigSizeUInt = PongTLV.tpe
|
||||
|
||||
override val value: ByteVector = {
|
||||
UInt16(ignored.length).bytes ++ ignored
|
||||
}
|
||||
}
|
||||
|
||||
object PongTLV extends TLVFactory[PongTLV] {
|
||||
override val tpe: BigSizeUInt = BigSizeUInt(19)
|
||||
|
||||
override def fromTLVValue(value: ByteVector): PongTLV = {
|
||||
val numIgnored = UInt16(value.take(2))
|
||||
val ignored = value.drop(2).take(numIgnored.toLong)
|
||||
|
||||
PongTLV.forIgnored(ignored)
|
||||
}
|
||||
|
||||
def forIgnored(ignored: ByteVector): PongTLV = {
|
||||
new PongTLV(ignored)
|
||||
}
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
package org.bitcoins.testkit.core.gen
|
||||
|
||||
import org.bitcoins.core.number._
|
||||
import org.bitcoins.core.protocol.CompactSizeUInt
|
||||
import org.bitcoins.core.protocol.{BigSizeUInt, CompactSizeUInt}
|
||||
import org.bitcoins.core.script.constant.ScriptNumber
|
||||
import org.bitcoins.core.util.NumberUtil
|
||||
import org.scalacheck.Arbitrary.arbitrary
|
||||
|
@ -49,6 +49,8 @@ trait NumberGenerator {
|
|||
|
||||
def uInt8s: Gen[Seq[UInt8]] = Gen.listOf(uInt8)
|
||||
|
||||
def uInt16: Gen[UInt16] = Gen.choose(0, 65535).map(UInt16(_))
|
||||
|
||||
/**
|
||||
* Generates a number in the range 0 <= x <= 2 ^^32 - 1
|
||||
* then wraps it in a UInt32
|
||||
|
@ -96,6 +98,8 @@ trait NumberGenerator {
|
|||
|
||||
def compactSizeUInts: Gen[CompactSizeUInt] = uInt64s.map(CompactSizeUInt(_))
|
||||
|
||||
def bigSizeUInt: Gen[BigSizeUInt] = uInt64.map(BigSizeUInt.apply)
|
||||
|
||||
/** Generates an arbitrary [[scala.Byte Byte]] in Scala */
|
||||
def byte: Gen[Byte] = arbitrary[Byte]
|
||||
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
package org.bitcoins.testkit.core.gen
|
||||
|
||||
import org.bitcoins.core.protocol.BigSizeUInt
|
||||
import org.bitcoins.core.protocol.tlv._
|
||||
import org.scalacheck.Gen
|
||||
|
||||
trait TLVGen {
|
||||
|
||||
def unknownTpe: Gen[BigSizeUInt] = {
|
||||
NumberGenerator.bigSizeUInt.suchThat(num => !TLV.knownTypes.contains(num))
|
||||
}
|
||||
|
||||
def unknownTLV: Gen[UnknownTLV] = {
|
||||
for {
|
||||
tpe <- unknownTpe
|
||||
value <- NumberGenerator.bytevector
|
||||
} yield {
|
||||
UnknownTLV(tpe, value)
|
||||
}
|
||||
}
|
||||
|
||||
def errorTLV: Gen[ErrorTLV] = {
|
||||
for {
|
||||
id <- NumberGenerator.bytevector(32)
|
||||
data <- NumberGenerator.bytevector
|
||||
} yield {
|
||||
ErrorTLV(id, data)
|
||||
}
|
||||
}
|
||||
|
||||
def pingTLV: Gen[PingTLV] = {
|
||||
for {
|
||||
num <- NumberGenerator.uInt16
|
||||
bytes <- NumberGenerator.bytevector
|
||||
} yield {
|
||||
PingTLV(num, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
def pongTLV: Gen[PongTLV] = {
|
||||
NumberGenerator.bytevector.map(PongTLV.forIgnored)
|
||||
}
|
||||
|
||||
def tlv: Gen[TLV] = {
|
||||
Gen.oneOf(
|
||||
unknownTLV,
|
||||
errorTLV,
|
||||
pingTLV,
|
||||
pongTLV
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
object TLVGen extends TLVGen
|
Loading…
Add table
Reference in a new issue