diff --git a/invoices/resolution_result.go b/invoices/resolution_result.go index 8e2e4327a..9c40c84ca 100644 --- a/invoices/resolution_result.go +++ b/invoices/resolution_result.go @@ -110,6 +110,10 @@ const ( // payment. ResultMppInProgress + // ResultHtlcInvoiceTypeMismatch is returned when an AMP HTLC targets a + // non-AMP invoice and vice versa. + ResultHtlcInvoiceTypeMismatch + // ResultAmpError is returned when we receive invalid AMP parameters. ResultAmpError @@ -176,6 +180,9 @@ func (f FailResolutionResult) FailureString() string { case ResultMppInProgress: return "mpp reception in progress" + case ResultHtlcInvoiceTypeMismatch: + return "htlc invoice type mismatch" + case ResultAmpError: return "invalid amp parameters" diff --git a/invoices/update.go b/invoices/update.go index b41bd1a59..77bb0dee5 100644 --- a/invoices/update.go +++ b/invoices/update.go @@ -125,6 +125,21 @@ func updateMpp(ctx *invoiceUpdateCtx, inv *channeldb.Invoice) (*channeldb.InvoiceUpdateDesc, HtlcResolution, error) { + // Reject HTLCs to AMP invoices if they are missing an AMP payload, and + // HTLCs to MPP invoices if they have an AMP payload. + switch { + + case inv.Terms.Features.RequiresFeature(lnwire.AMPRequired) && + ctx.amp == nil: + + return nil, ctx.failRes(ResultHtlcInvoiceTypeMismatch), nil + + case !inv.Terms.Features.RequiresFeature(lnwire.AMPRequired) && + ctx.amp != nil: + + return nil, ctx.failRes(ResultHtlcInvoiceTypeMismatch), nil + } + setID := ctx.setID() // Start building the accept descriptor.