diff --git a/eclair-demo/src/main/scala/fr/acinq/eclair/api/Service.scala b/eclair-demo/src/main/scala/fr/acinq/eclair/api/Service.scala index 9cf6d76c8..2129f2694 100644 --- a/eclair-demo/src/main/scala/fr/acinq/eclair/api/Service.scala +++ b/eclair-demo/src/main/scala/fr/acinq/eclair/api/Service.scala @@ -23,6 +23,7 @@ import scala.concurrent.duration._ import scala.util.{Failure, Success} import akka.pattern.ask import fr.acinq.eclair.channel.Register.ListChannels +import fr.acinq.eclair.channel.Router.CreatePayment /** * Created by PM on 25/01/2016. @@ -66,20 +67,12 @@ trait Service extends Logging { .flatMap(l => Future.sequence(l.map(c => c ? CMD_GETINFO))) case JsonRPCBody(_, _, "network", _) => (router ? 'network).mapTo[Iterable[channel_desc]] - case JsonRPCBody(_, _, "addhtlc", JInt(amount) :: JString(rhash) :: JInt(expiry) :: tail) => - val nodeIds = tail.map { - case JString(nodeId) => nodeId - } - Boot.system.actorSelection(Register.actorPathToNodeId(nodeIds.head)) - .resolveOne(2 seconds) - .map { channel => - channel ! CMD_ADD_HTLC(amount.toInt, BinaryData(rhash), locktime(Seconds(expiry.toInt)), nodeIds.drop(1)) - channel.toString() - } + case JsonRPCBody(_, _, "addhtlc", JInt(amount) :: JString(rhash) :: JString(nodeId) :: Nil) => + (router ? CreatePayment(amount.toInt, BinaryData(rhash), BinaryData(nodeId))).mapTo[ActorRef] case JsonRPCBody(_, _, "sign", JString(channel) :: Nil) => sendCommand(channel, CMD_SIGN) case JsonRPCBody(_, _, "fulfillhtlc", JString(channel) :: JDouble(id) :: JString(r) :: Nil) => - sendCommand(channel, CMD_FULFILL_HTLC(id.toLong, BinaryData(r))) + sendCommand(channel, CMD_FULFILL_HTLC(id.toLong, BinaryData(r), commit = true)) case JsonRPCBody(_, _, "close", JString(channel) :: JString(scriptPubKey) :: Nil) => sendCommand(channel, CMD_CLOSE(Some(scriptPubKey))) case JsonRPCBody(_, _, "help", _) => diff --git a/eclair-demo/src/main/scala/fr/acinq/eclair/channel/Channel.scala b/eclair-demo/src/main/scala/fr/acinq/eclair/channel/Channel.scala index c21ba5e2a..c7a244483 100644 --- a/eclair-demo/src/main/scala/fr/acinq/eclair/channel/Channel.scala +++ b/eclair-demo/src/main/scala/fr/acinq/eclair/channel/Channel.scala @@ -268,13 +268,14 @@ class Channel(val them: ActorRef, val blockchain: ActorRef, val params: OurChann when(NORMAL) { - case Event(CMD_ADD_HTLC(amount, rHash, expiry, nodeIds, origin, id_opt), d@DATA_NORMAL(_, _, _, htlcIdx, _, _, ourChanges, _, _, _, downstreams)) => + case Event(CMD_ADD_HTLC(amount, rHash, expiry, nodeIds, origin, id_opt, commit), d@DATA_NORMAL(_, _, _, htlcIdx, _, _, ourChanges, _, _, _, downstreams)) => // TODO: should we take pending htlcs into account? // TODO: assert(commitment.state.commit_changes(staged).us.pay_msat >= amount, "insufficient funds!") // TODO: nodeIds are ignored val id: Long = id_opt.getOrElse(htlcIdx + 1) - val htlc = update_add_htlc(id, amount, rHash, expiry, routing(ByteString.EMPTY)) + val htlc = update_add_htlc(id, amount, rHash, expiry, routing(ByteString.copyFromUtf8(nodeIds.mkString(",")))) them ! htlc + if (commit) self ! CMD_SIGN stay using d.copy(htlcIdx = htlc.id, ourChanges = ourChanges.copy(proposed = ourChanges.proposed :+ htlc), downstreams = downstreams + (htlc.id -> origin)) case Event(htlc@update_add_htlc(htlcId, amount, rHash, expiry, nodeIds), d@DATA_NORMAL(_, _, _, _, _, _, _, theirChanges, _, _, _)) => @@ -283,11 +284,12 @@ class Channel(val them: ActorRef, val blockchain: ActorRef, val params: OurChann // TODO: nodeIds are ignored stay using d.copy(theirChanges = theirChanges.copy(proposed = theirChanges.proposed :+ htlc)) - case Event(CMD_FULFILL_HTLC(id, r), d@DATA_NORMAL(_, _, _, _, _, theirCommit, ourChanges, theirChanges, _, _, _)) => + case Event(CMD_FULFILL_HTLC(id, r, commit), d@DATA_NORMAL(_, _, _, _, _, theirCommit, ourChanges, theirChanges, _, _, _)) => theirChanges.acked.collectFirst { case u: update_add_htlc if u.id == id => u } match { case Some(htlc) if htlc.rHash == bin2sha256(Crypto.sha256(r)) => val fulfill = update_fulfill_htlc(id, r) them ! fulfill + if (commit) self ! CMD_SIGN stay using d.copy(ourChanges = ourChanges.copy(proposed = ourChanges.proposed :+ fulfill)) case Some(htlc) => throw new RuntimeException(s"invalid htlc preimage for htlc $id") case None => throw new RuntimeException(s"unknown htlc id=$id") @@ -306,7 +308,7 @@ class Channel(val them: ActorRef, val blockchain: ActorRef, val params: OurChann .onComplete { case Success(downstream) => log.info(s"forwarding r value to downstream=$downstream") - downstream ! CMD_FULFILL_HTLC(id, r) + downstream ! CMD_FULFILL_HTLC(id, r, commit = true) case Failure(t: Throwable) => log.warning(s"couldn't resolve downstream node, htlc #${htlc.id} will timeout", t) } @@ -318,11 +320,12 @@ class Channel(val them: ActorRef, val blockchain: ActorRef, val params: OurChann case None => throw new RuntimeException(s"unknown htlc id=$id") // TODO : we should fail the channel } - case Event(CMD_FAIL_HTLC(id, reason), d@DATA_NORMAL(_, _, _, _, _, theirCommit, ourChanges, theirChanges, _, _, _)) => + case Event(CMD_FAIL_HTLC(id, reason, commit), d@DATA_NORMAL(_, _, _, _, _, theirCommit, ourChanges, theirChanges, _, _, _)) => theirChanges.acked.collectFirst { case u: update_add_htlc if u.id == id => u } match { case Some(htlc) => val fail = update_fail_htlc(id, fail_reason(ByteString.copyFromUtf8(reason))) them ! fail + if (commit) self ! CMD_SIGN stay using d.copy(ourChanges = ourChanges.copy(proposed = ourChanges.proposed :+ fail)) case None => throw new RuntimeException(s"unknown htlc id=$id") } @@ -339,7 +342,7 @@ class Channel(val them: ActorRef, val blockchain: ActorRef, val params: OurChann .onComplete { case Success(downstream) => log.info(s"forwarding fail to downstream=$downstream") - downstream ! CMD_FAIL_HTLC(id, reason.info.toStringUtf8) + downstream ! CMD_FAIL_HTLC(id, reason.info.toStringUtf8, commit =true) case Failure(t: Throwable) => log.warning(s"couldn't resolve downstream node, htlc #${htlc.id} will timeout", t) } @@ -384,7 +387,7 @@ class Channel(val them: ActorRef, val blockchain: ActorRef, val params: OurChann them ! update_revocation(ourRevocationPreimage, ourNextRevocationHash) // now that we have their sig, we should propagate the htlcs newly received (spec.htlcs_in -- ourCommit.spec.htlcs_in).foreach(htlc => { - val nextNodeIds = htlc.route.info.toStringUtf8.split(",").toSeq.filterNot(_.isEmpty) + val nextNodeIds = htlc.route.info.toStringUtf8.split(",").toSeq.filterNot(_.isEmpty).map(BinaryData(_)) nextNodeIds.headOption match { case Some(nextNodeId) => log.debug(s"propagating htlc #${htlc.id} to $nextNodeId") @@ -395,7 +398,7 @@ class Channel(val them: ActorRef, val blockchain: ActorRef, val params: OurChann case Success(upstream) => log.info(s"forwarding htlc #${htlc.id} to upstream=$upstream") // TODO : we should decrement expiry !! - upstream ! CMD_ADD_HTLC(htlc.amountMsat, htlc.rHash, htlc.expiry, nextNodeIds.drop(1), Some(d.anchorId)) + upstream ! CMD_ADD_HTLC(htlc.amountMsat, htlc.rHash, htlc.expiry, nextNodeIds.drop(1), Some(d.anchorId), commit = true) case Failure(t: Throwable) => // TODO : send "fail route error" log.warning(s"couldn't resolve upstream node, htlc #${htlc.id} will timeout", t) diff --git a/eclair-demo/src/main/scala/fr/acinq/eclair/channel/ChannelTypes.scala b/eclair-demo/src/main/scala/fr/acinq/eclair/channel/ChannelTypes.scala index 0ace1da8c..9087ec201 100644 --- a/eclair-demo/src/main/scala/fr/acinq/eclair/channel/ChannelTypes.scala +++ b/eclair-demo/src/main/scala/fr/acinq/eclair/channel/ChannelTypes.scala @@ -93,9 +93,9 @@ sealed trait Command * @param originChannelId * @param id should only be provided in tests otherwise it will be assigned automatically */ -final case class CMD_ADD_HTLC(amount: Int, rHash: sha256_hash, expiry: locktime, nodeIds: Seq[String] = Seq.empty[String], originChannelId: Option[BinaryData] = None, id: Option[Long] = None) extends Command -final case class CMD_FULFILL_HTLC(id: Long, r: sha256_hash) extends Command -final case class CMD_FAIL_HTLC(id: Long, reason: String) extends Command +final case class CMD_ADD_HTLC(amount: Int, rHash: sha256_hash, expiry: locktime, nodeIds: Seq[BinaryData] = Seq.empty[BinaryData], originChannelId: Option[BinaryData] = None, id: Option[Long] = None, commit: Boolean = false) extends Command +final case class CMD_FULFILL_HTLC(id: Long, r: sha256_hash, commit: Boolean = false) extends Command +final case class CMD_FAIL_HTLC(id: Long, reason: String, commit: Boolean = false) extends Command case object CMD_SIGN extends Command final case class CMD_CLOSE(scriptPubKey: Option[BinaryData]) extends Command case object CMD_GETSTATE extends Command diff --git a/eclair-demo/src/main/scala/fr/acinq/eclair/channel/Router.scala b/eclair-demo/src/main/scala/fr/acinq/eclair/channel/Router.scala index dbff1c9f0..027775850 100644 --- a/eclair-demo/src/main/scala/fr/acinq/eclair/channel/Router.scala +++ b/eclair-demo/src/main/scala/fr/acinq/eclair/channel/Router.scala @@ -2,9 +2,13 @@ package fr.acinq.eclair.channel import akka.actor.{Actor, ActorLogging} import fr.acinq.bitcoin.BinaryData +import fr.acinq.eclair.{Boot, Globals} +import fr.acinq.eclair._ import lightning._ +import lightning.locktime.Locktime.{Blocks} -import scala.concurrent.ExecutionContext +import scala.annotation.tailrec +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ /** @@ -13,6 +17,8 @@ import scala.concurrent.duration._ class Router extends Actor with ActorLogging { import ExecutionContext.Implicits.global + import Router._ + context.system.scheduler.schedule(5 seconds, 10 seconds, self, 'tick) def receive: Receive = main(Map()) @@ -25,6 +31,38 @@ class Router extends Actor with ActorLogging { channels.values.foreach(sel ! register_channel(_)) case 'network => sender ! channels.values + case c: CreatePayment => + val s = sender + findRoute(Globals.Node.publicKey, c.targetNodeId, channels).map(route => { + Boot.system.actorSelection(Register.actorPathToNodeId(route.head)) + .resolveOne(2 seconds) + .map { channel => + // TODO : expiry is not correctly calculated + channel ! CMD_ADD_HTLC(c.amount, c.h, locktime(Blocks(route.size - 1)), route.drop(1), commit = true) + s ! channel + } + }) } } + +object Router { + + // @formatter:off + case class CreatePayment(amount: Int, h: sha256_hash, targetNodeId: BinaryData) + // @formatter:on + + @tailrec + def findRoute(myNodeId: BinaryData, targetNodeId: BinaryData, channels: Map[BinaryData, channel_desc], route: Seq[BinaryData]): Seq[BinaryData] = { + channels.values.map(c => (c.nodeIdA: BinaryData, c.nodeIdB: BinaryData) ::(c.nodeIdB: BinaryData, c.nodeIdA: BinaryData) :: Nil).flatten.find(_._1 == targetNodeId) match { + case Some((_, previous)) if previous == myNodeId => targetNodeId +: route + case Some((_, previous)) => findRoute(myNodeId, previous, channels, targetNodeId +: route) + case None => throw new RuntimeException(s"cannot find route to $targetNodeId") + } + } + + def findRoute(myNodeId: BinaryData, targetNodeId: BinaryData, channels: Map[BinaryData, channel_desc])(implicit ec: ExecutionContext): Future[Seq[BinaryData]] = Future { + findRoute(myNodeId, targetNodeId, channels, Seq()) + } + +} \ No newline at end of file diff --git a/eclair-demo/src/test/scala/fr/acinq/eclair/RouterSpec.scala b/eclair-demo/src/test/scala/fr/acinq/eclair/RouterSpec.scala new file mode 100644 index 000000000..85b59af5a --- /dev/null +++ b/eclair-demo/src/test/scala/fr/acinq/eclair/RouterSpec.scala @@ -0,0 +1,45 @@ +package fr.acinq.eclair + +import com.google.protobuf.ByteString +import fr.acinq.bitcoin.BinaryData +import fr.acinq.eclair.channel.Router +import lightning.channel_desc +import org.junit.runner.RunWith +import org.scalatest.FunSuite +import org.scalatest.junit.JUnitRunner + +/** + * Created by PM on 31/05/2016. + */ +@RunWith(classOf[JUnitRunner]) +class RouterSpec extends FunSuite { + + test("calculate simple route") { + + val channels: Map[BinaryData, channel_desc] = Map( + BinaryData("0a") -> channel_desc(BinaryData("0a"), BinaryData("01"), BinaryData("02")), + BinaryData("0b") -> channel_desc(BinaryData("0b"), BinaryData("03"), BinaryData("02")), + BinaryData("0c") -> channel_desc(BinaryData("0c"), BinaryData("03"), BinaryData("04")), + BinaryData("0d") -> channel_desc(BinaryData("0d"), BinaryData("04"), BinaryData("05")) + ) + + val route = Router.findRoute(BinaryData("01"), BinaryData("05"), channels, Seq()) + + assert(route === BinaryData("02") :: BinaryData("03") :: BinaryData("04") :: BinaryData("05") :: Nil) + + } + + test("calculate simple route 2") { + + val channels: Map[BinaryData, channel_desc] = Map( + BinaryData("99e542c274b073d215af02e57f814f3d16a2373a00ac52b49ef4a1949c912609") -> channel_desc(BinaryData("99e542c274b073d215af02e57f814f3d16a2373a00ac52b49ef4a1949c912609"), BinaryData("032b2e37d202658eb5216a698e52da665c25c5d04de0faf1d29aa2af7fb374a003"), BinaryData("0382887856e9f10a8a1ffade96b4009769141e5f3692f2ffc35fd4221f6057643b")), + BinaryData("44d7f822e0498e21473a8a40045c9b7e7bd2e78730b5274cb5836e64bc0b6125") -> channel_desc(BinaryData("44d7f822e0498e21473a8a40045c9b7e7bd2e78730b5274cb5836e64bc0b6125"), BinaryData("023cda4e9506ce0a5fd3e156fc6d1bff16873375c8e823ee18aa36fa6844c0ae61"), BinaryData("0382887856e9f10a8a1ffade96b4009769141e5f3692f2ffc35fd4221f6057643b")) + ) + + val route = Router.findRoute(BinaryData("032b2e37d202658eb5216a698e52da665c25c5d04de0faf1d29aa2af7fb374a003"), BinaryData("023cda4e9506ce0a5fd3e156fc6d1bff16873375c8e823ee18aa36fa6844c0ae61"), channels, Seq()) + + assert(route === BinaryData("0382887856e9f10a8a1ffade96b4009769141e5f3692f2ffc35fd4221f6057643b") :: BinaryData("023cda4e9506ce0a5fd3e156fc6d1bff16873375c8e823ee18aa36fa6844c0ae61") :: Nil) + + } + +}