1
0
mirror of https://github.com/ACINQ/eclair.git synced 2024-11-19 01:43:22 +01:00

Wake up wallet nodes before relaying messages or payments (#2865)

We refactor `NodeRelay.scala` to re-order some steps. The steps are:

1. Fully receive the incoming payment
2. Resolve the next node (unwrap blinded paths if needed)
3. Wake-up the next node if necessary (mobile wallet)
4. Relay outgoing payment

Note that we introduce a wake-up step, that can be extended to include
mobile notifications. We introduce that same wake-up step in channel
relay and message relay. We also allow relaying data to contain a wallet
`node_id` instead of an scid. When that's the case, we start by waking
up that wallet node before we try relaying onion messages or payments.

This wake-up step doesn't contain any logic right now apart from waiting
for the peer to connect, if it isn't connected already. But it can easily be
extended to send a mobile notification to prompt the wallet to connect.
This commit is contained in:
Bastien Teinturier 2024-08-28 09:43:11 +02:00 committed by GitHub
parent c440007b52
commit fcd88b0a0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 1029 additions and 594 deletions

View File

@ -318,6 +318,13 @@ eclair {
max-no-channels = 250 // maximum number of incoming connections from peers that do not have any channels with us max-no-channels = 250 // maximum number of incoming connections from peers that do not have any channels with us
} }
// When relaying payments or messages to mobile peers who are disconnected, we may try to wake them up using a mobile
// notification system, or we attempt connecting to the last known address.
peer-wake-up {
enabled = false
timeout = 60 seconds
}
auto-reconnect = true auto-reconnect = true
initial-random-reconnect-delay = 5 seconds // we add a random delay before the first reconnection attempt, capped by this value initial-random-reconnect-delay = 5 seconds // we add a random delay before the first reconnection attempt, capped by this value
max-reconnect-interval = 1 hour // max interval between two reconnection attempts, after the exponential backoff period max-reconnect-interval = 1 hour // max interval between two reconnection attempts, after the exponential backoff period

View File

@ -28,7 +28,7 @@ import fr.acinq.eclair.crypto.Noise.KeyPair
import fr.acinq.eclair.crypto.keymanager.{ChannelKeyManager, NodeKeyManager, OnChainKeyManager} import fr.acinq.eclair.crypto.keymanager.{ChannelKeyManager, NodeKeyManager, OnChainKeyManager}
import fr.acinq.eclair.db._ import fr.acinq.eclair.db._
import fr.acinq.eclair.io.MessageRelay.{RelayAll, RelayChannelsOnly, RelayPolicy} import fr.acinq.eclair.io.MessageRelay.{RelayAll, RelayChannelsOnly, RelayPolicy}
import fr.acinq.eclair.io.PeerConnection import fr.acinq.eclair.io.{PeerConnection, PeerReadyNotifier}
import fr.acinq.eclair.message.OnionMessages.OnionMessageConfig import fr.acinq.eclair.message.OnionMessages.OnionMessageConfig
import fr.acinq.eclair.payment.relay.Relayer.{AsyncPaymentsParams, RelayFees, RelayParams} import fr.acinq.eclair.payment.relay.Relayer.{AsyncPaymentsParams, RelayFees, RelayParams}
import fr.acinq.eclair.router.Announcements.AddressException import fr.acinq.eclair.router.Announcements.AddressException
@ -87,7 +87,8 @@ case class NodeParams(nodeKeyManager: NodeKeyManager,
blockchainWatchdogSources: Seq[String], blockchainWatchdogSources: Seq[String],
onionMessageConfig: OnionMessageConfig, onionMessageConfig: OnionMessageConfig,
purgeInvoicesInterval: Option[FiniteDuration], purgeInvoicesInterval: Option[FiniteDuration],
revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config) { revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config,
peerWakeUpConfig: PeerReadyNotifier.WakeUpConfig) {
val privateKey: Crypto.PrivateKey = nodeKeyManager.nodeKey.privateKey val privateKey: Crypto.PrivateKey = nodeKeyManager.nodeKey.privateKey
val nodeId: PublicKey = nodeKeyManager.nodeId val nodeId: PublicKey = nodeKeyManager.nodeId
@ -611,7 +612,11 @@ object NodeParams extends Logging {
revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config( revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(
batchSize = config.getInt("db.revoked-htlc-info-cleaner.batch-size"), batchSize = config.getInt("db.revoked-htlc-info-cleaner.batch-size"),
interval = FiniteDuration(config.getDuration("db.revoked-htlc-info-cleaner.interval").getSeconds, TimeUnit.SECONDS) interval = FiniteDuration(config.getDuration("db.revoked-htlc-info-cleaner.interval").getSeconds, TimeUnit.SECONDS)
) ),
peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(
enabled = config.getBoolean("peer-wake-up.enabled"),
timeout = FiniteDuration(config.getDuration("peer-wake-up.timeout").getSeconds, TimeUnit.SECONDS)
),
) )
} }
} }

View File

@ -360,7 +360,8 @@ class Setup(val datadir: File,
offerManager = system.spawn(Behaviors.supervise(OfferManager(nodeParams, router, paymentTimeout = 1 minute)).onFailure(typed.SupervisorStrategy.resume), name = "offer-manager") offerManager = system.spawn(Behaviors.supervise(OfferManager(nodeParams, router, paymentTimeout = 1 minute)).onFailure(typed.SupervisorStrategy.resume), name = "offer-manager")
paymentHandler = system.actorOf(SimpleSupervisor.props(PaymentHandler.props(nodeParams, register, offerManager), "payment-handler", SupervisorStrategy.Resume)) paymentHandler = system.actorOf(SimpleSupervisor.props(PaymentHandler.props(nodeParams, register, offerManager), "payment-handler", SupervisorStrategy.Resume))
triggerer = system.spawn(Behaviors.supervise(AsyncPaymentTriggerer()).onFailure(typed.SupervisorStrategy.resume), name = "async-payment-triggerer") triggerer = system.spawn(Behaviors.supervise(AsyncPaymentTriggerer()).onFailure(typed.SupervisorStrategy.resume), name = "async-payment-triggerer")
relayer = system.actorOf(SimpleSupervisor.props(Relayer.props(nodeParams, router, register, paymentHandler, triggerer, Some(postRestartCleanUpInitialized)), "relayer", SupervisorStrategy.Resume)) peerReadyManager = system.spawn(Behaviors.supervise(PeerReadyManager()).onFailure(typed.SupervisorStrategy.restart), name = "peer-ready-manager")
relayer = system.actorOf(SimpleSupervisor.props(Relayer.props(nodeParams, router, register, paymentHandler, Some(postRestartCleanUpInitialized)), "relayer", SupervisorStrategy.Resume))
_ = relayer ! PostRestartHtlcCleaner.Init(channels) _ = relayer ! PostRestartHtlcCleaner.Init(channels)
// Before initializing the switchboard (which re-connects us to the network) and the user-facing parts of the system, // Before initializing the switchboard (which re-connects us to the network) and the user-facing parts of the system,
// we want to make sure the handler for post-restart broken HTLCs has finished initializing. // we want to make sure the handler for post-restart broken HTLCs has finished initializing.

View File

@ -44,29 +44,18 @@ object MessageRelay {
policy: RelayPolicy, policy: RelayPolicy,
replyTo_opt: Option[typed.ActorRef[Status]]) extends Command replyTo_opt: Option[typed.ActorRef[Status]]) extends Command
case class WrappedPeerInfo(peerInfo: PeerInfoResponse) extends Command case class WrappedPeerInfo(peerInfo: PeerInfoResponse) extends Command
case class WrappedConnectionResult(result: PeerConnection.ConnectionResult) extends Command private case class WrappedConnectionResult(result: PeerConnection.ConnectionResult) extends Command
case class WrappedOptionalNodeId(nodeId_opt: Option[PublicKey]) extends Command private case class WrappedOptionalNodeId(nodeId_opt: Option[PublicKey]) extends Command
private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command
sealed trait Status { sealed trait Status { val messageId: ByteVector32 }
val messageId: ByteVector32
}
case class Sent(messageId: ByteVector32) extends Status case class Sent(messageId: ByteVector32) extends Status
sealed trait Failure extends Status sealed trait Failure extends Status
case class AgainstPolicy(messageId: ByteVector32, policy: RelayPolicy) extends Failure { case class AgainstPolicy(messageId: ByteVector32, policy: RelayPolicy) extends Failure { override def toString: String = s"Relay prevented by policy $policy" }
override def toString: String = s"Relay prevented by policy $policy" case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure { override def toString: String = s"Can't connect to peer: ${failure.toString}" }
} case class Disconnected(messageId: ByteVector32) extends Failure { override def toString: String = "Peer is not connected" }
case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure { case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure { override def toString: String = s"Unknown channel: $channelId" }
override def toString: String = s"Can't connect to peer: ${failure.toString}" case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure { override def toString: String = s"Message dropped: $reason" }
}
case class Disconnected(messageId: ByteVector32) extends Failure {
override def toString: String = "Peer is not connected"
}
case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure {
override def toString: String = s"Unknown channel: $channelId"
}
case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure {
override def toString: String = s"Message dropped: $reason"
}
sealed trait RelayPolicy sealed trait RelayPolicy
case object RelayChannelsOnly extends RelayPolicy case object RelayChannelsOnly extends RelayPolicy
@ -106,7 +95,7 @@ private class MessageRelay(nodeParams: NodeParams,
def queryNextNodeId(msg: OnionMessage, nextNode: Either[ShortChannelId, EncodedNodeId]): Behavior[Command] = { def queryNextNodeId(msg: OnionMessage, nextNode: Either[ShortChannelId, EncodedNodeId]): Behavior[Command] = {
nextNode match { nextNode match {
case Left(outgoingChannelId) if outgoingChannelId == ShortChannelId.toSelf => case Left(outgoingChannelId) if outgoingChannelId == ShortChannelId.toSelf =>
withNextNodeId(msg, nodeParams.nodeId) withNextNodeId(msg, EncodedNodeId.WithPublicKey.Plain(nodeParams.nodeId))
case Left(outgoingChannelId) => case Left(outgoingChannelId) =>
register ! Register.GetNextNodeId(context.messageAdapter(WrappedOptionalNodeId), outgoingChannelId) register ! Register.GetNextNodeId(context.messageAdapter(WrappedOptionalNodeId), outgoingChannelId)
waitForNextNodeId(msg, outgoingChannelId) waitForNextNodeId(msg, outgoingChannelId)
@ -114,7 +103,7 @@ private class MessageRelay(nodeParams: NodeParams,
router ! Router.GetNodeId(context.messageAdapter(WrappedOptionalNodeId), scid, isNode1) router ! Router.GetNodeId(context.messageAdapter(WrappedOptionalNodeId), scid, isNode1)
waitForNextNodeId(msg, scid) waitForNextNodeId(msg, scid)
case Right(encodedNodeId: EncodedNodeId.WithPublicKey) => case Right(encodedNodeId: EncodedNodeId.WithPublicKey) =>
withNextNodeId(msg, encodedNodeId.publicKey) withNextNodeId(msg, encodedNodeId)
} }
} }
@ -127,34 +116,39 @@ private class MessageRelay(nodeParams: NodeParams,
Behaviors.stopped Behaviors.stopped
case WrappedOptionalNodeId(Some(nextNodeId)) => case WrappedOptionalNodeId(Some(nextNodeId)) =>
log.info("found outgoing node {} for channel {}", nextNodeId, channelId) log.info("found outgoing node {} for channel {}", nextNodeId, channelId)
withNextNodeId(msg, nextNodeId) withNextNodeId(msg, EncodedNodeId.WithPublicKey.Plain(nextNodeId))
} }
} }
private def withNextNodeId(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = { private def withNextNodeId(msg: OnionMessage, nextNodeId: EncodedNodeId.WithPublicKey): Behavior[Command] = {
if (nextNodeId == nodeParams.nodeId) { nextNodeId match {
OnionMessages.process(nodeParams.privateKey, msg) match { case EncodedNodeId.WithPublicKey.Plain(nodeId) if nodeId == nodeParams.nodeId =>
case OnionMessages.DropMessage(reason) => OnionMessages.process(nodeParams.privateKey, msg) match {
Metrics.OnionMessagesNotRelayed.withTag(Tags.Reason, reason.getClass.getSimpleName).increment() case OnionMessages.DropMessage(reason) =>
replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason)) Metrics.OnionMessagesNotRelayed.withTag(Tags.Reason, reason.getClass.getSimpleName).increment()
Behaviors.stopped replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason))
case OnionMessages.SendMessage(nextNode, nextMessage) => Behaviors.stopped
// We need to repeat the process until we identify the (real) next node, or find out that we're the recipient. case OnionMessages.SendMessage(nextNode, nextMessage) =>
queryNextNodeId(nextMessage, nextNode) // We need to repeat the process until we identify the (real) next node, or find out that we're the recipient.
case received: OnionMessages.ReceiveMessage => queryNextNodeId(nextMessage, nextNode)
context.system.eventStream ! EventStream.Publish(received) case received: OnionMessages.ReceiveMessage =>
replyTo_opt.foreach(_ ! Sent(messageId)) context.system.eventStream ! EventStream.Publish(received)
Behaviors.stopped replyTo_opt.foreach(_ ! Sent(messageId))
} Behaviors.stopped
} else { }
policy match { case EncodedNodeId.WithPublicKey.Plain(nodeId) =>
case RelayChannelsOnly => policy match {
switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId) case RelayChannelsOnly =>
waitForPreviousPeerForPolicyCheck(msg, nextNodeId) switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId)
case RelayAll => waitForPreviousPeerForPolicyCheck(msg, nodeId)
switchboard ! Peer.Connect(nextNodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false) case RelayAll =>
waitForConnection(msg, nextNodeId) switchboard ! Peer.Connect(nodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false)
} waitForConnection(msg, nodeId)
}
case EncodedNodeId.WithPublicKey.Wallet(nodeId) =>
val notifier = context.spawnAnonymous(PeerReadyNotifier(nodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout))))
notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult))
waitForWalletNodeUp(msg, nodeId)
} }
} }
@ -197,4 +191,18 @@ private class MessageRelay(nodeParams: NodeParams,
Behaviors.stopped Behaviors.stopped
} }
} }
private def waitForWalletNodeUp(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = {
Behaviors.receiveMessagePartial {
case WrappedPeerReadyResult(r: PeerReadyNotifier.PeerReady) =>
log.info("successfully woke up {}: relaying onion message", nextNodeId)
r.peer ! Peer.RelayOnionMessage(messageId, msg, replyTo_opt)
Behaviors.stopped
case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) =>
Metrics.OnionMessagesNotRelayed.withTag(Tags.Reason, Tags.Reasons.ConnectionFailure).increment()
log.info("could not wake up {}: onion message cannot be relayed", nextNodeId)
replyTo_opt.foreach(_ ! Disconnected(messageId))
Behaviors.stopped
}
}
} }

View File

@ -17,36 +17,104 @@
package fr.acinq.eclair.io package fr.acinq.eclair.io
import akka.actor.typed.eventstream.EventStream import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.receptionist.Receptionist import akka.actor.typed.receptionist.{Receptionist, ServiceKey}
import akka.actor.typed.scaladsl.adapter.{ClassicActorRefOps, TypedActorRefOps} import akka.actor.typed.scaladsl.adapter.{ClassicActorRefOps, TypedActorRefOps}
import akka.actor.typed.scaladsl.{ActorContext, Behaviors, TimerScheduler} import akka.actor.typed.scaladsl.{ActorContext, Behaviors, TimerScheduler}
import akka.actor.typed.{ActorRef, Behavior} import akka.actor.typed.{ActorRef, Behavior, SupervisorStrategy}
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.blockchain.CurrentBlockHeight import fr.acinq.eclair.blockchain.CurrentBlockHeight
import fr.acinq.eclair.{BlockHeight, Logs, channel} import fr.acinq.eclair.{BlockHeight, Logs, channel}
import scala.concurrent.duration.{DurationInt, FiniteDuration} import scala.concurrent.duration.{DurationInt, FiniteDuration}
/**
* This actor tracks the set of pending [[PeerReadyNotifier]].
* It can be used to ensure that notifications are only sent once, even if there are multiple parallel operations
* waiting for that peer to come online.
*/
object PeerReadyManager {
val PeerReadyManagerServiceKey: ServiceKey[Register] = ServiceKey[Register]("peer-ready-manager")
// @formatter:off
sealed trait Command
case class Register(replyTo: ActorRef[Registered], remoteNodeId: PublicKey) extends Command
case class List(replyTo: ActorRef[Set[PublicKey]]) extends Command
private case class Completed(remoteNodeId: PublicKey, actor: ActorRef[Registered]) extends Command
// @formatter:on
/**
* @param otherAttempts number of already pending [[PeerReadyNotifier]] instances for that peer.
*/
case class Registered(remoteNodeId: PublicKey, otherAttempts: Int)
def apply(): Behavior[Command] = {
Behaviors.setup { context =>
context.system.receptionist ! Receptionist.Register(PeerReadyManagerServiceKey, context.self)
watch(Map.empty, context)
}
}
private def watch(pending: Map[PublicKey, Set[ActorRef[Registered]]], context: ActorContext[Command]): Behavior[Command] = {
Behaviors.receiveMessage {
case Register(replyTo, remoteNodeId) =>
context.watchWith(replyTo, Completed(remoteNodeId, replyTo))
pending.get(remoteNodeId) match {
case Some(attempts) =>
replyTo ! Registered(remoteNodeId, otherAttempts = attempts.size)
val attempts1 = attempts + replyTo
watch(pending + (remoteNodeId -> attempts1), context)
case None =>
replyTo ! Registered(remoteNodeId, otherAttempts = 0)
watch(pending + (remoteNodeId -> Set(replyTo)), context)
}
case Completed(remoteNodeId, actor) =>
pending.get(remoteNodeId) match {
case Some(attempts) =>
val attempts1 = attempts - actor
if (attempts1.isEmpty) {
watch(pending - remoteNodeId, context)
} else {
watch(pending + (remoteNodeId -> attempts1), context)
}
case None =>
Behaviors.same
}
case List(replyTo) =>
replyTo ! pending.keySet
Behaviors.same
}
}
}
/** /**
* This actor waits for a given peer to be online and ready to process payments. * This actor waits for a given peer to be online and ready to process payments.
* It automatically stops after the timeout provided. * It automatically stops after the timeout provided if the peer doesn't connect.
* There may be multiple instances of this actor running in parallel for the same peer, which is fine because they
* may use different timeouts.
* Having separate actor instances for each caller guarantees that the caller will always receive a response.
*/ */
object PeerReadyNotifier { object PeerReadyNotifier {
case class WakeUpConfig(enabled: Boolean, timeout: FiniteDuration)
// @formatter:off // @formatter:off
sealed trait Command sealed trait Command
case class NotifyWhenPeerReady(replyTo: ActorRef[Result]) extends Command case class NotifyWhenPeerReady(replyTo: ActorRef[Result]) extends Command
private final case class WrappedListing(wrapped: Receptionist.Listing) extends Command private final case class WrappedListing(wrapped: Receptionist.Listing) extends Command
private final case class WrappedRegistered(registered: PeerReadyManager.Registered) extends Command
private case object PeerNotConnected extends Command private case object PeerNotConnected extends Command
private case class SomePeerConnected(nodeId: PublicKey) extends Command private case object PeerConnected extends Command
private case class SomePeerDisconnected(nodeId: PublicKey) extends Command private case object PeerDisconnected extends Command
private case class WrappedPeerInfo(peer: ActorRef[Peer.GetPeerChannels], channelCount: Int) extends Command private case class WrappedPeerInfo(peer: ActorRef[Peer.GetPeerChannels], channelCount: Int) extends Command
private case class NewBlockNotTimedOut(currentBlockHeight: BlockHeight) extends Command private case class NewBlockNotTimedOut(currentBlockHeight: BlockHeight) extends Command
private case object CheckChannelsReady extends Command private case object CheckChannelsReady extends Command
private case class WrappedPeerChannels(wrapped: Peer.PeerChannels) extends Command private case class WrappedPeerChannels(wrapped: Peer.PeerChannels) extends Command
private case object Timeout extends Command private case object Timeout extends Command
private case object ToBeIgnored extends Command
sealed trait Result sealed trait Result { def remoteNodeId: PublicKey }
case class PeerReady(remoteNodeId: PublicKey, peer: akka.actor.ActorRef, channelInfos: Seq[Peer.ChannelInfo]) extends Result { val channelsCount: Int = channelInfos.size } case class PeerReady(remoteNodeId: PublicKey, peer: akka.actor.ActorRef, channelInfos: Seq[Peer.ChannelInfo]) extends Result { val channelsCount: Int = channelInfos.size }
case class PeerUnavailable(remoteNodeId: PublicKey) extends Result case class PeerUnavailable(remoteNodeId: PublicKey) extends Result
@ -66,102 +134,40 @@ object PeerReadyNotifier {
case cbc => NewBlockNotTimedOut(cbc.blockHeight) case cbc => NewBlockNotTimedOut(cbc.blockHeight)
}) })
} }
// In case the peer is not currently connected, we will wait for them to connect instead of regularly // The actor should never throw, but for extra safety we wrap it with a supervisor.
// polling the switchboard. This makes more sense for long timeouts such as the ones used for async payments. Behaviors.supervise {
context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerConnected](e => SomePeerConnected(e.nodeId))) start(replyTo, remoteNodeId, context, timers)
context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerDisconnected](e => SomePeerDisconnected(e.nodeId))) }.onFailure(SupervisorStrategy.stop)
findSwitchboard(replyTo, remoteNodeId, context, timers)
} }
} }
} }
} }
} }
private def findSwitchboard(replyTo: ActorRef[Result], remoteNodeId: PublicKey, context: ActorContext[Command], timers: TimerScheduler[Command]): Behavior[Command] = { private def start(replyTo: ActorRef[Result], remoteNodeId: PublicKey, context: ActorContext[Command], timers: TimerScheduler[Command]): Behavior[Command] = {
context.system.receptionist ! Receptionist.Find(Switchboard.SwitchboardServiceKey, context.messageAdapter[Receptionist.Listing](WrappedListing)) // We start by registering ourself to see if other instances are running.
context.system.receptionist ! Receptionist.Find(PeerReadyManager.PeerReadyManagerServiceKey, context.messageAdapter[Receptionist.Listing](WrappedListing))
Behaviors.receiveMessagePartial { Behaviors.receiveMessagePartial {
case WrappedListing(Switchboard.SwitchboardServiceKey.Listing(listings)) => case WrappedListing(PeerReadyManager.PeerReadyManagerServiceKey.Listing(listings)) =>
listings.headOption match { listings.headOption match {
case Some(switchboard) => case Some(peerReadyManager) =>
waitForPeerConnected(replyTo, remoteNodeId, switchboard, context, timers) peerReadyManager ! PeerReadyManager.Register(context.messageAdapter[PeerReadyManager.Registered](WrappedRegistered), remoteNodeId)
Behaviors.same
case None => case None =>
context.log.error("no switchboard found") context.log.error("no peer-ready-manager found")
replyTo ! PeerUnavailable(remoteNodeId) replyTo ! PeerUnavailable(remoteNodeId)
Behaviors.stopped Behaviors.stopped
}
}
}
private def waitForPeerConnected(replyTo: ActorRef[Result], remoteNodeId: PublicKey, switchboard: ActorRef[Switchboard.GetPeerInfo], context: ActorContext[Command], timers: TimerScheduler[Command]): Behavior[Command] = {
val peerInfoAdapter = context.messageAdapter[Peer.PeerInfoResponse] {
// We receive this when we don't have any channel to the given peer and are not currently connected to them.
// In that case we still want to wait for a connection, because we may want to open a channel to them.
case _: Peer.PeerNotFound => PeerNotConnected
case info: Peer.PeerInfo if info.state != Peer.CONNECTED => PeerNotConnected
case info: Peer.PeerInfo => WrappedPeerInfo(info.peer.toTyped, info.channels.size)
}
// We check whether the peer is already connected.
switchboard ! Switchboard.GetPeerInfo(peerInfoAdapter, remoteNodeId)
Behaviors.receiveMessagePartial {
case PeerNotConnected =>
context.log.debug("peer is not connected yet")
Behaviors.same
case SomePeerConnected(nodeId) =>
if (nodeId == remoteNodeId) {
switchboard ! Switchboard.GetPeerInfo(peerInfoAdapter, remoteNodeId)
} }
Behaviors.same case WrappedRegistered(registered) =>
case SomePeerDisconnected(_) => context.log.info("checking if peer is ready ({} other attempts)", registered.otherAttempts)
Behaviors.same val isFirstAttempt = registered.otherAttempts == 0
case WrappedPeerInfo(peer, channelCount) => // In case the peer is not currently connected, we will wait for them to connect instead of regularly
if (channelCount == 0) { // polling the switchboard. This makes more sense for long timeouts such as the ones used for async payments.
context.log.info("peer is ready with no channels") context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerConnected](e => if (e.nodeId == remoteNodeId) PeerConnected else ToBeIgnored))
replyTo ! PeerReady(remoteNodeId, peer.toClassic, Seq.empty) context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerDisconnected](e => if (e.nodeId == remoteNodeId) PeerDisconnected else ToBeIgnored))
Behaviors.stopped new PeerReadyNotifier(replyTo, remoteNodeId, isFirstAttempt, context, timers).findSwitchboard()
} else {
context.log.debug("peer is connected with {} channels", channelCount)
waitForChannelsReady(replyTo, remoteNodeId, peer, switchboard, context, timers)
}
case NewBlockNotTimedOut(currentBlockHeight) =>
context.log.debug("waiting for peer to connect at block {}", currentBlockHeight)
Behaviors.same
case Timeout => case Timeout =>
context.log.info("timed out waiting for peer to connect") context.log.info("timed out finding peer-ready-manager actor")
replyTo ! PeerUnavailable(remoteNodeId)
Behaviors.stopped
}
}
private def waitForChannelsReady(replyTo: ActorRef[Result], remoteNodeId: PublicKey, peer: ActorRef[Peer.GetPeerChannels], switchboard: ActorRef[Switchboard.GetPeerInfo], context: ActorContext[Command], timers: TimerScheduler[Command]): Behavior[Command] = {
timers.startTimerWithFixedDelay(ChannelsReadyTimerKey, CheckChannelsReady, initialDelay = 50 millis, delay = 1 second)
Behaviors.receiveMessagePartial {
case CheckChannelsReady =>
context.log.debug("checking channel states")
peer ! Peer.GetPeerChannels(context.messageAdapter[Peer.PeerChannels](WrappedPeerChannels))
Behaviors.same
case WrappedPeerChannels(peerChannels) =>
if (peerChannels.channels.map(_.state).forall(isChannelReady)) {
replyTo ! PeerReady(remoteNodeId, peer.toClassic, peerChannels.channels)
Behaviors.stopped
} else {
context.log.debug("peer has {} channels that are not ready", peerChannels.channels.count(s => !isChannelReady(s.state)))
Behaviors.same
}
case NewBlockNotTimedOut(currentBlockHeight) =>
context.log.debug("waiting for channels to be ready at block {}", currentBlockHeight)
Behaviors.same
case SomePeerConnected(_) =>
Behaviors.same
case SomePeerDisconnected(nodeId) =>
if (nodeId == remoteNodeId) {
context.log.debug("peer disconnected, waiting for them to reconnect")
timers.cancel(ChannelsReadyTimerKey)
waitForPeerConnected(replyTo, remoteNodeId, switchboard, context, timers)
} else {
Behaviors.same
}
case Timeout =>
context.log.info("timed out waiting for channels to be ready")
replyTo ! PeerUnavailable(remoteNodeId) replyTo ! PeerUnavailable(remoteNodeId)
Behaviors.stopped Behaviors.stopped
} }
@ -199,3 +205,109 @@ object PeerReadyNotifier {
} }
} }
private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result],
remoteNodeId: PublicKey,
isFirstAttempt: Boolean,
context: ActorContext[PeerReadyNotifier.Command],
timers: TimerScheduler[PeerReadyNotifier.Command]) {
import PeerReadyNotifier._
private val log = context.log
private def findSwitchboard(): Behavior[Command] = {
context.system.receptionist ! Receptionist.Find(Switchboard.SwitchboardServiceKey, context.messageAdapter[Receptionist.Listing](WrappedListing))
Behaviors.receiveMessagePartial {
case WrappedListing(Switchboard.SwitchboardServiceKey.Listing(listings)) =>
listings.headOption match {
case Some(switchboard) =>
waitForPeerConnected(switchboard)
case None =>
log.error("no switchboard found")
replyTo ! PeerUnavailable(remoteNodeId)
Behaviors.stopped
}
case Timeout =>
log.info("timed out finding switchboard actor")
replyTo ! PeerUnavailable(remoteNodeId)
Behaviors.stopped
case ToBeIgnored =>
Behaviors.same
}
}
private def waitForPeerConnected(switchboard: ActorRef[Switchboard.GetPeerInfo]): Behavior[Command] = {
val peerInfoAdapter = context.messageAdapter[Peer.PeerInfoResponse] {
// We receive this when we don't have any channel to the given peer and are not currently connected to them.
// In that case we still want to wait for a connection, because we may want to open a channel to them.
case _: Peer.PeerNotFound => PeerNotConnected
case info: Peer.PeerInfo if info.state != Peer.CONNECTED => PeerNotConnected
case info: Peer.PeerInfo => WrappedPeerInfo(info.peer.toTyped, info.channels.size)
}
// We check whether the peer is already connected.
switchboard ! Switchboard.GetPeerInfo(peerInfoAdapter, remoteNodeId)
Behaviors.receiveMessagePartial {
case PeerNotConnected =>
log.debug("peer is not connected yet")
Behaviors.same
case PeerConnected =>
switchboard ! Switchboard.GetPeerInfo(peerInfoAdapter, remoteNodeId)
Behaviors.same
case PeerDisconnected =>
Behaviors.same
case WrappedPeerInfo(peer, channelCount) =>
if (channelCount == 0) {
log.info("peer is ready with no channels")
replyTo ! PeerReady(remoteNodeId, peer.toClassic, Seq.empty)
Behaviors.stopped
} else {
log.debug("peer is connected with {} channels", channelCount)
waitForChannelsReady(peer, switchboard)
}
case NewBlockNotTimedOut(currentBlockHeight) =>
log.debug("waiting for peer to connect at block {}", currentBlockHeight)
Behaviors.same
case Timeout =>
log.info("timed out waiting for peer to connect")
replyTo ! PeerUnavailable(remoteNodeId)
Behaviors.stopped
case ToBeIgnored =>
Behaviors.same
}
}
private def waitForChannelsReady(peer: ActorRef[Peer.GetPeerChannels], switchboard: ActorRef[Switchboard.GetPeerInfo]): Behavior[Command] = {
timers.startTimerWithFixedDelay(ChannelsReadyTimerKey, CheckChannelsReady, initialDelay = 50 millis, delay = 1 second)
Behaviors.receiveMessagePartial {
case CheckChannelsReady =>
log.debug("checking channel states")
peer ! Peer.GetPeerChannels(context.messageAdapter[Peer.PeerChannels](WrappedPeerChannels))
Behaviors.same
case WrappedPeerChannels(peerChannels) =>
if (peerChannels.channels.map(_.state).forall(isChannelReady)) {
replyTo ! PeerReady(remoteNodeId, peer.toClassic, peerChannels.channels)
Behaviors.stopped
} else {
log.debug("peer has {} channels that are not ready", peerChannels.channels.count(s => !isChannelReady(s.state)))
Behaviors.same
}
case NewBlockNotTimedOut(currentBlockHeight) =>
log.debug("waiting for channels to be ready at block {}", currentBlockHeight)
Behaviors.same
case PeerConnected =>
Behaviors.same
case PeerDisconnected =>
log.debug("peer disconnected, waiting for them to reconnect")
timers.cancel(ChannelsReadyTimerKey)
waitForPeerConnected(switchboard)
case Timeout =>
log.info("timed out waiting for channels to be ready")
replyTo ! PeerUnavailable(remoteNodeId)
Behaviors.stopped
case ToBeIgnored =>
Behaviors.same
}
}
}

