1
0
mirror of https://github.com/ACINQ/eclair.git synced 2024-11-19 01:43:22 +01:00

Wake up wallet nodes before relaying messages or payments (#2865)

We refactor `NodeRelay.scala` to re-order some steps. The steps are:

1. Fully receive the incoming payment
2. Resolve the next node (unwrap blinded paths if needed)
3. Wake-up the next node if necessary (mobile wallet)
4. Relay outgoing payment

Note that we introduce a wake-up step, that can be extended to include
mobile notifications. We introduce that same wake-up step in channel
relay and message relay. We also allow relaying data to contain a wallet
`node_id` instead of an scid. When that's the case, we start by waking
up that wallet node before we try relaying onion messages or payments.

This wake-up step doesn't contain any logic right now apart from waiting
for the peer to connect, if it isn't connected already. But it can easily be
extended to send a mobile notification to prompt the wallet to connect.
This commit is contained in:
Bastien Teinturier 2024-08-28 09:43:11 +02:00 committed by GitHub
parent c440007b52
commit fcd88b0a0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 1029 additions and 594 deletions

View File

@ -318,6 +318,13 @@ eclair {
max-no-channels = 250 // maximum number of incoming connections from peers that do not have any channels with us
}
// When relaying payments or messages to mobile peers who are disconnected, we may try to wake them up using a mobile
// notification system, or we attempt connecting to the last known address.
peer-wake-up {
enabled = false
timeout = 60 seconds
}
auto-reconnect = true
initial-random-reconnect-delay = 5 seconds // we add a random delay before the first reconnection attempt, capped by this value
max-reconnect-interval = 1 hour // max interval between two reconnection attempts, after the exponential backoff period

View File

@ -28,7 +28,7 @@ import fr.acinq.eclair.crypto.Noise.KeyPair
import fr.acinq.eclair.crypto.keymanager.{ChannelKeyManager, NodeKeyManager, OnChainKeyManager}
import fr.acinq.eclair.db._
import fr.acinq.eclair.io.MessageRelay.{RelayAll, RelayChannelsOnly, RelayPolicy}
import fr.acinq.eclair.io.PeerConnection
import fr.acinq.eclair.io.{PeerConnection, PeerReadyNotifier}
import fr.acinq.eclair.message.OnionMessages.OnionMessageConfig
import fr.acinq.eclair.payment.relay.Relayer.{AsyncPaymentsParams, RelayFees, RelayParams}
import fr.acinq.eclair.router.Announcements.AddressException
@ -87,7 +87,8 @@ case class NodeParams(nodeKeyManager: NodeKeyManager,
blockchainWatchdogSources: Seq[String],
onionMessageConfig: OnionMessageConfig,
purgeInvoicesInterval: Option[FiniteDuration],
revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config) {
revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config,
peerWakeUpConfig: PeerReadyNotifier.WakeUpConfig) {
val privateKey: Crypto.PrivateKey = nodeKeyManager.nodeKey.privateKey
val nodeId: PublicKey = nodeKeyManager.nodeId
@ -611,7 +612,11 @@ object NodeParams extends Logging {
revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(
batchSize = config.getInt("db.revoked-htlc-info-cleaner.batch-size"),
interval = FiniteDuration(config.getDuration("db.revoked-htlc-info-cleaner.interval").getSeconds, TimeUnit.SECONDS)
)
),
peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(
enabled = config.getBoolean("peer-wake-up.enabled"),
timeout = FiniteDuration(config.getDuration("peer-wake-up.timeout").getSeconds, TimeUnit.SECONDS)
),
)
}
}

View File

@ -360,7 +360,8 @@ class Setup(val datadir: File,
offerManager = system.spawn(Behaviors.supervise(OfferManager(nodeParams, router, paymentTimeout = 1 minute)).onFailure(typed.SupervisorStrategy.resume), name = "offer-manager")
paymentHandler = system.actorOf(SimpleSupervisor.props(PaymentHandler.props(nodeParams, register, offerManager), "payment-handler", SupervisorStrategy.Resume))
triggerer = system.spawn(Behaviors.supervise(AsyncPaymentTriggerer()).onFailure(typed.SupervisorStrategy.resume), name = "async-payment-triggerer")
relayer = system.actorOf(SimpleSupervisor.props(Relayer.props(nodeParams, router, register, paymentHandler, triggerer, Some(postRestartCleanUpInitialized)), "relayer", SupervisorStrategy.Resume))
peerReadyManager = system.spawn(Behaviors.supervise(PeerReadyManager()).onFailure(typed.SupervisorStrategy.restart), name = "peer-ready-manager")
relayer = system.actorOf(SimpleSupervisor.props(Relayer.props(nodeParams, router, register, paymentHandler, Some(postRestartCleanUpInitialized)), "relayer", SupervisorStrategy.Resume))
_ = relayer ! PostRestartHtlcCleaner.Init(channels)
// Before initializing the switchboard (which re-connects us to the network) and the user-facing parts of the system,
// we want to make sure the handler for post-restart broken HTLCs has finished initializing.

View File

@ -44,29 +44,18 @@ object MessageRelay {
policy: RelayPolicy,
replyTo_opt: Option[typed.ActorRef[Status]]) extends Command
case class WrappedPeerInfo(peerInfo: PeerInfoResponse) extends Command
case class WrappedConnectionResult(result: PeerConnection.ConnectionResult) extends Command
case class WrappedOptionalNodeId(nodeId_opt: Option[PublicKey]) extends Command
private case class WrappedConnectionResult(result: PeerConnection.ConnectionResult) extends Command
private case class WrappedOptionalNodeId(nodeId_opt: Option[PublicKey]) extends Command
private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command
sealed trait Status {
val messageId: ByteVector32
}
sealed trait Status { val messageId: ByteVector32 }
case class Sent(messageId: ByteVector32) extends Status
sealed trait Failure extends Status
case class AgainstPolicy(messageId: ByteVector32, policy: RelayPolicy) extends Failure {
override def toString: String = s"Relay prevented by policy $policy"
}
case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure {
override def toString: String = s"Can't connect to peer: ${failure.toString}"
}
case class Disconnected(messageId: ByteVector32) extends Failure {
override def toString: String = "Peer is not connected"
}
case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure {
override def toString: String = s"Unknown channel: $channelId"
}
case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure {
override def toString: String = s"Message dropped: $reason"
}
case class AgainstPolicy(messageId: ByteVector32, policy: RelayPolicy) extends Failure { override def toString: String = s"Relay prevented by policy $policy" }
case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure { override def toString: String = s"Can't connect to peer: ${failure.toString}" }
case class Disconnected(messageId: ByteVector32) extends Failure { override def toString: String = "Peer is not connected" }
case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure { override def toString: String = s"Unknown channel: $channelId" }
case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure { override def toString: String = s"Message dropped: $reason" }
sealed trait RelayPolicy
case object RelayChannelsOnly extends RelayPolicy
@ -106,7 +95,7 @@ private class MessageRelay(nodeParams: NodeParams,
def queryNextNodeId(msg: OnionMessage, nextNode: Either[ShortChannelId, EncodedNodeId]): Behavior[Command] = {
nextNode match {
case Left(outgoingChannelId) if outgoingChannelId == ShortChannelId.toSelf =>
withNextNodeId(msg, nodeParams.nodeId)
withNextNodeId(msg, EncodedNodeId.WithPublicKey.Plain(nodeParams.nodeId))
case Left(outgoingChannelId) =>
register ! Register.GetNextNodeId(context.messageAdapter(WrappedOptionalNodeId), outgoingChannelId)
waitForNextNodeId(msg, outgoingChannelId)
@ -114,7 +103,7 @@ private class MessageRelay(nodeParams: NodeParams,
router ! Router.GetNodeId(context.messageAdapter(WrappedOptionalNodeId), scid, isNode1)
waitForNextNodeId(msg, scid)
case Right(encodedNodeId: EncodedNodeId.WithPublicKey) =>
withNextNodeId(msg, encodedNodeId.publicKey)
withNextNodeId(msg, encodedNodeId)
}
}
@ -127,34 +116,39 @@ private class MessageRelay(nodeParams: NodeParams,
Behaviors.stopped
case WrappedOptionalNodeId(Some(nextNodeId)) =>
log.info("found outgoing node {} for channel {}", nextNodeId, channelId)
withNextNodeId(msg, nextNodeId)
withNextNodeId(msg, EncodedNodeId.WithPublicKey.Plain(nextNodeId))
}
}
private def withNextNodeId(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = {
if (nextNodeId == nodeParams.nodeId) {
OnionMessages.process(nodeParams.privateKey, msg) match {
case OnionMessages.DropMessage(reason) =>
Metrics.OnionMessagesNotRelayed.withTag(Tags.Reason, reason.getClass.getSimpleName).increment()
replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason))
Behaviors.stopped
case OnionMessages.SendMessage(nextNode, nextMessage) =>
// We need to repeat the process until we identify the (real) next node, or find out that we're the recipient.
queryNextNodeId(nextMessage, nextNode)
case received: OnionMessages.ReceiveMessage =>
context.system.eventStream ! EventStream.Publish(received)
replyTo_opt.foreach(_ ! Sent(messageId))
Behaviors.stopped
}
} else {
policy match {
case RelayChannelsOnly =>
switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId)
waitForPreviousPeerForPolicyCheck(msg, nextNodeId)
case RelayAll =>
switchboard ! Peer.Connect(nextNodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false)
waitForConnection(msg, nextNodeId)
}
private def withNextNodeId(msg: OnionMessage, nextNodeId: EncodedNodeId.WithPublicKey): Behavior[Command] = {
nextNodeId match {
case EncodedNodeId.WithPublicKey.Plain(nodeId) if nodeId == nodeParams.nodeId =>
OnionMessages.process(nodeParams.privateKey, msg) match {
case OnionMessages.DropMessage(reason) =>
Metrics.OnionMessagesNotRelayed.withTag(Tags.Reason, reason.getClass.getSimpleName).increment()
replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason))
Behaviors.stopped
case OnionMessages.SendMessage(nextNode, nextMessage) =>
// We need to repeat the process until we identify the (real) next node, or find out that we're the recipient.
queryNextNodeId(nextMessage, nextNode)
case received: OnionMessages.ReceiveMessage =>
context.system.eventStream ! EventStream.Publish(received)
replyTo_opt.foreach(_ ! Sent(messageId))
Behaviors.stopped
}
case EncodedNodeId.WithPublicKey.Plain(nodeId) =>
policy match {
case RelayChannelsOnly =>
switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId)
waitForPreviousPeerForPolicyCheck(msg, nodeId)
case RelayAll =>
switchboard ! Peer.Connect(nodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false)
waitForConnection(msg, nodeId)
}
case EncodedNodeId.WithPublicKey.Wallet(nodeId) =>
val notifier = context.spawnAnonymous(PeerReadyNotifier(nodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout))))
notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult))
waitForWalletNodeUp(msg, nodeId)
}
}
@ -197,4 +191,18 @@ private class MessageRelay(nodeParams: NodeParams,
Behaviors.stopped
}
}
private def waitForWalletNodeUp(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = {
Behaviors.receiveMessagePartial {
case WrappedPeerReadyResult(r: PeerReadyNotifier.PeerReady) =>
log.info("successfully woke up {}: relaying onion message", nextNodeId)
r.peer ! Peer.RelayOnionMessage(messageId, msg, replyTo_opt)
Behaviors.stopped
case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) =>
Metrics.OnionMessagesNotRelayed.withTag(Tags.Reason, Tags.Reasons.ConnectionFailure).increment()
log.info("could not wake up {}: onion message cannot be relayed", nextNodeId)
replyTo_opt.foreach(_ ! Disconnected(messageId))
Behaviors.stopped
}
}
}

View File

