diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index 4584e9c5d..8dc066a17 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -24,6 +24,7 @@ import akka.util.Timeout import com.softwaremill.quicklens.ModifyPimp import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Crypto, Satoshi} +import fr.acinq.eclair.ApiTypes.ChannelNotFound import fr.acinq.eclair.balance.CheckBalance.GlobalBalance import fr.acinq.eclair.balance.{BalanceActor, ChannelsListener} import fr.acinq.eclair.blockchain.OnChainWallet.OnChainBalance @@ -68,6 +69,8 @@ object SignedMessage { object ApiTypes { type ChannelIdentifier = Either[ByteVector32, ShortChannelId] + + case class ChannelNotFound(identifier: ChannelIdentifier) extends IllegalArgumentException(s"channel ${identifier.fold(_.toString(), _.toString)} not found") } trait Eclair { @@ -179,11 +182,11 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { } override def close(channels: List[ApiTypes.ChannelIdentifier], scriptPubKey_opt: Option[ByteVector], closingFeerates_opt: Option[ClosingFeerates])(implicit timeout: Timeout): Future[Map[ApiTypes.ChannelIdentifier, Either[Throwable, CommandResponse[CMD_CLOSE]]]] = { - sendToChannels[CommandResponse[CMD_CLOSE]](channels, CMD_CLOSE(ActorRef.noSender, scriptPubKey_opt, closingFeerates_opt)) + sendToChannels(channels, CMD_CLOSE(ActorRef.noSender, scriptPubKey_opt, closingFeerates_opt)) } override def forceClose(channels: List[ApiTypes.ChannelIdentifier])(implicit timeout: Timeout): Future[Map[ApiTypes.ChannelIdentifier, Either[Throwable, CommandResponse[CMD_FORCECLOSE]]]] = { - sendToChannels[CommandResponse[CMD_FORCECLOSE]](channels, CMD_FORCECLOSE(ActorRef.noSender)) + sendToChannels(channels, CMD_FORCECLOSE(ActorRef.noSender)) } override def updateRelayFee(nodes: List[PublicKey], feeBaseMsat: MilliSatoshi, feeProportionalMillionths: Long)(implicit timeout: Timeout): Future[Map[ApiTypes.ChannelIdentifier, Either[Throwable, CommandResponse[CMD_UPDATE_RELAY_FEE]]]] = { @@ -207,16 +210,16 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { override def channelsInfo(toRemoteNode_opt: Option[PublicKey])(implicit timeout: Timeout): Future[Iterable[RES_GETINFO]] = toRemoteNode_opt match { case Some(pk) => for { channelIds <- (appKit.register ? Symbol("channelsTo")).mapTo[Map[ByteVector32, PublicKey]].map(_.filter(_._2 == pk).keys) - channels <- Future.sequence(channelIds.map(channelId => sendToChannel[RES_GETINFO](Left(channelId), CMD_GETINFO(ActorRef.noSender)))) + channels <- Future.sequence(channelIds.map(channelId => sendToChannel[CMD_GETINFO, RES_GETINFO](Left(channelId), CMD_GETINFO(ActorRef.noSender)))) } yield channels case None => for { channelIds <- (appKit.register ? Symbol("channels")).mapTo[Map[ByteVector32, ActorRef]].map(_.keys) - channels <- Future.sequence(channelIds.map(channelId => sendToChannel[RES_GETINFO](Left(channelId), CMD_GETINFO(ActorRef.noSender)))) + channels <- Future.sequence(channelIds.map(channelId => sendToChannel[CMD_GETINFO, RES_GETINFO](Left(channelId), CMD_GETINFO(ActorRef.noSender)))) } yield channels } override def channelInfo(channel: ApiTypes.ChannelIdentifier)(implicit timeout: Timeout): Future[RES_GETINFO] = { - sendToChannel[RES_GETINFO](channel, CMD_GETINFO(ActorRef.noSender)) + sendToChannel[CMD_GETINFO, RES_GETINFO](channel, CMD_GETINFO(ActorRef.noSender)) } override def allChannels()(implicit timeout: Timeout): Future[Iterable[ChannelDesc]] = { @@ -405,18 +408,20 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { Future.fromTry(appKit.nodeParams.db.payments.removeIncomingPayment(paymentHash).map(_ => s"deleted invoice $paymentHash")) } + + /** * Send a request to a channel and expect a response. * * @param channel either a shortChannelId (BOLT encoded) or a channelId (32-byte hex encoded). */ - private def sendToChannel[T: ClassTag](channel: ApiTypes.ChannelIdentifier, request: Any)(implicit timeout: Timeout): Future[T] = (channel match { + private def sendToChannel[C <: Command, R <: CommandResponse[C]](channel: ApiTypes.ChannelIdentifier, request: C)(implicit timeout: Timeout): Future[R] = (channel match { case Left(channelId) => appKit.register ? Register.Forward(ActorRef.noSender, channelId, request) case Right(shortChannelId) => appKit.register ? Register.ForwardShortId(ActorRef.noSender, shortChannelId, request) }).map { - case t: T => t - case t: Register.ForwardFailure[T]@unchecked => throw new RuntimeException(s"channel ${t.fwd.channelId} not found") - case t: Register.ForwardShortIdFailure[T]@unchecked => throw new RuntimeException(s"channel ${t.fwd.shortChannelId} not found") + case t: R@unchecked => t + case t: Register.ForwardFailure[C]@unchecked => throw ChannelNotFound(Left(t.fwd.channelId)) + case t: Register.ForwardShortIdFailure[C]@unchecked => throw ChannelNotFound(Right(t.fwd.shortChannelId)) } /** @@ -424,16 +429,16 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { * * @param channels either shortChannelIds (BOLT encoded) or channelIds (32-byte hex encoded). */ - private def sendToChannels[T: ClassTag](channels: List[ApiTypes.ChannelIdentifier], request: Any)(implicit timeout: Timeout): Future[Map[ApiTypes.ChannelIdentifier, Either[Throwable, T]]] = { - val commands = channels.map(c => sendToChannel[T](c, request).map(r => Right(r)).recover(t => Left(t)).map(r => c -> r)) - Future.foldLeft(commands)(Map.empty[ApiTypes.ChannelIdentifier, Either[Throwable, T]])(_ + _) + private def sendToChannels[C <: Command, R <: CommandResponse[C]](channels: List[ApiTypes.ChannelIdentifier], request: C)(implicit timeout: Timeout): Future[Map[ApiTypes.ChannelIdentifier, Either[Throwable, R]]] = { + val commands = channels.map(c => sendToChannel[C, R](c, request).map(r => Right(r)).recover(t => Left(t)).map(r => c -> r)) + Future.foldLeft(commands)(Map.empty[ApiTypes.ChannelIdentifier, Either[Throwable, R]])(_ + _) } /** Send a request to multiple channels using node ids */ - private def sendToNodes[T: ClassTag](nodeids: List[PublicKey], request: Any)(implicit timeout: Timeout): Future[Map[ApiTypes.ChannelIdentifier, Either[Throwable, T]]] = { + private def sendToNodes[C <: Command, R <: CommandResponse[C]](nodeids: List[PublicKey], request: C)(implicit timeout: Timeout): Future[Map[ApiTypes.ChannelIdentifier, Either[Throwable, R]]] = { for { channelIds <- (appKit.register ? Symbol("channelsTo")).mapTo[Map[ByteVector32, PublicKey]].map(_.filter(kv => nodeids.contains(kv._2)).keys) - res <- sendToChannels[T](channelIds.map(Left(_)).toList, request) + res <- sendToChannels[C, R](channelIds.map(Left(_)).toList, request) } yield res } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala index 5928c8df5..b68aecd7a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala @@ -229,13 +229,12 @@ final case class RES_ADD_SETTLED[+O <: Origin, +R <: HtlcResult](origin: O, htlc final case class RES_GETSTATE[+S <: ChannelState](state: S) extends CommandSuccess[CMD_GETSTATE] final case class RES_GETSTATEDATA[+D <: ChannelData](data: D) extends CommandSuccess[CMD_GETSTATEDATA] final case class RES_GETINFO(nodeId: PublicKey, channelId: ByteVector32, state: ChannelState, data: ChannelData) extends CommandSuccess[CMD_GETINFO] -final case class RES_CLOSE(channelId: ByteVector32) extends CommandSuccess[CMD_CLOSE] /** * Those are not response to [[Command]], but to [[fr.acinq.eclair.io.Peer.OpenChannel]] * * If actor A sends a [[fr.acinq.eclair.io.Peer.OpenChannel]] and actor B sends a [[CMD_CLOSE]], then A will receive a - * [[ChannelOpenResponse.ChannelClosed]] whereas B will receive a [[RES_CLOSE]] + * [[ChannelOpenResponse.ChannelClosed]] whereas B will receive a [[RES_SUCCESS]] */ sealed trait ChannelOpenResponse object ChannelOpenResponse { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala b/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala index 753063019..0e020c6a7 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala @@ -191,6 +191,7 @@ object ChannelOpenResponseSerializer extends MinimalSerializer({ object CommandResponseSerializer extends MinimalSerializer({ case RES_SUCCESS(_: CloseCommand, channelId) => JString(s"closed channel $channelId") + case RES_SUCCESS(_, _) => JString("ok") case RES_FAILURE(_: Command, ex: Throwable) => JString(ex.getMessage) }) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala index 61eda8374..d374529e2 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala @@ -22,6 +22,7 @@ import akka.testkit.TestProbe import akka.util.Timeout import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64, Crypto, SatoshiLong} +import fr.acinq.eclair.ApiTypes.ChannelNotFound import fr.acinq.eclair.TestConstants._ import fr.acinq.eclair.blockchain.DummyOnChainWallet import fr.acinq.eclair.blockchain.fee.{FeeratePerByte, FeeratePerKw} @@ -37,6 +38,7 @@ import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentToNode, SendPay import fr.acinq.eclair.router.RouteCalculationSpec.makeUpdateShort import fr.acinq.eclair.router.Router.{GetNetworkStats, GetNetworkStatsResponse, PredefinedNodeRoute, PublicChannel} import fr.acinq.eclair.router.{Announcements, NetworkStats, Router, Stats} +import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec import fr.acinq.eclair.wire.protocol.{ChannelUpdate, Color, NodeAnnouncement} import org.mockito.Mockito import org.mockito.scalatest.IdiomaticMockito @@ -469,6 +471,121 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I assert(verifiedMessage.publicKey !== kit.nodeParams.nodeId) } + test("get channel info (all channels)") { f => + import f._ + + val eclair = new EclairImpl(kit) + + val a = randomKey().publicKey + val b = randomKey().publicKey + val a1 = randomBytes32() + val a2 = randomBytes32() + val b1 = randomBytes32() + val map = Map(a1 -> a, a2 -> a, b1 -> b) + + val res = eclair.channelsInfo(toRemoteNode_opt = None) + + register.expectMsg(Symbol("channels")) + register.reply(map) + + val c1 = register.expectMsgType[Register.Forward[CMD_GETINFO]] + register.reply(RES_GETINFO(map(c1.channelId), c1.channelId, NORMAL, ChannelCodecsSpec.normal)) + val c2 = register.expectMsgType[Register.Forward[CMD_GETINFO]] + register.reply(RES_GETINFO(map(c2.channelId), c2.channelId, NORMAL, ChannelCodecsSpec.normal)) + val c3 = register.expectMsgType[Register.Forward[CMD_GETINFO]] + register.reply(RES_GETINFO(map(c3.channelId), c3.channelId, NORMAL, ChannelCodecsSpec.normal)) + register.expectNoMessage() + + awaitCond(res.isCompleted) + + assert(res.value.get.get.toSet === Set( + RES_GETINFO(a, a1, NORMAL, ChannelCodecsSpec.normal), + RES_GETINFO(a, a2, NORMAL, ChannelCodecsSpec.normal), + RES_GETINFO(b, b1, NORMAL, ChannelCodecsSpec.normal), + )) + } + + test("get channel info (using node id)") { f => + import f._ + + val eclair = new EclairImpl(kit) + + val a = randomKey().publicKey + val b = randomKey().publicKey + val a1 = randomBytes32() + val a2 = randomBytes32() + val b1 = randomBytes32() + val channels2Nodes = Map(a1 -> a, a2 -> a, b1 -> b) + + val res = eclair.channelsInfo(toRemoteNode_opt = Some(a)) + + register.expectMsg(Symbol("channelsTo")) + register.reply(channels2Nodes) + + val c1 = register.expectMsgType[Register.Forward[CMD_GETINFO]] + register.reply(RES_GETINFO(channels2Nodes(c1.channelId), c1.channelId, NORMAL, ChannelCodecsSpec.normal)) + val c2 = register.expectMsgType[Register.Forward[CMD_GETINFO]] + register.reply(RES_GETINFO(channels2Nodes(c2.channelId), c2.channelId, NORMAL, ChannelCodecsSpec.normal)) + register.expectNoMessage() + + awaitCond(res.isCompleted) + + assert(res.value.get.get.toSet === Set( + RES_GETINFO(a, a1, NORMAL, ChannelCodecsSpec.normal), + RES_GETINFO(a, a2, NORMAL, ChannelCodecsSpec.normal) + )) + } + + test("get channel info (using channel id)") { f => + import f._ + + val eclair = new EclairImpl(kit) + + val a = randomKey().publicKey + val b = randomKey().publicKey + val a1 = randomBytes32() + val a2 = randomBytes32() + val b1 = randomBytes32() + val channels2Nodes = Map(a1 -> a, a2 -> a, b1 -> b) + + val res = eclair.channelInfo(Left(a2)) + + val c1 = register.expectMsgType[Register.Forward[CMD_GETINFO]] + register.reply(RES_GETINFO(channels2Nodes(c1.channelId), c1.channelId, NORMAL, ChannelCodecsSpec.normal)) + register.expectNoMessage() + + awaitCond(res.isCompleted) + + assert(res.value.get.get === RES_GETINFO(a, a2, NORMAL, ChannelCodecsSpec.normal)) + } + + test("close channels") { f => + import f._ + + val eclair = new EclairImpl(kit) + + val a = randomKey().publicKey + val b = randomKey().publicKey + val a1 = randomBytes32() + val a2 = randomBytes32() + val b1 = randomBytes32() + + val res = eclair.close(List(Left(a2), Left(b1)), None, None) + + val c1 = register.expectMsgType[Register.Forward[CMD_CLOSE]] + register.reply(RES_SUCCESS(c1.message, c1.channelId)) + val c2 = register.expectMsgType[Register.Forward[CMD_CLOSE]] + register.reply(RES_SUCCESS(c2.message, c2.channelId)) + register.expectNoMessage() + + awaitCond(res.isCompleted) + + assert(res.value.get.get === Map( + Left(a2) -> Right(RES_SUCCESS(CMD_CLOSE(ActorRef.noSender, None, None), a2)), + Left(b1) -> Right(RES_SUCCESS(CMD_CLOSE(ActorRef.noSender, None, None), b1)) + )) + } + test("update relay fees in database") { f => import f._ @@ -482,8 +599,31 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I val a = randomKey().publicKey val b = randomKey().publicKey + val a1 = randomBytes32() + val a2 = randomBytes32() + val b1 = randomBytes32() + val map = Map(a1 -> a, a2 -> a, b1 -> b) - eclair.updateRelayFee(List(a, b), 999 msat, 1234) + val res = eclair.updateRelayFee(List(a, b), 999 msat, 1234) + + register.expectMsg(Symbol("channelsTo")) + register.reply(map) + + val u1 = register.expectMsgType[Register.Forward[CMD_UPDATE_RELAY_FEE]] + register.reply(RES_SUCCESS(u1.message, u1.channelId)) + val u2 = register.expectMsgType[Register.Forward[CMD_UPDATE_RELAY_FEE]] + register.reply(RES_FAILURE(u2.message, CommandUnavailableInThisState(u2.channelId, "CMD_UPDATE_RELAY_FEE", channel.CLOSING))) + val u3 = register.expectMsgType[Register.Forward[CMD_UPDATE_RELAY_FEE]] + register.reply(Register.ForwardFailure(u3)) + register.expectNoMessage() + + awaitCond(res.isCompleted) + + assert(res.value.get.get === Map( + Left(a1) -> Right(RES_SUCCESS(CMD_UPDATE_RELAY_FEE(ActorRef.noSender, 999 msat, 1234, None), a1)), + Left(a2) -> Right(RES_FAILURE(CMD_UPDATE_RELAY_FEE(ActorRef.noSender, 999 msat, 1234, None), CommandUnavailableInThisState(a2, "CMD_UPDATE_RELAY_FEE", channel.CLOSING))), + Left(b1) -> Left(ChannelNotFound(Left(b1))) + )) peersDb.addOrUpdateRelayFees(a, RelayFees(999 msat, 1234)).wasCalled(once) peersDb.addOrUpdateRelayFees(b, RelayFees(999 msat, 1234)).wasCalled(once) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/json/JsonSerializersSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/json/JsonSerializersSpec.scala index 30061a299..a515fb862 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/json/JsonSerializersSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/json/JsonSerializersSpec.scala @@ -16,11 +16,12 @@ package fr.acinq.eclair.json +import akka.actor.ActorRef import fr.acinq.bitcoin.{Btc, ByteVector32, OutPoint, Satoshi, SatoshiLong, Transaction, TxOut} import fr.acinq.eclair._ import fr.acinq.eclair.balance.CheckBalance import fr.acinq.eclair.balance.CheckBalance.{ClosingBalance, GlobalBalance, MainAndHtlcBalance, PossiblyPublishedMainAndHtlcBalance, PossiblyPublishedMainBalance} -import fr.acinq.eclair.channel.Origin +import fr.acinq.eclair.channel.{CMD_UPDATE_RELAY_FEE, CommandResponse, CommandUnavailableInThisState, Origin, RES_FAILURE, RES_SUCCESS} import fr.acinq.eclair.payment.{PaymentRequest, PaymentSettlingOnChain} import fr.acinq.eclair.transactions.Transactions._ import fr.acinq.eclair.transactions.{IncomingHtlc, OutgoingHtlc} @@ -273,6 +274,21 @@ class JsonSerializersSpec extends AnyFunSuite with Matchers { JsonSerializers.serialization.write(tsms)(JsonSerializers.formats) shouldBe """{"iso":"2021-10-04T14:32:41.456Z","unix":1633357961}""" } + test("serialize channel command responses") { + val id1 = ByteVector32(hex"e2fc57221cfb1942224082174022f3f70a32005aa209956f9c94c6903f7669ff") + val id2 = ByteVector32(hex"8e3ec6e16436b7dc61b86340192603d05f16d4f8e06c8aaa02fbe2ad63209af3") + val id3 = ByteVector32(hex"74ca7a86e52d597aa2248cd2ff3b24428ede71345262be7fb31afddfe18dc0d8") + val res1 = RES_SUCCESS(CMD_UPDATE_RELAY_FEE(ActorRef.noSender, 420L msat, 986, None), id1) + val res2 = RES_FAILURE(CMD_UPDATE_RELAY_FEE(ActorRef.noSender, 420L msat, 986, None), CommandUnavailableInThisState(id2, "CMD_UPDATE_RELAY_FEE", channel.CLOSING)) + val res3 = ApiTypes.ChannelNotFound(Left(id3)) + val map = Map( + Left(id1) -> Right(res1), + Left(id2) -> Right(res2), + Left(id3) -> Left(res3) + ) + JsonSerializers.serialization.write(map)(JsonSerializers.formats) shouldBe s"""{"e2fc57221cfb1942224082174022f3f70a32005aa209956f9c94c6903f7669ff":"ok","8e3ec6e16436b7dc61b86340192603d05f16d4f8e06c8aaa02fbe2ad63209af3":"cannot execute command=CMD_UPDATE_RELAY_FEE in state=CLOSING","74ca7a86e52d597aa2248cd2ff3b24428ede71345262be7fb31afddfe18dc0d8":"channel 74ca7a86e52d597aa2248cd2ff3b24428ede71345262be7fb31afddfe18dc0d8 not found"}""" + } + /** utility method that strips line breaks in the expected json */ def assertJsonEquals(actual: String, expected: String) = { val cleanedExpected = expected