From a58678eb0b2797370ce4712ecfd3b6a3f54ef423 Mon Sep 17 00:00:00 2001 From: Pierre-Marie Padiou Date: Thu, 9 Apr 2020 16:08:10 +0200 Subject: [PATCH] Move router handlers to separate files (#1352) Also, acknowledge all gossip with a `GossipDecision`. --- .../main/scala/fr/acinq/eclair/Eclair.scala | 5 +- .../src/main/scala/fr/acinq/eclair/Logs.scala | 13 +- .../scala/fr/acinq/eclair/NodeParams.scala | 2 +- .../scala/fr/acinq/eclair/db/NetworkDb.scala | 2 +- .../scala/fr/acinq/eclair/db/PaymentsDb.scala | 2 +- .../eclair/db/sqlite/SqliteNetworkDb.scala | 2 +- .../fr/acinq/eclair/io/PeerConnection.scala | 38 +- .../main/scala/fr/acinq/eclair/package.scala | 5 + .../acinq/eclair/payment/PaymentEvents.scala | 2 +- .../acinq/eclair/payment/PaymentPacket.scala | 2 +- .../eclair/payment/relay/NodeRelayer.scala | 5 +- .../acinq/eclair/payment/send/Autoprobe.scala | 3 +- .../send/MultiPartPaymentLifecycle.scala | 1 + .../payment/send/PaymentInitiator.scala | 2 +- .../payment/send/PaymentLifecycle.scala | 1 + .../scala/fr/acinq/eclair/router/Graph.scala | 2 +- .../fr/acinq/eclair/router/Monitoring.scala | 11 +- .../fr/acinq/eclair/router/NetworkStats.scala | 1 + .../eclair/router/RouteCalculation.scala | 207 +++ .../scala/fr/acinq/eclair/router/Router.scala | 1442 +++-------------- .../acinq/eclair/router/StaleChannels.scala | 102 ++ .../scala/fr/acinq/eclair/router/Sync.scala | 515 ++++++ .../fr/acinq/eclair/router/Validation.scala | 450 +++++ .../eclair/wire/LightningMessageTypes.scala | 7 +- .../fr/acinq/eclair/EclairImplSpec.scala | 3 +- .../scala/fr/acinq/eclair/TestConstants.scala | 4 +- .../fr/acinq/eclair/channel/FuzzySpec.scala | 5 +- .../states/StateTestsHelperMethods.scala | 2 +- .../channel/states/f/ShutdownStateSpec.scala | 2 +- .../acinq/eclair/db/SqliteNetworkDbSpec.scala | 3 +- .../eclair/db/SqlitePaymentsDbSpec.scala | 2 +- .../eclair/integration/IntegrationSpec.scala | 13 +- .../acinq/eclair/io/PeerConnectionSpec.scala | 25 +- .../MultiPartPaymentLifecycleSpec.scala | 1 + .../eclair/payment/PaymentInitiatorSpec.scala | 6 +- .../eclair/payment/PaymentLifecycleSpec.scala | 46 +- .../eclair/payment/PaymentPacketSpec.scala | 2 +- .../payment/PostRestartHtlcCleanerSpec.scala | 2 +- .../fr/acinq/eclair/payment/RelayerSpec.scala | 3 +- .../acinq/eclair/router/BaseRouterSpec.scala | 83 +- .../router/ChannelRangeQueriesSpec.scala | 143 +- .../fr/acinq/eclair/router/GraphSpec.scala | 1 + .../eclair/router/NetworkStatsSpec.scala | 1 + .../eclair/router/RouteCalculationSpec.scala | 102 +- .../fr/acinq/eclair/router/RouterSpec.scala | 289 +++- .../acinq/eclair/router/RoutingSyncSpec.scala | 20 +- .../wire/ExtendedQueriesCodecsSpec.scala | 6 +- .../fr/acinq/eclair/gui/GUIUpdater.scala | 3 +- .../fr/acinq/eclair/api/JsonSerializers.scala | 3 +- 49 files changed, 2036 insertions(+), 1556 deletions(-) create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/router/StaleChannels.scala create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/router/Sync.scala create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index 5651a78d5..bae41da71 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -34,7 +34,8 @@ import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.receive.MultiPartHandler.ReceivePayment import fr.acinq.eclair.payment.relay.Relayer.{GetOutgoingChannels, OutgoingChannels, UsableBalance} import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentRequest, SendPaymentToRouteRequest, SendPaymentToRouteResponse} -import fr.acinq.eclair.router._ +import fr.acinq.eclair.router.Router._ +import fr.acinq.eclair.router.{NetworkStats, RouteCalculation} import fr.acinq.eclair.wire.{ChannelAnnouncement, ChannelUpdate, NodeAddress, NodeAnnouncement} import scodec.bits.ByteVector @@ -232,7 +233,7 @@ class EclairImpl(appKit: Kit) extends Eclair { override def send(externalId_opt: Option[String], recipientNodeId: PublicKey, amount: MilliSatoshi, paymentHash: ByteVector32, invoice_opt: Option[PaymentRequest], maxAttempts_opt: Option[Int], feeThreshold_opt: Option[Satoshi], maxFeePct_opt: Option[Double])(implicit timeout: Timeout): Future[UUID] = { val maxAttempts = maxAttempts_opt.getOrElse(appKit.nodeParams.maxPaymentAttempts) - val defaultRouteParams = Router.getDefaultRouteParams(appKit.nodeParams.routerConf) + val defaultRouteParams = RouteCalculation.getDefaultRouteParams(appKit.nodeParams.routerConf) val routeParams = defaultRouteParams.copy( maxFeePct = maxFeePct_opt.getOrElse(defaultRouteParams.maxFeePct), maxFeeBase = feeThreshold_opt.map(_.toMilliSatoshi).getOrElse(defaultRouteParams.maxFeeBase) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Logs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Logs.scala index faafddf29..4c35152e5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Logs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Logs.scala @@ -28,6 +28,7 @@ import fr.acinq.eclair.channel.{LocalChannelDown, LocalChannelUpdate} import fr.acinq.eclair.crypto.TransportHandler.HandshakeCompleted import fr.acinq.eclair.io.Peer.PeerRoutingMessage import fr.acinq.eclair.io.{Peer, PeerConnection} +import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.router._ import fr.acinq.eclair.wire._ @@ -44,12 +45,12 @@ object Logs { ).flatten.toMap /** - * Temporarily add the provided MDC to the current one, and then restore the original one. - * - * This is useful in some cases where we can't rely on the `aroundReceive` trick to set the MDC before processing a - * message because we don't have enough context. That's typically the case when handling `Terminated` messages. - */ - def withMdc(log: DiagnosticLoggingAdapter)(mdc: MDC)(f: => Any): Any = { + * Temporarily add the provided MDC to the current one, and then restore the original one. + * + * This is useful in some cases where we can't rely on the `aroundReceive` trick to set the MDC before processing a + * message because we don't have enough context. That's typically the case when handling `Terminated` messages. + */ + def withMdc[T](log: DiagnosticLoggingAdapter)(mdc: MDC)(f: => T): T = { val mdc0 = log.mdc // backup the current mdc try { log.mdc(mdc0 ++ mdc) // add the new mdc to the current one diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index eb145c618..a2f66d0a2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -30,7 +30,7 @@ import fr.acinq.eclair.blockchain.fee.{FeeEstimator, FeeTargets, OnChainFeeConf} import fr.acinq.eclair.channel.Channel import fr.acinq.eclair.crypto.KeyManager import fr.acinq.eclair.db._ -import fr.acinq.eclair.router.RouterConf +import fr.acinq.eclair.router.Router.RouterConf import fr.acinq.eclair.tor.Socks5ProxyParams import fr.acinq.eclair.wire.{Color, EncodingType, NodeAddress} import scodec.bits.ByteVector diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala index abaa68703..f89f82552 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala @@ -21,7 +21,7 @@ import java.io.Closeable import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.{ByteVector32, Satoshi} import fr.acinq.eclair.ShortChannelId -import fr.acinq.eclair.router.PublicChannel +import fr.acinq.eclair.router.Router.PublicChannel import fr.acinq.eclair.wire.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement} import scala.collection.immutable.SortedMap diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index a559df4ad..73c4e8ecb 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -22,7 +22,7 @@ import java.util.UUID import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.payment._ -import fr.acinq.eclair.router.{ChannelHop, Hop, NodeHop} +import fr.acinq.eclair.router.Router.{ChannelHop, Hop, NodeHop} import fr.acinq.eclair.{MilliSatoshi, ShortChannelId} import scala.compat.Platform diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala index a8e858173..2e42061d2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala @@ -21,7 +21,7 @@ import java.sql.Connection import fr.acinq.bitcoin.{ByteVector32, Crypto, Satoshi} import fr.acinq.eclair.ShortChannelId import fr.acinq.eclair.db.NetworkDb -import fr.acinq.eclair.router.PublicChannel +import fr.acinq.eclair.router.Router.PublicChannel import fr.acinq.eclair.wire.LightningMessageCodecs.{channelAnnouncementCodec, channelUpdateCodec, nodeAnnouncementCodec} import fr.acinq.eclair.wire.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement} import grizzled.slf4j.Logging diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala index b0f8d942c..085564cd4 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala @@ -26,7 +26,7 @@ import fr.acinq.eclair.crypto.Noise.KeyPair import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.io.Monitoring.{Metrics, Tags} import fr.acinq.eclair.io.Peer.CHANNELID_ZERO -import fr.acinq.eclair.router._ +import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.wire._ import fr.acinq.eclair.{wire, _} import scodec.Attempt @@ -160,7 +160,7 @@ class PeerConnection(nodeParams: NodeParams, switchboard: ActorRef, router: Acto val flags_opt = if (canUseChannelRangeQueriesEx) Some(QueryChannelRangeTlv.QueryFlags(QueryChannelRangeTlv.QueryFlags.WANT_ALL)) else None if (d.nodeParams.syncWhitelist.isEmpty || d.nodeParams.syncWhitelist.contains(d.remoteNodeId)) { log.info(s"sending sync channel range query with flags_opt=$flags_opt") - router ! SendChannelQuery(d.remoteNodeId, self, flags_opt = flags_opt) + router ! SendChannelQuery(nodeParams.chainHash, d.remoteNodeId, self, flags_opt = flags_opt) } else { log.info("not syncing with this peer") } @@ -258,11 +258,12 @@ class PeerConnection(nodeParams: NodeParams, switchboard: ActorRef, router: Acto case Event(DelayedRebroadcast(rebroadcast), d: ConnectedData) => + val thisRemote = RemoteGossip(self, d.remoteNodeId) /** * Send and count in a single iteration */ def sendAndCount(msgs: Map[_ <: RoutingMessage, Set[GossipOrigin]]): Int = msgs.foldLeft(0) { - case (count, (_, origins)) if origins.contains(RemoteGossip(self)) => + case (count, (_, origins)) if origins.contains(thisRemote) => // the announcement came from this peer, we don't send it back count case (count, (msg, origins)) if !timestampInRange(d.nodeParams, msg, origins, d.gossipTimestampFilter) => @@ -321,9 +322,9 @@ class PeerConnection(nodeParams: NodeParams, switchboard: ActorRef, router: Acto d.transport forward readAck stay - case Event(badMessage: BadMessage, d: ConnectedData) => - val behavior1 = badMessage match { - case InvalidSignature(r) => + case Event(rejectedGossip: GossipDecision.Rejected, d: ConnectedData) => + val behavior1 = rejectedGossip match { + case GossipDecision.InvalidSignature(r) => val bin: String = LightningMessageCodecs.meteredLightningMessageCodec.encode(r) match { case Attempt.Successful(b) => b.toHex case _ => "unknown" @@ -333,14 +334,14 @@ class PeerConnection(nodeParams: NodeParams, switchboard: ActorRef, router: Acto // TODO: this doesn't actually disconnect the peer, once we introduce peer banning we should actively disconnect d.transport ! Error(CHANNELID_ZERO, ByteVector.view(s"bad announcement sig! bin=$bin".getBytes())) d.behavior - case InvalidAnnouncement(c) => + case GossipDecision.InvalidAnnouncement(c) => // they seem to be sending us fake announcements? log.error(s"couldn't find funding tx with valid scripts for shortChannelId=${c.shortChannelId}") // for now we just return an error, maybe ban the peer in the future? // TODO: this doesn't actually disconnect the peer, once we introduce peer banning we should actively disconnect d.transport ! Error(CHANNELID_ZERO, ByteVector.view(s"couldn't verify channel! shortChannelId=${c.shortChannelId}".getBytes())) d.behavior - case ChannelClosed(_) => + case GossipDecision.ChannelClosed(_) => if (d.behavior.ignoreNetworkAnnouncement) { // we already are ignoring announcements, we may have additional notifications for announcements that were received right before our ban d.behavior.copy(fundingTxAlreadySpentCount = d.behavior.fundingTxAlreadySpentCount + 1) @@ -351,6 +352,15 @@ class PeerConnection(nodeParams: NodeParams, switchboard: ActorRef, router: Acto setTimer(ResumeAnnouncements.toString, ResumeAnnouncements, IGNORE_NETWORK_ANNOUNCEMENTS_PERIOD, repeat = false) d.behavior.copy(fundingTxAlreadySpentCount = d.behavior.fundingTxAlreadySpentCount + 1, ignoreNetworkAnnouncement = true) } + // other rejections are not considered punishable offenses + // we are not using a catch-all on purpose, to make compiler warn us when a new error is added + case _: GossipDecision.Duplicate => d.behavior + case _: GossipDecision.NoKnownChannel => d.behavior + case _: GossipDecision.ValidationFailure => d.behavior + case _: GossipDecision.ChannelPruned => d.behavior + case _: GossipDecision.ChannelClosing => d.behavior + case _: GossipDecision.Stale => d.behavior + case _: GossipDecision.NoRelatedChannel => d.behavior } stay using d.copy(behavior = behavior1) @@ -373,6 +383,10 @@ class PeerConnection(nodeParams: NodeParams, switchboard: ActorRef, router: Acto } stop(FSM.Normal) + case Event(_: GossipDecision.Accepted, _) => stay // for now we don't do anything with those events + + case Event(_: GossipDecision.Rejected, _) => stay // we got disconnected while syncing + case Event(_: Rebroadcast, _) => stay // ignored case Event(_: DelayedRebroadcast, _) => stay // ignored @@ -386,9 +400,6 @@ class PeerConnection(nodeParams: NodeParams, switchboard: ActorRef, router: Acto case Event(_: Pong, _) => stay // we got disconnected before receiving the pong case Event(_: PingTimeout, _) => stay // we got disconnected after sending a ping - - case Event(_: BadMessage, _) => stay // we got disconnected while syncing - } onTransition { @@ -490,11 +501,6 @@ object PeerConnection { case class DelayedRebroadcast(rebroadcast: Rebroadcast) - sealed trait BadMessage - case class InvalidSignature(r: RoutingMessage) extends BadMessage - case class InvalidAnnouncement(c: ChannelAnnouncement) extends BadMessage - case class ChannelClosed(c: ChannelAnnouncement) extends BadMessage - case class Behavior(fundingTxAlreadySpentCount: Int = 0, ignoreNetworkAnnouncement: Boolean = false) // @formatter:on diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/package.scala b/eclair-core/src/main/scala/fr/acinq/eclair/package.scala index f14482fcc..9aa4226af 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/package.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/package.scala @@ -185,4 +185,9 @@ package object eclair { // @formatter:on } + /** + * Apparently .getClass.getSimpleName can crash java 8 with a "Malformed class name" error + */ + def getSimpleClassName(o: Any): String = o.getClass.getName.split("\\$").last + } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala index 22677e18b..0d4d3e1a3 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala @@ -22,7 +22,7 @@ import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.MilliSatoshi import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.router.Hop +import fr.acinq.eclair.router.Router.Hop import scala.compat.Platform diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala index 86c74335f..ee6e2e1ac 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala @@ -21,7 +21,7 @@ import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.channel.{CMD_ADD_HTLC, Upstream} import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.router.{ChannelHop, Hop, NodeHop} +import fr.acinq.eclair.router.Router.{ChannelHop, Hop, NodeHop} import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, UInt64, randomKey} import scodec.bits.ByteVector diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala index 548bf0b9b..ff3da4732 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala @@ -30,7 +30,8 @@ import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.SendMultiPartPayme import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPayment import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentError, PaymentLifecycle} -import fr.acinq.eclair.router.{RouteNotFound, RouteParams, Router} +import fr.acinq.eclair.router.Router.RouteParams +import fr.acinq.eclair.router.{RouteCalculation, RouteNotFound} import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiry, Logs, MilliSatoshi, NodeParams, nodeFee, randomBytes32} @@ -246,7 +247,7 @@ object NodeRelayer { private def computeRouteParams(nodeParams: NodeParams, amountIn: MilliSatoshi, expiryIn: CltvExpiry, amountOut: MilliSatoshi, expiryOut: CltvExpiry): RouteParams = { val routeMaxCltv = expiryIn - expiryOut - nodeParams.expiryDeltaBlocks val routeMaxFee = amountIn - amountOut - nodeFee(nodeParams.feeBase, nodeParams.feeProportionalMillionth, amountOut) - Router.getDefaultRouteParams(nodeParams.routerConf).copy( + RouteCalculation.getDefaultRouteParams(nodeParams.routerConf).copy( maxFeeBase = routeMaxFee, routeMaxCltv = routeMaxCltv, maxFeePct = 0 // we disable percent-based max fee calculation, we're only interested in collecting our node fee diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Autoprobe.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Autoprobe.scala index 9d70e778d..bfd72b7b0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Autoprobe.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Autoprobe.scala @@ -20,7 +20,8 @@ import akka.actor.{Actor, ActorLogging, ActorRef, Props} import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.crypto.Sphinx.DecryptedFailurePacket import fr.acinq.eclair.payment.{PaymentEvent, PaymentFailed, RemoteFailure} -import fr.acinq.eclair.router.{Announcements, Data, PublicChannel} +import fr.acinq.eclair.router.Announcements +import fr.acinq.eclair.router.Router.{Data, PublicChannel} import fr.acinq.eclair.wire.IncorrectOrUnknownPaymentDetails import fr.acinq.eclair.{LongToBtcAmount, NodeParams, randomBytes32, secureRandom} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala index e556a395f..282521240 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala @@ -32,6 +32,7 @@ import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.relay.Relayer.{GetOutgoingChannels, OutgoingChannel, OutgoingChannels} import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPayment +import fr.acinq.eclair.router.Router.{ChannelHop, GetNetworkStats, GetNetworkStatsResponse, RouteParams, TickComputeNetworkStats} import fr.acinq.eclair.router._ import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiry, FSMDiagnosticActorLogging, Logs, LongToBtcAmount, MilliSatoshi, NodeParams, ShortChannelId, ToMilliSatoshiConversion} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala index 7a5f37287..3e93b4e94 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala @@ -28,7 +28,7 @@ import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.SendMultiPartPayment import fr.acinq.eclair.payment.send.PaymentError._ import fr.acinq.eclair.payment.send.PaymentLifecycle.{SendPayment, SendPaymentToRoute} -import fr.acinq.eclair.router.{ChannelHop, Hop, NodeHop, RouteParams} +import fr.acinq.eclair.router.Router.{ChannelHop, Hop, NodeHop, RouteParams} import fr.acinq.eclair.wire.Onion.FinalLegacyPayload import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, LongToBtcAmount, MilliSatoshi, NodeParams, randomBytes32} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala index eb7f6bca8..56df6755e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala @@ -33,6 +33,7 @@ import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.relay.Relayer import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle._ +import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.router._ import fr.acinq.eclair.wire.Onion._ import fr.acinq.eclair.wire._ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala index 2db4b9656..c0f387931 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala @@ -275,7 +275,7 @@ object Graph { case false => Seq.empty[GraphEdge] case true => // we traverse the list of "previous" backward building the final list of edges that make the shortest path - val edgePath = new mutable.ArrayBuffer[GraphEdge](ROUTE_MAX_LENGTH) + val edgePath = new mutable.ArrayBuffer[GraphEdge](RouteCalculation.ROUTE_MAX_LENGTH) var current = prev.get(sourceNode) while (current != null) { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Monitoring.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Monitoring.scala index fd929b008..33760bcf2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Monitoring.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Monitoring.scala @@ -16,14 +16,23 @@ package fr.acinq.eclair.router -import fr.acinq.eclair.{LongToBtcAmount, MilliSatoshi} +import fr.acinq.eclair.router.Router.GossipDecision +import fr.acinq.eclair.{LongToBtcAmount, MilliSatoshi, getSimpleClassName} import kamon.Kamon +import kamon.metric.Counter object Monitoring { object Metrics { val FindRouteDuration = Kamon.timer("router.find-route.duration", "Path-finding duration") val RouteLength = Kamon.histogram("router.find-route.length", "Path-finding result length") + + private val GossipResult = Kamon.counter("router.gossip.result") + + def gossipResult(decision: GossipDecision): Counter = decision match { + case _: GossipDecision.Accepted => GossipResult.withTag("result", "accepted") + case rejected: GossipDecision.Rejected => GossipResult.withTag("result", "rejected").withTag("reason", getSimpleClassName(rejected)) + } } object Tags { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/NetworkStats.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/NetworkStats.scala index 2d3d66eaf..910a383ff 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/NetworkStats.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/NetworkStats.scala @@ -18,6 +18,7 @@ package fr.acinq.eclair.router import com.google.common.math.Quantiles.percentiles import fr.acinq.bitcoin.Satoshi +import fr.acinq.eclair.router.Router.PublicChannel import fr.acinq.eclair.wire.ChannelUpdate import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshi} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala new file mode 100644 index 000000000..d431f62b5 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala @@ -0,0 +1,207 @@ +/* + * Copyright 2020 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.router + +import akka.actor.{ActorContext, ActorRef, Status} +import akka.event.LoggingAdapter +import fr.acinq.bitcoin.Crypto.PublicKey +import fr.acinq.bitcoin.{ByteVector32, ByteVector64} +import fr.acinq.eclair.payment.PaymentRequest.ExtraHop +import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph.graphEdgeToHop +import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} +import fr.acinq.eclair.router.Graph.{RichWeight, RoutingHeuristics, WeightRatios} +import fr.acinq.eclair.router.Monitoring.{Metrics, Tags} +import fr.acinq.eclair.router.Router._ +import fr.acinq.eclair.wire.ChannelUpdate +import fr.acinq.eclair.{ShortChannelId, _} + +import scala.compat.Platform +import scala.concurrent.duration._ +import scala.util.{Random, Try} + +object RouteCalculation { + + def finalizeRoute(d: Data, fr: FinalizeRoute)(implicit ctx: ActorContext, log: LoggingAdapter): Data = { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + + // NB: using a capacity of 0 msat will impact the path-finding algorithm. However here we don't run any path-finding, so it's ok. + val assistedChannels: Map[ShortChannelId, AssistedChannel] = fr.assistedRoutes.flatMap(toAssistedChannels(_, fr.hops.last, 0 msat)).toMap + val extraEdges = assistedChannels.values.map(ac => GraphEdge(ChannelDesc(ac.extraHop.shortChannelId, ac.extraHop.nodeId, ac.nextNodeId), toFakeUpdate(ac.extraHop, ac.htlcMaximum))).toSet + val g = extraEdges.foldLeft(d.graph) { case (g: DirectedGraph, e: GraphEdge) => g.addEdge(e) } + // split into sublists [(a,b),(b,c), ...] then get the edges between each of those pairs + fr.hops.sliding(2).map { case List(v1, v2) => g.getEdgesBetween(v1, v2) }.toList match { + case edges if edges.nonEmpty && edges.forall(_.nonEmpty) => + val selectedEdges = edges.map(_.maxBy(_.update.htlcMaximumMsat.getOrElse(0 msat))) // select the largest edge + val hops = selectedEdges.map(d => ChannelHop(d.desc.a, d.desc.b, d.update)) + ctx.sender ! RouteResponse(hops, Set.empty, Set.empty) + case _ => // some nodes in the supplied route aren't connected in our graph + ctx.sender ! Status.Failure(new IllegalArgumentException("Not all the nodes in the supplied route are connected with public channels")) + } + d + } + + def handleRouteRequest(d: Data, routerConf: RouterConf, currentBlockHeight: Long, r: RouteRequest)(implicit ctx: ActorContext, log: LoggingAdapter): Data = { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + import r._ + + // we convert extra routing info provided in the payment request to fake channel_update + // it takes precedence over all other channel_updates we know + val assistedChannels: Map[ShortChannelId, AssistedChannel] = assistedRoutes.flatMap(toAssistedChannels(_, target, amount)).toMap + val extraEdges = assistedChannels.values.map(ac => GraphEdge(ChannelDesc(ac.extraHop.shortChannelId, ac.extraHop.nodeId, ac.nextNodeId), toFakeUpdate(ac.extraHop, ac.htlcMaximum))).toSet + val ignoredEdges = ignoreChannels ++ d.excludedChannels + val defaultRouteParams: RouteParams = getDefaultRouteParams(routerConf) + val params = routeParams.getOrElse(defaultRouteParams) + val routesToFind = if (params.randomize) DEFAULT_ROUTES_COUNT else 1 + + log.info(s"finding a route $source->$target with assistedChannels={} ignoreNodes={} ignoreChannels={} excludedChannels={}", assistedChannels.keys.mkString(","), ignoreNodes.map(_.value).mkString(","), ignoreChannels.mkString(","), d.excludedChannels.mkString(",")) + log.info(s"finding a route with randomize={} params={}", routesToFind > 1, params) + findRoute(d.graph, source, target, amount, numRoutes = routesToFind, extraEdges = extraEdges, ignoredEdges = ignoredEdges, ignoredVertices = ignoreNodes, routeParams = params, currentBlockHeight) + .map(r => ctx.sender ! RouteResponse(r, ignoreNodes, ignoreChannels)) + .recover { case t => ctx.sender ! Status.Failure(t) } + d + } + + def toFakeUpdate(extraHop: ExtraHop, htlcMaximum: MilliSatoshi): ChannelUpdate = { + // the `direction` bit in flags will not be accurate but it doesn't matter because it is not used + // what matters is that the `disable` bit is 0 so that this update doesn't get filtered out + ChannelUpdate(signature = ByteVector64.Zeroes, chainHash = ByteVector32.Zeroes, extraHop.shortChannelId, Platform.currentTime.milliseconds.toSeconds, messageFlags = 1, channelFlags = 0, extraHop.cltvExpiryDelta, htlcMinimumMsat = 0 msat, extraHop.feeBase, extraHop.feeProportionalMillionths, Some(htlcMaximum)) + } + + def toAssistedChannels(extraRoute: Seq[ExtraHop], targetNodeId: PublicKey, amount: MilliSatoshi): Map[ShortChannelId, AssistedChannel] = { + // BOLT 11: "For each entry, the pubkey is the node ID of the start of the channel", and the last node is the destination + // The invoice doesn't explicitly specify the channel's htlcMaximumMsat, but we can safely assume that the channel + // should be able to route the payment, so we'll compute an htlcMaximumMsat accordingly. + // We could also get the channel capacity from the blockchain (since we have the shortChannelId) but that's more expensive. + // We also need to make sure the channel isn't excluded by our heuristics. + val lastChannelCapacity = amount.max(RoutingHeuristics.CAPACITY_CHANNEL_LOW) + val nextNodeIds = extraRoute.map(_.nodeId).drop(1) :+ targetNodeId + extraRoute.zip(nextNodeIds).reverse.foldLeft((lastChannelCapacity, Map.empty[ShortChannelId, AssistedChannel])) { + case ((amount, acs), (extraHop: ExtraHop, nextNodeId)) => + val nextAmount = amount + nodeFee(extraHop.feeBase, extraHop.feeProportionalMillionths, amount) + (nextAmount, acs + (extraHop.shortChannelId -> AssistedChannel(extraHop, nextNodeId, nextAmount))) + }._2 + } + + /** + * This method is used after a payment failed, and we want to exclude some nodes that we know are failing + */ + def getIgnoredChannelDesc(channels: Map[ShortChannelId, PublicChannel], ignoreNodes: Set[PublicKey]): Iterable[ChannelDesc] = { + val desc = if (ignoreNodes.isEmpty) { + Iterable.empty[ChannelDesc] + } else { + // expensive, but node blacklisting shouldn't happen often + channels.values + .filter(channelData => ignoreNodes.contains(channelData.ann.nodeId1) || ignoreNodes.contains(channelData.ann.nodeId2)) + .flatMap(channelData => Vector(ChannelDesc(channelData.ann.shortChannelId, channelData.ann.nodeId1, channelData.ann.nodeId2), ChannelDesc(channelData.ann.shortChannelId, channelData.ann.nodeId2, channelData.ann.nodeId1))) + } + desc + } + + /** + * https://github.com/lightningnetwork/lightning-rfc/blob/master/04-onion-routing.md#clarifications + */ + val ROUTE_MAX_LENGTH = 20 + + // Max allowed CLTV for a route + val DEFAULT_ROUTE_MAX_CLTV = CltvExpiryDelta(1008) + + // The default number of routes we'll search for when findRoute is called with randomize = true + val DEFAULT_ROUTES_COUNT = 3 + + def getDefaultRouteParams(routerConf: RouterConf) = RouteParams( + randomize = routerConf.randomizeRouteSelection, + maxFeeBase = routerConf.searchMaxFeeBase.toMilliSatoshi, + maxFeePct = routerConf.searchMaxFeePct, + routeMaxLength = routerConf.searchMaxRouteLength, + routeMaxCltv = routerConf.searchMaxCltv, + ratios = routerConf.searchHeuristicsEnabled match { + case false => None + case true => Some(WeightRatios( + cltvDeltaFactor = routerConf.searchRatioCltv, + ageFactor = routerConf.searchRatioChannelAge, + capacityFactor = routerConf.searchRatioChannelCapacity + )) + } + ) + + /** + * Find a route in the graph between localNodeId and targetNodeId, returns the route. + * Will perform a k-shortest path selection given the @param numRoutes and randomly select one of the result. + * + * @param g graph of the whole network + * @param localNodeId sender node (payer) + * @param targetNodeId target node (final recipient) + * @param amount the amount that will be sent along this route + * @param numRoutes the number of shortest-paths to find + * @param extraEdges a set of extra edges we want to CONSIDER during the search + * @param ignoredEdges a set of extra edges we want to IGNORE during the search + * @param routeParams a set of parameters that can restrict the route search + * @return the computed route to the destination @targetNodeId + */ + def findRoute(g: DirectedGraph, + localNodeId: PublicKey, + targetNodeId: PublicKey, + amount: MilliSatoshi, + numRoutes: Int, + extraEdges: Set[GraphEdge] = Set.empty, + ignoredEdges: Set[ChannelDesc] = Set.empty, + ignoredVertices: Set[PublicKey] = Set.empty, + routeParams: RouteParams, + currentBlockHeight: Long): Try[Seq[ChannelHop]] = Try { + + if (localNodeId == targetNodeId) throw CannotRouteToSelf + + def feeBaseOk(fee: MilliSatoshi): Boolean = fee <= routeParams.maxFeeBase + + def feePctOk(fee: MilliSatoshi, amount: MilliSatoshi): Boolean = { + val maxFee = amount * routeParams.maxFeePct + fee <= maxFee + } + + def feeOk(fee: MilliSatoshi, amount: MilliSatoshi): Boolean = feeBaseOk(fee) || feePctOk(fee, amount) + + def lengthOk(length: Int): Boolean = length <= routeParams.routeMaxLength && length <= ROUTE_MAX_LENGTH + + def cltvOk(cltv: CltvExpiryDelta): Boolean = cltv <= routeParams.routeMaxCltv + + val boundaries: RichWeight => Boolean = { weight => + feeOk(weight.cost - amount, amount) && lengthOk(weight.length) && cltvOk(weight.cltv) + } + + val foundRoutes = KamonExt.time(Metrics.FindRouteDuration.withTag(Tags.NumberOfRoutes, numRoutes).withTag(Tags.Amount, Tags.amountBucket(amount))) { + Graph.yenKshortestPaths(g, localNodeId, targetNodeId, amount, ignoredEdges, ignoredVertices, extraEdges, numRoutes, routeParams.ratios, currentBlockHeight, boundaries).toList + } + foundRoutes match { + case Nil if routeParams.routeMaxLength < ROUTE_MAX_LENGTH => // if not found within the constraints we relax and repeat the search + Metrics.RouteLength.withTag(Tags.Amount, Tags.amountBucket(amount)).record(0) + return findRoute(g, localNodeId, targetNodeId, amount, numRoutes, extraEdges, ignoredEdges, ignoredVertices, routeParams.copy(routeMaxLength = ROUTE_MAX_LENGTH, routeMaxCltv = DEFAULT_ROUTE_MAX_CLTV), currentBlockHeight) + case Nil => + Metrics.RouteLength.withTag(Tags.Amount, Tags.amountBucket(amount)).record(0) + throw RouteNotFound + case foundRoutes => + val routes = foundRoutes.find(_.path.size == 1) match { + case Some(directRoute) => directRoute :: Nil + case _ => foundRoutes + } + // At this point 'routes' cannot be empty + val randomizedRoutes = if (routeParams.randomize) Random.shuffle(routes) else routes + val route = randomizedRoutes.head.path.map(graphEdgeToHop) + Metrics.RouteLength.withTag(Tags.Amount, Tags.amountBucket(amount)).record(route.length) + route + } + } +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index abcc052a8..8ec30d27d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -17,12 +17,12 @@ package fr.acinq.eclair.router import akka.Done -import akka.actor.{ActorRef, Props, Status} +import akka.actor.{ActorRef, Props} +import akka.event.DiagnosticLoggingAdapter import akka.event.Logging.MDC -import akka.event.LoggingAdapter import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.Script.{pay2wsh, write} -import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Satoshi} +import fr.acinq.bitcoin.{ByteVector32, Satoshi} import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair._ import fr.acinq.eclair.blockchain._ @@ -30,192 +30,28 @@ import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.db.NetworkDb import fr.acinq.eclair.io.Peer.PeerRoutingMessage -import fr.acinq.eclair.io.PeerConnection -import fr.acinq.eclair.io.PeerConnection.InvalidAnnouncement import fr.acinq.eclair.payment.PaymentRequest.ExtraHop -import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph.graphEdgeToHop -import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} -import fr.acinq.eclair.router.Graph.{RichWeight, RoutingHeuristics, WeightRatios} -import fr.acinq.eclair.router.Monitoring.{Metrics, Tags} +import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph +import fr.acinq.eclair.router.Graph.WeightRatios import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire._ -import kamon.Kamon import kamon.context.Context -import scodec.bits.ByteVector -import shapeless.HNil -import scala.annotation.tailrec import scala.collection.immutable.SortedMap -import scala.collection.{SortedSet, mutable} -import scala.compat.Platform import scala.concurrent.duration._ import scala.concurrent.{ExecutionContext, Promise} -import scala.util.{Random, Try} +import scala.util.Try /** * Created by PM on 24/05/2016. */ - -case class RouterConf(randomizeRouteSelection: Boolean, - channelExcludeDuration: FiniteDuration, - routerBroadcastInterval: FiniteDuration, - networkStatsRefreshInterval: FiniteDuration, - requestNodeAnnouncements: Boolean, - encodingType: EncodingType, - channelRangeChunkSize: Int, - channelQueryChunkSize: Int, - searchMaxFeeBase: Satoshi, - searchMaxFeePct: Double, - searchMaxRouteLength: Int, - searchMaxCltv: CltvExpiryDelta, - searchHeuristicsEnabled: Boolean, - searchRatioCltv: Double, - searchRatioChannelAge: Double, - searchRatioChannelCapacity: Double) - -// @formatter:off -case class ChannelDesc(shortChannelId: ShortChannelId, a: PublicKey, b: PublicKey) -case class PublicChannel(ann: ChannelAnnouncement, fundingTxid: ByteVector32, capacity: Satoshi, update_1_opt: Option[ChannelUpdate], update_2_opt: Option[ChannelUpdate]) { - update_1_opt.foreach(u => assert(Announcements.isNode1(u.channelFlags))) - update_2_opt.foreach(u => assert(!Announcements.isNode1(u.channelFlags))) - - def getNodeIdSameSideAs(u: ChannelUpdate): PublicKey = if (Announcements.isNode1(u.channelFlags)) ann.nodeId1 else ann.nodeId2 - - def getChannelUpdateSameSideAs(u: ChannelUpdate): Option[ChannelUpdate] = if (Announcements.isNode1(u.channelFlags)) update_1_opt else update_2_opt - - def updateChannelUpdateSameSideAs(u: ChannelUpdate): PublicChannel = if (Announcements.isNode1(u.channelFlags)) copy(update_1_opt = Some(u)) else copy(update_2_opt = Some(u)) -} -case class PrivateChannel(localNodeId: PublicKey, remoteNodeId: PublicKey, update_1_opt: Option[ChannelUpdate], update_2_opt: Option[ChannelUpdate]) { - val (nodeId1, nodeId2) = if (Announcements.isNode1(localNodeId, remoteNodeId)) (localNodeId, remoteNodeId) else (remoteNodeId, localNodeId) - - def getNodeIdSameSideAs(u: ChannelUpdate): PublicKey = if (Announcements.isNode1(u.channelFlags)) nodeId1 else nodeId2 - - def getChannelUpdateSameSideAs(u: ChannelUpdate): Option[ChannelUpdate] = if (Announcements.isNode1(u.channelFlags)) update_1_opt else update_2_opt - - def updateChannelUpdateSameSideAs(u: ChannelUpdate): PrivateChannel = if (Announcements.isNode1(u.channelFlags)) copy(update_1_opt = Some(u)) else copy(update_2_opt = Some(u)) -} -// @formatter:on - -case class AssistedChannel(extraHop: ExtraHop, nextNodeId: PublicKey, htlcMaximum: MilliSatoshi) - -trait Hop { - /** @return the id of the start node. */ - def nodeId: PublicKey - - /** @return the id of the end node. */ - def nextNodeId: PublicKey - - /** - * @param amount amount to be forwarded. - * @return total fee required by the current hop. - */ - def fee(amount: MilliSatoshi): MilliSatoshi - - /** @return cltv delta required by the current hop. */ - def cltvExpiryDelta: CltvExpiryDelta -} - -/** - * A directed hop between two connected nodes using a specific channel. - * - * @param nodeId id of the start node. - * @param nextNodeId id of the end node. - * @param lastUpdate last update of the channel used for the hop. - */ -case class ChannelHop(nodeId: PublicKey, nextNodeId: PublicKey, lastUpdate: ChannelUpdate) extends Hop { - override lazy val cltvExpiryDelta: CltvExpiryDelta = lastUpdate.cltvExpiryDelta - - override def fee(amount: MilliSatoshi): MilliSatoshi = nodeFee(lastUpdate.feeBaseMsat, lastUpdate.feeProportionalMillionths, amount) -} - -/** - * A directed hop between two trampoline nodes. - * These nodes need not be connected and we don't need to know a route between them. - * The start node will compute the route to the end node itself when it receives our payment. - * - * @param nodeId id of the start node. - * @param nextNodeId id of the end node. - * @param cltvExpiryDelta cltv expiry delta. - * @param fee total fee for that hop. - */ -case class NodeHop(nodeId: PublicKey, nextNodeId: PublicKey, cltvExpiryDelta: CltvExpiryDelta, fee: MilliSatoshi) extends Hop { - override def fee(amount: MilliSatoshi): MilliSatoshi = fee -} - -case class RouteParams(randomize: Boolean, maxFeeBase: MilliSatoshi, maxFeePct: Double, routeMaxLength: Int, routeMaxCltv: CltvExpiryDelta, ratios: Option[WeightRatios]) - -case class RouteRequest(source: PublicKey, - target: PublicKey, - amount: MilliSatoshi, - assistedRoutes: Seq[Seq[ExtraHop]] = Nil, - ignoreNodes: Set[PublicKey] = Set.empty, - ignoreChannels: Set[ChannelDesc] = Set.empty, - routeParams: Option[RouteParams] = None) - -case class FinalizeRoute(hops: Seq[PublicKey], assistedRoutes: Seq[Seq[ExtraHop]] = Nil) - -case class RouteResponse(hops: Seq[ChannelHop], ignoreNodes: Set[PublicKey], ignoreChannels: Set[ChannelDesc], allowEmpty: Boolean = false) { - require(allowEmpty || hops.nonEmpty, "route cannot be empty") -} - -// @formatter:off -/** This is used when we get a TemporaryChannelFailure, to give time for the channel to recover (note that exclusions are directed) */ -case class ExcludeChannel(desc: ChannelDesc) -case class LiftChannelExclusion(desc: ChannelDesc) -// @formatter:on - -// @formatter:off -case class SendChannelQuery(remoteNodeId: PublicKey, to: ActorRef, flags_opt: Option[QueryChannelRangeTlv]) -case object GetNetworkStats -case class GetNetworkStatsResponse(stats: Option[NetworkStats]) -case object GetRoutingState -case class RoutingState(channels: Iterable[PublicChannel], nodes: Iterable[NodeAnnouncement]) -// @formatter:on - -// @formatter:off -sealed trait GossipOrigin -/** Gossip that we received from a remote peer. */ -case class RemoteGossip(peer: ActorRef) extends GossipOrigin -/** Gossip that was generated by our node. */ -case object LocalGossip extends GossipOrigin - -case class Stash(updates: Map[ChannelUpdate, Set[GossipOrigin]], nodes: Map[NodeAnnouncement, Set[GossipOrigin]]) -case class Rebroadcast(channels: Map[ChannelAnnouncement, Set[GossipOrigin]], updates: Map[ChannelUpdate, Set[GossipOrigin]], nodes: Map[NodeAnnouncement, Set[GossipOrigin]]) -// @formatter:on - -case class ShortChannelIdAndFlag(shortChannelId: ShortChannelId, flag: Long) - -case class Sync(pending: List[RoutingMessage], total: Int) - -case class Data(nodes: Map[PublicKey, NodeAnnouncement], - channels: SortedMap[ShortChannelId, PublicChannel], - stats: Option[NetworkStats], - stash: Stash, - rebroadcast: Rebroadcast, - awaiting: Map[ChannelAnnouncement, Seq[RemoteGossip]], // note: this is a seq because we want to preserve order: first actor is the one who we need to send a tcp-ack when validation is done - privateChannels: Map[ShortChannelId, PrivateChannel], // short_channel_id -> node_id - excludedChannels: Set[ChannelDesc], // those channels are temporarily excluded from route calculation, because their node returned a TemporaryChannelFailure - graph: DirectedGraph, - sync: Map[PublicKey, Sync] // keep tracks of channel range queries sent to each peer. If there is an entry in the map, it means that there is an ongoing query for which we have not yet received an 'end' message - ) - -// @formatter:off -sealed trait State -case object NORMAL extends State - -case object TickBroadcast -case object TickPruneStaleChannels -case object TickComputeNetworkStats -// @formatter:on - -class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Promise[Done]] = None) extends FSMDiagnosticActorLogging[State, Data] { - - import Router._ +class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Promise[Done]] = None) extends FSMDiagnosticActorLogging[Router.State, Router.Data] { import ExecutionContext.Implicits.global + import Router._ // we pass these to helpers classes so that they have the logging context - implicit def implicitLog: LoggingAdapter = log + implicit def implicitLog: DiagnosticLoggingAdapter = diagLog context.system.eventStream.subscribe(self, classOf[LocalChannelUpdate]) context.system.eventStream.subscribe(self, classOf[LocalChannelDown]) @@ -224,8 +60,6 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ setTimer(TickPruneStaleChannels.toString, TickPruneStaleChannels, 1 hour, repeat = true) setTimer(TickComputeNetworkStats.toString, TickComputeNetworkStats, nodeParams.routerConf.networkStatsRefreshInterval, repeat = true) - val defaultRouteParams: RouteParams = getDefaultRouteParams(nodeParams.routerConf) - val db: NetworkDb = nodeParams.db.network { @@ -265,54 +99,6 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ } when(NORMAL) { - case Event(LocalChannelUpdate(_, _, shortChannelId, remoteNodeId, channelAnnouncement_opt, u, _), d: Data) => - d.channels.get(shortChannelId) match { - case Some(_) => - // channel has already been announced and router knows about it, we can process the channel_update - stay using handle(u, LocalGossip, d) - case None => - channelAnnouncement_opt match { - case Some(c) if d.awaiting.contains(c) => - // channel is currently being verified, we can process the channel_update right away (it will be stashed) - stay using handle(u, LocalGossip, d) - case Some(c) => - // channel wasn't announced but here is the announcement, we will process it *before* the channel_update - watcher ! ValidateRequest(c) - val d1 = d.copy(awaiting = d.awaiting + (c -> Nil)) // no origin - // maybe the local channel was pruned (can happen if we were disconnected for more than 2 weeks) - db.removeFromPruned(c.shortChannelId) - stay using handle(u, LocalGossip, d1) - case None if d.privateChannels.contains(shortChannelId) => - // channel isn't announced but we already know about it, we can process the channel_update - stay using handle(u, LocalGossip, d) - case None => - // channel isn't announced and we never heard of it (maybe it is a private channel or maybe it is a public channel that doesn't yet have 6 confirmations) - // let's create a corresponding private channel and process the channel_update - log.debug("adding unannounced local channel to remote={} shortChannelId={}", remoteNodeId, shortChannelId) - stay using handle(u, LocalGossip, d.copy(privateChannels = d.privateChannels + (shortChannelId -> PrivateChannel(nodeParams.nodeId, remoteNodeId, None, None)))) - } - } - - case Event(LocalChannelDown(_, channelId, shortChannelId, remoteNodeId), d: Data) => - // a local channel has permanently gone down - if (d.channels.contains(shortChannelId)) { - // the channel was public, we will receive (or have already received) a WatchEventSpentBasic event, that will trigger a clean up of the channel - // so let's not do anything here - stay - } else if (d.privateChannels.contains(shortChannelId)) { - // the channel was private or public-but-not-yet-announced, let's do the clean up - log.debug("removing private local channel and channel_update for channelId={} shortChannelId={}", channelId, shortChannelId) - val desc1 = ChannelDesc(shortChannelId, nodeParams.nodeId, remoteNodeId) - val desc2 = ChannelDesc(shortChannelId, remoteNodeId, nodeParams.nodeId) - // we remove the corresponding updates from the graph - val graph1 = d.graph - .removeEdge(desc1) - .removeEdge(desc2) - // and we remove the channel and channel_update from our state - stay using d.copy(privateChannels = d.privateChannels - shortChannelId, graph = graph1) - } else { - stay - } case Event(SyncProgress(progress), d: Data) => if (d.stats.isEmpty && progress == 1.0 && d.channels.nonEmpty) { @@ -330,126 +116,6 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ sender ! GetNetworkStatsResponse(d.stats) stay - case Event(v@ValidateResult(c, _), d0) => - Kamon.runWithContextEntry(shortChannelIdKey, c.shortChannelId) { - Kamon.runWithSpan(Kamon.currentSpan(), finishSpan = true) { - Kamon.runWithSpan(Kamon.spanBuilder("process-validate-result").start(), finishSpan = true) { - d0.awaiting.get(c) match { - case Some(origin +: _) => origin.peer ! TransportHandler.ReadAck(c) // now we can acknowledge the message, we only need to do it for the first peer that sent us the announcement - case _ => () - } - log.info("got validation result for shortChannelId={} (awaiting={} stash.nodes={} stash.updates={})", c.shortChannelId, d0.awaiting.size, d0.stash.nodes.size, d0.stash.updates.size) - val publicChannel_opt = v match { - case ValidateResult(c, Left(t)) => - log.warning("validation failure for shortChannelId={} reason={}", c.shortChannelId, t.getMessage) - None - case ValidateResult(c, Right((tx, UtxoStatus.Unspent))) => - val TxCoordinates(_, _, outputIndex) = ShortChannelId.coordinates(c.shortChannelId) - val (fundingOutputScript, ok) = Kamon.runWithSpan(Kamon.spanBuilder("checked-pubkeyscript").start(), finishSpan = true) { - // let's check that the output is indeed a P2WSH multisig 2-of-2 of nodeid1 and nodeid2) - val fundingOutputScript = write(pay2wsh(Scripts.multiSig2of2(c.bitcoinKey1, c.bitcoinKey2))) - val ok = tx.txOut.size < outputIndex + 1 || fundingOutputScript != tx.txOut(outputIndex).publicKeyScript - (fundingOutputScript, ok) - } - if (ok) { - log.error(s"invalid script for shortChannelId={}: txid={} does not have script=$fundingOutputScript at outputIndex=$outputIndex ann={}", c.shortChannelId, tx.txid, c) - d0.awaiting.get(c) match { - case Some(origins) => origins.foreach(_.peer ! InvalidAnnouncement(c)) - case _ => () - } - None - } else { - watcher ! WatchSpentBasic(self, tx, outputIndex, BITCOIN_FUNDING_EXTERNAL_CHANNEL_SPENT(c.shortChannelId)) - // TODO: check feature bit set - log.debug("added channel channelId={}", c.shortChannelId) - val capacity = tx.txOut(outputIndex).amount - context.system.eventStream.publish(ChannelsDiscovered(SingleChannelDiscovered(c, capacity, None, None) :: Nil)) - Kamon.runWithSpan(Kamon.spanBuilder("add-to-db").start(), finishSpan = true) { - db.addChannel(c, tx.txid, capacity) - } - // in case we just validated our first local channel, we announce the local node - if (!d0.nodes.contains(nodeParams.nodeId) && isRelatedTo(c, nodeParams.nodeId)) { - log.info("first local channel validated, announcing local node") - val nodeAnn = Announcements.makeNodeAnnouncement(nodeParams.privateKey, nodeParams.alias, nodeParams.color, nodeParams.publicAddresses, nodeParams.features) - self ! nodeAnn - } - Some(PublicChannel(c, tx.txid, capacity, None, None)) - } - case ValidateResult(c, Right((tx, fundingTxStatus: UtxoStatus.Spent))) => - if (fundingTxStatus.spendingTxConfirmed) { - log.warning("ignoring shortChannelId={} tx={} (funding tx already spent and spending tx is confirmed)", c.shortChannelId, tx.txid) - // the funding tx has been spent by a transaction that is now confirmed: peer shouldn't send us those - d0.awaiting.get(c) match { - case Some(origins) => origins.foreach(_.peer ! PeerConnection.ChannelClosed(c)) - case _ => () - } - } else { - log.debug("ignoring shortChannelId={} tx={} (funding tx already spent but spending tx isn't confirmed)", c.shortChannelId, tx.txid) - } - // there may be a record if we have just restarted - db.removeChannel(c.shortChannelId) - None - } - val span1 = Kamon.spanBuilder("reprocess-stash").start - // we also reprocess node and channel_update announcements related to channels that were just analyzed - val reprocessUpdates = d0.stash.updates.filterKeys(u => u.shortChannelId == c.shortChannelId) - val reprocessNodes = d0.stash.nodes.filterKeys(n => isRelatedTo(c, n.nodeId)) - // and we remove the reprocessed messages from the stash - val stash1 = d0.stash.copy(updates = d0.stash.updates -- reprocessUpdates.keys, nodes = d0.stash.nodes -- reprocessNodes.keys) - // we remove channel from awaiting map - val awaiting1 = d0.awaiting - c - span1.finish() - - publicChannel_opt match { - case Some(pc) => - Kamon.runWithSpan(Kamon.spanBuilder("build-new-state").start, finishSpan = true) { - // note: if the channel is graduating from private to public, the implementation (in the LocalChannelUpdate handler) guarantees that we will process a new channel_update - // right after the channel_announcement, channel_updates will be moved from private to public at that time - val d1 = d0.copy( - channels = d0.channels + (c.shortChannelId -> pc), - privateChannels = d0.privateChannels - c.shortChannelId, // we remove fake announcements that we may have made before - rebroadcast = d0.rebroadcast.copy(channels = d0.rebroadcast.channels + (c -> d0.awaiting.getOrElse(c, Nil).toSet)), // we also add the newly validated channels to the rebroadcast queue - stash = stash1, - awaiting = awaiting1) - // we only reprocess updates and nodes if validation succeeded - val d2 = reprocessUpdates.foldLeft(d1) { - case (d, (u, origins)) => origins.foldLeft(d) { case (d, origin) => handle(u, origin, d) } // we reprocess the same channel_update for every origin (to preserve origin information) - } - val d3 = reprocessNodes.foldLeft(d2) { - case (d, (n, origins)) => origins.foldLeft(d) { case (d, origin) => handle(n, origin, d) } // we reprocess the same node_announcement for every origins (to preserve origin information) - } - stay using d3 - } - case None => - stay using d0.copy(stash = stash1, awaiting = awaiting1) - } - } - } - } - - case Event(WatchEventSpentBasic(BITCOIN_FUNDING_EXTERNAL_CHANNEL_SPENT(shortChannelId)), d) if d.channels.contains(shortChannelId) => - val lostChannel = d.channels(shortChannelId).ann - log.info("funding tx of channelId={} has been spent", shortChannelId) - // we need to remove nodes that aren't tied to any channels anymore - val channels1 = d.channels - lostChannel.shortChannelId - val lostNodes = Seq(lostChannel.nodeId1, lostChannel.nodeId2).filterNot(nodeId => hasChannels(nodeId, channels1.values)) - // let's clean the db and send the events - log.info("pruning shortChannelId={} (spent)", shortChannelId) - db.removeChannel(shortChannelId) // NB: this also removes channel updates - // we also need to remove updates from the graph - val graph1 = d.graph - .removeEdge(ChannelDesc(lostChannel.shortChannelId, lostChannel.nodeId1, lostChannel.nodeId2)) - .removeEdge(ChannelDesc(lostChannel.shortChannelId, lostChannel.nodeId2, lostChannel.nodeId1)) - - context.system.eventStream.publish(ChannelLost(shortChannelId)) - lostNodes.foreach { - nodeId => - log.info("pruning nodeId={} (spent)", nodeId) - db.removeNode(nodeId) - context.system.eventStream.publish(NodeLost(nodeId)) - } - stay using d.copy(nodes = d.nodes -- lostNodes, channels = d.channels - shortChannelId, graph = graph1) - case Event(TickBroadcast, d) => if (d.rebroadcast.channels.isEmpty && d.rebroadcast.updates.isEmpty && d.rebroadcast.nodes.isEmpty) { stay @@ -470,38 +136,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ } case Event(TickPruneStaleChannels, d) => - // first we select channels that we will prune - val staleChannels = getStaleChannels(d.channels.values, nodeParams.currentBlockHeight) - val staleChannelIds = staleChannels.map(_.ann.shortChannelId) - // then we remove nodes that aren't tied to any channels anymore (and deduplicate them) - val potentialStaleNodes = staleChannels.flatMap(c => Set(c.ann.nodeId1, c.ann.nodeId2)).toSet - val channels1 = d.channels -- staleChannelIds - // no need to iterate on all nodes, just on those that are affected by current pruning - val staleNodes = potentialStaleNodes.filterNot(nodeId => hasChannels(nodeId, channels1.values)) - - // let's clean the db and send the events - db.removeChannels(staleChannelIds) // NB: this also removes channel updates - // we keep track of recently pruned channels so we don't revalidate them (zombie churn) - db.addToPruned(staleChannelIds) - staleChannelIds.foreach { shortChannelId => - log.info("pruning shortChannelId={} (stale)", shortChannelId) - context.system.eventStream.publish(ChannelLost(shortChannelId)) - } - - val staleChannelsToRemove = new mutable.MutableList[ChannelDesc] - staleChannels.foreach(ca => { - staleChannelsToRemove += ChannelDesc(ca.ann.shortChannelId, ca.ann.nodeId1, ca.ann.nodeId2) - staleChannelsToRemove += ChannelDesc(ca.ann.shortChannelId, ca.ann.nodeId2, ca.ann.nodeId1) - }) - - val graph1 = d.graph.removeEdges(staleChannelsToRemove) - staleNodes.foreach { - nodeId => - log.info("pruning nodeId={} (stale)", nodeId) - db.removeNode(nodeId) - context.system.eventStream.publish(NodeLost(nodeId)) - } - stay using d.copy(nodes = d.nodes -- staleNodes, channels = channels1, graph = graph1) + stay using StaleChannels.handlePruneStaleChannels(d, nodeParams.db.network, nodeParams.currentBlockHeight) case Event(ExcludeChannel(desc@ChannelDesc(shortChannelId, nodeId, _)), d) => val banDuration = nodeParams.routerConf.channelExcludeDuration @@ -534,56 +169,11 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ sender ! d stay - case Event(FinalizeRoute(partialHops, assistedRoutes), d) => - // NB: using a capacity of 0 msat will impact the path-finding algorithm. However here we don't run any path-finding, so it's ok. - val assistedChannels: Map[ShortChannelId, AssistedChannel] = assistedRoutes.flatMap(toAssistedChannels(_, partialHops.last, 0 msat)).toMap - val extraEdges = assistedChannels.values.map(ac => GraphEdge(ChannelDesc(ac.extraHop.shortChannelId, ac.extraHop.nodeId, ac.nextNodeId), toFakeUpdate(ac.extraHop, ac.htlcMaximum))).toSet - val g = extraEdges.foldLeft(d.graph) { case (g: DirectedGraph, e: GraphEdge) => g.addEdge(e) } - // split into sublists [(a,b),(b,c), ...] then get the edges between each of those pairs - partialHops.sliding(2).map { case List(v1, v2) => g.getEdgesBetween(v1, v2) }.toList match { - case edges if edges.nonEmpty && edges.forall(_.nonEmpty) => - val selectedEdges = edges.map(_.maxBy(_.update.htlcMaximumMsat.getOrElse(0 msat))) // select the largest edge - val hops = selectedEdges.map(d => ChannelHop(d.desc.a, d.desc.b, d.update)) - sender ! RouteResponse(hops, Set.empty, Set.empty) - case _ => // some nodes in the supplied route aren't connected in our graph - sender ! Status.Failure(new IllegalArgumentException("Not all the nodes in the supplied route are connected with public channels")) - } - stay + case Event(fr: FinalizeRoute, d) => + stay using RouteCalculation.finalizeRoute(d, fr) - case Event(RouteRequest(start, end, amount, assistedRoutes, ignoreNodes, ignoreChannels, params_opt), d) => - // we convert extra routing info provided in the payment request to fake channel_update - // it takes precedence over all other channel_updates we know - val assistedChannels: Map[ShortChannelId, AssistedChannel] = assistedRoutes.flatMap(toAssistedChannels(_, end, amount)).toMap - val extraEdges = assistedChannels.values.map(ac => GraphEdge(ChannelDesc(ac.extraHop.shortChannelId, ac.extraHop.nodeId, ac.nextNodeId), toFakeUpdate(ac.extraHop, ac.htlcMaximum))).toSet - val ignoredEdges = ignoreChannels ++ d.excludedChannels - val params = params_opt.getOrElse(defaultRouteParams) - val routesToFind = if (params.randomize) DEFAULT_ROUTES_COUNT else 1 - - log.info(s"finding a route $start->$end with assistedChannels={} ignoreNodes={} ignoreChannels={} excludedChannels={}", assistedChannels.keys.mkString(","), ignoreNodes.map(_.value).mkString(","), ignoreChannels.mkString(","), d.excludedChannels.mkString(",")) - log.info(s"finding a route with randomize={} params={}", routesToFind > 1, params) - findRoute(d.graph, start, end, amount, numRoutes = routesToFind, extraEdges = extraEdges, ignoredEdges = ignoredEdges, ignoredVertices = ignoreNodes, routeParams = params, nodeParams.currentBlockHeight) - .map(r => sender ! RouteResponse(r, ignoreNodes, ignoreChannels)) - .recover { case t => sender ! Status.Failure(t) } - stay - - case Event(SendChannelQuery(remoteNodeId, remote, flags_opt), d) => - // ask for everything - // we currently send only one query_channel_range message per peer, when we just (re)connected to it, so we don't - // have to worry about sending a new query_channel_range when another query is still in progress - val query = QueryChannelRange(nodeParams.chainHash, firstBlockNum = 0L, numberOfBlocks = Int.MaxValue.toLong, TlvStream(flags_opt.toList)) - log.info("sending query_channel_range={}", query) - remote ! query - - // we also set a pass-all filter for now (we can update it later) for the future gossip messages, by setting - // the first_timestamp field to the current date/time and timestamp_range to the maximum value - // NB: we can't just set firstTimestamp to 0, because in that case peer would send us all past messages matching - // that (i.e. the whole routing table) - val filter = GossipTimestampFilter(nodeParams.chainHash, firstTimestamp = Platform.currentTime.milliseconds.toSeconds, timestampRange = Int.MaxValue) - remote ! filter - - // clean our sync state for this peer: we receive a SendChannelQuery just when we connect/reconnect to a peer and - // will start a new complete sync process - stay using d.copy(sync = d.sync - remoteNodeId) + case Event(r: RouteRequest, d) => + stay using RouteCalculation.handleRouteRequest(d, nodeParams.routerConf, nodeParams.currentBlockHeight, r) // Warning: order matters here, this must be the first match for HasChainHash messages ! case Event(PeerRoutingMessage(_, _, routingMessage: HasChainHash), _) if routingMessage.chainHash != nodeParams.chainHash => @@ -591,364 +181,60 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ log.warning("message {} for wrong chain {}, we're on {}", routingMessage, routingMessage.chainHash, nodeParams.chainHash) stay - case Event(u: ChannelUpdate, d: Data) => - // it was sent by us (e.g. the payment lifecycle); routing messages that are sent by our peers are wrapped in a PeerRoutingMessage - log.debug("received channel update from {}", sender) - stay using handle(u, LocalGossip, d) + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, c: ChannelAnnouncement), d) => + stay using Validation.handleChannelAnnouncement(d, nodeParams.db.network, watcher, RemoteGossip(peerConnection, remoteNodeId), c) - case Event(PeerRoutingMessage(peerConnection, remoteNodeId, u: ChannelUpdate), d) => - sender ! TransportHandler.ReadAck(u) - log.debug("received channel update for shortChannelId={}", u.shortChannelId) - stay using handle(u, RemoteGossip(sender), d, remoteNodeId_opt = Some(remoteNodeId), peerConnection_opt = Some(peerConnection)) + case Event(r: ValidateResult, d) => + stay using Validation.handleChannelValidationResponse(d, nodeParams, watcher, r) - case Event(PeerRoutingMessage(_, _, c: ChannelAnnouncement), d) => - log.debug("received channel announcement for shortChannelId={} nodeId1={} nodeId2={}", c.shortChannelId, c.nodeId1, c.nodeId2) - if (d.channels.contains(c.shortChannelId)) { - sender ! TransportHandler.ReadAck(c) - log.debug("ignoring {} (duplicate)", c) - stay - } else if (d.awaiting.contains(c)) { - sender ! TransportHandler.ReadAck(c) - log.debug("ignoring {} (being verified)", c) - // adding the sender to the list of origins so that we don't send back the same announcement to this peer later - val origins = d.awaiting(c) :+ RemoteGossip(sender) - stay using d.copy(awaiting = d.awaiting + (c -> origins)) - } else if (db.isPruned(c.shortChannelId)) { - sender ! TransportHandler.ReadAck(c) - // channel was pruned and we haven't received a recent channel_update, so we have no reason to revalidate it - log.debug("ignoring {} (was pruned)", c) - stay - } else if (!Announcements.checkSigs(c)) { - sender ! TransportHandler.ReadAck(c) - log.warning("bad signature for announcement {}", c) - sender ! PeerConnection.InvalidSignature(c) - stay - } else { - log.info("validating shortChannelId={}", c.shortChannelId) - Kamon.runWithContextEntry(shortChannelIdKey, c.shortChannelId) { - Kamon.runWithSpan(Kamon.spanBuilder("validate-channel").tag("shortChannelId", c.shortChannelId.toString).start(), finishSpan = false) { - watcher ! ValidateRequest(c) - } - } - // we don't acknowledge the message just yet - stay using d.copy(awaiting = d.awaiting + (c -> Seq(RemoteGossip(sender)))) - } + case Event(WatchEventSpentBasic(e: BITCOIN_FUNDING_EXTERNAL_CHANNEL_SPENT), d) if d.channels.contains(e.shortChannelId) => + stay using Validation.handleChannelSpent(d, nodeParams.db.network, e) case Event(n: NodeAnnouncement, d: Data) => - // it was sent by us, routing messages that are sent by our peers are wrapped in a PeerRoutingMessage - log.debug("received node announcement from {}", sender) - stay using handle(n, LocalGossip, d) + stay using Validation.handleNodeAnnouncement(d, nodeParams.db.network, Set(LocalGossip), n) - case Event(PeerRoutingMessage(_, _, n: NodeAnnouncement), d: Data) => - sender ! TransportHandler.ReadAck(n) - log.debug("received node announcement for nodeId={}", n.nodeId) - stay using handle(n, RemoteGossip(sender), d) + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, n: NodeAnnouncement), d: Data) => + stay using Validation.handleNodeAnnouncement(d, nodeParams.db.network, Set(RemoteGossip(peerConnection, remoteNodeId)), n) - case Event(PeerRoutingMessage(peerConnection, remoteNodeId, routingMessage@QueryChannelRange(chainHash, firstBlockNum, numberOfBlocks, extendedQueryFlags_opt)), d) => - sender ! TransportHandler.ReadAck(routingMessage) - Kamon.runWithContextEntry(remoteNodeIdKey, remoteNodeId.toString) { - Kamon.runWithSpan(Kamon.spanBuilder("query-channel-range").start(), finishSpan = true) { - log.info("received query_channel_range with firstBlockNum={} numberOfBlocks={} extendedQueryFlags_opt={}", firstBlockNum, numberOfBlocks, extendedQueryFlags_opt) - // keep channel ids that are in [firstBlockNum, firstBlockNum + numberOfBlocks] - val shortChannelIds: SortedSet[ShortChannelId] = d.channels.keySet.filter(keep(firstBlockNum, numberOfBlocks, _)) - log.info("replying with {} items for range=({}, {})", shortChannelIds.size, firstBlockNum, numberOfBlocks) - val chunks = Kamon.runWithSpan(Kamon.spanBuilder("split-channel-ids").start(), finishSpan = true) { - split(shortChannelIds, firstBlockNum, numberOfBlocks, nodeParams.routerConf.channelRangeChunkSize) - } + case Event(u: ChannelUpdate, d: Data) => + stay using Validation.handleChannelUpdate(d, nodeParams.db.network, nodeParams.routerConf, Set(LocalGossip), u) - Kamon.runWithSpan(Kamon.spanBuilder("compute-timestamps-checksums").start(), finishSpan = true) { - chunks.foreach { chunk => - val reply = Router.buildReplyChannelRange(chunk, chainHash, nodeParams.routerConf.encodingType, routingMessage.queryFlags_opt, d.channels) - peerConnection ! reply - } - } - stay - } - } + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, u: ChannelUpdate), d) => + stay using Validation.handleChannelUpdate(d, nodeParams.db.network, nodeParams.routerConf, Set(RemoteGossip(peerConnection, remoteNodeId)), u) - case Event(PeerRoutingMessage(peerConnection, remoteNodeId, routingMessage@ReplyChannelRange(chainHash, _, _, _, shortChannelIds, _)), d) => - sender ! TransportHandler.ReadAck(routingMessage) + case Event(lcu: LocalChannelUpdate, d: Data) => + stay using Validation.handleLocalChannelUpdate(d, nodeParams.db.network, nodeParams.routerConf, nodeParams.nodeId, watcher, lcu) - Kamon.runWithContextEntry(remoteNodeIdKey, remoteNodeId.toString) { - Kamon.runWithSpan(Kamon.spanBuilder("reply-channel-range").start(), finishSpan = true) { + case Event(lcd: LocalChannelDown, d: Data) => + stay using Validation.handleLocalChannelDown(d, nodeParams.nodeId, lcd) - @tailrec - def loop(ids: List[ShortChannelId], timestamps: List[ReplyChannelRangeTlv.Timestamps], checksums: List[ReplyChannelRangeTlv.Checksums], acc: List[ShortChannelIdAndFlag] = List.empty[ShortChannelIdAndFlag]): List[ShortChannelIdAndFlag] = { - ids match { - case Nil => acc.reverse - case head :: tail => - val flag = computeFlag(d.channels)(head, timestamps.headOption, checksums.headOption, nodeParams.routerConf.requestNodeAnnouncements) - // 0 means nothing to query, just don't include it - val acc1 = if (flag != 0) ShortChannelIdAndFlag(head, flag) :: acc else acc - loop(tail, timestamps.drop(1), checksums.drop(1), acc1) - } - } + case Event(s: SendChannelQuery, d) => + stay using Sync.handleSendChannelQuery(d, s) - val timestamps_opt = routingMessage.timestamps_opt.map(_.timestamps).getOrElse(List.empty[ReplyChannelRangeTlv.Timestamps]) - val checksums_opt = routingMessage.checksums_opt.map(_.checksums).getOrElse(List.empty[ReplyChannelRangeTlv.Checksums]) + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, q: QueryChannelRange), d) => + Sync.handleQueryChannelRange(d.channels, nodeParams.routerConf, RemoteGossip(peerConnection, remoteNodeId), q) + stay - val shortChannelIdAndFlags = Kamon.runWithSpan(Kamon.spanBuilder("compute-flags").start(), finishSpan = true) { - loop(shortChannelIds.array, timestamps_opt, checksums_opt) - } + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, r: ReplyChannelRange), d) => + stay using Sync.handleReplyChannelRange(d, nodeParams.routerConf, RemoteGossip(peerConnection, remoteNodeId), r) - val (channelCount, updatesCount) = shortChannelIdAndFlags.foldLeft((0, 0)) { - case ((c, u), ShortChannelIdAndFlag(_, flag)) => - val c1 = c + (if (QueryShortChannelIdsTlv.QueryFlagType.includeChannelAnnouncement(flag)) 1 else 0) - val u1 = u + (if (QueryShortChannelIdsTlv.QueryFlagType.includeUpdate1(flag)) 1 else 0) + (if (QueryShortChannelIdsTlv.QueryFlagType.includeUpdate2(flag)) 1 else 0) - (c1, u1) - } - log.info(s"received reply_channel_range with {} channels, we're missing {} channel announcements and {} updates, format={}", shortChannelIds.array.size, channelCount, updatesCount, shortChannelIds.encoding) + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, q: QueryShortChannelIds), d) => + Sync.handleQueryShortChannelIds(d.nodes, d.channels, nodeParams.routerConf, RemoteGossip(peerConnection, remoteNodeId), q) + stay - def buildQuery(chunk: List[ShortChannelIdAndFlag]): QueryShortChannelIds = { - // always encode empty lists as UNCOMPRESSED - val encoding = if (chunk.isEmpty) EncodingType.UNCOMPRESSED else shortChannelIds.encoding - QueryShortChannelIds(chainHash, - shortChannelIds = EncodedShortChannelIds(encoding, chunk.map(_.shortChannelId)), - if (routingMessage.timestamps_opt.isDefined || routingMessage.checksums_opt.isDefined) - TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(encoding, chunk.map(_.flag))) - else - TlvStream.empty - ) - } - - // we update our sync data to this node (there may be multiple channel range responses and we can only query one set of ids at a time) - val replies = shortChannelIdAndFlags - .grouped(nodeParams.routerConf.channelQueryChunkSize) - .map(buildQuery) - .toList - - val (sync1, replynow_opt) = addToSync(d.sync, remoteNodeId, replies) - // we only send a reply right away if there were no pending requests - replynow_opt.foreach(peerConnection ! _) - val progress = syncProgress(sync1) - context.system.eventStream.publish(progress) - self ! progress - stay using d.copy(sync = sync1) - } - } - - case Event(PeerRoutingMessage(peerConnection, remoteNodeId, routingMessage@QueryShortChannelIds(chainHash, shortChannelIds, _)), d) => - sender ! TransportHandler.ReadAck(routingMessage) - - Kamon.runWithContextEntry(remoteNodeIdKey, remoteNodeId.toString) { - Kamon.runWithSpan(Kamon.spanBuilder("query-short-channel-ids").start(), finishSpan = true) { - - val flags = routingMessage.queryFlags_opt.map(_.array).getOrElse(List.empty[Long]) - - var channelCount = 0 - var updateCount = 0 - var nodeCount = 0 - - Router.processChannelQuery(d.nodes, d.channels)( - shortChannelIds.array, - flags, - ca => { - channelCount = channelCount + 1 - peerConnection ! ca - }, - cu => { - updateCount = updateCount + 1 - peerConnection ! cu - }, - na => { - nodeCount = nodeCount + 1 - peerConnection ! na - } - ) - log.info("received query_short_channel_ids with {} items, sent back {} channels and {} updates and {} nodes", shortChannelIds.array.size, channelCount, updateCount, nodeCount) - peerConnection ! ReplyShortChannelIdsEnd(chainHash, 1) - stay - } - } - - case Event(PeerRoutingMessage(peerConnection, remoteNodeId, routingMessage: ReplyShortChannelIdsEnd), d) => - sender ! TransportHandler.ReadAck(routingMessage) - // have we more channels to ask this peer? - val sync1 = d.sync.get(remoteNodeId) match { - case Some(sync) => - sync.pending match { - case nextRequest +: rest => - log.info(s"asking for the next slice of short_channel_ids (remaining=${sync.pending.size}/${sync.total})") - peerConnection ! nextRequest - d.sync + (remoteNodeId -> sync.copy(pending = rest)) - case Nil => - // we received reply_short_channel_ids_end for our last query and have not sent another one, we can now remove - // the remote peer from our map - log.info(s"sync complete (total=${sync.total})") - d.sync - remoteNodeId - } - case _ => d.sync - } - val progress = syncProgress(sync1) - context.system.eventStream.publish(progress) - self ! progress - stay using d.copy(sync = sync1) + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, r: ReplyShortChannelIdsEnd), d) => + stay using Sync.handleReplyShortChannelIdsEnd(d, RemoteGossip(peerConnection, remoteNodeId), r) } initialize() - def handle(n: NodeAnnouncement, origin: GossipOrigin, d: Data): Data = - if (d.stash.nodes.contains(n)) { - log.debug("ignoring {} (already stashed)", n) - val origins = d.stash.nodes(n) + origin - d.copy(stash = d.stash.copy(nodes = d.stash.nodes + (n -> origins))) - } else if (d.rebroadcast.nodes.contains(n)) { - log.debug("ignoring {} (pending rebroadcast)", n) - val origins = d.rebroadcast.nodes(n) + origin - d.copy(rebroadcast = d.rebroadcast.copy(nodes = d.rebroadcast.nodes + (n -> origins))) - } else if (d.nodes.contains(n.nodeId) && d.nodes(n.nodeId).timestamp >= n.timestamp) { - log.debug("ignoring {} (duplicate)", n) - d - } else if (!Announcements.checkSig(n)) { - log.warning("bad signature for {}", n) - origin match { - case RemoteGossip(peer) => peer ! PeerConnection.InvalidSignature(n) - case LocalGossip => - } - d - } else if (d.nodes.contains(n.nodeId)) { - log.debug("updated node nodeId={}", n.nodeId) - context.system.eventStream.publish(NodeUpdated(n)) - db.updateNode(n) - d.copy(nodes = d.nodes + (n.nodeId -> n), rebroadcast = d.rebroadcast.copy(nodes = d.rebroadcast.nodes + (n -> Set(origin)))) - } else if (d.channels.values.exists(c => isRelatedTo(c.ann, n.nodeId))) { - log.debug("added node nodeId={}", n.nodeId) - context.system.eventStream.publish(NodesDiscovered(n :: Nil)) - db.addNode(n) - d.copy(nodes = d.nodes + (n.nodeId -> n), rebroadcast = d.rebroadcast.copy(nodes = d.rebroadcast.nodes + (n -> Set(origin)))) - } else if (d.awaiting.keys.exists(c => isRelatedTo(c, n.nodeId))) { - log.debug("stashing {}", n) - d.copy(stash = d.stash.copy(nodes = d.stash.nodes + (n -> Set(origin)))) - } else { - log.debug("ignoring {} (no related channel found)", n) - // there may be a record if we have just restarted - db.removeNode(n.nodeId) - d - } - - def handle(u: ChannelUpdate, origin: GossipOrigin, d: Data, remoteNodeId_opt: Option[PublicKey] = None, peerConnection_opt: Option[ActorRef] = None): Data = - if (d.channels.contains(u.shortChannelId)) { - // related channel is already known (note: this means no related channel_update is in the stash) - val publicChannel = true - val pc = d.channels(u.shortChannelId) - val desc = getDesc(u, pc.ann) - if (d.rebroadcast.updates.contains(u)) { - log.debug("ignoring {} (pending rebroadcast)", u) - val origins = d.rebroadcast.updates(u) + origin - d.copy(rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> origins))) - } else if (isStale(u)) { - log.debug("ignoring {} (stale)", u) - d - } else if (pc.getChannelUpdateSameSideAs(u).exists(_.timestamp >= u.timestamp)) { - log.debug("ignoring {} (duplicate)", u) - d - } else if (!Announcements.checkSig(u, pc.getNodeIdSameSideAs(u))) { - log.warning("bad signature for announcement shortChannelId={} {}", u.shortChannelId, u) - origin match { - case RemoteGossip(peer) => peer ! PeerConnection.InvalidSignature(u) - case LocalGossip => - } - d - } else if (pc.getChannelUpdateSameSideAs(u).isDefined) { - log.debug("updated channel_update for shortChannelId={} public={} flags={} {}", u.shortChannelId, publicChannel, u.channelFlags, u) - context.system.eventStream.publish(ChannelUpdatesReceived(u :: Nil)) - db.updateChannel(u) - // update the graph - val graph1 = Announcements.isEnabled(u.channelFlags) match { - case true => d.graph.removeEdge(desc).addEdge(desc, u) - case false => d.graph.removeEdge(desc) // if the channel is now disabled, we remove it from the graph - } - d.copy(channels = d.channels + (u.shortChannelId -> pc.updateChannelUpdateSameSideAs(u)), rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> Set(origin))), graph = graph1) - } else { - log.debug("added channel_update for shortChannelId={} public={} flags={} {}", u.shortChannelId, publicChannel, u.channelFlags, u) - context.system.eventStream.publish(ChannelUpdatesReceived(u :: Nil)) - db.updateChannel(u) - // we also need to update the graph - val graph1 = d.graph.addEdge(desc, u) - d.copy(channels = d.channels + (u.shortChannelId -> pc.updateChannelUpdateSameSideAs(u)), privateChannels = d.privateChannels - u.shortChannelId, rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> Set(origin))), graph = graph1) - } - } else if (d.awaiting.keys.exists(c => c.shortChannelId == u.shortChannelId)) { - // channel is currently being validated - if (d.stash.updates.contains(u)) { - log.debug("ignoring {} (already stashed)", u) - val origins = d.stash.updates(u) + origin - d.copy(stash = d.stash.copy(updates = d.stash.updates + (u -> origins))) - } else { - log.debug("stashing {}", u) - d.copy(stash = d.stash.copy(updates = d.stash.updates + (u -> Set(origin)))) - } - } else if (d.privateChannels.contains(u.shortChannelId)) { - val publicChannel = false - val pc = d.privateChannels(u.shortChannelId) - val desc = if (Announcements.isNode1(u.channelFlags)) ChannelDesc(u.shortChannelId, pc.nodeId1, pc.nodeId2) else ChannelDesc(u.shortChannelId, pc.nodeId2, pc.nodeId1) - if (isStale(u)) { - log.debug("ignoring {} (stale)", u) - d - } else if (pc.getChannelUpdateSameSideAs(u).exists(_.timestamp >= u.timestamp)) { - log.debug("ignoring {} (already know same or newer)", u) - d - } else if (!Announcements.checkSig(u, desc.a)) { - log.warning("bad signature for announcement shortChannelId={} {}", u.shortChannelId, u) - origin match { - case RemoteGossip(peer) => peer ! PeerConnection.InvalidSignature(u) - case LocalGossip => - } - d - } else if (pc.getChannelUpdateSameSideAs(u).isDefined) { - log.debug("updated channel_update for shortChannelId={} public={} flags={} {}", u.shortChannelId, publicChannel, u.channelFlags, u) - context.system.eventStream.publish(ChannelUpdatesReceived(u :: Nil)) - // we also need to update the graph - val graph1 = d.graph.removeEdge(desc).addEdge(desc, u) - d.copy(privateChannels = d.privateChannels + (u.shortChannelId -> pc.updateChannelUpdateSameSideAs(u)), graph = graph1) - } else { - log.debug("added channel_update for shortChannelId={} public={} flags={} {}", u.shortChannelId, publicChannel, u.channelFlags, u) - context.system.eventStream.publish(ChannelUpdatesReceived(u :: Nil)) - // we also need to update the graph - val graph1 = d.graph.addEdge(desc, u) - d.copy(privateChannels = d.privateChannels + (u.shortChannelId -> pc.updateChannelUpdateSameSideAs(u)), graph = graph1) - } - } else if (db.isPruned(u.shortChannelId) && !isStale(u)) { - // the channel was recently pruned, but if we are here, it means that the update is not stale so this is the case - // of a zombie channel coming back from the dead. they probably sent us a channel_announcement right before this update, - // but we ignored it because the channel was in the 'pruned' list. Now that we know that the channel is alive again, - // let's remove the channel from the zombie list and ask the sender to re-send announcements (channel_announcement + updates) - // about that channel. We can ignore this update since we will receive it again - log.info(s"channel shortChannelId=${u.shortChannelId} is back from the dead! requesting announcements about this channel") - db.removeFromPruned(u.shortChannelId) - - // peerConnection_opt will contain a valid peerConnection only when we're handling an update that we received from a peer, not - // when we're sending updates to ourselves - (peerConnection_opt, remoteNodeId_opt) match { - case (Some(peerConnection), Some(remoteNodeId)) => - val query = QueryShortChannelIds(u.chainHash, EncodedShortChannelIds(nodeParams.routerConf.encodingType, List(u.shortChannelId)), TlvStream.empty) - d.sync.get(remoteNodeId) match { - case Some(sync) => - // we already have a pending request to that node, let's add this channel to the list and we'll get it later - // TODO: we only request channels with old style channel_query - d.copy(sync = d.sync + (remoteNodeId -> sync.copy(pending = sync.pending :+ query, total = sync.total + 1))) - case None => - // we send the query right away - peerConnection ! query - d.copy(sync = d.sync + (remoteNodeId -> Sync(pending = Nil, total = 1))) - } - case _ => - // we don't know which node this update came from (maybe it was stashed and the channel got pruned in the meantime or some other corner case). - // or we don't have a peerConnection to send our query to. - // anyway, that's not really a big deal because we have removed the channel from the pruned db so next time it shows up we will revalidate it - d - } - } else { - log.debug("ignoring announcement {} (unknown channel)", u) - d - } - override def mdc(currentMessage: Any): MDC = { val category_opt = LogCategory(currentMessage) currentMessage match { - case SendChannelQuery(remoteNodeId, _, _) => Logs.mdc(category_opt, remoteNodeId_opt = Some(remoteNodeId)) - case PeerRoutingMessage(_, remoteNodeId, _) => Logs.mdc(category_opt, remoteNodeId_opt = Some(remoteNodeId)) - case LocalChannelUpdate(_, _, _, remoteNodeId, _, _, _) => Logs.mdc(category_opt, remoteNodeId_opt = Some(remoteNodeId)) + case s: SendChannelQuery => Logs.mdc(category_opt, remoteNodeId_opt = Some(s.remoteNodeId)) + case prm: PeerRoutingMessage => Logs.mdc(category_opt, remoteNodeId_opt = Some(prm.remoteNodeId)) + case lcu: LocalChannelUpdate => Logs.mdc(category_opt, remoteNodeId_opt = Some(lcu.remoteNodeId)) case _ => Logs.mdc(category_opt) } } @@ -959,36 +245,179 @@ object Router { val shortChannelIdKey = Context.key[ShortChannelId]("shortChannelId", ShortChannelId(0)) val remoteNodeIdKey = Context.key[String]("remoteNodeId", "unknown") - // maximum number of ids we can keep in a single chunk and still have an encoded reply that is smaller than 65Kb - // please note that: - // - this is based on the worst case scenario where peer want timestamps and checksums and the reply is not compressed - // - the maximum number of public channels in a single block so far is less than 300, and the maximum number of tx per block - // almost never exceeds 2800 so this is not a real limitation yet - val MAXIMUM_CHUNK_SIZE = 3200 - def props(nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Promise[Done]] = None) = Props(new Router(nodeParams, watcher, initialized)) - def toFakeUpdate(extraHop: ExtraHop, htlcMaximum: MilliSatoshi): ChannelUpdate = { - // the `direction` bit in flags will not be accurate but it doesn't matter because it is not used - // what matters is that the `disable` bit is 0 so that this update doesn't get filtered out - ChannelUpdate(signature = ByteVector64.Zeroes, chainHash = ByteVector32.Zeroes, extraHop.shortChannelId, Platform.currentTime.milliseconds.toSeconds, messageFlags = 1, channelFlags = 0, extraHop.cltvExpiryDelta, htlcMinimumMsat = 0 msat, extraHop.feeBase, extraHop.feeProportionalMillionths, Some(htlcMaximum)) + case class RouterConf(randomizeRouteSelection: Boolean, + channelExcludeDuration: FiniteDuration, + routerBroadcastInterval: FiniteDuration, + networkStatsRefreshInterval: FiniteDuration, + requestNodeAnnouncements: Boolean, + encodingType: EncodingType, + channelRangeChunkSize: Int, + channelQueryChunkSize: Int, + searchMaxFeeBase: Satoshi, + searchMaxFeePct: Double, + searchMaxRouteLength: Int, + searchMaxCltv: CltvExpiryDelta, + searchHeuristicsEnabled: Boolean, + searchRatioCltv: Double, + searchRatioChannelAge: Double, + searchRatioChannelCapacity: Double) + + // @formatter:off + case class ChannelDesc(shortChannelId: ShortChannelId, a: PublicKey, b: PublicKey) + case class PublicChannel(ann: ChannelAnnouncement, fundingTxid: ByteVector32, capacity: Satoshi, update_1_opt: Option[ChannelUpdate], update_2_opt: Option[ChannelUpdate]) { + update_1_opt.foreach(u => assert(Announcements.isNode1(u.channelFlags))) + update_2_opt.foreach(u => assert(!Announcements.isNode1(u.channelFlags))) + + def getNodeIdSameSideAs(u: ChannelUpdate): PublicKey = if (Announcements.isNode1(u.channelFlags)) ann.nodeId1 else ann.nodeId2 + + def getChannelUpdateSameSideAs(u: ChannelUpdate): Option[ChannelUpdate] = if (Announcements.isNode1(u.channelFlags)) update_1_opt else update_2_opt + + def updateChannelUpdateSameSideAs(u: ChannelUpdate): PublicChannel = if (Announcements.isNode1(u.channelFlags)) copy(update_1_opt = Some(u)) else copy(update_2_opt = Some(u)) + } + case class PrivateChannel(localNodeId: PublicKey, remoteNodeId: PublicKey, update_1_opt: Option[ChannelUpdate], update_2_opt: Option[ChannelUpdate]) { + val (nodeId1, nodeId2) = if (Announcements.isNode1(localNodeId, remoteNodeId)) (localNodeId, remoteNodeId) else (remoteNodeId, localNodeId) + + def getNodeIdSameSideAs(u: ChannelUpdate): PublicKey = if (Announcements.isNode1(u.channelFlags)) nodeId1 else nodeId2 + + def getChannelUpdateSameSideAs(u: ChannelUpdate): Option[ChannelUpdate] = if (Announcements.isNode1(u.channelFlags)) update_1_opt else update_2_opt + + def updateChannelUpdateSameSideAs(u: ChannelUpdate): PrivateChannel = if (Announcements.isNode1(u.channelFlags)) copy(update_1_opt = Some(u)) else copy(update_2_opt = Some(u)) + } + // @formatter:on + + case class AssistedChannel(extraHop: ExtraHop, nextNodeId: PublicKey, htlcMaximum: MilliSatoshi) + + trait Hop { + /** @return the id of the start node. */ + def nodeId: PublicKey + + /** @return the id of the end node. */ + def nextNodeId: PublicKey + + /** + * @param amount amount to be forwarded. + * @return total fee required by the current hop. + */ + def fee(amount: MilliSatoshi): MilliSatoshi + + /** @return cltv delta required by the current hop. */ + def cltvExpiryDelta: CltvExpiryDelta } - def toAssistedChannels(extraRoute: Seq[ExtraHop], targetNodeId: PublicKey, amount: MilliSatoshi): Map[ShortChannelId, AssistedChannel] = { - // BOLT 11: "For each entry, the pubkey is the node ID of the start of the channel", and the last node is the destination - // The invoice doesn't explicitly specify the channel's htlcMaximumMsat, but we can safely assume that the channel - // should be able to route the payment, so we'll compute an htlcMaximumMsat accordingly. - // We could also get the channel capacity from the blockchain (since we have the shortChannelId) but that's more expensive. - // We also need to make sure the channel isn't excluded by our heuristics. - val lastChannelCapacity = amount.max(RoutingHeuristics.CAPACITY_CHANNEL_LOW) - val nextNodeIds = extraRoute.map(_.nodeId).drop(1) :+ targetNodeId - extraRoute.zip(nextNodeIds).reverse.foldLeft((lastChannelCapacity, Map.empty[ShortChannelId, AssistedChannel])) { - case ((amount, acs), (extraHop: ExtraHop, nextNodeId)) => - val nextAmount = amount + nodeFee(extraHop.feeBase, extraHop.feeProportionalMillionths, amount) - (nextAmount, acs + (extraHop.shortChannelId -> AssistedChannel(extraHop, nextNodeId, nextAmount))) - }._2 + /** + * A directed hop between two connected nodes using a specific channel. + * + * @param nodeId id of the start node. + * @param nextNodeId id of the end node. + * @param lastUpdate last update of the channel used for the hop. + */ + case class ChannelHop(nodeId: PublicKey, nextNodeId: PublicKey, lastUpdate: ChannelUpdate) extends Hop { + override lazy val cltvExpiryDelta: CltvExpiryDelta = lastUpdate.cltvExpiryDelta + + override def fee(amount: MilliSatoshi): MilliSatoshi = nodeFee(lastUpdate.feeBaseMsat, lastUpdate.feeProportionalMillionths, amount) } + /** + * A directed hop between two trampoline nodes. + * These nodes need not be connected and we don't need to know a route between them. + * The start node will compute the route to the end node itself when it receives our payment. + * + * @param nodeId id of the start node. + * @param nextNodeId id of the end node. + * @param cltvExpiryDelta cltv expiry delta. + * @param fee total fee for that hop. + */ + case class NodeHop(nodeId: PublicKey, nextNodeId: PublicKey, cltvExpiryDelta: CltvExpiryDelta, fee: MilliSatoshi) extends Hop { + override def fee(amount: MilliSatoshi): MilliSatoshi = fee + } + + case class RouteParams(randomize: Boolean, maxFeeBase: MilliSatoshi, maxFeePct: Double, routeMaxLength: Int, routeMaxCltv: CltvExpiryDelta, ratios: Option[WeightRatios]) + + case class RouteRequest(source: PublicKey, + target: PublicKey, + amount: MilliSatoshi, + assistedRoutes: Seq[Seq[ExtraHop]] = Nil, + ignoreNodes: Set[PublicKey] = Set.empty, + ignoreChannels: Set[ChannelDesc] = Set.empty, + routeParams: Option[RouteParams] = None) + + case class FinalizeRoute(hops: Seq[PublicKey], assistedRoutes: Seq[Seq[ExtraHop]] = Nil) + + case class RouteResponse(hops: Seq[ChannelHop], ignoreNodes: Set[PublicKey], ignoreChannels: Set[ChannelDesc], allowEmpty: Boolean = false) { + require(allowEmpty || hops.nonEmpty, "route cannot be empty") + } + + // @formatter:off + /** This is used when we get a TemporaryChannelFailure, to give time for the channel to recover (note that exclusions are directed) */ + case class ExcludeChannel(desc: ChannelDesc) + case class LiftChannelExclusion(desc: ChannelDesc) + // @formatter:on + + // @formatter:off + case class SendChannelQuery(chainHash: ByteVector32, remoteNodeId: PublicKey, to: ActorRef, flags_opt: Option[QueryChannelRangeTlv]) + case object GetNetworkStats + case class GetNetworkStatsResponse(stats: Option[NetworkStats]) + case object GetRoutingState + case class RoutingState(channels: Iterable[PublicChannel], nodes: Iterable[NodeAnnouncement]) + case object GetRoutingStateStreaming + case object RoutingStateStreamingUpToDate + // @formatter:on + + // @formatter:off + sealed trait GossipOrigin + /** Gossip that we received from a remote peer. */ + case class RemoteGossip(peerConnection: ActorRef, nodeId: PublicKey) extends GossipOrigin + /** Gossip that was generated by our node. */ + case object LocalGossip extends GossipOrigin + + sealed trait GossipDecision { def ann: AnnouncementMessage } + object GossipDecision { + case class Accepted(ann: AnnouncementMessage) extends GossipDecision + + sealed trait Rejected extends GossipDecision + case class Duplicate(ann: AnnouncementMessage) extends Rejected + case class InvalidSignature(ann: AnnouncementMessage) extends Rejected + case class NoKnownChannel(ann: NodeAnnouncement) extends Rejected + case class ValidationFailure(ann: ChannelAnnouncement) extends Rejected + case class InvalidAnnouncement(ann: ChannelAnnouncement) extends Rejected + case class ChannelPruned(ann: ChannelAnnouncement) extends Rejected + case class ChannelClosing(ann: ChannelAnnouncement) extends Rejected + case class ChannelClosed(ann: ChannelAnnouncement) extends Rejected + case class Stale(ann: ChannelUpdate) extends Rejected + case class NoRelatedChannel(ann: ChannelUpdate) extends Rejected + } + + case class Stash(updates: Map[ChannelUpdate, Set[GossipOrigin]], nodes: Map[NodeAnnouncement, Set[GossipOrigin]]) + case class Rebroadcast(channels: Map[ChannelAnnouncement, Set[GossipOrigin]], updates: Map[ChannelUpdate, Set[GossipOrigin]], nodes: Map[NodeAnnouncement, Set[GossipOrigin]]) + // @formatter:on + + case class ShortChannelIdAndFlag(shortChannelId: ShortChannelId, flag: Long) + + case class Syncing(pending: List[RoutingMessage], total: Int) + + case class Data(nodes: Map[PublicKey, NodeAnnouncement], + channels: SortedMap[ShortChannelId, PublicChannel], + stats: Option[NetworkStats], + stash: Stash, + rebroadcast: Rebroadcast, + awaiting: Map[ChannelAnnouncement, Seq[RemoteGossip]], // note: this is a seq because we want to preserve order: first actor is the one who we need to send a tcp-ack when validation is done + privateChannels: Map[ShortChannelId, PrivateChannel], // short_channel_id -> node_id + excludedChannels: Set[ChannelDesc], // those channels are temporarily excluded from route calculation, because their node returned a TemporaryChannelFailure + graph: DirectedGraph, + sync: Map[PublicKey, Syncing] // keep tracks of channel range queries sent to each peer. If there is an entry in the map, it means that there is an ongoing query for which we have not yet received an 'end' message + ) + + // @formatter:off + sealed trait State + case object NORMAL extends State + + case object TickBroadcast + case object TickPruneStaleChannels + case object TickComputeNetworkStats + // @formatter:on + def getDesc(u: ChannelUpdate, channel: ChannelAnnouncement): ChannelDesc = { // the least significant bit tells us if it is node1 or node2 if (Announcements.isNode1(u.channelFlags)) ChannelDesc(u.shortChannelId, channel.nodeId1, channel.nodeId2) else ChannelDesc(u.shortChannelId, channel.nodeId2, channel.nodeId1) @@ -997,449 +426,4 @@ object Router { def isRelatedTo(c: ChannelAnnouncement, nodeId: PublicKey) = nodeId == c.nodeId1 || nodeId == c.nodeId2 def hasChannels(nodeId: PublicKey, channels: Iterable[PublicChannel]): Boolean = channels.exists(c => isRelatedTo(c.ann, nodeId)) - - def isStale(u: ChannelUpdate): Boolean = isStale(u.timestamp) - - def isStale(timestamp: Long): Boolean = { - // BOLT 7: "nodes MAY prune channels should the timestamp of the latest channel_update be older than 2 weeks" - // but we don't want to prune brand new channels for which we didn't yet receive a channel update - val staleThresholdSeconds = (Platform.currentTime.milliseconds - 14.days).toSeconds - timestamp < staleThresholdSeconds - } - - def isAlmostStale(timestamp: Long): Boolean = { - // we define almost stale as 2 weeks minus 4 days - val staleThresholdSeconds = (Platform.currentTime.milliseconds - 10.days).toSeconds - timestamp < staleThresholdSeconds - } - - /** - * Is stale a channel that: - * (1) is older than 2 weeks (2*7*144 = 2016 blocks) - * AND - * (2) has no channel_update younger than 2 weeks - * - * @param update1_opt update corresponding to one side of the channel, if we have it - * @param update2_opt update corresponding to the other side of the channel, if we have it - * @return - */ - def isStale(channel: ChannelAnnouncement, update1_opt: Option[ChannelUpdate], update2_opt: Option[ChannelUpdate], currentBlockHeight: Long): Boolean = { - // BOLT 7: "nodes MAY prune channels should the timestamp of the latest channel_update be older than 2 weeks (1209600 seconds)" - // but we don't want to prune brand new channels for which we didn't yet receive a channel update, so we keep them as long as they are less than 2 weeks (2016 blocks) old - val staleThresholdBlocks = currentBlockHeight - 2016 - val TxCoordinates(blockHeight, _, _) = ShortChannelId.coordinates(channel.shortChannelId) - blockHeight < staleThresholdBlocks && update1_opt.forall(isStale) && update2_opt.forall(isStale) - } - - def getStaleChannels(channels: Iterable[PublicChannel], currentBlockHeight: Long): Iterable[PublicChannel] = channels.filter(data => isStale(data.ann, data.update_1_opt, data.update_2_opt, currentBlockHeight)) - - /** - * Filters channels that we want to send to nodes asking for a channel range - */ - def keep(firstBlockNum: Long, numberOfBlocks: Long, id: ShortChannelId): Boolean = { - val height = id.blockHeight - height >= firstBlockNum && height < (firstBlockNum + numberOfBlocks) - } - - def shouldRequestUpdate(ourTimestamp: Long, ourChecksum: Long, theirTimestamp_opt: Option[Long], theirChecksum_opt: Option[Long]): Boolean = { - (theirTimestamp_opt, theirChecksum_opt) match { - case (Some(theirTimestamp), Some(theirChecksum)) => - // we request their channel_update if all those conditions are met: - // - it is more recent than ours - // - it is different from ours, or it is the same but ours is about to be stale - // - it is not stale - val theirsIsMoreRecent = ourTimestamp < theirTimestamp - val areDifferent = ourChecksum != theirChecksum - val oursIsAlmostStale = isAlmostStale(ourTimestamp) - val theirsIsStale = isStale(theirTimestamp) - theirsIsMoreRecent && (areDifferent || oursIsAlmostStale) && !theirsIsStale - case (Some(theirTimestamp), None) => - // if we only have their timestamp, we request their channel_update if theirs is more recent than ours - val theirsIsMoreRecent = ourTimestamp < theirTimestamp - val theirsIsStale = isStale(theirTimestamp) - theirsIsMoreRecent && !theirsIsStale - case (None, Some(theirChecksum)) => - // if we only have their checksum, we request their channel_update if it is different from ours - // NB: a zero checksum means that they don't have the data - val areDifferent = theirChecksum != 0 && ourChecksum != theirChecksum - areDifferent - case (None, None) => - // if we have neither their timestamp nor their checksum we request their channel_update - true - } - } - - def computeFlag(channels: SortedMap[ShortChannelId, PublicChannel])( - shortChannelId: ShortChannelId, - theirTimestamps_opt: Option[ReplyChannelRangeTlv.Timestamps], - theirChecksums_opt: Option[ReplyChannelRangeTlv.Checksums], - includeNodeAnnouncements: Boolean): Long = { - import QueryShortChannelIdsTlv.QueryFlagType._ - - val flagsNodes = if (includeNodeAnnouncements) INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2 else 0 - - val flags = if (!channels.contains(shortChannelId)) { - INCLUDE_CHANNEL_ANNOUNCEMENT | INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2 - } else { - // we already know this channel - val (ourTimestamps, ourChecksums) = Router.getChannelDigestInfo(channels)(shortChannelId) - // if they don't provide timestamps or checksums, we set appropriate default values: - // - we assume their timestamp is more recent than ours by setting timestamp = Long.MaxValue - // - we assume their update is different from ours by setting checkum = Long.MaxValue (NB: our default value for checksum is 0) - val shouldRequestUpdate1 = shouldRequestUpdate(ourTimestamps.timestamp1, ourChecksums.checksum1, theirTimestamps_opt.map(_.timestamp1), theirChecksums_opt.map(_.checksum1)) - val shouldRequestUpdate2 = shouldRequestUpdate(ourTimestamps.timestamp2, ourChecksums.checksum2, theirTimestamps_opt.map(_.timestamp2), theirChecksums_opt.map(_.checksum2)) - val flagUpdate1 = if (shouldRequestUpdate1) INCLUDE_CHANNEL_UPDATE_1 else 0 - val flagUpdate2 = if (shouldRequestUpdate2) INCLUDE_CHANNEL_UPDATE_2 else 0 - flagUpdate1 | flagUpdate2 - } - - if (flags == 0) 0 else flags | flagsNodes - } - - /** - * Handle a query message, which includes a list of channel ids and flags. - * - * @param nodes node id -> node announcement - * @param channels channel id -> channel announcement + updates - * @param ids list of channel ids - * @param flags list of query flags, either empty one flag per channel id - * @param onChannel called when a channel announcement matches (i.e. its bit is set in the query flag and we have it) - * @param onUpdate called when a channel update matches - * @param onNode called when a node announcement matches - * - */ - def processChannelQuery(nodes: Map[PublicKey, NodeAnnouncement], - channels: SortedMap[ShortChannelId, PublicChannel])( - ids: List[ShortChannelId], - flags: List[Long], - onChannel: ChannelAnnouncement => Unit, - onUpdate: ChannelUpdate => Unit, - onNode: NodeAnnouncement => Unit)(implicit log: LoggingAdapter): Unit = { - import QueryShortChannelIdsTlv.QueryFlagType - - // we loop over channel ids and query flag. We track node Ids for node announcement - // we've already sent to avoid sending them multiple times, as requested by the BOLTs - @tailrec - def loop(ids: List[ShortChannelId], flags: List[Long], numca: Int = 0, numcu: Int = 0, nodesSent: Set[PublicKey] = Set.empty[PublicKey]): (Int, Int, Int) = ids match { - case Nil => (numca, numcu, nodesSent.size) - case head :: tail if !channels.contains(head) => - log.warning("received query for shortChannelId={} that we don't have", head) - loop(tail, flags.drop(1), numca, numcu, nodesSent) - case head :: tail => - val numca1 = numca - val numcu1 = numcu - var sent1 = nodesSent - val pc = channels(head) - val flag_opt = flags.headOption - // no flag means send everything - - val includeChannel = flag_opt.forall(QueryFlagType.includeChannelAnnouncement) - val includeUpdate1 = flag_opt.forall(QueryFlagType.includeUpdate1) - val includeUpdate2 = flag_opt.forall(QueryFlagType.includeUpdate2) - val includeNode1 = flag_opt.forall(QueryFlagType.includeNodeAnnouncement1) - val includeNode2 = flag_opt.forall(QueryFlagType.includeNodeAnnouncement2) - - if (includeChannel) { - onChannel(pc.ann) - } - if (includeUpdate1) { - pc.update_1_opt.foreach { u => - onUpdate(u) - } - } - if (includeUpdate2) { - pc.update_2_opt.foreach { u => - onUpdate(u) - } - } - if (includeNode1 && !sent1.contains(pc.ann.nodeId1)) { - nodes.get(pc.ann.nodeId1).foreach { n => - onNode(n) - sent1 = sent1 + pc.ann.nodeId1 - } - } - if (includeNode2 && !sent1.contains(pc.ann.nodeId2)) { - nodes.get(pc.ann.nodeId2).foreach { n => - onNode(n) - sent1 = sent1 + pc.ann.nodeId2 - } - } - loop(tail, flags.drop(1), numca1, numcu1, sent1) - } - - loop(ids, flags) - } - - /** - * Returns overall progress on synchronization - * - * @return a sync progress indicator (1 means fully synced) - */ - def syncProgress(sync: Map[PublicKey, Sync]): SyncProgress = { - // NB: progress is in terms of requests, not individual channels - val (pending, total) = sync.foldLeft((0, 0)) { - case ((p, t), (_, sync)) => (p + sync.pending.size, t + sync.total) - } - if (total == 0) { - SyncProgress(1) - } else { - SyncProgress((total - pending) / (1.0 * total)) - } - } - - /** - * This method is used after a payment failed, and we want to exclude some nodes that we know are failing - */ - def getIgnoredChannelDesc(channels: Map[ShortChannelId, PublicChannel], ignoreNodes: Set[PublicKey]): Iterable[ChannelDesc] = { - val desc = if (ignoreNodes.isEmpty) { - Iterable.empty[ChannelDesc] - } else { - // expensive, but node blacklisting shouldn't happen often - channels.values - .filter(channelData => ignoreNodes.contains(channelData.ann.nodeId1) || ignoreNodes.contains(channelData.ann.nodeId2)) - .flatMap(channelData => Vector(ChannelDesc(channelData.ann.shortChannelId, channelData.ann.nodeId1, channelData.ann.nodeId2), ChannelDesc(channelData.ann.shortChannelId, channelData.ann.nodeId2, channelData.ann.nodeId1))) - } - desc - } - - def getChannelDigestInfo(channels: SortedMap[ShortChannelId, PublicChannel])(shortChannelId: ShortChannelId): (ReplyChannelRangeTlv.Timestamps, ReplyChannelRangeTlv.Checksums) = { - val c = channels(shortChannelId) - val timestamp1 = c.update_1_opt.map(_.timestamp).getOrElse(0L) - val timestamp2 = c.update_2_opt.map(_.timestamp).getOrElse(0L) - val checksum1 = c.update_1_opt.map(getChecksum).getOrElse(0L) - val checksum2 = c.update_2_opt.map(getChecksum).getOrElse(0L) - (ReplyChannelRangeTlv.Timestamps(timestamp1 = timestamp1, timestamp2 = timestamp2), ReplyChannelRangeTlv.Checksums(checksum1 = checksum1, checksum2 = checksum2)) - } - - def crc32c(data: ByteVector): Long = { - import com.google.common.hash.Hashing - Hashing.crc32c().hashBytes(data.toArray).asInt() & 0xFFFFFFFFL - } - - def getChecksum(u: ChannelUpdate): Long = { - import u._ - - val data = serializationResult(LightningMessageCodecs.channelUpdateChecksumCodec.encode(chainHash :: shortChannelId :: messageFlags :: channelFlags :: cltvExpiryDelta :: htlcMinimumMsat :: feeBaseMsat :: feeProportionalMillionths :: htlcMaximumMsat :: HNil)) - crc32c(data) - } - - case class ShortChannelIdsChunk(firstBlock: Long, numBlocks: Long, shortChannelIds: List[ShortChannelId]) { - /** - * - * @param maximumSize maximum size of the short channel ids list - * @return a chunk with at most `maximumSize` ids - */ - def enforceMaximumSize(maximumSize: Int) = { - if (shortChannelIds.size <= maximumSize) this else { - // we use a random offset here, so even if shortChannelIds.size is much bigger than maximumSize (which should - // not happen) peers will eventually receive info about all channels in this chunk - val offset = Random.nextInt(shortChannelIds.size - maximumSize + 1) - this.copy(shortChannelIds = this.shortChannelIds.slice(offset, offset + maximumSize)) - } - } - } - - /** - * Split short channel ids into chunks, because otherwise message could be too big - * there could be several reply_channel_range messages for a single query, but we make sure that the returned - * chunks fully covers the [firstBlockNum, numberOfBlocks] range that was requested - * - * @param shortChannelIds list of short channel ids to split - * @param firstBlockNum first block height requested by our peers - * @param numberOfBlocks number of blocks requested by our peer - * @param channelRangeChunkSize target chunk size. All ids that have the same block height will be grouped together, so - * returned chunks may still contain more than `channelRangeChunkSize` elements - * @return a list of short channel id chunks - */ - def split(shortChannelIds: SortedSet[ShortChannelId], firstBlockNum: Long, numberOfBlocks: Long, channelRangeChunkSize: Int): List[ShortChannelIdsChunk] = { - // see BOLT7: MUST encode a short_channel_id for every open channel it knows in blocks first_blocknum to first_blocknum plus number_of_blocks minus one - val it = shortChannelIds.iterator.dropWhile(_.blockHeight < firstBlockNum).takeWhile(_.blockHeight < firstBlockNum + numberOfBlocks) - if (it.isEmpty) { - List(ShortChannelIdsChunk(firstBlockNum, numberOfBlocks, List.empty)) - } else { - // we want to split ids in different chunks, with the following rules by order of priority - // ids that have the same block height must be grouped in the same chunk - // chunk should contain `channelRangeChunkSize` ids - @tailrec - def loop(currentChunk: List[ShortChannelId], acc: List[ShortChannelIdsChunk]): List[ShortChannelIdsChunk] = { - if (it.hasNext) { - val id = it.next() - val currentHeight = currentChunk.head.blockHeight - if (id.blockHeight == currentHeight) - loop(id :: currentChunk, acc) // same height => always add to the current chunk - else if (currentChunk.size < channelRangeChunkSize) // different height but we're under the size target => add to the current chunk - loop(id :: currentChunk, acc) // different height and over the size target => start a new chunk - else { - // we always prepend because it's more efficient so we have to reverse the current chunk - // for the first chunk, we make sure that we start at the request first block - // for the next chunks we start at the end of the range covered by the last chunk - val first = if (acc.isEmpty) firstBlockNum else acc.head.firstBlock + acc.head.numBlocks - val count = currentChunk.head.blockHeight - first + 1 - loop(id :: Nil, ShortChannelIdsChunk(first, count, currentChunk.reverse) :: acc) - } - } - else { - // for the last chunk, we make sure that we cover the requested block range - val first = if (acc.isEmpty) firstBlockNum else acc.head.firstBlock + acc.head.numBlocks - val count = numberOfBlocks - first + firstBlockNum - (ShortChannelIdsChunk(first, count, currentChunk.reverse) :: acc).reverse - } - } - - val first = it.next() - val chunks = loop(first :: Nil, Nil) - - // make sure that all our chunks match our max size policy - enforceMaximumSize(chunks) - } - } - - /** - * Enforce max-size constraints for each chunk - * - * @param chunks list of short channel id chunks - * @return a processed list of chunks - */ - def enforceMaximumSize(chunks: List[ShortChannelIdsChunk]): List[ShortChannelIdsChunk] = chunks.map(_.enforceMaximumSize(MAXIMUM_CHUNK_SIZE)) - - /** - * Build a `reply_channel_range` message - * - * @param chunk chunk of scids - * @param chainHash chain hash - * @param defaultEncoding default encoding - * @param queryFlags_opt query flag set by the requester - * @param channels channels map - * @return a ReplyChannelRange object - */ - def buildReplyChannelRange(chunk: ShortChannelIdsChunk, chainHash: ByteVector32, defaultEncoding: EncodingType, queryFlags_opt: Option[QueryChannelRangeTlv.QueryFlags], channels: SortedMap[ShortChannelId, PublicChannel]): ReplyChannelRange = { - val encoding = if (chunk.shortChannelIds.isEmpty) EncodingType.UNCOMPRESSED else defaultEncoding - val (timestamps, checksums) = queryFlags_opt match { - case Some(extension) if extension.wantChecksums | extension.wantTimestamps => - // we always compute timestamps and checksums even if we don't need both, overhead is negligible - val (timestamps, checksums) = chunk.shortChannelIds.map(getChannelDigestInfo(channels)).unzip - val encodedTimestamps = if (extension.wantTimestamps) Some(ReplyChannelRangeTlv.EncodedTimestamps(encoding, timestamps)) else None - val encodedChecksums = if (extension.wantChecksums) Some(ReplyChannelRangeTlv.EncodedChecksums(checksums)) else None - (encodedTimestamps, encodedChecksums) - case _ => (None, None) - } - ReplyChannelRange(chainHash, chunk.firstBlock, chunk.numBlocks, - complete = 1, - shortChannelIds = EncodedShortChannelIds(encoding, chunk.shortChannelIds), - timestamps = timestamps, - checksums = checksums) - } - - def addToSync(syncMap: Map[PublicKey, Sync], remoteNodeId: PublicKey, pending: List[RoutingMessage]): (Map[PublicKey, Sync], Option[RoutingMessage]) = { - pending match { - case head +: rest => - // they may send back several reply_channel_range messages for a single query_channel_range query, and we must not - // send another query_short_channel_ids query if they're still processing one - syncMap.get(remoteNodeId) match { - case None => - // we don't have a pending query with this peer, let's send it - (syncMap + (remoteNodeId -> Sync(rest, pending.size)), Some(head)) - case Some(sync) => - // we already have a pending query with this peer, add missing ids to our "sync" state - (syncMap + (remoteNodeId -> Sync(sync.pending ++ pending, sync.total + pending.size)), None) - } - case Nil => - // there is nothing to send - (syncMap, None) - } - } - - /** - * https://github.com/lightningnetwork/lightning-rfc/blob/master/04-onion-routing.md#clarifications - */ - val ROUTE_MAX_LENGTH = 20 - - // Max allowed CLTV for a route - val DEFAULT_ROUTE_MAX_CLTV = CltvExpiryDelta(1008) - - // The default number of routes we'll search for when findRoute is called with randomize = true - val DEFAULT_ROUTES_COUNT = 3 - - def getDefaultRouteParams(routerConf: RouterConf) = RouteParams( - randomize = routerConf.randomizeRouteSelection, - maxFeeBase = routerConf.searchMaxFeeBase.toMilliSatoshi, - maxFeePct = routerConf.searchMaxFeePct, - routeMaxLength = routerConf.searchMaxRouteLength, - routeMaxCltv = routerConf.searchMaxCltv, - ratios = routerConf.searchHeuristicsEnabled match { - case false => None - case true => Some(WeightRatios( - cltvDeltaFactor = routerConf.searchRatioCltv, - ageFactor = routerConf.searchRatioChannelAge, - capacityFactor = routerConf.searchRatioChannelCapacity - )) - } - ) - - /** - * Find a route in the graph between localNodeId and targetNodeId, returns the route. - * Will perform a k-shortest path selection given the @param numRoutes and randomly select one of the result. - * - * @param g graph of the whole network - * @param localNodeId sender node (payer) - * @param targetNodeId target node (final recipient) - * @param amount the amount that will be sent along this route - * @param numRoutes the number of shortest-paths to find - * @param extraEdges a set of extra edges we want to CONSIDER during the search - * @param ignoredEdges a set of extra edges we want to IGNORE during the search - * @param routeParams a set of parameters that can restrict the route search - * @return the computed route to the destination @targetNodeId - */ - def findRoute(g: DirectedGraph, - localNodeId: PublicKey, - targetNodeId: PublicKey, - amount: MilliSatoshi, - numRoutes: Int, - extraEdges: Set[GraphEdge] = Set.empty, - ignoredEdges: Set[ChannelDesc] = Set.empty, - ignoredVertices: Set[PublicKey] = Set.empty, - routeParams: RouteParams, - currentBlockHeight: Long): Try[Seq[ChannelHop]] = Try { - - if (localNodeId == targetNodeId) throw CannotRouteToSelf - - def feeBaseOk(fee: MilliSatoshi): Boolean = fee <= routeParams.maxFeeBase - - def feePctOk(fee: MilliSatoshi, amount: MilliSatoshi): Boolean = { - val maxFee = amount * routeParams.maxFeePct - fee <= maxFee - } - - def feeOk(fee: MilliSatoshi, amount: MilliSatoshi): Boolean = feeBaseOk(fee) || feePctOk(fee, amount) - - def lengthOk(length: Int): Boolean = length <= routeParams.routeMaxLength && length <= ROUTE_MAX_LENGTH - - def cltvOk(cltv: CltvExpiryDelta): Boolean = cltv <= routeParams.routeMaxCltv - - val boundaries: RichWeight => Boolean = { weight => - feeOk(weight.cost - amount, amount) && lengthOk(weight.length) && cltvOk(weight.cltv) - } - - val foundRoutes = KamonExt.time(Metrics.FindRouteDuration.withTag(Tags.NumberOfRoutes, numRoutes).withTag(Tags.Amount, Tags.amountBucket(amount))) { - Graph.yenKshortestPaths(g, localNodeId, targetNodeId, amount, ignoredEdges, ignoredVertices, extraEdges, numRoutes, routeParams.ratios, currentBlockHeight, boundaries).toList - } - foundRoutes match { - case Nil if routeParams.routeMaxLength < ROUTE_MAX_LENGTH => // if not found within the constraints we relax and repeat the search - Metrics.RouteLength.withTag(Tags.Amount, Tags.amountBucket(amount)).record(0) - return findRoute(g, localNodeId, targetNodeId, amount, numRoutes, extraEdges, ignoredEdges, ignoredVertices, routeParams.copy(routeMaxLength = ROUTE_MAX_LENGTH, routeMaxCltv = DEFAULT_ROUTE_MAX_CLTV), currentBlockHeight) - case Nil => - Metrics.RouteLength.withTag(Tags.Amount, Tags.amountBucket(amount)).record(0) - throw RouteNotFound - case foundRoutes => - val routes = foundRoutes.find(_.path.size == 1) match { - case Some(directRoute) => directRoute :: Nil - case _ => foundRoutes - } - // At this point 'routes' cannot be empty - val randomizedRoutes = if (routeParams.randomize) Random.shuffle(routes) else routes - val route = randomizedRoutes.head.path.map(graphEdgeToHop) - Metrics.RouteLength.withTag(Tags.Amount, Tags.amountBucket(amount)).record(route.length) - route - } - } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/StaleChannels.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/StaleChannels.scala new file mode 100644 index 000000000..8c33b551a --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/StaleChannels.scala @@ -0,0 +1,102 @@ +/* + * Copyright 2020 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.router + +import akka.actor.ActorContext +import akka.event.LoggingAdapter +import fr.acinq.eclair.db.NetworkDb +import fr.acinq.eclair.router.Router.{ChannelDesc, Data, PublicChannel, hasChannels} +import fr.acinq.eclair.wire.{ChannelAnnouncement, ChannelUpdate} +import fr.acinq.eclair.{ShortChannelId, TxCoordinates} + +import scala.collection.mutable +import scala.compat.Platform +import scala.concurrent.duration._ + +object StaleChannels { + + def handlePruneStaleChannels(d: Data, db: NetworkDb, currentBlockHeight: Long)(implicit ctx: ActorContext, log: LoggingAdapter): Data = { + // first we select channels that we will prune + val staleChannels = getStaleChannels(d.channels.values, currentBlockHeight) + val staleChannelIds = staleChannels.map(_.ann.shortChannelId) + // then we remove nodes that aren't tied to any channels anymore (and deduplicate them) + val potentialStaleNodes = staleChannels.flatMap(c => Set(c.ann.nodeId1, c.ann.nodeId2)).toSet + val channels1 = d.channels -- staleChannelIds + // no need to iterate on all nodes, just on those that are affected by current pruning + val staleNodes = potentialStaleNodes.filterNot(nodeId => hasChannels(nodeId, channels1.values)) + + // let's clean the db and send the events + db.removeChannels(staleChannelIds) // NB: this also removes channel updates + // we keep track of recently pruned channels so we don't revalidate them (zombie churn) + db.addToPruned(staleChannelIds) + staleChannelIds.foreach { shortChannelId => + log.info("pruning shortChannelId={} (stale)", shortChannelId) + ctx.system.eventStream.publish(ChannelLost(shortChannelId)) + } + + val staleChannelsToRemove = new mutable.MutableList[ChannelDesc] + staleChannels.foreach(ca => { + staleChannelsToRemove += ChannelDesc(ca.ann.shortChannelId, ca.ann.nodeId1, ca.ann.nodeId2) + staleChannelsToRemove += ChannelDesc(ca.ann.shortChannelId, ca.ann.nodeId2, ca.ann.nodeId1) + }) + + val graph1 = d.graph.removeEdges(staleChannelsToRemove) + staleNodes.foreach { + nodeId => + log.info("pruning nodeId={} (stale)", nodeId) + db.removeNode(nodeId) + ctx.system.eventStream.publish(NodeLost(nodeId)) + } + d.copy(nodes = d.nodes -- staleNodes, channels = channels1, graph = graph1) + } + + def isStale(u: ChannelUpdate): Boolean = isStale(u.timestamp) + + def isStale(timestamp: Long): Boolean = { + // BOLT 7: "nodes MAY prune channels should the timestamp of the latest channel_update be older than 2 weeks" + // but we don't want to prune brand new channels for which we didn't yet receive a channel update + val staleThresholdSeconds = (Platform.currentTime.milliseconds - 14.days).toSeconds + timestamp < staleThresholdSeconds + } + + def isAlmostStale(timestamp: Long): Boolean = { + // we define almost stale as 2 weeks minus 4 days + val staleThresholdSeconds = (Platform.currentTime.milliseconds - 10.days).toSeconds + timestamp < staleThresholdSeconds + } + + /** + * Is stale a channel that: + * (1) is older than 2 weeks (2*7*144 = 2016 blocks) + * AND + * (2) has no channel_update younger than 2 weeks + * + * @param update1_opt update corresponding to one side of the channel, if we have it + * @param update2_opt update corresponding to the other side of the channel, if we have it + * @return + */ + def isStale(channel: ChannelAnnouncement, update1_opt: Option[ChannelUpdate], update2_opt: Option[ChannelUpdate], currentBlockHeight: Long): Boolean = { + // BOLT 7: "nodes MAY prune channels should the timestamp of the latest channel_update be older than 2 weeks (1209600 seconds)" + // but we don't want to prune brand new channels for which we didn't yet receive a channel update, so we keep them as long as they are less than 2 weeks (2016 blocks) old + val staleThresholdBlocks = currentBlockHeight - 2016 + val TxCoordinates(blockHeight, _, _) = ShortChannelId.coordinates(channel.shortChannelId) + blockHeight < staleThresholdBlocks && update1_opt.forall(isStale) && update2_opt.forall(isStale) + } + + def getStaleChannels(channels: Iterable[PublicChannel], currentBlockHeight: Long): Iterable[PublicChannel] = channels.filter(data => isStale(data.ann, data.update_1_opt, data.update_2_opt, currentBlockHeight)) + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Sync.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Sync.scala new file mode 100644 index 000000000..be3df33da --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Sync.scala @@ -0,0 +1,515 @@ +/* + * Copyright 2020 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.router + +import akka.actor.{ActorContext, ActorRef} +import akka.event.LoggingAdapter +import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.Crypto.PublicKey +import fr.acinq.eclair.crypto.TransportHandler +import fr.acinq.eclair.router.Router._ +import fr.acinq.eclair.wire._ +import fr.acinq.eclair.{ShortChannelId, serializationResult} +import kamon.Kamon +import scodec.bits.ByteVector +import shapeless.HNil + +import scala.annotation.tailrec +import scala.collection.SortedSet +import scala.collection.immutable.SortedMap +import scala.compat.Platform +import scala.concurrent.duration._ +import scala.util.Random + +object Sync { + + // maximum number of ids we can keep in a single chunk and still have an encoded reply that is smaller than 65Kb + // please note that: + // - this is based on the worst case scenario where peer want timestamps and checksums and the reply is not compressed + // - the maximum number of public channels in a single block so far is less than 300, and the maximum number of tx per block + // almost never exceeds 2800 so this is not a real limitation yet + val MAXIMUM_CHUNK_SIZE = 3200 + + def handleSendChannelQuery(d: Data, s: SendChannelQuery)(implicit ctx: ActorContext, log: LoggingAdapter): Data = { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + // ask for everything + // we currently send only one query_channel_range message per peer, when we just (re)connected to it, so we don't + // have to worry about sending a new query_channel_range when another query is still in progress + val query = QueryChannelRange(s.chainHash, firstBlockNum = 0L, numberOfBlocks = Int.MaxValue.toLong, TlvStream(s.flags_opt.toList)) + log.info("sending query_channel_range={}", query) + s.to ! query + + // we also set a pass-all filter for now (we can update it later) for the future gossip messages, by setting + // the first_timestamp field to the current date/time and timestamp_range to the maximum value + // NB: we can't just set firstTimestamp to 0, because in that case peer would send us all past messages matching + // that (i.e. the whole routing table) + val filter = GossipTimestampFilter(s.chainHash, firstTimestamp = Platform.currentTime.milliseconds.toSeconds, timestampRange = Int.MaxValue) + s.to ! filter + + // clean our sync state for this peer: we receive a SendChannelQuery just when we connect/reconnect to a peer and + // will start a new complete sync process + d.copy(sync = d.sync - s.remoteNodeId) + } + + def handleQueryChannelRange(channels: SortedMap[ShortChannelId, PublicChannel], routerConf: RouterConf, origin: RemoteGossip, q: QueryChannelRange)(implicit ctx: ActorContext, log: LoggingAdapter): Unit = { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + ctx.sender ! TransportHandler.ReadAck(q) + Kamon.runWithContextEntry(remoteNodeIdKey, origin.nodeId.toString) { + Kamon.runWithSpan(Kamon.spanBuilder("query-channel-range").start(), finishSpan = true) { + log.info("received query_channel_range with firstBlockNum={} numberOfBlocks={} extendedQueryFlags_opt={}", q.firstBlockNum, q.numberOfBlocks, q.tlvStream) + // keep channel ids that are in [firstBlockNum, firstBlockNum + numberOfBlocks] + val shortChannelIds: SortedSet[ShortChannelId] = channels.keySet.filter(keep(q.firstBlockNum, q.numberOfBlocks, _)) + log.info("replying with {} items for range=({}, {})", shortChannelIds.size, q.firstBlockNum, q.numberOfBlocks) + val chunks = Kamon.runWithSpan(Kamon.spanBuilder("split-channel-ids").start(), finishSpan = true) { + split(shortChannelIds, q.firstBlockNum, q.numberOfBlocks, routerConf.channelRangeChunkSize) + } + + Kamon.runWithSpan(Kamon.spanBuilder("compute-timestamps-checksums").start(), finishSpan = true) { + chunks.foreach { chunk => + val reply = buildReplyChannelRange(chunk, q.chainHash, routerConf.encodingType, q.queryFlags_opt, channels) + origin.peerConnection ! reply + } + } + } + } + } + + def handleReplyChannelRange(d: Data, routerConf: RouterConf, origin: RemoteGossip, r: ReplyChannelRange)(implicit ctx: ActorContext, log: LoggingAdapter): Data = { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + ctx.sender ! TransportHandler.ReadAck(r) + + Kamon.runWithContextEntry(remoteNodeIdKey, origin.nodeId.toString) { + Kamon.runWithSpan(Kamon.spanBuilder("reply-channel-range").start(), finishSpan = true) { + + @tailrec + def loop(ids: List[ShortChannelId], timestamps: List[ReplyChannelRangeTlv.Timestamps], checksums: List[ReplyChannelRangeTlv.Checksums], acc: List[ShortChannelIdAndFlag] = List.empty[ShortChannelIdAndFlag]): List[ShortChannelIdAndFlag] = { + ids match { + case Nil => acc.reverse + case head :: tail => + val flag = computeFlag(d.channels)(head, timestamps.headOption, checksums.headOption, routerConf.requestNodeAnnouncements) + // 0 means nothing to query, just don't include it + val acc1 = if (flag != 0) ShortChannelIdAndFlag(head, flag) :: acc else acc + loop(tail, timestamps.drop(1), checksums.drop(1), acc1) + } + } + + val timestamps_opt = r.timestamps_opt.map(_.timestamps).getOrElse(List.empty[ReplyChannelRangeTlv.Timestamps]) + val checksums_opt = r.checksums_opt.map(_.checksums).getOrElse(List.empty[ReplyChannelRangeTlv.Checksums]) + + val shortChannelIdAndFlags = Kamon.runWithSpan(Kamon.spanBuilder("compute-flags").start(), finishSpan = true) { + loop(r.shortChannelIds.array, timestamps_opt, checksums_opt) + } + + val (channelCount, updatesCount) = shortChannelIdAndFlags.foldLeft((0, 0)) { + case ((c, u), ShortChannelIdAndFlag(_, flag)) => + val c1 = c + (if (QueryShortChannelIdsTlv.QueryFlagType.includeChannelAnnouncement(flag)) 1 else 0) + val u1 = u + (if (QueryShortChannelIdsTlv.QueryFlagType.includeUpdate1(flag)) 1 else 0) + (if (QueryShortChannelIdsTlv.QueryFlagType.includeUpdate2(flag)) 1 else 0) + (c1, u1) + } + log.info(s"received reply_channel_range with {} channels, we're missing {} channel announcements and {} updates, format={}", r.shortChannelIds.array.size, channelCount, updatesCount, r.shortChannelIds.encoding) + + def buildQuery(chunk: List[ShortChannelIdAndFlag]): QueryShortChannelIds = { + // always encode empty lists as UNCOMPRESSED + val encoding = if (chunk.isEmpty) EncodingType.UNCOMPRESSED else r.shortChannelIds.encoding + QueryShortChannelIds(r.chainHash, + shortChannelIds = EncodedShortChannelIds(encoding, chunk.map(_.shortChannelId)), + if (r.timestamps_opt.isDefined || r.checksums_opt.isDefined) + TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(encoding, chunk.map(_.flag))) + else + TlvStream.empty + ) + } + + // we update our sync data to this node (there may be multiple channel range responses and we can only query one set of ids at a time) + val replies = shortChannelIdAndFlags + .grouped(routerConf.channelQueryChunkSize) + .map(buildQuery) + .toList + + val (sync1, replynow_opt) = addToSync(d.sync, origin.nodeId, replies) + // we only send a reply right away if there were no pending requests + replynow_opt.foreach(origin.peerConnection ! _) + val progress = syncProgress(sync1) + ctx.system.eventStream.publish(progress) + ctx.self ! progress + d.copy(sync = sync1) + } + } + } + + def handleQueryShortChannelIds(nodes: Map[PublicKey, NodeAnnouncement], channels: SortedMap[ShortChannelId, PublicChannel], routerConf: RouterConf, origin: RemoteGossip, q: QueryShortChannelIds)(implicit ctx: ActorContext, log: LoggingAdapter): Unit = { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + ctx.sender ! TransportHandler.ReadAck(q) + + Kamon.runWithContextEntry(remoteNodeIdKey, origin.nodeId.toString) { + Kamon.runWithSpan(Kamon.spanBuilder("query-short-channel-ids").start(), finishSpan = true) { + + val flags = q.queryFlags_opt.map(_.array).getOrElse(List.empty[Long]) + + var channelCount = 0 + var updateCount = 0 + var nodeCount = 0 + + processChannelQuery(nodes, channels)( + q.shortChannelIds.array, + flags, + ca => { + channelCount = channelCount + 1 + origin.peerConnection ! ca + }, + cu => { + updateCount = updateCount + 1 + origin.peerConnection ! cu + }, + na => { + nodeCount = nodeCount + 1 + origin.peerConnection ! na + } + ) + log.info("received query_short_channel_ids with {} items, sent back {} channels and {} updates and {} nodes", q.shortChannelIds.array.size, channelCount, updateCount, nodeCount) + origin.peerConnection ! ReplyShortChannelIdsEnd(q.chainHash, 1) + } + } + } + + def handleReplyShortChannelIdsEnd(d: Data, origin: RemoteGossip, r: ReplyShortChannelIdsEnd)(implicit ctx: ActorContext, log: LoggingAdapter): Data = { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + ctx.sender ! TransportHandler.ReadAck(r) + // have we more channels to ask this peer? + val sync1 = d.sync.get(origin.nodeId) match { + case Some(sync) => + sync.pending match { + case nextRequest +: rest => + log.info(s"asking for the next slice of short_channel_ids (remaining=${sync.pending.size}/${sync.total})") + origin.peerConnection ! nextRequest + d.sync + (origin.nodeId -> sync.copy(pending = rest)) + case Nil => + // we received reply_short_channel_ids_end for our last query and have not sent another one, we can now remove + // the remote peer from our map + log.info(s"sync complete (total=${sync.total})") + d.sync - origin.nodeId + } + case _ => d.sync + } + val progress = syncProgress(sync1) + ctx.system.eventStream.publish(progress) + ctx.self ! progress + d.copy(sync = sync1) + } + + /** + * Filters channels that we want to send to nodes asking for a channel range + */ + def keep(firstBlockNum: Long, numberOfBlocks: Long, id: ShortChannelId): Boolean = { + val height = id.blockHeight + height >= firstBlockNum && height < (firstBlockNum + numberOfBlocks) + } + + def shouldRequestUpdate(ourTimestamp: Long, ourChecksum: Long, theirTimestamp_opt: Option[Long], theirChecksum_opt: Option[Long]): Boolean = { + (theirTimestamp_opt, theirChecksum_opt) match { + case (Some(theirTimestamp), Some(theirChecksum)) => + // we request their channel_update if all those conditions are met: + // - it is more recent than ours + // - it is different from ours, or it is the same but ours is about to be stale + // - it is not stale + val theirsIsMoreRecent = ourTimestamp < theirTimestamp + val areDifferent = ourChecksum != theirChecksum + val oursIsAlmostStale = StaleChannels.isAlmostStale(ourTimestamp) + val theirsIsStale = StaleChannels.isStale(theirTimestamp) + theirsIsMoreRecent && (areDifferent || oursIsAlmostStale) && !theirsIsStale + case (Some(theirTimestamp), None) => + // if we only have their timestamp, we request their channel_update if theirs is more recent than ours + val theirsIsMoreRecent = ourTimestamp < theirTimestamp + val theirsIsStale = StaleChannels.isStale(theirTimestamp) + theirsIsMoreRecent && !theirsIsStale + case (None, Some(theirChecksum)) => + // if we only have their checksum, we request their channel_update if it is different from ours + // NB: a zero checksum means that they don't have the data + val areDifferent = theirChecksum != 0 && ourChecksum != theirChecksum + areDifferent + case (None, None) => + // if we have neither their timestamp nor their checksum we request their channel_update + true + } + } + + def computeFlag(channels: SortedMap[ShortChannelId, PublicChannel])( + shortChannelId: ShortChannelId, + theirTimestamps_opt: Option[ReplyChannelRangeTlv.Timestamps], + theirChecksums_opt: Option[ReplyChannelRangeTlv.Checksums], + includeNodeAnnouncements: Boolean): Long = { + import QueryShortChannelIdsTlv.QueryFlagType._ + + val flagsNodes = if (includeNodeAnnouncements) INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2 else 0 + + val flags = if (!channels.contains(shortChannelId)) { + INCLUDE_CHANNEL_ANNOUNCEMENT | INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2 + } else { + // we already know this channel + val (ourTimestamps, ourChecksums) = getChannelDigestInfo(channels)(shortChannelId) + // if they don't provide timestamps or checksums, we set appropriate default values: + // - we assume their timestamp is more recent than ours by setting timestamp = Long.MaxValue + // - we assume their update is different from ours by setting checkum = Long.MaxValue (NB: our default value for checksum is 0) + val shouldRequestUpdate1 = shouldRequestUpdate(ourTimestamps.timestamp1, ourChecksums.checksum1, theirTimestamps_opt.map(_.timestamp1), theirChecksums_opt.map(_.checksum1)) + val shouldRequestUpdate2 = shouldRequestUpdate(ourTimestamps.timestamp2, ourChecksums.checksum2, theirTimestamps_opt.map(_.timestamp2), theirChecksums_opt.map(_.checksum2)) + val flagUpdate1 = if (shouldRequestUpdate1) INCLUDE_CHANNEL_UPDATE_1 else 0 + val flagUpdate2 = if (shouldRequestUpdate2) INCLUDE_CHANNEL_UPDATE_2 else 0 + flagUpdate1 | flagUpdate2 + } + + if (flags == 0) 0 else flags | flagsNodes + } + + /** + * Handle a query message, which includes a list of channel ids and flags. + * + * @param nodes node id -> node announcement + * @param channels channel id -> channel announcement + updates + * @param ids list of channel ids + * @param flags list of query flags, either empty one flag per channel id + * @param onChannel called when a channel announcement matches (i.e. its bit is set in the query flag and we have it) + * @param onUpdate called when a channel update matches + * @param onNode called when a node announcement matches + * + */ + def processChannelQuery(nodes: Map[PublicKey, NodeAnnouncement], + channels: SortedMap[ShortChannelId, PublicChannel])( + ids: List[ShortChannelId], + flags: List[Long], + onChannel: ChannelAnnouncement => Unit, + onUpdate: ChannelUpdate => Unit, + onNode: NodeAnnouncement => Unit)(implicit log: LoggingAdapter): Unit = { + import QueryShortChannelIdsTlv.QueryFlagType + + // we loop over channel ids and query flag. We track node Ids for node announcement + // we've already sent to avoid sending them multiple times, as requested by the BOLTs + @tailrec + def loop(ids: List[ShortChannelId], flags: List[Long], numca: Int = 0, numcu: Int = 0, nodesSent: Set[PublicKey] = Set.empty[PublicKey]): (Int, Int, Int) = ids match { + case Nil => (numca, numcu, nodesSent.size) + case head :: tail if !channels.contains(head) => + log.warning("received query for shortChannelId={} that we don't have", head) + loop(tail, flags.drop(1), numca, numcu, nodesSent) + case head :: tail => + val numca1 = numca + val numcu1 = numcu + var sent1 = nodesSent + val pc = channels(head) + val flag_opt = flags.headOption + // no flag means send everything + + val includeChannel = flag_opt.forall(QueryFlagType.includeChannelAnnouncement) + val includeUpdate1 = flag_opt.forall(QueryFlagType.includeUpdate1) + val includeUpdate2 = flag_opt.forall(QueryFlagType.includeUpdate2) + val includeNode1 = flag_opt.forall(QueryFlagType.includeNodeAnnouncement1) + val includeNode2 = flag_opt.forall(QueryFlagType.includeNodeAnnouncement2) + + if (includeChannel) { + onChannel(pc.ann) + } + if (includeUpdate1) { + pc.update_1_opt.foreach { u => + onUpdate(u) + } + } + if (includeUpdate2) { + pc.update_2_opt.foreach { u => + onUpdate(u) + } + } + if (includeNode1 && !sent1.contains(pc.ann.nodeId1)) { + nodes.get(pc.ann.nodeId1).foreach { n => + onNode(n) + sent1 = sent1 + pc.ann.nodeId1 + } + } + if (includeNode2 && !sent1.contains(pc.ann.nodeId2)) { + nodes.get(pc.ann.nodeId2).foreach { n => + onNode(n) + sent1 = sent1 + pc.ann.nodeId2 + } + } + loop(tail, flags.drop(1), numca1, numcu1, sent1) + } + + loop(ids, flags) + } + + /** + * Returns overall progress on synchronization + * + * @return a sync progress indicator (1 means fully synced) + */ + def syncProgress(sync: Map[PublicKey, Syncing]): SyncProgress = { + // NB: progress is in terms of requests, not individual channels + val (pending, total) = sync.foldLeft((0, 0)) { + case ((p, t), (_, sync)) => (p + sync.pending.size, t + sync.total) + } + if (total == 0) { + SyncProgress(1) + } else { + SyncProgress((total - pending) / (1.0 * total)) + } + } + + def getChannelDigestInfo(channels: SortedMap[ShortChannelId, PublicChannel])(shortChannelId: ShortChannelId): (ReplyChannelRangeTlv.Timestamps, ReplyChannelRangeTlv.Checksums) = { + val c = channels(shortChannelId) + val timestamp1 = c.update_1_opt.map(_.timestamp).getOrElse(0L) + val timestamp2 = c.update_2_opt.map(_.timestamp).getOrElse(0L) + val checksum1 = c.update_1_opt.map(getChecksum).getOrElse(0L) + val checksum2 = c.update_2_opt.map(getChecksum).getOrElse(0L) + (ReplyChannelRangeTlv.Timestamps(timestamp1 = timestamp1, timestamp2 = timestamp2), ReplyChannelRangeTlv.Checksums(checksum1 = checksum1, checksum2 = checksum2)) + } + + def crc32c(data: ByteVector): Long = { + import com.google.common.hash.Hashing + Hashing.crc32c().hashBytes(data.toArray).asInt() & 0xFFFFFFFFL + } + + def getChecksum(u: ChannelUpdate): Long = { + import u._ + + val data = serializationResult(LightningMessageCodecs.channelUpdateChecksumCodec.encode(chainHash :: shortChannelId :: messageFlags :: channelFlags :: cltvExpiryDelta :: htlcMinimumMsat :: feeBaseMsat :: feeProportionalMillionths :: htlcMaximumMsat :: HNil)) + crc32c(data) + } + + case class ShortChannelIdsChunk(firstBlock: Long, numBlocks: Long, shortChannelIds: List[ShortChannelId]) { + /** + * + * @param maximumSize maximum size of the short channel ids list + * @return a chunk with at most `maximumSize` ids + */ + def enforceMaximumSize(maximumSize: Int) = { + if (shortChannelIds.size <= maximumSize) this else { + // we use a random offset here, so even if shortChannelIds.size is much bigger than maximumSize (which should + // not happen) peers will eventually receive info about all channels in this chunk + val offset = Random.nextInt(shortChannelIds.size - maximumSize + 1) + this.copy(shortChannelIds = this.shortChannelIds.slice(offset, offset + maximumSize)) + } + } + } + + /** + * Split short channel ids into chunks, because otherwise message could be too big + * there could be several reply_channel_range messages for a single query, but we make sure that the returned + * chunks fully covers the [firstBlockNum, numberOfBlocks] range that was requested + * + * @param shortChannelIds list of short channel ids to split + * @param firstBlockNum first block height requested by our peers + * @param numberOfBlocks number of blocks requested by our peer + * @param channelRangeChunkSize target chunk size. All ids that have the same block height will be grouped together, so + * returned chunks may still contain more than `channelRangeChunkSize` elements + * @return a list of short channel id chunks + */ + def split(shortChannelIds: SortedSet[ShortChannelId], firstBlockNum: Long, numberOfBlocks: Long, channelRangeChunkSize: Int): List[ShortChannelIdsChunk] = { + // see BOLT7: MUST encode a short_channel_id for every open channel it knows in blocks first_blocknum to first_blocknum plus number_of_blocks minus one + val it = shortChannelIds.iterator.dropWhile(_.blockHeight < firstBlockNum).takeWhile(_.blockHeight < firstBlockNum + numberOfBlocks) + if (it.isEmpty) { + List(ShortChannelIdsChunk(firstBlockNum, numberOfBlocks, List.empty)) + } else { + // we want to split ids in different chunks, with the following rules by order of priority + // ids that have the same block height must be grouped in the same chunk + // chunk should contain `channelRangeChunkSize` ids + @tailrec + def loop(currentChunk: List[ShortChannelId], acc: List[ShortChannelIdsChunk]): List[ShortChannelIdsChunk] = { + if (it.hasNext) { + val id = it.next() + val currentHeight = currentChunk.head.blockHeight + if (id.blockHeight == currentHeight) + loop(id :: currentChunk, acc) // same height => always add to the current chunk + else if (currentChunk.size < channelRangeChunkSize) // different height but we're under the size target => add to the current chunk + loop(id :: currentChunk, acc) // different height and over the size target => start a new chunk + else { + // we always prepend because it's more efficient so we have to reverse the current chunk + // for the first chunk, we make sure that we start at the request first block + // for the next chunks we start at the end of the range covered by the last chunk + val first = if (acc.isEmpty) firstBlockNum else acc.head.firstBlock + acc.head.numBlocks + val count = currentChunk.head.blockHeight - first + 1 + loop(id :: Nil, ShortChannelIdsChunk(first, count, currentChunk.reverse) :: acc) + } + } + else { + // for the last chunk, we make sure that we cover the requested block range + val first = if (acc.isEmpty) firstBlockNum else acc.head.firstBlock + acc.head.numBlocks + val count = numberOfBlocks - first + firstBlockNum + (ShortChannelIdsChunk(first, count, currentChunk.reverse) :: acc).reverse + } + } + + val first = it.next() + val chunks = loop(first :: Nil, Nil) + + // make sure that all our chunks match our max size policy + enforceMaximumSize(chunks) + } + } + + /** + * Enforce max-size constraints for each chunk + * + * @param chunks list of short channel id chunks + * @return a processed list of chunks + */ + def enforceMaximumSize(chunks: List[ShortChannelIdsChunk]): List[ShortChannelIdsChunk] = chunks.map(_.enforceMaximumSize(MAXIMUM_CHUNK_SIZE)) + + /** + * Build a `reply_channel_range` message + * + * @param chunk chunk of scids + * @param chainHash chain hash + * @param defaultEncoding default encoding + * @param queryFlags_opt query flag set by the requester + * @param channels channels map + * @return a ReplyChannelRange object + */ + def buildReplyChannelRange(chunk: ShortChannelIdsChunk, chainHash: ByteVector32, defaultEncoding: EncodingType, queryFlags_opt: Option[QueryChannelRangeTlv.QueryFlags], channels: SortedMap[ShortChannelId, PublicChannel]): ReplyChannelRange = { + val encoding = if (chunk.shortChannelIds.isEmpty) EncodingType.UNCOMPRESSED else defaultEncoding + val (timestamps, checksums) = queryFlags_opt match { + case Some(extension) if extension.wantChecksums | extension.wantTimestamps => + // we always compute timestamps and checksums even if we don't need both, overhead is negligible + val (timestamps, checksums) = chunk.shortChannelIds.map(getChannelDigestInfo(channels)).unzip + val encodedTimestamps = if (extension.wantTimestamps) Some(ReplyChannelRangeTlv.EncodedTimestamps(encoding, timestamps)) else None + val encodedChecksums = if (extension.wantChecksums) Some(ReplyChannelRangeTlv.EncodedChecksums(checksums)) else None + (encodedTimestamps, encodedChecksums) + case _ => (None, None) + } + ReplyChannelRange(chainHash, chunk.firstBlock, chunk.numBlocks, + complete = 1, + shortChannelIds = EncodedShortChannelIds(encoding, chunk.shortChannelIds), + timestamps = timestamps, + checksums = checksums) + } + + def addToSync(syncMap: Map[PublicKey, Syncing], remoteNodeId: PublicKey, pending: List[RoutingMessage]): (Map[PublicKey, Syncing], Option[RoutingMessage]) = { + pending match { + case head +: rest => + // they may send back several reply_channel_range messages for a single query_channel_range query, and we must not + // send another query_short_channel_ids query if they're still processing one + syncMap.get(remoteNodeId) match { + case None => + // we don't have a pending query with this peer, let's send it + (syncMap + (remoteNodeId -> Syncing(rest, pending.size)), Some(head)) + case Some(sync) => + // we already have a pending query with this peer, add missing ids to our "sync" state + (syncMap + (remoteNodeId -> Syncing(sync.pending ++ pending, sync.total + pending.size)), None) + } + case Nil => + // there is nothing to send + (syncMap, None) + } + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala new file mode 100644 index 000000000..8d035d350 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala @@ -0,0 +1,450 @@ +/* + * Copyright 2020 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.router + +import akka.actor.{ActorContext, ActorRef} +import akka.event.{DiagnosticLoggingAdapter, LoggingAdapter} +import fr.acinq.bitcoin.Crypto.PublicKey +import fr.acinq.bitcoin.Script.{pay2wsh, write} +import fr.acinq.eclair.blockchain.{UtxoStatus, ValidateRequest, ValidateResult, WatchSpentBasic} +import fr.acinq.eclair.channel.{BITCOIN_FUNDING_EXTERNAL_CHANNEL_SPENT, LocalChannelDown, LocalChannelUpdate} +import fr.acinq.eclair.crypto.TransportHandler +import fr.acinq.eclair.db.NetworkDb +import fr.acinq.eclair.router.Monitoring.Metrics +import fr.acinq.eclair.router.Router._ +import fr.acinq.eclair.transactions.Scripts +import fr.acinq.eclair.wire._ +import fr.acinq.eclair.{Logs, NodeParams, ShortChannelId, TxCoordinates} +import kamon.Kamon + +object Validation { + + def sendDecision(peerConnection: ActorRef, decision: GossipDecision)(implicit sender: ActorRef): Unit = { + peerConnection ! decision + Metrics.gossipResult(decision).increment() + } + + def handleChannelAnnouncement(d: Data, db: NetworkDb, watcher: ActorRef, origin: RemoteGossip, c: ChannelAnnouncement)(implicit ctx: ActorContext, log: LoggingAdapter): Data = { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + log.debug("received channel announcement for shortChannelId={} nodeId1={} nodeId2={}", c.shortChannelId, c.nodeId1, c.nodeId2) + if (d.channels.contains(c.shortChannelId)) { + origin.peerConnection ! TransportHandler.ReadAck(c) + log.debug("ignoring {} (duplicate)", c) + sendDecision(origin.peerConnection, GossipDecision.Duplicate(c)) + d + } else if (d.awaiting.contains(c)) { + origin.peerConnection ! TransportHandler.ReadAck(c) + log.debug("ignoring {} (being verified)", c) + // adding the sender to the list of origins so that we don't send back the same announcement to this peer later + val origins = d.awaiting(c) :+ origin + d.copy(awaiting = d.awaiting + (c -> origins)) + } else if (db.isPruned(c.shortChannelId)) { + origin.peerConnection ! TransportHandler.ReadAck(c) + // channel was pruned and we haven't received a recent channel_update, so we have no reason to revalidate it + log.debug("ignoring {} (was pruned)", c) + sendDecision(origin.peerConnection, GossipDecision.ChannelPruned(c)) + d + } else if (!Announcements.checkSigs(c)) { + origin.peerConnection ! TransportHandler.ReadAck(c) + log.warning("bad signature for announcement {}", c) + sendDecision(origin.peerConnection, GossipDecision.InvalidSignature(c)) + d + } else { + log.info("validating shortChannelId={}", c.shortChannelId) + Kamon.runWithContextEntry(shortChannelIdKey, c.shortChannelId) { + Kamon.runWithSpan(Kamon.spanBuilder("validate-channel").tag("shortChannelId", c.shortChannelId.toString).start(), finishSpan = false) { + watcher ! ValidateRequest(c) + } + } + // we don't acknowledge the message just yet + d.copy(awaiting = d.awaiting + (c -> Seq(origin))) + } + } + + def handleChannelValidationResponse(d0: Data, nodeParams: NodeParams, watcher: ActorRef, r: ValidateResult)(implicit ctx: ActorContext, log: DiagnosticLoggingAdapter): Data = { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + import nodeParams.db.{network => db} + import r.c + Kamon.runWithContextEntry(shortChannelIdKey, c.shortChannelId) { + Kamon.runWithSpan(Kamon.currentSpan(), finishSpan = true) { + Kamon.runWithSpan(Kamon.spanBuilder("process-validate-result").start(), finishSpan = true) { + d0.awaiting.get(c) match { + case Some(origin +: _) => origin.peerConnection ! TransportHandler.ReadAck(c) // now we can acknowledge the message, we only need to do it for the first peer that sent us the announcement + case _ => () + } + val remoteOrigins_opt = d0.awaiting.get(c) + Logs.withMdc(log)(Logs.mdc(remoteNodeId_opt = remoteOrigins_opt.flatMap(_.headOption).map(_.nodeId))) { // in the MDC we use the node id that sent us the announcement first + log.info("got validation result for shortChannelId={} (awaiting={} stash.nodes={} stash.updates={})", c.shortChannelId, d0.awaiting.size, d0.stash.nodes.size, d0.stash.updates.size) + val publicChannel_opt = r match { + case ValidateResult(c, Left(t)) => + log.warning("validation failure for shortChannelId={} reason={}", c.shortChannelId, t.getMessage) + remoteOrigins_opt.foreach(_.foreach(o => sendDecision(o.peerConnection, GossipDecision.ValidationFailure(c)))) + None + case ValidateResult(c, Right((tx, UtxoStatus.Unspent))) => + val TxCoordinates(_, _, outputIndex) = ShortChannelId.coordinates(c.shortChannelId) + val (fundingOutputScript, ok) = Kamon.runWithSpan(Kamon.spanBuilder("checked-pubkeyscript").start(), finishSpan = true) { + // let's check that the output is indeed a P2WSH multisig 2-of-2 of nodeid1 and nodeid2) + val fundingOutputScript = write(pay2wsh(Scripts.multiSig2of2(c.bitcoinKey1, c.bitcoinKey2))) + val ok = tx.txOut.size < outputIndex + 1 || fundingOutputScript != tx.txOut(outputIndex).publicKeyScript + (fundingOutputScript, ok) + } + if (ok) { + log.error(s"invalid script for shortChannelId={}: txid={} does not have script=$fundingOutputScript at outputIndex=$outputIndex ann={}", c.shortChannelId, tx.txid, c) + remoteOrigins_opt.foreach(_.foreach(o => sendDecision(o.peerConnection, GossipDecision.InvalidAnnouncement(c)))) + None + } else { + watcher ! WatchSpentBasic(ctx.self, tx, outputIndex, BITCOIN_FUNDING_EXTERNAL_CHANNEL_SPENT(c.shortChannelId)) + // TODO: check feature bit set + log.debug("added channel channelId={}", c.shortChannelId) + remoteOrigins_opt.foreach(_.foreach(o => sendDecision(o.peerConnection, GossipDecision.Accepted(c)))) + val capacity = tx.txOut(outputIndex).amount + ctx.system.eventStream.publish(ChannelsDiscovered(SingleChannelDiscovered(c, capacity, None, None) :: Nil)) + Kamon.runWithSpan(Kamon.spanBuilder("add-to-db").start(), finishSpan = true) { + db.addChannel(c, tx.txid, capacity) + } + // in case we just validated our first local channel, we announce the local node + if (!d0.nodes.contains(nodeParams.nodeId) && isRelatedTo(c, nodeParams.nodeId)) { + log.info("first local channel validated, announcing local node") + val nodeAnn = Announcements.makeNodeAnnouncement(nodeParams.privateKey, nodeParams.alias, nodeParams.color, nodeParams.publicAddresses, nodeParams.features) + ctx.self ! nodeAnn + } + Some(PublicChannel(c, tx.txid, capacity, None, None)) + } + case ValidateResult(c, Right((tx, fundingTxStatus: UtxoStatus.Spent))) => + if (fundingTxStatus.spendingTxConfirmed) { + log.warning("ignoring shortChannelId={} tx={} (funding tx already spent and spending tx is confirmed)", c.shortChannelId, tx.txid) + // the funding tx has been spent by a transaction that is now confirmed: peer shouldn't send us those + remoteOrigins_opt.foreach(_.foreach(o => sendDecision(o.peerConnection, GossipDecision.ChannelClosed(c)))) + } else { + log.debug("ignoring shortChannelId={} tx={} (funding tx already spent but spending tx isn't confirmed)", c.shortChannelId, tx.txid) + remoteOrigins_opt.foreach(_.foreach(o => sendDecision(o.peerConnection, GossipDecision.ChannelClosing(c)))) + } + // there may be a record if we have just restarted + db.removeChannel(c.shortChannelId) + None + } + val span1 = Kamon.spanBuilder("reprocess-stash").start + // we also reprocess node and channel_update announcements related to channels that were just analyzed + val reprocessUpdates = d0.stash.updates.filterKeys(u => u.shortChannelId == c.shortChannelId) + val reprocessNodes = d0.stash.nodes.filterKeys(n => isRelatedTo(c, n.nodeId)) + // and we remove the reprocessed messages from the stash + val stash1 = d0.stash.copy(updates = d0.stash.updates -- reprocessUpdates.keys, nodes = d0.stash.nodes -- reprocessNodes.keys) + // we remove channel from awaiting map + val awaiting1 = d0.awaiting - c + span1.finish() + + publicChannel_opt match { + case Some(pc) => + Kamon.runWithSpan(Kamon.spanBuilder("build-new-state").start, finishSpan = true) { + // note: if the channel is graduating from private to public, the implementation (in the LocalChannelUpdate handler) guarantees that we will process a new channel_update + // right after the channel_announcement, channel_updates will be moved from private to public at that time + val d1 = d0.copy( + channels = d0.channels + (c.shortChannelId -> pc), + privateChannels = d0.privateChannels - c.shortChannelId, // we remove fake announcements that we may have made before + rebroadcast = d0.rebroadcast.copy(channels = d0.rebroadcast.channels + (c -> d0.awaiting.getOrElse(c, Nil).toSet)), // we also add the newly validated channels to the rebroadcast queue + stash = stash1, + awaiting = awaiting1) + // we only reprocess updates and nodes if validation succeeded + val d2 = reprocessUpdates.foldLeft(d1) { + case (d, (u, origins)) => Validation.handleChannelUpdate(d, nodeParams.db.network, nodeParams.routerConf, origins, u, wasStashed = true) + } + val d3 = reprocessNodes.foldLeft(d2) { + case (d, (n, origins)) => Validation.handleNodeAnnouncement(d, nodeParams.db.network, origins, n, wasStashed = true) + } + d3 + } + case None => + reprocessUpdates.foreach { case (u, origins) => origins.collect { case o: RemoteGossip => sendDecision(o.peerConnection, GossipDecision.NoRelatedChannel(u)) } } + reprocessNodes.foreach { case (n, origins) => origins.collect { case o: RemoteGossip => sendDecision(o.peerConnection, GossipDecision.NoKnownChannel(n)) } } + d0.copy(stash = stash1, awaiting = awaiting1) + } + } + } + } + } + } + + def handleChannelSpent(d: Data, db: NetworkDb, event: BITCOIN_FUNDING_EXTERNAL_CHANNEL_SPENT)(implicit ctx: ActorContext, log: LoggingAdapter): Data = { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + import event.shortChannelId + val lostChannel = d.channels(shortChannelId).ann + log.info("funding tx of channelId={} has been spent", shortChannelId) + // we need to remove nodes that aren't tied to any channels anymore + val channels1 = d.channels - lostChannel.shortChannelId + val lostNodes = Seq(lostChannel.nodeId1, lostChannel.nodeId2).filterNot(nodeId => hasChannels(nodeId, channels1.values)) + // let's clean the db and send the events + log.info("pruning shortChannelId={} (spent)", shortChannelId) + db.removeChannel(shortChannelId) // NB: this also removes channel updates + // we also need to remove updates from the graph + val graph1 = d.graph + .removeEdge(ChannelDesc(lostChannel.shortChannelId, lostChannel.nodeId1, lostChannel.nodeId2)) + .removeEdge(ChannelDesc(lostChannel.shortChannelId, lostChannel.nodeId2, lostChannel.nodeId1)) + + ctx.system.eventStream.publish(ChannelLost(shortChannelId)) + lostNodes.foreach { + nodeId => + log.info("pruning nodeId={} (spent)", nodeId) + db.removeNode(nodeId) + ctx.system.eventStream.publish(NodeLost(nodeId)) + } + d.copy(nodes = d.nodes -- lostNodes, channels = d.channels - shortChannelId, graph = graph1) + } + + def handleNodeAnnouncement(d: Data, db: NetworkDb, origins: Set[GossipOrigin], n: NodeAnnouncement, wasStashed: Boolean = false)(implicit ctx: ActorContext, log: LoggingAdapter): Data = { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + val remoteOrigins = origins flatMap { + case r: RemoteGossip if wasStashed => + Some(r.peerConnection) + case RemoteGossip(peerConnection, _) => + peerConnection ! TransportHandler.ReadAck(n) + log.debug("received node announcement for nodeId={}", n.nodeId) + Some(peerConnection) + case LocalGossip => + log.debug("received node announcement from {}", ctx.sender) + None + } + if (d.stash.nodes.contains(n)) { + log.debug("ignoring {} (already stashed)", n) + val origins1 = d.stash.nodes(n) ++ origins + d.copy(stash = d.stash.copy(nodes = d.stash.nodes + (n -> origins1))) + } else if (d.rebroadcast.nodes.contains(n)) { + log.debug("ignoring {} (pending rebroadcast)", n) + remoteOrigins.foreach(sendDecision(_, GossipDecision.Accepted(n))) + val origins1 = d.rebroadcast.nodes(n) ++ origins + d.copy(rebroadcast = d.rebroadcast.copy(nodes = d.rebroadcast.nodes + (n -> origins1))) + } else if (d.nodes.contains(n.nodeId) && d.nodes(n.nodeId).timestamp >= n.timestamp) { + log.debug("ignoring {} (duplicate)", n) + remoteOrigins.foreach(sendDecision(_, GossipDecision.Duplicate(n))) + d + } else if (!Announcements.checkSig(n)) { + log.warning("bad signature for {}", n) + remoteOrigins.foreach(sendDecision(_, GossipDecision.InvalidSignature(n))) + d + } else if (d.nodes.contains(n.nodeId)) { + log.debug("updated node nodeId={}", n.nodeId) + remoteOrigins.foreach(sendDecision(_, GossipDecision.Accepted(n))) + ctx.system.eventStream.publish(NodeUpdated(n)) + db.updateNode(n) + d.copy(nodes = d.nodes + (n.nodeId -> n), rebroadcast = d.rebroadcast.copy(nodes = d.rebroadcast.nodes + (n -> origins))) + } else if (d.channels.values.exists(c => isRelatedTo(c.ann, n.nodeId))) { + log.debug("added node nodeId={}", n.nodeId) + remoteOrigins.foreach(sendDecision(_, GossipDecision.Accepted(n))) + ctx.system.eventStream.publish(NodesDiscovered(n :: Nil)) + db.addNode(n) + d.copy(nodes = d.nodes + (n.nodeId -> n), rebroadcast = d.rebroadcast.copy(nodes = d.rebroadcast.nodes + (n -> origins))) + } else if (d.awaiting.keys.exists(c => isRelatedTo(c, n.nodeId))) { + log.debug("stashing {}", n) + d.copy(stash = d.stash.copy(nodes = d.stash.nodes + (n -> origins))) + } else { + log.debug("ignoring {} (no related channel found)", n) + remoteOrigins.foreach(sendDecision(_, GossipDecision.NoKnownChannel(n))) + // there may be a record if we have just restarted + db.removeNode(n.nodeId) + d + } + } + + def handleChannelUpdate(d: Data, db: NetworkDb, routerConf: RouterConf, origins: Set[GossipOrigin], u: ChannelUpdate, wasStashed: Boolean = false)(implicit ctx: ActorContext, log: LoggingAdapter): Data = { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + val remoteOrigins = origins flatMap { + case r: RemoteGossip if wasStashed => + Some(r.peerConnection) + case RemoteGossip(peerConnection, _) => + peerConnection ! TransportHandler.ReadAck(u) + log.debug("received channel update for shortChannelId={}", u.shortChannelId) + Some(peerConnection) + case LocalGossip => + log.debug("received channel update from {}", ctx.sender) + None + } + if (d.channels.contains(u.shortChannelId)) { + // related channel is already known (note: this means no related channel_update is in the stash) + val publicChannel = true + val pc = d.channels(u.shortChannelId) + val desc = getDesc(u, pc.ann) + if (d.rebroadcast.updates.contains(u)) { + log.debug("ignoring {} (pending rebroadcast)", u) + remoteOrigins.foreach(sendDecision(_, GossipDecision.Accepted(u))) + val origins1 = d.rebroadcast.updates(u) ++ origins + d.copy(rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> origins1))) + } else if (StaleChannels.isStale(u)) { + log.debug("ignoring {} (stale)", u) + remoteOrigins.foreach(sendDecision(_, GossipDecision.Stale(u))) + d + } else if (pc.getChannelUpdateSameSideAs(u).exists(_.timestamp >= u.timestamp)) { + log.debug("ignoring {} (duplicate)", u) + remoteOrigins.foreach(sendDecision(_, GossipDecision.Duplicate(u))) + d + } else if (!Announcements.checkSig(u, pc.getNodeIdSameSideAs(u))) { + log.warning("bad signature for announcement shortChannelId={} {}", u.shortChannelId, u) + remoteOrigins.foreach(sendDecision(_, GossipDecision.InvalidSignature(u))) + d + } else if (pc.getChannelUpdateSameSideAs(u).isDefined) { + log.debug("updated channel_update for shortChannelId={} public={} flags={} {}", u.shortChannelId, publicChannel, u.channelFlags, u) + remoteOrigins.foreach(sendDecision(_, GossipDecision.Accepted(u))) + ctx.system.eventStream.publish(ChannelUpdatesReceived(u :: Nil)) + db.updateChannel(u) + // update the graph + val graph1 = if (Announcements.isEnabled(u.channelFlags)) { + d.graph.removeEdge(desc).addEdge(desc, u) + } else { + d.graph.removeEdge(desc) + } + d.copy(channels = d.channels + (u.shortChannelId -> pc.updateChannelUpdateSameSideAs(u)), rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> origins)), graph = graph1) + } else { + log.debug("added channel_update for shortChannelId={} public={} flags={} {}", u.shortChannelId, publicChannel, u.channelFlags, u) + remoteOrigins.foreach(sendDecision(_, GossipDecision.Accepted(u))) + ctx.system.eventStream.publish(ChannelUpdatesReceived(u :: Nil)) + db.updateChannel(u) + // we also need to update the graph + val graph1 = d.graph.addEdge(desc, u) + d.copy(channels = d.channels + (u.shortChannelId -> pc.updateChannelUpdateSameSideAs(u)), privateChannels = d.privateChannels - u.shortChannelId, rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> origins)), graph = graph1) + } + } else if (d.awaiting.keys.exists(c => c.shortChannelId == u.shortChannelId)) { + // channel is currently being validated + if (d.stash.updates.contains(u)) { + log.debug("ignoring {} (already stashed)", u) + val origins1 = d.stash.updates(u) ++ origins + d.copy(stash = d.stash.copy(updates = d.stash.updates + (u -> origins1))) + } else { + log.debug("stashing {}", u) + d.copy(stash = d.stash.copy(updates = d.stash.updates + (u -> origins))) + } + } else if (d.privateChannels.contains(u.shortChannelId)) { + val publicChannel = false + val pc = d.privateChannels(u.shortChannelId) + val desc = if (Announcements.isNode1(u.channelFlags)) ChannelDesc(u.shortChannelId, pc.nodeId1, pc.nodeId2) else ChannelDesc(u.shortChannelId, pc.nodeId2, pc.nodeId1) + if (StaleChannels.isStale(u)) { + log.debug("ignoring {} (stale)", u) + remoteOrigins.foreach(sendDecision(_, GossipDecision.Stale(u))) + d + } else if (pc.getChannelUpdateSameSideAs(u).exists(_.timestamp >= u.timestamp)) { + log.debug("ignoring {} (already know same or newer)", u) + remoteOrigins.foreach(sendDecision(_, GossipDecision.Duplicate(u))) + d + } else if (!Announcements.checkSig(u, desc.a)) { + log.warning("bad signature for announcement shortChannelId={} {}", u.shortChannelId, u) + remoteOrigins.foreach(sendDecision(_, GossipDecision.InvalidSignature(u))) + d + } else if (pc.getChannelUpdateSameSideAs(u).isDefined) { + log.debug("updated channel_update for shortChannelId={} public={} flags={} {}", u.shortChannelId, publicChannel, u.channelFlags, u) + remoteOrigins.foreach(sendDecision(_, GossipDecision.Accepted(u))) + ctx.system.eventStream.publish(ChannelUpdatesReceived(u :: Nil)) + // we also need to update the graph + val graph1 = d.graph.removeEdge(desc).addEdge(desc, u) + d.copy(privateChannels = d.privateChannels + (u.shortChannelId -> pc.updateChannelUpdateSameSideAs(u)), graph = graph1) + } else { + log.debug("added channel_update for shortChannelId={} public={} flags={} {}", u.shortChannelId, publicChannel, u.channelFlags, u) + remoteOrigins.foreach(sendDecision(_, GossipDecision.Accepted(u))) + ctx.system.eventStream.publish(ChannelUpdatesReceived(u :: Nil)) + // we also need to update the graph + val graph1 = d.graph.addEdge(desc, u) + d.copy(privateChannels = d.privateChannels + (u.shortChannelId -> pc.updateChannelUpdateSameSideAs(u)), graph = graph1) + } + } else if (db.isPruned(u.shortChannelId) && !StaleChannels.isStale(u)) { + // the channel was recently pruned, but if we are here, it means that the update is not stale so this is the case + // of a zombie channel coming back from the dead. they probably sent us a channel_announcement right before this update, + // but we ignored it because the channel was in the 'pruned' list. Now that we know that the channel is alive again, + // let's remove the channel from the zombie list and ask the sender to re-send announcements (channel_announcement + updates) + // about that channel. We can ignore this update since we will receive it again + log.info(s"channel shortChannelId=${u.shortChannelId} is back from the dead! requesting announcements about this channel") + remoteOrigins.foreach(sendDecision(_, GossipDecision.Duplicate(u))) + db.removeFromPruned(u.shortChannelId) + + // peerConnection_opt will contain a valid peerConnection only when we're handling an update that we received from a peer, not + // when we're sending updates to ourselves + origins head match { + case RemoteGossip(peerConnection, remoteNodeId) => + val query = QueryShortChannelIds(u.chainHash, EncodedShortChannelIds(routerConf.encodingType, List(u.shortChannelId)), TlvStream.empty) + d.sync.get(remoteNodeId) match { + case Some(sync) => + // we already have a pending request to that node, let's add this channel to the list and we'll get it later + // TODO: we only request channels with old style channel_query + d.copy(sync = d.sync + (remoteNodeId -> sync.copy(pending = sync.pending :+ query, total = sync.total + 1))) + case None => + // we send the query right away + peerConnection ! query + d.copy(sync = d.sync + (remoteNodeId -> Syncing(pending = Nil, total = 1))) + } + case _ => + // we don't know which node this update came from (maybe it was stashed and the channel got pruned in the meantime or some other corner case). + // or we don't have a peerConnection to send our query to. + // anyway, that's not really a big deal because we have removed the channel from the pruned db so next time it shows up we will revalidate it + d + } + } else { + log.debug("ignoring announcement {} (unknown channel)", u) + remoteOrigins.foreach(sendDecision(_, GossipDecision.NoRelatedChannel(u))) + d + } + } + + def handleLocalChannelUpdate(d: Data, db: NetworkDb, routerConf: RouterConf, localNodeId: PublicKey, watcher: ActorRef, lcu: LocalChannelUpdate)(implicit ctx: ActorContext, log: LoggingAdapter): Data = { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + import lcu.{channelAnnouncement_opt, shortChannelId, channelUpdate => u} + d.channels.get(shortChannelId) match { + case Some(_) => + // channel has already been announced and router knows about it, we can process the channel_update + handleChannelUpdate(d, db, routerConf, Set(LocalGossip), u) + case None => + channelAnnouncement_opt match { + case Some(c) if d.awaiting.contains(c) => + // channel is currently being verified, we can process the channel_update right away (it will be stashed) + handleChannelUpdate(d, db, routerConf, Set(LocalGossip), u) + case Some(c) => + // channel wasn't announced but here is the announcement, we will process it *before* the channel_update + watcher ! ValidateRequest(c) + val d1 = d.copy(awaiting = d.awaiting + (c -> Nil)) // no origin + // maybe the local channel was pruned (can happen if we were disconnected for more than 2 weeks) + db.removeFromPruned(c.shortChannelId) + handleChannelUpdate(d1, db, routerConf, Set(LocalGossip), u) + case None if d.privateChannels.contains(shortChannelId) => + // channel isn't announced but we already know about it, we can process the channel_update + handleChannelUpdate(d, db, routerConf, Set(LocalGossip), u) + case None => + // channel isn't announced and we never heard of it (maybe it is a private channel or maybe it is a public channel that doesn't yet have 6 confirmations) + // let's create a corresponding private channel and process the channel_update + log.debug("adding unannounced local channel to remote={} shortChannelId={}", lcu.remoteNodeId, shortChannelId) + val d1 = d.copy(privateChannels = d.privateChannels + (shortChannelId -> PrivateChannel(localNodeId, lcu.remoteNodeId, None, None))) + handleChannelUpdate(d1, db, routerConf, Set(LocalGossip), u) + } + } + } + + def handleLocalChannelDown(d: Data, localNodeId: PublicKey, lcd: LocalChannelDown)(implicit log: LoggingAdapter): Data = { + import lcd.{channelId, remoteNodeId, shortChannelId} + // a local channel has permanently gone down + if (d.channels.contains(shortChannelId)) { + // the channel was public, we will receive (or have already received) a WatchEventSpentBasic event, that will trigger a clean up of the channel + // so let's not do anything here + d + } else if (d.privateChannels.contains(shortChannelId)) { + // the channel was private or public-but-not-yet-announced, let's do the clean up + log.debug("removing private local channel and channel_update for channelId={} shortChannelId={}", channelId, shortChannelId) + val desc1 = ChannelDesc(shortChannelId, localNodeId, remoteNodeId) + val desc2 = ChannelDesc(shortChannelId, remoteNodeId, localNodeId) + // we remove the corresponding updates from the graph + val graph1 = d.graph + .removeEdge(desc1) + .removeEdge(desc2) + // and we remove the channel and channel_update from our state + d.copy(privateChannels = d.privateChannels - shortChannelId, graph = graph1) + } else { + d + } + } +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala index 9e9889aa0..90055d853 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala @@ -38,6 +38,7 @@ sealed trait SetupMessage extends LightningMessage sealed trait ChannelMessage extends LightningMessage sealed trait HtlcMessage extends LightningMessage sealed trait RoutingMessage extends LightningMessage +sealed trait AnnouncementMessage extends RoutingMessage // <- not in the spec sealed trait HasTimestamp extends LightningMessage { def timestamp: Long } sealed trait HasTemporaryChannelId extends LightningMessage { def temporaryChannelId: ByteVector32 } // <- not in the spec sealed trait HasChannelId extends LightningMessage { def channelId: ByteVector32 } // <- not in the spec @@ -168,7 +169,7 @@ case class ChannelAnnouncement(nodeSignature1: ByteVector64, nodeId2: PublicKey, bitcoinKey1: PublicKey, bitcoinKey2: PublicKey, - unknownFields: ByteVector = ByteVector.empty) extends RoutingMessage with HasChainHash + unknownFields: ByteVector = ByteVector.empty) extends RoutingMessage with AnnouncementMessage with HasChainHash case class Color(r: Byte, g: Byte, b: Byte) { override def toString: String = f"#$r%02x$g%02x$b%02x" // to hexa s"# ${r}%02x ${r & 0xFF}${g & 0xFF}${b & 0xFF}" @@ -211,7 +212,7 @@ case class NodeAnnouncement(signature: ByteVector64, rgbColor: Color, alias: String, addresses: List[NodeAddress], - unknownFields: ByteVector = ByteVector.empty) extends RoutingMessage with HasTimestamp + unknownFields: ByteVector = ByteVector.empty) extends RoutingMessage with AnnouncementMessage with HasTimestamp case class ChannelUpdate(signature: ByteVector64, chainHash: ByteVector32, @@ -224,7 +225,7 @@ case class ChannelUpdate(signature: ByteVector64, feeBaseMsat: MilliSatoshi, feeProportionalMillionths: Long, htlcMaximumMsat: Option[MilliSatoshi], - unknownFields: ByteVector = ByteVector.empty) extends RoutingMessage with HasTimestamp with HasChainHash { + unknownFields: ByteVector = ByteVector.empty) extends RoutingMessage with AnnouncementMessage with HasTimestamp with HasChainHash { require(((messageFlags & 1) != 0) == htlcMaximumMsat.isDefined, "htlcMaximumMsat is not consistent with messageFlags") def isNode1 = Announcements.isNode1(channelFlags) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala index 588b646a7..9c5708a68 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala @@ -34,7 +34,8 @@ import fr.acinq.eclair.payment.receive.MultiPartHandler.ReceivePayment import fr.acinq.eclair.payment.receive.PaymentHandler import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentRequest, SendPaymentToRouteRequest} import fr.acinq.eclair.router.RouteCalculationSpec.makeUpdate -import fr.acinq.eclair.router.{Announcements, GetNetworkStats, GetNetworkStatsResponse, NetworkStats, PublicChannel, Router, Stats} +import fr.acinq.eclair.router.Router.{GetNetworkStats, GetNetworkStatsResponse, PublicChannel} +import fr.acinq.eclair.router.{Announcements, NetworkStats, Router, Stats} import org.mockito.Mockito import org.mockito.scalatest.IdiomaticMockito import org.scalatest.{Outcome, ParallelTestExecution, fixture} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index ba11bd7d6..9d2814ccd 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -20,13 +20,13 @@ import java.sql.{Connection, DriverManager} import java.util.concurrent.atomic.AtomicLong import fr.acinq.bitcoin.Crypto.PrivateKey -import fr.acinq.bitcoin.{Block, Btc, ByteVector32, Script} +import fr.acinq.bitcoin.{Block, ByteVector32, Script} import fr.acinq.eclair.NodeParams.BITCOIND import fr.acinq.eclair.blockchain.fee.{FeeEstimator, FeeTargets, FeeratesPerKw, OnChainFeeConf} import fr.acinq.eclair.crypto.LocalKeyManager import fr.acinq.eclair.db._ import fr.acinq.eclair.io.Peer -import fr.acinq.eclair.router.RouterConf +import fr.acinq.eclair.router.Router.RouterConf import fr.acinq.eclair.wire.{Color, EncodingType, NodeAddress} import scodec.bits.ByteVector diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala index dbd58c02f..22b060f20 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala @@ -20,8 +20,7 @@ import java.util.UUID import java.util.concurrent.CountDownLatch import akka.actor.{Actor, ActorLogging, ActorRef, Props, Status} -import akka.testkit -import akka.testkit.{TestActor, TestFSMRef, TestProbe} +import akka.testkit.{TestFSMRef, TestProbe} import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.TestConstants.{Alice, Bob} @@ -32,7 +31,7 @@ import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.receive.MultiPartHandler.ReceivePayment import fr.acinq.eclair.payment.receive.PaymentHandler import fr.acinq.eclair.payment.relay.{CommandBuffer, Relayer} -import fr.acinq.eclair.router.ChannelHop +import fr.acinq.eclair.router.Router.ChannelHop import fr.acinq.eclair.wire.Onion.FinalLegacyPayload import fr.acinq.eclair.wire._ import grizzled.slf4j.Logging diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/StateTestsHelperMethods.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/StateTestsHelperMethods.scala index a6e47e3c6..404182e2d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/StateTestsHelperMethods.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/StateTestsHelperMethods.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.blockchain.fee.FeeTargets import fr.acinq.eclair.channel._ import fr.acinq.eclair.io.Peer import fr.acinq.eclair.payment.OutgoingPacket -import fr.acinq.eclair.router.ChannelHop +import fr.acinq.eclair.router.Router.ChannelHop import fr.acinq.eclair.wire.Onion.FinalLegacyPayload import fr.acinq.eclair.wire._ import fr.acinq.eclair.{NodeParams, TestConstants, randomBytes32, _} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala index ecafcf963..f4472902d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala @@ -29,7 +29,7 @@ import fr.acinq.eclair.channel.states.StateTestsHelperMethods import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.relay.Relayer._ import fr.acinq.eclair.payment.relay.{CommandBuffer, Origin} -import fr.acinq.eclair.router.ChannelHop +import fr.acinq.eclair.router.Router.ChannelHop import fr.acinq.eclair.wire.Onion.FinalLegacyPayload import fr.acinq.eclair.wire.{CommitSig, Error, FailureMessageCodecs, PermanentChannelFailure, RevokeAndAck, Shutdown, UpdateAddHtlc, UpdateFailHtlc, UpdateFailMalformedHtlc, UpdateFee, UpdateFulfillHtlc} import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, TestConstants, TestkitBaseClass, randomBytes32} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala index d1237e64d..60af98901 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala @@ -22,7 +22,8 @@ import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64, Crypto, Satoshi} import fr.acinq.eclair.db.sqlite.SqliteNetworkDb import fr.acinq.eclair.db.sqlite.SqliteUtils._ -import fr.acinq.eclair.router.{Announcements, PublicChannel} +import fr.acinq.eclair.router.Announcements +import fr.acinq.eclair.router.Router.PublicChannel import fr.acinq.eclair.wire.{Color, NodeAddress, Tor2} import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, ShortChannelId, TestConstants, randomBytes32, randomKey} import org.scalatest.FunSuite diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala index 700c15f4b..a35470624 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala @@ -24,7 +24,7 @@ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.sqlite.SqlitePaymentsDb import fr.acinq.eclair.db.sqlite.SqliteUtils._ import fr.acinq.eclair.payment._ -import fr.acinq.eclair.router.{ChannelHop, NodeHop} +import fr.acinq.eclair.router.Router.{ChannelHop, NodeHop} import fr.acinq.eclair.wire.{ChannelUpdate, UnknownNextPeer} import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, ShortChannelId, TestConstants, randomBytes32, randomBytes64, randomKey} import org.scalatest.FunSuite diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala index b8e9bc599..76463e020 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala @@ -33,6 +33,7 @@ import fr.acinq.eclair.channel.ChannelCommandResponse.ChannelOpened import fr.acinq.eclair.channel.Register.{Forward, ForwardShortId} import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx.DecryptedFailurePacket +import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.db._ import fr.acinq.eclair.io.Peer import fr.acinq.eclair.io.Peer.{Disconnect, PeerRoutingMessage} @@ -44,9 +45,11 @@ import fr.acinq.eclair.payment.relay.Relayer import fr.acinq.eclair.payment.relay.Relayer.{GetOutgoingChannels, OutgoingChannels} import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentRequest, SendTrampolinePaymentRequest} import fr.acinq.eclair.payment.send.PaymentLifecycle.{State => _} +import fr.acinq.eclair.router.{Announcements, AnnouncementsBatchValidationSpec} import fr.acinq.eclair.router.Graph.WeightRatios -import fr.acinq.eclair.router.Router.ROUTE_MAX_LENGTH -import fr.acinq.eclair.router.{Announcements, AnnouncementsBatchValidationSpec, PublicChannel, RouteParams} +import fr.acinq.eclair.router.RouteCalculation.ROUTE_MAX_LENGTH +import fr.acinq.eclair.router.Router.{GossipDecision, PublicChannel, RouteParams} +import fr.acinq.eclair.router.Router.{NORMAL => _, State => _, _} import fr.acinq.eclair.transactions.Transactions import fr.acinq.eclair.transactions.Transactions.{HtlcSuccessTx, HtlcTimeoutTx} import fr.acinq.eclair.wire._ @@ -1255,7 +1258,11 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService // then we make the announcements val announcements = channels.map(c => AnnouncementsBatchValidationSpec.makeChannelAnnouncement(c)) - announcements.foreach(ann => nodes("A").router ! PeerRoutingMessage(sender.ref, remoteNodeId, ann)) + announcements.foreach { ann => + nodes("A").router ! PeerRoutingMessage(sender.ref, remoteNodeId, ann) + sender.expectMsg(TransportHandler.ReadAck(ann)) + sender.expectMsg(GossipDecision.Accepted(ann)) + } awaitCond({ sender.send(nodes("D").router, 'channels) sender.expectMsgType[Iterable[ChannelAnnouncement]](5 seconds).size == channels.size + 8 // 8 remaining channels because D->F{1-5} have disappeared diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala index ed89863f4..6b5ca302d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala @@ -18,7 +18,7 @@ package fr.acinq.eclair.io import java.net.{Inet4Address, InetSocketAddress} -import akka.actor.{ActorRef, PoisonPill} +import akka.actor.PoisonPill import akka.testkit.{TestFSMRef, TestProbe} import fr.acinq.bitcoin.Block import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} @@ -26,6 +26,7 @@ import fr.acinq.eclair.TestConstants._ import fr.acinq.eclair._ import fr.acinq.eclair.channel.states.StateTestsHelperMethods import fr.acinq.eclair.crypto.TransportHandler +import fr.acinq.eclair.router.Router.{GossipDecision, GossipOrigin, LocalGossip, Rebroadcast, RemoteGossip, SendChannelQuery} import fr.acinq.eclair.router.{RoutingSyncSpec, _} import fr.acinq.eclair.wire._ import org.scalatest.{Outcome, Tag} @@ -260,7 +261,7 @@ class PeerConnectionSpec extends TestkitBaseClass with StateTestsHelperMethods { test("filter gossip message (no filtering)") { f => import f._ val probe = TestProbe() - val gossipOrigin = Set[GossipOrigin](RemoteGossip(TestProbe().ref)) + val gossipOrigin = Set[GossipOrigin](RemoteGossip(TestProbe().ref, randomKey.publicKey)) connect(remoteNodeId, switchboard, router, connection, transport, peerConnection, peer) val rebroadcast = Rebroadcast(channels.map(_ -> gossipOrigin).toMap, updates.map(_ -> gossipOrigin).toMap, nodes.map(_ -> gossipOrigin).toMap) probe.send(peerConnection, rebroadcast) @@ -270,12 +271,12 @@ class PeerConnectionSpec extends TestkitBaseClass with StateTestsHelperMethods { test("filter gossip message (filtered by origin)") { f => import f._ connect(remoteNodeId, switchboard, router, connection, transport, peerConnection, peer) - val gossipOrigin = Set[GossipOrigin](RemoteGossip(TestProbe().ref)) - val pcActor: ActorRef = peerConnection + val gossipOrigin = Set[GossipOrigin](RemoteGossip(TestProbe().ref, randomKey.publicKey)) + val bobOrigin = RemoteGossip(peerConnection, remoteNodeId) val rebroadcast = Rebroadcast( - channels.map(_ -> gossipOrigin).toMap + (channels(5) -> Set(RemoteGossip(pcActor))), - updates.map(_ -> gossipOrigin).toMap + (updates(6) -> (gossipOrigin + RemoteGossip(pcActor))) + (updates(10) -> Set(RemoteGossip(pcActor))), - nodes.map(_ -> gossipOrigin).toMap + (nodes(4) -> Set(RemoteGossip(pcActor)))) + channels.map(_ -> gossipOrigin).toMap + (channels(5) -> Set(bobOrigin)), + updates.map(_ -> gossipOrigin).toMap + (updates(6) -> (gossipOrigin + bobOrigin)) + (updates(10) -> Set(bobOrigin)), + nodes.map(_ -> gossipOrigin).toMap + (nodes(4) -> Set(bobOrigin))) val filter = wire.GossipTimestampFilter(Alice.nodeParams.chainHash, 0, Long.MaxValue) // no filtering on timestamps transport.send(peerConnection, filter) transport.expectMsg(TransportHandler.ReadAck(filter)) @@ -289,7 +290,7 @@ class PeerConnectionSpec extends TestkitBaseClass with StateTestsHelperMethods { test("filter gossip message (filtered by timestamp)") { f => import f._ connect(remoteNodeId, switchboard, router, connection, transport, peerConnection, peer) - val gossipOrigin = Set[GossipOrigin](RemoteGossip(TestProbe().ref)) + val gossipOrigin = Set[GossipOrigin](RemoteGossip(TestProbe().ref, randomKey.publicKey)) val rebroadcast = Rebroadcast(channels.map(_ -> gossipOrigin).toMap, updates.map(_ -> gossipOrigin).toMap, nodes.map(_ -> gossipOrigin).toMap) val timestamps = updates.map(_.timestamp).sorted.slice(10, 30) val filter = wire.GossipTimestampFilter(Alice.nodeParams.chainHash, timestamps.head, timestamps.last - timestamps.head) @@ -307,7 +308,7 @@ class PeerConnectionSpec extends TestkitBaseClass with StateTestsHelperMethods { import f._ val probe = TestProbe() connect(remoteNodeId, switchboard, router, connection, transport, peerConnection, peer) - val gossipOrigin = Set[GossipOrigin](RemoteGossip(TestProbe().ref)) + val gossipOrigin = Set[GossipOrigin](RemoteGossip(TestProbe().ref, randomKey.publicKey)) val rebroadcast = Rebroadcast( channels.map(_ -> gossipOrigin).toMap + (channels(5) -> Set(LocalGossip)), updates.map(_ -> gossipOrigin).toMap + (updates(6) -> (gossipOrigin + LocalGossip)) + (updates(10) -> Set(LocalGossip)), @@ -340,7 +341,7 @@ class PeerConnectionSpec extends TestkitBaseClass with StateTestsHelperMethods { // let's assume that the router isn't happy with those channels because the funding tx is already spent for (c <- channels) { - router.send(peerConnection, PeerConnection.ChannelClosed(c)) + router.send(peerConnection, GossipDecision.ChannelClosed(c)) } // peer will temporary ignore announcements coming from bob for (ann <- channels ++ updates) { @@ -363,14 +364,14 @@ class PeerConnectionSpec extends TestkitBaseClass with StateTestsHelperMethods { transport.expectNoMsg(1 second) // peer hasn't acknowledged the messages // now let's assume that the router isn't happy with those channels because the announcement is invalid - router.send(peerConnection, PeerConnection.InvalidAnnouncement(channels(0))) + router.send(peerConnection, GossipDecision.InvalidAnnouncement(channels(0))) // peer will return a connection-wide error, including the hex-encoded representation of the bad message val error1 = transport.expectMsgType[Error] assert(error1.channelId === Peer.CHANNELID_ZERO) assert(new String(error1.data.toArray).startsWith("couldn't verify channel! shortChannelId=")) // let's assume that one of the sigs were invalid - router.send(peerConnection, PeerConnection.InvalidSignature(channels(0))) + router.send(peerConnection, GossipDecision.InvalidSignature(channels(0))) // peer will return a connection-wide error, including the hex-encoded representation of the bad message val error2 = transport.expectMsgType[Error] assert(error2.channelId === Peer.CHANNELID_ZERO) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala index 5ea9efed4..5767e2b4d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala @@ -32,6 +32,7 @@ import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle._ import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPayment import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentError} +import fr.acinq.eclair.router.Router.{ChannelHop, GetNetworkStats, GetNetworkStatsResponse, RouteParams, TickComputeNetworkStats} import fr.acinq.eclair.router._ import fr.acinq.eclair.wire._ import org.scalatest.{Outcome, Tag, fixture} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala index 99397db55..ba582226e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala @@ -22,6 +22,7 @@ import akka.actor.{ActorRef, ActorSystem} import akka.testkit.{TestActorRef, TestKit, TestProbe} import fr.acinq.bitcoin.Block import fr.acinq.eclair.Features._ +import fr.acinq.eclair.UInt64.Conversions._ import fr.acinq.eclair.channel.{Channel, Upstream} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.PaymentPacketSpec._ @@ -30,12 +31,11 @@ import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.SendMultiPartPayme import fr.acinq.eclair.payment.send.PaymentInitiator._ import fr.acinq.eclair.payment.send.PaymentLifecycle.{SendPayment, SendPaymentToRoute} import fr.acinq.eclair.payment.send.{PaymentError, PaymentInitiator} -import fr.acinq.eclair.router.{NodeHop, RouteParams} +import fr.acinq.eclair.router.Router.{NodeHop, RouteParams} import fr.acinq.eclair.wire.Onion.{FinalLegacyPayload, FinalTlvPayload} import fr.acinq.eclair.wire.OnionTlv.{AmountToForward, OutgoingCltv} -import fr.acinq.eclair.wire._ +import fr.acinq.eclair.wire.{Onion, OnionCodecs, OnionTlv, TrampolineFeeInsufficient, _} import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, NodeParams, TestConstants, randomBytes32, randomKey} -import fr.acinq.eclair.UInt64.Conversions._ import org.scalatest.{Outcome, Tag, fixture} import scodec.bits.HexStringSyntax diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index 9c7b6dafa..f6f36e5b7 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -38,6 +38,7 @@ import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentConfig, SendPay import fr.acinq.eclair.payment.send.PaymentLifecycle import fr.acinq.eclair.payment.send.PaymentLifecycle._ import fr.acinq.eclair.router.Announcements.{makeChannelUpdate, makeNodeAnnouncement} +import fr.acinq.eclair.router.Router.{ChannelDesc, ChannelHop, ExcludeChannel, FinalizeRoute, RouteParams, RouteRequest, RouteResponse} import fr.acinq.eclair.router._ import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire.Onion.FinalLegacyPayload @@ -145,13 +146,13 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val payFixture = createPaymentLifecycle() import payFixture._ - val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 3, routePrefix = Seq(ChannelHop(a, b, channelUpdate_ab), ChannelHop(b, c, channelUpdate_bc))) + val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 3, routePrefix = Seq(ChannelHop(a, b, update_ab), ChannelHop(b, c, update_bc))) sender.send(paymentFSM, request) routerForwarder.expectMsg(RouteRequest(c, d, defaultAmountMsat, ignoreNodes = Set(a, b))) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) awaitCond(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) - routerForwarder.send(paymentFSM, RouteResponse(Seq(ChannelHop(c, d, channelUpdate_cd)), Set.empty, Set.empty)) + routerForwarder.send(paymentFSM, RouteResponse(Seq(ChannelHop(c, d, update_cd)), Set.empty, Set.empty)) val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) } @@ -159,7 +160,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val payFixture = createPaymentLifecycle() import payFixture._ - val request = SendPayment(c, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 3, routePrefix = Seq(ChannelHop(a, b, channelUpdate_ab), ChannelHop(b, c, channelUpdate_bc))) + val request = SendPayment(c, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 3, routePrefix = Seq(ChannelHop(a, b, update_ab), ChannelHop(b, c, update_bc))) sender.send(paymentFSM, request) routerForwarder.expectNoMsg(50 millis) // we don't need the router when we already have the whole route val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -171,13 +172,13 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val payFixture = createPaymentLifecycle() import payFixture._ - val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 3, routePrefix = Seq(ChannelHop(a, b, channelUpdate_ab), ChannelHop(b, c, channelUpdate_bc))) + val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 3, routePrefix = Seq(ChannelHop(a, b, update_ab), ChannelHop(b, c, update_bc))) sender.send(paymentFSM, request) routerForwarder.expectMsg(RouteRequest(c, d, defaultAmountMsat, ignoreNodes = Set(a, b))) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) awaitCond(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) - routerForwarder.send(paymentFSM, RouteResponse(Seq(ChannelHop(c, d, channelUpdate_cd)), Set(a, b), Set.empty)) + routerForwarder.send(paymentFSM, RouteResponse(Seq(ChannelHop(c, d, update_cd)), Set(a, b), Set.empty)) val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) sender.send(paymentFSM, UpdateFailHtlc(randomBytes32, 0, randomBytes(Sphinx.FailurePacket.PacketLength))) @@ -306,14 +307,14 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, _, hops) = paymentFSM.stateData register.expectMsg(ForwardShortId(channelId_ab, cmd1)) - val failure = TemporaryChannelFailure(channelUpdate_bc) + val failure = TemporaryChannelFailure(update_bc) sender.send(paymentFSM, UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure))) // payment lifecycle will ask the router to temporarily exclude this channel from its route calculations - routerForwarder.expectMsg(ExcludeChannel(ChannelDesc(channelUpdate_bc.shortChannelId, b, c))) + routerForwarder.expectMsg(ExcludeChannel(ChannelDesc(update_bc.shortChannelId, b, c))) routerForwarder.forward(routerFixture.router) // payment lifecycle forwards the embedded channelUpdate to the router - routerForwarder.expectMsg(channelUpdate_bc) + routerForwarder.expectMsg(update_bc) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) routerForwarder.forward(routerFixture.router) @@ -337,7 +338,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { register.expectMsg(ForwardShortId(channelId_ab, cmd1)) // we change the cltv expiry - val channelUpdate_bc_modified = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, channelId_bc, CltvExpiryDelta(42), htlcMinimumMsat = channelUpdate_bc.htlcMinimumMsat, feeBaseMsat = channelUpdate_bc.feeBaseMsat, feeProportionalMillionths = channelUpdate_bc.feeProportionalMillionths, htlcMaximumMsat = channelUpdate_bc.htlcMaximumMsat.get) + val channelUpdate_bc_modified = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, channelId_bc, CltvExpiryDelta(42), htlcMinimumMsat = update_bc.htlcMinimumMsat, feeBaseMsat = update_bc.feeBaseMsat, feeProportionalMillionths = update_bc.feeProportionalMillionths, htlcMaximumMsat = update_bc.htlcMaximumMsat.get) val failure = IncorrectCltvExpiry(CltvExpiry(5), channelUpdate_bc_modified) // and node replies with a failure containing a new channel update sender.send(paymentFSM, UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure))) @@ -354,13 +355,13 @@ class PaymentLifecycleSpec extends BaseRouterSpec { register.expectMsg(ForwardShortId(channelId_ab, cmd2)) // we change the cltv expiry one more time - val channelUpdate_bc_modified_2 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, channelId_bc, CltvExpiryDelta(43), htlcMinimumMsat = channelUpdate_bc.htlcMinimumMsat, feeBaseMsat = channelUpdate_bc.feeBaseMsat, feeProportionalMillionths = channelUpdate_bc.feeProportionalMillionths, htlcMaximumMsat = channelUpdate_bc.htlcMaximumMsat.get) + val channelUpdate_bc_modified_2 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, channelId_bc, CltvExpiryDelta(43), htlcMinimumMsat = update_bc.htlcMinimumMsat, feeBaseMsat = update_bc.feeBaseMsat, feeProportionalMillionths = update_bc.feeProportionalMillionths, htlcMaximumMsat = update_bc.htlcMaximumMsat.get) val failure2 = IncorrectCltvExpiry(CltvExpiry(5), channelUpdate_bc_modified_2) // and node replies with a failure containing a new channel update sender.send(paymentFSM, UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets2.head._1, failure2))) // this time the payment lifecycle will ask the router to temporarily exclude this channel from its route calculations - routerForwarder.expectMsg(ExcludeChannel(ChannelDesc(channelUpdate_bc.shortChannelId, b, c))) + routerForwarder.expectMsg(ExcludeChannel(ChannelDesc(update_bc.shortChannelId, b, c))) routerForwarder.forward(routerFixture.router) // but it will still forward the embedded channelUpdate to the router routerForwarder.expectMsg(channelUpdate_bc_modified_2) @@ -379,8 +380,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // we build an assisted route for channel bc and cd val assistedRoutes = Seq(Seq( - ExtraHop(b, channelId_bc, channelUpdate_bc.feeBaseMsat, channelUpdate_bc.feeProportionalMillionths, channelUpdate_bc.cltvExpiryDelta), - ExtraHop(c, channelId_cd, channelUpdate_cd.feeBaseMsat, channelUpdate_cd.feeProportionalMillionths, channelUpdate_cd.cltvExpiryDelta) + ExtraHop(b, channelId_bc, update_bc.feeBaseMsat, update_bc.feeProportionalMillionths, update_bc.cltvExpiryDelta), + ExtraHop(c, channelId_cd, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.cltvExpiryDelta) )) val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 5, assistedRoutes = assistedRoutes) @@ -395,7 +396,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { register.expectMsg(ForwardShortId(channelId_ab, cmd1)) // we change the cltv expiry - val channelUpdate_bc_modified = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, channelId_bc, CltvExpiryDelta(42), htlcMinimumMsat = channelUpdate_bc.htlcMinimumMsat, feeBaseMsat = channelUpdate_bc.feeBaseMsat, feeProportionalMillionths = channelUpdate_bc.feeProportionalMillionths, htlcMaximumMsat = channelUpdate_bc.htlcMaximumMsat.get) + val channelUpdate_bc_modified = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, channelId_bc, CltvExpiryDelta(42), htlcMinimumMsat = update_bc.htlcMinimumMsat, feeBaseMsat = update_bc.feeBaseMsat, feeProportionalMillionths = update_bc.feeProportionalMillionths, htlcMaximumMsat = update_bc.htlcMaximumMsat.get) val failure = IncorrectCltvExpiry(CltvExpiry(5), channelUpdate_bc_modified) // and node replies with a failure containing a new channel update sender.send(paymentFSM, UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure))) @@ -404,8 +405,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { routerForwarder.expectMsg(channelUpdate_bc_modified) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) // 1 failure but not final, the payment is still PENDING val assistedRoutes1 = Seq(Seq( - ExtraHop(b, channelId_bc, channelUpdate_bc.feeBaseMsat, channelUpdate_bc.feeProportionalMillionths, channelUpdate_bc_modified.cltvExpiryDelta), - ExtraHop(c, channelId_cd, channelUpdate_cd.feeBaseMsat, channelUpdate_cd.feeProportionalMillionths, channelUpdate_cd.cltvExpiryDelta) + ExtraHop(b, channelId_bc, update_bc.feeBaseMsat, update_bc.feeProportionalMillionths, channelUpdate_bc_modified.cltvExpiryDelta), + ExtraHop(c, channelId_cd, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.cltvExpiryDelta) )) routerForwarder.expectMsg(RouteRequest(nodeParams.nodeId, d, defaultAmountMsat, assistedRoutes = assistedRoutes1, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) routerForwarder.forward(routerFixture.router) @@ -494,10 +495,11 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val channelUpdate_bg = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, g, channelId_bg, CltvExpiryDelta(9), htlcMinimumMsat = 0 msat, feeBaseMsat = 0 msat, feeProportionalMillionths = 0, htlcMaximumMsat = 500000000 msat) val channelUpdate_gb = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_g, b, channelId_bg, CltvExpiryDelta(9), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 8, htlcMaximumMsat = 500000000 msat) assert(Router.getDesc(channelUpdate_bg, chan_bg) === ChannelDesc(chan_bg.shortChannelId, priv_b.publicKey, priv_g.publicKey)) - router ! PeerRoutingMessage(null, remoteNodeId, chan_bg) - router ! PeerRoutingMessage(null, remoteNodeId, ann_g) - router ! PeerRoutingMessage(null, remoteNodeId, channelUpdate_bg) - router ! PeerRoutingMessage(null, remoteNodeId, channelUpdate_gb) + val peerConnection = TestProbe() + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_bg) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, ann_g) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, channelUpdate_bg) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, channelUpdate_gb) watcher.expectMsg(ValidateRequest(chan_bg)) watcher.send(router, ValidateResult(chan_bg, Right((Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_b, funding_g)))) :: Nil, lockTime = 0), UtxoStatus.Unspent)))) watcher.expectMsgType[WatchSpentBasic] @@ -528,9 +530,9 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("filter errors properly") { _ => - val failures = LocalFailure(RouteNotFound) :: RemoteFailure(ChannelHop(a, b, channelUpdate_ab) :: Nil, Sphinx.DecryptedFailurePacket(a, TemporaryNodeFailure)) :: LocalFailure(AddHtlcFailed(ByteVector32.Zeroes, ByteVector32.Zeroes, ChannelUnavailable(ByteVector32.Zeroes), Local(UUID.randomUUID(), None), None, None)) :: LocalFailure(RouteNotFound) :: Nil + val failures = LocalFailure(RouteNotFound) :: RemoteFailure(ChannelHop(a, b, update_ab) :: Nil, Sphinx.DecryptedFailurePacket(a, TemporaryNodeFailure)) :: LocalFailure(AddHtlcFailed(ByteVector32.Zeroes, ByteVector32.Zeroes, ChannelUnavailable(ByteVector32.Zeroes), Local(UUID.randomUUID(), None), None, None)) :: LocalFailure(RouteNotFound) :: Nil val filtered = PaymentFailure.transformForUser(failures) - assert(filtered == LocalFailure(RouteNotFound) :: RemoteFailure(ChannelHop(a, b, channelUpdate_ab) :: Nil, Sphinx.DecryptedFailurePacket(a, TemporaryNodeFailure)) :: LocalFailure(ChannelUnavailable(ByteVector32.Zeroes)) :: Nil) + assert(filtered == LocalFailure(RouteNotFound) :: RemoteFailure(ChannelHop(a, b, update_ab) :: Nil, Sphinx.DecryptedFailurePacket(a, TemporaryNodeFailure)) :: LocalFailure(ChannelUnavailable(ByteVector32.Zeroes)) :: Nil) } test("disable database and events") { routerFixture => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala index 163b4778a..e0017690a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.IncomingPacket.{ChannelRelayPacket, FinalPacket, NodeRelayPacket, decrypt} import fr.acinq.eclair.payment.OutgoingPacket._ import fr.acinq.eclair.payment.PaymentRequest.Features -import fr.acinq.eclair.router.{ChannelHop, NodeHop} +import fr.acinq.eclair.router.Router.{ChannelHop, NodeHop} import fr.acinq.eclair.wire.Onion.{FinalLegacyPayload, FinalTlvPayload, RelayLegacyPayload} import fr.acinq.eclair.wire.OnionTlv.{AmountToForward, OutgoingCltv, PaymentData} import fr.acinq.eclair.wire._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala index 84f7d9f32..450a69395 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala @@ -29,7 +29,7 @@ import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus, PaymentType} import fr.acinq.eclair.payment.OutgoingPacket.buildCommand import fr.acinq.eclair.payment.PaymentPacketSpec._ import fr.acinq.eclair.payment.relay.{CommandBuffer, Origin, PostRestartHtlcCleaner, Relayer} -import fr.acinq.eclair.router.ChannelHop +import fr.acinq.eclair.router.Router.ChannelHop import fr.acinq.eclair.transactions.{DirectedHtlc, IncomingHtlc, OutgoingHtlc} import fr.acinq.eclair.wire.Onion.FinalLegacyPayload import fr.acinq.eclair.wire._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala index 7def845cf..fa7b3ec2f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala @@ -28,7 +28,8 @@ import fr.acinq.eclair.payment.OutgoingPacket.{buildCommand, buildOnion, buildPa import fr.acinq.eclair.payment.relay.Origin._ import fr.acinq.eclair.payment.relay.Relayer._ import fr.acinq.eclair.payment.relay.{CommandBuffer, Relayer} -import fr.acinq.eclair.router._ +import fr.acinq.eclair.router.Router.{ChannelHop, GetNetworkStats, GetNetworkStatsResponse, NodeHop, TickComputeNetworkStats} +import fr.acinq.eclair.router.{Announcements, _} import fr.acinq.eclair.wire.Onion.{ChannelRelayTlvPayload, FinalLegacyPayload, FinalTlvPayload, PerHopPayload} import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, NodeParams, ShortChannelId, TestConstants, TestkitBaseClass, UInt64, nodeFee, randomBytes32} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala index c358490b4..f38387f45 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala @@ -26,6 +26,7 @@ import fr.acinq.eclair.blockchain.{UtxoStatus, ValidateRequest, ValidateResult, import fr.acinq.eclair.crypto.LocalKeyManager import fr.acinq.eclair.io.Peer.PeerRoutingMessage import fr.acinq.eclair.router.Announcements._ +import fr.acinq.eclair.router.Router.ChannelDesc import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire._ import fr.acinq.eclair.{TestkitBaseClass, randomKey, _} @@ -42,7 +43,7 @@ import scala.concurrent.duration._ abstract class BaseRouterSpec extends TestkitBaseClass { - case class FixtureParam(router: ActorRef, watcher: TestProbe) + case class FixtureParam(nodeParams: NodeParams, router: ActorRef, watcher: TestProbe) val remoteNodeId = PrivateKey(ByteVector32(ByteVector.fill(32)(1))).publicKey @@ -55,12 +56,12 @@ abstract class BaseRouterSpec extends TestkitBaseClass { val (priv_funding_a, priv_funding_b, priv_funding_c, priv_funding_d, priv_funding_e, priv_funding_f) = (randomKey, randomKey, randomKey, randomKey, randomKey, randomKey) val (funding_a, funding_b, funding_c, funding_d, funding_e, funding_f) = (priv_funding_a.publicKey, priv_funding_b.publicKey, priv_funding_c.publicKey, priv_funding_d.publicKey, priv_funding_e.publicKey, priv_funding_f.publicKey) - val ann_a = makeNodeAnnouncement(priv_a, "node-A", Color(15, 10, -70), Nil, hex"0200") - val ann_b = makeNodeAnnouncement(priv_b, "node-B", Color(50, 99, -80), Nil, hex"") - val ann_c = makeNodeAnnouncement(priv_c, "node-C", Color(123, 100, -40), Nil, hex"0200") - val ann_d = makeNodeAnnouncement(priv_d, "node-D", Color(-120, -20, 60), Nil, hex"00") - val ann_e = makeNodeAnnouncement(priv_e, "node-E", Color(-50, 0, 10), Nil, hex"00") - val ann_f = makeNodeAnnouncement(priv_f, "node-F", Color(30, 10, -50), Nil, hex"00") + val node_a = makeNodeAnnouncement(priv_a, "node-A", Color(15, 10, -70), Nil, hex"0200") + val node_b = makeNodeAnnouncement(priv_b, "node-B", Color(50, 99, -80), Nil, hex"") + val node_c = makeNodeAnnouncement(priv_c, "node-C", Color(123, 100, -40), Nil, hex"0200") + val node_d = makeNodeAnnouncement(priv_d, "node-D", Color(-120, -20, 60), Nil, hex"00") + val node_e = makeNodeAnnouncement(priv_e, "node-E", Color(-50, 0, 10), Nil, hex"00") + val node_f = makeNodeAnnouncement(priv_f, "node-F", Color(30, 10, -50), Nil, hex"00") val channelId_ab = ShortChannelId(420000, 1, 0) val channelId_bc = ShortChannelId(420000, 2, 0) @@ -78,14 +79,14 @@ abstract class BaseRouterSpec extends TestkitBaseClass { val chan_cd = channelAnnouncement(channelId_cd, priv_c, priv_d, priv_funding_c, priv_funding_d) val chan_ef = channelAnnouncement(channelId_ef, priv_e, priv_f, priv_funding_e, priv_funding_f) - val channelUpdate_ab = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, b, channelId_ab, CltvExpiryDelta(7), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 10, htlcMaximumMsat = 500000000 msat) - val channelUpdate_ba = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, a, channelId_ab, CltvExpiryDelta(7), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 10, htlcMaximumMsat = 500000000 msat) - val channelUpdate_bc = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, channelId_bc, CltvExpiryDelta(5), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 1, htlcMaximumMsat = 500000000 msat) - val channelUpdate_cb = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_c, b, channelId_bc, CltvExpiryDelta(5), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 1, htlcMaximumMsat = 500000000 msat) - val channelUpdate_cd = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_c, d, channelId_cd, CltvExpiryDelta(3), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 4, htlcMaximumMsat = 500000000 msat) - val channelUpdate_dc = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_d, c, channelId_cd, CltvExpiryDelta(3), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 4, htlcMaximumMsat = 500000000 msat) - val channelUpdate_ef = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_e, f, channelId_ef, CltvExpiryDelta(9), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 8, htlcMaximumMsat = 500000000 msat) - val channelUpdate_fe = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_f, e, channelId_ef, CltvExpiryDelta(9), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 8, htlcMaximumMsat = 500000000 msat) + val update_ab = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, b, channelId_ab, CltvExpiryDelta(7), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 10, htlcMaximumMsat = 500000000 msat) + val update_ba = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, a, channelId_ab, CltvExpiryDelta(7), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 10, htlcMaximumMsat = 500000000 msat) + val update_bc = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, channelId_bc, CltvExpiryDelta(5), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 1, htlcMaximumMsat = 500000000 msat) + val update_cb = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_c, b, channelId_bc, CltvExpiryDelta(5), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 1, htlcMaximumMsat = 500000000 msat) + val update_cd = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_c, d, channelId_cd, CltvExpiryDelta(3), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 4, htlcMaximumMsat = 500000000 msat) + val update_dc = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_d, c, channelId_cd, CltvExpiryDelta(3), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 4, htlcMaximumMsat = 500000000 msat) + val update_ef = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_e, f, channelId_ef, CltvExpiryDelta(9), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 8, htlcMaximumMsat = 500000000 msat) + val update_fe = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_f, e, channelId_ef, CltvExpiryDelta(9), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 8, htlcMaximumMsat = 500000000 msat) override def withFixture(test: OneArgTest): Outcome = { // the network will be a --(1)--> b ---(2)--> c --(3)--> d and e --(4)--> f (we are a) @@ -93,36 +94,38 @@ abstract class BaseRouterSpec extends TestkitBaseClass { within(30 seconds) { // first we make sure that we correctly resolve channelId+direction to nodeId - assert(Router.getDesc(channelUpdate_ab, chan_ab) === ChannelDesc(chan_ab.shortChannelId, priv_a.publicKey, priv_b.publicKey)) - assert(Router.getDesc(channelUpdate_bc, chan_bc) === ChannelDesc(chan_bc.shortChannelId, priv_b.publicKey, priv_c.publicKey)) - assert(Router.getDesc(channelUpdate_cd, chan_cd) === ChannelDesc(chan_cd.shortChannelId, priv_c.publicKey, priv_d.publicKey)) - assert(Router.getDesc(channelUpdate_ef, chan_ef) === ChannelDesc(chan_ef.shortChannelId, priv_e.publicKey, priv_f.publicKey)) + assert(Router.getDesc(update_ab, chan_ab) === ChannelDesc(chan_ab.shortChannelId, priv_a.publicKey, priv_b.publicKey)) + assert(Router.getDesc(update_bc, chan_bc) === ChannelDesc(chan_bc.shortChannelId, priv_b.publicKey, priv_c.publicKey)) + assert(Router.getDesc(update_cd, chan_cd) === ChannelDesc(chan_cd.shortChannelId, priv_c.publicKey, priv_d.publicKey)) + assert(Router.getDesc(update_ef, chan_ef) === ChannelDesc(chan_ef.shortChannelId, priv_e.publicKey, priv_f.publicKey)) - // let's we set up the router + // let's set up the router + val peerConnection = TestProbe() val watcher = TestProbe() - val router = system.actorOf(Router.props(Alice.nodeParams, watcher.ref)) + val nodeParams = Alice.nodeParams + val router = system.actorOf(Router.props(nodeParams, watcher.ref)) // we announce channels - router ! PeerRoutingMessage(null, remoteNodeId, chan_ab) - router ! PeerRoutingMessage(null, remoteNodeId, chan_bc) - router ! PeerRoutingMessage(null, remoteNodeId, chan_cd) - router ! PeerRoutingMessage(null, remoteNodeId, chan_ef) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_ab)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_bc)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_cd)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_ef)) // then nodes - router ! PeerRoutingMessage(null, remoteNodeId, ann_a) - router ! PeerRoutingMessage(null, remoteNodeId, ann_b) - router ! PeerRoutingMessage(null, remoteNodeId, ann_c) - router ! PeerRoutingMessage(null, remoteNodeId, ann_d) - router ! PeerRoutingMessage(null, remoteNodeId, ann_e) - router ! PeerRoutingMessage(null, remoteNodeId, ann_f) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_a)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_b)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_c)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_d)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_e)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_f)) // then channel updates - router ! PeerRoutingMessage(null, remoteNodeId, channelUpdate_ab) - router ! PeerRoutingMessage(null, remoteNodeId, channelUpdate_ba) - router ! PeerRoutingMessage(null, remoteNodeId, channelUpdate_bc) - router ! PeerRoutingMessage(null, remoteNodeId, channelUpdate_cb) - router ! PeerRoutingMessage(null, remoteNodeId, channelUpdate_cd) - router ! PeerRoutingMessage(null, remoteNodeId, channelUpdate_dc) - router ! PeerRoutingMessage(null, remoteNodeId, channelUpdate_ef) - router ! PeerRoutingMessage(null, remoteNodeId, channelUpdate_fe) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ab)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ba)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_bc)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_cb)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_cd)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_dc)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ef)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_fe)) // watcher receives the get tx requests watcher.expectMsg(ValidateRequest(chan_ab)) watcher.expectMsg(ValidateRequest(chan_bc)) @@ -151,7 +154,7 @@ abstract class BaseRouterSpec extends TestkitBaseClass { nodes.size === 6 && channels.size === 4 && updates.size === 8 }, max = 10 seconds, interval = 1 second) - withFixture(test.toNoArgTest(FixtureParam(router, watcher))) + withFixture(test.toNoArgTest(FixtureParam(nodeParams, router, watcher))) } } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala index 860883760..94748d8ac 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala @@ -17,10 +17,11 @@ package fr.acinq.eclair.router import fr.acinq.bitcoin.{Block, ByteVector32} -import fr.acinq.eclair.router.Router.ShortChannelIdsChunk +import fr.acinq.eclair.router.Router.PublicChannel +import fr.acinq.eclair.router.Sync._ import fr.acinq.eclair.wire.QueryChannelRangeTlv.QueryFlags -import fr.acinq.eclair.wire.{EncodedShortChannelIds, EncodingType, QueryChannelRange, QueryChannelRangeTlv, ReplyChannelRange} import fr.acinq.eclair.wire.ReplyChannelRangeTlv._ +import fr.acinq.eclair.wire.{EncodedShortChannelIds, EncodingType, ReplyChannelRange} import fr.acinq.eclair.{LongToBtcAmount, ShortChannelId, randomKey} import org.scalatest.FunSuite import scodec.bits.ByteVector @@ -36,51 +37,51 @@ class ChannelRangeQueriesSpec extends FunSuite { test("ask for update test") { // they don't provide anything => we always ask for the update - assert(Router.shouldRequestUpdate(0, 0, None, None)) - assert(Router.shouldRequestUpdate(Int.MaxValue, 12345, None, None)) + assert(shouldRequestUpdate(0, 0, None, None)) + assert(shouldRequestUpdate(Int.MaxValue, 12345, None, None)) // their update is older => don't ask val now = Platform.currentTime / 1000 - assert(!Router.shouldRequestUpdate(now, 0, Some(now - 1), None)) - assert(!Router.shouldRequestUpdate(now, 0, Some(now - 1), Some(12345))) - assert(!Router.shouldRequestUpdate(now, 12344, Some(now - 1), None)) - assert(!Router.shouldRequestUpdate(now, 12344, Some(now - 1), Some(12345))) + assert(!shouldRequestUpdate(now, 0, Some(now - 1), None)) + assert(!shouldRequestUpdate(now, 0, Some(now - 1), Some(12345))) + assert(!shouldRequestUpdate(now, 12344, Some(now - 1), None)) + assert(!shouldRequestUpdate(now, 12344, Some(now - 1), Some(12345))) // their update is newer but stale => don't ask val old = now - 4 * 2016 * 24 * 3600 - assert(!Router.shouldRequestUpdate(old - 1, 0, Some(old), None)) - assert(!Router.shouldRequestUpdate(old - 1, 0, Some(old), Some(12345))) - assert(!Router.shouldRequestUpdate(old - 1, 12344, Some(old), None)) - assert(!Router.shouldRequestUpdate(old - 1, 12344, Some(old), Some(12345))) + assert(!shouldRequestUpdate(old - 1, 0, Some(old), None)) + assert(!shouldRequestUpdate(old - 1, 0, Some(old), Some(12345))) + assert(!shouldRequestUpdate(old - 1, 12344, Some(old), None)) + assert(!shouldRequestUpdate(old - 1, 12344, Some(old), Some(12345))) // their update is newer but with the same checksum, and ours is stale or about to be => ask (we want to renew our update) - assert(Router.shouldRequestUpdate(old, 12345, Some(now), Some(12345))) + assert(shouldRequestUpdate(old, 12345, Some(now), Some(12345))) // their update is newer but with the same checksum => don't ask - assert(!Router.shouldRequestUpdate(now - 1, 12345, Some(now), Some(12345))) + assert(!shouldRequestUpdate(now - 1, 12345, Some(now), Some(12345))) // their update is newer with a different checksum => always ask - assert(Router.shouldRequestUpdate(now - 1, 0, Some(now), None)) - assert(Router.shouldRequestUpdate(now - 1, 0, Some(now), Some(12345))) - assert(Router.shouldRequestUpdate(now - 1, 12344, Some(now), None)) - assert(Router.shouldRequestUpdate(now - 1, 12344, Some(now), Some(12345))) + assert(shouldRequestUpdate(now - 1, 0, Some(now), None)) + assert(shouldRequestUpdate(now - 1, 0, Some(now), Some(12345))) + assert(shouldRequestUpdate(now - 1, 12344, Some(now), None)) + assert(shouldRequestUpdate(now - 1, 12344, Some(now), Some(12345))) // they just provided a 0 checksum => don't ask - assert(!Router.shouldRequestUpdate(0, 0, None, Some(0))) - assert(!Router.shouldRequestUpdate(now, 1234, None, Some(0))) + assert(!shouldRequestUpdate(0, 0, None, Some(0))) + assert(!shouldRequestUpdate(now, 1234, None, Some(0))) // they just provided a checksum that is the same as us => don't ask - assert(!Router.shouldRequestUpdate(now, 1234, None, Some(1234))) + assert(!shouldRequestUpdate(now, 1234, None, Some(1234))) // they just provided a different checksum that is the same as us => ask - assert(Router.shouldRequestUpdate(now, 1234, None, Some(1235))) + assert(shouldRequestUpdate(now, 1234, None, Some(1235))) } test("compute checksums") { - assert(Router.crc32c(ByteVector.fromValidHex("00" * 32)) == 0x8a9136aaL) - assert(Router.crc32c(ByteVector.fromValidHex("FF" * 32)) == 0x62a8ab43L) - assert(Router.crc32c(ByteVector((0 to 31).map(_.toByte))) == 0x46dd794eL) - assert(Router.crc32c(ByteVector((31 to 0 by -1).map(_.toByte))) == 0x113fdb5cL) + assert(crc32c(ByteVector.fromValidHex("00" * 32)) == 0x8a9136aaL) + assert(crc32c(ByteVector.fromValidHex("FF" * 32)) == 0x62a8ab43L) + assert(crc32c(ByteVector((0 to 31).map(_.toByte))) == 0x46dd794eL) + assert(crc32c(ByteVector((31 to 0 by -1).map(_.toByte))) == 0x113fdb5cL) } test("compute flag tests") { @@ -110,28 +111,28 @@ class ChannelRangeQueriesSpec extends FunSuite { import fr.acinq.eclair.wire.QueryShortChannelIdsTlv.QueryFlagType._ - assert(Router.getChannelDigestInfo(channels)(ab.shortChannelId) == (Timestamps(now, now), Checksums(1697591108L, 3692323747L))) + assert(getChannelDigestInfo(channels)(ab.shortChannelId) == (Timestamps(now, now), Checksums(1697591108L, 3692323747L))) // no extended info but we know the channel: we ask for the updates - assert(Router.computeFlag(channels)(ab.shortChannelId, None, None, false) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2)) - assert(Router.computeFlag(channels)(ab.shortChannelId, None, None, true) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + assert(computeFlag(channels)(ab.shortChannelId, None, None, false) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2)) + assert(computeFlag(channels)(ab.shortChannelId, None, None, true) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) // same checksums, newer timestamps: we don't ask anything - assert(Router.computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now + 1, now + 1)), Some(Checksums(1697591108L, 3692323747L)), true) === 0) + assert(computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now + 1, now + 1)), Some(Checksums(1697591108L, 3692323747L)), true) === 0) // different checksums, newer timestamps: we ask for the updates - assert(Router.computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now + 1, now)), Some(Checksums(154654604, 3692323747L)), true) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) - assert(Router.computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now, now + 1)), Some(Checksums(1697591108L, 45664546)), true) === (INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) - assert(Router.computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now + 1, now + 1)), Some(Checksums(154654604, 45664546 + 6)), true) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + assert(computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now + 1, now)), Some(Checksums(154654604, 3692323747L)), true) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + assert(computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now, now + 1)), Some(Checksums(1697591108L, 45664546)), true) === (INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + assert(computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now + 1, now + 1)), Some(Checksums(154654604, 45664546 + 6)), true) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) // different checksums, older timestamps: we don't ask anything - assert(Router.computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now - 1, now)), Some(Checksums(154654604, 3692323747L)), true) === 0) - assert(Router.computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now, now - 1)), Some(Checksums(1697591108L, 45664546)), true) === 0) - assert(Router.computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now - 1, now - 1)), Some(Checksums(154654604, 45664546)), true) === 0) + assert(computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now - 1, now)), Some(Checksums(154654604, 3692323747L)), true) === 0) + assert(computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now, now - 1)), Some(Checksums(1697591108L, 45664546)), true) === 0) + assert(computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now - 1, now - 1)), Some(Checksums(154654604, 45664546)), true) === 0) // missing channel update: we ask for it - assert(Router.computeFlag(channels)(cd.shortChannelId, Some(Timestamps(now, now)), Some(Checksums(3297511804L, 3297511804L)), true) === (INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + assert(computeFlag(channels)(cd.shortChannelId, Some(Timestamps(now, now)), Some(Checksums(3297511804L, 3297511804L)), true) === (INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) // unknown channel: we ask everything - assert(Router.computeFlag(channels)(ef.shortChannelId, None, None, false) === (INCLUDE_CHANNEL_ANNOUNCEMENT | INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2)) - assert(Router.computeFlag(channels)(ef.shortChannelId, None, None, true) === (INCLUDE_CHANNEL_ANNOUNCEMENT | INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + assert(computeFlag(channels)(ef.shortChannelId, None, None, false) === (INCLUDE_CHANNEL_ANNOUNCEMENT | INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2)) + assert(computeFlag(channels)(ef.shortChannelId, None, None, true) === (INCLUDE_CHANNEL_ANNOUNCEMENT | INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) } def makeShortChannelIds(height: Int, count: Int): List[ShortChannelId] = { @@ -151,7 +152,7 @@ class ChannelRangeQueriesSpec extends FunSuite { } def validate(chunk: ShortChannelIdsChunk) = { - require(chunk.shortChannelIds.forall(Router.keep(chunk.firstBlock, chunk.numBlocks, _))) + require(chunk.shortChannelIds.forall(keep(chunk.firstBlock, chunk.numBlocks, _))) } // check that chunks contain exactly the ids they were built from are are consistent i.e each chunk covers a range that immediately follows @@ -167,7 +168,7 @@ class ChannelRangeQueriesSpec extends FunSuite { // aggregate ids from all chunks, to check that they match our input ids exactly val chunkIds = SortedSet.empty[ShortChannelId] ++ chunks.flatMap(_.shortChannelIds).toSet - val expected = ids.filter(Router.keep(firstBlockNum, numberOfBlocks, _)) + val expected = ids.filter(keep(firstBlockNum, numberOfBlocks, _)) if (expected.isEmpty) require(chunks == List(ShortChannelIdsChunk(firstBlockNum, numberOfBlocks, Nil))) chunks.foreach(validate) @@ -177,7 +178,7 @@ class ChannelRangeQueriesSpec extends FunSuite { require(noOverlap(chunks)) } - test("limit channel ids chunk size") { + test("limit channel ids chunk size") { val ids = makeShortChannelIds(1, 3) val chunk = ShortChannelIdsChunk(0, 10, ids) @@ -200,7 +201,7 @@ class ChannelRangeQueriesSpec extends FunSuite { val ids = Nil val firstBlockNum = 10 val numberOfBlocks = 100 - val chunks = Router.split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, ids.size) + val chunks = split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, ids.size) assert(chunks == ShortChannelIdsChunk(firstBlockNum, numberOfBlocks, Nil) :: Nil) } @@ -209,7 +210,7 @@ class ChannelRangeQueriesSpec extends FunSuite { val ids = List(id(1000), id(1001), id(1002), id(1003), id(1004), id(1005)) val firstBlockNum = 10 val numberOfBlocks = 100 - val chunks = Router.split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, ids.size) + val chunks = split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, ids.size) assert(chunks == ShortChannelIdsChunk(firstBlockNum, numberOfBlocks, Nil) :: Nil) } @@ -218,7 +219,7 @@ class ChannelRangeQueriesSpec extends FunSuite { val ids = List(id(1000), id(1001), id(1002), id(1003), id(1004), id(1005)) val firstBlockNum = 1100 val numberOfBlocks = 100 - val chunks = Router.split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, ids.size) + val chunks = split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, ids.size) assert(chunks == ShortChannelIdsChunk(firstBlockNum, numberOfBlocks, Nil) :: Nil) } @@ -227,7 +228,7 @@ class ChannelRangeQueriesSpec extends FunSuite { val ids = List(id(1000), id(1001), id(1002), id(1003), id(1004), id(1005)) val firstBlockNum = 900 val numberOfBlocks = 200 - val chunks = Router.split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, ids.size) + val chunks = split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, ids.size) assert(chunks == ShortChannelIdsChunk(firstBlockNum, numberOfBlocks, ids) :: Nil) } @@ -237,7 +238,7 @@ class ChannelRangeQueriesSpec extends FunSuite { val ids = List(id(1000, 0), id(1000, 1), id(1000, 2), id(1000, 3), id(1000, 4), id(1000, 5)) val firstBlockNum = 900 val numberOfBlocks = 200 - val chunks = Router.split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, 2) + val chunks = split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, 2) assert(chunks == ShortChannelIdsChunk(firstBlockNum, numberOfBlocks, ids) :: Nil) } @@ -246,7 +247,7 @@ class ChannelRangeQueriesSpec extends FunSuite { val ids = List(id(1000), id(1005), id(1012), id(1013), id(1040), id(1050)) val firstBlockNum = 900 val numberOfBlocks = 200 - val chunks = Router.split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, 2) + val chunks = split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, 2) assert(chunks == List( ShortChannelIdsChunk(firstBlockNum, 100 + 6, List(ids(0), ids(1))), ShortChannelIdsChunk(1006, 8, List(ids(2), ids(3))), @@ -259,7 +260,7 @@ class ChannelRangeQueriesSpec extends FunSuite { val ids = List(id(1000), id(1005), id(1012), id(1013), id(1040), id(1050)) val firstBlockNum = 1001 val numberOfBlocks = 200 - val chunks = Router.split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, 2) + val chunks = split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, 2) assert(chunks == List( ShortChannelIdsChunk(firstBlockNum, 12, List(ids(1), ids(2))), ShortChannelIdsChunk(1013, 1040 - 1013 + 1, List(ids(3), ids(4))), @@ -272,20 +273,20 @@ class ChannelRangeQueriesSpec extends FunSuite { val ids = List(id(1000), id(1001), id(1002), id(1003), id(1004), id(1005)) val firstBlockNum = 900 val numberOfBlocks = 105 - val chunks = Router.split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, 2) + val chunks = split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, 2) assert(chunks == List( ShortChannelIdsChunk(firstBlockNum, 100 + 2, List(ids(0), ids(1))), ShortChannelIdsChunk(1002, 2, List(ids(2), ids(3))), ShortChannelIdsChunk(1004, numberOfBlocks - 1004 + firstBlockNum, List(ids(4))) )) - } + } // all ids in different blocks, chunk size == 2, first and last id outside of range { val ids = List(id(1000), id(1001), id(1002), id(1003), id(1004), id(1005)) val firstBlockNum = 1001 val numberOfBlocks = 4 - val chunks = Router.split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, 2) + val chunks = split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, 2) assert(chunks == List( ShortChannelIdsChunk(firstBlockNum, 2, List(ids(1), ids(2))), ShortChannelIdsChunk(1003, 2, List(ids(3), ids(4))) @@ -297,7 +298,7 @@ class ChannelRangeQueriesSpec extends FunSuite { val ids = makeShortChannelIds(1000, 100) val firstBlockNum = 900 val numberOfBlocks = 200 - val chunks = Router.split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, 10) + val chunks = split(SortedSet.empty[ShortChannelId] ++ ids, firstBlockNum, numberOfBlocks, 10) assert(chunks == ShortChannelIdsChunk(firstBlockNum, numberOfBlocks, ids) :: Nil) } } @@ -307,11 +308,11 @@ class ChannelRangeQueriesSpec extends FunSuite { val firstBlockNum = 0 val numberOfBlocks = 1000 - validate(ids, firstBlockNum, numberOfBlocks, Router.split(ids, firstBlockNum, numberOfBlocks, 1)) - validate(ids, firstBlockNum, numberOfBlocks, Router.split(ids, firstBlockNum, numberOfBlocks, 20)) - validate(ids, firstBlockNum, numberOfBlocks, Router.split(ids, firstBlockNum, numberOfBlocks, 50)) - validate(ids, firstBlockNum, numberOfBlocks, Router.split(ids, firstBlockNum, numberOfBlocks, 100)) - validate(ids, firstBlockNum, numberOfBlocks, Router.split(ids, firstBlockNum, numberOfBlocks, 1000)) + validate(ids, firstBlockNum, numberOfBlocks, split(ids, firstBlockNum, numberOfBlocks, 1)) + validate(ids, firstBlockNum, numberOfBlocks, split(ids, firstBlockNum, numberOfBlocks, 20)) + validate(ids, firstBlockNum, numberOfBlocks, split(ids, firstBlockNum, numberOfBlocks, 50)) + validate(ids, firstBlockNum, numberOfBlocks, split(ids, firstBlockNum, numberOfBlocks, 100)) + validate(ids, firstBlockNum, numberOfBlocks, split(ids, firstBlockNum, numberOfBlocks, 1000)) } test("split short channel ids correctly (comprehensive tests)") { @@ -319,7 +320,7 @@ class ChannelRangeQueriesSpec extends FunSuite { for (firstBlockNum <- 0 to 60) { for (numberOfBlocks <- 1 to 60) { for (chunkSize <- 1 :: 2 :: 20 :: 50 :: 100 :: 1000 :: Nil) { - validate(ids, firstBlockNum, numberOfBlocks, Router.split(ids, firstBlockNum, numberOfBlocks, chunkSize)) + validate(ids, firstBlockNum, numberOfBlocks, split(ids, firstBlockNum, numberOfBlocks, chunkSize)) } } } @@ -327,11 +328,11 @@ class ChannelRangeQueriesSpec extends FunSuite { test("enforce maximum size of short channel lists") { - def makeChunk(startBlock: Int, count : Int) = ShortChannelIdsChunk(startBlock, count, makeShortChannelIds(startBlock, count)) + def makeChunk(startBlock: Int, count: Int) = ShortChannelIdsChunk(startBlock, count, makeShortChannelIds(startBlock, count)) def validate(before: ShortChannelIdsChunk, after: ShortChannelIdsChunk) = { require(before.shortChannelIds.containsSlice(after.shortChannelIds)) - require(after.shortChannelIds.size <= Router.MAXIMUM_CHUNK_SIZE) + require(after.shortChannelIds.size <= Sync.MAXIMUM_CHUNK_SIZE) } def validateChunks(before: List[ShortChannelIdsChunk], after: List[ShortChannelIdsChunk]): Unit = { @@ -341,40 +342,40 @@ class ChannelRangeQueriesSpec extends FunSuite { // empty chunk { val chunks = makeChunk(0, 0) :: Nil - assert(Router.enforceMaximumSize(chunks) == chunks) + assert(enforceMaximumSize(chunks) == chunks) } // chunks are just below the limit { - val chunks = makeChunk(0, Router.MAXIMUM_CHUNK_SIZE) :: makeChunk(Router.MAXIMUM_CHUNK_SIZE, Router.MAXIMUM_CHUNK_SIZE) :: Nil - assert(Router.enforceMaximumSize(chunks) == chunks) + val chunks = makeChunk(0, Sync.MAXIMUM_CHUNK_SIZE) :: makeChunk(Sync.MAXIMUM_CHUNK_SIZE, Sync.MAXIMUM_CHUNK_SIZE) :: Nil + assert(enforceMaximumSize(chunks) == chunks) } - + // fuzzy tests { val chunks = collection.mutable.ArrayBuffer.empty[ShortChannelIdsChunk] // we select parameters to make sure that some chunks will have too many ids - for (i <- 0 until 100) chunks += makeChunk(0, Router.MAXIMUM_CHUNK_SIZE - 500 + Random.nextInt(1000)) - val pruned = Router.enforceMaximumSize(chunks.toList) + for (i <- 0 until 100) chunks += makeChunk(0, Sync.MAXIMUM_CHUNK_SIZE - 500 + Random.nextInt(1000)) + val pruned = enforceMaximumSize(chunks.toList) validateChunks(chunks.toList, pruned) } } test("do not encode empty lists as COMPRESSED_ZLIB") { { - val reply = Router.buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), Block.RegtestGenesisBlock.hash, EncodingType.COMPRESSED_ZLIB, Some(QueryFlags(QueryFlags.WANT_ALL)), SortedMap()) + val reply = buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), Block.RegtestGenesisBlock.hash, EncodingType.COMPRESSED_ZLIB, Some(QueryFlags(QueryFlags.WANT_ALL)), SortedMap()) assert(reply == ReplyChannelRange(Block.RegtestGenesisBlock.hash, 0L, 42L, 1.toByte, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, Nil), Some(EncodedTimestamps(EncodingType.UNCOMPRESSED, Nil)), Some(EncodedChecksums(Nil)))) } { - val reply = Router.buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), Block.RegtestGenesisBlock.hash, EncodingType.COMPRESSED_ZLIB, Some(QueryFlags(QueryFlags.WANT_TIMESTAMPS)), SortedMap()) + val reply = buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), Block.RegtestGenesisBlock.hash, EncodingType.COMPRESSED_ZLIB, Some(QueryFlags(QueryFlags.WANT_TIMESTAMPS)), SortedMap()) assert(reply == ReplyChannelRange(Block.RegtestGenesisBlock.hash, 0L, 42L, 1.toByte, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, Nil), Some(EncodedTimestamps(EncodingType.UNCOMPRESSED, Nil)), None)) } { - val reply = Router.buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), Block.RegtestGenesisBlock.hash, EncodingType.COMPRESSED_ZLIB, Some(QueryFlags(QueryFlags.WANT_CHECKSUMS)), SortedMap()) + val reply = buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), Block.RegtestGenesisBlock.hash, EncodingType.COMPRESSED_ZLIB, Some(QueryFlags(QueryFlags.WANT_CHECKSUMS)), SortedMap()) assert(reply == ReplyChannelRange(Block.RegtestGenesisBlock.hash, 0L, 42L, 1.toByte, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, Nil), None, Some(EncodedChecksums(Nil)))) } { - val reply = Router.buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), Block.RegtestGenesisBlock.hash, EncodingType.COMPRESSED_ZLIB, None, SortedMap()) + val reply = buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), Block.RegtestGenesisBlock.hash, EncodingType.COMPRESSED_ZLIB, None, SortedMap()) assert(reply == ReplyChannelRange(Block.RegtestGenesisBlock.hash, 0L, 42L, 1.toByte, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, Nil), None, None)) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala index a8808316d..85d7b6e0a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala @@ -19,6 +19,7 @@ package fr.acinq.eclair.router import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} import fr.acinq.eclair.router.RouteCalculationSpec._ +import fr.acinq.eclair.router.Router.ChannelDesc import fr.acinq.eclair.wire.ChannelUpdate import fr.acinq.eclair.{LongToBtcAmount, ShortChannelId} import org.scalatest.FunSuite diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/NetworkStatsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/NetworkStatsSpec.scala index 97eee5818..811716c99 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/NetworkStatsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/NetworkStatsSpec.scala @@ -18,6 +18,7 @@ package fr.acinq.eclair.router import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.Satoshi +import fr.acinq.eclair.router.Router.PublicChannel import fr.acinq.eclair.wire.{ChannelAnnouncement, ChannelUpdate} import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, ShortChannelId, randomBytes32, randomBytes64, randomKey} import org.scalatest.FunSuite diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala index 06e9da01c..5ea976767 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala @@ -17,11 +17,13 @@ package fr.acinq.eclair.router import fr.acinq.bitcoin.Crypto.PublicKey -import fr.acinq.bitcoin.{Block, Btc, ByteVector32, ByteVector64, Satoshi} +import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64, Satoshi} import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph.graphEdgeToHop import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} import fr.acinq.eclair.router.Graph.{RichWeight, WeightRatios} +import fr.acinq.eclair.router.RouteCalculation._ +import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.transactions.Transactions import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, ShortChannelId, ToMilliSatoshiConversion, randomKey} @@ -51,7 +53,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - val route = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) } @@ -74,7 +76,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - val route = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS.copy(maxFeeBase = 1 msat), currentBlockHeight = 400000) + val route = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS.copy(maxFeeBase = 1 msat), currentBlockHeight = 400000) assert(route.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) } @@ -111,7 +113,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val graph = makeGraph(updates) - val Success(route) = Router.findRoute(graph, a, d, amount, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val Success(route) = findRoute(graph, a, d, amount, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) val totalCost = Graph.pathWeight(hops2Edges(route), amount, isPartial = false, 0, None).cost @@ -122,7 +124,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val (desc, update) = makeUpdate(5L, e, f, feeBase = 1 msat, feeProportionalMillionth = 400, minHtlc = 0 msat, maxHtlc = Some(10005 msat)) val graph1 = graph.addEdge(desc, update) - val Success(route1) = Router.findRoute(graph1, a, d, amount, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val Success(route1) = findRoute(graph1, a, d, amount, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(hops2Ids(route1) === 1 :: 2 :: 3 :: Nil) } @@ -137,7 +139,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { ).toMap val g = makeGraph(updates) - val route = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route.map(hops2Ids) === Success(2 :: 5 :: Nil)) } @@ -152,11 +154,11 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - val route1 = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route1 = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route1.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) val graphWithRemovedEdge = g.removeEdge(ChannelDesc(ShortChannelId(3L), c, d)) - val route2 = Router.findRoute(graphWithRemovedEdge, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route2 = findRoute(graphWithRemovedEdge, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route2.map(hops2Ids) === Failure(RouteNotFound)) } @@ -177,7 +179,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val graph = makeGraph(updates) - val route = Router.findRoute(graph, f, i, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route = findRoute(graph, f, i, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route.map(hops2Ids) === Success(4 :: 3 :: Nil)) } @@ -199,7 +201,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val graph = makeGraph(updates) - val route = Router.findRoute(graph, f, i, DEFAULT_AMOUNT_MSAT, numRoutes = 2, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route = findRoute(graph, f, i, DEFAULT_AMOUNT_MSAT, numRoutes = 2, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route.map(hops2Ids) === Success(4 :: Nil)) } @@ -220,7 +222,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val graph = makeGraph(updates) - val route = Router.findRoute(graph, f, i, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route = findRoute(graph, f, i, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route.map(hops2Ids) == Success(1 :: 2 :: 3 :: Nil)) } @@ -241,7 +243,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val graph = makeGraph(updates) - val route = Router.findRoute(graph, f, i, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route = findRoute(graph, f, i, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route.map(hops2Ids) === Failure(RouteNotFound)) } @@ -262,7 +264,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val graph = makeGraph(updates) - val route = Router.findRoute(graph, f, i, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route = findRoute(graph, f, i, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route.map(hops2Ids) === Success(1 :: 6 :: 3 :: Nil)) } @@ -277,7 +279,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - val route = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) } @@ -289,7 +291,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - val route = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route.map(hops2Ids) === Failure(RouteNotFound)) } @@ -302,7 +304,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - val route = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route.map(hops2Ids) === Failure(RouteNotFound)) } @@ -314,8 +316,8 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates).addVertex(a).addVertex(e) - assert(Router.findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) === Failure(RouteNotFound)) - assert(Router.findRoute(g, b, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) === Failure(RouteNotFound)) + assert(findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) === Failure(RouteNotFound)) + assert(findRoute(g, b, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) === Failure(RouteNotFound)) } test("route not found (amount too high OR too low)") { @@ -337,8 +339,8 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updatesHi) val g1 = makeGraph(updatesLo) - assert(Router.findRoute(g, a, d, highAmount, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) === Failure(RouteNotFound)) - assert(Router.findRoute(g1, a, d, lowAmount, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) === Failure(RouteNotFound)) + assert(findRoute(g, a, d, highAmount, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) === Failure(RouteNotFound)) + assert(findRoute(g1, a, d, lowAmount, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) === Failure(RouteNotFound)) } test("route to self") { @@ -350,7 +352,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - val route = Router.findRoute(g, a, a, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route = findRoute(g, a, a, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route.map(hops2Ids) === Failure(CannotRouteToSelf)) } @@ -364,7 +366,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - val route = Router.findRoute(g, a, b, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route = findRoute(g, a, b, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route.map(hops2Ids) === Success(1 :: Nil)) } @@ -380,10 +382,10 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - val route1 = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route1 = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route1.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) - val route2 = Router.findRoute(g, e, a, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route2 = findRoute(g, e, a, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route2.map(hops2Ids) === Failure(RouteNotFound)) } @@ -412,7 +414,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - val hops = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000).get + val hops = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000).get assert(hops === ChannelHop(a, b, uab) :: ChannelHop(b, c, ubc) :: ChannelHop(c, d, ucd) :: ChannelHop(d, e, ude) :: Nil) } @@ -431,7 +433,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val extraHops = extraHop1 :: extraHop2 :: extraHop3 :: extraHop4 :: Nil val amount = 90000 sat // below RoutingHeuristics.CAPACITY_CHANNEL_LOW - val assistedChannels = Router.toAssistedChannels(extraHops, e, amount.toMilliSatoshi) + val assistedChannels = toAssistedChannels(extraHops, e, amount.toMilliSatoshi) assert(assistedChannels(extraHop4.shortChannelId) === AssistedChannel(extraHop4, e, 100050.sat.toMilliSatoshi)) assert(assistedChannels(extraHop3.shortChannelId) === AssistedChannel(extraHop3, d, 100200.sat.toMilliSatoshi)) @@ -449,7 +451,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - val route1 = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, ignoredEdges = Set(ChannelDesc(ShortChannelId(3L), c, d)), routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route1 = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, ignoredEdges = Set(ChannelDesc(ShortChannelId(3L), c, d)), routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route1.map(hops2Ids) === Failure(RouteNotFound)) // verify that we left the graph untouched @@ -458,7 +460,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { assert(g.containsVertex(d)) // make sure we can find a route if without the blacklist - val route2 = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route2 = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route2.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) } @@ -471,14 +473,14 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - val route = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route.map(hops2Ids) === Failure(RouteNotFound)) // now we add the missing edge to reach the destination val (extraDesc, extraUpdate) = makeUpdate(4L, d, e, 5 msat, 5) val extraGraphEdges = Set(GraphEdge(extraDesc, extraUpdate)) - val route1 = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, extraEdges = extraGraphEdges, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route1 = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, extraEdges = extraGraphEdges, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route1.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) } @@ -492,7 +494,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - val route1 = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route1 = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route1.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) assert(route1.get(1).lastUpdate.feeBaseMsat === 10.msat) @@ -500,7 +502,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val extraGraphEdges = Set(GraphEdge(extraDesc, extraUpdate)) - val route2 = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, extraEdges = extraGraphEdges, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route2 = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, extraEdges = extraGraphEdges, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route2.map(hops2Ids) === Success(1 :: 2 :: 3 :: 4 :: Nil)) assert(route2.get(1).lastUpdate.feeBaseMsat === 5.msat) } @@ -542,7 +544,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { (shortChannelId, pc) } - val ignored = Router.getIgnoredChannelDesc(publicChannels, ignoreNodes = Set(c, j, randomKey.publicKey)) + val ignored = getIgnoredChannelDesc(publicChannels, ignoreNodes = Set(c, j, randomKey.publicKey)) assert(ignored.toSet.contains(ChannelDesc(ShortChannelId(2L), b, c))) assert(ignored.toSet.contains(ChannelDesc(ShortChannelId(2L), c, b))) @@ -560,10 +562,10 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - assert(Router.findRoute(g, nodes(0), nodes(18), DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000).map(hops2Ids) === Success(0 until 18)) - assert(Router.findRoute(g, nodes(0), nodes(19), DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000).map(hops2Ids) === Success(0 until 19)) - assert(Router.findRoute(g, nodes(0), nodes(20), DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000).map(hops2Ids) === Success(0 until 20)) - assert(Router.findRoute(g, nodes(0), nodes(21), DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000).map(hops2Ids) === Failure(RouteNotFound)) + assert(findRoute(g, nodes(0), nodes(18), DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000).map(hops2Ids) === Success(0 until 18)) + assert(findRoute(g, nodes(0), nodes(19), DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000).map(hops2Ids) === Success(0 until 19)) + assert(findRoute(g, nodes(0), nodes(20), DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000).map(hops2Ids) === Success(0 until 20)) + assert(findRoute(g, nodes(0), nodes(21), DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000).map(hops2Ids) === Failure(RouteNotFound)) } test("ignore cheaper route when it has more than 20 hops") { @@ -579,7 +581,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates2) - val route = Router.findRoute(g, nodes(0), nodes(49), DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route = findRoute(g, nodes(0), nodes(49), DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route.map(hops2Ids) === Success(0 :: 1 :: 99 :: 48 :: Nil)) } @@ -595,7 +597,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { makeUpdate(6, f, d, feeBase = 5 msat, 0, minHtlc = 0 msat, maxHtlc = None, CltvExpiryDelta(9)) ).toMap) - val route = Router.findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS.copy(routeMaxCltv = CltvExpiryDelta(28)), currentBlockHeight = 400000) + val route = findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS.copy(routeMaxCltv = CltvExpiryDelta(28)), currentBlockHeight = 400000) assert(route.map(hops2Ids) === Success(4 :: 5 :: 6 :: Nil)) } @@ -611,7 +613,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { makeUpdate(6, b, f, feeBase = 5 msat, 0, minHtlc = 0 msat, maxHtlc = None, CltvExpiryDelta(9)) ).toMap) - val route = Router.findRoute(g, a, f, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS.copy(routeMaxLength = 3), currentBlockHeight = 400000) + val route = findRoute(g, a, f, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS.copy(routeMaxLength = 3), currentBlockHeight = 400000) assert(route.map(hops2Ids) === Success(1 :: 6 :: Nil)) } @@ -626,7 +628,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - val route1 = Router.findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route1 = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route1.map(hops2Ids) === Success(1 :: 2 :: 4 :: 5 :: Nil)) } @@ -644,7 +646,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - val route1 = Router.findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val route1 = findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(route1.map(hops2Ids) === Success(1 :: 3 :: 5 :: Nil)) } @@ -772,7 +774,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { makeUpdate(7L, e, c, feeBase = 9 msat, 0) ).toMap) - (for {_ <- 0 to 10} yield Router.findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, numRoutes = 3, routeParams = strictFeeParams, currentBlockHeight = 400000)).map { + (for {_ <- 0 to 10} yield findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, numRoutes = 3, routeParams = strictFeeParams, currentBlockHeight = 400000)).map { case Failure(thr) => fail(thr) case Success(someRoute) => @@ -801,10 +803,10 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val g = makeGraph(updates) - val Success(routeFeeOptimized) = Router.findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, numRoutes = 0, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) + val Success(routeFeeOptimized) = findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, numRoutes = 0, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000) assert(hops2Nodes(routeFeeOptimized) === (a, b) :: (b, c) :: (c, d) :: Nil) - val Success(routeCltvOptimized) = Router.findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, numRoutes = 0, routeParams = DEFAULT_ROUTE_PARAMS.copy(ratios = Some(WeightRatios( + val Success(routeCltvOptimized) = findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, numRoutes = 0, routeParams = DEFAULT_ROUTE_PARAMS.copy(ratios = Some(WeightRatios( cltvDeltaFactor = 1, ageFactor = 0, capacityFactor = 0 @@ -812,7 +814,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { assert(hops2Nodes(routeCltvOptimized) === (a, e) :: (e, f) :: (f, d) :: Nil) - val Success(routeCapacityOptimized) = Router.findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, numRoutes = 0, routeParams = DEFAULT_ROUTE_PARAMS.copy(ratios = Some(WeightRatios( + val Success(routeCapacityOptimized) = findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, numRoutes = 0, routeParams = DEFAULT_ROUTE_PARAMS.copy(ratios = Some(WeightRatios( cltvDeltaFactor = 0, ageFactor = 0, capacityFactor = 1 @@ -833,7 +835,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { makeUpdateShort(ShortChannelId(s"${currentBlockHeight}x0x6"), f, d, feeBase = 1 msat, 0, minHtlc = 0 msat, maxHtlc = None, cltvDelta = CltvExpiryDelta(144)) ).toMap) - val Success(routeScoreOptimized) = Router.findRoute(g, a, d, DEFAULT_AMOUNT_MSAT / 2, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS.copy(ratios = Some(WeightRatios( + val Success(routeScoreOptimized) = findRoute(g, a, d, DEFAULT_AMOUNT_MSAT / 2, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS.copy(ratios = Some(WeightRatios( ageFactor = 0.33, cltvDeltaFactor = 0.33, capacityFactor = 0.33 @@ -852,7 +854,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { makeUpdateShort(ShortChannelId(s"0x0x6"), f, d, feeBase = 1 msat, 0, minHtlc = 0 msat, maxHtlc = None, cltvDelta = CltvExpiryDelta(12)) ).toMap) - val Success(routeScoreOptimized) = Router.findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS.copy(ratios = Some(WeightRatios( + val Success(routeScoreOptimized) = findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS.copy(ratios = Some(WeightRatios( ageFactor = 0.33, cltvDeltaFactor = 0.33, capacityFactor = 0.33 @@ -873,7 +875,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { makeUpdateShort(ShortChannelId(s"0x0x6"), f, d, feeBase = 1 msat, 0, minHtlc = 0 msat, maxHtlc = None, cltvDelta = CltvExpiryDelta(144)) ).toMap) - val Success(routeScoreOptimized) = Router.findRoute(g, a, d, DEFAULT_AMOUNT_MSAT / 2, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS.copy(ratios = Some(WeightRatios( + val Success(routeScoreOptimized) = findRoute(g, a, d, DEFAULT_AMOUNT_MSAT / 2, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS.copy(ratios = Some(WeightRatios( ageFactor = 0.33, cltvDeltaFactor = 0.33, capacityFactor = 0.33 @@ -918,7 +920,7 @@ class RouteCalculationSpec extends FunSuite with ParallelTestExecution { val targetNode = PublicKey(hex"024655b768ef40951b20053a5c4b951606d4d86085d51238f2c67c7dec29c792ca") val amount = 351000 msat - val Success(route) = Router.findRoute(g, thisNode, targetNode, amount, 1, Set.empty, Set.empty, Set.empty, params, currentBlockHeight = 567634) // simulate mainnet block for heuristic + val Success(route) = findRoute(g, thisNode, targetNode, amount, 1, Set.empty, Set.empty, Set.empty, params, currentBlockHeight = 567634) // simulate mainnet block for heuristic assert(route.size == 2) assert(route.last.nextNodeId == targetNode) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala index 444f64e1f..473ad757c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala @@ -25,12 +25,12 @@ import fr.acinq.eclair.blockchain._ import fr.acinq.eclair.channel.BITCOIN_FUNDING_EXTERNAL_CHANNEL_SPENT import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.io.Peer.PeerRoutingMessage -import fr.acinq.eclair.io.PeerConnection.InvalidSignature import fr.acinq.eclair.payment.PaymentRequest.ExtraHop -import fr.acinq.eclair.router.Announcements.makeChannelUpdate +import fr.acinq.eclair.router.Announcements.{makeChannelUpdate, makeNodeAnnouncement} import fr.acinq.eclair.router.RouteCalculationSpec.DEFAULT_AMOUNT_MSAT +import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.transactions.Scripts -import fr.acinq.eclair.wire.QueryShortChannelIds +import fr.acinq.eclair.wire.{Color, QueryShortChannelIds} import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, ShortChannelId, randomKey} import scodec.bits._ @@ -49,46 +49,212 @@ class RouterSpec extends BaseRouterSpec { import fixture._ val eventListener = TestProbe() system.eventStream.subscribe(eventListener.ref, classOf[NetworkEvent]) + system.eventStream.subscribe(eventListener.ref, classOf[Rebroadcast]) + val peerConnection = TestProbe() - val channelId_ac = ShortChannelId(420000, 5, 0) - val chan_ac = channelAnnouncement(channelId_ac, priv_a, priv_c, priv_funding_a, priv_funding_c) - val update_ac = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, c, channelId_ac, CltvExpiryDelta(7), 0 msat, 766000 msat, 10, 500000000L msat) - // a-x will not be found - val priv_x = randomKey - val chan_ax = channelAnnouncement(ShortChannelId(42001), priv_a, priv_x, priv_funding_a, randomKey) - val update_ax = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, priv_x.publicKey, chan_ax.shortChannelId, CltvExpiryDelta(7), 0 msat, 766000 msat, 10, 500000000L msat) - // a-y will have an invalid script - val priv_y = randomKey - val priv_funding_y = randomKey - val chan_ay = channelAnnouncement(ShortChannelId(42002), priv_a, priv_y, priv_funding_a, priv_funding_y) - val update_ay = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, priv_y.publicKey, chan_ay.shortChannelId, CltvExpiryDelta(7), 0 msat, 766000 msat, 10, 500000000L msat) - // a-z will be spent - val priv_z = randomKey - val priv_funding_z = randomKey - val chan_az = channelAnnouncement(ShortChannelId(42003), priv_a, priv_z, priv_funding_a, priv_funding_z) - val update_az = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, priv_z.publicKey, chan_az.shortChannelId, CltvExpiryDelta(7), 0 msat, 766000 msat, 10, 500000000L msat) + { + // valid channel announcement, no stashing + val chan_ac = channelAnnouncement(ShortChannelId(420000, 5, 0), priv_a, priv_c, priv_funding_a, priv_funding_c) + val update_ac = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, c, chan_ac.shortChannelId, CltvExpiryDelta(7), 0 msat, 766000 msat, 10, 500000000L msat) + val node_c = makeNodeAnnouncement(priv_c, "node-C", Color(123, 100, -40), Nil, hex"0200", timestamp = Platform.currentTime.milliseconds.toSeconds + 1) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_ac) + peerConnection.expectNoMsg(100 millis) // we don't immediately acknowledge the announcement (back pressure) + watcher.expectMsg(ValidateRequest(chan_ac)) + watcher.send(router, ValidateResult(chan_ac, Right(Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_a, funding_c)))) :: Nil, lockTime = 0), UtxoStatus.Unspent))) + peerConnection.expectMsg(TransportHandler.ReadAck(chan_ac)) + peerConnection.expectMsg(GossipDecision.Accepted(chan_ac)) + assert(peerConnection.sender() == router) + watcher.expectMsgType[WatchSpentBasic] + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ac) + peerConnection.expectMsg(TransportHandler.ReadAck(update_ac)) + peerConnection.expectMsg(GossipDecision.Accepted(update_ac)) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_c) + peerConnection.expectMsg(TransportHandler.ReadAck(node_c)) + peerConnection.expectMsg(GossipDecision.Accepted(node_c)) + eventListener.expectMsg(ChannelsDiscovered(SingleChannelDiscovered(chan_ac, 1000000 sat, None, None) :: Nil)) + eventListener.expectMsg(ChannelUpdatesReceived(update_ac :: Nil)) + eventListener.expectMsg(NodeUpdated(node_c)) + peerConnection.expectNoMsg(100 millis) + eventListener.expectNoMsg(100 millis) + router ! Router.TickBroadcast + eventListener.expectMsgType[Rebroadcast] + } - router ! PeerRoutingMessage(null, remoteNodeId, chan_ac) - router ! PeerRoutingMessage(null, remoteNodeId, chan_ax) - router ! PeerRoutingMessage(null, remoteNodeId, chan_ay) - router ! PeerRoutingMessage(null, remoteNodeId, chan_az) - // router won't validate channels before it has a recent enough channel update - router ! PeerRoutingMessage(null, remoteNodeId, update_ac) - router ! PeerRoutingMessage(null, remoteNodeId, update_ax) - router ! PeerRoutingMessage(null, remoteNodeId, update_ay) - router ! PeerRoutingMessage(null, remoteNodeId, update_az) - watcher.expectMsg(ValidateRequest(chan_ac)) - watcher.expectMsg(ValidateRequest(chan_ax)) - watcher.expectMsg(ValidateRequest(chan_ay)) - watcher.expectMsg(ValidateRequest(chan_az)) - watcher.send(router, ValidateResult(chan_ac, Right(Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_a, funding_c)))) :: Nil, lockTime = 0), UtxoStatus.Unspent))) - watcher.send(router, ValidateResult(chan_ax, Left(new RuntimeException(s"funding tx not found")))) - watcher.send(router, ValidateResult(chan_ay, Right(Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_a, randomKey.publicKey)))) :: Nil, lockTime = 0), UtxoStatus.Unspent))) - watcher.send(router, ValidateResult(chan_az, Right(Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_a, priv_funding_z.publicKey)))) :: Nil, lockTime = 0), UtxoStatus.Spent(spendingTxConfirmed = true)))) - watcher.expectMsgType[WatchSpentBasic] - watcher.expectNoMsg(1 second) + { + // valid channel announcement, stashing while validating channel announcement + val priv_u = randomKey + val priv_funding_u = randomKey + val chan_uc = channelAnnouncement(ShortChannelId(420000, 6, 0), priv_u, priv_c, priv_funding_u, priv_funding_c) + val update_uc = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_u, c, chan_uc.shortChannelId, CltvExpiryDelta(7), 0 msat, 766000 msat, 10, 500000000L msat) + val node_u = makeNodeAnnouncement(priv_u, "node-U", Color(-120, -20, 60), Nil, hex"00") + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_uc) + peerConnection.expectNoMsg(200 millis) // we don't immediately acknowledge the announcement (back pressure) + watcher.expectMsg(ValidateRequest(chan_uc)) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_uc) + peerConnection.expectMsg(TransportHandler.ReadAck(update_uc)) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_u) + peerConnection.expectMsg(TransportHandler.ReadAck(node_u)) + watcher.send(router, ValidateResult(chan_uc, Right(Transaction(version = 0, txIn = Nil, txOut = TxOut(2000000 sat, write(pay2wsh(Scripts.multiSig2of2(priv_funding_u.publicKey, funding_c)))) :: Nil, lockTime = 0), UtxoStatus.Unspent))) + peerConnection.expectMsg(TransportHandler.ReadAck(chan_uc)) + peerConnection.expectMsg(GossipDecision.Accepted(chan_uc)) + assert(peerConnection.sender() == router) + watcher.expectMsgType[WatchSpentBasic] + peerConnection.expectMsg(GossipDecision.Accepted(update_uc)) + peerConnection.expectMsg(GossipDecision.Accepted(node_u)) + eventListener.expectMsg(ChannelsDiscovered(SingleChannelDiscovered(chan_uc, 2000000 sat, None, None) :: Nil)) + eventListener.expectMsg(ChannelUpdatesReceived(update_uc :: Nil)) + eventListener.expectMsg(NodesDiscovered(node_u :: Nil)) + peerConnection.expectNoMsg(100 millis) + eventListener.expectNoMsg(100 millis) + router ! Router.TickBroadcast + eventListener.expectMsgType[Rebroadcast] + } + + { + // duplicates + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_a) + peerConnection.expectMsg(TransportHandler.ReadAck(node_a)) + peerConnection.expectMsg(GossipDecision.Duplicate(node_a)) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_ab) + peerConnection.expectMsg(TransportHandler.ReadAck(chan_ab)) + peerConnection.expectMsg(GossipDecision.Duplicate(chan_ab)) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ab) + peerConnection.expectMsg(TransportHandler.ReadAck(update_ab)) + peerConnection.expectMsg(GossipDecision.Duplicate(update_ab)) + peerConnection.expectNoMsg(100 millis) + router ! Router.TickBroadcast + eventListener.expectNoMsg(100 millis) + } + + { + // invalid signatures + val invalid_node_a = node_a.copy(timestamp = node_a.timestamp + 10) + val invalid_chan_a = channelAnnouncement(ShortChannelId(420000, 5, 1), priv_a, priv_c, priv_funding_a, priv_funding_c).copy(nodeId1 = randomKey.publicKey) + val invalid_update_ab = update_ab.copy(cltvExpiryDelta = CltvExpiryDelta(21), timestamp = update_ab.timestamp + 1) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, invalid_node_a) + peerConnection.expectMsg(TransportHandler.ReadAck(invalid_node_a)) + peerConnection.expectMsg(GossipDecision.InvalidSignature(invalid_node_a)) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, invalid_chan_a) + peerConnection.expectMsg(TransportHandler.ReadAck(invalid_chan_a)) + peerConnection.expectMsg(GossipDecision.InvalidSignature(invalid_chan_a)) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, invalid_update_ab) + peerConnection.expectMsg(TransportHandler.ReadAck(invalid_update_ab)) + peerConnection.expectMsg(GossipDecision.InvalidSignature(invalid_update_ab)) + peerConnection.expectNoMsg(100 millis) + router ! Router.TickBroadcast + eventListener.expectNoMsg(100 millis) + } + + { + // pruned channel + val priv_v = randomKey + val priv_funding_v = randomKey + val chan_vc = channelAnnouncement(ShortChannelId(420000, 7, 0), priv_v, priv_c, priv_funding_v, priv_funding_c) + nodeParams.db.network.addToPruned(chan_vc.shortChannelId :: Nil) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_vc) + peerConnection.expectMsg(TransportHandler.ReadAck(chan_vc)) + peerConnection.expectMsg(GossipDecision.ChannelPruned(chan_vc)) + peerConnection.expectNoMsg(100 millis) + router ! Router.TickBroadcast + eventListener.expectNoMsg(100 millis) + } + + { + // stale channel update + val update_ab = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, priv_b.publicKey, chan_ab.shortChannelId, CltvExpiryDelta(7), 0 msat, 766000 msat, 10, 500000000L msat, timestamp = (Platform.currentTime.milliseconds - 15.days).toSeconds) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ab) + peerConnection.expectMsg(TransportHandler.ReadAck(update_ab)) + peerConnection.expectMsg(GossipDecision.Stale(update_ab)) + peerConnection.expectNoMsg(100 millis) + router ! Router.TickBroadcast + eventListener.expectNoMsg(100 millis) + } + + { + // unknown channel + val priv_y = randomKey + val update_ay = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, priv_y.publicKey, ShortChannelId(4646464), CltvExpiryDelta(7), 0 msat, 766000 msat, 10, 500000000L msat) + val node_y = makeNodeAnnouncement(priv_y, "node-Y", Color(123, 100, -40), Nil, hex"0200") + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ay) + peerConnection.expectMsg(TransportHandler.ReadAck(update_ay)) + peerConnection.expectMsg(GossipDecision.NoRelatedChannel(update_ay)) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_y) + peerConnection.expectMsg(TransportHandler.ReadAck(node_y)) + peerConnection.expectMsg(GossipDecision.NoKnownChannel(node_y)) + peerConnection.expectNoMsg(100 millis) + router ! Router.TickBroadcast + eventListener.expectNoMsg(100 millis) + } + + { + // invalid announcement + reject stashed + val priv_y = randomKey + val priv_funding_y = randomKey // a-y will have an invalid script + val chan_ay = channelAnnouncement(ShortChannelId(42002), priv_a, priv_y, priv_funding_a, priv_funding_y) + val update_ay = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, priv_y.publicKey, chan_ay.shortChannelId, CltvExpiryDelta(7), 0 msat, 766000 msat, 10, 500000000L msat) + val node_y = makeNodeAnnouncement(priv_y, "node-Y", Color(123, 100, -40), Nil, hex"0200") + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_ay) + watcher.expectMsg(ValidateRequest(chan_ay)) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ay) + peerConnection.expectMsg(TransportHandler.ReadAck(update_ay)) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_y) + peerConnection.expectMsg(TransportHandler.ReadAck(node_y)) + watcher.send(router, ValidateResult(chan_ay, Right(Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_a, randomKey.publicKey)))) :: Nil, lockTime = 0), UtxoStatus.Unspent))) + peerConnection.expectMsg(TransportHandler.ReadAck(chan_ay)) + peerConnection.expectMsg(GossipDecision.InvalidAnnouncement(chan_ay)) + peerConnection.expectMsg(GossipDecision.NoRelatedChannel(update_ay)) + peerConnection.expectMsg(GossipDecision.NoKnownChannel(node_y)) + peerConnection.expectNoMsg(100 millis) + router ! Router.TickBroadcast + eventListener.expectNoMsg(100 millis) + } + + { + // validation failure + val priv_x = randomKey + val chan_ax = channelAnnouncement(ShortChannelId(42001), priv_a, priv_x, priv_funding_a, randomKey) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_ax) + watcher.expectMsg(ValidateRequest(chan_ax)) + watcher.send(router, ValidateResult(chan_ax, Left(new RuntimeException("funding tx not found")))) + peerConnection.expectMsg(TransportHandler.ReadAck(chan_ax)) + peerConnection.expectMsg(GossipDecision.ValidationFailure(chan_ax)) + peerConnection.expectNoMsg(100 millis) + router ! Router.TickBroadcast + eventListener.expectNoMsg(100 millis) + } + + { + // funding tx spent (funding tx not confirmed) + val priv_z = randomKey + val priv_funding_z = randomKey + val chan_az = channelAnnouncement(ShortChannelId(42003), priv_a, priv_z, priv_funding_a, priv_funding_z) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_az) + watcher.expectMsg(ValidateRequest(chan_az)) + watcher.send(router, ValidateResult(chan_az, Right(Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_a, priv_funding_z.publicKey)))) :: Nil, lockTime = 0), UtxoStatus.Spent(spendingTxConfirmed = false)))) + peerConnection.expectMsg(TransportHandler.ReadAck(chan_az)) + peerConnection.expectMsg(GossipDecision.ChannelClosing(chan_az)) + peerConnection.expectNoMsg(100 millis) + router ! Router.TickBroadcast + eventListener.expectNoMsg(100 millis) + } + + { + // funding tx spent (funding tx confirmed) + val priv_z = randomKey + val priv_funding_z = randomKey + val chan_az = channelAnnouncement(ShortChannelId(42003), priv_a, priv_z, priv_funding_a, priv_funding_z) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_az) + watcher.expectMsg(ValidateRequest(chan_az)) + watcher.send(router, ValidateResult(chan_az, Right(Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_a, priv_funding_z.publicKey)))) :: Nil, lockTime = 0), UtxoStatus.Spent(spendingTxConfirmed = true)))) + peerConnection.expectMsg(TransportHandler.ReadAck(chan_az)) + peerConnection.expectMsg(GossipDecision.ChannelClosed(chan_az)) + peerConnection.expectNoMsg(100 millis) + router ! Router.TickBroadcast + eventListener.expectNoMsg(100 millis) + } + + watcher.expectNoMsg(100 millis) - eventListener.expectMsg(ChannelsDiscovered(SingleChannelDiscovered(chan_ac, 1000000 sat, None, None) :: Nil)) } test("properly announce lost channels and nodes") { fixture => @@ -122,27 +288,27 @@ class RouterSpec extends BaseRouterSpec { val channelId_ac = ShortChannelId(420000, 5, 0) val chan_ac = channelAnnouncement(channelId_ac, priv_a, priv_c, priv_funding_a, priv_funding_c) val buggy_chan_ac = chan_ac.copy(nodeSignature1 = chan_ac.nodeSignature2) - sender.send(router, PeerRoutingMessage(null, remoteNodeId, buggy_chan_ac)) + sender.send(router, PeerRoutingMessage(sender.ref, remoteNodeId, buggy_chan_ac)) sender.expectMsg(TransportHandler.ReadAck(buggy_chan_ac)) - sender.expectMsg(InvalidSignature(buggy_chan_ac)) + sender.expectMsg(GossipDecision.InvalidSignature(buggy_chan_ac)) } test("handle bad signature for NodeAnnouncement") { fixture => import fixture._ - val sender = TestProbe() - val buggy_ann_a = ann_a.copy(signature = ann_b.signature, timestamp = ann_a.timestamp + 1) - sender.send(router, PeerRoutingMessage(null, remoteNodeId, buggy_ann_a)) - sender.expectMsg(TransportHandler.ReadAck(buggy_ann_a)) - sender.expectMsg(InvalidSignature(buggy_ann_a)) + val peerConnection = TestProbe() + val buggy_ann_a = node_a.copy(signature = node_b.signature, timestamp = node_a.timestamp + 1) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, buggy_ann_a)) + peerConnection.expectMsg(TransportHandler.ReadAck(buggy_ann_a)) + peerConnection.expectMsg(GossipDecision.InvalidSignature(buggy_ann_a)) } test("handle bad signature for ChannelUpdate") { fixture => import fixture._ - val sender = TestProbe() - val buggy_channelUpdate_ab = channelUpdate_ab.copy(signature = ann_b.signature, timestamp = channelUpdate_ab.timestamp + 1) - sender.send(router, PeerRoutingMessage(null, remoteNodeId, buggy_channelUpdate_ab)) - sender.expectMsg(TransportHandler.ReadAck(buggy_channelUpdate_ab)) - sender.expectMsg(InvalidSignature(buggy_channelUpdate_ab)) + val peerConnection = TestProbe() + val buggy_channelUpdate_ab = update_ab.copy(signature = node_b.signature, timestamp = update_ab.timestamp + 1) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, buggy_channelUpdate_ab)) + peerConnection.expectMsg(TransportHandler.ReadAck(buggy_channelUpdate_ab)) + peerConnection.expectMsg(GossipDecision.InvalidSignature(buggy_channelUpdate_ab)) } test("route not found (unreachable target)") { fixture => @@ -196,14 +362,15 @@ class RouterSpec extends BaseRouterSpec { test("route not found (channel disabled)") { fixture => import fixture._ val sender = TestProbe() + val peerConnection = TestProbe() sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT, routeParams = relaxedRouteParams)) val res = sender.expectMsgType[RouteResponse] assert(res.hops.map(_.nodeId).toList === a :: b :: c :: Nil) assert(res.hops.last.nextNodeId === d) val channelUpdate_cd1 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_c, d, channelId_cd, CltvExpiryDelta(3), 0 msat, 153000 msat, 4, 500000000L msat, enable = false) - sender.send(router, PeerRoutingMessage(null, remoteNodeId, channelUpdate_cd1)) - sender.expectMsg(TransportHandler.ReadAck(channelUpdate_cd1)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, channelUpdate_cd1)) + peerConnection.expectMsg(TransportHandler.ReadAck(channelUpdate_cd1)) sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT, routeParams = relaxedRouteParams)) sender.expectMsg(Failure(RouteNotFound)) } @@ -271,7 +438,7 @@ class RouterSpec extends BaseRouterSpec { // the route hasn't changed (nodes are the same) assert(response.hops.map(_.nodeId).toList == preComputedRoute.dropRight(1).toList) assert(response.hops.last.nextNodeId == preComputedRoute.last) - assert(response.hops.map(_.lastUpdate).toList == List(channelUpdate_ab, channelUpdate_bc, channelUpdate_cd)) + assert(response.hops.map(_.lastUpdate).toList == List(update_ab, update_bc, update_cd)) } test("ask for channels that we marked as stale for which we receive a new update") { fixture => @@ -283,9 +450,9 @@ class RouterSpec extends BaseRouterSpec { val update = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, c, channelId, CltvExpiryDelta(7), 0 msat, 766000 msat, 10, 5 msat, timestamp = timestamp) val probe = TestProbe() probe.ignoreMsg { case _: TransportHandler.ReadAck => true } - probe.send(router, PeerRoutingMessage(null, remoteNodeId, announcement)) + probe.send(router, PeerRoutingMessage(probe.ref, remoteNodeId, announcement)) watcher.expectMsgType[ValidateRequest] - probe.send(router, PeerRoutingMessage(null, remoteNodeId, update)) + probe.send(router, PeerRoutingMessage(probe.ref, remoteNodeId, update)) watcher.send(router, ValidateResult(announcement, Right((Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_a, funding_c)))) :: Nil, lockTime = 0), UtxoStatus.Unspent)))) probe.send(router, TickPruneStaleChannels) @@ -296,9 +463,11 @@ class RouterSpec extends BaseRouterSpec { val update1 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, c, channelId, CltvExpiryDelta(7), 0 msat, 766000 msat, 10, 500000000L msat, timestamp = Platform.currentTime.millisecond.toSeconds) // we want to make sure that transport receives the query - val transport = TestProbe() - probe.send(router, PeerRoutingMessage(transport.ref, remoteNodeId, update1)) - val query = transport.expectMsgType[QueryShortChannelIds] + val peerConnection = TestProbe() + peerConnection.ignoreMsg { case _: GossipDecision.Duplicate => true } + probe.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update1)) + peerConnection.expectMsg(TransportHandler.ReadAck(update1)) + val query = peerConnection.expectMsgType[QueryShortChannelIds] assert(query.shortChannelIds.array == List(channelId)) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala index 35c2b77be..94d167b99 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala @@ -27,6 +27,8 @@ import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.io.Peer.PeerRoutingMessage import fr.acinq.eclair.router.Announcements.{makeChannelUpdate, makeNodeAnnouncement} import fr.acinq.eclair.router.BaseRouterSpec.channelAnnouncement +import fr.acinq.eclair.router.Router.{Data, GossipDecision, PublicChannel, SendChannelQuery, State} +import fr.acinq.eclair.router.Sync._ import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire._ import org.scalatest.{FunSuiteLike, ParallelTestExecution} @@ -77,10 +79,12 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike wit pipe.ignoreMsg { case _: TransportHandler.ReadAck => true case _: GossipTimestampFilter => true + case _: GossipDecision.Duplicate => true + case _: GossipDecision.Accepted => true } val srcId = src.underlyingActor.nodeParams.nodeId val tgtId = tgt.underlyingActor.nodeParams.nodeId - sender.send(src, SendChannelQuery(tgtId, pipe.ref, extendedQueryFlags_opt)) + sender.send(src, SendChannelQuery(src.underlyingActor.nodeParams.chainHash, tgtId, pipe.ref, extendedQueryFlags_opt)) // src sends a query_channel_range to bob val qcr = pipe.expectMsgType[QueryChannelRange] pipe.send(tgt, PeerRoutingMessage(pipe.ref, srcId, qcr)) @@ -253,7 +257,7 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike wit val remoteNodeId = TestConstants.Bob.nodeParams.nodeId // ask router to send a channel range query - sender.send(router, SendChannelQuery(remoteNodeId, sender.ref, None)) + sender.send(router, SendChannelQuery(params.chainHash, remoteNodeId, sender.ref, None)) val QueryChannelRange(chainHash, firstBlockNum, numberOfBlocks, _) = sender.expectMsgType[QueryChannelRange] sender.expectMsgType[GossipTimestampFilter] @@ -269,7 +273,7 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike wit assert(sync.total == 1) // simulate a re-connection - sender.send(router, SendChannelQuery(remoteNodeId, sender.ref, None)) + sender.send(router, SendChannelQuery(params.chainHash, remoteNodeId, sender.ref, None)) sender.expectMsgType[QueryChannelRange] sender.expectMsgType[GossipTimestampFilter] assert(router.stateData.sync.get(remoteNodeId).isEmpty) @@ -282,17 +286,17 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike wit val nodeidA = randomKey.publicKey val nodeidB = randomKey.publicKey - val (sync1, _) = Router.addToSync(Map.empty, nodeidA, List(req, req, req, req)) - assert(Router.syncProgress(sync1) == SyncProgress(0.25D)) + val (sync1, _) = addToSync(Map.empty, nodeidA, List(req, req, req, req)) + assert(syncProgress(sync1) == SyncProgress(0.25D)) - val (sync2, _) = Router.addToSync(sync1, nodeidB, List(req, req, req, req, req, req, req, req, req, req, req, req)) - assert(Router.syncProgress(sync2) == SyncProgress(0.125D)) + val (sync2, _) = addToSync(sync1, nodeidB, List(req, req, req, req, req, req, req, req, req, req, req, req)) + assert(syncProgress(sync2) == SyncProgress(0.125D)) // let's assume we made some progress val sync3 = sync2 .updated(nodeidA, sync2(nodeidA).copy(pending = List(req))) .updated(nodeidB, sync2(nodeidB).copy(pending = List(req))) - assert(Router.syncProgress(sync3) == SyncProgress(0.875D)) + assert(syncProgress(sync3) == SyncProgress(0.875D)) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/ExtendedQueriesCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/ExtendedQueriesCodecsSpec.scala index 05b01530d..8f536ed32 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/ExtendedQueriesCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/ExtendedQueriesCodecsSpec.scala @@ -17,7 +17,7 @@ package fr.acinq.eclair.wire import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64} -import fr.acinq.eclair.router.Router +import fr.acinq.eclair.router.Sync import fr.acinq.eclair.wire.LightningMessageCodecs._ import fr.acinq.eclair.wire.ReplyChannelRangeTlv._ import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, ShortChannelId, UInt64} @@ -163,7 +163,7 @@ class ExtendedQueriesCodecsSpec extends FunSuite { val check = ByteVector.fromValidHex("010276df7e70c63cc2b63ef1c062b99c6d934a80ef2fd4dae9e1d86d277f47674af3255a97fa52ade7f129263f591ed784996eba6383135896cc117a438c8029328206226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f00006700000100005d50f933000000900000000000000000000003e80000000a") assert(LightningMessageCodecs.channelUpdateCodec.encode(update).require.bytes == check.drop(2)) - val checksum = Router.getChecksum(update) + val checksum = Sync.getChecksum(update) assert(checksum == 0x1112fa30L) } @@ -184,7 +184,7 @@ class ExtendedQueriesCodecsSpec extends FunSuite { val check = ByteVector.fromValidHex("010206737e9e18d3e4d0ab4066ccaecdcc10e648c5f1c5413f1610747e0d463fa7fa39c1b02ea2fd694275ecfefe4fe9631f24afd182ab75b805e16cd550941f858c06226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f00006d00000100005d50f935010000300000000000000000000000640000000b00000000000186a0") assert(LightningMessageCodecs.channelUpdateCodec.encode(update).require.bytes == check.drop(2)) - val checksum = Router.getChecksum(update) + val checksum = Sync.getChecksum(update) assert(checksum == 0xf32ce968L) } } diff --git a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/GUIUpdater.scala b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/GUIUpdater.scala index ae8ccf692..6df8ff675 100644 --- a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/GUIUpdater.scala +++ b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/GUIUpdater.scala @@ -27,7 +27,8 @@ import fr.acinq.eclair.blockchain.electrum.ElectrumClient.{ElectrumDisconnected, import fr.acinq.eclair.channel._ import fr.acinq.eclair.gui.controllers._ import fr.acinq.eclair.payment._ -import fr.acinq.eclair.router.{NORMAL => _, _} +import fr.acinq.eclair.router.{Announcements, ChannelLost, ChannelUpdatesReceived, ChannelsDiscovered, NodeLost, NodeUpdated, NodesDiscovered, SingleChannelDiscovered} +import fr.acinq.eclair.router.Router.{NORMAL => _, _} import javafx.application.Platform import javafx.fxml.FXMLLoader import javafx.scene.layout.VBox diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/JsonSerializers.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/JsonSerializers.scala index 441c00285..c05abc8e6 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/JsonSerializers.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/JsonSerializers.scala @@ -27,13 +27,12 @@ import fr.acinq.eclair.channel.{ChannelCommandResponse, ChannelVersion, State} import fr.acinq.eclair.crypto.ShaChain import fr.acinq.eclair.db.{IncomingPaymentStatus, OutgoingPaymentStatus} import fr.acinq.eclair.payment._ -import fr.acinq.eclair.router.RouteResponse +import fr.acinq.eclair.router.Router.RouteResponse import fr.acinq.eclair.transactions.DirectedHtlc import fr.acinq.eclair.transactions.Transactions.{InputInfo, TransactionWithInputInfo} import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, MilliSatoshi, ShortChannelId, UInt64} import org.json4s.JsonAST._ -import org.json4s.jackson.Serialization import org.json4s.{CustomKeySerializer, CustomSerializer, DefaultFormats, Extraction, TypeHints, jackson} import scodec.bits.ByteVector