diff --git a/eclair-core/eclair-cli b/eclair-core/eclair-cli index 6d87b1aaf..c9286529c 100755 --- a/eclair-core/eclair-cli +++ b/eclair-core/eclair-cli @@ -33,7 +33,7 @@ and COMMAND is one of the available commands: - connect - disconnect - peers - - allnodes + - nodes - audit === Channel === diff --git a/eclair-core/src/main/resources/reference.conf b/eclair-core/src/main/resources/reference.conf index 32891730b..7cce0a2d7 100644 --- a/eclair-core/src/main/resources/reference.conf +++ b/eclair-core/src/main/resources/reference.conf @@ -44,6 +44,8 @@ eclair { gossip_queries = optional gossip_queries_ex = optional var_onion_optin = optional + payment_secret = optional + basic_mpp = optional } override-features = [ // optional per-node features # { @@ -96,9 +98,10 @@ eclair { claim-main = 12 // target for the claim main transaction (tx that spends main channel output back to wallet) } - // maximum local vs remote feerate mismatch; 1.0 means 100% - // actual check is abs((local feerate - remote fee rate) / (local fee rate + remote fee rate)/2) > fee rate mismatch - max-feerate-mismatch = 1.56 // will allow remote fee rates up to 8x bigger or smaller than our local fee rate + feerate-tolerance { + ratio-low = 0.5 // will allow remote fee rates as low as half our local feerate + ratio-high = 10.0 // will allow remote fee rates as high as 10 times our local feerate + } close-on-offline-feerate-mismatch = true // do not change this unless you know what you are doing // funder will send an UpdateFee message if the difference between current commitment fee and actual current network fee is greater diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index e9b346cce..2dcba4e93 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -86,7 +86,9 @@ trait Eclair { def channelInfo(channel: ApiTypes.ChannelIdentifier)(implicit timeout: Timeout): Future[RES_GETINFO] - def peersInfo()(implicit timeout: Timeout): Future[Iterable[PeerInfo]] + def peers()(implicit timeout: Timeout): Future[Iterable[PeerInfo]] + + def nodes(nodeIds_opt: Option[Set[PublicKey]] = None)(implicit timeout: Timeout): Future[Iterable[NodeAnnouncement]] def receive(description: String, amount_opt: Option[MilliSatoshi], expire_opt: Option[Long], fallbackAddress_opt: Option[String], paymentPreimage_opt: Option[ByteVector32])(implicit timeout: Timeout): Future[PaymentRequest] @@ -108,7 +110,7 @@ trait Eclair { def networkFees(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[NetworkFee]] - def channelStats()(implicit timeout: Timeout): Future[Seq[Stats]] + def channelStats(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[Stats]] def networkStats()(implicit timeout: Timeout): Future[Option[NetworkStats]] @@ -118,8 +120,6 @@ trait Eclair { def allInvoices(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[PaymentRequest]] - def allNodes()(implicit timeout: Timeout): Future[Iterable[NodeAnnouncement]] - def allChannels()(implicit timeout: Timeout): Future[Iterable[ChannelDesc]] def allUpdates(nodeId_opt: Option[PublicKey])(implicit timeout: Timeout): Future[Iterable[ChannelUpdate]] @@ -174,11 +174,17 @@ class EclairImpl(appKit: Kit) extends Eclair { sendToChannels[ChannelCommandResponse](channels, CMD_UPDATE_RELAY_FEE(feeBaseMsat, feeProportionalMillionths)) } - override def peersInfo()(implicit timeout: Timeout): Future[Iterable[PeerInfo]] = for { + override def peers()(implicit timeout: Timeout): Future[Iterable[PeerInfo]] = for { peers <- (appKit.switchboard ? Symbol("peers")).mapTo[Iterable[ActorRef]] peerinfos <- Future.sequence(peers.map(peer => (peer ? GetPeerInfo).mapTo[PeerInfo])) } yield peerinfos + override def nodes(nodeIds_opt: Option[Set[PublicKey]])(implicit timeout: Timeout): Future[Iterable[NodeAnnouncement]] = { + (appKit.router ? Symbol("nodes")) + .mapTo[Iterable[NodeAnnouncement]] + .map(_.filter(n => nodeIds_opt.forall(_.contains(n.nodeId)))) + } + override def channelsInfo(toRemoteNode_opt: Option[PublicKey])(implicit timeout: Timeout): Future[Iterable[RES_GETINFO]] = toRemoteNode_opt match { case Some(pk) => for { channelIds <- (appKit.register ? Symbol("channelsTo")).mapTo[Map[ByteVector32, PublicKey]].map(_.filter(_._2 == pk).keys) @@ -194,8 +200,6 @@ class EclairImpl(appKit: Kit) extends Eclair { sendToChannel[RES_GETINFO](channel, CMD_GETINFO) } - override def allNodes()(implicit timeout: Timeout): Future[Iterable[NodeAnnouncement]] = (appKit.router ? Symbol("nodes")).mapTo[Iterable[NodeAnnouncement]] - override def allChannels()(implicit timeout: Timeout): Future[Iterable[ChannelDesc]] = { (appKit.router ? Symbol("channels")).mapTo[Iterable[ChannelAnnouncement]].map(_.map(c => ChannelDesc(c.shortChannelId, c.nodeId1, c.nodeId2))) } @@ -312,7 +316,10 @@ class EclairImpl(appKit: Kit) extends Eclair { Future(appKit.nodeParams.db.audit.listNetworkFees(filter.from, filter.to)) } - override def channelStats()(implicit timeout: Timeout): Future[Seq[Stats]] = Future(appKit.nodeParams.db.audit.stats) + override def channelStats(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[Stats]] = { + val filter = getDefaultTimestampFilters(from_opt, to_opt) + Future(appKit.nodeParams.db.audit.stats(filter.from, filter.to)) + } override def networkStats()(implicit timeout: Timeout): Future[Option[NetworkStats]] = (appKit.router ? GetNetworkStats).mapTo[GetNetworkStatsResponse].map(_.stats) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index 12d4dcd07..6e2fa1460 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -27,7 +27,7 @@ import com.google.common.io.Files import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.{Block, ByteVector32, Satoshi} import fr.acinq.eclair.NodeParams.WatcherType -import fr.acinq.eclair.blockchain.fee.{FeeEstimator, FeeTargets, OnChainFeeConf} +import fr.acinq.eclair.blockchain.fee.{FeeEstimator, FeeTargets, FeerateTolerance, OnChainFeeConf} import fr.acinq.eclair.channel.Channel import fr.acinq.eclair.crypto.KeyManager import fr.acinq.eclair.db._ @@ -143,7 +143,9 @@ object NodeParams { "update-fee_min-diff-ratio" -> "on-chain-fees.update-fee-min-diff-ratio", // v0.3.3 "global-features" -> "features", - "local-features" -> "features" + "local-features" -> "features", + // v0.4.1 + "on-chain-fees.max-feerate-mismatch" -> "on-chain-fees.feerate-tolerance.ratio-low / on-chain-fees.feerate-tolerance.ratio-high" ) deprecatedKeyPaths.foreach { case (old, new_) => require(!config.hasPath(old), s"configuration key '$old' has been replaced by '$new_'") @@ -246,7 +248,7 @@ object NodeParams { onChainFeeConf = OnChainFeeConf( feeTargets = feeTargets, feeEstimator = feeEstimator, - maxFeerateMismatch = config.getDouble("on-chain-fees.max-feerate-mismatch"), + maxFeerateMismatch = FeerateTolerance(config.getDouble("on-chain-fees.feerate-tolerance.ratio-low"), config.getDouble("on-chain-fees.feerate-tolerance.ratio-high")), closeOnOfflineMismatch = config.getBoolean("on-chain-fees.close-on-offline-feerate-mismatch"), updateFeeMinDiffRatio = config.getDouble("on-chain-fees.update-fee-min-diff-ratio") ), diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala index 845d3d16f..96503b18c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala @@ -252,7 +252,7 @@ class Setup(datadir: File, _ <- postRestartCleanUpInitialized.future switchboard = system.actorOf(SimpleSupervisor.props(Switchboard.props(nodeParams, watcher, relayer, paymentHandler, wallet), "switchboard", SupervisorStrategy.Resume)) clientSpawner = system.actorOf(SimpleSupervisor.props(ClientSpawner.props(nodeParams, switchboard, router), "client-spawner", SupervisorStrategy.Restart)) - paymentInitiator = system.actorOf(SimpleSupervisor.props(PaymentInitiator.props(nodeParams, router, relayer, register), "payment-initiator", SupervisorStrategy.Restart)) + paymentInitiator = system.actorOf(SimpleSupervisor.props(PaymentInitiator.props(nodeParams, router, register), "payment-initiator", SupervisorStrategy.Restart)) kit = Kit( nodeParams = nodeParams, diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/fee/FeeEstimator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/fee/FeeEstimator.scala index bdd0a24e1..eff88c1ee 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/fee/FeeEstimator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/fee/FeeEstimator.scala @@ -18,12 +18,14 @@ package fr.acinq.eclair.blockchain.fee trait FeeEstimator { - def getFeeratePerKb(target: Int) : Long + def getFeeratePerKb(target: Int): Long - def getFeeratePerKw(target: Int) : Long + def getFeeratePerKw(target: Int): Long } case class FeeTargets(fundingBlockTarget: Int, commitmentBlockTarget: Int, mutualCloseBlockTarget: Int, claimMainBlockTarget: Int) -case class OnChainFeeConf(feeTargets: FeeTargets, feeEstimator: FeeEstimator, maxFeerateMismatch: Double, closeOnOfflineMismatch: Boolean, updateFeeMinDiffRatio: Double) \ No newline at end of file +case class FeerateTolerance(ratioLow: Double, ratioHigh: Double) + +case class OnChainFeeConf(feeTargets: FeeTargets, feeEstimator: FeeEstimator, maxFeerateMismatch: FeerateTolerance, closeOnOfflineMismatch: Boolean, updateFeeMinDiffRatio: Double) \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala index 568e099d3..c86fccebc 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala @@ -35,10 +35,8 @@ import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.transactions._ import fr.acinq.eclair.wire._ import scodec.bits.ByteVector -import ChannelVersion._ import scala.collection.immutable.Queue -import scala.compat.Platform import scala.concurrent.ExecutionContext import scala.concurrent.duration._ import scala.util.{Failure, Success, Try} @@ -653,7 +651,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId handleCommandError(AddHtlcFailed(d.channelId, c.paymentHash, error, origin(c), Some(d.channelUpdate), Some(c)), c) case Event(c: CMD_ADD_HTLC, d: DATA_NORMAL) => - Commitments.sendAdd(d.commitments, c, origin(c), nodeParams.currentBlockHeight) match { + Commitments.sendAdd(d.commitments, c, origin(c), nodeParams.currentBlockHeight, nodeParams.onChainFeeConf) match { case Success((commitments1, add)) => if (c.commit) self ! CMD_SIGN context.system.eventStream.publish(AvailableBalanceChanged(self, d.channelId, d.shortChannelId, commitments1)) @@ -662,7 +660,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId } case Event(add: UpdateAddHtlc, d: DATA_NORMAL) => - Commitments.receiveAdd(d.commitments, add) match { + Commitments.receiveAdd(d.commitments, add, nodeParams.onChainFeeConf) match { case Success(commitments1) => stay using d.copy(commitments = commitments1) case Failure(cause) => handleLocalError(cause, d, Some(add)) } @@ -734,7 +732,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId } case Event(fee: UpdateFee, d: DATA_NORMAL) => - Commitments.receiveFee(d.commitments, nodeParams.onChainFeeConf.feeEstimator, nodeParams.onChainFeeConf.feeTargets, fee, nodeParams.onChainFeeConf.maxFeerateMismatch) match { + Commitments.receiveFee(d.commitments, fee, nodeParams.onChainFeeConf) match { case Success(commitments1) => stay using d.copy(commitments = commitments1) case Failure(cause) => handleLocalError(cause, d, Some(fee)) } @@ -822,14 +820,14 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId case Event(c@CMD_CLOSE(localScriptPubKey_opt), d: DATA_NORMAL) => val localScriptPubKey = localScriptPubKey_opt.getOrElse(d.commitments.localParams.defaultFinalScriptPubKey) - if (d.localShutdown.isDefined) + if (d.localShutdown.isDefined) { handleCommandError(ClosingAlreadyInProgress(d.channelId), c) - else if (Commitments.localHasUnsignedOutgoingHtlcs(d.commitments)) - // TODO: simplistic behavior, we could also sign-then-close - handleCommandError(CannotCloseWithUnsignedOutgoingHtlcs(d.channelId), c) - else if (!Closing.isValidFinalScriptPubkey(localScriptPubKey)) + } else if (Commitments.localHasUnsignedOutgoingHtlcs(d.commitments)) { + // TODO: simplistic behavior, we could also sign-then-close + handleCommandError(CannotCloseWithUnsignedOutgoingHtlcs(d.channelId), c) + } else if (!Closing.isValidFinalScriptPubkey(localScriptPubKey)) { handleCommandError(InvalidFinalScript(d.channelId), c) - else { + } else { val shutdown = Shutdown(d.channelId, localScriptPubKey) handleCommandSuccess(sender, d.copy(localShutdown = Some(shutdown))) storing() sending shutdown } @@ -1079,7 +1077,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId } case Event(fee: UpdateFee, d: DATA_SHUTDOWN) => - Commitments.receiveFee(d.commitments, nodeParams.onChainFeeConf.feeEstimator, nodeParams.onChainFeeConf.feeTargets, fee, nodeParams.onChainFeeConf.maxFeerateMismatch) match { + Commitments.receiveFee(d.commitments, fee, nodeParams.onChainFeeConf) match { case Success(commitments1) => stay using d.copy(commitments = commitments1) case Failure(cause) => handleLocalError(cause, d, Some(fee)) } @@ -1827,13 +1825,18 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId def handleCurrentFeerate(c: CurrentFeerates, d: HasCommitments) = { val networkFeeratePerKw = c.feeratesPerKw.feePerBlock(target = nodeParams.onChainFeeConf.feeTargets.commitmentBlockTarget) val currentFeeratePerKw = d.commitments.localCommit.spec.feeratePerKw - d.commitments.localParams.isFunder match { - case true if Helpers.shouldUpdateFee(currentFeeratePerKw, networkFeeratePerKw, nodeParams.onChainFeeConf.updateFeeMinDiffRatio) => - self ! CMD_UPDATE_FEE(networkFeeratePerKw, commit = true) - stay - case false if Helpers.isFeeDiffTooHigh(currentFeeratePerKw, networkFeeratePerKw, nodeParams.onChainFeeConf.maxFeerateMismatch) => - handleLocalError(FeerateTooDifferent(d.channelId, localFeeratePerKw = networkFeeratePerKw, remoteFeeratePerKw = d.commitments.localCommit.spec.feeratePerKw), d, Some(c)) - case _ => stay + val shouldUpdateFee = d.commitments.localParams.isFunder && + Helpers.shouldUpdateFee(currentFeeratePerKw, networkFeeratePerKw, nodeParams.onChainFeeConf.updateFeeMinDiffRatio) + val shouldClose = !d.commitments.localParams.isFunder && + Helpers.isFeeDiffTooHigh(networkFeeratePerKw, currentFeeratePerKw, nodeParams.onChainFeeConf.maxFeerateMismatch) && + d.commitments.hasPendingOrProposedHtlcs // we close only if we have HTLCs potentially at risk + if (shouldUpdateFee) { + self ! CMD_UPDATE_FEE(networkFeeratePerKw, commit = true) + stay + } else if (shouldClose) { + handleLocalError(FeerateTooDifferent(d.channelId, localFeeratePerKw = networkFeeratePerKw, remoteFeeratePerKw = d.commitments.localCommit.spec.feeratePerKw), d, Some(c)) + } else { + stay } } @@ -1847,13 +1850,16 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId def handleOfflineFeerate(c: CurrentFeerates, d: HasCommitments) = { val networkFeeratePerKw = c.feeratesPerKw.feePerBlock(target = nodeParams.onChainFeeConf.feeTargets.commitmentBlockTarget) val currentFeeratePerKw = d.commitments.localCommit.spec.feeratePerKw - // if the fees are too high we risk to not be able to confirm our current commitment - if (networkFeeratePerKw > currentFeeratePerKw && Helpers.isFeeDiffTooHigh(currentFeeratePerKw, networkFeeratePerKw, nodeParams.onChainFeeConf.maxFeerateMismatch)) { + // if the network fees are too high we risk to not be able to confirm our current commitment + val shouldClose = networkFeeratePerKw > currentFeeratePerKw && + Helpers.isFeeDiffTooHigh(networkFeeratePerKw, currentFeeratePerKw, nodeParams.onChainFeeConf.maxFeerateMismatch) && + d.commitments.hasPendingOrProposedHtlcs // we close only if we have HTLCs potentially at risk + if (shouldClose) { if (nodeParams.onChainFeeConf.closeOnOfflineMismatch) { - log.warning(s"closing OFFLINE ${d.channelId} due to fee mismatch: currentFeeratePerKw=$currentFeeratePerKw networkFeeratePerKw=$networkFeeratePerKw") + log.warning(s"closing OFFLINE channel due to fee mismatch: currentFeeratePerKw=$currentFeeratePerKw networkFeeratePerKw=$networkFeeratePerKw") handleLocalError(FeerateTooDifferent(d.channelId, localFeeratePerKw = currentFeeratePerKw, remoteFeeratePerKw = networkFeeratePerKw), d, Some(c)) } else { - log.warning(s"channel ${d.channelId} is OFFLINE but its fee mismatch is over the threshold: currentFeeratePerKw=$currentFeeratePerKw networkFeeratePerKw=$networkFeeratePerKw") + log.warning(s"channel is OFFLINE but its fee mismatch is over the threshold: currentFeeratePerKw=$currentFeeratePerKw networkFeeratePerKw=$networkFeeratePerKw") stay } } else { @@ -2173,7 +2179,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId val remotePerCommitmentPoint = d.remoteChannelReestablish.myCurrentPerCommitmentPoint val remoteCommitPublished = Helpers.Closing.claimRemoteCommitMainOutput(keyManager, d.commitments, remotePerCommitmentPoint, commitTx, nodeParams.onChainFeeConf.feeEstimator, nodeParams.onChainFeeConf.feeTargets) val nextData = DATA_CLOSING(d.commitments, fundingTx = None, waitingSince = now, Nil, futureRemoteCommitPublished = Some(remoteCommitPublished)) - goto(CLOSING) using nextData storing() calling(doPublish(remoteCommitPublished)) + goto(CLOSING) using nextData storing() calling (doPublish(remoteCommitPublished)) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala index 7dffd2dc5..cb8da19e6 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala @@ -17,10 +17,9 @@ package fr.acinq.eclair.channel import akka.event.LoggingAdapter -import fr.acinq.eclair.channel.ChannelVersion._ import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey, sha256} import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Crypto} -import fr.acinq.eclair.blockchain.fee.{FeeEstimator, FeeTargets} +import fr.acinq.eclair.blockchain.fee.OnChainFeeConf import fr.acinq.eclair.channel.Monitoring.Metrics import fr.acinq.eclair.crypto.{Generators, KeyManager, ShaChain, Sphinx} import fr.acinq.eclair.payment.relay.{Origin, Relayer} @@ -37,7 +36,9 @@ import scala.util.{Failure, Success, Try} case class LocalChanges(proposed: List[UpdateMessage], signed: List[UpdateMessage], acked: List[UpdateMessage]) { def all: List[UpdateMessage] = proposed ++ signed ++ acked } -case class RemoteChanges(proposed: List[UpdateMessage], acked: List[UpdateMessage], signed: List[UpdateMessage]) +case class RemoteChanges(proposed: List[UpdateMessage], acked: List[UpdateMessage], signed: List[UpdateMessage]) { + def all: List[UpdateMessage] = proposed ++ signed ++ acked +} case class Changes(ourChanges: LocalChanges, theirChanges: RemoteChanges) case class HtlcTxAndSigs(txinfo: TransactionWithInputInfo, localSig: ByteVector64, remoteSig: ByteVector64) case class PublishableTxs(commitTx: CommitTx, htlcTxsAndSigs: List[HtlcTxAndSigs]) @@ -70,6 +71,10 @@ case class Commitments(channelVersion: ChannelVersion, def hasNoPendingHtlcs: Boolean = localCommit.spec.htlcs.isEmpty && remoteCommit.spec.htlcs.isEmpty && remoteNextCommitInfo.isRight + def hasPendingOrProposedHtlcs: Boolean = !hasNoPendingHtlcs || + localChanges.all.exists(_.isInstanceOf[UpdateAddHtlc]) || + remoteChanges.all.exists(_.isInstanceOf[UpdateAddHtlc]) + def timedOutOutgoingHtlcs(blockheight: Long): Set[UpdateAddHtlc] = { def expired(add: UpdateAddHtlc) = blockheight >= add.cltvExpiry.toLong @@ -179,7 +184,7 @@ object Commitments { * @param cmd add HTLC command * @return either Left(failure, error message) where failure is a failure message (see BOLT #4 and the Failure Message class) or Right((new commitments, updateAddHtlc) */ - def sendAdd(commitments: Commitments, cmd: CMD_ADD_HTLC, origin: Origin, blockHeight: Long): Try[(Commitments, UpdateAddHtlc)] = { + def sendAdd(commitments: Commitments, cmd: CMD_ADD_HTLC, origin: Origin, blockHeight: Long, feeConf: OnChainFeeConf): Try[(Commitments, UpdateAddHtlc)] = { // our counterparty needs a reasonable amount of time to pull the funds from downstream before we can get refunded (see BOLT 2 and BOLT 11 for a calculation and rationale) val minExpiry = Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry(blockHeight) if (cmd.cltvExpiry < minExpiry) { @@ -197,6 +202,13 @@ object Commitments { return Failure(HtlcValueTooSmall(commitments.channelId, minimum = htlcMinimum, actual = cmd.amount)) } + // we allowed mismatches between our feerates and our remote's as long as commitments didn't contain any HTLC at risk + // we need to verify that we're not disagreeing on feerates anymore before offering new HTLCs + val localFeeratePerKw = feeConf.feeEstimator.getFeeratePerKw(target = feeConf.feeTargets.commitmentBlockTarget) + if (Helpers.isFeeDiffTooHigh(localFeeratePerKw, commitments.localCommit.spec.feeratePerKw, feeConf.maxFeerateMismatch)) { + return Failure(FeerateTooDifferent(commitments.channelId, localFeeratePerKw = localFeeratePerKw, remoteFeeratePerKw = commitments.localCommit.spec.feeratePerKw)) + } + // let's compute the current commitment *as seen by them* with this change taken into account val add = UpdateAddHtlc(commitments.channelId, commitments.localNextHtlcId, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) // we increment the local htlc index and add an entry to the origins map @@ -238,7 +250,7 @@ object Commitments { Success(commitments1, add) } - def receiveAdd(commitments: Commitments, add: UpdateAddHtlc): Try[Commitments] = Try { + def receiveAdd(commitments: Commitments, add: UpdateAddHtlc, feeConf: OnChainFeeConf): Try[Commitments] = Try { if (add.id != commitments.remoteNextHtlcId) { throw UnexpectedHtlcId(commitments.channelId, expected = commitments.remoteNextHtlcId, actual = add.id) } @@ -249,6 +261,13 @@ object Commitments { throw HtlcValueTooSmall(commitments.channelId, minimum = htlcMinimum, actual = add.amountMsat) } + // we allowed mismatches between our feerates and our remote's as long as commitments didn't contain any HTLC at risk + // we need to verify that we're not disagreeing on feerates anymore before accepting new HTLCs + val localFeeratePerKw = feeConf.feeEstimator.getFeeratePerKw(target = feeConf.feeTargets.commitmentBlockTarget) + if (Helpers.isFeeDiffTooHigh(localFeeratePerKw, commitments.localCommit.spec.feeratePerKw, feeConf.maxFeerateMismatch)) { + throw FeerateTooDifferent(commitments.channelId, localFeeratePerKw = localFeeratePerKw, remoteFeeratePerKw = commitments.localCommit.spec.feeratePerKw) + } + // let's compute the current commitment *as seen by us* including this change val commitments1 = addRemoteProposal(commitments, add).copy(remoteNextHtlcId = commitments.remoteNextHtlcId + 1) val reduced = CommitmentSpec.reduce(commitments1.localCommit.spec, commitments1.localChanges.acked, commitments1.remoteChanges.proposed) @@ -398,16 +417,16 @@ object Commitments { } } - def receiveFee(commitments: Commitments, feeEstimator: FeeEstimator, feeTargets: FeeTargets, fee: UpdateFee, maxFeerateMismatch: Double)(implicit log: LoggingAdapter): Try[Commitments] = { + def receiveFee(commitments: Commitments, fee: UpdateFee, feeConf: OnChainFeeConf)(implicit log: LoggingAdapter): Try[Commitments] = { if (commitments.localParams.isFunder) { Failure(FundeeCannotSendUpdateFee(commitments.channelId)) } else if (fee.feeratePerKw < fr.acinq.eclair.MinimumFeeratePerKw) { Failure(FeerateTooSmall(commitments.channelId, remoteFeeratePerKw = fee.feeratePerKw)) } else { Metrics.RemoteFeeratePerKw.withoutTags().record(fee.feeratePerKw) - val localFeeratePerKw = feeEstimator.getFeeratePerKw(target = feeTargets.commitmentBlockTarget) + val localFeeratePerKw = feeConf.feeEstimator.getFeeratePerKw(target = feeConf.feeTargets.commitmentBlockTarget) log.info("remote feeratePerKw={}, local feeratePerKw={}, ratio={}", fee.feeratePerKw, localFeeratePerKw, fee.feeratePerKw.toDouble / localFeeratePerKw) - if (Helpers.isFeeDiffTooHigh(fee.feeratePerKw, localFeeratePerKw, maxFeerateMismatch)) { + if (Helpers.isFeeDiffTooHigh(localFeeratePerKw, fee.feeratePerKw, feeConf.maxFeerateMismatch) && commitments.hasPendingOrProposedHtlcs) { Failure(FeerateTooDifferent(commitments.channelId, localFeeratePerKw = localFeeratePerKw, remoteFeeratePerKw = fee.feeratePerKw)) } else { // NB: we check that the funder can afford this new fee even if spec allows to do it at next signature @@ -461,6 +480,7 @@ object Commitments { // NB: IN/OUT htlcs are inverted because this is the remote commit log.info(s"built remote commit number=${remoteCommit.index + 1} toLocalMsat=${spec.toLocal.toLong} toRemoteMsat=${spec.toRemote.toLong} htlc_in={} htlc_out={} feeratePerKw=${spec.feeratePerKw} txid=${remoteCommitTx.tx.txid} tx={}", spec.htlcs.collect(outgoing).map(_.id).mkString(","), spec.htlcs.collect(incoming).map(_.id).mkString(","), remoteCommitTx.tx) + Metrics.recordHtlcsInFlight(spec, remoteCommit.spec) // don't sign if they don't get paid val commitSig = CommitSig( diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala index 8f7a0427f..6f9c57baa 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala @@ -21,7 +21,7 @@ import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey, ripemd160, sha256} import fr.acinq.bitcoin.Script._ import fr.acinq.bitcoin._ import fr.acinq.eclair.blockchain.EclairWallet -import fr.acinq.eclair.blockchain.fee.{FeeEstimator, FeeTargets} +import fr.acinq.eclair.blockchain.fee.{FeeEstimator, FeeTargets, FeerateTolerance} import fr.acinq.eclair.channel.Channel.REFRESH_CHANNEL_UPDATE_INTERVAL import fr.acinq.eclair.crypto.{ChaCha20Poly1305, Generators, KeyManager} import fr.acinq.eclair.db.ChannelsDb @@ -32,7 +32,6 @@ import fr.acinq.eclair.transactions._ import fr.acinq.eclair.wire.OpenChannelTlv.ChannelVersionTlv import fr.acinq.eclair.wire._ import fr.acinq.eclair.{NodeParams, ShortChannelId, addressToPublicKeyScript, _} -import fr.acinq.eclair.channel.ChannelVersion.USE_STATIC_REMOTEKEY_BIT import scodec.bits.ByteVector import scala.concurrent.Await @@ -143,7 +142,7 @@ object Helpers { } val localFeeratePerKw = nodeParams.onChainFeeConf.feeEstimator.getFeeratePerKw(target = nodeParams.onChainFeeConf.feeTargets.commitmentBlockTarget) - if (isFeeDiffTooHigh(open.feeratePerKw, localFeeratePerKw, nodeParams.onChainFeeConf.maxFeerateMismatch)) throw FeerateTooDifferent(open.temporaryChannelId, localFeeratePerKw, open.feeratePerKw) + if (isFeeDiffTooHigh(localFeeratePerKw, open.feeratePerKw, nodeParams.onChainFeeConf.maxFeerateMismatch)) throw FeerateTooDifferent(open.temporaryChannelId, localFeeratePerKw, open.feeratePerKw) // only enforce dust limit check on mainnet if (nodeParams.chainHash == Block.LivenetGenesisBlock.hash) { if (open.dustLimitSatoshis < Channel.MIN_DUSTLIMIT) throw DustLimitTooSmall(open.temporaryChannelId, open.dustLimitSatoshis, Channel.MIN_DUSTLIMIT) @@ -207,25 +206,20 @@ object Helpers { } /** - * @param referenceFeePerKw reference fee rate per kiloweight - * @param currentFeePerKw current fee rate per kiloweight - * @return the "normalized" difference between i.e local and remote fee rate: |reference - current| / avg(current, reference) + * To avoid spamming our peers with fee updates every time there's a small variation, we only update the fee when the + * difference exceeds a given ratio (updateFeeMinDiffRatio). */ - def feeRateMismatch(referenceFeePerKw: Long, currentFeePerKw: Long): Double = - Math.abs((2.0 * (referenceFeePerKw - currentFeePerKw)) / (currentFeePerKw + referenceFeePerKw)) - - def shouldUpdateFee(commitmentFeeratePerKw: Long, networkFeeratePerKw: Long, updateFeeMinDiffRatio: Double): Boolean = - feeRateMismatch(networkFeeratePerKw, commitmentFeeratePerKw) > updateFeeMinDiffRatio + def shouldUpdateFee(currentFeeratePerKw: Long, nextFeeratePerKw: Long, updateFeeMinDiffRatio: Double): Boolean = + currentFeeratePerKw == 0 || Math.abs((currentFeeratePerKw - nextFeeratePerKw).toDouble / currentFeeratePerKw) > updateFeeMinDiffRatio /** - * @param referenceFeePerKw reference fee rate per kiloweight - * @param currentFeePerKw current fee rate per kiloweight - * @param maxFeerateMismatchRatio maximum fee rate mismatch ratio - * @return true if the difference between current and reference fee rates is too high. - * the actual check is |reference - current| / avg(current, reference) > mismatch ratio + * @param referenceFeePerKw reference fee rate per kiloweight + * @param currentFeePerKw current fee rate per kiloweight + * @param maxFeerateMismatch maximum fee rate mismatch tolerated + * @return true if the difference between proposed and reference fee rates is too high. */ - def isFeeDiffTooHigh(referenceFeePerKw: Long, currentFeePerKw: Long, maxFeerateMismatchRatio: Double): Boolean = - feeRateMismatch(referenceFeePerKw, currentFeePerKw) > maxFeerateMismatchRatio + def isFeeDiffTooHigh(referenceFeePerKw: Long, currentFeePerKw: Long, maxFeerateMismatch: FeerateTolerance): Boolean = + currentFeePerKw < referenceFeePerKw * maxFeerateMismatch.ratioLow || referenceFeePerKw * maxFeerateMismatch.ratioHigh < currentFeePerKw /** * @param remoteFeeratePerKw remote fee rate per kiloweight @@ -663,9 +657,9 @@ object Helpers { ) case _ => claimRemoteCommitMainOutput(keyManager, commitments, remoteCommit.remotePerCommitmentPoint, tx, feeEstimator, feeTargets).copy( - claimHtlcSuccessTxs = txes.toList.collect { case c: ClaimHtlcSuccessTx => c.tx }, - claimHtlcTimeoutTxs = txes.toList.collect { case c: ClaimHtlcTimeoutTx => c.tx } - ) + claimHtlcSuccessTxs = txes.toList.collect { case c: ClaimHtlcSuccessTx => c.tx }, + claimHtlcTimeoutTxs = txes.toList.collect { case c: ClaimHtlcTimeoutTx => c.tx } + ) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Monitoring.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Monitoring.scala index f1c9db6b0..11fdec12c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Monitoring.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Monitoring.scala @@ -16,6 +16,7 @@ package fr.acinq.eclair.channel +import fr.acinq.eclair.transactions.{CommitmentSpec, DirectedHtlc} import kamon.Kamon import kamon.tag.TagSet @@ -25,15 +26,35 @@ object Monitoring { val ChannelsCount = Kamon.gauge("channels.count") val ChannelErrors = Kamon.counter("channels.errors") val ChannelLifecycleEvents = Kamon.counter("channels.lifecycle") + val HtlcsInFlight = Kamon.histogram("channels.htlc-in-flight", "Per-channel HTLCs in flight") + val HtlcsInFlightGlobal = Kamon.gauge("channels.htlc-in-flight-global", "Global HTLCs in flight across all channels") + val HtlcValueInFlight = Kamon.histogram("channels.htlc-value-in-flight", "Per-channel HTLC value in flight") + val HtlcValueInFlightGlobal = Kamon.gauge("channels.htlc-value-in-flight-global", "Global HTLC value in flight across all channels") val LocalFeeratePerKw = Kamon.gauge("channels.local-feerate-per-kw") val RemoteFeeratePerKw = Kamon.histogram("channels.remote-feerate-per-kw") + + def recordHtlcsInFlight(remoteSpec: CommitmentSpec, previousRemoteSpec: CommitmentSpec): Unit = { + for (direction <- Tags.Directions.Incoming :: Tags.Directions.Outgoing :: Nil) { + // NB: IN/OUT htlcs are inverted because this is the remote commit + val filter = if (direction == Tags.Directions.Incoming) DirectedHtlc.outgoing else DirectedHtlc.incoming + // NB: we need the `toSeq` because otherwise duplicate amounts would be removed (since htlcs are sets) + val htlcs = remoteSpec.htlcs.collect(filter).toSeq.map(_.amountMsat) + val previousHtlcs = previousRemoteSpec.htlcs.collect(filter).toSeq.map(_.amountMsat) + HtlcsInFlight.withTag(Tags.Direction, direction).record(htlcs.length) + HtlcsInFlightGlobal.withTag(Tags.Direction, direction).increment(htlcs.length - previousHtlcs.length) + val (value, previousValue) = (htlcs.sum.truncateToSatoshi.toLong, previousHtlcs.sum.truncateToSatoshi.toLong) + HtlcValueInFlight.withTag(Tags.Direction, direction).record(value) + HtlcValueInFlightGlobal.withTag(Tags.Direction, direction).increment(value - previousValue) + } + } } object Tags { - val Event = TagSet.Empty - val Fatal = TagSet.Empty - val Origin = TagSet.Empty - val State = TagSet.Empty + val Direction = "direction" + val Event = "event" + val Fatal = "fatal" + val Origin = "origin" + val State = "state" object Events { val Created = "created" @@ -46,6 +67,11 @@ object Monitoring { val Remote = "remote" } + object Directions { + val Incoming = "incoming" + val Outgoing = "outgoing" + } + } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Monitoring.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Monitoring.scala new file mode 100644 index 000000000..38bbcb272 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Monitoring.scala @@ -0,0 +1,31 @@ +/* + * Copyright 2020 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.crypto + +import kamon.Kamon + +object Monitoring { + + object Metrics { + val OnionPayloadFormat = Kamon.counter("crypto.sphinx.onion-payload-format") + } + + object Tags { + val LegacyOnion = "legacy" + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala index d18bcf1e2..c96942e57 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala @@ -18,6 +18,7 @@ package fr.acinq.eclair.crypto import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.{ByteVector32, Crypto} +import fr.acinq.eclair.crypto.Monitoring.{Metrics, Tags} import fr.acinq.eclair.wire import fr.acinq.eclair.wire.{FailureMessage, FailureMessageCodecs, Onion, OnionCodecs} import grizzled.slf4j.Logging @@ -88,11 +89,13 @@ object Sphinx extends Logging { case 0 => // The 1.0 BOLT spec used 65-bytes frames inside the onion payload. // The first byte of the frame (called `realm`) is set to 0x00, followed by 32 bytes of per-hop data, followed by a 32-bytes mac. + Metrics.OnionPayloadFormat.withTag(Tags.LegacyOnion, value = true).increment() 65 case _ => // The 1.1 BOLT spec changed the frame format to use variable-length per-hop payloads. // The first bytes contain a varint encoding the length of the payload data (not including the trailing mac). // Since messages are always smaller than 65535 bytes, this varint will either be 1 or 3 bytes long. + Metrics.OnionPayloadFormat.withTag(Tags.LegacyOnion, value = false).increment() MacLength + OnionCodecs.payloadLengthDecoder.decode(payload.bits).require.value.toInt } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/AuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/AuditDb.scala index c3a1b9eea..978f16fce 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/AuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/AuditDb.scala @@ -45,7 +45,7 @@ trait AuditDb extends Closeable { def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] - def stats: Seq[Stats] + def stats(from: Long, to: Long): Seq[Stats] } @@ -53,4 +53,4 @@ case class ChannelLifecycleEvent(channelId: ByteVector32, remoteNodeId: PublicKe case class NetworkFee(remoteNodeId: PublicKey, channelId: ByteVector32, txId: ByteVector32, fee: Satoshi, txType: String, timestamp: Long) -case class Stats(channelId: ByteVector32, avgPaymentAmount: Satoshi, paymentCount: Int, relayFee: Satoshi, networkFee: Satoshi) +case class Stats(channelId: ByteVector32, direction: String, avgPaymentAmount: Satoshi, paymentCount: Int, relayFee: Satoshi, networkFee: Satoshi) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala index 67566c8c7..293c20b8e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala @@ -298,35 +298,46 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { q } - override def stats: Seq[Stats] = { - val networkFees = listNetworkFees(0, System.currentTimeMillis + 1).foldLeft(Map.empty[ByteVector32, Satoshi]) { case (feeByChannelId, f) => + override def stats(from: Long, to: Long): Seq[Stats] = { + val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { case (feeByChannelId, f) => feeByChannelId + (f.channelId -> (feeByChannelId.getOrElse(f.channelId, 0 sat) + f.fee)) } - val relayed = listRelayed(0, System.currentTimeMillis + 1).foldLeft(Map.empty[ByteVector32, Seq[PaymentRelayed]]) { case (relayedByChannelId, e) => - val relayedTo = e match { - case c: ChannelPaymentRelayed => Set(c.toChannelId) - case t: TrampolinePaymentRelayed => t.outgoing.map(_.channelId).toSet + case class Relayed(amount: MilliSatoshi, fee: MilliSatoshi, direction: String) + val relayed = listRelayed(from, to).foldLeft(Map.empty[ByteVector32, Seq[Relayed]]) { case (previous, e) => + // NB: we must avoid counting the fee twice: we associate it to the outgoing channels rather than the incoming ones. + val current = e match { + case c: ChannelPaymentRelayed => Map( + c.fromChannelId -> (Relayed(c.amountIn, 0 msat, "IN") +: previous.getOrElse(c.fromChannelId, Nil)), + c.toChannelId -> (Relayed(c.amountOut, c.amountIn - c.amountOut, "OUT") +: previous.getOrElse(c.toChannelId, Nil)) + ) + case t: TrampolinePaymentRelayed => + // We ensure a trampoline payment is counted only once per channel and per direction (if multiple HTLCs were + // sent from/to the same channel, we group them). + val in = t.incoming.groupBy(_.channelId).map { case (channelId, parts) => (channelId, Relayed(parts.map(_.amount).sum, 0 msat, "IN")) }.toSeq + val out = t.outgoing.groupBy(_.channelId).map { case (channelId, parts) => + val fee = (t.amountIn - t.amountOut) * parts.length / t.outgoing.length // we split the fee among outgoing channels + (channelId, Relayed(parts.map(_.amount).sum, fee, "OUT")) + }.toSeq + (in ++ out).groupBy(_._1).map { case (channelId, payments) => (channelId, payments.map(_._2) ++ previous.getOrElse(channelId, Nil)) } } - val updated = relayedTo.map(channelId => (channelId, relayedByChannelId.getOrElse(channelId, Nil) :+ e)).toMap - relayedByChannelId ++ updated + previous ++ current } // Channels opened by our peers won't have any entry in the network_fees table, but we still want to compute stats for them. val allChannels = networkFees.keySet ++ relayed.keySet - allChannels.map(channelId => { + allChannels.toSeq.flatMap(channelId => { val networkFee = networkFees.getOrElse(channelId, 0 sat) - val r = relayed.getOrElse(channelId, Nil) - val paymentCount = r.length - if (paymentCount == 0) { - Stats(channelId, 0 sat, 0, 0 sat, networkFee) - } else { - val avgPaymentAmount = r.map(_.amountOut).sum / paymentCount - val relayFee = r.map { - case c: ChannelPaymentRelayed => c.amountIn - c.amountOut - case t: TrampolinePaymentRelayed => (t.amountIn - t.amountOut) * t.outgoing.count(_.channelId == channelId) / t.outgoing.length - }.sum - Stats(channelId, avgPaymentAmount.truncateToSatoshi, paymentCount, relayFee.truncateToSatoshi, networkFee) + val (in, out) = relayed.getOrElse(channelId, Nil).partition(_.direction == "IN") + ((in, "IN") :: (out, "OUT") :: Nil).map { case (r, direction) => + val paymentCount = r.length + if (paymentCount == 0) { + Stats(channelId, direction, 0 sat, 0, 0 sat, networkFee) + } else { + val avgPaymentAmount = r.map(_.amount).sum / paymentCount + val relayFee = r.map(_.fee).sum + Stats(channelId, direction, avgPaymentAmount.truncateToSatoshi, paymentCount, relayFee.truncateToSatoshi, networkFee) + } } - }).toSeq + }) } // used by mobile apps diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala index ad2a86437..1d57d08aa 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala @@ -27,9 +27,14 @@ object Monitoring { val PaymentParts = Kamon.histogram("payment.parts", "Number of HTLCs per payment (MPP)") val PaymentFailed = Kamon.counter("payment.failed", "Number of failed payment") val PaymentError = Kamon.counter("payment.error", "Non-fatal errors encountered during payment attempts") + val PaymentAttempt = Kamon.histogram("payment.attempt", "Number of attempts before a payment succeeds") val SentPaymentDuration = Kamon.timer("payment.duration.sent", "Outgoing payment duration") val ReceivedPaymentDuration = Kamon.timer("payment.duration.received", "Incoming payment duration") + // The goal of this metric is to measure whether retrying MPP payments on failing channels yields useful results. + // Once enough data has been collected, we will update the MultiPartPaymentLifecycle logic accordingly. + val RetryFailedChannelsResult = Kamon.counter("payment.mpp.retry-failed-channels-result") + def recordPaymentRelayFailed(failureType: String, relayType: String): Unit = Metrics.PaymentFailed .withTag(Tags.Direction, Tags.Directions.Relayed) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala index fe3a433f1..6a4d2747d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala @@ -23,7 +23,7 @@ import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.MilliSatoshi import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.router.Announcements -import fr.acinq.eclair.router.Router.{ChannelDesc, ChannelHop, Hop} +import fr.acinq.eclair.router.Router.{ChannelDesc, ChannelHop, Hop, Ignore} import fr.acinq.eclair.wire.Node /** @@ -162,43 +162,43 @@ object PaymentFailure { .isDefined /** Update the set of nodes and channels to ignore in retries depending on the failure we received. */ - def updateIgnored(failure: PaymentFailure, ignoreNodes: Set[PublicKey], ignoreChannels: Set[ChannelDesc]): (Set[PublicKey], Set[ChannelDesc]) = failure match { + def updateIgnored(failure: PaymentFailure, ignore: Ignore): Ignore = failure match { case RemoteFailure(hops, Sphinx.DecryptedFailurePacket(nodeId, _)) if nodeId == hops.last.nextNodeId => // The failure came from the final recipient: the payment should be aborted without penalizing anyone in the route. - (ignoreNodes, ignoreChannels) + ignore case RemoteFailure(_, Sphinx.DecryptedFailurePacket(nodeId, _: Node)) => - (ignoreNodes + nodeId, ignoreChannels) + ignore + nodeId case RemoteFailure(_, Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Update)) => if (Announcements.checkSig(failureMessage.update, nodeId)) { // We were using an outdated channel update, we should retry with the new one and nobody should be penalized. - (ignoreNodes, ignoreChannels) + ignore } else { // This node is fishy, it gave us a bad signature, so let's filter it out. - (ignoreNodes + nodeId, ignoreChannels) + ignore + nodeId } case RemoteFailure(hops, Sphinx.DecryptedFailurePacket(nodeId, _)) => // Let's ignore the channel outgoing from nodeId. hops.collectFirst { case hop: ChannelHop if hop.nodeId == nodeId => ChannelDesc(hop.lastUpdate.shortChannelId, hop.nodeId, hop.nextNodeId) } match { - case Some(faultyChannel) => (ignoreNodes, ignoreChannels + faultyChannel) - case None => (ignoreNodes, ignoreChannels) + case Some(faultyChannel) => ignore + faultyChannel + case None => ignore } case UnreadableRemoteFailure(hops) => // We don't know which node is sending garbage, let's blacklist all nodes except the one we are directly connected to and the final recipient. - val blacklist = hops.map(_.nextNodeId).drop(1).dropRight(1) - (ignoreNodes ++ blacklist, ignoreChannels) + val blacklist = hops.map(_.nextNodeId).drop(1).dropRight(1).toSet + ignore ++ blacklist case LocalFailure(hops, _) => hops.headOption match { case Some(hop: ChannelHop) => val faultyChannel = ChannelDesc(hop.lastUpdate.shortChannelId, hop.nodeId, hop.nextNodeId) - (ignoreNodes, ignoreChannels + faultyChannel) - case _ => (ignoreNodes, ignoreChannels) + ignore + faultyChannel + case _ => ignore } } /** Update the set of nodes and channels to ignore in retries depending on the failures we received. */ - def updateIgnored(failures: Seq[PaymentFailure], ignoreNodes: Set[PublicKey], ignoreChannels: Set[ChannelDesc]): (Set[PublicKey], Set[ChannelDesc]) = { - failures.foldLeft((ignoreNodes, ignoreChannels)) { case ((nodes, channels), failure) => updateIgnored(failure, nodes, channels) } + def updateIgnored(failures: Seq[PaymentFailure], ignore: Ignore): Ignore = { + failures.foldLeft(ignore) { case (current, failure) => updateIgnored(failure, current) } } } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala index adfb65091..32ac71676 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala @@ -219,6 +219,7 @@ object ChannelRelayer { case (_: ExpiryTooBig, _) => ExpiryTooFar case (_: InsufficientFunds, Some(channelUpdate)) => TemporaryChannelFailure(channelUpdate) case (_: TooManyAcceptedHtlcs, Some(channelUpdate)) => TemporaryChannelFailure(channelUpdate) + case (_: FeerateTooDifferent, Some(channelUpdate)) => TemporaryChannelFailure(channelUpdate) case (_: ChannelUnavailable, Some(channelUpdate)) if !Announcements.isEnabled(channelUpdate.channelFlags) => ChannelDisabled(channelUpdate.messageFlags, channelUpdate.channelFlags, channelUpdate) case (_: ChannelUnavailable, None) => PermanentChannelFailure case _ => TemporaryNodeFailure diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala index b5eea704c..58b0fa701 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala @@ -20,20 +20,20 @@ import java.util.UUID import akka.actor.{Actor, ActorRef, DiagnosticActorLogging, PoisonPill, Props} import akka.event.Logging.MDC +import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey -import fr.acinq.bitcoin.{ByteVector32, Crypto} import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Upstream} import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM -import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.SendMultiPartPayment +import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment} import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPayment -import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentError, PaymentLifecycle} +import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentLifecycle} import fr.acinq.eclair.router.Router.RouteParams -import fr.acinq.eclair.router.{RouteCalculation, RouteNotFound} +import fr.acinq.eclair.router.{BalanceTooLow, RouteCalculation, RouteNotFound} import fr.acinq.eclair.wire._ -import fr.acinq.eclair.{CltvExpiry, Logs, MilliSatoshi, NodeParams, nodeFee, randomBytes32, _} +import fr.acinq.eclair.{CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, nodeFee, randomBytes32, _} import scala.collection.immutable.Queue @@ -46,7 +46,7 @@ import scala.collection.immutable.Queue * It aggregates incoming HTLCs (in case multi-part was used upstream) and then forwards the requested amount (using the * router to find a route to the remote node and potentially splitting the payment using multi-part). */ -class NodeRelayer(nodeParams: NodeParams, relayer: ActorRef, router: ActorRef, commandBuffer: ActorRef, register: ActorRef) extends Actor with DiagnosticActorLogging { +class NodeRelayer(nodeParams: NodeParams, router: ActorRef, commandBuffer: ActorRef, register: ActorRef) extends Actor with DiagnosticActorLogging { import NodeRelayer._ @@ -109,18 +109,13 @@ class NodeRelayer(nodeParams: NodeParams, relayer: ActorRef, router: ActorRef, c case None => log.error("could not find pending incoming payment: payment will not be relayed: please investigate") } - case ff: Relayer.ForwardFulfill => ff.to match { - case Origin.TrampolineRelayed(_, Some(paymentSender)) => - paymentSender ! ff - val paymentHash = Crypto.sha256(ff.paymentPreimage) - pendingOutgoing.get(paymentHash).foreach(p => if (!p.fulfilledUpstream) { - // We want to fulfill upstream as soon as we receive the preimage (even if not all HTLCs have fulfilled downstream). - log.debug("trampoline payment successfully relayed") - fulfillPayment(p.upstream, ff.paymentPreimage) - context become main(pendingIncoming, pendingOutgoing + (paymentHash -> p.copy(fulfilledUpstream = true))) - }) - case _ => log.error(s"unexpected non-trampoline fulfill: $ff") - } + case PreimageReceived(paymentHash, paymentPreimage) => + log.debug("trampoline payment successfully relayed") + pendingOutgoing.get(paymentHash).foreach(p => if (!p.fulfilledUpstream) { + // We want to fulfill upstream as soon as we receive the preimage (even if not all HTLCs have fulfilled downstream). + fulfillPayment(p.upstream, paymentPreimage) + context become main(pendingIncoming, pendingOutgoing + (paymentHash -> p.copy(fulfilledUpstream = true))) + }) case PaymentSent(id, paymentHash, paymentPreimage, _, _, parts) => // We may have already fulfilled upstream, but we can now emit an accurate relayed event and clean-up resources. @@ -148,7 +143,7 @@ class NodeRelayer(nodeParams: NodeParams, relayer: ActorRef, router: ActorRef, c def spawnOutgoingPayFSM(cfg: SendPaymentConfig, multiPart: Boolean): ActorRef = { if (multiPart) { - context.actorOf(MultiPartPaymentLifecycle.props(nodeParams, cfg, relayer, router, register)) + context.actorOf(MultiPartPaymentLifecycle.props(nodeParams, cfg, router, register)) } else { context.actorOf(PaymentLifecycle.props(nodeParams, cfg, router, register)) } @@ -158,15 +153,21 @@ class NodeRelayer(nodeParams: NodeParams, relayer: ActorRef, router: ActorRef, c val paymentId = UUID.randomUUID() val paymentCfg = SendPaymentConfig(paymentId, paymentId, None, paymentHash, payloadOut.amountToForward, payloadOut.outgoingNodeId, upstream, None, storeInDb = false, publishEvent = false, Nil) val routeParams = computeRouteParams(nodeParams, upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv) + // If invoice features are provided in the onion, the sender is asking us to relay to a non-trampoline recipient. payloadOut.invoiceFeatures match { - case Some(_) => - log.debug("relaying trampoline payment to non-trampoline recipient") + case Some(features) => val routingHints = payloadOut.invoiceRoutingInfo.map(_.map(_.toSeq).toSeq).getOrElse(Nil) - // TODO: @t-bast: MPP is disabled for trampoline to non-trampoline payments until we improve the splitting algorithm for nodes with a lot of channels. - val payFSM = spawnOutgoingPayFSM(paymentCfg, multiPart = false) - val finalPayload = Onion.createSinglePartPayload(payloadOut.amountToForward, payloadOut.outgoingCltv, payloadOut.paymentSecret) - val payment = SendPayment(payloadOut.outgoingNodeId, finalPayload, nodeParams.maxPaymentAttempts, routingHints, Some(routeParams)) - payFSM ! payment + payloadOut.paymentSecret match { + case Some(paymentSecret) if Features(features).hasFeature(Features.BasicMultiPartPayment) => + log.debug("relaying trampoline payment to non-trampoline recipient using MPP") + val payment = SendMultiPartPayment(paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, routingHints, Some(routeParams)) + spawnOutgoingPayFSM(paymentCfg, multiPart = true) ! payment + case _ => + log.debug("relaying trampoline payment to non-trampoline recipient without MPP") + val finalPayload = Onion.createSinglePartPayload(payloadOut.amountToForward, payloadOut.outgoingCltv, payloadOut.paymentSecret) + val payment = SendPayment(payloadOut.outgoingNodeId, finalPayload, nodeParams.maxPaymentAttempts, routingHints, Some(routeParams)) + spawnOutgoingPayFSM(paymentCfg, multiPart = false) ! payment + } case None => log.debug("relaying trampoline payment to next trampoline node") val payFSM = spawnOutgoingPayFSM(paymentCfg, multiPart = true) @@ -209,7 +210,7 @@ class NodeRelayer(nodeParams: NodeParams, relayer: ActorRef, router: ActorRef, c object NodeRelayer { - def props(nodeParams: NodeParams, relayer: ActorRef, router: ActorRef, commandBuffer: ActorRef, register: ActorRef) = Props(new NodeRelayer(nodeParams, relayer, router, commandBuffer, register)) + def props(nodeParams: NodeParams, router: ActorRef, commandBuffer: ActorRef, register: ActorRef) = Props(new NodeRelayer(nodeParams, router, commandBuffer, register)) /** * We start by aggregating an incoming HTLC set. Once we received the whole set, we will compute a route to the next @@ -260,15 +261,11 @@ object NodeRelayer { * should return upstream. */ private def translateError(failures: Seq[PaymentFailure], outgoingNodeId: PublicKey): Option[FailureMessage] = { - def tooManyRouteNotFound(failures: Seq[PaymentFailure]): Boolean = { - val routeNotFoundCount = failures.collect { case f@LocalFailure(_, RouteNotFound) => f }.length - routeNotFoundCount > failures.length / 2 - } - + val routeNotFound = failures.collectFirst { case f@LocalFailure(_, RouteNotFound) => f }.nonEmpty failures match { case Nil => None - case LocalFailure(_, PaymentError.BalanceTooLow) :: Nil => Some(TemporaryNodeFailure) // we don't have enough outgoing liquidity at the moment - case _ if tooManyRouteNotFound(failures) => Some(TrampolineFeeInsufficient) // if we couldn't find routes, it's likely that the fee/cltv was insufficient + case LocalFailure(_, BalanceTooLow) :: Nil => Some(TemporaryNodeFailure) // we don't have enough outgoing liquidity at the moment + case _ if routeNotFound => Some(TrampolineFeeInsufficient) // if we couldn't find routes, it's likely that the fee/cltv was insufficient case _ => // Otherwise, we try to find a downstream error that we could decrypt. val outgoingNodeFailure = failures.collectFirst { case RemoteFailure(_, e) if e.originNode == outgoingNodeId => e.failureMessage } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala index f5ba1b041..e509499ed 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala @@ -77,9 +77,9 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, comm context.system.eventStream.subscribe(self, classOf[AvailableBalanceChanged]) context.system.eventStream.subscribe(self, classOf[ShortChannelIdAssigned]) - private val postRestartCleaner = context.actorOf(PostRestartHtlcCleaner.props(nodeParams, commandBuffer, initialized)) - private val channelRelayer = context.actorOf(ChannelRelayer.props(nodeParams, self, register, commandBuffer)) - private val nodeRelayer = context.actorOf(NodeRelayer.props(nodeParams, self, router, commandBuffer, register)) + private val postRestartCleaner = context.actorOf(PostRestartHtlcCleaner.props(nodeParams, commandBuffer, initialized), "post-restart-htlc-cleaner") + private val channelRelayer = context.actorOf(ChannelRelayer.props(nodeParams, self, register, commandBuffer), "channel-relayer") + private val nodeRelayer = context.actorOf(NodeRelayer.props(nodeParams, router, commandBuffer, register), "node-relayer") override def receive: Receive = main(Map.empty, new mutable.HashMap[PublicKey, mutable.Set[ShortChannelId]] with mutable.MultiMap[PublicKey, ShortChannelId]) @@ -167,7 +167,7 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, comm commandBuffer ! CommandBuffer.CommandSend(originChannelId, cmd) context.system.eventStream.publish(ChannelPaymentRelayed(amountIn, amountOut, ff.htlc.paymentHash, originChannelId, ff.htlc.channelId)) case Origin.TrampolineRelayed(_, None) => postRestartCleaner forward ff - case Origin.TrampolineRelayed(_, Some(_)) => nodeRelayer forward ff + case Origin.TrampolineRelayed(_, Some(paymentSender)) => paymentSender ! ff } case ff: ForwardFail => ff.to match { @@ -206,7 +206,7 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, comm object Relayer extends Logging { def props(nodeParams: NodeParams, router: ActorRef, register: ActorRef, commandBuffer: ActorRef, paymentHandler: ActorRef, initialized: Option[Promise[Done]] = None) = - Props(classOf[Relayer], nodeParams, router, register, commandBuffer, paymentHandler, initialized) + Props(new Relayer(nodeParams, router, register, commandBuffer, paymentHandler, initialized)) type ChannelUpdates = Map[ShortChannelId, OutgoingChannel] type NodeChannels = mutable.HashMap[PublicKey, mutable.Set[ShortChannelId]] with mutable.MultiMap[PublicKey, ShortChannelId] diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala index 6e2434a2c..6c5e8f050 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala @@ -19,29 +19,23 @@ package fr.acinq.eclair.payment.send import java.util.UUID import java.util.concurrent.TimeUnit -import akka.actor.{ActorRef, FSM, Props} +import akka.actor.{ActorRef, FSM, Props, Status} import akka.event.Logging.MDC import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey -import fr.acinq.eclair.channel.{Commitments, Upstream} -import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.channel.Upstream import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.payment.PaymentSent.PartialPayment import fr.acinq.eclair.payment._ -import fr.acinq.eclair.payment.relay.Relayer.{GetOutgoingChannels, OutgoingChannel, OutgoingChannels} import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig -import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPayment +import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToRoute +import fr.acinq.eclair.router.RouteCalculation import fr.acinq.eclair.router.Router._ -import fr.acinq.eclair.router._ import fr.acinq.eclair.wire._ -import fr.acinq.eclair.{CltvExpiry, FSMDiagnosticActorLogging, Logs, LongToBtcAmount, MilliSatoshi, NodeParams, ShortChannelId, ToMilliSatoshiConversion} +import fr.acinq.eclair.{CltvExpiry, FSMDiagnosticActorLogging, Logs, LongToBtcAmount, MilliSatoshi, NodeParams} import kamon.Kamon import kamon.context.Context -import scodec.bits.ByteVector - -import scala.annotation.tailrec -import scala.util.Random /** * Created by t-bast on 18/07/2019. @@ -51,7 +45,7 @@ import scala.util.Random * Sender for a multi-part payment (see https://github.com/lightningnetwork/lightning-rfc/blob/master/04-onion-routing.md#basic-multi-part-payments). * The payment will be split into multiple sub-payments that will be sent in parallel. */ -class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, relayer: ActorRef, router: ActorRef, register: ActorRef) extends FSMDiagnosticActorLogging[MultiPartPaymentLifecycle.State, MultiPartPaymentLifecycle.Data] { +class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, register: ActorRef) extends FSMDiagnosticActorLogging[MultiPartPaymentLifecycle.State, MultiPartPaymentLifecycle.Data] { import MultiPartPaymentLifecycle._ @@ -60,6 +54,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, val id = cfg.id val paymentHash = cfg.paymentHash val start = System.currentTimeMillis + private var retriedFailedChannels = false private val span = Kamon.spanBuilder("multi-part-payment") .tag(Tags.ParentId, cfg.parentId.toString) @@ -72,109 +67,95 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, when(WAIT_FOR_PAYMENT_REQUEST) { case Event(r: SendMultiPartPayment, _) => - router ! GetNetworkStats - goto(WAIT_FOR_NETWORK_STATS) using WaitingForNetworkStats(sender, r) + val routeParams = r.getRouteParams(nodeParams, randomize = false) // we don't randomize the first attempt, regardless of configuration choices + val maxFee = routeParams.getMaxFee(r.totalAmount) + log.debug("sending {} with maximum fee {}", r.totalAmount, maxFee) + val d = PaymentProgress(sender, r, r.maxAttempts, Map.empty, Ignore.empty, Nil) + router ! createRouteRequest(nodeParams, r.totalAmount, maxFee, routeParams, d, cfg) + goto(WAIT_FOR_ROUTES) using d } - when(WAIT_FOR_NETWORK_STATS) { - case Event(s: GetNetworkStatsResponse, d: WaitingForNetworkStats) => - log.debug("network stats: {}", s.stats.map(_.capacity)) - // If we don't have network stats it's ok, we'll use data about our local channels instead. - // We tell the router to compute those stats though: in case our payment attempt fails, they will be available for - // another payment attempt. - if (s.stats.isEmpty) { - router ! TickComputeNetworkStats - } - relayer ! GetOutgoingChannels() - goto(WAIT_FOR_CHANNEL_BALANCES) using WaitingForChannelBalances(d.sender, d.request, s.stats) - } - - when(WAIT_FOR_CHANNEL_BALANCES) { - case Event(OutgoingChannels(channels), d: WaitingForChannelBalances) => - log.debug("trying to send {} with local channels: {}", d.request.totalAmount, channels.map(_.toUsableBalance).mkString(",")) - val (remaining, payments) = splitPayment(nodeParams, d.request.totalAmount, channels, d.networkStats, d.request, randomize = false) - if (remaining > 0.msat) { - log.warning(s"cannot send ${d.request.totalAmount} with our current balance") - Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(Nil, PaymentError.BalanceTooLow))) - goto(PAYMENT_ABORTED) using PaymentAborted(d.sender, d.request, LocalFailure(Nil, PaymentError.BalanceTooLow) :: Nil, Set.empty) - } else { - val pending = setFees(d.request.routeParams, payments, payments.size) + when(WAIT_FOR_ROUTES) { + case Event(RouteResponse(routes), d: PaymentProgress) => + log.info("{} routes found (attempt={}/{})", routes.length, d.request.maxAttempts - d.remainingAttempts + 1, d.request.maxAttempts) + // We may have already succeeded sending parts of the payment and only need to take care of the rest. + val (toSend, maxFee) = remainingToSend(nodeParams, d.request, d.pending.values) + if (routes.map(_.amount).sum == toSend) { + val childPayments = routes.map(route => (UUID.randomUUID(), route)).toMap Kamon.runWithContextEntry(parentPaymentIdKey, cfg.parentId) { Kamon.runWithSpan(span, finishSpan = true) { - pending.foreach { case (childId, payment) => spawnChildPaymentFsm(childId) ! payment } + childPayments.foreach { case (childId, route) => spawnChildPaymentFsm(childId) ! createChildPayment(route, d.request) } } } - goto(PAYMENT_IN_PROGRESS) using PaymentProgress(d.sender, d.request, d.networkStats, channels.length, 0 msat, d.request.maxAttempts - 1, pending, Set.empty, Nil) + goto(PAYMENT_IN_PROGRESS) using d.copy(remainingAttempts = (d.remainingAttempts - 1).max(0), pending = d.pending ++ childPayments) + } else { + // If a child payment failed while we were waiting for routes, the routes we received don't cover the whole + // remaining amount. In that case we discard these routes and send a new request to the router. + log.info("discarding routes, another child payment failed so we need to recompute them (amount = {}, maximum fee = {})", toSend, maxFee) + val routeParams = d.request.getRouteParams(nodeParams, randomize = true) // we randomize route selection when we retry + router ! createRouteRequest(nodeParams, toSend, maxFee, routeParams, d, cfg) + stay } + + case Event(Status.Failure(t), d: PaymentProgress) => + log.warning("router error: {}", t.getMessage) + if (d.ignore.channels.nonEmpty) { + // If no route can be found, we will retry once with the channels that we previously ignored. + // Channels are mostly ignored for temporary reasons, likely because they didn't have enough balance to forward + // the payment. When we're retrying an MPP split, it may make sense to retry those ignored channels because with + // a different split, they may have enough balance to forward the payment. + val (toSend, maxFee) = remainingToSend(nodeParams, d.request, d.pending.values) + log.debug("retry sending {} with maximum fee {} without ignoring channels ({})", toSend, maxFee, d.ignore.channels.map(_.shortChannelId).mkString(",")) + val routeParams = d.request.getRouteParams(nodeParams, randomize = true) // we randomize route selection when we retry + router ! createRouteRequest(nodeParams, toSend, maxFee, routeParams, d, cfg).copy(ignore = d.ignore.emptyChannels()) + retriedFailedChannels = true + stay using d.copy(remainingAttempts = (d.remainingAttempts - 1).max(0), ignore = d.ignore.emptyChannels()) + } else { + val failure = LocalFailure(Nil, t) + Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(failure)).increment() + gotoAbortedOrStop(PaymentAborted(d.sender, d.request, d.failures :+ failure, d.pending.keySet)) + } + + case Event(pf: PaymentFailed, d: PaymentProgress) => + if (isFinalRecipientFailure(pf, d)) { + gotoAbortedOrStop(PaymentAborted(d.sender, d.request, d.failures ++ pf.failures, d.pending.keySet - pf.id)) + } else { + val ignore1 = PaymentFailure.updateIgnored(pf.failures, d.ignore) + stay using d.copy(pending = d.pending - pf.id, ignore = ignore1, failures = d.failures ++ pf.failures) + } + + // The recipient released the preimage without receiving the full payment amount. + // This is a spec violation and is too bad for them, we obtained a proof of payment without paying the full amount. + case Event(ps: PaymentSent, d: PaymentProgress) => + require(ps.parts.length == 1, "child payment must contain only one part") + // As soon as we get the preimage we can consider that the whole payment succeeded (we have a proof of payment). + gotoSucceededOrStop(PaymentSucceeded(d.sender, d.request, ps.paymentPreimage, ps.parts, d.pending.keySet - ps.parts.head.id)) } when(PAYMENT_IN_PROGRESS) { - case Event(pf: PaymentFailed, d: PaymentProgress) => handleChildFailure(pf, d) match { - case Some(paymentAborted) => - goto(PAYMENT_ABORTED) using paymentAborted - case None => - // Get updated local channels (will take into account the child payments that are in-flight). - relayer ! GetOutgoingChannels() - val failedPayment = d.pending(pf.id) - val shouldBlacklist = shouldBlacklistChannel(pf) - if (shouldBlacklist) { - log.debug(s"ignoring channel ${getFirstHopShortChannelId(failedPayment)} to ${failedPayment.routePrefix.head.nextNodeId}") - } - val ignoreChannels = if (shouldBlacklist) d.ignoreChannels + getFirstHopShortChannelId(failedPayment) else d.ignoreChannels - val remainingAttempts = if (shouldBlacklist && Random.nextDouble() * math.log(d.channelsCount) > 2.0) { - // When we have a lot of channels, many of them may end up being a bad route prefix for the destination we're - // trying to reach. This is a cheap error that is detected quickly (RouteNotFound), so we don't want to count - // it in our payment attempts to avoid failing too fast. - // However we don't want to test all of our channels either which would be expensive, so we only probabilistically - // count the failure in our payment attempts. - // With the log-scale used, here are the probabilities and the corresponding number of retries: - // * 10 channels -> refund 13% of failures -> with 5 initial retries we will actually try 5/(1-0.13) = ~6 times - // * 20 channels -> refund 32% of failures -> with 5 initial retries we will actually try 5/(1-0.32) = ~7 times - // * 50 channels -> refund 50% of failures -> with 5 initial retries we will actually try 5/(1-0.50) = ~10 times - // * 100 channels -> refund 56% of failures -> with 5 initial retries we will actually try 5/(1-0.56) = ~11 times - // * 1000 channels -> refund 70% of failures -> with 5 initial retries we will actually try 5/(1-0.70) = ~17 times - // NB: this hack won't be necessary once multi-part is directly handled by the router. - d.remainingAttempts + 1 - } else { - d.remainingAttempts - } - goto(RETRY_WITH_UPDATED_BALANCES) using d.copy(toSend = d.toSend + failedPayment.finalPayload.amount, pending = d.pending - pf.id, failures = d.failures ++ pf.failures, ignoreChannels = ignoreChannels, remainingAttempts = remainingAttempts) - } - - case Event(ps: PaymentSent, d: PaymentProgress) => - require(ps.parts.length == 1, "child payment must contain only one part") - // As soon as we get the preimage we can consider that the whole payment succeeded (we have a proof of payment). - goto(PAYMENT_SUCCEEDED) using PaymentSucceeded(d.sender, d.request, ps.paymentPreimage, ps.parts, d.pending.keySet - ps.parts.head.id) - } - - when(RETRY_WITH_UPDATED_BALANCES) { - case Event(OutgoingChannels(channels), d: PaymentProgress) => - log.debug("trying to send {} with local channels: {}", d.toSend, channels.map(_.toUsableBalance).mkString(",")) - val filteredChannels = channels.filter(c => !d.ignoreChannels.contains(c.channelUpdate.shortChannelId)) - val (remaining, payments) = splitPayment(nodeParams, d.toSend, filteredChannels, d.networkStats, d.request, randomize = true) // we randomize channel selection when we retry - if (remaining > 0.msat) { - log.warning(s"cannot send ${d.toSend} with our current balance") - Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(Nil, PaymentError.BalanceTooLow))) - goto(PAYMENT_ABORTED) using PaymentAborted(d.sender, d.request, d.failures :+ LocalFailure(Nil, PaymentError.BalanceTooLow), d.pending.keySet) + case Event(pf: PaymentFailed, d: PaymentProgress) => + if (isFinalRecipientFailure(pf, d)) { + gotoAbortedOrStop(PaymentAborted(d.sender, d.request, d.failures ++ pf.failures, d.pending.keySet - pf.id)) + } else if (d.remainingAttempts == 0) { + val failure = LocalFailure(Nil, PaymentError.RetryExhausted) + Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(failure)).increment() + gotoAbortedOrStop(PaymentAborted(d.sender, d.request, d.failures ++ pf.failures :+ failure, d.pending.keySet - pf.id)) } else { - val pending = setFees(d.request.routeParams, payments, payments.size + d.pending.size) - pending.foreach { case (childId, payment) => spawnChildPaymentFsm(childId) ! payment } - goto(PAYMENT_IN_PROGRESS) using d.copy(toSend = 0 msat, remainingAttempts = d.remainingAttempts - 1, pending = d.pending ++ pending, channelsCount = channels.length) + val ignore1 = PaymentFailure.updateIgnored(pf.failures, d.ignore) + val stillPending = d.pending - pf.id + val (toSend, maxFee) = remainingToSend(nodeParams, d.request, stillPending.values) + log.debug("child payment failed, retry sending {} with maximum fee {}", toSend, maxFee) + val routeParams = d.request.getRouteParams(nodeParams, randomize = true) // we randomize route selection when we retry + val d1 = d.copy(pending = stillPending, ignore = ignore1, failures = d.failures ++ pf.failures) + router ! createRouteRequest(nodeParams, toSend, maxFee, routeParams, d1, cfg) + goto(WAIT_FOR_ROUTES) using d1 } - case Event(pf: PaymentFailed, d: PaymentProgress) => handleChildFailure(pf, d) match { - case Some(paymentAborted) => - goto(PAYMENT_ABORTED) using paymentAborted - case None => - val failedPayment = d.pending(pf.id) - val ignoreChannels = if (shouldBlacklistChannel(pf)) d.ignoreChannels + getFirstHopShortChannelId(failedPayment) else d.ignoreChannels - stay using d.copy(toSend = d.toSend + failedPayment.finalPayload.amount, pending = d.pending - pf.id, failures = d.failures ++ pf.failures, ignoreChannels = ignoreChannels) - } - case Event(ps: PaymentSent, d: PaymentProgress) => require(ps.parts.length == 1, "child payment must contain only one part") // As soon as we get the preimage we can consider that the whole payment succeeded (we have a proof of payment). - goto(PAYMENT_SUCCEEDED) using PaymentSucceeded(d.sender, d.request, ps.paymentPreimage, ps.parts, d.pending.keySet - ps.parts.head.id) + Metrics.PaymentAttempt.withTag(Tags.MultiPart, value = true).record(d.request.maxAttempts - d.remainingAttempts) + gotoSucceededOrStop(PaymentSucceeded(d.sender, d.request, ps.paymentPreimage, ps.parts, d.pending.keySet - ps.parts.head.id)) } when(PAYMENT_ABORTED) { @@ -192,7 +173,10 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, case Event(ps: PaymentSent, d: PaymentAborted) => require(ps.parts.length == 1, "child payment must contain only one part") log.warning(s"payment recipient fulfilled incomplete multi-part payment (id=${ps.parts.head.id})") - goto(PAYMENT_SUCCEEDED) using PaymentSucceeded(d.sender, d.request, ps.paymentPreimage, ps.parts, d.pending - ps.parts.head.id) + gotoSucceededOrStop(PaymentSucceeded(d.sender, d.request, ps.paymentPreimage, ps.parts, d.pending - ps.parts.head.id)) + + case Event(_: RouteResponse, _) => stay + case Event(_: Status.Failure, _) => stay } when(PAYMENT_SUCCEEDED) { @@ -216,20 +200,9 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, } else { stay using d.copy(pending = pending) } - } - onTransition { - case _ -> PAYMENT_ABORTED => nextStateData match { - case d: PaymentAborted if d.pending.isEmpty => - myStop(d.sender, Left(PaymentFailed(id, paymentHash, d.failures))) - case _ => - } - - case _ -> PAYMENT_SUCCEEDED => nextStateData match { - case d: PaymentSucceeded if d.pending.isEmpty => - myStop(d.sender, Right(cfg.createPaymentSent(d.preimage, d.parts))) - case _ => - } + case Event(_: RouteResponse, _) => stay + case Event(_: Status.Failure, _) => stay } def spawnChildPaymentFsm(childId: UUID): ActorRef = { @@ -241,6 +214,21 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, context.actorOf(PaymentLifecycle.props(nodeParams, childCfg, router, register)) } + private def gotoAbortedOrStop(d: PaymentAborted): State = { + if (d.pending.isEmpty) { + myStop(d.sender, Left(PaymentFailed(id, paymentHash, d.failures))) + } else + goto(PAYMENT_ABORTED) using d + } + + private def gotoSucceededOrStop(d: PaymentSucceeded): State = { + d.sender ! PreimageReceived(paymentHash, d.preimage) + if (d.pending.isEmpty) { + myStop(d.sender, Right(cfg.createPaymentSent(d.preimage, d.parts))) + } else + goto(PAYMENT_SUCCEEDED) using d + } + def myStop(origin: ActorRef, event: Either[PaymentFailed, PaymentSent]): State = { event match { case Left(paymentFailed) => @@ -255,6 +243,9 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, .withTag(Tags.MultiPart, Tags.MultiPartType.Parent) .withTag(Tags.Success, value = event.isRight) .record(System.currentTimeMillis - start, TimeUnit.MILLISECONDS) + if (retriedFailedChannels) { + Metrics.RetryFailedChannelsResult.withTag(Tags.Success, event.isRight).increment() + } span.finish() stop(FSM.Normal) } @@ -264,21 +255,13 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, if (cfg.publishEvent) context.system.eventStream.publish(e) } - def handleChildFailure(pf: PaymentFailed, d: PaymentProgress): Option[PaymentAborted] = { - val isFromFinalRecipient = pf.failures.collectFirst { case f: RemoteFailure if f.e.originNode == d.request.targetNodeId => true }.isDefined - if (isFromFinalRecipient) { - Some(PaymentAborted(d.sender, d.request, d.failures ++ pf.failures, d.pending.keySet - pf.id)) - } else if (d.remainingAttempts == 0) { - val failure = LocalFailure(Nil, PaymentError.RetryExhausted) - Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(failure)) - Some(PaymentAborted(d.sender, d.request, d.failures ++ pf.failures :+ failure, d.pending.keySet - pf.id)) - } else { - None - } - } - override def mdc(currentMessage: Any): MDC = { - Logs.mdc(category_opt = Some(Logs.LogCategory.PAYMENT), parentPaymentId_opt = Some(cfg.parentId), paymentId_opt = Some(id), paymentHash_opt = Some(paymentHash)) + Logs.mdc( + category_opt = Some(Logs.LogCategory.PAYMENT), + parentPaymentId_opt = Some(cfg.parentId), + paymentId_opt = Some(id), + paymentHash_opt = Some(paymentHash), + remoteNodeId_opt = Some(cfg.recipientNodeId)) } initialize() @@ -289,7 +272,7 @@ object MultiPartPaymentLifecycle { val parentPaymentIdKey = Context.key[UUID]("parentPaymentId", UUID.fromString("00000000-0000-0000-0000-000000000000")) - def props(nodeParams: NodeParams, cfg: SendPaymentConfig, relayer: ActorRef, router: ActorRef, register: ActorRef) = Props(new MultiPartPaymentLifecycle(nodeParams, cfg, relayer, router, register)) + def props(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, register: ActorRef) = Props(new MultiPartPaymentLifecycle(nodeParams, cfg, router, register)) /** * Send a payment to a given node. The payment may be split into multiple child payments, for which a path-finding @@ -316,53 +299,52 @@ object MultiPartPaymentLifecycle { additionalTlvs: Seq[OnionTlv] = Nil, userCustomTlvs: Seq[GenericTlv] = Nil) { require(totalAmount > 0.msat, s"total amount must be > 0") + + def getRouteParams(nodeParams: NodeParams, randomize: Boolean): RouteParams = + routeParams.getOrElse(RouteCalculation.getDefaultRouteParams(nodeParams.routerConf)).copy(randomize = randomize) } + /** + * The payment FSM will wait for all child payments to settle before emitting payment events, but the preimage will be + * shared as soon as it's received to unblock other actors that may need it. + */ + case class PreimageReceived(paymentHash: ByteVector32, paymentPreimage: ByteVector32) + // @formatter:off sealed trait State - case object WAIT_FOR_PAYMENT_REQUEST extends State - case object WAIT_FOR_NETWORK_STATS extends State - case object WAIT_FOR_CHANNEL_BALANCES extends State + case object WAIT_FOR_PAYMENT_REQUEST extends State + case object WAIT_FOR_ROUTES extends State case object PAYMENT_IN_PROGRESS extends State - case object RETRY_WITH_UPDATED_BALANCES extends State case object PAYMENT_ABORTED extends State case object PAYMENT_SUCCEEDED extends State + // @formatter:on sealed trait Data + /** - * During initialization, we wait for a multi-part payment request containing the total amount to send. + * During initialization, we wait for a multi-part payment request containing the total amount to send and the maximum + * fee budget. */ case object WaitingForRequest extends Data + /** - * During initialization, we collect network statistics to help us decide how to best split a big payment. - * - * @param sender the sender of the payment request. - * @param request payment request containing the total amount to send. - */ - case class WaitingForNetworkStats(sender: ActorRef, request: SendMultiPartPayment) extends Data - /** - * During initialization, we request our local channels balances. - * - * @param sender the sender of the payment request. - * @param request payment request containing the total amount to send. - * @param networkStats network statistics help us decide how to best split a big payment. - */ - case class WaitingForChannelBalances(sender: ActorRef, request: SendMultiPartPayment, networkStats: Option[NetworkStats]) extends Data - /** - * While the payment is in progress, we listen to child payment failures. When we receive such failures, we request - * our up-to-date local channels balances and retry the failed child payments with a potentially different route. + * While the payment is in progress, we listen to child payment failures. When we receive such failures, we retry the + * failed amount with different routes. * * @param sender the sender of the payment request. * @param request payment request containing the total amount to send. - * @param networkStats network statistics help us decide how to best split a big payment. - * @param channelsCount number of local channels. - * @param toSend remaining amount that should be split and sent. * @param remainingAttempts remaining attempts (after child payments fail). * @param pending pending child payments (payment sent, we are waiting for a fulfill or a failure). - * @param ignoreChannels channels that should be ignored (previously returned a permanent error). + * @param ignore channels and nodes that should be ignored (previously returned a permanent error). * @param failures previous child payment failures. */ - case class PaymentProgress(sender: ActorRef, request: SendMultiPartPayment, networkStats: Option[NetworkStats], channelsCount: Int, toSend: MilliSatoshi, remainingAttempts: Int, pending: Map[UUID, SendPayment], ignoreChannels: Set[ShortChannelId], failures: Seq[PaymentFailure]) extends Data + case class PaymentProgress(sender: ActorRef, + request: SendMultiPartPayment, + remainingAttempts: Int, + pending: Map[UUID, Route], + ignore: Ignore, + failures: Seq[PaymentFailure]) extends Data + /** * When we exhaust our retry attempts without success, we abort the payment. * Once we're in that state, we wait for all the pending child payments to settle. @@ -373,6 +355,7 @@ object MultiPartPaymentLifecycle { * @param pending pending child payments (we are waiting for them to be failed downstream). */ case class PaymentAborted(sender: ActorRef, request: SendMultiPartPayment, failures: Seq[PaymentFailure], pending: Set[UUID]) extends Data + /** * Once we receive a first fulfill for a child payment, we can consider that the whole payment succeeded (because we * received the payment preimage that we can use as a proof of payment). @@ -385,160 +368,34 @@ object MultiPartPaymentLifecycle { * @param pending pending child payments (we are waiting for them to be fulfilled downstream). */ case class PaymentSucceeded(sender: ActorRef, request: SendMultiPartPayment, preimage: ByteVector32, parts: Seq[PartialPayment], pending: Set[UUID]) extends Data - // @formatter:on - /** If the payment failed immediately with a RouteNotFound, the channel we selected should be ignored in retries. */ - private def shouldBlacklistChannel(pf: PaymentFailed): Boolean = pf.failures match { - case LocalFailure(_, RouteNotFound) :: Nil => true - case _ => false + private def createRouteRequest(nodeParams: NodeParams, toSend: MilliSatoshi, maxFee: MilliSatoshi, routeParams: RouteParams, d: PaymentProgress, cfg: SendPaymentConfig): RouteRequest = + RouteRequest( + nodeParams.nodeId, + d.request.targetNodeId, + toSend, + maxFee, + d.request.assistedRoutes, + d.ignore, + Some(routeParams), + allowMultiPart = true, + d.pending.values.toSeq, + Some(cfg.paymentContext)) + + private def createChildPayment(route: Route, request: SendMultiPartPayment): SendPaymentToRoute = { + val finalPayload = Onion.createMultiPartPayload(route.amount, request.totalAmount, request.targetExpiry, request.paymentSecret, request.additionalTlvs, request.userCustomTlvs) + SendPaymentToRoute(Right(route), finalPayload) } - def getFirstHopShortChannelId(payment: SendPayment): ShortChannelId = { - require(payment.routePrefix.nonEmpty, "multi-part payment must have a route prefix") - payment.routePrefix.head.lastUpdate.shortChannelId - } + /** When we receive an error from the final recipient, we should fail the whole payment, it's useless to retry. */ + private def isFinalRecipientFailure(pf: PaymentFailed, d: PaymentProgress): Boolean = pf.failures.collectFirst { + case f: RemoteFailure if f.e.originNode == d.request.targetNodeId => true + }.getOrElse(false) - /** - * If fee limits are provided, we need to divide them between all child payments. Otherwise we could end up paying - * N * maxFee (where N is the number of child payments). - * Note that payment retries may mess up this calculation and make us pay a bit more than our fee limit. - * - * TODO: @t-bast: the router should expose a GetMultiRouteRequest API; this is where fee calculations will be more - * accurate and path-finding will be more efficient. - */ - private def setFees(routeParams: Option[RouteParams], payments: Seq[SendPayment], paymentsCount: Int): Map[UUID, SendPayment] = - payments.map(p => { - val payment = routeParams match { - case Some(routeParams) => p.copy(routeParams = Some(routeParams.copy(maxFeeBase = routeParams.maxFeeBase / paymentsCount))) - case None => p - } - (UUID.randomUUID(), payment) - }).toMap - - private def createChildPayment(nodeParams: NodeParams, request: SendMultiPartPayment, childAmount: MilliSatoshi, channel: OutgoingChannel): SendPayment = { - SendPayment( - request.targetNodeId, - Onion.createMultiPartPayload(childAmount, request.totalAmount, request.targetExpiry, request.paymentSecret, request.additionalTlvs, request.userCustomTlvs), - request.maxAttempts, - request.assistedRoutes, - request.routeParams, - ChannelHop(nodeParams.nodeId, channel.nextNodeId, channel.channelUpdate) :: Nil) - } - - /** Compute the maximum amount we should send in a single child payment. */ - private def computeThreshold(networkStats: Option[NetworkStats], localChannels: Seq[OutgoingChannel]): MilliSatoshi = { - import com.google.common.math.Quantiles.median - - import scala.collection.JavaConverters.asJavaCollectionConverter - // We use network statistics with a random factor to decide on the maximum amount for child payments. - // The current choice of parameters is completely arbitrary and could be made configurable. - // We could also learn from previous payment failures to dynamically tweak that value. - val maxAmount = networkStats.map(_.capacity.percentile75.toMilliSatoshi * ((75.0 + Random.nextInt(25)) / 100)) - // If network statistics aren't available, we'll use our local channels to choose a value. - maxAmount.getOrElse({ - val localBalanceMedian = median().compute(localChannels.map(b => java.lang.Long.valueOf(b.commitments.availableBalanceForSend.toLong)).asJavaCollection) - MilliSatoshi(localBalanceMedian.toLong) - }) - } - - /** - * Split a payment to a remote node inside a given channel. - * - * @param nodeParams node params. - * @param toSend total amount to send (may exceed the channel capacity if we have other channels available). - * @param request parent payment request. - * @param maxChildAmount maximum amount of each child payment inside that channel. - * @param maxFeeBase maximum base fee (for the future payment route). - * @param maxFeePct maximum proportional fee (for the future payment route). - * @param channel channel to use. - * @param channelCommitments channel commitments. - * @param channelPayments already-constructed child payments inside this channel. - * @return child payments to send through this channel. - */ - @tailrec - private def splitInsideChannel(nodeParams: NodeParams, - toSend: MilliSatoshi, - request: SendMultiPartPayment, - maxChildAmount: MilliSatoshi, - maxFeeBase: MilliSatoshi, - maxFeePct: Double, - channel: OutgoingChannel, - channelCommitments: Commitments, - channelPayments: Seq[SendPayment]): Seq[SendPayment] = { - // We can't use all the available balance because we need to take the fees for each child payment into account and - // we don't know the exact fee before-hand because we don't know the rest of the route yet (so we assume the worst - // case where the max fee is used). - val previousFees = channelPayments.map(p => maxFeeBase.max(p.finalPayload.amount * maxFeePct)) - val totalPreviousFee = previousFees.sum - val withFeeBase = channelCommitments.availableBalanceForSend - maxFeeBase - totalPreviousFee - val withFeePct = channelCommitments.availableBalanceForSend * (1 - maxFeePct) - totalPreviousFee - val childAmount = Seq(maxChildAmount, toSend - channelPayments.map(_.finalPayload.amount).sum, withFeeBase, withFeePct).min - if (childAmount <= 0.msat) { - channelPayments - } else if (previousFees.nonEmpty && childAmount < previousFees.max) { - // We avoid sending tiny HTLCs: that would be a waste of fees. - channelPayments - } else { - val childPayment = createChildPayment(nodeParams, request, childAmount, channel) - // Splitting into multiple HTLCs in the same channel will also increase the size of the CommitTx (and thus its - // fee), which decreases the available balance. - // We need to take that into account when trying to send multiple payments through the same channel, which is - // why we simulate adding the HTLC to the commitments. - val fakeOnion = OnionRoutingPacket(0, ByteVector.fill(33)(0), ByteVector.fill(Sphinx.PaymentPacket.PayloadLength)(0), ByteVector32.Zeroes) - val add = UpdateAddHtlc(channelCommitments.channelId, channelCommitments.localNextHtlcId + channelPayments.size, childAmount, ByteVector32.Zeroes, CltvExpiry(0), fakeOnion) - val updatedCommitments = channelCommitments.addLocalProposal(add) - splitInsideChannel(nodeParams, toSend, request, maxChildAmount, maxFeeBase, maxFeePct, channel, updatedCommitments, childPayment +: channelPayments) - } - } - - /** - * Split a payment into many child payments. - * - * @param toSend amount to split. - * @param localChannels local channels balances. - * @param request payment request containing the total amount to send and routing hints and parameters. - * @param randomize randomize the channel selection. - * @return the child payments that should be then sent to PaymentLifecycle actors. - */ - def splitPayment(nodeParams: NodeParams, toSend: MilliSatoshi, localChannels: Seq[OutgoingChannel], networkStats: Option[NetworkStats], request: SendMultiPartPayment, randomize: Boolean): (MilliSatoshi, Seq[SendPayment]) = { - require(toSend > 0.msat, "amount to send must be greater than 0") - - val maxFeePct = request.routeParams.map(_.maxFeePct).getOrElse(nodeParams.routerConf.searchMaxFeePct) - val maxFeeBase = request.routeParams.map(_.maxFeeBase).getOrElse(nodeParams.routerConf.searchMaxFeeBase.toMilliSatoshi) - - @tailrec - def split(remaining: MilliSatoshi, payments: Seq[SendPayment], channels: Seq[OutgoingChannel], splitInsideChannel: (MilliSatoshi, OutgoingChannel) => Seq[SendPayment]): Seq[SendPayment] = channels match { - case Nil => payments - case _ if remaining == 0.msat => payments - case _ if remaining < 0.msat => throw new RuntimeException(s"payment splitting error: remaining amount must not be negative ($remaining): sending $toSend to ${request.targetNodeId} with local channels=${localChannels.map(_.toUsableBalance)}, current channels=${channels.map(_.toUsableBalance)}, network=${networkStats.map(_.capacity)}, fees=($maxFeeBase, $maxFeePct)") - case channel :: rest if channel.commitments.availableBalanceForSend == 0.msat => split(remaining, payments, rest, splitInsideChannel) - case channel :: rest => - val childPayments = splitInsideChannel(remaining, channel) - split(remaining - childPayments.map(_.finalPayload.amount).sum, payments ++ childPayments, rest, splitInsideChannel) - } - - // If we have direct channels to the target, we use them without splitting the payment inside each channel. - val channelsToTarget = localChannels.filter(p => p.nextNodeId == request.targetNodeId).sortBy(_.commitments.availableBalanceForSend) - val directPayments = split(toSend, Seq.empty, channelsToTarget, (remaining: MilliSatoshi, channel: OutgoingChannel) => { - // When using direct channels to the destination, it doesn't make sense to use retries so we set maxAttempts to 1. - createChildPayment(nodeParams, request.copy(maxAttempts = 1), remaining.min(channel.commitments.availableBalanceForSend), channel) :: Nil - }) - - // Otherwise we need to split the amount based on network statistics and pessimistic fees estimates. - // Note that this will be handled more gracefully once this logic is migrated inside the router. - val channels = if (randomize) { - Random.shuffle(localChannels.filter(p => p.nextNodeId != request.targetNodeId)) - } else { - localChannels.filter(p => p.nextNodeId != request.targetNodeId).sortBy(_.commitments.availableBalanceForSend) - } - val remotePayments = split(toSend - directPayments.map(_.finalPayload.amount).sum, Seq.empty, channels, (remaining: MilliSatoshi, channel: OutgoingChannel) => { - // We re-generate a split threshold for each channel to randomize the amounts. - val maxChildAmount = computeThreshold(networkStats, localChannels) - splitInsideChannel(nodeParams, remaining, request, maxChildAmount, maxFeeBase, maxFeePct, channel, channel.commitments, Nil) - }) - - val childPayments = directPayments ++ remotePayments - (toSend - childPayments.map(_.finalPayload.amount).sum, childPayments) + private def remainingToSend(nodeParams: NodeParams, request: SendMultiPartPayment, pending: Iterable[Route]): (MilliSatoshi, MilliSatoshi) = { + val sentAmount = pending.map(_.amount).sum + val sentFees = pending.map(_.fee).sum + (request.totalAmount - sentAmount, request.getRouteParams(nodeParams, randomize = false).getMaxFee(request.totalAmount) - sentFees) } } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentError.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentError.scala index d660dfde7..e3ecc28be 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentError.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentError.scala @@ -41,8 +41,6 @@ object PaymentError { // @formatter:on // @formatter:off - /** Outbound capacity is too low. */ - case object BalanceTooLow extends PaymentError /** Payment attempts exhausted without success. */ case object RetryExhausted extends PaymentError // @formatter:on diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala index ab937c617..acf16f74a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala @@ -26,10 +26,11 @@ import fr.acinq.eclair.channel.{Channel, Upstream} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.payment._ -import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.SendMultiPartPayment +import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment} import fr.acinq.eclair.payment.send.PaymentError._ import fr.acinq.eclair.payment.send.PaymentLifecycle.{SendPayment, SendPaymentToRoute} -import fr.acinq.eclair.router.Router.{Hop, NodeHop, Route, RouteParams} +import fr.acinq.eclair.router.RouteNotFound +import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.wire.Onion.FinalLegacyPayload import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, NodeParams, randomBytes32} @@ -37,7 +38,7 @@ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, MilliSatos /** * Created by PM on 29/08/2016. */ -class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, relayer: ActorRef, register: ActorRef) extends Actor with ActorLogging { +class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: ActorRef) extends Actor with ActorLogging { import PaymentInitiator._ @@ -83,19 +84,33 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, relayer: ActorR case pf: PaymentFailed => pending.get(pf.id).foreach(pp => { val decryptedFailures = pf.failures.collect { case RemoteFailure(_, Sphinx.DecryptedFailurePacket(_, f)) => f } - val canRetry = decryptedFailures.contains(TrampolineFeeInsufficient) || decryptedFailures.contains(TrampolineExpiryTooSoon) - pp.remainingAttempts match { - case (trampolineFees, trampolineExpiryDelta) :: remainingAttempts if canRetry => - log.info(s"retrying trampoline payment with trampoline fees=$trampolineFees and expiry delta=$trampolineExpiryDelta") - sendTrampolinePayment(pf.id, pp.r, trampolineFees, trampolineExpiryDelta) - context become main(pending + (pf.id -> pp.copy(remainingAttempts = remainingAttempts))) - case _ => - pp.sender ! pf - context.system.eventStream.publish(pf) - context become main(pending - pf.id) + val shouldRetry = decryptedFailures.contains(TrampolineFeeInsufficient) || decryptedFailures.contains(TrampolineExpiryTooSoon) + if (shouldRetry) { + pp.remainingAttempts match { + case (trampolineFees, trampolineExpiryDelta) :: remaining => + log.info(s"retrying trampoline payment with trampoline fees=$trampolineFees and expiry delta=$trampolineExpiryDelta") + sendTrampolinePayment(pf.id, pp.r, trampolineFees, trampolineExpiryDelta) + context become main(pending + (pf.id -> pp.copy(remainingAttempts = remaining))) + case Nil => + log.info("trampoline node couldn't find a route after all retries") + val trampolineRoute = Seq( + NodeHop(nodeParams.nodeId, pp.r.trampolineNodeId, nodeParams.expiryDeltaBlocks, 0 msat), + NodeHop(pp.r.trampolineNodeId, pp.r.recipientNodeId, pp.r.trampolineAttempts.last._2, pp.r.trampolineAttempts.last._1) + ) + val localFailure = pf.copy(failures = Seq(LocalFailure(trampolineRoute, RouteNotFound))) + pp.sender ! localFailure + context.system.eventStream.publish(localFailure) + context become main(pending - pf.id) + } + } else { + pp.sender ! pf + context.system.eventStream.publish(pf) + context become main(pending - pf.id) } }) + case _: PreimageReceived => // we received the preimage, but we wait for the PaymentSent event that will contain more data + case ps: PaymentSent => pending.get(ps.id).foreach(pp => { pp.sender ! ps context.system.eventStream.publish(ps) @@ -116,12 +131,12 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, relayer: ActorR val trampolineSecret = r.trampolineSecret.getOrElse(randomBytes32) sender ! SendPaymentToRouteResponse(paymentId, parentPaymentId, Some(trampolineSecret)) val (trampolineAmount, trampolineExpiry, trampolineOnion) = buildTrampolinePayment(SendTrampolinePaymentRequest(r.recipientAmount, r.paymentRequest, trampoline, Seq((r.trampolineFees, r.trampolineExpiryDelta)), r.finalExpiryDelta), r.trampolineFees, r.trampolineExpiryDelta) - payFsm forward SendPaymentToRoute(r.route, Onion.createMultiPartPayload(r.amount, trampolineAmount, trampolineExpiry, trampolineSecret, Seq(OnionTlv.TrampolineOnion(trampolineOnion))), r.paymentRequest.routingInfo) + payFsm forward SendPaymentToRoute(Left(r.route), Onion.createMultiPartPayload(r.amount, trampolineAmount, trampolineExpiry, trampolineSecret, Seq(OnionTlv.TrampolineOnion(trampolineOnion))), r.paymentRequest.routingInfo) case Nil => sender ! SendPaymentToRouteResponse(paymentId, parentPaymentId, None) r.paymentRequest.paymentSecret match { - case Some(paymentSecret) => payFsm forward SendPaymentToRoute(r.route, Onion.createMultiPartPayload(r.amount, r.recipientAmount, finalExpiry, paymentSecret), r.paymentRequest.routingInfo) - case None => payFsm forward SendPaymentToRoute(r.route, FinalLegacyPayload(r.recipientAmount, finalExpiry), r.paymentRequest.routingInfo) + case Some(paymentSecret) => payFsm forward SendPaymentToRoute(Left(r.route), Onion.createMultiPartPayload(r.amount, r.recipientAmount, finalExpiry, paymentSecret), r.paymentRequest.routingInfo) + case None => payFsm forward SendPaymentToRoute(Left(r.route), FinalLegacyPayload(r.recipientAmount, finalExpiry), r.paymentRequest.routingInfo) } case _ => sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(Nil, TrampolineMultiNodeNotSupported) :: Nil) @@ -130,7 +145,7 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, relayer: ActorR def spawnPaymentFsm(paymentCfg: SendPaymentConfig): ActorRef = context.actorOf(PaymentLifecycle.props(nodeParams, paymentCfg, router, register)) - def spawnMultiPartPaymentFsm(paymentCfg: SendPaymentConfig): ActorRef = context.actorOf(MultiPartPaymentLifecycle.props(nodeParams, paymentCfg, relayer, router, register)) + def spawnMultiPartPaymentFsm(paymentCfg: SendPaymentConfig): ActorRef = context.actorOf(MultiPartPaymentLifecycle.props(nodeParams, paymentCfg, router, register)) private def buildTrampolinePayment(r: SendTrampolinePaymentRequest, trampolineFees: MilliSatoshi, trampolineExpiryDelta: CltvExpiryDelta): (MilliSatoshi, CltvExpiry, OnionRoutingPacket) = { val trampolineRoute = Seq( @@ -163,7 +178,7 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, relayer: ActorR object PaymentInitiator { - def props(nodeParams: NodeParams, router: ActorRef, relayer: ActorRef, register: ActorRef) = Props(new PaymentInitiator(nodeParams, router, relayer, register)) + def props(nodeParams: NodeParams, router: ActorRef, register: ActorRef) = Props(new PaymentInitiator(nodeParams, router, register)) case class PendingPayment(sender: ActorRef, remainingAttempts: Seq[(MilliSatoshi, CltvExpiryDelta)], r: SendTrampolinePaymentRequest) @@ -316,6 +331,8 @@ object PaymentInitiator { def fullRoute(route: Route): Seq[Hop] = route.hops ++ additionalHops def createPaymentSent(preimage: ByteVector32, parts: Seq[PaymentSent.PartialPayment]) = PaymentSent(parentId, paymentHash, preimage, recipientAmount, recipientNodeId, parts) + + def paymentContext: PaymentContext = PaymentContext(id, parentId, paymentHash) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala index fd831c527..565e96f22 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala @@ -75,44 +75,39 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A span.tag(Tags.Amount, c.finalPayload.amount.toLong) span.tag(Tags.TotalAmount, c.finalPayload.totalAmount.toLong) span.tag(Tags.Expiry, c.finalPayload.expiry.toLong) - log.debug("sending {} to route {}", c.finalPayload.amount, c.hops.mkString("->")) - val send = SendPayment(c.hops.last, c.finalPayload, maxAttempts = 1) - router ! FinalizeRoute(c.finalPayload.amount, c.hops, c.assistedRoutes) + log.debug("sending {} to route {}", c.finalPayload.amount, c.printRoute()) + val send = SendPayment(c.targetNodeId, c.finalPayload, maxAttempts = 1, assistedRoutes = c.assistedRoutes) + c.route.fold( + hops => router ! FinalizeRoute(c.finalPayload.amount, hops, c.assistedRoutes, paymentContext = Some(cfg.paymentContext)), + route => self ! RouteResponse(route :: Nil) + ) if (cfg.storeInDb) { paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, c.finalPayload.amount, cfg.recipientAmount, cfg.recipientNodeId, System.currentTimeMillis, cfg.paymentRequest, OutgoingPaymentStatus.Pending)) } - goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, send, Nil, Set.empty, Set.empty) + goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, send, Nil, Ignore.empty) case Event(c: SendPayment, WaitingForRequest) => span.tag(Tags.TargetNodeId, c.targetNodeId.toString()) span.tag(Tags.Amount, c.finalPayload.amount.toLong) span.tag(Tags.TotalAmount, c.finalPayload.totalAmount.toLong) span.tag(Tags.Expiry, c.finalPayload.expiry.toLong) - log.debug("sending {} to {}{}", c.finalPayload.amount, c.targetNodeId, c.routePrefix.mkString(" with route prefix ", "->", "")) - // We don't want the router to try cycling back to nodes that are at the beginning of the route. - val ignoredNodes = c.routePrefix.map(_.nodeId).toSet - if (c.routePrefix.lastOption.exists(_.nextNodeId == c.targetNodeId)) { - // If the sender already provided a route to the target, no need to involve the router. - self ! RouteResponse(Seq(Route(c.finalPayload.amount, Nil, allowEmpty = true))) - } else { - router ! RouteRequest(c.getRouteRequestStart(nodeParams), c.targetNodeId, c.finalPayload.amount, c.getMaxFee(nodeParams), c.assistedRoutes, routeParams = c.routeParams, ignoreNodes = ignoredNodes) - } + log.debug("sending {} to {}", c.finalPayload.amount, c.targetNodeId) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.getMaxFee(nodeParams), c.assistedRoutes, routeParams = c.routeParams, paymentContext = Some(cfg.paymentContext)) if (cfg.storeInDb) { paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, c.finalPayload.amount, cfg.recipientAmount, cfg.recipientNodeId, System.currentTimeMillis, cfg.paymentRequest, OutgoingPaymentStatus.Pending)) } - goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, c, Nil, ignoredNodes, Set.empty) + goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, c, Nil, Ignore.empty) } when(WAITING_FOR_ROUTE) { - case Event(RouteResponse(routes), WaitingForRoute(s, c, failures, ignoreNodes, ignoreChannels)) => - val hops = c.routePrefix ++ routes.head.hops - log.info(s"route found: attempt=${failures.size + 1}/${c.maxAttempts} route=${hops.map(_.nextNodeId).mkString("->")} channels=${hops.map(_.lastUpdate.shortChannelId).mkString("->")}") - val firstHop = hops.head - val (cmd, sharedSecrets) = OutgoingPacket.buildCommand(cfg.upstream, paymentHash, hops, c.finalPayload) - register ! Register.ForwardShortId(firstHop.lastUpdate.shortChannelId, cmd) - goto(WAITING_FOR_PAYMENT_COMPLETE) using WaitingForComplete(s, c, cmd, failures, sharedSecrets, ignoreNodes, ignoreChannels, Route(c.finalPayload.amount, hops)) + case Event(RouteResponse(route +: _), WaitingForRoute(s, c, failures, ignore)) => + log.info(s"route found: attempt=${failures.size + 1}/${c.maxAttempts} route=${route.printNodes()} channels=${route.printChannels()}") + val (cmd, sharedSecrets) = OutgoingPacket.buildCommand(cfg.upstream, paymentHash, route.hops, c.finalPayload) + register ! Register.ForwardShortId(route.hops.head.lastUpdate.shortChannelId, cmd) + goto(WAITING_FOR_PAYMENT_COMPLETE) using WaitingForComplete(s, c, cmd, failures, sharedSecrets, ignore, route) - case Event(Status.Failure(t), WaitingForRoute(s, _, failures, _, _)) => + case Event(Status.Failure(t), WaitingForRoute(s, _, failures, _)) => + log.warning("router error: {}", t.getMessage) Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(Nil, t))).increment() onFailure(s, PaymentFailed(id, paymentHash, failures :+ LocalFailure(Nil, t))) myStop() @@ -121,7 +116,8 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A when(WAITING_FOR_PAYMENT_COMPLETE) { case Event(ChannelCommandResponse.Ok, _) => stay - case Event(fulfill: Relayer.ForwardFulfill, WaitingForComplete(s, c, cmd, _, _, _, _, route)) => + case Event(fulfill: Relayer.ForwardFulfill, WaitingForComplete(s, c, cmd, failures, _, _, route)) => + Metrics.PaymentAttempt.withTag(Tags.MultiPart, value = false).record(failures.size + 1) val p = PartialPayment(id, c.finalPayload.amount, cmd.amount - c.finalPayload.amount, fulfill.htlc.channelId, Some(cfg.fullRoute(route))) onSuccess(s, cfg.createPaymentSent(fulfill.paymentPreimage, p :: Nil)) myStop() @@ -134,7 +130,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A } stay - case Event(fail: UpdateFailHtlc, data@WaitingForComplete(s, c, _, failures, sharedSecrets, ignoreNodes, ignoreChannels, route)) => + case Event(fail: UpdateFailHtlc, data@WaitingForComplete(s, c, _, failures, sharedSecrets, ignore, route)) => (Sphinx.FailurePacket.decrypt(fail.reason, sharedSecrets) match { case success@Success(e) => Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(RemoteFailure(Nil, e))).increment() @@ -153,6 +149,10 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A val failure = res match { case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) => log.info(s"received an error message from nodeId=$nodeId (failure=$failureMessage)") + failureMessage match { + case failureMessage: Update => handleUpdate(nodeId, failureMessage, data) + case _ => + } RemoteFailure(cfg.fullRoute(route), e) case Failure(t) => log.warning(s"cannot parse returned error: ${t.getMessage}") @@ -162,7 +162,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A onFailure(s, PaymentFailed(id, paymentHash, failures :+ failure)) myStop() case Failure(t) => - log.warning(s"cannot parse returned error: ${t.getMessage}, route=${route.hops.map(_.nextNodeId)}") + log.warning(s"cannot parse returned error: ${t.getMessage}, route=${route.printNodes()}") val failure = UnreadableRemoteFailure(cfg.fullRoute(route)) retry(failure, data) case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Node)) => @@ -171,48 +171,18 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A retry(failure, data) case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Update)) => log.info(s"received 'Update' type error message from nodeId=$nodeId, retrying payment (failure=$failureMessage)") - val ignoreNodes1 = if (Announcements.checkSig(failureMessage.update, nodeId)) { - route.getChannelUpdateForNode(nodeId) match { - case Some(u) if u.shortChannelId != failureMessage.update.shortChannelId => - // it is possible that nodes in the route prefer using a different channel (to the same N+1 node) than the one we requested, that's fine - log.info(s"received an update for a different channel than the one we asked: requested=${u.shortChannelId} actual=${failureMessage.update.shortChannelId} update=${failureMessage.update}") - case Some(u) if Announcements.areSame(u, failureMessage.update) => - // node returned the exact same update we used, this can happen e.g. if the channel is imbalanced - // in that case, let's temporarily exclude the channel from future routes, giving it time to recover - log.info(s"received exact same update from nodeId=$nodeId, excluding the channel from futures routes") - val nextNodeId = route.hops.find(_.nodeId == nodeId).get.nextNodeId - router ! ExcludeChannel(ChannelDesc(u.shortChannelId, nodeId, nextNodeId)) - case Some(u) if PaymentFailure.hasAlreadyFailedOnce(nodeId, failures) => - // this node had already given us a new channel update and is still unhappy, it is probably messing with us, let's exclude it - log.warning(s"it is the second time nodeId=$nodeId answers with a new update, excluding it: old=$u new=${failureMessage.update}") - val nextNodeId = route.hops.find(_.nodeId == nodeId).get.nextNodeId - router ! ExcludeChannel(ChannelDesc(u.shortChannelId, nodeId, nextNodeId)) - case Some(u) => - log.info(s"got a new update for shortChannelId=${u.shortChannelId}: old=$u new=${failureMessage.update}") - case None => - log.error(s"couldn't find a channel update for node=$nodeId, this should never happen") - } - // in any case, we forward the update to the router - router ! failureMessage.update - // we also update assisted routes, because they take precedence over the router's routing table - val assistedRoutes1 = c.assistedRoutes.map(_.map { - case extraHop: ExtraHop if extraHop.shortChannelId == failureMessage.update.shortChannelId => extraHop.copy( - cltvExpiryDelta = failureMessage.update.cltvExpiryDelta, - feeBase = failureMessage.update.feeBaseMsat, - feeProportionalMillionths = failureMessage.update.feeProportionalMillionths - ) - case extraHop => extraHop - }) + val ignore1 = if (Announcements.checkSig(failureMessage.update, nodeId)) { + val assistedRoutes1 = handleUpdate(nodeId, failureMessage, data) // let's try again, router will have updated its state - router ! RouteRequest(c.getRouteRequestStart(nodeParams), c.targetNodeId, c.finalPayload.amount, c.getMaxFee(nodeParams), assistedRoutes1, ignoreNodes, ignoreChannels, c.routeParams) - ignoreNodes + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.getMaxFee(nodeParams), assistedRoutes1, ignore, c.routeParams, paymentContext = Some(cfg.paymentContext)) + ignore } else { // this node is fishy, it gave us a bad sig!! let's filter it out log.warning(s"got bad signature from node=$nodeId update=${failureMessage.update}") - router ! RouteRequest(c.getRouteRequestStart(nodeParams), c.targetNodeId, c.finalPayload.amount, c.getMaxFee(nodeParams), c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams) - ignoreNodes + nodeId + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.getMaxFee(nodeParams), c.assistedRoutes, ignore + nodeId, c.routeParams, paymentContext = Some(cfg.paymentContext)) + ignore + nodeId } - goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(cfg.fullRoute(route), e), ignoreNodes1, ignoreChannels) + goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(cfg.fullRoute(route), e), ignore1) case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) => log.info(s"received an error message from nodeId=$nodeId, trying to use a different channel (failure=$failureMessage)") val failure = RemoteFailure(cfg.fullRoute(route), e) @@ -228,10 +198,9 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A self ! Status.Failure(new RuntimeException("first hop returned an UpdateFailMalformedHtlc message")) stay - case Event(Status.Failure(t), data@WaitingForComplete(s, c, _, failures, _, _, _, hops)) => + case Event(Status.Failure(t), data@WaitingForComplete(s, c, _, failures, _, _, hops)) => Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(cfg.fullRoute(hops), t))).increment() val isFatal = failures.size + 1 >= c.maxAttempts || // retries exhausted - c.routePrefix.nonEmpty || // first hop was selected by the sender and failed, it doesn't make sense to retry t.isInstanceOf[HtlcsTimedoutDownstream] // htlc timed out so retrying won't help, we need to re-compute cltvs if (isFatal) { onFailure(s, PaymentFailed(id, paymentHash, failures :+ LocalFailure(cfg.fullRoute(hops), t))) @@ -266,9 +235,57 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A } private def retry(failure: PaymentFailure, data: WaitingForComplete): FSM.State[PaymentLifecycle.State, PaymentLifecycle.Data] = { - val (ignoreNodes1, ignoreChannels1) = PaymentFailure.updateIgnored(failure, data.ignoreNodes, data.ignoreChannels) - router ! RouteRequest(data.c.getRouteRequestStart(nodeParams), data.c.targetNodeId, data.c.finalPayload.amount, data.c.getMaxFee(nodeParams), data.c.assistedRoutes, ignoreNodes1, ignoreChannels1, data.c.routeParams) - goto(WAITING_FOR_ROUTE) using WaitingForRoute(data.sender, data.c, data.failures :+ failure, ignoreNodes1, ignoreChannels1) + val ignore1 = PaymentFailure.updateIgnored(failure, data.ignore) + router ! RouteRequest(nodeParams.nodeId, data.c.targetNodeId, data.c.finalPayload.amount, data.c.getMaxFee(nodeParams), data.c.assistedRoutes, ignore1, data.c.routeParams, paymentContext = Some(cfg.paymentContext)) + goto(WAITING_FOR_ROUTE) using WaitingForRoute(data.sender, data.c, data.failures :+ failure, ignore1) + } + + /** + * Apply the channel update to our routing table. + * + * @return updated routing hints if applicable. + */ + private def handleUpdate(nodeId: PublicKey, failure: Update, data: WaitingForComplete): Seq[Seq[ExtraHop]] = { + data.route.getChannelUpdateForNode(nodeId) match { + case Some(u) if u.shortChannelId != failure.update.shortChannelId => + // it is possible that nodes in the route prefer using a different channel (to the same N+1 node) than the one we requested, that's fine + log.info(s"received an update for a different channel than the one we asked: requested=${u.shortChannelId} actual=${failure.update.shortChannelId} update=${failure.update}") + case Some(u) if Announcements.areSame(u, failure.update) => + // node returned the exact same update we used, this can happen e.g. if the channel is imbalanced + // in that case, let's temporarily exclude the channel from future routes, giving it time to recover + log.info(s"received exact same update from nodeId=$nodeId, excluding the channel from futures routes") + val nextNodeId = data.route.hops.find(_.nodeId == nodeId).get.nextNodeId + router ! ExcludeChannel(ChannelDesc(u.shortChannelId, nodeId, nextNodeId)) + case Some(u) if PaymentFailure.hasAlreadyFailedOnce(nodeId, data.failures) => + // this node had already given us a new channel update and is still unhappy, it is probably messing with us, let's exclude it + log.warning(s"it is the second time nodeId=$nodeId answers with a new update, excluding it: old=$u new=${failure.update}") + val nextNodeId = data.route.hops.find(_.nodeId == nodeId).get.nextNodeId + router ! ExcludeChannel(ChannelDesc(u.shortChannelId, nodeId, nextNodeId)) + case Some(u) => + log.info(s"got a new update for shortChannelId=${u.shortChannelId}: old=$u new=${failure.update}") + case None => + log.error(s"couldn't find a channel update for node=$nodeId, this should never happen") + } + // in any case, we forward the update to the router: if the channel is disabled, the router will remove it from its routing table + router ! failure.update + // we return updated assisted routes: they take precedence over the router's routing table + if (Announcements.isEnabled(failure.update.channelFlags)) { + data.c.assistedRoutes.map(_.map { + case extraHop: ExtraHop if extraHop.shortChannelId == failure.update.shortChannelId => extraHop.copy( + cltvExpiryDelta = failure.update.cltvExpiryDelta, + feeBase = failure.update.feeBaseMsat, + feeProportionalMillionths = failure.update.feeProportionalMillionths + ) + case extraHop => extraHop + }) + } else { + // if the channel is disabled, we temporarily exclude it: this is necessary because the routing hint doesn't contain + // channel flags to indicate that it's disabled + data.c.assistedRoutes.flatMap(r => RouteCalculation.toChannelDescs(r, data.c.targetNodeId)) + .find(_.shortChannelId == failure.update.shortChannelId) + .foreach(desc => router ! ExcludeChannel(desc)) // we want the exclusion to be router-wide so that sister payments in the case of MPP are aware the channel is faulty + data.c.assistedRoutes + } } private def myStop(): State = { @@ -312,12 +329,15 @@ object PaymentLifecycle { /** * Send a payment to a pre-defined route without running the path-finding algorithm. * - * @param hops payment route to use. + * @param route payment route to use. * @param finalPayload onion payload for the target node. */ - case class SendPaymentToRoute(hops: Seq[PublicKey], finalPayload: FinalPayload, assistedRoutes: Seq[Seq[ExtraHop]] = Nil) { - require(hops.nonEmpty, s"payment route must not be empty") - val targetNodeId = hops.last + case class SendPaymentToRoute(route: Either[Seq[PublicKey], Route], finalPayload: FinalPayload, assistedRoutes: Seq[Seq[ExtraHop]] = Nil) { + require(route.fold(_.nonEmpty, _.hops.nonEmpty), "payment route must not be empty") + + val targetNodeId = route.fold(_.last, _.hops.last.nextNodeId) + + def printRoute(): String = route.fold(nodes => nodes, _.hops.map(_.nextNodeId)).mkString("->") } /** @@ -329,32 +349,24 @@ object PaymentLifecycle { * @param maxAttempts maximum number of retries. * @param assistedRoutes routing hints (usually from a Bolt 11 invoice). * @param routeParams parameters to fine-tune the routing algorithm. - * @param routePrefix when provided, the payment route will start with these hops. Path-finding will run only to - * find how to route from the last node of the route prefix to the target node. */ case class SendPayment(targetNodeId: PublicKey, finalPayload: FinalPayload, maxAttempts: Int, assistedRoutes: Seq[Seq[ExtraHop]] = Nil, - routeParams: Option[RouteParams] = None, - routePrefix: Seq[ChannelHop] = Nil) { + routeParams: Option[RouteParams] = None) { require(finalPayload.amount > 0.msat, s"amount must be > 0") def getMaxFee(nodeParams: NodeParams): MilliSatoshi = routeParams.getOrElse(RouteCalculation.getDefaultRouteParams(nodeParams.routerConf)).getMaxFee(finalPayload.amount) - /** Returns the node from which the path-finding algorithm should start. */ - def getRouteRequestStart(nodeParams: NodeParams): PublicKey = routePrefix match { - case Nil => nodeParams.nodeId - case prefix => prefix.last.nextNodeId - } } // @formatter:off sealed trait Data case object WaitingForRequest extends Data - case class WaitingForRoute(sender: ActorRef, c: SendPayment, failures: Seq[PaymentFailure], ignoreNodes: Set[PublicKey], ignoreChannels: Set[ChannelDesc]) extends Data - case class WaitingForComplete(sender: ActorRef, c: SendPayment, cmd: CMD_ADD_HTLC, failures: Seq[PaymentFailure], sharedSecrets: Seq[(ByteVector32, PublicKey)], ignoreNodes: Set[PublicKey], ignoreChannels: Set[ChannelDesc], route: Route) extends Data + case class WaitingForRoute(sender: ActorRef, c: SendPayment, failures: Seq[PaymentFailure], ignore: Ignore) extends Data + case class WaitingForComplete(sender: ActorRef, c: SendPayment, cmd: CMD_ADD_HTLC, failures: Seq[PaymentFailure], sharedSecrets: Seq[(ByteVector32, PublicKey)], ignore: Ignore, route: Route) extends Data sealed trait State case object WAITING_FOR_REQUEST extends State diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Monitoring.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Monitoring.scala index 73ff63ba3..b6fc2513c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Monitoring.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Monitoring.scala @@ -28,6 +28,7 @@ object Monitoring { val FindRouteDuration = Kamon.timer("router.find-route.duration", "Path-finding duration") val FindRouteErrors = Kamon.counter("router.find-route.errors", "Path-finding errors") val RouteLength = Kamon.histogram("router.find-route.length", "Path-finding result length") + val RouteResults = Kamon.histogram("router.find-route.results", "Path-finding number of routes found") object QueryChannelRange { val Blocks = Kamon.histogram("router.gossip.query-channel-range.blocks", "Number of blocks requested in query-channel-range") @@ -71,6 +72,7 @@ object Monitoring { val Announced = "announced" val Direction = "direction" val Error = "error" + val MultiPart = "multiPart" val NumberOfRoutes = "numRoutes" object Directions { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala index ce3440bbd..bf1be1308 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala @@ -17,9 +17,10 @@ package fr.acinq.eclair.router import akka.actor.{ActorContext, ActorRef, Status} -import akka.event.LoggingAdapter +import akka.event.{DiagnosticLoggingAdapter, LoggingAdapter} import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Satoshi} +import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph.graphEdgeToHop import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} @@ -28,6 +29,7 @@ import fr.acinq.eclair.router.Monitoring.{Metrics, Tags} import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.wire.ChannelUpdate import fr.acinq.eclair.{ShortChannelId, _} +import kamon.tag.TagSet import scala.annotation.tailrec import scala.collection.mutable @@ -36,54 +38,75 @@ import scala.util.{Failure, Random, Success, Try} object RouteCalculation { - def finalizeRoute(d: Data, fr: FinalizeRoute)(implicit ctx: ActorContext, log: LoggingAdapter): Data = { - implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + def finalizeRoute(d: Data, fr: FinalizeRoute)(implicit ctx: ActorContext, log: DiagnosticLoggingAdapter): Data = { + Logs.withMdc(log)(Logs.mdc( + category_opt = Some(LogCategory.PAYMENT), + parentPaymentId_opt = fr.paymentContext.map(_.parentId), + paymentId_opt = fr.paymentContext.map(_.id), + paymentHash_opt = fr.paymentContext.map(_.paymentHash))) { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors - val assistedChannels: Map[ShortChannelId, AssistedChannel] = fr.assistedRoutes.flatMap(toAssistedChannels(_, fr.hops.last, fr.amount)).toMap - val extraEdges = assistedChannels.values.map(ac => - GraphEdge(ChannelDesc(ac.extraHop.shortChannelId, ac.extraHop.nodeId, ac.nextNodeId), toFakeUpdate(ac.extraHop, ac.htlcMaximum), htlcMaxToCapacity(ac.htlcMaximum), Some(ac.htlcMaximum)) - ).toSet - val g = extraEdges.foldLeft(d.graph) { case (g: DirectedGraph, e: GraphEdge) => g.addEdge(e) } - // split into sublists [(a,b),(b,c), ...] then get the edges between each of those pairs - fr.hops.sliding(2).map { case List(v1, v2) => g.getEdgesBetween(v1, v2) }.toList match { - case edges if edges.nonEmpty && edges.forall(_.nonEmpty) => - // select the largest edge (using balance when available, otherwise capacity). - val selectedEdges = edges.map(es => es.maxBy(e => e.balance_opt.getOrElse(e.capacity.toMilliSatoshi))) - val hops = selectedEdges.map(d => ChannelHop(d.desc.a, d.desc.b, d.update)) - ctx.sender ! RouteResponse(Route(fr.amount, hops) :: Nil) - case _ => // some nodes in the supplied route aren't connected in our graph - ctx.sender ! Status.Failure(new IllegalArgumentException("Not all the nodes in the supplied route are connected with public channels")) + val assistedChannels: Map[ShortChannelId, AssistedChannel] = fr.assistedRoutes.flatMap(toAssistedChannels(_, fr.hops.last, fr.amount)).toMap + val extraEdges = assistedChannels.values.map(ac => + GraphEdge(ChannelDesc(ac.extraHop.shortChannelId, ac.extraHop.nodeId, ac.nextNodeId), toFakeUpdate(ac.extraHop, ac.htlcMaximum), htlcMaxToCapacity(ac.htlcMaximum), Some(ac.htlcMaximum)) + ).toSet + val g = extraEdges.foldLeft(d.graph) { case (g: DirectedGraph, e: GraphEdge) => g.addEdge(e) } + // split into sublists [(a,b),(b,c), ...] then get the edges between each of those pairs + fr.hops.sliding(2).map { case List(v1, v2) => g.getEdgesBetween(v1, v2) }.toList match { + case edges if edges.nonEmpty && edges.forall(_.nonEmpty) => + // select the largest edge (using balance when available, otherwise capacity). + val selectedEdges = edges.map(es => es.maxBy(e => e.balance_opt.getOrElse(e.capacity.toMilliSatoshi))) + val hops = selectedEdges.map(d => ChannelHop(d.desc.a, d.desc.b, d.update)) + ctx.sender ! RouteResponse(Route(fr.amount, hops) :: Nil) + case _ => // some nodes in the supplied route aren't connected in our graph + ctx.sender ! Status.Failure(new IllegalArgumentException("Not all the nodes in the supplied route are connected with public channels")) + } + d } - d } - def handleRouteRequest(d: Data, routerConf: RouterConf, currentBlockHeight: Long, r: RouteRequest)(implicit ctx: ActorContext, log: LoggingAdapter): Data = { - implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + def handleRouteRequest(d: Data, routerConf: RouterConf, currentBlockHeight: Long, r: RouteRequest)(implicit ctx: ActorContext, log: DiagnosticLoggingAdapter): Data = { + Logs.withMdc(log)(Logs.mdc( + category_opt = Some(LogCategory.PAYMENT), + parentPaymentId_opt = r.paymentContext.map(_.parentId), + paymentId_opt = r.paymentContext.map(_.id), + paymentHash_opt = r.paymentContext.map(_.paymentHash))) { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors - // we convert extra routing info provided in the payment request to fake channel_update - // it takes precedence over all other channel_updates we know - val assistedChannels: Map[ShortChannelId, AssistedChannel] = r.assistedRoutes.flatMap(toAssistedChannels(_, r.target, r.amount)).toMap - val extraEdges = assistedChannels.values.map(ac => - GraphEdge(ChannelDesc(ac.extraHop.shortChannelId, ac.extraHop.nodeId, ac.nextNodeId), toFakeUpdate(ac.extraHop, ac.htlcMaximum), htlcMaxToCapacity(ac.htlcMaximum), Some(ac.htlcMaximum)) - ).toSet - val ignoredEdges = r.ignoreChannels ++ d.excludedChannels - val defaultRouteParams: RouteParams = getDefaultRouteParams(routerConf) - val params = r.routeParams.getOrElse(defaultRouteParams) - val routesToFind = if (params.randomize) DEFAULT_ROUTES_COUNT else 1 + // we convert extra routing info provided in the payment request to fake channel_update + // it takes precedence over all other channel_updates we know + val assistedChannels: Map[ShortChannelId, AssistedChannel] = r.assistedRoutes.flatMap(toAssistedChannels(_, r.target, r.amount)) + .filterNot { case (_, ac) => ac.extraHop.nodeId == r.source } // we ignore routing hints for our own channels, we have more accurate information + .toMap + val extraEdges = assistedChannels.values.map(ac => + GraphEdge(ChannelDesc(ac.extraHop.shortChannelId, ac.extraHop.nodeId, ac.nextNodeId), toFakeUpdate(ac.extraHop, ac.htlcMaximum), htlcMaxToCapacity(ac.htlcMaximum), Some(ac.htlcMaximum)) + ).toSet + val ignoredEdges = r.ignore.channels ++ d.excludedChannels + val params = r.routeParams.getOrElse(getDefaultRouteParams(routerConf)) + val routesToFind = if (params.randomize) DEFAULT_ROUTES_COUNT else 1 - log.info(s"finding a route ${r.source}->${r.target} with assistedChannels={} ignoreNodes={} ignoreChannels={} excludedChannels={}", assistedChannels.keys.mkString(","), r.ignoreNodes.map(_.value).mkString(","), r.ignoreChannels.mkString(","), d.excludedChannels.mkString(",")) - log.info(s"finding a route with randomize={} params={}", routesToFind > 1, params) - KamonExt.time(Metrics.FindRouteDuration.withTag(Tags.NumberOfRoutes, routesToFind).withTag(Tags.Amount, Tags.amountBucket(r.amount))) { - findRoute(d.graph, r.source, r.target, r.amount, r.maxFee, routesToFind, extraEdges, ignoredEdges, r.ignoreNodes, params, currentBlockHeight) match { - case Success(routes) => - Metrics.RouteLength.withTag(Tags.Amount, Tags.amountBucket(r.amount)).record(routes.head.length) - ctx.sender ! RouteResponse(routes) - case Failure(t) => - Metrics.FindRouteErrors.withTag(Tags.Amount, Tags.amountBucket(r.amount)).withTag(Tags.Error, t.getClass.getSimpleName).increment() - ctx.sender ! Status.Failure(t) + log.info(s"finding routes ${r.source}->${r.target} with assistedChannels={} ignoreNodes={} ignoreChannels={} excludedChannels={}", assistedChannels.keys.mkString(","), r.ignore.nodes.map(_.value).mkString(","), r.ignore.channels.mkString(","), d.excludedChannels.mkString(",")) + log.info("finding routes with randomize={} params={}", params.randomize, params) + val tags = TagSet.Empty.withTag(Tags.MultiPart, r.allowMultiPart).withTag(Tags.Amount, Tags.amountBucket(r.amount)) + KamonExt.time(Metrics.FindRouteDuration.withTags(tags.withTag(Tags.NumberOfRoutes, routesToFind.toLong))) { + val result = if (r.allowMultiPart) { + findMultiPartRoute(d.graph, r.source, r.target, r.amount, r.maxFee, extraEdges, ignoredEdges, r.ignore.nodes, r.pendingPayments, params, currentBlockHeight) + } else { + findRoute(d.graph, r.source, r.target, r.amount, r.maxFee, routesToFind, extraEdges, ignoredEdges, r.ignore.nodes, params, currentBlockHeight) + } + result match { + case Success(routes) => + Metrics.RouteResults.withTags(tags).record(routes.length) + routes.foreach(route => Metrics.RouteLength.withTags(tags).record(route.length)) + ctx.sender ! RouteResponse(routes) + case Failure(t) => + val failure = if (isNeighborBalanceTooLow(d.graph, r)) BalanceTooLow else t + Metrics.FindRouteErrors.withTags(tags.withTag(Tags.Error, failure.getClass.getSimpleName)).increment() + ctx.sender ! Status.Failure(failure) + } } + d } - d } private def toFakeUpdate(extraHop: ExtraHop, htlcMaximum: MilliSatoshi): ChannelUpdate = { @@ -107,6 +130,11 @@ object RouteCalculation { }._2 } + def toChannelDescs(extraRoute: Seq[ExtraHop], targetNodeId: PublicKey): Seq[ChannelDesc] = { + val nextNodeIds = extraRoute.map(_.nodeId).drop(1) :+ targetNodeId + extraRoute.zip(nextNodeIds).map { case (hop, nextNodeId) => ChannelDesc(hop.shortChannelId, hop.nodeId, nextNodeId) } + } + /** Bolt 11 routing hints don't include the channel's capacity, so we round up the maximum htlc amount. */ private def htlcMaxToCapacity(htlcMaximum: MilliSatoshi): Satoshi = htlcMaximum.truncateToSatoshi + 1.sat @@ -357,4 +385,14 @@ object RouteCalculation { amountOk && feeOk } + /** + * Checks if we are directly connected to the target but don't have enough balance in our local channels to send the + * requested amount. We could potentially relay the payment by using indirect routes, but since we're connected to + * the target node it means we'd like to reach it via direct channels as much as possible. + */ + private def isNeighborBalanceTooLow(g: DirectedGraph, r: RouteRequest): Boolean = { + val neighborEdges = g.getEdgesBetween(r.source, r.target) + neighborEdges.nonEmpty && neighborEdges.map(e => e.balance_opt.getOrElse(e.capacity.toMilliSatoshi)).sum < r.amount + } + } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index 43a62fe5a..a001dd46b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -16,6 +16,8 @@ package fr.acinq.eclair.router +import java.util.UUID + import akka.Done import akka.actor.{ActorRef, Props} import akka.event.DiagnosticLoggingAdapter @@ -31,6 +33,7 @@ import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.db.NetworkDb import fr.acinq.eclair.io.Peer.PeerRoutingMessage import fr.acinq.eclair.payment.PaymentRequest.ExtraHop +import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph import fr.acinq.eclair.router.Graph.WeightRatios import fr.acinq.eclair.router.Monitoring.{Metrics, Tags} @@ -341,19 +344,44 @@ object Router { } } + case class Ignore(nodes: Set[PublicKey], channels: Set[ChannelDesc]) { + // @formatter:off + def +(ignoreNode: PublicKey): Ignore = copy(nodes = nodes + ignoreNode) + def ++(ignoreNodes: Set[PublicKey]): Ignore = copy(nodes = nodes ++ ignoreNodes) + def +(ignoreChannel: ChannelDesc): Ignore = copy(channels = channels + ignoreChannel) + def emptyNodes(): Ignore = copy(nodes = Set.empty) + def emptyChannels(): Ignore = copy(channels = Set.empty) + // @formatter:on + } + + object Ignore { + def empty: Ignore = Ignore(Set.empty, Set.empty) + } + case class RouteRequest(source: PublicKey, target: PublicKey, amount: MilliSatoshi, maxFee: MilliSatoshi, assistedRoutes: Seq[Seq[ExtraHop]] = Nil, - ignoreNodes: Set[PublicKey] = Set.empty, - ignoreChannels: Set[ChannelDesc] = Set.empty, - routeParams: Option[RouteParams] = None) + ignore: Ignore = Ignore.empty, + routeParams: Option[RouteParams] = None, + allowMultiPart: Boolean = false, + pendingPayments: Seq[Route] = Nil, + paymentContext: Option[PaymentContext] = None) - case class FinalizeRoute(amount: MilliSatoshi, hops: Seq[PublicKey], assistedRoutes: Seq[Seq[ExtraHop]] = Nil) + case class FinalizeRoute(amount: MilliSatoshi, + hops: Seq[PublicKey], + assistedRoutes: Seq[Seq[ExtraHop]] = Nil, + paymentContext: Option[PaymentContext] = None) + + /** + * Useful for having appropriate logging context at hand when finding routes + */ + case class PaymentContext(id: UUID, parentId: UUID, paymentHash: ByteVector32) + + case class Route(amount: MilliSatoshi, hops: Seq[ChannelHop]) { + require(hops.nonEmpty, "route cannot be empty") - case class Route(amount: MilliSatoshi, hops: Seq[ChannelHop], allowEmpty: Boolean = false) { - require(allowEmpty || hops.nonEmpty, "route cannot be empty") val length = hops.length lazy val fee: MilliSatoshi = { val amountToSend = hops.drop(1).reverse.foldLeft(amount) { case (amount1, hop) => amount1 + hop.fee(amount1) } @@ -362,6 +390,11 @@ object Router { /** This method retrieves the channel update that we used when we built the route. */ def getChannelUpdateForNode(nodeId: PublicKey): Option[ChannelUpdate] = hops.find(_.nodeId == nodeId).map(_.lastUpdate) + + def printNodes(): String = hops.map(_.nextNodeId).mkString("->") + + def printChannels(): String = hops.map(_.lastUpdate.shortChannelId).mkString("->") + } case class RouteResponse(routes: Seq[Route]) { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouterExceptions.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouterExceptions.scala index d0b1984b2..706411fd3 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouterExceptions.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouterExceptions.scala @@ -24,4 +24,6 @@ class RouterException(message: String) extends RuntimeException(message) object RouteNotFound extends RouterException("route not found") +object BalanceTooLow extends RouterException("balance too low") + object CannotRouteToSelf extends RouterException("cannot route to self") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala index 14823c4c9..c3a187b7a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala @@ -475,25 +475,29 @@ object Validation { } } - def handleAvailableBalanceChanged(d: Data, e: AvailableBalanceChanged): Data = { - val (channels1, graph1) = d.channels.get(e.shortChannelId) match { + def handleAvailableBalanceChanged(d: Data, e: AvailableBalanceChanged)(implicit log: LoggingAdapter): Data = { + val desc = ChannelDesc(e.shortChannelId, e.commitments.localParams.nodeId, e.commitments.remoteParams.nodeId) + val (publicChannels1, graph1) = d.channels.get(e.shortChannelId) match { case Some(pc) => val pc1 = pc.updateBalances(e.commitments) - val desc = ChannelDesc(e.shortChannelId, e.commitments.localParams.nodeId, e.commitments.remoteParams.nodeId) + log.debug("public channel balance updated: {}", pc1) val update_opt = if (e.commitments.localParams.nodeId == pc1.ann.nodeId1) pc1.update_1_opt else pc1.update_2_opt val graph1 = update_opt.map(u => d.graph.addEdge(desc, u, pc1.capacity, pc1.getBalanceSameSideAs(u))).getOrElse(d.graph) (d.channels + (e.shortChannelId -> pc1), graph1) case None => (d.channels, d.graph) } - val privateChannels1 = d.privateChannels.get(e.shortChannelId) match { + val (privateChannels1, graph2) = d.privateChannels.get(e.shortChannelId) match { case Some(pc) => val pc1 = pc.updateBalances(e.commitments) - d.privateChannels + (e.shortChannelId -> pc1) + log.debug("private channel balance updated: {}", pc1) + val update_opt = if (e.commitments.localParams.nodeId == pc1.nodeId1) pc1.update_1_opt else pc1.update_2_opt + val graph2 = update_opt.map(u => graph1.addEdge(desc, u, pc1.capacity, pc1.getBalanceSameSideAs(u))).getOrElse(graph1) + (d.privateChannels + (e.shortChannelId -> pc1), graph2) case None => - d.privateChannels + (d.privateChannels, graph1) } - d.copy(channels = channels1, privateChannels = privateChannels1, graph = graph1) + d.copy(channels = publicChannels1, privateChannels = privateChannels1, graph = graph2) } } diff --git a/eclair-core/src/main/scala/kamon/Kamon.scala b/eclair-core/src/main/scala/kamon/Kamon.scala index 28834e351..173bb4c1c 100644 --- a/eclair-core/src/main/scala/kamon/Kamon.scala +++ b/eclair-core/src/main/scala/kamon/Kamon.scala @@ -21,7 +21,7 @@ object Kamon { def withoutTags() = this - def withTags(args: TagSet, a: Boolean) = this + def withTags(args: TagSet) = this def withTags(a: TagSet, b: TagSet, c: Boolean) = this @@ -39,6 +39,8 @@ object Kamon { def increment(a: Int) = this + def increment(a: Long) = this + def decrement() = this def record(a: Long) = this diff --git a/eclair-core/src/main/scala/kamon/tag/TagSet.scala b/eclair-core/src/main/scala/kamon/tag/TagSet.scala index badd2d4c6..ab067bb70 100644 --- a/eclair-core/src/main/scala/kamon/tag/TagSet.scala +++ b/eclair-core/src/main/scala/kamon/tag/TagSet.scala @@ -1,6 +1,10 @@ package kamon.tag -trait TagSet +trait TagSet { + def withTag(t: String, s: Boolean) = this + def withTag(a: String, value: Long) = this + def withTag(a: String, value: String) = this +} object TagSet extends TagSet { def Empty: TagSet = this def of(t: String, s: String) = this diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala index 14192ae85..34f386667 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala @@ -35,6 +35,7 @@ import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentRequest, SendPa import fr.acinq.eclair.router.RouteCalculationSpec.makeUpdateShort import fr.acinq.eclair.router.Router.{GetNetworkStats, GetNetworkStatsResponse, PublicChannel} import fr.acinq.eclair.router.{Announcements, NetworkStats, Router, Stats} +import fr.acinq.eclair.wire.{Color, NodeAnnouncement} import org.mockito.Mockito import org.mockito.scalatest.IdiomaticMockito import org.scalatest.funsuite.FixtureAnyFunSuiteLike @@ -150,6 +151,54 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I assertThrows[IllegalArgumentException](Await.result(eclair.send(None, nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(expiredInvoice)), 50 millis)) } + test("return node announcements") { f => + import f._ + + val eclair = new EclairImpl(kit) + val remoteNodeAnn1 = NodeAnnouncement(randomBytes64, Features.empty, 42L, randomKey.publicKey, Color(42, 42, 42), "LN-rocks", Nil) + val remoteNodeAnn2 = NodeAnnouncement(randomBytes64, Features.empty, 43L, randomKey.publicKey, Color(43, 43, 43), "LN-papers", Nil) + val allNodes = Seq( + NodeAnnouncement(randomBytes64, Features.empty, 561L, randomKey.publicKey, Color(0, 0, 0), "some-node", Nil), + remoteNodeAnn1, + remoteNodeAnn2, + NodeAnnouncement(randomBytes64, Features.empty, 1105L, randomKey.publicKey, Color(0, 0, 0), "some-other-node", Nil) + ) + + { + val fRes = eclair.nodes() + router.expectMsg(Symbol("nodes")) + router.reply(allNodes) + awaitCond(fRes.value match { + case Some(Success(nodes)) => + assert(nodes.toSet === allNodes.toSet) + true + case _ => false + }) + } + { + val fRes = eclair.nodes(Some(Set(remoteNodeAnn1.nodeId, remoteNodeAnn2.nodeId))) + router.expectMsg(Symbol("nodes")) + router.reply(allNodes) + awaitCond(fRes.value match { + case Some(Success(nodes)) => + assert(nodes.toSet === Set(remoteNodeAnn1, remoteNodeAnn2)) + true + case _ => false + }) + } + { + val fRes = eclair.nodes(Some(Set(randomKey.publicKey))) + router.expectMsg(Symbol("nodes")) + router.reply(allNodes) + awaitCond(fRes.value match { + case Some(Success(nodes)) => + assert(nodes.isEmpty) + true + case _ => false + }) + } + } + ignore("allupdates can filter by nodeId") { f => import f._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/StartupSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/StartupSpec.scala index b44bca1ac..c4fdfedfc 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/StartupSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/StartupSpec.scala @@ -22,10 +22,10 @@ import com.typesafe.config.{Config, ConfigFactory} import fr.acinq.bitcoin.Block import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.FeatureSupport.Mandatory -import fr.acinq.eclair.Features.{BasicMultiPartPayment, ChannelRangeQueries, ChannelRangeQueriesExtended, InitialRoutingSync, OptionDataLossProtect, PaymentSecret, VariableLengthOnion} +import fr.acinq.eclair.Features._ import fr.acinq.eclair.crypto.LocalKeyManager -import scodec.bits.ByteVector import org.scalatest.funsuite.AnyFunSuite +import scodec.bits.ByteVector import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -46,6 +46,7 @@ class StartupSpec extends AnyFunSuite { test("check configuration") { assert(Try(makeNodeParamsWithDefaults(ConfigFactory.load().getConfig("eclair"))).isSuccess) assert(Try(makeNodeParamsWithDefaults(ConfigFactory.load().getConfig("eclair").withFallback(ConfigFactory.parseMap(Map("max-feerate-mismatch" -> 42).asJava)))).isFailure) + assert(Try(makeNodeParamsWithDefaults(ConfigFactory.load().getConfig("eclair").withFallback(ConfigFactory.parseMap(Map("on-chain-fees.max-feerate-mismatch" -> 1.56).asJava)))).isFailure) } test("NodeParams should fail if the alias is illegal (over 32 bytes)") { @@ -84,6 +85,10 @@ class StartupSpec extends AnyFunSuite { } test("NodeParams should fail if features are inconsistent") { + // Because of https://github.com/ACINQ/eclair/issues/1434, we need to remove the default features when falling back + // to the default configuration. + def finalizeConf(testCfg: Config): Config = testCfg.withFallback(defaultConf.withoutPath("features")) + val legalFeaturesConf = ConfigFactory.parseMap(Map( s"features.${OptionDataLossProtect.rfcName}" -> "optional", s"features.${InitialRoutingSync.rfcName}" -> "optional", @@ -105,9 +110,9 @@ class StartupSpec extends AnyFunSuite { s"features.${BasicMultiPartPayment.rfcName}" -> "optional" ).asJava) - assert(Try(makeNodeParamsWithDefaults(legalFeaturesConf.withFallback(defaultConf))).isSuccess) - assert(Try(makeNodeParamsWithDefaults(illegalButAllowedFeaturesConf.withFallback(defaultConf))).isSuccess) - assert(Try(makeNodeParamsWithDefaults(illegalFeaturesConf.withFallback(defaultConf))).isFailure) + assert(Try(makeNodeParamsWithDefaults(finalizeConf(legalFeaturesConf))).isSuccess) + assert(Try(makeNodeParamsWithDefaults(finalizeConf(illegalButAllowedFeaturesConf))).isSuccess) + assert(Try(makeNodeParamsWithDefaults(finalizeConf(illegalFeaturesConf))).isFailure) } test("parse human readable override features") { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index b71976bba..469df7b0b 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -22,9 +22,9 @@ import java.util.concurrent.atomic.AtomicLong import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.{Block, ByteVector32, Script} import fr.acinq.eclair.FeatureSupport.Optional -import fr.acinq.eclair.Features.{ChannelRangeQueries, ChannelRangeQueriesExtended, InitialRoutingSync, OptionDataLossProtect, VariableLengthOnion} +import fr.acinq.eclair.Features._ import fr.acinq.eclair.NodeParams.BITCOIND -import fr.acinq.eclair.blockchain.fee.{FeeEstimator, FeeTargets, FeeratesPerKw, OnChainFeeConf} +import fr.acinq.eclair.blockchain.fee._ import fr.acinq.eclair.crypto.LocalKeyManager import fr.acinq.eclair.db._ import fr.acinq.eclair.io.Peer @@ -84,7 +84,7 @@ object TestConstants { onChainFeeConf = OnChainFeeConf( feeTargets = FeeTargets(6, 2, 2, 6), feeEstimator = new TestFeeEstimator, - maxFeerateMismatch = 1.5, + maxFeerateMismatch = FeerateTolerance(0.5, 8.0), closeOnOfflineMismatch = true, updateFeeMinDiffRatio = 0.1 ), @@ -170,7 +170,7 @@ object TestConstants { onChainFeeConf = OnChainFeeConf( feeTargets = FeeTargets(6, 2, 2, 6), feeEstimator = new TestFeeEstimator, - maxFeerateMismatch = 1.0, + maxFeerateMismatch = FeerateTolerance(0.75, 1.5), closeOnOfflineMismatch = true, updateFeeMinDiffRatio = 0.1 ), diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/CommitmentsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/CommitmentsSpec.scala index 3cf93faa7..239499b2c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/CommitmentsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/CommitmentsSpec.scala @@ -20,6 +20,8 @@ import java.util.UUID import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.{DeterministicWallet, Satoshi, Transaction} +import fr.acinq.eclair.TestConstants.TestFeeEstimator +import fr.acinq.eclair.blockchain.fee.{FeeTargets, FeerateTolerance, OnChainFeeConf} import fr.acinq.eclair.channel.Commitments._ import fr.acinq.eclair.channel.Helpers.Funding import fr.acinq.eclair.channel.states.StateTestsHelperMethods @@ -42,6 +44,8 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with implicit val log: akka.event.LoggingAdapter = akka.event.NoLogging + val feeConfNoMismatch = OnChainFeeConf(FeeTargets(6, 2, 2, 6), new TestFeeEstimator, FeerateTolerance(0.00001, 100000.0), closeOnOfflineMismatch = false, 1.0) + override def withFixture(test: OneArgTest): Outcome = { val setup = init() import setup._ @@ -66,8 +70,8 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with assert(bc0.availableBalanceForReceive == a - htlcOutputFee) val (_, cmdAdd) = makeCmdAdd(a - htlcOutputFee - 1000.msat, bob.underlyingActor.nodeParams.nodeId, currentBlockHeight) - val Success((ac1, add)) = sendAdd(ac0, cmdAdd, Local(UUID.randomUUID, None), currentBlockHeight) - val Success(bc1) = receiveAdd(bc0, add) + val Success((ac1, add)) = sendAdd(ac0, cmdAdd, Local(UUID.randomUUID, None), currentBlockHeight, alice.underlyingActor.nodeParams.onChainFeeConf) + val Success(bc1) = receiveAdd(bc0, add, bob.underlyingActor.nodeParams.onChainFeeConf) val Success((_, commit1)) = sendCommit(ac1, alice.underlyingActor.nodeParams.keyManager) val Success((bc2, _)) = receiveCommit(bc1, commit1, bob.underlyingActor.nodeParams.keyManager) // we don't take into account the additional HTLC fee since Alice's balance is below the trim threshold. @@ -94,11 +98,11 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with assert(bc0.availableBalanceForReceive == a) val (payment_preimage, cmdAdd) = makeCmdAdd(p, bob.underlyingActor.nodeParams.nodeId, currentBlockHeight) - val Success((ac1, add)) = sendAdd(ac0, cmdAdd, Local(UUID.randomUUID, None), currentBlockHeight) + val Success((ac1, add)) = sendAdd(ac0, cmdAdd, Local(UUID.randomUUID, None), currentBlockHeight, alice.underlyingActor.nodeParams.onChainFeeConf) assert(ac1.availableBalanceForSend == a - p - fee) // as soon as htlc is sent, alice sees its balance decrease (more than the payment amount because of the commitment fees) assert(ac1.availableBalanceForReceive == b) - val Success(bc1) = receiveAdd(bc0, add) + val Success(bc1) = receiveAdd(bc0, add, bob.underlyingActor.nodeParams.onChainFeeConf) assert(bc1.availableBalanceForSend == b) assert(bc1.availableBalanceForReceive == a - p - fee) @@ -179,11 +183,11 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with assert(bc0.availableBalanceForReceive == a) val (_, cmdAdd) = makeCmdAdd(p, bob.underlyingActor.nodeParams.nodeId, currentBlockHeight) - val Success((ac1, add)) = sendAdd(ac0, cmdAdd, Local(UUID.randomUUID, None), currentBlockHeight) + val Success((ac1, add)) = sendAdd(ac0, cmdAdd, Local(UUID.randomUUID, None), currentBlockHeight, alice.underlyingActor.nodeParams.onChainFeeConf) assert(ac1.availableBalanceForSend == a - p - fee) // as soon as htlc is sent, alice sees its balance decrease (more than the payment amount because of the commitment fees) assert(ac1.availableBalanceForReceive == b) - val Success(bc1) = receiveAdd(bc0, add) + val Success(bc1) = receiveAdd(bc0, add, bob.underlyingActor.nodeParams.onChainFeeConf) assert(bc1.availableBalanceForSend == b) assert(bc1.availableBalanceForReceive == a - p - fee) @@ -267,29 +271,29 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with assert(bc0.availableBalanceForReceive == a) val (payment_preimage1, cmdAdd1) = makeCmdAdd(p1, bob.underlyingActor.nodeParams.nodeId, currentBlockHeight) - val Success((ac1, add1)) = sendAdd(ac0, cmdAdd1, Local(UUID.randomUUID, None), currentBlockHeight) + val Success((ac1, add1)) = sendAdd(ac0, cmdAdd1, Local(UUID.randomUUID, None), currentBlockHeight, alice.underlyingActor.nodeParams.onChainFeeConf) assert(ac1.availableBalanceForSend == a - p1 - fee) // as soon as htlc is sent, alice sees its balance decrease (more than the payment amount because of the commitment fees) assert(ac1.availableBalanceForReceive == b) val (_, cmdAdd2) = makeCmdAdd(p2, bob.underlyingActor.nodeParams.nodeId, currentBlockHeight) - val Success((ac2, add2)) = sendAdd(ac1, cmdAdd2, Local(UUID.randomUUID, None), currentBlockHeight) + val Success((ac2, add2)) = sendAdd(ac1, cmdAdd2, Local(UUID.randomUUID, None), currentBlockHeight, alice.underlyingActor.nodeParams.onChainFeeConf) assert(ac2.availableBalanceForSend == a - p1 - fee - p2 - fee) // as soon as htlc is sent, alice sees its balance decrease (more than the payment amount because of the commitment fees) assert(ac2.availableBalanceForReceive == b) val (payment_preimage3, cmdAdd3) = makeCmdAdd(p3, alice.underlyingActor.nodeParams.nodeId, currentBlockHeight) - val Success((bc1, add3)) = sendAdd(bc0, cmdAdd3, Local(UUID.randomUUID, None), currentBlockHeight) + val Success((bc1, add3)) = sendAdd(bc0, cmdAdd3, Local(UUID.randomUUID, None), currentBlockHeight, bob.underlyingActor.nodeParams.onChainFeeConf) assert(bc1.availableBalanceForSend == b - p3) // bob doesn't pay the fee assert(bc1.availableBalanceForReceive == a) - val Success(bc2) = receiveAdd(bc1, add1) + val Success(bc2) = receiveAdd(bc1, add1, bob.underlyingActor.nodeParams.onChainFeeConf) assert(bc2.availableBalanceForSend == b - p3) assert(bc2.availableBalanceForReceive == a - p1 - fee) - val Success(bc3) = receiveAdd(bc2, add2) + val Success(bc3) = receiveAdd(bc2, add2, bob.underlyingActor.nodeParams.onChainFeeConf) assert(bc3.availableBalanceForSend == b - p3) assert(bc3.availableBalanceForReceive == a - p1 - fee - p2 - fee) - val Success(ac3) = receiveAdd(ac2, add3) + val Success(ac3) = receiveAdd(ac2, add3, alice.underlyingActor.nodeParams.onChainFeeConf) assert(ac3.availableBalanceForSend == a - p1 - fee - p2 - fee) assert(ac3.availableBalanceForReceive == b - p3) @@ -398,7 +402,7 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with val isFunder = true val c = CommitmentsSpec.makeCommitments(100000000 msat, 50000000 msat, 2500, 546 sat, isFunder) val (_, cmdAdd) = makeCmdAdd(c.availableBalanceForSend, randomKey.publicKey, f.currentBlockHeight) - val Success((c1, _)) = sendAdd(c, cmdAdd, Local(UUID.randomUUID, None), f.currentBlockHeight) + val Success((c1, _)) = sendAdd(c, cmdAdd, Local(UUID.randomUUID, None), f.currentBlockHeight, feeConfNoMismatch) assert(c1.availableBalanceForSend === 0.msat) // We should be able to handle a fee increase. @@ -406,7 +410,7 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with // Now we shouldn't be able to send until we receive enough to handle the updated commit tx fee (even trimmed HTLCs shouldn't be sent). val (_, cmdAdd1) = makeCmdAdd(100 msat, randomKey.publicKey, f.currentBlockHeight) - val Failure(e) = sendAdd(c2, cmdAdd1, Local(UUID.randomUUID, None), f.currentBlockHeight) + val Failure(e) = sendAdd(c2, cmdAdd1, Local(UUID.randomUUID, None), f.currentBlockHeight, feeConfNoMismatch) assert(e.isInstanceOf[InsufficientFunds]) } @@ -414,7 +418,7 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with for (isFunder <- Seq(true, false)) { val c = CommitmentsSpec.makeCommitments(702000000 msat, 52000000 msat, 2679, 546 sat, isFunder) val (_, cmdAdd) = makeCmdAdd(c.availableBalanceForSend, randomKey.publicKey, f.currentBlockHeight) - val result = sendAdd(c, cmdAdd, Local(UUID.randomUUID, None), f.currentBlockHeight) + val result = sendAdd(c, cmdAdd, Local(UUID.randomUUID, None), f.currentBlockHeight, feeConfNoMismatch) assert(result.isSuccess, result) } } @@ -423,7 +427,7 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with for (isFunder <- Seq(true, false)) { val c = CommitmentsSpec.makeCommitments(31000000 msat, 702000000 msat, 2679, 546 sat, isFunder) val add = UpdateAddHtlc(randomBytes32, c.remoteNextHtlcId, c.availableBalanceForReceive, randomBytes32, CltvExpiry(f.currentBlockHeight), TestConstants.emptyOnionPacket) - receiveAdd(c, add) + receiveAdd(c, add, feeConfNoMismatch) } } @@ -444,13 +448,13 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with for (_ <- 0 to t.pendingHtlcs) { val amount = Random.nextInt(maxPendingHtlcAmount.toLong.toInt).msat.max(1 msat) val (_, cmdAdd) = makeCmdAdd(amount, randomKey.publicKey, f.currentBlockHeight) - sendAdd(c, cmdAdd, Local(UUID.randomUUID, None), f.currentBlockHeight) match { + sendAdd(c, cmdAdd, Local(UUID.randomUUID, None), f.currentBlockHeight, feeConfNoMismatch) match { case Success((cc, _)) => c = cc case Failure(e) => fail(s"$t -> could not setup initial htlcs: $e") } } val (_, cmdAdd) = makeCmdAdd(c.availableBalanceForSend, randomKey.publicKey, f.currentBlockHeight) - val result = sendAdd(c, cmdAdd, Local(UUID.randomUUID, None), f.currentBlockHeight) + val result = sendAdd(c, cmdAdd, Local(UUID.randomUUID, None), f.currentBlockHeight, feeConfNoMismatch) assert(result.isSuccess, s"$t -> $result") } } @@ -472,13 +476,13 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with for (_ <- 0 to t.pendingHtlcs) { val amount = Random.nextInt(maxPendingHtlcAmount.toLong.toInt).msat.max(1 msat) val add = UpdateAddHtlc(randomBytes32, c.remoteNextHtlcId, amount, randomBytes32, CltvExpiry(f.currentBlockHeight), TestConstants.emptyOnionPacket) - receiveAdd(c, add) match { + receiveAdd(c, add, feeConfNoMismatch) match { case Success(cc) => c = cc case Failure(e) => fail(s"$t -> could not setup initial htlcs: $e") } } val add = UpdateAddHtlc(randomBytes32, c.remoteNextHtlcId, c.availableBalanceForReceive, randomBytes32, CltvExpiry(f.currentBlockHeight), TestConstants.emptyOnionPacket) - receiveAdd(c, add) match { + receiveAdd(c, add, feeConfNoMismatch) match { case Success(_) => () case Failure(e) => fail(s"$t -> $e") } @@ -496,8 +500,8 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with val bc0 = bob.stateData.asInstanceOf[DATA_NORMAL].commitments val (_, cmdAdd) = makeCmdAdd(p, bob.underlyingActor.nodeParams.nodeId, currentBlockHeight) - val Success((ac1, add)) = sendAdd(ac0, cmdAdd, Local(UUID.randomUUID, None), currentBlockHeight) - val Success(bc1) = receiveAdd(bc0, add) + val Success((ac1, add)) = sendAdd(ac0, cmdAdd, Local(UUID.randomUUID, None), currentBlockHeight, feeConfNoMismatch) + val Success(bc1) = receiveAdd(bc0, add, feeConfNoMismatch) val Success((ac2, commit1)) = sendCommit(ac1, alice.underlyingActor.nodeParams.keyManager) val Success((bc2, revocation1)) = receiveCommit(bc1, commit1, bob.underlyingActor.nodeParams.keyManager) val Success((ac3, _)) = receiveRevocation(ac2, revocation1) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForOpenChannelStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForOpenChannelStateSpec.scala index 9d34131b9..b1cb7b70c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForOpenChannelStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForOpenChannelStateSpec.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.wire.{AcceptChannel, ChannelTlv, Error, Init, OpenChannel import fr.acinq.eclair.{ActivatedFeature, CltvExpiryDelta, Features, LongToBtcAmount, TestConstants, TestKitBaseClass, ToMilliSatoshiConversion} import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.{Outcome, Tag} -import scodec.bits.{ByteVector, HexStringSyntax} +import scodec.bits.ByteVector import scala.concurrent.duration._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala index d5cc5659c..9d8ecb0ba 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala @@ -24,8 +24,9 @@ import akka.testkit.TestProbe import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Crypto, ScriptFlags, Transaction} import fr.acinq.eclair.Features.StaticRemoteKey -import fr.acinq.eclair.TestConstants.{Alice, Bob} +import fr.acinq.eclair.TestConstants.{Alice, Bob, TestFeeEstimator} import fr.acinq.eclair.UInt64.Conversions._ +import fr.acinq.eclair._ import fr.acinq.eclair.blockchain._ import fr.acinq.eclair.blockchain.fee.FeeratesPerKw import fr.acinq.eclair.channel.Channel._ @@ -40,7 +41,6 @@ import fr.acinq.eclair.transactions.DirectedHtlc.{incoming, outgoing} import fr.acinq.eclair.transactions.Transactions import fr.acinq.eclair.transactions.Transactions.{HtlcSuccessTx, htlcSuccessWeight, htlcTimeoutWeight, weight2fee} import fr.acinq.eclair.wire.{AnnouncementSignatures, ChannelUpdate, ClosingSigned, CommitSig, Error, FailureMessageCodecs, PermanentChannelFailure, RevokeAndAck, Shutdown, UpdateAddHtlc, UpdateFailHtlc, UpdateFailMalformedHtlc, UpdateFee, UpdateFulfillHtlc} -import fr.acinq.eclair._ import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.{Outcome, Tag} import scodec.bits._ @@ -383,6 +383,31 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with alice2bob.expectNoMsg(200 millis) } + test("recv CMD_ADD_HTLC (channel feerate mismatch)") { f => + import f._ + + val sender = TestProbe() + bob.feeEstimator.setFeerate(FeeratesPerKw.single(20000)) + sender.send(bob, CurrentFeerates(FeeratesPerKw.single(20000))) + bob2alice.expectNoMsg(100 millis) // we don't close because the commitment doesn't contain any HTLC + + val initialState = bob.stateData.asInstanceOf[DATA_NORMAL] + val upstream = Upstream.Local(UUID.randomUUID()) + val add = CMD_ADD_HTLC(500000 msat, randomBytes32, CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), TestConstants.emptyOnionPacket, upstream) + sender.send(bob, add) + val error = FeerateTooDifferent(channelId(bob), 20000, 10000) + sender.expectMsg(Failure(AddHtlcFailed(channelId(bob), add.paymentHash, error, Origin.Local(upstream.id, Some(sender.ref)), Some(initialState.channelUpdate), Some(add)))) + bob2alice.expectNoMsg(100 millis) // we don't close the channel, we can simply avoid using it while we disagree on feerate + + // we now agree on feerate so we can send HTLCs + bob.feeEstimator.setFeerate(FeeratesPerKw.single(11000)) + sender.send(bob, CurrentFeerates(FeeratesPerKw.single(11000))) + bob2alice.expectNoMsg(100 millis) + sender.send(bob, add) + sender.expectMsg(ChannelCommandResponse.Ok) + bob2alice.expectMsgType[UpdateAddHtlc] + } + test("recv CMD_ADD_HTLC (after having sent Shutdown)") { f => import f._ val sender = TestProbe() @@ -1273,9 +1298,13 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with localChanges = initialState.commitments.localChanges.copy(initialState.commitments.localChanges.proposed :+ fulfill)))) } - test("recv CMD_FULFILL_HTLC") { testReceiveCmdFulfillHtlc _ } + test("recv CMD_FULFILL_HTLC") { + testReceiveCmdFulfillHtlc _ + } - test("recv CMD_FULFILL_HTLC (static_remotekey)", Tag("static_remotekey")) { testReceiveCmdFulfillHtlc _ } + test("recv CMD_FULFILL_HTLC (static_remotekey)", Tag("static_remotekey")) { + testReceiveCmdFulfillHtlc _ + } test("recv CMD_FULFILL_HTLC (unknown htlc id)") { f => import f._ @@ -1331,9 +1360,13 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with assert(forward.htlc === htlc) } - test("recv UpdateFulfillHtlc") { testUpdateFulfillHtlc _ } + test("recv UpdateFulfillHtlc") { + testUpdateFulfillHtlc _ + } - test("recv UpdateFulfillHtlc (static_remotekey)", Tag("(static_remotekey)")) { testUpdateFulfillHtlc _ } + test("recv UpdateFulfillHtlc (static_remotekey)", Tag("(static_remotekey)")) { + testUpdateFulfillHtlc _ + } test("recv UpdateFulfillHtlc (sender has not signed htlc)") { f => import f._ @@ -1407,9 +1440,13 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with } - test("recv CMD_FAIL_HTLC") { testCmdFailHtlc _ } + test("recv CMD_FAIL_HTLC") { + testCmdFailHtlc _ + } - test("recv CMD_FAIL_HTLC (static_remotekey)", Tag("static_remotekey")) { testCmdFailHtlc _ } + test("recv CMD_FAIL_HTLC (static_remotekey)", Tag("static_remotekey")) { + testCmdFailHtlc _ + } test("recv CMD_FAIL_HTLC (unknown htlc id)") { f => import f._ @@ -1496,8 +1533,12 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with relayerA.expectNoMsg() } - test("recv UpdateFailHtlc") { testUpdateFailHtlc _ } - test("recv UpdateFailHtlc (static_remotekey)", Tag("static_remotekey")) { testUpdateFailHtlc _ } + test("recv UpdateFailHtlc") { + testUpdateFailHtlc _ + } + test("recv UpdateFailHtlc (static_remotekey)", Tag("static_remotekey")) { + testUpdateFailHtlc _ + } test("recv UpdateFailMalformedHtlc") { f => import f._ @@ -1635,19 +1676,19 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with test("recv UpdateFee") { f => import f._ val initialData = bob.stateData.asInstanceOf[DATA_NORMAL] - val fee1 = UpdateFee(ByteVector32.Zeroes, 12000) - bob ! fee1 - val fee2 = UpdateFee(ByteVector32.Zeroes, 14000) - bob ! fee2 - awaitCond(bob.stateData == initialData.copy(commitments = initialData.commitments.copy(remoteChanges = initialData.commitments.remoteChanges.copy(proposed = initialData.commitments.remoteChanges.proposed :+ fee2), remoteNextHtlcId = 0))) + val fee = UpdateFee(ByteVector32.Zeroes, 12000) + bob ! fee + awaitCond(bob.stateData == initialData.copy(commitments = initialData.commitments.copy(remoteChanges = initialData.commitments.remoteChanges.copy(proposed = initialData.commitments.remoteChanges.proposed :+ fee), remoteNextHtlcId = 0))) } test("recv UpdateFee (two in a row)") { f => import f._ val initialData = bob.stateData.asInstanceOf[DATA_NORMAL] - val fee = UpdateFee(ByteVector32.Zeroes, 12000) - bob ! fee - awaitCond(bob.stateData == initialData.copy(commitments = initialData.commitments.copy(remoteChanges = initialData.commitments.remoteChanges.copy(proposed = initialData.commitments.remoteChanges.proposed :+ fee), remoteNextHtlcId = 0))) + val fee1 = UpdateFee(ByteVector32.Zeroes, 12000) + bob ! fee1 + val fee2 = UpdateFee(ByteVector32.Zeroes, 14000) + bob ! fee2 + awaitCond(bob.stateData == initialData.copy(commitments = initialData.commitments.copy(remoteChanges = initialData.commitments.remoteChanges.copy(proposed = initialData.commitments.remoteChanges.proposed :+ fee2), remoteNextHtlcId = 0))) } test("recv UpdateFee (when sender is not funder)") { f => @@ -1684,16 +1725,20 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with test("recv UpdateFee (local/remote feerates are too different)") { f => import f._ + bob.feeEstimator.setFeerate(FeeratesPerKw(1000, 2000, 6000, 12000, 36000, 72000, 140000)) val tx = bob.stateData.asInstanceOf[DATA_NORMAL].commitments.localCommit.publishableTxs.commitTx.tx val sender = TestProbe() - // Alice will use $localFeeRate when performing the checks for update_fee - val localFeeRate = bob.feeEstimator.getFeeratePerKw(bob.feeTargets.commitmentBlockTarget) - assert(localFeeRate === 2000) - val remoteFeeUpdate = 85000 - sender.send(bob, UpdateFee(ByteVector32.Zeroes, remoteFeeUpdate)) + val localFeerate = bob.feeEstimator.getFeeratePerKw(bob.feeTargets.commitmentBlockTarget) + assert(localFeerate === 2000) + val remoteFeerate = 4000 + sender.send(bob, UpdateFee(ByteVector32.Zeroes, remoteFeerate)) + bob2alice.expectNoMsg(250 millis) // we don't close because the commitment doesn't contain any HTLC + + // when we try to add an HTLC, we still disagree on the feerate so we close + alice2bob.send(bob, UpdateAddHtlc(ByteVector32.Zeroes, 0, 2500000 msat, randomBytes32, CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), TestConstants.emptyOnionPacket)) val error = bob2alice.expectMsgType[Error] - assert(new String(error.data.toArray) === s"local/remote feerates are too different: remoteFeeratePerKw=$remoteFeeUpdate localFeeratePerKw=$localFeeRate") + assert(new String(error.data.toArray).contains("local/remote feerates are too different")) awaitCond(bob.stateName == CLOSING) // channel should be advertised as down assert(channelUpdateListener.expectMsgType[LocalChannelDown].channelId === bob.stateData.asInstanceOf[DATA_CLOSING].channelId) @@ -2105,10 +2150,15 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with bob2alice.expectNoMsg(500 millis) } - test("recv CurrentFeerate (when fundee, commit-fee/network-fee are very different)") { f => + test("recv CurrentFeerate (when fundee, commit-fee/network-fee are very different, with HTLCs)") { f => import f._ + + addHtlc(10000000 msat, alice, bob, alice2bob, bob2alice) + crossSign(alice, bob, alice2bob, bob2alice) + val sender = TestProbe() - val event = CurrentFeerates(FeeratesPerKw.single(100)) + bob.feeEstimator.setFeerate(FeeratesPerKw.single(14000)) + val event = CurrentFeerates(FeeratesPerKw.single(14000)) sender.send(bob, event) bob2alice.expectMsgType[Error] bob2blockchain.expectMsgType[PublishAsap] // commit tx @@ -2117,6 +2167,24 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with awaitCond(bob.stateName == CLOSING) } + test("recv CurrentFeerate (when fundee, commit-fee/network-fee are very different, without HTLCs)") { f => + import f._ + + val sender = TestProbe() + bob.feeEstimator.setFeerate(FeeratesPerKw.single(1000)) + val event = CurrentFeerates(FeeratesPerKw.single(1000)) + sender.send(bob, event) + bob2alice.expectNoMsg(250 millis) // we don't close because the commitment doesn't contain any HTLC + + // when we try to add an HTLC, we still disagree on the feerate so we close + alice2bob.send(bob, UpdateAddHtlc(ByteVector32.Zeroes, 0, 2500000 msat, randomBytes32, CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), TestConstants.emptyOnionPacket)) + bob2alice.expectMsgType[Error] + bob2blockchain.expectMsgType[PublishAsap] // commit tx + bob2blockchain.expectMsgType[PublishAsap] // main delayed + bob2blockchain.expectMsgType[WatchConfirmed] + awaitCond(bob.stateName == CLOSING) + } + test("recv BITCOIN_FUNDING_SPENT (their commit w/ htlc)") { f => import f._ val sender = TestProbe() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala index bbfb001bf..ecbc33102 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala @@ -469,9 +469,23 @@ class OfflineStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with test("handle feerate changes while offline (funder scenario)") { f => import f._ - val sender = TestProbe() + + // we only close channels on feerate mismatch if there are HTLCs at risk in the commitment + addHtlc(125000000 msat, alice, bob, alice2bob, bob2alice) + crossSign(alice, bob, alice2bob, bob2alice) + + testHandleFeerateFunder(f, shouldClose = true) + } + + test("handle feerate changes while offline without HTLCs (funder scenario)") { f => + testHandleFeerateFunder(f, shouldClose = false) + } + + def testHandleFeerateFunder(f: FixtureParam, shouldClose: Boolean): Unit = { + import f._ // we simulate a disconnection + val sender = TestProbe() sender.send(alice, INPUT_DISCONNECTED) sender.send(bob, INPUT_DISCONNECTED) awaitCond(alice.stateName == OFFLINE) @@ -480,32 +494,44 @@ class OfflineStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with val aliceStateData = alice.stateData.asInstanceOf[DATA_NORMAL] val aliceCommitTx = aliceStateData.commitments.localCommit.publishableTxs.commitTx.tx - val localFeeratePerKw = aliceStateData.commitments.localCommit.spec.feeratePerKw - val tooHighFeeratePerKw = ((alice.underlyingActor.nodeParams.onChainFeeConf.maxFeerateMismatch + 6) * localFeeratePerKw).toLong - val highFeerate = FeeratesPerKw.single(tooHighFeeratePerKw) + val currentFeeratePerKw = aliceStateData.commitments.localCommit.spec.feeratePerKw + // we receive a feerate update that makes our current feerate too low compared to the network's (we multiply by 1.1 + // to ensure the network's feerate is 10% above our threshold). + val networkFeeratePerKw = (1.1 * currentFeeratePerKw / alice.underlyingActor.nodeParams.onChainFeeConf.maxFeerateMismatch.ratioLow).toLong + val networkFeerate = FeeratesPerKw.single(networkFeeratePerKw) // alice is funder - sender.send(alice, CurrentFeerates(highFeerate)) - alice2blockchain.expectMsg(PublishAsap(aliceCommitTx)) + sender.send(alice, CurrentFeerates(networkFeerate)) + if (shouldClose) { + alice2blockchain.expectMsg(PublishAsap(aliceCommitTx)) + } else { + alice2blockchain.expectNoMsg() + } } test("handle feerate changes while offline (don't close on mismatch)", Tag("disable-offline-mismatch")) { f => import f._ - val sender = TestProbe() + + // we only close channels on feerate mismatch if there are HTLCs at risk in the commitment + addHtlc(125000000 msat, alice, bob, alice2bob, bob2alice) + crossSign(alice, bob, alice2bob, bob2alice) // we simulate a disconnection + val sender = TestProbe() sender.send(alice, INPUT_DISCONNECTED) sender.send(bob, INPUT_DISCONNECTED) awaitCond(alice.stateName == OFFLINE) awaitCond(bob.stateName == OFFLINE) val aliceStateData = alice.stateData.asInstanceOf[DATA_NORMAL] - val localFeeratePerKw = aliceStateData.commitments.localCommit.spec.feeratePerKw - val tooHighFeeratePerKw = ((alice.underlyingActor.nodeParams.onChainFeeConf.maxFeerateMismatch + 6) * localFeeratePerKw).toLong - val highFeerate = FeeratesPerKw.single(tooHighFeeratePerKw) + val currentFeeratePerKw = aliceStateData.commitments.localCommit.spec.feeratePerKw + // we receive a feerate update that makes our current feerate too low compared to the network's (we multiply by 1.1 + // to ensure the network's feerate is 10% above our threshold). + val networkFeeratePerKw = (1.1 * currentFeeratePerKw / alice.underlyingActor.nodeParams.onChainFeeConf.maxFeerateMismatch.ratioLow).toLong + val networkFeerate = FeeratesPerKw.single(networkFeeratePerKw) // this time Alice will ignore feerate changes for the offline channel - sender.send(alice, CurrentFeerates(highFeerate)) + sender.send(alice, CurrentFeerates(networkFeerate)) alice2blockchain.expectNoMsg() alice2bob.expectNoMsg() } @@ -546,9 +572,23 @@ class OfflineStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with test("handle feerate changes while offline (fundee scenario)") { f => import f._ - val sender = TestProbe() + + // we only close channels on feerate mismatch if there are HTLCs at risk in the commitment + addHtlc(125000000 msat, alice, bob, alice2bob, bob2alice) + crossSign(alice, bob, alice2bob, bob2alice) + + testHandleFeerateFundee(f, shouldClose = true) + } + + test("handle feerate changes while offline without HTLCs (fundee scenario)") { f => + testHandleFeerateFundee(f, shouldClose = false) + } + + def testHandleFeerateFundee(f: FixtureParam, shouldClose: Boolean): Unit = { + import f._ // we simulate a disconnection + val sender = TestProbe() sender.send(alice, INPUT_DISCONNECTED) sender.send(bob, INPUT_DISCONNECTED) awaitCond(alice.stateName == OFFLINE) @@ -557,13 +597,19 @@ class OfflineStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with val bobStateData = bob.stateData.asInstanceOf[DATA_NORMAL] val bobCommitTx = bobStateData.commitments.localCommit.publishableTxs.commitTx.tx - val localFeeratePerKw = bobStateData.commitments.localCommit.spec.feeratePerKw - val tooHighFeeratePerKw = ((bob.underlyingActor.nodeParams.onChainFeeConf.maxFeerateMismatch + 6) * localFeeratePerKw).toLong - val highFeerate = FeeratesPerKw.single(tooHighFeeratePerKw) + val currentFeeratePerKw = bobStateData.commitments.localCommit.spec.feeratePerKw + // we receive a feerate update that makes our current feerate too low compared to the network's (we multiply by 1.1 + // to ensure the network's feerate is 10% above our threshold). + val networkFeeratePerKw = (1.1 * currentFeeratePerKw / bob.underlyingActor.nodeParams.onChainFeeConf.maxFeerateMismatch.ratioLow).toLong + val networkFeerate = FeeratesPerKw.single(networkFeeratePerKw) // bob is fundee - sender.send(bob, CurrentFeerates(highFeerate)) - bob2blockchain.expectMsg(PublishAsap(bobCommitTx)) + sender.send(bob, CurrentFeerates(networkFeerate)) + if (shouldClose) { + bob2blockchain.expectMsg(PublishAsap(bobCommitTx)) + } else { + bob2blockchain.expectNoMsg() + } } test("re-send channel_update at reconnection for private channels") { f => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala index 953e51ae8..03a03a670 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala @@ -90,7 +90,6 @@ class SqliteAuditDbSpec extends AnyFunSuite { val sqlite = TestConstants.sqliteInMemory() val db = new SqliteAuditDb(sqlite) - val n1 = randomKey.publicKey val n2 = randomKey.publicKey val n3 = randomKey.publicKey val n4 = randomKey.publicKey @@ -99,24 +98,36 @@ class SqliteAuditDbSpec extends AnyFunSuite { val c2 = randomBytes32 val c3 = randomBytes32 val c4 = randomBytes32 + val c5 = randomBytes32 + val c6 = randomBytes32 - db.add(ChannelPaymentRelayed(46000 msat, 44000 msat, randomBytes32, randomBytes32, c1)) - db.add(ChannelPaymentRelayed(41000 msat, 40000 msat, randomBytes32, randomBytes32, c1)) - db.add(ChannelPaymentRelayed(43000 msat, 42000 msat, randomBytes32, randomBytes32, c1)) - db.add(ChannelPaymentRelayed(42000 msat, 40000 msat, randomBytes32, randomBytes32, c2)) - db.add(TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(25000 msat, randomBytes32)), Seq(PaymentRelayed.Part(20000 msat, c4)))) - db.add(TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(46000 msat, randomBytes32)), Seq(PaymentRelayed.Part(16000 msat, c2), PaymentRelayed.Part(10000 msat, c4), PaymentRelayed.Part(14000 msat, c4)))) + db.add(ChannelPaymentRelayed(46000 msat, 44000 msat, randomBytes32, c6, c1)) + db.add(ChannelPaymentRelayed(41000 msat, 40000 msat, randomBytes32, c6, c1)) + db.add(ChannelPaymentRelayed(43000 msat, 42000 msat, randomBytes32, c5, c1)) + db.add(ChannelPaymentRelayed(42000 msat, 40000 msat, randomBytes32, c5, c2)) + db.add(ChannelPaymentRelayed(45000 msat, 40000 msat, randomBytes32, c5, c6)) + db.add(TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(25000 msat, c6)), Seq(PaymentRelayed.Part(20000 msat, c4)))) + db.add(TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(46000 msat, c6)), Seq(PaymentRelayed.Part(16000 msat, c2), PaymentRelayed.Part(10000 msat, c4), PaymentRelayed.Part(14000 msat, c4)))) db.add(NetworkFeePaid(null, n2, c2, Transaction(0, Seq.empty, Seq.empty, 0), 200 sat, "funding")) db.add(NetworkFeePaid(null, n2, c2, Transaction(0, Seq.empty, Seq.empty, 0), 300 sat, "mutual")) db.add(NetworkFeePaid(null, n3, c3, Transaction(0, Seq.empty, Seq.empty, 0), 400 sat, "funding")) db.add(NetworkFeePaid(null, n4, c4, Transaction(0, Seq.empty, Seq.empty, 0), 500 sat, "funding")) - assert(db.stats.toSet === Set( - Stats(channelId = c1, avgPaymentAmount = 42 sat, paymentCount = 3, relayFee = 4 sat, networkFee = 0 sat), - Stats(channelId = c2, avgPaymentAmount = 40 sat, paymentCount = 2, relayFee = 4 sat, networkFee = 500 sat), - Stats(channelId = c3, avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 400 sat), - Stats(channelId = c4, avgPaymentAmount = 30 sat, paymentCount = 2, relayFee = 9 sat, networkFee = 500 sat) + // NB: we only count a relay fee for the outgoing channel, no the incoming one. + assert(db.stats(0, System.currentTimeMillis + 1).toSet === Set( + Stats(channelId = c1, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 0 sat), + Stats(channelId = c1, direction = "OUT", avgPaymentAmount = 42 sat, paymentCount = 3, relayFee = 4 sat, networkFee = 0 sat), + Stats(channelId = c2, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 500 sat), + Stats(channelId = c2, direction = "OUT", avgPaymentAmount = 28 sat, paymentCount = 2, relayFee = 4 sat, networkFee = 500 sat), + Stats(channelId = c3, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 400 sat), + Stats(channelId = c3, direction = "OUT", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 400 sat), + Stats(channelId = c4, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 500 sat), + Stats(channelId = c4, direction = "OUT", avgPaymentAmount = 22 sat, paymentCount = 2, relayFee = 9 sat, networkFee = 500 sat), + Stats(channelId = c5, direction = "IN", avgPaymentAmount = 43 sat, paymentCount = 3, relayFee = 0 sat, networkFee = 0 sat), + Stats(channelId = c5, direction = "OUT", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 0 sat), + Stats(channelId = c6, direction = "IN", avgPaymentAmount = 39 sat, paymentCount = 4, relayFee = 0 sat, networkFee = 0 sat), + Stats(channelId = c6, direction = "OUT", avgPaymentAmount = 40 sat, paymentCount = 1, relayFee = 5 sat, networkFee = 0 sat) )) } @@ -148,7 +159,7 @@ class SqliteAuditDbSpec extends AnyFunSuite { }) // Test starts here. val start = System.currentTimeMillis - assert(db.stats.nonEmpty) + assert(db.stats(0, start + 1).nonEmpty) val end = System.currentTimeMillis fail(s"took ${end - start}ms") } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala index 26270c6a8..95e3bac4d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala @@ -18,29 +18,25 @@ package fr.acinq.eclair.payment import java.util.UUID -import akka.actor.ActorRef +import akka.actor.{ActorRef, Status} import akka.testkit.{TestFSMRef, TestProbe} -import fr.acinq.bitcoin.{Block, Crypto, Satoshi} -import fr.acinq.eclair.TestConstants.TestFeeEstimator +import fr.acinq.bitcoin.{Block, Crypto} import fr.acinq.eclair._ -import fr.acinq.eclair.blockchain.fee.FeeratesPerKw -import fr.acinq.eclair.channel.{ChannelFlags, Commitments, CommitmentsSpec, Upstream} +import fr.acinq.eclair.channel.{AddHtlcFailed, ChannelFlags, ChannelUnavailable, Upstream} import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.payment.PaymentSent.PartialPayment -import fr.acinq.eclair.payment.relay.Relayer.{GetOutgoingChannels, OutgoingChannel, OutgoingChannels} +import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle._ +import fr.acinq.eclair.payment.send.PaymentError.RetryExhausted import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig -import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPayment -import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentError} +import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToRoute +import fr.acinq.eclair.router.RouteNotFound import fr.acinq.eclair.router.Router._ -import fr.acinq.eclair.router._ import fr.acinq.eclair.wire._ +import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike -import org.scalatest.{Outcome, Tag} -import scodec.bits.ByteVector +import scodec.bits.{ByteVector, HexStringSyntax} import scala.concurrent.duration._ -import scala.util.Random /** * Created by t-bast on 18/07/2019. @@ -50,11 +46,10 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS import MultiPartPaymentLifecycleSpec._ - case class FixtureParam(paymentId: UUID, + case class FixtureParam(cfg: SendPaymentConfig, nodeParams: NodeParams, payFsm: TestFSMRef[MultiPartPaymentLifecycle.State, MultiPartPaymentLifecycle.Data, MultiPartPaymentLifecycle], router: TestProbe, - relayer: TestProbe, sender: TestProbe, childPayFsm: TestProbe, eventListener: TestProbe) @@ -63,445 +58,396 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS val id = UUID.randomUUID() val cfg = SendPaymentConfig(id, id, Some("42"), paymentHash, finalAmount, finalRecipient, Upstream.Local(id), None, storeInDb = true, publishEvent = true, Nil) val nodeParams = TestConstants.Alice.nodeParams - nodeParams.onChainFeeConf.feeEstimator.asInstanceOf[TestFeeEstimator].setFeerate(FeeratesPerKw.single(500)) - val (childPayFsm, router, relayer, sender, eventListener) = (TestProbe(), TestProbe(), TestProbe(), TestProbe(), TestProbe()) - class TestMultiPartPaymentLifecycle extends MultiPartPaymentLifecycle(nodeParams, cfg, relayer.ref, router.ref, TestProbe().ref) { + val (childPayFsm, router, sender, eventListener) = (TestProbe(), TestProbe(), TestProbe(), TestProbe()) + class TestMultiPartPaymentLifecycle extends MultiPartPaymentLifecycle(nodeParams, cfg, router.ref, TestProbe().ref) { override def spawnChildPaymentFsm(childId: UUID): ActorRef = childPayFsm.ref } val paymentHandler = TestFSMRef(new TestMultiPartPaymentLifecycle().asInstanceOf[MultiPartPaymentLifecycle]) system.eventStream.subscribe(eventListener.ref, classOf[PaymentEvent]) - withFixture(test.toNoArgTest(FixtureParam(id, nodeParams, paymentHandler, router, relayer, sender, childPayFsm, eventListener))) + withFixture(test.toNoArgTest(FixtureParam(cfg, nodeParams, paymentHandler, router, sender, childPayFsm, eventListener))) } - def initPayment(f: FixtureParam, request: SendMultiPartPayment, networkStats: NetworkStats, localChannels: OutgoingChannels): Unit = { - import f._ - sender.send(payFsm, request) - router.expectMsg(GetNetworkStats) - router.send(payFsm, GetNetworkStatsResponse(Some(networkStats))) - relayer.expectMsg(GetOutgoingChannels()) - relayer.send(payFsm, localChannels) - } - - def waitUntilAmountSent(f: FixtureParam, amount: MilliSatoshi): Unit = { - Iterator.iterate(0 msat)(sent => { - sent + f.childPayFsm.expectMsgType[SendPayment].finalPayload.amount - }).takeWhile(sent => sent < amount) - } - - test("get network statistics and usable balances before paying") { f => + test("successful first attempt (single part)") { f => import f._ assert(payFsm.stateName === WAIT_FOR_PAYMENT_REQUEST) - val payment = SendMultiPartPayment(randomBytes32, b, 1500 * 1000 msat, expiry, 1) + val payment = SendMultiPartPayment(randomBytes32, e, finalAmount, expiry, 1, routeParams = Some(routeParams.copy(randomize = true))) sender.send(payFsm, payment) - router.expectMsg(GetNetworkStats) - assert(payFsm.stateName === WAIT_FOR_NETWORK_STATS) - router.send(payFsm, GetNetworkStatsResponse(Some(emptyStats))) - relayer.expectMsg(GetOutgoingChannels()) - awaitCond(payFsm.stateName === WAIT_FOR_CHANNEL_BALANCES) - assert(payFsm.stateData.asInstanceOf[WaitingForChannelBalances].networkStats === Some(emptyStats)) + + router.expectMsg(RouteRequest(nodeParams.nodeId, e, finalAmount, maxFee, routeParams = Some(routeParams.copy(randomize = false)), allowMultiPart = true, paymentContext = Some(cfg.paymentContext))) + assert(payFsm.stateName === WAIT_FOR_ROUTES) + + val singleRoute = Route(finalAmount, hop_ab_1 :: hop_be :: Nil) + router.send(payFsm, RouteResponse(Seq(singleRoute))) + val childPayment = childPayFsm.expectMsgType[SendPaymentToRoute] + assert(childPayment.route === Right(singleRoute)) + assert(childPayment.finalPayload.expiry === expiry) + assert(childPayment.finalPayload.paymentSecret === Some(payment.paymentSecret)) + assert(childPayment.finalPayload.amount === finalAmount) + assert(childPayment.finalPayload.totalAmount === finalAmount) + assert(payFsm.stateName === PAYMENT_IN_PROGRESS) + + val result = fulfillPendingPayments(f, 1) + assert(result.amountWithFees === finalAmount + 100.msat) + assert(result.trampolineFees === 0.msat) + assert(result.nonTrampolineFees === 100.msat) } - test("get network statistics not available") { f => + test("successful first attempt (multiple parts)") { f => import f._ assert(payFsm.stateName === WAIT_FOR_PAYMENT_REQUEST) - val payment = SendMultiPartPayment(randomBytes32, b, 2500 * 1000 msat, expiry, 1) + val payment = SendMultiPartPayment(randomBytes32, e, 1200000 msat, expiry, 1, routeParams = Some(routeParams.copy(randomize = false))) sender.send(payFsm, payment) - router.expectMsg(GetNetworkStats) - assert(payFsm.stateName === WAIT_FOR_NETWORK_STATS) - router.send(payFsm, GetNetworkStatsResponse(None)) - // If network stats aren't available we'll use local channel balance information instead. - // We should ask the router to compute statistics (for next payment attempts). - router.expectMsg(TickComputeNetworkStats) - relayer.expectMsg(GetOutgoingChannels()) - awaitCond(payFsm.stateName === WAIT_FOR_CHANNEL_BALANCES) - assert(payFsm.stateData.asInstanceOf[WaitingForChannelBalances].networkStats === None) - relayer.send(payFsm, localChannels()) - awaitCond(payFsm.stateName === PAYMENT_IN_PROGRESS) - waitUntilAmountSent(f, payment.totalAmount) - val payments = payFsm.stateData.asInstanceOf[PaymentProgress].pending.values - assert(payments.size > 1) - } + router.expectMsg(RouteRequest(nodeParams.nodeId, e, 1200000 msat, maxFee, routeParams = Some(routeParams.copy(randomize = false)), allowMultiPart = true, paymentContext = Some(cfg.paymentContext))) + assert(payFsm.stateName === WAIT_FOR_ROUTES) - test("send to peer node via multiple channels") { f => - import f._ - val payment = SendMultiPartPayment(randomBytes32, b, 2000 * 1000 msat, expiry, 3) - // When sending to a peer node, we should not filter out unannounced channels. - val channels = OutgoingChannels(Seq( - OutgoingChannel(c, channelUpdate_ac_2, makeCommitments(1000 * 1000 msat, 0)), - OutgoingChannel(c, channelUpdate_ac_3, makeCommitments(1500 * 1000 msat, 0)), - OutgoingChannel(b, channelUpdate_ab_1.copy(channelFlags = ChannelFlags.Empty), makeCommitments(1000 * 1000 msat, 0, announceChannel = false)), - OutgoingChannel(b, channelUpdate_ab_2.copy(channelFlags = ChannelFlags.Empty), makeCommitments(1500 * 1000 msat, 0, announceChannel = false)), - OutgoingChannel(d, channelUpdate_ad_1, makeCommitments(1000 * 1000 msat, 0)))) - // Network statistics should be ignored when sending to peer. - initPayment(f, payment, emptyStats, channels) - - // The payment should be split in two, using direct channels with b. - // MaxAttempts should be set to 1 when using direct channels to the destination. - childPayFsm.expectMsgAllOf( - SendPayment(b, Onion.createMultiPartPayload(1000 * 1000 msat, payment.totalAmount, expiry, payment.paymentSecret), 1, routePrefix = Seq(ChannelHop(nodeParams.nodeId, b, channelUpdate_ab_1.copy(channelFlags = ChannelFlags.Empty)))), - SendPayment(b, Onion.createMultiPartPayload(1000 * 1000 msat, payment.totalAmount, expiry, payment.paymentSecret), 1, routePrefix = Seq(ChannelHop(nodeParams.nodeId, b, channelUpdate_ab_2.copy(channelFlags = ChannelFlags.Empty)))) + val routes = Seq( + Route(500000 msat, hop_ab_1 :: hop_be :: Nil), + Route(700000 msat, hop_ac_1 :: hop_ce :: Nil) ) - childPayFsm.expectNoMsg(50 millis) - val childIds = payFsm.stateData.asInstanceOf[PaymentProgress].pending.keys.toSeq - assert(childIds.length === 2) + router.send(payFsm, RouteResponse(routes)) + val childPayments = childPayFsm.expectMsgType[SendPaymentToRoute] :: childPayFsm.expectMsgType[SendPaymentToRoute] :: Nil + assert(childPayments.map(_.route).toSet === routes.map(r => Right(r)).toSet) + assert(childPayments.map(_.finalPayload.expiry).toSet === Set(expiry)) + assert(childPayments.map(_.finalPayload.paymentSecret.get).toSet === Set(payment.paymentSecret)) + assert(childPayments.map(_.finalPayload.amount).toSet === Set(500000 msat, 700000 msat)) + assert(childPayments.map(_.finalPayload.totalAmount).toSet === Set(1200000 msat)) + assert(payFsm.stateName === PAYMENT_IN_PROGRESS) - val pp1 = PartialPayment(childIds.head, 1000 * 1000 msat, 0 msat, randomBytes32, None) - val pp2 = PartialPayment(childIds(1), 1000 * 1000 msat, 0 msat, randomBytes32, None) - childPayFsm.send(payFsm, PaymentSent(paymentId, paymentHash, paymentPreimage, finalAmount, b, Seq(pp1))) - childPayFsm.send(payFsm, PaymentSent(paymentId, paymentHash, paymentPreimage, finalAmount, b, Seq(pp2))) - val expectedMsg = PaymentSent(paymentId, paymentHash, paymentPreimage, finalAmount, finalRecipient, Seq(pp1, pp2)) - sender.expectMsg(expectedMsg) - eventListener.expectMsg(expectedMsg) - - assert(expectedMsg.recipientAmount === finalAmount) - assert(expectedMsg.amountWithFees === (2000 * 1000).msat) - assert(expectedMsg.trampolineFees === (1000 * 1000).msat) - assert(expectedMsg.nonTrampolineFees === 0.msat) - assert(expectedMsg.feesPaid === expectedMsg.trampolineFees) + val result = fulfillPendingPayments(f, 2) + assert(result.amountWithFees === 1200200.msat) + assert(result.trampolineFees === 200000.msat) + assert(result.nonTrampolineFees === 200.msat) } - test("send to peer node via single big channel") { f => + test("send custom tlv records") { f => import f._ - val payment = SendMultiPartPayment(randomBytes32, b, 1000 * 1000 msat, expiry, 1) - // Network statistics should be ignored when sending to peer (otherwise we should have split into multiple payments). - initPayment(f, payment, emptyStats.copy(capacity = Stats.generate(Seq(100), d => Satoshi(d.toLong))), localChannels(0)) - childPayFsm.expectMsg(SendPayment(b, Onion.createMultiPartPayload(payment.totalAmount, payment.totalAmount, expiry, payment.paymentSecret), 1, routePrefix = Seq(ChannelHop(nodeParams.nodeId, b, channelUpdate_ab_1)))) - childPayFsm.expectNoMsg(50 millis) - } - test("send to peer node via remote channels") { f => - import f._ - // d only has a single channel with capacity 1000 sat, we try to send more. - val payment = SendMultiPartPayment(randomBytes32, d, 2000 * 1000 msat, expiry, 1) - val testChannels = localChannels() - val balanceToTarget = testChannels.channels.filter(_.nextNodeId == d).map(_.commitments.availableBalanceForSend).sum - assert(balanceToTarget < (1000 * 1000).msat) // the commit tx fee prevents us from completely emptying our channel - initPayment(f, payment, emptyStats.copy(capacity = Stats.generate(Seq(500), d => Satoshi(d.toLong))), testChannels) - waitUntilAmountSent(f, payment.totalAmount) - val payments = payFsm.stateData.asInstanceOf[PaymentProgress].pending.values - assert(payments.size > 1) - val directPayments = payments.filter(p => p.routePrefix.head.nextNodeId == d) - assert(directPayments.size === 1) - assert(directPayments.head.finalPayload.amount === balanceToTarget) - } - - test("send to remote node without splitting") { f => - import f._ - val payment = SendMultiPartPayment(randomBytes32, e, 300 * 1000 msat, expiry, 1) - initPayment(f, payment, emptyStats.copy(capacity = Stats.generate(Seq(1500), d => Satoshi(d.toLong))), localChannels()) - waitUntilAmountSent(f, payment.totalAmount) - payFsm.stateData.asInstanceOf[PaymentProgress].pending.foreach { - case (id, payment) => childPayFsm.send(payFsm, PaymentSent(paymentId, paymentHash, paymentPreimage, finalAmount, e, Seq(PartialPayment(id, payment.finalPayload.amount, 5 msat, randomBytes32, None)))) - } - - val result = sender.expectMsgType[PaymentSent] - assert(result.id === paymentId) - assert(result.amountWithFees === payment.totalAmount + result.nonTrampolineFees) - assert(result.parts.length === 1) - } - - test("send to remote node via multiple channels") { f => - import f._ - val payment = SendMultiPartPayment(randomBytes32, e, 3200 * 1000 msat, expiry, 3) - // A network capacity of 1000 sat should split the payment in at least 3 parts. - initPayment(f, payment, emptyStats.copy(capacity = Stats.generate(Seq(1000), d => Satoshi(d.toLong))), localChannels()) - - val payments = Iterator.iterate(0 msat)(sent => { - val child = childPayFsm.expectMsgType[SendPayment] - assert(child.targetNodeId === e) - assert(child.maxAttempts === 3) - assert(child.finalPayload.expiry === expiry) - assert(child.finalPayload.paymentSecret === Some(payment.paymentSecret)) - assert(child.finalPayload.totalAmount === payment.totalAmount) - assert(child.routePrefix.length === 1 && child.routePrefix.head.nodeId === nodeParams.nodeId) - assert(sent + child.finalPayload.amount <= payment.totalAmount) - sent + child.finalPayload.amount - }).takeWhile(sent => sent != payment.totalAmount).toSeq - assert(payments.length > 2) - assert(payments.length < 10) - childPayFsm.expectNoMsg(50 millis) - - val pending = payFsm.stateData.asInstanceOf[PaymentProgress].pending - val partialPayments = pending.map { - case (id, payment) => PartialPayment(id, payment.finalPayload.amount, 1 msat, randomBytes32, Some(hop_ac_1 :: hop_ab_2 :: Nil)) - } - partialPayments.foreach(pp => childPayFsm.send(payFsm, PaymentSent(paymentId, paymentHash, paymentPreimage, finalAmount, e, Seq(pp)))) - val result = sender.expectMsgType[PaymentSent] - assert(result.id === paymentId) - assert(result.paymentHash === paymentHash) - assert(result.paymentPreimage === paymentPreimage) - assert(result.parts === partialPayments) - assert(result.recipientAmount === finalAmount) - assert(result.amountWithFees > (3200 * 1000).msat) - assert(result.trampolineFees === (2200 * 1000).msat) - assert(result.nonTrampolineFees === partialPayments.map(_.feesPaid).sum) - } - - test("send to remote node via single big channel") { f => - import f._ - val payment = SendMultiPartPayment(randomBytes32, e, 3500 * 1000 msat, expiry, 3) - // When splitting inside a channel, we need to take the fees of the commit tx into account (multiple outgoing HTLCs - // will increase the size of the commit tx and thus its fee. - val feeRatePerKw = 100 - // A network capacity of 1500 sat should split the payment in at least 2 parts. - // We have a single big channel inside which we'll send multiple payments. - val localChannel = OutgoingChannels(Seq(OutgoingChannel(b, channelUpdate_ab_1, makeCommitments(5000 * 1000 msat, feeRatePerKw)))) - initPayment(f, payment, emptyStats.copy(capacity = Stats.generate(Seq(1500), d => Satoshi(d.toLong))), localChannel) - waitUntilAmountSent(f, payment.totalAmount) - - val pending = payFsm.stateData.asInstanceOf[PaymentProgress].pending - assert(pending.size >= 2) - val partialPayments = pending.map { - case (id, payment) => PartialPayment(id, payment.finalPayload.amount, 1 msat, randomBytes32, None) - } - partialPayments.foreach(pp => childPayFsm.send(payFsm, PaymentSent(paymentId, paymentHash, paymentPreimage, payment.totalAmount, e, Seq(pp)))) - val result = sender.expectMsgType[PaymentSent] - assert(result.id === paymentId) - assert(result.paymentHash === paymentHash) - assert(result.paymentPreimage === paymentPreimage) - assert(result.parts === partialPayments) - assert(result.amountWithFees - result.nonTrampolineFees === (3500 * 1000).msat) - assert(result.recipientNodeId === finalRecipient) // the recipient is obtained from the config, not from the request (which may be to the first trampoline node) - assert(result.nonTrampolineFees === partialPayments.map(_.feesPaid).sum) - } - - test("send to remote trampoline node") { f => - import f._ + // We include a bunch of additional tlv records. val trampolineTlv = OnionTlv.TrampolineOnion(OnionRoutingPacket(0, ByteVector.fill(33)(0), ByteVector.fill(400)(0), randomBytes32)) - val payment = SendMultiPartPayment(randomBytes32, e, 3000 * 1000 msat, expiry, 3, additionalTlvs = Seq(trampolineTlv)) - initPayment(f, payment, emptyStats.copy(capacity = Stats.generate(Seq(1000), d => Satoshi(d.toLong))), localChannels()) - waitUntilAmountSent(f, payment.totalAmount) - - val pending = payFsm.stateData.asInstanceOf[PaymentProgress].pending - pending.foreach { - case (_, p) => assert(p.finalPayload.asInstanceOf[Onion.FinalTlvPayload].records.get[OnionTlv.TrampolineOnion] === Some(trampolineTlv)) - } - } - - test("split fees between child payments") { f => - import f._ - val routeParams = RouteParams(randomize = false, 100 msat, 0.05, 20, CltvExpiryDelta(144), None, MultiPartParams(10000 msat, 5)) - val payment = SendMultiPartPayment(randomBytes32, e, 3000 * 1000 msat, expiry, 3, routeParams = Some(routeParams)) - initPayment(f, payment, emptyStats.copy(capacity = Stats.generate(Seq(1000), d => Satoshi(d.toLong))), localChannels()) - waitUntilAmountSent(f, 3000 * 1000 msat) - - val pending = payFsm.stateData.asInstanceOf[PaymentProgress].pending - assert(pending.size >= 2) - pending.foreach { - case (_, p) => - assert(p.routeParams.get.maxFeeBase < 50.msat) - assert(p.routeParams.get.maxFeePct == 0.05) // fee percent doesn't need to change - } - } - - test("skip empty channels") { f => - import f._ - val payment = SendMultiPartPayment(randomBytes32, e, 3000 * 1000 msat, expiry, 3) - val testChannels = localChannels() - val testChannels1 = testChannels.copy(channels = testChannels.channels ++ Seq( - OutgoingChannel(b, channelUpdate_ab_1.copy(shortChannelId = ShortChannelId(42)), makeCommitments(0 msat, 10)), - OutgoingChannel(e, channelUpdate_ab_1.copy(shortChannelId = ShortChannelId(43)), makeCommitments(0 msat, 10) - ))) - initPayment(f, payment, emptyStats.copy(capacity = Stats.generate(Seq(1000), d => Satoshi(d.toLong))), testChannels1) - waitUntilAmountSent(f, payment.totalAmount) - payFsm.stateData.asInstanceOf[PaymentProgress].pending.foreach { - case (id, p) => childPayFsm.send(payFsm, PaymentSent(paymentId, paymentHash, paymentPreimage, payment.totalAmount, e, Seq(PartialPayment(id, p.finalPayload.amount, 5 msat, randomBytes32, None)))) - } - - val result = sender.expectMsgType[PaymentSent] - assert(result.id === paymentId) - assert(result.amountWithFees > payment.totalAmount) - } - - test("retry after error") { f => - import f._ - val payment = SendMultiPartPayment(randomBytes32, e, 3000 * 1000 msat, expiry, 3) - val testChannels = localChannels() - // A network capacity of 1000 sat should split the payment in at least 3 parts. - initPayment(f, payment, emptyStats.copy(capacity = Stats.generate(Seq(1000), d => Satoshi(d.toLong))), testChannels) - waitUntilAmountSent(f, payment.totalAmount) - val pending = payFsm.stateData.asInstanceOf[PaymentProgress].pending - assert(pending.size > 2) - - // Simulate a local channel failure and a remote failure. - val faultyLocalChannelId = getFirstHopShortChannelId(pending.head._2) - val faultyLocalPayments = pending.filter { case (_, p) => getFirstHopShortChannelId(p) == faultyLocalChannelId } - val faultyRemotePayment = pending.filter { case (_, p) => getFirstHopShortChannelId(p) != faultyLocalChannelId }.head - faultyLocalPayments.keys.foreach(id => { - childPayFsm.send(payFsm, PaymentFailed(id, paymentHash, LocalFailure(Nil, RouteNotFound) :: Nil)) + val userCustomTlv = GenericTlv(UInt64(561), hex"deadbeef") + val payment = SendMultiPartPayment(randomBytes32, e, finalAmount + 1000.msat, expiry, 1, routeParams = Some(routeParams), additionalTlvs = Seq(trampolineTlv), userCustomTlvs = Seq(userCustomTlv)) + sender.send(payFsm, payment) + router.expectMsgType[RouteRequest] + router.send(payFsm, RouteResponse(Seq(Route(500000 msat, hop_ab_1 :: hop_be :: Nil), Route(501000 msat, hop_ac_1 :: hop_ce :: Nil)))) + val childPayments = childPayFsm.expectMsgType[SendPaymentToRoute] :: childPayFsm.expectMsgType[SendPaymentToRoute] :: Nil + childPayments.map(_.finalPayload.asInstanceOf[Onion.FinalTlvPayload]).foreach(p => { + assert(p.records.get[OnionTlv.TrampolineOnion] === Some(trampolineTlv)) + assert(p.records.unknown.toSeq === Seq(userCustomTlv)) }) - childPayFsm.send(payFsm, PaymentFailed(faultyRemotePayment._1, paymentHash, UnreadableRemoteFailure(Nil) :: Nil)) - // We should ask for updated balance to take into account pending payments. - relayer.expectMsg(GetOutgoingChannels()) - relayer.send(payFsm, testChannels.copy(channels = testChannels.channels.dropRight(2))) - - // The channel that lead to a RouteNotFound should be ignored. - assert(payFsm.stateData.asInstanceOf[PaymentProgress].ignoreChannels === Set(faultyLocalChannelId)) - - // New payments should be sent that match the failed amount. - waitUntilAmountSent(f, faultyRemotePayment._2.finalPayload.amount + faultyLocalPayments.values.map(_.finalPayload.amount).sum) - val stateData = payFsm.stateData.asInstanceOf[PaymentProgress] - assert(stateData.failures.toSet === Set(LocalFailure(Nil, RouteNotFound), UnreadableRemoteFailure(Nil))) - assert(stateData.pending.values.forall(p => getFirstHopShortChannelId(p) != faultyLocalChannelId)) + val result = fulfillPendingPayments(f, 2) + assert(result.trampolineFees === 1000.msat) } - test("cannot send (not enough capacity on local channels)") { f => + test("successful retry") { f => import f._ - val payment = SendMultiPartPayment(randomBytes32, e, 3000 * 1000 msat, expiry, 3) - initPayment(f, payment, emptyStats.copy(capacity = Stats.generate(Seq(1000), d => Satoshi(d.toLong))), OutgoingChannels(Seq( - OutgoingChannel(b, channelUpdate_ab_1, makeCommitments(1000 * 1000 msat, 10)), - OutgoingChannel(c, channelUpdate_ac_2, makeCommitments(1000 * 1000 msat, 10)), - OutgoingChannel(d, channelUpdate_ad_1, makeCommitments(1000 * 1000 msat, 10)))) - ) + + val payment = SendMultiPartPayment(randomBytes32, e, finalAmount, expiry, 3, routeParams = Some(routeParams)) + sender.send(payFsm, payment) + router.expectMsgType[RouteRequest] + val failingRoute = Route(finalAmount, hop_ab_1 :: hop_be :: Nil) + router.send(payFsm, RouteResponse(Seq(failingRoute))) + childPayFsm.expectMsgType[SendPaymentToRoute] + childPayFsm.expectNoMsg(100 millis) + + val childId = payFsm.stateData.asInstanceOf[PaymentProgress].pending.keys.head + childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(failingRoute.hops, Sphinx.DecryptedFailurePacket(b, PermanentChannelFailure))))) + // We retry ignoring the failing channel. + router.expectMsg(RouteRequest(nodeParams.nodeId, e, finalAmount, maxFee, routeParams = Some(routeParams.copy(randomize = true)), allowMultiPart = true, ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_be, b, e))), paymentContext = Some(cfg.paymentContext))) + router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ac_1 :: hop_ce :: Nil), Route(600000 msat, hop_ad :: hop_de :: Nil)))) + childPayFsm.expectMsgType[SendPaymentToRoute] + childPayFsm.expectMsgType[SendPaymentToRoute] + assert(!payFsm.stateData.asInstanceOf[PaymentProgress].pending.contains(childId)) + + val result = fulfillPendingPayments(f, 2) + assert(result.amountWithFees === 1000200.msat) + assert(result.trampolineFees === 0.msat) + assert(result.nonTrampolineFees === 200.msat) + } + + test("retry failures while waiting for routes") { f => + import f._ + + val payment = SendMultiPartPayment(randomBytes32, e, finalAmount, expiry, 3, routeParams = Some(routeParams)) + sender.send(payFsm, payment) + router.expectMsgType[RouteRequest] + router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ab_1 :: hop_be :: Nil), Route(600000 msat, hop_ab_2 :: hop_be :: Nil)))) + childPayFsm.expectMsgType[SendPaymentToRoute] + childPayFsm.expectMsgType[SendPaymentToRoute] + childPayFsm.expectNoMsg(100 millis) + + val (failedId1, failedRoute1) :: (failedId2, failedRoute2) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toList + childPayFsm.send(payFsm, PaymentFailed(failedId1, paymentHash, Seq(RemoteFailure(failedRoute1.hops, Sphinx.DecryptedFailurePacket(b, TemporaryNodeFailure))))) + + // When we retry, we ignore the failing node and we let the router know about the remaining pending route. + router.expectMsg(RouteRequest(nodeParams.nodeId, e, failedRoute1.amount, maxFee - failedRoute1.fee, ignore = Ignore(Set(b), Set.empty), pendingPayments = Seq(failedRoute2), allowMultiPart = true, routeParams = Some(routeParams.copy(randomize = true)), paymentContext = Some(cfg.paymentContext))) + // The second part fails while we're still waiting for new routes. + childPayFsm.send(payFsm, PaymentFailed(failedId2, paymentHash, Seq(RemoteFailure(failedRoute2.hops, Sphinx.DecryptedFailurePacket(b, TemporaryNodeFailure))))) + // We receive a response to our first request, but it's now obsolete: we re-sent a new route request that takes into + // account the latest failures. + router.send(payFsm, RouteResponse(Seq(Route(failedRoute1.amount, hop_ac_1 :: hop_ce :: Nil)))) + router.expectMsg(RouteRequest(nodeParams.nodeId, e, finalAmount, maxFee, ignore = Ignore(Set(b), Set.empty), allowMultiPart = true, routeParams = Some(routeParams.copy(randomize = true)), paymentContext = Some(cfg.paymentContext))) + awaitCond(payFsm.stateData.asInstanceOf[PaymentProgress].pending.isEmpty) + childPayFsm.expectNoMsg(100 millis) + + // We receive new routes that work. + router.send(payFsm, RouteResponse(Seq(Route(300000 msat, hop_ac_1 :: hop_ce :: Nil), Route(700000 msat, hop_ad :: hop_de :: Nil)))) + childPayFsm.expectMsgType[SendPaymentToRoute] + childPayFsm.expectMsgType[SendPaymentToRoute] + + val result = fulfillPendingPayments(f, 2) + assert(result.amountWithFees === 1000200.msat) + assert(result.nonTrampolineFees === 200.msat) + } + + test("retry without ignoring channels") { f => + import f._ + + val payment = SendMultiPartPayment(randomBytes32, e, finalAmount, expiry, 3, routeParams = Some(routeParams)) + sender.send(payFsm, payment) + router.expectMsgType[RouteRequest] + router.send(payFsm, RouteResponse(Seq(Route(500000 msat, hop_ab_1 :: hop_be :: Nil), Route(500000 msat, hop_ab_1 :: hop_be :: Nil)))) + childPayFsm.expectMsgType[SendPaymentToRoute] + childPayFsm.expectMsgType[SendPaymentToRoute] + childPayFsm.expectNoMsg(100 millis) + + val (failedId, failedRoute) :: (_, pendingRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toList + childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(LocalFailure(failedRoute.hops, AddHtlcFailed(randomBytes32, paymentHash, ChannelUnavailable(randomBytes32), null, None, None))))) + + // If the router doesn't find routes, we will retry without ignoring the channel: it may work with a different split + // of the amount to send. + val expectedRouteRequest = RouteRequest( + nodeParams.nodeId, e, + failedRoute.amount, maxFee - failedRoute.fee, + ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_ab_1, a, b))), + pendingPayments = Seq(pendingRoute), + allowMultiPart = true, + routeParams = Some(routeParams.copy(randomize = true)), + paymentContext = Some(cfg.paymentContext)) + router.expectMsg(expectedRouteRequest) + router.send(payFsm, Status.Failure(RouteNotFound)) + router.expectMsg(expectedRouteRequest.copy(ignore = Ignore.empty)) + + router.send(payFsm, RouteResponse(Seq(Route(500000 msat, hop_ac_1 :: hop_ce :: Nil)))) + childPayFsm.expectMsgType[SendPaymentToRoute] + + val result = fulfillPendingPayments(f, 2) + assert(result.amountWithFees === 1000200.msat) + } + + test("abort after too many failed attempts") { f => + import f._ + + val payment = SendMultiPartPayment(randomBytes32, e, finalAmount, expiry, 2, routeParams = Some(routeParams)) + sender.send(payFsm, payment) + router.expectMsgType[RouteRequest] + router.send(payFsm, RouteResponse(Seq(Route(500000 msat, hop_ab_1 :: hop_be :: Nil), Route(500000 msat, hop_ac_1 :: hop_ce :: Nil)))) + childPayFsm.expectMsgType[SendPaymentToRoute] + childPayFsm.expectMsgType[SendPaymentToRoute] + + val (failedId1, failedRoute1) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head + childPayFsm.send(payFsm, PaymentFailed(failedId1, paymentHash, Seq(UnreadableRemoteFailure(failedRoute1.hops)))) + router.expectMsgType[RouteRequest] + router.send(payFsm, RouteResponse(Seq(Route(500000 msat, hop_ad :: hop_de :: Nil)))) + childPayFsm.expectMsgType[SendPaymentToRoute] + + assert(!payFsm.stateData.asInstanceOf[PaymentProgress].pending.contains(failedId1)) + val (failedId2, failedRoute2) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head + val result = abortAfterFailure(f, PaymentFailed(failedId2, paymentHash, Seq(UnreadableRemoteFailure(failedRoute2.hops)))) + assert(result.failures.length >= 3) + assert(result.failures.contains(LocalFailure(Nil, RetryExhausted))) + } + + test("abort if no routes found") { f => + import f._ + + sender.watch(payFsm) + val payment = SendMultiPartPayment(randomBytes32, e, finalAmount, expiry, 5, routeParams = Some(routeParams)) + sender.send(payFsm, payment) + router.expectMsgType[RouteRequest] + router.send(payFsm, Status.Failure(RouteNotFound)) + val result = sender.expectMsgType[PaymentFailed] - assert(result.id === paymentId) + assert(result.id === cfg.id) assert(result.paymentHash === paymentHash) + assert(result.failures === Seq(LocalFailure(Nil, RouteNotFound))) + + sender.expectTerminated(payFsm) + sender.expectNoMsg(100 millis) + router.expectNoMsg(100 millis) + childPayFsm.expectNoMsg(100 millis) + } + + test("abort if recipient sends error") { f => + import f._ + + val payment = SendMultiPartPayment(randomBytes32, e, finalAmount, expiry, 5, routeParams = Some(routeParams)) + sender.send(payFsm, payment) + router.expectMsgType[RouteRequest] + router.send(payFsm, RouteResponse(Seq(Route(finalAmount, hop_ab_1 :: hop_be :: Nil)))) + childPayFsm.expectMsgType[SendPaymentToRoute] + + val (failedId, failedRoute) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head + val result = abortAfterFailure(f, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.hops, Sphinx.DecryptedFailurePacket(e, IncorrectOrUnknownPaymentDetails(600000 msat, 0)))))) assert(result.failures.length === 1) - assert(result.failures.head.asInstanceOf[LocalFailure].t === PaymentError.BalanceTooLow) } - test("cannot send (fee rate too high)") { f => + test("abort if recipient sends error during retry") { f => import f._ - val payment = SendMultiPartPayment(randomBytes32, e, 2500 * 1000 msat, expiry, 3) - initPayment(f, payment, emptyStats.copy(capacity = Stats.generate(Seq(1000), d => Satoshi(d.toLong))), OutgoingChannels(Seq( - OutgoingChannel(b, channelUpdate_ab_1, makeCommitments(1500 * 1000 msat, 1000)), - OutgoingChannel(c, channelUpdate_ac_2, makeCommitments(1500 * 1000 msat, 1000)), - OutgoingChannel(d, channelUpdate_ad_1, makeCommitments(1500 * 1000 msat, 1000)))) - ) - val result = sender.expectMsgType[PaymentFailed] - assert(result.id === paymentId) - assert(result.paymentHash === paymentHash) - assert(result.failures.length === 1) - assert(result.failures.head.asInstanceOf[LocalFailure].t === PaymentError.BalanceTooLow) + + val payment = SendMultiPartPayment(randomBytes32, e, finalAmount, expiry, 5, routeParams = Some(routeParams)) + sender.send(payFsm, payment) + router.expectMsgType[RouteRequest] + router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ab_1 :: hop_be :: Nil), Route(600000 msat, hop_ac_1 :: hop_ce :: Nil)))) + childPayFsm.expectMsgType[SendPaymentToRoute] + childPayFsm.expectMsgType[SendPaymentToRoute] + + val (failedId1, failedRoute1) :: (failedId2, failedRoute2) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toList + childPayFsm.send(payFsm, PaymentFailed(failedId1, paymentHash, Seq(UnreadableRemoteFailure(failedRoute1.hops)))) + router.expectMsgType[RouteRequest] + + val result = abortAfterFailure(f, PaymentFailed(failedId2, paymentHash, Seq(RemoteFailure(failedRoute2.hops, Sphinx.DecryptedFailurePacket(e, PaymentTimeout))))) + assert(result.failures.length === 2) } - test("payment timeout") { f => + test("receive partial success after retriable failure (recipient spec violation)") { f => import f._ - val payment = SendMultiPartPayment(randomBytes32, e, 3000 * 1000 msat, expiry, 5) - initPayment(f, payment, emptyStats.copy(capacity = Stats.generate(Seq(1000), d => Satoshi(d.toLong))), localChannels()) - waitUntilAmountSent(f, payment.totalAmount) - val (childId1, _) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head - // If we receive a timeout failure, we directly abort the payment instead of retrying. - childPayFsm.send(payFsm, PaymentFailed(childId1, paymentHash, RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(e, PaymentTimeout)) :: Nil)) - relayer.expectNoMsg(50 millis) - awaitCond(payFsm.stateName === PAYMENT_ABORTED) - } + val payment = SendMultiPartPayment(randomBytes32, e, finalAmount, expiry, 5, routeParams = Some(routeParams)) + sender.send(payFsm, payment) + router.expectMsgType[RouteRequest] + router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ab_1 :: hop_be :: Nil), Route(600000 msat, hop_ac_1 :: hop_ce :: Nil)))) + childPayFsm.expectMsgType[SendPaymentToRoute] + childPayFsm.expectMsgType[SendPaymentToRoute] - test("failure received from final recipient") { f => - import f._ - val payment = SendMultiPartPayment(randomBytes32, e, 3000 * 1000 msat, expiry, 5) - initPayment(f, payment, emptyStats.copy(capacity = Stats.generate(Seq(1000), d => Satoshi(d.toLong))), localChannels()) - waitUntilAmountSent(f, payment.totalAmount) - val (childId1, _) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head + val (failedId, failedRoute) :: (successId, successRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toList + childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(UnreadableRemoteFailure(failedRoute.hops)))) + router.expectMsgType[RouteRequest] - // If we receive a failure from the final node, we directly abort the payment instead of retrying. - childPayFsm.send(payFsm, PaymentFailed(childId1, paymentHash, RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(e, IncorrectOrUnknownPaymentDetails(3000 * 1000 msat, 42))) :: Nil)) - relayer.expectNoMsg(50 millis) - awaitCond(payFsm.stateName === PAYMENT_ABORTED) - } - - test("fail after too many attempts") { f => - import f._ - val payment = SendMultiPartPayment(randomBytes32, e, 3000 * 1000 msat, expiry, 2) - initPayment(f, payment, emptyStats.copy(capacity = Stats.generate(Seq(1000), d => Satoshi(d.toLong))), localChannels()) - waitUntilAmountSent(f, payment.totalAmount) - val (childId1, childPayment1) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head - - // We retry one failure. - val failures = Seq(UnreadableRemoteFailure(hop_ab_1 :: Nil), UnreadableRemoteFailure(hop_ac_1 :: hop_ab_2 :: Nil)) - childPayFsm.send(payFsm, PaymentFailed(childId1, paymentHash, failures.slice(0, 1))) - relayer.expectMsg(GetOutgoingChannels()) - relayer.send(payFsm, localChannels()) - waitUntilAmountSent(f, childPayment1.finalPayload.amount) - - // But another failure occurs... - val (childId2, _) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head - childPayFsm.send(payFsm, PaymentFailed(childId2, paymentHash, failures.slice(1, 2))) - relayer.expectNoMsg(50 millis) - awaitCond(payFsm.stateName === PAYMENT_ABORTED) - - // And then all other payments time out. - payFsm.stateData.asInstanceOf[PaymentAborted].pending.foreach(childId => childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Nil))) - val result = sender.expectMsgType[PaymentFailed] - assert(result.id === paymentId) - assert(result.paymentHash === paymentHash) - assert(result.failures.length === 3) - assert(result.failures.slice(0, 2) === failures) - assert(result.failures.last.asInstanceOf[LocalFailure].t === PaymentError.RetryExhausted) - } - - test("receive partial failure after success (recipient spec violation)") { f => - import f._ - val payment = SendMultiPartPayment(randomBytes32, e, 4000 * 1000 msat, expiry, 2) - initPayment(f, payment, emptyStats.copy(capacity = Stats.generate(Seq(1500), d => Satoshi(d.toLong))), localChannels()) - waitUntilAmountSent(f, payment.totalAmount) - val pending = payFsm.stateData.asInstanceOf[PaymentProgress].pending - - // If one of the payments succeeds, the recipient MUST succeed them all: we can consider the whole payment succeeded. - val (id1, payment1) = pending.head - childPayFsm.send(payFsm, PaymentSent(paymentId, paymentHash, paymentPreimage, payment.totalAmount, e, Seq(PartialPayment(id1, payment1.finalPayload.amount, 0 msat, randomBytes32, None)))) - awaitCond(payFsm.stateName === PAYMENT_SUCCEEDED) - - // A partial failure should simply be ignored. - val (id2, payment2) = pending.tail.head - childPayFsm.send(payFsm, PaymentFailed(id2, paymentHash, Nil)) - - pending.tail.tail.foreach { - case (id, p) => childPayFsm.send(payFsm, PaymentSent(paymentId, paymentHash, paymentPreimage, payment.totalAmount, e, Seq(PartialPayment(id, p.finalPayload.amount, 0 msat, randomBytes32, None)))) - } - val result = sender.expectMsgType[PaymentSent] - assert(result.id === paymentId) - assert(result.amountWithFees === payment.totalAmount - payment2.finalPayload.amount) + val result = fulfillPendingPayments(f, 1) + assert(result.amountWithFees < finalAmount) // we got the preimage without paying the full amount + assert(result.nonTrampolineFees === successRoute.fee) // we paid the fee for only one of the partial payments + assert(result.parts.length === 1 && result.parts.head.id === successId) } test("receive partial success after abort (recipient spec violation)") { f => import f._ - val payment = SendMultiPartPayment(randomBytes32, e, 5000 * 1000 msat, expiry, 1) - initPayment(f, payment, emptyStats.copy(capacity = Stats.generate(Seq(2000), d => Satoshi(d.toLong))), localChannels()) - waitUntilAmountSent(f, payment.totalAmount) - val pending = payFsm.stateData.asInstanceOf[PaymentProgress].pending - // One of the payments failed and we configured maxAttempts = 1, so we abort. - val (id1, _) = pending.head - childPayFsm.send(payFsm, PaymentFailed(id1, paymentHash, Nil)) + val payment = SendMultiPartPayment(randomBytes32, e, finalAmount, expiry, 5, routeParams = Some(routeParams)) + sender.send(payFsm, payment) + router.expectMsgType[RouteRequest] + router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ab_1 :: hop_be :: Nil), Route(600000 msat, hop_ac_1 :: hop_ce :: Nil)))) + childPayFsm.expectMsgType[SendPaymentToRoute] + childPayFsm.expectMsgType[SendPaymentToRoute] + + val (failedId, failedRoute) :: (successId, successRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toList + childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.hops, Sphinx.DecryptedFailurePacket(e, PaymentTimeout))))) awaitCond(payFsm.stateName === PAYMENT_ABORTED) - // The in-flight HTLC set doesn't pay the full amount, so the recipient MUST not fulfill any of those. - // But if he does, it's too bad for him as we have obtained a cheaper proof of payment. - val (id2, payment2) = pending.tail.head - childPayFsm.send(payFsm, PaymentSent(paymentId, paymentHash, paymentPreimage, payment.totalAmount, e, Seq(PartialPayment(id2, payment2.finalPayload.amount, 5 msat, randomBytes32, None)))) - awaitCond(payFsm.stateName === PAYMENT_SUCCEEDED) - - // Even if all other child payments fail, we obtained the preimage so the payment is a success from our point of view. - pending.tail.tail.foreach { - case (id, _) => childPayFsm.send(payFsm, PaymentFailed(id, paymentHash, Nil)) - } + sender.watch(payFsm) + childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(successId, successRoute.amount, successRoute.fee, randomBytes32, Some(successRoute.hops))))) + sender.expectMsg(PreimageReceived(paymentHash, paymentPreimage)) val result = sender.expectMsgType[PaymentSent] - assert(result.id === paymentId) - assert(result.amountWithFees === payment2.finalPayload.amount + 5.msat) - assert(result.nonTrampolineFees === 5.msat) + assert(result.id === cfg.id) + assert(result.paymentHash === paymentHash) + assert(result.paymentPreimage === paymentPreimage) + assert(result.parts.length === 1 && result.parts.head.id === successId) + assert(result.recipientAmount === finalAmount) + assert(result.recipientNodeId === finalRecipient) + assert(result.amountWithFees < finalAmount) // we got the preimage without paying the full amount + assert(result.nonTrampolineFees === successRoute.fee) // we paid the fee for only one of the partial payments + + sender.expectTerminated(payFsm) + sender.expectNoMsg(100 millis) + router.expectNoMsg(100 millis) + childPayFsm.expectNoMsg(100 millis) } - test("split payment", Tag("fuzzy")) { f => - // The fees for a single HTLC will be 100 * 172 / 1000 = 17 satoshis. - val testChannels = localChannels(100) - for (_ <- 1 to 100) { - // We have a total of 6500 satoshis across all channels. We try to send lower amounts to take fees into account. - val toSend = ((1 + Random.nextInt(3500)) * 1000).msat - val networkStats = emptyStats.copy(capacity = Stats.generate(Seq(400 + Random.nextInt(1600)), d => Satoshi(d.toLong))) - val routeParams = RouteParams(randomize = true, Random.nextInt(1000).msat, Random.nextInt(10).toDouble / 100, 20, CltvExpiryDelta(144), None, MultiPartParams(10000 msat, 5)) - val request = SendMultiPartPayment(randomBytes32, e, toSend, CltvExpiry(561), 1, Nil, Some(routeParams)) - val fuzzParams = s"(sending $toSend with network capacity ${networkStats.capacity.percentile75.toMilliSatoshi}, fee base ${routeParams.maxFeeBase} and fee percentage ${routeParams.maxFeePct})" - val (remaining, payments) = splitPayment(f.nodeParams, toSend, testChannels.channels, Some(networkStats), request, randomize = true) - assert(remaining === 0.msat, fuzzParams) - assert(payments.nonEmpty, fuzzParams) - assert(payments.map(_.finalPayload.amount).sum === toSend, fuzzParams) + test("receive partial failure after success (recipient spec violation)") { f => + import f._ + + val payment = SendMultiPartPayment(randomBytes32, e, finalAmount, expiry, 5, routeParams = Some(routeParams)) + sender.send(payFsm, payment) + router.expectMsgType[RouteRequest] + router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ab_1 :: hop_be :: Nil), Route(600000 msat, hop_ac_1 :: hop_ce :: Nil)))) + childPayFsm.expectMsgType[SendPaymentToRoute] + childPayFsm.expectMsgType[SendPaymentToRoute] + + val (childId, route) :: (failedId, failedRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toList + childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(childId, route.amount, route.fee, randomBytes32, Some(route.hops))))) + sender.expectMsg(PreimageReceived(paymentHash, paymentPreimage)) + awaitCond(payFsm.stateName === PAYMENT_SUCCEEDED) + + sender.watch(payFsm) + childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.hops, Sphinx.DecryptedFailurePacket(e, PaymentTimeout))))) + val result = sender.expectMsgType[PaymentSent] + assert(result.parts.length === 1 && result.parts.head.id === childId) + assert(result.amountWithFees < finalAmount) // we got the preimage without paying the full amount + assert(result.nonTrampolineFees === route.fee) // we paid the fee for only one of the partial payments + + sender.expectTerminated(payFsm) + sender.expectNoMsg(100 millis) + router.expectNoMsg(100 millis) + childPayFsm.expectNoMsg(100 millis) + } + + def fulfillPendingPayments(f: FixtureParam, childCount: Int): PaymentSent = { + import f._ + + sender.watch(payFsm) + val pending = payFsm.stateData.asInstanceOf[PaymentProgress].pending + assert(pending.size === childCount) + + val partialPayments = pending.map { + case (childId, route) => PaymentSent.PartialPayment(childId, route.amount, route.fee, randomBytes32, Some(route.hops)) } + partialPayments.foreach(pp => childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(pp)))) + sender.expectMsg(PreimageReceived(paymentHash, paymentPreimage)) + val result = sender.expectMsgType[PaymentSent] + assert(result.id === cfg.id) + assert(result.paymentHash === paymentHash) + assert(result.paymentPreimage === paymentPreimage) + assert(result.parts.toSet === partialPayments.toSet) + assert(result.recipientAmount === finalAmount) + assert(result.recipientNodeId === finalRecipient) + + sender.expectTerminated(payFsm) + sender.expectNoMsg(100 millis) + router.expectNoMsg(100 millis) + childPayFsm.expectNoMsg(100 millis) + + result + } + + def abortAfterFailure(f: FixtureParam, childFailure: PaymentFailed): PaymentFailed = { + import f._ + + sender.watch(payFsm) + val pendingCount = payFsm.stateData.asInstanceOf[PaymentProgress].pending.size + childPayFsm.send(payFsm, childFailure) // this failure should trigger an abort + if (pendingCount > 1) { + awaitCond(payFsm.stateName === PAYMENT_ABORTED) + assert(payFsm.stateData.asInstanceOf[PaymentAborted].pending.size === pendingCount - 1) + // Fail all remaining child payments. + payFsm.stateData.asInstanceOf[PaymentAborted].pending.foreach(childId => + childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(hop_ab_1 :: hop_be :: Nil, Sphinx.DecryptedFailurePacket(e, PaymentTimeout))))) + ) + } + + val result = sender.expectMsgType[PaymentFailed] + assert(result.id === cfg.id) + assert(result.paymentHash === paymentHash) + assert(result.failures.nonEmpty) + + sender.expectTerminated(payFsm) + sender.expectNoMsg(100 millis) + router.expectNoMsg(100 millis) + childPayFsm.expectNoMsg(100 millis) + + result } } @@ -513,6 +459,8 @@ object MultiPartPaymentLifecycleSpec { val expiry = CltvExpiry(1105) val finalAmount = 1000000 msat val finalRecipient = randomKey.publicKey + val routeParams = RouteParams(randomize = false, 15000 msat, 0.01, 6, CltvExpiryDelta(1008), None, MultiPartParams(1000 msat, 5)) + val maxFee = 15000 msat // max fee for the defaultAmount /** * We simulate a multi-part-friendly network: @@ -527,35 +475,29 @@ object MultiPartPaymentLifecycleSpec { val a :: b :: c :: d :: e :: Nil = Seq.fill(5)(randomKey.publicKey) val channelId_ab_1 = ShortChannelId(1) val channelId_ab_2 = ShortChannelId(2) + val channelId_be = ShortChannelId(3) val channelId_ac_1 = ShortChannelId(11) val channelId_ac_2 = ShortChannelId(12) - val channelId_ac_3 = ShortChannelId(13) - val channelId_ad_1 = ShortChannelId(21) - val defaultChannelUpdate = ChannelUpdate(randomBytes64, Block.RegtestGenesisBlock.hash, ShortChannelId(0), 0, 1, ChannelFlags.AnnounceChannel, CltvExpiryDelta(12), 1 msat, 0 msat, 0, Some(2000 * 1000 msat)) - val channelUpdate_ab_1 = defaultChannelUpdate.copy(shortChannelId = channelId_ab_1, cltvExpiryDelta = CltvExpiryDelta(4), feeBaseMsat = 100 msat, feeProportionalMillionths = 70) - val channelUpdate_ab_2 = defaultChannelUpdate.copy(shortChannelId = channelId_ab_2, cltvExpiryDelta = CltvExpiryDelta(4), feeBaseMsat = 100 msat, feeProportionalMillionths = 70) - val channelUpdate_ac_1 = defaultChannelUpdate.copy(shortChannelId = channelId_ac_1, cltvExpiryDelta = CltvExpiryDelta(5), feeBaseMsat = 150 msat, feeProportionalMillionths = 40) - val channelUpdate_ac_2 = defaultChannelUpdate.copy(shortChannelId = channelId_ac_2, cltvExpiryDelta = CltvExpiryDelta(5), feeBaseMsat = 150 msat, feeProportionalMillionths = 40) - val channelUpdate_ac_3 = defaultChannelUpdate.copy(shortChannelId = channelId_ac_3, cltvExpiryDelta = CltvExpiryDelta(5), feeBaseMsat = 150 msat, feeProportionalMillionths = 40) - val channelUpdate_ad_1 = defaultChannelUpdate.copy(shortChannelId = channelId_ad_1, cltvExpiryDelta = CltvExpiryDelta(6), feeBaseMsat = 200 msat, feeProportionalMillionths = 50) - - // With a fee rate of 10, the fees for a single HTLC will be 10 * 172 / 1000 = 1 satoshi. - def localChannels(feeRatePerKw: Long = 10): OutgoingChannels = OutgoingChannels(Seq( - OutgoingChannel(b, channelUpdate_ab_1, makeCommitments(1000 * 1000 msat, feeRatePerKw)), - OutgoingChannel(b, channelUpdate_ab_2, makeCommitments(1500 * 1000 msat, feeRatePerKw)), - OutgoingChannel(c, channelUpdate_ac_1, makeCommitments(500 * 1000 msat, feeRatePerKw)), - OutgoingChannel(c, channelUpdate_ac_2, makeCommitments(1000 * 1000 msat, feeRatePerKw)), - OutgoingChannel(c, channelUpdate_ac_3, makeCommitments(1500 * 1000 msat, feeRatePerKw)), - OutgoingChannel(d, channelUpdate_ad_1, makeCommitments(1000 * 1000 msat, feeRatePerKw)))) + val channelId_ce = ShortChannelId(13) + val channelId_ad = ShortChannelId(21) + val channelId_de = ShortChannelId(22) + val defaultChannelUpdate = ChannelUpdate(randomBytes64, Block.RegtestGenesisBlock.hash, ShortChannelId(0), 0, 1, ChannelFlags.AnnounceChannel, CltvExpiryDelta(12), 1 msat, 100 msat, 0, Some(2000000 msat)) + val channelUpdate_ab_1 = defaultChannelUpdate.copy(shortChannelId = channelId_ab_1) + val channelUpdate_ab_2 = defaultChannelUpdate.copy(shortChannelId = channelId_ab_2) + val channelUpdate_be = defaultChannelUpdate.copy(shortChannelId = channelId_be) + val channelUpdate_ac_1 = defaultChannelUpdate.copy(shortChannelId = channelId_ac_1) + val channelUpdate_ac_2 = defaultChannelUpdate.copy(shortChannelId = channelId_ac_2) + val channelUpdate_ce = defaultChannelUpdate.copy(shortChannelId = channelId_ce) + val channelUpdate_ad = defaultChannelUpdate.copy(shortChannelId = channelId_ad) + val channelUpdate_de = defaultChannelUpdate.copy(shortChannelId = channelId_de) val hop_ab_1 = ChannelHop(a, b, channelUpdate_ab_1) val hop_ab_2 = ChannelHop(a, b, channelUpdate_ab_2) + val hop_be = ChannelHop(b, e, channelUpdate_be) val hop_ac_1 = ChannelHop(a, c, channelUpdate_ac_1) - - val emptyStats = NetworkStats(0, 0, Stats.generate(Seq(0), d => Satoshi(d.toLong)), Stats.generate(Seq(0), d => CltvExpiryDelta(d.toInt)), Stats.generate(Seq(0), d => MilliSatoshi(d.toLong)), Stats.generate(Seq(0), d => d.toLong)) - - // We are only interested in availableBalanceForSend so we can put dummy values for the rest. - def makeCommitments(canSend: MilliSatoshi, feeRatePerKw: Long, announceChannel: Boolean = true): Commitments = - CommitmentsSpec.makeCommitments(canSend, 0 msat, feeRatePerKw, 0 sat, announceChannel = announceChannel) + val hop_ac_2 = ChannelHop(a, c, channelUpdate_ac_2) + val hop_ce = ChannelHop(c, e, channelUpdate_ce) + val hop_ad = ChannelHop(a, d, channelUpdate_ad) + val hop_de = ChannelHop(d, e, channelUpdate_de) } \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/NodeRelayerSpec.scala index 544771617..9555bfee5 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/NodeRelayerSpec.scala @@ -26,12 +26,11 @@ import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Upstream} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.PaymentRequest.{ExtraHop, PaymentRequestFeatures} import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM -import fr.acinq.eclair.payment.relay.{CommandBuffer, NodeRelayer, Origin, Relayer} -import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.SendMultiPartPayment -import fr.acinq.eclair.payment.send.PaymentError +import fr.acinq.eclair.payment.relay.{CommandBuffer, NodeRelayer} +import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment} import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPayment -import fr.acinq.eclair.router.RouteNotFound +import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound} import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, NodeParams, ShortChannelId, TestConstants, TestKitBaseClass, nodeFee, randomBytes, randomBytes32, randomKey} import org.scalatest.Outcome @@ -56,16 +55,16 @@ class NodeRelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { within(30 seconds) { val nodeParams = TestConstants.Bob.nodeParams val outgoingPayFSM = TestProbe() - val (relayer, router, commandBuffer, register, eventListener) = (TestProbe(), TestProbe(), TestProbe(), TestProbe(), TestProbe()) + val (router, commandBuffer, register, eventListener) = (TestProbe(), TestProbe(), TestProbe(), TestProbe()) system.eventStream.subscribe(eventListener.ref, classOf[PaymentEvent]) - class TestNodeRelayer extends NodeRelayer(nodeParams, relayer.ref, router.ref, commandBuffer.ref, register.ref) { + class TestNodeRelayer extends NodeRelayer(nodeParams, router.ref, commandBuffer.ref, register.ref) { override def spawnOutgoingPayFSM(cfg: SendPaymentConfig, multiPart: Boolean): ActorRef = { outgoingPayFSM.ref ! cfg outgoingPayFSM.ref } } val nodeRelayer = TestActorRef(new TestNodeRelayer().asInstanceOf[NodeRelayer]) - withFixture(test.toNoArgTest(FixtureParam(nodeParams, nodeRelayer, relayer, outgoingPayFSM, commandBuffer, eventListener))) + withFixture(test.toNoArgTest(FixtureParam(nodeParams, nodeRelayer, TestProbe(), outgoingPayFSM, commandBuffer, eventListener))) } } @@ -219,7 +218,7 @@ class NodeRelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { val outgoingPaymentId = outgoingPayFSM.expectMsgType[SendPaymentConfig].id outgoingPayFSM.expectMsgType[SendMultiPartPayment] - outgoingPayFSM.send(nodeRelayer, PaymentFailed(outgoingPaymentId, paymentHash, LocalFailure(Nil, PaymentError.BalanceTooLow) :: Nil)) + outgoingPayFSM.send(nodeRelayer, PaymentFailed(outgoingPaymentId, paymentHash, LocalFailure(Nil, BalanceTooLow) :: Nil)) incomingMultiPart.foreach(p => commandBuffer.expectMsg(CommandBuffer.CommandSend(p.add.channelId, CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure), commit = true)))) commandBuffer.expectNoMsg(100 millis) eventListener.expectNoMsg(100 millis) @@ -249,7 +248,7 @@ class NodeRelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { val outgoingPaymentId = outgoingPayFSM.expectMsgType[SendPaymentConfig].id outgoingPayFSM.expectMsgType[SendMultiPartPayment] - val failures = RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(outgoingNodeId, FinalIncorrectHtlcAmount(42 msat))) :: UnreadableRemoteFailure(Nil) :: LocalFailure(Nil, RouteNotFound) :: Nil + val failures = RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(outgoingNodeId, FinalIncorrectHtlcAmount(42 msat))) :: UnreadableRemoteFailure(Nil) :: Nil outgoingPayFSM.send(nodeRelayer, PaymentFailed(outgoingPaymentId, paymentHash, failures)) incomingMultiPart.foreach(p => commandBuffer.expectMsg(CommandBuffer.CommandSend(p.add.channelId, CMD_FAIL_HTLC(p.add.id, Right(FinalIncorrectHtlcAmount(42 msat)), commit = true)))) commandBuffer.expectNoMsg(100 millis) @@ -281,18 +280,12 @@ class NodeRelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { val outgoingPayment = outgoingPayFSM.expectMsgType[SendMultiPartPayment] validateOutgoingPayment(outgoingPayment) - // A first downstream HTLC is fulfilled. - val ff1 = createDownstreamFulfill(outgoingPayFSM.ref) - relayer.send(nodeRelayer, ff1) - outgoingPayFSM.expectMsg(ff1) - // We should immediately forward the fulfill upstream. + // A first downstream HTLC is fulfilled: we should immediately forward the fulfill upstream. + outgoingPayFSM.send(nodeRelayer, PreimageReceived(paymentHash, paymentPreimage)) incomingMultiPart.foreach(p => commandBuffer.expectMsg(CommandBuffer.CommandSend(p.add.channelId, CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)))) - // A second downstream HTLC is fulfilled. - val ff2 = createDownstreamFulfill(outgoingPayFSM.ref) - relayer.send(nodeRelayer, ff2) - outgoingPayFSM.expectMsg(ff2) - // We should not fulfill a second time upstream. + // If the payment FSM sends us duplicate preimage events, we should not fulfill a second time upstream. + outgoingPayFSM.send(nodeRelayer, PreimageReceived(paymentHash, paymentPreimage)) commandBuffer.expectNoMsg(100 millis) // Once all the downstream payments have settled, we should emit the relayed event. @@ -315,9 +308,7 @@ class NodeRelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { val outgoingPayment = outgoingPayFSM.expectMsgType[SendMultiPartPayment] validateOutgoingPayment(outgoingPayment) - val ff = createDownstreamFulfill(outgoingPayFSM.ref) - relayer.send(nodeRelayer, ff) - outgoingPayFSM.expectMsg(ff) + outgoingPayFSM.send(nodeRelayer, PreimageReceived(paymentHash, paymentPreimage)) val incomingAdd = incomingSinglePart.add commandBuffer.expectMsg(CommandBuffer.CommandSend(incomingAdd.channelId, CMD_FULFILL_HTLC(incomingAdd.id, paymentPreimage, commit = true))) @@ -329,8 +320,7 @@ class NodeRelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { commandBuffer.expectNoMsg(100 millis) } - // TODO: re-activate this test once we have better MPP split to remote legacy recipients - ignore("relay to non-trampoline recipient supporting multi-part") { f => + test("relay to non-trampoline recipient supporting multi-part") { f => import f._ // Receive an upstream multi-part payment. @@ -352,9 +342,7 @@ class NodeRelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { assert(outgoingPayment.routeParams.isDefined) assert(outgoingPayment.assistedRoutes === hints) - val ff = createDownstreamFulfill(outgoingPayFSM.ref) - relayer.send(nodeRelayer, ff) - outgoingPayFSM.expectMsg(ff) + outgoingPayFSM.send(nodeRelayer, PreimageReceived(paymentHash, paymentPreimage)) incomingMultiPart.foreach(p => commandBuffer.expectMsg(CommandBuffer.CommandSend(p.add.channelId, CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)))) outgoingPayFSM.send(nodeRelayer, createSuccessEvent(outgoingCfg.id)) @@ -378,16 +366,13 @@ class NodeRelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { val outgoingCfg = outgoingPayFSM.expectMsgType[SendPaymentConfig] validateOutgoingCfg(outgoingCfg, Upstream.TrampolineRelayed(incomingMultiPart.map(_.add))) val outgoingPayment = outgoingPayFSM.expectMsgType[SendPayment] - assert(outgoingPayment.routePrefix === Nil) assert(outgoingPayment.finalPayload.amount === outgoingAmount) assert(outgoingPayment.finalPayload.expiry === outgoingExpiry) assert(outgoingPayment.targetNodeId === outgoingNodeId) assert(outgoingPayment.routeParams.isDefined) assert(outgoingPayment.assistedRoutes === hints) - val ff = createDownstreamFulfill(outgoingPayFSM.ref) - relayer.send(nodeRelayer, ff) - outgoingPayFSM.expectMsg(ff) + outgoingPayFSM.send(nodeRelayer, PreimageReceived(paymentHash, paymentPreimage)) incomingMultiPart.foreach(p => commandBuffer.expectMsg(CommandBuffer.CommandSend(p.add.channelId, CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)))) outgoingPayFSM.send(nodeRelayer, createSuccessEvent(outgoingCfg.id)) @@ -449,11 +434,6 @@ object NodeRelayerSpec { val incomingSinglePart = createValidIncomingPacket(incomingAmount, incomingAmount, CltvExpiry(500000), outgoingAmount, outgoingExpiry) - def createDownstreamFulfill(payFSM: ActorRef): Relayer.ForwardFulfill = { - val origin = Origin.TrampolineRelayed(null, Some(payFSM)) - Relayer.ForwardRemoteFulfill(UpdateFulfillHtlc(randomBytes32, Random.nextInt(100), paymentPreimage), origin, null) - } - def createSuccessEvent(id: UUID): PaymentSent = PaymentSent(id, paymentHash, paymentPreimage, outgoingAmount, outgoingNodeId, Seq(PaymentSent.PartialPayment(id, outgoingAmount, 10 msat, randomBytes32, None))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala index 41fd44da5..2154ae36c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala @@ -32,6 +32,7 @@ import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.SendMultiPartPayme import fr.acinq.eclair.payment.send.PaymentInitiator._ import fr.acinq.eclair.payment.send.PaymentLifecycle.{SendPayment, SendPaymentToRoute} import fr.acinq.eclair.payment.send.{PaymentError, PaymentInitiator} +import fr.acinq.eclair.router.RouteNotFound import fr.acinq.eclair.router.Router.{MultiPartParams, NodeHop, RouteParams} import fr.acinq.eclair.wire.Onion.{FinalLegacyPayload, FinalTlvPayload} import fr.acinq.eclair.wire.OnionTlv.{AmountToForward, OutgoingCltv} @@ -69,7 +70,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val (sender, payFsm, multiPartPayFsm) = (TestProbe(), TestProbe(), TestProbe()) val eventListener = TestProbe() system.eventStream.subscribe(eventListener.ref, classOf[PaymentEvent]) - class TestPaymentInitiator extends PaymentInitiator(nodeParams, TestProbe().ref, TestProbe().ref, TestProbe().ref) { + class TestPaymentInitiator extends PaymentInitiator(nodeParams, TestProbe().ref, TestProbe().ref) { // @formatter:off override def spawnPaymentFsm(cfg: SendPaymentConfig): ActorRef = { payFsm.ref ! cfg @@ -116,7 +117,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike sender.send(initiator, SendPaymentToRouteRequest(finalAmount, finalAmount, None, None, pr, Channel.MIN_CLTV_EXPIRY_DELTA, Seq(a, b, c), None, 0 msat, CltvExpiryDelta(0), Nil)) val payment = sender.expectMsgType[SendPaymentToRouteResponse] payFsm.expectMsg(SendPaymentConfig(payment.paymentId, payment.parentId, None, paymentHash, finalAmount, c, Upstream.Local(payment.paymentId), Some(pr), storeInDb = true, publishEvent = true, Nil)) - payFsm.expectMsg(SendPaymentToRoute(Seq(a, b, c), FinalLegacyPayload(finalAmount, Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry(nodeParams.currentBlockHeight + 1)))) + payFsm.expectMsg(SendPaymentToRoute(Left(Seq(a, b, c)), FinalLegacyPayload(finalAmount, Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry(nodeParams.currentBlockHeight + 1)))) } test("forward legacy payment") { f => @@ -162,7 +163,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val payment = sender.expectMsgType[SendPaymentToRouteResponse] payFsm.expectMsg(SendPaymentConfig(payment.paymentId, payment.parentId, None, paymentHash, finalAmount, c, Upstream.Local(payment.paymentId), Some(pr), storeInDb = true, publishEvent = true, Nil)) val msg = payFsm.expectMsgType[SendPaymentToRoute] - assert(msg.hops === Seq(a, b, c)) + assert(msg.route === Left(Seq(a, b, c))) assert(msg.finalPayload.amount === finalAmount / 2) assert(msg.finalPayload.paymentSecret === pr.paymentSecret) assert(msg.finalPayload.totalAmount === finalAmount) @@ -302,13 +303,13 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(msg1.totalAmount === finalAmount + 21000.msat) // Simulate a failure which should trigger a retry. - val failed = PaymentFailed(cfg.parentId, pr.paymentHash, Seq(RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(b, TrampolineFeeInsufficient)))) - multiPartPayFsm.send(initiator, failed) + multiPartPayFsm.send(initiator, PaymentFailed(cfg.parentId, pr.paymentHash, Seq(RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(b, TrampolineFeeInsufficient))))) multiPartPayFsm.expectMsgType[SendPaymentConfig] val msg2 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] assert(msg2.totalAmount === finalAmount + 25000.msat) // Simulate a failure that exhausts payment attempts. + val failed = PaymentFailed(cfg.parentId, pr.paymentHash, Seq(RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(b, TemporaryNodeFailure)))) multiPartPayFsm.send(initiator, failed) sender.expectMsg(failed) eventListener.expectMsg(failed) @@ -316,6 +317,34 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike eventListener.expectNoMsg(100 millis) } + test("retry trampoline payment and fail (route not found)") { f => + import f._ + val features = PaymentRequestFeatures(VariableLengthOnion.optional, PaymentSecret.optional, BasicMultiPartPayment.optional, TrampolinePayment.optional) + val pr = PaymentRequest(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, priv_c.privateKey, "Some phoenix invoice", features = Some(features)) + val trampolineAttempts = (21000 msat, CltvExpiryDelta(12)) :: (25000 msat, CltvExpiryDelta(24)) :: Nil + val req = SendTrampolinePaymentRequest(finalAmount, pr, b, trampolineAttempts, CltvExpiryDelta(9)) + sender.send(initiator, req) + sender.expectMsgType[UUID] + + val cfg = multiPartPayFsm.expectMsgType[SendPaymentConfig] + val msg1 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] + assert(msg1.totalAmount === finalAmount + 21000.msat) + // Trampoline node couldn't find a route for the given fee. + val failed = PaymentFailed(cfg.parentId, pr.paymentHash, Seq(RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(b, TrampolineFeeInsufficient)))) + multiPartPayFsm.send(initiator, failed) + multiPartPayFsm.expectMsgType[SendPaymentConfig] + val msg2 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] + assert(msg2.totalAmount === finalAmount + 25000.msat) + // Trampoline node couldn't find a route even with the increased fee. + multiPartPayFsm.send(initiator, failed) + + val failure = sender.expectMsgType[PaymentFailed] + assert(failure.failures === Seq(LocalFailure(Seq(NodeHop(nodeParams.nodeId, b, nodeParams.expiryDeltaBlocks, 0 msat), NodeHop(b, c, CltvExpiryDelta(24), 25000 msat)), RouteNotFound))) + eventListener.expectMsg(failure) + sender.expectNoMsg(100 millis) + eventListener.expectNoMsg(100 millis) + } + test("forward trampoline payment with pre-defined route") { f => import f._ val pr = PaymentRequest(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, priv_c.privateKey, "Some invoice") @@ -326,7 +355,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(payment.trampolineSecret.nonEmpty) payFsm.expectMsg(SendPaymentConfig(payment.paymentId, payment.parentId, None, paymentHash, finalAmount, c, Upstream.Local(payment.paymentId), Some(pr), storeInDb = true, publishEvent = true, Seq(NodeHop(b, c, CltvExpiryDelta(0), 0 msat)))) val msg = payFsm.expectMsgType[SendPaymentToRoute] - assert(msg.hops === Seq(a, b)) + assert(msg.route === Left(Seq(a, b))) assert(msg.finalPayload.amount === finalAmount + trampolineFees) assert(msg.finalPayload.paymentSecret === payment.trampolineSecret) assert(msg.finalPayload.totalAmount === finalAmount + trampolineFees) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index 20a665e38..2d4510158 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -62,10 +62,9 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val defaultExternalId = UUID.randomUUID().toString val defaultPaymentRequest = SendPaymentRequest(defaultAmountMsat, defaultPaymentHash, d, 1, externalId = Some(defaultExternalId)) - def defaultRouteRequest(source: PublicKey, target: PublicKey): RouteRequest = RouteRequest(source, target, defaultAmountMsat, defaultMaxFee) + def defaultRouteRequest(source: PublicKey, target: PublicKey, cfg: SendPaymentConfig): RouteRequest = RouteRequest(source, target, defaultAmountMsat, defaultMaxFee, paymentContext = Some(cfg.paymentContext)) - case class PaymentFixture(id: UUID, - parentId: UUID, + case class PaymentFixture(cfg: SendPaymentConfig, nodeParams: NodeParams, paymentFSM: TestFSMRef[PaymentLifecycle.State, PaymentLifecycle.Data, PaymentLifecycle], routerForwarder: TestProbe, @@ -83,18 +82,43 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) system.eventStream.subscribe(eventListener.ref, classOf[PaymentEvent]) - PaymentFixture(id, parentId, nodeParams, paymentFSM, routerForwarder, register, sender, monitor, eventListener) + PaymentFixture(cfg, nodeParams, paymentFSM, routerForwarder, register, sender, monitor, eventListener) } - test("send to route") { routerFixture => + test("send to route") { _ => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ // pre-computed route going from A to D - val request = SendPaymentToRoute(Seq(a, b, c, d), FinalLegacyPayload(defaultAmountMsat, defaultExpiry)) + val route = Route(defaultAmountMsat, ChannelHop(a, b, update_ab) :: ChannelHop(b, c, update_bc) :: ChannelHop(c, d, update_cd) :: Nil) + val request = SendPaymentToRoute(Right(route), FinalLegacyPayload(defaultAmountMsat, defaultExpiry)) + sender.send(paymentFSM, request) + routerForwarder.expectNoMsg(100 millis) // we don't need the router, we have the pre-computed route + val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) + + val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) + awaitCond(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) + val Some(outgoing) = nodeParams.db.payments.getOutgoingPayment(id) + assert(outgoing.copy(createdAt = 0) === OutgoingPayment(id, parentId, Some(defaultExternalId), defaultPaymentHash, PaymentType.Standard, defaultAmountMsat, defaultAmountMsat, d, 0, None, OutgoingPaymentStatus.Pending)) + sender.send(paymentFSM, Relayer.ForwardRemoteFulfill(UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentPreimage), defaultOrigin, UpdateAddHtlc(ByteVector32.Zeroes, 0, defaultAmountMsat, defaultPaymentHash, defaultExpiry, TestConstants.emptyOnionPacket))) + + val ps = sender.expectMsgType[PaymentSent] + assert(ps.id === parentId) + assert(ps.parts.head.route === Some(route.hops)) + awaitCond(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Succeeded])) + } + + test("send to route (node_id only)") { routerFixture => + val payFixture = createPaymentLifecycle() + import payFixture._ + import cfg._ + + // pre-computed route going from A to D + val request = SendPaymentToRoute(Left(Seq(a, b, c, d)), FinalLegacyPayload(defaultAmountMsat, defaultExpiry)) sender.send(paymentFSM, request) - routerForwarder.expectMsg(FinalizeRoute(defaultAmountMsat, Seq(a, b, c, d))) + routerForwarder.expectMsg(FinalizeRoute(defaultAmountMsat, Seq(a, b, c, d), paymentContext = Some(cfg.paymentContext))) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) routerForwarder.forward(routerFixture.router) @@ -113,7 +137,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val payFixture = createPaymentLifecycle() import payFixture._ - val brokenRoute = SendPaymentToRoute(Seq(randomKey.publicKey, randomKey.publicKey, randomKey.publicKey), FinalLegacyPayload(defaultAmountMsat, defaultExpiry)) + val brokenRoute = SendPaymentToRoute(Left(Seq(randomKey.publicKey, randomKey.publicKey, randomKey.publicKey)), FinalLegacyPayload(defaultAmountMsat, defaultExpiry)) sender.send(paymentFSM, brokenRoute) routerForwarder.expectMsgType[FinalizeRoute] routerForwarder.forward(routerFixture.router) @@ -125,13 +149,14 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("send to route (routing hints)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val recipient = randomKey.publicKey val routingHint = Seq(Seq(ExtraHop(c, ShortChannelId(561), 1 msat, 100, CltvExpiryDelta(144)))) - val request = SendPaymentToRoute(Seq(a, b, c, recipient), FinalLegacyPayload(defaultAmountMsat, defaultExpiry), routingHint) + val request = SendPaymentToRoute(Left(Seq(a, b, c, recipient)), FinalLegacyPayload(defaultAmountMsat, defaultExpiry), routingHint) sender.send(paymentFSM, request) - routerForwarder.expectMsg(FinalizeRoute(defaultAmountMsat, Seq(a, b, c, recipient), routingHint)) + routerForwarder.expectMsg(FinalizeRoute(defaultAmountMsat, Seq(a, b, c, recipient), routingHint, paymentContext = Some(cfg.paymentContext))) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) routerForwarder.forward(routerFixture.router) @@ -145,54 +170,10 @@ class PaymentLifecycleSpec extends BaseRouterSpec { awaitCond(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Succeeded])) } - test("send with route prefix") { _ => - val payFixture = createPaymentLifecycle() - import payFixture._ - - val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 3, routePrefix = Seq(ChannelHop(a, b, update_ab), ChannelHop(b, c, update_bc))) - sender.send(paymentFSM, request) - routerForwarder.expectMsg(defaultRouteRequest(c, d).copy(ignoreNodes = Set(a, b))) - val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) - awaitCond(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) - - routerForwarder.send(paymentFSM, RouteResponse(Route(defaultAmountMsat, Seq(ChannelHop(c, d, update_cd))) :: Nil)) - val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) - } - - test("send with whole route prefix") { _ => - val payFixture = createPaymentLifecycle() - import payFixture._ - - val request = SendPayment(c, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 3, routePrefix = Seq(ChannelHop(a, b, update_ab), ChannelHop(b, c, update_bc))) - sender.send(paymentFSM, request) - routerForwarder.expectNoMsg(50 millis) // we don't need the router when we already have the whole route - val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) - val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) - awaitCond(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) - } - - test("send with route prefix and retry") { _ => - val payFixture = createPaymentLifecycle() - import payFixture._ - - val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 3, routePrefix = Seq(ChannelHop(a, b, update_ab), ChannelHop(b, c, update_bc))) - sender.send(paymentFSM, request) - routerForwarder.expectMsg(defaultRouteRequest(c, d).copy(ignoreNodes = Set(a, b))) - val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) - awaitCond(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) - - routerForwarder.send(paymentFSM, RouteResponse(Route(defaultAmountMsat, Seq(ChannelHop(c, d, update_cd))) :: Nil)) - val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) - - sender.send(paymentFSM, UpdateFailHtlc(randomBytes32, 0, randomBytes(Sphinx.FailurePacket.PacketLength))) - routerForwarder.expectMsg(defaultRouteRequest(c, d).copy(ignoreNodes = Set(a, b, c))) - val Transition(_, WAITING_FOR_PAYMENT_COMPLETE, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) - assert(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) - } - test("payment failed (route not found)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(f, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 5) sender.send(paymentFSM, request) @@ -208,6 +189,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment failed (route too expensive)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 5, routeParams = Some(RouteParams(randomize = false, 100 msat, 0.0, 20, CltvExpiryDelta(2016), None, MultiPartParams(10000 msat, 5)))) sender.send(paymentFSM, request) @@ -222,29 +204,30 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment failed (unparsable failure)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 2) sender.send(paymentFSM, request) - routerForwarder.expectMsg(defaultRouteRequest(a, d)) + routerForwarder.expectMsg(defaultRouteRequest(a, d, cfg)) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) - val WaitingForRoute(_, _, Nil, _, _) = paymentFSM.stateData + val WaitingForRoute(_, _, Nil, _) = paymentFSM.stateData routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) - val WaitingForComplete(_, _, cmd1, Nil, _, ignoreNodes1, _, route) = paymentFSM.stateData - assert(ignoreNodes1.isEmpty) + val WaitingForComplete(_, _, cmd1, Nil, _, ignore1, route) = paymentFSM.stateData + assert(ignore1.nodes.isEmpty) register.expectMsg(ForwardShortId(channelId_ab, cmd1)) sender.send(paymentFSM, Relayer.ForwardRemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, randomBytes32), defaultOrigin, UpdateAddHtlc(ByteVector32.Zeroes, 0, defaultAmountMsat, defaultPaymentHash, defaultExpiry, TestConstants.emptyOnionPacket))) // unparsable message // then the payment lifecycle will ask for a new route excluding all intermediate nodes - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d).copy(ignoreNodes = Set(c))) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(ignore = Ignore(Set(c), Set.empty))) // let's simulate a response by the router with another route sender.send(paymentFSM, RouteResponse(route :: Nil)) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) - val WaitingForComplete(_, _, cmd2, _, _, ignoreNodes2, _, _) = paymentFSM.stateData - assert(ignoreNodes2 === Set(c)) + val WaitingForComplete(_, _, cmd2, _, _, ignore2, _) = paymentFSM.stateData + assert(ignore2.nodes === Set(c)) // and reply a 2nd time with an unparsable failure register.expectMsg(ForwardShortId(channelId_ab, cmd2)) sender.send(paymentFSM, UpdateFailHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) // unparsable message @@ -257,59 +240,62 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment failed (local error)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) - val WaitingForRoute(_, _, Nil, _, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d)) + val WaitingForRoute(_, _, Nil, _) = paymentFSM.stateData + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) - val WaitingForComplete(_, _, cmd1, Nil, _, _, _, _) = paymentFSM.stateData + val WaitingForComplete(_, _, cmd1, Nil, _, _, _) = paymentFSM.stateData register.expectMsg(ForwardShortId(channelId_ab, cmd1)) sender.send(paymentFSM, Status.Failure(AddHtlcFailed(ByteVector32.Zeroes, defaultPaymentHash, ChannelUnavailable(ByteVector32.Zeroes), Local(id, Some(paymentFSM.underlying.self)), None, None))) // then the payment lifecycle will ask for a new route excluding the channel - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d).copy(ignoreChannels = Set(ChannelDesc(channelId_ab, a, b)))) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_ab, a, b))))) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) // payment is still pending because the error is recoverable } test("payment failed (first hop returns an UpdateFailMalformedHtlc)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) - val WaitingForRoute(_, _, Nil, _, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d)) + val WaitingForRoute(_, _, Nil, _) = paymentFSM.stateData + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) - val WaitingForComplete(_, _, cmd1, Nil, _, _, _, _) = paymentFSM.stateData + val WaitingForComplete(_, _, cmd1, Nil, _, _, _) = paymentFSM.stateData register.expectMsg(ForwardShortId(channelId_ab, cmd1)) sender.send(paymentFSM, UpdateFailMalformedHtlc(ByteVector32.Zeroes, 0, randomBytes32, FailureMessageCodecs.BADONION)) // then the payment lifecycle will ask for a new route excluding the channel - routerForwarder.expectMsg(defaultRouteRequest(a, d).copy(ignoreChannels = Set(ChannelDesc(channelId_ab, a, b)))) + routerForwarder.expectMsg(defaultRouteRequest(a, d, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_ab, a, b))))) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) } test("payment failed (TemporaryChannelFailure)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) - val WaitingForRoute(_, _, Nil, _, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d)) + val WaitingForRoute(_, _, Nil, _) = paymentFSM.stateData + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) - val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, _, route) = paymentFSM.stateData + val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, route) = paymentFSM.stateData register.expectMsg(ForwardShortId(channelId_ab, cmd1)) val failure = TemporaryChannelFailure(update_bc) @@ -321,7 +307,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // payment lifecycle forwards the embedded channelUpdate to the router routerForwarder.expectMsg(update_bc) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) - routerForwarder.expectMsg(defaultRouteRequest(a, d)) + routerForwarder.expectMsg(defaultRouteRequest(a, d, cfg)) routerForwarder.forward(routerFixture.router) // we allow 2 tries, so we send a 2nd request to the router assert(sender.expectMsgType[PaymentFailed].failures === RemoteFailure(route.hops, Sphinx.DecryptedFailurePacket(b, failure)) :: LocalFailure(Nil, RouteNotFound) :: Nil) @@ -330,16 +316,17 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment failed (Update)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 5) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) - val WaitingForRoute(_, _, Nil, _, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d)) + val WaitingForRoute(_, _, Nil, _) = paymentFSM.stateData + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) - val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, _, route1) = paymentFSM.stateData + val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, route1) = paymentFSM.stateData register.expectMsg(ForwardShortId(channelId_ab, cmd1)) // we change the cltv expiry @@ -351,12 +338,12 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // payment lifecycle forwards the embedded channelUpdate to the router routerForwarder.expectMsg(channelUpdate_bc_modified) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) // 1 failure but not final, the payment is still PENDING - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d)) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) routerForwarder.forward(routerFixture.router) // router answers with a new route, taking into account the new update awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) - val WaitingForComplete(_, _, cmd2, _, sharedSecrets2, _, _, route2) = paymentFSM.stateData + val WaitingForComplete(_, _, cmd2, _, sharedSecrets2, _, route2) = paymentFSM.stateData register.expectMsg(ForwardShortId(channelId_ab, cmd2)) // we change the cltv expiry one more time @@ -371,7 +358,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // but it will still forward the embedded channelUpdate to the router routerForwarder.expectMsg(channelUpdate_bc_modified_2) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d)) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) routerForwarder.forward(routerFixture.router) // this time the router can't find a route: game over @@ -379,9 +366,34 @@ class PaymentLifecycleSpec extends BaseRouterSpec { awaitCond(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Failed])) } + test("payment failed (Update in last attempt)") { routerFixture => + val payFixture = createPaymentLifecycle() + import payFixture._ + import cfg._ + + val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 1) + sender.send(paymentFSM, request) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) + routerForwarder.forward(routerFixture.router) + awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) + val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, _) = paymentFSM.stateData + register.expectMsg(ForwardShortId(channelId_ab, cmd1)) + + // the node replies with a temporary failure containing the same update as the one we already have (likely a balance issue) + val failure = TemporaryChannelFailure(update_bc) + sender.send(paymentFSM, UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure))) + // we should temporarily exclude that channel + routerForwarder.expectMsg(ExcludeChannel(ChannelDesc(update_bc.shortChannelId, b, c))) + routerForwarder.expectMsg(update_bc) + + // this was a single attempt payment + sender.expectMsgType[PaymentFailed] + } + test("payment failed (Update in assisted route)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ // we build an assisted route for channel bc and cd val assistedRoutes = Seq(Seq( @@ -393,11 +405,11 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) - val WaitingForRoute(_, _, Nil, _, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d).copy(assistedRoutes = assistedRoutes)) + val WaitingForRoute(_, _, Nil, _) = paymentFSM.stateData + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(assistedRoutes = assistedRoutes)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) - val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, _, _) = paymentFSM.stateData + val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, _) = paymentFSM.stateData register.expectMsg(ForwardShortId(channelId_ab, cmd1)) // we change the cltv expiry @@ -413,36 +425,64 @@ class PaymentLifecycleSpec extends BaseRouterSpec { ExtraHop(b, channelId_bc, update_bc.feeBaseMsat, update_bc.feeProportionalMillionths, channelUpdate_bc_modified.cltvExpiryDelta), ExtraHop(c, channelId_cd, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.cltvExpiryDelta) )) - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d).copy(assistedRoutes = assistedRoutes1)) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(assistedRoutes = assistedRoutes1)) routerForwarder.forward(routerFixture.router) // router answers with a new route, taking into account the new update awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) - val WaitingForComplete(_, _, cmd2, _, _, _, _, _) = paymentFSM.stateData + val WaitingForComplete(_, _, cmd2, _, _, _, _) = paymentFSM.stateData register.expectMsg(ForwardShortId(channelId_ab, cmd2)) assert(cmd2.cltvExpiry > cmd1.cltvExpiry) } + test("payment failed (Update disabled in assisted route)") { routerFixture => + val payFixture = createPaymentLifecycle() + import payFixture._ + import cfg._ + + // we build an assisted route for channel cd + val assistedRoutes = Seq(Seq(ExtraHop(c, channelId_cd, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.cltvExpiryDelta))) + val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 1, assistedRoutes = assistedRoutes) + sender.send(paymentFSM, request) + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) + + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(assistedRoutes = assistedRoutes)) + routerForwarder.forward(routerFixture.router) + awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) + val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, _) = paymentFSM.stateData + register.expectMsg(ForwardShortId(channelId_ab, cmd1)) + + // we disable the channel + val channelUpdate_cd_disabled = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_c, d, channelId_cd, CltvExpiryDelta(42), update_cd.htlcMinimumMsat, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.htlcMaximumMsat.get, enable = false) + val failure = ChannelDisabled(channelUpdate_cd_disabled.messageFlags, channelUpdate_cd_disabled.channelFlags, channelUpdate_cd_disabled) + val failureOnion = Sphinx.FailurePacket.wrap(Sphinx.FailurePacket.create(sharedSecrets1(1)._1, failure), sharedSecrets1.head._1) + sender.send(paymentFSM, UpdateFailHtlc(ByteVector32.Zeroes, 0, failureOnion)) + + routerForwarder.expectMsg(channelUpdate_cd_disabled) + routerForwarder.expectMsg(ExcludeChannel(ChannelDesc(update_cd.shortChannelId, c, d))) + } + def testPermanentFailure(router: ActorRef, failure: FailureMessage): Unit = { val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) - val WaitingForRoute(_, _, Nil, _, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d)) + val WaitingForRoute(_, _, Nil, _) = paymentFSM.stateData + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) routerForwarder.forward(router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) - val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, _, route1) = paymentFSM.stateData + val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, route1) = paymentFSM.stateData register.expectMsg(ForwardShortId(channelId_ab, cmd1)) sender.send(paymentFSM, UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure))) // payment lifecycle forwards the embedded channelUpdate to the router awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d).copy(ignoreChannels = Set(ChannelDesc(channelId_bc, b, c)))) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_bc, b, c))))) routerForwarder.forward(router) // we allow 2 tries, so we send a 2nd request to the router, which won't find another route @@ -463,6 +503,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment succeeded") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 5) sender.send(paymentFSM, request) @@ -569,9 +610,9 @@ class PaymentLifecycleSpec extends BaseRouterSpec { ) for ((failure, expectedNodes, expectedChannels) <- testCases) { - val (ignoreNodes, ignoreChannels) = PaymentFailure.updateIgnored(failure, Set.empty[PublicKey], Set.empty[ChannelDesc]) - assert(ignoreNodes === expectedNodes, failure) - assert(ignoreChannels === expectedChannels, failure) + val ignore = PaymentFailure.updateIgnored(failure, Ignore.empty) + assert(ignore.nodes === expectedNodes, failure) + assert(ignore.channels === expectedChannels, failure) } val failures = Seq( @@ -579,14 +620,15 @@ class PaymentLifecycleSpec extends BaseRouterSpec { RemoteFailure(route_abcd, Sphinx.DecryptedFailurePacket(b, UnknownNextPeer)), LocalFailure(route_abcd, new RuntimeException("fatal")) ) - val (ignoreNodes, ignoreChannels) = PaymentFailure.updateIgnored(failures, Set.empty[PublicKey], Set.empty[ChannelDesc]) - assert(ignoreNodes === Set(c)) - assert(ignoreChannels === Set(ChannelDesc(channelId_ab, a, b), ChannelDesc(channelId_bc, b, c))) + val ignore = PaymentFailure.updateIgnored(failures, Ignore.empty) + assert(ignore.nodes === Set(c)) + assert(ignore.channels === Set(ChannelDesc(channelId_ab, a, b), ChannelDesc(channelId_bc, b, c))) } test("disable database and events") { routerFixture => val payFixture = createPaymentLifecycle(storeInDb = false, publishEvent = false) import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 3) sender.send(paymentFSM, request) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala index 61ba0f43e..34d7c8b64 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala @@ -28,7 +28,8 @@ import fr.acinq.eclair.payment.OutgoingPacket.{buildCommand, buildOnion, buildPa import fr.acinq.eclair.payment.relay.Origin._ import fr.acinq.eclair.payment.relay.Relayer._ import fr.acinq.eclair.payment.relay.{CommandBuffer, Relayer} -import fr.acinq.eclair.router.Router.{ChannelHop, GetNetworkStats, GetNetworkStatsResponse, NodeHop, TickComputeNetworkStats} +import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.PreimageReceived +import fr.acinq.eclair.router.Router.{ChannelHop, Ignore, NodeHop} import fr.acinq.eclair.router.{Announcements, _} import fr.acinq.eclair.wire.Onion.{ChannelRelayTlvPayload, FinalLegacyPayload, FinalTlvPayload, PerHopPayload} import fr.acinq.eclair.wire._ @@ -177,9 +178,13 @@ class RelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { sender.send(relayer, ForwardAdd(add_ab2)) // A multi-part payment FSM should start to relay the payment. - router.expectMsg(GetNetworkStats) - router.send(router.lastSender, GetNetworkStatsResponse(None)) - router.expectMsg(TickComputeNetworkStats) + val routeRequest1 = router.expectMsgType[Router.RouteRequest] + assert(routeRequest1.source === b) + assert(routeRequest1.target === c) + assert(routeRequest1.amount === finalAmount) + assert(routeRequest1.allowMultiPart) + assert(routeRequest1.ignore === Ignore.empty) + router.send(router.lastSender, Router.RouteResponse(Router.Route(finalAmount, ChannelHop(b, c, channelUpdate_bc) :: Nil) :: Nil)) // first try val fwd1 = register.expectMsgType[Register.ForwardShortId[CMD_ADD_HTLC]] @@ -191,6 +196,10 @@ class RelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { sender.send(relayer, Status.Failure(AddHtlcFailed(channelId_bc, paymentHash, HtlcValueTooHighInFlight(channelId_bc, UInt64(1000000000L), 1516977616L msat), origin1, Some(channelUpdate_bc), originalCommand = Some(fwd1.message)))) // second try + val routeRequest2 = router.expectMsgType[Router.RouteRequest] + assert(routeRequest2.ignore.channels.map(_.shortChannelId) === Set(channelUpdate_bc.shortChannelId)) + router.send(router.lastSender, Router.RouteResponse(Router.Route(finalAmount, ChannelHop(b, c, channelUpdate_bc) :: Nil) :: Nil)) + val fwd2 = register.expectMsgType[Register.ForwardShortId[CMD_ADD_HTLC]] assert(fwd2.shortChannelId === channelUpdate_bc.shortChannelId) assert(fwd2.message.upstream.asInstanceOf[Upstream.TrampolineRelayed].adds === Seq(add_ab1, add_ab2)) @@ -466,6 +475,9 @@ class RelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { sender.send(relayer, Status.Failure(AddHtlcFailed(channelId_bc, paymentHash, InsufficientFunds(channelId_bc, origin.amountOut, 100 sat, 0 sat, 0 sat), origin, Some(channelUpdate_bc), None))) assert(register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message.reason === Right(TemporaryChannelFailure(channelUpdate_bc))) + sender.send(relayer, Status.Failure(AddHtlcFailed(channelId_bc, paymentHash, FeerateTooDifferent(channelId_bc, 1000, 300), origin, Some(channelUpdate_bc), None))) + assert(register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message.reason === Right(TemporaryChannelFailure(channelUpdate_bc))) + val channelUpdate_bc_disabled = channelUpdate_bc.copy(channelFlags = 2) sender.send(relayer, Status.Failure(AddHtlcFailed(channelId_bc, paymentHash, ChannelUnavailable(channelId_bc), origin, Some(channelUpdate_bc_disabled), None))) assert(register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message.reason === Right(ChannelDisabled(channelUpdate_bc_disabled.messageFlags, channelUpdate_bc_disabled.channelFlags, channelUpdate_bc_disabled))) @@ -527,7 +539,10 @@ class RelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { val add_ab2 = UpdateAddHtlc(channelId_ab, 565, cmd2.amount, cmd2.paymentHash, cmd2.cltvExpiry, cmd2.onion) sender.send(relayer, ForwardAdd(add_ab1)) sender.send(relayer, ForwardAdd(add_ab2)) - router.expectMsg(GetNetworkStats) // A multi-part payment FSM is started to relay the payment downstream. + + // A multi-part payment FSM is started to relay the payment downstream. + val routeRequest = router.expectMsgType[Router.RouteRequest] + assert(routeRequest.allowMultiPart) // We simulate a fake htlc fulfill for the downstream channel. val payFSM = TestProbe() @@ -539,8 +554,9 @@ class RelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { } sender.send(relayer, forwardFulfill) - // the FSM responsible for the payment should receive the fulfill. + // the FSM responsible for the payment should receive the fulfill and emit a preimage event. payFSM.expectMsg(forwardFulfill) + system.actorSelection(relayer.path.child("node-relayer")).tell(PreimageReceived(paymentHash, preimage), payFSM.ref) // the payment should be immediately fulfilled upstream. val upstream1 = register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala index 562a7634f..1ceaa867f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala @@ -402,19 +402,25 @@ class RouterSpec extends BaseRouterSpec { val sender = TestProbe() // Via private channels. - sender.send(router, RouteRequest(a, h, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE)) + sender.send(router, RouteRequest(a, g, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE)) sender.expectMsgType[RouteResponse] - sender.send(router, RouteRequest(a, h, 50000000 msat, Long.MaxValue.msat)) - sender.expectMsg(Failure(RouteNotFound)) + sender.send(router, RouteRequest(a, g, 50000000 msat, Long.MaxValue.msat)) + sender.expectMsg(Failure(BalanceTooLow)) + sender.send(router, RouteRequest(a, g, 50000000 msat, Long.MaxValue.msat, allowMultiPart = true)) + sender.expectMsg(Failure(BalanceTooLow)) // Via public channels. - sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE)) + sender.send(router, RouteRequest(a, b, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE)) sender.expectMsgType[RouteResponse] val commitments1 = CommitmentsSpec.makeCommitments(10000000 msat, 20000000 msat, a, b, announceChannel = true) sender.send(router, LocalChannelUpdate(sender.ref, null, channelId_ab, b, Some(chan_ab), update_ab, commitments1)) - sender.send(router, RouteRequest(a, d, 12000000 msat, Long.MaxValue.msat)) - sender.expectMsg(Failure(RouteNotFound)) - sender.send(router, RouteRequest(a, d, 5000000 msat, Long.MaxValue.msat)) + sender.send(router, RouteRequest(a, b, 12000000 msat, Long.MaxValue.msat)) + sender.expectMsg(Failure(BalanceTooLow)) + sender.send(router, RouteRequest(a, b, 12000000 msat, Long.MaxValue.msat, allowMultiPart = true)) + sender.expectMsg(Failure(BalanceTooLow)) + sender.send(router, RouteRequest(a, b, 5000000 msat, Long.MaxValue.msat)) + sender.expectMsgType[RouteResponse] + sender.send(router, RouteRequest(a, b, 5000000 msat, Long.MaxValue.msat, allowMultiPart = true)) sender.expectMsgType[RouteResponse] } @@ -589,6 +595,21 @@ class RouterSpec extends BaseRouterSpec { assert(balances.contains(edge_ab.balance_opt)) assert(edge_ba.balance_opt === None) } + + { + // Private channels should also update the graph when HTLCs are relayed through them. + val balances = Set(33000000 msat, 5000000 msat) + val commitments = CommitmentsSpec.makeCommitments(33000000 msat, 5000000 msat, a, g, announceChannel = false) + sender.send(router, AvailableBalanceChanged(sender.ref, null, channelId_ag, commitments)) + sender.send(router, Symbol("data")) + val data = sender.expectMsgType[Data] + val channel_ag = data.privateChannels(channelId_ag) + assert(Set(channel_ag.meta.balance1, channel_ag.meta.balance2) === balances) + // And the graph should be updated too. + val edge_ag = data.graph.getEdge(ChannelDesc(channelId_ag, a, g)).get + assert(edge_ag.capacity == channel_ag.capacity) + assert(edge_ag.balance_opt === Some(33000000 msat)) + } } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/transactions/TransactionsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/transactions/TransactionsSpec.scala index 1c0f93d4d..d85db2d40 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/transactions/TransactionsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/transactions/TransactionsSpec.scala @@ -407,7 +407,7 @@ class TransactionsSpec extends AnyFunSuite with Logging { """[^$]+""").r // this regex extracts htlc direction and amounts val htlcRegex = - """.*HTLC ([a-z]+) amount ([0-9]+).*""".r + """.*HTLC [0-9] ([a-z]+) amount ([0-9]+).*""".r val dustLimit = 546 sat case class TestSetup(name: String, dustLimit: Satoshi, spec: CommitmentSpec, expectedFee: Satoshi) @@ -422,7 +422,7 @@ class TransactionsSpec extends AnyFunSuite with Logging { } }).toSet TestSetup(name, dustLimit, CommitmentSpec(htlcs = htlcs, feeratePerKw = feerate_per_kw.toLong, toLocal = MilliSatoshi(to_local_msat.toLong), toRemote = MilliSatoshi(to_remote_msat.toLong)), Satoshi(fee.toLong)) - }) + }).toSeq // simple non-reg test making sure we are not missing tests assert(tests.size === 15, "there were 15 tests at ec99f893f320e8c88f564c1c8566f3454f0f1f5f") diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/ExtraDirectives.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/ExtraDirectives.scala index 21f97cf9c..5b5430443 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/ExtraDirectives.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/ExtraDirectives.scala @@ -42,6 +42,7 @@ trait ExtraDirectives extends Directives { val channelIdFormParam_opt = "channelId".as[Option[ByteVector32]](sha256HashUnmarshaller) val channelIdsFormParam_opt = "channelIds".as[Option[List[ByteVector32]]](sha256HashesUnmarshaller) val nodeIdFormParam_opt = "nodeId".as[Option[PublicKey]](publicKeyUnmarshaller) + val nodeIdsFormParam_opt = "nodeIds".as[Option[Set[PublicKey]]](publicKeysUnmarshaller) val paymentHashFormParam_opt = "paymentHash".as[Option[ByteVector32]](sha256HashUnmarshaller) val fromFormParam_opt = "from".as[Long] val toFormParam_opt = "to".as[Long] diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/FormParamExtractors.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/FormParamExtractors.scala index 9ea740b86..7d1cc3cd0 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/FormParamExtractors.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/FormParamExtractors.scala @@ -38,6 +38,10 @@ object FormParamExtractors { PublicKey(ByteVector.fromValidHex(str)) } + implicit val publicKeysUnmarshaller: Deserializer[Option[String], Option[Set[PublicKey]]] = strictDeserializer { bin => + bin.split(",").map(str => PublicKey(ByteVector.fromValidHex(str))).toSet + } + implicit val binaryDataUnmarshaller: Deserializer[Option[String], Option[ByteVector]] = strictDeserializer { str => ByteVector.fromValidHex(str) } diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala index df7a5f349..9ac7779ac 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala @@ -115,7 +115,12 @@ trait Service extends ExtraDirectives with Logging { } } ~ path("peers") { - complete(eclairApi.peersInfo()) + complete(eclairApi.peers()) + } ~ + path("nodes") { + formFields(nodeIdsFormParam_opt) { nodeIds_opt => + complete(eclairApi.nodes(nodeIds_opt)) + } } ~ path("channels") { formFields(nodeIdFormParam_opt) { toRemoteNodeId_opt => @@ -127,9 +132,6 @@ trait Service extends ExtraDirectives with Logging { complete(eclairApi.channelInfo(channelIdentifier)) } } ~ - path("allnodes") { - complete(eclairApi.allNodes()) - } ~ path("allchannels") { complete(eclairApi.allChannels()) } ~ @@ -228,7 +230,9 @@ trait Service extends ExtraDirectives with Logging { } } ~ path("channelstats") { - complete(eclairApi.channelStats()) + formFields(fromFormParam_opt.?, toFormParam_opt.?) { (from_opt, to_opt) => + complete(eclairApi.channelStats(from_opt, to_opt)) + } } ~ path("usablebalances") { complete(eclairApi.usableBalances()) diff --git a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala index 36e737d53..9a1af6b37 100644 --- a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala +++ b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala @@ -123,7 +123,7 @@ class ApiServiceSpec extends AnyFunSuiteLike with ScalatestRouteTest with RouteT test("'peers' should ask the switchboard for current known peers") { val mockEclair = mock[Eclair] val service = new MockService(mockEclair) - mockEclair.peersInfo()(any[Timeout]) returns Future.successful(List( + mockEclair.peers()(any[Timeout]) returns Future.successful(List( PeerInfo( nodeId = aliceNodeId, state = "CONNECTED", @@ -142,7 +142,7 @@ class ApiServiceSpec extends AnyFunSuiteLike with ScalatestRouteTest with RouteT assert(handled) assert(status == OK) val response = responseAs[String] - mockEclair.peersInfo()(any[Timeout]).wasCalled(once) + mockEclair.peers()(any[Timeout]).wasCalled(once) matchTestJson("peers", response) } }