Merge pull request #5640 from bhandras/kvdb-prefetch

kvdb+channeld: extend `kvdb` with `Prefetch` for prefetching buckets in one go and speed up payment control by prefetching payments on hot paths
This commit is contained in:
Oliver Gugger 2021-09-20 09:42:18 +02:00 committed by GitHub
commit 29a8661517
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 798 additions and 151 deletions

View File

@ -131,6 +131,7 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash,
// from a previous execution of the batched db transaction.
updateErr = nil
prefetchPayment(tx, paymentHash)
bucket, err := createPaymentBucket(tx, paymentHash)
if err != nil {
return err
@ -292,6 +293,7 @@ func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash,
var payment *MPPayment
err = kvdb.Batch(p.db.Backend, func(tx kvdb.RwTx) error {
prefetchPayment(tx, paymentHash)
bucket, err := fetchPaymentBucketUpdate(tx, paymentHash)
if err != nil {
return err
@ -430,6 +432,7 @@ func (p *PaymentControl) updateHtlcKey(paymentHash lntypes.Hash,
err := kvdb.Batch(p.db.Backend, func(tx kvdb.RwTx) error {
payment = nil
prefetchPayment(tx, paymentHash)
bucket, err := fetchPaymentBucketUpdate(tx, paymentHash)
if err != nil {
return err
@ -500,6 +503,7 @@ func (p *PaymentControl) Fail(paymentHash lntypes.Hash,
updateErr = nil
payment = nil
prefetchPayment(tx, paymentHash)
bucket, err := fetchPaymentBucketUpdate(tx, paymentHash)
if err == ErrPaymentNotInitiated {
updateErr = ErrPaymentNotInitiated
@ -550,6 +554,7 @@ func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) (
var payment *MPPayment
err := kvdb.View(p.db, func(tx kvdb.RTx) error {
prefetchPayment(tx, paymentHash)
bucket, err := fetchPaymentBucket(tx, paymentHash)
if err != nil {
return err
@ -568,6 +573,26 @@ func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) (
return payment, nil
}
// prefetchPayment attempts to prefetch as much of the payment as possible to
// reduce DB roundtrips.
func prefetchPayment(tx kvdb.RTx, paymentHash lntypes.Hash) {
rb := kvdb.RootBucket(tx)
kvdb.Prefetch(
rb,
[]string{
// Prefetch all keys in the payment's bucket.
string(paymentsRootBucket),
string(paymentHash[:]),
},
[]string{
// Prefetch all keys in the payment's htlc bucket.
string(paymentsRootBucket),
string(paymentHash[:]),
string(paymentHtlcsBucket),
},
)
}
// createPaymentBucket creates or fetches the sub-bucket assigned to this
// payment hash.
func createPaymentBucket(tx kvdb.RwTx, paymentHash lntypes.Hash) (

View File

@ -267,6 +267,9 @@ you.
payments. Deleting all failed payments beforehand makes migration safer and
faster too.
* [Prefetch payments on hot paths](https://github.com/lightningnetwork/lnd/pull/5640)
to reduce roundtrips to the remote DB backend.
## Performance improvements
* [Update MC store in blocks](https://github.com/lightningnetwork/lnd/pull/5515)

View File

@ -63,6 +63,10 @@ func TestBolt(t *testing.T) {
name: "tx rollback",
test: testTxRollback,
},
{
name: "prefetch",
test: testPrefetch,
},
}
for _, test := range tests {

View File

@ -219,7 +219,8 @@ func (db *db) View(f func(tx walletdb.ReadTx) error, reset func()) error {
return f(newReadWriteTx(stm, etcdDefaultRootBucketId, nil))
}
return RunSTM(db.cli, apply, db.txQueue, db.getSTMOptions()...)
_, err := RunSTM(db.cli, apply, db.txQueue, db.getSTMOptions()...)
return err
}
// Update opens a database read/write transaction and executes the function f
@ -240,7 +241,8 @@ func (db *db) Update(f func(tx walletdb.ReadWriteTx) error, reset func()) error
return f(newReadWriteTx(stm, etcdDefaultRootBucketId, nil))
}
return RunSTM(db.cli, apply, db.txQueue, db.getSTMOptions()...)
_, err := RunSTM(db.cli, apply, db.txQueue, db.getSTMOptions()...)
return err
}
// PrintStats returns all collected stats pretty printed into a string.

View File

@ -371,3 +371,37 @@ func (b *readWriteBucket) Sequence() uint64 {
return num
}
func flattenMap(m map[string]struct{}) []string {
result := make([]string, len(m))
i := 0
for key := range m {
result[i] = key
i++
}
return result
}
// Prefetch will prefetch all keys in the passed paths as well as all bucket
// keys along the paths.
func (b *readWriteBucket) Prefetch(paths ...[]string) {
keys := make(map[string]struct{})
ranges := make(map[string]struct{})
for _, path := range paths {
parent := b.id
for _, bucket := range path {
bucketKey := makeBucketKey(parent, []byte(bucket))
keys[string(bucketKey[:])] = struct{}{}
id := makeBucketID(bucketKey)
parent = id[:]
}
ranges[string(parent)] = struct{}{}
}
b.tx.stm.Prefetch(flattenMap(keys), flattenMap(ranges))
}

View File

@ -41,6 +41,13 @@ func rootBucket(tx *readWriteTx) *readWriteBucket {
return newReadWriteBucket(tx, tx.rootBucketID[:], tx.rootBucketID[:])
}
// RootBucket will return a handle to the root bucket. This is not a real handle
// but just a wrapper around the root bucket ID to allow derivation of child
// keys.
func (tx *readWriteTx) RootBucket() walletdb.ReadBucket {
return rootBucket(tx)
}
// ReadBucket opens the root bucket for read only access. If the bucket
// described by the key does not exist, nil is returned.
func (tx *readWriteTx) ReadBucket(key []byte) walletdb.ReadBucket {

View File

@ -8,6 +8,8 @@ import (
"math"
"strings"
"github.com/google/btree"
pb "go.etcd.io/etcd/api/v3/etcdserverpb"
v3 "go.etcd.io/etcd/client/v3"
)
@ -71,9 +73,13 @@ type STM interface {
// Commit may return CommitError if transaction is outdated and needs retry.
Commit() error
// Rollback emties the read and write sets such that a subsequent commit
// Rollback entries the read and write sets such that a subsequent commit
// won't alter the database.
Rollback()
// Prefetch prefetches the passed keys and prefixes. For prefixes it'll
// fetch the whole range.
Prefetch(keys []string, prefix []string)
}
// CommitError is used to check if there was an error
@ -104,15 +110,26 @@ func (e DatabaseError) Error() string {
return fmt.Sprintf("etcd error: %v - %v", e.msg, e.err)
}
// stmGet is the result of a read operation,
// a value and the mod revision of the key/value.
// stmGet is the result of a read operation, a value and the mod revision of the
// key/value.
type stmGet struct {
val string
KV
rev int64
}
// Less implements less operator for btree.BTree.
func (c *stmGet) Less(than btree.Item) bool {
return c.key < than.(*stmGet).key
}
// readSet stores all reads done in an STM.
type readSet map[string]stmGet
type readSet struct {
// tree stores the items in the read set.
tree *btree.BTree
// fullRanges stores full range prefixes.
fullRanges map[string]struct{}
}
// stmPut stores a value and an operation (put/delete).
type stmPut struct {
@ -141,11 +158,8 @@ type stm struct {
// options stores optional settings passed by the user.
options *STMOptions
// prefetch hold prefetched key values and revisions.
prefetch readSet
// rset holds read key values and revisions.
rset readSet
rset *readSet
// wset holds overwritten keys and their values.
wset writeSet
@ -158,6 +172,9 @@ type stm struct {
// onCommit gets called upon commit.
onCommit func()
// callCount tracks the number of times we called into etcd.
callCount int
}
// STMOptions can be used to pass optional settings
@ -188,9 +205,12 @@ func WithCommitStatsCallback(cb func(bool, CommitStats)) STMOptionFunc {
// RunSTM runs the apply function by creating an STM using serializable snapshot
// isolation, passing it to the apply and handling commit errors and retries.
func RunSTM(cli *v3.Client, apply func(STM) error, txQueue *commitQueue,
so ...STMOptionFunc) error {
so ...STMOptionFunc) (int, error) {
return runSTM(makeSTM(cli, false, txQueue, so...), apply)
stm := makeSTM(cli, false, txQueue, so...)
err := runSTM(stm, apply)
return stm.callCount, err
}
// NewSTM creates a new STM instance, using serializable snapshot isolation.
@ -213,15 +233,15 @@ func makeSTM(cli *v3.Client, manual bool, txQueue *commitQueue,
}
s := &stm{
client: cli,
manual: manual,
txQueue: txQueue,
options: opts,
prefetch: make(map[string]stmGet),
client: cli,
manual: manual,
txQueue: txQueue,
options: opts,
rset: newReadSet(),
}
// Reset read and write set.
s.Rollback()
s.rollback(true)
return s
}
@ -262,8 +282,11 @@ func runSTM(s *stm, apply func(STM) error) error {
return
}
// Rollback before trying to re-apply.
s.Rollback()
// Rollback the write set before trying to re-apply.
// Upon commit we retrieved the latest version of all
// previously fetched keys and ranges so we don't need
// to rollback the read set.
s.rollback(false)
retries++
// Re-apply the transaction closure.
@ -287,14 +310,16 @@ func runSTM(s *stm, apply func(STM) error) error {
// result in queueing up transactions and contending DB access.
// Copying these strings is cheap due to Go's immutable string which is
// always a reference.
rkeys := make([]string, len(s.rset))
rkeys := make([]string, s.rset.tree.Len())
wkeys := make([]string, len(s.wset))
i := 0
for key := range s.rset {
rkeys[i] = key
s.rset.tree.Ascend(func(item btree.Item) bool {
rkeys[i] = item.(*stmGet).key
i++
}
return true
})
i = 0
for key := range s.wset {
@ -320,42 +345,225 @@ func runSTM(s *stm, apply func(STM) error) error {
return executeErr
}
// add inserts a txn response to the read set. This is useful when the txn
// fails due to conflict where the txn response can be used to prefetch
// key/values.
func (rs readSet) add(txnResp *v3.TxnResponse) {
for _, resp := range txnResp.Responses {
getResp := (*v3.GetResponse)(resp.GetResponseRange())
func newReadSet() *readSet {
return &readSet{
tree: btree.New(5),
fullRanges: make(map[string]struct{}),
}
}
// add inserts key/values to to read set.
func (rs *readSet) add(responses []*pb.ResponseOp) {
for _, resp := range responses {
getResp := resp.GetResponseRange()
for _, kv := range getResp.Kvs {
rs[string(kv.Key)] = stmGet{
val: string(kv.Value),
rev: kv.ModRevision,
}
rs.addItem(
string(kv.Key), string(kv.Value), kv.ModRevision,
)
}
}
}
// gets is a helper to create an op slice for transaction
// construction.
func (rs readSet) gets() []v3.Op {
ops := make([]v3.Op, 0, len(rs))
// addFullRange adds all full ranges to the read set.
func (rs *readSet) addFullRange(prefixes []string, responses []*pb.ResponseOp) {
for i, resp := range responses {
getResp := resp.GetResponseRange()
for _, kv := range getResp.Kvs {
rs.addItem(
string(kv.Key), string(kv.Value), kv.ModRevision,
)
}
for k := range rs {
ops = append(ops, v3.OpGet(k))
rs.fullRanges[prefixes[i]] = struct{}{}
}
}
// presetItem presets a key to zero revision if not already present in the read
// set.
func (rs *readSet) presetItem(key string) {
item := &stmGet{
KV: KV{
key: key,
},
rev: 0,
}
if !rs.tree.Has(item) {
rs.tree.ReplaceOrInsert(item)
}
}
// addItem adds a single new key/value to the read set (if not already present).
func (rs *readSet) addItem(key, val string, modRevision int64) {
item := &stmGet{
KV: KV{
key: key,
val: val,
},
rev: modRevision,
}
rs.tree.ReplaceOrInsert(item)
}
// hasFullRange checks if the read set has a full range prefetched.
func (rs *readSet) hasFullRange(prefix string) bool {
_, ok := rs.fullRanges[prefix]
return ok
}
// next returns the pre-fetched next value of the prefix. If matchKey is true,
// it'll simply return the key/value that matches the passed key.
func (rs *readSet) next(prefix, key string, matchKey bool) (*stmGet, bool) {
pivot := &stmGet{
KV: KV{
key: key,
},
}
var result *stmGet
rs.tree.AscendGreaterOrEqual(
pivot,
func(item btree.Item) bool {
next := item.(*stmGet)
if (!matchKey && next.key == key) || next.rev == 0 {
return true
}
if strings.HasPrefix(next.key, prefix) {
result = next
}
return false
},
)
return result, result != nil
}
// prev returns the pre-fetched prev key/value of the prefix from key.
func (rs *readSet) prev(prefix, key string) (*stmGet, bool) {
pivot := &stmGet{
KV: KV{
key: key,
},
}
var result *stmGet
rs.tree.DescendLessOrEqual(
pivot, func(item btree.Item) bool {
prev := item.(*stmGet)
if prev.key == key || prev.rev == 0 {
return true
}
if strings.HasPrefix(prev.key, prefix) {
result = prev
}
return false
},
)
return result, result != nil
}
// last returns the last key/value of the passed range (if prefetched).
func (rs *readSet) last(prefix string) (*stmGet, bool) {
// We create an artificial key here that is just one step away from the
// prefix. This way when we try to get the first item with our prefix
// before this newly crafted key we'll make sure it's the last element
// of our range.
key := []byte(prefix)
key[len(key)-1] += 1
return rs.prev(prefix, string(key))
}
// clear completely clears the readset.
func (rs *readSet) clear() {
rs.tree.Clear(false)
rs.fullRanges = make(map[string]struct{})
}
// getItem returns the matching key/value from the readset.
func (rs *readSet) getItem(key string) (*stmGet, bool) {
pivot := &stmGet{
KV: KV{
key: key,
},
rev: 0,
}
item := rs.tree.Get(pivot)
if item != nil {
return item.(*stmGet), true
}
// It's possible that although this key isn't in the read set, we
// fetched a full range the key is prefixed with. In this case we'll
// insert the key with zero revision.
for prefix := range rs.fullRanges {
if strings.HasPrefix(key, prefix) {
rs.tree.ReplaceOrInsert(pivot)
return pivot, true
}
}
return nil, false
}
// prefetchSet is a helper to create an op slice of all OpGet's that represent
// fetched keys appended with a slice of all OpGet's representing all prefetched
// full ranges.
func (rs *readSet) prefetchSet() []v3.Op {
ops := make([]v3.Op, 0, rs.tree.Len())
rs.tree.Ascend(func(item btree.Item) bool {
key := item.(*stmGet).key
for prefix := range rs.fullRanges {
// Do not add the key if it has been prefetched in a
// full range.
if strings.HasPrefix(key, prefix) {
return true
}
}
ops = append(ops, v3.OpGet(key))
return true
})
for prefix := range rs.fullRanges {
ops = append(ops, v3.OpGet(prefix, v3.WithPrefix()))
}
return ops
}
// getFullRanges returns all prefixes that we prefetched.
func (rs *readSet) getFullRanges() []string {
prefixes := make([]string, 0, len(rs.fullRanges))
for prefix := range rs.fullRanges {
prefixes = append(prefixes, prefix)
}
return prefixes
}
// cmps returns a compare list which will serve as a precondition testing that
// the values in the read set didn't change.
func (rs readSet) cmps() []v3.Cmp {
cmps := make([]v3.Cmp, 0, len(rs))
for key, getValue := range rs {
cmps = append(cmps, v3.Compare(
v3.ModRevision(key), "=", getValue.rev,
))
}
func (rs *readSet) cmps() []v3.Cmp {
cmps := make([]v3.Cmp, 0, rs.tree.Len())
rs.tree.Ascend(func(item btree.Item) bool {
get := item.(*stmGet)
cmps = append(
cmps, v3.Compare(v3.ModRevision(get.key), "=", get.rev),
)
return true
})
return cmps
}
@ -384,6 +592,7 @@ func (ws writeSet) puts() []v3.Op {
// then fetch will try to fix the STM's snapshot revision (if not already set).
// We'll also cache the returned key/value in the read set.
func (s *stm) fetch(key string, opts ...v3.OpOption) ([]KV, error) {
s.callCount++
resp, err := s.client.Get(
s.options.ctx, key, append(opts, s.getOpts...)...,
)
@ -394,7 +603,7 @@ func (s *stm) fetch(key string, opts ...v3.OpOption) ([]KV, error) {
}
}
// Set revison and serializable options upon first fetch
// Set revision and serializable options upon first fetch
// for any subsequent fetches.
if s.getOpts == nil {
s.revision = resp.Header.Revision
@ -408,26 +617,18 @@ func (s *stm) fetch(key string, opts ...v3.OpOption) ([]KV, error) {
// Add assertion to the read set which will extend our commit
// constraint such that the commit will fail if the key is
// present in the database.
s.rset[key] = stmGet{
rev: 0,
}
s.rset.addItem(key, "", 0)
}
var result []KV
// Fill the read set with key/values returned.
for _, kv := range resp.Kvs {
// Remove from prefetch.
key := string(kv.Key)
val := string(kv.Value)
delete(s.prefetch, key)
// Add to read set.
s.rset[key] = stmGet{
val: val,
rev: kv.ModRevision,
}
s.rset.addItem(key, val, kv.ModRevision)
result = append(result, KV{key, val})
}
@ -452,20 +653,8 @@ func (s *stm) Get(key string) ([]byte, error) {
return []byte(put.val), nil
}
// Populate read set if key is present in
// the prefetch set.
if getValue, ok := s.prefetch[key]; ok {
delete(s.prefetch, key)
// Use the prefetched value only if it is for
// an existing key.
if getValue.rev != 0 {
s.rset[key] = getValue
}
}
// Return value if alread in read set.
if getValue, ok := s.rset[key]; ok {
if getValue, ok := s.rset.getItem(key); ok {
// Return the value if the rset contains an existing key.
if getValue.rev != 0 {
return []byte(getValue.val), nil
@ -497,21 +686,28 @@ func (s *stm) First(prefix string) (*KV, error) {
// Last returns the last key/value with prefix. If there's no key starting with
// prefix, Last will return nil.
func (s *stm) Last(prefix string) (*KV, error) {
// As we don't know the full range, fetch the last
// key/value with this prefix first.
resp, err := s.fetch(prefix, v3.WithLastKey()...)
if err != nil {
return nil, err
}
var (
kv KV
found bool
)
if len(resp) > 0 {
kv = resp[0]
found = true
if s.rset.hasFullRange(prefix) {
if item, ok := s.rset.last(prefix); ok {
kv = item.KV
found = true
}
} else {
// As we don't know the full range, fetch the last
// key/value with this prefix first.
resp, err := s.fetch(prefix, v3.WithLastKey()...)
if err != nil {
return nil, err
}
if len(resp) > 0 {
kv = resp[0]
found = true
}
}
// Now make sure there's nothing in the write set
@ -539,32 +735,41 @@ func (s *stm) Last(prefix string) (*KV, error) {
}
// Prev returns the prior key/value before key (with prefix). If there's no such
// key Next will return nil.
// key Prev will return nil.
func (s *stm) Prev(prefix, startKey string) (*KV, error) {
var result KV
var kv, result KV
fetchKey := startKey
matchFound := false
for {
// Ask etcd to retrieve one key that is a
// match in descending order from the passed key.
opts := []v3.OpOption{
v3.WithRange(fetchKey),
v3.WithSort(v3.SortByKey, v3.SortDescend),
v3.WithLimit(1),
}
if s.rset.hasFullRange(prefix) {
if item, ok := s.rset.prev(prefix, fetchKey); ok {
kv = item.KV
} else {
break
}
} else {
kvs, err := s.fetch(prefix, opts...)
if err != nil {
return nil, err
}
// Ask etcd to retrieve one key that is a
// match in descending order from the passed key.
opts := []v3.OpOption{
v3.WithRange(fetchKey),
v3.WithSort(v3.SortByKey, v3.SortDescend),
v3.WithLimit(1),
}
if len(kvs) == 0 {
break
}
kvs, err := s.fetch(prefix, opts...)
if err != nil {
return nil, err
}
kv := &kvs[0]
if len(kvs) == 0 {
break
}
kv = kvs[0]
}
// WithRange and WithPrefix can't be used
// together, so check prefix here. If the
@ -580,13 +785,13 @@ func (s *stm) Prev(prefix, startKey string) (*KV, error) {
continue
}
result = *kv
result = kv
matchFound = true
break
}
// Closre holding all checks to find a possibly
// Closure holding all checks to find a possibly
// better match.
matches := func(key string) bool {
if !strings.HasPrefix(key, prefix) {
@ -635,47 +840,60 @@ func (s *stm) Seek(prefix, key string) (*KV, error) {
// passed startKey. If includeStartKey is set to true, it'll return the value
// of startKey (essentially implementing seek).
func (s *stm) next(prefix, startKey string, includeStartKey bool) (*KV, error) {
var result KV
var kv, result KV
fetchKey := startKey
firstFetch := true
matchFound := false
for {
// Ask etcd to retrieve one key that is a
// match in ascending order from the passed key.
opts := []v3.OpOption{
v3.WithFromKey(),
v3.WithSort(v3.SortByKey, v3.SortAscend),
v3.WithLimit(1),
}
// By default we include the start key too
// if it is a full match.
if includeStartKey && firstFetch {
if s.rset.hasFullRange(prefix) {
matchKey := includeStartKey && firstFetch
firstFetch = false
if item, ok := s.rset.next(
prefix, fetchKey, matchKey,
); ok {
kv = item.KV
} else {
break
}
} else {
// If we'd like to retrieve the first key
// after the start key.
fetchKey += "\x00"
}
// Ask etcd to retrieve one key that is a
// match in ascending order from the passed key.
opts := []v3.OpOption{
v3.WithFromKey(),
v3.WithSort(v3.SortByKey, v3.SortAscend),
v3.WithLimit(1),
}
kvs, err := s.fetch(fetchKey, opts...)
if err != nil {
return nil, err
}
// By default we include the start key too
// if it is a full match.
if includeStartKey && firstFetch {
firstFetch = false
} else {
// If we'd like to retrieve the first key
// after the start key.
fetchKey += "\x00"
}
if len(kvs) == 0 {
break
}
kvs, err := s.fetch(fetchKey, opts...)
if err != nil {
return nil, err
}
kv := &kvs[0]
// WithRange and WithPrefix can't be used
// together, so check prefix here. If the
// returned key no longer has the prefix,
// then break the fetch loop.
if !strings.HasPrefix(kv.key, prefix) {
break
if len(kvs) == 0 {
break
}
kv = kvs[0]
// WithRange and WithPrefix can't be used
// together, so check prefix here. If the
// returned key no longer has the prefix,
// then break the fetch loop.
if !strings.HasPrefix(kv.key, prefix) {
break
}
}
// Move on to fetch starting with the next
@ -685,7 +903,7 @@ func (s *stm) next(prefix, startKey string, includeStartKey bool) (*KV, error) {
continue
}
result = *kv
result = kv
matchFound = true
break
@ -753,6 +971,72 @@ func (s *stm) OnCommit(cb func()) {
s.onCommit = cb
}
// Prefetch will prefetch the passed keys and prefixes in one transaction.
// Keys and prefixes that we already have will be skipped.
func (s *stm) Prefetch(keys []string, prefixes []string) {
fetchKeys := make([]string, 0, len(keys))
for _, key := range keys {
if _, ok := s.rset.getItem(key); !ok {
fetchKeys = append(fetchKeys, key)
}
}
fetchPrefixes := make([]string, 0, len(prefixes))
for _, prefix := range prefixes {
if s.rset.hasFullRange(prefix) {
continue
}
fetchPrefixes = append(fetchPrefixes, prefix)
}
if len(fetchKeys) == 0 && len(fetchPrefixes) == 0 {
return
}
prefixOpts := append(
[]v3.OpOption{v3.WithPrefix()}, s.getOpts...,
)
txn := s.client.Txn(s.options.ctx)
ops := make([]v3.Op, 0, len(fetchKeys)+len(fetchPrefixes))
for _, key := range fetchKeys {
ops = append(ops, v3.OpGet(key, s.getOpts...))
}
for _, key := range fetchPrefixes {
ops = append(ops, v3.OpGet(key, prefixOpts...))
}
txn.Then(ops...)
txnresp, err := txn.Commit()
s.callCount++
if err != nil {
return
}
// Set revision and serializable options upon first fetch for any
// subsequent fetches.
if s.getOpts == nil {
s.revision = txnresp.Header.Revision
s.getOpts = []v3.OpOption{
v3.WithRev(s.revision),
v3.WithSerializable(),
}
}
// Preset keys to "not-present" (revision set to zero).
for _, key := range fetchKeys {
s.rset.presetItem(key)
}
// Set prefetched keys.
s.rset.add(txnresp.Responses[:len(fetchKeys)])
// Set prefetched ranges.
s.rset.addFullRange(fetchPrefixes, txnresp.Responses[len(fetchKeys):])
}
// commit builds the final transaction and tries to execute it. If commit fails
// because the keys have changed return a CommitError, otherwise return a
// DatabaseError.
@ -774,10 +1058,11 @@ func (s *stm) commit() (CommitStats, error) {
txn = txn.If(cmps...)
txn = txn.Then(s.wset.puts()...)
// Prefetch keys in case of conflict to save
// a round trip to etcd.
txn = txn.Else(s.rset.gets()...)
// Prefetch keys and ranges in case of conflict to save as many
// round-trips as possible.
txn = txn.Else(s.rset.prefetchSet()...)
s.callCount++
txnresp, err := txn.Commit()
if err != nil {
return stats, DatabaseError{
@ -786,8 +1071,7 @@ func (s *stm) commit() (CommitStats, error) {
}
}
// Call the commit callback if the transaction
// was successful.
// Call the commit callback if the transaction was successful.
if txnresp.Succeeded {
if s.onCommit != nil {
s.onCommit()
@ -796,12 +1080,23 @@ func (s *stm) commit() (CommitStats, error) {
return stats, nil
}
// Load prefetch before if commit failed.
s.rset.add(txnresp)
s.prefetch = s.rset
// Determine where our fetched full ranges begin in the response.
prefixes := s.rset.getFullRanges()
firstPrefixResp := len(txnresp.Responses) - len(prefixes)
// Return CommitError indicating that the transaction
// can be retried.
// Clear reload and preload it with the prefetched keys and ranges.
s.rset.clear()
s.rset.add(txnresp.Responses[:firstPrefixResp])
s.rset.addFullRange(prefixes, txnresp.Responses[firstPrefixResp:])
// Set our revision boundary.
s.revision = txnresp.Header.Revision
s.getOpts = []v3.OpOption{
v3.WithRev(s.revision),
v3.WithSerializable(),
}
// Return CommitError indicating that the transaction can be retried.
return stats, CommitError{}
}
@ -819,8 +1114,17 @@ func (s *stm) Commit() error {
// Rollback resets the STM. This is useful for uncommitted transaction rollback
// and also used in the STM main loop to reset state if commit fails.
func (s *stm) Rollback() {
s.rset = make(map[string]stmGet)
s.wset = make(map[string]stmPut)
s.getOpts = nil
s.revision = math.MaxInt64 - 1
s.rollback(true)
}
// rollback will reset the read and write sets. If clearReadSet is false we'll
// only reset the the write set.
func (s *stm) rollback(clearReadSet bool) {
if clearReadSet {
s.rset.clear()
s.revision = math.MaxInt64 - 1
s.getOpts = nil
}
s.wset = make(map[string]stmPut)
}

View File

@ -39,8 +39,9 @@ func TestPutToEmpty(t *testing.T) {
return nil
}
err = RunSTM(db.cli, apply, txQueue)
callCount, err := RunSTM(db.cli, apply, txQueue)
require.NoError(t, err)
require.Equal(t, 1, callCount)
require.Equal(t, "abc", f.Get("123"))
}
@ -66,6 +67,9 @@ func TestGetPutDel(t *testing.T) {
{"e", "5"},
}
// Extra 2 => Get(x), Commit()
expectedCallCount := len(testKeyValues) + 2
for _, kv := range testKeyValues {
f.Put(kv.key, kv.val)
}
@ -79,11 +83,12 @@ func TestGetPutDel(t *testing.T) {
require.NoError(t, err)
require.Nil(t, v)
// Fetches: 1.
v, err = stm.Get("x")
require.NoError(t, err)
require.Nil(t, v)
// Get all existing keys.
// Get all existing keys. Fetches: len(testKeyValues)
for _, kv := range testKeyValues {
v, err = stm.Get(kv.key)
require.NoError(t, err)
@ -120,8 +125,9 @@ func TestGetPutDel(t *testing.T) {
return nil
}
err = RunSTM(db.cli, apply, txQueue)
callCount, err := RunSTM(db.cli, apply, txQueue)
require.NoError(t, err)
require.Equal(t, expectedCallCount, callCount)
require.Equal(t, "1", f.Get("a"))
require.Equal(t, "2", f.Get("b"))
@ -134,6 +140,17 @@ func TestGetPutDel(t *testing.T) {
func TestFirstLastNextPrev(t *testing.T) {
t.Parallel()
testFirstLastNextPrev(t, nil, nil, 41)
testFirstLastNextPrev(t, nil, []string{"k"}, 4)
testFirstLastNextPrev(t, nil, []string{"k", "w"}, 2)
testFirstLastNextPrev(t, []string{"kb"}, nil, 42)
testFirstLastNextPrev(t, []string{"kb", "ke"}, nil, 42)
testFirstLastNextPrev(t, []string{"kb", "ke", "w"}, []string{"k", "w"}, 2)
}
func testFirstLastNextPrev(t *testing.T, prefetchKeys []string,
prefetchRange []string, expectedCallCount int) {
f := NewEtcdTestFixture(t)
ctx, cancel := context.WithCancel(context.Background())
@ -159,6 +176,8 @@ func TestFirstLastNextPrev(t *testing.T) {
require.NoError(t, err)
apply := func(stm STM) error {
stm.Prefetch(prefetchKeys, prefetchRange)
// First/Last on valid multi item interval.
kv, err := stm.First("k")
require.NoError(t, err)
@ -177,11 +196,25 @@ func TestFirstLastNextPrev(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &KV{"w", "w"}, kv)
// Non existing.
val, err := stm.Get("ke1")
require.Nil(t, val)
require.Nil(t, err)
val, err = stm.Get("ke2")
require.Nil(t, val)
require.Nil(t, err)
// Next/Prev on start/end.
kv, err = stm.Next("k", "ke")
require.NoError(t, err)
require.Nil(t, kv)
// Non existing.
val, err = stm.Get("ka")
require.Nil(t, val)
require.Nil(t, err)
kv, err = stm.Prev("k", "kb")
require.NoError(t, err)
require.Nil(t, kv)
@ -277,8 +310,9 @@ func TestFirstLastNextPrev(t *testing.T) {
return nil
}
err = RunSTM(db.cli, apply, txQueue)
callCount, err := RunSTM(db.cli, apply, txQueue)
require.NoError(t, err)
require.Equal(t, expectedCallCount, callCount)
require.Equal(t, "0", f.Get("ka"))
require.Equal(t, "2", f.Get("kc"))
@ -330,9 +364,11 @@ func TestCommitError(t *testing.T) {
return nil
}
err = RunSTM(db.cli, apply, txQueue)
callCount, err := RunSTM(db.cli, apply, txQueue)
require.NoError(t, err)
require.Equal(t, 2, cnt)
// Get() + 2 * Commit().
require.Equal(t, 3, callCount)
require.Equal(t, "abc", f.Get("123"))
}

View File

@ -135,6 +135,11 @@ func TestEtcd(t *testing.T) {
test: testTxRollback,
expectedDb: map[string]string{},
},
{
name: "prefetch",
test: testPrefetch,
expectedDb: map[string]string{},
},
}
for _, test := range tests {

View File

@ -3,9 +3,12 @@ module github.com/lightningnetwork/lnd/kvdb
require (
github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f
github.com/btcsuite/btcwallet/walletdb v1.3.6-0.20210803004036-eebed51155ec
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/btree v1.0.1
github.com/lightningnetwork/lnd/healthcheck v1.0.0
github.com/stretchr/testify v1.7.0
go.etcd.io/bbolt v1.3.6
go.etcd.io/etcd/api/v3 v3.5.0
go.etcd.io/etcd/client/pkg/v3 v3.5.0
go.etcd.io/etcd/client/v3 v3.5.0
go.etcd.io/etcd/server/v3 v3.5.0

View File

@ -93,6 +93,42 @@ type RwCursor = walletdb.ReadWriteCursor
// writes. When only reads are necessary, consider using a RTx instead.
type RwTx = walletdb.ReadWriteTx
// ExtendedRTx is an extension to walletdb.ReadTx to allow prefetching of keys.
type ExtendedRTx interface {
RTx
// RootBucket returns the "root bucket" which is pseudo bucket used
// when prefetching (keys from) top level buckets.
RootBucket() RBucket
}
// ExtendedRBucket is an extension to walletdb.ReadBucket to allow prefetching
// of all values inside buckets.
type ExtendedRBucket interface {
RBucket
// Prefetch will attempt to prefetch all values under a path.
Prefetch(paths ...[]string)
}
// Prefetch will attempt to prefetch all values under a path from the passed
// bucket.
func Prefetch(b RBucket, paths ...[]string) {
if bucket, ok := b.(ExtendedRBucket); ok {
bucket.Prefetch(paths...)
}
}
// RootBucket is a wrapper to ExtendedRTx.RootBucket which does nothing if
// the implementation doesn't have ExtendedRTx.
func RootBucket(t RTx) RBucket {
if tx, ok := t.(ExtendedRTx); ok {
return tx.RootBucket()
}
return nil
}
var (
// ErrBucketNotFound is returned when trying to access a bucket that
// has not been created yet.

188
kvdb/prefetch_test.go Normal file
View File

@ -0,0 +1,188 @@
package kvdb
import (
"fmt"
"testing"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/davecgh/go-spew/spew"
"github.com/stretchr/testify/require"
)
func fetchBucket(t *testing.T, bucket walletdb.ReadBucket) map[string]string {
items := make(map[string]string)
err := bucket.ForEach(func(k, v []byte) error {
if v != nil {
items[string(k)] = string(v)
}
return nil
})
require.NoError(t, err)
return items
}
func alterBucket(t *testing.T, bucket walletdb.ReadWriteBucket,
put map[string]string, remove []string) {
for k, v := range put {
require.NoError(t, bucket.Put([]byte(k), []byte(v)))
}
for _, k := range remove {
require.NoError(t, bucket.Delete([]byte(k)))
}
}
func prefetchTest(t *testing.T, db walletdb.DB,
prefetchAt []bool, put map[string]string, remove []string) {
prefetch := func(i int, tx walletdb.ReadTx) {
require.Less(t, i, len(prefetchAt))
if prefetchAt[i] {
Prefetch(
RootBucket(tx),
[]string{"top"}, []string{"top", "bucket"},
)
}
}
items := map[string]string{
"a": "1",
"b": "2",
"c": "3",
"d": "4",
"e": "5",
}
err := Update(db, func(tx walletdb.ReadWriteTx) error {
top, err := tx.CreateTopLevelBucket([]byte("top"))
require.NoError(t, err)
require.NotNil(t, top)
for k, v := range items {
require.NoError(t, top.Put([]byte(k), []byte(v)))
}
bucket, err := top.CreateBucket([]byte("bucket"))
require.NoError(t, err)
require.NotNil(t, bucket)
for k, v := range items {
require.NoError(t, bucket.Put([]byte(k), []byte(v)))
}
return nil
}, func() {})
require.NoError(t, err)
for k, v := range put {
items[k] = v
}
for _, k := range remove {
delete(items, k)
}
err = Update(db, func(tx walletdb.ReadWriteTx) error {
prefetch(0, tx)
top := tx.ReadWriteBucket([]byte("top"))
require.NotNil(t, top)
alterBucket(t, top, put, remove)
prefetch(1, tx)
require.Equal(t, items, fetchBucket(t, top))
prefetch(2, tx)
bucket := top.NestedReadWriteBucket([]byte("bucket"))
require.NotNil(t, bucket)
alterBucket(t, bucket, put, remove)
prefetch(3, tx)
require.Equal(t, items, fetchBucket(t, bucket))
return nil
}, func() {})
require.NoError(t, err)
err = Update(db, func(tx walletdb.ReadWriteTx) error {
return tx.DeleteTopLevelBucket([]byte("top"))
}, func() {})
require.NoError(t, err)
}
// testPrefetch tests that prefetching buckets works as expected even when the
// prefetch happens multiple times and the bucket contents change. Our expectation
// is that with or without prefetches, the kvdb layer works accourding to the
// interface specification.
func testPrefetch(t *testing.T, db walletdb.DB) {
tests := []struct {
put map[string]string
remove []string
}{
{
put: nil,
remove: nil,
},
{
put: map[string]string{
"a": "a",
"aa": "aa",
"aaa": "aaa",
"x": "x",
"y": "y",
},
remove: nil,
},
{
put: map[string]string{
"a": "a",
"aa": "aa",
"aaa": "aaa",
"x": "x",
"y": "y",
},
remove: []string{"a", "c", "d"},
},
{
put: nil,
remove: []string{"b", "d"},
},
}
prefetchAt := [][]bool{
{false, false, false, false},
{true, false, false, false},
{false, true, false, false},
{false, false, true, false},
{false, false, false, true},
{true, true, false, false},
{true, true, true, false},
{true, true, true, true},
{true, false, true, true},
{true, false, false, true},
{true, false, true, false},
}
for i, test := range tests {
test := test
for j := 0; j < len(prefetchAt); j++ {
if !t.Run(
fmt.Sprintf("prefetch %d %d", i, j),
func(t *testing.T) {
prefetchTest(
t, db, prefetchAt[j], test.put,
test.remove,
)
}) {
fmt.Printf("Prefetch test (%d, %d) failed:\n"+
"testcase=%v\n prefetch=%v\n",
i, j, spew.Sdump(test),
spew.Sdump(prefetchAt[j]))
}
}
}
}