From 11cf4216e441a323d98ffa279ebb23653f2e7ae3 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 21 Sep 2021 19:18:17 +0200 Subject: [PATCH] multi: move all channelstate operations to ChannelStateDB --- chainreg/chainregistry.go | 2 +- channeldb/channel.go | 54 +++---- channeldb/channel_test.go | 54 ++++--- channeldb/db.go | 186 ++++++++++++++++--------- channeldb/db_test.go | 46 ++++-- channeldb/nodes_test.go | 8 +- channeldb/waitingproof.go | 4 +- channelnotifier/channelnotifier.go | 4 +- chanrestore.go | 2 +- contractcourt/breacharbiter.go | 6 +- contractcourt/breacharbiter_test.go | 12 +- contractcourt/chain_arbitrator.go | 16 ++- contractcourt/chain_arbitrator_test.go | 6 +- contractcourt/utils_test.go | 4 +- discovery/message_store.go | 11 +- funding/manager_test.go | 6 +- htlcswitch/circuit_map.go | 21 ++- htlcswitch/circuit_test.go | 15 +- htlcswitch/link_test.go | 4 +- htlcswitch/mock.go | 6 +- htlcswitch/payment_result.go | 14 +- htlcswitch/switch.go | 19 ++- htlcswitch/test_utils.go | 26 ++-- lnd.go | 2 +- lnrpc/invoicesrpc/addinvoice.go | 2 +- lnrpc/invoicesrpc/config_active.go | 2 +- lnwallet/config.go | 2 +- lnwallet/test/test_interface.go | 8 +- lnwallet/test_utils.go | 4 +- lnwallet/transactions_test.go | 4 +- peer/brontide.go | 2 +- peer/test_utils.go | 8 +- rpcserver.go | 16 +-- server.go | 37 +++-- subrpcserver_config.go | 2 +- 35 files changed, 377 insertions(+), 238 deletions(-) diff --git a/chainreg/chainregistry.go b/chainreg/chainregistry.go index ad0b2f54b..762ca7e2f 100644 --- a/chainreg/chainregistry.go +++ b/chainreg/chainregistry.go @@ -75,7 +75,7 @@ type Config struct { // ChanStateDB is a pointer to the database that stores the channel // state. - ChanStateDB *channeldb.DB + ChanStateDB *channeldb.ChannelStateDB // BlockCacheSize is the size (in bytes) of blocks kept in memory. BlockCacheSize uint64 diff --git a/channeldb/channel.go b/channeldb/channel.go index f6b5fab4d..858cead91 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -729,7 +729,7 @@ type OpenChannel struct { RevocationKeyLocator keychain.KeyLocator // TODO(roasbeef): eww - Db *DB + Db *ChannelStateDB // TODO(roasbeef): just need to store local and remote HTLC's? @@ -800,7 +800,7 @@ func (c *OpenChannel) RefreshShortChanID() error { c.Lock() defer c.Unlock() - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -995,7 +995,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { c.Lock() defer c.Unlock() - if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -1047,7 +1047,7 @@ func (c *OpenChannel) MarkDataLoss(commitPoint *btcec.PublicKey) error { func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) { var commitPoint *btcec.PublicKey - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -1271,7 +1271,7 @@ func (c *OpenChannel) BroadcastedCooperative() (*wire.MsgTx, error) { func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) { var closeTx *wire.MsgTx - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -1305,7 +1305,7 @@ func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) { func (c *OpenChannel) putChanStatus(status ChannelStatus, fs ...func(kvdb.RwBucket) error) error { - if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -1349,7 +1349,7 @@ func (c *OpenChannel) putChanStatus(status ChannelStatus, } func (c *OpenChannel) clearChanStatus(status ChannelStatus) error { - if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -1473,7 +1473,7 @@ func (c *OpenChannel) SyncPending(addr net.Addr, pendingHeight uint32) error { c.FundingBroadcastHeight = pendingHeight - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { return syncNewChannel(tx, c, []net.Addr{addr}) }, func() {}) } @@ -1502,7 +1502,7 @@ func syncNewChannel(tx kvdb.RwTx, c *OpenChannel, addrs []net.Addr) error { // for this channel. The LinkNode metadata contains reachability, // up-time, and service bits related information. linkNode := NewLinkNode( - &LinkNodeDB{backend: c.Db.Backend}, + &LinkNodeDB{backend: c.Db.backend}, wire.MainNet, c.IdentityPub, addrs..., ) @@ -1532,7 +1532,7 @@ func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment, return ErrNoRestoredChannelMutation } - err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2124,7 +2124,7 @@ func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error { return ErrNoRestoredChannelMutation } - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { // First, we'll grab the writable bucket where this channel's // data resides. chanBucket, err := fetchChanBucketRw( @@ -2194,7 +2194,7 @@ func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error { // these pointers, causing the tip and the tail to point to the same entry. func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) { var cd *CommitDiff - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2233,7 +2233,7 @@ func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) { // updates that still need to be signed for. func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) { var updates []LogUpdate - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2267,7 +2267,7 @@ func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) { // updates that the remote still needs to sign for. func (c *OpenChannel) RemoteUnsignedLocalUpdates() ([]LogUpdate, error) { var updates []LogUpdate - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2311,7 +2311,7 @@ func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error { c.RemoteNextRevocation = revKey - err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2352,7 +2352,7 @@ func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg, var newRemoteCommit *ChannelCommitment - err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2527,7 +2527,7 @@ func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { defer c.RUnlock() var fwdPkgs []*FwdPkg - if err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + if err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { var err error fwdPkgs, err = c.Packager.LoadFwdPkgs(tx) return err @@ -2547,7 +2547,7 @@ func (c *OpenChannel) AckAddHtlcs(addRefs ...AddRef) error { c.Lock() defer c.Unlock() - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { return c.Packager.AckAddHtlcs(tx, addRefs...) }, func() {}) } @@ -2560,7 +2560,7 @@ func (c *OpenChannel) AckSettleFails(settleFailRefs ...SettleFailRef) error { c.Lock() defer c.Unlock() - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { return c.Packager.AckSettleFails(tx, settleFailRefs...) }, func() {}) } @@ -2571,7 +2571,7 @@ func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { c.Lock() defer c.Unlock() - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { return c.Packager.SetFwdFilter(tx, height, fwdFilter) }, func() {}) } @@ -2585,7 +2585,7 @@ func (c *OpenChannel) RemoveFwdPkgs(heights ...uint64) error { c.Lock() defer c.Unlock() - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { for _, height := range heights { err := c.Packager.RemovePkg(tx, height) if err != nil { @@ -2613,7 +2613,7 @@ func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) { } var commit ChannelCommitment - if err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + if err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2660,7 +2660,7 @@ func (c *OpenChannel) CommitmentHeight() (uint64, error) { defer c.RUnlock() var height uint64 - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { // Get the bucket dedicated to storing the metadata for open // channels. chanBucket, err := fetchChanBucket( @@ -2697,7 +2697,7 @@ func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, e defer c.RUnlock() var commit ChannelCommitment - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2855,7 +2855,7 @@ func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary, c.Lock() defer c.Unlock() - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { openChanBucket := tx.ReadWriteBucket(openChannelBucket) if openChanBucket == nil { return ErrNoChanDBExists @@ -3067,7 +3067,7 @@ func (c *OpenChannel) Snapshot() *ChannelSnapshot { // latest fully committed state is returned. The first commitment returned is // the local commitment, and the second returned is the remote commitment. func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitment, error) { - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -3089,7 +3089,7 @@ func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitmen // acting on a possible contract breach to ensure, that the caller has the most // up to date information required to deliver justice. func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) { - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index ad1b3c07c..044308f88 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -183,7 +183,7 @@ var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption { // createTestChannel writes a test channel to the database. It takes a set of // functional options which can be used to overwrite the default of creating // a pending channel that was broadcast at height 100. -func createTestChannel(t *testing.T, cdb *DB, +func createTestChannel(t *testing.T, cdb *ChannelStateDB, opts ...testChannelOption) *OpenChannel { // Create a default set of parameters. @@ -221,7 +221,7 @@ func createTestChannel(t *testing.T, cdb *DB, return params.channel } -func createTestChannelState(t *testing.T, cdb *DB) *OpenChannel { +func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { // Simulate 1000 channel updates. producer, err := shachain.NewRevocationProducerFromBytes(key[:]) if err != nil { @@ -359,12 +359,14 @@ func createTestChannelState(t *testing.T, cdb *DB) *OpenChannel { func TestOpenChannelPutGetDelete(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create the test channel state, with additional htlcs on the local // and remote commitment. localHtlcs := []HTLC{ @@ -508,12 +510,14 @@ func TestOptionalShutdown(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create a channel with upfront scripts set as // specified in the test. state := createTestChannel( @@ -565,12 +569,14 @@ func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) { func TestChannelStateTransition(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // First create a minimal channel, then perform a full sync in order to // persist the data. channel := createTestChannel(t, cdb) @@ -842,7 +848,7 @@ func TestChannelStateTransition(t *testing.T) { } // At this point, we should have 2 forwarding packages added. - fwdPkgs := loadFwdPkgs(t, cdb, channel.Packager) + fwdPkgs := loadFwdPkgs(t, cdb.backend, channel.Packager) require.Len(t, fwdPkgs, 2, "wrong number of forwarding packages") // Now attempt to delete the channel from the database. @@ -877,19 +883,21 @@ func TestChannelStateTransition(t *testing.T) { } // All forwarding packages of this channel has been deleted too. - fwdPkgs = loadFwdPkgs(t, cdb, channel.Packager) + fwdPkgs = loadFwdPkgs(t, cdb.backend, channel.Packager) require.Empty(t, fwdPkgs, "no forwarding packages should exist") } func TestFetchPendingChannels(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create a pending channel that was broadcast at height 99. const broadcastHeight = 99 createTestChannel(t, cdb, pendingHeightOption(broadcastHeight)) @@ -963,12 +971,14 @@ func TestFetchPendingChannels(t *testing.T) { func TestFetchClosedChannels(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create an open channel in the database. state := createTestChannel(t, cdb, openChannelOption()) @@ -1054,18 +1064,20 @@ func TestFetchWaitingCloseChannels(t *testing.T) { // We'll start by creating two channels within our test database. One of // them will have their funding transaction confirmed on-chain, while // the other one will remain unconfirmed. - db, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + channels := make([]*OpenChannel, numChannels) for i := 0; i < numChannels; i++ { // Create a pending channel in the database at the broadcast // height. channels[i] = createTestChannel( - t, db, pendingHeightOption(broadcastHeight), + t, cdb, pendingHeightOption(broadcastHeight), ) } @@ -1116,7 +1128,7 @@ func TestFetchWaitingCloseChannels(t *testing.T) { // Now, we'll fetch all the channels waiting to be closed from the // database. We should expect to see both channels above, even if any of // them haven't had their funding transaction confirm on-chain. - waitingCloseChannels, err := db.FetchWaitingCloseChannels() + waitingCloseChannels, err := cdb.FetchWaitingCloseChannels() if err != nil { t.Fatalf("unable to fetch all waiting close channels: %v", err) } @@ -1169,12 +1181,14 @@ func TestFetchWaitingCloseChannels(t *testing.T) { func TestRefreshShortChanID(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // First create a test channel. state := createTestChannel(t, cdb) @@ -1317,13 +1331,15 @@ func TestCloseInitiator(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create an open channel. channel := createTestChannel( t, cdb, openChannelOption(), @@ -1362,13 +1378,15 @@ func TestCloseInitiator(t *testing.T) { // TestCloseChannelStatus tests setting of a channel status on the historical // channel on channel close. func TestCloseChannelStatus(t *testing.T) { - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create an open channel. channel := createTestChannel( t, cdb, openChannelOption(), @@ -1427,7 +1445,7 @@ func TestBalanceAtHeight(t *testing.T) { putRevokedState := func(c *OpenChannel, height uint64, local, remote lnwire.MilliSatoshi) error { - err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, @@ -1508,13 +1526,15 @@ func TestBalanceAtHeight(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create options to set the heights and balances of // our local and remote commitments. localCommitOpt := channelCommitmentOption( diff --git a/channeldb/db.go b/channeldb/db.go index 8b373d8d7..3ceb93d93 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -222,8 +222,8 @@ var ( type DB struct { kvdb.Backend - // linkNodeDB separates all DB operations on LinkNodes. - linkNodeDB *LinkNodeDB + // channelStateDB separates all DB operations on channel state. + channelStateDB *ChannelStateDB dbPath string graph *ChannelGraph @@ -273,13 +273,19 @@ func CreateWithBackend(backend kvdb.Backend, modifiers ...OptionModifier) (*DB, chanDB := &DB{ Backend: backend, - linkNodeDB: &LinkNodeDB{ + channelStateDB: &ChannelStateDB{ + linkNodeDB: &LinkNodeDB{ + backend: backend, + }, backend: backend, }, clock: opts.clock, dryRun: opts.dryRun, } + // Set the parent pointer (only used in tests). + chanDB.channelStateDB.parent = chanDB + chanDB.graph = newChannelGraph( backend, opts.RejectCacheSize, opts.ChannelCacheSize, opts.BatchCommitInterval, @@ -339,10 +345,10 @@ func (d *DB) Wipe() error { return initChannelDB(d.Backend) } -// createChannelDB creates and initializes a fresh version of channeldb. In -// the case that the target path has not yet been created or doesn't yet exist, -// then the path is created. Additionally, all required top-level buckets used -// within the database are created. +// initChannelDB creates and initializes a fresh version of channeldb. In the +// case that the target path has not yet been created or doesn't yet exist, then +// the path is created. Additionally, all required top-level buckets used within +// the database are created. func initChannelDB(db kvdb.Backend) error { err := kvdb.Update(db, func(tx kvdb.RwTx) error { meta := &Meta{} @@ -409,15 +415,45 @@ func fileExists(path string) bool { return true } +// ChannelStateDB is a database that keeps track of all channel state. +type ChannelStateDB struct { + // linkNodeDB separates all DB operations on LinkNodes. + linkNodeDB *LinkNodeDB + + // parent holds a pointer to the "main" channeldb.DB object. This is + // only used for testing and should never be used in production code. + // For testing use the ChannelStateDB.GetParentDB() function to retrieve + // this pointer. + parent *DB + + // backend points to the actual backend holding the channel state + // database. This may be a real backend or a cache middleware. + backend kvdb.Backend +} + +// GetParentDB returns the "main" channeldb.DB object that is the owner of this +// ChannelStateDB instance. Use this function only in tests where passing around +// pointers makes testing less readable. Never to be used in production code! +func (c *ChannelStateDB) GetParentDB() *DB { + return c.parent +} + +// LinkNodeDB returns the current instance of the link node database. +func (c *ChannelStateDB) LinkNodeDB() *LinkNodeDB { + return c.linkNodeDB +} + // FetchOpenChannels starts a new database transaction and returns all stored // currently active/open channels associated with the target nodeID. In the case // that no active channels are known to have been created with this node, then a // zero-length slice is returned. -func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) { +func (c *ChannelStateDB) FetchOpenChannels(nodeID *btcec.PublicKey) ( + []*OpenChannel, error) { + var channels []*OpenChannel - err := kvdb.View(d, func(tx kvdb.RTx) error { + err := kvdb.View(c.backend, func(tx kvdb.RTx) error { var err error - channels, err = d.fetchOpenChannels(tx, nodeID) + channels, err = c.fetchOpenChannels(tx, nodeID) return err }, func() { channels = nil @@ -430,7 +466,7 @@ func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) // stored currently active/open channels associated with the target nodeID. In // the case that no active channels are known to have been created with this // node, then a zero-length slice is returned. -func (d *DB) fetchOpenChannels(tx kvdb.RTx, +func (c *ChannelStateDB) fetchOpenChannels(tx kvdb.RTx, nodeID *btcec.PublicKey) ([]*OpenChannel, error) { // Get the bucket dedicated to storing the metadata for open channels. @@ -466,7 +502,7 @@ func (d *DB) fetchOpenChannels(tx kvdb.RTx, // Finally, we both of the necessary buckets retrieved, fetch // all the active channels related to this node. - nodeChannels, err := d.fetchNodeChannels(chainBucket) + nodeChannels, err := c.fetchNodeChannels(chainBucket) if err != nil { return fmt.Errorf("unable to read channel for "+ "chain_hash=%x, node_key=%x: %v", @@ -483,7 +519,8 @@ func (d *DB) fetchOpenChannels(tx kvdb.RTx, // fetchNodeChannels retrieves all active channels from the target chainBucket // which is under a node's dedicated channel bucket. This function is typically // used to fetch all the active channels related to a particular node. -func (d *DB) fetchNodeChannels(chainBucket kvdb.RBucket) ([]*OpenChannel, error) { +func (c *ChannelStateDB) fetchNodeChannels(chainBucket kvdb.RBucket) ( + []*OpenChannel, error) { var channels []*OpenChannel @@ -509,7 +546,7 @@ func (d *DB) fetchNodeChannels(chainBucket kvdb.RBucket) ([]*OpenChannel, error) return fmt.Errorf("unable to read channel data for "+ "chan_point=%v: %v", outPoint, err) } - oChannel.Db = d + oChannel.Db = c channels = append(channels, oChannel) @@ -526,8 +563,8 @@ func (d *DB) fetchNodeChannels(chainBucket kvdb.RBucket) ([]*OpenChannel, error) // point. If the channel cannot be found, then an error will be returned. // Optionally an existing db tx can be supplied. Optionally an existing db tx // can be supplied. -func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel, - error) { +func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( + *OpenChannel, error) { var ( targetChan *OpenChannel @@ -603,7 +640,7 @@ func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel, } targetChan = channel - targetChan.Db = d + targetChan.Db = c return nil }) @@ -612,7 +649,7 @@ func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel, var err error if tx == nil { - err = kvdb.View(d, chanScan, func() {}) + err = kvdb.View(c.backend, chanScan, func() {}) } else { err = chanScan(tx) } @@ -632,16 +669,16 @@ func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel, // FetchAllChannels attempts to retrieve all open channels currently stored // within the database, including pending open, fully open and channels waiting // for a closing transaction to confirm. -func (d *DB) FetchAllChannels() ([]*OpenChannel, error) { - return fetchChannels(d) +func (c *ChannelStateDB) FetchAllChannels() ([]*OpenChannel, error) { + return fetchChannels(c) } // FetchAllOpenChannels will return all channels that have the funding // transaction confirmed, and is not waiting for a closing transaction to be // confirmed. -func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) { +func (c *ChannelStateDB) FetchAllOpenChannels() ([]*OpenChannel, error) { return fetchChannels( - d, + c, pendingChannelFilter(false), waitingCloseFilter(false), ) @@ -650,8 +687,8 @@ func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) { // FetchPendingChannels will return channels that have completed the process of // generating and broadcasting funding transactions, but whose funding // transactions have yet to be confirmed on the blockchain. -func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) { - return fetchChannels(d, +func (c *ChannelStateDB) FetchPendingChannels() ([]*OpenChannel, error) { + return fetchChannels(c, pendingChannelFilter(true), waitingCloseFilter(false), ) @@ -661,9 +698,9 @@ func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) { // but are now waiting for a closing transaction to be confirmed. // // NOTE: This includes channels that are also pending to be opened. -func (d *DB) FetchWaitingCloseChannels() ([]*OpenChannel, error) { +func (c *ChannelStateDB) FetchWaitingCloseChannels() ([]*OpenChannel, error) { return fetchChannels( - d, waitingCloseFilter(true), + c, waitingCloseFilter(true), ) } @@ -704,10 +741,12 @@ func waitingCloseFilter(waitingClose bool) fetchChannelsFilter { // which have a true value returned for *all* of the filters will be returned. // If no filters are provided, every channel in the open channels bucket will // be returned. -func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error) { +func fetchChannels(c *ChannelStateDB, filters ...fetchChannelsFilter) ( + []*OpenChannel, error) { + var channels []*OpenChannel - err := kvdb.View(d, func(tx kvdb.RTx) error { + err := kvdb.View(c.backend, func(tx kvdb.RTx) error { // Get the bucket dedicated to storing the metadata for open // channels. openChanBucket := tx.ReadBucket(openChannelBucket) @@ -749,7 +788,7 @@ func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error "bucket for chain=%x", chainHash[:]) } - nodeChans, err := d.fetchNodeChannels(chainBucket) + nodeChans, err := c.fetchNodeChannels(chainBucket) if err != nil { return fmt.Errorf("unable to read "+ "channel for chain_hash=%x, "+ @@ -798,10 +837,12 @@ func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error // it becomes fully closed after a single confirmation. When a channel was // forcibly closed, it will become fully closed after _all_ the pending funds // (if any) have been swept. -func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, error) { +func (c *ChannelStateDB) FetchClosedChannels(pendingOnly bool) ( + []*ChannelCloseSummary, error) { + var chanSummaries []*ChannelCloseSummary - if err := kvdb.View(d, func(tx kvdb.RTx) error { + if err := kvdb.View(c.backend, func(tx kvdb.RTx) error { closeBucket := tx.ReadBucket(closedChannelBucket) if closeBucket == nil { return ErrNoClosedChannels @@ -839,9 +880,11 @@ var ErrClosedChannelNotFound = errors.New("unable to find closed channel summary // FetchClosedChannel queries for a channel close summary using the channel // point of the channel in question. -func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, error) { +func (c *ChannelStateDB) FetchClosedChannel(chanID *wire.OutPoint) ( + *ChannelCloseSummary, error) { + var chanSummary *ChannelCloseSummary - if err := kvdb.View(d, func(tx kvdb.RTx) error { + if err := kvdb.View(c.backend, func(tx kvdb.RTx) error { closeBucket := tx.ReadBucket(closedChannelBucket) if closeBucket == nil { return ErrClosedChannelNotFound @@ -873,11 +916,11 @@ func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, er // FetchClosedChannelForID queries for a channel close summary using the // channel ID of the channel in question. -func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) ( +func (c *ChannelStateDB) FetchClosedChannelForID(cid lnwire.ChannelID) ( *ChannelCloseSummary, error) { var chanSummary *ChannelCloseSummary - if err := kvdb.View(d, func(tx kvdb.RTx) error { + if err := kvdb.View(c.backend, func(tx kvdb.RTx) error { closeBucket := tx.ReadBucket(closedChannelBucket) if closeBucket == nil { return ErrClosedChannelNotFound @@ -926,12 +969,12 @@ func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) ( // cooperatively closed and it's reached a single confirmation, or after all // the pending funds in a channel that has been forcibly closed have been // swept. -func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { +func (c *ChannelStateDB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { var ( openChannels []*OpenChannel pruneLinkNode *btcec.PublicKey ) - err := kvdb.Update(d, func(tx kvdb.RwTx) error { + err := kvdb.Update(c.backend, func(tx kvdb.RwTx) error { var b bytes.Buffer if err := writeOutpoint(&b, chanPoint); err != nil { return err @@ -978,7 +1021,9 @@ func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { // garbage collect it to ensure we don't establish persistent // connections to peers without open channels. pruneLinkNode = chanSummary.RemotePub - openChannels, err = d.fetchOpenChannels(tx, pruneLinkNode) + openChannels, err = c.fetchOpenChannels( + tx, pruneLinkNode, + ) if err != nil { return fmt.Errorf("unable to fetch open channels for "+ "peer %x: %v", @@ -996,13 +1041,13 @@ func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { // Decide whether we want to remove the link node, based upon the number // of still open channels. - return d.pruneLinkNode(openChannels, pruneLinkNode) + return c.pruneLinkNode(openChannels, pruneLinkNode) } // pruneLinkNode determines whether we should garbage collect a link node from // the database due to no longer having any open channels with it. If there are // any left, then this acts as a no-op. -func (d *DB) pruneLinkNode(openChannels []*OpenChannel, +func (c *ChannelStateDB) pruneLinkNode(openChannels []*OpenChannel, remotePub *btcec.PublicKey) error { if len(openChannels) > 0 { @@ -1012,13 +1057,13 @@ func (d *DB) pruneLinkNode(openChannels []*OpenChannel, log.Infof("Pruning link node %x with zero open channels from database", remotePub.SerializeCompressed()) - return d.linkNodeDB.DeleteLinkNode(remotePub) + return c.linkNodeDB.DeleteLinkNode(remotePub) } // PruneLinkNodes attempts to prune all link nodes found within the databse with // whom we no longer have any open channels with. -func (d *DB) PruneLinkNodes() error { - allLinkNodes, err := d.linkNodeDB.FetchAllLinkNodes() +func (c *ChannelStateDB) PruneLinkNodes() error { + allLinkNodes, err := c.linkNodeDB.FetchAllLinkNodes() if err != nil { return err } @@ -1028,9 +1073,9 @@ func (d *DB) PruneLinkNodes() error { openChannels []*OpenChannel linkNode = linkNode ) - err := kvdb.View(d, func(tx kvdb.RTx) error { + err := kvdb.View(c.backend, func(tx kvdb.RTx) error { var err error - openChannels, err = d.fetchOpenChannels( + openChannels, err = c.fetchOpenChannels( tx, linkNode.IdentityPub, ) return err @@ -1041,7 +1086,7 @@ func (d *DB) PruneLinkNodes() error { return err } - err = d.pruneLinkNode(openChannels, linkNode.IdentityPub) + err = c.pruneLinkNode(openChannels, linkNode.IdentityPub) if err != nil { return err } @@ -1069,8 +1114,8 @@ type ChannelShell struct { // addresses, and finally create an edge within the graph for the channel as // well. This method is idempotent, so repeated calls with the same set of // channel shells won't modify the database after the initial call. -func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { - err := kvdb.Update(d, func(tx kvdb.RwTx) error { +func (c *ChannelStateDB) RestoreChannelShells(channelShells ...*ChannelShell) error { + err := kvdb.Update(c.backend, func(tx kvdb.RwTx) error { for _, channelShell := range channelShells { channel := channelShell.Chan @@ -1084,7 +1129,7 @@ func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { // and link node for this channel. If the channel // already exists, then in order to ensure this method // is idempotent, we'll continue to the next step. - channel.Db = d + channel.Db = c err := syncNewChannel( tx, channel, channelShell.NodeAddrs, ) @@ -1104,8 +1149,10 @@ func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { // AddrsForNode consults the graph and channel database for all addresses known // to the passed node public key. -func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { - linkNode, err := d.linkNodeDB.FetchLinkNode(nodePub) +func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, + error) { + + linkNode, err := d.channelStateDB.linkNodeDB.FetchLinkNode(nodePub) if err != nil { return nil, err } @@ -1157,16 +1204,18 @@ func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { // database. If the channel was already removed (has a closed channel entry), // then we'll return a nil error. Otherwise, we'll insert a new close summary // into the database. -func (d *DB) AbandonChannel(chanPoint *wire.OutPoint, bestHeight uint32) error { +func (c *ChannelStateDB) AbandonChannel(chanPoint *wire.OutPoint, + bestHeight uint32) error { + // With the chanPoint constructed, we'll attempt to find the target // channel in the database. If we can't find the channel, then we'll // return the error back to the caller. - dbChan, err := d.FetchChannel(nil, *chanPoint) + dbChan, err := c.FetchChannel(nil, *chanPoint) switch { // If the channel wasn't found, then it's possible that it was already // abandoned from the database. case err == ErrChannelNotFound: - _, closedErr := d.FetchClosedChannel(chanPoint) + _, closedErr := c.FetchClosedChannel(chanPoint) if closedErr != nil { return closedErr } @@ -1204,8 +1253,10 @@ func (d *DB) AbandonChannel(chanPoint *wire.OutPoint, bestHeight uint32) error { // SaveChannelOpeningState saves the serialized channel state for the provided // chanPoint to the channelOpeningStateBucket. -func (d *DB) SaveChannelOpeningState(outPoint, serializedState []byte) error { - return kvdb.Update(d, func(tx kvdb.RwTx) error { +func (c *ChannelStateDB) SaveChannelOpeningState(outPoint, + serializedState []byte) error { + + return kvdb.Update(c.backend, func(tx kvdb.RwTx) error { bucket, err := tx.CreateTopLevelBucket(channelOpeningStateBucket) if err != nil { return err @@ -1218,9 +1269,9 @@ func (d *DB) SaveChannelOpeningState(outPoint, serializedState []byte) error { // GetChannelOpeningState fetches the serialized channel state for the provided // outPoint from the database, or returns ErrChannelNotFound if the channel // is not found. -func (d *DB) GetChannelOpeningState(outPoint []byte) ([]byte, error) { +func (c *ChannelStateDB) GetChannelOpeningState(outPoint []byte) ([]byte, error) { var serializedState []byte - err := kvdb.View(d, func(tx kvdb.RTx) error { + err := kvdb.View(c.backend, func(tx kvdb.RTx) error { bucket := tx.ReadBucket(channelOpeningStateBucket) if bucket == nil { // If the bucket does not exist, it means we never added @@ -1241,8 +1292,8 @@ func (d *DB) GetChannelOpeningState(outPoint []byte) ([]byte, error) { } // DeleteChannelOpeningState removes any state for outPoint from the database. -func (d *DB) DeleteChannelOpeningState(outPoint []byte) error { - return kvdb.Update(d, func(tx kvdb.RwTx) error { +func (c *ChannelStateDB) DeleteChannelOpeningState(outPoint []byte) error { + return kvdb.Update(c.backend, func(tx kvdb.RwTx) error { bucket := tx.ReadWriteBucket(channelOpeningStateBucket) if bucket == nil { return ErrChannelNotFound @@ -1330,9 +1381,10 @@ func (d *DB) ChannelGraph() *ChannelGraph { return d.graph } -// LinkNodeDB returns the current instance of the link node database. -func (d *DB) LinkNodeDB() *LinkNodeDB { - return d.linkNodeDB +// ChannelStateDB returns the sub database that is concerned with the channel +// state. +func (d *DB) ChannelStateDB() *ChannelStateDB { + return d.channelStateDB } func getLatestDBVersion(versions []version) uint32 { @@ -1384,9 +1436,11 @@ func fetchHistoricalChanBucket(tx kvdb.RTx, // FetchHistoricalChannel fetches open channel data from the historical channel // bucket. -func (d *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, error) { +func (c *ChannelStateDB) FetchHistoricalChannel(outPoint *wire.OutPoint) ( + *OpenChannel, error) { + var channel *OpenChannel - err := kvdb.View(d, func(tx kvdb.RTx) error { + err := kvdb.View(c.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchHistoricalChanBucket(tx, outPoint) if err != nil { return err @@ -1394,7 +1448,7 @@ func (d *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, erro channel, err = fetchOpenChannel(chanBucket, outPoint) - channel.Db = d + channel.Db = c return err }, func() { channel = nil diff --git a/channeldb/db_test.go b/channeldb/db_test.go index ef471c84c..5731c03a8 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -87,15 +87,18 @@ func TestWipe(t *testing.T) { } defer cleanup() - cdb, err := CreateWithBackend(backend) + fullDB, err := CreateWithBackend(backend) if err != nil { t.Fatalf("unable to create channeldb: %v", err) } - defer cdb.Close() + defer fullDB.Close() - if err := cdb.Wipe(); err != nil { + if err := fullDB.Wipe(); err != nil { t.Fatalf("unable to wipe channeldb: %v", err) } + + cdb := fullDB.ChannelStateDB() + // Check correct errors are returned openChannels, err := cdb.FetchAllOpenChannels() require.NoError(t, err, "fetching open channels") @@ -113,12 +116,14 @@ func TestFetchClosedChannelForID(t *testing.T) { const numChans = 101 - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create the test channel state, that we will mutate the index of the // funding point. state := createTestChannelState(t, cdb) @@ -184,18 +189,18 @@ func TestFetchClosedChannelForID(t *testing.T) { func TestAddrsForNode(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() - graph := cdb.ChannelGraph() + graph := fullDB.ChannelGraph() // We'll make a test vertex to insert into the database, as the source // node, but this node will only have half the number of addresses it // usually does. - testNode, err := createTestVertex(cdb) + testNode, err := createTestVertex(fullDB) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -211,7 +216,8 @@ func TestAddrsForNode(t *testing.T) { t.Fatalf("unable to recv node pub: %v", err) } linkNode := NewLinkNode( - cdb.linkNodeDB, wire.MainNet, nodePub, anotherAddr, + fullDB.channelStateDB.linkNodeDB, wire.MainNet, nodePub, + anotherAddr, ) if err := linkNode.Sync(); err != nil { t.Fatalf("unable to sync link node: %v", err) @@ -219,7 +225,7 @@ func TestAddrsForNode(t *testing.T) { // Now that we've created a link node, as well as a vertex for the // node, we'll query for all its addresses. - nodeAddrs, err := cdb.AddrsForNode(nodePub) + nodeAddrs, err := fullDB.AddrsForNode(nodePub) if err != nil { t.Fatalf("unable to obtain node addrs: %v", err) } @@ -245,12 +251,14 @@ func TestAddrsForNode(t *testing.T) { func TestFetchChannel(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create an open channel. channelState := createTestChannel(t, cdb, openChannelOption()) @@ -349,12 +357,14 @@ func genRandomChannelShell() (*ChannelShell, error) { func TestRestoreChannelShells(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // First, we'll make our channel shell, it will only have the minimal // amount of information required for us to initiate the data loss // protection feature. @@ -423,7 +433,7 @@ func TestRestoreChannelShells(t *testing.T) { // We should also be able to find the link node that was inserted by // its public key. - linkNode, err := cdb.linkNodeDB.FetchLinkNode( + linkNode, err := fullDB.channelStateDB.linkNodeDB.FetchLinkNode( channelShell.Chan.IdentityPub, ) if err != nil { @@ -445,12 +455,14 @@ func TestRestoreChannelShells(t *testing.T) { func TestAbandonChannel(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // If we attempt to abandon the state of a channel that doesn't exist // in the open or closed channel bucket, then we should receive an // error. @@ -618,13 +630,15 @@ func TestFetchChannels(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test "+ "database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create a pending channel that is not awaiting close. createTestChannel( t, cdb, channelIDOption(pendingChan), @@ -687,12 +701,14 @@ func TestFetchChannels(t *testing.T) { // TestFetchHistoricalChannel tests lookup of historical channels. func TestFetchHistoricalChannel(t *testing.T) { - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create a an open channel in the database. channel := createTestChannel(t, cdb, openChannelOption()) diff --git a/channeldb/nodes_test.go b/channeldb/nodes_test.go index 7e9231fc5..8f60a7986 100644 --- a/channeldb/nodes_test.go +++ b/channeldb/nodes_test.go @@ -13,12 +13,14 @@ import ( func TestLinkNodeEncodeDecode(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // First we'll create some initial data to use for populating our test // LinkNode instances. _, pub1 := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) @@ -110,12 +112,14 @@ func TestLinkNodeEncodeDecode(t *testing.T) { func TestDeleteLinkNode(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + _, pubKey := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) addr := &net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), diff --git a/channeldb/waitingproof.go b/channeldb/waitingproof.go index e8a09b758..7bb53e179 100644 --- a/channeldb/waitingproof.go +++ b/channeldb/waitingproof.go @@ -36,12 +36,12 @@ type WaitingProofStore struct { // cache is used in order to reduce the number of redundant get // calls, when object isn't stored in it. cache map[WaitingProofKey]struct{} - db *DB + db kvdb.Backend mu sync.RWMutex } // NewWaitingProofStore creates new instance of proofs storage. -func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) { +func NewWaitingProofStore(db kvdb.Backend) (*WaitingProofStore, error) { s := &WaitingProofStore{ db: db, } diff --git a/channelnotifier/channelnotifier.go b/channelnotifier/channelnotifier.go index 74c2b15eb..2cf6015c4 100644 --- a/channelnotifier/channelnotifier.go +++ b/channelnotifier/channelnotifier.go @@ -17,7 +17,7 @@ type ChannelNotifier struct { ntfnServer *subscribe.Server - chanDB *channeldb.DB + chanDB *channeldb.ChannelStateDB } // PendingOpenChannelEvent represents a new event where a new channel has @@ -76,7 +76,7 @@ type FullyResolvedChannelEvent struct { // New creates a new channel notifier. The ChannelNotifier gets channel // events from peers and from the chain arbitrator, and dispatches them to // its clients. -func New(chanDB *channeldb.DB) *ChannelNotifier { +func New(chanDB *channeldb.ChannelStateDB) *ChannelNotifier { return &ChannelNotifier{ ntfnServer: subscribe.NewServer(), chanDB: chanDB, diff --git a/chanrestore.go b/chanrestore.go index 7527499cd..cd68b5077 100644 --- a/chanrestore.go +++ b/chanrestore.go @@ -34,7 +34,7 @@ const ( // need the secret key chain in order obtain the prior shachain root so we can // verify the DLP protocol as initiated by the remote node. type chanDBRestorer struct { - db *channeldb.DB + db *channeldb.ChannelStateDB secretKeys keychain.SecretKeyRing diff --git a/contractcourt/breacharbiter.go b/contractcourt/breacharbiter.go index 112aa5bce..3253e0009 100644 --- a/contractcourt/breacharbiter.go +++ b/contractcourt/breacharbiter.go @@ -136,7 +136,7 @@ type BreachConfig struct { // DB provides access to the user's channels, allowing the breach // arbiter to determine the current state of a user's channels, and how // it should respond to channel closure. - DB *channeldb.DB + DB *channeldb.ChannelStateDB // Estimator is used by the breach arbiter to determine an appropriate // fee level when generating, signing, and broadcasting sweep @@ -1432,11 +1432,11 @@ func (b *BreachArbiter) sweepSpendableOutputsTxn(txWeight int64, // store is to ensure that we can recover from a restart in the middle of a // breached contract retribution. type RetributionStore struct { - db *channeldb.DB + db kvdb.Backend } // NewRetributionStore creates a new instance of a RetributionStore. -func NewRetributionStore(db *channeldb.DB) *RetributionStore { +func NewRetributionStore(db kvdb.Backend) *RetributionStore { return &RetributionStore{ db: db, } diff --git a/contractcourt/breacharbiter_test.go b/contractcourt/breacharbiter_test.go index 0d423584c..61e819391 100644 --- a/contractcourt/breacharbiter_test.go +++ b/contractcourt/breacharbiter_test.go @@ -987,7 +987,7 @@ func initBreachedState(t *testing.T) (*BreachArbiter, contractBreaches := make(chan *ContractBreachEvent) brar, cleanUpArb, err := createTestArbiter( - t, contractBreaches, alice.State().Db, + t, contractBreaches, alice.State().Db.GetParentDB(), ) if err != nil { t.Fatalf("unable to initialize test breach arbiter: %v", err) @@ -1164,7 +1164,7 @@ func TestBreachHandoffFail(t *testing.T) { assertNotPendingClosed(t, alice) brar, cleanUpArb, err := createTestArbiter( - t, contractBreaches, alice.State().Db, + t, contractBreaches, alice.State().Db.GetParentDB(), ) if err != nil { t.Fatalf("unable to initialize test breach arbiter: %v", err) @@ -2075,7 +2075,7 @@ func assertNoArbiterBreach(t *testing.T, brar *BreachArbiter, // assertBrarCleanup blocks until the given channel point has been removed the // retribution store and the channel is fully closed in the database. func assertBrarCleanup(t *testing.T, brar *BreachArbiter, - chanPoint *wire.OutPoint, db *channeldb.DB) { + chanPoint *wire.OutPoint, db *channeldb.ChannelStateDB) { t.Helper() @@ -2174,7 +2174,7 @@ func createTestArbiter(t *testing.T, contractBreaches chan *ContractBreachEvent, notifier := mock.MakeMockSpendNotifier() ba := NewBreachArbiter(&BreachConfig{ CloseLink: func(_ *wire.OutPoint, _ ChannelCloseType) {}, - DB: db, + DB: db.ChannelStateDB(), Estimator: chainfee.NewStaticEstimator(12500, 0), GenSweepScript: func() ([]byte, error) { return nil, nil }, ContractBreaches: contractBreaches, @@ -2375,7 +2375,7 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa RevocationStore: shachain.NewRevocationStore(), LocalCommitment: aliceCommit, RemoteCommitment: aliceCommit, - Db: dbAlice, + Db: dbAlice.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: channels.TestFundingTx, } @@ -2393,7 +2393,7 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa RevocationStore: shachain.NewRevocationStore(), LocalCommitment: bobCommit, RemoteCommitment: bobCommit, - Db: dbBob, + Db: dbBob.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), } diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 426382dd3..aeeff69f9 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -258,7 +258,9 @@ func (a *arbChannel) NewAnchorResolutions() (*lnwallet.AnchorResolutions, // same instance that is used by the link. chanPoint := a.channel.FundingOutpoint - channel, err := a.c.chanSource.FetchChannel(nil, chanPoint) + channel, err := a.c.chanSource.ChannelStateDB().FetchChannel( + nil, chanPoint, + ) if err != nil { return nil, err } @@ -301,7 +303,9 @@ func (a *arbChannel) ForceCloseChan() (*lnwallet.LocalForceCloseSummary, error) // Now that we know the link can't mutate the channel // state, we'll read the channel from disk the target // channel according to its channel point. - channel, err := a.c.chanSource.FetchChannel(nil, chanPoint) + channel, err := a.c.chanSource.ChannelStateDB().FetchChannel( + nil, chanPoint, + ) if err != nil { return nil, err } @@ -422,7 +426,7 @@ func (c *ChainArbitrator) ResolveContract(chanPoint wire.OutPoint) error { // First, we'll we'll mark the channel as fully closed from the PoV of // the channel source. - err := c.chanSource.MarkChanFullyClosed(&chanPoint) + err := c.chanSource.ChannelStateDB().MarkChanFullyClosed(&chanPoint) if err != nil { log.Errorf("ChainArbitrator: unable to mark ChannelPoint(%v) "+ "fully closed: %v", chanPoint, err) @@ -480,7 +484,7 @@ func (c *ChainArbitrator) Start() error { // First, we'll fetch all the channels that are still open, in order to // collect them within our set of active contracts. - openChannels, err := c.chanSource.FetchAllChannels() + openChannels, err := c.chanSource.ChannelStateDB().FetchAllChannels() if err != nil { return err } @@ -538,7 +542,9 @@ func (c *ChainArbitrator) Start() error { // In addition to the channels that we know to be open, we'll also // launch arbitrators to finishing resolving any channels that are in // the pending close state. - closingChannels, err := c.chanSource.FetchClosedChannels(true) + closingChannels, err := c.chanSource.ChannelStateDB().FetchClosedChannels( + true, + ) if err != nil { return err } diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index e197c0b09..cb1648065 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -49,7 +49,7 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { // We manually set the db here to make sure all channels are // synced to the same db. - channel.Db = db + channel.Db = db.ChannelStateDB() addr := &net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), @@ -165,7 +165,7 @@ func TestResolveContract(t *testing.T) { } defer cleanup() channel := newChannel.State() - channel.Db = db + channel.Db = db.ChannelStateDB() addr := &net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 18556, @@ -205,7 +205,7 @@ func TestResolveContract(t *testing.T) { // While the resolver are active, we'll now remove the channel from the // database (mark is as closed). - err = db.AbandonChannel(&channel.FundingOutpoint, 4) + err = db.ChannelStateDB().AbandonChannel(&channel.FundingOutpoint, 4) if err != nil { t.Fatalf("unable to remove channel: %v", err) } diff --git a/contractcourt/utils_test.go b/contractcourt/utils_test.go index 11f23d8cc..0023402c1 100644 --- a/contractcourt/utils_test.go +++ b/contractcourt/utils_test.go @@ -58,7 +58,7 @@ func copyChannelState(state *channeldb.OpenChannel) ( *channeldb.OpenChannel, func(), error) { // Make a copy of the DB. - dbFile := filepath.Join(state.Db.Path(), "channel.db") + dbFile := filepath.Join(state.Db.GetParentDB().Path(), "channel.db") tempDbPath, err := ioutil.TempDir("", "past-state") if err != nil { return nil, nil, err @@ -81,7 +81,7 @@ func copyChannelState(state *channeldb.OpenChannel) ( return nil, nil, err } - chans, err := newDb.FetchAllChannels() + chans, err := newDb.ChannelStateDB().FetchAllChannels() if err != nil { cleanup() return nil, nil, err diff --git a/discovery/message_store.go b/discovery/message_store.go index 4d5f9b205..40f2df78a 100644 --- a/discovery/message_store.go +++ b/discovery/message_store.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" ) @@ -59,7 +58,7 @@ type GossipMessageStore interface { // version of a message (like in the case of multiple ChannelUpdate's) for a // channel with a peer. type MessageStore struct { - db *channeldb.DB + db kvdb.Backend } // A compile-time assertion to ensure messageStore implements the @@ -67,8 +66,8 @@ type MessageStore struct { var _ GossipMessageStore = (*MessageStore)(nil) // NewMessageStore creates a new message store backed by a channeldb instance. -func NewMessageStore(db *channeldb.DB) (*MessageStore, error) { - err := kvdb.Batch(db.Backend, func(tx kvdb.RwTx) error { +func NewMessageStore(db kvdb.Backend) (*MessageStore, error) { + err := kvdb.Batch(db, func(tx kvdb.RwTx) error { _, err := tx.CreateTopLevelBucket(messageStoreBucket) return err }) @@ -124,7 +123,7 @@ func (s *MessageStore) AddMessage(msg lnwire.Message, peerPubKey [33]byte) error return err } - return kvdb.Batch(s.db.Backend, func(tx kvdb.RwTx) error { + return kvdb.Batch(s.db, func(tx kvdb.RwTx) error { messageStore := tx.ReadWriteBucket(messageStoreBucket) if messageStore == nil { return ErrCorruptedMessageStore @@ -145,7 +144,7 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message, return err } - return kvdb.Batch(s.db.Backend, func(tx kvdb.RwTx) error { + return kvdb.Batch(s.db, func(tx kvdb.RwTx) error { messageStore := tx.ReadWriteBucket(messageStoreBucket) if messageStore == nil { return ErrCorruptedMessageStore diff --git a/funding/manager_test.go b/funding/manager_test.go index acd7ca514..64c66f92e 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -261,7 +261,7 @@ func (n *testNode) AddNewChannel(channel *channeldb.OpenChannel, } } -func createTestWallet(cdb *channeldb.DB, netParams *chaincfg.Params, +func createTestWallet(cdb *channeldb.ChannelStateDB, netParams *chaincfg.Params, notifier chainntnfs.ChainNotifier, wc lnwallet.WalletController, signer input.Signer, keyRing keychain.SecretKeyRing, bio lnwallet.BlockChainIO, @@ -329,11 +329,13 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey, } dbDir := filepath.Join(tempTestDir, "cdb") - cdb, err := channeldb.Open(dbDir) + fullDB, err := channeldb.Open(dbDir) if err != nil { return nil, err } + cdb := fullDB.ChannelStateDB() + keyRing := &mock.SecretKeyRing{ RootKey: alicePrivKey, } diff --git a/htlcswitch/circuit_map.go b/htlcswitch/circuit_map.go index 951c922f0..d5bb5f376 100644 --- a/htlcswitch/circuit_map.go +++ b/htlcswitch/circuit_map.go @@ -199,9 +199,16 @@ type circuitMap struct { // parameterize an instance of circuitMap. type CircuitMapConfig struct { // DB provides the persistent storage engine for the circuit map. - // TODO(conner): create abstraction to allow for the substitution of - // other persistence engines. - DB *channeldb.DB + DB kvdb.Backend + + // FetchAllOpenChannels is a function that fetches all currently open + // channels from the channel database. + FetchAllOpenChannels func() ([]*channeldb.OpenChannel, error) + + // FetchClosedChannels is a function that fetches all closed channels + // from the channel database. + FetchClosedChannels func( + pendingOnly bool) ([]*channeldb.ChannelCloseSummary, error) // ExtractErrorEncrypter derives the shared secret used to encrypt // errors from the obfuscator's ephemeral public key. @@ -296,7 +303,7 @@ func (cm *circuitMap) cleanClosedChannels() error { // Find closed channels and cache their ShortChannelIDs into a map. // This map will be used for looking up relative circuits and keystones. - closedChannels, err := cm.cfg.DB.FetchClosedChannels(false) + closedChannels, err := cm.cfg.FetchClosedChannels(false) if err != nil { return err } @@ -629,7 +636,7 @@ func (cm *circuitMap) decodeCircuit(v []byte) (*PaymentCircuit, error) { // channels. Therefore, it must be called before any links are created to avoid // interfering with normal operation. func (cm *circuitMap) trimAllOpenCircuits() error { - activeChannels, err := cm.cfg.DB.FetchAllOpenChannels() + activeChannels, err := cm.cfg.FetchAllOpenChannels() if err != nil { return err } @@ -860,7 +867,7 @@ func (cm *circuitMap) CommitCircuits(circuits ...*PaymentCircuit) ( // Write the entire batch of circuits to the persistent circuit bucket // using bolt's Batch write. This method must be called from multiple, // distinct goroutines to have any impact on performance. - err := kvdb.Batch(cm.cfg.DB.Backend, func(tx kvdb.RwTx) error { + err := kvdb.Batch(cm.cfg.DB, func(tx kvdb.RwTx) error { circuitBkt := tx.ReadWriteBucket(circuitAddKey) if circuitBkt == nil { return ErrCorruptedCircuitMap @@ -1091,7 +1098,7 @@ func (cm *circuitMap) DeleteCircuits(inKeys ...CircuitKey) error { } cm.mtx.Unlock() - err := kvdb.Batch(cm.cfg.DB.Backend, func(tx kvdb.RwTx) error { + err := kvdb.Batch(cm.cfg.DB, func(tx kvdb.RwTx) error { for _, circuit := range removedCircuits { // If this htlc made it to an outgoing link, load the // keystone bucket from which we will remove the diff --git a/htlcswitch/circuit_test.go b/htlcswitch/circuit_test.go index d3ee7b4fe..fed07958b 100644 --- a/htlcswitch/circuit_test.go +++ b/htlcswitch/circuit_test.go @@ -103,8 +103,11 @@ func newCircuitMap(t *testing.T) (*htlcswitch.CircuitMapConfig, onionProcessor := newOnionProcessor(t) + db := makeCircuitDB(t, "") circuitMapCfg := &htlcswitch.CircuitMapConfig{ - DB: makeCircuitDB(t, ""), + DB: db, + FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, + FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, ExtractErrorEncrypter: onionProcessor.ExtractErrorEncrypter, } @@ -634,13 +637,17 @@ func makeCircuitDB(t *testing.T, path string) *channeldb.DB { func restartCircuitMap(t *testing.T, cfg *htlcswitch.CircuitMapConfig) ( *htlcswitch.CircuitMapConfig, htlcswitch.CircuitMap) { - // Record the current temp path and close current db. - dbPath := cfg.DB.Path() + // Record the current temp path and close current db. We know we have + // a full channeldb.DB here since we created it just above. + dbPath := cfg.DB.(*channeldb.DB).Path() cfg.DB.Close() // Reinitialize circuit map with same db path. + db := makeCircuitDB(t, dbPath) cfg2 := &htlcswitch.CircuitMapConfig{ - DB: makeCircuitDB(t, dbPath), + DB: db, + FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, + FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, ExtractErrorEncrypter: cfg.ExtractErrorEncrypter, } cm2, err := htlcswitch.NewCircuitMap(cfg2) diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 1f99a1d9d..865f3afd1 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -1938,7 +1938,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( pCache := newMockPreimageCache() - aliceDb := aliceLc.channel.State().Db + aliceDb := aliceLc.channel.State().Db.GetParentDB() aliceSwitch, err := initSwitchWithDB(testStartingHeight, aliceDb) if err != nil { return nil, nil, nil, nil, nil, nil, err @@ -4438,7 +4438,7 @@ func (h *persistentLinkHarness) restartLink( pCache = newMockPreimageCache() ) - aliceDb := aliceChannel.State().Db + aliceDb := aliceChannel.State().Db.GetParentDB() aliceSwitch := h.coreLink.cfg.Switch if restartSwitch { var err error diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index ce9b0f838..578a92367 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -169,8 +169,10 @@ func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) } cfg := Config{ - DB: db, - SwitchPackager: channeldb.NewSwitchPackager(), + DB: db, + FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, + FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, + SwitchPackager: channeldb.NewSwitchPackager(), FwdingLog: &mockForwardingLog{ events: make(map[time.Time]channeldb.ForwardingEvent), }, diff --git a/htlcswitch/payment_result.go b/htlcswitch/payment_result.go index 2bd35f60a..8d6cb5b3a 100644 --- a/htlcswitch/payment_result.go +++ b/htlcswitch/payment_result.go @@ -83,7 +83,7 @@ func deserializeNetworkResult(r io.Reader) (*networkResult, error) { // is back. The Switch will checkpoint any received result to the store, and // the store will keep results and notify the callers about them. type networkResultStore struct { - db *channeldb.DB + backend kvdb.Backend // results is a map from paymentIDs to channels where subscribers to // payment results will be notified. @@ -96,9 +96,9 @@ type networkResultStore struct { paymentIDMtx *multimutex.Mutex } -func newNetworkResultStore(db *channeldb.DB) *networkResultStore { +func newNetworkResultStore(db kvdb.Backend) *networkResultStore { return &networkResultStore{ - db: db, + backend: db, results: make(map[uint64][]chan *networkResult), paymentIDMtx: multimutex.NewMutex(), } @@ -126,7 +126,7 @@ func (store *networkResultStore) storeResult(paymentID uint64, var paymentIDBytes [8]byte binary.BigEndian.PutUint64(paymentIDBytes[:], paymentID) - err := kvdb.Batch(store.db.Backend, func(tx kvdb.RwTx) error { + err := kvdb.Batch(store.backend, func(tx kvdb.RwTx) error { networkResults, err := tx.CreateTopLevelBucket( networkResultStoreBucketKey, ) @@ -171,7 +171,7 @@ func (store *networkResultStore) subscribeResult(paymentID uint64) ( resultChan = make(chan *networkResult, 1) ) - err := kvdb.View(store.db, func(tx kvdb.RTx) error { + err := kvdb.View(store.backend, func(tx kvdb.RTx) error { var err error result, err = fetchResult(tx, paymentID) switch { @@ -219,7 +219,7 @@ func (store *networkResultStore) getResult(pid uint64) ( *networkResult, error) { var result *networkResult - err := kvdb.View(store.db, func(tx kvdb.RTx) error { + err := kvdb.View(store.backend, func(tx kvdb.RTx) error { var err error result, err = fetchResult(tx, pid) return err @@ -260,7 +260,7 @@ func fetchResult(tx kvdb.RTx, pid uint64) (*networkResult, error) { // concurrently while this process is ongoing, as its result might end up being // deleted. func (store *networkResultStore) cleanStore(keep map[uint64]struct{}) error { - return kvdb.Update(store.db.Backend, func(tx kvdb.RwTx) error { + return kvdb.Update(store.backend, func(tx kvdb.RwTx) error { networkResults, err := tx.CreateTopLevelBucket( networkResultStoreBucketKey, ) diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index d367d5e6b..17b423857 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -121,9 +121,18 @@ type Config struct { // subsystem. LocalChannelClose func(pubKey []byte, request *ChanClose) - // DB is the channeldb instance that will be used to back the switch's + // DB is the database backend that will be used to back the switch's // persistent circuit map. - DB *channeldb.DB + DB kvdb.Backend + + // FetchAllOpenChannels is a function that fetches all currently open + // channels from the channel database. + FetchAllOpenChannels func() ([]*channeldb.OpenChannel, error) + + // FetchClosedChannels is a function that fetches all closed channels + // from the channel database. + FetchClosedChannels func( + pendingOnly bool) ([]*channeldb.ChannelCloseSummary, error) // SwitchPackager provides access to the forwarding packages of all // active channels. This gives the switch the ability to read arbitrary @@ -281,6 +290,8 @@ type Switch struct { func New(cfg Config, currentHeight uint32) (*Switch, error) { circuitMap, err := NewCircuitMap(&CircuitMapConfig{ DB: cfg.DB, + FetchAllOpenChannels: cfg.FetchAllOpenChannels, + FetchClosedChannels: cfg.FetchClosedChannels, ExtractErrorEncrypter: cfg.ExtractErrorEncrypter, }) if err != nil { @@ -1374,7 +1385,7 @@ func (s *Switch) closeCircuit(pkt *htlcPacket) (*PaymentCircuit, error) { // we're the originator of the payment, so the link stops attempting to // re-broadcast. func (s *Switch) ackSettleFail(settleFailRefs ...channeldb.SettleFailRef) error { - return kvdb.Batch(s.cfg.DB.Backend, func(tx kvdb.RwTx) error { + return kvdb.Batch(s.cfg.DB, func(tx kvdb.RwTx) error { return s.cfg.SwitchPackager.AckSettleFails(tx, settleFailRefs...) }) } @@ -1778,7 +1789,7 @@ func (s *Switch) Start() error { // forwarding packages and reforwards any Settle or Fail HTLCs found. This is // used to resurrect the switch's mailboxes after a restart. func (s *Switch) reforwardResponses() error { - openChannels, err := s.cfg.DB.FetchAllOpenChannels() + openChannels, err := s.cfg.FetchAllOpenChannels() if err != nil { return err } diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index d33daff8f..eaf2aa99c 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -306,7 +306,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, LocalCommitment: aliceCommit, RemoteCommitment: aliceCommit, ShortChannelID: chanID, - Db: dbAlice, + Db: dbAlice.ChannelStateDB(), Packager: channeldb.NewChannelPackager(chanID), FundingTxn: channels.TestFundingTx, } @@ -325,7 +325,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, LocalCommitment: bobCommit, RemoteCommitment: bobCommit, ShortChannelID: chanID, - Db: dbBob, + Db: dbBob.ChannelStateDB(), Packager: channeldb.NewChannelPackager(chanID), } @@ -384,7 +384,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, } restoreAlice := func() (*lnwallet.LightningChannel, error) { - aliceStoredChannels, err := dbAlice.FetchOpenChannels(aliceKeyPub) + aliceStoredChannels, err := dbAlice.ChannelStateDB(). + FetchOpenChannels(aliceKeyPub) switch err { case nil: case kvdb.ErrDatabaseNotOpen: @@ -394,7 +395,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, "db: %v", err) } - aliceStoredChannels, err = dbAlice.FetchOpenChannels(aliceKeyPub) + aliceStoredChannels, err = dbAlice.ChannelStateDB(). + FetchOpenChannels(aliceKeyPub) if err != nil { return nil, errors.Errorf("unable to fetch alice "+ "channel: %v", err) @@ -428,7 +430,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, } restoreBob := func() (*lnwallet.LightningChannel, error) { - bobStoredChannels, err := dbBob.FetchOpenChannels(bobKeyPub) + bobStoredChannels, err := dbBob.ChannelStateDB(). + FetchOpenChannels(bobKeyPub) switch err { case nil: case kvdb.ErrDatabaseNotOpen: @@ -438,7 +441,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, "db: %v", err) } - bobStoredChannels, err = dbBob.FetchOpenChannels(bobKeyPub) + bobStoredChannels, err = dbBob.ChannelStateDB(). + FetchOpenChannels(bobKeyPub) if err != nil { return nil, errors.Errorf("unable to fetch bob "+ "channel: %v", err) @@ -950,9 +954,9 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, secondBobChannel, carolChannel *lnwallet.LightningChannel, startingHeight uint32, opts ...serverOption) *threeHopNetwork { - aliceDb := aliceChannel.State().Db - bobDb := firstBobChannel.State().Db - carolDb := carolChannel.State().Db + aliceDb := aliceChannel.State().Db.GetParentDB() + bobDb := firstBobChannel.State().Db.GetParentDB() + carolDb := carolChannel.State().Db.GetParentDB() hopNetwork := newHopNetwork() @@ -1201,8 +1205,8 @@ func newTwoHopNetwork(t testing.TB, aliceChannel, bobChannel *lnwallet.LightningChannel, startingHeight uint32) *twoHopNetwork { - aliceDb := aliceChannel.State().Db - bobDb := bobChannel.State().Db + aliceDb := aliceChannel.State().Db.GetParentDB() + bobDb := bobChannel.State().Db.GetParentDB() hopNetwork := newHopNetwork() diff --git a/lnd.go b/lnd.go index 8bf4a9fc4..7dcc7bf55 100644 --- a/lnd.go +++ b/lnd.go @@ -697,7 +697,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, interceptor signal.Interceptor) error BtcdMode: cfg.BtcdMode, LtcdMode: cfg.LtcdMode, HeightHintDB: dbs.heightHintDB, - ChanStateDB: dbs.chanStateDB, + ChanStateDB: dbs.chanStateDB.ChannelStateDB(), PrivateWalletPw: privateWalletPw, PublicWalletPw: publicWalletPw, Birthday: walletInitParams.Birthday, diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index 4e88ae0c1..193f3a63a 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -56,7 +56,7 @@ type AddInvoiceConfig struct { // ChanDB is a global boltdb instance which is needed to access the // channel graph. - ChanDB *channeldb.DB + ChanDB *channeldb.ChannelStateDB // Graph holds a reference to the ChannelGraph database. Graph *channeldb.ChannelGraph diff --git a/lnrpc/invoicesrpc/config_active.go b/lnrpc/invoicesrpc/config_active.go index 3246f4b7f..abe8c5565 100644 --- a/lnrpc/invoicesrpc/config_active.go +++ b/lnrpc/invoicesrpc/config_active.go @@ -50,7 +50,7 @@ type Config struct { // ChanStateDB is a possibly replicated db instance which contains the // channel state - ChanStateDB *channeldb.DB + ChanStateDB *channeldb.ChannelStateDB // GenInvoiceFeatures returns a feature containing feature bits that // should be advertised on freshly generated invoices. diff --git a/lnwallet/config.go b/lnwallet/config.go index a73120c02..cf7f3f4b8 100644 --- a/lnwallet/config.go +++ b/lnwallet/config.go @@ -18,7 +18,7 @@ type Config struct { // Database is a wrapper around a namespace within boltdb reserved for // ln-based wallet metadata. See the 'channeldb' package for further // information. - Database *channeldb.DB + Database *channeldb.ChannelStateDB // Notifier is used by in order to obtain notifications about funding // transaction reaching a specified confirmation depth, and to catch diff --git a/lnwallet/test/test_interface.go b/lnwallet/test/test_interface.go index dd6bf1a95..0b2aecff6 100644 --- a/lnwallet/test/test_interface.go +++ b/lnwallet/test/test_interface.go @@ -327,13 +327,13 @@ func createTestWallet(tempTestDir string, miningNode *rpctest.Harness, signer input.Signer, bio lnwallet.BlockChainIO) (*lnwallet.LightningWallet, error) { dbDir := filepath.Join(tempTestDir, "cdb") - cdb, err := channeldb.Open(dbDir) + fullDB, err := channeldb.Open(dbDir) if err != nil { return nil, err } cfg := lnwallet.Config{ - Database: cdb, + Database: fullDB.ChannelStateDB(), Notifier: notifier, SecretKeyRing: keyRing, WalletController: wc, @@ -2944,11 +2944,11 @@ func clearWalletStates(a, b *lnwallet.LightningWallet) error { a.ResetReservations() b.ResetReservations() - if err := a.Cfg.Database.Wipe(); err != nil { + if err := a.Cfg.Database.GetParentDB().Wipe(); err != nil { return err } - return b.Cfg.Database.Wipe() + return b.Cfg.Database.GetParentDB().Wipe() } func waitForMempoolTx(r *rpctest.Harness, txid *chainhash.Hash) error { diff --git a/lnwallet/test_utils.go b/lnwallet/test_utils.go index bd048b2c0..40af5201c 100644 --- a/lnwallet/test_utils.go +++ b/lnwallet/test_utils.go @@ -322,7 +322,7 @@ func CreateTestChannels(chanType channeldb.ChannelType) ( RevocationStore: shachain.NewRevocationStore(), LocalCommitment: aliceLocalCommit, RemoteCommitment: aliceRemoteCommit, - Db: dbAlice, + Db: dbAlice.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: testTx, } @@ -340,7 +340,7 @@ func CreateTestChannels(chanType channeldb.ChannelType) ( RevocationStore: shachain.NewRevocationStore(), LocalCommitment: bobLocalCommit, RemoteCommitment: bobRemoteCommit, - Db: dbBob, + Db: dbBob.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), } diff --git a/lnwallet/transactions_test.go b/lnwallet/transactions_test.go index 696328cdb..be5bf705a 100644 --- a/lnwallet/transactions_test.go +++ b/lnwallet/transactions_test.go @@ -937,7 +937,7 @@ func createTestChannelsForVectors(tc *testContext, chanType channeldb.ChannelTyp RevocationStore: shachain.NewRevocationStore(), LocalCommitment: remoteCommit, RemoteCommitment: remoteCommit, - Db: dbRemote, + Db: dbRemote.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: tc.fundingTx.MsgTx(), } @@ -955,7 +955,7 @@ func createTestChannelsForVectors(tc *testContext, chanType channeldb.ChannelTyp RevocationStore: shachain.NewRevocationStore(), LocalCommitment: localCommit, RemoteCommitment: localCommit, - Db: dbLocal, + Db: dbLocal.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: tc.fundingTx.MsgTx(), } diff --git a/peer/brontide.go b/peer/brontide.go index 60c41af6f..9c6df6734 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -185,7 +185,7 @@ type Config struct { InterceptSwitch *htlcswitch.InterceptableSwitch // ChannelDB is used to fetch opened channels, and closed channels. - ChannelDB *channeldb.DB + ChannelDB *channeldb.ChannelStateDB // ChannelGraph is a pointer to the channel graph which is used to // query information about the set of known active channels. diff --git a/peer/test_utils.go b/peer/test_utils.go index 3ce1cbe03..ac5f5f5ab 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -229,7 +229,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, RevocationStore: shachain.NewRevocationStore(), LocalCommitment: aliceCommit, RemoteCommitment: aliceCommit, - Db: dbAlice, + Db: dbAlice.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: channels.TestFundingTx, } @@ -246,7 +246,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, RevocationStore: shachain.NewRevocationStore(), LocalCommitment: bobCommit, RemoteCommitment: bobCommit, - Db: dbBob, + Db: dbBob.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), } @@ -321,7 +321,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, ChanStatusSampleInterval: 30 * time.Second, ChanEnableTimeout: chanActiveTimeout, ChanDisableTimeout: 2 * time.Minute, - DB: dbAlice, + DB: dbAlice.ChannelStateDB(), Graph: dbAlice.ChannelGraph(), MessageSigner: nodeSignerAlice, OurPubKey: aliceKeyPub, @@ -359,7 +359,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, ChanActiveTimeout: chanActiveTimeout, InterceptSwitch: htlcswitch.NewInterceptableSwitch(nil), - ChannelDB: dbAlice, + ChannelDB: dbAlice.ChannelStateDB(), FeeEstimator: estimator, Wallet: wallet, ChainNotifier: notifier, diff --git a/rpcserver.go b/rpcserver.go index 79f655375..67934bac1 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -3979,7 +3979,7 @@ func (r *rpcServer) createRPCClosedChannel( CloseInitiator: closeInitiator, } - reports, err := r.server.chanStateDB.FetchChannelReports( + reports, err := r.server.miscDB.FetchChannelReports( *r.cfg.ActiveNetParams.GenesisHash, &dbChannel.ChanPoint, ) switch err { @@ -5142,7 +5142,7 @@ func (r *rpcServer) ListInvoices(ctx context.Context, PendingOnly: req.PendingOnly, Reversed: req.Reversed, } - invoiceSlice, err := r.server.chanStateDB.QueryInvoices(q) + invoiceSlice, err := r.server.miscDB.QueryInvoices(q) if err != nil { return nil, fmt.Errorf("unable to query invoices: %v", err) } @@ -5944,7 +5944,7 @@ func (r *rpcServer) ListPayments(ctx context.Context, query.MaxPayments = math.MaxUint64 } - paymentsQuerySlice, err := r.server.chanStateDB.QueryPayments(query) + paymentsQuerySlice, err := r.server.miscDB.QueryPayments(query) if err != nil { return nil, err } @@ -5985,9 +5985,7 @@ func (r *rpcServer) DeletePayment(ctx context.Context, rpcsLog.Infof("[DeletePayment] payment_identifier=%v, "+ "failed_htlcs_only=%v", hash, req.FailedHtlcsOnly) - err = r.server.chanStateDB.DeletePayment( - hash, req.FailedHtlcsOnly, - ) + err = r.server.miscDB.DeletePayment(hash, req.FailedHtlcsOnly) if err != nil { return nil, err } @@ -6004,7 +6002,7 @@ func (r *rpcServer) DeleteAllPayments(ctx context.Context, "failed_htlcs_only=%v", req.FailedPaymentsOnly, req.FailedHtlcsOnly) - err := r.server.chanStateDB.DeletePayments( + err := r.server.miscDB.DeletePayments( req.FailedPaymentsOnly, req.FailedHtlcsOnly, ) if err != nil { @@ -6166,7 +6164,7 @@ func (r *rpcServer) FeeReport(ctx context.Context, return nil, err } - fwdEventLog := r.server.chanStateDB.ForwardingLog() + fwdEventLog := r.server.miscDB.ForwardingLog() // computeFeeSum is a helper function that computes the total fees for // a particular time slice described by a forwarding event query. @@ -6407,7 +6405,7 @@ func (r *rpcServer) ForwardingHistory(ctx context.Context, IndexOffset: req.IndexOffset, NumMaxEvents: numEvents, } - timeSlice, err := r.server.chanStateDB.ForwardingLog().Query(eventQuery) + timeSlice, err := r.server.miscDB.ForwardingLog().Query(eventQuery) if err != nil { return nil, fmt.Errorf("unable to query forwarding log: %v", err) } diff --git a/server.go b/server.go index 6b8830f03..f8f1f53e8 100644 --- a/server.go +++ b/server.go @@ -222,10 +222,14 @@ type server struct { graphDB *channeldb.ChannelGraph - chanStateDB *channeldb.DB + chanStateDB *channeldb.ChannelStateDB addrSource chanbackup.AddressSource + // miscDB is the DB that contains all "other" databases within the main + // channel DB that haven't been separated out yet. + miscDB *channeldb.DB + htlcSwitch *htlcswitch.Switch interceptableSwitch *htlcswitch.InterceptableSwitch @@ -434,15 +438,18 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s := &server{ cfg: cfg, graphDB: dbs.graphDB.ChannelGraph(), - chanStateDB: dbs.chanStateDB, + chanStateDB: dbs.chanStateDB.ChannelStateDB(), addrSource: dbs.chanStateDB, + miscDB: dbs.chanStateDB, cc: cc, sigPool: lnwallet.NewSigPool(cfg.Workers.Sig, cc.Signer), writePool: writePool, readPool: readPool, chansToRestore: chansToRestore, - channelNotifier: channelnotifier.New(dbs.chanStateDB), + channelNotifier: channelnotifier.New( + dbs.chanStateDB.ChannelStateDB(), + ), identityECDH: nodeKeyECDH, nodeSigner: netann.NewNodeSigner(nodeKeySigner), @@ -494,7 +501,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.htlcNotifier = htlcswitch.NewHtlcNotifier(time.Now) s.htlcSwitch, err = htlcswitch.New(htlcswitch.Config{ - DB: dbs.chanStateDB, + DB: dbs.chanStateDB, + FetchAllOpenChannels: s.chanStateDB.FetchAllOpenChannels, + FetchClosedChannels: s.chanStateDB.FetchClosedChannels, LocalChannelClose: func(pubKey []byte, request *htlcswitch.ChanClose) { @@ -536,7 +545,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, MessageSigner: s.nodeSigner, IsChannelActive: s.htlcSwitch.HasActiveLink, ApplyChannelUpdate: s.applyChannelUpdate, - DB: dbs.chanStateDB, + DB: s.chanStateDB, Graph: dbs.graphDB.ChannelGraph(), } @@ -804,11 +813,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } chanSeries := discovery.NewChanSeries(s.graphDB) - gossipMessageStore, err := discovery.NewMessageStore(s.chanStateDB) + gossipMessageStore, err := discovery.NewMessageStore(dbs.chanStateDB) if err != nil { return nil, err } - waitingProofStore, err := channeldb.NewWaitingProofStore(s.chanStateDB) + waitingProofStore, err := channeldb.NewWaitingProofStore(dbs.chanStateDB) if err != nil { return nil, err } @@ -890,8 +899,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.utxoNursery = contractcourt.NewUtxoNursery(&contractcourt.NurseryConfig{ ChainIO: cc.ChainIO, ConfDepth: 1, - FetchClosedChannels: dbs.chanStateDB.FetchClosedChannels, - FetchClosedChannel: dbs.chanStateDB.FetchClosedChannel, + FetchClosedChannels: s.chanStateDB.FetchClosedChannels, + FetchClosedChannel: s.chanStateDB.FetchClosedChannel, Notifier: cc.ChainNotifier, PublishTransaction: cc.Wallet.PublishTransaction, Store: utxnStore, @@ -1017,7 +1026,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.breachArbiter = contractcourt.NewBreachArbiter(&contractcourt.BreachConfig{ CloseLink: closeLink, - DB: dbs.chanStateDB, + DB: s.chanStateDB, Estimator: s.cc.FeeEstimator, GenSweepScript: newSweepPkScriptGen(cc.Wallet), Notifier: cc.ChainNotifier, @@ -1074,7 +1083,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, FindChannel: func(chanID lnwire.ChannelID) ( *channeldb.OpenChannel, error) { - dbChannels, err := dbs.chanStateDB.FetchAllChannels() + dbChannels, err := s.chanStateDB.FetchAllChannels() if err != nil { return nil, err } @@ -1246,7 +1255,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, // static backup of the latest channel state. chanNotifier := &channelNotifier{ chanNotifier: s.channelNotifier, - addrs: s.chanStateDB, + addrs: dbs.chanStateDB, } backupFile := chanbackup.NewMultiFile(cfg.BackupFilePath) startingChans, err := chanbackup.FetchStaticChanBackups( @@ -1276,8 +1285,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, }, GetOpenChannels: s.chanStateDB.FetchAllOpenChannels, Clock: clock.NewDefaultClock(), - ReadFlapCount: s.chanStateDB.ReadFlapCount, - WriteFlapCount: s.chanStateDB.WriteFlapCounts, + ReadFlapCount: s.miscDB.ReadFlapCount, + WriteFlapCount: s.miscDB.WriteFlapCounts, FlapCountTicker: ticker.New(chanfitness.FlapCountFlushRate), }) diff --git a/subrpcserver_config.go b/subrpcserver_config.go index bf5911ec2..04853db76 100644 --- a/subrpcserver_config.go +++ b/subrpcserver_config.go @@ -93,7 +93,7 @@ func (s *subRPCServerConfigs) PopulateDependencies(cfg *Config, routerBackend *routerrpc.RouterBackend, nodeSigner *netann.NodeSigner, graphDB *channeldb.ChannelGraph, - chanStateDB *channeldb.DB, + chanStateDB *channeldb.ChannelStateDB, sweeper *sweep.UtxoSweeper, tower *watchtower.Standalone, towerClient wtclient.Client,