1
0
Fork 0
mirror of https://github.com/ACINQ/eclair.git synced 2025-02-23 06:35:11 +01:00

Introduce actor factories (#1744)

This removes unnecessary fields and allows more flexibility in tests.
This commit is contained in:
Bastien Teinturier 2021-03-31 08:58:40 +02:00 committed by GitHub
parent e5429ebdf4
commit c6a76af9d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 195 additions and 136 deletions

View file

@ -37,7 +37,7 @@ import fr.acinq.eclair.channel.Register
import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyManager}
import fr.acinq.eclair.db.Databases.FileBackup
import fr.acinq.eclair.db.{Databases, DbEventHandler, FileBackupHandler}
import fr.acinq.eclair.io.{ClientSpawner, Server, Switchboard}
import fr.acinq.eclair.io.{ClientSpawner, Peer, Server, Switchboard}
import fr.acinq.eclair.payment.receive.PaymentHandler
import fr.acinq.eclair.payment.relay.Relayer
import fr.acinq.eclair.payment.send.{Autoprobe, PaymentInitiator}
@ -290,8 +290,8 @@ class Setup(datadir: File,
new ElectrumEclairWallet(electrumWallet, nodeParams.chainHash)
}
_ = wallet.getReceiveAddress.map(address => logger.info(s"initial wallet address=$address"))
// do not change the name of this actor. it is used in the configuration to specify a custom bounded mailbox
// do not change the name of this actor. it is used in the configuration to specify a custom bounded mailbox
backupHandler = if (config.getBoolean("enable-db-backup")) {
nodeParams.db match {
case fileBackup: FileBackup => system.actorOf(SimpleSupervisor.props(
@ -314,10 +314,14 @@ class Setup(datadir: File,
// Before initializing the switchboard (which re-connects us to the network) and the user-facing parts of the system,
// we want to make sure the handler for post-restart broken HTLCs has finished initializing.
_ <- postRestartCleanUpInitialized.future
switchboard = system.actorOf(SimpleSupervisor.props(Switchboard.props(nodeParams, watcher, relayer, wallet), "switchboard", SupervisorStrategy.Resume))
channelFactory = Peer.SimpleChannelFactory(nodeParams, watcher, relayer, wallet)
peerFactory = Switchboard.SimplePeerFactory(nodeParams, wallet, channelFactory)
switchboard = system.actorOf(SimpleSupervisor.props(Switchboard.props(nodeParams, peerFactory), "switchboard", SupervisorStrategy.Resume))
clientSpawner = system.actorOf(SimpleSupervisor.props(ClientSpawner.props(nodeParams.keyPair, nodeParams.socksProxy_opt, nodeParams.peerConnectionConf, switchboard, router), "client-spawner", SupervisorStrategy.Restart))
server = system.actorOf(SimpleSupervisor.props(Server.props(nodeParams.keyPair, nodeParams.peerConnectionConf, switchboard, router, serverBindingAddress, Some(tcpBound)), "server", SupervisorStrategy.Restart))
paymentInitiator = system.actorOf(SimpleSupervisor.props(PaymentInitiator.props(nodeParams, router, register), "payment-initiator", SupervisorStrategy.Restart))
paymentInitiator = system.actorOf(SimpleSupervisor.props(PaymentInitiator.props(nodeParams, PaymentInitiator.SimplePaymentFactory(nodeParams, router, register)), "payment-initiator", SupervisorStrategy.Restart))
_ = for (i <- 0 until config.getInt("autoprobe-count")) yield system.actorOf(SimpleSupervisor.props(Autoprobe.props(nodeParams, router, paymentInitiator), s"payment-autoprobe-$i", SupervisorStrategy.Restart))
kit = Kit(
@ -381,11 +385,11 @@ class Setup(datadir: File,
}
// @formatter:off
object Setup {
final case class Seeds(nodeSeed: ByteVector, channelSeed: ByteVector)
}
// @formatter:off
sealed trait Bitcoin
case class Bitcoind(bitcoinClient: BasicBitcoinJsonRPCClient) extends Bitcoin
case class Electrum(electrumClient: ActorRef) extends Bitcoin

View file

@ -16,7 +16,7 @@
package fr.acinq.eclair.io
import akka.actor.{Actor, ActorRef, ExtendedActorSystem, FSM, OneForOneStrategy, PossiblyHarmful, Props, Status, SupervisorStrategy, Terminated}
import akka.actor.{Actor, ActorContext, ActorRef, ExtendedActorSystem, FSM, OneForOneStrategy, PossiblyHarmful, Props, Status, SupervisorStrategy, Terminated}
import akka.event.Logging.MDC
import akka.event.{BusLogging, DiagnosticLoggingAdapter}
import akka.util.Timeout
@ -48,7 +48,7 @@ import java.net.InetSocketAddress
*
* Created by PM on 26/08/2016.
*/
class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] {
class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: EclairWallet, channelFactory: Peer.ChannelFactory) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] {
import Peer._
@ -57,7 +57,7 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRe
when(INSTANTIATING) {
case Event(Init(storedChannels), _) =>
val channels = storedChannels.map { state =>
val channel = spawnChannel(nodeParams, origin_opt = None)
val channel = spawnChannel(origin_opt = None)
channel ! INPUT_RESTORED(state)
FinalChannelId(state.channelId) -> channel
}.toMap
@ -294,12 +294,12 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRe
(Helpers.getFinalScriptPubKey(wallet, nodeParams.chainHash), None)
}
val localParams = makeChannelParams(nodeParams, features, finalScript, walletStaticPaymentBasepoint, funder, fundingAmount)
val channel = spawnChannel(nodeParams, origin_opt)
val channel = spawnChannel(origin_opt)
(channel, localParams)
}
def spawnChannel(nodeParams: NodeParams, origin_opt: Option[ActorRef]): ActorRef = {
val channel = context.actorOf(Channel.props(nodeParams, wallet, remoteNodeId, watcher, relayer, origin_opt))
def spawnChannel(origin_opt: Option[ActorRef]): ActorRef = {
val channel = channelFactory.spawn(context, remoteNodeId, origin_opt)
context watch channel
channel
}
@ -353,7 +353,16 @@ object Peer {
val UNKNOWN_CHANNEL_MESSAGE: ByteVector = ByteVector.view("unknown channel".getBytes())
// @formatter:on
def props(nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet): Props = Props(new Peer(nodeParams, remoteNodeId, watcher, relayer: ActorRef, wallet))
trait ChannelFactory {
def spawn(context: ActorContext, remoteNodeId: PublicKey, origin_opt: Option[ActorRef]): ActorRef
}
case class SimpleChannelFactory(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet) extends ChannelFactory {
override def spawn(context: ActorContext, remoteNodeId: PublicKey, origin_opt: Option[ActorRef]): ActorRef =
context.actorOf(Channel.props(nodeParams, wallet, remoteNodeId, watcher, relayer, origin_opt))
}
def props(nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: EclairWallet, channelFactory: ChannelFactory): Props = Props(new Peer(nodeParams, remoteNodeId, wallet, channelFactory))
// @formatter:off

View file

@ -16,7 +16,7 @@
package fr.acinq.eclair.io
import akka.actor.{Actor, ActorLogging, ActorRef, OneForOneStrategy, Props, Status, SupervisorStrategy}
import akka.actor.{Actor, ActorContext, ActorLogging, ActorRef, OneForOneStrategy, Props, Status, SupervisorStrategy}
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.NodeParams
import fr.acinq.eclair.blockchain.EclairWallet
@ -29,7 +29,7 @@ import fr.acinq.eclair.router.Router.RouterConf
* Ties network connections to peers.
* Created by PM on 14/02/2017.
*/
class Switchboard(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet) extends Actor with ActorLogging {
class Switchboard(nodeParams: NodeParams, peerFactory: Switchboard.PeerFactory) extends Actor with ActorLogging {
import Switchboard._
@ -103,7 +103,7 @@ class Switchboard(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef,
*/
def getPeer(remoteNodeId: PublicKey): Option[ActorRef] = context.child(peerActorName(remoteNodeId))
def createPeer(remoteNodeId: PublicKey): ActorRef = context.actorOf(Peer.props(nodeParams, remoteNodeId, watcher, relayer, wallet), name = peerActorName(remoteNodeId))
def createPeer(remoteNodeId: PublicKey): ActorRef = peerFactory.spawn(context, remoteNodeId)
def createOrGetPeer(remoteNodeId: PublicKey, offlineChannels: Set[HasCommitments]): ActorRef = {
getPeer(remoteNodeId) match {
@ -124,7 +124,16 @@ class Switchboard(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef,
object Switchboard {
def props(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet) = Props(new Switchboard(nodeParams, watcher, relayer, wallet))
trait PeerFactory {
def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef
}
case class SimplePeerFactory(nodeParams: NodeParams, wallet: EclairWallet, channelFactory: Peer.ChannelFactory) extends PeerFactory {
override def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef =
context.actorOf(Peer.props(nodeParams, remoteNodeId, wallet, channelFactory), name = peerActorName(remoteNodeId))
}
def props(nodeParams: NodeParams, peerFactory: PeerFactory) = Props(new Switchboard(nodeParams, peerFactory))
def peerActorName(remoteNodeId: PublicKey): String = s"peer-$remoteNodeId"

View file

@ -29,11 +29,10 @@ import fr.acinq.eclair.payment.OutgoingPacket.Upstream
import fr.acinq.eclair.payment._
import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM
import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM.HtlcPart
import fr.acinq.eclair.payment.relay.NodeRelay.FsmFactory
import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment}
import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig
import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPayment
import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentLifecycle}
import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentInitiator, PaymentLifecycle}
import fr.acinq.eclair.router.Router.RouteParams
import fr.acinq.eclair.router.{BalanceTooLow, RouteCalculation, RouteNotFound}
import fr.acinq.eclair.wire.protocol._
@ -60,26 +59,29 @@ object NodeRelay {
private case class WrappedPaymentFailed(paymentFailed: PaymentFailed) extends Command
// @formatter:on
def apply(nodeParams: NodeParams, parent: akka.actor.typed.ActorRef[NodeRelayer.Command], router: ActorRef, register: ActorRef, relayId: UUID, paymentHash: ByteVector32, fsmFactory: FsmFactory = new FsmFactory): Behavior[Command] =
trait OutgoingPaymentFactory {
def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], cfg: SendPaymentConfig, multiPart: Boolean): ActorRef
}
case class SimpleOutgoingPaymentFactory(nodeParams: NodeParams, router: ActorRef, register: ActorRef) extends OutgoingPaymentFactory {
val paymentFactory = PaymentInitiator.SimplePaymentFactory(nodeParams, router, register)
override def spawnOutgoingPayFSM(context: ActorContext[Command], cfg: SendPaymentConfig, multiPart: Boolean): ActorRef = {
if (multiPart) {
context.toClassic.actorOf(MultiPartPaymentLifecycle.props(nodeParams, cfg, router, paymentFactory))
} else {
context.toClassic.actorOf(PaymentLifecycle.props(nodeParams, cfg, router, register))
}
}
}
def apply(nodeParams: NodeParams, parent: akka.actor.typed.ActorRef[NodeRelayer.Command], register: ActorRef, relayId: UUID, paymentHash: ByteVector32, outgoingPaymentFactory: OutgoingPaymentFactory): Behavior[Command] =
Behaviors.setup { context =>
Behaviors.withMdc(Logs.mdc(
category_opt = Some(Logs.LogCategory.PAYMENT),
parentPaymentId_opt = Some(relayId), // for a node relay, we use the same identifier for the whole relay itself, and the outgoing payment
paymentHash_opt = Some(paymentHash))) {
new NodeRelay(nodeParams, parent, router, register, relayId, paymentHash, context, fsmFactory)()
}
}
/**
* This is supposed to be overridden in tests
*/
class FsmFactory {
def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], nodeParams: NodeParams, router: ActorRef, register: ActorRef, cfg: SendPaymentConfig, multiPart: Boolean): ActorRef = {
if (multiPart) {
context.toClassic.actorOf(MultiPartPaymentLifecycle.props(nodeParams, cfg, router, register))
} else {
context.toClassic.actorOf(PaymentLifecycle.props(nodeParams, cfg, router, register))
}
new NodeRelay(nodeParams, parent, register, relayId, paymentHash, context, outgoingPaymentFactory)()
}
}
@ -139,12 +141,11 @@ object NodeRelay {
*/
class NodeRelay private(nodeParams: NodeParams,
parent: akka.actor.typed.ActorRef[NodeRelayer.Command],
router: ActorRef,
register: ActorRef,
relayId: UUID,
paymentHash: ByteVector32,
context: ActorContext[NodeRelay.Command],
fsmFactory: FsmFactory) {
outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory) {
import NodeRelay._
@ -285,20 +286,20 @@ class NodeRelay private(nodeParams: NodeParams,
case Some(paymentSecret) if Features(features).hasFeature(Features.BasicMultiPartPayment) =>
context.log.debug("sending the payment to non-trampoline recipient using MPP")
val payment = SendMultiPartPayment(payFsmAdapters, paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, routingHints, Some(routeParams))
val payFSM = fsmFactory.spawnOutgoingPayFSM(context, nodeParams, router, register, paymentCfg, multiPart = true)
val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = true)
payFSM ! payment
payFSM
case _ =>
context.log.debug("sending the payment to non-trampoline recipient without MPP")
val finalPayload = Onion.createSinglePartPayload(payloadOut.amountToForward, payloadOut.outgoingCltv, payloadOut.paymentSecret)
val payment = SendPayment(payFsmAdapters, payloadOut.outgoingNodeId, finalPayload, nodeParams.maxPaymentAttempts, routingHints, Some(routeParams))
val payFSM = fsmFactory.spawnOutgoingPayFSM(context, nodeParams, router, register, paymentCfg, multiPart = false)
val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = false)
payFSM ! payment
payFSM
}
case None =>
context.log.debug("sending the payment to the next trampoline node")
val payFSM = fsmFactory.spawnOutgoingPayFSM(context, nodeParams, router, register, paymentCfg, multiPart = true)
val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = true)
val paymentSecret = randomBytes32 // we generate a new secret to protect against probing attacks
val payment = SendMultiPartPayment(payFsmAdapters, paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, routeParams = Some(routeParams), additionalTlvs = Seq(OnionTlv.TrampolineOnion(packetOut)))
payFSM ! payment

View file

@ -66,7 +66,8 @@ object NodeRelayer {
case None =>
val relayId = UUID.randomUUID()
context.log.debug(s"spawning a new handler with relayId=$relayId")
val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, router, register, relayId, paymentHash), relayId.toString)
val outgoingPaymentFactory = NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register)
val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, paymentHash, outgoingPaymentFactory), relayId.toString)
context.log.debug("forwarding incoming htlc to new handler")
handler ! NodeRelay.Relay(nodeRelayPacket)
apply(nodeParams, router, register, children + (paymentHash -> handler))

View file

@ -44,7 +44,7 @@ import java.util.concurrent.TimeUnit
* Sender for a multi-part payment (see https://github.com/lightningnetwork/lightning-rfc/blob/master/04-onion-routing.md#basic-multi-part-payments).
* The payment will be split into multiple sub-payments that will be sent in parallel.
*/
class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, register: ActorRef) extends FSMDiagnosticActorLogging[MultiPartPaymentLifecycle.State, MultiPartPaymentLifecycle.Data] {
class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, paymentFactory: PaymentInitiator.PaymentFactory) extends FSMDiagnosticActorLogging[MultiPartPaymentLifecycle.State, MultiPartPaymentLifecycle.Data] {
import MultiPartPaymentLifecycle._
@ -202,13 +202,13 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig,
case Event(_: Status.Failure, _) => stay
}
def spawnChildPaymentFsm(childId: UUID): ActorRef = {
private def spawnChildPaymentFsm(childId: UUID): ActorRef = {
val upstream = cfg.upstream match {
case Upstream.Local(_) => Upstream.Local(childId)
case _ => cfg.upstream
}
val childCfg = cfg.copy(id = childId, publishEvent = false, upstream = upstream)
context.actorOf(PaymentLifecycle.props(nodeParams, childCfg, router, register))
paymentFactory.spawnOutgoingPayment(context, childCfg)
}
private def gotoAbortedOrStop(d: PaymentAborted): State = {
@ -265,7 +265,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig,
object MultiPartPaymentLifecycle {
def props(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, register: ActorRef) = Props(new MultiPartPaymentLifecycle(nodeParams, cfg, router, register))
def props(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, paymentFactory: PaymentInitiator.PaymentFactory) = Props(new MultiPartPaymentLifecycle(nodeParams, cfg, router, paymentFactory))
/**
* Send a payment to a given node. The payment may be split into multiple child payments, for which a path-finding

View file

@ -16,7 +16,7 @@
package fr.acinq.eclair.payment.send
import akka.actor.{Actor, ActorLogging, ActorRef, Props}
import akka.actor.{Actor, ActorContext, ActorLogging, ActorRef, Props}
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.Features.BasicMultiPartPayment
@ -39,7 +39,7 @@ import java.util.UUID
/**
* Created by PM on 29/08/2016.
*/
class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: ActorRef) extends Actor with ActorLogging {
class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentInitiator.MultiPartPaymentFactory) extends Actor with ActorLogging {
import PaymentInitiator._
@ -57,14 +57,16 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor
case Some(invoice) if invoice.features.allowMultiPart && nodeParams.features.hasFeature(BasicMultiPartPayment) =>
invoice.paymentSecret match {
case Some(paymentSecret) =>
spawnMultiPartPaymentFsm(paymentCfg) ! SendMultiPartPayment(sender, paymentSecret, r.recipientNodeId, r.recipientAmount, finalExpiry, r.maxAttempts, r.assistedRoutes, r.routeParams, userCustomTlvs = r.userCustomTlvs)
val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg)
fsm ! SendMultiPartPayment(sender, paymentSecret, r.recipientNodeId, r.recipientAmount, finalExpiry, r.maxAttempts, r.assistedRoutes, r.routeParams, userCustomTlvs = r.userCustomTlvs)
case None =>
sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(Nil, PaymentSecretMissing) :: Nil)
}
case _ =>
val paymentSecret = r.paymentRequest.flatMap(_.paymentSecret)
val finalPayload = Onion.createSinglePartPayload(r.recipientAmount, finalExpiry, paymentSecret, r.userCustomTlvs)
spawnPaymentFsm(paymentCfg) ! SendPayment(sender, r.recipientNodeId, finalPayload, r.maxAttempts, r.assistedRoutes, r.routeParams)
val fsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg)
fsm ! SendPayment(sender, r.recipientNodeId, finalPayload, r.maxAttempts, r.assistedRoutes, r.routeParams)
}
case r: SendTrampolinePaymentRequest =>
@ -122,7 +124,7 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor
val finalExpiry = r.finalExpiry(nodeParams.currentBlockHeight)
val additionalHops = r.trampolineNodes.sliding(2).map(hop => NodeHop(hop.head, hop(1), CltvExpiryDelta(0), 0 msat)).toSeq
val paymentCfg = SendPaymentConfig(paymentId, parentPaymentId, r.externalId, r.paymentHash, r.recipientAmount, r.recipientNodeId, Upstream.Local(paymentId), Some(r.paymentRequest), storeInDb = true, publishEvent = true, additionalHops)
val payFsm = spawnPaymentFsm(paymentCfg)
val payFsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg)
r.trampolineNodes match {
case trampoline :: recipient :: Nil =>
log.info(s"sending trampoline payment to $recipient with trampoline=$trampoline, trampoline fees=${r.trampolineFees}, expiry delta=${r.trampolineExpiryDelta}")
@ -142,10 +144,6 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor
}
}
def spawnPaymentFsm(paymentCfg: SendPaymentConfig): ActorRef = context.actorOf(PaymentLifecycle.props(nodeParams, paymentCfg, router, register))
def spawnMultiPartPaymentFsm(paymentCfg: SendPaymentConfig): ActorRef = context.actorOf(MultiPartPaymentLifecycle.props(nodeParams, paymentCfg, router, register))
private def buildTrampolinePayment(r: SendTrampolinePaymentRequest, trampolineFees: MilliSatoshi, trampolineExpiryDelta: CltvExpiryDelta): (MilliSatoshi, CltvExpiry, OnionRoutingPacket) = {
val trampolineRoute = Seq(
NodeHop(nodeParams.nodeId, r.trampolineNodeId, nodeParams.expiryDelta, 0 msat),
@ -170,14 +168,33 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor
// We generate a random secret for this payment to avoid leaking the invoice secret to the first trampoline node.
val trampolineSecret = randomBytes32
val (trampolineAmount, trampolineExpiry, trampolineOnion) = buildTrampolinePayment(r, trampolineFees, trampolineExpiryDelta)
spawnMultiPartPaymentFsm(paymentCfg) ! SendMultiPartPayment(self, trampolineSecret, r.trampolineNodeId, trampolineAmount, trampolineExpiry, 1, r.paymentRequest.routingInfo, r.routeParams, Seq(OnionTlv.TrampolineOnion(trampolineOnion)))
val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg)
fsm ! SendMultiPartPayment(self, trampolineSecret, r.trampolineNodeId, trampolineAmount, trampolineExpiry, 1, r.paymentRequest.routingInfo, r.routeParams, Seq(OnionTlv.TrampolineOnion(trampolineOnion)))
}
}
object PaymentInitiator {
def props(nodeParams: NodeParams, router: ActorRef, register: ActorRef) = Props(new PaymentInitiator(nodeParams, router, register))
trait PaymentFactory {
def spawnOutgoingPayment(context: ActorContext, cfg: SendPaymentConfig): ActorRef
}
trait MultiPartPaymentFactory extends PaymentFactory {
def spawnOutgoingMultiPartPayment(context: ActorContext, cfg: SendPaymentConfig): ActorRef
}
case class SimplePaymentFactory(nodeParams: NodeParams, router: ActorRef, register: ActorRef) extends MultiPartPaymentFactory {
override def spawnOutgoingPayment(context: ActorContext, cfg: SendPaymentConfig): ActorRef = {
context.actorOf(PaymentLifecycle.props(nodeParams, cfg, router, register))
}
override def spawnOutgoingMultiPartPayment(context: ActorContext, cfg: SendPaymentConfig): ActorRef = {
context.actorOf(MultiPartPaymentLifecycle.props(nodeParams, cfg, router, this))
}
}
def props(nodeParams: NodeParams, outgoingPaymentFactory: MultiPartPaymentFactory) = Props(new PaymentInitiator(nodeParams, outgoingPaymentFactory))
case class PendingPayment(sender: ActorRef, remainingAttempts: Seq[(MilliSatoshi, CltvExpiryDelta)], r: SendTrampolinePaymentRequest)

View file

@ -16,8 +16,8 @@
package fr.acinq.eclair.io
import akka.actor.FSM
import akka.actor.Status.Failure
import akka.actor.{ActorContext, ActorRef, FSM}
import akka.testkit.{TestFSMRef, TestProbe}
import com.google.common.net.HostAndPort
import fr.acinq.bitcoin.Crypto.PublicKey
@ -46,14 +46,20 @@ class PeerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Paralle
val fakeIPAddress: NodeAddress = NodeAddress.fromParts("1.2.3.4", 42000).get
case class FixtureParam(nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: TestProbe, relayer: TestProbe, peer: TestFSMRef[Peer.State, Peer.Data, Peer], peerConnection: TestProbe)
case class FixtureParam(nodeParams: NodeParams, remoteNodeId: PublicKey, peer: TestFSMRef[Peer.State, Peer.Data, Peer], peerConnection: TestProbe, channel: TestProbe)
case class FakeChannelFactory(channel: TestProbe) extends ChannelFactory {
override def spawn(context: ActorContext, remoteNodeId: PublicKey, origin_opt: Option[ActorRef]): ActorRef = {
assert(remoteNodeId === Bob.nodeParams.nodeId)
channel.ref
}
}
override protected def withFixture(test: OneArgTest): Outcome = {
val watcher = TestProbe()
val relayer = TestProbe()
val wallet: EclairWallet = new TestWallet()
val remoteNodeId = Bob.nodeParams.nodeId
val peerConnection = TestProbe()
val channel = TestProbe()
import com.softwaremill.quicklens._
val aliceParams = TestConstants.Alice.nodeParams
@ -68,8 +74,8 @@ class PeerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Paralle
aliceParams.db.network.addNode(bobAnnouncement)
}
val peer: TestFSMRef[Peer.State, Peer.Data, Peer] = TestFSMRef(new Peer(aliceParams, remoteNodeId, watcher.ref, relayer.ref, wallet))
withFixture(test.toNoArgTest(FixtureParam(aliceParams, remoteNodeId, watcher, relayer, peer, peerConnection)))
val peer: TestFSMRef[Peer.State, Peer.Data, Peer] = TestFSMRef(new Peer(aliceParams, remoteNodeId, wallet, FakeChannelFactory(channel)))
withFixture(test.toNoArgTest(FixtureParam(aliceParams, remoteNodeId, peer, peerConnection, channel)))
}
def connect(remoteNodeId: PublicKey, peer: TestFSMRef[Peer.State, Peer.Data, Peer], peerConnection: TestProbe, channels: Set[HasCommitments] = Set.empty, remoteInit: protocol.Init = protocol.Init(Bob.nodeParams.features)): Unit = {
@ -198,21 +204,26 @@ class PeerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Paralle
val peerConnection3 = TestProbe()
connect(remoteNodeId, peer, peerConnection, channels = Set(ChannelCodecsSpec.normal))
peerConnection1.expectMsgType[ChannelReestablish]
// this is just to extract inits
val Peer.ConnectedData(_, _, localInit, remoteInit, _) = peer.stateData
channel.expectMsg(INPUT_RESTORED(ChannelCodecsSpec.normal))
val (localInit, remoteInit) = {
val inputReconnected = channel.expectMsgType[INPUT_RECONNECTED]
assert(inputReconnected.remote === peerConnection1.ref)
(inputReconnected.localInit, inputReconnected.remoteInit)
}
peerConnection2.send(peer, PeerConnection.ConnectionReady(peerConnection2.ref, remoteNodeId, fakeIPAddress.socketAddress, outgoing = false, localInit, remoteInit))
// peer should kill previous connection
peerConnection1.expectMsg(PeerConnection.Kill(PeerConnection.KillReason.ConnectionReplaced))
channel.expectMsg(INPUT_DISCONNECTED)
channel.expectMsg(INPUT_RECONNECTED(peerConnection2.ref, localInit, remoteInit))
awaitCond(peer.stateData.asInstanceOf[Peer.ConnectedData].peerConnection === peerConnection2.ref)
peerConnection2.expectMsgType[ChannelReestablish]
peerConnection3.send(peer, PeerConnection.ConnectionReady(peerConnection3.ref, remoteNodeId, fakeIPAddress.socketAddress, outgoing = false, localInit, remoteInit))
// peer should kill previous connection
peerConnection2.expectMsg(PeerConnection.Kill(PeerConnection.KillReason.ConnectionReplaced))
channel.expectMsg(INPUT_DISCONNECTED)
channel.expectMsg(INPUT_RECONNECTED(peerConnection3.ref, localInit, remoteInit))
awaitCond(peer.stateData.asInstanceOf[Peer.ConnectedData].peerConnection === peerConnection3.ref)
peerConnection3.expectMsgType[ChannelReestablish]
}
test("send state transitions to child reconnection actor", Tag("auto_reconnect"), Tag("with_node_announcement")) { f =>
@ -251,12 +262,12 @@ class PeerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Paralle
val open = protocol.OpenChannel(Block.RegtestGenesisBlock.hash, randomBytes32, 25000 sat, 0 msat, 483 sat, UInt64(100), 1000 sat, 1 msat, TestConstants.feeratePerKw, CltvExpiryDelta(144), 10, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, 0)
peerConnection.send(peer, open)
awaitCond(peer.stateData.channels.nonEmpty)
assert(probe.expectMsgType[ChannelCreated].temporaryChannelId === open.temporaryChannelId)
peerConnection.expectMsgType[AcceptChannel]
assert(channel.expectMsgType[INPUT_INIT_FUNDEE].temporaryChannelId === open.temporaryChannelId)
channel.expectMsg(open)
// open_channel messages with the same temporary channel id should simply be ignored
peerConnection.send(peer, open.copy(fundingSatoshis = 100000 sat, fundingPubkey = randomKey.publicKey))
probe.expectNoMsg(100 millis)
channel.expectNoMsg(100 millis)
peerConnection.expectNoMsg(100 millis)
assert(peer.stateData.channels.size === 1)
}
@ -307,59 +318,64 @@ class PeerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Paralle
import f._
val probe = TestProbe()
system.eventStream.subscribe(probe.ref, classOf[ChannelCreated])
connect(remoteNodeId, peer, peerConnection)
assert(peer.stateData.channels.isEmpty)
val relayFees = Some(100 msat, 1000)
probe.send(peer, Peer.OpenChannel(remoteNodeId, 12300 sat, 0 msat, None, relayFees, None, None))
val init = channel.expectMsgType[INPUT_INIT_FUNDER]
assert(init.channelVersion === ChannelVersion.STANDARD)
assert(init.fundingAmount === 12300.sat)
assert(init.initialRelayFees_opt === relayFees)
awaitCond(peer.stateData.channels.nonEmpty)
val channelCreated = probe.expectMsgType[ChannelCreated]
assert(channelCreated.initialFeeratePerKw == nodeParams.onChainFeeConf.feeEstimator.getFeeratePerKw(nodeParams.onChainFeeConf.feeTargets.commitmentBlockTarget))
assert(channelCreated.fundingTxFeeratePerKw.get == nodeParams.onChainFeeConf.feeEstimator.getFeeratePerKw(nodeParams.onChainFeeConf.feeTargets.fundingBlockTarget))
peer.stateData.channels.foreach { case (_, channelRef) =>
probe.send(channelRef, CMD_GETINFO(probe.ref))
val info = probe.expectMsgType[RES_GETINFO]
assert(info.state == WAIT_FOR_ACCEPT_CHANNEL)
val inputInit = info.data.asInstanceOf[DATA_WAIT_FOR_ACCEPT_CHANNEL].initFunder
assert(inputInit.initialRelayFees_opt === relayFees)
}
}
test("use correct on-chain fee rates when spawning a channel (anchor outputs)", Tag("anchor_outputs")) { f =>
import f._
val probe = TestProbe()
system.eventStream.subscribe(probe.ref, classOf[ChannelCreated])
connect(remoteNodeId, peer, peerConnection, remoteInit = protocol.Init(Features(StaticRemoteKey -> Optional, AnchorOutputs -> Optional)))
assert(peer.stateData.channels.isEmpty)
// We ensure the current network feerate is higher than the default anchor output feerate.
val feeEstimator = nodeParams.onChainFeeConf.feeEstimator.asInstanceOf[TestFeeEstimator]
feeEstimator.setFeerate(FeeratesPerKw.single(TestConstants.anchorOutputsFeeratePerKw * 2))
probe.send(peer, Peer.OpenChannel(remoteNodeId, 15000 sat, 0 msat, None, None, None, None))
val channelCreated = probe.expectMsgType[ChannelCreated]
assert(channelCreated.initialFeeratePerKw == TestConstants.anchorOutputsFeeratePerKw)
assert(channelCreated.fundingTxFeeratePerKw.get == feeEstimator.getFeeratePerKw(nodeParams.onChainFeeConf.feeTargets.fundingBlockTarget))
val init = channel.expectMsgType[INPUT_INIT_FUNDER]
assert(init.channelVersion.hasAnchorOutputs)
assert(init.fundingAmount === 15000.sat)
assert(init.initialRelayFees_opt === None)
assert(init.initialFeeratePerKw === TestConstants.anchorOutputsFeeratePerKw)
assert(init.fundingTxFeeratePerKw === feeEstimator.getFeeratePerKw(nodeParams.onChainFeeConf.feeTargets.fundingBlockTarget))
}
test("use correct final script if option_static_remotekey is negotiated", Tag("static_remotekey")) { f =>
import f._
val probe = TestProbe()
connect(remoteNodeId, peer, peerConnection, remoteInit = protocol.Init(Features(StaticRemoteKey -> Optional))) // Bob supports option_static_remotekey
connect(remoteNodeId, peer, peerConnection, remoteInit = protocol.Init(Features(StaticRemoteKey -> Optional)))
probe.send(peer, Peer.OpenChannel(remoteNodeId, 24000 sat, 0 msat, None, None, None, None))
awaitCond(peer.stateData.channels.nonEmpty)
peer.stateData.channels.foreach { case (_, channelRef) =>
probe.send(channelRef, CMD_GETINFO(probe.ref))
val info = probe.expectMsgType[RES_GETINFO]
assert(info.state == WAIT_FOR_ACCEPT_CHANNEL)
val inputInit = info.data.asInstanceOf[DATA_WAIT_FOR_ACCEPT_CHANNEL].initFunder
assert(inputInit.channelVersion.hasStaticRemotekey)
assert(inputInit.localParams.walletStaticPaymentBasepoint.isDefined)
assert(inputInit.localParams.defaultFinalScriptPubKey === Script.write(Script.pay2wpkh(inputInit.localParams.walletStaticPaymentBasepoint.get)))
val init = channel.expectMsgType[INPUT_INIT_FUNDER]
assert(init.channelVersion.hasStaticRemotekey)
assert(init.localParams.walletStaticPaymentBasepoint.isDefined)
assert(init.localParams.defaultFinalScriptPubKey === Script.write(Script.pay2wpkh(init.localParams.walletStaticPaymentBasepoint.get)))
}
test("set origin_opt when spawning a channel") { f =>
import f._
val probe = TestProbe()
val channelFactory = new ChannelFactory {
override def spawn(context: ActorContext, remoteNodeId: PublicKey, origin_opt: Option[ActorRef]): ActorRef = {
assert(origin_opt === Some(probe.ref))
channel.ref
}
}
val peer = TestFSMRef(new Peer(TestConstants.Alice.nodeParams, remoteNodeId, new TestWallet, channelFactory))
connect(remoteNodeId, peer, peerConnection)
probe.send(peer, Peer.OpenChannel(remoteNodeId, 15000 sat, 100 msat, None, None, None, None))
val init = channel.expectMsgType[INPUT_INIT_FUNDER]
assert(init.fundingAmount === 15000.sat)
assert(init.pushAmount === 100.msat)
}
}

