Correctly decoding bech32 address to spks

This commit is contained in:
Chris Stewart 2017-10-26 14:28:54 -05:00
parent 4643d46256
commit ce5ae1a69a
8 changed files with 150 additions and 12 deletions

View File

@ -430,6 +430,8 @@ object UInt8 extends Factory[UInt8] with BaseNumbers[UInt8] {
def apply(short: Short): UInt8 = UInt8Impl(short)
def apply(byte: Byte): UInt8 = toUInt8(byte)
def isValid(short: Short): Boolean = short >= 0 && short < 256
override def fromBytes(bytes: Seq[Byte]): UInt8 = {
@ -443,9 +445,13 @@ object UInt8 extends Factory[UInt8] with BaseNumbers[UInt8] {
if ((byte & 0x80) == 0x80) {
val r = (byte & 0x7f) + NumberUtil.pow2(7)
UInt8(r.toShort)
} else UInt8(byte)
} else UInt8Impl(byte)
}
def toByte(uInt8: UInt8): Byte = uInt8.underlying.toByte
def toBytes(us: Seq[UInt8]): Seq[Byte] = us.map(toByte(_))
def toUInt8s(bytes: Seq[Byte]): Seq[UInt8] = bytes.map(toUInt8(_))
}

View File

@ -5,8 +5,11 @@ import org.bitcoins.core.crypto.{ECPublicKey, Sha256Hash160Digest}
import org.bitcoins.core.number.{UInt32, UInt8}
import org.bitcoins.core.protocol.transaction.TransactionOutput
import org.bitcoins.core.protocol.script._
import org.bitcoins.core.script.constant.ScriptConstant
import org.bitcoins.core.serializers.script.ScriptParser
import org.bitcoins.core.util._
import scala.annotation.tailrec
import scala.util.{Failure, Success, Try}
sealed abstract class Address {
@ -133,6 +136,8 @@ sealed abstract class Bech32Address extends BitcoinAddress {
object Bech32Address {
private case class Bech32AddressImpl(hrp: HumanReadablePart, data: Seq[UInt8]) extends Bech32Address
private val logger = BitcoinSLogger.logger
def isValid(bytes: Seq[Byte]): Boolean = ???
def apply(witSPK: WitnessScriptPubKey,
@ -201,31 +206,110 @@ object Bech32Address {
chk
}
def verifyChecksum(hrp: HumanReadablePart, data: Seq[UInt8]): Boolean = {
polyMod(hrpExpand(hrp) ++ data) == 1
def verifyChecksum(hrp: HumanReadablePart, data: Seq[Byte]): Boolean = {
val u8s = UInt8.toUInt8s(data)
polyMod(hrpExpand(hrp) ++ u8s) == 1
}
private val u32Five = UInt32(5)
private val u32Eight = UInt32(8)
/** The separator between the hrp and payload of bech 32 */
private val separator = '1'
/** Converts a byte array from base 8 to base 5 */
def encode(bytes: Seq[UInt8]): Try[Seq[UInt8]] = {
NumberUtil.convertBits(bytes,u32Eight,u32Five,true)
NumberUtil.convertUInt8s(bytes,u32Eight,u32Five,true)
}
/** Decodes a byte array from base 5 to base 8 */
def decode(b: Seq[UInt8]): Try[Seq[UInt8]] = {
NumberUtil.convertBits(b,u32Five,u32Eight,false)
NumberUtil.convertUInt8s(b,u32Five,u32Eight,false)
}
/** Takes a base32 byte array and encodes it to a string */
def encodeToString(b: Seq[UInt8]): String = {
b.map(b => charset(b.underlying)).mkString
}
/** Decodes a base32 string to a base 32 byte array */
def decodeFromString(string: String): Try[Seq[Byte]] = {
val invariant = Try(require(string.exists(charset.contains(_)), "String contained a non base32 character"))
???
/** Decodes bech32 string to a spk */
def fromString(str: String): Try[ScriptPubKey] = {
if (str.size > 90 || str.size < 8) {
Failure(new IllegalArgumentException("bech32 payloads must be betwee 8 and 90 chars, got: " + str.size))
} else if (str(2) != separator) {
Failure(new IllegalArgumentException("Bech32 address did not have the correct separator, got: " + str(2)))
} else {
val (hrp,data) = (str.take(2), str.splitAt(3)._2)
if (hrp.size < 1 || data.size < 6) {
Failure(new IllegalArgumentException("Hrp/data too short"))
} else {
val hrpValid = checkHrpValidity(hrp)
val dataValid = checkDataValidity(data)
val isChecksumValid: Try[Seq[Byte]] = hrpValid.flatMap { h =>
dataValid.flatMap { d =>
if (verifyChecksum(h,d)) {
//remove checksum bytes since it is valid
Success(d.take(d.size - 6))
}
else Failure(new IllegalArgumentException("Checksum was invalid on the bech32 address"))
}
}
isChecksumValid.flatMap { bytes: Seq[Byte] =>
val (v,prog) = (bytes.head,bytes.tail)
val convertedProg = NumberUtil.convertBytes(prog,u32Five,u32Eight,false)
val progBytes = convertedProg.map(UInt8.toBytes(_))
val witVersion = WitnessVersion(v)
progBytes.flatMap { prog =>
val pushOp = BitcoinScriptUtil.calculatePushOp(prog)
WitnessScriptPubKey(Seq(witVersion.version) ++ pushOp ++ Seq(ScriptConstant(prog))) match {
case Some(spk) => Success(spk)
case None => Failure(new IllegalArgumentException("Failed to decode bech32 into a witSPK"))
}
}
}
}
}
}
private def checkHrpValidity(hrp: String): Try[HumanReadablePart] = {
var (isLower,isUpper) = (false,false)
@tailrec
def loop(remaining: List[Char], accum: Seq[UInt8]): Try[Seq[UInt8]] = remaining match {
case h :: t =>
if (h < 33 || h > 126) {
Failure(new IllegalArgumentException("Invalid character range for hrp, got: " + hrp))
} else {
if (h.isLower) {
isLower = true
} else if (h.isUpper) {
isUpper = true
}
loop(t,UInt8(h.toByte) +: accum)
}
case Nil =>
Success(accum.reverse)
}
loop(hrp.toCharArray.toList,Nil).map { _ =>
HumanReadablePart(hrp)
}
}
def checkDataValidity(data: String): Try[Seq[Byte]] = {
@tailrec
def loop(remaining: List[Char], accum: Seq[Byte], hasUpper: Boolean, hasLower: Boolean): Try[Seq[Byte]] = remaining match {
case Nil => Success(accum.reverse)
case h :: t =>
if (!charset.contains(h)) {
Failure(new IllegalArgumentException("Invalid character in data of bech32 address, got: " + h))
} else {
if ((h.isUpper && hasLower) || (h.isLower && hasUpper)) {
Failure(new IllegalArgumentException("Cannot have mixed case for bech32 address"))
} else {
val byte = charset.indexOf(h).toByte
require(byte >= 0 && byte < 32)
loop(t, byte +: accum, h.isUpper || hasUpper, h.isLower || hasLower)
}
}
}
val payload: Try[Seq[Byte]] = loop(data.toCharArray.toList,Nil,false,false)
payload
}
/** https://github.com/bitcoin/bips/blob/master/bip-0173.mediawiki#bech32 */
def charset: Seq[Char] = Seq('q', 'p', 'z', 'r', 'y', '9', 'x', '8',

View File

@ -68,6 +68,8 @@ object CompactSizeUInt extends Factory[CompactSizeUInt] {
else CompactSizeUInt(UInt64(bytes.size),9)
}
def calc(bytes: Seq[Byte]): CompactSizeUInt = calculateCompactSizeUInt(bytes)
/** Responsible for calculating what the [[CompactSizeUInt]] is for this hex string. */
def calculateCompactSizeUInt(hex : String) : CompactSizeUInt = calculateCompactSizeUInt(BitcoinSUtil.decodeHex(hex))
@ -90,6 +92,8 @@ object CompactSizeUInt extends Factory[CompactSizeUInt] {
else CompactSizeUInt(UInt64(bytes.slice(1,9).reverse),9)
}
def parse(bytes: Seq[Byte]): CompactSizeUInt = parseCompactSizeUInt(bytes)
/** Returns the size of a VarInt in the number of bytes
* https://en.bitcoin.it/wiki/Protocol_documentation#Variable_length_integer. */
def parseCompactSizeUIntSize(byte : Byte) : Long = {

View File

@ -491,7 +491,7 @@ sealed trait WitnessScriptPubKey extends ScriptPubKey {
object WitnessScriptPubKey {
/** Witness scripts must begin with one of these operations, see BIP141 */
private val validFirstOps = Seq(OP_0,OP_1,OP_2,OP_3,OP_4,OP_5,OP_6, OP_7, OP_8,
private val validFirstOps: Seq[ScriptNumberOperation] = Seq(OP_0,OP_1,OP_2,OP_3,OP_4,OP_5,OP_6, OP_7, OP_8,
OP_9, OP_10, OP_11, OP_12, OP_13, OP_14, OP_15, OP_16)
def apply(asm: Seq[ScriptToken]): Option[WitnessScriptPubKey] = fromAsm(asm)

View File

@ -16,6 +16,8 @@ sealed trait WitnessVersion extends BitcoinSLogger {
* Either returns the stack and the [[ScriptPubKey]] it needs to be executed against or
* the [[ScriptError]] that was encountered while rebuilding the witness*/
def rebuild(scriptWitness: ScriptWitness, witnessProgram: Seq[ScriptToken]): Either[(Seq[ScriptToken], ScriptPubKey),ScriptError]
def version: ScriptNumberOperation
}
case object WitnessVersion0 extends WitnessVersion {
@ -60,16 +62,22 @@ case object WitnessVersion0 extends WitnessVersion {
Right(ScriptErrorWitnessProgramWrongLength)
}
}
override def version = OP_0
}
/** The witness version that represents all witnesses that have not been allocated yet */
case object UnassignedWitness extends WitnessVersion {
override def rebuild(scriptWitness: ScriptWitness, witnessProgram: Seq[ScriptToken]): Either[(Seq[ScriptToken], ScriptPubKey),ScriptError] =
Right(ScriptErrorDiscourageUpgradeableWitnessProgram)
override def version = OP_16
}
object WitnessVersion {
private val versions = Seq(WitnessVersion0, UnassignedWitness)
def apply(scriptNumberOp: ScriptNumberOperation): WitnessVersion = scriptNumberOp match {
case OP_0 | OP_FALSE => WitnessVersion0
case OP_1 | OP_TRUE | OP_2 | OP_3 | OP_4 | OP_5 | OP_6 | OP_7 | OP_8
@ -82,4 +90,10 @@ object WitnessVersion {
case _ : ScriptConstant | _ : ScriptNumber | _ : ScriptOperation =>
throw new IllegalArgumentException("We can only have witness version that is a script number operation, i.e OP_0 through OP_16")
}
def apply(int: Int): WitnessVersion = int match {
case 0 => WitnessVersion0
case _ => UnassignedWitness
}
}

View File

@ -79,8 +79,8 @@ trait NumberUtil extends BitcoinSLogger {
/** Converts a hex string to a [[Long]]. */
def toLong(hex : String): Long = toLong(BitcoinSUtil.decodeHex(hex))
/** Converts a sequence bytes 'from' base to 'to' base */
def convertBits(data: Seq[UInt8], from: UInt32, to: UInt32, pad: Boolean): Try[Seq[UInt8]] = {
/** Converts a sequence uint8 'from' base to 'to' base */
def convertUInt8s(data: Seq[UInt8], from: UInt32, to: UInt32, pad: Boolean): Try[Seq[UInt8]] = {
var acc: UInt32 = UInt32.zero
var bits: UInt32 = UInt32.zero
var ret: Seq[UInt8] = Nil
@ -114,6 +114,10 @@ trait NumberUtil extends BitcoinSLogger {
Success(ret)
}
}
def convertBytes(data: Seq[Byte], from: UInt32, to: UInt32, pad: Boolean): Try[Seq[UInt8]] = {
convertUInt8s(UInt8.toUInt8s(data),from,to,pad)
}
}
object NumberUtil extends NumberUtil

View File

@ -0,0 +1,13 @@
package org.bitcoins.core.number
import org.bitcoins.core.gen.NumberGenerator
import org.scalacheck.{Prop, Properties}
class UInt8Spec extends Properties("UInt8Spec") {
property("convert uint8 -> byte -> uint8") = {
Prop.forAll(NumberGenerator.uInt8) { case u8: UInt8 =>
UInt8(UInt8.toByte(u8)) == u8
}
}
}

View File

@ -44,16 +44,29 @@ class Bech32Test extends FlatSpec with MustMatchers {
val addr = Bech32Address(p2wpkh,TestNet3)
addr.map(_.value) must be (Success("tb1qw508d6qejxtdg4y5r3zarvary0c5xw7kxpjzsx"))
//decode
val decoded = addr.flatMap(a => Bech32Address.fromString(a.value))
decoded must be (Success(p2wpkh))
val p2wpkhMain = Bech32Address(p2wpkh,MainNet)
p2wpkhMain.map(_.value) must be (Success("bc1qw508d6qejxtdg4y5r3zarvary0c5xw7kv8f3t4"))
val mp2wpkhDecoded = p2wpkhMain.flatMap(a => Bech32Address.fromString(a.value))
mp2wpkhDecoded must be (Success(p2wpkh))
val p2pk = P2PKScriptPubKey(key)
val p2wsh = WitnessScriptPubKeyV0(p2pk)
val addr1 = Bech32Address(p2wsh,TestNet3)
addr1.map(_.value) must be (Success("tb1qrp33g0q5c5txsp9arysrx4k6zdkfs4nce4xj0gdcccefvpysxf3q0sl5k7"))
//decode
val decoded1 = addr1.flatMap(a => Bech32Address.fromString(a.value))
decoded1 must be (Success(p2wsh))
val p2wshMain = Bech32Address(p2wsh,MainNet)
p2wshMain.map(_.value) must be (Success("bc1qrp33g0q5c5txsp9arysrx4k6zdkfs4nce4xj0gdcccefvpysxf3qccfmv3"))
val mp2wshDecoded = p2wshMain.flatMap(a => Bech32Address.fromString(a.value))
mp2wshDecoded must be (Success(p2wsh))
}
it must "encode 0 byte correctly" in {