diff --git a/eclair-core/src/main/resources/reference.conf b/eclair-core/src/main/resources/reference.conf index fd58844a1..9bf4cb33f 100644 --- a/eclair-core/src/main/resources/reference.conf +++ b/eclair-core/src/main/resources/reference.conf @@ -318,6 +318,13 @@ eclair { max-no-channels = 250 // maximum number of incoming connections from peers that do not have any channels with us } + // When relaying payments or messages to mobile peers who are disconnected, we may try to wake them up using a mobile + // notification system, or we attempt connecting to the last known address. + peer-wake-up { + enabled = false + timeout = 60 seconds + } + auto-reconnect = true initial-random-reconnect-delay = 5 seconds // we add a random delay before the first reconnection attempt, capped by this value max-reconnect-interval = 1 hour // max interval between two reconnection attempts, after the exponential backoff period diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index 1c66625b5..25b355604 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -28,7 +28,7 @@ import fr.acinq.eclair.crypto.Noise.KeyPair import fr.acinq.eclair.crypto.keymanager.{ChannelKeyManager, NodeKeyManager, OnChainKeyManager} import fr.acinq.eclair.db._ import fr.acinq.eclair.io.MessageRelay.{RelayAll, RelayChannelsOnly, RelayPolicy} -import fr.acinq.eclair.io.PeerConnection +import fr.acinq.eclair.io.{PeerConnection, PeerReadyNotifier} import fr.acinq.eclair.message.OnionMessages.OnionMessageConfig import fr.acinq.eclair.payment.relay.Relayer.{AsyncPaymentsParams, RelayFees, RelayParams} import fr.acinq.eclair.router.Announcements.AddressException @@ -87,7 +87,8 @@ case class NodeParams(nodeKeyManager: NodeKeyManager, blockchainWatchdogSources: Seq[String], onionMessageConfig: OnionMessageConfig, purgeInvoicesInterval: Option[FiniteDuration], - revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config) { + revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config, + peerWakeUpConfig: PeerReadyNotifier.WakeUpConfig) { val privateKey: Crypto.PrivateKey = nodeKeyManager.nodeKey.privateKey val nodeId: PublicKey = nodeKeyManager.nodeId @@ -611,7 +612,11 @@ object NodeParams extends Logging { revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config( batchSize = config.getInt("db.revoked-htlc-info-cleaner.batch-size"), interval = FiniteDuration(config.getDuration("db.revoked-htlc-info-cleaner.interval").getSeconds, TimeUnit.SECONDS) - ) + ), + peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig( + enabled = config.getBoolean("peer-wake-up.enabled"), + timeout = FiniteDuration(config.getDuration("peer-wake-up.timeout").getSeconds, TimeUnit.SECONDS) + ), ) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala index b6ca12c5e..2b28bb041 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala @@ -360,7 +360,8 @@ class Setup(val datadir: File, offerManager = system.spawn(Behaviors.supervise(OfferManager(nodeParams, router, paymentTimeout = 1 minute)).onFailure(typed.SupervisorStrategy.resume), name = "offer-manager") paymentHandler = system.actorOf(SimpleSupervisor.props(PaymentHandler.props(nodeParams, register, offerManager), "payment-handler", SupervisorStrategy.Resume)) triggerer = system.spawn(Behaviors.supervise(AsyncPaymentTriggerer()).onFailure(typed.SupervisorStrategy.resume), name = "async-payment-triggerer") - relayer = system.actorOf(SimpleSupervisor.props(Relayer.props(nodeParams, router, register, paymentHandler, triggerer, Some(postRestartCleanUpInitialized)), "relayer", SupervisorStrategy.Resume)) + peerReadyManager = system.spawn(Behaviors.supervise(PeerReadyManager()).onFailure(typed.SupervisorStrategy.restart), name = "peer-ready-manager") + relayer = system.actorOf(SimpleSupervisor.props(Relayer.props(nodeParams, router, register, paymentHandler, Some(postRestartCleanUpInitialized)), "relayer", SupervisorStrategy.Resume)) _ = relayer ! PostRestartHtlcCleaner.Init(channels) // Before initializing the switchboard (which re-connects us to the network) and the user-facing parts of the system, // we want to make sure the handler for post-restart broken HTLCs has finished initializing. diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala index a25a16691..368d883c1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala @@ -44,29 +44,18 @@ object MessageRelay { policy: RelayPolicy, replyTo_opt: Option[typed.ActorRef[Status]]) extends Command case class WrappedPeerInfo(peerInfo: PeerInfoResponse) extends Command - case class WrappedConnectionResult(result: PeerConnection.ConnectionResult) extends Command - case class WrappedOptionalNodeId(nodeId_opt: Option[PublicKey]) extends Command + private case class WrappedConnectionResult(result: PeerConnection.ConnectionResult) extends Command + private case class WrappedOptionalNodeId(nodeId_opt: Option[PublicKey]) extends Command + private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command - sealed trait Status { - val messageId: ByteVector32 - } + sealed trait Status { val messageId: ByteVector32 } case class Sent(messageId: ByteVector32) extends Status sealed trait Failure extends Status - case class AgainstPolicy(messageId: ByteVector32, policy: RelayPolicy) extends Failure { - override def toString: String = s"Relay prevented by policy $policy" - } - case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure { - override def toString: String = s"Can't connect to peer: ${failure.toString}" - } - case class Disconnected(messageId: ByteVector32) extends Failure { - override def toString: String = "Peer is not connected" - } - case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure { - override def toString: String = s"Unknown channel: $channelId" - } - case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure { - override def toString: String = s"Message dropped: $reason" - } + case class AgainstPolicy(messageId: ByteVector32, policy: RelayPolicy) extends Failure { override def toString: String = s"Relay prevented by policy $policy" } + case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure { override def toString: String = s"Can't connect to peer: ${failure.toString}" } + case class Disconnected(messageId: ByteVector32) extends Failure { override def toString: String = "Peer is not connected" } + case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure { override def toString: String = s"Unknown channel: $channelId" } + case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure { override def toString: String = s"Message dropped: $reason" } sealed trait RelayPolicy case object RelayChannelsOnly extends RelayPolicy @@ -106,7 +95,7 @@ private class MessageRelay(nodeParams: NodeParams, def queryNextNodeId(msg: OnionMessage, nextNode: Either[ShortChannelId, EncodedNodeId]): Behavior[Command] = { nextNode match { case Left(outgoingChannelId) if outgoingChannelId == ShortChannelId.toSelf => - withNextNodeId(msg, nodeParams.nodeId) + withNextNodeId(msg, EncodedNodeId.WithPublicKey.Plain(nodeParams.nodeId)) case Left(outgoingChannelId) => register ! Register.GetNextNodeId(context.messageAdapter(WrappedOptionalNodeId), outgoingChannelId) waitForNextNodeId(msg, outgoingChannelId) @@ -114,7 +103,7 @@ private class MessageRelay(nodeParams: NodeParams, router ! Router.GetNodeId(context.messageAdapter(WrappedOptionalNodeId), scid, isNode1) waitForNextNodeId(msg, scid) case Right(encodedNodeId: EncodedNodeId.WithPublicKey) => - withNextNodeId(msg, encodedNodeId.publicKey) + withNextNodeId(msg, encodedNodeId) } } @@ -127,34 +116,39 @@ private class MessageRelay(nodeParams: NodeParams, Behaviors.stopped case WrappedOptionalNodeId(Some(nextNodeId)) => log.info("found outgoing node {} for channel {}", nextNodeId, channelId) - withNextNodeId(msg, nextNodeId) + withNextNodeId(msg, EncodedNodeId.WithPublicKey.Plain(nextNodeId)) } } - private def withNextNodeId(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = { - if (nextNodeId == nodeParams.nodeId) { - OnionMessages.process(nodeParams.privateKey, msg) match { - case OnionMessages.DropMessage(reason) => - Metrics.OnionMessagesNotRelayed.withTag(Tags.Reason, reason.getClass.getSimpleName).increment() - replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason)) - Behaviors.stopped - case OnionMessages.SendMessage(nextNode, nextMessage) => - // We need to repeat the process until we identify the (real) next node, or find out that we're the recipient. - queryNextNodeId(nextMessage, nextNode) - case received: OnionMessages.ReceiveMessage => - context.system.eventStream ! EventStream.Publish(received) - replyTo_opt.foreach(_ ! Sent(messageId)) - Behaviors.stopped - } - } else { - policy match { - case RelayChannelsOnly => - switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId) - waitForPreviousPeerForPolicyCheck(msg, nextNodeId) - case RelayAll => - switchboard ! Peer.Connect(nextNodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false) - waitForConnection(msg, nextNodeId) - } + private def withNextNodeId(msg: OnionMessage, nextNodeId: EncodedNodeId.WithPublicKey): Behavior[Command] = { + nextNodeId match { + case EncodedNodeId.WithPublicKey.Plain(nodeId) if nodeId == nodeParams.nodeId => + OnionMessages.process(nodeParams.privateKey, msg) match { + case OnionMessages.DropMessage(reason) => + Metrics.OnionMessagesNotRelayed.withTag(Tags.Reason, reason.getClass.getSimpleName).increment() + replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason)) + Behaviors.stopped + case OnionMessages.SendMessage(nextNode, nextMessage) => + // We need to repeat the process until we identify the (real) next node, or find out that we're the recipient. + queryNextNodeId(nextMessage, nextNode) + case received: OnionMessages.ReceiveMessage => + context.system.eventStream ! EventStream.Publish(received) + replyTo_opt.foreach(_ ! Sent(messageId)) + Behaviors.stopped + } + case EncodedNodeId.WithPublicKey.Plain(nodeId) => + policy match { + case RelayChannelsOnly => + switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId) + waitForPreviousPeerForPolicyCheck(msg, nodeId) + case RelayAll => + switchboard ! Peer.Connect(nodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false) + waitForConnection(msg, nodeId) + } + case EncodedNodeId.WithPublicKey.Wallet(nodeId) => + val notifier = context.spawnAnonymous(PeerReadyNotifier(nodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout)))) + notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) + waitForWalletNodeUp(msg, nodeId) } } @@ -197,4 +191,18 @@ private class MessageRelay(nodeParams: NodeParams, Behaviors.stopped } } + + private def waitForWalletNodeUp(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = { + Behaviors.receiveMessagePartial { + case WrappedPeerReadyResult(r: PeerReadyNotifier.PeerReady) => + log.info("successfully woke up {}: relaying onion message", nextNodeId) + r.peer ! Peer.RelayOnionMessage(messageId, msg, replyTo_opt) + Behaviors.stopped + case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) => + Metrics.OnionMessagesNotRelayed.withTag(Tags.Reason, Tags.Reasons.ConnectionFailure).increment() + log.info("could not wake up {}: onion message cannot be relayed", nextNodeId) + replyTo_opt.foreach(_ ! Disconnected(messageId)) + Behaviors.stopped + } + } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala index 81d6c71b5..f4a8b8d67 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala @@ -17,36 +17,104 @@ package fr.acinq.eclair.io import akka.actor.typed.eventstream.EventStream -import akka.actor.typed.receptionist.Receptionist +import akka.actor.typed.receptionist.{Receptionist, ServiceKey} import akka.actor.typed.scaladsl.adapter.{ClassicActorRefOps, TypedActorRefOps} import akka.actor.typed.scaladsl.{ActorContext, Behaviors, TimerScheduler} -import akka.actor.typed.{ActorRef, Behavior} +import akka.actor.typed.{ActorRef, Behavior, SupervisorStrategy} import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.blockchain.CurrentBlockHeight import fr.acinq.eclair.{BlockHeight, Logs, channel} import scala.concurrent.duration.{DurationInt, FiniteDuration} +/** + * This actor tracks the set of pending [[PeerReadyNotifier]]. + * It can be used to ensure that notifications are only sent once, even if there are multiple parallel operations + * waiting for that peer to come online. + */ +object PeerReadyManager { + + val PeerReadyManagerServiceKey: ServiceKey[Register] = ServiceKey[Register]("peer-ready-manager") + + // @formatter:off + sealed trait Command + case class Register(replyTo: ActorRef[Registered], remoteNodeId: PublicKey) extends Command + case class List(replyTo: ActorRef[Set[PublicKey]]) extends Command + private case class Completed(remoteNodeId: PublicKey, actor: ActorRef[Registered]) extends Command + // @formatter:on + + /** + * @param otherAttempts number of already pending [[PeerReadyNotifier]] instances for that peer. + */ + case class Registered(remoteNodeId: PublicKey, otherAttempts: Int) + + def apply(): Behavior[Command] = { + Behaviors.setup { context => + context.system.receptionist ! Receptionist.Register(PeerReadyManagerServiceKey, context.self) + watch(Map.empty, context) + } + } + + private def watch(pending: Map[PublicKey, Set[ActorRef[Registered]]], context: ActorContext[Command]): Behavior[Command] = { + Behaviors.receiveMessage { + case Register(replyTo, remoteNodeId) => + context.watchWith(replyTo, Completed(remoteNodeId, replyTo)) + pending.get(remoteNodeId) match { + case Some(attempts) => + replyTo ! Registered(remoteNodeId, otherAttempts = attempts.size) + val attempts1 = attempts + replyTo + watch(pending + (remoteNodeId -> attempts1), context) + case None => + replyTo ! Registered(remoteNodeId, otherAttempts = 0) + watch(pending + (remoteNodeId -> Set(replyTo)), context) + } + case Completed(remoteNodeId, actor) => + pending.get(remoteNodeId) match { + case Some(attempts) => + val attempts1 = attempts - actor + if (attempts1.isEmpty) { + watch(pending - remoteNodeId, context) + } else { + watch(pending + (remoteNodeId -> attempts1), context) + } + case None => + Behaviors.same + } + case List(replyTo) => + replyTo ! pending.keySet + Behaviors.same + } + } + +} + /** * This actor waits for a given peer to be online and ready to process payments. - * It automatically stops after the timeout provided. + * It automatically stops after the timeout provided if the peer doesn't connect. + * There may be multiple instances of this actor running in parallel for the same peer, which is fine because they + * may use different timeouts. + * Having separate actor instances for each caller guarantees that the caller will always receive a response. */ object PeerReadyNotifier { + case class WakeUpConfig(enabled: Boolean, timeout: FiniteDuration) + // @formatter:off sealed trait Command case class NotifyWhenPeerReady(replyTo: ActorRef[Result]) extends Command private final case class WrappedListing(wrapped: Receptionist.Listing) extends Command + private final case class WrappedRegistered(registered: PeerReadyManager.Registered) extends Command private case object PeerNotConnected extends Command - private case class SomePeerConnected(nodeId: PublicKey) extends Command - private case class SomePeerDisconnected(nodeId: PublicKey) extends Command + private case object PeerConnected extends Command + private case object PeerDisconnected extends Command private case class WrappedPeerInfo(peer: ActorRef[Peer.GetPeerChannels], channelCount: Int) extends Command private case class NewBlockNotTimedOut(currentBlockHeight: BlockHeight) extends Command private case object CheckChannelsReady extends Command private case class WrappedPeerChannels(wrapped: Peer.PeerChannels) extends Command private case object Timeout extends Command + private case object ToBeIgnored extends Command - sealed trait Result + sealed trait Result { def remoteNodeId: PublicKey } case class PeerReady(remoteNodeId: PublicKey, peer: akka.actor.ActorRef, channelInfos: Seq[Peer.ChannelInfo]) extends Result { val channelsCount: Int = channelInfos.size } case class PeerUnavailable(remoteNodeId: PublicKey) extends Result @@ -66,102 +134,40 @@ object PeerReadyNotifier { case cbc => NewBlockNotTimedOut(cbc.blockHeight) }) } - // In case the peer is not currently connected, we will wait for them to connect instead of regularly - // polling the switchboard. This makes more sense for long timeouts such as the ones used for async payments. - context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerConnected](e => SomePeerConnected(e.nodeId))) - context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerDisconnected](e => SomePeerDisconnected(e.nodeId))) - findSwitchboard(replyTo, remoteNodeId, context, timers) + // The actor should never throw, but for extra safety we wrap it with a supervisor. + Behaviors.supervise { + start(replyTo, remoteNodeId, context, timers) + }.onFailure(SupervisorStrategy.stop) } } } } } - private def findSwitchboard(replyTo: ActorRef[Result], remoteNodeId: PublicKey, context: ActorContext[Command], timers: TimerScheduler[Command]): Behavior[Command] = { - context.system.receptionist ! Receptionist.Find(Switchboard.SwitchboardServiceKey, context.messageAdapter[Receptionist.Listing](WrappedListing)) + private def start(replyTo: ActorRef[Result], remoteNodeId: PublicKey, context: ActorContext[Command], timers: TimerScheduler[Command]): Behavior[Command] = { + // We start by registering ourself to see if other instances are running. + context.system.receptionist ! Receptionist.Find(PeerReadyManager.PeerReadyManagerServiceKey, context.messageAdapter[Receptionist.Listing](WrappedListing)) Behaviors.receiveMessagePartial { - case WrappedListing(Switchboard.SwitchboardServiceKey.Listing(listings)) => + case WrappedListing(PeerReadyManager.PeerReadyManagerServiceKey.Listing(listings)) => listings.headOption match { - case Some(switchboard) => - waitForPeerConnected(replyTo, remoteNodeId, switchboard, context, timers) + case Some(peerReadyManager) => + peerReadyManager ! PeerReadyManager.Register(context.messageAdapter[PeerReadyManager.Registered](WrappedRegistered), remoteNodeId) + Behaviors.same case None => - context.log.error("no switchboard found") + context.log.error("no peer-ready-manager found") replyTo ! PeerUnavailable(remoteNodeId) Behaviors.stopped - } - } - } - - private def waitForPeerConnected(replyTo: ActorRef[Result], remoteNodeId: PublicKey, switchboard: ActorRef[Switchboard.GetPeerInfo], context: ActorContext[Command], timers: TimerScheduler[Command]): Behavior[Command] = { - val peerInfoAdapter = context.messageAdapter[Peer.PeerInfoResponse] { - // We receive this when we don't have any channel to the given peer and are not currently connected to them. - // In that case we still want to wait for a connection, because we may want to open a channel to them. - case _: Peer.PeerNotFound => PeerNotConnected - case info: Peer.PeerInfo if info.state != Peer.CONNECTED => PeerNotConnected - case info: Peer.PeerInfo => WrappedPeerInfo(info.peer.toTyped, info.channels.size) - } - // We check whether the peer is already connected. - switchboard ! Switchboard.GetPeerInfo(peerInfoAdapter, remoteNodeId) - Behaviors.receiveMessagePartial { - case PeerNotConnected => - context.log.debug("peer is not connected yet") - Behaviors.same - case SomePeerConnected(nodeId) => - if (nodeId == remoteNodeId) { - switchboard ! Switchboard.GetPeerInfo(peerInfoAdapter, remoteNodeId) } - Behaviors.same - case SomePeerDisconnected(_) => - Behaviors.same - case WrappedPeerInfo(peer, channelCount) => - if (channelCount == 0) { - context.log.info("peer is ready with no channels") - replyTo ! PeerReady(remoteNodeId, peer.toClassic, Seq.empty) - Behaviors.stopped - } else { - context.log.debug("peer is connected with {} channels", channelCount) - waitForChannelsReady(replyTo, remoteNodeId, peer, switchboard, context, timers) - } - case NewBlockNotTimedOut(currentBlockHeight) => - context.log.debug("waiting for peer to connect at block {}", currentBlockHeight) - Behaviors.same + case WrappedRegistered(registered) => + context.log.info("checking if peer is ready ({} other attempts)", registered.otherAttempts) + val isFirstAttempt = registered.otherAttempts == 0 + // In case the peer is not currently connected, we will wait for them to connect instead of regularly + // polling the switchboard. This makes more sense for long timeouts such as the ones used for async payments. + context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerConnected](e => if (e.nodeId == remoteNodeId) PeerConnected else ToBeIgnored)) + context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerDisconnected](e => if (e.nodeId == remoteNodeId) PeerDisconnected else ToBeIgnored)) + new PeerReadyNotifier(replyTo, remoteNodeId, isFirstAttempt, context, timers).findSwitchboard() case Timeout => - context.log.info("timed out waiting for peer to connect") - replyTo ! PeerUnavailable(remoteNodeId) - Behaviors.stopped - } - } - - private def waitForChannelsReady(replyTo: ActorRef[Result], remoteNodeId: PublicKey, peer: ActorRef[Peer.GetPeerChannels], switchboard: ActorRef[Switchboard.GetPeerInfo], context: ActorContext[Command], timers: TimerScheduler[Command]): Behavior[Command] = { - timers.startTimerWithFixedDelay(ChannelsReadyTimerKey, CheckChannelsReady, initialDelay = 50 millis, delay = 1 second) - Behaviors.receiveMessagePartial { - case CheckChannelsReady => - context.log.debug("checking channel states") - peer ! Peer.GetPeerChannels(context.messageAdapter[Peer.PeerChannels](WrappedPeerChannels)) - Behaviors.same - case WrappedPeerChannels(peerChannels) => - if (peerChannels.channels.map(_.state).forall(isChannelReady)) { - replyTo ! PeerReady(remoteNodeId, peer.toClassic, peerChannels.channels) - Behaviors.stopped - } else { - context.log.debug("peer has {} channels that are not ready", peerChannels.channels.count(s => !isChannelReady(s.state))) - Behaviors.same - } - case NewBlockNotTimedOut(currentBlockHeight) => - context.log.debug("waiting for channels to be ready at block {}", currentBlockHeight) - Behaviors.same - case SomePeerConnected(_) => - Behaviors.same - case SomePeerDisconnected(nodeId) => - if (nodeId == remoteNodeId) { - context.log.debug("peer disconnected, waiting for them to reconnect") - timers.cancel(ChannelsReadyTimerKey) - waitForPeerConnected(replyTo, remoteNodeId, switchboard, context, timers) - } else { - Behaviors.same - } - case Timeout => - context.log.info("timed out waiting for channels to be ready") + context.log.info("timed out finding peer-ready-manager actor") replyTo ! PeerUnavailable(remoteNodeId) Behaviors.stopped } @@ -199,3 +205,109 @@ object PeerReadyNotifier { } } + +private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], + remoteNodeId: PublicKey, + isFirstAttempt: Boolean, + context: ActorContext[PeerReadyNotifier.Command], + timers: TimerScheduler[PeerReadyNotifier.Command]) { + + import PeerReadyNotifier._ + + private val log = context.log + + private def findSwitchboard(): Behavior[Command] = { + context.system.receptionist ! Receptionist.Find(Switchboard.SwitchboardServiceKey, context.messageAdapter[Receptionist.Listing](WrappedListing)) + Behaviors.receiveMessagePartial { + case WrappedListing(Switchboard.SwitchboardServiceKey.Listing(listings)) => + listings.headOption match { + case Some(switchboard) => + waitForPeerConnected(switchboard) + case None => + log.error("no switchboard found") + replyTo ! PeerUnavailable(remoteNodeId) + Behaviors.stopped + } + case Timeout => + log.info("timed out finding switchboard actor") + replyTo ! PeerUnavailable(remoteNodeId) + Behaviors.stopped + case ToBeIgnored => + Behaviors.same + } + } + + private def waitForPeerConnected(switchboard: ActorRef[Switchboard.GetPeerInfo]): Behavior[Command] = { + val peerInfoAdapter = context.messageAdapter[Peer.PeerInfoResponse] { + // We receive this when we don't have any channel to the given peer and are not currently connected to them. + // In that case we still want to wait for a connection, because we may want to open a channel to them. + case _: Peer.PeerNotFound => PeerNotConnected + case info: Peer.PeerInfo if info.state != Peer.CONNECTED => PeerNotConnected + case info: Peer.PeerInfo => WrappedPeerInfo(info.peer.toTyped, info.channels.size) + } + // We check whether the peer is already connected. + switchboard ! Switchboard.GetPeerInfo(peerInfoAdapter, remoteNodeId) + Behaviors.receiveMessagePartial { + case PeerNotConnected => + log.debug("peer is not connected yet") + Behaviors.same + case PeerConnected => + switchboard ! Switchboard.GetPeerInfo(peerInfoAdapter, remoteNodeId) + Behaviors.same + case PeerDisconnected => + Behaviors.same + case WrappedPeerInfo(peer, channelCount) => + if (channelCount == 0) { + log.info("peer is ready with no channels") + replyTo ! PeerReady(remoteNodeId, peer.toClassic, Seq.empty) + Behaviors.stopped + } else { + log.debug("peer is connected with {} channels", channelCount) + waitForChannelsReady(peer, switchboard) + } + case NewBlockNotTimedOut(currentBlockHeight) => + log.debug("waiting for peer to connect at block {}", currentBlockHeight) + Behaviors.same + case Timeout => + log.info("timed out waiting for peer to connect") + replyTo ! PeerUnavailable(remoteNodeId) + Behaviors.stopped + case ToBeIgnored => + Behaviors.same + } + } + + private def waitForChannelsReady(peer: ActorRef[Peer.GetPeerChannels], switchboard: ActorRef[Switchboard.GetPeerInfo]): Behavior[Command] = { + timers.startTimerWithFixedDelay(ChannelsReadyTimerKey, CheckChannelsReady, initialDelay = 50 millis, delay = 1 second) + Behaviors.receiveMessagePartial { + case CheckChannelsReady => + log.debug("checking channel states") + peer ! Peer.GetPeerChannels(context.messageAdapter[Peer.PeerChannels](WrappedPeerChannels)) + Behaviors.same + case WrappedPeerChannels(peerChannels) => + if (peerChannels.channels.map(_.state).forall(isChannelReady)) { + replyTo ! PeerReady(remoteNodeId, peer.toClassic, peerChannels.channels) + Behaviors.stopped + } else { + log.debug("peer has {} channels that are not ready", peerChannels.channels.count(s => !isChannelReady(s.state))) + Behaviors.same + } + case NewBlockNotTimedOut(currentBlockHeight) => + log.debug("waiting for channels to be ready at block {}", currentBlockHeight) + Behaviors.same + case PeerConnected => + Behaviors.same + case PeerDisconnected => + log.debug("peer disconnected, waiting for them to reconnect") + timers.cancel(ChannelsReadyTimerKey) + waitForPeerConnected(switchboard) + case Timeout => + log.info("timed out waiting for channels to be ready") + replyTo ! PeerUnavailable(remoteNodeId) + Behaviors.stopped + case ToBeIgnored => + Behaviors.same + } + } + +} \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala index d9e1c424a..085fa9bc2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala @@ -119,6 +119,7 @@ object Monitoring { val Failure = "failure" object FailureType { + val WakeUp = "WakeUp" val Remote = "Remote" val Malformed = "MalformedHtlc" diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala index 3f17db19c..479315307 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala @@ -126,7 +126,7 @@ object IncomingPaymentPacket { decryptEncryptedRecipientData(add, privateKey, payload, encrypted.data).flatMap { case DecodedEncryptedRecipientData(blindedPayload, nextBlinding) => validateBlindedChannelRelayPayload(add, payload, blindedPayload, nextBlinding, nextPacket).flatMap { - case ChannelRelayPacket(_, payload, nextPacket) if payload.outgoingChannelId == ShortChannelId.toSelf => + case ChannelRelayPacket(_, payload, nextPacket) if payload.outgoing == Right(ShortChannelId.toSelf) => decrypt(add.copy(onionRoutingPacket = nextPacket, tlvStream = add.tlvStream.copy(records = Set(UpdateAddHtlcTlv.BlindingPoint(nextBlinding)))), privateKey, features) case relayPacket => Right(relayPacket) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggerer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggerer.scala index 5adccac80..965e4b520 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggerer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggerer.scala @@ -19,7 +19,7 @@ package fr.acinq.eclair.payment.relay import akka.actor.typed.ActorRef.ActorRefOps import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.{ActorContext, Behaviors} -import akka.actor.typed.{ActorRef, Behavior, SupervisorStrategy} +import akka.actor.typed.{ActorRef, Behavior} import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.Logs.LogCategory @@ -99,7 +99,7 @@ private class AsyncPaymentTriggerer(context: ActorContext[Command]) { case Watch(replyTo, remoteNodeId, paymentHash, timeout) => peers.get(remoteNodeId) match { case None => - val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(remoteNodeId, timeout_opt = None)).onFailure(SupervisorStrategy.stop)) + val notifier = context.spawnAnonymous(PeerReadyNotifier(remoteNodeId, timeout_opt = None)) context.watchWith(notifier, NotifierStopped(remoteNodeId)) notifier ! NotifyWhenPeerReady(context.messageAdapter[PeerReadyNotifier.Result](WrappedPeerReadyResult)) val peer = PeerPayments(notifier, Set(Payment(replyTo, timeout, paymentHash))) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala index 25b6cbe2c..1181f75bc 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala @@ -16,16 +16,17 @@ package fr.acinq.eclair.payment.relay +import akka.actor.ActorRef import akka.actor.typed.Behavior import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.scaladsl.{ActorContext, Behaviors} -import akka.actor.{ActorRef, typed} import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.PendingCommandsDb +import fr.acinq.eclair.io.PeerReadyNotifier import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment.relay.Relayer.{OutgoingChannel, OutgoingChannelParams} import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket} @@ -44,6 +45,7 @@ object ChannelRelay { // @formatter:off sealed trait Command private case object DoRelay extends Command + private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command private case class WrappedForwardFailure(failure: Register.ForwardFailure[CMD_ADD_HTLC]) extends Command private case class WrappedAddResponse(res: CommandResponse[CMD_ADD_HTLC]) extends Command // @formatter:on @@ -57,7 +59,7 @@ object ChannelRelay { def apply(nodeParams: NodeParams, register: ActorRef, channels: Map[ByteVector32, Relayer.OutgoingChannel], - originNode:PublicKey, + originNode: PublicKey, relayId: UUID, r: IncomingPaymentPacket.ChannelRelayPacket): Behavior[Command] = Behaviors.setup { context => @@ -67,9 +69,8 @@ object ChannelRelay { paymentHash_opt = Some(r.add.paymentHash), nodeAlias_opt = Some(nodeParams.alias))) { val upstream = Upstream.Hot.Channel(r.add.removeUnknownTlvs(), TimestampMilli.now(), originNode) - context.self ! DoRelay val confidence = (r.add.endorsement + 0.5) / 8 - new ChannelRelay(nodeParams, register, channels, r, upstream, confidence, context).relay(Seq.empty) + new ChannelRelay(nodeParams, register, channels, r, upstream, confidence, context).start() } } @@ -77,7 +78,7 @@ object ChannelRelay { * This helper method translates relaying errors (returned by the downstream outgoing channel) to BOLT 4 standard * errors that we should return upstream. */ - def translateLocalError(error: Throwable, channelUpdate_opt: Option[ChannelUpdate]): FailureMessage = { + private def translateLocalError(error: ChannelException, channelUpdate_opt: Option[ChannelUpdate]): FailureMessage = { (error, channelUpdate_opt) match { case (_: ExpiryTooSmall, Some(channelUpdate)) => ExpiryTooSoon(Some(channelUpdate)) case (_: ExpiryTooBig, _) => ExpiryTooFar() @@ -121,13 +122,57 @@ class ChannelRelay private(nodeParams: NodeParams, private val forwardFailureAdapter = context.messageAdapter[Register.ForwardFailure[CMD_ADD_HTLC]](WrappedForwardFailure) private val addResponseAdapter = context.messageAdapter[CommandResponse[CMD_ADD_HTLC]](WrappedAddResponse) + private val nextBlindingKey_opt = r.payload match { + case payload: IntermediatePayload.ChannelRelay.Blinded => Some(payload.nextBlinding) + case _: IntermediatePayload.ChannelRelay.Standard => None + } + + /** Channel id explicitly requested in the onion payload. */ + private val requestedChannelId_opt = r.payload.outgoing match { + case Left(_) => None + case Right(outgoingChannelId) => channels.collectFirst { + case (channelId, channel) if channel.shortIds.localAlias == outgoingChannelId => channelId + case (channelId, channel) if channel.shortIds.real.toOption.contains(outgoingChannelId) => channelId + } + } + + private val (requestedShortChannelId_opt, walletNodeId_opt) = r.payload.outgoing match { + case Left(walletNodeId) => (None, Some(walletNodeId)) + case Right(shortChannelId) => (Some(shortChannelId), None) + } + private case class PreviouslyTried(channelId: ByteVector32, failure: RES_ADD_FAILED[ChannelException]) + def start(): Behavior[Command] = { + walletNodeId_opt match { + case Some(walletNodeId) if nodeParams.peerWakeUpConfig.enabled => wakeUp(walletNodeId) + case _ => + context.self ! DoRelay + relay(Seq.empty) + } + } + + private def wakeUp(walletNodeId: PublicKey): Behavior[Command] = { + context.log.info("trying to wake up channel peer (nodeId={})", walletNodeId) + val notifier = context.spawnAnonymous(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout)))) + notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) + Behaviors.receiveMessagePartial { + case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) => + Metrics.recordPaymentRelayFailed(Tags.FailureType.WakeUp, Tags.RelayType.Channel) + context.log.info("rejecting htlc: failed to wake-up remote peer") + safeSendAndStop(r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true)) + case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerReady) => + context.self ! DoRelay + relay(Seq.empty) + } + } + def relay(previousFailures: Seq[PreviouslyTried]): Behavior[Command] = { Behaviors.receiveMessagePartial { case DoRelay => if (previousFailures.isEmpty) { - context.log.info("relaying htlc #{} from channelId={} to requestedShortChannelId={} nextNode={}", r.add.id, r.add.channelId, r.payload.outgoingChannelId, nextNodeId_opt.getOrElse("")) + val nextNodeId_opt = channels.headOption.map(_._2.nextNodeId) + context.log.info("relaying htlc #{} from channelId={} to requestedShortChannelId={} nextNode={}", r.add.id, r.add.channelId, requestedShortChannelId_opt, nextNodeId_opt.getOrElse("")) } context.log.debug("attempting relay previousAttempts={}", previousFailures.size) handleRelay(previousFailures) match { @@ -143,7 +188,7 @@ class ChannelRelay private(nodeParams: NodeParams, } } - def waitForAddResponse(selectedChannelId: ByteVector32, previousFailures: Seq[PreviouslyTried]): Behavior[Command] = + private def waitForAddResponse(selectedChannelId: ByteVector32, previousFailures: Seq[PreviouslyTried]): Behavior[Command] = Behaviors.receiveMessagePartial { case WrappedForwardFailure(Register.ForwardFailure(Register.Forward(_, channelId, _))) => context.log.warn(s"couldn't resolve downstream channel $channelId, failing htlc #${upstream.add.id}") @@ -156,23 +201,23 @@ class ChannelRelay private(nodeParams: NodeParams, context.self ! DoRelay relay(previousFailures :+ PreviouslyTried(selectedChannelId, addFailed)) - case WrappedAddResponse(r: RES_SUCCESS[_]) => + case WrappedAddResponse(_: RES_SUCCESS[_]) => context.log.debug("sent htlc to the downstream channel") - waitForAddSettled(r.channelId) + waitForAddSettled() } - def waitForAddSettled(channelId: ByteVector32): Behavior[Command] = + private def waitForAddSettled(): Behavior[Command] = Behaviors.receiveMessagePartial { case WrappedAddResponse(RES_ADD_SETTLED(_, htlc, fulfill: HtlcResult.Fulfill)) => - context.log.info("relaying fulfill to upstream, startedAt={}, endedAt={}, confidence={}, originNode={}, outgoingChannel={}", upstream.receivedAt, TimestampMilli.now(), confidence, upstream.receivedFrom, channelId) + context.log.info("relaying fulfill to upstream, startedAt={}, endedAt={}, confidence={}, originNode={}, outgoingChannel={}", upstream.receivedAt, TimestampMilli.now(), confidence, upstream.receivedFrom, htlc.channelId) Metrics.relayFulfill(confidence) val cmd = CMD_FULFILL_HTLC(upstream.add.id, fulfill.paymentPreimage, commit = true) context.system.eventStream ! EventStream.Publish(ChannelPaymentRelayed(upstream.amountIn, htlc.amountMsat, htlc.paymentHash, upstream.add.channelId, htlc.channelId, upstream.receivedAt, TimestampMilli.now())) recordRelayDuration(isSuccess = true) safeSendAndStop(upstream.add.channelId, cmd) - case WrappedAddResponse(RES_ADD_SETTLED(_, _, fail: HtlcResult.Fail)) => - context.log.info("relaying fail to upstream, startedAt={}, endedAt={}, confidence={}, originNode={}, outgoingChannel={}", upstream.receivedAt, TimestampMilli.now(), confidence, upstream.receivedFrom, channelId) + case WrappedAddResponse(RES_ADD_SETTLED(_, htlc, fail: HtlcResult.Fail)) => + context.log.info("relaying fail to upstream, startedAt={}, endedAt={}, confidence={}, originNode={}, outgoingChannel={}", upstream.receivedAt, TimestampMilli.now(), confidence, upstream.receivedFrom, htlc.channelId) Metrics.relayFail(confidence) Metrics.recordPaymentRelayFailed(Tags.FailureType.Remote, Tags.RelayType.Channel) val cmd = translateRelayFailure(upstream.add.id, fail) @@ -180,7 +225,7 @@ class ChannelRelay private(nodeParams: NodeParams, safeSendAndStop(upstream.add.channelId, cmd) } - def safeSendAndStop(channelId: ByteVector32, cmd: channel.HtlcSettlementCommand): Behavior[Command] = { + private def safeSendAndStop(channelId: ByteVector32, cmd: channel.HtlcSettlementCommand): Behavior[Command] = { val toSend = cmd match { case _: CMD_FULFILL_HTLC => cmd case _: CMD_FAIL_HTLC | _: CMD_FAIL_MALFORMED_HTLC => r.payload match { @@ -211,49 +256,44 @@ class ChannelRelay private(nodeParams: NodeParams, * - a CMD_FAIL_HTLC to be sent back upstream * - a CMD_ADD_HTLC to propagate downstream */ - def handleRelay(previousFailures: Seq[PreviouslyTried]): RelayResult = { + private def handleRelay(previousFailures: Seq[PreviouslyTried]): RelayResult = { val alreadyTried = previousFailures.map(_.channelId) selectPreferredChannel(alreadyTried) match { - case None if previousFailures.nonEmpty => - // no more channels to try - val error = previousFailures - // we return the error for the initially requested channel if it exists - .find(failure => requestedChannelId_opt.contains(failure.channelId)) - // otherwise we return the error for the first channel tried - .getOrElse(previousFailures.head) - .failure - RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(translateLocalError(error.t, error.channelUpdate)), commit = true)) - case outgoingChannel_opt => - relayOrFail(outgoingChannel_opt) + case Some(outgoingChannel) => relayOrFail(outgoingChannel) + case None => + // No more channels to try. + val cmdFail = if (previousFailures.nonEmpty) { + val error = previousFailures + // We return the error for the initially requested channel if it exists. + .find(failure => requestedChannelId_opt.contains(failure.channelId)) + // Otherwise we return the error for the first channel tried. + .getOrElse(previousFailures.head) + .failure + CMD_FAIL_HTLC(r.add.id, Right(translateLocalError(error.t, error.channelUpdate)), commit = true) + } else { + CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true) + } + RelayFailure(cmdFail) } } - /** all the channels point to the same next node, we take the first one */ - private val nextNodeId_opt = channels.headOption.map(_._2.nextNodeId) - - /** channel id explicitly requested in the onion payload */ - private val requestedChannelId_opt = channels.collectFirst { - case (channelId, channel) if channel.shortIds.localAlias == r.payload.outgoingChannelId => channelId - case (channelId, channel) if channel.shortIds.real.toOption.contains(r.payload.outgoingChannelId) => channelId - } - /** * Select a channel to the same node to relay the payment to, that has the lowest capacity and balance and is * compatible in terms of fees, expiry_delta, etc. * * If no suitable channel is found we default to the originally requested channel. */ - def selectPreferredChannel(alreadyTried: Seq[ByteVector32]): Option[OutgoingChannel] = { - val requestedShortChannelId = r.payload.outgoingChannelId - context.log.debug("selecting next channel with requestedShortChannelId={}", requestedShortChannelId) + private def selectPreferredChannel(alreadyTried: Seq[ByteVector32]): Option[OutgoingChannel] = { + context.log.debug("selecting next channel with requestedShortChannelId={}", requestedShortChannelId_opt) // we filter out channels that we have already tried val candidateChannels: Map[ByteVector32, OutgoingChannel] = channels -- alreadyTried // and we filter again to keep the ones that are compatible with this payment (mainly fees, expiry delta) candidateChannels .values .map { channel => - val relayResult = relayOrFail(Some(channel)) - context.log.debug(s"candidate channel: channelId=${channel.channelId} availableForSend={} capacity={} channelUpdate={} result={}", + val relayResult = relayOrFail(channel) + context.log.debug("candidate channel: channelId={} availableForSend={} capacity={} channelUpdate={} result={}", + channel.channelId, channel.commitments.availableBalanceForSend, channel.commitments.latest.capacity, channel.channelUpdate, @@ -279,7 +319,7 @@ class ChannelRelay private(nodeParams: NodeParams, context.log.debug("requested short channel id is our preferred channel") Some(channel) } else { - context.log.debug("replacing requestedShortChannelId={} by preferredShortChannelId={} with availableBalanceMsat={}", requestedShortChannelId, channel.channelUpdate.shortChannelId, channel.commitments.availableBalanceForSend) + context.log.debug("replacing requestedShortChannelId={} by preferredShortChannelId={} with availableBalanceMsat={}", requestedShortChannelId_opt, channel.channelUpdate.shortChannelId, channel.commitments.availableBalanceForSend) Some(channel) } case None => @@ -300,28 +340,35 @@ class ChannelRelay private(nodeParams: NodeParams, * channel, because some parameters don't match with our settings for that channel. In that case we directly fail the * htlc. */ - def relayOrFail(outgoingChannel_opt: Option[OutgoingChannelParams]): RelayResult = { - outgoingChannel_opt match { + private def relayOrFail(outgoingChannel: OutgoingChannelParams): RelayResult = { + val update = outgoingChannel.channelUpdate + validateRelayParams(outgoingChannel) match { + case Some(fail) => + RelayFailure(fail) + case None if !update.channelFlags.isEnabled => + RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(ChannelDisabled(update.messageFlags, update.channelFlags, Some(update))), commit = true)) case None => - RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true)) - case Some(c) if !c.channelUpdate.channelFlags.isEnabled => - RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(ChannelDisabled(c.channelUpdate.messageFlags, c.channelUpdate.channelFlags, Some(c.channelUpdate))), commit = true)) - case Some(c) if r.amountToForward < c.channelUpdate.htlcMinimumMsat => - RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(AmountBelowMinimum(r.amountToForward, Some(c.channelUpdate))), commit = true)) - case Some(c) if r.expiryDelta < c.channelUpdate.cltvExpiryDelta => - RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(IncorrectCltvExpiry(r.outgoingCltv, Some(c.channelUpdate))), commit = true)) - case Some(c) if r.relayFeeMsat < nodeFee(c.channelUpdate.relayFees, r.amountToForward) && - // fees also do not satisfy the previous channel update for `enforcementDelay` seconds after current update - (TimestampSecond.now() - c.channelUpdate.timestamp > nodeParams.relayParams.enforcementDelay || - outgoingChannel_opt.flatMap(_.prevChannelUpdate).forall(c => r.relayFeeMsat < nodeFee(c.relayFees, r.amountToForward))) => - RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(FeeInsufficient(r.add.amountMsat, Some(c.channelUpdate))), commit = true)) - case Some(c: OutgoingChannel) => val origin = Origin.Hot(addResponseAdapter.toClassic, upstream) - val nextBlindingKey_opt = r.payload match { - case payload: IntermediatePayload.ChannelRelay.Blinded => Some(payload.nextBlinding) - case _: IntermediatePayload.ChannelRelay.Standard => None - } - RelaySuccess(c.channelId, CMD_ADD_HTLC(addResponseAdapter.toClassic, r.amountToForward, r.add.paymentHash, r.outgoingCltv, r.nextPacket, nextBlindingKey_opt, confidence, origin, commit = true)) + RelaySuccess(outgoingChannel.channelId, CMD_ADD_HTLC(addResponseAdapter.toClassic, r.amountToForward, r.add.paymentHash, r.outgoingCltv, r.nextPacket, nextBlindingKey_opt, confidence, origin, commit = true)) + } + } + + private def validateRelayParams(outgoingChannel: OutgoingChannelParams): Option[CMD_FAIL_HTLC] = { + val update = outgoingChannel.channelUpdate + // If our current channel update was recently created, we accept payments that used our previous channel update. + val allowPreviousUpdate = TimestampSecond.now() - update.timestamp <= nodeParams.relayParams.enforcementDelay + val prevUpdate_opt = if (allowPreviousUpdate) outgoingChannel.prevChannelUpdate else None + val htlcMinimumOk = update.htlcMinimumMsat <= r.amountToForward || prevUpdate_opt.exists(_.htlcMinimumMsat <= r.amountToForward) + val expiryDeltaOk = update.cltvExpiryDelta <= r.expiryDelta || prevUpdate_opt.exists(_.cltvExpiryDelta <= r.expiryDelta) + val feesOk = nodeFee(update.relayFees, r.amountToForward) <= r.relayFeeMsat || prevUpdate_opt.exists(u => nodeFee(u.relayFees, r.amountToForward) <= r.relayFeeMsat) + if (!htlcMinimumOk) { + Some(CMD_FAIL_HTLC(r.add.id, Right(AmountBelowMinimum(r.amountToForward, Some(update))), commit = true)) + } else if (!expiryDeltaOk) { + Some(CMD_FAIL_HTLC(r.add.id, Right(IncorrectCltvExpiry(r.outgoingCltv, Some(update))), commit = true)) + } else if (!feesOk) { + Some(CMD_FAIL_HTLC(r.add.id, Right(FeeInsufficient(r.add.amountMsat, Some(update))), commit = true)) + } else { + None } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala index 39d61a22c..8c635df70 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala @@ -24,7 +24,7 @@ import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.channel._ import fr.acinq.eclair.payment.IncomingPaymentPacket -import fr.acinq.eclair.{SubscriptionsComplete, Logs, NodeParams, ShortChannelId} +import fr.acinq.eclair.{Logs, NodeParams, ShortChannelId, SubscriptionsComplete} import java.util.UUID import scala.collection.mutable @@ -70,9 +70,12 @@ object ChannelRelayer { Behaviors.receiveMessage { case Relay(channelRelayPacket, originNode) => val relayId = UUID.randomUUID() - val nextNodeId_opt: Option[PublicKey] = scid2channels.get(channelRelayPacket.payload.outgoingChannelId) match { - case Some(channelId) => channels.get(channelId).map(_.nextNodeId) - case None => None + val nextNodeId_opt: Option[PublicKey] = channelRelayPacket.payload.outgoing match { + case Left(walletNodeId) => Some(walletNodeId) + case Right(outgoingChannelId) => scid2channels.get(outgoingChannelId) match { + case Some(channelId) => channels.get(channelId).map(_.nextNodeId) + case None => None + } } val nextChannels: Map[ByteVector32, Relayer.OutgoingChannel] = nextNodeId_opt match { case Some(nextNodeId) => node2channels.get(nextNodeId).flatMap(channels.get).map(c => c.channelId -> c).toMap diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index 49471f82b..de22feba6 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -26,6 +26,7 @@ import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Upstream} import fr.acinq.eclair.db.PendingCommandsDb +import fr.acinq.eclair.io.PeerReadyNotifier import fr.acinq.eclair.payment.IncomingPaymentPacket.NodeRelayPacket import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment._ @@ -40,7 +41,7 @@ import fr.acinq.eclair.router.Router.RouteParams import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound} import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, EncodedNodeId, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32} import java.util.UUID import java.util.concurrent.TimeUnit @@ -62,7 +63,7 @@ object NodeRelay { private case class WrappedPreimageReceived(preimageReceived: PreimageReceived) extends Command private case class WrappedPaymentSent(paymentSent: PaymentSent) extends Command private case class WrappedPaymentFailed(paymentFailed: PaymentFailed) extends Command - private[relay] case class WrappedPeerReadyResult(result: AsyncPaymentTriggerer.Result) extends Command + private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command private case class WrappedResolvedPaths(resolved: Seq[ResolvedPath]) extends Command // @formatter:on @@ -88,7 +89,6 @@ object NodeRelay { relayId: UUID, nodeRelayPacket: NodeRelayPacket, outgoingPaymentFactory: OutgoingPaymentFactory, - triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command], router: ActorRef): Behavior[Command] = Behaviors.setup { context => val paymentHash = nodeRelayPacket.add.paymentHash @@ -108,7 +108,7 @@ object NodeRelay { case IncomingPaymentPacket.RelayToTrampolinePacket(_, _, _, nextPacket) => Some(nextPacket) case _: IncomingPaymentPacket.RelayToBlindedPathsPacket => None } - new NodeRelay(nodeParams, parent, register, relayId, paymentHash, nodeRelayPacket.outerPayload.paymentSecret, context, outgoingPaymentFactory, triggerer, router) + new NodeRelay(nodeParams, parent, register, relayId, paymentHash, nodeRelayPacket.outerPayload.paymentSecret, context, outgoingPaymentFactory, router) .receiving(Queue.empty, nodeRelayPacket.innerPayload, nextPacket_opt, incomingPaymentHandler) } } @@ -125,14 +125,29 @@ object NodeRelay { Some(InvalidOnionPayload(UInt64(2), 0)) } else { payloadOut match { - case payloadOut: IntermediatePayload.NodeRelay.Standard => - if (payloadOut.invoiceFeatures.isDefined && payloadOut.paymentSecret.isEmpty) { - Some(InvalidOnionPayload(UInt64(8), 0)) // payment secret field is missing - } else { - None - } - case _: IntermediatePayload.NodeRelay.ToBlindedPaths => - None + // If we're relaying a standard payment to a non-trampoline recipient, we need the payment secret. + case payloadOut: IntermediatePayload.NodeRelay.Standard if payloadOut.invoiceFeatures.isDefined && payloadOut.paymentSecret.isEmpty => Some(InvalidOnionPayload(UInt64(8), 0)) + case _: IntermediatePayload.NodeRelay.Standard => None + case _: IntermediatePayload.NodeRelay.ToBlindedPaths => None + } + } + } + + /** This function identifies whether the next node is a wallet node directly connected to us, and returns its node_id. */ + private def nextWalletNodeId(nodeParams: NodeParams, recipient: Recipient): Option[PublicKey] = { + recipient match { + // These two recipients are only used when we're the payment initiator. + case _: SpontaneousRecipient => None + case _: TrampolineRecipient => None + // When relaying to a trampoline node, the next node may be a wallet node directly connected to us, but we don't + // want to have false positives. Feature branches should check an internal DB/cache to confirm. + case r: ClearRecipient if r.nextTrampolineOnion_opt.nonEmpty => None + // If we're relaying to a non-trampoline recipient, it's never a wallet node. + case _: ClearRecipient => None + // When using blinded paths, we may be the introduction node for a wallet node directly connected to us. + case r: BlindedRecipient => r.blindedHops.head.resolved.route match { + case BlindedPathsResolver.PartialBlindedRoute(walletNodeId: EncodedNodeId.WithPublicKey.Wallet, _, _) => Some(walletNodeId.publicKey) + case _ => None } } } @@ -188,7 +203,6 @@ class NodeRelay private(nodeParams: NodeParams, paymentSecret: ByteVector32, context: ActorContext[NodeRelay.Command], outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory, - triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command], router: ActorRef) { import NodeRelay._ @@ -223,44 +237,102 @@ class NodeRelay private(nodeParams: NodeParams, rejectPayment(upstream, Some(failure)) stopping() case None => - nextPayload match { - // TODO: async payments are not currently supported for blinded recipients. We should update the AsyncPaymentTriggerer to decrypt the blinded path. - case nextPayload: IntermediatePayload.NodeRelay.Standard if nextPayload.isAsyncPayment && nodeParams.features.hasFeature(Features.AsyncPaymentPrototype) => - waitForTrigger(upstream, nextPayload, nextPacket_opt) - case _ => - doSend(upstream, nextPayload, nextPacket_opt) - } + resolveNextNode(upstream, nextPayload, nextPacket_opt) } } - private def waitForTrigger(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay.Standard, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { - context.log.info(s"waiting for async payment to trigger before relaying trampoline payment (amountIn=${upstream.amountIn} expiryIn=${upstream.expiryIn} amountOut=${nextPayload.amountToForward} expiryOut=${nextPayload.outgoingCltv}, asyncPaymentsParams=${nodeParams.relayParams.asyncPaymentsParams})") - val timeoutBlock = nodeParams.currentBlockHeight + nodeParams.relayParams.asyncPaymentsParams.holdTimeoutBlocks - val safetyBlock = (upstream.expiryIn - nodeParams.relayParams.asyncPaymentsParams.cancelSafetyBeforeTimeout).blockHeight - // wait for notification until which ever occurs first: the hold timeout block or the safety block - val notifierTimeout = Seq(timeoutBlock, safetyBlock).min - val peerReadyResultAdapter = context.messageAdapter[AsyncPaymentTriggerer.Result](WrappedPeerReadyResult) - - triggerer ! AsyncPaymentTriggerer.Watch(peerReadyResultAdapter, nextPayload.outgoingNodeId, paymentHash, notifierTimeout) - context.system.eventStream ! EventStream.Publish(WaitingToRelayPayment(nextPayload.outgoingNodeId, paymentHash)) - Behaviors.receiveMessagePartial { - case WrappedPeerReadyResult(AsyncPaymentTriggerer.AsyncPaymentTimeout) => - context.log.warn("rejecting async payment; was not triggered before block {}", notifierTimeout) - rejectPayment(upstream, Some(TemporaryNodeFailure())) // TODO: replace failure type when async payment spec is finalized - stopping() - case WrappedPeerReadyResult(AsyncPaymentTriggerer.AsyncPaymentCanceled) => - context.log.warn(s"payment sender canceled a waiting async payment") - rejectPayment(upstream, Some(TemporaryNodeFailure())) // TODO: replace failure type when async payment spec is finalized - stopping() - case WrappedPeerReadyResult(AsyncPaymentTriggerer.AsyncPaymentTriggered) => - doSend(upstream, nextPayload, nextPacket_opt) + /** Once we've fully received the incoming HTLC set, we must identify the next node before forwarding the payment. */ + private def resolveNextNode(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { + nextPayload match { + case payloadOut: IntermediatePayload.NodeRelay.Standard => + // If invoice features are provided in the onion, the sender is asking us to relay to a non-trampoline recipient. + payloadOut.invoiceFeatures match { + case Some(features) => + val extraEdges = payloadOut.invoiceRoutingInfo.getOrElse(Nil).flatMap(Bolt11Invoice.toExtraEdges(_, payloadOut.outgoingNodeId)) + val paymentSecret = payloadOut.paymentSecret.get // NB: we've verified that there was a payment secret in validateRelay + val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features(features).invoiceFeatures(), payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, extraEdges, payloadOut.paymentMetadata) + context.log.debug("forwarding payment to non-trampoline recipient {}", recipient.nodeId) + ensureRecipientReady(upstream, recipient, nextPayload, None) + case None => + val paymentSecret = randomBytes32() // we generate a new secret to protect against probing attacks + val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features.empty, payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, nextTrampolineOnion_opt = nextPacket_opt) + context.log.debug("forwarding payment to the next trampoline node {}", recipient.nodeId) + ensureRecipientReady(upstream, recipient, nextPayload, nextPacket_opt) + } + case payloadOut: IntermediatePayload.NodeRelay.ToBlindedPaths => + // Blinded paths in Bolt 12 invoices may encode the introduction node with an scid and a direction: we need to + // resolve that to a nodeId in order to reach that introduction node and use the blinded path. + // If we are the introduction node ourselves, we'll also need to decrypt the onion and identify the next node. + context.spawnAnonymous(BlindedPathsResolver(nodeParams, paymentHash, router, register)) ! Resolve(context.messageAdapter[Seq[ResolvedPath]](WrappedResolvedPaths), payloadOut.outgoingBlindedPaths) + Behaviors.receiveMessagePartial { + rejectExtraHtlcPartialFunction orElse { + case WrappedResolvedPaths(resolved) if resolved.isEmpty => + context.log.warn("rejecting trampoline payment to blinded paths: no usable blinded path") + rejectPayment(upstream, Some(UnknownNextPeer())) + stopping() + case WrappedResolvedPaths(resolved) => + // We don't have access to the invoice: we use the only node_id that somewhat makes sense for the recipient. + val blindedNodeId = resolved.head.route.blindedNodeIds.last + val recipient = BlindedRecipient.fromPaths(blindedNodeId, Features(payloadOut.invoiceFeatures).invoiceFeatures(), payloadOut.amountToForward, payloadOut.outgoingCltv, resolved, Set.empty) + context.log.debug("forwarding payment to blinded recipient {}", recipient.nodeId) + ensureRecipientReady(upstream, recipient, nextPayload, nextPacket_opt) + } + } } } - private def doSend(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { - context.log.debug(s"relaying trampoline payment (amountIn=${upstream.amountIn} expiryIn=${upstream.expiryIn} amountOut=${nextPayload.amountToForward} expiryOut=${nextPayload.outgoingCltv})") + /** + * The next node may be a mobile wallet directly connected to us: in that case, we'll need to wake them up before + * relaying the payment. + */ + private def ensureRecipientReady(upstream: Upstream.Hot.Trampoline, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { + nextWalletNodeId(nodeParams, recipient) match { + case Some(walletNodeId) if nodeParams.peerWakeUpConfig.enabled => waitForPeerReady(upstream, walletNodeId, recipient, nextPayload, nextPacket_opt) + case _ => relay(upstream, recipient, nextPayload, nextPacket_opt) + } + } + + /** + * The next node is the payment recipient. They are directly connected to us and may be offline. We try to wake them + * up and will relay the payment once they're connected and channels are reestablished. + */ + private def waitForPeerReady(upstream: Upstream.Hot.Trampoline, walletNodeId: PublicKey, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { + context.log.info("trying to wake up next peer (nodeId={})", walletNodeId) + val notifier = context.spawnAnonymous(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout)))) + notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) + Behaviors.receiveMessagePartial { + rejectExtraHtlcPartialFunction orElse { + case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) => + context.log.warn("rejecting payment: failed to wake-up remote peer") + rejectPayment(upstream, Some(UnknownNextPeer())) + stopping() + case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerReady) => + relay(upstream, recipient, nextPayload, nextPacket_opt) + } + } + } + + /** Relay the payment to the next identified node: this is similar to sending an outgoing payment. */ + private def relay(upstream: Upstream.Hot.Trampoline, recipient: Recipient, payloadOut: IntermediatePayload.NodeRelay, packetOut_opt: Option[OnionRoutingPacket]): Behavior[Command] = { + context.log.debug("relaying trampoline payment (amountIn={} expiryIn={} amountOut={} expiryOut={})", upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv) val confidence = (upstream.received.map(_.add.endorsement).min + 0.5) / 8 - relay(upstream, nextPayload, nextPacket_opt, confidence) + val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, recipient.nodeId, upstream, None, None, storeInDb = false, publishEvent = false, recordPathFindingMetrics = true, confidence) + val routeParams = computeRouteParams(nodeParams, upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv) + // If the next node is using trampoline, we assume that they support MPP. + val useMultiPart = recipient.features.hasFeature(Features.BasicMultiPartPayment) || packetOut_opt.nonEmpty + val payFsmAdapters = { + context.messageAdapter[PreimageReceived](WrappedPreimageReceived) + context.messageAdapter[PaymentSent](WrappedPaymentSent) + context.messageAdapter[PaymentFailed](WrappedPaymentFailed) + }.toClassic + val payment = if (useMultiPart) { + SendMultiPartPayment(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams) + } else { + SendPaymentToNode(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams) + } + val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, useMultiPart) + payFSM ! payment + sending(upstream, payloadOut, recipient, TimestampMilli.now(), fulfilledUpstream = false) } /** @@ -270,7 +342,11 @@ class NodeRelay private(nodeParams: NodeParams, * @param nextPayload relay instructions. * @param fulfilledUpstream true if we already fulfilled the payment upstream. */ - private def sending(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay, startedAt: TimestampMilli, fulfilledUpstream: Boolean): Behavior[Command] = + private def sending(upstream: Upstream.Hot.Trampoline, + nextPayload: IntermediatePayload.NodeRelay, + recipient: Recipient, + startedAt: TimestampMilli, + fulfilledUpstream: Boolean): Behavior[Command] = Behaviors.receiveMessagePartial { rejectExtraHtlcPartialFunction orElse { // this is the fulfill that arrives from downstream channels @@ -279,7 +355,7 @@ class NodeRelay private(nodeParams: NodeParams, // We want to fulfill upstream as soon as we receive the preimage (even if not all HTLCs have fulfilled downstream). context.log.debug("got preimage from downstream") fulfillPayment(upstream, paymentPreimage) - sending(upstream, nextPayload, startedAt, fulfilledUpstream = true) + sending(upstream, nextPayload, recipient, startedAt, fulfilledUpstream = true) } else { // we don't want to fulfill multiple times Behaviors.same @@ -311,80 +387,6 @@ class NodeRelay private(nodeParams: NodeParams, } } - private val payFsmAdapters = { - context.messageAdapter[PreimageReceived](WrappedPreimageReceived) - context.messageAdapter[PaymentSent](WrappedPaymentSent) - context.messageAdapter[PaymentFailed](WrappedPaymentFailed) - }.toClassic - - private def relay(upstream: Upstream.Hot.Trampoline, payloadOut: IntermediatePayload.NodeRelay, packetOut_opt: Option[OnionRoutingPacket], confidence: Double): Behavior[Command] = { - val displayNodeId = payloadOut match { - case payloadOut: IntermediatePayload.NodeRelay.Standard => payloadOut.outgoingNodeId - case _: IntermediatePayload.NodeRelay.ToBlindedPaths => randomKey().publicKey - } - val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, displayNodeId, upstream, None, None, storeInDb = false, publishEvent = false, recordPathFindingMetrics = true, confidence) - val routeParams = computeRouteParams(nodeParams, upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv) - payloadOut match { - case payloadOut: IntermediatePayload.NodeRelay.Standard => - // If invoice features are provided in the onion, the sender is asking us to relay to a non-trampoline recipient. - payloadOut.invoiceFeatures match { - case Some(features) => - val extraEdges = payloadOut.invoiceRoutingInfo.getOrElse(Nil).flatMap(Bolt11Invoice.toExtraEdges(_, payloadOut.outgoingNodeId)) - val paymentSecret = payloadOut.paymentSecret.get // NB: we've verified that there was a payment secret in validateRelay - val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features(features).invoiceFeatures(), payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, extraEdges, payloadOut.paymentMetadata) - context.log.debug("sending the payment to non-trampoline recipient (MPP={})", recipient.features.hasFeature(Features.BasicMultiPartPayment)) - relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, useMultiPart = recipient.features.hasFeature(Features.BasicMultiPartPayment)) - case None => - context.log.debug("sending the payment to the next trampoline node") - val paymentSecret = randomBytes32() // we generate a new secret to protect against probing attacks - val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features.empty, payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, nextTrampolineOnion_opt = packetOut_opt) - relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, useMultiPart = true) - } - case payloadOut: IntermediatePayload.NodeRelay.ToBlindedPaths => - context.spawnAnonymous(BlindedPathsResolver(nodeParams, paymentHash, router, register)) ! Resolve(context.messageAdapter[Seq[ResolvedPath]](WrappedResolvedPaths), payloadOut.outgoingBlindedPaths) - waitForResolvedPaths(upstream, payloadOut, paymentCfg, routeParams) - } - } - - private def relayToRecipient(upstream: Upstream.Hot.Trampoline, - payloadOut: IntermediatePayload.NodeRelay, - recipient: Recipient, - paymentCfg: SendPaymentConfig, - routeParams: RouteParams, - useMultiPart: Boolean): Behavior[Command] = { - val payment = - if (useMultiPart) { - SendMultiPartPayment(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams) - } else { - SendPaymentToNode(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams) - } - val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, useMultiPart) - payFSM ! payment - sending(upstream, payloadOut, TimestampMilli.now(), fulfilledUpstream = false) - } - - /** - * Blinded paths in Bolt 12 invoices may encode the introduction node with an scid and a direction: we need to resolve - * that to a nodeId in order to reach that introduction node and use the blinded path. - */ - private def waitForResolvedPaths(upstream: Upstream.Hot.Trampoline, - payloadOut: IntermediatePayload.NodeRelay.ToBlindedPaths, - paymentCfg: SendPaymentConfig, - routeParams: RouteParams): Behavior[Command] = - Behaviors.receiveMessagePartial { - case WrappedResolvedPaths(resolved) if resolved.isEmpty => - context.log.warn(s"rejecting trampoline payment to blinded paths: no usable blinded path") - rejectPayment(upstream, Some(UnknownNextPeer())) - stopping() - case WrappedResolvedPaths(resolved) => - val features = Features(payloadOut.invoiceFeatures).invoiceFeatures() - // We don't have access to the invoice: we use the only node_id that somewhat makes sense for the recipient. - val blindedNodeId = resolved.head.route.blindedNodeIds.last - val recipient = BlindedRecipient.fromPaths(blindedNodeId, features, payloadOut.amountToForward, payloadOut.outgoingCltv, resolved, Set.empty) - context.log.debug("sending the payment to blinded recipient, useMultiPart={}", features.hasFeature(Features.BasicMultiPartPayment)) - relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, features.hasFeature(Features.BasicMultiPartPayment)) - } - private def rejectExtraHtlcPartialFunction: PartialFunction[Command, Behavior[Command]] = { case Relay(nodeRelayPacket, _) => rejectExtraHtlc(nodeRelayPacket.add) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala index 20d65b199..75bb545c8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala @@ -16,7 +16,6 @@ package fr.acinq.eclair.payment.relay -import akka.actor.typed import akka.actor.typed.scaladsl.Behaviors import akka.actor.typed.{ActorRef, Behavior} import fr.acinq.bitcoin.scalacompat.ByteVector32 @@ -58,7 +57,7 @@ object NodeRelayer { * NB: the payment secret used here is different from the invoice's payment secret and ensures we can * group together HTLCs that the previous trampoline node sent in the same MPP. */ - def apply(nodeParams: NodeParams, register: akka.actor.ActorRef, outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory, triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command], router: akka.actor.ActorRef, children: Map[PaymentKey, ActorRef[NodeRelay.Command]] = Map.empty): Behavior[Command] = + def apply(nodeParams: NodeParams, register: akka.actor.ActorRef, outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory, router: akka.actor.ActorRef, children: Map[PaymentKey, ActorRef[NodeRelay.Command]] = Map.empty): Behavior[Command] = Behaviors.setup { context => Behaviors.withMdc(Logs.mdc(category_opt = Some(Logs.LogCategory.PAYMENT)), mdc) { Behaviors.receiveMessage { @@ -73,15 +72,15 @@ object NodeRelayer { case None => val relayId = UUID.randomUUID() context.log.debug(s"spawning a new handler with relayId=$relayId") - val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, nodeRelayPacket, outgoingPaymentFactory, triggerer, router), relayId.toString) + val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, nodeRelayPacket, outgoingPaymentFactory, router), relayId.toString) context.log.debug("forwarding incoming htlc #{} from channel {} to new handler", htlcIn.id, htlcIn.channelId) handler ! NodeRelay.Relay(nodeRelayPacket, originNode) - apply(nodeParams, register, outgoingPaymentFactory, triggerer, router, children + (childKey -> handler)) + apply(nodeParams, register, outgoingPaymentFactory, router, children + (childKey -> handler)) } case RelayComplete(childHandler, paymentHash, paymentSecret) => // we do a back-and-forth between parent and child before stopping the child to prevent a race condition childHandler ! NodeRelay.Stop - apply(nodeParams, register, outgoingPaymentFactory, triggerer, router, children - PaymentKey(paymentHash, paymentSecret)) + apply(nodeParams, register, outgoingPaymentFactory, router, children - PaymentKey(paymentHash, paymentSecret)) case GetPendingPayments(replyTo) => replyTo ! children Behaviors.same diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala index f9f5c0039..d85f9876a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala @@ -49,7 +49,7 @@ import scala.util.Random * It also receives channel HTLC events (fulfill / failed) and relays those to the appropriate handlers. * It also maintains an up-to-date view of local channel balances. */ -class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paymentHandler: ActorRef, triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command], initialized: Option[Promise[Done]] = None) extends Actor with DiagnosticActorLogging { +class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paymentHandler: ActorRef, initialized: Option[Promise[Done]] = None) extends Actor with DiagnosticActorLogging { import Relayer._ @@ -58,7 +58,7 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym private val postRestartCleaner = context.actorOf(PostRestartHtlcCleaner.props(nodeParams, register, initialized), "post-restart-htlc-cleaner") private val channelRelayer = context.spawn(Behaviors.supervise(ChannelRelayer(nodeParams, register)).onFailure(SupervisorStrategy.resume), "channel-relayer") - private val nodeRelayer = context.spawn(Behaviors.supervise(NodeRelayer(nodeParams, register, NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register), triggerer, router)).onFailure(SupervisorStrategy.resume), name = "node-relayer") + private val nodeRelayer = context.spawn(Behaviors.supervise(NodeRelayer(nodeParams, register, NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register), router)).onFailure(SupervisorStrategy.resume), name = "node-relayer") def receive: Receive = { case init: PostRestartHtlcCleaner.Init => postRestartCleaner forward init @@ -120,8 +120,8 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym object Relayer extends Logging { - def props(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paymentHandler: ActorRef, triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command], initialized: Option[Promise[Done]] = None): Props = - Props(new Relayer(nodeParams, router, register, paymentHandler, triggerer, initialized)) + def props(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paymentHandler: ActorRef, initialized: Option[Promise[Done]] = None): Props = + Props(new Relayer(nodeParams, router, register, paymentHandler, initialized)) // @formatter:off case class RelayFees(feeBase: MilliSatoshi, feeProportionalMillionths: Long) { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala index a12799d8d..12ee84f46 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala @@ -14,7 +14,7 @@ import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.RouteBlindingDecryptedData import fr.acinq.eclair.wire.protocol.{BlindedRouteData, OfferTypes, RouteBlindingEncryptedDataCodecs} -import fr.acinq.eclair.{EncodedNodeId, Logs, NodeParams} +import fr.acinq.eclair.{EncodedNodeId, Logs, MilliSatoshiLong, NodeParams, ShortChannelId} import scodec.bits.ByteVector import scala.annotation.tailrec @@ -45,8 +45,8 @@ object BlindedPathsResolver { override val firstNodeId: PublicKey = introductionNodeId } /** A partially unwrapped blinded route that started at our node: it only contains the part of the route after our node. */ - case class PartialBlindedRoute(nextNodeId: PublicKey, nextBlinding: PublicKey, blindedNodes: Seq[BlindedNode]) extends ResolvedBlindedRoute { - override val firstNodeId: PublicKey = nextNodeId + case class PartialBlindedRoute(nextNodeId: EncodedNodeId.WithPublicKey, nextBlinding: PublicKey, blindedNodes: Seq[BlindedNode]) extends ResolvedBlindedRoute { + override val firstNodeId: PublicKey = nextNodeId.publicKey } // @formatter:on @@ -111,8 +111,14 @@ private class BlindedPathsResolver(nodeParams: NodeParams, feeProportionalMillionths = nextFeeProportionalMillionths, cltvExpiryDelta = nextCltvExpiryDelta ) - register ! Register.GetNextNodeId(context.messageAdapter(WrappedNodeId), paymentRelayData.outgoingChannelId) - waitForNextNodeId(nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved) + paymentRelayData.outgoing match { + case Left(outgoingNodeId) => + // The next node seems to be a wallet node directly connected to us. + validateRelay(EncodedNodeId.WithPublicKey.Wallet(outgoingNodeId), nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved) + case Right(outgoingChannelId) => + register ! Register.GetNextNodeId(context.messageAdapter(WrappedNodeId), outgoingChannelId) + waitForNextNodeId(outgoingChannelId, nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved) + } } } case encodedNodeId: EncodedNodeId.WithPublicKey => @@ -129,7 +135,8 @@ private class BlindedPathsResolver(nodeParams: NodeParams, } /** Resolve the next node in the blinded path when we are the introduction node. */ - private def waitForNextNodeId(nextPaymentInfo: OfferTypes.PaymentInfo, + private def waitForNextNodeId(outgoingChannelId: ShortChannelId, + nextPaymentInfo: OfferTypes.PaymentInfo, paymentRelayData: BlindedRouteData.PaymentRelayData, nextBlinding: PublicKey, nextBlindedNodes: Seq[RouteBlinding.BlindedNode], @@ -137,29 +144,42 @@ private class BlindedPathsResolver(nodeParams: NodeParams, resolved: Seq[ResolvedPath]): Behavior[Command] = Behaviors.receiveMessagePartial { case WrappedNodeId(None) => - context.log.warn("ignoring blinded path starting at our node: could not resolve outgoingChannelId={}", paymentRelayData.outgoingChannelId) + context.log.warn("ignoring blinded path starting at our node: could not resolve outgoingChannelId={}", outgoingChannelId) resolveBlindedPaths(toResolve, resolved) case WrappedNodeId(Some(nodeId)) if nodeId == nodeParams.nodeId => // The next node in the route is also our node: this is fishy, there is not reason to include us in the route twice. context.log.warn("ignoring blinded path starting at our node relaying to ourselves") resolveBlindedPaths(toResolve, resolved) case WrappedNodeId(Some(nodeId)) => - // Note that we default to private fees if we don't have a channel yet with that node. - // The announceChannel parameter is ignored if we already have a channel. - val relayFees = getRelayFees(nodeParams, nodeId, announceChannel = false) - val shouldRelay = paymentRelayData.paymentRelay.feeBase >= relayFees.feeBase && - paymentRelayData.paymentRelay.feeProportionalMillionths >= relayFees.feeProportionalMillionths && - paymentRelayData.paymentRelay.cltvExpiryDelta >= nodeParams.channelConf.expiryDelta - if (shouldRelay) { - context.log.debug("unwrapped blinded path starting at our node: next_node={}", nodeId) - val path = ResolvedPath(PartialBlindedRoute(nodeId, nextBlinding, nextBlindedNodes), nextPaymentInfo) - resolveBlindedPaths(toResolve, resolved :+ path) - } else { - context.log.warn("ignoring blinded path starting at our node: allocated fees are too low (base={}, proportional={}, expiryDelta={})", paymentRelayData.paymentRelay.feeBase, paymentRelayData.paymentRelay.feeProportionalMillionths, paymentRelayData.paymentRelay.cltvExpiryDelta) - resolveBlindedPaths(toResolve, resolved) - } + validateRelay(EncodedNodeId.WithPublicKey.Plain(nodeId), nextPaymentInfo, paymentRelayData, nextBlinding, nextBlindedNodes, toResolve, resolved) } + private def validateRelay(nextNodeId: EncodedNodeId.WithPublicKey, + nextPaymentInfo: OfferTypes.PaymentInfo, + paymentRelayData: BlindedRouteData.PaymentRelayData, + nextBlinding: PublicKey, + nextBlindedNodes: Seq[RouteBlinding.BlindedNode], + toResolve: Seq[PaymentBlindedRoute], + resolved: Seq[ResolvedPath]): Behavior[Command] = { + // Note that we default to private fees if we don't have a channel yet with that node. + // The announceChannel parameter is ignored if we already have a channel. + val relayFees = getRelayFees(nodeParams, nextNodeId.publicKey, announceChannel = false) + val shouldRelay = paymentRelayData.paymentRelay.feeBase >= relayFees.feeBase && + paymentRelayData.paymentRelay.feeProportionalMillionths >= relayFees.feeProportionalMillionths && + paymentRelayData.paymentRelay.cltvExpiryDelta >= nodeParams.channelConf.expiryDelta && + nextPaymentInfo.feeBase >= 0.msat && + nextPaymentInfo.feeProportionalMillionths >= 0 && + nextPaymentInfo.cltvExpiryDelta.toInt >= 0 + if (shouldRelay) { + context.log.debug("unwrapped blinded path starting at our node: next_node={}", nextNodeId.publicKey) + val path = ResolvedPath(PartialBlindedRoute(nextNodeId, nextBlinding, nextBlindedNodes), nextPaymentInfo) + resolveBlindedPaths(toResolve, resolved :+ path) + } else { + context.log.warn("ignoring blinded path starting at our node: allocated fees are too low (base={}, proportional={}, expiryDelta={})", paymentRelayData.paymentRelay.feeBase, paymentRelayData.paymentRelay.feeProportionalMillionths, paymentRelayData.paymentRelay.cltvExpiryDelta) + resolveBlindedPaths(toResolve, resolved) + } + } + /** Resolve the introduction node's [[EncodedNodeId.ShortChannelIdDir]] to the corresponding [[EncodedNodeId.WithPublicKey]]. */ private def waitForNodeId(paymentRoute: PaymentBlindedRoute, toResolve: Seq[PaymentBlindedRoute], resolved: Seq[ResolvedPath]): Behavior[Command] = Behaviors.receiveMessagePartial { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala index 81605c5cc..37e31c7ff 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala @@ -21,7 +21,7 @@ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.router.Router.ChannelHop import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo import fr.acinq.eclair.wire.protocol.{RouteBlindingEncryptedDataCodecs, RouteBlindingEncryptedDataTlv, TlvStream} -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshi, MilliSatoshiLong, randomKey} import scodec.bits.ByteVector object BlindedRouteCreation { @@ -77,7 +77,7 @@ object BlindedRouteCreation { Total: 24 to 36 bytes */ val targetLength = 36 - val paddedPayloads = payloads.map(tlvs =>{ + val paddedPayloads = payloads.map(tlvs => { val payloadLength = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(tlvs).require.bytes.length tlvs.copy(records = tlvs.records + RouteBlindingEncryptedDataTlv.Padding(ByteVector.fill(targetLength - payloadLength)(0))) }) @@ -95,4 +95,19 @@ object BlindedRouteCreation { Sphinx.RouteBlinding.create(randomKey(), Seq(nodeId), Seq(finalPayload)) } + /** Create a blinded route where the recipient is a wallet node. */ + def createBlindedRouteToWallet(hop: Router.ChannelHop, pathId: ByteVector, minAmount: MilliSatoshi, routeFinalExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = { + val routeExpiry = routeFinalExpiry + hop.cltvExpiryDelta + val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( + RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, minAmount), + RouteBlindingEncryptedDataTlv.PathId(pathId), + )).require.bytes + val intermediatePayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream[RouteBlindingEncryptedDataTlv]( + RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Wallet(hop.nextNodeId)), + RouteBlindingEncryptedDataTlv.PaymentRelay(hop.cltvExpiryDelta, hop.params.relayFees.feeProportionalMillionths, hop.params.relayFees.feeBase), + RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, minAmount), + )).require.bytes + Sphinx.RouteBlinding.create(randomKey(), Seq(hop.nodeId, hop.nextNodeId), Seq(intermediatePayload, finalPayload)) + } + } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala index 10bd99f44..4468ed717 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala @@ -23,7 +23,7 @@ import fr.acinq.eclair.wire.protocol.BlindedRouteData.PaymentRelayData import fr.acinq.eclair.wire.protocol.CommonCodecs._ import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs._ -import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, UInt64} +import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, ShortChannelId, UInt64} import scodec.bits.{BitVector, ByteVector} /** @@ -227,7 +227,8 @@ object PaymentOnion { object IntermediatePayload { sealed trait ChannelRelay extends IntermediatePayload { // @formatter:off - def outgoingChannelId: ShortChannelId + /** The outgoing channel, or the nodeId of one of our peers. */ + def outgoing: Either[PublicKey, ShortChannelId] def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry // @formatter:on @@ -238,7 +239,7 @@ object PaymentOnion { // @formatter:off val amountOut = records.get[AmountToForward].get.amount val cltvOut = records.get[OutgoingCltv].get.cltv - override val outgoingChannelId = records.get[OutgoingChannelId].get.shortChannelId + override val outgoing = Right(records.get[OutgoingChannelId].get.shortChannelId) override def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = amountOut override def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = cltvOut // @formatter:on @@ -258,12 +259,12 @@ object PaymentOnion { } /** - * @param blindedRecords decrypted tlv stream from the encrypted_recipient_data tlv. - * @param nextBlinding blinding point that must be forwarded to the next hop. + * @param paymentRelayData decrypted relaying data from the encrypted_recipient_data tlv. + * @param nextBlinding blinding point that must be forwarded to the next hop. */ case class Blinded(records: TlvStream[OnionPaymentPayloadTlv], paymentRelayData: PaymentRelayData, nextBlinding: PublicKey) extends ChannelRelay { // @formatter:off - override val outgoingChannelId = paymentRelayData.outgoingChannelId + override val outgoing = paymentRelayData.outgoing override def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = paymentRelayData.amountToForward(incomingAmount) override def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = paymentRelayData.outgoingCltv(incomingCltv) // @formatter:on diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala index 096c1f8ed..be53f4aab 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala @@ -98,7 +98,11 @@ object BlindedRouteData { } case class PaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]) { - val outgoingChannelId: ShortChannelId = records.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId].get.shortChannelId + // This is usually a channel, unless the next node is a mobile wallet connected to our node. + val outgoing: Either[PublicKey, ShortChannelId] = records.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId] match { + case Some(r) => Right(r.shortChannelId) + case None => Left(records.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].get.nodeId.asInstanceOf[EncodedNodeId.WithPublicKey.Wallet].publicKey) + } val paymentRelay: PaymentRelay = records.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get val paymentConstraints: PaymentConstraints = records.get[RouteBlindingEncryptedDataTlv.PaymentConstraints].get val allowedFeatures: Features[Feature] = records.get[RouteBlindingEncryptedDataTlv.AllowedFeatures].map(_.features).getOrElse(Features.empty) @@ -110,7 +114,9 @@ object BlindedRouteData { } def validatePaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]): Either[InvalidTlvPayload, PaymentRelayData] = { - if (records.get[OutgoingChannelId].isEmpty) return Left(MissingRequiredTlv(UInt64(2))) + // Note that the BOLTs require using an OutgoingChannelId, but we optionally support a wallet node_id. + if (records.get[OutgoingChannelId].isEmpty && records.get[OutgoingNodeId].isEmpty) return Left(MissingRequiredTlv(UInt64(2))) + if (records.get[OutgoingNodeId].nonEmpty && !records.get[OutgoingNodeId].get.nodeId.isInstanceOf[EncodedNodeId.WithPublicKey.Wallet]) return Left(ForbiddenTlv(UInt64(4))) if (records.get[PaymentRelay].isEmpty) return Left(MissingRequiredTlv(UInt64(10))) if (records.get[PaymentConstraints].isEmpty) return Left(MissingRequiredTlv(UInt64(12))) if (records.get[PathId].nonEmpty) return Left(ForbiddenTlv(UInt64(6))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index 0b0483b43..0e70083fa 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -26,7 +26,7 @@ import fr.acinq.eclair.channel.{ChannelFlags, LocalParams, Origin, Upstream} import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyManager} import fr.acinq.eclair.db.RevokedHtlcInfoCleaner import fr.acinq.eclair.io.MessageRelay.RelayAll -import fr.acinq.eclair.io.{OpenChannelInterceptor, PeerConnection} +import fr.acinq.eclair.io.{OpenChannelInterceptor, PeerConnection, PeerReadyNotifier} import fr.acinq.eclair.message.OnionMessages.OnionMessageConfig import fr.acinq.eclair.payment.relay.Relayer.{AsyncPaymentsParams, RelayFees, RelayParams} import fr.acinq.eclair.router.Graph.{MessagePath, WeightRatios} @@ -231,7 +231,8 @@ object TestConstants { maxAttempts = 2, ), purgeInvoicesInterval = None, - revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis) + revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis), + peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(enabled = false, timeout = 30 seconds), ) def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( @@ -401,7 +402,8 @@ object TestConstants { maxAttempts = 2, ), purgeInvoicesInterval = None, - revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis) + revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis), + peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(enabled = false, timeout = 30 seconds), ) def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala index dc25ecedc..d3f7f47da 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala @@ -66,8 +66,8 @@ class FuzzySpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Channe val bobRegister = system.actorOf(Props(new TestRegister())) val alicePaymentHandler = system.actorOf(Props(new PaymentHandler(aliceParams, aliceRegister, TestProbe().ref))) val bobPaymentHandler = system.actorOf(Props(new PaymentHandler(bobParams, bobRegister, TestProbe().ref))) - val aliceRelayer = system.actorOf(Relayer.props(aliceParams, TestProbe().ref, aliceRegister, alicePaymentHandler, TestProbe().ref)) - val bobRelayer = system.actorOf(Relayer.props(bobParams, TestProbe().ref, bobRegister, bobPaymentHandler, TestProbe().ref)) + val aliceRelayer = system.actorOf(Relayer.props(aliceParams, TestProbe().ref, aliceRegister, alicePaymentHandler)) + val bobRelayer = system.actorOf(Relayer.props(bobParams, TestProbe().ref, bobRegister, bobPaymentHandler)) val wallet = new DummyOnChainWallet() val alice: TestFSMRef[ChannelState, ChannelData, Channel] = TestFSMRef(new Channel(aliceParams, wallet, bobParams.nodeId, alice2blockchain.ref, aliceRelayer, FakeTxPublisherFactory(alice2blockchain)), alicePeer.ref) val bob: TestFSMRef[ChannelState, ChannelData, Channel] = TestFSMRef(new Channel(bobParams, wallet, aliceParams.nodeId, bob2blockchain.ref, bobRelayer, FakeTxPublisherFactory(bob2blockchain)), bobPeer.ref) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala index e4f064ee3..57a93308a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala @@ -505,7 +505,7 @@ class SphinxSpec extends AnyFunSuite { val Right(decryptedPayloadBob) = RouteBlindingEncryptedDataCodecs.decode(bob, blinding, tlvsBob.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) val blindingEphemeralKeyForCarol = decryptedPayloadBob.nextBlinding val Right(payloadBob) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsBob, decryptedPayloadBob.tlvs, blindingEphemeralKeyForCarol) - assert(payloadBob.outgoingChannelId == ShortChannelId(1)) + assert(payloadBob.outgoing.contains(ShortChannelId(1))) assert(payloadBob.amountToForward(110_125 msat) == 100_125.msat) assert(payloadBob.outgoingCltv(CltvExpiry(749150)) == CltvExpiry(749100)) assert(payloadBob.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(50), 0, 10_000 msat)) @@ -523,7 +523,7 @@ class SphinxSpec extends AnyFunSuite { val Right(decryptedPayloadCarol) = RouteBlindingEncryptedDataCodecs.decode(carol, blindingEphemeralKeyForCarol, tlvsCarol.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) val blindingEphemeralKeyForDave = decryptedPayloadCarol.nextBlinding val Right(payloadCarol) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsCarol, decryptedPayloadCarol.tlvs, blindingEphemeralKeyForDave) - assert(payloadCarol.outgoingChannelId == ShortChannelId(2)) + assert(payloadCarol.outgoing.contains(ShortChannelId(2))) assert(payloadCarol.amountToForward(100_125 msat) == 100_010.msat) assert(payloadCarol.outgoingCltv(CltvExpiry(749100)) == CltvExpiry(749025)) assert(payloadCarol.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(75), 150, 100 msat)) @@ -544,7 +544,7 @@ class SphinxSpec extends AnyFunSuite { val Right(decryptedPayloadDave) = RouteBlindingEncryptedDataCodecs.decode(dave, blindingOverride, tlvsDave.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) val blindingEphemeralKeyForEve = decryptedPayloadDave.nextBlinding val Right(payloadDave) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsDave, decryptedPayloadDave.tlvs, blindingEphemeralKeyForEve) - assert(payloadDave.outgoingChannelId == ShortChannelId(3)) + assert(payloadDave.outgoing.contains(ShortChannelId(3))) assert(payloadDave.amountToForward(100_010 msat) == 100_000.msat) assert(payloadDave.outgoingCltv(CltvExpiry(749025)) == CltvExpiry(749000)) assert(payloadDave.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(25), 100, 0 msat)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala index 1fcbada4a..bbb153a27 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala @@ -90,13 +90,12 @@ object MinimalNodeFixture extends Assertions with Eventually with IntegrationPat val bitcoinClient = new TestBitcoinCoreClient() val wallet = new SingleKeyOnChainWallet() val watcher = TestProbe("watcher") - val triggerer = TestProbe("payment-triggerer") val watcherTyped = watcher.ref.toTyped[ZmqWatcher.Command] val register = system.actorOf(Register.props(), "register") val router = system.actorOf(Router.props(nodeParams, watcherTyped), "router") val offerManager = system.spawn(OfferManager(nodeParams, router, 1 minute), "offer-manager") val paymentHandler = system.actorOf(PaymentHandler.props(nodeParams, register, offerManager), "payment-handler") - val relayer = system.actorOf(Relayer.props(nodeParams, router, register, paymentHandler, triggerer.ref.toTyped), "relayer") + val relayer = system.actorOf(Relayer.props(nodeParams, router, register, paymentHandler), "relayer") val txPublisherFactory = Channel.SimpleTxPublisherFactory(nodeParams, watcherTyped, bitcoinClient) val channelFactory = Peer.SimpleChannelFactory(nodeParams, watcherTyped, relayer, wallet, txPublisherFactory) val pendingChannelsRateLimiter = system.spawnAnonymous(Behaviors.supervise(PendingChannelsRateLimiter(nodeParams, router.toTyped, Seq())).onFailure(typed.SupervisorStrategy.resume)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala index 50da4b3c0..ae45c4c95 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala @@ -19,8 +19,10 @@ package fr.acinq.eclair.io import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe => TypedProbe} import akka.actor.typed.ActorRef import akka.actor.typed.eventstream.EventStream -import akka.actor.typed.scaladsl.adapter.TypedActorRefOps +import akka.actor.typed.receptionist.Receptionist +import akka.actor.typed.scaladsl.adapter.{ClassicActorRefOps, TypedActorRefOps} import akka.testkit.TestProbe +import com.softwaremill.quicklens.ModifyPimp import com.typesafe.config.ConfigFactory import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.TestConstants.{Alice, Bob} @@ -33,8 +35,8 @@ import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient} import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionMessagePayloadTlv, TlvStream} import fr.acinq.eclair.{EncodedNodeId, RealShortChannelId, ShortChannelId, UInt64, randomBytes32, randomKey} -import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike +import org.scalatest.{Outcome, Tag} import scodec.bits.HexStringSyntax import scala.concurrent.duration.DurationInt @@ -43,19 +45,30 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val aliceId: PublicKey = Alice.nodeParams.nodeId val bobId: PublicKey = Bob.nodeParams.nodeId - case class FixtureParam(relay: ActorRef[Command], switchboard: TestProbe, register: TestProbe, router: TypedProbe[Router.GetNodeId], peerConnection: TypedProbe[Nothing], peer: TypedProbe[Peer.RelayOnionMessage], probe: TypedProbe[Status]) + val wakeUpEnabled = "wake_up_enabled" + val wakeUpTimeout = "wake_up_timeout" + + case class FixtureParam(relay: ActorRef[Command], switchboard: TestProbe, register: TestProbe, router: TypedProbe[Router.GetNodeId], peerConnection: TypedProbe[Nothing], peer: TypedProbe[Peer.RelayOnionMessage], peerReadyManager: TestProbe, probe: TypedProbe[Status]) override def withFixture(test: OneArgTest): Outcome = { + val peerReadyManager = TestProbe("peer-ready-manager")(system.classicSystem) + system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref.toTyped) val switchboard = TestProbe("switchboard")(system.classicSystem) + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref.toTyped) val register = TestProbe("register")(system.classicSystem) val router = TypedProbe[Router.GetNodeId]("router") val peerConnection = TypedProbe[Nothing]("peerConnection") val peer = TypedProbe[Peer.RelayOnionMessage]("peer") val probe = TypedProbe[Status]("probe") - val relay = testKit.spawn(MessageRelay(Alice.nodeParams, switchboard.ref, register.ref, router.ref)) + val nodeParams = Alice.nodeParams + .modify(_.peerWakeUpConfig.enabled).setToIf(test.tags.contains(wakeUpEnabled))(true) + .modify(_.peerWakeUpConfig.timeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis) + val relay = testKit.spawn(MessageRelay(nodeParams, switchboard.ref, register.ref, router.ref)) try { - withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, register, router, peerConnection, peer, probe))) + withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, register, router, peerConnection, peer, peerReadyManager, probe))) } finally { + system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref.toTyped) + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref.toTyped) testKit.stop(relay) } } @@ -86,6 +99,23 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app assert(peer.expectMessageType[Peer.RelayOnionMessage].msg == message) } + test("relay after waking up next node", Tag(wakeUpEnabled)) { f => + import f._ + + val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) + val messageId = randomBytes32() + relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId.WithPublicKey.Wallet(bobId)), message, RelayChannelsOnly, None) + + val register = peerReadyManager.expectMsgType[PeerReadyManager.Register] + assert(register.remoteNodeId == bobId) + register.replyTo ! PeerReadyManager.Registered(bobId, otherAttempts = 0) + + val request = switchboard.expectMsgType[GetPeerInfo] + assert(request.remoteNodeId == bobId) + request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, bobId, Peer.CONNECTED, None, Set.empty) + assert(peer.expectMessageType[Peer.RelayOnionMessage].msg == message) + } + test("can't open new connection") { f => import f._ @@ -99,6 +129,15 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app probe.expectMessage(ConnectionFailure(messageId, PeerConnection.ConnectionResult.NoAddressFound)) } + test("can't wake up next node", Tag(wakeUpEnabled), Tag(wakeUpTimeout)) { f => + import f._ + + val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) + val messageId = randomBytes32() + relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId.WithPublicKey.Wallet(bobId)), message, RelayChannelsOnly, Some(probe.ref)) + probe.expectMessage(Disconnected(messageId)) + } + test("no channel with previous node") { f => import f._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyManagerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyManagerSpec.scala new file mode 100644 index 000000000..7ef663ec5 --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyManagerSpec.scala @@ -0,0 +1,55 @@ +/* + * Copyright 2024 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.io + +import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe} +import com.typesafe.config.ConfigFactory +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey +import fr.acinq.eclair.randomKey +import org.scalatest.funsuite.AnyFunSuiteLike + +class PeerReadyManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with AnyFunSuiteLike { + + test("watch pending notifiers") { + val manager = testKit.spawn(PeerReadyManager()) + val remoteNodeId1 = randomKey().publicKey + val notifier1a = TestProbe[PeerReadyManager.Registered]() + val notifier1b = TestProbe[PeerReadyManager.Registered]() + + manager ! PeerReadyManager.Register(notifier1a.ref, remoteNodeId1) + assert(notifier1a.expectMessageType[PeerReadyManager.Registered].otherAttempts == 0) + manager ! PeerReadyManager.Register(notifier1b.ref, remoteNodeId1) + assert(notifier1b.expectMessageType[PeerReadyManager.Registered].otherAttempts == 1) + + val remoteNodeId2 = randomKey().publicKey + val notifier2a = TestProbe[PeerReadyManager.Registered]() + val notifier2b = TestProbe[PeerReadyManager.Registered]() + + // Later attempts aren't affected by previously completed attempts. + manager ! PeerReadyManager.Register(notifier2a.ref, remoteNodeId2) + assert(notifier2a.expectMessageType[PeerReadyManager.Registered].otherAttempts == 0) + notifier2a.stop() + val probe = TestProbe[Set[PublicKey]]() + probe.awaitAssert({ + manager ! PeerReadyManager.List(probe.ref) + assert(probe.expectMessageType[Set[PublicKey]] == Set(remoteNodeId1)) + }) + manager ! PeerReadyManager.Register(notifier2b.ref, remoteNodeId2) + assert(notifier2b.expectMessageType[PeerReadyManager.Registered].otherAttempts == 0) + } + +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyNotifierSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyNotifierSpec.scala index 4d2dad780..bb5f17409 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyNotifierSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyNotifierSpec.scala @@ -33,17 +33,20 @@ import scala.concurrent.duration.DurationInt class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike { - case class FixtureParam(remoteNodeId: PublicKey, switchboard: TestProbe[Switchboard.GetPeerInfo], peer: TestProbe[Peer.GetPeerChannels], probe: TestProbe[PeerReadyNotifier.Result]) + case class FixtureParam(remoteNodeId: PublicKey, peerReadyManager: TestProbe[PeerReadyManager.Register], switchboard: TestProbe[Switchboard.GetPeerInfo], peer: TestProbe[Peer.GetPeerChannels], probe: TestProbe[PeerReadyNotifier.Result]) override def withFixture(test: OneArgTest): Outcome = { val remoteNodeId = randomKey().publicKey + val peerReadyManager = TestProbe[PeerReadyManager.Register]("peer-ready-manager") + system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) val switchboard = TestProbe[Switchboard.GetPeerInfo]("switchboard") system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) val peer = TestProbe[Peer.GetPeerChannels]("peer") val probe = TestProbe[PeerReadyNotifier.Result]() try { - withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, switchboard, peer, probe))) + withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, peerReadyManager, switchboard, peer, probe))) } finally { + system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) } } @@ -53,7 +56,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Left(10 millis)))) notifier ! NotifyWhenPeerReady(probe.ref) - assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == remoteNodeId) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) probe.expectMessage(PeerUnavailable(remoteNodeId)) } @@ -62,6 +65,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(100))))) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == remoteNodeId) // We haven't reached the timeout yet. @@ -78,6 +82,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500))))) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) val request = switchboard.expectMessageType[Switchboard.GetPeerInfo] request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set.empty) probe.expectMessage(PeerReadyNotifier.PeerReady(remoteNodeId, peer.ref.toClassic, Seq.empty)) @@ -88,6 +93,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500))))) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set(TestProbe().ref.toClassic, TestProbe().ref.toClassic)) @@ -115,6 +121,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500))))) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 1) val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic, TestProbe().ref.toClassic)) peer.expectNoMessage(100 millis) @@ -137,6 +144,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = None)) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] request1.replyTo ! Peer.PeerNotFound(remoteNodeId) peer.expectNoMessage(100 millis) @@ -161,6 +169,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500))))) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 5) val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set.empty) peer.expectNoMessage(100 millis) @@ -185,6 +194,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Left(1 second)))) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) val request = switchboard.expectMessageType[Switchboard.GetPeerInfo] request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set(TestProbe().ref.toClassic)) peer.expectMessageType[Peer.GetPeerChannels] @@ -196,6 +206,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(100))))) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 2) val request = switchboard.expectMessageType[Switchboard.GetPeerInfo] request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set(TestProbe().ref.toClassic)) peer.expectMessageType[Peer.GetPeerChannels] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala index 1aab2b2c9..6cca2f3c8 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala @@ -85,7 +85,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_b.amountToForward == amount_bc) assert(relay_b.outgoingCltv == expiry_bc) - assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) + assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId)) assert(relay_b.relayFeeMsat == fee_b) assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) @@ -95,7 +95,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_c.amountToForward == amount_cd) assert(relay_c.outgoingCltv == expiry_cd) - assert(payload_c.outgoingChannelId == channelUpdate_cd.shortChannelId) + assert(payload_c.outgoing.contains(channelUpdate_cd.shortChannelId)) assert(relay_c.relayFeeMsat == fee_c) assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta) @@ -105,7 +105,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_d.amountToForward == amount_de) assert(relay_d.outgoingCltv == expiry_de) - assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) assert(relay_d.relayFeeMsat == fee_d) assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta) @@ -175,7 +175,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_b.amountToForward >= amount_bc) assert(relay_b.outgoingCltv == expiry_bc) - assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) + assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId)) assert(relay_b.relayFeeMsat == fee_b) assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard]) @@ -185,7 +185,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_c.amountToForward >= amount_cd) assert(relay_c.outgoingCltv == expiry_cd) - assert(payload_c.outgoingChannelId == channelUpdate_cd.shortChannelId) + assert(payload_c.outgoing.contains(channelUpdate_cd.shortChannelId)) assert(relay_c.relayFeeMsat == fee_c) assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta) assert(payload_c.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) @@ -196,7 +196,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_d.amountToForward >= amount_de) assert(relay_d.outgoingCltv == expiry_de) - assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) assert(relay_d.relayFeeMsat == fee_d) assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta) assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) @@ -238,7 +238,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_b.amountToForward >= amount_bc) assert(relay_b.outgoingCltv == expiry_bc) - assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) + assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId)) assert(relay_b.relayFeeMsat == fee_b) assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard]) @@ -547,7 +547,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // A smaller amount is sent to d, who doesn't know that it's invalid. val add_d = UpdateAddHtlc(randomBytes32(), 0, amount_de, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt, 1.0) val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) - assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) assert(relay_d.amountToForward < amount_de) assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding @@ -569,7 +569,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val invalidExpiry = payment.cmd.cltvExpiry - Channel.MIN_CLTV_EXPIRY_DELTA - CltvExpiryDelta(1) val add_d = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, paymentHash, invalidExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt, 1.0) val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) - assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) assert(relay_d.outgoingCltv < CltvExpiry(currentBlockCount)) assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala index cd45aa1c7..91dac119f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala @@ -57,7 +57,7 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit case class FixtureParam(nodeParams: NodeParams, register: TestProbe, sender: TestProbe, eventListener: TestProbe) { def createRelayer(nodeParams1: NodeParams): (ActorRef, ActorRef) = { - val relayer = system.actorOf(Relayer.props(nodeParams1, TestProbe().ref, register.ref, TestProbe().ref, TestProbe().ref.toTyped)) + val relayer = system.actorOf(Relayer.props(nodeParams1, TestProbe().ref, register.ref, TestProbe().ref)) // we need ensure the post-htlc-restart child actor is initialized sender.send(relayer, Relayer.GetChildActors(sender.ref)) (relayer, sender.expectMsgType[Relayer.ChildActors].postRestartCleaner) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggererSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggererSpec.scala index 223b7af75..eb19cfd3a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggererSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggererSpec.scala @@ -1,17 +1,18 @@ package fr.acinq.eclair.payment.relay import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe} -import akka.actor.typed.ActorRef import akka.actor.typed.eventstream.EventStream import akka.actor.typed.receptionist.Receptionist +import akka.actor.typed.scaladsl.Behaviors import akka.actor.typed.scaladsl.adapter.TypedActorRefOps +import akka.actor.typed.{ActorRef, Behavior} import com.typesafe.config.ConfigFactory import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.blockchain.CurrentBlockHeight import fr.acinq.eclair.channel.NEGOTIATING import fr.acinq.eclair.io.Switchboard.GetPeerInfo -import fr.acinq.eclair.io.{Peer, PeerConnected, Switchboard} +import fr.acinq.eclair.io.{Peer, PeerConnected, PeerReadyManager, Switchboard} import fr.acinq.eclair.payment.relay.AsyncPaymentTriggerer._ import fr.acinq.eclair.{BlockHeight, TestConstants, randomKey} import org.scalatest.Outcome @@ -23,8 +24,20 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. case class FixtureParam(remoteNodeId: PublicKey, switchboard: TestProbe[Switchboard.GetPeerInfo], peer: TestProbe[Peer.GetPeerChannels], probe: TestProbe[Result], triggerer: ActorRef[Command]) + object DummyPeerReadyManager { + def apply(): Behavior[PeerReadyManager.Command] = { + Behaviors.receiveMessagePartial { + case PeerReadyManager.Register(replyTo, remoteNodeId) => + replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) + Behaviors.same + } + } + } + override def withFixture(test: OneArgTest): Outcome = { val remoteNodeId = TestConstants.Alice.nodeParams.nodeId + val peerReadyManager = testKit.spawn(DummyPeerReadyManager()) + system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager) val switchboard = TestProbe[Switchboard.GetPeerInfo]("switchboard") system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) val peer = TestProbe[Peer.GetPeerChannels]("peer") @@ -33,6 +46,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. try { withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, switchboard, peer, probe, triggerer))) } finally { + system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager) system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) } } @@ -170,7 +184,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. val probe2 = TestProbe[Result]() triggerer ! Watch(probe2.ref, remoteNodeId2, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(101)) val request2 = switchboard.expectMessageType[Switchboard.GetPeerInfo] - request2.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic)) + request2.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId2, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic)) // First remote node times out system.eventStream ! EventStream.Publish(CurrentBlockHeight(BlockHeight(100))) @@ -192,6 +206,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. test("triggerer treats an unexpected stop of the notifier as a cancel") { f => import f._ + triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala index a25391031..433f49c5d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala @@ -19,6 +19,7 @@ package fr.acinq.eclair.payment.relay import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe} import akka.actor.typed import akka.actor.typed.eventstream.EventStream +import akka.actor.typed.receptionist.Receptionist import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import com.softwaremill.quicklens.ModifyPimp import com.typesafe.config.ConfigFactory @@ -29,6 +30,7 @@ import fr.acinq.eclair.TestConstants.emptyOnionPacket import fr.acinq.eclair.blockchain.fee.FeeratePerKw import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.io.{Peer, PeerReadyManager, Switchboard} import fr.acinq.eclair.payment.IncomingPaymentPacket.ChannelRelayPacket import fr.acinq.eclair.payment.relay.ChannelRelayer._ import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket, PaymentPacketSpec} @@ -39,19 +41,26 @@ import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload.ChannelRel import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiry, NodeParams, RealShortChannelId, TestConstants, randomBytes32, _} import org.scalatest.Inside.inside -import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike +import org.scalatest.{Outcome, Tag} import scodec.bits.HexStringSyntax +import scala.concurrent.duration.DurationInt + class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike { import ChannelRelayerSpec._ + val wakeUpEnabled = "wake_up_enabled" + val wakeUpTimeout = "wake_up_timeout" + case class FixtureParam(nodeParams: NodeParams, channelRelayer: typed.ActorRef[ChannelRelayer.Command], register: TestProbe[Any]) override def withFixture(test: OneArgTest): Outcome = { // we are node B in the route A -> B -> C -> .... val nodeParams = TestConstants.Bob.nodeParams + .modify(_.peerWakeUpConfig.enabled).setToIf(test.tags.contains(wakeUpEnabled))(true) + .modify(_.peerWakeUpConfig.timeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis) val register = TestProbe[Any]("register") val channelRelayer = testKit.spawn(ChannelRelayer.apply(nodeParams, register.ref.toClassic)) try { @@ -157,7 +166,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a import f._ val u = createLocalUpdate(channelId1, feeBaseMsat = 2500 msat, feeProportionalMillionths = 0) - val payload = createBlindedPayload(u.channelUpdate, isIntroduction = false) + val payload = createBlindedPayload(Right(u.channelUpdate.shortChannelId), u.channelUpdate, isIntroduction = false) val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) channelRelayer ! WrappedLocalChannelUpdate(u) @@ -166,6 +175,34 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry, 7) } + test("relay blinded payment (wake up wallet node)", Tag(wakeUpEnabled)) { f => + import f._ + + val peerReadyManager = TestProbe[PeerReadyManager.Register]() + system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val u = createLocalUpdate(channelId1, feeBaseMsat = 2500 msat, feeProportionalMillionths = 0) + Seq(true, false).foreach(isIntroduction => { + val payload = createBlindedPayload(Left(outgoingNodeId), u.channelUpdate, isIntroduction) + val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) + + channelRelayer ! WrappedLocalChannelUpdate(u) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) + + // We try to wake-up the next node. + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(outgoingNodeId, otherAttempts = 0) + val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo] + assert(wakeUp.remoteNodeId == outgoingNodeId) + wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty) + expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry, 7) + }) + + system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + } + test("relay with retries") { f => import f._ @@ -270,7 +307,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a Seq(true, false).foreach { isIntroduction => // The outgoing channel is disabled, so we won't be able to relay the payment. val u = createLocalUpdate(channelId1, feeBaseMsat = 5000 msat, feeProportionalMillionths = 0, enabled = false) - val r = createValidIncomingPacket(createBlindedPayload(u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) + val r = createValidIncomingPacket(createBlindedPayload(Right(u.channelUpdate.shortChannelId), u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) channelRelayer ! WrappedLocalChannelUpdate(u) channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) @@ -293,6 +330,31 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a } } + test("fail to relay blinded payment (cannot wake up remote node)", Tag(wakeUpEnabled), Tag(wakeUpTimeout)) { f => + import f._ + + val peerReadyManager = TestProbe[PeerReadyManager.Register]() + system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val u = createLocalUpdate(channelId1, feeBaseMsat = 2500 msat, feeProportionalMillionths = 0) + val payload = createBlindedPayload(Left(outgoingNodeId), u.channelUpdate, isIntroduction = true) + val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) + + channelRelayer ! WrappedLocalChannelUpdate(u) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) + + // We try to wake-up the next node, but we timeout before they connect. + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(outgoingNodeId, otherAttempts = 0) + assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == outgoingNodeId) + val fail = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] + assert(fail.message.reason.contains(InvalidOnionBlinding(Sphinx.hash(r.add.onionRoutingPacket)))) + + system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + } + test("relay when expiry larger than our requirements") { f => import f._ @@ -519,7 +581,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a Seq(true, false).foreach { isIntroduction => testCases.foreach { htlcResult => - val r = createValidIncomingPacket(createBlindedPayload(u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta, endorsementIn = 0) + val r = createValidIncomingPacket(createBlindedPayload(Right(u.channelUpdate.shortChannelId), u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta, endorsementIn = 0) channelRelayer ! WrappedLocalChannelUpdate(u) channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) val fwd = expectFwdAdd(register, channelId1, outgoingAmount, outgoingExpiry, 0) @@ -653,13 +715,16 @@ object ChannelRelayerSpec { localAlias2 -> channelId2, ) - def createBlindedPayload(update: ChannelUpdate, isIntroduction: Boolean): ChannelRelay.Blinded = { + def createBlindedPayload(outgoing: Either[PublicKey, ShortChannelId], update: ChannelUpdate, isIntroduction: Boolean): ChannelRelay.Blinded = { val tlvs = TlvStream[OnionPaymentPayloadTlv](Set( Some(OnionPaymentPayloadTlv.EncryptedRecipientData(hex"2a")), if (isIntroduction) Some(OnionPaymentPayloadTlv.BlindingPoint(randomKey().publicKey)) else None, ).flatten[OnionPaymentPayloadTlv]) val blindedTlvs = TlvStream[RouteBlindingEncryptedDataTlv]( - RouteBlindingEncryptedDataTlv.OutgoingChannelId(update.shortChannelId), + outgoing match { + case Left(nodeId) => RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Wallet(nodeId)) + case Right(scid) => RouteBlindingEncryptedDataTlv.OutgoingChannelId(scid) + }, RouteBlindingEncryptedDataTlv.PaymentRelay(update.cltvExpiryDelta, update.feeProportionalMillionths, update.feeBaseMsat), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(500_000), 0 msat), ) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala index 64b270f4c..21c1891e4 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala @@ -20,35 +20,37 @@ import akka.actor.Status import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe} import akka.actor.typed.ActorRef import akka.actor.typed.eventstream.EventStream +import akka.actor.typed.receptionist.Receptionist import akka.actor.typed.scaladsl.ActorContext import akka.actor.typed.scaladsl.adapter._ import com.softwaremill.quicklens.ModifyPimp import com.typesafe.config.ConfigFactory import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} -import fr.acinq.bitcoin.scalacompat.{Block, BlockHash, ByteVector32, Crypto} +import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto} +import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features.{AsyncPaymentPrototype, BasicMultiPartPayment, PaymentSecret, VariableLengthOnion} -import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register, Upstream} import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.io.{Peer, PeerReadyManager, Switchboard} import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.IncomingPaymentPacket.{RelayToBlindedPathsPacket, RelayToTrampolinePacket} import fr.acinq.eclair.payment.Invoice.ExtraEdge +import fr.acinq.eclair.payment.OutgoingPaymentPacket.NodePayload import fr.acinq.eclair.payment._ -import fr.acinq.eclair.payment.relay.AsyncPaymentTriggerer.{AsyncPaymentCanceled, AsyncPaymentTimeout, AsyncPaymentTriggered, Watch} import fr.acinq.eclair.payment.relay.NodeRelayer.PaymentKey import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment} import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToNode import fr.acinq.eclair.payment.send.{BlindedRecipient, ClearRecipient} -import fr.acinq.eclair.router.Router.RouteRequest -import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound, Router} +import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams, RouteRequest} +import fr.acinq.eclair.router.{BalanceTooLow, BlindedRouteCreation, RouteNotFound, Router} import fr.acinq.eclair.wire.protocol.OfferTypes._ import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload} import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{AllowedFeatures, PathId, PaymentConstraints} import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{BlockHeight, Bolt11Feature, CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, RealShortChannelId, ShortChannelId, TestConstants, TimestampMilli, UInt64, randomBytes, randomBytes32, randomKey} +import fr.acinq.eclair.{Alias, BlockHeight, Bolt11Feature, Bolt12Feature, CltvExpiry, CltvExpiryDelta, EncodedNodeId, FeatureSupport, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, RealShortChannelId, ShortChannelId, TestConstants, TimestampMilli, UInt64, randomBytes32, randomKey} import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.{Outcome, Tag} import scodec.bits.{ByteVector, HexStringSyntax} @@ -65,11 +67,14 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl import NodeRelayerSpec._ - case class FixtureParam(nodeParams: NodeParams, router: TestProbe[Any], register: TestProbe[Any], mockPayFSM: TestProbe[Any], eventListener: TestProbe[PaymentEvent], triggerer: TestProbe[AsyncPaymentTriggerer.Command]) { + val wakeUpEnabled = "wake_up_enabled" + val wakeUpTimeout = "wake_up_timeout" + + case class FixtureParam(nodeParams: NodeParams, router: TestProbe[Any], register: TestProbe[Any], mockPayFSM: TestProbe[Any], eventListener: TestProbe[PaymentEvent]) { def createNodeRelay(packetIn: IncomingPaymentPacket.NodeRelayPacket, useRealPaymentFactory: Boolean = false): (ActorRef[NodeRelay.Command], TestProbe[NodeRelayer.Command]) = { val parent = TestProbe[NodeRelayer.Command]("parent-relayer") val outgoingPaymentFactory = if (useRealPaymentFactory) RealOutgoingPaymentFactory(this) else FakeOutgoingPaymentFactory(this) - val nodeRelay = testKit.spawn(NodeRelay(nodeParams, parent.ref, register.ref.toClassic, relayId, packetIn, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic)) + val nodeRelay = testKit.spawn(NodeRelay(nodeParams, parent.ref, register.ref.toClassic, relayId, packetIn, outgoingPaymentFactory, router.ref.toClassic)) (nodeRelay, parent) } } @@ -92,21 +97,21 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl override def withFixture(test: OneArgTest): Outcome = { val nodeParams = TestConstants.Bob.nodeParams .modify(_.multiPartPaymentExpiry).setTo(5 seconds) - .modify(_.features).setToIf(test.tags.contains("async_payments"))(Features(AsyncPaymentPrototype -> Optional)) .modify(_.relayParams.asyncPaymentsParams.holdTimeoutBlocks).setToIf(test.tags.contains("long_hold_timeout"))(200000) // timeout after payment expires + .modify(_.peerWakeUpConfig.enabled).setToIf(test.tags.contains(wakeUpEnabled))(true) + .modify(_.peerWakeUpConfig.timeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis) val router = TestProbe[Any]("router") val register = TestProbe[Any]("register") val eventListener = TestProbe[PaymentEvent]("event-listener") system.eventStream ! EventStream.Subscribe(eventListener.ref) val mockPayFSM = TestProbe[Any]("pay-fsm") - val triggerer = TestProbe[AsyncPaymentTriggerer.Command]("payment-triggerer") - withFixture(test.toNoArgTest(FixtureParam(nodeParams, router, register, mockPayFSM, eventListener, triggerer))) + withFixture(test.toNoArgTest(FixtureParam(nodeParams, router, register, mockPayFSM, eventListener))) } test("create child handlers for new payments") { f => import f._ val probe = TestProbe[Any]() - val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, FakeOutgoingPaymentFactory(f), triggerer.ref, router.ref.toClassic)) + val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, FakeOutgoingPaymentFactory(f), router.ref.toClassic)) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) probe.expectMessage(Map.empty) @@ -145,7 +150,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val outgoingPaymentFactory = FakeOutgoingPaymentFactory(f) { - val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic)) + val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, router.ref.toClassic)) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) probe.expectMessage(Map.empty) } @@ -153,7 +158,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val (paymentHash1, paymentSecret1, child1) = (randomBytes32(), randomBytes32(), TestProbe[NodeRelay.Command]()) val (paymentHash2, paymentSecret2, child2) = (randomBytes32(), randomBytes32(), TestProbe[NodeRelay.Command]()) val children = Map(PaymentKey(paymentHash1, paymentSecret1) -> child1.ref, PaymentKey(paymentHash2, paymentSecret2) -> child2.ref) - val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic, children)) + val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, router.ref.toClassic, children)) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) probe.expectMessage(children) @@ -169,7 +174,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val (paymentSecret1, child1) = (randomBytes32(), TestProbe[NodeRelay.Command]()) val (paymentSecret2, child2) = (randomBytes32(), TestProbe[NodeRelay.Command]()) val children = Map(PaymentKey(paymentHash, paymentSecret1) -> child1.ref, PaymentKey(paymentHash, paymentSecret2) -> child2.ref) - val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic, children)) + val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, router.ref.toClassic, children)) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) probe.expectMessage(children) @@ -179,7 +184,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl probe.expectMessage(Map(PaymentKey(paymentHash, paymentSecret2) -> child2.ref)) } { - val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic)) + val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, router.ref.toClassic)) parentRelayer ! NodeRelayer.Relay(incomingMultiPart.head, randomKey().publicKey) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) val pending1 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]] @@ -228,7 +233,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1000 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None, 1.0), FinalPayload.Standard.createPayload(1000 msat, incomingAmount, CltvExpiry(499990), incomingSecret, None), IntermediatePayload.NodeRelay.Standard(outgoingAmount, outgoingExpiry, outgoingNodeId), - nextTrampolinePacket) + createTrampolinePacket(outgoingAmount, outgoingExpiry)) nodeRelayer ! NodeRelay.Relay(extra, randomKey().publicKey) // the extra payment will be rejected @@ -257,7 +262,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1000 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None, 1.0), FinalPayload.Standard.createPayload(1000 msat, incomingAmount, CltvExpiry(499990), incomingSecret, None), IntermediatePayload.NodeRelay.Standard(outgoingAmount, outgoingExpiry, outgoingNodeId), - nextTrampolinePacket) + createTrampolinePacket(outgoingAmount, outgoingExpiry)) nodeRelayer ! NodeRelay.Relay(i1, randomKey().publicKey) val fwd1 = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] @@ -270,7 +275,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1500 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None, 1.0), PaymentOnion.FinalPayload.Standard.createPayload(1500 msat, 1500 msat, CltvExpiry(499990), incomingSecret, None), IntermediatePayload.NodeRelay.Standard(1250 msat, outgoingExpiry, outgoingNodeId), - nextTrampolinePacket) + createTrampolinePacket(outgoingAmount, outgoingExpiry)) nodeRelayer ! NodeRelay.Relay(i2, randomKey().publicKey) val fwd2 = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] @@ -335,115 +340,6 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl register.expectNoMessage(100 millis) } - test("fail to relay when not triggered before the hold timeout", Tag("async_payments")) { f => - import f._ - - val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head) - incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) - - // wait until the NodeRelay is waiting for the trigger - eventListener.expectMessageType[WaitingToRelayPayment] - mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a trigger - - // publish notification that peer is unavailable at the timeout height - val peerWatch = triggerer.expectMessageType[Watch] - assert(asyncTimeoutHeight(nodeParams) < asyncSafetyHeight(incomingAsyncPayment, nodeParams)) - assert(peerWatch.timeout == asyncTimeoutHeight(nodeParams)) - peerWatch.replyTo ! AsyncPaymentTimeout - - incomingAsyncPayment.foreach { p => - val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] - assert(fwd.channelId == p.add.channelId) - assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure()), commit = true)) - } - register.expectNoMessage(100 millis) - } - - test("relay the payment when triggered while waiting", Tag("async_payments"), Tag("long_hold_timeout")) { f => - import f._ - - val (nodeRelayer, parent) = createNodeRelay(incomingAsyncPayment.head) - incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) - - // wait until the NodeRelay is waiting for the trigger - eventListener.expectMessageType[WaitingToRelayPayment] - mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a trigger - - // publish notification that peer is ready before the safety interval before the current incoming payment expires (and before the timeout height) - val peerWatch = triggerer.expectMessageType[Watch] - assert(asyncTimeoutHeight(nodeParams) > asyncSafetyHeight(incomingAsyncPayment, nodeParams)) - assert(peerWatch.timeout == asyncSafetyHeight(incomingAsyncPayment, nodeParams)) - peerWatch.replyTo ! AsyncPaymentTriggered - - // upstream payment relayed - val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingAsyncPayment.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5) - val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] - validateOutgoingPayment(outgoingPayment) - // those are adapters for pay-fsm messages - val nodeRelayerAdapters = outgoingPayment.replyTo - - // A first downstream HTLC is fulfilled: we should immediately forward the fulfill upstream. - nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) - incomingAsyncPayment.foreach { p => - val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] - assert(fwd.channelId == p.add.channelId) - assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) - } - - // Once all the downstream payments have settled, we should emit the relayed event. - nodeRelayerAdapters ! createSuccessEvent() - val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] - validateRelayEvent(relayEvent) - assert(relayEvent.incoming.map(p => (p.amount, p.channelId)).toSet == incomingAsyncPayment.map(i => (i.add.amountMsat, i.add.channelId)).toSet) - assert(relayEvent.outgoing.nonEmpty) - parent.expectMessageType[NodeRelayer.RelayComplete] - register.expectNoMessage(100 millis) - } - - test("fail to relay when not triggered before the incoming expiry safety timeout", Tag("async_payments"), Tag("long_hold_timeout")) { f => - import f._ - - val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head) - incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) - mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a complete upstream payment - - // publish notification that peer is unavailable at the cancel-safety-before-timeout-block threshold before the current incoming payment expires (and before the timeout height) - val peerWatch = triggerer.expectMessageType[Watch] - assert(asyncTimeoutHeight(nodeParams) > asyncSafetyHeight(incomingAsyncPayment, nodeParams)) - assert(peerWatch.timeout == asyncSafetyHeight(incomingAsyncPayment, nodeParams)) - peerWatch.replyTo ! AsyncPaymentTimeout - - incomingAsyncPayment.foreach { p => - val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] - assert(fwd.channelId == p.add.channelId) - assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure()), commit = true)) - } - - register.expectNoMessage(100 millis) - } - - test("fail to relay payment when canceled by sender before timeout", Tag("async_payments")) { f => - import f._ - - val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head) - incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) - - // wait until the NodeRelay is waiting for the trigger - eventListener.expectMessageType[WaitingToRelayPayment] - mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a trigger - - // fail the payment if waiting when payment sender sends cancel message - nodeRelayer ! NodeRelay.WrappedPeerReadyResult(AsyncPaymentCanceled) - - incomingAsyncPayment.foreach { p => - val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] - assert(fwd.channelId == p.add.channelId) - assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure()), commit = true)) - } - register.expectNoMessage(100 millis) - } - test("relay the payment immediately when the async payment feature is disabled") { f => import f._ @@ -827,26 +723,15 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl } } - def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedRoute = { - val selfPayload = blindedRouteDataCodec.encode(TlvStream(PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty))).require.bytes - PaymentBlindedRoute(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route, PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty)) - } - test("relay to blinded paths without multi-part") { f => import f._ - val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32())) - val offer = Offer(None, Some("test offer"), outgoingNodeId, Features.empty, chain) - val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain) - val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(createPaymentBlindedRoute(outgoingNodeId))) - val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths( - incoming.innerPayload.amountToForward, outgoingExpiry, invoice - ))) + val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features.empty, None) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode] assert(outgoingPayment.amount == outgoingAmount) assert(outgoingPayment.recipient.expiry == outgoingExpiry) @@ -856,7 +741,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val nodeRelayerAdapters = outgoingPayment.replyTo nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) - incomingMultiPart.foreach { p => + incomingPayments.foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] assert(fwd.channelId == p.add.channelId) assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) @@ -865,7 +750,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl nodeRelayerAdapters ! createSuccessEvent() val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] validateRelayEvent(relayEvent) - assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingMultiPart.map(i => (i.add.amountMsat, i.add.channelId))) + assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId))) assert(relayEvent.outgoing.length == 1) parent.expectMessageType[NodeRelayer.RelayComplete] register.expectNoMessage(100 millis) @@ -874,18 +759,12 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl test("relay to blinded paths with multi-part") { f => import f._ - val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32())) - val offer = Offer(None, Some("test offer"), outgoingNodeId, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), chain) - val request = InvoiceRequest(offer, outgoingAmount, 1, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), payerKey, chain) - val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), Seq(createPaymentBlindedRoute(outgoingNodeId))) - val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths( - incoming.innerPayload.amountToForward, outgoingExpiry, invoice - ))) + val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), None) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] assert(outgoingPayment.recipient.totalAmount == outgoingAmount) assert(outgoingPayment.recipient.expiry == outgoingExpiry) @@ -895,7 +774,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val nodeRelayerAdapters = outgoingPayment.replyTo nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) - incomingMultiPart.foreach { p => + incomingPayments.foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] assert(fwd.channelId == p.add.channelId) assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) @@ -904,25 +783,89 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl nodeRelayerAdapters ! createSuccessEvent() val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] validateRelayEvent(relayEvent) - assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingMultiPart.map(i => (i.add.amountMsat, i.add.channelId))) + assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId))) assert(relayEvent.outgoing.length == 1) parent.expectMessageType[NodeRelayer.RelayComplete] register.expectNoMessage(100 millis) } + test("relay to blinded path with wake-up", Tag(wakeUpEnabled)) { f => + import f._ + + val peerReadyManager = TestProbe[PeerReadyManager.Register]() + system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val incomingPayments = createIncomingPaymentsToWalletBlindedPath(nodeParams) + val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) + + // The remote node is a wallet node: we try to wake them up before relaying the payment. + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(outgoingNodeId, otherAttempts = 0) + val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo] + assert(wakeUp.remoteNodeId == outgoingNodeId) + wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty) + system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) + val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] + assert(outgoingPayment.recipient.totalAmount == outgoingAmount) + assert(outgoingPayment.recipient.expiry == outgoingExpiry) + assert(outgoingPayment.recipient.isInstanceOf[BlindedRecipient]) + + // those are adapters for pay-fsm messages + val nodeRelayerAdapters = outgoingPayment.replyTo + + nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) + incomingPayments.foreach { p => + val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] + assert(fwd.channelId == p.add.channelId) + assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) + } + + nodeRelayerAdapters ! createSuccessEvent() + val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] + validateRelayEvent(relayEvent) + assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId))) + assert(relayEvent.outgoing.length == 1) + parent.expectMessageType[NodeRelayer.RelayComplete] + register.expectNoMessage(100 millis) + } + + test("fail to relay to blinded path when wake-up fails", Tag(wakeUpEnabled), Tag(wakeUpTimeout)) { f => + import f._ + + val peerReadyManager = TestProbe[PeerReadyManager.Register]() + system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val incomingPayments = createIncomingPaymentsToWalletBlindedPath(nodeParams) + val (nodeRelayer, _) = f.createNodeRelay(incomingPayments.head) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) + + // The remote node is a wallet node: we try to wake them up before relaying the payment, but it times out. + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(outgoingNodeId, otherAttempts = 0) + assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == outgoingNodeId) + system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + mockPayFSM.expectNoMessage(100 millis) + + incomingPayments.foreach { p => + val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd.channelId == p.add.channelId) + assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(UnknownNextPeer()), commit = true)) + } + } + test("relay to compact blinded paths") { f => import f._ - val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32())) - val offer = Offer(None, Some("test offer"), outgoingNodeId, Features.empty, chain) - val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain) - val paymentBlindedRoute = createPaymentBlindedRoute(outgoingNodeId) val scidDir = ShortChannelIdDir(isNode1 = true, RealShortChannelId(123456L)) - val compactPaymentBlindedRoute = paymentBlindedRoute.copy(route = paymentBlindedRoute.route.copy(introductionNodeId = scidDir)) - val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(compactPaymentBlindedRoute)) - val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths( - incoming.innerPayload.amountToForward, outgoingExpiry, invoice - ))) + val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features.empty, Some(scidDir)) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) @@ -932,7 +875,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl getNodeId.replyTo ! Some(outgoingNodeId) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode] assert(outgoingPayment.amount == outgoingAmount) assert(outgoingPayment.recipient.expiry == outgoingExpiry) @@ -942,7 +885,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val nodeRelayerAdapters = outgoingPayment.replyTo nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) - incomingMultiPart.foreach { p => + incomingPayments.foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] assert(fwd.channelId == p.add.channelId) assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) @@ -951,7 +894,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl nodeRelayerAdapters ! createSuccessEvent() val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] validateRelayEvent(relayEvent) - assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingMultiPart.map(i => (i.add.amountMsat, i.add.channelId))) + assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId))) assert(relayEvent.outgoing.length == 1) parent.expectMessageType[NodeRelayer.RelayComplete] register.expectNoMessage(100 millis) @@ -960,16 +903,8 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl test("fail to relay to compact blinded paths with unknown scid") { f => import f._ - val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32())) - val offer = Offer(None, Some("test offer"), outgoingNodeId, Features.empty, chain) - val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain) - val paymentBlindedRoute = createPaymentBlindedRoute(outgoingNodeId) val scidDir = ShortChannelIdDir(isNode1 = true, RealShortChannelId(123456L)) - val compactPaymentBlindedRoute = paymentBlindedRoute.copy(route = paymentBlindedRoute.route.copy(introductionNodeId = scidDir)) - val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(compactPaymentBlindedRoute)) - val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths( - incoming.innerPayload.amountToForward, outgoingExpiry, invoice - ))) + val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features.empty, Some(scidDir)) val (nodeRelayer, _) = f.createNodeRelay(incomingPayments.head) incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) @@ -980,7 +915,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl mockPayFSM.expectNoMessage(100 millis) - incomingMultiPart.foreach { p => + incomingPayments.foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] assert(fwd.channelId == p.add.channelId) assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(UnknownNextPeer()), commit = true)) @@ -1008,7 +943,9 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl assert(outgoingPayment.recipient.isInstanceOf[ClearRecipient]) val recipient = outgoingPayment.recipient.asInstanceOf[ClearRecipient] assert(recipient.paymentSecret !== incomingSecret) // we should generate a new outgoing secret - assert(recipient.nextTrampolineOnion_opt.contains(nextTrampolinePacket)) + assert(recipient.nextTrampolineOnion_opt.nonEmpty) + // The recipient is able to decrypt the trampoline onion. + recipient.nextTrampolineOnion_opt.foreach(onion => assert(IncomingPaymentPacket.decryptOnion(paymentHash, outgoingNodeKey, onion).isRight)) } def validateRelayEvent(e: TrampolinePaymentRelayed): Unit = { @@ -1025,10 +962,7 @@ object NodeRelayerSpec { val paymentPreimage = randomBytes32() val paymentHash = Crypto.sha256(paymentPreimage) - - // This is the result of decrypting the incoming trampoline onion packet. - // It should be forwarded to the next trampoline node. - val nextTrampolinePacket = OnionRoutingPacket(0, hex"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619", randomBytes(400), randomBytes32()) + val paymentSecret = randomBytes32() val outgoingAmount = 40_000_000 msat val outgoingExpiry = CltvExpiry(490000) @@ -1054,6 +988,12 @@ object NodeRelayerSpec { def createSuccessEvent(): PaymentSent = PaymentSent(relayId, paymentHash, paymentPreimage, outgoingAmount, outgoingNodeId, Seq(PaymentSent.PartialPayment(UUID.randomUUID(), outgoingAmount, 10 msat, randomBytes32(), None))) + def createTrampolinePacket(amount: MilliSatoshi, expiry: CltvExpiry): OnionRoutingPacket = { + val payload = NodePayload(outgoingNodeId, FinalPayload.Standard.createPayload(amount, amount, expiry, paymentSecret)) + val Right(onion) = OutgoingPaymentPacket.buildOnion(Seq(payload), paymentHash, None) + onion.packet + } + def createValidIncomingPacket(amountIn: MilliSatoshi, totalAmountIn: MilliSatoshi, expiryIn: CltvExpiry, amountOut: MilliSatoshi, expiryOut: CltvExpiry, endorsementIn: Int = 7): RelayToTrampolinePacket = { val outerPayload = FinalPayload.Standard.createPayload(amountIn, totalAmountIn, expiryIn, incomingSecret, None) val tlvs = TlvStream[UpdateAddHtlcTlv](UpdateAddHtlcTlv.Endorsement(endorsementIn)) @@ -1061,7 +1001,7 @@ object NodeRelayerSpec { UpdateAddHtlc(randomBytes32(), Random.nextInt(100), amountIn, paymentHash, expiryIn, TestConstants.emptyOnionPacket, tlvs), outerPayload, IntermediatePayload.NodeRelay.Standard(amountOut, expiryOut, outgoingNodeId), - nextTrampolinePacket) + createTrampolinePacket(amountOut, expiryOut)) } def createPartialIncomingPacket(paymentHash: ByteVector32, paymentSecret: ByteVector32): RelayToTrampolinePacket = { @@ -1071,7 +1011,46 @@ object NodeRelayerSpec { UpdateAddHtlc(randomBytes32(), Random.nextInt(100), amountIn, paymentHash, expiryIn, TestConstants.emptyOnionPacket, None, 1.0), FinalPayload.Standard.createPayload(amountIn, incomingAmount, expiryIn, paymentSecret, None), IntermediatePayload.NodeRelay.Standard(outgoingAmount, expiryOut, outgoingNodeId), - nextTrampolinePacket) + createTrampolinePacket(outgoingAmount, expiryOut)) + } + + def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedRoute = { + val selfPayload = blindedRouteDataCodec.encode(TlvStream(PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty))).require.bytes + PaymentBlindedRoute(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route, PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty)) + } + + /** Create payments to a blinded path that starts at a remote node. */ + def createIncomingPaymentsToRemoteBlindedPath(features: Features[Bolt12Feature], scidDir_opt: Option[EncodedNodeId.ShortChannelIdDir]): Seq[RelayToBlindedPathsPacket] = { + val offer = Offer(None, Some("test offer"), outgoingNodeId, features, Block.RegtestGenesisBlock.hash) + val request = InvoiceRequest(offer, outgoingAmount, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) + val paymentBlindedRoute = scidDir_opt match { + case Some(scidDir) => + val nonCompact = createPaymentBlindedRoute(outgoingNodeId) + nonCompact.copy(route = nonCompact.route.copy(introductionNodeId = scidDir)) + case None => + createPaymentBlindedRoute(outgoingNodeId) + } + val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, features, Seq(paymentBlindedRoute)) + incomingMultiPart.map(incoming => { + val innerPayload = IntermediatePayload.NodeRelay.ToBlindedPaths(incoming.innerPayload.amountToForward, outgoingExpiry, invoice) + RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, innerPayload) + }) + } + + /** Create payments to a blinded path that starts at our node and relays to a wallet node. */ + def createIncomingPaymentsToWalletBlindedPath(nodeParams: NodeParams): Seq[RelayToBlindedPathsPacket] = { + val features: Features[Bolt12Feature] = Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional) + val offer = Offer(None, Some("test offer"), outgoingNodeId, features, Block.RegtestGenesisBlock.hash) + val request = InvoiceRequest(offer, outgoingAmount, 1, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), randomKey(), Block.RegtestGenesisBlock.hash) + val edge = ExtraEdge(nodeParams.nodeId, outgoingNodeId, Alias(561), 2_000_000 msat, 250, CltvExpiryDelta(144), 1 msat, None) + val hop = ChannelHop(edge.shortChannelId, nodeParams.nodeId, outgoingNodeId, HopRelayParams.FromHint(edge)) + val route = BlindedRouteCreation.createBlindedRouteToWallet(hop, hex"deadbeef", 1 msat, outgoingExpiry).route + val paymentInfo = BlindedRouteCreation.aggregatePaymentInfo(outgoingAmount, Seq(hop), CltvExpiryDelta(12)) + val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, features, Seq(PaymentBlindedRoute(route, paymentInfo))) + incomingMultiPart.map(incoming => { + val innerPayload = IntermediatePayload.NodeRelay.ToBlindedPaths(incoming.innerPayload.amountToForward, outgoingExpiry, invoice) + RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, innerPayload) + }) } } \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala index 8e68b899e..2c53d7e43 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala @@ -55,11 +55,10 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val router = TestProbe[Any]("router") val register = TestProbe[Any]("register") val paymentHandler = TestProbe[Any]("payment-handler") - val triggerer = TestProbe[AsyncPaymentTriggerer.Command]("payment-triggerer") val probe = TestProbe[Any]() // we can't spawn top-level actors with akka typed testKit.spawn(Behaviors.setup[Any] { context => - val relayer = context.toClassic.actorOf(Relayer.props(nodeParams, router.ref.toClassic, register.ref.toClassic, paymentHandler.ref.toClassic, triggerer.ref)) + val relayer = context.toClassic.actorOf(Relayer.props(nodeParams, router.ref.toClassic, register.ref.toClassic, paymentHandler.ref.toClassic)) probe.ref ! relayer Behaviors.empty[Any] }) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/BlindedPathsResolverSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/BlindedPathsResolverSpec.scala index 038f15aa8..8a9ae7d64 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/BlindedPathsResolverSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/BlindedPathsResolverSpec.scala @@ -31,7 +31,7 @@ import fr.acinq.eclair.payment.send.BlindedPathsResolver.{FullBlindedRoute, Part import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams} import fr.acinq.eclair.router.{BlindedRouteCreation, Router} import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo -import fr.acinq.eclair.{BlockHeight, CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshiLong, NodeParams, RealShortChannelId, TestConstants, randomBytes32, randomKey} +import fr.acinq.eclair.{Alias, BlockHeight, CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshiLong, NodeParams, RealShortChannelId, TestConstants, randomBytes32, randomKey} import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike import scodec.bits.HexStringSyntax @@ -151,6 +151,31 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l } } + test("resolve route starting at our node (wallet node)") { f => + import f._ + + val probe = TestProbe() + val walletNodeId = randomKey().publicKey + val edge = ExtraEdge(nodeParams.nodeId, walletNodeId, Alias(561), 5_000_000 msat, 200, CltvExpiryDelta(144), 1 msat, None) + val hop = ChannelHop(edge.shortChannelId, nodeParams.nodeId, walletNodeId, HopRelayParams.FromHint(edge)) + val route = BlindedRouteCreation.createBlindedRouteToWallet(hop, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route + val paymentInfo = BlindedRouteCreation.aggregatePaymentInfo(100_000_000 msat, Seq(hop), CltvExpiryDelta(12)) + val resolver = testKit.spawn(BlindedPathsResolver(nodeParams, randomBytes32(), router.ref, register.ref)) + resolver ! Resolve(probe.ref, Seq(PaymentBlindedRoute(route, paymentInfo))) + // We are the introduction node: we decrypt the payload and discover that the next node is a wallet node. + val resolved = probe.expectMsgType[Seq[ResolvedPath]] + assert(resolved.size == 1) + assert(resolved.head.route.isInstanceOf[PartialBlindedRoute]) + val partialRoute = resolved.head.route.asInstanceOf[PartialBlindedRoute] + assert(partialRoute.firstNodeId == walletNodeId) + assert(partialRoute.nextNodeId == EncodedNodeId.WithPublicKey.Wallet(walletNodeId)) + assert(partialRoute.blindedNodes == route.subsequentNodes) + assert(partialRoute.nextBlinding != route.blindingKey) + // We don't need to resolve the nodeId. + register.expectNoMessage(100 millis) + router.expectNoMessage(100 millis) + } + test("ignore blinded paths that cannot be resolved") { f => import f._ @@ -181,8 +206,9 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l val probe = TestProbe() val scid = RealShortChannelId(BlockHeight(750_000), 3, 7) - val edgeLowFees = ExtraEdge(nodeParams.nodeId, randomKey().publicKey, scid, 100 msat, 5, CltvExpiryDelta(144), 1 msat, None) - val edgeLowExpiryDelta = ExtraEdge(nodeParams.nodeId, randomKey().publicKey, scid, 600_000 msat, 100, CltvExpiryDelta(36), 1 msat, None) + val nextNodeId = randomKey().publicKey + val edgeLowFees = ExtraEdge(nodeParams.nodeId, nextNodeId, scid, 100 msat, 5, CltvExpiryDelta(144), 1 msat, None) + val edgeLowExpiryDelta = ExtraEdge(nodeParams.nodeId, nextNodeId, scid, 600_000 msat, 100, CltvExpiryDelta(36), 1 msat, None) val toResolve = Seq( // We don't allow paying blinded routes to ourselves. BlindedRouteCreation.createBlindedRouteWithoutHops(nodeParams.nodeId, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, @@ -190,6 +216,8 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, // We reject blinded routes with low cltv_expiry_delta. BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowExpiryDelta.targetNodeId, HopRelayParams.FromHint(edgeLowExpiryDelta))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, + // We reject blinded routes with low fees, even when the next node seems to be a wallet node. + BlindedRouteCreation.createBlindedRouteToWallet(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees)), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, // We reject blinded routes that cannot be decrypted. BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route.copy(blindingKey = randomKey().publicKey) ).map(r => PaymentBlindedRoute(r, PaymentInfo(1_000_000 msat, 2500, CltvExpiryDelta(300), 1 msat, 500_000_000 msat, Features.empty))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala index 13e22eca5..c0b691b58 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala @@ -89,7 +89,7 @@ class PaymentOnionSpec extends AnyFunSuite { val Right(payload) = IntermediatePayload.ChannelRelay.Standard.validate(decoded) assert(payload.amountOut == 561.msat) assert(payload.cltvOut == CltvExpiry(42)) - assert(payload.outgoingChannelId == ShortChannelId(1105)) + assert(payload.outgoing.contains(ShortChannelId(1105))) val encoded = perHopPayloadCodec.encode(expected).require.bytes assert(encoded == bin) } @@ -110,7 +110,7 @@ class PaymentOnionSpec extends AnyFunSuite { val decoded = perHopPayloadCodec.decode(bin.bits).require.value assert(decoded == expected) val Right(payload) = IntermediatePayload.ChannelRelay.Blinded.validate(decoded, blindedTlvs, randomKey().publicKey) - assert(payload.outgoingChannelId == ShortChannelId(42)) + assert(payload.outgoing.contains(ShortChannelId(42))) assert(payload.amountToForward(10_000 msat) == 9990.msat) assert(payload.outgoingCltv(CltvExpiry(1000)) == CltvExpiry(856)) assert(payload.paymentRelayData.allowedFeatures.isEmpty) @@ -119,6 +119,20 @@ class PaymentOnionSpec extends AnyFunSuite { } } + test("encode/decode channel relay blinded per-hop-payload (with wallet node_id)") { + val walletNodeId = PublicKey(hex"0221cd519eba9c8b840a5e40b65dc2c040e159a766979723ed770efceb97260ec8") + val blindedTlvs = TlvStream[RouteBlindingEncryptedDataTlv]( + RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Wallet(walletNodeId)), + RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(144), 100, 10 msat), + RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat), + ) + val Right(payload) = IntermediatePayload.ChannelRelay.Blinded.validate(TlvStream(EncryptedRecipientData(hex"deadbeef")), blindedTlvs, randomKey().publicKey) + assert(payload.outgoing == Left(walletNodeId)) + assert(payload.amountToForward(10_000 msat) == 9990.msat) + assert(payload.outgoingCltv(CltvExpiry(1000)) == CltvExpiry(856)) + assert(payload.paymentRelayData.allowedFeatures.isEmpty) + } + test("encode/decode node relay per-hop payload") { val nodeId = PublicKey(hex"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619") val expected = TlvStream[OnionPaymentPayloadTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), OutgoingNodeId(nodeId)) @@ -292,6 +306,8 @@ class PaymentOnionSpec extends AnyFunSuite { TestCase(MissingRequiredTlv(UInt64(10)), hex"23 0c21036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2", validBlindedTlvs), // Missing encrypted outgoing channel. TestCase(MissingRequiredTlv(UInt64(2)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(144), 100, 10 msat), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))), + // Forbidden encrypted outgoing plain node_id. + TestCase(ForbiddenTlv(UInt64(4)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Plain(randomKey().publicKey)), RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(144), 100, 10 msat), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))), // Missing encrypted payment relay data. TestCase(MissingRequiredTlv(UInt64(10)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.OutgoingChannelId(ShortChannelId(42)), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))), // Missing encrypted payment constraint.