mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-04 17:55:36 +01:00
discovery: let mockGraphSource track call count
We will use this in an upcoming test.
This commit is contained in:
parent
f744a5477f
commit
78413f5b84
1 changed files with 87 additions and 23 deletions
|
@ -77,6 +77,7 @@ var (
|
|||
)
|
||||
|
||||
type mockGraphSource struct {
|
||||
t *testing.T
|
||||
bestHeight uint32
|
||||
|
||||
mu sync.Mutex
|
||||
|
@ -85,25 +86,42 @@ type mockGraphSource struct {
|
|||
edges map[uint64][]models.ChannelEdgePolicy
|
||||
zombies map[uint64][][33]byte
|
||||
chansToReject map[uint64]struct{}
|
||||
|
||||
callCount map[string]int
|
||||
}
|
||||
|
||||
func newMockRouter(height uint32) *mockGraphSource {
|
||||
func newMockRouter(t *testing.T, height uint32) *mockGraphSource {
|
||||
return &mockGraphSource{
|
||||
t: t,
|
||||
bestHeight: height,
|
||||
infos: make(map[uint64]models.ChannelEdgeInfo),
|
||||
edges: make(map[uint64][]models.ChannelEdgePolicy),
|
||||
edges: make(
|
||||
map[uint64][]models.ChannelEdgePolicy,
|
||||
),
|
||||
zombies: make(map[uint64][][33]byte),
|
||||
chansToReject: make(map[uint64]struct{}),
|
||||
callCount: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
// getCallCount returns the number of times the given method has been called.
|
||||
func (r *mockGraphSource) getCallCount(method string) int {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
return r.callCount[method]
|
||||
}
|
||||
|
||||
var _ graph.ChannelGraphSource = (*mockGraphSource)(nil)
|
||||
|
||||
func (r *mockGraphSource) AddNode(node *models.LightningNode,
|
||||
_ ...batch.SchedulerOption) error {
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
defer func() {
|
||||
r.callCount["AddNode"]++
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
r.nodes = append(r.nodes, *node)
|
||||
return nil
|
||||
|
@ -119,7 +137,10 @@ func (r *mockGraphSource) IsZombieEdge(chanID lnwire.ShortChannelID) (bool,
|
|||
error) {
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
defer func() {
|
||||
r.callCount["IsZombieEdge"]++
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
_, ok := r.zombies[chanID.ToUint64()]
|
||||
|
||||
|
@ -130,7 +151,10 @@ func (r *mockGraphSource) AddEdge(info *models.ChannelEdgeInfo,
|
|||
_ ...batch.SchedulerOption) error {
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
defer func() {
|
||||
r.callCount["AddEdge"]++
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
if _, ok := r.infos[info.ChannelID]; ok {
|
||||
return errors.New("info already exist")
|
||||
|
@ -155,7 +179,10 @@ func (r *mockGraphSource) UpdateEdge(edge *models.ChannelEdgePolicy,
|
|||
_ ...batch.SchedulerOption) error {
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
defer func() {
|
||||
r.callCount["UpdateEdge"]++
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
if len(r.edges[edge.ChannelID]) == 0 {
|
||||
r.edges[edge.ChannelID] = make([]models.ChannelEdgePolicy, 2)
|
||||
|
@ -171,6 +198,12 @@ func (r *mockGraphSource) UpdateEdge(edge *models.ChannelEdgePolicy,
|
|||
}
|
||||
|
||||
func (r *mockGraphSource) CurrentBlockHeight() (uint32, error) {
|
||||
r.mu.Lock()
|
||||
defer func() {
|
||||
r.callCount["CurrentBlock"]++
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
return r.bestHeight, nil
|
||||
}
|
||||
|
||||
|
@ -178,7 +211,10 @@ func (r *mockGraphSource) AddProof(chanID lnwire.ShortChannelID,
|
|||
proof *models.ChannelAuthProof) error {
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
defer func() {
|
||||
r.callCount["AddProof"]++
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
chanIDInt := chanID.ToUint64()
|
||||
info, ok := r.infos[chanIDInt]
|
||||
|
@ -192,17 +228,14 @@ func (r *mockGraphSource) AddProof(chanID lnwire.ShortChannelID,
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *mockGraphSource) ForEachNode(
|
||||
func(node *models.LightningNode) error) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *mockGraphSource) ForAllOutgoingChannels(cb func(
|
||||
i *models.ChannelEdgeInfo, c *models.ChannelEdgePolicy) error) error {
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
defer func() {
|
||||
r.callCount["ForAllOutgoingChannels"]++
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
chans := make(map[uint64]graphdb.ChannelEdge)
|
||||
for _, info := range r.infos {
|
||||
|
@ -235,7 +268,10 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) (
|
|||
*models.ChannelEdgePolicy, error) {
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
defer func() {
|
||||
r.callCount["GetChannelByID"]++
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
chanIDInt := chanID.ToUint64()
|
||||
chanInfo, ok := r.infos[chanIDInt]
|
||||
|
@ -272,6 +308,12 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) (
|
|||
func (r *mockGraphSource) FetchLightningNode(
|
||||
nodePub route.Vertex) (*models.LightningNode, error) {
|
||||
|
||||
r.mu.Lock()
|
||||
defer func() {
|
||||
r.callCount["FetchLightningNode"]++
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
for _, node := range r.nodes {
|
||||
if bytes.Equal(nodePub[:], node.PubKeyBytes[:]) {
|
||||
return &node, nil
|
||||
|
@ -285,7 +327,10 @@ func (r *mockGraphSource) FetchLightningNode(
|
|||
// target node with a more recent timestamp.
|
||||
func (r *mockGraphSource) IsStaleNode(nodePub route.Vertex, timestamp time.Time) bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
defer func() {
|
||||
r.callCount["IsStaleNode"]++
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
for _, node := range r.nodes {
|
||||
if node.PubKeyBytes == nodePub {
|
||||
|
@ -311,6 +356,12 @@ func (r *mockGraphSource) IsStaleNode(nodePub route.Vertex, timestamp time.Time)
|
|||
// IsPublicNode determines whether the given vertex is seen as a public node in
|
||||
// the graph from the graph's source node's point of view.
|
||||
func (r *mockGraphSource) IsPublicNode(node route.Vertex) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer func() {
|
||||
r.callCount["IsPublicNode"]++
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
for _, info := range r.infos {
|
||||
if !bytes.Equal(node[:], info.NodeKey1Bytes[:]) &&
|
||||
!bytes.Equal(node[:], info.NodeKey2Bytes[:]) {
|
||||
|
@ -328,7 +379,10 @@ func (r *mockGraphSource) IsPublicNode(node route.Vertex) (bool, error) {
|
|||
// channel ID either as a live or zombie channel.
|
||||
func (r *mockGraphSource) IsKnownEdge(chanID lnwire.ShortChannelID) bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
defer func() {
|
||||
r.callCount["IsKnownEdge"]++
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
chanIDInt := chanID.ToUint64()
|
||||
_, exists := r.infos[chanIDInt]
|
||||
|
@ -342,7 +396,10 @@ func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID,
|
|||
timestamp time.Time, flags lnwire.ChanUpdateChanFlags) bool {
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
defer func() {
|
||||
r.callCount["IsStaleEdgePolicy"]++
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
chanIDInt := chanID.ToUint64()
|
||||
edges, ok := r.edges[chanIDInt]
|
||||
|
@ -381,7 +438,11 @@ func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID,
|
|||
// NOTE: This method is part of the ChannelGraphSource interface.
|
||||
func (r *mockGraphSource) MarkEdgeLive(chanID lnwire.ShortChannelID) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
defer func() {
|
||||
r.callCount["MarkEdgeLive"]++
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
delete(r.zombies, chanID.ToUint64())
|
||||
return nil
|
||||
}
|
||||
|
@ -391,7 +452,10 @@ func (r *mockGraphSource) MarkEdgeZombie(chanID lnwire.ShortChannelID, pubKey1,
|
|||
pubKey2 [33]byte) error {
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
defer func() {
|
||||
r.callCount["MarkEdgeZombie"]++
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
r.zombies[chanID.ToUint64()] = [][33]byte{pubKey1, pubKey2}
|
||||
|
||||
|
@ -874,7 +938,7 @@ func createTestCtx(t *testing.T, startHeight uint32, isChanPeer bool) (
|
|||
// any p2p functionality, the peer send and switch send,
|
||||
// broadcast functions won't be populated.
|
||||
notifier := newMockNotifier()
|
||||
router := newMockRouter(startHeight)
|
||||
router := newMockRouter(t, startHeight)
|
||||
chain := &lnmock.MockChain{}
|
||||
t.Cleanup(func() {
|
||||
chain.AssertExpectations(t)
|
||||
|
|
Loading…
Add table
Reference in a new issue