Add BasicArithmetic trait (#329)

* Add BasicArithmetic to MilliSatoshis

* Add BasicArithmetic

* Add LnCurrencyUnit, CurrencyUnit, Number to BasicArithmetic

* Add tests for BasicArithmetic

* Make logger in unit test class protected and annotated

* Use BitcoinSUnitTest class

* Address code review
This commit is contained in:
Torkel Rogstad 2019-02-16 22:19:00 +01:00 committed by Chris Stewart
parent 15160eccc6
commit 696e9f4c45
9 changed files with 268 additions and 53 deletions

View file

@ -0,0 +1,63 @@
package org.bitcoins.core.number
import org.bitcoins.core.gen.NumberGenerator
import org.bitcoins.core.util.BitcoinSUnitTest
import org.scalatest.prop.PropertyChecks
class BasicArithmeticSpec extends BitcoinSUnitTest {
// We have to wrap BasicArithmetic instead of doing
// an anonymous class, that causes overloading confusion
private case class NumWrapper(underlying: BigInt)
extends BasicArithmetic[NumWrapper] {
override def +(n: NumWrapper): NumWrapper =
NumWrapper(underlying + n.underlying)
override def -(n: NumWrapper): NumWrapper =
NumWrapper(underlying - n.underlying)
override def *(factor: BigInt): NumWrapper =
NumWrapper(underlying * factor)
override def *(factor: NumWrapper): NumWrapper =
NumWrapper(underlying * factor.underlying)
}
private val numWrapperGen = for {
int <- NumberGenerator.bigInts
} yield NumWrapper(int)
behavior of "BasicArithmetic"
it must "multiply safely and unsafely with an int" in {
PropertyChecks.forAll(NumberGenerator.bigInts, numWrapperGen) { (i, num) =>
val unsafe = num * i
val safe = num.multiplySafe(i)
assert(safe.toOption.contains(unsafe))
}
}
it must "multiply safely and unsafely with itself" in {
PropertyChecks.forAll(numWrapperGen, numWrapperGen) { (first, second) =>
val unsafe = first * second
val safe = first.multiplySafe(second)
assert(safe.toOption.contains(unsafe))
}
}
it must "add safely and unsafely" in {
PropertyChecks.forAll(numWrapperGen, numWrapperGen) { (first, second) =>
val unsafe = first + second
val safe = first.addSafe(second)
assert(safe.toOption.contains(unsafe))
}
}
it must "subtract safely and unsafely" in {
PropertyChecks.forAll(numWrapperGen, numWrapperGen) { (first, second) =>
val unsafe = first - second
val safe = first.subtractSafe(second)
assert(safe.toOption.contains(unsafe))
}
}
}

View file

@ -1,14 +1,11 @@
package org.bitcoins.core.protocol.ln.currency
import org.bitcoins.core.gen.CurrencyUnitGenerator
import org.bitcoins.core.gen.{CurrencyUnitGenerator, NumberGenerator}
import org.bitcoins.core.gen.ln.LnCurrencyUnitGen
import org.scalacheck.Gen
import org.bitcoins.core.util.BitcoinSUnitTest
import org.scalatest.prop.PropertyChecks
import org.scalatest.{FlatSpec, MustMatchers}
import org.slf4j.LoggerFactory
class MilliSatoshisTest extends FlatSpec with MustMatchers {
private val logger = LoggerFactory.getLogger(this.getClass)
class MilliSatoshisTest extends BitcoinSUnitTest {
behavior of "MilliSatoshis"
it must "convert pico bitcoins to msat correctly" in {
@ -28,6 +25,63 @@ class MilliSatoshisTest extends FlatSpec with MustMatchers {
MilliSatoshis.fromPico(PicoBitcoins(110)) must be(MilliSatoshis(11))
}
it must "add millisatoshis" in {
PropertyChecks.forAll(LnCurrencyUnitGen.milliSatoshisPair) {
case (first, second) =>
val bigInt = first.toBigInt + second.toBigInt
assert((first + second).toBigInt == bigInt)
}
}
private val msatWithNum = for {
msat <- LnCurrencyUnitGen.milliSatoshis
num <- NumberGenerator.bigIntsUInt64Range.filter(_ > 0)
} yield (msat, num)
it must "multiply millisatoshis with an int" in {
PropertyChecks.forAll(msatWithNum) {
case (msat, bigint) =>
val underlyingCalc = msat.toBigInt * bigint
assert((msat * bigint).toBigInt == underlyingCalc)
}
}
it must "multiply millisatoshis with itself" in {
PropertyChecks.forAll(LnCurrencyUnitGen.milliSatoshisPair) {
case (first, second) =>
val safe = first.multiplySafe(second)
val unsafe = first * second
assert(safe.toOption.contains(unsafe))
val underlying = first.toBigInt * second.toBigInt
assert(unsafe.toBigInt == underlying)
}
}
it must "subtract msats after adding them" in {
PropertyChecks.forAll(LnCurrencyUnitGen.milliSatoshisPair) {
case (first, second) =>
val added = first + second
val subtracted = added - second
assert(subtracted == first)
}
}
it must "subtract msats" in {
PropertyChecks.forAll(LnCurrencyUnitGen.milliSatoshisPair) {
case (first, second) =>
val subtracted = first subtractSafe second
val isPositive = (first.toBigInt - second.toBigInt) >= 0
assert(subtracted.isSuccess == isPositive)
if (subtracted.isSuccess) {
val underlyingCalc = first.toBigInt - second.toBigInt
assert(subtracted.get.toBigInt == underlyingCalc)
}
}
}
it must "covert from a ln currency unit -> millisatoshis -> lnCurrencyUnit" in {
PropertyChecks.forAll(LnCurrencyUnitGen.positivePicoBitcoin) { pb =>

View file

@ -2,12 +2,15 @@ package org.bitcoins.core.util
import org.scalatest.prop.PropertyChecks
import org.scalatest.{FlatSpec, MustMatchers}
import org.slf4j.LoggerFactory
import org.slf4j.{Logger, LoggerFactory}
/** A wrapper for boiler plate testing procesures in bitcoin-s */
abstract class BitcoinSUnitTest extends FlatSpec with MustMatchers with PropertyChecks {
abstract class BitcoinSUnitTest
extends FlatSpec
with MustMatchers
with PropertyChecks {
lazy val logger = LoggerFactory.getLogger(getClass)
lazy protected val logger: Logger = LoggerFactory.getLogger(getClass)
/** The configuration for property based tests in our testing suite
* See: http://www.scalatest.org/user_guide/writing_scalacheck_style_properties

View file

@ -1,13 +1,15 @@
package org.bitcoins.core.currency
import org.bitcoins.core.consensus.Consensus
import org.bitcoins.core.number.{BaseNumbers, Int64}
import org.bitcoins.core.number.{BaseNumbers, BasicArithmetic, Int64}
import org.bitcoins.core.protocol.NetworkElement
import org.bitcoins.core.serializers.RawSatoshisSerializer
import org.bitcoins.core.util.Factory
import scodec.bits.ByteVector
sealed abstract class CurrencyUnit extends NetworkElement {
sealed abstract class CurrencyUnit
extends NetworkElement
with BasicArithmetic[CurrencyUnit] {
type A
def satoshis: Satoshis
@ -32,15 +34,19 @@ sealed abstract class CurrencyUnit extends NetworkElement {
def ==(c: CurrencyUnit): Boolean = satoshis == c.satoshis
def +(c: CurrencyUnit): CurrencyUnit = {
override def +(c: CurrencyUnit): CurrencyUnit = {
Satoshis(satoshis.underlying + c.satoshis.underlying)
}
def -(c: CurrencyUnit): CurrencyUnit = {
override def -(c: CurrencyUnit): CurrencyUnit = {
Satoshis(satoshis.underlying - c.satoshis.underlying)
}
def *(c: CurrencyUnit): CurrencyUnit = {
override def *(factor: BigInt): CurrencyUnit = {
Satoshis(satoshis.underlying * factor)
}
override def *(c: CurrencyUnit): CurrencyUnit = {
Satoshis(satoshis.underlying * c.satoshis.underlying)
}
@ -48,7 +54,7 @@ sealed abstract class CurrencyUnit extends NetworkElement {
Satoshis(-satoshis.underlying)
}
override def bytes = satoshis.bytes
override def bytes: ByteVector = satoshis.bytes
def toBigDecimal: BigDecimal
@ -58,7 +64,9 @@ sealed abstract class CurrencyUnit extends NetworkElement {
sealed abstract class Satoshis extends CurrencyUnit {
override type A = Int64
override def bytes = RawSatoshisSerializer.write(this)
override def toString: String = s"$toLong sat"
override def bytes: ByteVector = RawSatoshisSerializer.write(this)
override def satoshis: Satoshis = this
@ -66,7 +74,7 @@ sealed abstract class Satoshis extends CurrencyUnit {
def toBigInt: BigInt = BigInt(toLong)
def toLong = underlying.toLong
def toLong: Long = underlying.toLong
def ==(satoshis: Satoshis): Boolean = underlying == satoshis.underlying
}
@ -80,7 +88,6 @@ object Satoshis extends Factory[Satoshis] with BaseNumbers[Satoshis] {
override def fromBytes(bytes: ByteVector): Satoshis =
RawSatoshisSerializer.read(bytes)
def apply(int64: Int64): Satoshis = SatoshisImpl(int64)
private case class SatoshisImpl(underlying: Int64) extends Satoshis
@ -89,9 +96,11 @@ object Satoshis extends Factory[Satoshis] with BaseNumbers[Satoshis] {
sealed abstract class Bitcoins extends CurrencyUnit {
override type A = BigDecimal
override def toString: String = s"$toBigDecimal BTC"
override def toBigDecimal: BigDecimal = underlying
override def hex = satoshis.hex
override def hex: String = satoshis.hex
override def satoshis: Satoshis = {
val sat = underlying * CurrencyUnits.btcToSatoshiScalar

View file

@ -0,0 +1,47 @@
package org.bitcoins.core.number
import scala.util.Try
/**
* @define mulSafe
* Some classes have restrictions on upper bounds
* for it's underlying value. This might cause the `*`
* operator to throw. This method wraps it in a `Try`
* block.
*/
trait BasicArithmetic[N] {
def +(n: N): N
/**
* Some classes have restrictions on upper bounds
* for it's underlying value. This might cause the `+`
* operator to throw. This method wraps it in a `Try`
* block.
*/
def addSafe(n: N): Try[N] = Try { this + n }
def -(n: N): N
/**
* Some classes have restrictions on lower bounds
* for it's underlying value. This might cause the `-`
* operator to throw. This method wraps it in a `Try`
* block.
*/
def subtractSafe(n: N): Try[N] = Try { this - n }
def *(factor: BigInt): N
/**
* $mulSafe
*/
def multiplySafe(factor: BigInt): Try[N] = Try { this * factor }
def *(factor: N): N
/**
* $mulSafe
*/
def multiplySafe(factor: N): Try[N] = Try { this * factor }
}

View file

@ -14,7 +14,9 @@ import scala.util.{Failure, Success, Try}
* This is useful for dealing with codebases/protocols that rely on C's
* unsigned integer types
*/
sealed abstract class Number[T <: Number[T]] extends NetworkElement {
sealed abstract class Number[T <: Number[T]]
extends NetworkElement
with BasicArithmetic[T] {
type A = BigInt
/** The underlying scala number used to to hold the number */
@ -33,9 +35,12 @@ sealed abstract class Number[T <: Number[T]] extends NetworkElement {
/** Factory function to create the underlying T, for instance a UInt32 */
def apply: A => T
def +(num: T): T = apply(checkResult(underlying + num.underlying))
def -(num: T): T = apply(checkResult(underlying - num.underlying))
def *(num: T): T = apply(checkResult(underlying * num.underlying))
override def +(num: T): T = apply(checkResult(underlying + num.underlying))
override def -(num: T): T = apply(checkResult(underlying - num.underlying))
override def *(factor: BigInt): T = apply(checkResult(underlying * factor))
override def *(num: T): T = apply(checkResult(underlying * num.underlying))
def >(num: T): Boolean = underlying > num.underlying
def >=(num: T): Boolean = underlying >= num.underlying
def <(num: T): Boolean = underlying < num.underlying
@ -64,6 +69,11 @@ sealed abstract class Number[T <: Number[T]] extends NetworkElement {
def |(num: T): T = apply(checkResult(underlying | num.underlying))
def &(num: T): T = apply(checkResult(underlying & num.underlying))
def unary_- : T = apply(-underlying)
/**
* Checks if the given result is within the range
* of this number type
*/
private def checkResult(result: BigInt): A = {
require((result & andMask) == result,
"Result was out of bounds, got: " + result)
@ -81,7 +91,7 @@ sealed abstract class Number[T <: Number[T]] extends NetworkElement {
}
}
override def bytes = BitcoinSUtil.decodeHex(hex)
override def bytes: ByteVector = BitcoinSUtil.decodeHex(hex)
}
/**
@ -115,7 +125,7 @@ sealed abstract class UInt5 extends UnsignedNumber[UInt5] {
sealed abstract class UInt8 extends UnsignedNumber[UInt8] {
override def apply: A => UInt8 = UInt8(_)
override def hex = BitcoinSUtil.encodeHex(toInt.toShort).slice(2, 4)
override def hex: String = BitcoinSUtil.encodeHex(toInt.toShort).slice(2, 4)
override def andMask = 0xff
@ -130,7 +140,7 @@ sealed abstract class UInt8 extends UnsignedNumber[UInt8] {
*/
sealed abstract class UInt32 extends UnsignedNumber[UInt32] {
override def apply: A => UInt32 = UInt32(_)
override def hex = BitcoinSUtil.encodeHex(toLong).slice(8, 16)
override def hex: String = BitcoinSUtil.encodeHex(toLong).slice(8, 16)
override def andMask = 0xffffffffL
}
@ -139,16 +149,16 @@ sealed abstract class UInt32 extends UnsignedNumber[UInt32] {
* Represents a uint64_t in C
*/
sealed abstract class UInt64 extends UnsignedNumber[UInt64] {
override def hex = encodeHex(underlying)
override def hex: String = encodeHex(underlying)
override def apply: A => UInt64 = UInt64(_)
override def andMask = 0xffffffffffffffffL
/**
* The converts a [[BigInt]] to a 8 byte hex representation
* Converts a [[BigInt]] to a 8 byte hex representation.
* [[BigInt]] will only allocate 1 byte for numbers like 1 which require 1 byte, giving us the hex representation 01
* this function pads the hex chars to be 0000000000000001
* @param bigInt
* @return
* @param bigInt The number to encode
* @return The hex encoded number
*/
private def encodeHex(bigInt: BigInt): String = {
val hex = BitcoinSUtil.encodeHex(bigInt)
@ -168,7 +178,7 @@ sealed abstract class UInt64 extends UnsignedNumber[UInt64] {
sealed abstract class Int32 extends SignedNumber[Int32] {
override def apply: A => Int32 = Int32(_)
override def andMask = 0xffffffff
override def hex = BitcoinSUtil.encodeHex(toInt)
override def hex: String = BitcoinSUtil.encodeHex(toInt)
}
/**
@ -177,7 +187,7 @@ sealed abstract class Int32 extends SignedNumber[Int32] {
sealed abstract class Int64 extends SignedNumber[Int64] {
override def apply: A => Int64 = Int64(_)
override def andMask = 0xffffffffffffffffL
override def hex = BitcoinSUtil.encodeHex(toLong)
override def hex: String = BitcoinSUtil.encodeHex(toLong)
}
/**
@ -200,7 +210,7 @@ object UInt5 extends Factory[UInt5] with BaseNumbers[UInt5] {
lazy val zero = UInt5(0.toByte)
lazy val one = UInt5(1.toByte)
lazy val min = zero
lazy val min: UInt5 = zero
lazy val max = UInt5(31.toByte)
def apply(byte: Byte): UInt5 = fromByte(byte)
@ -208,7 +218,7 @@ object UInt5 extends Factory[UInt5] with BaseNumbers[UInt5] {
def apply(bigInt: BigInt): UInt5 = {
require(
bigInt.toByteArray.size == 1,
bigInt.toByteArray.length == 1,
s"To create a uint5 from a BigInt it must be less than 32. Got: ${bigInt.toString}")
UInt5.fromByte(bigInt.toByteArray.head)
@ -230,7 +240,7 @@ object UInt5 extends Factory[UInt5] with BaseNumbers[UInt5] {
}
def toUInt5s(bytes: ByteVector): Vector[UInt5] = {
bytes.toArray.map(toUInt5(_)).toVector
bytes.toArray.map(toUInt5).toVector
}
}
@ -242,7 +252,7 @@ object UInt8 extends Factory[UInt8] with BaseNumbers[UInt8] {
lazy val zero = UInt8(0.toShort)
lazy val one = UInt8(1.toShort)
lazy val min = zero
lazy val min: UInt8 = zero
lazy val max = UInt8(255.toShort)
def apply(short: Short): UInt8 = UInt8(BigInt(short))
@ -268,7 +278,7 @@ object UInt8 extends Factory[UInt8] with BaseNumbers[UInt8] {
def toByte(uInt8: UInt8): Byte = uInt8.underlying.toByte
def toBytes(us: Seq[UInt8]): ByteVector = {
ByteVector(us.map(toByte(_)))
ByteVector(us.map(toByte))
}
def toUInt8s(bytes: ByteVector): Vector[UInt8] = {
@ -296,7 +306,7 @@ object UInt32 extends Factory[UInt32] with BaseNumbers[UInt32] {
lazy val zero = UInt32(0)
lazy val one = UInt32(1)
lazy val min = zero
lazy val min: UInt32 = zero
lazy val max = UInt32(4294967295L)
override def fromBytes(bytes: ByteVector): UInt32 = {
@ -333,7 +343,7 @@ object UInt64 extends Factory[UInt64] with BaseNumbers[UInt64] {
lazy val zero = UInt64(BigInt(0))
lazy val one = UInt64(BigInt(1))
lazy val min = zero
lazy val min: UInt64 = zero
lazy val max = UInt64(BigInt("18446744073709551615"))
override def fromBytes(bytes: ByteVector): UInt64 = {

View file

@ -1,7 +1,7 @@
package org.bitcoins.core.protocol.ln.currency
import org.bitcoins.core.currency.Satoshis
import org.bitcoins.core.number.{BaseNumbers, Int64, UInt5}
import org.bitcoins.core.number.{BaseNumbers, BasicArithmetic, Int64, UInt5}
import org.bitcoins.core.protocol.NetworkElement
import org.bitcoins.core.protocol.ln._
import org.bitcoins.core.util.Bech32
@ -9,7 +9,9 @@ import scodec.bits.ByteVector
import scala.util.{Failure, Try}
sealed abstract class LnCurrencyUnit extends NetworkElement {
sealed abstract class LnCurrencyUnit
extends NetworkElement
with BasicArithmetic[LnCurrencyUnit] {
def character: Char
def >=(ln: LnCurrencyUnit): Boolean = {
@ -33,15 +35,18 @@ sealed abstract class LnCurrencyUnit extends NetworkElement {
def ==(ln: LnCurrencyUnit): Boolean =
toPicoBitcoinValue == ln.toPicoBitcoinValue
def +(ln: LnCurrencyUnit): LnCurrencyUnit = {
override def +(ln: LnCurrencyUnit): LnCurrencyUnit = {
PicoBitcoins(toPicoBitcoinValue + ln.toPicoBitcoinValue)
}
def -(ln: LnCurrencyUnit): LnCurrencyUnit = {
override def -(ln: LnCurrencyUnit): LnCurrencyUnit = {
PicoBitcoins(toPicoBitcoinValue - ln.toPicoBitcoinValue)
}
def *(ln: LnCurrencyUnit): LnCurrencyUnit = {
override def *(factor: BigInt): LnCurrencyUnit =
PicoBitcoins(toPicoBitcoinValue * factor)
override def *(ln: LnCurrencyUnit): LnCurrencyUnit = {
PicoBitcoins(toPicoBitcoinValue * ln.toPicoBitcoinValue)
}

View file

@ -1,7 +1,7 @@
package org.bitcoins.core.protocol.ln.currency
import org.bitcoins.core.currency.{CurrencyUnit, Satoshis}
import org.bitcoins.core.number.UInt64
import org.bitcoins.core.number.{BasicArithmetic, UInt64}
import org.bitcoins.core.protocol.NetworkElement
import scodec.bits.ByteVector
@ -13,7 +13,9 @@ import scala.math.BigDecimal.RoundingMode
*
* @see [[https://github.com/lightningnetwork/lightning-rfc/blob/master/02-peer-protocol.md#adding-an-htlc-update_add_htlc BOLT2]]
*/
sealed abstract class MilliSatoshis extends NetworkElement {
sealed abstract class MilliSatoshis
extends NetworkElement
with BasicArithmetic[MilliSatoshis] {
require(toBigInt >= 0, s"Millisatoshis cannot be negative, got $toBigInt")
protected def underlying: BigInt
@ -25,7 +27,7 @@ sealed abstract class MilliSatoshis extends NetworkElement {
* 10 msat
* }}}
*/
override def toString: String = s"$toLong msat"
override def toString: String = s"$toBigInt msat"
def toBigInt: BigInt = underlying
@ -62,27 +64,43 @@ sealed abstract class MilliSatoshis extends NetworkElement {
}
def ==(ms: MilliSatoshis): Boolean = {
toLnCurrencyUnit == ms.toLnCurrencyUnit
toBigInt == ms.toBigInt
}
def !=(ms: MilliSatoshis): Boolean = {
toLnCurrencyUnit != ms.toLnCurrencyUnit
toBigInt != ms.toBigInt
}
def >=(ms: MilliSatoshis): Boolean = {
toLnCurrencyUnit >= ms.toLnCurrencyUnit
toBigInt >= ms.toBigInt
}
def >(ms: MilliSatoshis): Boolean = {
toLnCurrencyUnit > ms.toLnCurrencyUnit
toBigInt > ms.toBigInt
}
def <(ms: MilliSatoshis): Boolean = {
toLnCurrencyUnit < ms.toLnCurrencyUnit
toBigInt < ms.toBigInt
}
def <=(ms: MilliSatoshis): Boolean = {
toLnCurrencyUnit <= ms.toLnCurrencyUnit
toBigInt <= ms.toBigInt
}
override def +(ms: MilliSatoshis): MilliSatoshis = {
MilliSatoshis(toBigInt + ms.toBigInt)
}
override def -(ms: MilliSatoshis): MilliSatoshis = {
MilliSatoshis(toBigInt - ms.toBigInt)
}
override def *(factor: BigInt): MilliSatoshis = {
MilliSatoshis(toBigInt * factor)
}
override def *(factor: MilliSatoshis): MilliSatoshis = {
MilliSatoshis(toBigInt * factor.toBigInt)
}
def toUInt64: UInt64 = {

View file

@ -54,6 +54,12 @@ trait LnCurrencyUnitGen {
for {
i64 <- NumberGenerator.uInt64
} yield MilliSatoshis(i64.toBigInt)
def milliSatoshisPair: Gen[(MilliSatoshis, MilliSatoshis)] =
for {
first <- milliSatoshis
second <- milliSatoshis
} yield (first, second)
}
object LnCurrencyUnitGen extends LnCurrencyUnitGen