diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index 6f9c9e838..7c425328c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.channel.Register.{Forward, ForwardShortId} import fr.acinq.eclair.channel._ import fr.acinq.eclair.db.{IncomingPayment, NetworkFee, OutgoingPayment, Stats} import fr.acinq.eclair.io.Peer.{GetPeerInfo, PeerInfo} -import fr.acinq.eclair.io.{NodeURI, Peer} +import fr.acinq.eclair.io.{NodeURI, Peer, Switchboard} import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.router.{ChannelDesc, RouteRequest, RouteResponse, Router} import scodec.bits.ByteVector @@ -55,7 +55,9 @@ object TimestampQueryFilters { trait Eclair { - def connect(uri: String)(implicit timeout: Timeout): Future[String] + def connect(target: Either[NodeURI, PublicKey])(implicit timeout: Timeout): Future[String] + + def disconnect(nodeId: PublicKey)(implicit timeout: Timeout): Future[String] def open(nodeId: PublicKey, fundingSatoshis: Long, pushMsat_opt: Option[Long], fundingFeerateSatByte_opt: Option[Long], flags_opt: Option[Int], openTimeout_opt: Option[Timeout])(implicit timeout: Timeout): Future[String] @@ -109,8 +111,13 @@ class EclairImpl(appKit: Kit) extends Eclair { implicit val ec = appKit.system.dispatcher - override def connect(uri: String)(implicit timeout: Timeout): Future[String] = { - (appKit.switchboard ? Peer.Connect(NodeURI.parse(uri))).mapTo[String] + override def connect(target: Either[NodeURI, PublicKey])(implicit timeout: Timeout): Future[String] = target match { + case Left(uri) => (appKit.switchboard ? Peer.Connect(uri)).mapTo[String] + case Right(pubKey) => (appKit.switchboard ? Peer.Connect(pubKey, None)).mapTo[String] + } + + override def disconnect(nodeId: PublicKey)(implicit timeout: Timeout): Future[String] = { + (appKit.switchboard ? Peer.Disconnect(nodeId)).mapTo[String] } override def open(nodeId: PublicKey, fundingSatoshis: Long, pushMsat_opt: Option[Long], fundingFeerateSatByte_opt: Option[Long], flags_opt: Option[Int], openTimeout_opt: Option[Timeout])(implicit timeout: Timeout): Future[String] = { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/api/FormParamExtractors.scala b/eclair-core/src/main/scala/fr/acinq/eclair/api/FormParamExtractors.scala index a9e24f544..17425cf36 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/api/FormParamExtractors.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/api/FormParamExtractors.scala @@ -24,9 +24,9 @@ import akka.util.Timeout import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.ShortChannelId +import fr.acinq.eclair.io.NodeURI import fr.acinq.eclair.payment.PaymentRequest import scodec.bits.ByteVector - import scala.concurrent.duration._ import scala.util.{Failure, Success, Try} @@ -60,6 +60,10 @@ object FormParamExtractors { Timeout(str.toInt.seconds) } + implicit val nodeURIUnmarshaller: Unmarshaller[String, NodeURI] = Unmarshaller.strict { str => + NodeURI.parse(str) + } + implicit val pubkeyListUnmarshaller: Unmarshaller[String, List[PublicKey]] = Unmarshaller.strict { str => Try(serialization.read[List[String]](str).map { el => PublicKey(ByteVector.fromValidHex(el), checkValid = false) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/api/Service.scala b/eclair-core/src/main/scala/fr/acinq/eclair/api/Service.scala index b8d633ada..f816aa8be 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/api/Service.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/api/Service.scala @@ -30,6 +30,7 @@ import akka.http.scaladsl.server.directives.Credentials import akka.stream.scaladsl.{BroadcastHub, Flow, Keep, Source} import akka.stream.{ActorMaterializer, OverflowStrategy} import akka.util.Timeout +import com.google.common.net.HostAndPort import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.api.FormParamExtractors._ @@ -41,6 +42,7 @@ import fr.acinq.eclair.{Eclair, ShortChannelId} import grizzled.slf4j.Logging import org.json4s.jackson.Serialization import scodec.bits.ByteVector + import scala.concurrent.Future import scala.concurrent.duration._ @@ -135,10 +137,17 @@ trait Service extends ExtraDirectives with Logging { complete(eclairApi.getInfoResponse()) } ~ path("connect") { - formFields("uri".as[String]) { uri => - complete(eclairApi.connect(uri)) + formFields("uri".as[NodeURI]) { uri => + complete(eclairApi.connect(Left(uri))) } ~ formFields(nodeIdFormParam, "host".as[String], "port".as[Int].?) { (nodeId, host, port_opt) => - complete(eclairApi.connect(s"$nodeId@$host:${port_opt.getOrElse(NodeURI.DEFAULT_PORT)}")) + complete(eclairApi.connect(Left(NodeURI(nodeId, HostAndPort.fromParts(host, port_opt.getOrElse(NodeURI.DEFAULT_PORT)))))) + } ~ formFields(nodeIdFormParam) { nodeId => + complete(eclairApi.connect(Right(nodeId))) + } + } ~ + path("disconnect") { + formFields(nodeIdFormParam) { nodeId => + complete(eclairApi.disconnect(nodeId)) } } ~ path("open") { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala index 1387170a0..546516785 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala @@ -27,6 +27,8 @@ trait NetworkDb { def updateNode(n: NodeAnnouncement) + def getNode(nodeId: PublicKey): Option[NodeAnnouncement] + def removeNode(nodeId: PublicKey) def listNodes(): Seq[NodeAnnouncement] diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala index 51e29cd81..4bca51ee2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala @@ -59,6 +59,14 @@ class SqliteNetworkDb(sqlite: Connection) extends NetworkDb { } } + override def getNode(nodeId: Crypto.PublicKey): Option[NodeAnnouncement] = { + using(sqlite.prepareStatement("SELECT data FROM nodes WHERE node_id=?")) { statement => + statement.setBytes(1, nodeId.toBin.toArray) + val rs = statement.executeQuery() + codecSequence(rs, nodeAnnouncementCodec).headOption + } + } + override def removeNode(nodeId: Crypto.PublicKey): Unit = { using(sqlite.prepareStatement("DELETE FROM nodes WHERE node_id=?")) { statement => statement.setBytes(1, nodeId.toBin.toArray) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index eadf81caa..020ad791a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -23,18 +23,17 @@ import java.nio.ByteOrder import akka.actor.{ActorRef, FSM, OneForOneStrategy, PoisonPill, Props, Status, SupervisorStrategy, Terminated} import akka.event.Logging.MDC import akka.util.Timeout +import com.google.common.net.HostAndPort import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.{ByteVector32, DeterministicWallet, MilliSatoshi, Protocol, Satoshi} import fr.acinq.eclair.blockchain.EclairWallet import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.TransportHandler -import fr.acinq.eclair.secureRandom import fr.acinq.eclair.router._ import fr.acinq.eclair.wire._ -import fr.acinq.eclair.{wire, _} +import fr.acinq.eclair.{secureRandom, wire, _} import scodec.Attempt import scodec.bits.ByteVector - import scala.compat.Platform import scala.concurrent.duration._ import scala.util.Random @@ -59,26 +58,34 @@ class Peer(nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: Actor } when(DISCONNECTED) { - case Event(Peer.Connect(NodeURI(_, hostAndPort)), d: DisconnectedData) => - val address = new InetSocketAddress(hostAndPort.getHost, hostAndPort.getPort) - if (d.address_opt.contains(address)) { - // we already know this address, we'll reconnect automatically - sender ! "reconnection in progress" - stay - } else { - // we immediately process explicit connection requests to new addresses - context.actorOf(Client.props(nodeParams, authenticator, address, remoteNodeId, origin_opt = Some(sender()))) - stay + case Event(Peer.Connect(_, address_opt), d: DisconnectedData) => + address_opt + .map(hostAndPort2InetSocketAddress) + .orElse(getPeerAddressFromNodeAnnouncement) match { + case None => + sender ! "no address found" + stay + case Some(address) => + if (d.address_opt.contains(address)) { + // we already know this address, we'll reconnect automatically + sender ! "reconnection in progress" + stay + } else { + // we immediately process explicit connection requests to new addresses + context.actorOf(Client.props(nodeParams, authenticator, address, remoteNodeId, origin_opt = Some(sender()))) + stay using d.copy(address_opt = Some(address)) + } } case Event(Reconnect, d: DisconnectedData) => - d.address_opt match { - case None => stay // no-op (this peer didn't initiate the connection and doesn't have the ip of the counterparty) - case _ if d.channels.isEmpty => stay // no-op (no more channels with this peer) + d.address_opt.orElse(getPeerAddressFromNodeAnnouncement) match { + case _ if d.channels.isEmpty => stay // no-op, no more channels with this peer + case None => stay // no-op, we don't know any address to this peer and we won't try reconnecting again case Some(address) => context.actorOf(Client.props(nodeParams, authenticator, address, remoteNodeId, origin_opt = None)) + log.info(s"reconnecting to $address") // exponential backoff retry with a finite max - setTimer(RECONNECT_TIMER, Reconnect, Math.min(10 + Math.pow(2, d.attempts), 60) seconds, repeat = false) + setTimer(RECONNECT_TIMER, Reconnect, Math.min(10 + Math.pow(2, d.attempts), 3600) seconds, repeat = false) stay using d.copy(attempts = d.attempts + 1) } @@ -177,6 +184,13 @@ class Peer(nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: Actor } else { stay using d.copy(channels = channels1) } + + case Event(Disconnect(nodeId), d: InitializingData) if nodeId == remoteNodeId => + log.info("disconnecting") + sender ! "disconnecting" + d.transport ! PoisonPill + stay + } when(CONNECTED) { @@ -411,7 +425,9 @@ class Peer(nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: Actor log.info(s"resuming processing of network announcements for peer") stay using d.copy(behavior = d.behavior.copy(fundingTxAlreadySpentCount = 0, ignoreNetworkAnnouncement = false)) - case Event(Disconnect, d: ConnectedData) => + case Event(Disconnect(nodeId), d: ConnectedData) if nodeId == remoteNodeId => + log.info(s"disconnecting") + sender ! "disconnecting" d.transport ! PoisonPill stay @@ -478,8 +494,8 @@ class Peer(nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: Actor onTransition { case INSTANTIATING -> DISCONNECTED if nodeParams.autoReconnect && nextStateData.address_opt.isDefined => self ! Reconnect // we reconnect right away if we just started the peer - case _ -> DISCONNECTED if nodeParams.autoReconnect && nextStateData.address_opt.isDefined => setTimer(RECONNECT_TIMER, Reconnect, 1 second, repeat = false) - case DISCONNECTED -> _ if nodeParams.autoReconnect && stateData.address_opt.isDefined => cancelTimer(RECONNECT_TIMER) + case _ -> DISCONNECTED if nodeParams.autoReconnect => setTimer(RECONNECT_TIMER, Reconnect, 1 second, repeat = false) + case DISCONNECTED -> _ if nodeParams.autoReconnect => cancelTimer(RECONNECT_TIMER) } def createNewChannel(nodeParams: NodeParams, funder: Boolean, fundingSatoshis: Long, origin_opt: Option[ActorRef]): (ActorRef, LocalParams) = { @@ -501,6 +517,11 @@ class Peer(nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: Actor stop(FSM.Normal) } + // TODO gets the first of the list, improve selection? + def getPeerAddressFromNodeAnnouncement: Option[InetSocketAddress] = { + nodeParams.db.network.getNode(remoteNodeId).flatMap(_.addresses.headOption.map(_.socketAddress)) + } + // a failing channel won't be restarted, it should handle its states override val supervisorStrategy = OneForOneStrategy(loggingEnabled = true) { case _ => SupervisorStrategy.Stop } @@ -549,9 +570,14 @@ object Peer { case object CONNECTED extends State case class Init(previousKnownAddress: Option[InetSocketAddress], storedChannels: Set[HasCommitments]) - case class Connect(uri: NodeURI) + case class Connect(nodeId: PublicKey, address_opt: Option[HostAndPort]) { + def uri: Option[NodeURI] = address_opt.map(NodeURI(nodeId, _)) + } + object Connect { + def apply(uri: NodeURI): Connect = new Connect(uri.nodeId, Some(uri.address)) + } case object Reconnect - case object Disconnect + case class Disconnect(nodeId: PublicKey) case object ResumeAnnouncements case class OpenChannel(remoteNodeId: PublicKey, fundingSatoshis: Satoshi, pushMsat: MilliSatoshi, fundingTxFeeratePerKw_opt: Option[Long], channelFlags: Option[Byte], timeout_opt: Option[Timeout]) { require(fundingSatoshis.amount < Channel.MAX_FUNDING_SATOSHIS, s"fundingSatoshis must be less than ${Channel.MAX_FUNDING_SATOSHIS}") @@ -617,4 +643,6 @@ object Peer { case _ => true // if there is a filter and message doesn't have a timestamp (e.g. channel_announcement), then we send it } } + + def hostAndPort2InetSocketAddress(hostAndPort: HostAndPort): InetSocketAddress = new InetSocketAddress(hostAndPort.getHost, hostAndPort.getPort) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala index 2abc2f008..56e918cdd 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala @@ -65,7 +65,11 @@ class Switchboard(nodeParams: NodeParams, authenticator: ActorRef, watcher: Acto channels .groupBy(_.commitments.remoteParams.nodeId) .map { - case (remoteNodeId, states) => (remoteNodeId, states, peers.get(remoteNodeId)) + case (remoteNodeId, states) => + val address_opt = peers.get(remoteNodeId).orElse { + nodeParams.db.network.getNode(remoteNodeId).flatMap(_.addresses.headOption) // gets the first of the list! TODO improve selection? + } + (remoteNodeId, states, address_opt) } .foreach { case (remoteNodeId, states, nodeaddress_opt) => @@ -77,14 +81,20 @@ class Switchboard(nodeParams: NodeParams, authenticator: ActorRef, watcher: Acto def receive: Receive = { - case Peer.Connect(NodeURI(publicKey, _)) if publicKey == nodeParams.nodeId => + case Peer.Connect(publicKey, _) if publicKey == nodeParams.nodeId => sender ! Status.Failure(new RuntimeException("cannot open connection with oneself")) case c: Peer.Connect => // we create a peer if it doesn't exist - val peer = createOrGetPeer(c.uri.nodeId, previousKnownAddress = None, offlineChannels = Set.empty) + val peer = createOrGetPeer(c.nodeId, previousKnownAddress = None, offlineChannels = Set.empty) peer forward c + case d: Peer.Disconnect => + getPeer(d.nodeId) match { + case Some(peer) => peer forward d + case None => sender ! Status.Failure(new RuntimeException("peer not found")) + } + case o: Peer.OpenChannel => getPeer(o.remoteNodeId) match { case Some(peer) => peer forward o diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestkitBaseClass.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestkitBaseClass.scala index 949de4793..6afa9e45f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestkitBaseClass.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestkitBaseClass.scala @@ -18,6 +18,7 @@ package fr.acinq.eclair import akka.actor.{ActorNotFound, ActorSystem, PoisonPill} import akka.testkit.TestKit +import com.typesafe.config.ConfigFactory import fr.acinq.eclair.blockchain.fee.FeeratesPerKw import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, fixture} @@ -47,4 +48,4 @@ abstract class TestkitBaseClass extends TestKit(ActorSystem("test")) with fixtur Globals.feeratesPerKw.set(FeeratesPerKw.single(1)) } -} +} \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala index 60c79cd29..abe4b102a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala @@ -30,6 +30,9 @@ import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.{ByteVector32, Crypto, MilliSatoshi} import fr.acinq.eclair.TestConstants._ import fr.acinq.eclair._ +import fr.acinq.eclair.channel.RES_GETINFO +import fr.acinq.eclair.db.{IncomingPayment, NetworkFee, OutgoingPayment, Stats} +import fr.acinq.eclair.io.NodeURI import fr.acinq.eclair.io.Peer.PeerInfo import fr.acinq.eclair.payment.PaymentLifecycle.PaymentFailed import fr.acinq.eclair.payment._ @@ -204,35 +207,35 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock test("'connect' method should accept an URI and a triple with nodeId/host/port") { - val remoteNodeId = "030bb6a5e0c6b203c7e2180fb78c7ba4bdce46126761d8201b91ddac089cdecc87" - val remoteHost = "93.137.102.239" - val remoteUri = "030bb6a5e0c6b203c7e2180fb78c7ba4bdce46126761d8201b91ddac089cdecc87@93.137.102.239:9735" + val remoteNodeId = PublicKey(hex"030bb6a5e0c6b203c7e2180fb78c7ba4bdce46126761d8201b91ddac089cdecc87") + val remoteUri = NodeURI.parse("030bb6a5e0c6b203c7e2180fb78c7ba4bdce46126761d8201b91ddac089cdecc87@93.137.102.239:9735") val eclair = mock[Eclair] - eclair.connect(any[String])(any[Timeout]) returns Future.successful("connected") + eclair.connect(any[Either[NodeURI, PublicKey]])(any[Timeout]) returns Future.successful("connected") val mockService = new MockService(eclair) - Post("/connect", FormData("nodeId" -> remoteNodeId, "host" -> remoteHost).toEntity) ~> + Post("/connect", FormData("nodeId" -> remoteNodeId.toHex).toEntity) ~> addCredentials(BasicHttpCredentials("", mockService.password)) ~> Route.seal(mockService.route) ~> check { assert(handled) assert(status == OK) assert(entityAs[String] == "\"connected\"") - eclair.connect(remoteUri)(any[Timeout]).wasCalled(once) + eclair.connect(Right(remoteNodeId))(any[Timeout]).wasCalled(once) } - Post("/connect", FormData("uri" -> remoteUri).toEntity) ~> + Post("/connect", FormData("uri" -> remoteUri.toString).toEntity) ~> addCredentials(BasicHttpCredentials("", mockService.password)) ~> Route.seal(mockService.route) ~> check { assert(handled) assert(status == OK) assert(entityAs[String] == "\"connected\"") - eclair.connect(remoteUri)(any[Timeout]).wasCalled(twice) // must account for the previous, identical, invocation + eclair.connect(Left(remoteUri))(any[Timeout]).wasCalled(once) // must account for the previous, identical, invocation } } + test("'send' method should correctly forward amount parameters to EclairImpl") { val invoice = "lnbc12580n1pw2ywztpp554ganw404sh4yjkwnysgn3wjcxfcq7gtx53gxczkjr9nlpc3hzvqdq2wpskwctddyxqr4rqrzjqwryaup9lh50kkranzgcdnn2fgvx390wgj5jd07rwr3vxeje0glc7z9rtvqqwngqqqqqqqlgqqqqqeqqjqrrt8smgjvfj7sg38dwtr9kc9gg3era9k3t2hvq3cup0jvsrtrxuplevqgfhd3rzvhulgcxj97yjuj8gdx8mllwj4wzjd8gdjhpz3lpqqvk2plh" diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala index 32dad143e..f94dcabe6 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala @@ -47,6 +47,7 @@ class SqliteNetworkDbSpec extends FunSuite { assert(db.listNodes().toSet === Set.empty) db.addNode(node_1) db.addNode(node_1) // duplicate is ignored + assert(db.getNode(node_1.nodeId) == Some(node_1)) assert(db.listNodes().size === 1) db.addNode(node_2) db.addNode(node_3) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala index aaed01027..d2e56fa2c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala @@ -159,9 +159,10 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService def connect(node1: Kit, node2: Kit, fundingSatoshis: Long, pushMsat: Long) = { val sender = TestProbe() val address = node2.nodeParams.publicAddresses.head - sender.send(node1.switchboard, Peer.Connect(NodeURI( + sender.send(node1.switchboard, Peer.Connect( nodeId = node2.nodeParams.nodeId, - address = HostAndPort.fromParts(address.socketAddress.getHostString, address.socketAddress.getPort)))) + address_opt = Some(HostAndPort.fromParts(address.socketAddress.getHostString, address.socketAddress.getPort)) + )) sender.expectMsgAnyOf(10 seconds, "connected", "already connected") sender.send(node1.switchboard, Peer.OpenChannel( remoteNodeId = node2.nodeParams.nodeId, @@ -485,7 +486,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(nodes("F1").switchboard, 'peers) val peers = sender.expectMsgType[Iterable[ActorRef]] // F's only node is C - peers.head ! Disconnect + peers.head ! Peer.Disconnect(nodes("C").nodeParams.nodeId) // we then wait for F to be in disconnected state awaitCond({ sender.send(nodes("F1").register, Forward(htlc.channelId, CMD_GETSTATE)) @@ -566,7 +567,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(nodes("F2").switchboard, 'peers) val peers = sender.expectMsgType[Iterable[ActorRef]] // F's only node is C - peers.head ! Disconnect + peers.head ! Disconnect(nodes("C").nodeParams.nodeId) // we then wait for F to be in disconnected state awaitCond({ sender.send(nodes("F2").register, Forward(htlc.channelId, CMD_GETSTATE)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala index c5bb498ff..157c6212a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala @@ -16,12 +16,14 @@ package fr.acinq.eclair.io -import java.net.InetSocketAddress +import java.net.{Inet4Address, InetSocketAddress} -import akka.actor.ActorRef -import akka.testkit.{TestFSMRef, TestProbe} +import akka.actor.{ActorRef, ActorSystem, PoisonPill} +import akka.actor.FSM.{CurrentState, SubscribeTransitionCallBack, Transition} +import akka.testkit.{EventFilter, TestFSMRef, TestKit, TestProbe} import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.TestConstants._ +import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.EclairWallet import fr.acinq.eclair.channel.HasCommitments import fr.acinq.eclair.crypto.TransportHandler @@ -29,14 +31,18 @@ import fr.acinq.eclair.db.ChannelStateSpec import fr.acinq.eclair.io.Peer._ import fr.acinq.eclair.router.RoutingSyncSpec.makeFakeRoutingInfo import fr.acinq.eclair.router.{ChannelRangeQueries, ChannelRangeQueriesSpec, Rebroadcast} -import fr.acinq.eclair.wire.{Error, Ping, Pong} -import fr.acinq.eclair.{ShortChannelId, TestkitBaseClass, randomBytes, wire} -import org.scalatest.Outcome +import fr.acinq.eclair.wire.LightningMessageCodecsSpec.randomSignature +import fr.acinq.eclair.wire.{Color, Error, IPv4, NodeAddress, NodeAnnouncement, Ping, Pong} +import org.scalatest.{Outcome, Tag} +import scodec.bits.ByteVector import scala.concurrent.duration._ - class PeerSpec extends TestkitBaseClass { + + def ipv4FromInet4(address: InetSocketAddress) = IPv4.apply(address.getAddress.asInstanceOf[Inet4Address], address.getPort) + + val fakeIPAddress = NodeAddress.fromParts("1.2.3.4", 42000).get val shortChannelIds = ChannelRangeQueriesSpec.shortChannelIds.take(100) val fakeRoutingInfo = shortChannelIds.map(makeFakeRoutingInfo) val channels = fakeRoutingInfo.map(_._1).toList @@ -46,6 +52,15 @@ class PeerSpec extends TestkitBaseClass { case class FixtureParam(remoteNodeId: PublicKey, authenticator: TestProbe, watcher: TestProbe, router: TestProbe, relayer: TestProbe, connection: TestProbe, transport: TestProbe, peer: TestFSMRef[Peer.State, Peer.Data, Peer]) override protected def withFixture(test: OneArgTest): Outcome = { + val aParams = Alice.nodeParams + val aliceParams = test.tags.contains("with_node_announcements") match { + case true => + val aliceAnnouncement = NodeAnnouncement(randomSignature, ByteVector.empty, 1, Bob.nodeParams.nodeId, Color(100.toByte, 200.toByte, 300.toByte), "node-alias", fakeIPAddress :: Nil) + aParams.db.network.addNode(aliceAnnouncement) + aParams + case false => aParams + } + val authenticator = TestProbe() val watcher = TestProbe() val router = TestProbe() @@ -54,7 +69,7 @@ class PeerSpec extends TestkitBaseClass { val transport = TestProbe() val wallet: EclairWallet = null // unused val remoteNodeId = Bob.nodeParams.nodeId - val peer: TestFSMRef[Peer.State, Peer.Data, Peer] = TestFSMRef(new Peer(Alice.nodeParams, remoteNodeId, authenticator.ref, watcher.ref, router.ref, relayer.ref, wallet)) + val peer: TestFSMRef[Peer.State, Peer.Data, Peer] = TestFSMRef(new Peer(aliceParams, remoteNodeId, authenticator.ref, watcher.ref, router.ref, relayer.ref, wallet)) withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, authenticator, watcher, router, relayer, connection, transport, peer))) } @@ -62,7 +77,7 @@ class PeerSpec extends TestkitBaseClass { // let's simulate a connection val probe = TestProbe() probe.send(peer, Peer.Init(None, channels)) - authenticator.send(peer, Authenticator.Authenticated(connection.ref, transport.ref, remoteNodeId, new InetSocketAddress("1.2.3.4", 42000), outgoing = true, None)) + authenticator.send(peer, Authenticator.Authenticated(connection.ref, transport.ref, remoteNodeId, fakeIPAddress.socketAddress, outgoing = true, None)) transport.expectMsgType[TransportHandler.Listener] transport.expectMsgType[wire.Init] transport.send(peer, wire.Init(Bob.nodeParams.globalFeatures, Bob.nodeParams.localFeatures)) @@ -77,7 +92,38 @@ class PeerSpec extends TestkitBaseClass { val probe = TestProbe() connect(remoteNodeId, authenticator, watcher, router, relayer, connection, transport, peer, channels = Set(ChannelStateSpec.normal)) probe.send(peer, Peer.GetPeerInfo) - probe.expectMsg(PeerInfo(remoteNodeId, "CONNECTED", Some(new InetSocketAddress("1.2.3.4", 42000)), 1)) + probe.expectMsg(PeerInfo(remoteNodeId, "CONNECTED", Some(fakeIPAddress.socketAddress), 1)) + } + + test("fail to connect if no address provided or found") { f => + import f._ + + val probe = TestProbe() + val monitor = TestProbe() + + peer ! SubscribeTransitionCallBack(monitor.ref) + + probe.send(peer, Peer.Init(None, Set.empty)) + val CurrentState(_, INSTANTIATING) = monitor.expectMsgType[CurrentState[_]] + val Transition(_, INSTANTIATING, DISCONNECTED) = monitor.expectMsgType[Transition[_]] + probe.send(peer, Peer.Connect(remoteNodeId, address_opt = None)) + probe.expectMsg(s"no address found") + } + + test("if no address was specified during connection use the one from node_announcement", Tag("with_node_announcements")) { f => + import f._ + + val probe = TestProbe() + val monitor = TestProbe() + + peer ! SubscribeTransitionCallBack(monitor.ref) + + probe.send(peer, Peer.Init(None, Set.empty)) + val CurrentState(_, INSTANTIATING) = monitor.expectMsgType[CurrentState[_]] + val Transition(_, INSTANTIATING, DISCONNECTED) = monitor.expectMsgType[Transition[_]] + + probe.send(peer, Peer.Connect(remoteNodeId, None)) + awaitCond(peer.stateData.address_opt == Some(fakeIPAddress.socketAddress)) } test("ignore connect to same address") { f => @@ -119,7 +165,7 @@ class PeerSpec extends TestkitBaseClass { awaitCond(peer.stateData.asInstanceOf[DisconnectedData].attempts == 3) } - test("disconnect if incompatible features") {f => + test("disconnect if incompatible features") { f => import f._ val probe = TestProbe() probe.watch(transport.ref) @@ -133,6 +179,33 @@ class PeerSpec extends TestkitBaseClass { probe.expectTerminated(transport.ref) } + test("handle disconnect in status INITIALIZING") { f => + import f._ + + val probe = TestProbe() + probe.send(peer, Peer.Init(None, Set(ChannelStateSpec.normal))) + authenticator.send(peer, Authenticator.Authenticated(connection.ref, transport.ref, remoteNodeId, fakeIPAddress.socketAddress, outgoing = true, None)) + + probe.send(peer, Peer.GetPeerInfo) + assert(probe.expectMsgType[Peer.PeerInfo].state == "INITIALIZING") + + probe.send(peer, Peer.Disconnect(f.remoteNodeId)) + probe.expectMsg("disconnecting") + } + + test("handle disconnect in status CONNECTED") { f => + import f._ + + val probe = TestProbe() + connect(remoteNodeId, authenticator, watcher, router, relayer, connection, transport, peer, channels = Set(ChannelStateSpec.normal)) + + probe.send(peer, Peer.GetPeerInfo) + assert(probe.expectMsgType[Peer.PeerInfo].state == "CONNECTED") + + probe.send(peer, Peer.Disconnect(f.remoteNodeId)) + probe.expectMsg("disconnecting") + } + test("reply to ping") { f => import f._ val probe = TestProbe() @@ -203,7 +276,7 @@ class PeerSpec extends TestkitBaseClass { probe.send(peer, filter) probe.send(peer, rebroadcast) // peer doesn't filter channel announcements - channels.foreach(transport.expectMsg(_)) + channels.foreach(transport.expectMsg(10 seconds, _)) // but it will only send updates and node announcements matching the filter updates.filter(u => timestamps.contains(u.timestamp)).foreach(transport.expectMsg(_)) nodes.filter(u => timestamps.contains(u.timestamp)).foreach(transport.expectMsg(_)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpecWithLogging.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpecWithLogging.scala new file mode 100644 index 000000000..265ccf826 --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpecWithLogging.scala @@ -0,0 +1,43 @@ +package fr.acinq.eclair.io + +import akka.actor.{ActorRef, ActorSystem} +import akka.testkit.{EventFilter, TestFSMRef, TestKit, TestProbe} +import com.typesafe.config.ConfigFactory +import fr.acinq.eclair.db.ChannelStateSpec +import org.scalatest.{FunSuiteLike, Outcome, Tag} +import scala.concurrent.duration._ +import akka.testkit.{TestFSMRef, TestProbe} +import fr.acinq.eclair.TestConstants.{Alice, Bob} +import fr.acinq.eclair.blockchain.EclairWallet +import fr.acinq.eclair.wire.LightningMessageCodecsSpec.randomSignature +import fr.acinq.eclair.wire.{Color, IPv4, NodeAddress, NodeAnnouncement} +import scodec.bits.ByteVector + +class PeerSpecWithLogging extends TestKit(ActorSystem("test", ConfigFactory.parseString("""akka.loggers = ["akka.testkit.TestEventListener"]"""))) with FunSuiteLike { + + val fakeIPAddress = NodeAddress.fromParts("1.2.3.4", 42000).get + + test("reconnect using the address from node_announcement") { + val aliceParams = Alice.nodeParams + val aliceAnnouncement = NodeAnnouncement(randomSignature, ByteVector.empty, 1, Bob.nodeParams.nodeId, Color(100.toByte, 200.toByte, 300.toByte), "node-alias", fakeIPAddress :: Nil) + aliceParams.db.network.addNode(aliceAnnouncement) + val authenticator = TestProbe() + val watcher = TestProbe() + val router = TestProbe() + val relayer = TestProbe() + val wallet: EclairWallet = null // unused + val remoteNodeId = Bob.nodeParams.nodeId + val peer: TestFSMRef[Peer.State, Peer.Data, Peer] = TestFSMRef(new Peer(aliceParams, remoteNodeId, authenticator.ref, watcher.ref, router.ref, relayer.ref, wallet)) + + + val probe = TestProbe() + awaitCond({peer.stateName.toString == "INSTANTIATING"}, 10 seconds) + probe.send(peer, Peer.Init(None, Set(ChannelStateSpec.normal))) + awaitCond({peer.stateName.toString == "DISCONNECTED" && peer.stateData.address_opt.isEmpty}, 10 seconds) + EventFilter.info(message = s"reconnecting to ${fakeIPAddress.socketAddress}", occurrences = 1) intercept { + probe.send(peer, Peer.Reconnect) + } + } + + +}