From 0ad4ef373a9f8c0efa56252a34f8549c21ae48c2 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 20 Sep 2023 10:24:41 +0200 Subject: [PATCH] channeldb+discovery: fetch timestamps from DB if required --- channeldb/graph.go | 98 +++++++++++++++++++++++++++++--- channeldb/graph_test.go | 119 +++++++++++++++++++++++++++++++++++---- discovery/chan_series.go | 2 +- discovery/syncer.go | 18 ++++-- discovery/syncer_test.go | 22 +++++--- 5 files changed, 228 insertions(+), 31 deletions(-) diff --git a/channeldb/graph.go b/channeldb/graph.go index d3c3a5508..5d5cb45df 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -2143,6 +2143,23 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { return newChanIDs, nil } +// ChannelUpdateInfo couples the SCID of a channel with the timestamps of the +// latest received channel updates for the channel. +type ChannelUpdateInfo struct { + // ShortChannelID is the SCID identifier of the channel. + ShortChannelID lnwire.ShortChannelID + + // Node1UpdateTimestamp is the timestamp of the latest received update + // from the node 1 channel peer. This will be set to zero time if no + // update has yet been received from this node. + Node1UpdateTimestamp time.Time + + // Node2UpdateTimestamp is the timestamp of the latest received update + // from the node 2 channel peer. This will be set to zero time if no + // update has yet been received from this node. + Node2UpdateTimestamp time.Time +} + // BlockChannelRange represents a range of channels for a given block height. type BlockChannelRange struct { // Height is the height of the block all of the channels below were @@ -2151,17 +2168,20 @@ type BlockChannelRange struct { // Channels is the list of channels identified by their short ID // representation known to us that were included in the block height - // above. - Channels []lnwire.ShortChannelID + // above. The list may include channel update timestamp information if + // requested. + Channels []ChannelUpdateInfo } // FilterChannelRange returns the channel ID's of all known channels which were // mined in a block height within the passed range. The channel IDs are grouped // by their common block height. This method can be used to quickly share with a // peer the set of channels we know of within a particular range to catch them -// up after a period of time offline. +// up after a period of time offline. If withTimestamps is true then the +// timestamp info of the latest received channel update messages of the channel +// will be included in the response. func (c *ChannelGraph) FilterChannelRange(startHeight, - endHeight uint32) ([]BlockChannelRange, error) { + endHeight uint32, withTimestamps bool) ([]BlockChannelRange, error) { startChanID := &lnwire.ShortChannelID{ BlockHeight: startHeight, @@ -2180,7 +2200,7 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, byteOrder.PutUint64(chanIDStart[:], startChanID.ToUint64()) byteOrder.PutUint64(chanIDEnd[:], endChanID.ToUint64()) - var channelsPerBlock map[uint32][]lnwire.ShortChannelID + var channelsPerBlock map[uint32][]ChannelUpdateInfo err := kvdb.View(c.db, func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) if edges == nil { @@ -2212,14 +2232,60 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, // we'll add it to our returned set. rawCid := byteOrder.Uint64(k) cid := lnwire.NewShortChanIDFromInt(rawCid) + + chanInfo := ChannelUpdateInfo{ + ShortChannelID: cid, + } + + if !withTimestamps { + channelsPerBlock[cid.BlockHeight] = append( + channelsPerBlock[cid.BlockHeight], + chanInfo, + ) + + continue + } + + node1Key, node2Key := computeEdgePolicyKeys(&edgeInfo) + + rawPolicy := edges.Get(node1Key) + if len(rawPolicy) != 0 { + r := bytes.NewReader(rawPolicy) + + edge, err := deserializeChanEdgePolicyRaw(r) + if err != nil && !errors.Is( + err, ErrEdgePolicyOptionalFieldNotFound, + ) { + + return err + } + + chanInfo.Node1UpdateTimestamp = edge.LastUpdate + } + + rawPolicy = edges.Get(node2Key) + if len(rawPolicy) != 0 { + r := bytes.NewReader(rawPolicy) + + edge, err := deserializeChanEdgePolicyRaw(r) + if err != nil && !errors.Is( + err, ErrEdgePolicyOptionalFieldNotFound, + ) { + + return err + } + + chanInfo.Node2UpdateTimestamp = edge.LastUpdate + } + channelsPerBlock[cid.BlockHeight] = append( - channelsPerBlock[cid.BlockHeight], cid, + channelsPerBlock[cid.BlockHeight], chanInfo, ) } return nil }, func() { - channelsPerBlock = make(map[uint32][]lnwire.ShortChannelID) + channelsPerBlock = make(map[uint32][]ChannelUpdateInfo) }) switch { @@ -3118,6 +3184,24 @@ func (c *ChannelGraph) FetchOtherNode(tx kvdb.RTx, return targetNode, err } +// computeEdgePolicyKeys is a helper function that can be used to compute the +// keys used to index the channel edge policy info for the two nodes of the +// edge. The keys for node 1 and node 2 are returned respectively. +func computeEdgePolicyKeys(info *models.ChannelEdgeInfo) ([]byte, []byte) { + var ( + node1Key [33 + 8]byte + node2Key [33 + 8]byte + ) + + copy(node1Key[:], info.NodeKey1Bytes[:]) + copy(node2Key[:], info.NodeKey2Bytes[:]) + + byteOrder.PutUint64(node1Key[33:], info.ChannelID) + byteOrder.PutUint64(node2Key[33:], info.ChannelID) + + return node1Key[:], node2Key[:] +} + // FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for // the channel identified by the funding outpoint. If the channel can't be // found, then ErrEdgeNotFound is returned. A struct which houses the general diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 49a2eb9ff..266d78b9f 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -27,6 +27,7 @@ import ( "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" + "golang.org/x/exp/rand" ) var ( @@ -2045,7 +2046,7 @@ func TestFilterChannelRange(t *testing.T) { // If we try to filter a channel range before we have any channels // inserted, we should get an empty slice of results. - resp, err := graph.FilterChannelRange(10, 100) + resp, err := graph.FilterChannelRange(10, 100, false) require.NoError(t, err) require.Empty(t, resp) @@ -2054,7 +2055,41 @@ func TestFilterChannelRange(t *testing.T) { startHeight := uint32(100) endHeight := startHeight const numChans = 10 - channelRanges := make([]BlockChannelRange, 0, numChans/2) + + var ( + channelRanges = make( + []BlockChannelRange, 0, numChans/2, + ) + channelRangesWithTimestamps = make( + []BlockChannelRange, 0, numChans/2, + ) + ) + + updateTimeSeed := int64(1) + maybeAddPolicy := func(chanID uint64, node *LightningNode, + node2 bool) time.Time { + + var chanFlags lnwire.ChanUpdateChanFlags + if node2 { + chanFlags = lnwire.ChanUpdateDirection + } + + var updateTime time.Time + if rand.Int31n(2) == 0 { + updateTime = time.Unix(updateTimeSeed, 0) + err = graph.UpdateEdgePolicy(&models.ChannelEdgePolicy{ + ToNode: node.PubKeyBytes, + ChannelFlags: chanFlags, + ChannelID: chanID, + LastUpdate: updateTime, + }) + require.NoError(t, err) + } + updateTimeSeed++ + + return updateTime + } + for i := 0; i < numChans/2; i++ { chanHeight := endHeight channel1, chanID1 := createEdge( @@ -2068,9 +2103,38 @@ func TestFilterChannelRange(t *testing.T) { require.NoError(t, graph.AddChannelEdge(&channel2)) channelRanges = append(channelRanges, BlockChannelRange{ - Height: chanHeight, - Channels: []lnwire.ShortChannelID{chanID1, chanID2}, + Height: chanHeight, + Channels: []ChannelUpdateInfo{ + {ShortChannelID: chanID1}, + {ShortChannelID: chanID2}, + }, }) + + var ( + time1 = maybeAddPolicy(channel1.ChannelID, node1, false) + time2 = maybeAddPolicy(channel1.ChannelID, node2, true) + time3 = maybeAddPolicy(channel2.ChannelID, node1, false) + time4 = maybeAddPolicy(channel2.ChannelID, node2, true) + ) + + channelRangesWithTimestamps = append( + channelRangesWithTimestamps, BlockChannelRange{ + Height: chanHeight, + Channels: []ChannelUpdateInfo{ + { + ShortChannelID: chanID1, + Node1UpdateTimestamp: time1, + Node2UpdateTimestamp: time2, + }, + { + ShortChannelID: chanID2, + Node1UpdateTimestamp: time3, + Node2UpdateTimestamp: time4, + }, + }, + }, + ) + endHeight += 10 } @@ -2083,7 +2147,9 @@ func TestFilterChannelRange(t *testing.T) { startHeight uint32 endHeight uint32 - resp []BlockChannelRange + resp []BlockChannelRange + expStartIndex int + expEndIndex int }{ // If we query for the entire range, then we should get the same // set of short channel IDs back. @@ -2092,7 +2158,9 @@ func TestFilterChannelRange(t *testing.T) { startHeight: startHeight, endHeight: endHeight, - resp: channelRanges, + resp: channelRanges, + expStartIndex: 0, + expEndIndex: len(channelRanges), }, // If we query for a range of channels right before our range, @@ -2110,7 +2178,9 @@ func TestFilterChannelRange(t *testing.T) { startHeight: endHeight - 10, endHeight: endHeight - 10, - resp: channelRanges[4:], + resp: channelRanges[4:], + expStartIndex: 4, + expEndIndex: len(channelRanges), }, // If we query for just the first height, we should only get a @@ -2120,7 +2190,9 @@ func TestFilterChannelRange(t *testing.T) { startHeight: startHeight, endHeight: startHeight, - resp: channelRanges[:1], + resp: channelRanges[:1], + expStartIndex: 0, + expEndIndex: 1, }, { @@ -2128,20 +2200,45 @@ func TestFilterChannelRange(t *testing.T) { startHeight: startHeight + 10, endHeight: endHeight - 10, - resp: channelRanges[1:5], + resp: channelRanges[1:5], + expStartIndex: 1, + expEndIndex: 5, }, } + for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() + // First, do the query without requesting timestamps. resp, err := graph.FilterChannelRange( - test.startHeight, test.endHeight, + test.startHeight, test.endHeight, false, ) require.NoError(t, err) - require.Equal(t, test.resp, resp) + + expRes := channelRanges[test.expStartIndex:test.expEndIndex] //nolint:lll + + if len(expRes) == 0 { + require.Nil(t, resp) + } else { + require.Equal(t, expRes, resp) + } + + // Now, query the timestamps as well. + resp, err = graph.FilterChannelRange( + test.startHeight, test.endHeight, true, + ) + require.NoError(t, err) + + expRes = channelRangesWithTimestamps[test.expStartIndex:test.expEndIndex] //nolint:lll + + if len(expRes) == 0 { + require.Nil(t, resp) + } else { + require.Equal(t, expRes, resp) + } }) } } diff --git a/discovery/chan_series.go b/discovery/chan_series.go index c811017ab..647d741a5 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -229,7 +229,7 @@ func (c *ChanSeries) FilterKnownChanIDs(chain chainhash.Hash, func (c *ChanSeries) FilterChannelRange(chain chainhash.Hash, startHeight, endHeight uint32) ([]channeldb.BlockChannelRange, error) { - return c.graph.FilterChannelRange(startHeight, endHeight) + return c.graph.FilterChannelRange(startHeight, endHeight, false) } // FetchChanAnns returns a full set of channel announcements as well as their diff --git a/discovery/syncer.go b/discovery/syncer.go index 5940b23d4..35368ddf6 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -11,6 +11,7 @@ import ( "time" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnwire" "golang.org/x/time/rate" @@ -1044,7 +1045,7 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // this as there's a transport message size limit which we'll need to // adhere to. We also need to make sure all of our replies cover the // expected range of the query. - sendReplyForChunk := func(channelChunk []lnwire.ShortChannelID, + sendReplyForChunk := func(channelChunk []channeldb.ChannelUpdateInfo, firstHeight, lastHeight uint32, finalChunk bool) error { // The number of blocks contained in the current chunk (the @@ -1057,20 +1058,25 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro complete = 1 } + scids := make([]lnwire.ShortChannelID, len(channelChunk)) + for i, info := range channelChunk { + scids[i] = info.ShortChannelID + } + return g.cfg.sendToPeerSync(&lnwire.ReplyChannelRange{ ChainHash: query.ChainHash, NumBlocks: numBlocks, FirstBlockHeight: firstHeight, Complete: complete, EncodingType: g.cfg.encodingType, - ShortChanIDs: channelChunk, + ShortChanIDs: scids, }) } var ( firstHeight = query.FirstBlockHeight lastHeight uint32 - channelChunk []lnwire.ShortChannelID + channelChunk []channeldb.ChannelUpdateInfo ) for _, channelRange := range channelRanges { channels := channelRange.Channels @@ -1118,8 +1124,10 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // Sort the chunk once again if we had to shuffle it. if exceedsChunkSize { sort.Slice(channelChunk, func(i, j int) bool { - return channelChunk[i].ToUint64() < - channelChunk[j].ToUint64() + id1 := channelChunk[i].ShortChannelID.ToUint64() + id2 := channelChunk[j].ShortChannelID.ToUint64() + + return id1 < id2 }) } } diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 9f9adcf4f..23fa2eb4e 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -103,10 +103,13 @@ func (m *mockChannelGraphTimeSeries) FilterChannelRange(chain chainhash.Hash, m.filterRangeReqs <- filterRangeReq{startHeight, endHeight} reply := <-m.filterRangeResp - channelsPerBlock := make(map[uint32][]lnwire.ShortChannelID) + channelsPerBlock := make(map[uint32][]channeldb.ChannelUpdateInfo) for _, cid := range reply { channelsPerBlock[cid.BlockHeight] = append( - channelsPerBlock[cid.BlockHeight], cid, + channelsPerBlock[cid.BlockHeight], + channeldb.ChannelUpdateInfo{ + ShortChannelID: cid, + }, ) } @@ -119,16 +122,21 @@ func (m *mockChannelGraphTimeSeries) FilterChannelRange(chain chainhash.Hash, return blocks[i] < blocks[j] }) - channelRanges := make([]channeldb.BlockChannelRange, 0, len(channelsPerBlock)) + channelRanges := make( + []channeldb.BlockChannelRange, 0, len(channelsPerBlock), + ) for _, block := range blocks { - channelRanges = append(channelRanges, channeldb.BlockChannelRange{ - Height: block, - Channels: channelsPerBlock[block], - }) + channelRanges = append( + channelRanges, channeldb.BlockChannelRange{ + Height: block, + Channels: channelsPerBlock[block], + }, + ) } return channelRanges, nil } + func (m *mockChannelGraphTimeSeries) FetchChanAnns(chain chainhash.Hash, shortChanIDs []lnwire.ShortChannelID) ([]lnwire.Message, error) {