diff --git a/zpay32/blinded_path.go b/zpay32/blinded_path.go new file mode 100644 index 000000000..128a05e4b --- /dev/null +++ b/zpay32/blinded_path.go @@ -0,0 +1,246 @@ +package zpay32 + +import ( + "encoding/binary" + "fmt" + "io" + "math" + + "github.com/btcsuite/btcd/btcec/v2" + sphinx "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // relayInfoSize is the number of bytes that the relay info of a blinded + // payment will occupy. + // base fee: 4 bytes + // prop fee: 4 bytes + // cltv delta: 2 bytes + // min htlc: 8 bytes + // max htlc: 8 bytes + relayInfoSize = 26 + + // maxNumHopsPerPath is the maximum number of blinded path hops that can + // be included in a single encoded blinded path. This is calculated + // based on the `data_length` limit of 638 bytes for any tagged field in + // a BOLT 11 invoice along with the estimated number of bytes required + // for encoding the most minimal blinded path hop. See the [bLIP + // proposal](https://github.com/lightning/blips/pull/39) for a detailed + // calculation. + maxNumHopsPerPath = 7 +) + +// BlindedPaymentPath holds all the information a payer needs to know about a +// blinded path to a receiver of a payment. +type BlindedPaymentPath struct { + // FeeBaseMsat is the total base fee for the path in milli-satoshis. + FeeBaseMsat uint32 + + // FeeRate is the total fee rate for the path in parts per million. + FeeRate uint32 + + // CltvExpiryDelta is the total CLTV delta to apply to the path. + CltvExpiryDelta uint16 + + // HTLCMinMsat is the minimum number of milli-satoshis that any hop in + // the path will route. + HTLCMinMsat uint64 + + // HTLCMaxMsat is the maximum number of milli-satoshis that a hop in the + // path will route. + HTLCMaxMsat uint64 + + // Features is the feature bit vector for the path. + Features *lnwire.FeatureVector + + // FirstEphemeralBlindingPoint is the blinding point to send to the + // introduction node. It will be used by the introduction node to derive + // a shared secret with the receiver which can then be used to decode + // the encrypted payload from the receiver. + FirstEphemeralBlindingPoint *btcec.PublicKey + + // Hops is the blinded path. The first hop is the introduction node and + // so the BlindedNodeID of this hop will be the real node ID. + Hops []*sphinx.BlindedHopInfo +} + +// DecodeBlindedPayment attempts to parse a BlindedPaymentPath from the passed +// reader. +func DecodeBlindedPayment(r io.Reader) (*BlindedPaymentPath, error) { + var relayInfo [relayInfoSize]byte + n, err := r.Read(relayInfo[:]) + if err != nil { + return nil, err + } + if n != relayInfoSize { + return nil, fmt.Errorf("unable to read %d relay info bytes "+ + "off of the given stream: %w", relayInfoSize, err) + } + + var payment BlindedPaymentPath + + // Parse the relay info fields. + payment.FeeBaseMsat = binary.BigEndian.Uint32(relayInfo[:4]) + payment.FeeRate = binary.BigEndian.Uint32(relayInfo[4:8]) + payment.CltvExpiryDelta = binary.BigEndian.Uint16(relayInfo[8:10]) + payment.HTLCMinMsat = binary.BigEndian.Uint64(relayInfo[10:18]) + payment.HTLCMaxMsat = binary.BigEndian.Uint64(relayInfo[18:]) + + // Parse the feature bit vector. + f := lnwire.EmptyFeatureVector() + err = f.Decode(r) + if err != nil { + return nil, err + } + payment.Features = f + + // Parse the first ephemeral blinding point. + var blindingPointBytes [btcec.PubKeyBytesLenCompressed]byte + _, err = r.Read(blindingPointBytes[:]) + if err != nil { + return nil, err + } + + blinding, err := btcec.ParsePubKey(blindingPointBytes[:]) + if err != nil { + return nil, err + } + payment.FirstEphemeralBlindingPoint = blinding + + // Read the one byte hop number. + var numHops [1]byte + _, err = r.Read(numHops[:]) + if err != nil { + return nil, err + } + + payment.Hops = make([]*sphinx.BlindedHopInfo, int(numHops[0])) + + // Parse each hop. + for i := 0; i < len(payment.Hops); i++ { + hop, err := DecodeBlindedHop(r) + if err != nil { + return nil, err + } + + payment.Hops[i] = hop + } + + return &payment, nil +} + +// Encode serialises the BlindedPaymentPath and writes the bytes to the passed +// writer. +// 1) The first 26 bytes contain the relay info: +// - Base Fee in msat: uint32 (4 bytes). +// - Proportional Fee in PPM: uint32 (4 bytes). +// - CLTV expiry delta: uint16 (2 bytes). +// - HTLC min msat: uint64 (8 bytes). +// - HTLC max msat: uint64 (8 bytes). +// +// 2) Feature bit vector length (2 bytes). +// 3) Feature bit vector (can be zero length). +// 4) First blinding point: 33 bytes. +// 5) Number of hops: 1 byte. +// 6) Encoded BlindedHops. +func (p *BlindedPaymentPath) Encode(w io.Writer) error { + var relayInfo [26]byte + binary.BigEndian.PutUint32(relayInfo[:4], p.FeeBaseMsat) + binary.BigEndian.PutUint32(relayInfo[4:8], p.FeeRate) + binary.BigEndian.PutUint16(relayInfo[8:10], p.CltvExpiryDelta) + binary.BigEndian.PutUint64(relayInfo[10:18], p.HTLCMinMsat) + binary.BigEndian.PutUint64(relayInfo[18:], p.HTLCMaxMsat) + + _, err := w.Write(relayInfo[:]) + if err != nil { + return err + } + + err = p.Features.Encode(w) + if err != nil { + return err + } + + _, err = w.Write(p.FirstEphemeralBlindingPoint.SerializeCompressed()) + if err != nil { + return err + } + + numHops := len(p.Hops) + if numHops > maxNumHopsPerPath { + return fmt.Errorf("the number of hops, %d, exceeds the "+ + "maximum of %d", numHops, maxNumHopsPerPath) + } + + _, err = w.Write([]byte{byte(numHops)}) + if err != nil { + return err + } + + for _, hop := range p.Hops { + err = EncodeBlindedHop(w, hop) + if err != nil { + return err + } + } + + return nil +} + +// DecodeBlindedHop reads a sphinx.BlindedHopInfo from the passed reader. +func DecodeBlindedHop(r io.Reader) (*sphinx.BlindedHopInfo, error) { + var nodeIDBytes [btcec.PubKeyBytesLenCompressed]byte + _, err := r.Read(nodeIDBytes[:]) + if err != nil { + return nil, err + } + + nodeID, err := btcec.ParsePubKey(nodeIDBytes[:]) + if err != nil { + return nil, err + } + + dataLen, err := tlv.ReadVarInt(r, &[8]byte{}) + if err != nil { + return nil, err + } + + encryptedData := make([]byte, dataLen) + _, err = r.Read(encryptedData) + if err != nil { + return nil, err + } + + return &sphinx.BlindedHopInfo{ + BlindedNodePub: nodeID, + CipherText: encryptedData, + }, nil +} + +// EncodeBlindedHop writes the passed BlindedHopInfo to the given writer. +// +// 1) Blinded node pub key: 33 bytes +// 2) Cipher text length: BigSize +// 3) Cipher text. +func EncodeBlindedHop(w io.Writer, hop *sphinx.BlindedHopInfo) error { + _, err := w.Write(hop.BlindedNodePub.SerializeCompressed()) + if err != nil { + return err + } + + if len(hop.CipherText) > math.MaxUint16 { + return fmt.Errorf("encrypted recipient data can not exceed a "+ + "length of %d bytes", math.MaxUint16) + } + + err = tlv.WriteVarInt(w, uint64(len(hop.CipherText)), &[8]byte{}) + if err != nil { + return err + } + + _, err = w.Write(hop.CipherText) + + return err +} diff --git a/zpay32/decode.go b/zpay32/decode.go index d37a34cf9..59bfccf00 100644 --- a/zpay32/decode.go +++ b/zpay32/decode.go @@ -215,6 +215,7 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er } invoice.PaymentHash, err = parse32Bytes(base32Data) + case fieldTypeS: if invoice.PaymentAddr != nil { // We skip the field if we have already seen a @@ -223,6 +224,7 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er } invoice.PaymentAddr, err = parse32Bytes(base32Data) + case fieldTypeD: if invoice.Description != nil { // We skip the field if we have already seen a @@ -231,6 +233,7 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er } invoice.Description, err = parseDescription(base32Data) + case fieldTypeM: if invoice.Metadata != nil { // We skip the field if we have already seen a @@ -248,6 +251,7 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er } invoice.Destination, err = parseDestination(base32Data) + case fieldTypeH: if invoice.DescriptionHash != nil { // We skip the field if we have already seen a @@ -256,6 +260,7 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er } invoice.DescriptionHash, err = parse32Bytes(base32Data) + case fieldTypeX: if invoice.expiry != nil { // We skip the field if we have already seen a @@ -264,6 +269,7 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er } invoice.expiry, err = parseExpiry(base32Data) + case fieldTypeC: if invoice.minFinalCLTVExpiry != nil { // We skip the field if we have already seen a @@ -271,7 +277,9 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er continue } - invoice.minFinalCLTVExpiry, err = parseMinFinalCLTVExpiry(base32Data) + invoice.minFinalCLTVExpiry, err = + parseMinFinalCLTVExpiry(base32Data) + case fieldTypeF: if invoice.FallbackAddr != nil { // We skip the field if we have already seen a @@ -279,7 +287,10 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er continue } - invoice.FallbackAddr, err = parseFallbackAddr(base32Data, net) + invoice.FallbackAddr, err = parseFallbackAddr( + base32Data, net, + ) + case fieldTypeR: // An `r` field can be included in an invoice multiple // times, so we won't skip it if we have already seen @@ -289,7 +300,10 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er return err } - invoice.RouteHints = append(invoice.RouteHints, routeHint) + invoice.RouteHints = append( + invoice.RouteHints, routeHint, + ) + case fieldType9: if invoice.Features != nil { // We skip the field if we have already seen a @@ -298,6 +312,19 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er } invoice.Features, err = parseFeatures(base32Data) + + case fieldTypeB: + blindedPaymentPath, err := parseBlindedPaymentPath( + base32Data, + ) + if err != nil { + return err + } + + invoice.BlindedPaymentPaths = append( + invoice.BlindedPaymentPaths, blindedPaymentPath, + ) + default: // Ignore unknown type. } @@ -495,6 +522,17 @@ func parseRouteHint(data []byte) ([]HopHint, error) { return routeHint, nil } +// parseBlindedPaymentPath attempts to parse a BlindedPaymentPath from the given +// byte slice. +func parseBlindedPaymentPath(data []byte) (*BlindedPaymentPath, error) { + base256Data, err := bech32.ConvertBits(data, 5, 8, false) + if err != nil { + return nil, err + } + + return DecodeBlindedPayment(bytes.NewReader(base256Data)) +} + // parseFeatures decodes any feature bits directly from the base32 // representation. func parseFeatures(data []byte) (*lnwire.FeatureVector, error) { diff --git a/zpay32/encode.go b/zpay32/encode.go index bf544d062..130abdcfd 100644 --- a/zpay32/encode.go +++ b/zpay32/encode.go @@ -260,6 +260,29 @@ func writeTaggedFields(bufferBase32 *bytes.Buffer, invoice *Invoice) error { } } + for _, path := range invoice.BlindedPaymentPaths { + var buf bytes.Buffer + + err := path.Encode(&buf) + if err != nil { + return err + } + + blindedPathBase32, err := bech32.ConvertBits( + buf.Bytes(), 8, 5, true, + ) + if err != nil { + return err + } + + err = writeTaggedField( + bufferBase32, fieldTypeB, blindedPathBase32, + ) + if err != nil { + return err + } + } + if invoice.Destination != nil { // Convert 33 byte pubkey to 53 5-bit groups. pubKeyBase32, err := bech32.ConvertBits( diff --git a/zpay32/invoice.go b/zpay32/invoice.go index f23992ec9..2afc59d95 100644 --- a/zpay32/invoice.go +++ b/zpay32/invoice.go @@ -76,6 +76,10 @@ const ( // probing the recipient. fieldTypeS = 16 + // fieldTypeB contains blinded payment path information. This field may + // be repeated to include multiple blinded payment paths in the invoice. + fieldTypeB = 20 + // maxInvoiceLength is the maximum total length an invoice can have. // This is chosen to be the maximum number of bytes that can fit into a // single QR code: https://en.wikipedia.org/wiki/QR_code#Storage @@ -180,9 +184,17 @@ type Invoice struct { // hint can be individually used to reach the destination. These usually // represent private routes. // - // NOTE: This is optional. + // NOTE: This is optional and should not be set at the same time as + // BlindedPaymentPaths. RouteHints [][]HopHint + // BlindedPaymentPaths is a set of blinded payment paths that can be + // used to find the payment receiver. + // + // NOTE: This is optional and should not be set at the same time as + // RouteHints. + BlindedPaymentPaths []*BlindedPaymentPath + // Features represents an optional field used to signal optional or // required support for features by the receiver. Features *lnwire.FeatureVector @@ -263,6 +275,15 @@ func RouteHint(routeHint []HopHint) func(*Invoice) { } } +// WithBlindedPaymentPath is a functional option that allows a caller of +// NewInvoice to attach a blinded payment path to the invoice. The option can +// be used multiple times to attach multiple paths. +func WithBlindedPaymentPath(p *BlindedPaymentPath) func(*Invoice) { + return func(i *Invoice) { + i.BlindedPaymentPaths = append(i.BlindedPaymentPaths, p) + } +} + // Features is a functional option that allows callers of NewInvoice to set the // desired feature bits that are advertised on the invoice. If this option is // not used, an empty feature vector will automatically be populated. @@ -355,6 +376,13 @@ func validateInvoice(invoice *Invoice) error { return fmt.Errorf("no payment hash found") } + if len(invoice.RouteHints) != 0 && + len(invoice.BlindedPaymentPaths) != 0 { + + return fmt.Errorf("cannot have both route hints and blinded " + + "payment paths") + } + // Either Description or DescriptionHash must be set, not both. if invoice.Description != nil && invoice.DescriptionHash != nil { return fmt.Errorf("both description and description hash set") diff --git a/zpay32/invoice_test.go b/zpay32/invoice_test.go index a360ed2a0..006b4fb6d 100644 --- a/zpay32/invoice_test.go +++ b/zpay32/invoice_test.go @@ -7,7 +7,6 @@ import ( "bytes" "encoding/hex" "fmt" - "reflect" "strings" "testing" "time" @@ -17,7 +16,9 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" + sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" ) var ( @@ -116,6 +117,62 @@ var ( // Must be initialized in init(). testDescriptionHash [32]byte + + testBlindedPK1Bytes, _ = hex.DecodeString("03f3311e948feb5115242c4e39" + + "6c81c448ab7ee5fd24c4e24e66c73533cc4f98b8") + testBlindedHopPK1, _ = btcec.ParsePubKey(testBlindedPK1Bytes) + testBlindedPK2Bytes, _ = hex.DecodeString("03a8c97ed5cd40d474e4ef18c8" + + "99854b25e5070106504cb225e6d2c112d61a805e") + testBlindedHopPK2, _ = btcec.ParsePubKey(testBlindedPK2Bytes) + testBlindedPK3Bytes, _ = hex.DecodeString("0220293926219d8efe733336e2" + + "b674570dd96aa763acb3564e6e367b384d861a0a") + testBlindedHopPK3, _ = btcec.ParsePubKey(testBlindedPK3Bytes) + testBlindedPK4Bytes, _ = hex.DecodeString("02c75eb336a038294eaaf76015" + + "8b2e851c3c0937262e35401ae64a1bee71a2e40c") + testBlindedHopPK4, _ = btcec.ParsePubKey(testBlindedPK4Bytes) + + blindedPath1 = &BlindedPaymentPath{ + FeeBaseMsat: 40, + FeeRate: 20, + CltvExpiryDelta: 130, + HTLCMinMsat: 2, + HTLCMaxMsat: 100, + Features: lnwire.EmptyFeatureVector(), + FirstEphemeralBlindingPoint: testBlindedHopPK1, + Hops: []*sphinx.BlindedHopInfo{ + { + BlindedNodePub: testBlindedHopPK2, + CipherText: []byte{1, 2, 3, 4, 5}, + }, + { + BlindedNodePub: testBlindedHopPK3, + CipherText: []byte{5, 4, 3, 2, 1}, + }, + { + BlindedNodePub: testBlindedHopPK4, + CipherText: []byte{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, + }, + }, + }, + } + + blindedPath2 = &BlindedPaymentPath{ + FeeBaseMsat: 4, + FeeRate: 2, + CltvExpiryDelta: 10, + HTLCMinMsat: 0, + HTLCMaxMsat: 10, + Features: lnwire.EmptyFeatureVector(), + FirstEphemeralBlindingPoint: testBlindedHopPK4, + Hops: []*sphinx.BlindedHopInfo{ + { + BlindedNodePub: testBlindedHopPK3, + CipherText: []byte{1, 2, 3, 4, 5}, + }, + }, + } ) func init() { @@ -125,6 +182,8 @@ func init() { // TestDecodeEncode tests that an encoded invoice gets decoded into the expected // Invoice object, and that reencoding the decoded invoice gets us back to the // original encoded string. +// +//nolint:lll func TestDecodeEncode(t *testing.T) { t.Parallel() @@ -673,52 +732,77 @@ func TestDecodeEncode(t *testing.T) { i.Destination = nil }, }, + { + // Invoice with blinded payment paths. + encodedInvoice: "lnbc20m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5xysxxatsyp3k7enxv4js5fdqqqqq2qqqqqpgqyzqqqqqqqqqqqqyqqqqqqqqqqqvsqqqqlnxy0ffrlt2y2jgtzw89kgr3zg4dlwtlfycn3yuek8x5eucnuchqps82xf0m2u6sx5wnjw7xxgnxz5kf09quqsv5zvkgj7d5kpzttp4qz7q5qsyqcyq5pzq2feycsemrh7wvendc4kw3tsmkt25a36ev6kfehrv7ecfkrp5zs9q5zqxqspqtr4avek5quzjn427asptzews5wrczfhychr2sq6ue9phmn35tjqcrspqgpsgpgxquyqjzstpsxsu59zqqqqqpqqqqqqyqq2qqqqqqqqqqqqqqqqqqqqqqqqpgqqqqk8t6endgpc99824amqzk9japgu8synwf3wx4qp4ej2r0h8rghypsqsygpf8ynzr8vwleenxdhzke69wrwed2nk8t9n2e8xudnm8pxcvxs2q5qsyqcyq5y4rdlhtf84f8rgdj34275juwls2ftxtcfh035863q3p9k6s94hpxhdmzfn5gxpsazdznxs56j4vt3fdhe00g9v2l3szher50hp4xlggqkxf77f", + valid: true, + decodedInvoice: func() *Invoice { + return &Invoice{ + Net: &chaincfg.MainNetParams, + MilliSat: &testMillisat20mBTC, + Timestamp: time.Unix(1496314658, 0), + PaymentHash: &testPaymentHash, + Description: &testCupOfCoffee, + Destination: testPubKey, + Features: emptyFeatures, + BlindedPaymentPaths: []*BlindedPaymentPath{ + blindedPath1, + blindedPath2, + }, + } + }, + beforeEncoding: func(i *Invoice) { + // Since this destination pubkey was recovered + // from the signature, we must set it nil before + // encoding to get back the same invoice string. + i.Destination = nil + }, + }, } for i, test := range tests { - var decodedInvoice *Invoice - net := &chaincfg.MainNetParams - if test.decodedInvoice != nil { - decodedInvoice = test.decodedInvoice() - net = decodedInvoice.Net - } + test := test - invoice, err := Decode(test.encodedInvoice, net) - if (err == nil) != test.valid { - t.Errorf("Decoding test %d failed: %v", i, err) - return - } + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + t.Parallel() - if test.valid { - if err := compareInvoices(decodedInvoice, invoice); err != nil { - t.Errorf("Invoice decoding result %d not as expected: %v", i, err) + var decodedInvoice *Invoice + net := &chaincfg.MainNetParams + if test.decodedInvoice != nil { + decodedInvoice = test.decodedInvoice() + net = decodedInvoice.Net + } + + invoice, err := Decode(test.encodedInvoice, net) + if !test.valid { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, decodedInvoice, invoice) + } + + if test.skipEncoding { return } - } - if test.skipEncoding { - continue - } + if test.beforeEncoding != nil { + test.beforeEncoding(decodedInvoice) + } - if test.beforeEncoding != nil { - test.beforeEncoding(decodedInvoice) - } + if decodedInvoice == nil { + return + } - if decodedInvoice != nil { reencoded, err := decodedInvoice.Encode( testMessageSigner, ) - if (err == nil) != test.valid { - t.Errorf("Encoding test %d failed: %v", i, err) + if !test.valid { + require.Error(t, err) return } - - if test.valid && test.encodedInvoice != reencoded { - t.Errorf("Encoding %d failed, expected %v, got %v", - i, test.encodedInvoice, reencoded) - return - } - } + require.NoError(t, err) + require.Equal(t, test.encodedInvoice, reencoded) + }) } } @@ -805,25 +889,42 @@ func TestNewInvoice(t *testing.T) { valid: true, encodedInvoice: "lnbcrt241pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdqqnp4q0n326hr8v9zprg8gsvezcch06gfaqqhde2aj730yg0durunfhv66df5c8pqjjt4z4ymmuaxfx8eh5v7hmzs3wrfas8m2sz5qz56rw2lxy8mmgm4xln0ha26qkw6u3vhu22pss2udugr9g74c3x20slpcqjgq0el4h6", }, + { + // Mainnet invoice with two blinded paths. + newInvoice: func() (*Invoice, error) { + return NewInvoice(&chaincfg.MainNetParams, + testPaymentHash, + time.Unix(1496314658, 0), + Amount(testMillisat20mBTC), + Description(testCupOfCoffee), + WithBlindedPaymentPath(blindedPath1), + WithBlindedPaymentPath(blindedPath2), + ) + }, + valid: true, + //nolint:lll + encodedInvoice: "lnbc20m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5xysxxatsyp3k7enxv4js5fdqqqqq2qqqqqpgqyzqqqqqqqqqqqqyqqqqqqqqqqqvsqqqqlnxy0ffrlt2y2jgtzw89kgr3zg4dlwtlfycn3yuek8x5eucnuchqps82xf0m2u6sx5wnjw7xxgnxz5kf09quqsv5zvkgj7d5kpzttp4qz7q5qsyqcyq5pzq2feycsemrh7wvendc4kw3tsmkt25a36ev6kfehrv7ecfkrp5zs9q5zqxqspqtr4avek5quzjn427asptzews5wrczfhychr2sq6ue9phmn35tjqcrspqgpsgpgxquyqjzstpsxsu59zqqqqqpqqqqqqyqq2qqqqqqqqqqqqqqqqqqqqqqqqpgqqqqk8t6endgpc99824amqzk9japgu8synwf3wx4qp4ej2r0h8rghypsqsygpf8ynzr8vwleenxdhzke69wrwed2nk8t9n2e8xudnm8pxcvxs2q5qsyqcyq5y4rdlhtf84f8rgdj34275juwls2ftxtcfh035863q3p9k6s94hpxhdmzfn5gxpsazdznxs56j4vt3fdhe00g9v2l3szher50hp4xlggqkxf77f", + }, } for i, test := range tests { + test := test - invoice, err := test.newInvoice() - if err != nil && !test.valid { - continue - } - encoded, err := invoice.Encode(testMessageSigner) - if (err == nil) != test.valid { - t.Errorf("NewInvoice test %d failed: %v", i, err) - return - } + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + t.Parallel() - if test.valid && test.encodedInvoice != encoded { - t.Errorf("Encoding %d failed, expected %v, got %v", - i, test.encodedInvoice, encoded) - return - } + invoice, err := test.newInvoice() + if !test.valid { + require.Error(t, err) + return + } + require.NoError(t, err) + + encoded, err := invoice.Encode(testMessageSigner) + require.NoError(t, err) + + require.Equal(t, test.encodedInvoice, encoded) + }) } } @@ -909,73 +1010,6 @@ func TestInvoiceChecksumMalleability(t *testing.T) { } } -func compareInvoices(expected, actual *Invoice) error { - if !reflect.DeepEqual(expected.Net, actual.Net) { - return fmt.Errorf("expected net %v, got %v", - expected.Net, actual.Net) - } - - if !reflect.DeepEqual(expected.MilliSat, actual.MilliSat) { - return fmt.Errorf("expected milli sat %d, got %d", - *expected.MilliSat, *actual.MilliSat) - } - - if expected.Timestamp != actual.Timestamp { - return fmt.Errorf("expected timestamp %v, got %v", - expected.Timestamp, actual.Timestamp) - } - - if !compareHashes(expected.PaymentHash, actual.PaymentHash) { - return fmt.Errorf("expected payment hash %x, got %x", - *expected.PaymentHash, *actual.PaymentHash) - } - - if !reflect.DeepEqual(expected.Description, actual.Description) { - return fmt.Errorf("expected description \"%s\", got \"%s\"", - *expected.Description, *actual.Description) - } - - if !comparePubkeys(expected.Destination, actual.Destination) { - return fmt.Errorf("expected destination pubkey %x, got %x", - expected.Destination.SerializeCompressed(), - actual.Destination.SerializeCompressed()) - } - - if !compareHashes(expected.DescriptionHash, actual.DescriptionHash) { - return fmt.Errorf("expected description hash %x, got %x", - *expected.DescriptionHash, *actual.DescriptionHash) - } - - if expected.Expiry() != actual.Expiry() { - return fmt.Errorf("expected expiry %d, got %d", - expected.Expiry(), actual.Expiry()) - } - - if !reflect.DeepEqual(expected.FallbackAddr, actual.FallbackAddr) { - return fmt.Errorf("expected FallbackAddr %v, got %v", - expected.FallbackAddr, actual.FallbackAddr) - } - - if len(expected.RouteHints) != len(actual.RouteHints) { - return fmt.Errorf("expected %d RouteHints, got %d", - len(expected.RouteHints), len(actual.RouteHints)) - } - - for i, routeHint := range expected.RouteHints { - err := compareRouteHints(routeHint, actual.RouteHints[i]) - if err != nil { - return err - } - } - - if !reflect.DeepEqual(expected.Features, actual.Features) { - return fmt.Errorf("expected features %v, got %v", - expected.Features, actual.Features) - } - - return nil -} - func comparePubkeys(a, b *btcec.PublicKey) bool { if a == b { return true