View File

@ -119,6 +119,7 @@ object Monitoring {
val Failure = "failure" val Failure = "failure"
object FailureType { object FailureType {
val WakeUp = "WakeUp"
val Remote = "Remote" val Remote = "Remote"
val Malformed = "MalformedHtlc" val Malformed = "MalformedHtlc"

View File

@ -126,7 +126,7 @@ object IncomingPaymentPacket {
decryptEncryptedRecipientData(add, privateKey, payload, encrypted.data).flatMap { decryptEncryptedRecipientData(add, privateKey, payload, encrypted.data).flatMap {
case DecodedEncryptedRecipientData(blindedPayload, nextBlinding) => case DecodedEncryptedRecipientData(blindedPayload, nextBlinding) =>
validateBlindedChannelRelayPayload(add, payload, blindedPayload, nextBlinding, nextPacket).flatMap { validateBlindedChannelRelayPayload(add, payload, blindedPayload, nextBlinding, nextPacket).flatMap {
case ChannelRelayPacket(_, payload, nextPacket) if payload.outgoingChannelId == ShortChannelId.toSelf => case ChannelRelayPacket(_, payload, nextPacket) if payload.outgoing == Right(ShortChannelId.toSelf) =>
decrypt(add.copy(onionRoutingPacket = nextPacket, tlvStream = add.tlvStream.copy(records = Set(UpdateAddHtlcTlv.BlindingPoint(nextBlinding)))), privateKey, features) decrypt(add.copy(onionRoutingPacket = nextPacket, tlvStream = add.tlvStream.copy(records = Set(UpdateAddHtlcTlv.BlindingPoint(nextBlinding)))), privateKey, features)
case relayPacket => Right(relayPacket) case relayPacket => Right(relayPacket)
} }

View File

@ -19,7 +19,7 @@ package fr.acinq.eclair.payment.relay
import akka.actor.typed.ActorRef.ActorRefOps import akka.actor.typed.ActorRef.ActorRefOps
import akka.actor.typed.eventstream.EventStream import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.scaladsl.{ActorContext, Behaviors} import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import akka.actor.typed.{ActorRef, Behavior, SupervisorStrategy} import akka.actor.typed.{ActorRef, Behavior}
import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair.Logs.LogCategory
@ -99,7 +99,7 @@ private class AsyncPaymentTriggerer(context: ActorContext[Command]) {
case Watch(replyTo, remoteNodeId, paymentHash, timeout) => case Watch(replyTo, remoteNodeId, paymentHash, timeout) =>
peers.get(remoteNodeId) match { peers.get(remoteNodeId) match {
case None => case None =>
val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(remoteNodeId, timeout_opt = None)).onFailure(SupervisorStrategy.stop)) val notifier = context.spawnAnonymous(PeerReadyNotifier(remoteNodeId, timeout_opt = None))
context.watchWith(notifier, NotifierStopped(remoteNodeId)) context.watchWith(notifier, NotifierStopped(remoteNodeId))
notifier ! NotifyWhenPeerReady(context.messageAdapter[PeerReadyNotifier.Result](WrappedPeerReadyResult)) notifier ! NotifyWhenPeerReady(context.messageAdapter[PeerReadyNotifier.Result](WrappedPeerReadyResult))
val peer = PeerPayments(notifier, Set(Payment(replyTo, timeout, paymentHash))) val peer = PeerPayments(notifier, Set(Payment(replyTo, timeout, paymentHash)))

View File

@ -16,16 +16,17 @@
package fr.acinq.eclair.payment.relay package fr.acinq.eclair.payment.relay
import akka.actor.ActorRef
import akka.actor.typed.Behavior import akka.actor.typed.Behavior
import akka.actor.typed.eventstream.EventStream import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.scaladsl.adapter.TypedActorRefOps
import akka.actor.typed.scaladsl.{ActorContext, Behaviors} import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import akka.actor.{ActorRef, typed}
import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel._
import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.db.PendingCommandsDb import fr.acinq.eclair.db.PendingCommandsDb
import fr.acinq.eclair.io.PeerReadyNotifier
import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags}
import fr.acinq.eclair.payment.relay.Relayer.{OutgoingChannel, OutgoingChannelParams} import fr.acinq.eclair.payment.relay.Relayer.{OutgoingChannel, OutgoingChannelParams}
import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket} import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket}
@ -44,6 +45,7 @@ object ChannelRelay {
// @formatter:off // @formatter:off
sealed trait Command sealed trait Command
private case object DoRelay extends Command private case object DoRelay extends Command
private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command
private case class WrappedForwardFailure(failure: Register.ForwardFailure[CMD_ADD_HTLC]) extends Command private case class WrappedForwardFailure(failure: Register.ForwardFailure[CMD_ADD_HTLC]) extends Command
private case class WrappedAddResponse(res: CommandResponse[CMD_ADD_HTLC]) extends Command private case class WrappedAddResponse(res: CommandResponse[CMD_ADD_HTLC]) extends Command
// @formatter:on // @formatter:on
@ -57,7 +59,7 @@ object ChannelRelay {
def apply(nodeParams: NodeParams, def apply(nodeParams: NodeParams,
register: ActorRef, register: ActorRef,
channels: Map[ByteVector32, Relayer.OutgoingChannel], channels: Map[ByteVector32, Relayer.OutgoingChannel],
originNode:PublicKey, originNode: PublicKey,
relayId: UUID, relayId: UUID,
r: IncomingPaymentPacket.ChannelRelayPacket): Behavior[Command] = r: IncomingPaymentPacket.ChannelRelayPacket): Behavior[Command] =
Behaviors.setup { context => Behaviors.setup { context =>
@ -67,9 +69,8 @@ object ChannelRelay {
paymentHash_opt = Some(r.add.paymentHash), paymentHash_opt = Some(r.add.paymentHash),
nodeAlias_opt = Some(nodeParams.alias))) { nodeAlias_opt = Some(nodeParams.alias))) {
val upstream = Upstream.Hot.Channel(r.add.removeUnknownTlvs(), TimestampMilli.now(), originNode) val upstream = Upstream.Hot.Channel(r.add.removeUnknownTlvs(), TimestampMilli.now(), originNode)
context.self ! DoRelay
val confidence = (r.add.endorsement + 0.5) / 8 val confidence = (r.add.endorsement + 0.5) / 8
new ChannelRelay(nodeParams, register, channels, r, upstream, confidence, context).relay(Seq.empty) new ChannelRelay(nodeParams, register, channels, r, upstream, confidence, context).start()
} }
} }
@ -77,7 +78,7 @@ object ChannelRelay {
* This helper method translates relaying errors (returned by the downstream outgoing channel) to BOLT 4 standard * This helper method translates relaying errors (returned by the downstream outgoing channel) to BOLT 4 standard
* errors that we should return upstream. * errors that we should return upstream.
*/ */
def translateLocalError(error: Throwable, channelUpdate_opt: Option[ChannelUpdate]): FailureMessage = { private def translateLocalError(error: ChannelException, channelUpdate_opt: Option[ChannelUpdate]): FailureMessage = {
(error, channelUpdate_opt) match { (error, channelUpdate_opt) match {
case (_: ExpiryTooSmall, Some(channelUpdate)) => ExpiryTooSoon(Some(channelUpdate)) case (_: ExpiryTooSmall, Some(channelUpdate)) => ExpiryTooSoon(Some(channelUpdate))
case (_: ExpiryTooBig, _) => ExpiryTooFar() case (_: ExpiryTooBig, _) => ExpiryTooFar()
@ -121,13 +122,57 @@ class ChannelRelay private(nodeParams: NodeParams,
private val forwardFailureAdapter = context.messageAdapter[Register.ForwardFailure[CMD_ADD_HTLC]](WrappedForwardFailure) private val forwardFailureAdapter = context.messageAdapter[Register.ForwardFailure[CMD_ADD_HTLC]](WrappedForwardFailure)
private val addResponseAdapter = context.messageAdapter[CommandResponse[CMD_ADD_HTLC]](WrappedAddResponse) private val addResponseAdapter = context.messageAdapter[CommandResponse[CMD_ADD_HTLC]](WrappedAddResponse)
private val nextBlindingKey_opt = r.payload match {
case payload: IntermediatePayload.ChannelRelay.Blinded => Some(payload.nextBlinding)
case _: IntermediatePayload.ChannelRelay.Standard => None
}
/** Channel id explicitly requested in the onion payload. */
private val requestedChannelId_opt = r.payload.outgoing match {
case Left(_) => None
case Right(outgoingChannelId) => channels.collectFirst {
case (channelId, channel) if channel.shortIds.localAlias == outgoingChannelId => channelId
case (channelId, channel) if channel.shortIds.real.toOption.contains(outgoingChannelId) => channelId
}
}
private val (requestedShortChannelId_opt, walletNodeId_opt) = r.payload.outgoing match {
case Left(walletNodeId) => (None, Some(walletNodeId))
case Right(shortChannelId) => (Some(shortChannelId), None)
}
private case class PreviouslyTried(channelId: ByteVector32, failure: RES_ADD_FAILED[ChannelException]) private case class PreviouslyTried(channelId: ByteVector32, failure: RES_ADD_FAILED[ChannelException])
def start(): Behavior[Command] = {
walletNodeId_opt match {
case Some(walletNodeId) if nodeParams.peerWakeUpConfig.enabled => wakeUp(walletNodeId)
case _ =>
context.self ! DoRelay
relay(Seq.empty)
}
}
private def wakeUp(walletNodeId: PublicKey): Behavior[Command] = {
context.log.info("trying to wake up channel peer (nodeId={})", walletNodeId)
val notifier = context.spawnAnonymous(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout))))
notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult))
Behaviors.receiveMessagePartial {
case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) =>
Metrics.recordPaymentRelayFailed(Tags.FailureType.WakeUp, Tags.RelayType.Channel)
context.log.info("rejecting htlc: failed to wake-up remote peer")
safeSendAndStop(r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true))
case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerReady) =>
context.self ! DoRelay
relay(Seq.empty)
}
}
def relay(previousFailures: Seq[PreviouslyTried]): Behavior[Command] = { def relay(previousFailures: Seq[PreviouslyTried]): Behavior[Command] = {
Behaviors.receiveMessagePartial { Behaviors.receiveMessagePartial {
case DoRelay => case DoRelay =>
if (previousFailures.isEmpty) { if (previousFailures.isEmpty) {
context.log.info("relaying htlc #{} from channelId={} to requestedShortChannelId={} nextNode={}", r.add.id, r.add.channelId, r.payload.outgoingChannelId, nextNodeId_opt.getOrElse("")) val nextNodeId_opt = channels.headOption.map(_._2.nextNodeId)
context.log.info("relaying htlc #{} from channelId={} to requestedShortChannelId={} nextNode={}", r.add.id, r.add.channelId, requestedShortChannelId_opt, nextNodeId_opt.getOrElse(""))
} }
context.log.debug("attempting relay previousAttempts={}", previousFailures.size) context.log.debug("attempting relay previousAttempts={}", previousFailures.size)
handleRelay(previousFailures) match { handleRelay(previousFailures) match {
@ -143,7 +188,7 @@ class ChannelRelay private(nodeParams: NodeParams,
} }
} }
def waitForAddResponse(selectedChannelId: ByteVector32, previousFailures: Seq[PreviouslyTried]): Behavior[Command] = private def waitForAddResponse(selectedChannelId: ByteVector32, previousFailures: Seq[PreviouslyTried]): Behavior[Command] =
Behaviors.receiveMessagePartial { Behaviors.receiveMessagePartial {
case WrappedForwardFailure(Register.ForwardFailure(Register.Forward(_, channelId, _))) => case WrappedForwardFailure(Register.ForwardFailure(Register.Forward(_, channelId, _))) =>
context.log.warn(s"couldn't resolve downstream channel $channelId, failing htlc #${upstream.add.id}") context.log.warn(s"couldn't resolve downstream channel $channelId, failing htlc #${upstream.add.id}")
@ -156,23 +201,23 @@ class ChannelRelay private(nodeParams: NodeParams,
context.self ! DoRelay context.self ! DoRelay
relay(previousFailures :+ PreviouslyTried(selectedChannelId, addFailed)) relay(previousFailures :+ PreviouslyTried(selectedChannelId, addFailed))
case WrappedAddResponse(r: RES_SUCCESS[_]) => case WrappedAddResponse(_: RES_SUCCESS[_]) =>
context.log.debug("sent htlc to the downstream channel") context.log.debug("sent htlc to the downstream channel")
waitForAddSettled(r.channelId) waitForAddSettled()
} }
def waitForAddSettled(channelId: ByteVector32): Behavior[Command] = private def waitForAddSettled(): Behavior[Command] =
Behaviors.receiveMessagePartial { Behaviors.receiveMessagePartial {
case WrappedAddResponse(RES_ADD_SETTLED(_, htlc, fulfill: HtlcResult.Fulfill)) => case WrappedAddResponse(RES_ADD_SETTLED(_, htlc, fulfill: HtlcResult.Fulfill)) =>
context.log.info("relaying fulfill to upstream, startedAt={}, endedAt={}, confidence={}, originNode={}, outgoingChannel={}", upstream.receivedAt, TimestampMilli.now(), confidence, upstream.receivedFrom, channelId) context.log.info("relaying fulfill to upstream, startedAt={}, endedAt={}, confidence={}, originNode={}, outgoingChannel={}", upstream.receivedAt, TimestampMilli.now(), confidence, upstream.receivedFrom, htlc.channelId)
Metrics.relayFulfill(confidence) Metrics.relayFulfill(confidence)
val cmd = CMD_FULFILL_HTLC(upstream.add.id, fulfill.paymentPreimage, commit = true) val cmd = CMD_FULFILL_HTLC(upstream.add.id, fulfill.paymentPreimage, commit = true)
context.system.eventStream ! EventStream.Publish(ChannelPaymentRelayed(upstream.amountIn, htlc.amountMsat, htlc.paymentHash, upstream.add.channelId, htlc.channelId, upstream.receivedAt, TimestampMilli.now())) context.system.eventStream ! EventStream.Publish(ChannelPaymentRelayed(upstream.amountIn, htlc.amountMsat, htlc.paymentHash, upstream.add.channelId, htlc.channelId, upstream.receivedAt, TimestampMilli.now()))
recordRelayDuration(isSuccess = true) recordRelayDuration(isSuccess = true)
safeSendAndStop(upstream.add.channelId, cmd) safeSendAndStop(upstream.add.channelId, cmd)
case WrappedAddResponse(RES_ADD_SETTLED(_, _, fail: HtlcResult.Fail)) => case WrappedAddResponse(RES_ADD_SETTLED(_, htlc, fail: HtlcResult.Fail)) =>
context.log.info("relaying fail to upstream, startedAt={}, endedAt={}, confidence={}, originNode={}, outgoingChannel={}", upstream.receivedAt, TimestampMilli.now(), confidence, upstream.receivedFrom, channelId) context.log.info("relaying fail to upstream, startedAt={}, endedAt={}, confidence={}, originNode={}, outgoingChannel={}", upstream.receivedAt, TimestampMilli.now(), confidence, upstream.receivedFrom, htlc.channelId)
Metrics.relayFail(confidence) Metrics.relayFail(confidence)
Metrics.recordPaymentRelayFailed(Tags.FailureType.Remote, Tags.RelayType.Channel) Metrics.recordPaymentRelayFailed(Tags.FailureType.Remote, Tags.RelayType.Channel)
val cmd = translateRelayFailure(upstream.add.id, fail) val cmd = translateRelayFailure(upstream.add.id, fail)
@ -180,7 +225,7 @@ class ChannelRelay private(nodeParams: NodeParams,
safeSendAndStop(upstream.add.channelId, cmd) safeSendAndStop(upstream.add.channelId, cmd)
} }
def safeSendAndStop(channelId: ByteVector32, cmd: channel.HtlcSettlementCommand): Behavior[Command] = { private def safeSendAndStop(channelId: ByteVector32, cmd: channel.HtlcSettlementCommand): Behavior[Command] = {
val toSend = cmd match { val toSend = cmd match {
case _: CMD_FULFILL_HTLC => cmd case _: CMD_FULFILL_HTLC => cmd
case _: CMD_FAIL_HTLC | _: CMD_FAIL_MALFORMED_HTLC => r.payload match { case _: CMD_FAIL_HTLC | _: CMD_FAIL_MALFORMED_HTLC => r.payload match {
@ -211,49 +256,44 @@ class ChannelRelay private(nodeParams: NodeParams,
* - a CMD_FAIL_HTLC to be sent back upstream * - a CMD_FAIL_HTLC to be sent back upstream
* - a CMD_ADD_HTLC to propagate downstream * - a CMD_ADD_HTLC to propagate downstream
*/ */
def handleRelay(previousFailures: Seq[PreviouslyTried]): RelayResult = { private def handleRelay(previousFailures: Seq[PreviouslyTried]): RelayResult = {
val alreadyTried = previousFailures.map(_.channelId) val alreadyTried = previousFailures.map(_.channelId)
selectPreferredChannel(alreadyTried) match { selectPreferredChannel(alreadyTried) match {
case None if previousFailures.nonEmpty => case Some(outgoingChannel) => relayOrFail(outgoingChannel)
// no more channels to try case None =>
val error = previousFailures // No more channels to try.
// we return the error for the initially requested channel if it exists val cmdFail = if (previousFailures.nonEmpty) {
.find(failure => requestedChannelId_opt.contains(failure.channelId)) val error = previousFailures
// otherwise we return the error for the first channel tried // We return the error for the initially requested channel if it exists.
.getOrElse(previousFailures.head) .find(failure => requestedChannelId_opt.contains(failure.channelId))
.failure // Otherwise we return the error for the first channel tried.
RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(translateLocalError(error.t, error.channelUpdate)), commit = true)) .getOrElse(previousFailures.head)
case outgoingChannel_opt => .failure
relayOrFail(outgoingChannel_opt) CMD_FAIL_HTLC(r.add.id, Right(translateLocalError(error.t, error.channelUpdate)), commit = true)
} else {
CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true)
}
RelayFailure(cmdFail)
} }
} }
/** all the channels point to the same next node, we take the first one */
private val nextNodeId_opt = channels.headOption.map(_._2.nextNodeId)
/** channel id explicitly requested in the onion payload */
private val requestedChannelId_opt = channels.collectFirst {
case (channelId, channel) if channel.shortIds.localAlias == r.payload.outgoingChannelId => channelId
case (channelId, channel) if channel.shortIds.real.toOption.contains(r.payload.outgoingChannelId) => channelId
}
/** /**
* Select a channel to the same node to relay the payment to, that has the lowest capacity and balance and is * Select a channel to the same node to relay the payment to, that has the lowest capacity and balance and is
* compatible in terms of fees, expiry_delta, etc. * compatible in terms of fees, expiry_delta, etc.
* *
* If no suitable channel is found we default to the originally requested channel. * If no suitable channel is found we default to the originally requested channel.
*/ */
def selectPreferredChannel(alreadyTried: Seq[ByteVector32]): Option[OutgoingChannel] = { private def selectPreferredChannel(alreadyTried: Seq[ByteVector32]): Option[OutgoingChannel] = {
val requestedShortChannelId = r.payload.outgoingChannelId context.log.debug("selecting next channel with requestedShortChannelId={}", requestedShortChannelId_opt)
context.log.debug("selecting next channel with requestedShortChannelId={}", requestedShortChannelId)
// we filter out channels that we have already tried // we filter out channels that we have already tried
val candidateChannels: Map[ByteVector32, OutgoingChannel] = channels -- alreadyTried val candidateChannels: Map[ByteVector32, OutgoingChannel] = channels -- alreadyTried
// and we filter again to keep the ones that are compatible with this payment (mainly fees, expiry delta) // and we filter again to keep the ones that are compatible with this payment (mainly fees, expiry delta)
candidateChannels candidateChannels
.values .values
.map { channel => .map { channel =>
val relayResult = relayOrFail(Some(channel)) val relayResult = relayOrFail(channel)
context.log.debug(s"candidate channel: channelId=${channel.channelId} availableForSend={} capacity={} channelUpdate={} result={}", context.log.debug("candidate channel: channelId={} availableForSend={} capacity={} channelUpdate={} result={}",
channel.channelId,
channel.commitments.availableBalanceForSend, channel.commitments.availableBalanceForSend,
channel.commitments.latest.capacity, channel.commitments.latest.capacity,
channel.channelUpdate, channel.channelUpdate,
@ -279,7 +319,7 @@ class ChannelRelay private(nodeParams: NodeParams,
context.log.debug("requested short channel id is our preferred channel") context.log.debug("requested short channel id is our preferred channel")
Some(channel) Some(channel)
} else { } else {
context.log.debug("replacing requestedShortChannelId={} by preferredShortChannelId={} with availableBalanceMsat={}", requestedShortChannelId, channel.channelUpdate.shortChannelId, channel.commitments.availableBalanceForSend) context.log.debug("replacing requestedShortChannelId={} by preferredShortChannelId={} with availableBalanceMsat={}", requestedShortChannelId_opt, channel.channelUpdate.shortChannelId, channel.commitments.availableBalanceForSend)
Some(channel) Some(channel)
} }
case None => case None =>
@ -300,28 +340,35 @@ class ChannelRelay private(nodeParams: NodeParams,
* channel, because some parameters don't match with our settings for that channel. In that case we directly fail the * channel, because some parameters don't match with our settings for that channel. In that case we directly fail the
* htlc. * htlc.
*/ */
def relayOrFail(outgoingChannel_opt: Option[OutgoingChannelParams]): RelayResult = { private def relayOrFail(outgoingChannel: OutgoingChannelParams): RelayResult = {
outgoingChannel_opt match { val update = outgoingChannel.channelUpdate
validateRelayParams(outgoingChannel) match {
case Some(fail) =>
RelayFailure(fail)
case None if !update.channelFlags.isEnabled =>
RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(ChannelDisabled(update.messageFlags, update.channelFlags, Some(update))), commit = true))
case None => case None =>
RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true))
case Some(c) if !c.channelUpdate.channelFlags.isEnabled =>
RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(ChannelDisabled(c.channelUpdate.messageFlags, c.channelUpdate.channelFlags, Some(c.channelUpdate))), commit = true))
case Some(c) if r.amountToForward < c.channelUpdate.htlcMinimumMsat =>
RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(AmountBelowMinimum(r.amountToForward, Some(c.channelUpdate))), commit = true))
case Some(c) if r.expiryDelta < c.channelUpdate.cltvExpiryDelta =>
RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(IncorrectCltvExpiry(r.outgoingCltv, Some(c.channelUpdate))), commit = true))
case Some(c) if r.relayFeeMsat < nodeFee(c.channelUpdate.relayFees, r.amountToForward) &&
// fees also do not satisfy the previous channel update for `enforcementDelay` seconds after current update
(TimestampSecond.now() - c.channelUpdate.timestamp > nodeParams.relayParams.enforcementDelay ||
outgoingChannel_opt.flatMap(_.prevChannelUpdate).forall(c => r.relayFeeMsat < nodeFee(c.relayFees, r.amountToForward))) =>
RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(FeeInsufficient(r.add.amountMsat, Some(c.channelUpdate))), commit = true))
case Some(c: OutgoingChannel) =>
val origin = Origin.Hot(addResponseAdapter.toClassic, upstream) val origin = Origin.Hot(addResponseAdapter.toClassic, upstream)
val nextBlindingKey_opt = r.payload match { RelaySuccess(outgoingChannel.channelId, CMD_ADD_HTLC(addResponseAdapter.toClassic, r.amountToForward, r.add.paymentHash, r.outgoingCltv, r.nextPacket, nextBlindingKey_opt, confidence, origin, commit = true))
case payload: IntermediatePayload.ChannelRelay.Blinded => Some(payload.nextBlinding) }
case _: IntermediatePayload.ChannelRelay.Standard => None }
}
RelaySuccess(c.channelId, CMD_ADD_HTLC(addResponseAdapter.toClassic, r.amountToForward, r.add.paymentHash, r.outgoingCltv, r.nextPacket, nextBlindingKey_opt, confidence, origin, commit = true)) private def validateRelayParams(outgoingChannel: OutgoingChannelParams): Option[CMD_FAIL_HTLC] = {
val update = outgoingChannel.channelUpdate
// If our current channel update was recently created, we accept payments that used our previous channel update.
val allowPreviousUpdate = TimestampSecond.now() - update.timestamp <= nodeParams.relayParams.enforcementDelay
val prevUpdate_opt = if (allowPreviousUpdate) outgoingChannel.prevChannelUpdate else None
val htlcMinimumOk = update.htlcMinimumMsat <= r.amountToForward || prevUpdate_opt.exists(_.htlcMinimumMsat <= r.amountToForward)
val expiryDeltaOk = update.cltvExpiryDelta <= r.expiryDelta || prevUpdate_opt.exists(_.cltvExpiryDelta <= r.expiryDelta)
val feesOk = nodeFee(update.relayFees, r.amountToForward) <= r.relayFeeMsat || prevUpdate_opt.exists(u => nodeFee(u.relayFees, r.amountToForward) <= r.relayFeeMsat)
if (!htlcMinimumOk) {
Some(CMD_FAIL_HTLC(r.add.id, Right(AmountBelowMinimum(r.amountToForward, Some(update))), commit = true))
} else if (!expiryDeltaOk) {
Some(CMD_FAIL_HTLC(r.add.id, Right(IncorrectCltvExpiry(r.outgoingCltv, Some(update))), commit = true))
} else if (!feesOk) {
Some(CMD_FAIL_HTLC(r.add.id, Right(FeeInsufficient(r.add.amountMsat, Some(update))), commit = true))
} else {
None
} }
} }

View File

@ -24,7 +24,7 @@ import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel._
import fr.acinq.eclair.payment.IncomingPaymentPacket import fr.acinq.eclair.payment.IncomingPaymentPacket
import fr.acinq.eclair.{SubscriptionsComplete, Logs, NodeParams, ShortChannelId} import fr.acinq.eclair.{Logs, NodeParams, ShortChannelId, SubscriptionsComplete}
import java.util.UUID import java.util.UUID
import scala.collection.mutable import scala.collection.mutable
@ -70,9 +70,12 @@ object ChannelRelayer {
Behaviors.receiveMessage { Behaviors.receiveMessage {
case Relay(channelRelayPacket, originNode) => case Relay(channelRelayPacket, originNode) =>
val relayId = UUID.randomUUID() val relayId = UUID.randomUUID()
val nextNodeId_opt: Option[PublicKey] = scid2channels.get(channelRelayPacket.payload.outgoingChannelId) match { val nextNodeId_opt: Option[PublicKey] = channelRelayPacket.payload.outgoing match {
case Some(channelId) => channels.get(channelId).map(_.nextNodeId) case Left(walletNodeId) => Some(walletNodeId)
case None => None case Right(outgoingChannelId) => scid2channels.get(outgoingChannelId) match {
case Some(channelId) => channels.get(channelId).map(_.nextNodeId)
case None => None
}
} }
val nextChannels: Map[ByteVector32, Relayer.OutgoingChannel] = nextNodeId_opt match { val nextChannels: Map[ByteVector32, Relayer.OutgoingChannel] = nextNodeId_opt match {
case Some(nextNodeId) => node2channels.get(nextNodeId).flatMap(channels.get).map(c => c.channelId -> c).toMap case Some(nextNodeId) => node2channels.get(nextNodeId).flatMap(channels.get).map(c => c.channelId -> c).toMap

View File

@ -26,6 +26,7 @@ import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Upstream} import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Upstream}
import fr.acinq.eclair.db.PendingCommandsDb import fr.acinq.eclair.db.PendingCommandsDb
import fr.acinq.eclair.io.PeerReadyNotifier
import fr.acinq.eclair.payment.IncomingPaymentPacket.NodeRelayPacket import fr.acinq.eclair.payment.IncomingPaymentPacket.NodeRelayPacket
import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags}
import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment._
@ -40,7 +41,7 @@ import fr.acinq.eclair.router.Router.RouteParams
import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound} import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound}
import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload
import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32, randomKey} import fr.acinq.eclair.{CltvExpiry, EncodedNodeId, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32}
import java.util.UUID import java.util.UUID
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
@ -62,7 +63,7 @@ object NodeRelay {
private case class WrappedPreimageReceived(preimageReceived: PreimageReceived) extends Command private case class WrappedPreimageReceived(preimageReceived: PreimageReceived) extends Command
private case class WrappedPaymentSent(paymentSent: PaymentSent) extends Command private case class WrappedPaymentSent(paymentSent: PaymentSent) extends Command
private case class WrappedPaymentFailed(paymentFailed: PaymentFailed) extends Command private case class WrappedPaymentFailed(paymentFailed: PaymentFailed) extends Command
private[relay] case class WrappedPeerReadyResult(result: AsyncPaymentTriggerer.Result) extends Command private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command
private case class WrappedResolvedPaths(resolved: Seq[ResolvedPath]) extends Command private case class WrappedResolvedPaths(resolved: Seq[ResolvedPath]) extends Command
// @formatter:on // @formatter:on
@ -88,7 +89,6 @@ object NodeRelay {
relayId: UUID, relayId: UUID,
nodeRelayPacket: NodeRelayPacket, nodeRelayPacket: NodeRelayPacket,
outgoingPaymentFactory: OutgoingPaymentFactory, outgoingPaymentFactory: OutgoingPaymentFactory,
triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command],
router: ActorRef): Behavior[Command] = router: ActorRef): Behavior[Command] =
Behaviors.setup { context => Behaviors.setup { context =>
val paymentHash = nodeRelayPacket.add.paymentHash val paymentHash = nodeRelayPacket.add.paymentHash
@ -108,7 +108,7 @@ object NodeRelay {
case IncomingPaymentPacket.RelayToTrampolinePacket(_, _, _, nextPacket) => Some(nextPacket) case IncomingPaymentPacket.RelayToTrampolinePacket(_, _, _, nextPacket) => Some(nextPacket)
case _: IncomingPaymentPacket.RelayToBlindedPathsPacket => None case _: IncomingPaymentPacket.RelayToBlindedPathsPacket => None
} }
new NodeRelay(nodeParams, parent, register, relayId, paymentHash, nodeRelayPacket.outerPayload.paymentSecret, context, outgoingPaymentFactory, triggerer, router) new NodeRelay(nodeParams, parent, register, relayId, paymentHash, nodeRelayPacket.outerPayload.paymentSecret, context, outgoingPaymentFactory, router)
.receiving(Queue.empty, nodeRelayPacket.innerPayload, nextPacket_opt, incomingPaymentHandler) .receiving(Queue.empty, nodeRelayPacket.innerPayload, nextPacket_opt, incomingPaymentHandler)
} }
} }
@ -125,14 +125,29 @@ object NodeRelay {
Some(InvalidOnionPayload(UInt64(2), 0)) Some(InvalidOnionPayload(UInt64(2), 0))
} else { } else {
payloadOut match { payloadOut match {
case payloadOut: IntermediatePayload.NodeRelay.Standard => // If we're relaying a standard payment to a non-trampoline recipient, we need the payment secret.
if (payloadOut.invoiceFeatures.isDefined && payloadOut.paymentSecret.isEmpty) { case payloadOut: IntermediatePayload.NodeRelay.Standard if payloadOut.invoiceFeatures.isDefined && payloadOut.paymentSecret.isEmpty => Some(InvalidOnionPayload(UInt64(8), 0))
Some(InvalidOnionPayload(UInt64(8), 0)) // payment secret field is missing case _: IntermediatePayload.NodeRelay.Standard => None
} else { case _: IntermediatePayload.NodeRelay.ToBlindedPaths => None
None }
} }
case _: IntermediatePayload.NodeRelay.ToBlindedPaths => }
None
/** This function identifies whether the next node is a wallet node directly connected to us, and returns its node_id. */
private def nextWalletNodeId(nodeParams: NodeParams, recipient: Recipient): Option[PublicKey] = {
recipient match {
// These two recipients are only used when we're the payment initiator.
case _: SpontaneousRecipient => None
case _: TrampolineRecipient => None
// When relaying to a trampoline node, the next node may be a wallet node directly connected to us, but we don't
// want to have false positives. Feature branches should check an internal DB/cache to confirm.
case r: ClearRecipient if r.nextTrampolineOnion_opt.nonEmpty => None
// If we're relaying to a non-trampoline recipient, it's never a wallet node.
case _: ClearRecipient => None
// When using blinded paths, we may be the introduction node for a wallet node directly connected to us.
case r: BlindedRecipient => r.blindedHops.head.resolved.route match {
case BlindedPathsResolver.PartialBlindedRoute(walletNodeId: EncodedNodeId.WithPublicKey.Wallet, _, _) => Some(walletNodeId.publicKey)
case _ => None
} }
} }
} }
@ -188,7 +203,6 @@ class NodeRelay private(nodeParams: NodeParams,
paymentSecret: ByteVector32, paymentSecret: ByteVector32,
context: ActorContext[NodeRelay.Command], context: ActorContext[NodeRelay.Command],
outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory, outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory,
triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command],
router: ActorRef) { router: ActorRef) {
import NodeRelay._ import NodeRelay._
@ -223,44 +237,102 @@ class NodeRelay private(nodeParams: NodeParams,
rejectPayment(upstream, Some(failure)) rejectPayment(upstream, Some(failure))
stopping() stopping()
case None => case None =>
nextPayload match { resolveNextNode(upstream, nextPayload, nextPacket_opt)
// TODO: async payments are not currently supported for blinded recipients. We should update the AsyncPaymentTriggerer to decrypt the blinded path.
case nextPayload: IntermediatePayload.NodeRelay.Standard if nextPayload.isAsyncPayment && nodeParams.features.hasFeature(Features.AsyncPaymentPrototype) =>
waitForTrigger(upstream, nextPayload, nextPacket_opt)
case _ =>
doSend(upstream, nextPayload, nextPacket_opt)
}
} }
} }
private def waitForTrigger(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay.Standard, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { /** Once we've fully received the incoming HTLC set, we must identify the next node before forwarding the payment. */
context.log.info(s"waiting for async payment to trigger before relaying trampoline payment (amountIn=${upstream.amountIn} expiryIn=${upstream.expiryIn} amountOut=${nextPayload.amountToForward} expiryOut=${nextPayload.outgoingCltv}, asyncPaymentsParams=${nodeParams.relayParams.asyncPaymentsParams})") private def resolveNextNode(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = {
val timeoutBlock = nodeParams.currentBlockHeight + nodeParams.relayParams.asyncPaymentsParams.holdTimeoutBlocks nextPayload match {
val safetyBlock = (upstream.expiryIn - nodeParams.relayParams.asyncPaymentsParams.cancelSafetyBeforeTimeout).blockHeight case payloadOut: IntermediatePayload.NodeRelay.Standard =>
// wait for notification until which ever occurs first: the hold timeout block or the safety block // If invoice features are provided in the onion, the sender is asking us to relay to a non-trampoline recipient.
val notifierTimeout = Seq(timeoutBlock, safetyBlock).min payloadOut.invoiceFeatures match {
val peerReadyResultAdapter = context.messageAdapter[AsyncPaymentTriggerer.Result](WrappedPeerReadyResult) case Some(features) =>
val extraEdges = payloadOut.invoiceRoutingInfo.getOrElse(Nil).flatMap(Bolt11Invoice.toExtraEdges(_, payloadOut.outgoingNodeId))
triggerer ! AsyncPaymentTriggerer.Watch(peerReadyResultAdapter, nextPayload.outgoingNodeId, paymentHash, notifierTimeout) val paymentSecret = payloadOut.paymentSecret.get // NB: we've verified that there was a payment secret in validateRelay
context.system.eventStream ! EventStream.Publish(WaitingToRelayPayment(nextPayload.outgoingNodeId, paymentHash)) val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features(features).invoiceFeatures(), payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, extraEdges, payloadOut.paymentMetadata)
Behaviors.receiveMessagePartial { context.log.debug("forwarding payment to non-trampoline recipient {}", recipient.nodeId)
case WrappedPeerReadyResult(AsyncPaymentTriggerer.AsyncPaymentTimeout) => ensureRecipientReady(upstream, recipient, nextPayload, None)
context.log.warn("rejecting async payment; was not triggered before block {}", notifierTimeout) case None =>
rejectPayment(upstream, Some(TemporaryNodeFailure())) // TODO: replace failure type when async payment spec is finalized val paymentSecret = randomBytes32() // we generate a new secret to protect against probing attacks
stopping() val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features.empty, payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, nextTrampolineOnion_opt = nextPacket_opt)
case WrappedPeerReadyResult(AsyncPaymentTriggerer.AsyncPaymentCanceled) => context.log.debug("forwarding payment to the next trampoline node {}", recipient.nodeId)
context.log.warn(s"payment sender canceled a waiting async payment") ensureRecipientReady(upstream, recipient, nextPayload, nextPacket_opt)
rejectPayment(upstream, Some(TemporaryNodeFailure())) // TODO: replace failure type when async payment spec is finalized }
stopping() case payloadOut: IntermediatePayload.NodeRelay.ToBlindedPaths =>
case WrappedPeerReadyResult(AsyncPaymentTriggerer.AsyncPaymentTriggered) => // Blinded paths in Bolt 12 invoices may encode the introduction node with an scid and a direction: we need to
doSend(upstream, nextPayload, nextPacket_opt) // resolve that to a nodeId in order to reach that introduction node and use the blinded path.
// If we are the introduction node ourselves, we'll also need to decrypt the onion and identify the next node.
context.spawnAnonymous(BlindedPathsResolver(nodeParams, paymentHash, router, register)) ! Resolve(context.messageAdapter[Seq[ResolvedPath]](WrappedResolvedPaths), payloadOut.outgoingBlindedPaths)
Behaviors.receiveMessagePartial {
rejectExtraHtlcPartialFunction orElse {
case WrappedResolvedPaths(resolved) if resolved.isEmpty =>
context.log.warn("rejecting trampoline payment to blinded paths: no usable blinded path")
rejectPayment(upstream, Some(UnknownNextPeer()))
stopping()
case WrappedResolvedPaths(resolved) =>
// We don't have access to the invoice: we use the only node_id that somewhat makes sense for the recipient.
val blindedNodeId = resolved.head.route.blindedNodeIds.last
val recipient = BlindedRecipient.fromPaths(blindedNodeId, Features(payloadOut.invoiceFeatures).invoiceFeatures(), payloadOut.amountToForward, payloadOut.outgoingCltv, resolved, Set.empty)
context.log.debug("forwarding payment to blinded recipient {}", recipient.nodeId)
ensureRecipientReady(upstream, recipient, nextPayload, nextPacket_opt)
}
}
} }
} }
private def doSend(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { /**
context.log.debug(s"relaying trampoline payment (amountIn=${upstream.amountIn} expiryIn=${upstream.expiryIn} amountOut=${nextPayload.amountToForward} expiryOut=${nextPayload.outgoingCltv})") * The next node may be a mobile wallet directly connected to us: in that case, we'll need to wake them up before
* relaying the payment.
*/
private def ensureRecipientReady(upstream: Upstream.Hot.Trampoline, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = {
nextWalletNodeId(nodeParams, recipient) match {
case Some(walletNodeId) if nodeParams.peerWakeUpConfig.enabled => waitForPeerReady(upstream, walletNodeId, recipient, nextPayload, nextPacket_opt)
case _ => relay(upstream, recipient, nextPayload, nextPacket_opt)
}
}
/**
* The next node is the payment recipient. They are directly connected to us and may be offline. We try to wake them
* up and will relay the payment once they're connected and channels are reestablished.
*/
private def waitForPeerReady(upstream: Upstream.Hot.Trampoline, walletNodeId: PublicKey, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = {
context.log.info("trying to wake up next peer (nodeId={})", walletNodeId)
val notifier = context.spawnAnonymous(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout))))
notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult))
Behaviors.receiveMessagePartial {
rejectExtraHtlcPartialFunction orElse {
case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) =>
context.log.warn("rejecting payment: failed to wake-up remote peer")
rejectPayment(upstream, Some(UnknownNextPeer()))
stopping()
case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerReady) =>
relay(upstream, recipient, nextPayload, nextPacket_opt)
}
}
}
/** Relay the payment to the next identified node: this is similar to sending an outgoing payment. */
private def relay(upstream: Upstream.Hot.Trampoline, recipient: Recipient, payloadOut: IntermediatePayload.NodeRelay, packetOut_opt: Option[OnionRoutingPacket]): Behavior[Command] = {
context.log.debug("relaying trampoline payment (amountIn={} expiryIn={} amountOut={} expiryOut={})", upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv)
val confidence = (upstream.received.map(_.add.endorsement).min + 0.5) / 8 val confidence = (upstream.received.map(_.add.endorsement).min + 0.5) / 8
relay(upstream, nextPayload, nextPacket_opt, confidence) val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, recipient.nodeId, upstream, None, None, storeInDb = false, publishEvent = false, recordPathFindingMetrics = true, confidence)
val routeParams = computeRouteParams(nodeParams, upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv)
// If the next node is using trampoline, we assume that they support MPP.
val useMultiPart = recipient.features.hasFeature(Features.BasicMultiPartPayment) || packetOut_opt.nonEmpty
val payFsmAdapters = {
context.messageAdapter[PreimageReceived](WrappedPreimageReceived)
context.messageAdapter[PaymentSent](WrappedPaymentSent)
context.messageAdapter[PaymentFailed](WrappedPaymentFailed)
}.toClassic
val payment = if (useMultiPart) {
SendMultiPartPayment(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams)
} else {
SendPaymentToNode(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams)
}
val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, useMultiPart)
payFSM ! payment
sending(upstream, payloadOut, recipient, TimestampMilli.now(), fulfilledUpstream = false)
} }
/** /**
@ -270,7 +342,11 @@ class NodeRelay private(nodeParams: NodeParams,
* @param nextPayload relay instructions. * @param nextPayload relay instructions.
* @param fulfilledUpstream true if we already fulfilled the payment upstream. * @param fulfilledUpstream true if we already fulfilled the payment upstream.
*/ */
private def sending(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay, startedAt: TimestampMilli, fulfilledUpstream: Boolean): Behavior[Command] = private def sending(upstream: Upstream.Hot.Trampoline,
nextPayload: IntermediatePayload.NodeRelay,
recipient: Recipient,
startedAt: TimestampMilli,
fulfilledUpstream: Boolean): Behavior[Command] =
Behaviors.receiveMessagePartial { Behaviors.receiveMessagePartial {
rejectExtraHtlcPartialFunction orElse { rejectExtraHtlcPartialFunction orElse {
// this is the fulfill that arrives from downstream channels // this is the fulfill that arrives from downstream channels
@ -279,7 +355,7 @@ class NodeRelay private(nodeParams: NodeParams,
// We want to fulfill upstream as soon as we receive the preimage (even if not all HTLCs have fulfilled downstream). // We want to fulfill upstream as soon as we receive the preimage (even if not all HTLCs have fulfilled downstream).
context.log.debug("got preimage from downstream") context.log.debug("got preimage from downstream")
fulfillPayment(upstream, paymentPreimage) fulfillPayment(upstream, paymentPreimage)
sending(upstream, nextPayload, startedAt, fulfilledUpstream = true) sending(upstream, nextPayload, recipient, startedAt, fulfilledUpstream = true)
} else { } else {
// we don't want to fulfill multiple times // we don't want to fulfill multiple times
Behaviors.same Behaviors.same
@ -311,80 +387,6 @@ class NodeRelay private(nodeParams: NodeParams,
} }
} }
private val payFsmAdapters = {
context.messageAdapter[PreimageReceived](WrappedPreimageReceived)
context.messageAdapter[PaymentSent](WrappedPaymentSent)
context.messageAdapter[PaymentFailed](WrappedPaymentFailed)
}.toClassic
private def relay(upstream: Upstream.Hot.Trampoline, payloadOut: IntermediatePayload.NodeRelay, packetOut_opt: Option[OnionRoutingPacket], confidence: Double): Behavior[Command] = {
val displayNodeId = payloadOut match {
case payloadOut: IntermediatePayload.NodeRelay.Standard => payloadOut.outgoingNodeId
case _: IntermediatePayload.NodeRelay.ToBlindedPaths => randomKey().publicKey
}
val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, displayNodeId, upstream, None, None, storeInDb = false, publishEvent = false, recordPathFindingMetrics = true, confidence)
val routeParams = computeRouteParams(nodeParams, upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv)
payloadOut match {
case payloadOut: IntermediatePayload.NodeRelay.Standard =>
// If invoice features are provided in the onion, the sender is asking us to relay to a non-trampoline recipient.
payloadOut.invoiceFeatures match {
case Some(features) =>
val extraEdges = payloadOut.invoiceRoutingInfo.getOrElse(Nil).flatMap(Bolt11Invoice.toExtraEdges(_, payloadOut.outgoingNodeId))
val paymentSecret = payloadOut.paymentSecret.get // NB: we've verified that there was a payment secret in validateRelay
val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features(features).invoiceFeatures(), payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, extraEdges, payloadOut.paymentMetadata)
context.log.debug("sending the payment to non-trampoline recipient (MPP={})", recipient.features.hasFeature(Features.BasicMultiPartPayment))
relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, useMultiPart = recipient.features.hasFeature(Features.BasicMultiPartPayment))
case None =>
context.log.debug("sending the payment to the next trampoline node")
val paymentSecret = randomBytes32() // we generate a new secret to protect against probing attacks
val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features.empty, payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, nextTrampolineOnion_opt = packetOut_opt)
relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, useMultiPart = true)
}
case payloadOut: IntermediatePayload.NodeRelay.ToBlindedPaths =>
context.spawnAnonymous(BlindedPathsResolver(nodeParams, paymentHash, router, register)) ! Resolve(context.messageAdapter[Seq[ResolvedPath]](WrappedResolvedPaths), payloadOut.outgoingBlindedPaths)
waitForResolvedPaths(upstream, payloadOut, paymentCfg, routeParams)
}
}
private def relayToRecipient(upstream: Upstream.Hot.Trampoline,
payloadOut: IntermediatePayload.NodeRelay,
recipient: Recipient,
paymentCfg: SendPaymentConfig,
routeParams: RouteParams,
useMultiPart: Boolean): Behavior[Command] = {
val payment =
if (useMultiPart) {
SendMultiPartPayment(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams)
} else {
SendPaymentToNode(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams)
}
val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, useMultiPart)
payFSM ! payment
sending(upstream, payloadOut, TimestampMilli.now(), fulfilledUpstream = false)
}
/**
* Blinded paths in Bolt 12 invoices may encode the introduction node with an scid and a direction: we need to resolve
* that to a nodeId in order to reach that introduction node and use the blinded path.
*/
private def waitForResolvedPaths(upstream: Upstream.Hot.Trampoline,
payloadOut: IntermediatePayload.NodeRelay.ToBlindedPaths,
paymentCfg: SendPaymentConfig,
routeParams: RouteParams): Behavior[Command] =
Behaviors.receiveMessagePartial {
case WrappedResolvedPaths(resolved) if resolved.isEmpty =>
context.log.warn(s"rejecting trampoline payment to blinded paths: no usable blinded path")
rejectPayment(upstream, Some(UnknownNextPeer()))
stopping()
case WrappedResolvedPaths(resolved) =>
val features = Features(payloadOut.invoiceFeatures).invoiceFeatures()
// We don't have access to the invoice: we use the only node_id that somewhat makes sense for the recipient.
val blindedNodeId = resolved.head.route.blindedNodeIds.last
val recipient = BlindedRecipient.fromPaths(blindedNodeId, features, payloadOut.amountToForward, payloadOut.outgoingCltv, resolved, Set.empty)
context.log.debug("sending the payment to blinded recipient, useMultiPart={}", features.hasFeature(Features.BasicMultiPartPayment))
relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, features.hasFeature(Features.BasicMultiPartPayment))
}
private def rejectExtraHtlcPartialFunction: PartialFunction[Command, Behavior[Command]] = { private def rejectExtraHtlcPartialFunction: PartialFunction[Command, Behavior[Command]] = {
case Relay(nodeRelayPacket, _) => case Relay(nodeRelayPacket, _) =>
rejectExtraHtlc(nodeRelayPacket.add) rejectExtraHtlc(nodeRelayPacket.add)

View File

@ -16,7 +16,6 @@
package fr.acinq.eclair.payment.relay package fr.acinq.eclair.payment.relay
import akka.actor.typed
import akka.actor.typed.scaladsl.Behaviors import akka.actor.typed.scaladsl.Behaviors
import akka.actor.typed.{ActorRef, Behavior} import akka.actor.typed.{ActorRef, Behavior}
import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.ByteVector32
@ -58,7 +57,7 @@ object NodeRelayer {
* NB: the payment secret used here is different from the invoice's payment secret and ensures we can * NB: the payment secret used here is different from the invoice's payment secret and ensures we can
* group together HTLCs that the previous trampoline node sent in the same MPP. * group together HTLCs that the previous trampoline node sent in the same MPP.
*/ */
def apply(nodeParams: NodeParams, register: akka.actor.ActorRef, outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory, triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command], router: akka.actor.ActorRef, children: Map[PaymentKey, ActorRef[NodeRelay.Command]] = Map.empty): Behavior[Command] = def apply(nodeParams: NodeParams, register: akka.actor.ActorRef, outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory, router: akka.actor.ActorRef, children: Map[PaymentKey, ActorRef[NodeRelay.Command]] = Map.empty): Behavior[Command] =
Behaviors.setup { context => Behaviors.setup { context =>
Behaviors.withMdc(Logs.mdc(category_opt = Some(Logs.LogCategory.PAYMENT)), mdc) { Behaviors.withMdc(Logs.mdc(category_opt = Some(Logs.LogCategory.PAYMENT)), mdc) {
Behaviors.receiveMessage { Behaviors.receiveMessage {
@ -73,15 +72,15 @@ object NodeRelayer {
case None => case None =>
val relayId = UUID.randomUUID() val relayId = UUID.randomUUID()
context.log.debug(s"spawning a new handler with relayId=$relayId") context.log.debug(s"spawning a new handler with relayId=$relayId")
val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, nodeRelayPacket, outgoingPaymentFactory, triggerer, router), relayId.toString) val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, nodeRelayPacket, outgoingPaymentFactory, router), relayId.toString)
context.log.debug("forwarding incoming htlc #{} from channel {} to new handler", htlcIn.id, htlcIn.channelId) context.log.debug("forwarding incoming htlc #{} from channel {} to new handler", htlcIn.id, htlcIn.channelId)
handler ! NodeRelay.Relay(nodeRelayPacket, originNode) handler ! NodeRelay.Relay(nodeRelayPacket, originNode)
apply(nodeParams, register, outgoingPaymentFactory, triggerer, router, children + (childKey -> handler)) apply(nodeParams, register, outgoingPaymentFactory, router, children + (childKey -> handler))
} }
case RelayComplete(childHandler, paymentHash, paymentSecret) => case RelayComplete(childHandler, paymentHash, paymentSecret) =>
// we do a back-and-forth between parent and child before stopping the child to prevent a race condition // we do a back-and-forth between parent and child before stopping the child to prevent a race condition
childHandler ! NodeRelay.Stop childHandler ! NodeRelay.Stop
apply(nodeParams, register, outgoingPaymentFactory, triggerer, router, children - PaymentKey(paymentHash, paymentSecret)) apply(nodeParams, register, outgoingPaymentFactory, router, children - PaymentKey(paymentHash, paymentSecret))
case GetPendingPayments(replyTo) => case GetPendingPayments(replyTo) =>
replyTo ! children replyTo ! children
Behaviors.same Behaviors.same

View File

@ -49,7 +49,7 @@ import scala.util.Random
* It also receives channel HTLC events (fulfill / failed) and relays those to the appropriate handlers. * It also receives channel HTLC events (fulfill / failed) and relays those to the appropriate handlers.
* It also maintains an up-to-date view of local channel balances. * It also maintains an up-to-date view of local channel balances.
*/ */
class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paymentHandler: ActorRef, triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command], initialized: Option[Promise[Done]] = None) extends Actor with DiagnosticActorLogging { class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paymentHandler: ActorRef, initialized: Option[Promise[Done]] = None) extends Actor with DiagnosticActorLogging {
import Relayer._ import Relayer._
@ -58,7 +58,7 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym
private val postRestartCleaner = context.actorOf(PostRestartHtlcCleaner.props(nodeParams, register, initialized), "post-restart-htlc-cleaner") private val postRestartCleaner = context.actorOf(PostRestartHtlcCleaner.props(nodeParams, register, initialized), "post-restart-htlc-cleaner")
private val channelRelayer = context.spawn(Behaviors.supervise(ChannelRelayer(nodeParams, register)).onFailure(SupervisorStrategy.resume), "channel-relayer") private val channelRelayer = context.spawn(Behaviors.supervise(ChannelRelayer(nodeParams, register)).onFailure(SupervisorStrategy.resume), "channel-relayer")
private val nodeRelayer = context.spawn(Behaviors.supervise(NodeRelayer(nodeParams, register, NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register), triggerer, router)).onFailure(SupervisorStrategy.resume), name = "node-relayer") private val nodeRelayer = context.spawn(Behaviors.supervise(NodeRelayer(nodeParams, register, NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register), router)).onFailure(SupervisorStrategy.resume), name = "node-relayer")
def receive: Receive = { def receive: Receive = {
case init: PostRestartHtlcCleaner.Init => postRestartCleaner forward init case init: PostRestartHtlcCleaner.Init => postRestartCleaner forward init
@ -120,8 +120,8 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym
object Relayer extends Logging { object Relayer extends Logging {
def props(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paymentHandler: ActorRef, triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command], initialized: Option[Promise[Done]] = None): Props = def props(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paymentHandler: ActorRef, initialized: Option[Promise[Done]] = None): Props =
Props(new Relayer(nodeParams, router, register, paymentHandler, triggerer, initialized)) Props(new Relayer(nodeParams, router, register, paymentHandler, initialized))
// @formatter:off // @formatter:off
case class RelayFees(feeBase: MilliSatoshi, feeProportionalMillionths: Long) { case class RelayFees(feeBase: MilliSatoshi, feeProportionalMillionths: Long) {

View File

@ -14,7 +14,7 @@ import fr.acinq.eclair.router.Router
import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo
import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.RouteBlindingDecryptedData import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.RouteBlindingDecryptedData
import fr.acinq.eclair.wire.protocol.{BlindedRouteData, OfferTypes, RouteBlindingEncryptedDataCodecs} import fr.acinq.eclair.wire.protocol.{BlindedRouteData, OfferTypes, RouteBlindingEncryptedDataCodecs}
import fr.acinq.eclair.{EncodedNodeId, Logs, NodeParams} import fr.acinq.eclair.{EncodedNodeId, Logs, MilliSatoshiLong, NodeParams, ShortChannelId}
import scodec.bits.ByteVector import scodec.bits.ByteVector
import scala.annotation.tailrec import scala.annotation.tailrec
@ -45,8 +45,8 @@ object BlindedPathsResolver {
override val firstNodeId: PublicKey = introductionNodeId override val firstNodeId: PublicKey = introductionNodeId
} }
/** A partially unwrapped blinded route that started at our node: it only contains the part of the route after our node. */ /** A partially unwrapped blinded route that started at our node: it only contains the part of the route after our node. */
case class PartialBlindedRoute(nextNodeId: PublicKey, nextBlinding: PublicKey, blindedNodes: Seq[BlindedNode]) extends ResolvedBlindedRoute { case class PartialBlindedRoute(nextNodeId: EncodedNodeId.WithPublicKey, nextBlinding: PublicKey, blindedNodes: Seq[BlindedNode]) extends ResolvedBlindedRoute {
override val firstNodeId: PublicKey = nextNodeId override val firstNodeId: PublicKey = nextNodeId.publicKey
} }
// @formatter:on // @formatter:on
@ -111,8 +111,14 @@ private class BlindedPathsResolver(nodeParams: NodeParams,
feeProportionalMillionths = nextFeeProportionalMillionths, feeProportionalMillionths = nextFeeProportionalMillionths,
cltvExpiryDelta = nextCltvExpiryDelta cltvExpiryDelta = nextCltvExpiryDelta
) )
register ! Register.GetNextNodeId(context.messageAdapter(WrappedNodeId), paymentRelayData.outgoingChannelId) paymentRelayData.outgoing match {
waitForNextNodeId(nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved) case Left(outgoingNodeId) =>
// The next node seems to be a wallet node directly connected to us.
validateRelay(EncodedNodeId.WithPublicKey.Wallet(outgoingNodeId), nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved)
case Right(outgoingChannelId) =>
register ! Register.GetNextNodeId(context.messageAdapter(WrappedNodeId), outgoingChannelId)
waitForNextNodeId(outgoingChannelId, nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved)
}
} }
} }
case encodedNodeId: EncodedNodeId.WithPublicKey => case encodedNodeId: EncodedNodeId.WithPublicKey =>
@ -129,7 +135,8 @@ private class BlindedPathsResolver(nodeParams: NodeParams,
} }
/** Resolve the next node in the blinded path when we are the introduction node. */ /** Resolve the next node in the blinded path when we are the introduction node. */
private def waitForNextNodeId(nextPaymentInfo: OfferTypes.PaymentInfo, private def waitForNextNodeId(outgoingChannelId: ShortChannelId,
nextPaymentInfo: OfferTypes.PaymentInfo,
paymentRelayData: BlindedRouteData.PaymentRelayData, paymentRelayData: BlindedRouteData.PaymentRelayData,
nextBlinding: PublicKey, nextBlinding: PublicKey,
nextBlindedNodes: Seq[RouteBlinding.BlindedNode], nextBlindedNodes: Seq[RouteBlinding.BlindedNode],
@ -137,29 +144,42 @@ private class BlindedPathsResolver(nodeParams: NodeParams,
resolved: Seq[ResolvedPath]): Behavior[Command] = resolved: Seq[ResolvedPath]): Behavior[Command] =
Behaviors.receiveMessagePartial { Behaviors.receiveMessagePartial {
case WrappedNodeId(None) => case WrappedNodeId(None) =>
context.log.warn("ignoring blinded path starting at our node: could not resolve outgoingChannelId={}", paymentRelayData.outgoingChannelId) context.log.warn("ignoring blinded path starting at our node: could not resolve outgoingChannelId={}", outgoingChannelId)
resolveBlindedPaths(toResolve, resolved) resolveBlindedPaths(toResolve, resolved)
case WrappedNodeId(Some(nodeId)) if nodeId == nodeParams.nodeId => case WrappedNodeId(Some(nodeId)) if nodeId == nodeParams.nodeId =>
// The next node in the route is also our node: this is fishy, there is not reason to include us in the route twice. // The next node in the route is also our node: this is fishy, there is not reason to include us in the route twice.
context.log.warn("ignoring blinded path starting at our node relaying to ourselves") context.log.warn("ignoring blinded path starting at our node relaying to ourselves")
resolveBlindedPaths(toResolve, resolved) resolveBlindedPaths(toResolve, resolved)
case WrappedNodeId(Some(nodeId)) => case WrappedNodeId(Some(nodeId)) =>
// Note that we default to private fees if we don't have a channel yet with that node. validateRelay(EncodedNodeId.WithPublicKey.Plain(nodeId), nextPaymentInfo, paymentRelayData, nextBlinding, nextBlindedNodes, toResolve, resolved)
// The announceChannel parameter is ignored if we already have a channel.
val relayFees = getRelayFees(nodeParams, nodeId, announceChannel = false)
val shouldRelay = paymentRelayData.paymentRelay.feeBase >= relayFees.feeBase &&
paymentRelayData.paymentRelay.feeProportionalMillionths >= relayFees.feeProportionalMillionths &&
paymentRelayData.paymentRelay.cltvExpiryDelta >= nodeParams.channelConf.expiryDelta
if (shouldRelay) {
context.log.debug("unwrapped blinded path starting at our node: next_node={}", nodeId)
val path = ResolvedPath(PartialBlindedRoute(nodeId, nextBlinding, nextBlindedNodes), nextPaymentInfo)
resolveBlindedPaths(toResolve, resolved :+ path)
} else {
context.log.warn("ignoring blinded path starting at our node: allocated fees are too low (base={}, proportional={}, expiryDelta={})", paymentRelayData.paymentRelay.feeBase, paymentRelayData.paymentRelay.feeProportionalMillionths, paymentRelayData.paymentRelay.cltvExpiryDelta)
resolveBlindedPaths(toResolve, resolved)
}
} }
private def validateRelay(nextNodeId: EncodedNodeId.WithPublicKey,
nextPaymentInfo: OfferTypes.PaymentInfo,
paymentRelayData: BlindedRouteData.PaymentRelayData,
nextBlinding: PublicKey,
nextBlindedNodes: Seq[RouteBlinding.BlindedNode],
toResolve: Seq[PaymentBlindedRoute],
resolved: Seq[ResolvedPath]): Behavior[Command] = {
// Note that we default to private fees if we don't have a channel yet with that node.
// The announceChannel parameter is ignored if we already have a channel.
val relayFees = getRelayFees(nodeParams, nextNodeId.publicKey, announceChannel = false)
val shouldRelay = paymentRelayData.paymentRelay.feeBase >= relayFees.feeBase &&
paymentRelayData.paymentRelay.feeProportionalMillionths >= relayFees.feeProportionalMillionths &&
paymentRelayData.paymentRelay.cltvExpiryDelta >= nodeParams.channelConf.expiryDelta &&
nextPaymentInfo.feeBase >= 0.msat &&
nextPaymentInfo.feeProportionalMillionths >= 0 &&
nextPaymentInfo.cltvExpiryDelta.toInt >= 0
if (shouldRelay) {
context.log.debug("unwrapped blinded path starting at our node: next_node={}", nextNodeId.publicKey)
val path = ResolvedPath(PartialBlindedRoute(nextNodeId, nextBlinding, nextBlindedNodes), nextPaymentInfo)
resolveBlindedPaths(toResolve, resolved :+ path)
} else {
context.log.warn("ignoring blinded path starting at our node: allocated fees are too low (base={}, proportional={}, expiryDelta={})", paymentRelayData.paymentRelay.feeBase, paymentRelayData.paymentRelay.feeProportionalMillionths, paymentRelayData.paymentRelay.cltvExpiryDelta)
resolveBlindedPaths(toResolve, resolved)
}
}
/** Resolve the introduction node's [[EncodedNodeId.ShortChannelIdDir]] to the corresponding [[EncodedNodeId.WithPublicKey]]. */ /** Resolve the introduction node's [[EncodedNodeId.ShortChannelIdDir]] to the corresponding [[EncodedNodeId.WithPublicKey]]. */
private def waitForNodeId(paymentRoute: PaymentBlindedRoute, toResolve: Seq[PaymentBlindedRoute], resolved: Seq[ResolvedPath]): Behavior[Command] = private def waitForNodeId(paymentRoute: PaymentBlindedRoute, toResolve: Seq[PaymentBlindedRoute], resolved: Seq[ResolvedPath]): Behavior[Command] =
Behaviors.receiveMessagePartial { Behaviors.receiveMessagePartial {

View File

@ -21,7 +21,7 @@ import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.router.Router.ChannelHop import fr.acinq.eclair.router.Router.ChannelHop
import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo
import fr.acinq.eclair.wire.protocol.{RouteBlindingEncryptedDataCodecs, RouteBlindingEncryptedDataTlv, TlvStream} import fr.acinq.eclair.wire.protocol.{RouteBlindingEncryptedDataCodecs, RouteBlindingEncryptedDataTlv, TlvStream}
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, randomKey} import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshi, MilliSatoshiLong, randomKey}
import scodec.bits.ByteVector import scodec.bits.ByteVector
object BlindedRouteCreation { object BlindedRouteCreation {
@ -77,7 +77,7 @@ object BlindedRouteCreation {
Total: 24 to 36 bytes Total: 24 to 36 bytes
*/ */
val targetLength = 36 val targetLength = 36
val paddedPayloads = payloads.map(tlvs =>{ val paddedPayloads = payloads.map(tlvs => {
val payloadLength = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(tlvs).require.bytes.length val payloadLength = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(tlvs).require.bytes.length
tlvs.copy(records = tlvs.records + RouteBlindingEncryptedDataTlv.Padding(ByteVector.fill(targetLength - payloadLength)(0))) tlvs.copy(records = tlvs.records + RouteBlindingEncryptedDataTlv.Padding(ByteVector.fill(targetLength - payloadLength)(0)))
}) })
@ -95,4 +95,19 @@ object BlindedRouteCreation {
Sphinx.RouteBlinding.create(randomKey(), Seq(nodeId), Seq(finalPayload)) Sphinx.RouteBlinding.create(randomKey(), Seq(nodeId), Seq(finalPayload))
} }
/** Create a blinded route where the recipient is a wallet node. */
def createBlindedRouteToWallet(hop: Router.ChannelHop, pathId: ByteVector, minAmount: MilliSatoshi, routeFinalExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = {
val routeExpiry = routeFinalExpiry + hop.cltvExpiryDelta
val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(
RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, minAmount),
RouteBlindingEncryptedDataTlv.PathId(pathId),
)).require.bytes
val intermediatePayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream[RouteBlindingEncryptedDataTlv](
RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Wallet(hop.nextNodeId)),
RouteBlindingEncryptedDataTlv.PaymentRelay(hop.cltvExpiryDelta, hop.params.relayFees.feeProportionalMillionths, hop.params.relayFees.feeBase),
RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, minAmount),
)).require.bytes
Sphinx.RouteBlinding.create(randomKey(), Seq(hop.nodeId, hop.nextNodeId), Seq(intermediatePayload, finalPayload))
}
} }

