From c6a76af9d3bfc81315436c9a4c26d9f38e768088 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier <31281497+t-bast@users.noreply.github.com> Date: Wed, 31 Mar 2021 08:58:40 +0200 Subject: [PATCH] Introduce actor factories (#1744) This removes unnecessary fields and allows more flexibility in tests. --- .../main/scala/fr/acinq/eclair/Setup.scala | 14 ++- .../main/scala/fr/acinq/eclair/io/Peer.scala | 23 ++-- .../fr/acinq/eclair/io/Switchboard.scala | 17 ++- .../eclair/payment/relay/NodeRelay.scala | 45 ++++---- .../eclair/payment/relay/NodeRelayer.scala | 3 +- .../send/MultiPartPaymentLifecycle.scala | 8 +- .../payment/send/PaymentInitiator.scala | 39 +++++-- .../scala/fr/acinq/eclair/io/PeerSpec.scala | 100 ++++++++++-------- .../fr/acinq/eclair/io/SwitchboardSpec.scala | 20 ++-- .../MultiPartPaymentLifecycleSpec.scala | 13 +-- .../eclair/payment/PaymentInitiatorSpec.scala | 29 ++--- .../payment/relay/NodeRelayerSpec.scala | 20 ++-- 12 files changed, 195 insertions(+), 136 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala index 6fe748670..70642371d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala @@ -37,7 +37,7 @@ import fr.acinq.eclair.channel.Register import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyManager} import fr.acinq.eclair.db.Databases.FileBackup import fr.acinq.eclair.db.{Databases, DbEventHandler, FileBackupHandler} -import fr.acinq.eclair.io.{ClientSpawner, Server, Switchboard} +import fr.acinq.eclair.io.{ClientSpawner, Peer, Server, Switchboard} import fr.acinq.eclair.payment.receive.PaymentHandler import fr.acinq.eclair.payment.relay.Relayer import fr.acinq.eclair.payment.send.{Autoprobe, PaymentInitiator} @@ -290,8 +290,8 @@ class Setup(datadir: File, new ElectrumEclairWallet(electrumWallet, nodeParams.chainHash) } _ = wallet.getReceiveAddress.map(address => logger.info(s"initial wallet address=$address")) - // do not change the name of this actor. it is used in the configuration to specify a custom bounded mailbox + // do not change the name of this actor. it is used in the configuration to specify a custom bounded mailbox backupHandler = if (config.getBoolean("enable-db-backup")) { nodeParams.db match { case fileBackup: FileBackup => system.actorOf(SimpleSupervisor.props( @@ -314,10 +314,14 @@ class Setup(datadir: File, // Before initializing the switchboard (which re-connects us to the network) and the user-facing parts of the system, // we want to make sure the handler for post-restart broken HTLCs has finished initializing. _ <- postRestartCleanUpInitialized.future - switchboard = system.actorOf(SimpleSupervisor.props(Switchboard.props(nodeParams, watcher, relayer, wallet), "switchboard", SupervisorStrategy.Resume)) + + channelFactory = Peer.SimpleChannelFactory(nodeParams, watcher, relayer, wallet) + peerFactory = Switchboard.SimplePeerFactory(nodeParams, wallet, channelFactory) + + switchboard = system.actorOf(SimpleSupervisor.props(Switchboard.props(nodeParams, peerFactory), "switchboard", SupervisorStrategy.Resume)) clientSpawner = system.actorOf(SimpleSupervisor.props(ClientSpawner.props(nodeParams.keyPair, nodeParams.socksProxy_opt, nodeParams.peerConnectionConf, switchboard, router), "client-spawner", SupervisorStrategy.Restart)) server = system.actorOf(SimpleSupervisor.props(Server.props(nodeParams.keyPair, nodeParams.peerConnectionConf, switchboard, router, serverBindingAddress, Some(tcpBound)), "server", SupervisorStrategy.Restart)) - paymentInitiator = system.actorOf(SimpleSupervisor.props(PaymentInitiator.props(nodeParams, router, register), "payment-initiator", SupervisorStrategy.Restart)) + paymentInitiator = system.actorOf(SimpleSupervisor.props(PaymentInitiator.props(nodeParams, PaymentInitiator.SimplePaymentFactory(nodeParams, router, register)), "payment-initiator", SupervisorStrategy.Restart)) _ = for (i <- 0 until config.getInt("autoprobe-count")) yield system.actorOf(SimpleSupervisor.props(Autoprobe.props(nodeParams, router, paymentInitiator), s"payment-autoprobe-$i", SupervisorStrategy.Restart)) kit = Kit( @@ -381,11 +385,11 @@ class Setup(datadir: File, } +// @formatter:off object Setup { final case class Seeds(nodeSeed: ByteVector, channelSeed: ByteVector) } -// @formatter:off sealed trait Bitcoin case class Bitcoind(bitcoinClient: BasicBitcoinJsonRPCClient) extends Bitcoin case class Electrum(electrumClient: ActorRef) extends Bitcoin diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index 8393976f4..4a0e87c64 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -16,7 +16,7 @@ package fr.acinq.eclair.io -import akka.actor.{Actor, ActorRef, ExtendedActorSystem, FSM, OneForOneStrategy, PossiblyHarmful, Props, Status, SupervisorStrategy, Terminated} +import akka.actor.{Actor, ActorContext, ActorRef, ExtendedActorSystem, FSM, OneForOneStrategy, PossiblyHarmful, Props, Status, SupervisorStrategy, Terminated} import akka.event.Logging.MDC import akka.event.{BusLogging, DiagnosticLoggingAdapter} import akka.util.Timeout @@ -48,7 +48,7 @@ import java.net.InetSocketAddress * * Created by PM on 26/08/2016. */ -class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] { +class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: EclairWallet, channelFactory: Peer.ChannelFactory) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] { import Peer._ @@ -57,7 +57,7 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRe when(INSTANTIATING) { case Event(Init(storedChannels), _) => val channels = storedChannels.map { state => - val channel = spawnChannel(nodeParams, origin_opt = None) + val channel = spawnChannel(origin_opt = None) channel ! INPUT_RESTORED(state) FinalChannelId(state.channelId) -> channel }.toMap @@ -294,12 +294,12 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRe (Helpers.getFinalScriptPubKey(wallet, nodeParams.chainHash), None) } val localParams = makeChannelParams(nodeParams, features, finalScript, walletStaticPaymentBasepoint, funder, fundingAmount) - val channel = spawnChannel(nodeParams, origin_opt) + val channel = spawnChannel(origin_opt) (channel, localParams) } - def spawnChannel(nodeParams: NodeParams, origin_opt: Option[ActorRef]): ActorRef = { - val channel = context.actorOf(Channel.props(nodeParams, wallet, remoteNodeId, watcher, relayer, origin_opt)) + def spawnChannel(origin_opt: Option[ActorRef]): ActorRef = { + val channel = channelFactory.spawn(context, remoteNodeId, origin_opt) context watch channel channel } @@ -353,7 +353,16 @@ object Peer { val UNKNOWN_CHANNEL_MESSAGE: ByteVector = ByteVector.view("unknown channel".getBytes()) // @formatter:on - def props(nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet): Props = Props(new Peer(nodeParams, remoteNodeId, watcher, relayer: ActorRef, wallet)) + trait ChannelFactory { + def spawn(context: ActorContext, remoteNodeId: PublicKey, origin_opt: Option[ActorRef]): ActorRef + } + + case class SimpleChannelFactory(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet) extends ChannelFactory { + override def spawn(context: ActorContext, remoteNodeId: PublicKey, origin_opt: Option[ActorRef]): ActorRef = + context.actorOf(Channel.props(nodeParams, wallet, remoteNodeId, watcher, relayer, origin_opt)) + } + + def props(nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: EclairWallet, channelFactory: ChannelFactory): Props = Props(new Peer(nodeParams, remoteNodeId, wallet, channelFactory)) // @formatter:off diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala index 7d5744e02..d1e24113a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala @@ -16,7 +16,7 @@ package fr.acinq.eclair.io -import akka.actor.{Actor, ActorLogging, ActorRef, OneForOneStrategy, Props, Status, SupervisorStrategy} +import akka.actor.{Actor, ActorContext, ActorLogging, ActorRef, OneForOneStrategy, Props, Status, SupervisorStrategy} import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.NodeParams import fr.acinq.eclair.blockchain.EclairWallet @@ -29,7 +29,7 @@ import fr.acinq.eclair.router.Router.RouterConf * Ties network connections to peers. * Created by PM on 14/02/2017. */ -class Switchboard(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet) extends Actor with ActorLogging { +class Switchboard(nodeParams: NodeParams, peerFactory: Switchboard.PeerFactory) extends Actor with ActorLogging { import Switchboard._ @@ -103,7 +103,7 @@ class Switchboard(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef, */ def getPeer(remoteNodeId: PublicKey): Option[ActorRef] = context.child(peerActorName(remoteNodeId)) - def createPeer(remoteNodeId: PublicKey): ActorRef = context.actorOf(Peer.props(nodeParams, remoteNodeId, watcher, relayer, wallet), name = peerActorName(remoteNodeId)) + def createPeer(remoteNodeId: PublicKey): ActorRef = peerFactory.spawn(context, remoteNodeId) def createOrGetPeer(remoteNodeId: PublicKey, offlineChannels: Set[HasCommitments]): ActorRef = { getPeer(remoteNodeId) match { @@ -124,7 +124,16 @@ class Switchboard(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef, object Switchboard { - def props(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet) = Props(new Switchboard(nodeParams, watcher, relayer, wallet)) + trait PeerFactory { + def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef + } + + case class SimplePeerFactory(nodeParams: NodeParams, wallet: EclairWallet, channelFactory: Peer.ChannelFactory) extends PeerFactory { + override def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef = + context.actorOf(Peer.props(nodeParams, remoteNodeId, wallet, channelFactory), name = peerActorName(remoteNodeId)) + } + + def props(nodeParams: NodeParams, peerFactory: PeerFactory) = Props(new Switchboard(nodeParams, peerFactory)) def peerActorName(remoteNodeId: PublicKey): String = s"peer-$remoteNodeId" diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index cc81355c8..572395c47 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -29,11 +29,10 @@ import fr.acinq.eclair.payment.OutgoingPacket.Upstream import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM.HtlcPart -import fr.acinq.eclair.payment.relay.NodeRelay.FsmFactory import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment} import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPayment -import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentLifecycle} +import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentInitiator, PaymentLifecycle} import fr.acinq.eclair.router.Router.RouteParams import fr.acinq.eclair.router.{BalanceTooLow, RouteCalculation, RouteNotFound} import fr.acinq.eclair.wire.protocol._ @@ -60,29 +59,32 @@ object NodeRelay { private case class WrappedPaymentFailed(paymentFailed: PaymentFailed) extends Command // @formatter:on - def apply(nodeParams: NodeParams, parent: akka.actor.typed.ActorRef[NodeRelayer.Command], router: ActorRef, register: ActorRef, relayId: UUID, paymentHash: ByteVector32, fsmFactory: FsmFactory = new FsmFactory): Behavior[Command] = - Behaviors.setup { context => - Behaviors.withMdc(Logs.mdc( - category_opt = Some(Logs.LogCategory.PAYMENT), - parentPaymentId_opt = Some(relayId), // for a node relay, we use the same identifier for the whole relay itself, and the outgoing payment - paymentHash_opt = Some(paymentHash))) { - new NodeRelay(nodeParams, parent, router, register, relayId, paymentHash, context, fsmFactory)() - } - } + trait OutgoingPaymentFactory { + def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], cfg: SendPaymentConfig, multiPart: Boolean): ActorRef + } - /** - * This is supposed to be overridden in tests - */ - class FsmFactory { - def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], nodeParams: NodeParams, router: ActorRef, register: ActorRef, cfg: SendPaymentConfig, multiPart: Boolean): ActorRef = { + case class SimpleOutgoingPaymentFactory(nodeParams: NodeParams, router: ActorRef, register: ActorRef) extends OutgoingPaymentFactory { + val paymentFactory = PaymentInitiator.SimplePaymentFactory(nodeParams, router, register) + + override def spawnOutgoingPayFSM(context: ActorContext[Command], cfg: SendPaymentConfig, multiPart: Boolean): ActorRef = { if (multiPart) { - context.toClassic.actorOf(MultiPartPaymentLifecycle.props(nodeParams, cfg, router, register)) + context.toClassic.actorOf(MultiPartPaymentLifecycle.props(nodeParams, cfg, router, paymentFactory)) } else { context.toClassic.actorOf(PaymentLifecycle.props(nodeParams, cfg, router, register)) } } } + def apply(nodeParams: NodeParams, parent: akka.actor.typed.ActorRef[NodeRelayer.Command], register: ActorRef, relayId: UUID, paymentHash: ByteVector32, outgoingPaymentFactory: OutgoingPaymentFactory): Behavior[Command] = + Behaviors.setup { context => + Behaviors.withMdc(Logs.mdc( + category_opt = Some(Logs.LogCategory.PAYMENT), + parentPaymentId_opt = Some(relayId), // for a node relay, we use the same identifier for the whole relay itself, and the outgoing payment + paymentHash_opt = Some(paymentHash))) { + new NodeRelay(nodeParams, parent, register, relayId, paymentHash, context, outgoingPaymentFactory)() + } + } + def validateRelay(nodeParams: NodeParams, upstream: Upstream.Trampoline, payloadOut: Onion.NodeRelayPayload): Option[FailureMessage] = { val fee = nodeFee(nodeParams.feeBase, nodeParams.feeProportionalMillionth, payloadOut.amountToForward) if (upstream.amountIn - payloadOut.amountToForward < fee) { @@ -139,12 +141,11 @@ object NodeRelay { */ class NodeRelay private(nodeParams: NodeParams, parent: akka.actor.typed.ActorRef[NodeRelayer.Command], - router: ActorRef, register: ActorRef, relayId: UUID, paymentHash: ByteVector32, context: ActorContext[NodeRelay.Command], - fsmFactory: FsmFactory) { + outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory) { import NodeRelay._ @@ -285,20 +286,20 @@ class NodeRelay private(nodeParams: NodeParams, case Some(paymentSecret) if Features(features).hasFeature(Features.BasicMultiPartPayment) => context.log.debug("sending the payment to non-trampoline recipient using MPP") val payment = SendMultiPartPayment(payFsmAdapters, paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, routingHints, Some(routeParams)) - val payFSM = fsmFactory.spawnOutgoingPayFSM(context, nodeParams, router, register, paymentCfg, multiPart = true) + val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = true) payFSM ! payment payFSM case _ => context.log.debug("sending the payment to non-trampoline recipient without MPP") val finalPayload = Onion.createSinglePartPayload(payloadOut.amountToForward, payloadOut.outgoingCltv, payloadOut.paymentSecret) val payment = SendPayment(payFsmAdapters, payloadOut.outgoingNodeId, finalPayload, nodeParams.maxPaymentAttempts, routingHints, Some(routeParams)) - val payFSM = fsmFactory.spawnOutgoingPayFSM(context, nodeParams, router, register, paymentCfg, multiPart = false) + val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = false) payFSM ! payment payFSM } case None => context.log.debug("sending the payment to the next trampoline node") - val payFSM = fsmFactory.spawnOutgoingPayFSM(context, nodeParams, router, register, paymentCfg, multiPart = true) + val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = true) val paymentSecret = randomBytes32 // we generate a new secret to protect against probing attacks val payment = SendMultiPartPayment(payFsmAdapters, paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, routeParams = Some(routeParams), additionalTlvs = Seq(OnionTlv.TrampolineOnion(packetOut))) payFSM ! payment diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala index d7bf544be..737e08f1a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala @@ -66,7 +66,8 @@ object NodeRelayer { case None => val relayId = UUID.randomUUID() context.log.debug(s"spawning a new handler with relayId=$relayId") - val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, router, register, relayId, paymentHash), relayId.toString) + val outgoingPaymentFactory = NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register) + val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, paymentHash, outgoingPaymentFactory), relayId.toString) context.log.debug("forwarding incoming htlc to new handler") handler ! NodeRelay.Relay(nodeRelayPacket) apply(nodeParams, router, register, children + (paymentHash -> handler)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala index f669cf206..8d140b3a7 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala @@ -44,7 +44,7 @@ import java.util.concurrent.TimeUnit * Sender for a multi-part payment (see https://github.com/lightningnetwork/lightning-rfc/blob/master/04-onion-routing.md#basic-multi-part-payments). * The payment will be split into multiple sub-payments that will be sent in parallel. */ -class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, register: ActorRef) extends FSMDiagnosticActorLogging[MultiPartPaymentLifecycle.State, MultiPartPaymentLifecycle.Data] { +class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, paymentFactory: PaymentInitiator.PaymentFactory) extends FSMDiagnosticActorLogging[MultiPartPaymentLifecycle.State, MultiPartPaymentLifecycle.Data] { import MultiPartPaymentLifecycle._ @@ -202,13 +202,13 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, case Event(_: Status.Failure, _) => stay } - def spawnChildPaymentFsm(childId: UUID): ActorRef = { + private def spawnChildPaymentFsm(childId: UUID): ActorRef = { val upstream = cfg.upstream match { case Upstream.Local(_) => Upstream.Local(childId) case _ => cfg.upstream } val childCfg = cfg.copy(id = childId, publishEvent = false, upstream = upstream) - context.actorOf(PaymentLifecycle.props(nodeParams, childCfg, router, register)) + paymentFactory.spawnOutgoingPayment(context, childCfg) } private def gotoAbortedOrStop(d: PaymentAborted): State = { @@ -265,7 +265,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, object MultiPartPaymentLifecycle { - def props(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, register: ActorRef) = Props(new MultiPartPaymentLifecycle(nodeParams, cfg, router, register)) + def props(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, paymentFactory: PaymentInitiator.PaymentFactory) = Props(new MultiPartPaymentLifecycle(nodeParams, cfg, router, paymentFactory)) /** * Send a payment to a given node. The payment may be split into multiple child payments, for which a path-finding diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala index 061397aad..d34650829 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala @@ -16,7 +16,7 @@ package fr.acinq.eclair.payment.send -import akka.actor.{Actor, ActorLogging, ActorRef, Props} +import akka.actor.{Actor, ActorContext, ActorLogging, ActorRef, Props} import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.Features.BasicMultiPartPayment @@ -39,7 +39,7 @@ import java.util.UUID /** * Created by PM on 29/08/2016. */ -class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: ActorRef) extends Actor with ActorLogging { +class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentInitiator.MultiPartPaymentFactory) extends Actor with ActorLogging { import PaymentInitiator._ @@ -57,14 +57,16 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor case Some(invoice) if invoice.features.allowMultiPart && nodeParams.features.hasFeature(BasicMultiPartPayment) => invoice.paymentSecret match { case Some(paymentSecret) => - spawnMultiPartPaymentFsm(paymentCfg) ! SendMultiPartPayment(sender, paymentSecret, r.recipientNodeId, r.recipientAmount, finalExpiry, r.maxAttempts, r.assistedRoutes, r.routeParams, userCustomTlvs = r.userCustomTlvs) + val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg) + fsm ! SendMultiPartPayment(sender, paymentSecret, r.recipientNodeId, r.recipientAmount, finalExpiry, r.maxAttempts, r.assistedRoutes, r.routeParams, userCustomTlvs = r.userCustomTlvs) case None => sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(Nil, PaymentSecretMissing) :: Nil) } case _ => val paymentSecret = r.paymentRequest.flatMap(_.paymentSecret) val finalPayload = Onion.createSinglePartPayload(r.recipientAmount, finalExpiry, paymentSecret, r.userCustomTlvs) - spawnPaymentFsm(paymentCfg) ! SendPayment(sender, r.recipientNodeId, finalPayload, r.maxAttempts, r.assistedRoutes, r.routeParams) + val fsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) + fsm ! SendPayment(sender, r.recipientNodeId, finalPayload, r.maxAttempts, r.assistedRoutes, r.routeParams) } case r: SendTrampolinePaymentRequest => @@ -122,7 +124,7 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor val finalExpiry = r.finalExpiry(nodeParams.currentBlockHeight) val additionalHops = r.trampolineNodes.sliding(2).map(hop => NodeHop(hop.head, hop(1), CltvExpiryDelta(0), 0 msat)).toSeq val paymentCfg = SendPaymentConfig(paymentId, parentPaymentId, r.externalId, r.paymentHash, r.recipientAmount, r.recipientNodeId, Upstream.Local(paymentId), Some(r.paymentRequest), storeInDb = true, publishEvent = true, additionalHops) - val payFsm = spawnPaymentFsm(paymentCfg) + val payFsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) r.trampolineNodes match { case trampoline :: recipient :: Nil => log.info(s"sending trampoline payment to $recipient with trampoline=$trampoline, trampoline fees=${r.trampolineFees}, expiry delta=${r.trampolineExpiryDelta}") @@ -142,10 +144,6 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor } } - def spawnPaymentFsm(paymentCfg: SendPaymentConfig): ActorRef = context.actorOf(PaymentLifecycle.props(nodeParams, paymentCfg, router, register)) - - def spawnMultiPartPaymentFsm(paymentCfg: SendPaymentConfig): ActorRef = context.actorOf(MultiPartPaymentLifecycle.props(nodeParams, paymentCfg, router, register)) - private def buildTrampolinePayment(r: SendTrampolinePaymentRequest, trampolineFees: MilliSatoshi, trampolineExpiryDelta: CltvExpiryDelta): (MilliSatoshi, CltvExpiry, OnionRoutingPacket) = { val trampolineRoute = Seq( NodeHop(nodeParams.nodeId, r.trampolineNodeId, nodeParams.expiryDelta, 0 msat), @@ -170,14 +168,33 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor // We generate a random secret for this payment to avoid leaking the invoice secret to the first trampoline node. val trampolineSecret = randomBytes32 val (trampolineAmount, trampolineExpiry, trampolineOnion) = buildTrampolinePayment(r, trampolineFees, trampolineExpiryDelta) - spawnMultiPartPaymentFsm(paymentCfg) ! SendMultiPartPayment(self, trampolineSecret, r.trampolineNodeId, trampolineAmount, trampolineExpiry, 1, r.paymentRequest.routingInfo, r.routeParams, Seq(OnionTlv.TrampolineOnion(trampolineOnion))) + val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg) + fsm ! SendMultiPartPayment(self, trampolineSecret, r.trampolineNodeId, trampolineAmount, trampolineExpiry, 1, r.paymentRequest.routingInfo, r.routeParams, Seq(OnionTlv.TrampolineOnion(trampolineOnion))) } } object PaymentInitiator { - def props(nodeParams: NodeParams, router: ActorRef, register: ActorRef) = Props(new PaymentInitiator(nodeParams, router, register)) + trait PaymentFactory { + def spawnOutgoingPayment(context: ActorContext, cfg: SendPaymentConfig): ActorRef + } + + trait MultiPartPaymentFactory extends PaymentFactory { + def spawnOutgoingMultiPartPayment(context: ActorContext, cfg: SendPaymentConfig): ActorRef + } + + case class SimplePaymentFactory(nodeParams: NodeParams, router: ActorRef, register: ActorRef) extends MultiPartPaymentFactory { + override def spawnOutgoingPayment(context: ActorContext, cfg: SendPaymentConfig): ActorRef = { + context.actorOf(PaymentLifecycle.props(nodeParams, cfg, router, register)) + } + + override def spawnOutgoingMultiPartPayment(context: ActorContext, cfg: SendPaymentConfig): ActorRef = { + context.actorOf(MultiPartPaymentLifecycle.props(nodeParams, cfg, router, this)) + } + } + + def props(nodeParams: NodeParams, outgoingPaymentFactory: MultiPartPaymentFactory) = Props(new PaymentInitiator(nodeParams, outgoingPaymentFactory)) case class PendingPayment(sender: ActorRef, remainingAttempts: Seq[(MilliSatoshi, CltvExpiryDelta)], r: SendTrampolinePaymentRequest) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala index a27f68d4c..619a552da 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala @@ -16,8 +16,8 @@ package fr.acinq.eclair.io -import akka.actor.FSM import akka.actor.Status.Failure +import akka.actor.{ActorContext, ActorRef, FSM} import akka.testkit.{TestFSMRef, TestProbe} import com.google.common.net.HostAndPort import fr.acinq.bitcoin.Crypto.PublicKey @@ -46,14 +46,20 @@ class PeerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Paralle val fakeIPAddress: NodeAddress = NodeAddress.fromParts("1.2.3.4", 42000).get - case class FixtureParam(nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: TestProbe, relayer: TestProbe, peer: TestFSMRef[Peer.State, Peer.Data, Peer], peerConnection: TestProbe) + case class FixtureParam(nodeParams: NodeParams, remoteNodeId: PublicKey, peer: TestFSMRef[Peer.State, Peer.Data, Peer], peerConnection: TestProbe, channel: TestProbe) + + case class FakeChannelFactory(channel: TestProbe) extends ChannelFactory { + override def spawn(context: ActorContext, remoteNodeId: PublicKey, origin_opt: Option[ActorRef]): ActorRef = { + assert(remoteNodeId === Bob.nodeParams.nodeId) + channel.ref + } + } override protected def withFixture(test: OneArgTest): Outcome = { - val watcher = TestProbe() - val relayer = TestProbe() val wallet: EclairWallet = new TestWallet() val remoteNodeId = Bob.nodeParams.nodeId val peerConnection = TestProbe() + val channel = TestProbe() import com.softwaremill.quicklens._ val aliceParams = TestConstants.Alice.nodeParams @@ -68,8 +74,8 @@ class PeerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Paralle aliceParams.db.network.addNode(bobAnnouncement) } - val peer: TestFSMRef[Peer.State, Peer.Data, Peer] = TestFSMRef(new Peer(aliceParams, remoteNodeId, watcher.ref, relayer.ref, wallet)) - withFixture(test.toNoArgTest(FixtureParam(aliceParams, remoteNodeId, watcher, relayer, peer, peerConnection))) + val peer: TestFSMRef[Peer.State, Peer.Data, Peer] = TestFSMRef(new Peer(aliceParams, remoteNodeId, wallet, FakeChannelFactory(channel))) + withFixture(test.toNoArgTest(FixtureParam(aliceParams, remoteNodeId, peer, peerConnection, channel))) } def connect(remoteNodeId: PublicKey, peer: TestFSMRef[Peer.State, Peer.Data, Peer], peerConnection: TestProbe, channels: Set[HasCommitments] = Set.empty, remoteInit: protocol.Init = protocol.Init(Bob.nodeParams.features)): Unit = { @@ -198,21 +204,26 @@ class PeerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Paralle val peerConnection3 = TestProbe() connect(remoteNodeId, peer, peerConnection, channels = Set(ChannelCodecsSpec.normal)) - peerConnection1.expectMsgType[ChannelReestablish] - // this is just to extract inits - val Peer.ConnectedData(_, _, localInit, remoteInit, _) = peer.stateData + channel.expectMsg(INPUT_RESTORED(ChannelCodecsSpec.normal)) + val (localInit, remoteInit) = { + val inputReconnected = channel.expectMsgType[INPUT_RECONNECTED] + assert(inputReconnected.remote === peerConnection1.ref) + (inputReconnected.localInit, inputReconnected.remoteInit) + } peerConnection2.send(peer, PeerConnection.ConnectionReady(peerConnection2.ref, remoteNodeId, fakeIPAddress.socketAddress, outgoing = false, localInit, remoteInit)) // peer should kill previous connection peerConnection1.expectMsg(PeerConnection.Kill(PeerConnection.KillReason.ConnectionReplaced)) + channel.expectMsg(INPUT_DISCONNECTED) + channel.expectMsg(INPUT_RECONNECTED(peerConnection2.ref, localInit, remoteInit)) awaitCond(peer.stateData.asInstanceOf[Peer.ConnectedData].peerConnection === peerConnection2.ref) - peerConnection2.expectMsgType[ChannelReestablish] peerConnection3.send(peer, PeerConnection.ConnectionReady(peerConnection3.ref, remoteNodeId, fakeIPAddress.socketAddress, outgoing = false, localInit, remoteInit)) // peer should kill previous connection peerConnection2.expectMsg(PeerConnection.Kill(PeerConnection.KillReason.ConnectionReplaced)) + channel.expectMsg(INPUT_DISCONNECTED) + channel.expectMsg(INPUT_RECONNECTED(peerConnection3.ref, localInit, remoteInit)) awaitCond(peer.stateData.asInstanceOf[Peer.ConnectedData].peerConnection === peerConnection3.ref) - peerConnection3.expectMsgType[ChannelReestablish] } test("send state transitions to child reconnection actor", Tag("auto_reconnect"), Tag("with_node_announcement")) { f => @@ -251,12 +262,12 @@ class PeerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Paralle val open = protocol.OpenChannel(Block.RegtestGenesisBlock.hash, randomBytes32, 25000 sat, 0 msat, 483 sat, UInt64(100), 1000 sat, 1 msat, TestConstants.feeratePerKw, CltvExpiryDelta(144), 10, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, 0) peerConnection.send(peer, open) awaitCond(peer.stateData.channels.nonEmpty) - assert(probe.expectMsgType[ChannelCreated].temporaryChannelId === open.temporaryChannelId) - peerConnection.expectMsgType[AcceptChannel] + assert(channel.expectMsgType[INPUT_INIT_FUNDEE].temporaryChannelId === open.temporaryChannelId) + channel.expectMsg(open) // open_channel messages with the same temporary channel id should simply be ignored peerConnection.send(peer, open.copy(fundingSatoshis = 100000 sat, fundingPubkey = randomKey.publicKey)) - probe.expectNoMsg(100 millis) + channel.expectNoMsg(100 millis) peerConnection.expectNoMsg(100 millis) assert(peer.stateData.channels.size === 1) } @@ -307,59 +318,64 @@ class PeerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Paralle import f._ val probe = TestProbe() - system.eventStream.subscribe(probe.ref, classOf[ChannelCreated]) connect(remoteNodeId, peer, peerConnection) assert(peer.stateData.channels.isEmpty) val relayFees = Some(100 msat, 1000) probe.send(peer, Peer.OpenChannel(remoteNodeId, 12300 sat, 0 msat, None, relayFees, None, None)) + val init = channel.expectMsgType[INPUT_INIT_FUNDER] + assert(init.channelVersion === ChannelVersion.STANDARD) + assert(init.fundingAmount === 12300.sat) + assert(init.initialRelayFees_opt === relayFees) awaitCond(peer.stateData.channels.nonEmpty) - - val channelCreated = probe.expectMsgType[ChannelCreated] - assert(channelCreated.initialFeeratePerKw == nodeParams.onChainFeeConf.feeEstimator.getFeeratePerKw(nodeParams.onChainFeeConf.feeTargets.commitmentBlockTarget)) - assert(channelCreated.fundingTxFeeratePerKw.get == nodeParams.onChainFeeConf.feeEstimator.getFeeratePerKw(nodeParams.onChainFeeConf.feeTargets.fundingBlockTarget)) - - peer.stateData.channels.foreach { case (_, channelRef) => - probe.send(channelRef, CMD_GETINFO(probe.ref)) - val info = probe.expectMsgType[RES_GETINFO] - assert(info.state == WAIT_FOR_ACCEPT_CHANNEL) - val inputInit = info.data.asInstanceOf[DATA_WAIT_FOR_ACCEPT_CHANNEL].initFunder - assert(inputInit.initialRelayFees_opt === relayFees) - } } test("use correct on-chain fee rates when spawning a channel (anchor outputs)", Tag("anchor_outputs")) { f => import f._ val probe = TestProbe() - system.eventStream.subscribe(probe.ref, classOf[ChannelCreated]) connect(remoteNodeId, peer, peerConnection, remoteInit = protocol.Init(Features(StaticRemoteKey -> Optional, AnchorOutputs -> Optional))) + assert(peer.stateData.channels.isEmpty) // We ensure the current network feerate is higher than the default anchor output feerate. val feeEstimator = nodeParams.onChainFeeConf.feeEstimator.asInstanceOf[TestFeeEstimator] feeEstimator.setFeerate(FeeratesPerKw.single(TestConstants.anchorOutputsFeeratePerKw * 2)) probe.send(peer, Peer.OpenChannel(remoteNodeId, 15000 sat, 0 msat, None, None, None, None)) - - val channelCreated = probe.expectMsgType[ChannelCreated] - assert(channelCreated.initialFeeratePerKw == TestConstants.anchorOutputsFeeratePerKw) - assert(channelCreated.fundingTxFeeratePerKw.get == feeEstimator.getFeeratePerKw(nodeParams.onChainFeeConf.feeTargets.fundingBlockTarget)) + val init = channel.expectMsgType[INPUT_INIT_FUNDER] + assert(init.channelVersion.hasAnchorOutputs) + assert(init.fundingAmount === 15000.sat) + assert(init.initialRelayFees_opt === None) + assert(init.initialFeeratePerKw === TestConstants.anchorOutputsFeeratePerKw) + assert(init.fundingTxFeeratePerKw === feeEstimator.getFeeratePerKw(nodeParams.onChainFeeConf.feeTargets.fundingBlockTarget)) } test("use correct final script if option_static_remotekey is negotiated", Tag("static_remotekey")) { f => import f._ val probe = TestProbe() - connect(remoteNodeId, peer, peerConnection, remoteInit = protocol.Init(Features(StaticRemoteKey -> Optional))) // Bob supports option_static_remotekey + connect(remoteNodeId, peer, peerConnection, remoteInit = protocol.Init(Features(StaticRemoteKey -> Optional))) probe.send(peer, Peer.OpenChannel(remoteNodeId, 24000 sat, 0 msat, None, None, None, None)) - awaitCond(peer.stateData.channels.nonEmpty) - peer.stateData.channels.foreach { case (_, channelRef) => - probe.send(channelRef, CMD_GETINFO(probe.ref)) - val info = probe.expectMsgType[RES_GETINFO] - assert(info.state == WAIT_FOR_ACCEPT_CHANNEL) - val inputInit = info.data.asInstanceOf[DATA_WAIT_FOR_ACCEPT_CHANNEL].initFunder - assert(inputInit.channelVersion.hasStaticRemotekey) - assert(inputInit.localParams.walletStaticPaymentBasepoint.isDefined) - assert(inputInit.localParams.defaultFinalScriptPubKey === Script.write(Script.pay2wpkh(inputInit.localParams.walletStaticPaymentBasepoint.get))) + val init = channel.expectMsgType[INPUT_INIT_FUNDER] + assert(init.channelVersion.hasStaticRemotekey) + assert(init.localParams.walletStaticPaymentBasepoint.isDefined) + assert(init.localParams.defaultFinalScriptPubKey === Script.write(Script.pay2wpkh(init.localParams.walletStaticPaymentBasepoint.get))) + } + + test("set origin_opt when spawning a channel") { f => + import f._ + + val probe = TestProbe() + val channelFactory = new ChannelFactory { + override def spawn(context: ActorContext, remoteNodeId: PublicKey, origin_opt: Option[ActorRef]): ActorRef = { + assert(origin_opt === Some(probe.ref)) + channel.ref + } } + val peer = TestFSMRef(new Peer(TestConstants.Alice.nodeParams, remoteNodeId, new TestWallet, channelFactory)) + connect(remoteNodeId, peer, peerConnection) + probe.send(peer, Peer.OpenChannel(remoteNodeId, 15000 sat, 100 msat, None, None, None, None)) + val init = channel.expectMsgType[INPUT_INIT_FUNDER] + assert(init.fundingAmount === 15000.sat) + assert(init.pushAmount === 100.msat) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/SwitchboardSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/SwitchboardSpec.scala index 057afbf8a..475032d15 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/SwitchboardSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/SwitchboardSpec.scala @@ -1,23 +1,23 @@ package fr.acinq.eclair.io -import akka.actor.ActorRef +import akka.actor.{ActorContext, ActorRef} import akka.testkit.{TestActorRef, TestProbe} import fr.acinq.bitcoin.ByteVector64 import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.TestConstants._ -import fr.acinq.eclair.blockchain.TestWallet import fr.acinq.eclair.channel.ChannelIdAssigned -import fr.acinq.eclair.wire.protocol._ +import fr.acinq.eclair.io.Switchboard.PeerFactory import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec +import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{Features, NodeParams, TestKitBaseClass, randomBytes32, randomKey} import org.scalatest.funsuite.AnyFunSuiteLike import scodec.bits._ class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike { - class TestSwitchboard(nodeParams: NodeParams, remoteNodeId: PublicKey, remotePeer: TestProbe) extends Switchboard(nodeParams, TestProbe().ref, TestProbe().ref, new TestWallet()) { - override def createPeer(remoteNodeId2: PublicKey): ActorRef = { - assert(remoteNodeId === remoteNodeId2) + case class FakePeerFactory(expectedRemoteNodeId: PublicKey, remotePeer: TestProbe) extends PeerFactory { + override def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef = { + assert(expectedRemoteNodeId === remoteNodeId) remotePeer.ref } } @@ -29,7 +29,7 @@ class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike { // If we have a channel with that remote peer, we will automatically reconnect. nodeParams.db.channels.addOrUpdateChannel(ChannelCodecsSpec.normal) - val _ = TestActorRef(new TestSwitchboard(nodeParams, remoteNodeId, peer)) + val _ = TestActorRef(new Switchboard(nodeParams, FakePeerFactory(remoteNodeId, peer))) peer.expectMsg(Peer.Init(Set(ChannelCodecsSpec.normal))) } @@ -40,7 +40,7 @@ class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike { val remoteNodeAddress = NodeAddress.fromParts("127.0.0.1", 9735).get nodeParams.db.network.addNode(NodeAnnouncement(ByteVector64.Zeroes, Features.empty, 0, remoteNodeId, Color(0, 0, 0), "alias", remoteNodeAddress :: Nil)) - val switchboard = TestActorRef(new TestSwitchboard(nodeParams, remoteNodeId, peer)) + val switchboard = TestActorRef(new Switchboard(nodeParams, FakePeerFactory(remoteNodeId, peer))) probe.send(switchboard, Peer.Connect(remoteNodeId, None)) peer.expectMsg(Peer.Init(Set.empty)) peer.expectMsg(Peer.Connect(remoteNodeId, None)) @@ -49,7 +49,7 @@ class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike { def sendFeatures(nodeParams: NodeParams, remoteNodeId: PublicKey, expectedFeatures: Features, expectedSync: Boolean) = { val peer = TestProbe() val peerConnection = TestProbe() - val switchboard = TestActorRef(new TestSwitchboard(nodeParams, remoteNodeId, peer)) + val switchboard = TestActorRef(new Switchboard(nodeParams, FakePeerFactory(remoteNodeId, peer))) switchboard ! PeerConnection.Authenticated(peerConnection.ref, remoteNodeId) peerConnection.expectMsg(PeerConnection.InitializeConnection(peer.ref, nodeParams.chainHash, expectedFeatures, doSync = expectedSync)) } @@ -66,7 +66,7 @@ class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike { val peerConnection = TestProbe() val nodeParams = Alice.nodeParams.copy(syncWhitelist = Set.empty) val remoteNodeId = ChannelCodecsSpec.normal.commitments.remoteParams.nodeId - val switchboard = TestActorRef(new TestSwitchboard(nodeParams, remoteNodeId, peer)) + val switchboard = TestActorRef(new Switchboard(nodeParams, FakePeerFactory(remoteNodeId, peer))) // We have a channel with our peer, so we trigger a sync when connecting. switchboard ! ChannelIdAssigned(TestProbe().ref, remoteNodeId, randomBytes32, randomBytes32) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala index c2a41431d..796caf71e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala @@ -16,7 +16,7 @@ package fr.acinq.eclair.payment -import akka.actor.{ActorRef, Status} +import akka.actor.{ActorContext, ActorRef, Status} import akka.testkit.{TestFSMRef, TestProbe} import fr.acinq.bitcoin.{Block, Crypto} import fr.acinq.eclair._ @@ -25,11 +25,11 @@ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.{FailureSummary, FailureType, OutgoingPaymentStatus} import fr.acinq.eclair.payment.OutgoingPacket.Upstream import fr.acinq.eclair.payment.PaymentRequest.ExtraHop -import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle._ import fr.acinq.eclair.payment.send.PaymentError.RetryExhausted import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToRoute +import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentInitiator} import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.router.{Announcements, RouteNotFound} import fr.acinq.eclair.wire.protocol._ @@ -56,15 +56,16 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm: TestProbe, eventListener: TestProbe) + case class FakePaymentFactory(childPayFsm: TestProbe) extends PaymentInitiator.PaymentFactory { + override def spawnOutgoingPayment(context: ActorContext, cfg: SendPaymentConfig): ActorRef = childPayFsm.ref + } + override def withFixture(test: OneArgTest): Outcome = { val id = UUID.randomUUID() val cfg = SendPaymentConfig(id, id, Some("42"), paymentHash, finalAmount, finalRecipient, Upstream.Local(id), None, storeInDb = true, publishEvent = true, Nil) val nodeParams = TestConstants.Alice.nodeParams val (childPayFsm, router, sender, eventListener) = (TestProbe(), TestProbe(), TestProbe(), TestProbe()) - class TestMultiPartPaymentLifecycle extends MultiPartPaymentLifecycle(nodeParams, cfg, router.ref, TestProbe().ref) { - override def spawnChildPaymentFsm(childId: UUID): ActorRef = childPayFsm.ref - } - val paymentHandler = TestFSMRef(new TestMultiPartPaymentLifecycle().asInstanceOf[MultiPartPaymentLifecycle]) + val paymentHandler = TestFSMRef(new MultiPartPaymentLifecycle(nodeParams, cfg, router.ref, FakePaymentFactory(childPayFsm))) system.eventStream.subscribe(eventListener.ref, classOf[PaymentEvent]) withFixture(test.toNoArgTest(FixtureParam(cfg, nodeParams, paymentHandler, router, sender, childPayFsm, eventListener))) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala index 832863ec1..88a0d84ce 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala @@ -16,7 +16,7 @@ package fr.acinq.eclair.payment -import akka.actor.ActorRef +import akka.actor.{ActorContext, ActorRef} import akka.testkit.{TestActorRef, TestProbe} import fr.acinq.bitcoin.Block import fr.acinq.eclair.FeatureSupport.Optional @@ -63,25 +63,26 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike BasicMultiPartPayment -> Optional, ) + case class FakePaymentFactory(payFsm: TestProbe, multiPartPayFsm: TestProbe) extends PaymentInitiator.MultiPartPaymentFactory { + // @formatter:off + override def spawnOutgoingPayment(context: ActorContext, cfg: SendPaymentConfig): ActorRef = { + payFsm.ref ! cfg + payFsm.ref + } + override def spawnOutgoingMultiPartPayment(context: ActorContext, cfg: SendPaymentConfig): ActorRef = { + multiPartPayFsm.ref ! cfg + multiPartPayFsm.ref + } + // @formatter:on + } + override def withFixture(test: OneArgTest): Outcome = { val features = if (test.tags.contains("mpp_disabled")) featuresWithoutMpp else featuresWithMpp val nodeParams = TestConstants.Alice.nodeParams.copy(features = features) val (sender, payFsm, multiPartPayFsm) = (TestProbe(), TestProbe(), TestProbe()) val eventListener = TestProbe() system.eventStream.subscribe(eventListener.ref, classOf[PaymentEvent]) - class TestPaymentInitiator extends PaymentInitiator(nodeParams, TestProbe().ref, TestProbe().ref) { - // @formatter:off - override def spawnPaymentFsm(cfg: SendPaymentConfig): ActorRef = { - payFsm.ref ! cfg - payFsm.ref - } - override def spawnMultiPartPaymentFsm(cfg: SendPaymentConfig): ActorRef = { - multiPartPayFsm.ref ! cfg - multiPartPayFsm.ref - } - // @formatter:on - } - val initiator = TestActorRef(new TestPaymentInitiator().asInstanceOf[PaymentInitiator]) + val initiator = TestActorRef(new PaymentInitiator(nodeParams, FakePaymentFactory(payFsm, multiPartPayFsm))) withFixture(test.toNoArgTest(FixtureParam(nodeParams, initiator, payFsm, multiPartPayFsm, sender, eventListener))) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala index 8143c841a..4ccd86703 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala @@ -63,23 +63,23 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val eventListener = TestProbe[PaymentEvent]("event-listener") system.eventStream ! EventStream.Subscribe(eventListener.ref) val mockPayFSM = TestProbe[Any]("pay-fsm") - val fsmFactory = if (test.tags.contains("mock-fsm")) { - new NodeRelay.FsmFactory { - override def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], nodeParams: NodeParams, router: akka.actor.ActorRef, register: akka.actor.ActorRef, cfg: SendPaymentConfig, multiPart: Boolean): akka.actor.ActorRef = { + val outgoingPaymentFactory = if (test.tags.contains("mock-fsm")) { + new NodeRelay.OutgoingPaymentFactory { + override def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], cfg: SendPaymentConfig, multiPart: Boolean): akka.actor.ActorRef = { mockPayFSM.ref ! cfg mockPayFSM.ref.toClassic } } } else { - new NodeRelay.FsmFactory { - override def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], nodeParams: NodeParams, router: akka.actor.ActorRef, register: akka.actor.ActorRef, cfg: SendPaymentConfig, multiPart: Boolean): akka.actor.ActorRef = { - val fsm = super.spawnOutgoingPayFSM(context, nodeParams, router, register, cfg, multiPart) - mockPayFSM.ref ! fsm - fsm + new NodeRelay.OutgoingPaymentFactory { + override def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], cfg: SendPaymentConfig, multiPart: Boolean): akka.actor.ActorRef = { + val outgoingPayFSM = NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router.ref.toClassic, register.ref.toClassic).spawnOutgoingPayFSM(context, cfg, multiPart) + mockPayFSM.ref ! outgoingPayFSM + outgoingPayFSM } } } - val nodeRelay = testKit.spawn(NodeRelay(nodeParams, parent.ref, router.ref.toClassic, register.ref.toClassic, relayId, paymentHash, fsmFactory)) + val nodeRelay = testKit.spawn(NodeRelay(nodeParams, parent.ref, register.ref.toClassic, relayId, paymentHash, outgoingPaymentFactory)) withFixture(test.toNoArgTest(FixtureParam(nodeParams, nodeRelay, parent, router, register, mockPayFSM, eventListener))) } @@ -431,7 +431,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl // Receive an upstream multi-part payment. incomingMultiPart.dropRight(1).foreach(p => nodeRelayer ! NodeRelay.Relay(p)) - router.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a complete upstream payment + mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a complete upstream payment nodeRelayer ! NodeRelay.Relay(incomingMultiPart.last)