discovery/gossiper_test: prevent race conditions within mockGraphSource

This commit is contained in:
Wilmer Paulino 2019-02-05 17:18:41 -08:00
parent 73b4bc4b68
commit 6e556aa897
No known key found for this signature in database
GPG Key ID: 6DF57B9F9514972F

View File

@ -108,28 +108,36 @@ func (n *mockSigner) SignMessage(pubKey *btcec.PublicKey,
}
type mockGraphSource struct {
nodes []*channeldb.LightningNode
infos map[uint64]*channeldb.ChannelEdgeInfo
edges map[uint64][]*channeldb.ChannelEdgePolicy
bestHeight uint32
mu sync.Mutex
nodes []channeldb.LightningNode
infos map[uint64]channeldb.ChannelEdgeInfo
edges map[uint64][]channeldb.ChannelEdgePolicy
}
func newMockRouter(height uint32) *mockGraphSource {
return &mockGraphSource{
bestHeight: height,
infos: make(map[uint64]*channeldb.ChannelEdgeInfo),
edges: make(map[uint64][]*channeldb.ChannelEdgePolicy),
infos: make(map[uint64]channeldb.ChannelEdgeInfo),
edges: make(map[uint64][]channeldb.ChannelEdgePolicy),
}
}
var _ routing.ChannelGraphSource = (*mockGraphSource)(nil)
func (r *mockGraphSource) AddNode(node *channeldb.LightningNode) error {
r.nodes = append(r.nodes, node)
r.mu.Lock()
defer r.mu.Unlock()
r.nodes = append(r.nodes, *node)
return nil
}
func (r *mockGraphSource) AddEdge(info *channeldb.ChannelEdgeInfo) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, ok := r.infos[info.ChannelID]; ok {
return errors.New("info already exist")
}
@ -137,15 +145,15 @@ func (r *mockGraphSource) AddEdge(info *channeldb.ChannelEdgeInfo) error {
// Usually, the capacity is fetched in the router from the funding txout.
// Since the mockGraphSource can't access the txout, assign a default value.
info.Capacity = maxBtcFundingAmount
r.infos[info.ChannelID] = info
r.infos[info.ChannelID] = *info
return nil
}
func (r *mockGraphSource) UpdateEdge(edge *channeldb.ChannelEdgePolicy) error {
r.edges[edge.ChannelID] = append(
r.edges[edge.ChannelID],
edge,
)
r.mu.Lock()
defer r.mu.Unlock()
r.edges[edge.ChannelID] = append(r.edges[edge.ChannelID], *edge)
return nil
}
@ -159,11 +167,19 @@ func (r *mockGraphSource) CurrentBlockHeight() (uint32, error) {
func (r *mockGraphSource) AddProof(chanID lnwire.ShortChannelID,
proof *channeldb.ChannelAuthProof) error {
info, ok := r.infos[chanID.ToUint64()]
r.mu.Lock()
defer r.mu.Unlock()
chanIDInt := chanID.ToUint64()
info, ok := r.infos[chanIDInt]
if !ok {
return errors.New("channel does not exist")
}
info.AuthProof = proof
r.infos[chanIDInt] = info
return nil
}
@ -186,6 +202,9 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) (
*channeldb.ChannelEdgePolicy,
*channeldb.ChannelEdgePolicy, error) {
r.mu.Lock()
defer r.mu.Unlock()
chanInfo, ok := r.infos[chanID.ToUint64()]
if !ok {
return nil, nil, nil, channeldb.ErrEdgeNotFound
@ -193,14 +212,16 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) (
edges := r.edges[chanID.ToUint64()]
if len(edges) == 0 {
return chanInfo, nil, nil, nil
return &chanInfo, nil, nil, nil
}
if len(edges) == 1 {
return chanInfo, edges[0], nil, nil
edge1 := edges[0]
return &chanInfo, &edge1, nil, nil
}
return chanInfo, edges[0], edges[1], nil
edge1, edge2 := edges[0], edges[1]
return &chanInfo, &edge1, &edge2, nil
}
func (r *mockGraphSource) FetchLightningNode(
@ -208,7 +229,7 @@ func (r *mockGraphSource) FetchLightningNode(
for _, node := range r.nodes {
if bytes.Equal(nodePub[:], node.PubKeyBytes[:]) {
return node, nil
return &node, nil
}
}
@ -218,6 +239,9 @@ func (r *mockGraphSource) FetchLightningNode(
// IsStaleNode returns true if the graph source has a node announcement for the
// target node with a more recent timestamp.
func (r *mockGraphSource) IsStaleNode(nodePub routing.Vertex, timestamp time.Time) bool {
r.mu.Lock()
defer r.mu.Unlock()
for _, node := range r.nodes {
if node.PubKeyBytes == nodePub {
return node.LastUpdate.After(timestamp) ||
@ -258,6 +282,9 @@ func (r *mockGraphSource) IsPublicNode(node routing.Vertex) (bool, error) {
// IsKnownEdge returns true if the graph source already knows of the passed
// channel ID.
func (r *mockGraphSource) IsKnownEdge(chanID lnwire.ShortChannelID) bool {
r.mu.Lock()
defer r.mu.Unlock()
_, ok := r.infos[chanID.ToUint64()]
return ok
}
@ -267,6 +294,9 @@ func (r *mockGraphSource) IsKnownEdge(chanID lnwire.ShortChannelID) bool {
func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID,
timestamp time.Time, flags lnwire.ChanUpdateChanFlags) bool {
r.mu.Lock()
defer r.mu.Unlock()
edges, ok := r.edges[chanID.ToUint64()]
if !ok {
return false