diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/package.scala b/eclair-core/src/main/scala/fr/acinq/eclair/package.scala index 3ae5f2ffe..2d7709c97 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/package.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/package.scala @@ -170,6 +170,7 @@ package object eclair { def +(other: MilliSatoshi) = MilliSatoshi(amount + other.amount) def -(other: MilliSatoshi) = MilliSatoshi(amount - other.amount) def *(m: Long) = MilliSatoshi(amount * m) + def *(m: Double) = MilliSatoshi((amount * m).toLong) def /(d: Long) = MilliSatoshi(amount / d) def compare(other: MilliSatoshi): Int = if (amount == other.amount) 0 else if (amount < other.amount) -1 else 1 def <= (that: MilliSatoshi): Boolean = compare(that) <= 0 diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index 6e969dbb1..59d241fa7 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -19,9 +19,9 @@ package fr.acinq.eclair.router import akka.Done import akka.actor.{ActorRef, Props, Status} import akka.event.Logging.MDC -import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Satoshi} import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.Script.{pay2wsh, write} +import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Satoshi} import fr.acinq.eclair._ import fr.acinq.eclair.blockchain._ import fr.acinq.eclair.channel._ @@ -33,7 +33,6 @@ import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} import fr.acinq.eclair.router.Graph.{RichWeight, WeightRatios} import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire._ -import scodec.bits.ByteVector import scala.collection.immutable.{SortedMap, TreeMap} import scala.collection.{SortedSet, mutable} @@ -868,10 +867,21 @@ object Router { val currentBlockHeight = Globals.blockCount.get() + def feeBaseOk(fee: MilliSatoshi): Boolean = fee <= routeParams.maxFeeBase + + def feePctOk(fee: MilliSatoshi, amount: MilliSatoshi): Boolean = { + val maxFee = amount * routeParams.maxFeePct + fee <= maxFee + } + + def feeOk(fee: MilliSatoshi, amount: MilliSatoshi): Boolean = feeBaseOk(fee) || feePctOk(fee, amount) + + def lengthOk(length: Int): Boolean = length <= routeParams.routeMaxLength && length <= ROUTE_MAX_LENGTH + + def cltvOk(cltv: Int): Boolean = cltv <= routeParams.routeMaxCltv + val boundaries: RichWeight => Boolean = { weight => - ((weight.cost - amount) < routeParams.maxFeeBase || (weight.cost - amount) < amount * routeParams.maxFeePct.toLong) && - weight.length <= routeParams.routeMaxLength && weight.length <= ROUTE_MAX_LENGTH && - weight.cltv <= routeParams.routeMaxCltv + feeOk(weight.cost - amount, amount) && lengthOk(weight.length) && cltvOk(weight.cltv) } val foundRoutes = Graph.yenKshortestPaths(g, localNodeId, targetNodeId, amount, ignoredEdges, extraEdges, numRoutes, routeParams.ratios, currentBlockHeight, boundaries).toList match { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala index c4930b0bd..6eaddc527 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala @@ -16,8 +16,8 @@ package fr.acinq.eclair.router -import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} -import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64, Crypto} +import fr.acinq.bitcoin.Crypto.PublicKey +import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64} import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph.graphEdgeToHop import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} @@ -56,6 +56,30 @@ class RouteCalculationSpec extends FunSuite { assert(route.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) } + test("check fee against max pct properly") { + + // fee is acceptable is it is either + // - below our maximum fee base + // - below our maximum fraction of the paid amount + + // here we have a maximum fee base of 1 msat, and all our updates have a base fee of 10 msat + // so our fee will always be above the base fee, and we will always check that it is below our maximum percentage + // of the amount being paid + + val updates = List( + makeUpdate(1L, a, b, MilliSatoshi(10), 10, cltvDelta = 1), + makeUpdate(2L, b, c, MilliSatoshi(10), 10, cltvDelta = 1), + makeUpdate(3L, c, d, MilliSatoshi(10), 10, cltvDelta = 1), + makeUpdate(4L, d, e, MilliSatoshi(10), 10, cltvDelta = 1) + ).toMap + + val g = makeGraph(updates) + + val route = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS.copy(maxFeeBase = MilliSatoshi(1))) + + assert(route.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) + } + test("calculate the shortest path (correct fees)") { val (a, b, c, d, e, f) = ( @@ -920,7 +944,6 @@ class RouteCalculationSpec extends FunSuite { assert(route.size == 2) assert(route.last.nextNodeId == targetNode) } - } object RouteCalculationSpec {