lnwire: add CustomRecords field to CommitSig

In a future commit, we'll use the new field to ensure that if we add any additional records, they aren't over written by the TLV records that would be encoded.
This commit is contained in:
Olaoluwa Osuntokun 2024-08-29 20:52:29 -05:00
parent 681f44fd16
commit 7dd3a5b361
No known key found for this signature in database
GPG Key ID: 3BBD59E99B280306
2 changed files with 84 additions and 16 deletions

View File

@ -2,6 +2,7 @@ package lnwire
import (
"bytes"
"fmt"
"io"
"github.com/lightningnetwork/lnd/tlv"
@ -45,6 +46,11 @@ type CommitSig struct {
// being signed for. In this case, the above Sig type MUST be blank.
PartialSig OptPartialSigWithNonceTLV
// CustomRecords maps TLV types to byte slices, storing arbitrary data
// intended for inclusion in the ExtraData field of the CommitSig
// message.
CustomRecords CustomRecords
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
@ -62,8 +68,8 @@ func NewCommitSig() *CommitSig {
// interface.
var _ Message = (*CommitSig)(nil)
// Decode deserializes a serialized CommitSig message stored in the
// passed io.Reader observing the specified protocol version.
// Decode deserializes a serialized CommitSig message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *CommitSig) Decode(r io.Reader, pver uint32) error {
@ -90,29 +96,57 @@ func (c *CommitSig) Decode(r io.Reader, pver uint32) error {
// Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[c.PartialSig.TlvType()]; ok && val == nil {
c.PartialSig = tlv.SomeRecordT(partialSig)
// Remove the entry from the TLV map. Anything left in the map
// will be included in the custom records field.
delete(typeMap, c.PartialSig.TlvType())
}
if len(tlvRecords) != 0 {
c.ExtraData = tlvRecords
// Parse through the remaining extra data map to separate the custom
// records, from the set of official records.
tlvTypes := newWireTlvMap(typeMap)
// Set the custom records field to the custom records specific TLV
// record map.
customRecords, err := NewCustomRecordsFromTlvTypeMap(
tlvTypes.customTypes,
)
if err != nil {
return err
}
c.CustomRecords = customRecords
// Set custom records to nil if we didn't parse anything out of it so
// that we can use assert.Equal in tests.
if len(customRecords) == 0 {
c.CustomRecords = nil
}
// Set extra data to nil if we didn't parse anything out of it so that
// we can use assert.Equal in tests.
if len(tlvTypes.officialTypes) == 0 {
c.ExtraData = nil
return nil
}
// Encode the remaining records back into the extra data field. These
// records are not in the custom records TLV type range and do not have
// associated fields in the CommitSig struct.
c.ExtraData, err = NewExtraOpaqueDataFromTlvTypeMap(
tlvTypes.officialTypes,
)
if err != nil {
return err
}
return nil
}
// Encode serializes the target CommitSig into the passed io.Writer
// observing the protocol version specified.
// Encode serializes the target CommitSig into the passed io.Writer observing
// the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *CommitSig) Encode(w *bytes.Buffer, pver uint32) error {
recordProducers := make([]tlv.RecordProducer, 0, 1)
c.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) {
recordProducers = append(recordProducers, &sig)
})
err := EncodeMessageExtraData(&c.ExtraData, recordProducers...)
if err != nil {
return err
}
if err := WriteChannelID(w, c.ChanID); err != nil {
return err
}
@ -125,7 +159,39 @@ func (c *CommitSig) Encode(w *bytes.Buffer, pver uint32) error {
return err
}
return WriteBytes(w, c.ExtraData)
// Construct a slice of all the records that we should include in the
// message extra data field. We will start by including any records
// from the extra data field.
msgExtraDataRecords, err := c.ExtraData.RecordProducers()
if err != nil {
return err
}
// Include the partial sig record if it is set.
c.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) {
msgExtraDataRecords = append(msgExtraDataRecords, &sig)
})
// Include custom records in the extra data wire field if they are
// present. Ensure that the custom records are validated before
// encoding them.
if err := c.CustomRecords.Validate(); err != nil {
return fmt.Errorf("custom records validation error: %w", err)
}
// Extend the message extra data records slice with TLV records from
// the custom records field.
customTlvRecords := c.CustomRecords.RecordProducers()
msgExtraDataRecords = append(msgExtraDataRecords, customTlvRecords...)
// We will now construct the message extra data field that will be
// encoded into the byte writer.
var msgExtraData ExtraOpaqueData
if err := msgExtraData.PackRecords(msgExtraDataRecords...); err != nil {
return err
}
return WriteBytes(w, msgExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the

View File

@ -915,6 +915,8 @@ func TestLightningWireProtocol(t *testing.T) {
}
}
req.CustomRecords = randCustomRecords(t, r)
// 50/50 chance to attach a partial sig.
if r.Int31()%2 == 0 {
req.PartialSig = somePartialSigWithNonce(t, r)