From 932f04851a3f1e1a24e5196f1d51760dbdf59d41 Mon Sep 17 00:00:00 2001 From: araspitzu Date: Thu, 9 Apr 2020 10:36:30 +0200 Subject: [PATCH] Support additional TLV records in SendPayentRequest (#1367) * Support additional user defined TLVs when sending a payment (both single-part and MPP) * Allow encoding and decoding of even TLV types above the high range --- .../send/MultiPartPaymentLifecycle.scala | 6 +++-- .../payment/send/PaymentInitiator.scala | 11 +++++---- .../scala/fr/acinq/eclair/wire/Onion.scala | 13 ++++------- .../fr/acinq/eclair/wire/TlvCodecs.scala | 5 ++-- .../eclair/payment/PaymentInitiatorSpec.scala | 23 +++++++++++++++---- .../acinq/eclair/wire/OnionCodecsSpec.scala | 9 ++++++++ .../fr/acinq/eclair/wire/TlvCodecsSpec.scala | 12 ++++++++-- 7 files changed, 57 insertions(+), 22 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala index dad62410c..e556a395f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala @@ -304,6 +304,7 @@ object MultiPartPaymentLifecycle { * @param assistedRoutes routing hints (usually from a Bolt 11 invoice). * @param routeParams parameters to fine-tune the routing algorithm. * @param additionalTlvs when provided, additional tlvs that will be added to the onion sent to the target node. + * @param userCustomTlvs when provided, additional user-defined custom tlvs that will be added to the onion sent to the target node. */ case class SendMultiPartPayment(paymentSecret: ByteVector32, targetNodeId: PublicKey, @@ -312,7 +313,8 @@ object MultiPartPaymentLifecycle { maxAttempts: Int, assistedRoutes: Seq[Seq[ExtraHop]] = Nil, routeParams: Option[RouteParams] = None, - additionalTlvs: Seq[OnionTlv] = Nil) { + additionalTlvs: Seq[OnionTlv] = Nil, + userCustomTlvs: Seq[GenericTlv] = Nil) { require(totalAmount > 0.msat, s"total amount must be > 0") } @@ -416,7 +418,7 @@ object MultiPartPaymentLifecycle { private def createChildPayment(nodeParams: NodeParams, request: SendMultiPartPayment, childAmount: MilliSatoshi, channel: OutgoingChannel): SendPayment = { SendPayment( request.targetNodeId, - Onion.createMultiPartPayload(childAmount, request.totalAmount, request.targetExpiry, request.paymentSecret, request.additionalTlvs), + Onion.createMultiPartPayload(childAmount, request.totalAmount, request.targetExpiry, request.paymentSecret, request.additionalTlvs, request.userCustomTlvs), request.maxAttempts, request.assistedRoutes, request.routeParams, diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala index 670934b46..7a5f37287 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala @@ -54,13 +54,14 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, relayer: ActorR case Some(invoice) if invoice.features.allowMultiPart && Features.hasFeature(nodeParams.features, Features.BasicMultiPartPayment) => invoice.paymentSecret match { case Some(paymentSecret) => - spawnMultiPartPaymentFsm(paymentCfg) forward SendMultiPartPayment(paymentSecret, r.recipientNodeId, r.recipientAmount, finalExpiry, r.maxAttempts, r.assistedRoutes, r.routeParams) + spawnMultiPartPaymentFsm(paymentCfg) forward SendMultiPartPayment(paymentSecret, r.recipientNodeId, r.recipientAmount, finalExpiry, r.maxAttempts, r.assistedRoutes, r.routeParams, userCustomTlvs = r.userCustomTlvs) case None => sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(PaymentSecretMissing) :: Nil) } case _ => - // NB: we only generate legacy payment onions for now for maximum compatibility. - spawnPaymentFsm(paymentCfg) forward SendPayment(r.recipientNodeId, FinalLegacyPayload(r.recipientAmount, finalExpiry), r.maxAttempts, r.assistedRoutes, r.routeParams) + val paymentSecret = r.paymentRequest.flatMap(_.paymentSecret) + val finalPayload = Onion.createSinglePartPayload(r.recipientAmount, finalExpiry, paymentSecret, r.userCustomTlvs) + spawnPaymentFsm(paymentCfg) forward SendPayment(r.recipientNodeId, finalPayload, r.maxAttempts, r.assistedRoutes, r.routeParams) } case r: SendTrampolinePaymentRequest => @@ -201,6 +202,7 @@ object PaymentInitiator { * @param externalId (optional) externally-controlled identifier (to reconcile between application DB and eclair DB). * @param assistedRoutes (optional) routing hints (usually from a Bolt 11 invoice). * @param routeParams (optional) parameters to fine-tune the routing algorithm. + * @param userCustomTlvs (optional) user-defined custom tlvs that will be added to the onion sent to the target node. */ case class SendPaymentRequest(recipientAmount: MilliSatoshi, paymentHash: ByteVector32, @@ -210,7 +212,8 @@ object PaymentInitiator { paymentRequest: Option[PaymentRequest] = None, externalId: Option[String] = None, assistedRoutes: Seq[Seq[ExtraHop]] = Nil, - routeParams: Option[RouteParams] = None) { + routeParams: Option[RouteParams] = None, + userCustomTlvs: Seq[GenericTlv] = Nil) { // We add one block in order to not have our htlcs fail when a new block has just been found. def finalExpiry(currentBlockHeight: Long) = finalExpiryDelta.toCltvExpiry(currentBlockHeight + 1) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala index 7b30d6a88..1689c9f18 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala @@ -276,17 +276,14 @@ object Onion { NodeRelayPayload(TlvStream(tlvs2)) } - def createSinglePartPayload(amount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: Option[ByteVector32] = None): FinalPayload = paymentSecret match { - // We try to use the legacy format as much as possible for maximum compatibility, but when we have a payment secret we need to use TLV to include it. - case Some(paymentSecret) => FinalTlvPayload(TlvStream(AmountToForward(amount), OutgoingCltv(expiry), PaymentData(paymentSecret, amount))) + def createSinglePartPayload(amount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: Option[ByteVector32] = None, userCustomTlvs: Seq[GenericTlv] = Nil): FinalPayload = paymentSecret match { + case Some(paymentSecret) => FinalTlvPayload(TlvStream(Seq(AmountToForward(amount), OutgoingCltv(expiry), PaymentData(paymentSecret, amount)), userCustomTlvs)) + case None if userCustomTlvs.nonEmpty => FinalTlvPayload(TlvStream(Seq(AmountToForward(amount), OutgoingCltv(expiry)), userCustomTlvs)) case None => FinalLegacyPayload(amount, expiry) } - def createMultiPartPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32): FinalPayload = - FinalTlvPayload(TlvStream(AmountToForward(amount), OutgoingCltv(expiry), PaymentData(paymentSecret, totalAmount))) - - def createMultiPartPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, additionalTlvs: Seq[OnionTlv]): FinalPayload = - FinalTlvPayload(TlvStream(AmountToForward(amount) +: OutgoingCltv(expiry) +: PaymentData(paymentSecret, totalAmount) +: additionalTlvs)) + def createMultiPartPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, additionalTlvs: Seq[OnionTlv] = Nil, userCustomTlvs: Seq[GenericTlv] = Nil): FinalPayload = + FinalTlvPayload(TlvStream(AmountToForward(amount) +: OutgoingCltv(expiry) +: PaymentData(paymentSecret, totalAmount) +: additionalTlvs, userCustomTlvs)) /** Create a trampoline outer payload. */ def createTrampolinePayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, trampolinePacket: OnionRoutingPacket): FinalPayload = { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala index 2a8dbb757..4c7fd4160 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala @@ -26,8 +26,9 @@ import scodec.{Attempt, Codec, Err} /** * Created by t-bast on 20/06/2019. */ - object TlvCodecs { + // high range types are greater than or equal 2^16, see https://github.com/lightningnetwork/lightning-rfc/blob/master/01-messaging.md#type-length-value-format + private val TLV_TYPE_HIGH_RANGE = 65536 /** * Truncated uint64 (0 to 8 bytes unsigned integer). @@ -104,7 +105,7 @@ object TlvCodecs { val ltu16: Codec[Int] = variableSizeBytes(uint8, tu16) private def validateGenericTlv(g: GenericTlv): Attempt[GenericTlv] = { - if (g.tag.toBigInt % 2 == 0) { + if (g.tag < TLV_TYPE_HIGH_RANGE && g.tag.toBigInt % 2 == 0) { Attempt.Failure(Err("unknown even tlv type")) } else { Attempt.Successful(g) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala index 2d03af6a6..99397db55 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala @@ -31,9 +31,11 @@ import fr.acinq.eclair.payment.send.PaymentInitiator._ import fr.acinq.eclair.payment.send.PaymentLifecycle.{SendPayment, SendPaymentToRoute} import fr.acinq.eclair.payment.send.{PaymentError, PaymentInitiator} import fr.acinq.eclair.router.{NodeHop, RouteParams} -import fr.acinq.eclair.wire.Onion.FinalLegacyPayload -import fr.acinq.eclair.wire.{Onion, OnionCodecs, OnionTlv, TrampolineFeeInsufficient} +import fr.acinq.eclair.wire.Onion.{FinalLegacyPayload, FinalTlvPayload} +import fr.acinq.eclair.wire.OnionTlv.{AmountToForward, OutgoingCltv} +import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, NodeParams, TestConstants, randomBytes32, randomKey} +import fr.acinq.eclair.UInt64.Conversions._ import org.scalatest.{Outcome, Tag, fixture} import scodec.bits.HexStringSyntax @@ -69,6 +71,19 @@ class PaymentInitiatorSpec extends TestKit(ActorSystem("test")) with fixture.Fun withFixture(test.toNoArgTest(FixtureParam(nodeParams, initiator, payFsm, multiPartPayFsm, sender, eventListener))) } + test("forward payment with user custom tlv records") { f => + import f._ + val keySendTlvRecords = Seq(GenericTlv(5482373484L, paymentPreimage)) + val req = SendPaymentRequest(finalAmount, paymentHash, c, 1, CltvExpiryDelta(42), userCustomTlvs = keySendTlvRecords) + sender.send(initiator, req) + sender.expectMsgType[UUID] + payFsm.expectMsgType[SendPaymentConfig] + val FinalTlvPayload(tlvs) = payFsm.expectMsgType[SendPayment].finalPayload + assert(tlvs.get[AmountToForward].get.amount == finalAmount) + assert(tlvs.get[OutgoingCltv].get.cltv == req.finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight + 1)) + assert(tlvs.unknown == keySendTlvRecords) + } + test("reject payment with unknown mandatory feature") { f => import f._ val unknownFeature = 42 @@ -105,14 +120,14 @@ class PaymentInitiatorSpec extends TestKit(ActorSystem("test")) with fixture.Fun payFsm.expectMsg(SendPayment(e, FinalLegacyPayload(finalAmount, Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry(nodeParams.currentBlockHeight + 1)), 3)) } - test("forward legacy payment when multi-part deactivated", Tag("mpp_disabled")) { f => + test("forward single-part payment when multi-part deactivated", Tag("mpp_disabled")) { f => import f._ val pr = PaymentRequest(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, randomKey, "Some MPP invoice", features = Some(Features(VariableLengthOnion.optional, PaymentSecret.optional, BasicMultiPartPayment.optional))) val req = SendPaymentRequest(finalAmount, paymentHash, c, 1, CltvExpiryDelta(42), Some(pr)) sender.send(initiator, req) val id = sender.expectMsgType[UUID] payFsm.expectMsg(SendPaymentConfig(id, id, None, paymentHash, finalAmount, c, Upstream.Local(id), Some(pr), storeInDb = true, publishEvent = true, Nil)) - payFsm.expectMsg(SendPayment(c, FinalLegacyPayload(finalAmount, req.finalExpiry(nodeParams.currentBlockHeight)), 1)) + payFsm.expectMsg(SendPayment(c, FinalTlvPayload(TlvStream(OnionTlv.AmountToForward(finalAmount), OnionTlv.OutgoingCltv(req.finalExpiry(nodeParams.currentBlockHeight)), OnionTlv.PaymentData(pr.paymentSecret.get, finalAmount))), 1)) } test("forward multi-part payment") { f => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala index 32b9c41ac..60c892a0f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala @@ -182,6 +182,15 @@ class OnionCodecsSpec extends FunSuite { } } + test("encode/decode variable-length (tlv) final per-hop payload with custom user records") { + val tlvs = TlvStream[OnionTlv](Seq(AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42))), Seq(GenericTlv(5482373484L, hex"16c7ec71663784ff100b6eface1e60a97b92ea9d18b8ece5e558586bc7453828"))) + val bin = hex"31 02020231 04012a ff0000000146c6616c2016c7ec71663784ff100b6eface1e60a97b92ea9d18b8ece5e558586bc7453828" + + val encoded = finalPerHopPayloadCodec.encode(FinalTlvPayload(tlvs)).require.bytes + assert(encoded === bin) + assert(finalPerHopPayloadCodec.decode(bin.bits).require.value == FinalTlvPayload(tlvs)) + } + test("decode multi-part final per-hop payload") { val notMultiPart = finalPerHopPayloadCodec.decode(hex"07 02020231 04012a".bits).require.value assert(notMultiPart.totalAmount === 561.msat) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala index 4cb20d5cf..2b7ca741f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala @@ -229,9 +229,7 @@ class TlvCodecsSpec extends FunSuite { hex"12 00", hex"0a 00", hex"fd0102 00", - hex"fe01000002 00", hex"01020101 0a0101", - hex"ff0100000000000002 00", // Invalid TestTlv1. hex"01 01 00", // not minimally-encoded hex"01 02 0001", // not minimally-encoded @@ -308,6 +306,16 @@ class TlvCodecsSpec extends FunSuite { assert(lengthPrefixedTestTlvStreamCodec.encode(stream).require.toByteVector === hex"0f 01012a 0b012b 0d012a fd00fe02002a") } + test("encode/decode custom even tlv records") { + val lowRangeEven = TlvStream[TestTlv](records = Nil, unknown = Seq(GenericTlv(124, hex"2a"))) + val highRangeEven = TlvStream[TestTlv](records = Nil, unknown = Seq(GenericTlv(67876545678L, hex"2b"))) + + assert(testTlvStreamCodec.encode(lowRangeEven).isFailure) + assert(testTlvStreamCodec.encode(highRangeEven).isSuccessful) + assert(testTlvStreamCodec.decode(hex"7c 01 2a".toBitVector).isFailure) // lowRangeEven + assert(testTlvStreamCodec.decode(testTlvStreamCodec.encode(highRangeEven).require).isSuccessful) + } + test("encode invalid tlv stream") { val testCases = Seq( // Unknown even type.