View File

@ -23,7 +23,7 @@ import fr.acinq.eclair.wire.protocol.BlindedRouteData.PaymentRelayData
import fr.acinq.eclair.wire.protocol.CommonCodecs._ import fr.acinq.eclair.wire.protocol.CommonCodecs._
import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv}
import fr.acinq.eclair.wire.protocol.TlvCodecs._ import fr.acinq.eclair.wire.protocol.TlvCodecs._
import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, UInt64} import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, ShortChannelId, UInt64}
import scodec.bits.{BitVector, ByteVector} import scodec.bits.{BitVector, ByteVector}
/** /**
@ -227,7 +227,8 @@ object PaymentOnion {
object IntermediatePayload { object IntermediatePayload {
sealed trait ChannelRelay extends IntermediatePayload { sealed trait ChannelRelay extends IntermediatePayload {
// @formatter:off // @formatter:off
def outgoingChannelId: ShortChannelId /** The outgoing channel, or the nodeId of one of our peers. */
def outgoing: Either[PublicKey, ShortChannelId]
def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi
def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry
// @formatter:on // @formatter:on
@ -238,7 +239,7 @@ object PaymentOnion {
// @formatter:off // @formatter:off
val amountOut = records.get[AmountToForward].get.amount val amountOut = records.get[AmountToForward].get.amount
val cltvOut = records.get[OutgoingCltv].get.cltv val cltvOut = records.get[OutgoingCltv].get.cltv
override val outgoingChannelId = records.get[OutgoingChannelId].get.shortChannelId override val outgoing = Right(records.get[OutgoingChannelId].get.shortChannelId)
override def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = amountOut override def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = amountOut
override def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = cltvOut override def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = cltvOut
// @formatter:on // @formatter:on
@ -258,12 +259,12 @@ object PaymentOnion {
} }
/** /**
* @param blindedRecords decrypted tlv stream from the encrypted_recipient_data tlv. * @param paymentRelayData decrypted relaying data from the encrypted_recipient_data tlv.
* @param nextBlinding blinding point that must be forwarded to the next hop. * @param nextBlinding blinding point that must be forwarded to the next hop.
*/ */
case class Blinded(records: TlvStream[OnionPaymentPayloadTlv], paymentRelayData: PaymentRelayData, nextBlinding: PublicKey) extends ChannelRelay { case class Blinded(records: TlvStream[OnionPaymentPayloadTlv], paymentRelayData: PaymentRelayData, nextBlinding: PublicKey) extends ChannelRelay {
// @formatter:off // @formatter:off
override val outgoingChannelId = paymentRelayData.outgoingChannelId override val outgoing = paymentRelayData.outgoing
override def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = paymentRelayData.amountToForward(incomingAmount) override def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = paymentRelayData.amountToForward(incomingAmount)
override def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = paymentRelayData.outgoingCltv(incomingCltv) override def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = paymentRelayData.outgoingCltv(incomingCltv)
// @formatter:on // @formatter:on

View File

@ -98,7 +98,11 @@ object BlindedRouteData {
} }
case class PaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]) { case class PaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]) {
val outgoingChannelId: ShortChannelId = records.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId].get.shortChannelId // This is usually a channel, unless the next node is a mobile wallet connected to our node.
val outgoing: Either[PublicKey, ShortChannelId] = records.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId] match {
case Some(r) => Right(r.shortChannelId)
case None => Left(records.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].get.nodeId.asInstanceOf[EncodedNodeId.WithPublicKey.Wallet].publicKey)
}
val paymentRelay: PaymentRelay = records.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get val paymentRelay: PaymentRelay = records.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get
val paymentConstraints: PaymentConstraints = records.get[RouteBlindingEncryptedDataTlv.PaymentConstraints].get val paymentConstraints: PaymentConstraints = records.get[RouteBlindingEncryptedDataTlv.PaymentConstraints].get
val allowedFeatures: Features[Feature] = records.get[RouteBlindingEncryptedDataTlv.AllowedFeatures].map(_.features).getOrElse(Features.empty) val allowedFeatures: Features[Feature] = records.get[RouteBlindingEncryptedDataTlv.AllowedFeatures].map(_.features).getOrElse(Features.empty)
@ -110,7 +114,9 @@ object BlindedRouteData {
} }
def validatePaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]): Either[InvalidTlvPayload, PaymentRelayData] = { def validatePaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]): Either[InvalidTlvPayload, PaymentRelayData] = {
if (records.get[OutgoingChannelId].isEmpty) return Left(MissingRequiredTlv(UInt64(2))) // Note that the BOLTs require using an OutgoingChannelId, but we optionally support a wallet node_id.
if (records.get[OutgoingChannelId].isEmpty && records.get[OutgoingNodeId].isEmpty) return Left(MissingRequiredTlv(UInt64(2)))
if (records.get[OutgoingNodeId].nonEmpty && !records.get[OutgoingNodeId].get.nodeId.isInstanceOf[EncodedNodeId.WithPublicKey.Wallet]) return Left(ForbiddenTlv(UInt64(4)))
if (records.get[PaymentRelay].isEmpty) return Left(MissingRequiredTlv(UInt64(10))) if (records.get[PaymentRelay].isEmpty) return Left(MissingRequiredTlv(UInt64(10)))
if (records.get[PaymentConstraints].isEmpty) return Left(MissingRequiredTlv(UInt64(12))) if (records.get[PaymentConstraints].isEmpty) return Left(MissingRequiredTlv(UInt64(12)))
if (records.get[PathId].nonEmpty) return Left(ForbiddenTlv(UInt64(6))) if (records.get[PathId].nonEmpty) return Left(ForbiddenTlv(UInt64(6)))

View File

@ -26,7 +26,7 @@ import fr.acinq.eclair.channel.{ChannelFlags, LocalParams, Origin, Upstream}
import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyManager} import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyManager}
import fr.acinq.eclair.db.RevokedHtlcInfoCleaner import fr.acinq.eclair.db.RevokedHtlcInfoCleaner
import fr.acinq.eclair.io.MessageRelay.RelayAll import fr.acinq.eclair.io.MessageRelay.RelayAll
import fr.acinq.eclair.io.{OpenChannelInterceptor, PeerConnection} import fr.acinq.eclair.io.{OpenChannelInterceptor, PeerConnection, PeerReadyNotifier}
import fr.acinq.eclair.message.OnionMessages.OnionMessageConfig import fr.acinq.eclair.message.OnionMessages.OnionMessageConfig
import fr.acinq.eclair.payment.relay.Relayer.{AsyncPaymentsParams, RelayFees, RelayParams} import fr.acinq.eclair.payment.relay.Relayer.{AsyncPaymentsParams, RelayFees, RelayParams}
import fr.acinq.eclair.router.Graph.{MessagePath, WeightRatios} import fr.acinq.eclair.router.Graph.{MessagePath, WeightRatios}
@ -231,7 +231,8 @@ object TestConstants {
maxAttempts = 2, maxAttempts = 2,
), ),
purgeInvoicesInterval = None, purgeInvoicesInterval = None,
revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis) revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis),
peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(enabled = false, timeout = 30 seconds),
) )
def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams(
@ -401,7 +402,8 @@ object TestConstants {
maxAttempts = 2, maxAttempts = 2,
), ),
purgeInvoicesInterval = None, purgeInvoicesInterval = None,
revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis) revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis),
peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(enabled = false, timeout = 30 seconds),
) )
def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams(

View File

@ -66,8 +66,8 @@ class FuzzySpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Channe
val bobRegister = system.actorOf(Props(new TestRegister())) val bobRegister = system.actorOf(Props(new TestRegister()))
val alicePaymentHandler = system.actorOf(Props(new PaymentHandler(aliceParams, aliceRegister, TestProbe().ref))) val alicePaymentHandler = system.actorOf(Props(new PaymentHandler(aliceParams, aliceRegister, TestProbe().ref)))
val bobPaymentHandler = system.actorOf(Props(new PaymentHandler(bobParams, bobRegister, TestProbe().ref))) val bobPaymentHandler = system.actorOf(Props(new PaymentHandler(bobParams, bobRegister, TestProbe().ref)))
val aliceRelayer = system.actorOf(Relayer.props(aliceParams, TestProbe().ref, aliceRegister, alicePaymentHandler, TestProbe().ref)) val aliceRelayer = system.actorOf(Relayer.props(aliceParams, TestProbe().ref, aliceRegister, alicePaymentHandler))
val bobRelayer = system.actorOf(Relayer.props(bobParams, TestProbe().ref, bobRegister, bobPaymentHandler, TestProbe().ref)) val bobRelayer = system.actorOf(Relayer.props(bobParams, TestProbe().ref, bobRegister, bobPaymentHandler))
val wallet = new DummyOnChainWallet() val wallet = new DummyOnChainWallet()
val alice: TestFSMRef[ChannelState, ChannelData, Channel] = TestFSMRef(new Channel(aliceParams, wallet, bobParams.nodeId, alice2blockchain.ref, aliceRelayer, FakeTxPublisherFactory(alice2blockchain)), alicePeer.ref) val alice: TestFSMRef[ChannelState, ChannelData, Channel] = TestFSMRef(new Channel(aliceParams, wallet, bobParams.nodeId, alice2blockchain.ref, aliceRelayer, FakeTxPublisherFactory(alice2blockchain)), alicePeer.ref)
val bob: TestFSMRef[ChannelState, ChannelData, Channel] = TestFSMRef(new Channel(bobParams, wallet, aliceParams.nodeId, bob2blockchain.ref, bobRelayer, FakeTxPublisherFactory(bob2blockchain)), bobPeer.ref) val bob: TestFSMRef[ChannelState, ChannelData, Channel] = TestFSMRef(new Channel(bobParams, wallet, aliceParams.nodeId, bob2blockchain.ref, bobRelayer, FakeTxPublisherFactory(bob2blockchain)), bobPeer.ref)

View File

@ -505,7 +505,7 @@ class SphinxSpec extends AnyFunSuite {
val Right(decryptedPayloadBob) = RouteBlindingEncryptedDataCodecs.decode(bob, blinding, tlvsBob.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) val Right(decryptedPayloadBob) = RouteBlindingEncryptedDataCodecs.decode(bob, blinding, tlvsBob.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data)
val blindingEphemeralKeyForCarol = decryptedPayloadBob.nextBlinding val blindingEphemeralKeyForCarol = decryptedPayloadBob.nextBlinding
val Right(payloadBob) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsBob, decryptedPayloadBob.tlvs, blindingEphemeralKeyForCarol) val Right(payloadBob) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsBob, decryptedPayloadBob.tlvs, blindingEphemeralKeyForCarol)
assert(payloadBob.outgoingChannelId == ShortChannelId(1)) assert(payloadBob.outgoing.contains(ShortChannelId(1)))
assert(payloadBob.amountToForward(110_125 msat) == 100_125.msat) assert(payloadBob.amountToForward(110_125 msat) == 100_125.msat)
assert(payloadBob.outgoingCltv(CltvExpiry(749150)) == CltvExpiry(749100)) assert(payloadBob.outgoingCltv(CltvExpiry(749150)) == CltvExpiry(749100))
assert(payloadBob.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(50), 0, 10_000 msat)) assert(payloadBob.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(50), 0, 10_000 msat))
@ -523,7 +523,7 @@ class SphinxSpec extends AnyFunSuite {
val Right(decryptedPayloadCarol) = RouteBlindingEncryptedDataCodecs.decode(carol, blindingEphemeralKeyForCarol, tlvsCarol.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) val Right(decryptedPayloadCarol) = RouteBlindingEncryptedDataCodecs.decode(carol, blindingEphemeralKeyForCarol, tlvsCarol.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data)
val blindingEphemeralKeyForDave = decryptedPayloadCarol.nextBlinding val blindingEphemeralKeyForDave = decryptedPayloadCarol.nextBlinding
val Right(payloadCarol) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsCarol, decryptedPayloadCarol.tlvs, blindingEphemeralKeyForDave) val Right(payloadCarol) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsCarol, decryptedPayloadCarol.tlvs, blindingEphemeralKeyForDave)
assert(payloadCarol.outgoingChannelId == ShortChannelId(2)) assert(payloadCarol.outgoing.contains(ShortChannelId(2)))
assert(payloadCarol.amountToForward(100_125 msat) == 100_010.msat) assert(payloadCarol.amountToForward(100_125 msat) == 100_010.msat)
assert(payloadCarol.outgoingCltv(CltvExpiry(749100)) == CltvExpiry(749025)) assert(payloadCarol.outgoingCltv(CltvExpiry(749100)) == CltvExpiry(749025))
assert(payloadCarol.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(75), 150, 100 msat)) assert(payloadCarol.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(75), 150, 100 msat))
@ -544,7 +544,7 @@ class SphinxSpec extends AnyFunSuite {
val Right(decryptedPayloadDave) = RouteBlindingEncryptedDataCodecs.decode(dave, blindingOverride, tlvsDave.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) val Right(decryptedPayloadDave) = RouteBlindingEncryptedDataCodecs.decode(dave, blindingOverride, tlvsDave.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data)
val blindingEphemeralKeyForEve = decryptedPayloadDave.nextBlinding val blindingEphemeralKeyForEve = decryptedPayloadDave.nextBlinding
val Right(payloadDave) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsDave, decryptedPayloadDave.tlvs, blindingEphemeralKeyForEve) val Right(payloadDave) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsDave, decryptedPayloadDave.tlvs, blindingEphemeralKeyForEve)
assert(payloadDave.outgoingChannelId == ShortChannelId(3)) assert(payloadDave.outgoing.contains(ShortChannelId(3)))
assert(payloadDave.amountToForward(100_010 msat) == 100_000.msat) assert(payloadDave.amountToForward(100_010 msat) == 100_000.msat)
assert(payloadDave.outgoingCltv(CltvExpiry(749025)) == CltvExpiry(749000)) assert(payloadDave.outgoingCltv(CltvExpiry(749025)) == CltvExpiry(749000))
assert(payloadDave.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(25), 100, 0 msat)) assert(payloadDave.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(25), 100, 0 msat))

