From b084d73e96a9602d611e76f7756167be9c3396df Mon Sep 17 00:00:00 2001 From: rorp Date: Wed, 7 Jun 2023 07:01:09 -0700 Subject: [PATCH] Add `maxFeeMsat` parameter to `sendtoroute` RPC call (#2626) This ensures that routes found with `findroute*` and a max fee are correctly ignored if we later use `sendtoroute` and the route fee has increased. --- .../payment/send/PaymentLifecycle.scala | 4 ++-- .../eclair/router/RouteCalculation.scala | 22 +++++++++++++++---- .../scala/fr/acinq/eclair/router/Router.scala | 5 +++-- .../fr/acinq/eclair/router/RouterSpec.scala | 16 ++++++++++++++ .../acinq/eclair/api/handlers/Payment.scala | 8 +++---- 5 files changed, 43 insertions(+), 12 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala index 68488f0a2..381ce075e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala @@ -431,8 +431,8 @@ object PaymentLifecycle { override val amount = route.fold(_.amount, _.amount) def printRoute(): String = route match { - case Left(PredefinedChannelRoute(_, _, channels)) => channels.mkString("->") - case Left(PredefinedNodeRoute(_, nodes)) => nodes.mkString("->") + case Left(PredefinedChannelRoute(_, _, channels, _)) => channels.mkString("->") + case Left(PredefinedNodeRoute(_, nodes, _)) => nodes.mkString("->") case Right(route) => route.printNodes() } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala index bc779dedd..40e6f61a5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala @@ -50,6 +50,14 @@ object RouteCalculation { } def finalizeRoute(d: Data, localNodeId: PublicKey, fr: FinalizeRoute)(implicit ctx: ActorContext, log: DiagnosticLoggingAdapter): Data = { + def validateMaxRouteFee(route: Route, maxFee_opt: Option[MilliSatoshi]): Try[Route] = { + val routeFee = route.channelFee(includeLocalChannelCost = false) + maxFee_opt match { + case Some(maxFee) if maxFee < routeFee => Failure(new IllegalArgumentException(s"Route fee ($routeFee) was above the maximum allowed fee ($maxFee) for route ${route.printChannels()}")) + case _ => Success(route) + } + } + Logs.withMdc(log)(Logs.mdc( category_opt = Some(LogCategory.PAYMENT), parentPaymentId_opt = fr.paymentContext.map(_.parentId), @@ -61,19 +69,22 @@ object RouteCalculation { val g = extraEdges.foldLeft(d.graphWithBalances.graph) { case (g: DirectedGraph, e: GraphEdge) => g.addEdge(e) } fr.route match { - case PredefinedNodeRoute(amount, hops) => + case PredefinedNodeRoute(amount, hops, maxFee_opt) => // split into sublists [(a,b),(b,c), ...] then get the edges between each of those pairs hops.sliding(2).map { case List(v1, v2) => g.getEdgesBetween(v1, v2) }.toList match { case edges if edges.nonEmpty && edges.forall(_.nonEmpty) => // select the largest edge (using balance when available, otherwise capacity). val selectedEdges = edges.map(es => es.maxBy(e => e.balance_opt.getOrElse(e.capacity.toMilliSatoshi))) val hops = selectedEdges.map(e => ChannelHop(getEdgeRelayScid(d, localNodeId, e), e.desc.a, e.desc.b, e.params)) - ctx.sender() ! RouteResponse(Route(amount, hops, None) :: Nil) + validateMaxRouteFee(Route(amount, hops, None), maxFee_opt) match { + case Success(route) => ctx.sender() ! RouteResponse(route :: Nil) + case Failure(f) => ctx.sender() ! Status.Failure(f) + } case _ => // some nodes in the supplied route aren't connected in our graph ctx.sender() ! Status.Failure(new IllegalArgumentException("Not all the nodes in the supplied route are connected with public channels")) } - case PredefinedChannelRoute(amount, targetNodeId, shortChannelIds) => + case PredefinedChannelRoute(amount, targetNodeId, shortChannelIds, maxFee_opt) => val (end, hops) = shortChannelIds.foldLeft((localNodeId, Seq.empty[ChannelHop])) { case ((currentNode, previousHops), shortChannelId) => val channelDesc_opt = d.resolve(shortChannelId) match { @@ -97,7 +108,10 @@ object RouteCalculation { if (end != targetNodeId || hops.length != shortChannelIds.length) { ctx.sender() ! Status.Failure(new IllegalArgumentException("The sequence of channels provided cannot be used to build a route to the target node")) } else { - ctx.sender() ! RouteResponse(Route(amount, hops, None) :: Nil) + validateMaxRouteFee(Route(amount, hops, None), maxFee_opt) match { + case Success(route) => ctx.sender() ! RouteResponse(route :: Nil) + case Failure(f) => ctx.sender() ! Status.Failure(f) + } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index 13aeac79d..7950dec57 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -637,12 +637,13 @@ object Router { def isEmpty: Boolean def amount: MilliSatoshi def targetNodeId: PublicKey + def maxFee_opt: Option[MilliSatoshi] } - case class PredefinedNodeRoute(amount: MilliSatoshi, nodes: Seq[PublicKey]) extends PredefinedRoute { + case class PredefinedNodeRoute(amount: MilliSatoshi, nodes: Seq[PublicKey], maxFee_opt: Option[MilliSatoshi] = None) extends PredefinedRoute { override def isEmpty = nodes.isEmpty override def targetNodeId: PublicKey = nodes.last } - case class PredefinedChannelRoute(amount: MilliSatoshi, targetNodeId: PublicKey, channels: Seq[ShortChannelId]) extends PredefinedRoute { + case class PredefinedChannelRoute(amount: MilliSatoshi, targetNodeId: PublicKey, channels: Seq[ShortChannelId], maxFee_opt: Option[MilliSatoshi] = None) extends PredefinedRoute { override def isEmpty = channels.isEmpty } // @formatter:on diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala index 50f179d79..a8c82aa9b 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala @@ -892,6 +892,22 @@ class RouterSpec extends BaseRouterSpec { } } + test("given a pre-defined channels route properly handles provided max fee") { fixture => + import fixture._ + val sender = TestProbe() + + { + val preComputedRoute = PredefinedChannelRoute(10000 msat, d, Seq(scid_ab, scid_bc, scid_cd), maxFee_opt = Some(19.msat)) + sender.send(router, FinalizeRoute(preComputedRoute)) + sender.expectMsgType[Status.Failure] + } + { + val preComputedRoute = PredefinedChannelRoute(10000 msat, d, Seq(scid_ab, scid_bc, scid_cd), maxFee_opt = Some(20.msat)) + sender.send(router, FinalizeRoute(preComputedRoute)) + sender.expectMsgType[Router.RouteResponse] + } + } + test("restore stale channel that comes back from the dead") { fixture => import fixture._ diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Payment.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Payment.scala index c55a368b9..66e650f05 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Payment.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Payment.scala @@ -56,11 +56,11 @@ trait Payment { val sendToRoute: Route = postRequest("sendtoroute") { implicit t => withRoute { hops => formFields(amountMsatFormParam, "recipientAmountMsat".as[MilliSatoshi].?, invoiceFormParam, "externalId".?, "parentId".as[UUID].?, - "trampolineSecret".as[ByteVector32].?, "trampolineFeesMsat".as[MilliSatoshi].?, "trampolineCltvExpiry".as[Int].?) { - (amountMsat, recipientAmountMsat_opt, invoice, externalId_opt, parentId_opt, trampolineSecret_opt, trampolineFeesMsat_opt, trampolineCltvExpiry_opt) => { + "trampolineSecret".as[ByteVector32].?, "trampolineFeesMsat".as[MilliSatoshi].?, "trampolineCltvExpiry".as[Int].?, maxFeeMsatFormParam.?) { + (amountMsat, recipientAmountMsat_opt, invoice, externalId_opt, parentId_opt, trampolineSecret_opt, trampolineFeesMsat_opt, trampolineCltvExpiry_opt, maxFee_opt) => { val route = hops match { - case Left(shortChannelIds) => PredefinedChannelRoute(amountMsat, invoice.nodeId, shortChannelIds) - case Right(nodeIds) => PredefinedNodeRoute(amountMsat, nodeIds) + case Left(shortChannelIds) => PredefinedChannelRoute(amountMsat, invoice.nodeId, shortChannelIds, maxFee_opt) + case Right(nodeIds) => PredefinedNodeRoute(amountMsat, nodeIds, maxFee_opt) } complete(eclairApi.sendToRoute( recipientAmountMsat_opt, externalId_opt, parentId_opt, invoice, route, trampolineSecret_opt, trampolineFeesMsat_opt, trampolineCltvExpiry_opt.map(CltvExpiryDelta))