lnwire: use require package for fuzz tests

Simplify code by using the require package instead of t.Fatal().
This commit is contained in:
Matt Morehouse 2023-05-19 11:59:45 -05:00
parent b95faaba45
commit 460ba4ad82
No known key found for this signature in database
GPG Key ID: CC8ECA224831C982

View File

@ -4,8 +4,9 @@ import (
"bytes"
"compress/zlib"
"encoding/binary"
"reflect"
"testing"
"github.com/stretchr/testify/require"
)
// prefixWithMsgType takes []byte and adds a wire protocol prefix
@ -41,26 +42,15 @@ func harness(t *testing.T, data []byte) {
// We will serialize the message into a new bytes buffer.
var b bytes.Buffer
if _, err := WriteMessage(&b, msg, 0); err != nil {
// Could not serialize message into bytes buffer, panic
t.Fatal(err)
}
_, err = WriteMessage(&b, msg, 0)
require.NoError(t, err)
// Deserialize the message from the serialized bytes buffer, and then
// assert that the original message is equal to the newly deserialized
// message.
newMsg, err := ReadMessage(&b, 0)
if err != nil {
// Could not deserialize message from bytes buffer, panic
t.Fatal(err)
}
if !reflect.DeepEqual(msg, newMsg) {
// Deserialized message and original message are not deeply
// equal.
t.Fatal("original message and deserialized message are not " +
"deeply equal")
}
require.NoError(t, err)
require.Equal(t, msg, newMsg)
}
func FuzzAcceptChannel(f *testing.F) {
@ -83,107 +73,32 @@ func FuzzAcceptChannel(f *testing.F) {
// We will serialize the message into a new bytes buffer.
var b bytes.Buffer
if _, err := WriteMessage(&b, msg, 0); err != nil {
// Could not serialize message into bytes buffer, panic
t.Fatal(err)
}
_, err = WriteMessage(&b, msg, 0)
require.NoError(t, err)
// Deserialize the message from the serialized bytes buffer, and
// then assert that the original message is equal to the newly
// deserialized message.
newMsg, err := ReadMessage(&b, 0)
if err != nil {
// Could not deserialize message from bytes buffer,
// panic
t.Fatal(err)
}
require.NoError(t, err)
// Now compare every field instead of using reflect.DeepEqual.
// For UpfrontShutdownScript, we only compare bytes. This
// probably takes up more branches than necessary, but that's
// fine for now.
var shouldPanic bool
first, ok := msg.(*AcceptChannel)
if !ok {
t.Fatal("message was not AcceptChannel")
}
second, ok := newMsg.(*AcceptChannel)
if !ok {
t.Fatal("new message was not AcceptChannel")
}
require.IsType(t, &AcceptChannel{}, msg)
first, _ := msg.(*AcceptChannel)
require.IsType(t, &AcceptChannel{}, newMsg)
second, _ := newMsg.(*AcceptChannel)
if !bytes.Equal(first.PendingChannelID[:],
second.PendingChannelID[:]) {
// We can't use require.Equal for UpfrontShutdownScript, since
// we consider the empty slice and nil to be equivalent.
require.True(
t, bytes.Equal(
first.UpfrontShutdownScript,
second.UpfrontShutdownScript,
),
)
first.UpfrontShutdownScript = nil
second.UpfrontShutdownScript = nil
shouldPanic = true
}
if first.DustLimit != second.DustLimit {
shouldPanic = true
}
if first.MaxValueInFlight != second.MaxValueInFlight {
shouldPanic = true
}
if first.ChannelReserve != second.ChannelReserve {
shouldPanic = true
}
if first.HtlcMinimum != second.HtlcMinimum {
shouldPanic = true
}
if first.MinAcceptDepth != second.MinAcceptDepth {
shouldPanic = true
}
if first.CsvDelay != second.CsvDelay {
shouldPanic = true
}
if first.MaxAcceptedHTLCs != second.MaxAcceptedHTLCs {
shouldPanic = true
}
if !first.FundingKey.IsEqual(second.FundingKey) {
shouldPanic = true
}
if !first.RevocationPoint.IsEqual(second.RevocationPoint) {
shouldPanic = true
}
if !first.PaymentPoint.IsEqual(second.PaymentPoint) {
shouldPanic = true
}
if !first.DelayedPaymentPoint.IsEqual(
second.DelayedPaymentPoint) {
shouldPanic = true
}
if !first.HtlcPoint.IsEqual(second.HtlcPoint) {
shouldPanic = true
}
if !first.FirstCommitmentPoint.IsEqual(
second.FirstCommitmentPoint) {
shouldPanic = true
}
if !bytes.Equal(first.UpfrontShutdownScript,
second.UpfrontShutdownScript) {
shouldPanic = true
}
if shouldPanic {
t.Fatal("original message and deseralized message " +
"are not equal")
}
require.Equal(t, first, second)
})
}
@ -356,80 +271,34 @@ func FuzzNodeAnnouncement(f *testing.F) {
// We will serialize the message into a new bytes buffer.
var b bytes.Buffer
if _, err := WriteMessage(&b, msg, 0); err != nil {
// Could not serialize message into bytes buffer, panic
t.Fatal(err)
}
_, err = WriteMessage(&b, msg, 0)
require.NoError(t, err)
// Deserialize the message from the serialized bytes buffer, and
// then assert that the original message is equal to the newly
// deserialized message.
newMsg, err := ReadMessage(&b, 0)
if err != nil {
// Could not deserialize message from bytes buffer,
// panic
t.Fatal(err)
}
require.NoError(t, err)
// Now compare every field instead of using reflect.DeepEqual
// for the Addresses field.
var shouldPanic bool
first, ok := msg.(*NodeAnnouncement)
if !ok {
t.Fatal("message was not NodeAnnouncement")
}
second, ok := newMsg.(*NodeAnnouncement)
if !ok {
t.Fatal("new message was not NodeAnnouncement")
}
if !bytes.Equal(first.Signature[:], second.Signature[:]) {
shouldPanic = true
}
if !reflect.DeepEqual(first.Features, second.Features) {
shouldPanic = true
}
if first.Timestamp != second.Timestamp {
shouldPanic = true
}
if !bytes.Equal(first.NodeID[:], second.NodeID[:]) {
shouldPanic = true
}
if !reflect.DeepEqual(first.RGBColor, second.RGBColor) {
shouldPanic = true
}
if !bytes.Equal(first.Alias[:], second.Alias[:]) {
shouldPanic = true
}
if len(first.Addresses) != len(second.Addresses) {
shouldPanic = true
}
require.IsType(t, &NodeAnnouncement{}, msg)
first, _ := msg.(*NodeAnnouncement)
require.IsType(t, &NodeAnnouncement{}, newMsg)
second, _ := newMsg.(*NodeAnnouncement)
// We can't use require.Equal for Addresses, since the same IP
// can be represented by different underlying bytes. Instead, we
// compare the normalized string representation of each address.
require.Equal(t, len(first.Addresses), len(second.Addresses))
for i := range first.Addresses {
if first.Addresses[i].String() !=
second.Addresses[i].String() {
shouldPanic = true
break
}
require.Equal(
t, first.Addresses[i].String(),
second.Addresses[i].String(),
)
}
first.Addresses = nil
second.Addresses = nil
if !reflect.DeepEqual(first.ExtraOpaqueData,
second.ExtraOpaqueData) {
shouldPanic = true
}
if shouldPanic {
t.Fatal("original message and deserialized message " +
"are not equal")
}
require.Equal(t, first, second)
})
}
@ -461,123 +330,32 @@ func FuzzOpenChannel(f *testing.F) {
// We will serialize the message into a new bytes buffer.
var b bytes.Buffer
if _, err := WriteMessage(&b, msg, 0); err != nil {
// Could not serialize message into bytes buffer, panic
t.Fatal(err)
}
_, err = WriteMessage(&b, msg, 0)
require.NoError(t, err)
// Deserialize the message from the serialized bytes buffer, and
// then assert that the original message is equal to the newly
// deserialized message.
newMsg, err := ReadMessage(&b, 0)
if err != nil {
// Could not deserialize message from bytes buffer,
// panic
t.Fatal(err)
}
require.NoError(t, err)
// Now compare every field instead of using reflect.DeepEqual.
// For UpfrontShutdownScript, we only compare bytes. This
// probably takes up more branches than necessary, but that's
// fine for now.
var shouldPanic bool
first, ok := msg.(*OpenChannel)
if !ok {
t.Fatal("message was not OpenChannel")
}
second, ok := newMsg.(*OpenChannel)
if !ok {
t.Fatal("new message was not OpenChannel")
}
require.IsType(t, &OpenChannel{}, msg)
first, _ := msg.(*OpenChannel)
require.IsType(t, &OpenChannel{}, newMsg)
second, _ := newMsg.(*OpenChannel)
if !first.ChainHash.IsEqual(&second.ChainHash) {
shouldPanic = true
}
// We can't use require.Equal for UpfrontShutdownScript, since
// we consider the empty slice and nil to be equivalent.
require.True(
t, bytes.Equal(
first.UpfrontShutdownScript,
second.UpfrontShutdownScript,
),
)
first.UpfrontShutdownScript = nil
second.UpfrontShutdownScript = nil
if !bytes.Equal(first.PendingChannelID[:],
second.PendingChannelID[:]) {
shouldPanic = true
}
if first.FundingAmount != second.FundingAmount {
shouldPanic = true
}
if first.PushAmount != second.PushAmount {
shouldPanic = true
}
if first.DustLimit != second.DustLimit {
shouldPanic = true
}
if first.MaxValueInFlight != second.MaxValueInFlight {
shouldPanic = true
}
if first.ChannelReserve != second.ChannelReserve {
shouldPanic = true
}
if first.HtlcMinimum != second.HtlcMinimum {
shouldPanic = true
}
if first.FeePerKiloWeight != second.FeePerKiloWeight {
shouldPanic = true
}
if first.CsvDelay != second.CsvDelay {
shouldPanic = true
}
if first.MaxAcceptedHTLCs != second.MaxAcceptedHTLCs {
shouldPanic = true
}
if !first.FundingKey.IsEqual(second.FundingKey) {
shouldPanic = true
}
if !first.RevocationPoint.IsEqual(second.RevocationPoint) {
shouldPanic = true
}
if !first.PaymentPoint.IsEqual(second.PaymentPoint) {
shouldPanic = true
}
if !first.DelayedPaymentPoint.IsEqual(
second.DelayedPaymentPoint) {
shouldPanic = true
}
if !first.HtlcPoint.IsEqual(second.HtlcPoint) {
shouldPanic = true
}
if !first.FirstCommitmentPoint.IsEqual(
second.FirstCommitmentPoint) {
shouldPanic = true
}
if first.ChannelFlags != second.ChannelFlags {
shouldPanic = true
}
if !bytes.Equal(first.UpfrontShutdownScript,
second.UpfrontShutdownScript) {
shouldPanic = true
}
if shouldPanic {
t.Fatal("original message and deserialized message " +
"are not equal")
}
require.Equal(t, first, second)
})
}
@ -619,15 +397,10 @@ func FuzzZlibQueryShortChanIDs(f *testing.F) {
var buf bytes.Buffer
zlibWriter := zlib.NewWriter(&buf)
_, err := zlibWriter.Write(data)
if err != nil {
// Zlib bug?
t.Fatal(err)
}
require.NoError(t, err) // Zlib bug?
if err := zlibWriter.Close(); err != nil {
// Zlib bug?
t.Fatal(err)
}
err = zlibWriter.Close()
require.NoError(t, err) // Zlib bug?
compressedPayload := buf.Bytes()
@ -668,15 +441,10 @@ func FuzzZlibReplyChannelRange(f *testing.F) {
var buf bytes.Buffer
zlibWriter := zlib.NewWriter(&buf)
_, err := zlibWriter.Write(data)
if err != nil {
// Zlib bug?
t.Fatal(err)
}
require.NoError(t, err) // Zlib bug?
if err := zlibWriter.Close(); err != nil {
// Zlib bug?
t.Fatal(err)
}
err = zlibWriter.Close()
require.NoError(t, err) // Zlib bug?
compressedPayload := buf.Bytes()
@ -834,13 +602,9 @@ func FuzzParseRawSignature(f *testing.F) {
}
sig2, err := NewSigFromRawSignature(sig.ToSignatureBytes())
if err != nil {
t.Fatalf("failed to reparse signature: %v", err)
}
require.NoError(t, err, "failed to reparse signature")
if !reflect.DeepEqual(sig, sig2) {
t.Fatalf("signature mismatch: %v != %v", sig, sig2)
}
require.Equal(t, sig, sig2, "signature mismatch")
})
}
@ -861,21 +625,13 @@ func FuzzConvertFixedSignature(f *testing.F) {
}
sig2, err := NewSigFromSignature(derSig)
if err != nil {
t.Fatalf("failed to parse signature: %v", err)
}
require.NoError(t, err, "failed to parse signature")
derSig2, err := sig2.ToSignature()
if err != nil {
t.Fatalf("failed to reconvert signature to DER: %v",
err)
}
require.NoError(t, err, "failed to reconvert signature to DER")
derBytes := derSig.Serialize()
derBytes2 := derSig2.Serialize()
if !bytes.Equal(derBytes, derBytes2) {
t.Fatalf("signature mismatch: %v != %v", derBytes,
derBytes2)
}
require.Equal(t, derBytes, derBytes2, "signature mismatch")
})
}