diff --git a/channeldb/channel.go b/channeldb/channel.go index 31873ae3f..a4a47e587 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -875,12 +875,43 @@ func fetchChanBucket(tx kvdb.RTx, nodeKey *btcec.PublicKey, func fetchChanBucketRw(tx kvdb.RwTx, nodeKey *btcec.PublicKey, // nolint:interfacer outPoint *wire.OutPoint, chainHash chainhash.Hash) (kvdb.RwBucket, error) { - readBucket, err := fetchChanBucket(tx, nodeKey, outPoint, chainHash) - if err != nil { - return nil, err + // First fetch the top level bucket which stores all data related to + // current, active channels. + openChanBucket := tx.ReadWriteBucket(openChannelBucket) + if openChanBucket == nil { + return nil, ErrNoChanDBExists } - return readBucket.(kvdb.RwBucket), nil + // TODO(roasbeef): CreateTopLevelBucket on the interface isn't like + // CreateIfNotExists, will return error + + // Within this top level bucket, fetch the bucket dedicated to storing + // open channel data specific to the remote node. + nodePub := nodeKey.SerializeCompressed() + nodeChanBucket := openChanBucket.NestedReadWriteBucket(nodePub) + if nodeChanBucket == nil { + return nil, ErrNoActiveChannels + } + + // We'll then recurse down an additional layer in order to fetch the + // bucket for this particular chain. + chainBucket := nodeChanBucket.NestedReadWriteBucket(chainHash[:]) + if chainBucket == nil { + return nil, ErrNoActiveChannels + } + + // With the bucket for the node and chain fetched, we can now go down + // another level, for this channel itself. + var chanPointBuf bytes.Buffer + if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { + return nil, err + } + chanBucket := chainBucket.NestedReadWriteBucket(chanPointBuf.Bytes()) + if chanBucket == nil { + return nil, ErrChannelNotFound + } + + return chanBucket, nil } // fullSync syncs the contents of an OpenChannel while re-using an existing @@ -965,7 +996,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { defer c.Unlock() if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { - chanBucket, err := fetchChanBucket( + chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) if err != nil { @@ -980,7 +1011,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { channel.IsPending = false channel.ShortChannelID = openLoc - return putOpenChannel(chanBucket.(kvdb.RwBucket), channel) + return putOpenChannel(chanBucket, channel) }, func() {}); err != nil { return err }