Merge pull request #9050 from bhandras/native-sql-invoice-fixes

invoices+sqldb: small fixes to address some inconsistencies between KV and native SQL invoice DB implementations
This commit is contained in:
Oliver Gugger 2024-09-04 01:30:52 -06:00 committed by GitHub
commit 258cf81240
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 197 additions and 41 deletions

View File

@ -93,11 +93,13 @@ linters-settings:
- 'errors.Wrap'
gomoddirectives:
replace-local: true
replace-allow-list:
# See go.mod for the explanation why these are needed.
- github.com/ulikunitz/xz
- github.com/gogo/protobuf
- google.golang.org/protobuf
- github.com/lightningnetwork/lnd/sqldb
linters:

View File

@ -269,7 +269,9 @@ func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (
// For each key found, we'll look up the actual
// invoice, then accumulate it into our return value.
invoice, err := fetchInvoice(invoiceKey, invoices)
invoice, err := fetchInvoice(
invoiceKey, invoices, nil, false,
)
if err != nil {
return err
}
@ -341,7 +343,9 @@ func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (
// An invoice was found, retrieve the remainder of the invoice
// body.
i, err := fetchInvoice(invoiceNum, invoices, setID)
i, err := fetchInvoice(
invoiceNum, invoices, []*invpkg.SetID{setID}, true,
)
if err != nil {
return err
}
@ -468,7 +472,7 @@ func (d *DB) FetchPendingInvoices(_ context.Context) (
return nil
}
invoice, err := fetchInvoice(v, invoices)
invoice, err := fetchInvoice(v, invoices, nil, false)
if err != nil {
return err
}
@ -526,7 +530,9 @@ func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
// characteristics for our query and returns the number of items
// we have added to our set of invoices.
accumulateInvoices := func(_, indexValue []byte) (bool, error) {
invoice, err := fetchInvoice(indexValue, invoices)
invoice, err := fetchInvoice(
indexValue, invoices, nil, false,
)
if err != nil {
return false, err
}
@ -654,7 +660,9 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
if setIDHint != nil {
invSetID = *setIDHint
}
invoice, err := fetchInvoice(invoiceNum, invoices, &invSetID)
invoice, err := fetchInvoice(
invoiceNum, invoices, []*invpkg.SetID{&invSetID}, false,
)
if err != nil {
return err
}
@ -676,8 +684,17 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
updatedInvoice, err = invpkg.UpdateInvoice(
payHash, updater.invoice, now, callback, updater,
)
if err != nil {
return err
}
return err
// If this is an AMP update, then limit the returned AMP state
// to only the requested set ID.
if setIDHint != nil {
filterInvoiceAMPState(updatedInvoice, &invSetID)
}
return nil
}, func() {
updatedInvoice = nil
})
@ -685,6 +702,25 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
return updatedInvoice, err
}
// filterInvoiceAMPState filters the AMP state of the invoice to only include
// state for the specified set IDs.
func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) {
filteredAMPState := make(invpkg.AMPInvoiceState)
for _, setID := range setIDs {
if setID == nil {
return
}
ampState, ok := invoice.AMPState[*setID]
if ok {
filteredAMPState[*setID] = ampState
}
}
invoice.AMPState = filteredAMPState
}
// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC
@ -1056,7 +1092,8 @@ func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
// For each key found, we'll look up the actual
// invoice, then accumulate it into our return value.
invoice, err := fetchInvoice(
invoiceKey[:], invoices, setID,
invoiceKey[:], invoices, []*invpkg.SetID{setID},
true,
)
if err != nil {
return err
@ -1485,7 +1522,7 @@ func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
// specified by the invoice number. If the setID fields are set, then only the
// HTLC information pertaining to those set IDs is returned.
func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
setIDs ...*invpkg.SetID) (invpkg.Invoice, error) {
setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) {
invoiceBytes := invoices.Get(invoiceNum)
if invoiceBytes == nil {
@ -1518,6 +1555,10 @@ func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
log.Errorf("unable to fetch amp htlcs for inv "+
"%v and setIDs %v: %w", invoiceNum, setIDs, err)
}
if filterAMPState {
filterInvoiceAMPState(&invoice, setIDs...)
}
}
return invoice, nil
@ -2163,7 +2204,7 @@ func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
return nil
}
invoice, err := fetchInvoice(v, invoices)
invoice, err := fetchInvoice(v, invoices, nil, false)
if err != nil {
return err
}

View File