View File

@ -90,13 +90,12 @@ object MinimalNodeFixture extends Assertions with Eventually with IntegrationPat
val bitcoinClient = new TestBitcoinCoreClient() val bitcoinClient = new TestBitcoinCoreClient()
val wallet = new SingleKeyOnChainWallet() val wallet = new SingleKeyOnChainWallet()
val watcher = TestProbe("watcher") val watcher = TestProbe("watcher")
val triggerer = TestProbe("payment-triggerer")
val watcherTyped = watcher.ref.toTyped[ZmqWatcher.Command] val watcherTyped = watcher.ref.toTyped[ZmqWatcher.Command]
val register = system.actorOf(Register.props(), "register") val register = system.actorOf(Register.props(), "register")
val router = system.actorOf(Router.props(nodeParams, watcherTyped), "router") val router = system.actorOf(Router.props(nodeParams, watcherTyped), "router")
val offerManager = system.spawn(OfferManager(nodeParams, router, 1 minute), "offer-manager") val offerManager = system.spawn(OfferManager(nodeParams, router, 1 minute), "offer-manager")
val paymentHandler = system.actorOf(PaymentHandler.props(nodeParams, register, offerManager), "payment-handler") val paymentHandler = system.actorOf(PaymentHandler.props(nodeParams, register, offerManager), "payment-handler")
val relayer = system.actorOf(Relayer.props(nodeParams, router, register, paymentHandler, triggerer.ref.toTyped), "relayer") val relayer = system.actorOf(Relayer.props(nodeParams, router, register, paymentHandler), "relayer")
val txPublisherFactory = Channel.SimpleTxPublisherFactory(nodeParams, watcherTyped, bitcoinClient) val txPublisherFactory = Channel.SimpleTxPublisherFactory(nodeParams, watcherTyped, bitcoinClient)
val channelFactory = Peer.SimpleChannelFactory(nodeParams, watcherTyped, relayer, wallet, txPublisherFactory) val channelFactory = Peer.SimpleChannelFactory(nodeParams, watcherTyped, relayer, wallet, txPublisherFactory)
val pendingChannelsRateLimiter = system.spawnAnonymous(Behaviors.supervise(PendingChannelsRateLimiter(nodeParams, router.toTyped, Seq())).onFailure(typed.SupervisorStrategy.resume)) val pendingChannelsRateLimiter = system.spawnAnonymous(Behaviors.supervise(PendingChannelsRateLimiter(nodeParams, router.toTyped, Seq())).onFailure(typed.SupervisorStrategy.resume))

