lnwallet: add unit test for extractPayDescs

This commit is contained in:
yyforyongyu 2023-11-25 01:49:32 +08:00
parent 85f4b13632
commit 81841b7dab
No known key found for this signature in database
GPG key ID: 9BCD95C4FF296868

View file

@ -10111,3 +10111,101 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) {
)
require.ErrorIs(t, err, channeldb.ErrLogEntryNotFound)
}
// TestExtractPayDescs asserts that `extractPayDescs` can correctly turn a
// slice of htlcs into two slices of PaymentDescriptors.
func TestExtractPayDescs(t *testing.T) {
t.Parallel()
// Create a testing LightningChannel.
lnChan, _, err := CreateTestChannels(
t, channeldb.SingleFunderTweaklessBit,
)
require.NoError(t, err)
// Create two incoming HTLCs.
incomings := []channeldb.HTLC{
createRandomHTLC(t, true),
createRandomHTLC(t, true),
}
// Create two outgoing HTLCs.
outgoings := []channeldb.HTLC{
createRandomHTLC(t, false),
createRandomHTLC(t, false),
}
// Concatenate incomings and outgoings into a single slice.
htlcs := []channeldb.HTLC{}
htlcs = append(htlcs, incomings...)
htlcs = append(htlcs, outgoings...)
// Run the method under test.
//
// NOTE: we use nil commitment key rings to avoid checking the htlc
// scripts(`genHtlcScript`) as it should be tested independently.
incomingPDs, outgoingPDs, err := lnChan.extractPayDescs(
0, 0, htlcs, nil, nil, true,
)
require.NoError(t, err)
// Assert the incoming PaymentDescriptors are matched.
for i, pd := range incomingPDs {
htlc := incomings[i]
assertPayDescMatchHTLC(t, pd, htlc)
}
// Assert the outgoing PaymentDescriptors are matched.
for i, pd := range outgoingPDs {
htlc := outgoings[i]
assertPayDescMatchHTLC(t, pd, htlc)
}
}
// assertPayDescMatchHTLC compares a PaymentDescriptor to a channeldb.HTLC and
// asserts that the fields are matched.
func assertPayDescMatchHTLC(t *testing.T, pd PaymentDescriptor,
htlc channeldb.HTLC) {
require := require.New(t)
require.EqualValues(htlc.RHash, pd.RHash, "RHash")
require.Equal(htlc.RefundTimeout, pd.Timeout, "Timeout")
require.Equal(htlc.Amt, pd.Amount, "Amount")
require.Equal(htlc.HtlcIndex, pd.HtlcIndex, "HtlcIndex")
require.Equal(htlc.LogIndex, pd.LogIndex, "LogIndex")
require.EqualValues(htlc.OnionBlob[:], pd.OnionBlob, "OnionBlob")
}
// createRandomHTLC creates an HTLC that has random value in every field except
// the `Incoming`.
func createRandomHTLC(t *testing.T, incoming bool) channeldb.HTLC {
var onionBlob [lnwire.OnionPacketSize]byte
_, err := rand.Read(onionBlob[:])
require.NoError(t, err)
var rHash [lntypes.HashSize]byte
_, err = rand.Read(rHash[:])
require.NoError(t, err)
sig := make([]byte, 64)
_, err = rand.Read(sig)
require.NoError(t, err)
extra := make([]byte, 1000)
_, err = rand.Read(extra)
require.NoError(t, err)
return channeldb.HTLC{
Signature: sig,
RHash: rHash,
Amt: lnwire.MilliSatoshi(rand.Uint64()),
RefundTimeout: rand.Uint32(),
OutputIndex: rand.Int31n(1000),
Incoming: incoming,
OnionBlob: onionBlob,
HtlcIndex: rand.Uint64(),
LogIndex: rand.Uint64(),
ExtraData: extra,
}
}