diff --git a/eclair-core/eclair-cli b/eclair-core/eclair-cli index 6d87b1aaf..c9286529c 100755 --- a/eclair-core/eclair-cli +++ b/eclair-core/eclair-cli @@ -33,7 +33,7 @@ and COMMAND is one of the available commands: - connect - disconnect - peers - - allnodes + - nodes - audit === Channel === diff --git a/eclair-core/src/main/resources/reference.conf b/eclair-core/src/main/resources/reference.conf index 14fc11386..7cce0a2d7 100644 --- a/eclair-core/src/main/resources/reference.conf +++ b/eclair-core/src/main/resources/reference.conf @@ -98,9 +98,10 @@ eclair { claim-main = 12 // target for the claim main transaction (tx that spends main channel output back to wallet) } - // maximum local vs remote feerate mismatch; 1.0 means 100% - // actual check is abs((local feerate - remote fee rate) / (local fee rate + remote fee rate)/2) > fee rate mismatch - max-feerate-mismatch = 1.56 // will allow remote fee rates up to 8x bigger or smaller than our local fee rate + feerate-tolerance { + ratio-low = 0.5 // will allow remote fee rates as low as half our local feerate + ratio-high = 10.0 // will allow remote fee rates as high as 10 times our local feerate + } close-on-offline-feerate-mismatch = true // do not change this unless you know what you are doing // funder will send an UpdateFee message if the difference between current commitment fee and actual current network fee is greater diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index e9b346cce..2dcba4e93 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -86,7 +86,9 @@ trait Eclair { def channelInfo(channel: ApiTypes.ChannelIdentifier)(implicit timeout: Timeout): Future[RES_GETINFO] - def peersInfo()(implicit timeout: Timeout): Future[Iterable[PeerInfo]] + def peers()(implicit timeout: Timeout): Future[Iterable[PeerInfo]] + + def nodes(nodeIds_opt: Option[Set[PublicKey]] = None)(implicit timeout: Timeout): Future[Iterable[NodeAnnouncement]] def receive(description: String, amount_opt: Option[MilliSatoshi], expire_opt: Option[Long], fallbackAddress_opt: Option[String], paymentPreimage_opt: Option[ByteVector32])(implicit timeout: Timeout): Future[PaymentRequest] @@ -108,7 +110,7 @@ trait Eclair { def networkFees(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[NetworkFee]] - def channelStats()(implicit timeout: Timeout): Future[Seq[Stats]] + def channelStats(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[Stats]] def networkStats()(implicit timeout: Timeout): Future[Option[NetworkStats]] @@ -118,8 +120,6 @@ trait Eclair { def allInvoices(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[PaymentRequest]] - def allNodes()(implicit timeout: Timeout): Future[Iterable[NodeAnnouncement]] - def allChannels()(implicit timeout: Timeout): Future[Iterable[ChannelDesc]] def allUpdates(nodeId_opt: Option[PublicKey])(implicit timeout: Timeout): Future[Iterable[ChannelUpdate]] @@ -174,11 +174,17 @@ class EclairImpl(appKit: Kit) extends Eclair { sendToChannels[ChannelCommandResponse](channels, CMD_UPDATE_RELAY_FEE(feeBaseMsat, feeProportionalMillionths)) } - override def peersInfo()(implicit timeout: Timeout): Future[Iterable[PeerInfo]] = for { + override def peers()(implicit timeout: Timeout): Future[Iterable[PeerInfo]] = for { peers <- (appKit.switchboard ? Symbol("peers")).mapTo[Iterable[ActorRef]] peerinfos <- Future.sequence(peers.map(peer => (peer ? GetPeerInfo).mapTo[PeerInfo])) } yield peerinfos + override def nodes(nodeIds_opt: Option[Set[PublicKey]])(implicit timeout: Timeout): Future[Iterable[NodeAnnouncement]] = { + (appKit.router ? Symbol("nodes")) + .mapTo[Iterable[NodeAnnouncement]] + .map(_.filter(n => nodeIds_opt.forall(_.contains(n.nodeId)))) + } + override def channelsInfo(toRemoteNode_opt: Option[PublicKey])(implicit timeout: Timeout): Future[Iterable[RES_GETINFO]] = toRemoteNode_opt match { case Some(pk) => for { channelIds <- (appKit.register ? Symbol("channelsTo")).mapTo[Map[ByteVector32, PublicKey]].map(_.filter(_._2 == pk).keys) @@ -194,8 +200,6 @@ class EclairImpl(appKit: Kit) extends Eclair { sendToChannel[RES_GETINFO](channel, CMD_GETINFO) } - override def allNodes()(implicit timeout: Timeout): Future[Iterable[NodeAnnouncement]] = (appKit.router ? Symbol("nodes")).mapTo[Iterable[NodeAnnouncement]] - override def allChannels()(implicit timeout: Timeout): Future[Iterable[ChannelDesc]] = { (appKit.router ? Symbol("channels")).mapTo[Iterable[ChannelAnnouncement]].map(_.map(c => ChannelDesc(c.shortChannelId, c.nodeId1, c.nodeId2))) } @@ -312,7 +316,10 @@ class EclairImpl(appKit: Kit) extends Eclair { Future(appKit.nodeParams.db.audit.listNetworkFees(filter.from, filter.to)) } - override def channelStats()(implicit timeout: Timeout): Future[Seq[Stats]] = Future(appKit.nodeParams.db.audit.stats) + override def channelStats(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[Stats]] = { + val filter = getDefaultTimestampFilters(from_opt, to_opt) + Future(appKit.nodeParams.db.audit.stats(filter.from, filter.to)) + } override def networkStats()(implicit timeout: Timeout): Future[Option[NetworkStats]] = (appKit.router ? GetNetworkStats).mapTo[GetNetworkStatsResponse].map(_.stats) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index 12d4dcd07..6e2fa1460 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -27,7 +27,7 @@ import com.google.common.io.Files import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.{Block, ByteVector32, Satoshi} import fr.acinq.eclair.NodeParams.WatcherType -import fr.acinq.eclair.blockchain.fee.{FeeEstimator, FeeTargets, OnChainFeeConf} +import fr.acinq.eclair.blockchain.fee.{FeeEstimator, FeeTargets, FeerateTolerance, OnChainFeeConf} import fr.acinq.eclair.channel.Channel import fr.acinq.eclair.crypto.KeyManager import fr.acinq.eclair.db._ @@ -143,7 +143,9 @@ object NodeParams { "update-fee_min-diff-ratio" -> "on-chain-fees.update-fee-min-diff-ratio", // v0.3.3 "global-features" -> "features", - "local-features" -> "features" + "local-features" -> "features", + // v0.4.1 + "on-chain-fees.max-feerate-mismatch" -> "on-chain-fees.feerate-tolerance.ratio-low / on-chain-fees.feerate-tolerance.ratio-high" ) deprecatedKeyPaths.foreach { case (old, new_) => require(!config.hasPath(old), s"configuration key '$old' has been replaced by '$new_'") @@ -246,7 +248,7 @@ object NodeParams { onChainFeeConf = OnChainFeeConf( feeTargets = feeTargets, feeEstimator = feeEstimator, - maxFeerateMismatch = config.getDouble("on-chain-fees.max-feerate-mismatch"), + maxFeerateMismatch = FeerateTolerance(config.getDouble("on-chain-fees.feerate-tolerance.ratio-low"), config.getDouble("on-chain-fees.feerate-tolerance.ratio-high")), closeOnOfflineMismatch = config.getBoolean("on-chain-fees.close-on-offline-feerate-mismatch"), updateFeeMinDiffRatio = config.getDouble("on-chain-fees.update-fee-min-diff-ratio") ), diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/fee/FeeEstimator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/fee/FeeEstimator.scala index bdd0a24e1..eff88c1ee 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/fee/FeeEstimator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/fee/FeeEstimator.scala @@ -18,12 +18,14 @@ package fr.acinq.eclair.blockchain.fee trait FeeEstimator { - def getFeeratePerKb(target: Int) : Long + def getFeeratePerKb(target: Int): Long - def getFeeratePerKw(target: Int) : Long + def getFeeratePerKw(target: Int): Long } case class FeeTargets(fundingBlockTarget: Int, commitmentBlockTarget: Int, mutualCloseBlockTarget: Int, claimMainBlockTarget: Int) -case class OnChainFeeConf(feeTargets: FeeTargets, feeEstimator: FeeEstimator, maxFeerateMismatch: Double, closeOnOfflineMismatch: Boolean, updateFeeMinDiffRatio: Double) \ No newline at end of file +case class FeerateTolerance(ratioLow: Double, ratioHigh: Double) + +case class OnChainFeeConf(feeTargets: FeeTargets, feeEstimator: FeeEstimator, maxFeerateMismatch: FeerateTolerance, closeOnOfflineMismatch: Boolean, updateFeeMinDiffRatio: Double) \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala index 814062f60..b164661f4 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala @@ -35,10 +35,8 @@ import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.transactions._ import fr.acinq.eclair.wire._ import scodec.bits.ByteVector -import ChannelVersion._ import scala.collection.immutable.Queue -import scala.compat.Platform import scala.concurrent.ExecutionContext import scala.concurrent.duration._ import scala.util.{Failure, Success, Try} @@ -641,7 +639,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId handleCommandError(AddHtlcFailed(d.channelId, c.paymentHash, error, origin(c), Some(d.channelUpdate), Some(c)), c) case Event(c: CMD_ADD_HTLC, d: DATA_NORMAL) => - Commitments.sendAdd(d.commitments, c, origin(c), nodeParams.currentBlockHeight) match { + Commitments.sendAdd(d.commitments, c, origin(c), nodeParams.currentBlockHeight, nodeParams.onChainFeeConf) match { case Success((commitments1, add)) => if (c.commit) self ! CMD_SIGN context.system.eventStream.publish(AvailableBalanceChanged(self, d.channelId, d.shortChannelId, commitments1)) @@ -650,7 +648,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId } case Event(add: UpdateAddHtlc, d: DATA_NORMAL) => - Commitments.receiveAdd(d.commitments, add) match { + Commitments.receiveAdd(d.commitments, add, nodeParams.onChainFeeConf) match { case Success(commitments1) => stay using d.copy(commitments = commitments1) case Failure(cause) => handleLocalError(cause, d, Some(add)) } @@ -722,7 +720,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId } case Event(fee: UpdateFee, d: DATA_NORMAL) => - Commitments.receiveFee(d.commitments, nodeParams.onChainFeeConf.feeEstimator, nodeParams.onChainFeeConf.feeTargets, fee, nodeParams.onChainFeeConf.maxFeerateMismatch) match { + Commitments.receiveFee(d.commitments, fee, nodeParams.onChainFeeConf) match { case Success(commitments1) => stay using d.copy(commitments = commitments1) case Failure(cause) => handleLocalError(cause, d, Some(fee)) } @@ -810,14 +808,14 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId case Event(c@CMD_CLOSE(localScriptPubKey_opt), d: DATA_NORMAL) => val localScriptPubKey = localScriptPubKey_opt.getOrElse(d.commitments.localParams.defaultFinalScriptPubKey) - if (d.localShutdown.isDefined) + if (d.localShutdown.isDefined) { handleCommandError(ClosingAlreadyInProgress(d.channelId), c) - else if (Commitments.localHasUnsignedOutgoingHtlcs(d.commitments)) - // TODO: simplistic behavior, we could also sign-then-close - handleCommandError(CannotCloseWithUnsignedOutgoingHtlcs(d.channelId), c) - else if (!Closing.isValidFinalScriptPubkey(localScriptPubKey)) + } else if (Commitments.localHasUnsignedOutgoingHtlcs(d.commitments)) { + // TODO: simplistic behavior, we could also sign-then-close + handleCommandError(CannotCloseWithUnsignedOutgoingHtlcs(d.channelId), c) + } else if (!Closing.isValidFinalScriptPubkey(localScriptPubKey)) { handleCommandError(InvalidFinalScript(d.channelId), c) - else { + } else { val shutdown = Shutdown(d.channelId, localScriptPubKey) handleCommandSuccess(sender, d.copy(localShutdown = Some(shutdown))) storing() sending shutdown } @@ -1067,7 +1065,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId } case Event(fee: UpdateFee, d: DATA_SHUTDOWN) => - Commitments.receiveFee(d.commitments, nodeParams.onChainFeeConf.feeEstimator, nodeParams.onChainFeeConf.feeTargets, fee, nodeParams.onChainFeeConf.maxFeerateMismatch) match { + Commitments.receiveFee(d.commitments, fee, nodeParams.onChainFeeConf) match { case Success(commitments1) => stay using d.copy(commitments = commitments1) case Failure(cause) => handleLocalError(cause, d, Some(fee)) } @@ -1796,13 +1794,18 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId def handleCurrentFeerate(c: CurrentFeerates, d: HasCommitments) = { val networkFeeratePerKw = c.feeratesPerKw.feePerBlock(target = nodeParams.onChainFeeConf.feeTargets.commitmentBlockTarget) val currentFeeratePerKw = d.commitments.localCommit.spec.feeratePerKw - d.commitments.localParams.isFunder match { - case true if Helpers.shouldUpdateFee(currentFeeratePerKw, networkFeeratePerKw, nodeParams.onChainFeeConf.updateFeeMinDiffRatio) => - self ! CMD_UPDATE_FEE(networkFeeratePerKw, commit = true) - stay - case false if Helpers.isFeeDiffTooHigh(currentFeeratePerKw, networkFeeratePerKw, nodeParams.onChainFeeConf.maxFeerateMismatch) => - handleLocalError(FeerateTooDifferent(d.channelId, localFeeratePerKw = networkFeeratePerKw, remoteFeeratePerKw = d.commitments.localCommit.spec.feeratePerKw), d, Some(c)) - case _ => stay + val shouldUpdateFee = d.commitments.localParams.isFunder && + Helpers.shouldUpdateFee(currentFeeratePerKw, networkFeeratePerKw, nodeParams.onChainFeeConf.updateFeeMinDiffRatio) + val shouldClose = !d.commitments.localParams.isFunder && + Helpers.isFeeDiffTooHigh(networkFeeratePerKw, currentFeeratePerKw, nodeParams.onChainFeeConf.maxFeerateMismatch) && + d.commitments.hasPendingOrProposedHtlcs // we close only if we have HTLCs potentially at risk + if (shouldUpdateFee) { + self ! CMD_UPDATE_FEE(networkFeeratePerKw, commit = true) + stay + } else if (shouldClose) { + handleLocalError(FeerateTooDifferent(d.channelId, localFeeratePerKw = networkFeeratePerKw, remoteFeeratePerKw = d.commitments.localCommit.spec.feeratePerKw), d, Some(c)) + } else { + stay } } @@ -1816,13 +1819,16 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId def handleOfflineFeerate(c: CurrentFeerates, d: HasCommitments) = { val networkFeeratePerKw = c.feeratesPerKw.feePerBlock(target = nodeParams.onChainFeeConf.feeTargets.commitmentBlockTarget) val currentFeeratePerKw = d.commitments.localCommit.spec.feeratePerKw - // if the fees are too high we risk to not be able to confirm our current commitment - if (networkFeeratePerKw > currentFeeratePerKw && Helpers.isFeeDiffTooHigh(currentFeeratePerKw, networkFeeratePerKw, nodeParams.onChainFeeConf.maxFeerateMismatch)) { + // if the network fees are too high we risk to not be able to confirm our current commitment + val shouldClose = networkFeeratePerKw > currentFeeratePerKw && + Helpers.isFeeDiffTooHigh(networkFeeratePerKw, currentFeeratePerKw, nodeParams.onChainFeeConf.maxFeerateMismatch) && + d.commitments.hasPendingOrProposedHtlcs // we close only if we have HTLCs potentially at risk + if (shouldClose) { if (nodeParams.onChainFeeConf.closeOnOfflineMismatch) { - log.warning(s"closing OFFLINE ${d.channelId} due to fee mismatch: currentFeeratePerKw=$currentFeeratePerKw networkFeeratePerKw=$networkFeeratePerKw") + log.warning(s"closing OFFLINE channel due to fee mismatch: currentFeeratePerKw=$currentFeeratePerKw networkFeeratePerKw=$networkFeeratePerKw") handleLocalError(FeerateTooDifferent(d.channelId, localFeeratePerKw = currentFeeratePerKw, remoteFeeratePerKw = networkFeeratePerKw), d, Some(c)) } else { - log.warning(s"channel ${d.channelId} is OFFLINE but its fee mismatch is over the threshold: currentFeeratePerKw=$currentFeeratePerKw networkFeeratePerKw=$networkFeeratePerKw") + log.warning(s"channel is OFFLINE but its fee mismatch is over the threshold: currentFeeratePerKw=$currentFeeratePerKw networkFeeratePerKw=$networkFeeratePerKw") stay } } else { @@ -2117,7 +2123,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId val remotePerCommitmentPoint = d.remoteChannelReestablish.myCurrentPerCommitmentPoint val remoteCommitPublished = Helpers.Closing.claimRemoteCommitMainOutput(keyManager, d.commitments, remotePerCommitmentPoint, commitTx, nodeParams.onChainFeeConf.feeEstimator, nodeParams.onChainFeeConf.feeTargets) val nextData = DATA_CLOSING(d.commitments, fundingTx = None, waitingSince = now, Nil, futureRemoteCommitPublished = Some(remoteCommitPublished)) - goto(CLOSING) using nextData storing() calling(doPublish(remoteCommitPublished)) + goto(CLOSING) using nextData storing() calling (doPublish(remoteCommitPublished)) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala index 81c9c94af..451348ece 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala @@ -17,10 +17,9 @@ package fr.acinq.eclair.channel import akka.event.LoggingAdapter -import fr.acinq.eclair.channel.ChannelVersion._ import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey, sha256} import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Crypto} -import fr.acinq.eclair.blockchain.fee.{FeeEstimator, FeeTargets} +import fr.acinq.eclair.blockchain.fee.OnChainFeeConf import fr.acinq.eclair.channel.Monitoring.Metrics import fr.acinq.eclair.crypto.{Generators, KeyManager, ShaChain, Sphinx} import fr.acinq.eclair.payment.relay.{Origin, Relayer} @@ -36,7 +35,9 @@ import scala.util.{Failure, Success, Try} case class LocalChanges(proposed: List[UpdateMessage], signed: List[UpdateMessage], acked: List[UpdateMessage]) { def all: List[UpdateMessage] = proposed ++ signed ++ acked } -case class RemoteChanges(proposed: List[UpdateMessage], acked: List[UpdateMessage], signed: List[UpdateMessage]) +case class RemoteChanges(proposed: List[UpdateMessage], acked: List[UpdateMessage], signed: List[UpdateMessage]) { + def all: List[UpdateMessage] = proposed ++ signed ++ acked +} case class Changes(ourChanges: LocalChanges, theirChanges: RemoteChanges) case class HtlcTxAndSigs(txinfo: TransactionWithInputInfo, localSig: ByteVector64, remoteSig: ByteVector64) case class PublishableTxs(commitTx: CommitTx, htlcTxsAndSigs: List[HtlcTxAndSigs]) @@ -68,6 +69,10 @@ case class Commitments(channelVersion: ChannelVersion, def hasNoPendingHtlcs: Boolean = localCommit.spec.htlcs.isEmpty && remoteCommit.spec.htlcs.isEmpty && remoteNextCommitInfo.isRight + def hasPendingOrProposedHtlcs: Boolean = !hasNoPendingHtlcs || + localChanges.all.exists(_.isInstanceOf[UpdateAddHtlc]) || + remoteChanges.all.exists(_.isInstanceOf[UpdateAddHtlc]) + def timedOutOutgoingHtlcs(blockheight: Long): Set[UpdateAddHtlc] = { def expired(add: UpdateAddHtlc) = blockheight >= add.cltvExpiry.toLong @@ -171,7 +176,7 @@ object Commitments { * @param cmd add HTLC command * @return either Left(failure, error message) where failure is a failure message (see BOLT #4 and the Failure Message class) or Right((new commitments, updateAddHtlc) */ - def sendAdd(commitments: Commitments, cmd: CMD_ADD_HTLC, origin: Origin, blockHeight: Long): Try[(Commitments, UpdateAddHtlc)] = { + def sendAdd(commitments: Commitments, cmd: CMD_ADD_HTLC, origin: Origin, blockHeight: Long, feeConf: OnChainFeeConf): Try[(Commitments, UpdateAddHtlc)] = { // our counterparty needs a reasonable amount of time to pull the funds from downstream before we can get refunded (see BOLT 2 and BOLT 11 for a calculation and rationale) val minExpiry = Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry(blockHeight) if (cmd.cltvExpiry < minExpiry) { @@ -189,6 +194,13 @@ object Commitments { return Failure(HtlcValueTooSmall(commitments.channelId, minimum = htlcMinimum, actual = cmd.amount)) } + // we allowed mismatches between our feerates and our remote's as long as commitments didn't contain any HTLC at risk + // we need to verify that we're not disagreeing on feerates anymore before offering new HTLCs + val localFeeratePerKw = feeConf.feeEstimator.getFeeratePerKw(target = feeConf.feeTargets.commitmentBlockTarget) + if (Helpers.isFeeDiffTooHigh(localFeeratePerKw, commitments.localCommit.spec.feeratePerKw, feeConf.maxFeerateMismatch)) { + return Failure(FeerateTooDifferent(commitments.channelId, localFeeratePerKw = localFeeratePerKw, remoteFeeratePerKw = commitments.localCommit.spec.feeratePerKw)) + } + // let's compute the current commitment *as seen by them* with this change taken into account val add = UpdateAddHtlc(commitments.channelId, commitments.localNextHtlcId, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) // we increment the local htlc index and add an entry to the origins map @@ -230,7 +242,7 @@ object Commitments { Success(commitments1, add) } - def receiveAdd(commitments: Commitments, add: UpdateAddHtlc): Try[Commitments] = Try { + def receiveAdd(commitments: Commitments, add: UpdateAddHtlc, feeConf: OnChainFeeConf): Try[Commitments] = Try { if (add.id != commitments.remoteNextHtlcId) { throw UnexpectedHtlcId(commitments.channelId, expected = commitments.remoteNextHtlcId, actual = add.id) } @@ -241,6 +253,13 @@ object Commitments { throw HtlcValueTooSmall(commitments.channelId, minimum = htlcMinimum, actual = add.amountMsat) } + // we allowed mismatches between our feerates and our remote's as long as commitments didn't contain any HTLC at risk + // we need to verify that we're not disagreeing on feerates anymore before accepting new HTLCs + val localFeeratePerKw = feeConf.feeEstimator.getFeeratePerKw(target = feeConf.feeTargets.commitmentBlockTarget) + if (Helpers.isFeeDiffTooHigh(localFeeratePerKw, commitments.localCommit.spec.feeratePerKw, feeConf.maxFeerateMismatch)) { + throw FeerateTooDifferent(commitments.channelId, localFeeratePerKw = localFeeratePerKw, remoteFeeratePerKw = commitments.localCommit.spec.feeratePerKw) + } + // let's compute the current commitment *as seen by us* including this change val commitments1 = addRemoteProposal(commitments, add).copy(remoteNextHtlcId = commitments.remoteNextHtlcId + 1) val reduced = CommitmentSpec.reduce(commitments1.localCommit.spec, commitments1.localChanges.acked, commitments1.remoteChanges.proposed) @@ -390,16 +409,16 @@ object Commitments { } } - def receiveFee(commitments: Commitments, feeEstimator: FeeEstimator, feeTargets: FeeTargets, fee: UpdateFee, maxFeerateMismatch: Double)(implicit log: LoggingAdapter): Try[Commitments] = { + def receiveFee(commitments: Commitments, fee: UpdateFee, feeConf: OnChainFeeConf)(implicit log: LoggingAdapter): Try[Commitments] = { if (commitments.localParams.isFunder) { Failure(FundeeCannotSendUpdateFee(commitments.channelId)) } else if (fee.feeratePerKw < fr.acinq.eclair.MinimumFeeratePerKw) { Failure(FeerateTooSmall(commitments.channelId, remoteFeeratePerKw = fee.feeratePerKw)) } else { Metrics.RemoteFeeratePerKw.withoutTags().record(fee.feeratePerKw) - val localFeeratePerKw = feeEstimator.getFeeratePerKw(target = feeTargets.commitmentBlockTarget) + val localFeeratePerKw = feeConf.feeEstimator.getFeeratePerKw(target = feeConf.feeTargets.commitmentBlockTarget) log.info("remote feeratePerKw={}, local feeratePerKw={}, ratio={}", fee.feeratePerKw, localFeeratePerKw, fee.feeratePerKw.toDouble / localFeeratePerKw) - if (Helpers.isFeeDiffTooHigh(fee.feeratePerKw, localFeeratePerKw, maxFeerateMismatch)) { + if (Helpers.isFeeDiffTooHigh(localFeeratePerKw, fee.feeratePerKw, feeConf.maxFeerateMismatch) && commitments.hasPendingOrProposedHtlcs) { Failure(FeerateTooDifferent(commitments.channelId, localFeeratePerKw = localFeeratePerKw, remoteFeeratePerKw = fee.feeratePerKw)) } else { // NB: we check that the funder can afford this new fee even if spec allows to do it at next signature @@ -453,6 +472,7 @@ object Commitments { // NB: IN/OUT htlcs are inverted because this is the remote commit log.info(s"built remote commit number=${remoteCommit.index + 1} toLocalMsat=${spec.toLocal.toLong} toRemoteMsat=${spec.toRemote.toLong} htlc_in={} htlc_out={} feeratePerKw=${spec.feeratePerKw} txid=${remoteCommitTx.tx.txid} tx={}", spec.htlcs.collect(outgoing).map(_.id).mkString(","), spec.htlcs.collect(incoming).map(_.id).mkString(","), remoteCommitTx.tx) + Metrics.recordHtlcsInFlight(spec, remoteCommit.spec) // don't sign if they don't get paid val commitSig = CommitSig( diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala index feef8ea0b..3f3f3f55d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala @@ -21,7 +21,7 @@ import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey, ripemd160, sha256} import fr.acinq.bitcoin.Script._ import fr.acinq.bitcoin._ import fr.acinq.eclair.blockchain.EclairWallet -import fr.acinq.eclair.blockchain.fee.{FeeEstimator, FeeTargets} +import fr.acinq.eclair.blockchain.fee.{FeeEstimator, FeeTargets, FeerateTolerance} import fr.acinq.eclair.channel.Channel.REFRESH_CHANNEL_UPDATE_INTERVAL import fr.acinq.eclair.crypto.{Generators, KeyManager} import fr.acinq.eclair.db.ChannelsDb @@ -31,7 +31,6 @@ import fr.acinq.eclair.transactions.Transactions._ import fr.acinq.eclair.transactions._ import fr.acinq.eclair.wire._ import fr.acinq.eclair.{NodeParams, ShortChannelId, addressToPublicKeyScript, _} -import fr.acinq.eclair.channel.ChannelVersion.USE_STATIC_REMOTEKEY_BIT import scodec.bits.ByteVector import scala.concurrent.Await @@ -129,7 +128,7 @@ object Helpers { } val localFeeratePerKw = nodeParams.onChainFeeConf.feeEstimator.getFeeratePerKw(target = nodeParams.onChainFeeConf.feeTargets.commitmentBlockTarget) - if (isFeeDiffTooHigh(open.feeratePerKw, localFeeratePerKw, nodeParams.onChainFeeConf.maxFeerateMismatch)) throw FeerateTooDifferent(open.temporaryChannelId, localFeeratePerKw, open.feeratePerKw) + if (isFeeDiffTooHigh(localFeeratePerKw, open.feeratePerKw, nodeParams.onChainFeeConf.maxFeerateMismatch)) throw FeerateTooDifferent(open.temporaryChannelId, localFeeratePerKw, open.feeratePerKw) // only enforce dust limit check on mainnet if (nodeParams.chainHash == Block.LivenetGenesisBlock.hash) { if (open.dustLimitSatoshis < Channel.MIN_DUSTLIMIT) throw DustLimitTooSmall(open.temporaryChannelId, open.dustLimitSatoshis, Channel.MIN_DUSTLIMIT) @@ -189,25 +188,20 @@ object Helpers { } /** - * @param referenceFeePerKw reference fee rate per kiloweight - * @param currentFeePerKw current fee rate per kiloweight - * @return the "normalized" difference between i.e local and remote fee rate: |reference - current| / avg(current, reference) + * To avoid spamming our peers with fee updates every time there's a small variation, we only update the fee when the + * difference exceeds a given ratio (updateFeeMinDiffRatio). */ - def feeRateMismatch(referenceFeePerKw: Long, currentFeePerKw: Long): Double = - Math.abs((2.0 * (referenceFeePerKw - currentFeePerKw)) / (currentFeePerKw + referenceFeePerKw)) - - def shouldUpdateFee(commitmentFeeratePerKw: Long, networkFeeratePerKw: Long, updateFeeMinDiffRatio: Double): Boolean = - feeRateMismatch(networkFeeratePerKw, commitmentFeeratePerKw) > updateFeeMinDiffRatio + def shouldUpdateFee(currentFeeratePerKw: Long, nextFeeratePerKw: Long, updateFeeMinDiffRatio: Double): Boolean = + currentFeeratePerKw == 0 || Math.abs((currentFeeratePerKw - nextFeeratePerKw).toDouble / currentFeeratePerKw) > updateFeeMinDiffRatio /** - * @param referenceFeePerKw reference fee rate per kiloweight - * @param currentFeePerKw current fee rate per kiloweight - * @param maxFeerateMismatchRatio maximum fee rate mismatch ratio - * @return true if the difference between current and reference fee rates is too high. - * the actual check is |reference - current| / avg(current, reference) > mismatch ratio + * @param referenceFeePerKw reference fee rate per kiloweight + * @param currentFeePerKw current fee rate per kiloweight + * @param maxFeerateMismatch maximum fee rate mismatch tolerated + * @return true if the difference between proposed and reference fee rates is too high. */ - def isFeeDiffTooHigh(referenceFeePerKw: Long, currentFeePerKw: Long, maxFeerateMismatchRatio: Double): Boolean = - feeRateMismatch(referenceFeePerKw, currentFeePerKw) > maxFeerateMismatchRatio + def isFeeDiffTooHigh(referenceFeePerKw: Long, currentFeePerKw: Long, maxFeerateMismatch: FeerateTolerance): Boolean = + currentFeePerKw < referenceFeePerKw * maxFeerateMismatch.ratioLow || referenceFeePerKw * maxFeerateMismatch.ratioHigh < currentFeePerKw /** * @param remoteFeeratePerKw remote fee rate per kiloweight @@ -645,9 +639,9 @@ object Helpers { ) case _ => claimRemoteCommitMainOutput(keyManager, commitments, remoteCommit.remotePerCommitmentPoint, tx, feeEstimator, feeTargets).copy( - claimHtlcSuccessTxs = txes.toList.collect { case c: ClaimHtlcSuccessTx => c.tx }, - claimHtlcTimeoutTxs = txes.toList.collect { case c: ClaimHtlcTimeoutTx => c.tx } - ) + claimHtlcSuccessTxs = txes.toList.collect { case c: ClaimHtlcSuccessTx => c.tx }, + claimHtlcTimeoutTxs = txes.toList.collect { case c: ClaimHtlcTimeoutTx => c.tx } + ) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Monitoring.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Monitoring.scala index f1c9db6b0..11fdec12c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Monitoring.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Monitoring.scala @@ -16,6 +16,7 @@ package fr.acinq.eclair.channel +import fr.acinq.eclair.transactions.{CommitmentSpec, DirectedHtlc} import kamon.Kamon import kamon.tag.TagSet @@ -25,15 +26,35 @@ object Monitoring { val ChannelsCount = Kamon.gauge("channels.count") val ChannelErrors = Kamon.counter("channels.errors") val ChannelLifecycleEvents = Kamon.counter("channels.lifecycle") + val HtlcsInFlight = Kamon.histogram("channels.htlc-in-flight", "Per-channel HTLCs in flight") + val HtlcsInFlightGlobal = Kamon.gauge("channels.htlc-in-flight-global", "Global HTLCs in flight across all channels") + val HtlcValueInFlight = Kamon.histogram("channels.htlc-value-in-flight", "Per-channel HTLC value in flight") + val HtlcValueInFlightGlobal = Kamon.gauge("channels.htlc-value-in-flight-global", "Global HTLC value in flight across all channels") val LocalFeeratePerKw = Kamon.gauge("channels.local-feerate-per-kw") val RemoteFeeratePerKw = Kamon.histogram("channels.remote-feerate-per-kw") + + def recordHtlcsInFlight(remoteSpec: CommitmentSpec, previousRemoteSpec: CommitmentSpec): Unit = { + for (direction <- Tags.Directions.Incoming :: Tags.Directions.Outgoing :: Nil) { + // NB: IN/OUT htlcs are inverted because this is the remote commit + val filter = if (direction == Tags.Directions.Incoming) DirectedHtlc.outgoing else DirectedHtlc.incoming + // NB: we need the `toSeq` because otherwise duplicate amounts would be removed (since htlcs are sets) + val htlcs = remoteSpec.htlcs.collect(filter).toSeq.map(_.amountMsat) + val previousHtlcs = previousRemoteSpec.htlcs.collect(filter).toSeq.map(_.amountMsat) + HtlcsInFlight.withTag(Tags.Direction, direction).record(htlcs.length) + HtlcsInFlightGlobal.withTag(Tags.Direction, direction).increment(htlcs.length - previousHtlcs.length) + val (value, previousValue) = (htlcs.sum.truncateToSatoshi.toLong, previousHtlcs.sum.truncateToSatoshi.toLong) + HtlcValueInFlight.withTag(Tags.Direction, direction).record(value) + HtlcValueInFlightGlobal.withTag(Tags.Direction, direction).increment(value - previousValue) + } + } } object Tags { - val Event = TagSet.Empty - val Fatal = TagSet.Empty - val Origin = TagSet.Empty - val State = TagSet.Empty + val Direction = "direction" + val Event = "event" + val Fatal = "fatal" + val Origin = "origin" + val State = "state" object Events { val Created = "created" @@ -46,6 +67,11 @@ object Monitoring { val Remote = "remote" } + object Directions { + val Incoming = "incoming" + val Outgoing = "outgoing" + } + } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Monitoring.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Monitoring.scala new file mode 100644 index 000000000..38bbcb272 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Monitoring.scala @@ -0,0 +1,31 @@ +/* + * Copyright 2020 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.crypto + +import kamon.Kamon + +object Monitoring { + + object Metrics { + val OnionPayloadFormat = Kamon.counter("crypto.sphinx.onion-payload-format") + } + + object Tags { + val LegacyOnion = "legacy" + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala index d18bcf1e2..c96942e57 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala @@ -18,6 +18,7 @@ package fr.acinq.eclair.crypto import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.{ByteVector32, Crypto} +import fr.acinq.eclair.crypto.Monitoring.{Metrics, Tags} import fr.acinq.eclair.wire import fr.acinq.eclair.wire.{FailureMessage, FailureMessageCodecs, Onion, OnionCodecs} import grizzled.slf4j.Logging @@ -88,11 +89,13 @@ object Sphinx extends Logging { case 0 => // The 1.0 BOLT spec used 65-bytes frames inside the onion payload. // The first byte of the frame (called `realm`) is set to 0x00, followed by 32 bytes of per-hop data, followed by a 32-bytes mac. + Metrics.OnionPayloadFormat.withTag(Tags.LegacyOnion, value = true).increment() 65 case _ => // The 1.1 BOLT spec changed the frame format to use variable-length per-hop payloads. // The first bytes contain a varint encoding the length of the payload data (not including the trailing mac). // Since messages are always smaller than 65535 bytes, this varint will either be 1 or 3 bytes long. + Metrics.OnionPayloadFormat.withTag(Tags.LegacyOnion, value = false).increment() MacLength + OnionCodecs.payloadLengthDecoder.decode(payload.bits).require.value.toInt } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/AuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/AuditDb.scala index c3a1b9eea..978f16fce 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/AuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/AuditDb.scala @@ -45,7 +45,7 @@ trait AuditDb extends Closeable { def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] - def stats: Seq[Stats] + def stats(from: Long, to: Long): Seq[Stats] } @@ -53,4 +53,4 @@ case class ChannelLifecycleEvent(channelId: ByteVector32, remoteNodeId: PublicKe case class NetworkFee(remoteNodeId: PublicKey, channelId: ByteVector32, txId: ByteVector32, fee: Satoshi, txType: String, timestamp: Long) -case class Stats(channelId: ByteVector32, avgPaymentAmount: Satoshi, paymentCount: Int, relayFee: Satoshi, networkFee: Satoshi) +case class Stats(channelId: ByteVector32, direction: String, avgPaymentAmount: Satoshi, paymentCount: Int, relayFee: Satoshi, networkFee: Satoshi) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala index 67566c8c7..293c20b8e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala @@ -298,35 +298,46 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { q } - override def stats: Seq[Stats] = { - val networkFees = listNetworkFees(0, System.currentTimeMillis + 1).foldLeft(Map.empty[ByteVector32, Satoshi]) { case (feeByChannelId, f) => + override def stats(from: Long, to: Long): Seq[Stats] = { + val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { case (feeByChannelId, f) => feeByChannelId + (f.channelId -> (feeByChannelId.getOrElse(f.channelId, 0 sat) + f.fee)) } - val relayed = listRelayed(0, System.currentTimeMillis + 1).foldLeft(Map.empty[ByteVector32, Seq[PaymentRelayed]]) { case (relayedByChannelId, e) => - val relayedTo = e match { - case c: ChannelPaymentRelayed => Set(c.toChannelId) - case t: TrampolinePaymentRelayed => t.outgoing.map(_.channelId).toSet + case class Relayed(amount: MilliSatoshi, fee: MilliSatoshi, direction: String) + val relayed = listRelayed(from, to).foldLeft(Map.empty[ByteVector32, Seq[Relayed]]) { case (previous, e) => + // NB: we must avoid counting the fee twice: we associate it to the outgoing channels rather than the incoming ones. + val current = e match { + case c: ChannelPaymentRelayed => Map( + c.fromChannelId -> (Relayed(c.amountIn, 0 msat, "IN") +: previous.getOrElse(c.fromChannelId, Nil)), + c.toChannelId -> (Relayed(c.amountOut, c.amountIn - c.amountOut, "OUT") +: previous.getOrElse(c.toChannelId, Nil)) + ) + case t: TrampolinePaymentRelayed => + // We ensure a trampoline payment is counted only once per channel and per direction (if multiple HTLCs were + // sent from/to the same channel, we group them). + val in = t.incoming.groupBy(_.channelId).map { case (channelId, parts) => (channelId, Relayed(parts.map(_.amount).sum, 0 msat, "IN")) }.toSeq + val out = t.outgoing.groupBy(_.channelId).map { case (channelId, parts) => + val fee = (t.amountIn - t.amountOut) * parts.length / t.outgoing.length // we split the fee among outgoing channels + (channelId, Relayed(parts.map(_.amount).sum, fee, "OUT")) + }.toSeq + (in ++ out).groupBy(_._1).map { case (channelId, payments) => (channelId, payments.map(_._2) ++ previous.getOrElse(channelId, Nil)) } } - val updated = relayedTo.map(channelId => (channelId, relayedByChannelId.getOrElse(channelId, Nil) :+ e)).toMap - relayedByChannelId ++ updated + previous ++ current } // Channels opened by our peers won't have any entry in the network_fees table, but we still want to compute stats for them. val allChannels = networkFees.keySet ++ relayed.keySet - allChannels.map(channelId => { + allChannels.toSeq.flatMap(channelId => { val networkFee = networkFees.getOrElse(channelId, 0 sat) - val r = relayed.getOrElse(channelId, Nil) - val paymentCount = r.length - if (paymentCount == 0) { - Stats(channelId, 0 sat, 0, 0 sat, networkFee) - } else { - val avgPaymentAmount = r.map(_.amountOut).sum / paymentCount - val relayFee = r.map { - case c: ChannelPaymentRelayed => c.amountIn - c.amountOut - case t: TrampolinePaymentRelayed => (t.amountIn - t.amountOut) * t.outgoing.count(_.channelId == channelId) / t.outgoing.length - }.sum - Stats(channelId, avgPaymentAmount.truncateToSatoshi, paymentCount, relayFee.truncateToSatoshi, networkFee) + val (in, out) = relayed.getOrElse(channelId, Nil).partition(_.direction == "IN") + ((in, "IN") :: (out, "OUT") :: Nil).map { case (r, direction) => + val paymentCount = r.length + if (paymentCount == 0) { + Stats(channelId, direction, 0 sat, 0, 0 sat, networkFee) + } else { + val avgPaymentAmount = r.map(_.amount).sum / paymentCount + val relayFee = r.map(_.fee).sum + Stats(channelId, direction, avgPaymentAmount.truncateToSatoshi, paymentCount, relayFee.truncateToSatoshi, networkFee) + } } - }).toSeq + }) } // used by mobile apps diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala index adfb65091..32ac71676 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala @@ -219,6 +219,7 @@ object ChannelRelayer { case (_: ExpiryTooBig, _) => ExpiryTooFar case (_: InsufficientFunds, Some(channelUpdate)) => TemporaryChannelFailure(channelUpdate) case (_: TooManyAcceptedHtlcs, Some(channelUpdate)) => TemporaryChannelFailure(channelUpdate) + case (_: FeerateTooDifferent, Some(channelUpdate)) => TemporaryChannelFailure(channelUpdate) case (_: ChannelUnavailable, Some(channelUpdate)) if !Announcements.isEnabled(channelUpdate.channelFlags) => ChannelDisabled(channelUpdate.messageFlags, channelUpdate.channelFlags, channelUpdate) case (_: ChannelUnavailable, None) => PermanentChannelFailure case _ => TemporaryNodeFailure diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala index 11aebce41..52a6e4951 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala @@ -20,13 +20,13 @@ import java.util.UUID import akka.actor.{Actor, ActorRef, DiagnosticActorLogging, PoisonPill, Props} import akka.event.Logging.MDC +import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey -import fr.acinq.bitcoin.{ByteVector32, Crypto} import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Upstream} import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM -import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.SendMultiPartPayment +import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment} import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPayment import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentLifecycle} @@ -108,18 +108,13 @@ class NodeRelayer(nodeParams: NodeParams, router: ActorRef, commandBuffer: Actor case None => log.error("could not find pending incoming payment: payment will not be relayed: please investigate") } - case ff: Relayer.ForwardFulfill => ff.to match { - case Origin.TrampolineRelayed(_, Some(paymentSender)) => - paymentSender ! ff - val paymentHash = Crypto.sha256(ff.paymentPreimage) - pendingOutgoing.get(paymentHash).foreach(p => if (!p.fulfilledUpstream) { - // We want to fulfill upstream as soon as we receive the preimage (even if not all HTLCs have fulfilled downstream). - log.debug("trampoline payment successfully relayed") - fulfillPayment(p.upstream, ff.paymentPreimage) - context become main(pendingIncoming, pendingOutgoing + (paymentHash -> p.copy(fulfilledUpstream = true))) - }) - case _ => log.error(s"unexpected non-trampoline fulfill: $ff") - } + case PreimageReceived(paymentHash, paymentPreimage) => + log.debug("trampoline payment successfully relayed") + pendingOutgoing.get(paymentHash).foreach(p => if (!p.fulfilledUpstream) { + // We want to fulfill upstream as soon as we receive the preimage (even if not all HTLCs have fulfilled downstream). + fulfillPayment(p.upstream, paymentPreimage) + context become main(pendingIncoming, pendingOutgoing + (paymentHash -> p.copy(fulfilledUpstream = true))) + }) case PaymentSent(id, paymentHash, paymentPreimage, _, _, parts) => // We may have already fulfilled upstream, but we can now emit an accurate relayed event and clean-up resources. diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala index 5b4ad08dd..fca20fd20 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala @@ -77,9 +77,9 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, comm context.system.eventStream.subscribe(self, classOf[AvailableBalanceChanged]) context.system.eventStream.subscribe(self, classOf[ShortChannelIdAssigned]) - private val postRestartCleaner = context.actorOf(PostRestartHtlcCleaner.props(nodeParams, commandBuffer, initialized)) - private val channelRelayer = context.actorOf(ChannelRelayer.props(nodeParams, self, register, commandBuffer)) - private val nodeRelayer = context.actorOf(NodeRelayer.props(nodeParams, router, commandBuffer, register)) + private val postRestartCleaner = context.actorOf(PostRestartHtlcCleaner.props(nodeParams, commandBuffer, initialized), "post-restart-htlc-cleaner") + private val channelRelayer = context.actorOf(ChannelRelayer.props(nodeParams, self, register, commandBuffer), "channel-relayer") + private val nodeRelayer = context.actorOf(NodeRelayer.props(nodeParams, router, commandBuffer, register), "node-relayer") override def receive: Receive = main(Map.empty, new mutable.HashMap[PublicKey, mutable.Set[ShortChannelId]] with mutable.MultiMap[PublicKey, ShortChannelId]) @@ -163,7 +163,7 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, comm commandBuffer ! CommandBuffer.CommandSend(originChannelId, cmd) context.system.eventStream.publish(ChannelPaymentRelayed(amountIn, amountOut, ff.htlc.paymentHash, originChannelId, ff.htlc.channelId)) case Origin.TrampolineRelayed(_, None) => postRestartCleaner forward ff - case Origin.TrampolineRelayed(_, Some(_)) => nodeRelayer forward ff + case Origin.TrampolineRelayed(_, Some(paymentSender)) => paymentSender ! ff } case ff: ForwardFail => ff.to match { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala index a31da8026..6c5e8f050 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala @@ -71,7 +71,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, val maxFee = routeParams.getMaxFee(r.totalAmount) log.debug("sending {} with maximum fee {}", r.totalAmount, maxFee) val d = PaymentProgress(sender, r, r.maxAttempts, Map.empty, Ignore.empty, Nil) - router ! createRouteRequest(nodeParams, r.totalAmount, maxFee, routeParams, d) + router ! createRouteRequest(nodeParams, r.totalAmount, maxFee, routeParams, d, cfg) goto(WAIT_FOR_ROUTES) using d } @@ -93,7 +93,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, // remaining amount. In that case we discard these routes and send a new request to the router. log.info("discarding routes, another child payment failed so we need to recompute them (amount = {}, maximum fee = {})", toSend, maxFee) val routeParams = d.request.getRouteParams(nodeParams, randomize = true) // we randomize route selection when we retry - router ! createRouteRequest(nodeParams, toSend, maxFee, routeParams, d) + router ! createRouteRequest(nodeParams, toSend, maxFee, routeParams, d, cfg) stay } @@ -107,7 +107,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, val (toSend, maxFee) = remainingToSend(nodeParams, d.request, d.pending.values) log.debug("retry sending {} with maximum fee {} without ignoring channels ({})", toSend, maxFee, d.ignore.channels.map(_.shortChannelId).mkString(",")) val routeParams = d.request.getRouteParams(nodeParams, randomize = true) // we randomize route selection when we retry - router ! createRouteRequest(nodeParams, toSend, maxFee, routeParams, d).copy(ignore = d.ignore.emptyChannels()) + router ! createRouteRequest(nodeParams, toSend, maxFee, routeParams, d, cfg).copy(ignore = d.ignore.emptyChannels()) retriedFailedChannels = true stay using d.copy(remainingAttempts = (d.remainingAttempts - 1).max(0), ignore = d.ignore.emptyChannels()) } else { @@ -147,7 +147,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, log.debug("child payment failed, retry sending {} with maximum fee {}", toSend, maxFee) val routeParams = d.request.getRouteParams(nodeParams, randomize = true) // we randomize route selection when we retry val d1 = d.copy(pending = stillPending, ignore = ignore1, failures = d.failures ++ pf.failures) - router ! createRouteRequest(nodeParams, toSend, maxFee, routeParams, d1) + router ! createRouteRequest(nodeParams, toSend, maxFee, routeParams, d1, cfg) goto(WAIT_FOR_ROUTES) using d1 } @@ -222,6 +222,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, } private def gotoSucceededOrStop(d: PaymentSucceeded): State = { + d.sender ! PreimageReceived(paymentHash, d.preimage) if (d.pending.isEmpty) { myStop(d.sender, Right(cfg.createPaymentSent(d.preimage, d.parts))) } else @@ -303,6 +304,12 @@ object MultiPartPaymentLifecycle { routeParams.getOrElse(RouteCalculation.getDefaultRouteParams(nodeParams.routerConf)).copy(randomize = randomize) } + /** + * The payment FSM will wait for all child payments to settle before emitting payment events, but the preimage will be + * shared as soon as it's received to unblock other actors that may need it. + */ + case class PreimageReceived(paymentHash: ByteVector32, paymentPreimage: ByteVector32) + // @formatter:off sealed trait State case object WAIT_FOR_PAYMENT_REQUEST extends State @@ -362,7 +369,7 @@ object MultiPartPaymentLifecycle { */ case class PaymentSucceeded(sender: ActorRef, request: SendMultiPartPayment, preimage: ByteVector32, parts: Seq[PartialPayment], pending: Set[UUID]) extends Data - private def createRouteRequest(nodeParams: NodeParams, toSend: MilliSatoshi, maxFee: MilliSatoshi, routeParams: RouteParams, d: PaymentProgress): RouteRequest = + private def createRouteRequest(nodeParams: NodeParams, toSend: MilliSatoshi, maxFee: MilliSatoshi, routeParams: RouteParams, d: PaymentProgress, cfg: SendPaymentConfig): RouteRequest = RouteRequest( nodeParams.nodeId, d.request.targetNodeId, @@ -372,7 +379,8 @@ object MultiPartPaymentLifecycle { d.ignore, Some(routeParams), allowMultiPart = true, - d.pending.values.toSeq) + d.pending.values.toSeq, + Some(cfg.paymentContext)) private def createChildPayment(route: Route, request: SendMultiPartPayment): SendPaymentToRoute = { val finalPayload = Onion.createMultiPartPayload(route.amount, request.totalAmount, request.targetExpiry, request.paymentSecret, request.additionalTlvs, request.userCustomTlvs) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala index 09611b051..a8bbb7dd1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala @@ -26,10 +26,11 @@ import fr.acinq.eclair.channel.{Channel, Upstream} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.payment._ -import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.SendMultiPartPayment +import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment} import fr.acinq.eclair.payment.send.PaymentError._ import fr.acinq.eclair.payment.send.PaymentLifecycle.{SendPayment, SendPaymentToRoute} -import fr.acinq.eclair.router.Router.{Hop, NodeHop, Route, RouteParams} +import fr.acinq.eclair.router.RouteNotFound +import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.wire.Onion.FinalLegacyPayload import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, NodeParams, randomBytes32} @@ -81,19 +82,33 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor case pf: PaymentFailed => pending.get(pf.id).foreach(pp => { val decryptedFailures = pf.failures.collect { case RemoteFailure(_, Sphinx.DecryptedFailurePacket(_, f)) => f } - val canRetry = decryptedFailures.contains(TrampolineFeeInsufficient) || decryptedFailures.contains(TrampolineExpiryTooSoon) - pp.remainingAttempts match { - case (trampolineFees, trampolineExpiryDelta) :: remainingAttempts if canRetry => - log.info(s"retrying trampoline payment with trampoline fees=$trampolineFees and expiry delta=$trampolineExpiryDelta") - sendTrampolinePayment(pf.id, pp.r, trampolineFees, trampolineExpiryDelta) - context become main(pending + (pf.id -> pp.copy(remainingAttempts = remainingAttempts))) - case _ => - pp.sender ! pf - context.system.eventStream.publish(pf) - context become main(pending - pf.id) + val shouldRetry = decryptedFailures.contains(TrampolineFeeInsufficient) || decryptedFailures.contains(TrampolineExpiryTooSoon) + if (shouldRetry) { + pp.remainingAttempts match { + case (trampolineFees, trampolineExpiryDelta) :: remaining => + log.info(s"retrying trampoline payment with trampoline fees=$trampolineFees and expiry delta=$trampolineExpiryDelta") + sendTrampolinePayment(pf.id, pp.r, trampolineFees, trampolineExpiryDelta) + context become main(pending + (pf.id -> pp.copy(remainingAttempts = remaining))) + case Nil => + log.info("trampoline node couldn't find a route after all retries") + val trampolineRoute = Seq( + NodeHop(nodeParams.nodeId, pp.r.trampolineNodeId, nodeParams.expiryDeltaBlocks, 0 msat), + NodeHop(pp.r.trampolineNodeId, pp.r.recipientNodeId, pp.r.trampolineAttempts.last._2, pp.r.trampolineAttempts.last._1) + ) + val localFailure = pf.copy(failures = Seq(LocalFailure(trampolineRoute, RouteNotFound))) + pp.sender ! localFailure + context.system.eventStream.publish(localFailure) + context become main(pending - pf.id) + } + } else { + pp.sender ! pf + context.system.eventStream.publish(pf) + context become main(pending - pf.id) } }) + case _: PreimageReceived => // we received the preimage, but we wait for the PaymentSent event that will contain more data + case ps: PaymentSent => pending.get(ps.id).foreach(pp => { pp.sender ! ps context.system.eventStream.publish(ps) @@ -314,6 +329,8 @@ object PaymentInitiator { def fullRoute(route: Route): Seq[Hop] = route.hops ++ additionalHops def createPaymentSent(preimage: ByteVector32, parts: Seq[PaymentSent.PartialPayment]) = PaymentSent(parentId, paymentHash, preimage, recipientAmount, recipientNodeId, parts) + + def paymentContext: PaymentContext = PaymentContext(id, parentId, paymentHash) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala index 4b3568715..565e96f22 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala @@ -78,7 +78,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A log.debug("sending {} to route {}", c.finalPayload.amount, c.printRoute()) val send = SendPayment(c.targetNodeId, c.finalPayload, maxAttempts = 1, assistedRoutes = c.assistedRoutes) c.route.fold( - hops => router ! FinalizeRoute(c.finalPayload.amount, hops, c.assistedRoutes), + hops => router ! FinalizeRoute(c.finalPayload.amount, hops, c.assistedRoutes, paymentContext = Some(cfg.paymentContext)), route => self ! RouteResponse(route :: Nil) ) if (cfg.storeInDb) { @@ -92,7 +92,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A span.tag(Tags.TotalAmount, c.finalPayload.totalAmount.toLong) span.tag(Tags.Expiry, c.finalPayload.expiry.toLong) log.debug("sending {} to {}", c.finalPayload.amount, c.targetNodeId) - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.getMaxFee(nodeParams), c.assistedRoutes, routeParams = c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.getMaxFee(nodeParams), c.assistedRoutes, routeParams = c.routeParams, paymentContext = Some(cfg.paymentContext)) if (cfg.storeInDb) { paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, c.finalPayload.amount, cfg.recipientAmount, cfg.recipientNodeId, System.currentTimeMillis, cfg.paymentRequest, OutgoingPaymentStatus.Pending)) } @@ -149,6 +149,10 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A val failure = res match { case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) => log.info(s"received an error message from nodeId=$nodeId (failure=$failureMessage)") + failureMessage match { + case failureMessage: Update => handleUpdate(nodeId, failureMessage, data) + case _ => + } RemoteFailure(cfg.fullRoute(route), e) case Failure(t) => log.warning(s"cannot parse returned error: ${t.getMessage}") @@ -168,44 +172,14 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Update)) => log.info(s"received 'Update' type error message from nodeId=$nodeId, retrying payment (failure=$failureMessage)") val ignore1 = if (Announcements.checkSig(failureMessage.update, nodeId)) { - route.getChannelUpdateForNode(nodeId) match { - case Some(u) if u.shortChannelId != failureMessage.update.shortChannelId => - // it is possible that nodes in the route prefer using a different channel (to the same N+1 node) than the one we requested, that's fine - log.info(s"received an update for a different channel than the one we asked: requested=${u.shortChannelId} actual=${failureMessage.update.shortChannelId} update=${failureMessage.update}") - case Some(u) if Announcements.areSame(u, failureMessage.update) => - // node returned the exact same update we used, this can happen e.g. if the channel is imbalanced - // in that case, let's temporarily exclude the channel from future routes, giving it time to recover - log.info(s"received exact same update from nodeId=$nodeId, excluding the channel from futures routes") - val nextNodeId = route.hops.find(_.nodeId == nodeId).get.nextNodeId - router ! ExcludeChannel(ChannelDesc(u.shortChannelId, nodeId, nextNodeId)) - case Some(u) if PaymentFailure.hasAlreadyFailedOnce(nodeId, failures) => - // this node had already given us a new channel update and is still unhappy, it is probably messing with us, let's exclude it - log.warning(s"it is the second time nodeId=$nodeId answers with a new update, excluding it: old=$u new=${failureMessage.update}") - val nextNodeId = route.hops.find(_.nodeId == nodeId).get.nextNodeId - router ! ExcludeChannel(ChannelDesc(u.shortChannelId, nodeId, nextNodeId)) - case Some(u) => - log.info(s"got a new update for shortChannelId=${u.shortChannelId}: old=$u new=${failureMessage.update}") - case None => - log.error(s"couldn't find a channel update for node=$nodeId, this should never happen") - } - // in any case, we forward the update to the router - router ! failureMessage.update - // we also update assisted routes, because they take precedence over the router's routing table - val assistedRoutes1 = c.assistedRoutes.map(_.map { - case extraHop: ExtraHop if extraHop.shortChannelId == failureMessage.update.shortChannelId => extraHop.copy( - cltvExpiryDelta = failureMessage.update.cltvExpiryDelta, - feeBase = failureMessage.update.feeBaseMsat, - feeProportionalMillionths = failureMessage.update.feeProportionalMillionths - ) - case extraHop => extraHop - }) + val assistedRoutes1 = handleUpdate(nodeId, failureMessage, data) // let's try again, router will have updated its state - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.getMaxFee(nodeParams), assistedRoutes1, ignore, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.getMaxFee(nodeParams), assistedRoutes1, ignore, c.routeParams, paymentContext = Some(cfg.paymentContext)) ignore } else { // this node is fishy, it gave us a bad sig!! let's filter it out log.warning(s"got bad signature from node=$nodeId update=${failureMessage.update}") - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.getMaxFee(nodeParams), c.assistedRoutes, ignore + nodeId, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.getMaxFee(nodeParams), c.assistedRoutes, ignore + nodeId, c.routeParams, paymentContext = Some(cfg.paymentContext)) ignore + nodeId } goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(cfg.fullRoute(route), e), ignore1) @@ -262,10 +236,58 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A private def retry(failure: PaymentFailure, data: WaitingForComplete): FSM.State[PaymentLifecycle.State, PaymentLifecycle.Data] = { val ignore1 = PaymentFailure.updateIgnored(failure, data.ignore) - router ! RouteRequest(nodeParams.nodeId, data.c.targetNodeId, data.c.finalPayload.amount, data.c.getMaxFee(nodeParams), data.c.assistedRoutes, ignore1, data.c.routeParams) + router ! RouteRequest(nodeParams.nodeId, data.c.targetNodeId, data.c.finalPayload.amount, data.c.getMaxFee(nodeParams), data.c.assistedRoutes, ignore1, data.c.routeParams, paymentContext = Some(cfg.paymentContext)) goto(WAITING_FOR_ROUTE) using WaitingForRoute(data.sender, data.c, data.failures :+ failure, ignore1) } + /** + * Apply the channel update to our routing table. + * + * @return updated routing hints if applicable. + */ + private def handleUpdate(nodeId: PublicKey, failure: Update, data: WaitingForComplete): Seq[Seq[ExtraHop]] = { + data.route.getChannelUpdateForNode(nodeId) match { + case Some(u) if u.shortChannelId != failure.update.shortChannelId => + // it is possible that nodes in the route prefer using a different channel (to the same N+1 node) than the one we requested, that's fine + log.info(s"received an update for a different channel than the one we asked: requested=${u.shortChannelId} actual=${failure.update.shortChannelId} update=${failure.update}") + case Some(u) if Announcements.areSame(u, failure.update) => + // node returned the exact same update we used, this can happen e.g. if the channel is imbalanced + // in that case, let's temporarily exclude the channel from future routes, giving it time to recover + log.info(s"received exact same update from nodeId=$nodeId, excluding the channel from futures routes") + val nextNodeId = data.route.hops.find(_.nodeId == nodeId).get.nextNodeId + router ! ExcludeChannel(ChannelDesc(u.shortChannelId, nodeId, nextNodeId)) + case Some(u) if PaymentFailure.hasAlreadyFailedOnce(nodeId, data.failures) => + // this node had already given us a new channel update and is still unhappy, it is probably messing with us, let's exclude it + log.warning(s"it is the second time nodeId=$nodeId answers with a new update, excluding it: old=$u new=${failure.update}") + val nextNodeId = data.route.hops.find(_.nodeId == nodeId).get.nextNodeId + router ! ExcludeChannel(ChannelDesc(u.shortChannelId, nodeId, nextNodeId)) + case Some(u) => + log.info(s"got a new update for shortChannelId=${u.shortChannelId}: old=$u new=${failure.update}") + case None => + log.error(s"couldn't find a channel update for node=$nodeId, this should never happen") + } + // in any case, we forward the update to the router: if the channel is disabled, the router will remove it from its routing table + router ! failure.update + // we return updated assisted routes: they take precedence over the router's routing table + if (Announcements.isEnabled(failure.update.channelFlags)) { + data.c.assistedRoutes.map(_.map { + case extraHop: ExtraHop if extraHop.shortChannelId == failure.update.shortChannelId => extraHop.copy( + cltvExpiryDelta = failure.update.cltvExpiryDelta, + feeBase = failure.update.feeBaseMsat, + feeProportionalMillionths = failure.update.feeProportionalMillionths + ) + case extraHop => extraHop + }) + } else { + // if the channel is disabled, we temporarily exclude it: this is necessary because the routing hint doesn't contain + // channel flags to indicate that it's disabled + data.c.assistedRoutes.flatMap(r => RouteCalculation.toChannelDescs(r, data.c.targetNodeId)) + .find(_.shortChannelId == failure.update.shortChannelId) + .foreach(desc => router ! ExcludeChannel(desc)) // we want the exclusion to be router-wide so that sister payments in the case of MPP are aware the channel is faulty + data.c.assistedRoutes + } + } + private def myStop(): State = { stateSpan.foreach(_.finish()) span.finish() diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala index 503711227..bf1be1308 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala @@ -17,9 +17,10 @@ package fr.acinq.eclair.router import akka.actor.{ActorContext, ActorRef, Status} -import akka.event.LoggingAdapter +import akka.event.{DiagnosticLoggingAdapter, LoggingAdapter} import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Satoshi} +import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph.graphEdgeToHop import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} @@ -37,62 +38,75 @@ import scala.util.{Failure, Random, Success, Try} object RouteCalculation { - def finalizeRoute(d: Data, fr: FinalizeRoute)(implicit ctx: ActorContext, log: LoggingAdapter): Data = { - implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + def finalizeRoute(d: Data, fr: FinalizeRoute)(implicit ctx: ActorContext, log: DiagnosticLoggingAdapter): Data = { + Logs.withMdc(log)(Logs.mdc( + category_opt = Some(LogCategory.PAYMENT), + parentPaymentId_opt = fr.paymentContext.map(_.parentId), + paymentId_opt = fr.paymentContext.map(_.id), + paymentHash_opt = fr.paymentContext.map(_.paymentHash))) { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors - val assistedChannels: Map[ShortChannelId, AssistedChannel] = fr.assistedRoutes.flatMap(toAssistedChannels(_, fr.hops.last, fr.amount)).toMap - val extraEdges = assistedChannels.values.map(ac => - GraphEdge(ChannelDesc(ac.extraHop.shortChannelId, ac.extraHop.nodeId, ac.nextNodeId), toFakeUpdate(ac.extraHop, ac.htlcMaximum), htlcMaxToCapacity(ac.htlcMaximum), Some(ac.htlcMaximum)) - ).toSet - val g = extraEdges.foldLeft(d.graph) { case (g: DirectedGraph, e: GraphEdge) => g.addEdge(e) } - // split into sublists [(a,b),(b,c), ...] then get the edges between each of those pairs - fr.hops.sliding(2).map { case List(v1, v2) => g.getEdgesBetween(v1, v2) }.toList match { - case edges if edges.nonEmpty && edges.forall(_.nonEmpty) => - // select the largest edge (using balance when available, otherwise capacity). - val selectedEdges = edges.map(es => es.maxBy(e => e.balance_opt.getOrElse(e.capacity.toMilliSatoshi))) - val hops = selectedEdges.map(d => ChannelHop(d.desc.a, d.desc.b, d.update)) - ctx.sender ! RouteResponse(Route(fr.amount, hops) :: Nil) - case _ => // some nodes in the supplied route aren't connected in our graph - ctx.sender ! Status.Failure(new IllegalArgumentException("Not all the nodes in the supplied route are connected with public channels")) + val assistedChannels: Map[ShortChannelId, AssistedChannel] = fr.assistedRoutes.flatMap(toAssistedChannels(_, fr.hops.last, fr.amount)).toMap + val extraEdges = assistedChannels.values.map(ac => + GraphEdge(ChannelDesc(ac.extraHop.shortChannelId, ac.extraHop.nodeId, ac.nextNodeId), toFakeUpdate(ac.extraHop, ac.htlcMaximum), htlcMaxToCapacity(ac.htlcMaximum), Some(ac.htlcMaximum)) + ).toSet + val g = extraEdges.foldLeft(d.graph) { case (g: DirectedGraph, e: GraphEdge) => g.addEdge(e) } + // split into sublists [(a,b),(b,c), ...] then get the edges between each of those pairs + fr.hops.sliding(2).map { case List(v1, v2) => g.getEdgesBetween(v1, v2) }.toList match { + case edges if edges.nonEmpty && edges.forall(_.nonEmpty) => + // select the largest edge (using balance when available, otherwise capacity). + val selectedEdges = edges.map(es => es.maxBy(e => e.balance_opt.getOrElse(e.capacity.toMilliSatoshi))) + val hops = selectedEdges.map(d => ChannelHop(d.desc.a, d.desc.b, d.update)) + ctx.sender ! RouteResponse(Route(fr.amount, hops) :: Nil) + case _ => // some nodes in the supplied route aren't connected in our graph + ctx.sender ! Status.Failure(new IllegalArgumentException("Not all the nodes in the supplied route are connected with public channels")) + } + d } - d } - def handleRouteRequest(d: Data, routerConf: RouterConf, currentBlockHeight: Long, r: RouteRequest)(implicit ctx: ActorContext, log: LoggingAdapter): Data = { - implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors + def handleRouteRequest(d: Data, routerConf: RouterConf, currentBlockHeight: Long, r: RouteRequest)(implicit ctx: ActorContext, log: DiagnosticLoggingAdapter): Data = { + Logs.withMdc(log)(Logs.mdc( + category_opt = Some(LogCategory.PAYMENT), + parentPaymentId_opt = r.paymentContext.map(_.parentId), + paymentId_opt = r.paymentContext.map(_.id), + paymentHash_opt = r.paymentContext.map(_.paymentHash))) { + implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors - // we convert extra routing info provided in the payment request to fake channel_update - // it takes precedence over all other channel_updates we know - val assistedChannels: Map[ShortChannelId, AssistedChannel] = r.assistedRoutes.flatMap(toAssistedChannels(_, r.target, r.amount)).toMap - val extraEdges = assistedChannels.values.map(ac => - GraphEdge(ChannelDesc(ac.extraHop.shortChannelId, ac.extraHop.nodeId, ac.nextNodeId), toFakeUpdate(ac.extraHop, ac.htlcMaximum), htlcMaxToCapacity(ac.htlcMaximum), Some(ac.htlcMaximum)) - ).toSet - val ignoredEdges = r.ignore.channels ++ d.excludedChannels - val params = r.routeParams.getOrElse(getDefaultRouteParams(routerConf)) - val routesToFind = if (params.randomize) DEFAULT_ROUTES_COUNT else 1 + // we convert extra routing info provided in the payment request to fake channel_update + // it takes precedence over all other channel_updates we know + val assistedChannels: Map[ShortChannelId, AssistedChannel] = r.assistedRoutes.flatMap(toAssistedChannels(_, r.target, r.amount)) + .filterNot { case (_, ac) => ac.extraHop.nodeId == r.source } // we ignore routing hints for our own channels, we have more accurate information + .toMap + val extraEdges = assistedChannels.values.map(ac => + GraphEdge(ChannelDesc(ac.extraHop.shortChannelId, ac.extraHop.nodeId, ac.nextNodeId), toFakeUpdate(ac.extraHop, ac.htlcMaximum), htlcMaxToCapacity(ac.htlcMaximum), Some(ac.htlcMaximum)) + ).toSet + val ignoredEdges = r.ignore.channels ++ d.excludedChannels + val params = r.routeParams.getOrElse(getDefaultRouteParams(routerConf)) + val routesToFind = if (params.randomize) DEFAULT_ROUTES_COUNT else 1 - log.info(s"finding routes ${r.source}->${r.target} with assistedChannels={} ignoreNodes={} ignoreChannels={} excludedChannels={}", assistedChannels.keys.mkString(","), r.ignore.nodes.map(_.value).mkString(","), r.ignore.channels.mkString(","), d.excludedChannels.mkString(",")) - log.info("finding routes with randomize={} params={}", params.randomize, params) - val tags = TagSet.Empty.withTag(Tags.MultiPart, r.allowMultiPart).withTag(Tags.Amount, Tags.amountBucket(r.amount)) - KamonExt.time(Metrics.FindRouteDuration.withTags(tags.withTag(Tags.NumberOfRoutes, routesToFind.toLong))) { - val result = if (r.allowMultiPart) { - findMultiPartRoute(d.graph, r.source, r.target, r.amount, r.maxFee, extraEdges, ignoredEdges, r.ignore.nodes, r.pendingPayments, params, currentBlockHeight) - } else { - findRoute(d.graph, r.source, r.target, r.amount, r.maxFee, routesToFind, extraEdges, ignoredEdges, r.ignore.nodes, params, currentBlockHeight) - } - result match { - case Success(routes) => - Metrics.RouteResults.withTags(tags).record(routes.length) - routes.foreach(route => Metrics.RouteLength.withTags(tags).record(route.length)) - ctx.sender ! RouteResponse(routes) - case Failure(t) => - val failure = if (isNeighborBalanceTooLow(d.graph, r)) BalanceTooLow else t - Metrics.FindRouteErrors.withTags(tags.withTag(Tags.Error, failure.getClass.getSimpleName)).increment() - ctx.sender ! Status.Failure(failure) + log.info(s"finding routes ${r.source}->${r.target} with assistedChannels={} ignoreNodes={} ignoreChannels={} excludedChannels={}", assistedChannels.keys.mkString(","), r.ignore.nodes.map(_.value).mkString(","), r.ignore.channels.mkString(","), d.excludedChannels.mkString(",")) + log.info("finding routes with randomize={} params={}", params.randomize, params) + val tags = TagSet.Empty.withTag(Tags.MultiPart, r.allowMultiPart).withTag(Tags.Amount, Tags.amountBucket(r.amount)) + KamonExt.time(Metrics.FindRouteDuration.withTags(tags.withTag(Tags.NumberOfRoutes, routesToFind.toLong))) { + val result = if (r.allowMultiPart) { + findMultiPartRoute(d.graph, r.source, r.target, r.amount, r.maxFee, extraEdges, ignoredEdges, r.ignore.nodes, r.pendingPayments, params, currentBlockHeight) + } else { + findRoute(d.graph, r.source, r.target, r.amount, r.maxFee, routesToFind, extraEdges, ignoredEdges, r.ignore.nodes, params, currentBlockHeight) + } + result match { + case Success(routes) => + Metrics.RouteResults.withTags(tags).record(routes.length) + routes.foreach(route => Metrics.RouteLength.withTags(tags).record(route.length)) + ctx.sender ! RouteResponse(routes) + case Failure(t) => + val failure = if (isNeighborBalanceTooLow(d.graph, r)) BalanceTooLow else t + Metrics.FindRouteErrors.withTags(tags.withTag(Tags.Error, failure.getClass.getSimpleName)).increment() + ctx.sender ! Status.Failure(failure) + } } + d } - - d } private def toFakeUpdate(extraHop: ExtraHop, htlcMaximum: MilliSatoshi): ChannelUpdate = { @@ -116,6 +130,11 @@ object RouteCalculation { }._2 } + def toChannelDescs(extraRoute: Seq[ExtraHop], targetNodeId: PublicKey): Seq[ChannelDesc] = { + val nextNodeIds = extraRoute.map(_.nodeId).drop(1) :+ targetNodeId + extraRoute.zip(nextNodeIds).map { case (hop, nextNodeId) => ChannelDesc(hop.shortChannelId, hop.nodeId, nextNodeId) } + } + /** Bolt 11 routing hints don't include the channel's capacity, so we round up the maximum htlc amount. */ private def htlcMaxToCapacity(htlcMaximum: MilliSatoshi): Satoshi = htlcMaximum.truncateToSatoshi + 1.sat diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index 0d3548932..a001dd46b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -16,6 +16,8 @@ package fr.acinq.eclair.router +import java.util.UUID + import akka.Done import akka.actor.{ActorRef, Props} import akka.event.DiagnosticLoggingAdapter @@ -31,6 +33,7 @@ import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.db.NetworkDb import fr.acinq.eclair.io.Peer.PeerRoutingMessage import fr.acinq.eclair.payment.PaymentRequest.ExtraHop +import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph import fr.acinq.eclair.router.Graph.WeightRatios import fr.acinq.eclair.router.Monitoring.{Metrics, Tags} @@ -363,9 +366,18 @@ object Router { ignore: Ignore = Ignore.empty, routeParams: Option[RouteParams] = None, allowMultiPart: Boolean = false, - pendingPayments: Seq[Route] = Nil) + pendingPayments: Seq[Route] = Nil, + paymentContext: Option[PaymentContext] = None) - case class FinalizeRoute(amount: MilliSatoshi, hops: Seq[PublicKey], assistedRoutes: Seq[Seq[ExtraHop]] = Nil) + case class FinalizeRoute(amount: MilliSatoshi, + hops: Seq[PublicKey], + assistedRoutes: Seq[Seq[ExtraHop]] = Nil, + paymentContext: Option[PaymentContext] = None) + + /** + * Useful for having appropriate logging context at hand when finding routes + */ + case class PaymentContext(id: UUID, parentId: UUID, paymentHash: ByteVector32) case class Route(amount: MilliSatoshi, hops: Seq[ChannelHop]) { require(hops.nonEmpty, "route cannot be empty") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala index 14823c4c9..c3a187b7a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala @@ -475,25 +475,29 @@ object Validation { } } - def handleAvailableBalanceChanged(d: Data, e: AvailableBalanceChanged): Data = { - val (channels1, graph1) = d.channels.get(e.shortChannelId) match { + def handleAvailableBalanceChanged(d: Data, e: AvailableBalanceChanged)(implicit log: LoggingAdapter): Data = { + val desc = ChannelDesc(e.shortChannelId, e.commitments.localParams.nodeId, e.commitments.remoteParams.nodeId) + val (publicChannels1, graph1) = d.channels.get(e.shortChannelId) match { case Some(pc) => val pc1 = pc.updateBalances(e.commitments) - val desc = ChannelDesc(e.shortChannelId, e.commitments.localParams.nodeId, e.commitments.remoteParams.nodeId) + log.debug("public channel balance updated: {}", pc1) val update_opt = if (e.commitments.localParams.nodeId == pc1.ann.nodeId1) pc1.update_1_opt else pc1.update_2_opt val graph1 = update_opt.map(u => d.graph.addEdge(desc, u, pc1.capacity, pc1.getBalanceSameSideAs(u))).getOrElse(d.graph) (d.channels + (e.shortChannelId -> pc1), graph1) case None => (d.channels, d.graph) } - val privateChannels1 = d.privateChannels.get(e.shortChannelId) match { + val (privateChannels1, graph2) = d.privateChannels.get(e.shortChannelId) match { case Some(pc) => val pc1 = pc.updateBalances(e.commitments) - d.privateChannels + (e.shortChannelId -> pc1) + log.debug("private channel balance updated: {}", pc1) + val update_opt = if (e.commitments.localParams.nodeId == pc1.nodeId1) pc1.update_1_opt else pc1.update_2_opt + val graph2 = update_opt.map(u => graph1.addEdge(desc, u, pc1.capacity, pc1.getBalanceSameSideAs(u))).getOrElse(graph1) + (d.privateChannels + (e.shortChannelId -> pc1), graph2) case None => - d.privateChannels + (d.privateChannels, graph1) } - d.copy(channels = channels1, privateChannels = privateChannels1, graph = graph1) + d.copy(channels = publicChannels1, privateChannels = privateChannels1, graph = graph2) } } diff --git a/eclair-core/src/main/scala/kamon/Kamon.scala b/eclair-core/src/main/scala/kamon/Kamon.scala index 0bf5454f6..173bb4c1c 100644 --- a/eclair-core/src/main/scala/kamon/Kamon.scala +++ b/eclair-core/src/main/scala/kamon/Kamon.scala @@ -39,6 +39,8 @@ object Kamon { def increment(a: Int) = this + def increment(a: Long) = this + def decrement() = this def record(a: Long) = this diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala index 4d2f35f67..d54bfb5ae 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala @@ -35,6 +35,7 @@ import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentRequest, SendPa import fr.acinq.eclair.router.RouteCalculationSpec.makeUpdateShort import fr.acinq.eclair.router.Router.{GetNetworkStats, GetNetworkStatsResponse, PublicChannel} import fr.acinq.eclair.router.{Announcements, NetworkStats, Router, Stats} +import fr.acinq.eclair.wire.{Color, NodeAnnouncement} import org.mockito.Mockito import org.mockito.scalatest.IdiomaticMockito import org.scalatest.funsuite.FixtureAnyFunSuiteLike @@ -150,6 +151,54 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I assertThrows[IllegalArgumentException](Await.result(eclair.send(None, nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(expiredInvoice)), 50 millis)) } + test("return node announcements") { f => + import f._ + + val eclair = new EclairImpl(kit) + val remoteNodeAnn1 = NodeAnnouncement(randomBytes64, Features.empty, 42L, randomKey.publicKey, Color(42, 42, 42), "LN-rocks", Nil) + val remoteNodeAnn2 = NodeAnnouncement(randomBytes64, Features.empty, 43L, randomKey.publicKey, Color(43, 43, 43), "LN-papers", Nil) + val allNodes = Seq( + NodeAnnouncement(randomBytes64, Features.empty, 561L, randomKey.publicKey, Color(0, 0, 0), "some-node", Nil), + remoteNodeAnn1, + remoteNodeAnn2, + NodeAnnouncement(randomBytes64, Features.empty, 1105L, randomKey.publicKey, Color(0, 0, 0), "some-other-node", Nil) + ) + + { + val fRes = eclair.nodes() + router.expectMsg(Symbol("nodes")) + router.reply(allNodes) + awaitCond(fRes.value match { + case Some(Success(nodes)) => + assert(nodes.toSet === allNodes.toSet) + true + case _ => false + }) + } + { + val fRes = eclair.nodes(Some(Set(remoteNodeAnn1.nodeId, remoteNodeAnn2.nodeId))) + router.expectMsg(Symbol("nodes")) + router.reply(allNodes) + awaitCond(fRes.value match { + case Some(Success(nodes)) => + assert(nodes.toSet === Set(remoteNodeAnn1, remoteNodeAnn2)) + true + case _ => false + }) + } + { + val fRes = eclair.nodes(Some(Set(randomKey.publicKey))) + router.expectMsg(Symbol("nodes")) + router.reply(allNodes) + awaitCond(fRes.value match { + case Some(Success(nodes)) => + assert(nodes.isEmpty) + true + case _ => false + }) + } + } + test("allupdates can filter by nodeId") { f => import f._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/StartupSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/StartupSpec.scala index 38e7ee00d..c4fdfedfc 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/StartupSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/StartupSpec.scala @@ -46,6 +46,7 @@ class StartupSpec extends AnyFunSuite { test("check configuration") { assert(Try(makeNodeParamsWithDefaults(ConfigFactory.load().getConfig("eclair"))).isSuccess) assert(Try(makeNodeParamsWithDefaults(ConfigFactory.load().getConfig("eclair").withFallback(ConfigFactory.parseMap(Map("max-feerate-mismatch" -> 42).asJava)))).isFailure) + assert(Try(makeNodeParamsWithDefaults(ConfigFactory.load().getConfig("eclair").withFallback(ConfigFactory.parseMap(Map("on-chain-fees.max-feerate-mismatch" -> 1.56).asJava)))).isFailure) } test("NodeParams should fail if the alias is illegal (over 32 bytes)") { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index b71976bba..469df7b0b 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -22,9 +22,9 @@ import java.util.concurrent.atomic.AtomicLong import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.{Block, ByteVector32, Script} import fr.acinq.eclair.FeatureSupport.Optional -import fr.acinq.eclair.Features.{ChannelRangeQueries, ChannelRangeQueriesExtended, InitialRoutingSync, OptionDataLossProtect, VariableLengthOnion} +import fr.acinq.eclair.Features._ import fr.acinq.eclair.NodeParams.BITCOIND -import fr.acinq.eclair.blockchain.fee.{FeeEstimator, FeeTargets, FeeratesPerKw, OnChainFeeConf} +import fr.acinq.eclair.blockchain.fee._ import fr.acinq.eclair.crypto.LocalKeyManager import fr.acinq.eclair.db._ import fr.acinq.eclair.io.Peer @@ -84,7 +84,7 @@ object TestConstants { onChainFeeConf = OnChainFeeConf( feeTargets = FeeTargets(6, 2, 2, 6), feeEstimator = new TestFeeEstimator, - maxFeerateMismatch = 1.5, + maxFeerateMismatch = FeerateTolerance(0.5, 8.0), closeOnOfflineMismatch = true, updateFeeMinDiffRatio = 0.1 ), @@ -170,7 +170,7 @@ object TestConstants { onChainFeeConf = OnChainFeeConf( feeTargets = FeeTargets(6, 2, 2, 6), feeEstimator = new TestFeeEstimator, - maxFeerateMismatch = 1.0, + maxFeerateMismatch = FeerateTolerance(0.75, 1.5), closeOnOfflineMismatch = true, updateFeeMinDiffRatio = 0.1 ), diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/CommitmentsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/CommitmentsSpec.scala index 33630a92d..08165dd85 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/CommitmentsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/CommitmentsSpec.scala @@ -20,6 +20,8 @@ import java.util.UUID import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.{DeterministicWallet, Satoshi, Transaction} +import fr.acinq.eclair.TestConstants.TestFeeEstimator +import fr.acinq.eclair.blockchain.fee.{FeeTargets, FeerateTolerance, OnChainFeeConf} import fr.acinq.eclair.channel.Commitments._ import fr.acinq.eclair.channel.Helpers.Funding import fr.acinq.eclair.channel.states.StateTestsHelperMethods @@ -42,6 +44,8 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with implicit val log: akka.event.LoggingAdapter = akka.event.NoLogging + val feeConfNoMismatch = OnChainFeeConf(FeeTargets(6, 2, 2, 6), new TestFeeEstimator, FeerateTolerance(0.00001, 100000.0), closeOnOfflineMismatch = false, 1.0) + override def withFixture(test: OneArgTest): Outcome = { val setup = init() import setup._ @@ -66,8 +70,8 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with assert(bc0.availableBalanceForReceive == a - htlcOutputFee) val (_, cmdAdd) = makeCmdAdd(a - htlcOutputFee - 1000.msat, bob.underlyingActor.nodeParams.nodeId, currentBlockHeight) - val Success((ac1, add)) = sendAdd(ac0, cmdAdd, Local(UUID.randomUUID, None), currentBlockHeight) - val Success(bc1) = receiveAdd(bc0, add) + val Success((ac1, add)) = sendAdd(ac0, cmdAdd, Local(UUID.randomUUID, None), currentBlockHeight, alice.underlyingActor.nodeParams.onChainFeeConf) + val Success(bc1) = receiveAdd(bc0, add, bob.underlyingActor.nodeParams.onChainFeeConf) val Success((_, commit1)) = sendCommit(ac1, alice.underlyingActor.nodeParams.keyManager) val Success((bc2, _)) = receiveCommit(bc1, commit1, bob.underlyingActor.nodeParams.keyManager) // we don't take into account the additional HTLC fee since Alice's balance is below the trim threshold. @@ -94,11 +98,11 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with assert(bc0.availableBalanceForReceive == a) val (payment_preimage, cmdAdd) = makeCmdAdd(p, bob.underlyingActor.nodeParams.nodeId, currentBlockHeight) - val Success((ac1, add)) = sendAdd(ac0, cmdAdd, Local(UUID.randomUUID, None), currentBlockHeight) + val Success((ac1, add)) = sendAdd(ac0, cmdAdd, Local(UUID.randomUUID, None), currentBlockHeight, alice.underlyingActor.nodeParams.onChainFeeConf) assert(ac1.availableBalanceForSend == a - p - fee) // as soon as htlc is sent, alice sees its balance decrease (more than the payment amount because of the commitment fees) assert(ac1.availableBalanceForReceive == b) - val Success(bc1) = receiveAdd(bc0, add) + val Success(bc1) = receiveAdd(bc0, add, bob.underlyingActor.nodeParams.onChainFeeConf) assert(bc1.availableBalanceForSend == b) assert(bc1.availableBalanceForReceive == a - p - fee) @@ -179,11 +183,11 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with assert(bc0.availableBalanceForReceive == a) val (_, cmdAdd) = makeCmdAdd(p, bob.underlyingActor.nodeParams.nodeId, currentBlockHeight) - val Success((ac1, add)) = sendAdd(ac0, cmdAdd, Local(UUID.randomUUID, None), currentBlockHeight) + val Success((ac1, add)) = sendAdd(ac0, cmdAdd, Local(UUID.randomUUID, None), currentBlockHeight, alice.underlyingActor.nodeParams.onChainFeeConf) assert(ac1.availableBalanceForSend == a - p - fee) // as soon as htlc is sent, alice sees its balance decrease (more than the payment amount because of the commitment fees) assert(ac1.availableBalanceForReceive == b) - val Success(bc1) = receiveAdd(bc0, add) + val Success(bc1) = receiveAdd(bc0, add, bob.underlyingActor.nodeParams.onChainFeeConf) assert(bc1.availableBalanceForSend == b) assert(bc1.availableBalanceForReceive == a - p - fee) @@ -267,29 +271,29 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with assert(bc0.availableBalanceForReceive == a) val (payment_preimage1, cmdAdd1) = makeCmdAdd(p1, bob.underlyingActor.nodeParams.nodeId, currentBlockHeight) - val Success((ac1, add1)) = sendAdd(ac0, cmdAdd1, Local(UUID.randomUUID, None), currentBlockHeight) + val Success((ac1, add1)) = sendAdd(ac0, cmdAdd1, Local(UUID.randomUUID, None), currentBlockHeight, alice.underlyingActor.nodeParams.onChainFeeConf) assert(ac1.availableBalanceForSend == a - p1 - fee) // as soon as htlc is sent, alice sees its balance decrease (more than the payment amount because of the commitment fees) assert(ac1.availableBalanceForReceive == b) val (_, cmdAdd2) = makeCmdAdd(p2, bob.underlyingActor.nodeParams.nodeId, currentBlockHeight) - val Success((ac2, add2)) = sendAdd(ac1, cmdAdd2, Local(UUID.randomUUID, None), currentBlockHeight) + val Success((ac2, add2)) = sendAdd(ac1, cmdAdd2, Local(UUID.randomUUID, None), currentBlockHeight, alice.underlyingActor.nodeParams.onChainFeeConf) assert(ac2.availableBalanceForSend == a - p1 - fee - p2 - fee) // as soon as htlc is sent, alice sees its balance decrease (more than the payment amount because of the commitment fees) assert(ac2.availableBalanceForReceive == b) val (payment_preimage3, cmdAdd3) = makeCmdAdd(p3, alice.underlyingActor.nodeParams.nodeId, currentBlockHeight) - val Success((bc1, add3)) = sendAdd(bc0, cmdAdd3, Local(UUID.randomUUID, None), currentBlockHeight) + val Success((bc1, add3)) = sendAdd(bc0, cmdAdd3, Local(UUID.randomUUID, None), currentBlockHeight, bob.underlyingActor.nodeParams.onChainFeeConf) assert(bc1.availableBalanceForSend == b - p3) // bob doesn't pay the fee assert(bc1.availableBalanceForReceive == a) - val Success(bc2) = receiveAdd(bc1, add1) + val Success(bc2) = receiveAdd(bc1, add1, bob.underlyingActor.nodeParams.onChainFeeConf) assert(bc2.availableBalanceForSend == b - p3) assert(bc2.availableBalanceForReceive == a - p1 - fee) - val Success(bc3) = receiveAdd(bc2, add2) + val Success(bc3) = receiveAdd(bc2, add2, bob.underlyingActor.nodeParams.onChainFeeConf) assert(bc3.availableBalanceForSend == b - p3) assert(bc3.availableBalanceForReceive == a - p1 - fee - p2 - fee) - val Success(ac3) = receiveAdd(ac2, add3) + val Success(ac3) = receiveAdd(ac2, add3, alice.underlyingActor.nodeParams.onChainFeeConf) assert(ac3.availableBalanceForSend == a - p1 - fee - p2 - fee) assert(ac3.availableBalanceForReceive == b - p3) @@ -398,7 +402,7 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with val isFunder = true val c = CommitmentsSpec.makeCommitments(100000000 msat, 50000000 msat, 2500, 546 sat, isFunder) val (_, cmdAdd) = makeCmdAdd(c.availableBalanceForSend, randomKey.publicKey, f.currentBlockHeight) - val Success((c1, _)) = sendAdd(c, cmdAdd, Local(UUID.randomUUID, None), f.currentBlockHeight) + val Success((c1, _)) = sendAdd(c, cmdAdd, Local(UUID.randomUUID, None), f.currentBlockHeight, feeConfNoMismatch) assert(c1.availableBalanceForSend === 0.msat) // We should be able to handle a fee increase. @@ -406,7 +410,7 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with // Now we shouldn't be able to send until we receive enough to handle the updated commit tx fee (even trimmed HTLCs shouldn't be sent). val (_, cmdAdd1) = makeCmdAdd(100 msat, randomKey.publicKey, f.currentBlockHeight) - val Failure(e) = sendAdd(c2, cmdAdd1, Local(UUID.randomUUID, None), f.currentBlockHeight) + val Failure(e) = sendAdd(c2, cmdAdd1, Local(UUID.randomUUID, None), f.currentBlockHeight, feeConfNoMismatch) assert(e.isInstanceOf[InsufficientFunds]) } @@ -414,7 +418,7 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with for (isFunder <- Seq(true, false)) { val c = CommitmentsSpec.makeCommitments(702000000 msat, 52000000 msat, 2679, 546 sat, isFunder) val (_, cmdAdd) = makeCmdAdd(c.availableBalanceForSend, randomKey.publicKey, f.currentBlockHeight) - val result = sendAdd(c, cmdAdd, Local(UUID.randomUUID, None), f.currentBlockHeight) + val result = sendAdd(c, cmdAdd, Local(UUID.randomUUID, None), f.currentBlockHeight, feeConfNoMismatch) assert(result.isSuccess, result) } } @@ -423,7 +427,7 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with for (isFunder <- Seq(true, false)) { val c = CommitmentsSpec.makeCommitments(31000000 msat, 702000000 msat, 2679, 546 sat, isFunder) val add = UpdateAddHtlc(randomBytes32, c.remoteNextHtlcId, c.availableBalanceForReceive, randomBytes32, CltvExpiry(f.currentBlockHeight), TestConstants.emptyOnionPacket) - receiveAdd(c, add) + receiveAdd(c, add, feeConfNoMismatch) } } @@ -444,13 +448,13 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with for (_ <- 0 to t.pendingHtlcs) { val amount = Random.nextInt(maxPendingHtlcAmount.toLong.toInt).msat.max(1 msat) val (_, cmdAdd) = makeCmdAdd(amount, randomKey.publicKey, f.currentBlockHeight) - sendAdd(c, cmdAdd, Local(UUID.randomUUID, None), f.currentBlockHeight) match { + sendAdd(c, cmdAdd, Local(UUID.randomUUID, None), f.currentBlockHeight, feeConfNoMismatch) match { case Success((cc, _)) => c = cc case Failure(e) => fail(s"$t -> could not setup initial htlcs: $e") } } val (_, cmdAdd) = makeCmdAdd(c.availableBalanceForSend, randomKey.publicKey, f.currentBlockHeight) - val result = sendAdd(c, cmdAdd, Local(UUID.randomUUID, None), f.currentBlockHeight) + val result = sendAdd(c, cmdAdd, Local(UUID.randomUUID, None), f.currentBlockHeight, feeConfNoMismatch) assert(result.isSuccess, s"$t -> $result") } } @@ -472,13 +476,13 @@ class CommitmentsSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with for (_ <- 0 to t.pendingHtlcs) { val amount = Random.nextInt(maxPendingHtlcAmount.toLong.toInt).msat.max(1 msat) val add = UpdateAddHtlc(randomBytes32, c.remoteNextHtlcId, amount, randomBytes32, CltvExpiry(f.currentBlockHeight), TestConstants.emptyOnionPacket) - receiveAdd(c, add) match { + receiveAdd(c, add, feeConfNoMismatch) match { case Success(cc) => c = cc case Failure(e) => fail(s"$t -> could not setup initial htlcs: $e") } } val add = UpdateAddHtlc(randomBytes32, c.remoteNextHtlcId, c.availableBalanceForReceive, randomBytes32, CltvExpiry(f.currentBlockHeight), TestConstants.emptyOnionPacket) - receiveAdd(c, add) match { + receiveAdd(c, add, feeConfNoMismatch) match { case Success(_) => () case Failure(e) => fail(s"$t -> $e") } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForOpenChannelStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForOpenChannelStateSpec.scala index 0cdcd6ebe..81a811552 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForOpenChannelStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForOpenChannelStateSpec.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.wire.{AcceptChannel, ChannelTlv, Error, Init, OpenChannel import fr.acinq.eclair.{ActivatedFeature, CltvExpiryDelta, Features, LongToBtcAmount, TestConstants, TestKitBaseClass, ToMilliSatoshiConversion} import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.{Outcome, Tag} -import scodec.bits.{ByteVector, HexStringSyntax} +import scodec.bits.ByteVector import scala.concurrent.duration._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala index 27406acd8..0db22f9ea 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala @@ -24,8 +24,9 @@ import akka.testkit.TestProbe import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Crypto, ScriptFlags, Transaction} import fr.acinq.eclair.Features.StaticRemoteKey -import fr.acinq.eclair.TestConstants.{Alice, Bob} +import fr.acinq.eclair.TestConstants.{Alice, Bob, TestFeeEstimator} import fr.acinq.eclair.UInt64.Conversions._ +import fr.acinq.eclair._ import fr.acinq.eclair.blockchain._ import fr.acinq.eclair.blockchain.fee.FeeratesPerKw import fr.acinq.eclair.channel.Channel._ @@ -40,7 +41,6 @@ import fr.acinq.eclair.transactions.DirectedHtlc.{incoming, outgoing} import fr.acinq.eclair.transactions.Transactions import fr.acinq.eclair.transactions.Transactions.{HtlcSuccessTx, htlcSuccessWeight, htlcTimeoutWeight, weight2fee} import fr.acinq.eclair.wire.{AnnouncementSignatures, ChannelUpdate, ClosingSigned, CommitSig, Error, FailureMessageCodecs, PermanentChannelFailure, RevokeAndAck, Shutdown, UpdateAddHtlc, UpdateFailHtlc, UpdateFailMalformedHtlc, UpdateFee, UpdateFulfillHtlc} -import fr.acinq.eclair._ import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.{Outcome, Tag} import scodec.bits._ @@ -345,6 +345,31 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with alice2bob.expectNoMsg(200 millis) } + test("recv CMD_ADD_HTLC (channel feerate mismatch)") { f => + import f._ + + val sender = TestProbe() + bob.feeEstimator.setFeerate(FeeratesPerKw.single(20000)) + sender.send(bob, CurrentFeerates(FeeratesPerKw.single(20000))) + bob2alice.expectNoMsg(100 millis) // we don't close because the commitment doesn't contain any HTLC + + val initialState = bob.stateData.asInstanceOf[DATA_NORMAL] + val upstream = Upstream.Local(UUID.randomUUID()) + val add = CMD_ADD_HTLC(500000 msat, randomBytes32, CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), TestConstants.emptyOnionPacket, upstream) + sender.send(bob, add) + val error = FeerateTooDifferent(channelId(bob), 20000, 10000) + sender.expectMsg(Failure(AddHtlcFailed(channelId(bob), add.paymentHash, error, Origin.Local(upstream.id, Some(sender.ref)), Some(initialState.channelUpdate), Some(add)))) + bob2alice.expectNoMsg(100 millis) // we don't close the channel, we can simply avoid using it while we disagree on feerate + + // we now agree on feerate so we can send HTLCs + bob.feeEstimator.setFeerate(FeeratesPerKw.single(11000)) + sender.send(bob, CurrentFeerates(FeeratesPerKw.single(11000))) + bob2alice.expectNoMsg(100 millis) + sender.send(bob, add) + sender.expectMsg(ChannelCommandResponse.Ok) + bob2alice.expectMsgType[UpdateAddHtlc] + } + test("recv CMD_ADD_HTLC (after having sent Shutdown)") { f => import f._ val sender = TestProbe() @@ -1174,9 +1199,13 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with localChanges = initialState.commitments.localChanges.copy(initialState.commitments.localChanges.proposed :+ fulfill)))) } - test("recv CMD_FULFILL_HTLC") { testReceiveCmdFulfillHtlc _ } + test("recv CMD_FULFILL_HTLC") { + testReceiveCmdFulfillHtlc _ + } - test("recv CMD_FULFILL_HTLC (static_remotekey)", Tag("static_remotekey")) { testReceiveCmdFulfillHtlc _ } + test("recv CMD_FULFILL_HTLC (static_remotekey)", Tag("static_remotekey")) { + testReceiveCmdFulfillHtlc _ + } test("recv CMD_FULFILL_HTLC (unknown htlc id)") { f => import f._ @@ -1232,9 +1261,13 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with assert(forward.htlc === htlc) } - test("recv UpdateFulfillHtlc") { testUpdateFulfillHtlc _ } + test("recv UpdateFulfillHtlc") { + testUpdateFulfillHtlc _ + } - test("recv UpdateFulfillHtlc (static_remotekey)", Tag("(static_remotekey)")) { testUpdateFulfillHtlc _ } + test("recv UpdateFulfillHtlc (static_remotekey)", Tag("(static_remotekey)")) { + testUpdateFulfillHtlc _ + } test("recv UpdateFulfillHtlc (sender has not signed htlc)") { f => import f._ @@ -1308,9 +1341,13 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with } - test("recv CMD_FAIL_HTLC") { testCmdFailHtlc _ } + test("recv CMD_FAIL_HTLC") { + testCmdFailHtlc _ + } - test("recv CMD_FAIL_HTLC (static_remotekey)", Tag("static_remotekey")) { testCmdFailHtlc _ } + test("recv CMD_FAIL_HTLC (static_remotekey)", Tag("static_remotekey")) { + testCmdFailHtlc _ + } test("recv CMD_FAIL_HTLC (unknown htlc id)") { f => import f._ @@ -1397,8 +1434,12 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with relayerA.expectNoMsg() } - test("recv UpdateFailHtlc") { testUpdateFailHtlc _ } - test("recv UpdateFailHtlc (static_remotekey)", Tag("static_remotekey")) { testUpdateFailHtlc _ } + test("recv UpdateFailHtlc") { + testUpdateFailHtlc _ + } + test("recv UpdateFailHtlc (static_remotekey)", Tag("static_remotekey")) { + testUpdateFailHtlc _ + } test("recv UpdateFailMalformedHtlc") { f => import f._ @@ -1536,19 +1577,19 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with test("recv UpdateFee") { f => import f._ val initialData = bob.stateData.asInstanceOf[DATA_NORMAL] - val fee1 = UpdateFee(ByteVector32.Zeroes, 12000) - bob ! fee1 - val fee2 = UpdateFee(ByteVector32.Zeroes, 14000) - bob ! fee2 - awaitCond(bob.stateData == initialData.copy(commitments = initialData.commitments.copy(remoteChanges = initialData.commitments.remoteChanges.copy(proposed = initialData.commitments.remoteChanges.proposed :+ fee2), remoteNextHtlcId = 0))) + val fee = UpdateFee(ByteVector32.Zeroes, 12000) + bob ! fee + awaitCond(bob.stateData == initialData.copy(commitments = initialData.commitments.copy(remoteChanges = initialData.commitments.remoteChanges.copy(proposed = initialData.commitments.remoteChanges.proposed :+ fee), remoteNextHtlcId = 0))) } test("recv UpdateFee (two in a row)") { f => import f._ val initialData = bob.stateData.asInstanceOf[DATA_NORMAL] - val fee = UpdateFee(ByteVector32.Zeroes, 12000) - bob ! fee - awaitCond(bob.stateData == initialData.copy(commitments = initialData.commitments.copy(remoteChanges = initialData.commitments.remoteChanges.copy(proposed = initialData.commitments.remoteChanges.proposed :+ fee), remoteNextHtlcId = 0))) + val fee1 = UpdateFee(ByteVector32.Zeroes, 12000) + bob ! fee1 + val fee2 = UpdateFee(ByteVector32.Zeroes, 14000) + bob ! fee2 + awaitCond(bob.stateData == initialData.copy(commitments = initialData.commitments.copy(remoteChanges = initialData.commitments.remoteChanges.copy(proposed = initialData.commitments.remoteChanges.proposed :+ fee2), remoteNextHtlcId = 0))) } test("recv UpdateFee (when sender is not funder)") { f => @@ -1585,16 +1626,20 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with test("recv UpdateFee (local/remote feerates are too different)") { f => import f._ + bob.feeEstimator.setFeerate(FeeratesPerKw(1000, 2000, 6000, 12000, 36000, 72000, 140000)) val tx = bob.stateData.asInstanceOf[DATA_NORMAL].commitments.localCommit.publishableTxs.commitTx.tx val sender = TestProbe() - // Alice will use $localFeeRate when performing the checks for update_fee - val localFeeRate = bob.feeEstimator.getFeeratePerKw(bob.feeTargets.commitmentBlockTarget) - assert(localFeeRate === 2000) - val remoteFeeUpdate = 85000 - sender.send(bob, UpdateFee(ByteVector32.Zeroes, remoteFeeUpdate)) + val localFeerate = bob.feeEstimator.getFeeratePerKw(bob.feeTargets.commitmentBlockTarget) + assert(localFeerate === 2000) + val remoteFeerate = 4000 + sender.send(bob, UpdateFee(ByteVector32.Zeroes, remoteFeerate)) + bob2alice.expectNoMsg(250 millis) // we don't close because the commitment doesn't contain any HTLC + + // when we try to add an HTLC, we still disagree on the feerate so we close + alice2bob.send(bob, UpdateAddHtlc(ByteVector32.Zeroes, 0, 2500000 msat, randomBytes32, CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), TestConstants.emptyOnionPacket)) val error = bob2alice.expectMsgType[Error] - assert(new String(error.data.toArray) === s"local/remote feerates are too different: remoteFeeratePerKw=$remoteFeeUpdate localFeeratePerKw=$localFeeRate") + assert(new String(error.data.toArray).contains("local/remote feerates are too different")) awaitCond(bob.stateName == CLOSING) // channel should be advertised as down assert(channelUpdateListener.expectMsgType[LocalChannelDown].channelId === bob.stateData.asInstanceOf[DATA_CLOSING].channelId) @@ -2006,10 +2051,15 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with bob2alice.expectNoMsg(500 millis) } - test("recv CurrentFeerate (when fundee, commit-fee/network-fee are very different)") { f => + test("recv CurrentFeerate (when fundee, commit-fee/network-fee are very different, with HTLCs)") { f => import f._ + + addHtlc(10000000 msat, alice, bob, alice2bob, bob2alice) + crossSign(alice, bob, alice2bob, bob2alice) + val sender = TestProbe() - val event = CurrentFeerates(FeeratesPerKw.single(100)) + bob.feeEstimator.setFeerate(FeeratesPerKw.single(14000)) + val event = CurrentFeerates(FeeratesPerKw.single(14000)) sender.send(bob, event) bob2alice.expectMsgType[Error] bob2blockchain.expectMsgType[PublishAsap] // commit tx @@ -2018,6 +2068,24 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with awaitCond(bob.stateName == CLOSING) } + test("recv CurrentFeerate (when fundee, commit-fee/network-fee are very different, without HTLCs)") { f => + import f._ + + val sender = TestProbe() + bob.feeEstimator.setFeerate(FeeratesPerKw.single(1000)) + val event = CurrentFeerates(FeeratesPerKw.single(1000)) + sender.send(bob, event) + bob2alice.expectNoMsg(250 millis) // we don't close because the commitment doesn't contain any HTLC + + // when we try to add an HTLC, we still disagree on the feerate so we close + alice2bob.send(bob, UpdateAddHtlc(ByteVector32.Zeroes, 0, 2500000 msat, randomBytes32, CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), TestConstants.emptyOnionPacket)) + bob2alice.expectMsgType[Error] + bob2blockchain.expectMsgType[PublishAsap] // commit tx + bob2blockchain.expectMsgType[PublishAsap] // main delayed + bob2blockchain.expectMsgType[WatchConfirmed] + awaitCond(bob.stateName == CLOSING) + } + test("recv BITCOIN_FUNDING_SPENT (their commit w/ htlc)") { f => import f._ val sender = TestProbe() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala index bbfb001bf..ecbc33102 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala @@ -469,9 +469,23 @@ class OfflineStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with test("handle feerate changes while offline (funder scenario)") { f => import f._ - val sender = TestProbe() + + // we only close channels on feerate mismatch if there are HTLCs at risk in the commitment + addHtlc(125000000 msat, alice, bob, alice2bob, bob2alice) + crossSign(alice, bob, alice2bob, bob2alice) + + testHandleFeerateFunder(f, shouldClose = true) + } + + test("handle feerate changes while offline without HTLCs (funder scenario)") { f => + testHandleFeerateFunder(f, shouldClose = false) + } + + def testHandleFeerateFunder(f: FixtureParam, shouldClose: Boolean): Unit = { + import f._ // we simulate a disconnection + val sender = TestProbe() sender.send(alice, INPUT_DISCONNECTED) sender.send(bob, INPUT_DISCONNECTED) awaitCond(alice.stateName == OFFLINE) @@ -480,32 +494,44 @@ class OfflineStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with val aliceStateData = alice.stateData.asInstanceOf[DATA_NORMAL] val aliceCommitTx = aliceStateData.commitments.localCommit.publishableTxs.commitTx.tx - val localFeeratePerKw = aliceStateData.commitments.localCommit.spec.feeratePerKw - val tooHighFeeratePerKw = ((alice.underlyingActor.nodeParams.onChainFeeConf.maxFeerateMismatch + 6) * localFeeratePerKw).toLong - val highFeerate = FeeratesPerKw.single(tooHighFeeratePerKw) + val currentFeeratePerKw = aliceStateData.commitments.localCommit.spec.feeratePerKw + // we receive a feerate update that makes our current feerate too low compared to the network's (we multiply by 1.1 + // to ensure the network's feerate is 10% above our threshold). + val networkFeeratePerKw = (1.1 * currentFeeratePerKw / alice.underlyingActor.nodeParams.onChainFeeConf.maxFeerateMismatch.ratioLow).toLong + val networkFeerate = FeeratesPerKw.single(networkFeeratePerKw) // alice is funder - sender.send(alice, CurrentFeerates(highFeerate)) - alice2blockchain.expectMsg(PublishAsap(aliceCommitTx)) + sender.send(alice, CurrentFeerates(networkFeerate)) + if (shouldClose) { + alice2blockchain.expectMsg(PublishAsap(aliceCommitTx)) + } else { + alice2blockchain.expectNoMsg() + } } test("handle feerate changes while offline (don't close on mismatch)", Tag("disable-offline-mismatch")) { f => import f._ - val sender = TestProbe() + + // we only close channels on feerate mismatch if there are HTLCs at risk in the commitment + addHtlc(125000000 msat, alice, bob, alice2bob, bob2alice) + crossSign(alice, bob, alice2bob, bob2alice) // we simulate a disconnection + val sender = TestProbe() sender.send(alice, INPUT_DISCONNECTED) sender.send(bob, INPUT_DISCONNECTED) awaitCond(alice.stateName == OFFLINE) awaitCond(bob.stateName == OFFLINE) val aliceStateData = alice.stateData.asInstanceOf[DATA_NORMAL] - val localFeeratePerKw = aliceStateData.commitments.localCommit.spec.feeratePerKw - val tooHighFeeratePerKw = ((alice.underlyingActor.nodeParams.onChainFeeConf.maxFeerateMismatch + 6) * localFeeratePerKw).toLong - val highFeerate = FeeratesPerKw.single(tooHighFeeratePerKw) + val currentFeeratePerKw = aliceStateData.commitments.localCommit.spec.feeratePerKw + // we receive a feerate update that makes our current feerate too low compared to the network's (we multiply by 1.1 + // to ensure the network's feerate is 10% above our threshold). + val networkFeeratePerKw = (1.1 * currentFeeratePerKw / alice.underlyingActor.nodeParams.onChainFeeConf.maxFeerateMismatch.ratioLow).toLong + val networkFeerate = FeeratesPerKw.single(networkFeeratePerKw) // this time Alice will ignore feerate changes for the offline channel - sender.send(alice, CurrentFeerates(highFeerate)) + sender.send(alice, CurrentFeerates(networkFeerate)) alice2blockchain.expectNoMsg() alice2bob.expectNoMsg() } @@ -546,9 +572,23 @@ class OfflineStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with test("handle feerate changes while offline (fundee scenario)") { f => import f._ - val sender = TestProbe() + + // we only close channels on feerate mismatch if there are HTLCs at risk in the commitment + addHtlc(125000000 msat, alice, bob, alice2bob, bob2alice) + crossSign(alice, bob, alice2bob, bob2alice) + + testHandleFeerateFundee(f, shouldClose = true) + } + + test("handle feerate changes while offline without HTLCs (fundee scenario)") { f => + testHandleFeerateFundee(f, shouldClose = false) + } + + def testHandleFeerateFundee(f: FixtureParam, shouldClose: Boolean): Unit = { + import f._ // we simulate a disconnection + val sender = TestProbe() sender.send(alice, INPUT_DISCONNECTED) sender.send(bob, INPUT_DISCONNECTED) awaitCond(alice.stateName == OFFLINE) @@ -557,13 +597,19 @@ class OfflineStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with val bobStateData = bob.stateData.asInstanceOf[DATA_NORMAL] val bobCommitTx = bobStateData.commitments.localCommit.publishableTxs.commitTx.tx - val localFeeratePerKw = bobStateData.commitments.localCommit.spec.feeratePerKw - val tooHighFeeratePerKw = ((bob.underlyingActor.nodeParams.onChainFeeConf.maxFeerateMismatch + 6) * localFeeratePerKw).toLong - val highFeerate = FeeratesPerKw.single(tooHighFeeratePerKw) + val currentFeeratePerKw = bobStateData.commitments.localCommit.spec.feeratePerKw + // we receive a feerate update that makes our current feerate too low compared to the network's (we multiply by 1.1 + // to ensure the network's feerate is 10% above our threshold). + val networkFeeratePerKw = (1.1 * currentFeeratePerKw / bob.underlyingActor.nodeParams.onChainFeeConf.maxFeerateMismatch.ratioLow).toLong + val networkFeerate = FeeratesPerKw.single(networkFeeratePerKw) // bob is fundee - sender.send(bob, CurrentFeerates(highFeerate)) - bob2blockchain.expectMsg(PublishAsap(bobCommitTx)) + sender.send(bob, CurrentFeerates(networkFeerate)) + if (shouldClose) { + bob2blockchain.expectMsg(PublishAsap(bobCommitTx)) + } else { + bob2blockchain.expectNoMsg() + } } test("re-send channel_update at reconnection for private channels") { f => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala index 953e51ae8..03a03a670 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala @@ -90,7 +90,6 @@ class SqliteAuditDbSpec extends AnyFunSuite { val sqlite = TestConstants.sqliteInMemory() val db = new SqliteAuditDb(sqlite) - val n1 = randomKey.publicKey val n2 = randomKey.publicKey val n3 = randomKey.publicKey val n4 = randomKey.publicKey @@ -99,24 +98,36 @@ class SqliteAuditDbSpec extends AnyFunSuite { val c2 = randomBytes32 val c3 = randomBytes32 val c4 = randomBytes32 + val c5 = randomBytes32 + val c6 = randomBytes32 - db.add(ChannelPaymentRelayed(46000 msat, 44000 msat, randomBytes32, randomBytes32, c1)) - db.add(ChannelPaymentRelayed(41000 msat, 40000 msat, randomBytes32, randomBytes32, c1)) - db.add(ChannelPaymentRelayed(43000 msat, 42000 msat, randomBytes32, randomBytes32, c1)) - db.add(ChannelPaymentRelayed(42000 msat, 40000 msat, randomBytes32, randomBytes32, c2)) - db.add(TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(25000 msat, randomBytes32)), Seq(PaymentRelayed.Part(20000 msat, c4)))) - db.add(TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(46000 msat, randomBytes32)), Seq(PaymentRelayed.Part(16000 msat, c2), PaymentRelayed.Part(10000 msat, c4), PaymentRelayed.Part(14000 msat, c4)))) + db.add(ChannelPaymentRelayed(46000 msat, 44000 msat, randomBytes32, c6, c1)) + db.add(ChannelPaymentRelayed(41000 msat, 40000 msat, randomBytes32, c6, c1)) + db.add(ChannelPaymentRelayed(43000 msat, 42000 msat, randomBytes32, c5, c1)) + db.add(ChannelPaymentRelayed(42000 msat, 40000 msat, randomBytes32, c5, c2)) + db.add(ChannelPaymentRelayed(45000 msat, 40000 msat, randomBytes32, c5, c6)) + db.add(TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(25000 msat, c6)), Seq(PaymentRelayed.Part(20000 msat, c4)))) + db.add(TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(46000 msat, c6)), Seq(PaymentRelayed.Part(16000 msat, c2), PaymentRelayed.Part(10000 msat, c4), PaymentRelayed.Part(14000 msat, c4)))) db.add(NetworkFeePaid(null, n2, c2, Transaction(0, Seq.empty, Seq.empty, 0), 200 sat, "funding")) db.add(NetworkFeePaid(null, n2, c2, Transaction(0, Seq.empty, Seq.empty, 0), 300 sat, "mutual")) db.add(NetworkFeePaid(null, n3, c3, Transaction(0, Seq.empty, Seq.empty, 0), 400 sat, "funding")) db.add(NetworkFeePaid(null, n4, c4, Transaction(0, Seq.empty, Seq.empty, 0), 500 sat, "funding")) - assert(db.stats.toSet === Set( - Stats(channelId = c1, avgPaymentAmount = 42 sat, paymentCount = 3, relayFee = 4 sat, networkFee = 0 sat), - Stats(channelId = c2, avgPaymentAmount = 40 sat, paymentCount = 2, relayFee = 4 sat, networkFee = 500 sat), - Stats(channelId = c3, avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 400 sat), - Stats(channelId = c4, avgPaymentAmount = 30 sat, paymentCount = 2, relayFee = 9 sat, networkFee = 500 sat) + // NB: we only count a relay fee for the outgoing channel, no the incoming one. + assert(db.stats(0, System.currentTimeMillis + 1).toSet === Set( + Stats(channelId = c1, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 0 sat), + Stats(channelId = c1, direction = "OUT", avgPaymentAmount = 42 sat, paymentCount = 3, relayFee = 4 sat, networkFee = 0 sat), + Stats(channelId = c2, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 500 sat), + Stats(channelId = c2, direction = "OUT", avgPaymentAmount = 28 sat, paymentCount = 2, relayFee = 4 sat, networkFee = 500 sat), + Stats(channelId = c3, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 400 sat), + Stats(channelId = c3, direction = "OUT", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 400 sat), + Stats(channelId = c4, direction = "IN", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 500 sat), + Stats(channelId = c4, direction = "OUT", avgPaymentAmount = 22 sat, paymentCount = 2, relayFee = 9 sat, networkFee = 500 sat), + Stats(channelId = c5, direction = "IN", avgPaymentAmount = 43 sat, paymentCount = 3, relayFee = 0 sat, networkFee = 0 sat), + Stats(channelId = c5, direction = "OUT", avgPaymentAmount = 0 sat, paymentCount = 0, relayFee = 0 sat, networkFee = 0 sat), + Stats(channelId = c6, direction = "IN", avgPaymentAmount = 39 sat, paymentCount = 4, relayFee = 0 sat, networkFee = 0 sat), + Stats(channelId = c6, direction = "OUT", avgPaymentAmount = 40 sat, paymentCount = 1, relayFee = 5 sat, networkFee = 0 sat) )) } @@ -148,7 +159,7 @@ class SqliteAuditDbSpec extends AnyFunSuite { }) // Test starts here. val start = System.currentTimeMillis - assert(db.stats.nonEmpty) + assert(db.stats(0, start + 1).nonEmpty) val end = System.currentTimeMillis fail(s"took ${end - start}ms") } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala index bceb57f93..95e3bac4d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala @@ -46,7 +46,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS import MultiPartPaymentLifecycleSpec._ - case class FixtureParam(paymentId: UUID, + case class FixtureParam(cfg: SendPaymentConfig, nodeParams: NodeParams, payFsm: TestFSMRef[MultiPartPaymentLifecycle.State, MultiPartPaymentLifecycle.Data, MultiPartPaymentLifecycle], router: TestProbe, @@ -64,7 +64,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS } val paymentHandler = TestFSMRef(new TestMultiPartPaymentLifecycle().asInstanceOf[MultiPartPaymentLifecycle]) system.eventStream.subscribe(eventListener.ref, classOf[PaymentEvent]) - withFixture(test.toNoArgTest(FixtureParam(id, nodeParams, paymentHandler, router, sender, childPayFsm, eventListener))) + withFixture(test.toNoArgTest(FixtureParam(cfg, nodeParams, paymentHandler, router, sender, childPayFsm, eventListener))) } test("successful first attempt (single part)") { f => @@ -74,7 +74,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS val payment = SendMultiPartPayment(randomBytes32, e, finalAmount, expiry, 1, routeParams = Some(routeParams.copy(randomize = true))) sender.send(payFsm, payment) - router.expectMsg(RouteRequest(nodeParams.nodeId, e, finalAmount, maxFee, routeParams = Some(routeParams.copy(randomize = false)), allowMultiPart = true)) + router.expectMsg(RouteRequest(nodeParams.nodeId, e, finalAmount, maxFee, routeParams = Some(routeParams.copy(randomize = false)), allowMultiPart = true, paymentContext = Some(cfg.paymentContext))) assert(payFsm.stateName === WAIT_FOR_ROUTES) val singleRoute = Route(finalAmount, hop_ab_1 :: hop_be :: Nil) @@ -100,7 +100,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS val payment = SendMultiPartPayment(randomBytes32, e, 1200000 msat, expiry, 1, routeParams = Some(routeParams.copy(randomize = false))) sender.send(payFsm, payment) - router.expectMsg(RouteRequest(nodeParams.nodeId, e, 1200000 msat, maxFee, routeParams = Some(routeParams.copy(randomize = false)), allowMultiPart = true)) + router.expectMsg(RouteRequest(nodeParams.nodeId, e, 1200000 msat, maxFee, routeParams = Some(routeParams.copy(randomize = false)), allowMultiPart = true, paymentContext = Some(cfg.paymentContext))) assert(payFsm.stateName === WAIT_FOR_ROUTES) val routes = Seq( @@ -156,7 +156,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS val childId = payFsm.stateData.asInstanceOf[PaymentProgress].pending.keys.head childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(failingRoute.hops, Sphinx.DecryptedFailurePacket(b, PermanentChannelFailure))))) // We retry ignoring the failing channel. - router.expectMsg(RouteRequest(nodeParams.nodeId, e, finalAmount, maxFee, routeParams = Some(routeParams.copy(randomize = true)), allowMultiPart = true, ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_be, b, e))))) + router.expectMsg(RouteRequest(nodeParams.nodeId, e, finalAmount, maxFee, routeParams = Some(routeParams.copy(randomize = true)), allowMultiPart = true, ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_be, b, e))), paymentContext = Some(cfg.paymentContext))) router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ac_1 :: hop_ce :: Nil), Route(600000 msat, hop_ad :: hop_de :: Nil)))) childPayFsm.expectMsgType[SendPaymentToRoute] childPayFsm.expectMsgType[SendPaymentToRoute] @@ -183,13 +183,13 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.send(payFsm, PaymentFailed(failedId1, paymentHash, Seq(RemoteFailure(failedRoute1.hops, Sphinx.DecryptedFailurePacket(b, TemporaryNodeFailure))))) // When we retry, we ignore the failing node and we let the router know about the remaining pending route. - router.expectMsg(RouteRequest(nodeParams.nodeId, e, failedRoute1.amount, maxFee - failedRoute1.fee, ignore = Ignore(Set(b), Set.empty), pendingPayments = Seq(failedRoute2), allowMultiPart = true, routeParams = Some(routeParams.copy(randomize = true)))) + router.expectMsg(RouteRequest(nodeParams.nodeId, e, failedRoute1.amount, maxFee - failedRoute1.fee, ignore = Ignore(Set(b), Set.empty), pendingPayments = Seq(failedRoute2), allowMultiPart = true, routeParams = Some(routeParams.copy(randomize = true)), paymentContext = Some(cfg.paymentContext))) // The second part fails while we're still waiting for new routes. childPayFsm.send(payFsm, PaymentFailed(failedId2, paymentHash, Seq(RemoteFailure(failedRoute2.hops, Sphinx.DecryptedFailurePacket(b, TemporaryNodeFailure))))) // We receive a response to our first request, but it's now obsolete: we re-sent a new route request that takes into // account the latest failures. router.send(payFsm, RouteResponse(Seq(Route(failedRoute1.amount, hop_ac_1 :: hop_ce :: Nil)))) - router.expectMsg(RouteRequest(nodeParams.nodeId, e, finalAmount, maxFee, ignore = Ignore(Set(b), Set.empty), allowMultiPart = true, routeParams = Some(routeParams.copy(randomize = true)))) + router.expectMsg(RouteRequest(nodeParams.nodeId, e, finalAmount, maxFee, ignore = Ignore(Set(b), Set.empty), allowMultiPart = true, routeParams = Some(routeParams.copy(randomize = true)), paymentContext = Some(cfg.paymentContext))) awaitCond(payFsm.stateData.asInstanceOf[PaymentProgress].pending.isEmpty) childPayFsm.expectNoMsg(100 millis) @@ -225,7 +225,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_ab_1, a, b))), pendingPayments = Seq(pendingRoute), allowMultiPart = true, - routeParams = Some(routeParams.copy(randomize = true))) + routeParams = Some(routeParams.copy(randomize = true)), + paymentContext = Some(cfg.paymentContext)) router.expectMsg(expectedRouteRequest) router.send(payFsm, Status.Failure(RouteNotFound)) router.expectMsg(expectedRouteRequest.copy(ignore = Ignore.empty)) @@ -270,7 +271,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS router.send(payFsm, Status.Failure(RouteNotFound)) val result = sender.expectMsgType[PaymentFailed] - assert(result.id === paymentId) + assert(result.id === cfg.id) assert(result.paymentHash === paymentHash) assert(result.failures === Seq(LocalFailure(Nil, RouteNotFound))) @@ -347,9 +348,10 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS awaitCond(payFsm.stateName === PAYMENT_ABORTED) sender.watch(payFsm) - childPayFsm.send(payFsm, PaymentSent(paymentId, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(successId, successRoute.amount, successRoute.fee, randomBytes32, Some(successRoute.hops))))) + childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(successId, successRoute.amount, successRoute.fee, randomBytes32, Some(successRoute.hops))))) + sender.expectMsg(PreimageReceived(paymentHash, paymentPreimage)) val result = sender.expectMsgType[PaymentSent] - assert(result.id === paymentId) + assert(result.id === cfg.id) assert(result.paymentHash === paymentHash) assert(result.paymentPreimage === paymentPreimage) assert(result.parts.length === 1 && result.parts.head.id === successId) @@ -375,7 +377,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectMsgType[SendPaymentToRoute] val (childId, route) :: (failedId, failedRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toList - childPayFsm.send(payFsm, PaymentSent(paymentId, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(childId, route.amount, route.fee, randomBytes32, Some(route.hops))))) + childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(childId, route.amount, route.fee, randomBytes32, Some(route.hops))))) + sender.expectMsg(PreimageReceived(paymentHash, paymentPreimage)) awaitCond(payFsm.stateName === PAYMENT_SUCCEEDED) sender.watch(payFsm) @@ -401,9 +404,10 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS val partialPayments = pending.map { case (childId, route) => PaymentSent.PartialPayment(childId, route.amount, route.fee, randomBytes32, Some(route.hops)) } - partialPayments.foreach(pp => childPayFsm.send(payFsm, PaymentSent(paymentId, paymentHash, paymentPreimage, finalAmount, e, Seq(pp)))) + partialPayments.foreach(pp => childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(pp)))) + sender.expectMsg(PreimageReceived(paymentHash, paymentPreimage)) val result = sender.expectMsgType[PaymentSent] - assert(result.id === paymentId) + assert(result.id === cfg.id) assert(result.paymentHash === paymentHash) assert(result.paymentPreimage === paymentPreimage) assert(result.parts.toSet === partialPayments.toSet) @@ -434,7 +438,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS } val result = sender.expectMsgType[PaymentFailed] - assert(result.id === paymentId) + assert(result.id === cfg.id) assert(result.paymentHash === paymentHash) assert(result.failures.nonEmpty) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/NodeRelayerSpec.scala index 668777480..9555bfee5 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/NodeRelayerSpec.scala @@ -26,9 +26,8 @@ import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Upstream} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.PaymentRequest.{ExtraHop, PaymentRequestFeatures} import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM -import fr.acinq.eclair.payment.relay.{CommandBuffer, NodeRelayer, Origin, Relayer} -import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.SendMultiPartPayment -import fr.acinq.eclair.payment.send.PaymentError +import fr.acinq.eclair.payment.relay.{CommandBuffer, NodeRelayer} +import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment} import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPayment import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound} @@ -281,18 +280,12 @@ class NodeRelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { val outgoingPayment = outgoingPayFSM.expectMsgType[SendMultiPartPayment] validateOutgoingPayment(outgoingPayment) - // A first downstream HTLC is fulfilled. - val ff1 = createDownstreamFulfill(outgoingPayFSM.ref) - relayer.send(nodeRelayer, ff1) - outgoingPayFSM.expectMsg(ff1) - // We should immediately forward the fulfill upstream. + // A first downstream HTLC is fulfilled: we should immediately forward the fulfill upstream. + outgoingPayFSM.send(nodeRelayer, PreimageReceived(paymentHash, paymentPreimage)) incomingMultiPart.foreach(p => commandBuffer.expectMsg(CommandBuffer.CommandSend(p.add.channelId, CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)))) - // A second downstream HTLC is fulfilled. - val ff2 = createDownstreamFulfill(outgoingPayFSM.ref) - relayer.send(nodeRelayer, ff2) - outgoingPayFSM.expectMsg(ff2) - // We should not fulfill a second time upstream. + // If the payment FSM sends us duplicate preimage events, we should not fulfill a second time upstream. + outgoingPayFSM.send(nodeRelayer, PreimageReceived(paymentHash, paymentPreimage)) commandBuffer.expectNoMsg(100 millis) // Once all the downstream payments have settled, we should emit the relayed event. @@ -315,9 +308,7 @@ class NodeRelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { val outgoingPayment = outgoingPayFSM.expectMsgType[SendMultiPartPayment] validateOutgoingPayment(outgoingPayment) - val ff = createDownstreamFulfill(outgoingPayFSM.ref) - relayer.send(nodeRelayer, ff) - outgoingPayFSM.expectMsg(ff) + outgoingPayFSM.send(nodeRelayer, PreimageReceived(paymentHash, paymentPreimage)) val incomingAdd = incomingSinglePart.add commandBuffer.expectMsg(CommandBuffer.CommandSend(incomingAdd.channelId, CMD_FULFILL_HTLC(incomingAdd.id, paymentPreimage, commit = true))) @@ -351,9 +342,7 @@ class NodeRelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { assert(outgoingPayment.routeParams.isDefined) assert(outgoingPayment.assistedRoutes === hints) - val ff = createDownstreamFulfill(outgoingPayFSM.ref) - relayer.send(nodeRelayer, ff) - outgoingPayFSM.expectMsg(ff) + outgoingPayFSM.send(nodeRelayer, PreimageReceived(paymentHash, paymentPreimage)) incomingMultiPart.foreach(p => commandBuffer.expectMsg(CommandBuffer.CommandSend(p.add.channelId, CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)))) outgoingPayFSM.send(nodeRelayer, createSuccessEvent(outgoingCfg.id)) @@ -383,9 +372,7 @@ class NodeRelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { assert(outgoingPayment.routeParams.isDefined) assert(outgoingPayment.assistedRoutes === hints) - val ff = createDownstreamFulfill(outgoingPayFSM.ref) - relayer.send(nodeRelayer, ff) - outgoingPayFSM.expectMsg(ff) + outgoingPayFSM.send(nodeRelayer, PreimageReceived(paymentHash, paymentPreimage)) incomingMultiPart.foreach(p => commandBuffer.expectMsg(CommandBuffer.CommandSend(p.add.channelId, CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)))) outgoingPayFSM.send(nodeRelayer, createSuccessEvent(outgoingCfg.id)) @@ -447,11 +434,6 @@ object NodeRelayerSpec { val incomingSinglePart = createValidIncomingPacket(incomingAmount, incomingAmount, CltvExpiry(500000), outgoingAmount, outgoingExpiry) - def createDownstreamFulfill(payFSM: ActorRef): Relayer.ForwardFulfill = { - val origin = Origin.TrampolineRelayed(null, Some(payFSM)) - Relayer.ForwardRemoteFulfill(UpdateFulfillHtlc(randomBytes32, Random.nextInt(100), paymentPreimage), origin, null) - } - def createSuccessEvent(id: UUID): PaymentSent = PaymentSent(id, paymentHash, paymentPreimage, outgoingAmount, outgoingNodeId, Seq(PaymentSent.PartialPayment(id, outgoingAmount, 10 msat, randomBytes32, None))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala index 1c1e7f697..3e7117f49 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala @@ -32,6 +32,7 @@ import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.SendMultiPartPayme import fr.acinq.eclair.payment.send.PaymentInitiator._ import fr.acinq.eclair.payment.send.PaymentLifecycle.{SendPayment, SendPaymentToRoute} import fr.acinq.eclair.payment.send.{PaymentError, PaymentInitiator} +import fr.acinq.eclair.router.RouteNotFound import fr.acinq.eclair.router.Router.{MultiPartParams, NodeHop, RouteParams} import fr.acinq.eclair.wire.Onion.{FinalLegacyPayload, FinalTlvPayload} import fr.acinq.eclair.wire.OnionTlv.{AmountToForward, OutgoingCltv} @@ -302,13 +303,13 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(msg1.totalAmount === finalAmount + 21000.msat) // Simulate a failure which should trigger a retry. - val failed = PaymentFailed(cfg.parentId, pr.paymentHash, Seq(RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(b, TrampolineFeeInsufficient)))) - multiPartPayFsm.send(initiator, failed) + multiPartPayFsm.send(initiator, PaymentFailed(cfg.parentId, pr.paymentHash, Seq(RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(b, TrampolineFeeInsufficient))))) multiPartPayFsm.expectMsgType[SendPaymentConfig] val msg2 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] assert(msg2.totalAmount === finalAmount + 25000.msat) // Simulate a failure that exhausts payment attempts. + val failed = PaymentFailed(cfg.parentId, pr.paymentHash, Seq(RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(b, TemporaryNodeFailure)))) multiPartPayFsm.send(initiator, failed) sender.expectMsg(failed) eventListener.expectMsg(failed) @@ -316,6 +317,34 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike eventListener.expectNoMsg(100 millis) } + test("retry trampoline payment and fail (route not found)") { f => + import f._ + val features = PaymentRequestFeatures(VariableLengthOnion.optional, PaymentSecret.optional, BasicMultiPartPayment.optional, TrampolinePayment.optional) + val pr = PaymentRequest(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, priv_c.privateKey, "Some phoenix invoice", features = Some(features)) + val trampolineAttempts = (21000 msat, CltvExpiryDelta(12)) :: (25000 msat, CltvExpiryDelta(24)) :: Nil + val req = SendTrampolinePaymentRequest(finalAmount, pr, b, trampolineAttempts, CltvExpiryDelta(9)) + sender.send(initiator, req) + sender.expectMsgType[UUID] + + val cfg = multiPartPayFsm.expectMsgType[SendPaymentConfig] + val msg1 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] + assert(msg1.totalAmount === finalAmount + 21000.msat) + // Trampoline node couldn't find a route for the given fee. + val failed = PaymentFailed(cfg.parentId, pr.paymentHash, Seq(RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(b, TrampolineFeeInsufficient)))) + multiPartPayFsm.send(initiator, failed) + multiPartPayFsm.expectMsgType[SendPaymentConfig] + val msg2 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] + assert(msg2.totalAmount === finalAmount + 25000.msat) + // Trampoline node couldn't find a route even with the increased fee. + multiPartPayFsm.send(initiator, failed) + + val failure = sender.expectMsgType[PaymentFailed] + assert(failure.failures === Seq(LocalFailure(Seq(NodeHop(nodeParams.nodeId, b, nodeParams.expiryDeltaBlocks, 0 msat), NodeHop(b, c, CltvExpiryDelta(24), 25000 msat)), RouteNotFound))) + eventListener.expectMsg(failure) + sender.expectNoMsg(100 millis) + eventListener.expectNoMsg(100 millis) + } + test("forward trampoline payment with pre-defined route") { f => import f._ val pr = PaymentRequest(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, priv_c.privateKey, "Some invoice") diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index f20b150fe..2d4510158 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -62,10 +62,9 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val defaultExternalId = UUID.randomUUID().toString val defaultPaymentRequest = SendPaymentRequest(defaultAmountMsat, defaultPaymentHash, d, 1, externalId = Some(defaultExternalId)) - def defaultRouteRequest(source: PublicKey, target: PublicKey): RouteRequest = RouteRequest(source, target, defaultAmountMsat, defaultMaxFee) + def defaultRouteRequest(source: PublicKey, target: PublicKey, cfg: SendPaymentConfig): RouteRequest = RouteRequest(source, target, defaultAmountMsat, defaultMaxFee, paymentContext = Some(cfg.paymentContext)) - case class PaymentFixture(id: UUID, - parentId: UUID, + case class PaymentFixture(cfg: SendPaymentConfig, nodeParams: NodeParams, paymentFSM: TestFSMRef[PaymentLifecycle.State, PaymentLifecycle.Data, PaymentLifecycle], routerForwarder: TestProbe, @@ -83,12 +82,13 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) system.eventStream.subscribe(eventListener.ref, classOf[PaymentEvent]) - PaymentFixture(id, parentId, nodeParams, paymentFSM, routerForwarder, register, sender, monitor, eventListener) + PaymentFixture(cfg, nodeParams, paymentFSM, routerForwarder, register, sender, monitor, eventListener) } test("send to route") { _ => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ // pre-computed route going from A to D val route = Route(defaultAmountMsat, ChannelHop(a, b, update_ab) :: ChannelHop(b, c, update_bc) :: ChannelHop(c, d, update_cd) :: Nil) @@ -112,12 +112,13 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("send to route (node_id only)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ // pre-computed route going from A to D val request = SendPaymentToRoute(Left(Seq(a, b, c, d)), FinalLegacyPayload(defaultAmountMsat, defaultExpiry)) sender.send(paymentFSM, request) - routerForwarder.expectMsg(FinalizeRoute(defaultAmountMsat, Seq(a, b, c, d))) + routerForwarder.expectMsg(FinalizeRoute(defaultAmountMsat, Seq(a, b, c, d), paymentContext = Some(cfg.paymentContext))) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) routerForwarder.forward(routerFixture.router) @@ -148,13 +149,14 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("send to route (routing hints)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val recipient = randomKey.publicKey val routingHint = Seq(Seq(ExtraHop(c, ShortChannelId(561), 1 msat, 100, CltvExpiryDelta(144)))) val request = SendPaymentToRoute(Left(Seq(a, b, c, recipient)), FinalLegacyPayload(defaultAmountMsat, defaultExpiry), routingHint) sender.send(paymentFSM, request) - routerForwarder.expectMsg(FinalizeRoute(defaultAmountMsat, Seq(a, b, c, recipient), routingHint)) + routerForwarder.expectMsg(FinalizeRoute(defaultAmountMsat, Seq(a, b, c, recipient), routingHint, paymentContext = Some(cfg.paymentContext))) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) routerForwarder.forward(routerFixture.router) @@ -171,6 +173,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment failed (route not found)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(f, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 5) sender.send(paymentFSM, request) @@ -186,6 +189,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment failed (route too expensive)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 5, routeParams = Some(RouteParams(randomize = false, 100 msat, 0.0, 20, CltvExpiryDelta(2016), None, MultiPartParams(10000 msat, 5)))) sender.send(paymentFSM, request) @@ -200,10 +204,11 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment failed (unparsable failure)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 2) sender.send(paymentFSM, request) - routerForwarder.expectMsg(defaultRouteRequest(a, d)) + routerForwarder.expectMsg(defaultRouteRequest(a, d, cfg)) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, _, Nil, _) = paymentFSM.stateData @@ -216,7 +221,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, Relayer.ForwardRemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, randomBytes32), defaultOrigin, UpdateAddHtlc(ByteVector32.Zeroes, 0, defaultAmountMsat, defaultPaymentHash, defaultExpiry, TestConstants.emptyOnionPacket))) // unparsable message // then the payment lifecycle will ask for a new route excluding all intermediate nodes - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d).copy(ignore = Ignore(Set(c), Set.empty))) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(ignore = Ignore(Set(c), Set.empty))) // let's simulate a response by the router with another route sender.send(paymentFSM, RouteResponse(route :: Nil)) @@ -235,13 +240,14 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment failed (local error)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, _, Nil, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d)) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, _, cmd1, Nil, _, _, _) = paymentFSM.stateData @@ -250,20 +256,21 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, Status.Failure(AddHtlcFailed(ByteVector32.Zeroes, defaultPaymentHash, ChannelUnavailable(ByteVector32.Zeroes), Local(id, Some(paymentFSM.underlying.self)), None, None))) // then the payment lifecycle will ask for a new route excluding the channel - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_ab, a, b))))) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_ab, a, b))))) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) // payment is still pending because the error is recoverable } test("payment failed (first hop returns an UpdateFailMalformedHtlc)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, _, Nil, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d)) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, _, cmd1, Nil, _, _, _) = paymentFSM.stateData @@ -272,19 +279,20 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, UpdateFailMalformedHtlc(ByteVector32.Zeroes, 0, randomBytes32, FailureMessageCodecs.BADONION)) // then the payment lifecycle will ask for a new route excluding the channel - routerForwarder.expectMsg(defaultRouteRequest(a, d).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_ab, a, b))))) + routerForwarder.expectMsg(defaultRouteRequest(a, d, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_ab, a, b))))) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) } test("payment failed (TemporaryChannelFailure)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) val WaitingForRoute(_, _, Nil, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d)) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, route) = paymentFSM.stateData @@ -299,7 +307,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // payment lifecycle forwards the embedded channelUpdate to the router routerForwarder.expectMsg(update_bc) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) - routerForwarder.expectMsg(defaultRouteRequest(a, d)) + routerForwarder.expectMsg(defaultRouteRequest(a, d, cfg)) routerForwarder.forward(routerFixture.router) // we allow 2 tries, so we send a 2nd request to the router assert(sender.expectMsgType[PaymentFailed].failures === RemoteFailure(route.hops, Sphinx.DecryptedFailurePacket(b, failure)) :: LocalFailure(Nil, RouteNotFound) :: Nil) @@ -308,13 +316,14 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment failed (Update)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 5) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, _, Nil, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d)) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, route1) = paymentFSM.stateData @@ -329,7 +338,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // payment lifecycle forwards the embedded channelUpdate to the router routerForwarder.expectMsg(channelUpdate_bc_modified) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) // 1 failure but not final, the payment is still PENDING - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d)) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) routerForwarder.forward(routerFixture.router) // router answers with a new route, taking into account the new update @@ -349,7 +358,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // but it will still forward the embedded channelUpdate to the router routerForwarder.expectMsg(channelUpdate_bc_modified_2) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d)) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) routerForwarder.forward(routerFixture.router) // this time the router can't find a route: game over @@ -357,9 +366,34 @@ class PaymentLifecycleSpec extends BaseRouterSpec { awaitCond(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Failed])) } + test("payment failed (Update in last attempt)") { routerFixture => + val payFixture = createPaymentLifecycle() + import payFixture._ + import cfg._ + + val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 1) + sender.send(paymentFSM, request) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) + routerForwarder.forward(routerFixture.router) + awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) + val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, _) = paymentFSM.stateData + register.expectMsg(ForwardShortId(channelId_ab, cmd1)) + + // the node replies with a temporary failure containing the same update as the one we already have (likely a balance issue) + val failure = TemporaryChannelFailure(update_bc) + sender.send(paymentFSM, UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure))) + // we should temporarily exclude that channel + routerForwarder.expectMsg(ExcludeChannel(ChannelDesc(update_bc.shortChannelId, b, c))) + routerForwarder.expectMsg(update_bc) + + // this was a single attempt payment + sender.expectMsgType[PaymentFailed] + } + test("payment failed (Update in assisted route)") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ // we build an assisted route for channel bc and cd val assistedRoutes = Seq(Seq( @@ -372,7 +406,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, _, Nil, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d).copy(assistedRoutes = assistedRoutes)) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(assistedRoutes = assistedRoutes)) routerForwarder.forward(routerFixture.router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, _) = paymentFSM.stateData @@ -391,7 +425,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { ExtraHop(b, channelId_bc, update_bc.feeBaseMsat, update_bc.feeProportionalMillionths, channelUpdate_bc_modified.cltvExpiryDelta), ExtraHop(c, channelId_cd, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.cltvExpiryDelta) )) - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d).copy(assistedRoutes = assistedRoutes1)) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(assistedRoutes = assistedRoutes1)) routerForwarder.forward(routerFixture.router) // router answers with a new route, taking into account the new update @@ -401,16 +435,44 @@ class PaymentLifecycleSpec extends BaseRouterSpec { assert(cmd2.cltvExpiry > cmd1.cltvExpiry) } + test("payment failed (Update disabled in assisted route)") { routerFixture => + val payFixture = createPaymentLifecycle() + import payFixture._ + import cfg._ + + // we build an assisted route for channel cd + val assistedRoutes = Seq(Seq(ExtraHop(c, channelId_cd, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.cltvExpiryDelta))) + val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 1, assistedRoutes = assistedRoutes) + sender.send(paymentFSM, request) + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) + + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(assistedRoutes = assistedRoutes)) + routerForwarder.forward(routerFixture.router) + awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) + val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, _) = paymentFSM.stateData + register.expectMsg(ForwardShortId(channelId_ab, cmd1)) + + // we disable the channel + val channelUpdate_cd_disabled = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_c, d, channelId_cd, CltvExpiryDelta(42), update_cd.htlcMinimumMsat, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.htlcMaximumMsat.get, enable = false) + val failure = ChannelDisabled(channelUpdate_cd_disabled.messageFlags, channelUpdate_cd_disabled.channelFlags, channelUpdate_cd_disabled) + val failureOnion = Sphinx.FailurePacket.wrap(Sphinx.FailurePacket.create(sharedSecrets1(1)._1, failure), sharedSecrets1.head._1) + sender.send(paymentFSM, UpdateFailHtlc(ByteVector32.Zeroes, 0, failureOnion)) + + routerForwarder.expectMsg(channelUpdate_cd_disabled) + routerForwarder.expectMsg(ExcludeChannel(ChannelDesc(update_cd.shortChannelId, c, d))) + } + def testPermanentFailure(router: ActorRef, failure: FailureMessage): Unit = { val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, _, Nil, _) = paymentFSM.stateData - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d)) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) routerForwarder.forward(router) awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) val WaitingForComplete(_, _, cmd1, Nil, sharedSecrets1, _, route1) = paymentFSM.stateData @@ -420,7 +482,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // payment lifecycle forwards the embedded channelUpdate to the router awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) - routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_bc, b, c))))) + routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg).copy(ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_bc, b, c))))) routerForwarder.forward(router) // we allow 2 tries, so we send a 2nd request to the router, which won't find another route @@ -441,6 +503,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment succeeded") { routerFixture => val payFixture = createPaymentLifecycle() import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 5) sender.send(paymentFSM, request) @@ -565,6 +628,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("disable database and events") { routerFixture => val payFixture = createPaymentLifecycle(storeInDb = false, publishEvent = false) import payFixture._ + import cfg._ val request = SendPayment(d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), 3) sender.send(paymentFSM, request) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala index 753f8aa1c..34d7c8b64 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala @@ -28,6 +28,7 @@ import fr.acinq.eclair.payment.OutgoingPacket.{buildCommand, buildOnion, buildPa import fr.acinq.eclair.payment.relay.Origin._ import fr.acinq.eclair.payment.relay.Relayer._ import fr.acinq.eclair.payment.relay.{CommandBuffer, Relayer} +import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.PreimageReceived import fr.acinq.eclair.router.Router.{ChannelHop, Ignore, NodeHop} import fr.acinq.eclair.router.{Announcements, _} import fr.acinq.eclair.wire.Onion.{ChannelRelayTlvPayload, FinalLegacyPayload, FinalTlvPayload, PerHopPayload} @@ -474,6 +475,9 @@ class RelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { sender.send(relayer, Status.Failure(AddHtlcFailed(channelId_bc, paymentHash, InsufficientFunds(channelId_bc, origin.amountOut, 100 sat, 0 sat, 0 sat), origin, Some(channelUpdate_bc), None))) assert(register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message.reason === Right(TemporaryChannelFailure(channelUpdate_bc))) + sender.send(relayer, Status.Failure(AddHtlcFailed(channelId_bc, paymentHash, FeerateTooDifferent(channelId_bc, 1000, 300), origin, Some(channelUpdate_bc), None))) + assert(register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message.reason === Right(TemporaryChannelFailure(channelUpdate_bc))) + val channelUpdate_bc_disabled = channelUpdate_bc.copy(channelFlags = 2) sender.send(relayer, Status.Failure(AddHtlcFailed(channelId_bc, paymentHash, ChannelUnavailable(channelId_bc), origin, Some(channelUpdate_bc_disabled), None))) assert(register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message.reason === Right(ChannelDisabled(channelUpdate_bc_disabled.messageFlags, channelUpdate_bc_disabled.channelFlags, channelUpdate_bc_disabled))) @@ -550,8 +554,9 @@ class RelayerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { } sender.send(relayer, forwardFulfill) - // the FSM responsible for the payment should receive the fulfill. + // the FSM responsible for the payment should receive the fulfill and emit a preimage event. payFSM.expectMsg(forwardFulfill) + system.actorSelection(relayer.path.child("node-relayer")).tell(PreimageReceived(paymentHash, preimage), payFSM.ref) // the payment should be immediately fulfilled upstream. val upstream1 = register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala index 58c489a44..1ceaa867f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala @@ -595,6 +595,21 @@ class RouterSpec extends BaseRouterSpec { assert(balances.contains(edge_ab.balance_opt)) assert(edge_ba.balance_opt === None) } + + { + // Private channels should also update the graph when HTLCs are relayed through them. + val balances = Set(33000000 msat, 5000000 msat) + val commitments = CommitmentsSpec.makeCommitments(33000000 msat, 5000000 msat, a, g, announceChannel = false) + sender.send(router, AvailableBalanceChanged(sender.ref, null, channelId_ag, commitments)) + sender.send(router, Symbol("data")) + val data = sender.expectMsgType[Data] + val channel_ag = data.privateChannels(channelId_ag) + assert(Set(channel_ag.meta.balance1, channel_ag.meta.balance2) === balances) + // And the graph should be updated too. + val edge_ag = data.graph.getEdge(ChannelDesc(channelId_ag, a, g)).get + assert(edge_ag.capacity == channel_ag.capacity) + assert(edge_ag.balance_opt === Some(33000000 msat)) + } } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/transactions/TransactionsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/transactions/TransactionsSpec.scala index 1c0f93d4d..d85db2d40 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/transactions/TransactionsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/transactions/TransactionsSpec.scala @@ -407,7 +407,7 @@ class TransactionsSpec extends AnyFunSuite with Logging { """[^$]+""").r // this regex extracts htlc direction and amounts val htlcRegex = - """.*HTLC ([a-z]+) amount ([0-9]+).*""".r + """.*HTLC [0-9] ([a-z]+) amount ([0-9]+).*""".r val dustLimit = 546 sat case class TestSetup(name: String, dustLimit: Satoshi, spec: CommitmentSpec, expectedFee: Satoshi) @@ -422,7 +422,7 @@ class TransactionsSpec extends AnyFunSuite with Logging { } }).toSet TestSetup(name, dustLimit, CommitmentSpec(htlcs = htlcs, feeratePerKw = feerate_per_kw.toLong, toLocal = MilliSatoshi(to_local_msat.toLong), toRemote = MilliSatoshi(to_remote_msat.toLong)), Satoshi(fee.toLong)) - }) + }).toSeq // simple non-reg test making sure we are not missing tests assert(tests.size === 15, "there were 15 tests at ec99f893f320e8c88f564c1c8566f3454f0f1f5f") diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/ExtraDirectives.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/ExtraDirectives.scala index 21f97cf9c..5b5430443 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/ExtraDirectives.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/ExtraDirectives.scala @@ -42,6 +42,7 @@ trait ExtraDirectives extends Directives { val channelIdFormParam_opt = "channelId".as[Option[ByteVector32]](sha256HashUnmarshaller) val channelIdsFormParam_opt = "channelIds".as[Option[List[ByteVector32]]](sha256HashesUnmarshaller) val nodeIdFormParam_opt = "nodeId".as[Option[PublicKey]](publicKeyUnmarshaller) + val nodeIdsFormParam_opt = "nodeIds".as[Option[Set[PublicKey]]](publicKeysUnmarshaller) val paymentHashFormParam_opt = "paymentHash".as[Option[ByteVector32]](sha256HashUnmarshaller) val fromFormParam_opt = "from".as[Long] val toFormParam_opt = "to".as[Long] diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/FormParamExtractors.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/FormParamExtractors.scala index 9ea740b86..7d1cc3cd0 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/FormParamExtractors.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/FormParamExtractors.scala @@ -38,6 +38,10 @@ object FormParamExtractors { PublicKey(ByteVector.fromValidHex(str)) } + implicit val publicKeysUnmarshaller: Deserializer[Option[String], Option[Set[PublicKey]]] = strictDeserializer { bin => + bin.split(",").map(str => PublicKey(ByteVector.fromValidHex(str))).toSet + } + implicit val binaryDataUnmarshaller: Deserializer[Option[String], Option[ByteVector]] = strictDeserializer { str => ByteVector.fromValidHex(str) } diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala index df7a5f349..9ac7779ac 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala @@ -115,7 +115,12 @@ trait Service extends ExtraDirectives with Logging { } } ~ path("peers") { - complete(eclairApi.peersInfo()) + complete(eclairApi.peers()) + } ~ + path("nodes") { + formFields(nodeIdsFormParam_opt) { nodeIds_opt => + complete(eclairApi.nodes(nodeIds_opt)) + } } ~ path("channels") { formFields(nodeIdFormParam_opt) { toRemoteNodeId_opt => @@ -127,9 +132,6 @@ trait Service extends ExtraDirectives with Logging { complete(eclairApi.channelInfo(channelIdentifier)) } } ~ - path("allnodes") { - complete(eclairApi.allNodes()) - } ~ path("allchannels") { complete(eclairApi.allChannels()) } ~ @@ -228,7 +230,9 @@ trait Service extends ExtraDirectives with Logging { } } ~ path("channelstats") { - complete(eclairApi.channelStats()) + formFields(fromFormParam_opt.?, toFormParam_opt.?) { (from_opt, to_opt) => + complete(eclairApi.channelStats(from_opt, to_opt)) + } } ~ path("usablebalances") { complete(eclairApi.usableBalances()) diff --git a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala index 53c5d8bae..216611eda 100644 --- a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala +++ b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala @@ -123,7 +123,7 @@ class ApiServiceSpec extends AnyFunSuiteLike with ScalatestRouteTest with RouteT test("'peers' should ask the switchboard for current known peers") { val mockEclair = mock[Eclair] val service = new MockService(mockEclair) - mockEclair.peersInfo()(any[Timeout]) returns Future.successful(List( + mockEclair.peers()(any[Timeout]) returns Future.successful(List( PeerInfo( nodeId = aliceNodeId, state = "CONNECTED", @@ -142,7 +142,7 @@ class ApiServiceSpec extends AnyFunSuiteLike with ScalatestRouteTest with RouteT assert(handled) assert(status == OK) val response = responseAs[String] - mockEclair.peersInfo()(any[Timeout]).wasCalled(once) + mockEclair.peers()(any[Timeout]).wasCalled(once) matchTestJson("peers", response) } }