@ -17,36 +17,104 @@
package fr.acinq.eclair.io
import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.receptionist.Receptionist
import akka.actor.typed.receptionist.{Receptionist, ServiceKey}
import akka.actor.typed.scaladsl.adapter.{ClassicActorRefOps, TypedActorRefOps}
import akka.actor.typed.scaladsl.{ActorContext, Behaviors, TimerScheduler}
import akka.actor.typed.{ActorRef, Behavior}
import akka.actor.typed.{ActorRef, Behavior, SupervisorStrategy}
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.blockchain.CurrentBlockHeight
import fr.acinq.eclair.{BlockHeight, Logs, channel}
import scala.concurrent.duration.{DurationInt, FiniteDuration}
/**
* This actor tracks the set of pending [[PeerReadyNotifier]].
* It can be used to ensure that notifications are only sent once, even if there are multiple parallel operations
* waiting for that peer to come online.
*/
object PeerReadyManager {
val PeerReadyManagerServiceKey: ServiceKey[Register] = ServiceKey[Register]("peer-ready-manager")
// @formatter:off
sealed trait Command
case class Register(replyTo: ActorRef[Registered], remoteNodeId: PublicKey) extends Command
case class List(replyTo: ActorRef[Set[PublicKey]]) extends Command
private case class Completed(remoteNodeId: PublicKey, actor: ActorRef[Registered]) extends Command
// @formatter:on
/**
* @param otherAttempts number of already pending [[PeerReadyNotifier]] instances for that peer.
*/
case class Registered(remoteNodeId: PublicKey, otherAttempts: Int)
def apply(): Behavior[Command] = {
Behaviors.setup { context =>
context.system.receptionist ! Receptionist.Register(PeerReadyManagerServiceKey, context.self)
watch(Map.empty, context)
}
}
private def watch(pending: Map[PublicKey, Set[ActorRef[Registered]]], context: ActorContext[Command]): Behavior[Command] = {
Behaviors.receiveMessage {
case Register(replyTo, remoteNodeId) =>
context.watchWith(replyTo, Completed(remoteNodeId, replyTo))
pending.get(remoteNodeId) match {
case Some(attempts) =>
replyTo ! Registered(remoteNodeId, otherAttempts = attempts.size)
val attempts1 = attempts + replyTo
watch(pending + (remoteNodeId -> attempts1), context)
case None =>
replyTo ! Registered(remoteNodeId, otherAttempts = 0)
watch(pending + (remoteNodeId -> Set(replyTo)), context)
}
case Completed(remoteNodeId, actor) =>
pending.get(remoteNodeId) match {
case Some(attempts) =>
val attempts1 = attempts - actor
if (attempts1.isEmpty) {
watch(pending - remoteNodeId, context)
} else {
watch(pending + (remoteNodeId -> attempts1), context)
}
case None =>
Behaviors.same
}
case List(replyTo) =>
replyTo ! pending.keySet
Behaviors.same
}
}
}
/**
* This actor waits for a given peer to be online and ready to process payments.
* It automatically stops after the timeout provided.
* It automatically stops after the timeout provided if the peer doesn't connect.
* There may be multiple instances of this actor running in parallel for the same peer, which is fine because they
* may use different timeouts.
* Having separate actor instances for each caller guarantees that the caller will always receive a response.
*/
object PeerReadyNotifier {
case class WakeUpConfig(enabled: Boolean, timeout: FiniteDuration)
// @formatter:off
sealed trait Command
case class NotifyWhenPeerReady(replyTo: ActorRef[Result]) extends Command
private final case class WrappedListing(wrapped: Receptionist.Listing) extends Command
private final case class WrappedRegistered(registered: PeerReadyManager.Registered) extends Command
private case object PeerNotConnected extends Command
private case class SomePeerConnected(nodeId: PublicKey) extends Command
private case class SomePeerDisconnected(nodeId: PublicKey) extends Command
private case object PeerConnected extends Command
private case object PeerDisconnected extends Command
private case class WrappedPeerInfo(peer: ActorRef[Peer.GetPeerChannels], channelCount: Int) extends Command
private case class NewBlockNotTimedOut(currentBlockHeight: BlockHeight) extends Command
private case object CheckChannelsReady extends Command
private case class WrappedPeerChannels(wrapped: Peer.PeerChannels) extends Command
private case object Timeout extends Command
private case object ToBeIgnored extends Command
sealed trait Result
sealed trait Result { def remoteNodeId: PublicKey }
case class PeerReady(remoteNodeId: PublicKey, peer: akka.actor.ActorRef, channelInfos: Seq[Peer.ChannelInfo]) extends Result { val channelsCount: Int = channelInfos.size }
case class PeerUnavailable(remoteNodeId: PublicKey) extends Result
@ -66,102 +134,40 @@ object PeerReadyNotifier {
case cbc => NewBlockNotTimedOut(cbc.blockHeight)
})
}
// In case the peer is not currently connected, we will wait for them to connect instead of regularly
// polling the switchboard. This makes more sense for long timeouts such as the ones used for async payments.
context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerConnected](e => SomePeerConnected(e.nodeId)))
context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerDisconnected](e => SomePeerDisconnected(e.nodeId)))
findSwitchboard(replyTo, remoteNodeId, context, timers)
// The actor should never throw, but for extra safety we wrap it with a supervisor.
Behaviors.supervise {
start(replyTo, remoteNodeId, context, timers)
}.onFailure(SupervisorStrategy.stop)
}
}
}
}
}
private def findSwitchboard(replyTo: ActorRef[Result], remoteNodeId: PublicKey, context: ActorContext[Command], timers: TimerScheduler[Command]): Behavior[Command] = {
context.system.receptionist ! Receptionist.Find(Switchboard.SwitchboardServiceKey, context.messageAdapter[Receptionist.Listing](WrappedListing))
private def start(replyTo: ActorRef[Result], remoteNodeId: PublicKey, context: ActorContext[Command], timers: TimerScheduler[Command]): Behavior[Command] = {
// We start by registering ourself to see if other instances are running.
context.system.receptionist ! Receptionist.Find(PeerReadyManager.PeerReadyManagerServiceKey, context.messageAdapter[Receptionist.Listing](WrappedListing))
Behaviors.receiveMessagePartial {
case WrappedListing(Switchboard.SwitchboardServiceKey.Listing(listings)) =>
case WrappedListing(PeerReadyManager.PeerReadyManagerServiceKey.Listing(listings)) =>
listings.headOption match {
case Some(switchboard) =>
waitForPeerConnected(replyTo, remoteNodeId, switchboard, context, timers)
case Some(peerReadyManager) =>
peerReadyManager ! PeerReadyManager.Register(context.messageAdapter[PeerReadyManager.Registered](WrappedRegistered), remoteNodeId)
Behaviors.same
case None =>
context.log.error("no switchboard found")
context.log.error("no peer-ready-manager found")
replyTo ! PeerUnavailable(remoteNodeId)
Behaviors.stopped
}
}
}
private def waitForPeerConnected(replyTo: ActorRef[Result], remoteNodeId: PublicKey, switchboard: ActorRef[Switchboard.GetPeerInfo], context: ActorContext[Command], timers: TimerScheduler[Command]): Behavior[Command] = {
val peerInfoAdapter = context.messageAdapter[Peer.PeerInfoResponse] {
// We receive this when we don't have any channel to the given peer and are not currently connected to them.
// In that case we still want to wait for a connection, because we may want to open a channel to them.
case _: Peer.PeerNotFound => PeerNotConnected
case info: Peer.PeerInfo if info.state != Peer.CONNECTED => PeerNotConnected
case info: Peer.PeerInfo => WrappedPeerInfo(info.peer.toTyped, info.channels.size)
}
// We check whether the peer is already connected.
switchboard ! Switchboard.GetPeerInfo(peerInfoAdapter, remoteNodeId)
Behaviors.receiveMessagePartial {
case PeerNotConnected =>
context.log.debug("peer is not connected yet")
Behaviors.same
case SomePeerConnected(nodeId) =>
if (nodeId == remoteNodeId) {
switchboard ! Switchboard.GetPeerInfo(peerInfoAdapter, remoteNodeId)
}
Behaviors.same
case SomePeerDisconnected(_) =>
Behaviors.same
case WrappedPeerInfo(peer, channelCount) =>
if (channelCount == 0) {
context.log.info("peer is ready with no channels")
replyTo ! PeerReady(remoteNodeId, peer.toClassic, Seq.empty)
Behaviors.stopped
} else {
context.log.debug("peer is connected with {} channels", channelCount)
waitForChannelsReady(replyTo, remoteNodeId, peer, switchboard, context, timers)
}
case NewBlockNotTimedOut(currentBlockHeight) =>
context.log.debug("waiting for peer to connect at block {}", currentBlockHeight)
Behaviors.same
case WrappedRegistered(registered) =>
context.log.info("checking if peer is ready ({} other attempts)", registered.otherAttempts)
val isFirstAttempt = registered.otherAttempts == 0
// In case the peer is not currently connected, we will wait for them to connect instead of regularly
// polling the switchboard. This makes more sense for long timeouts such as the ones used for async payments.
context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerConnected](e => if (e.nodeId == remoteNodeId) PeerConnected else ToBeIgnored))
context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerDisconnected](e => if (e.nodeId == remoteNodeId) PeerDisconnected else ToBeIgnored))
new PeerReadyNotifier(replyTo, remoteNodeId, isFirstAttempt, context, timers).findSwitchboard()
case Timeout =>
context.log.info("timed out waiting for peer to connect")
replyTo ! PeerUnavailable(remoteNodeId)
Behaviors.stopped
}
}
private def waitForChannelsReady(replyTo: ActorRef[Result], remoteNodeId: PublicKey, peer: ActorRef[Peer.GetPeerChannels], switchboard: ActorRef[Switchboard.GetPeerInfo], context: ActorContext[Command], timers: TimerScheduler[Command]): Behavior[Command] = {
timers.startTimerWithFixedDelay(ChannelsReadyTimerKey, CheckChannelsReady, initialDelay = 50 millis, delay = 1 second)
Behaviors.receiveMessagePartial {
case CheckChannelsReady =>
context.log.debug("checking channel states")
peer ! Peer.GetPeerChannels(context.messageAdapter[Peer.PeerChannels](WrappedPeerChannels))
Behaviors.same
case WrappedPeerChannels(peerChannels) =>
if (peerChannels.channels.map(_.state).forall(isChannelReady)) {
replyTo ! PeerReady(remoteNodeId, peer.toClassic, peerChannels.channels)
Behaviors.stopped
} else {
context.log.debug("peer has {} channels that are not ready", peerChannels.channels.count(s => !isChannelReady(s.state)))
Behaviors.same
}
case NewBlockNotTimedOut(currentBlockHeight) =>
context.log.debug("waiting for channels to be ready at block {}", currentBlockHeight)
Behaviors.same
case SomePeerConnected(_) =>
Behaviors.same
case SomePeerDisconnected(nodeId) =>
if (nodeId == remoteNodeId) {
context.log.debug("peer disconnected, waiting for them to reconnect")
timers.cancel(ChannelsReadyTimerKey)
waitForPeerConnected(replyTo, remoteNodeId, switchboard, context, timers)
} else {
Behaviors.same
}
case Timeout =>
context.log.info("timed out waiting for channels to be ready")
context.log.info("timed out finding peer-ready-manager actor")
replyTo ! PeerUnavailable(remoteNodeId)
Behaviors.stopped
}
@ -199,3 +205,109 @@ object PeerReadyNotifier {
}
}
private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result],
remoteNodeId: PublicKey,
isFirstAttempt: Boolean,
context: ActorContext[PeerReadyNotifier.Command],
timers: TimerScheduler[PeerReadyNotifier.Command]) {
import PeerReadyNotifier._
private val log = context.log
private def findSwitchboard(): Behavior[Command] = {
context.system.receptionist ! Receptionist.Find(Switchboard.SwitchboardServiceKey, context.messageAdapter[Receptionist.Listing](WrappedListing))
Behaviors.receiveMessagePartial {
case WrappedListing(Switchboard.SwitchboardServiceKey.Listing(listings)) =>
listings.headOption match {
case Some(switchboard) =>
waitForPeerConnected(switchboard)
case None =>
log.error("no switchboard found")
replyTo ! PeerUnavailable(remoteNodeId)
Behaviors.stopped
}
case Timeout =>
log.info("timed out finding switchboard actor")
replyTo ! PeerUnavailable(remoteNodeId)
Behaviors.stopped
case ToBeIgnored =>
Behaviors.same
}
}
private def waitForPeerConnected(switchboard: ActorRef[Switchboard.GetPeerInfo]): Behavior[Command] = {
val peerInfoAdapter = context.messageAdapter[Peer.PeerInfoResponse] {
// We receive this when we don't have any channel to the given peer and are not currently connected to them.
// In that case we still want to wait for a connection, because we may want to open a channel to them.
case _: Peer.PeerNotFound => PeerNotConnected
case info: Peer.PeerInfo if info.state != Peer.CONNECTED => PeerNotConnected
case info: Peer.PeerInfo => WrappedPeerInfo(info.peer.toTyped, info.channels.size)
}
// We check whether the peer is already connected.
switchboard ! Switchboard.GetPeerInfo(peerInfoAdapter, remoteNodeId)
Behaviors.receiveMessagePartial {
case PeerNotConnected =>
log.debug("peer is not connected yet")
Behaviors.same
case PeerConnected =>
switchboard ! Switchboard.GetPeerInfo(peerInfoAdapter, remoteNodeId)
Behaviors.same
case PeerDisconnected =>
Behaviors.same
case WrappedPeerInfo(peer, channelCount) =>
if (channelCount == 0) {
log.info("peer is ready with no channels")
replyTo ! PeerReady(remoteNodeId, peer.toClassic, Seq.empty)
Behaviors.stopped
} else {
log.debug("peer is connected with {} channels", channelCount)
waitForChannelsReady(peer, switchboard)
}
case NewBlockNotTimedOut(currentBlockHeight) =>
log.debug("waiting for peer to connect at block {}", currentBlockHeight)
Behaviors.same
case Timeout =>
log.info("timed out waiting for peer to connect")
replyTo ! PeerUnavailable(remoteNodeId)
Behaviors.stopped
case ToBeIgnored =>
Behaviors.same
}
}
private def waitForChannelsReady(peer: ActorRef[Peer.GetPeerChannels], switchboard: ActorRef[Switchboard.GetPeerInfo]): Behavior[Command] = {
timers.startTimerWithFixedDelay(ChannelsReadyTimerKey, CheckChannelsReady, initialDelay = 50 millis, delay = 1 second)
Behaviors.receiveMessagePartial {
case CheckChannelsReady =>
log.debug("checking channel states")
peer ! Peer.GetPeerChannels(context.messageAdapter[Peer.PeerChannels](WrappedPeerChannels))
Behaviors.same
case WrappedPeerChannels(peerChannels) =>
if (peerChannels.channels.map(_.state).forall(isChannelReady)) {
replyTo ! PeerReady(remoteNodeId, peer.toClassic, peerChannels.channels)
Behaviors.stopped
} else {
log.debug("peer has {} channels that are not ready", peerChannels.channels.count(s => !isChannelReady(s.state)))
Behaviors.same
}
case NewBlockNotTimedOut(currentBlockHeight) =>
log.debug("waiting for channels to be ready at block {}", currentBlockHeight)
Behaviors.same
case PeerConnected =>
Behaviors.same
case PeerDisconnected =>
log.debug("peer disconnected, waiting for them to reconnect")
timers.cancel(ChannelsReadyTimerKey)
waitForPeerConnected(switchboard)
case Timeout =>
log.info("timed out waiting for channels to be ready")
replyTo ! PeerUnavailable(remoteNodeId)
Behaviors.stopped
case ToBeIgnored =>
Behaviors.same
}
}
}

View File

