Merge pull request #8400 from ellemouton/deadlockFix

channeldb: acquire mutexes in the same order throughout
This commit is contained in:
Olaoluwa Osuntokun 2024-01-22 12:57:48 -08:00 committed by GitHub
commit 41c167d37c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 361 additions and 34 deletions

View File

@ -176,6 +176,9 @@ const (
type ChannelGraph struct {
db kvdb.Backend
// cacheMu guards all caches (rejectCache, chanCache, graphCache). If
// this mutex will be acquired at the same time as the DB mutex then
// the cacheMu MUST be acquired first to prevent deadlock.
cacheMu sync.RWMutex
rejectCache *rejectCache
chanCache *channelCache
@ -1331,8 +1334,8 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
// will be returned if that outpoint isn't known to be
// a channel. If no error is returned, then a channel
// was successfully pruned.
err = c.delChannelEdge(
edges, edgeIndex, chanIndex, zombieIndex, nodes,
err = c.delChannelEdgeUnsafe(
edges, edgeIndex, chanIndex, zombieIndex,
chanID, false, false,
)
if err != nil && err != ErrEdgeNotFound {
@ -1562,10 +1565,6 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) (
if err != nil {
return err
}
nodes, err := tx.CreateTopLevelBucket(nodeBucket)
if err != nil {
return err
}
// Scan from chanIDStart to chanIDEnd, deleting every
// found edge.
@ -1590,8 +1589,8 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) (
}
for _, k := range keys {
err = c.delChannelEdge(
edges, edgeIndex, chanIndex, zombieIndex, nodes,
err = c.delChannelEdgeUnsafe(
edges, edgeIndex, chanIndex, zombieIndex,
k, false, false,
)
if err != nil && err != ErrEdgeNotFound {
@ -1734,8 +1733,8 @@ func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning, markZombie bool,
var rawChanID [8]byte
for _, chanID := range chanIDs {
byteOrder.PutUint64(rawChanID[:], chanID)
err := c.delChannelEdge(
edges, edgeIndex, chanIndex, zombieIndex, nodes,
err := c.delChannelEdgeUnsafe(
edges, edgeIndex, chanIndex, zombieIndex,
rawChanID[:], markZombie, strictZombiePruning,
)
if err != nil {
@ -2091,6 +2090,9 @@ func (c *ChannelGraph) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo,
var newChanIDs []uint64
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
@ -2143,7 +2145,7 @@ func (c *ChannelGraph) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo,
// and we let it be added to the set of IDs to
// query our peer for.
case isZombie && !isStillZombie:
err := c.markEdgeLive(tx, scid)
err := c.markEdgeLiveUnsafe(tx, scid)
if err != nil {
return err
}
@ -2355,7 +2357,11 @@ func (c *ChannelGraph) FilterChannelRange(startHeight,
// skipped and the result will contain only those edges that exist at the time
// of the query. This can be used to respond to peer queries that are seeking to
// fill in gaps in their view of the channel graph.
func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
//
// NOTE: An optional transaction may be provided. If none is provided, then a
// new one will be created.
func (c *ChannelGraph) FetchChanInfos(tx kvdb.RTx, chanIDs []uint64) (
[]ChannelEdge, error) {
// TODO(roasbeef): sort cids?
var (
@ -2363,7 +2369,7 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
cidBytes [8]byte
)
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
fetchChanInfos := func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
@ -2425,9 +2431,20 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
})
}
return nil
}, func() {
chanEdges = nil
})
}
if tx == nil {
err := kvdb.View(c.db, fetchChanInfos, func() {
chanEdges = nil
})
if err != nil {
return nil, err
}
return chanEdges, nil
}
err := fetchChanInfos(tx)
if err != nil {
return nil, err
}
@ -2473,8 +2490,16 @@ func delEdgeUpdateIndexEntry(edgesBucket kvdb.RwBucket, chanID uint64,
return nil
}
func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex,
nodes kvdb.RwBucket, chanID []byte, isZombie, strictZombie bool) error {
// delChannelEdgeUnsafe deletes the edge with the given chanID from the graph
// cache. It then goes on to delete any policy info and edge info for this
// channel from the DB and finally, if isZombie is true, it will add an entry
// for this channel in the zombie index.
//
// NOTE: this method MUST only be called if the cacheMu has already been
// acquired.
func (c *ChannelGraph) delChannelEdgeUnsafe(edges, edgeIndex, chanIndex,
zombieIndex kvdb.RwBucket, chanID []byte, isZombie,
strictZombie bool) error {
edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID)
if err != nil {
@ -3612,16 +3637,19 @@ func markEdgeZombie(zombieIndex kvdb.RwBucket, chanID uint64, pubKey1,
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error {
return c.markEdgeLive(nil, chanID)
}
// markEdgeLive clears an edge from the zombie index. This method can be called
// with an existing kvdb.RwTx or the argument can be set to nil in which case a
// new transaction will be created.
func (c *ChannelGraph) markEdgeLive(tx kvdb.RwTx, chanID uint64) error {
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
return c.markEdgeLiveUnsafe(nil, chanID)
}
// markEdgeLiveUnsafe clears an edge from the zombie index. This method can be
// called with an existing kvdb.RwTx or the argument can be set to nil in which
// case a new transaction will be created.
//
// NOTE: this method MUST only be called if the cacheMu has already been
// acquired.
func (c *ChannelGraph) markEdgeLiveUnsafe(tx kvdb.RwTx, chanID uint64) error {
dbFn := func(tx kvdb.RwTx) error {
edges := tx.ReadWriteBucket(edgeBucket)
if edges == nil {
@ -3660,7 +3688,7 @@ func (c *ChannelGraph) markEdgeLive(tx kvdb.RwTx, chanID uint64) error {
// We need to add the channel back into our graph cache, otherwise we
// won't use it for path finding.
if c.graphCache != nil {
edgeInfos, err := c.FetchChanInfos([]uint64{chanID})
edgeInfos, err := c.FetchChanInfos(tx, []uint64{chanID})
if err != nil {
return err
}

View File

@ -2090,6 +2090,296 @@ func TestFilterKnownChanIDs(t *testing.T) {
}
}
// TestStressTestChannelGraphAPI is a stress test that concurrently calls some
// of the ChannelGraph methods in various orders in order to ensure that no
// deadlock can occur. This test currently focuses on stress testing all the
// methods that acquire the cache mutex along with the DB mutex.
func TestStressTestChannelGraphAPI(t *testing.T) {
t.Parallel()
graph, err := MakeTestGraph(t)
require.NoError(t, err)
node1, err := createTestVertex(graph.db)
require.NoError(t, err, "unable to create test node")
require.NoError(t, graph.AddLightningNode(node1))
node2, err := createTestVertex(graph.db)
require.NoError(t, err, "unable to create test node")
require.NoError(t, graph.AddLightningNode(node2))
err = graph.SetSourceNode(node1)
require.NoError(t, err)
type chanInfo struct {
info models.ChannelEdgeInfo
id lnwire.ShortChannelID
}
var (
chans []*chanInfo
mu sync.RWMutex
)
// newBlockHeight returns a random block height between 0 and 100.
newBlockHeight := func() uint32 {
return uint32(rand.Int31n(100))
}
// addNewChan is a will create and return a new random channel and will
// add it to the set of channels.
addNewChan := func() *chanInfo {
mu.Lock()
defer mu.Unlock()
channel, chanID := createEdge(
newBlockHeight(), rand.Uint32(), uint16(rand.Int()),
rand.Uint32(), node1, node2,
)
newChan := &chanInfo{
info: channel,
id: chanID,
}
chans = append(chans, newChan)
return newChan
}
// getRandChan picks a random channel from the set and returns it.
getRandChan := func() *chanInfo {
mu.RLock()
defer mu.RUnlock()
if len(chans) == 0 {
return nil
}
return chans[rand.Intn(len(chans))]
}
// getRandChanSet returns a random set of channels.
getRandChanSet := func() []*chanInfo {
mu.RLock()
defer mu.RUnlock()
if len(chans) == 0 {
return nil
}
start := rand.Intn(len(chans))
end := rand.Intn(len(chans))
if end < start {
start, end = end, start
}
var infoCopy []*chanInfo
for i := start; i < end; i++ {
infoCopy = append(infoCopy, &chanInfo{
info: chans[i].info,
id: chans[i].id,
})
}
return infoCopy
}
// delChan deletes the channel with the given ID from the set if it
// exists.
delChan := func(id lnwire.ShortChannelID) {
mu.Lock()
defer mu.Unlock()
index := -1
for i, c := range chans {
if c.id == id {
index = i
break
}
}
if index == -1 {
return
}
chans = append(chans[:index], chans[index+1:]...)
}
var blockHash chainhash.Hash
copy(blockHash[:], bytes.Repeat([]byte{2}, 32))
var methodsMu sync.Mutex
methods := []struct {
name string
fn func() error
}{
{
name: "MarkEdgeZombie",
fn: func() error {
channel := getRandChan()
if channel == nil {
return nil
}
return graph.MarkEdgeZombie(
channel.id.ToUint64(),
node1.PubKeyBytes,
node2.PubKeyBytes,
)
},
},
{
name: "FilterKnownChanIDs",
fn: func() error {
chanSet := getRandChanSet()
var chanIDs []ChannelUpdateInfo
for _, c := range chanSet {
chanIDs = append(
chanIDs,
ChannelUpdateInfo{
ShortChannelID: c.id,
},
)
}
_, err := graph.FilterKnownChanIDs(
chanIDs,
func(t time.Time, t2 time.Time) bool {
return rand.Intn(2) == 0
},
)
return err
},
},
{
name: "HasChannelEdge",
fn: func() error {
channel := getRandChan()
if channel == nil {
return nil
}
_, _, _, _, err := graph.HasChannelEdge(
channel.id.ToUint64(),
)
return err
},
},
{
name: "PruneGraph",
fn: func() error {
chanSet := getRandChanSet()
var spentOutpoints []*wire.OutPoint
for _, c := range chanSet {
spentOutpoints = append(
spentOutpoints,
&c.info.ChannelPoint,
)
}
_, err := graph.PruneGraph(
spentOutpoints, &blockHash, 100,
)
return err
},
},
{
name: "ChanUpdateInHorizon",
fn: func() error {
_, err := graph.ChanUpdatesInHorizon(
time.Now().Add(-time.Hour), time.Now(),
)
return err
},
},
{
name: "DeleteChannelEdges",
fn: func() error {
var (
strictPruning = rand.Intn(2) == 0
markZombie = rand.Intn(2) == 0
channels = getRandChanSet()
chanIDs []uint64
)
for _, c := range channels {
chanIDs = append(
chanIDs, c.id.ToUint64(),
)
delChan(c.id)
}
err := graph.DeleteChannelEdges(
strictPruning, markZombie, chanIDs...,
)
if err != nil &&
!errors.Is(err, ErrEdgeNotFound) {
return err
}
return nil
},
},
{
name: "DisconnectBlockAtHeight",
fn: func() error {
_, err := graph.DisconnectBlockAtHeight(
newBlockHeight(),
)
return err
},
},
{
name: "AddChannelEdge",
fn: func() error {
channel := addNewChan()
return graph.AddChannelEdge(&channel.info)
},
},
}
const (
// concurrencyLevel is the number of concurrent goroutines that
// will be run simultaneously.
concurrencyLevel = 10
// executionCount is the number of methods that will be called
// per goroutine.
executionCount = 100
)
for i := 0; i < concurrencyLevel; i++ {
i := i
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
t.Parallel()
for j := 0; j < executionCount; j++ {
// Randomly select a method to execute.
methodIndex := rand.Intn(len(methods))
methodsMu.Lock()
fn := methods[methodIndex].fn
name := methods[methodIndex].name
methodsMu.Unlock()
err := fn()
require.NoErrorf(t, err, fmt.Sprintf(name))
}
})
}
}
// TestFilterChannelRange tests that we're able to properly retrieve the full
// set of short channel ID's for a given block range.
func TestFilterChannelRange(t *testing.T) {
@ -2395,7 +2685,7 @@ func TestFetchChanInfos(t *testing.T) {
// We'll now attempt to query for the range of channel ID's we just
// inserted into the database. We should get the exact same set of
// edges back.
resp, err := graph.FetchChanInfos(edgeQuery)
resp, err := graph.FetchChanInfos(nil, edgeQuery)
require.NoError(t, err, "unable to fetch chan edges")
if len(resp) != len(edges) {
t.Fatalf("expected %v edges, instead got %v", len(edges),

View File

@ -249,7 +249,7 @@ func (c *ChanSeries) FetchChanAnns(chain chainhash.Hash,
chanIDs = append(chanIDs, chanID.ToUint64())
}
channels, err := c.graph.FetchChanInfos(chanIDs)
channels, err := c.graph.FetchChanInfos(nil, chanIDs)
if err != nil {
return nil, err
}

View File

@ -24,6 +24,10 @@
channel opening was pruned from memory no more channels were able to be
created nor accepted. This PR fixes this issue and enhances the test suite
for this behavior.
* [Fix deadlock possibility in
FilterKnownChanIDs](https://github.com/lightningnetwork/lnd/pull/8400) by
ensuring the `cacheMu` mutex is acquired before the main database lock.
# New Features
## Functional Enhancements
@ -46,3 +50,5 @@
## Tooling and Documentation
# Contributors (Alphabetical Order)
* Elle Mouton
* ziggie1984

View File

@ -1008,20 +1008,23 @@ func (r *ChannelRouter) pruneZombieChans() error {
if r.cfg.AssumeChannelValid {
disabledChanIDs, err := r.cfg.Graph.DisabledChannelIDs()
if err != nil {
return fmt.Errorf("unable to get disabled channels ids "+
"chans: %v", err)
return fmt.Errorf("unable to get disabled channels "+
"ids chans: %v", err)
}
disabledEdges, err := r.cfg.Graph.FetchChanInfos(disabledChanIDs)
disabledEdges, err := r.cfg.Graph.FetchChanInfos(
nil, disabledChanIDs,
)
if err != nil {
return fmt.Errorf("unable to fetch disabled channels edges "+
"chans: %v", err)
return fmt.Errorf("unable to fetch disabled channels "+
"edges chans: %v", err)
}
// Ensuring we won't prune our own channel from the graph.
for _, disabledEdge := range disabledEdges {
if !isSelfChannelEdge(disabledEdge.Info) {
chansToPrune[disabledEdge.Info.ChannelID] = struct{}{}
chansToPrune[disabledEdge.Info.ChannelID] =
struct{}{}
}
}
}