From 49a0370dcd1b0b31f13a91216e74dcd6e64c2ff8 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 19 Sep 2023 20:53:58 +0200 Subject: [PATCH] lnwire: add timestamps to ReplyChannelRange msg --- lnwire/lnwire_test.go | 15 +- lnwire/reply_channel_range.go | 81 +++++++- lnwire/reply_channel_range_test.go | 292 +++++++++++++++++++++++++++-- 3 files changed, 368 insertions(+), 20 deletions(-) diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 9b119c851..c3248c3c9 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -1159,12 +1159,25 @@ func TestLightningWireProtocol(t *testing.T) { req.EncodingType = EncodingSortedPlain } - numChanIDs := rand.Int31n(5000) + numChanIDs := rand.Int31n(4000) for i := int32(0); i < numChanIDs; i++ { req.ShortChanIDs = append(req.ShortChanIDs, NewShortChanIDFromInt(uint64(r.Int63()))) } + // With a 50/50 chance, add some timestamps. + if r.Int31()%2 == 0 { + for i := int32(0); i < numChanIDs; i++ { + timestamps := ChanUpdateTimestamps{ + Timestamp1: rand.Uint32(), + Timestamp2: rand.Uint32(), + } + req.Timestamps = append( + req.Timestamps, timestamps, + ) + } + } + v[0] = reflect.ValueOf(req) }, MsgQueryChannelRange: func(v []reflect.Value, r *rand.Rand) { diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index 2a461cded..ea45a5843 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -2,11 +2,13 @@ package lnwire import ( "bytes" + "fmt" "io" "math" "sort" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/tlv" ) // ReplyChannelRange is the response to the QueryChannelRange message. It @@ -39,6 +41,12 @@ type ReplyChannelRange struct { // ShortChanIDs is a slice of decoded short channel ID's. ShortChanIDs []ShortChannelID + // Timestamps is an optional set of timestamps corresponding to the + // latest timestamps for the channel update messages corresponding to + // those referenced in the ShortChanIDs list. If this field is used, + // then the length must match the length of ShortChanIDs. + Timestamps Timestamps + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -53,7 +61,9 @@ type ReplyChannelRange struct { // NewReplyChannelRange creates a new empty ReplyChannelRange message. func NewReplyChannelRange() *ReplyChannelRange { - return &ReplyChannelRange{} + return &ReplyChannelRange{ + ExtraData: make([]byte, 0), + } } // A compile time check to ensure ReplyChannelRange implements the @@ -80,7 +90,27 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { return err } - return c.ExtraData.Decode(r) + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + var timeStamps Timestamps + typeMap, err := tlvRecords.ExtractRecords(&timeStamps) + if err != nil { + return err + } + + // Set the corresponding TLV types if they were included in the stream. + if val, ok := typeMap[TimestampsRecordType]; ok && val == nil { + c.Timestamps = timeStamps + } + + if len(tlvRecords) != 0 { + c.ExtraData = tlvRecords + } + + return nil } // Encode serializes the target ReplyChannelRange into the passed io.Writer @@ -108,10 +138,48 @@ func (c *ReplyChannelRange) Encode(w *bytes.Buffer, pver uint32) error { // sorted in place, so we'll do that now. The sorting is applied unless // we were specifically requested not to for testing purposes. if !c.noSort { + var scidPreSortIndex map[uint64]int + if len(c.Timestamps) != 0 { + // Sanity check that a timestamp was provided for each + // SCID. + if len(c.Timestamps) != len(c.ShortChanIDs) { + return fmt.Errorf("must provide a timestamp " + + "pair for each of the given SCIDs") + } + + // Create a map from SCID value to the original index of + // the SCID in the unsorted list. + scidPreSortIndex = make( + map[uint64]int, len(c.ShortChanIDs), + ) + for i, scid := range c.ShortChanIDs { + scidPreSortIndex[scid.ToUint64()] = i + } + + // Sanity check that there were no duplicates in the + // SCID list. + if len(scidPreSortIndex) != len(c.ShortChanIDs) { + return fmt.Errorf("scid list should not " + + "contain duplicates") + } + } + + // Now sort the SCIDs. sort.Slice(c.ShortChanIDs, func(i, j int) bool { return c.ShortChanIDs[i].ToUint64() < c.ShortChanIDs[j].ToUint64() }) + + if len(c.Timestamps) != 0 { + timestamps := make(Timestamps, len(c.Timestamps)) + + for i, scid := range c.ShortChanIDs { + timestamps[i] = []ChanUpdateTimestamps( + c.Timestamps, + )[scidPreSortIndex[scid.ToUint64()]] + } + c.Timestamps = timestamps + } } err := encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs) @@ -119,6 +187,15 @@ func (c *ReplyChannelRange) Encode(w *bytes.Buffer, pver uint32) error { return err } + recordProducers := make([]tlv.RecordProducer, 0, 1) + if len(c.Timestamps) != 0 { + recordProducers = append(recordProducers, &c.Timestamps) + } + err = EncodeMessageExtraData(&c.ExtraData, recordProducers...) + if err != nil { + return err + } + return WriteBytes(w, c.ExtraData) } diff --git a/lnwire/reply_channel_range_test.go b/lnwire/reply_channel_range_test.go index 85863ae53..12955cfd9 100644 --- a/lnwire/reply_channel_range_test.go +++ b/lnwire/reply_channel_range_test.go @@ -3,10 +3,9 @@ package lnwire import ( "bytes" "encoding/hex" - "reflect" "testing" - "github.com/davecgh/go-spew/spew" + "github.com/stretchr/testify/require" ) // TestReplyChannelRangeUnsorted tests that decoding a ReplyChannelRange request @@ -78,29 +77,288 @@ func TestReplyChannelRangeEmpty(t *testing.T) { // First decode the hex string in the test case into a // new ReplyChannelRange message. It should be // identical to the one created above. - var req2 ReplyChannelRange + req2 := NewReplyChannelRange() b, _ := hex.DecodeString(test.encodedHex) err := req2.Decode(bytes.NewReader(b), 0) - if err != nil { - t.Fatalf("unable to decode req: %v", err) - } - if !reflect.DeepEqual(req, req2) { - t.Fatalf("requests don't match: expected %v got %v", - spew.Sdump(req), spew.Sdump(req2)) - } + require.NoError(t, err) + require.Equal(t, req, *req2) // Next, we go in the reverse direction: encode the // request created above, and assert that it matches // the raw byte encoding. var b2 bytes.Buffer err = req.Encode(&b2, 0) - if err != nil { - t.Fatalf("unable to encode req: %v", err) - } - if !bytes.Equal(b, b2.Bytes()) { - t.Fatalf("encoded requests don't match: expected %x got %x", - b, b2.Bytes()) - } + require.NoError(t, err) + require.Equal(t, b, b2.Bytes()) + }) + } +} + +// TestReplyChannelRangeEncode tests that encoding a ReplyChannelRange message +// results in the correct sorting of the SCIDs and Timestamps. +func TestReplyChannelRangeEncode(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + scids []ShortChannelID + timestamps Timestamps + expError string + expScids []ShortChannelID + expTimestamps Timestamps + }{ + { + name: "scids only, sorted", + scids: []ShortChannelID{ + {BlockHeight: 100}, + {BlockHeight: 200}, + {BlockHeight: 300}, + }, + expScids: []ShortChannelID{ + {BlockHeight: 100}, + {BlockHeight: 200}, + {BlockHeight: 300}, + }, + }, + { + name: "scids only, unsorted", + scids: []ShortChannelID{ + {BlockHeight: 300}, + {BlockHeight: 100}, + {BlockHeight: 200}, + }, + expScids: []ShortChannelID{ + {BlockHeight: 100}, + {BlockHeight: 200}, + {BlockHeight: 300}, + }, + }, + { + name: "scids and timestamps, sorted", + scids: []ShortChannelID{ + {BlockHeight: 100}, + {BlockHeight: 200}, + {BlockHeight: 300}, + }, + timestamps: Timestamps{ + {Timestamp1: 1, Timestamp2: 2}, + {Timestamp1: 3, Timestamp2: 4}, + {Timestamp1: 5, Timestamp2: 6}, + }, + expScids: []ShortChannelID{ + {BlockHeight: 100}, + {BlockHeight: 200}, + {BlockHeight: 300}, + }, + expTimestamps: Timestamps{ + {Timestamp1: 1, Timestamp2: 2}, + {Timestamp1: 3, Timestamp2: 4}, + {Timestamp1: 5, Timestamp2: 6}, + }, + }, + { + name: "scids and timestamps, unsorted", + scids: []ShortChannelID{ + {BlockHeight: 300}, + {BlockHeight: 100}, + {BlockHeight: 200}, + }, + timestamps: Timestamps{ + {Timestamp1: 5, Timestamp2: 6}, + {Timestamp1: 1, Timestamp2: 2}, + {Timestamp1: 3, Timestamp2: 4}, + }, + expScids: []ShortChannelID{ + {BlockHeight: 100}, + {BlockHeight: 200}, + {BlockHeight: 300}, + }, + expTimestamps: Timestamps{ + {Timestamp1: 1, Timestamp2: 2}, + {Timestamp1: 3, Timestamp2: 4}, + {Timestamp1: 5, Timestamp2: 6}, + }, + }, + { + name: "scid and timestamp count does not match", + scids: []ShortChannelID{ + {BlockHeight: 100}, + {BlockHeight: 200}, + {BlockHeight: 300}, + }, + timestamps: Timestamps{ + {Timestamp1: 1, Timestamp2: 2}, + {Timestamp1: 3, Timestamp2: 4}, + }, + expError: "got must provide a timestamp pair for " + + "each of the given SCIDs", + }, + { + name: "duplicate scids", + scids: []ShortChannelID{ + {BlockHeight: 100}, + {BlockHeight: 200}, + {BlockHeight: 200}, + }, + timestamps: Timestamps{ + {Timestamp1: 1, Timestamp2: 2}, + {Timestamp1: 3, Timestamp2: 4}, + {Timestamp1: 5, Timestamp2: 6}, + }, + expError: "scid list should not contain duplicates", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + replyMsg := &ReplyChannelRange{ + FirstBlockHeight: 1, + NumBlocks: 2, + Complete: 1, + EncodingType: EncodingSortedPlain, + ShortChanIDs: test.scids, + Timestamps: test.timestamps, + ExtraData: make([]byte, 0), + } + + var buf bytes.Buffer + _, err := WriteMessage(&buf, replyMsg, 0) + if len(test.expError) != 0 { + require.ErrorContains(t, err, test.expError) + + return + } + + require.NoError(t, err) + + r := bytes.NewBuffer(buf.Bytes()) + msg, err := ReadMessage(r, 0) + require.NoError(t, err) + + msg2, ok := msg.(*ReplyChannelRange) + require.True(t, ok) + + require.Equal(t, test.expScids, msg2.ShortChanIDs) + require.Equal(t, test.expTimestamps, msg2.Timestamps) + }) + } +} + +// TestReplyChannelRangeDecode tests the decoding of some ReplyChannelRange +// test vectors. +func TestReplyChannelRangeDecode(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + hex string + expEncoding QueryEncoding + expSCIDs []string + expTimestamps Timestamps + expError string + }{ + { + name: "plain encoding", + hex: "01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb4360" + + "12afca590b1a11466e2206000b8a06000005dc01001" + + "900000000000000008e0000000000003c6900000000" + + "0045a6c4", + expEncoding: EncodingSortedPlain, + expSCIDs: []string{ + "0:0:142", + "0:0:15465", + "0:69:42692", + }, + }, + { + name: "zlib encoding", + hex: "01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb4360" + + "12afca590b1a11466e2206000006400000006e010016" + + "01789c636000833e08659309a65878be010010a9023a", + expEncoding: EncodingSortedZlib, + expSCIDs: []string{ + "0:0:142", + "0:0:15465", + "0:4:3318", + }, + }, + { + name: "plain encoding including timestamps", + hex: "01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb43601" + + "2afca590b1a11466e22060001ddde000005dc0100190" + + "0000000000000304300000000000778d600000000004" + + "6e1c1011900000282c1000e77c5000778ad00490ab00" + + "000b57800955bff031800000457000008ae00000d050" + + "000115c000015b300001a0a", + expEncoding: EncodingSortedPlain, + expSCIDs: []string{ + "0:0:12355", + "0:7:30934", + "0:70:57793", + }, + expTimestamps: Timestamps{ + { + Timestamp1: 164545, + Timestamp2: 948165, + }, + { + Timestamp1: 489645, + Timestamp2: 4786864, + }, + { + Timestamp1: 46456, + Timestamp2: 9788415, + }, + }, + }, + { + name: "unsupported encoding", + hex: "01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb" + + "436012afca590b1a11466e22060001ddde000005dc01" + + "001801789c63600001036730c55e710d4cbb3d3c0800" + + "17c303b1012201789c63606a3ac8c0577e9481bd622d" + + "8327d7060686ad150c53a3ff0300554707db03180000" + + "0457000008ae00000d050000115c000015b300001a0a", + expError: "unsupported encoding", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + b, err := hex.DecodeString(test.hex) + require.NoError(t, err) + + r := bytes.NewBuffer(b) + + msg, err := ReadMessage(r, 0) + if len(test.expError) != 0 { + require.ErrorContains(t, err, test.expError) + + return + } + require.NoError(t, err) + + replyMsg, ok := msg.(*ReplyChannelRange) + require.True(t, ok) + require.Equal( + t, test.expEncoding, replyMsg.EncodingType, + ) + + scids := make([]string, len(replyMsg.ShortChanIDs)) + for i, id := range replyMsg.ShortChanIDs { + scids[i] = id.String() + } + require.Equal(t, scids, test.expSCIDs) + + require.Equal( + t, test.expTimestamps, replyMsg.Timestamps, + ) }) } }