@ -119,6 +119,7 @@ object Monitoring {
val Failure = "failure"
object FailureType {
val WakeUp = "WakeUp"
val Remote = "Remote"
val Malformed = "MalformedHtlc"

View File

@ -126,7 +126,7 @@ object IncomingPaymentPacket {
decryptEncryptedRecipientData(add, privateKey, payload, encrypted.data).flatMap {
case DecodedEncryptedRecipientData(blindedPayload, nextBlinding) =>
validateBlindedChannelRelayPayload(add, payload, blindedPayload, nextBlinding, nextPacket).flatMap {
case ChannelRelayPacket(_, payload, nextPacket) if payload.outgoingChannelId == ShortChannelId.toSelf =>
case ChannelRelayPacket(_, payload, nextPacket) if payload.outgoing == Right(ShortChannelId.toSelf) =>
decrypt(add.copy(onionRoutingPacket = nextPacket, tlvStream = add.tlvStream.copy(records = Set(UpdateAddHtlcTlv.BlindingPoint(nextBlinding)))), privateKey, features)
case relayPacket => Right(relayPacket)
}

View File

@ -19,7 +19,7 @@ package fr.acinq.eclair.payment.relay
import akka.actor.typed.ActorRef.ActorRefOps
import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import akka.actor.typed.{ActorRef, Behavior, SupervisorStrategy}
import akka.actor.typed.{ActorRef, Behavior}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.Logs.LogCategory
@ -99,7 +99,7 @@ private class AsyncPaymentTriggerer(context: ActorContext[Command]) {
case Watch(replyTo, remoteNodeId, paymentHash, timeout) =>
peers.get(remoteNodeId) match {
case None =>
val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(remoteNodeId, timeout_opt = None)).onFailure(SupervisorStrategy.stop))
val notifier = context.spawnAnonymous(PeerReadyNotifier(remoteNodeId, timeout_opt = None))
context.watchWith(notifier, NotifierStopped(remoteNodeId))
notifier ! NotifyWhenPeerReady(context.messageAdapter[PeerReadyNotifier.Result](WrappedPeerReadyResult))
val peer = PeerPayments(notifier, Set(Payment(replyTo, timeout, paymentHash)))

View File

@ -16,16 +16,17 @@
package fr.acinq.eclair.payment.relay
import akka.actor.ActorRef
import akka.actor.typed.Behavior
import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.scaladsl.adapter.TypedActorRefOps
import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import akka.actor.{ActorRef, typed}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.channel._
import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.db.PendingCommandsDb
import fr.acinq.eclair.io.PeerReadyNotifier
import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags}
import fr.acinq.eclair.payment.relay.Relayer.{OutgoingChannel, OutgoingChannelParams}
import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket}
@ -44,6 +45,7 @@ object ChannelRelay {
// @formatter:off
sealed trait Command
private case object DoRelay extends Command
private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command
private case class WrappedForwardFailure(failure: Register.ForwardFailure[CMD_ADD_HTLC]) extends Command
private case class WrappedAddResponse(res: CommandResponse[CMD_ADD_HTLC]) extends Command
// @formatter:on
@ -57,7 +59,7 @@ object ChannelRelay {
def apply(nodeParams: NodeParams,
register: ActorRef,
channels: Map[ByteVector32, Relayer.OutgoingChannel],
originNode:PublicKey,
originNode: PublicKey,
relayId: UUID,
r: IncomingPaymentPacket.ChannelRelayPacket): Behavior[Command] =
Behaviors.setup { context =>
@ -67,9 +69,8 @@ object ChannelRelay {
paymentHash_opt = Some(r.add.paymentHash),
nodeAlias_opt = Some(nodeParams.alias))) {
val upstream = Upstream.Hot.Channel(r.add.removeUnknownTlvs(), TimestampMilli.now(), originNode)
context.self ! DoRelay
val confidence = (r.add.endorsement + 0.5) / 8
new ChannelRelay(nodeParams, register, channels, r, upstream, confidence, context).relay(Seq.empty)
new ChannelRelay(nodeParams, register, channels, r, upstream, confidence, context).start()
}
}
@ -77,7 +78,7 @@ object ChannelRelay {
* This helper method translates relaying errors (returned by the downstream outgoing channel) to BOLT 4 standard
* errors that we should return upstream.
*/
def translateLocalError(error: Throwable, channelUpdate_opt: Option[ChannelUpdate]): FailureMessage = {
private def translateLocalError(error: ChannelException, channelUpdate_opt: Option[ChannelUpdate]): FailureMessage = {
(error, channelUpdate_opt) match {
case (_: ExpiryTooSmall, Some(channelUpdate)) => ExpiryTooSoon(Some(channelUpdate))
case (_: ExpiryTooBig, _) => ExpiryTooFar()
@ -121,13 +122,57 @@ class ChannelRelay private(nodeParams: NodeParams,
private val forwardFailureAdapter = context.messageAdapter[Register.ForwardFailure[CMD_ADD_HTLC]](WrappedForwardFailure)
private val addResponseAdapter = context.messageAdapter[CommandResponse[CMD_ADD_HTLC]](WrappedAddResponse)
private val nextBlindingKey_opt = r.payload match {
case payload: IntermediatePayload.ChannelRelay.Blinded => Some(payload.nextBlinding)
case _: IntermediatePayload.ChannelRelay.Standard => None
}
/** Channel id explicitly requested in the onion payload. */
private val requestedChannelId_opt = r.payload.outgoing match {
case Left(_) => None
case Right(outgoingChannelId) => channels.collectFirst {
case (channelId, channel) if channel.shortIds.localAlias == outgoingChannelId => channelId
case (channelId, channel) if channel.shortIds.real.toOption.contains(outgoingChannelId) => channelId
}
}
private val (requestedShortChannelId_opt, walletNodeId_opt) = r.payload.outgoing match {
case Left(walletNodeId) => (None, Some(walletNodeId))
case Right(shortChannelId) => (Some(shortChannelId), None)
}
private case class PreviouslyTried(channelId: ByteVector32, failure: RES_ADD_FAILED[ChannelException])
def start(): Behavior[Command] = {
walletNodeId_opt match {
case Some(walletNodeId) if nodeParams.peerWakeUpConfig.enabled => wakeUp(walletNodeId)
case _ =>
context.self ! DoRelay
relay(Seq.empty)
}
}
private def wakeUp(walletNodeId: PublicKey): Behavior[Command] = {
context.log.info("trying to wake up channel peer (nodeId={})", walletNodeId)
val notifier = context.spawnAnonymous(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout))))
notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult))
Behaviors.receiveMessagePartial {
case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) =>
Metrics.recordPaymentRelayFailed(Tags.FailureType.WakeUp, Tags.RelayType.Channel)
context.log.info("rejecting htlc: failed to wake-up remote peer")
safeSendAndStop(r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true))
case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerReady) =>
context.self ! DoRelay
relay(Seq.empty)
}
}
def relay(previousFailures: Seq[PreviouslyTried]): Behavior[Command] = {
Behaviors.receiveMessagePartial {
case DoRelay =>
if (previousFailures.isEmpty) {
context.log.info("relaying htlc #{} from channelId={} to requestedShortChannelId={} nextNode={}", r.add.id, r.add.channelId, r.payload.outgoingChannelId, nextNodeId_opt.getOrElse(""))
val nextNodeId_opt = channels.headOption.map(_._2.nextNodeId)
context.log.info("relaying htlc #{} from channelId={} to requestedShortChannelId={} nextNode={}", r.add.id, r.add.channelId, requestedShortChannelId_opt, nextNodeId_opt.getOrElse(""))
}
context.log.debug("attempting relay previousAttempts={}", previousFailures.size)
handleRelay(previousFailures) match {
@ -143,7 +188,7 @@ class ChannelRelay private(nodeParams: NodeParams,
}
}
def waitForAddResponse(selectedChannelId: ByteVector32, previousFailures: Seq[PreviouslyTried]): Behavior[Command] =
private def waitForAddResponse(selectedChannelId: ByteVector32, previousFailures: Seq[PreviouslyTried]): Behavior[Command] =
Behaviors.receiveMessagePartial {
case WrappedForwardFailure(Register.ForwardFailure(Register.Forward(_, channelId, _))) =>
context.log.warn(s"couldn't resolve downstream channel $channelId, failing htlc #${upstream.add.id}")
@ -156,23 +201,23 @@ class ChannelRelay private(nodeParams: NodeParams,
context.self ! DoRelay
relay(previousFailures :+ PreviouslyTried(selectedChannelId, addFailed))
case WrappedAddResponse(r: RES_SUCCESS[_]) =>
case WrappedAddResponse(_: RES_SUCCESS[_]) =>
context.log.debug("sent htlc to the downstream channel")
waitForAddSettled(r.channelId)
waitForAddSettled()
}
def waitForAddSettled(channelId: ByteVector32): Behavior[Command] =
private def waitForAddSettled(): Behavior[Command] =
Behaviors.receiveMessagePartial {
case WrappedAddResponse(RES_ADD_SETTLED(_, htlc, fulfill: HtlcResult.Fulfill)) =>
context.log.info("relaying fulfill to upstream, startedAt={}, endedAt={}, confidence={}, originNode={}, outgoingChannel={}", upstream.receivedAt, TimestampMilli.now(), confidence, upstream.receivedFrom, channelId)
context.log.info("relaying fulfill to upstream, startedAt={}, endedAt={}, confidence={}, originNode={}, outgoingChannel={}", upstream.receivedAt, TimestampMilli.now(), confidence, upstream.receivedFrom, htlc.channelId)
Metrics.relayFulfill(confidence)
val cmd = CMD_FULFILL_HTLC(upstream.add.id, fulfill.paymentPreimage, commit = true)
context.system.eventStream ! EventStream.Publish(ChannelPaymentRelayed(upstream.amountIn, htlc.amountMsat, htlc.paymentHash, upstream.add.channelId, htlc.channelId, upstream.receivedAt, TimestampMilli.now()))
recordRelayDuration(isSuccess = true)
safeSendAndStop(upstream.add.channelId, cmd)
case WrappedAddResponse(RES_ADD_SETTLED(_, _, fail: HtlcResult.Fail)) =>
context.log.info("relaying fail to upstream, startedAt={}, endedAt={}, confidence={}, originNode={}, outgoingChannel={}", upstream.receivedAt, TimestampMilli.now(), confidence, upstream.receivedFrom, channelId)
case WrappedAddResponse(RES_ADD_SETTLED(_, htlc, fail: HtlcResult.Fail)) =>
context.log.info("relaying fail to upstream, startedAt={}, endedAt={}, confidence={}, originNode={}, outgoingChannel={}", upstream.receivedAt, TimestampMilli.now(), confidence, upstream.receivedFrom, htlc.channelId)
Metrics.relayFail(confidence)
Metrics.recordPaymentRelayFailed(Tags.FailureType.Remote, Tags.RelayType.Channel)
val cmd = translateRelayFailure(upstream.add.id, fail)
@ -180,7 +225,7 @@ class ChannelRelay private(nodeParams: NodeParams,
safeSendAndStop(upstream.add.channelId, cmd)
}
def safeSendAndStop(channelId: ByteVector32, cmd: channel.HtlcSettlementCommand): Behavior[Command] = {
private def safeSendAndStop(channelId: ByteVector32, cmd: channel.HtlcSettlementCommand): Behavior[Command] = {
val toSend = cmd match {
case _: CMD_FULFILL_HTLC => cmd
case _: CMD_FAIL_HTLC | _: CMD_FAIL_MALFORMED_HTLC => r.payload match {
@ -211,49 +256,44 @@ class ChannelRelay private(nodeParams: NodeParams,
* - a CMD_FAIL_HTLC to be sent back upstream
* - a CMD_ADD_HTLC to propagate downstream
*/
def handleRelay(previousFailures: Seq[PreviouslyTried]): RelayResult = {
private def handleRelay(previousFailures: Seq[PreviouslyTried]): RelayResult = {
val alreadyTried = previousFailures.map(_.channelId)
selectPreferredChannel(alreadyTried) match {
case None if previousFailures.nonEmpty =>
// no more channels to try
val error = previousFailures
// we return the error for the initially requested channel if it exists
.find(failure => requestedChannelId_opt.contains(failure.channelId))
// otherwise we return the error for the first channel tried
.getOrElse(previousFailures.head)
.failure
RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(translateLocalError(error.t, error.channelUpdate)), commit = true))
case outgoingChannel_opt =>
relayOrFail(outgoingChannel_opt)
case Some(outgoingChannel) => relayOrFail(outgoingChannel)
case None =>
// No more channels to try.
val cmdFail = if (previousFailures.nonEmpty) {
val error = previousFailures
// We return the error for the initially requested channel if it exists.
.find(failure => requestedChannelId_opt.contains(failure.channelId))
// Otherwise we return the error for the first channel tried.
.getOrElse(previousFailures.head)
.failure
CMD_FAIL_HTLC(r.add.id, Right(translateLocalError(error.t, error.channelUpdate)), commit = true)
} else {
CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true)
}
RelayFailure(cmdFail)
}
}
/** all the channels point to the same next node, we take the first one */
private val nextNodeId_opt = channels.headOption.map(_._2.nextNodeId)
/** channel id explicitly requested in the onion payload */
private val requestedChannelId_opt = channels.collectFirst {
case (channelId, channel) if channel.shortIds.localAlias == r.payload.outgoingChannelId => channelId
case (channelId, channel) if channel.shortIds.real.toOption.contains(r.payload.outgoingChannelId) => channelId
}
/**
* Select a channel to the same node to relay the payment to, that has the lowest capacity and balance and is
* compatible in terms of fees, expiry_delta, etc.
*
* If no suitable channel is found we default to the originally requested channel.
*/
def selectPreferredChannel(alreadyTried: Seq[ByteVector32]): Option[OutgoingChannel] = {
val requestedShortChannelId = r.payload.outgoingChannelId
context.log.debug("selecting next channel with requestedShortChannelId={}", requestedShortChannelId)
private def selectPreferredChannel(alreadyTried: Seq[ByteVector32]): Option[OutgoingChannel] = {
context.log.debug("selecting next channel with requestedShortChannelId={}", requestedShortChannelId_opt)
// we filter out channels that we have already tried
val candidateChannels: Map[ByteVector32, OutgoingChannel] = channels -- alreadyTried
// and we filter again to keep the ones that are compatible with this payment (mainly fees, expiry delta)
candidateChannels
.values
.map { channel =>
val relayResult = relayOrFail(Some(channel))
context.log.debug(s"candidate channel: channelId=${channel.channelId} availableForSend={} capacity={} channelUpdate={} result={}",
val relayResult = relayOrFail(channel)
context.log.debug("candidate channel: channelId={} availableForSend={} capacity={} channelUpdate={} result={}",
channel.channelId,
channel.commitments.availableBalanceForSend,
channel.commitments.latest.capacity,
channel.channelUpdate,
@ -279,7 +319,7 @@ class ChannelRelay private(nodeParams: NodeParams,
context.log.debug("requested short channel id is our preferred channel")
Some(channel)
} else {
context.log.debug("replacing requestedShortChannelId={} by preferredShortChannelId={} with availableBalanceMsat={}", requestedShortChannelId, channel.channelUpdate.shortChannelId, channel.commitments.availableBalanceForSend)
context.log.debug("replacing requestedShortChannelId={} by preferredShortChannelId={} with availableBalanceMsat={}", requestedShortChannelId_opt, channel.channelUpdate.shortChannelId, channel.commitments.availableBalanceForSend)
Some(channel)
}
case None =>
@ -300,28 +340,35 @@ class ChannelRelay private(nodeParams: NodeParams,
* channel, because some parameters don't match with our settings for that channel. In that case we directly fail the
* htlc.
*/
def relayOrFail(outgoingChannel_opt: Option[OutgoingChannelParams]): RelayResult = {
outgoingChannel_opt match {
private def relayOrFail(outgoingChannel: OutgoingChannelParams): RelayResult = {
val update = outgoingChannel.channelUpdate
validateRelayParams(outgoingChannel) match {
case Some(fail) =>
RelayFailure(fail)
case None if !update.channelFlags.isEnabled =>
RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(ChannelDisabled(update.messageFlags, update.channelFlags, Some(update))), commit = true))
case None =>
RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true))
case Some(c) if !c.channelUpdate.channelFlags.isEnabled =>
RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(ChannelDisabled(c.channelUpdate.messageFlags, c.channelUpdate.channelFlags, Some(c.channelUpdate))), commit = true))
case Some(c) if r.amountToForward < c.channelUpdate.htlcMinimumMsat =>
RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(AmountBelowMinimum(r.amountToForward, Some(c.channelUpdate))), commit = true))
case Some(c) if r.expiryDelta < c.channelUpdate.cltvExpiryDelta =>
RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(IncorrectCltvExpiry(r.outgoingCltv, Some(c.channelUpdate))), commit = true))
case Some(c) if r.relayFeeMsat < nodeFee(c.channelUpdate.relayFees, r.amountToForward) &&
// fees also do not satisfy the previous channel update for `enforcementDelay` seconds after current update
(TimestampSecond.now() - c.channelUpdate.timestamp > nodeParams.relayParams.enforcementDelay ||
outgoingChannel_opt.flatMap(_.prevChannelUpdate).forall(c => r.relayFeeMsat < nodeFee(c.relayFees, r.amountToForward))) =>
RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(FeeInsufficient(r.add.amountMsat, Some(c.channelUpdate))), commit = true))
case Some(c: OutgoingChannel) =>
val origin = Origin.Hot(addResponseAdapter.toClassic, upstream)
val nextBlindingKey_opt = r.payload match {
case payload: IntermediatePayload.ChannelRelay.Blinded => Some(payload.nextBlinding)
case _: IntermediatePayload.ChannelRelay.Standard => None
}
RelaySuccess(c.channelId, CMD_ADD_HTLC(addResponseAdapter.toClassic, r.amountToForward, r.add.paymentHash, r.outgoingCltv, r.nextPacket, nextBlindingKey_opt, confidence, origin, commit = true))
RelaySuccess(outgoingChannel.channelId, CMD_ADD_HTLC(addResponseAdapter.toClassic, r.amountToForward, r.add.paymentHash, r.outgoingCltv, r.nextPacket, nextBlindingKey_opt, confidence, origin, commit = true))
}
}
private def validateRelayParams(outgoingChannel: OutgoingChannelParams): Option[CMD_FAIL_HTLC] = {
val update = outgoingChannel.channelUpdate
// If our current channel update was recently created, we accept payments that used our previous channel update.
val allowPreviousUpdate = TimestampSecond.now() - update.timestamp <= nodeParams.relayParams.enforcementDelay
val prevUpdate_opt = if (allowPreviousUpdate) outgoingChannel.prevChannelUpdate else None
val htlcMinimumOk = update.htlcMinimumMsat <= r.amountToForward || prevUpdate_opt.exists(_.htlcMinimumMsat <= r.amountToForward)
val expiryDeltaOk = update.cltvExpiryDelta <= r.expiryDelta || prevUpdate_opt.exists(_.cltvExpiryDelta <= r.expiryDelta)
val feesOk = nodeFee(update.relayFees, r.amountToForward) <= r.relayFeeMsat || prevUpdate_opt.exists(u => nodeFee(u.relayFees, r.amountToForward) <= r.relayFeeMsat)
if (!htlcMinimumOk) {
Some(CMD_FAIL_HTLC(r.add.id, Right(AmountBelowMinimum(r.amountToForward, Some(update))), commit = true))
} else if (!expiryDeltaOk) {
Some(CMD_FAIL_HTLC(r.add.id, Right(IncorrectCltvExpiry(r.outgoingCltv, Some(update))), commit = true))
} else if (!feesOk) {
Some(CMD_FAIL_HTLC(r.add.id, Right(FeeInsufficient(r.add.amountMsat, Some(update))), commit = true))
} else {
None
}
}

View File

@ -24,7 +24,7 @@ import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.channel._
import fr.acinq.eclair.payment.IncomingPaymentPacket
import fr.acinq.eclair.{SubscriptionsComplete, Logs, NodeParams, ShortChannelId}
import fr.acinq.eclair.{Logs, NodeParams, ShortChannelId, SubscriptionsComplete}
import java.util.UUID
import scala.collection.mutable
@ -70,9 +70,12 @@ object ChannelRelayer {
Behaviors.receiveMessage {
case Relay(channelRelayPacket, originNode) =>
val relayId = UUID.randomUUID()
val nextNodeId_opt: Option[PublicKey] = scid2channels.get(channelRelayPacket.payload.outgoingChannelId) match {
case Some(channelId) => channels.get(channelId).map(_.nextNodeId)
case None => None
val nextNodeId_opt: Option[PublicKey] = channelRelayPacket.payload.outgoing match {
case Left(walletNodeId) => Some(walletNodeId)
case Right(outgoingChannelId) => scid2channels.get(outgoingChannelId) match {
case Some(channelId) => channels.get(channelId).map(_.nextNodeId)
case None => None
}
}
val nextChannels: Map[ByteVector32, Relayer.OutgoingChannel] = nextNodeId_opt match {
case Some(nextNodeId) => node2channels.get(nextNodeId).flatMap(channels.get).map(c => c.channelId -> c).toMap

View File

@ -26,6 +26,7 @@ import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Upstream}
import fr.acinq.eclair.db.PendingCommandsDb
import fr.acinq.eclair.io.PeerReadyNotifier
import fr.acinq.eclair.payment.IncomingPaymentPacket.NodeRelayPacket
import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags}
import fr.acinq.eclair.payment._
@ -40,7 +41,7 @@ import fr.acinq.eclair.router.Router.RouteParams
import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound}
import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32, randomKey}
import fr.acinq.eclair.{CltvExpiry, EncodedNodeId, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32}
import java.util.UUID
import java.util.concurrent.TimeUnit
@ -62,7 +63,7 @@ object NodeRelay {
private case class WrappedPreimageReceived(preimageReceived: PreimageReceived) extends Command
private case class WrappedPaymentSent(paymentSent: PaymentSent) extends Command
private case class WrappedPaymentFailed(paymentFailed: PaymentFailed) extends Command
private[relay] case class WrappedPeerReadyResult(result: AsyncPaymentTriggerer.Result) extends Command
private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command
private case class WrappedResolvedPaths(resolved: Seq[ResolvedPath]) extends Command
// @formatter:on
@ -88,7 +89,6 @@ object NodeRelay {
relayId: UUID,
nodeRelayPacket: NodeRelayPacket,
outgoingPaymentFactory: OutgoingPaymentFactory,
triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command],
router: ActorRef): Behavior[Command] =
Behaviors.setup { context =>
val paymentHash = nodeRelayPacket.add.paymentHash
@ -108,7 +108,7 @@ object NodeRelay {
case IncomingPaymentPacket.RelayToTrampolinePacket(_, _, _, nextPacket) => Some(nextPacket)
case _: IncomingPaymentPacket.RelayToBlindedPathsPacket => None
}
new NodeRelay(nodeParams, parent, register, relayId, paymentHash, nodeRelayPacket.outerPayload.paymentSecret, context, outgoingPaymentFactory, triggerer, router)
new NodeRelay(nodeParams, parent, register, relayId, paymentHash, nodeRelayPacket.outerPayload.paymentSecret, context, outgoingPaymentFactory, router)
.receiving(Queue.empty, nodeRelayPacket.innerPayload, nextPacket_opt, incomingPaymentHandler)
}
}
@ -125,14 +125,29 @@ object NodeRelay {
Some(InvalidOnionPayload(UInt64(2), 0))
} else {
payloadOut match {
case payloadOut: IntermediatePayload.NodeRelay.Standard =>
if (payloadOut.invoiceFeatures.isDefined && payloadOut.paymentSecret.isEmpty) {
Some(InvalidOnionPayload(UInt64(8), 0)) // payment secret field is missing
} else {
None
}
case _: IntermediatePayload.NodeRelay.ToBlindedPaths =>
None
// If we're relaying a standard payment to a non-trampoline recipient, we need the payment secret.
case payloadOut: IntermediatePayload.NodeRelay.Standard if payloadOut.invoiceFeatures.isDefined && payloadOut.paymentSecret.isEmpty => Some(InvalidOnionPayload(UInt64(8), 0))
case _: IntermediatePayload.NodeRelay.Standard => None
case _: IntermediatePayload.NodeRelay.ToBlindedPaths => None
}
}
}
/** This function identifies whether the next node is a wallet node directly connected to us, and returns its node_id. */
private def nextWalletNodeId(nodeParams: NodeParams, recipient: Recipient): Option[PublicKey] = {
recipient match {
// These two recipients are only used when we're the payment initiator.
case _: SpontaneousRecipient => None
case _: TrampolineRecipient => None
// When relaying to a trampoline node, the next node may be a wallet node directly connected to us, but we don't
// want to have false positives. Feature branches should check an internal DB/cache to confirm.
case r: ClearRecipient if r.nextTrampolineOnion_opt.nonEmpty => None
// If we're relaying to a non-trampoline recipient, it's never a wallet node.
case _: ClearRecipient => None
// When using blinded paths, we may be the introduction node for a wallet node directly connected to us.
case r: BlindedRecipient => r.blindedHops.head.resolved.route match {
case BlindedPathsResolver.PartialBlindedRoute(walletNodeId: EncodedNodeId.WithPublicKey.Wallet, _, _) => Some(walletNodeId.publicKey)
case _ => None
}
}
}
@ -188,7 +203,6 @@ class NodeRelay private(nodeParams: NodeParams,
paymentSecret: ByteVector32,
context: ActorContext[NodeRelay.Command],
outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory,
triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command],
router: ActorRef) {
import NodeRelay._
@ -223,44 +237,102 @@ class NodeRelay private(nodeParams: NodeParams,
rejectPayment(upstream, Some(failure))
stopping()
case None =>
nextPayload match {
// TODO: async payments are not currently supported for blinded recipients. We should update the AsyncPaymentTriggerer to decrypt the blinded path.
case nextPayload: IntermediatePayload.NodeRelay.Standard if nextPayload.isAsyncPayment && nodeParams.features.hasFeature(Features.AsyncPaymentPrototype) =>
waitForTrigger(upstream, nextPayload, nextPacket_opt)
case _ =>
doSend(upstream, nextPayload, nextPacket_opt)
}
resolveNextNode(upstream, nextPayload, nextPacket_opt)
}
}
private def waitForTrigger(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay.Standard, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = {
context.log.info(s"waiting for async payment to trigger before relaying trampoline payment (amountIn=${upstream.amountIn} expiryIn=${upstream.expiryIn} amountOut=${nextPayload.amountToForward} expiryOut=${nextPayload.outgoingCltv}, asyncPaymentsParams=${nodeParams.relayParams.asyncPaymentsParams})")
val timeoutBlock = nodeParams.currentBlockHeight + nodeParams.relayParams.asyncPaymentsParams.holdTimeoutBlocks
val safetyBlock = (upstream.expiryIn - nodeParams.relayParams.asyncPaymentsParams.cancelSafetyBeforeTimeout).blockHeight
// wait for notification until which ever occurs first: the hold timeout block or the safety block
val notifierTimeout = Seq(timeoutBlock, safetyBlock).min
val peerReadyResultAdapter = context.messageAdapter[AsyncPaymentTriggerer.Result](WrappedPeerReadyResult)
triggerer ! AsyncPaymentTriggerer.Watch(peerReadyResultAdapter, nextPayload.outgoingNodeId, paymentHash, notifierTimeout)
context.system.eventStream ! EventStream.Publish(WaitingToRelayPayment(nextPayload.outgoingNodeId, paymentHash))
Behaviors.receiveMessagePartial {
case WrappedPeerReadyResult(AsyncPaymentTriggerer.AsyncPaymentTimeout) =>
context.log.warn("rejecting async payment; was not triggered before block {}", notifierTimeout)
rejectPayment(upstream, Some(TemporaryNodeFailure())) // TODO: replace failure type when async payment spec is finalized
stopping()
case WrappedPeerReadyResult(AsyncPaymentTriggerer.AsyncPaymentCanceled) =>
context.log.warn(s"payment sender canceled a waiting async payment")
rejectPayment(upstream, Some(TemporaryNodeFailure())) // TODO: replace failure type when async payment spec is finalized
stopping()
case WrappedPeerReadyResult(AsyncPaymentTriggerer.AsyncPaymentTriggered) =>
doSend(upstream, nextPayload, nextPacket_opt)
/** Once we've fully received the incoming HTLC set, we must identify the next node before forwarding the payment. */
private def resolveNextNode(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = {
nextPayload match {
case payloadOut: IntermediatePayload.NodeRelay.Standard =>
// If invoice features are provided in the onion, the sender is asking us to relay to a non-trampoline recipient.
payloadOut.invoiceFeatures match {
case Some(features) =>
val extraEdges = payloadOut.invoiceRoutingInfo.getOrElse(Nil).flatMap(Bolt11Invoice.toExtraEdges(_, payloadOut.outgoingNodeId))
val paymentSecret = payloadOut.paymentSecret.get // NB: we've verified that there was a payment secret in validateRelay
val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features(features).invoiceFeatures(), payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, extraEdges, payloadOut.paymentMetadata)
context.log.debug("forwarding payment to non-trampoline recipient {}", recipient.nodeId)
ensureRecipientReady(upstream, recipient, nextPayload, None)
case None =>
val paymentSecret = randomBytes32() // we generate a new secret to protect against probing attacks
val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features.empty, payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, nextTrampolineOnion_opt = nextPacket_opt)
context.log.debug("forwarding payment to the next trampoline node {}", recipient.nodeId)
ensureRecipientReady(upstream, recipient, nextPayload, nextPacket_opt)
}
case payloadOut: IntermediatePayload.NodeRelay.ToBlindedPaths =>
// Blinded paths in Bolt 12 invoices may encode the introduction node with an scid and a direction: we need to
// resolve that to a nodeId in order to reach that introduction node and use the blinded path.
// If we are the introduction node ourselves, we'll also need to decrypt the onion and identify the next node.
context.spawnAnonymous(BlindedPathsResolver(nodeParams, paymentHash, router, register)) ! Resolve(context.messageAdapter[Seq[ResolvedPath]](WrappedResolvedPaths), payloadOut.outgoingBlindedPaths)
Behaviors.receiveMessagePartial {
rejectExtraHtlcPartialFunction orElse {
case WrappedResolvedPaths(resolved) if resolved.isEmpty =>
context.log.warn("rejecting trampoline payment to blinded paths: no usable blinded path")
rejectPayment(upstream, Some(UnknownNextPeer()))
stopping()
case WrappedResolvedPaths(resolved) =>
// We don't have access to the invoice: we use the only node_id that somewhat makes sense for the recipient.
val blindedNodeId = resolved.head.route.blindedNodeIds.last
val recipient = BlindedRecipient.fromPaths(blindedNodeId, Features(payloadOut.invoiceFeatures).invoiceFeatures(), payloadOut.amountToForward, payloadOut.outgoingCltv, resolved, Set.empty)
context.log.debug("forwarding payment to blinded recipient {}", recipient.nodeId)
ensureRecipientReady(upstream, recipient, nextPayload, nextPacket_opt)
}
}
}
}
private def doSend(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = {
context.log.debug(s"relaying trampoline payment (amountIn=${upstream.amountIn} expiryIn=${upstream.expiryIn} amountOut=${nextPayload.amountToForward} expiryOut=${nextPayload.outgoingCltv})")
/**
* The next node may be a mobile wallet directly connected to us: in that case, we'll need to wake them up before
* relaying the payment.
*/
private def ensureRecipientReady(upstream: Upstream.Hot.Trampoline, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = {
nextWalletNodeId(nodeParams, recipient) match {
case Some(walletNodeId) if nodeParams.peerWakeUpConfig.enabled => waitForPeerReady(upstream, walletNodeId, recipient, nextPayload, nextPacket_opt)
case _ => relay(upstream, recipient, nextPayload, nextPacket_opt)
}
}
/**
* The next node is the payment recipient. They are directly connected to us and may be offline. We try to wake them
* up and will relay the payment once they're connected and channels are reestablished.
*/
private def waitForPeerReady(upstream: Upstream.Hot.Trampoline, walletNodeId: PublicKey, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = {
context.log.info("trying to wake up next peer (nodeId={})", walletNodeId)
val notifier = context.spawnAnonymous(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout))))
notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult))
Behaviors.receiveMessagePartial {
rejectExtraHtlcPartialFunction orElse {
case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) =>
context.log.warn("rejecting payment: failed to wake-up remote peer")
rejectPayment(upstream, Some(UnknownNextPeer()))
stopping()
case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerReady) =>
relay(upstream, recipient, nextPayload, nextPacket_opt)
}
}
}
/** Relay the payment to the next identified node: this is similar to sending an outgoing payment. */
private def relay(upstream: Upstream.Hot.Trampoline, recipient: Recipient, payloadOut: IntermediatePayload.NodeRelay, packetOut_opt: Option[OnionRoutingPacket]): Behavior[Command] = {
context.log.debug("relaying trampoline payment (amountIn={} expiryIn={} amountOut={} expiryOut={})", upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv)
val confidence = (upstream.received.map(_.add.endorsement).min + 0.5) / 8
relay(upstream, nextPayload, nextPacket_opt, confidence)
val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, recipient.nodeId, upstream, None, None, storeInDb = false, publishEvent = false, recordPathFindingMetrics = true, confidence)
val routeParams = computeRouteParams(nodeParams, upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv)
// If the next node is using trampoline, we assume that they support MPP.
val useMultiPart = recipient.features.hasFeature(Features.BasicMultiPartPayment) || packetOut_opt.nonEmpty
val payFsmAdapters = {
context.messageAdapter[PreimageReceived](WrappedPreimageReceived)
context.messageAdapter[PaymentSent](WrappedPaymentSent)
context.messageAdapter[PaymentFailed](WrappedPaymentFailed)
}.toClassic
val payment = if (useMultiPart) {
SendMultiPartPayment(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams)
} else {
SendPaymentToNode(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams)
}
val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, useMultiPart)
payFSM ! payment
sending(upstream, payloadOut, recipient, TimestampMilli.now(), fulfilledUpstream = false)
}
/**
@ -270,7 +342,11 @@ class NodeRelay private(nodeParams: NodeParams,
* @param nextPayload relay instructions.
* @param fulfilledUpstream true if we already fulfilled the payment upstream.
*/
private def sending(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay, startedAt: TimestampMilli, fulfilledUpstream: Boolean): Behavior[Command] =
private def sending(upstream: Upstream.Hot.Trampoline,
nextPayload: IntermediatePayload.NodeRelay,
recipient: Recipient,
startedAt: TimestampMilli,
fulfilledUpstream: Boolean): Behavior[Command] =
Behaviors.receiveMessagePartial {
rejectExtraHtlcPartialFunction orElse {
// this is the fulfill that arrives from downstream channels
@ -279,7 +355,7 @@ class NodeRelay private(nodeParams: NodeParams,
// We want to fulfill upstream as soon as we receive the preimage (even if not all HTLCs have fulfilled downstream).
context.log.debug("got preimage from downstream")
fulfillPayment(upstream, paymentPreimage)
sending(upstream, nextPayload, startedAt, fulfilledUpstream = true)
sending(upstream, nextPayload, recipient, startedAt, fulfilledUpstream = true)
} else {
// we don't want to fulfill multiple times
Behaviors.same
@ -311,80 +387,6 @@ class NodeRelay private(nodeParams: NodeParams,
}
}
private val payFsmAdapters = {
context.messageAdapter[PreimageReceived](WrappedPreimageReceived)
context.messageAdapter[PaymentSent](WrappedPaymentSent)
context.messageAdapter[PaymentFailed](WrappedPaymentFailed)
}.toClassic
private def relay(upstream: Upstream.Hot.Trampoline, payloadOut: IntermediatePayload.NodeRelay, packetOut_opt: Option[OnionRoutingPacket], confidence: Double): Behavior[Command] = {
val displayNodeId = payloadOut match {
case payloadOut: IntermediatePayload.NodeRelay.Standard => payloadOut.outgoingNodeId
case _: IntermediatePayload.NodeRelay.ToBlindedPaths => randomKey().publicKey
}
val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, displayNodeId, upstream, None, None, storeInDb = false, publishEvent = false, recordPathFindingMetrics = true, confidence)
val routeParams = computeRouteParams(nodeParams, upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv)
payloadOut match {
case payloadOut: IntermediatePayload.NodeRelay.Standard =>
// If invoice features are provided in the onion, the sender is asking us to relay to a non-trampoline recipient.
payloadOut.invoiceFeatures match {
case Some(features) =>
val extraEdges = payloadOut.invoiceRoutingInfo.getOrElse(Nil).flatMap(Bolt11Invoice.toExtraEdges(_, payloadOut.outgoingNodeId))
val paymentSecret = payloadOut.paymentSecret.get // NB: we've verified that there was a payment secret in validateRelay
val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features(features).invoiceFeatures(), payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, extraEdges, payloadOut.paymentMetadata)
context.log.debug("sending the payment to non-trampoline recipient (MPP={})", recipient.features.hasFeature(Features.BasicMultiPartPayment))
relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, useMultiPart = recipient.features.hasFeature(Features.BasicMultiPartPayment))
case None =>
context.log.debug("sending the payment to the next trampoline node")
val paymentSecret = randomBytes32() // we generate a new secret to protect against probing attacks
val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features.empty, payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, nextTrampolineOnion_opt = packetOut_opt)
relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, useMultiPart = true)
}
case payloadOut: IntermediatePayload.NodeRelay.ToBlindedPaths =>
context.spawnAnonymous(BlindedPathsResolver(nodeParams, paymentHash, router, register)) ! Resolve(context.messageAdapter[Seq[ResolvedPath]](WrappedResolvedPaths), payloadOut.outgoingBlindedPaths)
waitForResolvedPaths(upstream, payloadOut, paymentCfg, routeParams)
}
}
private def relayToRecipient(upstream: Upstream.Hot.Trampoline,
payloadOut: IntermediatePayload.NodeRelay,
recipient: Recipient,
paymentCfg: SendPaymentConfig,
routeParams: RouteParams,
useMultiPart: Boolean): Behavior[Command] = {
val payment =
if (useMultiPart) {
SendMultiPartPayment(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams)
} else {
SendPaymentToNode(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams)
}
val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, useMultiPart)
payFSM ! payment
sending(upstream, payloadOut, TimestampMilli.now(), fulfilledUpstream = false)
}
/**
* Blinded paths in Bolt 12 invoices may encode the introduction node with an scid and a direction: we need to resolve
* that to a nodeId in order to reach that introduction node and use the blinded path.
*/
private def waitForResolvedPaths(upstream: Upstream.Hot.Trampoline,
payloadOut: IntermediatePayload.NodeRelay.ToBlindedPaths,
paymentCfg: SendPaymentConfig,
routeParams: RouteParams): Behavior[Command] =
Behaviors.receiveMessagePartial {
case WrappedResolvedPaths(resolved) if resolved.isEmpty =>
context.log.warn(s"rejecting trampoline payment to blinded paths: no usable blinded path")
rejectPayment(upstream, Some(UnknownNextPeer()))
stopping()
case WrappedResolvedPaths(resolved) =>
val features = Features(payloadOut.invoiceFeatures).invoiceFeatures()
// We don't have access to the invoice: we use the only node_id that somewhat makes sense for the recipient.
val blindedNodeId = resolved.head.route.blindedNodeIds.last
val recipient = BlindedRecipient.fromPaths(blindedNodeId, features, payloadOut.amountToForward, payloadOut.outgoingCltv, resolved, Set.empty)
context.log.debug("sending the payment to blinded recipient, useMultiPart={}", features.hasFeature(Features.BasicMultiPartPayment))
relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, features.hasFeature(Features.BasicMultiPartPayment))
}
private def rejectExtraHtlcPartialFunction: PartialFunction[Command, Behavior[Command]] = {
case Relay(nodeRelayPacket, _) =>
rejectExtraHtlc(nodeRelayPacket.add)

View File

@ -16,7 +16,6 @@
package fr.acinq.eclair.payment.relay
import akka.actor.typed
import akka.actor.typed.scaladsl.Behaviors
import akka.actor.typed.{ActorRef, Behavior}
import fr.acinq.bitcoin.scalacompat.ByteVector32
@ -58,7 +57,7 @@ object NodeRelayer {
* NB: the payment secret used here is different from the invoice's payment secret and ensures we can
* group together HTLCs that the previous trampoline node sent in the same MPP.
*/
def apply(nodeParams: NodeParams, register: akka.actor.ActorRef, outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory, triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command], router: akka.actor.ActorRef, children: Map[PaymentKey, ActorRef[NodeRelay.Command]] = Map.empty): Behavior[Command] =
def apply(nodeParams: NodeParams, register: akka.actor.ActorRef, outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory, router: akka.actor.ActorRef, children: Map[PaymentKey, ActorRef[NodeRelay.Command]] = Map.empty): Behavior[Command] =
Behaviors.setup { context =>
Behaviors.withMdc(Logs.mdc(category_opt = Some(Logs.LogCategory.PAYMENT)), mdc) {
Behaviors.receiveMessage {
@ -73,15 +72,15 @@ object NodeRelayer {
case None =>
val relayId = UUID.randomUUID()
context.log.debug(s"spawning a new handler with relayId=$relayId")
val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, nodeRelayPacket, outgoingPaymentFactory, triggerer, router), relayId.toString)
val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, nodeRelayPacket, outgoingPaymentFactory, router), relayId.toString)
context.log.debug("forwarding incoming htlc #{} from channel {} to new handler", htlcIn.id, htlcIn.channelId)
handler ! NodeRelay.Relay(nodeRelayPacket, originNode)
apply(nodeParams, register, outgoingPaymentFactory, triggerer, router, children + (childKey -> handler))
apply(nodeParams, register, outgoingPaymentFactory, router, children + (childKey -> handler))
}
case RelayComplete(childHandler, paymentHash, paymentSecret) =>
// we do a back-and-forth between parent and child before stopping the child to prevent a race condition
childHandler ! NodeRelay.Stop
apply(nodeParams, register, outgoingPaymentFactory, triggerer, router, children - PaymentKey(paymentHash, paymentSecret))
apply(nodeParams, register, outgoingPaymentFactory, router, children - PaymentKey(paymentHash, paymentSecret))
case GetPendingPayments(replyTo) =>
replyTo ! children
Behaviors.same

