From c882223ead5e8d84af33feb334bfc7a90a5bde6d Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 19 Sep 2023 20:43:14 +0200 Subject: [PATCH] lnwire+discovery: rename ShortChannelIDEncoding to QueryEncoding Since the the encoding can be used for multiple different fields, we rename it here to be more generic. --- discovery/syncer.go | 4 +- discovery/syncer_test.go | 2 +- lnwire/encoding.go | 17 ++++ lnwire/lnwire.go | 6 +- lnwire/query_short_chan_ids.go | 33 ++------ lnwire/query_short_chan_ids_test.go | 4 +- lnwire/reply_channel_range.go | 2 +- lnwire/reply_channel_range_test.go | 2 +- lnwire/timestamps.go | 123 ++++++++++++++++++++++++++++ lnwire/writer.go | 5 +- lnwire/writer_test.go | 4 +- 11 files changed, 162 insertions(+), 40 deletions(-) create mode 100644 lnwire/encoding.go create mode 100644 lnwire/timestamps.go diff --git a/discovery/syncer.go b/discovery/syncer.go index 722519fb0..3f03bcea8 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -185,7 +185,7 @@ var ( // encodingTypeToChunkSize maps an encoding type, to the max number of // short chan ID's using the encoding type that we can fit into a // single message safely. - encodingTypeToChunkSize = map[lnwire.ShortChanIDEncoding]int32{ + encodingTypeToChunkSize = map[lnwire.QueryEncoding]int32{ lnwire.EncodingSortedPlain: 8000, } @@ -232,7 +232,7 @@ type gossipSyncerCfg struct { // encodingType is the current encoding type we're aware of. Requests // with different encoding types will be rejected. - encodingType lnwire.ShortChanIDEncoding + encodingType lnwire.QueryEncoding // chunkSize is the max number of short chan IDs using the syncer's // encoding type that we can fit into a single message safely. diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index a7b514db8..32b7b7032 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -156,7 +156,7 @@ var _ ChannelGraphTimeSeries = (*mockChannelGraphTimeSeries)(nil) // ignored. If no flags are provided, both a channelGraphSyncer and replyHandler // will be spawned by default. func newTestSyncer(hID lnwire.ShortChannelID, - encodingType lnwire.ShortChanIDEncoding, chunkSize int32, + encodingType lnwire.QueryEncoding, chunkSize int32, flags ...bool) (chan []lnwire.Message, *GossipSyncer, *mockChannelGraphTimeSeries) { diff --git a/lnwire/encoding.go b/lnwire/encoding.go new file mode 100644 index 000000000..e04b2b01d --- /dev/null +++ b/lnwire/encoding.go @@ -0,0 +1,17 @@ +package lnwire + +// QueryEncoding is an enum-like type that represents exactly how a set data is +// encoded on the wire. +type QueryEncoding uint8 + +const ( + // EncodingSortedPlain signals that the set of data is encoded using the + // regular encoding, in a sorted order. + EncodingSortedPlain QueryEncoding = 0 + + // EncodingSortedZlib signals that the set of data is encoded by first + // sorting the set of channel ID's, as then compressing them using zlib. + // + // NOTE: this should no longer be used or accepted. + EncodingSortedZlib QueryEncoding = 1 +) diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index 50a547e22..8ab082b0b 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -85,7 +85,7 @@ func WriteElement(w *bytes.Buffer, element interface{}) error { return err } - case ShortChanIDEncoding: + case QueryEncoding: var b [1]byte b[0] = uint8(e) if _, err := w.Write(b[:]); err != nil { @@ -509,12 +509,12 @@ func ReadElement(r io.Reader, element interface{}) error { } *e = alias - case *ShortChanIDEncoding: + case *QueryEncoding: var b [1]uint8 if _, err := r.Read(b[:]); err != nil { return err } - *e = ShortChanIDEncoding(b[0]) + *e = QueryEncoding(b[0]) case *uint8: var b [1]uint8 diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index 323a936db..281c10eae 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -11,23 +11,6 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" ) -// ShortChanIDEncoding is an enum-like type that represents exactly how a set -// of short channel ID's is encoded on the wire. The set of encodings allows us -// to take advantage of the structure of a list of short channel ID's to -// achieving a high degree of compression. -type ShortChanIDEncoding uint8 - -const ( - // EncodingSortedPlain signals that the set of short channel ID's is - // encoded using the regular encoding, in a sorted order. - EncodingSortedPlain ShortChanIDEncoding = 0 - - // EncodingSortedZlib signals that the set of short channel ID's is - // encoded by first sorting the set of channel ID's, as then - // compressing them using zlib. - EncodingSortedZlib ShortChanIDEncoding = 1 -) - const ( // maxZlibBufSize is the max number of bytes that we'll accept from a // zlib decoding instance. We do this in order to limit the total @@ -56,7 +39,7 @@ var zlibDecodeMtx sync.Mutex // ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we // came across an unknown short channel ID encoding, and therefore were unable // to continue parsing. -func ErrUnknownShortChanIDEncoding(encoding ShortChanIDEncoding) error { +func ErrUnknownShortChanIDEncoding(encoding QueryEncoding) error { return fmt.Errorf("unknown short chan id encoding: %v", encoding) } @@ -76,7 +59,7 @@ type QueryShortChanIDs struct { // EncodingType is a signal to the receiver of the message that // indicates exactly how the set of short channel ID's that follow have // been encoded. - EncodingType ShortChanIDEncoding + EncodingType QueryEncoding // ShortChanIDs is a slice of decoded short channel ID's. ShortChanIDs []ShortChannelID @@ -94,7 +77,7 @@ type QueryShortChanIDs struct { } // NewQueryShortChanIDs creates a new QueryShortChanIDs message. -func NewQueryShortChanIDs(h chainhash.Hash, e ShortChanIDEncoding, +func NewQueryShortChanIDs(h chainhash.Hash, e QueryEncoding, s []ShortChannelID) *QueryShortChanIDs { return &QueryShortChanIDs{ @@ -130,7 +113,7 @@ func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error { // encoded. The first byte of the body details how the short chan ID's were // encoded. We'll use this type to govern exactly how we go about encoding the // set of short channel ID's. -func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, error) { +func decodeShortChanIDs(r io.Reader) (QueryEncoding, []ShortChannelID, error) { // First, we'll attempt to read the number of bytes in the body of the // set of encoded short channel ID's. var numBytesResp uint16 @@ -150,7 +133,7 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err // The first byte is the encoding type, so we'll extract that so we can // continue our parsing. - encodingType := ShortChanIDEncoding(queryBody[0]) + encodingType := QueryEncoding(queryBody[0]) // Before continuing, we'll snip off the first byte of the query body // as that was just the encoding type. @@ -309,7 +292,7 @@ func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error { // encodeShortChanIDs encodes the passed short channel ID's into the passed // io.Writer, respecting the specified encoding type. -func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding, +func encodeShortChanIDs(w *bytes.Buffer, encodingType QueryEncoding, shortChanIDs []ShortChannelID, noSort bool) error { // For both of the current encoding types, the channel ID's are to be @@ -337,7 +320,7 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding, // We'll then write out the encoding that that follows the // actual encoded short channel ID's. - err := WriteShortChanIDEncoding(w, encodingType) + err := WriteQueryEncoding(w, encodingType) if err != nil { return err } @@ -421,7 +404,7 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding, if err := WriteUint16(w, uint16(numBytesBody)); err != nil { return err } - err := WriteShortChanIDEncoding(w, encodingType) + err := WriteQueryEncoding(w, encodingType) if err != nil { return err } diff --git a/lnwire/query_short_chan_ids_test.go b/lnwire/query_short_chan_ids_test.go index e42184044..996c9f744 100644 --- a/lnwire/query_short_chan_ids_test.go +++ b/lnwire/query_short_chan_ids_test.go @@ -7,7 +7,7 @@ import ( type unsortedSidTest struct { name string - encType ShortChanIDEncoding + encType QueryEncoding sids []ShortChannelID } @@ -79,7 +79,7 @@ func TestQueryShortChanIDsUnsorted(t *testing.T) { func TestQueryShortChanIDsZero(t *testing.T) { testCases := []struct { name string - encoding ShortChanIDEncoding + encoding QueryEncoding }{ { name: "plain", diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index 9dc0fca9c..e02ccd474 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -33,7 +33,7 @@ type ReplyChannelRange struct { // EncodingType is a signal to the receiver of the message that // indicates exactly how the set of short channel ID's that follow have // been encoded. - EncodingType ShortChanIDEncoding + EncodingType QueryEncoding // ShortChanIDs is a slice of decoded short channel ID's. ShortChanIDs []ShortChannelID diff --git a/lnwire/reply_channel_range_test.go b/lnwire/reply_channel_range_test.go index ff3414958..85863ae53 100644 --- a/lnwire/reply_channel_range_test.go +++ b/lnwire/reply_channel_range_test.go @@ -44,7 +44,7 @@ func TestReplyChannelRangeEmpty(t *testing.T) { emptyChannelsTests := []struct { name string - encType ShortChanIDEncoding + encType QueryEncoding encodedHex string }{ { diff --git a/lnwire/timestamps.go b/lnwire/timestamps.go new file mode 100644 index 000000000..1d0d8c6a4 --- /dev/null +++ b/lnwire/timestamps.go @@ -0,0 +1,123 @@ +package lnwire + +import ( + "bytes" + "fmt" + "io" + + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // TimestampsRecordType is the TLV number of the timestamps TLV record + // in the reply_channel_range message. + TimestampsRecordType tlv.Type = 1 + + // timestampPairSize is the number of bytes required to encode two + // timestamps. Each timestamp is four bytes. + timestampPairSize = 8 +) + +// Timestamps is a type representing the timestamps TLV field used in the +// reply_channel_range message to communicate the timestamps info of the updates +// of the SCID list being communicated. +type Timestamps []ChanUpdateTimestamps + +// ChanUpdateTimestamps holds the timestamp info of the latest known channel +// updates corresponding to the two sides of a channel. +type ChanUpdateTimestamps struct { + Timestamp1 uint32 + Timestamp2 uint32 +} + +// Record constructs the tlv.Record from the Timestamps. +func (t *Timestamps) Record() tlv.Record { + return tlv.MakeDynamicRecord( + TimestampsRecordType, t, t.encodedLen, timeStampsEncoder, + timeStampsDecoder, + ) +} + +// encodedLen calculates the length of the encoded Timestamps. +func (t *Timestamps) encodedLen() uint64 { + return uint64(1 + timestampPairSize*(len(*t))) +} + +// timeStampsEncoder encodes the Timestamps and writes the encoded bytes to the +// given writer. +func timeStampsEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*Timestamps); ok { + var buf bytes.Buffer + + // Add the encoding byte. + err := WriteQueryEncoding(&buf, EncodingSortedPlain) + if err != nil { + return err + } + + // For each timestamp, write 4 byte timestamp of node 1 and the + // 4 byte timestamp of node 2. + for _, timestamps := range *v { + err = WriteUint32(&buf, timestamps.Timestamp1) + if err != nil { + return err + } + + err = WriteUint32(&buf, timestamps.Timestamp2) + if err != nil { + return err + } + } + + _, err = w.Write(buf.Bytes()) + + return err + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.Timestamps") +} + +// timeStampsDecoder attempts to read and reconstruct a Timestamps object from +// the given reader. +func timeStampsDecoder(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if v, ok := val.(*Timestamps); ok { + var encodingByte [1]byte + if _, err := r.Read(encodingByte[:]); err != nil { + return err + } + + encoding := QueryEncoding(encodingByte[0]) + if encoding != EncodingSortedPlain { + return fmt.Errorf("unsupported encoding: %x", encoding) + } + + // The number of timestamps bytes is equal to the passed length + // minus one since the first byte is used for the encoding type. + numTimestampBytes := l - 1 + + if numTimestampBytes%timestampPairSize != 0 { + return fmt.Errorf("whole number of timestamps not " + + "encoded") + } + + numTimestamps := int(numTimestampBytes) / timestampPairSize + timestamps := make(Timestamps, numTimestamps) + for i := 0; i < numTimestamps; i++ { + err := ReadElements( + r, ×tamps[i].Timestamp1, + ×tamps[i].Timestamp2, + ) + if err != nil { + return err + } + } + + *v = timestamps + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.Timestamps") +} diff --git a/lnwire/writer.go b/lnwire/writer.go index 671ebfdc0..fa6247de0 100644 --- a/lnwire/writer.go +++ b/lnwire/writer.go @@ -205,9 +205,8 @@ func WriteColorRGBA(buf *bytes.Buffer, e color.RGBA) error { return WriteUint8(buf, e.B) } -// WriteShortChanIDEncoding appends the ShortChanIDEncoding to the provided -// buffer. -func WriteShortChanIDEncoding(buf *bytes.Buffer, e ShortChanIDEncoding) error { +// WriteQueryEncoding appends the QueryEncoding to the provided buffer. +func WriteQueryEncoding(buf *bytes.Buffer, e QueryEncoding) error { return WriteUint8(buf, uint8(e)) } diff --git a/lnwire/writer_test.go b/lnwire/writer_test.go index ccdeabcf6..3e2550443 100644 --- a/lnwire/writer_test.go +++ b/lnwire/writer_test.go @@ -225,10 +225,10 @@ func TestWriteColorRGBA(t *testing.T) { func TestWriteShortChanIDEncoding(t *testing.T) { buf := new(bytes.Buffer) - data := ShortChanIDEncoding(1) + data := QueryEncoding(1) expectedBytes := []byte{1} - err := WriteShortChanIDEncoding(buf, data) + err := WriteQueryEncoding(buf, data) require.NoError(t, err) require.Equal(t, expectedBytes, buf.Bytes())