1
0
Fork 0
mirror of https://github.com/ACINQ/eclair.git synced 2025-02-23 14:40:34 +01:00

Add initial support for async payment trampoline relay (#2435)

This commit is contained in:
Richard Myers 2022-09-29 10:08:05 +02:00 committed by GitHub
parent 1b36697802
commit afdaf4619d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 262 additions and 18 deletions

View file

@ -70,6 +70,7 @@ eclair {
option_zeroconf = disabled
keysend = disabled
trampoline_payment_prototype = disabled
async_payment_prototype = disabled
}
// The following section lets you customize features for specific nodes.
// The overrides will be applied on top of the default features settings.
@ -149,6 +150,13 @@ eclair {
// Delay enforcement of channel fee updates
enforcement-delay = 10 minutes
}
async-payments {
// Maximum number of blocks to hold an async payment while waiting to receive a trigger from the receiver
hold-timeout-blocks = 1008
// Number of blocks before the incoming HTLC expires that an async payment must be triggered by the receiver
cancel-safety-before-timeout-blocks = 144
}
}
on-chain-fees {

View file

@ -281,6 +281,12 @@ object Features {
val mandatory = 148
}
// TODO: @remyers update feature bits once spec-ed (currently reserved here: https://github.com/lightning/bolts/pull/989)
case object AsyncPaymentPrototype extends Feature with InitFeature with InvoiceFeature {
val rfcName = "async_payment_prototype"
val mandatory = 152
}
val knownFeatures: Set[Feature] = Set(
DataLossProtect,
InitialRoutingSync,
@ -303,7 +309,8 @@ object Features {
PaymentMetadata,
ZeroConf,
KeySend,
TrampolinePaymentPrototype
TrampolinePaymentPrototype,
AsyncPaymentPrototype
)
// Features may depend on other features, as specified in Bolt 9.
@ -315,7 +322,8 @@ object Features {
AnchorOutputsZeroFeeHtlcTx -> (StaticRemoteKey :: Nil),
RouteBlinding -> (VariableLengthOnion :: Nil),
TrampolinePaymentPrototype -> (PaymentSecret :: Nil),
KeySend -> (VariableLengthOnion :: Nil)
KeySend -> (VariableLengthOnion :: Nil),
AsyncPaymentPrototype -> (TrampolinePaymentPrototype :: Nil)
)
case class FeatureException(message: String) extends IllegalArgumentException(message)

View file

@ -30,7 +30,7 @@ import fr.acinq.eclair.db._
import fr.acinq.eclair.io.MessageRelay.{NoRelay, RelayAll, RelayChannelsOnly, RelayPolicy}
import fr.acinq.eclair.io.PeerConnection
import fr.acinq.eclair.message.OnionMessages.OnionMessageConfig
import fr.acinq.eclair.payment.relay.Relayer.{RelayFees, RelayParams}
import fr.acinq.eclair.payment.relay.Relayer.{AsyncPaymentsParams, RelayFees, RelayParams}
import fr.acinq.eclair.router.Announcements.AddressException
import fr.acinq.eclair.router.Graph.{HeuristicsConstants, WeightRatios}
import fr.acinq.eclair.router.PathFindingExperimentConf
@ -419,6 +419,12 @@ object NodeParams extends Logging {
None
}
val asyncPaymentCancelSafetyBeforeTimeoutBlocks = CltvExpiryDelta(config.getInt("relay.async-payments.cancel-safety-before-timeout-blocks"))
require(asyncPaymentCancelSafetyBeforeTimeoutBlocks >= expiryDelta, "relay.async-payments.cancel-safety-before-timeout-blocks must not be less than channel.expiry-delta-blocks; this may lead to undesired channel closure")
val asyncPaymentHoldTimeoutBlocks = config.getInt("relay.async-payments.hold-timeout-blocks")
require(asyncPaymentHoldTimeoutBlocks >= (asyncPaymentCancelSafetyBeforeTimeoutBlocks + expiryDelta).toInt, "relay.async-payments.hold-timeout-blocks must not be less than relay.async-payments.cancel-safety-before-timeout-blocks + channel.expiry-delta-blocks; otherwise it will have no effect")
NodeParams(
nodeKeyManager = nodeKeyManager,
channelKeyManager = channelKeyManager,
@ -488,7 +494,8 @@ object NodeParams extends Logging {
publicChannelFees = getRelayFees(config.getConfig("relay.fees.public-channels")),
privateChannelFees = getRelayFees(config.getConfig("relay.fees.private-channels")),
minTrampolineFees = getRelayFees(config.getConfig("relay.fees.min-trampoline")),
enforcementDelay = FiniteDuration(config.getDuration("relay.fees.enforcement-delay").getSeconds, TimeUnit.SECONDS)
enforcementDelay = FiniteDuration(config.getDuration("relay.fees.enforcement-delay").getSeconds, TimeUnit.SECONDS),
asyncPaymentsParams = AsyncPaymentsParams(asyncPaymentHoldTimeoutBlocks, asyncPaymentCancelSafetyBeforeTimeoutBlocks)
),
db = database,
autoReconnect = config.getBoolean("auto-reconnect"),

View file

@ -120,6 +120,8 @@ case class PaymentMetadataReceived(paymentHash: ByteVector32, paymentMetadata: B
case class PaymentSettlingOnChain(id: UUID, amount: MilliSatoshi, paymentHash: ByteVector32, timestamp: TimestampMilli = TimestampMilli.now()) extends PaymentEvent
case class WaitingToRelayPayment(remoteNodeId: PublicKey, paymentHash: ByteVector32, timestamp: TimestampMilli = TimestampMilli.now()) extends PaymentEvent
sealed trait PaymentFailure {
// @formatter:off
def amount: MilliSatoshi

View file

@ -23,6 +23,7 @@ import akka.actor.typed.scaladsl.adapter.{TypedActorContextOps, TypedActorRefOps
import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import com.softwaremill.quicklens.ModifyPimp
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.eclair.blockchain.CurrentBlockHeight
import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC}
import fr.acinq.eclair.db.PendingCommandsDb
import fr.acinq.eclair.payment.IncomingPaymentPacket.NodeRelayPacket
@ -39,7 +40,7 @@ import fr.acinq.eclair.router.Router.RouteParams
import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound}
import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload}
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, UInt64, nodeFee, randomBytes32}
import fr.acinq.eclair.{BlockHeight, CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, UInt64, nodeFee, randomBytes32}
import java.util.UUID
import scala.collection.immutable.Queue
@ -54,12 +55,15 @@ object NodeRelay {
sealed trait Command
case class Relay(nodeRelayPacket: IncomingPaymentPacket.NodeRelayPacket) extends Command
case object Stop extends Command
case object RelayAsyncPayment extends Command
case object CancelAsyncPayment extends Command
private case class WrappedMultiPartExtraPaymentReceived(mppExtraReceived: MultiPartPaymentFSM.ExtraPaymentReceived[HtlcPart]) extends Command
private case class WrappedMultiPartPaymentFailed(mppFailed: MultiPartPaymentFSM.MultiPartPaymentFailed) extends Command
private case class WrappedMultiPartPaymentSucceeded(mppSucceeded: MultiPartPaymentFSM.MultiPartPaymentSucceeded) extends Command
private case class WrappedPreimageReceived(preimageReceived: PreimageReceived) extends Command
private case class WrappedPaymentSent(paymentSent: PaymentSent) extends Command
private case class WrappedPaymentFailed(paymentFailed: PaymentFailed) extends Command
private case class WrappedCurrentBlockHeight(currentBlockHeight: BlockHeight) extends Command
// @formatter:on
trait OutgoingPaymentFactory {
@ -200,10 +204,45 @@ class NodeRelay private(nodeParams: NodeParams,
rejectPayment(upstream, Some(failure))
stopping()
case None =>
doSend(upstream, nextPayload, nextPacket)
if (nextPayload.isAsyncPayment && nodeParams.features.hasFeature(Features.AsyncPaymentPrototype)) {
waitForTrigger(upstream, nextPayload, nextPacket)
} else {
doSend(upstream, nextPayload, nextPacket)
}
}
}
private def waitForTrigger(upstream: Upstream.Trampoline, nextPayload: IntermediatePayload.NodeRelay.Standard, nextPacket: OnionRoutingPacket): Behavior[Command] = {
context.log.info(s"waiting for async payment to trigger before relaying trampoline payment (amountIn=${upstream.amountIn} expiryIn=${upstream.expiryIn} amountOut=${nextPayload.amountToForward} expiryOut=${nextPayload.outgoingCltv}, asyncPaymentsParams=${nodeParams.relayParams.asyncPaymentsParams})")
// a trigger must be received before waiting more than `holdTimeoutBlocks`
val timeoutBlock: BlockHeight = nodeParams.currentBlockHeight + nodeParams.relayParams.asyncPaymentsParams.holdTimeoutBlocks
// a trigger must be received `cancelSafetyBeforeTimeoutBlocks` before the incoming payment cltv expiry
val safetyBlock: BlockHeight = (upstream.expiryIn - nodeParams.relayParams.asyncPaymentsParams.cancelSafetyBeforeTimeout).blockHeight
val messageAdapter = context.messageAdapter[CurrentBlockHeight](cbc => WrappedCurrentBlockHeight(cbc.blockHeight))
context.system.eventStream ! EventStream.Subscribe[CurrentBlockHeight](messageAdapter)
// TODO: send the WaitingToRelayPayment message to an actor that watches for the payment receiver to come back online before sending the RelayAsyncPayment message
context.system.eventStream ! EventStream.Publish(WaitingToRelayPayment(nextPayload.outgoingNodeId, paymentHash))
Behaviors.receiveMessagePartial {
case WrappedCurrentBlockHeight(blockHeight) if blockHeight >= safetyBlock =>
context.log.warn(s"rejecting async payment at block $blockHeight; was not triggered ${nodeParams.relayParams.asyncPaymentsParams.cancelSafetyBeforeTimeout} safety blocks before upstream cltv expiry at ${upstream.expiryIn}")
rejectPayment(upstream, Some(TemporaryNodeFailure)) // TODO: replace failure type when async payment spec is finalized
stopping()
case WrappedCurrentBlockHeight(blockHeight) if blockHeight >= timeoutBlock =>
context.log.warn(s"rejecting async payment at block $blockHeight; was not triggered after waiting ${nodeParams.relayParams.asyncPaymentsParams.holdTimeoutBlocks} blocks")
rejectPayment(upstream, Some(TemporaryNodeFailure)) // TODO: replace failure type when async payment spec is finalized
stopping()
case WrappedCurrentBlockHeight(blockHeight) =>
Behaviors.same
case CancelAsyncPayment =>
context.log.warn(s"payment sender canceled a waiting async payment")
rejectPayment(upstream, Some(TemporaryNodeFailure)) // TODO: replace failure type when async payment spec is finalized
stopping()
case RelayAsyncPayment =>
doSend(upstream, nextPayload, nextPacket)
}
}
private def doSend(upstream: Upstream.Trampoline, nextPayload: IntermediatePayload.NodeRelay.Standard, nextPacket: OnionRoutingPacket): Behavior[Command] = {
context.log.info(s"relaying trampoline payment (amountIn=${upstream.amountIn} expiryIn=${upstream.expiryIn} amountOut=${nextPayload.amountToForward} expiryOut=${nextPayload.outgoingCltv})")
relay(upstream, nextPayload, nextPacket)

View file

@ -29,7 +29,7 @@ import fr.acinq.eclair.channel._
import fr.acinq.eclair.db.PendingCommandsDb
import fr.acinq.eclair.payment._
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{Logs, MilliSatoshi, NodeParams}
import fr.acinq.eclair.{CltvExpiryDelta, Logs, MilliSatoshi, NodeParams}
import grizzled.slf4j.Logging
import scala.concurrent.Promise
@ -126,10 +126,13 @@ object Relayer extends Logging {
require(feeProportionalMillionths >= 0.0, "feeProportionalMillionths must be nonnegative")
}
case class AsyncPaymentsParams(holdTimeoutBlocks: Int, cancelSafetyBeforeTimeout: CltvExpiryDelta)
case class RelayParams(publicChannelFees: RelayFees,
privateChannelFees: RelayFees,
minTrampolineFees: RelayFees,
enforcementDelay: FiniteDuration) {
enforcementDelay: FiniteDuration,
asyncPaymentsParams: AsyncPaymentsParams) {
def defaultFees(announceChannel: Boolean): RelayFees = {
if (announceChannel) {
publicChannelFees

View file

@ -181,6 +181,9 @@ object OnionPaymentPayloadTlv {
/** Pre-image included by the sender of a payment in case of a donation */
case class KeySend(paymentPreimage: ByteVector32) extends OnionPaymentPayloadTlv
/** Only included for intermediate trampoline nodes that should wait before forwarding this payment */
case class AsyncPayment() extends OnionPaymentPayloadTlv
}
object PaymentOnion {
@ -301,6 +304,8 @@ object PaymentOnion {
val paymentMetadata = records.get[PaymentMetadata].map(_.data)
val invoiceFeatures = records.get[InvoiceFeatures].map(_.features)
val invoiceRoutingInfo = records.get[InvoiceRoutingInfo].map(_.extraHops)
// The following fields are only included in the async payment case.
val isAsyncPayment: Boolean = records.get[AsyncPayment].isDefined
}
object Standard {
@ -331,6 +336,11 @@ object PaymentOnion {
).flatten
Standard(TlvStream(tlvs))
}
/** Create a standard trampoline inner payload instructing the trampoline node to wait for a trigger before sending an async payment. */
def createNodeRelayForAsyncPayment(amount: MilliSatoshi, expiry: CltvExpiry, nextNodeId: PublicKey): Standard = {
Standard(TlvStream(AmountToForward(amount), OutgoingCltv(expiry), OutgoingNodeId(nextNodeId), AsyncPayment()))
}
}
}
}
@ -475,6 +485,8 @@ object PaymentOnionCodecs {
private val keySend: Codec[KeySend] = variableSizeBytesLong(varintoverflow, bytes32).as[KeySend]
private val asyncPayment: Codec[AsyncPayment] = variableSizeBytesLong(varintoverflow, provide(AsyncPayment())).as[AsyncPayment]
private val onionTlvCodec = discriminated[OnionPaymentPayloadTlv].by(varint)
.typecase(UInt64(2), amountToForward)
.typecase(UInt64(4), outgoingCltv)
@ -489,6 +501,7 @@ object PaymentOnionCodecs {
.typecase(UInt64(66098), outgoingNodeId)
.typecase(UInt64(66099), invoiceRoutingInfo)
.typecase(UInt64(66100), trampolineOnion)
.typecase(UInt64(181324718L), asyncPayment)
.typecase(UInt64(5482373484L), keySend)
val perHopPayloadCodec: Codec[TlvStream[OnionPaymentPayloadTlv]] = TlvCodecs.lengthPrefixedTlvStream[OnionPaymentPayloadTlv](onionTlvCodec).complete

View file

@ -26,7 +26,7 @@ import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyMa
import fr.acinq.eclair.io.MessageRelay.RelayAll
import fr.acinq.eclair.io.{Peer, PeerConnection}
import fr.acinq.eclair.message.OnionMessages.OnionMessageConfig
import fr.acinq.eclair.payment.relay.Relayer.{RelayFees, RelayParams}
import fr.acinq.eclair.payment.relay.Relayer.{AsyncPaymentsParams, RelayFees, RelayParams}
import fr.acinq.eclair.router.Graph.WeightRatios
import fr.acinq.eclair.router.PathFindingExperimentConf
import fr.acinq.eclair.router.Router.{MultiPartParams, PathFindingConf, RouterConf, SearchBoundaries}
@ -143,7 +143,8 @@ object TestConstants {
minTrampolineFees = RelayFees(
feeBase = 548000 msat,
feeProportionalMillionths = 30),
enforcementDelay = 10 minutes),
enforcementDelay = 10 minutes,
asyncPaymentsParams = AsyncPaymentsParams(1008, CltvExpiryDelta(144))),
db = TestDatabases.inMemoryDb(),
autoReconnect = false,
initialRandomReconnectDelay = 5 seconds,
@ -203,7 +204,7 @@ object TestConstants {
relayPolicy = RelayAll,
timeout = 1 minute
),
purgeInvoicesInterval = None,
purgeInvoicesInterval = None
)
def channelParams: LocalParams = Peer.makeChannelParams(
@ -286,7 +287,8 @@ object TestConstants {
minTrampolineFees = RelayFees(
feeBase = 548000 msat,
feeProportionalMillionths = 30),
enforcementDelay = 10 minutes),
enforcementDelay = 10 minutes,
asyncPaymentsParams = AsyncPaymentsParams(1008, CltvExpiryDelta(144))),
db = TestDatabases.inMemoryDb(),
autoReconnect = false,
initialRandomReconnectDelay = 5 seconds,

View file

@ -22,13 +22,16 @@ import akka.actor.typed.ActorRef
import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.scaladsl.ActorContext
import akka.actor.typed.scaladsl.adapter._
import com.softwaremill.quicklens.ModifyPimp
import com.typesafe.config.ConfigFactory
import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto}
import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional}
import fr.acinq.eclair.Features.{BasicMultiPartPayment, PaymentSecret, VariableLengthOnion}
import fr.acinq.eclair.Features.{AsyncPaymentPrototype, BasicMultiPartPayment, PaymentSecret, VariableLengthOnion}
import fr.acinq.eclair.blockchain.CurrentBlockHeight
import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register}
import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop
import fr.acinq.eclair.payment.IncomingPaymentPacket.NodeRelayPacket
import fr.acinq.eclair.payment.Invoice.BasicEdge
import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream
import fr.acinq.eclair.payment._
@ -40,9 +43,9 @@ import fr.acinq.eclair.router.Router.RouteRequest
import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound}
import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload}
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TestConstants, UInt64, randomBytes, randomBytes32, randomKey}
import org.scalatest.Outcome
import fr.acinq.eclair.{BlockHeight, CltvExpiry, CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TestConstants, UInt64, randomBytes, randomBytes32, randomKey}
import org.scalatest.funsuite.FixtureAnyFunSuiteLike
import org.scalatest.{Outcome, Tag}
import scodec.bits.HexStringSyntax
import java.util.UUID
@ -82,7 +85,10 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
}
override def withFixture(test: OneArgTest): Outcome = {
val nodeParams = TestConstants.Bob.nodeParams.copy(multiPartPaymentExpiry = 5 seconds)
val nodeParams = TestConstants.Bob.nodeParams
.modify(_.multiPartPaymentExpiry).setTo(5 seconds)
.modify(_.features).setToIf(test.tags.contains("async_payments"))(Features(AsyncPaymentPrototype -> Optional))
.modify(_.relayParams.asyncPaymentsParams.holdTimeoutBlocks).setToIf(test.tags.contains("long_hold_timeout"))(200000) // timeout after payment expires
val router = TestProbe[Any]("router")
val register = TestProbe[Any]("register")
val eventListener = TestProbe[PaymentEvent]("event-listener")
@ -289,7 +295,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
import f._
val expiryIn = CltvExpiry(500000)
val expiryOut = CltvExpiry(300000) // not ok (chain heigh = 400000)
val expiryOut = CltvExpiry(300000) // not ok (chain height = 400000)
val p = createValidIncomingPacket(2000000 msat, 2000000 msat, expiryIn, 1000000 msat, expiryOut)
val (nodeRelayer, _) = f.createNodeRelay(p)
nodeRelayer ! NodeRelay.Relay(p)
@ -323,6 +329,146 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
register.expectNoMessage(100 millis)
}
test("fail to relay when not triggered before the hold timeout", Tag("async_payments")) { f =>
import f._
val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head)
incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p))
// wait until the NodeRelay is waiting for the trigger
eventListener.expectMessageType[WaitingToRelayPayment]
mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a trigger
// publish a block height at the timeout height
assert(asyncTimeoutHeight(nodeParams) < asyncSafetyHeight(incomingAsyncPayment, nodeParams))
system.eventStream ! EventStream.Publish(CurrentBlockHeight(asyncTimeoutHeight(nodeParams)))
incomingAsyncPayment.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure), commit = true))
}
register.expectNoMessage(100 millis)
}
test("relay the payment when triggered while waiting", Tag("async_payments"), Tag("long_hold_timeout")) { f =>
import f._
val (nodeRelayer, parent) = createNodeRelay(incomingAsyncPayment.head)
incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p))
// wait until the NodeRelay is waiting for the trigger
eventListener.expectMessageType[WaitingToRelayPayment]
mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a trigger
// publish a block height one block before the safety interval before the current incoming payment expires (and before the timeout height)
assert(asyncTimeoutHeight(nodeParams) > asyncSafetyHeight(incomingAsyncPayment, nodeParams))
system.eventStream ! EventStream.Publish(CurrentBlockHeight(asyncSafetyHeight(incomingAsyncPayment, nodeParams) - 1))
// send trigger to forward the payment
nodeRelayer ! NodeRelay.RelayAsyncPayment
// upstream payment relayed
val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig]
validateOutgoingCfg(outgoingCfg, Upstream.Trampoline(incomingAsyncPayment.map(_.add)))
val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment]
validateOutgoingPayment(outgoingPayment)
// those are adapters for pay-fsm messages
val nodeRelayerAdapters = outgoingPayment.replyTo
// A first downstream HTLC is fulfilled: we should immediately forward the fulfill upstream.
nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage)
incomingAsyncPayment.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true))
}
// Once all the downstream payments have settled, we should emit the relayed event.
nodeRelayerAdapters ! createSuccessEvent()
val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed]
validateRelayEvent(relayEvent)
assert(relayEvent.incoming.toSet == incomingAsyncPayment.map(i => PaymentRelayed.Part(i.add.amountMsat, i.add.channelId)).toSet)
assert(relayEvent.outgoing.nonEmpty)
parent.expectMessageType[NodeRelayer.RelayComplete]
register.expectNoMessage(100 millis)
}
test("fail to relay when not triggered before the incoming expiry safety timeout", Tag("async_payments"), Tag("long_hold_timeout")) { f =>
import f._
val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head)
incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p))
mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a complete upstream payment
// publish block height at the cancel-safety-before-timeout-block threshold before the current incoming payment expiry
assert(asyncTimeoutHeight(nodeParams) > asyncSafetyHeight(incomingAsyncPayment, nodeParams))
system.eventStream ! EventStream.Publish(CurrentBlockHeight(asyncSafetyHeight(incomingAsyncPayment, nodeParams)))
incomingAsyncPayment.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure), commit = true))
}
register.expectNoMessage(100 millis)
}
test("fail to relay payment when canceled by sender before timeout", Tag("async_payments")) { f =>
import f._
val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head)
incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p))
// wait until the NodeRelay is waiting for the trigger
eventListener.expectMessageType[WaitingToRelayPayment]
mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a trigger
// fail the payment if waiting when payment sender sends cancel message
nodeRelayer ! NodeRelay.CancelAsyncPayment
incomingAsyncPayment.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure), commit = true))
}
register.expectNoMessage(100 millis)
}
test("relay the payment immediately when the async payment feature is disabled") { f =>
import f._
assert(!nodeParams.features.hasFeature(AsyncPaymentPrototype))
val (nodeRelayer, parent) = createNodeRelay(incomingAsyncPayment.head)
incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p))
// upstream payment relayed
val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig]
validateOutgoingCfg(outgoingCfg, Upstream.Trampoline(incomingAsyncPayment.map(_.add)))
val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment]
validateOutgoingPayment(outgoingPayment)
// those are adapters for pay-fsm messages
val nodeRelayerAdapters = outgoingPayment.replyTo
// A first downstream HTLC is fulfilled: we should immediately forward the fulfill upstream.
nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage)
incomingAsyncPayment.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]]
assert(fwd.channelId == p.add.channelId)
assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true))
}
// Once all the downstream payments have settled, we should emit the relayed event.
nodeRelayerAdapters ! createSuccessEvent()
val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed]
validateRelayEvent(relayEvent)
assert(relayEvent.incoming.toSet == incomingAsyncPayment.map(i => PaymentRelayed.Part(i.add.amountMsat, i.add.channelId)).toSet)
assert(relayEvent.outgoing.nonEmpty)
parent.expectMessageType[NodeRelayer.RelayComplete]
register.expectNoMessage(100 millis)
}
test("fail to relay when fees are insufficient (single-part)") { f =>
import f._
@ -723,12 +869,19 @@ object NodeRelayerSpec {
val incomingAmount = 5000000 msat
val incomingSecret = randomBytes32()
val incomingMultiPart = Seq(
val incomingMultiPart: Seq[NodeRelayPacket] = Seq(
createValidIncomingPacket(2000000 msat, incomingAmount, CltvExpiry(500000), outgoingAmount, outgoingExpiry),
createValidIncomingPacket(2000000 msat, incomingAmount, CltvExpiry(499999), outgoingAmount, outgoingExpiry),
createValidIncomingPacket(1000000 msat, incomingAmount, CltvExpiry(499999), outgoingAmount, outgoingExpiry)
)
val incomingSinglePart = createValidIncomingPacket(incomingAmount, incomingAmount, CltvExpiry(500000), outgoingAmount, outgoingExpiry)
val incomingAsyncPayment: Seq[NodeRelayPacket] = incomingMultiPart.map(p => p.copy(innerPayload = IntermediatePayload.NodeRelay.Standard.createNodeRelayForAsyncPayment(p.innerPayload.amountToForward, p.innerPayload.outgoingCltv, outgoingNodeId)))
def asyncTimeoutHeight(nodeParams: NodeParams): BlockHeight =
nodeParams.currentBlockHeight + nodeParams.relayParams.asyncPaymentsParams.holdTimeoutBlocks
def asyncSafetyHeight(paymentPackets: Seq[NodeRelayPacket], nodeParams: NodeParams): BlockHeight =
(paymentPackets.map(_.outerPayload.expiry).min - nodeParams.relayParams.asyncPaymentsParams.cancelSafetyBeforeTimeout).blockHeight
def createSuccessEvent(): PaymentSent =
PaymentSent(relayId, paymentHash, paymentPreimage, outgoingAmount, outgoingNodeId, Seq(PaymentSent.PartialPayment(UUID.randomUUID(), outgoingAmount, 10 msat, randomBytes32(), None)))

View file

@ -343,4 +343,13 @@ class PaymentOnionSpec extends AnyFunSuite {
}
}
test("encode/decode empty AsyncPayment TLV") {
val tlvs = TlvStream[OnionPaymentPayloadTlv](AsyncPayment())
val bin = hex"06 fe0acecbae00"
val encoded = perHopPayloadCodec.encode(tlvs).require.bytes
assert(encoded == bin)
assert(perHopPayloadCodec.decode(bin.bits).require.value == tlvs)
}
}