1
0
Fork 0
mirror of https://github.com/ACINQ/eclair.git synced 2025-03-27 02:37:06 +01:00

Replaced 'find' by 'collect' (#8)

* replaced 'find' by 'collect'

* fixed bug ourChanges->theirChanges
This commit is contained in:
Pierre-Marie Padiou 2016-05-26 14:34:54 +02:00
parent 291920b7f9
commit 07d489c0d0
3 changed files with 38 additions and 55 deletions

View file

@ -75,7 +75,7 @@ class Channel(val them: ActorRef, val blockchain: ActorRef, val params: OurChann
when(OPEN_WAIT_FOR_ANCHOR) {
case Event(open_anchor(anchorTxHash, anchorOutputIndex, anchorAmount, theirSig), DATA_OPEN_WAIT_FOR_ANCHOR(ourParams, theirParams, theirRevocationHash, theirNextRevocationHash)) =>
val anchorTxid = anchorTxHash.reverse //see https://github.com/ElementsProject/lightning/issues/17
val anchorOutput = TxOut(Satoshi(anchorAmount), publicKeyScript = Scripts.anchorPubkeyScript(ourParams.commitPubKey, theirParams.commitPubKey))
val anchorOutput = TxOut(Satoshi(anchorAmount), publicKeyScript = Scripts.anchorPubkeyScript(ourParams.commitPubKey, theirParams.commitPubKey))
// they fund the channel with their anchor tx, so the money is theirs
val ourSpec = CommitmentSpec(Set.empty[Htlc], feeRate = ourParams.initialFeeRate, initial_amount_them_msat = anchorAmount * 1000, initial_amount_us_msat = 0, amount_them_msat = anchorAmount * 1000, amount_us_msat = 0)
@ -285,23 +285,25 @@ class Channel(val them: ActorRef, val blockchain: ActorRef, val params: OurChann
stay using d.copy(theirChanges = theirChanges.copy(proposed = theirChanges.proposed :+ htlc))
case Event(CMD_FULFILL_HTLC(id, r), d@DATA_NORMAL(_, _, _, _, _, theirCommit, ourChanges, theirChanges, _, _)) =>
findHtlc(theirChanges.acked, id, r) match {
case Some(htlc) =>
theirChanges.acked.collectFirst { case u: update_add_htlc if u.id == id => u } match {
case Some(htlc) if htlc.rHash == bin2sha256(Crypto.sha256(r)) =>
val fulfill = update_fulfill_htlc(id, r)
them ! fulfill
stay using d.copy(ourChanges = ourChanges.copy(proposed = ourChanges.proposed :+ fulfill))
case Some(htlc) => throw new RuntimeException(s"invalid htlc preimage for htlc $id")
case None => throw new RuntimeException(s"unknown htlc id=$id")
}
case Event(fulfill@update_fulfill_htlc(id, r), d@DATA_NORMAL(_, _, _, _, ourCommit, _, ourChanges, theirChanges, _, _)) =>
findHtlc(ourChanges.acked, id, r) match {
case Some(htlc) =>
ourChanges.acked.collectFirst { case u: update_add_htlc if u.id == id => u } match {
case Some(htlc) if htlc.rHash == bin2sha256(Crypto.sha256(r)) =>
stay using d.copy(theirChanges = theirChanges.copy(proposed = theirChanges.proposed :+ fulfill))
case Some(htlc) => throw new RuntimeException(s"invalid htlc preimage for htlc $id")
case None => throw new RuntimeException(s"unknown htlc id=$id") // TODO : we should fail the channel
}
case Event(CMD_FAIL_HTLC(id, reason), d@DATA_NORMAL(_, _, _, _, _, theirCommit, ourChanges, theirChanges, _, _)) =>
findHtlc(theirChanges.acked, id) match {
theirChanges.acked.collectFirst { case u: update_add_htlc if u.id == id => u } match {
case Some(htlc) =>
val fail = update_fail_htlc(id, fail_reason(ByteString.copyFromUtf8(reason)))
them ! fail
@ -310,7 +312,7 @@ class Channel(val them: ActorRef, val blockchain: ActorRef, val params: OurChann
}
case Event(fail@update_fail_htlc(id, reason), d@DATA_NORMAL(_, _, _, _, ourCommit, _, ourChanges, theirChanges, _, _)) =>
findHtlc(ourChanges.acked, id) match {
ourChanges.acked.collectFirst { case u: update_add_htlc if u.id == id => u } match {
case Some(htlc) =>
stay using d.copy(theirChanges = theirChanges.copy(proposed = theirChanges.proposed :+ fail))
case None => throw new RuntimeException(s"unknown htlc id=$id") // TODO : we should fail the channel
@ -321,12 +323,11 @@ class Channel(val them: ActorRef, val blockchain: ActorRef, val params: OurChann
// their commitment now includes all our changes + their acked changes
theirNextRevocationHash_opt match {
case Some(theirNextRevocationHash) =>
val ours1 = ourChanges.copy(proposed = Nil, signed = ourChanges.signed ++ ourChanges.proposed)
val spec = reduce(theirCommit.spec, theirChanges.acked, ourChanges.acked ++ ourChanges.signed ++ ourChanges.proposed)
val theirTx = makeTheirTx(ourParams, theirParams, ourCommit.publishableTx.txIn, theirNextRevocationHash, spec)
val ourSig = sign(ourParams, theirParams, anchorOutput.amount.toLong, theirTx)
them ! update_commit(ourSig)
stay using d.copy(theirCommit = TheirCommit(theirCommit.index + 1, spec, theirNextRevocationHash), ourChanges = ours1, theirNextRevocationHash = None)
stay using d.copy(theirCommit = TheirCommit(theirCommit.index + 1, spec, theirNextRevocationHash), ourChanges = ourChanges.copy(proposed = Nil, signed = ourChanges.signed ++ ourChanges.proposed), theirNextRevocationHash = None)
case None => throw new RuntimeException(s"cannot send two update_commit in a row (must wait for revocation)")
}
@ -334,7 +335,6 @@ class Channel(val them: ActorRef, val blockchain: ActorRef, val params: OurChann
// we've received a signature
// ack all their changes
// our commitment now includes all theirs changes + our acked changes
val theirs1 = theirChanges.copy(proposed = Nil, acked = theirChanges.acked ++ theirChanges.proposed)
val spec = reduce(ourCommit.spec, ourChanges.acked, theirChanges.acked ++ theirChanges.proposed)
val ourNextRevocationHash = Crypto.sha256(ShaChain.shaChainFromSeed(ourParams.shaSeed, ourCommit.index + 1))
val ourTx = makeOurTx(ourParams, theirParams, ourCommit.publishableTx.txIn, ourNextRevocationHash, spec)
@ -351,7 +351,7 @@ class Channel(val them: ActorRef, val blockchain: ActorRef, val params: OurChann
val ourNextRevocationHash = Crypto.sha256(ShaChain.shaChainFromSeed(ourParams.shaSeed, ourCommit.index + 2))
them ! update_revocation(ourRevocationPreimage, ourNextRevocationHash)
val ourCommit1 = ourCommit.copy(index = ourCommit.index + 1, spec, publishableTx = signedTx)
stay using d.copy(ourCommit = ourCommit1, theirChanges = theirs1)
stay using d.copy(ourCommit = ourCommit1, theirChanges = theirChanges.copy(proposed = Nil, acked = theirChanges.acked ++ theirChanges.proposed))
}
case Event(msg@update_revocation(revocationPreimage, nextRevocationHash), d@DATA_NORMAL(ourParams, theirParams, shaChain, _, ourCommit, theirCommit, ourChanges, theirChanges, _, _)) =>

View file

@ -149,6 +149,7 @@ object TypeDefs {
import TypeDefs._
case class OurChanges(proposed: List[Change], signed: List[Change], acked: List[Change])
case class TheirChanges(proposed: List[Change], acked: List[Change])
case class Changes(ourChanges: OurChanges, theirChanges: TheirChanges)
case class OurCommit(index: Long, spec: CommitmentSpec, publishableTx: Transaction)
case class TheirCommit(index: Long, spec: CommitmentSpec, theirRevocationHash: sha256_hash)

View file

@ -13,47 +13,31 @@ import scala.util.Try
*/
object Helpers {
def isAddHtlc(change: Change) = change match {
case u:update_add_htlc => true
case _ => false
}
def findHtlc(changes: List[Change], id: Long): Option[Change] = changes.find(_ match {
case u:update_add_htlc if u.id == id => true
case _ => false
})
def findHtlc(changes: List[Change], id: Long, r: sha256_hash): Option[Change] = changes.find(_ match {
case u:update_add_htlc if u.id == id && u.rHash == bin2sha256(Crypto.sha256(r)) => true
case u:update_add_htlc if u.id == id => throw new RuntimeException(s"invalid htlc preimage for htlc $id")
case _ => false
})
def removeHtlc(changes: List[Change], id: Long) : List[Change] = changes.filterNot(_ match {
def removeHtlc(changes: List[Change], id: Long): List[Change] = changes.filterNot(_ match {
case u: update_add_htlc if u.id == id => true
case _ => false
})
def addHtlc(spec: CommitmentSpec, direction: Direction, update: update_add_htlc) : CommitmentSpec = {
def addHtlc(spec: CommitmentSpec, direction: Direction, update: update_add_htlc): CommitmentSpec = {
val htlc = Htlc(direction, update.id, update.amountMsat, update.rHash, update.expiry, previousChannelId = None)
direction match {
case OUT => spec.copy(amount_us_msat = spec.amount_us_msat - htlc.amountMsat, htlcs = spec.htlcs + htlc)
case OUT => spec.copy(amount_us_msat = spec.amount_us_msat - htlc.amountMsat, htlcs = spec.htlcs + htlc)
case IN => spec.copy(amount_them_msat = spec.amount_them_msat - htlc.amountMsat, htlcs = spec.htlcs + htlc)
}
}
// OUT means we are sending an update_fulfill_htlc message which means that we are fulfilling an HTLC that they sent
def fulfillHtlc(spec: CommitmentSpec, direction: Direction, update: update_fulfill_htlc) : CommitmentSpec = {
def fulfillHtlc(spec: CommitmentSpec, direction: Direction, update: update_fulfill_htlc): CommitmentSpec = {
spec.htlcs.find(htlc => htlc.id == update.id && htlc.rHash == bin2sha256(Crypto.sha256(update.r))) match {
case Some(htlc) => direction match {
case OUT => spec.copy(amount_us_msat = spec.amount_us_msat + htlc.amountMsat, htlcs = spec.htlcs - htlc)
case OUT => spec.copy(amount_us_msat = spec.amount_us_msat + htlc.amountMsat, htlcs = spec.htlcs - htlc)
case IN => spec.copy(amount_them_msat = spec.amount_them_msat + htlc.amountMsat, htlcs = spec.htlcs - htlc)
}
}
}
// OUT means we are sending an update_fail_htlc message which means that we are failing an HTLC that they sent
def failHtlc(spec: CommitmentSpec, direction: Direction, update: update_fail_htlc) : CommitmentSpec = {
def failHtlc(spec: CommitmentSpec, direction: Direction, update: update_fail_htlc): CommitmentSpec = {
spec.htlcs.find(_.id == update.id) match {
case Some(htlc) => direction match {
case OUT => spec.copy(amount_them_msat = spec.amount_them_msat + htlc.amountMsat, htlcs = spec.htlcs - htlc)
@ -64,26 +48,24 @@ object Helpers {
def reduce(ourCommitSpec: CommitmentSpec, ourChanges: List[Change], theirChanges: List[Change]): CommitmentSpec = {
val spec = ourCommitSpec.copy(htlcs = Set(), amount_us_msat = ourCommitSpec.initial_amount_us_msat, amount_them_msat = ourCommitSpec.initial_amount_them_msat)
val spec1 = ourChanges.filter(isAddHtlc).foldLeft(spec)( (spec, change) => change match {
case u: update_add_htlc => addHtlc(spec, OUT, u)
case u: update_fulfill_htlc => fulfillHtlc(spec, OUT, u)
case u: update_fail_htlc => failHtlc(spec, OUT, u)
})
val spec2 = theirChanges.filter(isAddHtlc).foldLeft(spec1)( (spec, change) => change match {
case u: update_add_htlc => addHtlc(spec, IN, u)
case u: update_fulfill_htlc => fulfillHtlc(spec, IN, u)
case u: update_fail_htlc => failHtlc(spec, IN, u)
})
val spec3 = ourChanges.filterNot(isAddHtlc).foldLeft(spec2)( (spec, change) => change match {
case u: update_add_htlc => addHtlc(spec, OUT, u)
case u: update_fulfill_htlc => fulfillHtlc(spec, OUT, u)
case u: update_fail_htlc => failHtlc(spec, OUT, u)
})
val spec4 = theirChanges.filterNot(isAddHtlc).foldLeft(spec3)( (spec, change) => change match {
case u: update_add_htlc => addHtlc(spec, IN, u)
case u: update_fulfill_htlc => fulfillHtlc(spec, IN, u)
case u: update_fail_htlc => failHtlc(spec, IN, u)
})
val spec1 = ourChanges.foldLeft(spec) {
case (spec, u: update_add_htlc) => addHtlc(spec, OUT, u)
case (spec, _) => spec
}
val spec2 = theirChanges.foldLeft(spec1) {
case (spec, u: update_add_htlc) => addHtlc(spec, IN, u)
case (spec, _) => spec
}
val spec3 = ourChanges.foldLeft(spec2) {
case (spec, u: update_fulfill_htlc) => fulfillHtlc(spec, OUT, u)
case (spec, u: update_fail_htlc) => failHtlc(spec, OUT, u)
case (spec, _) => spec
}
val spec4 = theirChanges.foldLeft(spec3) {
case (spec, u: update_fulfill_htlc) => fulfillHtlc(spec, IN, u)
case (spec, u: update_fail_htlc) => failHtlc(spec, IN, u)
case (spec, _) => spec
}
spec4
}
@ -93,10 +75,10 @@ object Helpers {
def makeTheirTx(ourParams: OurChannelParams, theirParams: TheirChannelParams, inputs: Seq[TxIn], theirRevocationHash: sha256_hash, spec: CommitmentSpec): Transaction =
makeCommitTx(inputs, theirParams.finalPubKey, ourParams.finalPubKey, theirParams.delay, theirRevocationHash, spec)
def sign(ourParams: OurChannelParams, theirParams: TheirChannelParams, anchorAmount: Long, tx: Transaction): signature =
def sign(ourParams: OurChannelParams, theirParams: TheirChannelParams, anchorAmount: Long, tx: Transaction): signature =
bin2signature(Transaction.signInput(tx, 0, multiSig2of2(ourParams.commitPubKey, theirParams.commitPubKey), SIGHASH_ALL, anchorAmount, 1, ourParams.commitPrivKey))
def addSigs(ourParams: OurChannelParams, theirParams: TheirChannelParams, anchorAmount: Long, tx: Transaction, ourSig: signature, theirSig: signature): Transaction = {
def addSigs(ourParams: OurChannelParams, theirParams: TheirChannelParams, anchorAmount: Long, tx: Transaction, ourSig: signature, theirSig: signature): Transaction = {
// TODO : Transaction.sign(...) should handle multisig
val ourSig = Transaction.signInput(tx, 0, multiSig2of2(ourParams.commitPubKey, theirParams.commitPubKey), SIGHASH_ALL, anchorAmount, 1, ourParams.commitPrivKey)
val witness = witness2of2(theirSig, ourSig, theirParams.commitPubKey, ourParams.commitPubKey)