View File

@ -49,7 +49,7 @@ import scala.util.Random
* It also receives channel HTLC events (fulfill / failed) and relays those to the appropriate handlers.
* It also maintains an up-to-date view of local channel balances.
*/
class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paymentHandler: ActorRef, triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command], initialized: Option[Promise[Done]] = None) extends Actor with DiagnosticActorLogging {
class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paymentHandler: ActorRef, initialized: Option[Promise[Done]] = None) extends Actor with DiagnosticActorLogging {
import Relayer._
@ -58,7 +58,7 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym
private val postRestartCleaner = context.actorOf(PostRestartHtlcCleaner.props(nodeParams, register, initialized), "post-restart-htlc-cleaner")
private val channelRelayer = context.spawn(Behaviors.supervise(ChannelRelayer(nodeParams, register)).onFailure(SupervisorStrategy.resume), "channel-relayer")
private val nodeRelayer = context.spawn(Behaviors.supervise(NodeRelayer(nodeParams, register, NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register), triggerer, router)).onFailure(SupervisorStrategy.resume), name = "node-relayer")
private val nodeRelayer = context.spawn(Behaviors.supervise(NodeRelayer(nodeParams, register, NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register), router)).onFailure(SupervisorStrategy.resume), name = "node-relayer")
def receive: Receive = {
case init: PostRestartHtlcCleaner.Init => postRestartCleaner forward init
@ -120,8 +120,8 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym
object Relayer extends Logging {
def props(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paymentHandler: ActorRef, triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command], initialized: Option[Promise[Done]] = None): Props =
Props(new Relayer(nodeParams, router, register, paymentHandler, triggerer, initialized))
def props(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paymentHandler: ActorRef, initialized: Option[Promise[Done]] = None): Props =
Props(new Relayer(nodeParams, router, register, paymentHandler, initialized))
// @formatter:off
case class RelayFees(feeBase: MilliSatoshi, feeProportionalMillionths: Long) {

View File

@ -14,7 +14,7 @@ import fr.acinq.eclair.router.Router
import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo
import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.RouteBlindingDecryptedData
import fr.acinq.eclair.wire.protocol.{BlindedRouteData, OfferTypes, RouteBlindingEncryptedDataCodecs}
import fr.acinq.eclair.{EncodedNodeId, Logs, NodeParams}
import fr.acinq.eclair.{EncodedNodeId, Logs, MilliSatoshiLong, NodeParams, ShortChannelId}
import scodec.bits.ByteVector
import scala.annotation.tailrec
@ -45,8 +45,8 @@ object BlindedPathsResolver {
override val firstNodeId: PublicKey = introductionNodeId
}
/** A partially unwrapped blinded route that started at our node: it only contains the part of the route after our node. */
case class PartialBlindedRoute(nextNodeId: PublicKey, nextBlinding: PublicKey, blindedNodes: Seq[BlindedNode]) extends ResolvedBlindedRoute {
override val firstNodeId: PublicKey = nextNodeId
case class PartialBlindedRoute(nextNodeId: EncodedNodeId.WithPublicKey, nextBlinding: PublicKey, blindedNodes: Seq[BlindedNode]) extends ResolvedBlindedRoute {
override val firstNodeId: PublicKey = nextNodeId.publicKey
}
// @formatter:on
@ -111,8 +111,14 @@ private class BlindedPathsResolver(nodeParams: NodeParams,
feeProportionalMillionths = nextFeeProportionalMillionths,
cltvExpiryDelta = nextCltvExpiryDelta
)
register ! Register.GetNextNodeId(context.messageAdapter(WrappedNodeId), paymentRelayData.outgoingChannelId)
waitForNextNodeId(nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved)
paymentRelayData.outgoing match {
case Left(outgoingNodeId) =>
// The next node seems to be a wallet node directly connected to us.
validateRelay(EncodedNodeId.WithPublicKey.Wallet(outgoingNodeId), nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved)
case Right(outgoingChannelId) =>
register ! Register.GetNextNodeId(context.messageAdapter(WrappedNodeId), outgoingChannelId)
waitForNextNodeId(outgoingChannelId, nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved)
}
}
}
case encodedNodeId: EncodedNodeId.WithPublicKey =>
@ -129,7 +135,8 @@ private class BlindedPathsResolver(nodeParams: NodeParams,
}
/** Resolve the next node in the blinded path when we are the introduction node. */
private def waitForNextNodeId(nextPaymentInfo: OfferTypes.PaymentInfo,
private def waitForNextNodeId(outgoingChannelId: ShortChannelId,
nextPaymentInfo: OfferTypes.PaymentInfo,
paymentRelayData: BlindedRouteData.PaymentRelayData,
nextBlinding: PublicKey,
nextBlindedNodes: Seq[RouteBlinding.BlindedNode],
@ -137,29 +144,42 @@ private class BlindedPathsResolver(nodeParams: NodeParams,
resolved: Seq[ResolvedPath]): Behavior[Command] =
Behaviors.receiveMessagePartial {
case WrappedNodeId(None) =>
context.log.warn("ignoring blinded path starting at our node: could not resolve outgoingChannelId={}", paymentRelayData.outgoingChannelId)
context.log.warn("ignoring blinded path starting at our node: could not resolve outgoingChannelId={}", outgoingChannelId)
resolveBlindedPaths(toResolve, resolved)
case WrappedNodeId(Some(nodeId)) if nodeId == nodeParams.nodeId =>
// The next node in the route is also our node: this is fishy, there is not reason to include us in the route twice.
context.log.warn("ignoring blinded path starting at our node relaying to ourselves")
resolveBlindedPaths(toResolve, resolved)
case WrappedNodeId(Some(nodeId)) =>
// Note that we default to private fees if we don't have a channel yet with that node.
// The announceChannel parameter is ignored if we already have a channel.
val relayFees = getRelayFees(nodeParams, nodeId, announceChannel = false)
val shouldRelay = paymentRelayData.paymentRelay.feeBase >= relayFees.feeBase &&
paymentRelayData.paymentRelay.feeProportionalMillionths >= relayFees.feeProportionalMillionths &&
paymentRelayData.paymentRelay.cltvExpiryDelta >= nodeParams.channelConf.expiryDelta
if (shouldRelay) {
context.log.debug("unwrapped blinded path starting at our node: next_node={}", nodeId)
val path = ResolvedPath(PartialBlindedRoute(nodeId, nextBlinding, nextBlindedNodes), nextPaymentInfo)
resolveBlindedPaths(toResolve, resolved :+ path)
} else {
context.log.warn("ignoring blinded path starting at our node: allocated fees are too low (base={}, proportional={}, expiryDelta={})", paymentRelayData.paymentRelay.feeBase, paymentRelayData.paymentRelay.feeProportionalMillionths, paymentRelayData.paymentRelay.cltvExpiryDelta)
resolveBlindedPaths(toResolve, resolved)
}
validateRelay(EncodedNodeId.WithPublicKey.Plain(nodeId), nextPaymentInfo, paymentRelayData, nextBlinding, nextBlindedNodes, toResolve, resolved)
}
private def validateRelay(nextNodeId: EncodedNodeId.WithPublicKey,
nextPaymentInfo: OfferTypes.PaymentInfo,
paymentRelayData: BlindedRouteData.PaymentRelayData,
nextBlinding: PublicKey,
nextBlindedNodes: Seq[RouteBlinding.BlindedNode],
toResolve: Seq[PaymentBlindedRoute],
resolved: Seq[ResolvedPath]): Behavior[Command] = {
// Note that we default to private fees if we don't have a channel yet with that node.
// The announceChannel parameter is ignored if we already have a channel.
val relayFees = getRelayFees(nodeParams, nextNodeId.publicKey, announceChannel = false)
val shouldRelay = paymentRelayData.paymentRelay.feeBase >= relayFees.feeBase &&
paymentRelayData.paymentRelay.feeProportionalMillionths >= relayFees.feeProportionalMillionths &&
paymentRelayData.paymentRelay.cltvExpiryDelta >= nodeParams.channelConf.expiryDelta &&
nextPaymentInfo.feeBase >= 0.msat &&
nextPaymentInfo.feeProportionalMillionths >= 0 &&
nextPaymentInfo.cltvExpiryDelta.toInt >= 0
if (shouldRelay) {
context.log.debug("unwrapped blinded path starting at our node: next_node={}", nextNodeId.publicKey)
val path = ResolvedPath(PartialBlindedRoute(nextNodeId, nextBlinding, nextBlindedNodes), nextPaymentInfo)
resolveBlindedPaths(toResolve, resolved :+ path)
} else {
context.log.warn("ignoring blinded path starting at our node: allocated fees are too low (base={}, proportional={}, expiryDelta={})", paymentRelayData.paymentRelay.feeBase, paymentRelayData.paymentRelay.feeProportionalMillionths, paymentRelayData.paymentRelay.cltvExpiryDelta)
resolveBlindedPaths(toResolve, resolved)
}
}
/** Resolve the introduction node's [[EncodedNodeId.ShortChannelIdDir]] to the corresponding [[EncodedNodeId.WithPublicKey]]. */
private def waitForNodeId(paymentRoute: PaymentBlindedRoute, toResolve: Seq[PaymentBlindedRoute], resolved: Seq[ResolvedPath]): Behavior[Command] =
Behaviors.receiveMessagePartial {

View File

@ -21,7 +21,7 @@ import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.router.Router.ChannelHop
import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo
import fr.acinq.eclair.wire.protocol.{RouteBlindingEncryptedDataCodecs, RouteBlindingEncryptedDataTlv, TlvStream}
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, randomKey}
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshi, MilliSatoshiLong, randomKey}
import scodec.bits.ByteVector
object BlindedRouteCreation {
@ -77,7 +77,7 @@ object BlindedRouteCreation {
Total: 24 to 36 bytes
*/
val targetLength = 36
val paddedPayloads = payloads.map(tlvs =>{
val paddedPayloads = payloads.map(tlvs => {
val payloadLength = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(tlvs).require.bytes.length
tlvs.copy(records = tlvs.records + RouteBlindingEncryptedDataTlv.Padding(ByteVector.fill(targetLength - payloadLength)(0)))
})
@ -95,4 +95,19 @@ object BlindedRouteCreation {
Sphinx.RouteBlinding.create(randomKey(), Seq(nodeId), Seq(finalPayload))
}
/** Create a blinded route where the recipient is a wallet node. */
def createBlindedRouteToWallet(hop: Router.ChannelHop, pathId: ByteVector, minAmount: MilliSatoshi, routeFinalExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = {
val routeExpiry = routeFinalExpiry + hop.cltvExpiryDelta
val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(
RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, minAmount),
RouteBlindingEncryptedDataTlv.PathId(pathId),
)).require.bytes
val intermediatePayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream[RouteBlindingEncryptedDataTlv](
RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Wallet(hop.nextNodeId)),
RouteBlindingEncryptedDataTlv.PaymentRelay(hop.cltvExpiryDelta, hop.params.relayFees.feeProportionalMillionths, hop.params.relayFees.feeBase),
RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, minAmount),
)).require.bytes
Sphinx.RouteBlinding.create(randomKey(), Seq(hop.nodeId, hop.nextNodeId), Seq(intermediatePayload, finalPayload))
}
}

View File

@ -23,7 +23,7 @@ import fr.acinq.eclair.wire.protocol.BlindedRouteData.PaymentRelayData
import fr.acinq.eclair.wire.protocol.CommonCodecs._
import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv}
import fr.acinq.eclair.wire.protocol.TlvCodecs._
import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, UInt64}
import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, ShortChannelId, UInt64}
import scodec.bits.{BitVector, ByteVector}
/**
@ -227,7 +227,8 @@ object PaymentOnion {
object IntermediatePayload {
sealed trait ChannelRelay extends IntermediatePayload {
// @formatter:off
def outgoingChannelId: ShortChannelId
/** The outgoing channel, or the nodeId of one of our peers. */
def outgoing: Either[PublicKey, ShortChannelId]
def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi
def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry
// @formatter:on
@ -238,7 +239,7 @@ object PaymentOnion {
// @formatter:off
val amountOut = records.get[AmountToForward].get.amount
val cltvOut = records.get[OutgoingCltv].get.cltv
override val outgoingChannelId = records.get[OutgoingChannelId].get.shortChannelId
override val outgoing = Right(records.get[OutgoingChannelId].get.shortChannelId)
override def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = amountOut
override def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = cltvOut
// @formatter:on
@ -258,12 +259,12 @@ object PaymentOnion {
}
/**
* @param blindedRecords decrypted tlv stream from the encrypted_recipient_data tlv.
* @param nextBlinding blinding point that must be forwarded to the next hop.
* @param paymentRelayData decrypted relaying data from the encrypted_recipient_data tlv.
* @param nextBlinding blinding point that must be forwarded to the next hop.
*/
case class Blinded(records: TlvStream[OnionPaymentPayloadTlv], paymentRelayData: PaymentRelayData, nextBlinding: PublicKey) extends ChannelRelay {
// @formatter:off
override val outgoingChannelId = paymentRelayData.outgoingChannelId
override val outgoing = paymentRelayData.outgoing
override def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = paymentRelayData.amountToForward(incomingAmount)
override def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = paymentRelayData.outgoingCltv(incomingCltv)
// @formatter:on

View File

@ -98,7 +98,11 @@ object BlindedRouteData {
}
case class PaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]) {
val outgoingChannelId: ShortChannelId = records.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId].get.shortChannelId
// This is usually a channel, unless the next node is a mobile wallet connected to our node.
val outgoing: Either[PublicKey, ShortChannelId] = records.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId] match {
case Some(r) => Right(r.shortChannelId)
case None => Left(records.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].get.nodeId.asInstanceOf[EncodedNodeId.WithPublicKey.Wallet].publicKey)
}
val paymentRelay: PaymentRelay = records.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get
val paymentConstraints: PaymentConstraints = records.get[RouteBlindingEncryptedDataTlv.PaymentConstraints].get
val allowedFeatures: Features[Feature] = records.get[RouteBlindingEncryptedDataTlv.AllowedFeatures].map(_.features).getOrElse(Features.empty)
@ -110,7 +114,9 @@ object BlindedRouteData {
}
def validatePaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]): Either[InvalidTlvPayload, PaymentRelayData] = {
if (records.get[OutgoingChannelId].isEmpty) return Left(MissingRequiredTlv(UInt64(2)))
// Note that the BOLTs require using an OutgoingChannelId, but we optionally support a wallet node_id.
if (records.get[OutgoingChannelId].isEmpty && records.get[OutgoingNodeId].isEmpty) return Left(MissingRequiredTlv(UInt64(2)))
if (records.get[OutgoingNodeId].nonEmpty && !records.get[OutgoingNodeId].get.nodeId.isInstanceOf[EncodedNodeId.WithPublicKey.Wallet]) return Left(ForbiddenTlv(UInt64(4)))
if (records.get[PaymentRelay].isEmpty) return Left(MissingRequiredTlv(UInt64(10)))
if (records.get[PaymentConstraints].isEmpty) return Left(MissingRequiredTlv(UInt64(12)))
if (records.get[PathId].nonEmpty) return Left(ForbiddenTlv(UInt64(6)))

View File

@ -26,7 +26,7 @@ import fr.acinq.eclair.channel.{ChannelFlags, LocalParams, Origin, Upstream}
import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyManager}
import fr.acinq.eclair.db.RevokedHtlcInfoCleaner
import fr.acinq.eclair.io.MessageRelay.RelayAll
import fr.acinq.eclair.io.{OpenChannelInterceptor, PeerConnection}
import fr.acinq.eclair.io.{OpenChannelInterceptor, PeerConnection, PeerReadyNotifier}
import fr.acinq.eclair.message.OnionMessages.OnionMessageConfig
import fr.acinq.eclair.payment.relay.Relayer.{AsyncPaymentsParams, RelayFees, RelayParams}
import fr.acinq.eclair.router.Graph.{MessagePath, WeightRatios}
@ -231,7 +231,8 @@ object TestConstants {
maxAttempts = 2,
),
purgeInvoicesInterval = None,
revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis)
revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis),
peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(enabled = false, timeout = 30 seconds),
)
def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams(
@ -401,7 +402,8 @@ object TestConstants {
maxAttempts = 2,
),
purgeInvoicesInterval = None,
revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis)
revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis),
peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(enabled = false, timeout = 30 seconds),
)
def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams(

View File

@ -66,8 +66,8 @@ class FuzzySpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Channe
val bobRegister = system.actorOf(Props(new TestRegister()))
val alicePaymentHandler = system.actorOf(Props(new PaymentHandler(aliceParams, aliceRegister, TestProbe().ref)))
val bobPaymentHandler = system.actorOf(Props(new PaymentHandler(bobParams, bobRegister, TestProbe().ref)))
val aliceRelayer = system.actorOf(Relayer.props(aliceParams, TestProbe().ref, aliceRegister, alicePaymentHandler, TestProbe().ref))
val bobRelayer = system.actorOf(Relayer.props(bobParams, TestProbe().ref, bobRegister, bobPaymentHandler, TestProbe().ref))
val aliceRelayer = system.actorOf(Relayer.props(aliceParams, TestProbe().ref, aliceRegister, alicePaymentHandler))
val bobRelayer = system.actorOf(Relayer.props(bobParams, TestProbe().ref, bobRegister, bobPaymentHandler))
val wallet = new DummyOnChainWallet()
val alice: TestFSMRef[ChannelState, ChannelData, Channel] = TestFSMRef(new Channel(aliceParams, wallet, bobParams.nodeId, alice2blockchain.ref, aliceRelayer, FakeTxPublisherFactory(alice2blockchain)), alicePeer.ref)
val bob: TestFSMRef[ChannelState, ChannelData, Channel] = TestFSMRef(new Channel(bobParams, wallet, aliceParams.nodeId, bob2blockchain.ref, bobRelayer, FakeTxPublisherFactory(bob2blockchain)), bobPeer.ref)

View File

@ -505,7 +505,7 @@ class SphinxSpec extends AnyFunSuite {
val Right(decryptedPayloadBob) = RouteBlindingEncryptedDataCodecs.decode(bob, blinding, tlvsBob.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data)
val blindingEphemeralKeyForCarol = decryptedPayloadBob.nextBlinding
val Right(payloadBob) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsBob, decryptedPayloadBob.tlvs, blindingEphemeralKeyForCarol)
assert(payloadBob.outgoingChannelId == ShortChannelId(1))
assert(payloadBob.outgoing.contains(ShortChannelId(1)))
assert(payloadBob.amountToForward(110_125 msat) == 100_125.msat)
assert(payloadBob.outgoingCltv(CltvExpiry(749150)) == CltvExpiry(749100))
assert(payloadBob.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(50), 0, 10_000 msat))
@ -523,7 +523,7 @@ class SphinxSpec extends AnyFunSuite {
val Right(decryptedPayloadCarol) = RouteBlindingEncryptedDataCodecs.decode(carol, blindingEphemeralKeyForCarol, tlvsCarol.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data)
val blindingEphemeralKeyForDave = decryptedPayloadCarol.nextBlinding
val Right(payloadCarol) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsCarol, decryptedPayloadCarol.tlvs, blindingEphemeralKeyForDave)
assert(payloadCarol.outgoingChannelId == ShortChannelId(2))
assert(payloadCarol.outgoing.contains(ShortChannelId(2)))
assert(payloadCarol.amountToForward(100_125 msat) == 100_010.msat)
assert(payloadCarol.outgoingCltv(CltvExpiry(749100)) == CltvExpiry(749025))
assert(payloadCarol.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(75), 150, 100 msat))
@ -544,7 +544,7 @@ class SphinxSpec extends AnyFunSuite {
val Right(decryptedPayloadDave) = RouteBlindingEncryptedDataCodecs.decode(dave, blindingOverride, tlvsDave.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data)
val blindingEphemeralKeyForEve = decryptedPayloadDave.nextBlinding
val Right(payloadDave) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsDave, decryptedPayloadDave.tlvs, blindingEphemeralKeyForEve)
assert(payloadDave.outgoingChannelId == ShortChannelId(3))
assert(payloadDave.outgoing.contains(ShortChannelId(3)))
assert(payloadDave.amountToForward(100_010 msat) == 100_000.msat)
assert(payloadDave.outgoingCltv(CltvExpiry(749025)) == CltvExpiry(749000))
assert(payloadDave.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(25), 100, 0 msat))

