lnd/invoices/test_utils_test.go

421 lines
10 KiB
Go

package invoices_test
import (
"crypto/rand"
"encoding/binary"
"encoding/hex"
"fmt"
"os"
"runtime/pprof"
"sync"
"testing"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
invpkg "github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/zpay32"
"github.com/stretchr/testify/require"
)
type mockPayload struct {
mpp *record.MPP
amp *record.AMP
customRecords record.CustomSet
metadata []byte
}
func (p *mockPayload) MultiPath() *record.MPP {
return p.mpp
}
func (p *mockPayload) AMPRecord() *record.AMP {
return p.amp
}
func (p *mockPayload) CustomRecords() record.CustomSet {
// This function should always return a map instance, but for mock
// configuration we do accept nil.
if p.customRecords == nil {
return make(record.CustomSet)
}
return p.customRecords
}
func (p *mockPayload) Metadata() []byte {
return p.metadata
}
type mockChainNotifier struct {
chainntnfs.ChainNotifier
blockChan chan *chainntnfs.BlockEpoch
}
func newMockNotifier() *mockChainNotifier {
return &mockChainNotifier{
blockChan: make(chan *chainntnfs.BlockEpoch),
}
}
// RegisterBlockEpochNtfn mocks a block epoch notification, using the mock's
// block channel to deliver blocks to the client.
func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) (
*chainntnfs.BlockEpochEvent, error) {
return &chainntnfs.BlockEpochEvent{
Epochs: m.blockChan,
Cancel: func() {},
}, nil
}
const (
testHtlcExpiry = uint32(5)
testInvoiceCltvDelta = uint32(4)
testFinalCltvRejectDelta = int32(4)
testCurrentHeight = int32(1)
)
var (
testTimeout = 5 * time.Second
testTime = time.Date(2018, time.February, 2, 14, 0, 0, 0, time.UTC)
testInvoicePreimage = lntypes.Preimage{1}
testInvoicePaymentHash = testInvoicePreimage.Hash()
testPrivKeyBytes, _ = hex.DecodeString(
"e126f68f7eafcc8b74f54d269fe206be715000f94dac067d1c04a8ca3b2d" +
"b734",
)
testPrivKey, _ = btcec.PrivKeyFromBytes(testPrivKeyBytes)
testInvoiceDescription = "coffee"
testInvoiceAmount = lnwire.MilliSatoshi(100000)
testNetParams = &chaincfg.MainNetParams
testMessageSigner = zpay32.MessageSigner{
SignCompact: func(msg []byte) ([]byte, error) {
hash := chainhash.HashB(msg)
sig, err := ecdsa.SignCompact(testPrivKey, hash, true)
if err != nil {
return nil, fmt.Errorf("can't sign the "+
"message: %v", err)
}
return sig, nil
},
}
testFeatures = lnwire.NewFeatureVector(
nil, lnwire.Features,
)
testPayload = &mockPayload{}
testInvoiceCreationDate = testTime
)
func newTestChannelDB(t *testing.T, clock clock.Clock) (*channeldb.DB, error) {
t.Helper()
// Create channeldb for the first time.
cdb, err := channeldb.Open(
t.TempDir(), channeldb.OptionClock(clock),
)
if err != nil {
return nil, err
}
t.Cleanup(func() {
cdb.Close()
})
return cdb, nil
}
type testContext struct {
idb *channeldb.DB
registry *invpkg.InvoiceRegistry
notifier *mockChainNotifier
clock *clock.TestClock
t *testing.T
}
func defaultRegistryConfig() invpkg.RegistryConfig {
return invpkg.RegistryConfig{
FinalCltvRejectDelta: testFinalCltvRejectDelta,
HtlcHoldDuration: 30 * time.Second,
}
}
func newTestContext(t *testing.T,
registryCfg *invpkg.RegistryConfig) *testContext {
t.Helper()
clock := clock.NewTestClock(testTime)
idb, err := newTestChannelDB(t, clock)
if err != nil {
t.Fatal(err)
}
notifier := newMockNotifier()
expiryWatcher := invpkg.NewInvoiceExpiryWatcher(
clock, 0, uint32(testCurrentHeight), nil, notifier,
)
cfg := defaultRegistryConfig()
if registryCfg != nil {
cfg = *registryCfg
}
cfg.Clock = clock
// Instantiate and start the invoice ctx.registry.
registry := invpkg.NewRegistry(idb, expiryWatcher, &cfg)
err = registry.Start()
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
require.NoError(t, registry.Stop())
})
ctx := testContext{
idb: idb,
registry: registry,
notifier: notifier,
clock: clock,
t: t,
}
return &ctx
}
func getCircuitKey(htlcID uint64) invpkg.CircuitKey {
return invpkg.CircuitKey{
ChanID: lnwire.ShortChannelID{
BlockHeight: 1, TxIndex: 2, TxPosition: 3,
},
HtlcID: htlcID,
}
}
// newInvoice returns an invoice that can be used for testing, using the
// constant values defined above (deep copied if necessary).
//
// Note that this invoice *does not* have a payment address set. It will
// create a regular invoice with a preimage is hodl is false, and a hodl
// invoice with no preimage otherwise.
func newInvoice(t *testing.T, hodl bool) *invpkg.Invoice {
invoice := &invpkg.Invoice{
Terms: invpkg.ContractTerm{
Value: testInvoiceAmount,
Expiry: time.Hour,
Features: testFeatures.Clone(),
},
CreationDate: testInvoiceCreationDate,
}
// If creating a hodl invoice, we don't include a preimage.
if hodl {
invoice.HodlInvoice = true
return invoice
}
preimage, err := lntypes.MakePreimage(
testInvoicePreimage[:],
)
require.NoError(t, err)
invoice.Terms.PaymentPreimage = &preimage
return invoice
}
// timeout implements a test level timeout.
func timeout() func() {
done := make(chan struct{})
go func() {
select {
case <-time.After(5 * time.Second):
err := pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
if err != nil {
panic(fmt.Sprintf("error writing to std out "+
"after timeout: %v", err))
}
panic("timeout")
case <-done:
}
}()
return func() {
close(done)
}
}
// invoiceExpiryTestData simply holds generated expired and pending invoices.
type invoiceExpiryTestData struct {
expiredInvoices map[lntypes.Hash]*invpkg.Invoice
pendingInvoices map[lntypes.Hash]*invpkg.Invoice
}
// generateInvoiceExpiryTestData generates the specified number of fake expired
// and pending invoices anchored to the passed now timestamp.
func generateInvoiceExpiryTestData(
t *testing.T, now time.Time,
offset, numExpired, numPending int) invoiceExpiryTestData {
var testData invoiceExpiryTestData
testData.expiredInvoices = make(map[lntypes.Hash]*invpkg.Invoice)
testData.pendingInvoices = make(map[lntypes.Hash]*invpkg.Invoice)
expiredCreationDate := now.Add(-24 * time.Hour)
for i := 1; i <= numExpired; i++ {
var preimage lntypes.Preimage
binary.BigEndian.PutUint32(preimage[:4], uint32(offset+i))
expiry := time.Duration((i+offset)%24) * time.Hour
invoice := newInvoiceExpiryTestInvoice(
t, preimage, expiredCreationDate, expiry,
)
testData.expiredInvoices[preimage.Hash()] = invoice
}
for i := 1; i <= numPending; i++ {
var preimage lntypes.Preimage
binary.BigEndian.PutUint32(preimage[4:], uint32(offset+i))
expiry := time.Duration((i+offset)%24) * time.Hour
invoice := newInvoiceExpiryTestInvoice(t, preimage, now, expiry)
testData.pendingInvoices[preimage.Hash()] = invoice
}
return testData
}
// newInvoiceExpiryTestInvoice creates a test invoice with a randomly generated
// payment address and custom preimage and expiry details. It should be used in
// the case where tests require custom invoice expiry and unique payment
// hashes.
func newInvoiceExpiryTestInvoice(t *testing.T, preimage lntypes.Preimage,
timestamp time.Time, expiry time.Duration) *invpkg.Invoice {
if expiry == 0 {
expiry = time.Hour
}
var payAddr [32]byte
if _, err := rand.Read(payAddr[:]); err != nil {
t.Fatalf("unable to generate payment addr: %v", err)
}
rawInvoice, err := zpay32.NewInvoice(
testNetParams,
preimage.Hash(),
timestamp,
zpay32.Amount(testInvoiceAmount),
zpay32.Description(testInvoiceDescription),
zpay32.Expiry(expiry),
zpay32.PaymentAddr(payAddr),
)
require.NoError(t, err, "Error while creating new invoice")
paymentRequest, err := rawInvoice.Encode(testMessageSigner)
require.NoError(t, err, "Error while encoding payment request")
return &invpkg.Invoice{
Terms: invpkg.ContractTerm{
PaymentPreimage: &preimage,
PaymentAddr: payAddr,
Value: testInvoiceAmount,
Expiry: expiry,
Features: testFeatures,
},
PaymentRequest: []byte(paymentRequest),
CreationDate: timestamp,
}
}
// checkSettleResolution asserts the resolution is a settle with the correct
// preimage. If successful, the HtlcSettleResolution is returned in case further
// checks are desired.
func checkSettleResolution(t *testing.T, res invpkg.HtlcResolution,
expPreimage lntypes.Preimage) *invpkg.HtlcSettleResolution {
t.Helper()
settleResolution, ok := res.(*invpkg.HtlcSettleResolution)
require.True(t, ok)
require.Equal(t, expPreimage, settleResolution.Preimage)
return settleResolution
}
// checkFailResolution asserts the resolution is a fail with the correct reason.
// If successful, the HtlcFailResolution is returned in case further checks are
// desired.
func checkFailResolution(t *testing.T, res invpkg.HtlcResolution,
expOutcome invpkg.FailResolutionResult) *invpkg.HtlcFailResolution {
t.Helper()
failResolution, ok := res.(*invpkg.HtlcFailResolution)
require.True(t, ok)
require.Equal(t, expOutcome, failResolution.Outcome)
return failResolution
}
type hodlExpiryTest struct {
hash lntypes.Hash
state invpkg.ContractState
stateLock sync.Mutex
mockNotifier *mockChainNotifier
mockClock *clock.TestClock
cancelChan chan lntypes.Hash
watcher *invpkg.InvoiceExpiryWatcher
}
func (h *hodlExpiryTest) announceBlock(t *testing.T, height uint32) {
t.Helper()
select {
case h.mockNotifier.blockChan <- &chainntnfs.BlockEpoch{
Height: int32(height),
}:
case <-time.After(testTimeout):
t.Fatalf("block %v not consumed", height)
}
}
func (h *hodlExpiryTest) assertCanceled(t *testing.T, expected lntypes.Hash) {
select {
case actual := <-h.cancelChan:
require.Equal(t, expected, actual)
case <-time.After(testTimeout):
t.Fatalf("invoice: %v not canceled", h.hash)
}
}