diff --git a/invoices/sql_store.go b/invoices/sql_store.go index 09bb91170..fa92ee56d 100644 --- a/invoices/sql_store.go +++ b/invoices/sql_store.go @@ -10,6 +10,7 @@ import ( "strconv" "time" + "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" @@ -46,6 +47,9 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat GetInvoice(ctx context.Context, arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error) + GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice, + error) + GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]sqlc.InvoiceFeature, error) @@ -343,7 +347,22 @@ func (i *SQLStore) fetchInvoice(ctx context.Context, params.SetID = ref.SetID()[:] } - rows, err := db.GetInvoice(ctx, params) + var ( + rows []sqlc.Invoice + err error + ) + + // We need to split the query based on how we intend to look up the + // invoice. If only the set ID is given then we want to have an exact + // match on the set ID. If other fields are given, we want to match on + // those fields and the set ID but with a less strict join condition. + if params.Hash == nil && params.PaymentAddr == nil && + params.SetID != nil { + + rows, err = db.GetInvoiceBySetID(ctx, params.SetID) + } else { + rows, err = db.GetInvoice(ctx, params) + } switch { case len(rows) == 0: return nil, ErrInvoiceNotFound @@ -351,8 +370,8 @@ func (i *SQLStore) fetchInvoice(ctx context.Context, case len(rows) > 1: // In case the reference is ambiguous, meaning it matches more // than one invoice, we'll return an error. - return nil, fmt.Errorf("ambiguous invoice ref: %s", - ref.String()) + return nil, fmt.Errorf("ambiguous invoice ref: %s: %s", + ref.String(), spew.Sdump(rows)) case err != nil: return nil, fmt.Errorf("unable to fetch invoice: %w", err) @@ -1308,13 +1327,24 @@ func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error { // invoice and is therefore atomic. The fields to update are controlled by the // supplied callback. func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef, - _ *SetID, callback InvoiceUpdateCallback) ( + setID *SetID, callback InvoiceUpdateCallback) ( *Invoice, error) { var updatedInvoice *Invoice txOpt := SQLInvoiceQueriesTxOptions{readOnly: false} txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error { + if setID != nil { + // Make sure to use the set ID if this is an AMP update. + var setIDBytes [32]byte + copy(setIDBytes[:], setID[:]) + ref.setID = &setIDBytes + + // If we're updating an AMP invoice, we'll also only + // need to fetch the HTLCs for the given set ID. + ref.refModifier = HtlcSetOnlyModifier + } + invoice, err := i.fetchInvoice(ctx, db, ref) if err != nil { return err diff --git a/sqldb/sqlc/invoices.sql.go b/sqldb/sqlc/invoices.sql.go index fde02391c..5849bca06 100644 --- a/sqldb/sqlc/invoices.sql.go +++ b/sqldb/sqlc/invoices.sql.go @@ -170,21 +170,22 @@ const getInvoice = `-- name: GetInvoice :many SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at FROM invoices i -LEFT JOIN amp_sub_invoices a on i.id = a.invoice_id +LEFT JOIN amp_sub_invoices a +ON i.id = a.invoice_id +AND ( + a.set_id = $1 OR $1 IS NULL +) WHERE ( - i.id = $1 OR - $1 IS NULL -) AND ( - i.hash = $2 OR + i.id = $2 OR $2 IS NULL ) AND ( - i.preimage = $3 OR + i.hash = $3 OR $3 IS NULL ) AND ( - i.payment_addr = $4 OR + i.preimage = $4 OR $4 IS NULL ) AND ( - a.set_id = $5 OR + i.payment_addr = $5 OR $5 IS NULL ) GROUP BY i.id @@ -192,11 +193,11 @@ LIMIT 2 ` type GetInvoiceParams struct { + SetID []byte AddIndex sql.NullInt64 Hash []byte Preimage []byte PaymentAddr []byte - SetID []byte } // This method may return more than one invoice if filter using multiple fields @@ -204,11 +205,11 @@ type GetInvoiceParams struct { // we bubble up an error in those cases. func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error) { rows, err := q.db.QueryContext(ctx, getInvoice, + arg.SetID, arg.AddIndex, arg.Hash, arg.Preimage, arg.PaymentAddr, - arg.SetID, ) if err != nil { return nil, err @@ -250,6 +251,55 @@ func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoi return items, nil } +const getInvoiceBySetID = `-- name: GetInvoiceBySetID :many +SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at +FROM invoices i +INNER JOIN amp_sub_invoices a +ON i.id = a.invoice_id AND a.set_id = $1 +` + +func (q *Queries) GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error) { + rows, err := q.db.QueryContext(ctx, getInvoiceBySetID, setID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Invoice + for rows.Next() { + var i Invoice + if err := rows.Scan( + &i.ID, + &i.Hash, + &i.Preimage, + &i.SettleIndex, + &i.SettledAt, + &i.Memo, + &i.AmountMsat, + &i.CltvDelta, + &i.Expiry, + &i.PaymentAddr, + &i.PaymentRequest, + &i.PaymentRequestHash, + &i.State, + &i.AmountPaidMsat, + &i.IsAmp, + &i.IsHodl, + &i.IsKeysend, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getInvoiceFeatures = `-- name: GetInvoiceFeatures :many SELECT feature, invoice_id FROM invoice_features diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index d55d8090a..04b61c700 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -21,6 +21,7 @@ type Querier interface { // from different invoices. It is the caller's responsibility to ensure that // we bubble up an error in those cases. GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error) + GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error) GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]InvoiceFeature, error) GetInvoiceHTLCCustomRecords(ctx context.Context, invoiceID int64) ([]GetInvoiceHTLCCustomRecordsRow, error) GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]InvoiceHtlc, error) diff --git a/sqldb/sqlc/queries/invoices.sql b/sqldb/sqlc/queries/invoices.sql index 07c5ca418..b03410a51 100644 --- a/sqldb/sqlc/queries/invoices.sql +++ b/sqldb/sqlc/queries/invoices.sql @@ -26,7 +26,11 @@ WHERE invoice_id = $1; -- name: GetInvoice :many SELECT i.* FROM invoices i -LEFT JOIN amp_sub_invoices a on i.id = a.invoice_id +LEFT JOIN amp_sub_invoices a +ON i.id = a.invoice_id +AND ( + a.set_id = sqlc.narg('set_id') OR sqlc.narg('set_id') IS NULL +) WHERE ( i.id = sqlc.narg('add_index') OR sqlc.narg('add_index') IS NULL @@ -39,13 +43,16 @@ WHERE ( ) AND ( i.payment_addr = sqlc.narg('payment_addr') OR sqlc.narg('payment_addr') IS NULL -) AND ( - a.set_id = sqlc.narg('set_id') OR - sqlc.narg('set_id') IS NULL ) GROUP BY i.id LIMIT 2; +-- name: GetInvoiceBySetID :many +SELECT i.* +FROM invoices i +INNER JOIN amp_sub_invoices a +ON i.id = a.invoice_id AND a.set_id = $1; + -- name: FilterInvoices :many SELECT invoices.*