channeldb: add k/v implementation for InvoiceDB.FetchPendingInvoices

This commit is contained in:
Andras Banki-Horvath 2023-10-10 18:41:06 +02:00
parent ad5cd9c8bb
commit 1f8065de35
No known key found for this signature in database
GPG key ID: 80E5375C094198D8
4 changed files with 127 additions and 0 deletions

View file

@ -1013,6 +1013,60 @@ func TestSettleIndexAmpPayments(t *testing.T) {
require.Nil(t, err)
}
// TestFetchPendingInvoices tests that we can fetch all pending invoices from
// the database using the FetchPendingInvoices method.
func TestFetchPendingInvoices(t *testing.T) {
t.Parallel()
db, err := MakeTestInvoiceDB(t, OptionClock(testClock))
require.NoError(t, err, "unable to make test db")
ctxb := context.Background()
// Make sure that fetching pending invoices from an empty database
// returns an empty result and no errors.
pending, err := db.FetchPendingInvoices(ctxb)
require.NoError(t, err)
require.Empty(t, pending)
const numInvoices = 20
var settleIndex uint64 = 1
pendingInvoices := make(map[lntypes.Hash]invpkg.Invoice)
for i := 1; i <= numInvoices; i++ {
amt := lnwire.MilliSatoshi(i * 1000)
invoice, err := randInvoice(amt)
require.NoError(t, err)
invoice.CreationDate = invoice.CreationDate.Add(
time.Duration(i-1) * time.Second,
)
paymentHash := invoice.Terms.PaymentPreimage.Hash()
_, err = db.AddInvoice(ctxb, invoice, paymentHash)
require.NoError(t, err)
// Settle every second invoice.
if i%2 == 0 {
pendingInvoices[paymentHash] = *invoice
continue
}
ref := invpkg.InvoiceRefByHash(paymentHash)
_, err = db.UpdateInvoice(ctxb, ref, nil, getUpdateInvoice(amt))
require.NoError(t, err)
settleTestInvoice(invoice, settleIndex)
settleIndex++
}
// Fetch all pending invoices.
pending, err = db.FetchPendingInvoices(ctxb)
require.NoError(t, err)
require.Equal(t, pendingInvoices, pending)
}
// TestScanInvoices tests that ScanInvoices scans through all stored invoices
// correctly.
func TestScanInvoices(t *testing.T) {

View file

@ -433,6 +433,65 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket,
}
}
// FetchPendingInvoices returns all invoices that have not yet been settled or
// canceled. The returned map is keyed by the payment hash of each respective
// invoice.
func (d *DB) FetchPendingInvoices(_ context.Context) (
map[lntypes.Hash]invpkg.Invoice, error) {
result := make(map[lntypes.Hash]invpkg.Invoice)
err := kvdb.View(d, func(tx kvdb.RTx) error {
invoices := tx.ReadBucket(invoiceBucket)
if invoices == nil {
return nil
}
invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
if invoiceIndex == nil {
// Mask the error if there's no invoice
// index as that simply means there are no
// invoices added yet to the DB. In this case
// we simply return an empty list.
return nil
}
return invoiceIndex.ForEach(func(k, v []byte) error {
// Skip the special numInvoicesKey as that does not
// point to a valid invoice.
if bytes.Equal(k, numInvoicesKey) {
return nil
}
// Skip sub-buckets.
if v == nil {
return nil
}
invoice, err := fetchInvoice(v, invoices)
if err != nil {
return err
}
if invoice.IsPending() {
var paymentHash lntypes.Hash
copy(paymentHash[:], k)
result[paymentHash] = invoice
}
return nil
})
}, func() {
result = make(map[lntypes.Hash]invpkg.Invoice)
})
if err != nil {
return nil, err
}
return result, nil
}
// ScanInvoices scans through all invoices and calls the passed scanFunc for
// for each invoice with its respective payment hash. Additionally a reset()
// closure is passed which is used to reset/initialize partial results and also

View file

@ -56,6 +56,11 @@ type InvoiceDB interface {
ScanInvoices(ctx context.Context, scanFunc InvScanFunc,
reset func()) error
// FetchPendingInvoices returns all invoices that have not yet been
// settled or canceled.
FetchPendingInvoices(ctx context.Context) (map[lntypes.Hash]Invoice,
error)
// QueryInvoices allows a caller to query the invoice database for
// invoices within the specified add index range.
QueryInvoices(ctx context.Context, q InvoiceQuery) (InvoiceSlice, error)

View file

@ -1,6 +1,8 @@
package invoices
import (
"context"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/stretchr/testify/mock"
)
@ -55,6 +57,13 @@ func (m *MockInvoiceDB) ScanInvoices(scanFunc InvScanFunc,
return args.Error(0)
}
func (m *MockInvoiceDB) FetchPendingInvoices(ctx context.Context) (
map[lntypes.Hash]Invoice, error) {
args := m.Called(ctx)
return args.Get(0).(map[lntypes.Hash]Invoice), args.Error(1)
}
func (m *MockInvoiceDB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
args := m.Called(q)
invoiceSlice, _ := args.Get(0).(InvoiceSlice)