From 80b642089a80716bcd864102051260459c0a37cf Mon Sep 17 00:00:00 2001 From: Bastien Teinturier <31281497+t-bast@users.noreply.github.com> Date: Wed, 18 Dec 2019 16:39:20 +0100 Subject: [PATCH] Improve CommandSend type (#1260) Add type with upper bound to make `asInstanceOf` unnecessary. Split `HasHtlcId` from `Command`: they are orthogonal traits. --- .../acinq/eclair/channel/ChannelTypes.scala | 10 ++-- .../fr/acinq/eclair/db/PendingRelayDb.scala | 6 +-- .../db/sqlite/SqlitePendingRelayDb.scala | 7 ++- .../eclair/payment/relay/CommandBuffer.scala | 2 +- .../fr/acinq/eclair/wire/CommandCodecs.scala | 5 +- .../eclair/payment/MultiPartHandlerSpec.scala | 50 +++++++++++-------- .../payment/PostRestartHtlcCleanerSpec.scala | 4 +- .../acinq/eclair/wire/CommandCodecsSpec.scala | 14 +++--- 8 files changed, 51 insertions(+), 47 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelTypes.scala index 7f007bf47..2dc6fcf02 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelTypes.scala @@ -118,12 +118,10 @@ object Upstream { } sealed trait Command -sealed trait HasHtlcIdCommand extends Command { - def id: Long -} -final case class CMD_FULFILL_HTLC(id: Long, r: ByteVector32, commit: Boolean = false) extends HasHtlcIdCommand -final case class CMD_FAIL_HTLC(id: Long, reason: Either[ByteVector, FailureMessage], commit: Boolean = false) extends HasHtlcIdCommand -final case class CMD_FAIL_MALFORMED_HTLC(id: Long, onionHash: ByteVector32, failureCode: Int, commit: Boolean = false) extends HasHtlcIdCommand +sealed trait HasHtlcId { def id: Long } +final case class CMD_FULFILL_HTLC(id: Long, r: ByteVector32, commit: Boolean = false) extends Command with HasHtlcId +final case class CMD_FAIL_HTLC(id: Long, reason: Either[ByteVector, FailureMessage], commit: Boolean = false) extends Command with HasHtlcId +final case class CMD_FAIL_MALFORMED_HTLC(id: Long, onionHash: ByteVector32, failureCode: Int, commit: Boolean = false) extends Command with HasHtlcId final case class CMD_ADD_HTLC(amount: MilliSatoshi, paymentHash: ByteVector32, cltvExpiry: CltvExpiry, onion: OnionRoutingPacket, upstream: Upstream, commit: Boolean = false, previousFailures: Seq[AddHtlcFailed] = Seq.empty) extends Command final case class CMD_UPDATE_FEE(feeratePerKw: Long, commit: Boolean = false) extends Command case object CMD_SIGN extends Command diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PendingRelayDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PendingRelayDb.scala index 8dfe79d17..49ed97b34 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PendingRelayDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PendingRelayDb.scala @@ -17,7 +17,7 @@ package fr.acinq.eclair.db import fr.acinq.bitcoin.ByteVector32 -import fr.acinq.eclair.channel.HasHtlcIdCommand +import fr.acinq.eclair.channel.{Command, HasHtlcId} /** * This database stores CMD_FULFILL_HTLC and CMD_FAIL_HTLC that we have received from downstream @@ -33,11 +33,11 @@ import fr.acinq.eclair.channel.HasHtlcIdCommand */ trait PendingRelayDb { - def addPendingRelay(channelId: ByteVector32, cmd: HasHtlcIdCommand) + def addPendingRelay(channelId: ByteVector32, cmd: Command with HasHtlcId) def removePendingRelay(channelId: ByteVector32, htlcId: Long) - def listPendingRelay(channelId: ByteVector32): Seq[HasHtlcIdCommand] + def listPendingRelay(channelId: ByteVector32): Seq[Command with HasHtlcId] def listPendingRelay(): Set[(ByteVector32, Long)] diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala index 77f67440a..aa28da074 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala @@ -19,10 +19,9 @@ package fr.acinq.eclair.db.sqlite import java.sql.Connection import fr.acinq.bitcoin.ByteVector32 -import fr.acinq.eclair.channel.HasHtlcIdCommand +import fr.acinq.eclair.channel.{Command, HasHtlcId} import fr.acinq.eclair.db.PendingRelayDb import fr.acinq.eclair.wire.CommandCodecs.cmdCodec -import scodec.bits.BitVector import scala.collection.immutable.Queue @@ -40,7 +39,7 @@ class SqlitePendingRelayDb(sqlite: Connection) extends PendingRelayDb { statement.executeUpdate("CREATE TABLE IF NOT EXISTS pending_relay (channel_id BLOB NOT NULL, htlc_id INTEGER NOT NULL, data BLOB NOT NULL, PRIMARY KEY(channel_id, htlc_id))") } - override def addPendingRelay(channelId: ByteVector32, cmd: HasHtlcIdCommand): Unit = { + override def addPendingRelay(channelId: ByteVector32, cmd: Command with HasHtlcId): Unit = { using(sqlite.prepareStatement("INSERT OR IGNORE INTO pending_relay VALUES (?, ?, ?)")) { statement => statement.setBytes(1, channelId.toArray) statement.setLong(2, cmd.id) @@ -57,7 +56,7 @@ class SqlitePendingRelayDb(sqlite: Connection) extends PendingRelayDb { } } - override def listPendingRelay(channelId: ByteVector32): Seq[HasHtlcIdCommand] = { + override def listPendingRelay(channelId: ByteVector32): Seq[Command with HasHtlcId] = { using(sqlite.prepareStatement("SELECT data FROM pending_relay WHERE channel_id=?")) { statement => statement.setBytes(1, channelId.toArray) val rs = statement.executeQuery() diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/CommandBuffer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/CommandBuffer.scala index 9eb3af08a..412380600 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/CommandBuffer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/CommandBuffer.scala @@ -62,7 +62,7 @@ class CommandBuffer(nodeParams: NodeParams, register: ActorRef) extends Actor wi object CommandBuffer { - case class CommandSend(channelId: ByteVector32, cmd: HasHtlcIdCommand) + case class CommandSend[T <: Command with HasHtlcId](channelId: ByteVector32, cmd: T) case class CommandAck(channelId: ByteVector32, htlcId: Long) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommandCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommandCodecs.scala index 5c7c88867..28b2d1e4e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommandCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommandCodecs.scala @@ -16,7 +16,7 @@ package fr.acinq.eclair.wire -import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FAIL_MALFORMED_HTLC, CMD_FULFILL_HTLC, HasHtlcIdCommand} +import fr.acinq.eclair.channel._ import fr.acinq.eclair.wire.CommonCodecs._ import fr.acinq.eclair.wire.FailureMessageCodecs.failureMessageCodec import scodec.Codec @@ -40,8 +40,9 @@ object CommandCodecs { ("failureCode" | uint16) :: ("commit" | provide(false))).as[CMD_FAIL_MALFORMED_HTLC] - val cmdCodec: Codec[HasHtlcIdCommand] = discriminated[HasHtlcIdCommand].by(uint16) + val cmdCodec: Codec[Command with HasHtlcId] = discriminated[Command with HasHtlcId].by(uint16) .typecase(0, cmdFulfillCodec) .typecase(1, cmdFailCodec) .typecase(2, cmdFailMalformedCodec) + } \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala index 79b5b6f65..9423853e5 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala @@ -70,7 +70,7 @@ class MultiPartHandlerSpec extends TestKit(ActorSystem("test")) with fixture.Fun val add = UpdateAddHtlc(ByteVector32.One, 0, amountMsat, pr.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) sender.send(handler, IncomingPacket.FinalPacket(add, Onion.FinalLegacyPayload(add.amountMsat, add.cltvExpiry))) - assert(commandBuffer.expectMsgType[CommandBuffer.CommandSend].cmd.isInstanceOf[CMD_FULFILL_HTLC]) + commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FULFILL_HTLC]] val paymentReceived = eventListener.expectMsgType[PaymentReceived] assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0))) === PaymentReceived(add.paymentHash, PartialPayment(amountMsat, add.channelId, timestamp = 0) :: Nil)) @@ -89,7 +89,7 @@ class MultiPartHandlerSpec extends TestKit(ActorSystem("test")) with fixture.Fun val add = UpdateAddHtlc(ByteVector32.One, 0, amountMsat, pr.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) sender.send(handler, IncomingPacket.FinalPacket(add, Onion.FinalLegacyPayload(add.amountMsat, add.cltvExpiry))) - assert(commandBuffer.expectMsgType[CommandBuffer.CommandSend].cmd.isInstanceOf[CMD_FULFILL_HTLC]) + commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FULFILL_HTLC]] val paymentReceived = eventListener.expectMsgType[PaymentReceived] assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0))) === PaymentReceived(add.paymentHash, PartialPayment(amountMsat, add.channelId, timestamp = 0) :: Nil)) @@ -107,7 +107,8 @@ class MultiPartHandlerSpec extends TestKit(ActorSystem("test")) with fixture.Fun val add = UpdateAddHtlc(ByteVector32.One, 0, amountMsat, pr.paymentHash, CltvExpiryDelta(3).toCltvExpiry(nodeParams.currentBlockHeight), TestConstants.emptyOnionPacket) sender.send(handler, IncomingPacket.FinalPacket(add, Onion.FinalLegacyPayload(add.amountMsat, add.cltvExpiry))) - assert(commandBuffer.expectMsgType[CommandBuffer.CommandSend].cmd.asInstanceOf[CMD_FAIL_HTLC].reason == Right(IncorrectOrUnknownPaymentDetails(amountMsat, nodeParams.currentBlockHeight))) + val cmd = commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FAIL_HTLC]].cmd + assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(amountMsat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).get.status === IncomingPaymentStatus.Pending) eventListener.expectNoMsg(100 milliseconds) @@ -216,7 +217,7 @@ class MultiPartHandlerSpec extends TestKit(ActorSystem("test")) with fixture.Fun val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, pr.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) sender.send(handler, IncomingPacket.FinalPacket(add, Onion.FinalLegacyPayload(add.amountMsat, add.cltvExpiry))) - assert(commandBuffer.expectMsgType[CommandBuffer.CommandSend].cmd.isInstanceOf[CMD_FAIL_HTLC]) + commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FAIL_HTLC]] val Some(incoming) = nodeParams.db.payments.getIncomingPayment(pr.paymentHash) assert(incoming.paymentRequest.isExpired && incoming.status === IncomingPaymentStatus.Expired) } @@ -231,7 +232,8 @@ class MultiPartHandlerSpec extends TestKit(ActorSystem("test")) with fixture.Fun val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, pr.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) sender.send(handler, IncomingPacket.FinalPacket(add, Onion.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, pr.paymentSecret.get))) - assert(commandBuffer.expectMsgType[CommandBuffer.CommandSend].cmd.asInstanceOf[CMD_FAIL_HTLC].reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) + val cmd = commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FAIL_HTLC]].cmd + assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) val Some(incoming) = nodeParams.db.payments.getIncomingPayment(pr.paymentHash) assert(incoming.paymentRequest.isExpired && incoming.status === IncomingPaymentStatus.Expired) } @@ -245,7 +247,8 @@ class MultiPartHandlerSpec extends TestKit(ActorSystem("test")) with fixture.Fun val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, pr.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) sender.send(handler, IncomingPacket.FinalPacket(add, Onion.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, pr.paymentSecret.get))) - assert(commandBuffer.expectMsgType[CommandBuffer.CommandSend].cmd.asInstanceOf[CMD_FAIL_HTLC].reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) + val cmd = commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FAIL_HTLC]].cmd + assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).get.status === IncomingPaymentStatus.Pending) } @@ -258,7 +261,8 @@ class MultiPartHandlerSpec extends TestKit(ActorSystem("test")) with fixture.Fun val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, pr.paymentHash, CltvExpiryDelta(1).toCltvExpiry(nodeParams.currentBlockHeight), TestConstants.emptyOnionPacket) sender.send(handler, IncomingPacket.FinalPacket(add, Onion.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, pr.paymentSecret.get))) - assert(commandBuffer.expectMsgType[CommandBuffer.CommandSend].cmd.asInstanceOf[CMD_FAIL_HTLC].reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) + val cmd = commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FAIL_HTLC]].cmd + assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).get.status === IncomingPaymentStatus.Pending) } @@ -271,7 +275,8 @@ class MultiPartHandlerSpec extends TestKit(ActorSystem("test")) with fixture.Fun val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, pr.paymentHash.reverse, defaultExpiry, TestConstants.emptyOnionPacket) sender.send(handler, IncomingPacket.FinalPacket(add, Onion.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, pr.paymentSecret.get))) - assert(commandBuffer.expectMsgType[CommandBuffer.CommandSend].cmd.asInstanceOf[CMD_FAIL_HTLC].reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) + val cmd = commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FAIL_HTLC]].cmd + assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).get.status === IncomingPaymentStatus.Pending) } @@ -284,7 +289,8 @@ class MultiPartHandlerSpec extends TestKit(ActorSystem("test")) with fixture.Fun val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, pr.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) sender.send(handler, IncomingPacket.FinalPacket(add, Onion.createMultiPartPayload(add.amountMsat, 999 msat, add.cltvExpiry, pr.paymentSecret.get))) - assert(commandBuffer.expectMsgType[CommandBuffer.CommandSend].cmd.asInstanceOf[CMD_FAIL_HTLC].reason == Right(IncorrectOrUnknownPaymentDetails(999 msat, nodeParams.currentBlockHeight))) + val cmd = commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FAIL_HTLC]].cmd + assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(999 msat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).get.status === IncomingPaymentStatus.Pending) } @@ -297,7 +303,8 @@ class MultiPartHandlerSpec extends TestKit(ActorSystem("test")) with fixture.Fun val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, pr.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) sender.send(handler, IncomingPacket.FinalPacket(add, Onion.createMultiPartPayload(add.amountMsat, 2001 msat, add.cltvExpiry, pr.paymentSecret.get))) - assert(commandBuffer.expectMsgType[CommandBuffer.CommandSend].cmd.asInstanceOf[CMD_FAIL_HTLC].reason == Right(IncorrectOrUnknownPaymentDetails(2001 msat, nodeParams.currentBlockHeight))) + val cmd = commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FAIL_HTLC]].cmd + assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(2001 msat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).get.status === IncomingPaymentStatus.Pending) } @@ -311,7 +318,8 @@ class MultiPartHandlerSpec extends TestKit(ActorSystem("test")) with fixture.Fun // Invalid payment secret. val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, pr.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) sender.send(handler, IncomingPacket.FinalPacket(add, Onion.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, pr.paymentSecret.get.reverse))) - assert(commandBuffer.expectMsgType[CommandBuffer.CommandSend].cmd.asInstanceOf[CMD_FAIL_HTLC].reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) + val cmd = commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FAIL_HTLC]].cmd + assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).get.status === IncomingPaymentStatus.Pending) } @@ -334,7 +342,7 @@ class MultiPartHandlerSpec extends TestKit(ActorSystem("test")) with fixture.Fun f.sender.send(handler, GetPendingPayments) assert(f.sender.expectMsgType[PendingPayments].paymentHashes.nonEmpty) - val commands = f.commandBuffer.expectMsgType[CommandBuffer.CommandSend] :: f.commandBuffer.expectMsgType[CommandBuffer.CommandSend] :: Nil + val commands = f.commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FAIL_HTLC]] :: f.commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FAIL_HTLC]] :: Nil assert(commands.toSet === Set( CommandBuffer.CommandSend(ByteVector32.One, CMD_FAIL_HTLC(0, Right(PaymentTimeout), commit = true)), CommandBuffer.CommandSend(ByteVector32.One, CMD_FAIL_HTLC(1, Right(PaymentTimeout), commit = true)) @@ -369,12 +377,11 @@ class MultiPartHandlerSpec extends TestKit(ActorSystem("test")) with fixture.Fun f.sender.send(handler, IncomingPacket.FinalPacket(add3, Onion.createMultiPartPayload(add3.amountMsat, 1000 msat, add3.cltvExpiry, pr.paymentSecret.get))) f.commandBuffer.expectMsg(CommandBuffer.CommandSend(add2.channelId, CMD_FAIL_HTLC(add2.id, Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight)), commit = true))) - val cmd1 = f.commandBuffer.expectMsgType[CommandBuffer.CommandSend] + val cmd1 = f.commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FULFILL_HTLC]] assert(cmd1.cmd.id === add1.id) assert(cmd1.channelId === add1.channelId) - val fulfill1 = cmd1.cmd.asInstanceOf[CMD_FULFILL_HTLC] - assert(Crypto.sha256(fulfill1.r) === pr.paymentHash) - f.commandBuffer.expectMsg(CommandBuffer.CommandSend(add3.channelId, CMD_FULFILL_HTLC(add3.id, fulfill1.r, commit = true))) + assert(Crypto.sha256(cmd1.cmd.r) === pr.paymentHash) + f.commandBuffer.expectMsg(CommandBuffer.CommandSend(add3.channelId, CMD_FULFILL_HTLC(add3.id, cmd1.cmd.r, commit = true))) f.sender.send(handler, CommandBuffer.CommandAck(add1.channelId, add1.id)) f.commandBuffer.expectMsg(CommandBuffer.CommandAck(add1.channelId, add1.id)) @@ -391,7 +398,7 @@ class MultiPartHandlerSpec extends TestKit(ActorSystem("test")) with fixture.Fun // Extraneous HTLCs should be fulfilled. f.sender.send(handler, MultiPartPaymentFSM.ExtraHtlcReceived(pr.paymentHash, PendingPayment(44, PartialPayment(200 msat, ByteVector32.One, 0)), None)) - f.commandBuffer.expectMsg(CommandBuffer.CommandSend(ByteVector32.One, CMD_FULFILL_HTLC(44, fulfill1.r, commit = true))) + f.commandBuffer.expectMsg(CommandBuffer.CommandSend(ByteVector32.One, CMD_FULFILL_HTLC(44, cmd1.cmd.r, commit = true))) assert(f.eventListener.expectMsgType[PaymentReceived].amount === 200.msat) val received2 = nodeParams.db.payments.getIncomingPayment(pr.paymentHash) assert(received2.get.status.asInstanceOf[IncomingPaymentStatus.Received].amount === 1200.msat) @@ -421,12 +428,11 @@ class MultiPartHandlerSpec extends TestKit(ActorSystem("test")) with fixture.Fun val add3 = UpdateAddHtlc(ByteVector32.Zeroes, 5, 700 msat, pr.paymentHash, f.defaultExpiry, TestConstants.emptyOnionPacket) f.sender.send(handler, IncomingPacket.FinalPacket(add3, Onion.createMultiPartPayload(add3.amountMsat, 1000 msat, add3.cltvExpiry, pr.paymentSecret.get))) - val cmd1 = f.commandBuffer.expectMsgType[CommandBuffer.CommandSend] + val cmd1 = f.commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FULFILL_HTLC]] assert(cmd1.channelId === add2.channelId) - val fulfill1 = cmd1.cmd.asInstanceOf[CMD_FULFILL_HTLC] - assert(fulfill1.id === 2) - assert(Crypto.sha256(fulfill1.r) === pr.paymentHash) - f.commandBuffer.expectMsg(CommandBuffer.CommandSend(add3.channelId, CMD_FULFILL_HTLC(5, fulfill1.r, commit = true))) + assert(cmd1.cmd.id === 2) + assert(Crypto.sha256(cmd1.cmd.r) === pr.paymentHash) + f.commandBuffer.expectMsg(CommandBuffer.CommandSend(add3.channelId, CMD_FULFILL_HTLC(5, cmd1.cmd.r, commit = true))) val paymentReceived = f.eventListener.expectMsgType[PaymentReceived] assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0))) === PaymentReceived(pr.paymentHash, PartialPayment(300 msat, ByteVector32.One, 0) :: PartialPayment(700 msat, ByteVector32.Zeroes, 0) :: Nil)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala index 076daa848..51d285148 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala @@ -283,7 +283,7 @@ class PostRestartHtlcCleanerSpec extends TestkitBaseClass { // This downstream HTLC has two upstream HTLCs. sender.send(relayer, buildForwardFail(testCase.downstream_1_1, testCase.upstream_1)) - val fails = commandBuffer.expectMsgType[CommandBuffer.CommandSend] :: commandBuffer.expectMsgType[CommandBuffer.CommandSend] :: Nil + val fails = commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FAIL_HTLC]] :: commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FAIL_HTLC]] :: Nil assert(fails.toSet === testCase.upstream_1.origins.map { case (channelId, htlcId) => CommandBuffer.CommandSend(channelId, CMD_FAIL_HTLC(htlcId, Right(TemporaryNodeFailure), commit = true)) }.toSet) @@ -313,7 +313,7 @@ class PostRestartHtlcCleanerSpec extends TestkitBaseClass { // This downstream HTLC has two upstream HTLCs. sender.send(relayer, buildForwardFulfill(testCase.downstream_1_1, testCase.upstream_1, preimage1)) - val fails = commandBuffer.expectMsgType[CommandBuffer.CommandSend] :: commandBuffer.expectMsgType[CommandBuffer.CommandSend] :: Nil + val fails = commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FULFILL_HTLC]] :: commandBuffer.expectMsgType[CommandBuffer.CommandSend[CMD_FULFILL_HTLC]] :: Nil assert(fails.toSet === testCase.upstream_1.origins.map { case (channelId, htlcId) => CommandBuffer.CommandSend(channelId, CMD_FULFILL_HTLC(htlcId, preimage1, commit = true)) }.toSet) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommandCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommandCodecsSpec.scala index cc2f2953e..e22d42f85 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommandCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommandCodecsSpec.scala @@ -16,22 +16,22 @@ package fr.acinq.eclair.wire -import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FAIL_MALFORMED_HTLC, CMD_FULFILL_HTLC, Command, HasHtlcIdCommand} +import fr.acinq.eclair.channel._ import fr.acinq.eclair.{randomBytes, randomBytes32} import org.scalatest.FunSuite /** - * Created by PM on 31/05/2016. - */ + * Created by PM on 31/05/2016. + */ class CommandCodecsSpec extends FunSuite { test("encode/decode all channel messages") { - val msgs: List[HasHtlcIdCommand] = + val msgs: List[Command with HasHtlcId] = CMD_FULFILL_HTLC(1573L, randomBytes32) :: - CMD_FAIL_HTLC(42456L, Left(randomBytes(145))) :: - CMD_FAIL_HTLC(253, Right(TemporaryNodeFailure)) :: - CMD_FAIL_MALFORMED_HTLC(7984, randomBytes32, FailureMessageCodecs.BADONION) :: Nil + CMD_FAIL_HTLC(42456L, Left(randomBytes(145))) :: + CMD_FAIL_HTLC(253, Right(TemporaryNodeFailure)) :: + CMD_FAIL_MALFORMED_HTLC(7984, randomBytes32, FailureMessageCodecs.BADONION) :: Nil msgs.foreach { msg =>