From 255c280bd0315af4f718240e802687f6a1967b70 Mon Sep 17 00:00:00 2001 From: araspitzu Date: Thu, 20 Dec 2018 16:52:42 +0100 Subject: [PATCH] Routing: use custom implementation for the shortest path algorithm and the graph (#779) * Consider htlc_minimum/maximum_msat when computing a route * Compare shortChannelIds first as it is less costly than comparing the pubkeys * Remove export to dot functionality * Remove dependency jgraph * Add optimized constructor to build the graph faster * Use fibonacci heaps from jheaps.org * Use Set instead of Seq for extraEdges, remove redundant publishing of channel updates * Use Set for ignored edges --- eclair-core/pom.xml | 21 +- .../scala/fr/acinq/eclair/api/Service.scala | 7 +- .../eclair/payment/PaymentLifecycle.scala | 14 +- .../scala/fr/acinq/eclair/router/Graph.scala | 326 ++++++++++++++++++ .../scala/fr/acinq/eclair/router/Router.scala | 205 ++++------- .../eclair/payment/PaymentLifecycleSpec.scala | 26 +- .../fr/acinq/eclair/router/GraphSpec.scala | 220 ++++++++++++ .../eclair/router/RouteCalculationSpec.scala | 177 ++++++++-- .../fr/acinq/eclair/router/RouterSpec.scala | 45 +-- .../src/main/resources/gui/main/main.fxml | 6 - .../scala/fr/acinq/eclair/gui/Handlers.scala | 15 - .../gui/controllers/MainController.scala | 8 - 12 files changed, 798 insertions(+), 272 deletions(-) create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala create mode 100644 eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala diff --git a/eclair-core/pom.xml b/eclair-core/pom.xml index 5e97166b4..a6de36c71 100644 --- a/eclair-core/pom.xml +++ b/eclair-core/pom.xml @@ -190,27 +190,16 @@ 1.3.1 + + org.jheaps + jheaps + 0.9 + org.xerial sqlite-jdbc 3.21.0.1 - - org.jgrapht - jgrapht-core - 1.0.1 - - - org.jgrapht - jgrapht-ext - 1.0.1 - - - org.tinyjee.jgraphx - jgraphx - - - com.google.code.findbugs diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/api/Service.scala b/eclair-core/src/main/scala/fr/acinq/eclair/api/Service.scala index ea5d8de5e..5a946dc22 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/api/Service.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/api/Service.scala @@ -40,7 +40,8 @@ import fr.acinq.eclair.io.Peer.{GetPeerInfo, PeerInfo} import fr.acinq.eclair.io.{NodeURI, Peer} import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.payment._ -import fr.acinq.eclair.router.{ChannelDesc, RouteRequest, RouteResponse} +import fr.acinq.eclair.router.{ChannelDesc, RouteRequest, RouteResponse, Router} +import fr.acinq.eclair.router.Router.DEFAULT_AMOUNT_MSAT import fr.acinq.eclair.wire.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement} import fr.acinq.eclair.{Kit, ShortChannelId, feerateByte2Kw} import grizzled.slf4j.Logging @@ -244,11 +245,11 @@ trait Service extends Logging { case "findroute" => req.params match { case JString(nodeId) :: Nil if nodeId.length() == 66 => Try(PublicKey(nodeId)) match { - case Success(pk) => completeRpcFuture(req.id, (router ? RouteRequest(appKit.nodeParams.nodeId, pk)).mapTo[RouteResponse]) + case Success(pk) => completeRpcFuture(req.id, (router ? RouteRequest(appKit.nodeParams.nodeId, pk, DEFAULT_AMOUNT_MSAT)).mapTo[RouteResponse]) case Failure(_) => reject(RpcValidationRejection(req.id, s"invalid nodeId hash '$nodeId'")) } case JString(paymentRequest) :: Nil => Try(PaymentRequest.read(paymentRequest)) match { - case Success(pr) => completeRpcFuture(req.id, (router ? RouteRequest(appKit.nodeParams.nodeId, pr.nodeId)).mapTo[RouteResponse]) + case Success(pr) => completeRpcFuture(req.id, (router ? RouteRequest(appKit.nodeParams.nodeId, pr.nodeId, pr.amount.map(_.toLong).getOrElse(DEFAULT_AMOUNT_MSAT))).mapTo[RouteResponse]) case Failure(t) => reject(RpcValidationRejection(req.id, s"invalid payment request ${t.getLocalizedMessage}")) } case _ => reject(UnknownParamsRejection(req.id, "[payment_request] or [nodeId]")) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala index fa2ae0a45..24255995d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala @@ -43,7 +43,7 @@ class PaymentLifecycle(sourceNodeId: PublicKey, router: ActorRef, register: Acto when(WAITING_FOR_REQUEST) { case Event(c: SendPayment, WaitingForRequest) => - router ! RouteRequest(sourceNodeId, c.targetNodeId, c.assistedRoutes) + router ! RouteRequest(sourceNodeId, c.targetNodeId, c.amountMsat, c.assistedRoutes) goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, c, failures = Nil) } @@ -103,12 +103,12 @@ class PaymentLifecycle(sourceNodeId: PublicKey, router: ActorRef, register: Acto // in that case we don't know which node is sending garbage, let's try to blacklist all nodes except the one we are directly connected to and the destination node val blacklist = hops.map(_.nextNodeId).drop(1).dropRight(1) log.warning(s"blacklisting intermediate nodes=${blacklist.mkString(",")}") - router ! RouteRequest(sourceNodeId, c.targetNodeId, c.assistedRoutes, ignoreNodes ++ blacklist, ignoreChannels) + router ! RouteRequest(sourceNodeId, c.targetNodeId, c.amountMsat, c.assistedRoutes, ignoreNodes ++ blacklist, ignoreChannels) goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ UnreadableRemoteFailure(hops)) case Success(e@ErrorPacket(nodeId, failureMessage: Node)) => log.info(s"received 'Node' type error message from nodeId=$nodeId, trying to route around it (failure=$failureMessage)") // let's try to route around this node - router ! RouteRequest(sourceNodeId, c.targetNodeId, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels) + router ! RouteRequest(sourceNodeId, c.targetNodeId, c.amountMsat, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels) goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e)) case Success(e@ErrorPacket(nodeId, failureMessage: Update)) => log.info(s"received 'Update' type error message from nodeId=$nodeId, retrying payment (failure=$failureMessage)") @@ -136,18 +136,18 @@ class PaymentLifecycle(sourceNodeId: PublicKey, router: ActorRef, register: Acto // in any case, we forward the update to the router router ! failureMessage.update // let's try again, router will have updated its state - router ! RouteRequest(sourceNodeId, c.targetNodeId, c.assistedRoutes, ignoreNodes, ignoreChannels) + router ! RouteRequest(sourceNodeId, c.targetNodeId, c.amountMsat, c.assistedRoutes, ignoreNodes, ignoreChannels) } else { // this node is fishy, it gave us a bad sig!! let's filter it out log.warning(s"got bad signature from node=$nodeId update=${failureMessage.update}") - router ! RouteRequest(sourceNodeId, c.targetNodeId, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels) + router ! RouteRequest(sourceNodeId, c.targetNodeId, c.amountMsat, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels) } goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e)) case Success(e@ErrorPacket(nodeId, failureMessage)) => log.info(s"received an error message from nodeId=$nodeId, trying to use a different channel (failure=$failureMessage)") // let's try again without the channel outgoing from nodeId val faultyChannel = hops.find(_.nodeId == nodeId).map(hop => ChannelDesc(hop.lastUpdate.shortChannelId, hop.nodeId, hop.nextNodeId)) - router ! RouteRequest(sourceNodeId, c.targetNodeId, c.assistedRoutes, ignoreNodes, ignoreChannels ++ faultyChannel.toSet) + router ! RouteRequest(sourceNodeId, c.targetNodeId, c.amountMsat, c.assistedRoutes, ignoreNodes, ignoreChannels ++ faultyChannel.toSet) goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e)) } @@ -166,7 +166,7 @@ class PaymentLifecycle(sourceNodeId: PublicKey, router: ActorRef, register: Acto } else { log.info(s"received an error message from local, trying to use a different channel (failure=${t.getMessage})") val faultyChannel = ChannelDesc(hops.head.lastUpdate.shortChannelId, hops.head.nodeId, hops.head.nextNodeId) - router ! RouteRequest(sourceNodeId, c.targetNodeId, c.assistedRoutes, ignoreNodes, ignoreChannels + faultyChannel) + router ! RouteRequest(sourceNodeId, c.targetNodeId, c.amountMsat, c.assistedRoutes, ignoreNodes, ignoreChannels + faultyChannel) goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ LocalFailure(t)) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala new file mode 100644 index 000000000..f8d0308b2 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala @@ -0,0 +1,326 @@ +package fr.acinq.eclair.router + +import fr.acinq.bitcoin.Crypto.PublicKey + +import scala.collection.mutable +import fr.acinq.eclair._ +import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} +import fr.acinq.eclair.wire.ChannelUpdate + +object Graph { + + import DirectedGraph._ + + case class WeightedNode(key: PublicKey, weight: Long) + + /** + * This comparator must be consistent with the "equals" behavior, thus for two weighted nodes with + * the same weight we distinguish them by their public key. See https://docs.oracle.com/javase/8/docs/api/java/util/Comparator.html + */ + object QueueComparator extends Ordering[WeightedNode] { + override def compare(x: WeightedNode, y: WeightedNode): Int = { + val weightCmp = x.weight.compareTo(y.weight) + if (weightCmp == 0) x.key.toString().compareTo(y.key.toString()) + else weightCmp + } + } + + /** + * Finds the shortest path in the graph, Dijsktra's algorithm + * + * @param g the graph on which will be performed the search + * @param sourceNode the starting node of the path we're looking for + * @param targetNode the destination node of the path + * @param amountMsat the amount (in millisatoshis) we want to transmit + * @param ignoredEdges a list of edges we do not want to consider + * @param extraEdges a list of extra edges we want to consider but are not currently in the graph + * @return + */ + def shortestPath(g: DirectedGraph, sourceNode: PublicKey, targetNode: PublicKey, amountMsat: Long, ignoredEdges: Set[ChannelDesc], extraEdges: Set[GraphEdge]): Seq[Hop] = { + dijkstraShortestPath(g, sourceNode, targetNode, amountMsat, ignoredEdges, extraEdges).map(graphEdgeToHop) + } + + def dijkstraShortestPath(g: DirectedGraph, sourceNode: PublicKey, targetNode: PublicKey, amountMsat: Long, ignoredEdges: Set[ChannelDesc], extraEdges: Set[GraphEdge]): Seq[GraphEdge] = { + + // optionally add the extra edges to the graph + val graphVerticesWithExtra = extraEdges.nonEmpty match { + case true => g.vertexSet() ++ extraEdges.map(_.desc.a) ++ extraEdges.map(_.desc.b) + case false => g.vertexSet() + } + + // the graph does not contain source/destination nodes + if (!graphVerticesWithExtra.contains(sourceNode)) return Seq.empty + if (!graphVerticesWithExtra.contains(targetNode)) return Seq.empty + + val maxMapSize = graphVerticesWithExtra.size + 1 + + // this is not the actual optimal size for the maps, because we only put in there all the vertices in the worst case scenario. + val cost = new java.util.HashMap[PublicKey, Long](maxMapSize) + val prev = new java.util.HashMap[PublicKey, GraphEdge](maxMapSize) + val vertexQueue = new org.jheaps.tree.SimpleFibonacciHeap[WeightedNode, Short](QueueComparator) + + // initialize the queue and cost array + cost.put(sourceNode, 0) + vertexQueue.insert(WeightedNode(sourceNode, 0)) + + var targetFound = false + + while (!vertexQueue.isEmpty && !targetFound) { + + // node with the smallest distance from the source + val current = vertexQueue.deleteMin().getKey // O(log(n)) + + if (current.key != targetNode) { + + // build the neighbors with optional extra edges + val currentNeighbors = extraEdges.isEmpty match { + case true => g.edgesOf(current.key) + case false => g.edgesOf(current.key) ++ extraEdges.filter(_.desc.a == current.key) + } + + // for each neighbor + currentNeighbors.foreach { edge => + + // test for ignored edges + if (!(edge.update.htlcMaximumMsat.exists(_ < amountMsat) || + amountMsat < edge.update.htlcMinimumMsat || + ignoredEdges.contains(edge.desc)) + ) { + + val neighbor = edge.desc.b + + // note: the default value here will never be used, as there is always an entry for the current in the 'cost' map + val newMinimumKnownCost = cost.get(current.key) + edgeWeightByAmount(edge, amountMsat) + + // we call containsKey first because "getOrDefault" is not available in JDK7 + val neighborCost = cost.containsKey(neighbor) match { + case false => Long.MaxValue + case true => cost.get(neighbor) + } + + // if this neighbor has a shorter distance than previously known + if (newMinimumKnownCost < neighborCost) { + + // update the visiting tree + prev.put(neighbor, edge) + + // update the queue + vertexQueue.insert(WeightedNode(neighbor, newMinimumKnownCost)) // O(1) + + // update the minimum known distance array + cost.put(neighbor, newMinimumKnownCost) + } + } + } + } else { // we popped the target node from the queue, no need to search any further + targetFound = true + } + } + + targetFound match { + case false => Seq.empty[GraphEdge] + case true => { + // we traverse the list of "previous" backward building the final list of edges that make the shortest path + val edgePath = new mutable.ArrayBuffer[GraphEdge](21) // max path length is 20! https://github.com/lightningnetwork/lightning-rfc/blob/master/04-onion-routing.md#clarifications + var current = prev.get(targetNode) + + while (current != null) { + + edgePath += current + current = prev.get(current.desc.a) + } + + edgePath.reverse + } + } + } + + private def edgeWeightByAmount(edge: GraphEdge, amountMsat: Long): Long = { + nodeFee(edge.update.feeBaseMsat, edge.update.feeProportionalMillionths, amountMsat) + } + + /** + * A graph data structure that uses the adjacency lists + */ + object GraphStructure { + + /** + * Representation of an edge of the graph + * + * @param desc channel description + * @param update channel info + */ + case class GraphEdge(desc: ChannelDesc, update: ChannelUpdate) + + case class DirectedGraph(private val vertices: Map[PublicKey, List[GraphEdge]]) { + + def addEdge(d: ChannelDesc, u: ChannelUpdate): DirectedGraph = addEdge(GraphEdge(d, u)) + + def addEdges(edges: Seq[(ChannelDesc, ChannelUpdate)]): DirectedGraph = { + edges.foldLeft(this)((acc, edge) => acc.addEdge(edge._1, edge._2)) + } + + /** + * Adds and edge to the graph, if one of the two vertices is not found, it will be created + * + * @param edge the edge that is going to be added to the graph + * @return a new graph containing this edge + */ + def addEdge(edge: GraphEdge): DirectedGraph = { + + val vertexIn = edge.desc.a + val vertexOut = edge.desc.b + + // the graph is allowed to have multiple edges between the same vertices but only one per channel + if (containsEdge(edge.desc)) { + removeEdge(edge.desc).addEdge(edge) + } else { + val withVertices = addVertex(vertexIn).addVertex(vertexOut) + DirectedGraph(withVertices.vertices.updated(vertexIn, edge +: withVertices.vertices(vertexIn))) + } + } + + /** + * Removes the edge corresponding to the given pair channel-desc/channel-update, + * NB: this operation does NOT remove any vertex + * + * @param desc the channel description associated to the edge that will be removed + * @return + */ + def removeEdge(desc: ChannelDesc): DirectedGraph = { + containsEdge(desc) match { + case true => DirectedGraph(vertices.updated(desc.a, vertices(desc.a).filterNot(_.desc == desc))) + case false => this + } + } + + def removeEdges(descList: Seq[ChannelDesc]): DirectedGraph = { + descList.foldLeft(this)((acc, edge) => acc.removeEdge(edge)) + } + + /** + * @param edge + * @return For edges to be considered equal they must have the same in/out vertices AND same shortChannelId + */ + def getEdge(edge: GraphEdge): Option[GraphEdge] = getEdge(edge.desc) + + def getEdge(desc: ChannelDesc): Option[GraphEdge] = vertices.get(desc.a).flatMap { adj => + adj.find(e => e.desc.shortChannelId == desc.shortChannelId && e.desc.b == desc.b) + } + + /** + * @param keyA the key associated with the starting vertex + * @param keyB the key associated with the ending vertex + * @return all the edges going from keyA --> keyB (there might be more than one if it refers to different shortChannelId) + */ + def getEdgesBetween(keyA: PublicKey, keyB: PublicKey): Seq[GraphEdge] = { + vertices.get(keyA) match { + case None => Seq.empty + case Some(adj) => adj.filter(e => e.desc.b == keyB) + } + } + + def getIncomingEdgesOf(keyA: PublicKey): Seq[GraphEdge] = { + edgeSet().filter(_.desc.b == keyA).toSeq + } + + /** + * Removes a vertex and all it's associated edges (both incoming and outgoing) + * + * @param key + * @return + */ + def removeVertex(key: PublicKey): DirectedGraph = { + DirectedGraph(removeEdges(getIncomingEdgesOf(key).map(_.desc)).vertices - key) + } + + /** + * Adds a new vertex to the graph, starting with no edges + * + * @param key + * @return + */ + def addVertex(key: PublicKey): DirectedGraph = { + vertices.get(key) match { + case None => DirectedGraph(vertices + (key -> List.empty)) + case _ => this + } + } + + /** + * @param key + * @return a list of the outgoing edges of vertex @param key, if the edge doesn't exists an empty list is returned + */ + def edgesOf(key: PublicKey): Seq[GraphEdge] = vertices.getOrElse(key, List.empty) + + /** + * @return the set of all the vertices in this graph + */ + def vertexSet(): Set[PublicKey] = vertices.keySet + + /** + * @return an iterator of all the edges in this graph + */ + def edgeSet(): Iterable[GraphEdge] = vertices.values.flatten + + /** + * @param key + * @return true if this graph contain a vertex with this key, false otherwise + */ + def containsVertex(key: PublicKey): Boolean = vertices.contains(key) + + /** + * @param desc + * @return true if this edge desc is in the graph. For edges to be considered equal they must have the same in/out vertices AND same shortChannelId + */ + def containsEdge(desc: ChannelDesc): Boolean = vertices.get(desc.a) match { + case None => false + case Some(adj) => adj.exists(neighbor => neighbor.desc.shortChannelId == desc.shortChannelId && neighbor.desc.b == desc.b) + } + + def prettyPrint(): String = { + vertices.foldLeft("") { case (acc, (vertex, adj)) => + acc + s"[${vertex.toString().take(5)}]: ${adj.map("-> " + _.desc.b.toString().take(5))} \n" + } + } + } + + object DirectedGraph { + + // convenience constructors + def apply(): DirectedGraph = new DirectedGraph(Map()) + + def apply(key: PublicKey): DirectedGraph = new DirectedGraph(Map(key -> List.empty)) + + def apply(edge: GraphEdge): DirectedGraph = new DirectedGraph(Map()).addEdge(edge.desc, edge.update) + + def apply(edges: Seq[GraphEdge]): DirectedGraph = { + makeGraph(edges.map(e => e.desc -> e.update).toMap) + } + + // optimized constructor + def makeGraph(descAndUpdates: Map[ChannelDesc, ChannelUpdate]): DirectedGraph = { + + // initialize the map with the appropriate size to avoid resizing during the graph initialization + val mutableMap = new {} with mutable.HashMap[PublicKey, List[GraphEdge]] { + override def initialSize: Int = descAndUpdates.size + 1 + } + + // add all the vertices and edges in one go + descAndUpdates.foreach { case (desc, update) => + // create or update vertex (desc.a) and update its neighbor + mutableMap.put(desc.a, GraphEdge(desc, update) +: mutableMap.getOrElse(desc.a, List.empty[GraphEdge])) + mutableMap.get(desc.b) match { + case None => mutableMap += desc.b -> List.empty[GraphEdge] + case _ => + } + } + + new DirectedGraph(mutableMap.toMap) + } + + def graphEdgeToHop(graphEdge: GraphEdge): Hop = Hop(graphEdge.desc.a, graphEdge.desc.b, graphEdge.update) + } + + } +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index 63ab83c09..827d750d0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -16,8 +16,6 @@ package fr.acinq.eclair.router -import java.io.StringWriter - import akka.actor.{ActorRef, Props, Status} import akka.event.Logging.MDC import akka.pattern.pipe @@ -28,29 +26,27 @@ import fr.acinq.eclair._ import fr.acinq.eclair.blockchain._ import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.TransportHandler -import fr.acinq.eclair.io.Peer.{ChannelClosed, NonexistingChannel, InvalidSignature, PeerRoutingMessage} +import fr.acinq.eclair.io.Peer.{ChannelClosed, InvalidSignature, NonexistingChannel, PeerRoutingMessage} import fr.acinq.eclair.payment.PaymentRequest.ExtraHop +import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire._ -import org.jgrapht.WeightedGraph -import org.jgrapht.alg.shortestpath.DijkstraShortestPath -import org.jgrapht.ext._ -import org.jgrapht.graph._ -import scala.collection.JavaConversions._ -import scala.collection.SortedSet +import scala.collection.{SortedSet, mutable} import scala.collection.immutable.{SortedMap, TreeMap} import scala.compat.Platform import scala.concurrent.duration._ -import scala.concurrent.{ExecutionContext, Future, Promise} +import scala.concurrent.{ExecutionContext, Promise} import scala.util.Try // @formatter:off case class ChannelDesc(shortChannelId: ShortChannelId, a: PublicKey, b: PublicKey) case class Hop(nodeId: PublicKey, nextNodeId: PublicKey, lastUpdate: ChannelUpdate) -case class RouteRequest(source: PublicKey, target: PublicKey, assistedRoutes: Seq[Seq[ExtraHop]] = Nil, ignoreNodes: Set[PublicKey] = Set.empty, ignoreChannels: Set[ChannelDesc] = Set.empty) -case class RouteResponse(hops: Seq[Hop], ignoreNodes: Set[PublicKey], ignoreChannels: Set[ChannelDesc]) { require(hops.size > 0, "route cannot be empty") } +case class RouteRequest(source: PublicKey, target: PublicKey, amountMsat: Long, assistedRoutes: Seq[Seq[ExtraHop]] = Nil, ignoreNodes: Set[PublicKey] = Set.empty, ignoreChannels: Set[ChannelDesc] = Set.empty) +case class RouteResponse(hops: Seq[Hop], ignoreNodes: Set[PublicKey], ignoreChannels: Set[ChannelDesc]) { + require(hops.size > 0, "route cannot be empty") +} case class ExcludeChannel(desc: ChannelDesc) // this is used when we get a TemporaryChannelFailure, to give time for the channel to recover (note that exclusions are directed) case class LiftChannelExclusion(desc: ChannelDesc) case class SendChannelQuery(remoteNodeId: PublicKey, to: ActorRef) @@ -61,8 +57,6 @@ case class Rebroadcast(channels: Map[ChannelAnnouncement, Set[ActorRef]], update case class Sync(missing: SortedSet[ShortChannelId], totalMissingCount: Int) -case class DescEdge(desc: ChannelDesc, u: ChannelUpdate) extends DefaultWeightedEdge - case class Data(nodes: Map[PublicKey, NodeAnnouncement], channels: SortedMap[ShortChannelId, ChannelAnnouncement], updates: Map[ChannelDesc, ChannelUpdate], @@ -72,9 +66,9 @@ case class Data(nodes: Map[PublicKey, NodeAnnouncement], privateChannels: Map[ShortChannelId, PublicKey], // short_channel_id -> node_id privateUpdates: Map[ChannelDesc, ChannelUpdate], excludedChannels: Set[ChannelDesc], // those channels are temporarily excluded from route calculation, because their node returned a TemporaryChannelFailure - graph: DirectedWeightedPseudograph[PublicKey, DescEdge], + graph: DirectedGraph, sync: Map[PublicKey, Sync] // keep tracks of channel range queries sent to each peer. If there is an entry in the map, it means that there is an ongoing query - // for which we have not yet received an 'end' message + // for which we have not yet received an 'end' message ) sealed trait State @@ -111,23 +105,18 @@ class Router(nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Prom val nodes = db.listNodes() val updates = db.listChannelUpdates() log.info("loaded from db: channels={} nodes={} updates={}", channels.size, nodes.size, updates.size) - - // this will be used to calculate routes - val graph = new DirectedWeightedPseudograph[PublicKey, DescEdge](classOf[DescEdge]) - val initChannels = channels.keys.foldLeft(TreeMap.empty[ShortChannelId, ChannelAnnouncement]) { case (m, c) => m + (c.shortChannelId -> c) } val initChannelUpdates = updates.map { u => val desc = getDesc(u, initChannels(u.shortChannelId)) - addEdge(graph, desc, u) - (desc) -> u + desc -> u }.toMap + // this will be used to calculate routes + val graph = DirectedGraph.makeGraph(initChannelUpdates) val initNodes = nodes.map(n => (n.nodeId -> n)).toMap - // send events for remaining channels/nodes initChannels.values.foreach(c => context.system.eventStream.publish(ChannelDiscovered(c, channels(c)._2))) initChannelUpdates.values.foreach(u => context.system.eventStream.publish(ChannelUpdateReceived(u))) initNodes.values.foreach(n => context.system.eventStream.publish(NodeDiscovered(n))) - initChannelUpdates.values.foreach(u => context.system.eventStream.publish(ChannelUpdateReceived(u))) // watch the funding tx of all these channels // note: some of them may already have been spent, in that case we will receive the watch event immediately @@ -189,10 +178,11 @@ class Router(nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Prom val desc1 = ChannelDesc(shortChannelId, nodeParams.nodeId, remoteNodeId) val desc2 = ChannelDesc(shortChannelId, remoteNodeId, nodeParams.nodeId) // we remove the corresponding updates from the graph - removeEdge(d.graph, desc1) - removeEdge(d.graph, desc2) + val graph1 = d.graph + .removeEdge(desc1) + .removeEdge(desc2) // and we remove the channel and channel_update from our state - stay using d.copy(privateChannels = d.privateChannels - shortChannelId, privateUpdates = d.privateUpdates - desc1 - desc2) + stay using d.copy(privateChannels = d.privateChannels - shortChannelId, privateUpdates = d.privateUpdates - desc1 - desc2, graph = graph1) } else { stay } @@ -294,9 +284,11 @@ class Router(nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Prom // let's clean the db and send the events log.info("pruning shortChannelId={} (spent)", shortChannelId) db.removeChannel(shortChannelId) // NB: this also removes channel updates - // we also need to remove updates from the graph - removeEdge(d.graph, ChannelDesc(lostChannel.shortChannelId, lostChannel.nodeId1, lostChannel.nodeId2)) - removeEdge(d.graph, ChannelDesc(lostChannel.shortChannelId, lostChannel.nodeId2, lostChannel.nodeId1)) + // we also need to remove updates from the graph + val graph1 = d.graph + .removeEdge(ChannelDesc(lostChannel.shortChannelId, lostChannel.nodeId1, lostChannel.nodeId2)) + .removeEdge(ChannelDesc(lostChannel.shortChannelId, lostChannel.nodeId2, lostChannel.nodeId1)) + context.system.eventStream.publish(ChannelLost(shortChannelId)) lostNodes.foreach { case nodeId => @@ -304,7 +296,7 @@ class Router(nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Prom db.removeNode(nodeId) context.system.eventStream.publish(NodeLost(nodeId)) } - stay using d.copy(nodes = d.nodes -- lostNodes, channels = d.channels - shortChannelId, updates = d.updates.filterKeys(_.shortChannelId != shortChannelId)) + stay using d.copy(nodes = d.nodes -- lostNodes, channels = d.channels - shortChannelId, updates = d.updates.filterKeys(_.shortChannelId != shortChannelId), graph = graph1) case Event(TickBroadcast, d) => if (d.rebroadcast.channels.isEmpty && d.rebroadcast.updates.isEmpty && d.rebroadcast.nodes.isEmpty) { @@ -335,18 +327,21 @@ class Router(nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Prom db.addToPruned(shortChannelId) context.system.eventStream.publish(ChannelLost(shortChannelId)) } - // we also need to remove updates from the graph - staleChannels.map(d.channels).foreach { c => - removeEdge(d.graph, ChannelDesc(c.shortChannelId, c.nodeId1, c.nodeId2)) - removeEdge(d.graph, ChannelDesc(c.shortChannelId, c.nodeId2, c.nodeId1)) - } + + val staleChannelsToRemove = new mutable.MutableList[ChannelDesc] + staleChannels.map(d.channels).foreach( ca => { + staleChannelsToRemove += ChannelDesc(ca.shortChannelId, ca.nodeId1, ca.nodeId2) + staleChannelsToRemove += ChannelDesc(ca.shortChannelId, ca.nodeId2, ca.nodeId1) + }) + + val graph1 = d.graph.removeEdges(staleChannelsToRemove) staleNodes.foreach { case nodeId => log.info("pruning nodeId={} (stale)", nodeId) db.removeNode(nodeId) context.system.eventStream.publish(NodeLost(nodeId)) } - stay using d.copy(nodes = d.nodes -- staleNodes, channels = channels1, updates = d.updates -- staleUpdates) + stay using d.copy(nodes = d.nodes -- staleNodes, channels = channels1, updates = d.updates -- staleUpdates, graph = graph1) case Event(ExcludeChannel(desc@ChannelDesc(shortChannelId, nodeId, _)), d) => val banDuration = nodeParams.channelExcludeDuration @@ -374,11 +369,7 @@ class Router(nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Prom sender ! (d.updates ++ d.privateUpdates) stay - case Event('dot, d) => - graph2dot(d.nodes, d.channels) pipeTo sender - stay - - case Event(RouteRequest(start, end, assistedRoutes, ignoreNodes, ignoreChannels), d) => + case Event(RouteRequest(start, end, amount, assistedRoutes, ignoreNodes, ignoreChannels), d) => // we convert extra routing info provided in the payment request to fake channel_update // it takes precedence over all other channel_updates we know val assistedUpdates = assistedRoutes.flatMap(toFakeUpdates(_, end)).toMap @@ -386,9 +377,10 @@ class Router(nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Prom // TODO: in case of duplicates, d.updates will be overridden by assistedUpdates even if they are more recent! val ignoredUpdates = getIgnoredChannelDesc(d.updates ++ d.privateUpdates ++ assistedUpdates, ignoreNodes) ++ ignoreChannels ++ d.excludedChannels log.info(s"finding a route $start->$end with assistedChannels={} ignoreNodes={} ignoreChannels={} excludedChannels={}", assistedUpdates.keys.mkString(","), ignoreNodes.map(_.toBin).mkString(","), ignoreChannels.mkString(","), d.excludedChannels.mkString(",")) - findRoute(d.graph, start, end, withEdges = assistedUpdates, withoutEdges = ignoredUpdates) - .map(r => sender ! RouteResponse(r, ignoreNodes, ignoreChannels)) - .recover { case t => sender ! Status.Failure(t) } + val extraEdges = assistedUpdates.map { case (c, u) => GraphEdge(c, u) }.toSet + findRoute(d.graph, start, end, amount, extraEdges = extraEdges, ignoredEdges = ignoredUpdates.toSet) + .map(r => sender ! RouteResponse(r, ignoreNodes, ignoreChannels)) + .recover { case t => sender ! Status.Failure(t) } stay case Event(SendChannelQuery(remoteNodeId, remote), d) => @@ -599,17 +591,19 @@ class Router(nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Prom log.debug("updated channel_update for shortChannelId={} public={} flags={} {}", u.shortChannelId, publicChannel, u.channelFlags, u) context.system.eventStream.publish(ChannelUpdateReceived(u)) db.updateChannelUpdate(u) - // we also need to update the graph - removeEdge(d.graph, desc) - addEdge(d.graph, desc, u) - d.copy(updates = d.updates + (desc -> u), rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> Set(origin)))) + // update the graph + val graph1 = Announcements.isEnabled(u.channelFlags) match { + case true => d.graph.removeEdge(desc).addEdge(desc, u) + case false => d.graph.removeEdge(desc) // if the channel is now disabled, we remove it from the graph + } + d.copy(updates = d.updates + (desc -> u), rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> Set(origin))), graph = graph1) } else { log.debug("added channel_update for shortChannelId={} public={} flags={} {}", u.shortChannelId, publicChannel, u.channelFlags, u) context.system.eventStream.publish(ChannelUpdateReceived(u)) db.addChannelUpdate(u) // we also need to update the graph - addEdge(d.graph, desc, u) - d.copy(updates = d.updates + (desc -> u), privateUpdates = d.privateUpdates - desc, rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> Set(origin)))) + val graph1 = d.graph.addEdge(desc, u) + d.copy(updates = d.updates + (desc -> u), privateUpdates = d.privateUpdates - desc, rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> Set(origin))), graph = graph1) } } else if (d.awaiting.keys.exists(c => c.shortChannelId == u.shortChannelId)) { // channel is currently being validated @@ -640,15 +634,14 @@ class Router(nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Prom log.debug("updated channel_update for shortChannelId={} public={} flags={} {}", u.shortChannelId, publicChannel, u.channelFlags, u) context.system.eventStream.publish(ChannelUpdateReceived(u)) // we also need to update the graph - removeEdge(d.graph, desc) - addEdge(d.graph, desc, u) - d.copy(privateUpdates = d.privateUpdates + (desc -> u)) + val graph1 = d.graph.removeEdge(desc).addEdge(desc, u) + d.copy(privateUpdates = d.privateUpdates + (desc -> u), graph = graph1) } else { log.debug("added channel_update for shortChannelId={} public={} flags={} {}", u.shortChannelId, publicChannel, u.channelFlags, u) context.system.eventStream.publish(ChannelUpdateReceived(u)) // we also need to update the graph - addEdge(d.graph, desc, u) - d.copy(privateUpdates = d.privateUpdates + (desc -> u)) + val graph1 = d.graph.addEdge(desc, u) + d.copy(privateUpdates = d.privateUpdates + (desc -> u), graph = graph1) } } else if (db.isPruned(u.shortChannelId) && !isStale(u)) { // the channel was recently pruned, but if we are here, it means that the update is not stale so this is the case @@ -780,107 +773,27 @@ object Router { } /** - * Routing fee have a variable part, as a simplification we compute fees using a default constant value for the amount + * Routing fee have a variable part, this value will be used as a default if none is provided when search for a route */ val DEFAULT_AMOUNT_MSAT = 10000000 /** - * Careful: this function *mutates* the graph - * - * Note that we only add the edge if the corresponding channel is enabled - */ - def addEdge(g: WeightedGraph[PublicKey, DescEdge], d: ChannelDesc, u: ChannelUpdate) = { - if (Announcements.isEnabled(u.channelFlags)) { - g.addVertex(d.a) - g.addVertex(d.b) - val e = new DescEdge(d, u) - val weight = nodeFee(u.feeBaseMsat, u.feeProportionalMillionths, DEFAULT_AMOUNT_MSAT).toDouble - g.addEdge(d.a, d.b, e) - g.setEdgeWeight(e, weight) - } - } - - /** - * Careful: this function *mutates* the graph - * - * NB: we don't clean up vertices - * - */ - def removeEdge(g: WeightedGraph[PublicKey, DescEdge], d: ChannelDesc) = { - import scala.collection.JavaConversions._ - Option(g.getAllEdges(d.a, d.b)) match { - case Some(edges) => edges.find(_.desc == d) match { - case Some(e) => g.removeEdge(e) - case None => () - } - case None => () - } - } - - /** - * Find a route in the graph between localNodeId and targetNodeId + * Find a route in the graph between localNodeId and targetNodeId, returns the route and its cost * * @param g * @param localNodeId * @param targetNodeId - * @param withEdges those will be added before computing the route, and removed after so that g is left unchanged - * @param withoutEdges those will be removed before computing the route, and added back after so that g is left unchanged - * @return + * @param amountMsat the amount that will be sent along this route + * @param extraEdges a set of extra edges we want to CONSIDER during the search + * @param ignoredEdges a set of extra edges we want to IGNORE during the search + * @return the computed route to the destination @targetNodeId */ - def findRoute(g: DirectedWeightedPseudograph[PublicKey, DescEdge], localNodeId: PublicKey, targetNodeId: PublicKey, withEdges: Map[ChannelDesc, ChannelUpdate] = Map.empty, withoutEdges: Iterable[ChannelDesc] = Iterable.empty): Try[Seq[Hop]] = Try { + def findRoute(g: DirectedGraph, localNodeId: PublicKey, targetNodeId: PublicKey, amountMsat: Long, extraEdges: Set[GraphEdge] = Set.empty, ignoredEdges: Set[ChannelDesc] = Set.empty): Try[Seq[Hop]] = Try { if (localNodeId == targetNodeId) throw CannotRouteToSelf - val workingGraph = if (withEdges.isEmpty && withoutEdges.isEmpty) { - // no filtering, let's work on the base graph - g - } else { - // slower but safer: we duplicate the graph and add/remove updates from the duplicated version - val clonedGraph = g.clone().asInstanceOf[DirectedWeightedPseudograph[PublicKey, DescEdge]] - withEdges.foreach { case (d, u) => - removeEdge(clonedGraph, d) - addEdge(clonedGraph, d, u) - } - withoutEdges.foreach { d => removeEdge(clonedGraph, d) } - clonedGraph - } - if (!workingGraph.containsVertex(localNodeId)) throw RouteNotFound - if (!workingGraph.containsVertex(targetNodeId)) throw RouteNotFound - val route_opt = Option(DijkstraShortestPath.findPathBetween(workingGraph, localNodeId, targetNodeId)) - route_opt match { - case Some(path) => path.getEdgeList.map(edge => Hop(edge.desc.a, edge.desc.b, edge.u)) - case None => throw RouteNotFound - } - } - def graph2dot(nodes: Map[PublicKey, NodeAnnouncement], channels: Map[ShortChannelId, ChannelAnnouncement])(implicit ec: ExecutionContext): Future[String] = Future { - case class DescEdge(shortChannelId: ShortChannelId) extends DefaultEdge - val g = new SimpleGraph[PublicKey, DescEdge](classOf[DescEdge]) - channels.foreach(d => { - g.addVertex(d._2.nodeId1) - g.addVertex(d._2.nodeId2) - g.addEdge(d._2.nodeId1, d._2.nodeId2, new DescEdge(d._1)) - }) - val vertexIDProvider = new ComponentNameProvider[PublicKey]() { - override def getName(nodeId: PublicKey): String = "\"" + nodeId.toString() + "\"" - } - val edgeLabelProvider = new ComponentNameProvider[DescEdge]() { - override def getName(e: DescEdge): String = e.shortChannelId.toString - } - val vertexAttributeProvider = new ComponentAttributeProvider[PublicKey]() { - - override def getComponentAttributes(nodeId: PublicKey): java.util.Map[String, String] = - - nodes.get(nodeId) match { - case Some(ann) => Map("label" -> ann.alias, "color" -> ann.rgbColor.toString) - case None => Map.empty[String, String] - } - } - val exporter = new DOTExporter[PublicKey, DescEdge](vertexIDProvider, null, edgeLabelProvider, vertexAttributeProvider, null) - val writer = new StringWriter() - try { - exporter.exportGraph(g, writer) - writer.toString - } finally { - writer.close() + Graph.shortestPath(g, localNodeId, targetNodeId, amountMsat, ignoredEdges, extraEdges) match { + case Nil => throw RouteNotFound + case path => path } } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index 5d33d25b6..a0a6e485e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -89,7 +89,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData - routerForwarder.expectMsg(RouteRequest(a, d, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) + routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) routerForwarder.forward(router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, _, cmd1, Nil, _, _, _, hops) = paymentFSM.stateData @@ -98,7 +98,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, UpdateFailHtlc("00" * 32, 0, defaultPaymentHash)) // unparsable message // then the payment lifecycle will ask for a new route excluding all intermediate nodes - routerForwarder.expectMsg(RouteRequest(a, d, ignoreNodes = Set(c), ignoreChannels = Set.empty)) + routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, ignoreNodes = Set(c), ignoreChannels = Set.empty)) // let's simulate a response by the router with another route sender.send(paymentFSM, RouteResponse(hops, Set(c), Set.empty)) @@ -127,7 +127,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData - routerForwarder.expectMsg(RouteRequest(a, d, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) + routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) routerForwarder.forward(router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, _, cmd1, Nil, _, _, _, hops) = paymentFSM.stateData @@ -136,7 +136,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, Status.Failure(AddHtlcFailed("00" * 32, request.paymentHash, ChannelUnavailable("00" * 32), Local(Some(paymentFSM.underlying.self)), None, None))) // then the payment lifecycle will ask for a new route excluding the channel - routerForwarder.expectMsg(RouteRequest(a, d, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set(ChannelDesc(channelId_ab, a, b)))) + routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set(ChannelDesc(channelId_ab, a, b)))) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) } @@ -155,7 +155,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData - routerForwarder.expectMsg(RouteRequest(a, d, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) + routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) routerForwarder.forward(router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, _, cmd1, Nil, _, _, _, hops) = paymentFSM.stateData @@ -164,7 +164,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, UpdateFailMalformedHtlc("00" * 32, 0, defaultPaymentHash, FailureMessageCodecs.BADONION)) // then the payment lifecycle will ask for a new route excluding the channel - routerForwarder.expectMsg(RouteRequest(a, d, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set(ChannelDesc(channelId_ab, a, b)))) + routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set(ChannelDesc(channelId_ab, a, b)))) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) } @@ -183,7 +183,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData - routerForwarder.expectMsg(RouteRequest(a, d, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) + routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) routerForwarder.forward(router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, _, hops) = paymentFSM.stateData @@ -199,7 +199,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // payment lifecycle forwards the embedded channelUpdate to the router routerForwarder.expectMsg(channelUpdate_bc) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) - routerForwarder.expectMsg(RouteRequest(a, d, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) + routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) routerForwarder.forward(router) // we allow 2 tries, so we send a 2nd request to the router sender.expectMsg(PaymentFailed(request.paymentHash, RemoteFailure(hops, ErrorPacket(b, failure)) :: LocalFailure(RouteNotFound) :: Nil)) @@ -220,7 +220,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData - routerForwarder.expectMsg(RouteRequest(a, d, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) + routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) routerForwarder.forward(router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, _, hops) = paymentFSM.stateData @@ -235,7 +235,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // payment lifecycle forwards the embedded channelUpdate to the router routerForwarder.expectMsg(channelUpdate_bc_modified) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) - routerForwarder.expectMsg(RouteRequest(a, d, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) + routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) routerForwarder.forward(router) // router answers with a new route, taking into account the new update @@ -255,7 +255,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // but it will still forward the embedded channelUpdate to the router routerForwarder.expectMsg(channelUpdate_bc_modified_2) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) - routerForwarder.expectMsg(RouteRequest(a, d, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) + routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) routerForwarder.forward(router) // this time the router can't find a route: game over @@ -277,7 +277,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData - routerForwarder.expectMsg(RouteRequest(a, d, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) + routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) routerForwarder.forward(router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, _, hops) = paymentFSM.stateData @@ -289,7 +289,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // payment lifecycle forwards the embedded channelUpdate to the router awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) - routerForwarder.expectMsg(RouteRequest(a, d, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set(ChannelDesc(channelId_bc, b, c)))) + routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set(ChannelDesc(channelId_bc, b, c)))) routerForwarder.forward(router) // we allow 2 tries, so we send a 2nd request to the router, which won't find another route diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala new file mode 100644 index 000000000..b05713372 --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala @@ -0,0 +1,220 @@ +package fr.acinq.eclair.router + +import fr.acinq.bitcoin.Crypto.PublicKey +import org.scalatest.FunSuite +import RouteCalculationSpec._ +import fr.acinq.eclair.ShortChannelId +import fr.acinq.eclair.router.Graph.GraphStructure.{GraphEdge, DirectedGraph} +import fr.acinq.eclair.wire.ChannelUpdate + +class GraphSpec extends FunSuite { + + val (a, b, c, d, e, f, g) = ( + PublicKey("02999fa724ec3c244e4da52b4a91ad421dc96c9a810587849cd4b2469313519c73"), //a + PublicKey("03f1cb1af20fe9ccda3ea128e27d7c39ee27375c8480f11a87c17197e97541ca6a"), //b + PublicKey("0358e32d245ff5f5a3eb14c78c6f69c67cea7846bdf9aeeb7199e8f6fbb0306484"), //c + PublicKey("029e059b6780f155f38e83601969919aae631ddf6faed58fe860c72225eb327d7c"), //d + PublicKey("02f38f4e37142cc05df44683a83e22dea608cf4691492829ff4cf99888c5ec2d3a"), //e + PublicKey("03fc5b91ce2d857f146fd9b986363374ffe04dc143d8bcd6d7664c8873c463cdfc"), //f + PublicKey("03864ef025fde8fb587d989186ce6a4a186895ee44a926bfc370e2c366597a3f8f") //g + ) + + /** + * /--> D --\ + * A --> B --> C + * \-> E/ + * + * @return + */ + def makeTestGraph() = { + + val updates = Seq( + makeUpdate(1L, a, b, 0, 0), + makeUpdate(2L, b, c, 0, 0), + makeUpdate(3L, a, d, 0, 0), + makeUpdate(4L, d, c, 0, 0), + makeUpdate(5L, c, e, 0, 0), + makeUpdate(6L, b, e, 0, 0) + ) + + DirectedGraph.makeGraph(updates.toMap) + } + + test("instantiate a graph, with vertices and then add edges") { + + val graph = DirectedGraph(a) + .addVertex(b) + .addVertex(c) + .addVertex(d) + .addVertex(e) + + assert(graph.containsVertex(a) && graph.containsVertex(e)) + assert(graph.vertexSet().size === 5) + + val otherGraph = graph.addVertex(a) //adding the same vertex twice! + assert(otherGraph.vertexSet().size === 5) + + // add some edges to the graph + + val (descAB, updateAB) = makeUpdate(1L, a, b, 0, 0) + val (descBC, updateBC) = makeUpdate(2L, b, c, 0, 0) + val (descAD, updateAD) = makeUpdate(3L, a, d, 0, 0) + val (descDC, updateDC) = makeUpdate(4L, d, c, 0, 0) + val (descCE, updateCE) = makeUpdate(5L, c, e, 0, 0) + + val graphWithEdges = graph + .addEdge(descAB, updateAB) + .addEdge(descAD, updateAD) + .addEdge(descBC, updateBC) + .addEdge(descDC, updateDC) + .addEdge(descCE, updateCE) + + assert(graphWithEdges.edgesOf(a).size === 2) + assert(graphWithEdges.edgesOf(b).size === 1) + assert(graphWithEdges.edgesOf(c).size === 1) + assert(graphWithEdges.edgesOf(d).size === 1) + assert(graphWithEdges.edgesOf(e).size === 0) + + val withRemovedEdges = graphWithEdges.removeEdge(descDC) + + assert(withRemovedEdges.edgesOf(d).size === 0) + } + + test("instantiate a graph adding edges only") { + + val edgeAB = edgeFromDesc(makeUpdate(1L, a, b, 0, 0)) + val (descBC, updateBC) = makeUpdate(2L, b, c, 0, 0) + val (descAD, updateAD) = makeUpdate(3L, a, d, 0, 0) + val (descDC, updateDC) = makeUpdate(4L, d, c, 0, 0) + val (descCE, updateCE) = makeUpdate(5L, c, e, 0, 0) + val (descBE, updateBE) = makeUpdate(6L, b, e, 0, 0) + + val graph = DirectedGraph(edgeAB) + .addEdge(descAD, updateAD) + .addEdge(descBC, updateBC) + .addEdge(descDC, updateDC) + .addEdge(descCE, updateCE) + .addEdge(descBE, updateBE) + + assert(graph.vertexSet().size === 5) + assert(graph.edgesOf(c).size === 1) + assert(graph.edgeSet().size === 6) + } + + test("containsEdge should return true if the graph contains that edge, false otherwise") { + + val updates = Seq( + makeUpdate(1L, a, b, 0, 0), + makeUpdate(2L, b, c, 0, 0), + makeUpdate(3L, c, d, 0, 0), + makeUpdate(4L, d, e, 0, 0) + ) + + val graph = DirectedGraph().addEdges(updates) + + assert(graph.containsEdge(descFromNodes(1, a, b))) + assert(!graph.containsEdge(descFromNodes(5, b, a))) + assert(graph.containsEdge(descFromNodes(2, b, c))) + assert(graph.containsEdge(descFromNodes(3, c, d))) + assert(graph.containsEdge(descFromNodes(4, d, e))) + assert(graph.containsEdge(ChannelDesc(ShortChannelId(4L), d, e))) // by channel desc + assert(!graph.containsEdge(ChannelDesc(ShortChannelId(4L), a, g))) // by channel desc + assert(!graph.containsEdge(descFromNodes(50, a, e))) + assert(!graph.containsEdge(descFromNodes(66, c, f))) // f isn't even in the graph + } + + test("should remove a set of edges") { + + val graph = makeTestGraph() + + val (descBE, _) = makeUpdate(6L, b, e, 0, 0) + val (descCE, _) = makeUpdate(5L, c, e, 0, 0) + val (descAD, _) = makeUpdate(3L, a, d, 0, 0) + val (descDC, _) = makeUpdate(4L, d, c, 0, 0) + + assert(graph.edgeSet().size === 6) + + val withRemovedEdge = graph.removeEdge(descBE) + assert(withRemovedEdge.edgeSet().size === 5) + + val withRemovedList = graph.removeEdges(Seq(descAD, descDC)) + assert(withRemovedList.edgeSet().size === 4) + + val withoutAnyIncomingEdgeInE = graph.removeEdges(Seq(descBE, descCE)) + assert(withoutAnyIncomingEdgeInE.containsVertex(e)) + assert(withoutAnyIncomingEdgeInE.getIncomingEdgesOf(e).size == 0) + } + + test("should get an edge given two vertices") { + + // contains an edge A --> B + val updates = Seq( + makeUpdate(1L, a, b, 0, 0), + makeUpdate(2L, b, c, 0, 0) + ) + + val graph = DirectedGraph().addEdges(updates) + + val edgesAB = graph.getEdgesBetween(a, b) + + assert(edgesAB.size === 1) //there should be an edge a --> b + assert(edgesAB.head.desc.a === a) + assert(edgesAB.head.desc.b === b) + + val bNeighbors = graph.edgesOf(b) + assert(bNeighbors.size === 1) + assert(bNeighbors.exists(_.desc.a === b)) //there should be an edge b -- c + assert(bNeighbors.exists(_.desc.b === c)) + } + + test("there can be multiple edges between the same vertices") { + + val graph = makeTestGraph() + + // A --> B , A --> D + assert(graph.edgesOf(a).size == 2) + + //now add a new edge a -> b but with a different channel update and a different ShortChannelId + val newEdgeForNewChannel = edgeFromDesc(makeUpdate(15L, a, b, 20, 0)) + val mutatedGraph = graph.addEdge(newEdgeForNewChannel) + + assert(mutatedGraph.edgesOf(a).size == 3) + + //if the ShortChannelId is the same we replace the edge and the update, this edge have an update with a different 'feeBaseMsat' + val edgeForTheSameChannel = edgeFromDesc(makeUpdate(15L, a, b, 30, 0)) + val mutatedGraph2 = mutatedGraph.addEdge(edgeForTheSameChannel) + + assert(mutatedGraph2.edgesOf(a).size == 3) // A --> B , A --> B , A --> D + assert(mutatedGraph2.getEdgesBetween(a, b).size === 2) + + assert(mutatedGraph2.getEdge(edgeForTheSameChannel).get.update.feeBaseMsat === 30) + } + + test("remove a vertex with incoming edges and check those edges are removed too") { + val graph = makeTestGraph() + + assert(graph.vertexSet().size === 5) + assert(graph.containsVertex(e)) + assert(graph.containsEdge(descFromNodes(5, c, e))) + assert(graph.containsEdge(descFromNodes(6, b, e))) + + //E has 2 incoming edges + val withoutE = graph.removeVertex(e) + + assert(withoutE.vertexSet().size === 4) + assert(!withoutE.containsVertex(e)) + assert(!withoutE.containsEdge(descFromNodes(5, c, e))) + assert(!withoutE.containsEdge(descFromNodes(6, b, e))) + } + + def edgeFromDesc(tuple: (ChannelDesc, ChannelUpdate)): GraphEdge = GraphEdge(tuple._1, tuple._2) + + def descFromNodes(shortChannelId: Long, a: PublicKey, b: PublicKey): ChannelDesc = { + makeUpdate(shortChannelId, a, b, 0, 0)._1 + } + + def edgeFromNodes(shortChannelId: Long, a: PublicKey, b: PublicKey): GraphEdge = { + edgeFromDesc(makeUpdate(shortChannelId, a, b, 0, 0)) + } + +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala index 0cf6f43db..310f157e9 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala @@ -19,9 +19,10 @@ package fr.acinq.eclair.router import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.{BinaryData, Block, Crypto} import fr.acinq.eclair.payment.PaymentRequest.ExtraHop +import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} import fr.acinq.eclair.wire._ +import fr.acinq.eclair.router.Router.DEFAULT_AMOUNT_MSAT import fr.acinq.eclair.{ShortChannelId, randomKey} -import org.jgrapht.graph.DirectedWeightedPseudograph import org.scalatest.FunSuite import scala.util.{Failure, Success} @@ -47,7 +48,7 @@ class RouteCalculationSpec extends FunSuite { val g = makeGraph(updates) - val route = Router.findRoute(g, a, e) + val route = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT) assert(route.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) } @@ -63,15 +64,60 @@ class RouteCalculationSpec extends FunSuite { val g = makeGraph(updates) - val route1 = Router.findRoute(g, a, e) + val route1 = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT) assert(route1.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) - Router.removeEdge(g, ChannelDesc(ShortChannelId(3L), c, d)) - val route2 = Router.findRoute(g, a, e) + val graphWithRemovedEdge = g.removeEdge(ChannelDesc(ShortChannelId(3L), c, d)) + val route2 = Router.findRoute(graphWithRemovedEdge, a, e, DEFAULT_AMOUNT_MSAT) assert(route2.map(hops2Ids) === Failure(RouteNotFound)) } + test("calculate the shortest path (hardcoded nodes)") { + + val (f, g, h, i) = ( + PublicKey("02999fa724ec3c244e4da52b4a91ad421dc96c9a810587849cd4b2469313519c73"), //source + PublicKey("03f1cb1af20fe9ccda3ea128e27d7c39ee27375c8480f11a87c17197e97541ca6a"), + PublicKey("0358e32d245ff5f5a3eb14c78c6f69c67cea7846bdf9aeeb7199e8f6fbb0306484"), + PublicKey("029e059b6780f155f38e83601969919aae631ddf6faed58fe860c72225eb327d7c") //target + ) + + val updates = List( + makeUpdate(1L, f, g, 0, 0), + makeUpdate(2L, g, h, 0, 0), + makeUpdate(3L, h, i, 0, 0), + makeUpdate(4L, f, i, 50, 0) //direct channel, more expensive + ).toMap + + val graph = makeGraph(updates) + + val route = Router.findRoute(graph, f, i, DEFAULT_AMOUNT_MSAT) + assert(route.map(hops2Ids) === Success(1 :: 2 :: 3 :: Nil)) + + } + + test("if there are multiple channels between the same node, select the cheapest") { + + val (f, g, h, i) = ( + PublicKey("02999fa724ec3c244e4da52b4a91ad421dc96c9a810587849cd4b2469313519c73"), //F source + PublicKey("03f1cb1af20fe9ccda3ea128e27d7c39ee27375c8480f11a87c17197e97541ca6a"), //G + PublicKey("0358e32d245ff5f5a3eb14c78c6f69c67cea7846bdf9aeeb7199e8f6fbb0306484"), //H + PublicKey("029e059b6780f155f38e83601969919aae631ddf6faed58fe860c72225eb327d7c") //I target + ) + + val updates = List( + makeUpdate(1L, f, g, 0, 0), + makeUpdate(2L, g, h, 5, 5), //expensive g -> h channel + makeUpdate(6L, g, h, 0, 0), //cheap g -> h channel + makeUpdate(3L, h, i, 0, 0) + ).toMap + + val graph = makeGraph(updates) + + val route = Router.findRoute(graph, f, i, DEFAULT_AMOUNT_MSAT) + assert(route.map(hops2Ids) === Success(1 :: 6 :: 3 :: Nil)) + } + test("calculate longer but cheaper route") { val updates = List( @@ -84,7 +130,7 @@ class RouteCalculationSpec extends FunSuite { val g = makeGraph(updates) - val route = Router.findRoute(g, a, e) + val route = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT) assert(route.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) } @@ -97,7 +143,7 @@ class RouteCalculationSpec extends FunSuite { val g = makeGraph(updates) - val route = Router.findRoute(g, a, e) + val route = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT) assert(route.map(hops2Ids) === Failure(RouteNotFound)) } @@ -111,7 +157,34 @@ class RouteCalculationSpec extends FunSuite { val g = makeGraph(updates) - val route = Router.findRoute(g, a, e) + val route = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT) + assert(route.map(hops2Ids) === Failure(RouteNotFound)) + } + + test("route not found (source node not connected)") { + + val updates = List( + makeUpdate(2L, b, c, 0, 0), + makeUpdate(4L, d, e, 0, 0) + ).toMap + + val g = makeGraph(updates).addVertex(a) + + val route = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT) + assert(route.map(hops2Ids) === Failure(RouteNotFound)) + } + + test("route not found (target node not connected)") { + + val updates = List( + makeUpdate(1L, a, b, 0, 0), + makeUpdate(2L, b, c, 0, 0), + makeUpdate(3L, c, d, 0, 0) + ).toMap + + val g = makeGraph(updates) + + val route = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT) assert(route.map(hops2Ids) === Failure(RouteNotFound)) } @@ -125,10 +198,44 @@ class RouteCalculationSpec extends FunSuite { val g = makeGraph(updates) - val route = Router.findRoute(g, a, e) + val route = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT) assert(route.map(hops2Ids) === Failure(RouteNotFound)) } + test("route not found (amount too high)") { + + val highAmount = DEFAULT_AMOUNT_MSAT * 10 + + val updates = List( + makeUpdate(1L, a, b, 0, 0), + makeUpdate(2L, b, c, 0, 0, maxHtlcMsat = Some(DEFAULT_AMOUNT_MSAT)), + makeUpdate(3L, c, d, 0, 0) + ).toMap + + val g = makeGraph(updates) + + val route = Router.findRoute(g, a, d, highAmount) + assert(route.map(hops2Ids) === Failure(RouteNotFound)) + + } + + test("route not found (amount too low)") { + + val lowAmount = DEFAULT_AMOUNT_MSAT / 10 + + val updates = List( + makeUpdate(1L, a, b, 0, 0), + makeUpdate(2L, b, c, 0, 0, minHtlcMsat = DEFAULT_AMOUNT_MSAT), + makeUpdate(3L, c, d, 0, 0) + ).toMap + + val g = makeGraph(updates) + + val route = Router.findRoute(g, a, d, lowAmount) + assert(route.map(hops2Ids) === Failure(RouteNotFound)) + + } + test("route to self") { val updates = List( @@ -139,7 +246,7 @@ class RouteCalculationSpec extends FunSuite { val g = makeGraph(updates) - val route = Router.findRoute(g, a, a) + val route = Router.findRoute(g, a, a, DEFAULT_AMOUNT_MSAT) assert(route.map(hops2Ids) === Failure(CannotRouteToSelf)) } @@ -154,7 +261,7 @@ class RouteCalculationSpec extends FunSuite { val g = makeGraph(updates) - val route = Router.findRoute(g, a, b) + val route = Router.findRoute(g, a, b, DEFAULT_AMOUNT_MSAT) assert(route.map(hops2Ids) === Success(1 :: Nil)) } @@ -170,10 +277,10 @@ class RouteCalculationSpec extends FunSuite { val g = makeGraph(updates) - val route1 = Router.findRoute(g, a, e) + val route1 = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT) assert(route1.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) - val route2 = Router.findRoute(g, e, a) + val route2 = Router.findRoute(g, e, a, DEFAULT_AMOUNT_MSAT) assert(route2.map(hops2Ids) === Failure(RouteNotFound)) } @@ -210,7 +317,7 @@ class RouteCalculationSpec extends FunSuite { val g = makeGraph(updates) - val hops = Router.findRoute(g, a, e).get + val hops = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT).get assert(hops === Hop(a, b, uab) :: Hop(b, c, ubc) :: Hop(c, d, ucd) :: Hop(d, e, ude) :: Nil) } @@ -250,16 +357,16 @@ class RouteCalculationSpec extends FunSuite { val g = makeGraph(updates) - val route1 = Router.findRoute(g, a, e, withoutEdges = ChannelDesc(ShortChannelId(3L), c, d) :: Nil) + val route1 = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, ignoredEdges = Set(ChannelDesc(ShortChannelId(3L), c, d))) assert(route1.map(hops2Ids) === Failure(RouteNotFound)) // verify that we left the graph untouched - assert(g.containsEdge(c, d)) + assert(g.containsEdge(makeUpdate(3L, c, d, 0, 0)._1)) // c -> d assert(g.containsVertex(c)) assert(g.containsVertex(d)) // make sure we can find a route if without the blacklist - val route2 = Router.findRoute(g, a, e) + val route2 = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT) assert(route2.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) } @@ -273,11 +380,15 @@ class RouteCalculationSpec extends FunSuite { val g = makeGraph(updates) - val route1 = Router.findRoute(g, a, e) + val route1 = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT) assert(route1.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) assert(route1.get.head.lastUpdate.feeBaseMsat == 10) - val route2 = Router.findRoute(g, a, e, withEdges = Map(makeUpdate(1L, a, b, 5, 5))) + val extraUpdate = makeUpdate(1L, a, b, 5, 5) + + val extraGraphEdges = Set(GraphEdge(extraUpdate._1, extraUpdate._2)) + + val route2 = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, extraEdges = extraGraphEdges) assert(route2.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) assert(route2.get.head.lastUpdate.feeBaseMsat == 5) } @@ -321,9 +432,7 @@ class RouteCalculationSpec extends FunSuite { ChannelDesc(ShortChannelId(3L), c, d), ChannelDesc(ShortChannelId(8L), i, j) )) - } - } object RouteCalculationSpec { @@ -335,15 +444,25 @@ object RouteCalculationSpec { ChannelAnnouncement(DUMMY_SIG, DUMMY_SIG, DUMMY_SIG, DUMMY_SIG, "", Block.RegtestGenesisBlock.hash, ShortChannelId(shortChannelId), nodeId1, nodeId2, randomKey.publicKey, randomKey.publicKey) } - def makeUpdate(shortChannelId: Long, nodeId1: PublicKey, nodeId2: PublicKey, feeBaseMsat: Int, feeProportionalMillionth: Int): (ChannelDesc, ChannelUpdate) = - (ChannelDesc(ShortChannelId(shortChannelId), nodeId1, nodeId2) -> ChannelUpdate(DUMMY_SIG, Block.RegtestGenesisBlock.hash, ShortChannelId(shortChannelId), 0L, 0, 0, 1, 42, feeBaseMsat, feeProportionalMillionth, None)) + def makeUpdate(shortChannelId: Long, nodeId1: PublicKey, nodeId2: PublicKey, feeBaseMsat: Int, feeProportionalMillionth: Int, minHtlcMsat: Long = DEFAULT_AMOUNT_MSAT, maxHtlcMsat: Option[Long] = None): (ChannelDesc, ChannelUpdate) = + ChannelDesc(ShortChannelId(shortChannelId), nodeId1, nodeId2) -> ChannelUpdate( + signature = DUMMY_SIG, + chainHash = Block.RegtestGenesisBlock.hash, + shortChannelId = ShortChannelId(shortChannelId), + timestamp = 0L, + messageFlags = maxHtlcMsat match { + case Some(_) => 1 + case None => 0 + }, + channelFlags = 0, + cltvExpiryDelta = 0, + htlcMinimumMsat = minHtlcMsat, + feeBaseMsat = feeBaseMsat, + feeProportionalMillionths = feeProportionalMillionth, + htlcMaximumMsat = maxHtlcMsat + ) - - def makeGraph(updates: Map[ChannelDesc, ChannelUpdate]) = { - val g = new DirectedWeightedPseudograph[PublicKey, DescEdge](classOf[DescEdge]) - updates.foreach { case (d, u) => Router.addEdge(g, d, u) } - g - } + def makeGraph(updates: Map[ChannelDesc, ChannelUpdate]) = DirectedGraph().addEdges(updates.toSeq) def hops2Ids(route: Seq[Hop]) = route.map(hop => hop.lastUpdate.shortChannelId.toLong) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala index d18c69ce3..81c50d97a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala @@ -18,6 +18,7 @@ package fr.acinq.eclair.router import akka.actor.Status.Failure import akka.testkit.TestProbe +import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.Script.{pay2wsh, write} import fr.acinq.bitcoin.{Block, Satoshi, Transaction, TxOut} import fr.acinq.eclair.blockchain._ @@ -26,6 +27,7 @@ import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.io.Peer.{InvalidSignature, PeerRoutingMessage} import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.router.Announcements.makeChannelUpdate +import fr.acinq.eclair.router.Router.DEFAULT_AMOUNT_MSAT import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire.QueryShortChannelIds import fr.acinq.eclair.{Globals, ShortChannelId, randomKey} @@ -144,7 +146,7 @@ class RouterSpec extends BaseRouterSpec { import fixture._ val sender = TestProbe() // no route a->f - sender.send(router, RouteRequest(a, f)) + sender.send(router, RouteRequest(a, f, DEFAULT_AMOUNT_MSAT)) sender.expectMsg(Failure(RouteNotFound)) } @@ -152,7 +154,7 @@ class RouterSpec extends BaseRouterSpec { import fixture._ val sender = TestProbe() // no route a->f - sender.send(router, RouteRequest(randomKey.publicKey, f)) + sender.send(router, RouteRequest(randomKey.publicKey, f, DEFAULT_AMOUNT_MSAT)) sender.expectMsg(Failure(RouteNotFound)) } @@ -160,14 +162,14 @@ class RouterSpec extends BaseRouterSpec { import fixture._ val sender = TestProbe() // no route a->f - sender.send(router, RouteRequest(a, randomKey.publicKey)) + sender.send(router, RouteRequest(a, randomKey.publicKey, DEFAULT_AMOUNT_MSAT)) sender.expectMsg(Failure(RouteNotFound)) } test("route found") { fixture => import fixture._ val sender = TestProbe() - sender.send(router, RouteRequest(a, d)) + sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT)) val res = sender.expectMsgType[RouteResponse] assert(res.hops.map(_.nodeId).toList === a :: b :: c :: Nil) assert(res.hops.last.nextNodeId === d) @@ -176,13 +178,13 @@ class RouterSpec extends BaseRouterSpec { test("route found (with extra routing info)") { fixture => import fixture._ val sender = TestProbe() - val x = randomKey.publicKey - val y = randomKey.publicKey - val z = randomKey.publicKey + val x = PublicKey("02999fa724ec3c244e4da52b4a91ad421dc96c9a810587849cd4b2469313519c73") + val y = PublicKey("03f1cb1af20fe9ccda3ea128e27d7c39ee27375c8480f11a87c17197e97541ca6a") + val z = PublicKey("0358e32d245ff5f5a3eb14c78c6f69c67cea7846bdf9aeeb7199e8f6fbb0306484") val extraHop_cx = ExtraHop(c, ShortChannelId(1), 10, 11, 12) val extraHop_xy = ExtraHop(x, ShortChannelId(2), 10, 11, 12) val extraHop_yz = ExtraHop(y, ShortChannelId(3), 20, 21, 22) - sender.send(router, RouteRequest(a, z, assistedRoutes = Seq(extraHop_cx :: extraHop_xy :: extraHop_yz :: Nil))) + sender.send(router, RouteRequest(a, z, DEFAULT_AMOUNT_MSAT, assistedRoutes = Seq(extraHop_cx :: extraHop_xy :: extraHop_yz :: Nil))) val res = sender.expectMsgType[RouteResponse] assert(res.hops.map(_.nodeId).toList === a :: b :: c :: x :: y :: Nil) assert(res.hops.last.nextNodeId === z) @@ -191,7 +193,7 @@ class RouterSpec extends BaseRouterSpec { test("route not found (channel disabled)") { fixture => import fixture._ val sender = TestProbe() - sender.send(router, RouteRequest(a, d)) + sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT)) val res = sender.expectMsgType[RouteResponse] assert(res.hops.map(_.nodeId).toList === a :: b :: c :: Nil) assert(res.hops.last.nextNodeId === d) @@ -199,44 +201,29 @@ class RouterSpec extends BaseRouterSpec { val channelUpdate_cd1 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_c, d, channelId_cd, cltvExpiryDelta = 3, 0, feeBaseMsat = 153000, feeProportionalMillionths = 4, htlcMaximumMsat = 500000000L, enable = false) sender.send(router, PeerRoutingMessage(null, remoteNodeId, channelUpdate_cd1)) sender.expectMsg(TransportHandler.ReadAck(channelUpdate_cd1)) - sender.send(router, RouteRequest(a, d)) + sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT)) sender.expectMsg(Failure(RouteNotFound)) } test("temporary channel exclusion") { fixture => import fixture._ val sender = TestProbe() - sender.send(router, RouteRequest(a, d)) + sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT)) sender.expectMsgType[RouteResponse] val bc = ChannelDesc(channelId_bc, b, c) // let's exclude channel b->c sender.send(router, ExcludeChannel(bc)) - sender.send(router, RouteRequest(a, d)) + sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT)) sender.expectMsg(Failure(RouteNotFound)) // note that cb is still available! - sender.send(router, RouteRequest(d, a)) + sender.send(router, RouteRequest(d, a, DEFAULT_AMOUNT_MSAT)) sender.expectMsgType[RouteResponse] // let's remove the exclusion sender.send(router, LiftChannelExclusion(bc)) - sender.send(router, RouteRequest(a, d)) + sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT)) sender.expectMsgType[RouteResponse] } - test("export graph in dot format") { fixture => - import fixture._ - val sender = TestProbe() - sender.send(router, 'dot) - val dot = sender.expectMsgType[String] - /*Files.write(dot.getBytes(), new File("graph.dot")) - - import scala.sys.process._ - val input = new ByteArrayInputStream(dot.getBytes) - val output = new ByteArrayOutputStream() - "dot -Tpng" #< input #> output ! - val img = output.toByteArray - Files.write(img, new File("graph.png"))*/ - } - test("send routing state") { fixture => import fixture._ val sender = TestProbe() diff --git a/eclair-node-gui/src/main/resources/gui/main/main.fxml b/eclair-node-gui/src/main/resources/gui/main/main.fxml index ce287f83c..ab0d0d333 100644 --- a/eclair-node-gui/src/main/resources/gui/main/main.fxml +++ b/eclair-node-gui/src/main/resources/gui/main/main.fxml @@ -292,12 +292,6 @@ - - - - - diff --git a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/Handlers.scala b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/Handlers.scala index e240c2e31..78c8bc6c2 100644 --- a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/Handlers.scala +++ b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/Handlers.scala @@ -107,21 +107,6 @@ class Handlers(fKit: Future[Kit])(implicit ec: ExecutionContext = ExecutionConte res <- (kit.paymentHandler ? ReceivePayment(amountMsat_opt, description)).mapTo[PaymentRequest].map(PaymentRequest.write) } yield res - def exportToDot(file: File) = for { - kit <- fKit - dot <- (kit.router ? 'dot).mapTo[String] - _ = printToFile(file)(writer => writer.write(dot)) - } yield {} - - private def printToFile(f: java.io.File)(op: java.io.FileWriter => Unit) { - val p = new FileWriter(f) - try { - op(p) - } finally { - p.close - } - } - /** * Displays a system notification if the system supports it. * diff --git a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/MainController.scala b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/MainController.scala index 628d26328..ae6329d9d 100644 --- a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/MainController.scala +++ b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/MainController.scala @@ -503,14 +503,6 @@ class MainController(val handlers: Handlers, val hostServices: HostServices) ext row } - @FXML def handleExportDot() = { - val fileChooser = new FileChooser - fileChooser.setTitle("Save as") - fileChooser.getExtensionFilters.addAll(new ExtensionFilter("DOT File (*.dot)", "*.dot")) - val file = fileChooser.showSaveDialog(getWindow.orNull) - if (file != null) handlers.exportToDot(file) - } - @FXML def handleOpenChannel() = { val openChannelStage = new OpenChannelStage(handlers) openChannelStage.initOwner(getWindow.orNull)