diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index b261c4773..aca845d5f 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -99,7 +99,7 @@ func TestInvoiceWorkflow(t *testing.T) { // now have the settled bit toggle to true and a non-default // SettledDate payAmt := fakeInvoice.Terms.Value * 2 - if err := db.SettleInvoice(paymentHash, payAmt); err != nil { + if _, err := db.SettleInvoice(paymentHash, payAmt); err != nil { t.Fatalf("unable to settle invoice: %v", err) } dbInvoice2, err := db.LookupInvoice(paymentHash) @@ -260,7 +260,7 @@ func TestInvoiceAddTimeSeries(t *testing.T) { invoice.Terms.PaymentPreimage[:], ) - err := db.SettleInvoice(paymentHash, 0) + _, err := db.SettleInvoice(paymentHash, 0) if err != nil { t.Fatalf("unable to settle invoice: %v", err) } diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 5a2c044a2..a6fe0df94 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -393,9 +393,11 @@ func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) { // payment hash as fully settled. If an invoice matching the passed payment // hash doesn't existing within the database, then the action will fail with a // "not found" error. -func (d *DB) SettleInvoice(paymentHash [32]byte, amtPaid lnwire.MilliSatoshi) error { +func (d *DB) SettleInvoice(paymentHash [32]byte, + amtPaid lnwire.MilliSatoshi) (*Invoice, error) { - return d.Update(func(tx *bolt.Tx) error { + var settledInvoice *Invoice + err := d.Update(func(tx *bolt.Tx) error { invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) if err != nil { return err @@ -420,10 +422,21 @@ func (d *DB) SettleInvoice(paymentHash [32]byte, amtPaid lnwire.MilliSatoshi) er return ErrInvoiceNotFound } - return settleInvoice( + invoice, err := settleInvoice( invoices, settleIndex, invoiceNum, amtPaid, ) + if err != nil { + return err + } + + settledInvoice = invoice + return nil }) + if err != nil { + return nil, err + } + + return settledInvoice, nil } // InvoicesSettledSince can be used by callers to catch up any settled invoices @@ -670,17 +683,17 @@ func deserializeInvoice(r io.Reader) (Invoice, error) { } func settleInvoice(invoices, settleIndex *bolt.Bucket, invoiceNum []byte, - amtPaid lnwire.MilliSatoshi) error { + amtPaid lnwire.MilliSatoshi) (*Invoice, error) { invoice, err := fetchInvoice(invoiceNum, invoices) if err != nil { - return err + return nil, err } // Add idempotency to duplicate settles, return here to avoid // overwriting the previous info. if invoice.Terms.Settled { - return nil + return nil, nil } // Now that we know the invoice hasn't already been settled, we'll @@ -688,13 +701,13 @@ func settleInvoice(invoices, settleIndex *bolt.Bucket, invoiceNum []byte, // proper location within our time series. nextSettleSeqNo, err := settleIndex.NextSequence() if err != nil { - return err + return nil, err } var seqNoBytes [8]byte byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo) if err := settleIndex.Put(seqNoBytes[:], invoiceNum); err != nil { - return err + return nil, err } invoice.AmtPaid = amtPaid @@ -704,8 +717,12 @@ func settleInvoice(invoices, settleIndex *bolt.Bucket, invoiceNum []byte, var buf bytes.Buffer if err := serializeInvoice(&buf, &invoice); err != nil { - return nil + return nil, err } - return invoices.Put(invoiceNum[:], buf.Bytes()) + if err := invoices.Put(invoiceNum[:], buf.Bytes()); err != nil { + return nil, err + } + + return &invoice, nil }