diff --git a/core/src/main/scala/org/bitcoins/core/protocol/tlv/TLV.scala b/core/src/main/scala/org/bitcoins/core/protocol/tlv/TLV.scala index 2bc1d53f0a..ebf4621274 100644 --- a/core/src/main/scala/org/bitcoins/core/protocol/tlv/TLV.scala +++ b/core/src/main/scala/org/bitcoins/core/protocol/tlv/TLV.scala @@ -17,7 +17,7 @@ import org.bitcoins.core.wallet.fee.SatoshisPerVirtualByte import org.bitcoins.crypto._ import scodec.bits.ByteVector -sealed trait TLV extends NetworkElement { +sealed trait TLV extends NetworkElement with TLVUtil { def tpe: BigSizeUInt def value: ByteVector @@ -32,6 +32,61 @@ sealed trait TLV extends NetworkElement { def sha256: Sha256Digest = CryptoUtil.sha256(bytes) } +trait TLVUtil { + + protected def boolBytes(bool: Boolean): ByteVector = { + if (bool) { + ByteVector(TRUE_BYTE) + } else { + ByteVector(FALSE_BYTE) + } + } + + protected def strBytes(str: NormalizedString): ByteVector = { + TLV.getStringBytes(str) + } + + protected def satBytes(sats: Satoshis): ByteVector = { + UInt64(sats.toLong).bytes + } + + protected def u16Prefix(bytes: ByteVector): ByteVector = { + UInt16(bytes.length).bytes ++ bytes + } + + protected def u16PrefixedList[T]( + vec: Vector[T], + serialize: T => ByteVector): ByteVector = { + vec.foldLeft(UInt16(vec.length).bytes) { + case (accum, elem) => + accum ++ serialize(elem) + } + } + + protected def u16PrefixedList[T <: NetworkElement]( + vec: Vector[T]): ByteVector = { + u16PrefixedList[T](vec, { elem: NetworkElement => elem.bytes }) + } + + protected def bigSizePrefix(bytes: ByteVector): ByteVector = { + BigSizeUInt(bytes.length).bytes ++ bytes + } + + protected def bigSizePrefixedList[T]( + vec: Vector[T], + serialize: T => ByteVector): ByteVector = { + vec.foldLeft(BigSizeUInt(vec.length).bytes) { + case (accum, elem) => + accum ++ serialize(elem) + } + } + + protected def bigSizePrefixedList[T <: NetworkElement]( + vec: Vector[T]): ByteVector = { + bigSizePrefixedList[T](vec, { elem: NetworkElement => elem.bytes }) + } +} + trait TLVSerializable[+T <: TLV] extends NetworkElement { def toTLV: T @@ -163,12 +218,75 @@ sealed trait TLVFactory[+T <: TLV] extends Factory[T] { bytes } + /** IMPORTANT: This only works for factories which read off of + * the front of a ByteVector without consuming the whole thing. + * If this is not the case, you must specify how many bytes. + */ + def take[E <: NetworkElement](factory: Factory[E]): E = { + val elem = factory(current) + skip(elem) + elem + } + + def take[E <: NetworkElement](factory: Factory[E], byteSize: Int): E = { + val bytes = take(byteSize) + factory(bytes) + } + def takeBits(numBits: Int): ByteVector = { require(numBits % 8 == 0, s"Must take a round byte number of bits, got $numBits") take(numBytes = numBits / 8) } + def takeBigSize(): BigSizeUInt = { + take(BigSizeUInt) + } + + def takeBigSizePrefixed[E](takeFunc: Int => E): E = { + val len = takeBigSize() + takeFunc(len.toInt) + } + + def takeBigSizePrefixedList[E](takeFunc: () => E): Vector[E] = { + val len = takeBigSize() + 0.until(len.toInt).toVector.map { _ => + takeFunc() + } + } + + def takeU16(): UInt16 = { + UInt16(takeBits(16)) + } + + def takeU16Prefixed[E](takeFunc: Int => E): E = { + val len = takeU16() + takeFunc(len.toInt) + } + + def takeU16PrefixedList[E](takeFunc: () => E): Vector[E] = { + val len = takeU16() + 0.until(len.toInt).toVector.map { _ => + takeFunc() + } + } + + def takeI32(): Int32 = { + Int32(takeBits(32)) + } + + def takeU32(): UInt32 = { + UInt32(takeBits(32)) + } + + def takeU64(): UInt64 = { + UInt64(takeBits(64)) + } + + def takeSats(): Satoshis = { + Satoshis(takeU64().toLong) + } + def takeBoolean(): Boolean = { take(1).head match { case FALSE_BYTE => false @@ -180,22 +298,15 @@ sealed trait TLVFactory[+T <: TLV] extends Factory[T] { } def takeString(): NormalizedString = { - val size = BigSizeUInt(current) - skip(size.byteSize) + val size = takeBigSize() val strBytes = take(size.toInt) NormalizedString(strBytes) } def takeSPK(): ScriptPubKey = { - val len = UInt16(takeBits(16)).toInt + val len = takeU16().toInt ScriptPubKey.fromAsmBytes(take(len)) } - - def takePoint(): TLVPoint = { - val point = TLVPoint(current) - skip(point) - point - } } } @@ -264,7 +375,7 @@ case class ErrorTLV(id: ByteVector, data: ByteVector) extends TLV { override val tpe: BigSizeUInt = ErrorTLV.tpe override val value: ByteVector = { - id ++ UInt16(data.length).bytes ++ data + id ++ u16Prefix(data) } } @@ -272,9 +383,10 @@ object ErrorTLV extends TLVFactory[ErrorTLV] { override val tpe: BigSizeUInt = BigSizeUInt(17) override def fromTLVValue(value: ByteVector): ErrorTLV = { - val id = value.take(32) - val len = UInt16(value.drop(32).take(2)) - val data = value.drop(32 + 2).take(len.toInt) + val iter = ValueIterator(value) + + val id = iter.take(32) + val data = iter.takeU16Prefixed(iter.take) ErrorTLV(id, data) } @@ -284,7 +396,7 @@ case class PingTLV(numPongBytes: UInt16, ignored: ByteVector) extends TLV { override val tpe: BigSizeUInt = PingTLV.tpe override val value: ByteVector = { - numPongBytes.bytes ++ UInt16(ignored.length).bytes ++ ignored + numPongBytes.bytes ++ u16Prefix(ignored) } } @@ -292,9 +404,10 @@ object PingTLV extends TLVFactory[PingTLV] { override val tpe: BigSizeUInt = BigSizeUInt(18) override def fromTLVValue(value: ByteVector): PingTLV = { - val numPongBytes = UInt16(value.take(2)) - val numIgnored = UInt16(value.slice(2, 4)) - val ignored = value.drop(4).take(numIgnored.toLong) + val iter = ValueIterator(value) + + val numPongBytes = iter.takeU16() + val ignored = iter.takeU16Prefixed(iter.take) PingTLV(numPongBytes, ignored) } @@ -304,7 +417,7 @@ case class PongTLV(ignored: ByteVector) extends TLV { override val tpe: BigSizeUInt = PongTLV.tpe override val value: ByteVector = { - UInt16(ignored.length).bytes ++ ignored + u16Prefix(ignored) } } @@ -312,8 +425,9 @@ object PongTLV extends TLVFactory[PongTLV] { override val tpe: BigSizeUInt = BigSizeUInt(19) override def fromTLVValue(value: ByteVector): PongTLV = { - val numIgnored = UInt16(value.take(2)) - val ignored = value.drop(2).take(numIgnored.toLong) + val iter = ValueIterator(value) + + val ignored = iter.takeU16Prefixed(iter.take) PongTLV.forIgnored(ignored) } @@ -347,12 +461,7 @@ case class EnumEventDescriptorV0TLV(outcomes: Vector[NormalizedString]) override def tpe: BigSizeUInt = EnumEventDescriptorV0TLV.tpe override val value: ByteVector = { - val starting = UInt16(outcomes.size).bytes - - outcomes.foldLeft(starting) { (accum, outcome) => - val outcomeBytes = TLV.getStringBytes(outcome) - accum ++ outcomeBytes - } + u16PrefixedList(outcomes, TLV.getStringBytes) } override def noncesNeeded: Int = 1 @@ -365,22 +474,9 @@ object EnumEventDescriptorV0TLV extends TLVFactory[EnumEventDescriptorV0TLV] { override def fromTLVValue(value: ByteVector): EnumEventDescriptorV0TLV = { val iter = ValueIterator(value) - val count = UInt16(iter.takeBits(16)) + val outcomes = iter.takeU16PrefixedList(() => iter.takeString()) - val builder = Vector.newBuilder[NormalizedString] - - while (iter.index < value.length) { - val str = iter.takeString() - builder.+=(str) - } - - val result = builder.result() - - require( - count.toInt == result.size, - s"Did not parse the expected number of outcomes, ${count.toInt} != ${result.size}") - - EnumEventDescriptorV0TLV(result) + EnumEventDescriptorV0TLV(outcomes) } val dummy: EnumEventDescriptorV0TLV = EnumEventDescriptorV0TLV( @@ -471,7 +567,7 @@ case class RangeEventDescriptorV0TLV( override val value: ByteVector = { start.bytes ++ count.bytes ++ step.bytes ++ - TLV.getStringBytes(unit) ++ precision.bytes + strBytes(unit) ++ precision.bytes } override def noncesNeeded: Int = 1 @@ -484,12 +580,11 @@ object RangeEventDescriptorV0TLV extends TLVFactory[RangeEventDescriptorV0TLV] { override def fromTLVValue(value: ByteVector): RangeEventDescriptorV0TLV = { val iter = ValueIterator(value) - val start = Int32(iter.takeBits(32)) - val count = UInt32(iter.takeBits(32)) - val step = UInt16(iter.takeBits(16)) - + val start = iter.takeI32() + val count = iter.takeU32() + val step = iter.takeU16() val unit = iter.takeString() - val precision = Int32(iter.takeBits(32)) + val precision = iter.takeI32() RangeEventDescriptorV0TLV(start, count, step, unit, precision) } @@ -535,13 +630,11 @@ sealed trait DigitDecompositionEventDescriptorV0TLV DigitDecompositionEventDescriptorV0TLV.tpe override lazy val value: ByteVector = { - val isSignedByte = - if (isSigned) ByteVector(TRUE_BYTE) else ByteVector(FALSE_BYTE) - - val numDigitBytes = numDigits.bytes - val unitBytes = TLV.getStringBytes(unit) - - base.bytes ++ isSignedByte ++ unitBytes ++ precision.bytes ++ numDigitBytes + base.bytes ++ + boolBytes(isSigned) ++ + strBytes(unit) ++ + precision.bytes ++ + numDigits.bytes } override def noncesNeeded: Int = { @@ -579,12 +672,11 @@ object DigitDecompositionEventDescriptorV0TLV value: ByteVector): DigitDecompositionEventDescriptorV0TLV = { val iter = ValueIterator(value) - val base = UInt16(iter.takeBits(16)) + val base = iter.takeU16() val isSigned = iter.takeBoolean() - val unit = iter.takeString() - val precision = Int32(iter.takeBits(32)) - val numDigits = UInt16(iter.takeBits(16)) + val precision = iter.takeI32() + val numDigits = iter.takeU16() DigitDecompositionEventDescriptorV0TLV(base, isSigned, @@ -631,12 +723,10 @@ case class OracleEventV0TLV( override def tpe: BigSizeUInt = OracleEventV0TLV.tpe override val value: ByteVector = { - val eventIdBytes = TLV.getStringBytes(eventId) - - val numNonces = UInt16(nonces.size) - val noncesBytes = nonces.foldLeft(numNonces.bytes)(_ ++ _.bytes) - - noncesBytes ++ eventMaturityEpoch.bytes ++ eventDescriptor.bytes ++ eventIdBytes + u16PrefixedList(nonces) ++ + eventMaturityEpoch.bytes ++ + eventDescriptor.bytes ++ + strBytes(eventId) } /** Gets the maturation of the event since epoch */ @@ -651,23 +741,9 @@ object OracleEventV0TLV extends TLVFactory[OracleEventV0TLV] { override def fromTLVValue(value: ByteVector): OracleEventV0TLV = { val iter = ValueIterator(value) - val numNonces = UInt16(iter.takeBits(16)) - val builder = Vector.newBuilder[SchnorrNonce] - - for (_ <- 0 until numNonces.toInt) { - val nonceBytes = iter.take(32) - builder.+=(SchnorrNonce(nonceBytes)) - } - - val nonces = builder.result() - - require( - numNonces.toInt == nonces.size, - s"Did not parse the expected number of nonces expected ${numNonces.toInt}, got ${nonces.size}") - - val eventMaturity = UInt32(iter.takeBits(32)) - val eventDescriptor = EventDescriptorTLV(iter.current) - iter.skip(eventDescriptor.byteSize) + val nonces = iter.takeU16PrefixedList(() => iter.take(SchnorrNonce, 32)) + val eventMaturity = iter.takeU32() + val eventDescriptor = iter.take(EventDescriptorTLV) val eventId = iter.takeString() OracleEventV0TLV(nonces, eventMaturity, eventDescriptor, eventId) @@ -712,9 +788,9 @@ object OracleAnnouncementV0TLV extends TLVFactory[OracleAnnouncementV0TLV] { override def fromTLVValue(value: ByteVector): OracleAnnouncementV0TLV = { val iter = ValueIterator(value) - val sig = SchnorrDigitalSignature(iter.take(64)) - val publicKey = SchnorrPublicKey(iter.take(32)) - val eventTLV = OracleEventV0TLV(iter.current) + val sig = iter.take(SchnorrDigitalSignature, 64) + val publicKey = iter.take(SchnorrPublicKey, 32) + val eventTLV = iter.take(OracleEventV0TLV) OracleAnnouncementV0TLV(sig, publicKey, eventTLV) } @@ -747,13 +823,13 @@ case class ContractInfoV0TLV(outcomes: Vector[(String, Satoshis)]) override val tpe: BigSizeUInt = ContractInfoV0TLV.tpe override val value: ByteVector = { - outcomes.foldLeft(BigSizeUInt(outcomes.length).bytes) { - case (bytes, (outcome, amt)) => - val outcomeBytes = CryptoUtil.serializeForHash(outcome) - bytes ++ BigSizeUInt - .calcFor(outcomeBytes) - .bytes ++ outcomeBytes ++ amt.toUInt64.bytes - } + bigSizePrefixedList[(String, Satoshis)]( + outcomes, + { + case (outcome, amt) => + val outcomeBytes = CryptoUtil.serializeForHash(outcome) + bigSizePrefix(outcomeBytes) ++ satBytes(amt) + }) } } @@ -763,15 +839,9 @@ object ContractInfoV0TLV extends TLVFactory[ContractInfoV0TLV] { override def fromTLVValue(value: ByteVector): ContractInfoV0TLV = { val iter = ValueIterator(value) - val numOutcomes = BigSizeUInt(iter.current) - iter.skip(numOutcomes) - - val outcomes = 0.until(numOutcomes.toInt).toVector.map { _ => - val outcomeLen = BigSizeUInt(iter.current) - iter.skip(outcomeLen) - val outcome = - new String(iter.take(outcomeLen.toInt).toArray, StandardCharsets.UTF_8) - val amt = Satoshis(UInt64(iter.takeBits(64))) + val outcomes = iter.takeBigSizePrefixedList { () => + val outcome = iter.takeString().normStr + val amt = iter.takeSats() outcome -> amt } @@ -790,8 +860,9 @@ case class TLVPoint(outcome: Long, value: Satoshis, isEndpoint: Boolean) } override def bytes: ByteVector = { - ByteVector(leadingByte) ++ BigSizeUInt(outcome).bytes ++ UInt64( - value.toLong).bytes + ByteVector(leadingByte) ++ + BigSizeUInt(outcome).bytes ++ + UInt64(value.toLong).bytes } } @@ -822,9 +893,10 @@ case class ContractInfoV1TLV( override val tpe: BigSizeUInt = ContractInfoV1TLV.tpe override val value: ByteVector = { - BigSizeUInt(base).bytes ++ UInt16(numDigits).bytes ++ UInt64( - totalCollateral.toLong).bytes ++ BigSizeUInt( - points.length).bytes ++ points.foldLeft(ByteVector.empty)(_ ++ _.bytes) + BigSizeUInt(base).bytes ++ + UInt16(numDigits).bytes ++ + satBytes(totalCollateral) ++ + bigSizePrefixedList(points) } } @@ -834,20 +906,12 @@ object ContractInfoV1TLV extends TLVFactory[ContractInfoV1TLV] { override def fromTLVValue(value: ByteVector): ContractInfoV1TLV = { val iter = ValueIterator(value) - val base = BigSizeUInt(iter.current) - iter.skip(base) - val numDigits = UInt16(iter.takeBits(16)) - val totalCollateral = UInt64(iter.takeBits(64)) - val numPoints = BigSizeUInt(iter.current) - iter.skip(numPoints) - val points = (0L until numPoints.toLong).toVector.map { _ => - iter.takePoint() - } + val base = iter.takeBigSize() + val numDigits = iter.takeU16() + val totalCollateral = iter.takeSats() + val points = iter.takeBigSizePrefixedList(() => iter.take(TLVPoint)) - ContractInfoV1TLV(base.toInt, - numDigits.toInt, - Satoshis(totalCollateral.toLong), - points) + ContractInfoV1TLV(base.toInt, numDigits.toInt, totalCollateral, points) } } @@ -874,9 +938,10 @@ object OracleInfoV0TLV extends TLVFactory[OracleInfoV0TLV] { override val tpe: BigSizeUInt = BigSizeUInt(42770) override def fromTLVValue(value: ByteVector): OracleInfoV0TLV = { - val (pubKeyBytes, rBytes) = value.splitAt(32) - val pubKey = SchnorrPublicKey(pubKeyBytes) - val rValue = SchnorrNonce(rBytes) + val iter = ValueIterator(value) + + val pubKey = iter.take(SchnorrPublicKey, 32) + val rValue = iter.take(SchnorrNonce, 32) OracleInfoV0TLV(pubKey, rValue) } @@ -889,7 +954,7 @@ case class OracleInfoV1TLV( override val tpe: BigSizeUInt = OracleInfoV1TLV.tpe override val value: ByteVector = { - nonces.foldLeft(pubKey.bytes)(_ ++ _.bytes) + pubKey.bytes ++ u16PrefixedList(nonces) } } @@ -897,16 +962,10 @@ object OracleInfoV1TLV extends TLVFactory[OracleInfoV1TLV] { override val tpe: BigSizeUInt = BigSizeUInt(42786) override def fromTLVValue(value: ByteVector): OracleInfoV1TLV = { - require( - value.length >= 64 && value.length % 32 == 0, - s"Expected multiple of 32 bytes with at least one nonce, got $value") - val iter = ValueIterator(value) - val pubKey = SchnorrPublicKey(iter.take(32)) - val nonces = (0L until iter.current.length / 32).toVector.map { _ => - SchnorrNonce(iter.take(32)) - } + val pubKey = iter.take(SchnorrPublicKey, 32) + val nonces = iter.takeU16PrefixedList(() => iter.take(SchnorrNonce, 32)) OracleInfoV1TLV(pubKey, nonces) } @@ -943,8 +1002,7 @@ case class FundingInputV0TLV( val redeemScript = redeemScriptOpt.getOrElse(EmptyScriptPubKey) - UInt16(prevTx.byteSize).bytes ++ - prevTx.bytes ++ + u16Prefix(prevTx.bytes) ++ prevTxVout.bytes ++ sequence.bytes ++ maxWitnessLen.bytes ++ @@ -958,11 +1016,10 @@ object FundingInputV0TLV extends TLVFactory[FundingInputV0TLV] { override def fromTLVValue(value: ByteVector): FundingInputV0TLV = { val iter = ValueIterator(value) - val prevTxLen = UInt16(iter.takeBits(16)) - val prevTx = Transaction(iter.take(prevTxLen.toInt)) - val prevTxVout = UInt32(iter.takeBits(32)) - val sequence = UInt32(iter.takeBits(32)) - val maxWitnessLen = UInt16(iter.takeBits(16)) + val prevTx = iter.takeU16Prefixed(iter.take(Transaction, _)) + val prevTxVout = iter.takeU32() + val sequence = iter.takeU32() + val maxWitnessLen = iter.takeU16() val redeemScript = iter.takeSPK() val redeemScriptOpt = redeemScript match { case EmptyScriptPubKey => None @@ -987,7 +1044,7 @@ case class CETSignaturesV0TLV(sigs: Vector[ECAdaptorSignature]) override val tpe: BigSizeUInt = CETSignaturesV0TLV.tpe override val value: ByteVector = { - sigs.foldLeft(BigSizeUInt(sigs.length).bytes)(_ ++ _.bytes) + bigSizePrefixedList(sigs) } } @@ -997,12 +1054,8 @@ object CETSignaturesV0TLV extends TLVFactory[CETSignaturesV0TLV] { override def fromTLVValue(value: ByteVector): CETSignaturesV0TLV = { val iter = ValueIterator(value) - val numSigs = BigSizeUInt(iter.current) - iter.skip(numSigs) - - val sigs = 0.until(numSigs.toInt).toVector.map { _ => - ECAdaptorSignature(iter.take(162)) - } + val sigs = + iter.takeBigSizePrefixedList(() => iter.take(ECAdaptorSignature, 162)) CETSignaturesV0TLV(sigs) } @@ -1015,14 +1068,11 @@ case class FundingSignaturesV0TLV(witnesses: Vector[ScriptWitnessV0]) override val tpe: BigSizeUInt = FundingSignaturesV0TLV.tpe override val value: ByteVector = { - witnesses.foldLeft(UInt16(witnesses.length).bytes) { - case (bytes, witness) => - witness.stack.reverse.foldLeft( - bytes ++ UInt16(witness.stack.length).bytes) { - case (bytes, stackElem) => - bytes ++ UInt16(stackElem.length).bytes ++ stackElem - } - } + u16PrefixedList( + witnesses, + { witness: ScriptWitnessV0 => + u16PrefixedList[ByteVector](witness.stack.toVector.reverse, u16Prefix) + }) } } @@ -1032,13 +1082,10 @@ object FundingSignaturesV0TLV extends TLVFactory[FundingSignaturesV0TLV] { override def fromTLVValue(value: ByteVector): FundingSignaturesV0TLV = { val iter = ValueIterator(value) - val numWitnesses = UInt16(iter.takeBits(16)) - val witnesses = (0 until numWitnesses.toInt).toVector.map { _ => - val numStackElements = UInt16(iter.takeBits(16)) - val stack = (0 until numStackElements.toInt).toVector.map { _ => - val stackElemLength = UInt16(iter.takeBits(16)) - iter.take(stackElemLength.toInt) - } + val witnesses = iter.takeU16PrefixedList { () => + val stack = + iter.takeU16PrefixedList(() => iter.takeU16Prefixed(iter.take)) + ScriptWitness(stack.reverse) match { case EmptyScriptWitness => throw new IllegalArgumentException(s"Invalid witness: $stack") @@ -1073,11 +1120,10 @@ case class DLCOfferTLV( oracleInfo.bytes ++ fundingPubKey.bytes ++ TLV.encodeScript(payoutSPK) ++ - totalCollateralSatoshis.toUInt64.bytes ++ - UInt16(fundingInputs.length).bytes ++ - fundingInputs.foldLeft(ByteVector.empty)(_ ++ _.bytes) ++ + satBytes(totalCollateralSatoshis) ++ + u16PrefixedList(fundingInputs) ++ TLV.encodeScript(changeSPK) ++ - feeRate.currencyUnit.satoshis.toUInt64.bytes ++ + satBytes(feeRate.currencyUnit.satoshis) ++ contractMaturityBound.toUInt32.bytes ++ contractTimeout.toUInt32.bytes } @@ -1090,24 +1136,18 @@ object DLCOfferTLV extends TLVFactory[DLCOfferTLV] { val iter = ValueIterator(value) val contractFlags = iter.take(1).head - val chainHash = DoubleSha256Digest(iter.take(32)) - val contractInfo = ContractInfoTLV.fromBytes(iter.current) - iter.skip(contractInfo) - val oracleInfo = OracleInfoTLV.fromBytes(iter.current) - iter.skip(oracleInfo) - val fundingPubKey = ECPublicKey(iter.take(33)) + val chainHash = iter.take(DoubleSha256Digest, 32) + val contractInfo = iter.take(ContractInfoTLV) + val oracleInfo = iter.take(OracleInfoTLV) + val fundingPubKey = iter.take(ECPublicKey, 33) val payoutSPK = iter.takeSPK() - val totalCollateralSatoshis = Satoshis(UInt64(iter.takeBits(64))) - val numFundingInputs = UInt16(iter.takeBits(16)) - val fundingInputs = (0 until numFundingInputs.toInt).toVector.map { _ => - val fundingInput = FundingInputV0TLV.fromBytes(iter.current) - iter.skip(fundingInput) - fundingInput - } + val totalCollateralSatoshis = iter.takeSats() + val fundingInputs = + iter.takeU16PrefixedList(() => iter.take(FundingInputV0TLV)) val changeSPK = iter.takeSPK() - val feeRate = SatoshisPerVirtualByte(Satoshis(UInt64(iter.takeBits(64)))) - val contractMaturityBound = BlockTimeStamp(UInt32(iter.takeBits(32))) - val contractTimeout = BlockTimeStamp(UInt32(iter.takeBits(32))) + val feeRate = SatoshisPerVirtualByte(iter.takeSats()) + val contractMaturityBound = BlockTimeStamp(iter.takeU32()) + val contractTimeout = BlockTimeStamp(iter.takeU32()) DLCOfferTLV( contractFlags, @@ -1140,11 +1180,10 @@ case class DLCAcceptTLV( override val value: ByteVector = { tempContractId.bytes ++ - totalCollateralSatoshis.toUInt64.bytes ++ + satBytes(totalCollateralSatoshis) ++ fundingPubKey.bytes ++ TLV.encodeScript(payoutSPK) ++ - UInt16(fundingInputs.length).bytes ++ - fundingInputs.foldLeft(ByteVector.empty)(_ ++ _.bytes) ++ + u16PrefixedList(fundingInputs) ++ TLV.encodeScript(changeSPK) ++ cetSignatures.bytes ++ refundSignature.toRawRS @@ -1157,19 +1196,14 @@ object DLCAcceptTLV extends TLVFactory[DLCAcceptTLV] { override def fromTLVValue(value: ByteVector): DLCAcceptTLV = { val iter = ValueIterator(value) - val tempContractId = Sha256Digest(iter.take(32)) - val totalCollateralSatoshis = Satoshis(UInt64(iter.takeBits(64))) - val fundingPubKey = ECPublicKey(iter.take(33)) + val tempContractId = iter.take(Sha256Digest, 32) + val totalCollateralSatoshis = iter.takeSats() + val fundingPubKey = iter.take(ECPublicKey, 33) val payoutSPK = iter.takeSPK() - val numFundingInputs = UInt16(iter.takeBits(16)) - val fundingInputs = (0 until numFundingInputs.toInt).toVector.map { _ => - val fundingInput = FundingInputV0TLV.fromBytes(iter.current) - iter.skip(fundingInput) - fundingInput - } + val fundingInputs = + iter.takeU16PrefixedList(() => iter.take(FundingInputV0TLV)) val changeSPK = iter.takeSPK() - val cetSignatures = CETSignaturesV0TLV.fromBytes(iter.current) - iter.skip(cetSignatures) + val cetSignatures = iter.take(CETSignaturesV0TLV) val refundSignature = ECDigitalSignature.fromRS(iter.take(64)) DLCAcceptTLV(tempContractId, @@ -1206,11 +1240,9 @@ object DLCSignTLV extends TLVFactory[DLCSignTLV] { val iter = ValueIterator(value) val contractId = iter.take(32) - val cetSignatures = CETSignaturesV0TLV.fromBytes(iter.current) - iter.skip(cetSignatures) + val cetSignatures = iter.take(CETSignaturesV0TLV) val refundSignature = ECDigitalSignature.fromRS(iter.take(64)) - val fundingSignatures = FundingSignaturesV0TLV.fromBytes(iter.current) - iter.skip(fundingSignatures) + val fundingSignatures = iter.take(FundingSignaturesV0TLV) DLCSignTLV(contractId, cetSignatures, refundSignature, fundingSignatures) }