View File

@ -90,13 +90,12 @@ object MinimalNodeFixture extends Assertions with Eventually with IntegrationPat
val bitcoinClient = new TestBitcoinCoreClient()
val wallet = new SingleKeyOnChainWallet()
val watcher = TestProbe("watcher")
val triggerer = TestProbe("payment-triggerer")
val watcherTyped = watcher.ref.toTyped[ZmqWatcher.Command]
val register = system.actorOf(Register.props(), "register")
val router = system.actorOf(Router.props(nodeParams, watcherTyped), "router")
val offerManager = system.spawn(OfferManager(nodeParams, router, 1 minute), "offer-manager")
val paymentHandler = system.actorOf(PaymentHandler.props(nodeParams, register, offerManager), "payment-handler")
val relayer = system.actorOf(Relayer.props(nodeParams, router, register, paymentHandler, triggerer.ref.toTyped), "relayer")
val relayer = system.actorOf(Relayer.props(nodeParams, router, register, paymentHandler), "relayer")
val txPublisherFactory = Channel.SimpleTxPublisherFactory(nodeParams, watcherTyped, bitcoinClient)
val channelFactory = Peer.SimpleChannelFactory(nodeParams, watcherTyped, relayer, wallet, txPublisherFactory)
val pendingChannelsRateLimiter = system.spawnAnonymous(Behaviors.supervise(PendingChannelsRateLimiter(nodeParams, router.toTyped, Seq())).onFailure(typed.SupervisorStrategy.resume))

View File

