mirror of
https://github.com/ACINQ/eclair.git
synced 2025-02-24 14:50:46 +01:00
Improve CustomCommitmentsPlugin methods (#1613)
* Add node params and logger parameters to plugin htlc methods * Refactor helper function to make it available for plugins
This commit is contained in:
parent
848b433836
commit
ed61b577df
3 changed files with 32 additions and 26 deletions
|
@ -16,6 +16,7 @@
|
|||
|
||||
package fr.acinq.eclair
|
||||
|
||||
import akka.event.LoggingAdapter
|
||||
import fr.acinq.bitcoin.ByteVector32
|
||||
import fr.acinq.bitcoin.Crypto.PublicKey
|
||||
import fr.acinq.eclair.channel.Origin
|
||||
|
@ -51,12 +52,12 @@ trait CustomCommitmentsPlugin extends PluginParams {
|
|||
* expire. If your plugin defines non-standard HTLCs, and they need to be automatically failed, they should be
|
||||
* returned by this method.
|
||||
*/
|
||||
def getIncomingHtlcs: Seq[IncomingHtlc]
|
||||
def getIncomingHtlcs(nodeParams: NodeParams, log: LoggingAdapter): Seq[IncomingHtlc]
|
||||
|
||||
/**
|
||||
* Outgoing HTLC sets that are still pending may either succeed or fail: we need to watch them to properly forward the
|
||||
* result upstream to preserve channels. If you have non-standard HTLCs that may be in this situation, they should be
|
||||
* returned by this method.
|
||||
*/
|
||||
def getHtlcsRelayedOut(htlcsIn: Seq[IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]]
|
||||
def getHtlcsRelayedOut(htlcsIn: Seq[IncomingHtlc], nodeParams: NodeParams, log: LoggingAdapter): Map[Origin, Set[(ByteVector32, Long)]]
|
||||
}
|
|
@ -65,9 +65,9 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial
|
|||
// result upstream to preserve channels.
|
||||
val brokenHtlcs: BrokenHtlcs = {
|
||||
val channels = listLocalChannels(nodeParams.db.channels)
|
||||
val nonStandardIncomingHtlcs: Seq[IncomingHtlc] = nodeParams.pluginParams.collect { case p: CustomCommitmentsPlugin => p.getIncomingHtlcs }.flatten
|
||||
val nonStandardIncomingHtlcs: Seq[IncomingHtlc] = nodeParams.pluginParams.collect { case p: CustomCommitmentsPlugin => p.getIncomingHtlcs(nodeParams, log) }.flatten
|
||||
val htlcsIn: Seq[IncomingHtlc] = getIncomingHtlcs(channels, nodeParams.db.payments, nodeParams.privateKey) ++ nonStandardIncomingHtlcs
|
||||
val nonStandardRelayedOutHtlcs: Map[Origin, Set[(ByteVector32, Long)]] = nodeParams.pluginParams.collect { case p: CustomCommitmentsPlugin => p.getHtlcsRelayedOut(htlcsIn) }.flatten.toMap
|
||||
val nonStandardRelayedOutHtlcs: Map[Origin, Set[(ByteVector32, Long)]] = nodeParams.pluginParams.collect { case p: CustomCommitmentsPlugin => p.getHtlcsRelayedOut(htlcsIn, nodeParams, log) }.flatten.toMap
|
||||
val relayedOut: Map[Origin, Set[(ByteVector32, Long)]] = getHtlcsRelayedOut(channels, htlcsIn) ++ nonStandardRelayedOutHtlcs
|
||||
|
||||
val notRelayed = htlcsIn.filterNot(htlcIn => relayedOut.keys.exists(origin => matchesOrigin(htlcIn.add, origin)))
|
||||
|
@ -329,9 +329,24 @@ object PostRestartHtlcCleaner {
|
|||
private def isPendingUpstream(channelId: ByteVector32, htlcId: Long, htlcsIn: Seq[IncomingHtlc]): Boolean =
|
||||
htlcsIn.exists(htlc => htlc.add.channelId == channelId && htlc.add.id == htlcId)
|
||||
|
||||
def groupByOrigin(htlcsOut: Seq[(Origin, ByteVector32, Long)], htlcsIn: Seq[IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]] =
|
||||
htlcsOut
|
||||
.groupBy { case (origin, _, _) => origin }
|
||||
.mapValues(_.map { case (_, channelId, htlcId) => (channelId, htlcId) }.toSet)
|
||||
// We are only interested in HTLCs that are pending upstream (not fulfilled nor failed yet).
|
||||
// It may be the case that we have unresolved HTLCs downstream that have been resolved upstream when the downstream
|
||||
// channel is closing (e.g. due to an HTLC timeout) because cooperatively failing the HTLC downstream will be
|
||||
// instant whereas the uncooperative close of the downstream channel will take time.
|
||||
.filterKeys {
|
||||
case _: Origin.Local => true
|
||||
case o: Origin.ChannelRelayed => isPendingUpstream(o.originChannelId, o.originHtlcId, htlcsIn)
|
||||
case o: Origin.TrampolineRelayed => o.htlcs.exists { case (channelId, htlcId) => isPendingUpstream(channelId, htlcId, htlcsIn) }
|
||||
}
|
||||
.toMap
|
||||
|
||||
/** @return pending outgoing HTLCs, grouped by their upstream origin. */
|
||||
private def getHtlcsRelayedOut(channels: Seq[HasCommitments], htlcsIn: Seq[IncomingHtlc])(implicit log: LoggingAdapter): Map[Origin, Set[(ByteVector32, Long)]] = {
|
||||
channels
|
||||
val htlcsOut = channels
|
||||
.flatMap { c =>
|
||||
// Filter out HTLCs that will never reach the blockchain or have already been timed-out on-chain.
|
||||
val htlcsToIgnore: Set[Long] = c match {
|
||||
|
@ -361,18 +376,7 @@ object PostRestartHtlcCleaner {
|
|||
}
|
||||
c.commitments.originChannels.collect { case (outgoingHtlcId, origin) if !htlcsToIgnore.contains(outgoingHtlcId) => (origin, c.channelId, outgoingHtlcId) }
|
||||
}
|
||||
.groupBy { case (origin, _, _) => origin }
|
||||
.mapValues(_.map { case (_, channelId, htlcId) => (channelId, htlcId) }.toSet)
|
||||
// We are only interested in HTLCs that are pending upstream (not fulfilled nor failed yet).
|
||||
// It may be the case that we have unresolved HTLCs downstream that have been resolved upstream when the downstream
|
||||
// channel is closing (e.g. due to an HTLC timeout) because cooperatively failing the HTLC downstream will be
|
||||
// instant whereas the uncooperative close of the downstream channel will take time.
|
||||
.filterKeys {
|
||||
case _: Origin.Local => true
|
||||
case o: Origin.ChannelRelayed => isPendingUpstream(o.originChannelId, o.originHtlcId, htlcsIn)
|
||||
case o: Origin.TrampolineRelayed => o.htlcs.exists { case (channelId, htlcId) => isPendingUpstream(channelId, htlcId, htlcsIn) }
|
||||
}
|
||||
.toMap
|
||||
groupByOrigin(htlcsOut, htlcsIn)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -20,6 +20,7 @@ import java.util.UUID
|
|||
|
||||
import akka.Done
|
||||
import akka.actor.ActorRef
|
||||
import akka.event.LoggingAdapter
|
||||
import akka.testkit.TestProbe
|
||||
import fr.acinq.bitcoin.Crypto.PublicKey
|
||||
import fr.acinq.bitcoin.{Block, ByteVector32, Crypto, Satoshi}
|
||||
|
@ -555,9 +556,9 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit
|
|||
val nonRelayedHtlc2In = buildHtlcIn(1L, channelId_ab_1, relayedPaymentHash)
|
||||
|
||||
val pluginParams = new CustomCommitmentsPlugin {
|
||||
def name = "test with incoming HTLC from remote"
|
||||
def getIncomingHtlcs: Seq[PostRestartHtlcCleaner.IncomingHtlc] = List(PostRestartHtlcCleaner.IncomingHtlc(relayedHtlc1In.add, None), PostRestartHtlcCleaner.IncomingHtlc(nonRelayedHtlc2In.add, None))
|
||||
def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]] = Map.empty
|
||||
override def name = "test with incoming HTLC from remote"
|
||||
override def getIncomingHtlcs(np: NodeParams, log: LoggingAdapter): Seq[PostRestartHtlcCleaner.IncomingHtlc] = List(PostRestartHtlcCleaner.IncomingHtlc(relayedHtlc1In.add, None), PostRestartHtlcCleaner.IncomingHtlc(nonRelayedHtlc2In.add, None))
|
||||
override def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc], np: NodeParams, log: LoggingAdapter): Map[Origin, Set[(ByteVector32, Long)]] = Map.empty
|
||||
}
|
||||
|
||||
val nodeParams1 = nodeParams.copy(pluginParams = List(pluginParams))
|
||||
|
@ -602,9 +603,9 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit
|
|||
val nonRelayedHtlcIn = buildHtlcIn(1L, channelId_ab_2, relayedPaymentHash)
|
||||
|
||||
val pluginParams = new CustomCommitmentsPlugin {
|
||||
def name = "test with outgoing HTLC to remote"
|
||||
def getIncomingHtlcs: Seq[PostRestartHtlcCleaner.IncomingHtlc] = List.empty
|
||||
def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]] = Map(trampolineRelayed -> Set((channelId_ab_1, 10L)))
|
||||
override def name = "test with outgoing HTLC to remote"
|
||||
override def getIncomingHtlcs(np: NodeParams, log: LoggingAdapter): Seq[PostRestartHtlcCleaner.IncomingHtlc] = List.empty
|
||||
override def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc], np: NodeParams, log: LoggingAdapter): Map[Origin, Set[(ByteVector32, Long)]] = Map(trampolineRelayed -> Set((channelId_ab_1, 10L)))
|
||||
}
|
||||
|
||||
val nodeParams1 = nodeParams.copy(pluginParams = List(pluginParams))
|
||||
|
@ -628,9 +629,9 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit
|
|||
val relayedHtlc1In = buildHtlcIn(0L, channelId_ab_1, trampolineRelayedPaymentHash)
|
||||
|
||||
val pluginParams = new CustomCommitmentsPlugin {
|
||||
def name = "test with incoming HTLC from remote"
|
||||
def getIncomingHtlcs: Seq[PostRestartHtlcCleaner.IncomingHtlc] = List(PostRestartHtlcCleaner.IncomingHtlc(relayedHtlc1In.add, None))
|
||||
def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]] = Map.empty
|
||||
override def name = "test with incoming HTLC from remote"
|
||||
override def getIncomingHtlcs(np: NodeParams, log: LoggingAdapter): Seq[PostRestartHtlcCleaner.IncomingHtlc] = List(PostRestartHtlcCleaner.IncomingHtlc(relayedHtlc1In.add, None))
|
||||
override def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc], np: NodeParams, log: LoggingAdapter): Map[Origin, Set[(ByteVector32, Long)]] = Map.empty
|
||||
}
|
||||
|
||||
val cmd1 = CMD_FAIL_HTLC(id = 0L, reason = Left(ByteVector.empty), replyTo_opt = None)
|
||||
|
|
Loading…
Add table
Reference in a new issue