lnwire: add custom records field to type UpdateAddHtlc

- Introduce the field `CustomRecords` to the type `UpdateAddHtlc`.
- Encode and decode the new field into the `ExtraData` field of
  the `update_add_htlc` wire message.
This commit is contained in:
ffranr 2024-04-13 12:29:41 +01:00 committed by Oliver Gugger
parent af50694643
commit 81f6a8060f
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
4 changed files with 272 additions and 19 deletions

View File

@ -2,6 +2,7 @@ package lnwire
import (
"bytes"
crand "crypto/rand"
"encoding/binary"
"encoding/hex"
"fmt"
@ -134,6 +135,27 @@ func randPubKey() (*btcec.PublicKey, error) {
return priv.PubKey(), nil
}
// pubkeyFromHex parses a Bitcoin public key from a hex encoded string.
func pubkeyFromHex(keyHex string) (*btcec.PublicKey, error) {
pubKeyBytes, err := hex.DecodeString(keyHex)
if err != nil {
return nil, err
}
return btcec.ParsePubKey(pubKeyBytes)
}
// generateRandomBytes returns a slice of n random bytes.
func generateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := crand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
func randRawKey() ([33]byte, error) {
var n [33]byte
@ -389,6 +411,37 @@ func TestEmptyMessageUnknownType(t *testing.T) {
}
}
// randCustomRecords generates a random set of custom records for testing.
func randCustomRecords(t *testing.T, r *rand.Rand) CustomRecords {
var (
customRecords = CustomRecords{}
// We'll generate a random number of records, between 1 and 10.
numRecords = r.Intn(9) + 1
)
// For each record, we'll generate a random key and value.
for i := 0; i < numRecords; i++ {
// Keys must be equal to or greater than
// MinCustomRecordsTlvType.
keyOffset := uint64(r.Intn(100))
key := MinCustomRecordsTlvType + keyOffset
// Values are byte slices of any length.
value := make([]byte, r.Intn(100))
_, err := r.Read(value)
require.NoError(t, err)
customRecords[key] = value
}
// Validate the custom records as a sanity check.
err := customRecords.Validate()
require.NoError(t, err)
return customRecords
}
// TestLightningWireProtocol uses the testing/quick package to create a series
// of fuzz tests to attempt to break a primary scenario which is implemented as
// property based testing scenario.
@ -1369,6 +1422,8 @@ func TestLightningWireProtocol(t *testing.T) {
_, err = r.Read(req.OnionBlob[:])
require.NoError(t, err)
req.CustomRecords = randCustomRecords(t, r)
// Generate a blinding point 50% of the time, since not
// all update adds will use route blinding.
if r.Int31()%2 == 0 {

View File

@ -72,6 +72,11 @@ type UpdateAddHTLC struct {
// next hop for this htlc.
BlindingPoint BlindingPointRecord
// CustomRecords maps TLV types to byte slices, storing arbitrary data
// intended for inclusion in the ExtraData field of the UpdateAddHTLC
// message.
CustomRecords CustomRecords
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
@ -92,6 +97,10 @@ var _ Message = (*UpdateAddHTLC)(nil)
//
// This is part of the lnwire.Message interface.
func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error {
// msgExtraData is a temporary variable used to read the message extra
// data field from the reader.
var msgExtraData ExtraOpaqueData
if err := ReadElements(r,
&c.ChanID,
&c.ID,
@ -99,26 +108,28 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error {
c.PaymentHash[:],
&c.Expiry,
c.OnionBlob[:],
&c.ExtraData,
&msgExtraData,
); err != nil {
return err
}
// Extract TLV records from the extra data field.
blindingRecord := c.BlindingPoint.Zero()
tlvMap, err := c.ExtraData.ExtractRecords(&blindingRecord)
customRecords, parsed, extraData, err := ParseAndExtractCustomRecords(
msgExtraData, &blindingRecord,
)
if err != nil {
return err
}
if val, ok := tlvMap[c.BlindingPoint.TlvType()]; ok && val == nil {
// Assign the parsed records back to the message.
if parsed.Contains(blindingRecord.TlvType()) {
c.BlindingPoint = tlv.SomeRecordT(blindingRecord)
}
// Set extra data to nil if we didn't parse anything out of it so that
// we can use assert.Equal in tests.
if len(tlvMap) == 0 {
c.ExtraData = nil
}
c.CustomRecords = customRecords
c.ExtraData = extraData
return nil
}
@ -154,19 +165,18 @@ func (c *UpdateAddHTLC) Encode(w *bytes.Buffer, pver uint32) error {
// Only include blinding point in extra data if present.
var records []tlv.RecordProducer
c.BlindingPoint.WhenSome(
func(b tlv.RecordT[BlindingPointTlvType, *btcec.PublicKey]) {
records = append(records, &b)
},
)
c.BlindingPoint.WhenSome(func(b tlv.RecordT[BlindingPointTlvType,
*btcec.PublicKey]) {
records = append(records, &b)
})
err := EncodeMessageExtraData(&c.ExtraData, records...)
extraData, err := MergeAndEncode(records, c.ExtraData, c.CustomRecords)
if err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
return WriteBytes(w, extraData)
}
// MsgType returns the integer uniquely identifying this message type on the

View File

@ -0,0 +1,188 @@
package lnwire
import (
"bytes"
"fmt"
"testing"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require"
)
// testCase is a test case for the UpdateAddHTLC message.
type testCase struct {
// Msg is the message to be encoded and decoded.
Msg UpdateAddHTLC
// ExpectEncodeError is a flag that indicates whether we expect the
// encoding of the message to fail.
ExpectEncodeError bool
}
// generateTestCases generates a set of UpdateAddHTLC message test cases.
func generateTestCases(t *testing.T) []testCase {
// Firstly, we'll set basic values for the message fields.
//
// Generate random channel ID.
chanIDBytes, err := generateRandomBytes(32)
require.NoError(t, err)
var chanID ChannelID
copy(chanID[:], chanIDBytes)
// Generate random payment hash.
paymentHashBytes, err := generateRandomBytes(32)
require.NoError(t, err)
var paymentHash [32]byte
copy(paymentHash[:], paymentHashBytes)
// Generate random onion blob.
onionBlobBytes, err := generateRandomBytes(OnionPacketSize)
require.NoError(t, err)
var onionBlob [OnionPacketSize]byte
copy(onionBlob[:], onionBlobBytes)
// Define the blinding point.
blinding, err := pubkeyFromHex(
"0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d4" +
"8236c39",
)
require.NoError(t, err)
blindingPoint := tlv.SomeRecordT(
tlv.NewPrimitiveRecord[BlindingPointTlvType](blinding),
)
// Define custom records.
recordKey1 := uint64(MinCustomRecordsTlvType + 1)
recordValue1, err := generateRandomBytes(10)
require.NoError(t, err)
recordKey2 := uint64(MinCustomRecordsTlvType + 2)
recordValue2, err := generateRandomBytes(10)
require.NoError(t, err)
customRecords := CustomRecords{
recordKey1: recordValue1,
recordKey2: recordValue2,
}
// Construct an instance of extra data that contains records with TLV
// types below the minimum custom records threshold and that lack
// corresponding fields in the message struct. Content should persist in
// the extra data field after encoding and decoding.
var (
recordBytes45 = []byte("recordBytes45")
tlvRecord45 = tlv.NewPrimitiveRecord[tlv.TlvType45](
recordBytes45,
)
recordBytes55 = []byte("recordBytes55")
tlvRecord55 = tlv.NewPrimitiveRecord[tlv.TlvType55](
recordBytes55,
)
)
var extraData ExtraOpaqueData
err = extraData.PackRecords(
[]tlv.RecordProducer{&tlvRecord45, &tlvRecord55}...,
)
require.NoError(t, err)
invalidCustomRecords := CustomRecords{
MinCustomRecordsTlvType - 1: recordValue1,
}
return []testCase{
{
Msg: UpdateAddHTLC{
ChanID: chanID,
ID: 42,
Amount: MilliSatoshi(1000),
PaymentHash: paymentHash,
Expiry: 43,
OnionBlob: onionBlob,
BlindingPoint: blindingPoint,
CustomRecords: customRecords,
ExtraData: extraData,
},
},
// Add a test case where the blinding point field is not
// populated.
{
Msg: UpdateAddHTLC{
ChanID: chanID,
ID: 42,
Amount: MilliSatoshi(1000),
PaymentHash: paymentHash,
Expiry: 43,
OnionBlob: onionBlob,
CustomRecords: customRecords,
},
},
// Add a test case where the custom records field is not
// populated.
{
Msg: UpdateAddHTLC{
ChanID: chanID,
ID: 42,
Amount: MilliSatoshi(1000),
PaymentHash: paymentHash,
Expiry: 43,
OnionBlob: onionBlob,
BlindingPoint: blindingPoint,
},
},
// Add a case where the custom records are invalid.
{
Msg: UpdateAddHTLC{
ChanID: chanID,
ID: 42,
Amount: MilliSatoshi(1000),
PaymentHash: paymentHash,
Expiry: 43,
OnionBlob: onionBlob,
BlindingPoint: blindingPoint,
CustomRecords: invalidCustomRecords,
},
ExpectEncodeError: true,
},
}
}
// TestUpdateAddHtlcEncodeDecode tests UpdateAddHTLC message encoding and
// decoding for all supported field values.
func TestUpdateAddHtlcEncodeDecode(t *testing.T) {
t.Parallel()
// Generate test cases.
testCases := generateTestCases(t)
// Execute test cases.
for tcIdx, tc := range testCases {
t.Run(fmt.Sprintf("testcase-%d", tcIdx), func(t *testing.T) {
// Encode test case message.
var buf bytes.Buffer
err := tc.Msg.Encode(&buf, 0)
// Check if we expect an encoding error.
if tc.ExpectEncodeError {
require.Error(t, err)
return
}
require.NoError(t, err)
// Decode the encoded message bytes message.
var actualMsg UpdateAddHTLC
decodeReader := bytes.NewReader(buf.Bytes())
err = actualMsg.Decode(decodeReader, 0)
require.NoError(t, err)
// Compare the two messages to ensure equality.
require.Equal(t, tc.Msg, actualMsg)
})
}
}

View File

@ -2193,9 +2193,9 @@ func messageSummary(msg lnwire.Message) string {
)
return fmt.Sprintf("chan_id=%v, id=%v, amt=%v, expiry=%v, "+
"hash=%x, blinding_point=%x", msg.ChanID, msg.ID,
msg.Amount, msg.Expiry, msg.PaymentHash[:],
blindingPoint)
"hash=%x, blinding_point=%x, custom_records=%v",
msg.ChanID, msg.ID, msg.Amount, msg.Expiry,
msg.PaymentHash[:], blindingPoint, msg.CustomRecords)
case *lnwire.UpdateFailHTLC:
return fmt.Sprintf("chan_id=%v, id=%v, reason=%x", msg.ChanID,