Merge pull request #5642 from guggero/in-memory-graph

In-memory graph cache for faster pathfinding
This commit is contained in:
Oliver Gugger 2021-10-04 11:20:23 +02:00 committed by GitHub
commit 692ea25295
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
65 changed files with 2595 additions and 1171 deletions

View File

@ -148,7 +148,7 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey,
return nil, err
}
dbNode, err := d.db.FetchLightningNode(nil, vertex)
dbNode, err := d.db.FetchLightningNode(vertex)
switch {
case err == channeldb.ErrGraphNodeNotFound:
fallthrough

View File

@ -75,7 +75,7 @@ type Config struct {
// ChanStateDB is a pointer to the database that stores the channel
// state.
ChanStateDB *channeldb.DB
ChanStateDB *channeldb.ChannelStateDB
// BlockCacheSize is the size (in bytes) of blocks kept in memory.
BlockCacheSize uint64

View File

@ -21,7 +21,11 @@ type LiveChannelSource interface {
// passed chanPoint. Optionally an existing db tx can be supplied.
FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
*channeldb.OpenChannel, error)
}
// AddressSource is an interface that allows us to query for the set of
// addresses a node can be connected to.
type AddressSource interface {
// AddrsForNode returns all known addresses for the target node public
// key.
AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error)
@ -31,15 +35,15 @@ type LiveChannelSource interface {
// passed open channel. The backup includes all information required to restore
// the channel, as well as addressing information so we can find the peer and
// reconnect to them to initiate the protocol.
func assembleChanBackup(chanSource LiveChannelSource,
func assembleChanBackup(addrSource AddressSource,
openChan *channeldb.OpenChannel) (*Single, error) {
log.Debugf("Crafting backup for ChannelPoint(%v)",
openChan.FundingOutpoint)
// First, we'll query the channel source to obtain all the addresses
// that are are associated with the peer for this channel.
nodeAddrs, err := chanSource.AddrsForNode(openChan.IdentityPub)
// that are associated with the peer for this channel.
nodeAddrs, err := addrSource.AddrsForNode(openChan.IdentityPub)
if err != nil {
return nil, err
}
@ -52,8 +56,8 @@ func assembleChanBackup(chanSource LiveChannelSource,
// FetchBackupForChan attempts to create a plaintext static channel backup for
// the target channel identified by its channel point. If we're unable to find
// the target channel, then an error will be returned.
func FetchBackupForChan(chanPoint wire.OutPoint,
chanSource LiveChannelSource) (*Single, error) {
func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource,
addrSource AddressSource) (*Single, error) {
// First, we'll query the channel source to see if the channel is known
// and open within the database.
@ -66,7 +70,7 @@ func FetchBackupForChan(chanPoint wire.OutPoint,
// Once we have the target channel, we can assemble the backup using
// the source to obtain any extra information that we may need.
staticChanBackup, err := assembleChanBackup(chanSource, targetChan)
staticChanBackup, err := assembleChanBackup(addrSource, targetChan)
if err != nil {
return nil, fmt.Errorf("unable to create chan backup: %v", err)
}
@ -76,7 +80,9 @@ func FetchBackupForChan(chanPoint wire.OutPoint,
// FetchStaticChanBackups will return a plaintext static channel back up for
// all known active/open channels within the passed channel source.
func FetchStaticChanBackups(chanSource LiveChannelSource) ([]Single, error) {
func FetchStaticChanBackups(chanSource LiveChannelSource,
addrSource AddressSource) ([]Single, error) {
// First, we'll query the backup source for information concerning all
// currently open and available channels.
openChans, err := chanSource.FetchAllChannels()
@ -89,7 +95,7 @@ func FetchStaticChanBackups(chanSource LiveChannelSource) ([]Single, error) {
// channel.
staticChanBackups := make([]Single, 0, len(openChans))
for _, openChan := range openChans {
chanBackup, err := assembleChanBackup(chanSource, openChan)
chanBackup, err := assembleChanBackup(addrSource, openChan)
if err != nil {
return nil, err
}

View File

@ -124,7 +124,9 @@ func TestFetchBackupForChan(t *testing.T) {
},
}
for i, testCase := range testCases {
_, err := FetchBackupForChan(testCase.chanPoint, chanSource)
_, err := FetchBackupForChan(
testCase.chanPoint, chanSource, chanSource,
)
switch {
// If this is a valid test case, and we failed, then we'll
// return an error.
@ -167,7 +169,7 @@ func TestFetchStaticChanBackups(t *testing.T) {
// With the channel source populated, we'll now attempt to create a set
// of backups for all the channels. This should succeed, as all items
// are populated within the channel source.
backups, err := FetchStaticChanBackups(chanSource)
backups, err := FetchStaticChanBackups(chanSource, chanSource)
if err != nil {
t.Fatalf("unable to create chan back ups: %v", err)
}
@ -184,7 +186,7 @@ func TestFetchStaticChanBackups(t *testing.T) {
copy(n[:], randomChan2.IdentityPub.SerializeCompressed())
delete(chanSource.addrs, n)
_, err = FetchStaticChanBackups(chanSource)
_, err = FetchStaticChanBackups(chanSource, chanSource)
if err == nil {
t.Fatalf("query with incomplete information should fail")
}
@ -193,7 +195,7 @@ func TestFetchStaticChanBackups(t *testing.T) {
// source at all, then we'll fail as well.
chanSource = newMockChannelSource()
chanSource.failQuery = true
_, err = FetchStaticChanBackups(chanSource)
_, err = FetchStaticChanBackups(chanSource, chanSource)
if err == nil {
t.Fatalf("query should fail")
}

View File

@ -729,7 +729,7 @@ type OpenChannel struct {
RevocationKeyLocator keychain.KeyLocator
// TODO(roasbeef): eww
Db *DB
Db *ChannelStateDB
// TODO(roasbeef): just need to store local and remote HTLC's?
@ -800,7 +800,7 @@ func (c *OpenChannel) RefreshShortChanID() error {
c.Lock()
defer c.Unlock()
err := kvdb.View(c.Db, func(tx kvdb.RTx) error {
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
@ -875,12 +875,43 @@ func fetchChanBucket(tx kvdb.RTx, nodeKey *btcec.PublicKey,
func fetchChanBucketRw(tx kvdb.RwTx, nodeKey *btcec.PublicKey, // nolint:interfacer
outPoint *wire.OutPoint, chainHash chainhash.Hash) (kvdb.RwBucket, error) {
readBucket, err := fetchChanBucket(tx, nodeKey, outPoint, chainHash)
if err != nil {
return nil, err
// First fetch the top level bucket which stores all data related to
// current, active channels.
openChanBucket := tx.ReadWriteBucket(openChannelBucket)
if openChanBucket == nil {
return nil, ErrNoChanDBExists
}
return readBucket.(kvdb.RwBucket), nil
// TODO(roasbeef): CreateTopLevelBucket on the interface isn't like
// CreateIfNotExists, will return error
// Within this top level bucket, fetch the bucket dedicated to storing
// open channel data specific to the remote node.
nodePub := nodeKey.SerializeCompressed()
nodeChanBucket := openChanBucket.NestedReadWriteBucket(nodePub)
if nodeChanBucket == nil {
return nil, ErrNoActiveChannels
}
// We'll then recurse down an additional layer in order to fetch the
// bucket for this particular chain.
chainBucket := nodeChanBucket.NestedReadWriteBucket(chainHash[:])
if chainBucket == nil {
return nil, ErrNoActiveChannels
}
// With the bucket for the node and chain fetched, we can now go down
// another level, for this channel itself.
var chanPointBuf bytes.Buffer
if err := writeOutpoint(&chanPointBuf, outPoint); err != nil {
return nil, err
}
chanBucket := chainBucket.NestedReadWriteBucket(chanPointBuf.Bytes())
if chanBucket == nil {
return nil, ErrChannelNotFound
}
return chanBucket, nil
}
// fullSync syncs the contents of an OpenChannel while re-using an existing
@ -964,8 +995,8 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error {
c.Lock()
defer c.Unlock()
if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucket(
if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
@ -980,7 +1011,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error {
channel.IsPending = false
channel.ShortChannelID = openLoc
return putOpenChannel(chanBucket.(kvdb.RwBucket), channel)
return putOpenChannel(chanBucket, channel)
}, func() {}); err != nil {
return err
}
@ -1016,7 +1047,7 @@ func (c *OpenChannel) MarkDataLoss(commitPoint *btcec.PublicKey) error {
func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) {
var commitPoint *btcec.PublicKey
err := kvdb.View(c.Db, func(tx kvdb.RTx) error {
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
@ -1240,7 +1271,7 @@ func (c *OpenChannel) BroadcastedCooperative() (*wire.MsgTx, error) {
func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) {
var closeTx *wire.MsgTx
err := kvdb.View(c.Db, func(tx kvdb.RTx) error {
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
@ -1274,7 +1305,7 @@ func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) {
func (c *OpenChannel) putChanStatus(status ChannelStatus,
fs ...func(kvdb.RwBucket) error) error {
if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
@ -1318,7 +1349,7 @@ func (c *OpenChannel) putChanStatus(status ChannelStatus,
}
func (c *OpenChannel) clearChanStatus(status ChannelStatus) error {
if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
@ -1442,7 +1473,7 @@ func (c *OpenChannel) SyncPending(addr net.Addr, pendingHeight uint32) error {
c.FundingBroadcastHeight = pendingHeight
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
return syncNewChannel(tx, c, []net.Addr{addr})
}, func() {})
}
@ -1470,7 +1501,10 @@ func syncNewChannel(tx kvdb.RwTx, c *OpenChannel, addrs []net.Addr) error {
// Next, we need to establish a (possibly) new LinkNode relationship
// for this channel. The LinkNode metadata contains reachability,
// up-time, and service bits related information.
linkNode := c.Db.NewLinkNode(wire.MainNet, c.IdentityPub, addrs...)
linkNode := NewLinkNode(
&LinkNodeDB{backend: c.Db.backend},
wire.MainNet, c.IdentityPub, addrs...,
)
// TODO(roasbeef): do away with link node all together?
@ -1498,7 +1532,7 @@ func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment,
return ErrNoRestoredChannelMutation
}
err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
@ -2090,7 +2124,7 @@ func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error {
return ErrNoRestoredChannelMutation
}
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
// First, we'll grab the writable bucket where this channel's
// data resides.
chanBucket, err := fetchChanBucketRw(
@ -2160,7 +2194,7 @@ func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error {
// these pointers, causing the tip and the tail to point to the same entry.
func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) {
var cd *CommitDiff
err := kvdb.View(c.Db, func(tx kvdb.RTx) error {
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
@ -2199,7 +2233,7 @@ func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) {
// updates that still need to be signed for.
func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) {
var updates []LogUpdate
err := kvdb.View(c.Db, func(tx kvdb.RTx) error {
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
@ -2233,7 +2267,7 @@ func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) {
// updates that the remote still needs to sign for.
func (c *OpenChannel) RemoteUnsignedLocalUpdates() ([]LogUpdate, error) {
var updates []LogUpdate
err := kvdb.View(c.Db, func(tx kvdb.RTx) error {
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
@ -2277,7 +2311,7 @@ func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error {
c.RemoteNextRevocation = revKey
err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
@ -2318,7 +2352,7 @@ func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg,
var newRemoteCommit *ChannelCommitment
err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
@ -2493,7 +2527,7 @@ func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) {
defer c.RUnlock()
var fwdPkgs []*FwdPkg
if err := kvdb.View(c.Db, func(tx kvdb.RTx) error {
if err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
var err error
fwdPkgs, err = c.Packager.LoadFwdPkgs(tx)
return err
@ -2513,7 +2547,7 @@ func (c *OpenChannel) AckAddHtlcs(addRefs ...AddRef) error {
c.Lock()
defer c.Unlock()
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
return c.Packager.AckAddHtlcs(tx, addRefs...)
}, func() {})
}
@ -2526,7 +2560,7 @@ func (c *OpenChannel) AckSettleFails(settleFailRefs ...SettleFailRef) error {
c.Lock()
defer c.Unlock()
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
return c.Packager.AckSettleFails(tx, settleFailRefs...)
}, func() {})
}
@ -2537,7 +2571,7 @@ func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error {
c.Lock()
defer c.Unlock()
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
return c.Packager.SetFwdFilter(tx, height, fwdFilter)
}, func() {})
}
@ -2551,7 +2585,7 @@ func (c *OpenChannel) RemoveFwdPkgs(heights ...uint64) error {
c.Lock()
defer c.Unlock()
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
for _, height := range heights {
err := c.Packager.RemovePkg(tx, height)
if err != nil {
@ -2579,7 +2613,7 @@ func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) {
}
var commit ChannelCommitment
if err := kvdb.View(c.Db, func(tx kvdb.RTx) error {
if err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
@ -2626,7 +2660,7 @@ func (c *OpenChannel) CommitmentHeight() (uint64, error) {
defer c.RUnlock()
var height uint64
err := kvdb.View(c.Db, func(tx kvdb.RTx) error {
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
// Get the bucket dedicated to storing the metadata for open
// channels.
chanBucket, err := fetchChanBucket(
@ -2663,7 +2697,7 @@ func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, e
defer c.RUnlock()
var commit ChannelCommitment
err := kvdb.View(c.Db, func(tx kvdb.RTx) error {
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
@ -2821,7 +2855,7 @@ func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary,
c.Lock()
defer c.Unlock()
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
openChanBucket := tx.ReadWriteBucket(openChannelBucket)
if openChanBucket == nil {
return ErrNoChanDBExists
@ -3033,7 +3067,7 @@ func (c *OpenChannel) Snapshot() *ChannelSnapshot {
// latest fully committed state is returned. The first commitment returned is
// the local commitment, and the second returned is the remote commitment.
func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitment, error) {
err := kvdb.View(c.Db, func(tx kvdb.RTx) error {
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
@ -3055,7 +3089,7 @@ func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitmen
// acting on a possible contract breach to ensure, that the caller has the most
// up to date information required to deliver justice.
func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) {
err := kvdb.View(c.Db, func(tx kvdb.RTx) error {
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)

View File

@ -183,7 +183,7 @@ var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption {
// createTestChannel writes a test channel to the database. It takes a set of
// functional options which can be used to overwrite the default of creating
// a pending channel that was broadcast at height 100.
func createTestChannel(t *testing.T, cdb *DB,
func createTestChannel(t *testing.T, cdb *ChannelStateDB,
opts ...testChannelOption) *OpenChannel {
// Create a default set of parameters.
@ -221,7 +221,7 @@ func createTestChannel(t *testing.T, cdb *DB,
return params.channel
}
func createTestChannelState(t *testing.T, cdb *DB) *OpenChannel {
func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel {
// Simulate 1000 channel updates.
producer, err := shachain.NewRevocationProducerFromBytes(key[:])
if err != nil {
@ -359,12 +359,14 @@ func createTestChannelState(t *testing.T, cdb *DB) *OpenChannel {
func TestOpenChannelPutGetDelete(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create the test channel state, with additional htlcs on the local
// and remote commitment.
localHtlcs := []HTLC{
@ -508,12 +510,14 @@ func TestOptionalShutdown(t *testing.T) {
test := test
t.Run(test.name, func(t *testing.T) {
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create a channel with upfront scripts set as
// specified in the test.
state := createTestChannel(
@ -565,12 +569,14 @@ func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) {
func TestChannelStateTransition(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
// First create a minimal channel, then perform a full sync in order to
// persist the data.
channel := createTestChannel(t, cdb)
@ -842,7 +848,7 @@ func TestChannelStateTransition(t *testing.T) {
}
// At this point, we should have 2 forwarding packages added.
fwdPkgs := loadFwdPkgs(t, cdb, channel.Packager)
fwdPkgs := loadFwdPkgs(t, cdb.backend, channel.Packager)
require.Len(t, fwdPkgs, 2, "wrong number of forwarding packages")
// Now attempt to delete the channel from the database.
@ -877,19 +883,21 @@ func TestChannelStateTransition(t *testing.T) {
}
// All forwarding packages of this channel has been deleted too.
fwdPkgs = loadFwdPkgs(t, cdb, channel.Packager)
fwdPkgs = loadFwdPkgs(t, cdb.backend, channel.Packager)
require.Empty(t, fwdPkgs, "no forwarding packages should exist")
}
func TestFetchPendingChannels(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create a pending channel that was broadcast at height 99.
const broadcastHeight = 99
createTestChannel(t, cdb, pendingHeightOption(broadcastHeight))
@ -963,12 +971,14 @@ func TestFetchPendingChannels(t *testing.T) {
func TestFetchClosedChannels(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create an open channel in the database.
state := createTestChannel(t, cdb, openChannelOption())
@ -1054,18 +1064,20 @@ func TestFetchWaitingCloseChannels(t *testing.T) {
// We'll start by creating two channels within our test database. One of
// them will have their funding transaction confirmed on-chain, while
// the other one will remain unconfirmed.
db, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
channels := make([]*OpenChannel, numChannels)
for i := 0; i < numChannels; i++ {
// Create a pending channel in the database at the broadcast
// height.
channels[i] = createTestChannel(
t, db, pendingHeightOption(broadcastHeight),
t, cdb, pendingHeightOption(broadcastHeight),
)
}
@ -1116,7 +1128,7 @@ func TestFetchWaitingCloseChannels(t *testing.T) {
// Now, we'll fetch all the channels waiting to be closed from the
// database. We should expect to see both channels above, even if any of
// them haven't had their funding transaction confirm on-chain.
waitingCloseChannels, err := db.FetchWaitingCloseChannels()
waitingCloseChannels, err := cdb.FetchWaitingCloseChannels()
if err != nil {
t.Fatalf("unable to fetch all waiting close channels: %v", err)
}
@ -1169,12 +1181,14 @@ func TestFetchWaitingCloseChannels(t *testing.T) {
func TestRefreshShortChanID(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
// First create a test channel.
state := createTestChannel(t, cdb)
@ -1317,13 +1331,15 @@ func TestCloseInitiator(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v",
err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create an open channel.
channel := createTestChannel(
t, cdb, openChannelOption(),
@ -1362,13 +1378,15 @@ func TestCloseInitiator(t *testing.T) {
// TestCloseChannelStatus tests setting of a channel status on the historical
// channel on channel close.
func TestCloseChannelStatus(t *testing.T) {
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v",
err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create an open channel.
channel := createTestChannel(
t, cdb, openChannelOption(),
@ -1427,7 +1445,7 @@ func TestBalanceAtHeight(t *testing.T) {
putRevokedState := func(c *OpenChannel, height uint64, local,
remote lnwire.MilliSatoshi) error {
err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint,
c.ChainHash,
@ -1508,13 +1526,15 @@ func TestBalanceAtHeight(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v",
err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create options to set the heights and balances of
// our local and remote commitments.
localCommitOpt := channelCommitmentOption(

View File

@ -23,6 +23,7 @@ import (
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
const (
@ -209,6 +210,11 @@ var (
// Big endian is the preferred byte order, due to cursor scans over
// integer keys iterating in order.
byteOrder = binary.BigEndian
// channelOpeningStateBucket is the database bucket used to store the
// channelOpeningState for each channel that is currently in the process
// of being opened.
channelOpeningStateBucket = []byte("channelOpeningState")
)
// DB is the primary datastore for the lnd daemon. The database stores
@ -217,6 +223,9 @@ var (
type DB struct {
kvdb.Backend
// channelStateDB separates all DB operations on channel state.
channelStateDB *ChannelStateDB
dbPath string
graph *ChannelGraph
clock clock.Clock
@ -265,13 +274,27 @@ func CreateWithBackend(backend kvdb.Backend, modifiers ...OptionModifier) (*DB,
chanDB := &DB{
Backend: backend,
clock: opts.clock,
dryRun: opts.dryRun,
channelStateDB: &ChannelStateDB{
linkNodeDB: &LinkNodeDB{
backend: backend,
},
backend: backend,
},
clock: opts.clock,
dryRun: opts.dryRun,
}
chanDB.graph = newChannelGraph(
chanDB, opts.RejectCacheSize, opts.ChannelCacheSize,
opts.BatchCommitInterval,
// Set the parent pointer (only used in tests).
chanDB.channelStateDB.parent = chanDB
var err error
chanDB.graph, err = NewChannelGraph(
backend, opts.RejectCacheSize, opts.ChannelCacheSize,
opts.BatchCommitInterval, opts.PreAllocCacheNumNodes,
)
if err != nil {
return nil, err
}
// Synchronize the version of database and apply migrations if needed.
if err := chanDB.syncVersions(dbVersions); err != nil {
@ -287,7 +310,7 @@ func (d *DB) Path() string {
return d.dbPath
}
var topLevelBuckets = [][]byte{
var dbTopLevelBuckets = [][]byte{
openChannelBucket,
closedChannelBucket,
forwardingLogBucket,
@ -298,10 +321,6 @@ var topLevelBuckets = [][]byte{
paymentsIndexBucket,
peersBucket,
nodeInfoBucket,
nodeBucket,
edgeBucket,
edgeIndexBucket,
graphMetaBucket,
metaBucket,
closeSummaryBucket,
outpointBucket,
@ -312,7 +331,7 @@ var topLevelBuckets = [][]byte{
// operation is fully atomic.
func (d *DB) Wipe() error {
err := kvdb.Update(d, func(tx kvdb.RwTx) error {
for _, tlb := range topLevelBuckets {
for _, tlb := range dbTopLevelBuckets {
err := tx.DeleteTopLevelBucket(tlb)
if err != nil && err != kvdb.ErrBucketNotFound {
return err
@ -327,10 +346,10 @@ func (d *DB) Wipe() error {
return initChannelDB(d.Backend)
}
// createChannelDB creates and initializes a fresh version of channeldb. In
// the case that the target path has not yet been created or doesn't yet exist,
// then the path is created. Additionally, all required top-level buckets used
// within the database are created.
// initChannelDB creates and initializes a fresh version of channeldb. In the
// case that the target path has not yet been created or doesn't yet exist, then
// the path is created. Additionally, all required top-level buckets used within
// the database are created.
func initChannelDB(db kvdb.Backend) error {
err := kvdb.Update(db, func(tx kvdb.RwTx) error {
meta := &Meta{}
@ -340,42 +359,12 @@ func initChannelDB(db kvdb.Backend) error {
return nil
}
for _, tlb := range topLevelBuckets {
for _, tlb := range dbTopLevelBuckets {
if _, err := tx.CreateTopLevelBucket(tlb); err != nil {
return err
}
}
nodes := tx.ReadWriteBucket(nodeBucket)
_, err = nodes.CreateBucket(aliasIndexBucket)
if err != nil {
return err
}
_, err = nodes.CreateBucket(nodeUpdateIndexBucket)
if err != nil {
return err
}
edges := tx.ReadWriteBucket(edgeBucket)
if _, err := edges.CreateBucket(edgeIndexBucket); err != nil {
return err
}
if _, err := edges.CreateBucket(edgeUpdateIndexBucket); err != nil {
return err
}
if _, err := edges.CreateBucket(channelPointBucket); err != nil {
return err
}
if _, err := edges.CreateBucket(zombieBucket); err != nil {
return err
}
graphMeta := tx.ReadWriteBucket(graphMetaBucket)
_, err = graphMeta.CreateBucket(pruneLogBucket)
if err != nil {
return err
}
meta.DbVersionNumber = getLatestDBVersion(dbVersions)
return putMeta(meta, tx)
}, func() {})
@ -397,15 +386,45 @@ func fileExists(path string) bool {
return true
}
// ChannelStateDB is a database that keeps track of all channel state.
type ChannelStateDB struct {
// linkNodeDB separates all DB operations on LinkNodes.
linkNodeDB *LinkNodeDB
// parent holds a pointer to the "main" channeldb.DB object. This is
// only used for testing and should never be used in production code.
// For testing use the ChannelStateDB.GetParentDB() function to retrieve
// this pointer.
parent *DB
// backend points to the actual backend holding the channel state
// database. This may be a real backend or a cache middleware.
backend kvdb.Backend
}
// GetParentDB returns the "main" channeldb.DB object that is the owner of this
// ChannelStateDB instance. Use this function only in tests where passing around
// pointers makes testing less readable. Never to be used in production code!
func (c *ChannelStateDB) GetParentDB() *DB {
return c.parent
}
// LinkNodeDB returns the current instance of the link node database.
func (c *ChannelStateDB) LinkNodeDB() *LinkNodeDB {
return c.linkNodeDB
}
// FetchOpenChannels starts a new database transaction and returns all stored
// currently active/open channels associated with the target nodeID. In the case
// that no active channels are known to have been created with this node, then a
// zero-length slice is returned.
func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) {
func (c *ChannelStateDB) FetchOpenChannels(nodeID *btcec.PublicKey) (
[]*OpenChannel, error) {
var channels []*OpenChannel
err := kvdb.View(d, func(tx kvdb.RTx) error {
err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
var err error
channels, err = d.fetchOpenChannels(tx, nodeID)
channels, err = c.fetchOpenChannels(tx, nodeID)
return err
}, func() {
channels = nil
@ -418,7 +437,7 @@ func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error)
// stored currently active/open channels associated with the target nodeID. In
// the case that no active channels are known to have been created with this
// node, then a zero-length slice is returned.
func (d *DB) fetchOpenChannels(tx kvdb.RTx,
func (c *ChannelStateDB) fetchOpenChannels(tx kvdb.RTx,
nodeID *btcec.PublicKey) ([]*OpenChannel, error) {
// Get the bucket dedicated to storing the metadata for open channels.
@ -454,7 +473,7 @@ func (d *DB) fetchOpenChannels(tx kvdb.RTx,
// Finally, we both of the necessary buckets retrieved, fetch
// all the active channels related to this node.
nodeChannels, err := d.fetchNodeChannels(chainBucket)
nodeChannels, err := c.fetchNodeChannels(chainBucket)
if err != nil {
return fmt.Errorf("unable to read channel for "+
"chain_hash=%x, node_key=%x: %v",
@ -471,7 +490,8 @@ func (d *DB) fetchOpenChannels(tx kvdb.RTx,
// fetchNodeChannels retrieves all active channels from the target chainBucket
// which is under a node's dedicated channel bucket. This function is typically
// used to fetch all the active channels related to a particular node.
func (d *DB) fetchNodeChannels(chainBucket kvdb.RBucket) ([]*OpenChannel, error) {
func (c *ChannelStateDB) fetchNodeChannels(chainBucket kvdb.RBucket) (
[]*OpenChannel, error) {
var channels []*OpenChannel
@ -497,7 +517,7 @@ func (d *DB) fetchNodeChannels(chainBucket kvdb.RBucket) ([]*OpenChannel, error)
return fmt.Errorf("unable to read channel data for "+
"chan_point=%v: %v", outPoint, err)
}
oChannel.Db = d
oChannel.Db = c
channels = append(channels, oChannel)
@ -514,8 +534,8 @@ func (d *DB) fetchNodeChannels(chainBucket kvdb.RBucket) ([]*OpenChannel, error)
// point. If the channel cannot be found, then an error will be returned.
// Optionally an existing db tx can be supplied. Optionally an existing db tx
// can be supplied.
func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel,
error) {
func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
*OpenChannel, error) {
var (
targetChan *OpenChannel
@ -591,7 +611,7 @@ func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel,
}
targetChan = channel
targetChan.Db = d
targetChan.Db = c
return nil
})
@ -600,7 +620,7 @@ func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel,
var err error
if tx == nil {
err = kvdb.View(d, chanScan, func() {})
err = kvdb.View(c.backend, chanScan, func() {})
} else {
err = chanScan(tx)
}
@ -620,16 +640,16 @@ func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel,
// FetchAllChannels attempts to retrieve all open channels currently stored
// within the database, including pending open, fully open and channels waiting
// for a closing transaction to confirm.
func (d *DB) FetchAllChannels() ([]*OpenChannel, error) {
return fetchChannels(d)
func (c *ChannelStateDB) FetchAllChannels() ([]*OpenChannel, error) {
return fetchChannels(c)
}
// FetchAllOpenChannels will return all channels that have the funding
// transaction confirmed, and is not waiting for a closing transaction to be
// confirmed.
func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) {
func (c *ChannelStateDB) FetchAllOpenChannels() ([]*OpenChannel, error) {
return fetchChannels(
d,
c,
pendingChannelFilter(false),
waitingCloseFilter(false),
)
@ -638,8 +658,8 @@ func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) {
// FetchPendingChannels will return channels that have completed the process of
// generating and broadcasting funding transactions, but whose funding
// transactions have yet to be confirmed on the blockchain.
func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) {
return fetchChannels(d,
func (c *ChannelStateDB) FetchPendingChannels() ([]*OpenChannel, error) {
return fetchChannels(c,
pendingChannelFilter(true),
waitingCloseFilter(false),
)
@ -649,9 +669,9 @@ func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) {
// but are now waiting for a closing transaction to be confirmed.
//
// NOTE: This includes channels that are also pending to be opened.
func (d *DB) FetchWaitingCloseChannels() ([]*OpenChannel, error) {
func (c *ChannelStateDB) FetchWaitingCloseChannels() ([]*OpenChannel, error) {
return fetchChannels(
d, waitingCloseFilter(true),
c, waitingCloseFilter(true),
)
}
@ -692,10 +712,12 @@ func waitingCloseFilter(waitingClose bool) fetchChannelsFilter {
// which have a true value returned for *all* of the filters will be returned.
// If no filters are provided, every channel in the open channels bucket will
// be returned.
func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error) {
func fetchChannels(c *ChannelStateDB, filters ...fetchChannelsFilter) (
[]*OpenChannel, error) {
var channels []*OpenChannel
err := kvdb.View(d, func(tx kvdb.RTx) error {
err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
// Get the bucket dedicated to storing the metadata for open
// channels.
openChanBucket := tx.ReadBucket(openChannelBucket)
@ -737,7 +759,7 @@ func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error
"bucket for chain=%x", chainHash[:])
}
nodeChans, err := d.fetchNodeChannels(chainBucket)
nodeChans, err := c.fetchNodeChannels(chainBucket)
if err != nil {
return fmt.Errorf("unable to read "+
"channel for chain_hash=%x, "+
@ -786,10 +808,12 @@ func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error
// it becomes fully closed after a single confirmation. When a channel was
// forcibly closed, it will become fully closed after _all_ the pending funds
// (if any) have been swept.
func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, error) {
func (c *ChannelStateDB) FetchClosedChannels(pendingOnly bool) (
[]*ChannelCloseSummary, error) {
var chanSummaries []*ChannelCloseSummary
if err := kvdb.View(d, func(tx kvdb.RTx) error {
if err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
closeBucket := tx.ReadBucket(closedChannelBucket)
if closeBucket == nil {
return ErrNoClosedChannels
@ -827,9 +851,11 @@ var ErrClosedChannelNotFound = errors.New("unable to find closed channel summary
// FetchClosedChannel queries for a channel close summary using the channel
// point of the channel in question.
func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, error) {
func (c *ChannelStateDB) FetchClosedChannel(chanID *wire.OutPoint) (
*ChannelCloseSummary, error) {
var chanSummary *ChannelCloseSummary
if err := kvdb.View(d, func(tx kvdb.RTx) error {
if err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
closeBucket := tx.ReadBucket(closedChannelBucket)
if closeBucket == nil {
return ErrClosedChannelNotFound
@ -861,11 +887,11 @@ func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, er
// FetchClosedChannelForID queries for a channel close summary using the
// channel ID of the channel in question.
func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) (
func (c *ChannelStateDB) FetchClosedChannelForID(cid lnwire.ChannelID) (
*ChannelCloseSummary, error) {
var chanSummary *ChannelCloseSummary
if err := kvdb.View(d, func(tx kvdb.RTx) error {
if err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
closeBucket := tx.ReadBucket(closedChannelBucket)
if closeBucket == nil {
return ErrClosedChannelNotFound
@ -914,8 +940,12 @@ func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) (
// cooperatively closed and it's reached a single confirmation, or after all
// the pending funds in a channel that has been forcibly closed have been
// swept.
func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error {
return kvdb.Update(d, func(tx kvdb.RwTx) error {
func (c *ChannelStateDB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error {
var (
openChannels []*OpenChannel
pruneLinkNode *btcec.PublicKey
)
err := kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
var b bytes.Buffer
if err := writeOutpoint(&b, chanPoint); err != nil {
return err
@ -961,19 +991,35 @@ func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error {
// other open channels with this peer. If we don't we'll
// garbage collect it to ensure we don't establish persistent
// connections to peers without open channels.
return d.pruneLinkNode(tx, chanSummary.RemotePub)
}, func() {})
pruneLinkNode = chanSummary.RemotePub
openChannels, err = c.fetchOpenChannels(
tx, pruneLinkNode,
)
if err != nil {
return fmt.Errorf("unable to fetch open channels for "+
"peer %x: %v",
pruneLinkNode.SerializeCompressed(), err)
}
return nil
}, func() {
openChannels = nil
pruneLinkNode = nil
})
if err != nil {
return err
}
// Decide whether we want to remove the link node, based upon the number
// of still open channels.
return c.pruneLinkNode(openChannels, pruneLinkNode)
}
// pruneLinkNode determines whether we should garbage collect a link node from
// the database due to no longer having any open channels with it. If there are
// any left, then this acts as a no-op.
func (d *DB) pruneLinkNode(tx kvdb.RwTx, remotePub *btcec.PublicKey) error {
openChannels, err := d.fetchOpenChannels(tx, remotePub)
if err != nil {
return fmt.Errorf("unable to fetch open channels for peer %x: "+
"%v", remotePub.SerializeCompressed(), err)
}
func (c *ChannelStateDB) pruneLinkNode(openChannels []*OpenChannel,
remotePub *btcec.PublicKey) error {
if len(openChannels) > 0 {
return nil
@ -982,27 +1028,42 @@ func (d *DB) pruneLinkNode(tx kvdb.RwTx, remotePub *btcec.PublicKey) error {
log.Infof("Pruning link node %x with zero open channels from database",
remotePub.SerializeCompressed())
return d.deleteLinkNode(tx, remotePub)
return c.linkNodeDB.DeleteLinkNode(remotePub)
}
// PruneLinkNodes attempts to prune all link nodes found within the databse with
// whom we no longer have any open channels with.
func (d *DB) PruneLinkNodes() error {
return kvdb.Update(d, func(tx kvdb.RwTx) error {
linkNodes, err := d.fetchAllLinkNodes(tx)
func (c *ChannelStateDB) PruneLinkNodes() error {
allLinkNodes, err := c.linkNodeDB.FetchAllLinkNodes()
if err != nil {
return err
}
for _, linkNode := range allLinkNodes {
var (
openChannels []*OpenChannel
linkNode = linkNode
)
err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
var err error
openChannels, err = c.fetchOpenChannels(
tx, linkNode.IdentityPub,
)
return err
}, func() {
openChannels = nil
})
if err != nil {
return err
}
for _, linkNode := range linkNodes {
err := d.pruneLinkNode(tx, linkNode.IdentityPub)
if err != nil {
return err
}
err = c.pruneLinkNode(openChannels, linkNode.IdentityPub)
if err != nil {
return err
}
}
return nil
}, func() {})
return nil
}
// ChannelShell is a shell of a channel that is meant to be used for channel
@ -1024,8 +1085,8 @@ type ChannelShell struct {
// addresses, and finally create an edge within the graph for the channel as
// well. This method is idempotent, so repeated calls with the same set of
// channel shells won't modify the database after the initial call.
func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error {
err := kvdb.Update(d, func(tx kvdb.RwTx) error {
func (c *ChannelStateDB) RestoreChannelShells(channelShells ...*ChannelShell) error {
err := kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
for _, channelShell := range channelShells {
channel := channelShell.Chan
@ -1039,7 +1100,7 @@ func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error {
// and link node for this channel. If the channel
// already exists, then in order to ensure this method
// is idempotent, we'll continue to the next step.
channel.Db = d
channel.Db = c
err := syncNewChannel(
tx, channel, channelShell.NodeAddrs,
)
@ -1059,41 +1120,28 @@ func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error {
// AddrsForNode consults the graph and channel database for all addresses known
// to the passed node public key.
func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) {
var (
linkNode *LinkNode
graphNode LightningNode
)
func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr,
error) {
dbErr := kvdb.View(d, func(tx kvdb.RTx) error {
var err error
linkNode, err := d.channelStateDB.linkNodeDB.FetchLinkNode(nodePub)
if err != nil {
return nil, err
}
linkNode, err = fetchLinkNode(tx, nodePub)
if err != nil {
return err
}
// We'll also query the graph for this peer to see if they have
// any addresses that we don't currently have stored within the
// link node database.
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
compressedPubKey := nodePub.SerializeCompressed()
graphNode, err = fetchLightningNode(nodes, compressedPubKey)
if err != nil && err != ErrGraphNodeNotFound {
// If the node isn't found, then that's OK, as we still
// have the link node data.
return err
}
return nil
}, func() {
linkNode = nil
})
if dbErr != nil {
return nil, dbErr
// We'll also query the graph for this peer to see if they have any
// addresses that we don't currently have stored within the link node
// database.
pubKey, err := route.NewVertexFromBytes(nodePub.SerializeCompressed())
if err != nil {
return nil, err
}
graphNode, err := d.graph.FetchLightningNode(pubKey)
if err != nil && err != ErrGraphNodeNotFound {
return nil, err
} else if err == ErrGraphNodeNotFound {
// If the node isn't found, then that's OK, as we still have the
// link node data. But any other error needs to be returned.
graphNode = &LightningNode{}
}
// Now that we have both sources of addrs for this node, we'll use a
@ -1118,16 +1166,18 @@ func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) {
// database. If the channel was already removed (has a closed channel entry),
// then we'll return a nil error. Otherwise, we'll insert a new close summary
// into the database.
func (d *DB) AbandonChannel(chanPoint *wire.OutPoint, bestHeight uint32) error {
func (c *ChannelStateDB) AbandonChannel(chanPoint *wire.OutPoint,
bestHeight uint32) error {
// With the chanPoint constructed, we'll attempt to find the target
// channel in the database. If we can't find the channel, then we'll
// return the error back to the caller.
dbChan, err := d.FetchChannel(nil, *chanPoint)
dbChan, err := c.FetchChannel(nil, *chanPoint)
switch {
// If the channel wasn't found, then it's possible that it was already
// abandoned from the database.
case err == ErrChannelNotFound:
_, closedErr := d.FetchClosedChannel(chanPoint)
_, closedErr := c.FetchClosedChannel(chanPoint)
if closedErr != nil {
return closedErr
}
@ -1163,6 +1213,58 @@ func (d *DB) AbandonChannel(chanPoint *wire.OutPoint, bestHeight uint32) error {
return dbChan.CloseChannel(summary, ChanStatusLocalCloseInitiator)
}
// SaveChannelOpeningState saves the serialized channel state for the provided
// chanPoint to the channelOpeningStateBucket.
func (c *ChannelStateDB) SaveChannelOpeningState(outPoint,
serializedState []byte) error {
return kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
bucket, err := tx.CreateTopLevelBucket(channelOpeningStateBucket)
if err != nil {
return err
}
return bucket.Put(outPoint, serializedState)
}, func() {})
}
// GetChannelOpeningState fetches the serialized channel state for the provided
// outPoint from the database, or returns ErrChannelNotFound if the channel
// is not found.
func (c *ChannelStateDB) GetChannelOpeningState(outPoint []byte) ([]byte, error) {
var serializedState []byte
err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
bucket := tx.ReadBucket(channelOpeningStateBucket)
if bucket == nil {
// If the bucket does not exist, it means we never added
// a channel to the db, so return ErrChannelNotFound.
return ErrChannelNotFound
}
serializedState = bucket.Get(outPoint)
if serializedState == nil {
return ErrChannelNotFound
}
return nil
}, func() {
serializedState = nil
})
return serializedState, err
}
// DeleteChannelOpeningState removes any state for outPoint from the database.
func (c *ChannelStateDB) DeleteChannelOpeningState(outPoint []byte) error {
return kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
bucket := tx.ReadWriteBucket(channelOpeningStateBucket)
if bucket == nil {
return ErrChannelNotFound
}
return bucket.Delete(outPoint)
}, func() {})
}
// syncVersions function is used for safe db version synchronization. It
// applies migration functions to the current database and recovers the
// previous state of db if at least one error/panic appeared during migration.
@ -1236,11 +1338,17 @@ func (d *DB) syncVersions(versions []version) error {
}, func() {})
}
// ChannelGraph returns a new instance of the directed channel graph.
// ChannelGraph returns the current instance of the directed channel graph.
func (d *DB) ChannelGraph() *ChannelGraph {
return d.graph
}
// ChannelStateDB returns the sub database that is concerned with the channel
// state.
func (d *DB) ChannelStateDB() *ChannelStateDB {
return d.channelStateDB
}
func getLatestDBVersion(versions []version) uint32 {
return versions[len(versions)-1].number
}
@ -1290,9 +1398,11 @@ func fetchHistoricalChanBucket(tx kvdb.RTx,
// FetchHistoricalChannel fetches open channel data from the historical channel
// bucket.
func (d *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, error) {
func (c *ChannelStateDB) FetchHistoricalChannel(outPoint *wire.OutPoint) (
*OpenChannel, error) {
var channel *OpenChannel
err := kvdb.View(d, func(tx kvdb.RTx) error {
err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchHistoricalChanBucket(tx, outPoint)
if err != nil {
return err
@ -1300,7 +1410,7 @@ func (d *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, erro
channel, err = fetchOpenChannel(chanBucket, outPoint)
channel.Db = d
channel.Db = c
return err
}, func() {
channel = nil

View File

@ -87,15 +87,18 @@ func TestWipe(t *testing.T) {
}
defer cleanup()
cdb, err := CreateWithBackend(backend)
fullDB, err := CreateWithBackend(backend)
if err != nil {
t.Fatalf("unable to create channeldb: %v", err)
}
defer cdb.Close()
defer fullDB.Close()
if err := cdb.Wipe(); err != nil {
if err := fullDB.Wipe(); err != nil {
t.Fatalf("unable to wipe channeldb: %v", err)
}
cdb := fullDB.ChannelStateDB()
// Check correct errors are returned
openChannels, err := cdb.FetchAllOpenChannels()
require.NoError(t, err, "fetching open channels")
@ -113,12 +116,14 @@ func TestFetchClosedChannelForID(t *testing.T) {
const numChans = 101
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create the test channel state, that we will mutate the index of the
// funding point.
state := createTestChannelState(t, cdb)
@ -184,18 +189,18 @@ func TestFetchClosedChannelForID(t *testing.T) {
func TestAddrsForNode(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
graph := cdb.ChannelGraph()
graph := fullDB.ChannelGraph()
// We'll make a test vertex to insert into the database, as the source
// node, but this node will only have half the number of addresses it
// usually does.
testNode, err := createTestVertex(cdb)
testNode, err := createTestVertex(fullDB)
if err != nil {
t.Fatalf("unable to create test node: %v", err)
}
@ -210,8 +215,9 @@ func TestAddrsForNode(t *testing.T) {
if err != nil {
t.Fatalf("unable to recv node pub: %v", err)
}
linkNode := cdb.NewLinkNode(
wire.MainNet, nodePub, anotherAddr,
linkNode := NewLinkNode(
fullDB.channelStateDB.linkNodeDB, wire.MainNet, nodePub,
anotherAddr,
)
if err := linkNode.Sync(); err != nil {
t.Fatalf("unable to sync link node: %v", err)
@ -219,7 +225,7 @@ func TestAddrsForNode(t *testing.T) {
// Now that we've created a link node, as well as a vertex for the
// node, we'll query for all its addresses.
nodeAddrs, err := cdb.AddrsForNode(nodePub)
nodeAddrs, err := fullDB.AddrsForNode(nodePub)
if err != nil {
t.Fatalf("unable to obtain node addrs: %v", err)
}
@ -245,12 +251,14 @@ func TestAddrsForNode(t *testing.T) {
func TestFetchChannel(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create an open channel.
channelState := createTestChannel(t, cdb, openChannelOption())
@ -349,12 +357,14 @@ func genRandomChannelShell() (*ChannelShell, error) {
func TestRestoreChannelShells(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
// First, we'll make our channel shell, it will only have the minimal
// amount of information required for us to initiate the data loss
// protection feature.
@ -423,7 +433,9 @@ func TestRestoreChannelShells(t *testing.T) {
// We should also be able to find the link node that was inserted by
// its public key.
linkNode, err := cdb.FetchLinkNode(channelShell.Chan.IdentityPub)
linkNode, err := fullDB.channelStateDB.linkNodeDB.FetchLinkNode(
channelShell.Chan.IdentityPub,
)
if err != nil {
t.Fatalf("unable to fetch link node: %v", err)
}
@ -443,12 +455,14 @@ func TestRestoreChannelShells(t *testing.T) {
func TestAbandonChannel(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
// If we attempt to abandon the state of a channel that doesn't exist
// in the open or closed channel bucket, then we should receive an
// error.
@ -616,13 +630,15 @@ func TestFetchChannels(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test "+
"database: %v", err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create a pending channel that is not awaiting close.
createTestChannel(
t, cdb, channelIDOption(pendingChan),
@ -685,12 +701,14 @@ func TestFetchChannels(t *testing.T) {
// TestFetchHistoricalChannel tests lookup of historical channels.
func TestFetchHistoricalChannel(t *testing.T) {
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create a an open channel in the database.
channel := createTestChannel(t, cdb, openChannelOption())

View File

@ -174,39 +174,132 @@ const (
// independently. Edge removal results in the deletion of all edge information
// for that edge.
type ChannelGraph struct {
db *DB
db kvdb.Backend
cacheMu sync.RWMutex
rejectCache *rejectCache
chanCache *channelCache
graphCache *GraphCache
chanScheduler batch.Scheduler
nodeScheduler batch.Scheduler
}
// newChannelGraph allocates a new ChannelGraph backed by a DB instance. The
// NewChannelGraph allocates a new ChannelGraph backed by a DB instance. The
// returned instance has its own unique reject cache and channel cache.
func newChannelGraph(db *DB, rejectCacheSize, chanCacheSize int,
batchCommitInterval time.Duration) *ChannelGraph {
func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int,
batchCommitInterval time.Duration,
preAllocCacheNumNodes int) (*ChannelGraph, error) {
if err := initChannelGraph(db); err != nil {
return nil, err
}
g := &ChannelGraph{
db: db,
rejectCache: newRejectCache(rejectCacheSize),
chanCache: newChannelCache(chanCacheSize),
graphCache: NewGraphCache(preAllocCacheNumNodes),
}
g.chanScheduler = batch.NewTimeScheduler(
db.Backend, &g.cacheMu, batchCommitInterval,
db, &g.cacheMu, batchCommitInterval,
)
g.nodeScheduler = batch.NewTimeScheduler(
db.Backend, nil, batchCommitInterval,
db, nil, batchCommitInterval,
)
return g
startTime := time.Now()
log.Debugf("Populating in-memory channel graph, this might take a " +
"while...")
err := g.ForEachNodeCacheable(func(tx kvdb.RTx, node GraphCacheNode) error {
return g.graphCache.AddNode(tx, node)
})
if err != nil {
return nil, err
}
log.Debugf("Finished populating in-memory channel graph (took %v, %s)",
time.Since(startTime), g.graphCache.Stats())
return g, nil
}
// Database returns a pointer to the underlying database.
func (c *ChannelGraph) Database() *DB {
return c.db
var graphTopLevelBuckets = [][]byte{
nodeBucket,
edgeBucket,
edgeIndexBucket,
graphMetaBucket,
}
// Wipe completely deletes all saved state within all used buckets within the
// database. The deletion is done in a single transaction, therefore this
// operation is fully atomic.
func (c *ChannelGraph) Wipe() error {
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
for _, tlb := range graphTopLevelBuckets {
err := tx.DeleteTopLevelBucket(tlb)
if err != nil && err != kvdb.ErrBucketNotFound {
return err
}
}
return nil
}, func() {})
if err != nil {
return err
}
return initChannelGraph(c.db)
}
// createChannelDB creates and initializes a fresh version of channeldb. In
// the case that the target path has not yet been created or doesn't yet exist,
// then the path is created. Additionally, all required top-level buckets used
// within the database are created.
func initChannelGraph(db kvdb.Backend) error {
err := kvdb.Update(db, func(tx kvdb.RwTx) error {
for _, tlb := range graphTopLevelBuckets {
if _, err := tx.CreateTopLevelBucket(tlb); err != nil {
return err
}
}
nodes := tx.ReadWriteBucket(nodeBucket)
_, err := nodes.CreateBucketIfNotExists(aliasIndexBucket)
if err != nil {
return err
}
_, err = nodes.CreateBucketIfNotExists(nodeUpdateIndexBucket)
if err != nil {
return err
}
edges := tx.ReadWriteBucket(edgeBucket)
_, err = edges.CreateBucketIfNotExists(edgeIndexBucket)
if err != nil {
return err
}
_, err = edges.CreateBucketIfNotExists(edgeUpdateIndexBucket)
if err != nil {
return err
}
_, err = edges.CreateBucketIfNotExists(channelPointBucket)
if err != nil {
return err
}
_, err = edges.CreateBucketIfNotExists(zombieBucket)
if err != nil {
return err
}
graphMeta := tx.ReadWriteBucket(graphMetaBucket)
_, err = graphMeta.CreateBucketIfNotExists(pruneLogBucket)
return err
}, func() {})
if err != nil {
return fmt.Errorf("unable to create new channel graph: %v", err)
}
return nil
}
// ForEachChannel iterates through all the channel edges stored within the
@ -218,7 +311,9 @@ func (c *ChannelGraph) Database() *DB {
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
// for that particular channel edge routing policy will be passed into the
// callback.
func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error {
func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo,
*ChannelEdgePolicy, *ChannelEdgePolicy) error) error {
// TODO(roasbeef): ptr map to reduce # of allocs? no duplicates
return kvdb.View(c.db, func(tx kvdb.RTx) error {
@ -270,23 +365,22 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePoli
// ForEachNodeChannel iterates through all channels of a given node, executing the
// passed callback with an edge info structure and the policies of each end
// of the channel. The first edge policy is the outgoing edge *to* the
// the connecting node, while the second is the incoming edge *from* the
// connecting node, while the second is the incoming edge *from* the
// connecting node. If the callback returns an error, then the iteration is
// halted with the error propagated back up to the caller.
//
// Unknown policies are passed into the callback as nil values.
//
// If the caller wishes to re-use an existing boltdb transaction, then it
// should be passed as the first argument. Otherwise the first argument should
// be nil and a fresh transaction will be created to execute the graph
// traversal.
func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, nodePub []byte,
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error {
func (c *ChannelGraph) ForEachNodeChannel(node route.Vertex,
cb func(channel *DirectedChannel) error) error {
db := c.db
return c.graphCache.ForEachChannel(node, cb)
}
return nodeTraversal(tx, nodePub, db, cb)
// FetchNodeFeatures returns the features of a given node.
func (c *ChannelGraph) FetchNodeFeatures(
node route.Vertex) (*lnwire.FeatureVector, error) {
return c.graphCache.GetFeatures(node), nil
}
// DisabledChannelIDs returns the channel ids of disabled channels.
@ -374,6 +468,47 @@ func (c *ChannelGraph) ForEachNode(cb func(kvdb.RTx, *LightningNode) error) erro
return kvdb.View(c.db, traversal, func() {})
}
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
// graph, executing the passed callback with each node encountered. If the
// callback returns an error, then the transaction is aborted and the iteration
// stops early.
func (c *ChannelGraph) ForEachNodeCacheable(cb func(kvdb.RTx,
GraphCacheNode) error) error {
traversal := func(tx kvdb.RTx) error {
// First grab the nodes bucket which stores the mapping from
// pubKey to node information.
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
cacheableNode := newGraphCacheNode(route.Vertex{}, nil)
return nodes.ForEach(func(pubKey, nodeBytes []byte) error {
// If this is the source key, then we skip this
// iteration as the value for this key is a pubKey
// rather than raw node information.
if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 {
return nil
}
nodeReader := bytes.NewReader(nodeBytes)
err := deserializeLightningNodeCacheable(
nodeReader, cacheableNode,
)
if err != nil {
return err
}
// Execute the callback, the transaction will abort if
// this returns an error.
return cb(tx, cacheableNode)
})
}
return kvdb.View(c.db, traversal, func() {})
}
// SourceNode returns the source node of the graph. The source node is treated
// as the center node within a star-graph. This method may be used to kick off
// a path finding algorithm in order to explore the reachability of another
@ -465,6 +600,13 @@ func (c *ChannelGraph) AddLightningNode(node *LightningNode,
r := &batch.Request{
Update: func(tx kvdb.RwTx) error {
cNode := newGraphCacheNode(
node.PubKeyBytes, node.Features,
)
if err := c.graphCache.AddNode(tx, cNode); err != nil {
return err
}
return addLightningNode(tx, node)
},
}
@ -543,6 +685,8 @@ func (c *ChannelGraph) DeleteLightningNode(nodePub route.Vertex) error {
return ErrGraphNodeNotFound
}
c.graphCache.RemoveNode(nodePub)
return c.deleteLightningNode(nodes, nodePub[:])
}, func() {})
}
@ -669,6 +813,8 @@ func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, edge *ChannelEdgeInfo) error
return ErrEdgeAlreadyExist
}
c.graphCache.AddChannel(edge, nil, nil)
// Before we insert the channel into the database, we'll ensure that
// both nodes already exist in the channel graph. If either node
// doesn't, then we'll insert a "shell" node that just includes its
@ -868,6 +1014,8 @@ func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo) error {
return ErrEdgeNotFound
}
c.graphCache.UpdateChannel(edge)
return putChanEdgeInfo(edgeIndex, edge, chanKey)
}, func() {})
}
@ -953,7 +1101,7 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
// will be returned if that outpoint isn't known to be
// a channel. If no error is returned, then a channel
// was successfully pruned.
err = delChannelEdge(
err = c.delChannelEdge(
edges, edgeIndex, chanIndex, zombieIndex, nodes,
chanID, false, false,
)
@ -1004,6 +1152,8 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
c.chanCache.remove(channel.ChannelID)
}
log.Debugf("Pruned graph, cache now has %s", c.graphCache.Stats())
return chansClosed, nil
}
@ -1104,6 +1254,8 @@ func (c *ChannelGraph) pruneGraphNodes(nodes kvdb.RwBucket,
continue
}
c.graphCache.RemoveNode(nodePubKey)
// If we reach this point, then there are no longer any edges
// that connect this node, so we can delete it.
if err := c.deleteLightningNode(nodes, nodePubKey[:]); err != nil {
@ -1202,7 +1354,7 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInf
}
for _, k := range keys {
err = delChannelEdge(
err = c.delChannelEdge(
edges, edgeIndex, chanIndex, zombieIndex, nodes,
k, false, false,
)
@ -1310,7 +1462,9 @@ func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) {
// true, then when we mark these edges as zombies, we'll set up the keys such
// that we require the node that failed to send the fresh update to be the one
// that resurrects the channel from its zombie state.
func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool, chanIDs ...uint64) error {
func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool,
chanIDs ...uint64) error {
// TODO(roasbeef): possibly delete from node bucket if node has no more
// channels
// TODO(roasbeef): don't delete both edges?
@ -1343,7 +1497,7 @@ func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool, chanIDs ...u
var rawChanID [8]byte
for _, chanID := range chanIDs {
byteOrder.PutUint64(rawChanID[:], chanID)
err := delChannelEdge(
err := c.delChannelEdge(
edges, edgeIndex, chanIndex, zombieIndex, nodes,
rawChanID[:], true, strictZombiePruning,
)
@ -1472,7 +1626,9 @@ type ChannelEdge struct {
// ChanUpdatesInHorizon returns all the known channel edges which have at least
// one edge that has an update timestamp within the specified horizon.
func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]ChannelEdge, error) {
func (c *ChannelGraph) ChanUpdatesInHorizon(startTime,
endTime time.Time) ([]ChannelEdge, error) {
// To ensure we don't return duplicate ChannelEdges, we'll use an
// additional map to keep track of the edges already seen to prevent
// re-adding it.
@ -1605,7 +1761,9 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha
// update timestamp within the passed range. This method can be used by two
// nodes to quickly determine if they have the same set of up to date node
// announcements.
func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]LightningNode, error) {
func (c *ChannelGraph) NodeUpdatesInHorizon(startTime,
endTime time.Time) ([]LightningNode, error) {
var nodesInHorizon []LightningNode
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
@ -1933,7 +2091,7 @@ func delEdgeUpdateIndexEntry(edgesBucket kvdb.RwBucket, chanID uint64,
return nil
}
func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex,
func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex,
nodes kvdb.RwBucket, chanID []byte, isZombie, strictZombie bool) error {
edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID)
@ -1941,6 +2099,11 @@ func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex,
return err
}
c.graphCache.RemoveChannel(
edgeInfo.NodeKey1Bytes, edgeInfo.NodeKey2Bytes,
edgeInfo.ChannelID,
)
// We'll also remove the entry in the edge update index bucket before
// we delete the edges themselves so we can access their last update
// times.
@ -2075,7 +2238,9 @@ func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy,
},
Update: func(tx kvdb.RwTx) error {
var err error
isUpdate1, err = updateEdgePolicy(tx, edge)
isUpdate1, err = updateEdgePolicy(
tx, edge, c.graphCache,
)
// Silence ErrEdgeNotFound so that the batch can
// succeed, but propagate the error via local state.
@ -2138,7 +2303,9 @@ func (c *ChannelGraph) updateEdgeCache(e *ChannelEdgePolicy, isUpdate1 bool) {
// buckets using an existing database transaction. The returned boolean will be
// true if the updated policy belongs to node1, and false if the policy belonged
// to node2.
func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy) (bool, error) {
func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy,
graphCache *GraphCache) (bool, error) {
edges := tx.ReadWriteBucket(edgeBucket)
if edges == nil {
return false, ErrEdgeNotFound
@ -2186,6 +2353,14 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy) (bool, error) {
return false, err
}
var (
fromNodePubKey route.Vertex
toNodePubKey route.Vertex
)
copy(fromNodePubKey[:], fromNode)
copy(toNodePubKey[:], toNode)
graphCache.UpdatePolicy(edge, fromNodePubKey, toNodePubKey, isUpdate1)
return isUpdate1, nil
}
@ -2232,7 +2407,7 @@ type LightningNode struct {
// compatible manner.
ExtraOpaqueData []byte
db *DB
db kvdb.Backend
// TODO(roasbeef): discovery will need storage to keep it's last IP
// address and re-announce if interface changes?
@ -2356,17 +2531,11 @@ func (l *LightningNode) isPublic(tx kvdb.RTx, sourcePubKey []byte) (bool, error)
// FetchLightningNode attempts to look up a target node by its identity public
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
// returned.
//
// If the caller wishes to re-use an existing boltdb transaction, then it
// should be passed as the first argument. Otherwise the first argument should
// be nil and a fresh transaction will be created to execute the graph
// traversal.
func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) (
func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) (
*LightningNode, error) {
var node *LightningNode
fetchNode := func(tx kvdb.RTx) error {
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
// First grab the nodes bucket which stores the mapping from
// pubKey to node information.
nodes := tx.ReadBucket(nodeBucket)
@ -2393,14 +2562,9 @@ func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) (
node = &n
return nil
}
var err error
if tx == nil {
err = kvdb.View(c.db, fetchNode, func() {})
} else {
err = fetchNode(tx)
}
}, func() {
node = nil
})
if err != nil {
return nil, err
}
@ -2408,6 +2572,52 @@ func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) (
return node, nil
}
// graphCacheNode is a struct that wraps a LightningNode in a way that it can be
// cached in the graph cache.
type graphCacheNode struct {
pubKeyBytes route.Vertex
features *lnwire.FeatureVector
nodeScratch [8]byte
}
// newGraphCacheNode returns a new cache optimized node.
func newGraphCacheNode(pubKey route.Vertex,
features *lnwire.FeatureVector) *graphCacheNode {
return &graphCacheNode{
pubKeyBytes: pubKey,
features: features,
}
}
// PubKey returns the node's public identity key.
func (n *graphCacheNode) PubKey() route.Vertex {
return n.pubKeyBytes
}
// Features returns the node's features.
func (n *graphCacheNode) Features() *lnwire.FeatureVector {
return n.features
}
// ForEachChannel iterates through all channels of this node, executing the
// passed callback with an edge info structure and the policies of each end
// of the channel. The first edge policy is the outgoing edge *to* the
// connecting node, while the second is the incoming edge *from* the
// connecting node. If the callback returns an error, then the iteration is
// halted with the error propagated back up to the caller.
//
// Unknown policies are passed into the callback as nil values.
func (n *graphCacheNode) ForEachChannel(tx kvdb.RTx,
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error {
return nodeTraversal(tx, n.pubKeyBytes[:], nil, cb)
}
var _ GraphCacheNode = (*graphCacheNode)(nil)
// HasLightningNode determines if the graph has a vertex identified by the
// target node identity public key. If the node exists in the database, a
// timestamp of when the data for the node was lasted updated is returned along
@ -2460,7 +2670,7 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro
// nodeTraversal is used to traverse all channels of a node given by its
// public key and passes channel information into the specified callback.
func nodeTraversal(tx kvdb.RTx, nodePub []byte, db *DB,
func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend,
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error {
traversal := func(tx kvdb.RTx) error {
@ -2548,7 +2758,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db *DB,
// ForEachChannel iterates through all channels of this node, executing the
// passed callback with an edge info structure and the policies of each end
// of the channel. The first edge policy is the outgoing edge *to* the
// the connecting node, while the second is the incoming edge *from* the
// connecting node, while the second is the incoming edge *from* the
// connecting node. If the callback returns an error, then the iteration is
// halted with the error propagated back up to the caller.
//
@ -2559,7 +2769,8 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db *DB,
// be nil and a fresh transaction will be created to execute the graph
// traversal.
func (l *LightningNode) ForEachChannel(tx kvdb.RTx,
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error {
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error {
nodePub := l.PubKeyBytes[:]
db := l.db
@ -2627,7 +2838,7 @@ type ChannelEdgeInfo struct {
// compatible manner.
ExtraOpaqueData []byte
db *DB
db kvdb.Backend
}
// AddNodeKeys is a setter-like method that can be used to replace the set of
@ -2988,7 +3199,7 @@ type ChannelEdgePolicy struct {
// compatible manner.
ExtraOpaqueData []byte
db *DB
db kvdb.Backend
}
// Signature is a channel announcement signature, which is needed for proper
@ -3406,7 +3617,7 @@ func (c *ChannelGraph) MarkEdgeZombie(chanID uint64,
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
err := kvdb.Batch(c.db.Backend, func(tx kvdb.RwTx) error {
err := kvdb.Batch(c.db, func(tx kvdb.RwTx) error {
edges := tx.ReadWriteBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
@ -3417,6 +3628,8 @@ func (c *ChannelGraph) MarkEdgeZombie(chanID uint64,
"bucket: %w", err)
}
c.graphCache.RemoveChannel(pubKey1, pubKey2, chanID)
return markEdgeZombie(zombieIndex, chanID, pubKey1, pubKey2)
})
if err != nil {
@ -3471,6 +3684,18 @@ func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error {
c.rejectCache.remove(chanID)
c.chanCache.remove(chanID)
// We need to add the channel back into our graph cache, otherwise we
// won't use it for path finding.
edgeInfos, err := c.FetchChanInfos([]uint64{chanID})
if err != nil {
return err
}
for _, edgeInfo := range edgeInfos {
c.graphCache.AddChannel(
edgeInfo.Info, edgeInfo.Policy1, edgeInfo.Policy2,
)
}
return nil
}
@ -3696,6 +3921,53 @@ func fetchLightningNode(nodeBucket kvdb.RBucket,
return deserializeLightningNode(nodeReader)
}
func deserializeLightningNodeCacheable(r io.Reader, node *graphCacheNode) error {
// Always populate a feature vector, even if we don't have a node
// announcement and short circuit below.
node.features = lnwire.EmptyFeatureVector()
// Skip ahead:
// - LastUpdate (8 bytes)
if _, err := r.Read(node.nodeScratch[:]); err != nil {
return err
}
if _, err := io.ReadFull(r, node.pubKeyBytes[:]); err != nil {
return err
}
// Read the node announcement flag.
if _, err := r.Read(node.nodeScratch[:2]); err != nil {
return err
}
hasNodeAnn := byteOrder.Uint16(node.nodeScratch[:2])
// The rest of the data is optional, and will only be there if we got a
// node announcement for this node.
if hasNodeAnn == 0 {
return nil
}
// We did get a node announcement for this node, so we'll have the rest
// of the data available.
var rgb uint8
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return err
}
if _, err := wire.ReadVarString(r, 0); err != nil {
return err
}
return node.features.Decode(r)
}
func deserializeLightningNode(r io.Reader) (LightningNode, error) {
var (
node LightningNode
@ -4102,7 +4374,7 @@ func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte,
func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket,
nodes kvdb.RBucket, chanID []byte,
db *DB) (*ChannelEdgePolicy, *ChannelEdgePolicy, error) {
db kvdb.Backend) (*ChannelEdgePolicy, *ChannelEdgePolicy, error) {
edgeInfo := edgeIndex.Get(chanID)
if edgeInfo == nil {

460
channeldb/graph_cache.go Normal file
View File

@ -0,0 +1,460 @@
package channeldb
import (
"fmt"
"sync"
"github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
// GraphCacheNode is an interface for all the information the cache needs to know
// about a lightning node.
type GraphCacheNode interface {
// PubKey is the node's public identity key.
PubKey() route.Vertex
// Features returns the node's p2p features.
Features() *lnwire.FeatureVector
// ForEachChannel iterates through all channels of a given node,
// executing the passed callback with an edge info structure and the
// policies of each end of the channel. The first edge policy is the
// outgoing edge *to* the connecting node, while the second is the
// incoming edge *from* the connecting node. If the callback returns an
// error, then the iteration is halted with the error propagated back up
// to the caller.
ForEachChannel(kvdb.RTx,
func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error
}
// CachedEdgePolicy is a struct that only caches the information of a
// ChannelEdgePolicy that we actually use for pathfinding and therefore need to
// store in the cache.
type CachedEdgePolicy struct {
// ChannelID is the unique channel ID for the channel. The first 3
// bytes are the block height, the next 3 the index within the block,
// and the last 2 bytes are the output index for the channel.
ChannelID uint64
// MessageFlags is a bitfield which indicates the presence of optional
// fields (like max_htlc) in the policy.
MessageFlags lnwire.ChanUpdateMsgFlags
// ChannelFlags is a bitfield which signals the capabilities of the
// channel as well as the directed edge this update applies to.
ChannelFlags lnwire.ChanUpdateChanFlags
// TimeLockDelta is the number of blocks this node will subtract from
// the expiry of an incoming HTLC. This value expresses the time buffer
// the node would like to HTLC exchanges.
TimeLockDelta uint16
// MinHTLC is the smallest value HTLC this node will forward, expressed
// in millisatoshi.
MinHTLC lnwire.MilliSatoshi
// MaxHTLC is the largest value HTLC this node will forward, expressed
// in millisatoshi.
MaxHTLC lnwire.MilliSatoshi
// FeeBaseMSat is the base HTLC fee that will be charged for forwarding
// ANY HTLC, expressed in mSAT's.
FeeBaseMSat lnwire.MilliSatoshi
// FeeProportionalMillionths is the rate that the node will charge for
// HTLCs for each millionth of a satoshi forwarded.
FeeProportionalMillionths lnwire.MilliSatoshi
// ToNodePubKey is a function that returns the to node of a policy.
// Since we only ever store the inbound policy, this is always the node
// that we query the channels for in ForEachChannel(). Therefore, we can
// save a lot of space by not storing this information in the memory and
// instead just set this function when we copy the policy from cache in
// ForEachChannel().
ToNodePubKey func() route.Vertex
// ToNodeFeatures are the to node's features. They are never set while
// the edge is in the cache, only on the copy that is returned in
// ForEachChannel().
ToNodeFeatures *lnwire.FeatureVector
}
// ComputeFee computes the fee to forward an HTLC of `amt` milli-satoshis over
// the passed active payment channel. This value is currently computed as
// specified in BOLT07, but will likely change in the near future.
func (c *CachedEdgePolicy) ComputeFee(
amt lnwire.MilliSatoshi) lnwire.MilliSatoshi {
return c.FeeBaseMSat + (amt*c.FeeProportionalMillionths)/feeRateParts
}
// ComputeFeeFromIncoming computes the fee to forward an HTLC given the incoming
// amount.
func (c *CachedEdgePolicy) ComputeFeeFromIncoming(
incomingAmt lnwire.MilliSatoshi) lnwire.MilliSatoshi {
return incomingAmt - divideCeil(
feeRateParts*(incomingAmt-c.FeeBaseMSat),
feeRateParts+c.FeeProportionalMillionths,
)
}
// NewCachedPolicy turns a full policy into a minimal one that can be cached.
func NewCachedPolicy(policy *ChannelEdgePolicy) *CachedEdgePolicy {
return &CachedEdgePolicy{
ChannelID: policy.ChannelID,
MessageFlags: policy.MessageFlags,
ChannelFlags: policy.ChannelFlags,
TimeLockDelta: policy.TimeLockDelta,
MinHTLC: policy.MinHTLC,
MaxHTLC: policy.MaxHTLC,
FeeBaseMSat: policy.FeeBaseMSat,
FeeProportionalMillionths: policy.FeeProportionalMillionths,
}
}
// DirectedChannel is a type that stores the channel information as seen from
// one side of the channel.
type DirectedChannel struct {
// ChannelID is the unique identifier of this channel.
ChannelID uint64
// IsNode1 indicates if this is the node with the smaller public key.
IsNode1 bool
// OtherNode is the public key of the node on the other end of this
// channel.
OtherNode route.Vertex
// Capacity is the announced capacity of this channel in satoshis.
Capacity btcutil.Amount
// OutPolicySet is a boolean that indicates whether the node has an
// outgoing policy set. For pathfinding only the existence of the policy
// is important to know, not the actual content.
OutPolicySet bool
// InPolicy is the incoming policy *from* the other node to this node.
// In path finding, we're walking backward from the destination to the
// source, so we're always interested in the edge that arrives to us
// from the other node.
InPolicy *CachedEdgePolicy
}
// DeepCopy creates a deep copy of the channel, including the incoming policy.
func (c *DirectedChannel) DeepCopy() *DirectedChannel {
channelCopy := *c
if channelCopy.InPolicy != nil {
inPolicyCopy := *channelCopy.InPolicy
channelCopy.InPolicy = &inPolicyCopy
// The fields for the ToNode can be overwritten by the path
// finding algorithm, which is why we need a deep copy in the
// first place. So we always start out with nil values, just to
// be sure they don't contain any old data.
channelCopy.InPolicy.ToNodePubKey = nil
channelCopy.InPolicy.ToNodeFeatures = nil
}
return &channelCopy
}
// GraphCache is a type that holds a minimal set of information of the public
// channel graph that can be used for pathfinding.
type GraphCache struct {
nodeChannels map[route.Vertex]map[uint64]*DirectedChannel
nodeFeatures map[route.Vertex]*lnwire.FeatureVector
mtx sync.RWMutex
}
// NewGraphCache creates a new graphCache.
func NewGraphCache(preAllocNumNodes int) *GraphCache {
return &GraphCache{
nodeChannels: make(
map[route.Vertex]map[uint64]*DirectedChannel,
// A channel connects two nodes, so we can look it up
// from both sides, meaning we get double the number of
// entries.
preAllocNumNodes*2,
),
nodeFeatures: make(
map[route.Vertex]*lnwire.FeatureVector,
preAllocNumNodes,
),
}
}
// Stats returns statistics about the current cache size.
func (c *GraphCache) Stats() string {
c.mtx.RLock()
defer c.mtx.RUnlock()
numChannels := 0
for node := range c.nodeChannels {
numChannels += len(c.nodeChannels[node])
}
return fmt.Sprintf("num_node_features=%d, num_nodes=%d, "+
"num_channels=%d", len(c.nodeFeatures), len(c.nodeChannels),
numChannels)
}
// AddNode adds a graph node, including all the (directed) channels of that
// node.
func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error {
nodePubKey := node.PubKey()
// Only hold the lock for a short time. The `ForEachChannel()` below is
// possibly slow as it has to go to the backend, so we can unlock
// between the calls. And the AddChannel() method will acquire its own
// lock anyway.
c.mtx.Lock()
c.nodeFeatures[nodePubKey] = node.Features()
c.mtx.Unlock()
return node.ForEachChannel(
tx, func(tx kvdb.RTx, info *ChannelEdgeInfo,
outPolicy *ChannelEdgePolicy,
inPolicy *ChannelEdgePolicy) error {
c.AddChannel(info, outPolicy, inPolicy)
return nil
},
)
}
// AddChannel adds a non-directed channel, meaning that the order of policy 1
// and policy 2 does not matter, the directionality is extracted from the info
// and policy flags automatically. The policy will be set as the outgoing policy
// on one node and the incoming policy on the peer's side.
func (c *GraphCache) AddChannel(info *ChannelEdgeInfo,
policy1 *ChannelEdgePolicy, policy2 *ChannelEdgePolicy) {
if info == nil {
return
}
if policy1 != nil && policy1.IsDisabled() &&
policy2 != nil && policy2.IsDisabled() {
return
}
// Create the edge entry for both nodes.
c.mtx.Lock()
c.updateOrAddEdge(info.NodeKey1Bytes, &DirectedChannel{
ChannelID: info.ChannelID,
IsNode1: true,
OtherNode: info.NodeKey2Bytes,
Capacity: info.Capacity,
})
c.updateOrAddEdge(info.NodeKey2Bytes, &DirectedChannel{
ChannelID: info.ChannelID,
IsNode1: false,
OtherNode: info.NodeKey1Bytes,
Capacity: info.Capacity,
})
c.mtx.Unlock()
// The policy's node is always the to_node. So if policy 1 has to_node
// of node 2 then we have the policy 1 as seen from node 1.
if policy1 != nil {
fromNode, toNode := info.NodeKey1Bytes, info.NodeKey2Bytes
if policy1.Node.PubKeyBytes != info.NodeKey2Bytes {
fromNode, toNode = toNode, fromNode
}
isEdge1 := policy1.ChannelFlags&lnwire.ChanUpdateDirection == 0
c.UpdatePolicy(policy1, fromNode, toNode, isEdge1)
}
if policy2 != nil {
fromNode, toNode := info.NodeKey2Bytes, info.NodeKey1Bytes
if policy2.Node.PubKeyBytes != info.NodeKey1Bytes {
fromNode, toNode = toNode, fromNode
}
isEdge1 := policy2.ChannelFlags&lnwire.ChanUpdateDirection == 0
c.UpdatePolicy(policy2, fromNode, toNode, isEdge1)
}
}
// updateOrAddEdge makes sure the edge information for a node is either updated
// if it already exists or is added to that node's list of channels.
func (c *GraphCache) updateOrAddEdge(node route.Vertex, edge *DirectedChannel) {
if len(c.nodeChannels[node]) == 0 {
c.nodeChannels[node] = make(map[uint64]*DirectedChannel)
}
c.nodeChannels[node][edge.ChannelID] = edge
}
// UpdatePolicy updates a single policy on both the from and to node. The order
// of the from and to node is not strictly important. But we assume that a
// channel edge was added beforehand so that the directed channel struct already
// exists in the cache.
func (c *GraphCache) UpdatePolicy(policy *ChannelEdgePolicy, fromNode,
toNode route.Vertex, edge1 bool) {
c.mtx.Lock()
defer c.mtx.Unlock()
updatePolicy := func(nodeKey route.Vertex) {
if len(c.nodeChannels[nodeKey]) == 0 {
return
}
channel, ok := c.nodeChannels[nodeKey][policy.ChannelID]
if !ok {
return
}
// Edge 1 is defined as the policy for the direction of node1 to
// node2.
switch {
// This is node 1, and it is edge 1, so this is the outgoing
// policy for node 1.
case channel.IsNode1 && edge1:
channel.OutPolicySet = true
// This is node 2, and it is edge 2, so this is the outgoing
// policy for node 2.
case !channel.IsNode1 && !edge1:
channel.OutPolicySet = true
// The other two cases left mean it's the inbound policy for the
// node.
default:
channel.InPolicy = NewCachedPolicy(policy)
}
}
updatePolicy(fromNode)
updatePolicy(toNode)
}
// RemoveNode completely removes a node and all its channels (including the
// peer's side).
func (c *GraphCache) RemoveNode(node route.Vertex) {
c.mtx.Lock()
defer c.mtx.Unlock()
delete(c.nodeFeatures, node)
// First remove all channels from the other nodes' lists.
for _, channel := range c.nodeChannels[node] {
c.removeChannelIfFound(channel.OtherNode, channel.ChannelID)
}
// Then remove our whole node completely.
delete(c.nodeChannels, node)
}
// RemoveChannel removes a single channel between two nodes.
func (c *GraphCache) RemoveChannel(node1, node2 route.Vertex, chanID uint64) {
c.mtx.Lock()
defer c.mtx.Unlock()
// Remove that one channel from both sides.
c.removeChannelIfFound(node1, chanID)
c.removeChannelIfFound(node2, chanID)
}
// removeChannelIfFound removes a single channel from one side.
func (c *GraphCache) removeChannelIfFound(node route.Vertex, chanID uint64) {
if len(c.nodeChannels[node]) == 0 {
return
}
delete(c.nodeChannels[node], chanID)
}
// UpdateChannel updates the channel edge information for a specific edge. We
// expect the edge to already exist and be known. If it does not yet exist, this
// call is a no-op.
func (c *GraphCache) UpdateChannel(info *ChannelEdgeInfo) {
c.mtx.Lock()
defer c.mtx.Unlock()
if len(c.nodeChannels[info.NodeKey1Bytes]) == 0 ||
len(c.nodeChannels[info.NodeKey2Bytes]) == 0 {
return
}
channel, ok := c.nodeChannels[info.NodeKey1Bytes][info.ChannelID]
if ok {
// We only expect to be called when the channel is already
// known.
channel.Capacity = info.Capacity
channel.OtherNode = info.NodeKey2Bytes
}
channel, ok = c.nodeChannels[info.NodeKey2Bytes][info.ChannelID]
if ok {
channel.Capacity = info.Capacity
channel.OtherNode = info.NodeKey1Bytes
}
}
// ForEachChannel invokes the given callback for each channel of the given node.
func (c *GraphCache) ForEachChannel(node route.Vertex,
cb func(channel *DirectedChannel) error) error {
c.mtx.RLock()
defer c.mtx.RUnlock()
channels, ok := c.nodeChannels[node]
if !ok {
return nil
}
features, ok := c.nodeFeatures[node]
if !ok {
log.Warnf("Node %v has no features defined, falling back to "+
"default feature vector for path finding", node)
features = lnwire.EmptyFeatureVector()
}
toNodeCallback := func() route.Vertex {
return node
}
for _, channel := range channels {
// We need to copy the channel and policy to avoid it being
// updated in the cache if the path finding algorithm sets
// fields on it (currently only the ToNodeFeatures of the
// policy).
channelCopy := channel.DeepCopy()
if channelCopy.InPolicy != nil {
channelCopy.InPolicy.ToNodePubKey = toNodeCallback
channelCopy.InPolicy.ToNodeFeatures = features
}
if err := cb(channelCopy); err != nil {
return err
}
}
return nil
}
// GetFeatures returns the features of the node with the given ID.
func (c *GraphCache) GetFeatures(node route.Vertex) *lnwire.FeatureVector {
c.mtx.RLock()
defer c.mtx.RUnlock()
features, ok := c.nodeFeatures[node]
if !ok || features == nil {
// The router expects the features to never be nil, so we return
// an empty feature set instead.
return lnwire.EmptyFeatureVector()
}
return features
}

View File

@ -0,0 +1,147 @@
package channeldb
import (
"encoding/hex"
"testing"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/stretchr/testify/require"
)
var (
pubKey1Bytes, _ = hex.DecodeString(
"0248f5cba4c6da2e4c9e01e81d1404dfac0cbaf3ee934a4fc117d2ea9a64" +
"22c91d",
)
pubKey2Bytes, _ = hex.DecodeString(
"038155ba86a8d3b23c806c855097ca5c9fa0f87621f1e7a7d2835ad057f6" +
"f4484f",
)
pubKey1, _ = route.NewVertexFromBytes(pubKey1Bytes)
pubKey2, _ = route.NewVertexFromBytes(pubKey2Bytes)
)
type node struct {
pubKey route.Vertex
features *lnwire.FeatureVector
edgeInfos []*ChannelEdgeInfo
outPolicies []*ChannelEdgePolicy
inPolicies []*ChannelEdgePolicy
}
func (n *node) PubKey() route.Vertex {
return n.pubKey
}
func (n *node) Features() *lnwire.FeatureVector {
return n.features
}
func (n *node) ForEachChannel(tx kvdb.RTx,
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error {
for idx := range n.edgeInfos {
err := cb(
tx, n.edgeInfos[idx], n.outPolicies[idx],
n.inPolicies[idx],
)
if err != nil {
return err
}
}
return nil
}
// TestGraphCacheAddNode tests that a channel going from node A to node B can be
// cached correctly, independent of the direction we add the channel as.
func TestGraphCacheAddNode(t *testing.T) {
runTest := func(nodeA, nodeB route.Vertex) {
t.Helper()
channelFlagA, channelFlagB := 0, 1
if nodeA == pubKey2 {
channelFlagA, channelFlagB = 1, 0
}
outPolicy1 := &ChannelEdgePolicy{
ChannelID: 1000,
ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagA),
Node: &LightningNode{
PubKeyBytes: nodeB,
Features: lnwire.EmptyFeatureVector(),
},
}
inPolicy1 := &ChannelEdgePolicy{
ChannelID: 1000,
ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagB),
Node: &LightningNode{
PubKeyBytes: nodeA,
Features: lnwire.EmptyFeatureVector(),
},
}
node := &node{
pubKey: nodeA,
features: lnwire.EmptyFeatureVector(),
edgeInfos: []*ChannelEdgeInfo{{
ChannelID: 1000,
// Those are direction independent!
NodeKey1Bytes: pubKey1,
NodeKey2Bytes: pubKey2,
Capacity: 500,
}},
outPolicies: []*ChannelEdgePolicy{outPolicy1},
inPolicies: []*ChannelEdgePolicy{inPolicy1},
}
cache := NewGraphCache(10)
require.NoError(t, cache.AddNode(nil, node))
var fromChannels, toChannels []*DirectedChannel
_ = cache.ForEachChannel(nodeA, func(c *DirectedChannel) error {
fromChannels = append(fromChannels, c)
return nil
})
_ = cache.ForEachChannel(nodeB, func(c *DirectedChannel) error {
toChannels = append(toChannels, c)
return nil
})
require.Len(t, fromChannels, 1)
require.Len(t, toChannels, 1)
require.Equal(t, outPolicy1 != nil, fromChannels[0].OutPolicySet)
assertCachedPolicyEqual(t, inPolicy1, fromChannels[0].InPolicy)
require.Equal(t, inPolicy1 != nil, toChannels[0].OutPolicySet)
assertCachedPolicyEqual(t, outPolicy1, toChannels[0].InPolicy)
}
runTest(pubKey1, pubKey2)
runTest(pubKey2, pubKey1)
}
func assertCachedPolicyEqual(t *testing.T, original *ChannelEdgePolicy,
cached *CachedEdgePolicy) {
require.Equal(t, original.ChannelID, cached.ChannelID)
require.Equal(t, original.MessageFlags, cached.MessageFlags)
require.Equal(t, original.ChannelFlags, cached.ChannelFlags)
require.Equal(t, original.TimeLockDelta, cached.TimeLockDelta)
require.Equal(t, original.MinHTLC, cached.MinHTLC)
require.Equal(t, original.MaxHTLC, cached.MaxHTLC)
require.Equal(t, original.FeeBaseMSat, cached.FeeBaseMSat)
require.Equal(
t, original.FeeProportionalMillionths,
cached.FeeProportionalMillionths,
)
require.Equal(
t,
route.Vertex(original.Node.PubKeyBytes),
cached.ToNodePubKey(),
)
require.Equal(t, original.Node.Features, cached.ToNodeFeatures)
}

File diff suppressed because it is too large Load Diff

View File

@ -56,12 +56,14 @@ type LinkNode struct {
// authenticated connection for the stored identity public key.
Addresses []net.Addr
db *DB
// db is the database instance this node was fetched from. This is used
// to sync back the node's state if it is updated.
db *LinkNodeDB
}
// NewLinkNode creates a new LinkNode from the provided parameters, which is
// backed by an instance of channeldb.
func (d *DB) NewLinkNode(bitNet wire.BitcoinNet, pub *btcec.PublicKey,
// backed by an instance of a link node DB.
func NewLinkNode(db *LinkNodeDB, bitNet wire.BitcoinNet, pub *btcec.PublicKey,
addrs ...net.Addr) *LinkNode {
return &LinkNode{
@ -69,7 +71,7 @@ func (d *DB) NewLinkNode(bitNet wire.BitcoinNet, pub *btcec.PublicKey,
IdentityPub: pub,
LastSeen: time.Now(),
Addresses: addrs,
db: d,
db: db,
}
}
@ -98,10 +100,9 @@ func (l *LinkNode) AddAddress(addr net.Addr) error {
// Sync performs a full database sync which writes the current up-to-date data
// within the struct to the database.
func (l *LinkNode) Sync() error {
// Finally update the database by storing the link node and updating
// any relevant indexes.
return kvdb.Update(l.db, func(tx kvdb.RwTx) error {
return kvdb.Update(l.db.backend, func(tx kvdb.RwTx) error {
nodeMetaBucket := tx.ReadWriteBucket(nodeInfoBucket)
if nodeMetaBucket == nil {
return ErrLinkNodesNotFound
@ -127,15 +128,20 @@ func putLinkNode(nodeMetaBucket kvdb.RwBucket, l *LinkNode) error {
return nodeMetaBucket.Put(nodePub, b.Bytes())
}
// LinkNodeDB is a database that keeps track of all link nodes.
type LinkNodeDB struct {
backend kvdb.Backend
}
// DeleteLinkNode removes the link node with the given identity from the
// database.
func (d *DB) DeleteLinkNode(identity *btcec.PublicKey) error {
return kvdb.Update(d, func(tx kvdb.RwTx) error {
return d.deleteLinkNode(tx, identity)
func (l *LinkNodeDB) DeleteLinkNode(identity *btcec.PublicKey) error {
return kvdb.Update(l.backend, func(tx kvdb.RwTx) error {
return deleteLinkNode(tx, identity)
}, func() {})
}
func (d *DB) deleteLinkNode(tx kvdb.RwTx, identity *btcec.PublicKey) error {
func deleteLinkNode(tx kvdb.RwTx, identity *btcec.PublicKey) error {
nodeMetaBucket := tx.ReadWriteBucket(nodeInfoBucket)
if nodeMetaBucket == nil {
return ErrLinkNodesNotFound
@ -148,9 +154,9 @@ func (d *DB) deleteLinkNode(tx kvdb.RwTx, identity *btcec.PublicKey) error {
// FetchLinkNode attempts to lookup the data for a LinkNode based on a target
// identity public key. If a particular LinkNode for the passed identity public
// key cannot be found, then ErrNodeNotFound if returned.
func (d *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) {
func (l *LinkNodeDB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) {
var linkNode *LinkNode
err := kvdb.View(d, func(tx kvdb.RTx) error {
err := kvdb.View(l.backend, func(tx kvdb.RTx) error {
node, err := fetchLinkNode(tx, identity)
if err != nil {
return err
@ -191,10 +197,10 @@ func fetchLinkNode(tx kvdb.RTx, targetPub *btcec.PublicKey) (*LinkNode, error) {
// FetchAllLinkNodes starts a new database transaction to fetch all nodes with
// whom we have active channels with.
func (d *DB) FetchAllLinkNodes() ([]*LinkNode, error) {
func (l *LinkNodeDB) FetchAllLinkNodes() ([]*LinkNode, error) {
var linkNodes []*LinkNode
err := kvdb.View(d, func(tx kvdb.RTx) error {
nodes, err := d.fetchAllLinkNodes(tx)
err := kvdb.View(l.backend, func(tx kvdb.RTx) error {
nodes, err := fetchAllLinkNodes(tx)
if err != nil {
return err
}
@ -213,7 +219,7 @@ func (d *DB) FetchAllLinkNodes() ([]*LinkNode, error) {
// fetchAllLinkNodes uses an existing database transaction to fetch all nodes
// with whom we have active channels with.
func (d *DB) fetchAllLinkNodes(tx kvdb.RTx) ([]*LinkNode, error) {
func fetchAllLinkNodes(tx kvdb.RTx) ([]*LinkNode, error) {
nodeMetaBucket := tx.ReadBucket(nodeInfoBucket)
if nodeMetaBucket == nil {
return nil, ErrLinkNodesNotFound

View File

@ -13,12 +13,14 @@ import (
func TestLinkNodeEncodeDecode(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
// First we'll create some initial data to use for populating our test
// LinkNode instances.
_, pub1 := btcec.PrivKeyFromBytes(btcec.S256(), key[:])
@ -34,8 +36,8 @@ func TestLinkNodeEncodeDecode(t *testing.T) {
// Create two fresh link node instances with the above dummy data, then
// fully sync both instances to disk.
node1 := cdb.NewLinkNode(wire.MainNet, pub1, addr1)
node2 := cdb.NewLinkNode(wire.TestNet3, pub2, addr2)
node1 := NewLinkNode(cdb.linkNodeDB, wire.MainNet, pub1, addr1)
node2 := NewLinkNode(cdb.linkNodeDB, wire.TestNet3, pub2, addr2)
if err := node1.Sync(); err != nil {
t.Fatalf("unable to sync node: %v", err)
}
@ -46,7 +48,7 @@ func TestLinkNodeEncodeDecode(t *testing.T) {
// Fetch all current link nodes from the database, they should exactly
// match the two created above.
originalNodes := []*LinkNode{node2, node1}
linkNodes, err := cdb.FetchAllLinkNodes()
linkNodes, err := cdb.linkNodeDB.FetchAllLinkNodes()
if err != nil {
t.Fatalf("unable to fetch nodes: %v", err)
}
@ -82,7 +84,7 @@ func TestLinkNodeEncodeDecode(t *testing.T) {
}
// Fetch the same node from the database according to its public key.
node1DB, err := cdb.FetchLinkNode(pub1)
node1DB, err := cdb.linkNodeDB.FetchLinkNode(pub1)
if err != nil {
t.Fatalf("unable to find node: %v", err)
}
@ -110,31 +112,33 @@ func TestLinkNodeEncodeDecode(t *testing.T) {
func TestDeleteLinkNode(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := MakeTestDB()
fullDB, cleanUp, err := MakeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
cdb := fullDB.ChannelStateDB()
_, pubKey := btcec.PrivKeyFromBytes(btcec.S256(), key[:])
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1337,
}
linkNode := cdb.NewLinkNode(wire.TestNet3, pubKey, addr)
linkNode := NewLinkNode(cdb.linkNodeDB, wire.TestNet3, pubKey, addr)
if err := linkNode.Sync(); err != nil {
t.Fatalf("unable to write link node to db: %v", err)
}
if _, err := cdb.FetchLinkNode(pubKey); err != nil {
if _, err := cdb.linkNodeDB.FetchLinkNode(pubKey); err != nil {
t.Fatalf("unable to find link node: %v", err)
}
if err := cdb.DeleteLinkNode(pubKey); err != nil {
if err := cdb.linkNodeDB.DeleteLinkNode(pubKey); err != nil {
t.Fatalf("unable to delete link node from db: %v", err)
}
if _, err := cdb.FetchLinkNode(pubKey); err == nil {
if _, err := cdb.linkNodeDB.FetchLinkNode(pubKey); err == nil {
t.Fatal("should not have found link node in db, but did")
}
}

View File

@ -17,6 +17,12 @@ const (
// in order to reply to gossip queries. This produces a cache size of
// around 40MB.
DefaultChannelCacheSize = 20000
// DefaultPreAllocCacheNumNodes is the default number of channels we
// assume for mainnet for pre-allocating the graph cache. As of
// September 2021, there currently are 14k nodes in a strictly pruned
// graph, so we choose a number that is slightly higher.
DefaultPreAllocCacheNumNodes = 15000
)
// Options holds parameters for tuning and customizing a channeldb.DB.
@ -35,6 +41,10 @@ type Options struct {
// wait before attempting to commit a pending set of updates.
BatchCommitInterval time.Duration
// PreAllocCacheNumNodes is the number of nodes we expect to be in the
// graph cache, so we can pre-allocate the map accordingly.
PreAllocCacheNumNodes int
// clock is the time source used by the database.
clock clock.Clock
@ -52,9 +62,10 @@ func DefaultOptions() Options {
AutoCompactMinAge: kvdb.DefaultBoltAutoCompactMinAge,
DBTimeout: kvdb.DefaultDBTimeout,
},
RejectCacheSize: DefaultRejectCacheSize,
ChannelCacheSize: DefaultChannelCacheSize,
clock: clock.NewDefaultClock(),
RejectCacheSize: DefaultRejectCacheSize,
ChannelCacheSize: DefaultChannelCacheSize,
PreAllocCacheNumNodes: DefaultPreAllocCacheNumNodes,
clock: clock.NewDefaultClock(),
}
}
@ -75,6 +86,13 @@ func OptionSetChannelCacheSize(n int) OptionModifier {
}
}
// OptionSetPreAllocCacheNumNodes sets the PreAllocCacheNumNodes to n.
func OptionSetPreAllocCacheNumNodes(n int) OptionModifier {
return func(o *Options) {
o.PreAllocCacheNumNodes = n
}
}
// OptionSetSyncFreelist allows the database to sync its freelist.
func OptionSetSyncFreelist(b bool) OptionModifier {
return func(o *Options) {

View File

@ -36,12 +36,12 @@ type WaitingProofStore struct {
// cache is used in order to reduce the number of redundant get
// calls, when object isn't stored in it.
cache map[WaitingProofKey]struct{}
db *DB
db kvdb.Backend
mu sync.RWMutex
}
// NewWaitingProofStore creates new instance of proofs storage.
func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) {
func NewWaitingProofStore(db kvdb.Backend) (*WaitingProofStore, error) {
s := &WaitingProofStore{
db: db,
}

View File

@ -17,7 +17,7 @@ type ChannelNotifier struct {
ntfnServer *subscribe.Server
chanDB *channeldb.DB
chanDB *channeldb.ChannelStateDB
}
// PendingOpenChannelEvent represents a new event where a new channel has
@ -76,7 +76,7 @@ type FullyResolvedChannelEvent struct {
// New creates a new channel notifier. The ChannelNotifier gets channel
// events from peers and from the chain arbitrator, and dispatches them to
// its clients.
func New(chanDB *channeldb.DB) *ChannelNotifier {
func New(chanDB *channeldb.ChannelStateDB) *ChannelNotifier {
return &ChannelNotifier{
ntfnServer: subscribe.NewServer(),
chanDB: chanDB,

View File

@ -34,7 +34,7 @@ const (
// need the secret key chain in order obtain the prior shachain root so we can
// verify the DLP protocol as initiated by the remote node.
type chanDBRestorer struct {
db *channeldb.DB
db *channeldb.ChannelStateDB
secretKeys keychain.SecretKeyRing

View File

@ -136,7 +136,7 @@ type BreachConfig struct {
// DB provides access to the user's channels, allowing the breach
// arbiter to determine the current state of a user's channels, and how
// it should respond to channel closure.
DB *channeldb.DB
DB *channeldb.ChannelStateDB
// Estimator is used by the breach arbiter to determine an appropriate
// fee level when generating, signing, and broadcasting sweep
@ -1432,11 +1432,11 @@ func (b *BreachArbiter) sweepSpendableOutputsTxn(txWeight int64,
// store is to ensure that we can recover from a restart in the middle of a
// breached contract retribution.
type RetributionStore struct {
db *channeldb.DB
db kvdb.Backend
}
// NewRetributionStore creates a new instance of a RetributionStore.
func NewRetributionStore(db *channeldb.DB) *RetributionStore {
func NewRetributionStore(db kvdb.Backend) *RetributionStore {
return &RetributionStore{
db: db,
}

View File

@ -987,7 +987,7 @@ func initBreachedState(t *testing.T) (*BreachArbiter,
contractBreaches := make(chan *ContractBreachEvent)
brar, cleanUpArb, err := createTestArbiter(
t, contractBreaches, alice.State().Db,
t, contractBreaches, alice.State().Db.GetParentDB(),
)
if err != nil {
t.Fatalf("unable to initialize test breach arbiter: %v", err)
@ -1164,7 +1164,7 @@ func TestBreachHandoffFail(t *testing.T) {
assertNotPendingClosed(t, alice)
brar, cleanUpArb, err := createTestArbiter(
t, contractBreaches, alice.State().Db,
t, contractBreaches, alice.State().Db.GetParentDB(),
)
if err != nil {
t.Fatalf("unable to initialize test breach arbiter: %v", err)
@ -2075,7 +2075,7 @@ func assertNoArbiterBreach(t *testing.T, brar *BreachArbiter,
// assertBrarCleanup blocks until the given channel point has been removed the
// retribution store and the channel is fully closed in the database.
func assertBrarCleanup(t *testing.T, brar *BreachArbiter,
chanPoint *wire.OutPoint, db *channeldb.DB) {
chanPoint *wire.OutPoint, db *channeldb.ChannelStateDB) {
t.Helper()
@ -2174,7 +2174,7 @@ func createTestArbiter(t *testing.T, contractBreaches chan *ContractBreachEvent,
notifier := mock.MakeMockSpendNotifier()
ba := NewBreachArbiter(&BreachConfig{
CloseLink: func(_ *wire.OutPoint, _ ChannelCloseType) {},
DB: db,
DB: db.ChannelStateDB(),
Estimator: chainfee.NewStaticEstimator(12500, 0),
GenSweepScript: func() ([]byte, error) { return nil, nil },
ContractBreaches: contractBreaches,
@ -2375,7 +2375,7 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa
RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: aliceCommit,
RemoteCommitment: aliceCommit,
Db: dbAlice,
Db: dbAlice.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID),
FundingTxn: channels.TestFundingTx,
}
@ -2393,7 +2393,7 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa
RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: bobCommit,
RemoteCommitment: bobCommit,
Db: dbBob,
Db: dbBob.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID),
}

View File

@ -258,7 +258,9 @@ func (a *arbChannel) NewAnchorResolutions() (*lnwallet.AnchorResolutions,
// same instance that is used by the link.
chanPoint := a.channel.FundingOutpoint
channel, err := a.c.chanSource.FetchChannel(nil, chanPoint)
channel, err := a.c.chanSource.ChannelStateDB().FetchChannel(
nil, chanPoint,
)
if err != nil {
return nil, err
}
@ -301,7 +303,9 @@ func (a *arbChannel) ForceCloseChan() (*lnwallet.LocalForceCloseSummary, error)
// Now that we know the link can't mutate the channel
// state, we'll read the channel from disk the target
// channel according to its channel point.
channel, err := a.c.chanSource.FetchChannel(nil, chanPoint)
channel, err := a.c.chanSource.ChannelStateDB().FetchChannel(
nil, chanPoint,
)
if err != nil {
return nil, err
}
@ -422,7 +426,7 @@ func (c *ChainArbitrator) ResolveContract(chanPoint wire.OutPoint) error {
// First, we'll we'll mark the channel as fully closed from the PoV of
// the channel source.
err := c.chanSource.MarkChanFullyClosed(&chanPoint)
err := c.chanSource.ChannelStateDB().MarkChanFullyClosed(&chanPoint)
if err != nil {
log.Errorf("ChainArbitrator: unable to mark ChannelPoint(%v) "+
"fully closed: %v", chanPoint, err)
@ -480,7 +484,7 @@ func (c *ChainArbitrator) Start() error {
// First, we'll fetch all the channels that are still open, in order to
// collect them within our set of active contracts.
openChannels, err := c.chanSource.FetchAllChannels()
openChannels, err := c.chanSource.ChannelStateDB().FetchAllChannels()
if err != nil {
return err
}
@ -538,7 +542,9 @@ func (c *ChainArbitrator) Start() error {
// In addition to the channels that we know to be open, we'll also
// launch arbitrators to finishing resolving any channels that are in
// the pending close state.
closingChannels, err := c.chanSource.FetchClosedChannels(true)
closingChannels, err := c.chanSource.ChannelStateDB().FetchClosedChannels(
true,
)
if err != nil {
return err
}

View File

@ -49,7 +49,7 @@ func TestChainArbitratorRepublishCloses(t *testing.T) {
// We manually set the db here to make sure all channels are
// synced to the same db.
channel.Db = db
channel.Db = db.ChannelStateDB()
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
@ -165,7 +165,7 @@ func TestResolveContract(t *testing.T) {
}
defer cleanup()
channel := newChannel.State()
channel.Db = db
channel.Db = db.ChannelStateDB()
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18556,
@ -205,7 +205,7 @@ func TestResolveContract(t *testing.T) {
// While the resolver are active, we'll now remove the channel from the
// database (mark is as closed).
err = db.AbandonChannel(&channel.FundingOutpoint, 4)
err = db.ChannelStateDB().AbandonChannel(&channel.FundingOutpoint, 4)
if err != nil {
t.Fatalf("unable to remove channel: %v", err)
}

View File

@ -58,7 +58,7 @@ func copyChannelState(state *channeldb.OpenChannel) (
*channeldb.OpenChannel, func(), error) {
// Make a copy of the DB.
dbFile := filepath.Join(state.Db.Path(), "channel.db")
dbFile := filepath.Join(state.Db.GetParentDB().Path(), "channel.db")
tempDbPath, err := ioutil.TempDir("", "past-state")
if err != nil {
return nil, nil, err
@ -81,7 +81,7 @@ func copyChannelState(state *channeldb.OpenChannel) (
return nil, nil, err
}
chans, err := newDb.FetchAllChannels()
chans, err := newDb.ChannelStateDB().FetchAllChannels()
if err != nil {
cleanup()
return nil, nil, err

View File

@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
)
@ -59,7 +58,7 @@ type GossipMessageStore interface {
// version of a message (like in the case of multiple ChannelUpdate's) for a
// channel with a peer.
type MessageStore struct {
db *channeldb.DB
db kvdb.Backend
}
// A compile-time assertion to ensure messageStore implements the
@ -67,8 +66,8 @@ type MessageStore struct {
var _ GossipMessageStore = (*MessageStore)(nil)
// NewMessageStore creates a new message store backed by a channeldb instance.
func NewMessageStore(db *channeldb.DB) (*MessageStore, error) {
err := kvdb.Batch(db.Backend, func(tx kvdb.RwTx) error {
func NewMessageStore(db kvdb.Backend) (*MessageStore, error) {
err := kvdb.Batch(db, func(tx kvdb.RwTx) error {
_, err := tx.CreateTopLevelBucket(messageStoreBucket)
return err
})
@ -124,7 +123,7 @@ func (s *MessageStore) AddMessage(msg lnwire.Message, peerPubKey [33]byte) error
return err
}
return kvdb.Batch(s.db.Backend, func(tx kvdb.RwTx) error {
return kvdb.Batch(s.db, func(tx kvdb.RwTx) error {
messageStore := tx.ReadWriteBucket(messageStoreBucket)
if messageStore == nil {
return ErrCorruptedMessageStore
@ -145,7 +144,7 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message,
return err
}
return kvdb.Batch(s.db.Backend, func(tx kvdb.RwTx) error {
return kvdb.Batch(s.db, func(tx kvdb.RwTx) error {
messageStore := tx.ReadWriteBucket(messageStoreBucket)
if messageStore == nil {
return ErrCorruptedMessageStore

View File

@ -59,6 +59,18 @@ in `lnd`, saving developer time and limiting the potential for bugs.
Instructions for enabling Postgres can be found in
[docs/postgres.md](../postgres.md).
### In-memory path finding
Finding a path through the channel graph for sending a payment doesn't involve
any database queries anymore. The [channel graph is now kept fully
in-memory](https://github.com/lightningnetwork/lnd/pull/5642) for up a massive
performance boost when calling `QueryRoutes` or any of the `SendPayment`
variants. Keeping the full graph in memory naturally comes with increased RAM
usage. Users running `lnd` on low-memory systems are advised to run with the
`routing.strictgraphpruning=true` configuration option that more aggressively
removes zombie channels from the graph, reducing the number of channels that
need to be kept in memory.
## Protocol Extensions
### Explicit Channel Negotiation

View File

@ -23,7 +23,6 @@ import (
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/labels"
"github.com/lightningnetwork/lnd/lnpeer"
"github.com/lightningnetwork/lnd/lnrpc"
@ -550,19 +549,6 @@ const (
addedToRouterGraph
)
var (
// channelOpeningStateBucket is the database bucket used to store the
// channelOpeningState for each channel that is currently in the process
// of being opened.
channelOpeningStateBucket = []byte("channelOpeningState")
// ErrChannelNotFound is an error returned when a channel is not known
// to us. In this case of the fundingManager, this error is returned
// when the channel in question is not considered being in an opening
// state.
ErrChannelNotFound = fmt.Errorf("channel not found")
)
// NewFundingManager creates and initializes a new instance of the
// fundingManager.
func NewFundingManager(cfg Config) (*Manager, error) {
@ -887,7 +873,7 @@ func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel,
channelState, shortChanID, err := f.getChannelOpeningState(
&channel.FundingOutpoint,
)
if err == ErrChannelNotFound {
if err == channeldb.ErrChannelNotFound {
// Channel not in fundingManager's opening database,
// meaning it was successfully announced to the
// network.
@ -3551,26 +3537,20 @@ func copyPubKey(pub *btcec.PublicKey) *btcec.PublicKey {
// chanPoint to the channelOpeningStateBucket.
func (f *Manager) saveChannelOpeningState(chanPoint *wire.OutPoint,
state channelOpeningState, shortChanID *lnwire.ShortChannelID) error {
return kvdb.Update(f.cfg.Wallet.Cfg.Database, func(tx kvdb.RwTx) error {
bucket, err := tx.CreateTopLevelBucket(channelOpeningStateBucket)
if err != nil {
return err
}
var outpointBytes bytes.Buffer
if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil {
return err
}
var outpointBytes bytes.Buffer
if err = WriteOutpoint(&outpointBytes, chanPoint); err != nil {
return err
}
// Save state and the uint64 representation of the shortChanID
// for later use.
scratch := make([]byte, 10)
byteOrder.PutUint16(scratch[:2], uint16(state))
byteOrder.PutUint64(scratch[2:], shortChanID.ToUint64())
return bucket.Put(outpointBytes.Bytes(), scratch)
}, func() {})
// Save state and the uint64 representation of the shortChanID
// for later use.
scratch := make([]byte, 10)
byteOrder.PutUint16(scratch[:2], uint16(state))
byteOrder.PutUint64(scratch[2:], shortChanID.ToUint64())
return f.cfg.Wallet.Cfg.Database.SaveChannelOpeningState(
outpointBytes.Bytes(), scratch,
)
}
// getChannelOpeningState fetches the channelOpeningState for the provided
@ -3579,51 +3559,31 @@ func (f *Manager) saveChannelOpeningState(chanPoint *wire.OutPoint,
func (f *Manager) getChannelOpeningState(chanPoint *wire.OutPoint) (
channelOpeningState, *lnwire.ShortChannelID, error) {
var state channelOpeningState
var shortChanID lnwire.ShortChannelID
err := kvdb.View(f.cfg.Wallet.Cfg.Database, func(tx kvdb.RTx) error {
var outpointBytes bytes.Buffer
if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil {
return 0, nil, err
}
bucket := tx.ReadBucket(channelOpeningStateBucket)
if bucket == nil {
// If the bucket does not exist, it means we never added
// a channel to the db, so return ErrChannelNotFound.
return ErrChannelNotFound
}
var outpointBytes bytes.Buffer
if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil {
return err
}
value := bucket.Get(outpointBytes.Bytes())
if value == nil {
return ErrChannelNotFound
}
state = channelOpeningState(byteOrder.Uint16(value[:2]))
shortChanID = lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:]))
return nil
}, func() {})
value, err := f.cfg.Wallet.Cfg.Database.GetChannelOpeningState(
outpointBytes.Bytes(),
)
if err != nil {
return 0, nil, err
}
state := channelOpeningState(byteOrder.Uint16(value[:2]))
shortChanID := lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:]))
return state, &shortChanID, nil
}
// deleteChannelOpeningState removes any state for chanPoint from the database.
func (f *Manager) deleteChannelOpeningState(chanPoint *wire.OutPoint) error {
return kvdb.Update(f.cfg.Wallet.Cfg.Database, func(tx kvdb.RwTx) error {
bucket := tx.ReadWriteBucket(channelOpeningStateBucket)
if bucket == nil {
return fmt.Errorf("bucket not found")
}
var outpointBytes bytes.Buffer
if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil {
return err
}
var outpointBytes bytes.Buffer
if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil {
return err
}
return bucket.Delete(outpointBytes.Bytes())
}, func() {})
return f.cfg.Wallet.Cfg.Database.DeleteChannelOpeningState(
outpointBytes.Bytes(),
)
}

View File

@ -262,7 +262,7 @@ func (n *testNode) AddNewChannel(channel *channeldb.OpenChannel,
}
}
func createTestWallet(cdb *channeldb.DB, netParams *chaincfg.Params,
func createTestWallet(cdb *channeldb.ChannelStateDB, netParams *chaincfg.Params,
notifier chainntnfs.ChainNotifier, wc lnwallet.WalletController,
signer input.Signer, keyRing keychain.SecretKeyRing,
bio lnwallet.BlockChainIO,
@ -330,11 +330,13 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey,
}
dbDir := filepath.Join(tempTestDir, "cdb")
cdb, err := channeldb.Open(dbDir)
fullDB, err := channeldb.Open(dbDir)
if err != nil {
return nil, err
}
cdb := fullDB.ChannelStateDB()
keyRing := &mock.SecretKeyRing{
RootKey: alicePrivKey,
}
@ -923,12 +925,12 @@ func assertDatabaseState(t *testing.T, node *testNode,
}
state, _, err = node.fundingMgr.getChannelOpeningState(
fundingOutPoint)
if err != nil && err != ErrChannelNotFound {
if err != nil && err != channeldb.ErrChannelNotFound {
t.Fatalf("unable to get channel state: %v", err)
}
// If we found the channel, check if it had the expected state.
if err != ErrChannelNotFound && state == expectedState {
if err != channeldb.ErrChannelNotFound && state == expectedState {
// Got expected state, return with success.
return
}
@ -1166,7 +1168,7 @@ func assertErrChannelNotFound(t *testing.T, node *testNode,
}
state, _, err = node.fundingMgr.getChannelOpeningState(
fundingOutPoint)
if err == ErrChannelNotFound {
if err == channeldb.ErrChannelNotFound {
// Got expected state, return with success.
return
} else if err != nil {

View File

@ -199,9 +199,16 @@ type circuitMap struct {
// parameterize an instance of circuitMap.
type CircuitMapConfig struct {
// DB provides the persistent storage engine for the circuit map.
// TODO(conner): create abstraction to allow for the substitution of
// other persistence engines.
DB *channeldb.DB
DB kvdb.Backend
// FetchAllOpenChannels is a function that fetches all currently open
// channels from the channel database.
FetchAllOpenChannels func() ([]*channeldb.OpenChannel, error)
// FetchClosedChannels is a function that fetches all closed channels
// from the channel database.
FetchClosedChannels func(
pendingOnly bool) ([]*channeldb.ChannelCloseSummary, error)
// ExtractErrorEncrypter derives the shared secret used to encrypt
// errors from the obfuscator's ephemeral public key.
@ -296,7 +303,7 @@ func (cm *circuitMap) cleanClosedChannels() error {
// Find closed channels and cache their ShortChannelIDs into a map.
// This map will be used for looking up relative circuits and keystones.
closedChannels, err := cm.cfg.DB.FetchClosedChannels(false)
closedChannels, err := cm.cfg.FetchClosedChannels(false)
if err != nil {
return err
}
@ -629,7 +636,7 @@ func (cm *circuitMap) decodeCircuit(v []byte) (*PaymentCircuit, error) {
// channels. Therefore, it must be called before any links are created to avoid
// interfering with normal operation.
func (cm *circuitMap) trimAllOpenCircuits() error {
activeChannels, err := cm.cfg.DB.FetchAllOpenChannels()
activeChannels, err := cm.cfg.FetchAllOpenChannels()
if err != nil {
return err
}
@ -860,7 +867,7 @@ func (cm *circuitMap) CommitCircuits(circuits ...*PaymentCircuit) (
// Write the entire batch of circuits to the persistent circuit bucket
// using bolt's Batch write. This method must be called from multiple,
// distinct goroutines to have any impact on performance.
err := kvdb.Batch(cm.cfg.DB.Backend, func(tx kvdb.RwTx) error {
err := kvdb.Batch(cm.cfg.DB, func(tx kvdb.RwTx) error {
circuitBkt := tx.ReadWriteBucket(circuitAddKey)
if circuitBkt == nil {
return ErrCorruptedCircuitMap
@ -1091,7 +1098,7 @@ func (cm *circuitMap) DeleteCircuits(inKeys ...CircuitKey) error {
}
cm.mtx.Unlock()
err := kvdb.Batch(cm.cfg.DB.Backend, func(tx kvdb.RwTx) error {
err := kvdb.Batch(cm.cfg.DB, func(tx kvdb.RwTx) error {
for _, circuit := range removedCircuits {
// If this htlc made it to an outgoing link, load the
// keystone bucket from which we will remove the

View File

@ -103,8 +103,11 @@ func newCircuitMap(t *testing.T) (*htlcswitch.CircuitMapConfig,
onionProcessor := newOnionProcessor(t)
db := makeCircuitDB(t, "")
circuitMapCfg := &htlcswitch.CircuitMapConfig{
DB: makeCircuitDB(t, ""),
DB: db,
FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels,
FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels,
ExtractErrorEncrypter: onionProcessor.ExtractErrorEncrypter,
}
@ -634,13 +637,17 @@ func makeCircuitDB(t *testing.T, path string) *channeldb.DB {
func restartCircuitMap(t *testing.T, cfg *htlcswitch.CircuitMapConfig) (
*htlcswitch.CircuitMapConfig, htlcswitch.CircuitMap) {
// Record the current temp path and close current db.
dbPath := cfg.DB.Path()
// Record the current temp path and close current db. We know we have
// a full channeldb.DB here since we created it just above.
dbPath := cfg.DB.(*channeldb.DB).Path()
cfg.DB.Close()
// Reinitialize circuit map with same db path.
db := makeCircuitDB(t, dbPath)
cfg2 := &htlcswitch.CircuitMapConfig{
DB: makeCircuitDB(t, dbPath),
DB: db,
FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels,
FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels,
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
}
cm2, err := htlcswitch.NewCircuitMap(cfg2)

View File

@ -1938,7 +1938,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
pCache := newMockPreimageCache()
aliceDb := aliceLc.channel.State().Db
aliceDb := aliceLc.channel.State().Db.GetParentDB()
aliceSwitch, err := initSwitchWithDB(testStartingHeight, aliceDb)
if err != nil {
return nil, nil, nil, nil, nil, nil, err
@ -4438,7 +4438,7 @@ func (h *persistentLinkHarness) restartLink(
pCache = newMockPreimageCache()
)
aliceDb := aliceChannel.State().Db
aliceDb := aliceChannel.State().Db.GetParentDB()
aliceSwitch := h.coreLink.cfg.Switch
if restartSwitch {
var err error

View File

@ -170,8 +170,10 @@ func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error)
}
cfg := Config{
DB: db,
SwitchPackager: channeldb.NewSwitchPackager(),
DB: db,
FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels,
FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels,
SwitchPackager: channeldb.NewSwitchPackager(),
FwdingLog: &mockForwardingLog{
events: make(map[time.Time]channeldb.ForwardingEvent),
},

View File

@ -83,7 +83,7 @@ func deserializeNetworkResult(r io.Reader) (*networkResult, error) {
// is back. The Switch will checkpoint any received result to the store, and
// the store will keep results and notify the callers about them.
type networkResultStore struct {
db *channeldb.DB
backend kvdb.Backend
// results is a map from paymentIDs to channels where subscribers to
// payment results will be notified.
@ -96,9 +96,9 @@ type networkResultStore struct {
paymentIDMtx *multimutex.Mutex
}
func newNetworkResultStore(db *channeldb.DB) *networkResultStore {
func newNetworkResultStore(db kvdb.Backend) *networkResultStore {
return &networkResultStore{
db: db,
backend: db,
results: make(map[uint64][]chan *networkResult),
paymentIDMtx: multimutex.NewMutex(),
}
@ -126,7 +126,7 @@ func (store *networkResultStore) storeResult(paymentID uint64,
var paymentIDBytes [8]byte
binary.BigEndian.PutUint64(paymentIDBytes[:], paymentID)
err := kvdb.Batch(store.db.Backend, func(tx kvdb.RwTx) error {
err := kvdb.Batch(store.backend, func(tx kvdb.RwTx) error {
networkResults, err := tx.CreateTopLevelBucket(
networkResultStoreBucketKey,
)
@ -171,7 +171,7 @@ func (store *networkResultStore) subscribeResult(paymentID uint64) (
resultChan = make(chan *networkResult, 1)
)
err := kvdb.View(store.db, func(tx kvdb.RTx) error {
err := kvdb.View(store.backend, func(tx kvdb.RTx) error {
var err error
result, err = fetchResult(tx, paymentID)
switch {
@ -219,7 +219,7 @@ func (store *networkResultStore) getResult(pid uint64) (
*networkResult, error) {
var result *networkResult
err := kvdb.View(store.db, func(tx kvdb.RTx) error {
err := kvdb.View(store.backend, func(tx kvdb.RTx) error {
var err error
result, err = fetchResult(tx, pid)
return err
@ -260,7 +260,7 @@ func fetchResult(tx kvdb.RTx, pid uint64) (*networkResult, error) {
// concurrently while this process is ongoing, as its result might end up being
// deleted.
func (store *networkResultStore) cleanStore(keep map[uint64]struct{}) error {
return kvdb.Update(store.db.Backend, func(tx kvdb.RwTx) error {
return kvdb.Update(store.backend, func(tx kvdb.RwTx) error {
networkResults, err := tx.CreateTopLevelBucket(
networkResultStoreBucketKey,
)

View File

@ -130,9 +130,18 @@ type Config struct {
// subsystem.
LocalChannelClose func(pubKey []byte, request *ChanClose)
// DB is the channeldb instance that will be used to back the switch's
// DB is the database backend that will be used to back the switch's
// persistent circuit map.
DB *channeldb.DB
DB kvdb.Backend
// FetchAllOpenChannels is a function that fetches all currently open
// channels from the channel database.
FetchAllOpenChannels func() ([]*channeldb.OpenChannel, error)
// FetchClosedChannels is a function that fetches all closed channels
// from the channel database.
FetchClosedChannels func(
pendingOnly bool) ([]*channeldb.ChannelCloseSummary, error)
// SwitchPackager provides access to the forwarding packages of all
// active channels. This gives the switch the ability to read arbitrary
@ -294,6 +303,8 @@ type Switch struct {
func New(cfg Config, currentHeight uint32) (*Switch, error) {
circuitMap, err := NewCircuitMap(&CircuitMapConfig{
DB: cfg.DB,
FetchAllOpenChannels: cfg.FetchAllOpenChannels,
FetchClosedChannels: cfg.FetchClosedChannels,
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
})
if err != nil {
@ -1455,7 +1466,7 @@ func (s *Switch) closeCircuit(pkt *htlcPacket) (*PaymentCircuit, error) {
// we're the originator of the payment, so the link stops attempting to
// re-broadcast.
func (s *Switch) ackSettleFail(settleFailRefs ...channeldb.SettleFailRef) error {
return kvdb.Batch(s.cfg.DB.Backend, func(tx kvdb.RwTx) error {
return kvdb.Batch(s.cfg.DB, func(tx kvdb.RwTx) error {
return s.cfg.SwitchPackager.AckSettleFails(tx, settleFailRefs...)
})
}
@ -1859,7 +1870,7 @@ func (s *Switch) Start() error {
// forwarding packages and reforwards any Settle or Fail HTLCs found. This is
// used to resurrect the switch's mailboxes after a restart.
func (s *Switch) reforwardResponses() error {
openChannels, err := s.cfg.DB.FetchAllOpenChannels()
openChannels, err := s.cfg.FetchAllOpenChannels()
if err != nil {
return err
}
@ -2122,6 +2133,17 @@ func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) {
return link, nil
}
// GetLinkByShortID attempts to return the link which possesses the target short
// channel ID.
func (s *Switch) GetLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink,
error) {
s.indexMtx.RLock()
defer s.indexMtx.RUnlock()
return s.getLinkByShortID(chanID)
}
// getLinkByShortID attempts to return the link which possesses the target
// short channel ID.
//

View File

@ -306,7 +306,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
LocalCommitment: aliceCommit,
RemoteCommitment: aliceCommit,
ShortChannelID: chanID,
Db: dbAlice,
Db: dbAlice.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(chanID),
FundingTxn: channels.TestFundingTx,
}
@ -325,7 +325,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
LocalCommitment: bobCommit,
RemoteCommitment: bobCommit,
ShortChannelID: chanID,
Db: dbBob,
Db: dbBob.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(chanID),
}
@ -384,7 +384,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
}
restoreAlice := func() (*lnwallet.LightningChannel, error) {
aliceStoredChannels, err := dbAlice.FetchOpenChannels(aliceKeyPub)
aliceStoredChannels, err := dbAlice.ChannelStateDB().
FetchOpenChannels(aliceKeyPub)
switch err {
case nil:
case kvdb.ErrDatabaseNotOpen:
@ -394,7 +395,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
"db: %v", err)
}
aliceStoredChannels, err = dbAlice.FetchOpenChannels(aliceKeyPub)
aliceStoredChannels, err = dbAlice.ChannelStateDB().
FetchOpenChannels(aliceKeyPub)
if err != nil {
return nil, errors.Errorf("unable to fetch alice "+
"channel: %v", err)
@ -428,7 +430,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
}
restoreBob := func() (*lnwallet.LightningChannel, error) {
bobStoredChannels, err := dbBob.FetchOpenChannels(bobKeyPub)
bobStoredChannels, err := dbBob.ChannelStateDB().
FetchOpenChannels(bobKeyPub)
switch err {
case nil:
case kvdb.ErrDatabaseNotOpen:
@ -438,7 +441,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
"db: %v", err)
}
bobStoredChannels, err = dbBob.FetchOpenChannels(bobKeyPub)
bobStoredChannels, err = dbBob.ChannelStateDB().
FetchOpenChannels(bobKeyPub)
if err != nil {
return nil, errors.Errorf("unable to fetch bob "+
"channel: %v", err)
@ -950,9 +954,9 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
secondBobChannel, carolChannel *lnwallet.LightningChannel,
startingHeight uint32, opts ...serverOption) *threeHopNetwork {
aliceDb := aliceChannel.State().Db
bobDb := firstBobChannel.State().Db
carolDb := carolChannel.State().Db
aliceDb := aliceChannel.State().Db.GetParentDB()
bobDb := firstBobChannel.State().Db.GetParentDB()
carolDb := carolChannel.State().Db.GetParentDB()
hopNetwork := newHopNetwork()
@ -1201,8 +1205,8 @@ func newTwoHopNetwork(t testing.TB,
aliceChannel, bobChannel *lnwallet.LightningChannel,
startingHeight uint32) *twoHopNetwork {
aliceDb := aliceChannel.State().Db
bobDb := bobChannel.State().Db
aliceDb := aliceChannel.State().Db.GetParentDB()
bobDb := bobChannel.State().Db.GetParentDB()
hopNetwork := newHopNetwork()

24
lnd.go
View File

@ -22,6 +22,7 @@ import (
"sync"
"time"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcutil"
"github.com/btcsuite/btcwallet/wallet"
@ -697,7 +698,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, interceptor signal.Interceptor) error
BtcdMode: cfg.BtcdMode,
LtcdMode: cfg.LtcdMode,
HeightHintDB: dbs.heightHintDB,
ChanStateDB: dbs.chanStateDB,
ChanStateDB: dbs.chanStateDB.ChannelStateDB(),
PrivateWalletPw: privateWalletPw,
PublicWalletPw: publicWalletPw,
Birthday: walletInitParams.Birthday,
@ -1679,14 +1680,27 @@ func initializeDatabases(ctx context.Context,
"instances")
}
// Otherwise, we'll open two instances, one for the state we only need
// locally, and the other for things we want to ensure are replicated.
dbs.graphDB, err = channeldb.CreateWithBackend(
databaseBackends.GraphDB,
dbOptions := []channeldb.OptionModifier{
channeldb.OptionSetRejectCacheSize(cfg.Caches.RejectCacheSize),
channeldb.OptionSetChannelCacheSize(cfg.Caches.ChannelCacheSize),
channeldb.OptionSetBatchCommitInterval(cfg.DB.BatchCommitInterval),
channeldb.OptionDryRunMigration(cfg.DryRunMigration),
}
// We want to pre-allocate the channel graph cache according to what we
// expect for mainnet to speed up memory allocation.
if cfg.ActiveNetParams.Name == chaincfg.MainNetParams.Name {
dbOptions = append(
dbOptions, channeldb.OptionSetPreAllocCacheNumNodes(
channeldb.DefaultPreAllocCacheNumNodes,
),
)
}
// Otherwise, we'll open two instances, one for the state we only need
// locally, and the other for things we want to ensure are replicated.
dbs.graphDB, err = channeldb.CreateWithBackend(
databaseBackends.GraphDB, dbOptions...,
)
switch {
// Give the DB a chance to dry run the migration. Since we know that

View File

@ -56,7 +56,7 @@ type AddInvoiceConfig struct {
// ChanDB is a global boltdb instance which is needed to access the
// channel graph.
ChanDB *channeldb.DB
ChanDB *channeldb.ChannelStateDB
// Graph holds a reference to the ChannelGraph database.
Graph *channeldb.ChannelGraph

View File

@ -51,7 +51,7 @@ type Config struct {
// ChanStateDB is a possibly replicated db instance which contains the
// channel state
ChanStateDB *channeldb.DB
ChanStateDB *channeldb.ChannelStateDB
// GenInvoiceFeatures returns a feature containing feature bits that
// should be advertised on freshly generated invoices.

View File

@ -55,7 +55,7 @@ type RouterBackend struct {
FindRoute func(source, target route.Vertex,
amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams,
destCustomRecords record.CustomSet,
routeHints map[route.Vertex][]*channeldb.ChannelEdgePolicy,
routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy,
finalExpiry uint16) (*route.Route, error)
MissionControl MissionControl

View File

@ -126,7 +126,7 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool,
findRoute := func(source, target route.Vertex,
amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams,
_ record.CustomSet,
routeHints map[route.Vertex][]*channeldb.ChannelEdgePolicy,
routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy,
finalExpiry uint16) (*route.Route, error) {
if int64(amt) != amtSat*1000 {

View File

@ -25,24 +25,20 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) {
ctxb := context.Background()
// Create two fresh nodes and open a channel between them.
alice := net.NewNode(
t.t, "Alice", []string{
"--minbackoff=10s",
"--chan-enable-timeout=1.5s",
"--chan-disable-timeout=3s",
"--chan-status-sample-interval=.5s",
},
)
alice := net.NewNode(t.t, "Alice", []string{
"--minbackoff=10s",
"--chan-enable-timeout=1.5s",
"--chan-disable-timeout=3s",
"--chan-status-sample-interval=.5s",
})
defer shutdownAndAssert(net, t, alice)
bob := net.NewNode(
t.t, "Bob", []string{
"--minbackoff=10s",
"--chan-enable-timeout=1.5s",
"--chan-disable-timeout=3s",
"--chan-status-sample-interval=.5s",
},
)
bob := net.NewNode(t.t, "Bob", []string{
"--minbackoff=10s",
"--chan-enable-timeout=1.5s",
"--chan-disable-timeout=3s",
"--chan-status-sample-interval=.5s",
})
defer shutdownAndAssert(net, t, bob)
// Connect Alice to Bob.
@ -55,36 +51,32 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) {
// being the sole funder of the channel.
chanAmt := btcutil.Amount(100000)
chanPoint := openChannelAndAssert(
t, net, alice, bob,
lntest.OpenChannelParams{
t, net, alice, bob, lntest.OpenChannelParams{
Amt: chanAmt,
},
)
// Wait for Alice and Bob to receive the channel edge from the
// funding manager.
ctxt, _ := context.WithTimeout(ctxb, defaultTimeout)
ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout)
defer cancel()
err := alice.WaitForNetworkChannelOpen(ctxt, chanPoint)
if err != nil {
t.Fatalf("alice didn't see the alice->bob channel before "+
"timeout: %v", err)
}
require.NoError(t.t, err, "alice didn't see the alice->bob channel")
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
err = bob.WaitForNetworkChannelOpen(ctxt, chanPoint)
if err != nil {
t.Fatalf("bob didn't see the bob->alice channel before "+
"timeout: %v", err)
}
require.NoError(t.t, err, "bob didn't see the alice->bob channel")
// Launch a node for Carol which will connect to Alice and Bob in
// order to receive graph updates. This will ensure that the
// channel updates are propagated throughout the network.
// Launch a node for Carol which will connect to Alice and Bob in order
// to receive graph updates. This will ensure that the channel updates
// are propagated throughout the network.
carol := net.NewNode(t.t, "Carol", nil)
defer shutdownAndAssert(net, t, carol)
// Connect both Alice and Bob to the new node Carol, so she can sync her
// graph.
net.ConnectNodes(t.t, alice, carol)
net.ConnectNodes(t.t, bob, carol)
waitForGraphSync(t, carol)
// assertChannelUpdate checks that the required policy update has
// happened on the given node.
@ -109,12 +101,11 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) {
ChanPoint: chanPoint,
Action: action,
}
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout)
defer cancel()
_, err = node.RouterClient.UpdateChanStatus(ctxt, req)
if err != nil {
t.Fatalf("unable to call UpdateChanStatus for %s's node: %v",
node.Name(), err)
}
require.NoErrorf(t.t, err, "UpdateChanStatus")
}
// assertEdgeDisabled ensures that a given node has the correct
@ -122,26 +113,30 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) {
assertEdgeDisabled := func(node *lntest.HarnessNode,
chanPoint *lnrpc.ChannelPoint, disabled bool) {
var predErr error
err = wait.Predicate(func() bool {
outPoint, err := lntest.MakeOutpoint(chanPoint)
require.NoError(t.t, err)
err = wait.NoError(func() error {
req := &lnrpc.ChannelGraphRequest{
IncludeUnannounced: true,
}
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout)
defer cancel()
chanGraph, err := node.DescribeGraph(ctxt, req)
if err != nil {
predErr = fmt.Errorf("unable to query node %v's graph: %v", node, err)
return false
return fmt.Errorf("unable to query node %v's "+
"graph: %v", node, err)
}
numEdges := len(chanGraph.Edges)
if numEdges != 1 {
predErr = fmt.Errorf("expected to find 1 edge in the graph, found %d", numEdges)
return false
return fmt.Errorf("expected to find 1 edge in "+
"the graph, found %d", numEdges)
}
edge := chanGraph.Edges[0]
if edge.ChanPoint != chanPoint.GetFundingTxidStr() {
predErr = fmt.Errorf("expected chan_point %v, got %v",
chanPoint.GetFundingTxidStr(), edge.ChanPoint)
if edge.ChanPoint != outPoint.String() {
return fmt.Errorf("expected chan_point %v, "+
"got %v", outPoint, edge.ChanPoint)
}
var policy *lnrpc.RoutingPolicy
if node.PubKeyStr == edge.Node1Pub {
@ -150,15 +145,14 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) {
policy = edge.Node2Policy
}
if disabled != policy.Disabled {
predErr = fmt.Errorf("expected policy.Disabled to be %v, "+
"but policy was %v", disabled, policy)
return false
return fmt.Errorf("expected policy.Disabled "+
"to be %v, but policy was %v", disabled,
policy)
}
return true
return nil
}, defaultTimeout)
if err != nil {
t.Fatalf("%v", predErr)
}
require.NoError(t.t, err)
}
// When updating the state of the channel between Alice and Bob, we
@ -193,9 +187,7 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) {
// disconnections from automatically disabling the channel again
// (we don't want to clutter the network with channels that are
// falsely advertised as enabled when they don't work).
if err := net.DisconnectNodes(alice, bob); err != nil {
t.Fatalf("unable to disconnect Alice from Bob: %v", err)
}
require.NoError(t.t, net.DisconnectNodes(alice, bob))
expectedPolicy.Disabled = true
assertChannelUpdate(alice, expectedPolicy)
assertChannelUpdate(bob, expectedPolicy)
@ -217,9 +209,7 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) {
expectedPolicy.Disabled = true
assertChannelUpdate(alice, expectedPolicy)
if err := net.DisconnectNodes(alice, bob); err != nil {
t.Fatalf("unable to disconnect Alice from Bob: %v", err)
}
require.NoError(t.t, net.DisconnectNodes(alice, bob))
// Bob sends a "Disabled = true" update upon detecting the
// disconnect.
@ -237,9 +227,7 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) {
// note the asymmetry between manual enable and manual disable!
assertEdgeDisabled(alice, chanPoint, true)
if err := net.DisconnectNodes(alice, bob); err != nil {
t.Fatalf("unable to disconnect Alice from Bob: %v", err)
}
require.NoError(t.t, net.DisconnectNodes(alice, bob))
// Bob sends a "Disabled = true" update upon detecting the
// disconnect.

View File

@ -18,7 +18,7 @@ type Config struct {
// Database is a wrapper around a namespace within boltdb reserved for
// ln-based wallet metadata. See the 'channeldb' package for further
// information.
Database *channeldb.DB
Database *channeldb.ChannelStateDB
// Notifier is used by in order to obtain notifications about funding
// transaction reaching a specified confirmation depth, and to catch

View File

@ -327,13 +327,13 @@ func createTestWallet(tempTestDir string, miningNode *rpctest.Harness,
signer input.Signer, bio lnwallet.BlockChainIO) (*lnwallet.LightningWallet, error) {
dbDir := filepath.Join(tempTestDir, "cdb")
cdb, err := channeldb.Open(dbDir)
fullDB, err := channeldb.Open(dbDir)
if err != nil {
return nil, err
}
cfg := lnwallet.Config{
Database: cdb,
Database: fullDB.ChannelStateDB(),
Notifier: notifier,
SecretKeyRing: keyRing,
WalletController: wc,
@ -2944,11 +2944,11 @@ func clearWalletStates(a, b *lnwallet.LightningWallet) error {
a.ResetReservations()
b.ResetReservations()
if err := a.Cfg.Database.Wipe(); err != nil {
if err := a.Cfg.Database.GetParentDB().Wipe(); err != nil {
return err
}
return b.Cfg.Database.Wipe()
return b.Cfg.Database.GetParentDB().Wipe()
}
func waitForMempoolTx(r *rpctest.Harness, txid *chainhash.Hash) error {

View File

@ -323,7 +323,7 @@ func CreateTestChannels(chanType channeldb.ChannelType) (
RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: aliceLocalCommit,
RemoteCommitment: aliceRemoteCommit,
Db: dbAlice,
Db: dbAlice.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID),
FundingTxn: testTx,
}
@ -341,7 +341,7 @@ func CreateTestChannels(chanType channeldb.ChannelType) (
RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: bobLocalCommit,
RemoteCommitment: bobRemoteCommit,
Db: dbBob,
Db: dbBob.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID),
}

View File

@ -940,7 +940,7 @@ func createTestChannelsForVectors(tc *testContext, chanType channeldb.ChannelTyp
RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: remoteCommit,
RemoteCommitment: remoteCommit,
Db: dbRemote,
Db: dbRemote.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID),
FundingTxn: tc.fundingTx.MsgTx(),
}
@ -958,7 +958,7 @@ func createTestChannelsForVectors(tc *testContext, chanType channeldb.ChannelTyp
RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: localCommit,
RemoteCommitment: localCommit,
Db: dbLocal,
Db: dbLocal.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID),
FundingTxn: tc.fundingTx.MsgTx(),
}

View File

@ -185,7 +185,7 @@ type Config struct {
InterceptSwitch *htlcswitch.InterceptableSwitch
// ChannelDB is used to fetch opened channels, and closed channels.
ChannelDB *channeldb.DB
ChannelDB *channeldb.ChannelStateDB
// ChannelGraph is a pointer to the channel graph which is used to
// query information about the set of known active channels.

View File

@ -229,7 +229,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: aliceCommit,
RemoteCommitment: aliceCommit,
Db: dbAlice,
Db: dbAlice.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID),
FundingTxn: channels.TestFundingTx,
}
@ -246,7 +246,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: bobCommit,
RemoteCommitment: bobCommit,
Db: dbBob,
Db: dbBob.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID),
}
@ -321,7 +321,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
ChanStatusSampleInterval: 30 * time.Second,
ChanEnableTimeout: chanActiveTimeout,
ChanDisableTimeout: 2 * time.Minute,
DB: dbAlice,
DB: dbAlice.ChannelStateDB(),
Graph: dbAlice.ChannelGraph(),
MessageSigner: nodeSignerAlice,
OurPubKey: aliceKeyPub,
@ -359,7 +359,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
ChanActiveTimeout: chanActiveTimeout,
InterceptSwitch: htlcswitch.NewInterceptableSwitch(nil),
ChannelDB: dbAlice,
ChannelDB: dbAlice.ChannelStateDB(),
FeeEstimator: estimator,
Wallet: wallet,
ChainNotifier: notifier,

View File

@ -2,7 +2,6 @@ package routing
import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
@ -10,10 +9,10 @@ import (
// routingGraph is an abstract interface that provides information about nodes
// and edges to pathfinding.
type routingGraph interface {
// forEachNodeChannel calls the callback for every channel of the given node.
// forEachNodeChannel calls the callback for every channel of the given
// node.
forEachNodeChannel(nodePub route.Vertex,
cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy,
*channeldb.ChannelEdgePolicy) error) error
cb func(channel *channeldb.DirectedChannel) error) error
// sourceNode returns the source node of the graph.
sourceNode() route.Vertex
@ -22,59 +21,44 @@ type routingGraph interface {
fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error)
}
// dbRoutingTx is a routingGraph implementation that retrieves from the
// CachedGraph is a routingGraph implementation that retrieves from the
// database.
type dbRoutingTx struct {
type CachedGraph struct {
graph *channeldb.ChannelGraph
tx kvdb.RTx
source route.Vertex
}
// newDbRoutingTx instantiates a new db-connected routing graph. It implictly
// A compile time assertion to make sure CachedGraph implements the routingGraph
// interface.
var _ routingGraph = (*CachedGraph)(nil)
// NewCachedGraph instantiates a new db-connected routing graph. It implictly
// instantiates a new read transaction.
func newDbRoutingTx(graph *channeldb.ChannelGraph) (*dbRoutingTx, error) {
func NewCachedGraph(graph *channeldb.ChannelGraph) (*CachedGraph, error) {
sourceNode, err := graph.SourceNode()
if err != nil {
return nil, err
}
tx, err := graph.Database().BeginReadTx()
if err != nil {
return nil, err
}
return &dbRoutingTx{
return &CachedGraph{
graph: graph,
tx: tx,
source: sourceNode.PubKeyBytes,
}, nil
}
// close closes the underlying db transaction.
func (g *dbRoutingTx) close() error {
return g.tx.Rollback()
}
// forEachNodeChannel calls the callback for every channel of the given node.
//
// NOTE: Part of the routingGraph interface.
func (g *dbRoutingTx) forEachNodeChannel(nodePub route.Vertex,
cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy,
*channeldb.ChannelEdgePolicy) error) error {
func (g *CachedGraph) forEachNodeChannel(nodePub route.Vertex,
cb func(channel *channeldb.DirectedChannel) error) error {
txCb := func(_ kvdb.RTx, info *channeldb.ChannelEdgeInfo,
p1, p2 *channeldb.ChannelEdgePolicy) error {
return cb(info, p1, p2)
}
return g.graph.ForEachNodeChannel(g.tx, nodePub[:], txCb)
return g.graph.ForEachNodeChannel(nodePub, cb)
}
// sourceNode returns the source node of the graph.
//
// NOTE: Part of the routingGraph interface.
func (g *dbRoutingTx) sourceNode() route.Vertex {
func (g *CachedGraph) sourceNode() route.Vertex {
return g.source
}
@ -82,23 +66,8 @@ func (g *dbRoutingTx) sourceNode() route.Vertex {
// unknown, assume no additional features are supported.
//
// NOTE: Part of the routingGraph interface.
func (g *dbRoutingTx) fetchNodeFeatures(nodePub route.Vertex) (
func (g *CachedGraph) fetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) {
targetNode, err := g.graph.FetchLightningNode(g.tx, nodePub)
switch err {
// If the node exists and has features, return them directly.
case nil:
return targetNode.Features, nil
// If we couldn't find a node announcement, populate a blank feature
// vector.
case channeldb.ErrGraphNodeNotFound:
return lnwire.EmptyFeatureVector(), nil
// Otherwise bubble the error up.
default:
return nil, err
}
return g.graph.FetchNodeFeatures(nodePub)
}

View File

@ -39,7 +39,7 @@ type nodeWithDist struct {
weight int64
// nextHop is the edge this route comes from.
nextHop *channeldb.ChannelEdgePolicy
nextHop *channeldb.CachedEdgePolicy
// routingInfoSize is the total size requirement for the payloads field
// in the onion packet from this hop towards the final destination.

View File

@ -162,11 +162,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,
}
session, err := newPaymentSession(
&payment, getBandwidthHints,
func() (routingGraph, func(), error) {
return c.graph, func() {}, nil
},
mc, c.pathFindingCfg,
&payment, getBandwidthHints, c.graph, mc, c.pathFindingCfg,
)
if err != nil {
c.t.Fatal(err)

View File

@ -159,8 +159,7 @@ func (m *mockGraph) addChannel(id uint64, node1id, node2id byte,
//
// NOTE: Part of the routingGraph interface.
func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex,
cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy,
*channeldb.ChannelEdgePolicy) error) error {
cb func(channel *channeldb.DirectedChannel) error) error {
// Look up the mock node.
node, ok := m.nodes[nodePub]
@ -171,36 +170,31 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex,
// Iterate over all of its channels.
for peer, channel := range node.channels {
// Lexicographically sort the pubkeys.
var node1, node2 route.Vertex
var node1 route.Vertex
if bytes.Compare(nodePub[:], peer[:]) == -1 {
node1, node2 = peer, nodePub
node1 = peer
} else {
node1, node2 = nodePub, peer
node1 = nodePub
}
peerNode := m.nodes[peer]
// Call the per channel callback.
err := cb(
&channeldb.ChannelEdgeInfo{
NodeKey1Bytes: node1,
NodeKey2Bytes: node2,
},
&channeldb.ChannelEdgePolicy{
ChannelID: channel.id,
Node: &channeldb.LightningNode{
PubKeyBytes: peer,
Features: lnwire.EmptyFeatureVector(),
&channeldb.DirectedChannel{
ChannelID: channel.id,
IsNode1: nodePub == node1,
OtherNode: peer,
Capacity: channel.capacity,
OutPolicySet: true,
InPolicy: &channeldb.CachedEdgePolicy{
ChannelID: channel.id,
ToNodePubKey: func() route.Vertex {
return nodePub
},
ToNodeFeatures: lnwire.EmptyFeatureVector(),
FeeBaseMSat: peerNode.baseFee,
},
FeeBaseMSat: node.baseFee,
},
&channeldb.ChannelEdgePolicy{
ChannelID: channel.id,
Node: &channeldb.LightningNode{
PubKeyBytes: nodePub,
Features: lnwire.EmptyFeatureVector(),
},
FeeBaseMSat: peerNode.baseFee,
},
)
if err != nil {

View File

@ -173,13 +173,13 @@ func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi,
}
func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate,
_ *btcec.PublicKey, _ *channeldb.ChannelEdgePolicy) bool {
_ *btcec.PublicKey, _ *channeldb.CachedEdgePolicy) bool {
return false
}
func (m *mockPaymentSessionOld) GetAdditionalEdgePolicy(_ *btcec.PublicKey,
_ uint64) *channeldb.ChannelEdgePolicy {
_ uint64) *channeldb.CachedEdgePolicy {
return nil
}
@ -637,17 +637,17 @@ func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
}
func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate,
pubKey *btcec.PublicKey, policy *channeldb.ChannelEdgePolicy) bool {
pubKey *btcec.PublicKey, policy *channeldb.CachedEdgePolicy) bool {
args := m.Called(msg, pubKey, policy)
return args.Bool(0)
}
func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
channelID uint64) *channeldb.ChannelEdgePolicy {
channelID uint64) *channeldb.CachedEdgePolicy {
args := m.Called(pubKey, channelID)
return args.Get(0).(*channeldb.ChannelEdgePolicy)
return args.Get(0).(*channeldb.CachedEdgePolicy)
}
type mockControlTower struct {

View File

@ -42,7 +42,7 @@ const (
type pathFinder = func(g *graphParams, r *RestrictParams,
cfg *PathFindingConfig, source, target route.Vertex,
amt lnwire.MilliSatoshi, finalHtlcExpiry int32) (
[]*channeldb.ChannelEdgePolicy, error)
[]*channeldb.CachedEdgePolicy, error)
var (
// DefaultAttemptCost is the default fixed virtual cost in path finding
@ -76,7 +76,7 @@ var (
// of the edge.
type edgePolicyWithSource struct {
sourceNode route.Vertex
edge *channeldb.ChannelEdgePolicy
edge *channeldb.CachedEdgePolicy
}
// finalHopParams encapsulates various parameters for route construction that
@ -102,7 +102,7 @@ type finalHopParams struct {
// any feature vectors on all hops have been validated for transitive
// dependencies.
func newRoute(sourceVertex route.Vertex,
pathEdges []*channeldb.ChannelEdgePolicy, currentHeight uint32,
pathEdges []*channeldb.CachedEdgePolicy, currentHeight uint32,
finalHop finalHopParams) (*route.Route, error) {
var (
@ -147,10 +147,10 @@ func newRoute(sourceVertex route.Vertex,
supports := func(feature lnwire.FeatureBit) bool {
// If this edge comes from router hints, the features
// could be nil.
if edge.Node.Features == nil {
if edge.ToNodeFeatures == nil {
return false
}
return edge.Node.Features.HasFeature(feature)
return edge.ToNodeFeatures.HasFeature(feature)
}
// We start by assuming the node doesn't support TLV. We'll now
@ -225,7 +225,7 @@ func newRoute(sourceVertex route.Vertex,
// each new hop such that, the final slice of hops will be in
// the forwards order.
currentHop := &route.Hop{
PubKeyBytes: edge.Node.PubKeyBytes,
PubKeyBytes: edge.ToNodePubKey(),
ChannelID: edge.ChannelID,
AmtToForward: amtToForward,
OutgoingTimeLock: outgoingTimeLock,
@ -280,7 +280,7 @@ type graphParams struct {
// additionalEdges is an optional set of edges that should be
// considered during path finding, that is not already found in the
// channel graph.
additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy
additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy
// bandwidthHints is an optional map from channels to bandwidths that
// can be populated if the caller has a better estimate of the current
@ -359,14 +359,12 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) {
var max, total lnwire.MilliSatoshi
cb := func(edgeInfo *channeldb.ChannelEdgeInfo, outEdge,
_ *channeldb.ChannelEdgePolicy) error {
if outEdge == nil {
cb := func(channel *channeldb.DirectedChannel) error {
if !channel.OutPolicySet {
return nil
}
chanID := outEdge.ChannelID
chanID := channel.ChannelID
// Enforce outgoing channel restriction.
if outgoingChans != nil {
@ -381,9 +379,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
// This can happen when a channel is added to the graph after
// we've already queried the bandwidth hints.
if !ok {
bandwidth = lnwire.NewMSatFromSatoshis(
edgeInfo.Capacity,
)
bandwidth = lnwire.NewMSatFromSatoshis(channel.Capacity)
}
if bandwidth > max {
@ -416,7 +412,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
// available bandwidth.
func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
source, target route.Vertex, amt lnwire.MilliSatoshi,
finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) {
finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) {
// Pathfinding can be a significant portion of the total payment
// latency, especially on low-powered devices. Log several metrics to
@ -523,7 +519,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// Build reverse lookup to find incoming edges. Needed because
// search is taken place from target to source.
for _, outgoingEdgePolicy := range outgoingEdgePolicies {
toVertex := outgoingEdgePolicy.Node.PubKeyBytes
toVertex := outgoingEdgePolicy.ToNodePubKey()
incomingEdgePolicy := &edgePolicyWithSource{
sourceNode: vertex,
edge: outgoingEdgePolicy,
@ -587,7 +583,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// satisfy our specific requirements.
processEdge := func(fromVertex route.Vertex,
fromFeatures *lnwire.FeatureVector,
edge *channeldb.ChannelEdgePolicy, toNodeDist *nodeWithDist) {
edge *channeldb.CachedEdgePolicy, toNodeDist *nodeWithDist) {
edgesExpanded++
@ -883,13 +879,14 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// Use the distance map to unravel the forward path from source to
// target.
var pathEdges []*channeldb.ChannelEdgePolicy
var pathEdges []*channeldb.CachedEdgePolicy
currentNode := source
for {
// Determine the next hop forward using the next map.
currentNodeWithDist, ok := distance[currentNode]
if !ok {
// If the node doesnt have a next hop it means we didn't find a path.
// If the node doesn't have a next hop it means we
// didn't find a path.
return nil, errNoPathFound
}
@ -897,7 +894,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
pathEdges = append(pathEdges, currentNodeWithDist.nextHop)
// Advance current node.
currentNode = currentNodeWithDist.nextHop.Node.PubKeyBytes
currentNode = currentNodeWithDist.nextHop.ToNodePubKey()
// Check stop condition at the end of this loop. This prevents
// breaking out too soon for self-payments that have target set
@ -918,7 +915,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// route construction does not care where the features are actually
// taken from. In the future we may wish to do route construction within
// findPath, and avoid using ChannelEdgePolicy altogether.
pathEdges[len(pathEdges)-1].Node.Features = features
pathEdges[len(pathEdges)-1].ToNodeFeatures = features
log.Debugf("Found route: probability=%v, hops=%v, fee=%v",
distance[source].probability, len(pathEdges),

View File

@ -23,6 +23,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/route"
@ -148,26 +149,36 @@ type testChan struct {
// makeTestGraph creates a new instance of a channeldb.ChannelGraph for testing
// purposes. A callback which cleans up the created temporary directories is
// also returned and intended to be executed after the test completes.
func makeTestGraph() (*channeldb.ChannelGraph, func(), error) {
func makeTestGraph() (*channeldb.ChannelGraph, kvdb.Backend, func(), error) {
// First, create a temporary directory to be used for the duration of
// this test.
tempDirName, err := ioutil.TempDir("", "channeldb")
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
// Next, create channeldb for the first time.
cdb, err := channeldb.Open(tempDirName)
// Next, create channelgraph for the first time.
backend, backendCleanup, err := kvdb.GetTestBackend(tempDirName, "cgr")
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
cleanUp := func() {
cdb.Close()
os.RemoveAll(tempDirName)
backendCleanup()
_ = os.RemoveAll(tempDirName)
}
return cdb.ChannelGraph(), cleanUp, nil
opts := channeldb.DefaultOptions()
graph, err := channeldb.NewChannelGraph(
backend, opts.RejectCacheSize, opts.ChannelCacheSize,
opts.BatchCommitInterval, opts.PreAllocCacheNumNodes,
)
if err != nil {
cleanUp()
return nil, nil, nil, err
}
return graph, backend, cleanUp, nil
}
// parseTestGraph returns a fully populated ChannelGraph given a path to a JSON
@ -197,7 +208,7 @@ func parseTestGraph(path string) (*testGraphInstance, error) {
testAddrs = append(testAddrs, testAddr)
// Next, create a temporary graph database for usage within the test.
graph, cleanUp, err := makeTestGraph()
graph, graphBackend, cleanUp, err := makeTestGraph()
if err != nil {
return nil, err
}
@ -293,6 +304,16 @@ func parseTestGraph(path string) (*testGraphInstance, error) {
}
}
aliasForNode := func(node route.Vertex) string {
for alias, pubKey := range aliasMap {
if pubKey == node {
return alias
}
}
return ""
}
// With all the vertexes inserted, we can now insert the edges into the
// test graph.
for _, edge := range g.Edges {
@ -342,10 +363,17 @@ func parseTestGraph(path string) (*testGraphInstance, error) {
return nil, err
}
channelFlags := lnwire.ChanUpdateChanFlags(edge.ChannelFlags)
isUpdate1 := channelFlags&lnwire.ChanUpdateDirection == 0
targetNode := edgeInfo.NodeKey1Bytes
if isUpdate1 {
targetNode = edgeInfo.NodeKey2Bytes
}
edgePolicy := &channeldb.ChannelEdgePolicy{
SigBytes: testSig.Serialize(),
MessageFlags: lnwire.ChanUpdateMsgFlags(edge.MessageFlags),
ChannelFlags: lnwire.ChanUpdateChanFlags(edge.ChannelFlags),
ChannelFlags: channelFlags,
ChannelID: edge.ChannelID,
LastUpdate: testTime,
TimeLockDelta: edge.Expiry,
@ -353,6 +381,10 @@ func parseTestGraph(path string) (*testGraphInstance, error) {
MaxHTLC: lnwire.MilliSatoshi(edge.MaxHTLC),
FeeBaseMSat: lnwire.MilliSatoshi(edge.FeeBaseMsat),
FeeProportionalMillionths: lnwire.MilliSatoshi(edge.FeeRate),
Node: &channeldb.LightningNode{
Alias: aliasForNode(targetNode),
PubKeyBytes: targetNode,
},
}
if err := graph.UpdateEdgePolicy(edgePolicy); err != nil {
return nil, err
@ -381,11 +413,12 @@ func parseTestGraph(path string) (*testGraphInstance, error) {
}
return &testGraphInstance{
graph: graph,
cleanUp: cleanUp,
aliasMap: aliasMap,
privKeyMap: privKeyMap,
channelIDs: channelIDs,
graph: graph,
graphBackend: graphBackend,
cleanUp: cleanUp,
aliasMap: aliasMap,
privKeyMap: privKeyMap,
channelIDs: channelIDs,
}, nil
}
@ -447,8 +480,9 @@ type testChannel struct {
}
type testGraphInstance struct {
graph *channeldb.ChannelGraph
cleanUp func()
graph *channeldb.ChannelGraph
graphBackend kvdb.Backend
cleanUp func()
// aliasMap is a map from a node's alias to its public key. This type is
// provided in order to allow easily look up from the human memorable alias
@ -482,7 +516,7 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) (
testAddrs = append(testAddrs, testAddr)
// Next, create a temporary graph database for usage within the test.
graph, cleanUp, err := makeTestGraph()
graph, graphBackend, cleanUp, err := makeTestGraph()
if err != nil {
return nil, err
}
@ -622,6 +656,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) (
channelFlags |= lnwire.ChanUpdateDisabled
}
node2Features := lnwire.EmptyFeatureVector()
if node2.testChannelPolicy != nil {
node2Features = node2.Features
}
edgePolicy := &channeldb.ChannelEdgePolicy{
SigBytes: testSig.Serialize(),
MessageFlags: msgFlags,
@ -633,6 +672,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) (
MaxHTLC: node1.MaxHTLC,
FeeBaseMSat: node1.FeeBaseMsat,
FeeProportionalMillionths: node1.FeeRate,
Node: &channeldb.LightningNode{
Alias: node2.Alias,
PubKeyBytes: node2Vertex,
Features: node2Features,
},
}
if err := graph.UpdateEdgePolicy(edgePolicy); err != nil {
return nil, err
@ -650,6 +694,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) (
}
channelFlags |= lnwire.ChanUpdateDirection
node1Features := lnwire.EmptyFeatureVector()
if node1.testChannelPolicy != nil {
node1Features = node1.Features
}
edgePolicy := &channeldb.ChannelEdgePolicy{
SigBytes: testSig.Serialize(),
MessageFlags: msgFlags,
@ -661,6 +710,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) (
MaxHTLC: node2.MaxHTLC,
FeeBaseMSat: node2.FeeBaseMsat,
FeeProportionalMillionths: node2.FeeRate,
Node: &channeldb.LightningNode{
Alias: node1.Alias,
PubKeyBytes: node1Vertex,
Features: node1Features,
},
}
if err := graph.UpdateEdgePolicy(edgePolicy); err != nil {
return nil, err
@ -671,10 +725,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) (
}
return &testGraphInstance{
graph: graph,
cleanUp: cleanUp,
aliasMap: aliasMap,
privKeyMap: privKeyMap,
graph: graph,
graphBackend: graphBackend,
cleanUp: cleanUp,
aliasMap: aliasMap,
privKeyMap: privKeyMap,
}, nil
}
@ -1044,20 +1099,23 @@ func TestPathFindingWithAdditionalEdges(t *testing.T) {
// Create the channel edge going from songoku to doge and include it in
// our map of additional edges.
songokuToDoge := &channeldb.ChannelEdgePolicy{
Node: doge,
songokuToDoge := &channeldb.CachedEdgePolicy{
ToNodePubKey: func() route.Vertex {
return doge.PubKeyBytes
},
ToNodeFeatures: lnwire.EmptyFeatureVector(),
ChannelID: 1337,
FeeBaseMSat: 1,
FeeProportionalMillionths: 1000,
TimeLockDelta: 9,
}
additionalEdges := map[route.Vertex][]*channeldb.ChannelEdgePolicy{
additionalEdges := map[route.Vertex][]*channeldb.CachedEdgePolicy{
graph.aliasMap["songoku"]: {songokuToDoge},
}
find := func(r *RestrictParams) (
[]*channeldb.ChannelEdgePolicy, error) {
[]*channeldb.CachedEdgePolicy, error) {
return dbFindPath(
graph.graph, additionalEdges, nil,
@ -1124,14 +1182,13 @@ func TestNewRoute(t *testing.T) {
createHop := func(baseFee lnwire.MilliSatoshi,
feeRate lnwire.MilliSatoshi,
bandwidth lnwire.MilliSatoshi,
timeLockDelta uint16) *channeldb.ChannelEdgePolicy {
timeLockDelta uint16) *channeldb.CachedEdgePolicy {
return &channeldb.ChannelEdgePolicy{
Node: &channeldb.LightningNode{
Features: lnwire.NewFeatureVector(
nil, nil,
),
return &channeldb.CachedEdgePolicy{
ToNodePubKey: func() route.Vertex {
return route.Vertex{}
},
ToNodeFeatures: lnwire.NewFeatureVector(nil, nil),
FeeProportionalMillionths: feeRate,
FeeBaseMSat: baseFee,
TimeLockDelta: timeLockDelta,
@ -1144,7 +1201,7 @@ func TestNewRoute(t *testing.T) {
// hops is the list of hops (the route) that gets passed into
// the call to newRoute.
hops []*channeldb.ChannelEdgePolicy
hops []*channeldb.CachedEdgePolicy
// paymentAmount is the amount that is send into the route
// indicated by hops.
@ -1193,7 +1250,7 @@ func TestNewRoute(t *testing.T) {
// For a single hop payment, no fees are expected to be paid.
name: "single hop",
paymentAmount: 100000,
hops: []*channeldb.ChannelEdgePolicy{
hops: []*channeldb.CachedEdgePolicy{
createHop(100, 1000, 1000000, 10),
},
expectedFees: []lnwire.MilliSatoshi{0},
@ -1206,7 +1263,7 @@ func TestNewRoute(t *testing.T) {
// a fee to receive the payment.
name: "two hop",
paymentAmount: 100000,
hops: []*channeldb.ChannelEdgePolicy{
hops: []*channeldb.CachedEdgePolicy{
createHop(0, 1000, 1000000, 10),
createHop(30, 1000, 1000000, 5),
},
@ -1221,7 +1278,7 @@ func TestNewRoute(t *testing.T) {
name: "two hop tlv onion feature",
destFeatures: tlvFeatures,
paymentAmount: 100000,
hops: []*channeldb.ChannelEdgePolicy{
hops: []*channeldb.CachedEdgePolicy{
createHop(0, 1000, 1000000, 10),
createHop(30, 1000, 1000000, 5),
},
@ -1238,7 +1295,7 @@ func TestNewRoute(t *testing.T) {
destFeatures: tlvPayAddrFeatures,
paymentAddr: &testPaymentAddr,
paymentAmount: 100000,
hops: []*channeldb.ChannelEdgePolicy{
hops: []*channeldb.CachedEdgePolicy{
createHop(0, 1000, 1000000, 10),
createHop(30, 1000, 1000000, 5),
},
@ -1258,7 +1315,7 @@ func TestNewRoute(t *testing.T) {
// gets rounded down to 1.
name: "three hop",
paymentAmount: 100000,
hops: []*channeldb.ChannelEdgePolicy{
hops: []*channeldb.CachedEdgePolicy{
createHop(0, 10, 1000000, 10),
createHop(0, 10, 1000000, 5),
createHop(0, 10, 1000000, 3),
@ -1273,7 +1330,7 @@ func TestNewRoute(t *testing.T) {
// because of the increase amount to forward.
name: "three hop with fee carry over",
paymentAmount: 100000,
hops: []*channeldb.ChannelEdgePolicy{
hops: []*channeldb.CachedEdgePolicy{
createHop(0, 10000, 1000000, 10),
createHop(0, 10000, 1000000, 5),
createHop(0, 10000, 1000000, 3),
@ -1288,7 +1345,7 @@ func TestNewRoute(t *testing.T) {
// effect.
name: "three hop with minimal fees for carry over",
paymentAmount: 100000,
hops: []*channeldb.ChannelEdgePolicy{
hops: []*channeldb.CachedEdgePolicy{
createHop(0, 10000, 1000000, 10),
// First hop charges 0.1% so the second hop fee
@ -1312,7 +1369,7 @@ func TestNewRoute(t *testing.T) {
// custom feature vector.
if testCase.destFeatures != nil {
finalHop := testCase.hops[len(testCase.hops)-1]
finalHop.Node.Features = testCase.destFeatures
finalHop.ToNodeFeatures = testCase.destFeatures
}
assertRoute := func(t *testing.T, route *route.Route) {
@ -1539,7 +1596,7 @@ func TestDestTLVGraphFallback(t *testing.T) {
}
find := func(r *RestrictParams,
target route.Vertex) ([]*channeldb.ChannelEdgePolicy, error) {
target route.Vertex) ([]*channeldb.CachedEdgePolicy, error) {
return dbFindPath(
ctx.graph, nil, nil,
@ -2120,7 +2177,7 @@ func TestPathFindSpecExample(t *testing.T) {
// Carol, so we set "B" as the source node so path finding starts from
// Bob.
bob := ctx.aliases["B"]
bobNode, err := ctx.graph.FetchLightningNode(nil, bob)
bobNode, err := ctx.graph.FetchLightningNode(bob)
if err != nil {
t.Fatalf("unable to find bob: %v", err)
}
@ -2170,7 +2227,7 @@ func TestPathFindSpecExample(t *testing.T) {
// Next, we'll set A as the source node so we can assert that we create
// the proper route for any queries starting with Alice.
alice := ctx.aliases["A"]
aliceNode, err := ctx.graph.FetchLightningNode(nil, alice)
aliceNode, err := ctx.graph.FetchLightningNode(alice)
if err != nil {
t.Fatalf("unable to find alice: %v", err)
}
@ -2270,16 +2327,16 @@ func TestPathFindSpecExample(t *testing.T) {
}
func assertExpectedPath(t *testing.T, aliasMap map[string]route.Vertex,
path []*channeldb.ChannelEdgePolicy, nodeAliases ...string) {
path []*channeldb.CachedEdgePolicy, nodeAliases ...string) {
if len(path) != len(nodeAliases) {
t.Fatal("number of hops and number of aliases do not match")
}
for i, hop := range path {
if hop.Node.PubKeyBytes != aliasMap[nodeAliases[i]] {
if hop.ToNodePubKey() != aliasMap[nodeAliases[i]] {
t.Fatalf("expected %v to be pos #%v in hop, instead "+
"%v was", nodeAliases[i], i, hop.Node.Alias)
"%v was", nodeAliases[i], i, hop.ToNodePubKey())
}
}
}
@ -2930,7 +2987,7 @@ func (c *pathFindingTestContext) cleanup() {
}
func (c *pathFindingTestContext) findPath(target route.Vertex,
amt lnwire.MilliSatoshi) ([]*channeldb.ChannelEdgePolicy,
amt lnwire.MilliSatoshi) ([]*channeldb.CachedEdgePolicy,
error) {
return dbFindPath(
@ -2939,7 +2996,9 @@ func (c *pathFindingTestContext) findPath(target route.Vertex,
)
}
func (c *pathFindingTestContext) assertPath(path []*channeldb.ChannelEdgePolicy, expected []uint64) {
func (c *pathFindingTestContext) assertPath(path []*channeldb.CachedEdgePolicy,
expected []uint64) {
if len(path) != len(expected) {
c.t.Fatalf("expected path of length %v, but got %v",
len(expected), len(path))
@ -2956,28 +3015,22 @@ func (c *pathFindingTestContext) assertPath(path []*channeldb.ChannelEdgePolicy,
// dbFindPath calls findPath after getting a db transaction from the database
// graph.
func dbFindPath(graph *channeldb.ChannelGraph,
additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy,
additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy,
bandwidthHints map[uint64]lnwire.MilliSatoshi,
r *RestrictParams, cfg *PathFindingConfig,
source, target route.Vertex, amt lnwire.MilliSatoshi,
finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) {
finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) {
routingTx, err := newDbRoutingTx(graph)
routingGraph, err := NewCachedGraph(graph)
if err != nil {
return nil, err
}
defer func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}()
return findPath(
&graphParams{
additionalEdges: additionalEdges,
bandwidthHints: bandwidthHints,
graph: routingTx,
graph: routingGraph,
},
r, cfg, source, target, amt, finalHtlcExpiry,
)

View File

@ -898,7 +898,7 @@ func (p *shardHandler) handleFailureMessage(rt *route.Route,
var (
isAdditionalEdge bool
policy *channeldb.ChannelEdgePolicy
policy *channeldb.CachedEdgePolicy
)
// Before we apply the channel update, we need to decide whether the

View File

@ -472,8 +472,8 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase,
Payer: payer,
ChannelPruneExpiry: time.Hour * 24,
GraphPruneInterval: time.Hour * 2,
QueryBandwidth: func(e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi {
return lnwire.NewMSatFromSatoshis(e.Capacity)
QueryBandwidth: func(c *channeldb.DirectedChannel) lnwire.MilliSatoshi {
return lnwire.NewMSatFromSatoshis(c.Capacity)
},
NextPaymentID: func() (uint64, error) {
next := atomic.AddUint64(&uniquePaymentID, 1)

View File

@ -144,13 +144,13 @@ type PaymentSession interface {
// a boolean to indicate whether the update has been applied without
// error.
UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, pubKey *btcec.PublicKey,
policy *channeldb.ChannelEdgePolicy) bool
policy *channeldb.CachedEdgePolicy) bool
// GetAdditionalEdgePolicy uses the public key and channel ID to query
// the ephemeral channel edge policy for additional edges. Returns a nil
// if nothing found.
GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
channelID uint64) *channeldb.ChannelEdgePolicy
channelID uint64) *channeldb.CachedEdgePolicy
}
// paymentSession is used during an HTLC routings session to prune the local
@ -162,7 +162,7 @@ type PaymentSession interface {
// loop if payment attempts take long enough. An additional set of edges can
// also be provided to assist in reaching the payment's destination.
type paymentSession struct {
additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy
additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy
getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error)
@ -172,7 +172,7 @@ type paymentSession struct {
pathFinder pathFinder
getRoutingGraph func() (routingGraph, func(), error)
routingGraph routingGraph
// pathFindingConfig defines global parameters that control the
// trade-off in path finding between fees and probabiity.
@ -193,7 +193,7 @@ type paymentSession struct {
// newPaymentSession instantiates a new payment session.
func newPaymentSession(p *LightningPayment,
getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error),
getRoutingGraph func() (routingGraph, func(), error),
routingGraph routingGraph,
missionControl MissionController, pathFindingConfig PathFindingConfig) (
*paymentSession, error) {
@ -209,7 +209,7 @@ func newPaymentSession(p *LightningPayment,
getBandwidthHints: getBandwidthHints,
payment: p,
pathFinder: findPath,
getRoutingGraph: getRoutingGraph,
routingGraph: routingGraph,
pathFindingConfig: pathFindingConfig,
missionControl: missionControl,
minShardAmt: DefaultShardMinAmt,
@ -287,29 +287,20 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
p.log.Debugf("pathfinding for amt=%v", maxAmt)
// Get a routing graph.
routingGraph, cleanup, err := p.getRoutingGraph()
if err != nil {
return nil, err
}
sourceVertex := routingGraph.sourceNode()
sourceVertex := p.routingGraph.sourceNode()
// Find a route for the current amount.
path, err := p.pathFinder(
&graphParams{
additionalEdges: p.additionalEdges,
bandwidthHints: bandwidthHints,
graph: routingGraph,
graph: p.routingGraph,
},
restrictions, &p.pathFindingConfig,
sourceVertex, p.payment.Target,
maxAmt, finalHtlcExpiry,
)
// Close routing graph.
cleanup()
switch {
case err == errNoPathFound:
// Don't split if this is a legacy payment without mpp
@ -403,7 +394,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
// updates to the supplied policy. It returns a boolean to indicate whether
// there's an error when applying the updates.
func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate,
pubKey *btcec.PublicKey, policy *channeldb.ChannelEdgePolicy) bool {
pubKey *btcec.PublicKey, policy *channeldb.CachedEdgePolicy) bool {
// Validate the message signature.
if err := VerifyChannelUpdateSignature(msg, pubKey); err != nil {
@ -428,7 +419,7 @@ func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate,
// ephemeral channel edge policy for additional edges. Returns a nil if nothing
// found.
func (p *paymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
channelID uint64) *channeldb.ChannelEdgePolicy {
channelID uint64) *channeldb.CachedEdgePolicy {
target := route.NewVertex(pubKey)

View File

@ -17,14 +17,14 @@ var _ PaymentSessionSource = (*SessionSource)(nil)
type SessionSource struct {
// Graph is the channel graph that will be used to gather metrics from
// and also to carry out path finding queries.
Graph *channeldb.ChannelGraph
Graph routingGraph
// QueryBandwidth is a method that allows querying the lower link layer
// to determine the up to date available bandwidth at a prospective link
// to be traversed. If the link isn't available, then a value of zero
// should be returned. Otherwise, the current up to date knowledge of
// the available bandwidth of the link should be returned.
QueryBandwidth func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi
QueryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi
// MissionControl is a shared memory of sorts that executions of payment
// path finding use in order to remember which vertexes/edges were
@ -40,21 +40,6 @@ type SessionSource struct {
PathFindingConfig PathFindingConfig
}
// getRoutingGraph returns a routing graph and a clean-up function for
// pathfinding.
func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) {
routingTx, err := newDbRoutingTx(m.Graph)
if err != nil {
return nil, nil, err
}
return routingTx, func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}, nil
}
// NewPaymentSession creates a new payment session backed by the latest prune
// view from Mission Control. An optional set of routing hints can be provided
// in order to populate additional edges to explore when finding a path to the
@ -62,19 +47,16 @@ func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) {
func (m *SessionSource) NewPaymentSession(p *LightningPayment) (
PaymentSession, error) {
sourceNode, err := m.Graph.SourceNode()
if err != nil {
return nil, err
}
getBandwidthHints := func() (map[uint64]lnwire.MilliSatoshi,
error) {
return generateBandwidthHints(sourceNode, m.QueryBandwidth)
return generateBandwidthHints(
m.Graph.sourceNode(), m.Graph, m.QueryBandwidth,
)
}
session, err := newPaymentSession(
p, getBandwidthHints, m.getRoutingGraph,
p, getBandwidthHints, m.Graph,
m.MissionControl, m.PathFindingConfig,
)
if err != nil {
@ -96,9 +78,9 @@ func (m *SessionSource) NewPaymentSessionEmpty() PaymentSession {
// RouteHintsToEdges converts a list of invoice route hints to an edge map that
// can be passed into pathfinding.
func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) (
map[route.Vertex][]*channeldb.ChannelEdgePolicy, error) {
map[route.Vertex][]*channeldb.CachedEdgePolicy, error) {
edges := make(map[route.Vertex][]*channeldb.ChannelEdgePolicy)
edges := make(map[route.Vertex][]*channeldb.CachedEdgePolicy)
// Traverse through all of the available hop hints and include them in
// our edges map, indexed by the public key of the channel's starting
@ -128,9 +110,12 @@ func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) (
// Finally, create the channel edge from the hop hint
// and add it to list of edges corresponding to the node
// at the start of the channel.
edge := &channeldb.ChannelEdgePolicy{
Node: endNode,
ChannelID: hopHint.ChannelID,
edge := &channeldb.CachedEdgePolicy{
ToNodePubKey: func() route.Vertex {
return endNode.PubKeyBytes
},
ToNodeFeatures: lnwire.EmptyFeatureVector(),
ChannelID: hopHint.ChannelID,
FeeBaseMSat: lnwire.MilliSatoshi(
hopHint.FeeBaseMSat,
),

View File

@ -121,9 +121,7 @@ func TestUpdateAdditionalEdge(t *testing.T) {
return nil, nil
},
func() (routingGraph, func(), error) {
return &sessionGraph{}, func() {}, nil
},
&sessionGraph{},
&MissionControl{},
PathFindingConfig{},
)
@ -203,9 +201,7 @@ func TestRequestRoute(t *testing.T) {
return nil, nil
},
func() (routingGraph, func(), error) {
return &sessionGraph{}, func() {}, nil
},
&sessionGraph{},
&MissionControl{},
PathFindingConfig{},
)
@ -217,7 +213,7 @@ func TestRequestRoute(t *testing.T) {
session.pathFinder = func(
g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
source, target route.Vertex, amt lnwire.MilliSatoshi,
finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) {
finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) {
// We expect find path to receive a cltv limit excluding the
// final cltv delta (including the block padding).
@ -225,13 +221,14 @@ func TestRequestRoute(t *testing.T) {
t.Fatal("wrong cltv limit")
}
path := []*channeldb.ChannelEdgePolicy{
path := []*channeldb.CachedEdgePolicy{
{
Node: &channeldb.LightningNode{
Features: lnwire.NewFeatureVector(
nil, nil,
),
ToNodePubKey: func() route.Vertex {
return route.Vertex{}
},
ToNodeFeatures: lnwire.NewFeatureVector(
nil, nil,
),
},
}

View File

@ -339,7 +339,7 @@ type Config struct {
// a value of zero should be returned. Otherwise, the current up to
// date knowledge of the available bandwidth of the link should be
// returned.
QueryBandwidth func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi
QueryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi
// NextPaymentID is a method that guarantees to return a new, unique ID
// each time it is called. This is used by the router to generate a
@ -406,6 +406,10 @@ type ChannelRouter struct {
// when doing any path finding.
selfNode *channeldb.LightningNode
// cachedGraph is an instance of routingGraph that caches the source node as
// well as the channel graph itself in memory.
cachedGraph routingGraph
// newBlocks is a channel in which new blocks connected to the end of
// the main chain are sent over, and blocks updated after a call to
// UpdateFilter.
@ -460,14 +464,17 @@ var _ ChannelGraphSource = (*ChannelRouter)(nil)
// channel graph is a subset of the UTXO set) set, then the router will proceed
// to fully sync to the latest state of the UTXO set.
func New(cfg Config) (*ChannelRouter, error) {
selfNode, err := cfg.Graph.SourceNode()
if err != nil {
return nil, err
}
r := &ChannelRouter{
cfg: &cfg,
cfg: &cfg,
cachedGraph: &CachedGraph{
graph: cfg.Graph,
source: selfNode.PubKeyBytes,
},
networkUpdates: make(chan *routingMsg),
topologyClients: make(map[uint64]*topologyClient),
ntfnClientUpdates: make(chan *topologyClientUpdate),
@ -1727,7 +1734,7 @@ type routingMsg struct {
func (r *ChannelRouter) FindRoute(source, target route.Vertex,
amt lnwire.MilliSatoshi, restrictions *RestrictParams,
destCustomRecords record.CustomSet,
routeHints map[route.Vertex][]*channeldb.ChannelEdgePolicy,
routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy,
finalExpiry uint16) (*route.Route, error) {
log.Debugf("Searching for path to %v, sending %v", target, amt)
@ -1735,7 +1742,7 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex,
// We'll attempt to obtain a set of bandwidth hints that can help us
// eliminate certain routes early on in the path finding process.
bandwidthHints, err := generateBandwidthHints(
r.selfNode, r.cfg.QueryBandwidth,
r.selfNode.PubKeyBytes, r.cachedGraph, r.cfg.QueryBandwidth,
)
if err != nil {
return nil, err
@ -1752,22 +1759,11 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex,
// execute our path finding algorithm.
finalHtlcExpiry := currentHeight + int32(finalExpiry)
routingTx, err := newDbRoutingTx(r.cfg.Graph)
if err != nil {
return nil, err
}
defer func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}()
path, err := findPath(
&graphParams{
additionalEdges: routeHints,
bandwidthHints: bandwidthHints,
graph: routingTx,
graph: r.cachedGraph,
},
restrictions,
&r.cfg.PathFindingConfig,
@ -2505,8 +2501,10 @@ func (r *ChannelRouter) GetChannelByID(chanID lnwire.ShortChannelID) (
// within the graph.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (r *ChannelRouter) FetchLightningNode(node route.Vertex) (*channeldb.LightningNode, error) {
return r.cfg.Graph.FetchLightningNode(nil, node)
func (r *ChannelRouter) FetchLightningNode(
node route.Vertex) (*channeldb.LightningNode, error) {
return r.cfg.Graph.FetchLightningNode(node)
}
// ForEachNode is used to iterate over every node in router topology.
@ -2661,19 +2659,19 @@ func (r *ChannelRouter) MarkEdgeLive(chanID lnwire.ShortChannelID) error {
// these hints allows us to reduce the number of extraneous attempts as we can
// skip channels that are inactive, or just don't have enough bandwidth to
// carry the payment.
func generateBandwidthHints(sourceNode *channeldb.LightningNode,
queryBandwidth func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi) (map[uint64]lnwire.MilliSatoshi, error) {
func generateBandwidthHints(sourceNode route.Vertex, graph routingGraph,
queryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi) (
map[uint64]lnwire.MilliSatoshi, error) {
// First, we'll collect the set of outbound edges from the target
// source node.
var localChans []*channeldb.ChannelEdgeInfo
err := sourceNode.ForEachChannel(nil, func(tx kvdb.RTx,
edgeInfo *channeldb.ChannelEdgeInfo,
_, _ *channeldb.ChannelEdgePolicy) error {
localChans = append(localChans, edgeInfo)
return nil
})
var localChans []*channeldb.DirectedChannel
err := graph.forEachNodeChannel(
sourceNode, func(channel *channeldb.DirectedChannel) error {
localChans = append(localChans, channel)
return nil
},
)
if err != nil {
return nil, err
}
@ -2726,7 +2724,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
// We'll attempt to obtain a set of bandwidth hints that helps us select
// the best outgoing channel to use in case no outgoing channel is set.
bandwidthHints, err := generateBandwidthHints(
r.selfNode, r.cfg.QueryBandwidth,
r.selfNode.PubKeyBytes, r.cachedGraph, r.cfg.QueryBandwidth,
)
if err != nil {
return nil, err
@ -2756,18 +2754,6 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
runningAmt = *amt
}
// Open a transaction to execute the graph queries in.
routingTx, err := newDbRoutingTx(r.cfg.Graph)
if err != nil {
return nil, err
}
defer func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}()
// Traverse hops backwards to accumulate fees in the running amounts.
source := r.selfNode.PubKeyBytes
for i := len(hops) - 1; i >= 0; i-- {
@ -2786,7 +2772,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
// known in the graph.
u := newUnifiedPolicies(source, toNode, outgoingChans)
err := u.addGraphPolicies(routingTx)
err := u.addGraphPolicies(r.cachedGraph)
if err != nil {
return nil, err
}
@ -2832,7 +2818,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
// total amount, we make a forward pass. Because the amount may have
// been increased in the backward pass, fees need to be recalculated and
// amount ranges re-checked.
var pathEdges []*channeldb.ChannelEdgePolicy
var pathEdges []*channeldb.CachedEdgePolicy
receiverAmt := runningAmt
for i, edge := range edges {
policy := edge.getPolicy(receiverAmt, bandwidthHints)

View File

@ -125,17 +125,19 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T,
}
mc, err := NewMissionControl(
graphInstance.graph.Database(), route.Vertex{},
mcConfig,
graphInstance.graphBackend, route.Vertex{}, mcConfig,
)
require.NoError(t, err, "failed to create missioncontrol")
sessionSource := &SessionSource{
Graph: graphInstance.graph,
QueryBandwidth: func(
e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi {
cachedGraph, err := NewCachedGraph(graphInstance.graph)
require.NoError(t, err)
return lnwire.NewMSatFromSatoshis(e.Capacity)
sessionSource := &SessionSource{
Graph: cachedGraph,
QueryBandwidth: func(
c *channeldb.DirectedChannel) lnwire.MilliSatoshi {
return lnwire.NewMSatFromSatoshis(c.Capacity)
},
PathFindingConfig: pathFindingConfig,
MissionControl: mc,
@ -159,7 +161,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T,
ChannelPruneExpiry: time.Hour * 24,
GraphPruneInterval: time.Hour * 2,
QueryBandwidth: func(
e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi {
e *channeldb.DirectedChannel) lnwire.MilliSatoshi {
return lnwire.NewMSatFromSatoshis(e.Capacity)
},
@ -188,7 +190,6 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T,
cleanUp := func() {
ctx.router.Stop()
graphInstance.cleanUp()
}
return ctx, cleanUp
@ -197,17 +198,10 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T,
func createTestCtxSingleNode(t *testing.T,
startingHeight uint32) (*testCtx, func()) {
var (
graph *channeldb.ChannelGraph
sourceNode *channeldb.LightningNode
cleanup func()
err error
)
graph, cleanup, err = makeTestGraph()
graph, graphBackend, cleanup, err := makeTestGraph()
require.NoError(t, err, "failed to make test graph")
sourceNode, err = createTestNode()
sourceNode, err := createTestNode()
require.NoError(t, err, "failed to create test node")
require.NoError(t,
@ -215,8 +209,9 @@ func createTestCtxSingleNode(t *testing.T,
)
graphInstance := &testGraphInstance{
graph: graph,
cleanUp: cleanup,
graph: graph,
graphBackend: graphBackend,
cleanUp: cleanup,
}
return createTestCtxFromGraphInstance(
@ -1401,6 +1396,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
MinHTLC: 1,
FeeBaseMSat: 10,
FeeProportionalMillionths: 10000,
Node: &channeldb.LightningNode{
PubKeyBytes: edge.NodeKey2Bytes,
},
}
edgePolicy.ChannelFlags = 0
@ -1417,6 +1415,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
MinHTLC: 1,
FeeBaseMSat: 10,
FeeProportionalMillionths: 10000,
Node: &channeldb.LightningNode{
PubKeyBytes: edge.NodeKey1Bytes,
},
}
edgePolicy.ChannelFlags = 1
@ -1498,6 +1499,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
MinHTLC: 1,
FeeBaseMSat: 10,
FeeProportionalMillionths: 10000,
Node: &channeldb.LightningNode{
PubKeyBytes: edge.NodeKey2Bytes,
},
}
edgePolicy.ChannelFlags = 0
@ -1513,6 +1517,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
MinHTLC: 1,
FeeBaseMSat: 10,
FeeProportionalMillionths: 10000,
Node: &channeldb.LightningNode{
PubKeyBytes: edge.NodeKey1Bytes,
},
}
edgePolicy.ChannelFlags = 1
@ -1577,7 +1584,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
t.Fatalf("unable to find any routes: %v", err)
}
copy1, err := ctx.graph.FetchLightningNode(nil, pub1)
copy1, err := ctx.graph.FetchLightningNode(pub1)
if err != nil {
t.Fatalf("unable to fetch node: %v", err)
}
@ -1586,7 +1593,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
t.Fatalf("fetched node not equal to original")
}
copy2, err := ctx.graph.FetchLightningNode(nil, pub2)
copy2, err := ctx.graph.FetchLightningNode(pub2)
if err != nil {
t.Fatalf("unable to fetch node: %v", err)
}
@ -2474,8 +2481,8 @@ func TestFindPathFeeWeighting(t *testing.T) {
if len(path) != 1 {
t.Fatalf("expected path length of 1, instead was: %v", len(path))
}
if path[0].Node.Alias != "luoji" {
t.Fatalf("wrong node: %v", path[0].Node.Alias)
if path[0].ToNodePubKey() != ctx.aliases["luoji"] {
t.Fatalf("wrong node: %v", path[0].ToNodePubKey())
}
}

View File

@ -40,7 +40,7 @@ func newUnifiedPolicies(sourceNode, toNode route.Vertex,
// addPolicy adds a single channel policy. Capacity may be zero if unknown
// (light clients).
func (u *unifiedPolicies) addPolicy(fromNode route.Vertex,
edge *channeldb.ChannelEdgePolicy, capacity btcutil.Amount) {
edge *channeldb.CachedEdgePolicy, capacity btcutil.Amount) {
localChan := fromNode == u.sourceNode
@ -69,24 +69,18 @@ func (u *unifiedPolicies) addPolicy(fromNode route.Vertex,
// addGraphPolicies adds all policies that are known for the toNode in the
// graph.
func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error {
cb := func(edgeInfo *channeldb.ChannelEdgeInfo, _,
inEdge *channeldb.ChannelEdgePolicy) error {
cb := func(channel *channeldb.DirectedChannel) error {
// If there is no edge policy for this candidate node, skip.
// Note that we are searching backwards so this node would have
// come prior to the pivot node in the route.
if inEdge == nil {
if channel.InPolicy == nil {
return nil
}
// The node on the other end of this channel is the from node.
fromNode, err := edgeInfo.OtherNodeKeyBytes(u.toNode[:])
if err != nil {
return err
}
// Add this policy to the unified policies map.
u.addPolicy(fromNode, inEdge, edgeInfo.Capacity)
u.addPolicy(
channel.OtherNode, channel.InPolicy, channel.Capacity,
)
return nil
}
@ -98,7 +92,7 @@ func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error {
// unifiedPolicyEdge is the individual channel data that is kept inside an
// unifiedPolicy object.
type unifiedPolicyEdge struct {
policy *channeldb.ChannelEdgePolicy
policy *channeldb.CachedEdgePolicy
capacity btcutil.Amount
}
@ -139,7 +133,7 @@ type unifiedPolicy struct {
// specific amount to send. It differentiates between local and network
// channels.
func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi,
bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy {
bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy {
if u.localChan {
return u.getPolicyLocal(amt, bandwidthHints)
@ -151,10 +145,10 @@ func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi,
// getPolicyLocal returns the optimal policy to use for this local connection
// given a specific amount to send.
func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi,
bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy {
bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy {
var (
bestPolicy *channeldb.ChannelEdgePolicy
bestPolicy *channeldb.CachedEdgePolicy
maxBandwidth lnwire.MilliSatoshi
)
@ -206,10 +200,10 @@ func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi,
// a specific amount to send. The goal is to return a policy that maximizes the
// probability of a successful forward in a non-strict forwarding context.
func (u *unifiedPolicy) getPolicyNetwork(
amt lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy {
amt lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy {
var (
bestPolicy *channeldb.ChannelEdgePolicy
bestPolicy *channeldb.CachedEdgePolicy
maxFee lnwire.MilliSatoshi
maxTimelock uint16
)

View File

@ -20,7 +20,7 @@ func TestUnifiedPolicies(t *testing.T) {
u := newUnifiedPolicies(source, toNode, nil)
// Add two channels between the pair of nodes.
p1 := channeldb.ChannelEdgePolicy{
p1 := channeldb.CachedEdgePolicy{
FeeProportionalMillionths: 100000,
FeeBaseMSat: 30,
TimeLockDelta: 60,
@ -28,7 +28,7 @@ func TestUnifiedPolicies(t *testing.T) {
MaxHTLC: 500,
MinHTLC: 100,
}
p2 := channeldb.ChannelEdgePolicy{
p2 := channeldb.CachedEdgePolicy{
FeeProportionalMillionths: 190000,
FeeBaseMSat: 10,
TimeLockDelta: 40,
@ -39,7 +39,7 @@ func TestUnifiedPolicies(t *testing.T) {
u.addPolicy(fromNode, &p1, 7)
u.addPolicy(fromNode, &p2, 7)
checkPolicy := func(policy *channeldb.ChannelEdgePolicy,
checkPolicy := func(policy *channeldb.CachedEdgePolicy,
feeBase lnwire.MilliSatoshi, feeRate lnwire.MilliSatoshi,
timeLockDelta uint16) {

View File

@ -3989,7 +3989,7 @@ func (r *rpcServer) createRPCClosedChannel(
CloseInitiator: closeInitiator,
}
reports, err := r.server.chanStateDB.FetchChannelReports(
reports, err := r.server.miscDB.FetchChannelReports(
*r.cfg.ActiveNetParams.GenesisHash, &dbChannel.ChanPoint,
)
switch err {
@ -5152,7 +5152,7 @@ func (r *rpcServer) ListInvoices(ctx context.Context,
PendingOnly: req.PendingOnly,
Reversed: req.Reversed,
}
invoiceSlice, err := r.server.chanStateDB.QueryInvoices(q)
invoiceSlice, err := r.server.miscDB.QueryInvoices(q)
if err != nil {
return nil, fmt.Errorf("unable to query invoices: %v", err)
}
@ -5549,7 +5549,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context,
// With the public key decoded, attempt to fetch the node corresponding
// to this public key. If the node cannot be found, then an error will
// be returned.
node, err := graph.FetchLightningNode(nil, pubKey)
node, err := graph.FetchLightningNode(pubKey)
switch {
case err == channeldb.ErrGraphNodeNotFound:
return nil, status.Error(codes.NotFound, err.Error())
@ -5954,7 +5954,7 @@ func (r *rpcServer) ListPayments(ctx context.Context,
query.MaxPayments = math.MaxUint64
}
paymentsQuerySlice, err := r.server.chanStateDB.QueryPayments(query)
paymentsQuerySlice, err := r.server.miscDB.QueryPayments(query)
if err != nil {
return nil, err
}
@ -5995,9 +5995,7 @@ func (r *rpcServer) DeletePayment(ctx context.Context,
rpcsLog.Infof("[DeletePayment] payment_identifier=%v, "+
"failed_htlcs_only=%v", hash, req.FailedHtlcsOnly)
err = r.server.chanStateDB.DeletePayment(
hash, req.FailedHtlcsOnly,
)
err = r.server.miscDB.DeletePayment(hash, req.FailedHtlcsOnly)
if err != nil {
return nil, err
}
@ -6014,7 +6012,7 @@ func (r *rpcServer) DeleteAllPayments(ctx context.Context,
"failed_htlcs_only=%v", req.FailedPaymentsOnly,
req.FailedHtlcsOnly)
err := r.server.chanStateDB.DeletePayments(
err := r.server.miscDB.DeletePayments(
req.FailedPaymentsOnly, req.FailedHtlcsOnly,
)
if err != nil {
@ -6176,7 +6174,7 @@ func (r *rpcServer) FeeReport(ctx context.Context,
return nil, err
}
fwdEventLog := r.server.chanStateDB.ForwardingLog()
fwdEventLog := r.server.miscDB.ForwardingLog()
// computeFeeSum is a helper function that computes the total fees for
// a particular time slice described by a forwarding event query.
@ -6417,7 +6415,7 @@ func (r *rpcServer) ForwardingHistory(ctx context.Context,
IndexOffset: req.IndexOffset,
NumMaxEvents: numEvents,
}
timeSlice, err := r.server.chanStateDB.ForwardingLog().Query(eventQuery)
timeSlice, err := r.server.miscDB.ForwardingLog().Query(eventQuery)
if err != nil {
return nil, fmt.Errorf("unable to query forwarding log: %v", err)
}
@ -6479,7 +6477,7 @@ func (r *rpcServer) ExportChannelBackup(ctx context.Context,
// the database. If this channel has been closed, or the outpoint is
// unknown, then we'll return an error
unpackedBackup, err := chanbackup.FetchBackupForChan(
chanPoint, r.server.chanStateDB,
chanPoint, r.server.chanStateDB, r.server.addrSource,
)
if err != nil {
return nil, err
@ -6649,7 +6647,7 @@ func (r *rpcServer) ExportAllChannelBackups(ctx context.Context,
// First, we'll attempt to read back ups for ALL currently opened
// channels from disk.
allUnpackedBackups, err := chanbackup.FetchStaticChanBackups(
r.server.chanStateDB,
r.server.chanStateDB, r.server.addrSource,
)
if err != nil {
return nil, fmt.Errorf("unable to fetch all static chan "+
@ -6776,7 +6774,7 @@ func (r *rpcServer) SubscribeChannelBackups(req *lnrpc.ChannelBackupSubscription
// we'll obtains the current set of single channel
// backups from disk.
chanBackups, err := chanbackup.FetchStaticChanBackups(
r.server.chanStateDB,
r.server.chanStateDB, r.server.addrSource,
)
if err != nil {
return fmt.Errorf("unable to fetch all "+

View File

@ -222,7 +222,13 @@ type server struct {
graphDB *channeldb.ChannelGraph
chanStateDB *channeldb.DB
chanStateDB *channeldb.ChannelStateDB
addrSource chanbackup.AddressSource
// miscDB is the DB that contains all "other" databases within the main
// channel DB that haven't been separated out yet.
miscDB *channeldb.DB
htlcSwitch *htlcswitch.Switch
@ -432,14 +438,18 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
s := &server{
cfg: cfg,
graphDB: dbs.graphDB.ChannelGraph(),
chanStateDB: dbs.chanStateDB,
chanStateDB: dbs.chanStateDB.ChannelStateDB(),
addrSource: dbs.chanStateDB,
miscDB: dbs.chanStateDB,
cc: cc,
sigPool: lnwallet.NewSigPool(cfg.Workers.Sig, cc.Signer),
writePool: writePool,
readPool: readPool,
chansToRestore: chansToRestore,
channelNotifier: channelnotifier.New(dbs.chanStateDB),
channelNotifier: channelnotifier.New(
dbs.chanStateDB.ChannelStateDB(),
),
identityECDH: nodeKeyECDH,
nodeSigner: netann.NewNodeSigner(nodeKeySigner),
@ -494,7 +504,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
thresholdMSats := lnwire.NewMSatFromSatoshis(thresholdSats)
s.htlcSwitch, err = htlcswitch.New(htlcswitch.Config{
DB: dbs.chanStateDB,
DB: dbs.chanStateDB,
FetchAllOpenChannels: s.chanStateDB.FetchAllOpenChannels,
FetchClosedChannels: s.chanStateDB.FetchClosedChannels,
LocalChannelClose: func(pubKey []byte,
request *htlcswitch.ChanClose) {
@ -537,7 +549,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
MessageSigner: s.nodeSigner,
IsChannelActive: s.htlcSwitch.HasActiveLink,
ApplyChannelUpdate: s.applyChannelUpdate,
DB: dbs.chanStateDB,
DB: s.chanStateDB,
Graph: dbs.graphDB.ChannelGraph(),
}
@ -702,9 +714,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
return nil, err
}
queryBandwidth := func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi {
cid := lnwire.NewChanIDFromOutPoint(&edge.ChannelPoint)
link, err := s.htlcSwitch.GetLink(cid)
queryBandwidth := func(c *channeldb.DirectedChannel) lnwire.MilliSatoshi {
cid := lnwire.NewShortChanIDFromInt(c.ChannelID)
link, err := s.htlcSwitch.GetLinkByShortID(cid)
if err != nil {
// If the link isn't online, then we'll report
// that it has zero bandwidth to the router.
@ -768,8 +780,12 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
MinProbability: routingConfig.MinRouteProbability,
}
cachedGraph, err := routing.NewCachedGraph(chanGraph)
if err != nil {
return nil, err
}
paymentSessionSource := &routing.SessionSource{
Graph: chanGraph,
Graph: cachedGraph,
MissionControl: s.missionControl,
QueryBandwidth: queryBandwidth,
PathFindingConfig: pathFindingConfig,
@ -805,11 +821,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
}
chanSeries := discovery.NewChanSeries(s.graphDB)
gossipMessageStore, err := discovery.NewMessageStore(s.chanStateDB)
gossipMessageStore, err := discovery.NewMessageStore(dbs.chanStateDB)
if err != nil {
return nil, err
}
waitingProofStore, err := channeldb.NewWaitingProofStore(s.chanStateDB)
waitingProofStore, err := channeldb.NewWaitingProofStore(dbs.chanStateDB)
if err != nil {
return nil, err
}
@ -891,8 +907,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
s.utxoNursery = contractcourt.NewUtxoNursery(&contractcourt.NurseryConfig{
ChainIO: cc.ChainIO,
ConfDepth: 1,
FetchClosedChannels: dbs.chanStateDB.FetchClosedChannels,
FetchClosedChannel: dbs.chanStateDB.FetchClosedChannel,
FetchClosedChannels: s.chanStateDB.FetchClosedChannels,
FetchClosedChannel: s.chanStateDB.FetchClosedChannel,
Notifier: cc.ChainNotifier,
PublishTransaction: cc.Wallet.PublishTransaction,
Store: utxnStore,
@ -1018,7 +1034,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
s.breachArbiter = contractcourt.NewBreachArbiter(&contractcourt.BreachConfig{
CloseLink: closeLink,
DB: dbs.chanStateDB,
DB: s.chanStateDB,
Estimator: s.cc.FeeEstimator,
GenSweepScript: newSweepPkScriptGen(cc.Wallet),
Notifier: cc.ChainNotifier,
@ -1075,7 +1091,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
FindChannel: func(chanID lnwire.ChannelID) (
*channeldb.OpenChannel, error) {
dbChannels, err := dbs.chanStateDB.FetchAllChannels()
dbChannels, err := s.chanStateDB.FetchAllChannels()
if err != nil {
return nil, err
}
@ -1247,10 +1263,12 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
// static backup of the latest channel state.
chanNotifier := &channelNotifier{
chanNotifier: s.channelNotifier,
addrs: s.chanStateDB,
addrs: dbs.chanStateDB,
}
backupFile := chanbackup.NewMultiFile(cfg.BackupFilePath)
startingChans, err := chanbackup.FetchStaticChanBackups(s.chanStateDB)
startingChans, err := chanbackup.FetchStaticChanBackups(
s.chanStateDB, s.addrSource,
)
if err != nil {
return nil, err
}
@ -1275,8 +1293,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
},
GetOpenChannels: s.chanStateDB.FetchAllOpenChannels,
Clock: clock.NewDefaultClock(),
ReadFlapCount: s.chanStateDB.ReadFlapCount,
WriteFlapCount: s.chanStateDB.WriteFlapCounts,
ReadFlapCount: s.miscDB.ReadFlapCount,
WriteFlapCount: s.miscDB.WriteFlapCounts,
FlapCountTicker: ticker.New(chanfitness.FlapCountFlushRate),
})
@ -2531,7 +2549,7 @@ func (s *server) establishPersistentConnections() error {
// Iterate through the list of LinkNodes to find addresses we should
// attempt to connect to based on our set of previous connections. Set
// the reconnection port to the default peer port.
linkNodes, err := s.chanStateDB.FetchAllLinkNodes()
linkNodes, err := s.chanStateDB.LinkNodeDB().FetchAllLinkNodes()
if err != nil && err != channeldb.ErrLinkNodesNotFound {
return err
}
@ -3911,7 +3929,7 @@ func (s *server) fetchNodeAdvertisedAddr(pub *btcec.PublicKey) (net.Addr, error)
return nil, err
}
node, err := s.graphDB.FetchLightningNode(nil, vertex)
node, err := s.graphDB.FetchLightningNode(vertex)
if err != nil {
return nil, err
}

View File

@ -93,7 +93,7 @@ func (s *subRPCServerConfigs) PopulateDependencies(cfg *Config,
routerBackend *routerrpc.RouterBackend,
nodeSigner *netann.NodeSigner,
graphDB *channeldb.ChannelGraph,
chanStateDB *channeldb.DB,
chanStateDB *channeldb.ChannelStateDB,
sweeper *sweep.UtxoSweeper,
tower *watchtower.Standalone,
towerClient wtclient.Client,