channeldb: fix dangerous type casting hack

This commit is contained in:
Andras Banki-Horvath 2021-09-21 19:18:12 +02:00 committed by Oliver Gugger
parent 639faeed6d
commit 292b8e1ce6
No known key found for this signature in database
GPG key ID: 8E4256593F177720

View file

@ -875,12 +875,43 @@ func fetchChanBucket(tx kvdb.RTx, nodeKey *btcec.PublicKey,
func fetchChanBucketRw(tx kvdb.RwTx, nodeKey *btcec.PublicKey, // nolint:interfacer func fetchChanBucketRw(tx kvdb.RwTx, nodeKey *btcec.PublicKey, // nolint:interfacer
outPoint *wire.OutPoint, chainHash chainhash.Hash) (kvdb.RwBucket, error) { outPoint *wire.OutPoint, chainHash chainhash.Hash) (kvdb.RwBucket, error) {
readBucket, err := fetchChanBucket(tx, nodeKey, outPoint, chainHash) // First fetch the top level bucket which stores all data related to
if err != nil { // current, active channels.
return nil, err openChanBucket := tx.ReadWriteBucket(openChannelBucket)
if openChanBucket == nil {
return nil, ErrNoChanDBExists
} }
return readBucket.(kvdb.RwBucket), nil // TODO(roasbeef): CreateTopLevelBucket on the interface isn't like
// CreateIfNotExists, will return error
// Within this top level bucket, fetch the bucket dedicated to storing
// open channel data specific to the remote node.
nodePub := nodeKey.SerializeCompressed()
nodeChanBucket := openChanBucket.NestedReadWriteBucket(nodePub)
if nodeChanBucket == nil {
return nil, ErrNoActiveChannels
}
// We'll then recurse down an additional layer in order to fetch the
// bucket for this particular chain.
chainBucket := nodeChanBucket.NestedReadWriteBucket(chainHash[:])
if chainBucket == nil {
return nil, ErrNoActiveChannels
}
// With the bucket for the node and chain fetched, we can now go down
// another level, for this channel itself.
var chanPointBuf bytes.Buffer
if err := writeOutpoint(&chanPointBuf, outPoint); err != nil {
return nil, err
}
chanBucket := chainBucket.NestedReadWriteBucket(chanPointBuf.Bytes())
if chanBucket == nil {
return nil, ErrChannelNotFound
}
return chanBucket, nil
} }
// fullSync syncs the contents of an OpenChannel while re-using an existing // fullSync syncs the contents of an OpenChannel while re-using an existing
@ -965,7 +996,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error {
defer c.Unlock() defer c.Unlock()
if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucket( chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )
if err != nil { if err != nil {
@ -980,7 +1011,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error {
channel.IsPending = false channel.IsPending = false
channel.ShortChannelID = openLoc channel.ShortChannelID = openLoc
return putOpenChannel(chanBucket.(kvdb.RwBucket), channel) return putOpenChannel(chanBucket, channel)
}, func() {}); err != nil { }, func() {}); err != nil {
return err return err
} }