View file

@ -1,23 +1,23 @@
package fr.acinq.eclair.io
import akka.actor.ActorRef
import akka.actor.{ActorContext, ActorRef}
import akka.testkit.{TestActorRef, TestProbe}
import fr.acinq.bitcoin.ByteVector64
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.TestConstants._
import fr.acinq.eclair.blockchain.TestWallet
import fr.acinq.eclair.channel.ChannelIdAssigned
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.io.Switchboard.PeerFactory
import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{Features, NodeParams, TestKitBaseClass, randomBytes32, randomKey}
import org.scalatest.funsuite.AnyFunSuiteLike
import scodec.bits._
class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike {
class TestSwitchboard(nodeParams: NodeParams, remoteNodeId: PublicKey, remotePeer: TestProbe) extends Switchboard(nodeParams, TestProbe().ref, TestProbe().ref, new TestWallet()) {
override def createPeer(remoteNodeId2: PublicKey): ActorRef = {
assert(remoteNodeId === remoteNodeId2)
case class FakePeerFactory(expectedRemoteNodeId: PublicKey, remotePeer: TestProbe) extends PeerFactory {
override def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef = {
assert(expectedRemoteNodeId === remoteNodeId)
remotePeer.ref
}
}
@ -29,7 +29,7 @@ class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike {
// If we have a channel with that remote peer, we will automatically reconnect.
nodeParams.db.channels.addOrUpdateChannel(ChannelCodecsSpec.normal)
val _ = TestActorRef(new TestSwitchboard(nodeParams, remoteNodeId, peer))
val _ = TestActorRef(new Switchboard(nodeParams, FakePeerFactory(remoteNodeId, peer)))
peer.expectMsg(Peer.Init(Set(ChannelCodecsSpec.normal)))
}
@ -40,7 +40,7 @@ class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike {
val remoteNodeAddress = NodeAddress.fromParts("127.0.0.1", 9735).get
nodeParams.db.network.addNode(NodeAnnouncement(ByteVector64.Zeroes, Features.empty, 0, remoteNodeId, Color(0, 0, 0), "alias", remoteNodeAddress :: Nil))
val switchboard = TestActorRef(new TestSwitchboard(nodeParams, remoteNodeId, peer))
val switchboard = TestActorRef(new Switchboard(nodeParams, FakePeerFactory(remoteNodeId, peer)))
probe.send(switchboard, Peer.Connect(remoteNodeId, None))
peer.expectMsg(Peer.Init(Set.empty))
peer.expectMsg(Peer.Connect(remoteNodeId, None))
@ -49,7 +49,7 @@ class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike {
def sendFeatures(nodeParams: NodeParams, remoteNodeId: PublicKey, expectedFeatures: Features, expectedSync: Boolean) = {
val peer = TestProbe()
val peerConnection = TestProbe()
val switchboard = TestActorRef(new TestSwitchboard(nodeParams, remoteNodeId, peer))
val switchboard = TestActorRef(new Switchboard(nodeParams, FakePeerFactory(remoteNodeId, peer)))
switchboard ! PeerConnection.Authenticated(peerConnection.ref, remoteNodeId)
peerConnection.expectMsg(PeerConnection.InitializeConnection(peer.ref, nodeParams.chainHash, expectedFeatures, doSync = expectedSync))
}
@ -66,7 +66,7 @@ class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike {
val peerConnection = TestProbe()
val nodeParams = Alice.nodeParams.copy(syncWhitelist = Set.empty)
val remoteNodeId = ChannelCodecsSpec.normal.commitments.remoteParams.nodeId
val switchboard = TestActorRef(new TestSwitchboard(nodeParams, remoteNodeId, peer))
val switchboard = TestActorRef(new Switchboard(nodeParams, FakePeerFactory(remoteNodeId, peer)))
// We have a channel with our peer, so we trigger a sync when connecting.
switchboard ! ChannelIdAssigned(TestProbe().ref, remoteNodeId, randomBytes32, randomBytes32)

View file

@ -16,7 +16,7 @@
package fr.acinq.eclair.payment
import akka.actor.{ActorRef, Status}
import akka.actor.{ActorContext, ActorRef, Status}
import akka.testkit.{TestFSMRef, TestProbe}
import fr.acinq.bitcoin.{Block, Crypto}
import fr.acinq.eclair._
@ -25,11 +25,11 @@ import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.db.{FailureSummary, FailureType, OutgoingPaymentStatus}
import fr.acinq.eclair.payment.OutgoingPacket.Upstream
import fr.acinq.eclair.payment.PaymentRequest.ExtraHop
import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle
import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle._
import fr.acinq.eclair.payment.send.PaymentError.RetryExhausted
import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig
import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToRoute
import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentInitiator}
import fr.acinq.eclair.router.Router._
import fr.acinq.eclair.router.{Announcements, RouteNotFound}
import fr.acinq.eclair.wire.protocol._
@ -56,15 +56,16 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS
childPayFsm: TestProbe,
eventListener: TestProbe)
case class FakePaymentFactory(childPayFsm: TestProbe) extends PaymentInitiator.PaymentFactory {
override def spawnOutgoingPayment(context: ActorContext, cfg: SendPaymentConfig): ActorRef = childPayFsm.ref
}
override def withFixture(test: OneArgTest): Outcome = {
val id = UUID.randomUUID()
val cfg = SendPaymentConfig(id, id, Some("42"), paymentHash, finalAmount, finalRecipient, Upstream.Local(id), None, storeInDb = true, publishEvent = true, Nil)
val nodeParams = TestConstants.Alice.nodeParams
val (childPayFsm, router, sender, eventListener) = (TestProbe(), TestProbe(), TestProbe(), TestProbe())
class TestMultiPartPaymentLifecycle extends MultiPartPaymentLifecycle(nodeParams, cfg, router.ref, TestProbe().ref) {
override def spawnChildPaymentFsm(childId: UUID): ActorRef = childPayFsm.ref
}
val paymentHandler = TestFSMRef(new TestMultiPartPaymentLifecycle().asInstanceOf[MultiPartPaymentLifecycle])
val paymentHandler = TestFSMRef(new MultiPartPaymentLifecycle(nodeParams, cfg, router.ref, FakePaymentFactory(childPayFsm)))
system.eventStream.subscribe(eventListener.ref, classOf[PaymentEvent])
withFixture(test.toNoArgTest(FixtureParam(cfg, nodeParams, paymentHandler, router, sender, childPayFsm, eventListener)))
}

View file

@ -16,7 +16,7 @@
package fr.acinq.eclair.payment
import akka.actor.ActorRef
import akka.actor.{ActorContext, ActorRef}
import akka.testkit.{TestActorRef, TestProbe}
import fr.acinq.bitcoin.Block
import fr.acinq.eclair.FeatureSupport.Optional
@ -63,25 +63,26 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike
BasicMultiPartPayment -> Optional,
)
case class FakePaymentFactory(payFsm: TestProbe, multiPartPayFsm: TestProbe) extends PaymentInitiator.MultiPartPaymentFactory {
// @formatter:off
override def spawnOutgoingPayment(context: ActorContext, cfg: SendPaymentConfig): ActorRef = {
payFsm.ref ! cfg
payFsm.ref
}
override def spawnOutgoingMultiPartPayment(context: ActorContext, cfg: SendPaymentConfig): ActorRef = {
multiPartPayFsm.ref ! cfg
multiPartPayFsm.ref
}
// @formatter:on
}
override def withFixture(test: OneArgTest): Outcome = {
val features = if (test.tags.contains("mpp_disabled")) featuresWithoutMpp else featuresWithMpp
val nodeParams = TestConstants.Alice.nodeParams.copy(features = features)
val (sender, payFsm, multiPartPayFsm) = (TestProbe(), TestProbe(), TestProbe())
val eventListener = TestProbe()
system.eventStream.subscribe(eventListener.ref, classOf[PaymentEvent])
class TestPaymentInitiator extends PaymentInitiator(nodeParams, TestProbe().ref, TestProbe().ref) {
// @formatter:off
override def spawnPaymentFsm(cfg: SendPaymentConfig): ActorRef = {
payFsm.ref ! cfg
payFsm.ref
}
override def spawnMultiPartPaymentFsm(cfg: SendPaymentConfig): ActorRef = {
multiPartPayFsm.ref ! cfg
multiPartPayFsm.ref
}
// @formatter:on
}
val initiator = TestActorRef(new TestPaymentInitiator().asInstanceOf[PaymentInitiator])
val initiator = TestActorRef(new PaymentInitiator(nodeParams, FakePaymentFactory(payFsm, multiPartPayFsm)))
withFixture(test.toNoArgTest(FixtureParam(nodeParams, initiator, payFsm, multiPartPayFsm, sender, eventListener)))
}

View file

@ -63,23 +63,23 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
val eventListener = TestProbe[PaymentEvent]("event-listener")
system.eventStream ! EventStream.Subscribe(eventListener.ref)
val mockPayFSM = TestProbe[Any]("pay-fsm")
val fsmFactory = if (test.tags.contains("mock-fsm")) {
new NodeRelay.FsmFactory {
override def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], nodeParams: NodeParams, router: akka.actor.ActorRef, register: akka.actor.ActorRef, cfg: SendPaymentConfig, multiPart: Boolean): akka.actor.ActorRef = {
val outgoingPaymentFactory = if (test.tags.contains("mock-fsm")) {
new NodeRelay.OutgoingPaymentFactory {
override def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], cfg: SendPaymentConfig, multiPart: Boolean): akka.actor.ActorRef = {
mockPayFSM.ref ! cfg
mockPayFSM.ref.toClassic
}
}
} else {
new NodeRelay.FsmFactory {
override def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], nodeParams: NodeParams, router: akka.actor.ActorRef, register: akka.actor.ActorRef, cfg: SendPaymentConfig, multiPart: Boolean): akka.actor.ActorRef = {
val fsm = super.spawnOutgoingPayFSM(context, nodeParams, router, register, cfg, multiPart)
mockPayFSM.ref ! fsm
fsm
new NodeRelay.OutgoingPaymentFactory {
override def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], cfg: SendPaymentConfig, multiPart: Boolean): akka.actor.ActorRef = {
val outgoingPayFSM = NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router.ref.toClassic, register.ref.toClassic).spawnOutgoingPayFSM(context, cfg, multiPart)
mockPayFSM.ref ! outgoingPayFSM
outgoingPayFSM
}
}
}
val nodeRelay = testKit.spawn(NodeRelay(nodeParams, parent.ref, router.ref.toClassic, register.ref.toClassic, relayId, paymentHash, fsmFactory))
val nodeRelay = testKit.spawn(NodeRelay(nodeParams, parent.ref, register.ref.toClassic, relayId, paymentHash, outgoingPaymentFactory))
withFixture(test.toNoArgTest(FixtureParam(nodeParams, nodeRelay, parent, router, register, mockPayFSM, eventListener)))
}
@ -431,7 +431,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
// Receive an upstream multi-part payment.
incomingMultiPart.dropRight(1).foreach(p => nodeRelayer ! NodeRelay.Relay(p))
router.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a complete upstream payment
mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a complete upstream payment
nodeRelayer ! NodeRelay.Relay(incomingMultiPart.last)