mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-19 05:45:21 +01:00
Merge pull request #8334 from lightningnetwork/tlv-record-enchancements
tlv: various enhancements to the new RecordT type
This commit is contained in:
commit
2f04ce7c6e
@ -20,10 +20,12 @@ package tlv
|
||||
|
||||
type tlvType{{ $index }} struct{}
|
||||
|
||||
func (t *tlvType{{ $index }}) typeVal() Type {
|
||||
func (t *tlvType{{ $index }}) TypeVal() Type {
|
||||
return {{ $index }}
|
||||
}
|
||||
|
||||
func (t *tlvType{{ $index }}) tlv() {}
|
||||
|
||||
type TlvType{{ $index }} = *tlvType{{ $index }}
|
||||
{{- end }}
|
||||
`
|
||||
|
@ -62,11 +62,29 @@ func (t *RecordT[T, V]) Record() Record {
|
||||
tlvRecord, ok := any(&t.Val).(RecordProducer)
|
||||
if !ok {
|
||||
return MakePrimitiveRecord(
|
||||
t.recordType.typeVal(), &t.Val,
|
||||
t.recordType.TypeVal(), &t.Val,
|
||||
)
|
||||
}
|
||||
|
||||
return tlvRecord.Record()
|
||||
// To enforce proper usage of the RecordT type, we'll make a wrapper
|
||||
// record that uses the proper internal type value.
|
||||
ogRecord := tlvRecord.Record()
|
||||
|
||||
return Record{
|
||||
value: ogRecord.value,
|
||||
typ: t.recordType.TypeVal(),
|
||||
staticSize: ogRecord.staticSize,
|
||||
sizeFunc: ogRecord.sizeFunc,
|
||||
encoder: ogRecord.encoder,
|
||||
decoder: ogRecord.decoder,
|
||||
}
|
||||
}
|
||||
|
||||
// TlvType returns the type of the record. This is the value used to identify
|
||||
// this type on the wire. This value is bound to the specified TlvType type
|
||||
// param.
|
||||
func (t *RecordT[T, V]) TlvType() Type {
|
||||
return t.recordType.TypeVal()
|
||||
}
|
||||
|
||||
// OptionalRecordT is a high-order type that represents an optional TLV record.
|
||||
@ -76,6 +94,29 @@ type OptionalRecordT[T TlvType, V any] struct {
|
||||
fn.Option[RecordT[T, V]]
|
||||
}
|
||||
|
||||
// TlvType returns the type of the record. This is the value used to identify
|
||||
// this type on the wire. This value is bound to the specified TlvType type
|
||||
// param.
|
||||
func (t *OptionalRecordT[T, V]) TlvType() Type {
|
||||
zeroRecord := ZeroRecordT[T, V]()
|
||||
return zeroRecord.TlvType()
|
||||
}
|
||||
|
||||
// WhenSomeV executes the given function if the optional record is present.
|
||||
// This operates on the inner most type, V, which is the value of the record.
|
||||
func (t *OptionalRecordT[T, V]) WhenSomeV(f func(V)) {
|
||||
t.Option.WhenSome(func(r RecordT[T, V]) {
|
||||
f(r.Val)
|
||||
})
|
||||
}
|
||||
|
||||
// SomeRecordT creates a new OptionalRecordT type from a given RecordT type.
|
||||
func SomeRecordT[T TlvType, V any](record RecordT[T, V]) OptionalRecordT[T, V] {
|
||||
return OptionalRecordT[T, V]{
|
||||
Option: fn.Some(record),
|
||||
}
|
||||
}
|
||||
|
||||
// ZeroRecordT returns a zero value of the RecordT type.
|
||||
func ZeroRecordT[T TlvType, V any]() RecordT[T, V] {
|
||||
var v V
|
||||
|
@ -63,6 +63,10 @@ type coolWireMsg struct {
|
||||
CsvDelay RecordT[TlvType1, wireCsv]
|
||||
}
|
||||
|
||||
type coolWireMsgDiffContext struct {
|
||||
CsvDelay RecordT[TlvType3, wireCsv]
|
||||
}
|
||||
|
||||
// TestRecordTFromRecord tests that we can create a RecordT type from an
|
||||
// existing record type and encode/decode as normal.
|
||||
func TestRecordTFromRecord(t *testing.T) {
|
||||
@ -91,3 +95,24 @@ func TestRecordTFromRecord(t *testing.T) {
|
||||
|
||||
require.Equal(t, wireMsg, wireMsg2)
|
||||
}
|
||||
|
||||
// TestRecordTFromRecordTypeOverride tests that we can create a RecordT type
|
||||
// from an existing record type and encode/decode as normal. In this variant,
|
||||
// we make sure that we can use the type system to override the type of an
|
||||
// original record.
|
||||
func TestRecordTFromRecordTypeOverride(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// First, we'll make a new wire message. Instead of using the TLV type
|
||||
// of 1 (hard coded in the Record() method defined above), we'll
|
||||
// instead use TLvType3, as we want to use the same encode/decode, but
|
||||
// in a context with a different integer type.
|
||||
val := wireCsv(5)
|
||||
wireMsg := coolWireMsgDiffContext{
|
||||
CsvDelay: NewRecordT[TlvType3](val),
|
||||
}
|
||||
|
||||
// If we extract the record, we should see that the type is now 3.
|
||||
tlvRecord := wireMsg.CsvDelay.Record()
|
||||
require.Equal(t, tlvRecord.Type(), Type(3))
|
||||
}
|
||||
|
@ -5,7 +5,13 @@ import "fmt"
|
||||
// TlvType is an interface used to enable binding the integer type of a TLV
|
||||
// record to the type at compile time.
|
||||
type TlvType interface {
|
||||
typeVal() Type
|
||||
// TypeVal returns the integer TLV type that this TlvType struct
|
||||
// instance maps to.
|
||||
TypeVal() Type
|
||||
|
||||
// tlv is an internal method to make this a "sealed" interface, meaning
|
||||
// only this package can declare new instances.
|
||||
tlv()
|
||||
}
|
||||
|
||||
//go:generate go run internal/gen/gen_tlv_types.go -o tlv_types_generated.go
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user