lnwire: use assertEqualFunc in onion failure harness

Simplifies the code slightly and improves the error message printed if
the original and deserialized messages do not match.
This commit is contained in:
Matt Morehouse 2024-11-12 16:11:17 -06:00
parent d0e6a7a37b
commit 75bdf2d252
No known key found for this signature in database
GPG Key ID: CC8ECA224831C982

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"compress/zlib" "compress/zlib"
"encoding/binary" "encoding/binary"
"reflect"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -554,16 +553,12 @@ func prefixWithFailCode(data []byte, code FailCode) []byte {
return data return data
} }
// equalFunc is a function used to determine whether two deserialized messages
// are equivalent.
type equalFunc func(x, y any) bool
// onionFailureHarnessCustom performs the actual fuzz testing of the appropriate // onionFailureHarnessCustom performs the actual fuzz testing of the appropriate
// onion failure message. This function will check that the passed-in message // onion failure message. This function will check that the passed-in message
// passes wire length checks, is a valid message once deserialized, and passes a // passes wire length checks, is a valid message once deserialized, and passes a
// sequence of serialization and deserialization checks. // sequence of serialization and deserialization checks.
func onionFailureHarnessCustom(t *testing.T, data []byte, code FailCode, func onionFailureHarnessCustom(t *testing.T, data []byte, code FailCode,
eq equalFunc) { assertEqual assertEqualFunc) {
data = prefixWithFailCode(data, code) data = prefixWithFailCode(data, code)
@ -589,12 +584,7 @@ func onionFailureHarnessCustom(t *testing.T, data []byte, code FailCode,
newMsg, err := DecodeFailureMessage(&b, 0) newMsg, err := DecodeFailureMessage(&b, 0)
require.NoError(t, err, "failed to decode serialized failure message") require.NoError(t, err, "failed to decode serialized failure message")
require.True( assertEqual(t, msg, newMsg)
t, eq(msg, newMsg),
"original message and deserialized message are not equal: "+
"%v != %v",
msg, newMsg,
)
// Now verify that encoding/decoding full packets works as expected. // Now verify that encoding/decoding full packets works as expected.
@ -628,17 +618,15 @@ func onionFailureHarnessCustom(t *testing.T, data []byte, code FailCode,
pktMsg, err := DecodeFailure(&pktBuf, 0) pktMsg, err := DecodeFailure(&pktBuf, 0)
require.NoError(t, err, "failed to decode failure packet") require.NoError(t, err, "failed to decode failure packet")
require.True( assertEqual(t, msg, pktMsg)
t, eq(msg, pktMsg),
"original message and decoded packet message are not equal: "+
"%v != %v",
msg, pktMsg,
)
} }
func onionFailureHarness(t *testing.T, data []byte, code FailCode) { func onionFailureHarness(t *testing.T, data []byte, code FailCode) {
t.Helper() t.Helper()
onionFailureHarnessCustom(t, data, code, reflect.DeepEqual) assertEq := func(t *testing.T, x, y any) {
require.Equal(t, x, y)
}
onionFailureHarnessCustom(t, data, code, assertEq)
} }
func FuzzFailIncorrectDetails(f *testing.F) { func FuzzFailIncorrectDetails(f *testing.F) {
@ -646,7 +634,7 @@ func FuzzFailIncorrectDetails(f *testing.F) {
// Since FailIncorrectDetails.Decode can leave extraOpaqueData // Since FailIncorrectDetails.Decode can leave extraOpaqueData
// as nil while FailIncorrectDetails.Encode writes an empty // as nil while FailIncorrectDetails.Encode writes an empty
// slice, we need to use a custom equality function. // slice, we need to use a custom equality function.
eq := func(x, y any) bool { assertEq := func(t *testing.T, x, y any) {
msg1, ok := x.(*FailIncorrectDetails) msg1, ok := x.(*FailIncorrectDetails)
require.True( require.True(
t, ok, "msg1 was not FailIncorrectDetails", t, ok, "msg1 was not FailIncorrectDetails",
@ -657,16 +645,18 @@ func FuzzFailIncorrectDetails(f *testing.F) {
t, ok, "msg2 was not FailIncorrectDetails", t, ok, "msg2 was not FailIncorrectDetails",
) )
return msg1.amount == msg2.amount && require.Equal(t, msg1.amount, msg2.amount)
msg1.height == msg2.height && require.Equal(t, msg1.height, msg2.height)
bytes.Equal( require.True(
t, bytes.Equal(
msg1.extraOpaqueData, msg1.extraOpaqueData,
msg2.extraOpaqueData, msg2.extraOpaqueData,
) ),
)
} }
onionFailureHarnessCustom( onionFailureHarnessCustom(
t, data, CodeIncorrectOrUnknownPaymentDetails, eq, t, data, CodeIncorrectOrUnknownPaymentDetails, assertEq,
) )
}) })
} }