1
0
Fork 0
mirror of https://github.com/ACINQ/eclair.git synced 2025-02-24 14:50:46 +01:00

Improve CustomCommitmentsPlugin methods (#1613)

* Add node params and logger parameters to plugin htlc methods
* Refactor helper function to make it available for plugins
This commit is contained in:
Anton Kumaigorodski 2020-12-02 18:58:05 +02:00 committed by GitHub
parent 848b433836
commit ed61b577df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 26 deletions

View file

@ -16,6 +16,7 @@
package fr.acinq.eclair
import akka.event.LoggingAdapter
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.channel.Origin
@ -51,12 +52,12 @@ trait CustomCommitmentsPlugin extends PluginParams {
* expire. If your plugin defines non-standard HTLCs, and they need to be automatically failed, they should be
* returned by this method.
*/
def getIncomingHtlcs: Seq[IncomingHtlc]
def getIncomingHtlcs(nodeParams: NodeParams, log: LoggingAdapter): Seq[IncomingHtlc]
/**
* Outgoing HTLC sets that are still pending may either succeed or fail: we need to watch them to properly forward the
* result upstream to preserve channels. If you have non-standard HTLCs that may be in this situation, they should be
* returned by this method.
*/
def getHtlcsRelayedOut(htlcsIn: Seq[IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]]
def getHtlcsRelayedOut(htlcsIn: Seq[IncomingHtlc], nodeParams: NodeParams, log: LoggingAdapter): Map[Origin, Set[(ByteVector32, Long)]]
}

View file

@ -65,9 +65,9 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial
// result upstream to preserve channels.
val brokenHtlcs: BrokenHtlcs = {
val channels = listLocalChannels(nodeParams.db.channels)
val nonStandardIncomingHtlcs: Seq[IncomingHtlc] = nodeParams.pluginParams.collect { case p: CustomCommitmentsPlugin => p.getIncomingHtlcs }.flatten
val nonStandardIncomingHtlcs: Seq[IncomingHtlc] = nodeParams.pluginParams.collect { case p: CustomCommitmentsPlugin => p.getIncomingHtlcs(nodeParams, log) }.flatten
val htlcsIn: Seq[IncomingHtlc] = getIncomingHtlcs(channels, nodeParams.db.payments, nodeParams.privateKey) ++ nonStandardIncomingHtlcs
val nonStandardRelayedOutHtlcs: Map[Origin, Set[(ByteVector32, Long)]] = nodeParams.pluginParams.collect { case p: CustomCommitmentsPlugin => p.getHtlcsRelayedOut(htlcsIn) }.flatten.toMap
val nonStandardRelayedOutHtlcs: Map[Origin, Set[(ByteVector32, Long)]] = nodeParams.pluginParams.collect { case p: CustomCommitmentsPlugin => p.getHtlcsRelayedOut(htlcsIn, nodeParams, log) }.flatten.toMap
val relayedOut: Map[Origin, Set[(ByteVector32, Long)]] = getHtlcsRelayedOut(channels, htlcsIn) ++ nonStandardRelayedOutHtlcs
val notRelayed = htlcsIn.filterNot(htlcIn => relayedOut.keys.exists(origin => matchesOrigin(htlcIn.add, origin)))
@ -329,9 +329,24 @@ object PostRestartHtlcCleaner {
private def isPendingUpstream(channelId: ByteVector32, htlcId: Long, htlcsIn: Seq[IncomingHtlc]): Boolean =
htlcsIn.exists(htlc => htlc.add.channelId == channelId && htlc.add.id == htlcId)
def groupByOrigin(htlcsOut: Seq[(Origin, ByteVector32, Long)], htlcsIn: Seq[IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]] =
htlcsOut
.groupBy { case (origin, _, _) => origin }
.mapValues(_.map { case (_, channelId, htlcId) => (channelId, htlcId) }.toSet)
// We are only interested in HTLCs that are pending upstream (not fulfilled nor failed yet).
// It may be the case that we have unresolved HTLCs downstream that have been resolved upstream when the downstream
// channel is closing (e.g. due to an HTLC timeout) because cooperatively failing the HTLC downstream will be
// instant whereas the uncooperative close of the downstream channel will take time.
.filterKeys {
case _: Origin.Local => true
case o: Origin.ChannelRelayed => isPendingUpstream(o.originChannelId, o.originHtlcId, htlcsIn)
case o: Origin.TrampolineRelayed => o.htlcs.exists { case (channelId, htlcId) => isPendingUpstream(channelId, htlcId, htlcsIn) }
}
.toMap
/** @return pending outgoing HTLCs, grouped by their upstream origin. */
private def getHtlcsRelayedOut(channels: Seq[HasCommitments], htlcsIn: Seq[IncomingHtlc])(implicit log: LoggingAdapter): Map[Origin, Set[(ByteVector32, Long)]] = {
channels
val htlcsOut = channels
.flatMap { c =>
// Filter out HTLCs that will never reach the blockchain or have already been timed-out on-chain.
val htlcsToIgnore: Set[Long] = c match {
@ -361,18 +376,7 @@ object PostRestartHtlcCleaner {
}
c.commitments.originChannels.collect { case (outgoingHtlcId, origin) if !htlcsToIgnore.contains(outgoingHtlcId) => (origin, c.channelId, outgoingHtlcId) }
}
.groupBy { case (origin, _, _) => origin }
.mapValues(_.map { case (_, channelId, htlcId) => (channelId, htlcId) }.toSet)
// We are only interested in HTLCs that are pending upstream (not fulfilled nor failed yet).
// It may be the case that we have unresolved HTLCs downstream that have been resolved upstream when the downstream
// channel is closing (e.g. due to an HTLC timeout) because cooperatively failing the HTLC downstream will be
// instant whereas the uncooperative close of the downstream channel will take time.
.filterKeys {
case _: Origin.Local => true
case o: Origin.ChannelRelayed => isPendingUpstream(o.originChannelId, o.originHtlcId, htlcsIn)
case o: Origin.TrampolineRelayed => o.htlcs.exists { case (channelId, htlcId) => isPendingUpstream(channelId, htlcId, htlcsIn) }
}
.toMap
groupByOrigin(htlcsOut, htlcsIn)
}
/**

View file

@ -20,6 +20,7 @@ import java.util.UUID
import akka.Done
import akka.actor.ActorRef
import akka.event.LoggingAdapter
import akka.testkit.TestProbe
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.bitcoin.{Block, ByteVector32, Crypto, Satoshi}
@ -555,9 +556,9 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit
val nonRelayedHtlc2In = buildHtlcIn(1L, channelId_ab_1, relayedPaymentHash)
val pluginParams = new CustomCommitmentsPlugin {
def name = "test with incoming HTLC from remote"
def getIncomingHtlcs: Seq[PostRestartHtlcCleaner.IncomingHtlc] = List(PostRestartHtlcCleaner.IncomingHtlc(relayedHtlc1In.add, None), PostRestartHtlcCleaner.IncomingHtlc(nonRelayedHtlc2In.add, None))
def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]] = Map.empty
override def name = "test with incoming HTLC from remote"
override def getIncomingHtlcs(np: NodeParams, log: LoggingAdapter): Seq[PostRestartHtlcCleaner.IncomingHtlc] = List(PostRestartHtlcCleaner.IncomingHtlc(relayedHtlc1In.add, None), PostRestartHtlcCleaner.IncomingHtlc(nonRelayedHtlc2In.add, None))
override def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc], np: NodeParams, log: LoggingAdapter): Map[Origin, Set[(ByteVector32, Long)]] = Map.empty
}
val nodeParams1 = nodeParams.copy(pluginParams = List(pluginParams))
@ -602,9 +603,9 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit
val nonRelayedHtlcIn = buildHtlcIn(1L, channelId_ab_2, relayedPaymentHash)
val pluginParams = new CustomCommitmentsPlugin {
def name = "test with outgoing HTLC to remote"
def getIncomingHtlcs: Seq[PostRestartHtlcCleaner.IncomingHtlc] = List.empty
def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]] = Map(trampolineRelayed -> Set((channelId_ab_1, 10L)))
override def name = "test with outgoing HTLC to remote"
override def getIncomingHtlcs(np: NodeParams, log: LoggingAdapter): Seq[PostRestartHtlcCleaner.IncomingHtlc] = List.empty
override def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc], np: NodeParams, log: LoggingAdapter): Map[Origin, Set[(ByteVector32, Long)]] = Map(trampolineRelayed -> Set((channelId_ab_1, 10L)))
}
val nodeParams1 = nodeParams.copy(pluginParams = List(pluginParams))
@ -628,9 +629,9 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit
val relayedHtlc1In = buildHtlcIn(0L, channelId_ab_1, trampolineRelayedPaymentHash)
val pluginParams = new CustomCommitmentsPlugin {
def name = "test with incoming HTLC from remote"
def getIncomingHtlcs: Seq[PostRestartHtlcCleaner.IncomingHtlc] = List(PostRestartHtlcCleaner.IncomingHtlc(relayedHtlc1In.add, None))
def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]] = Map.empty
override def name = "test with incoming HTLC from remote"
override def getIncomingHtlcs(np: NodeParams, log: LoggingAdapter): Seq[PostRestartHtlcCleaner.IncomingHtlc] = List(PostRestartHtlcCleaner.IncomingHtlc(relayedHtlc1In.add, None))
override def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc], np: NodeParams, log: LoggingAdapter): Map[Origin, Set[(ByteVector32, Long)]] = Map.empty
}
val cmd1 = CMD_FAIL_HTLC(id = 0L, reason = Left(ByteVector.empty), replyTo_opt = None)