diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala index 2d6e03ce2..670934b46 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala @@ -112,12 +112,12 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, relayer: ActorR val trampolineSecret = r.trampolineSecret.getOrElse(randomBytes32) sender ! SendPaymentToRouteResponse(paymentId, parentPaymentId, Some(trampolineSecret)) val (trampolineAmount, trampolineExpiry, trampolineOnion) = buildTrampolinePayment(SendTrampolinePaymentRequest(r.recipientAmount, r.paymentRequest, trampoline, Seq((r.trampolineFees, r.trampolineExpiryDelta)), r.finalExpiryDelta), r.trampolineFees, r.trampolineExpiryDelta) - payFsm forward SendPaymentToRoute(r.route, Onion.createMultiPartPayload(r.amount, trampolineAmount, trampolineExpiry, trampolineSecret, Seq(OnionTlv.TrampolineOnion(trampolineOnion)))) + payFsm forward SendPaymentToRoute(r.route, Onion.createMultiPartPayload(r.amount, trampolineAmount, trampolineExpiry, trampolineSecret, Seq(OnionTlv.TrampolineOnion(trampolineOnion))), r.paymentRequest.routingInfo) case Nil => sender ! SendPaymentToRouteResponse(paymentId, parentPaymentId, None) r.paymentRequest.paymentSecret match { - case Some(paymentSecret) => payFsm forward SendPaymentToRoute(r.route, Onion.createMultiPartPayload(r.amount, r.recipientAmount, finalExpiry, paymentSecret)) - case None => payFsm forward SendPaymentToRoute(r.route, FinalLegacyPayload(r.recipientAmount, finalExpiry)) + case Some(paymentSecret) => payFsm forward SendPaymentToRoute(r.route, Onion.createMultiPartPayload(r.amount, r.recipientAmount, finalExpiry, paymentSecret), r.paymentRequest.routingInfo) + case None => payFsm forward SendPaymentToRoute(r.route, FinalLegacyPayload(r.recipientAmount, finalExpiry), r.paymentRequest.routingInfo) } case _ => sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(TrampolineMultiNodeNotSupported) :: Nil) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala index bd2da36f8..a5961a958 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala @@ -72,7 +72,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A span.tag("expiry", c.finalPayload.expiry.toLong) log.debug("sending {} to route {}", c.finalPayload.amount, c.hops.mkString("->")) val send = SendPayment(c.hops.last, c.finalPayload, maxAttempts = 1) - router ! FinalizeRoute(c.hops) + router ! FinalizeRoute(c.hops, c.assistedRoutes) if (cfg.storeInDb) { paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, c.finalPayload.amount, cfg.recipientAmount, cfg.recipientNodeId, Platform.currentTime, cfg.paymentRequest, OutgoingPaymentStatus.Pending)) } @@ -282,7 +282,7 @@ object PaymentLifecycle { * @param hops payment route to use. * @param finalPayload onion payload for the target node. */ - case class SendPaymentToRoute(hops: Seq[PublicKey], finalPayload: FinalPayload) { + case class SendPaymentToRoute(hops: Seq[PublicKey], finalPayload: FinalPayload, assistedRoutes: Seq[Seq[ExtraHop]] = Nil) { require(hops.nonEmpty, s"payment route must not be empty") val targetNodeId = hops.last } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index f7f417f91..0d60c4dce 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -148,7 +148,7 @@ case class RouteRequest(source: PublicKey, ignoreChannels: Set[ChannelDesc] = Set.empty, routeParams: Option[RouteParams] = None) -case class FinalizeRoute(hops: Seq[PublicKey]) +case class FinalizeRoute(hops: Seq[PublicKey], assistedRoutes: Seq[Seq[ExtraHop]] = Nil) case class RouteResponse(hops: Seq[ChannelHop], ignoreNodes: Set[PublicKey], ignoreChannels: Set[ChannelDesc], allowEmpty: Boolean = false) { require(allowEmpty || hops.nonEmpty, "route cannot be empty") @@ -533,9 +533,13 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ sender ! d stay - case Event(FinalizeRoute(partialHops), d) => + case Event(FinalizeRoute(partialHops, assistedRoutes), d) => + // NB: using a capacity of 0 msat will impact the path-finding algorithm. However here we don't run any path-finding, so it's ok. + val assistedChannels: Map[ShortChannelId, AssistedChannel] = assistedRoutes.flatMap(toAssistedChannels(_, partialHops.last, 0 msat)).toMap + val extraEdges = assistedChannels.values.map(ac => GraphEdge(ChannelDesc(ac.extraHop.shortChannelId, ac.extraHop.nodeId, ac.nextNodeId), toFakeUpdate(ac.extraHop, ac.htlcMaximum))).toSet + val g = extraEdges.foldLeft(d.graph) { case (g: DirectedGraph, e: GraphEdge) => g.addEdge(e) } // split into sublists [(a,b),(b,c), ...] then get the edges between each of those pairs - partialHops.sliding(2).map { case List(v1, v2) => d.graph.getEdgesBetween(v1, v2) }.toList match { + partialHops.sliding(2).map { case List(v1, v2) => g.getEdgesBetween(v1, v2) }.toList match { case edges if edges.nonEmpty && edges.forall(_.nonEmpty) => val selectedEdges = edges.map(_.maxBy(_.update.htlcMaximumMsat.getOrElse(0 msat))) // select the largest edge val hops = selectedEdges.map(d => ChannelHop(d.desc.a, d.desc.b, d.update)) @@ -1299,11 +1303,12 @@ object Router { /** * Build a `reply_channel_range` message - * @param chunk chunk of scids - * @param chainHash chain hash + * + * @param chunk chunk of scids + * @param chainHash chain hash * @param defaultEncoding default encoding - * @param queryFlags_opt query flag set by the requester - * @param channels channels map + * @param queryFlags_opt query flag set by the requester + * @param channels channels map * @return a ReplyChannelRange object */ def buildReplyChannelRange(chunk: ShortChannelIdsChunk, chainHash: ByteVector32, defaultEncoding: EncodingType, queryFlags_opt: Option[QueryChannelRangeTlv.QueryFlags], channels: SortedMap[ShortChannelId, PublicChannel]): ReplyChannelRange = { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index 49240d2b3..e760c6ae7 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -116,6 +116,29 @@ class PaymentLifecycleSpec extends BaseRouterSpec { assert(failureMessage == "Not all the nodes in the supplied route are connected with public channels") } + test("send to route (routing hints)") { routerFixture => + val payFixture = createPaymentLifecycle() + import payFixture._ + + val recipient = randomKey.publicKey + val routingHint = Seq(Seq(ExtraHop(c, ShortChannelId(561), 1 msat, 100, CltvExpiryDelta(144)))) + val request = SendPaymentToRoute(Seq(a, b, c, recipient), FinalLegacyPayload(defaultAmountMsat, defaultExpiry), routingHint) + + sender.send(paymentFSM, request) + routerForwarder.expectMsg(FinalizeRoute(Seq(a, b, c, recipient), routingHint)) + val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) + + routerForwarder.forward(routerFixture.router) + val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) + + // Payment accepted by the recipient. + sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) + + val ps = sender.expectMsgType[PaymentSent] + assert(ps.id === parentId) + awaitCond(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Succeeded])) + } + test("send with route prefix") { _ => val payFixture = createPaymentLifecycle() import payFixture._