sqldb+invoices: synchronize SQL invoice updater behavior with KV version

Previously SQL invoice updater ignored the set ID hint when updating an
AMP invoice resulting in update subscriptions returning all of the AMP
state as well as all AMP HTLCs. This commit synchornizes behavior with
the KV implementation such that we now only return relevant AMP state
and HTLCs when updating an AMP invoice.
This commit is contained in:
Andras Banki-Horvath 2024-08-30 11:22:53 +02:00
parent c8de7a1699
commit b57910ee3a
No known key found for this signature in database
GPG key ID: 80E5375C094198D8
4 changed files with 106 additions and 18 deletions

View file

@ -10,6 +10,7 @@ import (
"strconv"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
@ -46,6 +47,9 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
GetInvoice(ctx context.Context,
arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
error)
GetInvoiceFeatures(ctx context.Context,
invoiceID int64) ([]sqlc.InvoiceFeature, error)
@ -343,7 +347,22 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
params.SetID = ref.SetID()[:]
}
rows, err := db.GetInvoice(ctx, params)
var (
rows []sqlc.Invoice
err error
)
// We need to split the query based on how we intend to look up the
// invoice. If only the set ID is given then we want to have an exact
// match on the set ID. If other fields are given, we want to match on
// those fields and the set ID but with a less strict join condition.
if params.Hash == nil && params.PaymentAddr == nil &&
params.SetID != nil {
rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
} else {
rows, err = db.GetInvoice(ctx, params)
}
switch {
case len(rows) == 0:
return nil, ErrInvoiceNotFound
@ -351,8 +370,8 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
case len(rows) > 1:
// In case the reference is ambiguous, meaning it matches more
// than one invoice, we'll return an error.
return nil, fmt.Errorf("ambiguous invoice ref: %s",
ref.String())
return nil, fmt.Errorf("ambiguous invoice ref: %s: %s",
ref.String(), spew.Sdump(rows))
case err != nil:
return nil, fmt.Errorf("unable to fetch invoice: %w", err)
@ -1308,13 +1327,24 @@ func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
// invoice and is therefore atomic. The fields to update are controlled by the
// supplied callback.
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
_ *SetID, callback InvoiceUpdateCallback) (
setID *SetID, callback InvoiceUpdateCallback) (
*Invoice, error) {
var updatedInvoice *Invoice
txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
if setID != nil {
// Make sure to use the set ID if this is an AMP update.
var setIDBytes [32]byte
copy(setIDBytes[:], setID[:])
ref.setID = &setIDBytes
// If we're updating an AMP invoice, we'll also only
// need to fetch the HTLCs for the given set ID.
ref.refModifier = HtlcSetOnlyModifier
}
invoice, err := i.fetchInvoice(ctx, db, ref)
if err != nil {
return err

View file

@ -170,21 +170,22 @@ const getInvoice = `-- name: GetInvoice :many
SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at
FROM invoices i
LEFT JOIN amp_sub_invoices a on i.id = a.invoice_id
LEFT JOIN amp_sub_invoices a
ON i.id = a.invoice_id
AND (
a.set_id = $1 OR $1 IS NULL
)
WHERE (
i.id = $1 OR
$1 IS NULL
) AND (
i.hash = $2 OR
i.id = $2 OR
$2 IS NULL
) AND (
i.preimage = $3 OR
i.hash = $3 OR
$3 IS NULL
) AND (
i.payment_addr = $4 OR
i.preimage = $4 OR
$4 IS NULL
) AND (
a.set_id = $5 OR
i.payment_addr = $5 OR
$5 IS NULL
)
GROUP BY i.id
@ -192,11 +193,11 @@ LIMIT 2
`
type GetInvoiceParams struct {
SetID []byte
AddIndex sql.NullInt64
Hash []byte
Preimage []byte
PaymentAddr []byte
SetID []byte
}
// This method may return more than one invoice if filter using multiple fields
@ -204,11 +205,11 @@ type GetInvoiceParams struct {
// we bubble up an error in those cases.
func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error) {
rows, err := q.db.QueryContext(ctx, getInvoice,
arg.SetID,
arg.AddIndex,
arg.Hash,
arg.Preimage,
arg.PaymentAddr,
arg.SetID,
)
if err != nil {
return nil, err
@ -250,6 +251,55 @@ func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoi
return items, nil
}
const getInvoiceBySetID = `-- name: GetInvoiceBySetID :many
SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at
FROM invoices i
INNER JOIN amp_sub_invoices a
ON i.id = a.invoice_id AND a.set_id = $1
`
func (q *Queries) GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error) {
rows, err := q.db.QueryContext(ctx, getInvoiceBySetID, setID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Invoice
for rows.Next() {
var i Invoice
if err := rows.Scan(
&i.ID,
&i.Hash,
&i.Preimage,
&i.SettleIndex,
&i.SettledAt,
&i.Memo,
&i.AmountMsat,
&i.CltvDelta,
&i.Expiry,
&i.PaymentAddr,
&i.PaymentRequest,
&i.PaymentRequestHash,
&i.State,
&i.AmountPaidMsat,
&i.IsAmp,
&i.IsHodl,
&i.IsKeysend,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getInvoiceFeatures = `-- name: GetInvoiceFeatures :many
SELECT feature, invoice_id
FROM invoice_features

View file

@ -21,6 +21,7 @@ type Querier interface {
// from different invoices. It is the caller's responsibility to ensure that
// we bubble up an error in those cases.
GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error)
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error)
GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]InvoiceFeature, error)
GetInvoiceHTLCCustomRecords(ctx context.Context, invoiceID int64) ([]GetInvoiceHTLCCustomRecordsRow, error)
GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]InvoiceHtlc, error)

View file

@ -26,7 +26,11 @@ WHERE invoice_id = $1;
-- name: GetInvoice :many
SELECT i.*
FROM invoices i
LEFT JOIN amp_sub_invoices a on i.id = a.invoice_id
LEFT JOIN amp_sub_invoices a
ON i.id = a.invoice_id
AND (
a.set_id = sqlc.narg('set_id') OR sqlc.narg('set_id') IS NULL
)
WHERE (
i.id = sqlc.narg('add_index') OR
sqlc.narg('add_index') IS NULL
@ -39,13 +43,16 @@ WHERE (
) AND (
i.payment_addr = sqlc.narg('payment_addr') OR
sqlc.narg('payment_addr') IS NULL
) AND (
a.set_id = sqlc.narg('set_id') OR
sqlc.narg('set_id') IS NULL
)
GROUP BY i.id
LIMIT 2;
-- name: GetInvoiceBySetID :many
SELECT i.*
FROM invoices i
INNER JOIN amp_sub_invoices a
ON i.id = a.invoice_id AND a.set_id = $1;
-- name: FilterInvoices :many
SELECT
invoices.*