@ -266,6 +266,11 @@ that validate `ChannelAnnouncement` messages.
our health checker to correctly shut down LND if network partitioning occurs
towards the etcd cluster.
* [Fix](https://github.com/lightningnetwork/lnd/pull/9050) some inconsistencies
to make the native SQL invoice DB compatible with the KV implementation.
Furthermore fix a native SQL invoice issue where AMP subinvoice HTLCs are
sometimes updated incorrectly on settlement.
## Code Health
* [Move graph building and
@ -282,6 +287,7 @@ that validate `ChannelAnnouncement` messages.
# Contributors (Alphabetical Order)
* Alex Akselrod
* Andras Banki-Horvath
* bitromortac
* Bufo

3
go.mod
View File

@ -204,6 +204,9 @@ replace github.com/gogo/protobuf => github.com/gogo/protobuf v1.3.2
// allows us to specify that as an option.
replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-display v1.30.0-hex-display
// Temporary replace until the next version of sqldb is taged.
replace github.com/lightningnetwork/lnd/sqldb => ./sqldb
// If you change this please also update .github/pull_request_template.md,
// docs/INSTALL.md and GO_IMAGE in lnrpc/gen_protos_docker.sh.
go 1.22.6

2
go.sum
View File

@ -458,8 +458,6 @@ github.com/lightningnetwork/lnd/kvdb v1.4.10 h1:vK89IVv1oVH9ubQWU+EmoCQFeVRaC8kf
github.com/lightningnetwork/lnd/kvdb v1.4.10/go.mod h1:J2diNABOoII9UrMnxXS5w7vZwP7CA1CStrl8MnIrb3A=
github.com/lightningnetwork/lnd/queue v1.1.1 h1:99ovBlpM9B0FRCGYJo6RSFDlt8/vOkQQZznVb18iNMI=
github.com/lightningnetwork/lnd/queue v1.1.1/go.mod h1:7A6nC1Qrm32FHuhx/mi1cieAiBZo5O6l8IBIoQxvkz4=
github.com/lightningnetwork/lnd/sqldb v1.0.3 h1:zLfAwOvM+6+3+hahYO9Q3h8pVV0TghAR7iJ5YMLCd3I=
github.com/lightningnetwork/lnd/sqldb v1.0.3/go.mod h1:4cQOkdymlZ1znnjuRNvMoatQGJkRneTj2CoPSPaQhWo=
github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM=
github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA=
github.com/lightningnetwork/lnd/tlv v1.2.6 h1:icvQG2yDr6k3ZuZzfRdG3EJp6pHurcuh3R6dg0gv/Mw=

View File

@ -10,6 +10,7 @@ import (
"strconv"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
@ -46,6 +47,9 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
GetInvoice(ctx context.Context,
arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
error)
GetInvoiceFeatures(ctx context.Context,
invoiceID int64) ([]sqlc.InvoiceFeature, error)
@ -343,7 +347,22 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
params.SetID = ref.SetID()[:]
}
rows, err := db.GetInvoice(ctx, params)
var (
rows []sqlc.Invoice
err error
)
// We need to split the query based on how we intend to look up the
// invoice. If only the set ID is given then we want to have an exact
// match on the set ID. If other fields are given, we want to match on
// those fields and the set ID but with a less strict join condition.
if params.Hash == nil && params.PaymentAddr == nil &&
params.SetID != nil {
rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
} else {
rows, err = db.GetInvoice(ctx, params)
}
switch {
case len(rows) == 0:
return nil, ErrInvoiceNotFound
@ -351,8 +370,8 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
case len(rows) > 1:
// In case the reference is ambiguous, meaning it matches more
// than one invoice, we'll return an error.
return nil, fmt.Errorf("ambiguous invoice ref: %s",
ref.String())
return nil, fmt.Errorf("ambiguous invoice ref: %s: %s",
ref.String(), spew.Sdump(rows))
case err != nil:
return nil, fmt.Errorf("unable to fetch invoice: %w", err)
@ -906,8 +925,10 @@ func (i *SQLStore) QueryInvoices(ctx context.Context,
}
if q.CreationDateEnd != 0 {
// We need to add 1 to the end date as we're
// checking less than the end date in SQL.
params.CreatedBefore = sqldb.SQLTime(
time.Unix(q.CreationDateEnd, 0).UTC(),
time.Unix(q.CreationDateEnd+1, 0).UTC(),
)
}
@ -1116,6 +1137,9 @@ func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
SetID: setID[:],
HtlcID: int64(circuitKey.HtlcID),
Preimage: preimage[:],
ChanID: strconv.FormatUint(
circuitKey.ChanID.ToUint64(), 10,
),
},
)
if err != nil {
@ -1280,6 +1304,13 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
return err
}
if settleIndex.Valid {
updatedState := s.invoice.AMPState[setID]
updatedState.SettleIndex = uint64(settleIndex.Int64)
updatedState.SettleDate = s.updateTime.UTC()
s.invoice.AMPState[setID] = updatedState
}
return nil
}
@ -1298,13 +1329,24 @@ func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
// invoice and is therefore atomic. The fields to update are controlled by the
// supplied callback.
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
_ *SetID, callback InvoiceUpdateCallback) (
setID *SetID, callback InvoiceUpdateCallback) (
*Invoice, error) {
var updatedInvoice *Invoice
txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
if setID != nil {
// Make sure to use the set ID if this is an AMP update.
var setIDBytes [32]byte
copy(setIDBytes[:], setID[:])
ref.setID = &setIDBytes
// If we're updating an AMP invoice, we'll also only
// need to fetch the HTLCs for the given set ID.
ref.refModifier = HtlcSetOnlyModifier
}
invoice, err := i.fetchInvoice(ctx, db, ref)
if err != nil {
return err

View File

@ -260,7 +260,8 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
invoiceNtfn := ht.ReceiveInvoiceUpdate(invSubscription)
// The notification should signal that the invoice is now settled, and
// should also include the set ID, and show the proper amount paid.
// should also include the set ID, show the proper amount paid, and have
// the correct settle index and time.
require.True(ht, invoiceNtfn.Settled)
require.Equal(ht, lnrpc.Invoice_SETTLED, invoiceNtfn.State)
require.Equal(ht, paymentAmt, int(invoiceNtfn.AmtPaidSat))
@ -270,6 +271,9 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
firstSetID, _ = hex.DecodeString(setIDStr)
require.Equal(ht, lnrpc.InvoiceHTLCState_SETTLED,
ampState.State)
require.GreaterOrEqual(ht, ampState.SettleTime,
rpcInvoice.CreationDate)
require.Equal(ht, uint64(1), ampState.SettleIndex)
}
// Pay the invoice again, we should get another notification that Dave
@ -299,9 +303,9 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
// return the "projected" sub-invoice for a given setID.
require.Equal(ht, 1, len(invoiceNtfn.Htlcs))
// However the AMP state index should show that there've been two
// repeated payments to this invoice so far.
require.Equal(ht, 2, len(invoiceNtfn.AmpInvoiceState))
// The AMP state should also be restricted to a single entry for the
// "projected" sub-invoice.
require.Equal(ht, 1, len(invoiceNtfn.AmpInvoiceState))
// Now we'll look up the invoice using the new LookupInvoice2 RPC call
// by the set ID of each of the invoices.
@ -360,7 +364,7 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
// through.
backlogInv := ht.ReceiveInvoiceUpdate(invSub2)
require.Equal(ht, 1, len(backlogInv.Htlcs))
require.Equal(ht, 2, len(backlogInv.AmpInvoiceState))
require.Equal(ht, 1, len(backlogInv.AmpInvoiceState))
require.True(ht, backlogInv.Settled)
require.Equal(ht, paymentAmt*2, int(backlogInv.AmtPaidSat))
}