View File

@ -19,8 +19,10 @@ package fr.acinq.eclair.io
import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe => TypedProbe} import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe => TypedProbe}
import akka.actor.typed.ActorRef import akka.actor.typed.ActorRef
import akka.actor.typed.eventstream.EventStream import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.receptionist.Receptionist
import akka.actor.typed.scaladsl.adapter.{ClassicActorRefOps, TypedActorRefOps}
import akka.testkit.TestProbe import akka.testkit.TestProbe
import com.softwaremill.quicklens.ModifyPimp
import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigFactory
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.TestConstants.{Alice, Bob} import fr.acinq.eclair.TestConstants.{Alice, Bob}
@ -33,8 +35,8 @@ import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient}
import fr.acinq.eclair.router.Router import fr.acinq.eclair.router.Router
import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionMessagePayloadTlv, TlvStream} import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionMessagePayloadTlv, TlvStream}
import fr.acinq.eclair.{EncodedNodeId, RealShortChannelId, ShortChannelId, UInt64, randomBytes32, randomKey} import fr.acinq.eclair.{EncodedNodeId, RealShortChannelId, ShortChannelId, UInt64, randomBytes32, randomKey}
import org.scalatest.Outcome
import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.funsuite.FixtureAnyFunSuiteLike
import org.scalatest.{Outcome, Tag}
import scodec.bits.HexStringSyntax import scodec.bits.HexStringSyntax
import scala.concurrent.duration.DurationInt import scala.concurrent.duration.DurationInt
@ -43,19 +45,30 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app
val aliceId: PublicKey = Alice.nodeParams.nodeId val aliceId: PublicKey = Alice.nodeParams.nodeId
val bobId: PublicKey = Bob.nodeParams.nodeId val bobId: PublicKey = Bob.nodeParams.nodeId
case class FixtureParam(relay: ActorRef[Command], switchboard: TestProbe, register: TestProbe, router: TypedProbe[Router.GetNodeId], peerConnection: TypedProbe[Nothing], peer: TypedProbe[Peer.RelayOnionMessage], probe: TypedProbe[Status]) val wakeUpEnabled = "wake_up_enabled"
val wakeUpTimeout = "wake_up_timeout"
case class FixtureParam(relay: ActorRef[Command], switchboard: TestProbe, register: TestProbe, router: TypedProbe[Router.GetNodeId], peerConnection: TypedProbe[Nothing], peer: TypedProbe[Peer.RelayOnionMessage], peerReadyManager: TestProbe, probe: TypedProbe[Status])
override def withFixture(test: OneArgTest): Outcome = { override def withFixture(test: OneArgTest): Outcome = {
val peerReadyManager = TestProbe("peer-ready-manager")(system.classicSystem)
system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref.toTyped)
val switchboard = TestProbe("switchboard")(system.classicSystem) val switchboard = TestProbe("switchboard")(system.classicSystem)
system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref.toTyped)
val register = TestProbe("register")(system.classicSystem) val register = TestProbe("register")(system.classicSystem)
val router = TypedProbe[Router.GetNodeId]("router") val router = TypedProbe[Router.GetNodeId]("router")
val peerConnection = TypedProbe[Nothing]("peerConnection") val peerConnection = TypedProbe[Nothing]("peerConnection")
val peer = TypedProbe[Peer.RelayOnionMessage]("peer") val peer = TypedProbe[Peer.RelayOnionMessage]("peer")
val probe = TypedProbe[Status]("probe") val probe = TypedProbe[Status]("probe")
val relay = testKit.spawn(MessageRelay(Alice.nodeParams, switchboard.ref, register.ref, router.ref)) val nodeParams = Alice.nodeParams
.modify(_.peerWakeUpConfig.enabled).setToIf(test.tags.contains(wakeUpEnabled))(true)
.modify(_.peerWakeUpConfig.timeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis)
val relay = testKit.spawn(MessageRelay(nodeParams, switchboard.ref, register.ref, router.ref))
try { try {
withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, register, router, peerConnection, peer, probe))) withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, register, router, peerConnection, peer, peerReadyManager, probe)))
} finally { } finally {
system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref.toTyped)
system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref.toTyped)
testKit.stop(relay) testKit.stop(relay)
} }
} }
@ -86,6 +99,23 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app
assert(peer.expectMessageType[Peer.RelayOnionMessage].msg == message) assert(peer.expectMessageType[Peer.RelayOnionMessage].msg == message)
} }
test("relay after waking up next node", Tag(wakeUpEnabled)) { f =>
import f._
val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty)
val messageId = randomBytes32()
relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId.WithPublicKey.Wallet(bobId)), message, RelayChannelsOnly, None)
val register = peerReadyManager.expectMsgType[PeerReadyManager.Register]
assert(register.remoteNodeId == bobId)
register.replyTo ! PeerReadyManager.Registered(bobId, otherAttempts = 0)
val request = switchboard.expectMsgType[GetPeerInfo]
assert(request.remoteNodeId == bobId)
request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, bobId, Peer.CONNECTED, None, Set.empty)
assert(peer.expectMessageType[Peer.RelayOnionMessage].msg == message)
}
test("can't open new connection") { f => test("can't open new connection") { f =>
import f._ import f._
@ -99,6 +129,15 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app
probe.expectMessage(ConnectionFailure(messageId, PeerConnection.ConnectionResult.NoAddressFound)) probe.expectMessage(ConnectionFailure(messageId, PeerConnection.ConnectionResult.NoAddressFound))
} }
test("can't wake up next node", Tag(wakeUpEnabled), Tag(wakeUpTimeout)) { f =>
import f._
val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty)
val messageId = randomBytes32()
relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId.WithPublicKey.Wallet(bobId)), message, RelayChannelsOnly, Some(probe.ref))
probe.expectMessage(Disconnected(messageId))
}
test("no channel with previous node") { f => test("no channel with previous node") { f =>
import f._ import f._

View File

@ -0,0 +1,55 @@
/*
* Copyright 2024 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package fr.acinq.eclair.io
import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe}
import com.typesafe.config.ConfigFactory
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.randomKey
import org.scalatest.funsuite.AnyFunSuiteLike
class PeerReadyManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with AnyFunSuiteLike {
test("watch pending notifiers") {
val manager = testKit.spawn(PeerReadyManager())
val remoteNodeId1 = randomKey().publicKey
val notifier1a = TestProbe[PeerReadyManager.Registered]()
val notifier1b = TestProbe[PeerReadyManager.Registered]()
manager ! PeerReadyManager.Register(notifier1a.ref, remoteNodeId1)
assert(notifier1a.expectMessageType[PeerReadyManager.Registered].otherAttempts == 0)
manager ! PeerReadyManager.Register(notifier1b.ref, remoteNodeId1)
assert(notifier1b.expectMessageType[PeerReadyManager.Registered].otherAttempts == 1)
val remoteNodeId2 = randomKey().publicKey
val notifier2a = TestProbe[PeerReadyManager.Registered]()
val notifier2b = TestProbe[PeerReadyManager.Registered]()
// Later attempts aren't affected by previously completed attempts.
manager ! PeerReadyManager.Register(notifier2a.ref, remoteNodeId2)
assert(notifier2a.expectMessageType[PeerReadyManager.Registered].otherAttempts == 0)
notifier2a.stop()
val probe = TestProbe[Set[PublicKey]]()
probe.awaitAssert({
manager ! PeerReadyManager.List(probe.ref)
assert(probe.expectMessageType[Set[PublicKey]] == Set(remoteNodeId1))
})
manager ! PeerReadyManager.Register(notifier2b.ref, remoteNodeId2)
assert(notifier2b.expectMessageType[PeerReadyManager.Registered].otherAttempts == 0)
}
}

View File

@ -33,17 +33,20 @@ import scala.concurrent.duration.DurationInt
class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike { class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike {
case class FixtureParam(remoteNodeId: PublicKey, switchboard: TestProbe[Switchboard.GetPeerInfo], peer: TestProbe[Peer.GetPeerChannels], probe: TestProbe[PeerReadyNotifier.Result]) case class FixtureParam(remoteNodeId: PublicKey, peerReadyManager: TestProbe[PeerReadyManager.Register], switchboard: TestProbe[Switchboard.GetPeerInfo], peer: TestProbe[Peer.GetPeerChannels], probe: TestProbe[PeerReadyNotifier.Result])
override def withFixture(test: OneArgTest): Outcome = { override def withFixture(test: OneArgTest): Outcome = {
val remoteNodeId = randomKey().publicKey val remoteNodeId = randomKey().publicKey
val peerReadyManager = TestProbe[PeerReadyManager.Register]("peer-ready-manager")
system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
val switchboard = TestProbe[Switchboard.GetPeerInfo]("switchboard") val switchboard = TestProbe[Switchboard.GetPeerInfo]("switchboard")
system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref)
val peer = TestProbe[Peer.GetPeerChannels]("peer") val peer = TestProbe[Peer.GetPeerChannels]("peer")
val probe = TestProbe[PeerReadyNotifier.Result]() val probe = TestProbe[PeerReadyNotifier.Result]()
try { try {
withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, switchboard, peer, probe))) withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, peerReadyManager, switchboard, peer, probe)))
} finally { } finally {
system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref)
} }
} }
@ -53,7 +56,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Left(10 millis)))) val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Left(10 millis))))
notifier ! NotifyWhenPeerReady(probe.ref) notifier ! NotifyWhenPeerReady(probe.ref)
assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == remoteNodeId) peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0)
probe.expectMessage(PeerUnavailable(remoteNodeId)) probe.expectMessage(PeerUnavailable(remoteNodeId))
} }
@ -62,6 +65,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(100))))) val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(100)))))
notifier ! NotifyWhenPeerReady(probe.ref) notifier ! NotifyWhenPeerReady(probe.ref)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0)
assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == remoteNodeId) assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == remoteNodeId)
// We haven't reached the timeout yet. // We haven't reached the timeout yet.
@ -78,6 +82,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500))))) val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500)))))
notifier ! NotifyWhenPeerReady(probe.ref) notifier ! NotifyWhenPeerReady(probe.ref)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0)
val request = switchboard.expectMessageType[Switchboard.GetPeerInfo] val request = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set.empty) request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set.empty)
probe.expectMessage(PeerReadyNotifier.PeerReady(remoteNodeId, peer.ref.toClassic, Seq.empty)) probe.expectMessage(PeerReadyNotifier.PeerReady(remoteNodeId, peer.ref.toClassic, Seq.empty))
@ -88,6 +93,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500))))) val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500)))))
notifier ! NotifyWhenPeerReady(probe.ref) notifier ! NotifyWhenPeerReady(probe.ref)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0)
val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set(TestProbe().ref.toClassic, TestProbe().ref.toClassic)) request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set(TestProbe().ref.toClassic, TestProbe().ref.toClassic))
@ -115,6 +121,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500))))) val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500)))))
notifier ! NotifyWhenPeerReady(probe.ref) notifier ! NotifyWhenPeerReady(probe.ref)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 1)
val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic, TestProbe().ref.toClassic)) request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic, TestProbe().ref.toClassic))
peer.expectNoMessage(100 millis) peer.expectNoMessage(100 millis)
@ -137,6 +144,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = None)) val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = None))
notifier ! NotifyWhenPeerReady(probe.ref) notifier ! NotifyWhenPeerReady(probe.ref)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0)
val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request1.replyTo ! Peer.PeerNotFound(remoteNodeId) request1.replyTo ! Peer.PeerNotFound(remoteNodeId)
peer.expectNoMessage(100 millis) peer.expectNoMessage(100 millis)
@ -161,6 +169,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500))))) val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500)))))
notifier ! NotifyWhenPeerReady(probe.ref) notifier ! NotifyWhenPeerReady(probe.ref)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 5)
val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set.empty) request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set.empty)
peer.expectNoMessage(100 millis) peer.expectNoMessage(100 millis)
@ -185,6 +194,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Left(1 second)))) val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Left(1 second))))
notifier ! NotifyWhenPeerReady(probe.ref) notifier ! NotifyWhenPeerReady(probe.ref)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0)
val request = switchboard.expectMessageType[Switchboard.GetPeerInfo] val request = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set(TestProbe().ref.toClassic)) request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set(TestProbe().ref.toClassic))
peer.expectMessageType[Peer.GetPeerChannels] peer.expectMessageType[Peer.GetPeerChannels]
@ -196,6 +206,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(100))))) val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(100)))))
notifier ! NotifyWhenPeerReady(probe.ref) notifier ! NotifyWhenPeerReady(probe.ref)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 2)
val request = switchboard.expectMessageType[Switchboard.GetPeerInfo] val request = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set(TestProbe().ref.toClassic)) request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set(TestProbe().ref.toClassic))
peer.expectMessageType[Peer.GetPeerChannels] peer.expectMessageType[Peer.GetPeerChannels]

View File

@ -85,7 +85,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength)
assert(relay_b.amountToForward == amount_bc) assert(relay_b.amountToForward == amount_bc)
assert(relay_b.outgoingCltv == expiry_bc) assert(relay_b.outgoingCltv == expiry_bc)
assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId))
assert(relay_b.relayFeeMsat == fee_b) assert(relay_b.relayFeeMsat == fee_b)
assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta)
@ -95,7 +95,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength)
assert(relay_c.amountToForward == amount_cd) assert(relay_c.amountToForward == amount_cd)
assert(relay_c.outgoingCltv == expiry_cd) assert(relay_c.outgoingCltv == expiry_cd)
assert(payload_c.outgoingChannelId == channelUpdate_cd.shortChannelId) assert(payload_c.outgoing.contains(channelUpdate_cd.shortChannelId))
assert(relay_c.relayFeeMsat == fee_c) assert(relay_c.relayFeeMsat == fee_c)
assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta) assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta)
@ -105,7 +105,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength)
assert(relay_d.amountToForward == amount_de) assert(relay_d.amountToForward == amount_de)
assert(relay_d.outgoingCltv == expiry_de) assert(relay_d.outgoingCltv == expiry_de)
assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId))
assert(relay_d.relayFeeMsat == fee_d) assert(relay_d.relayFeeMsat == fee_d)
assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta) assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta)
@ -175,7 +175,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength)
assert(relay_b.amountToForward >= amount_bc) assert(relay_b.amountToForward >= amount_bc)
assert(relay_b.outgoingCltv == expiry_bc) assert(relay_b.outgoingCltv == expiry_bc)
assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId))
assert(relay_b.relayFeeMsat == fee_b) assert(relay_b.relayFeeMsat == fee_b)
assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta)
assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard]) assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard])
@ -185,7 +185,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength)
assert(relay_c.amountToForward >= amount_cd) assert(relay_c.amountToForward >= amount_cd)
assert(relay_c.outgoingCltv == expiry_cd) assert(relay_c.outgoingCltv == expiry_cd)
assert(payload_c.outgoingChannelId == channelUpdate_cd.shortChannelId) assert(payload_c.outgoing.contains(channelUpdate_cd.shortChannelId))
assert(relay_c.relayFeeMsat == fee_c) assert(relay_c.relayFeeMsat == fee_c)
assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta) assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta)
assert(payload_c.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) assert(payload_c.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded])
@ -196,7 +196,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength)
assert(relay_d.amountToForward >= amount_de) assert(relay_d.amountToForward >= amount_de)
assert(relay_d.outgoingCltv == expiry_de) assert(relay_d.outgoingCltv == expiry_de)
assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId))
assert(relay_d.relayFeeMsat == fee_d) assert(relay_d.relayFeeMsat == fee_d)
assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta) assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta)
assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded])
@ -238,7 +238,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength)
assert(relay_b.amountToForward >= amount_bc) assert(relay_b.amountToForward >= amount_bc)
assert(relay_b.outgoingCltv == expiry_bc) assert(relay_b.outgoingCltv == expiry_bc)
assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId))
assert(relay_b.relayFeeMsat == fee_b) assert(relay_b.relayFeeMsat == fee_b)
assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta)
assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard]) assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard])
@ -547,7 +547,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
// A smaller amount is sent to d, who doesn't know that it's invalid. // A smaller amount is sent to d, who doesn't know that it's invalid.
val add_d = UpdateAddHtlc(randomBytes32(), 0, amount_de, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt, 1.0) val add_d = UpdateAddHtlc(randomBytes32(), 0, amount_de, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt, 1.0)
val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional))
assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId))
assert(relay_d.amountToForward < amount_de) assert(relay_d.amountToForward < amount_de)
assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded])
val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding
@ -569,7 +569,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
val invalidExpiry = payment.cmd.cltvExpiry - Channel.MIN_CLTV_EXPIRY_DELTA - CltvExpiryDelta(1) val invalidExpiry = payment.cmd.cltvExpiry - Channel.MIN_CLTV_EXPIRY_DELTA - CltvExpiryDelta(1)
val add_d = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, paymentHash, invalidExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt, 1.0) val add_d = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, paymentHash, invalidExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt, 1.0)
val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional))
assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId))
assert(relay_d.outgoingCltv < CltvExpiry(currentBlockCount)) assert(relay_d.outgoingCltv < CltvExpiry(currentBlockCount))
assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded])
val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding

View File

@ -57,7 +57,7 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit
case class FixtureParam(nodeParams: NodeParams, register: TestProbe, sender: TestProbe, eventListener: TestProbe) { case class FixtureParam(nodeParams: NodeParams, register: TestProbe, sender: TestProbe, eventListener: TestProbe) {
def createRelayer(nodeParams1: NodeParams): (ActorRef, ActorRef) = { def createRelayer(nodeParams1: NodeParams): (ActorRef, ActorRef) = {
val relayer = system.actorOf(Relayer.props(nodeParams1, TestProbe().ref, register.ref, TestProbe().ref, TestProbe().ref.toTyped)) val relayer = system.actorOf(Relayer.props(nodeParams1, TestProbe().ref, register.ref, TestProbe().ref))
// we need ensure the post-htlc-restart child actor is initialized // we need ensure the post-htlc-restart child actor is initialized
sender.send(relayer, Relayer.GetChildActors(sender.ref)) sender.send(relayer, Relayer.GetChildActors(sender.ref))
(relayer, sender.expectMsgType[Relayer.ChildActors].postRestartCleaner) (relayer, sender.expectMsgType[Relayer.ChildActors].postRestartCleaner)

View File

@ -1,17 +1,18 @@
package fr.acinq.eclair.payment.relay package fr.acinq.eclair.payment.relay
import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe} import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe}
import akka.actor.typed.ActorRef
import akka.actor.typed.eventstream.EventStream import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.receptionist.Receptionist import akka.actor.typed.receptionist.Receptionist
import akka.actor.typed.scaladsl.Behaviors
import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.scaladsl.adapter.TypedActorRefOps
import akka.actor.typed.{ActorRef, Behavior}
import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigFactory
import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.blockchain.CurrentBlockHeight import fr.acinq.eclair.blockchain.CurrentBlockHeight
import fr.acinq.eclair.channel.NEGOTIATING import fr.acinq.eclair.channel.NEGOTIATING
import fr.acinq.eclair.io.Switchboard.GetPeerInfo import fr.acinq.eclair.io.Switchboard.GetPeerInfo
import fr.acinq.eclair.io.{Peer, PeerConnected, Switchboard} import fr.acinq.eclair.io.{Peer, PeerConnected, PeerReadyManager, Switchboard}
import fr.acinq.eclair.payment.relay.AsyncPaymentTriggerer._ import fr.acinq.eclair.payment.relay.AsyncPaymentTriggerer._
import fr.acinq.eclair.{BlockHeight, TestConstants, randomKey} import fr.acinq.eclair.{BlockHeight, TestConstants, randomKey}
import org.scalatest.Outcome import org.scalatest.Outcome
@ -23,8 +24,20 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.
case class FixtureParam(remoteNodeId: PublicKey, switchboard: TestProbe[Switchboard.GetPeerInfo], peer: TestProbe[Peer.GetPeerChannels], probe: TestProbe[Result], triggerer: ActorRef[Command]) case class FixtureParam(remoteNodeId: PublicKey, switchboard: TestProbe[Switchboard.GetPeerInfo], peer: TestProbe[Peer.GetPeerChannels], probe: TestProbe[Result], triggerer: ActorRef[Command])
object DummyPeerReadyManager {
def apply(): Behavior[PeerReadyManager.Command] = {
Behaviors.receiveMessagePartial {
case PeerReadyManager.Register(replyTo, remoteNodeId) =>
replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0)
Behaviors.same
}
}
}
override def withFixture(test: OneArgTest): Outcome = { override def withFixture(test: OneArgTest): Outcome = {
val remoteNodeId = TestConstants.Alice.nodeParams.nodeId val remoteNodeId = TestConstants.Alice.nodeParams.nodeId
val peerReadyManager = testKit.spawn(DummyPeerReadyManager())
system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager)
val switchboard = TestProbe[Switchboard.GetPeerInfo]("switchboard") val switchboard = TestProbe[Switchboard.GetPeerInfo]("switchboard")
system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref)
val peer = TestProbe[Peer.GetPeerChannels]("peer") val peer = TestProbe[Peer.GetPeerChannels]("peer")
@ -33,6 +46,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.
try { try {
withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, switchboard, peer, probe, triggerer))) withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, switchboard, peer, probe, triggerer)))
} finally { } finally {
system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager)
system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref)
} }
} }
@ -170,7 +184,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.
val probe2 = TestProbe[Result]() val probe2 = TestProbe[Result]()
triggerer ! Watch(probe2.ref, remoteNodeId2, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(101)) triggerer ! Watch(probe2.ref, remoteNodeId2, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(101))
val request2 = switchboard.expectMessageType[Switchboard.GetPeerInfo] val request2 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request2.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic)) request2.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId2, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic))
// First remote node times out // First remote node times out
system.eventStream ! EventStream.Publish(CurrentBlockHeight(BlockHeight(100))) system.eventStream ! EventStream.Publish(CurrentBlockHeight(BlockHeight(100)))
@ -192,6 +206,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.
test("triggerer treats an unexpected stop of the notifier as a cancel") { f => test("triggerer treats an unexpected stop of the notifier as a cancel") { f =>
import f._ import f._
triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100))
assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId) assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId)

View File

@ -19,6 +19,7 @@ package fr.acinq.eclair.payment.relay
import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe} import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe}
import akka.actor.typed import akka.actor.typed
import akka.actor.typed.eventstream.EventStream import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.receptionist.Receptionist
import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.scaladsl.adapter.TypedActorRefOps
import com.softwaremill.quicklens.ModifyPimp import com.softwaremill.quicklens.ModifyPimp
import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigFactory
@ -29,6 +30,7 @@ import fr.acinq.eclair.TestConstants.emptyOnionPacket
import fr.acinq.eclair.blockchain.fee.FeeratePerKw import fr.acinq.eclair.blockchain.fee.FeeratePerKw
import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel._
import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.io.{Peer, PeerReadyManager, Switchboard}
import fr.acinq.eclair.payment.IncomingPaymentPacket.ChannelRelayPacket import fr.acinq.eclair.payment.IncomingPaymentPacket.ChannelRelayPacket
import fr.acinq.eclair.payment.relay.ChannelRelayer._ import fr.acinq.eclair.payment.relay.ChannelRelayer._
import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket, PaymentPacketSpec} import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket, PaymentPacketSpec}
@ -39,19 +41,26 @@ import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload.ChannelRel
import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiry, NodeParams, RealShortChannelId, TestConstants, randomBytes32, _} import fr.acinq.eclair.{CltvExpiry, NodeParams, RealShortChannelId, TestConstants, randomBytes32, _}
import org.scalatest.Inside.inside import org.scalatest.Inside.inside
import org.scalatest.Outcome
import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.funsuite.FixtureAnyFunSuiteLike
import org.scalatest.{Outcome, Tag}
import scodec.bits.HexStringSyntax import scodec.bits.HexStringSyntax
import scala.concurrent.duration.DurationInt
class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike { class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike {
import ChannelRelayerSpec._ import ChannelRelayerSpec._
val wakeUpEnabled = "wake_up_enabled"
val wakeUpTimeout = "wake_up_timeout"
case class FixtureParam(nodeParams: NodeParams, channelRelayer: typed.ActorRef[ChannelRelayer.Command], register: TestProbe[Any]) case class FixtureParam(nodeParams: NodeParams, channelRelayer: typed.ActorRef[ChannelRelayer.Command], register: TestProbe[Any])
override def withFixture(test: OneArgTest): Outcome = { override def withFixture(test: OneArgTest): Outcome = {
// we are node B in the route A -> B -> C -> .... // we are node B in the route A -> B -> C -> ....
val nodeParams = TestConstants.Bob.nodeParams val nodeParams = TestConstants.Bob.nodeParams
.modify(_.peerWakeUpConfig.enabled).setToIf(test.tags.contains(wakeUpEnabled))(true)
.modify(_.peerWakeUpConfig.timeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis)
val register = TestProbe[Any]("register") val register = TestProbe[Any]("register")
val channelRelayer = testKit.spawn(ChannelRelayer.apply(nodeParams, register.ref.toClassic)) val channelRelayer = testKit.spawn(ChannelRelayer.apply(nodeParams, register.ref.toClassic))
try { try {
@ -157,7 +166,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a
import f._ import f._
val u = createLocalUpdate(channelId1, feeBaseMsat = 2500 msat, feeProportionalMillionths = 0) val u = createLocalUpdate(channelId1, feeBaseMsat = 2500 msat, feeProportionalMillionths = 0)
val payload = createBlindedPayload(u.channelUpdate, isIntroduction = false) val payload = createBlindedPayload(Right(u.channelUpdate.shortChannelId), u.channelUpdate, isIntroduction = false)
val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta)
channelRelayer ! WrappedLocalChannelUpdate(u) channelRelayer ! WrappedLocalChannelUpdate(u)
@ -166,6 +175,34 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a
expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry, 7) expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry, 7)
} }
test("relay blinded payment (wake up wallet node)", Tag(wakeUpEnabled)) { f =>
import f._
val peerReadyManager = TestProbe[PeerReadyManager.Register]()
system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
val switchboard = TestProbe[Switchboard.GetPeerInfo]()
system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref)
val u = createLocalUpdate(channelId1, feeBaseMsat = 2500 msat, feeProportionalMillionths = 0)
Seq(true, false).foreach(isIntroduction => {
val payload = createBlindedPayload(Left(outgoingNodeId), u.channelUpdate, isIntroduction)
val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta)
channelRelayer ! WrappedLocalChannelUpdate(u)
channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId)
// We try to wake-up the next node.
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(outgoingNodeId, otherAttempts = 0)
val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo]
assert(wakeUp.remoteNodeId == outgoingNodeId)
wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty)
expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry, 7)
})
system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref)
}
test("relay with retries") { f => test("relay with retries") { f =>
import f._ import f._
@ -270,7 +307,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a
Seq(true, false).foreach { isIntroduction => Seq(true, false).foreach { isIntroduction =>
// The outgoing channel is disabled, so we won't be able to relay the payment. // The outgoing channel is disabled, so we won't be able to relay the payment.
val u = createLocalUpdate(channelId1, feeBaseMsat = 5000 msat, feeProportionalMillionths = 0, enabled = false) val u = createLocalUpdate(channelId1, feeBaseMsat = 5000 msat, feeProportionalMillionths = 0, enabled = false)
val r = createValidIncomingPacket(createBlindedPayload(u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) val r = createValidIncomingPacket(createBlindedPayload(Right(u.channelUpdate.shortChannelId), u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta)
channelRelayer ! WrappedLocalChannelUpdate(u) channelRelayer ! WrappedLocalChannelUpdate(u)
channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId)
@ -293,6 +330,31 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a
} }
} }
test("fail to relay blinded payment (cannot wake up remote node)", Tag(wakeUpEnabled), Tag(wakeUpTimeout)) { f =>
import f._
val peerReadyManager = TestProbe[PeerReadyManager.Register]()
system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
val switchboard = TestProbe[Switchboard.GetPeerInfo]()
system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref)
val u = createLocalUpdate(channelId1, feeBaseMsat = 2500 msat, feeProportionalMillionths = 0)
val payload = createBlindedPayload(Left(outgoingNodeId), u.channelUpdate, isIntroduction = true)
val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta)
channelRelayer ! WrappedLocalChannelUpdate(u)
channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId)
// We try to wake-up the next node, but we timeout before they connect.
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(outgoingNodeId, otherAttempts = 0)
assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == outgoingNodeId)
val fail = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
assert(fail.message.reason.contains(InvalidOnionBlinding(Sphinx.hash(r.add.onionRoutingPacket))))
system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref)
}
test("relay when expiry larger than our requirements") { f => test("relay when expiry larger than our requirements") { f =>
import f._ import f._
@ -519,7 +581,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a
Seq(true, false).foreach { isIntroduction => Seq(true, false).foreach { isIntroduction =>
testCases.foreach { htlcResult => testCases.foreach { htlcResult =>
val r = createValidIncomingPacket(createBlindedPayload(u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta, endorsementIn = 0) val r = createValidIncomingPacket(createBlindedPayload(Right(u.channelUpdate.shortChannelId), u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta, endorsementIn = 0)
channelRelayer ! WrappedLocalChannelUpdate(u) channelRelayer ! WrappedLocalChannelUpdate(u)
channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId)
val fwd = expectFwdAdd(register, channelId1, outgoingAmount, outgoingExpiry, 0) val fwd = expectFwdAdd(register, channelId1, outgoingAmount, outgoingExpiry, 0)
@ -653,13 +715,16 @@ object ChannelRelayerSpec {
localAlias2 -> channelId2, localAlias2 -> channelId2,
) )
def createBlindedPayload(update: ChannelUpdate, isIntroduction: Boolean): ChannelRelay.Blinded = { def createBlindedPayload(outgoing: Either[PublicKey, ShortChannelId], update: ChannelUpdate, isIntroduction: Boolean): ChannelRelay.Blinded = {
val tlvs = TlvStream[OnionPaymentPayloadTlv](Set( val tlvs = TlvStream[OnionPaymentPayloadTlv](Set(
Some(OnionPaymentPayloadTlv.EncryptedRecipientData(hex"2a")), Some(OnionPaymentPayloadTlv.EncryptedRecipientData(hex"2a")),
if (isIntroduction) Some(OnionPaymentPayloadTlv.BlindingPoint(randomKey().publicKey)) else None, if (isIntroduction) Some(OnionPaymentPayloadTlv.BlindingPoint(randomKey().publicKey)) else None,
).flatten[OnionPaymentPayloadTlv]) ).flatten[OnionPaymentPayloadTlv])
val blindedTlvs = TlvStream[RouteBlindingEncryptedDataTlv]( val blindedTlvs = TlvStream[RouteBlindingEncryptedDataTlv](
RouteBlindingEncryptedDataTlv.OutgoingChannelId(update.shortChannelId), outgoing match {
case Left(nodeId) => RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Wallet(nodeId))
case Right(scid) => RouteBlindingEncryptedDataTlv.OutgoingChannelId(scid)
},
RouteBlindingEncryptedDataTlv.PaymentRelay(update.cltvExpiryDelta, update.feeProportionalMillionths, update.feeBaseMsat), RouteBlindingEncryptedDataTlv.PaymentRelay(update.cltvExpiryDelta, update.feeProportionalMillionths, update.feeBaseMsat),
RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(500_000), 0 msat), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(500_000), 0 msat),
) )

View File

@ -20,35 +20,37 @@ import akka.actor.Status
import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe} import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe}
import akka.actor.typed.ActorRef import akka.actor.typed.ActorRef
import akka.actor.typed.eventstream.EventStream import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.receptionist.Receptionist
import akka.actor.typed.scaladsl.ActorContext import akka.actor.typed.scaladsl.ActorContext
import akka.actor.typed.scaladsl.adapter._ import akka.actor.typed.scaladsl.adapter._
import com.softwaremill.quicklens.ModifyPimp import com.softwaremill.quicklens.ModifyPimp
import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigFactory
import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
import fr.acinq.bitcoin.scalacompat.{Block, BlockHash, ByteVector32, Crypto} import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto}
import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir
import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional}
import fr.acinq.eclair.Features.{AsyncPaymentPrototype, BasicMultiPartPayment, PaymentSecret, VariableLengthOnion} import fr.acinq.eclair.Features.{AsyncPaymentPrototype, BasicMultiPartPayment, PaymentSecret, VariableLengthOnion}
import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir
import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register, Upstream} import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register, Upstream}
import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.io.{Peer, PeerReadyManager, Switchboard}
import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop
import fr.acinq.eclair.payment.IncomingPaymentPacket.{RelayToBlindedPathsPacket, RelayToTrampolinePacket} import fr.acinq.eclair.payment.IncomingPaymentPacket.{RelayToBlindedPathsPacket, RelayToTrampolinePacket}
import fr.acinq.eclair.payment.Invoice.ExtraEdge import fr.acinq.eclair.payment.Invoice.ExtraEdge
import fr.acinq.eclair.payment.OutgoingPaymentPacket.NodePayload
import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment._
import fr.acinq.eclair.payment.relay.AsyncPaymentTriggerer.{AsyncPaymentCanceled, AsyncPaymentTimeout, AsyncPaymentTriggered, Watch}
import fr.acinq.eclair.payment.relay.NodeRelayer.PaymentKey import fr.acinq.eclair.payment.relay.NodeRelayer.PaymentKey
import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment} import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment}
import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig
import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToNode import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToNode
import fr.acinq.eclair.payment.send.{BlindedRecipient, ClearRecipient} import fr.acinq.eclair.payment.send.{BlindedRecipient, ClearRecipient}
import fr.acinq.eclair.router.Router.RouteRequest import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams, RouteRequest}
import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound, Router} import fr.acinq.eclair.router.{BalanceTooLow, BlindedRouteCreation, RouteNotFound, Router}
import fr.acinq.eclair.wire.protocol.OfferTypes._ import fr.acinq.eclair.wire.protocol.OfferTypes._
import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload} import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload}
import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec
import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{AllowedFeatures, PathId, PaymentConstraints} import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{AllowedFeatures, PathId, PaymentConstraints}
import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{BlockHeight, Bolt11Feature, CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, RealShortChannelId, ShortChannelId, TestConstants, TimestampMilli, UInt64, randomBytes, randomBytes32, randomKey} import fr.acinq.eclair.{Alias, BlockHeight, Bolt11Feature, Bolt12Feature, CltvExpiry, CltvExpiryDelta, EncodedNodeId, FeatureSupport, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, RealShortChannelId, ShortChannelId, TestConstants, TimestampMilli, UInt64, randomBytes32, randomKey}
import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.funsuite.FixtureAnyFunSuiteLike
import org.scalatest.{Outcome, Tag} import org.scalatest.{Outcome, Tag}
import scodec.bits.{ByteVector, HexStringSyntax} import scodec.bits.{ByteVector, HexStringSyntax}
@ -65,11 +67,14 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
import NodeRelayerSpec._ import NodeRelayerSpec._
case class FixtureParam(nodeParams: NodeParams, router: TestProbe[Any], register: TestProbe[Any], mockPayFSM: TestProbe[Any], eventListener: TestProbe[PaymentEvent], triggerer: TestProbe[AsyncPaymentTriggerer.Command]) { val wakeUpEnabled = "wake_up_enabled"
val wakeUpTimeout = "wake_up_timeout"
case class FixtureParam(nodeParams: NodeParams, router: TestProbe[Any], register: TestProbe[Any], mockPayFSM: TestProbe[Any], eventListener: TestProbe[PaymentEvent]) {
def createNodeRelay(packetIn: IncomingPaymentPacket.NodeRelayPacket, useRealPaymentFactory: Boolean = false): (ActorRef[NodeRelay.Command], TestProbe[NodeRelayer.Command]) = { def createNodeRelay(packetIn: IncomingPaymentPacket.NodeRelayPacket, useRealPaymentFactory: Boolean = false): (ActorRef[NodeRelay.Command], TestProbe[NodeRelayer.Command]) = {
val parent = TestProbe[NodeRelayer.Command]("parent-relayer") val parent = TestProbe[NodeRelayer.Command]("parent-relayer")
val outgoingPaymentFactory = if (useRealPaymentFactory) RealOutgoingPaymentFactory(this) else FakeOutgoingPaymentFactory(this) val outgoingPaymentFactory = if (useRealPaymentFactory) RealOutgoingPaymentFactory(this) else FakeOutgoingPaymentFactory(this)
val nodeRelay = testKit.spawn(NodeRelay(nodeParams, parent.ref, register.ref.toClassic, relayId, packetIn, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic)) val nodeRelay = testKit.spawn(NodeRelay(nodeParams, parent.ref, register.ref.toClassic, relayId, packetIn, outgoingPaymentFactory, router.ref.toClassic))
(nodeRelay, parent) (nodeRelay, parent)
} }
} }
@ -92,21 +97,21 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
override def withFixture(test: OneArgTest): Outcome = { override def withFixture(test: OneArgTest): Outcome = {
val nodeParams = TestConstants.Bob.nodeParams val nodeParams = TestConstants.Bob.nodeParams
.modify(_.multiPartPaymentExpiry).setTo(5 seconds) .modify(_.multiPartPaymentExpiry).setTo(5 seconds)
.modify(_.features).setToIf(test.tags.contains("async_payments"))(Features(AsyncPaymentPrototype -> Optional))
.modify(_.relayParams.asyncPaymentsParams.holdTimeoutBlocks).setToIf(test.tags.contains("long_hold_timeout"))(200000) // timeout after payment expires .modify(_.relayParams.asyncPaymentsParams.holdTimeoutBlocks).setToIf(test.tags.contains("long_hold_timeout"))(200000) // timeout after payment expires
.modify(_.peerWakeUpConfig.enabled).setToIf(test.tags.contains(wakeUpEnabled))(true)
.modify(_.peerWakeUpConfig.timeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis)
val router = TestProbe[Any]("router") val router = TestProbe[Any]("router")
val register = TestProbe[Any]("register") val register = TestProbe[Any]("register")
val eventListener = TestProbe[PaymentEvent]("event-listener") val eventListener = TestProbe[PaymentEvent]("event-listener")
system.eventStream ! EventStream.Subscribe(eventListener.ref) system.eventStream ! EventStream.Subscribe(eventListener.ref)
val mockPayFSM = TestProbe[Any]("pay-fsm") val mockPayFSM = TestProbe[Any]("pay-fsm")
val triggerer = TestProbe[AsyncPaymentTriggerer.Command]("payment-triggerer") withFixture(test.toNoArgTest(FixtureParam(nodeParams, router, register, mockPayFSM, eventListener)))
withFixture(test.toNoArgTest(FixtureParam(nodeParams, router, register, mockPayFSM, eventListener, triggerer)))
} }
test("create child handlers for new payments") { f => test("create child handlers for new payments") { f =>
import f._ import f._
val probe = TestProbe[Any]() val probe = TestProbe[Any]()
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, FakeOutgoingPaymentFactory(f), triggerer.ref, router.ref.toClassic)) val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, FakeOutgoingPaymentFactory(f), router.ref.toClassic))
parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic)
probe.expectMessage(Map.empty) probe.expectMessage(Map.empty)
@ -145,7 +150,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
val outgoingPaymentFactory = FakeOutgoingPaymentFactory(f) val outgoingPaymentFactory = FakeOutgoingPaymentFactory(f)
{ {
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic)) val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, router.ref.toClassic))
parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic)
probe.expectMessage(Map.empty) probe.expectMessage(Map.empty)
} }
@ -153,7 +158,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
val (paymentHash1, paymentSecret1, child1) = (randomBytes32(), randomBytes32(), TestProbe[NodeRelay.Command]()) val (paymentHash1, paymentSecret1, child1) = (randomBytes32(), randomBytes32(), TestProbe[NodeRelay.Command]())
val (paymentHash2, paymentSecret2, child2) = (randomBytes32(), randomBytes32(), TestProbe[NodeRelay.Command]()) val (paymentHash2, paymentSecret2, child2) = (randomBytes32(), randomBytes32(), TestProbe[NodeRelay.Command]())
val children = Map(PaymentKey(paymentHash1, paymentSecret1) -> child1.ref, PaymentKey(paymentHash2, paymentSecret2) -> child2.ref) val children = Map(PaymentKey(paymentHash1, paymentSecret1) -> child1.ref, PaymentKey(paymentHash2, paymentSecret2) -> child2.ref)
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic, children)) val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, router.ref.toClassic, children))
parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic)
probe.expectMessage(children) probe.expectMessage(children)
@ -169,7 +174,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
val (paymentSecret1, child1) = (randomBytes32(), TestProbe[NodeRelay.Command]()) val (paymentSecret1, child1) = (randomBytes32(), TestProbe[NodeRelay.Command]())
val (paymentSecret2, child2) = (randomBytes32(), TestProbe[NodeRelay.Command]()) val (paymentSecret2, child2) = (randomBytes32(), TestProbe[NodeRelay.Command]())
val children = Map(PaymentKey(paymentHash, paymentSecret1) -> child1.ref, PaymentKey(paymentHash, paymentSecret2) -> child2.ref) val children = Map(PaymentKey(paymentHash, paymentSecret1) -> child1.ref, PaymentKey(paymentHash, paymentSecret2) -> child2.ref)
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic, children)) val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, router.ref.toClassic, children))
parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic)
probe.expectMessage(children) probe.expectMessage(children)
@ -179,7 +184,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
probe.expectMessage(Map(PaymentKey(paymentHash, paymentSecret2) -> child2.ref)) probe.expectMessage(Map(PaymentKey(paymentHash, paymentSecret2) -> child2.ref))
} }
{ {
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic)) val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, router.ref.toClassic))
parentRelayer ! NodeRelayer.Relay(incomingMultiPart.head, randomKey().publicKey) parentRelayer ! NodeRelayer.Relay(incomingMultiPart.head, randomKey().publicKey)
parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic)
val pending1 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]] val pending1 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]]
@ -228,7 +233,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1000 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None, 1.0), UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1000 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None, 1.0),
FinalPayload.Standard.createPayload(1000 msat, incomingAmount, CltvExpiry(499990), incomingSecret, None), FinalPayload.Standard.createPayload(1000 msat, incomingAmount, CltvExpiry(499990), incomingSecret, None),
IntermediatePayload.NodeRelay.Standard(outgoingAmount, outgoingExpiry, outgoingNodeId), IntermediatePayload.NodeRelay.Standard(outgoingAmount, outgoingExpiry, outgoingNodeId),
nextTrampolinePacket) createTrampolinePacket(outgoingAmount, outgoingExpiry))
nodeRelayer ! NodeRelay.Relay(extra, randomKey().publicKey) nodeRelayer ! NodeRelay.Relay(extra, randomKey().publicKey)
// the extra payment will be rejected // the extra payment will be rejected
@ -257,7 +262,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1000 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None, 1.0), UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1000 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None, 1.0),
FinalPayload.Standard.createPayload(1000 msat, incomingAmount, CltvExpiry(499990), incomingSecret, None), FinalPayload.Standard.createPayload(1000 msat, incomingAmount, CltvExpiry(499990), incomingSecret, None),
IntermediatePayload.NodeRelay.Standard(outgoingAmount, outgoingExpiry, outgoingNodeId), IntermediatePayload.NodeRelay.Standard(outgoingAmount, outgoingExpiry, outgoingNodeId),
nextTrampolinePacket) createTrampolinePacket(outgoingAmount, outgoingExpiry))
nodeRelayer ! NodeRelay.Relay(i1, randomKey().publicKey) nodeRelayer ! NodeRelay.Relay(i1, randomKey().publicKey)
val fwd1 = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] val fwd1 = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
@ -270,7 +275,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1500 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None, 1.0), UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1500 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None, 1.0),
PaymentOnion.FinalPayload.Standard.createPayload(1500 msat, 1500 msat, CltvExpiry(499990), incomingSecret, None), PaymentOnion.FinalPayload.Standard.createPayload(1500 msat, 1500 msat, CltvExpiry(499990), incomingSecret, None),
IntermediatePayload.NodeRelay.Standard(1250 msat, outgoingExpiry, outgoingNodeId), IntermediatePayload.NodeRelay.Standard(1250 msat, outgoingExpiry, outgoingNodeId),
nextTrampolinePacket) createTrampolinePacket(outgoingAmount, outgoingExpiry))
nodeRelayer ! NodeRelay.Relay(i2, randomKey().publicKey) nodeRelayer ! NodeRelay.Relay(i2, randomKey().publicKey)
val fwd2 = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] val fwd2 = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
@ -335,115 +340,6 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
register.expectNoMessage(100 millis) register.expectNoMessage(100 millis)
} }
test("fail to relay when not triggered before the hold timeout", Tag("async_payments")) { f =>
import f._
val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head)
incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey))
// wait until the NodeRelay is waiting for the trigger
eventListener.expectMessageType[WaitingToRelayPayment]
mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a trigger
// publish notification that peer is unavailable at the timeout height
val peerWatch = triggerer.expectMessageType[Watch]
assert(asyncTimeoutHeight(nodeParams) < asyncSafetyHeight(incomingAsyncPayment, nodeParams))
assert(peerWatch.timeout == asyncTimeoutHeight(nodeParams))
peerWatch.replyTo ! AsyncPaymentTimeout
incomingAsyncPayment.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure()), commit = true))
}
register.expectNoMessage(100 millis)
}
test("relay the payment when triggered while waiting", Tag("async_payments"), Tag("long_hold_timeout")) { f =>
import f._
val (nodeRelayer, parent) = createNodeRelay(incomingAsyncPayment.head)
incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey))
// wait until the NodeRelay is waiting for the trigger
eventListener.expectMessageType[WaitingToRelayPayment]
mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a trigger
// publish notification that peer is ready before the safety interval before the current incoming payment expires (and before the timeout height)
val peerWatch = triggerer.expectMessageType[Watch]
assert(asyncTimeoutHeight(nodeParams) > asyncSafetyHeight(incomingAsyncPayment, nodeParams))
assert(peerWatch.timeout == asyncSafetyHeight(incomingAsyncPayment, nodeParams))
peerWatch.replyTo ! AsyncPaymentTriggered
// upstream payment relayed
val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig]
validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingAsyncPayment.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5)
val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment]
validateOutgoingPayment(outgoingPayment)
// those are adapters for pay-fsm messages
val nodeRelayerAdapters = outgoingPayment.replyTo
// A first downstream HTLC is fulfilled: we should immediately forward the fulfill upstream.
nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage)
incomingAsyncPayment.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true))
}
// Once all the downstream payments have settled, we should emit the relayed event.
nodeRelayerAdapters ! createSuccessEvent()
val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed]
validateRelayEvent(relayEvent)
assert(relayEvent.incoming.map(p => (p.amount, p.channelId)).toSet == incomingAsyncPayment.map(i => (i.add.amountMsat, i.add.channelId)).toSet)
assert(relayEvent.outgoing.nonEmpty)
parent.expectMessageType[NodeRelayer.RelayComplete]
register.expectNoMessage(100 millis)
}
test("fail to relay when not triggered before the incoming expiry safety timeout", Tag("async_payments"), Tag("long_hold_timeout")) { f =>
import f._
val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head)
incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey))
mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a complete upstream payment
// publish notification that peer is unavailable at the cancel-safety-before-timeout-block threshold before the current incoming payment expires (and before the timeout height)
val peerWatch = triggerer.expectMessageType[Watch]
assert(asyncTimeoutHeight(nodeParams) > asyncSafetyHeight(incomingAsyncPayment, nodeParams))
assert(peerWatch.timeout == asyncSafetyHeight(incomingAsyncPayment, nodeParams))
peerWatch.replyTo ! AsyncPaymentTimeout
incomingAsyncPayment.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure()), commit = true))
}
register.expectNoMessage(100 millis)
}
test("fail to relay payment when canceled by sender before timeout", Tag("async_payments")) { f =>
import f._
val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head)
incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey))
// wait until the NodeRelay is waiting for the trigger
eventListener.expectMessageType[WaitingToRelayPayment]
mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a trigger
// fail the payment if waiting when payment sender sends cancel message
nodeRelayer ! NodeRelay.WrappedPeerReadyResult(AsyncPaymentCanceled)
incomingAsyncPayment.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure()), commit = true))
}
register.expectNoMessage(100 millis)
}
test("relay the payment immediately when the async payment feature is disabled") { f => test("relay the payment immediately when the async payment feature is disabled") { f =>
import f._ import f._
@ -827,26 +723,15 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
} }
} }
def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedRoute = {
val selfPayload = blindedRouteDataCodec.encode(TlvStream(PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty))).require.bytes
PaymentBlindedRoute(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route, PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty))
}
test("relay to blinded paths without multi-part") { f => test("relay to blinded paths without multi-part") { f =>
import f._ import f._
val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32())) val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features.empty, None)
val offer = Offer(None, Some("test offer"), outgoingNodeId, Features.empty, chain)
val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain)
val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(createPaymentBlindedRoute(outgoingNodeId)))
val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths(
incoming.innerPayload.amountToForward, outgoingExpiry, invoice
)))
val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head)
incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey))
val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig]
validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true)
val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode] val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode]
assert(outgoingPayment.amount == outgoingAmount) assert(outgoingPayment.amount == outgoingAmount)
assert(outgoingPayment.recipient.expiry == outgoingExpiry) assert(outgoingPayment.recipient.expiry == outgoingExpiry)
@ -856,7 +741,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
val nodeRelayerAdapters = outgoingPayment.replyTo val nodeRelayerAdapters = outgoingPayment.replyTo
nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage)
incomingMultiPart.foreach { p => incomingPayments.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]]
assert(fwd.channelId == p.add.channelId) assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true))
@ -865,7 +750,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
nodeRelayerAdapters ! createSuccessEvent() nodeRelayerAdapters ! createSuccessEvent()
val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed]
validateRelayEvent(relayEvent) validateRelayEvent(relayEvent)
assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingMultiPart.map(i => (i.add.amountMsat, i.add.channelId))) assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId)))
assert(relayEvent.outgoing.length == 1) assert(relayEvent.outgoing.length == 1)
parent.expectMessageType[NodeRelayer.RelayComplete] parent.expectMessageType[NodeRelayer.RelayComplete]
register.expectNoMessage(100 millis) register.expectNoMessage(100 millis)
@ -874,18 +759,12 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
test("relay to blinded paths with multi-part") { f => test("relay to blinded paths with multi-part") { f =>
import f._ import f._
val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32())) val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), None)
val offer = Offer(None, Some("test offer"), outgoingNodeId, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), chain)
val request = InvoiceRequest(offer, outgoingAmount, 1, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), payerKey, chain)
val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), Seq(createPaymentBlindedRoute(outgoingNodeId)))
val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths(
incoming.innerPayload.amountToForward, outgoingExpiry, invoice
)))
val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head)
incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey))
val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig]
validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true)
val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment]
assert(outgoingPayment.recipient.totalAmount == outgoingAmount) assert(outgoingPayment.recipient.totalAmount == outgoingAmount)
assert(outgoingPayment.recipient.expiry == outgoingExpiry) assert(outgoingPayment.recipient.expiry == outgoingExpiry)
@ -895,7 +774,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
val nodeRelayerAdapters = outgoingPayment.replyTo val nodeRelayerAdapters = outgoingPayment.replyTo
nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage)
incomingMultiPart.foreach { p => incomingPayments.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]]
assert(fwd.channelId == p.add.channelId) assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true))
@ -904,25 +783,89 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
nodeRelayerAdapters ! createSuccessEvent() nodeRelayerAdapters ! createSuccessEvent()
val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed]
validateRelayEvent(relayEvent) validateRelayEvent(relayEvent)
assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingMultiPart.map(i => (i.add.amountMsat, i.add.channelId))) assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId)))
assert(relayEvent.outgoing.length == 1) assert(relayEvent.outgoing.length == 1)
parent.expectMessageType[NodeRelayer.RelayComplete] parent.expectMessageType[NodeRelayer.RelayComplete]
register.expectNoMessage(100 millis) register.expectNoMessage(100 millis)
} }
test("relay to blinded path with wake-up", Tag(wakeUpEnabled)) { f =>
import f._
val peerReadyManager = TestProbe[PeerReadyManager.Register]()
system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
val switchboard = TestProbe[Switchboard.GetPeerInfo]()
system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref)
val incomingPayments = createIncomingPaymentsToWalletBlindedPath(nodeParams)
val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head)
incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey))
// The remote node is a wallet node: we try to wake them up before relaying the payment.
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(outgoingNodeId, otherAttempts = 0)
val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo]
assert(wakeUp.remoteNodeId == outgoingNodeId)
wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty)
system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref)
val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig]
validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true)
val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment]
assert(outgoingPayment.recipient.totalAmount == outgoingAmount)
assert(outgoingPayment.recipient.expiry == outgoingExpiry)
assert(outgoingPayment.recipient.isInstanceOf[BlindedRecipient])
// those are adapters for pay-fsm messages
val nodeRelayerAdapters = outgoingPayment.replyTo
nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage)
incomingPayments.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true))
}
nodeRelayerAdapters ! createSuccessEvent()
val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed]
validateRelayEvent(relayEvent)
assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId)))
assert(relayEvent.outgoing.length == 1)
parent.expectMessageType[NodeRelayer.RelayComplete]
register.expectNoMessage(100 millis)
}
test("fail to relay to blinded path when wake-up fails", Tag(wakeUpEnabled), Tag(wakeUpTimeout)) { f =>
import f._
val peerReadyManager = TestProbe[PeerReadyManager.Register]()
system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
val switchboard = TestProbe[Switchboard.GetPeerInfo]()
system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref)
val incomingPayments = createIncomingPaymentsToWalletBlindedPath(nodeParams)
val (nodeRelayer, _) = f.createNodeRelay(incomingPayments.head)
incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey))
// The remote node is a wallet node: we try to wake them up before relaying the payment, but it times out.
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(outgoingNodeId, otherAttempts = 0)
assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == outgoingNodeId)
system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref)
mockPayFSM.expectNoMessage(100 millis)
incomingPayments.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(UnknownNextPeer()), commit = true))
}
}
test("relay to compact blinded paths") { f => test("relay to compact blinded paths") { f =>
import f._ import f._
val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32()))
val offer = Offer(None, Some("test offer"), outgoingNodeId, Features.empty, chain)
val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain)
val paymentBlindedRoute = createPaymentBlindedRoute(outgoingNodeId)
val scidDir = ShortChannelIdDir(isNode1 = true, RealShortChannelId(123456L)) val scidDir = ShortChannelIdDir(isNode1 = true, RealShortChannelId(123456L))
val compactPaymentBlindedRoute = paymentBlindedRoute.copy(route = paymentBlindedRoute.route.copy(introductionNodeId = scidDir)) val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features.empty, Some(scidDir))
val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(compactPaymentBlindedRoute))
val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths(
incoming.innerPayload.amountToForward, outgoingExpiry, invoice
)))
val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head)
incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey))
@ -932,7 +875,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
getNodeId.replyTo ! Some(outgoingNodeId) getNodeId.replyTo ! Some(outgoingNodeId)
val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig]
validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true)
val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode] val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode]
assert(outgoingPayment.amount == outgoingAmount) assert(outgoingPayment.amount == outgoingAmount)
assert(outgoingPayment.recipient.expiry == outgoingExpiry) assert(outgoingPayment.recipient.expiry == outgoingExpiry)
@ -942,7 +885,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
val nodeRelayerAdapters = outgoingPayment.replyTo val nodeRelayerAdapters = outgoingPayment.replyTo
nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage)
incomingMultiPart.foreach { p => incomingPayments.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]]
assert(fwd.channelId == p.add.channelId) assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true))
@ -951,7 +894,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
nodeRelayerAdapters ! createSuccessEvent() nodeRelayerAdapters ! createSuccessEvent()
val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed]
validateRelayEvent(relayEvent) validateRelayEvent(relayEvent)
assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingMultiPart.map(i => (i.add.amountMsat, i.add.channelId))) assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId)))
assert(relayEvent.outgoing.length == 1) assert(relayEvent.outgoing.length == 1)
parent.expectMessageType[NodeRelayer.RelayComplete] parent.expectMessageType[NodeRelayer.RelayComplete]
register.expectNoMessage(100 millis) register.expectNoMessage(100 millis)
@ -960,16 +903,8 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
test("fail to relay to compact blinded paths with unknown scid") { f => test("fail to relay to compact blinded paths with unknown scid") { f =>
import f._ import f._
val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32()))
val offer = Offer(None, Some("test offer"), outgoingNodeId, Features.empty, chain)
val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain)
val paymentBlindedRoute = createPaymentBlindedRoute(outgoingNodeId)
val scidDir = ShortChannelIdDir(isNode1 = true, RealShortChannelId(123456L)) val scidDir = ShortChannelIdDir(isNode1 = true, RealShortChannelId(123456L))
val compactPaymentBlindedRoute = paymentBlindedRoute.copy(route = paymentBlindedRoute.route.copy(introductionNodeId = scidDir)) val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features.empty, Some(scidDir))
val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(compactPaymentBlindedRoute))
val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths(
incoming.innerPayload.amountToForward, outgoingExpiry, invoice
)))
val (nodeRelayer, _) = f.createNodeRelay(incomingPayments.head) val (nodeRelayer, _) = f.createNodeRelay(incomingPayments.head)
incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey))
@ -980,7 +915,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
mockPayFSM.expectNoMessage(100 millis) mockPayFSM.expectNoMessage(100 millis)
incomingMultiPart.foreach { p => incomingPayments.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
assert(fwd.channelId == p.add.channelId) assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(UnknownNextPeer()), commit = true)) assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(UnknownNextPeer()), commit = true))
@ -1008,7 +943,9 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
assert(outgoingPayment.recipient.isInstanceOf[ClearRecipient]) assert(outgoingPayment.recipient.isInstanceOf[ClearRecipient])
val recipient = outgoingPayment.recipient.asInstanceOf[ClearRecipient] val recipient = outgoingPayment.recipient.asInstanceOf[ClearRecipient]
assert(recipient.paymentSecret !== incomingSecret) // we should generate a new outgoing secret assert(recipient.paymentSecret !== incomingSecret) // we should generate a new outgoing secret
assert(recipient.nextTrampolineOnion_opt.contains(nextTrampolinePacket)) assert(recipient.nextTrampolineOnion_opt.nonEmpty)
// The recipient is able to decrypt the trampoline onion.
recipient.nextTrampolineOnion_opt.foreach(onion => assert(IncomingPaymentPacket.decryptOnion(paymentHash, outgoingNodeKey, onion).isRight))
} }
def validateRelayEvent(e: TrampolinePaymentRelayed): Unit = { def validateRelayEvent(e: TrampolinePaymentRelayed): Unit = {
@ -1025,10 +962,7 @@ object NodeRelayerSpec {
val paymentPreimage = randomBytes32() val paymentPreimage = randomBytes32()
val paymentHash = Crypto.sha256(paymentPreimage) val paymentHash = Crypto.sha256(paymentPreimage)
val paymentSecret = randomBytes32()
// This is the result of decrypting the incoming trampoline onion packet.
// It should be forwarded to the next trampoline node.
val nextTrampolinePacket = OnionRoutingPacket(0, hex"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619", randomBytes(400), randomBytes32())
val outgoingAmount = 40_000_000 msat val outgoingAmount = 40_000_000 msat
val outgoingExpiry = CltvExpiry(490000) val outgoingExpiry = CltvExpiry(490000)
@ -1054,6 +988,12 @@ object NodeRelayerSpec {
def createSuccessEvent(): PaymentSent = def createSuccessEvent(): PaymentSent =
PaymentSent(relayId, paymentHash, paymentPreimage, outgoingAmount, outgoingNodeId, Seq(PaymentSent.PartialPayment(UUID.randomUUID(), outgoingAmount, 10 msat, randomBytes32(), None))) PaymentSent(relayId, paymentHash, paymentPreimage, outgoingAmount, outgoingNodeId, Seq(PaymentSent.PartialPayment(UUID.randomUUID(), outgoingAmount, 10 msat, randomBytes32(), None)))
def createTrampolinePacket(amount: MilliSatoshi, expiry: CltvExpiry): OnionRoutingPacket = {
val payload = NodePayload(outgoingNodeId, FinalPayload.Standard.createPayload(amount, amount, expiry, paymentSecret))
val Right(onion) = OutgoingPaymentPacket.buildOnion(Seq(payload), paymentHash, None)
onion.packet
}
def createValidIncomingPacket(amountIn: MilliSatoshi, totalAmountIn: MilliSatoshi, expiryIn: CltvExpiry, amountOut: MilliSatoshi, expiryOut: CltvExpiry, endorsementIn: Int = 7): RelayToTrampolinePacket = { def createValidIncomingPacket(amountIn: MilliSatoshi, totalAmountIn: MilliSatoshi, expiryIn: CltvExpiry, amountOut: MilliSatoshi, expiryOut: CltvExpiry, endorsementIn: Int = 7): RelayToTrampolinePacket = {
val outerPayload = FinalPayload.Standard.createPayload(amountIn, totalAmountIn, expiryIn, incomingSecret, None) val outerPayload = FinalPayload.Standard.createPayload(amountIn, totalAmountIn, expiryIn, incomingSecret, None)
val tlvs = TlvStream[UpdateAddHtlcTlv](UpdateAddHtlcTlv.Endorsement(endorsementIn)) val tlvs = TlvStream[UpdateAddHtlcTlv](UpdateAddHtlcTlv.Endorsement(endorsementIn))
@ -1061,7 +1001,7 @@ object NodeRelayerSpec {
UpdateAddHtlc(randomBytes32(), Random.nextInt(100), amountIn, paymentHash, expiryIn, TestConstants.emptyOnionPacket, tlvs), UpdateAddHtlc(randomBytes32(), Random.nextInt(100), amountIn, paymentHash, expiryIn, TestConstants.emptyOnionPacket, tlvs),
outerPayload, outerPayload,
IntermediatePayload.NodeRelay.Standard(amountOut, expiryOut, outgoingNodeId), IntermediatePayload.NodeRelay.Standard(amountOut, expiryOut, outgoingNodeId),
nextTrampolinePacket) createTrampolinePacket(amountOut, expiryOut))
} }
def createPartialIncomingPacket(paymentHash: ByteVector32, paymentSecret: ByteVector32): RelayToTrampolinePacket = { def createPartialIncomingPacket(paymentHash: ByteVector32, paymentSecret: ByteVector32): RelayToTrampolinePacket = {
@ -1071,7 +1011,46 @@ object NodeRelayerSpec {
UpdateAddHtlc(randomBytes32(), Random.nextInt(100), amountIn, paymentHash, expiryIn, TestConstants.emptyOnionPacket, None, 1.0), UpdateAddHtlc(randomBytes32(), Random.nextInt(100), amountIn, paymentHash, expiryIn, TestConstants.emptyOnionPacket, None, 1.0),
FinalPayload.Standard.createPayload(amountIn, incomingAmount, expiryIn, paymentSecret, None), FinalPayload.Standard.createPayload(amountIn, incomingAmount, expiryIn, paymentSecret, None),
IntermediatePayload.NodeRelay.Standard(outgoingAmount, expiryOut, outgoingNodeId), IntermediatePayload.NodeRelay.Standard(outgoingAmount, expiryOut, outgoingNodeId),
nextTrampolinePacket) createTrampolinePacket(outgoingAmount, expiryOut))
}
def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedRoute = {
val selfPayload = blindedRouteDataCodec.encode(TlvStream(PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty))).require.bytes
PaymentBlindedRoute(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route, PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty))
}
/** Create payments to a blinded path that starts at a remote node. */
def createIncomingPaymentsToRemoteBlindedPath(features: Features[Bolt12Feature], scidDir_opt: Option[EncodedNodeId.ShortChannelIdDir]): Seq[RelayToBlindedPathsPacket] = {
val offer = Offer(None, Some("test offer"), outgoingNodeId, features, Block.RegtestGenesisBlock.hash)
val request = InvoiceRequest(offer, outgoingAmount, 1, features, randomKey(), Block.RegtestGenesisBlock.hash)
val paymentBlindedRoute = scidDir_opt match {
case Some(scidDir) =>
val nonCompact = createPaymentBlindedRoute(outgoingNodeId)
nonCompact.copy(route = nonCompact.route.copy(introductionNodeId = scidDir))
case None =>
createPaymentBlindedRoute(outgoingNodeId)
}
val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, features, Seq(paymentBlindedRoute))
incomingMultiPart.map(incoming => {
val innerPayload = IntermediatePayload.NodeRelay.ToBlindedPaths(incoming.innerPayload.amountToForward, outgoingExpiry, invoice)
RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, innerPayload)
})
}
/** Create payments to a blinded path that starts at our node and relays to a wallet node. */
def createIncomingPaymentsToWalletBlindedPath(nodeParams: NodeParams): Seq[RelayToBlindedPathsPacket] = {
val features: Features[Bolt12Feature] = Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional)
val offer = Offer(None, Some("test offer"), outgoingNodeId, features, Block.RegtestGenesisBlock.hash)
val request = InvoiceRequest(offer, outgoingAmount, 1, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), randomKey(), Block.RegtestGenesisBlock.hash)
val edge = ExtraEdge(nodeParams.nodeId, outgoingNodeId, Alias(561), 2_000_000 msat, 250, CltvExpiryDelta(144), 1 msat, None)
val hop = ChannelHop(edge.shortChannelId, nodeParams.nodeId, outgoingNodeId, HopRelayParams.FromHint(edge))
val route = BlindedRouteCreation.createBlindedRouteToWallet(hop, hex"deadbeef", 1 msat, outgoingExpiry).route
val paymentInfo = BlindedRouteCreation.aggregatePaymentInfo(outgoingAmount, Seq(hop), CltvExpiryDelta(12))
val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, features, Seq(PaymentBlindedRoute(route, paymentInfo)))
incomingMultiPart.map(incoming => {
val innerPayload = IntermediatePayload.NodeRelay.ToBlindedPaths(incoming.innerPayload.amountToForward, outgoingExpiry, invoice)
RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, innerPayload)
})
} }
} }

