From 21cb454664667e6d79e5a1a95b529b9a86a06025 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 29 May 2024 19:57:37 +0200 Subject: [PATCH] lnwire: add CustomRecords to shutdown message --- lnwire/custom_records.go | 32 +++++++ lnwire/custom_records_test.go | 37 ++++++++ lnwire/shutdown.go | 86 ++++++++++++++--- lnwire/shutdown_test.go | 144 +++++++++++++++++++++++++++++ lnwire/update_add_htlc_test.go | 30 +----- lnwire/update_fulfill_htlc_test.go | 29 +----- 6 files changed, 288 insertions(+), 70 deletions(-) create mode 100644 lnwire/custom_records_test.go create mode 100644 lnwire/shutdown_test.go diff --git a/lnwire/custom_records.go b/lnwire/custom_records.go index ab11e3ce0..dcd7361c1 100644 --- a/lnwire/custom_records.go +++ b/lnwire/custom_records.go @@ -42,6 +42,38 @@ func NewCustomRecordsFromTlvTypeMap(tlvMap tlv.TypeMap) (CustomRecords, return customRecords, nil } +// FilteredCustomRecords returns a new CustomRecords instance containing only +// the records from the given tlv.TypeMap that are in the custom records TLV +// type range. Filtered out records that aren't in the custom records TLV type +// range are returned in a new tlv.TypeMap. +func FilteredCustomRecords(typeMap tlv.TypeMap) (CustomRecords, tlv.TypeMap, + error) { + + // Any records from the extra data TLV map which are in the custom + // records TLV type range will be included in the custom records field. + customRecordsTlvMap := make(tlv.TypeMap, len(typeMap)) + remainder := make(tlv.TypeMap) + for k, v := range typeMap { + // Skip records that are not in the custom records TLV type + // range. + if k < MinCustomRecordsTlvType { + remainder[k] = v + + continue + } + + // Include the record in the custom records map. + customRecordsTlvMap[k] = v + } + + cr, err := NewCustomRecordsFromTlvTypeMap(customRecordsTlvMap) + if err != nil { + return nil, nil, err + } + + return cr, remainder, nil +} + // ParseCustomRecords creates a new CustomRecords instance from a // tlv.Blob. func ParseCustomRecords(b tlv.Blob) (CustomRecords, error) { diff --git a/lnwire/custom_records_test.go b/lnwire/custom_records_test.go new file mode 100644 index 000000000..6f97079f6 --- /dev/null +++ b/lnwire/custom_records_test.go @@ -0,0 +1,37 @@ +package lnwire + +import ( + "testing" + + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" +) + +// TestFilteredCustomRecords tests that we can separate TLV types into custom +// and standard records. +func TestFilteredCustomRecords(t *testing.T) { + // Create a new tlv.TypeMap with some records, both in the standard and + // in the custom range. + typeMap := make(tlv.TypeMap) + typeMap[45] = []byte{1, 2, 3} + typeMap[55] = []byte{4, 5, 6} + typeMap[65] = []byte{7, 8, 9} + typeMap[65536] = []byte{11, 22, 33} + typeMap[65537] = []byte{44, 55, 66} + typeMap[65538] = []byte{77, 88, 99} + + customRecords, remainder, err := FilteredCustomRecords(typeMap) + require.NoError(t, err) + + require.Len(t, customRecords, 3) + require.Len(t, remainder, 3) + + require.Contains(t, maps.Keys(customRecords), uint64(65536)) + require.Contains(t, maps.Keys(customRecords), uint64(65537)) + require.Contains(t, maps.Keys(customRecords), uint64(65538)) + + require.Contains(t, maps.Keys(remainder), tlv.Type(45)) + require.Contains(t, maps.Keys(remainder), tlv.Type(55)) + require.Contains(t, maps.Keys(remainder), tlv.Type(65)) +} diff --git a/lnwire/shutdown.go b/lnwire/shutdown.go index c5455651b..4d3a1ea56 100644 --- a/lnwire/shutdown.go +++ b/lnwire/shutdown.go @@ -2,6 +2,7 @@ package lnwire import ( "bytes" + "fmt" "io" "github.com/lightningnetwork/lnd/tlv" @@ -38,6 +39,11 @@ type Shutdown struct { // co-op sign offer. ShutdownNonce ShutdownNonceTLV + // CustomRecords maps TLV types to byte slices, storing arbitrary data + // intended for inclusion in the ExtraData field of the Shutdown + // message. + CustomRecords CustomRecords + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -56,7 +62,7 @@ func NewShutdown(cid ChannelID, addr DeliveryAddress) *Shutdown { // interface. var _ Message = (*Shutdown)(nil) -// Decode deserializes a serialized Shutdown stored in the passed io.Reader +// Decode deserializes a serialized Shutdown from the passed io.Reader, // observing the specified protocol version. // // This is part of the lnwire.Message interface. @@ -80,10 +86,32 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error { // Set the corresponding TLV types if they were included in the stream. if val, ok := typeMap[s.ShutdownNonce.TlvType()]; ok && val == nil { s.ShutdownNonce = tlv.SomeRecordT(musigNonce) + + // Remove the entry from the TLV map. Anything left in the map + // will be included in the custom records field. + delete(typeMap, s.ShutdownNonce.TlvType()) } - if len(tlvRecords) != 0 { - s.ExtraData = tlvRecords + // Set the custom records field to the remaining TLV records, but only + // those that actually are in the custom TLV type range. + customRecords, filtered, err := FilteredCustomRecords(typeMap) + if err != nil { + return err + } + s.CustomRecords = customRecords + + // Set extra data to nil if we didn't parse anything out of it so that + // we can use assert.Equal in tests. + if len(filtered) == 0 { + s.ExtraData = make([]byte, 0) + } else { + // Encode the remaining records into the extra data field. These + // records are not in the custom records TLV type range and do + // not have associated fields in the UpdateAddHTLC struct. + s.ExtraData, err = NewExtraOpaqueDataFromTlvTypeMap(filtered) + if err != nil { + return err + } } return nil @@ -94,17 +122,6 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error { - recordProducers := make([]tlv.RecordProducer, 0, 1) - s.ShutdownNonce.WhenSome( - func(nonce tlv.RecordT[ShutdownNonceType, Musig2Nonce]) { - recordProducers = append(recordProducers, &nonce) - }, - ) - err := EncodeMessageExtraData(&s.ExtraData, recordProducers...) - if err != nil { - return err - } - if err := WriteChannelID(w, s.ChannelID); err != nil { return err } @@ -113,7 +130,46 @@ func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error { return err } - return WriteBytes(w, s.ExtraData) + // Construct a slice of all the records that we should include in the + // message extra data field. We will start by including any records from + // the extra data field. + msgExtraDataRecords, err := s.ExtraData.RecordProducers() + if err != nil { + return err + } + + s.ShutdownNonce.WhenSome( + func(nonce tlv.RecordT[ShutdownNonceType, Musig2Nonce]) { + msgExtraDataRecords = append( + msgExtraDataRecords, &nonce, + ) + }, + ) + + // Include custom records in the extra data wire field if they are + // present. Ensure that the custom records are validated before + // encoding them. + if err := s.CustomRecords.Validate(); err != nil { + return fmt.Errorf("custom records validation error: %w", err) + } + + // Extend the message extra data records slice with TLV records from + // the custom records field. + recordProducers, err := s.CustomRecords.ExtendRecordProducers( + msgExtraDataRecords, + ) + if err != nil { + return err + } + + // We will now construct the message extra data field that will be + // encoded into the byte writer. + var msgExtraData ExtraOpaqueData + if err := msgExtraData.PackRecords(recordProducers...); err != nil { + return err + } + + return WriteBytes(w, msgExtraData) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/shutdown_test.go b/lnwire/shutdown_test.go new file mode 100644 index 000000000..b8ddfdea2 --- /dev/null +++ b/lnwire/shutdown_test.go @@ -0,0 +1,144 @@ +package lnwire + +import ( + "bytes" + "testing" + + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// testCaseShutdown is a test case for the Shutdown message. +type testCaseShutdown struct { + // Msg is the message to be encoded and decoded. + Msg Shutdown + + // ExpectEncodeError is a flag that indicates whether we expect the + // encoding of the message to fail. + ExpectEncodeError bool +} + +// generateShutdownTestCases generates a set of Shutdown message test cases. +func generateShutdownTestCases(t *testing.T) []testCaseShutdown { + // Firstly, we'll set basic values for the message fields. + // + // Generate random channel ID. + chanIDBytes, err := generateRandomBytes(32) + require.NoError(t, err) + + var chanID ChannelID + copy(chanID[:], chanIDBytes) + + // Generate random payment preimage. + paymentPreimageBytes, err := generateRandomBytes(32) + require.NoError(t, err) + + var paymentPreimage [32]byte + copy(paymentPreimage[:], paymentPreimageBytes) + + deliveryAddr, err := generateRandomBytes(16) + require.NoError(t, err) + + // Define custom records. + recordKey1 := uint64(MinCustomRecordsTlvType + 1) + recordValue1, err := generateRandomBytes(10) + require.NoError(t, err) + + recordKey2 := uint64(MinCustomRecordsTlvType + 2) + recordValue2, err := generateRandomBytes(10) + require.NoError(t, err) + + customRecords := CustomRecords{ + recordKey1: recordValue1, + recordKey2: recordValue2, + } + + dummyPubKey, err := pubkeyFromHex( + "0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d4" + + "8236c39", + ) + require.NoError(t, err) + + muSig2Nonce, err := musig2.GenNonces(musig2.WithPublicKey(dummyPubKey)) + require.NoError(t, err) + + // Construct an instance of extra data that contains records with TLV + // types below the minimum custom records threshold and that lack + // corresponding fields in the message struct. Content should persist in + // the extra data field after encoding and decoding. + var ( + recordBytes45 = []byte("recordBytes45") + tlvRecord45 = tlv.NewPrimitiveRecord[tlv.TlvType45]( + recordBytes45, + ) + + recordBytes55 = []byte("recordBytes55") + tlvRecord55 = tlv.NewPrimitiveRecord[tlv.TlvType55]( + recordBytes55, + ) + ) + + var extraData ExtraOpaqueData + err = extraData.PackRecords( + []tlv.RecordProducer{&tlvRecord45, &tlvRecord55}..., + ) + require.NoError(t, err) + + // Define test cases. + testCases := make([]testCaseShutdown, 0) + + testCases = append(testCases, testCaseShutdown{ + Msg: Shutdown{ + ChannelID: chanID, + CustomRecords: customRecords, + ExtraData: extraData, + Address: deliveryAddr, + }, + }, testCaseShutdown{ + Msg: Shutdown{ + ChannelID: chanID, + CustomRecords: customRecords, + ExtraData: extraData, + Address: deliveryAddr, + ShutdownNonce: SomeShutdownNonce(muSig2Nonce.PubNonce), + }, + }) + + return testCases +} + +// TestShutdownEncodeDecode tests Shutdown message encoding and decoding for all +// supported field values. +func TestShutdownEncodeDecode(t *testing.T) { + t.Parallel() + + // Generate test cases. + testCases := generateShutdownTestCases(t) + + // Execute test cases. + for tcIdx, tc := range testCases { + t.Log("Running test case", tcIdx) + + // Encode test case message. + var buf bytes.Buffer + err := tc.Msg.Encode(&buf, 0) + + // Check if we expect an encoding error. + if tc.ExpectEncodeError { + require.Error(t, err) + continue + } + require.NoError(t, err) + + // Decode the encoded message bytes message. + var actualMsg Shutdown + decodeReader := bytes.NewReader(buf.Bytes()) + err = actualMsg.Decode(decodeReader, 0) + require.NoError(t, err) + + // Compare the two messages to ensure equality one field at a + // time. + require.Equal(t, tc.Msg, actualMsg) + } +} diff --git a/lnwire/update_add_htlc_test.go b/lnwire/update_add_htlc_test.go index 5a0515461..861eca6d2 100644 --- a/lnwire/update_add_htlc_test.go +++ b/lnwire/update_add_htlc_test.go @@ -184,33 +184,7 @@ func TestUpdateAddHtlcEncodeDecode(t *testing.T) { err = actualMsg.Decode(decodeReader, 0) require.NoError(t, err) - // Compare the two messages to ensure equality one field at a - // time. - require.Equal(t, tc.Msg.ChanID, actualMsg.ChanID) - require.Equal(t, tc.Msg.ID, actualMsg.ID) - require.Equal(t, tc.Msg.Amount, actualMsg.Amount) - require.Equal(t, tc.Msg.PaymentHash, actualMsg.PaymentHash) - require.Equal(t, tc.Msg.OnionBlob, actualMsg.OnionBlob) - require.Equal(t, tc.Msg.BlindingPoint, actualMsg.BlindingPoint) - - // Check that the custom records field is as expected. - if len(tc.Msg.CustomRecords) == 0 { - require.Len(t, actualMsg.CustomRecords, 0) - } else { - require.Equal( - t, tc.Msg.CustomRecords, - actualMsg.CustomRecords, - ) - } - - // Check that the extra data field is as expected. - if len(tc.Msg.ExtraData) == 0 { - require.Len(t, actualMsg.ExtraData, 0) - } else { - require.Equal( - t, tc.Msg.ExtraData, - actualMsg.ExtraData, - ) - } + // Compare the two messages to ensure equality. + require.Equal(t, tc.Msg, actualMsg) } } diff --git a/lnwire/update_fulfill_htlc_test.go b/lnwire/update_fulfill_htlc_test.go index 705798df4..8de040645 100644 --- a/lnwire/update_fulfill_htlc_test.go +++ b/lnwire/update_fulfill_htlc_test.go @@ -117,32 +117,7 @@ func TestUpdateFulfillHtlcEncodeDecode(t *testing.T) { err = actualMsg.Decode(decodeReader, 0) require.NoError(t, err) - // Compare the two messages to ensure equality one field at a - // time. - require.Equal(t, tc.Msg.ChanID, actualMsg.ChanID) - require.Equal(t, tc.Msg.ID, actualMsg.ID) - require.Equal( - t, tc.Msg.PaymentPreimage, actualMsg.PaymentPreimage, - ) - - // Check that the custom records field is as expected. - if len(tc.Msg.CustomRecords) == 0 { - require.Len(t, actualMsg.CustomRecords, 0) - } else { - require.Equal( - t, tc.Msg.CustomRecords, - actualMsg.CustomRecords, - ) - } - - // Check that the extra data field is as expected. - if len(tc.Msg.ExtraData) == 0 { - require.Len(t, actualMsg.ExtraData, 0) - } else { - require.Equal( - t, tc.Msg.ExtraData, - actualMsg.ExtraData, - ) - } + // Compare the two messages to ensure equality. + require.Equal(t, tc.Msg, actualMsg) } }