lnwire: add timestamps to ReplyChannelRange msg

This commit is contained in:
Elle Mouton 2023-09-19 20:53:58 +02:00
parent 4872010779
commit 49a0370dcd
No known key found for this signature in database
GPG key ID: D7D916376026F177
3 changed files with 368 additions and 20 deletions

View file

@ -1159,12 +1159,25 @@ func TestLightningWireProtocol(t *testing.T) {
req.EncodingType = EncodingSortedPlain
}
numChanIDs := rand.Int31n(5000)
numChanIDs := rand.Int31n(4000)
for i := int32(0); i < numChanIDs; i++ {
req.ShortChanIDs = append(req.ShortChanIDs,
NewShortChanIDFromInt(uint64(r.Int63())))
}
// With a 50/50 chance, add some timestamps.
if r.Int31()%2 == 0 {
for i := int32(0); i < numChanIDs; i++ {
timestamps := ChanUpdateTimestamps{
Timestamp1: rand.Uint32(),
Timestamp2: rand.Uint32(),
}
req.Timestamps = append(
req.Timestamps, timestamps,
)
}
}
v[0] = reflect.ValueOf(req)
},
MsgQueryChannelRange: func(v []reflect.Value, r *rand.Rand) {

View file

@ -2,11 +2,13 @@ package lnwire
import (
"bytes"
"fmt"
"io"
"math"
"sort"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/tlv"
)
// ReplyChannelRange is the response to the QueryChannelRange message. It
@ -39,6 +41,12 @@ type ReplyChannelRange struct {
// ShortChanIDs is a slice of decoded short channel ID's.
ShortChanIDs []ShortChannelID
// Timestamps is an optional set of timestamps corresponding to the
// latest timestamps for the channel update messages corresponding to
// those referenced in the ShortChanIDs list. If this field is used,
// then the length must match the length of ShortChanIDs.
Timestamps Timestamps
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
@ -53,7 +61,9 @@ type ReplyChannelRange struct {
// NewReplyChannelRange creates a new empty ReplyChannelRange message.
func NewReplyChannelRange() *ReplyChannelRange {
return &ReplyChannelRange{}
return &ReplyChannelRange{
ExtraData: make([]byte, 0),
}
}
// A compile time check to ensure ReplyChannelRange implements the
@ -80,7 +90,27 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error {
return err
}
return c.ExtraData.Decode(r)
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
var timeStamps Timestamps
typeMap, err := tlvRecords.ExtractRecords(&timeStamps)
if err != nil {
return err
}
// Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[TimestampsRecordType]; ok && val == nil {
c.Timestamps = timeStamps
}
if len(tlvRecords) != 0 {
c.ExtraData = tlvRecords
}
return nil
}
// Encode serializes the target ReplyChannelRange into the passed io.Writer
@ -108,10 +138,48 @@ func (c *ReplyChannelRange) Encode(w *bytes.Buffer, pver uint32) error {
// sorted in place, so we'll do that now. The sorting is applied unless
// we were specifically requested not to for testing purposes.
if !c.noSort {
var scidPreSortIndex map[uint64]int
if len(c.Timestamps) != 0 {
// Sanity check that a timestamp was provided for each
// SCID.
if len(c.Timestamps) != len(c.ShortChanIDs) {
return fmt.Errorf("must provide a timestamp " +
"pair for each of the given SCIDs")
}
// Create a map from SCID value to the original index of
// the SCID in the unsorted list.
scidPreSortIndex = make(
map[uint64]int, len(c.ShortChanIDs),
)
for i, scid := range c.ShortChanIDs {
scidPreSortIndex[scid.ToUint64()] = i
}
// Sanity check that there were no duplicates in the
// SCID list.
if len(scidPreSortIndex) != len(c.ShortChanIDs) {
return fmt.Errorf("scid list should not " +
"contain duplicates")
}
}
// Now sort the SCIDs.
sort.Slice(c.ShortChanIDs, func(i, j int) bool {
return c.ShortChanIDs[i].ToUint64() <
c.ShortChanIDs[j].ToUint64()
})
if len(c.Timestamps) != 0 {
timestamps := make(Timestamps, len(c.Timestamps))
for i, scid := range c.ShortChanIDs {
timestamps[i] = []ChanUpdateTimestamps(
c.Timestamps,
)[scidPreSortIndex[scid.ToUint64()]]
}
c.Timestamps = timestamps
}
}
err := encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs)
@ -119,6 +187,15 @@ func (c *ReplyChannelRange) Encode(w *bytes.Buffer, pver uint32) error {
return err
}
recordProducers := make([]tlv.RecordProducer, 0, 1)
if len(c.Timestamps) != 0 {
recordProducers = append(recordProducers, &c.Timestamps)
}
err = EncodeMessageExtraData(&c.ExtraData, recordProducers...)
if err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}