@ -19,8 +19,10 @@ package fr.acinq.eclair.io
import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe => TypedProbe}
import akka.actor.typed.ActorRef
import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.scaladsl.adapter.TypedActorRefOps
import akka.actor.typed.receptionist.Receptionist
import akka.actor.typed.scaladsl.adapter.{ClassicActorRefOps, TypedActorRefOps}
import akka.testkit.TestProbe
import com.softwaremill.quicklens.ModifyPimp
import com.typesafe.config.ConfigFactory
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.TestConstants.{Alice, Bob}
@ -33,8 +35,8 @@ import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient}
import fr.acinq.eclair.router.Router
import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionMessagePayloadTlv, TlvStream}
import fr.acinq.eclair.{EncodedNodeId, RealShortChannelId, ShortChannelId, UInt64, randomBytes32, randomKey}
import org.scalatest.Outcome
import org.scalatest.funsuite.FixtureAnyFunSuiteLike
import org.scalatest.{Outcome, Tag}
import scodec.bits.HexStringSyntax
import scala.concurrent.duration.DurationInt
@ -43,19 +45,30 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app
val aliceId: PublicKey = Alice.nodeParams.nodeId
val bobId: PublicKey = Bob.nodeParams.nodeId
case class FixtureParam(relay: ActorRef[Command], switchboard: TestProbe, register: TestProbe, router: TypedProbe[Router.GetNodeId], peerConnection: TypedProbe[Nothing], peer: TypedProbe[Peer.RelayOnionMessage], probe: TypedProbe[Status])
val wakeUpEnabled = "wake_up_enabled"
val wakeUpTimeout = "wake_up_timeout"
case class FixtureParam(relay: ActorRef[Command], switchboard: TestProbe, register: TestProbe, router: TypedProbe[Router.GetNodeId], peerConnection: TypedProbe[Nothing], peer: TypedProbe[Peer.RelayOnionMessage], peerReadyManager: TestProbe, probe: TypedProbe[Status])
override def withFixture(test: OneArgTest): Outcome = {
val peerReadyManager = TestProbe("peer-ready-manager")(system.classicSystem)
system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref.toTyped)
val switchboard = TestProbe("switchboard")(system.classicSystem)
system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref.toTyped)
val register = TestProbe("register")(system.classicSystem)
val router = TypedProbe[Router.GetNodeId]("router")
val peerConnection = TypedProbe[Nothing]("peerConnection")
val peer = TypedProbe[Peer.RelayOnionMessage]("peer")
val probe = TypedProbe[Status]("probe")
val relay = testKit.spawn(MessageRelay(Alice.nodeParams, switchboard.ref, register.ref, router.ref))
val nodeParams = Alice.nodeParams
.modify(_.peerWakeUpConfig.enabled).setToIf(test.tags.contains(wakeUpEnabled))(true)
.modify(_.peerWakeUpConfig.timeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis)
val relay = testKit.spawn(MessageRelay(nodeParams, switchboard.ref, register.ref, router.ref))
try {
withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, register, router, peerConnection, peer, probe)))
withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, register, router, peerConnection, peer, peerReadyManager, probe)))
} finally {
system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref.toTyped)
system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref.toTyped)
testKit.stop(relay)
}
}
@ -86,6 +99,23 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app
assert(peer.expectMessageType[Peer.RelayOnionMessage].msg == message)
}
test("relay after waking up next node", Tag(wakeUpEnabled)) { f =>
import f._
val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty)
val messageId = randomBytes32()
relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId.WithPublicKey.Wallet(bobId)), message, RelayChannelsOnly, None)
val register = peerReadyManager.expectMsgType[PeerReadyManager.Register]
assert(register.remoteNodeId == bobId)
register.replyTo ! PeerReadyManager.Registered(bobId, otherAttempts = 0)
val request = switchboard.expectMsgType[GetPeerInfo]
assert(request.remoteNodeId == bobId)
request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, bobId, Peer.CONNECTED, None, Set.empty)
assert(peer.expectMessageType[Peer.RelayOnionMessage].msg == message)
}
test("can't open new connection") { f =>
import f._
@ -99,6 +129,15 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app
probe.expectMessage(ConnectionFailure(messageId, PeerConnection.ConnectionResult.NoAddressFound))
}
test("can't wake up next node", Tag(wakeUpEnabled), Tag(wakeUpTimeout)) { f =>
import f._
val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty)
val messageId = randomBytes32()
relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId.WithPublicKey.Wallet(bobId)), message, RelayChannelsOnly, Some(probe.ref))
probe.expectMessage(Disconnected(messageId))
}
test("no channel with previous node") { f =>
import f._

View File

@ -0,0 +1,55 @@
/*
* Copyright 2024 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package fr.acinq.eclair.io
import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe}
import com.typesafe.config.ConfigFactory
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.randomKey
import org.scalatest.funsuite.AnyFunSuiteLike
class PeerReadyManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with AnyFunSuiteLike {
test("watch pending notifiers") {
val manager = testKit.spawn(PeerReadyManager())
val remoteNodeId1 = randomKey().publicKey
val notifier1a = TestProbe[PeerReadyManager.Registered]()
val notifier1b = TestProbe[PeerReadyManager.Registered]()
manager ! PeerReadyManager.Register(notifier1a.ref, remoteNodeId1)
assert(notifier1a.expectMessageType[PeerReadyManager.Registered].otherAttempts == 0)
manager ! PeerReadyManager.Register(notifier1b.ref, remoteNodeId1)
assert(notifier1b.expectMessageType[PeerReadyManager.Registered].otherAttempts == 1)
val remoteNodeId2 = randomKey().publicKey
val notifier2a = TestProbe[PeerReadyManager.Registered]()
val notifier2b = TestProbe[PeerReadyManager.Registered]()
// Later attempts aren't affected by previously completed attempts.
manager ! PeerReadyManager.Register(notifier2a.ref, remoteNodeId2)
assert(notifier2a.expectMessageType[PeerReadyManager.Registered].otherAttempts == 0)
notifier2a.stop()
val probe = TestProbe[Set[PublicKey]]()
probe.awaitAssert({
manager ! PeerReadyManager.List(probe.ref)
assert(probe.expectMessageType[Set[PublicKey]] == Set(remoteNodeId1))
})
manager ! PeerReadyManager.Register(notifier2b.ref, remoteNodeId2)
assert(notifier2b.expectMessageType[PeerReadyManager.Registered].otherAttempts == 0)
}
}

View File

@ -33,17 +33,20 @@ import scala.concurrent.duration.DurationInt
class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike {
case class FixtureParam(remoteNodeId: PublicKey, switchboard: TestProbe[Switchboard.GetPeerInfo], peer: TestProbe[Peer.GetPeerChannels], probe: TestProbe[PeerReadyNotifier.Result])
case class FixtureParam(remoteNodeId: PublicKey, peerReadyManager: TestProbe[PeerReadyManager.Register], switchboard: TestProbe[Switchboard.GetPeerInfo], peer: TestProbe[Peer.GetPeerChannels], probe: TestProbe[PeerReadyNotifier.Result])
override def withFixture(test: OneArgTest): Outcome = {
val remoteNodeId = randomKey().publicKey
val peerReadyManager = TestProbe[PeerReadyManager.Register]("peer-ready-manager")
system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
val switchboard = TestProbe[Switchboard.GetPeerInfo]("switchboard")
system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref)
val peer = TestProbe[Peer.GetPeerChannels]("peer")
val probe = TestProbe[PeerReadyNotifier.Result]()
try {
withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, switchboard, peer, probe)))
withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, peerReadyManager, switchboard, peer, probe)))
} finally {
system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref)
}
}
@ -53,7 +56,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Left(10 millis))))
notifier ! NotifyWhenPeerReady(probe.ref)
assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == remoteNodeId)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0)
probe.expectMessage(PeerUnavailable(remoteNodeId))
}
@ -62,6 +65,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(100)))))
notifier ! NotifyWhenPeerReady(probe.ref)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0)
assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == remoteNodeId)
// We haven't reached the timeout yet.
@ -78,6 +82,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500)))))
notifier ! NotifyWhenPeerReady(probe.ref)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0)
val request = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set.empty)
probe.expectMessage(PeerReadyNotifier.PeerReady(remoteNodeId, peer.ref.toClassic, Seq.empty))
@ -88,6 +93,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500)))))
notifier ! NotifyWhenPeerReady(probe.ref)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0)
val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set(TestProbe().ref.toClassic, TestProbe().ref.toClassic))
@ -115,6 +121,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500)))))
notifier ! NotifyWhenPeerReady(probe.ref)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 1)
val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic, TestProbe().ref.toClassic))
peer.expectNoMessage(100 millis)
@ -137,6 +144,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = None))
notifier ! NotifyWhenPeerReady(probe.ref)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0)
val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request1.replyTo ! Peer.PeerNotFound(remoteNodeId)
peer.expectNoMessage(100 millis)
@ -161,6 +169,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500)))))
notifier ! NotifyWhenPeerReady(probe.ref)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 5)
val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set.empty)
peer.expectNoMessage(100 millis)
@ -185,6 +194,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Left(1 second))))
notifier ! NotifyWhenPeerReady(probe.ref)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0)
val request = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set(TestProbe().ref.toClassic))
peer.expectMessageType[Peer.GetPeerChannels]
@ -196,6 +206,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load
val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(100)))))
notifier ! NotifyWhenPeerReady(probe.ref)
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 2)
val request = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set(TestProbe().ref.toClassic))
peer.expectMessageType[Peer.GetPeerChannels]

View File

@ -85,7 +85,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength)
assert(relay_b.amountToForward == amount_bc)
assert(relay_b.outgoingCltv == expiry_bc)
assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId)
assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId))
assert(relay_b.relayFeeMsat == fee_b)
assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta)
@ -95,7 +95,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength)
assert(relay_c.amountToForward == amount_cd)
assert(relay_c.outgoingCltv == expiry_cd)
assert(payload_c.outgoingChannelId == channelUpdate_cd.shortChannelId)
assert(payload_c.outgoing.contains(channelUpdate_cd.shortChannelId))
assert(relay_c.relayFeeMsat == fee_c)
assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta)
@ -105,7 +105,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength)
assert(relay_d.amountToForward == amount_de)
assert(relay_d.outgoingCltv == expiry_de)
assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId)
assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId))
assert(relay_d.relayFeeMsat == fee_d)
assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta)
@ -175,7 +175,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength)
assert(relay_b.amountToForward >= amount_bc)
assert(relay_b.outgoingCltv == expiry_bc)
assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId)
assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId))
assert(relay_b.relayFeeMsat == fee_b)
assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta)
assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard])
@ -185,7 +185,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength)
assert(relay_c.amountToForward >= amount_cd)
assert(relay_c.outgoingCltv == expiry_cd)
assert(payload_c.outgoingChannelId == channelUpdate_cd.shortChannelId)
assert(payload_c.outgoing.contains(channelUpdate_cd.shortChannelId))
assert(relay_c.relayFeeMsat == fee_c)
assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta)
assert(payload_c.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded])
@ -196,7 +196,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength)
assert(relay_d.amountToForward >= amount_de)
assert(relay_d.outgoingCltv == expiry_de)
assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId)
assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId))
assert(relay_d.relayFeeMsat == fee_d)
assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta)
assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded])
@ -238,7 +238,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength)
assert(relay_b.amountToForward >= amount_bc)
assert(relay_b.outgoingCltv == expiry_bc)
assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId)
assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId))
assert(relay_b.relayFeeMsat == fee_b)
assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta)
assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard])
@ -547,7 +547,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
// A smaller amount is sent to d, who doesn't know that it's invalid.
val add_d = UpdateAddHtlc(randomBytes32(), 0, amount_de, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt, 1.0)
val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional))
assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId)
assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId))
assert(relay_d.amountToForward < amount_de)
assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded])
val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding
@ -569,7 +569,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
val invalidExpiry = payment.cmd.cltvExpiry - Channel.MIN_CLTV_EXPIRY_DELTA - CltvExpiryDelta(1)
val add_d = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, paymentHash, invalidExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt, 1.0)
val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional))
assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId)
assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId))
assert(relay_d.outgoingCltv < CltvExpiry(currentBlockCount))
assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded])
val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding

View File

@ -57,7 +57,7 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit
case class FixtureParam(nodeParams: NodeParams, register: TestProbe, sender: TestProbe, eventListener: TestProbe) {
def createRelayer(nodeParams1: NodeParams): (ActorRef, ActorRef) = {
val relayer = system.actorOf(Relayer.props(nodeParams1, TestProbe().ref, register.ref, TestProbe().ref, TestProbe().ref.toTyped))
val relayer = system.actorOf(Relayer.props(nodeParams1, TestProbe().ref, register.ref, TestProbe().ref))
// we need ensure the post-htlc-restart child actor is initialized
sender.send(relayer, Relayer.GetChildActors(sender.ref))
(relayer, sender.expectMsgType[Relayer.ChildActors].postRestartCleaner)

View File

@ -1,17 +1,18 @@
package fr.acinq.eclair.payment.relay
import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe}
import akka.actor.typed.ActorRef
import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.receptionist.Receptionist
import akka.actor.typed.scaladsl.Behaviors
import akka.actor.typed.scaladsl.adapter.TypedActorRefOps
import akka.actor.typed.{ActorRef, Behavior}
import com.typesafe.config.ConfigFactory
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.blockchain.CurrentBlockHeight
import fr.acinq.eclair.channel.NEGOTIATING
import fr.acinq.eclair.io.Switchboard.GetPeerInfo
import fr.acinq.eclair.io.{Peer, PeerConnected, Switchboard}
import fr.acinq.eclair.io.{Peer, PeerConnected, PeerReadyManager, Switchboard}
import fr.acinq.eclair.payment.relay.AsyncPaymentTriggerer._
import fr.acinq.eclair.{BlockHeight, TestConstants, randomKey}
import org.scalatest.Outcome
@ -23,8 +24,20 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.
case class FixtureParam(remoteNodeId: PublicKey, switchboard: TestProbe[Switchboard.GetPeerInfo], peer: TestProbe[Peer.GetPeerChannels], probe: TestProbe[Result], triggerer: ActorRef[Command])
object DummyPeerReadyManager {
def apply(): Behavior[PeerReadyManager.Command] = {
Behaviors.receiveMessagePartial {
case PeerReadyManager.Register(replyTo, remoteNodeId) =>
replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0)
Behaviors.same
}
}
}
override def withFixture(test: OneArgTest): Outcome = {
val remoteNodeId = TestConstants.Alice.nodeParams.nodeId
val peerReadyManager = testKit.spawn(DummyPeerReadyManager())
system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager)
val switchboard = TestProbe[Switchboard.GetPeerInfo]("switchboard")
system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref)
val peer = TestProbe[Peer.GetPeerChannels]("peer")
@ -33,6 +46,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.
try {
withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, switchboard, peer, probe, triggerer)))
} finally {
system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager)
system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref)
}
}
@ -170,7 +184,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.
val probe2 = TestProbe[Result]()
triggerer ! Watch(probe2.ref, remoteNodeId2, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(101))
val request2 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request2.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic))
request2.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId2, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic))
// First remote node times out
system.eventStream ! EventStream.Publish(CurrentBlockHeight(BlockHeight(100)))
@ -192,6 +206,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.
test("triggerer treats an unexpected stop of the notifier as a cancel") { f =>
import f._
triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100))
assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId)

View File

@ -19,6 +19,7 @@ package fr.acinq.eclair.payment.relay
import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe}
import akka.actor.typed
import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.receptionist.Receptionist
import akka.actor.typed.scaladsl.adapter.TypedActorRefOps
import com.softwaremill.quicklens.ModifyPimp
import com.typesafe.config.ConfigFactory
@ -29,6 +30,7 @@ import fr.acinq.eclair.TestConstants.emptyOnionPacket
import fr.acinq.eclair.blockchain.fee.FeeratePerKw
import fr.acinq.eclair.channel._
import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.io.{Peer, PeerReadyManager, Switchboard}
import fr.acinq.eclair.payment.IncomingPaymentPacket.ChannelRelayPacket
import fr.acinq.eclair.payment.relay.ChannelRelayer._
import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket, PaymentPacketSpec}
@ -39,19 +41,26 @@ import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload.ChannelRel
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiry, NodeParams, RealShortChannelId, TestConstants, randomBytes32, _}
import org.scalatest.Inside.inside
import org.scalatest.Outcome
import org.scalatest.funsuite.FixtureAnyFunSuiteLike
import org.scalatest.{Outcome, Tag}
import scodec.bits.HexStringSyntax
import scala.concurrent.duration.DurationInt
class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike {
import ChannelRelayerSpec._
val wakeUpEnabled = "wake_up_enabled"
val wakeUpTimeout = "wake_up_timeout"
case class FixtureParam(nodeParams: NodeParams, channelRelayer: typed.ActorRef[ChannelRelayer.Command], register: TestProbe[Any])
override def withFixture(test: OneArgTest): Outcome = {
// we are node B in the route A -> B -> C -> ....
val nodeParams = TestConstants.Bob.nodeParams
.modify(_.peerWakeUpConfig.enabled).setToIf(test.tags.contains(wakeUpEnabled))(true)
.modify(_.peerWakeUpConfig.timeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis)
val register = TestProbe[Any]("register")
val channelRelayer = testKit.spawn(ChannelRelayer.apply(nodeParams, register.ref.toClassic))
try {
@ -157,7 +166,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a
import f._
val u = createLocalUpdate(channelId1, feeBaseMsat = 2500 msat, feeProportionalMillionths = 0)
val payload = createBlindedPayload(u.channelUpdate, isIntroduction = false)
val payload = createBlindedPayload(Right(u.channelUpdate.shortChannelId), u.channelUpdate, isIntroduction = false)
val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta)
channelRelayer ! WrappedLocalChannelUpdate(u)
@ -166,6 +175,34 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a
expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry, 7)
}
test("relay blinded payment (wake up wallet node)", Tag(wakeUpEnabled)) { f =>
import f._
val peerReadyManager = TestProbe[PeerReadyManager.Register]()
system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
val switchboard = TestProbe[Switchboard.GetPeerInfo]()
system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref)
val u = createLocalUpdate(channelId1, feeBaseMsat = 2500 msat, feeProportionalMillionths = 0)
Seq(true, false).foreach(isIntroduction => {
val payload = createBlindedPayload(Left(outgoingNodeId), u.channelUpdate, isIntroduction)
val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta)
channelRelayer ! WrappedLocalChannelUpdate(u)
channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId)
// We try to wake-up the next node.
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(outgoingNodeId, otherAttempts = 0)
val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo]
assert(wakeUp.remoteNodeId == outgoingNodeId)
wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty)
expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry, 7)
})
system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref)
}
test("relay with retries") { f =>
import f._
@ -270,7 +307,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a
Seq(true, false).foreach { isIntroduction =>
// The outgoing channel is disabled, so we won't be able to relay the payment.
val u = createLocalUpdate(channelId1, feeBaseMsat = 5000 msat, feeProportionalMillionths = 0, enabled = false)
val r = createValidIncomingPacket(createBlindedPayload(u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta)
val r = createValidIncomingPacket(createBlindedPayload(Right(u.channelUpdate.shortChannelId), u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta)
channelRelayer ! WrappedLocalChannelUpdate(u)
channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId)
@ -293,6 +330,31 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a
}
}
test("fail to relay blinded payment (cannot wake up remote node)", Tag(wakeUpEnabled), Tag(wakeUpTimeout)) { f =>
import f._
val peerReadyManager = TestProbe[PeerReadyManager.Register]()
system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
val switchboard = TestProbe[Switchboard.GetPeerInfo]()
system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref)
val u = createLocalUpdate(channelId1, feeBaseMsat = 2500 msat, feeProportionalMillionths = 0)
val payload = createBlindedPayload(Left(outgoingNodeId), u.channelUpdate, isIntroduction = true)
val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta)
channelRelayer ! WrappedLocalChannelUpdate(u)
channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId)
// We try to wake-up the next node, but we timeout before they connect.
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(outgoingNodeId, otherAttempts = 0)
assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == outgoingNodeId)
val fail = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
assert(fail.message.reason.contains(InvalidOnionBlinding(Sphinx.hash(r.add.onionRoutingPacket))))
system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref)
}
test("relay when expiry larger than our requirements") { f =>
import f._
@ -519,7 +581,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a
Seq(true, false).foreach { isIntroduction =>
testCases.foreach { htlcResult =>
val r = createValidIncomingPacket(createBlindedPayload(u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta, endorsementIn = 0)
val r = createValidIncomingPacket(createBlindedPayload(Right(u.channelUpdate.shortChannelId), u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta, endorsementIn = 0)
channelRelayer ! WrappedLocalChannelUpdate(u)
channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId)
val fwd = expectFwdAdd(register, channelId1, outgoingAmount, outgoingExpiry, 0)
@ -653,13 +715,16 @@ object ChannelRelayerSpec {
localAlias2 -> channelId2,
)
def createBlindedPayload(update: ChannelUpdate, isIntroduction: Boolean): ChannelRelay.Blinded = {
def createBlindedPayload(outgoing: Either[PublicKey, ShortChannelId], update: ChannelUpdate, isIntroduction: Boolean): ChannelRelay.Blinded = {
val tlvs = TlvStream[OnionPaymentPayloadTlv](Set(
Some(OnionPaymentPayloadTlv.EncryptedRecipientData(hex"2a")),
if (isIntroduction) Some(OnionPaymentPayloadTlv.BlindingPoint(randomKey().publicKey)) else None,
).flatten[OnionPaymentPayloadTlv])
val blindedTlvs = TlvStream[RouteBlindingEncryptedDataTlv](
RouteBlindingEncryptedDataTlv.OutgoingChannelId(update.shortChannelId),
outgoing match {
case Left(nodeId) => RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Wallet(nodeId))
case Right(scid) => RouteBlindingEncryptedDataTlv.OutgoingChannelId(scid)
},
RouteBlindingEncryptedDataTlv.PaymentRelay(update.cltvExpiryDelta, update.feeProportionalMillionths, update.feeBaseMsat),
RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(500_000), 0 msat),
)

View File

@ -20,35 +20,37 @@ import akka.actor.Status
import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe}
import akka.actor.typed.ActorRef
import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.receptionist.Receptionist
import akka.actor.typed.scaladsl.ActorContext
import akka.actor.typed.scaladsl.adapter._
import com.softwaremill.quicklens.ModifyPimp
import com.typesafe.config.ConfigFactory
import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
import fr.acinq.bitcoin.scalacompat.{Block, BlockHash, ByteVector32, Crypto}
import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto}
import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir
import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional}
import fr.acinq.eclair.Features.{AsyncPaymentPrototype, BasicMultiPartPayment, PaymentSecret, VariableLengthOnion}
import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir
import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register, Upstream}
import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.io.{Peer, PeerReadyManager, Switchboard}
import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop
import fr.acinq.eclair.payment.IncomingPaymentPacket.{RelayToBlindedPathsPacket, RelayToTrampolinePacket}
import fr.acinq.eclair.payment.Invoice.ExtraEdge
import fr.acinq.eclair.payment.OutgoingPaymentPacket.NodePayload
import fr.acinq.eclair.payment._
import fr.acinq.eclair.payment.relay.AsyncPaymentTriggerer.{AsyncPaymentCanceled, AsyncPaymentTimeout, AsyncPaymentTriggered, Watch}
import fr.acinq.eclair.payment.relay.NodeRelayer.PaymentKey
import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment}
import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig
import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToNode
import fr.acinq.eclair.payment.send.{BlindedRecipient, ClearRecipient}
import fr.acinq.eclair.router.Router.RouteRequest
import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound, Router}
import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams, RouteRequest}
import fr.acinq.eclair.router.{BalanceTooLow, BlindedRouteCreation, RouteNotFound, Router}
import fr.acinq.eclair.wire.protocol.OfferTypes._
import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload}
import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec
import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{AllowedFeatures, PathId, PaymentConstraints}
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{BlockHeight, Bolt11Feature, CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, RealShortChannelId, ShortChannelId, TestConstants, TimestampMilli, UInt64, randomBytes, randomBytes32, randomKey}
import fr.acinq.eclair.{Alias, BlockHeight, Bolt11Feature, Bolt12Feature, CltvExpiry, CltvExpiryDelta, EncodedNodeId, FeatureSupport, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, RealShortChannelId, ShortChannelId, TestConstants, TimestampMilli, UInt64, randomBytes32, randomKey}
import org.scalatest.funsuite.FixtureAnyFunSuiteLike
import org.scalatest.{Outcome, Tag}
import scodec.bits.{ByteVector, HexStringSyntax}
@ -65,11 +67,14 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
import NodeRelayerSpec._
case class FixtureParam(nodeParams: NodeParams, router: TestProbe[Any], register: TestProbe[Any], mockPayFSM: TestProbe[Any], eventListener: TestProbe[PaymentEvent], triggerer: TestProbe[AsyncPaymentTriggerer.Command]) {
val wakeUpEnabled = "wake_up_enabled"
val wakeUpTimeout = "wake_up_timeout"
case class FixtureParam(nodeParams: NodeParams, router: TestProbe[Any], register: TestProbe[Any], mockPayFSM: TestProbe[Any], eventListener: TestProbe[PaymentEvent]) {
def createNodeRelay(packetIn: IncomingPaymentPacket.NodeRelayPacket, useRealPaymentFactory: Boolean = false): (ActorRef[NodeRelay.Command], TestProbe[NodeRelayer.Command]) = {
val parent = TestProbe[NodeRelayer.Command]("parent-relayer")
val outgoingPaymentFactory = if (useRealPaymentFactory) RealOutgoingPaymentFactory(this) else FakeOutgoingPaymentFactory(this)
val nodeRelay = testKit.spawn(NodeRelay(nodeParams, parent.ref, register.ref.toClassic, relayId, packetIn, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic))
val nodeRelay = testKit.spawn(NodeRelay(nodeParams, parent.ref, register.ref.toClassic, relayId, packetIn, outgoingPaymentFactory, router.ref.toClassic))
(nodeRelay, parent)
}
}
@ -92,21 +97,21 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
override def withFixture(test: OneArgTest): Outcome = {
val nodeParams = TestConstants.Bob.nodeParams
.modify(_.multiPartPaymentExpiry).setTo(5 seconds)
.modify(_.features).setToIf(test.tags.contains("async_payments"))(Features(AsyncPaymentPrototype -> Optional))
.modify(_.relayParams.asyncPaymentsParams.holdTimeoutBlocks).setToIf(test.tags.contains("long_hold_timeout"))(200000) // timeout after payment expires
.modify(_.peerWakeUpConfig.enabled).setToIf(test.tags.contains(wakeUpEnabled))(true)
.modify(_.peerWakeUpConfig.timeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis)
val router = TestProbe[Any]("router")
val register = TestProbe[Any]("register")
val eventListener = TestProbe[PaymentEvent]("event-listener")
system.eventStream ! EventStream.Subscribe(eventListener.ref)
val mockPayFSM = TestProbe[Any]("pay-fsm")
val triggerer = TestProbe[AsyncPaymentTriggerer.Command]("payment-triggerer")
withFixture(test.toNoArgTest(FixtureParam(nodeParams, router, register, mockPayFSM, eventListener, triggerer)))
withFixture(test.toNoArgTest(FixtureParam(nodeParams, router, register, mockPayFSM, eventListener)))
}
test("create child handlers for new payments") { f =>
import f._
val probe = TestProbe[Any]()
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, FakeOutgoingPaymentFactory(f), triggerer.ref, router.ref.toClassic))
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, FakeOutgoingPaymentFactory(f), router.ref.toClassic))
parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic)
probe.expectMessage(Map.empty)
@ -145,7 +150,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
val outgoingPaymentFactory = FakeOutgoingPaymentFactory(f)
{
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic))
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, router.ref.toClassic))
parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic)
probe.expectMessage(Map.empty)
}
@ -153,7 +158,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
val (paymentHash1, paymentSecret1, child1) = (randomBytes32(), randomBytes32(), TestProbe[NodeRelay.Command]())
val (paymentHash2, paymentSecret2, child2) = (randomBytes32(), randomBytes32(), TestProbe[NodeRelay.Command]())
val children = Map(PaymentKey(paymentHash1, paymentSecret1) -> child1.ref, PaymentKey(paymentHash2, paymentSecret2) -> child2.ref)
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic, children))
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, router.ref.toClassic, children))
parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic)
probe.expectMessage(children)
@ -169,7 +174,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
val (paymentSecret1, child1) = (randomBytes32(), TestProbe[NodeRelay.Command]())
val (paymentSecret2, child2) = (randomBytes32(), TestProbe[NodeRelay.Command]())
val children = Map(PaymentKey(paymentHash, paymentSecret1) -> child1.ref, PaymentKey(paymentHash, paymentSecret2) -> child2.ref)
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic, children))
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, router.ref.toClassic, children))
parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic)
probe.expectMessage(children)
@ -179,7 +184,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
probe.expectMessage(Map(PaymentKey(paymentHash, paymentSecret2) -> child2.ref))
}
{
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic))
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, router.ref.toClassic))
parentRelayer ! NodeRelayer.Relay(incomingMultiPart.head, randomKey().publicKey)
parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic)
val pending1 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]]
@ -228,7 +233,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1000 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None, 1.0),
FinalPayload.Standard.createPayload(1000 msat, incomingAmount, CltvExpiry(499990), incomingSecret, None),
IntermediatePayload.NodeRelay.Standard(outgoingAmount, outgoingExpiry, outgoingNodeId),
nextTrampolinePacket)
createTrampolinePacket(outgoingAmount, outgoingExpiry))
nodeRelayer ! NodeRelay.Relay(extra, randomKey().publicKey)
// the extra payment will be rejected
@ -257,7 +262,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1000 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None, 1.0),
FinalPayload.Standard.createPayload(1000 msat, incomingAmount, CltvExpiry(499990), incomingSecret, None),
IntermediatePayload.NodeRelay.Standard(outgoingAmount, outgoingExpiry, outgoingNodeId),
nextTrampolinePacket)
createTrampolinePacket(outgoingAmount, outgoingExpiry))
nodeRelayer ! NodeRelay.Relay(i1, randomKey().publicKey)
val fwd1 = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
@ -270,7 +275,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1500 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None, 1.0),
PaymentOnion.FinalPayload.Standard.createPayload(1500 msat, 1500 msat, CltvExpiry(499990), incomingSecret, None),
IntermediatePayload.NodeRelay.Standard(1250 msat, outgoingExpiry, outgoingNodeId),
nextTrampolinePacket)
createTrampolinePacket(outgoingAmount, outgoingExpiry))
nodeRelayer ! NodeRelay.Relay(i2, randomKey().publicKey)
val fwd2 = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
@ -335,115 +340,6 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
register.expectNoMessage(100 millis)
}
test("fail to relay when not triggered before the hold timeout", Tag("async_payments")) { f =>
import f._
val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head)
incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey))
// wait until the NodeRelay is waiting for the trigger
eventListener.expectMessageType[WaitingToRelayPayment]
mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a trigger
// publish notification that peer is unavailable at the timeout height
val peerWatch = triggerer.expectMessageType[Watch]
assert(asyncTimeoutHeight(nodeParams) < asyncSafetyHeight(incomingAsyncPayment, nodeParams))
assert(peerWatch.timeout == asyncTimeoutHeight(nodeParams))
peerWatch.replyTo ! AsyncPaymentTimeout
incomingAsyncPayment.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure()), commit = true))
}
register.expectNoMessage(100 millis)
}
test("relay the payment when triggered while waiting", Tag("async_payments"), Tag("long_hold_timeout")) { f =>
import f._
val (nodeRelayer, parent) = createNodeRelay(incomingAsyncPayment.head)
incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey))
// wait until the NodeRelay is waiting for the trigger
eventListener.expectMessageType[WaitingToRelayPayment]
mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a trigger
// publish notification that peer is ready before the safety interval before the current incoming payment expires (and before the timeout height)
val peerWatch = triggerer.expectMessageType[Watch]
assert(asyncTimeoutHeight(nodeParams) > asyncSafetyHeight(incomingAsyncPayment, nodeParams))
assert(peerWatch.timeout == asyncSafetyHeight(incomingAsyncPayment, nodeParams))
peerWatch.replyTo ! AsyncPaymentTriggered
// upstream payment relayed
val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig]
validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingAsyncPayment.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5)
val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment]
validateOutgoingPayment(outgoingPayment)
// those are adapters for pay-fsm messages
val nodeRelayerAdapters = outgoingPayment.replyTo
// A first downstream HTLC is fulfilled: we should immediately forward the fulfill upstream.
nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage)
incomingAsyncPayment.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true))
}
// Once all the downstream payments have settled, we should emit the relayed event.
nodeRelayerAdapters ! createSuccessEvent()
val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed]
validateRelayEvent(relayEvent)
assert(relayEvent.incoming.map(p => (p.amount, p.channelId)).toSet == incomingAsyncPayment.map(i => (i.add.amountMsat, i.add.channelId)).toSet)
assert(relayEvent.outgoing.nonEmpty)
parent.expectMessageType[NodeRelayer.RelayComplete]
register.expectNoMessage(100 millis)
}
test("fail to relay when not triggered before the incoming expiry safety timeout", Tag("async_payments"), Tag("long_hold_timeout")) { f =>
import f._
val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head)
incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey))
mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a complete upstream payment
// publish notification that peer is unavailable at the cancel-safety-before-timeout-block threshold before the current incoming payment expires (and before the timeout height)
val peerWatch = triggerer.expectMessageType[Watch]
assert(asyncTimeoutHeight(nodeParams) > asyncSafetyHeight(incomingAsyncPayment, nodeParams))
assert(peerWatch.timeout == asyncSafetyHeight(incomingAsyncPayment, nodeParams))
peerWatch.replyTo ! AsyncPaymentTimeout
incomingAsyncPayment.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure()), commit = true))
}
register.expectNoMessage(100 millis)
}
test("fail to relay payment when canceled by sender before timeout", Tag("async_payments")) { f =>
import f._
val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head)
incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey))
// wait until the NodeRelay is waiting for the trigger
eventListener.expectMessageType[WaitingToRelayPayment]
mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a trigger
// fail the payment if waiting when payment sender sends cancel message
nodeRelayer ! NodeRelay.WrappedPeerReadyResult(AsyncPaymentCanceled)
incomingAsyncPayment.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure()), commit = true))
}
register.expectNoMessage(100 millis)
}
test("relay the payment immediately when the async payment feature is disabled") { f =>
import f._
@ -827,26 +723,15 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
}
}
def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedRoute = {
val selfPayload = blindedRouteDataCodec.encode(TlvStream(PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty))).require.bytes
PaymentBlindedRoute(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route, PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty))
}
test("relay to blinded paths without multi-part") { f =>
import f._
val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32()))
val offer = Offer(None, Some("test offer"), outgoingNodeId, Features.empty, chain)
val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain)
val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(createPaymentBlindedRoute(outgoingNodeId)))
val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths(
incoming.innerPayload.amountToForward, outgoingExpiry, invoice
)))
val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features.empty, None)
val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head)
incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey))
val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig]
validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true)
validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true)
val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode]
assert(outgoingPayment.amount == outgoingAmount)
assert(outgoingPayment.recipient.expiry == outgoingExpiry)
@ -856,7 +741,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
val nodeRelayerAdapters = outgoingPayment.replyTo
nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage)
incomingMultiPart.foreach { p =>
incomingPayments.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true))
@ -865,7 +750,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
nodeRelayerAdapters ! createSuccessEvent()
val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed]
validateRelayEvent(relayEvent)
assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingMultiPart.map(i => (i.add.amountMsat, i.add.channelId)))
assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId)))
assert(relayEvent.outgoing.length == 1)
parent.expectMessageType[NodeRelayer.RelayComplete]
register.expectNoMessage(100 millis)
@ -874,18 +759,12 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
test("relay to blinded paths with multi-part") { f =>
import f._
val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32()))
val offer = Offer(None, Some("test offer"), outgoingNodeId, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), chain)
val request = InvoiceRequest(offer, outgoingAmount, 1, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), payerKey, chain)
val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), Seq(createPaymentBlindedRoute(outgoingNodeId)))
val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths(
incoming.innerPayload.amountToForward, outgoingExpiry, invoice
)))
val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), None)
val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head)
incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey))
val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig]
validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true)
validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true)
val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment]
assert(outgoingPayment.recipient.totalAmount == outgoingAmount)
assert(outgoingPayment.recipient.expiry == outgoingExpiry)
@ -895,7 +774,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
val nodeRelayerAdapters = outgoingPayment.replyTo
nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage)
incomingMultiPart.foreach { p =>
incomingPayments.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true))
@ -904,25 +783,89 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
nodeRelayerAdapters ! createSuccessEvent()
val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed]
validateRelayEvent(relayEvent)
assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingMultiPart.map(i => (i.add.amountMsat, i.add.channelId)))
assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId)))
assert(relayEvent.outgoing.length == 1)
parent.expectMessageType[NodeRelayer.RelayComplete]
register.expectNoMessage(100 millis)
}
test("relay to blinded path with wake-up", Tag(wakeUpEnabled)) { f =>
import f._
val peerReadyManager = TestProbe[PeerReadyManager.Register]()
system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
val switchboard = TestProbe[Switchboard.GetPeerInfo]()
system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref)
val incomingPayments = createIncomingPaymentsToWalletBlindedPath(nodeParams)
val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head)
incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey))
// The remote node is a wallet node: we try to wake them up before relaying the payment.
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(outgoingNodeId, otherAttempts = 0)
val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo]
assert(wakeUp.remoteNodeId == outgoingNodeId)
wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty)
system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref)
val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig]
validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true)
val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment]
assert(outgoingPayment.recipient.totalAmount == outgoingAmount)
assert(outgoingPayment.recipient.expiry == outgoingExpiry)
assert(outgoingPayment.recipient.isInstanceOf[BlindedRecipient])
// those are adapters for pay-fsm messages
val nodeRelayerAdapters = outgoingPayment.replyTo
nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage)
incomingPayments.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true))
}
nodeRelayerAdapters ! createSuccessEvent()
val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed]
validateRelayEvent(relayEvent)
assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId)))
assert(relayEvent.outgoing.length == 1)
parent.expectMessageType[NodeRelayer.RelayComplete]
register.expectNoMessage(100 millis)
}
test("fail to relay to blinded path when wake-up fails", Tag(wakeUpEnabled), Tag(wakeUpTimeout)) { f =>
import f._
val peerReadyManager = TestProbe[PeerReadyManager.Register]()
system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
val switchboard = TestProbe[Switchboard.GetPeerInfo]()
system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref)
val incomingPayments = createIncomingPaymentsToWalletBlindedPath(nodeParams)
val (nodeRelayer, _) = f.createNodeRelay(incomingPayments.head)
incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey))
// The remote node is a wallet node: we try to wake them up before relaying the payment, but it times out.
peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(outgoingNodeId, otherAttempts = 0)
assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == outgoingNodeId)
system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref)
system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref)
mockPayFSM.expectNoMessage(100 millis)
incomingPayments.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(UnknownNextPeer()), commit = true))
}
}
test("relay to compact blinded paths") { f =>
import f._
val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32()))
val offer = Offer(None, Some("test offer"), outgoingNodeId, Features.empty, chain)
val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain)
val paymentBlindedRoute = createPaymentBlindedRoute(outgoingNodeId)
val scidDir = ShortChannelIdDir(isNode1 = true, RealShortChannelId(123456L))
val compactPaymentBlindedRoute = paymentBlindedRoute.copy(route = paymentBlindedRoute.route.copy(introductionNodeId = scidDir))
val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(compactPaymentBlindedRoute))
val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths(
incoming.innerPayload.amountToForward, outgoingExpiry, invoice
)))
val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features.empty, Some(scidDir))
val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head)
incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey))
@ -932,7 +875,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
getNodeId.replyTo ! Some(outgoingNodeId)
val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig]
validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true)
validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true)
val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode]
assert(outgoingPayment.amount == outgoingAmount)
assert(outgoingPayment.recipient.expiry == outgoingExpiry)
@ -942,7 +885,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
val nodeRelayerAdapters = outgoingPayment.replyTo
nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage)
incomingMultiPart.foreach { p =>
incomingPayments.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true))
@ -951,7 +894,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
nodeRelayerAdapters ! createSuccessEvent()
val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed]
validateRelayEvent(relayEvent)
assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingMultiPart.map(i => (i.add.amountMsat, i.add.channelId)))
assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId)))
assert(relayEvent.outgoing.length == 1)
parent.expectMessageType[NodeRelayer.RelayComplete]
register.expectNoMessage(100 millis)
@ -960,16 +903,8 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
test("fail to relay to compact blinded paths with unknown scid") { f =>
import f._
val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32()))
val offer = Offer(None, Some("test offer"), outgoingNodeId, Features.empty, chain)
val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain)
val paymentBlindedRoute = createPaymentBlindedRoute(outgoingNodeId)
val scidDir = ShortChannelIdDir(isNode1 = true, RealShortChannelId(123456L))
val compactPaymentBlindedRoute = paymentBlindedRoute.copy(route = paymentBlindedRoute.route.copy(introductionNodeId = scidDir))
val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(compactPaymentBlindedRoute))
val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths(
incoming.innerPayload.amountToForward, outgoingExpiry, invoice
)))
val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features.empty, Some(scidDir))
val (nodeRelayer, _) = f.createNodeRelay(incomingPayments.head)
incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey))
@ -980,7 +915,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
mockPayFSM.expectNoMessage(100 millis)
incomingMultiPart.foreach { p =>
incomingPayments.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(UnknownNextPeer()), commit = true))
@ -1008,7 +943,9 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
assert(outgoingPayment.recipient.isInstanceOf[ClearRecipient])
val recipient = outgoingPayment.recipient.asInstanceOf[ClearRecipient]
assert(recipient.paymentSecret !== incomingSecret) // we should generate a new outgoing secret
assert(recipient.nextTrampolineOnion_opt.contains(nextTrampolinePacket))
assert(recipient.nextTrampolineOnion_opt.nonEmpty)
// The recipient is able to decrypt the trampoline onion.
recipient.nextTrampolineOnion_opt.foreach(onion => assert(IncomingPaymentPacket.decryptOnion(paymentHash, outgoingNodeKey, onion).isRight))
}
def validateRelayEvent(e: TrampolinePaymentRelayed): Unit = {
@ -1025,10 +962,7 @@ object NodeRelayerSpec {
val paymentPreimage = randomBytes32()
val paymentHash = Crypto.sha256(paymentPreimage)
// This is the result of decrypting the incoming trampoline onion packet.
// It should be forwarded to the next trampoline node.
val nextTrampolinePacket = OnionRoutingPacket(0, hex"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619", randomBytes(400), randomBytes32())
val paymentSecret = randomBytes32()
val outgoingAmount = 40_000_000 msat
val outgoingExpiry = CltvExpiry(490000)
@ -1054,6 +988,12 @@ object NodeRelayerSpec {
def createSuccessEvent(): PaymentSent =
PaymentSent(relayId, paymentHash, paymentPreimage, outgoingAmount, outgoingNodeId, Seq(PaymentSent.PartialPayment(UUID.randomUUID(), outgoingAmount, 10 msat, randomBytes32(), None)))
def createTrampolinePacket(amount: MilliSatoshi, expiry: CltvExpiry): OnionRoutingPacket = {
val payload = NodePayload(outgoingNodeId, FinalPayload.Standard.createPayload(amount, amount, expiry, paymentSecret))
val Right(onion) = OutgoingPaymentPacket.buildOnion(Seq(payload), paymentHash, None)
onion.packet
}
def createValidIncomingPacket(amountIn: MilliSatoshi, totalAmountIn: MilliSatoshi, expiryIn: CltvExpiry, amountOut: MilliSatoshi, expiryOut: CltvExpiry, endorsementIn: Int = 7): RelayToTrampolinePacket = {
val outerPayload = FinalPayload.Standard.createPayload(amountIn, totalAmountIn, expiryIn, incomingSecret, None)
val tlvs = TlvStream[UpdateAddHtlcTlv](UpdateAddHtlcTlv.Endorsement(endorsementIn))
@ -1061,7 +1001,7 @@ object NodeRelayerSpec {
UpdateAddHtlc(randomBytes32(), Random.nextInt(100), amountIn, paymentHash, expiryIn, TestConstants.emptyOnionPacket, tlvs),
outerPayload,
IntermediatePayload.NodeRelay.Standard(amountOut, expiryOut, outgoingNodeId),
nextTrampolinePacket)
createTrampolinePacket(amountOut, expiryOut))
}
def createPartialIncomingPacket(paymentHash: ByteVector32, paymentSecret: ByteVector32): RelayToTrampolinePacket = {
@ -1071,7 +1011,46 @@ object NodeRelayerSpec {
UpdateAddHtlc(randomBytes32(), Random.nextInt(100), amountIn, paymentHash, expiryIn, TestConstants.emptyOnionPacket, None, 1.0),
FinalPayload.Standard.createPayload(amountIn, incomingAmount, expiryIn, paymentSecret, None),
IntermediatePayload.NodeRelay.Standard(outgoingAmount, expiryOut, outgoingNodeId),
nextTrampolinePacket)
createTrampolinePacket(outgoingAmount, expiryOut))
}
def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedRoute = {
val selfPayload = blindedRouteDataCodec.encode(TlvStream(PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty))).require.bytes
PaymentBlindedRoute(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route, PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty))
}
/** Create payments to a blinded path that starts at a remote node. */
def createIncomingPaymentsToRemoteBlindedPath(features: Features[Bolt12Feature], scidDir_opt: Option[EncodedNodeId.ShortChannelIdDir]): Seq[RelayToBlindedPathsPacket] = {
val offer = Offer(None, Some("test offer"), outgoingNodeId, features, Block.RegtestGenesisBlock.hash)
val request = InvoiceRequest(offer, outgoingAmount, 1, features, randomKey(), Block.RegtestGenesisBlock.hash)
val paymentBlindedRoute = scidDir_opt match {
case Some(scidDir) =>
val nonCompact = createPaymentBlindedRoute(outgoingNodeId)
nonCompact.copy(route = nonCompact.route.copy(introductionNodeId = scidDir))
case None =>
createPaymentBlindedRoute(outgoingNodeId)
}
val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, features, Seq(paymentBlindedRoute))
incomingMultiPart.map(incoming => {
val innerPayload = IntermediatePayload.NodeRelay.ToBlindedPaths(incoming.innerPayload.amountToForward, outgoingExpiry, invoice)
RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, innerPayload)
})
}
/** Create payments to a blinded path that starts at our node and relays to a wallet node. */
def createIncomingPaymentsToWalletBlindedPath(nodeParams: NodeParams): Seq[RelayToBlindedPathsPacket] = {
val features: Features[Bolt12Feature] = Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional)
val offer = Offer(None, Some("test offer"), outgoingNodeId, features, Block.RegtestGenesisBlock.hash)
val request = InvoiceRequest(offer, outgoingAmount, 1, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), randomKey(), Block.RegtestGenesisBlock.hash)
val edge = ExtraEdge(nodeParams.nodeId, outgoingNodeId, Alias(561), 2_000_000 msat, 250, CltvExpiryDelta(144), 1 msat, None)
val hop = ChannelHop(edge.shortChannelId, nodeParams.nodeId, outgoingNodeId, HopRelayParams.FromHint(edge))
val route = BlindedRouteCreation.createBlindedRouteToWallet(hop, hex"deadbeef", 1 msat, outgoingExpiry).route
val paymentInfo = BlindedRouteCreation.aggregatePaymentInfo(outgoingAmount, Seq(hop), CltvExpiryDelta(12))
val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, features, Seq(PaymentBlindedRoute(route, paymentInfo)))
incomingMultiPart.map(incoming => {
val innerPayload = IntermediatePayload.NodeRelay.ToBlindedPaths(incoming.innerPayload.amountToForward, outgoingExpiry, invoice)
RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, innerPayload)
})
}
}

