1
0
mirror of https://github.com/ACINQ/eclair.git synced 2024-11-19 01:43:22 +01:00

Add limit for incoming connections from peers without channels (#2601)

This commit is contained in:
Richard Myers 2023-03-24 14:05:10 +01:00 committed by GitHub
parent e1cee96c12
commit 732eb31681
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 389 additions and 33 deletions

View File

@ -61,6 +61,18 @@ eclair.channel.channel-open-limits.max-pending-channels-per-peer = 3
eclair.channel.channel-open-limits.max-total-pending-channels-private-nodes = 99
```
#### Configurable limit on incoming connections (#2601)
We have added a parameter to `eclair.conf` to allow nodes to track the number of incoming connections they maintain from peers they do not have existing channels with. Once the limit is reached, Eclair will disconnect from the oldest tracked peers first.
Outgoing connections and peers on the `sync-whitelist` are exempt from and do not count towards the limit.
The new configuration option and default is as follows:
```conf
// maximum number of incoming connections from peers that do not have any channels with us
eclair.peer-connection.max-no-channels = 250
```
## Verifying signatures
You will need `gpg` and our release signing key 7A73FE77DE2C4027. Note that you can get it:

View File

@ -278,6 +278,7 @@ eclair {
// When enabled, if we receive an incoming connection, we will echo the source IP address in our init message.
// This should be disabled if your node is behind a load balancer that doesn't preserve source IP addresses.
send-remote-address-init = true
max-no-channels = 250 // maximum number of incoming connections from peers that do not have any channels with us
}
auto-reconnect = true

View File

@ -186,7 +186,7 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging {
}
override def disconnect(nodeId: PublicKey)(implicit timeout: Timeout): Future[String] = {
(appKit.switchboard ? Peer.Disconnect(nodeId)).mapTo[String]
(appKit.switchboard ? Peer.Disconnect(nodeId)).mapTo[Peer.DisconnectResponse].map(_.toString)
}
override def open(nodeId: PublicKey, fundingAmount: Satoshi, pushAmount_opt: Option[MilliSatoshi], channelType_opt: Option[SupportedChannelType], fundingFeeratePerByte_opt: Option[FeeratePerByte], announceChannel_opt: Option[Boolean], openTimeout_opt: Option[Timeout])(implicit timeout: Timeout): Future[ChannelOpenResponse] = {

View File

@ -449,6 +449,9 @@ object NodeParams extends Logging {
val maxPendingChannelsPerPeer = config.getInt("channel.channel-open-limits.max-pending-channels-per-peer")
val maxTotalPendingChannelsPrivateNodes = config.getInt("channel.channel-open-limits.max-total-pending-channels-private-nodes")
val maxNoChannels = config.getInt("peer-connection.max-no-channels")
require(maxNoChannels > 0, "peer-connection.max-no-channels must be > 0")
NodeParams(
nodeKeyManager = nodeKeyManager,
channelKeyManager = channelKeyManager,
@ -544,6 +547,7 @@ object NodeParams extends Logging {
killIdleDelay = FiniteDuration(config.getDuration("onion-messages.kill-transient-connection-after").getSeconds, TimeUnit.SECONDS),
maxOnionMessagesPerSecond = config.getInt("onion-messages.max-per-peer-per-second"),
sendRemoteAddressInit = config.getBoolean("peer-connection.send-remote-address-init"),
maxNoChannels = maxNoChannels,
),
routerConf = RouterConf(
watchSpentWindow = watchSpentWindow,

View File

@ -0,0 +1,84 @@
package fr.acinq.eclair.io
import akka.actor.typed.delivery.DurableProducerQueue.TimestampMillis
import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import akka.actor.typed.{ActorRef, Behavior}
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.Logs.LogCategory
import fr.acinq.eclair.{Logs, NodeParams}
import fr.acinq.eclair.channel.ChannelOpened
import fr.acinq.eclair.io.IncomingConnectionsTracker.Command
import fr.acinq.eclair.io.Monitoring.Metrics
import fr.acinq.eclair.io.Peer.{Disconnect, DisconnectResponse}
/**
* A singleton actor that limits the total number of incoming connections from peers that do not have channels with us.
*
* When a new incoming connection request is received, the Switchboard should send an
* [[IncomingConnectionsTracker.TrackIncomingConnection]] message.
*
* When the number of tracked peers exceeds `eclair.peer-connection.max-no-channels`, send [[Peer.Disconnect]] to
* the tracked peer with the oldest incoming connection.
*
* When a tracked peer disconnects or confirms a channel, we will stop tracking that peer.
*
* We do not need to track peers that disconnect because they will terminate if they have no channels.
* Likewise, peers with channels will disconnect and terminate when their last channel closes.
*
* Note: Peers on the sync whitelist are not tracked.
*
* This rate-limiting can be abused by attackers to prevent us from accepting channels from unknown peers: attackers can
* create a continuous stream of incoming connections with random nodeIds, which forces us to constantly disconnect old
* connections before they have the opportunity to open a channel. This can be fixed by adding a TCP rate-limiter that
* rejects connections based on IP addresses, which forces the attacker to own a lot of IP addresses.
*/
object IncomingConnectionsTracker {
// @formatter:off
sealed trait Command
case class TrackIncomingConnection(remoteNodeId: PublicKey) extends Command
private[io] case class ForgetIncomingConnection(remoteNodeId: PublicKey) extends Command
private[io] case class CountIncomingConnections(replyTo: ActorRef[Int]) extends Command
// @formatter:on
def apply(nodeParams: NodeParams, switchboard: ActorRef[Disconnect]): Behavior[Command] = {
Behaviors.setup { context =>
Behaviors.withMdc(Logs.mdc(category_opt = Some(LogCategory.CONNECTION))) {
context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerDisconnected](c => ForgetIncomingConnection(c.nodeId)))
context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[ChannelOpened](c => ForgetIncomingConnection(c.remoteNodeId)))
new IncomingConnectionsTracker(nodeParams, switchboard, context).tracking(Map.empty)
}
}
}
}
private class IncomingConnectionsTracker(nodeParams: NodeParams, switchboard: ActorRef[Disconnect], context: ActorContext[Command]) {
import IncomingConnectionsTracker._
private def tracking(incomingConnections: Map[PublicKey, TimestampMillis]): Behavior[Command] = {
Metrics.IncomingConnectionsNoChannels.withoutTags().update(incomingConnections.size)
Behaviors.receiveMessage {
case TrackIncomingConnection(remoteNodeId) =>
if (nodeParams.syncWhitelist.contains(remoteNodeId)) {
Behaviors.same
} else {
if (incomingConnections.size >= nodeParams.peerConnectionConf.maxNoChannels) {
Metrics.IncomingConnectionsDisconnected.withoutTags().increment()
val oldest = incomingConnections.minBy(_._2)._1
context.log.warn(s"disconnecting peer=$oldest, too many incoming connections from peers without channels.")
switchboard ! Disconnect(oldest, Some(context.system.ignoreRef[DisconnectResponse]))
tracking(incomingConnections + (remoteNodeId -> System.currentTimeMillis()) - oldest)
}
else {
tracking(incomingConnections + (remoteNodeId -> System.currentTimeMillis()))
}
}
case ForgetIncomingConnection(remoteNodeId) => tracking(incomingConnections - remoteNodeId)
case CountIncomingConnections(replyTo) =>
replyTo ! incomingConnections.size
Behaviors.same
}
}
}

View File

@ -33,6 +33,9 @@ object Monitoring {
val OnionMessagesThrottled = Kamon.counter("onionmessages.throttled")
val OpenChannelRequestsPending = Kamon.gauge("openchannelrequests.pending")
val IncomingConnectionsNoChannels = Kamon.gauge("incomingconnections.nochannels")
val IncomingConnectionsDisconnected = Kamon.counter("incomingconnections.disconnected")
}
object Tags {

View File

@ -90,6 +90,17 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainP
stay() using d.copy(channels = channels1)
}
case Event(ConnectionDown(_), d: DisconnectedData) =>
Logs.withMdc(diagLog)(Logs.mdc(category_opt = Some(Logs.LogCategory.CONNECTION))) {
log.debug("connection lost while negotiating connection")
}
if (d.channels.isEmpty) {
// we have no existing channels, we can forget about this peer
stopPeer()
} else {
stay()
}
// This event is usually handled while we're connected, but if our peer disconnects right when we're emitting this,
// we still want to record the channelId mapping.
case Event(ChannelIdAssigned(channel, _, temporaryChannelId, channelId), d: DisconnectedData) =>
@ -226,9 +237,10 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainP
// we won't clean it up, but we won't remember the temporary id on channel termination
stay() using d.copy(channels = d.channels + (FinalChannelId(channelId) -> channel))
case Event(Disconnect(nodeId), d: ConnectedData) if nodeId == remoteNodeId =>
case Event(Disconnect(nodeId, replyTo_opt), d: ConnectedData) if nodeId == remoteNodeId =>
log.debug("disconnecting")
sender() ! "disconnecting"
val replyTo = replyTo_opt.getOrElse(sender().toTyped)
replyTo ! Disconnecting(nodeId)
d.peerConnection ! PeerConnection.Kill(KillReason.UserRequest)
stay()
@ -301,8 +313,9 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainP
sender() ! Status.Failure(new RuntimeException("not connected"))
stay()
case Event(_: Peer.Disconnect, _) =>
sender() ! Status.Failure(new RuntimeException("not connected"))
case Event(Disconnect(nodeId, replyTo_opt), _) =>
val replyTo = replyTo_opt.getOrElse(sender().toTyped)
replyTo ! NotConnected(nodeId)
stay()
case Event(r: GetPeerInfo, d) =>
@ -470,7 +483,12 @@ object Peer {
def apply(uri: NodeURI, replyTo: ActorRef, isPersistent: Boolean): Connect = new Connect(uri.nodeId, Some(uri.address), replyTo, isPersistent)
}
case class Disconnect(nodeId: PublicKey) extends PossiblyHarmful
case class Disconnect(nodeId: PublicKey, replyTo_opt: Option[typed.ActorRef[DisconnectResponse]] = None) extends PossiblyHarmful
sealed trait DisconnectResponse {
def nodeId: PublicKey
}
case class Disconnecting(nodeId: PublicKey) extends DisconnectResponse { override def toString: String = s"peer $nodeId disconnecting" }
case class NotConnected(nodeId: PublicKey) extends DisconnectResponse { override def toString: String = s"peer $nodeId not connected" }
case class OpenChannel(remoteNodeId: PublicKey,
fundingAmount: Satoshi,
@ -499,7 +517,7 @@ object Peer {
def nodeId: PublicKey
}
case class PeerInfo(peer: ActorRef, nodeId: PublicKey, state: State, address: Option[NodeAddress], channels: Set[ActorRef]) extends PeerInfoResponse
case class PeerNotFound(nodeId: PublicKey) extends PeerInfoResponse { override def toString: String = s"peer $nodeId not found" }
case class PeerNotFound(nodeId: PublicKey) extends PeerInfoResponse with DisconnectResponse { override def toString: String = s"peer $nodeId not found" }
case class PeerRoutingMessage(peerConnection: ActorRef, remoteNodeId: PublicKey, message: RoutingMessage) extends RemoteTypes

View File

@ -88,7 +88,7 @@ class PeerConnection(keyPair: KeyPair, conf: PeerConnection.Conf, switchboard: A
cancelTimer(AUTH_TIMER)
log.info(s"connection authenticated (direction=${if (d.pendingAuth.outgoing) "outgoing" else "incoming"})")
Metrics.PeerConnectionsConnecting.withTag(Tags.ConnectionState, Tags.ConnectionStates.Authenticated).increment()
switchboard ! Authenticated(self, remoteNodeId)
switchboard ! Authenticated(self, remoteNodeId, d.pendingAuth.outgoing)
goto(BEFORE_INIT) using BeforeInitData(remoteNodeId, d.pendingAuth, d.transport, d.isPersistent)
case Event(AuthTimeout, d: AuthenticatingData) =>
@ -464,6 +464,9 @@ class PeerConnection(keyPair: KeyPair, conf: PeerConnection.Conf, switchboard: A
case StopEvent(_, CONNECTED, d: ConnectedData) =>
Metrics.PeerConnectionsConnected.withoutTags().decrement()
d.peer ! Peer.ConnectionDown(self)
case StopEvent(_, INITIALIZING, d: InitializingData) =>
log.debug("terminated while initializing.")
d.peer ! Peer.ConnectionDown(self)
}
/**
@ -542,7 +545,8 @@ object PeerConnection {
maxRebroadcastDelay: FiniteDuration,
killIdleDelay: FiniteDuration,
maxOnionMessagesPerSecond: Int,
sendRemoteAddressInit: Boolean)
sendRemoteAddressInit: Boolean,
maxNoChannels: Int)
// @formatter:off
@ -567,7 +571,7 @@ object PeerConnection {
case class PendingAuth(connection: ActorRef, remoteNodeId_opt: Option[PublicKey], address: NodeAddress, origin_opt: Option[ActorRef], transport_opt: Option[ActorRef] = None, isPersistent: Boolean) {
def outgoing: Boolean = remoteNodeId_opt.isDefined // if this is an outgoing connection, we know the node id in advance
}
case class Authenticated(peerConnection: ActorRef, remoteNodeId: PublicKey) extends RemoteTypes
case class Authenticated(peerConnection: ActorRef, remoteNodeId: PublicKey, outgoing: Boolean) extends RemoteTypes
case class InitializeConnection(peer: ActorRef, chainHash: ByteVector32, features: Features[InitFeature], doSync: Boolean) extends RemoteTypes
case class ConnectionReady(peerConnection: ActorRef, remoteNodeId: PublicKey, address: NodeAddress, outgoing: Boolean, localInit: protocol.Init, remoteInit: protocol.Init) extends RemoteTypes

View File

@ -17,15 +17,16 @@
package fr.acinq.eclair.io
import akka.actor.typed.scaladsl.Behaviors
import akka.actor.typed.scaladsl.adapter.ClassicActorContextOps
import akka.actor.typed.scaladsl.adapter.{ClassicActorContextOps, ClassicActorRefOps, TypedActorRefOps}
import akka.actor.{Actor, ActorContext, ActorLogging, ActorRef, OneForOneStrategy, Props, Stash, Status, SupervisorStrategy, typed}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.blockchain.OnchainPubkeyCache
import fr.acinq.eclair.channel.Helpers.Closing
import fr.acinq.eclair.channel._
import fr.acinq.eclair.io.IncomingConnectionsTracker.TrackIncomingConnection
import fr.acinq.eclair.io.MessageRelay.RelayPolicy
import fr.acinq.eclair.io.Peer.PeerInfoResponse
import fr.acinq.eclair.io.Peer.{PeerInfoResponse, PeerNotFound}
import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes
import fr.acinq.eclair.router.Router.RouterConf
import fr.acinq.eclair.wire.protocol.OnionMessage
@ -39,6 +40,8 @@ class Switchboard(nodeParams: NodeParams, peerFactory: Switchboard.PeerFactory)
import Switchboard._
private val incomingConnectionsTracker = context.spawn(Behaviors.supervise(IncomingConnectionsTracker(nodeParams, context.self.toTyped)).onFailure(typed.SupervisorStrategy.resume), name = "incoming-connections-tracker")
context.system.eventStream.subscribe(self, classOf[ChannelIdAssigned])
context.system.eventStream.subscribe(self, classOf[LastChannelClosed])
context.system.eventStream.publish(SubscriptionsComplete(this.getClass))
@ -80,9 +83,10 @@ class Switchboard(nodeParams: NodeParams, peerFactory: Switchboard.PeerFactory)
peer forward c
case d: Peer.Disconnect =>
val replyTo = d.replyTo_opt.getOrElse(sender().toTyped)
getPeer(d.nodeId) match {
case Some(peer) => peer forward d
case None => sender() ! Status.Failure(new RuntimeException(s"peer ${d.nodeId} not found"))
case None => replyTo ! PeerNotFound(d.nodeId)
}
case o: Peer.OpenChannel =>
@ -95,15 +99,19 @@ class Switchboard(nodeParams: NodeParams, peerFactory: Switchboard.PeerFactory)
// if this is an incoming connection, we might not yet have created the peer
val peer = createOrGetPeer(authenticated.remoteNodeId, offlineChannels = Set.empty)
val features = nodeParams.initFeaturesFor(authenticated.remoteNodeId)
val hasChannels = peersWithChannels.contains(authenticated.remoteNodeId)
// if the peer is whitelisted, we sync with them, otherwise we only sync with peers with whom we have at least one channel
val doSync = nodeParams.syncWhitelist.contains(authenticated.remoteNodeId) || (nodeParams.syncWhitelist.isEmpty && peersWithChannels.contains(authenticated.remoteNodeId))
val doSync = nodeParams.syncWhitelist.contains(authenticated.remoteNodeId) || (nodeParams.syncWhitelist.isEmpty && hasChannels)
authenticated.peerConnection ! PeerConnection.InitializeConnection(peer, nodeParams.chainHash, features, doSync)
if (!hasChannels && !authenticated.outgoing) {
incomingConnectionsTracker ! TrackIncomingConnection(authenticated.remoteNodeId)
}
case ChannelIdAssigned(_, remoteNodeId, _, _) => context.become(normal(peersWithChannels + remoteNodeId))
case LastChannelClosed(_, remoteNodeId) => context.become(normal(peersWithChannels - remoteNodeId))
case GetPeers => sender() ! context.children
case GetPeers => sender() ! context.children.filterNot(_ == incomingConnectionsTracker.toClassic)
case GetPeerInfo(replyTo, remoteNodeId) =>
getPeer(remoteNodeId) match {
@ -134,7 +142,8 @@ class Switchboard(nodeParams: NodeParams, peerFactory: Switchboard.PeerFactory)
getPeer(remoteNodeId) match {
case Some(peer) => peer
case None =>
log.debug(s"creating new peer (current={})", context.children.size)
// do not count the incoming-connections-tracker child actor that is not a peer
log.debug(s"creating new peer (current={})", context.children.size - 1)
val peer = createPeer(remoteNodeId)
peer ! Peer.Init(offlineChannels)
peer

View File

@ -110,7 +110,8 @@ object EclairInternalsSerializer {
("maxRebroadcastDelay" | finiteDurationCodec) ::
("killIdleDelay" | finiteDurationCodec) ::
("maxOnionMessagesPerSecond" | int32) ::
("sendRemoteAddressInit" | bool(8))).as[PeerConnection.Conf]
("sendRemoteAddressInit" | bool(8)) ::
("maxNoChannels" | int32)).as[PeerConnection.Conf]
val peerConnectionDoSyncCodec: Codec[PeerConnection.DoSync] = bool(8).as[PeerConnection.DoSync]
@ -176,7 +177,7 @@ object EclairInternalsSerializer {
.typecase(1, (routerConfCodec :: peerConnectionConfCodec).as[RouterPeerConf])
.typecase(5, readAckCodec)
.typecase(7, connectionRequestCodec(system))
.typecase(10, (actorRefCodec(system) :: publicKey).as[PeerConnection.Authenticated])
.typecase(10, (actorRefCodec(system) :: publicKey :: bool).as[PeerConnection.Authenticated])
.typecase(11, initializeConnectionCodec(system))
.typecase(12, connectionReadyCodec(system))
.typecase(13, provide(PeerConnection.ConnectionResult.NoAddressFound))

View File

@ -696,5 +696,4 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I
eclair.usableBalances().pipeTo(sender.ref)
relayer.expectMsg(GetOutgoingChannels())
}
}

View File

@ -169,6 +169,7 @@ object TestConstants {
killIdleDelay = 1 seconds,
maxOnionMessagesPerSecond = 10,
sendRemoteAddressInit = true,
maxNoChannels = 250,
),
routerConf = RouterConf(
watchSpentWindow = 1 second,
@ -323,6 +324,7 @@ object TestConstants {
killIdleDelay = 10 seconds,
maxOnionMessagesPerSecond = 10,
sendRemoteAddressInit = true,
maxNoChannels = 250,
),
routerConf = RouterConf(
watchSpentWindow = 1 second,

View File

@ -2605,7 +2605,7 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with
bob2alice.expectMsgType[Warning]
// we should fail the connection as per the BOLTs
bobPeer.fishForMessage(3 seconds) {
case Peer.Disconnect(nodeId) if nodeId == bob.stateData.asInstanceOf[DATA_NORMAL].commitments.params.remoteParams.nodeId => true
case Peer.Disconnect(nodeId, _) if nodeId == bob.stateData.asInstanceOf[DATA_NORMAL].commitments.params.remoteParams.nodeId => true
case _ => false
}
}
@ -2616,7 +2616,7 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with
bob2alice.expectMsgType[Warning]
// we should fail the connection as per the BOLTs
bobPeer.fishForMessage(3 seconds) {
case Peer.Disconnect(nodeId) if nodeId == bob.stateData.asInstanceOf[DATA_NORMAL].commitments.params.remoteParams.nodeId => true
case Peer.Disconnect(nodeId, _) if nodeId == bob.stateData.asInstanceOf[DATA_NORMAL].commitments.params.remoteParams.nodeId => true
case _ => false
}
}
@ -2636,7 +2636,7 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with
bob ! Shutdown(ByteVector32.Zeroes, hex"00112233445566778899")
// we should fail the connection as per the BOLTs
bobPeer.fishForMessage(3 seconds) {
case Peer.Disconnect(nodeId) if nodeId == bob.stateData.asInstanceOf[DATA_NORMAL].commitments.params.remoteParams.nodeId => true
case Peer.Disconnect(nodeId, _) if nodeId == bob.stateData.asInstanceOf[DATA_NORMAL].commitments.params.remoteParams.nodeId => true
case _ => false
}
}
@ -2647,7 +2647,7 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with
// we should fail the connection as per the BOLTs
bobPeer.fishForMessage(3 seconds) {
case Peer.Disconnect(nodeId) if nodeId == bob.stateData.asInstanceOf[DATA_NORMAL].commitments.params.remoteParams.nodeId => true
case Peer.Disconnect(nodeId, _) if nodeId == bob.stateData.asInstanceOf[DATA_NORMAL].commitments.params.remoteParams.nodeId => true
case _ => false
}
}

View File

@ -513,7 +513,7 @@ class StandardChannelIntegrationSpec extends ChannelIntegrationSpec {
// simulate a disconnection
sender.send(funder.switchboard, Peer.Disconnect(fundee.nodeParams.nodeId))
assert(sender.expectMsgType[String] == "disconnecting")
sender.expectMsgType[Peer.Disconnecting]
awaitCond({
fundee.register ! Register.Forward(sender.ref.toTyped[Any], channelId, CMD_GET_CHANNEL_STATE(ActorRef.noSender))

View File

@ -0,0 +1,114 @@
/*
* Copyright 2023 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package fr.acinq.eclair.io
import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe}
import akka.actor.typed.ActorRef
import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.scaladsl.adapter.TypedActorRefOps
import com.typesafe.config.ConfigFactory
import fr.acinq.bitcoin.scalacompat.Crypto
import fr.acinq.eclair.TestConstants.Alice.nodeParams
import fr.acinq.eclair.channel.ChannelOpened
import fr.acinq.eclair.io.Peer.Disconnect
import fr.acinq.eclair.{randomBytes32, randomKey}
import org.scalatest.Outcome
import org.scalatest.funsuite.FixtureAnyFunSuiteLike
import scala.concurrent.duration.DurationInt
class IncomingConnectionsTrackerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike {
val connection1: Crypto.PublicKey = randomKey().publicKey
val connection2: Crypto.PublicKey = randomKey().publicKey
override def withFixture(test: OneArgTest): Outcome = {
val nodeParams1 = nodeParams.copy(peerConnectionConf = nodeParams.peerConnectionConf.copy(maxNoChannels = 2))
val switchboard = TestProbe[Disconnect]()
val tracker = testKit.spawn(IncomingConnectionsTracker(nodeParams1, switchboard.ref))
withFixture(test.toNoArgTest(FixtureParam(tracker, switchboard)))
}
case class FixtureParam(tracker: ActorRef[IncomingConnectionsTracker.Command], switchboard: TestProbe[Disconnect])
test("accept new node connections, after limit is reached kill oldest node connection first") { f =>
import f._
tracker ! IncomingConnectionsTracker.TrackIncomingConnection(connection1)
tracker ! IncomingConnectionsTracker.TrackIncomingConnection(connection2)
tracker ! IncomingConnectionsTracker.TrackIncomingConnection(randomKey().publicKey)
assert(switchboard.expectMessageType[Disconnect].nodeId === connection1)
tracker ! IncomingConnectionsTracker.TrackIncomingConnection(randomKey().publicKey)
assert(switchboard.expectMessageType[Disconnect].nodeId === connection2)
}
test("stop tracking a node that disconnects and free space for a new node connection") { f =>
import f._
// Track nodes without channels.
val probe = TestProbe[Int]()
tracker ! IncomingConnectionsTracker.TrackIncomingConnection(connection1)
tracker ! IncomingConnectionsTracker.TrackIncomingConnection(connection2)
eventually {
tracker ! IncomingConnectionsTracker.CountIncomingConnections(probe.ref)
probe.expectMessage(2)
}
// Untrack a node when it disconnects.
system.eventStream ! EventStream.Publish(PeerDisconnected(system.deadLetters.toClassic, connection1))
eventually {
tracker ! IncomingConnectionsTracker.CountIncomingConnections(probe.ref)
probe.expectMessage(1)
}
// Track a new node connection without disconnecting the oldest node connection.
tracker ! IncomingConnectionsTracker.TrackIncomingConnection(randomKey().publicKey)
switchboard.expectNoMessage(100 millis)
// Track a new node connection and disconnect the oldest node connection.
tracker ! IncomingConnectionsTracker.TrackIncomingConnection(randomKey().publicKey)
assert(switchboard.expectMessageType[Disconnect].nodeId === connection2)
}
test("stop tracking a node that creates a channel and free space for a new node connection") { f =>
import f._
// Track nodes without channels.
val probe = TestProbe[Int]()
tracker ! IncomingConnectionsTracker.TrackIncomingConnection(connection1)
tracker ! IncomingConnectionsTracker.TrackIncomingConnection(connection2)
eventually {
tracker ! IncomingConnectionsTracker.CountIncomingConnections(probe.ref)
probe.expectMessage(2)
}
// Untrack a node when a channel with it is confirmed on-chain.
system.eventStream ! EventStream.Publish(ChannelOpened(system.deadLetters.toClassic, connection1, randomBytes32()))
eventually {
tracker ! IncomingConnectionsTracker.CountIncomingConnections(probe.ref)
probe.expectMessage(1)
}
// Track a new node connection without disconnecting the oldest node connection.
tracker ! IncomingConnectionsTracker.TrackIncomingConnection(randomKey().publicKey)
switchboard.expectNoMessage(100 millis)
// Track a new node connection and disconnect the oldest node connection.
tracker ! IncomingConnectionsTracker.TrackIncomingConnection(randomKey().publicKey)
assert(switchboard.expectMessageType[Disconnect].nodeId === connection2)
}
}

View File

@ -24,6 +24,7 @@ import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional}
import fr.acinq.eclair.Features.{BasicMultiPartPayment, ChannelRangeQueries, PaymentSecret, VariableLengthOnion}
import fr.acinq.eclair.TestConstants._
import fr.acinq.eclair.crypto.TransportHandler
import fr.acinq.eclair.io.Peer.ConnectionDown
import fr.acinq.eclair.message.OnionMessages.{Recipient, buildMessage}
import fr.acinq.eclair.router.Router._
import fr.acinq.eclair.router.RoutingSyncSpec
@ -73,7 +74,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi
val probe = TestProbe()
probe.send(peerConnection, PeerConnection.PendingAuth(connection.ref, Some(remoteNodeId), address, origin_opt = None, transport_opt = Some(transport.ref), isPersistent = isPersistent))
transport.send(peerConnection, TransportHandler.HandshakeCompleted(remoteNodeId))
switchboard.expectMsg(PeerConnection.Authenticated(peerConnection, remoteNodeId))
switchboard.expectMsg(PeerConnection.Authenticated(peerConnection, remoteNodeId, outgoing = true))
probe.send(peerConnection, PeerConnection.InitializeConnection(peer.ref, aliceParams.chainHash, aliceParams.features.initFeatures(), doSync))
transport.expectMsgType[TransportHandler.Listener]
val localInit = transport.expectMsgType[protocol.Init]
@ -101,7 +102,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi
assert(!incomingConnection.outgoing)
probe.send(peerConnection, incomingConnection)
transport.send(peerConnection, TransportHandler.HandshakeCompleted(remoteNodeId))
switchboard.expectMsg(PeerConnection.Authenticated(peerConnection, remoteNodeId))
switchboard.expectMsg(PeerConnection.Authenticated(peerConnection, remoteNodeId, outgoing = false))
probe.send(peerConnection, PeerConnection.InitializeConnection(peer.ref, nodeParams.chainHash, nodeParams.features.initFeatures(), doSync = false))
transport.expectMsgType[TransportHandler.Listener]
val localInit = transport.expectMsgType[protocol.Init]
@ -137,6 +138,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi
probe.send(peerConnection, PeerConnection.InitializeConnection(peer.ref, nodeParams.chainHash, nodeParams.features.initFeatures(), doSync = true))
probe.expectTerminated(peerConnection, nodeParams.peerConnectionConf.initTimeout / transport.testKitSettings.TestTimeFactor + 1.second) // we don't want dilated time here
origin.expectMsg(PeerConnection.ConnectionResult.InitializationFailed("initialization timed out"))
peer.expectMsg(ConnectionDown(peerConnection))
}
test("disconnect if incompatible local features") { f =>
@ -153,6 +155,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi
transport.expectMsgType[TransportHandler.ReadAck]
probe.expectTerminated(transport.ref)
origin.expectMsg(PeerConnection.ConnectionResult.InitializationFailed("incompatible features"))
peer.expectMsg(ConnectionDown(peerConnection))
}
test("disconnect if incompatible global features") { f =>
@ -169,6 +172,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi
transport.expectMsgType[TransportHandler.ReadAck]
probe.expectTerminated(transport.ref)
origin.expectMsg(PeerConnection.ConnectionResult.InitializationFailed("incompatible features"))
peer.expectMsg(ConnectionDown(peerConnection))
}
test("disconnect if features dependencies not met") { f =>
@ -186,6 +190,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi
transport.expectMsgType[TransportHandler.ReadAck]
probe.expectTerminated(transport.ref)
origin.expectMsg(PeerConnection.ConnectionResult.InitializationFailed("basic_mpp is set but is missing a dependency (payment_secret)"))
peer.expectMsg(ConnectionDown(peerConnection))
}
test("disconnect if incompatible networks") { f =>
@ -202,6 +207,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi
transport.expectMsgType[TransportHandler.ReadAck]
probe.expectTerminated(transport.ref)
origin.expectMsg(PeerConnection.ConnectionResult.InitializationFailed("incompatible networks"))
peer.expectMsg(ConnectionDown(peerConnection))
}
test("sync when requested") { f =>
@ -261,6 +267,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi
transport.expectMsgType[Ping]
deathWatcher.watch(transport.ref)
deathWatcher.expectTerminated(transport.ref, max = 11 seconds)
peer.expectMsg(ConnectionDown(peerConnection))
}
test("filter gossip message (no filtering)") { f =>

View File

@ -245,7 +245,7 @@ class PeerSpec extends FixtureSpec {
assert(probe.expectMsgType[Peer.PeerInfo].state == Peer.CONNECTED)
probe.send(peer, Peer.Disconnect(f.remoteNodeId))
probe.expectMsg("disconnecting")
probe.expectMsgType[Peer.Disconnecting]
}
test("handle disconnect in state DISCONNECTED") { f =>
@ -260,7 +260,7 @@ class PeerSpec extends FixtureSpec {
}
probe.send(peer, Peer.Disconnect(f.remoteNodeId))
assert(probe.expectMsgType[Status.Failure].cause.getMessage == "not connected")
probe.expectMsgType[Peer.NotConnected]
}
test("handle new connection in state CONNECTED") { f =>
@ -634,6 +634,21 @@ class PeerSpec extends FixtureSpec {
assert(channelAborted.remoteNodeId == remoteNodeId)
assert(channelAborted.channelId == open.temporaryChannelId)
}
test("kill peer with no channels if connection dies before receiving `ConnectionReady`") { f =>
import f._
val probe = TestProbe()
probe.watch(peer)
switchboard.send(peer, Peer.Init(Set.empty))
eventually {
probe.send(peer, Peer.GetPeerInfo(None))
assert(probe.expectMsgType[Peer.PeerInfo].state == Peer.DISCONNECTED)
}
// this will be sent if PeerConnection dies for any reason during the handshake
peer ! ConnectionDown(peerConnection.ref)
probe.expectTerminated(peer)
}
}
object PeerSpec {

View File

@ -1,12 +1,13 @@
package fr.acinq.eclair.io
import akka.actor.typed.scaladsl.adapter.ClassicActorRefOps
import akka.actor.{Actor, ActorContext, ActorRef, Props, Status}
import akka.actor.{Actor, ActorContext, ActorRef, Props}
import akka.testkit.{TestActorRef, TestProbe}
import fr.acinq.bitcoin.scalacompat.ByteVector64
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.TestConstants._
import fr.acinq.eclair.channel.{ChannelIdAssigned, PersistentChannelData}
import fr.acinq.eclair.io.Peer.PeerNotFound
import fr.acinq.eclair.io.Switchboard._
import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec
import fr.acinq.eclair.wire.protocol._
@ -14,6 +15,8 @@ import fr.acinq.eclair.{Features, InitFeature, NodeParams, TestKitBaseClass, Tim
import org.scalatest.funsuite.AnyFunSuiteLike
import scodec.bits._
import scala.concurrent.duration.DurationInt
class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike {
import SwitchboardSpec._
@ -60,7 +63,7 @@ class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike {
val unknownNodeId = randomKey().publicKey
probe.send(switchboard, Peer.Disconnect(unknownNodeId))
assert(probe.expectMsgType[Status.Failure].cause.getMessage == s"peer $unknownNodeId not found")
probe.expectMsgType[PeerNotFound]
probe.send(switchboard, Peer.Disconnect(remoteNodeId))
peer.expectMsg(Peer.Disconnect(remoteNodeId))
}
@ -69,7 +72,7 @@ class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike {
val (probe, peer, peerConnection) = (TestProbe(), TestProbe(), TestProbe())
val switchboard = TestActorRef(new Switchboard(nodeParams, FakePeerFactory(probe, peer)))
switchboard ! Switchboard.Init(channels)
switchboard ! PeerConnection.Authenticated(peerConnection.ref, remoteNodeId)
switchboard ! PeerConnection.Authenticated(peerConnection.ref, remoteNodeId, outgoing = true)
val initConnection = peerConnection.expectMsgType[PeerConnection.InitializeConnection]
assert(initConnection.chainHash == nodeParams.chainHash)
assert(initConnection.features == expectedFeatures)
@ -91,7 +94,7 @@ class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike {
// We have a channel with our peer, so we trigger a sync when connecting.
switchboard ! ChannelIdAssigned(TestProbe().ref, remoteNodeId, randomBytes32(), randomBytes32())
switchboard ! PeerConnection.Authenticated(peerConnection.ref, remoteNodeId)
switchboard ! PeerConnection.Authenticated(peerConnection.ref, remoteNodeId, outgoing = true)
val initConnection1 = peerConnection.expectMsgType[PeerConnection.InitializeConnection]
assert(initConnection1.chainHash == nodeParams.chainHash)
assert(initConnection1.features == nodeParams.features.initFeatures())
@ -99,7 +102,7 @@ class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike {
// We don't have channels with our peer, so we won't trigger a sync when connecting.
switchboard ! LastChannelClosed(peer.ref, remoteNodeId)
switchboard ! PeerConnection.Authenticated(peerConnection.ref, remoteNodeId)
switchboard ! PeerConnection.Authenticated(peerConnection.ref, remoteNodeId, outgoing = true)
val initConnection2 = peerConnection.expectMsgType[PeerConnection.InitializeConnection]
assert(initConnection2.chainHash == nodeParams.chainHash)
assert(initConnection2.features == nodeParams.features.initFeatures())
@ -141,6 +144,66 @@ class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike {
peer.expectMsg(Peer.GetPeerInfo(Some(probe.ref.toTyped)))
}
test("track nodes with incoming connections that do not have a channel") {
val nodeParams = Alice.nodeParams.copy(peerConnectionConf = Alice.nodeParams.peerConnectionConf.copy(maxNoChannels = 2))
val (probe, peer, peerConnection, channel) = (TestProbe(), TestProbe(), TestProbe(), TestProbe())
val hasChannelsNodeId1 = randomKey().publicKey
val hasChannelsNodeId2 = randomKey().publicKey
val unknownNodeId1 = randomKey().publicKey
val unknownNodeId2 = randomKey().publicKey
val switchboard = TestActorRef(new Switchboard(nodeParams, FakePeerFactory(probe, peer)))
switchboard ! Switchboard.Init(Nil)
// Do not track nodes we connect to.
switchboard ! PeerConnection.Authenticated(peerConnection.ref, randomKey().publicKey, outgoing = true)
peer.expectMsgType[Peer.Init]
// Do not track an incoming connection from a peer we have a channel with.
switchboard ! ChannelIdAssigned(channel.ref, hasChannelsNodeId1, randomBytes32(), randomBytes32())
switchboard ! PeerConnection.Authenticated(peerConnection.ref, hasChannelsNodeId1, outgoing = false)
peer.expectMsgType[Peer.Init]
// We do not yet have channels with these peers, so we track their incoming connections.
switchboard ! PeerConnection.Authenticated(peerConnection.ref, unknownNodeId1, outgoing = false)
peer.expectMsgType[Peer.Init]
switchboard ! PeerConnection.Authenticated(peerConnection.ref, unknownNodeId2, outgoing = false)
peer.expectMsgType[Peer.Init]
// Disconnect the oldest tracked peer when an incoming connection from a peer without channels connects.
switchboard ! PeerConnection.Authenticated(peerConnection.ref, randomKey().publicKey, outgoing = false)
peer.fishForMessage() {
case d: Peer.Disconnect => d.nodeId == unknownNodeId1
case _: Peer.Init => false
}
// Do not disconnect an old peer when a peer with channels connects.
switchboard ! ChannelIdAssigned(channel.ref, hasChannelsNodeId2, randomBytes32(), randomBytes32())
switchboard ! PeerConnection.Authenticated(peerConnection.ref, hasChannelsNodeId2, outgoing = false)
peer.expectMsgType[Peer.Init]
peer.expectNoMessage(100 millis)
// Disconnect the next oldest tracked peer when an incoming connection from a peer without channels connects.
switchboard ! PeerConnection.Authenticated(peerConnection.ref, randomKey().publicKey, outgoing = false)
peer.fishForMessage() {
case d: Peer.Disconnect => d.nodeId == unknownNodeId2
case _: Peer.Init => false
}
}
test("GetPeers should only return child nodes of type `Peer`") {
val nodeParams = Alice.nodeParams.copy(peerConnectionConf = Alice.nodeParams.peerConnectionConf.copy(maxNoChannels = 2))
val (peer, probe) = (TestProbe(), TestProbe())
val remoteNodeId = ChannelCodecsSpec.normal.commitments.remoteNodeId
val switchboard = TestActorRef(new Switchboard(nodeParams, FakePeerFactory(TestProbe(), peer)))
switchboard ! Switchboard.Init(Nil)
switchboard ! Peer.Connect(remoteNodeId, None, TestProbe().ref, isPersistent = true)
peer.expectMsgType[Peer.Init]
probe.send(switchboard, GetPeers)
val peers = probe.expectMsgType[Iterable[ActorRef]]
assert(peers.size == 1)
assert(peers.head.path.name == peerActorName(remoteNodeId))
}
}
object SwitchboardSpec {

View File

@ -1216,6 +1216,26 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM
}
}
test("'disconnect' should accept a nodeId") {
val remoteNodeId = PublicKey(hex"030bb6a5e0c6b203c7e2180fb78c7ba4bdce46126761d8201b91ddac089cdecc87")
val eclair = mock[Eclair]
eclair.disconnect(any[PublicKey])(any[Timeout]) returns Future.successful(Peer.Disconnecting(remoteNodeId).toString)
val mockService = new MockService(eclair)
Post("/disconnect", FormData("nodeId" -> remoteNodeId.toString()).toEntity) ~>
addCredentials(BasicHttpCredentials("", mockApi().password)) ~>
Route.seal(mockService.disconnect) ~>
check {
assert(handled)
assert(status == OK)
val response = entityAs[String]
assert(response == s"\"peer $remoteNodeId disconnecting\"")
eclair.disconnect(remoteNodeId)(any[Timeout]).wasCalled(once)
}
}
private def matchTestJson(apiName: String, response: String) = {
val resource = getClass.getResourceAsStream(s"/api/$apiName")
val expectedResponse = Try(Source.fromInputStream(resource).mkString).getOrElse {