From ca831df241b2a723f7d9037db380509286eed3c6 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier <31281497+t-bast@users.noreply.github.com> Date: Fri, 16 Dec 2022 10:44:34 +0100 Subject: [PATCH] Add a payment recipient abstraction (#2480) We were previously directly creating onion payloads inside the various payment state machines and manipulating tlv fields. This was a layering violation that was somewhat ok because in most cases we only needed to create the onion payload for the recipient at the beginning of the payment flow and didn't need to modify it, except for a small change in the MPP case. This forced us to handle trampoline onions directly in the payment initiator and will not work for blinded payments, where we can only build the onion payload for the recipient after we've chosen the routes and how to split the amount. We clean this up by introducing payment recipients that abstract away the creation of onion payloads. This makes it much easier to integrate blinded payments. It also allows us to clean up the way we do trampoline payments and potentially support splitting across multiple trampoline routes (not included in this PR as this change isn't immediately needed). It also lets us simplify the MultiPartPaymentLifecycle FSM, by moving the logic of computing how much remains to be sent and what fee can be used to the route calculation component. --- docs/TrampolinePayments.md | 14 +- docs/release-notes/eclair-vnext.md | 1 + .../main/scala/fr/acinq/eclair/Eclair.scala | 25 +- .../acinq/eclair/json/JsonSerializers.scala | 46 +- .../acinq/eclair/payment/Bolt11Invoice.scala | 43 +- .../acinq/eclair/payment/Bolt12Invoice.scala | 5 +- .../fr/acinq/eclair/payment/Invoice.scala | 68 +-- .../acinq/eclair/payment/PaymentEvents.scala | 35 +- .../acinq/eclair/payment/PaymentPacket.scala | 149 ++----- .../payment/receive/MultiPartHandler.scala | 11 +- .../eclair/payment/relay/NodeRelay.scala | 21 +- .../send/MultiPartPaymentLifecycle.scala | 131 ++---- .../payment/send/PaymentInitiator.scala | 205 ++++----- .../payment/send/PaymentLifecycle.scala | 219 +++++---- .../acinq/eclair/payment/send/Recipient.scala | 170 +++++++ .../scala/fr/acinq/eclair/router/Graph.scala | 18 +- .../eclair/router/RouteCalculation.scala | 104 +++-- .../scala/fr/acinq/eclair/router/Router.scala | 90 ++-- .../eclair/wire/protocol/PaymentOnion.scala | 24 +- .../fr/acinq/eclair/EclairImplSpec.scala | 7 +- .../fr/acinq/eclair/channel/FuzzySpec.scala | 12 +- .../ChannelStateTestsHelperMethods.scala | 17 +- .../channel/states/f/ShutdownStateSpec.scala | 16 +- .../fr/acinq/eclair/db/AuditDbSpec.scala | 3 +- .../fr/acinq/eclair/db/PaymentsDbSpec.scala | 4 +- .../integration/PaymentIntegrationSpec.scala | 80 +++- .../ZeroConfAliasIntegrationSpec.scala | 4 +- .../eclair/payment/MultiPartHandlerSpec.scala | 76 ++-- .../MultiPartPaymentLifecycleSpec.scala | 291 ++++++------ .../eclair/payment/PaymentInitiatorSpec.scala | 216 ++++----- .../eclair/payment/PaymentLifecycleSpec.scala | 168 +++---- .../eclair/payment/PaymentPacketSpec.scala | 420 +++++++++--------- .../payment/PostRestartHtlcCleanerSpec.scala | 13 +- .../payment/relay/NodeRelayerSpec.scala | 72 ++- .../eclair/payment/relay/RelayerSpec.scala | 59 ++- .../eclair/router/BalanceEstimateSpec.scala | 4 +- .../acinq/eclair/router/BaseRouterSpec.scala | 11 +- .../router/BlindedRouteCreationSpec.scala | 10 +- .../eclair/router/RouteCalculationSpec.scala | 65 +-- .../fr/acinq/eclair/router/RouterSpec.scala | 199 ++++++--- .../acinq/eclair/api/handlers/Payment.scala | 18 +- .../fr/acinq/eclair/api/ApiServiceSpec.scala | 22 +- 42 files changed, 1664 insertions(+), 1502 deletions(-) create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala diff --git a/docs/TrampolinePayments.md b/docs/TrampolinePayments.md index 62ffa3b1c..90a7f8bd2 100644 --- a/docs/TrampolinePayments.md +++ b/docs/TrampolinePayments.md @@ -4,7 +4,7 @@ Eclair started supporting [trampoline payments](https://github.com/lightning/bol It is disabled by default, as it is still being reviewed for spec acceptance. However, if you want to experiment with it, here is what you can do. -First of all, you need to activate the feature for any node that will act as s trampoline node. Update your `eclair.conf` with the following values: +First of all, you need to activate the feature for any node that will act as a trampoline node. Update your `eclair.conf` with the following values: ```conf eclair.trampoline-payments-enable=true @@ -24,12 +24,12 @@ Where Bob is a trampoline node and Alice, Carol and Dave are "normal" nodes. Let's imagine that Dave has generated an MPP invoice for 400000 msat: `lntb1500n1pwxx94fp...`. Alice wants to pay that invoice using Bob as a trampoline. -To spice things up, Alice will use MPP between Bob and her, splitting the payment in two parts. +To spice things up, Alice will use MPP between Bob and herself, splitting the payment in two parts. Initiate the payment by sending the first part: ```sh -eclair-cli sendtoroute --amountMsat=150000 --nodeIds=$ALICE_ID,$BOB_ID --trampolineNodes=$BOB_ID,$DAVE_ID --trampolineFeesMsat=100000 --trampolineCltvExpiry=450 --finalCltvExpiry=16 --invoice=lntb1500n1pwxx94fp... +eclair-cli sendtoroute --amountMsat=150000 --nodeIds=$ALICE_ID,$BOB_ID --trampolineFeesMsat=10000 --trampolineCltvExpiry=450 --finalCltvExpiry=16 --invoice=lntb1500n1pwxx94fp... ``` Note the `trampolineFeesMsat` and `trampolineCltvExpiry`. At the moment you have to estimate those yourself. If the values you provide are too low, Bob will send an error and you can retry with higher values. In future versions, we will automatically fill those values for you. @@ -51,12 +51,15 @@ The `trampolineSecret` is also important: this is what prevents a malicious tram Now that you have those, you can send the second part: ```sh -eclair-cli sendtoroute --amountMsat=250000 --parentId=cd083b31-5939-46ac-bf90-8ac5b286a9e2 --trampolineSecret=9e13d1b602496871bb647b48e8ff8f15a91c07affb0a3599e995d470ac488715 --nodeIds=$ALICE_ID,$BOB_ID --trampolineNodes=$BOB_ID,$DAVE_ID --trampolineFeesMsat=100000 --trampolineCltvExpiry=450 --finalCltvExpiry=16 --invoice=lntb1500n1pwxx94fp... +eclair-cli sendtoroute --amountMsat=260000 --parentId=cd083b31-5939-46ac-bf90-8ac5b286a9e2 --trampolineSecret=9e13d1b602496871bb647b48e8ff8f15a91c07affb0a3599e995d470ac488715 --nodeIds=$ALICE_ID,$BOB_ID --trampolineFeesMsat=10000 --trampolineCltvExpiry=450 --finalCltvExpiry=16 --invoice=lntb1500n1pwxx94fp... ``` Note that Alice didn't need to know about Carol. Bob will find the route to Dave through Carol on his own. That's the magic of trampoline! -A couple gotchas: you need to make sure you specify the same `trampolineFeesMsat` and `trampolineCltvExpiry` as the first part. This is something we will improve if our users ask for a better API. +A couple gotchas: + +- you need to make sure you specify the same `trampolineFeesMsat` and `trampolineCltvExpiry` as the first part +- the total `amountMsat` sent need to cover the `trampolineFeesMsat` specified You can then check the status of the payment with the `getsentinfo` command: @@ -65,4 +68,3 @@ eclair-cli getsentinfo --id=cd083b31-5939-46ac-bf90-8ac5b286a9e2 ``` Once Dave accepts the payment you should see all the details about the payment success (preimage, route, fees, etc). - \ No newline at end of file diff --git a/docs/release-notes/eclair-vnext.md b/docs/release-notes/eclair-vnext.md index 765982205..4ed7fff8c 100644 --- a/docs/release-notes/eclair-vnext.md +++ b/docs/release-notes/eclair-vnext.md @@ -9,6 +9,7 @@ ### API changes - `audit` now accepts `--count` and `--skip` parameters to limit the number of retrieved items (#2474, #2487) +- `sendtoroute` removes the `--trampolineNodes` argument and implicitly uses a single trampoline hop (#2480) ### Miscellaneous improvements and bug fixes diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index c985d0e93..ed4f63a02 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -41,6 +41,7 @@ import fr.acinq.eclair.message.{OnionMessages, Postman} import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.receive.MultiPartHandler.ReceiveStandardPayment import fr.acinq.eclair.payment.relay.Relayer.{ChannelBalance, GetOutgoingChannels, OutgoingChannels, RelayFees} +import fr.acinq.eclair.payment.send.ClearRecipient import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.PreimageReceived import fr.acinq.eclair.payment.send.PaymentInitiator._ import fr.acinq.eclair.router.Router @@ -125,7 +126,7 @@ trait Eclair { def findRouteBetween(sourceNodeId: PublicKey, targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], extraEdges: Seq[Invoice.ExtraEdge] = Seq.empty, includeLocalChannelCost: Boolean = false, ignoreNodeIds: Seq[PublicKey] = Seq.empty, ignoreShortChannelIds: Seq[ShortChannelId] = Seq.empty, maxFee_opt: Option[MilliSatoshi] = None)(implicit timeout: Timeout): Future[RouteResponse] - def sendToRoute(amount: MilliSatoshi, recipientAmount_opt: Option[MilliSatoshi], externalId_opt: Option[String], parentId_opt: Option[UUID], invoice: Bolt11Invoice, route: PredefinedRoute, trampolineSecret_opt: Option[ByteVector32] = None, trampolineFees_opt: Option[MilliSatoshi] = None, trampolineExpiryDelta_opt: Option[CltvExpiryDelta] = None, trampolineNodes_opt: Seq[PublicKey] = Nil)(implicit timeout: Timeout): Future[SendPaymentToRouteResponse] + def sendToRoute(recipientAmount_opt: Option[MilliSatoshi], externalId_opt: Option[String], parentId_opt: Option[UUID], invoice: Bolt11Invoice, route: PredefinedRoute, trampolineSecret_opt: Option[ByteVector32] = None, trampolineFees_opt: Option[MilliSatoshi] = None, trampolineExpiryDelta_opt: Option[CltvExpiryDelta] = None)(implicit timeout: Timeout): Future[SendPaymentToRouteResponse] def audit(from: TimestampSecond, to: TimestampSecond, paginated_opt: Option[Paginated])(implicit timeout: Timeout): Future[AuditResponse] @@ -312,30 +313,36 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { override def findRouteBetween(sourceNodeId: PublicKey, targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], extraEdges: Seq[Invoice.ExtraEdge] = Seq.empty, includeLocalChannelCost: Boolean = false, ignoreNodeIds: Seq[PublicKey] = Seq.empty, ignoreShortChannelIds: Seq[ShortChannelId] = Seq.empty, maxFee_opt: Option[MilliSatoshi] = None)(implicit timeout: Timeout): Future[RouteResponse] = { getRouteParams(pathFindingExperimentName_opt) match { case Right(routeParams) => - val maxFee = maxFee_opt.getOrElse(routeParams.getMaxFee(amount)) + val target = ClearRecipient(targetNodeId, Features.empty, amount, CltvExpiry(appKit.nodeParams.currentBlockHeight), ByteVector32.Zeroes, extraEdges) + val routeParams1 = routeParams.copy( + includeLocalChannelCost = includeLocalChannelCost, + boundaries = routeParams.boundaries.copy( + maxFeeFlat = maxFee_opt.getOrElse(routeParams.boundaries.maxFeeFlat), + maxFeeProportional = maxFee_opt.map(_ => 0.0).getOrElse(routeParams.boundaries.maxFeeProportional) + ) + ) for { ignoredChannels <- getChannelDescs(ignoreShortChannelIds.toSet) ignore = Ignore(ignoreNodeIds.toSet, ignoredChannels) - response <- (appKit.router ? RouteRequest(sourceNodeId, targetNodeId, amount, maxFee, extraEdges, ignore = ignore, routeParams = routeParams.copy(includeLocalChannelCost = includeLocalChannelCost))).mapTo[RouteResponse] + response <- (appKit.router ? RouteRequest(sourceNodeId, target, routeParams1, ignore)).mapTo[RouteResponse] } yield response case Left(t) => Future.failed(t) } } - override def sendToRoute(amount: MilliSatoshi, recipientAmount_opt: Option[MilliSatoshi], externalId_opt: Option[String], parentId_opt: Option[UUID], invoice: Bolt11Invoice, route: PredefinedRoute, trampolineSecret_opt: Option[ByteVector32], trampolineFees_opt: Option[MilliSatoshi], trampolineExpiryDelta_opt: Option[CltvExpiryDelta], trampolineNodes_opt: Seq[PublicKey])(implicit timeout: Timeout): Future[SendPaymentToRouteResponse] = { - val recipientAmount = recipientAmount_opt.getOrElse(invoice.amount_opt.getOrElse(amount)) - val sendPayment = SendPaymentToRoute(amount, recipientAmount, invoice, route, externalId_opt, parentId_opt, trampolineSecret_opt, trampolineFees_opt.getOrElse(0 msat), trampolineExpiryDelta_opt.getOrElse(CltvExpiryDelta(0)), trampolineNodes_opt) + override def sendToRoute(recipientAmount_opt: Option[MilliSatoshi], externalId_opt: Option[String], parentId_opt: Option[UUID], invoice: Bolt11Invoice, route: PredefinedRoute, trampolineSecret_opt: Option[ByteVector32], trampolineFees_opt: Option[MilliSatoshi], trampolineExpiryDelta_opt: Option[CltvExpiryDelta])(implicit timeout: Timeout): Future[SendPaymentToRouteResponse] = { if (invoice.isExpired()) { Future.failed(new IllegalArgumentException("invoice has expired")) } else if (route.isEmpty) { Future.failed(new IllegalArgumentException("missing payment route")) } else if (externalId_opt.exists(_.length > externalIdMaxLength)) { Future.failed(new IllegalArgumentException(s"externalId is too long: cannot exceed $externalIdMaxLength characters")) - } else if (trampolineNodes_opt.nonEmpty && (trampolineFees_opt.isEmpty || trampolineExpiryDelta_opt.isEmpty)) { + } else if (trampolineFees_opt.nonEmpty && trampolineExpiryDelta_opt.isEmpty) { Future.failed(new IllegalArgumentException("trampoline payments must specify a trampoline fee and cltv delta")) - } else if (trampolineNodes_opt.nonEmpty && trampolineNodes_opt.length != 2) { - Future.failed(new IllegalArgumentException("trampoline payments currently only support paying a trampoline node via a single other trampoline node")) } else { + val recipientAmount = recipientAmount_opt.getOrElse(invoice.amount_opt.getOrElse(route.amount)) + val trampoline_opt = trampolineFees_opt.map(fees => TrampolineAttempt(trampolineSecret_opt.getOrElse(randomBytes32()), fees, trampolineExpiryDelta_opt.get)) + val sendPayment = SendPaymentToRoute(recipientAmount, invoice, route, externalId_opt, parentId_opt, trampoline_opt) (appKit.paymentInitiator ? sendPayment).mapTo[SendPaymentToRouteResponse] } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala b/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala index dfd05bb85..4fb7ea15b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala @@ -31,7 +31,7 @@ import fr.acinq.eclair.io.Peer import fr.acinq.eclair.message.OnionMessages import fr.acinq.eclair.payment.PaymentFailure.PaymentFailedSummary import fr.acinq.eclair.payment._ -import fr.acinq.eclair.router.Router.{ChannelRelayParams, Route} +import fr.acinq.eclair.router.Router.{HopRelayParams, NodeHop, Route} import fr.acinq.eclair.transactions.DirectedHtlc import fr.acinq.eclair.transactions.Transactions._ import fr.acinq.eclair.wire.protocol.MessageOnionCodecs.blindedRouteCodec @@ -294,30 +294,48 @@ object ColorSerializer extends MinimalSerializer({ }) // @formatter:off -private case class ChannelHopJson(nodeId: PublicKey, nextNodeId: PublicKey, source: ChannelRelayParams) -private case class RouteFullJson(amount: MilliSatoshi, hops: Seq[ChannelHopJson]) -object RouteFullSerializer extends ConvertClassSerializer[Route](route => RouteFullJson(route.amount, route.hops.map(h => ChannelHopJson(h.nodeId, h.nextNodeId, h.params)))) +private sealed trait HopJson +private case class ChannelHopJson(nodeId: PublicKey, nextNodeId: PublicKey, source: HopRelayParams) extends HopJson +private case class NodeHopJson(nodeId: PublicKey, nextNodeId: PublicKey, fee: MilliSatoshi, cltvExpiryDelta: CltvExpiryDelta) extends HopJson +private case class RouteFullJson(amount: MilliSatoshi, hops: Seq[HopJson]) +object RouteFullSerializer extends ConvertClassSerializer[Route](route => { + val channelHops = route.hops.map(h => ChannelHopJson(h.nodeId, h.nextNodeId, h.params)) + val finalHop_opt = route.finalHop_opt.map { + case h: NodeHop => NodeHopJson(h.nodeId, h.nextNodeId, h.fee, h.cltvExpiryDelta) + } + RouteFullJson(route.amount, channelHops ++ finalHop_opt.toSeq) +}) private case class RouteNodeIdsJson(amount: MilliSatoshi, nodeIds: Seq[PublicKey]) object RouteNodeIdsSerializer extends ConvertClassSerializer[Route](route => { - val nodeIds = route.hops match { - case rest :+ last => rest.map(_.nodeId) :+ last.nodeId :+ last.nextNodeId - case Nil => Nil + val channelNodeIds = route.hops.headOption match { + case Some(hop) => Seq(hop.nodeId, hop.nextNodeId) ++ route.hops.tail.map(_.nextNodeId) + case None => Nil } - RouteNodeIdsJson(route.amount, nodeIds) + val finalNodeIds = route.finalHop_opt match { + case Some(hop: NodeHop) if channelNodeIds.nonEmpty => Seq(hop.nextNodeId) + case Some(hop: NodeHop) => Seq(hop.nodeId, hop.nextNodeId) + case None => Nil + } + RouteNodeIdsJson(route.amount, channelNodeIds ++ finalNodeIds) }) -private case class RouteShortChannelIdsJson(amount: MilliSatoshi, shortChannelIds: Seq[ShortChannelId]) -object RouteShortChannelIdsSerializer extends ConvertClassSerializer[Route](route => RouteShortChannelIdsJson(route.amount, route.hops.map(_.shortChannelId))) +private case class RouteShortChannelIdsJson(amount: MilliSatoshi, shortChannelIds: Seq[ShortChannelId], finalHop: Option[String]) +object RouteShortChannelIdsSerializer extends ConvertClassSerializer[Route](route => { + val hops = route.hops.map(_.shortChannelId) + val finalHop = route.finalHop_opt.map { + case _: NodeHop => "trampoline" + } + RouteShortChannelIdsJson(route.amount, hops, finalHop) +}) // @formatter:on // @formatter:off private case class PaymentFailureSummaryJson(amount: MilliSatoshi, route: Seq[PublicKey], message: String) -private case class PaymentFailedSummaryJson(paymentHash: ByteVector32, destination: PublicKey, totalAmount: MilliSatoshi, pathFindingExperiment: String, failures: Seq[PaymentFailureSummaryJson]) +private case class PaymentFailedSummaryJson(paymentHash: ByteVector32, destination: PublicKey, pathFindingExperiment: String, failures: Seq[PaymentFailureSummaryJson]) object PaymentFailedSummarySerializer extends ConvertClassSerializer[PaymentFailedSummary](p => PaymentFailedSummaryJson( p.cfg.paymentHash, p.cfg.recipientNodeId, - p.cfg.recipientAmount, p.pathFindingExperiment, p.paymentFailed.failures.map(f => { val route = f.route.map(_.nodeId) ++ f.route.lastOption.map(_.nextNodeId) @@ -512,8 +530,8 @@ object CustomTypeHints { )) val channelSources: CustomTypeHints = CustomTypeHints(Map( - classOf[ChannelRelayParams.FromAnnouncement] -> "announcement", - classOf[ChannelRelayParams.FromHint] -> "hint" + classOf[HopRelayParams.FromAnnouncement] -> "announcement", + classOf[HopRelayParams.FromHint] -> "hint" )) val channelStates: ShortTypeHints = ShortTypeHints( diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt11Invoice.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt11Invoice.scala index 709937491..3ec8f132f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt11Invoice.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt11Invoice.scala @@ -549,53 +549,12 @@ object Bolt11Invoice { signature = bolt11Data.signature) } - private def readBoltData(input: String): Bolt11Data = { - val lowercaseInput = input.toLowerCase - val separatorIndex = lowercaseInput.lastIndexOf('1') - val hrp = lowercaseInput.take(separatorIndex) - if (!prefixes.values.exists(prefix => hrp.startsWith(prefix))) throw new RuntimeException("unknown prefix") - val data = string2Bits(lowercaseInput.slice(separatorIndex + 1, lowercaseInput.length - 6)) // 6 == checksum size - Codecs.bolt11DataCodec.decode(data).require.value - } - - /** - * Extracts the description from a serialized invoice that is **expected to be valid**. - * Throws an error if the invoice is not valid. - * - * @param input valid serialized invoice - * @return description as a String. If the description is a hash, returns the hash value as a String. - */ - def fastReadDescription(input: String): String = { - readBoltData(input).taggedFields.collectFirst { - case Bolt11Invoice.Description(d) => d - case Bolt11Invoice.DescriptionHash(h) => h.toString() - }.get - } - - /** - * Checks if a serialized invoice is expired. Timestamp is compared to the System's current time. - * - * @param input valid serialized invoice - * @return true if the invoice has expired, false otherwise. - */ - def fastHasExpired(input: String): Boolean = { - val bolt11Data = readBoltData(input) - val expiry_opt = bolt11Data.taggedFields.collectFirst { - case p: Bolt11Invoice.Expiry => p - } - val timestamp = bolt11Data.timestamp - expiry_opt match { - case Some(expiry) => timestamp + expiry.toLong <= TimestampSecond.now() - case None => timestamp + DEFAULT_EXPIRY_SECONDS <= TimestampSecond.now() - } - } - def toExtraEdges(extraRoute: Seq[ExtraHop], targetNodeId: PublicKey): Seq[Invoice.ExtraEdge] = { // BOLT 11: "For each entry, the pubkey is the node ID of the start of the channel", and the last node is the destination val nextNodeIds = extraRoute.map(_.nodeId).drop(1) :+ targetNodeId extraRoute.zip(nextNodeIds).map { case (extraHop, nextNodeId) => - Invoice.BasicEdge(extraHop.nodeId, nextNodeId, extraHop.shortChannelId, extraHop.feeBase, extraHop.feeProportionalMillionths, extraHop.cltvExpiryDelta) + Invoice.ExtraEdge(extraHop.nodeId, nextNodeId, extraHop.shortChannelId, extraHop.feeBase, extraHop.feeProportionalMillionths, extraHop.cltvExpiryDelta, 1 msat, None) } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt12Invoice.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt12Invoice.scala index 5905fc245..b74f6d218 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt12Invoice.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt12Invoice.scala @@ -24,7 +24,7 @@ import fr.acinq.eclair.crypto.Sphinx.RouteBlinding import fr.acinq.eclair.wire.protocol.OfferTypes._ import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.{OfferCodecs, OfferTypes, TlvStream} -import fr.acinq.eclair.{CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, TimestampSecond, UInt64, randomBytes32} +import fr.acinq.eclair.{CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, TimestampSecond, UInt64} import scodec.bits.ByteVector import java.util.concurrent.TimeUnit @@ -43,10 +43,7 @@ case class Bolt12Invoice(records: TlvStream[InvoiceTlv]) extends Invoice { override val amount_opt: Option[MilliSatoshi] = Some(amount) override val nodeId: Crypto.PublicKey = records.get[NodeId].get.publicKey override val paymentHash: ByteVector32 = records.get[PaymentHash].get.hash - override val paymentSecret: ByteVector32 = randomBytes32() - override val paymentMetadata: Option[ByteVector] = None override val description: Either[String, ByteVector32] = Left(records.get[Description].get.description) - override val extraEdges: Seq[Invoice.ExtraEdge] = Seq.empty // TODO: the blinded paths need to be converted to graph edges override val createdAt: TimestampSecond = records.get[CreatedAt].get.timestamp override val relativeExpiry: FiniteDuration = FiniteDuration(records.get[RelativeExpiry].map(_.seconds).getOrElse(DEFAULT_EXPIRY_SECONDS), TimeUnit.SECONDS) override val minFinalCltvExpiryDelta: CltvExpiryDelta = records.get[Cltv].map(_.minFinalCltvExpiry).getOrElse(DEFAULT_MIN_FINAL_EXPIRY_DELTA) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Invoice.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Invoice.scala index 0dccf730c..74bd49451 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Invoice.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Invoice.scala @@ -20,65 +20,45 @@ import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.payment.relay.Relayer import fr.acinq.eclair.wire.protocol.ChannelUpdate -import fr.acinq.eclair.{CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, ShortChannelId, TimestampSecond} -import scodec.bits.ByteVector +import fr.acinq.eclair.{CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, ShortChannelId, TimestampSecond} import scala.concurrent.duration.FiniteDuration import scala.util.Try trait Invoice { - val amount_opt: Option[MilliSatoshi] - - val createdAt: TimestampSecond - - val nodeId: PublicKey - - val paymentHash: ByteVector32 - - val paymentSecret: ByteVector32 - - val paymentMetadata: Option[ByteVector] - - val description: Either[String, ByteVector32] - - val extraEdges: Seq[Invoice.ExtraEdge] - - val relativeExpiry: FiniteDuration - - val minFinalCltvExpiryDelta: CltvExpiryDelta - - val features: Features[InvoiceFeature] - - def isExpired(): Boolean = createdAt + relativeExpiry.toSeconds <= TimestampSecond.now() - + // @formatter:off + def nodeId: PublicKey + def amount_opt: Option[MilliSatoshi] + def createdAt: TimestampSecond + def paymentHash: ByteVector32 + def description: Either[String, ByteVector32] + def relativeExpiry: FiniteDuration + def minFinalCltvExpiryDelta: CltvExpiryDelta + def features: Features[InvoiceFeature] + def isExpired(now: TimestampSecond = TimestampSecond.now()): Boolean = createdAt + relativeExpiry.toSeconds <= now def toString: String + // @formatter:on } object Invoice { /** An extra edge that can be used to pay a given invoice and may not be part of the public graph. */ - sealed trait ExtraEdge { - // @formatter:off - def sourceNodeId: PublicKey - def feeBase: MilliSatoshi - def feeProportionalMillionths: Long - def cltvExpiryDelta: CltvExpiryDelta - def htlcMinimum: MilliSatoshi - def htlcMaximum_opt: Option[MilliSatoshi] - def relayFees: Relayer.RelayFees = Relayer.RelayFees(feeBase = feeBase, feeProportionalMillionths = feeProportionalMillionths) - // @formatter:on - } - - /** A normal graph edge, that should be handled exactly like public graph edges. */ - case class BasicEdge(sourceNodeId: PublicKey, + case class ExtraEdge(sourceNodeId: PublicKey, targetNodeId: PublicKey, shortChannelId: ShortChannelId, feeBase: MilliSatoshi, feeProportionalMillionths: Long, - cltvExpiryDelta: CltvExpiryDelta) extends ExtraEdge { - override val htlcMinimum: MilliSatoshi = 0 msat - override val htlcMaximum_opt: Option[MilliSatoshi] = None + cltvExpiryDelta: CltvExpiryDelta, + htlcMinimum: MilliSatoshi, + htlcMaximum_opt: Option[MilliSatoshi]) { + val relayFees = Relayer.RelayFees(feeBase, feeProportionalMillionths) - def update(u: ChannelUpdate): BasicEdge = copy(feeBase = u.feeBaseMsat, feeProportionalMillionths = u.feeProportionalMillionths, cltvExpiryDelta = u.cltvExpiryDelta) + def update(u: ChannelUpdate): ExtraEdge = copy( + feeBase = u.feeBaseMsat, + feeProportionalMillionths = u.feeProportionalMillionths, + cltvExpiryDelta = u.cltvExpiryDelta, + htlcMinimum = u.htlcMinimumMsat, + htlcMaximum_opt = Some(u.htlcMaximumMsat) + ) } def fromString(input: String): Try[Invoice] = { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala index 925445618..98a579785 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala @@ -19,9 +19,10 @@ package fr.acinq.eclair.payment import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.payment.Invoice.{BasicEdge, ExtraEdge} +import fr.acinq.eclair.payment.Invoice.ExtraEdge import fr.acinq.eclair.payment.send.PaymentError.RetryExhausted import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig +import fr.acinq.eclair.payment.send.{ClearRecipient, Recipient} import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.router.Router.{ChannelDesc, ChannelHop, Hop, Ignore} import fr.acinq.eclair.wire.protocol.{ChannelDisabled, ChannelUpdate, Node, TemporaryChannelFailure} @@ -234,21 +235,23 @@ object PaymentFailure { failures.foldLeft(ignore) { case (current, failure) => updateIgnored(failure, current) } } - /** Update the invoice routing hints based on more recent channel updates received. */ - def updateExtraEdges(failures: Seq[PaymentFailure], extraEdges: Seq[ExtraEdge]): Seq[ExtraEdge] = { - // We're only interested in the last channel update received per channel. - val updates = failures.foldLeft(Map.empty[ShortChannelId, ChannelUpdate]) { - case (current, failure) => failure match { - case RemoteFailure(_, _, Sphinx.DecryptedFailurePacket(_, f: Update)) => current.updated(f.update.shortChannelId, f.update) - case _ => current - } - } - extraEdges.map { - case edge: BasicEdge => updates.get(edge.shortChannelId) match { - case Some(u) => edge.update(u) - case None => edge - } - case edge => edge + /** Update the recipient routing hints based on more recent data received. */ + def updateExtraEdges(failures: Seq[PaymentFailure], recipient: Recipient): Recipient = { + recipient match { + case r: ClearRecipient => + // We're only interested in the last channel update received per channel. + val updates = failures.foldLeft(Map.empty[ShortChannelId, ChannelUpdate]) { + case (current, failure) => failure match { + case RemoteFailure(_, _, Sphinx.DecryptedFailurePacket(_, f: Update)) => current.updated(f.update.shortChannelId, f.update) + case _ => current + } + } + val extraEdges1 = r.extraEdges.map(edge => updates.get(edge.shortChannelId) match { + case Some(u) => edge.update(u) + case None => edge + }) + r.copy(extraEdges = extraEdges1) + case r => r } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala index a6af9b4d2..a063d74c7 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala @@ -22,15 +22,16 @@ import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.channel.{CMD_ADD_HTLC, CMD_FAIL_HTLC, CannotExtractSharedSecret, Origin} import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.router.Router.{ChannelHop, Hop, NodeHop} +import fr.acinq.eclair.payment.send.Recipient +import fr.acinq.eclair.router.Router.Route import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload, PerHopPayload} import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, ShortChannelId, UInt64, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, ShortChannelId, UInt64, randomKey} import scodec.bits.ByteVector import scodec.{Attempt, DecodeResult} import java.util.UUID -import scala.util.Try +import scala.util.{Failure, Success} /** * Created by t-bast on 08/10/2019. @@ -196,7 +197,7 @@ object IncomingPaymentPacket { case innerPayload => // We merge contents from the outer and inner payloads. // We must use the inner payload's total amount and payment secret because the payment may be split between multiple trampoline payments (#reckless). - Right(FinalPacket(add, FinalPayload.Standard.createMultiPartPayload(outerPayload.amount, innerPayload.totalAmount, innerPayload.expiry, innerPayload.paymentSecret, innerPayload.paymentMetadata))) + Right(FinalPacket(add, FinalPayload.Standard.createPayload(outerPayload.amount, innerPayload.totalAmount, innerPayload.expiry, innerPayload.paymentSecret, innerPayload.paymentMetadata))) } } } @@ -214,99 +215,26 @@ object IncomingPaymentPacket { } +/** + * @param cmd command to send the HTLC for this payment. + * @param outgoingChannel channel to send the HTLC to. + * @param sharedSecrets shared secrets (used to decrypt the error in case of payment failure). + */ +case class OutgoingPaymentPacket(cmd: CMD_ADD_HTLC, outgoingChannel: ShortChannelId, sharedSecrets: Seq[(ByteVector32, PublicKey)]) + /** Helpers to create outgoing payment packets. */ object OutgoingPaymentPacket { - /** - * Build an encrypted onion packet from onion payloads and node public keys. - */ - private def buildOnion(packetPayloadLength: Int, nodes: Seq[PublicKey], payloads: Seq[PerHopPayload], associatedData: ByteVector32): Try[Sphinx.PacketAndSecrets] = { - require(nodes.size == payloads.size) - val sessionKey = randomKey() - val payloadsBin: Seq[ByteVector] = payloads - .map(p => PaymentOnionCodecs.perHopPayloadCodec.encode(p.records)) - .map { - case Attempt.Successful(bits) => bits.bytes - case Attempt.Failure(cause) => throw new RuntimeException(s"serialization error: $cause") - } - Sphinx.create(sessionKey, packetPayloadLength, nodes, payloadsBin, Some(associatedData)) - } - - /** - * Build the onion payloads for each hop. - * - * @param hops the hops as computed by the router + extra routes from the invoice - * @param finalPayload payload data for the final node (amount, expiry, etc) - * @return a (firstAmount, firstExpiry, payloads) tuple where: - * - firstAmount is the amount for the first htlc in the route - * - firstExpiry is the cltv expiry for the first htlc in the route - * - a sequence of payloads that will be used to build the onion - */ - def buildPayloads(hops: Seq[Hop], finalPayload: FinalPayload): (MilliSatoshi, CltvExpiry, Seq[PerHopPayload]) = { - hops.reverse.foldLeft((finalPayload.amount, finalPayload.expiry, Seq[PerHopPayload](finalPayload))) { - case ((amount, expiry, payloads), hop) => - val payload = hop match { - case hop: ChannelHop => IntermediatePayload.ChannelRelay.Standard(hop.shortChannelId, amount, expiry) - case hop: NodeHop => IntermediatePayload.NodeRelay.Standard(amount, expiry, hop.nextNodeId) - } - (amount + hop.fee(amount), expiry + hop.cltvExpiryDelta, payload +: payloads) - } - } - - /** - * Build an encrypted onion packet with the given final payload. - * - * @param hops the hops as computed by the router + extra routes from the invoice, including ourselves in the first hop - * @param finalPayload payload data for the final node (amount, expiry, etc) - * @return a (firstAmount, firstExpiry, onion) tuple where: - * - firstAmount is the amount for the first htlc in the route - * - firstExpiry is the cltv expiry for the first htlc in the route - * - the onion to include in the HTLC - */ - private def buildPacket(packetPayloadLength: Int, paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: FinalPayload): Try[(MilliSatoshi, CltvExpiry, Sphinx.PacketAndSecrets)] = { - val (firstAmount, firstExpiry, payloads) = buildPayloads(hops.drop(1), finalPayload) - val nodes = hops.map(_.nextNodeId) - // BOLT 2 requires that associatedData == paymentHash - buildOnion(packetPayloadLength, nodes, payloads, paymentHash).map(onion => (firstAmount, firstExpiry, onion)) - } - - def buildPaymentPacket(paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: FinalPayload): Try[(MilliSatoshi, CltvExpiry, Sphinx.PacketAndSecrets)] = - buildPacket(PaymentOnionCodecs.paymentOnionPayloadLength, paymentHash, hops, finalPayload) - - def buildTrampolinePacket(paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: FinalPayload): Try[(MilliSatoshi, CltvExpiry, Sphinx.PacketAndSecrets)] = - buildPacket(PaymentOnionCodecs.trampolineOnionPayloadLength, paymentHash, hops, finalPayload) - - /** - * Build an encrypted trampoline onion packet when the final recipient doesn't support trampoline. - * The next-to-last trampoline node payload will contain instructions to convert to a legacy payment. - * - * @param invoice Bolt 11 invoice (features and routing hints will be provided to the next-to-last node). - * @param hops the trampoline hops (including ourselves in the first hop, and the non-trampoline final recipient in the last hop). - * @param finalPayload payload data for the final node (amount, expiry, etc) - * @return a (firstAmount, firstExpiry, onion) tuple where: - * - firstAmount is the amount for the trampoline node in the route - * - firstExpiry is the cltv expiry for the first trampoline node in the route - * - the trampoline onion to include in final payload of a normal onion - */ - def buildTrampolineToLegacyPacket(invoice: Bolt11Invoice, hops: Seq[NodeHop], finalPayload: FinalPayload): Try[(MilliSatoshi, CltvExpiry, Sphinx.PacketAndSecrets)] = { - // NB: the final payload will never reach the recipient, since the next-to-last node in the trampoline route will convert that to a non-trampoline payment. - // We use the smallest final payload possible, otherwise we may overflow the trampoline onion size. - val dummyFinalPayload = FinalPayload.Standard.createSinglePartPayload(finalPayload.amount, finalPayload.expiry, randomBytes32(), None) - val (firstAmount, firstExpiry, payloads) = hops.drop(1).reverse.foldLeft((finalPayload.amount, finalPayload.expiry, Seq[PerHopPayload](dummyFinalPayload))) { - case ((amount, expiry, payloads), hop) => - // The next-to-last node in the trampoline route must receive invoice data to indicate the conversion to a non-trampoline payment. - val payload = if (payloads.length == 1) { - IntermediatePayload.NodeRelay.Standard.createNodeRelayToNonTrampolinePayload(finalPayload.amount, finalPayload.totalAmount, finalPayload.expiry, hop.nextNodeId, invoice) - } else { - IntermediatePayload.NodeRelay.Standard(amount, expiry, hop.nextNodeId) - } - (amount + hop.fee(amount), expiry + hop.cltvExpiryDelta, payload +: payloads) - } - val nodes = hops.map(_.nextNodeId) - buildOnion(PaymentOnionCodecs.trampolineOnionPayloadLength, nodes, payloads, invoice.paymentHash).map(onion => (firstAmount, firstExpiry, onion)) - } - // @formatter:off + case class NodePayload(nodeId: PublicKey, payload: PerHopPayload) + case class PaymentPayloads(amount: MilliSatoshi, expiry: CltvExpiry, payloads: Seq[NodePayload]) + + sealed trait OutgoingPaymentError extends Throwable + case class CannotCreateOnion(message: String) extends OutgoingPaymentError { override def getMessage: String = message } + case class InvalidRouteRecipient(expected: PublicKey, actual: PublicKey) extends OutgoingPaymentError { override def getMessage: String = s"expected route to $expected, got route to $actual" } + case class MissingTrampolineHop(trampolineNodeId: PublicKey) extends OutgoingPaymentError { override def getMessage: String = s"expected route to trampoline node $trampolineNodeId" } + case object EmptyRoute extends OutgoingPaymentError { override def getMessage: String = "route cannot be empty" } + sealed trait Upstream object Upstream { case class Local(id: UUID) extends Upstream @@ -317,15 +245,31 @@ object OutgoingPaymentPacket { } // @formatter:on - /** - * Build the command to add an HTLC with the given final payload and using the provided hops. - * - * @return the command and the onion shared secrets (used to decrypt the error in case of payment failure) - */ - def buildCommand(replyTo: ActorRef, upstream: Upstream, paymentHash: ByteVector32, hops: Seq[ChannelHop], finalPayload: FinalPayload): Try[(CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)])] = { - buildPaymentPacket(paymentHash, hops, finalPayload).map { - case (firstAmount, firstExpiry, onion) => - CMD_ADD_HTLC(replyTo, firstAmount, paymentHash, firstExpiry, onion.packet, None, Origin.Hot(replyTo, upstream), commit = true) -> onion.sharedSecrets + /** Build an encrypted onion packet from onion payloads and node public keys. */ + def buildOnion(packetPayloadLength: Int, payloads: Seq[NodePayload], associatedData: ByteVector32): Either[OutgoingPaymentError, Sphinx.PacketAndSecrets] = { + val sessionKey = randomKey() + val nodeIds = payloads.map(_.nodeId) + val payloadsBin = payloads + .map(p => PaymentOnionCodecs.perHopPayloadCodec.encode(p.payload.records)) + .map { + case Attempt.Successful(bits) => bits.bytes + case Attempt.Failure(cause) => return Left(CannotCreateOnion(cause.message)) + } + Sphinx.create(sessionKey, packetPayloadLength, nodeIds, payloadsBin, Some(associatedData)) match { + case Failure(f) => Left(CannotCreateOnion(f.getMessage)) + case Success(packet) => Right(packet) + } + } + + /** Build the command to add an HTLC for the given recipient using the provided route. */ + def buildOutgoingPayment(replyTo: ActorRef, upstream: Upstream, paymentHash: ByteVector32, route: Route, recipient: Recipient): Either[OutgoingPaymentError, OutgoingPaymentPacket] = { + val outgoingChannel = route.hops.head.shortChannelId + for { + payment <- recipient.buildPayloads(paymentHash, route) + onion <- buildOnion(PaymentOnionCodecs.paymentOnionPayloadLength, payment.payloads, paymentHash) // BOLT 2 requires that associatedData == paymentHash + } yield { + val cmd = CMD_ADD_HTLC(replyTo, payment.amount, paymentHash, payment.expiry, onion.packet, None, Origin.Hot(replyTo, upstream), commit = true) + OutgoingPaymentPacket(cmd, outgoingChannel, onion.sharedSecrets) } } @@ -344,4 +288,5 @@ object OutgoingPaymentPacket { def buildHtlcFailure(nodeSecret: PrivateKey, cmd: CMD_FAIL_HTLC, add: UpdateAddHtlc): Either[CannotExtractSharedSecret, UpdateFailHtlc] = { buildHtlcFailure(nodeSecret, cmd.reason, add).map(encryptedReason => UpdateFailHtlc(add.channelId, cmd.id, encryptedReason)) } + } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala index 6242af708..f9033056c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala @@ -34,11 +34,11 @@ import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.BlindedRouteCreation.{aggregatePaymentInfo, createBlindedRouteFromHops, createBlindedRouteWithoutHops} import fr.acinq.eclair.router.Router -import fr.acinq.eclair.router.Router.{ChannelHop, ChannelRelayParams} +import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams} import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiryDelta, FeatureSupport, Features, InvoiceFeature, Logs, MilliSatoshi, NodeParams, ShortChannelId, TimestampMilli, randomBytes32} +import fr.acinq.eclair.{CltvExpiryDelta, FeatureSupport, Features, InvoiceFeature, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TimestampMilli, randomBytes32} import scodec.bits.HexStringSyntax import scala.concurrent.duration.DurationInt @@ -339,8 +339,9 @@ object MultiPartHandler { Future.sequence(r.routes.map(route => { val pathId = randomBytes32() val dummyHops = route.dummyHops.map(h => { - val edge = Invoice.BasicEdge(nodeParams.nodeId, nodeParams.nodeId, ShortChannelId.toSelf, h.feeBase, h.feeProportionalMillionths, h.cltvExpiryDelta) - ChannelHop(edge.shortChannelId, edge.sourceNodeId, edge.targetNodeId, ChannelRelayParams.FromHint(edge)) + // We don't want to restrict HTLC size in dummy hops, so we use htlc_minimum_msat = 1 msat and htlc_maximum_msat = None. + val edge = Invoice.ExtraEdge(nodeParams.nodeId, nodeParams.nodeId, ShortChannelId.toSelf, h.feeBase, h.feeProportionalMillionths, h.cltvExpiryDelta, htlcMinimum = 1 msat, htlcMaximum_opt = None) + ChannelHop(edge.shortChannelId, edge.sourceNodeId, edge.targetNodeId, HopRelayParams.FromHint(edge)) }) if (route.nodes.length == 1) { val blindedRoute = if (dummyHops.isEmpty) { @@ -352,7 +353,7 @@ object MultiPartHandler { Future.successful((blindedRoute, paymentInfo, pathId)) } else { implicit val timeout: Timeout = 10.seconds - r.router.ask(Router.FinalizeRoute(r.amount, Router.PredefinedNodeRoute(route.nodes))).mapTo[Router.RouteResponse].map(routeResponse => { + r.router.ask(Router.FinalizeRoute(Router.PredefinedNodeRoute(r.amount, route.nodes))).mapTo[Router.RouteResponse].map(routeResponse => { val clearRoute = routeResponse.routes.head val blindedRoute = createBlindedRouteFromHops(clearRoute.hops ++ dummyHops, pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) val paymentInfo = aggregatePaymentInfo(r.amount, clearRoute.hops ++ dummyHops) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index 5dcdfba6c..89dcded82 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -35,10 +35,10 @@ import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM.HtlcPart import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment} import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToNode -import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentInitiator, PaymentLifecycle} +import fr.acinq.eclair.payment.send.{ClearRecipient, MultiPartPaymentLifecycle, PaymentInitiator, PaymentLifecycle} import fr.acinq.eclair.router.Router.RouteParams import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound} -import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload} +import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{BlockHeight, CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, UInt64, nodeFee, randomBytes32} @@ -232,7 +232,7 @@ class NodeRelay private(nodeParams: NodeParams, context.log.warn(s"rejecting async payment at block $blockHeight; was not triggered after waiting ${nodeParams.relayParams.asyncPaymentsParams.holdTimeoutBlocks} blocks") rejectPayment(upstream, Some(TemporaryNodeFailure)) // TODO: replace failure type when async payment spec is finalized stopping() - case WrappedCurrentBlockHeight(blockHeight) => + case WrappedCurrentBlockHeight(_) => Behaviors.same case CancelAsyncPayment => context.log.warn(s"payment sender canceled a waiting async payment") @@ -302,32 +302,33 @@ class NodeRelay private(nodeParams: NodeParams, }.toClassic private def relay(upstream: Upstream.Trampoline, payloadOut: IntermediatePayload.NodeRelay.Standard, packetOut: OnionRoutingPacket): ActorRef = { - val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, payloadOut.amountToForward, payloadOut.outgoingNodeId, upstream, None, storeInDb = false, publishEvent = false, recordPathFindingMetrics = true, Nil) + val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, payloadOut.outgoingNodeId, upstream, None, storeInDb = false, publishEvent = false, recordPathFindingMetrics = true) val routeParams = computeRouteParams(nodeParams, upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv) // If invoice features are provided in the onion, the sender is asking us to relay to a non-trampoline recipient. val payFSM = payloadOut.invoiceFeatures match { case Some(features) => val extraEdges = payloadOut.invoiceRoutingInfo.getOrElse(Nil).flatMap(Bolt11Invoice.toExtraEdges(_, payloadOut.outgoingNodeId)) val paymentSecret = payloadOut.paymentSecret.get // NB: we've verified that there was a payment secret in validateRelay - if (Features(features).hasFeature(Features.BasicMultiPartPayment)) { + val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features(features).invoiceFeatures(), payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, extraEdges, payloadOut.paymentMetadata) + if (recipient.features.hasFeature(Features.BasicMultiPartPayment)) { context.log.debug("sending the payment to non-trampoline recipient using MPP") - val payment = SendMultiPartPayment(payFsmAdapters, paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, payloadOut.paymentMetadata, extraEdges, routeParams) + val payment = SendMultiPartPayment(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams) val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = true) payFSM ! payment payFSM } else { context.log.debug("sending the payment to non-trampoline recipient without MPP") - val finalPayload = FinalPayload.Standard.createSinglePartPayload(payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, payloadOut.paymentMetadata) - val payment = SendPaymentToNode(payFsmAdapters, payloadOut.outgoingNodeId, finalPayload, nodeParams.maxPaymentAttempts, extraEdges, routeParams) + val payment = SendPaymentToNode(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams) val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = false) payFSM ! payment payFSM } case None => context.log.debug("sending the payment to the next trampoline node") - val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = true) val paymentSecret = randomBytes32() // we generate a new secret to protect against probing attacks - val payment = SendMultiPartPayment(payFsmAdapters, paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, None, routeParams = routeParams, additionalTlvs = Seq(OnionPaymentPayloadTlv.TrampolineOnion(packetOut))) + val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features.empty, payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, nextTrampolineOnion_opt = Some(packetOut)) + val payment = SendMultiPartPayment(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams) + val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = true) payFSM ! payment payFSM } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala index d0bc976e6..68574d0d0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala @@ -19,10 +19,8 @@ package fr.acinq.eclair.payment.send import akka.actor.{ActorRef, FSM, Props, Status} import akka.event.Logging.MDC import fr.acinq.bitcoin.scalacompat.ByteVector32 -import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.channel.{HtlcOverriddenByLocalCommit, HtlcsTimedoutDownstream, HtlcsWillTimeoutUpstream} import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus, PaymentType} -import fr.acinq.eclair.payment.Invoice.ExtraEdge import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream import fr.acinq.eclair.payment.PaymentSent.PartialPayment @@ -30,10 +28,7 @@ import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToRoute import fr.acinq.eclair.router.Router._ -import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload -import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, FSMDiagnosticActorLogging, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, TimestampMilli} -import scodec.bits.ByteVector +import fr.acinq.eclair.{FSMDiagnosticActorLogging, Logs, MilliSatoshiLong, NodeParams, TimestampMilli} import java.util.UUID import java.util.concurrent.TimeUnit @@ -62,29 +57,26 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, when(WAIT_FOR_PAYMENT_REQUEST) { case Event(r: SendMultiPartPayment, _) => val routeParams = r.routeParams.copy(randomize = false) // we don't randomize the first attempt, regardless of configuration choices - val maxFee = routeParams.getMaxFee(r.totalAmount) - log.debug("sending {} with maximum fee {}", r.totalAmount, maxFee) - val d = PaymentProgress(r, r.maxAttempts, Map.empty, Ignore.empty, Nil) - router ! createRouteRequest(nodeParams, r.totalAmount, maxFee, routeParams, d, cfg) + log.debug("sending {} with maximum fee {}", r.recipient.totalAmount, r.routeParams.getMaxFee(r.recipient.totalAmount)) + val d = PaymentProgress(r, r.maxAttempts, Map.empty, Ignore.empty, retryRouteRequest = false, failures = Nil) + router ! createRouteRequest(nodeParams, routeParams, d, cfg) goto(WAIT_FOR_ROUTES) using d } when(WAIT_FOR_ROUTES) { case Event(RouteResponse(routes), d: PaymentProgress) => log.info("{} routes found (attempt={}/{})", routes.length, d.request.maxAttempts - d.remainingAttempts + 1, d.request.maxAttempts) - // We may have already succeeded sending parts of the payment and only need to take care of the rest. - val (toSend, maxFee) = remainingToSend(d.request, d.pending.values, d.request.routeParams.includeLocalChannelCost) - if (routes.map(_.amount).sum == toSend) { + if (!d.retryRouteRequest) { val childPayments = routes.map(route => (UUID.randomUUID(), route)).toMap childPayments.foreach { case (childId, route) => spawnChildPaymentFsm(childId) ! createChildPayment(self, route, d.request) } goto(PAYMENT_IN_PROGRESS) using d.copy(remainingAttempts = (d.remainingAttempts - 1).max(0), pending = d.pending ++ childPayments) } else { // If a child payment failed while we were waiting for routes, the routes we received don't cover the whole // remaining amount. In that case we discard these routes and send a new request to the router. - log.info("discarding routes, another child payment failed so we need to recompute them (amount = {}, maximum fee = {})", toSend, maxFee) + log.info("discarding routes, another child payment failed so we need to recompute them ({} payments still pending for {})", d.pending.size, d.pending.values.map(_.amount).sum) val routeParams = d.request.routeParams.copy(randomize = true) // we randomize route selection when we retry - router ! createRouteRequest(nodeParams, toSend, maxFee, routeParams, d, cfg) - stay() + router ! createRouteRequest(nodeParams, routeParams, d, cfg) + stay() using d.copy(retryRouteRequest = false) } case Event(Status.Failure(t), d: PaymentProgress) => @@ -93,20 +85,19 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, // Channels are mostly ignored for temporary reasons, likely because they didn't have enough balance to forward // the payment. When we're retrying an MPP split, it may make sense to retry those ignored channels because with // a different split, they may have enough balance to forward the payment. - val (toSend, maxFee) = remainingToSend(d.request, d.pending.values, d.request.routeParams.includeLocalChannelCost) if (d.ignore.channels.nonEmpty) { - log.debug("retry sending {} with maximum fee {} without ignoring channels ({})", toSend, maxFee, d.ignore.channels.map(_.shortChannelId).mkString(",")) + log.debug("retry sending payment without ignoring channels {} ({} payments still pending for {})", d.ignore.channels.map(_.shortChannelId).mkString(","), d.pending.size, d.pending.values.map(_.amount).sum) val routeParams = d.request.routeParams.copy(randomize = true) // we randomize route selection when we retry - router ! createRouteRequest(nodeParams, toSend, maxFee, routeParams, d, cfg).copy(ignore = d.ignore.emptyChannels()) + router ! createRouteRequest(nodeParams, routeParams, d, cfg).copy(ignore = d.ignore.emptyChannels()) retriedFailedChannels = true - stay() using d.copy(remainingAttempts = (d.remainingAttempts - 1).max(0), ignore = d.ignore.emptyChannels()) + stay() using d.copy(remainingAttempts = (d.remainingAttempts - 1).max(0), ignore = d.ignore.emptyChannels(), retryRouteRequest = false) } else { - val failure = LocalFailure(toSend, Nil, t) + val failure = LocalFailure(d.request.recipient.totalAmount, Nil, t) Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(failure)).increment() if (cfg.storeInDb && d.pending.isEmpty && d.failures.isEmpty) { // In cases where we fail early (router error during the first attempt), the DB won't have an entry for that // payment, which may be confusing for users. - val dummyPayment = OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, cfg.recipientAmount, cfg.recipientAmount, cfg.recipientNodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending) + val dummyPayment = OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, d.request.recipient.totalAmount, d.request.recipient.totalAmount, d.request.recipient.nodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending) nodeParams.db.payments.addOutgoingPayment(dummyPayment) nodeParams.db.payments.updateOutgoingPayment(PaymentFailed(id, paymentHash, failure :: Nil)) } @@ -118,8 +109,8 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, gotoAbortedOrStop(PaymentAborted(d.request, d.failures ++ pf.failures, d.pending.keySet - pf.id)) } else { val ignore1 = PaymentFailure.updateIgnored(pf.failures, d.ignore) - val extraEdges1 = PaymentFailure.updateExtraEdges(pf.failures, d.request.extraEdges) - stay() using d.copy(pending = d.pending - pf.id, ignore = ignore1, failures = d.failures ++ pf.failures, request = d.request.copy(extraEdges = extraEdges1)) + val recipient1 = PaymentFailure.updateExtraEdges(pf.failures, d.request.recipient) + stay() using d.copy(pending = d.pending - pf.id, ignore = ignore1, failures = d.failures ++ pf.failures, request = d.request.copy(recipient = recipient1), retryRouteRequest = true) } // The recipient released the preimage without receiving the full payment amount. @@ -135,18 +126,17 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, if (abortPayment(pf, d)) { gotoAbortedOrStop(PaymentAborted(d.request, d.failures ++ pf.failures, d.pending.keySet - pf.id)) } else if (d.remainingAttempts == 0) { - val failure = LocalFailure(d.request.totalAmount, Nil, PaymentError.RetryExhausted) + val failure = LocalFailure(d.request.recipient.totalAmount, Nil, PaymentError.RetryExhausted) Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(failure)).increment() gotoAbortedOrStop(PaymentAborted(d.request, d.failures ++ pf.failures :+ failure, d.pending.keySet - pf.id)) } else { val ignore1 = PaymentFailure.updateIgnored(pf.failures, d.ignore) - val extraEdges1 = PaymentFailure.updateExtraEdges(pf.failures, d.request.extraEdges) + val recipient1 = PaymentFailure.updateExtraEdges(pf.failures, d.request.recipient) val stillPending = d.pending - pf.id - val (toSend, maxFee) = remainingToSend(d.request, stillPending.values, d.request.routeParams.includeLocalChannelCost) - log.debug("child payment failed, retry sending {} with maximum fee {}", toSend, maxFee) + log.debug("child payment failed, retrying payment ({} payments still pending for {})", stillPending.size, stillPending.values.map(_.amount).sum) val routeParams = d.request.routeParams.copy(randomize = true) // we randomize route selection when we retry - val d1 = d.copy(pending = stillPending, ignore = ignore1, failures = d.failures ++ pf.failures, request = d.request.copy(extraEdges = extraEdges1)) - router ! createRouteRequest(nodeParams, toSend, maxFee, routeParams, d1, cfg) + val d1 = d.copy(pending = stillPending, ignore = ignore1, failures = d.failures ++ pf.failures, request = d.request.copy(recipient = recipient1), retryRouteRequest = false) + router ! createRouteRequest(nodeParams, routeParams, d1, cfg) goto(WAIT_FOR_ROUTES) using d1 } @@ -184,7 +174,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, val parts = d.parts ++ ps.parts val pending = d.pending - ps.parts.head.id if (pending.isEmpty) { - myStop(d.request, Right(cfg.createPaymentSent(d.preimage, parts))) + myStop(d.request, Right(cfg.createPaymentSent(d.request.recipient, d.preimage, parts))) } else { stay() using d.copy(parts = parts, pending = pending) } @@ -195,7 +185,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, log.warning(s"payment succeeded but partial payment failed (id=${pf.id})") val pending = d.pending - pf.id if (pending.isEmpty) { - myStop(d.request, Right(cfg.createPaymentSent(d.preimage, d.parts))) + myStop(d.request, Right(cfg.createPaymentSent(d.request.recipient, d.preimage, d.parts))) } else { stay() using d.copy(pending = pending) } @@ -223,7 +213,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, private def gotoSucceededOrStop(d: PaymentSucceeded): State = { d.request.replyTo ! PreimageReceived(paymentHash, d.preimage) if (d.pending.isEmpty) { - myStop(d.request, Right(cfg.createPaymentSent(d.preimage, d.parts))) + myStop(d.request, Right(cfg.createPaymentSent(d.request.recipient, d.preimage, d.parts))) } else goto(PAYMENT_SUCCEEDED) using d } @@ -240,7 +230,11 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, val status = event match { case Right(_: PaymentSent) => "SUCCESS" case Left(f: PaymentFailed) => - if (f.failures.exists({ case r: RemoteFailure => r.e.originNode == cfg.recipientNodeId case _ => false })) { + val isRecipientFailure = f.failures.exists { + case r: RemoteFailure => r.e.originNode == request.recipient.nodeId + case _ => false + } + if (isRecipientFailure) { "RECIPIENT_FAILURE" } else { "FAILURE" @@ -252,7 +246,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, val fees = event match { case Left(paymentFailed) => log.info(s"failed payment attempts details: ${PaymentFailure.jsonSummary(cfg, request.routeParams.experimentName, paymentFailed)}") - request.routeParams.getMaxFee(cfg.recipientAmount) + request.routeParams.getMaxFee(request.recipient.totalAmount) case Right(paymentSent) => val localFees = cfg.upstream match { case _: Upstream.Local => 0.msat // no local fees when we are the origin of the payment @@ -265,7 +259,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, } paymentSent.feesPaid + localFees } - context.system.eventStream.publish(PathFindingExperimentMetrics(cfg.paymentHash, cfg.recipientAmount, fees, status, duration, now, isMultiPart = true, request.routeParams.experimentName, cfg.recipientNodeId, request.extraEdges)) + context.system.eventStream.publish(PathFindingExperimentMetrics(cfg.paymentHash, request.recipient.totalAmount, fees, status, duration, now, isMultiPart = true, request.routeParams.experimentName, request.recipient.nodeId, request.recipient.extraEdges)) } Metrics.SentPaymentDuration .withTag(Tags.MultiPart, Tags.MultiPartType.Parent) @@ -304,30 +298,12 @@ object MultiPartPaymentLifecycle { * Send a payment to a given node. The payment may be split into multiple child payments, for which a path-finding * algorithm will run to find suitable payment routes. * - * @param paymentSecret payment secret to protect against probing (usually from a Bolt 11 invoice). - * @param targetNodeId target node (may be the final recipient when using source-routing, or the first trampoline - * node when using trampoline). - * @param totalAmount total amount to send to the target node. - * @param targetExpiry expiry at the target node (CLTV for the target node's received HTLCs). - * @param maxAttempts maximum number of retries. - * @param paymentMetadata payment metadata (usually from the Bolt 11 invoice). - * @param extraEdges routing hints (usually from a Bolt 11 invoice). - * @param routeParams parameters to fine-tune the routing algorithm. - * @param additionalTlvs when provided, additional tlvs that will be added to the onion sent to the target node. - * @param userCustomTlvs when provided, additional user-defined custom tlvs that will be added to the onion sent to the target node. + * @param recipient final recipient. + * @param maxAttempts maximum number of retries. + * @param routeParams parameters to fine-tune the routing algorithm. */ - case class SendMultiPartPayment(replyTo: ActorRef, - paymentSecret: ByteVector32, - targetNodeId: PublicKey, - totalAmount: MilliSatoshi, - targetExpiry: CltvExpiry, - maxAttempts: Int, - paymentMetadata: Option[ByteVector], - extraEdges: Seq[ExtraEdge] = Nil, - routeParams: RouteParams, - additionalTlvs: Seq[OnionPaymentPayloadTlv] = Nil, - userCustomTlvs: Seq[GenericTlv] = Nil) { - require(totalAmount > 0.msat, s"total amount must be > 0") + case class SendMultiPartPayment(replyTo: ActorRef, recipient: Recipient, maxAttempts: Int, routeParams: RouteParams) { + require(recipient.totalAmount > 0.msat, "total amount must be > 0") } /** @@ -361,12 +337,14 @@ object MultiPartPaymentLifecycle { * @param remainingAttempts remaining attempts (after child payments fail). * @param pending pending child payments (payment sent, we are waiting for a fulfill or a failure). * @param ignore channels and nodes that should be ignored (previously returned a permanent error). + * @param retryRouteRequest if true, ignore the next [[RouteResponse]] and send another [[RouteRequest]]. * @param failures previous child payment failures. */ case class PaymentProgress(request: SendMultiPartPayment, remainingAttempts: Int, pending: Map[UUID, Route], ignore: Ignore, + retryRouteRequest: Boolean, failures: Seq[PaymentFailure]) extends Data /** @@ -391,37 +369,24 @@ object MultiPartPaymentLifecycle { */ case class PaymentSucceeded(request: SendMultiPartPayment, preimage: ByteVector32, parts: Seq[PartialPayment], pending: Set[UUID]) extends Data - private def createRouteRequest(nodeParams: NodeParams, toSend: MilliSatoshi, maxFee: MilliSatoshi, routeParams: RouteParams, d: PaymentProgress, cfg: SendPaymentConfig): RouteRequest = - RouteRequest( - nodeParams.nodeId, - d.request.targetNodeId, - toSend, - maxFee, - d.request.extraEdges, - d.ignore, - routeParams, - allowMultiPart = true, - d.pending.values.toSeq, - Some(cfg.paymentContext)) - - private def createChildPayment(replyTo: ActorRef, route: Route, request: SendMultiPartPayment): SendPaymentToRoute = { - val finalPayload = FinalPayload.Standard.createMultiPartPayload(route.amount, request.totalAmount, request.targetExpiry, request.paymentSecret, request.paymentMetadata, request.additionalTlvs, request.userCustomTlvs) - SendPaymentToRoute(replyTo, Right(route), finalPayload) + private def createRouteRequest(nodeParams: NodeParams, routeParams: RouteParams, d: PaymentProgress, cfg: SendPaymentConfig): RouteRequest = { + RouteRequest(nodeParams.nodeId, d.request.recipient, routeParams, d.ignore, allowMultiPart = true, d.pending.values.toSeq, Some(cfg.paymentContext)) } - /** When we receive an error from the final recipient or payment gets settled on chain, we should fail the whole payment, it's useless to retry. */ + private def createChildPayment(replyTo: ActorRef, route: Route, request: SendMultiPartPayment): SendPaymentToRoute = { + SendPaymentToRoute(replyTo, Right(route), request.recipient) + } + + /** When we receive a final error or the payment gets settled on chain, we should fail the whole payment, it's useless to retry. */ private def abortPayment(pf: PaymentFailed, d: PaymentProgress): Boolean = pf.failures.exists { - case f: RemoteFailure => f.e.originNode == d.request.targetNodeId + case f: RemoteFailure => + val isRecipientFailure = f.e.originNode == d.request.recipient.nodeId + val isTrampolineFailure = f.route.lastOption.collect { case h: NodeHop if f.e.originNode == h.nodeId => h }.nonEmpty + isRecipientFailure || isTrampolineFailure case LocalFailure(_, _, _: HtlcOverriddenByLocalCommit) => true case LocalFailure(_, _, _: HtlcsWillTimeoutUpstream) => true case LocalFailure(_, _, _: HtlcsTimedoutDownstream) => true case _ => false } - private def remainingToSend(request: SendMultiPartPayment, pending: Iterable[Route], includeLocalChannelCost: Boolean): (MilliSatoshi, MilliSatoshi) = { - val sentAmount = pending.map(_.amount).sum - val sentFees = pending.map(_.fee(includeLocalChannelCost)).sum - (request.totalAmount - sentAmount, request.routeParams.getMaxFee(request.totalAmount) - sentFees) - } - } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala index 49ee4e9b0..2a05451dc 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala @@ -26,9 +26,8 @@ import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.send.PaymentError._ import fr.acinq.eclair.router.RouteNotFound import fr.acinq.eclair.router.Router._ -import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, randomBytes32} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, NodeParams, randomBytes32} import java.util.UUID import scala.util.{Failure, Success, Try} @@ -49,29 +48,34 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn // Immediately return the paymentId sender() ! paymentId } - val paymentCfg = SendPaymentConfig(paymentId, paymentId, r.externalId, r.paymentHash, r.recipientAmount, r.recipientNodeId, Upstream.Local(paymentId), Some(r.invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true, Nil) + val paymentCfg = SendPaymentConfig(paymentId, paymentId, r.externalId, r.paymentHash, r.invoice.nodeId, Upstream.Local(paymentId), Some(r.invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true) val finalExpiry = r.finalExpiry(nodeParams) - if (!nodeParams.features.invoiceFeatures().areSupported(r.invoice.features)) { - sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, UnsupportedFeatures(r.invoice.features)) :: Nil) - } else if (Features.canUseFeature(nodeParams.features.invoiceFeatures(), r.invoice.features, Features.BasicMultiPartPayment)) { - val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg) - fsm ! MultiPartPaymentLifecycle.SendMultiPartPayment(self, r.invoice.paymentSecret, r.recipientNodeId, r.recipientAmount, finalExpiry, r.maxAttempts, r.invoice.paymentMetadata, r.invoice.extraEdges, r.routeParams, userCustomTlvs = r.userCustomTlvs) - context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r))) - } else { - val finalPayload = FinalPayload.Standard.createSinglePartPayload(r.recipientAmount, finalExpiry, r.invoice.paymentSecret, r.invoice.paymentMetadata, r.userCustomTlvs) - val fsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) - fsm ! PaymentLifecycle.SendPaymentToNode(self, r.recipientNodeId, finalPayload, r.maxAttempts, r.invoice.extraEdges, r.routeParams) - context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r))) + r.invoice match { + case invoice: Bolt11Invoice => + val recipient = ClearRecipient(invoice, r.recipientAmount, finalExpiry, r.userCustomTlvs) + if (!nodeParams.features.invoiceFeatures().areSupported(recipient.features)) { + sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, UnsupportedFeatures(recipient.features)) :: Nil) + } else if (Features.canUseFeature(nodeParams.features.invoiceFeatures(), recipient.features, Features.BasicMultiPartPayment)) { + val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg) + fsm ! MultiPartPaymentLifecycle.SendMultiPartPayment(self, recipient, r.maxAttempts, r.routeParams) + context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r))) + } else { + val fsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) + fsm ! PaymentLifecycle.SendPaymentToNode(self, recipient, r.maxAttempts, r.routeParams) + context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r))) + } + case _: Bolt12Invoice => + sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, new IllegalArgumentException("payments to Bolt12 invoices are not supported yet")) :: Nil) } case r: SendSpontaneousPayment => val paymentId = UUID.randomUUID() sender() ! paymentId - val paymentCfg = SendPaymentConfig(paymentId, paymentId, r.externalId, r.paymentHash, r.recipientAmount, r.recipientNodeId, Upstream.Local(paymentId), None, storeInDb = true, publishEvent = true, recordPathFindingMetrics = r.recordPathFindingMetrics, Nil) + val paymentCfg = SendPaymentConfig(paymentId, paymentId, r.externalId, r.paymentHash, r.recipientNodeId, Upstream.Local(paymentId), None, storeInDb = true, publishEvent = true, recordPathFindingMetrics = r.recordPathFindingMetrics) val finalExpiry = nodeParams.paymentFinalExpiry.computeFinalExpiry(nodeParams.currentBlockHeight, Channel.MIN_CLTV_EXPIRY_DELTA) - val finalPayload = FinalPayload.Standard(TlvStream(Seq(OnionPaymentPayloadTlv.AmountToForward(r.recipientAmount), OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), OnionPaymentPayloadTlv.PaymentData(randomBytes32(), r.recipientAmount), OnionPaymentPayloadTlv.KeySend(r.paymentPreimage)), r.userCustomTlvs)) + val recipient = SpontaneousRecipient(r.recipientNodeId, r.recipientAmount, finalExpiry, r.paymentPreimage, r.userCustomTlvs) val fsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) - fsm ! PaymentLifecycle.SendPaymentToNode(self, r.recipientNodeId, finalPayload, r.maxAttempts, routeParams = r.routeParams) + fsm ! PaymentLifecycle.SendPaymentToNode(self, recipient, r.maxAttempts, r.routeParams) context become main(pending + (paymentId -> PendingSpontaneousPayment(sender(), r))) case r: SendTrampolinePayment => @@ -96,39 +100,44 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn case r: SendPaymentToRoute => val paymentId = UUID.randomUUID() val parentPaymentId = r.parentId.getOrElse(UUID.randomUUID()) - val finalExpiry = r.finalExpiry(nodeParams) - val additionalHops = r.trampolineNodes.sliding(2).map(hop => NodeHop(hop.head, hop(1), CltvExpiryDelta(0), 0 msat)).toSeq - val paymentCfg = SendPaymentConfig(paymentId, parentPaymentId, r.externalId, r.paymentHash, r.recipientAmount, r.recipientNodeId, Upstream.Local(paymentId), Some(r.invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = false, additionalHops) - r.trampolineNodes match { - case trampoline :: recipient :: Nil => - log.info(s"sending trampoline payment to $recipient with trampoline=$trampoline, trampoline fees=${r.trampolineFees}, expiry delta=${r.trampolineExpiryDelta}") - buildTrampolinePayment(r, trampoline, r.trampolineFees, r.trampolineExpiryDelta) match { - case Success((trampolineAmount, trampolineExpiry, trampolineOnion)) => - // We generate a random secret for the payment to the first trampoline node. - val trampolineSecret = r.trampolineSecret.getOrElse(randomBytes32()) - sender() ! SendPaymentToRouteResponse(paymentId, parentPaymentId, Some(trampolineSecret)) + r.trampoline_opt match { + case _ if !nodeParams.features.invoiceFeatures().areSupported(r.invoice.features) => + sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, UnsupportedFeatures(r.invoice.features)) :: Nil) + case Some(trampolineAttempt) => + val trampolineNodeId = r.route.targetNodeId + log.info(s"sending trampoline payment to ${r.recipientNodeId} with trampoline=$trampolineNodeId, trampoline fees=${trampolineAttempt.fees}, expiry delta=${trampolineAttempt.cltvExpiryDelta}") + val trampolineHop = NodeHop(trampolineNodeId, r.recipientNodeId, trampolineAttempt.cltvExpiryDelta, trampolineAttempt.fees) + buildTrampolineRecipient(r, trampolineHop) match { + case Success(recipient) => + sender() ! SendPaymentToRouteResponse(paymentId, parentPaymentId, Some(recipient.trampolinePaymentSecret)) + val paymentCfg = SendPaymentConfig(paymentId, parentPaymentId, r.externalId, r.paymentHash, r.recipientNodeId, Upstream.Local(paymentId), Some(r.invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = false) val payFsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) - payFsm ! PaymentLifecycle.SendPaymentToRoute(self, Left(r.route), FinalPayload.Standard.createMultiPartPayload(r.amount, trampolineAmount, trampolineExpiry, trampolineSecret, r.invoice.paymentMetadata, Seq(OnionPaymentPayloadTlv.TrampolineOnion(trampolineOnion))), r.invoice.extraEdges) + payFsm ! PaymentLifecycle.SendPaymentToRoute(self, Left(r.route), recipient) context become main(pending + (paymentId -> PendingPaymentToRoute(sender(), r))) case Failure(t) => log.warning("cannot send outgoing trampoline payment: {}", t.getMessage) sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, t) :: Nil) } - case Nil => - sender() ! SendPaymentToRouteResponse(paymentId, parentPaymentId, None) - val payFsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) - payFsm ! PaymentLifecycle.SendPaymentToRoute(self, Left(r.route), FinalPayload.Standard.createMultiPartPayload(r.amount, r.recipientAmount, finalExpiry, r.invoice.paymentSecret, r.invoice.paymentMetadata), r.invoice.extraEdges) - context become main(pending + (paymentId -> PendingPaymentToRoute(sender(), r))) + case None => + r.invoice match { + case invoice: Bolt11Invoice => + sender() ! SendPaymentToRouteResponse(paymentId, parentPaymentId, None) + val paymentCfg = SendPaymentConfig(paymentId, parentPaymentId, r.externalId, r.paymentHash, r.recipientNodeId, Upstream.Local(paymentId), Some(r.invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = false) + val finalExpiry = r.finalExpiry(nodeParams) + val recipient = ClearRecipient(invoice, r.recipientAmount, finalExpiry, Nil) + val payFsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) + payFsm ! PaymentLifecycle.SendPaymentToRoute(self, Left(r.route), recipient) + context become main(pending + (paymentId -> PendingPaymentToRoute(sender(), r))) + case _: Bolt12Invoice => + sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, new IllegalArgumentException("payments to Bolt12 invoices are not supported yet")) :: Nil) + } case _ => sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, TrampolineMultiNodeNotSupported) :: Nil) } case pf: PaymentFailed => pending.get(pf.id).foreach { case pp: PendingTrampolinePayment => - val trampolineRoute = Seq( - NodeHop(nodeParams.nodeId, pp.r.trampolineNodeId, nodeParams.channelConf.expiryDelta, 0 msat), - NodeHop(pp.r.trampolineNodeId, pp.r.recipientNodeId, pp.r.trampolineAttempts.last._2, pp.r.trampolineAttempts.last._1) - ) + val trampolineHop = NodeHop(pp.r.trampolineNodeId, pp.r.recipientNodeId, pp.r.trampolineAttempts.last._2, pp.r.trampolineAttempts.last._1) val decryptedFailures = pf.failures.collect { case RemoteFailure(_, _, Sphinx.DecryptedFailurePacket(_, f)) => f } val shouldRetry = decryptedFailures.contains(TrampolineFeeInsufficient) || decryptedFailures.contains(TrampolineExpiryTooSoon) if (shouldRetry) { @@ -140,14 +149,14 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn context become main(pending + (pf.id -> pp.copy(remainingAttempts = remaining))) case Failure(t) => log.warning("cannot send outgoing trampoline payment: {}", t.getMessage) - val localFailure = pf.copy(failures = Seq(LocalFailure(pp.r.recipientAmount, trampolineRoute, t))) + val localFailure = pf.copy(failures = Seq(LocalFailure(pp.r.recipientAmount, Seq(trampolineHop), t))) pp.sender ! localFailure context.system.eventStream.publish(localFailure) context become main(pending - pf.id) } case Nil => log.info("trampoline node couldn't find a route after all retries") - val localFailure = pf.copy(failures = Seq(LocalFailure(pp.r.recipientAmount, trampolineRoute, RouteNotFound))) + val localFailure = pf.copy(failures = Seq(LocalFailure(pp.r.recipientAmount, Seq(trampolineHop), RouteNotFound))) pp.sender ! localFailure context.system.eventStream.publish(localFailure) context become main(pending - pf.id) @@ -185,38 +194,28 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn } - private def buildTrampolinePayment(r: SendRequestedPayment, trampolineNodeId: PublicKey, trampolineFees: MilliSatoshi, trampolineExpiryDelta: CltvExpiryDelta): Try[(MilliSatoshi, CltvExpiry, OnionRoutingPacket)] = { - val trampolineRoute = Seq( - NodeHop(nodeParams.nodeId, trampolineNodeId, nodeParams.channelConf.expiryDelta, 0 msat), - NodeHop(trampolineNodeId, r.recipientNodeId, trampolineExpiryDelta, trampolineFees) // for now we only use a single trampoline hop - ) - val finalPayload = if (r.invoice.features.hasFeature(Features.BasicMultiPartPayment)) { - FinalPayload.Standard.createMultiPartPayload(r.recipientAmount, r.recipientAmount, r.finalExpiry(nodeParams), r.invoice.paymentSecret, r.invoice.paymentMetadata) - } else { - FinalPayload.Standard.createSinglePartPayload(r.recipientAmount, r.finalExpiry(nodeParams), r.invoice.paymentSecret, r.invoice.paymentMetadata) - } - // We assume that the trampoline node supports multi-part payments (it should). - val trampolinePacket_opt = if (r.invoice.features.hasFeature(Features.TrampolinePaymentPrototype)) { - OutgoingPaymentPacket.buildTrampolinePacket(r.paymentHash, trampolineRoute, finalPayload) - } else { - r.invoice match { - case invoice: Bolt11Invoice => OutgoingPaymentPacket.buildTrampolineToLegacyPacket(invoice, trampolineRoute, finalPayload) - case _ => Failure(new Exception("Trampoline to legacy is only supported for Bolt11 invoices.")) - } - } - trampolinePacket_opt.map { - case (trampolineAmount, trampolineExpiry, trampolineOnion) => (trampolineAmount, trampolineExpiry, trampolineOnion.packet) + private def buildTrampolineRecipient(r: SendRequestedPayment, trampolineHop: NodeHop): Try[ClearTrampolineRecipient] = { + r.invoice match { + case invoice: Bolt11Invoice => + // We generate a random secret for the payment to the trampoline node. + val trampolineSecret = r match { + case r: SendPaymentToRoute => r.trampoline_opt.map(_.paymentSecret).getOrElse(randomBytes32()) + case _ => randomBytes32() + } + val finalExpiry = r.finalExpiry(nodeParams) + val recipient = ClearTrampolineRecipient(invoice, r.recipientAmount, finalExpiry, trampolineHop, trampolineSecret) + Success(recipient) + case _: Bolt12Invoice => + Failure(new IllegalArgumentException("payments to Bolt12 invoices are not supported yet")) } } private def sendTrampolinePayment(paymentId: UUID, r: SendTrampolinePayment, trampolineFees: MilliSatoshi, trampolineExpiryDelta: CltvExpiryDelta): Try[Unit] = { - val paymentCfg = SendPaymentConfig(paymentId, paymentId, None, r.paymentHash, r.recipientAmount, r.recipientNodeId, Upstream.Local(paymentId), Some(r.invoice), storeInDb = true, publishEvent = false, recordPathFindingMetrics = true, Seq(NodeHop(r.trampolineNodeId, r.recipientNodeId, trampolineExpiryDelta, trampolineFees))) - // We generate a random secret for this payment to avoid leaking the invoice secret to the first trampoline node. - val trampolineSecret = randomBytes32() - buildTrampolinePayment(r, r.trampolineNodeId, trampolineFees, trampolineExpiryDelta).map { - case (trampolineAmount, trampolineExpiry, trampolineOnion) => - val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg) - fsm ! MultiPartPaymentLifecycle.SendMultiPartPayment(self, trampolineSecret, r.trampolineNodeId, trampolineAmount, trampolineExpiry, nodeParams.maxPaymentAttempts, r.invoice.paymentMetadata, r.invoice.extraEdges, r.routeParams, Seq(OnionPaymentPayloadTlv.TrampolineOnion(trampolineOnion))) + val trampolineHop = NodeHop(r.trampolineNodeId, r.recipientNodeId, trampolineExpiryDelta, trampolineFees) + val paymentCfg = SendPaymentConfig(paymentId, paymentId, None, r.paymentHash, r.recipientNodeId, Upstream.Local(paymentId), Some(r.invoice), storeInDb = true, publishEvent = false, recordPathFindingMetrics = true) + buildTrampolineRecipient(r, trampolineHop).map { recipient => + val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg) + fsm ! MultiPartPaymentLifecycle.SendMultiPartPayment(self, recipient, nodeParams.maxPaymentAttempts, r.routeParams) } } @@ -273,9 +272,7 @@ object PaymentInitiator { } /** - * We temporarily let the caller decide to use Trampoline (instead of a normal payment) and set the fees/cltv. - * Once we have trampoline fee estimation built into the router, the decision to use Trampoline or not should be done - * automatically by the router instead of the caller. + * This command should be used to test the trampoline implementation until the feature is fully specified. * * @param recipientAmount amount that should be received by the final recipient (usually from a Bolt 11 invoice). * @param invoice Bolt 11 invoice. @@ -330,49 +327,44 @@ object PaymentInitiator { val paymentHash = Crypto.sha256(paymentPreimage) } + /** + * @param paymentSecret this is a secret to protect the payment to the trampoline node against probing. + * @param fees fees for the trampoline node. + * @param cltvExpiryDelta expiry delta for the trampoline node. + */ + case class TrampolineAttempt(paymentSecret: ByteVector32, fees: MilliSatoshi, cltvExpiryDelta: CltvExpiryDelta) + /** * The sender can skip the routing algorithm by specifying the route to use. + * * When combining with MPP and Trampoline, extra-care must be taken to make sure payments are correctly grouped: only - * amount, route and trampolineNodes should be changing. + * amount, route and trampoline_opt should be changing. Splitting across multiple trampoline nodes isn't supported. * * Example 1: MPP containing two HTLCs for a 600 msat invoice: - * SendPaymentToRouteRequest(200 msat, 600 msat, None, parentId, invoice, CltvExpiryDelta(9), Seq(alice, bob, dave), None, 0 msat, CltvExpiryDelta(0), Nil) - * SendPaymentToRouteRequest(400 msat, 600 msat, None, parentId, invoice, CltvExpiryDelta(9), Seq(alice, carol, dave), None, 0 msat, CltvExpiryDelta(0), Nil) + * SendPaymentToRoute(600 msat, invoice, Route(200 msat, Seq(alice, bob, dave)), None, Some(parentId), None) + * SendPaymentToRoute(600 msat, invoice, Route(400 msat, Seq(alice, carol, dave)), None, Some(parentId), None) * * Example 2: Trampoline with MPP for a 600 msat invoice and 100 msat trampoline fees: - * SendPaymentToRouteRequest(250 msat, 600 msat, None, parentId, invoice, CltvExpiryDelta(9), Seq(alice, bob, dave), secret, 100 msat, CltvExpiryDelta(144), Seq(dave, peter)) - * SendPaymentToRouteRequest(450 msat, 600 msat, None, parentId, invoice, CltvExpiryDelta(9), Seq(alice, carol, dave), secret, 100 msat, CltvExpiryDelta(144), Seq(dave, peter)) + * SendPaymentToRoute(600 msat, invoice, Route(250 msat, Seq(alice, bob, ted)), None, Some(parentId), Some(TrampolineAttempt(secret, 100 msat, CltvExpiryDelta(144)))) + * SendPaymentToRoute(600 msat, invoice, Route(450 msat, Seq(alice, carol, ted)), None, Some(parentId), Some(TrampolineAttempt(secret, 100 msat, CltvExpiryDelta(144)))) * - * @param amount amount that should be received by the last node in the route (should take trampoline - * fees into account). - * @param recipientAmount amount that should be received by the final recipient (usually from a Bolt 11 invoice). - * This amount may be split between multiple requests if using MPP. - * @param invoice Bolt 11 invoice. - * @param route route to use to reach either the final recipient or the first trampoline node. - * @param externalId (optional) externally-controlled identifier (to reconcile between application DB and eclair DB). - * @param parentId id of the whole payment. When manually sending a multi-part payment, you need to make - * sure all partial payments use the same parentId. If not provided, a random parentId will - * be generated that can be used for the remaining partial payments. - * @param trampolineSecret if trampoline is used, this is a secret to protect the payment to the first trampoline - * node against probing. When manually sending a multi-part payment, you need to make sure - * all partial payments use the same trampolineSecret. - * @param trampolineFees if trampoline is used, fees for the first trampoline node. This value must be the same - * for all partial payments in the set. - * @param trampolineExpiryDelta if trampoline is used, expiry delta for the first trampoline node. This value must be - * the same for all partial payments in the set. - * @param trampolineNodes if trampoline is used, list of trampoline nodes to use (we currently support only a - * single trampoline node). + * @param recipientAmount amount that should be received by the final recipient (usually from a Bolt 11 invoice). + * This amount may be split between multiple requests if using MPP. + * @param invoice Bolt 11 invoice. + * @param route route to use to reach either the final recipient or the trampoline node. + * @param externalId (optional) externally-controlled identifier (to reconcile between application DB and eclair DB). + * @param parentId id of the whole payment. When manually sending a multi-part payment, you need to make + * sure all partial payments use the same parentId. If not provided, a random parentId will + * be generated that can be used for the remaining partial payments. + * @param trampoline_opt if trampoline is used, this field must be provided. When manually sending a multi-part + * payment, you need to make sure all partial payments share the same values. */ - case class SendPaymentToRoute(amount: MilliSatoshi, - recipientAmount: MilliSatoshi, + case class SendPaymentToRoute(recipientAmount: MilliSatoshi, invoice: Invoice, route: PredefinedRoute, externalId: Option[String], parentId: Option[UUID], - trampolineSecret: Option[ByteVector32], - trampolineFees: MilliSatoshi, - trampolineExpiryDelta: CltvExpiryDelta, - trampolineNodes: Seq[PublicKey]) extends SendRequestedPayment + trampoline_opt: Option[TrampolineAttempt]) extends SendRequestedPayment /** * @param paymentId id of the outgoing payment (mapped to a single outgoing HTLC). @@ -392,7 +384,6 @@ object PaymentInitiator { * each with a different id). * @param externalId externally-controlled identifier (to reconcile between application DB and eclair DB). * @param paymentHash payment hash. - * @param recipientAmount amount that should be received by the final recipient (usually from a Bolt 11 invoice). * @param recipientNodeId id of the final recipient. * @param upstream information about the payment origin (to link upstream to downstream when relaying a payment). * @param invoice Bolt 11 invoice. @@ -401,24 +392,18 @@ object PaymentInitiator { * @param publishEvent whether to publish a [[fr.acinq.eclair.payment.PaymentEvent]] on success/failure (e.g. for * multi-part child payments, we don't want to emit events for each child, only for the whole payment). * @param recordPathFindingMetrics We don't record metrics for payments that don't use path finding or that are a part of a bigger payment. - * @param additionalHops additional hops that the payment state machine isn't aware of (e.g. when using trampoline, hops - * that occur after the first trampoline node). */ case class SendPaymentConfig(id: UUID, parentId: UUID, externalId: Option[String], paymentHash: ByteVector32, - recipientAmount: MilliSatoshi, recipientNodeId: PublicKey, upstream: Upstream, invoice: Option[Invoice], storeInDb: Boolean, // e.g. for trampoline we don't want to store in the DB when we're relaying payments publishEvent: Boolean, - recordPathFindingMetrics: Boolean, - additionalHops: Seq[NodeHop]) { - def fullRoute(route: Route): Seq[Hop] = route.hops ++ additionalHops - - def createPaymentSent(preimage: ByteVector32, parts: Seq[PaymentSent.PartialPayment]) = PaymentSent(parentId, paymentHash, preimage, recipientAmount, recipientNodeId, parts) + recordPathFindingMetrics: Boolean) { + def createPaymentSent(recipient: Recipient, preimage: ByteVector32, parts: Seq[PaymentSent.PartialPayment]) = PaymentSent(parentId, paymentHash, preimage, recipient.totalAmount, recipient.nodeId, parts) def paymentContext: PaymentContext = PaymentContext(id, parentId, paymentHash) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala index 561dd517d..c5e0476e9 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala @@ -25,7 +25,7 @@ import fr.acinq.eclair._ import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.{Sphinx, TransportHandler} import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus, PaymentType} -import fr.acinq.eclair.payment.Invoice.{BasicEdge, ExtraEdge} +import fr.acinq.eclair.payment.Invoice.ExtraEdge import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream import fr.acinq.eclair.payment.PaymentSent.PartialPayment @@ -34,7 +34,6 @@ import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle._ import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.router._ -import fr.acinq.eclair.wire.protocol.PaymentOnion._ import fr.acinq.eclair.wire.protocol._ import java.util.concurrent.TimeUnit @@ -54,43 +53,43 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A startWith(WAITING_FOR_REQUEST, WaitingForRequest) when(WAITING_FOR_REQUEST) { - case Event(c: SendPaymentToRoute, WaitingForRequest) => - log.debug("sending {} to route {}", c.finalPayload.amount, c.printRoute()) - c.route.fold( - hops => router ! FinalizeRoute(c.finalPayload.amount, hops, c.extraEdges, paymentContext = Some(cfg.paymentContext)), + case Event(request: SendPaymentToRoute, WaitingForRequest) => + log.debug("sending {} to route {}", request.amount, request.printRoute()) + request.route.fold( + hops => router ! FinalizeRoute(hops, request.recipient.extraEdges, paymentContext = Some(cfg.paymentContext)), route => self ! RouteResponse(route :: Nil) ) if (cfg.storeInDb) { - paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, c.finalPayload.amount, cfg.recipientAmount, cfg.recipientNodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending)) + paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, request.amount, request.recipient.totalAmount, request.recipient.nodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending)) } - goto(WAITING_FOR_ROUTE) using WaitingForRoute(c, Nil, Ignore.empty) + goto(WAITING_FOR_ROUTE) using WaitingForRoute(request, Nil, Ignore.empty) - case Event(c: SendPaymentToNode, WaitingForRequest) => - log.debug("sending {} to {}", c.finalPayload.amount, c.targetNodeId) - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.maxFee, c.extraEdges, routeParams = c.routeParams, paymentContext = Some(cfg.paymentContext)) + case Event(request: SendPaymentToNode, WaitingForRequest) => + log.debug("sending {} to {}", request.amount, request.recipient.nodeId) + router ! RouteRequest(nodeParams.nodeId, request.recipient, request.routeParams, paymentContext = Some(cfg.paymentContext)) if (cfg.storeInDb) { - paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, c.finalPayload.amount, cfg.recipientAmount, cfg.recipientNodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending)) + paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, request.amount, request.recipient.totalAmount, request.recipient.nodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending)) } - goto(WAITING_FOR_ROUTE) using WaitingForRoute(c, Nil, Ignore.empty) + goto(WAITING_FOR_ROUTE) using WaitingForRoute(request, Nil, Ignore.empty) } when(WAITING_FOR_ROUTE) { - case Event(RouteResponse(route +: _), WaitingForRoute(c, failures, ignore)) => - log.info(s"route found: attempt=${failures.size + 1}/${c.maxAttempts} route=${route.printNodes()} channels=${route.printChannels()}") - OutgoingPaymentPacket.buildCommand(self, cfg.upstream, paymentHash, route.hops, c.finalPayload) match { - case Success((cmd, sharedSecrets)) => - register ! Register.ForwardShortId(self.toTyped[Register.ForwardShortIdFailure[CMD_ADD_HTLC]], route.hops.head.shortChannelId, cmd) - goto(WAITING_FOR_PAYMENT_COMPLETE) using WaitingForComplete(c, cmd, failures, sharedSecrets, ignore, route) - case Failure(t) => - log.warning("cannot send outgoing payment: {}", t.getMessage) - Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(c.finalPayload.amount, Nil, t))).increment() - myStop(c, Left(PaymentFailed(id, paymentHash, failures :+ LocalFailure(c.finalPayload.amount, Nil, t)))) + case Event(RouteResponse(route +: _), WaitingForRoute(request, failures, ignore)) => + log.info(s"route found: attempt=${failures.size + 1}/${request.maxAttempts} route=${route.printNodes()} channels=${route.printChannels()}") + OutgoingPaymentPacket.buildOutgoingPayment(self, cfg.upstream, paymentHash, route, request.recipient) match { + case Right(payment) => + register ! Register.ForwardShortId(self.toTyped[Register.ForwardShortIdFailure[CMD_ADD_HTLC]], payment.outgoingChannel, payment.cmd) + goto(WAITING_FOR_PAYMENT_COMPLETE) using WaitingForComplete(request, payment.cmd, failures, payment.sharedSecrets, ignore, route) + case Left(error) => + log.warning("cannot send outgoing payment: {}", error.getMessage) + Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(request.amount, route.fullRoute, error))).increment() + myStop(request, Left(PaymentFailed(id, paymentHash, failures :+ LocalFailure(request.amount, route.fullRoute, error)))) } - case Event(Status.Failure(t), WaitingForRoute(c, failures, _)) => + case Event(Status.Failure(t), WaitingForRoute(request, failures, _)) => log.warning("router error: {}", t.getMessage) - Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(c.finalPayload.amount, Nil, t))).increment() - myStop(c, Left(PaymentFailed(id, paymentHash, failures :+ LocalFailure(c.finalPayload.amount, Nil, t)))) + Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(request.amount, Nil, t))).increment() + myStop(request, Left(PaymentFailed(id, paymentHash, failures :+ LocalFailure(request.amount, Nil, t)))) } when(WAITING_FOR_PAYMENT_COMPLETE) { @@ -105,8 +104,8 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A case Event(RES_ADD_SETTLED(_, htlc, fulfill: HtlcResult.Fulfill), d: WaitingForComplete) => router ! Router.RouteDidRelay(d.route) Metrics.PaymentAttempt.withTag(Tags.MultiPart, value = false).record(d.failures.size + 1) - val p = PartialPayment(id, d.c.finalPayload.amount, d.cmd.amount - d.c.finalPayload.amount, htlc.channelId, Some(cfg.fullRoute(d.route))) - myStop(d.c, Right(cfg.createPaymentSent(fulfill.paymentPreimage, p :: Nil))) + val p = PartialPayment(id, d.request.amount, d.cmd.amount - d.request.amount, htlc.channelId, Some(d.route.fullRoute)) + myStop(d.request, Right(cfg.createPaymentSent(d.recipient, fulfill.paymentPreimage, p :: Nil))) case Event(RES_ADD_SETTLED(_, _, fail: HtlcResult.Fail), d: WaitingForComplete) => fail match { @@ -136,11 +135,11 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A } private def retry(failure: PaymentFailure, data: WaitingForComplete): FSM.State[PaymentLifecycle.State, PaymentLifecycle.Data] = { - data.c match { - case sendPaymentToNode: SendPaymentToNode => + data.request match { + case request: SendPaymentToNode => val ignore1 = PaymentFailure.updateIgnored(failure, data.ignore) - router ! RouteRequest(nodeParams.nodeId, data.c.targetNodeId, data.c.finalPayload.amount, sendPaymentToNode.maxFee, data.c.extraEdges, ignore1, sendPaymentToNode.routeParams, paymentContext = Some(cfg.paymentContext)) - goto(WAITING_FOR_ROUTE) using WaitingForRoute(data.c, data.failures :+ failure, ignore1) + router ! RouteRequest(nodeParams.nodeId, data.recipient, request.routeParams, ignore1, paymentContext = Some(cfg.paymentContext)) + goto(WAITING_FOR_ROUTE) using WaitingForRoute(data.request, data.failures :+ failure, ignore1) case _: SendPaymentToRoute => log.error("unexpected retry during SendPaymentToRoute") stop(FSM.Normal) @@ -153,16 +152,16 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A private def handleLocalFail(d: WaitingForComplete, t: Throwable, isFatal: Boolean) = { t match { case UpdateMalformedException => Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType.Malformed).increment() - case _ => Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(d.c.finalPayload.amount, cfg.fullRoute(d.route), t))).increment() + case _ => Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(d.request.amount, d.route.fullRoute, t))).increment() } // we only retry if the error isn't fatal, and we haven't exhausted the max number of retried - val doRetry = !isFatal && (d.failures.size + 1 < d.c.maxAttempts) - val localFailure = LocalFailure(d.c.finalPayload.amount, cfg.fullRoute(d.route), t) + val doRetry = !isFatal && (d.failures.size + 1 < d.request.maxAttempts) + val localFailure = LocalFailure(d.request.amount, d.route.fullRoute, t) if (doRetry) { log.info(s"received an error message from local, trying to use a different channel (failure=${t.getMessage})") retry(localFailure, d) } else { - myStop(d.c, Left(PaymentFailed(id, paymentHash, d.failures :+ localFailure))) + myStop(d.request, Left(PaymentFailed(id, paymentHash, d.failures :+ localFailure))) } } @@ -170,22 +169,22 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A import d._ ((Sphinx.FailurePacket.decrypt(fail.reason, sharedSecrets) match { case success@Success(e) => - Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(RemoteFailure(d.c.finalPayload.amount, Nil, e))).increment() + Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(RemoteFailure(request.amount, Nil, e))).increment() success case failure@Failure(_) => - Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(UnreadableRemoteFailure(d.c.finalPayload.amount, Nil))).increment() + Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(UnreadableRemoteFailure(request.amount, Nil))).increment() failure }) match { case res@Success(Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) => // We have discovered some liquidity information with this payment: we update the router accordingly. - val stoppedRoute = d.route.stopAt(nodeId) + val stoppedRoute = route.stopAt(nodeId) if (stoppedRoute.hops.length > 1) { router ! Router.RouteCouldRelay(stoppedRoute) } failureMessage match { case TemporaryChannelFailure(update) => - d.route.hops.find(_.nodeId == nodeId) match { - case Some(failingHop) if ChannelRelayParams.areSame(failingHop.params, ChannelRelayParams.FromAnnouncement(update), ignoreHtlcSize = true) => + route.hops.find(_.nodeId == nodeId) match { + case Some(failingHop) if HopRelayParams.areSame(failingHop.params, HopRelayParams.FromAnnouncement(update), ignoreHtlcSize = true) => router ! Router.ChannelCouldNotRelay(stoppedRoute.amount, failingHop) case _ => // otherwise the relay parameters may have changed, so it's not necessarily a liquidity issue } @@ -194,11 +193,15 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A res case res => res }) match { - case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) if nodeId == c.targetNodeId => + case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) if nodeId == recipient.nodeId => // if destination node returns an error, we fail the payment immediately log.warning(s"received an error message from target nodeId=$nodeId, failing the payment (failure=$failureMessage)") - myStop(c, Left(PaymentFailed(id, paymentHash, failures :+ RemoteFailure(d.c.finalPayload.amount, cfg.fullRoute(route), e)))) - case res if failures.size + 1 >= c.maxAttempts => + myStop(request, Left(PaymentFailed(id, paymentHash, failures :+ RemoteFailure(request.amount, route.fullRoute, e)))) + case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) if route.finalHop_opt.collect { case h: NodeHop if h.nodeId == nodeId => h }.nonEmpty => + // if trampoline node returns an error, we fail the payment immediately + log.warning(s"received an error message from trampoline nodeId=$nodeId, failing the payment (failure=$failureMessage)") + myStop(request, Left(PaymentFailed(id, paymentHash, failures :+ RemoteFailure(request.amount, route.fullRoute, e)))) + case res if failures.size + 1 >= request.maxAttempts => // otherwise we never try more than maxAttempts, no matter the kind of error returned val failure = res match { case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) => @@ -207,51 +210,51 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A case failureMessage: Update => handleUpdate(nodeId, failureMessage, d) case _ => } - RemoteFailure(d.c.finalPayload.amount, cfg.fullRoute(route), e) + RemoteFailure(request.amount, route.fullRoute, e) case Failure(t) => log.warning(s"cannot parse returned error ${fail.reason.toHex} with sharedSecrets=$sharedSecrets: ${t.getMessage}") - UnreadableRemoteFailure(d.c.finalPayload.amount, cfg.fullRoute(route)) + UnreadableRemoteFailure(request.amount, route.fullRoute) } log.warning(s"too many failed attempts, failing the payment") - myStop(c, Left(PaymentFailed(id, paymentHash, failures :+ failure))) + myStop(request, Left(PaymentFailed(id, paymentHash, failures :+ failure))) case Failure(t) => log.warning(s"cannot parse returned error: ${t.getMessage}, route=${route.printNodes()}") - val failure = UnreadableRemoteFailure(d.c.finalPayload.amount, cfg.fullRoute(route)) + val failure = UnreadableRemoteFailure(request.amount, route.fullRoute) retry(failure, d) case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Node)) => log.info(s"received 'Node' type error message from nodeId=$nodeId, trying to route around it (failure=$failureMessage)") - val failure = RemoteFailure(d.c.finalPayload.amount, cfg.fullRoute(route), e) + val failure = RemoteFailure(request.amount, route.fullRoute, e) retry(failure, d) case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Update)) => log.info(s"received 'Update' type error message from nodeId=$nodeId, retrying payment (failure=$failureMessage)") - val failure = RemoteFailure(d.c.finalPayload.amount, cfg.fullRoute(route), e) + val failure = RemoteFailure(request.amount, route.fullRoute, e) if (Announcements.checkSig(failureMessage.update, nodeId)) { - val extraEdges1 = handleUpdate(nodeId, failureMessage, d) + val recipient1 = handleUpdate(nodeId, failureMessage, d) val ignore1 = PaymentFailure.updateIgnored(failure, ignore) // let's try again, router will have updated its state - c match { + request match { case _: SendPaymentToRoute => log.error("unexpected retry during SendPaymentToRoute") stop(FSM.Normal) - case c: SendPaymentToNode => - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.maxFee, extraEdges1, ignore1, c.routeParams, paymentContext = Some(cfg.paymentContext)) - goto(WAITING_FOR_ROUTE) using WaitingForRoute(c, failures :+ failure, ignore1) + case request: SendPaymentToNode => + router ! RouteRequest(nodeParams.nodeId, recipient1, request.routeParams, ignore1, paymentContext = Some(cfg.paymentContext)) + goto(WAITING_FOR_ROUTE) using WaitingForRoute(request.copy(recipient = recipient1), failures :+ failure, ignore1) } } else { // this node is fishy, it gave us a bad sig!! let's filter it out log.warning(s"got bad signature from node=$nodeId update=${failureMessage.update}") - c match { + request match { case _: SendPaymentToRoute => log.error("unexpected retry during SendPaymentToRoute") stop(FSM.Normal) - case c: SendPaymentToNode => - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.maxFee, c.extraEdges, ignore + nodeId, c.routeParams, paymentContext = Some(cfg.paymentContext)) - goto(WAITING_FOR_ROUTE) using WaitingForRoute(c, failures :+ failure, ignore + nodeId) + case request: SendPaymentToNode => + router ! RouteRequest(nodeParams.nodeId, recipient, request.routeParams, ignore + nodeId, paymentContext = Some(cfg.paymentContext)) + goto(WAITING_FOR_ROUTE) using WaitingForRoute(request, failures :+ failure, ignore + nodeId) } } case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) => log.info(s"received an error message from nodeId=$nodeId, trying to use a different channel (failure=$failureMessage)") - val failure = RemoteFailure(d.c.finalPayload.amount, cfg.fullRoute(route), e) + val failure = RemoteFailure(request.amount, route.fullRoute, e) retry(failure, d) } } @@ -261,10 +264,10 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A * * @return updated routing hints if applicable. */ - private def handleUpdate(nodeId: PublicKey, failure: Update, data: WaitingForComplete): Seq[ExtraEdge] = { + private def handleUpdate(nodeId: PublicKey, failure: Update, data: WaitingForComplete): Recipient = { val extraEdges1 = data.route.hops.find(_.nodeId == nodeId) match { case Some(hop) => hop.params match { - case ann: ChannelRelayParams.FromAnnouncement => + case ann: HopRelayParams.FromAnnouncement => if (ann.channelUpdate.shortChannelId != failure.update.shortChannelId) { // it is possible that nodes in the route prefer using a different channel (to the same N+1 node) than the one we requested, that's fine log.info("received an update for a different channel than the one we asked: requested={} actual={} update={}", ann.channelUpdate.shortChannelId, failure.update.shortChannelId, failure.update) @@ -280,34 +283,37 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A } else { log.info("got a new update for shortChannelId={}: old={} new={}", ann.channelUpdate.shortChannelId, ann.channelUpdate, failure.update) } - data.c.extraEdges - case _: ChannelRelayParams.FromHint => + data.recipient.extraEdges + case _: HopRelayParams.FromHint => log.info("received an update for a routing hint (shortChannelId={} nodeId={} enabled={} update={})", failure.update.shortChannelId, nodeId, failure.update.channelFlags.isEnabled, failure.update) if (failure.update.channelFlags.isEnabled) { - data.c.extraEdges.map { - case edge: BasicEdge if edge.sourceNodeId == nodeId && edge.targetNodeId == hop.nextNodeId => edge.update(failure.update) - case edge: BasicEdge => edge + data.recipient.extraEdges.map { + case edge: ExtraEdge if edge.sourceNodeId == nodeId && edge.targetNodeId == hop.nextNodeId => edge.update(failure.update) + case edge: ExtraEdge => edge } } else { // if the channel is disabled, we temporarily exclude it: this is necessary because the routing hint doesn't // contain channel flags to indicate that it's disabled // we want the exclusion to be router-wide so that sister payments in the case of MPP are aware the channel is faulty - data.c.extraEdges - .find { case edge: BasicEdge => edge.sourceNodeId == nodeId && edge.targetNodeId == hop.nextNodeId } - .foreach { case edge: BasicEdge => router ! ExcludeChannel(ChannelDesc(edge.shortChannelId, edge.sourceNodeId, edge.targetNodeId), Some(nodeParams.routerConf.channelExcludeDuration)) } + data.recipient.extraEdges + .find(edge => edge.sourceNodeId == nodeId && edge.targetNodeId == hop.nextNodeId) + .foreach(edge => router ! ExcludeChannel(ChannelDesc(edge.shortChannelId, edge.sourceNodeId, edge.targetNodeId), Some(nodeParams.routerConf.channelExcludeDuration))) // we remove this edge for our next payment attempt - data.c.extraEdges.filterNot { case edge: BasicEdge => edge.sourceNodeId == nodeId && edge.targetNodeId == hop.nextNodeId } + data.recipient.extraEdges.filterNot(edge => edge.sourceNodeId == nodeId && edge.targetNodeId == hop.nextNodeId) } } case None => log.error(s"couldn't find node=$nodeId in the route, this should never happen") - data.c.extraEdges + data.recipient.extraEdges } // in all cases, we forward the update to the router: if the channel is disabled, the router will remove it from its routing table // if the channel is not announced (e.g. was from a hint), the router will simply ignore the update router ! failure.update - // we return updated assisted routes: they take precedence over the router's routing table - extraEdges1 + // we update the recipient's assisted routes: they take precedence over the router's routing table + data.recipient match { + case recipient: ClearRecipient => recipient.copy(extraEdges = extraEdges1) + case recipient => recipient + } } def myStop(request: SendPayment, result: Either[PaymentFailed, PaymentSent]): State = { @@ -328,7 +334,11 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A val status = result match { case Right(_: PaymentSent) => "SUCCESS" case Left(f: PaymentFailed) => - if (f.failures.exists({ case r: RemoteFailure => r.e.originNode == cfg.recipientNodeId case _ => false })) { + val isRecipientFailure = f.failures.exists { + case r: RemoteFailure => r.e.originNode == request.recipient.nodeId + case _ => false + } + if (isRecipientFailure) { "RECIPIENT_FAILURE" } else { "FAILURE" @@ -351,7 +361,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A paymentSent.feesPaid + localFees case Left(paymentFailed) => val (fees, pathFindingExperiment) = request match { - case s: SendPaymentToNode => (s.routeParams.getMaxFee(cfg.recipientAmount), s.routeParams.experimentName) + case request: SendPaymentToNode => (request.routeParams.getMaxFee(request.amount), request.routeParams.experimentName) case _: SendPaymentToRoute => (0 msat, "n/a") } log.info(s"failed payment attempts details: ${PaymentFailure.jsonSummary(cfg, pathFindingExperiment, paymentFailed)}") @@ -359,7 +369,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A } request match { case request: SendPaymentToNode => - context.system.eventStream.publish(PathFindingExperimentMetrics(cfg.paymentHash, request.finalPayload.amount, fees, status, duration, now, isMultiPart = false, request.routeParams.experimentName, cfg.recipientNodeId, request.extraEdges)) + context.system.eventStream.publish(PathFindingExperimentMetrics(cfg.paymentHash, request.amount, fees, status, duration, now, isMultiPart = false, request.routeParams.experimentName, request.recipient.nodeId, request.recipient.extraEdges)) case _: SendPaymentToRoute => () } } @@ -390,9 +400,8 @@ object PaymentLifecycle { sealed trait SendPayment { // @formatter:off def replyTo: ActorRef - def finalPayload: FinalPayload - def extraEdges: Seq[ExtraEdge] - def targetNodeId: PublicKey + def amount: MilliSatoshi + def recipient: Recipient def maxAttempts: Int // @formatter:on } @@ -400,53 +409,41 @@ object PaymentLifecycle { /** * Send a payment to a given route. * - * @param route payment route to use. - * @param finalPayload onion payload for the target node. + * @param route payment route to use. + * @param recipient final recipient. */ - case class SendPaymentToRoute(replyTo: ActorRef, - route: Either[PredefinedRoute, Route], - finalPayload: FinalPayload, - extraEdges: Seq[ExtraEdge] = Nil) extends SendPayment { - require(route.fold(!_.isEmpty, _.hops.nonEmpty), "payment route must not be empty") + case class SendPaymentToRoute(replyTo: ActorRef, route: Either[PredefinedRoute, Route], recipient: Recipient) extends SendPayment { + require(route.fold(r => !r.isEmpty, r => r.hops.nonEmpty || r.finalHop_opt.nonEmpty), "payment route must not be empty") - val targetNodeId: PublicKey = route.fold(_.targetNodeId, _.hops.last.nextNodeId) - - override def maxAttempts: Int = 1 + override val maxAttempts: Int = 1 + override val amount = route.fold(_.amount, _.amount) def printRoute(): String = route match { - case Left(PredefinedChannelRoute(_, channels)) => channels.mkString("->") - case Left(PredefinedNodeRoute(nodes)) => nodes.mkString("->") - case Right(route) => route.hops.map(_.nextNodeId).mkString("->") + case Left(PredefinedChannelRoute(_, _, channels)) => channels.mkString("->") + case Left(PredefinedNodeRoute(_, nodes)) => nodes.mkString("->") + case Right(route) => route.printNodes() } } /** * Send a payment to a given node. A path-finding algorithm will run to find a suitable payment route. * - * @param targetNodeId target node (may be the final recipient when using source-routing, or the first trampoline - * node when using trampoline). - * @param finalPayload onion payload for the target node. - * @param maxAttempts maximum number of retries. - * @param extraEdges routing hints (usually from a Bolt 11 invoice). - * @param routeParams parameters to fine-tune the routing algorithm. + * @param recipient final recipient. + * @param maxAttempts maximum number of retries. + * @param routeParams parameters to fine-tune the routing algorithm. */ - case class SendPaymentToNode(replyTo: ActorRef, - targetNodeId: PublicKey, - finalPayload: FinalPayload, - maxAttempts: Int, - extraEdges: Seq[ExtraEdge] = Nil, - routeParams: RouteParams) extends SendPayment { - require(finalPayload.amount > 0.msat, s"amount must be > 0") - - val maxFee: MilliSatoshi = routeParams.getMaxFee(finalPayload.amount) - + case class SendPaymentToNode(replyTo: ActorRef, recipient: Recipient, maxAttempts: Int, routeParams: RouteParams) extends SendPayment { + require(recipient.totalAmount > 0.msat, "amount must be > 0") + override val amount = recipient.totalAmount } // @formatter:off sealed trait Data case object WaitingForRequest extends Data - case class WaitingForRoute(c: SendPayment, failures: Seq[PaymentFailure], ignore: Ignore) extends Data - case class WaitingForComplete(c: SendPayment, cmd: CMD_ADD_HTLC, failures: Seq[PaymentFailure], sharedSecrets: Seq[(ByteVector32, PublicKey)], ignore: Ignore, route: Route) extends Data + case class WaitingForRoute(request: SendPayment, failures: Seq[PaymentFailure], ignore: Ignore) extends Data + case class WaitingForComplete(request: SendPayment, cmd: CMD_ADD_HTLC, failures: Seq[PaymentFailure], sharedSecrets: Seq[(ByteVector32, PublicKey)], ignore: Ignore, route: Route) extends Data { + val recipient = request.recipient + } sealed trait State case object WAITING_FOR_REQUEST extends State diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala new file mode 100644 index 000000000..ed0690c40 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala @@ -0,0 +1,170 @@ +/* + * Copyright 2022 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.payment.send + +import fr.acinq.bitcoin.scalacompat.ByteVector32 +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey +import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.payment.Invoice.ExtraEdge +import fr.acinq.eclair.payment.OutgoingPaymentPacket._ +import fr.acinq.eclair.payment.{Bolt11Invoice, OutgoingPaymentPacket} +import fr.acinq.eclair.router.Router.{ChannelHop, NodeHop, Route} +import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload} +import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionRoutingPacket, PaymentOnionCodecs} +import fr.acinq.eclair.{CltvExpiry, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, ShortChannelId} +import scodec.bits.ByteVector + +/** + * Created by t-bast on 28/10/2022. + */ + +sealed trait Recipient { + /** Id of the final receiving node. */ + def nodeId: PublicKey + + /** Total amount that will be received by the final receiving node. */ + def totalAmount: MilliSatoshi + + /** CLTV expiry that will be received by the final receiving node. */ + def expiry: CltvExpiry + + /** Features supported by the recipient. */ + def features: Features[InvoiceFeature] + + /** Edges that aren't part of the public graph and can be used to reach the recipient. */ + def extraEdges: Seq[ExtraEdge] + + /** Build a payment to the recipient using the route provided. */ + def buildPayloads(paymentHash: ByteVector32, route: Route): Either[OutgoingPaymentError, PaymentPayloads] +} + +object Recipient { + /** Iteratively build all the payloads for a payment relayed through channel hops. */ + def buildPayloads(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, finalPayload: NodePayload, hops: Seq[ChannelHop]): PaymentPayloads = { + // We ignore the first hop since the route starts at our node. + hops.tail.foldRight(PaymentPayloads(finalAmount, finalExpiry, Seq(finalPayload))) { + case (hop, current) => + val payload = NodePayload(hop.nodeId, IntermediatePayload.ChannelRelay.Standard(hop.shortChannelId, current.amount, current.expiry)) + PaymentPayloads(current.amount + hop.fee(current.amount), current.expiry + hop.cltvExpiryDelta, payload +: current.payloads) + } + } +} + +/** A payment recipient that can directly be found in the routing graph. */ +case class ClearRecipient(nodeId: PublicKey, + features: Features[InvoiceFeature], + totalAmount: MilliSatoshi, + expiry: CltvExpiry, + paymentSecret: ByteVector32, + extraEdges: Seq[ExtraEdge] = Nil, + paymentMetadata_opt: Option[ByteVector] = None, + nextTrampolineOnion_opt: Option[OnionRoutingPacket] = None, + customTlvs: Seq[GenericTlv] = Nil) extends Recipient { + override def buildPayloads(paymentHash: ByteVector32, route: Route): Either[OutgoingPaymentError, PaymentPayloads] = { + ClearRecipient.validateRoute(nodeId, route).map(_ => { + val finalPayload = nextTrampolineOnion_opt match { + case Some(trampolinePacket) => NodePayload(nodeId, FinalPayload.Standard.createTrampolinePayload(route.amount, totalAmount, expiry, paymentSecret, trampolinePacket)) + case None => NodePayload(nodeId, FinalPayload.Standard.createPayload(route.amount, totalAmount, expiry, paymentSecret, paymentMetadata_opt, customTlvs)) + } + Recipient.buildPayloads(route.amount, expiry, finalPayload, route.hops) + }) + } +} + +object ClearRecipient { + def apply(invoice: Bolt11Invoice, totalAmount: MilliSatoshi, expiry: CltvExpiry, customTlvs: Seq[GenericTlv]): ClearRecipient = { + ClearRecipient(invoice.nodeId, invoice.features, totalAmount, expiry, invoice.paymentSecret, invoice.extraEdges, invoice.paymentMetadata, None, customTlvs) + } + + def validateRoute(nodeId: PublicKey, route: Route): Either[OutgoingPaymentError, Route] = { + route.hops.lastOption match { + case Some(hop) if hop.nextNodeId == nodeId => Right(route) + case Some(hop) => Left(InvalidRouteRecipient(nodeId, hop.nextNodeId)) + case None => Left(EmptyRoute) + } + } +} + +/** A payment recipient that doesn't expect to receive a payment and can directly be found in the routing graph. */ +case class SpontaneousRecipient(nodeId: PublicKey, + totalAmount: MilliSatoshi, + expiry: CltvExpiry, + preimage: ByteVector32, + customTlvs: Seq[GenericTlv] = Nil) extends Recipient { + override val features = Features.empty + override val extraEdges = Nil + + override def buildPayloads(paymentHash: ByteVector32, route: Route): Either[OutgoingPaymentError, PaymentPayloads] = { + ClearRecipient.validateRoute(nodeId, route).map(_ => { + val finalPayload = NodePayload(nodeId, FinalPayload.Standard.createKeySendPayload(route.amount, totalAmount, expiry, preimage, customTlvs)) + Recipient.buildPayloads(totalAmount, expiry, finalPayload, route.hops) + }) + } +} + +/** A payment recipient that can be reached through a given trampoline node (usually not found in the routing graph). */ +case class ClearTrampolineRecipient(invoice: Bolt11Invoice, + totalAmount: MilliSatoshi, + expiry: CltvExpiry, + trampolineHop: NodeHop, + trampolinePaymentSecret: ByteVector32, + customTlvs: Seq[GenericTlv] = Nil) extends Recipient { + require(trampolineHop.nextNodeId == invoice.nodeId, "trampoline hop must end at the recipient") + + val trampolineNodeId = trampolineHop.nodeId + val trampolineFee = trampolineHop.fee(totalAmount) + val trampolineAmount = totalAmount + trampolineFee + val trampolineExpiry = expiry + trampolineHop.cltvExpiryDelta + + override val nodeId = invoice.nodeId + override val features = invoice.features + override val extraEdges = Seq(ExtraEdge(trampolineNodeId, nodeId, ShortChannelId.generateLocalAlias(), trampolineFee, 0, trampolineHop.cltvExpiryDelta, 1 msat, None)) + + private def validateRoute(route: Route): Either[OutgoingPaymentError, NodeHop] = { + route.finalHop_opt match { + case Some(trampolineHop: NodeHop) => Right(trampolineHop) + case None => Left(MissingTrampolineHop(trampolineNodeId)) + } + } + + override def buildPayloads(paymentHash: ByteVector32, route: Route): Either[OutgoingPaymentError, PaymentPayloads] = { + for { + trampolineHop <- validateRoute(route) + trampolineOnion <- createTrampolinePacket(paymentHash, trampolineHop) + } yield { + val trampolinePayload = NodePayload(trampolineHop.nodeId, FinalPayload.Standard.createTrampolinePayload(route.amount, trampolineAmount, trampolineExpiry, trampolinePaymentSecret, trampolineOnion.packet)) + Recipient.buildPayloads(route.amount, trampolineExpiry, trampolinePayload, route.hops) + } + } + + private def createTrampolinePacket(paymentHash: ByteVector32, trampolineHop: NodeHop): Either[OutgoingPaymentError, Sphinx.PacketAndSecrets] = { + if (invoice.features.hasFeature(Features.TrampolinePaymentPrototype)) { + // This is the payload the final recipient will receive, so we use the invoice's payment secret. + val finalPayload = NodePayload(nodeId, FinalPayload.Standard.createPayload(totalAmount, totalAmount, expiry, invoice.paymentSecret, invoice.paymentMetadata, customTlvs)) + val trampolinePayload = NodePayload(trampolineHop.nodeId, IntermediatePayload.NodeRelay.Standard(totalAmount, expiry, nodeId)) + val payloads = Seq(trampolinePayload, finalPayload) + OutgoingPaymentPacket.buildOnion(PaymentOnionCodecs.trampolineOnionPayloadLength, payloads, paymentHash) + } else { + // The recipient doesn't support trampoline: the trampoline node will convert the payment to a non-trampoline payment. + // The final payload will thus never reach the recipient, so we create the smallest payload possible to avoid overflowing the trampoline onion size. + val dummyFinalPayload = NodePayload(nodeId, IntermediatePayload.ChannelRelay.Standard(ShortChannelId(0), 0 msat, CltvExpiry(0))) + val trampolinePayload = NodePayload(trampolineHop.nodeId, IntermediatePayload.NodeRelay.Standard.createNodeRelayToNonTrampolinePayload(totalAmount, totalAmount, expiry, nodeId, invoice)) + val payloads = Seq(trampolinePayload, dummyFinalPayload) + OutgoingPaymentPacket.buildOnion(PaymentOnionCodecs.trampolineOnionPayloadLength, payloads, paymentHash) + } + } +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala index 554a25dbe..a6dfa4ac5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala @@ -433,7 +433,7 @@ object Graph { * @param capacity channel capacity * @param balance_opt (optional) available balance that can be sent through this edge */ - case class GraphEdge private(desc: ChannelDesc, params: ChannelRelayParams, capacity: Satoshi, balance_opt: Option[MilliSatoshi]) { + case class GraphEdge private(desc: ChannelDesc, params: HopRelayParams, capacity: Satoshi, balance_opt: Option[MilliSatoshi]) { def maxHtlcAmount(reservedCapacity: MilliSatoshi): MilliSatoshi = Seq( balance_opt.map(balance => balance - reservedCapacity), @@ -447,28 +447,26 @@ object Graph { object GraphEdge { def apply(u: ChannelUpdate, pc: PublicChannel): GraphEdge = GraphEdge( desc = ChannelDesc(u, pc.ann), - params = ChannelRelayParams.FromAnnouncement(u), + params = HopRelayParams.FromAnnouncement(u), capacity = pc.capacity, balance_opt = pc.getBalanceSameSideAs(u) ) def apply(u: ChannelUpdate, pc: PrivateChannel): GraphEdge = GraphEdge( desc = ChannelDesc(u, pc), - params = ChannelRelayParams.FromAnnouncement(u), + params = HopRelayParams.FromAnnouncement(u), capacity = pc.capacity, balance_opt = pc.getBalanceSameSideAs(u) ) - def apply(e: Invoice.ExtraEdge): GraphEdge = e match { - case e@Invoice.BasicEdge(sourceNodeId, targetNodeId, shortChannelId, _, _, _) => + def apply(e: Invoice.ExtraEdge): GraphEdge = { val maxBtc = 21e6.btc GraphEdge( - desc = ChannelDesc(shortChannelId, sourceNodeId, targetNodeId), - params = ChannelRelayParams.FromHint(e), - // Bolt 11 routing hints don't include the channel's capacity, so we assume it's big enough + desc = ChannelDesc(e.shortChannelId, e.sourceNodeId, e.targetNodeId), + params = HopRelayParams.FromHint(e), + // Routing hints don't include the channel's capacity, so we assume it's big enough. capacity = maxBtc.toSatoshi, - // we assume channels provided as hints have enough balance to handle the payment - balance_opt = Some(maxBtc.toMilliSatoshi) + balance_opt = None, ) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala index e2b98294f..b8de39d65 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala @@ -22,6 +22,7 @@ import com.softwaremill.quicklens.ModifyPimp import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair._ +import fr.acinq.eclair.payment.send.{ClearRecipient, ClearTrampolineRecipient, Recipient, SpontaneousRecipient} import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph.graphEdgeToHop import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} import fr.acinq.eclair.router.Graph.{InfiniteLoop, NegativeProbability, RichWeight} @@ -56,22 +57,23 @@ object RouteCalculation { paymentHash_opt = fr.paymentContext.map(_.paymentHash))) { implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors - val g = fr.extraEdges.map(GraphEdge(_)).foldLeft(d.graphWithBalances.graph) { case (g: DirectedGraph, e: GraphEdge) => g.addEdge(e) } + val extraEdges = fr.extraEdges.map(GraphEdge(_)) + val g = extraEdges.foldLeft(d.graphWithBalances.graph) { case (g: DirectedGraph, e: GraphEdge) => g.addEdge(e) } fr.route match { - case PredefinedNodeRoute(hops) => + case PredefinedNodeRoute(amount, hops) => // split into sublists [(a,b),(b,c), ...] then get the edges between each of those pairs hops.sliding(2).map { case List(v1, v2) => g.getEdgesBetween(v1, v2) }.toList match { case edges if edges.nonEmpty && edges.forall(_.nonEmpty) => // select the largest edge (using balance when available, otherwise capacity). val selectedEdges = edges.map(es => es.maxBy(e => e.balance_opt.getOrElse(e.capacity.toMilliSatoshi))) val hops = selectedEdges.map(e => ChannelHop(getEdgeRelayScid(d, localNodeId, e), e.desc.a, e.desc.b, e.params)) - ctx.sender() ! RouteResponse(Route(fr.amount, hops) :: Nil) + ctx.sender() ! RouteResponse(Route(amount, hops, None) :: Nil) case _ => // some nodes in the supplied route aren't connected in our graph ctx.sender() ! Status.Failure(new IllegalArgumentException("Not all the nodes in the supplied route are connected with public channels")) } - case PredefinedChannelRoute(targetNodeId, shortChannelIds) => + case PredefinedChannelRoute(amount, targetNodeId, shortChannelIds) => val (end, hops) = shortChannelIds.foldLeft((localNodeId, Seq.empty[ChannelHop])) { case ((currentNode, previousHops), shortChannelId) => val channelDesc_opt = d.resolve(shortChannelId) match { @@ -85,7 +87,7 @@ object RouteCalculation { case c.nodeId2 => Some(ChannelDesc(c.shortIds.localAlias, c.nodeId2, c.nodeId1)) case _ => None } - case None => fr.extraEdges.map(GraphEdge(_)).find(e => e.desc.shortChannelId == shortChannelId && e.desc.a == currentNode).map(_.desc) + case None => extraEdges.find(e => e.desc.shortChannelId == shortChannelId && e.desc.a == currentNode).map(_.desc) } channelDesc_opt.flatMap(c => g.getEdge(c)) match { case Some(edge) => (edge.desc.b, previousHops :+ ChannelHop(getEdgeRelayScid(d, localNodeId, edge), edge.desc.a, edge.desc.b, edge.params)) @@ -95,7 +97,7 @@ object RouteCalculation { if (end != targetNodeId || hops.length != shortChannelIds.length) { ctx.sender() ! Status.Failure(new IllegalArgumentException("The sequence of channels provided cannot be used to build a route to the target node")) } else { - ctx.sender() ! RouteResponse(Route(fr.amount, hops) :: Nil) + ctx.sender() ! RouteResponse(Route(amount, hops, None) :: Nil) } } @@ -103,6 +105,56 @@ object RouteCalculation { } } + /** + * Based on the type of recipient for the payment, this function returns: + * - the node to which routes should be found + * - the amount that should be sent to that node + * - the maximum allowed fee for routes to that node + * - an optional set of additional graph edges + * + * The routes found must then be post-processed by calling [[addFinalHop]]. + */ + private def computeTarget(r: RouteRequest, ignoredEdges: Set[ChannelDesc]): (PublicKey, MilliSatoshi, MilliSatoshi, Set[GraphEdge]) = { + val pendingAmount = r.pendingPayments.map(_.amount).sum + val totalMaxFee = r.routeParams.getMaxFee(r.target.totalAmount) + val pendingChannelFee = r.pendingPayments.map(_.channelFee(r.routeParams.includeLocalChannelCost)).sum + r.target match { + case recipient: ClearRecipient => + val targetNodeId = recipient.nodeId + val amountToSend = recipient.totalAmount - pendingAmount + val maxFee = totalMaxFee - pendingChannelFee + val extraEdges = recipient.extraEdges + .filter(_.sourceNodeId != r.source) // we ignore routing hints for our own channels, we have more accurate information + .map(GraphEdge(_)) + .filterNot(e => ignoredEdges.contains(e.desc)) + .toSet + (targetNodeId, amountToSend, maxFee, extraEdges) + case recipient: SpontaneousRecipient => + val targetNodeId = recipient.nodeId + val amountToSend = recipient.totalAmount - pendingAmount + val maxFee = totalMaxFee - pendingChannelFee + (targetNodeId, amountToSend, maxFee, Set.empty) + case recipient: ClearTrampolineRecipient => + // Trampoline payments require finding routes to the trampoline node, not the final recipient. + // This also ensures that we correctly take the trampoline fee into account only once, even when using MPP to + // reach the trampoline node (which will aggregate the incoming MPP payment and re-split as necessary). + val targetNodeId = recipient.trampolineHop.nodeId + val amountToSend = recipient.trampolineAmount - pendingAmount + val maxFee = totalMaxFee - pendingChannelFee - recipient.trampolineFee + (targetNodeId, amountToSend, maxFee, Set.empty) + } + } + + private def addFinalHop(recipient: Recipient, routes: Seq[Route]): Seq[Route] = { + routes.map(route => { + recipient match { + case _: ClearRecipient => route + case _: SpontaneousRecipient => route + case recipient: ClearTrampolineRecipient => route.copy(finalHop_opt = Some(recipient.trampolineHop)) + } + }) + } + def handleRouteRequest(d: Data, currentBlockHeight: BlockHeight, r: RouteRequest)(implicit ctx: ActorContext, log: DiagnosticLoggingAdapter): Data = { Logs.withMdc(log)(Logs.mdc( category_opt = Some(LogCategory.PAYMENT), @@ -111,25 +163,27 @@ object RouteCalculation { paymentHash_opt = r.paymentContext.map(_.paymentHash))) { implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors - val extraEdges = r.extraEdges.map(GraphEdge(_)).filterNot(_.desc.a == r.source).toSet // we ignore routing hints for our own channels, we have more accurate information val ignoredEdges = r.ignore.channels ++ d.excludedChannels.keySet - val params = r.routeParams - val routesToFind = if (params.randomize) DEFAULT_ROUTES_COUNT else 1 + val (targetNodeId, amountToSend, maxFee, extraEdges) = computeTarget(r, ignoredEdges) + val routesToFind = if (r.routeParams.randomize) DEFAULT_ROUTES_COUNT else 1 - log.info(s"finding routes ${r.source}->${r.target} with assistedChannels={} ignoreNodes={} ignoreChannels={} excludedChannels={}", extraEdges.map(_.desc.shortChannelId).mkString(","), r.ignore.nodes.map(_.value).mkString(","), r.ignore.channels.mkString(","), d.excludedChannels.mkString(",")) - log.info("finding routes with params={}, multiPart={}", params, r.allowMultiPart) - log.info("local channels to recipient: {}", d.graphWithBalances.graph.getEdgesBetween(r.source, r.target).map(e => s"${e.desc.shortChannelId} (${e.balance_opt}/${e.capacity})").mkString(", ")) - val tags = TagSet.Empty.withTag(Tags.MultiPart, r.allowMultiPart).withTag(Tags.Amount, Tags.amountBucket(r.amount)) + log.info(s"finding routes ${r.source}->$targetNodeId with assistedChannels={} ignoreNodes={} ignoreChannels={} excludedChannels={}", extraEdges.map(_.desc.shortChannelId).mkString(","), r.ignore.nodes.map(_.value).mkString(","), r.ignore.channels.mkString(","), d.excludedChannels.mkString(",")) + log.info("finding routes with params={}, multiPart={}", r.routeParams, r.allowMultiPart) + log.info("local channels to target node: {}", d.graphWithBalances.graph.getEdgesBetween(r.source, targetNodeId).map(e => s"${e.desc.shortChannelId} (${e.balance_opt}/${e.capacity})").mkString(", ")) + val tags = TagSet.Empty.withTag(Tags.MultiPart, r.allowMultiPart).withTag(Tags.Amount, Tags.amountBucket(amountToSend)) KamonExt.time(Metrics.FindRouteDuration.withTags(tags.withTag(Tags.NumberOfRoutes, routesToFind.toLong))) { val result = if (r.allowMultiPart) { - findMultiPartRoute(d.graphWithBalances.graph, r.source, r.target, r.amount, r.maxFee, extraEdges, ignoredEdges, r.ignore.nodes, r.pendingPayments, params, currentBlockHeight) + findMultiPartRoute(d.graphWithBalances.graph, r.source, targetNodeId, amountToSend, maxFee, extraEdges, ignoredEdges, r.ignore.nodes, r.pendingPayments, r.routeParams, currentBlockHeight) } else { - findRoute(d.graphWithBalances.graph, r.source, r.target, r.amount, r.maxFee, routesToFind, extraEdges, ignoredEdges, r.ignore.nodes, params, currentBlockHeight) + findRoute(d.graphWithBalances.graph, r.source, targetNodeId, amountToSend, maxFee, routesToFind, extraEdges, ignoredEdges, r.ignore.nodes, r.routeParams, currentBlockHeight) } - result match { + result.map(routes => addFinalHop(r.target, routes)) match { case Success(routes) => + // Note that we don't record the length of the whole route: we ignore the trampoline hop because we only + // care about the part that we found ourselves (and we don't even know the length that will be used between + // trampoline nodes). Metrics.RouteResults.withTags(tags).record(routes.length) - routes.foreach(route => Metrics.RouteLength.withTags(tags).record(route.length)) + routes.foreach(route => Metrics.RouteLength.withTags(tags).record(route.hops.length)) ctx.sender() ! RouteResponse(routes) case Failure(failure: InfiniteLoop) => log.error(s"found infinite loop ${failure.path.map(edge => edge.desc).mkString(" -> ")}") @@ -140,7 +194,7 @@ object RouteCalculation { Metrics.FindRouteErrors.withTags(tags.withTag(Tags.Error, "NegativeProbability")).increment() ctx.sender() ! Status.Failure(failure) case Failure(t) => - val failure = if (isNeighborBalanceTooLow(d.graphWithBalances.graph, r)) BalanceTooLow else t + val failure = if (isNeighborBalanceTooLow(d.graphWithBalances.graph, r.source, targetNodeId, amountToSend)) BalanceTooLow else t Metrics.FindRouteErrors.withTags(tags.withTag(Tags.Error, failure.getClass.getSimpleName)).increment() ctx.sender() ! Status.Failure(failure) } @@ -199,7 +253,7 @@ object RouteCalculation { routeParams: RouteParams, currentBlockHeight: BlockHeight): Try[Seq[Route]] = Try { findRouteInternal(g, localNodeId, targetNodeId, amount, maxFee, numRoutes, extraEdges, ignoredEdges, ignoredVertices, routeParams, currentBlockHeight) match { - case Right(routes) => routes.map(route => Route(amount, route.path.map(graphEdgeToHop))) + case Right(routes) => routes.map(route => Route(amount, route.path.map(graphEdgeToHop), None)) case Left(ex) => return Failure(ex) } } @@ -367,7 +421,7 @@ object RouteCalculation { val edgeMaxAmount = edge.maxHtlcAmount(usedCapacity.getOrElse(edge.desc.shortChannelId, 0 msat)) amountMinusFees.min(edgeMaxAmount) } - Route(amount.max(0 msat), route.map(graphEdgeToHop)) + Route(amount.max(0 msat), route.map(graphEdgeToHop), None) } /** Initialize known used capacity based on pending HTLCs. */ @@ -381,7 +435,7 @@ object RouteCalculation { /** Update used capacity by taking into account an HTLC sent to the given route. */ private def updateUsedCapacity(route: Route, usedCapacity: mutable.Map[ShortChannelId, MilliSatoshi]): Unit = { - route.hops.reverse.foldLeft(route.amount) { case (amount, hop) => + route.hops.foldRight(route.amount) { case (hop, amount) => usedCapacity.updateWith(hop.shortChannelId)(previous => Some(amount + previous.getOrElse(0 msat))) amount + hop.fee(amount) } @@ -389,7 +443,7 @@ object RouteCalculation { private def validateMultiPartRoute(amount: MilliSatoshi, maxFee: MilliSatoshi, routes: Seq[Route], includeLocalChannelCost: Boolean): Boolean = { val amountOk = routes.map(_.amount).sum == amount - val feeOk = routes.map(_.fee(includeLocalChannelCost)).sum <= maxFee + val feeOk = routes.map(_.channelFee(includeLocalChannelCost)).sum <= maxFee amountOk && feeOk } @@ -398,9 +452,9 @@ object RouteCalculation { * requested amount. We could potentially relay the payment by using indirect routes, but since we're connected to * the target node it means we'd like to reach it via direct channels as much as possible. */ - private def isNeighborBalanceTooLow(g: DirectedGraph, r: RouteRequest): Boolean = { - val neighborEdges = g.getEdgesBetween(r.source, r.target) - neighborEdges.nonEmpty && neighborEdges.map(e => e.balance_opt.getOrElse(e.capacity.toMilliSatoshi)).sum < r.amount + private def isNeighborBalanceTooLow(g: DirectedGraph, source: PublicKey, target: PublicKey, amount: MilliSatoshi): Boolean = { + val neighborEdges = g.getEdgesBetween(source, target) + neighborEdges.nonEmpty && neighborEdges.map(e => e.balance_opt.getOrElse(e.capacity.toMilliSatoshi)).sum < amount } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index 8eabb9ac6..56d3970ab 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -33,6 +33,7 @@ import fr.acinq.eclair.db.NetworkDb import fr.acinq.eclair.io.Peer.PeerRoutingMessage import fr.acinq.eclair.payment.Invoice.ExtraEdge import fr.acinq.eclair.payment.relay.Relayer +import fr.acinq.eclair.payment.send.Recipient import fr.acinq.eclair.payment.{Bolt11Invoice, Invoice} import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph @@ -428,7 +429,7 @@ object Router { } // @formatter:on - trait Hop { + sealed trait Hop { /** @return the id of the start node. */ def nodeId: PublicKey @@ -445,54 +446,57 @@ object Router { def cltvExpiryDelta: CltvExpiryDelta } - // @formatter:off - /** Channel routing parameters */ - sealed trait ChannelRelayParams { + /** Routing parameters for relaying payments over a given hop. */ + sealed trait HopRelayParams { + // @formatter:off def cltvExpiryDelta: CltvExpiryDelta def relayFees: Relayer.RelayFees final def fee(amount: MilliSatoshi): MilliSatoshi = nodeFee(relayFees, amount) def htlcMinimum: MilliSatoshi def htlcMaximum_opt: Option[MilliSatoshi] + // @formatter:on } - object ChannelRelayParams { - /** We learnt about this channel from a channel_update */ - case class FromAnnouncement(channelUpdate: ChannelUpdate) extends ChannelRelayParams { - override def cltvExpiryDelta: CltvExpiryDelta = channelUpdate.cltvExpiryDelta - override def relayFees: Relayer.RelayFees = channelUpdate.relayFees - override def htlcMinimum: MilliSatoshi = channelUpdate.htlcMinimumMsat - override def htlcMaximum_opt: Option[MilliSatoshi] = Some(channelUpdate.htlcMaximumMsat) - } - /** We learnt about this channel from hints in an invoice */ - case class FromHint(extraHop: Invoice.ExtraEdge) extends ChannelRelayParams { - override def cltvExpiryDelta: CltvExpiryDelta = extraHop.cltvExpiryDelta - override def relayFees: Relayer.RelayFees = extraHop.relayFees - override def htlcMinimum: MilliSatoshi = extraHop.htlcMinimum - override def htlcMaximum_opt: Option[MilliSatoshi] = extraHop.htlcMaximum_opt + object HopRelayParams { + /** We learnt about this channel from a channel_update. */ + case class FromAnnouncement(channelUpdate: ChannelUpdate) extends HopRelayParams { + override val cltvExpiryDelta = channelUpdate.cltvExpiryDelta + override val relayFees = channelUpdate.relayFees + override val htlcMinimum = channelUpdate.htlcMinimumMsat + override val htlcMaximum_opt = Some(channelUpdate.htlcMaximumMsat) } - def areSame(a: ChannelRelayParams, b: ChannelRelayParams, ignoreHtlcSize: Boolean = false): Boolean = + /** We learnt about this hop from hints in an invoice. */ + case class FromHint(extraHop: Invoice.ExtraEdge) extends HopRelayParams { + override val cltvExpiryDelta = extraHop.cltvExpiryDelta + override val relayFees = extraHop.relayFees + override val htlcMinimum = extraHop.htlcMinimum + override val htlcMaximum_opt = extraHop.htlcMaximum_opt + } + + def areSame(a: HopRelayParams, b: HopRelayParams, ignoreHtlcSize: Boolean = false): Boolean = a.cltvExpiryDelta == b.cltvExpiryDelta && a.relayFees == b.relayFees && (ignoreHtlcSize || (a.htlcMinimum == b.htlcMinimum && a.htlcMaximum_opt == b.htlcMaximum_opt)) } - // @formatter:on /** - * A directed hop between two connected nodes using a specific channel. + * A directed hop between two nodes connected by a channel. * + * @param shortChannelId scid of the channel. * @param nodeId id of the start node. * @param nextNodeId id of the end node. - * @param shortChannelId scid that will be used to build the payment onion. * @param params source for the channel parameters. */ - case class ChannelHop(shortChannelId: ShortChannelId, nodeId: PublicKey, nextNodeId: PublicKey, params: ChannelRelayParams) extends Hop { + case class ChannelHop(shortChannelId: ShortChannelId, nodeId: PublicKey, nextNodeId: PublicKey, params: HopRelayParams) extends Hop { // @formatter:off - override def cltvExpiryDelta: CltvExpiryDelta = params.cltvExpiryDelta + override val cltvExpiryDelta = params.cltvExpiryDelta override def fee(amount: MilliSatoshi): MilliSatoshi = params.fee(amount) // @formatter:on } + sealed trait FinalHop extends Hop + /** * A directed hop between two trampoline nodes. * These nodes need not be connected and we don't need to know a route between them. @@ -503,7 +507,7 @@ object Router { * @param cltvExpiryDelta cltv expiry delta. * @param fee total fee for that hop. */ - case class NodeHop(nodeId: PublicKey, nextNodeId: PublicKey, cltvExpiryDelta: CltvExpiryDelta, fee: MilliSatoshi) extends Hop { + case class NodeHop(nodeId: PublicKey, nextNodeId: PublicKey, cltvExpiryDelta: CltvExpiryDelta, fee: MilliSatoshi) extends FinalHop { override def fee(amount: MilliSatoshi): MilliSatoshi = fee } @@ -536,18 +540,14 @@ object Router { } case class RouteRequest(source: PublicKey, - target: PublicKey, - amount: MilliSatoshi, - maxFee: MilliSatoshi, - extraEdges: Seq[ExtraEdge] = Nil, - ignore: Ignore = Ignore.empty, + target: Recipient, routeParams: RouteParams, + ignore: Ignore = Ignore.empty, allowMultiPart: Boolean = false, pendingPayments: Seq[Route] = Nil, paymentContext: Option[PaymentContext] = None) - case class FinalizeRoute(amount: MilliSatoshi, - route: PredefinedRoute, + case class FinalizeRoute(route: PredefinedRoute, extraEdges: Seq[ExtraEdge] = Nil, paymentContext: Option[PaymentContext] = None) @@ -556,14 +556,22 @@ object Router { */ case class PaymentContext(id: UUID, parentId: UUID, paymentHash: ByteVector32) - case class Route(amount: MilliSatoshi, hops: Seq[ChannelHop]) { - require(hops.nonEmpty, "route cannot be empty") + case class Route(amount: MilliSatoshi, hops: Seq[ChannelHop], finalHop_opt: Option[FinalHop]) { + require(hops.nonEmpty || finalHop_opt.nonEmpty, "route cannot be empty") - val length = hops.length + /** Full route including the final hop, if any. */ + val fullRoute: Seq[Hop] = hops ++ finalHop_opt.toSeq - def fee(includeLocalChannelCost: Boolean): MilliSatoshi = { + /** + * Fee paid for the trampoline hop, if any. + * Note that when using MPP to reach the trampoline node, the trampoline fee must be counted only once. + */ + val trampolineFee: MilliSatoshi = finalHop_opt.collect { case hop: NodeHop => hop.fee(amount) }.getOrElse(0 msat) + + /** Fee paid for the channel hops towards the recipient or the source of the final hop, if any. */ + def channelFee(includeLocalChannelCost: Boolean): MilliSatoshi = { val hopsToPay = if (includeLocalChannelCost) hops else hops.drop(1) - val amountToSend = hopsToPay.reverse.foldLeft(amount) { case (amount1, hop) => amount1 + hop.fee(amount1) } + val amountToSend = hopsToPay.foldRight(amount) { case (hop, amount1) => amount1 + hop.fee(amount1) } amountToSend - amount } @@ -573,7 +581,7 @@ object Router { def stopAt(nodeId: PublicKey): Route = { val amountAtStop = hops.reverse.takeWhile(_.nextNodeId != nodeId).foldLeft(amount) { case (amount1, hop) => amount1 + hop.fee(amount1) } - Route(amountAtStop, hops.takeWhile(_.nodeId != nodeId)) + Route(amountAtStop, hops.takeWhile(_.nodeId != nodeId), None) } } @@ -585,13 +593,14 @@ object Router { /** A pre-defined route chosen outside of eclair (e.g. manually by a user to do some re-balancing). */ sealed trait PredefinedRoute { def isEmpty: Boolean + def amount: MilliSatoshi def targetNodeId: PublicKey } - case class PredefinedNodeRoute(nodes: Seq[PublicKey]) extends PredefinedRoute { + case class PredefinedNodeRoute(amount: MilliSatoshi, nodes: Seq[PublicKey]) extends PredefinedRoute { override def isEmpty = nodes.isEmpty override def targetNodeId: PublicKey = nodes.last } - case class PredefinedChannelRoute(targetNodeId: PublicKey, channels: Seq[ShortChannelId]) extends PredefinedRoute { + case class PredefinedChannelRoute(amount: MilliSatoshi, targetNodeId: PublicKey, channels: Seq[ShortChannelId]) extends PredefinedRoute { override def isEmpty = channels.isEmpty } // @formatter:on @@ -616,7 +625,6 @@ object Router { case object RoutingStateStreamingUpToDate extends RemoteTypes case object GetRouterData case object GetNodes - case object GetLocalChannels case object GetChannels case object GetChannelsMap case object GetChannelUpdates diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala index 78cce6eb1..eca19fb94 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala @@ -375,24 +375,24 @@ object PaymentOnion { Right(Standard(records)) } - def createSinglePartPayload(amount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, paymentMetadata: Option[ByteVector], userCustomTlvs: Seq[GenericTlv] = Nil): Standard = { - val tlvs = Seq( - Some(AmountToForward(amount)), - Some(OutgoingCltv(expiry)), - Some(PaymentData(paymentSecret, amount)), - paymentMetadata.map(m => PaymentMetadata(m)) - ).flatten - Standard(TlvStream(tlvs, userCustomTlvs)) - } - - def createMultiPartPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, paymentMetadata: Option[ByteVector], additionalTlvs: Seq[OnionPaymentPayloadTlv] = Nil, userCustomTlvs: Seq[GenericTlv] = Nil): Standard = { + def createPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, paymentMetadata: Option[ByteVector] = None, customTlvs: Seq[GenericTlv] = Nil): Standard = { val tlvs = Seq( Some(AmountToForward(amount)), Some(OutgoingCltv(expiry)), Some(PaymentData(paymentSecret, totalAmount)), paymentMetadata.map(m => PaymentMetadata(m)) ).flatten - Standard(TlvStream(tlvs ++ additionalTlvs, userCustomTlvs)) + Standard(TlvStream(tlvs, customTlvs)) + } + + def createKeySendPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, preimage: ByteVector32, customTlvs: Seq[GenericTlv] = Nil): Standard = { + val tlvs = Seq( + AmountToForward(amount), + OutgoingCltv(expiry), + PaymentData(preimage, totalAmount), + KeySend(preimage) + ) + Standard(TlvStream(tlvs, customTlvs)) } /** Create a trampoline outer payload. */ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala index 0fd137154..56b2303d8 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala @@ -308,13 +308,12 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I import f._ val eclair = new EclairImpl(kit) - val route = PredefinedNodeRoute(Seq(randomKey().publicKey)) - val trampolines = Seq(randomKey().publicKey, randomKey().publicKey) + val route = PredefinedNodeRoute(1000 msat, Seq(randomKey().publicKey)) val parentId = UUID.randomUUID() val secret = randomBytes32() val pr = Bolt11Invoice(Block.LivenetGenesisBlock.hash, Some(1234 msat), ByteVector32.One, randomKey(), Right(randomBytes32()), CltvExpiryDelta(18)) - eclair.sendToRoute(1000 msat, Some(1200 msat), Some("42"), Some(parentId), pr, route, Some(secret), Some(100 msat), Some(CltvExpiryDelta(144)), trampolines) - paymentInitiator.expectMsg(SendPaymentToRoute(1000 msat, 1200 msat, pr, route, Some("42"), Some(parentId), Some(secret), 100 msat, CltvExpiryDelta(144), trampolines)) + eclair.sendToRoute(Some(1200 msat), Some("42"), Some(parentId), pr, route, Some(secret), Some(100 msat), Some(CltvExpiryDelta(144))) + paymentInitiator.expectMsg(SendPaymentToRoute(1200 msat, pr, route, Some("42"), Some(parentId), Some(TrampolineAttempt(secret, 100 msat, CltvExpiryDelta(144))))) } test("call sendWithPreimage, which generates a random preimage, to perform a KeySend payment") { f => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala index b72db19e6..d06c59de4 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala @@ -20,7 +20,6 @@ import akka.actor.typed.scaladsl.adapter.actorRefAdapter import akka.actor.{Actor, ActorLogging, ActorRef, Props} import akka.testkit.{TestFSMRef, TestProbe} import fr.acinq.bitcoin.scalacompat.ByteVector32 -import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.TestConstants.{Alice, Bob} import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.DummyOnChainWallet @@ -34,7 +33,7 @@ import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.receive.MultiPartHandler.ReceiveStandardPayment import fr.acinq.eclair.payment.receive.PaymentHandler import fr.acinq.eclair.payment.relay.Relayer -import fr.acinq.eclair.router.Router.ChannelHop +import fr.acinq.eclair.payment.send.ClearRecipient import fr.acinq.eclair.wire.protocol._ import grizzled.slf4j.Logging import org.scalatest.funsuite.FixtureAnyFunSuiteLike @@ -118,19 +117,20 @@ class FuzzySpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Channe // we don't want to be below htlcMinimumMsat val requiredAmount = 1000000 msat - def buildCmdAdd(paymentHash: ByteVector32, dest: PublicKey, paymentSecret: ByteVector32): CMD_ADD_HTLC = { + def buildCmdAdd(invoice: Bolt11Invoice): CMD_ADD_HTLC = { // allow overpaying (no more than 2 times the required amount) val amount = requiredAmount + Random.nextInt(requiredAmount.toLong.toInt).msat val expiry = (Channel.MIN_CLTV_EXPIRY_DELTA + 1).toCltvExpiry(currentBlockHeight = BlockHeight(400000)) - OutgoingPaymentPacket.buildCommand(self, Upstream.Local(UUID.randomUUID()), paymentHash, ChannelHop(null, null, dest, null) :: Nil, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount, expiry, paymentSecret, None)).get._1 + val Right(payment) = OutgoingPaymentPacket.buildOutgoingPayment(self, Upstream.Local(UUID.randomUUID()), invoice.paymentHash, makeSingleHopRoute(amount, invoice.nodeId), ClearRecipient(invoice, amount, expiry, Nil)) + payment.cmd } def initiatePaymentOrStop(remaining: Int): Unit = if (remaining > 0) { paymentHandler ! ReceiveStandardPayment(Some(requiredAmount), Left("One coffee")) context become { - case req: Invoice => - sendChannel ! buildCmdAdd(req.paymentHash, req.nodeId, req.paymentSecret) + case invoice: Bolt11Invoice => + sendChannel ! buildCmdAdd(invoice) context become { case RES_SUCCESS(_: CMD_ADD_HTLC, _) => () case RES_ADD_SETTLED(_, htlc, _: HtlcResult.Fulfill) => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala index 0445aff7d..10bcc52e1 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala @@ -33,9 +33,10 @@ import fr.acinq.eclair.channel.fsm.Channel import fr.acinq.eclair.channel.publish.TxPublisher import fr.acinq.eclair.channel.publish.TxPublisher.PublishReplaceableTx import fr.acinq.eclair.channel.states.ChannelStateTestsBase.FakeTxPublisherFactory -import fr.acinq.eclair.payment.OutgoingPaymentPacket import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream -import fr.acinq.eclair.router.Router.ChannelHop +import fr.acinq.eclair.payment.send.SpontaneousRecipient +import fr.acinq.eclair.payment.{Invoice, OutgoingPaymentPacket} +import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams, Route} import fr.acinq.eclair.transactions.Transactions import fr.acinq.eclair.transactions.Transactions._ import fr.acinq.eclair.wire.protocol._ @@ -355,10 +356,16 @@ trait ChannelStateTestsBase extends Assertions with Eventually { } def makeCmdAdd(amount: MilliSatoshi, cltvExpiryDelta: CltvExpiryDelta, destination: PublicKey, paymentPreimage: ByteVector32, currentBlockHeight: BlockHeight, upstream: Upstream, replyTo: ActorRef = TestProbe().ref): (ByteVector32, CMD_ADD_HTLC) = { - val paymentHash: ByteVector32 = Crypto.sha256(paymentPreimage) + val paymentHash = Crypto.sha256(paymentPreimage) val expiry = cltvExpiryDelta.toCltvExpiry(currentBlockHeight) - val cmd = OutgoingPaymentPacket.buildCommand(replyTo, upstream, paymentHash, ChannelHop(null, null, destination, null) :: Nil, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount, expiry, randomBytes32(), None)).get._1.copy(commit = false) - (paymentPreimage, cmd) + val recipient = SpontaneousRecipient(destination, amount, expiry, paymentPreimage) + val Right(payment) = OutgoingPaymentPacket.buildOutgoingPayment(replyTo, upstream, paymentHash, makeSingleHopRoute(amount, destination), recipient) + (paymentPreimage, payment.cmd.copy(commit = false)) + } + + def makeSingleHopRoute(amount: MilliSatoshi, destination: PublicKey): Route = { + val dummyParams = HopRelayParams.FromHint(Invoice.ExtraEdge(randomKey().publicKey, destination, ShortChannelId(0), 0 msat, 0, CltvExpiryDelta(0), 0 msat, None)) + Route(amount, Seq(ChannelHop(ShortChannelId(0), dummyParams.extraHop.sourceNodeId, dummyParams.extraHop.targetNodeId, dummyParams)), None) } def addHtlc(amount: MilliSatoshi, s: TestFSMRef[ChannelState, ChannelData, Channel], r: TestFSMRef[ChannelState, ChannelData, Channel], s2r: TestProbe, r2s: TestProbe): (ByteVector32, UpdateAddHtlc) = { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala index 4f79a07e7..b47394355 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala @@ -17,9 +17,9 @@ package fr.acinq.eclair.channel.states.f import akka.testkit.TestProbe +import fr.acinq.bitcoin.ScriptFlags import fr.acinq.bitcoin.scalacompat.Crypto.PrivateKey import fr.acinq.bitcoin.scalacompat.{ByteVector32, ByteVector64, Crypto, SatoshiLong, Transaction} -import fr.acinq.bitcoin.ScriptFlags import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher._ import fr.acinq.eclair.blockchain.fee.{FeeratePerKw, FeeratesPerKw} import fr.acinq.eclair.blockchain.{CurrentBlockHeight, CurrentFeerates} @@ -29,8 +29,8 @@ import fr.acinq.eclair.channel.states.{ChannelStateTestsBase, ChannelStateTestsT import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.relay.Relayer._ -import fr.acinq.eclair.router.Router.ChannelHop -import fr.acinq.eclair.wire.protocol.{ClosingSigned, CommitSig, Error, FailureMessageCodecs, PaymentOnion, PermanentChannelFailure, RevokeAndAck, Shutdown, UpdateAddHtlc, UpdateFailHtlc, UpdateFailMalformedHtlc, UpdateFee, UpdateFulfillHtlc} +import fr.acinq.eclair.payment.send.SpontaneousRecipient +import fr.acinq.eclair.wire.protocol.{ClosingSigned, CommitSig, Error, FailureMessageCodecs, PermanentChannelFailure, RevokeAndAck, Shutdown, UpdateAddHtlc, UpdateFailHtlc, UpdateFailMalformedHtlc, UpdateFee, UpdateFulfillHtlc} import fr.acinq.eclair.{BlockHeight, CltvExpiry, CltvExpiryDelta, MilliSatoshiLong, TestConstants, TestKitBaseClass, randomBytes32} import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.{Outcome, Tag} @@ -58,9 +58,8 @@ class ShutdownStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wit val sender = TestProbe() // alice sends an HTLC to bob val h1 = Crypto.sha256(r1) - val amount1 = 300000000 msat - val expiry1 = CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight) - val cmd1 = OutgoingPaymentPacket.buildCommand(sender.ref, Upstream.Local(UUID.randomUUID), h1, ChannelHop(null, null, TestConstants.Bob.nodeParams.nodeId, null) :: Nil, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount1, expiry1, randomBytes32(), None)).get._1.copy(commit = false) + val recipient1 = SpontaneousRecipient(TestConstants.Bob.nodeParams.nodeId, 300_000_000 msat, CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), r1) + val Right(cmd1) = OutgoingPaymentPacket.buildOutgoingPayment(sender.ref, Upstream.Local(UUID.randomUUID), h1, makeSingleHopRoute(recipient1.totalAmount, recipient1.nodeId), recipient1).map(_.cmd.copy(commit = false)) alice ! cmd1 sender.expectMsgType[RES_SUCCESS[CMD_ADD_HTLC]] val htlc1 = alice2bob.expectMsgType[UpdateAddHtlc] @@ -68,9 +67,8 @@ class ShutdownStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wit awaitCond(bob.stateData.asInstanceOf[DATA_NORMAL].commitments.remoteChanges.proposed == htlc1 :: Nil) // alice sends another HTLC to bob val h2 = Crypto.sha256(r2) - val amount2 = 200000000 msat - val expiry2 = CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight) - val cmd2 = OutgoingPaymentPacket.buildCommand(sender.ref, Upstream.Local(UUID.randomUUID), h2, ChannelHop(null, null, TestConstants.Bob.nodeParams.nodeId, null) :: Nil, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount2, expiry2, randomBytes32(), None)).get._1.copy(commit = false) + val recipient2 = SpontaneousRecipient(TestConstants.Bob.nodeParams.nodeId, 200_000_000 msat, CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), r2) + val Right(cmd2) = OutgoingPaymentPacket.buildOutgoingPayment(sender.ref, Upstream.Local(UUID.randomUUID), h2, makeSingleHopRoute(recipient2.totalAmount, recipient2.nodeId), recipient2).map(_.cmd.copy(commit = false)) alice ! cmd2 sender.expectMsgType[RES_SUCCESS[CMD_ADD_HTLC]] val htlc2 = alice2bob.expectMsgType[UpdateAddHtlc] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala index 92446779e..658e7baac 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala @@ -785,7 +785,6 @@ class AuditDbSpec extends AnyFunSuite { dbName = PgAuditDb.DB_NAME, targetVersion = PgAuditDb.CURRENT_VERSION, postCheck = connection => { - val migratedDb = dbs.audit using(connection.createStatement()) { statement => assert(getVersion(statement, "audit").contains(PgAuditDb.CURRENT_VERSION)) } using(connection.prepareStatement(s"SELECT amount_msat, status, experiment_name, recipient_node_id FROM audit.path_finding_metrics ORDER BY timestamp")) { statement => val result = statement.executeQuery() @@ -899,7 +898,7 @@ class AuditDbSpec extends AnyFunSuite { assert(result.getString(5) == "my-test-experiment") if (isPg) { assert(result.getString(6) == recipientNodeId.toHex) - assert(result.getString(7) == "[{\"feeBase\": 1000, \"sourceNodeId\": \"033f2d90d6ba1f771e4b3586b35cc9f825cfcb7cdd7edaa2bfd63f0cb81b17580e\", \"targetNodeId\": \"02c15a88ff263cec5bf79c315b17b7f2e083f71d62a880e30281faaac0898cb2b7\", \"shortChannelId\": \"0x0x1\", \"cltvExpiryDelta\": 144, \"feeProportionalMillionths\": 100}, {\"feeBase\": 900, \"sourceNodeId\": \"02c15a88ff263cec5bf79c315b17b7f2e083f71d62a880e30281faaac0898cb2b7\", \"targetNodeId\": \"03f5b1f2768140178e1daac0fec11fce2eec6beec3ed64862bfb1114f7bc535b48\", \"shortChannelId\": \"0x0x2\", \"cltvExpiryDelta\": 12, \"feeProportionalMillionths\": 200}, {\"feeBase\": 800, \"sourceNodeId\": \"026ec3e3438308519a75ca4496822a6c1e229174fbcaadeeb174704c377112c331\", \"targetNodeId\": \"03f5b1f2768140178e1daac0fec11fce2eec6beec3ed64862bfb1114f7bc535b48\", \"shortChannelId\": \"0x0x3\", \"cltvExpiryDelta\": 78, \"feeProportionalMillionths\": 300}]") + assert(result.getString(7) == """[{"feeBase": 1000, "htlcMinimum": 1, "sourceNodeId": "033f2d90d6ba1f771e4b3586b35cc9f825cfcb7cdd7edaa2bfd63f0cb81b17580e", "targetNodeId": "02c15a88ff263cec5bf79c315b17b7f2e083f71d62a880e30281faaac0898cb2b7", "shortChannelId": "0x0x1", "cltvExpiryDelta": 144, "feeProportionalMillionths": 100}, {"feeBase": 900, "htlcMinimum": 1, "sourceNodeId": "02c15a88ff263cec5bf79c315b17b7f2e083f71d62a880e30281faaac0898cb2b7", "targetNodeId": "03f5b1f2768140178e1daac0fec11fce2eec6beec3ed64862bfb1114f7bc535b48", "shortChannelId": "0x0x2", "cltvExpiryDelta": 12, "feeProportionalMillionths": 200}, {"feeBase": 800, "htlcMinimum": 1, "sourceNodeId": "026ec3e3438308519a75ca4496822a6c1e229174fbcaadeeb174704c377112c331", "targetNodeId": "03f5b1f2768140178e1daac0fec11fce2eec6beec3ed64862bfb1114f7bc535b48", "shortChannelId": "0x0x3", "cltvExpiryDelta": 78, "feeProportionalMillionths": 300}]""") } assert(!result.next()) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala index 57dd15730..b2648c430 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala @@ -25,7 +25,7 @@ import fr.acinq.eclair.db.jdbc.JdbcUtils.{setVersion, using} import fr.acinq.eclair.db.pg.PgPaymentsDb import fr.acinq.eclair.db.sqlite.SqlitePaymentsDb import fr.acinq.eclair.payment._ -import fr.acinq.eclair.router.Router.{ChannelHop, ChannelRelayParams, NodeHop} +import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams, NodeHop} import fr.acinq.eclair.wire.protocol.OfferTypes._ import fr.acinq.eclair.wire.protocol.{ChannelUpdate, TlvStream, UnknownNextPeer} import fr.acinq.eclair.{CltvExpiryDelta, Features, MilliSatoshiLong, Paginated, ShortChannelId, TimestampMilli, TimestampMilliLong, TimestampSecond, TimestampSecondLong, randomBytes32, randomBytes64, randomKey} @@ -655,7 +655,7 @@ class PaymentsDbSpec extends AnyFunSuite { object PaymentsDbSpec { val (alicePriv, bobPriv, carolPriv, davePriv) = (randomKey(), randomKey(), randomKey(), randomKey()) val (alice, bob, carol, dave) = (alicePriv.publicKey, bobPriv.publicKey, carolPriv.publicKey, davePriv.publicKey) - val hop_ab = ChannelHop(ShortChannelId(42), alice, bob, ChannelRelayParams.FromAnnouncement(ChannelUpdate(randomBytes64(), randomBytes32(), ShortChannelId(42), 1 unixsec, ChannelUpdate.MessageFlags(dontForward = false), ChannelUpdate.ChannelFlags.DUMMY, CltvExpiryDelta(12), 1 msat, 1 msat, 1, 500_000_000 msat))) + val hop_ab = ChannelHop(ShortChannelId(42), alice, bob, HopRelayParams.FromAnnouncement(ChannelUpdate(randomBytes64(), randomBytes32(), ShortChannelId(42), 1 unixsec, ChannelUpdate.MessageFlags(dontForward = false), ChannelUpdate.ChannelFlags.DUMMY, CltvExpiryDelta(12), 1 msat, 1 msat, 1, 500_000_000 msat))) val hop_bc = NodeHop(bob, carol, CltvExpiryDelta(14), 1 msat) val (preimage1, preimage2, preimage3, preimage4) = (randomBytes32(), randomBytes32(), randomBytes32(), randomBytes32()) val (paymentHash1, paymentHash2, paymentHash3, paymentHash4) = (Crypto.sha256(preimage1), Crypto.sha256(preimage2), Crypto.sha256(preimage3), Crypto.sha256(preimage4)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala index e4a3e0a2a..b3cf7cca1 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala @@ -159,7 +159,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { // first we retrieve a payment hash from D val amountMsat = 4200000.msat sender.send(nodes("D").paymentHandler, ReceiveStandardPayment(Some(amountMsat), Left("1 coffee"))) - val invoice = sender.expectMsgType[Invoice] + val invoice = sender.expectMsgType[Bolt11Invoice] assert(invoice.paymentMetadata.nonEmpty) // then we make the actual payment @@ -477,8 +477,9 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.paymentHash == invoice.paymentHash, paymentSent) assert(paymentSent.recipientNodeId == nodes("F").nodeParams.nodeId, paymentSent) assert(paymentSent.recipientAmount == amount, paymentSent) - assert(paymentSent.feesPaid == 1210100.msat, paymentSent) + assert(paymentSent.trampolineFees == 1210100.msat, paymentSent) assert(paymentSent.nonTrampolineFees == 0.msat, paymentSent) + assert(paymentSent.feesPaid == 1210100.msat, paymentSent) awaitCond(nodes("F").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("F").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) @@ -492,12 +493,12 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(relayed.amountIn - relayed.amountOut > 0.msat, relayed) assert(relayed.amountIn - relayed.amountOut < 1210100.msat, relayed) - val outgoingSuccess = nodes("B").nodeParams.db.payments.listOutgoingPayments(paymentId).filter(p => p.status.isInstanceOf[OutgoingPaymentStatus.Succeeded]) - outgoingSuccess.collect { case p@OutgoingPayment(_, _, _, _, _, _, _, recipientNodeId, _, _, OutgoingPaymentStatus.Succeeded(_, _, route, _)) => - assert(recipientNodeId == nodes("F").nodeParams.nodeId, p) - assert(route.lastOption.contains(HopSummary(nodes("G").nodeParams.nodeId, nodes("F").nodeParams.nodeId)), p) - } - assert(outgoingSuccess.map(_.amount).sum == amount + 1210100.msat, outgoingSuccess) + val outgoingSuccess = nodes("B").nodeParams.db.payments.listOutgoingPayments(paymentId).filter(p => p.status.isInstanceOf[OutgoingPaymentStatus.Succeeded]).head + assert(outgoingSuccess.recipientNodeId == nodes("F").nodeParams.nodeId, outgoingSuccess) + assert(outgoingSuccess.recipientAmount == amount, outgoingSuccess) + assert(outgoingSuccess.amount == amount + 1210100.msat, outgoingSuccess) + val status = outgoingSuccess.status.asInstanceOf[OutgoingPaymentStatus.Succeeded] + assert(status.route.lastOption.contains(HopSummary(nodes("G").nodeParams.nodeId, nodes("F").nodeParams.nodeId)), status) } test("send a trampoline payment D->B (via trampoline C)") { @@ -506,7 +507,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { nodes("B").system.eventStream.subscribe(eventListener.ref, classOf[PaymentMetadataReceived]) val amount = 2500000000L.msat sender.send(nodes("B").paymentHandler, ReceiveStandardPayment(Some(amount), Left("trampoline-MPP is so #reckless"))) - val invoice = sender.expectMsgType[Invoice] + val invoice = sender.expectMsgType[Bolt11Invoice] assert(invoice.features.hasFeature(Features.BasicMultiPartPayment)) assert(invoice.features.hasFeature(Features.TrampolinePaymentPrototype)) assert(invoice.paymentMetadata.nonEmpty) @@ -522,8 +523,9 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.id == paymentId, paymentSent) assert(paymentSent.paymentHash == invoice.paymentHash, paymentSent) assert(paymentSent.recipientAmount == amount, paymentSent) - assert(paymentSent.feesPaid == 750000.msat, paymentSent) + assert(paymentSent.trampolineFees == 750000.msat, paymentSent) assert(paymentSent.nonTrampolineFees == 0.msat, paymentSent) + assert(paymentSent.feesPaid == 750000.msat, paymentSent) awaitCond(nodes("B").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("B").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) @@ -538,12 +540,12 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(relayed.amountIn - relayed.amountOut > 0.msat, relayed) assert(relayed.amountIn - relayed.amountOut < 750000.msat, relayed) - val outgoingSuccess = nodes("D").nodeParams.db.payments.listOutgoingPayments(paymentId).filter(p => p.status.isInstanceOf[OutgoingPaymentStatus.Succeeded]) - outgoingSuccess.collect { case p@OutgoingPayment(_, _, _, _, _, _, _, recipientNodeId, _, _, OutgoingPaymentStatus.Succeeded(_, _, route, _)) => - assert(recipientNodeId == nodes("B").nodeParams.nodeId, p) - assert(route.lastOption.contains(HopSummary(nodes("C").nodeParams.nodeId, nodes("B").nodeParams.nodeId)), p) - } - assert(outgoingSuccess.map(_.amount).sum == amount + 750000.msat, outgoingSuccess) + val outgoingSuccess = nodes("D").nodeParams.db.payments.listOutgoingPayments(paymentId).filter(p => p.status.isInstanceOf[OutgoingPaymentStatus.Succeeded]).head + assert(outgoingSuccess.recipientNodeId == nodes("B").nodeParams.nodeId, outgoingSuccess) + assert(outgoingSuccess.recipientAmount == amount, outgoingSuccess) + assert(outgoingSuccess.amount == amount + 750000.msat, outgoingSuccess) + val status = outgoingSuccess.status.asInstanceOf[OutgoingPaymentStatus.Succeeded] + assert(status.route.lastOption.contains(HopSummary(nodes("C").nodeParams.nodeId, nodes("B").nodeParams.nodeId)), status) awaitCond(nodes("D").nodeParams.db.audit.listSent(start, TimestampMilli.now()).nonEmpty) val sent = nodes("D").nodeParams.db.audit.listSent(start, TimestampMilli.now()) @@ -562,7 +564,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { val amount = 3000000000L.msat sender.send(nodes("A").paymentHandler, ReceiveStandardPayment(Some(amount), Left("trampoline to non-trampoline is so #vintage"), extraHops = routingHints)) - val invoice = sender.expectMsgType[Invoice] + val invoice = sender.expectMsgType[Bolt11Invoice] assert(invoice.features.hasFeature(Features.BasicMultiPartPayment)) assert(!invoice.features.hasFeature(Features.TrampolinePaymentPrototype)) assert(invoice.paymentMetadata.nonEmpty) @@ -575,6 +577,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.paymentHash == invoice.paymentHash, paymentSent) assert(paymentSent.recipientAmount == amount, paymentSent) assert(paymentSent.trampolineFees == 1500000.msat, paymentSent) + assert(paymentSent.nonTrampolineFees == 0.msat, paymentSent) awaitCond(nodes("A").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("A").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) @@ -589,12 +592,12 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(relayed.amountIn - relayed.amountOut > 0.msat, relayed) assert(relayed.amountIn - relayed.amountOut < 1500000.msat, relayed) - val outgoingSuccess = nodes("F").nodeParams.db.payments.listOutgoingPayments(paymentId).filter(p => p.status.isInstanceOf[OutgoingPaymentStatus.Succeeded]) - outgoingSuccess.collect { case p@OutgoingPayment(_, _, _, _, _, _, _, recipientNodeId, _, _, OutgoingPaymentStatus.Succeeded(_, _, route, _)) => - assert(recipientNodeId == nodes("A").nodeParams.nodeId, p) - assert(route.lastOption.contains(HopSummary(nodes("C").nodeParams.nodeId, nodes("A").nodeParams.nodeId)), p) - } - assert(outgoingSuccess.map(_.amount).sum == amount + 1500000.msat, outgoingSuccess) + val outgoingSuccess = nodes("F").nodeParams.db.payments.listOutgoingPayments(paymentId).filter(p => p.status.isInstanceOf[OutgoingPaymentStatus.Succeeded]).head + assert(outgoingSuccess.recipientNodeId == nodes("A").nodeParams.nodeId, outgoingSuccess) + assert(outgoingSuccess.recipientAmount == amount, outgoingSuccess) + assert(outgoingSuccess.amount == amount + 1500000.msat, outgoingSuccess) + val status = outgoingSuccess.status.asInstanceOf[OutgoingPaymentStatus.Succeeded] + assert(status.route.lastOption.contains(HopSummary(nodes("C").nodeParams.nodeId, nodes("A").nodeParams.nodeId)), status) } test("send a trampoline payment B->D (temporary local failure at trampoline)") { @@ -648,6 +651,37 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(outgoingPayments.forall(p => p.status.isInstanceOf[OutgoingPaymentStatus.Failed]), outgoingPayments) } + test("send a trampoline payment A->D (via remote trampoline C)") { + val sender = TestProbe() + val amount = 500000000L.msat + sender.send(nodes("D").paymentHandler, ReceiveStandardPayment(Some(amount), Left("remote trampoline is so #reckless"))) + val invoice = sender.expectMsgType[Bolt11Invoice] + assert(invoice.features.hasFeature(Features.BasicMultiPartPayment)) + assert(invoice.features.hasFeature(Features.TrampolinePaymentPrototype)) + + val payment = SendTrampolinePayment(amount, invoice, nodes("C").nodeParams.nodeId, Seq((500000 msat, CltvExpiryDelta(288))), routeParams = integrationTestRouteParams) + sender.send(nodes("A").paymentInitiator, payment) + val paymentId = sender.expectMsgType[UUID] + val paymentSent = sender.expectMsgType[PaymentSent](max = 30 seconds) + assert(paymentSent.id == paymentId, paymentSent) + assert(paymentSent.paymentHash == invoice.paymentHash, paymentSent) + assert(paymentSent.recipientAmount == amount, paymentSent) + assert(paymentSent.trampolineFees == 500000.msat, paymentSent) + assert(paymentSent.nonTrampolineFees > 0.msat, paymentSent) + assert(paymentSent.feesPaid > 500000.msat, paymentSent) + + awaitCond(nodes("D").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) + val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("D").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) + assert(receivedAmount == amount) + + val outgoingSuccess = nodes("A").nodeParams.db.payments.listOutgoingPayments(paymentId).filter(p => p.status.isInstanceOf[OutgoingPaymentStatus.Succeeded]).head + assert(outgoingSuccess.recipientNodeId == nodes("D").nodeParams.nodeId, outgoingSuccess) + assert(outgoingSuccess.recipientAmount == amount, outgoingSuccess) + assert(outgoingSuccess.amount == amount + 500000.msat, outgoingSuccess) + val status = outgoingSuccess.status.asInstanceOf[OutgoingPaymentStatus.Succeeded] + assert(status.route.lastOption.contains(HopSummary(nodes("C").nodeParams.nodeId, nodes("D").nodeParams.nodeId)), status) + } + test("generate and validate lots of channels") { val bitcoinClient = new BitcoinCoreClient(bitcoinrpcclient) // we simulate fake channels by publishing a funding tx and sending announcement messages to a node at random diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/zeroconf/ZeroConfAliasIntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/zeroconf/ZeroConfAliasIntegrationSpec.scala index c530088dd..55890b23b 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/zeroconf/ZeroConfAliasIntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/zeroconf/ZeroConfAliasIntegrationSpec.scala @@ -100,9 +100,9 @@ class ZeroConfAliasIntegrationSpec extends FixtureSpec with IntegrationPatience private def createSelfRouteCarol(f: FixtureParam, scid_ab: ShortChannelId, scid_bc: ShortChannelId): Unit = { import f._ val sender = TestProbe("sender") - sender.send(carol.router, FinalizeRoute(50_000 msat, PredefinedNodeRoute(Seq(alice.nodeId, bob.nodeId, carol.nodeId)))) + sender.send(carol.router, FinalizeRoute(PredefinedNodeRoute(50_000 msat, Seq(alice.nodeId, bob.nodeId, carol.nodeId)))) val route = sender.expectMsgType[RouteResponse].routes.head - assert(route.length == 2) + assert(route.hops.length == 2) assert(route.hops.map(_.nodeId) == Seq(alice.nodeId, bob.nodeId)) assert(route.hops.map(_.nextNodeId) == Seq(bob.nodeId, carol.nodeId)) assert(route.hops.map(_.shortChannelId) == Seq(scid_ab, scid_bc)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala index 97acf1957..fd0959bbb 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala @@ -116,7 +116,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(Crypto.sha256(incoming.get.paymentPreimage) == invoice.paymentHash) val add = UpdateAddHtlc(ByteVector32.One, 1, amountMsat, invoice.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket, None) - sender.send(handlerWithoutMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createSinglePartPayload(add.amountMsat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + sender.send(handlerWithoutMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createPayload(add.amountMsat, add.amountMsat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) assert(register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]].message.id == add.id) val paymentReceived = eventListener.expectMsgType[PaymentReceived] @@ -132,7 +132,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val invoice = sender.expectMsgType[Bolt11Invoice] val add = UpdateAddHtlc(ByteVector32.One, 1, 75_000 msat, invoice.paymentHash, defaultExpiry + CltvExpiryDelta(12), TestConstants.emptyOnionPacket, None) - sender.send(handlerWithoutMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createSinglePartPayload(70_000 msat, defaultExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + sender.send(handlerWithoutMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createPayload(70_000 msat, 70_000 msat, defaultExpiry, invoice.paymentSecret, invoice.paymentMetadata))) assert(register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]].message.id == add.id) val paymentReceived = eventListener.expectMsgType[PaymentReceived] @@ -150,7 +150,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.status == IncomingPaymentStatus.Pending) val add = UpdateAddHtlc(ByteVector32.One, 2, amountMsat, invoice.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket, None) - sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createSinglePartPayload(add.amountMsat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createPayload(add.amountMsat, add.amountMsat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) assert(register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]].message.id == add.id) val paymentReceived = eventListener.expectMsgType[PaymentReceived] @@ -191,7 +191,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.status == IncomingPaymentStatus.Pending) val add = UpdateAddHtlc(ByteVector32.One, 0, amountMsat, invoice.paymentHash, CltvExpiryDelta(3).toCltvExpiry(nodeParams.currentBlockHeight), TestConstants.emptyOnionPacket, None) - sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createSinglePartPayload(add.amountMsat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createPayload(add.amountMsat, add.amountMsat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(amountMsat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.status == IncomingPaymentStatus.Pending) @@ -283,9 +283,9 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val invoiceReq = InvoiceRequest(offer, 25_000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) val router = TestProbe() val (a, b, c, d) = (randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, nodeParams.nodeId) - val hop_ab = Router.ChannelHop(ShortChannelId(1), a, b, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(a, b, ShortChannelId(1), 1000 msat, 0, CltvExpiryDelta(100)))) - val hop_bd = Router.ChannelHop(ShortChannelId(2), b, d, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(b, d, ShortChannelId(2), 800 msat, 0, CltvExpiryDelta(50)))) - val hop_cd = Router.ChannelHop(ShortChannelId(3), c, d, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(c, d, ShortChannelId(3), 0 msat, 0, CltvExpiryDelta(75)))) + val hop_ab = Router.ChannelHop(ShortChannelId(1), a, b, Router.HopRelayParams.FromHint(Invoice.ExtraEdge(a, b, ShortChannelId(1), 1000 msat, 0, CltvExpiryDelta(100), 1 msat, None))) + val hop_bd = Router.ChannelHop(ShortChannelId(2), b, d, Router.HopRelayParams.FromHint(Invoice.ExtraEdge(b, d, ShortChannelId(2), 800 msat, 0, CltvExpiryDelta(50), 1 msat, None))) + val hop_cd = Router.ChannelHop(ShortChannelId(3), c, d, Router.HopRelayParams.FromHint(Invoice.ExtraEdge(c, d, ShortChannelId(3), 0 msat, 0, CltvExpiryDelta(75), 1 msat, None))) val receivingRoutes = Seq( ReceivingRoute(Seq(a, b, d), CltvExpiryDelta(100), Seq(DummyBlindedHop(150 msat, 0, CltvExpiryDelta(25)))), ReceivingRoute(Seq(c, d), CltvExpiryDelta(50), Seq(DummyBlindedHop(250 msat, 0, CltvExpiryDelta(10)), DummyBlindedHop(150 msat, 0, CltvExpiryDelta(80)))), @@ -293,11 +293,11 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike ) sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(privKey, offer, invoiceReq, receivingRoutes, router.ref)) val finalizeRoute1 = router.expectMsgType[Router.FinalizeRoute] - assert(finalizeRoute1.route == Router.PredefinedNodeRoute(Seq(a, b, d))) - router.send(router.lastSender, RouteResponse(Seq(Router.Route(finalizeRoute1.amount, Seq(hop_ab, hop_bd))))) + assert(finalizeRoute1.route == Router.PredefinedNodeRoute(25_000 msat, Seq(a, b, d))) + router.send(router.lastSender, RouteResponse(Seq(Router.Route(25_000 msat, Seq(hop_ab, hop_bd), None)))) val finalizeRoute2 = router.expectMsgType[Router.FinalizeRoute] - assert(finalizeRoute2.route == Router.PredefinedNodeRoute(Seq(c, d))) - router.send(router.lastSender, RouteResponse(Seq(Router.Route(finalizeRoute2.amount, Seq(hop_cd))))) + assert(finalizeRoute2.route == Router.PredefinedNodeRoute(25_000 msat, Seq(c, d))) + router.send(router.lastSender, RouteResponse(Seq(Router.Route(25_000 msat, Seq(hop_cd), None)))) val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.amount == 25_000.msat) assert(invoice.nodeId == privKey.publicKey) @@ -308,10 +308,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(invoice.blindedPaths.length == 3) assert(invoice.blindedPaths(0).blindedNodeIds.length == 4) assert(invoice.blindedPaths(0).introductionNodeId == a) - assert(invoice.blindedPathsInfo(0) == PaymentInfo(1950 msat, 0, CltvExpiryDelta(175), 0 msat, 25_000 msat, Features.empty)) + assert(invoice.blindedPathsInfo(0) == PaymentInfo(1950 msat, 0, CltvExpiryDelta(175), 1 msat, 25_000 msat, Features.empty)) assert(invoice.blindedPaths(1).blindedNodeIds.length == 4) assert(invoice.blindedPaths(1).introductionNodeId == c) - assert(invoice.blindedPathsInfo(1) == PaymentInfo(400 msat, 0, CltvExpiryDelta(165), 0 msat, 25_000 msat, Features.empty)) + assert(invoice.blindedPathsInfo(1) == PaymentInfo(400 msat, 0, CltvExpiryDelta(165), 1 msat, 25_000 msat, Features.empty)) assert(invoice.blindedPaths(2).blindedNodeIds.length == 1) assert(invoice.blindedPaths(2).introductionNodeId == d) assert(invoice.blindedPathsInfo(2) == PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, 25_000 msat, Features.empty)) @@ -331,17 +331,17 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val invoiceReq = InvoiceRequest(offer, 25_000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) val router = TestProbe() val (a, b, c) = (randomKey().publicKey, randomKey().publicKey, nodeParams.nodeId) - val hop_ac = Router.ChannelHop(ShortChannelId(1), a, c, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(a, c, ShortChannelId(1), 100 msat, 0, CltvExpiryDelta(50)))) + val hop_ac = Router.ChannelHop(ShortChannelId(1), a, c, Router.HopRelayParams.FromHint(Invoice.ExtraEdge(a, c, ShortChannelId(1), 100 msat, 0, CltvExpiryDelta(50), 1 msat, None))) val receivingRoutes = Seq( ReceivingRoute(Seq(a, c), CltvExpiryDelta(100)), ReceivingRoute(Seq(b, c), CltvExpiryDelta(100)), ) sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(privKey, offer, invoiceReq, receivingRoutes, router.ref)) val finalizeRoute1 = router.expectMsgType[Router.FinalizeRoute] - assert(finalizeRoute1.route == Router.PredefinedNodeRoute(Seq(a, c))) - router.send(router.lastSender, RouteResponse(Seq(Router.Route(finalizeRoute1.amount, Seq(hop_ac))))) + assert(finalizeRoute1.route == Router.PredefinedNodeRoute(25_000 msat, Seq(a, c))) + router.send(router.lastSender, RouteResponse(Seq(Router.Route(25_000 msat, Seq(hop_ac), None)))) val finalizeRoute2 = router.expectMsgType[Router.FinalizeRoute] - assert(finalizeRoute2.route == Router.PredefinedNodeRoute(Seq(b, c))) + assert(finalizeRoute2.route == Router.PredefinedNodeRoute(25_000 msat, Seq(b, c))) router.send(router.lastSender, Status.Failure(new IllegalArgumentException("invalid route"))) sender.expectMsgType[Status.Failure] @@ -376,7 +376,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(invoice.isExpired()) val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, invoice.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket, None) - sender.send(handlerWithoutMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createSinglePartPayload(add.amountMsat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + sender.send(handlerWithoutMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createPayload(add.amountMsat, add.amountMsat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]] val Some(incoming) = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(incoming.invoice.isExpired() && incoming.status == IncomingPaymentStatus.Expired) @@ -391,7 +391,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(invoice.isExpired()) val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, invoice.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket, None) - sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createPayload(add.amountMsat, 1000 msat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) val Some(incoming) = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) @@ -406,7 +406,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(!invoice.features.hasFeature(BasicMultiPartPayment)) val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, invoice.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket, None) - sender.send(handlerWithoutMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + sender.send(handlerWithoutMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createPayload(add.amountMsat, 1000 msat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.status == IncomingPaymentStatus.Pending) @@ -421,7 +421,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val lowCltvExpiry = nodeParams.channelConf.fulfillSafetyBeforeTimeout.toCltvExpiry(nodeParams.currentBlockHeight) val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, invoice.paymentHash, lowCltvExpiry, TestConstants.emptyOnionPacket, None) - sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createPayload(add.amountMsat, 1000 msat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.status == IncomingPaymentStatus.Pending) @@ -435,7 +435,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(invoice.features.hasFeature(BasicMultiPartPayment)) val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, invoice.paymentHash.reverse, defaultExpiry, TestConstants.emptyOnionPacket, None) - sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createPayload(add.amountMsat, 1000 msat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.status == IncomingPaymentStatus.Pending) @@ -449,7 +449,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(invoice.features.hasFeature(BasicMultiPartPayment)) val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, invoice.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket, None) - sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createMultiPartPayload(add.amountMsat, 999 msat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createPayload(add.amountMsat, 999 msat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(999 msat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.status == IncomingPaymentStatus.Pending) @@ -463,7 +463,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(invoice.features.hasFeature(BasicMultiPartPayment)) val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, invoice.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket, None) - sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createMultiPartPayload(add.amountMsat, 2001 msat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createPayload(add.amountMsat, 2001 msat, add.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(2001 msat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.status == IncomingPaymentStatus.Pending) @@ -478,7 +478,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike // Invalid payment secret. val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, invoice.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket, None) - sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, invoice.paymentSecret.reverse, invoice.paymentMetadata))) + sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createPayload(add.amountMsat, 1000 msat, add.cltvExpiry, invoice.paymentSecret.reverse, invoice.paymentMetadata))) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.status == IncomingPaymentStatus.Pending) @@ -508,7 +508,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) val add = UpdateAddHtlc(ByteVector32.One, 0, 5000 msat, invoice.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket, None) - sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createSinglePartPayload(add.amountMsat, add.cltvExpiry, randomBytes32(), None))) + sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createPayload(add.amountMsat, add.amountMsat, add.cltvExpiry, randomBytes32(), None))) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(5000 msat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.status == IncomingPaymentStatus.Pending) @@ -612,13 +612,13 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike f.sender.send(handler, ReceiveStandardPayment(Some(1000 msat), Left("1 slow coffee"))) val pr1 = f.sender.expectMsgType[Bolt11Invoice] val add1 = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, pr1.paymentHash, f.defaultExpiry, TestConstants.emptyOnionPacket, None) - f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add1, FinalPayload.Standard.createMultiPartPayload(add1.amountMsat, 1000 msat, add1.cltvExpiry, pr1.paymentSecret, pr1.paymentMetadata))) + f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add1, FinalPayload.Standard.createPayload(add1.amountMsat, 1000 msat, add1.cltvExpiry, pr1.paymentSecret, pr1.paymentMetadata))) // Partial payment exceeding the invoice amount, but incomplete because it promises to overpay. f.sender.send(handler, ReceiveStandardPayment(Some(1500 msat), Left("1 slow latte"))) val pr2 = f.sender.expectMsgType[Bolt11Invoice] val add2 = UpdateAddHtlc(ByteVector32.One, 1, 1600 msat, pr2.paymentHash, f.defaultExpiry, TestConstants.emptyOnionPacket, None) - f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add2, FinalPayload.Standard.createMultiPartPayload(add2.amountMsat, 2000 msat, add2.cltvExpiry, pr2.paymentSecret, pr2.paymentMetadata))) + f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add2, FinalPayload.Standard.createPayload(add2.amountMsat, 2000 msat, add2.cltvExpiry, pr2.paymentSecret, pr2.paymentMetadata))) awaitCond { f.sender.send(handler, GetPendingPayments) @@ -653,12 +653,12 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val invoice = f.sender.expectMsgType[Bolt11Invoice] val add1 = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, invoice.paymentHash, f.defaultExpiry, TestConstants.emptyOnionPacket, None) - f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add1, FinalPayload.Standard.createMultiPartPayload(add1.amountMsat, 1000 msat, add1.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add1, FinalPayload.Standard.createPayload(add1.amountMsat, 1000 msat, add1.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) // Invalid payment secret -> should be rejected. val add2 = UpdateAddHtlc(ByteVector32.Zeroes, 42, 200 msat, invoice.paymentHash, f.defaultExpiry, TestConstants.emptyOnionPacket, None) - f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add2, FinalPayload.Standard.createMultiPartPayload(add2.amountMsat, 1000 msat, add2.cltvExpiry, invoice.paymentSecret.reverse, invoice.paymentMetadata))) + f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add2, FinalPayload.Standard.createPayload(add2.amountMsat, 1000 msat, add2.cltvExpiry, invoice.paymentSecret.reverse, invoice.paymentMetadata))) val add3 = add2.copy(id = 43) - f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add3, FinalPayload.Standard.createMultiPartPayload(add3.amountMsat, 1000 msat, add3.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add3, FinalPayload.Standard.createPayload(add3.amountMsat, 1000 msat, add3.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) f.register.expectMsgAllOf( Register.Forward(null, add2.channelId, CMD_FAIL_HTLC(add2.id, Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight)), commit = true)), @@ -696,9 +696,9 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val invoice = f.sender.expectMsgType[Bolt11Invoice] val add1 = UpdateAddHtlc(randomBytes32(), 0, 1100 msat, invoice.paymentHash, f.defaultExpiry, TestConstants.emptyOnionPacket, None) - f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add1, FinalPayload.Standard.createMultiPartPayload(add1.amountMsat, 1500 msat, add1.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add1, FinalPayload.Standard.createPayload(add1.amountMsat, 1500 msat, add1.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) val add2 = UpdateAddHtlc(randomBytes32(), 1, 500 msat, invoice.paymentHash, f.defaultExpiry, TestConstants.emptyOnionPacket, None) - f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add2, FinalPayload.Standard.createMultiPartPayload(add2.amountMsat, 1500 msat, add2.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add2, FinalPayload.Standard.createPayload(add2.amountMsat, 1500 msat, add2.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) f.register.expectMsgAllOf( Register.Forward(null, add1.channelId, CMD_FULFILL_HTLC(add1.id, preimage, commit = true)), @@ -723,7 +723,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(invoice.paymentHash == Crypto.sha256(preimage)) val add1 = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, invoice.paymentHash, f.defaultExpiry, TestConstants.emptyOnionPacket, None) - f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add1, FinalPayload.Standard.createMultiPartPayload(add1.amountMsat, 1000 msat, add1.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add1, FinalPayload.Standard.createPayload(add1.amountMsat, 1000 msat, add1.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) f.register.expectMsg(Register.Forward(null, ByteVector32.One, CMD_FAIL_HTLC(0, Right(PaymentTimeout), commit = true))) awaitCond({ f.sender.send(handler, GetPendingPayments) @@ -731,9 +731,9 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike }) val add2 = UpdateAddHtlc(ByteVector32.One, 2, 300 msat, invoice.paymentHash, f.defaultExpiry, TestConstants.emptyOnionPacket, None) - f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add2, FinalPayload.Standard.createMultiPartPayload(add2.amountMsat, 1000 msat, add2.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add2, FinalPayload.Standard.createPayload(add2.amountMsat, 1000 msat, add2.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) val add3 = UpdateAddHtlc(ByteVector32.Zeroes, 5, 700 msat, invoice.paymentHash, f.defaultExpiry, TestConstants.emptyOnionPacket, None) - f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add3, FinalPayload.Standard.createMultiPartPayload(add3.amountMsat, 1000 msat, add3.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) + f.sender.send(handler, IncomingPaymentPacket.FinalPacket(add3, FinalPayload.Standard.createPayload(add3.amountMsat, 1000 msat, add3.cltvExpiry, invoice.paymentSecret, invoice.paymentMetadata))) // the fulfill are not necessarily in the same order as the commands f.register.expectMsgAllOf( @@ -801,7 +801,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(nodeParams.db.payments.getIncomingPayment(paymentHash).isEmpty) val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket, None) - sender.send(handlerWithoutMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createSinglePartPayload(add.amountMsat, add.cltvExpiry, paymentSecret, None))) + sender.send(handlerWithoutMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createPayload(add.amountMsat, add.amountMsat, add.cltvExpiry, paymentSecret, None))) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message assert(cmd.id == add.id) assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) @@ -815,7 +815,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(nodeParams.db.payments.getIncomingPayment(paymentHash).isEmpty) val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket, None) - sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, paymentSecret, Some(hex"012345")))) + sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createPayload(add.amountMsat, 1000 msat, add.cltvExpiry, paymentSecret, Some(hex"012345")))) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message assert(cmd.id == add.id) assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala index ef13fd1ce..1ea206357 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala @@ -18,28 +18,28 @@ package fr.acinq.eclair.payment import akka.actor.{ActorContext, ActorRef, Status} import akka.testkit.{TestFSMRef, TestProbe} +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto, SatoshiLong} import fr.acinq.eclair._ import fr.acinq.eclair.channel.{ChannelUnavailable, HtlcsTimedoutDownstream, RemoteCannotAffordFeesForNewHtlc} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.{FailureSummary, FailureType, OutgoingPaymentStatus} -import fr.acinq.eclair.payment.Invoice.BasicEdge +import fr.acinq.eclair.payment.Invoice.ExtraEdge import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle._ import fr.acinq.eclair.payment.send.PaymentError.RetryExhausted import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToRoute -import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentInitiator} +import fr.acinq.eclair.payment.send._ import fr.acinq.eclair.router.BaseRouterSpec.channelHopFromUpdate import fr.acinq.eclair.router.Graph.WeightRatios import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.router.{Announcements, RouteNotFound} -import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.wire.protocol._ import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike -import scodec.bits.{ByteVector, HexStringSyntax} +import scodec.bits.HexStringSyntax import java.util.UUID import scala.concurrent.duration._ @@ -67,7 +67,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS override def withFixture(test: OneArgTest): Outcome = { val id = UUID.randomUUID() - val cfg = SendPaymentConfig(id, id, Some("42"), paymentHash, finalAmount, finalRecipient, Upstream.Local(id), None, storeInDb = true, publishEvent = true, recordPathFindingMetrics = true, Nil) + val cfg = SendPaymentConfig(id, id, Some("42"), paymentHash, randomKey().publicKey, Upstream.Local(id), None, storeInDb = true, publishEvent = true, recordPathFindingMetrics = true) val nodeParams = TestConstants.Alice.nodeParams val (childPayFsm, router, sender, eventListener, metricsListener) = (TestProbe(), TestProbe(), TestProbe(), TestProbe(), TestProbe()) val paymentHandler = TestFSMRef(new MultiPartPaymentLifecycle(nodeParams, cfg, router.ref, FakePaymentFactory(childPayFsm))) @@ -80,24 +80,21 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS import f._ assert(payFsm.stateName == WAIT_FOR_PAYMENT_REQUEST) - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 1, None, routeParams = routeParams.copy(randomize = true)) + val payment = SendMultiPartPayment(sender.ref, clearRecipient, 1, routeParams.copy(randomize = true)) sender.send(payFsm, payment) - router.expectMsg(RouteRequest(nodeParams.nodeId, e, finalAmount, maxFee, routeParams = routeParams.copy(randomize = false), allowMultiPart = true, paymentContext = Some(cfg.paymentContext))) + router.expectMsg(RouteRequest(nodeParams.nodeId, clearRecipient, routeParams.copy(randomize = false), allowMultiPart = true, paymentContext = Some(cfg.paymentContext))) assert(payFsm.stateName == WAIT_FOR_ROUTES) - val singleRoute = Route(finalAmount, hop_ab_1 :: hop_be :: Nil) + val singleRoute = Route(finalAmount, hop_ab_1 :: hop_be :: Nil, None) router.send(payFsm, RouteResponse(Seq(singleRoute))) val childPayment = childPayFsm.expectMsgType[SendPaymentToRoute] assert(childPayment.route == Right(singleRoute)) - assert(childPayment.finalPayload.isInstanceOf[FinalPayload.Standard]) - assert(childPayment.finalPayload.expiry == expiry) - assert(childPayment.finalPayload.asInstanceOf[FinalPayload.Standard].paymentSecret == payment.paymentSecret) - assert(childPayment.finalPayload.amount == finalAmount) - assert(childPayment.finalPayload.totalAmount == finalAmount) + assert(childPayment.amount == finalAmount) + assert(childPayment.recipient == payment.recipient) assert(payFsm.stateName == PAYMENT_IN_PROGRESS) - val result = fulfillPendingPayments(f, 1) + val result = fulfillPendingPayments(f, 1, e, finalAmount) assert(result.amountWithFees == finalAmount + 100.msat) assert(result.trampolineFees == 0.msat) assert(result.nonTrampolineFees == 100.msat) @@ -114,58 +111,64 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS import f._ assert(payFsm.stateName == WAIT_FOR_PAYMENT_REQUEST) - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, 1200000 msat, expiry, 1, Some(hex"012345"), routeParams = routeParams.copy(randomize = false)) + val recipient = ClearRecipient(e, recipientFeatures, 1_200_000 msat, expiry, randomBytes32(), paymentMetadata_opt = Some(hex"012345")) + val payment = SendMultiPartPayment(sender.ref, recipient, 1, routeParams.copy(randomize = false)) sender.send(payFsm, payment) - router.expectMsg(RouteRequest(nodeParams.nodeId, e, 1200000 msat, maxFee, routeParams = routeParams.copy(randomize = false), allowMultiPart = true, paymentContext = Some(cfg.paymentContext))) + router.expectMsg(RouteRequest(nodeParams.nodeId, recipient, routeParams.copy(randomize = false), allowMultiPart = true, paymentContext = Some(cfg.paymentContext))) assert(payFsm.stateName == WAIT_FOR_ROUTES) val routes = Seq( - Route(500000 msat, hop_ab_1 :: hop_be :: Nil), - Route(700000 msat, hop_ac_1 :: hop_ce :: Nil), + Route(500_000 msat, hop_ab_1 :: hop_be :: Nil, None), + Route(700_000 msat, hop_ac_1 :: hop_ce :: Nil, None), ) router.send(payFsm, RouteResponse(routes)) val childPayments = childPayFsm.expectMsgType[SendPaymentToRoute] :: childPayFsm.expectMsgType[SendPaymentToRoute] :: Nil assert(childPayments.map(_.route).toSet == routes.map(r => Right(r)).toSet) - assert(childPayments.map(_.finalPayload.expiry).toSet == Set(expiry)) - childPayments.foreach(childPayment => assert(childPayment.finalPayload.isInstanceOf[FinalPayload.Standard])) - assert(childPayments.map(_.finalPayload.asInstanceOf[FinalPayload.Standard].paymentSecret).toSet == Set(payment.paymentSecret)) - assert(childPayments.map(_.finalPayload.asInstanceOf[FinalPayload.Standard].paymentMetadata).toSet == Set(Some(hex"012345"))) - assert(childPayments.map(_.finalPayload.amount).toSet == Set(500000 msat, 700000 msat)) - assert(childPayments.map(_.finalPayload.totalAmount).toSet == Set(1200000 msat)) + childPayments.foreach(childPayment => assert(childPayment.recipient == recipient)) + assert(childPayments.map(_.amount).toSet == Set(500_000 msat, 700_000 msat)) assert(payFsm.stateName == PAYMENT_IN_PROGRESS) - val result = fulfillPendingPayments(f, 2) - assert(result.amountWithFees == 1200200.msat) - assert(result.trampolineFees == 200000.msat) + val result = fulfillPendingPayments(f, 2, e, 1_200_000 msat) + assert(result.amountWithFees == 1_200_200.msat) assert(result.nonTrampolineFees == 200.msat) val metrics = metricsListener.expectMsgType[PathFindingExperimentMetrics] assert(metrics.status == "SUCCESS") assert(metrics.experimentName == "my-test-experiment") - assert(metrics.amount == finalAmount) - assert(metrics.fees == 200200.msat) + assert(metrics.amount == 1_200_000.msat) + assert(metrics.fees == 200.msat) metricsListener.expectNoMessage() } - test("send custom tlv records") { f => + test("successful first attempt (trampoline)") { f => import f._ - // We include a bunch of additional tlv records. - val trampolineTlv = OnionPaymentPayloadTlv.TrampolineOnion(OnionRoutingPacket(0, ByteVector.fill(33)(0), ByteVector.fill(400)(0), randomBytes32())) - val userCustomTlv = GenericTlv(UInt64(561), hex"deadbeef") - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount + 1000.msat, expiry, 1, None, routeParams = routeParams, additionalTlvs = Seq(trampolineTlv), userCustomTlvs = Seq(userCustomTlv)) + assert(payFsm.stateName == WAIT_FOR_PAYMENT_REQUEST) + val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, Some(finalAmount), randomBytes32(), randomKey(), Left("invoice"), CltvExpiryDelta(12)) + val trampolineHop = NodeHop(e, invoice.nodeId, CltvExpiryDelta(50), 1000 msat) + val recipient = ClearTrampolineRecipient(invoice, finalAmount, expiry, trampolineHop, randomBytes32()) + val payment = SendMultiPartPayment(sender.ref, recipient, 1, routeParams) sender.send(payFsm, payment) - router.expectMsgType[RouteRequest] - router.send(payFsm, RouteResponse(Seq(Route(500000 msat, hop_ab_1 :: hop_be :: Nil), Route(501000 msat, hop_ac_1 :: hop_ce :: Nil)))) - val childPayments = childPayFsm.expectMsgType[SendPaymentToRoute] :: childPayFsm.expectMsgType[SendPaymentToRoute] :: Nil - childPayments.map(_.finalPayload).foreach(p => { - assert(p.records.get[OnionPaymentPayloadTlv.TrampolineOnion].contains(trampolineTlv)) - assert(p.records.unknown.toSeq == Seq(userCustomTlv)) - }) - val result = fulfillPendingPayments(f, 2) + router.expectMsg(RouteRequest(nodeParams.nodeId, recipient, routeParams.copy(randomize = false), allowMultiPart = true, paymentContext = Some(cfg.paymentContext))) + assert(payFsm.stateName == WAIT_FOR_ROUTES) + + val routes = Seq( + Route(500_000 msat, hop_ab_1 :: hop_be :: Nil, Some(trampolineHop)), + Route(501_000 msat, hop_ac_1 :: hop_ce :: Nil, Some(trampolineHop)) + ) + router.send(payFsm, RouteResponse(routes)) + val childPayments = childPayFsm.expectMsgType[SendPaymentToRoute] :: childPayFsm.expectMsgType[SendPaymentToRoute] :: Nil + assert(childPayments.map(_.route).toSet == routes.map(r => Right(r)).toSet) + childPayments.foreach(childPayment => assert(childPayment.recipient == recipient)) + assert(childPayments.map(_.amount).toSet == Set(500_000 msat, 501_000 msat)) + assert(payFsm.stateName == PAYMENT_IN_PROGRESS) + + val result = fulfillPendingPayments(f, 2, invoice.nodeId, finalAmount) + assert(result.amountWithFees == 1_001_200.msat) assert(result.trampolineFees == 1000.msat) + assert(result.nonTrampolineFees == 200.msat) val metrics = metricsListener.expectMsgType[PathFindingExperimentMetrics] assert(metrics.status == "SUCCESS") @@ -178,10 +181,10 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("successful retry") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 3, None, routeParams = routeParams) + val payment = SendMultiPartPayment(sender.ref, clearRecipient, 3, routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] - val failingRoute = Route(finalAmount, hop_ab_1 :: hop_be :: Nil) + val failingRoute = Route(finalAmount, hop_ab_1 :: hop_be :: Nil, None) router.send(payFsm, RouteResponse(Seq(failingRoute))) childPayFsm.expectMsgType[SendPaymentToRoute] childPayFsm.expectNoMessage(100 millis) @@ -189,14 +192,14 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS val childId = payFsm.stateData.asInstanceOf[PaymentProgress].pending.keys.head childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(failingRoute.amount, failingRoute.hops, Sphinx.DecryptedFailurePacket(b, PermanentChannelFailure))))) // We retry ignoring the failing channel. - router.expectMsg(RouteRequest(nodeParams.nodeId, e, finalAmount, maxFee, routeParams = routeParams.copy(randomize = true), allowMultiPart = true, ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_be, b, e))), paymentContext = Some(cfg.paymentContext))) - router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ac_1 :: hop_ce :: Nil), Route(600000 msat, hop_ad :: hop_de :: Nil)))) + router.expectMsg(RouteRequest(nodeParams.nodeId, clearRecipient, routeParams.copy(randomize = true), allowMultiPart = true, ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_be, b, e))), paymentContext = Some(cfg.paymentContext))) + router.send(payFsm, RouteResponse(Seq(Route(400_000 msat, hop_ac_1 :: hop_ce :: Nil, None), Route(600_000 msat, hop_ad :: hop_de :: Nil, None)))) childPayFsm.expectMsgType[SendPaymentToRoute] childPayFsm.expectMsgType[SendPaymentToRoute] assert(!payFsm.stateData.asInstanceOf[PaymentProgress].pending.contains(childId)) - val result = fulfillPendingPayments(f, 2) - assert(result.amountWithFees == 1000200.msat) + val result = fulfillPendingPayments(f, 2, e, finalAmount) + assert(result.amountWithFees == 1_000_200.msat) assert(result.trampolineFees == 0.msat) assert(result.nonTrampolineFees == 200.msat) @@ -211,10 +214,10 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("retry failures while waiting for routes") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 3, None, routeParams = routeParams) + val payment = SendMultiPartPayment(sender.ref, clearRecipient, 3, routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] - router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ab_1 :: hop_be :: Nil), Route(600000 msat, hop_ab_2 :: hop_be :: Nil)))) + router.send(payFsm, RouteResponse(Seq(Route(400_000 msat, hop_ab_1 :: hop_be :: Nil, None), Route(600_000 msat, hop_ab_2 :: hop_be :: Nil, None)))) childPayFsm.expectMsgType[SendPaymentToRoute] childPayFsm.expectMsgType[SendPaymentToRoute] childPayFsm.expectNoMessage(100 millis) @@ -223,23 +226,23 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.send(payFsm, PaymentFailed(failedId1, paymentHash, Seq(RemoteFailure(failedRoute1.amount, failedRoute1.hops, Sphinx.DecryptedFailurePacket(b, TemporaryNodeFailure))))) // When we retry, we ignore the failing node and we let the router know about the remaining pending route. - router.expectMsg(RouteRequest(nodeParams.nodeId, e, failedRoute1.amount, maxFee - failedRoute1.fee(false), ignore = Ignore(Set(b), Set.empty), pendingPayments = Seq(failedRoute2), allowMultiPart = true, routeParams = routeParams.copy(randomize = true), paymentContext = Some(cfg.paymentContext))) + router.expectMsg(RouteRequest(nodeParams.nodeId, clearRecipient, routeParams.copy(randomize = true), Ignore(Set(b), Set.empty), pendingPayments = Seq(failedRoute2), allowMultiPart = true, paymentContext = Some(cfg.paymentContext))) // The second part fails while we're still waiting for new routes. childPayFsm.send(payFsm, PaymentFailed(failedId2, paymentHash, Seq(RemoteFailure(failedRoute2.amount, failedRoute2.hops, Sphinx.DecryptedFailurePacket(b, TemporaryNodeFailure))))) // We receive a response to our first request, but it's now obsolete: we re-sent a new route request that takes into // account the latest failures. - router.send(payFsm, RouteResponse(Seq(Route(failedRoute1.amount, hop_ac_1 :: hop_ce :: Nil)))) - router.expectMsg(RouteRequest(nodeParams.nodeId, e, finalAmount, maxFee, ignore = Ignore(Set(b), Set.empty), allowMultiPart = true, routeParams = routeParams.copy(randomize = true), paymentContext = Some(cfg.paymentContext))) + router.send(payFsm, RouteResponse(Seq(Route(failedRoute1.amount, hop_ac_1 :: hop_ce :: Nil, None)))) + router.expectMsg(RouteRequest(nodeParams.nodeId, clearRecipient, routeParams.copy(randomize = true), Ignore(Set(b), Set.empty), allowMultiPart = true, paymentContext = Some(cfg.paymentContext))) awaitCond(payFsm.stateData.asInstanceOf[PaymentProgress].pending.isEmpty) childPayFsm.expectNoMessage(100 millis) // We receive new routes that work. - router.send(payFsm, RouteResponse(Seq(Route(300000 msat, hop_ac_1 :: hop_ce :: Nil), Route(700000 msat, hop_ad :: hop_de :: Nil)))) + router.send(payFsm, RouteResponse(Seq(Route(300_000 msat, hop_ac_1 :: hop_ce :: Nil, None), Route(700_000 msat, hop_ad :: hop_de :: Nil, None)))) childPayFsm.expectMsgType[SendPaymentToRoute] childPayFsm.expectMsgType[SendPaymentToRoute] - val result = fulfillPendingPayments(f, 2) - assert(result.amountWithFees == 1000200.msat) + val result = fulfillPendingPayments(f, 2, e, finalAmount) + assert(result.amountWithFees == 1_000_200.msat) assert(result.nonTrampolineFees == 200.msat) val metrics = metricsListener.expectMsgType[PathFindingExperimentMetrics] @@ -253,10 +256,10 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("retry local channel failures") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 3, None, routeParams = routeParams) + val payment = SendMultiPartPayment(sender.ref, clearRecipient, 3, routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] - router.send(payFsm, RouteResponse(Seq(Route(finalAmount, hop_ab_1 :: hop_be :: Nil)))) + router.send(payFsm, RouteResponse(Seq(Route(finalAmount, hop_ab_1 :: hop_be :: Nil, None)))) childPayFsm.expectMsgType[SendPaymentToRoute] childPayFsm.expectNoMessage(100 millis) @@ -264,24 +267,24 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(LocalFailure(failedRoute.amount, failedRoute.hops, RemoteCannotAffordFeesForNewHtlc(randomBytes32(), finalAmount, 15 sat, 0 sat, 15 sat))))) // We retry without the failing channel. - val expectedRouteRequest = RouteRequest( - nodeParams.nodeId, e, - failedRoute.amount, maxFee, - ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_ab_1, a, b))), + router.expectMsg(RouteRequest( + nodeParams.nodeId, + clearRecipient, + routeParams.copy(randomize = true), + Ignore(Set.empty, Set(ChannelDesc(channelId_ab_1, a, b))), pendingPayments = Nil, allowMultiPart = true, - routeParams = routeParams.copy(randomize = true), - paymentContext = Some(cfg.paymentContext)) - router.expectMsg(expectedRouteRequest) + paymentContext = Some(cfg.paymentContext) + )) } test("retry without ignoring channels") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 3, None, routeParams = routeParams) + val payment = SendMultiPartPayment(sender.ref, clearRecipient, 3, routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] - router.send(payFsm, RouteResponse(Seq(Route(500000 msat, hop_ab_1 :: hop_be :: Nil), Route(500000 msat, hop_ab_1 :: hop_be :: Nil)))) + router.send(payFsm, RouteResponse(Seq(Route(500_000 msat, hop_ab_1 :: hop_be :: Nil, None), Route(500_000 msat, hop_ab_1 :: hop_be :: Nil, None)))) childPayFsm.expectMsgType[SendPaymentToRoute] childPayFsm.expectMsgType[SendPaymentToRoute] childPayFsm.expectNoMessage(100 millis) @@ -292,22 +295,22 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS // If the router doesn't find routes, we will retry without ignoring the channel: it may work with a different split // of the amount to send. val expectedRouteRequest = RouteRequest( - nodeParams.nodeId, e, - failedRoute.amount, maxFee - failedRoute.fee(false), - ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_ab_1, a, b))), + nodeParams.nodeId, + clearRecipient, + routeParams.copy(randomize = true), + Ignore(Set.empty, Set(ChannelDesc(channelId_ab_1, a, b))), pendingPayments = Seq(pendingRoute), allowMultiPart = true, - routeParams = routeParams.copy(randomize = true), paymentContext = Some(cfg.paymentContext)) router.expectMsg(expectedRouteRequest) router.send(payFsm, Status.Failure(RouteNotFound)) router.expectMsg(expectedRouteRequest.copy(ignore = Ignore.empty)) - router.send(payFsm, RouteResponse(Seq(Route(500000 msat, hop_ac_1 :: hop_ce :: Nil)))) + router.send(payFsm, RouteResponse(Seq(Route(500_000 msat, hop_ac_1 :: hop_ce :: Nil, None)))) childPayFsm.expectMsgType[SendPaymentToRoute] - val result = fulfillPendingPayments(f, 2) - assert(result.amountWithFees == 1000200.msat) + val result = fulfillPendingPayments(f, 2, e, finalAmount) + assert(result.amountWithFees == 1_000_200.msat) val metrics = metricsListener.expectMsgType[PathFindingExperimentMetrics] assert(metrics.status == "SUCCESS") @@ -321,54 +324,56 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS import f._ // The B -> E channel is private and provided in the invoice routing hints. - val extraEdge = Invoice.BasicEdge(b, e, hop_be.shortChannelId, hop_be.params.relayFees.feeBase, hop_be.params.relayFees.feeProportionalMillionths, hop_be.params.cltvExpiryDelta) - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 3, None, routeParams = routeParams, extraEdges = List(extraEdge)) + val extraEdge = ExtraEdge(b, e, hop_be.shortChannelId, hop_be.params.relayFees.feeBase, hop_be.params.relayFees.feeProportionalMillionths, hop_be.params.cltvExpiryDelta, hop_be.params.htlcMinimum, hop_be.params.htlcMaximum_opt) + val recipient = ClearRecipient(e, Features.empty, finalAmount, expiry, randomBytes32(), Seq(extraEdge)) + val payment = SendMultiPartPayment(sender.ref, recipient, 3, routeParams) sender.send(payFsm, payment) - assert(router.expectMsgType[RouteRequest].extraEdges.head == extraEdge) - val route = Route(finalAmount, hop_ab_1 :: hop_be :: Nil) + assert(router.expectMsgType[RouteRequest].target.extraEdges == Seq(extraEdge)) + val route = Route(finalAmount, hop_ab_1 :: hop_be :: Nil, None) router.send(payFsm, RouteResponse(Seq(route))) childPayFsm.expectMsgType[SendPaymentToRoute] childPayFsm.expectNoMessage(100 millis) // B changed his fees and expiry after the invoice was issued. - val channelUpdate = hop_be.params.asInstanceOf[ChannelRelayParams.FromAnnouncement].channelUpdate.copy(feeBaseMsat = 250 msat, feeProportionalMillionths = 150, cltvExpiryDelta = CltvExpiryDelta(24)) + val channelUpdate = channelUpdate_be.copy(feeBaseMsat = 250 msat, feeProportionalMillionths = 150, cltvExpiryDelta = CltvExpiryDelta(24)) val childId = payFsm.stateData.asInstanceOf[PaymentProgress].pending.keys.head childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(route.amount, route.hops, Sphinx.DecryptedFailurePacket(b, FeeInsufficient(finalAmount, channelUpdate)))))) // We update the routing hints accordingly before requesting a new route. - val updatedExtraEdge = router.expectMsgType[RouteRequest].extraEdges.head - assert(updatedExtraEdge == BasicEdge(b, e, hop_be.shortChannelId, channelUpdate.feeBaseMsat, channelUpdate.feeProportionalMillionths, channelUpdate.cltvExpiryDelta)) + val extraEdge1 = extraEdge.copy(feeBase = 250 msat, feeProportionalMillionths = 150, cltvExpiryDelta = CltvExpiryDelta(24)) + assert(router.expectMsgType[RouteRequest].target.extraEdges == Seq(extraEdge1)) } test("retry with ignored routing hints (temporary channel failure)") { f => import f._ // The B -> E channel is private and provided in the invoice routing hints. - val extraEdge = Invoice.BasicEdge(b, e, hop_be.shortChannelId, hop_be.params.relayFees.feeBase, hop_be.params.relayFees.feeProportionalMillionths, hop_be.params.cltvExpiryDelta) - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 3, None, routeParams = routeParams, extraEdges = List(extraEdge)) + val extraEdge = ExtraEdge(b, e, hop_be.shortChannelId, hop_be.params.relayFees.feeBase, hop_be.params.relayFees.feeProportionalMillionths, hop_be.params.cltvExpiryDelta, hop_be.params.htlcMinimum, hop_be.params.htlcMaximum_opt) + val recipient = ClearRecipient(e, Features.empty, finalAmount, expiry, randomBytes32(), Seq(extraEdge)) + val payment = SendMultiPartPayment(sender.ref, recipient, 3, routeParams) sender.send(payFsm, payment) - assert(router.expectMsgType[RouteRequest].extraEdges.head == extraEdge) - val route = Route(finalAmount, hop_ab_1 :: hop_be :: Nil) + assert(router.expectMsgType[RouteRequest].target.extraEdges == Seq(extraEdge)) + val route = Route(finalAmount, hop_ab_1 :: hop_be :: Nil, None) router.send(payFsm, RouteResponse(Seq(route))) childPayFsm.expectMsgType[SendPaymentToRoute] childPayFsm.expectNoMessage(100 millis) // B doesn't have enough liquidity on this channel. // NB: we need a channel update with a valid signature, otherwise we'll ignore the node instead of this specific channel. - val channelUpdateBE = hop_be.params.asInstanceOf[ChannelRelayParams.FromAnnouncement].channelUpdate - val channelUpdateBE1 = Announcements.makeChannelUpdate(channelUpdateBE.chainHash, priv_b, e, channelUpdateBE.shortChannelId, channelUpdateBE.cltvExpiryDelta, channelUpdateBE.htlcMinimumMsat, channelUpdateBE.feeBaseMsat, channelUpdateBE.feeProportionalMillionths, channelUpdateBE.htlcMaximumMsat) + val channelUpdate = Announcements.makeChannelUpdate(channelUpdate_be.chainHash, priv_b, e, channelUpdate_be.shortChannelId, channelUpdate_be.cltvExpiryDelta, channelUpdate_be.htlcMinimumMsat, channelUpdate_be.feeBaseMsat, channelUpdate_be.feeProportionalMillionths, channelUpdate_be.htlcMaximumMsat) val childId = payFsm.stateData.asInstanceOf[PaymentProgress].pending.keys.head - childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(route.amount, route.hops, Sphinx.DecryptedFailurePacket(b, TemporaryChannelFailure(channelUpdateBE1)))))) + childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(route.amount, route.hops, Sphinx.DecryptedFailurePacket(b, TemporaryChannelFailure(channelUpdate)))))) // We update the routing hints accordingly before requesting a new route and ignore the channel. val routeRequest = router.expectMsgType[RouteRequest] - assert(routeRequest.extraEdges.head == extraEdge) - assert(routeRequest.ignore.channels.map(_.shortChannelId) == Set(channelUpdateBE1.shortChannelId)) + assert(routeRequest.target.extraEdges == Seq(extraEdge)) + assert(routeRequest.ignore.channels.map(_.shortChannelId) == Set(channelUpdate.shortChannelId)) } test("update routing hints") { () => - val extraEdges = Seq( - BasicEdge(a, b, ShortChannelId(1), 10 msat, 0, CltvExpiryDelta(12)), BasicEdge(b, c, ShortChannelId(2), 0 msat, 100, CltvExpiryDelta(24)), - BasicEdge(a, c, ShortChannelId(3), 1 msat, 10, CltvExpiryDelta(144)) - ) + val recipient = ClearRecipient(e, Features.empty, finalAmount, expiry, randomBytes32(), Seq( + ExtraEdge(a, b, ShortChannelId(1), 10 msat, 0, CltvExpiryDelta(12), 1 msat, None), + ExtraEdge(b, c, ShortChannelId(2), 0 msat, 100, CltvExpiryDelta(24), 1 msat, None), + ExtraEdge(a, c, ShortChannelId(3), 1 msat, 10, CltvExpiryDelta(144), 1 msat, None) + )) def makeChannelUpdate(shortChannelId: ShortChannelId, feeBase: MilliSatoshi, feeProportional: Long, cltvExpiryDelta: CltvExpiryDelta): ChannelUpdate = { defaultChannelUpdate.copy(shortChannelId = shortChannelId, feeBaseMsat = feeBase, feeProportionalMillionths = feeProportional, cltvExpiryDelta = cltvExpiryDelta) @@ -381,10 +386,11 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS UnreadableRemoteFailure(finalAmount, Nil) ) val extraEdges1 = Seq( - BasicEdge(a, b, ShortChannelId(1), 10 msat, 0, CltvExpiryDelta(12)), BasicEdge(b, c, ShortChannelId(2), 15 msat, 150, CltvExpiryDelta(48)), - BasicEdge(a, c, ShortChannelId(3), 1 msat, 10, CltvExpiryDelta(144)) + ExtraEdge(a, b, ShortChannelId(1), 10 msat, 0, CltvExpiryDelta(12), 1 msat, None), + ExtraEdge(b, c, ShortChannelId(2), 15 msat, 150, CltvExpiryDelta(48), defaultChannelUpdate.htlcMinimumMsat, Some(defaultChannelUpdate.htlcMaximumMsat)), + ExtraEdge(a, c, ShortChannelId(3), 1 msat, 10, CltvExpiryDelta(144), 1 msat, None) ) - assert(extraEdges1.zip(PaymentFailure.updateExtraEdges(failures, extraEdges)).forall { case (e1, e2) => e1 == e2 }) + assert(extraEdges1.zip(PaymentFailure.updateExtraEdges(failures, recipient).extraEdges).forall { case (e1, e2) => e1 == e2 }) } { val failures = Seq( @@ -394,27 +400,28 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS RemoteFailure(finalAmount, Nil, Sphinx.DecryptedFailurePacket(a, FeeInsufficient(100 msat, makeChannelUpdate(ShortChannelId(1), 23 msat, 23, CltvExpiryDelta(23))))), ) val extraEdges1 = Seq( - BasicEdge(a, b, ShortChannelId(1), 23 msat, 23, CltvExpiryDelta(23)), BasicEdge(b, c, ShortChannelId(2), 21 msat, 21, CltvExpiryDelta(21)), - BasicEdge(a, c, ShortChannelId(3), 22 msat, 22, CltvExpiryDelta(22)) + ExtraEdge(a, b, ShortChannelId(1), 23 msat, 23, CltvExpiryDelta(23), defaultChannelUpdate.htlcMinimumMsat, Some(defaultChannelUpdate.htlcMaximumMsat)), + ExtraEdge(b, c, ShortChannelId(2), 21 msat, 21, CltvExpiryDelta(21), defaultChannelUpdate.htlcMinimumMsat, Some(defaultChannelUpdate.htlcMaximumMsat)), + ExtraEdge(a, c, ShortChannelId(3), 22 msat, 22, CltvExpiryDelta(22), defaultChannelUpdate.htlcMinimumMsat, Some(defaultChannelUpdate.htlcMaximumMsat)) ) - assert(extraEdges1.zip(PaymentFailure.updateExtraEdges(failures, extraEdges)).forall { case (e1, e2) => e1 == e2 }) + assert(extraEdges1.zip(PaymentFailure.updateExtraEdges(failures, recipient).extraEdges).forall { case (e1, e2) => e1 == e2 }) } } test("abort after too many failed attempts") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 2, None, routeParams = routeParams) + val payment = SendMultiPartPayment(sender.ref, clearRecipient, 2, routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] - router.send(payFsm, RouteResponse(Seq(Route(500000 msat, hop_ab_1 :: hop_be :: Nil), Route(500000 msat, hop_ac_1 :: hop_ce :: Nil)))) + router.send(payFsm, RouteResponse(Seq(Route(500_000 msat, hop_ab_1 :: hop_be :: Nil, None), Route(500_000 msat, hop_ac_1 :: hop_ce :: Nil, None)))) childPayFsm.expectMsgType[SendPaymentToRoute] childPayFsm.expectMsgType[SendPaymentToRoute] val (failedId1, failedRoute1) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head childPayFsm.send(payFsm, PaymentFailed(failedId1, paymentHash, Seq(UnreadableRemoteFailure(failedRoute1.amount, failedRoute1.hops)))) router.expectMsgType[RouteRequest] - router.send(payFsm, RouteResponse(Seq(Route(500000 msat, hop_ad :: hop_de :: Nil)))) + router.send(payFsm, RouteResponse(Seq(Route(500_000 msat, hop_ad :: hop_de :: Nil, None)))) childPayFsm.expectMsgType[SendPaymentToRoute] assert(!payFsm.stateData.asInstanceOf[PaymentProgress].pending.contains(failedId1)) @@ -424,10 +431,10 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS assert(result.failures.contains(LocalFailure(finalAmount, Nil, RetryExhausted))) val metrics = metricsListener.expectMsgType[PathFindingExperimentMetrics] - assert(metrics.status == "FAILURE") + assert(metrics.status == "RECIPIENT_FAILURE") assert(metrics.experimentName == "my-test-experiment") assert(metrics.amount == finalAmount) - assert(metrics.fees == 15000.msat) + assert(metrics.fees == 15_000.msat) metricsListener.expectNoMessage() } @@ -435,7 +442,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS import f._ sender.watch(payFsm) - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 5, None, routeParams = routeParams) + val payment = SendMultiPartPayment(sender.ref, clearRecipient, 5, routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] router.send(payFsm, Status.Failure(RouteNotFound)) @@ -458,38 +465,38 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS assert(metrics.status == "FAILURE") assert(metrics.experimentName == "my-test-experiment") assert(metrics.amount == finalAmount) - assert(metrics.fees == 15000.msat) + assert(metrics.fees == 15_000.msat) metricsListener.expectNoMessage() } test("abort if recipient sends error") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 5, None, routeParams = routeParams) + val payment = SendMultiPartPayment(sender.ref, clearRecipient, 5, routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] - router.send(payFsm, RouteResponse(Seq(Route(finalAmount, hop_ab_1 :: hop_be :: Nil)))) + router.send(payFsm, RouteResponse(Seq(Route(finalAmount, hop_ab_1 :: hop_be :: Nil, None)))) childPayFsm.expectMsgType[SendPaymentToRoute] val (failedId, failedRoute) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head - val result = abortAfterFailure(f, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.amount, failedRoute.hops, Sphinx.DecryptedFailurePacket(e, IncorrectOrUnknownPaymentDetails(600000 msat, BlockHeight(0))))))) + val result = abortAfterFailure(f, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.amount, failedRoute.hops, Sphinx.DecryptedFailurePacket(e, IncorrectOrUnknownPaymentDetails(600_000 msat, BlockHeight(0))))))) assert(result.failures.length == 1) val metrics = metricsListener.expectMsgType[PathFindingExperimentMetrics] - assert(metrics.status == "FAILURE") + assert(metrics.status == "RECIPIENT_FAILURE") assert(metrics.experimentName == "my-test-experiment") assert(metrics.amount == finalAmount) - assert(metrics.fees == 15000.msat) + assert(metrics.fees == 15_000.msat) metricsListener.expectNoMessage() } test("abort if payment gets settled on chain") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 5, None, routeParams = routeParams) + val payment = SendMultiPartPayment(sender.ref, clearRecipient, 5, routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] - router.send(payFsm, RouteResponse(Seq(Route(finalAmount, hop_ab_1 :: hop_be :: Nil)))) + router.send(payFsm, RouteResponse(Seq(Route(finalAmount, hop_ab_1 :: hop_be :: Nil, None)))) childPayFsm.expectMsgType[SendPaymentToRoute] val (failedId, failedRoute) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head @@ -500,10 +507,10 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("abort if recipient sends error during retry") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 5, None, routeParams = routeParams) + val payment = SendMultiPartPayment(sender.ref, clearRecipient, 5, routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] - router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ab_1 :: hop_be :: Nil), Route(600000 msat, hop_ac_1 :: hop_ce :: Nil)))) + router.send(payFsm, RouteResponse(Seq(Route(400_000 msat, hop_ab_1 :: hop_be :: Nil, None), Route(600_000 msat, hop_ac_1 :: hop_ce :: Nil, None)))) childPayFsm.expectMsgType[SendPaymentToRoute] childPayFsm.expectMsgType[SendPaymentToRoute] @@ -518,10 +525,10 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("receive partial success after retriable failure (recipient spec violation)") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 5, None, routeParams = routeParams) + val payment = SendMultiPartPayment(sender.ref, clearRecipient, 5, routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] - router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ab_1 :: hop_be :: Nil), Route(600000 msat, hop_ac_1 :: hop_ce :: Nil)))) + router.send(payFsm, RouteResponse(Seq(Route(400_000 msat, hop_ab_1 :: hop_be :: Nil, None), Route(600_000 msat, hop_ac_1 :: hop_ce :: Nil, None)))) childPayFsm.expectMsgType[SendPaymentToRoute] childPayFsm.expectMsgType[SendPaymentToRoute] @@ -529,19 +536,19 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(UnreadableRemoteFailure(failedRoute.amount, failedRoute.hops)))) router.expectMsgType[RouteRequest] - val result = fulfillPendingPayments(f, 1) + val result = fulfillPendingPayments(f, 1, e, finalAmount) assert(result.amountWithFees < finalAmount) // we got the preimage without paying the full amount - assert(result.nonTrampolineFees == successRoute.fee(false)) // we paid the fee for only one of the partial payments + assert(result.nonTrampolineFees == successRoute.channelFee(false)) // we paid the fee for only one of the partial payments assert(result.parts.length == 1 && result.parts.head.id == successId) } test("receive partial success after abort (recipient spec violation)") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 5, None, routeParams = routeParams) + val payment = SendMultiPartPayment(sender.ref, clearRecipient, 5, routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] - router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ab_1 :: hop_be :: Nil), Route(600000 msat, hop_ac_1 :: hop_ce :: Nil)))) + router.send(payFsm, RouteResponse(Seq(Route(400_000 msat, hop_ab_1 :: hop_be :: Nil, None), Route(600_000 msat, hop_ac_1 :: hop_ce :: Nil, None)))) childPayFsm.expectMsgType[SendPaymentToRoute] childPayFsm.expectMsgType[SendPaymentToRoute] @@ -550,7 +557,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS awaitCond(payFsm.stateName == PAYMENT_ABORTED) sender.watch(payFsm) - childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(successId, successRoute.amount, successRoute.fee(false), randomBytes32(), Some(successRoute.hops))))) + childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(successId, successRoute.amount, successRoute.channelFee(false), randomBytes32(), Some(successRoute.hops))))) sender.expectMsg(PreimageReceived(paymentHash, paymentPreimage)) val result = sender.expectMsgType[PaymentSent] assert(result.id == cfg.id) @@ -558,9 +565,9 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS assert(result.paymentPreimage == paymentPreimage) assert(result.parts.length == 1 && result.parts.head.id == successId) assert(result.recipientAmount == finalAmount) - assert(result.recipientNodeId == finalRecipient) + assert(result.recipientNodeId == e) assert(result.amountWithFees < finalAmount) // we got the preimage without paying the full amount - assert(result.nonTrampolineFees == successRoute.fee(false)) // we paid the fee for only one of the partial payments + assert(result.nonTrampolineFees == successRoute.channelFee(false)) // we paid the fee for only one of the partial payments sender.expectTerminated(payFsm) sender.expectNoMessage(100 millis) @@ -571,15 +578,15 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("receive partial failure after success (recipient spec violation)") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 5, None, routeParams = routeParams) + val payment = SendMultiPartPayment(sender.ref, clearRecipient, 5, routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] - router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ab_1 :: hop_be :: Nil), Route(600000 msat, hop_ac_1 :: hop_ce :: Nil)))) + router.send(payFsm, RouteResponse(Seq(Route(400_000 msat, hop_ab_1 :: hop_be :: Nil, None), Route(600_000 msat, hop_ac_1 :: hop_ce :: Nil, None)))) childPayFsm.expectMsgType[SendPaymentToRoute] childPayFsm.expectMsgType[SendPaymentToRoute] val (childId, route) :: (failedId, failedRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toSeq - childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(childId, route.amount, route.fee(false), randomBytes32(), Some(route.hops))))) + childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(childId, route.amount, route.channelFee(false), randomBytes32(), Some(route.hops))))) sender.expectMsg(PreimageReceived(paymentHash, paymentPreimage)) awaitCond(payFsm.stateName == PAYMENT_SUCCEEDED) @@ -588,7 +595,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS val result = sender.expectMsgType[PaymentSent] assert(result.parts.length == 1 && result.parts.head.id == childId) assert(result.amountWithFees < finalAmount) // we got the preimage without paying the full amount - assert(result.nonTrampolineFees == route.fee(false)) // we paid the fee for only one of the partial payments + assert(result.nonTrampolineFees == route.channelFee(false)) // we paid the fee for only one of the partial payments sender.expectTerminated(payFsm) sender.expectNoMessage(100 millis) @@ -596,7 +603,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectNoMessage(100 millis) } - def fulfillPendingPayments(f: FixtureParam, childCount: Int): PaymentSent = { + def fulfillPendingPayments(f: FixtureParam, childCount: Int, recipientNodeId: PublicKey, recipientAmount: MilliSatoshi): PaymentSent = { import f._ sender.watch(payFsm) @@ -604,7 +611,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS assert(pending.size == childCount) val partialPayments = pending.map { - case (childId, route) => PaymentSent.PartialPayment(childId, route.amount, route.fee(false), randomBytes32(), Some(route.hops)) + case (childId, route) => PaymentSent.PartialPayment(childId, route.amount, route.channelFee(false), randomBytes32(), Some(route.hops)) } partialPayments.foreach(pp => childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(pp)))) sender.expectMsg(PreimageReceived(paymentHash, paymentPreimage)) @@ -613,8 +620,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS assert(result.paymentHash == paymentHash) assert(result.paymentPreimage == paymentPreimage) assert(result.parts.toSet == partialPayments.toSet) - assert(result.recipientAmount == finalAmount) - assert(result.recipientNodeId == finalRecipient) + assert(result.recipientAmount == recipientAmount) + assert(result.recipientNodeId == recipientNodeId) sender.expectTerminated(payFsm) sender.expectNoMessage(100 millis) @@ -659,20 +666,19 @@ object MultiPartPaymentLifecycleSpec { val paymentPreimage = randomBytes32() val paymentHash = Crypto.sha256(paymentPreimage) val expiry = CltvExpiry(1105) - val finalAmount = 1000000 msat - val finalRecipient = randomKey().publicKey + val finalAmount = 1_000_000 msat val routeParams = PathFindingConf( randomize = false, boundaries = SearchBoundaries( - 15000 msat, - 0.01, + 15_000 msat, + 0.00, 6, CltvExpiryDelta(1008)), Left(WeightRatios(1, 0, 0, 0, RelayFees(0 msat, 0))), MultiPartParams(1000 msat, 5), experimentName = "my-test-experiment", - experimentPercentage = 100).getDefaultRouteParams - val maxFee = 15000 msat // max fee for the defaultAmount + experimentPercentage = 100 + ).getDefaultRouteParams /** * We simulate a multi-part-friendly network: @@ -713,4 +719,11 @@ object MultiPartPaymentLifecycleSpec { val hop_ad = channelHopFromUpdate(a, d, channelUpdate_ad) val hop_de = channelHopFromUpdate(d, e, channelUpdate_de) + val recipientFeatures = Features( + Features.VariableLengthOnion -> FeatureSupport.Mandatory, + Features.PaymentSecret -> FeatureSupport.Mandatory, + Features.BasicMultiPartPayment -> FeatureSupport.Optional, + ).invoiceFeatures() + val clearRecipient = ClearRecipient(e, recipientFeatures, finalAmount, expiry, ByteVector32.One) + } \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala index f513d4970..7627e47c6 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala @@ -31,13 +31,11 @@ import fr.acinq.eclair.payment.PaymentSent.PartialPayment import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.SendMultiPartPayment import fr.acinq.eclair.payment.send.PaymentError.UnsupportedFeatures import fr.acinq.eclair.payment.send.PaymentInitiator._ -import fr.acinq.eclair.payment.send.{PaymentError, PaymentInitiator, PaymentLifecycle} +import fr.acinq.eclair.payment.send._ import fr.acinq.eclair.router.RouteNotFound import fr.acinq.eclair.router.Router._ -import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv.{AmountToForward, KeySend, OutgoingCltv} -import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload} import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiryDelta, Feature, Features, InvoiceFeature, MilliSatoshiLong, NodeParams, PaymentFinalExpiryConf, TestConstants, TestKitBaseClass, TimestampSecond, UnknownFeature, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, InvoiceFeature, MilliSatoshiLong, NodeParams, PaymentFinalExpiryConf, TestConstants, TestKitBaseClass, TimestampSecond, UnknownFeature, randomBytes32, randomKey} import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.{Outcome, Tag} import scodec.bits.{ByteVector, HexStringSyntax} @@ -112,10 +110,13 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike sender.send(initiator, req) sender.expectMsgType[UUID] payFsm.expectMsgType[SendPaymentConfig] - val tlvs = payFsm.expectMsgType[PaymentLifecycle.SendPayment].finalPayload.records - assert(tlvs.get[AmountToForward].get.amount == finalAmount) - assert(tlvs.get[OutgoingCltv].get.cltv == req.invoice.minFinalCltvExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight + 1)) - assert(tlvs.unknown == customRecords) + val payment = payFsm.expectMsgType[PaymentLifecycle.SendPayment] + assert(payment.amount == finalAmount) + assert(payment.recipient.nodeId == invoice.nodeId) + assert(payment.recipient.totalAmount == finalAmount) + assert(payment.recipient.expiry == req.invoice.minFinalCltvExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight + 1)) + assert(payment.recipient.isInstanceOf[ClearRecipient]) + assert(payment.recipient.asInstanceOf[ClearRecipient].customTlvs == customRecords) } test("forward keysend payment") { f => @@ -124,11 +125,13 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike sender.send(initiator, req) sender.expectMsgType[UUID] payFsm.expectMsgType[SendPaymentConfig] - val tlvs = payFsm.expectMsgType[PaymentLifecycle.SendPayment].finalPayload.records - assert(tlvs.get[AmountToForward].get.amount == finalAmount) - assert(tlvs.get[OutgoingCltv].get.cltv == Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry(nodeParams.currentBlockHeight + 1)) - assert(tlvs.get[KeySend].get.paymentPreimage == paymentPreimage) - assert(tlvs.unknown.isEmpty) + val payment = payFsm.expectMsgType[PaymentLifecycle.SendPayment] + assert(payment.amount == finalAmount) + assert(payment.recipient.nodeId == c) + assert(payment.recipient.totalAmount == finalAmount) + assert(payment.recipient.expiry == Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry(nodeParams.currentBlockHeight + 1)) + assert(payment.recipient.isInstanceOf[SpontaneousRecipient]) + assert(payment.recipient.asInstanceOf[SpontaneousRecipient].preimage == paymentPreimage) } test("reject payment with unsupported mandatory feature") { f => @@ -160,12 +163,12 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike import f._ val finalExpiryDelta = CltvExpiryDelta(36) val invoice = Bolt11Invoice(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, priv_c.privateKey, Left("Some invoice"), finalExpiryDelta) - val route = PredefinedNodeRoute(Seq(a, b, c)) - val request = SendPaymentToRoute(finalAmount, finalAmount, invoice, route, None, None, None, 0 msat, CltvExpiryDelta(0), Nil) + val route = PredefinedNodeRoute(finalAmount, Seq(a, b, c)) + val request = SendPaymentToRoute(finalAmount, invoice, route, None, None, None) sender.send(initiator, request) val payment = sender.expectMsgType[SendPaymentToRouteResponse] - payFsm.expectMsg(SendPaymentConfig(payment.paymentId, payment.parentId, None, paymentHash, finalAmount, c, Upstream.Local(payment.paymentId), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = false, Nil)) - payFsm.expectMsg(PaymentLifecycle.SendPaymentToRoute(initiator, Left(route), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight + 1), invoice.paymentSecret, invoice.paymentMetadata))) + payFsm.expectMsg(SendPaymentConfig(payment.paymentId, payment.parentId, None, paymentHash, c, Upstream.Local(payment.paymentId), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = false)) + payFsm.expectMsg(PaymentLifecycle.SendPaymentToRoute(initiator, Left(route), ClearRecipient(invoice, finalAmount, finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight + 1), Nil))) sender.send(initiator, GetPayment(Left(payment.paymentId))) sender.expectMsg(PaymentIsPending(payment.paymentId, invoice.paymentHash, PendingPaymentToRoute(sender.ref, request))) @@ -189,8 +192,8 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(req.finalExpiry(nodeParams) == (finalExpiryDelta + 1).toCltvExpiry(nodeParams.currentBlockHeight)) sender.send(initiator, req) val id = sender.expectMsgType[UUID] - payFsm.expectMsg(SendPaymentConfig(id, id, None, paymentHash, finalAmount, c, Upstream.Local(id), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true, Nil)) - payFsm.expectMsg(PaymentLifecycle.SendPaymentToNode(initiator, c, FinalPayload.Standard(TlvStream(OnionPaymentPayloadTlv.AmountToForward(finalAmount), OnionPaymentPayloadTlv.OutgoingCltv(req.finalExpiry(nodeParams)), OnionPaymentPayloadTlv.PaymentData(invoice.paymentSecret, finalAmount))), 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams)) + payFsm.expectMsg(SendPaymentConfig(id, id, None, paymentHash, c, Upstream.Local(id), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true)) + payFsm.expectMsg(PaymentLifecycle.SendPaymentToNode(initiator, ClearRecipient(invoice, finalAmount, req.finalExpiry(nodeParams), Nil), 1, nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams)) sender.send(initiator, GetPayment(Left(id))) sender.expectMsg(PaymentIsPending(id, invoice.paymentHash, PendingPaymentToNode(sender.ref, req))) @@ -212,8 +215,8 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val req = SendPaymentToNode(finalAmount + 100.msat, invoice, 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) sender.send(initiator, req) val id = sender.expectMsgType[UUID] - multiPartPayFsm.expectMsg(SendPaymentConfig(id, id, None, paymentHash, finalAmount + 100.msat, c, Upstream.Local(id), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true, Nil)) - multiPartPayFsm.expectMsg(SendMultiPartPayment(initiator, invoice.paymentSecret, c, finalAmount + 100.msat, req.finalExpiry(nodeParams), 1, invoice.paymentMetadata, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams)) + multiPartPayFsm.expectMsg(SendPaymentConfig(id, id, None, paymentHash, c, Upstream.Local(id), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true)) + multiPartPayFsm.expectMsg(SendMultiPartPayment(initiator, ClearRecipient(invoice, finalAmount + 100.msat, req.finalExpiry(nodeParams), Nil), 1, nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams)) sender.send(initiator, GetPayment(Left(id))) sender.expectMsg(PaymentIsPending(id, invoice.paymentHash, PendingPaymentToNode(sender.ref, req))) @@ -236,28 +239,28 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val req = SendPaymentToNode(finalAmount, invoice, 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) sender.send(initiator, req) val id = sender.expectMsgType[UUID] - multiPartPayFsm.expectMsg(SendPaymentConfig(id, id, None, paymentHash, finalAmount, c, Upstream.Local(id), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true, Nil)) + multiPartPayFsm.expectMsg(SendPaymentConfig(id, id, None, paymentHash, c, Upstream.Local(id), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true)) val payment = multiPartPayFsm.expectMsgType[SendMultiPartPayment] - assert(nodeParams.currentBlockHeight + invoiceFinalExpiryDelta.toInt + 50 <= payment.targetExpiry.blockHeight) - assert(payment.targetExpiry.blockHeight <= nodeParams.currentBlockHeight + invoiceFinalExpiryDelta.toInt + 200) + val expiry = payment.recipient.asInstanceOf[ClearRecipient].expiry + assert(nodeParams.currentBlockHeight + invoiceFinalExpiryDelta.toInt + 50 <= expiry.blockHeight) + assert(expiry.blockHeight <= nodeParams.currentBlockHeight + invoiceFinalExpiryDelta.toInt + 200) } test("forward multi-part payment with pre-defined route") { f => import f._ val invoice = Bolt11Invoice(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, priv_c.privateKey, Left("Some invoice"), CltvExpiryDelta(18), features = featuresWithMpp) - val route = PredefinedChannelRoute(c, Seq(channelUpdate_ab.shortChannelId, channelUpdate_bc.shortChannelId)) - val req = SendPaymentToRoute(finalAmount / 2, finalAmount, invoice, route, None, None, None, 0 msat, CltvExpiryDelta(0), Nil) + val route = PredefinedChannelRoute(finalAmount / 2, c, Seq(channelUpdate_ab.shortChannelId, channelUpdate_bc.shortChannelId)) + val req = SendPaymentToRoute(finalAmount, invoice, route, None, None, None) sender.send(initiator, req) val payment = sender.expectMsgType[SendPaymentToRouteResponse] - payFsm.expectMsg(SendPaymentConfig(payment.paymentId, payment.parentId, None, paymentHash, finalAmount, c, Upstream.Local(payment.paymentId), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = false, Nil)) + payFsm.expectMsg(SendPaymentConfig(payment.paymentId, payment.parentId, None, paymentHash, c, Upstream.Local(payment.paymentId), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = false)) val msg = payFsm.expectMsgType[PaymentLifecycle.SendPaymentToRoute] assert(msg.replyTo == initiator) assert(msg.route == Left(route)) - assert(msg.finalPayload.isInstanceOf[FinalPayload.Standard]) - assert(msg.finalPayload.amount == finalAmount / 2) - assert(msg.finalPayload.expiry == req.finalExpiry(nodeParams)) - assert(msg.finalPayload.asInstanceOf[FinalPayload.Standard].paymentSecret == invoice.paymentSecret) - assert(msg.finalPayload.totalAmount == finalAmount) + assert(msg.amount == finalAmount / 2) + assert(msg.recipient.nodeId == c) + assert(msg.recipient.totalAmount == finalAmount) + assert(msg.recipient.expiry == req.finalExpiry(nodeParams)) sender.send(initiator, GetPayment(Left(payment.paymentId))) sender.expectMsg(PaymentIsPending(payment.paymentId, invoice.paymentHash, PendingPaymentToRoute(sender.ref, req))) @@ -277,8 +280,8 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike import f._ val ignoredRoutingHints = List(List(ExtraHop(b, channelUpdate_bc.shortChannelId, feeBase = 10 msat, feeProportionalMillionths = 1, cltvExpiryDelta = CltvExpiryDelta(12)))) val invoice = Bolt11Invoice(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, priv_c.privateKey, Left("Some phoenix invoice"), CltvExpiryDelta(9), features = featuresWithTrampoline, extraHops = ignoredRoutingHints) - val trampolineFees = 21000 msat - val req = SendTrampolinePayment(finalAmount, invoice, b, Seq((trampolineFees, CltvExpiryDelta(12))), routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) + val trampolineFees = 21_000 msat + val req = SendTrampolinePayment(finalAmount, invoice, b, Seq((trampolineFees, CltvExpiryDelta(12))), nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) sender.send(initiator, req) val id = sender.expectMsgType[UUID] multiPartPayFsm.expectMsgType[SendPaymentConfig] @@ -289,63 +292,38 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike sender.expectMsg(PaymentIsPending(id, invoice.paymentHash, PendingTrampolinePayment(sender.ref, Nil, req))) val msg = multiPartPayFsm.expectMsgType[SendMultiPartPayment] - assert(msg.paymentSecret !== invoice.paymentSecret) // we should not leak the invoice secret to the trampoline node - assert(msg.targetNodeId == b) - assert(msg.targetExpiry.toLong == currentBlockCount + 9 + 12 + 1) - assert(msg.totalAmount == finalAmount + trampolineFees) - assert(msg.additionalTlvs.head.isInstanceOf[OnionPaymentPayloadTlv.TrampolineOnion]) + assert(msg.recipient.nodeId == c) + assert(msg.recipient.totalAmount == finalAmount) + assert(msg.recipient.expiry.toLong == currentBlockCount + 9 + 1) + assert(msg.recipient.features.hasFeature(Features.TrampolinePaymentPrototype)) + assert(msg.recipient.isInstanceOf[ClearTrampolineRecipient]) + assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineNodeId == b) + assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + trampolineFees) + assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineExpiry == CltvExpiry(currentBlockCount + 9 + 1 + 12)) + assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolinePaymentSecret != invoice.paymentSecret) // we should not leak the invoice secret to the trampoline node assert(msg.maxAttempts == nodeParams.maxPaymentAttempts) - - // Verify that the trampoline node can correctly peel the trampoline onion. - val trampolineOnion = msg.additionalTlvs.head.asInstanceOf[OnionPaymentPayloadTlv.TrampolineOnion].packet - val Right(decrypted) = Sphinx.peel(priv_b.privateKey, Some(invoice.paymentHash), trampolineOnion) - assert(!decrypted.isLastPacket) - val Right(trampolinePayload) = IntermediatePayload.NodeRelay.Standard.validate(PaymentOnionCodecs.perHopPayloadCodec.decode(decrypted.payload.bits).require.value) - assert(trampolinePayload.amountToForward == finalAmount) - assert(trampolinePayload.totalAmount == finalAmount) - assert(trampolinePayload.outgoingCltv.toLong == currentBlockCount + 9 + 1) - assert(trampolinePayload.outgoingNodeId == c) - assert(trampolinePayload.paymentSecret.isEmpty) // we're not leaking the invoice secret to the trampoline node - assert(trampolinePayload.invoiceRoutingInfo.isEmpty) - assert(trampolinePayload.invoiceFeatures.isEmpty) - - // Verify that the recipient can correctly peel the trampoline onion. - val Right(decrypted1) = Sphinx.peel(priv_c.privateKey, Some(invoice.paymentHash), decrypted.nextPacket) - assert(decrypted1.isLastPacket) - val Right(finalPayload) = FinalPayload.Standard.validate(PaymentOnionCodecs.perHopPayloadCodec.decode(decrypted1.payload.bits).require.value) - assert(finalPayload.amount == finalAmount) - assert(finalPayload.totalAmount == finalAmount) - assert(finalPayload.expiry.toLong == currentBlockCount + 9 + 1) - assert(finalPayload.paymentSecret == invoice.paymentSecret) } test("forward trampoline to legacy payment") { f => import f._ - val invoice = Bolt11Invoice(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, priv_c.privateKey, Left("Some eclair-mobile invoice"), CltvExpiryDelta(9)) - val trampolineFees = 21000 msat + val invoice = Bolt11Invoice(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, priv_c.privateKey, Left("Some wallet invoice"), CltvExpiryDelta(9)) + val trampolineFees = 21_000 msat val req = SendTrampolinePayment(finalAmount, invoice, b, Seq((trampolineFees, CltvExpiryDelta(12))), routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) sender.send(initiator, req) sender.expectMsgType[UUID] multiPartPayFsm.expectMsgType[SendPaymentConfig] val msg = multiPartPayFsm.expectMsgType[SendMultiPartPayment] - assert(msg.paymentSecret !== invoice.paymentSecret) // we should not leak the invoice secret to the trampoline node - assert(msg.targetNodeId == b) - assert(msg.targetExpiry.toLong == currentBlockCount + 9 + 12 + 1) - assert(msg.totalAmount == finalAmount + trampolineFees) - assert(msg.additionalTlvs.head.isInstanceOf[OnionPaymentPayloadTlv.TrampolineOnion]) - - // Verify that the trampoline node can correctly peel the trampoline onion. - val trampolineOnion = msg.additionalTlvs.head.asInstanceOf[OnionPaymentPayloadTlv.TrampolineOnion].packet - val Right(decrypted) = Sphinx.peel(priv_b.privateKey, Some(invoice.paymentHash), trampolineOnion) - assert(!decrypted.isLastPacket) - val Right(trampolinePayload) = IntermediatePayload.NodeRelay.Standard.validate(PaymentOnionCodecs.perHopPayloadCodec.decode(decrypted.payload.bits).require.value) - assert(trampolinePayload.amountToForward == finalAmount) - assert(trampolinePayload.totalAmount == finalAmount) - assert(trampolinePayload.outgoingCltv.toLong == currentBlockCount + 9 + 1) - assert(trampolinePayload.outgoingNodeId == c) - assert(trampolinePayload.paymentSecret.contains(invoice.paymentSecret)) - assert(trampolinePayload.invoiceFeatures.contains(hex"4100")) // var_onion_optin, payment_secret + assert(msg.recipient.nodeId == c) + assert(msg.recipient.totalAmount == finalAmount) + assert(msg.recipient.expiry.toLong == currentBlockCount + 9 + 1) + assert(!msg.recipient.features.hasFeature(Features.TrampolinePaymentPrototype)) + assert(msg.recipient.isInstanceOf[ClearTrampolineRecipient]) + assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineNodeId == b) + assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + trampolineFees) + assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineExpiry == CltvExpiry(currentBlockCount + 9 + 1 + 12)) + assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolinePaymentSecret != invoice.paymentSecret) // we should not leak the invoice secret to the trampoline node + assert(msg.maxAttempts == nodeParams.maxPaymentAttempts) } test("reject trampoline to legacy payment for 0-value invoice") { f => @@ -353,7 +331,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike // This is disabled because it would let the trampoline node steal the whole payment (if malicious). val routingHints = List(List(Bolt11Invoice.ExtraHop(b, channelUpdate_bc.shortChannelId, 10 msat, 100, CltvExpiryDelta(144)))) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_a.privateKey, Left("#abittooreckless"), CltvExpiryDelta(18), None, None, routingHints, features = featuresWithMpp) - val trampolineFees = 21000 msat + val trampolineFees = 21_000 msat val req = SendTrampolinePayment(finalAmount, invoice, b, Seq((trampolineFees, CltvExpiryDelta(12))), routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) sender.send(initiator, req) val id = sender.expectMsgType[UUID] @@ -365,24 +343,10 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike payFsm.expectNoMessage(50 millis) } - test("reject trampoline payment with onion too big") { f => - import f._ - val paymentMetadata = ByteVector.fromValidHex("01" * 400) - val invoice = Bolt11Invoice(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, priv_c.privateKey, Left("Much payment very metadata"), CltvExpiryDelta(9), features = featuresWithTrampoline, paymentMetadata = Some(paymentMetadata)) - val trampolineFees = 21000 msat - val req = SendTrampolinePayment(finalAmount, invoice, b, Seq((trampolineFees, CltvExpiryDelta(12))), routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) - sender.send(initiator, req) - val id = sender.expectMsgType[UUID] - val fail = sender.expectMsgType[PaymentFailed] - assert(fail.id == id) - assert(fail.failures.length == 1) - assert(fail.failures.head.asInstanceOf[LocalFailure].t.getMessage == "requirement failed: packet per-hop payloads cannot exceed 400 bytes") - } - test("retry trampoline payment") { f => import f._ val invoice = Bolt11Invoice(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, priv_c.privateKey, Left("Some phoenix invoice"), CltvExpiryDelta(18), features = featuresWithTrampoline) - val trampolineAttempts = (21000 msat, CltvExpiryDelta(12)) :: (25000 msat, CltvExpiryDelta(24)) :: Nil + val trampolineAttempts = (21_000 msat, CltvExpiryDelta(12)) :: (25_000 msat, CltvExpiryDelta(24)) :: Nil val req = SendTrampolinePayment(finalAmount, invoice, b, trampolineAttempts, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) sender.send(initiator, req) val id = sender.expectMsgType[UUID] @@ -391,16 +355,18 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(!cfg.publishEvent) val msg1 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] - assert(msg1.totalAmount == finalAmount + 21000.msat) + assert(msg1.recipient.totalAmount == finalAmount) + assert(msg1.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + 21_000.msat) sender.send(initiator, GetPayment(Left(id))) sender.expectMsgType[PaymentIsPending] // Simulate a failure which should trigger a retry. - multiPartPayFsm.send(initiator, PaymentFailed(cfg.parentId, invoice.paymentHash, Seq(RemoteFailure(msg1.totalAmount, Nil, Sphinx.DecryptedFailurePacket(b, TrampolineFeeInsufficient))))) + multiPartPayFsm.send(initiator, PaymentFailed(cfg.parentId, invoice.paymentHash, Seq(RemoteFailure(msg1.recipient.totalAmount, Nil, Sphinx.DecryptedFailurePacket(b, TrampolineFeeInsufficient))))) multiPartPayFsm.expectMsgType[SendPaymentConfig] val msg2 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] - assert(msg2.totalAmount == finalAmount + 25000.msat) + assert(msg2.recipient.totalAmount == finalAmount) + assert(msg2.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + 25_000.msat) // Simulate success which should publish the event and respond to the original sender. val success = PaymentSent(cfg.parentId, invoice.paymentHash, randomBytes32(), finalAmount, c, Seq(PaymentSent.PartialPayment(UUID.randomUUID(), 1000 msat, 500 msat, randomBytes32(), None))) @@ -417,7 +383,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike test("retry trampoline payment and fail") { f => import f._ val invoice = Bolt11Invoice(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, priv_c.privateKey, Left("Some phoenix invoice"), CltvExpiryDelta(18), features = featuresWithTrampoline) - val trampolineAttempts = (21000 msat, CltvExpiryDelta(12)) :: (25000 msat, CltvExpiryDelta(24)) :: Nil + val trampolineAttempts = (21_000 msat, CltvExpiryDelta(12)) :: (25_000 msat, CltvExpiryDelta(24)) :: Nil val req = SendTrampolinePayment(finalAmount, invoice, b, trampolineAttempts, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) sender.send(initiator, req) sender.expectMsgType[UUID] @@ -426,16 +392,18 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(!cfg.publishEvent) val msg1 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] - assert(msg1.totalAmount == finalAmount + 21000.msat) + assert(msg1.recipient.totalAmount == finalAmount) + assert(msg1.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + 21_000.msat) // Simulate a failure which should trigger a retry. - multiPartPayFsm.send(initiator, PaymentFailed(cfg.parentId, invoice.paymentHash, Seq(RemoteFailure(msg1.totalAmount, Nil, Sphinx.DecryptedFailurePacket(b, TrampolineFeeInsufficient))))) + multiPartPayFsm.send(initiator, PaymentFailed(cfg.parentId, invoice.paymentHash, Seq(RemoteFailure(msg1.recipient.totalAmount, Nil, Sphinx.DecryptedFailurePacket(b, TrampolineFeeInsufficient))))) multiPartPayFsm.expectMsgType[SendPaymentConfig] val msg2 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] - assert(msg2.totalAmount == finalAmount + 25000.msat) + assert(msg2.recipient.totalAmount == finalAmount) + assert(msg2.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + 25_000.msat) // Simulate a failure that exhausts payment attempts. - val failed = PaymentFailed(cfg.parentId, invoice.paymentHash, Seq(RemoteFailure(msg2.totalAmount, Nil, Sphinx.DecryptedFailurePacket(b, TemporaryNodeFailure)))) + val failed = PaymentFailed(cfg.parentId, invoice.paymentHash, Seq(RemoteFailure(msg2.recipient.totalAmount, Nil, Sphinx.DecryptedFailurePacket(b, TemporaryNodeFailure)))) multiPartPayFsm.send(initiator, failed) sender.expectMsg(failed) eventListener.expectMsg(failed) @@ -446,25 +414,25 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike test("retry trampoline payment and fail (route not found)") { f => import f._ val invoice = Bolt11Invoice(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, priv_c.privateKey, Left("Some phoenix invoice"), CltvExpiryDelta(18), features = featuresWithTrampoline) - val trampolineAttempts = (21000 msat, CltvExpiryDelta(12)) :: (25000 msat, CltvExpiryDelta(24)) :: Nil + val trampolineAttempts = (21_000 msat, CltvExpiryDelta(12)) :: (25_000 msat, CltvExpiryDelta(24)) :: Nil val req = SendTrampolinePayment(finalAmount, invoice, b, trampolineAttempts, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) sender.send(initiator, req) sender.expectMsgType[UUID] val cfg = multiPartPayFsm.expectMsgType[SendPaymentConfig] val msg1 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] - assert(msg1.totalAmount == finalAmount + 21000.msat) + assert(msg1.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + 21_000.msat) // Trampoline node couldn't find a route for the given fee. - val failed = PaymentFailed(cfg.parentId, invoice.paymentHash, Seq(RemoteFailure(msg1.totalAmount, Nil, Sphinx.DecryptedFailurePacket(b, TrampolineFeeInsufficient)))) + val failed = PaymentFailed(cfg.parentId, invoice.paymentHash, Seq(RemoteFailure(msg1.recipient.totalAmount, Nil, Sphinx.DecryptedFailurePacket(b, TrampolineFeeInsufficient)))) multiPartPayFsm.send(initiator, failed) multiPartPayFsm.expectMsgType[SendPaymentConfig] val msg2 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] - assert(msg2.totalAmount == finalAmount + 25000.msat) + assert(msg2.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + 25_000.msat) // Trampoline node couldn't find a route even with the increased fee. multiPartPayFsm.send(initiator, failed) val failure = sender.expectMsgType[PaymentFailed] - assert(failure.failures == Seq(LocalFailure(finalAmount, Seq(NodeHop(nodeParams.nodeId, b, nodeParams.channelConf.expiryDelta, 0 msat), NodeHop(b, c, CltvExpiryDelta(24), 25000 msat)), RouteNotFound))) + assert(failure.failures == Seq(LocalFailure(finalAmount, Seq(NodeHop(b, c, CltvExpiryDelta(24), 25_000 msat)), RouteNotFound))) eventListener.expectMsg(failure) sender.expectNoMessage(100 millis) eventListener.expectNoMessage(100 millis) @@ -473,30 +441,20 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike test("forward trampoline payment with pre-defined route") { f => import f._ val invoice = Bolt11Invoice(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, priv_c.privateKey, Left("Some invoice"), CltvExpiryDelta(18)) - val trampolineFees = 100 msat - val route = PredefinedNodeRoute(Seq(a, b)) - val req = SendPaymentToRoute(finalAmount + trampolineFees, finalAmount, invoice, route, None, None, None, trampolineFees, CltvExpiryDelta(144), Seq(b, c)) + val trampolineAttempt = TrampolineAttempt(randomBytes32(), 100 msat, CltvExpiryDelta(144)) + val route = PredefinedNodeRoute(finalAmount + trampolineAttempt.fees, Seq(a, b)) + val req = SendPaymentToRoute(finalAmount, invoice, route, None, None, Some(trampolineAttempt)) sender.send(initiator, req) val payment = sender.expectMsgType[SendPaymentToRouteResponse] - assert(payment.trampolineSecret.nonEmpty) - payFsm.expectMsg(SendPaymentConfig(payment.paymentId, payment.parentId, None, paymentHash, finalAmount, c, Upstream.Local(payment.paymentId), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = false, Seq(NodeHop(b, c, CltvExpiryDelta(0), 0 msat)))) + assert(payment.trampolineSecret.contains(trampolineAttempt.paymentSecret)) + payFsm.expectMsg(SendPaymentConfig(payment.paymentId, payment.parentId, None, paymentHash, c, Upstream.Local(payment.paymentId), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = false)) val msg = payFsm.expectMsgType[PaymentLifecycle.SendPaymentToRoute] assert(msg.route == Left(route)) - assert(msg.finalPayload.isInstanceOf[FinalPayload.Standard]) - assert(msg.finalPayload.amount == finalAmount + trampolineFees) - assert(msg.finalPayload.asInstanceOf[FinalPayload.Standard].paymentSecret == payment.trampolineSecret.get) - assert(msg.finalPayload.totalAmount == finalAmount + trampolineFees) - val trampolineOnion = msg.finalPayload.records.get[OnionPaymentPayloadTlv.TrampolineOnion] - assert(trampolineOnion.nonEmpty) - - // Verify that the trampoline node can correctly peel the trampoline onion. - val Right(decrypted) = Sphinx.peel(priv_b.privateKey, Some(invoice.paymentHash), trampolineOnion.get.packet) - assert(!decrypted.isLastPacket) - val Right(trampolinePayload) = IntermediatePayload.NodeRelay.Standard.validate(PaymentOnionCodecs.perHopPayloadCodec.decode(decrypted.payload.bits).require.value) - assert(trampolinePayload.amountToForward == finalAmount) - assert(trampolinePayload.totalAmount == finalAmount) - assert(trampolinePayload.outgoingNodeId == c) - assert(trampolinePayload.paymentSecret.contains(invoice.paymentSecret)) + assert(msg.amount == finalAmount + trampolineAttempt.fees) + assert(msg.recipient.totalAmount == finalAmount) + assert(msg.recipient.isInstanceOf[ClearTrampolineRecipient]) + assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + trampolineAttempt.fees) + assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolinePaymentSecret == payment.trampolineSecret.get) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index 3bbcd723c..e97cf86ff 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -31,13 +31,13 @@ import fr.acinq.eclair.channel.fsm.Channel import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus, PaymentType} import fr.acinq.eclair.io.Peer.PeerRoutingMessage -import fr.acinq.eclair.payment.Invoice.BasicEdge +import fr.acinq.eclair.payment.Invoice.ExtraEdge import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream import fr.acinq.eclair.payment.PaymentSent.PartialPayment import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig -import fr.acinq.eclair.payment.send.PaymentLifecycle import fr.acinq.eclair.payment.send.PaymentLifecycle._ +import fr.acinq.eclair.payment.send.{ClearRecipient, PaymentLifecycle} import fr.acinq.eclair.router.Announcements.makeChannelUpdate import fr.acinq.eclair.router.BaseRouterSpec.{channelAnnouncement, channelHopFromUpdate} import fr.acinq.eclair.router.Graph.WeightRatios @@ -56,17 +56,17 @@ import scala.concurrent.duration._ class PaymentLifecycleSpec extends BaseRouterSpec { - val defaultAmountMsat = 142000000 msat - val defaultMaxFee = 4260000 msat // 3% of defaultAmountMsat - val defaultExpiry = Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry(BlockHeight(40000)) + val defaultAmountMsat = 142_000_000 msat + val defaultExpiry = Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry(BlockHeight(40_000)) val defaultPaymentPreimage = randomBytes32() val defaultPaymentHash = Crypto.sha256(defaultPaymentPreimage) val defaultOrigin = Origin.LocalCold(UUID.randomUUID()) val defaultExternalId = UUID.randomUUID().toString val defaultInvoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, defaultPaymentHash, priv_d, Left("test"), Channel.MIN_CLTV_EXPIRY_DELTA) + val defaultRecipient = ClearRecipient(defaultInvoice, defaultAmountMsat, defaultExpiry, Nil) val defaultRouteParams = TestConstants.Alice.nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams - def defaultRouteRequest(source: PublicKey, target: PublicKey, cfg: SendPaymentConfig): RouteRequest = RouteRequest(source, target, defaultAmountMsat, defaultMaxFee, paymentContext = Some(cfg.paymentContext), routeParams = defaultRouteParams) + def defaultRouteRequest(source: PublicKey, cfg: SendPaymentConfig): RouteRequest = RouteRequest(source, defaultRecipient, defaultRouteParams, paymentContext = Some(cfg.paymentContext)) case class PaymentFixture(cfg: SendPaymentConfig, nodeParams: NodeParams, @@ -81,7 +81,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { def createPaymentLifecycle(storeInDb: Boolean = true, publishEvent: Boolean = true, recordMetrics: Boolean = true): PaymentFixture = { val (id, parentId) = (UUID.randomUUID(), UUID.randomUUID()) val nodeParams = TestConstants.Alice.nodeParams.copy(nodeKeyManager = testNodeKeyManager, channelKeyManager = testChannelKeyManager) - val cfg = SendPaymentConfig(id, parentId, Some(defaultExternalId), defaultPaymentHash, defaultAmountMsat, d, Upstream.Local(id), Some(defaultInvoice), storeInDb, publishEvent, recordMetrics, Nil) + val cfg = SendPaymentConfig(id, parentId, Some(defaultExternalId), defaultPaymentHash, d, Upstream.Local(id), Some(defaultInvoice), storeInDb, publishEvent, recordMetrics) val (routerForwarder, register, sender, monitor, eventListener, metricsListener) = (TestProbe(), TestProbe(), TestProbe(), TestProbe(), TestProbe(), TestProbe()) val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, cfg, routerForwarder.ref, register.ref)) paymentFSM ! SubscribeTransitionCallBack(monitor.ref) @@ -104,8 +104,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import cfg._ // pre-computed route going from A to D - val route = Route(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: Nil) - val request = SendPaymentToRoute(sender.ref, Right(route), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata)) + val route = Route(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: Nil, None) + val request = SendPaymentToRoute(sender.ref, Right(route), defaultRecipient) sender.send(paymentFSM, request) routerForwarder.expectNoMessage(100 millis) // we don't need the router, we have the pre-computed route val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -121,7 +121,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { assert(ps.parts.head.route.contains(route.hops)) awaitCond(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Succeeded])) - metricsListener.expectNoMessage() + metricsListener.expectNoMessage(100 millis) assert(routerForwarder.expectMsgType[RouteDidRelay].route === route) } @@ -132,11 +132,11 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import cfg._ // pre-computed route going from A to D - val route = PredefinedNodeRoute(Seq(a, b, c, d)) - val request = SendPaymentToRoute(sender.ref, Left(route), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata)) + val route = PredefinedNodeRoute(defaultAmountMsat, Seq(a, b, c, d)) + val request = SendPaymentToRoute(sender.ref, Left(route), defaultRecipient) sender.send(paymentFSM, request) - routerForwarder.expectMsg(FinalizeRoute(defaultAmountMsat, route, paymentContext = Some(cfg.paymentContext))) + routerForwarder.expectMsg(FinalizeRoute(route, paymentContext = Some(cfg.paymentContext))) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) routerForwarder.forward(routerFixture.router) @@ -150,7 +150,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { assert(ps.id == parentId) awaitCond(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Succeeded])) - metricsListener.expectNoMessage() + metricsListener.expectNoMessage(100 millis) assert(routerForwarder.expectMsgType[RouteDidRelay].route.hops.map(_.nodeId) === Seq(a, b, c)) } @@ -159,7 +159,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val payFixture = createPaymentLifecycle(recordMetrics = false) import payFixture._ - val brokenRoute = SendPaymentToRoute(sender.ref, Left(PredefinedNodeRoute(Seq(randomKey().publicKey, randomKey().publicKey, randomKey().publicKey))), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata)) + val brokenRoute = SendPaymentToRoute(sender.ref, Left(PredefinedNodeRoute(defaultAmountMsat, Seq(randomKey().publicKey, randomKey().publicKey, randomKey().publicKey))), defaultRecipient) sender.send(paymentFSM, brokenRoute) routerForwarder.expectMsgType[FinalizeRoute] routerForwarder.forward(routerFixture.router) @@ -167,7 +167,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val failureMessage = eventListener.expectMsgType[PaymentFailed].failures.head.asInstanceOf[LocalFailure].t.getMessage assert(failureMessage == "Not all the nodes in the supplied route are connected with public channels") - metricsListener.expectNoMessage() + metricsListener.expectNoMessage(100 millis) routerForwarder.expectNoMessage(100 millis) } @@ -176,7 +176,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val payFixture = createPaymentLifecycle(recordMetrics = false) import payFixture._ - val brokenRoute = SendPaymentToRoute(sender.ref, Left(PredefinedChannelRoute(randomKey().publicKey, Seq(ShortChannelId(1), ShortChannelId(2)))), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata)) + val brokenRoute = SendPaymentToRoute(sender.ref, Left(PredefinedChannelRoute(defaultAmountMsat, randomKey().publicKey, Seq(ShortChannelId(1), ShortChannelId(2)))), defaultRecipient) sender.send(paymentFSM, brokenRoute) routerForwarder.expectMsgType[FinalizeRoute] routerForwarder.forward(routerFixture.router) @@ -184,7 +184,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val failureMessage = eventListener.expectMsgType[PaymentFailed].failures.head.asInstanceOf[LocalFailure].t.getMessage assert(failureMessage == "The sequence of channels provided cannot be used to build a route to the target node") - metricsListener.expectNoMessage() + metricsListener.expectNoMessage(100 millis) routerForwarder.expectNoMessage(100 millis) } @@ -194,13 +194,14 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val recipient = randomKey().publicKey - val route = PredefinedNodeRoute(Seq(a, b, c, recipient)) - val extraEdges = Seq(BasicEdge(c, recipient, ShortChannelId(561), 1 msat, 100, CltvExpiryDelta(144))) - val request = SendPaymentToRoute(sender.ref, Left(route), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), extraEdges) + val recipientNodeId = randomKey().publicKey + val route = PredefinedNodeRoute(defaultAmountMsat, Seq(a, b, c, recipientNodeId)) + val extraEdges = Seq(ExtraEdge(c, recipientNodeId, ShortChannelId(561), 1 msat, 100, CltvExpiryDelta(144), 1 msat, None)) + val recipient = ClearRecipient(recipientNodeId, Features.empty, defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, extraEdges) + val request = SendPaymentToRoute(sender.ref, Left(route), recipient) sender.send(paymentFSM, request) - routerForwarder.expectMsg(FinalizeRoute(defaultAmountMsat, route, extraEdges, paymentContext = Some(cfg.paymentContext))) + routerForwarder.expectMsg(FinalizeRoute(route, extraEdges, paymentContext = Some(cfg.paymentContext))) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) routerForwarder.forward(routerFixture.router) @@ -213,7 +214,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { assert(ps.id == parentId) awaitCond(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Succeeded])) - metricsListener.expectNoMessage() + metricsListener.expectNoMessage(100 millis) assert(routerForwarder.expectMsgType[RouteDidRelay].route.hops.map(_.nodeId) === Seq(a, b, c)) } @@ -223,7 +224,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, f, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 5, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, ClearRecipient(f, Features.empty, defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret), 5, defaultRouteParams) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val routeRequest = routerForwarder.expectMsgType[RouteRequest] @@ -237,8 +238,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { assert(metrics.status == "FAILURE") assert(metrics.experimentName == "alice-test-experiment") assert(metrics.amount == defaultAmountMsat) - assert(metrics.fees == 4260000.msat) - metricsListener.expectNoMessage() + assert(metrics.fees == 4_260_000.msat) + metricsListener.expectNoMessage(100 millis) routerForwarder.expectNoMessage(100 millis) } @@ -252,11 +253,11 @@ class PaymentLifecycleSpec extends BaseRouterSpec { randomize = false, boundaries = SearchBoundaries(100 msat, 0.0, 20, CltvExpiryDelta(2016)), Left(WeightRatios(1, 0, 0, 0, RelayFees(0 msat, 0))), - MultiPartParams(10000 msat, 5), + MultiPartParams(10_000 msat, 5), "my-test-experiment", experimentPercentage = 100 ).getDefaultRouteParams - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 5, routeParams = routeParams) + val request = SendPaymentToNode(sender.ref, defaultRecipient, 5, routeParams) sender.send(paymentFSM, request) val routeRequest = routerForwarder.expectMsgType[RouteRequest] val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -270,7 +271,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { assert(metrics.experimentName == "my-test-experiment") assert(metrics.amount == defaultAmountMsat) assert(metrics.fees == 100.msat) - metricsListener.expectNoMessage() + metricsListener.expectNoMessage(100 millis) routerForwarder.expectNoMessage(100 millis) } @@ -281,7 +282,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import cfg._ val paymentMetadataTooBig = ByteVector.fromValidHex("01" * 1300) - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, Some(paymentMetadataTooBig)), 5, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, ClearRecipient(d, Features.empty, defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, paymentMetadata_opt = Some(paymentMetadataTooBig)), 5, defaultRouteParams) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val routeRequest = routerForwarder.expectMsgType[RouteRequest] @@ -300,9 +301,9 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 2, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, defaultRecipient, 2, defaultRouteParams) sender.send(paymentFSM, request) - routerForwarder.expectMsg(defaultRouteRequest(a, d, cfg)) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg)) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, Nil, _) = paymentFSM.stateData @@ -315,7 +316,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, randomBytes32())))) // unparsable message // then the payment lifecycle will ask for a new route excluding all intermediate nodes - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(ignore = Ignore(Set(c), Set.empty))) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg).copy(ignore = Ignore(Set(c), Set.empty))) // let's simulate a response by the router with another route sender.send(paymentFSM, RouteResponse(route :: Nil)) @@ -334,8 +335,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { assert(metrics.status == "FAILURE") assert(metrics.experimentName == "alice-test-experiment") assert(metrics.amount == defaultAmountMsat) - assert(metrics.fees == 4260000.msat) - metricsListener.expectNoMessage() + assert(metrics.fees == 4_260_000.msat) + metricsListener.expectNoMessage(100 millis) routerForwarder.expectNoMessage(100 millis) } @@ -345,7 +346,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 2, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, defaultRecipient, 2, defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) routerForwarder.expectMsgType[RouteRequest] @@ -357,7 +358,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, RES_ADD_FAILED(cmd1, ChannelUnavailable(ByteVector32.Zeroes), None)) // then the payment lifecycle will ask for a new route excluding the channel - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(scid_ab, a, b))))) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(scid_ab, a, b))))) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) // payment is still pending because the error is recoverable } @@ -366,7 +367,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 2, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, defaultRecipient, 2, defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) routerForwarder.expectMsgType[RouteRequest] @@ -377,7 +378,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { register.send(paymentFSM, ForwardShortIdFailure(fwd)) // then the payment lifecycle will ask for a new route excluding the channel - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(scid_ab, a, b))))) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(scid_ab, a, b))))) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) // payment is still pending because the error is recoverable } @@ -386,12 +387,12 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 2, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, defaultRecipient, 2, defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, Nil, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, cmd1, Nil, _, _, _) = paymentFSM.stateData @@ -400,7 +401,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFailMalformed(UpdateFailMalformedHtlc(ByteVector32.Zeroes, 0, randomBytes32(), FailureMessageCodecs.BADONION)))) // then the payment lifecycle will ask for a new route excluding the channel - routerForwarder.expectMsg(defaultRouteRequest(a, d, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(scid_ab, a, b))))) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(scid_ab, a, b))))) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) } @@ -409,12 +410,12 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 2, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, defaultRecipient, 2, defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, Nil, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, cmd1, Nil, _, _, _) = paymentFSM.stateData @@ -432,12 +433,12 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 2, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, defaultRecipient, 2, defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, Nil, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, cmd1, Nil, _, _, _) = paymentFSM.stateData @@ -447,7 +448,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, addCompleted(HtlcResult.DisconnectedBeforeSigned(update_bc_disabled))) // then the payment lifecycle will ask for a new route excluding the channel - routerForwarder.expectMsg(defaultRouteRequest(a, d, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(scid_ab, a, b))))) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(scid_ab, a, b))))) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) } @@ -455,11 +456,11 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val payFixture = createPaymentLifecycle() import payFixture._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 2, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, defaultRecipient, 2, defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) val WaitingForRoute(_, Nil, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, cmd1, Nil, sharedSecrets1, _, route) = paymentFSM.stateData @@ -475,7 +476,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // payment lifecycle forwards the embedded channelUpdate to the router routerForwarder.expectMsg(update_bc) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) - routerForwarder.expectMsg(defaultRouteRequest(a, d, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(update_bc.shortChannelId, b, c))))) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(update_bc.shortChannelId, b, c))))) routerForwarder.forward(routerFixture.router) // we allow 2 tries, so we send a 2nd request to the router assert(sender.expectMsgType[PaymentFailed].failures == RemoteFailure(route.amount, route.hops, Sphinx.DecryptedFailurePacket(b, failure)) :: LocalFailure(defaultAmountMsat, Nil, RouteNotFound) :: Nil) @@ -486,12 +487,12 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 5, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, defaultRecipient, 5, defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, Nil, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, cmd1, Nil, sharedSecrets1, _, route1) = paymentFSM.stateData @@ -506,7 +507,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // payment lifecycle forwards the embedded channelUpdate to the router routerForwarder.expectMsg(channelUpdate_bc_modified) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) // 1 failure but not final, the payment is still PENDING - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg)) routerForwarder.forward(routerFixture.router) // router answers with a new route, taking into account the new update @@ -526,7 +527,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // but it will still forward the embedded channelUpdate to the router routerForwarder.expectMsg(channelUpdate_bc_modified_2) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg)) routerForwarder.forward(routerFixture.router) // this time the router can't find a route: game over @@ -540,9 +541,9 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val payFixture = createPaymentLifecycle() import payFixture._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 1, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, defaultRecipient, 1, defaultRouteParams) sender.send(paymentFSM, request) - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, cmd1, Nil, sharedSecrets1, _, _) = paymentFSM.stateData @@ -566,17 +567,16 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import cfg._ // we build an assisted route for channel bc and cd - val extraEdges = Seq( - BasicEdge(b, c, scid_bc, update_bc.feeBaseMsat, update_bc.feeProportionalMillionths, update_bc.cltvExpiryDelta), - BasicEdge(c, d, scid_cd, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.cltvExpiryDelta) - ) - - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 5, extraEdges = extraEdges, routeParams = defaultRouteParams) + val recipient = ClearRecipient(d, Features.empty, defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, Seq( + ExtraEdge(b, c, scid_bc, update_bc.feeBaseMsat, update_bc.feeProportionalMillionths, update_bc.cltvExpiryDelta, 1 msat, None), + ExtraEdge(c, d, scid_cd, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.cltvExpiryDelta, 1 msat, None) + )) + val request = SendPaymentToNode(sender.ref, recipient, 5, defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, Nil, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(extraEdges = extraEdges)) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg).copy(target = recipient)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, cmd1, Nil, sharedSecrets1, _, _) = paymentFSM.stateData @@ -592,10 +592,10 @@ class PaymentLifecycleSpec extends BaseRouterSpec { routerForwarder.expectMsg(channelUpdate_bc_modified) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) // 1 failure but not final, the payment is still PENDING val extraEdges1 = Seq( - extraEdges(0).update(channelUpdate_bc_modified), - extraEdges(1) + recipient.extraEdges(0).update(channelUpdate_bc_modified), + recipient.extraEdges(1) ) - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(extraEdges = extraEdges1)) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg).copy(target = recipient.copy(extraEdges = extraEdges1))) routerForwarder.forward(routerFixture.router) // router answers with a new route, taking into account the new update @@ -611,12 +611,14 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import cfg._ // we build an assisted route for channel cd - val extraEdges = Seq(BasicEdge(c, d, scid_cd, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.cltvExpiryDelta)) - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 1, extraEdges = extraEdges, routeParams = defaultRouteParams) + val recipient = ClearRecipient(d, Features.empty, defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, Seq( + ExtraEdge(c, d, scid_cd, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.cltvExpiryDelta, 1 msat, None) + )) + val request = SendPaymentToNode(sender.ref, recipient, 1, defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(extraEdges = extraEdges)) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg).copy(target = recipient)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, cmd1, Nil, sharedSecrets1, _, _) = paymentFSM.stateData @@ -638,12 +640,12 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 2, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, defaultRecipient, 2, defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, Nil, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg)) routerForwarder.forward(router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, cmd1, Nil, sharedSecrets1, _, route1) = paymentFSM.stateData @@ -653,7 +655,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // payment lifecycle forwards the embedded channelUpdate to the router awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(scid_bc, b, c))))) + routerForwarder.expectMsg(defaultRouteRequest(a, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(scid_bc, b, c))))) routerForwarder.forward(router) // we allow 2 tries, so we send a 2nd request to the router, which won't find another route @@ -676,7 +678,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 5, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, defaultRecipient, 5, defaultRouteParams) sender.send(paymentFSM, request) routerForwarder.expectMsgType[RouteRequest] val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -701,7 +703,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { assert(metrics.experimentName == "alice-test-experiment") assert(metrics.amount == defaultAmountMsat) assert(metrics.fees == 730.msat) - metricsListener.expectNoMessage() + metricsListener.expectNoMessage(100 millis) assert(routerForwarder.expectMsgType[RouteDidRelay].route.hops.map(_.shortChannelId) == Seq(update_ab, update_bc, update_cd).map(_.shortChannelId)) } @@ -732,7 +734,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ // we send a payment to H - val request = SendPaymentToNode(sender.ref, h, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 5, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, ClearRecipient(h, Features.empty, defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret), 5, defaultRouteParams) sender.send(paymentFSM, request) routerForwarder.expectMsgType[RouteRequest] @@ -744,20 +746,20 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, addCompleted(HtlcResult.OnChainFulfill(defaultPaymentPreimage))) val paymentOK = sender.expectMsgType[PaymentSent] val PaymentSent(_, _, paymentOK.paymentPreimage, finalAmount, _, PartialPayment(_, partAmount, fee, ByteVector32.Zeroes, _, _) :: Nil) = eventListener.expectMsgType[PaymentSent] - assert(partAmount == request.finalPayload.amount) + assert(partAmount == request.amount) assert(finalAmount == defaultAmountMsat) // NB: A -> B doesn't pay fees because it's our direct neighbor // NB: B -> H doesn't asks for fees at all assert(fee == 0.msat) - assert(paymentOK.recipientAmount == request.finalPayload.amount) + assert(paymentOK.recipientAmount == request.amount) val metrics = metricsListener.expectMsgType[PathFindingExperimentMetrics] assert(metrics.status == "SUCCESS") assert(metrics.experimentName == "alice-test-experiment") assert(metrics.amount == defaultAmountMsat) assert(metrics.fees == 0.msat) - metricsListener.expectNoMessage() + metricsListener.expectNoMessage(100 millis) assert(routerForwarder.expectMsgType[RouteDidRelay].route.hops.map(_.shortChannelId) == Seq(update_ab, channelUpdate_bh).map(_.shortChannelId)) } @@ -795,7 +797,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { (RemoteFailure(defaultAmountMsat, route_abcd, Sphinx.DecryptedFailurePacket(b, FeeInsufficient(100 msat, update_bc))), Set.empty, Set.empty), // unreadable remote failures -> blacklist all nodes except our direct peer and the final recipient (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: Nil), Set.empty, Set.empty), - (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: ChannelHop(ShortChannelId(5656986L), d, e, null) :: Nil), Set(c, d), Set.empty) + (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: NodeHop(d, e, CltvExpiryDelta(24), 0 msat) :: Nil), Set(c, d), Set.empty) ) for ((failure, expectedNodes, expectedChannels) <- testCases) { @@ -819,7 +821,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 3, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, defaultRecipient, 3, defaultRouteParams) sender.send(paymentFSM, request) routerForwarder.expectMsgType[RouteRequest] val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -841,8 +843,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import cfg._ // pre-computed route going from A to D - val route = Route(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: Nil) - val request = SendPaymentToRoute(sender.ref, Right(route), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata)) + val route = Route(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: Nil, None) + val request = SendPaymentToRoute(sender.ref, Right(route), defaultRecipient) sender.send(paymentFSM, request) routerForwarder.expectNoMessage(100 millis) // we don't need the router, we have the pre-computed route val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala index 168554719..a607d9bad 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala @@ -17,18 +17,17 @@ package fr.acinq.eclair.payment import akka.actor.ActorRef -import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.DeterministicWallet.ExtendedPrivateKey import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto, DeterministicWallet, OutPoint, Satoshi, SatoshiLong, TxOut} import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features._ import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel.fsm.Channel -import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.IncomingPaymentPacket.{ChannelRelayPacket, FinalPacket, NodeRelayPacket, decrypt} import fr.acinq.eclair.payment.OutgoingPaymentPacket._ +import fr.acinq.eclair.payment.send.{ClearRecipient, ClearTrampolineRecipient} import fr.acinq.eclair.router.BaseRouterSpec.channelHopFromUpdate -import fr.acinq.eclair.router.Router.NodeHop +import fr.acinq.eclair.router.Router.{NodeHop, Route} import fr.acinq.eclair.transactions.Transactions.InputInfo import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv.{AmountToForward, OutgoingCltv, PaymentData} import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload} @@ -36,11 +35,9 @@ import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, ShortChannelId, TestConstants, TimestampSecondLong, UInt64, nodeFee, randomBytes32, randomKey} import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite -import scodec.Attempt import scodec.bits.{ByteVector, HexStringSyntax} import java.util.UUID -import scala.util.Success /** * Created by PM on 31/05/2016. @@ -53,24 +50,25 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { implicit val log: akka.event.LoggingAdapter = akka.event.NoLogging test("compute fees") { - val feeBaseMsat = 150000 msat - val feeProportionalMillionth = 4L - val htlcAmountMsat = 42000000 msat + val feeBaseMsat = 150_000 msat + val feeProportionalMillionth = 4 + val htlcAmountMsat = 42_000_000 msat // spec: fee-base-msat + htlc-amount-msat * fee-proportional-millionths / 1000000 - val ref = feeBaseMsat + htlcAmountMsat * feeProportionalMillionth / 1000000 + val ref = feeBaseMsat + htlcAmountMsat * feeProportionalMillionth / 1_000_000 val fee = nodeFee(feeBaseMsat, feeProportionalMillionth, htlcAmountMsat) assert(ref == fee) } - def testBuildOnion(): Unit = { - val Right(finalPayload) = FinalPayload.Standard.validate(TlvStream(AmountToForward(finalAmount), OutgoingCltv(finalExpiry), PaymentData(paymentSecret, 0 msat))) - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, hops, finalPayload) - assert(firstAmount == amount_ab) - assert(firstExpiry == expiry_ab) - assert(onion.packet.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) + def testBuildOutgoingPayment(): Unit = { + val recipient = ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), recipient) + assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) + assert(payment.cmd.amount == amount_ab) + assert(payment.cmd.cltvExpiry == expiry_ab) + assert(payment.cmd.onion.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) // let's peel the onion - testPeelOnion(onion.packet) + testPeelOnion(payment.cmd.onion) } def testPeelOnion(packet_b: OnionRoutingPacket): Unit = { @@ -114,30 +112,21 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(payload_e.asInstanceOf[FinalPayload.Standard].paymentSecret == paymentSecret) } - test("build onion with final payload") { - testBuildOnion() + test("build outgoing payment onion") { + testBuildOutgoingPayment() } - test("build a command including the onion") { - val Success((add, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID), paymentHash, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) - assert(add.amount > finalAmount) - assert(add.cltvExpiry == finalExpiry + channelUpdate_de.cltvExpiryDelta + channelUpdate_cd.cltvExpiryDelta + channelUpdate_bc.cltvExpiryDelta) - assert(add.paymentHash == paymentHash) - assert(add.onion.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) + test("build outgoing payment for direct peer") { + val recipient = ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret, paymentMetadata_opt = Some(paymentMetadata)) + val route = Route(finalAmount, hops.take(1), None) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(payment.cmd.amount == finalAmount) + assert(payment.cmd.cltvExpiry == finalExpiry) + assert(payment.cmd.paymentHash == paymentHash) + assert(payment.cmd.onion.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) // let's peel the onion - testPeelOnion(add.onion) - } - - test("build a command with no hops") { - val Success((add, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, hops.take(1), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, Some(paymentMetadata))) - assert(add.amount == finalAmount) - assert(add.cltvExpiry == finalExpiry) - assert(add.paymentHash == paymentHash) - assert(add.onion.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) - - // let's peel the onion - val add_b = UpdateAddHtlc(randomBytes32(), 0, finalAmount, paymentHash, finalExpiry, add.onion, None) + val add_b = UpdateAddHtlc(randomBytes32(), 0, finalAmount, paymentHash, finalExpiry, payment.cmd.onion, None) val Right(FinalPacket(add_b2, payload_b)) = decrypt(add_b, priv_b.privateKey, Features.empty) assert(add_b2 == add_b) assert(payload_b.isInstanceOf[FinalPayload.Standard]) @@ -148,11 +137,13 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(payload_b.asInstanceOf[FinalPayload.Standard].paymentMetadata.contains(paymentMetadata)) } - test("build a command with greater amount and expiry") { - val Success((add, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID), paymentHash, hops.take(1), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) + test("build outgoing payment with greater amount and expiry") { + val recipient = ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret, paymentMetadata_opt = Some(paymentMetadata)) + val route = Route(finalAmount, hops.take(1), None) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) // let's peel the onion - val add_b = UpdateAddHtlc(randomBytes32(), 0, finalAmount + 100.msat, paymentHash, finalExpiry + CltvExpiryDelta(6), add.onion, None) + val add_b = UpdateAddHtlc(randomBytes32(), 0, finalAmount + 100.msat, paymentHash, finalExpiry + CltvExpiryDelta(6), payment.cmd.onion, None) val Right(FinalPacket(_, payload_b)) = decrypt(add_b, priv_b.privateKey, Features.empty) assert(payload_b.isInstanceOf[FinalPayload.Standard]) assert(payload_b.amount == finalAmount) @@ -161,209 +152,246 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(payload_b.asInstanceOf[FinalPayload.Standard].paymentSecret == paymentSecret) } - test("build a trampoline payment") { + test("build outgoing trampoline payment") { // simple trampoline route to e: - // .--. .--. - // / \ / \ - // a -> b -> c d e + // .----. + // / \ + // a -> b -> c e + val invoiceFeatures = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, PaymentMetadata -> Optional, TrampolinePaymentPrototype -> Optional) + val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_e.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures, paymentMetadata = Some(hex"010203")) + val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) + assert(recipient.trampolineAmount == amount_bc) + assert(recipient.trampolineExpiry == expiry_bc) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) + assert(payment.cmd.amount == amount_ab) + assert(payment.cmd.cltvExpiry == expiry_ab) - val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount * 3, finalExpiry, paymentSecret, Some(hex"010203"))) - assert(amount_ac == amount_bc) - assert(expiry_ac == expiry_bc) - - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet)) - assert(firstAmount == amount_ab) - assert(firstExpiry == expiry_ab) - - val add_b = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None) + val add_b = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) val Right(ChannelRelayPacket(add_b2, payload_b, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) assert(add_b2 == add_b) assert(payload_b == IntermediatePayload.ChannelRelay.Standard(channelUpdate_bc.shortChannelId, amount_bc, expiry_bc)) val add_c = UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc, packet_c, None) - val Right(NodeRelayPacket(add_c2, outer_c, inner_c, packet_d)) = decrypt(add_c, priv_c.privateKey, Features.empty) + val Right(NodeRelayPacket(add_c2, outer_c, inner_c, trampolinePacket_e)) = decrypt(add_c, priv_c.privateKey, Features.empty) assert(add_c2 == add_c) assert(outer_c.amount == amount_bc) assert(outer_c.totalAmount == amount_bc) assert(outer_c.expiry == expiry_bc) - assert(inner_c.amountToForward == amount_cd) - assert(inner_c.outgoingCltv == expiry_cd) - assert(inner_c.outgoingNodeId == d) + assert(outer_c.paymentSecret != invoice.paymentSecret) + assert(inner_c.amountToForward == finalAmount) + assert(inner_c.outgoingCltv == finalExpiry) + assert(inner_c.outgoingNodeId == e) assert(inner_c.invoiceRoutingInfo.isEmpty) assert(inner_c.invoiceFeatures.isEmpty) assert(inner_c.paymentSecret.isEmpty) assert(inner_c.paymentMetadata.isEmpty) - // c forwards the trampoline payment to d. - val Success((amount_d, expiry_d, onion_d)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(c, d, channelUpdate_cd) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_cd, amount_cd, expiry_cd, randomBytes32(), packet_d)) - assert(amount_d == amount_cd) - assert(expiry_d == expiry_cd) - val add_d = UpdateAddHtlc(randomBytes32(), 3, amount_d, paymentHash, expiry_d, onion_d.packet, None) - val Right(NodeRelayPacket(add_d2, outer_d, inner_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) + // c forwards the trampoline payment to e through d. + val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e)) + val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) + assert(payment_e.cmd.amount == amount_cd) + assert(payment_e.cmd.cltvExpiry == expiry_cd) + val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None) + val Right(ChannelRelayPacket(add_d2, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) assert(add_d2 == add_d) - assert(outer_d.amount == amount_cd) - assert(outer_d.totalAmount == amount_cd) - assert(outer_d.expiry == expiry_cd) - assert(inner_d.amountToForward == amount_de) - assert(inner_d.outgoingCltv == expiry_de) - assert(inner_d.outgoingNodeId == e) - assert(inner_d.invoiceRoutingInfo.isEmpty) - assert(inner_d.invoiceFeatures.isEmpty) - assert(inner_d.paymentSecret.isEmpty) - assert(inner_d.paymentMetadata.isEmpty) + assert(payload_d == IntermediatePayload.ChannelRelay.Standard(channelUpdate_de.shortChannelId, amount_de, expiry_de)) - // d forwards the trampoline payment to e. - val Success((amount_e, expiry_e, onion_e)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(d, e, channelUpdate_de) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_de, amount_de, expiry_de, randomBytes32(), packet_e)) - assert(amount_e == amount_de) - assert(expiry_e == expiry_de) - val add_e = UpdateAddHtlc(randomBytes32(), 4, amount_e, paymentHash, expiry_e, onion_e.packet, None) + val add_e = UpdateAddHtlc(randomBytes32(), 4, amount_de, paymentHash, expiry_de, packet_e, None) val Right(FinalPacket(add_e2, payload_e)) = decrypt(add_e, priv_e.privateKey, Features.empty) assert(add_e2 == add_e) - assert(payload_e == FinalPayload.Standard(TlvStream(AmountToForward(finalAmount), OutgoingCltv(finalExpiry), PaymentData(paymentSecret, finalAmount * 3), OnionPaymentPayloadTlv.PaymentMetadata(hex"010203")))) + assert(payload_e == FinalPayload.Standard(TlvStream(AmountToForward(finalAmount), OutgoingCltv(finalExpiry), PaymentData(paymentSecret, finalAmount), OnionPaymentPayloadTlv.PaymentMetadata(hex"010203")))) } - test("build a trampoline payment with non-trampoline recipient") { + test("build outgoing trampoline payment with non-trampoline recipient") { // simple trampoline route to e where e doesn't support trampoline: - // .--. - // / \ - // a -> b -> c d -> e - + // .----. + // / \ + // a -> b -> c e val routingHints = List(List(Bolt11Invoice.ExtraHop(randomKey().publicKey, ShortChannelId(42), 10 msat, 100, CltvExpiryDelta(144)))) val invoiceFeatures = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional) - val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, Some(finalAmount), paymentHash, priv_a.privateKey, Left("#reckless"), CltvExpiryDelta(18), None, None, routingHints, features = invoiceFeatures, paymentMetadata = Some(hex"010203")) - val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolineToLegacyPacket(invoice, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, invoice.paymentSecret, None)) - assert(amount_ac == amount_bc) - assert(expiry_ac == expiry_bc) + val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, Some(finalAmount), paymentHash, priv_e.privateKey, Left("#reckless"), CltvExpiryDelta(18), extraHops = routingHints, features = invoiceFeatures, paymentMetadata = Some(hex"010203")) + val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) + assert(recipient.trampolineAmount == amount_bc) + assert(recipient.trampolineExpiry == expiry_bc) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) + assert(payment.cmd.amount == amount_ab) + assert(payment.cmd.cltvExpiry == expiry_ab) - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet)) - assert(firstAmount == amount_ab) - assert(firstExpiry == expiry_ab) - - val add_b = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None) + val add_b = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) val add_c = UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc, packet_c, None) - val Right(NodeRelayPacket(_, outer_c, inner_c, packet_d)) = decrypt(add_c, priv_c.privateKey, Features.empty) + val Right(NodeRelayPacket(_, outer_c, inner_c, _)) = decrypt(add_c, priv_c.privateKey, Features.empty) assert(outer_c.amount == amount_bc) assert(outer_c.totalAmount == amount_bc) assert(outer_c.expiry == expiry_bc) - assert(outer_c.paymentSecret !== invoice.paymentSecret) - assert(inner_c.amountToForward == amount_cd) - assert(inner_c.outgoingCltv == expiry_cd) - assert(inner_c.outgoingNodeId == d) - assert(inner_c.invoiceRoutingInfo.isEmpty) - assert(inner_c.invoiceFeatures.isEmpty) - assert(inner_c.paymentSecret.isEmpty) + assert(outer_c.paymentSecret != invoice.paymentSecret) + assert(inner_c.amountToForward == finalAmount) + assert(inner_c.totalAmount == finalAmount) + assert(inner_c.outgoingCltv == finalExpiry) + assert(inner_c.outgoingNodeId == e) + assert(inner_c.paymentSecret.contains(invoice.paymentSecret)) + assert(inner_c.paymentMetadata.contains(hex"010203")) + assert(inner_c.invoiceFeatures.contains(invoiceFeatures.toByteVector)) + assert(inner_c.invoiceRoutingInfo.contains(routingHints)) - // c forwards the trampoline payment to d. - val Success((amount_d, expiry_d, onion_d)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(c, d, channelUpdate_cd) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_cd, amount_cd, expiry_cd, randomBytes32(), packet_d)) - assert(amount_d == amount_cd) - assert(expiry_d == expiry_cd) - val add_d = UpdateAddHtlc(randomBytes32(), 3, amount_d, paymentHash, expiry_d, onion_d.packet, None) - val Right(NodeRelayPacket(_, outer_d, inner_d, _)) = decrypt(add_d, priv_d.privateKey, Features.empty) - assert(outer_d.amount == amount_cd) - assert(outer_d.totalAmount == amount_cd) - assert(outer_d.expiry == expiry_cd) - assert(outer_d.paymentSecret !== invoice.paymentSecret) - assert(inner_d.amountToForward == finalAmount) - assert(inner_d.outgoingCltv == expiry_de) - assert(inner_d.outgoingNodeId == e) - assert(inner_d.totalAmount == finalAmount) - assert(inner_d.paymentSecret.contains(invoice.paymentSecret)) - assert(inner_d.paymentMetadata.contains(hex"010203")) - assert(inner_d.invoiceFeatures.contains(hex"024100")) // var_onion_optin, payment_secret, basic_mpp - assert(inner_d.invoiceRoutingInfo.contains(routingHints)) + // c forwards the trampoline payment to e through d. + val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, inner_c.paymentSecret.get, invoice.extraEdges, inner_c.paymentMetadata) + val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) + assert(payment_e.cmd.amount == amount_cd) + assert(payment_e.cmd.cltvExpiry == expiry_cd) + val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None) + val Right(ChannelRelayPacket(add_d2, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) + assert(add_d2 == add_d) + assert(payload_d == IntermediatePayload.ChannelRelay.Standard(channelUpdate_de.shortChannelId, amount_de, expiry_de)) + + val add_e = UpdateAddHtlc(randomBytes32(), 4, amount_de, paymentHash, expiry_de, packet_e, None) + val Right(FinalPacket(add_e2, payload_e)) = decrypt(add_e, priv_e.privateKey, Features.empty) + assert(add_e2 == add_e) + assert(payload_e == FinalPayload.Standard(TlvStream(AmountToForward(finalAmount), OutgoingCltv(finalExpiry), PaymentData(invoice.paymentSecret, finalAmount), OnionPaymentPayloadTlv.PaymentMetadata(hex"010203")))) } - test("fail to build a trampoline payment when too much invoice data is provided") { + test("fail to build outgoing trampoline payment when too much invoice data is provided") { val routingHintOverflow = List(List.fill(7)(Bolt11Invoice.ExtraHop(randomKey().publicKey, ShortChannelId(1), 10 msat, 100, CltvExpiryDelta(12)))) - val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, Some(finalAmount), paymentHash, priv_a.privateKey, Left("#reckless"), CltvExpiryDelta(18), None, None, routingHintOverflow) - assert(buildTrampolineToLegacyPacket(invoice, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, invoice.paymentSecret, invoice.paymentMetadata)).isFailure) + val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, Some(finalAmount), paymentHash, priv_e.privateKey, Left("#reckless"), CltvExpiryDelta(18), None, None, routingHintOverflow) + val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) + val Left(failure) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + assert(failure.isInstanceOf[CannotCreateOnion]) + } + + test("fail to build outgoing trampoline payment when too much payment metadata is provided") { + val paymentMetadata = ByteVector.fromValidHex("01" * 400) + val invoiceFeatures = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, PaymentMetadata -> Optional, TrampolinePaymentPrototype -> Optional) + val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, Some(finalAmount), paymentHash, priv_e.privateKey, Left("Much payment very metadata"), CltvExpiryDelta(9), features = invoiceFeatures, paymentMetadata = Some(paymentMetadata)) + val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) + val Left(failure) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + assert(failure.isInstanceOf[CannotCreateOnion]) } test("fail to decrypt when the onion is invalid") { - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) - val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet.copy(payload = onion.packet.payload.reverse), None) + val recipient = ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), recipient) + val add = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion.copy(payload = payment.cmd.onion.payload.reverse), None) val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure.isInstanceOf[InvalidOnionHmac]) } test("fail to decrypt when the trampoline onion is invalid") { - val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount * 2, finalExpiry, paymentSecret, None)) - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet.copy(payload = trampolineOnion.packet.payload.reverse))) - val add_b = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None) + val invoiceFeatures = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, PaymentMetadata -> Optional, TrampolinePaymentPrototype -> Optional) + val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_e.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures, paymentMetadata = Some(hex"010203")) + val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + + val add_b = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) + val add_c = UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc, packet_c, None) - val Left(failure) = decrypt(add_c, priv_c.privateKey, Features.empty) + val Right(NodeRelayPacket(_, _, inner_c, trampolinePacket_e)) = decrypt(add_c, priv_c.privateKey, Features.empty) + + // c forwards an invalid trampoline onion to e through d. + val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e.copy(payload = trampolinePacket_e.payload.reverse))) + val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) + val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None) + val Right(ChannelRelayPacket(_, _, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) + + val add_e = UpdateAddHtlc(randomBytes32(), 4, amount_de, paymentHash, expiry_de, packet_e, None) + val Left(failure) = decrypt(add_e, priv_e.privateKey, Features.empty) assert(failure.isInstanceOf[InvalidOnionHmac]) } test("fail to decrypt when payment hash doesn't match associated data") { - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash.reverse, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) - val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None) + val recipient = ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash.reverse, Route(finalAmount, hops, None), recipient) + val add = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure.isInstanceOf[InvalidOnionHmac]) } test("fail to decrypt at the final node when amount has been modified by next-to-last node") { - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, hops.take(1), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) - val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount - 100.msat, paymentHash, firstExpiry, onion.packet, None) + val recipient = ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret) + val route = Route(finalAmount, hops.take(1), None) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + val add = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount - 100.msat, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) - assert(failure == FinalIncorrectHtlcAmount(firstAmount - 100.msat)) + assert(failure == FinalIncorrectHtlcAmount(payment.cmd.amount - 100.msat)) } test("fail to decrypt at the final node when expiry has been modified by next-to-last node") { - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, hops.take(1), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) - val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry - CltvExpiryDelta(12), onion.packet, None) + val recipient = ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret) + val route = Route(finalAmount, hops.take(1), None) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + val add = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry - CltvExpiryDelta(12), payment.cmd.onion, None) val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) - assert(failure == FinalIncorrectCltvExpiry(firstExpiry - CltvExpiryDelta(12))) + assert(failure == FinalIncorrectCltvExpiry(payment.cmd.cltvExpiry - CltvExpiryDelta(12))) + } + + // Create a trampoline payment to e: + // .----. + // / \ + // a -> b -> c e + // + // and return the HTLC sent by b to c. + def createIntermediateTrampolinePayment(): UpdateAddHtlc = { + val invoiceFeatures = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, TrampolinePaymentPrototype -> Optional) + val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_e.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures) + val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + + val add_b = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) + val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) + + UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc, packet_c, None) } test("fail to decrypt at the final trampoline node when amount has been decreased by next-to-last trampoline") { - val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret, None)) - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet)) - val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None), priv_b.privateKey, Features.empty) - val Right(NodeRelayPacket(_, _, _, packet_d)) = decrypt(UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc, packet_c, None), priv_c.privateKey, Features.empty) - // c forwards the trampoline payment to d. - val Success((amount_d, expiry_d, onion_d)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(c, d, channelUpdate_cd) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_cd, amount_cd, expiry_cd, randomBytes32(), packet_d)) - val Right(NodeRelayPacket(_, _, _, packet_e)) = decrypt(UpdateAddHtlc(randomBytes32(), 3, amount_d, paymentHash, expiry_d, onion_d.packet, None), priv_d.privateKey, Features.empty) - // d forwards an invalid amount to e (the outer total amount doesn't match the inner amount). - val invalidTotalAmount = amount_de - 1.msat - val Success((amount_e, expiry_e, onion_e)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(d, e, channelUpdate_de) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_de, invalidTotalAmount, expiry_de, randomBytes32(), packet_e)) - val Left(failure) = decrypt(UpdateAddHtlc(randomBytes32(), 4, amount_e, paymentHash, expiry_e, onion_e.packet, None), priv_e.privateKey, Features.empty) + val add_c = createIntermediateTrampolinePayment() + val Right(NodeRelayPacket(_, _, inner_c, trampolinePacket_e)) = decrypt(add_c, priv_c.privateKey, Features.empty) + + // c forwards an invalid amount to e through (the outer total amount doesn't match the inner amount). + val invalidTotalAmount = inner_c.amountToForward - 1.msat + val recipient_e = ClearRecipient(e, Features.empty, invalidTotalAmount, inner_c.outgoingCltv, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e)) + val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(invalidTotalAmount, afterTrampolineChannelHops, None), recipient_e) + val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None) + val Right(ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) + + val add_e = UpdateAddHtlc(randomBytes32(), 4, payload_d.amountToForward(add_d.amountMsat), paymentHash, payload_d.outgoingCltv(add_d.cltvExpiry), packet_e, None) + val Left(failure) = decrypt(add_e, priv_e.privateKey, Features.empty) assert(failure == FinalIncorrectHtlcAmount(invalidTotalAmount)) } test("fail to decrypt at the final trampoline node when expiry has been modified by next-to-last trampoline") { - val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret, None)) - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet)) - val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None), priv_b.privateKey, Features.empty) - val Right(NodeRelayPacket(_, _, _, packet_d)) = decrypt(UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc, packet_c, None), priv_c.privateKey, Features.empty) - // c forwards the trampoline payment to d. - val Success((amount_d, expiry_d, onion_d)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(c, d, channelUpdate_cd) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_cd, amount_cd, expiry_cd, randomBytes32(), packet_d)) - val Right(NodeRelayPacket(_, _, _, packet_e)) = decrypt(UpdateAddHtlc(randomBytes32(), 3, amount_d, paymentHash, expiry_d, onion_d.packet, None), priv_d.privateKey, Features.empty) - // d forwards an invalid expiry to e (the outer expiry doesn't match the inner expiry). - val invalidExpiry = expiry_de - CltvExpiryDelta(12) - val Success((amount_e, expiry_e, onion_e)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(d, e, channelUpdate_de) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_de, amount_de, invalidExpiry, randomBytes32(), packet_e)) - val Left(failure) = decrypt(UpdateAddHtlc(randomBytes32(), 4, amount_e, paymentHash, expiry_e, onion_e.packet, None), priv_e.privateKey, Features.empty) + val add_c = createIntermediateTrampolinePayment() + val Right(NodeRelayPacket(_, _, inner_c, trampolinePacket_e)) = decrypt(add_c, priv_c.privateKey, Features.empty) + + // c forwards an invalid amount to e through (the outer expiry doesn't match the inner expiry). + val invalidExpiry = inner_c.outgoingCltv - CltvExpiryDelta(12) + val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, invalidExpiry, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e)) + val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None) + val Right(ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) + + val add_e = UpdateAddHtlc(randomBytes32(), 4, payload_d.amountToForward(add_d.amountMsat), paymentHash, payload_d.outgoingCltv(add_d.cltvExpiry), packet_e, None) + val Left(failure) = decrypt(add_e, priv_e.privateKey, Features.empty) assert(failure == FinalIncorrectCltvExpiry(invalidExpiry)) } test("fail to decrypt at intermediate trampoline node when amount is invalid") { - val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet)) - val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None), priv_b.privateKey, Features.empty) + val add_c = createIntermediateTrampolinePayment() // A trampoline relay is very similar to a final node: it can validate that the HTLC amount matches the onion outer amount. - val Left(failure) = decrypt(UpdateAddHtlc(randomBytes32(), 2, amount_bc - 100.msat, paymentHash, expiry_bc, packet_c, None), priv_c.privateKey, Features.empty) - assert(failure == FinalIncorrectHtlcAmount(amount_bc - 100.msat)) + val Left(failure) = decrypt(add_c.copy(amountMsat = add_c.amountMsat - 100.msat), priv_c.privateKey, Features.empty) + assert(failure == FinalIncorrectHtlcAmount(add_c.amountMsat - 100.msat)) } test("fail to decrypt at intermediate trampoline node when expiry is invalid") { - val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet)) - val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None), priv_b.privateKey, Features.empty) - // A trampoline relay is very similar to a final node: it can validate that the HTLC expiry matches the onion outer expiry. - val Left(failure) = decrypt(UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc - CltvExpiryDelta(12), packet_c, None), priv_c.privateKey, Features.empty) + val add_c = createIntermediateTrampolinePayment() + val invalidAdd = add_c.copy(cltvExpiry = add_c.cltvExpiry - CltvExpiryDelta(12)) + // A trampoline relay is very similar to a final node: it validates that the HTLC expiry matches the onion outer expiry. + val Left(failure) = decrypt(invalidAdd, priv_c.privateKey, Features.empty) assert(failure == FinalIncorrectCltvExpiry(expiry_bc - CltvExpiryDelta(12))) } @@ -371,17 +399,6 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { object PaymentPacketSpec { - /** Build onion from arbitrary tlv stream (potentially invalid). */ - def buildTlvOnion(packetPayloadLength: Int, nodes: Seq[PublicKey], payloads: Seq[TlvStream[OnionPaymentPayloadTlv]], associatedData: ByteVector32): OnionRoutingPacket = { - require(nodes.size == payloads.size) - val sessionKey = randomKey() - val payloadsBin: Seq[ByteVector] = payloads.map(PaymentOnionCodecs.perHopPayloadCodec.encode).map { - case Attempt.Successful(bitVector) => bitVector.bytes - case Attempt.Failure(cause) => throw new RuntimeException(s"serialization error: $cause") - } - Sphinx.create(sessionKey, packetPayloadLength, nodes, payloadsBin, Some(associatedData)).get.packet - } - def makeCommitments(channelId: ByteVector32, testAvailableBalanceForSend: MilliSatoshi = 50000000 msat, testAvailableBalanceForReceive: MilliSatoshi = 50000000 msat, testCapacity: Satoshi = 100000 sat, channelFeatures: ChannelFeatures = ChannelFeatures()): Commitments = { val channelReserve = testCapacity * 0.01 val params = LocalParams(null, null, null, Long.MaxValue.msat, Some(channelReserve), null, null, 0, isInitiator = true, null, None, null) @@ -400,21 +417,21 @@ object PaymentPacketSpec { val (a, b, c, d, e) = (priv_a.publicKey, priv_b.publicKey, priv_c.publicKey, priv_d.publicKey, priv_e.publicKey) val sig = Crypto.sign(Crypto.sha256(ByteVector.empty), priv_a.privateKey) val defaultChannelUpdate = ChannelUpdate(sig, Block.RegtestGenesisBlock.hash, ShortChannelId(0), 0 unixsec, ChannelUpdate.MessageFlags(dontForward = false), ChannelUpdate.ChannelFlags.DUMMY, CltvExpiryDelta(0), 42000 msat, 0 msat, 0, 500_000_000 msat) - val channelUpdate_ab = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(1), cltvExpiryDelta = CltvExpiryDelta(4), feeBaseMsat = 642000 msat, feeProportionalMillionths = 7) - val channelUpdate_bc = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(2), cltvExpiryDelta = CltvExpiryDelta(5), feeBaseMsat = 153000 msat, feeProportionalMillionths = 4) - val channelUpdate_cd = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(3), cltvExpiryDelta = CltvExpiryDelta(10), feeBaseMsat = 60000 msat, feeProportionalMillionths = 1) - val channelUpdate_de = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(4), cltvExpiryDelta = CltvExpiryDelta(7), feeBaseMsat = 766000 msat, feeProportionalMillionths = 10) + val channelUpdate_ab = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(1), cltvExpiryDelta = CltvExpiryDelta(4), feeBaseMsat = 642_000 msat, feeProportionalMillionths = 7) + val channelUpdate_bc = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(2), cltvExpiryDelta = CltvExpiryDelta(5), feeBaseMsat = 153_000 msat, feeProportionalMillionths = 4) + val channelUpdate_cd = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(3), cltvExpiryDelta = CltvExpiryDelta(10), feeBaseMsat = 60_000 msat, feeProportionalMillionths = 1) + val channelUpdate_de = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(4), cltvExpiryDelta = CltvExpiryDelta(7), feeBaseMsat = 766_000 msat, feeProportionalMillionths = 10) // simple route a -> b -> c -> d -> e + val hops = Seq( + channelHopFromUpdate(a, b, channelUpdate_ab), + channelHopFromUpdate(b, c, channelUpdate_bc), + channelHopFromUpdate(c, d, channelUpdate_cd), + channelHopFromUpdate(d, e, channelUpdate_de), + ) - val hops = - channelHopFromUpdate(a, b, channelUpdate_ab) :: - channelHopFromUpdate(b, c, channelUpdate_bc) :: - channelHopFromUpdate(c, d, channelUpdate_cd) :: - channelHopFromUpdate(d, e, channelUpdate_de) :: Nil - - val finalAmount = 42000000 msat - val currentBlockCount = 400000 + val finalAmount = 42_000_000 msat + val currentBlockCount = 400_000 val finalExpiry = CltvExpiry(currentBlockCount) + Channel.MIN_CLTV_EXPIRY_DELTA val paymentPreimage = randomBytes32() val paymentHash = Crypto.sha256(paymentPreimage) @@ -437,17 +454,20 @@ object PaymentPacketSpec { val amount_ab = amount_bc + fee_b // simple trampoline route to e: - // .--. .--. - // / \ / \ - // a -> b -> c d e + // .----. + // / \ + // a -> b -> c e - val trampolineHops = - NodeHop(a, c, channelUpdate_ab.cltvExpiryDelta + channelUpdate_bc.cltvExpiryDelta, fee_b) :: - NodeHop(c, d, channelUpdate_cd.cltvExpiryDelta, fee_c) :: - NodeHop(d, e, channelUpdate_de.cltvExpiryDelta, fee_d) :: Nil + val trampolineHop = NodeHop(c, e, channelUpdate_cd.cltvExpiryDelta + channelUpdate_de.cltvExpiryDelta, fee_c + fee_d) - val trampolineChannelHops = - channelHopFromUpdate(a, b, channelUpdate_ab) :: - channelHopFromUpdate(b, c, channelUpdate_bc) :: Nil + val trampolineChannelHops = Seq( + channelHopFromUpdate(a, b, channelUpdate_ab), + channelHopFromUpdate(b, c, channelUpdate_bc) + ) + + val afterTrampolineChannelHops = Seq( + channelHopFromUpdate(c, d, channelUpdate_cd), + channelHopFromUpdate(d, e, channelUpdate_de), + ) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala index 92b304967..9d7e66ab6 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala @@ -27,10 +27,12 @@ import fr.acinq.eclair.channel.Helpers.Closing import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel.states.ChannelStateTestsBase import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus, PaymentType} -import fr.acinq.eclair.payment.OutgoingPaymentPacket.{Upstream, buildCommand} +import fr.acinq.eclair.payment.OutgoingPaymentPacket.{Upstream, buildOutgoingPayment} import fr.acinq.eclair.payment.PaymentPacketSpec._ import fr.acinq.eclair.payment.relay.{PostRestartHtlcCleaner, Relayer} +import fr.acinq.eclair.payment.send.SpontaneousRecipient import fr.acinq.eclair.router.BaseRouterSpec.channelHopFromUpdate +import fr.acinq.eclair.router.Router.Route import fr.acinq.eclair.transactions.Transactions.{ClaimRemoteDelayedOutputTx, InputInfo} import fr.acinq.eclair.transactions.{DirectedHtlc, IncomingHtlc, OutgoingHtlc} import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec @@ -43,7 +45,6 @@ import scodec.bits.ByteVector import java.util.UUID import scala.concurrent.Promise import scala.concurrent.duration._ -import scala.util.Success /** * Created by t-bast on 21/11/2019. @@ -720,8 +721,8 @@ object PostRestartHtlcCleanerSpec { val (paymentHash1, paymentHash2, paymentHash3) = (Crypto.sha256(preimage1), Crypto.sha256(preimage2), Crypto.sha256(preimage3)) def buildHtlc(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): UpdateAddHtlc = { - val Success((cmd, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, hops, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), None)) - UpdateAddHtlc(channelId, htlcId, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion, None) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), SpontaneousRecipient(e, finalAmount, finalExpiry, randomBytes32())) + UpdateAddHtlc(channelId, htlcId, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) } def buildHtlcIn(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): DirectedHtlc = IncomingHtlc(buildHtlc(htlcId, channelId, paymentHash)) @@ -729,8 +730,8 @@ object PostRestartHtlcCleanerSpec { def buildHtlcOut(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): DirectedHtlc = OutgoingHtlc(buildHtlc(htlcId, channelId, paymentHash)) def buildFinalHtlc(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): DirectedHtlc = { - val Success((cmd, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, channelHopFromUpdate(a, TestConstants.Bob.nodeParams.nodeId, channelUpdate_ab) :: Nil, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), None)) - IncomingHtlc(UpdateAddHtlc(channelId, htlcId, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion, None)) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, Seq(channelHopFromUpdate(a, b, channelUpdate_ab)), None), SpontaneousRecipient(b, finalAmount, finalExpiry, randomBytes32())) + IncomingHtlc(UpdateAddHtlc(channelId, htlcId, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None)) } def buildForwardFail(add: UpdateAddHtlc, origin: Origin.Cold): RES_ADD_SETTLED[Origin.Cold, HtlcResult.Fail] = diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala index 78b52cf71..ad91b203a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala @@ -32,10 +32,11 @@ import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.IncomingPaymentPacket.NodeRelayPacket -import fr.acinq.eclair.payment.Invoice.BasicEdge +import fr.acinq.eclair.payment.Invoice.ExtraEdge import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.relay.NodeRelayer.PaymentKey +import fr.acinq.eclair.payment.send.ClearRecipient import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment} import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToNode @@ -220,7 +221,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl // and then one extra val extra = IncomingPaymentPacket.NodeRelayPacket( UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1000 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None), - FinalPayload.Standard.createMultiPartPayload(1000 msat, incomingAmount, CltvExpiry(499990), incomingSecret, None), + FinalPayload.Standard.createPayload(1000 msat, incomingAmount, CltvExpiry(499990), incomingSecret, None), IntermediatePayload.NodeRelay.Standard(outgoingAmount, outgoingExpiry, outgoingNodeId), nextTrampolinePacket) nodeRelayer ! NodeRelay.Relay(extra) @@ -249,7 +250,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl // Receive new extraneous multi-part HTLC. val i1 = IncomingPaymentPacket.NodeRelayPacket( UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1000 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None), - FinalPayload.Standard.createMultiPartPayload(1000 msat, incomingAmount, CltvExpiry(499990), incomingSecret, None), + FinalPayload.Standard.createPayload(1000 msat, incomingAmount, CltvExpiry(499990), incomingSecret, None), IntermediatePayload.NodeRelay.Standard(outgoingAmount, outgoingExpiry, outgoingNodeId), nextTrampolinePacket) nodeRelayer ! NodeRelay.Relay(i1) @@ -262,7 +263,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl // Receive new HTLC with different details, but for the same payment hash. val i2 = IncomingPaymentPacket.NodeRelayPacket( UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1500 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None), - PaymentOnion.FinalPayload.Standard.createSinglePartPayload(1500 msat, CltvExpiry(499990), incomingSecret, None), + PaymentOnion.FinalPayload.Standard.createPayload(1500 msat, 1500 msat, CltvExpiry(499990), incomingSecret, None), IntermediatePayload.NodeRelay.Standard(1250 msat, outgoingExpiry, outgoingNodeId), nextTrampolinePacket) nodeRelayer ! NodeRelay.Relay(i2) @@ -726,18 +727,15 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] validateOutgoingCfg(outgoingCfg, Upstream.Trampoline(incomingMultiPart.map(_.add))) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] - assert(outgoingPayment.paymentSecret == invoice.paymentSecret) // we should use the provided secret - assert(outgoingPayment.paymentMetadata == invoice.paymentMetadata) // we should use the provided metadata - assert(outgoingPayment.totalAmount == outgoingAmount) - assert(outgoingPayment.targetExpiry == outgoingExpiry) - assert(outgoingPayment.targetNodeId == outgoingNodeId) - assert(outgoingPayment.additionalTlvs == Nil) - assert(outgoingPayment.extraEdges.head.asInstanceOf[BasicEdge].shortChannelId == ShortChannelId(42)) - assert(outgoingPayment.extraEdges.head.asInstanceOf[BasicEdge].sourceNodeId == hints.head.nodeId) - assert(outgoingPayment.extraEdges.head.asInstanceOf[BasicEdge].targetNodeId == outgoingNodeId) - assert(outgoingPayment.extraEdges.head.asInstanceOf[BasicEdge].feeBase == 10.msat) - assert(outgoingPayment.extraEdges.head.asInstanceOf[BasicEdge].feeProportionalMillionths == 1) - assert(outgoingPayment.extraEdges.head.asInstanceOf[BasicEdge].cltvExpiryDelta == CltvExpiryDelta(12)) + assert(outgoingPayment.recipient.nodeId == outgoingNodeId) + assert(outgoingPayment.recipient.totalAmount == outgoingAmount) + assert(outgoingPayment.recipient.expiry == outgoingExpiry) + assert(outgoingPayment.recipient.extraEdges.head == ExtraEdge(hints.head.nodeId, outgoingNodeId, ShortChannelId(42), 10 msat, 1, CltvExpiryDelta(12), 1 msat, None)) + assert(outgoingPayment.recipient.isInstanceOf[ClearRecipient]) + val recipient = outgoingPayment.recipient.asInstanceOf[ClearRecipient] + assert(recipient.nextTrampolineOnion_opt.isEmpty) + assert(recipient.paymentSecret == invoice.paymentSecret) // we should use the provided secret + assert(recipient.paymentMetadata_opt == invoice.paymentMetadata) // we should use the provided metadata // those are adapters for pay-fsm messages val nodeRelayerAdapters = outgoingPayment.replyTo @@ -773,17 +771,14 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] validateOutgoingCfg(outgoingCfg, Upstream.Trampoline(incomingMultiPart.map(_.add))) val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode] - assert(outgoingPayment.finalPayload.isInstanceOf[FinalPayload.Standard]) - assert(outgoingPayment.finalPayload.amount == outgoingAmount) - assert(outgoingPayment.finalPayload.expiry == outgoingExpiry) - assert(outgoingPayment.finalPayload.asInstanceOf[FinalPayload.Standard].paymentMetadata == invoice.paymentMetadata) // we should use the provided metadata - assert(outgoingPayment.targetNodeId == outgoingNodeId) - assert(outgoingPayment.extraEdges.head.asInstanceOf[BasicEdge].shortChannelId == ShortChannelId(42)) - assert(outgoingPayment.extraEdges.head.asInstanceOf[BasicEdge].sourceNodeId == hints.head.nodeId) - assert(outgoingPayment.extraEdges.head.asInstanceOf[BasicEdge].targetNodeId == outgoingNodeId) - assert(outgoingPayment.extraEdges.head.asInstanceOf[BasicEdge].feeBase == 10.msat) - assert(outgoingPayment.extraEdges.head.asInstanceOf[BasicEdge].feeProportionalMillionths == 1) - assert(outgoingPayment.extraEdges.head.asInstanceOf[BasicEdge].cltvExpiryDelta == CltvExpiryDelta(12)) + assert(outgoingPayment.recipient.nodeId == outgoingNodeId) + assert(outgoingPayment.amount == outgoingAmount) + assert(outgoingPayment.recipient.expiry == outgoingExpiry) + assert(outgoingPayment.recipient.extraEdges.head == ExtraEdge(hints.head.nodeId, outgoingNodeId, ShortChannelId(42), 10 msat, 1, CltvExpiryDelta(12), 1 msat, None)) + assert(outgoingPayment.recipient.isInstanceOf[ClearRecipient]) + val recipient = outgoingPayment.recipient.asInstanceOf[ClearRecipient] + assert(recipient.nextTrampolineOnion_opt.isEmpty) + assert(recipient.paymentMetadata_opt == invoice.paymentMetadata) // we should use the provided metadata // those are adapters for pay-fsm messages val nodeRelayerAdapters = outgoingPayment.replyTo @@ -829,18 +824,19 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl assert(!outgoingCfg.storeInDb) assert(outgoingCfg.paymentHash == paymentHash) assert(outgoingCfg.invoice.isEmpty) - assert(outgoingCfg.recipientAmount == outgoingAmount) assert(outgoingCfg.recipientNodeId == outgoingNodeId) assert(outgoingCfg.upstream == upstream) } def validateOutgoingPayment(outgoingPayment: SendMultiPartPayment): Unit = { - assert(outgoingPayment.paymentSecret !== incomingSecret) // we should generate a new outgoing secret - assert(outgoingPayment.totalAmount == outgoingAmount) - assert(outgoingPayment.targetExpiry == outgoingExpiry) - assert(outgoingPayment.targetNodeId == outgoingNodeId) - assert(outgoingPayment.additionalTlvs == Seq(OnionPaymentPayloadTlv.TrampolineOnion(nextTrampolinePacket))) - assert(outgoingPayment.extraEdges == Nil) + assert(outgoingPayment.recipient.nodeId == outgoingNodeId) + assert(outgoingPayment.recipient.totalAmount == outgoingAmount) + assert(outgoingPayment.recipient.expiry == outgoingExpiry) + assert(outgoingPayment.recipient.extraEdges == Nil) + assert(outgoingPayment.recipient.isInstanceOf[ClearRecipient]) + val recipient = outgoingPayment.recipient.asInstanceOf[ClearRecipient] + assert(recipient.paymentSecret !== incomingSecret) // we should generate a new outgoing secret + assert(recipient.nextTrampolineOnion_opt.contains(nextTrampolinePacket)) } def validateRelayEvent(e: TrampolinePaymentRelayed): Unit = { @@ -887,11 +883,7 @@ object NodeRelayerSpec { PaymentSent(relayId, paymentHash, paymentPreimage, outgoingAmount, outgoingNodeId, Seq(PaymentSent.PartialPayment(UUID.randomUUID(), outgoingAmount, 10 msat, randomBytes32(), None))) def createValidIncomingPacket(amountIn: MilliSatoshi, totalAmountIn: MilliSatoshi, expiryIn: CltvExpiry, amountOut: MilliSatoshi, expiryOut: CltvExpiry): IncomingPaymentPacket.NodeRelayPacket = { - val outerPayload = if (amountIn == totalAmountIn) { - PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amountIn, expiryIn, incomingSecret, None) - } else { - FinalPayload.Standard.createMultiPartPayload(amountIn, totalAmountIn, expiryIn, incomingSecret, None) - } + val outerPayload = FinalPayload.Standard.createPayload(amountIn, totalAmountIn, expiryIn, incomingSecret, None) IncomingPaymentPacket.NodeRelayPacket( UpdateAddHtlc(randomBytes32(), Random.nextInt(100), amountIn, paymentHash, expiryIn, TestConstants.emptyOnionPacket, None), outerPayload, @@ -904,7 +896,7 @@ object NodeRelayerSpec { val amountIn = incomingAmount / 2 IncomingPaymentPacket.NodeRelayPacket( UpdateAddHtlc(randomBytes32(), Random.nextInt(100), amountIn, paymentHash, expiryIn, TestConstants.emptyOnionPacket, None), - FinalPayload.Standard.createMultiPartPayload(amountIn, incomingAmount, expiryIn, paymentSecret, None), + FinalPayload.Standard.createPayload(amountIn, incomingAmount, expiryIn, paymentSecret, None), IntermediatePayload.NodeRelay.Standard(outgoingAmount, expiryOut, outgoingNodeId), nextTrampolinePacket) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala index d06d3f55f..f0f07e895 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala @@ -22,16 +22,19 @@ import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.Behaviors import akka.actor.typed.scaladsl.adapter.{TypedActorContextOps, TypedActorRefOps} import com.typesafe.config.ConfigFactory -import fr.acinq.bitcoin.scalacompat.ByteVector32 +import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32} +import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} +import fr.acinq.eclair.Features._ import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.payment.Bolt11Invoice import fr.acinq.eclair.payment.IncomingPaymentPacket.FinalPacket -import fr.acinq.eclair.payment.OutgoingPaymentPacket.{Upstream, buildCommand} +import fr.acinq.eclair.payment.OutgoingPaymentPacket.{NodePayload, Upstream, buildOnion, buildOutgoingPayment} import fr.acinq.eclair.payment.PaymentPacketSpec._ import fr.acinq.eclair.payment.relay.Relayer._ -import fr.acinq.eclair.payment.{OutgoingPaymentPacket, PaymentPacketSpec} +import fr.acinq.eclair.payment.send.{ClearRecipient, ClearTrampolineRecipient} import fr.acinq.eclair.router.BaseRouterSpec.channelHopFromUpdate -import fr.acinq.eclair.router.Router.NodeHop +import fr.acinq.eclair.router.Router.{NodeHop, Route} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{NodeParams, TestConstants, randomBytes32, _} @@ -41,7 +44,6 @@ import org.scalatest.{Outcome, Tag} import java.util.UUID import scala.concurrent.duration.DurationInt -import scala.util.Success class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike { @@ -88,9 +90,9 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat } // we use this to build a valid onion - val Success((cmd, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret)) // and then manually build an htlc - val add_ab = UpdateAddHtlc(channelId = randomBytes32(), id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion, None) + val add_ab = UpdateAddHtlc(randomBytes32(), 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) relayer ! RelayForward(add_ab) register.expectMessageType[Register.Forward[CMD_ADD_HTLC]] } @@ -98,34 +100,29 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat test("relay an htlc-add at the final node to the payment handler") { f => import f._ - val Success((cmd, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, hops.take(1), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) - val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion, None) - + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops.take(1), None), ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret)) + val add_ab = UpdateAddHtlc(channelId_ab, 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) relayer ! RelayForward(add_ab) val fp = paymentHandler.expectMessageType[FinalPacket] assert(fp.add == add_ab) - assert(fp.payload == FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) + assert(fp.payload == FinalPayload.Standard.createPayload(finalAmount, finalAmount, finalExpiry, paymentSecret)) register.expectNoMessage(50 millis) } test("relay a trampoline htlc-add at the final node to the payment handler") { f => - import PaymentPacketSpec._ import f._ - val a = PaymentPacketSpec.a // We simulate a payment split between multiple trampoline routes. val totalAmount = finalAmount * 3 - val trampolineHops = NodeHop(a, b, channelUpdate_ab.cltvExpiryDelta, 0 msat) :: Nil - val Success((trampolineAmount, trampolineExpiry, trampolineOnion)) = OutgoingPaymentPacket.buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, totalAmount, finalExpiry, paymentSecret, None)) - assert(trampolineAmount == finalAmount) - assert(trampolineExpiry == finalExpiry) - val Success((cmd, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, channelHopFromUpdate(a, b, channelUpdate_ab) :: Nil, FinalPayload.Standard.createTrampolinePayload(trampolineAmount, trampolineAmount, trampolineExpiry, randomBytes32(), trampolineOnion.packet)) - assert(cmd.amount == finalAmount) - assert(cmd.cltvExpiry == finalExpiry) - val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion, None) - + val finalTrampolinePayload = NodePayload(b, FinalPayload.Standard.createPayload(finalAmount, totalAmount, finalExpiry, paymentSecret)) + val Right(trampolineOnion) = buildOnion(PaymentOnionCodecs.trampolineOnionPayloadLength, Seq(finalTrampolinePayload), paymentHash) + val recipient = ClearRecipient(b, nodeParams.features.invoiceFeatures(), finalAmount, finalExpiry, randomBytes32(), nextTrampolineOnion_opt = Some(trampolineOnion.packet)) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, Seq(channelHopFromUpdate(priv_a.publicKey, b, channelUpdate_ab)), None), recipient) + assert(payment.cmd.amount == finalAmount) + assert(payment.cmd.cltvExpiry == finalExpiry) + val add_ab = UpdateAddHtlc(channelId_ab, 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) relayer ! RelayForward(add_ab) val fp = paymentHandler.expectMessageType[FinalPacket] @@ -143,10 +140,9 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat import f._ // we use this to build a valid onion - val Success((cmd, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret)) // and then manually build an htlc with an invalid onion (hmac) - val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion.copy(hmac = cmd.onion.hmac.reverse), None) - + val add_ab = UpdateAddHtlc(channelId_ab, 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion.copy(hmac = payment.cmd.onion.hmac.reverse), None) relayer ! RelayForward(add_ab) val fail = register.expectMessageType[Register.Forward[CMD_FAIL_MALFORMED_HTLC]].message @@ -158,18 +154,17 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat } test("fail to relay a trampoline htlc-add when trampoline is disabled", Tag("trampoline-disabled")) { f => - import PaymentPacketSpec._ import f._ - val a = PaymentPacketSpec.a // we use this to build a valid trampoline onion inside a normal onion - val trampolineHops = NodeHop(a, b, channelUpdate_ab.cltvExpiryDelta, 0 msat) :: NodeHop(b, c, channelUpdate_bc.cltvExpiryDelta, fee_b) :: Nil - val Success((trampolineAmount, trampolineExpiry, trampolineOnion)) = OutgoingPaymentPacket.buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) - val Success((cmd, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, channelHopFromUpdate(a, b, channelUpdate_ab) :: Nil, FinalPayload.Standard.createTrampolinePayload(trampolineAmount, trampolineAmount, trampolineExpiry, randomBytes32(), trampolineOnion.packet)) + val invoiceFeatures = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, PaymentMetadata -> Optional, TrampolinePaymentPrototype -> Optional) + val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_c.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures) + val trampolineHop = NodeHop(b, c, channelUpdate_bc.cltvExpiryDelta, fee_b) + val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, Seq(channelHopFromUpdate(priv_a.publicKey, b, channelUpdate_ab)), Some(trampolineHop)), recipient) // and then manually build an htlc - val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion, None) - + val add_ab = UpdateAddHtlc(channelId_ab, 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) relayer ! RelayForward(add_ab) val fail = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]].message diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/BalanceEstimateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/BalanceEstimateSpec.scala index 8c4e6371c..9784f9334 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/BalanceEstimateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/BalanceEstimateSpec.scala @@ -20,7 +20,7 @@ import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.{Satoshi, SatoshiLong} import fr.acinq.eclair.payment.Invoice import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} -import fr.acinq.eclair.router.Router.{ChannelDesc, ChannelRelayParams} +import fr.acinq.eclair.router.Router.{ChannelDesc, HopRelayParams} import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, TimestampSecond, randomKey} import org.scalactic.Tolerance.convertNumericToPlusOrMinusWrapper import org.scalatest.funsuite.AnyFunSuite @@ -38,7 +38,7 @@ class BalanceEstimateSpec extends AnyFunSuite { def makeEdge(nodeId1: PublicKey, nodeId2: PublicKey, channelId: Long, capacity: Satoshi): GraphEdge = GraphEdge( ChannelDesc(ShortChannelId(channelId), nodeId1, nodeId2), - ChannelRelayParams.FromHint(Invoice.BasicEdge(nodeId1, nodeId2, ShortChannelId(channelId), 0 msat, 0, CltvExpiryDelta(0))), + HopRelayParams.FromHint(Invoice.ExtraEdge(nodeId1, nodeId2, ShortChannelId(channelId), 0 msat, 0, CltvExpiryDelta(0), 0 msat, None)), capacity, None) def makeEdge(channelId: Long, capacity: Satoshi): GraphEdge = diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala index cf305c000..3aef6d1cf 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala @@ -23,8 +23,9 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.Script.{pay2wsh, write} import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, SatoshiLong, Transaction, TxOut} import fr.acinq.eclair.TestConstants.Alice +import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher.{UtxoStatus, ValidateRequest, ValidateResult, WatchExternalChannelSpent} -import fr.acinq.eclair.channel.{CommitmentsSpec, LocalChannelUpdate, RealScidStatus, ShortChannelIdAssigned, ShortIds} +import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyManager} import fr.acinq.eclair.io.Peer.PeerRoutingMessage @@ -33,7 +34,6 @@ import fr.acinq.eclair.router.BaseRouterSpec.channelAnnouncement import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair._ import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike import scodec.bits.ByteVector @@ -88,7 +88,7 @@ abstract class BaseRouterSpec extends TestKitBaseClass with FixtureAnyFunSuiteLi val alias_ga_private = ShortChannelId.generateLocalAlias() val scids_ab = ShortIds(RealScidStatus.Final(scid_ab), alias_ab, Some(alias_ba)) - val scids_ag_private = ShortIds(RealScidStatus.Final(scid_ag_private), alias_ag_private, Some(alias_ga_private)) + val scids_ag_private = ShortIds(RealScidStatus.Final(scid_ag_private), alias_ag_private, Some(alias_ga_private)) val chan_ab = channelAnnouncement(scid_ab, priv_a, priv_b, priv_funding_a, priv_funding_b) val chan_bc = channelAnnouncement(scid_bc, priv_b, priv_c, priv_funding_b, priv_funding_c) @@ -237,6 +237,7 @@ object BaseRouterSpec { makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, shortChannelId, node1_priv.publicKey, node2_priv.publicKey, funding1_priv.publicKey, funding2_priv.publicKey, node1_sig, node2_sig, funding1_sig, funding2_sig) } - def channelHopFromUpdate(nodeId: PublicKey, nextNodeId: PublicKey, channelUpdate: ChannelUpdate): ChannelHop = - ChannelHop(channelUpdate.shortChannelId, nodeId, nextNodeId, ChannelRelayParams.FromAnnouncement(channelUpdate)) + def channelHopFromUpdate(nodeId: PublicKey, nextNodeId: PublicKey, channelUpdate: ChannelUpdate): ChannelHop = { + ChannelHop(channelUpdate.shortChannelId, nodeId, nextNodeId, HopRelayParams.FromAnnouncement(channelUpdate)) + } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/BlindedRouteCreationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/BlindedRouteCreationSpec.scala index 9f4217615..f617a3e3a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/BlindedRouteCreationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/BlindedRouteCreationSpec.scala @@ -17,7 +17,7 @@ package fr.acinq.eclair.router import fr.acinq.eclair.router.RouteCalculationSpec.makeUpdateShort -import fr.acinq.eclair.router.Router.{ChannelHop, ChannelRelayParams} +import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams} import fr.acinq.eclair.wire.protocol.{BlindedRouteData, RouteBlindingEncryptedDataCodecs, RouteBlindingEncryptedDataTlv} import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, randomBytes32, randomKey} import org.scalatest.funsuite.AnyFunSuite @@ -44,8 +44,8 @@ class BlindedRouteCreationSpec extends AnyFunSuite with ParallelTestExecution { val pathId = randomBytes32() val (scid1, scid2) = (ShortChannelId(1), ShortChannelId(2)) val hops = Seq( - ChannelHop(scid1, a.publicKey, b.publicKey, ChannelRelayParams.FromAnnouncement(makeUpdateShort(scid1, a.publicKey, b.publicKey, 10 msat, 300, cltvDelta = CltvExpiryDelta(200)))), - ChannelHop(scid2, b.publicKey, c.publicKey, ChannelRelayParams.FromAnnouncement(makeUpdateShort(scid2, b.publicKey, c.publicKey, 20 msat, 150, cltvDelta = CltvExpiryDelta(600)))), + ChannelHop(scid1, a.publicKey, b.publicKey, HopRelayParams.FromAnnouncement(makeUpdateShort(scid1, a.publicKey, b.publicKey, 10 msat, 300, cltvDelta = CltvExpiryDelta(200)))), + ChannelHop(scid2, b.publicKey, c.publicKey, HopRelayParams.FromAnnouncement(makeUpdateShort(scid2, b.publicKey, c.publicKey, 20 msat, 150, cltvDelta = CltvExpiryDelta(600)))), ) val route = createBlindedRouteFromHops(hops, pathId, 1 msat, CltvExpiry(500)) assert(route.route.introductionNodeId == a.publicKey) @@ -77,7 +77,7 @@ class BlindedRouteCreationSpec extends AnyFunSuite with ParallelTestExecution { val feeBase = rand.nextInt(10_000).msat val feeProp = rand.nextInt(5000) val cltvExpiryDelta = CltvExpiryDelta(rand.nextInt(500)) - val params = ChannelRelayParams.FromAnnouncement(makeUpdateShort(scid, nodeId, nodeId, feeBase, feeProp, cltvDelta = cltvExpiryDelta)) + val params = HopRelayParams.FromAnnouncement(makeUpdateShort(scid, nodeId, nodeId, feeBase, feeProp, cltvDelta = cltvExpiryDelta)) ChannelHop(scid, nodeId, nodeId, params) }) for (_ <- 0 to 100) { @@ -86,7 +86,7 @@ class BlindedRouteCreationSpec extends AnyFunSuite with ParallelTestExecution { assert(payInfo.cltvExpiryDelta == CltvExpiryDelta(hops.map(_.cltvExpiryDelta.toInt).sum)) // We verify that the aggregated fee slightly exceeds the actual fee (because of proportional fees rounding). val aggregatedFee = payInfo.fee(amount) - val actualFee = Router.Route(amount, hops).fee(includeLocalChannelCost = true) + val actualFee = Router.Route(amount, hops, None).channelFee(includeLocalChannelCost = true) assert(aggregatedFee >= actualFee, s"amount=$amount, hops=${hops.map(_.params.relayFees)}, aggregatedFee=$aggregatedFee, actualFee=$actualFee") assert(aggregatedFee - actualFee < 1000.msat.max(amount * 1e-5), s"amount=$amount, hops=${hops.map(_.params.relayFees)}, aggregatedFee=$aggregatedFee, actualFee=$actualFee") } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala index 7ecf11c5a..50ce90911 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala @@ -28,7 +28,7 @@ import fr.acinq.eclair.router.RouteCalculation._ import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.transactions.Transactions import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{BlockHeight, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, RealShortChannelId, ShortChannelId, TimestampSecond, TimestampSecondLong, ToMilliSatoshiConversion, randomKey} +import fr.acinq.eclair.{BlockHeight, CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, RealShortChannelId, ShortChannelId, TestConstants, TimestampSecond, TimestampSecondLong, ToMilliSatoshiConversion, randomKey} import org.scalatest.TryValues.convertTryToSuccessOrFailure import org.scalatest.funsuite.AnyFunSuite import org.scalatest.{ParallelTestExecution, Tag} @@ -428,14 +428,14 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { val ued = ChannelUpdate(DUMMY_SIG, Block.RegtestGenesisBlock.hash, ShortChannelId(4L), 1 unixsec, ChannelUpdate.MessageFlags(dontForward = false), ChannelUpdate.ChannelFlags(isNode1 = false, isEnabled = false), CltvExpiryDelta(1), 49 msat, 2507 msat, 147, DEFAULT_CAPACITY.toMilliSatoshi) val edges = Seq( - GraphEdge(ChannelDesc(ShortChannelId(1L), a, b), ChannelRelayParams.FromAnnouncement(uab), DEFAULT_CAPACITY, None), - GraphEdge(ChannelDesc(ShortChannelId(1L), b, a), ChannelRelayParams.FromAnnouncement(uba), DEFAULT_CAPACITY, None), - GraphEdge(ChannelDesc(ShortChannelId(2L), b, c), ChannelRelayParams.FromAnnouncement(ubc), DEFAULT_CAPACITY, None), - GraphEdge(ChannelDesc(ShortChannelId(2L), c, b), ChannelRelayParams.FromAnnouncement(ucb), DEFAULT_CAPACITY, None), - GraphEdge(ChannelDesc(ShortChannelId(3L), c, d), ChannelRelayParams.FromAnnouncement(ucd), DEFAULT_CAPACITY, None), - GraphEdge(ChannelDesc(ShortChannelId(3L), d, c), ChannelRelayParams.FromAnnouncement(udc), DEFAULT_CAPACITY, None), - GraphEdge(ChannelDesc(ShortChannelId(4L), d, e), ChannelRelayParams.FromAnnouncement(ude), DEFAULT_CAPACITY, None), - GraphEdge(ChannelDesc(ShortChannelId(4L), e, d), ChannelRelayParams.FromAnnouncement(ued), DEFAULT_CAPACITY, None) + GraphEdge(ChannelDesc(ShortChannelId(1L), a, b), HopRelayParams.FromAnnouncement(uab), DEFAULT_CAPACITY, None), + GraphEdge(ChannelDesc(ShortChannelId(1L), b, a), HopRelayParams.FromAnnouncement(uba), DEFAULT_CAPACITY, None), + GraphEdge(ChannelDesc(ShortChannelId(2L), b, c), HopRelayParams.FromAnnouncement(ubc), DEFAULT_CAPACITY, None), + GraphEdge(ChannelDesc(ShortChannelId(2L), c, b), HopRelayParams.FromAnnouncement(ucb), DEFAULT_CAPACITY, None), + GraphEdge(ChannelDesc(ShortChannelId(3L), c, d), HopRelayParams.FromAnnouncement(ucd), DEFAULT_CAPACITY, None), + GraphEdge(ChannelDesc(ShortChannelId(3L), d, c), HopRelayParams.FromAnnouncement(udc), DEFAULT_CAPACITY, None), + GraphEdge(ChannelDesc(ShortChannelId(4L), d, e), HopRelayParams.FromAnnouncement(ude), DEFAULT_CAPACITY, None), + GraphEdge(ChannelDesc(ShortChannelId(4L), e, d), HopRelayParams.FromAnnouncement(ued), DEFAULT_CAPACITY, None) ) val g = DirectedGraph(edges) @@ -544,7 +544,7 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { ) val publicChannels = channels.map { case (shortChannelId, announcement) => - val ChannelRelayParams.FromAnnouncement(update) = edges.find(_.desc.shortChannelId == shortChannelId).get.params + val HopRelayParams.FromAnnouncement(update) = edges.find(_.desc.shortChannelId == shortChannelId).get.params val (update_1_opt, update_2_opt) = if (update.channelFlags.isNode1) (Some(update), None) else (None, Some(update)) val pc = PublicChannel(announcement, ByteVector32.Zeroes, Satoshi(1000), update_1_opt, update_2_opt, None) (shortChannelId, pc) @@ -928,7 +928,7 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { val amount = 351000 msat val Success(route :: Nil) = findRoute(g, thisNode, targetNode, amount, DEFAULT_MAX_FEE, 1, Set.empty, Set.empty, Set.empty, params, currentBlockHeight = BlockHeight(567634)) // simulate mainnet block for heuristic - assert(route.length == 2) + assert(route.hops.length == 2) assert(route.hops.last.nextNodeId == targetNode) } @@ -966,13 +966,13 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { { val Success(routes) = findMultiPartRoute(g, a, b, amount, 1 msat, routeParams = routeParams, currentBlockHeight = BlockHeight(400000)) assert(routes.length == 4, routes) - assert(routes.forall(_.length == 1), routes) + assert(routes.forall(_.hops.length == 1), routes) checkRouteAmounts(routes, amount, 0 msat) } { val Success(routes) = findMultiPartRoute(g, a, b, amount, 1 msat, routeParams = routeParams.copy(randomize = true), currentBlockHeight = BlockHeight(400000)) assert(routes.length >= 4, routes) - assert(routes.forall(_.length == 1), routes) + assert(routes.forall(_.hops.length == 1), routes) checkRouteAmounts(routes, amount, 0 msat) } { @@ -1009,7 +1009,7 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { val amount = 65000 msat val Success(routes) = findMultiPartRoute(g, a, b, amount, 1 msat, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = BlockHeight(400000)) assert(routes.length == 4, routes) - assert(routes.forall(_.length == 1), routes) + assert(routes.forall(_.hops.length == 1), routes) checkRouteAmounts(routes, amount, 0 msat) } @@ -1028,14 +1028,14 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { { val Success(routes) = findMultiPartRoute(g, a, b, amount, 1 msat, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = BlockHeight(400000)) assert(routes.length == 3, routes) - assert(routes.forall(_.length == 1), routes) + assert(routes.forall(_.hops.length == 1), routes) checkIgnoredChannels(routes, 2L) checkRouteAmounts(routes, amount, 0 msat) } { val Success(routes) = findMultiPartRoute(g, a, b, amount, 1 msat, routeParams = DEFAULT_ROUTE_PARAMS.copy(randomize = true), currentBlockHeight = BlockHeight(400000)) assert(routes.length >= 3, routes) - assert(routes.forall(_.length == 1), routes) + assert(routes.forall(_.hops.length == 1), routes) checkIgnoredChannels(routes, 2L) checkRouteAmounts(routes, amount, 0 msat) } @@ -1054,7 +1054,7 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { val amount = 20000 msat val ignoredEdges = Set(ChannelDesc(ShortChannelId(2L), a, b), ChannelDesc(ShortChannelId(3L), a, b)) val Success(routes) = findMultiPartRoute(g, a, b, amount, 1 msat, ignoredEdges = ignoredEdges, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = BlockHeight(400000)) - assert(routes.forall(_.length == 1), routes) + assert(routes.forall(_.hops.length == 1), routes) checkIgnoredChannels(routes, 2L, 3L) checkRouteAmounts(routes, amount, 0 msat) } @@ -1073,9 +1073,9 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { val amount = 50000 msat // These pending HTLCs will have already been taken into account in the edge's `balance_opt` field: findMultiPartRoute // should ignore this information. - val pendingHtlcs = Seq(Route(10000 msat, graphEdgeToHop(edge_ab_1) :: Nil), Route(5000 msat, graphEdgeToHop(edge_ab_2) :: Nil)) + val pendingHtlcs = Seq(Route(10000 msat, graphEdgeToHop(edge_ab_1) :: Nil, None), Route(5000 msat, graphEdgeToHop(edge_ab_2) :: Nil, None)) val Success(routes) = findMultiPartRoute(g, a, b, amount, 1 msat, pendingHtlcs = pendingHtlcs, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = BlockHeight(400000)) - assert(routes.forall(_.length == 1), routes) + assert(routes.forall(_.hops.length == 1), routes) checkRouteAmounts(routes, amount, 0 msat) } @@ -1089,7 +1089,7 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { val amount = 50000 msat val Success(routes) = findMultiPartRoute(g, a, b, amount, 1 msat, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = BlockHeight(400000)) - assert(routes.forall(_.length == 1), routes) + assert(routes.forall(_.hops.length == 1), routes) assert(routes.length >= 10, routes) assert(routes.forall(_.amount <= 5000.msat), routes) checkRouteAmounts(routes, amount, 0 msat) @@ -1106,7 +1106,7 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { val amount = 30000 msat val routeParams = DEFAULT_ROUTE_PARAMS.copy(mpp = MultiPartParams(2500 msat, 5)) val Success(routes) = findMultiPartRoute(g, a, b, amount, 1 msat, routeParams = routeParams, currentBlockHeight = BlockHeight(400000)) - assert(routes.forall(_.length == 1), routes) + assert(routes.forall(_.hops.length == 1), routes) assert(routes.length == 3, routes) checkRouteAmounts(routes, amount, 0 msat) } @@ -1125,7 +1125,7 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { assert(maxFeeTooLow == Failure(RouteNotFound)) val Success(routes) = findMultiPartRoute(g, a, b, amount, 20 msat, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = BlockHeight(400000)) - assert(routes.forall(_.length <= 2), routes) + assert(routes.forall(_.hops.length <= 2), routes) assert(routes.length == 3, routes) checkRouteAmounts(routes, amount, 20 msat) } @@ -1375,7 +1375,7 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { val Success(routes) = findMultiPartRoute(g, a, d, amount, maxFee, routeParams = routeParams, currentBlockHeight = BlockHeight(400000)) assert(routes.length == 5) routes.foreach(route => { - assert(route.length == 2) + assert(route.hops.length == 2) assert(route.amount <= 1_200_000.msat) assert(!route.hops.flatMap(h => Seq(h.nodeId, h.nextNodeId)).contains(c)) }) @@ -1545,9 +1545,9 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { makeEdge(6L, d, e, 50 msat, 0, minHtlc = 100 msat, capacity = 25 sat), )) - val pendingHtlcs = Seq(Route(5000 msat, graphEdgeToHop(edge_ab) :: graphEdgeToHop(edge_be) :: Nil)) + val pendingHtlcs = Seq(Route(5000 msat, graphEdgeToHop(edge_ab) :: graphEdgeToHop(edge_be) :: Nil, None)) val Success(routes) = findMultiPartRoute(g, a, e, amount, maxFee, pendingHtlcs = pendingHtlcs, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = BlockHeight(400000)) - assert(routes.forall(_.length == 2), routes) + assert(routes.forall(_.hops.length == 2), routes) checkRouteAmounts(routes, amount, maxFee) checkIgnoredChannels(routes, 1L, 2L) } @@ -1704,7 +1704,7 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { val Success(routes) = findRoute(g, a, b, DEFAULT_AMOUNT_MSAT, 100000000 msat, numRoutes = 10, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = BlockHeight(400000)) assert(routes.distinct.length == 10) - val fees = routes.map(_.fee(false)) + val fees = routes.map(_.channelFee(false)) assert(fees.forall(_ == fees.head)) } @@ -1909,9 +1909,10 @@ object RouteCalculationSpec { val noopBoundaries = { _: RichWeight => true } - val DEFAULT_AMOUNT_MSAT = 10000000 msat - val DEFAULT_MAX_FEE = 100000 msat - val DEFAULT_CAPACITY = 100000 sat + val DEFAULT_AMOUNT_MSAT = 10_000_000 msat + val DEFAULT_MAX_FEE = 100_000 msat + val DEFAULT_EXPIRY = CltvExpiry(TestConstants.defaultBlockHeight) + val DEFAULT_CAPACITY = 100_000 sat val NO_WEIGHT_RATIOS: WeightRatios = WeightRatios(1, 0, 0, 0, RelayFees(0 msat, 0)) val DEFAULT_ROUTE_PARAMS = PathFindingConf( @@ -1940,7 +1941,7 @@ object RouteCalculationSpec { capacity: Satoshi = DEFAULT_CAPACITY, balance_opt: Option[MilliSatoshi] = None): GraphEdge = { val update = makeUpdateShort(ShortChannelId(shortChannelId), nodeId1, nodeId2, feeBase, feeProportionalMillionth, minHtlc, maxHtlc, cltvDelta) - GraphEdge(ChannelDesc(RealShortChannelId(shortChannelId), nodeId1, nodeId2), ChannelRelayParams.FromAnnouncement(update), capacity, balance_opt) + GraphEdge(ChannelDesc(RealShortChannelId(shortChannelId), nodeId1, nodeId2), HopRelayParams.FromAnnouncement(update), capacity, balance_opt) } def makeUpdateShort(shortChannelId: ShortChannelId, nodeId1: PublicKey, nodeId2: PublicKey, feeBase: MilliSatoshi, feeProportionalMillionth: Int, minHtlc: MilliSatoshi = DEFAULT_AMOUNT_MSAT, maxHtlc: Option[MilliSatoshi] = None, cltvDelta: CltvExpiryDelta = CltvExpiryDelta(0), timestamp: TimestampSecond = 0 unixsec): ChannelUpdate = @@ -1968,6 +1969,8 @@ object RouteCalculationSpec { def route2Nodes(route: Route): Seq[(PublicKey, PublicKey)] = route.hops.map(hop => (hop.nodeId, hop.nextNodeId)) + def route2NodeIds(route: Route): Seq[PublicKey] = route.hops.head.nodeId +: route.hops.map(_.nextNodeId) + def checkIgnoredChannels(routes: Seq[Route], shortChannelIds: Long*): Unit = { shortChannelIds.foreach(shortChannelId => routes.foreach(route => { assert(route.hops.forall(_.shortChannelId.toLong != shortChannelId), route) @@ -1976,7 +1979,7 @@ object RouteCalculationSpec { def checkRouteAmounts(routes: Seq[Route], totalAmount: MilliSatoshi, maxFee: MilliSatoshi): Unit = { assert(routes.map(_.amount).sum == totalAmount, routes) - assert(routes.map(_.fee(false)).sum <= maxFee, routes) + assert(routes.map(_.channelFee(false)).sum <= maxFee, routes) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala index b3a4a8efa..cefde9282 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala @@ -21,21 +21,22 @@ import akka.actor.Status.Failure import akka.testkit.TestProbe import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.Script.{pay2wsh, write} -import fr.acinq.bitcoin.scalacompat.{Block, SatoshiLong, Transaction, TxOut} +import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, SatoshiLong, Transaction, TxOut} import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher._ import fr.acinq.eclair.channel.{AvailableBalanceChanged, CommitmentsSpec, LocalChannelUpdate} import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.io.Peer.PeerRoutingMessage import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop +import fr.acinq.eclair.payment.send.{ClearRecipient, ClearTrampolineRecipient, SpontaneousRecipient} import fr.acinq.eclair.payment.{Bolt11Invoice, Invoice} import fr.acinq.eclair.router.Announcements.{makeChannelUpdate, makeNodeAnnouncement} import fr.acinq.eclair.router.BaseRouterSpec.channelAnnouncement import fr.acinq.eclair.router.Graph.RoutingHeuristics -import fr.acinq.eclair.router.RouteCalculationSpec.{DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, DEFAULT_ROUTE_PARAMS} +import fr.acinq.eclair.router.RouteCalculationSpec.{DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, DEFAULT_ROUTE_PARAMS, route2NodeIds} import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{BlockHeight, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, RealShortChannelId, ShortChannelId, TestConstants, TimestampSecond, randomKey} +import fr.acinq.eclair.{BlockHeight, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, RealShortChannelId, ShortChannelId, TestConstants, TimestampSecond, randomBytes32, randomKey} import scodec.bits._ import scala.concurrent.duration._ @@ -379,7 +380,7 @@ class RouterSpec extends BaseRouterSpec { import fixture._ val sender = TestProbe() // no route a->f - sender.send(router, RouteRequest(a, f, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(f, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsg(Failure(RouteNotFound)) } @@ -387,7 +388,7 @@ class RouterSpec extends BaseRouterSpec { import fixture._ val sender = TestProbe() // no route a->f - sender.send(router, RouteRequest(randomKey().publicKey, f, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(randomKey().publicKey, SpontaneousRecipient(f, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsg(Failure(RouteNotFound)) } @@ -395,22 +396,22 @@ class RouterSpec extends BaseRouterSpec { import fixture._ val sender = TestProbe() // no route a->f - sender.send(router, RouteRequest(a, randomKey().publicKey, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(randomKey().publicKey, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsg(Failure(RouteNotFound)) } test("route found") { fixture => import fixture._ val sender = TestProbe() - sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(d, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) val res = sender.expectMsgType[RouteResponse] - assert(res.routes.head.hops.map(_.nodeId).toList == a :: b :: c :: Nil) - assert(res.routes.head.hops.last.nextNodeId == d) + assert(route2NodeIds(res.routes.head) == Seq(a, b, c, d)) + assert(res.routes.head.finalHop_opt.isEmpty) - sender.send(router, RouteRequest(a, h, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(h, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) val res1 = sender.expectMsgType[RouteResponse] - assert(res1.routes.head.hops.map(_.nodeId).toList == a :: g :: Nil) - assert(res1.routes.head.hops.last.nextNodeId == h) + assert(route2NodeIds(res1.routes.head) == Seq(a, g, h)) + assert(res1.routes.head.finalHop_opt.isEmpty) } test("route found (with extra routing info)") { fixture => @@ -422,17 +423,67 @@ class RouterSpec extends BaseRouterSpec { val extraHop_cx = ExtraHop(c, ShortChannelId(1), 10 msat, 11, CltvExpiryDelta(12)) val extraHop_xy = ExtraHop(x, ShortChannelId(2), 10 msat, 11, CltvExpiryDelta(12)) val extraHop_yz = ExtraHop(y, ShortChannelId(3), 20 msat, 21, CltvExpiryDelta(22)) - sender.send(router, RouteRequest(a, z, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, extraEdges = Bolt11Invoice.toExtraEdges(extraHop_cx :: extraHop_xy :: extraHop_yz :: Nil, z), routeParams = DEFAULT_ROUTE_PARAMS)) + val recipient = ClearRecipient(z, Features.empty, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One, Bolt11Invoice.toExtraEdges(extraHop_cx :: extraHop_xy :: extraHop_yz :: Nil, z)) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS)) val res = sender.expectMsgType[RouteResponse] - assert(res.routes.head.hops.map(_.nodeId).toList == a :: b :: c :: x :: y :: Nil) - assert(res.routes.head.hops.last.nextNodeId == z) + assert(route2NodeIds(res.routes.head) == Seq(a, b, c, x, y, z)) + assert(res.routes.head.finalHop_opt.isEmpty) + } + + test("routes found (with pending payments)") { fixture => + import fixture._ + val sender = TestProbe() + val routeParams = DEFAULT_ROUTE_PARAMS.copy(boundaries = SearchBoundaries(15 msat, 0.0, 6, CltvExpiryDelta(1008))) + val recipient = ClearRecipient(c, Features.empty, 500_000 msat, DEFAULT_EXPIRY, randomBytes32()) + sender.send(router, RouteRequest(a, recipient, routeParams)) + val route1 = sender.expectMsgType[RouteResponse].routes.head + assert(route1.amount == 500_000.msat) + assert(route2NodeIds(route1) == Seq(a, b, c)) + assert(route1.channelFee(false) == 10.msat) + // We can't find another route to complete the payment amount because it exceeds the fee budget. + sender.send(router, RouteRequest(a, recipient, routeParams, pendingPayments = Seq(route1.copy(amount = 200_000 msat)))) + sender.expectMsg(Failure(RouteNotFound)) + // But if we increase the fee budget, we're able to find a second route. + sender.send(router, RouteRequest(a, recipient, routeParams.copy(boundaries = routeParams.boundaries.copy(maxFeeFlat = 20 msat)), pendingPayments = Seq(route1.copy(amount = 200_000 msat)))) + val route2 = sender.expectMsgType[RouteResponse].routes.head + assert(route2.amount == 300_000.msat) + assert(route2NodeIds(route2) == Seq(a, b, c)) + assert(route2.channelFee(false) == 10.msat) + } + + test("routes found (with trampoline hop)") { fixture => + import fixture._ + val sender = TestProbe() + val routeParams = DEFAULT_ROUTE_PARAMS.copy(boundaries = SearchBoundaries(25_015 msat, 0.0, 6, CltvExpiryDelta(1008))) + val recipientKey = randomKey() + val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, randomBytes32(), recipientKey, Left("invoice"), CltvExpiryDelta(6)) + val trampolineHop = NodeHop(c, recipientKey.publicKey, CltvExpiryDelta(100), 25_000 msat) + val recipient = ClearTrampolineRecipient(invoice, 725_000 msat, DEFAULT_EXPIRY, trampolineHop, randomBytes32()) + sender.send(router, RouteRequest(a, recipient, routeParams)) + val route1 = sender.expectMsgType[RouteResponse].routes.head + assert(route1.amount == 750_000.msat) + assert(route2NodeIds(route1) == Seq(a, b, c)) + assert(route1.channelFee(false) == 10.msat) + assert(route1.trampolineFee == 25_000.msat) + assert(route1.finalHop_opt.contains(trampolineHop)) + // We can't find another route to complete the payment amount because it exceeds the fee budget. + sender.send(router, RouteRequest(a, recipient, routeParams, pendingPayments = Seq(route1.copy(500_000 msat)))) + sender.expectMsg(Failure(RouteNotFound)) + // But if we increase the fee budget, we're able to find a second route. + sender.send(router, RouteRequest(a, recipient, routeParams.copy(boundaries = routeParams.boundaries.copy(maxFeeFlat = 25_020 msat)), pendingPayments = Seq(route1.copy(500_000 msat)))) + val route2 = sender.expectMsgType[RouteResponse].routes.head + assert(route2.amount == 250_000.msat) + assert(route2NodeIds(route2) == Seq(a, b, c)) + assert(route2.channelFee(false) == 10.msat) + assert(route2.trampolineFee == 25_000.msat) + assert(route2.finalHop_opt.contains(trampolineHop)) } test("route not found (channel disabled)") { fixture => import fixture._ val sender = TestProbe() val peerConnection = TestProbe() - sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(d, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) val res = sender.expectMsgType[RouteResponse] assert(res.routes.head.hops.map(_.nodeId).toList == a :: b :: c :: Nil) assert(res.routes.head.hops.last.nextNodeId == d) @@ -440,21 +491,21 @@ class RouterSpec extends BaseRouterSpec { val channelUpdate_cd1 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_c, d, scid_cd, CltvExpiryDelta(3), 0 msat, 153000 msat, 4, htlcMaximum, enable = false) peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, channelUpdate_cd1)) peerConnection.expectMsg(TransportHandler.ReadAck(channelUpdate_cd1)) - sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(d, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsg(Failure(RouteNotFound)) } test("route not found (private channel disabled)") { fixture => import fixture._ val sender = TestProbe() - sender.send(router, RouteRequest(a, h, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(h, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) val res = sender.expectMsgType[RouteResponse] assert(res.routes.head.hops.map(_.nodeId).toList == a :: g :: Nil) assert(res.routes.head.hops.last.nextNodeId == h) val channelUpdate_ag1 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, g, alias_ga_private, CltvExpiryDelta(7), 0 msat, 10 msat, 10, htlcMaximum, enable = false) sender.send(router, LocalChannelUpdate(sender.ref, channelId_ag_private, scids_ag_private, g, None, channelUpdate_ag1, CommitmentsSpec.makeCommitments(10000 msat, 15000 msat, a, g, announceChannel = false))) - sender.send(router, RouteRequest(a, h, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(h, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsg(Failure(RouteNotFound)) } @@ -463,57 +514,57 @@ class RouterSpec extends BaseRouterSpec { val sender = TestProbe() // Via private channels. - sender.send(router, RouteRequest(a, g, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(g, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsgType[RouteResponse] - sender.send(router, RouteRequest(a, g, 50000000 msat, Long.MaxValue.msat, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(g, 50000000 msat, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsg(Failure(BalanceTooLow)) - sender.send(router, RouteRequest(a, g, 50000000 msat, Long.MaxValue.msat, allowMultiPart = true, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(g, 50000000 msat, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS, allowMultiPart = true)) sender.expectMsg(Failure(BalanceTooLow)) // Via public channels. - sender.send(router, RouteRequest(a, b, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(b, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsgType[RouteResponse] val commitments1 = CommitmentsSpec.makeCommitments(10000000 msat, 20000000 msat, a, b, announceChannel = true) sender.send(router, LocalChannelUpdate(sender.ref, null, scids_ab, b, Some(chan_ab), update_ab, commitments1)) - sender.send(router, RouteRequest(a, b, 12000000 msat, Long.MaxValue.msat, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(b, 12000000 msat, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsg(Failure(BalanceTooLow)) - sender.send(router, RouteRequest(a, b, 12000000 msat, Long.MaxValue.msat, allowMultiPart = true, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(b, 12000000 msat, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS, allowMultiPart = true)) sender.expectMsg(Failure(BalanceTooLow)) - sender.send(router, RouteRequest(a, b, 5000000 msat, Long.MaxValue.msat, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(b, 5000000 msat, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsgType[RouteResponse] - sender.send(router, RouteRequest(a, b, 5000000 msat, Long.MaxValue.msat, allowMultiPart = true, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(b, 5000000 msat, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS, allowMultiPart = true)) sender.expectMsgType[RouteResponse] } test("temporary channel exclusion") { fixture => import fixture._ val sender = TestProbe() - sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(d, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsgType[RouteResponse] val bc = ChannelDesc(scid_bc, b, c) // let's exclude channel b->c sender.send(router, ExcludeChannel(bc, Some(1 hour))) - sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(d, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsg(Failure(RouteNotFound)) // note that cb is still available! - sender.send(router, RouteRequest(d, a, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(d, SpontaneousRecipient(a, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsgType[RouteResponse] // let's remove the exclusion sender.send(router, LiftChannelExclusion(bc)) - sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(d, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsgType[RouteResponse] } test("concurrent channel exclusions") { fixture => import fixture._ val sender = TestProbe() - sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(d, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsgType[RouteResponse] val bc = ChannelDesc(scid_bc, b, c) sender.send(router, ExcludeChannel(bc, Some(1 second))) sender.send(router, ExcludeChannel(bc, Some(10 minute))) sender.send(router, ExcludeChannel(bc, Some(1 second))) - sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(d, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsg(Failure(RouteNotFound)) sender.send(router, GetExcludedChannels) val excludedChannels1 = sender.expectMsgType[Map[ChannelDesc, ExcludedChannelStatus]] @@ -521,7 +572,7 @@ class RouterSpec extends BaseRouterSpec { assert(excludedChannels1(bc).isInstanceOf[ExcludedUntil]) assert(excludedChannels1(bc).asInstanceOf[ExcludedUntil].liftExclusionAt > TimestampSecond.now() + 9.minute) sender.send(router, LiftChannelExclusion(bc)) - sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) + sender.send(router, RouteRequest(a, SpontaneousRecipient(d, DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, ByteVector32.One), DEFAULT_ROUTE_PARAMS)) sender.expectMsgType[RouteResponse] sender.send(router, ExcludeChannel(bc, None)) sender.send(router, GetExcludedChannels) @@ -547,25 +598,25 @@ class RouterSpec extends BaseRouterSpec { val sender = TestProbe() { - val preComputedRoute = PredefinedNodeRoute(Seq(a, b, c, d)) - sender.send(router, FinalizeRoute(10000 msat, preComputedRoute)) + val preComputedRoute = PredefinedNodeRoute(10000 msat, Seq(a, b, c, d)) + sender.send(router, FinalizeRoute(preComputedRoute)) val response = sender.expectMsgType[RouteResponse] assert(response.routes.head.hops.map(_.nodeId) == Seq(a, b, c)) assert(response.routes.head.hops.map(_.nextNodeId) == Seq(b, c, d)) assert(response.routes.head.hops.map(_.shortChannelId) == Seq(scid_ab, scid_bc, scid_cd)) - assert(response.routes.head.hops.map(_.params) == Seq(ChannelRelayParams.FromAnnouncement(update_ab), ChannelRelayParams.FromAnnouncement(update_bc), ChannelRelayParams.FromAnnouncement(update_cd))) + assert(response.routes.head.hops.map(_.params) == Seq(HopRelayParams.FromAnnouncement(update_ab), HopRelayParams.FromAnnouncement(update_bc), HopRelayParams.FromAnnouncement(update_cd))) } { - val preComputedRoute = PredefinedNodeRoute(Seq(a, g, h)) - sender.send(router, FinalizeRoute(10000 msat, preComputedRoute)) + val preComputedRoute = PredefinedNodeRoute(10000 msat, Seq(a, g, h)) + sender.send(router, FinalizeRoute(preComputedRoute)) val response = sender.expectMsgType[RouteResponse] assert(response.routes.head.hops.map(_.nodeId) == Seq(a, g)) assert(response.routes.head.hops.map(_.nextNodeId) == Seq(g, h)) assert(response.routes.head.hops.map(_.shortChannelId) == Seq(alias_ag_private, scid_gh)) } { - val preComputedRoute = PredefinedNodeRoute(Seq(a, g, a)) - sender.send(router, FinalizeRoute(10000 msat, preComputedRoute)) + val preComputedRoute = PredefinedNodeRoute(10000 msat, Seq(a, g, a)) + sender.send(router, FinalizeRoute(preComputedRoute)) val response = sender.expectMsgType[RouteResponse] assert(response.routes.head.hops.map(_.nodeId) == Seq(a, g)) assert(response.routes.head.hops.map(_.nextNodeId) == Seq(g, a)) @@ -577,15 +628,15 @@ class RouterSpec extends BaseRouterSpec { import fixture._ val sender = TestProbe() - val preComputedRoute = PredefinedChannelRoute(d, Seq(scid_ab, scid_bc, scid_cd)) - sender.send(router, FinalizeRoute(10000 msat, preComputedRoute)) + val preComputedRoute = PredefinedChannelRoute(10000 msat, d, Seq(scid_ab, scid_bc, scid_cd)) + sender.send(router, FinalizeRoute(preComputedRoute)) val response = sender.expectMsgType[RouteResponse] // the route hasn't changed (nodes are the same) assert(response.routes.head.hops.map(_.nodeId) == Seq(a, b, c)) assert(response.routes.head.hops.map(_.nextNodeId) == Seq(b, c, d)) assert(response.routes.head.hops.map(_.shortChannelId) == Seq(scid_ab, scid_bc, scid_cd)) - assert(response.routes.head.hops.map(_.params) == Seq(ChannelRelayParams.FromAnnouncement(update_ab), ChannelRelayParams.FromAnnouncement(update_bc), ChannelRelayParams.FromAnnouncement(update_cd))) + assert(response.routes.head.hops.map(_.params) == Seq(HopRelayParams.FromAnnouncement(update_ab), HopRelayParams.FromAnnouncement(update_bc), HopRelayParams.FromAnnouncement(update_cd))) } test("given a pre-defined private channels route add the proper channel updates") { fixture => @@ -594,50 +645,50 @@ class RouterSpec extends BaseRouterSpec { { // using the channel alias - val preComputedRoute = PredefinedChannelRoute(g, Seq(alias_ag_private)) - sender.send(router, FinalizeRoute(10000 msat, preComputedRoute)) + val preComputedRoute = PredefinedChannelRoute(10000 msat, g, Seq(alias_ag_private)) + sender.send(router, FinalizeRoute(preComputedRoute)) val response = sender.expectMsgType[RouteResponse] assert(response.routes.length == 1) val route = response.routes.head - assert(route.hops.map(_.params) == Seq(ChannelRelayParams.FromAnnouncement(update_ag_private))) + assert(route.hops.map(_.params) == Seq(HopRelayParams.FromAnnouncement(update_ag_private))) assert(route.hops.head.nodeId == a) assert(route.hops.head.nextNodeId == g) assert(route.hops.head.shortChannelId == alias_ag_private) } { // using the channel alias routing to ourselves: a -> g -> a - val preComputedRoute = PredefinedChannelRoute(a, Seq(alias_ag_private, alias_ag_private)) - sender.send(router, FinalizeRoute(10000 msat, preComputedRoute)) + val preComputedRoute = PredefinedChannelRoute(10000 msat, a, Seq(alias_ag_private, alias_ag_private)) + sender.send(router, FinalizeRoute(preComputedRoute)) val response = sender.expectMsgType[RouteResponse] assert(response.routes.length == 1) val route = response.routes.head - assert(route.hops.map(_.params) == Seq(ChannelRelayParams.FromAnnouncement(update_ag_private), ChannelRelayParams.FromAnnouncement(update_ga_private))) + assert(route.hops.map(_.params) == Seq(HopRelayParams.FromAnnouncement(update_ag_private), HopRelayParams.FromAnnouncement(update_ga_private))) assert(route.hops.map(_.nodeId) == Seq(a, g)) assert(route.hops.map(_.nextNodeId) == Seq(g, a)) assert(route.hops.map(_.shortChannelId) == Seq(alias_ag_private, alias_ga_private)) } { // using the real scid - val preComputedRoute = PredefinedChannelRoute(g, Seq(scid_ag_private)) - sender.send(router, FinalizeRoute(10000 msat, preComputedRoute)) + val preComputedRoute = PredefinedChannelRoute(10000 msat, g, Seq(scid_ag_private)) + sender.send(router, FinalizeRoute(preComputedRoute)) val response = sender.expectMsgType[RouteResponse] assert(response.routes.length == 1) val route = response.routes.head - assert(route.hops.map(_.params) == Seq(ChannelRelayParams.FromAnnouncement(update_ag_private))) + assert(route.hops.map(_.params) == Seq(HopRelayParams.FromAnnouncement(update_ag_private))) assert(route.hops.head.nodeId == a) assert(route.hops.head.nextNodeId == g) assert(route.hops.head.shortChannelId == alias_ag_private) } { - val preComputedRoute = PredefinedChannelRoute(h, Seq(scid_ag_private, scid_gh)) - sender.send(router, FinalizeRoute(10000 msat, preComputedRoute)) + val preComputedRoute = PredefinedChannelRoute(10000 msat, h, Seq(scid_ag_private, scid_gh)) + sender.send(router, FinalizeRoute(preComputedRoute)) val response = sender.expectMsgType[RouteResponse] assert(response.routes.length == 1) val route = response.routes.head assert(route.hops.map(_.nodeId) == Seq(a, g)) assert(route.hops.map(_.nextNodeId) == Seq(g, h)) assert(route.hops.map(_.shortChannelId) == Seq(alias_ag_private, scid_gh)) - assert(route.hops.map(_.params) == Seq(ChannelRelayParams.FromAnnouncement(update_ag_private), ChannelRelayParams.FromAnnouncement(update_gh))) + assert(route.hops.map(_.params) == Seq(HopRelayParams.FromAnnouncement(update_ag_private), HopRelayParams.FromAnnouncement(update_gh))) } } @@ -647,36 +698,36 @@ class RouterSpec extends BaseRouterSpec { val targetNodeId = randomKey().publicKey { - val invoiceRoutingHint = Invoice.BasicEdge(b, targetNodeId, RealShortChannelId(BlockHeight(420000), 516, 1105), 10 msat, 150, CltvExpiryDelta(96)) - val preComputedRoute = PredefinedChannelRoute(targetNodeId, Seq(scid_ab, invoiceRoutingHint.shortChannelId)) val amount = 10_000.msat + val invoiceRoutingHint = Invoice.ExtraEdge(b, targetNodeId, RealShortChannelId(BlockHeight(420000), 516, 1105), 10 msat, 150, CltvExpiryDelta(96), 1 msat, None) + val preComputedRoute = PredefinedChannelRoute(amount, targetNodeId, Seq(scid_ab, invoiceRoutingHint.shortChannelId)) // the amount affects the way we estimate the channel capacity of the hinted channel assert(amount < RoutingHeuristics.CAPACITY_CHANNEL_LOW) - sender.send(router, FinalizeRoute(amount, preComputedRoute, extraEdges = Seq(invoiceRoutingHint))) + sender.send(router, FinalizeRoute(preComputedRoute, extraEdges = Seq(invoiceRoutingHint))) val response = sender.expectMsgType[RouteResponse] assert(response.routes.length == 1) val route = response.routes.head assert(route.hops.map(_.nodeId) == Seq(a, b)) assert(route.hops.map(_.nextNodeId) == Seq(b, targetNodeId)) assert(route.hops.map(_.shortChannelId) == Seq(scid_ab, invoiceRoutingHint.shortChannelId)) - assert(route.hops.head.params == ChannelRelayParams.FromAnnouncement(update_ab)) - assert(route.hops.last.params == ChannelRelayParams.FromHint(invoiceRoutingHint)) + assert(route.hops.head.params == HopRelayParams.FromAnnouncement(update_ab)) + assert(route.hops.last.params == HopRelayParams.FromHint(invoiceRoutingHint)) } { - val invoiceRoutingHint = Invoice.BasicEdge(h, targetNodeId, RealShortChannelId(BlockHeight(420000), 516, 1105), 10 msat, 150, CltvExpiryDelta(96)) - val preComputedRoute = PredefinedChannelRoute(targetNodeId, Seq(scid_ag_private, scid_gh, invoiceRoutingHint.shortChannelId)) val amount = RoutingHeuristics.CAPACITY_CHANNEL_LOW * 2 + val invoiceRoutingHint = Invoice.ExtraEdge(h, targetNodeId, RealShortChannelId(BlockHeight(420000), 516, 1105), 10 msat, 150, CltvExpiryDelta(96), 1 msat, None) + val preComputedRoute = PredefinedChannelRoute(amount, targetNodeId, Seq(scid_ag_private, scid_gh, invoiceRoutingHint.shortChannelId)) // the amount affects the way we estimate the channel capacity of the hinted channel assert(amount > RoutingHeuristics.CAPACITY_CHANNEL_LOW) - sender.send(router, FinalizeRoute(amount, preComputedRoute, extraEdges = Seq(invoiceRoutingHint))) + sender.send(router, FinalizeRoute(preComputedRoute, extraEdges = Seq(invoiceRoutingHint))) val response = sender.expectMsgType[RouteResponse] assert(response.routes.length == 1) val route = response.routes.head assert(route.hops.map(_.nodeId) == Seq(a, g, h)) assert(route.hops.map(_.nextNodeId) == Seq(g, h, targetNodeId)) assert(route.hops.map(_.shortChannelId) == Seq(alias_ag_private, scid_gh, invoiceRoutingHint.shortChannelId)) - assert(route.hops.map(_.params).dropRight(1) == Seq(ChannelRelayParams.FromAnnouncement(update_ag_private), ChannelRelayParams.FromAnnouncement(update_gh))) - assert(route.hops.last.params == ChannelRelayParams.FromHint(invoiceRoutingHint)) + assert(route.hops.map(_.params).dropRight(1) == Seq(HopRelayParams.FromAnnouncement(update_ag_private), HopRelayParams.FromAnnouncement(update_gh))) + assert(route.hops.last.params == HopRelayParams.FromHint(invoiceRoutingHint)) } } @@ -685,23 +736,23 @@ class RouterSpec extends BaseRouterSpec { val sender = TestProbe() { - val preComputedRoute = PredefinedChannelRoute(d, Seq(scid_ab, scid_cd)) - sender.send(router, FinalizeRoute(10000 msat, preComputedRoute)) + val preComputedRoute = PredefinedChannelRoute(10000 msat, d, Seq(scid_ab, scid_cd)) + sender.send(router, FinalizeRoute(preComputedRoute)) sender.expectMsgType[Status.Failure] } { - val preComputedRoute = PredefinedChannelRoute(d, Seq(scid_ab, scid_bc)) - sender.send(router, FinalizeRoute(10000 msat, preComputedRoute)) + val preComputedRoute = PredefinedChannelRoute(10000 msat, d, Seq(scid_ab, scid_bc)) + sender.send(router, FinalizeRoute(preComputedRoute)) sender.expectMsgType[Status.Failure] } { - val preComputedRoute = PredefinedChannelRoute(d, Seq(scid_bc, scid_cd)) - sender.send(router, FinalizeRoute(10000 msat, preComputedRoute)) + val preComputedRoute = PredefinedChannelRoute(10000 msat, d, Seq(scid_bc, scid_cd)) + sender.send(router, FinalizeRoute(preComputedRoute)) sender.expectMsgType[Status.Failure] } { - val preComputedRoute = PredefinedChannelRoute(d, Seq(scid_ab, ShortChannelId(1105), scid_cd)) - sender.send(router, FinalizeRoute(10000 msat, preComputedRoute)) + val preComputedRoute = PredefinedChannelRoute(10000 msat, d, Seq(scid_ab, ShortChannelId(1105), scid_cd)) + sender.send(router, FinalizeRoute(preComputedRoute)) sender.expectMsgType[Status.Failure] } } diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Payment.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Payment.scala index f6b68fbde..1229dba81 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Payment.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Payment.scala @@ -17,11 +17,10 @@ package fr.acinq.eclair.api.handlers import akka.http.scaladsl.server.{MalformedFormFieldRejection, Route} -import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.{ByteVector32, Satoshi} import fr.acinq.eclair.api.Service import fr.acinq.eclair.api.directives.EclairDirectives -import fr.acinq.eclair.api.serde.FormParamExtractors.{pubkeyListUnmarshaller, _} +import fr.acinq.eclair.api.serde.FormParamExtractors._ import fr.acinq.eclair.payment.Bolt11Invoice import fr.acinq.eclair.router.Router.{PredefinedChannelRoute, PredefinedNodeRoute} import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshi, randomBytes32} @@ -39,7 +38,7 @@ trait Payment { val payInvoice: Route = postRequest("payinvoice") { implicit t => formFields(invoiceFormParam, amountMsatFormParam.?, "maxAttempts".as[Int].?, "maxFeeFlatSat".as[Satoshi].?, "maxFeePct".as[Double].?, "externalId".?, "blocking".as[Boolean].?, "pathFindingExperimentName".?) { - case (invoice@Bolt11Invoice(_, Some(amount), _, nodeId, _, _), None, maxAttempts, maxFeeFlat_opt, maxFeePct_opt, externalId_opt, blocking_opt, pathFindingExperimentName_opt) => + case (invoice@Bolt11Invoice(_, Some(amount), _, _, _, _), None, maxAttempts, maxFeeFlat_opt, maxFeePct_opt, externalId_opt, blocking_opt, pathFindingExperimentName_opt) => blocking_opt match { case Some(true) => complete(eclairApi.sendBlocking(externalId_opt, amount, invoice, maxAttempts, maxFeeFlat_opt, maxFeePct_opt, pathFindingExperimentName_opt)) case _ => complete(eclairApi.send(externalId_opt, amount, invoice, maxAttempts, maxFeeFlat_opt, maxFeePct_opt, pathFindingExperimentName_opt)) @@ -56,16 +55,15 @@ trait Payment { val sendToRoute: Route = postRequest("sendtoroute") { implicit t => withRoute { hops => formFields(amountMsatFormParam, "recipientAmountMsat".as[MilliSatoshi].?, invoiceFormParam, "externalId".?, "parentId".as[UUID].?, - "trampolineSecret".as[ByteVector32].?, "trampolineFeesMsat".as[MilliSatoshi].?, "trampolineCltvExpiry".as[Int].?, "trampolineNodes".as[List[PublicKey]](pubkeyListUnmarshaller).?) { - (amountMsat, recipientAmountMsat_opt, invoice, externalId_opt, parentId_opt, trampolineSecret_opt, trampolineFeesMsat_opt, trampolineCltvExpiry_opt, trampolineNodes_opt) => { + "trampolineSecret".as[ByteVector32].?, "trampolineFeesMsat".as[MilliSatoshi].?, "trampolineCltvExpiry".as[Int].?) { + (amountMsat, recipientAmountMsat_opt, invoice, externalId_opt, parentId_opt, trampolineSecret_opt, trampolineFeesMsat_opt, trampolineCltvExpiry_opt) => { val route = hops match { - case Left(shortChannelIds) => PredefinedChannelRoute(invoice.nodeId, shortChannelIds) - case Right(nodeIds) => PredefinedNodeRoute(nodeIds) + case Left(shortChannelIds) => PredefinedChannelRoute(amountMsat, invoice.nodeId, shortChannelIds) + case Right(nodeIds) => PredefinedNodeRoute(amountMsat, nodeIds) } complete(eclairApi.sendToRoute( - amountMsat, recipientAmountMsat_opt, externalId_opt, parentId_opt, invoice, route, trampolineSecret_opt, trampolineFeesMsat_opt, - trampolineCltvExpiry_opt.map(CltvExpiryDelta), trampolineNodes_opt.getOrElse(Nil) - )) + recipientAmountMsat_opt, externalId_opt, parentId_opt, invoice, route, trampolineSecret_opt, trampolineFeesMsat_opt, trampolineCltvExpiry_opt.map(CltvExpiryDelta)) + ) } } } diff --git a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala index 8b5cbc233..280954583 100644 --- a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala +++ b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala @@ -46,7 +46,7 @@ import fr.acinq.eclair.payment.relay.Relayer.ChannelBalance import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.PreimageReceived import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentToRouteResponse import fr.acinq.eclair.router.Router -import fr.acinq.eclair.router.Router.{ChannelRelayParams, PredefinedNodeRoute} +import fr.acinq.eclair.router.Router.{HopRelayParams, PredefinedNodeRoute} import fr.acinq.eclair.wire.protocol._ import org.json4s.{Formats, Serialization} import org.mockito.scalatest.IdiomaticMockito @@ -952,11 +952,11 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM val expected = """{"paymentId":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","parentId":"2ad8c6d7-99cb-4238-8f67-89024b8eed0d"}""" val externalId = UUID.randomUUID().toString val pr = Bolt11Invoice(Block.LivenetGenesisBlock.hash, Some(1234 msat), ByteVector32.Zeroes, randomKey(), Left("Some invoice"), CltvExpiryDelta(24)) - val expectedRoute = PredefinedNodeRoute(Seq(PublicKey(hex"0217eb8243c95f5a3b7d4c5682d10de354b7007eb59b6807ae407823963c7547a9"), PublicKey(hex"0242a4ae0c5bef18048fbecf995094b74bfb0f7391418d71ed394784373f41e4f3"), PublicKey(hex"026ac9fcd64fb1aa1c491fc490634dc33da41d4a17b554e0adf1b32fee88ee9f28"))) + val expectedRoute = PredefinedNodeRoute(1234 msat, Seq(PublicKey(hex"0217eb8243c95f5a3b7d4c5682d10de354b7007eb59b6807ae407823963c7547a9"), PublicKey(hex"0242a4ae0c5bef18048fbecf995094b74bfb0f7391418d71ed394784373f41e4f3"), PublicKey(hex"026ac9fcd64fb1aa1c491fc490634dc33da41d4a17b554e0adf1b32fee88ee9f28"))) val jsonNodes = serialization.write(expectedRoute.nodes) val eclair = mock[Eclair] - eclair.sendToRoute(any[MilliSatoshi], any[Option[MilliSatoshi]], any[Option[String]], any[Option[UUID]], any[Bolt11Invoice], any[PredefinedNodeRoute], any[Option[ByteVector32]], any[Option[MilliSatoshi]], any[Option[CltvExpiryDelta]], any[List[PublicKey]])(any[Timeout]) returns Future.successful(payment) + eclair.sendToRoute(any[Option[MilliSatoshi]], any[Option[String]], any[Option[UUID]], any[Bolt11Invoice], any[PredefinedNodeRoute], any[Option[ByteVector32]], any[Option[MilliSatoshi]], any[Option[CltvExpiryDelta]])(any[Timeout]) returns Future.successful(payment) val mockService = new MockService(eclair) Post("/sendtoroute", FormData("nodeIds" -> jsonNodes, "amountMsat" -> "1234", "finalCltvExpiry" -> "190", "externalId" -> externalId, "invoice" -> pr.toString).toEntity) ~> @@ -967,7 +967,7 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM assert(handled) assert(status == OK) assert(entityAs[String] == expected) - eclair.sendToRoute(1234 msat, None, Some(externalId), None, pr, expectedRoute, None, None, None, Nil)(any[Timeout]).wasCalled(once) + eclair.sendToRoute(None, Some(externalId), None, pr, expectedRoute, None, None, None)(any[Timeout]).wasCalled(once) } } @@ -975,11 +975,11 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM val payment = SendPaymentToRouteResponse(UUID.fromString("487da196-a4dc-4b1e-92b4-3e5e905e9f3f"), UUID.fromString("2ad8c6d7-99cb-4238-8f67-89024b8eed0d"), None) val expected = """{"paymentId":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","parentId":"2ad8c6d7-99cb-4238-8f67-89024b8eed0d"}""" val pr = Bolt11Invoice(Block.LivenetGenesisBlock.hash, Some(1234 msat), ByteVector32.Zeroes, randomKey(), Left("Some invoice"), CltvExpiryDelta(24)) - val expectedRoute = PredefinedNodeRoute(Seq(PublicKey(hex"0217eb8243c95f5a3b7d4c5682d10de354b7007eb59b6807ae407823963c7547a9"), PublicKey(hex"0242a4ae0c5bef18048fbecf995094b74bfb0f7391418d71ed394784373f41e4f3"), PublicKey(hex"026ac9fcd64fb1aa1c491fc490634dc33da41d4a17b554e0adf1b32fee88ee9f28"))) + val expectedRoute = PredefinedNodeRoute(1234 msat, Seq(PublicKey(hex"0217eb8243c95f5a3b7d4c5682d10de354b7007eb59b6807ae407823963c7547a9"), PublicKey(hex"0242a4ae0c5bef18048fbecf995094b74bfb0f7391418d71ed394784373f41e4f3"), PublicKey(hex"026ac9fcd64fb1aa1c491fc490634dc33da41d4a17b554e0adf1b32fee88ee9f28"))) val csvNodes = "0217eb8243c95f5a3b7d4c5682d10de354b7007eb59b6807ae407823963c7547a9, 0242a4ae0c5bef18048fbecf995094b74bfb0f7391418d71ed394784373f41e4f3, 026ac9fcd64fb1aa1c491fc490634dc33da41d4a17b554e0adf1b32fee88ee9f28" val eclair = mock[Eclair] - eclair.sendToRoute(any[MilliSatoshi], any[Option[MilliSatoshi]], any[Option[String]], any[Option[UUID]], any[Bolt11Invoice], any[PredefinedNodeRoute], any[Option[ByteVector32]], any[Option[MilliSatoshi]], any[Option[CltvExpiryDelta]], any[List[PublicKey]])(any[Timeout]) returns Future.successful(payment) + eclair.sendToRoute(any[Option[MilliSatoshi]], any[Option[String]], any[Option[UUID]], any[Bolt11Invoice], any[PredefinedNodeRoute], any[Option[ByteVector32]], any[Option[MilliSatoshi]], any[Option[CltvExpiryDelta]])(any[Timeout]) returns Future.successful(payment) val mockService = new MockService(eclair) // this test uses CSV encoded route @@ -991,7 +991,7 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM assert(handled) assert(status == OK) assert(entityAs[String] == expected) - eclair.sendToRoute(1234 msat, None, None, None, pr, expectedRoute, None, None, None, Nil)(any[Timeout]).wasCalled(once) + eclair.sendToRoute(None, None, None, pr, expectedRoute, None, None, None)(any[Timeout]).wasCalled(once) } } @@ -1015,14 +1015,14 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM val mockChannelUpdate2 = mockChannelUpdate1.copy(shortChannelId = RealShortChannelId(BlockHeight(1), 2, 4)) val mockChannelUpdate3 = mockChannelUpdate1.copy(shortChannelId = RealShortChannelId(BlockHeight(1), 2, 5)) - val mockHop1 = Router.ChannelHop(mockChannelUpdate1.shortChannelId, PublicKey.fromBin(ByteVector.fromValidHex("03007e67dc5a8fd2b2ef21cb310ab6359ddb51f3f86a8b79b8b1e23bc3a6ea150a")), PublicKey.fromBin(ByteVector.fromValidHex("026105f6cb4862810be989385d16f04b0f748f6f2a14040338b1a534d45b4be1c1")), ChannelRelayParams.FromAnnouncement(mockChannelUpdate1)) - val mockHop2 = Router.ChannelHop(mockChannelUpdate2.shortChannelId, mockHop1.nextNodeId, PublicKey.fromBin(ByteVector.fromValidHex("038cfa2b5857843ee90cff91b06f692c0d8fe201921ee6387aee901d64f43699f0")), ChannelRelayParams.FromAnnouncement(mockChannelUpdate2)) - val mockHop3 = Router.ChannelHop(mockChannelUpdate3.shortChannelId, mockHop2.nextNodeId, PublicKey.fromBin(ByteVector.fromValidHex("02be60276e294c6921240daae33a361d214d02578656df0e74c61a09c3196e51df")), ChannelRelayParams.FromAnnouncement(mockChannelUpdate3)) + val mockHop1 = Router.ChannelHop(mockChannelUpdate1.shortChannelId, PublicKey.fromBin(ByteVector.fromValidHex("03007e67dc5a8fd2b2ef21cb310ab6359ddb51f3f86a8b79b8b1e23bc3a6ea150a")), PublicKey.fromBin(ByteVector.fromValidHex("026105f6cb4862810be989385d16f04b0f748f6f2a14040338b1a534d45b4be1c1")), HopRelayParams.FromAnnouncement(mockChannelUpdate1)) + val mockHop2 = Router.ChannelHop(mockChannelUpdate2.shortChannelId, mockHop1.nextNodeId, PublicKey.fromBin(ByteVector.fromValidHex("038cfa2b5857843ee90cff91b06f692c0d8fe201921ee6387aee901d64f43699f0")), HopRelayParams.FromAnnouncement(mockChannelUpdate2)) + val mockHop3 = Router.ChannelHop(mockChannelUpdate3.shortChannelId, mockHop2.nextNodeId, PublicKey.fromBin(ByteVector.fromValidHex("02be60276e294c6921240daae33a361d214d02578656df0e74c61a09c3196e51df")), HopRelayParams.FromAnnouncement(mockChannelUpdate3)) val mockHops = Seq(mockHop1, mockHop2, mockHop3) val eclair = mock[Eclair] val mockService = new MockService(eclair) - eclair.findRoute(any, any, any, any, any, any, any, any)(any[Timeout]) returns Future.successful(Router.RouteResponse(Seq(Router.Route(456.msat, mockHops)))) + eclair.findRoute(any, any, any, any, any, any, any, any)(any[Timeout]) returns Future.successful(Router.RouteResponse(Seq(Router.Route(456.msat, mockHops, None)))) // invalid format Post("/findroute", FormData("format" -> "invalid-output-format", "invoice" -> serializedInvoice, "amountMsat" -> "456")) ~>