Represent and handle SIGHASH_DEFAULT correctly in TaprootKeyPath (#4488)

* Represent and handle SIGHASH_DEFAULT correctly in TaprootKeyPath

* Prevent construction of invalid TaprootKeyPath, fix tests

* Have SIGHASH_DEFAULT be SIGHASH_ALL in preTaproot cases
This commit is contained in:
benthecarman 2022-07-11 07:22:08 -05:00 committed by GitHub
parent ef50becf1b
commit 59732809d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 62 additions and 30 deletions

View File

@ -89,7 +89,7 @@ class ScriptSignatureTest extends BitcoinSJvmTest {
val hex =
"8c493046022100d23459d03ed7e9511a47d13292d3430a04627de6235b6e51a40f9cd386f2abe3022100e7d25b080f0bb8d8d5f878bba7d54ad2fda650ea8d158a33ee3cbd11768191fd004104b0e2c879e4daf7b9ab68350228c159766676a14f5815084ba166432aab46198d4cca98fa3e9981d0a90b2effc514b76279476550ba3663fdcaff94c38420e9d5"
val scriptSig: ScriptSignature = RawScriptSignatureParser.read(hex)
HashType(scriptSig.signatures.head.bytes.last) must be(SIGHASH_ALL(0))
HashType(scriptSig.signatures.head.bytes.last) must be(SIGHASH_DEFAULT)
}
it must "have an empty script signature" in {

View File

@ -111,7 +111,7 @@ class TaprootWitnessTest extends BitcoinSUnitTest {
}
it must "have serialization symmetry" in {
forAll(WitnessGenerators.taprootWitness) { case wit =>
forAll(WitnessGenerators.taprootWitness) { wit =>
val fromBytes = TaprootWitness.fromBytes(wit.bytes)
assert(fromBytes == wit)
assert(TaprootWitness.fromStack(wit.stack.toVector) == wit)

View File

@ -74,11 +74,15 @@ trait TransactionSignatureChecker {
pubKey: SchnorrPublicKey,
witness: TaprootKeyPath,
taprootOptions: TaprootSerializationOptions): ScriptResult = {
checkSchnorrSignature(txSigComponent = txSigComponent,
pubKey = pubKey,
schnorrSignature = witness.signature,
hashType = witness.hashType,
taprootOptions)
if (witness.hashTypeOpt.contains(HashType.sigHashDefault)) {
ScriptErrorSchnorrSigHashType
} else {
checkSchnorrSignature(txSigComponent = txSigComponent,
pubKey = pubKey,
schnorrSignature = witness.signature,
hashType = witness.hashType,
taprootOptions)
}
}
def checkSchnorrSignature(

View File

@ -136,9 +136,6 @@ sealed abstract class TransactionSignatureSerializer {
val sigHashBytes = Int32(hashType.num).bytes.reverse
hashType match {
case SIGHASH_DEFAULT =>
sys.error(
s"SIGHASH_DEFAULT is only available in taproot signature serialization, got=${sigVersion}")
case _: SIGHASH_NONE =>
val sigHashNoneTx: Transaction =
sigHashNone(txWithInputSigsRemoved, inputIndex)
@ -162,7 +159,7 @@ sealed abstract class TransactionSignatureSerializer {
sigHashSingleTx.bytes ++ sigHashBytes
}
case _: SIGHASH_ALL =>
case _: SIGHASH_ALL | SIGHASH_DEFAULT =>
val sigHashAllTx: Transaction = sigHashAll(txWithInputSigsRemoved)
sigHashAllTx.bytes ++ sigHashBytes

View File

@ -249,12 +249,14 @@ object TaprootWitness extends Factory[TaprootWitness] {
}
/** Spending a taproot output via the key path spend */
case class TaprootKeyPath(
case class TaprootKeyPath private (
signature: SchnorrDigitalSignature,
hashType: HashType,
hashTypeOpt: Option[HashType],
annexOpt: Option[ByteVector])
extends TaprootWitness {
val hashType: HashType = hashTypeOpt.getOrElse(HashType.sigHashDefault)
override val stack: Vector[ByteVector] = {
val sig = if (hashType == HashType.sigHashDefault) {
Vector(signature.bytes)
@ -280,6 +282,38 @@ object TaprootKeyPath extends Factory[TaprootKeyPath] {
}
}
def apply(
signature: SchnorrDigitalSignature,
hashType: HashType,
annexOpt: Option[ByteVector]): TaprootKeyPath = {
if (hashType == HashType.sigHashDefault) {
new TaprootKeyPath(signature, None, annexOpt)
} else {
new TaprootKeyPath(signature, Some(hashType), annexOpt)
}
}
def apply(
signature: SchnorrDigitalSignature,
hashTypeOpt: Option[HashType],
annexOpt: Option[ByteVector]): TaprootKeyPath = {
if (hashTypeOpt.contains(HashType.sigHashDefault)) {
new TaprootKeyPath(signature, None, annexOpt)
} else {
new TaprootKeyPath(signature, hashTypeOpt, annexOpt)
}
}
def apply(
signature: SchnorrDigitalSignature,
annexOpt: Option[ByteVector]): TaprootKeyPath = {
TaprootKeyPath(signature, None, annexOpt)
}
def apply(signature: SchnorrDigitalSignature): TaprootKeyPath = {
TaprootKeyPath(signature, None, None)
}
def fromStack(vec: Vector[ByteVector]): TaprootKeyPath = {
val hasAnnex = TaprootScriptPath.hasAnnex(vec)
require(
@ -305,11 +339,11 @@ object TaprootKeyPath extends Factory[TaprootKeyPath] {
//means SIGHASH_DEFAULT is implicitly encoded
//see: https://github.com/bitcoin/bips/blob/master/bip-0341.mediawiki#Common_signature_message
val sig = SchnorrDigitalSignature.fromBytes(sigBytes)
TaprootKeyPath(sig, HashType.sigHashDefault, annexOpt)
TaprootKeyPath(sig, None, annexOpt)
} else if (sigBytes.length == 65) {
val sig = SchnorrDigitalSignature.fromBytes(sigBytes.dropRight(1))
val hashType = HashType.fromByte(sigBytes.last)
TaprootKeyPath(sig, hashType, annexOpt)
TaprootKeyPath(sig, Some(hashType), annexOpt)
} else {
sys.error(
s"Unknown sig bytes length, should be 64 or 65, got=${sigBytes.length}")

View File

@ -20,7 +20,7 @@ class HashTypeTest extends BitcoinSCryptoTest {
}
it must "find a hash type by its byte value" in {
HashType(0.toByte) must be(SIGHASH_ALL(0))
HashType(0.toByte) must be(SIGHASH_DEFAULT)
HashType(1.toByte) must be(SIGHASH_ALL(1))
HashType(2.toByte) must be(HashType.sigHashNone)
HashType(3.toByte) must be(HashType.sigHashSingle)
@ -41,10 +41,11 @@ class HashTypeTest extends BitcoinSCryptoTest {
}
it must "determine if a given number is of hashType SIGHASH_ALL" in {
HashType.isSigHashAll(0) must be(true)
HashType.isSigHashAll(1) must be(true)
HashType.isSigHashAll(5) must be(true)
HashType.isSigHashAll(90) must be(true)
HashType.isSigHashAll(0) must be(false)
HashType.isSigHashAll(HashType.sigHashNone.num) must be(false)
HashType.isSigHashAll(HashType.sigHashSingle.num) must be(false)
}

View File

@ -24,7 +24,8 @@ object HashType extends Factory[HashType] {
def fromByte(byte: Byte): HashType = fromBytes(ByteVector.fromByte(byte))
def fromNumber(num: Int): HashType = {
if (isSigHashNone(num)) {
if (isSigHashDefault(num)) SIGHASH_DEFAULT
else if (isSigHashNone(num)) {
if (isSigHashNoneAnyoneCanPay(num)) {
SIGHASH_NONE_ANYONECANPAY(num)
} else {
@ -53,6 +54,8 @@ object HashType extends Factory[HashType] {
case h: HashType => h.byte
}
def isSigHashDefault(num: Int): Boolean = num == 0x00
def isSigHashAllOne(num: Int): Boolean = (num & 0x1f) == 1
def isSigHashNone(num: Int): Boolean = (num & 0x1f) == 2
@ -81,7 +84,8 @@ object HashType extends Factory[HashType] {
isSigHashAnyoneCanPay(num) ||
isSigHashAllAnyoneCanPay(num) ||
isSigHashSingleAnyoneCanPay(num) ||
isSigHashNoneAnyoneCanPay(num))
isSigHashNoneAnyoneCanPay(num) ||
isSigHashDefault(num))
) true
else false
}
@ -109,7 +113,8 @@ object HashType extends Factory[HashType] {
sigHashAnyoneCanPay,
sigHashNoneAnyoneCanPay,
sigHashAllAnyoneCanPay,
sigHashSingleAnyoneCanPay)
sigHashSingleAnyoneCanPay,
sigHashDefault)
lazy val hashTypeBytes: Vector[Byte] = Vector(
sigHashDefaultByte,

View File

@ -278,16 +278,7 @@ sealed abstract class CryptoGenerators {
} yield hash
/** Generates a random [[HashType HashType]] */
def hashType: Gen[HashType] =
Gen.oneOf(
HashType.sigHashAll,
HashType.sigHashNone,
HashType.sigHashSingle,
HashType.sigHashAnyoneCanPay,
HashType.sigHashSingleAnyoneCanPay,
HashType.sigHashNoneAnyoneCanPay,
HashType.sigHashAllAnyoneCanPay
)
def hashType: Gen[HashType] = Gen.oneOf(HashType.hashTypes)
def extVersion: Gen[ExtKeyVersion] = {
Gen.oneOf(ExtKeyVersion.all)