From a242ad5acb6b46e82ef839be84b0695b2de089a7 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 20 Sep 2023 11:13:35 +0200 Subject: [PATCH] channeldb+discovery: use timestamps to maybe revive zombie --- channeldb/graph.go | 48 +++++++++++++++--- channeldb/graph_test.go | 106 +++++++++++++++++++++++++++++++-------- discovery/chan_series.go | 17 +++---- discovery/syncer.go | 73 +++++++++++++++++++++++---- discovery/syncer_test.go | 21 +++++--- 5 files changed, 212 insertions(+), 53 deletions(-) diff --git a/channeldb/graph.go b/channeldb/graph.go index 0a1b233e8..1ab2897bb 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -2086,10 +2086,12 @@ func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, // words, we perform a set difference of our set of chan ID's and the ones // passed in. This method can be used by callers to determine the set of // channels another peer knows of that we don't. -func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { +func (c *ChannelGraph) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo, + isZombieChan func(time.Time, time.Time) bool) ([]uint64, error) { + var newChanIDs []uint64 - err := kvdb.View(c.db, func(tx kvdb.RTx) error { + err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { edges := tx.ReadBucket(edgeBucket) if edges == nil { return ErrGraphNoEdgesFound @@ -2107,8 +2109,9 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { // We'll run through the set of chanIDs and collate only the // set of channel that are unable to be found within our db. var cidBytes [8]byte - for _, cid := range chanIDs { - byteOrder.PutUint64(cidBytes[:], cid) + for _, info := range chansInfo { + scid := info.ShortChannelID.ToUint64() + byteOrder.PutUint64(cidBytes[:], scid) // If the edge is already known, skip it. if v := edgeIndex.Get(cidBytes[:]); v != nil { @@ -2117,13 +2120,37 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { // If the edge is a known zombie, skip it. if zombieIndex != nil { - isZombie, _, _ := isZombieEdge(zombieIndex, cid) - if isZombie { + isZombie, _, _ := isZombieEdge( + zombieIndex, scid, + ) + + isStillZombie := isZombieChan( + info.Node1UpdateTimestamp, + info.Node2UpdateTimestamp, + ) + + switch { + // If the edge is a known zombie and if we + // would still consider it a zombie given the + // latest update timestamps, then we skip this + // channel. + case isZombie && isStillZombie: continue + + // Otherwise, if we have marked it as a zombie + // but the latest update timestamps could bring + // it back from the dead, then we mark it alive, + // and we let it be added to the set of IDs to + // query our peer for. + case isZombie && !isStillZombie: + err := c.markEdgeLive(tx, scid) + if err != nil { + return err + } } } - newChanIDs = append(newChanIDs, cid) + newChanIDs = append(newChanIDs, scid) } return nil @@ -2134,7 +2161,12 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { // If we don't know of any edges yet, then we'll return the entire set // of chan IDs specified. case err == ErrGraphNoEdgesFound: - return chanIDs, nil + ogChanIDs := make([]uint64, len(chansInfo)) + for i, info := range chansInfo { + ogChanIDs[i] = info.ShortChannelID.ToUint64() + } + + return ogChanIDs, nil case err != nil: return nil, err diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 266d78b9f..cbfabb3cc 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -1928,14 +1928,32 @@ func TestFilterKnownChanIDs(t *testing.T) { graph, err := MakeTestGraph(t) require.NoError(t, err, "unable to make test database") + isZombieUpdate := func(updateTime1 time.Time, + updateTime2 time.Time) bool { + + return true + } + + var ( + scid1 = lnwire.ShortChannelID{BlockHeight: 1} + scid2 = lnwire.ShortChannelID{BlockHeight: 2} + scid3 = lnwire.ShortChannelID{BlockHeight: 3} + ) + // If we try to filter out a set of channel ID's before we even know of // any channels, then we should get the entire set back. - preChanIDs := []uint64{1, 2, 3, 4} - filteredIDs, err := graph.FilterKnownChanIDs(preChanIDs) - require.NoError(t, err, "unable to filter chan IDs") - if !reflect.DeepEqual(preChanIDs, filteredIDs) { - t.Fatalf("chan IDs shouldn't have been filtered!") + preChanIDs := []ChannelUpdateInfo{ + {ShortChannelID: scid1}, + {ShortChannelID: scid2}, + {ShortChannelID: scid3}, } + filteredIDs, err := graph.FilterKnownChanIDs(preChanIDs, isZombieUpdate) + require.NoError(t, err, "unable to filter chan IDs") + require.EqualValues(t, []uint64{ + scid1.ToUint64(), + scid2.ToUint64(), + scid3.ToUint64(), + }, filteredIDs) // We'll start by creating two nodes which will seed our test graph. node1, err := createTestVertex(graph.db) @@ -1952,7 +1970,7 @@ func TestFilterKnownChanIDs(t *testing.T) { // Next, we'll add 5 channel ID's to the graph, each of them having a // block height 10 blocks after the previous. const numChans = 5 - chanIDs := make([]uint64, 0, numChans) + chanIDs := make([]ChannelUpdateInfo, 0, numChans) for i := 0; i < numChans; i++ { channel, chanID := createEdge( uint32(i*10), 0, 0, 0, node1, node2, @@ -1962,11 +1980,13 @@ func TestFilterKnownChanIDs(t *testing.T) { t.Fatalf("unable to create channel edge: %v", err) } - chanIDs = append(chanIDs, chanID.ToUint64()) + chanIDs = append(chanIDs, ChannelUpdateInfo{ + ShortChannelID: chanID, + }) } const numZombies = 5 - zombieIDs := make([]uint64, 0, numZombies) + zombieIDs := make([]ChannelUpdateInfo, 0, numZombies) for i := 0; i < numZombies; i++ { channel, chanID := createEdge( uint32(i*10+1), 0, 0, 0, node1, node2, @@ -1979,13 +1999,15 @@ func TestFilterKnownChanIDs(t *testing.T) { t.Fatalf("unable to mark edge zombie: %v", err) } - zombieIDs = append(zombieIDs, chanID.ToUint64()) + zombieIDs = append( + zombieIDs, ChannelUpdateInfo{ShortChannelID: chanID}, + ) } queryCases := []struct { - queryIDs []uint64 + queryIDs []ChannelUpdateInfo - resp []uint64 + resp []ChannelUpdateInfo }{ // If we attempt to filter out all chanIDs we know of, the // response should be the empty set. @@ -2001,28 +2023,70 @@ func TestFilterKnownChanIDs(t *testing.T) { // If we query for a set of ID's that we didn't insert, we // should get the same set back. { - queryIDs: []uint64{99, 100}, - resp: []uint64{99, 100}, + queryIDs: []ChannelUpdateInfo{ + {ShortChannelID: lnwire.ShortChannelID{ + BlockHeight: 99, + }}, + {ShortChannelID: lnwire.ShortChannelID{ + BlockHeight: 100, + }}, + }, + resp: []ChannelUpdateInfo{ + {ShortChannelID: lnwire.ShortChannelID{ + BlockHeight: 99, + }}, + {ShortChannelID: lnwire.ShortChannelID{ + BlockHeight: 100, + }}, + }, }, // If we query for a super-set of our the chan ID's inserted, // we should only get those new chanIDs back. { - queryIDs: append(chanIDs, []uint64{99, 101}...), - resp: []uint64{99, 101}, + queryIDs: append(chanIDs, []ChannelUpdateInfo{ + { + ShortChannelID: lnwire.ShortChannelID{ + BlockHeight: 99, + }, + }, + { + ShortChannelID: lnwire.ShortChannelID{ + BlockHeight: 101, + }, + }, + }...), + resp: []ChannelUpdateInfo{ + { + ShortChannelID: lnwire.ShortChannelID{ + BlockHeight: 99, + }, + }, + { + ShortChannelID: lnwire.ShortChannelID{ + BlockHeight: 101, + }, + }, + }, }, } for _, queryCase := range queryCases { - resp, err := graph.FilterKnownChanIDs(queryCase.queryIDs) - if err != nil { - t.Fatalf("unable to filter chan IDs: %v", err) + resp, err := graph.FilterKnownChanIDs( + queryCase.queryIDs, isZombieUpdate, + ) + require.NoError(t, err) + + expectedSCIDs := make([]uint64, len(queryCase.resp)) + for i, info := range queryCase.resp { + expectedSCIDs[i] = info.ShortChannelID.ToUint64() } - if !reflect.DeepEqual(resp, queryCase.resp) { - t.Fatalf("expected %v, got %v", spew.Sdump(queryCase.resp), - spew.Sdump(resp)) + if len(expectedSCIDs) == 0 { + expectedSCIDs = nil } + + require.EqualValues(t, expectedSCIDs, resp) } } diff --git a/discovery/chan_series.go b/discovery/chan_series.go index d0dc091d7..34e6d4a9d 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -37,7 +37,9 @@ type ChannelGraphTimeSeries interface { // ID's represents the ID's that we don't know of which were in the // passed superSet. FilterKnownChanIDs(chain chainhash.Hash, - superSet []lnwire.ShortChannelID) ([]lnwire.ShortChannelID, error) + superSet []channeldb.ChannelUpdateInfo, + isZombieChan func(time.Time, time.Time) bool) ( + []lnwire.ShortChannelID, error) // FilterChannelRange returns the set of channels that we created // between the start height and the end height. The channel IDs are @@ -197,15 +199,12 @@ func (c *ChanSeries) UpdatesInHorizon(chain chainhash.Hash, // represents the ID's that we don't know of which were in the passed superSet. // // NOTE: This is part of the ChannelGraphTimeSeries interface. -func (c *ChanSeries) FilterKnownChanIDs(chain chainhash.Hash, - superSet []lnwire.ShortChannelID) ([]lnwire.ShortChannelID, error) { +func (c *ChanSeries) FilterKnownChanIDs(_ chainhash.Hash, + superSet []channeldb.ChannelUpdateInfo, + isZombieChan func(time.Time, time.Time) bool) ( + []lnwire.ShortChannelID, error) { - chanIDs := make([]uint64, 0, len(superSet)) - for _, chanID := range superSet { - chanIDs = append(chanIDs, chanID.ToUint64()) - } - - newChanIDs, err := c.graph.FilterKnownChanIDs(chanIDs) + newChanIDs, err := c.graph.FilterKnownChanIDs(superSet, isZombieChan) if err != nil { return nil, err } diff --git a/discovery/syncer.go b/discovery/syncer.go index 937b0abbb..dc1565371 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -180,6 +180,9 @@ const ( // requestBatchSize is the maximum number of channels we will query the // remote peer for in a QueryShortChanIDs message. requestBatchSize = 500 + + // filterSemaSize is the capacity of gossipFilterSema. + filterSemaSize = 5 ) var ( @@ -372,7 +375,7 @@ type GossipSyncer struct { // bufferedChanRangeReplies is used in the waitingQueryChanReply to // buffer all the chunked response to our query. - bufferedChanRangeReplies []lnwire.ShortChannelID + bufferedChanRangeReplies []channeldb.ChannelUpdateInfo // numChanRangeRepliesRcvd is used to track the number of replies // received as part of a QueryChannelRange. This field is primarily used @@ -398,6 +401,8 @@ type GossipSyncer struct { sync.Mutex + gossipFilterSema chan struct{} + quit chan struct{} wg sync.WaitGroup } @@ -426,6 +431,11 @@ func newGossipSyncer(cfg gossipSyncerCfg) *GossipSyncer { interval, cfg.maxUndelayedQueryReplies, ) + filterSema := make(chan struct{}, filterSemaSize) + for i := 0; i < filterSemaSize; i++ { + filterSema <- struct{}{} + } + return &GossipSyncer{ cfg: cfg, rateLimiter: rateLimiter, @@ -433,6 +443,7 @@ func newGossipSyncer(cfg gossipSyncerCfg) *GossipSyncer { historicalSyncReqs: make(chan *historicalSyncReq), gossipMsgs: make(chan lnwire.Message, 100), queryMsgs: make(chan lnwire.Message, 100), + gossipFilterSema: filterSema, quit: make(chan struct{}), } } @@ -819,9 +830,31 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro } g.prevReplyChannelRange = msg - g.bufferedChanRangeReplies = append( - g.bufferedChanRangeReplies, msg.ShortChanIDs..., - ) + if len(msg.Timestamps) != 0 && + len(msg.Timestamps) != len(msg.ShortChanIDs) { + + return fmt.Errorf("number of timestamps not equal to " + + "number of SCIDs") + } + + for i, scid := range msg.ShortChanIDs { + info := channeldb.ChannelUpdateInfo{ + ShortChannelID: scid, + } + + if len(msg.Timestamps) != 0 { + t1 := time.Unix(int64(msg.Timestamps[i].Timestamp1), 0) + info.Node1UpdateTimestamp = t1 + + t2 := time.Unix(int64(msg.Timestamps[i].Timestamp2), 0) + info.Node2UpdateTimestamp = t2 + } + + g.bufferedChanRangeReplies = append( + g.bufferedChanRangeReplies, info, + ) + } + switch g.cfg.encodingType { case lnwire.EncodingSortedPlain: g.numChanRangeRepliesRcvd++ @@ -868,6 +901,7 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro // which channels they know of that we don't. newChans, err := g.cfg.channelSeries.FilterKnownChanIDs( g.cfg.chainHash, g.bufferedChanRangeReplies, + g.cfg.isStillZombieChannel, ) if err != nil { return fmt.Errorf("unable to filter chan ids: %v", err) @@ -1107,10 +1141,20 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro lastHeight uint32 channelChunk []channeldb.ChannelUpdateInfo ) + + // chunkSize is the maximum number of SCIDs that we can safely put in a + // single message. If we also need to include timestamps though, then + // this number is halved since encoding two timestamps takes the same + // number of bytes as encoding an SCID. + chunkSize := g.cfg.chunkSize + if withTimestamps { + chunkSize /= 2 + } + for _, channelRange := range channelRanges { channels := channelRange.Channels numChannels := int32(len(channels)) - numLeftToAdd := g.cfg.chunkSize - int32(len(channelChunk)) + numLeftToAdd := chunkSize - int32(len(channelChunk)) // Include the current block in the ongoing chunk if it can fit // and move on to the next block. @@ -1126,6 +1170,7 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // to. log.Infof("GossipSyncer(%x): sending range chunk of size=%v", g.cfg.peerPub[:], len(channelChunk)) + lastHeight = channelRange.Height - 1 err := sendReplyForChunk( channelChunk, firstHeight, lastHeight, false, @@ -1140,15 +1185,15 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // this isn't an issue since we'll randomly shuffle them and we // assume a historical gossip sync is performed at a later time. firstHeight = channelRange.Height - chunkSize := numChannels - exceedsChunkSize := numChannels > g.cfg.chunkSize + finalChunkSize := numChannels + exceedsChunkSize := numChannels > chunkSize if exceedsChunkSize { rand.Shuffle(len(channels), func(i, j int) { channels[i], channels[j] = channels[j], channels[i] }) - chunkSize = g.cfg.chunkSize + finalChunkSize = chunkSize } - channelChunk = channels[:chunkSize] + channelChunk = channels[:finalChunkSize] // Sort the chunk once again if we had to shuffle it. if exceedsChunkSize { @@ -1164,6 +1209,7 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // Send the remaining chunk as the final reply. log.Infof("GossipSyncer(%x): sending final chan range chunk, size=%v", g.cfg.peerPub[:], len(channelChunk)) + return sendReplyForChunk( channelChunk, firstHeight, query.LastBlockHeight(), true, ) @@ -1267,10 +1313,19 @@ func (g *GossipSyncer) ApplyGossipFilter(filter *lnwire.GossipTimestampRange) er return nil } + select { + case <-g.gossipFilterSema: + case <-g.quit: + return ErrGossipSyncerExiting + } + // We'll conclude by launching a goroutine to send out any updates. g.wg.Add(1) go func() { defer g.wg.Done() + defer func() { + g.gossipFilterSema <- struct{}{} + }() for _, msg := range newUpdatestoSend { err := g.cfg.sendToPeerSync(msg) diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 3c7835107..3a77e046c 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -42,7 +42,7 @@ type mockChannelGraphTimeSeries struct { horizonReq chan horizonQuery horizonResp chan []lnwire.Message - filterReq chan []lnwire.ShortChannelID + filterReq chan []channeldb.ChannelUpdateInfo filterResp chan []lnwire.ShortChannelID filterRangeReqs chan filterRangeReq @@ -64,7 +64,7 @@ func newMockChannelGraphTimeSeries( horizonReq: make(chan horizonQuery, 1), horizonResp: make(chan []lnwire.Message, 1), - filterReq: make(chan []lnwire.ShortChannelID, 1), + filterReq: make(chan []channeldb.ChannelUpdateInfo, 1), filterResp: make(chan []lnwire.ShortChannelID, 1), filterRangeReqs: make(chan filterRangeReq, 1), @@ -90,8 +90,11 @@ func (m *mockChannelGraphTimeSeries) UpdatesInHorizon(chain chainhash.Hash, return <-m.horizonResp, nil } + func (m *mockChannelGraphTimeSeries) FilterKnownChanIDs(chain chainhash.Hash, - superSet []lnwire.ShortChannelID) ([]lnwire.ShortChannelID, error) { + superSet []channeldb.ChannelUpdateInfo, + isZombieChan func(time.Time, time.Time) bool) ( + []lnwire.ShortChannelID, error) { m.filterReq <- superSet @@ -1309,11 +1312,17 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { return case req := <-chanSeries.filterReq: + scids := make([]lnwire.ShortChannelID, len(req)) + for i, scid := range req { + scids[i] = scid.ShortChannelID + } + // We should get a request for the entire range of short // chan ID's. - if !reflect.DeepEqual(expectedReq, req) { - errCh <- fmt.Errorf("wrong request: expected %v, got %v", - expectedReq, req) + if !reflect.DeepEqual(expectedReq, scids) { + errCh <- fmt.Errorf("wrong request: "+ + "expected %v, got %v", expectedReq, req) + return }