diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala index 6a4d2747d..a207f0232 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala @@ -16,15 +16,16 @@ package fr.acinq.eclair.payment -import java.util.UUID - import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey -import fr.acinq.eclair.MilliSatoshi import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.router.Router.{ChannelDesc, ChannelHop, Hop, Ignore} -import fr.acinq.eclair.wire.Node +import fr.acinq.eclair.wire.{ChannelUpdate, Node} +import fr.acinq.eclair.{MilliSatoshi, ShortChannelId} + +import java.util.UUID /** * Created by PM on 01/02/2017. @@ -201,4 +202,23 @@ object PaymentFailure { failures.foldLeft(ignore) { case (current, failure) => updateIgnored(failure, current) } } + /** Update the invoice routing hints based on more recent channel updates received. */ + def updateRoutingHints(failures: Seq[PaymentFailure], routingHints: Seq[Seq[ExtraHop]]): Seq[Seq[ExtraHop]] = { + // We're only interested in the last channel update received per channel. + val updates = failures.foldLeft(Map.empty[ShortChannelId, ChannelUpdate]) { + case (current, failure) => failure match { + case RemoteFailure(_, Sphinx.DecryptedFailurePacket(_, f: Update)) => current.updated(f.update.shortChannelId, f.update) + case _ => current + } + } + routingHints.map(_.map(extraHop => updates.get(extraHop.shortChannelId) match { + case Some(u) => extraHop.copy( + cltvExpiryDelta = u.cltvExpiryDelta, + feeBase = u.feeBaseMsat, + feeProportionalMillionths = u.feeProportionalMillionths + ) + case None => extraHop + })) + } + } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala index 6c5e8f050..a98741f50 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala @@ -121,7 +121,8 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, gotoAbortedOrStop(PaymentAborted(d.sender, d.request, d.failures ++ pf.failures, d.pending.keySet - pf.id)) } else { val ignore1 = PaymentFailure.updateIgnored(pf.failures, d.ignore) - stay using d.copy(pending = d.pending - pf.id, ignore = ignore1, failures = d.failures ++ pf.failures) + val assistedRoutes1 = PaymentFailure.updateRoutingHints(pf.failures, d.request.assistedRoutes) + stay using d.copy(pending = d.pending - pf.id, ignore = ignore1, failures = d.failures ++ pf.failures, request = d.request.copy(assistedRoutes = assistedRoutes1)) } // The recipient released the preimage without receiving the full payment amount. @@ -142,11 +143,12 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, gotoAbortedOrStop(PaymentAborted(d.sender, d.request, d.failures ++ pf.failures :+ failure, d.pending.keySet - pf.id)) } else { val ignore1 = PaymentFailure.updateIgnored(pf.failures, d.ignore) + val assistedRoutes1 = PaymentFailure.updateRoutingHints(pf.failures, d.request.assistedRoutes) val stillPending = d.pending - pf.id val (toSend, maxFee) = remainingToSend(nodeParams, d.request, stillPending.values) log.debug("child payment failed, retry sending {} with maximum fee {}", toSend, maxFee) val routeParams = d.request.getRouteParams(nodeParams, randomize = true) // we randomize route selection when we retry - val d1 = d.copy(pending = stillPending, ignore = ignore1, failures = d.failures ++ pf.failures) + val d1 = d.copy(pending = stillPending, ignore = ignore1, failures = d.failures ++ pf.failures, request = d.request.copy(assistedRoutes = assistedRoutes1)) router ! createRouteRequest(nodeParams, toSend, maxFee, routeParams, d1, cfg) goto(WAIT_FOR_ROUTES) using d1 } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala index 95e3bac4d..ce9f50d2c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala @@ -16,14 +16,13 @@ package fr.acinq.eclair.payment -import java.util.UUID - import akka.actor.{ActorRef, Status} import akka.testkit.{TestFSMRef, TestProbe} import fr.acinq.bitcoin.{Block, Crypto} import fr.acinq.eclair._ import fr.acinq.eclair.channel.{AddHtlcFailed, ChannelFlags, ChannelUnavailable, Upstream} import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle._ import fr.acinq.eclair.payment.send.PaymentError.RetryExhausted @@ -36,6 +35,7 @@ import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike import scodec.bits.{ByteVector, HexStringSyntax} +import java.util.UUID import scala.concurrent.duration._ /** @@ -238,6 +238,64 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS assert(result.amountWithFees === 1000200.msat) } + test("retry with updated routing hints") { f => + import f._ + + // The B -> E channel is private and provided in the invoice routing hints. + val routingHint = ExtraHop(b, hop_be.lastUpdate.shortChannelId, hop_be.lastUpdate.feeBaseMsat, hop_be.lastUpdate.feeProportionalMillionths, hop_be.lastUpdate.cltvExpiryDelta) + val payment = SendMultiPartPayment(randomBytes32, e, finalAmount, expiry, 3, routeParams = Some(routeParams), assistedRoutes = List(List(routingHint))) + sender.send(payFsm, payment) + assert(router.expectMsgType[RouteRequest].assistedRoutes.head.head === routingHint) + val route = Route(finalAmount, hop_ab_1 :: hop_be :: Nil) + router.send(payFsm, RouteResponse(Seq(route))) + childPayFsm.expectMsgType[SendPaymentToRoute] + childPayFsm.expectNoMsg(100 millis) + + // B changed his fees and expiry after the invoice was issued. + val channelUpdate = hop_be.lastUpdate.copy(feeBaseMsat = 250 msat, feeProportionalMillionths = 150, cltvExpiryDelta = CltvExpiryDelta(24)) + val childId = payFsm.stateData.asInstanceOf[PaymentProgress].pending.keys.head + childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(route.hops, Sphinx.DecryptedFailurePacket(b, FeeInsufficient(finalAmount, channelUpdate)))))) + // We update the routing hints accordingly before requesting a new route. + assert(router.expectMsgType[RouteRequest].assistedRoutes.head.head === ExtraHop(b, channelUpdate.shortChannelId, 250 msat, 150, CltvExpiryDelta(24))) + } + + test("update routing hints") { _ => + val routingHints = Seq( + Seq(ExtraHop(a, ShortChannelId(1), 10 msat, 0, CltvExpiryDelta(12)), ExtraHop(b, ShortChannelId(2), 0 msat, 100, CltvExpiryDelta(24))), + Seq(ExtraHop(a, ShortChannelId(3), 1 msat, 10, CltvExpiryDelta(144))) + ) + + def makeChannelUpdate(shortChannelId: ShortChannelId, feeBase: MilliSatoshi, feeProportional: Long, cltvExpiryDelta: CltvExpiryDelta): ChannelUpdate = { + defaultChannelUpdate.copy(shortChannelId = shortChannelId, feeBaseMsat = feeBase, feeProportionalMillionths = feeProportional, cltvExpiryDelta = cltvExpiryDelta) + } + + { + val failures = Seq( + LocalFailure(Nil, ChannelUnavailable(randomBytes32)), + RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(b, FeeInsufficient(100 msat, makeChannelUpdate(ShortChannelId(2), 15 msat, 150, CltvExpiryDelta(48))))), + UnreadableRemoteFailure(Nil) + ) + val routingHints1 = Seq( + Seq(ExtraHop(a, ShortChannelId(1), 10 msat, 0, CltvExpiryDelta(12)), ExtraHop(b, ShortChannelId(2), 15 msat, 150, CltvExpiryDelta(48))), + Seq(ExtraHop(a, ShortChannelId(3), 1 msat, 10, CltvExpiryDelta(144))) + ) + assert(routingHints1 === PaymentFailure.updateRoutingHints(failures, routingHints)) + } + { + val failures = Seq( + RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(a, FeeInsufficient(100 msat, makeChannelUpdate(ShortChannelId(1), 20 msat, 20, CltvExpiryDelta(20))))), + RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(b, FeeInsufficient(100 msat, makeChannelUpdate(ShortChannelId(2), 21 msat, 21, CltvExpiryDelta(21))))), + RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(a, FeeInsufficient(100 msat, makeChannelUpdate(ShortChannelId(3), 22 msat, 22, CltvExpiryDelta(22))))), + RemoteFailure(Nil, Sphinx.DecryptedFailurePacket(a, FeeInsufficient(100 msat, makeChannelUpdate(ShortChannelId(1), 23 msat, 23, CltvExpiryDelta(23))))) + ) + val routingHints1 = Seq( + Seq(ExtraHop(a, ShortChannelId(1), 23 msat, 23, CltvExpiryDelta(23)), ExtraHop(b, ShortChannelId(2), 21 msat, 21, CltvExpiryDelta(21))), + Seq(ExtraHop(a, ShortChannelId(3), 22 msat, 22, CltvExpiryDelta(22))) + ) + assert(routingHints1 === PaymentFailure.updateRoutingHints(failures, routingHints)) + } + } + test("abort after too many failed attempts") { f => import f._