1
0
mirror of https://github.com/ACINQ/eclair.git synced 2024-11-19 01:43:22 +01:00

Register can forward messages to nodes (#2863)

We add a `ForwardNodeId` command to the `Register` to forward messages
to a `Peer` actor based on its `node_id`.
This commit is contained in:
Bastien Teinturier 2024-06-12 10:38:52 +02:00 committed by GitHub
parent f0e3985d10
commit 741ac492e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 128 additions and 48 deletions

View File

@ -293,8 +293,8 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging {
override def channelsInfo(toRemoteNode_opt: Option[PublicKey])(implicit timeout: Timeout): Future[Iterable[RES_GET_CHANNEL_INFO]] = {
val futureResponse = toRemoteNode_opt match {
case Some(pk) => (appKit.register ? Symbol("channelsTo")).mapTo[Map[ByteVector32, PublicKey]].map(_.filter(_._2 == pk).keys)
case None => (appKit.register ? Symbol("channels")).mapTo[Map[ByteVector32, ActorRef]].map(_.keys)
case Some(pk) => (appKit.register ? Register.GetChannelsTo).mapTo[Map[ByteVector32, PublicKey]].map(_.filter(_._2 == pk).keys)
case None => (appKit.register ? Register.GetChannels).mapTo[Map[ByteVector32, ActorRef]].map(_.keys)
}
for {
@ -594,7 +594,7 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging {
/** Send a request to multiple channels using node ids */
private def sendToNodes[C <: Command, R <: CommandResponse[C]](nodeids: List[PublicKey], request: C)(implicit timeout: Timeout): Future[Map[ApiTypes.ChannelIdentifier, Either[Throwable, R]]] = {
for {
channelIds <- (appKit.register ? Symbol("channelsTo")).mapTo[Map[ByteVector32, PublicKey]].map(_.filter(kv => nodeids.contains(kv._2)).keys)
channelIds <- (appKit.register ? Register.GetChannelsTo).mapTo[Map[ByteVector32, PublicKey]].map(_.filter(kv => nodeids.contains(kv._2)).keys)
res <- sendToChannels[C, R](channelIds.map(Left(_)).toList, request)
} yield res
}

View File

@ -17,44 +17,45 @@
package fr.acinq.eclair.channel
import akka.actor.typed.scaladsl.adapter.TypedActorRefOps
import akka.actor.typed
import akka.actor.{Actor, ActorLogging, ActorRef, Props}
import akka.actor.{Actor, ActorLogging, ActorRef, Props, typed}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.channel.Register._
import fr.acinq.eclair.{SubscriptionsComplete, ShortChannelId}
import fr.acinq.eclair.io.PeerCreated
import fr.acinq.eclair.{ShortChannelId, SubscriptionsComplete}
/**
* Created by PM on 26/01/2016.
*/
class Register() extends Actor with ActorLogging {
class Register extends Actor with ActorLogging {
context.system.eventStream.subscribe(self, classOf[PeerCreated])
context.system.eventStream.subscribe(self, classOf[ChannelCreated])
context.system.eventStream.subscribe(self, classOf[AbstractChannelRestored])
context.system.eventStream.subscribe(self, classOf[ChannelIdAssigned])
context.system.eventStream.subscribe(self, classOf[ShortChannelIdAssigned])
context.system.eventStream.publish(SubscriptionsComplete(this.getClass))
// @formatter:off
private case class ChannelTerminated(channel: ActorRef, channelId: ByteVector32)
// @formatter:on
override def receive: Receive = main(Map.empty, Map.empty, Map.empty, Map.empty)
override def receive: Receive = main(Map.empty, Map.empty, Map.empty)
def main(channels: Map[ByteVector32, ActorRef], shortIds: Map[ShortChannelId, ByteVector32], channelsTo: Map[ByteVector32, PublicKey], nodeIdToPeer: Map[PublicKey, ActorRef]): Receive = {
case PeerCreated(peer, remoteNodeId) =>
context.watchWith(peer, PeerTerminated(peer, remoteNodeId))
context become main(channels, shortIds, channelsTo, nodeIdToPeer + (remoteNodeId -> peer))
def main(channels: Map[ByteVector32, ActorRef], shortIds: Map[ShortChannelId, ByteVector32], channelsTo: Map[ByteVector32, PublicKey]): Receive = {
case ChannelCreated(channel, _, remoteNodeId, _, temporaryChannelId, _, _) =>
context.watchWith(channel, ChannelTerminated(channel, temporaryChannelId))
context become main(channels + (temporaryChannelId -> channel), shortIds, channelsTo + (temporaryChannelId -> remoteNodeId))
context become main(channels + (temporaryChannelId -> channel), shortIds, channelsTo + (temporaryChannelId -> remoteNodeId), nodeIdToPeer)
case event: AbstractChannelRestored =>
context.watchWith(event.channel, ChannelTerminated(event.channel, event.channelId))
context become main(channels + (event.channelId -> event.channel), shortIds, channelsTo + (event.channelId -> event.remoteNodeId))
context become main(channels + (event.channelId -> event.channel), shortIds, channelsTo + (event.channelId -> event.remoteNodeId), nodeIdToPeer)
case ChannelIdAssigned(channel, remoteNodeId, temporaryChannelId, channelId) =>
context.unwatch(channel)
context.watchWith(channel, ChannelTerminated(channel, channelId))
context become main(channels + (channelId -> channel) - temporaryChannelId, shortIds, channelsTo + (channelId -> remoteNodeId) - temporaryChannelId)
context become main(channels + (channelId -> channel) - temporaryChannelId, shortIds, channelsTo + (channelId -> remoteNodeId) - temporaryChannelId, nodeIdToPeer)
case scidAssigned: ShortChannelIdAssigned =>
// We map all known scids (real or alias) to the channel_id. The relayer is in charge of deciding whether a real
@ -66,17 +67,24 @@ class Register() extends Actor with ActorLogging {
log.error("duplicate alias={} for channelIds={},{} this should never happen!", scidAssigned.shortIds.localAlias, channelId, scidAssigned.channelId)
case _ => ()
}
context become main(channels, shortIds ++ m, channelsTo)
context become main(channels, shortIds ++ m, channelsTo, nodeIdToPeer)
case ChannelTerminated(_, channelId) =>
val shortChannelIds = shortIds.collect { case (key, value) if value == channelId => key }
context become main(channels - channelId, shortIds -- shortChannelIds, channelsTo - channelId)
context become main(channels - channelId, shortIds -- shortChannelIds, channelsTo - channelId, nodeIdToPeer)
case Symbol("channels") => sender() ! channels
case PeerTerminated(peer, remoteNodeId) =>
// Note that peer actors can be stopped and recreated, which may lead to race conditions between PeerCreated and
// PeerTerminated messages: we only remove that nodeId from the map if the actor matches.
if (nodeIdToPeer.get(remoteNodeId).contains(peer)) {
context become main(channels, shortIds, channelsTo, nodeIdToPeer - remoteNodeId)
} else {
log.debug("ignoring obsolete PeerTerminated event for remoteNodeId={}", remoteNodeId)
}
case Symbol("shortIds") => sender() ! shortIds
case GetChannels => sender() ! channels
case Symbol("channelsTo") => sender() ! channelsTo
case GetChannelsTo => sender() ! channelsTo
case GetNextNodeId(replyTo, shortChannelId) =>
replyTo ! shortIds.get(shortChannelId).flatMap(cid => channelsTo.get(cid))
@ -96,6 +104,12 @@ class Register() extends Actor with ActorLogging {
case Some(channel) => channel.tell(msg, compatReplyTo)
case None => compatReplyTo ! ForwardShortIdFailure(fwd)
}
case fwd@ForwardNodeId(replyTo, nodeId, msg) =>
nodeIdToPeer.get(nodeId) match {
case Some(peer) => peer.tell(msg, replyTo.toClassic)
case None => replyTo ! ForwardNodeIdFailure(fwd)
}
}
}
@ -103,13 +117,24 @@ object Register {
def props(): Props = Props(new Register())
// @formatter:off
private[channel] case class PeerTerminated(peer: ActorRef, nodeId: PublicKey)
private case class ChannelTerminated(channel: ActorRef, channelId: ByteVector32)
// @formatter:on
// @formatter:off
case class Forward[T](replyTo: akka.actor.typed.ActorRef[ForwardFailure[T]], channelId: ByteVector32, message: T)
case class ForwardShortId[T](replyTo: akka.actor.typed.ActorRef[ForwardShortIdFailure[T]], shortChannelId: ShortChannelId, message: T)
case class ForwardNodeId[T](replyTo: akka.actor.typed.ActorRef[ForwardNodeIdFailure[T]], nodeId: PublicKey, message: T)
case class ForwardFailure[T](fwd: Forward[T])
case class ForwardShortIdFailure[T](fwd: ForwardShortId[T])
// @formatter:on
case class ForwardNodeIdFailure[T](fwd: ForwardNodeId[T])
case class GetNextNodeId(replyTo: typed.ActorRef[Option[PublicKey]], shortChannelId: ShortChannelId)
case object GetChannels
case object GetChannelsTo
// @formatter:on
}

View File

@ -72,6 +72,7 @@ class Peer(val nodeParams: NodeParams,
channel ! INPUT_RESTORED(state)
FinalChannelId(state.channelId) -> channel
}.toMap
context.system.eventStream.publish(PeerCreated(self, remoteNodeId))
goto(DISCONNECTED) using DisconnectedData(channels) // when we restart, we will attempt to reconnect right away, but then we'll wait
}
@ -374,7 +375,7 @@ class Peer(val nodeParams: NodeParams,
context.system.eventStream.publish(PeerDisconnected(self, remoteNodeId))
}
def gotoConnected(connectionReady: PeerConnection.ConnectionReady, channels: Map[ChannelId, ActorRef]): State = {
private def gotoConnected(connectionReady: PeerConnection.ConnectionReady, channels: Map[ChannelId, ActorRef]): State = {
require(remoteNodeId == connectionReady.remoteNodeId, s"invalid nodeid: $remoteNodeId != ${connectionReady.remoteNodeId}")
log.debug("got authenticated connection to address {}", connectionReady.address)
@ -394,7 +395,7 @@ class Peer(val nodeParams: NodeParams,
* We need to ignore [[LightningMessage]] not sent by the current [[PeerConnection]]. This may happen if we switch
* between connections.
*/
def dropStaleMessages(s: StateFunction): StateFunction = {
private def dropStaleMessages(s: StateFunction): StateFunction = {
case Event(msg: LightningMessage, d: ConnectedData) if sender() != d.peerConnection =>
log.warning("dropping message from stale connection: {}", msg)
stay()
@ -402,13 +403,13 @@ class Peer(val nodeParams: NodeParams,
s(e)
}
def spawnChannel(): ActorRef = {
private def spawnChannel(): ActorRef = {
val channel = channelFactory.spawn(context, remoteNodeId)
context watch channel
channel
}
def replyUnknownChannel(peerConnection: ActorRef, unknownChannelId: ByteVector32): Unit = {
private def replyUnknownChannel(peerConnection: ActorRef, unknownChannelId: ByteVector32): Unit = {
val msg = Error(unknownChannelId, "unknown channel")
self ! Peer.OutgoingMessage(msg, peerConnection)
}
@ -416,7 +417,7 @@ class Peer(val nodeParams: NodeParams,
// resume the openChannelInterceptor in case of failure, we always want the open channel request to succeed or fail
private val openChannelInterceptor = context.spawnAnonymous(Behaviors.supervise(OpenChannelInterceptor(context.self.toTyped, nodeParams, remoteNodeId, wallet, pendingChannelsRateLimiter)).onFailure(typed.SupervisorStrategy.resume))
def stopPeer(): State = {
private def stopPeer(): State = {
log.info("removing peer from db")
nodeParams.db.peers.removePeer(remoteNodeId)
stop(FSM.Normal)

View File

@ -25,6 +25,8 @@ import scala.concurrent.duration._
sealed trait PeerEvent
case class PeerCreated(peer: ActorRef, nodeId: PublicKey) extends PeerEvent
case class ConnectionInfo(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init)
case class PeerConnected(peer: ActorRef, nodeId: PublicKey, connectionInfo: ConnectionInfo) extends PeerEvent

View File

@ -515,7 +515,7 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I
eclair.channelsInfo(toRemoteNode_opt = None).pipeTo(sender.ref)
register.expectMsg(Symbol("channels"))
register.expectMsg(Register.GetChannels)
register.reply(map)
val c1 = register.expectMsgType[Register.Forward[CMD_GET_CHANNEL_INFO]]
@ -544,7 +544,7 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I
eclair.channelsInfo(toRemoteNode_opt = Some(a)).pipeTo(sender.ref)
register.expectMsg(Symbol("channelsTo"))
register.expectMsg(Register.GetChannelsTo)
register.reply(channels2Nodes)
val c1 = register.expectMsgType[Register.Forward[CMD_GET_CHANNEL_INFO]]
@ -676,7 +676,7 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I
eclair.updateRelayFee(List(a, b), 999 msat, 1234).pipeTo(sender.ref)
register.expectMsg(Symbol("channelsTo"))
register.expectMsg(Register.GetChannelsTo)
register.reply(map)
val u1 = register.expectMsgType[Register.Forward[CMD_UPDATE_RELAY_FEE]]

View File

@ -1,24 +1,74 @@
package fr.acinq.eclair.channel
import fr.acinq.eclair._
import akka.actor.{ActorRef, Props}
import akka.testkit.TestProbe
import akka.actor.typed.scaladsl.adapter._
import akka.actor.{ActorRef, PoisonPill}
import akka.testkit.{TestActorRef, TestProbe}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import org.scalatest.funsuite.AnyFunSuiteLike
import org.scalatest.ParallelTestExecution
import fr.acinq.eclair._
import fr.acinq.eclair.io.PeerCreated
import org.scalatest.funsuite.FixtureAnyFunSuiteLike
import org.scalatest.{Outcome, ParallelTestExecution}
class RegisterSpec extends TestKitBaseClass with AnyFunSuiteLike with ParallelTestExecution {
class RegisterSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with ParallelTestExecution {
case class CustomChannelRestored(channel: ActorRef, channelId: ByteVector32, peer: ActorRef, remoteNodeId: PublicKey) extends AbstractChannelRestored
test("register processes custom restored events") {
val sender = TestProbe()
val registerRef = system.actorOf(Register.props())
val customRestoredEvent = CustomChannelRestored(TestProbe().ref, randomBytes32(), TestProbe().ref, randomKey().publicKey)
registerRef ! customRestoredEvent
sender.send(registerRef, Symbol("channels"))
sender.expectMsgType[Map[ByteVector32, ActorRef]] == Map(customRestoredEvent.channelId -> customRestoredEvent.channel)
case class FixtureParam(register: TestActorRef[Register], probe: TestProbe)
override def withFixture(test: OneArgTest): Outcome = {
val probe = TestProbe()
system.eventStream.subscribe(probe.ref, classOf[SubscriptionsComplete])
val register = TestActorRef(new Register())
probe.expectMsg(SubscriptionsComplete(classOf[Register]))
try {
withFixture(test.toNoArgTest(FixtureParam(register, probe)))
} finally {
system.stop(register)
}
}
test("process custom restored events") { f =>
import f._
val customRestoredEvent = CustomChannelRestored(TestProbe().ref, randomBytes32(), TestProbe().ref, randomKey().publicKey)
system.eventStream.publish(customRestoredEvent)
awaitAssert({
probe.send(register, Register.GetChannels)
probe.expectMsgType[Map[ByteVector32, ActorRef]] == Map(customRestoredEvent.channelId -> customRestoredEvent.channel)
})
}
test("forward messages to peers") { f =>
import f._
val nodeId = randomKey().publicKey
val peer1 = TestProbe()
system.eventStream.publish(PeerCreated(peer1.ref, nodeId))
awaitAssert({
register ! Register.ForwardNodeId(probe.ref.toTyped, nodeId, "hello")
peer1.expectMsg("hello")
})
// We simulate a race condition, where the peer is recreated but we receive events out of order.
val peer2 = TestProbe()
system.eventStream.publish(PeerCreated(peer2.ref, nodeId))
awaitAssert({
register ! Register.ForwardNodeId(probe.ref.toTyped, nodeId, "world")
peer2.expectMsg("world")
})
register ! Register.PeerTerminated(peer1.ref, nodeId)
register ! Register.ForwardNodeId(probe.ref.toTyped, nodeId, "hello again")
peer2.expectMsg("hello again")
peer2.ref ! PoisonPill
awaitAssert({
val fwd = Register.ForwardNodeId(probe.ref.toTyped, nodeId, "d34d")
register ! fwd
probe.expectMsg(Register.ForwardNodeIdFailure(fwd))
})
}
}

View File

@ -503,7 +503,7 @@ class StandardChannelIntegrationSpec extends ChannelIntegrationSpec {
// mine the funding tx
generateBlocks(2)
// get the channelId
sender.send(fundee.register, Symbol("channels"))
sender.send(fundee.register, Register.GetChannels)
val Some((_, fundeeChannel)) = sender.expectMsgType[Map[ByteVector32, ActorRef]].find(_._1 == tempChannelId)
sender.send(fundeeChannel, CMD_GET_CHANNEL_DATA(ActorRef.noSender))
@ -682,7 +682,7 @@ abstract class AnchorChannelIntegrationSpec extends ChannelIntegrationSpec {
// initially all the balance is on C side and F doesn't have an output
val sender = TestProbe()
sender.send(nodes("F").register, Symbol("channelsTo"))
sender.send(nodes("F").register, Register.GetChannelsTo)
// retrieve the channelId of C <--> F
val Some(channelId) = sender.expectMsgType[Map[ByteVector32, PublicKey]].find(_._2 == nodes("C").nodeParams.nodeId).map(_._1)

View File

@ -28,7 +28,7 @@ import fr.acinq.eclair.TestUtils.waitEventStreamSynced
import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher
import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher.{Watch, WatchFundingConfirmed}
import fr.acinq.eclair.blockchain.bitcoind.rpc.BitcoinCoreClient
import fr.acinq.eclair.channel.{CMD_CLOSE, RES_SUCCESS}
import fr.acinq.eclair.channel.{CMD_CLOSE, RES_SUCCESS, Register}
import fr.acinq.eclair.io.Switchboard
import fr.acinq.eclair.message.OnionMessages
import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient, buildRoute}
@ -328,10 +328,10 @@ class MessageIntegrationSpec extends IntegrationSpec {
// We close the channels A -> B -> C but we keep channels with D
// This ensures nodes still have an unrelated channel so we keep them in the network DB.
val probe = TestProbe()
probe.send(nodes("B").register, Symbol("channels"))
probe.send(nodes("B").register, Register.GetChannels)
val channelsB = probe.expectMsgType[Map[ByteVector32, ActorRef]]
assert(channelsB.size == 3)
probe.send(nodes("D").register, Symbol("channels"))
probe.send(nodes("D").register, Register.GetChannels)
val channelsD = probe.expectMsgType[Map[ByteVector32, ActorRef]]
assert(channelsD.size == 3)
channelsB.foreach {

View File

@ -123,7 +123,9 @@ class PeerSpec extends FixtureSpec {
test("restore existing channels") { f =>
import f._
val probe = TestProbe()
system.eventStream.subscribe(probe.ref, classOf[PeerCreated])
connect(remoteNodeId, peer, peerConnection, switchboard, channels = Set(ChannelCodecsSpec.normal))
probe.expectMsg(PeerCreated(peer.ref, remoteNodeId))
probe.send(peer, Peer.GetPeerInfo(None))
val peerInfo = probe.expectMsgType[PeerInfo]
assert(peerInfo.peer == peer)