lnwire: add CustomRecords to shutdown message

This commit is contained in:
Olaoluwa Osuntokun 2024-05-29 19:57:37 +02:00 committed by Oliver Gugger
parent 1f86f38f36
commit 21cb454664
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
6 changed files with 288 additions and 70 deletions

View File

@ -42,6 +42,38 @@ func NewCustomRecordsFromTlvTypeMap(tlvMap tlv.TypeMap) (CustomRecords,
return customRecords, nil
}
// FilteredCustomRecords returns a new CustomRecords instance containing only
// the records from the given tlv.TypeMap that are in the custom records TLV
// type range. Filtered out records that aren't in the custom records TLV type
// range are returned in a new tlv.TypeMap.
func FilteredCustomRecords(typeMap tlv.TypeMap) (CustomRecords, tlv.TypeMap,
error) {
// Any records from the extra data TLV map which are in the custom
// records TLV type range will be included in the custom records field.
customRecordsTlvMap := make(tlv.TypeMap, len(typeMap))
remainder := make(tlv.TypeMap)
for k, v := range typeMap {
// Skip records that are not in the custom records TLV type
// range.
if k < MinCustomRecordsTlvType {
remainder[k] = v
continue
}
// Include the record in the custom records map.
customRecordsTlvMap[k] = v
}
cr, err := NewCustomRecordsFromTlvTypeMap(customRecordsTlvMap)
if err != nil {
return nil, nil, err
}
return cr, remainder, nil
}
// ParseCustomRecords creates a new CustomRecords instance from a
// tlv.Blob.
func ParseCustomRecords(b tlv.Blob) (CustomRecords, error) {

View File

@ -0,0 +1,37 @@
package lnwire
import (
"testing"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
)
// TestFilteredCustomRecords tests that we can separate TLV types into custom
// and standard records.
func TestFilteredCustomRecords(t *testing.T) {
// Create a new tlv.TypeMap with some records, both in the standard and
// in the custom range.
typeMap := make(tlv.TypeMap)
typeMap[45] = []byte{1, 2, 3}
typeMap[55] = []byte{4, 5, 6}
typeMap[65] = []byte{7, 8, 9}
typeMap[65536] = []byte{11, 22, 33}
typeMap[65537] = []byte{44, 55, 66}
typeMap[65538] = []byte{77, 88, 99}
customRecords, remainder, err := FilteredCustomRecords(typeMap)
require.NoError(t, err)
require.Len(t, customRecords, 3)
require.Len(t, remainder, 3)
require.Contains(t, maps.Keys(customRecords), uint64(65536))
require.Contains(t, maps.Keys(customRecords), uint64(65537))
require.Contains(t, maps.Keys(customRecords), uint64(65538))
require.Contains(t, maps.Keys(remainder), tlv.Type(45))
require.Contains(t, maps.Keys(remainder), tlv.Type(55))
require.Contains(t, maps.Keys(remainder), tlv.Type(65))
}

View File

@ -2,6 +2,7 @@ package lnwire
import (
"bytes"
"fmt"
"io"
"github.com/lightningnetwork/lnd/tlv"
@ -38,6 +39,11 @@ type Shutdown struct {
// co-op sign offer.
ShutdownNonce ShutdownNonceTLV
// CustomRecords maps TLV types to byte slices, storing arbitrary data
// intended for inclusion in the ExtraData field of the Shutdown
// message.
CustomRecords CustomRecords
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
@ -56,7 +62,7 @@ func NewShutdown(cid ChannelID, addr DeliveryAddress) *Shutdown {
// interface.
var _ Message = (*Shutdown)(nil)
// Decode deserializes a serialized Shutdown stored in the passed io.Reader
// Decode deserializes a serialized Shutdown from the passed io.Reader,
// observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
@ -80,10 +86,32 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error {
// Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[s.ShutdownNonce.TlvType()]; ok && val == nil {
s.ShutdownNonce = tlv.SomeRecordT(musigNonce)
// Remove the entry from the TLV map. Anything left in the map
// will be included in the custom records field.
delete(typeMap, s.ShutdownNonce.TlvType())
}
if len(tlvRecords) != 0 {
s.ExtraData = tlvRecords
// Set the custom records field to the remaining TLV records, but only
// those that actually are in the custom TLV type range.
customRecords, filtered, err := FilteredCustomRecords(typeMap)
if err != nil {
return err
}
s.CustomRecords = customRecords
// Set extra data to nil if we didn't parse anything out of it so that
// we can use assert.Equal in tests.
if len(filtered) == 0 {
s.ExtraData = make([]byte, 0)
} else {
// Encode the remaining records into the extra data field. These
// records are not in the custom records TLV type range and do
// not have associated fields in the UpdateAddHTLC struct.
s.ExtraData, err = NewExtraOpaqueDataFromTlvTypeMap(filtered)
if err != nil {
return err
}
}
return nil
@ -94,17 +122,6 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error {
//
// This is part of the lnwire.Message interface.
func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error {
recordProducers := make([]tlv.RecordProducer, 0, 1)
s.ShutdownNonce.WhenSome(
func(nonce tlv.RecordT[ShutdownNonceType, Musig2Nonce]) {
recordProducers = append(recordProducers, &nonce)
},
)
err := EncodeMessageExtraData(&s.ExtraData, recordProducers...)
if err != nil {
return err
}
if err := WriteChannelID(w, s.ChannelID); err != nil {
return err
}
@ -113,7 +130,46 @@ func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error {
return err
}
return WriteBytes(w, s.ExtraData)
// Construct a slice of all the records that we should include in the
// message extra data field. We will start by including any records from
// the extra data field.
msgExtraDataRecords, err := s.ExtraData.RecordProducers()
if err != nil {
return err
}
s.ShutdownNonce.WhenSome(
func(nonce tlv.RecordT[ShutdownNonceType, Musig2Nonce]) {
msgExtraDataRecords = append(
msgExtraDataRecords, &nonce,
)
},
)
// Include custom records in the extra data wire field if they are
// present. Ensure that the custom records are validated before
// encoding them.
if err := s.CustomRecords.Validate(); err != nil {
return fmt.Errorf("custom records validation error: %w", err)
}
// Extend the message extra data records slice with TLV records from
// the custom records field.
recordProducers, err := s.CustomRecords.ExtendRecordProducers(
msgExtraDataRecords,
)
if err != nil {
return err
}
// We will now construct the message extra data field that will be
// encoded into the byte writer.
var msgExtraData ExtraOpaqueData
if err := msgExtraData.PackRecords(recordProducers...); err != nil {
return err
}
return WriteBytes(w, msgExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the

144
lnwire/shutdown_test.go Normal file
View File

@ -0,0 +1,144 @@
package lnwire
import (
"bytes"
"testing"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require"
)
// testCaseShutdown is a test case for the Shutdown message.
type testCaseShutdown struct {
// Msg is the message to be encoded and decoded.
Msg Shutdown
// ExpectEncodeError is a flag that indicates whether we expect the
// encoding of the message to fail.
ExpectEncodeError bool
}
// generateShutdownTestCases generates a set of Shutdown message test cases.
func generateShutdownTestCases(t *testing.T) []testCaseShutdown {
// Firstly, we'll set basic values for the message fields.
//
// Generate random channel ID.
chanIDBytes, err := generateRandomBytes(32)
require.NoError(t, err)
var chanID ChannelID
copy(chanID[:], chanIDBytes)
// Generate random payment preimage.
paymentPreimageBytes, err := generateRandomBytes(32)
require.NoError(t, err)
var paymentPreimage [32]byte
copy(paymentPreimage[:], paymentPreimageBytes)
deliveryAddr, err := generateRandomBytes(16)
require.NoError(t, err)
// Define custom records.
recordKey1 := uint64(MinCustomRecordsTlvType + 1)
recordValue1, err := generateRandomBytes(10)
require.NoError(t, err)
recordKey2 := uint64(MinCustomRecordsTlvType + 2)
recordValue2, err := generateRandomBytes(10)
require.NoError(t, err)
customRecords := CustomRecords{
recordKey1: recordValue1,
recordKey2: recordValue2,
}
dummyPubKey, err := pubkeyFromHex(
"0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d4" +
"8236c39",
)
require.NoError(t, err)
muSig2Nonce, err := musig2.GenNonces(musig2.WithPublicKey(dummyPubKey))
require.NoError(t, err)
// Construct an instance of extra data that contains records with TLV
// types below the minimum custom records threshold and that lack
// corresponding fields in the message struct. Content should persist in
// the extra data field after encoding and decoding.
var (
recordBytes45 = []byte("recordBytes45")
tlvRecord45 = tlv.NewPrimitiveRecord[tlv.TlvType45](
recordBytes45,
)
recordBytes55 = []byte("recordBytes55")
tlvRecord55 = tlv.NewPrimitiveRecord[tlv.TlvType55](
recordBytes55,
)
)
var extraData ExtraOpaqueData
err = extraData.PackRecords(
[]tlv.RecordProducer{&tlvRecord45, &tlvRecord55}...,
)
require.NoError(t, err)
// Define test cases.
testCases := make([]testCaseShutdown, 0)
testCases = append(testCases, testCaseShutdown{
Msg: Shutdown{
ChannelID: chanID,
CustomRecords: customRecords,
ExtraData: extraData,
Address: deliveryAddr,
},
}, testCaseShutdown{
Msg: Shutdown{
ChannelID: chanID,
CustomRecords: customRecords,
ExtraData: extraData,
Address: deliveryAddr,
ShutdownNonce: SomeShutdownNonce(muSig2Nonce.PubNonce),
},
})
return testCases
}
// TestShutdownEncodeDecode tests Shutdown message encoding and decoding for all
// supported field values.
func TestShutdownEncodeDecode(t *testing.T) {
t.Parallel()
// Generate test cases.
testCases := generateShutdownTestCases(t)
// Execute test cases.
for tcIdx, tc := range testCases {
t.Log("Running test case", tcIdx)
// Encode test case message.
var buf bytes.Buffer
err := tc.Msg.Encode(&buf, 0)
// Check if we expect an encoding error.
if tc.ExpectEncodeError {
require.Error(t, err)
continue
}
require.NoError(t, err)
// Decode the encoded message bytes message.
var actualMsg Shutdown
decodeReader := bytes.NewReader(buf.Bytes())
err = actualMsg.Decode(decodeReader, 0)
require.NoError(t, err)
// Compare the two messages to ensure equality one field at a
// time.
require.Equal(t, tc.Msg, actualMsg)
}
}

View File

@ -184,33 +184,7 @@ func TestUpdateAddHtlcEncodeDecode(t *testing.T) {
err = actualMsg.Decode(decodeReader, 0)
require.NoError(t, err)
// Compare the two messages to ensure equality one field at a
// time.
require.Equal(t, tc.Msg.ChanID, actualMsg.ChanID)
require.Equal(t, tc.Msg.ID, actualMsg.ID)
require.Equal(t, tc.Msg.Amount, actualMsg.Amount)
require.Equal(t, tc.Msg.PaymentHash, actualMsg.PaymentHash)
require.Equal(t, tc.Msg.OnionBlob, actualMsg.OnionBlob)
require.Equal(t, tc.Msg.BlindingPoint, actualMsg.BlindingPoint)
// Check that the custom records field is as expected.
if len(tc.Msg.CustomRecords) == 0 {
require.Len(t, actualMsg.CustomRecords, 0)
} else {
require.Equal(
t, tc.Msg.CustomRecords,
actualMsg.CustomRecords,
)
}
// Check that the extra data field is as expected.
if len(tc.Msg.ExtraData) == 0 {
require.Len(t, actualMsg.ExtraData, 0)
} else {
require.Equal(
t, tc.Msg.ExtraData,
actualMsg.ExtraData,
)
}
// Compare the two messages to ensure equality.
require.Equal(t, tc.Msg, actualMsg)
}
}

View File

@ -117,32 +117,7 @@ func TestUpdateFulfillHtlcEncodeDecode(t *testing.T) {
err = actualMsg.Decode(decodeReader, 0)
require.NoError(t, err)
// Compare the two messages to ensure equality one field at a
// time.
require.Equal(t, tc.Msg.ChanID, actualMsg.ChanID)
require.Equal(t, tc.Msg.ID, actualMsg.ID)
require.Equal(
t, tc.Msg.PaymentPreimage, actualMsg.PaymentPreimage,
)
// Check that the custom records field is as expected.
if len(tc.Msg.CustomRecords) == 0 {
require.Len(t, actualMsg.CustomRecords, 0)
} else {
require.Equal(
t, tc.Msg.CustomRecords,
actualMsg.CustomRecords,
)
}
// Check that the extra data field is as expected.
if len(tc.Msg.ExtraData) == 0 {
require.Len(t, actualMsg.ExtraData, 0)
} else {
require.Equal(
t, tc.Msg.ExtraData,
actualMsg.ExtraData,
)
}
// Compare the two messages to ensure equality.
require.Equal(t, tc.Msg, actualMsg)
}
}