diff --git a/eclair-core/src/main/resources/reference.conf b/eclair-core/src/main/resources/reference.conf index 3f0bb3f6c..dac2c69c2 100644 --- a/eclair-core/src/main/resources/reference.conf +++ b/eclair-core/src/main/resources/reference.conf @@ -57,6 +57,7 @@ eclair { // Do not enable option_anchor_outputs unless you really know what you're doing. option_anchor_outputs = disabled option_anchors_zero_fee_htlc_tx = optional + option_route_blinding = disabled option_shutdown_anysegwit = optional option_dual_fund = disabled option_onion_messages = optional diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala index e6e260b0b..9d9383cba 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala @@ -221,6 +221,11 @@ object Features { val mandatory = 22 } + case object RouteBlinding extends Feature with InitFeature with NodeFeature with InvoiceFeature { + val rfcName = "option_route_blinding" + val mandatory = 24 + } + case object ShutdownAnySegwit extends Feature with InitFeature with NodeFeature { val rfcName = "option_shutdown_anysegwit" val mandatory = 26 @@ -285,6 +290,7 @@ object Features { StaticRemoteKey, AnchorOutputs, AnchorOutputsZeroFeeHtlcTx, + RouteBlinding, ShutdownAnySegwit, DualFunding, OnionMessages, @@ -303,6 +309,7 @@ object Features { BasicMultiPartPayment -> (PaymentSecret :: Nil), AnchorOutputs -> (StaticRemoteKey :: Nil), AnchorOutputsZeroFeeHtlcTx -> (StaticRemoteKey :: Nil), + RouteBlinding -> (VariableLengthOnion :: Nil), TrampolinePaymentPrototype -> (PaymentSecret :: Nil), KeySend -> (VariableLengthOnion :: Nil) ) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala index b053b61d5..0f0c09268 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala @@ -29,6 +29,7 @@ import fr.acinq.eclair.{Alias, BlockHeight, CltvExpiry, CltvExpiryDelta, Feature import scodec.bits.ByteVector import java.util.UUID +import scala.concurrent.duration.FiniteDuration /** * Created by PM on 20/05/2016. @@ -183,7 +184,7 @@ final case class CMD_ADD_HTLC(replyTo: ActorRef, amount: MilliSatoshi, paymentHa sealed trait HtlcSettlementCommand extends HasOptionalReplyToCommand { def id: Long } final case class CMD_FULFILL_HTLC(id: Long, r: ByteVector32, commit: Boolean = false, replyTo_opt: Option[ActorRef] = None) extends HtlcSettlementCommand final case class CMD_FAIL_HTLC(id: Long, reason: Either[ByteVector, FailureMessage], commit: Boolean = false, replyTo_opt: Option[ActorRef] = None) extends HtlcSettlementCommand -final case class CMD_FAIL_MALFORMED_HTLC(id: Long, onionHash: ByteVector32, failureCode: Int, commit: Boolean = false, replyTo_opt: Option[ActorRef] = None) extends HtlcSettlementCommand +final case class CMD_FAIL_MALFORMED_HTLC(id: Long, onionHash: ByteVector32, failureCode: Int, delay_opt: Option[FiniteDuration] = None, commit: Boolean = false, replyTo_opt: Option[ActorRef] = None) extends HtlcSettlementCommand final case class CMD_UPDATE_FEE(feeratePerKw: FeeratePerKw, commit: Boolean = false, replyTo_opt: Option[ActorRef] = None) extends HasOptionalReplyToCommand final case class CMD_SIGN(replyTo_opt: Option[ActorRef] = None) extends HasOptionalReplyToCommand diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala index 0bb17af97..8ee2fe1fa 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala @@ -399,14 +399,20 @@ class Channel(val nodeParams: NodeParams, val wallet: OnChainChannelFunder, val } case Event(c: CMD_FAIL_MALFORMED_HTLC, d: DATA_NORMAL) => - Commitments.sendFailMalformed(d.commitments, c) match { - case Right((commitments1, fail)) => - if (c.commit) self ! CMD_SIGN() - context.system.eventStream.publish(AvailableBalanceChanged(self, d.channelId, d.shortIds, commitments1)) - handleCommandSuccess(c, d.copy(commitments = commitments1)) sending fail - case Left(cause) => - // we acknowledge the command right away in case of failure - handleCommandError(cause, c).acking(d.channelId, c) + c.delay_opt match { + case Some(delay) => + log.debug("delaying CMD_FAIL_MALFORMED_HTLC with id={} for {}", c.id, delay) + context.system.scheduler.scheduleOnce(delay, self, c.copy(delay_opt = None)) + stay() + case None => Commitments.sendFailMalformed(d.commitments, c) match { + case Right((commitments1, fail)) => + if (c.commit) self ! CMD_SIGN() + context.system.eventStream.publish(AvailableBalanceChanged(self, d.channelId, d.shortIds, commitments1)) + handleCommandSuccess(c, d.copy(commitments = commitments1)) sending fail + case Left(cause) => + // we acknowledge the command right away in case of failure + handleCommandError(cause, c).acking(d.channelId, c) + } } case Event(fail: UpdateFailHtlc, d: DATA_NORMAL) => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala index 11c87ae57..940c3da55 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala @@ -25,7 +25,7 @@ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.router.Router.{ChannelHop, Hop, NodeHop} import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload, PerHopPayload} import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, UInt64, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, UInt64, randomBytes32, randomKey} import scodec.bits.ByteVector import scodec.{Attempt, DecodeResult} @@ -77,8 +77,7 @@ object IncomingPaymentPacket { private[payment] def decryptEncryptedRecipientData(add: UpdateAddHtlc, privateKey: PrivateKey, payload: TlvStream[OnionPaymentPayloadTlv], encryptedRecipientData: ByteVector): Either[FailureMessage, DecodedEncryptedRecipientData] = { if (add.blinding_opt.isDefined && payload.get[OnionPaymentPayloadTlv.BlindingPoint].isDefined) { - // TODO: return an unparseable error - Left(InvalidOnionPayload(UInt64(12), 0)) + Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) } else { add.blinding_opt.orElse(payload.get[OnionPaymentPayloadTlv.BlindingPoint].map(_.publicKey)) match { case Some(blinding) => RouteBlindingEncryptedDataCodecs.decode(privateKey, blinding, encryptedRecipientData) match { @@ -86,15 +85,13 @@ object IncomingPaymentPacket { // There are two possibilities in this case: // - the blinding point is invalid: the sender or the previous node is buggy or malicious // - the encrypted data is invalid: the sender, the previous node or the recipient must be buggy or malicious - // TODO: return an unparseable error - Left(InvalidOnionPayload(UInt64(12), 0)) + Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) case Right(decoded) => Right(DecodedEncryptedRecipientData(decoded.tlvs, decoded.nextBlinding)) } case None => // The sender is trying to use route blinding, but we didn't receive the blinding point used to derive // the decryption key. The sender or the previous peer is buggy or malicious. - // TODO: return an unparseable error - Left(InvalidOnionPayload(UInt64(12), 0)) + Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) } } } @@ -110,7 +107,7 @@ object IncomingPaymentPacket { * @param privateKey this node's private key * @return whether the payment is to be relayed or if our node is the final recipient (or an error). */ - def decrypt(add: UpdateAddHtlc, privateKey: PrivateKey)(implicit log: LoggingAdapter): Either[FailureMessage, IncomingPaymentPacket] = { + def decrypt(add: UpdateAddHtlc, privateKey: PrivateKey, features: Features[Feature])(implicit log: LoggingAdapter): Either[FailureMessage, IncomingPaymentPacket] = { // We first derive the decryption key used to peel the onion. val outerOnionDecryptionKey = add.blinding_opt match { case Some(blinding) => Sphinx.RouteBlinding.derivePrivateKey(privateKey, blinding) @@ -119,25 +116,25 @@ object IncomingPaymentPacket { decryptOnion(add.paymentHash, outerOnionDecryptionKey, add.onionRoutingPacket).flatMap { case DecodedOnionPacket(payload, Some(nextPacket)) => payload.get[OnionPaymentPayloadTlv.EncryptedRecipientData] match { - case Some(OnionPaymentPayloadTlv.EncryptedRecipientData(encryptedRecipientData)) => - decryptEncryptedRecipientData(add, privateKey, payload, encryptedRecipientData).flatMap { + case Some(_) if !features.hasFeature(Features.RouteBlinding) => Left(InvalidOnionPayload(UInt64(10), 0)) + case Some(encrypted) => + decryptEncryptedRecipientData(add, privateKey, payload, encrypted.data).flatMap { case DecodedEncryptedRecipientData(blindedPayload, nextBlinding) => validateBlindedChannelRelayPayload(add, payload, blindedPayload, nextBlinding, nextPacket) } - case None if add.blinding_opt.isDefined => Left(InvalidOnionPayload(UInt64(12), 0)) + case None if add.blinding_opt.isDefined => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) case None => IntermediatePayload.ChannelRelay.Standard.validate(payload).left.map(_.failureMessage).map { payload => ChannelRelayPacket(add, payload, nextPacket) } } case DecodedOnionPacket(payload, None) => payload.get[OnionPaymentPayloadTlv.EncryptedRecipientData] match { - case Some(OnionPaymentPayloadTlv.EncryptedRecipientData(encryptedRecipientData)) => - decryptEncryptedRecipientData(add, privateKey, payload, encryptedRecipientData).flatMap { - case DecodedEncryptedRecipientData(blindedPayload, _) => - // TODO: receiving through blinded routes is not supported yet. - FinalPayload.Blinded.validate(payload, blindedPayload).left.map(_.failureMessage).flatMap(_ => Left(InvalidOnionPayload(UInt64(12), 0))) + case Some(_) if !features.hasFeature(Features.RouteBlinding) => Left(InvalidOnionPayload(UInt64(10), 0)) + case Some(encrypted) => + decryptEncryptedRecipientData(add, privateKey, payload, encrypted.data).flatMap { + case DecodedEncryptedRecipientData(blindedPayload, _) => validateBlindedFinalPayload(add, payload, blindedPayload) } - case None if add.blinding_opt.isDefined => Left(InvalidOnionPayload(UInt64(12), 0)) + case None if add.blinding_opt.isDefined => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) case None => // We check if the payment is using trampoline: if it is, we may not be the final recipient. payload.get[OnionPaymentPayloadTlv.TrampolineOnion] match { @@ -146,7 +143,7 @@ object IncomingPaymentPacket { // blinding point and use it to derive the decryption key for the blinded trampoline onion. decryptOnion(add.paymentHash, privateKey, trampolinePacket).flatMap { case DecodedOnionPacket(innerPayload, Some(next)) => validateNodeRelay(add, payload, innerPayload, next) - case DecodedOnionPacket(innerPayload, None) => validateFinalPayload(add, payload, innerPayload) + case DecodedOnionPacket(innerPayload, None) => validateTrampolineFinalPayload(add, payload, innerPayload) } case None => validateFinalPayload(add, payload) } @@ -156,10 +153,9 @@ object IncomingPaymentPacket { private def validateBlindedChannelRelayPayload(add: UpdateAddHtlc, payload: TlvStream[OnionPaymentPayloadTlv], blindedPayload: TlvStream[RouteBlindingEncryptedDataTlv], nextBlinding: PublicKey, nextPacket: OnionRoutingPacket): Either[FailureMessage, ChannelRelayPacket] = { IntermediatePayload.ChannelRelay.Blinded.validate(payload, blindedPayload, nextBlinding).left.map(_.failureMessage).flatMap { - // TODO: return an unparseable error - case payload if add.amountMsat < payload.paymentConstraints.minAmount => Left(InvalidOnionPayload(UInt64(12), 0)) - case payload if add.cltvExpiry > payload.paymentConstraints.maxCltvExpiry => Left(InvalidOnionPayload(UInt64(12), 0)) - case payload if !Features.areCompatible(Features.empty, payload.allowedFeatures) => Left(InvalidOnionPayload(UInt64(12), 0)) + case payload if add.amountMsat < payload.paymentConstraints.minAmount => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) + case payload if add.cltvExpiry > payload.paymentConstraints.maxCltvExpiry => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) + case payload if !Features.areCompatible(Features.empty, payload.allowedFeatures) => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) case payload => Right(ChannelRelayPacket(add, payload, nextPacket)) } } @@ -172,7 +168,17 @@ object IncomingPaymentPacket { } } - private def validateFinalPayload(add: UpdateAddHtlc, outerPayload: TlvStream[OnionPaymentPayloadTlv], innerPayload: TlvStream[OnionPaymentPayloadTlv]): Either[FailureMessage, FinalPacket] = { + private def validateBlindedFinalPayload(add: UpdateAddHtlc, payload: TlvStream[OnionPaymentPayloadTlv], blindedPayload: TlvStream[RouteBlindingEncryptedDataTlv]): Either[FailureMessage, FinalPacket] = { + FinalPayload.Blinded.validate(payload, blindedPayload).left.map(_.failureMessage).flatMap { + case payload if add.amountMsat < payload.paymentConstraints.minAmount => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) + case payload if add.cltvExpiry > payload.paymentConstraints.maxCltvExpiry => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) + case payload if !Features.areCompatible(Features.empty, payload.allowedFeatures) => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) + // TODO: receiving through blinded routes is not supported yet. + case _ => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) + } + } + + private def validateTrampolineFinalPayload(add: UpdateAddHtlc, outerPayload: TlvStream[OnionPaymentPayloadTlv], innerPayload: TlvStream[OnionPaymentPayloadTlv]): Either[FailureMessage, FinalPacket] = { // The outer payload cannot use route blinding, but the inner payload may (but it's not supported yet). FinalPayload.Standard.validate(outerPayload).left.map(_.failureMessage).flatMap { outerPayload => FinalPayload.Standard.validate(innerPayload).left.map(_.failureMessage).flatMap { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala index 736b47f5f..a961e1852 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala @@ -23,6 +23,7 @@ import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.scaladsl.{ActorContext, Behaviors} import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.eclair.channel._ +import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.PendingCommandsDb import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment.relay.Relayer.{OutgoingChannel, OutgoingChannelParams} @@ -32,6 +33,8 @@ import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{Logs, NodeParams, TimestampSecond, channel, nodeFee} import java.util.UUID +import scala.concurrent.duration.DurationLong +import scala.util.Random object ChannelRelay { @@ -77,10 +80,21 @@ object ChannelRelay { } } - def translateRelayFailure(originHtlcId: Long, fail: HtlcResult.Fail): channel.Command with channel.HtlcSettlementCommand = { + def translateRelayFailure(originHtlcId: Long, fail: HtlcResult.Fail, relayPacket_opt: Option[IncomingPaymentPacket.ChannelRelayPacket]): channel.Command with channel.HtlcSettlementCommand = { fail match { case f: HtlcResult.RemoteFail => CMD_FAIL_HTLC(originHtlcId, Left(f.fail.reason), commit = true) - case f: HtlcResult.RemoteFailMalformed => CMD_FAIL_MALFORMED_HTLC(originHtlcId, f.fail.onionHash, f.fail.failureCode, commit = true) + case f: HtlcResult.RemoteFailMalformed => relayPacket_opt match { + case Some(IncomingPaymentPacket.ChannelRelayPacket(add, payload: IntermediatePayload.ChannelRelay.Blinded, _)) => + // Bolt 2: + // - if it is part of a blinded route: + // - MUST return an `update_fail_malformed_htlc` error using the `invalid_onion_blinding` failure code, with the `sha256_of_onion` of the onion it received. + // - If its onion payload contains `current_blinding_point`: + // - SHOULD add a random delay before sending `update_fail_malformed_htlc`. + val delay_opt = payload.records.get[OnionPaymentPayloadTlv.BlindingPoint].map(_ => Random.nextLong(1000).millis) + CMD_FAIL_MALFORMED_HTLC(originHtlcId, Sphinx.hash(add.onionRoutingPacket), InvalidOnionBlinding(ByteVector32.Zeroes).code, delay_opt, commit = true) + case _ => + CMD_FAIL_MALFORMED_HTLC(originHtlcId, f.fail.onionHash, f.fail.failureCode, commit = true) + } case _: HtlcResult.OnChainFail => CMD_FAIL_HTLC(originHtlcId, Right(PermanentChannelFailure), commit = true) case HtlcResult.ChannelFailureBeforeSigned => CMD_FAIL_HTLC(originHtlcId, Right(PermanentChannelFailure), commit = true) case f: HtlcResult.DisconnectedBeforeSigned => CMD_FAIL_HTLC(originHtlcId, Right(TemporaryChannelFailure(f.channelUpdate)), commit = true) @@ -154,7 +168,7 @@ class ChannelRelay private(nodeParams: NodeParams, case WrappedAddResponse(RES_ADD_SETTLED(o: Origin.ChannelRelayedHot, _, fail: HtlcResult.Fail)) => context.log.info("relaying fail to upstream") Metrics.recordPaymentRelayFailed(Tags.FailureType.Remote, Tags.RelayType.Channel) - val cmd = translateRelayFailure(o.originHtlcId, fail) + val cmd = translateRelayFailure(o.originHtlcId, fail, Some(r)) safeSendAndStop(o.originChannelId, cmd) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/PostRestartHtlcCleaner.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/PostRestartHtlcCleaner.scala index 4ace89917..fe9f6bf32 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/PostRestartHtlcCleaner.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/PostRestartHtlcCleaner.scala @@ -29,7 +29,7 @@ import fr.acinq.eclair.payment.Monitoring.Tags import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket, PaymentFailed, PaymentSent} import fr.acinq.eclair.transactions.DirectedHtlc.outgoing import fr.acinq.eclair.wire.protocol.{FailureMessage, TemporaryNodeFailure, UpdateAddHtlc} -import fr.acinq.eclair.{CustomCommitmentsPlugin, Logs, MilliSatoshiLong, NodeParams, TimestampMilli} +import fr.acinq.eclair.{CustomCommitmentsPlugin, Feature, Features, Logs, MilliSatoshiLong, NodeParams, TimestampMilli} import scala.concurrent.Promise import scala.util.Try @@ -67,7 +67,7 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial val brokenHtlcs: BrokenHtlcs = { val channels = listLocalChannels(nodeParams.db.channels) val nonStandardIncomingHtlcs: Seq[IncomingHtlc] = nodeParams.pluginParams.collect { case p: CustomCommitmentsPlugin => p.getIncomingHtlcs(nodeParams, log) }.flatten - val htlcsIn: Seq[IncomingHtlc] = getIncomingHtlcs(channels, nodeParams.db.payments, nodeParams.privateKey) ++ nonStandardIncomingHtlcs + val htlcsIn: Seq[IncomingHtlc] = getIncomingHtlcs(channels, nodeParams.db.payments, nodeParams.privateKey, nodeParams.features) ++ nonStandardIncomingHtlcs val nonStandardRelayedOutHtlcs: Map[Origin, Set[(ByteVector32, Long)]] = nodeParams.pluginParams.collect { case p: CustomCommitmentsPlugin => p.getHtlcsRelayedOut(htlcsIn, nodeParams, log) }.flatten.toMap val relayedOut: Map[Origin, Set[(ByteVector32, Long)]] = getHtlcsRelayedOut(channels, htlcsIn) ++ nonStandardRelayedOutHtlcs @@ -235,7 +235,7 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial case Origin.ChannelRelayedCold(originChannelId, originHtlcId, _, _) => log.warning(s"payment failed for paymentHash=${failedHtlc.paymentHash}: failing 1 HTLC upstream") Metrics.Resolved.withTag(Tags.Success, value = false).withTag(Metrics.Relayed, value = true).increment() - val cmd = ChannelRelay.translateRelayFailure(originHtlcId, fail) + val cmd = ChannelRelay.translateRelayFailure(originHtlcId, fail, None) PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, originChannelId, cmd) case Origin.TrampolineRelayedCold(origins) => log.warning(s"payment failed for paymentHash=${failedHtlc.paymentHash}: failing ${origins.length} HTLCs upstream") @@ -334,14 +334,14 @@ object PostRestartHtlcCleaner { } /** @return incoming HTLCs that have been *cross-signed* (that potentially have been relayed). */ - private def getIncomingHtlcs(channels: Seq[PersistentChannelData], paymentsDb: IncomingPaymentsDb, privateKey: PrivateKey)(implicit log: LoggingAdapter): Seq[IncomingHtlc] = { + private def getIncomingHtlcs(channels: Seq[PersistentChannelData], paymentsDb: IncomingPaymentsDb, privateKey: PrivateKey, features: Features[Feature])(implicit log: LoggingAdapter): Seq[IncomingHtlc] = { // We are interested in incoming HTLCs, that have been *cross-signed* (otherwise they wouldn't have been relayed). // They signed it first, so the HTLC will first appear in our commitment tx, and later on in their commitment when // we subsequently sign it. That's why we need to look in *their* commitment with direction=OUT. channels .flatMap(_.commitments.remoteCommit.spec.htlcs) .collect(outgoing) - .map(IncomingPaymentPacket.decrypt(_, privateKey)) + .map(IncomingPaymentPacket.decrypt(_, privateKey, features)) .collect(decryptedIncomingHtlcs(paymentsDb)) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala index c7341e7fa..1d8eb709b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala @@ -33,7 +33,8 @@ import fr.acinq.eclair.{Logs, MilliSatoshi, NodeParams} import grizzled.slf4j.Logging import scala.concurrent.Promise -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.duration.{DurationLong, FiniteDuration} +import scala.util.Random /** * Created by PM on 01/02/2017. @@ -62,7 +63,7 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym def receive: Receive = { case RelayForward(add) => log.debug(s"received forwarding request for htlc #${add.id} from channelId=${add.channelId}") - IncomingPaymentPacket.decrypt(add, nodeParams.privateKey) match { + IncomingPaymentPacket.decrypt(add, nodeParams.privateKey, nodeParams.features) match { case Right(p: IncomingPaymentPacket.FinalPacket) => log.debug(s"forwarding htlc #${add.id} to payment-handler") paymentHandler forward p @@ -77,7 +78,13 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym } case Left(badOnion: BadOnion) => log.warning(s"couldn't parse onion: reason=${badOnion.message}") - val cmdFail = CMD_FAIL_MALFORMED_HTLC(add.id, badOnion.onionHash, badOnion.code, commit = true) + val delay_opt = badOnion match { + // We are the introduction point of a blinded path: we add a non-negligible delay to make it look like it + // could come from a downstream node. + case InvalidOnionBlinding(_) if add.blinding_opt.isEmpty => Some(500.millis + Random.nextLong(1500).millis) + case _ => None + } + val cmdFail = CMD_FAIL_MALFORMED_HTLC(add.id, badOnion.onionHash, badOnion.code, delay_opt, commit = true) log.warning(s"rejecting htlc #${add.id} from channelId=${add.channelId} reason=malformed onionHash=${cmdFail.onionHash} failureCode=${cmdFail.failureCode}") PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, add.channelId, cmdFail) case Left(failure) => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/CommandCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/CommandCodecs.scala index 0328e79d2..6c79f8366 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/CommandCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/CommandCodecs.scala @@ -23,6 +23,8 @@ import fr.acinq.eclair.wire.protocol.FailureMessageCodecs.failureMessageCodec import scodec.Codec import scodec.codecs._ +import scala.concurrent.duration.FiniteDuration + object CommandCodecs { val cmdFulfillCodec: Codec[CMD_FULFILL_HTLC] = @@ -41,6 +43,8 @@ object CommandCodecs { (("id" | int64) :: ("onionHash" | bytes32) :: ("failureCode" | uint16) :: + // No need to delay commands after a restart, we've been offline which already created a random delay. + ("delay_opt" | provide(Option.empty[FiniteDuration])) :: ("commit" | provide(false)) :: ("replyTo_opt" | provide(Option.empty[ActorRef]))).as[CMD_FAIL_MALFORMED_HTLC] diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/FailureMessage.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/FailureMessage.scala index 1e9612c22..e9bbb5cff 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/FailureMessage.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/FailureMessage.scala @@ -49,6 +49,7 @@ case object RequiredNodeFeatureMissing extends Perm with Node { def message = "p case class InvalidOnionVersion(onionHash: ByteVector32) extends BadOnion with Perm { def message = "onion version was not understood by the processing node" } case class InvalidOnionHmac(onionHash: ByteVector32) extends BadOnion with Perm { def message = "onion HMAC was incorrect when it reached the processing node" } case class InvalidOnionKey(onionHash: ByteVector32) extends BadOnion with Perm { def message = "ephemeral key was unparsable by the processing node" } +case class InvalidOnionBlinding(onionHash: ByteVector32) extends BadOnion with Perm { def message = "the blinded onion didn't match the processing node's requirements" } case class TemporaryChannelFailure(update: ChannelUpdate) extends Update { def message = s"channel ${update.shortChannelId} is currently unavailable" } case object PermanentChannelFailure extends Perm { def message = "channel is permanently unavailable" } case object RequiredChannelFeatureMissing extends Perm { def message = "channel requires features not present in the onion" } @@ -120,6 +121,7 @@ object FailureMessageCodecs { .typecase(21, provide(ExpiryTooFar)) .typecase(PERM | 22, (("tag" | varint) :: ("offset" | uint16)).as[InvalidOnionPayload]) .typecase(23, provide(PaymentTimeout)) + .typecase(BADONION | PERM | 24, sha256.as[InvalidOnionBlinding]) // TODO: @t-bast: once fully spec-ed, these should probably include a NodeUpdate and use a different ID. // We should update Phoenix and our nodes at the same time, or first update Phoenix to understand both new and old errors. .typecase(NODE | 51, provide(TrampolineFeeInsufficient)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala index e586eebe6..42c3d467e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala @@ -145,6 +145,19 @@ object OnionPaymentPayloadTlv { /** Id of the next node. */ case class OutgoingNodeId(nodeId: PublicKey) extends OnionPaymentPayloadTlv + /** + * Route blinding lets the recipient provide some encrypted data for each intermediate node in the blinded part of the + * route. This data cannot be decrypted or modified by the sender and usually contains information to locate the next + * node without revealing it to the sender. + */ + case class EncryptedRecipientData(data: ByteVector) extends OnionPaymentPayloadTlv + + /** Blinding ephemeral public key for the introduction node of a blinded route. */ + case class BlindingPoint(publicKey: PublicKey) extends OnionPaymentPayloadTlv + + /** Total amount in blinded multi-part payments. */ + case class TotalAmount(totalAmount: MilliSatoshi) extends OnionPaymentPayloadTlv + /** * When payment metadata is included in a Bolt 11 invoice, we should send it as-is to the recipient. * This lets recipients generate invoices without having to store anything on their side until the invoice is paid. @@ -168,16 +181,6 @@ object OnionPaymentPayloadTlv { /** Pre-image included by the sender of a payment in case of a donation */ case class KeySend(paymentPreimage: ByteVector32) extends OnionPaymentPayloadTlv - - /** - * Route blinding lets the recipient provide some encrypted data for each intermediate node in the blinded part of the - * route. This data cannot be decrypted or modified by the sender and usually contains information to locate the next - * node without revealing it to the sender. - */ - case class EncryptedRecipientData(data: ByteVector) extends OnionPaymentPayloadTlv - - /** Blinding ephemeral public key for the introduction node of a blinded route. */ - case class BlindingPoint(publicKey: PublicKey) extends OnionPaymentPayloadTlv } object PaymentOnion { @@ -264,9 +267,17 @@ object PaymentOnion { object Blinded { def validate(records: TlvStream[OnionPaymentPayloadTlv], blindedRecords: TlvStream[RouteBlindingEncryptedDataTlv], nextBlinding: PublicKey): Either[InvalidTlvPayload, Blinded] = { - if (records.get[AmountToForward].nonEmpty) return Left(ForbiddenTlv(UInt64(2))) - if (records.get[OutgoingCltv].nonEmpty) return Left(ForbiddenTlv(UInt64(4))) if (records.get[EncryptedRecipientData].isEmpty) return Left(MissingRequiredTlv(UInt64(10))) + // Bolt 4: MUST return an error if the payload contains other tlv fields than `encrypted_recipient_data` and `current_blinding_point`. + if (records.unknown.nonEmpty) return Left(ForbiddenTlv(records.unknown.head.tag)) + records.records.find { + case _: EncryptedRecipientData => false + case _: BlindingPoint => false + case _ => true + } match { + case Some(_) => return Left(ForbiddenTlv(UInt64(0))) + case None => // no forbidden tlv found + } BlindedRouteData.validatePaymentRelayData(blindedRecords).map(blindedRecords => Blinded(records, blindedRecords, nextBlinding)) } } @@ -388,7 +399,7 @@ object PaymentOnion { */ case class Blinded(records: TlvStream[OnionPaymentPayloadTlv], blindedRecords: TlvStream[RouteBlindingEncryptedDataTlv]) extends FinalPayload { override val amount = records.get[AmountToForward].get.amount - override val totalAmount = amount // TODO: get from total_amount_msat tlv + override val totalAmount = records.get[TotalAmount].map(_.totalAmount).getOrElse(amount) override val expiry = records.get[OutgoingCltv].get.cltv val pathId_opt = blindedRecords.get[RouteBlindingEncryptedDataTlv.PathId].map(_.data) val paymentConstraints = blindedRecords.get[RouteBlindingEncryptedDataTlv.PaymentConstraints].get @@ -403,6 +414,20 @@ object PaymentOnion { def validate(records: TlvStream[OnionPaymentPayloadTlv], blindedRecords: TlvStream[RouteBlindingEncryptedDataTlv]): Either[InvalidTlvPayload, Blinded] = { if (records.get[AmountToForward].isEmpty) return Left(MissingRequiredTlv(UInt64(2))) if (records.get[OutgoingCltv].isEmpty) return Left(MissingRequiredTlv(UInt64(4))) + if (records.get[EncryptedRecipientData].isEmpty) return Left(MissingRequiredTlv(UInt64(10))) + // Bolt 4: MUST return an error if the payload contains other tlv fields than `encrypted_recipient_data`, `current_blinding_point`, `amt_to_forward`, `outgoing_cltv_value` and `total_amount_msat`. + if (records.unknown.nonEmpty) return Left(ForbiddenTlv(records.unknown.head.tag)) + records.records.find { + case _: AmountToForward => false + case _: OutgoingCltv => false + case _: EncryptedRecipientData => false + case _: BlindingPoint => false + case _: TotalAmount => false + case _ => true + } match { + case Some(_) => return Left(ForbiddenTlv(UInt64(0))) + case None => // no forbidden tlv found + } BlindedRouteData.validPaymentRecipientData(blindedRecords).map(blindedRecords => Blinded(records, blindedRecords)) } } @@ -447,6 +472,8 @@ object PaymentOnionCodecs { private val paymentMetadata: Codec[PaymentMetadata] = variableSizeBytesLong(varintoverflow, "payment_metadata" | bytes).as[PaymentMetadata] + private val totalAmount: Codec[TotalAmount] = ("total_amount_msat" | ltmillisatoshi).as[TotalAmount] + private val invoiceFeatures: Codec[InvoiceFeatures] = variableSizeBytesLong(varintoverflow, bytes).as[InvoiceFeatures] private val invoiceRoutingInfo: Codec[InvoiceRoutingInfo] = variableSizeBytesLong(varintoverflow, list(listOfN(uint8, Bolt11Invoice.Codecs.extraHopCodec))).as[InvoiceRoutingInfo] @@ -463,6 +490,7 @@ object PaymentOnionCodecs { .typecase(UInt64(10), encryptedRecipientData) .typecase(UInt64(12), blindingPoint) .typecase(UInt64(16), paymentMetadata) + .typecase(UInt64(18), totalAmount) // Types below aren't specified - use cautiously when deploying (be careful with backwards-compatibility). .typecase(UInt64(66097), invoiceFeatures) .typecase(UInt64(66098), outgoingNodeId) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala index 8b91689a8..4e779dbc6 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala @@ -78,6 +78,7 @@ object BlindedRouteData { if (records.get[PathId].isDefined) return Left(ForbiddenTlv(UInt64(6))) if (records.get[PaymentRelay].isDefined) return Left(ForbiddenTlv(UInt64(10))) if (records.get[PaymentConstraints].isDefined) return Left(ForbiddenTlv(UInt64(12))) + if (records.get[AllowedFeatures].exists(!_.features.isEmpty)) return Left(ForbiddenTlv(UInt64(14))) // we don't support custom blinded relay features yet Right(records) } @@ -92,6 +93,7 @@ object BlindedRouteData { if (records.get[PaymentRelay].isEmpty) return Left(MissingRequiredTlv(UInt64(10))) if (records.get[PaymentConstraints].isEmpty) return Left(MissingRequiredTlv(UInt64(12))) if (records.get[PathId].nonEmpty) return Left(ForbiddenTlv(UInt64(6))) + if (records.get[AllowedFeatures].exists(!_.features.isEmpty)) return Left(ForbiddenTlv(UInt64(14))) // we don't support custom blinded relay features yet Right(records) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala index 1bf74af7a..ea05d2274 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala @@ -99,8 +99,8 @@ class FeaturesSpec extends AnyFunSuite { for ((testCase, valid) <- testCases) { if (valid) { - assert(validateFeatureGraph(Features(testCase)) == None) - assert(validateFeatureGraph(Features(testCase.bytes)) == None) + assert(validateFeatureGraph(Features(testCase)).isEmpty) + assert(validateFeatureGraph(Features(testCase.bytes)).isEmpty) } else { assert(validateFeatureGraph(Features(testCase)).nonEmpty) assert(validateFeatureGraph(Features(testCase.bytes)).nonEmpty) @@ -235,7 +235,7 @@ class FeaturesSpec extends AnyFunSuite { hex"" -> Features.empty, hex"0100" -> Features(VariableLengthOnion -> Mandatory), hex"028a8a" -> Features(DataLossProtect -> Optional, InitialRoutingSync -> Optional, ChannelRangeQueries -> Optional, VariableLengthOnion -> Optional, ChannelRangeQueriesExtended -> Optional, PaymentSecret -> Optional, BasicMultiPartPayment -> Optional), - hex"09004200" -> Features(Map(VariableLengthOnion -> Optional, PaymentSecret -> Mandatory, ShutdownAnySegwit -> Optional), Set(UnknownFeature(24))), + hex"09004200" -> Features(Map(VariableLengthOnion -> Optional, PaymentSecret -> Mandatory, RouteBlinding -> Mandatory, ShutdownAnySegwit -> Optional)), hex"80010080000000000000000000000000000000000000" -> Features(Map.empty[Feature, FeatureSupport], Set(UnknownFeature(151), UnknownFeature(160), UnknownFeature(175))) ) @@ -264,7 +264,7 @@ class FeaturesSpec extends AnyFunSuite { val features = fromConfiguration(conf) assert(features.toByteVector == hex"028a8a") assert(Features(hex"028a8a") == features) - assert(validateFeatureGraph(features) == None) + assert(validateFeatureGraph(features).isEmpty) assert(features.hasFeature(DataLossProtect, Some(Optional))) assert(features.hasFeature(InitialRoutingSync, Some(Optional))) assert(features.hasFeature(ChannelRangeQueries, Some(Optional))) @@ -287,7 +287,7 @@ class FeaturesSpec extends AnyFunSuite { val features = fromConfiguration(conf) assert(features.toByteVector == hex"068a") assert(Features(hex"068a") == features) - assert(validateFeatureGraph(features) == None) + assert(validateFeatureGraph(features).isEmpty) assert(features.hasFeature(DataLossProtect, Some(Optional))) assert(features.hasFeature(InitialRoutingSync, Some(Optional))) assert(!features.hasFeature(InitialRoutingSync, Some(Mandatory))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala index a4ef930b3..505e626bc 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala @@ -1771,6 +1771,20 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with localChanges = initialState.commitments.localChanges.copy(initialState.commitments.localChanges.proposed :+ fail)))) } + test("recv CMD_FAIL_MALFORMED_HTLC (with delay)") { f => + import f._ + val (_, htlc) = addHtlc(50000000 msat, alice, bob, alice2bob, bob2alice) + crossSign(alice, bob, alice2bob, bob2alice) + + // actual test begins + val initialState = bob.stateData.asInstanceOf[DATA_NORMAL] + bob ! CMD_FAIL_MALFORMED_HTLC(htlc.id, Sphinx.hash(htlc.onionRoutingPacket), FailureMessageCodecs.BADONION | FailureMessageCodecs.PERM | 24, delay_opt = Some(50 millis)) + val fail = bob2alice.expectMsgType[UpdateFailMalformedHtlc] + awaitCond(bob.stateData == initialState.copy( + commitments = initialState.commitments.copy( + localChanges = initialState.commitments.localChanges.copy(initialState.commitments.localChanges.proposed :+ fail)))) + } + test("recv CMD_FAIL_MALFORMED_HTLC (unknown htlc id)") { f => import f._ val sender = TestProbe() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt11InvoiceSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt11InvoiceSpec.scala index 1ef0ff3e6..03929d2f7 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt11InvoiceSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt11InvoiceSpec.scala @@ -213,7 +213,7 @@ class Bolt11InvoiceSpec extends AnyFunSuite { assert(invoice.prefix == "lntbs") assert(invoice.amount_opt.contains(250000000 msat)) assert(invoice.paymentHash.bytes == hex"4ffb6e9eabe93a88eb927ead43ae74172d9fbc3d858cede1e80871a5eb8bd863") - assert(invoice.features == Features(VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, PaymentMetadata -> Optional )) + assert(invoice.features == Features(VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, PaymentMetadata -> Optional)) assert(invoice.createdAt == TimestampSecond(1660836433)) assert(invoice.nodeId == PublicKey(hex"02e899d99662f2e64ea0eeaecb53c4628fa40a22d7185076e42e8a3d67fcb7b8e6")) assert(invoice.description == Left("yolo")) @@ -492,8 +492,8 @@ class Bolt11InvoiceSpec extends AnyFunSuite { Features(bin" 0000110000101000100000000") -> Result(allowMultiPart = false, requirePaymentSecret = true, areSupported = true), Features(bin" 0000100000101000100000000") -> Result(allowMultiPart = false, requirePaymentSecret = true, areSupported = true), Features(bin" 0010000000101000100000000") -> Result(allowMultiPart = false, requirePaymentSecret = true, areSupported = true), + Features(bin" 000001000000000100000100000000") -> Result(allowMultiPart = false, requirePaymentSecret = true, areSupported = true), // those are useful for nonreg testing of the areSupported method (which needs to be updated with every new supported mandatory bit) - Features(bin" 000001000000000100000100000000") -> Result(allowMultiPart = false, requirePaymentSecret = true, areSupported = false), Features(bin" 000100000000000100000100000000") -> Result(allowMultiPart = false, requirePaymentSecret = true, areSupported = true), Features(bin"00000010000000000000100000100000000") -> Result(allowMultiPart = false, requirePaymentSecret = true, areSupported = true), Features(bin"00001000000000000000100000100000000") -> Result(allowMultiPart = false, requirePaymentSecret = true, areSupported = false) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala index aeaf59e05..b12f08401 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala @@ -75,7 +75,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { def testPeelOnion(packet_b: OnionRoutingPacket): Unit = { val add_b = UpdateAddHtlc(randomBytes32(), 0, amount_ab, paymentHash, expiry_ab, packet_b, None) - val Right(relay_b@ChannelRelayPacket(add_b2, payload_b, packet_c)) = decrypt(add_b, priv_b.privateKey) + val Right(relay_b@ChannelRelayPacket(add_b2, payload_b, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) assert(add_b2 == add_b) assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_b.amountToForward == amount_bc) @@ -85,7 +85,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) val add_c = UpdateAddHtlc(randomBytes32(), 1, amount_bc, paymentHash, expiry_bc, packet_c, None) - val Right(relay_c@ChannelRelayPacket(add_c2, payload_c, packet_d)) = decrypt(add_c, priv_c.privateKey) + val Right(relay_c@ChannelRelayPacket(add_c2, payload_c, packet_d)) = decrypt(add_c, priv_c.privateKey, Features.empty) assert(add_c2 == add_c) assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_c.amountToForward == amount_cd) @@ -95,7 +95,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta) val add_d = UpdateAddHtlc(randomBytes32(), 2, amount_cd, paymentHash, expiry_cd, packet_d, None) - val Right(relay_d@ChannelRelayPacket(add_d2, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey) + val Right(relay_d@ChannelRelayPacket(add_d2, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) assert(add_d2 == add_d) assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_d.amountToForward == amount_de) @@ -105,7 +105,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta) val add_e = UpdateAddHtlc(randomBytes32(), 2, amount_de, paymentHash, expiry_de, packet_e, None) - val Right(FinalPacket(add_e2, payload_e)) = decrypt(add_e, priv_e.privateKey) + val Right(FinalPacket(add_e2, payload_e)) = decrypt(add_e, priv_e.privateKey, Features.empty) assert(add_e2 == add_e) assert(payload_e.amount == finalAmount) assert(payload_e.totalAmount == finalAmount) @@ -137,7 +137,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // let's peel the onion val add_b = UpdateAddHtlc(randomBytes32(), 0, finalAmount, paymentHash, finalExpiry, add.onion, None) - val Right(FinalPacket(add_b2, payload_b)) = decrypt(add_b, priv_b.privateKey) + val Right(FinalPacket(add_b2, payload_b)) = decrypt(add_b, priv_b.privateKey, Features.empty) assert(add_b2 == add_b) assert(payload_b.amount == finalAmount) assert(payload_b.totalAmount == finalAmount) @@ -161,12 +161,12 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(firstExpiry == expiry_ab) val add_b = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None) - val Right(ChannelRelayPacket(add_b2, payload_b, packet_c)) = decrypt(add_b, priv_b.privateKey) + val Right(ChannelRelayPacket(add_b2, payload_b, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) assert(add_b2 == add_b) assert(payload_b == IntermediatePayload.ChannelRelay.Standard(channelUpdate_bc.shortChannelId, amount_bc, expiry_bc)) val add_c = UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc, packet_c, None) - val Right(NodeRelayPacket(add_c2, outer_c, inner_c, packet_d)) = decrypt(add_c, priv_c.privateKey) + val Right(NodeRelayPacket(add_c2, outer_c, inner_c, packet_d)) = decrypt(add_c, priv_c.privateKey, Features.empty) assert(add_c2 == add_c) assert(outer_c.amount == amount_bc) assert(outer_c.totalAmount == amount_bc) @@ -184,7 +184,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(amount_d == amount_cd) assert(expiry_d == expiry_cd) val add_d = UpdateAddHtlc(randomBytes32(), 3, amount_d, paymentHash, expiry_d, onion_d.packet, None) - val Right(NodeRelayPacket(add_d2, outer_d, inner_d, packet_e)) = decrypt(add_d, priv_d.privateKey) + val Right(NodeRelayPacket(add_d2, outer_d, inner_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) assert(add_d2 == add_d) assert(outer_d.amount == amount_cd) assert(outer_d.totalAmount == amount_cd) @@ -202,7 +202,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(amount_e == amount_de) assert(expiry_e == expiry_de) val add_e = UpdateAddHtlc(randomBytes32(), 4, amount_e, paymentHash, expiry_e, onion_e.packet, None) - val Right(FinalPacket(add_e2, payload_e)) = decrypt(add_e, priv_e.privateKey) + val Right(FinalPacket(add_e2, payload_e)) = decrypt(add_e, priv_e.privateKey, Features.empty) assert(add_e2 == add_e) assert(payload_e == FinalPayload.Standard(TlvStream(AmountToForward(finalAmount), OutgoingCltv(finalExpiry), PaymentData(paymentSecret, finalAmount * 3), OnionPaymentPayloadTlv.PaymentMetadata(hex"010203")))) } @@ -225,10 +225,10 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(firstExpiry == expiry_ab) val add_b = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None) - val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(add_b, priv_b.privateKey) + val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) val add_c = UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc, packet_c, None) - val Right(NodeRelayPacket(_, outer_c, inner_c, packet_d)) = decrypt(add_c, priv_c.privateKey) + val Right(NodeRelayPacket(_, outer_c, inner_c, packet_d)) = decrypt(add_c, priv_c.privateKey, Features.empty) assert(outer_c.amount == amount_bc) assert(outer_c.totalAmount == amount_bc) assert(outer_c.expiry == expiry_bc) @@ -245,7 +245,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(amount_d == amount_cd) assert(expiry_d == expiry_cd) val add_d = UpdateAddHtlc(randomBytes32(), 3, amount_d, paymentHash, expiry_d, onion_d.packet, None) - val Right(NodeRelayPacket(_, outer_d, inner_d, _)) = decrypt(add_d, priv_d.privateKey) + val Right(NodeRelayPacket(_, outer_d, inner_d, _)) = decrypt(add_d, priv_d.privateKey, Features.empty) assert(outer_d.amount == amount_cd) assert(outer_d.totalAmount == amount_cd) assert(outer_d.expiry == expiry_cd) @@ -269,7 +269,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { test("fail to decrypt when the onion is invalid") { val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet.copy(payload = onion.packet.payload.reverse), None) - val Left(failure) = decrypt(add, priv_b.privateKey) + val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure.isInstanceOf[InvalidOnionHmac]) } @@ -277,78 +277,78 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount * 2, finalExpiry, paymentSecret, None)) val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet.copy(payload = trampolineOnion.packet.payload.reverse))) val add_b = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None) - val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(add_b, priv_b.privateKey) + val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) val add_c = UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc, packet_c, None) - val Left(failure) = decrypt(add_c, priv_c.privateKey) + val Left(failure) = decrypt(add_c, priv_c.privateKey, Features.empty) assert(failure.isInstanceOf[InvalidOnionHmac]) } test("fail to decrypt when payment hash doesn't match associated data") { val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash.reverse, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None) - val Left(failure) = decrypt(add, priv_b.privateKey) + val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure.isInstanceOf[InvalidOnionHmac]) } test("fail to decrypt at the final node when amount has been modified by next-to-last node") { val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, hops.take(1), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount - 100.msat, paymentHash, firstExpiry, onion.packet, None) - val Left(failure) = decrypt(add, priv_b.privateKey) + val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure == FinalIncorrectHtlcAmount(firstAmount - 100.msat)) } test("fail to decrypt at the final node when expiry has been modified by next-to-last node") { val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, hops.take(1), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry - CltvExpiryDelta(12), onion.packet, None) - val Left(failure) = decrypt(add, priv_b.privateKey) + val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure == FinalIncorrectCltvExpiry(firstExpiry - CltvExpiryDelta(12))) } test("fail to decrypt at the final trampoline node when amount has been modified by next-to-last trampoline") { val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret, None)) val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet)) - val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None), priv_b.privateKey) - val Right(NodeRelayPacket(_, _, _, packet_d)) = decrypt(UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc, packet_c, None), priv_c.privateKey) + val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None), priv_b.privateKey, Features.empty) + val Right(NodeRelayPacket(_, _, _, packet_d)) = decrypt(UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc, packet_c, None), priv_c.privateKey, Features.empty) // c forwards the trampoline payment to d. val Success((amount_d, expiry_d, onion_d)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(c, d, channelUpdate_cd) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_cd, amount_cd, expiry_cd, randomBytes32(), packet_d)) - val Right(NodeRelayPacket(_, _, _, packet_e)) = decrypt(UpdateAddHtlc(randomBytes32(), 3, amount_d, paymentHash, expiry_d, onion_d.packet, None), priv_d.privateKey) + val Right(NodeRelayPacket(_, _, _, packet_e)) = decrypt(UpdateAddHtlc(randomBytes32(), 3, amount_d, paymentHash, expiry_d, onion_d.packet, None), priv_d.privateKey, Features.empty) // d forwards an invalid amount to e (the outer total amount doesn't match the inner amount). val invalidTotalAmount = amount_de + 100.msat val Success((amount_e, expiry_e, onion_e)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(d, e, channelUpdate_de) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_de, invalidTotalAmount, expiry_de, randomBytes32(), packet_e)) - val Left(failure) = decrypt(UpdateAddHtlc(randomBytes32(), 4, amount_e, paymentHash, expiry_e, onion_e.packet, None), priv_e.privateKey) + val Left(failure) = decrypt(UpdateAddHtlc(randomBytes32(), 4, amount_e, paymentHash, expiry_e, onion_e.packet, None), priv_e.privateKey, Features.empty) assert(failure == FinalIncorrectHtlcAmount(invalidTotalAmount)) } test("fail to decrypt at the final trampoline node when expiry has been modified by next-to-last trampoline") { val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret, None)) val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet)) - val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None), priv_b.privateKey) - val Right(NodeRelayPacket(_, _, _, packet_d)) = decrypt(UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc, packet_c, None), priv_c.privateKey) + val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None), priv_b.privateKey, Features.empty) + val Right(NodeRelayPacket(_, _, _, packet_d)) = decrypt(UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc, packet_c, None), priv_c.privateKey, Features.empty) // c forwards the trampoline payment to d. val Success((amount_d, expiry_d, onion_d)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(c, d, channelUpdate_cd) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_cd, amount_cd, expiry_cd, randomBytes32(), packet_d)) - val Right(NodeRelayPacket(_, _, _, packet_e)) = decrypt(UpdateAddHtlc(randomBytes32(), 3, amount_d, paymentHash, expiry_d, onion_d.packet, None), priv_d.privateKey) + val Right(NodeRelayPacket(_, _, _, packet_e)) = decrypt(UpdateAddHtlc(randomBytes32(), 3, amount_d, paymentHash, expiry_d, onion_d.packet, None), priv_d.privateKey, Features.empty) // d forwards an invalid expiry to e (the outer expiry doesn't match the inner expiry). val invalidExpiry = expiry_de - CltvExpiryDelta(12) val Success((amount_e, expiry_e, onion_e)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(d, e, channelUpdate_de) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_de, amount_de, invalidExpiry, randomBytes32(), packet_e)) - val Left(failure) = decrypt(UpdateAddHtlc(randomBytes32(), 4, amount_e, paymentHash, expiry_e, onion_e.packet, None), priv_e.privateKey) + val Left(failure) = decrypt(UpdateAddHtlc(randomBytes32(), 4, amount_e, paymentHash, expiry_e, onion_e.packet, None), priv_e.privateKey, Features.empty) assert(failure == FinalIncorrectCltvExpiry(invalidExpiry)) } test("fail to decrypt at intermediate trampoline node when amount is invalid") { val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet)) - val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None), priv_b.privateKey) + val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None), priv_b.privateKey, Features.empty) // A trampoline relay is very similar to a final node: it can validate that the HTLC amount matches the onion outer amount. - val Left(failure) = decrypt(UpdateAddHtlc(randomBytes32(), 2, amount_bc - 100.msat, paymentHash, expiry_bc, packet_c, None), priv_c.privateKey) + val Left(failure) = decrypt(UpdateAddHtlc(randomBytes32(), 2, amount_bc - 100.msat, paymentHash, expiry_bc, packet_c, None), priv_c.privateKey, Features.empty) assert(failure == FinalIncorrectHtlcAmount(amount_bc - 100.msat)) } test("fail to decrypt at intermediate trampoline node when expiry is invalid") { val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet)) - val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None), priv_b.privateKey) + val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None), priv_b.privateKey, Features.empty) // A trampoline relay is very similar to a final node: it can validate that the HTLC expiry matches the onion outer expiry. - val Left(failure) = decrypt(UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc - CltvExpiryDelta(12), packet_c, None), priv_c.privateKey) + val Left(failure) = decrypt(UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc - CltvExpiryDelta(12), packet_c, None), priv_c.privateKey, Features.empty) assert(failure == FinalIncorrectCltvExpiry(expiry_bc - CltvExpiryDelta(12))) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/internal/CommandCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/internal/CommandCodecsSpec.scala index 592dc1d54..b38cee448 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/internal/CommandCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/internal/CommandCodecsSpec.scala @@ -49,30 +49,30 @@ class CommandCodecsSpec extends AnyFunSuite { } test("backward compatibility") { - val data32 = randomBytes32() val data123 = randomBytes(123) val legacyCmdFulfillCodec = - (("id" | int64) :: + ("id" | int64) :: ("r" | bytes32) :: - ("commit" | provide(false))) + ("commit" | provide(false)) assert(CommandCodecs.cmdFulfillCodec.decode(legacyCmdFulfillCodec.encode(42 :: data32 :: true :: HNil).require).require == DecodeResult(CMD_FULFILL_HTLC(42, data32, commit = false, None), BitVector.empty)) val legacyCmdFailCodec = - (("id" | int64) :: + ("id" | int64) :: ("reason" | either(bool, varsizebinarydata, failureMessageCodec)) :: - ("commit" | provide(false))) + ("commit" | provide(false)) assert(CommandCodecs.cmdFailCodec.decode(legacyCmdFailCodec.encode(42 :: Left(data123) :: true :: HNil).require).require == DecodeResult(CMD_FAIL_HTLC(42, Left(data123), commit = false, None), BitVector.empty)) val legacyCmdFailMalformedCodec = - (("id" | int64) :: + ("id" | int64) :: ("onionHash" | bytes32) :: ("failureCode" | uint16) :: - ("commit" | provide(false))) + ("commit" | provide(false)) assert(CommandCodecs.cmdFailMalformedCodec.decode(legacyCmdFailMalformedCodec.encode(42 :: data32 :: 456 :: true :: HNil).require).require == - DecodeResult(CMD_FAIL_MALFORMED_HTLC(42, data32, 456, commit = false, None), BitVector.empty)) + DecodeResult(CMD_FAIL_MALFORMED_HTLC(42, data32, 456, None, commit = false, None), BitVector.empty)) } + } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala index c64d50b9b..4c15adb92 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala @@ -195,6 +195,29 @@ class PaymentOnionSpec extends AnyFunSuite { assert(perHopPayloadCodec.decode(bin.bits).require.value == tlvs) } + test("encode/decode final blinded per-hop payload") { + val blindedTlvs = TlvStream[RouteBlindingEncryptedDataTlv]( + RouteBlindingEncryptedDataTlv.PathId(hex"2a2a2a2a"), + RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat), + ) + val testCases = Map( + TlvStream[OnionPaymentPayloadTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), EncryptedRecipientData(hex"deadbeef")) -> hex"0d 02020231 04012a 0a04deadbeef", + TlvStream[OnionPaymentPayloadTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), EncryptedRecipientData(hex"deadbeef"), BlindingPoint(PublicKey(hex"036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2"))) -> hex"30 02020231 04012a 0a04deadbeef 0c21036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2", + TlvStream[OnionPaymentPayloadTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), EncryptedRecipientData(hex"deadbeef"), BlindingPoint(PublicKey(hex"036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2")), TotalAmount(1105 msat)) -> hex"34 02020231 04012a 0a04deadbeef 0c21036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2 12020451", + ) + + for ((expected, bin) <- testCases) { + val decoded = perHopPayloadCodec.decode(bin.bits).require.value + assert(decoded == expected) + val Right(payload) = FinalPayload.Blinded.validate(decoded, blindedTlvs) + assert(payload.amount == 561.msat) + assert(payload.expiry == CltvExpiry(42)) + assert(payload.pathId_opt.contains(hex"2a2a2a2a")) + val encoded = perHopPayloadCodec.encode(expected).require.bytes + assert(encoded == bin) + } + } + test("decode multi-part final per-hop payload") { val Right(multiPart) = FinalPayload.Standard.validate(perHopPayloadCodec.decode(hex"2b 02020231 04012a 0822eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f2836866190451".bits).require.value) assert(multiPart.amount == 561.msat) @@ -232,9 +255,13 @@ class PaymentOnionSpec extends AnyFunSuite { val testCases = Seq( // Forbidden non-encrypted amount. - TestCase(ForbiddenTlv(UInt64(2)), hex"0e 02020231 0a080123456789abcdef", validBlindedTlvs), + TestCase(ForbiddenTlv(UInt64(0)), hex"0e 02020231 0a080123456789abcdef", validBlindedTlvs), // Forbidden non-encrypted expiry. - TestCase(ForbiddenTlv(UInt64(4)), hex"0d 04012a 0a080123456789abcdef", validBlindedTlvs), + TestCase(ForbiddenTlv(UInt64(0)), hex"0d 04012a 0a080123456789abcdef", validBlindedTlvs), + // Forbidden outgoing channel id. + TestCase(ForbiddenTlv(UInt64(0)), hex"14 06080000000000000451 0a080123456789abcdef", validBlindedTlvs), + // Forbidden unknown tlv. + TestCase(ForbiddenTlv(UInt64(51)), hex"0e 0a080123456789abcdef 33020102", validBlindedTlvs), // Missing encrypted data. TestCase(MissingRequiredTlv(UInt64(10)), hex"23 0c21036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2", validBlindedTlvs), // Missing encrypted outgoing channel. @@ -279,6 +306,26 @@ class PaymentOnionSpec extends AnyFunSuite { } } + test("decode invalid final blinded per-hop payload") { + val blindedTlvs = TlvStream[RouteBlindingEncryptedDataTlv]( + RouteBlindingEncryptedDataTlv.PathId(hex"2a2a2a2a"), + RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat), + ) + val testCases = Seq( + (MissingRequiredTlv(UInt64(2)), hex"0d 04012a 0a080123456789abcdef"), // missing amount + (MissingRequiredTlv(UInt64(4)), hex"0e 02020231 0a080123456789abcdef"), // missing expiry + (MissingRequiredTlv(UInt64(10)), hex"07 02020231 04012a"), // missing encrypted data + (ForbiddenTlv(UInt64(0)), hex"1b 02020231 04012a 06080000000000000451 0a080123456789abcdef"), // forbidden outgoing_channel_id + (ForbiddenTlv(UInt64(0)), hex"35 02020231 04012a 0822eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f2836866190451 0a080123456789abcdef"), // forbidden payment_data + (ForbiddenTlv(UInt64(0)), hex"17 02020231 04012a 0a080123456789abcdef 1004deadbeef"), // forbidden payment_metadata + (ForbiddenTlv(UInt64(65535)), hex"17 02020231 04012a 0a080123456789abcdef fdffff0206c1"), // forbidden unknown tlv + ) + + for ((expectedErr, bin) <- testCases) { + assert(FinalPayload.Blinded.validate(perHopPayloadCodec.decode(bin.bits).require.value, blindedTlvs) == Left(expectedErr)) + } + } + test("decode invalid per-hop payload") { val testCases = Seq( // Invalid fixed-size (legacy) payload. diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/RouteBlindingSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/RouteBlindingSpec.scala index b1545a98e..92266201c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/RouteBlindingSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/RouteBlindingSpec.scala @@ -53,6 +53,19 @@ class RouteBlindingSpec extends AnyFunSuiteLike { } } + test("reject non-empty allowed features for intermediate nodes") { + { + val encoded = hex"02080000000000000231 0a060090000000fa 0c06000b699105dc 0e0101" + val decoded = blindedRouteDataCodec.decode(encoded.bits).require.value + assert(BlindedRouteData.validatePaymentRelayData(decoded) == Left(ForbiddenTlv(UInt64(14)))) + } + { + val encoded = hex"01020000 042102edabbd16b41c8371b92ef2f04c1185b4f03b6dcd52ba9b78d9d7c89c8f221145 0e020100" + val decoded = blindedRouteDataCodec.decode(encoded.bits).require.value + assert(BlindedRouteData.validateMessageRelayData(decoded) == Left(ForbiddenTlv(UInt64(14)))) + } + } + test("decode encrypted route blinding data") { val sessionKey = randomKey() val nodePrivKeys = Seq(randomKey(), randomKey(), randomKey(), randomKey(), randomKey())