View file

@ -3,10 +3,9 @@ package lnwire
import (
"bytes"
"encoding/hex"
"reflect"
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/stretchr/testify/require"
)
// TestReplyChannelRangeUnsorted tests that decoding a ReplyChannelRange request
@ -78,29 +77,288 @@ func TestReplyChannelRangeEmpty(t *testing.T) {
// First decode the hex string in the test case into a
// new ReplyChannelRange message. It should be
// identical to the one created above.
var req2 ReplyChannelRange
req2 := NewReplyChannelRange()
b, _ := hex.DecodeString(test.encodedHex)
err := req2.Decode(bytes.NewReader(b), 0)
if err != nil {
t.Fatalf("unable to decode req: %v", err)
}
if !reflect.DeepEqual(req, req2) {
t.Fatalf("requests don't match: expected %v got %v",
spew.Sdump(req), spew.Sdump(req2))
}
require.NoError(t, err)
require.Equal(t, req, *req2)
// Next, we go in the reverse direction: encode the
// request created above, and assert that it matches
// the raw byte encoding.
var b2 bytes.Buffer
err = req.Encode(&b2, 0)
if err != nil {
t.Fatalf("unable to encode req: %v", err)
}
if !bytes.Equal(b, b2.Bytes()) {
t.Fatalf("encoded requests don't match: expected %x got %x",
b, b2.Bytes())
}
require.NoError(t, err)
require.Equal(t, b, b2.Bytes())
})
}
}
// TestReplyChannelRangeEncode tests that encoding a ReplyChannelRange message
// results in the correct sorting of the SCIDs and Timestamps.
func TestReplyChannelRangeEncode(t *testing.T) {
t.Parallel()
tests := []struct {
name string
scids []ShortChannelID
timestamps Timestamps
expError string
expScids []ShortChannelID
expTimestamps Timestamps
}{
{
name: "scids only, sorted",
scids: []ShortChannelID{
{BlockHeight: 100},
{BlockHeight: 200},
{BlockHeight: 300},
},
expScids: []ShortChannelID{
{BlockHeight: 100},
{BlockHeight: 200},
{BlockHeight: 300},
},
},
{
name: "scids only, unsorted",
scids: []ShortChannelID{
{BlockHeight: 300},
{BlockHeight: 100},
{BlockHeight: 200},
},
expScids: []ShortChannelID{
{BlockHeight: 100},
{BlockHeight: 200},
{BlockHeight: 300},
},
},
{
name: "scids and timestamps, sorted",
scids: []ShortChannelID{
{BlockHeight: 100},
{BlockHeight: 200},
{BlockHeight: 300},
},
timestamps: Timestamps{
{Timestamp1: 1, Timestamp2: 2},
{Timestamp1: 3, Timestamp2: 4},
{Timestamp1: 5, Timestamp2: 6},
},
expScids: []ShortChannelID{
{BlockHeight: 100},
{BlockHeight: 200},
{BlockHeight: 300},
},
expTimestamps: Timestamps{
{Timestamp1: 1, Timestamp2: 2},
{Timestamp1: 3, Timestamp2: 4},
{Timestamp1: 5, Timestamp2: 6},
},
},
{
name: "scids and timestamps, unsorted",
scids: []ShortChannelID{
{BlockHeight: 300},
{BlockHeight: 100},
{BlockHeight: 200},
},
timestamps: Timestamps{
{Timestamp1: 5, Timestamp2: 6},
{Timestamp1: 1, Timestamp2: 2},
{Timestamp1: 3, Timestamp2: 4},
},
expScids: []ShortChannelID{
{BlockHeight: 100},
{BlockHeight: 200},
{BlockHeight: 300},
},
expTimestamps: Timestamps{
{Timestamp1: 1, Timestamp2: 2},
{Timestamp1: 3, Timestamp2: 4},
{Timestamp1: 5, Timestamp2: 6},
},
},
{
name: "scid and timestamp count does not match",
scids: []ShortChannelID{
{BlockHeight: 100},
{BlockHeight: 200},
{BlockHeight: 300},
},
timestamps: Timestamps{
{Timestamp1: 1, Timestamp2: 2},
{Timestamp1: 3, Timestamp2: 4},
},
expError: "got must provide a timestamp pair for " +
"each of the given SCIDs",
},
{
name: "duplicate scids",
scids: []ShortChannelID{
{BlockHeight: 100},
{BlockHeight: 200},
{BlockHeight: 200},
},
timestamps: Timestamps{
{Timestamp1: 1, Timestamp2: 2},
{Timestamp1: 3, Timestamp2: 4},
{Timestamp1: 5, Timestamp2: 6},
},
expError: "scid list should not contain duplicates",
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
replyMsg := &ReplyChannelRange{
FirstBlockHeight: 1,
NumBlocks: 2,
Complete: 1,
EncodingType: EncodingSortedPlain,
ShortChanIDs: test.scids,
Timestamps: test.timestamps,
ExtraData: make([]byte, 0),
}
var buf bytes.Buffer
_, err := WriteMessage(&buf, replyMsg, 0)
if len(test.expError) != 0 {
require.ErrorContains(t, err, test.expError)
return
}
require.NoError(t, err)
r := bytes.NewBuffer(buf.Bytes())
msg, err := ReadMessage(r, 0)
require.NoError(t, err)
msg2, ok := msg.(*ReplyChannelRange)
require.True(t, ok)
require.Equal(t, test.expScids, msg2.ShortChanIDs)
require.Equal(t, test.expTimestamps, msg2.Timestamps)
})
}
}
// TestReplyChannelRangeDecode tests the decoding of some ReplyChannelRange
// test vectors.
func TestReplyChannelRangeDecode(t *testing.T) {
t.Parallel()
tests := []struct {
name string
hex string
expEncoding QueryEncoding
expSCIDs []string
expTimestamps Timestamps
expError string
}{
{
name: "plain encoding",
hex: "01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb4360" +
"12afca590b1a11466e2206000b8a06000005dc01001" +
"900000000000000008e0000000000003c6900000000" +
"0045a6c4",
expEncoding: EncodingSortedPlain,
expSCIDs: []string{
"0:0:142",
"0:0:15465",
"0:69:42692",
},
},
{
name: "zlib encoding",
hex: "01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb4360" +
"12afca590b1a11466e2206000006400000006e010016" +
"01789c636000833e08659309a65878be010010a9023a",
expEncoding: EncodingSortedZlib,
expSCIDs: []string{
"0:0:142",
"0:0:15465",
"0:4:3318",
},
},
{
name: "plain encoding including timestamps",
hex: "01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb43601" +
"2afca590b1a11466e22060001ddde000005dc0100190" +
"0000000000000304300000000000778d600000000004" +
"6e1c1011900000282c1000e77c5000778ad00490ab00" +
"000b57800955bff031800000457000008ae00000d050" +
"000115c000015b300001a0a",
expEncoding: EncodingSortedPlain,
expSCIDs: []string{
"0:0:12355",
"0:7:30934",
"0:70:57793",
},
expTimestamps: Timestamps{
{
Timestamp1: 164545,
Timestamp2: 948165,
},
{
Timestamp1: 489645,
Timestamp2: 4786864,
},
{
Timestamp1: 46456,
Timestamp2: 9788415,
},
},
},
{
name: "unsupported encoding",
hex: "01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb" +
"436012afca590b1a11466e22060001ddde000005dc01" +
"001801789c63600001036730c55e710d4cbb3d3c0800" +
"17c303b1012201789c63606a3ac8c0577e9481bd622d" +
"8327d7060686ad150c53a3ff0300554707db03180000" +
"0457000008ae00000d050000115c000015b300001a0a",
expError: "unsupported encoding",
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
b, err := hex.DecodeString(test.hex)
require.NoError(t, err)
r := bytes.NewBuffer(b)
msg, err := ReadMessage(r, 0)
if len(test.expError) != 0 {
require.ErrorContains(t, err, test.expError)
return
}
require.NoError(t, err)
replyMsg, ok := msg.(*ReplyChannelRange)
require.True(t, ok)
require.Equal(
t, test.expEncoding, replyMsg.EncodingType,
)
scids := make([]string, len(replyMsg.ShortChanIDs))
for i, id := range replyMsg.ShortChanIDs {
scids[i] = id.String()
}
require.Equal(t, scids, test.expSCIDs)
require.Equal(
t, test.expTimestamps, replyMsg.Timestamps,
)
})
}
}