discovery: let mockGraphSource track call count

We will use this in an upcoming test.
This commit is contained in:
Elle Mouton 2025-03-03 14:00:46 +02:00
parent f744a5477f
commit 78413f5b84
No known key found for this signature in database
GPG key ID: D7D916376026F177

View file

@ -77,6 +77,7 @@ var (
)
type mockGraphSource struct {
t *testing.T
bestHeight uint32
mu sync.Mutex
@ -85,25 +86,42 @@ type mockGraphSource struct {
edges map[uint64][]models.ChannelEdgePolicy
zombies map[uint64][][33]byte
chansToReject map[uint64]struct{}
callCount map[string]int
}
func newMockRouter(height uint32) *mockGraphSource {
func newMockRouter(t *testing.T, height uint32) *mockGraphSource {
return &mockGraphSource{
t: t,
bestHeight: height,
infos: make(map[uint64]models.ChannelEdgeInfo),
edges: make(map[uint64][]models.ChannelEdgePolicy),
edges: make(
map[uint64][]models.ChannelEdgePolicy,
),
zombies: make(map[uint64][][33]byte),
chansToReject: make(map[uint64]struct{}),
callCount: make(map[string]int),
}
}
// getCallCount returns the number of times the given method has been called.
func (r *mockGraphSource) getCallCount(method string) int {
r.mu.Lock()
defer r.mu.Unlock()
return r.callCount[method]
}
var _ graph.ChannelGraphSource = (*mockGraphSource)(nil)
func (r *mockGraphSource) AddNode(node *models.LightningNode,
_ ...batch.SchedulerOption) error {
r.mu.Lock()
defer r.mu.Unlock()
defer func() {
r.callCount["AddNode"]++
r.mu.Unlock()
}()
r.nodes = append(r.nodes, *node)
return nil
@ -119,7 +137,10 @@ func (r *mockGraphSource) IsZombieEdge(chanID lnwire.ShortChannelID) (bool,
error) {
r.mu.Lock()
defer r.mu.Unlock()
defer func() {
r.callCount["IsZombieEdge"]++
r.mu.Unlock()
}()
_, ok := r.zombies[chanID.ToUint64()]
@ -130,7 +151,10 @@ func (r *mockGraphSource) AddEdge(info *models.ChannelEdgeInfo,
_ ...batch.SchedulerOption) error {
r.mu.Lock()
defer r.mu.Unlock()
defer func() {
r.callCount["AddEdge"]++
r.mu.Unlock()
}()
if _, ok := r.infos[info.ChannelID]; ok {
return errors.New("info already exist")
@ -155,7 +179,10 @@ func (r *mockGraphSource) UpdateEdge(edge *models.ChannelEdgePolicy,
_ ...batch.SchedulerOption) error {
r.mu.Lock()
defer r.mu.Unlock()
defer func() {
r.callCount["UpdateEdge"]++
r.mu.Unlock()
}()
if len(r.edges[edge.ChannelID]) == 0 {
r.edges[edge.ChannelID] = make([]models.ChannelEdgePolicy, 2)
@ -171,6 +198,12 @@ func (r *mockGraphSource) UpdateEdge(edge *models.ChannelEdgePolicy,
}
func (r *mockGraphSource) CurrentBlockHeight() (uint32, error) {
r.mu.Lock()
defer func() {
r.callCount["CurrentBlock"]++
r.mu.Unlock()
}()
return r.bestHeight, nil
}
@ -178,7 +211,10 @@ func (r *mockGraphSource) AddProof(chanID lnwire.ShortChannelID,
proof *models.ChannelAuthProof) error {
r.mu.Lock()
defer r.mu.Unlock()
defer func() {
r.callCount["AddProof"]++
r.mu.Unlock()
}()
chanIDInt := chanID.ToUint64()
info, ok := r.infos[chanIDInt]
@ -192,17 +228,14 @@ func (r *mockGraphSource) AddProof(chanID lnwire.ShortChannelID,
return nil
}
func (r *mockGraphSource) ForEachNode(
func(node *models.LightningNode) error) error {
return nil
}
func (r *mockGraphSource) ForAllOutgoingChannels(cb func(
i *models.ChannelEdgeInfo, c *models.ChannelEdgePolicy) error) error {
r.mu.Lock()
defer r.mu.Unlock()
defer func() {
r.callCount["ForAllOutgoingChannels"]++
r.mu.Unlock()
}()
chans := make(map[uint64]graphdb.ChannelEdge)
for _, info := range r.infos {
@ -235,7 +268,10 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) (
*models.ChannelEdgePolicy, error) {
r.mu.Lock()
defer r.mu.Unlock()
defer func() {
r.callCount["GetChannelByID"]++
r.mu.Unlock()
}()
chanIDInt := chanID.ToUint64()
chanInfo, ok := r.infos[chanIDInt]
@ -272,6 +308,12 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) (
func (r *mockGraphSource) FetchLightningNode(
nodePub route.Vertex) (*models.LightningNode, error) {
r.mu.Lock()
defer func() {
r.callCount["FetchLightningNode"]++
r.mu.Unlock()
}()
for _, node := range r.nodes {
if bytes.Equal(nodePub[:], node.PubKeyBytes[:]) {
return &node, nil
@ -285,7 +327,10 @@ func (r *mockGraphSource) FetchLightningNode(
// target node with a more recent timestamp.
func (r *mockGraphSource) IsStaleNode(nodePub route.Vertex, timestamp time.Time) bool {
r.mu.Lock()
defer r.mu.Unlock()
defer func() {
r.callCount["IsStaleNode"]++
r.mu.Unlock()
}()
for _, node := range r.nodes {
if node.PubKeyBytes == nodePub {
@ -311,6 +356,12 @@ func (r *mockGraphSource) IsStaleNode(nodePub route.Vertex, timestamp time.Time)
// IsPublicNode determines whether the given vertex is seen as a public node in
// the graph from the graph's source node's point of view.
func (r *mockGraphSource) IsPublicNode(node route.Vertex) (bool, error) {
r.mu.Lock()
defer func() {
r.callCount["IsPublicNode"]++
r.mu.Unlock()
}()
for _, info := range r.infos {
if !bytes.Equal(node[:], info.NodeKey1Bytes[:]) &&
!bytes.Equal(node[:], info.NodeKey2Bytes[:]) {
@ -328,7 +379,10 @@ func (r *mockGraphSource) IsPublicNode(node route.Vertex) (bool, error) {
// channel ID either as a live or zombie channel.
func (r *mockGraphSource) IsKnownEdge(chanID lnwire.ShortChannelID) bool {
r.mu.Lock()
defer r.mu.Unlock()
defer func() {
r.callCount["IsKnownEdge"]++
r.mu.Unlock()
}()
chanIDInt := chanID.ToUint64()
_, exists := r.infos[chanIDInt]
@ -342,7 +396,10 @@ func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID,
timestamp time.Time, flags lnwire.ChanUpdateChanFlags) bool {
r.mu.Lock()
defer r.mu.Unlock()
defer func() {
r.callCount["IsStaleEdgePolicy"]++
r.mu.Unlock()
}()
chanIDInt := chanID.ToUint64()
edges, ok := r.edges[chanIDInt]
@ -381,7 +438,11 @@ func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID,
// NOTE: This method is part of the ChannelGraphSource interface.
func (r *mockGraphSource) MarkEdgeLive(chanID lnwire.ShortChannelID) error {
r.mu.Lock()
defer r.mu.Unlock()
defer func() {
r.callCount["MarkEdgeLive"]++
r.mu.Unlock()
}()
delete(r.zombies, chanID.ToUint64())
return nil
}
@ -391,7 +452,10 @@ func (r *mockGraphSource) MarkEdgeZombie(chanID lnwire.ShortChannelID, pubKey1,
pubKey2 [33]byte) error {
r.mu.Lock()
defer r.mu.Unlock()
defer func() {
r.callCount["MarkEdgeZombie"]++
r.mu.Unlock()
}()
r.zombies[chanID.ToUint64()] = [][33]byte{pubKey1, pubKey2}
@ -874,7 +938,7 @@ func createTestCtx(t *testing.T, startHeight uint32, isChanPeer bool) (
// any p2p functionality, the peer send and switch send,
// broadcast functions won't be populated.
notifier := newMockNotifier()
router := newMockRouter(startHeight)
router := newMockRouter(t, startHeight)
chain := &lnmock.MockChain{}
t.Cleanup(func() {
chain.AssertExpectations(t)