View File

@ -268,15 +268,16 @@ func (q *Queries) InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubI
const updateAMPSubInvoiceHTLCPreimage = `-- name: UpdateAMPSubInvoiceHTLCPreimage :execresult
UPDATE amp_sub_invoice_htlcs AS a
SET preimage = $4
SET preimage = $5
WHERE a.invoice_id = $1 AND a.set_id = $2 AND a.htlc_id = (
SELECT id FROM invoice_htlcs AS i WHERE i.htlc_id = $3
SELECT id FROM invoice_htlcs AS i WHERE i.chan_id = $3 AND i.htlc_id = $4
)
`
type UpdateAMPSubInvoiceHTLCPreimageParams struct {
InvoiceID int64
SetID []byte
ChanID string
HtlcID int64
Preimage []byte
}
@ -285,6 +286,7 @@ func (q *Queries) UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context, arg Updat
return q.db.ExecContext(ctx, updateAMPSubInvoiceHTLCPreimage,
arg.InvoiceID,
arg.SetID,
arg.ChanID,
arg.HtlcID,
arg.Preimage,
)

View File

@ -78,7 +78,7 @@ WHERE (
created_at >= $6 OR
$6 IS NULL
) AND (
created_at <= $7 OR
created_at < $7 OR
$7 IS NULL
) AND (
CASE
@ -170,21 +170,22 @@ const getInvoice = `-- name: GetInvoice :many
SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at
FROM invoices i
LEFT JOIN amp_sub_invoices a on i.id = a.invoice_id
LEFT JOIN amp_sub_invoices a
ON i.id = a.invoice_id
AND (
a.set_id = $1 OR $1 IS NULL
)
WHERE (
i.id = $1 OR
$1 IS NULL
) AND (
i.hash = $2 OR
i.id = $2 OR
$2 IS NULL
) AND (
i.preimage = $3 OR
i.hash = $3 OR
$3 IS NULL
) AND (
i.payment_addr = $4 OR
i.preimage = $4 OR
$4 IS NULL
) AND (
a.set_id = $5 OR
i.payment_addr = $5 OR
$5 IS NULL
)
GROUP BY i.id
@ -192,11 +193,11 @@ LIMIT 2
`
type GetInvoiceParams struct {
SetID []byte
AddIndex sql.NullInt64
Hash []byte
Preimage []byte
PaymentAddr []byte
SetID []byte
}
// This method may return more than one invoice if filter using multiple fields
@ -204,11 +205,11 @@ type GetInvoiceParams struct {
// we bubble up an error in those cases.
func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error) {
rows, err := q.db.QueryContext(ctx, getInvoice,
arg.SetID,
arg.AddIndex,
arg.Hash,
arg.Preimage,
arg.PaymentAddr,
arg.SetID,
)
if err != nil {
return nil, err
@ -250,6 +251,55 @@ func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoi
return items, nil
}
const getInvoiceBySetID = `-- name: GetInvoiceBySetID :many
SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at
FROM invoices i
INNER JOIN amp_sub_invoices a
ON i.id = a.invoice_id AND a.set_id = $1
`
func (q *Queries) GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error) {
rows, err := q.db.QueryContext(ctx, getInvoiceBySetID, setID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Invoice
for rows.Next() {
var i Invoice
if err := rows.Scan(
&i.ID,
&i.Hash,
&i.Preimage,
&i.SettleIndex,
&i.SettledAt,
&i.Memo,
&i.AmountMsat,
&i.CltvDelta,
&i.Expiry,
&i.PaymentAddr,
&i.PaymentRequest,
&i.PaymentRequestHash,
&i.State,
&i.AmountPaidMsat,
&i.IsAmp,
&i.IsHodl,
&i.IsKeysend,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getInvoiceFeatures = `-- name: GetInvoiceFeatures :many
SELECT feature, invoice_id
FROM invoice_features

View File

@ -21,6 +21,7 @@ type Querier interface {
// from different invoices. It is the caller's responsibility to ensure that
// we bubble up an error in those cases.
GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error)
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error)
GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]InvoiceFeature, error)
GetInvoiceHTLCCustomRecords(ctx context.Context, invoiceID int64) ([]GetInvoiceHTLCCustomRecordsRow, error)
GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]InvoiceHtlc, error)

View File

@ -61,7 +61,7 @@ WHERE (
-- name: UpdateAMPSubInvoiceHTLCPreimage :execresult
UPDATE amp_sub_invoice_htlcs AS a
SET preimage = $4
SET preimage = $5
WHERE a.invoice_id = $1 AND a.set_id = $2 AND a.htlc_id = (
SELECT id FROM invoice_htlcs AS i WHERE i.htlc_id = $3
SELECT id FROM invoice_htlcs AS i WHERE i.chan_id = $3 AND i.htlc_id = $4
);

View File

@ -26,7 +26,11 @@ WHERE invoice_id = $1;
-- name: GetInvoice :many
SELECT i.*
FROM invoices i
LEFT JOIN amp_sub_invoices a on i.id = a.invoice_id
LEFT JOIN amp_sub_invoices a
ON i.id = a.invoice_id
AND (
a.set_id = sqlc.narg('set_id') OR sqlc.narg('set_id') IS NULL
)
WHERE (
i.id = sqlc.narg('add_index') OR
sqlc.narg('add_index') IS NULL
@ -39,13 +43,16 @@ WHERE (
) AND (
i.payment_addr = sqlc.narg('payment_addr') OR
sqlc.narg('payment_addr') IS NULL
) AND (
a.set_id = sqlc.narg('set_id') OR
sqlc.narg('set_id') IS NULL
)
GROUP BY i.id
LIMIT 2;
-- name: GetInvoiceBySetID :many
SELECT i.*
FROM invoices i
INNER JOIN amp_sub_invoices a
ON i.id = a.invoice_id AND a.set_id = $1;
-- name: FilterInvoices :many
SELECT
invoices.*
@ -69,7 +76,7 @@ WHERE (
created_at >= sqlc.narg('created_after') OR
sqlc.narg('created_after') IS NULL
) AND (
created_at <= sqlc.narg('created_before') OR
created_at < sqlc.narg('created_before') OR
sqlc.narg('created_before') IS NULL
) AND (
CASE