diff --git a/tlv/record_type.go b/tlv/record_type.go index 1b25a6ade..badb86de2 100644 --- a/tlv/record_type.go +++ b/tlv/record_type.go @@ -66,7 +66,18 @@ func (t *RecordT[T, V]) Record() Record { ) } - return tlvRecord.Record() + // To enforce proper usage of the RecordT type, we'll make a wrapper + // record that uses the proper internal type value. + ogRecord := tlvRecord.Record() + + return Record{ + value: ogRecord.value, + typ: t.recordType.typeVal(), + staticSize: ogRecord.staticSize, + sizeFunc: ogRecord.sizeFunc, + encoder: ogRecord.encoder, + decoder: ogRecord.decoder, + } } // OptionalRecordT is a high-order type that represents an optional TLV record. diff --git a/tlv/record_type_test.go b/tlv/record_type_test.go index 29d2fc7b6..0a4d38b60 100644 --- a/tlv/record_type_test.go +++ b/tlv/record_type_test.go @@ -63,6 +63,10 @@ type coolWireMsg struct { CsvDelay RecordT[TlvType1, wireCsv] } +type coolWireMsgDiffContext struct { + CsvDelay RecordT[TlvType3, wireCsv] +} + // TestRecordTFromRecord tests that we can create a RecordT type from an // existing record type and encode/decode as normal. func TestRecordTFromRecord(t *testing.T) { @@ -91,3 +95,24 @@ func TestRecordTFromRecord(t *testing.T) { require.Equal(t, wireMsg, wireMsg2) } + +// TestRecordTFromRecordTypeOverride tests that we can create a RecordT type +// from an existing record type and encode/decode as normal. In this variant, +// we make sure that we can use the type system to override the type of an +// original record. +func TestRecordTFromRecordTypeOverride(t *testing.T) { + t.Parallel() + + // First, we'll make a new wire message. Instead of using the TLV type + // of 1 (hard coded in the Record() method defined above), we'll + // instead use TLvType3, as we want to use the same encode/decode, but + // in a context with a different integer type. + val := wireCsv(5) + wireMsg := coolWireMsgDiffContext{ + CsvDelay: NewRecordT[TlvType3](val), + } + + // If we extract the record, we should see that the type is now 3. + tlvRecord := wireMsg.CsvDelay.Record() + require.Equal(t, tlvRecord.Type(), Type(3)) +}