mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 01:43:16 +01:00
af50694643
Introduces a couple of new helper functions for both the ExtraOpaqueData and CustomRecords types along with new methods on the ExtraOpaqueData.
264 lines
7.2 KiB
Go
264 lines
7.2 KiB
Go
package lnwire
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"sort"
|
|
|
|
"github.com/lightningnetwork/lnd/fn"
|
|
"github.com/lightningnetwork/lnd/tlv"
|
|
)
|
|
|
|
const (
|
|
// MinCustomRecordsTlvType is the minimum custom records TLV type as
|
|
// defined in BOLT 01.
|
|
MinCustomRecordsTlvType = 65536
|
|
)
|
|
|
|
// CustomRecords stores a set of custom key/value pairs. Map keys are TLV types
|
|
// which must be greater than or equal to MinCustomRecordsTlvType.
|
|
type CustomRecords map[uint64][]byte
|
|
|
|
// NewCustomRecords creates a new CustomRecords instance from a
|
|
// tlv.TypeMap.
|
|
func NewCustomRecords(tlvMap tlv.TypeMap) (CustomRecords, error) {
|
|
// Make comparisons in unit tests easy by returning nil if the map is
|
|
// empty.
|
|
if len(tlvMap) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
customRecords := make(CustomRecords, len(tlvMap))
|
|
for k, v := range tlvMap {
|
|
customRecords[uint64(k)] = v
|
|
}
|
|
|
|
// Validate the custom records.
|
|
err := customRecords.Validate()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("custom records from tlv map "+
|
|
"validation error: %w", err)
|
|
}
|
|
|
|
return customRecords, nil
|
|
}
|
|
|
|
// ParseCustomRecords creates a new CustomRecords instance from a tlv.Blob.
|
|
func ParseCustomRecords(b tlv.Blob) (CustomRecords, error) {
|
|
return ParseCustomRecordsFrom(bytes.NewReader(b))
|
|
}
|
|
|
|
// ParseCustomRecordsFrom creates a new CustomRecords instance from a reader.
|
|
func ParseCustomRecordsFrom(r io.Reader) (CustomRecords, error) {
|
|
typeMap, err := DecodeRecords(r)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error decoding HTLC record: %w", err)
|
|
}
|
|
|
|
return NewCustomRecords(typeMap)
|
|
}
|
|
|
|
// Validate checks that all custom records are in the custom type range.
|
|
func (c CustomRecords) Validate() error {
|
|
if c == nil {
|
|
return nil
|
|
}
|
|
|
|
for key := range c {
|
|
if key < MinCustomRecordsTlvType {
|
|
return fmt.Errorf("custom records entry with TLV "+
|
|
"type below min: %d", MinCustomRecordsTlvType)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Copy returns a copy of the custom records.
|
|
func (c CustomRecords) Copy() CustomRecords {
|
|
if c == nil {
|
|
return nil
|
|
}
|
|
|
|
customRecords := make(CustomRecords, len(c))
|
|
for k, v := range c {
|
|
customRecords[k] = v
|
|
}
|
|
|
|
return customRecords
|
|
}
|
|
|
|
// ExtendRecordProducers extends the given records slice with the custom
|
|
// records. The resultant records slice will be sorted if the given records
|
|
// slice contains TLV types greater than or equal to MinCustomRecordsTlvType.
|
|
func (c CustomRecords) ExtendRecordProducers(
|
|
producers []tlv.RecordProducer) ([]tlv.RecordProducer, error) {
|
|
|
|
// If the custom records are nil or empty, there is nothing to do.
|
|
if len(c) == 0 {
|
|
return producers, nil
|
|
}
|
|
|
|
// Validate the custom records.
|
|
err := c.Validate()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Ensure that the existing records slice TLV types are not also present
|
|
// in the custom records. If they are, the resultant extended records
|
|
// slice would erroneously contain duplicate TLV types.
|
|
for _, rp := range producers {
|
|
record := rp.Record()
|
|
recordTlvType := uint64(record.Type())
|
|
|
|
_, foundDuplicateTlvType := c[recordTlvType]
|
|
if foundDuplicateTlvType {
|
|
return nil, fmt.Errorf("custom records contains a TLV "+
|
|
"type that is already present in the "+
|
|
"existing records: %d", recordTlvType)
|
|
}
|
|
}
|
|
|
|
// Convert the custom records map to a TLV record producer slice and
|
|
// append them to the exiting records slice.
|
|
customRecordProducers := RecordsAsProducers(tlv.MapToRecords(c))
|
|
producers = append(producers, customRecordProducers...)
|
|
|
|
// If the records slice which was given as an argument included TLV
|
|
// values greater than or equal to the minimum custom records TLV type
|
|
// we will sort the extended records slice to ensure that it is ordered
|
|
// correctly.
|
|
SortProducers(producers)
|
|
|
|
return producers, nil
|
|
}
|
|
|
|
// RecordProducers returns a slice of record producers for the custom records.
|
|
func (c CustomRecords) RecordProducers() []tlv.RecordProducer {
|
|
// If the custom records are nil or empty, return an empty slice.
|
|
if len(c) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Convert the custom records map to a TLV record producer slice.
|
|
records := tlv.MapToRecords(c)
|
|
|
|
return RecordsAsProducers(records)
|
|
}
|
|
|
|
// Serialize serializes the custom records into a byte slice.
|
|
func (c CustomRecords) Serialize() ([]byte, error) {
|
|
records := tlv.MapToRecords(c)
|
|
return EncodeRecords(records)
|
|
}
|
|
|
|
// SerializeTo serializes the custom records into the given writer.
|
|
func (c CustomRecords) SerializeTo(w io.Writer) error {
|
|
records := tlv.MapToRecords(c)
|
|
return EncodeRecordsTo(w, records)
|
|
}
|
|
|
|
// ProduceRecordsSorted converts a slice of record producers into a slice of
|
|
// records and then sorts it by type.
|
|
func ProduceRecordsSorted(recordProducers ...tlv.RecordProducer) []tlv.Record {
|
|
records := fn.Map(func(producer tlv.RecordProducer) tlv.Record {
|
|
return producer.Record()
|
|
}, recordProducers)
|
|
|
|
// Ensure that the set of records are sorted before we attempt to
|
|
// decode from the stream, to ensure they're canonical.
|
|
tlv.SortRecords(records)
|
|
|
|
return records
|
|
}
|
|
|
|
// SortProducers sorts the given record producers by their type.
|
|
func SortProducers(producers []tlv.RecordProducer) {
|
|
sort.Slice(producers, func(i, j int) bool {
|
|
recordI := producers[i].Record()
|
|
recordJ := producers[j].Record()
|
|
return recordI.Type() < recordJ.Type()
|
|
})
|
|
}
|
|
|
|
// TlvMapToRecords converts a TLV map into a slice of records.
|
|
func TlvMapToRecords(tlvMap tlv.TypeMap) []tlv.Record {
|
|
tlvMapGeneric := make(map[uint64][]byte)
|
|
for k, v := range tlvMap {
|
|
tlvMapGeneric[uint64(k)] = v
|
|
}
|
|
|
|
return tlv.MapToRecords(tlvMapGeneric)
|
|
}
|
|
|
|
// RecordsAsProducers converts a slice of records into a slice of record
|
|
// producers.
|
|
func RecordsAsProducers(records []tlv.Record) []tlv.RecordProducer {
|
|
return fn.Map(func(record tlv.Record) tlv.RecordProducer {
|
|
return &record
|
|
}, records)
|
|
}
|
|
|
|
// EncodeRecords encodes the given records into a byte slice.
|
|
func EncodeRecords(records []tlv.Record) ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
if err := EncodeRecordsTo(&buf, records); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
// EncodeRecordsTo encodes the given records into the given writer.
|
|
func EncodeRecordsTo(w io.Writer, records []tlv.Record) error {
|
|
tlvStream, err := tlv.NewStream(records...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return tlvStream.Encode(w)
|
|
}
|
|
|
|
// DecodeRecords decodes the given byte slice into the given records and returns
|
|
// the rest as a TLV type map.
|
|
func DecodeRecords(r io.Reader,
|
|
records ...tlv.Record) (tlv.TypeMap, error) {
|
|
|
|
tlvStream, err := tlv.NewStream(records...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return tlvStream.DecodeWithParsedTypes(r)
|
|
}
|
|
|
|
// DecodeRecordsP2P decodes the given byte slice into the given records and
|
|
// returns the rest as a TLV type map. This function is identical to
|
|
// DecodeRecords except that the record size is capped at 65535.
|
|
func DecodeRecordsP2P(r *bytes.Reader,
|
|
records ...tlv.Record) (tlv.TypeMap, error) {
|
|
|
|
tlvStream, err := tlv.NewStream(records...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return tlvStream.DecodeWithParsedTypesP2P(r)
|
|
}
|
|
|
|
// AssertUniqueTypes asserts that the given records have unique types.
|
|
func AssertUniqueTypes(r []tlv.Record) error {
|
|
seen := make(fn.Set[tlv.Type], len(r))
|
|
for _, record := range r {
|
|
t := record.Type()
|
|
if seen.Contains(t) {
|
|
return fmt.Errorf("duplicate record type: %d", t)
|
|
}
|
|
seen.Add(t)
|
|
}
|
|
|
|
return nil
|
|
}
|