1
0
Fork 0
mirror of https://github.com/ACINQ/eclair.git synced 2025-02-24 14:50:46 +01:00

Fix race condition on early connection failure (#1430)

Both `Client` and `TransportHandler` were watching the connection actor,
which resulted in undeterministic behavior during termination of
`PeerConnection`.

We now always return a message when a connection fails during
authentication.

Took the opportunity to add more typing (insert
deathtoallthestring.jpg).
This commit is contained in:
Pierre-Marie Padiou 2020-05-19 12:45:53 +02:00 committed by GitHub
parent 9faaf24934
commit c01031708d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 73 additions and 39 deletions

View file

@ -29,7 +29,7 @@ import fr.acinq.eclair.channel.Register.{Forward, ForwardShortId}
import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel._
import fr.acinq.eclair.db.{IncomingPayment, NetworkFee, OutgoingPayment, Stats} import fr.acinq.eclair.db.{IncomingPayment, NetworkFee, OutgoingPayment, Stats}
import fr.acinq.eclair.io.Peer.{GetPeerInfo, PeerInfo} import fr.acinq.eclair.io.Peer.{GetPeerInfo, PeerInfo}
import fr.acinq.eclair.io.{NodeURI, Peer} import fr.acinq.eclair.io.{NodeURI, Peer, PeerConnection}
import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment._
import fr.acinq.eclair.payment.receive.MultiPartHandler.ReceivePayment import fr.acinq.eclair.payment.receive.MultiPartHandler.ReceivePayment
import fr.acinq.eclair.payment.relay.Relayer.{GetOutgoingChannels, OutgoingChannels, UsableBalance} import fr.acinq.eclair.payment.relay.Relayer.{GetOutgoingChannels, OutgoingChannels, UsableBalance}
@ -128,8 +128,8 @@ class EclairImpl(appKit: Kit) extends Eclair {
private val externalIdMaxLength = 66 private val externalIdMaxLength = 66
override def connect(target: Either[NodeURI, PublicKey])(implicit timeout: Timeout): Future[String] = target match { override def connect(target: Either[NodeURI, PublicKey])(implicit timeout: Timeout): Future[String] = target match {
case Left(uri) => (appKit.switchboard ? Peer.Connect(uri)).mapTo[String] case Left(uri) => (appKit.switchboard ? Peer.Connect(uri)).mapTo[PeerConnection.ConnectionResult].map(_.toString)
case Right(pubKey) => (appKit.switchboard ? Peer.Connect(pubKey, None)).mapTo[String] case Right(pubKey) => (appKit.switchboard ? Peer.Connect(pubKey, None)).mapTo[PeerConnection.ConnectionResult].map(_.toString)
} }
override def disconnect(nodeId: PublicKey)(implicit timeout: Timeout): Future[String] = { override def disconnect(nodeId: PublicKey)(implicit timeout: Timeout): Future[String] = {

View file

@ -268,7 +268,10 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co
} }
} }
override def aroundPostStop(): Unit = connection ! Tcp.Close // attempts to gracefully close the connection when dying onTermination {
case _: StopEvent =>
connection ! Tcp.Close // attempts to gracefully close the connection when dying
}
initialize() initialize()

View file