View File

@ -55,11 +55,10 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat
val router = TestProbe[Any]("router")
val register = TestProbe[Any]("register")
val paymentHandler = TestProbe[Any]("payment-handler")
val triggerer = TestProbe[AsyncPaymentTriggerer.Command]("payment-triggerer")
val probe = TestProbe[Any]()
// we can't spawn top-level actors with akka typed
testKit.spawn(Behaviors.setup[Any] { context =>
val relayer = context.toClassic.actorOf(Relayer.props(nodeParams, router.ref.toClassic, register.ref.toClassic, paymentHandler.ref.toClassic, triggerer.ref))
val relayer = context.toClassic.actorOf(Relayer.props(nodeParams, router.ref.toClassic, register.ref.toClassic, paymentHandler.ref.toClassic))
probe.ref ! relayer
Behaviors.empty[Any]
})

View File

@ -31,7 +31,7 @@ import fr.acinq.eclair.payment.send.BlindedPathsResolver.{FullBlindedRoute, Part
import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams}
import fr.acinq.eclair.router.{BlindedRouteCreation, Router}
import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo
import fr.acinq.eclair.{BlockHeight, CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshiLong, NodeParams, RealShortChannelId, TestConstants, randomBytes32, randomKey}
import fr.acinq.eclair.{Alias, BlockHeight, CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshiLong, NodeParams, RealShortChannelId, TestConstants, randomBytes32, randomKey}
import org.scalatest.Outcome
import org.scalatest.funsuite.FixtureAnyFunSuiteLike
import scodec.bits.HexStringSyntax
@ -151,6 +151,31 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l
}
}
test("resolve route starting at our node (wallet node)") { f =>
import f._
val probe = TestProbe()
val walletNodeId = randomKey().publicKey
val edge = ExtraEdge(nodeParams.nodeId, walletNodeId, Alias(561), 5_000_000 msat, 200, CltvExpiryDelta(144), 1 msat, None)
val hop = ChannelHop(edge.shortChannelId, nodeParams.nodeId, walletNodeId, HopRelayParams.FromHint(edge))
val route = BlindedRouteCreation.createBlindedRouteToWallet(hop, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route
val paymentInfo = BlindedRouteCreation.aggregatePaymentInfo(100_000_000 msat, Seq(hop), CltvExpiryDelta(12))
val resolver = testKit.spawn(BlindedPathsResolver(nodeParams, randomBytes32(), router.ref, register.ref))
resolver ! Resolve(probe.ref, Seq(PaymentBlindedRoute(route, paymentInfo)))
// We are the introduction node: we decrypt the payload and discover that the next node is a wallet node.
val resolved = probe.expectMsgType[Seq[ResolvedPath]]
assert(resolved.size == 1)
assert(resolved.head.route.isInstanceOf[PartialBlindedRoute])
val partialRoute = resolved.head.route.asInstanceOf[PartialBlindedRoute]
assert(partialRoute.firstNodeId == walletNodeId)
assert(partialRoute.nextNodeId == EncodedNodeId.WithPublicKey.Wallet(walletNodeId))
assert(partialRoute.blindedNodes == route.subsequentNodes)
assert(partialRoute.nextBlinding != route.blindingKey)
// We don't need to resolve the nodeId.
register.expectNoMessage(100 millis)
router.expectNoMessage(100 millis)
}
test("ignore blinded paths that cannot be resolved") { f =>
import f._
@ -181,8 +206,9 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l
val probe = TestProbe()
val scid = RealShortChannelId(BlockHeight(750_000), 3, 7)
val edgeLowFees = ExtraEdge(nodeParams.nodeId, randomKey().publicKey, scid, 100 msat, 5, CltvExpiryDelta(144), 1 msat, None)
val edgeLowExpiryDelta = ExtraEdge(nodeParams.nodeId, randomKey().publicKey, scid, 600_000 msat, 100, CltvExpiryDelta(36), 1 msat, None)
val nextNodeId = randomKey().publicKey
val edgeLowFees = ExtraEdge(nodeParams.nodeId, nextNodeId, scid, 100 msat, 5, CltvExpiryDelta(144), 1 msat, None)
val edgeLowExpiryDelta = ExtraEdge(nodeParams.nodeId, nextNodeId, scid, 600_000 msat, 100, CltvExpiryDelta(36), 1 msat, None)
val toResolve = Seq(
// We don't allow paying blinded routes to ourselves.
BlindedRouteCreation.createBlindedRouteWithoutHops(nodeParams.nodeId, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route,
@ -190,6 +216,8 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l
BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route,
// We reject blinded routes with low cltv_expiry_delta.
BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowExpiryDelta.targetNodeId, HopRelayParams.FromHint(edgeLowExpiryDelta))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route,
// We reject blinded routes with low fees, even when the next node seems to be a wallet node.
BlindedRouteCreation.createBlindedRouteToWallet(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees)), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route,
// We reject blinded routes that cannot be decrypted.
BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route.copy(blindingKey = randomKey().publicKey)
).map(r => PaymentBlindedRoute(r, PaymentInfo(1_000_000 msat, 2500, CltvExpiryDelta(300), 1 msat, 500_000_000 msat, Features.empty)))

View File

@ -89,7 +89,7 @@ class PaymentOnionSpec extends AnyFunSuite {
val Right(payload) = IntermediatePayload.ChannelRelay.Standard.validate(decoded)
assert(payload.amountOut == 561.msat)
assert(payload.cltvOut == CltvExpiry(42))
assert(payload.outgoingChannelId == ShortChannelId(1105))
assert(payload.outgoing.contains(ShortChannelId(1105)))
val encoded = perHopPayloadCodec.encode(expected).require.bytes
assert(encoded == bin)
}
@ -110,7 +110,7 @@ class PaymentOnionSpec extends AnyFunSuite {
val decoded = perHopPayloadCodec.decode(bin.bits).require.value
assert(decoded == expected)
val Right(payload) = IntermediatePayload.ChannelRelay.Blinded.validate(decoded, blindedTlvs, randomKey().publicKey)
assert(payload.outgoingChannelId == ShortChannelId(42))
assert(payload.outgoing.contains(ShortChannelId(42)))
assert(payload.amountToForward(10_000 msat) == 9990.msat)
assert(payload.outgoingCltv(CltvExpiry(1000)) == CltvExpiry(856))
assert(payload.paymentRelayData.allowedFeatures.isEmpty)
@ -119,6 +119,20 @@ class PaymentOnionSpec extends AnyFunSuite {
}
}
test("encode/decode channel relay blinded per-hop-payload (with wallet node_id)") {
val walletNodeId = PublicKey(hex"0221cd519eba9c8b840a5e40b65dc2c040e159a766979723ed770efceb97260ec8")
val blindedTlvs = TlvStream[RouteBlindingEncryptedDataTlv](
RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Wallet(walletNodeId)),
RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(144), 100, 10 msat),
RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat),
)
val Right(payload) = IntermediatePayload.ChannelRelay.Blinded.validate(TlvStream(EncryptedRecipientData(hex"deadbeef")), blindedTlvs, randomKey().publicKey)
assert(payload.outgoing == Left(walletNodeId))
assert(payload.amountToForward(10_000 msat) == 9990.msat)
assert(payload.outgoingCltv(CltvExpiry(1000)) == CltvExpiry(856))
assert(payload.paymentRelayData.allowedFeatures.isEmpty)
}
test("encode/decode node relay per-hop payload") {
val nodeId = PublicKey(hex"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619")
val expected = TlvStream[OnionPaymentPayloadTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), OutgoingNodeId(nodeId))
@ -292,6 +306,8 @@ class PaymentOnionSpec extends AnyFunSuite {
TestCase(MissingRequiredTlv(UInt64(10)), hex"23 0c21036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2", validBlindedTlvs),
// Missing encrypted outgoing channel.
TestCase(MissingRequiredTlv(UInt64(2)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(144), 100, 10 msat), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))),
// Forbidden encrypted outgoing plain node_id.
TestCase(ForbiddenTlv(UInt64(4)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Plain(randomKey().publicKey)), RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(144), 100, 10 msat), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))),
// Missing encrypted payment relay data.
TestCase(MissingRequiredTlv(UInt64(10)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.OutgoingChannelId(ShortChannelId(42)), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))),
// Missing encrypted payment constraint.