mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-19 05:45:21 +01:00
discovery/gossiper_test: prevent race conditions within mockGraphSource
This commit is contained in:
parent
73b4bc4b68
commit
6e556aa897
@ -108,28 +108,36 @@ func (n *mockSigner) SignMessage(pubKey *btcec.PublicKey,
|
||||
}
|
||||
|
||||
type mockGraphSource struct {
|
||||
nodes []*channeldb.LightningNode
|
||||
infos map[uint64]*channeldb.ChannelEdgeInfo
|
||||
edges map[uint64][]*channeldb.ChannelEdgePolicy
|
||||
bestHeight uint32
|
||||
|
||||
mu sync.Mutex
|
||||
nodes []channeldb.LightningNode
|
||||
infos map[uint64]channeldb.ChannelEdgeInfo
|
||||
edges map[uint64][]channeldb.ChannelEdgePolicy
|
||||
}
|
||||
|
||||
func newMockRouter(height uint32) *mockGraphSource {
|
||||
return &mockGraphSource{
|
||||
bestHeight: height,
|
||||
infos: make(map[uint64]*channeldb.ChannelEdgeInfo),
|
||||
edges: make(map[uint64][]*channeldb.ChannelEdgePolicy),
|
||||
infos: make(map[uint64]channeldb.ChannelEdgeInfo),
|
||||
edges: make(map[uint64][]channeldb.ChannelEdgePolicy),
|
||||
}
|
||||
}
|
||||
|
||||
var _ routing.ChannelGraphSource = (*mockGraphSource)(nil)
|
||||
|
||||
func (r *mockGraphSource) AddNode(node *channeldb.LightningNode) error {
|
||||
r.nodes = append(r.nodes, node)
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.nodes = append(r.nodes, *node)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *mockGraphSource) AddEdge(info *channeldb.ChannelEdgeInfo) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, ok := r.infos[info.ChannelID]; ok {
|
||||
return errors.New("info already exist")
|
||||
}
|
||||
@ -137,15 +145,15 @@ func (r *mockGraphSource) AddEdge(info *channeldb.ChannelEdgeInfo) error {
|
||||
// Usually, the capacity is fetched in the router from the funding txout.
|
||||
// Since the mockGraphSource can't access the txout, assign a default value.
|
||||
info.Capacity = maxBtcFundingAmount
|
||||
r.infos[info.ChannelID] = info
|
||||
r.infos[info.ChannelID] = *info
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *mockGraphSource) UpdateEdge(edge *channeldb.ChannelEdgePolicy) error {
|
||||
r.edges[edge.ChannelID] = append(
|
||||
r.edges[edge.ChannelID],
|
||||
edge,
|
||||
)
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.edges[edge.ChannelID] = append(r.edges[edge.ChannelID], *edge)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -159,11 +167,19 @@ func (r *mockGraphSource) CurrentBlockHeight() (uint32, error) {
|
||||
|
||||
func (r *mockGraphSource) AddProof(chanID lnwire.ShortChannelID,
|
||||
proof *channeldb.ChannelAuthProof) error {
|
||||
info, ok := r.infos[chanID.ToUint64()]
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
chanIDInt := chanID.ToUint64()
|
||||
info, ok := r.infos[chanIDInt]
|
||||
if !ok {
|
||||
return errors.New("channel does not exist")
|
||||
}
|
||||
|
||||
info.AuthProof = proof
|
||||
r.infos[chanIDInt] = info
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -186,6 +202,9 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) (
|
||||
*channeldb.ChannelEdgePolicy,
|
||||
*channeldb.ChannelEdgePolicy, error) {
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
chanInfo, ok := r.infos[chanID.ToUint64()]
|
||||
if !ok {
|
||||
return nil, nil, nil, channeldb.ErrEdgeNotFound
|
||||
@ -193,14 +212,16 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) (
|
||||
|
||||
edges := r.edges[chanID.ToUint64()]
|
||||
if len(edges) == 0 {
|
||||
return chanInfo, nil, nil, nil
|
||||
return &chanInfo, nil, nil, nil
|
||||
}
|
||||
|
||||
if len(edges) == 1 {
|
||||
return chanInfo, edges[0], nil, nil
|
||||
edge1 := edges[0]
|
||||
return &chanInfo, &edge1, nil, nil
|
||||
}
|
||||
|
||||
return chanInfo, edges[0], edges[1], nil
|
||||
edge1, edge2 := edges[0], edges[1]
|
||||
return &chanInfo, &edge1, &edge2, nil
|
||||
}
|
||||
|
||||
func (r *mockGraphSource) FetchLightningNode(
|
||||
@ -208,7 +229,7 @@ func (r *mockGraphSource) FetchLightningNode(
|
||||
|
||||
for _, node := range r.nodes {
|
||||
if bytes.Equal(nodePub[:], node.PubKeyBytes[:]) {
|
||||
return node, nil
|
||||
return &node, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -218,6 +239,9 @@ func (r *mockGraphSource) FetchLightningNode(
|
||||
// IsStaleNode returns true if the graph source has a node announcement for the
|
||||
// target node with a more recent timestamp.
|
||||
func (r *mockGraphSource) IsStaleNode(nodePub routing.Vertex, timestamp time.Time) bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
for _, node := range r.nodes {
|
||||
if node.PubKeyBytes == nodePub {
|
||||
return node.LastUpdate.After(timestamp) ||
|
||||
@ -258,6 +282,9 @@ func (r *mockGraphSource) IsPublicNode(node routing.Vertex) (bool, error) {
|
||||
// IsKnownEdge returns true if the graph source already knows of the passed
|
||||
// channel ID.
|
||||
func (r *mockGraphSource) IsKnownEdge(chanID lnwire.ShortChannelID) bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
_, ok := r.infos[chanID.ToUint64()]
|
||||
return ok
|
||||
}
|
||||
@ -267,6 +294,9 @@ func (r *mockGraphSource) IsKnownEdge(chanID lnwire.ShortChannelID) bool {
|
||||
func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID,
|
||||
timestamp time.Time, flags lnwire.ChanUpdateChanFlags) bool {
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
edges, ok := r.edges[chanID.ToUint64()]
|
||||
if !ok {
|
||||
return false
|
||||
|
Loading…
Reference in New Issue
Block a user