mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 09:53:54 +01:00
Merge pull request #8400 from ellemouton/deadlockFix
channeldb: acquire mutexes in the same order throughout
This commit is contained in:
commit
41c167d37c
@ -176,6 +176,9 @@ const (
|
||||
type ChannelGraph struct {
|
||||
db kvdb.Backend
|
||||
|
||||
// cacheMu guards all caches (rejectCache, chanCache, graphCache). If
|
||||
// this mutex will be acquired at the same time as the DB mutex then
|
||||
// the cacheMu MUST be acquired first to prevent deadlock.
|
||||
cacheMu sync.RWMutex
|
||||
rejectCache *rejectCache
|
||||
chanCache *channelCache
|
||||
@ -1331,8 +1334,8 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
|
||||
// will be returned if that outpoint isn't known to be
|
||||
// a channel. If no error is returned, then a channel
|
||||
// was successfully pruned.
|
||||
err = c.delChannelEdge(
|
||||
edges, edgeIndex, chanIndex, zombieIndex, nodes,
|
||||
err = c.delChannelEdgeUnsafe(
|
||||
edges, edgeIndex, chanIndex, zombieIndex,
|
||||
chanID, false, false,
|
||||
)
|
||||
if err != nil && err != ErrEdgeNotFound {
|
||||
@ -1562,10 +1565,6 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) (
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nodes, err := tx.CreateTopLevelBucket(nodeBucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Scan from chanIDStart to chanIDEnd, deleting every
|
||||
// found edge.
|
||||
@ -1590,8 +1589,8 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) (
|
||||
}
|
||||
|
||||
for _, k := range keys {
|
||||
err = c.delChannelEdge(
|
||||
edges, edgeIndex, chanIndex, zombieIndex, nodes,
|
||||
err = c.delChannelEdgeUnsafe(
|
||||
edges, edgeIndex, chanIndex, zombieIndex,
|
||||
k, false, false,
|
||||
)
|
||||
if err != nil && err != ErrEdgeNotFound {
|
||||
@ -1734,8 +1733,8 @@ func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning, markZombie bool,
|
||||
var rawChanID [8]byte
|
||||
for _, chanID := range chanIDs {
|
||||
byteOrder.PutUint64(rawChanID[:], chanID)
|
||||
err := c.delChannelEdge(
|
||||
edges, edgeIndex, chanIndex, zombieIndex, nodes,
|
||||
err := c.delChannelEdgeUnsafe(
|
||||
edges, edgeIndex, chanIndex, zombieIndex,
|
||||
rawChanID[:], markZombie, strictZombiePruning,
|
||||
)
|
||||
if err != nil {
|
||||
@ -2091,6 +2090,9 @@ func (c *ChannelGraph) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo,
|
||||
|
||||
var newChanIDs []uint64
|
||||
|
||||
c.cacheMu.Lock()
|
||||
defer c.cacheMu.Unlock()
|
||||
|
||||
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
|
||||
edges := tx.ReadBucket(edgeBucket)
|
||||
if edges == nil {
|
||||
@ -2143,7 +2145,7 @@ func (c *ChannelGraph) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo,
|
||||
// and we let it be added to the set of IDs to
|
||||
// query our peer for.
|
||||
case isZombie && !isStillZombie:
|
||||
err := c.markEdgeLive(tx, scid)
|
||||
err := c.markEdgeLiveUnsafe(tx, scid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -2355,7 +2357,11 @@ func (c *ChannelGraph) FilterChannelRange(startHeight,
|
||||
// skipped and the result will contain only those edges that exist at the time
|
||||
// of the query. This can be used to respond to peer queries that are seeking to
|
||||
// fill in gaps in their view of the channel graph.
|
||||
func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
|
||||
//
|
||||
// NOTE: An optional transaction may be provided. If none is provided, then a
|
||||
// new one will be created.
|
||||
func (c *ChannelGraph) FetchChanInfos(tx kvdb.RTx, chanIDs []uint64) (
|
||||
[]ChannelEdge, error) {
|
||||
// TODO(roasbeef): sort cids?
|
||||
|
||||
var (
|
||||
@ -2363,7 +2369,7 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
|
||||
cidBytes [8]byte
|
||||
)
|
||||
|
||||
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
|
||||
fetchChanInfos := func(tx kvdb.RTx) error {
|
||||
edges := tx.ReadBucket(edgeBucket)
|
||||
if edges == nil {
|
||||
return ErrGraphNoEdgesFound
|
||||
@ -2425,9 +2431,20 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}, func() {
|
||||
chanEdges = nil
|
||||
})
|
||||
}
|
||||
|
||||
if tx == nil {
|
||||
err := kvdb.View(c.db, fetchChanInfos, func() {
|
||||
chanEdges = nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return chanEdges, nil
|
||||
}
|
||||
|
||||
err := fetchChanInfos(tx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -2473,8 +2490,16 @@ func delEdgeUpdateIndexEntry(edgesBucket kvdb.RwBucket, chanID uint64,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex,
|
||||
nodes kvdb.RwBucket, chanID []byte, isZombie, strictZombie bool) error {
|
||||
// delChannelEdgeUnsafe deletes the edge with the given chanID from the graph
|
||||
// cache. It then goes on to delete any policy info and edge info for this
|
||||
// channel from the DB and finally, if isZombie is true, it will add an entry
|
||||
// for this channel in the zombie index.
|
||||
//
|
||||
// NOTE: this method MUST only be called if the cacheMu has already been
|
||||
// acquired.
|
||||
func (c *ChannelGraph) delChannelEdgeUnsafe(edges, edgeIndex, chanIndex,
|
||||
zombieIndex kvdb.RwBucket, chanID []byte, isZombie,
|
||||
strictZombie bool) error {
|
||||
|
||||
edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID)
|
||||
if err != nil {
|
||||
@ -3612,16 +3637,19 @@ func markEdgeZombie(zombieIndex kvdb.RwBucket, chanID uint64, pubKey1,
|
||||
|
||||
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
|
||||
func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error {
|
||||
return c.markEdgeLive(nil, chanID)
|
||||
}
|
||||
|
||||
// markEdgeLive clears an edge from the zombie index. This method can be called
|
||||
// with an existing kvdb.RwTx or the argument can be set to nil in which case a
|
||||
// new transaction will be created.
|
||||
func (c *ChannelGraph) markEdgeLive(tx kvdb.RwTx, chanID uint64) error {
|
||||
c.cacheMu.Lock()
|
||||
defer c.cacheMu.Unlock()
|
||||
|
||||
return c.markEdgeLiveUnsafe(nil, chanID)
|
||||
}
|
||||
|
||||
// markEdgeLiveUnsafe clears an edge from the zombie index. This method can be
|
||||
// called with an existing kvdb.RwTx or the argument can be set to nil in which
|
||||
// case a new transaction will be created.
|
||||
//
|
||||
// NOTE: this method MUST only be called if the cacheMu has already been
|
||||
// acquired.
|
||||
func (c *ChannelGraph) markEdgeLiveUnsafe(tx kvdb.RwTx, chanID uint64) error {
|
||||
dbFn := func(tx kvdb.RwTx) error {
|
||||
edges := tx.ReadWriteBucket(edgeBucket)
|
||||
if edges == nil {
|
||||
@ -3660,7 +3688,7 @@ func (c *ChannelGraph) markEdgeLive(tx kvdb.RwTx, chanID uint64) error {
|
||||
// We need to add the channel back into our graph cache, otherwise we
|
||||
// won't use it for path finding.
|
||||
if c.graphCache != nil {
|
||||
edgeInfos, err := c.FetchChanInfos([]uint64{chanID})
|
||||
edgeInfos, err := c.FetchChanInfos(tx, []uint64{chanID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -2090,6 +2090,296 @@ func TestFilterKnownChanIDs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestStressTestChannelGraphAPI is a stress test that concurrently calls some
|
||||
// of the ChannelGraph methods in various orders in order to ensure that no
|
||||
// deadlock can occur. This test currently focuses on stress testing all the
|
||||
// methods that acquire the cache mutex along with the DB mutex.
|
||||
func TestStressTestChannelGraphAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
graph, err := MakeTestGraph(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
node1, err := createTestVertex(graph.db)
|
||||
require.NoError(t, err, "unable to create test node")
|
||||
require.NoError(t, graph.AddLightningNode(node1))
|
||||
|
||||
node2, err := createTestVertex(graph.db)
|
||||
require.NoError(t, err, "unable to create test node")
|
||||
require.NoError(t, graph.AddLightningNode(node2))
|
||||
|
||||
err = graph.SetSourceNode(node1)
|
||||
require.NoError(t, err)
|
||||
|
||||
type chanInfo struct {
|
||||
info models.ChannelEdgeInfo
|
||||
id lnwire.ShortChannelID
|
||||
}
|
||||
|
||||
var (
|
||||
chans []*chanInfo
|
||||
mu sync.RWMutex
|
||||
)
|
||||
|
||||
// newBlockHeight returns a random block height between 0 and 100.
|
||||
newBlockHeight := func() uint32 {
|
||||
return uint32(rand.Int31n(100))
|
||||
}
|
||||
|
||||
// addNewChan is a will create and return a new random channel and will
|
||||
// add it to the set of channels.
|
||||
addNewChan := func() *chanInfo {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
channel, chanID := createEdge(
|
||||
newBlockHeight(), rand.Uint32(), uint16(rand.Int()),
|
||||
rand.Uint32(), node1, node2,
|
||||
)
|
||||
|
||||
newChan := &chanInfo{
|
||||
info: channel,
|
||||
id: chanID,
|
||||
}
|
||||
chans = append(chans, newChan)
|
||||
|
||||
return newChan
|
||||
}
|
||||
|
||||
// getRandChan picks a random channel from the set and returns it.
|
||||
getRandChan := func() *chanInfo {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
|
||||
if len(chans) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return chans[rand.Intn(len(chans))]
|
||||
}
|
||||
|
||||
// getRandChanSet returns a random set of channels.
|
||||
getRandChanSet := func() []*chanInfo {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
|
||||
if len(chans) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
start := rand.Intn(len(chans))
|
||||
end := rand.Intn(len(chans))
|
||||
|
||||
if end < start {
|
||||
start, end = end, start
|
||||
}
|
||||
|
||||
var infoCopy []*chanInfo
|
||||
for i := start; i < end; i++ {
|
||||
infoCopy = append(infoCopy, &chanInfo{
|
||||
info: chans[i].info,
|
||||
id: chans[i].id,
|
||||
})
|
||||
}
|
||||
|
||||
return infoCopy
|
||||
}
|
||||
|
||||
// delChan deletes the channel with the given ID from the set if it
|
||||
// exists.
|
||||
delChan := func(id lnwire.ShortChannelID) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
index := -1
|
||||
for i, c := range chans {
|
||||
if c.id == id {
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if index == -1 {
|
||||
return
|
||||
}
|
||||
|
||||
chans = append(chans[:index], chans[index+1:]...)
|
||||
}
|
||||
|
||||
var blockHash chainhash.Hash
|
||||
copy(blockHash[:], bytes.Repeat([]byte{2}, 32))
|
||||
|
||||
var methodsMu sync.Mutex
|
||||
methods := []struct {
|
||||
name string
|
||||
fn func() error
|
||||
}{
|
||||
{
|
||||
name: "MarkEdgeZombie",
|
||||
fn: func() error {
|
||||
channel := getRandChan()
|
||||
if channel == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return graph.MarkEdgeZombie(
|
||||
channel.id.ToUint64(),
|
||||
node1.PubKeyBytes,
|
||||
node2.PubKeyBytes,
|
||||
)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "FilterKnownChanIDs",
|
||||
fn: func() error {
|
||||
chanSet := getRandChanSet()
|
||||
var chanIDs []ChannelUpdateInfo
|
||||
|
||||
for _, c := range chanSet {
|
||||
chanIDs = append(
|
||||
chanIDs,
|
||||
ChannelUpdateInfo{
|
||||
ShortChannelID: c.id,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
_, err := graph.FilterKnownChanIDs(
|
||||
chanIDs,
|
||||
func(t time.Time, t2 time.Time) bool {
|
||||
return rand.Intn(2) == 0
|
||||
},
|
||||
)
|
||||
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "HasChannelEdge",
|
||||
fn: func() error {
|
||||
channel := getRandChan()
|
||||
if channel == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, _, _, _, err := graph.HasChannelEdge(
|
||||
channel.id.ToUint64(),
|
||||
)
|
||||
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "PruneGraph",
|
||||
fn: func() error {
|
||||
chanSet := getRandChanSet()
|
||||
var spentOutpoints []*wire.OutPoint
|
||||
|
||||
for _, c := range chanSet {
|
||||
spentOutpoints = append(
|
||||
spentOutpoints,
|
||||
&c.info.ChannelPoint,
|
||||
)
|
||||
}
|
||||
|
||||
_, err := graph.PruneGraph(
|
||||
spentOutpoints, &blockHash, 100,
|
||||
)
|
||||
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ChanUpdateInHorizon",
|
||||
fn: func() error {
|
||||
_, err := graph.ChanUpdatesInHorizon(
|
||||
time.Now().Add(-time.Hour), time.Now(),
|
||||
)
|
||||
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DeleteChannelEdges",
|
||||
fn: func() error {
|
||||
var (
|
||||
strictPruning = rand.Intn(2) == 0
|
||||
markZombie = rand.Intn(2) == 0
|
||||
channels = getRandChanSet()
|
||||
chanIDs []uint64
|
||||
)
|
||||
|
||||
for _, c := range channels {
|
||||
chanIDs = append(
|
||||
chanIDs, c.id.ToUint64(),
|
||||
)
|
||||
delChan(c.id)
|
||||
}
|
||||
|
||||
err := graph.DeleteChannelEdges(
|
||||
strictPruning, markZombie, chanIDs...,
|
||||
)
|
||||
if err != nil &&
|
||||
!errors.Is(err, ErrEdgeNotFound) {
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DisconnectBlockAtHeight",
|
||||
fn: func() error {
|
||||
_, err := graph.DisconnectBlockAtHeight(
|
||||
newBlockHeight(),
|
||||
)
|
||||
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "AddChannelEdge",
|
||||
fn: func() error {
|
||||
channel := addNewChan()
|
||||
|
||||
return graph.AddChannelEdge(&channel.info)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const (
|
||||
// concurrencyLevel is the number of concurrent goroutines that
|
||||
// will be run simultaneously.
|
||||
concurrencyLevel = 10
|
||||
|
||||
// executionCount is the number of methods that will be called
|
||||
// per goroutine.
|
||||
executionCount = 100
|
||||
)
|
||||
|
||||
for i := 0; i < concurrencyLevel; i++ {
|
||||
i := i
|
||||
|
||||
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for j := 0; j < executionCount; j++ {
|
||||
// Randomly select a method to execute.
|
||||
methodIndex := rand.Intn(len(methods))
|
||||
|
||||
methodsMu.Lock()
|
||||
fn := methods[methodIndex].fn
|
||||
name := methods[methodIndex].name
|
||||
methodsMu.Unlock()
|
||||
|
||||
err := fn()
|
||||
require.NoErrorf(t, err, fmt.Sprintf(name))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterChannelRange tests that we're able to properly retrieve the full
|
||||
// set of short channel ID's for a given block range.
|
||||
func TestFilterChannelRange(t *testing.T) {
|
||||
@ -2395,7 +2685,7 @@ func TestFetchChanInfos(t *testing.T) {
|
||||
// We'll now attempt to query for the range of channel ID's we just
|
||||
// inserted into the database. We should get the exact same set of
|
||||
// edges back.
|
||||
resp, err := graph.FetchChanInfos(edgeQuery)
|
||||
resp, err := graph.FetchChanInfos(nil, edgeQuery)
|
||||
require.NoError(t, err, "unable to fetch chan edges")
|
||||
if len(resp) != len(edges) {
|
||||
t.Fatalf("expected %v edges, instead got %v", len(edges),
|
||||
|
@ -249,7 +249,7 @@ func (c *ChanSeries) FetchChanAnns(chain chainhash.Hash,
|
||||
chanIDs = append(chanIDs, chanID.ToUint64())
|
||||
}
|
||||
|
||||
channels, err := c.graph.FetchChanInfos(chanIDs)
|
||||
channels, err := c.graph.FetchChanInfos(nil, chanIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -24,6 +24,10 @@
|
||||
channel opening was pruned from memory no more channels were able to be
|
||||
created nor accepted. This PR fixes this issue and enhances the test suite
|
||||
for this behavior.
|
||||
|
||||
* [Fix deadlock possibility in
|
||||
FilterKnownChanIDs](https://github.com/lightningnetwork/lnd/pull/8400) by
|
||||
ensuring the `cacheMu` mutex is acquired before the main database lock.
|
||||
|
||||
# New Features
|
||||
## Functional Enhancements
|
||||
@ -46,3 +50,5 @@
|
||||
## Tooling and Documentation
|
||||
|
||||
# Contributors (Alphabetical Order)
|
||||
* Elle Mouton
|
||||
* ziggie1984
|
||||
|
@ -1008,20 +1008,23 @@ func (r *ChannelRouter) pruneZombieChans() error {
|
||||
if r.cfg.AssumeChannelValid {
|
||||
disabledChanIDs, err := r.cfg.Graph.DisabledChannelIDs()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get disabled channels ids "+
|
||||
"chans: %v", err)
|
||||
return fmt.Errorf("unable to get disabled channels "+
|
||||
"ids chans: %v", err)
|
||||
}
|
||||
|
||||
disabledEdges, err := r.cfg.Graph.FetchChanInfos(disabledChanIDs)
|
||||
disabledEdges, err := r.cfg.Graph.FetchChanInfos(
|
||||
nil, disabledChanIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to fetch disabled channels edges "+
|
||||
"chans: %v", err)
|
||||
return fmt.Errorf("unable to fetch disabled channels "+
|
||||
"edges chans: %v", err)
|
||||
}
|
||||
|
||||
// Ensuring we won't prune our own channel from the graph.
|
||||
for _, disabledEdge := range disabledEdges {
|
||||
if !isSelfChannelEdge(disabledEdge.Info) {
|
||||
chansToPrune[disabledEdge.Info.ChannelID] = struct{}{}
|
||||
chansToPrune[disabledEdge.Info.ChannelID] =
|
||||
struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user