diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 8b8a3917e..212246ccf 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -155,7 +155,16 @@ func (r *mockGraphSource) UpdateEdge(edge *channeldb.ChannelEdgePolicy) error { r.mu.Lock() defer r.mu.Unlock() - r.edges[edge.ChannelID] = append(r.edges[edge.ChannelID], *edge) + if len(r.edges[edge.ChannelID]) == 0 { + r.edges[edge.ChannelID] = make([]channeldb.ChannelEdgePolicy, 2) + } + + if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { + r.edges[edge.ChannelID][0] = *edge + } else { + r.edges[edge.ChannelID][1] = *edge + } + return nil } @@ -226,13 +235,17 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( return &chanInfo, nil, nil, nil } - if len(edges) == 1 { - edge1 := edges[0] - return &chanInfo, &edge1, nil, nil + var edge1 *channeldb.ChannelEdgePolicy + if !reflect.DeepEqual(edges[0], channeldb.ChannelEdgePolicy{}) { + edge1 = &edges[0] } - edge1, edge2 := edges[0], edges[1] - return &chanInfo, &edge1, &edge2, nil + var edge2 *channeldb.ChannelEdgePolicy + if !reflect.DeepEqual(edges[1], channeldb.ChannelEdgePolicy{}) { + edge2 = &edges[1] + } + + return &chanInfo, edge1, edge2, nil } func (r *mockGraphSource) FetchLightningNode( @@ -327,11 +340,15 @@ func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID, } switch { - case len(edges) >= 1 && edges[0].ChannelFlags == flags: - return !edges[0].LastUpdate.Before(timestamp) + case flags&lnwire.ChanUpdateDirection == 0 && + !reflect.DeepEqual(edges[0], channeldb.ChannelEdgePolicy{}): - case len(edges) >= 2 && edges[1].ChannelFlags == flags: - return !edges[1].LastUpdate.Before(timestamp) + return !timestamp.After(edges[0].LastUpdate) + + case flags&lnwire.ChanUpdateDirection == 1 && + !reflect.DeepEqual(edges[1], channeldb.ChannelEdgePolicy{}): + + return !timestamp.After(edges[1].LastUpdate) default: return false