From 1f8065de3593635a5bbfa40b2e2e0b87c1f5104e Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 10 Oct 2023 18:41:06 +0200 Subject: [PATCH] channeldb: add k/v implementation for InvoiceDB.FetchPendingInvoices --- channeldb/invoice_test.go | 54 +++++++++++++++++++++++++++++++++++ channeldb/invoices.go | 59 +++++++++++++++++++++++++++++++++++++++ invoices/interface.go | 5 ++++ invoices/mock.go | 9 ++++++ 4 files changed, 127 insertions(+) diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index e4d9ffffa..f0d667332 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -1013,6 +1013,60 @@ func TestSettleIndexAmpPayments(t *testing.T) { require.Nil(t, err) } +// TestFetchPendingInvoices tests that we can fetch all pending invoices from +// the database using the FetchPendingInvoices method. +func TestFetchPendingInvoices(t *testing.T) { + t.Parallel() + + db, err := MakeTestInvoiceDB(t, OptionClock(testClock)) + require.NoError(t, err, "unable to make test db") + + ctxb := context.Background() + + // Make sure that fetching pending invoices from an empty database + // returns an empty result and no errors. + pending, err := db.FetchPendingInvoices(ctxb) + require.NoError(t, err) + require.Empty(t, pending) + + const numInvoices = 20 + var settleIndex uint64 = 1 + pendingInvoices := make(map[lntypes.Hash]invpkg.Invoice) + + for i := 1; i <= numInvoices; i++ { + amt := lnwire.MilliSatoshi(i * 1000) + invoice, err := randInvoice(amt) + require.NoError(t, err) + + invoice.CreationDate = invoice.CreationDate.Add( + time.Duration(i-1) * time.Second, + ) + + paymentHash := invoice.Terms.PaymentPreimage.Hash() + + _, err = db.AddInvoice(ctxb, invoice, paymentHash) + require.NoError(t, err) + + // Settle every second invoice. + if i%2 == 0 { + pendingInvoices[paymentHash] = *invoice + continue + } + + ref := invpkg.InvoiceRefByHash(paymentHash) + _, err = db.UpdateInvoice(ctxb, ref, nil, getUpdateInvoice(amt)) + require.NoError(t, err) + + settleTestInvoice(invoice, settleIndex) + settleIndex++ + } + + // Fetch all pending invoices. + pending, err = db.FetchPendingInvoices(ctxb) + require.NoError(t, err) + require.Equal(t, pendingInvoices, pending) +} + // TestScanInvoices tests that ScanInvoices scans through all stored invoices // correctly. func TestScanInvoices(t *testing.T) { diff --git a/channeldb/invoices.go b/channeldb/invoices.go index d7016ff33..cc665d097 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -433,6 +433,65 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket, } } +// FetchPendingInvoices returns all invoices that have not yet been settled or +// canceled. The returned map is keyed by the payment hash of each respective +// invoice. +func (d *DB) FetchPendingInvoices(_ context.Context) ( + map[lntypes.Hash]invpkg.Invoice, error) { + + result := make(map[lntypes.Hash]invpkg.Invoice) + + err := kvdb.View(d, func(tx kvdb.RTx) error { + invoices := tx.ReadBucket(invoiceBucket) + if invoices == nil { + return nil + } + + invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket) + if invoiceIndex == nil { + // Mask the error if there's no invoice + // index as that simply means there are no + // invoices added yet to the DB. In this case + // we simply return an empty list. + return nil + } + + return invoiceIndex.ForEach(func(k, v []byte) error { + // Skip the special numInvoicesKey as that does not + // point to a valid invoice. + if bytes.Equal(k, numInvoicesKey) { + return nil + } + + // Skip sub-buckets. + if v == nil { + return nil + } + + invoice, err := fetchInvoice(v, invoices) + if err != nil { + return err + } + + if invoice.IsPending() { + var paymentHash lntypes.Hash + copy(paymentHash[:], k) + result[paymentHash] = invoice + } + + return nil + }) + }, func() { + result = make(map[lntypes.Hash]invpkg.Invoice) + }) + + if err != nil { + return nil, err + } + + return result, nil +} + // ScanInvoices scans through all invoices and calls the passed scanFunc for // for each invoice with its respective payment hash. Additionally a reset() // closure is passed which is used to reset/initialize partial results and also diff --git a/invoices/interface.go b/invoices/interface.go index 36bc0ea48..d88a96753 100644 --- a/invoices/interface.go +++ b/invoices/interface.go @@ -56,6 +56,11 @@ type InvoiceDB interface { ScanInvoices(ctx context.Context, scanFunc InvScanFunc, reset func()) error + // FetchPendingInvoices returns all invoices that have not yet been + // settled or canceled. + FetchPendingInvoices(ctx context.Context) (map[lntypes.Hash]Invoice, + error) + // QueryInvoices allows a caller to query the invoice database for // invoices within the specified add index range. QueryInvoices(ctx context.Context, q InvoiceQuery) (InvoiceSlice, error) diff --git a/invoices/mock.go b/invoices/mock.go index 8410208bf..5c419d0b1 100644 --- a/invoices/mock.go +++ b/invoices/mock.go @@ -1,6 +1,8 @@ package invoices import ( + "context" + "github.com/lightningnetwork/lnd/lntypes" "github.com/stretchr/testify/mock" ) @@ -55,6 +57,13 @@ func (m *MockInvoiceDB) ScanInvoices(scanFunc InvScanFunc, return args.Error(0) } +func (m *MockInvoiceDB) FetchPendingInvoices(ctx context.Context) ( + map[lntypes.Hash]Invoice, error) { + + args := m.Called(ctx) + return args.Get(0).(map[lntypes.Hash]Invoice), args.Error(1) +} + func (m *MockInvoiceDB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { args := m.Called(q) invoiceSlice, _ := args.Get(0).(InvoiceSlice)