lnwire: create common encoder/decoder for raw feature vectors

We'll need to pack feature vectors for route blinding, so we pull
the encoding/decoding out into separate functions (currently
contained in ChannelType). Though it's more lines of code, we keep
most of the ChannelType assertions so that we strictly enforce
use of the alias.
This commit is contained in:
Carla Kirk-Cohen 2023-02-01 10:51:01 -05:00
parent 42069ef2f8
commit 3cc50ced55
No known key found for this signature in database
GPG Key ID: 4CA7FE54A6213C91
2 changed files with 47 additions and 8 deletions

View File

@ -19,7 +19,7 @@ type ChannelType RawFeatureVector
// featureBitLen returns the length in bytes of the encoded feature bits.
func (c ChannelType) featureBitLen() uint64 {
fv := RawFeatureVector(c)
return uint64(fv.SerializeSize())
return fv.sizeFunc()
}
// Record returns a TLV record that can be used to encode/decode the channel
@ -34,25 +34,27 @@ func (c *ChannelType) Record() tlv.Record {
// channelTypeEncoder is a custom TLV encoder for the ChannelType record.
func channelTypeEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*ChannelType); ok {
// Encode the feature bits as a byte slice without its length
// prepended, as that's already taken care of by the TLV record.
fv := RawFeatureVector(*v)
return fv.encode(w, fv.SerializeSize(), 8)
return rawFeatureEncoder(w, &fv, buf)
}
return tlv.NewTypeForEncodingErr(val, "lnwire.ChannelType")
return tlv.NewTypeForEncodingErr(val, "*lnwire.ChannelType")
}
// channelTypeDecoder is a custom TLV decoder for the ChannelType record.
func channelTypeDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
func channelTypeDecoder(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
if v, ok := val.(*ChannelType); ok {
fv := NewRawFeatureVector()
if err := fv.decode(r, int(l), 8); err != nil {
if err := rawFeatureDecoder(r, fv, buf, l); err != nil {
return err
}
*v = ChannelType(*fv)
return nil
}
return tlv.NewTypeForEncodingErr(val, "lnwire.ChannelType")
return tlv.NewTypeForEncodingErr(val, "*lnwire.ChannelType")
}

View File

@ -5,6 +5,8 @@ import (
"errors"
"fmt"
"io"
"github.com/lightningnetwork/lnd/tlv"
)
var (
@ -612,6 +614,41 @@ func (fv *RawFeatureVector) decode(r io.Reader, length, width int) error {
return nil
}
// sizeFunc returns the length required to encode the feature vector.
func (fv *RawFeatureVector) sizeFunc() uint64 {
return uint64(fv.SerializeSize())
}
// Record returns a TLV record that can be used to encode/decode raw feature
// vectors. Note that the length of the feature vector is not included, because
// it is covered by the TLV record's length field.
func (fv *RawFeatureVector) Record(recordType tlv.Type) tlv.Record {
return tlv.MakeDynamicRecord(
recordType, fv, fv.sizeFunc, rawFeatureEncoder,
rawFeatureDecoder,
)
}
// rawFeatureEncoder is a custom TLV encoder for raw feature vectors.
func rawFeatureEncoder(w io.Writer, val interface{}, _ *[8]byte) error {
if f, ok := val.(*RawFeatureVector); ok {
return f.encode(w, f.SerializeSize(), 8)
}
return tlv.NewTypeForEncodingErr(val, "*lnwire.RawFeatureVector")
}
// rawFeatureDecoder is a custom TLV decoder for raw feature vectors.
func rawFeatureDecoder(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if f, ok := val.(*RawFeatureVector); ok {
return f.decode(r, int(l), 8)
}
return tlv.NewTypeForEncodingErr(val, "*lnwire.RawFeatureVector")
}
// FeatureVector represents a set of enabled features. The set stores
// information on enabled flags and metadata about the feature names. A feature
// vector is serializable to a compact byte representation that is included in