1
0
mirror of https://github.com/ACINQ/eclair.git synced 2024-11-20 02:27:32 +01:00

Support additional TLV records in SendPayentRequest (#1367)

* Support additional user defined TLVs when sending a payment (both single-part and MPP)

* Allow encoding and decoding of even TLV types above the high range
This commit is contained in:
araspitzu 2020-04-09 10:36:30 +02:00 committed by GitHub
parent 7866be11c3
commit 932f04851a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 57 additions and 22 deletions

View File

@ -304,6 +304,7 @@ object MultiPartPaymentLifecycle {
* @param assistedRoutes routing hints (usually from a Bolt 11 invoice).
* @param routeParams parameters to fine-tune the routing algorithm.
* @param additionalTlvs when provided, additional tlvs that will be added to the onion sent to the target node.
* @param userCustomTlvs when provided, additional user-defined custom tlvs that will be added to the onion sent to the target node.
*/
case class SendMultiPartPayment(paymentSecret: ByteVector32,
targetNodeId: PublicKey,
@ -312,7 +313,8 @@ object MultiPartPaymentLifecycle {
maxAttempts: Int,
assistedRoutes: Seq[Seq[ExtraHop]] = Nil,
routeParams: Option[RouteParams] = None,
additionalTlvs: Seq[OnionTlv] = Nil) {
additionalTlvs: Seq[OnionTlv] = Nil,
userCustomTlvs: Seq[GenericTlv] = Nil) {
require(totalAmount > 0.msat, s"total amount must be > 0")
}
@ -416,7 +418,7 @@ object MultiPartPaymentLifecycle {
private def createChildPayment(nodeParams: NodeParams, request: SendMultiPartPayment, childAmount: MilliSatoshi, channel: OutgoingChannel): SendPayment = {
SendPayment(
request.targetNodeId,
Onion.createMultiPartPayload(childAmount, request.totalAmount, request.targetExpiry, request.paymentSecret, request.additionalTlvs),
Onion.createMultiPartPayload(childAmount, request.totalAmount, request.targetExpiry, request.paymentSecret, request.additionalTlvs, request.userCustomTlvs),
request.maxAttempts,
request.assistedRoutes,
request.routeParams,

View File

@ -54,13 +54,14 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, relayer: ActorR
case Some(invoice) if invoice.features.allowMultiPart && Features.hasFeature(nodeParams.features, Features.BasicMultiPartPayment) =>
invoice.paymentSecret match {
case Some(paymentSecret) =>
spawnMultiPartPaymentFsm(paymentCfg) forward SendMultiPartPayment(paymentSecret, r.recipientNodeId, r.recipientAmount, finalExpiry, r.maxAttempts, r.assistedRoutes, r.routeParams)
spawnMultiPartPaymentFsm(paymentCfg) forward SendMultiPartPayment(paymentSecret, r.recipientNodeId, r.recipientAmount, finalExpiry, r.maxAttempts, r.assistedRoutes, r.routeParams, userCustomTlvs = r.userCustomTlvs)
case None =>
sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(PaymentSecretMissing) :: Nil)
}
case _ =>
// NB: we only generate legacy payment onions for now for maximum compatibility.
spawnPaymentFsm(paymentCfg) forward SendPayment(r.recipientNodeId, FinalLegacyPayload(r.recipientAmount, finalExpiry), r.maxAttempts, r.assistedRoutes, r.routeParams)
val paymentSecret = r.paymentRequest.flatMap(_.paymentSecret)
val finalPayload = Onion.createSinglePartPayload(r.recipientAmount, finalExpiry, paymentSecret, r.userCustomTlvs)
spawnPaymentFsm(paymentCfg) forward SendPayment(r.recipientNodeId, finalPayload, r.maxAttempts, r.assistedRoutes, r.routeParams)
}
case r: SendTrampolinePaymentRequest =>
@ -201,6 +202,7 @@ object PaymentInitiator {
* @param externalId (optional) externally-controlled identifier (to reconcile between application DB and eclair DB).
* @param assistedRoutes (optional) routing hints (usually from a Bolt 11 invoice).
* @param routeParams (optional) parameters to fine-tune the routing algorithm.
* @param userCustomTlvs (optional) user-defined custom tlvs that will be added to the onion sent to the target node.
*/
case class SendPaymentRequest(recipientAmount: MilliSatoshi,
paymentHash: ByteVector32,
@ -210,7 +212,8 @@ object PaymentInitiator {
paymentRequest: Option[PaymentRequest] = None,
externalId: Option[String] = None,
assistedRoutes: Seq[Seq[ExtraHop]] = Nil,
routeParams: Option[RouteParams] = None) {
routeParams: Option[RouteParams] = None,
userCustomTlvs: Seq[GenericTlv] = Nil) {
// We add one block in order to not have our htlcs fail when a new block has just been found.
def finalExpiry(currentBlockHeight: Long) = finalExpiryDelta.toCltvExpiry(currentBlockHeight + 1)
}

View File

@ -276,17 +276,14 @@ object Onion {
NodeRelayPayload(TlvStream(tlvs2))
}
def createSinglePartPayload(amount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: Option[ByteVector32] = None): FinalPayload = paymentSecret match {
// We try to use the legacy format as much as possible for maximum compatibility, but when we have a payment secret we need to use TLV to include it.
case Some(paymentSecret) => FinalTlvPayload(TlvStream(AmountToForward(amount), OutgoingCltv(expiry), PaymentData(paymentSecret, amount)))
def createSinglePartPayload(amount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: Option[ByteVector32] = None, userCustomTlvs: Seq[GenericTlv] = Nil): FinalPayload = paymentSecret match {
case Some(paymentSecret) => FinalTlvPayload(TlvStream(Seq(AmountToForward(amount), OutgoingCltv(expiry), PaymentData(paymentSecret, amount)), userCustomTlvs))
case None if userCustomTlvs.nonEmpty => FinalTlvPayload(TlvStream(Seq(AmountToForward(amount), OutgoingCltv(expiry)), userCustomTlvs))
case None => FinalLegacyPayload(amount, expiry)
}
def createMultiPartPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32): FinalPayload =
FinalTlvPayload(TlvStream(AmountToForward(amount), OutgoingCltv(expiry), PaymentData(paymentSecret, totalAmount)))
def createMultiPartPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, additionalTlvs: Seq[OnionTlv]): FinalPayload =
FinalTlvPayload(TlvStream(AmountToForward(amount) +: OutgoingCltv(expiry) +: PaymentData(paymentSecret, totalAmount) +: additionalTlvs))
def createMultiPartPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, additionalTlvs: Seq[OnionTlv] = Nil, userCustomTlvs: Seq[GenericTlv] = Nil): FinalPayload =
FinalTlvPayload(TlvStream(AmountToForward(amount) +: OutgoingCltv(expiry) +: PaymentData(paymentSecret, totalAmount) +: additionalTlvs, userCustomTlvs))
/** Create a trampoline outer payload. */
def createTrampolinePayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, trampolinePacket: OnionRoutingPacket): FinalPayload = {

View File

@ -26,8 +26,9 @@ import scodec.{Attempt, Codec, Err}
/**
* Created by t-bast on 20/06/2019.
*/
object TlvCodecs {
// high range types are greater than or equal 2^16, see https://github.com/lightningnetwork/lightning-rfc/blob/master/01-messaging.md#type-length-value-format
private val TLV_TYPE_HIGH_RANGE = 65536
/**
* Truncated uint64 (0 to 8 bytes unsigned integer).
@ -104,7 +105,7 @@ object TlvCodecs {
val ltu16: Codec[Int] = variableSizeBytes(uint8, tu16)
private def validateGenericTlv(g: GenericTlv): Attempt[GenericTlv] = {
if (g.tag.toBigInt % 2 == 0) {
if (g.tag < TLV_TYPE_HIGH_RANGE && g.tag.toBigInt % 2 == 0) {
Attempt.Failure(Err("unknown even tlv type"))
} else {
Attempt.Successful(g)

View File

@ -31,9 +31,11 @@ import fr.acinq.eclair.payment.send.PaymentInitiator._
import fr.acinq.eclair.payment.send.PaymentLifecycle.{SendPayment, SendPaymentToRoute}
import fr.acinq.eclair.payment.send.{PaymentError, PaymentInitiator}
import fr.acinq.eclair.router.{NodeHop, RouteParams}
import fr.acinq.eclair.wire.Onion.FinalLegacyPayload
import fr.acinq.eclair.wire.{Onion, OnionCodecs, OnionTlv, TrampolineFeeInsufficient}
import fr.acinq.eclair.wire.Onion.{FinalLegacyPayload, FinalTlvPayload}
import fr.acinq.eclair.wire.OnionTlv.{AmountToForward, OutgoingCltv}
import fr.acinq.eclair.wire._
import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, NodeParams, TestConstants, randomBytes32, randomKey}
import fr.acinq.eclair.UInt64.Conversions._
import org.scalatest.{Outcome, Tag, fixture}
import scodec.bits.HexStringSyntax
@ -69,6 +71,19 @@ class PaymentInitiatorSpec extends TestKit(ActorSystem("test")) with fixture.Fun
withFixture(test.toNoArgTest(FixtureParam(nodeParams, initiator, payFsm, multiPartPayFsm, sender, eventListener)))
}
test("forward payment with user custom tlv records") { f =>
import f._
val keySendTlvRecords = Seq(GenericTlv(5482373484L, paymentPreimage))
val req = SendPaymentRequest(finalAmount, paymentHash, c, 1, CltvExpiryDelta(42), userCustomTlvs = keySendTlvRecords)
sender.send(initiator, req)
sender.expectMsgType[UUID]
payFsm.expectMsgType[SendPaymentConfig]
val FinalTlvPayload(tlvs) = payFsm.expectMsgType[SendPayment].finalPayload
assert(tlvs.get[AmountToForward].get.amount == finalAmount)
assert(tlvs.get[OutgoingCltv].get.cltv == req.finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight + 1))
assert(tlvs.unknown == keySendTlvRecords)
}
test("reject payment with unknown mandatory feature") { f =>
import f._
val unknownFeature = 42
@ -105,14 +120,14 @@ class PaymentInitiatorSpec extends TestKit(ActorSystem("test")) with fixture.Fun
payFsm.expectMsg(SendPayment(e, FinalLegacyPayload(finalAmount, Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry(nodeParams.currentBlockHeight + 1)), 3))
}
test("forward legacy payment when multi-part deactivated", Tag("mpp_disabled")) { f =>
test("forward single-part payment when multi-part deactivated", Tag("mpp_disabled")) { f =>
import f._
val pr = PaymentRequest(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, randomKey, "Some MPP invoice", features = Some(Features(VariableLengthOnion.optional, PaymentSecret.optional, BasicMultiPartPayment.optional)))
val req = SendPaymentRequest(finalAmount, paymentHash, c, 1, CltvExpiryDelta(42), Some(pr))
sender.send(initiator, req)
val id = sender.expectMsgType[UUID]
payFsm.expectMsg(SendPaymentConfig(id, id, None, paymentHash, finalAmount, c, Upstream.Local(id), Some(pr), storeInDb = true, publishEvent = true, Nil))
payFsm.expectMsg(SendPayment(c, FinalLegacyPayload(finalAmount, req.finalExpiry(nodeParams.currentBlockHeight)), 1))
payFsm.expectMsg(SendPayment(c, FinalTlvPayload(TlvStream(OnionTlv.AmountToForward(finalAmount), OnionTlv.OutgoingCltv(req.finalExpiry(nodeParams.currentBlockHeight)), OnionTlv.PaymentData(pr.paymentSecret.get, finalAmount))), 1))
}
test("forward multi-part payment") { f =>

View File

@ -182,6 +182,15 @@ class OnionCodecsSpec extends FunSuite {
}
}
test("encode/decode variable-length (tlv) final per-hop payload with custom user records") {
val tlvs = TlvStream[OnionTlv](Seq(AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42))), Seq(GenericTlv(5482373484L, hex"16c7ec71663784ff100b6eface1e60a97b92ea9d18b8ece5e558586bc7453828")))
val bin = hex"31 02020231 04012a ff0000000146c6616c2016c7ec71663784ff100b6eface1e60a97b92ea9d18b8ece5e558586bc7453828"
val encoded = finalPerHopPayloadCodec.encode(FinalTlvPayload(tlvs)).require.bytes
assert(encoded === bin)
assert(finalPerHopPayloadCodec.decode(bin.bits).require.value == FinalTlvPayload(tlvs))
}
test("decode multi-part final per-hop payload") {
val notMultiPart = finalPerHopPayloadCodec.decode(hex"07 02020231 04012a".bits).require.value
assert(notMultiPart.totalAmount === 561.msat)

View File

@ -229,9 +229,7 @@ class TlvCodecsSpec extends FunSuite {
hex"12 00",
hex"0a 00",
hex"fd0102 00",
hex"fe01000002 00",
hex"01020101 0a0101",
hex"ff0100000000000002 00",
// Invalid TestTlv1.
hex"01 01 00", // not minimally-encoded
hex"01 02 0001", // not minimally-encoded
@ -308,6 +306,16 @@ class TlvCodecsSpec extends FunSuite {
assert(lengthPrefixedTestTlvStreamCodec.encode(stream).require.toByteVector === hex"0f 01012a 0b012b 0d012a fd00fe02002a")
}
test("encode/decode custom even tlv records") {
val lowRangeEven = TlvStream[TestTlv](records = Nil, unknown = Seq(GenericTlv(124, hex"2a")))
val highRangeEven = TlvStream[TestTlv](records = Nil, unknown = Seq(GenericTlv(67876545678L, hex"2b")))
assert(testTlvStreamCodec.encode(lowRangeEven).isFailure)
assert(testTlvStreamCodec.encode(highRangeEven).isSuccessful)
assert(testTlvStreamCodec.decode(hex"7c 01 2a".toBitVector).isFailure) // lowRangeEven
assert(testTlvStreamCodec.decode(testTlvStreamCodec.encode(highRangeEven).require).isSuccessful)
}
test("encode invalid tlv stream") {
val testCases = Seq(
// Unknown even type.