zpay32: add functional opt to error out on unknown feature bit

This commit adds two functional options to the zpay32.Decode function.
`WithKnownFeatureBits` allows the caller to overwrite the default set of
known feature bits used by the function.
`WithErrorOnUnknownFeatureBit` allows the caller to instruct the
function to error out if the invoice that is decoded contaijns unknown
feature bits. We then use this new error-out option from the
`rpcServer`'s `extractPaymentIntent` method.
This commit is contained in:
Elle Mouton 2024-08-06 12:33:56 +02:00
parent e4619afc08
commit 48a9a8d20e
No known key found for this signature in database
GPG key ID: D7D916376026F177
3 changed files with 133 additions and 4 deletions

View file

@ -5210,6 +5210,7 @@ func (r *rpcServer) extractPaymentIntent(rpcPayReq *rpcPaymentRequest) (rpcPayme
if rpcPayReq.PaymentRequest != "" {
payReq, err := zpay32.Decode(
rpcPayReq.PaymentRequest, r.cfg.ActiveNetParams.Params,
zpay32.WithErrorOnUnknownFeatureBit(),
)
if err != nil {
return payIntent, err

View file

@ -17,10 +17,53 @@ import (
"github.com/lightningnetwork/lnd/lnwire"
)
// DecodeOption is a type that can be used to supply functional options to the
// Decode function.
type DecodeOption func(*decodeOptions)
// WithKnownFeatureBits is a functional option that overwrites the set of
// known feature bits. If not set, then LND's lnwire.Features variable will be
// used by default.
func WithKnownFeatureBits(features map[lnwire.FeatureBit]string) DecodeOption {
return func(options *decodeOptions) {
options.knownFeatureBits = features
}
}
// WithErrorOnUnknownFeatureBit is a functional option that will cause the
// Decode function to return an error if the decoded invoice contains an unknown
// feature bit.
func WithErrorOnUnknownFeatureBit() DecodeOption {
return func(options *decodeOptions) {
options.errorOnUnknownFeature = true
}
}
// decodeOptions holds the set of Decode options.
type decodeOptions struct {
knownFeatureBits map[lnwire.FeatureBit]string
errorOnUnknownFeature bool
}
// newDecodeOptions constructs the default decodeOptions struct.
func newDecodeOptions() *decodeOptions {
return &decodeOptions{
knownFeatureBits: lnwire.Features,
errorOnUnknownFeature: false,
}
}
// Decode parses the provided encoded invoice and returns a decoded Invoice if
// it is valid by BOLT-0011 and matches the provided active network.
func Decode(invoice string, net *chaincfg.Params) (*Invoice, error) {
decodedInvoice := Invoice{}
func Decode(invoice string, net *chaincfg.Params, opts ...DecodeOption) (
*Invoice, error) {
options := newDecodeOptions()
for _, o := range opts {
o(options)
}
var decodedInvoice Invoice
// Before bech32 decoding the invoice, make sure that it is not too large.
// This is done as an anti-DoS measure since bech32 decoding is expensive.
@ -134,7 +177,7 @@ func Decode(invoice string, net *chaincfg.Params) (*Invoice, error) {
// If no feature vector was decoded, populate an empty one.
if decodedInvoice.Features == nil {
decodedInvoice.Features = lnwire.NewFeatureVector(
nil, lnwire.Features,
nil, options.knownFeatureBits,
)
}
@ -144,6 +187,24 @@ func Decode(invoice string, net *chaincfg.Params) (*Invoice, error) {
return nil, err
}
if options.errorOnUnknownFeature {
// Make sure that we understand all the required feature bits
// in the invoice.
unknownFeatureBits := decodedInvoice.Features.
UnknownRequiredFeatures()
if len(unknownFeatureBits) > 0 {
errStr := fmt.Sprintf("invoice contains " +
"unknown feature bits:")
for _, bit := range unknownFeatureBits {
errStr += fmt.Sprintf(" %d,", bit)
}
return nil, fmt.Errorf(strings.TrimRight(errStr, ","))
}
}
return &decodedInvoice, nil
}

View file

@ -191,6 +191,7 @@ func TestDecodeEncode(t *testing.T) {
encodedInvoice string
valid bool
decodedInvoice func() *Invoice
decodeOpts []DecodeOption
skipEncoding bool
beforeEncoding func(*Invoice)
}{
@ -758,6 +759,70 @@ func TestDecodeEncode(t *testing.T) {
i.Destination = nil
},
},
{
// Invoice with unknown feature bits but since the
// WithErrorOnUnknownFeatureBit option is not provided,
// it is not expected to error out.
encodedInvoice: "lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5vdhkven9v5sxyetpdeessp5zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zygs9q4psqqqqqqqqqqqqqqqpqsqq40wa3khl49yue3zsgm26jrepqr2eghqlx86rttutve3ugd05em86nsefzh4pfurpd9ek9w2vp95zxqnfe2u7ckudyahsa52q66tgzcp6t2dyk",
valid: true,
skipEncoding: true,
decodedInvoice: func() *Invoice {
return &Invoice{
Net: &chaincfg.MainNetParams,
MilliSat: &testMillisat25mBTC,
Timestamp: time.Unix(1496314658, 0),
PaymentHash: &testPaymentHash,
PaymentAddr: &specPaymentAddr,
Description: &testCoffeeBeans,
Destination: testPubKey,
Features: lnwire.NewFeatureVector(
lnwire.NewRawFeatureVector(
9, 15, 99, 100,
),
lnwire.Features,
),
}
},
decodeOpts: []DecodeOption{
WithKnownFeatureBits(map[lnwire.FeatureBit]string{
9: "9",
15: "15",
99: "99",
}),
},
},
{
// Invoice with unknown feature bits with option set to
// error out on unknown feature bit.
encodedInvoice: "lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5vdhkven9v5sxyetpdeessp5zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zygs9q4psqqqqqqqqqqqqqqqpqsqq40wa3khl49yue3zsgm26jrepqr2eghqlx86rttutve3ugd05em86nsefzh4pfurpd9ek9w2vp95zxqnfe2u7ckudyahsa52q66tgzcp6t2dyk",
valid: false,
skipEncoding: true,
decodedInvoice: func() *Invoice {
return &Invoice{
Net: &chaincfg.MainNetParams,
MilliSat: &testMillisat25mBTC,
Timestamp: time.Unix(1496314658, 0),
PaymentHash: &testPaymentHash,
PaymentAddr: &specPaymentAddr,
Description: &testCoffeeBeans,
Destination: testPubKey,
Features: lnwire.NewFeatureVector(
lnwire.NewRawFeatureVector(
9, 15, 99, 100,
),
lnwire.Features,
),
}
},
decodeOpts: []DecodeOption{
WithKnownFeatureBits(map[lnwire.FeatureBit]string{
9: "9",
15: "15",
99: "99",
}),
WithErrorOnUnknownFeatureBit(),
},
},
}
for i, test := range tests {
@ -773,7 +838,9 @@ func TestDecodeEncode(t *testing.T) {
net = decodedInvoice.Net
}
invoice, err := Decode(test.encodedInvoice, net)
invoice, err := Decode(
test.encodedInvoice, net, test.decodeOpts...,
)
if !test.valid {
require.Error(t, err)
} else {