diff --git a/tlv/record_type.go b/tlv/record_type.go index 248ebc415..ec893a06b 100644 --- a/tlv/record_type.go +++ b/tlv/record_type.go @@ -1,6 +1,8 @@ package tlv import ( + "testing" + "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/fn" "golang.org/x/exp/constraints" @@ -115,6 +117,29 @@ func (t *OptionalRecordT[T, V]) WhenSomeV(f func(V)) { }) } +// UnwrapOrFailV is used to extract a value from an option within a test +// context. If the option is None, then the test fails. This gives the +// underlying value of the record, rather then the record itself. +func (o *OptionalRecordT[T, V]) UnwrapOrFailV(t *testing.T) V { + inner := o.Option.UnwrapOrFail(t) + + return inner.Val +} + +// UnwrapOrErr is used to extract a value from an option, if the option is +// empty, then the specified error is returned directly. This gives the +// underlying value of the record, instead of the record itself. +func (o *OptionalRecordT[T, V]) UnwrapOrErrV(err error) (V, error) { + var zero V + + inner, err := o.Option.UnwrapOrErr(err) + if err != nil { + return zero, err + } + + return inner.Val, nil +} + // SomeRecordT creates a new OptionalRecordT type from a given RecordT type. func SomeRecordT[T TlvType, V any](record RecordT[T, V]) OptionalRecordT[T, V] { return OptionalRecordT[T, V]{