@ -24,7 +24,6 @@ import akka.io.Tcp.SO.KeepAlive
import akka.io.{IO, Tcp} import akka.io.{IO, Tcp}
import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair.Logs.LogCategory
import fr.acinq.eclair.io.Client.ConnectionFailed
import fr.acinq.eclair.tor.Socks5Connection.{Socks5Connect, Socks5Connected, Socks5Error} import fr.acinq.eclair.tor.Socks5Connection.{Socks5Connect, Socks5Connected, Socks5Error}
import fr.acinq.eclair.tor.{Socks5Connection, Socks5ProxyParams} import fr.acinq.eclair.tor.{Socks5Connection, Socks5ProxyParams}
import fr.acinq.eclair.{Logs, NodeParams} import fr.acinq.eclair.{Logs, NodeParams}
@ -60,7 +59,7 @@ class Client(nodeParams: NodeParams, switchboard: ActorRef, router: ActorRef, re
case Tcp.CommandFailed(c: Tcp.Connect) => case Tcp.CommandFailed(c: Tcp.Connect) =>
val peerOrProxyAddress = c.remoteAddress val peerOrProxyAddress = c.remoteAddress
log.info(s"connection failed to ${str(peerOrProxyAddress)}") log.info(s"connection failed to ${str(peerOrProxyAddress)}")
origin_opt.foreach(_ ! Status.Failure(ConnectionFailed(remoteAddress))) origin_opt.foreach(_ ! PeerConnection.ConnectionResult.ConnectionFailed(remoteAddress))
context stop self context stop self
case Tcp.Connected(peerOrProxyAddress, _) => case Tcp.Connected(peerOrProxyAddress, _) =>
@ -75,24 +74,28 @@ class Client(nodeParams: NodeParams, switchboard: ActorRef, router: ActorRef, re
context become { context become {
case Tcp.CommandFailed(_: Socks5Connect) => case Tcp.CommandFailed(_: Socks5Connect) =>
log.info(s"connection failed to ${str(remoteAddress)} via SOCKS5 ${str(proxyAddress)}") log.info(s"connection failed to ${str(remoteAddress)} via SOCKS5 ${str(proxyAddress)}")
origin_opt.foreach(_ ! Status.Failure(ConnectionFailed(remoteAddress))) origin_opt.foreach(_ ! PeerConnection.ConnectionResult.ConnectionFailed(remoteAddress))
context stop self context stop self
case Socks5Connected(_) => case Socks5Connected(_) =>
log.info(s"connected to ${str(remoteAddress)} via SOCKS5 proxy ${str(proxyAddress)}") log.info(s"connected to ${str(remoteAddress)} via SOCKS5 proxy ${str(proxyAddress)}")
auth(proxy) context unwatch proxy
context become connected(proxy) val peerConnection = auth(proxy)
context watch peerConnection
context become connected(peerConnection)
case Terminated(actor) if actor == proxy =>
context stop self
} }
case None => case None =>
val peerAddress = peerOrProxyAddress val peerAddress = peerOrProxyAddress
log.info(s"connected to ${str(peerAddress)}") log.info(s"connected to ${str(peerAddress)}")
auth(connection) val peerConnection = auth(connection)
context watch connection context watch peerConnection
context become connected(connection) context become connected(peerConnection)
} }
} }
def connected(connection: ActorRef): Receive = { def connected(peerConnection: ActorRef): Receive = {
case Terminated(actor) if actor == connection => case Terminated(actor) if actor == peerConnection =>
context stop self context stop self
} }
@ -100,7 +103,7 @@ class Client(nodeParams: NodeParams, switchboard: ActorRef, router: ActorRef, re
log.warning(s"unhandled message=$message") log.warning(s"unhandled message=$message")
} }
// we should not restart a failing socks client // we should not restart a failing socks client or transport handler
override val supervisorStrategy = OneForOneStrategy(loggingEnabled = false) { override val supervisorStrategy = OneForOneStrategy(loggingEnabled = false) {
case t => case t =>
Logs.withMdc(log)(Logs.mdc(remoteNodeId_opt = Some(remoteNodeId))) { Logs.withMdc(log)(Logs.mdc(remoteNodeId_opt = Some(remoteNodeId))) {
@ -116,13 +119,14 @@ class Client(nodeParams: NodeParams, switchboard: ActorRef, router: ActorRef, re
private def str(address: InetSocketAddress): String = s"${address.getHostString}:${address.getPort}" private def str(address: InetSocketAddress): String = s"${address.getHostString}:${address.getPort}"
def auth(connection: ActorRef) = { def auth(connection: ActorRef): ActorRef = {
val peerConnection = context.actorOf(PeerConnection.props( val peerConnection = context.actorOf(PeerConnection.props(
nodeParams = nodeParams, nodeParams = nodeParams,
switchboard = switchboard, switchboard = switchboard,
router = router router = router
)) ))
peerConnection ! PeerConnection.PendingAuth(connection, remoteNodeId_opt = Some(remoteNodeId), address = remoteAddress, origin_opt = origin_opt) peerConnection ! PeerConnection.PendingAuth(connection, remoteNodeId_opt = Some(remoteNodeId), address = remoteAddress, origin_opt = origin_opt)
peerConnection
} }
} }
@ -130,6 +134,4 @@ object Client {
def props(nodeParams: NodeParams, switchboard: ActorRef, router: ActorRef, address: InetSocketAddress, remoteNodeId: PublicKey, origin_opt: Option[ActorRef]): Props = Props(new Client(nodeParams, switchboard, router, address, remoteNodeId, origin_opt)) def props(nodeParams: NodeParams, switchboard: ActorRef, router: ActorRef, address: InetSocketAddress, remoteNodeId: PublicKey, origin_opt: Option[ActorRef]): Props = Props(new Client(nodeParams, switchboard, router, address, remoteNodeId, origin_opt))
case class ConnectionFailed(address: InetSocketAddress) extends RuntimeException(s"connection failed to $address")
} }

View file

@ -86,7 +86,7 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRe
when(CONNECTED) { when(CONNECTED) {
dropStaleMessages { dropStaleMessages {
case Event(_: Peer.Connect, _) => case Event(_: Peer.Connect, _) =>
sender ! "already connected" sender ! PeerConnection.ConnectionResult.AlreadyConnected
stay stay
case Event(Channel.OutgoingMessage(msg, peerConnection), d: ConnectedData) if peerConnection == d.peerConnection => // this is an outgoing message, but we need to make sure that this is for the current active connection case Event(Channel.OutgoingMessage(msg, peerConnection), d: ConnectedData) if peerConnection == d.peerConnection => // this is an outgoing message, but we need to make sure that this is for the current active connection

View file

@ -89,8 +89,9 @@ class PeerConnection(nodeParams: NodeParams, switchboard: ActorRef, router: Acto
switchboard ! Authenticated(self, remoteNodeId) switchboard ! Authenticated(self, remoteNodeId)
goto(BEFORE_INIT) using BeforeInitData(remoteNodeId, d.pendingAuth, d.transport) goto(BEFORE_INIT) using BeforeInitData(remoteNodeId, d.pendingAuth, d.transport)
case Event(AuthTimeout, _) => case Event(AuthTimeout, d: AuthenticatingData) =>
log.warning(s"authentication timed out after ${nodeParams.authTimeout}") log.warning(s"authentication timed out after ${nodeParams.authTimeout}")
d.pendingAuth.origin_opt.foreach(_ ! ConnectionResult.AuthenticationFailed("authentication timed out"))
stop(FSM.Normal) stop(FSM.Normal)
} }
@ -133,19 +134,19 @@ class PeerConnection(nodeParams: NodeParams, switchboard: ActorRef, router: Acto
if (remoteInit.networks.nonEmpty && !remoteInit.networks.contains(d.nodeParams.chainHash)) { if (remoteInit.networks.nonEmpty && !remoteInit.networks.contains(d.nodeParams.chainHash)) {
log.warning(s"incompatible networks (${remoteInit.networks}), disconnecting") log.warning(s"incompatible networks (${remoteInit.networks}), disconnecting")
d.pendingAuth.origin_opt.foreach(origin => origin ! Status.Failure(new RuntimeException("incompatible networks"))) d.pendingAuth.origin_opt.foreach(_ ! ConnectionResult.InitializationFailed("incompatible networks"))
d.transport ! PoisonPill d.transport ! PoisonPill
stay stay
} else if (!Features.areSupported(remoteInit.features)) { } else if (!Features.areSupported(remoteInit.features)) {
log.warning("incompatible features, disconnecting") log.warning("incompatible features, disconnecting")
d.pendingAuth.origin_opt.foreach(origin => origin ! Status.Failure(new RuntimeException("incompatible features"))) d.pendingAuth.origin_opt.foreach(_ ! ConnectionResult.InitializationFailed("incompatible features"))
d.transport ! PoisonPill d.transport ! PoisonPill
stay stay
} else { } else {
Metrics.PeerConnectionsConnecting.withTag(Tags.ConnectionState, Tags.ConnectionStates.Initialized).increment() Metrics.PeerConnectionsConnecting.withTag(Tags.ConnectionState, Tags.ConnectionStates.Initialized).increment()
d.peer ! ConnectionReady(self, d.remoteNodeId, d.pendingAuth.address, d.pendingAuth.outgoing, d.localInit, remoteInit) d.peer ! ConnectionReady(self, d.remoteNodeId, d.pendingAuth.address, d.pendingAuth.outgoing, d.localInit, remoteInit)
d.pendingAuth.origin_opt.foreach(origin => origin ! "connected") d.pendingAuth.origin_opt.foreach(_ ! ConnectionResult.Connected)
def localHasFeature(f: Feature): Boolean = Features.hasFeature(d.localInit.features, f) def localHasFeature(f: Feature): Boolean = Features.hasFeature(d.localInit.features, f)
@ -177,8 +178,9 @@ class PeerConnection(nodeParams: NodeParams, switchboard: ActorRef, router: Acto
goto(CONNECTED) using ConnectedData(d.nodeParams, d.remoteNodeId, d.transport, d.peer, d.localInit, remoteInit, rebroadcastDelay) goto(CONNECTED) using ConnectedData(d.nodeParams, d.remoteNodeId, d.transport, d.peer, d.localInit, remoteInit, rebroadcastDelay)
} }
case Event(InitTimeout, _) => case Event(InitTimeout, d: InitializingData) =>
log.warning(s"initialization timed out after ${nodeParams.initTimeout}") log.warning(s"initialization timed out after ${nodeParams.initTimeout}")
d.pendingAuth.origin_opt.foreach(_ ! ConnectionResult.InitializationFailed("initialization timed out"))
stop(FSM.Normal) stop(FSM.Normal)
} }
} }
@ -382,6 +384,12 @@ class PeerConnection(nodeParams: NodeParams, switchboard: ActorRef, router: Acto
Logs.withMdc(diagLog)(Logs.mdc(category_opt = Some(Logs.LogCategory.CONNECTION))) { Logs.withMdc(diagLog)(Logs.mdc(category_opt = Some(Logs.LogCategory.CONNECTION))) {
log.info("transport died, stopping") log.info("transport died, stopping")
} }
d match {
case a: AuthenticatingData => a.pendingAuth.origin_opt.foreach(_ ! ConnectionResult.AuthenticationFailed("connection aborted while authenticating"))
case a: BeforeInitData => a.pendingAuth.origin_opt.foreach(_ ! ConnectionResult.InitializationFailed("connection aborted while initializing"))
case a: InitializingData => a.pendingAuth.origin_opt.foreach(_ ! ConnectionResult.InitializationFailed("connection aborted while initializing"))
case _ => ()
}
stop(FSM.Normal) stop(FSM.Normal)
case Event(_: GossipDecision.Accepted, _) => stay // for now we don't do anything with those events case Event(_: GossipDecision.Accepted, _) => stay // for now we don't do anything with those events
@ -500,6 +508,19 @@ object PeerConnection {
case class InitializeConnection(peer: ActorRef) case class InitializeConnection(peer: ActorRef)
case class ConnectionReady(peerConnection: ActorRef, remoteNodeId: PublicKey, address: InetSocketAddress, outgoing: Boolean, localInit: wire.Init, remoteInit: wire.Init) case class ConnectionReady(peerConnection: ActorRef, remoteNodeId: PublicKey, address: InetSocketAddress, outgoing: Boolean, localInit: wire.Init, remoteInit: wire.Init)
sealed trait ConnectionResult
object ConnectionResult {
sealed trait Success extends ConnectionResult
sealed trait Failure extends ConnectionResult
case object NoAddressFound extends ConnectionResult.Failure { override def toString: String = "no address found" }
case class ConnectionFailed(address: InetSocketAddress) extends ConnectionResult.Failure { override def toString: String = s"connection failed to $address" }
case class AuthenticationFailed(reason: String) extends ConnectionResult.Failure { override def toString: String = reason }
case class InitializationFailed(reason: String) extends ConnectionResult.Failure { override def toString: String = reason }
case object AlreadyConnected extends ConnectionResult.Failure { override def toString: String = "already connected" }
case object Connected extends ConnectionResult.Success { override def toString: String = "connected" }
}
case class DelayedRebroadcast(rebroadcast: Rebroadcast) case class DelayedRebroadcast(rebroadcast: Rebroadcast)
case class Behavior(fundingTxAlreadySpentCount: Int = 0, ignoreNetworkAnnouncement: Boolean = false) case class Behavior(fundingTxAlreadySpentCount: Int = 0, ignoreNetworkAnnouncement: Boolean = false)

View file

@ -50,7 +50,7 @@ class ReconnectionTask(nodeParams: NodeParams, remoteNodeId: PublicKey) extends
startWith(IDLE, IdleData(Nothing)) startWith(IDLE, IdleData(Nothing))
when(CONNECTING) { when(CONNECTING) {
case Event(Status.Failure(_: Client.ConnectionFailed), d: ConnectingData) => case Event(_: PeerConnection.ConnectionResult.ConnectionFailed, d: ConnectingData) =>
log.info(s"connection failed, next reconnection in ${d.nextReconnectionDelay.toSeconds} seconds") log.info(s"connection failed, next reconnection in ${d.nextReconnectionDelay.toSeconds} seconds")
setReconnectTimer(d.nextReconnectionDelay) setReconnectTimer(d.nextReconnectionDelay)
goto(WAITING) using WaitingData(nextReconnectionDelay(d.nextReconnectionDelay, nodeParams.maxReconnectInterval)) goto(WAITING) using WaitingData(nextReconnectionDelay(d.nextReconnectionDelay, nodeParams.maxReconnectInterval))
@ -121,9 +121,7 @@ class ReconnectionTask(nodeParams: NodeParams, remoteNodeId: PublicKey) extends
} }
whenUnhandled { whenUnhandled {
case Event("connected", _) => stay case Event(_: PeerConnection.ConnectionResult, _) => stay
case Event(Status.Failure(_: Client.ConnectionFailed), _) => stay
case Event(TickReconnect, _) => stay case Event(TickReconnect, _) => stay
@ -135,7 +133,7 @@ class ReconnectionTask(nodeParams: NodeParams, remoteNodeId: PublicKey) extends
.map(hostAndPort2InetSocketAddress) .map(hostAndPort2InetSocketAddress)
.orElse(getPeerAddressFromDb(nodeParams.db.peers, nodeParams.db.network, remoteNodeId)) match { .orElse(getPeerAddressFromDb(nodeParams.db.peers, nodeParams.db.network, remoteNodeId)) match {
case Some(address) => connect(address, origin = sender) case Some(address) => connect(address, origin = sender)
case None => sender ! "no address found" case None => sender ! PeerConnection.ConnectionResult.NoAddressFound
} }
stay stay
} }

View file

@ -36,7 +36,7 @@ import fr.acinq.eclair.channel._
import fr.acinq.eclair.crypto.Sphinx.DecryptedFailurePacket import fr.acinq.eclair.crypto.Sphinx.DecryptedFailurePacket
import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.crypto.TransportHandler
import fr.acinq.eclair.db._ import fr.acinq.eclair.db._
import fr.acinq.eclair.io.Peer import fr.acinq.eclair.io.{Peer, PeerConnection}
import fr.acinq.eclair.io.Peer.{Disconnect, PeerRoutingMessage} import fr.acinq.eclair.io.Peer.{Disconnect, PeerRoutingMessage}
import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.payment.PaymentRequest.ExtraHop
import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment._
@ -172,7 +172,7 @@ class IntegrationSpec extends TestKitBaseClass with BitcoindService with AnyFunS
nodeId = node2.nodeParams.nodeId, nodeId = node2.nodeParams.nodeId,
address_opt = Some(HostAndPort.fromParts(address.socketAddress.getHostString, address.socketAddress.getPort)) address_opt = Some(HostAndPort.fromParts(address.socketAddress.getHostString, address.socketAddress.getPort))
)) ))
sender.expectMsgAnyOf(10 seconds, "connected", "already connected") sender.expectMsgAnyOf(10 seconds, PeerConnection.ConnectionResult.Connected, PeerConnection.ConnectionResult.AlreadyConnected)
sender.send(node1.switchboard, Peer.OpenChannel( sender.send(node1.switchboard, Peer.OpenChannel(
remoteNodeId = node2.nodeParams.nodeId, remoteNodeId = node2.nodeParams.nodeId,
fundingSatoshis = fundingSatoshis, fundingSatoshis = fundingSatoshis,
@ -318,7 +318,7 @@ class IntegrationSpec extends TestKitBaseClass with BitcoindService with AnyFunS
nodeId = funder.nodeParams.nodeId, nodeId = funder.nodeParams.nodeId,
address_opt = Some(HostAndPort.fromParts(funder.nodeParams.publicAddresses.head.socketAddress.getHostString, funder.nodeParams.publicAddresses.head.socketAddress.getPort)) address_opt = Some(HostAndPort.fromParts(funder.nodeParams.publicAddresses.head.socketAddress.getHostString, funder.nodeParams.publicAddresses.head.socketAddress.getPort))
)) ))
sender.expectMsgAnyOf(10 seconds, "connected", "already connected", "reconnection in progress") sender.expectMsgAnyOf(10 seconds, PeerConnection.ConnectionResult.Connected, PeerConnection.ConnectionResult.AlreadyConnected)
sender.send(fundee.register, Forward(channelId, CMD_GETSTATE)) sender.send(fundee.register, Forward(channelId, CMD_GETSTATE))
val fundeeState = sender.expectMsgType[State](max = 30 seconds) val fundeeState = sender.expectMsgType[State](max = 30 seconds)

View file

@ -107,26 +107,31 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi
test("disconnect if authentication timeout") { f => test("disconnect if authentication timeout") { f =>
import f._ import f._
val probe = TestProbe() val probe = TestProbe()
val origin = TestProbe()
probe.watch(peerConnection) probe.watch(peerConnection)
probe.send(peerConnection, PeerConnection.PendingAuth(connection.ref, Some(remoteNodeId), address, origin_opt = None, transport_opt = Some(transport.ref))) probe.send(peerConnection, PeerConnection.PendingAuth(connection.ref, Some(remoteNodeId), address, origin_opt = Some(origin.ref), transport_opt = Some(transport.ref)))
probe.expectTerminated(peerConnection, nodeParams.authTimeout / transport.testKitSettings.TestTimeFactor + 1.second) // we don't want dilated time here probe.expectTerminated(peerConnection, nodeParams.authTimeout / transport.testKitSettings.TestTimeFactor + 1.second) // we don't want dilated time here
origin.expectMsg(PeerConnection.ConnectionResult.AuthenticationFailed("authentication timed out"))
} }
test("disconnect if init timeout") { f => test("disconnect if init timeout") { f =>
import f._ import f._
val probe = TestProbe() val probe = TestProbe()
val origin = TestProbe()
probe.watch(peerConnection) probe.watch(peerConnection)
probe.send(peerConnection, PeerConnection.PendingAuth(connection.ref, Some(remoteNodeId), address, origin_opt = None, transport_opt = Some(transport.ref))) probe.send(peerConnection, PeerConnection.PendingAuth(connection.ref, Some(remoteNodeId), address, origin_opt = Some(origin.ref), transport_opt = Some(transport.ref)))
transport.send(peerConnection, TransportHandler.HandshakeCompleted(remoteNodeId)) transport.send(peerConnection, TransportHandler.HandshakeCompleted(remoteNodeId))
probe.send(peerConnection, PeerConnection.InitializeConnection(peer.ref)) probe.send(peerConnection, PeerConnection.InitializeConnection(peer.ref))
probe.expectTerminated(peerConnection, nodeParams.initTimeout / transport.testKitSettings.TestTimeFactor + 1.second) // we don't want dilated time here probe.expectTerminated(peerConnection, nodeParams.initTimeout / transport.testKitSettings.TestTimeFactor + 1.second) // we don't want dilated time here
origin.expectMsg(PeerConnection.ConnectionResult.InitializationFailed("initialization timed out"))
} }
test("disconnect if incompatible local features") { f => test("disconnect if incompatible local features") { f =>
import f._ import f._
val probe = TestProbe() val probe = TestProbe()
val origin = TestProbe()
probe.watch(transport.ref) probe.watch(transport.ref)
probe.send(peerConnection, PeerConnection.PendingAuth(connection.ref, Some(remoteNodeId), address, origin_opt = None, transport_opt = Some(transport.ref))) probe.send(peerConnection, PeerConnection.PendingAuth(connection.ref, Some(remoteNodeId), address, origin_opt = Some(origin.ref), transport_opt = Some(transport.ref)))
transport.send(peerConnection, TransportHandler.HandshakeCompleted(remoteNodeId)) transport.send(peerConnection, TransportHandler.HandshakeCompleted(remoteNodeId))
probe.send(peerConnection, PeerConnection.InitializeConnection(peer.ref)) probe.send(peerConnection, PeerConnection.InitializeConnection(peer.ref))
transport.expectMsgType[TransportHandler.Listener] transport.expectMsgType[TransportHandler.Listener]
@ -134,13 +139,15 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi
transport.send(peerConnection, LightningMessageCodecs.initCodec.decode(hex"0000 00050100000000".bits).require.value) transport.send(peerConnection, LightningMessageCodecs.initCodec.decode(hex"0000 00050100000000".bits).require.value)
transport.expectMsgType[TransportHandler.ReadAck] transport.expectMsgType[TransportHandler.ReadAck]
probe.expectTerminated(transport.ref) probe.expectTerminated(transport.ref)
origin.expectMsg(PeerConnection.ConnectionResult.InitializationFailed("incompatible features"))
} }
test("disconnect if incompatible global features") { f => test("disconnect if incompatible global features") { f =>
import f._ import f._
val probe = TestProbe() val probe = TestProbe()
val origin = TestProbe()
probe.watch(transport.ref) probe.watch(transport.ref)
probe.send(peerConnection, PeerConnection.PendingAuth(connection.ref, Some(remoteNodeId), address, origin_opt = None, transport_opt = Some(transport.ref))) probe.send(peerConnection, PeerConnection.PendingAuth(connection.ref, Some(remoteNodeId), address, origin_opt = Some(origin.ref), transport_opt = Some(transport.ref)))
transport.send(peerConnection, TransportHandler.HandshakeCompleted(remoteNodeId)) transport.send(peerConnection, TransportHandler.HandshakeCompleted(remoteNodeId))
probe.send(peerConnection, PeerConnection.InitializeConnection(peer.ref)) probe.send(peerConnection, PeerConnection.InitializeConnection(peer.ref))
transport.expectMsgType[TransportHandler.Listener] transport.expectMsgType[TransportHandler.Listener]
@ -148,6 +155,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi
transport.send(peerConnection, LightningMessageCodecs.initCodec.decode(hex"00050100000000 0000".bits).require.value) transport.send(peerConnection, LightningMessageCodecs.initCodec.decode(hex"00050100000000 0000".bits).require.value)
transport.expectMsgType[TransportHandler.ReadAck] transport.expectMsgType[TransportHandler.ReadAck]
probe.expectTerminated(transport.ref) probe.expectTerminated(transport.ref)
origin.expectMsg(PeerConnection.ConnectionResult.InitializationFailed("incompatible features"))
} }
test("masks off MPP and PaymentSecret features") { f => test("masks off MPP and PaymentSecret features") { f =>
@ -178,8 +186,9 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi
test("disconnect if incompatible networks") { f => test("disconnect if incompatible networks") { f =>
import f._ import f._
val probe = TestProbe() val probe = TestProbe()
val origin = TestProbe()
probe.watch(transport.ref) probe.watch(transport.ref)
probe.send(peerConnection, PeerConnection.PendingAuth(connection.ref, Some(remoteNodeId), address, origin_opt = None, transport_opt = Some(transport.ref))) probe.send(peerConnection, PeerConnection.PendingAuth(connection.ref, Some(remoteNodeId), address, origin_opt = Some(origin.ref), transport_opt = Some(transport.ref)))
transport.send(peerConnection, TransportHandler.HandshakeCompleted(remoteNodeId)) transport.send(peerConnection, TransportHandler.HandshakeCompleted(remoteNodeId))
probe.send(peerConnection, PeerConnection.InitializeConnection(peer.ref)) probe.send(peerConnection, PeerConnection.InitializeConnection(peer.ref))
transport.expectMsgType[TransportHandler.Listener] transport.expectMsgType[TransportHandler.Listener]
@ -187,6 +196,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi
transport.send(peerConnection, wire.Init(Bob.nodeParams.features, TlvStream(InitTlv.Networks(Block.LivenetGenesisBlock.hash :: Block.SegnetGenesisBlock.hash :: Nil)))) transport.send(peerConnection, wire.Init(Bob.nodeParams.features, TlvStream(InitTlv.Networks(Block.LivenetGenesisBlock.hash :: Block.SegnetGenesisBlock.hash :: Nil))))
transport.expectMsgType[TransportHandler.ReadAck] transport.expectMsgType[TransportHandler.ReadAck]
probe.expectTerminated(transport.ref) probe.expectTerminated(transport.ref)
origin.expectMsg(PeerConnection.ConnectionResult.InitializationFailed("incompatible networks"))
} }
test("sync if no whitelist is defined") { f => test("sync if no whitelist is defined") { f =>

View file

@ -92,7 +92,7 @@ class PeerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with StateTe
val probe = TestProbe() val probe = TestProbe()
probe.send(peer, Peer.Init(Set.empty)) probe.send(peer, Peer.Init(Set.empty))
probe.send(peer, Peer.Connect(remoteNodeId, address_opt = None)) probe.send(peer, Peer.Connect(remoteNodeId, address_opt = None))
probe.expectMsg(s"no address found") probe.expectMsg(PeerConnection.ConnectionResult.NoAddressFound)
} }
test("successfully connect to peer at user request") { f => test("successfully connect to peer at user request") { f =>
@ -156,7 +156,7 @@ class PeerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with StateTe
connect(remoteNodeId, peer, peerConnection, channels = Set(ChannelCodecsSpec.normal)) connect(remoteNodeId, peer, peerConnection, channels = Set(ChannelCodecsSpec.normal))
probe.send(peer, Peer.Connect(remoteNodeId, None)) probe.send(peer, Peer.Connect(remoteNodeId, None))
probe.expectMsg("already connected") probe.expectMsg(PeerConnection.ConnectionResult.AlreadyConnected)
} }
test("handle disconnect in state CONNECTED") { f => test("handle disconnect in state CONNECTED") { f =>