Add TLVs defined in BOLT 4 (#4380)

* Add TLVs defined in BOLT 4

* Fix test case
This commit is contained in:
benthecarman 2022-06-13 11:58:11 -05:00 committed by GitHub
parent b021649ac4
commit 344a8fd759
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 182 additions and 4 deletions

View file

@ -39,7 +39,7 @@ class LnMessageTest extends BitcoinSUnitTest {
Try(LnMessageFactory(InitTLV)(hex"00100000000001012a030104")).isSuccess)
assert(Try(LnMessageFactory(InitTLV)(hex"00100000000001")).isFailure)
assert(Try(LnMessageFactory(InitTLV)(hex"00100000000002012a")).isFailure)
assert(Try(LnMessageFactory(InitTLV)(hex"001000000000ca012a")).isFailure)
assert(
Try(LnMessageFactory(InitTLV)(hex"001000000000010101010102")).isFailure)
}

View file

@ -47,6 +47,34 @@ class TLVTest extends BitcoinSUnitTest {
}
}
"AmtToForwardTLV" must "have serialization symmetry" in {
forAll(TLVGen.amtToForwardTLV) { tlv =>
assert(AmtToForwardTLV(tlv.bytes) == tlv)
assert(TLV(tlv.bytes) == tlv)
}
}
"OutgoingCLTVValueTLV" must "have serialization symmetry" in {
forAll(TLVGen.outgoingCLTVValueTLV) { tlv =>
assert(OutgoingCLTVValueTLV(tlv.bytes) == tlv)
assert(TLV(tlv.bytes) == tlv)
}
}
"ShortChannelIdTLV" must "have serialization symmetry" in {
forAll(TLVGen.shortChannelIdTLV) { tlv =>
assert(ShortChannelIdTLV(tlv.bytes) == tlv)
assert(TLV(tlv.bytes) == tlv)
}
}
"PaymentDataTLV" must "have serialization symmetry" in {
forAll(TLVGen.paymentDataTLV) { tlv =>
assert(PaymentDataTLV(tlv.bytes) == tlv)
assert(TLV(tlv.bytes) == tlv)
}
}
"PingTLV" must "have serialization symmetry" in {
forAll(TLVGen.pingTLV) { ping =>
assert(PingTLV(ping.bytes) == ping)

View file

@ -66,6 +66,7 @@ sealed abstract class Number[T <: Number[T]]
def &(num: T): T = apply(underlying & num.underlying)
def unary_- : T = apply(-underlying)
def truncatedBytes: ByteVector = bytes.dropWhile(_ == 0x00)
}
/** Represents a signed number in our number system
@ -349,6 +350,10 @@ object UInt16
UInt16(bytes.toLong(signed = false, ordering = ByteOrdering.BigEndian))
}
def fromTruncatedBytes(bytes: ByteVector): UInt16 = {
fromBytes(bytes.padLeft(2))
}
def apply(long: Long): UInt16 = {
checkCached(long)
}
@ -393,6 +398,10 @@ object UInt32
UInt32(bytes.toLong(signed = false, ordering = ByteOrdering.BigEndian))
}
def fromTruncatedBytes(bytes: ByteVector): UInt32 = {
fromBytes(bytes.padLeft(4))
}
def apply(int: Int): UInt32 = {
apply(int.toLong)
}
@ -440,6 +449,10 @@ object UInt64
UInt64(NumberUtil.toUnsignedInt(bytes))
}
def fromTruncatedBytes(bytes: ByteVector): UInt64 = {
fromBytes(bytes.padLeft(8))
}
def apply(long: Long): UInt64 = {
checkCached(long)
}

View file

@ -43,7 +43,8 @@ object LnMessage extends Factory[LnMessage[TLV]] {
throw new IllegalArgumentException(s"Parsed unknown TLV $unknown")
case _: DLCSetupTLV | _: DLCSetupPieceTLV | _: InitTLV | _: DLCOracleTLV |
_: ErrorTLV | _: PingTLV | _: PongTLV | _: ContractInfoV0TLV |
_: ContractInfoV1TLV | _: PayoutCurvePieceTLV =>
_: ContractInfoV1TLV | _: PayoutCurvePieceTLV | _: AmtToForwardTLV |
_: OutgoingCLTVValueTLV | _: ShortChannelIdTLV | _: PaymentDataTLV =>
()
}

View file

@ -9,6 +9,9 @@ import org.bitcoins.core.protocol.dlc.models.{
OutcomePayoutPoint,
PiecewisePolynomialEndpoint
}
import org.bitcoins.core.protocol.ln.PaymentSecret
import org.bitcoins.core.protocol.ln.channel.ShortChannelId
import org.bitcoins.core.protocol.ln.currency.MilliSatoshis
import org.bitcoins.core.protocol.script._
import org.bitcoins.core.protocol.tlv.TLV.{
DecodeTLVResult,
@ -163,6 +166,10 @@ object TLV extends TLVParentFactory[TLV] {
Vector(
InitTLV,
ErrorTLV,
AmtToForwardTLV,
OutgoingCLTVValueTLV,
ShortChannelIdTLV,
PaymentDataTLV,
PingTLV,
PongTLV,
OracleEventV0TLV,
@ -297,6 +304,94 @@ object UnknownTLV extends Factory[UnknownTLV] {
}
}
case class AmtToForwardTLV(amt: MilliSatoshis) extends TLV {
override val tpe: BigSizeUInt = AmtToForwardTLV.tpe
override val value: ByteVector = {
amt.toUInt64.truncatedBytes
}
}
object AmtToForwardTLV extends TLVFactory[AmtToForwardTLV] {
override val tpe: BigSizeUInt = BigSizeUInt(2)
override def fromTLVValue(value: ByteVector): AmtToForwardTLV = {
val uint64 = UInt64.fromTruncatedBytes(value)
val msat = MilliSatoshis(uint64.toBigInt)
AmtToForwardTLV(msat)
}
override val typeName: String = "AmtToForwardTLV"
}
case class OutgoingCLTVValueTLV(cltv: UInt32) extends TLV {
override val tpe: BigSizeUInt = OutgoingCLTVValueTLV.tpe
override val value: ByteVector = {
cltv.truncatedBytes
}
}
object OutgoingCLTVValueTLV extends TLVFactory[OutgoingCLTVValueTLV] {
override val tpe: BigSizeUInt = BigSizeUInt(4)
override def fromTLVValue(value: ByteVector): OutgoingCLTVValueTLV = {
val iter = ValueIterator(value)
val cltv = UInt32.fromTruncatedBytes(iter.current)
OutgoingCLTVValueTLV(cltv)
}
override val typeName: String = "OutgoingCLTVValueTLV"
}
case class ShortChannelIdTLV(scid: ShortChannelId) extends TLV {
override val tpe: BigSizeUInt = ShortChannelIdTLV.tpe
override val value: ByteVector = scid.bytes
}
object ShortChannelIdTLV extends TLVFactory[ShortChannelIdTLV] {
override val tpe: BigSizeUInt = BigSizeUInt(6)
override def fromTLVValue(value: ByteVector): ShortChannelIdTLV = {
val iter = ValueIterator(value)
val scid = iter.take(ShortChannelId, 8)
ShortChannelIdTLV(scid)
}
override val typeName: String = "ShortChannelIdTLV"
}
case class PaymentDataTLV(paymentSecret: PaymentSecret, msats: MilliSatoshis)
extends TLV {
override val tpe: BigSizeUInt = PaymentDataTLV.tpe
override val value: ByteVector = {
paymentSecret.bytes ++ msats.toUInt64.truncatedBytes
}
}
object PaymentDataTLV extends TLVFactory[PaymentDataTLV] {
override val tpe: BigSizeUInt = BigSizeUInt(8)
override def fromTLVValue(value: ByteVector): PaymentDataTLV = {
val iter = ValueIterator(value)
val secret = iter.take(PaymentSecret, 32)
val uint64 = UInt64.fromTruncatedBytes(iter.current)
val msat = MilliSatoshis(uint64.toBigInt)
PaymentDataTLV(secret, msat)
}
override val typeName: String = "PaymentDataTLV"
}
/** @see https://github.com/lightningnetwork/lightning-rfc/blob/master/01-messaging.md#the-init-message */
case class InitTLV(
globalFeatureBytes: ByteVector,

View file

@ -32,7 +32,9 @@ class DLCDataHandler(dlcWalletApi: DLCWalletApi, connectionHandler: ActorRef)
private def handleTLVMessage(lnMessage: LnMessage[TLV]): Future[Unit] = {
lnMessage.tlv match {
case msg @ (_: UnknownTLV | _: DLCOracleTLV | _: DLCSetupPieceTLV) =>
case msg @ (_: UnknownTLV | _: DLCOracleTLV | _: DLCSetupPieceTLV |
_: ShortChannelIdTLV | _: OutgoingCLTVValueTLV | _: AmtToForwardTLV |
_: PaymentDataTLV) =>
log.error(s"Received unhandled message $msg")
Future.unit
case _: InitTLV =>

View file

@ -551,7 +551,8 @@ object DLCParsingTestVector extends TestVectorParser[DLCParsingTestVector] {
DLCMessageTestVector(LnMessage(tlv), "oracle_attestment_v0", fields)
case _: UnknownTLV | _: ErrorTLV | _: PingTLV | _: PongTLV | _: InitTLV |
_: SendOfferTLV =>
_: SendOfferTLV | _: AmtToForwardTLV | _: OutgoingCLTVValueTLV |
_: PaymentDataTLV | _: ShortChannelIdTLV =>
throw new IllegalArgumentException(
s"DLCParsingTestVector is only defined for DLC messages and TLVs, got $tlv")
}

View file

@ -13,6 +13,7 @@ import org.bitcoins.core.util.sorted._
import org.bitcoins.core.wallet.fee.SatoshisPerVirtualByte
import org.bitcoins.crypto.ECPrivateKey
import org.bitcoins.testkitcore.dlc.DLCTestUtil
import org.bitcoins.testkitcore.gen.ln._
import org.scalacheck.Gen
trait TLVGen {
@ -53,6 +54,39 @@ trait TLVGen {
}
}
def amtToForwardTLV: Gen[AmtToForwardTLV] = {
for {
msat <- LnCurrencyUnitGen.milliSatoshis
} yield {
AmtToForwardTLV(msat)
}
}
def outgoingCLTVValueTLV: Gen[OutgoingCLTVValueTLV] = {
for {
uint32 <- NumberGenerator.uInt32s
} yield {
OutgoingCLTVValueTLV(uint32)
}
}
def shortChannelIdTLV: Gen[ShortChannelIdTLV] = {
for {
scid <- LnRouteGen.shortChannelId
} yield {
ShortChannelIdTLV(scid)
}
}
def paymentDataTLV: Gen[PaymentDataTLV] = {
for {
secretTag <- LnInvoiceGen.secret
msat <- LnCurrencyUnitGen.milliSatoshis
} yield {
PaymentDataTLV(secretTag.secret, msat)
}
}
def pingTLV: Gen[PingTLV] = {
for {
num <- NumberGenerator.uInt16
@ -568,6 +602,10 @@ trait TLVGen {
errorTLV,
pingTLV,
pongTLV,
amtToForwardTLV,
outgoingCLTVValueTLV,
shortChannelIdTLV,
paymentDataTLV,
oracleEventV0TLV,
eventDescriptorTLV,
oracleAnnouncementV0TLV,