lnwire: add custom records field to type UpdateFulfillHtlc

- Introduce the field `CustomRecords` to the type `UpdateFulfillHtlc`.
- Encode and decode the new field into the `ExtraData` field of the
`update_fulfill_htlc` wire message.
- Empty `ExtraData` field is set to `nil`.
This commit is contained in:
ffranr 2024-05-03 16:22:05 +01:00 committed by Oliver Gugger
parent 81f6a8060f
commit 8d1059f41c
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
5 changed files with 188 additions and 6 deletions

View File

@ -38,7 +38,6 @@ func TestNetworkResultSerialization(t *testing.T) {
ChanID: chanID,
ID: 2,
PaymentPreimage: preimage,
ExtraData: make([]byte, 0),
}
fail := &lnwire.UpdateFailHTLC{

View File

@ -1442,6 +1442,29 @@ func TestLightningWireProtocol(t *testing.T) {
)
}
v[0] = reflect.ValueOf(*req)
},
MsgUpdateFulfillHTLC: func(v []reflect.Value, r *rand.Rand) {
req := &UpdateFulfillHTLC{
ID: r.Uint64(),
}
_, err := r.Read(req.ChanID[:])
require.NoError(t, err)
_, err = r.Read(req.PaymentPreimage[:])
require.NoError(t, err)
req.CustomRecords = randCustomRecords(t, r)
// Generate some random TLV records 50% of the time.
if r.Int31()%2 == 0 {
req.ExtraData = []byte{
0x01, 0x03, 1, 2, 3,
0x02, 0x03, 4, 5, 6,
}
}
v[0] = reflect.ValueOf(*req)
},
}

View File

@ -23,6 +23,10 @@ type UpdateFulfillHTLC struct {
// HTLC.
PaymentPreimage [32]byte
// CustomRecords maps TLV types to byte slices, storing arbitrary data
// intended for inclusion in the ExtraData field.
CustomRecords CustomRecords
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
@ -49,12 +53,31 @@ var _ Message = (*UpdateFulfillHTLC)(nil)
//
// This is part of the lnwire.Message interface.
func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
// msgExtraData is a temporary variable used to read the message extra
// data field from the reader.
var msgExtraData ExtraOpaqueData
if err := ReadElements(r,
&c.ChanID,
&c.ID,
c.PaymentPreimage[:],
&c.ExtraData,
&msgExtraData,
); err != nil {
return err
}
// Extract custom records from the extra data field.
customRecords, _, extraData, err := ParseAndExtractCustomRecords(
msgExtraData,
)
if err != nil {
return err
}
c.CustomRecords = customRecords
c.ExtraData = extraData
return nil
}
// Encode serializes the target UpdateFulfillHTLC into the passed io.Writer
@ -74,7 +97,14 @@ func (c *UpdateFulfillHTLC) Encode(w *bytes.Buffer, pver uint32) error {
return err
}
return WriteBytes(w, c.ExtraData)
// Combine the custom records and the extra data, then encode the
// result as a byte slice.
extraData, err := MergeAndEncode(nil, c.ExtraData, c.CustomRecords)
if err != nil {
return err
}
return WriteBytes(w, extraData)
}
// MsgType returns the integer uniquely identifying this message type on the

View File

@ -0,0 +1,129 @@
package lnwire
import (
"bytes"
"fmt"
"testing"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require"
)
// testCaseUpdateFulfill is a test case for the UpdateFulfillHTLC message.
type testCaseUpdateFulfill struct {
// Msg is the message to be encoded and decoded.
Msg UpdateFulfillHTLC
// ExpectEncodeError is a flag that indicates whether we expect the
// encoding of the message to fail.
ExpectEncodeError bool
}
// generateTestCases generates a set of UpdateFulfillHTLC message test cases.
func generateUpdateFulfillTestCases(t *testing.T) []testCaseUpdateFulfill {
// Firstly, we'll set basic values for the message fields.
//
// Generate random channel ID.
chanIDBytes, err := generateRandomBytes(32)
require.NoError(t, err)
var chanID ChannelID
copy(chanID[:], chanIDBytes)
// Generate random payment preimage.
paymentPreimageBytes, err := generateRandomBytes(32)
require.NoError(t, err)
var paymentPreimage [32]byte
copy(paymentPreimage[:], paymentPreimageBytes)
// Define custom records.
recordKey1 := uint64(MinCustomRecordsTlvType + 1)
recordValue1, err := generateRandomBytes(10)
require.NoError(t, err)
recordKey2 := uint64(MinCustomRecordsTlvType + 2)
recordValue2, err := generateRandomBytes(10)
require.NoError(t, err)
customRecords := CustomRecords{
recordKey1: recordValue1,
recordKey2: recordValue2,
}
// Construct an instance of extra data that contains records with TLV
// types below the minimum custom records threshold and that lack
// corresponding fields in the message struct. Content should persist in
// the extra data field after encoding and decoding.
var (
recordBytes45 = []byte("recordBytes45")
tlvRecord45 = tlv.NewPrimitiveRecord[tlv.TlvType45](
recordBytes45,
)
recordBytes55 = []byte("recordBytes55")
tlvRecord55 = tlv.NewPrimitiveRecord[tlv.TlvType55](
recordBytes55,
)
)
var extraData ExtraOpaqueData
err = extraData.PackRecords(
[]tlv.RecordProducer{&tlvRecord45, &tlvRecord55}...,
)
require.NoError(t, err)
return []testCaseUpdateFulfill{
{
Msg: UpdateFulfillHTLC{
ChanID: chanID,
ID: 42,
PaymentPreimage: paymentPreimage,
},
},
{
Msg: UpdateFulfillHTLC{
ChanID: chanID,
ID: 42,
PaymentPreimage: paymentPreimage,
CustomRecords: customRecords,
ExtraData: extraData,
},
},
}
}
// TestUpdateFulfillHtlcEncodeDecode tests UpdateFulfillHTLC message encoding
// and decoding for all supported field values.
func TestUpdateFulfillHtlcEncodeDecode(t *testing.T) {
t.Parallel()
// Generate test cases.
testCases := generateUpdateFulfillTestCases(t)
// Execute test cases.
for tcIdx, tc := range testCases {
t.Run(fmt.Sprintf("testcase-%d", tcIdx), func(t *testing.T) {
// Encode test case message.
var buf bytes.Buffer
err := tc.Msg.Encode(&buf, 0)
// Check if we expect an encoding error.
if tc.ExpectEncodeError {
require.Error(t, err)
return
}
require.NoError(t, err)
// Decode the encoded message bytes message.
var actualMsg UpdateFulfillHTLC
decodeReader := bytes.NewReader(buf.Bytes())
err = actualMsg.Decode(decodeReader, 0)
require.NoError(t, err)
// Compare the two messages to ensure equality.
require.Equal(t, tc.Msg, actualMsg)
})
}
}

View File

@ -2202,8 +2202,9 @@ func messageSummary(msg lnwire.Message) string {
msg.ID, msg.Reason)
case *lnwire.UpdateFulfillHTLC:
return fmt.Sprintf("chan_id=%v, id=%v, pre_image=%x",
msg.ChanID, msg.ID, msg.PaymentPreimage[:])
return fmt.Sprintf("chan_id=%v, id=%v, pre_image=%x, "+
"custom_records=%v", msg.ChanID, msg.ID,
msg.PaymentPreimage[:], msg.CustomRecords)
case *lnwire.CommitSig:
return fmt.Sprintf("chan_id=%v, num_htlcs=%v", msg.ChanID,