tlv: modify RecordT type to ensure type param takes precedence

In this commit, we modify the RecordT type to allow callers to re-use
the Record definition of a different type, but use the new type param to
override the integer type used on the wire.

This will let use do things like encode a signature using the same
RecordProducer instance, but with a diff type in another context.

The upcoming use for this is allowing our `lnwire.Sig` type to be
encoded in the same message using distinct TLV integer types (new co-op
close protocol).
This commit is contained in:
Olaoluwa Osuntokun 2024-01-02 16:53:28 -08:00
parent 7f8b185f40
commit 66cf4396a2
No known key found for this signature in database
GPG key ID: 3BBD59E99B280306
2 changed files with 37 additions and 1 deletions

View file

@ -66,7 +66,18 @@ func (t *RecordT[T, V]) Record() Record {
)
}
return tlvRecord.Record()
// To enforce proper usage of the RecordT type, we'll make a wrapper
// record that uses the proper internal type value.
ogRecord := tlvRecord.Record()
return Record{
value: ogRecord.value,
typ: t.recordType.typeVal(),
staticSize: ogRecord.staticSize,
sizeFunc: ogRecord.sizeFunc,
encoder: ogRecord.encoder,
decoder: ogRecord.decoder,
}
}
// OptionalRecordT is a high-order type that represents an optional TLV record.

View file

@ -63,6 +63,10 @@ type coolWireMsg struct {
CsvDelay RecordT[TlvType1, wireCsv]
}
type coolWireMsgDiffContext struct {
CsvDelay RecordT[TlvType3, wireCsv]
}
// TestRecordTFromRecord tests that we can create a RecordT type from an
// existing record type and encode/decode as normal.
func TestRecordTFromRecord(t *testing.T) {
@ -91,3 +95,24 @@ func TestRecordTFromRecord(t *testing.T) {
require.Equal(t, wireMsg, wireMsg2)
}
// TestRecordTFromRecordTypeOverride tests that we can create a RecordT type
// from an existing record type and encode/decode as normal. In this variant,
// we make sure that we can use the type system to override the type of an
// original record.
func TestRecordTFromRecordTypeOverride(t *testing.T) {
t.Parallel()
// First, we'll make a new wire message. Instead of using the TLV type
// of 1 (hard coded in the Record() method defined above), we'll
// instead use TLvType3, as we want to use the same encode/decode, but
// in a context with a different integer type.
val := wireCsv(5)
wireMsg := coolWireMsgDiffContext{
CsvDelay: NewRecordT[TlvType3](val),
}
// If we extract the record, we should see that the type is now 3.
tlvRecord := wireMsg.CsvDelay.Record()
require.Equal(t, tlvRecord.Type(), Type(3))
}