btcd/btcec/schnorr/musig2/keys_test.go
2022-10-20 17:57:57 -07:00

394 lines
8.9 KiB
Go

// Copyright 2013-2022 The btcsuite developers
package musig2
import (
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path"
"strings"
"testing"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
secp "github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/stretchr/testify/require"
)
const (
keySortTestVectorFileName = "key_sort_vectors.json"
keyAggTestVectorFileName = "key_agg_vectors.json"
keyTweakTestVectorFileName = "tweak_vectors.json"
)
type keySortTestVector struct {
PubKeys []string `json:"pubkeys"`
SortedKeys []string `json:"sorted_pubkeys"`
}
// TestMusig2KeySort tests that keys are properly sorted according to the
// musig2 test vectors.
func TestMusig2KeySort(t *testing.T) {
t.Parallel()
testVectorPath := path.Join(
testVectorBaseDir, keySortTestVectorFileName,
)
testVectorBytes, err := os.ReadFile(testVectorPath)
require.NoError(t, err)
var testCase keySortTestVector
require.NoError(t, json.Unmarshal(testVectorBytes, &testCase))
keys := make([]*btcec.PublicKey, len(testCase.PubKeys))
for i, keyStr := range testCase.PubKeys {
pubKey, err := btcec.ParsePubKey(mustParseHex(keyStr))
require.NoError(t, err)
keys[i] = pubKey
}
sortedKeys := sortKeys(keys)
expectedKeys := make([]*btcec.PublicKey, len(testCase.PubKeys))
for i, keyStr := range testCase.SortedKeys {
pubKey, err := btcec.ParsePubKey(mustParseHex(keyStr))
require.NoError(t, err)
expectedKeys[i] = pubKey
}
require.Equal(t, sortedKeys, expectedKeys)
}
type keyAggValidTest struct {
Indices []int `json:"key_indices"`
Expected string `json:"expected"`
}
type keyAggError struct {
Type string `json:"type"`
Signer int `json:"signer"`
Contring string `json:"contrib"`
}
type keyAggInvalidTest struct {
Indices []int `json:"key_indices"`
TweakIndices []int `json:"tweak_indices"`
IsXOnly []bool `json:"is_xonly"`
Comment string `json:"comment"`
}
type keyAggTestVectors struct {
PubKeys []string `json:"pubkeys"`
Tweaks []string `json:"tweaks"`
ValidCases []keyAggValidTest `json:"valid_test_cases"`
InvalidCases []keyAggInvalidTest `json:"error_test_cases"`
}
func keysFromIndices(t *testing.T, indices []int,
pubKeys []string) ([]*btcec.PublicKey, error) {
t.Helper()
inputKeys := make([]*btcec.PublicKey, len(indices))
for i, keyIdx := range indices {
var err error
inputKeys[i], err = btcec.ParsePubKey(
mustParseHex(pubKeys[keyIdx]),
)
if err != nil {
return nil, err
}
}
return inputKeys, nil
}
func tweaksFromIndices(t *testing.T, indices []int,
tweaks []string, isXonly []bool) []KeyTweakDesc {
t.Helper()
testTweaks := make([]KeyTweakDesc, len(indices))
for i, idx := range indices {
var rawTweak [32]byte
copy(rawTweak[:], mustParseHex(tweaks[idx]))
testTweaks[i] = KeyTweakDesc{
Tweak: rawTweak,
IsXOnly: isXonly[i],
}
}
return testTweaks
}
// TestMuSig2KeyAggTestVectors tests that this implementation of musig2 key
// aggregation lines up with the secp256k1-zkp test vectors.
func TestMuSig2KeyAggTestVectors(t *testing.T) {
t.Parallel()
testVectorPath := path.Join(
testVectorBaseDir, keyAggTestVectorFileName,
)
testVectorBytes, err := os.ReadFile(testVectorPath)
require.NoError(t, err)
var testCases keyAggTestVectors
require.NoError(t, json.Unmarshal(testVectorBytes, &testCases))
tweaks := make([][]byte, len(testCases.Tweaks))
for i := range testCases.Tweaks {
tweaks[i] = mustParseHex(testCases.Tweaks[i])
}
for i, testCase := range testCases.ValidCases {
testCase := testCase
// Assemble the set of keys we'll pass in based on their key
// index. We don't use sorting to ensure we send the keys in
// the exact same order as the test vectors do.
inputKeys, err := keysFromIndices(
t, testCase.Indices, testCases.PubKeys,
)
require.NoError(t, err)
t.Run(fmt.Sprintf("test_case=%v", i), func(t *testing.T) {
uniqueKeyIndex := secondUniqueKeyIndex(inputKeys, false)
opts := []KeyAggOption{WithUniqueKeyIndex(uniqueKeyIndex)}
combinedKey, _, _, err := AggregateKeys(
inputKeys, false, opts...,
)
require.NoError(t, err)
require.Equal(
t, schnorr.SerializePubKey(combinedKey.FinalKey),
mustParseHex(testCase.Expected),
)
})
}
for _, testCase := range testCases.InvalidCases {
testCase := testCase
testName := fmt.Sprintf("invalid_%v",
strings.ToLower(testCase.Comment))
t.Run(testName, func(t *testing.T) {
// For each test, we'll extract the set of input keys
// as well as the tweaks since this set of cases also
// exercises error cases related to the set of tweaks.
inputKeys, err := keysFromIndices(
t, testCase.Indices, testCases.PubKeys,
)
// In this set of test cases, we should only get this
// for the very first vector.
if err != nil {
switch testCase.Comment {
case "Invalid public key":
require.ErrorIs(
t, err,
secp.ErrPubKeyNotOnCurve,
)
case "Public key exceeds field size":
require.ErrorIs(
t, err, secp.ErrPubKeyXTooBig,
)
case "First byte of public key is not 2 or 3":
require.ErrorIs(
t, err,
secp.ErrPubKeyInvalidFormat,
)
default:
t.Fatalf("uncaught err: %v", err)
}
return
}
var tweaks []KeyTweakDesc
if len(testCase.TweakIndices) != 0 {
tweaks = tweaksFromIndices(
t, testCase.TweakIndices, testCases.Tweaks,
testCase.IsXOnly,
)
}
uniqueKeyIndex := secondUniqueKeyIndex(inputKeys, false)
opts := []KeyAggOption{
WithUniqueKeyIndex(uniqueKeyIndex),
}
if len(tweaks) != 0 {
opts = append(opts, WithKeyTweaks(tweaks...))
}
_, _, _, err = AggregateKeys(
inputKeys, false, opts...,
)
require.Error(t, err)
switch testCase.Comment {
case "Tweak is out of range":
require.ErrorIs(t, err, ErrTweakedKeyOverflows)
case "Intermediate tweaking result is point at infinity":
require.ErrorIs(t, err, ErrTweakedKeyIsInfinity)
default:
t.Fatalf("uncaught err: %v", err)
}
})
}
}
type keyTweakInvalidTest struct {
Indices []int `json:"key_indices"`
NonceIndices []int `json:"nonce_indices"`
TweakIndices []int `json:"tweak_indices"`
IsXOnly []bool `json:"is_only"`
SignerIndex int `json:"signer_index"`
Comment string `json:"comment"`
}
type keyTweakValidTest struct {
Indices []int `json:"key_indices"`
NonceIndices []int `json:"nonce_indices"`
TweakIndices []int `json:"tweak_indices"`
IsXOnly []bool `json:"is_xonly"`
SignerIndex int `json:"signer_index"`
Expected string `json:"expected"`
Comment string `json:"comment"`
}
type keyTweakVector struct {
PrivKey string `json:"sk"`
PubKeys []string `json:"pubkeys"`
PrivNonce string `json:"secnonce"`
PubNonces []string `json:"pnonces"`
AggNnoce string `json:"aggnonce"`
Tweaks []string `json:"tweaks"`
Msg string `json:"msg"`
ValidCases []keyTweakValidTest `json:"valid_test_cases"`
InvalidCases []keyTweakInvalidTest `json:"error_test_cases"`
}
func pubNoncesFromIndices(t *testing.T, nonceIndices []int, pubNonces []string) [][PubNonceSize]byte {
nonces := make([][PubNonceSize]byte, len(nonceIndices))
for i, idx := range nonceIndices {
var pubNonce [PubNonceSize]byte
copy(pubNonce[:], mustParseHex(pubNonces[idx]))
nonces[i] = pubNonce
}
return nonces
}
// TestMuSig2TweakTestVectors tests that we properly handle the various edge
// cases related to tweaking public keys.
func TestMuSig2TweakTestVectors(t *testing.T) {
t.Parallel()
testVectorPath := path.Join(
testVectorBaseDir, keyTweakTestVectorFileName,
)
testVectorBytes, err := os.ReadFile(testVectorPath)
require.NoError(t, err)
var testCases keyTweakVector
require.NoError(t, json.Unmarshal(testVectorBytes, &testCases))
privKey, _ := btcec.PrivKeyFromBytes(mustParseHex(testCases.PrivKey))
var msg [32]byte
copy(msg[:], mustParseHex(testCases.Msg))
var secNonce [SecNonceSize]byte
copy(secNonce[:], mustParseHex(testCases.PrivNonce))
for _, testCase := range testCases.ValidCases {
testCase := testCase
testName := fmt.Sprintf("valid_%v",
strings.ToLower(testCase.Comment))
t.Run(testName, func(t *testing.T) {
pubKeys, err := keysFromIndices(
t, testCase.Indices, testCases.PubKeys,
)
require.NoError(t, err)
var tweaks []KeyTweakDesc
if len(testCase.TweakIndices) != 0 {
tweaks = tweaksFromIndices(
t, testCase.TweakIndices,
testCases.Tweaks, testCase.IsXOnly,
)
}
pubNonces := pubNoncesFromIndices(
t, testCase.NonceIndices, testCases.PubNonces,
)
combinedNonce, err := AggregateNonces(pubNonces)
require.NoError(t, err)
var opts []SignOption
if len(tweaks) != 0 {
opts = append(opts, WithTweaks(tweaks...))
}
partialSig, err := Sign(
secNonce, privKey, combinedNonce, pubKeys,
msg, opts...,
)
var partialSigBytes [32]byte
partialSig.S.PutBytesUnchecked(partialSigBytes[:])
require.Equal(
t, hex.EncodeToString(partialSigBytes[:]),
hex.EncodeToString(mustParseHex(testCase.Expected)),
)
})
}
}