diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index 82e96f358..1f4cdd326 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -94,7 +94,7 @@ trait Eclair { def channelsInfo(toRemoteNode_opt: Option[PublicKey])(implicit timeout: Timeout): Future[Iterable[RES_GETINFO]] - def channelInfo(channel: ApiTypes.ChannelIdentifier)(implicit timeout: Timeout): Future[RES_GETINFO] + def channelInfo(channel: ApiTypes.ChannelIdentifier)(implicit timeout: Timeout): Future[CommandResponse[CMD_GETINFO]] def peers()(implicit timeout: Timeout): Future[Iterable[PeerInfo]] @@ -212,19 +212,22 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { .map(_.filter(n => nodeIds_opt.forall(_.contains(n.nodeId)))) } - override def channelsInfo(toRemoteNode_opt: Option[PublicKey])(implicit timeout: Timeout): Future[Iterable[RES_GETINFO]] = toRemoteNode_opt match { - case Some(pk) => for { - channelIds <- (appKit.register ? Symbol("channelsTo")).mapTo[Map[ByteVector32, PublicKey]].map(_.filter(_._2 == pk).keys) - channels <- Future.sequence(channelIds.map(channelId => sendToChannel[CMD_GETINFO, RES_GETINFO](Left(channelId), CMD_GETINFO(ActorRef.noSender)))) - } yield channels - case None => for { - channelIds <- (appKit.register ? Symbol("channels")).mapTo[Map[ByteVector32, ActorRef]].map(_.keys) - channels <- Future.sequence(channelIds.map(channelId => sendToChannel[CMD_GETINFO, RES_GETINFO](Left(channelId), CMD_GETINFO(ActorRef.noSender)))) - } yield channels + override def channelsInfo(toRemoteNode_opt: Option[PublicKey])(implicit timeout: Timeout): Future[Iterable[RES_GETINFO]] = { + val futureResponse = toRemoteNode_opt match { + case Some(pk) => (appKit.register ? Symbol("channelsTo")).mapTo[Map[ByteVector32, PublicKey]].map(_.filter(_._2 == pk).keys) + case None => (appKit.register ? Symbol("channels")).mapTo[Map[ByteVector32, ActorRef]].map(_.keys) + } + + for { + channelIds <- futureResponse + channels <- Future.sequence(channelIds.map(channelId => sendToChannel[CMD_GETINFO, CommandResponse[CMD_GETINFO]](Left(channelId), CMD_GETINFO(ActorRef.noSender)))) + } yield channels.collect { + case properResponse: RES_GETINFO => properResponse + } } - override def channelInfo(channel: ApiTypes.ChannelIdentifier)(implicit timeout: Timeout): Future[RES_GETINFO] = { - sendToChannel[CMD_GETINFO, RES_GETINFO](channel, CMD_GETINFO(ActorRef.noSender)) + override def channelInfo(channel: ApiTypes.ChannelIdentifier)(implicit timeout: Timeout): Future[CommandResponse[CMD_GETINFO]] = { + sendToChannel[CMD_GETINFO, CommandResponse[CMD_GETINFO]](channel, CMD_GETINFO(ActorRef.noSender)) } override def allChannels()(implicit timeout: Timeout): Future[Iterable[ChannelDesc]] = { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala index 877ac1f92..3df5abe44 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala @@ -449,7 +449,7 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I assert(verifiedMessage.publicKey !== kit.nodeParams.nodeId) } - test("get channel info (all channels)") { f => + test("get channel info (filtered channels)") { f => import f._ val eclair = new EclairImpl(kit) @@ -468,8 +468,8 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I val c1 = register.expectMsgType[Register.Forward[CMD_GETINFO]] register.reply(RES_GETINFO(map(c1.channelId), c1.channelId, NORMAL, ChannelCodecsSpec.normal)) - val c2 = register.expectMsgType[Register.Forward[CMD_GETINFO]] - register.reply(RES_GETINFO(map(c2.channelId), c2.channelId, NORMAL, ChannelCodecsSpec.normal)) + register.expectMsgType[Register.Forward[CMD_GETINFO]] + register.reply(RES_FAILURE(CMD_GETINFO(ActorRef.noSender), new IllegalArgumentException("Non-standard channel"))) val c3 = register.expectMsgType[Register.Forward[CMD_GETINFO]] register.reply(RES_GETINFO(map(c3.channelId), c3.channelId, NORMAL, ChannelCodecsSpec.normal)) register.expectNoMessage() @@ -478,7 +478,6 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I assert(res.value.get.get.toSet === Set( RES_GETINFO(a, a1, NORMAL, ChannelCodecsSpec.normal), - RES_GETINFO(a, a2, NORMAL, ChannelCodecsSpec.normal), RES_GETINFO(b, b1, NORMAL, ChannelCodecsSpec.normal), )) }