lnd/lnwire/commit_sig_test.go
2024-09-18 10:25:42 +02:00

169 lines
4.2 KiB
Go

package lnwire
import (
"bytes"
"fmt"
"testing"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require"
)
// testCase is a test case for the CommitSig message.
type commitSigTestCase struct {
// Msg is the message to be encoded and decoded.
Msg CommitSig
// ExpectEncodeError is a flag that indicates whether we expect the
// encoding of the message to fail.
ExpectEncodeError bool
}
// generateCommitSigTestCases generates a set of CommitSig message test cases.
func generateCommitSigTestCases(t *testing.T) []commitSigTestCase {
// Firstly, we'll set basic values for the message fields.
//
// Generate random channel ID.
chanIDBytes, err := generateRandomBytes(32)
require.NoError(t, err)
var chanID ChannelID
copy(chanID[:], chanIDBytes)
// Generate random commit sig.
commitSigBytes, err := generateRandomBytes(64)
require.NoError(t, err)
sig, err := NewSigFromSchnorrRawSignature(commitSigBytes)
require.NoError(t, err)
sigScalar := new(btcec.ModNScalar)
sigScalar.SetByteSlice(sig.RawBytes())
var nonce [musig2.PubNonceSize]byte
copy(nonce[:], commitSigBytes)
sigWithNonce := NewPartialSigWithNonce(nonce, *sigScalar)
partialSig := MaybePartialSigWithNonce(sigWithNonce)
// Define custom records.
recordKey1 := uint64(MinCustomRecordsTlvType + 1)
recordValue1, err := generateRandomBytes(10)
require.NoError(t, err)
recordKey2 := uint64(MinCustomRecordsTlvType + 2)
recordValue2, err := generateRandomBytes(10)
require.NoError(t, err)
customRecords := CustomRecords{
recordKey1: recordValue1,
recordKey2: recordValue2,
}
// Construct an instance of extra data that contains records with TLV
// types below the minimum custom records threshold and that lack
// corresponding fields in the message struct. Content should persist in
// the extra data field after encoding and decoding.
var (
recordBytes45 = []byte("recordBytes45")
tlvRecord45 = tlv.NewPrimitiveRecord[tlv.TlvType45](
recordBytes45,
)
recordBytes55 = []byte("recordBytes55")
tlvRecord55 = tlv.NewPrimitiveRecord[tlv.TlvType55](
recordBytes55,
)
)
var extraData ExtraOpaqueData
err = extraData.PackRecords(
[]tlv.RecordProducer{&tlvRecord45, &tlvRecord55}...,
)
require.NoError(t, err)
invalidCustomRecords := CustomRecords{
MinCustomRecordsTlvType - 1: recordValue1,
}
return []commitSigTestCase{
{
Msg: CommitSig{
ChanID: chanID,
CommitSig: sig,
PartialSig: partialSig,
CustomRecords: customRecords,
ExtraData: extraData,
},
},
// Add a test case where the blinding point field is not
// populated.
{
Msg: CommitSig{
ChanID: chanID,
CommitSig: sig,
CustomRecords: customRecords,
},
},
// Add a test case where the custom records field is not
// populated.
{
Msg: CommitSig{
ChanID: chanID,
CommitSig: sig,
PartialSig: partialSig,
},
},
// Add a case where the custom records are invalid.
{
Msg: CommitSig{
ChanID: chanID,
CommitSig: sig,
PartialSig: partialSig,
CustomRecords: invalidCustomRecords,
},
ExpectEncodeError: true,
},
}
}
// TestCommitSigEncodeDecode tests CommitSig message encoding and decoding for
// all supported field values.
func TestCommitSigEncodeDecode(t *testing.T) {
t.Parallel()
// Generate test cases.
testCases := generateCommitSigTestCases(t)
// Execute test cases.
for tcIdx, tc := range testCases {
t.Run(fmt.Sprintf("testcase-%d", tcIdx), func(t *testing.T) {
// Encode test case message.
var buf bytes.Buffer
err := tc.Msg.Encode(&buf, 0)
// Check if we expect an encoding error.
if tc.ExpectEncodeError {
require.Error(t, err)
return
}
require.NoError(t, err)
// Decode the encoded message bytes message.
var actualMsg CommitSig
decodeReader := bytes.NewReader(buf.Bytes())
err = actualMsg.Decode(decodeReader, 0)
require.NoError(t, err)
// The signature type isn't serialized.
actualMsg.CommitSig.ForceSchnorr()
// Compare the two messages to ensure equality.
require.Equal(t, tc.Msg, actualMsg)
})
}
}