diff --git a/btcec/btcec.go b/btcec/btcec.go index efde8d6a..f85baba8 100644 --- a/btcec/btcec.go +++ b/btcec/btcec.go @@ -39,3 +39,18 @@ type CurveParams = secp.CurveParams func Params() *CurveParams { return secp.Params() } + +// Generator returns the public key at the Generator Point. +func Generator() *PublicKey { + var ( + result JacobianPoint + k secp.ModNScalar + ) + + k.SetInt(1) + ScalarBaseMultNonConst(&k, &result) + + result.ToAffine() + + return NewPublicKey(&result.X, &result.Y) +} diff --git a/btcec/schnorr/musig2/keys.go b/btcec/schnorr/musig2/keys.go index e61a22f2..8c86c624 100644 --- a/btcec/schnorr/musig2/keys.go +++ b/btcec/schnorr/musig2/keys.go @@ -26,6 +26,10 @@ var ( // ErrTweakedKeyIsInfinity is returned if while tweaking a key, we end // up with the point at infinity. ErrTweakedKeyIsInfinity = fmt.Errorf("tweaked key is infinity point") + + // ErrTweakedKeyOverflows is returned if a tweaking key is larger than + // 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141. + ErrTweakedKeyOverflows = fmt.Errorf("tweaked key is to large") ) // sortableKeys defines a type of slice of public keys that implements the sort @@ -286,7 +290,10 @@ func tweakKey(keyJ btcec.JacobianPoint, parityAcc btcec.ModNScalar, tweak [32]by // Next, map the tweak into a mod n integer so we can use it for // manipulations below. tweakInt := new(btcec.ModNScalar) - tweakInt.SetBytes(&tweak) + overflows := tweakInt.SetBytes(&tweak) + if overflows == 1 { + return keyJ, parityAcc, tweakAcc, ErrTweakedKeyOverflows + } // Next, we'll compute: Q_i = g*Q + t*G, where g is our parityFactor and t // is the tweakInt above. We'll space things out a bit to make it easier to diff --git a/btcec/schnorr/musig2/musig2_test.go b/btcec/schnorr/musig2/musig2_test.go index 309718ca..a032618b 100644 --- a/btcec/schnorr/musig2/musig2_test.go +++ b/btcec/schnorr/musig2/musig2_test.go @@ -16,6 +16,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/schnorr" + "github.com/decred/dcrd/dcrec/secp256k1/v4" ) var ( @@ -26,14 +27,50 @@ var ( key3Bytes, _ = hex.DecodeString("3590A94E768F8E1815C2F24B4D80A8E3149" + "316C3518CE7B7AD338368D038CA66") - testKeys = [][]byte{key1Bytes, key2Bytes, key3Bytes} + invalidPk1, _ = hex.DecodeString("00000000000000000000000000000000" + + "00000000000000000000000000000005") + invalidPk2, _ = hex.DecodeString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" + + "FFFFFFFFFFFFFFFFFFFFFFFEFFFFFC30") + invalidTweak, _ = hex.DecodeString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE" + + "BAAEDCE6AF48A03BBFD25E8CD0364141") - keyCombo1, _ = hex.DecodeString("E5830140512195D74C8307E39637CBE5FB730EBEAB80EC514CF88A877CEEEE0B") - keyCombo2, _ = hex.DecodeString("D70CD69A2647F7390973DF48CBFA2CCC407B8B2D60B08C5F1641185C7998A290") - keyCombo3, _ = hex.DecodeString("81A8B093912C9E481408D09776CEFB48AEB8B65481B6BAAFB3C5810106717BEB") - keyCombo4, _ = hex.DecodeString("2EB18851887E7BDC5E830E89B19DDBC28078F1FA88AAD0AD01CA06FE4F80210B") + testKeys = [][]byte{key1Bytes, key2Bytes, key3Bytes, invalidPk1, + invalidPk2} + + keyCombo1, _ = hex.DecodeString("E5830140512195D74C8307E39637CBE5FB73" + + "0EBEAB80EC514CF88A877CEEEE0B") + keyCombo2, _ = hex.DecodeString("D70CD69A2647F7390973DF48CBFA2CCC407B" + + "8B2D60B08C5F1641185C7998A290") + keyCombo3, _ = hex.DecodeString("81A8B093912C9E481408D09776CEFB48AEB8" + + "B65481B6BAAFB3C5810106717BEB") + keyCombo4, _ = hex.DecodeString("2EB18851887E7BDC5E830E89B19DDBC28078" + + "F1FA88AAD0AD01CA06FE4F80210B") ) +// getInfinityTweak returns a tweak that, when tweaking the Generator, triggers +// the ErrTweakedKeyIsInfinity error. +func getInfinityTweak() KeyTweakDesc { + generator := btcec.Generator() + + keySet := []*btcec.PublicKey{generator} + + keysHash := keyHashFingerprint(keySet, true) + uniqueKeyIndex := secondUniqueKeyIndex(keySet, true) + + n := &btcec.ModNScalar{} + + n.SetByteSlice(invalidTweak) + + coeff := aggregationCoefficient( + keySet, generator, keysHash, uniqueKeyIndex, + ).Negate().Add(n) + + return KeyTweakDesc{ + Tweak: coeff.Bytes(), + IsXOnly: false, + } +} + const ( keyAggTestVectorName = "key_agg_vectors.json" @@ -44,8 +81,10 @@ var dumpJson = flag.Bool("dumpjson", false, "if true, a JSON version of the "+ "test vectors will be written to the cwd") type jsonKeyAggTestCase struct { - Keys []string `json:"keys"` - ExpectedKey string `json:"expected_key"` + Keys []string `json:"keys"` + Tweaks []jsonTweak `json:"tweaks"` + ExpectedKey string `json:"expected_key"` + ExpectedError string `json:"expected_error"` } // TestMuSig2KeyAggTestVectors tests that this implementation of musig2 key @@ -56,8 +95,11 @@ func TestMuSig2KeyAggTestVectors(t *testing.T) { var jsonCases []jsonKeyAggTestCase testCases := []struct { - keyOrder []int - expectedKey []byte + keyOrder []int + explicitKeys []*btcec.PublicKey + tweaks []KeyTweakDesc + expectedKey []byte + expectedError error }{ // Keys in backwards lexicographical order. { @@ -82,18 +124,58 @@ func TestMuSig2KeyAggTestVectors(t *testing.T) { keyOrder: []int{0, 0, 1, 1}, expectedKey: keyCombo4, }, + + // Invalid public key. + { + keyOrder: []int{0, 3}, + expectedError: secp256k1.ErrPubKeyNotOnCurve, + }, + + // Public key exceeds field size. + { + keyOrder: []int{0, 4}, + expectedError: secp256k1.ErrPubKeyXTooBig, + }, + + // Tweak is out of range. + { + keyOrder: []int{0, 1}, + tweaks: []KeyTweakDesc{ + KeyTweakDesc{ + Tweak: to32ByteSlice(invalidTweak), + IsXOnly: true, + }, + }, + expectedError: ErrTweakedKeyOverflows, + }, + + // Intermediate tweaking result is point at infinity. + { + explicitKeys: []*secp256k1.PublicKey{btcec.Generator()}, + tweaks: []KeyTweakDesc{ + getInfinityTweak(), + }, + expectedError: ErrTweakedKeyIsInfinity, + }, } for i, testCase := range testCases { testName := fmt.Sprintf("%v", testCase.keyOrder) t.Run(testName, func(t *testing.T) { var ( - keys []*btcec.PublicKey - strKeys []string + keys []*btcec.PublicKey + strKeys []string + strTweaks []jsonTweak + jsonError string ) for _, keyIndex := range testCase.keyOrder { keyBytes := testKeys[keyIndex] pub, err := schnorr.ParsePubKey(keyBytes) - if err != nil { + + switch { + case testCase.expectedError != nil && + errors.Is(err, testCase.expectedError): + return + case err != nil: t.Fatalf("unable to parse pubkeys: %v", err) } @@ -101,15 +183,59 @@ func TestMuSig2KeyAggTestVectors(t *testing.T) { strKeys = append(strKeys, hex.EncodeToString(keyBytes)) } - jsonCases = append(jsonCases, jsonKeyAggTestCase{ - Keys: strKeys, - ExpectedKey: hex.EncodeToString(testCase.expectedKey), - }) + for _, explicitKey := range testCase.explicitKeys { + keys = append(keys, explicitKey) + strKeys = append( + strKeys, + hex.EncodeToString( + explicitKey.SerializeCompressed(), + )) + } + + for _, tweak := range testCase.tweaks { + strTweaks = append( + strTweaks, + jsonTweak{ + Tweak: hex.EncodeToString( + tweak.Tweak[:], + ), + XOnly: tweak.IsXOnly, + }) + } + + if testCase.expectedError != nil { + jsonError = testCase.expectedError.Error() + } + + jsonCases = append( + jsonCases, + jsonKeyAggTestCase{ + Keys: strKeys, + Tweaks: strTweaks, + ExpectedKey: hex.EncodeToString( + testCase.expectedKey), + ExpectedError: jsonError, + }) uniqueKeyIndex := secondUniqueKeyIndex(keys, false) - combinedKey, _, _, _ := AggregateKeys( - keys, false, WithUniqueKeyIndex(uniqueKeyIndex), + opts := []KeyAggOption{WithUniqueKeyIndex(uniqueKeyIndex)} + if len(testCase.tweaks) > 0 { + opts = append(opts, WithKeyTweaks(testCase.tweaks...)) + } + + combinedKey, _, _, err := AggregateKeys( + keys, false, opts..., ) + + switch { + case testCase.expectedError != nil && + errors.Is(err, testCase.expectedError): + return + + case err != nil: + t.Fatalf("case #%v, got error %v", i, err) + } + combinedKeyBytes := schnorr.SerializePubKey(combinedKey.FinalKey) if !bytes.Equal(combinedKeyBytes, testCase.expectedKey) { t.Fatalf("case: #%v, invalid aggregation: "+ @@ -910,3 +1036,12 @@ func memsetLoop(a []byte, v uint8) { a[i] = byte(v) } } + +func to32ByteSlice(input []byte) [32]byte { + if len(input) != 32 { + panic("input byte slice has invalid length") + } + var output [32]byte + copy(output[:], input) + return output +}