diff --git a/channeldb/graph.go b/channeldb/graph.go index 4712412f5..1ab2897bb 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -314,7 +314,6 @@ func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) ( var graphTopLevelBuckets = [][]byte{ nodeBucket, edgeBucket, - edgeIndexBucket, graphMetaBucket, } @@ -2087,10 +2086,12 @@ func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, // words, we perform a set difference of our set of chan ID's and the ones // passed in. This method can be used by callers to determine the set of // channels another peer knows of that we don't. -func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { +func (c *ChannelGraph) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo, + isZombieChan func(time.Time, time.Time) bool) ([]uint64, error) { + var newChanIDs []uint64 - err := kvdb.View(c.db, func(tx kvdb.RTx) error { + err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { edges := tx.ReadBucket(edgeBucket) if edges == nil { return ErrGraphNoEdgesFound @@ -2108,8 +2109,9 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { // We'll run through the set of chanIDs and collate only the // set of channel that are unable to be found within our db. var cidBytes [8]byte - for _, cid := range chanIDs { - byteOrder.PutUint64(cidBytes[:], cid) + for _, info := range chansInfo { + scid := info.ShortChannelID.ToUint64() + byteOrder.PutUint64(cidBytes[:], scid) // If the edge is already known, skip it. if v := edgeIndex.Get(cidBytes[:]); v != nil { @@ -2118,13 +2120,37 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { // If the edge is a known zombie, skip it. if zombieIndex != nil { - isZombie, _, _ := isZombieEdge(zombieIndex, cid) - if isZombie { + isZombie, _, _ := isZombieEdge( + zombieIndex, scid, + ) + + isStillZombie := isZombieChan( + info.Node1UpdateTimestamp, + info.Node2UpdateTimestamp, + ) + + switch { + // If the edge is a known zombie and if we + // would still consider it a zombie given the + // latest update timestamps, then we skip this + // channel. + case isZombie && isStillZombie: continue + + // Otherwise, if we have marked it as a zombie + // but the latest update timestamps could bring + // it back from the dead, then we mark it alive, + // and we let it be added to the set of IDs to + // query our peer for. + case isZombie && !isStillZombie: + err := c.markEdgeLive(tx, scid) + if err != nil { + return err + } } } - newChanIDs = append(newChanIDs, cid) + newChanIDs = append(newChanIDs, scid) } return nil @@ -2135,7 +2161,12 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { // If we don't know of any edges yet, then we'll return the entire set // of chan IDs specified. case err == ErrGraphNoEdgesFound: - return chanIDs, nil + ogChanIDs := make([]uint64, len(chansInfo)) + for i, info := range chansInfo { + ogChanIDs[i] = info.ShortChannelID.ToUint64() + } + + return ogChanIDs, nil case err != nil: return nil, err @@ -2144,6 +2175,23 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { return newChanIDs, nil } +// ChannelUpdateInfo couples the SCID of a channel with the timestamps of the +// latest received channel updates for the channel. +type ChannelUpdateInfo struct { + // ShortChannelID is the SCID identifier of the channel. + ShortChannelID lnwire.ShortChannelID + + // Node1UpdateTimestamp is the timestamp of the latest received update + // from the node 1 channel peer. This will be set to zero time if no + // update has yet been received from this node. + Node1UpdateTimestamp time.Time + + // Node2UpdateTimestamp is the timestamp of the latest received update + // from the node 2 channel peer. This will be set to zero time if no + // update has yet been received from this node. + Node2UpdateTimestamp time.Time +} + // BlockChannelRange represents a range of channels for a given block height. type BlockChannelRange struct { // Height is the height of the block all of the channels below were @@ -2152,17 +2200,20 @@ type BlockChannelRange struct { // Channels is the list of channels identified by their short ID // representation known to us that were included in the block height - // above. - Channels []lnwire.ShortChannelID + // above. The list may include channel update timestamp information if + // requested. + Channels []ChannelUpdateInfo } // FilterChannelRange returns the channel ID's of all known channels which were // mined in a block height within the passed range. The channel IDs are grouped // by their common block height. This method can be used to quickly share with a // peer the set of channels we know of within a particular range to catch them -// up after a period of time offline. +// up after a period of time offline. If withTimestamps is true then the +// timestamp info of the latest received channel update messages of the channel +// will be included in the response. func (c *ChannelGraph) FilterChannelRange(startHeight, - endHeight uint32) ([]BlockChannelRange, error) { + endHeight uint32, withTimestamps bool) ([]BlockChannelRange, error) { startChanID := &lnwire.ShortChannelID{ BlockHeight: startHeight, @@ -2181,7 +2232,7 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, byteOrder.PutUint64(chanIDStart[:], startChanID.ToUint64()) byteOrder.PutUint64(chanIDEnd[:], endChanID.ToUint64()) - var channelsPerBlock map[uint32][]lnwire.ShortChannelID + var channelsPerBlock map[uint32][]ChannelUpdateInfo err := kvdb.View(c.db, func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) if edges == nil { @@ -2213,14 +2264,60 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, // we'll add it to our returned set. rawCid := byteOrder.Uint64(k) cid := lnwire.NewShortChanIDFromInt(rawCid) + + chanInfo := ChannelUpdateInfo{ + ShortChannelID: cid, + } + + if !withTimestamps { + channelsPerBlock[cid.BlockHeight] = append( + channelsPerBlock[cid.BlockHeight], + chanInfo, + ) + + continue + } + + node1Key, node2Key := computeEdgePolicyKeys(&edgeInfo) + + rawPolicy := edges.Get(node1Key) + if len(rawPolicy) != 0 { + r := bytes.NewReader(rawPolicy) + + edge, err := deserializeChanEdgePolicyRaw(r) + if err != nil && !errors.Is( + err, ErrEdgePolicyOptionalFieldNotFound, + ) { + + return err + } + + chanInfo.Node1UpdateTimestamp = edge.LastUpdate + } + + rawPolicy = edges.Get(node2Key) + if len(rawPolicy) != 0 { + r := bytes.NewReader(rawPolicy) + + edge, err := deserializeChanEdgePolicyRaw(r) + if err != nil && !errors.Is( + err, ErrEdgePolicyOptionalFieldNotFound, + ) { + + return err + } + + chanInfo.Node2UpdateTimestamp = edge.LastUpdate + } + channelsPerBlock[cid.BlockHeight] = append( - channelsPerBlock[cid.BlockHeight], cid, + channelsPerBlock[cid.BlockHeight], chanInfo, ) } return nil }, func() { - channelsPerBlock = make(map[uint32][]lnwire.ShortChannelID) + channelsPerBlock = make(map[uint32][]ChannelUpdateInfo) }) switch { @@ -3119,6 +3216,24 @@ func (c *ChannelGraph) FetchOtherNode(tx kvdb.RTx, return targetNode, err } +// computeEdgePolicyKeys is a helper function that can be used to compute the +// keys used to index the channel edge policy info for the two nodes of the +// edge. The keys for node 1 and node 2 are returned respectively. +func computeEdgePolicyKeys(info *models.ChannelEdgeInfo) ([]byte, []byte) { + var ( + node1Key [33 + 8]byte + node2Key [33 + 8]byte + ) + + copy(node1Key[:], info.NodeKey1Bytes[:]) + copy(node2Key[:], info.NodeKey2Bytes[:]) + + byteOrder.PutUint64(node1Key[33:], info.ChannelID) + byteOrder.PutUint64(node2Key[33:], info.ChannelID) + + return node1Key[:], node2Key[:] +} + // FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for // the channel identified by the funding outpoint. If the channel can't be // found, then ErrEdgeNotFound is returned. A struct which houses the general @@ -3497,10 +3612,17 @@ func markEdgeZombie(zombieIndex kvdb.RwBucket, chanID uint64, pubKey1, // MarkEdgeLive clears an edge from our zombie index, deeming it as live. func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error { + return c.markEdgeLive(nil, chanID) +} + +// markEdgeLive clears an edge from the zombie index. This method can be called +// with an existing kvdb.RwTx or the argument can be set to nil in which case a +// new transaction will be created. +func (c *ChannelGraph) markEdgeLive(tx kvdb.RwTx, chanID uint64) error { c.cacheMu.Lock() defer c.cacheMu.Unlock() - err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { + dbFn := func(tx kvdb.RwTx) error { edges := tx.ReadWriteBucket(edgeBucket) if edges == nil { return ErrGraphNoEdgesFound @@ -3518,7 +3640,16 @@ func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error { } return zombieIndex.Delete(k[:]) - }, func() {}) + } + + // If the transaction is nil, we'll create a new one. Otherwise, we use + // the existing transaction + var err error + if tx == nil { + err = kvdb.Update(c.db, dbFn, func() {}) + } else { + err = dbFn(tx) + } if err != nil { return err } @@ -3528,11 +3659,12 @@ func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error { // We need to add the channel back into our graph cache, otherwise we // won't use it for path finding. - edgeInfos, err := c.FetchChanInfos([]uint64{chanID}) - if err != nil { - return err - } if c.graphCache != nil { + edgeInfos, err := c.FetchChanInfos([]uint64{chanID}) + if err != nil { + return err + } + for _, edgeInfo := range edgeInfos { c.graphCache.AddChannel( edgeInfo.Info, edgeInfo.Policy1, diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 41c65e353..cbfabb3cc 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -27,6 +27,7 @@ import ( "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" + "golang.org/x/exp/rand" ) var ( @@ -1927,14 +1928,32 @@ func TestFilterKnownChanIDs(t *testing.T) { graph, err := MakeTestGraph(t) require.NoError(t, err, "unable to make test database") + isZombieUpdate := func(updateTime1 time.Time, + updateTime2 time.Time) bool { + + return true + } + + var ( + scid1 = lnwire.ShortChannelID{BlockHeight: 1} + scid2 = lnwire.ShortChannelID{BlockHeight: 2} + scid3 = lnwire.ShortChannelID{BlockHeight: 3} + ) + // If we try to filter out a set of channel ID's before we even know of // any channels, then we should get the entire set back. - preChanIDs := []uint64{1, 2, 3, 4} - filteredIDs, err := graph.FilterKnownChanIDs(preChanIDs) - require.NoError(t, err, "unable to filter chan IDs") - if !reflect.DeepEqual(preChanIDs, filteredIDs) { - t.Fatalf("chan IDs shouldn't have been filtered!") + preChanIDs := []ChannelUpdateInfo{ + {ShortChannelID: scid1}, + {ShortChannelID: scid2}, + {ShortChannelID: scid3}, } + filteredIDs, err := graph.FilterKnownChanIDs(preChanIDs, isZombieUpdate) + require.NoError(t, err, "unable to filter chan IDs") + require.EqualValues(t, []uint64{ + scid1.ToUint64(), + scid2.ToUint64(), + scid3.ToUint64(), + }, filteredIDs) // We'll start by creating two nodes which will seed our test graph. node1, err := createTestVertex(graph.db) @@ -1951,7 +1970,7 @@ func TestFilterKnownChanIDs(t *testing.T) { // Next, we'll add 5 channel ID's to the graph, each of them having a // block height 10 blocks after the previous. const numChans = 5 - chanIDs := make([]uint64, 0, numChans) + chanIDs := make([]ChannelUpdateInfo, 0, numChans) for i := 0; i < numChans; i++ { channel, chanID := createEdge( uint32(i*10), 0, 0, 0, node1, node2, @@ -1961,11 +1980,13 @@ func TestFilterKnownChanIDs(t *testing.T) { t.Fatalf("unable to create channel edge: %v", err) } - chanIDs = append(chanIDs, chanID.ToUint64()) + chanIDs = append(chanIDs, ChannelUpdateInfo{ + ShortChannelID: chanID, + }) } const numZombies = 5 - zombieIDs := make([]uint64, 0, numZombies) + zombieIDs := make([]ChannelUpdateInfo, 0, numZombies) for i := 0; i < numZombies; i++ { channel, chanID := createEdge( uint32(i*10+1), 0, 0, 0, node1, node2, @@ -1978,13 +1999,15 @@ func TestFilterKnownChanIDs(t *testing.T) { t.Fatalf("unable to mark edge zombie: %v", err) } - zombieIDs = append(zombieIDs, chanID.ToUint64()) + zombieIDs = append( + zombieIDs, ChannelUpdateInfo{ShortChannelID: chanID}, + ) } queryCases := []struct { - queryIDs []uint64 + queryIDs []ChannelUpdateInfo - resp []uint64 + resp []ChannelUpdateInfo }{ // If we attempt to filter out all chanIDs we know of, the // response should be the empty set. @@ -2000,28 +2023,70 @@ func TestFilterKnownChanIDs(t *testing.T) { // If we query for a set of ID's that we didn't insert, we // should get the same set back. { - queryIDs: []uint64{99, 100}, - resp: []uint64{99, 100}, + queryIDs: []ChannelUpdateInfo{ + {ShortChannelID: lnwire.ShortChannelID{ + BlockHeight: 99, + }}, + {ShortChannelID: lnwire.ShortChannelID{ + BlockHeight: 100, + }}, + }, + resp: []ChannelUpdateInfo{ + {ShortChannelID: lnwire.ShortChannelID{ + BlockHeight: 99, + }}, + {ShortChannelID: lnwire.ShortChannelID{ + BlockHeight: 100, + }}, + }, }, // If we query for a super-set of our the chan ID's inserted, // we should only get those new chanIDs back. { - queryIDs: append(chanIDs, []uint64{99, 101}...), - resp: []uint64{99, 101}, + queryIDs: append(chanIDs, []ChannelUpdateInfo{ + { + ShortChannelID: lnwire.ShortChannelID{ + BlockHeight: 99, + }, + }, + { + ShortChannelID: lnwire.ShortChannelID{ + BlockHeight: 101, + }, + }, + }...), + resp: []ChannelUpdateInfo{ + { + ShortChannelID: lnwire.ShortChannelID{ + BlockHeight: 99, + }, + }, + { + ShortChannelID: lnwire.ShortChannelID{ + BlockHeight: 101, + }, + }, + }, }, } for _, queryCase := range queryCases { - resp, err := graph.FilterKnownChanIDs(queryCase.queryIDs) - if err != nil { - t.Fatalf("unable to filter chan IDs: %v", err) + resp, err := graph.FilterKnownChanIDs( + queryCase.queryIDs, isZombieUpdate, + ) + require.NoError(t, err) + + expectedSCIDs := make([]uint64, len(queryCase.resp)) + for i, info := range queryCase.resp { + expectedSCIDs[i] = info.ShortChannelID.ToUint64() } - if !reflect.DeepEqual(resp, queryCase.resp) { - t.Fatalf("expected %v, got %v", spew.Sdump(queryCase.resp), - spew.Sdump(resp)) + if len(expectedSCIDs) == 0 { + expectedSCIDs = nil } + + require.EqualValues(t, expectedSCIDs, resp) } } @@ -2031,79 +2096,141 @@ func TestFilterChannelRange(t *testing.T) { t.Parallel() graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + require.NoError(t, err) // We'll first populate our graph with two nodes. All channels created // below will be made between these two nodes. node1, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } + require.NoError(t, err) + require.NoError(t, graph.AddLightningNode(node1)) + node2, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } + require.NoError(t, err) + require.NoError(t, graph.AddLightningNode(node2)) // If we try to filter a channel range before we have any channels // inserted, we should get an empty slice of results. - resp, err := graph.FilterChannelRange(10, 100) - require.NoError(t, err, "unable to filter channels") - if len(resp) != 0 { - t.Fatalf("expected zero chans, instead got %v", len(resp)) - } + resp, err := graph.FilterChannelRange(10, 100, false) + require.NoError(t, err) + require.Empty(t, resp) // To start, we'll create a set of channels, two mined in a block 10 // blocks after the prior one. startHeight := uint32(100) endHeight := startHeight const numChans = 10 - channelRanges := make([]BlockChannelRange, 0, numChans/2) + + var ( + channelRanges = make( + []BlockChannelRange, 0, numChans/2, + ) + channelRangesWithTimestamps = make( + []BlockChannelRange, 0, numChans/2, + ) + ) + + updateTimeSeed := int64(1) + maybeAddPolicy := func(chanID uint64, node *LightningNode, + node2 bool) time.Time { + + var chanFlags lnwire.ChanUpdateChanFlags + if node2 { + chanFlags = lnwire.ChanUpdateDirection + } + + var updateTime time.Time + if rand.Int31n(2) == 0 { + updateTime = time.Unix(updateTimeSeed, 0) + err = graph.UpdateEdgePolicy(&models.ChannelEdgePolicy{ + ToNode: node.PubKeyBytes, + ChannelFlags: chanFlags, + ChannelID: chanID, + LastUpdate: updateTime, + }) + require.NoError(t, err) + } + updateTimeSeed++ + + return updateTime + } + for i := 0; i < numChans/2; i++ { chanHeight := endHeight channel1, chanID1 := createEdge( chanHeight, uint32(i+1), 0, 0, node1, node2, ) - if err := graph.AddChannelEdge(&channel1); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } + require.NoError(t, graph.AddChannelEdge(&channel1)) channel2, chanID2 := createEdge( chanHeight, uint32(i+2), 0, 0, node1, node2, ) - if err := graph.AddChannelEdge(&channel2); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } + require.NoError(t, graph.AddChannelEdge(&channel2)) channelRanges = append(channelRanges, BlockChannelRange{ - Height: chanHeight, - Channels: []lnwire.ShortChannelID{chanID1, chanID2}, + Height: chanHeight, + Channels: []ChannelUpdateInfo{ + {ShortChannelID: chanID1}, + {ShortChannelID: chanID2}, + }, }) + + var ( + time1 = maybeAddPolicy(channel1.ChannelID, node1, false) + time2 = maybeAddPolicy(channel1.ChannelID, node2, true) + time3 = maybeAddPolicy(channel2.ChannelID, node1, false) + time4 = maybeAddPolicy(channel2.ChannelID, node2, true) + ) + + channelRangesWithTimestamps = append( + channelRangesWithTimestamps, BlockChannelRange{ + Height: chanHeight, + Channels: []ChannelUpdateInfo{ + { + ShortChannelID: chanID1, + Node1UpdateTimestamp: time1, + Node2UpdateTimestamp: time2, + }, + { + ShortChannelID: chanID2, + Node1UpdateTimestamp: time3, + Node2UpdateTimestamp: time4, + }, + }, + }, + ) + endHeight += 10 } // With our channels inserted, we'll construct a series of queries that // we'll execute below in order to exercise the features of the // FilterKnownChanIDs method. - queryCases := []struct { + tests := []struct { + name string + startHeight uint32 endHeight uint32 - resp []BlockChannelRange + resp []BlockChannelRange + expStartIndex int + expEndIndex int }{ // If we query for the entire range, then we should get the same // set of short channel IDs back. { + name: "entire range", startHeight: startHeight, endHeight: endHeight, - resp: channelRanges, + resp: channelRanges, + expStartIndex: 0, + expEndIndex: len(channelRanges), }, - // If we query for a range of channels right before our range, we - // shouldn't get any results back. + // If we query for a range of channels right before our range, + // we shouldn't get any results back. { + name: "range before", startHeight: 0, endHeight: 10, }, @@ -2111,40 +2238,72 @@ func TestFilterChannelRange(t *testing.T) { // If we only query for the last height (range wise), we should // only get that last channel. { + name: "last height", startHeight: endHeight - 10, endHeight: endHeight - 10, - resp: channelRanges[4:], + resp: channelRanges[4:], + expStartIndex: 4, + expEndIndex: len(channelRanges), }, // If we query for just the first height, we should only get a // single channel back (the first one). { + name: "first height", startHeight: startHeight, endHeight: startHeight, - resp: channelRanges[:1], + resp: channelRanges[:1], + expStartIndex: 0, + expEndIndex: 1, }, { + name: "subset", startHeight: startHeight + 10, endHeight: endHeight - 10, - resp: channelRanges[1:5], + resp: channelRanges[1:5], + expStartIndex: 1, + expEndIndex: 5, }, } - for i, queryCase := range queryCases { - resp, err := graph.FilterChannelRange( - queryCase.startHeight, queryCase.endHeight, - ) - if err != nil { - t.Fatalf("unable to issue range query: %v", err) - } - if !reflect.DeepEqual(resp, queryCase.resp) { - t.Fatalf("case #%v: expected %v, got %v", i, - queryCase.resp, resp) - } + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + // First, do the query without requesting timestamps. + resp, err := graph.FilterChannelRange( + test.startHeight, test.endHeight, false, + ) + require.NoError(t, err) + + expRes := channelRanges[test.expStartIndex:test.expEndIndex] //nolint:lll + + if len(expRes) == 0 { + require.Nil(t, resp) + } else { + require.Equal(t, expRes, resp) + } + + // Now, query the timestamps as well. + resp, err = graph.FilterChannelRange( + test.startHeight, test.endHeight, true, + ) + require.NoError(t, err) + + expRes = channelRangesWithTimestamps[test.expStartIndex:test.expEndIndex] //nolint:lll + + if len(expRes) == 0 { + require.Nil(t, resp) + } else { + require.Equal(t, expRes, resp) + } + }) } } diff --git a/discovery/chan_series.go b/discovery/chan_series.go index c811017ab..34e6d4a9d 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -37,14 +37,16 @@ type ChannelGraphTimeSeries interface { // ID's represents the ID's that we don't know of which were in the // passed superSet. FilterKnownChanIDs(chain chainhash.Hash, - superSet []lnwire.ShortChannelID) ([]lnwire.ShortChannelID, error) + superSet []channeldb.ChannelUpdateInfo, + isZombieChan func(time.Time, time.Time) bool) ( + []lnwire.ShortChannelID, error) // FilterChannelRange returns the set of channels that we created // between the start height and the end height. The channel IDs are // grouped by their common block height. We'll use this to to a remote // peer's QueryChannelRange message. - FilterChannelRange(chain chainhash.Hash, - startHeight, endHeight uint32) ([]channeldb.BlockChannelRange, error) + FilterChannelRange(chain chainhash.Hash, startHeight, endHeight uint32, + withTimestamps bool) ([]channeldb.BlockChannelRange, error) // FetchChanAnns returns a full set of channel announcements as well as // their updates that match the set of specified short channel ID's. @@ -197,15 +199,12 @@ func (c *ChanSeries) UpdatesInHorizon(chain chainhash.Hash, // represents the ID's that we don't know of which were in the passed superSet. // // NOTE: This is part of the ChannelGraphTimeSeries interface. -func (c *ChanSeries) FilterKnownChanIDs(chain chainhash.Hash, - superSet []lnwire.ShortChannelID) ([]lnwire.ShortChannelID, error) { +func (c *ChanSeries) FilterKnownChanIDs(_ chainhash.Hash, + superSet []channeldb.ChannelUpdateInfo, + isZombieChan func(time.Time, time.Time) bool) ( + []lnwire.ShortChannelID, error) { - chanIDs := make([]uint64, 0, len(superSet)) - for _, chanID := range superSet { - chanIDs = append(chanIDs, chanID.ToUint64()) - } - - newChanIDs, err := c.graph.FilterKnownChanIDs(chanIDs) + newChanIDs, err := c.graph.FilterKnownChanIDs(superSet, isZombieChan) if err != nil { return nil, err } @@ -226,10 +225,13 @@ func (c *ChanSeries) FilterKnownChanIDs(chain chainhash.Hash, // message. // // NOTE: This is part of the ChannelGraphTimeSeries interface. -func (c *ChanSeries) FilterChannelRange(chain chainhash.Hash, - startHeight, endHeight uint32) ([]channeldb.BlockChannelRange, error) { +func (c *ChanSeries) FilterChannelRange(_ chainhash.Hash, startHeight, + endHeight uint32, withTimestamps bool) ([]channeldb.BlockChannelRange, + error) { - return c.graph.FilterChannelRange(startHeight, endHeight) + return c.graph.FilterChannelRange( + startHeight, endHeight, withTimestamps, + ) } // FetchChanAnns returns a full set of channel announcements as well as their diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 17ef933e0..8fd18bf5e 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -261,6 +261,11 @@ type Config struct { // gossip syncers will be passive. NumActiveSyncers int + // NoTimestampQueries will prevent the GossipSyncer from querying + // timestamps of announcement messages from the peer and from replying + // to timestamp queries. + NoTimestampQueries bool + // RotateTicker is a ticker responsible for notifying the SyncManager // when it should rotate its active syncers. A single active syncer with // a chansSynced state will be exchanged for a passive syncer in order @@ -330,6 +335,11 @@ type Config struct { // to without iterating over the entire set of open channels. FindChannel func(node *btcec.PublicKey, chanID lnwire.ChannelID) ( *channeldb.OpenChannel, error) + + // IsStillZombieChannel takes the timestamps of the latest channel + // updates for a channel and returns true if the channel should be + // considered a zombie based on these timestamps. + IsStillZombieChannel func(time.Time, time.Time) bool } // processedNetworkMsg is a wrapper around networkMsg and a boolean. It is @@ -510,9 +520,11 @@ func New(cfg Config, selfKeyDesc *keychain.KeyDescriptor) *AuthenticatedGossiper RotateTicker: cfg.RotateTicker, HistoricalSyncTicker: cfg.HistoricalSyncTicker, NumActiveSyncers: cfg.NumActiveSyncers, + NoTimestampQueries: cfg.NoTimestampQueries, IgnoreHistoricalFilters: cfg.IgnoreHistoricalFilters, BestHeight: gossiper.latestHeight, PinnedSyncers: cfg.PinnedSyncers, + IsStillZombieChannel: cfg.IsStillZombieChannel, }) gossiper.reliableSender = newReliableSender(&reliableSenderCfg{ diff --git a/discovery/sync_manager.go b/discovery/sync_manager.go index 8123c37eb..39098f70e 100644 --- a/discovery/sync_manager.go +++ b/discovery/sync_manager.go @@ -73,6 +73,11 @@ type SyncManagerCfg struct { // gossip syncers will be passive. NumActiveSyncers int + // NoTimestampQueries will prevent the GossipSyncer from querying + // timestamps of announcement messages from the peer and from responding + // to timestamp queries + NoTimestampQueries bool + // RotateTicker is a ticker responsible for notifying the SyncManager // when it should rotate its active syncers. A single active syncer with // a chansSynced state will be exchanged for a passive syncer in order @@ -97,6 +102,11 @@ type SyncManagerCfg struct { // ActiveSync upon connection. These peers will never transition to // PassiveSync. PinnedSyncers PinnedSyncers + + // IsStillZombieChannel takes the timestamps of the latest channel + // updates for a channel and returns true if the channel should be + // considered a zombie based on these timestamps. + IsStillZombieChannel func(time.Time, time.Time) bool } // SyncManager is a subsystem of the gossiper that manages the gossip syncers @@ -495,6 +505,8 @@ func (m *SyncManager) createGossipSyncer(peer lnpeer.Peer) *GossipSyncer { bestHeight: m.cfg.BestHeight, markGraphSynced: m.markGraphSynced, maxQueryChanRangeReplies: maxQueryChanRangeReplies, + noTimestampQueryOption: m.cfg.NoTimestampQueries, + isStillZombieChannel: m.cfg.IsStillZombieChannel, }) // Gossip syncers are initialized by default in a PassiveSync type diff --git a/discovery/sync_manager_test.go b/discovery/sync_manager_test.go index f6861f0c6..f71d1728e 100644 --- a/discovery/sync_manager_test.go +++ b/discovery/sync_manager_test.go @@ -277,6 +277,7 @@ func TestSyncManagerInitialHistoricalSync(t *testing.T) { assertMsgSent(t, peer, &lnwire.QueryChannelRange{ FirstBlockHeight: 0, NumBlocks: latestKnownHeight, + QueryOptions: lnwire.NewTimestampQueryOption(), }) // The graph should not be considered as synced since the initial @@ -379,6 +380,7 @@ func TestSyncManagerForceHistoricalSync(t *testing.T) { assertMsgSent(t, peer, &lnwire.QueryChannelRange{ FirstBlockHeight: 0, NumBlocks: latestKnownHeight, + QueryOptions: lnwire.NewTimestampQueryOption(), }) // If an additional peer connects, then a historical sync should not be @@ -394,6 +396,7 @@ func TestSyncManagerForceHistoricalSync(t *testing.T) { assertMsgSent(t, extraPeer, &lnwire.QueryChannelRange{ FirstBlockHeight: 0, NumBlocks: latestKnownHeight, + QueryOptions: lnwire.NewTimestampQueryOption(), }) } @@ -415,6 +418,7 @@ func TestSyncManagerGraphSyncedAfterHistoricalSyncReplacement(t *testing.T) { assertMsgSent(t, peer, &lnwire.QueryChannelRange{ FirstBlockHeight: 0, NumBlocks: latestKnownHeight, + QueryOptions: lnwire.NewTimestampQueryOption(), }) // The graph should not be considered as synced since the initial @@ -620,6 +624,7 @@ func assertTransitionToChansSynced(t *testing.T, s *GossipSyncer, peer *mockPeer query := &lnwire.QueryChannelRange{ FirstBlockHeight: 0, NumBlocks: latestKnownHeight, + QueryOptions: lnwire.NewTimestampQueryOption(), } assertMsgSent(t, peer, query) diff --git a/discovery/syncer.go b/discovery/syncer.go index 722519fb0..dc1565371 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -11,6 +11,7 @@ import ( "time" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnwire" "golang.org/x/time/rate" @@ -179,13 +180,16 @@ const ( // requestBatchSize is the maximum number of channels we will query the // remote peer for in a QueryShortChanIDs message. requestBatchSize = 500 + + // filterSemaSize is the capacity of gossipFilterSema. + filterSemaSize = 5 ) var ( // encodingTypeToChunkSize maps an encoding type, to the max number of // short chan ID's using the encoding type that we can fit into a // single message safely. - encodingTypeToChunkSize = map[lnwire.ShortChanIDEncoding]int32{ + encodingTypeToChunkSize = map[lnwire.QueryEncoding]int32{ lnwire.EncodingSortedPlain: 8000, } @@ -232,7 +236,7 @@ type gossipSyncerCfg struct { // encodingType is the current encoding type we're aware of. Requests // with different encoding types will be rejected. - encodingType lnwire.ShortChanIDEncoding + encodingType lnwire.QueryEncoding // chunkSize is the max number of short chan IDs using the syncer's // encoding type that we can fit into a single message safely. @@ -271,6 +275,11 @@ type gossipSyncerCfg struct { // peer. noReplyQueries bool + // noTimestampQueryOption will prevent the GossipSyncer from querying + // timestamps of announcement messages from the peer, and it will + // prevent it from responding to timestamp queries. + noTimestampQueryOption bool + // ignoreHistoricalFilters will prevent syncers from replying with // historical data when the remote peer sets a gossip_timestamp_range. // This prevents ranges with old start times from causing us to dump the @@ -287,6 +296,11 @@ type gossipSyncerCfg struct { // maxQueryChanRangeReplies is the maximum number of replies we'll allow // for a single QueryChannelRange request. maxQueryChanRangeReplies uint32 + + // isStillZombieChannel takes the timestamps of the latest channel + // updates for a channel and returns true if the channel should be + // considered a zombie based on these timestamps. + isStillZombieChannel func(time.Time, time.Time) bool } // GossipSyncer is a struct that handles synchronizing the channel graph state @@ -361,7 +375,7 @@ type GossipSyncer struct { // bufferedChanRangeReplies is used in the waitingQueryChanReply to // buffer all the chunked response to our query. - bufferedChanRangeReplies []lnwire.ShortChannelID + bufferedChanRangeReplies []channeldb.ChannelUpdateInfo // numChanRangeRepliesRcvd is used to track the number of replies // received as part of a QueryChannelRange. This field is primarily used @@ -387,6 +401,8 @@ type GossipSyncer struct { sync.Mutex + gossipFilterSema chan struct{} + quit chan struct{} wg sync.WaitGroup } @@ -415,6 +431,11 @@ func newGossipSyncer(cfg gossipSyncerCfg) *GossipSyncer { interval, cfg.maxUndelayedQueryReplies, ) + filterSema := make(chan struct{}, filterSemaSize) + for i := 0; i < filterSemaSize; i++ { + filterSema <- struct{}{} + } + return &GossipSyncer{ cfg: cfg, rateLimiter: rateLimiter, @@ -422,6 +443,7 @@ func newGossipSyncer(cfg gossipSyncerCfg) *GossipSyncer { historicalSyncReqs: make(chan *historicalSyncReq), gossipMsgs: make(chan lnwire.Message, 100), queryMsgs: make(chan lnwire.Message, 100), + gossipFilterSema: filterSema, quit: make(chan struct{}), } } @@ -808,9 +830,31 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro } g.prevReplyChannelRange = msg - g.bufferedChanRangeReplies = append( - g.bufferedChanRangeReplies, msg.ShortChanIDs..., - ) + if len(msg.Timestamps) != 0 && + len(msg.Timestamps) != len(msg.ShortChanIDs) { + + return fmt.Errorf("number of timestamps not equal to " + + "number of SCIDs") + } + + for i, scid := range msg.ShortChanIDs { + info := channeldb.ChannelUpdateInfo{ + ShortChannelID: scid, + } + + if len(msg.Timestamps) != 0 { + t1 := time.Unix(int64(msg.Timestamps[i].Timestamp1), 0) + info.Node1UpdateTimestamp = t1 + + t2 := time.Unix(int64(msg.Timestamps[i].Timestamp2), 0) + info.Node2UpdateTimestamp = t2 + } + + g.bufferedChanRangeReplies = append( + g.bufferedChanRangeReplies, info, + ) + } + switch g.cfg.encodingType { case lnwire.EncodingSortedPlain: g.numChanRangeRepliesRcvd++ @@ -857,6 +901,7 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro // which channels they know of that we don't. newChans, err := g.cfg.channelSeries.FilterKnownChanIDs( g.cfg.chainHash, g.bufferedChanRangeReplies, + g.cfg.isStillZombieChannel, ) if err != nil { return fmt.Errorf("unable to filter chan ids: %v", err) @@ -922,7 +967,7 @@ func (g *GossipSyncer) genChanRangeQuery( case newestChan.BlockHeight <= chanRangeQueryBuffer: startHeight = 0 default: - startHeight = uint32(newestChan.BlockHeight - chanRangeQueryBuffer) + startHeight = newestChan.BlockHeight - chanRangeQueryBuffer } // Determine the number of blocks to request based on our best height. @@ -945,6 +990,11 @@ func (g *GossipSyncer) genChanRangeQuery( FirstBlockHeight: startHeight, NumBlocks: numBlocks, } + + if !g.cfg.noTimestampQueryOption { + query.QueryOptions = lnwire.NewTimestampQueryOption() + } + g.curQueryRangeMsg = query return query, nil @@ -1016,12 +1066,18 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro "num_blocks=%v", g.cfg.peerPub[:], query.FirstBlockHeight, query.NumBlocks) + // Check if the query asked for timestamps. We will only serve + // timestamps if this has not been disabled with + // noTimestampQueryOption. + withTimestamps := query.WithTimestamps() && + !g.cfg.noTimestampQueryOption + // Next, we'll consult the time series to obtain the set of known // channel ID's that match their query. startBlock := query.FirstBlockHeight endBlock := query.LastBlockHeight() channelRanges, err := g.cfg.channelSeries.FilterChannelRange( - query.ChainHash, startBlock, endBlock, + query.ChainHash, startBlock, endBlock, withTimestamps, ) if err != nil { return err @@ -1034,7 +1090,7 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // this as there's a transport message size limit which we'll need to // adhere to. We also need to make sure all of our replies cover the // expected range of the query. - sendReplyForChunk := func(channelChunk []lnwire.ShortChannelID, + sendReplyForChunk := func(channelChunk []channeldb.ChannelUpdateInfo, firstHeight, lastHeight uint32, finalChunk bool) error { // The number of blocks contained in the current chunk (the @@ -1047,25 +1103,58 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro complete = 1 } + var timestamps lnwire.Timestamps + if withTimestamps { + timestamps = make(lnwire.Timestamps, len(channelChunk)) + } + + scids := make([]lnwire.ShortChannelID, len(channelChunk)) + for i, info := range channelChunk { + scids[i] = info.ShortChannelID + + if !withTimestamps { + continue + } + + timestamps[i].Timestamp1 = uint32( + info.Node1UpdateTimestamp.Unix(), + ) + + timestamps[i].Timestamp2 = uint32( + info.Node2UpdateTimestamp.Unix(), + ) + } + return g.cfg.sendToPeerSync(&lnwire.ReplyChannelRange{ ChainHash: query.ChainHash, NumBlocks: numBlocks, FirstBlockHeight: firstHeight, Complete: complete, EncodingType: g.cfg.encodingType, - ShortChanIDs: channelChunk, + ShortChanIDs: scids, + Timestamps: timestamps, }) } var ( firstHeight = query.FirstBlockHeight lastHeight uint32 - channelChunk []lnwire.ShortChannelID + channelChunk []channeldb.ChannelUpdateInfo ) + + // chunkSize is the maximum number of SCIDs that we can safely put in a + // single message. If we also need to include timestamps though, then + // this number is halved since encoding two timestamps takes the same + // number of bytes as encoding an SCID. + chunkSize := g.cfg.chunkSize + if withTimestamps { + chunkSize /= 2 + } + for _, channelRange := range channelRanges { channels := channelRange.Channels numChannels := int32(len(channels)) - numLeftToAdd := g.cfg.chunkSize - int32(len(channelChunk)) + numLeftToAdd := chunkSize - int32(len(channelChunk)) // Include the current block in the ongoing chunk if it can fit // and move on to the next block. @@ -1081,6 +1170,7 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // to. log.Infof("GossipSyncer(%x): sending range chunk of size=%v", g.cfg.peerPub[:], len(channelChunk)) + lastHeight = channelRange.Height - 1 err := sendReplyForChunk( channelChunk, firstHeight, lastHeight, false, @@ -1095,21 +1185,23 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // this isn't an issue since we'll randomly shuffle them and we // assume a historical gossip sync is performed at a later time. firstHeight = channelRange.Height - chunkSize := numChannels - exceedsChunkSize := numChannels > g.cfg.chunkSize + finalChunkSize := numChannels + exceedsChunkSize := numChannels > chunkSize if exceedsChunkSize { rand.Shuffle(len(channels), func(i, j int) { channels[i], channels[j] = channels[j], channels[i] }) - chunkSize = g.cfg.chunkSize + finalChunkSize = chunkSize } - channelChunk = channels[:chunkSize] + channelChunk = channels[:finalChunkSize] // Sort the chunk once again if we had to shuffle it. if exceedsChunkSize { sort.Slice(channelChunk, func(i, j int) bool { - return channelChunk[i].ToUint64() < - channelChunk[j].ToUint64() + id1 := channelChunk[i].ShortChannelID.ToUint64() + id2 := channelChunk[j].ShortChannelID.ToUint64() + + return id1 < id2 }) } } @@ -1117,6 +1209,7 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // Send the remaining chunk as the final reply. log.Infof("GossipSyncer(%x): sending final chan range chunk, size=%v", g.cfg.peerPub[:], len(channelChunk)) + return sendReplyForChunk( channelChunk, firstHeight, query.LastBlockHeight(), true, ) @@ -1220,10 +1313,19 @@ func (g *GossipSyncer) ApplyGossipFilter(filter *lnwire.GossipTimestampRange) er return nil } + select { + case <-g.gossipFilterSema: + case <-g.quit: + return ErrGossipSyncerExiting + } + // We'll conclude by launching a goroutine to send out any updates. g.wg.Add(1) go func() { defer g.wg.Done() + defer func() { + g.gossipFilterSema <- struct{}{} + }() for _, msg := range newUpdatestoSend { err := g.cfg.sendToPeerSync(msg) diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index a7b514db8..3a77e046c 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -42,7 +42,7 @@ type mockChannelGraphTimeSeries struct { horizonReq chan horizonQuery horizonResp chan []lnwire.Message - filterReq chan []lnwire.ShortChannelID + filterReq chan []channeldb.ChannelUpdateInfo filterResp chan []lnwire.ShortChannelID filterRangeReqs chan filterRangeReq @@ -64,7 +64,7 @@ func newMockChannelGraphTimeSeries( horizonReq: make(chan horizonQuery, 1), horizonResp: make(chan []lnwire.Message, 1), - filterReq: make(chan []lnwire.ShortChannelID, 1), + filterReq: make(chan []channeldb.ChannelUpdateInfo, 1), filterResp: make(chan []lnwire.ShortChannelID, 1), filterRangeReqs: make(chan filterRangeReq, 1), @@ -90,23 +90,30 @@ func (m *mockChannelGraphTimeSeries) UpdatesInHorizon(chain chainhash.Hash, return <-m.horizonResp, nil } + func (m *mockChannelGraphTimeSeries) FilterKnownChanIDs(chain chainhash.Hash, - superSet []lnwire.ShortChannelID) ([]lnwire.ShortChannelID, error) { + superSet []channeldb.ChannelUpdateInfo, + isZombieChan func(time.Time, time.Time) bool) ( + []lnwire.ShortChannelID, error) { m.filterReq <- superSet return <-m.filterResp, nil } func (m *mockChannelGraphTimeSeries) FilterChannelRange(chain chainhash.Hash, - startHeight, endHeight uint32) ([]channeldb.BlockChannelRange, error) { + startHeight, endHeight uint32, withTimestamps bool) ( + []channeldb.BlockChannelRange, error) { m.filterRangeReqs <- filterRangeReq{startHeight, endHeight} reply := <-m.filterRangeResp - channelsPerBlock := make(map[uint32][]lnwire.ShortChannelID) + channelsPerBlock := make(map[uint32][]channeldb.ChannelUpdateInfo) for _, cid := range reply { channelsPerBlock[cid.BlockHeight] = append( - channelsPerBlock[cid.BlockHeight], cid, + channelsPerBlock[cid.BlockHeight], + channeldb.ChannelUpdateInfo{ + ShortChannelID: cid, + }, ) } @@ -119,16 +126,21 @@ func (m *mockChannelGraphTimeSeries) FilterChannelRange(chain chainhash.Hash, return blocks[i] < blocks[j] }) - channelRanges := make([]channeldb.BlockChannelRange, 0, len(channelsPerBlock)) + channelRanges := make( + []channeldb.BlockChannelRange, 0, len(channelsPerBlock), + ) for _, block := range blocks { - channelRanges = append(channelRanges, channeldb.BlockChannelRange{ - Height: block, - Channels: channelsPerBlock[block], - }) + channelRanges = append( + channelRanges, channeldb.BlockChannelRange{ + Height: block, + Channels: channelsPerBlock[block], + }, + ) } return channelRanges, nil } + func (m *mockChannelGraphTimeSeries) FetchChanAnns(chain chainhash.Hash, shortChanIDs []lnwire.ShortChannelID) ([]lnwire.Message, error) { @@ -156,27 +168,34 @@ var _ ChannelGraphTimeSeries = (*mockChannelGraphTimeSeries)(nil) // ignored. If no flags are provided, both a channelGraphSyncer and replyHandler // will be spawned by default. func newTestSyncer(hID lnwire.ShortChannelID, - encodingType lnwire.ShortChanIDEncoding, chunkSize int32, + encodingType lnwire.QueryEncoding, chunkSize int32, flags ...bool) (chan []lnwire.Message, *GossipSyncer, *mockChannelGraphTimeSeries) { - syncChannels := true - replyQueries := true + var ( + syncChannels = true + replyQueries = true + timestamps = false + ) if len(flags) > 0 { syncChannels = flags[0] } if len(flags) > 1 { replyQueries = flags[1] } + if len(flags) > 2 { + timestamps = flags[2] + } msgChan := make(chan []lnwire.Message, 20) cfg := gossipSyncerCfg{ - channelSeries: newMockChannelGraphTimeSeries(hID), - encodingType: encodingType, - chunkSize: chunkSize, - batchSize: chunkSize, - noSyncChannels: !syncChannels, - noReplyQueries: !replyQueries, + channelSeries: newMockChannelGraphTimeSeries(hID), + encodingType: encodingType, + chunkSize: chunkSize, + batchSize: chunkSize, + noSyncChannels: !syncChannels, + noReplyQueries: !replyQueries, + noTimestampQueryOption: !timestamps, sendToPeer: func(msgs ...lnwire.Message) error { msgChan <- msgs return nil @@ -1293,11 +1312,17 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { return case req := <-chanSeries.filterReq: + scids := make([]lnwire.ShortChannelID, len(req)) + for i, scid := range req { + scids[i] = scid.ShortChannelID + } + // We should get a request for the entire range of short // chan ID's. - if !reflect.DeepEqual(expectedReq, req) { - errCh <- fmt.Errorf("wrong request: expected %v, got %v", - expectedReq, req) + if !reflect.DeepEqual(expectedReq, scids) { + errCh <- fmt.Errorf("wrong request: "+ + "expected %v, got %v", expectedReq, req) + return } @@ -2250,7 +2275,7 @@ func TestGossipSyncerHistoricalSync(t *testing.T) { // historical sync requests in this state. msgChan, syncer, _ := newTestSyncer( lnwire.ShortChannelID{BlockHeight: latestKnownHeight}, - defaultEncoding, defaultChunkSize, + defaultEncoding, defaultChunkSize, true, true, true, ) syncer.setSyncType(PassiveSync) syncer.setSyncState(chansSynced) @@ -2265,6 +2290,7 @@ func TestGossipSyncerHistoricalSync(t *testing.T) { expectedMsg := &lnwire.QueryChannelRange{ FirstBlockHeight: 0, NumBlocks: latestKnownHeight, + QueryOptions: lnwire.NewTimestampQueryOption(), } select { diff --git a/docs/release-notes/release-notes-0.18.0.md b/docs/release-notes/release-notes-0.18.0.md index 06f093604..6839045b7 100644 --- a/docs/release-notes/release-notes-0.18.0.md +++ b/docs/release-notes/release-notes-0.18.0.md @@ -149,6 +149,13 @@ * [Add Dynamic Commitment Wire Types](https://github.com/lightningnetwork/lnd/pull/8026). This change begins the development of Dynamic Commitments allowing for the negotiation of new channel parameters and the upgrading of channel types. + +* Start using the [timestamps query + option](https://github.com/lightningnetwork/lnd/pull/8030) in the + `query_channel_range` message. This will allow us to know if our peer has a + newer update for a channel that we have marked as a zombie. This addition can + be switched off using the new `protocol.no-timestamp-query-option` config + option. ## Testing diff --git a/lncfg/protocol.go b/lncfg/protocol.go index 4376a145b..f8ac08e86 100644 --- a/lncfg/protocol.go +++ b/lncfg/protocol.go @@ -46,7 +46,14 @@ type ProtocolOptions struct { // NoOptionAnySegwit should be set to true if we don't want to use any // Taproot (and beyond) addresses for co-op closing. - NoOptionAnySegwit bool `long:"no-any-segwit" description:"disallow using any segiwt witness version as a co-op close address"` + NoOptionAnySegwit bool `long:"no-any-segwit" description:"disallow using any segwit witness version as a co-op close address"` + + // NoTimestampQueryOption should be set to true if we don't want our + // syncing peers to also send us the timestamps of announcement messages + // when we send them a channel range query. Setting this to true will + // also mean that we won't respond with timestamps if requested by our + // peers. + NoTimestampQueryOption bool `long:"no-timestamp-query-option" description:"do not query syncing peers for announcement timestamps and do not respond with timestamps if requested"` } // Wumbo returns true if lnd should permit the creation and acceptance of wumbo @@ -82,3 +89,11 @@ func (l *ProtocolOptions) ZeroConf() bool { func (l *ProtocolOptions) NoAnySegwit() bool { return l.NoOptionAnySegwit } + +// NoTimestampsQuery returns true if we should not ask our syncing peers to also +// send us the timestamps of announcement messages when we send them a channel +// range query, and it also means that we will not respond with timestamps if +// requested by our peer. +func (l *ProtocolOptions) NoTimestampsQuery() bool { + return l.NoTimestampQueryOption +} diff --git a/lncfg/protocol_integration.go b/lncfg/protocol_integration.go index 18e09ca72..ff74ba9e9 100644 --- a/lncfg/protocol_integration.go +++ b/lncfg/protocol_integration.go @@ -50,6 +50,13 @@ type ProtocolOptions struct { // NoOptionAnySegwit should be set to true if we don't want to use any // Taproot (and beyond) addresses for co-op closing. NoOptionAnySegwit bool `long:"no-any-segwit" description:"disallow using any segiwt witness version as a co-op close address"` + + // NoTimestampQueryOption should be set to true if we don't want our + // syncing peers to also send us the timestamps of announcement messages + // when we send them a channel range query. Setting this to true will + // also mean that we won't respond with timestamps if requested by our + // peers. + NoTimestampQueryOption bool `long:"no-timestamp-query-option" description:"do not query syncing peers for announcement timestamps and do not respond with timestamps if requested"` } // Wumbo returns true if lnd should permit the creation and acceptance of wumbo diff --git a/lnwire/encoding.go b/lnwire/encoding.go new file mode 100644 index 000000000..e04b2b01d --- /dev/null +++ b/lnwire/encoding.go @@ -0,0 +1,17 @@ +package lnwire + +// QueryEncoding is an enum-like type that represents exactly how a set data is +// encoded on the wire. +type QueryEncoding uint8 + +const ( + // EncodingSortedPlain signals that the set of data is encoded using the + // regular encoding, in a sorted order. + EncodingSortedPlain QueryEncoding = 0 + + // EncodingSortedZlib signals that the set of data is encoded by first + // sorting the set of channel ID's, as then compressing them using zlib. + // + // NOTE: this should no longer be used or accepted. + EncodingSortedZlib QueryEncoding = 1 +) diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index 50a547e22..8ab082b0b 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -85,7 +85,7 @@ func WriteElement(w *bytes.Buffer, element interface{}) error { return err } - case ShortChanIDEncoding: + case QueryEncoding: var b [1]byte b[0] = uint8(e) if _, err := w.Write(b[:]); err != nil { @@ -509,12 +509,12 @@ func ReadElement(r io.Reader, element interface{}) error { } *e = alias - case *ShortChanIDEncoding: + case *QueryEncoding: var b [1]uint8 if _, err := r.Read(b[:]); err != nil { return err } - *e = ShortChanIDEncoding(b[0]) + *e = QueryEncoding(b[0]) case *uint8: var b [1]uint8 diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 9d92b970b..c3248c3c9 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -1159,12 +1159,42 @@ func TestLightningWireProtocol(t *testing.T) { req.EncodingType = EncodingSortedPlain } - numChanIDs := rand.Int31n(5000) + numChanIDs := rand.Int31n(4000) for i := int32(0); i < numChanIDs; i++ { req.ShortChanIDs = append(req.ShortChanIDs, NewShortChanIDFromInt(uint64(r.Int63()))) } + // With a 50/50 chance, add some timestamps. + if r.Int31()%2 == 0 { + for i := int32(0); i < numChanIDs; i++ { + timestamps := ChanUpdateTimestamps{ + Timestamp1: rand.Uint32(), + Timestamp2: rand.Uint32(), + } + req.Timestamps = append( + req.Timestamps, timestamps, + ) + } + } + + v[0] = reflect.ValueOf(req) + }, + MsgQueryChannelRange: func(v []reflect.Value, r *rand.Rand) { + req := QueryChannelRange{ + FirstBlockHeight: uint32(r.Int31()), + NumBlocks: uint32(r.Int31()), + ExtraData: make([]byte, 0), + } + + _, err := rand.Read(req.ChainHash[:]) + require.NoError(t, err) + + // With a 50/50 change, we'll set a query option. + if r.Int31()%2 == 0 { + req.QueryOptions = NewTimestampQueryOption() + } + v[0] = reflect.ValueOf(req) }, MsgPing: func(v []reflect.Value, r *rand.Rand) { diff --git a/lnwire/query_channel_range.go b/lnwire/query_channel_range.go index cfe88e9bf..1e0dcb0fa 100644 --- a/lnwire/query_channel_range.go +++ b/lnwire/query_channel_range.go @@ -6,6 +6,7 @@ import ( "math" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/tlv" ) // QueryChannelRange is a message sent by a node in order to query the @@ -27,6 +28,10 @@ type QueryChannelRange struct { // channel ID's should be sent for. NumBlocks uint32 + // QueryOptions is an optional feature bit vector that can be used to + // specify additional query options. + QueryOptions *QueryOptions + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -35,7 +40,9 @@ type QueryChannelRange struct { // NewQueryChannelRange creates a new empty QueryChannelRange message. func NewQueryChannelRange() *QueryChannelRange { - return &QueryChannelRange{} + return &QueryChannelRange{ + ExtraData: make([]byte, 0), + } } // A compile time check to ensure QueryChannelRange implements the @@ -46,20 +53,42 @@ var _ Message = (*QueryChannelRange)(nil) // passed io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. -func (q *QueryChannelRange) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, - q.ChainHash[:], - &q.FirstBlockHeight, - &q.NumBlocks, - &q.ExtraData, +func (q *QueryChannelRange) Decode(r io.Reader, _ uint32) error { + err := ReadElements( + r, q.ChainHash[:], &q.FirstBlockHeight, &q.NumBlocks, ) + if err != nil { + return err + } + + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + var queryOptions QueryOptions + typeMap, err := tlvRecords.ExtractRecords(&queryOptions) + if err != nil { + return err + } + + // Set the corresponding TLV types if they were included in the stream. + if val, ok := typeMap[QueryOptionsRecordType]; ok && val == nil { + q.QueryOptions = &queryOptions + } + + if len(tlvRecords) != 0 { + q.ExtraData = tlvRecords + } + + return nil } // Encode serializes the target QueryChannelRange into the passed io.Writer // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (q *QueryChannelRange) Encode(w *bytes.Buffer, pver uint32) error { +func (q *QueryChannelRange) Encode(w *bytes.Buffer, _ uint32) error { if err := WriteBytes(w, q.ChainHash[:]); err != nil { return err } @@ -72,6 +101,15 @@ func (q *QueryChannelRange) Encode(w *bytes.Buffer, pver uint32) error { return err } + recordProducers := make([]tlv.RecordProducer, 0, 1) + if q.QueryOptions != nil { + recordProducers = append(recordProducers, q.QueryOptions) + } + err := EncodeMessageExtraData(&q.ExtraData, recordProducers...) + if err != nil { + return err + } + return WriteBytes(w, q.ExtraData) } @@ -93,3 +131,14 @@ func (q *QueryChannelRange) LastBlockHeight() uint32 { } return uint32(lastBlockHeight) } + +// WithTimestamps returns true if the query has asked for timestamps too. +func (q *QueryChannelRange) WithTimestamps() bool { + if q.QueryOptions == nil { + return false + } + + queryOpts := RawFeatureVector(*q.QueryOptions) + + return queryOpts.IsSet(QueryOptionTimestampBit) +} diff --git a/lnwire/query_channel_range_test.go b/lnwire/query_channel_range_test.go new file mode 100644 index 000000000..5d690f38d --- /dev/null +++ b/lnwire/query_channel_range_test.go @@ -0,0 +1,79 @@ +package lnwire + +import ( + "bytes" + "encoding/hex" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestQueryChannelRange tests that a few query_channel_range test vectors can +// correctly be decoded and encoded. +func TestQueryChannelRange(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expFirstBlockNum int + expNumOfBlocks int + expWantTimestamps bool + }{ + { + name: "without timestamps query option", + input: "01070f9188f13cb7b2c71f2a335e3a4fc328bf5beb436" + + "012afca590b1a11466e2206000186a0000005dc", + expFirstBlockNum: 100000, + expNumOfBlocks: 1500, + expWantTimestamps: false, + }, + { + name: "with timestamps query option", + input: "01070f9188f13cb7b2c71f2a335e3a4fc328bf5beb436" + + "012afca590b1a11466e2206000088b800000064010103", + expFirstBlockNum: 35000, + expNumOfBlocks: 100, + expWantTimestamps: true, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + b, err := hex.DecodeString(test.input) + require.NoError(t, err) + + r := bytes.NewBuffer(b) + + msg, err := ReadMessage(r, 0) + require.NoError(t, err) + + queryMsg, ok := msg.(*QueryChannelRange) + require.True(t, ok) + + require.EqualValues( + t, test.expFirstBlockNum, + queryMsg.FirstBlockHeight, + ) + + require.EqualValues( + t, test.expNumOfBlocks, queryMsg.NumBlocks, + ) + + require.Equal( + t, test.expWantTimestamps, + queryMsg.WithTimestamps(), + ) + + var buf bytes.Buffer + _, err = WriteMessage(&buf, queryMsg, 0) + require.NoError(t, err) + + require.Equal(t, buf.Bytes(), b) + }) + } +} diff --git a/lnwire/query_options.go b/lnwire/query_options.go new file mode 100644 index 000000000..1f1730ca0 --- /dev/null +++ b/lnwire/query_options.go @@ -0,0 +1,80 @@ +package lnwire + +import ( + "io" + + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // QueryOptionsRecordType is the TLV number of the query_options TLV + // record in the query_channel_range message. + QueryOptionsRecordType tlv.Type = 1 + + // QueryOptionTimestampBit is the bit position in the query_option + // feature bit vector which is used to indicate that timestamps are + // desired in the reply_channel_range response. + QueryOptionTimestampBit = 0 +) + +// QueryOptions is the type used to represent the query_options feature bit +// vector in the query_channel_range message. +type QueryOptions RawFeatureVector + +// NewTimestampQueryOption is a helper constructor used to construct a +// QueryOption with the timestamp bit set. +func NewTimestampQueryOption() *QueryOptions { + opt := QueryOptions(*NewRawFeatureVector( + QueryOptionTimestampBit, + )) + + return &opt +} + +// featureBitLen calculates and returns the size of the resulting feature bit +// vector. +func (c *QueryOptions) featureBitLen() uint64 { + fv := RawFeatureVector(*c) + + return uint64(fv.SerializeSize()) +} + +// Record constructs a tlv.Record from the QueryOptions to be used in the +// query_channel_range message. +func (c *QueryOptions) Record() tlv.Record { + return tlv.MakeDynamicRecord( + QueryOptionsRecordType, c, c.featureBitLen, queryOptionsEncoder, + queryOptionsDecoder, + ) +} + +// queryOptionsEncoder encodes the QueryOptions and writes it to the provided +// writer. +func queryOptionsEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*QueryOptions); ok { + // Encode the feature bits as a byte slice without its length + // prepended, as that's already taken care of by the TLV record. + fv := RawFeatureVector(*v) + return fv.encode(w, fv.SerializeSize(), 8) + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.QueryOptions") +} + +// queryOptionsDecoder attempts to read a QueryOptions from the given reader. +func queryOptionsDecoder(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if v, ok := val.(*QueryOptions); ok { + fv := NewRawFeatureVector() + if err := fv.decode(r, int(l), 8); err != nil { + return err + } + + *v = QueryOptions(*fv) + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.QueryOptions") +} diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index 323a936db..6a90bed75 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -11,23 +11,6 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" ) -// ShortChanIDEncoding is an enum-like type that represents exactly how a set -// of short channel ID's is encoded on the wire. The set of encodings allows us -// to take advantage of the structure of a list of short channel ID's to -// achieving a high degree of compression. -type ShortChanIDEncoding uint8 - -const ( - // EncodingSortedPlain signals that the set of short channel ID's is - // encoded using the regular encoding, in a sorted order. - EncodingSortedPlain ShortChanIDEncoding = 0 - - // EncodingSortedZlib signals that the set of short channel ID's is - // encoded by first sorting the set of channel ID's, as then - // compressing them using zlib. - EncodingSortedZlib ShortChanIDEncoding = 1 -) - const ( // maxZlibBufSize is the max number of bytes that we'll accept from a // zlib decoding instance. We do this in order to limit the total @@ -56,7 +39,7 @@ var zlibDecodeMtx sync.Mutex // ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we // came across an unknown short channel ID encoding, and therefore were unable // to continue parsing. -func ErrUnknownShortChanIDEncoding(encoding ShortChanIDEncoding) error { +func ErrUnknownShortChanIDEncoding(encoding QueryEncoding) error { return fmt.Errorf("unknown short chan id encoding: %v", encoding) } @@ -76,7 +59,7 @@ type QueryShortChanIDs struct { // EncodingType is a signal to the receiver of the message that // indicates exactly how the set of short channel ID's that follow have // been encoded. - EncodingType ShortChanIDEncoding + EncodingType QueryEncoding // ShortChanIDs is a slice of decoded short channel ID's. ShortChanIDs []ShortChannelID @@ -94,7 +77,7 @@ type QueryShortChanIDs struct { } // NewQueryShortChanIDs creates a new QueryShortChanIDs message. -func NewQueryShortChanIDs(h chainhash.Hash, e ShortChanIDEncoding, +func NewQueryShortChanIDs(h chainhash.Hash, e QueryEncoding, s []ShortChannelID) *QueryShortChanIDs { return &QueryShortChanIDs{ @@ -130,7 +113,7 @@ func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error { // encoded. The first byte of the body details how the short chan ID's were // encoded. We'll use this type to govern exactly how we go about encoding the // set of short channel ID's. -func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, error) { +func decodeShortChanIDs(r io.Reader) (QueryEncoding, []ShortChannelID, error) { // First, we'll attempt to read the number of bytes in the body of the // set of encoded short channel ID's. var numBytesResp uint16 @@ -150,7 +133,7 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err // The first byte is the encoding type, so we'll extract that so we can // continue our parsing. - encodingType := ShortChanIDEncoding(queryBody[0]) + encodingType := QueryEncoding(queryBody[0]) // Before continuing, we'll snip off the first byte of the query body // as that was just the encoding type. @@ -297,9 +280,19 @@ func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error { return err } + // For both of the current encoding types, the channel ID's are to be + // sorted in place, so we'll do that now. The sorting is applied unless + // we were specifically requested not to for testing purposes. + if !q.noSort { + sort.Slice(q.ShortChanIDs, func(i, j int) bool { + return q.ShortChanIDs[i].ToUint64() < + q.ShortChanIDs[j].ToUint64() + }) + } + // Base on our encoding type, we'll write out the set of short channel // ID's. - err := encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort) + err := encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs) if err != nil { return err } @@ -309,18 +302,8 @@ func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error { // encodeShortChanIDs encodes the passed short channel ID's into the passed // io.Writer, respecting the specified encoding type. -func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding, - shortChanIDs []ShortChannelID, noSort bool) error { - - // For both of the current encoding types, the channel ID's are to be - // sorted in place, so we'll do that now. The sorting is applied unless - // we were specifically requested not to for testing purposes. - if !noSort { - sort.Slice(shortChanIDs, func(i, j int) bool { - return shortChanIDs[i].ToUint64() < - shortChanIDs[j].ToUint64() - }) - } +func encodeShortChanIDs(w *bytes.Buffer, encodingType QueryEncoding, + shortChanIDs []ShortChannelID) error { switch encodingType { @@ -337,7 +320,7 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding, // We'll then write out the encoding that that follows the // actual encoded short channel ID's. - err := WriteShortChanIDEncoding(w, encodingType) + err := WriteQueryEncoding(w, encodingType) if err != nil { return err } @@ -421,7 +404,7 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding, if err := WriteUint16(w, uint16(numBytesBody)); err != nil { return err } - err := WriteShortChanIDEncoding(w, encodingType) + err := WriteQueryEncoding(w, encodingType) if err != nil { return err } diff --git a/lnwire/query_short_chan_ids_test.go b/lnwire/query_short_chan_ids_test.go index e42184044..996c9f744 100644 --- a/lnwire/query_short_chan_ids_test.go +++ b/lnwire/query_short_chan_ids_test.go @@ -7,7 +7,7 @@ import ( type unsortedSidTest struct { name string - encType ShortChanIDEncoding + encType QueryEncoding sids []ShortChannelID } @@ -79,7 +79,7 @@ func TestQueryShortChanIDsUnsorted(t *testing.T) { func TestQueryShortChanIDsZero(t *testing.T) { testCases := []struct { name string - encoding ShortChanIDEncoding + encoding QueryEncoding }{ { name: "plain", diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index 9dc0fca9c..ea45a5843 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -2,10 +2,13 @@ package lnwire import ( "bytes" + "fmt" "io" "math" + "sort" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/tlv" ) // ReplyChannelRange is the response to the QueryChannelRange message. It @@ -33,11 +36,17 @@ type ReplyChannelRange struct { // EncodingType is a signal to the receiver of the message that // indicates exactly how the set of short channel ID's that follow have // been encoded. - EncodingType ShortChanIDEncoding + EncodingType QueryEncoding // ShortChanIDs is a slice of decoded short channel ID's. ShortChanIDs []ShortChannelID + // Timestamps is an optional set of timestamps corresponding to the + // latest timestamps for the channel update messages corresponding to + // those referenced in the ShortChanIDs list. If this field is used, + // then the length must match the length of ShortChanIDs. + Timestamps Timestamps + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -52,7 +61,9 @@ type ReplyChannelRange struct { // NewReplyChannelRange creates a new empty ReplyChannelRange message. func NewReplyChannelRange() *ReplyChannelRange { - return &ReplyChannelRange{} + return &ReplyChannelRange{ + ExtraData: make([]byte, 0), + } } // A compile time check to ensure ReplyChannelRange implements the @@ -79,7 +90,27 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { return err } - return c.ExtraData.Decode(r) + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + var timeStamps Timestamps + typeMap, err := tlvRecords.ExtractRecords(&timeStamps) + if err != nil { + return err + } + + // Set the corresponding TLV types if they were included in the stream. + if val, ok := typeMap[TimestampsRecordType]; ok && val == nil { + c.Timestamps = timeStamps + } + + if len(tlvRecords) != 0 { + c.ExtraData = tlvRecords + } + + return nil } // Encode serializes the target ReplyChannelRange into the passed io.Writer @@ -103,7 +134,64 @@ func (c *ReplyChannelRange) Encode(w *bytes.Buffer, pver uint32) error { return err } - err := encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort) + // For both of the current encoding types, the channel ID's are to be + // sorted in place, so we'll do that now. The sorting is applied unless + // we were specifically requested not to for testing purposes. + if !c.noSort { + var scidPreSortIndex map[uint64]int + if len(c.Timestamps) != 0 { + // Sanity check that a timestamp was provided for each + // SCID. + if len(c.Timestamps) != len(c.ShortChanIDs) { + return fmt.Errorf("must provide a timestamp " + + "pair for each of the given SCIDs") + } + + // Create a map from SCID value to the original index of + // the SCID in the unsorted list. + scidPreSortIndex = make( + map[uint64]int, len(c.ShortChanIDs), + ) + for i, scid := range c.ShortChanIDs { + scidPreSortIndex[scid.ToUint64()] = i + } + + // Sanity check that there were no duplicates in the + // SCID list. + if len(scidPreSortIndex) != len(c.ShortChanIDs) { + return fmt.Errorf("scid list should not " + + "contain duplicates") + } + } + + // Now sort the SCIDs. + sort.Slice(c.ShortChanIDs, func(i, j int) bool { + return c.ShortChanIDs[i].ToUint64() < + c.ShortChanIDs[j].ToUint64() + }) + + if len(c.Timestamps) != 0 { + timestamps := make(Timestamps, len(c.Timestamps)) + + for i, scid := range c.ShortChanIDs { + timestamps[i] = []ChanUpdateTimestamps( + c.Timestamps, + )[scidPreSortIndex[scid.ToUint64()]] + } + c.Timestamps = timestamps + } + } + + err := encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs) + if err != nil { + return err + } + + recordProducers := make([]tlv.RecordProducer, 0, 1) + if len(c.Timestamps) != 0 { + recordProducers = append(recordProducers, &c.Timestamps) + } + err = EncodeMessageExtraData(&c.ExtraData, recordProducers...) if err != nil { return err } diff --git a/lnwire/reply_channel_range_test.go b/lnwire/reply_channel_range_test.go index ff3414958..12955cfd9 100644 --- a/lnwire/reply_channel_range_test.go +++ b/lnwire/reply_channel_range_test.go @@ -3,10 +3,9 @@ package lnwire import ( "bytes" "encoding/hex" - "reflect" "testing" - "github.com/davecgh/go-spew/spew" + "github.com/stretchr/testify/require" ) // TestReplyChannelRangeUnsorted tests that decoding a ReplyChannelRange request @@ -44,7 +43,7 @@ func TestReplyChannelRangeEmpty(t *testing.T) { emptyChannelsTests := []struct { name string - encType ShortChanIDEncoding + encType QueryEncoding encodedHex string }{ { @@ -78,29 +77,288 @@ func TestReplyChannelRangeEmpty(t *testing.T) { // First decode the hex string in the test case into a // new ReplyChannelRange message. It should be // identical to the one created above. - var req2 ReplyChannelRange + req2 := NewReplyChannelRange() b, _ := hex.DecodeString(test.encodedHex) err := req2.Decode(bytes.NewReader(b), 0) - if err != nil { - t.Fatalf("unable to decode req: %v", err) - } - if !reflect.DeepEqual(req, req2) { - t.Fatalf("requests don't match: expected %v got %v", - spew.Sdump(req), spew.Sdump(req2)) - } + require.NoError(t, err) + require.Equal(t, req, *req2) // Next, we go in the reverse direction: encode the // request created above, and assert that it matches // the raw byte encoding. var b2 bytes.Buffer err = req.Encode(&b2, 0) - if err != nil { - t.Fatalf("unable to encode req: %v", err) - } - if !bytes.Equal(b, b2.Bytes()) { - t.Fatalf("encoded requests don't match: expected %x got %x", - b, b2.Bytes()) - } + require.NoError(t, err) + require.Equal(t, b, b2.Bytes()) + }) + } +} + +// TestReplyChannelRangeEncode tests that encoding a ReplyChannelRange message +// results in the correct sorting of the SCIDs and Timestamps. +func TestReplyChannelRangeEncode(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + scids []ShortChannelID + timestamps Timestamps + expError string + expScids []ShortChannelID + expTimestamps Timestamps + }{ + { + name: "scids only, sorted", + scids: []ShortChannelID{ + {BlockHeight: 100}, + {BlockHeight: 200}, + {BlockHeight: 300}, + }, + expScids: []ShortChannelID{ + {BlockHeight: 100}, + {BlockHeight: 200}, + {BlockHeight: 300}, + }, + }, + { + name: "scids only, unsorted", + scids: []ShortChannelID{ + {BlockHeight: 300}, + {BlockHeight: 100}, + {BlockHeight: 200}, + }, + expScids: []ShortChannelID{ + {BlockHeight: 100}, + {BlockHeight: 200}, + {BlockHeight: 300}, + }, + }, + { + name: "scids and timestamps, sorted", + scids: []ShortChannelID{ + {BlockHeight: 100}, + {BlockHeight: 200}, + {BlockHeight: 300}, + }, + timestamps: Timestamps{ + {Timestamp1: 1, Timestamp2: 2}, + {Timestamp1: 3, Timestamp2: 4}, + {Timestamp1: 5, Timestamp2: 6}, + }, + expScids: []ShortChannelID{ + {BlockHeight: 100}, + {BlockHeight: 200}, + {BlockHeight: 300}, + }, + expTimestamps: Timestamps{ + {Timestamp1: 1, Timestamp2: 2}, + {Timestamp1: 3, Timestamp2: 4}, + {Timestamp1: 5, Timestamp2: 6}, + }, + }, + { + name: "scids and timestamps, unsorted", + scids: []ShortChannelID{ + {BlockHeight: 300}, + {BlockHeight: 100}, + {BlockHeight: 200}, + }, + timestamps: Timestamps{ + {Timestamp1: 5, Timestamp2: 6}, + {Timestamp1: 1, Timestamp2: 2}, + {Timestamp1: 3, Timestamp2: 4}, + }, + expScids: []ShortChannelID{ + {BlockHeight: 100}, + {BlockHeight: 200}, + {BlockHeight: 300}, + }, + expTimestamps: Timestamps{ + {Timestamp1: 1, Timestamp2: 2}, + {Timestamp1: 3, Timestamp2: 4}, + {Timestamp1: 5, Timestamp2: 6}, + }, + }, + { + name: "scid and timestamp count does not match", + scids: []ShortChannelID{ + {BlockHeight: 100}, + {BlockHeight: 200}, + {BlockHeight: 300}, + }, + timestamps: Timestamps{ + {Timestamp1: 1, Timestamp2: 2}, + {Timestamp1: 3, Timestamp2: 4}, + }, + expError: "got must provide a timestamp pair for " + + "each of the given SCIDs", + }, + { + name: "duplicate scids", + scids: []ShortChannelID{ + {BlockHeight: 100}, + {BlockHeight: 200}, + {BlockHeight: 200}, + }, + timestamps: Timestamps{ + {Timestamp1: 1, Timestamp2: 2}, + {Timestamp1: 3, Timestamp2: 4}, + {Timestamp1: 5, Timestamp2: 6}, + }, + expError: "scid list should not contain duplicates", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + replyMsg := &ReplyChannelRange{ + FirstBlockHeight: 1, + NumBlocks: 2, + Complete: 1, + EncodingType: EncodingSortedPlain, + ShortChanIDs: test.scids, + Timestamps: test.timestamps, + ExtraData: make([]byte, 0), + } + + var buf bytes.Buffer + _, err := WriteMessage(&buf, replyMsg, 0) + if len(test.expError) != 0 { + require.ErrorContains(t, err, test.expError) + + return + } + + require.NoError(t, err) + + r := bytes.NewBuffer(buf.Bytes()) + msg, err := ReadMessage(r, 0) + require.NoError(t, err) + + msg2, ok := msg.(*ReplyChannelRange) + require.True(t, ok) + + require.Equal(t, test.expScids, msg2.ShortChanIDs) + require.Equal(t, test.expTimestamps, msg2.Timestamps) + }) + } +} + +// TestReplyChannelRangeDecode tests the decoding of some ReplyChannelRange +// test vectors. +func TestReplyChannelRangeDecode(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + hex string + expEncoding QueryEncoding + expSCIDs []string + expTimestamps Timestamps + expError string + }{ + { + name: "plain encoding", + hex: "01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb4360" + + "12afca590b1a11466e2206000b8a06000005dc01001" + + "900000000000000008e0000000000003c6900000000" + + "0045a6c4", + expEncoding: EncodingSortedPlain, + expSCIDs: []string{ + "0:0:142", + "0:0:15465", + "0:69:42692", + }, + }, + { + name: "zlib encoding", + hex: "01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb4360" + + "12afca590b1a11466e2206000006400000006e010016" + + "01789c636000833e08659309a65878be010010a9023a", + expEncoding: EncodingSortedZlib, + expSCIDs: []string{ + "0:0:142", + "0:0:15465", + "0:4:3318", + }, + }, + { + name: "plain encoding including timestamps", + hex: "01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb43601" + + "2afca590b1a11466e22060001ddde000005dc0100190" + + "0000000000000304300000000000778d600000000004" + + "6e1c1011900000282c1000e77c5000778ad00490ab00" + + "000b57800955bff031800000457000008ae00000d050" + + "000115c000015b300001a0a", + expEncoding: EncodingSortedPlain, + expSCIDs: []string{ + "0:0:12355", + "0:7:30934", + "0:70:57793", + }, + expTimestamps: Timestamps{ + { + Timestamp1: 164545, + Timestamp2: 948165, + }, + { + Timestamp1: 489645, + Timestamp2: 4786864, + }, + { + Timestamp1: 46456, + Timestamp2: 9788415, + }, + }, + }, + { + name: "unsupported encoding", + hex: "01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb" + + "436012afca590b1a11466e22060001ddde000005dc01" + + "001801789c63600001036730c55e710d4cbb3d3c0800" + + "17c303b1012201789c63606a3ac8c0577e9481bd622d" + + "8327d7060686ad150c53a3ff0300554707db03180000" + + "0457000008ae00000d050000115c000015b300001a0a", + expError: "unsupported encoding", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + b, err := hex.DecodeString(test.hex) + require.NoError(t, err) + + r := bytes.NewBuffer(b) + + msg, err := ReadMessage(r, 0) + if len(test.expError) != 0 { + require.ErrorContains(t, err, test.expError) + + return + } + require.NoError(t, err) + + replyMsg, ok := msg.(*ReplyChannelRange) + require.True(t, ok) + require.Equal( + t, test.expEncoding, replyMsg.EncodingType, + ) + + scids := make([]string, len(replyMsg.ShortChanIDs)) + for i, id := range replyMsg.ShortChanIDs { + scids[i] = id.String() + } + require.Equal(t, scids, test.expSCIDs) + + require.Equal( + t, test.expTimestamps, replyMsg.Timestamps, + ) }) } } diff --git a/lnwire/timestamps.go b/lnwire/timestamps.go new file mode 100644 index 000000000..1d0d8c6a4 --- /dev/null +++ b/lnwire/timestamps.go @@ -0,0 +1,123 @@ +package lnwire + +import ( + "bytes" + "fmt" + "io" + + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // TimestampsRecordType is the TLV number of the timestamps TLV record + // in the reply_channel_range message. + TimestampsRecordType tlv.Type = 1 + + // timestampPairSize is the number of bytes required to encode two + // timestamps. Each timestamp is four bytes. + timestampPairSize = 8 +) + +// Timestamps is a type representing the timestamps TLV field used in the +// reply_channel_range message to communicate the timestamps info of the updates +// of the SCID list being communicated. +type Timestamps []ChanUpdateTimestamps + +// ChanUpdateTimestamps holds the timestamp info of the latest known channel +// updates corresponding to the two sides of a channel. +type ChanUpdateTimestamps struct { + Timestamp1 uint32 + Timestamp2 uint32 +} + +// Record constructs the tlv.Record from the Timestamps. +func (t *Timestamps) Record() tlv.Record { + return tlv.MakeDynamicRecord( + TimestampsRecordType, t, t.encodedLen, timeStampsEncoder, + timeStampsDecoder, + ) +} + +// encodedLen calculates the length of the encoded Timestamps. +func (t *Timestamps) encodedLen() uint64 { + return uint64(1 + timestampPairSize*(len(*t))) +} + +// timeStampsEncoder encodes the Timestamps and writes the encoded bytes to the +// given writer. +func timeStampsEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*Timestamps); ok { + var buf bytes.Buffer + + // Add the encoding byte. + err := WriteQueryEncoding(&buf, EncodingSortedPlain) + if err != nil { + return err + } + + // For each timestamp, write 4 byte timestamp of node 1 and the + // 4 byte timestamp of node 2. + for _, timestamps := range *v { + err = WriteUint32(&buf, timestamps.Timestamp1) + if err != nil { + return err + } + + err = WriteUint32(&buf, timestamps.Timestamp2) + if err != nil { + return err + } + } + + _, err = w.Write(buf.Bytes()) + + return err + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.Timestamps") +} + +// timeStampsDecoder attempts to read and reconstruct a Timestamps object from +// the given reader. +func timeStampsDecoder(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if v, ok := val.(*Timestamps); ok { + var encodingByte [1]byte + if _, err := r.Read(encodingByte[:]); err != nil { + return err + } + + encoding := QueryEncoding(encodingByte[0]) + if encoding != EncodingSortedPlain { + return fmt.Errorf("unsupported encoding: %x", encoding) + } + + // The number of timestamps bytes is equal to the passed length + // minus one since the first byte is used for the encoding type. + numTimestampBytes := l - 1 + + if numTimestampBytes%timestampPairSize != 0 { + return fmt.Errorf("whole number of timestamps not " + + "encoded") + } + + numTimestamps := int(numTimestampBytes) / timestampPairSize + timestamps := make(Timestamps, numTimestamps) + for i := 0; i < numTimestamps; i++ { + err := ReadElements( + r, ×tamps[i].Timestamp1, + ×tamps[i].Timestamp2, + ) + if err != nil { + return err + } + } + + *v = timestamps + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.Timestamps") +} diff --git a/lnwire/writer.go b/lnwire/writer.go index 671ebfdc0..fa6247de0 100644 --- a/lnwire/writer.go +++ b/lnwire/writer.go @@ -205,9 +205,8 @@ func WriteColorRGBA(buf *bytes.Buffer, e color.RGBA) error { return WriteUint8(buf, e.B) } -// WriteShortChanIDEncoding appends the ShortChanIDEncoding to the provided -// buffer. -func WriteShortChanIDEncoding(buf *bytes.Buffer, e ShortChanIDEncoding) error { +// WriteQueryEncoding appends the QueryEncoding to the provided buffer. +func WriteQueryEncoding(buf *bytes.Buffer, e QueryEncoding) error { return WriteUint8(buf, uint8(e)) } diff --git a/lnwire/writer_test.go b/lnwire/writer_test.go index ccdeabcf6..3e2550443 100644 --- a/lnwire/writer_test.go +++ b/lnwire/writer_test.go @@ -225,10 +225,10 @@ func TestWriteColorRGBA(t *testing.T) { func TestWriteShortChanIDEncoding(t *testing.T) { buf := new(bytes.Buffer) - data := ShortChanIDEncoding(1) + data := QueryEncoding(1) expectedBytes := []byte{1} - err := WriteShortChanIDEncoding(buf, data) + err := WriteQueryEncoding(buf, data) require.NoError(t, err) require.Equal(t, expectedBytes, buf.Bytes()) diff --git a/routing/router.go b/routing/router.go index 48aeead18..80f647807 100644 --- a/routing/router.go +++ b/routing/router.go @@ -888,6 +888,55 @@ func (r *ChannelRouter) syncGraphWithChain() error { return nil } +// isZombieChannel takes two edge policy updates and determines if the +// corresponding channel should be considered a zombie. The first boolean is +// true if the policy update from node 1 is considered a zombie, the second +// boolean is that of node 2, and the final boolean is true if the channel +// is considered a zombie. +func (r *ChannelRouter) isZombieChannel(e1, + e2 *models.ChannelEdgePolicy) (bool, bool, bool) { + + chanExpiry := r.cfg.ChannelPruneExpiry + + e1Zombie := e1 == nil || time.Since(e1.LastUpdate) >= chanExpiry + e2Zombie := e2 == nil || time.Since(e2.LastUpdate) >= chanExpiry + + var e1Time, e2Time time.Time + if e1 != nil { + e1Time = e1.LastUpdate + } + if e2 != nil { + e2Time = e2.LastUpdate + } + + return e1Zombie, e2Zombie, r.IsZombieChannel(e1Time, e2Time) +} + +// IsZombieChannel takes the timestamps of the latest channel updates for a +// channel and returns true if the channel should be considered a zombie based +// on these timestamps. +func (r *ChannelRouter) IsZombieChannel(updateTime1, + updateTime2 time.Time) bool { + + chanExpiry := r.cfg.ChannelPruneExpiry + + e1Zombie := updateTime1.IsZero() || + time.Since(updateTime1) >= chanExpiry + + e2Zombie := updateTime2.IsZero() || + time.Since(updateTime2) >= chanExpiry + + // If we're using strict zombie pruning, then a channel is only + // considered live if both edges have a recent update we know of. + if r.cfg.StrictZombiePruning { + return e1Zombie || e2Zombie + } + + // Otherwise, if we're using the less strict variant, then a channel is + // considered live if either of the edges have a recent update. + return e1Zombie && e2Zombie +} + // pruneZombieChans is a method that will be called periodically to prune out // any "zombie" channels. We consider channels zombies if *both* edges haven't // been updated since our zombie horizon. If AssumeChannelValid is present, @@ -911,8 +960,10 @@ func (r *ChannelRouter) pruneZombieChans() error { filterPruneChans := func(info *models.ChannelEdgeInfo, e1, e2 *models.ChannelEdgePolicy) error { - // Exit early in case this channel is already marked to be pruned - if _, markedToPrune := chansToPrune[info.ChannelID]; markedToPrune { + // Exit early in case this channel is already marked to be + // pruned + _, markedToPrune := chansToPrune[info.ChannelID] + if markedToPrune { return nil } @@ -923,39 +974,22 @@ func (r *ChannelRouter) pruneZombieChans() error { return nil } - // If either edge hasn't been updated for a period of - // chanExpiry, then we'll mark the channel itself as eligible - // for graph pruning. - e1Zombie := e1 == nil || time.Since(e1.LastUpdate) >= chanExpiry - e2Zombie := e2 == nil || time.Since(e2.LastUpdate) >= chanExpiry + e1Zombie, e2Zombie, isZombieChan := r.isZombieChannel(e1, e2) if e1Zombie { log.Tracef("Node1 pubkey=%x of chan_id=%v is zombie", info.NodeKey1Bytes, info.ChannelID) } + if e2Zombie { log.Tracef("Node2 pubkey=%x of chan_id=%v is zombie", info.NodeKey2Bytes, info.ChannelID) } - // If we're using strict zombie pruning, then a channel is only - // considered live if both edges have a recent update we know - // of. - var channelIsLive bool - switch { - case r.cfg.StrictZombiePruning: - channelIsLive = !e1Zombie && !e2Zombie - - // Otherwise, if we're using the less strict variant, then a - // channel is considered live if either of the edges have a - // recent update. - default: - channelIsLive = !e1Zombie || !e2Zombie - } - - // Return early if the channel is still considered to be live - // with the current set of configuration parameters. - if channelIsLive { + // If either edge hasn't been updated for a period of + // chanExpiry, then we'll mark the channel itself as eligible + // for graph pruning. + if !isZombieChan { return nil } diff --git a/sample-lnd.conf b/sample-lnd.conf index f13098723..2e7f79376 100644 --- a/sample-lnd.conf +++ b/sample-lnd.conf @@ -1272,6 +1272,9 @@ ; closing. ; protocol.no-any-segwit=false +; Set to disable querying our peers for the timestamps of announcement +; messages and to disable responding to such queries +; protocol.no-timestamp-query-option=false ; Set to enable support for the experimental taproot channel type. ; protocol.simple-taproot-chans=false diff --git a/server.go b/server.go index fd53ee5e6..aed4d3c85 100644 --- a/server.go +++ b/server.go @@ -1018,6 +1018,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, RotateTicker: ticker.New(discovery.DefaultSyncerRotationInterval), HistoricalSyncTicker: ticker.New(cfg.HistoricalSyncInterval), NumActiveSyncers: cfg.NumGraphSyncPeers, + NoTimestampQueries: cfg.ProtocolOptions.NoTimestampQueryOption, //nolint:lll MinimumBatchSize: 10, SubBatchDelay: cfg.Gossip.SubBatchDelay, IgnoreHistoricalFilters: cfg.IgnoreHistoricalGossipFilters, @@ -1029,6 +1030,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, FindBaseByAlias: s.aliasMgr.FindBaseSCID, GetAlias: s.aliasMgr.GetPeerAlias, FindChannel: s.findChannel, + IsStillZombieChannel: s.chanRouter.IsZombieChannel, }, nodeKeyDesc) s.localChanMgr = &localchans.Manager{