View File

@ -55,11 +55,10 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat
val router = TestProbe[Any]("router") val router = TestProbe[Any]("router")
val register = TestProbe[Any]("register") val register = TestProbe[Any]("register")
val paymentHandler = TestProbe[Any]("payment-handler") val paymentHandler = TestProbe[Any]("payment-handler")
val triggerer = TestProbe[AsyncPaymentTriggerer.Command]("payment-triggerer")
val probe = TestProbe[Any]() val probe = TestProbe[Any]()
// we can't spawn top-level actors with akka typed // we can't spawn top-level actors with akka typed
testKit.spawn(Behaviors.setup[Any] { context => testKit.spawn(Behaviors.setup[Any] { context =>
val relayer = context.toClassic.actorOf(Relayer.props(nodeParams, router.ref.toClassic, register.ref.toClassic, paymentHandler.ref.toClassic, triggerer.ref)) val relayer = context.toClassic.actorOf(Relayer.props(nodeParams, router.ref.toClassic, register.ref.toClassic, paymentHandler.ref.toClassic))
probe.ref ! relayer probe.ref ! relayer
Behaviors.empty[Any] Behaviors.empty[Any]
}) })

View File

@ -31,7 +31,7 @@ import fr.acinq.eclair.payment.send.BlindedPathsResolver.{FullBlindedRoute, Part
import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams} import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams}
import fr.acinq.eclair.router.{BlindedRouteCreation, Router} import fr.acinq.eclair.router.{BlindedRouteCreation, Router}
import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo
import fr.acinq.eclair.{BlockHeight, CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshiLong, NodeParams, RealShortChannelId, TestConstants, randomBytes32, randomKey} import fr.acinq.eclair.{Alias, BlockHeight, CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshiLong, NodeParams, RealShortChannelId, TestConstants, randomBytes32, randomKey}
import org.scalatest.Outcome import org.scalatest.Outcome
import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.funsuite.FixtureAnyFunSuiteLike
import scodec.bits.HexStringSyntax import scodec.bits.HexStringSyntax
@ -151,6 +151,31 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l
} }
} }
test("resolve route starting at our node (wallet node)") { f =>
import f._
val probe = TestProbe()
val walletNodeId = randomKey().publicKey
val edge = ExtraEdge(nodeParams.nodeId, walletNodeId, Alias(561), 5_000_000 msat, 200, CltvExpiryDelta(144), 1 msat, None)
val hop = ChannelHop(edge.shortChannelId, nodeParams.nodeId, walletNodeId, HopRelayParams.FromHint(edge))
val route = BlindedRouteCreation.createBlindedRouteToWallet(hop, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route
val paymentInfo = BlindedRouteCreation.aggregatePaymentInfo(100_000_000 msat, Seq(hop), CltvExpiryDelta(12))
val resolver = testKit.spawn(BlindedPathsResolver(nodeParams, randomBytes32(), router.ref, register.ref))
resolver ! Resolve(probe.ref, Seq(PaymentBlindedRoute(route, paymentInfo)))
// We are the introduction node: we decrypt the payload and discover that the next node is a wallet node.
val resolved = probe.expectMsgType[Seq[ResolvedPath]]
assert(resolved.size == 1)
assert(resolved.head.route.isInstanceOf[PartialBlindedRoute])
val partialRoute = resolved.head.route.asInstanceOf[PartialBlindedRoute]
assert(partialRoute.firstNodeId == walletNodeId)
assert(partialRoute.nextNodeId == EncodedNodeId.WithPublicKey.Wallet(walletNodeId))
assert(partialRoute.blindedNodes == route.subsequentNodes)
assert(partialRoute.nextBlinding != route.blindingKey)
// We don't need to resolve the nodeId.
register.expectNoMessage(100 millis)
router.expectNoMessage(100 millis)
}
test("ignore blinded paths that cannot be resolved") { f => test("ignore blinded paths that cannot be resolved") { f =>
import f._ import f._
@ -181,8 +206,9 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l
val probe = TestProbe() val probe = TestProbe()
val scid = RealShortChannelId(BlockHeight(750_000), 3, 7) val scid = RealShortChannelId(BlockHeight(750_000), 3, 7)
val edgeLowFees = ExtraEdge(nodeParams.nodeId, randomKey().publicKey, scid, 100 msat, 5, CltvExpiryDelta(144), 1 msat, None) val nextNodeId = randomKey().publicKey
val edgeLowExpiryDelta = ExtraEdge(nodeParams.nodeId, randomKey().publicKey, scid, 600_000 msat, 100, CltvExpiryDelta(36), 1 msat, None) val edgeLowFees = ExtraEdge(nodeParams.nodeId, nextNodeId, scid, 100 msat, 5, CltvExpiryDelta(144), 1 msat, None)
val edgeLowExpiryDelta = ExtraEdge(nodeParams.nodeId, nextNodeId, scid, 600_000 msat, 100, CltvExpiryDelta(36), 1 msat, None)
val toResolve = Seq( val toResolve = Seq(
// We don't allow paying blinded routes to ourselves. // We don't allow paying blinded routes to ourselves.
BlindedRouteCreation.createBlindedRouteWithoutHops(nodeParams.nodeId, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, BlindedRouteCreation.createBlindedRouteWithoutHops(nodeParams.nodeId, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route,
@ -190,6 +216,8 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l
BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route,
// We reject blinded routes with low cltv_expiry_delta. // We reject blinded routes with low cltv_expiry_delta.
BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowExpiryDelta.targetNodeId, HopRelayParams.FromHint(edgeLowExpiryDelta))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowExpiryDelta.targetNodeId, HopRelayParams.FromHint(edgeLowExpiryDelta))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route,
// We reject blinded routes with low fees, even when the next node seems to be a wallet node.
BlindedRouteCreation.createBlindedRouteToWallet(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees)), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route,
// We reject blinded routes that cannot be decrypted. // We reject blinded routes that cannot be decrypted.
BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route.copy(blindingKey = randomKey().publicKey) BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route.copy(blindingKey = randomKey().publicKey)
).map(r => PaymentBlindedRoute(r, PaymentInfo(1_000_000 msat, 2500, CltvExpiryDelta(300), 1 msat, 500_000_000 msat, Features.empty))) ).map(r => PaymentBlindedRoute(r, PaymentInfo(1_000_000 msat, 2500, CltvExpiryDelta(300), 1 msat, 500_000_000 msat, Features.empty)))

View File

@ -89,7 +89,7 @@ class PaymentOnionSpec extends AnyFunSuite {
val Right(payload) = IntermediatePayload.ChannelRelay.Standard.validate(decoded) val Right(payload) = IntermediatePayload.ChannelRelay.Standard.validate(decoded)
assert(payload.amountOut == 561.msat) assert(payload.amountOut == 561.msat)
assert(payload.cltvOut == CltvExpiry(42)) assert(payload.cltvOut == CltvExpiry(42))
assert(payload.outgoingChannelId == ShortChannelId(1105)) assert(payload.outgoing.contains(ShortChannelId(1105)))
val encoded = perHopPayloadCodec.encode(expected).require.bytes val encoded = perHopPayloadCodec.encode(expected).require.bytes
assert(encoded == bin) assert(encoded == bin)
} }
@ -110,7 +110,7 @@ class PaymentOnionSpec extends AnyFunSuite {
val decoded = perHopPayloadCodec.decode(bin.bits).require.value val decoded = perHopPayloadCodec.decode(bin.bits).require.value
assert(decoded == expected) assert(decoded == expected)
val Right(payload) = IntermediatePayload.ChannelRelay.Blinded.validate(decoded, blindedTlvs, randomKey().publicKey) val Right(payload) = IntermediatePayload.ChannelRelay.Blinded.validate(decoded, blindedTlvs, randomKey().publicKey)
assert(payload.outgoingChannelId == ShortChannelId(42)) assert(payload.outgoing.contains(ShortChannelId(42)))
assert(payload.amountToForward(10_000 msat) == 9990.msat) assert(payload.amountToForward(10_000 msat) == 9990.msat)
assert(payload.outgoingCltv(CltvExpiry(1000)) == CltvExpiry(856)) assert(payload.outgoingCltv(CltvExpiry(1000)) == CltvExpiry(856))
assert(payload.paymentRelayData.allowedFeatures.isEmpty) assert(payload.paymentRelayData.allowedFeatures.isEmpty)
@ -119,6 +119,20 @@ class PaymentOnionSpec extends AnyFunSuite {
} }
} }
test("encode/decode channel relay blinded per-hop-payload (with wallet node_id)") {
val walletNodeId = PublicKey(hex"0221cd519eba9c8b840a5e40b65dc2c040e159a766979723ed770efceb97260ec8")
val blindedTlvs = TlvStream[RouteBlindingEncryptedDataTlv](
RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Wallet(walletNodeId)),
RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(144), 100, 10 msat),
RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat),
)
val Right(payload) = IntermediatePayload.ChannelRelay.Blinded.validate(TlvStream(EncryptedRecipientData(hex"deadbeef")), blindedTlvs, randomKey().publicKey)
assert(payload.outgoing == Left(walletNodeId))
assert(payload.amountToForward(10_000 msat) == 9990.msat)
assert(payload.outgoingCltv(CltvExpiry(1000)) == CltvExpiry(856))
assert(payload.paymentRelayData.allowedFeatures.isEmpty)
}
test("encode/decode node relay per-hop payload") { test("encode/decode node relay per-hop payload") {
val nodeId = PublicKey(hex"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619") val nodeId = PublicKey(hex"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619")
val expected = TlvStream[OnionPaymentPayloadTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), OutgoingNodeId(nodeId)) val expected = TlvStream[OnionPaymentPayloadTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), OutgoingNodeId(nodeId))
@ -292,6 +306,8 @@ class PaymentOnionSpec extends AnyFunSuite {
TestCase(MissingRequiredTlv(UInt64(10)), hex"23 0c21036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2", validBlindedTlvs), TestCase(MissingRequiredTlv(UInt64(10)), hex"23 0c21036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2", validBlindedTlvs),
// Missing encrypted outgoing channel. // Missing encrypted outgoing channel.
TestCase(MissingRequiredTlv(UInt64(2)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(144), 100, 10 msat), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))), TestCase(MissingRequiredTlv(UInt64(2)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(144), 100, 10 msat), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))),
// Forbidden encrypted outgoing plain node_id.
TestCase(ForbiddenTlv(UInt64(4)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Plain(randomKey().publicKey)), RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(144), 100, 10 msat), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))),
// Missing encrypted payment relay data. // Missing encrypted payment relay data.
TestCase(MissingRequiredTlv(UInt64(10)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.OutgoingChannelId(ShortChannelId(42)), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))), TestCase(MissingRequiredTlv(UInt64(10)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.OutgoingChannelId(ShortChannelId(42)), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))),
// Missing encrypted payment constraint. // Missing encrypted payment constraint.