lnd/invoices/invoice_expiry_watcher_test.go

319 lines
8.5 KiB
Go

package invoices
import (
"sync"
"testing"
"time"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
)
// invoiceExpiryWatcherTest holds a test fixture and implements checks
// for InvoiceExpiryWatcher tests.
type invoiceExpiryWatcherTest struct {
t *testing.T
wg sync.WaitGroup
watcher *InvoiceExpiryWatcher
testData invoiceExpiryTestData
canceledInvoices []lntypes.Hash
}
type mockChainNotifier struct {
chainntnfs.ChainNotifier
blockChan chan *chainntnfs.BlockEpoch
}
func newMockNotifier() *mockChainNotifier {
return &mockChainNotifier{
blockChan: make(chan *chainntnfs.BlockEpoch),
}
}
// RegisterBlockEpochNtfn mocks a block epoch notification, using the mock's
// block channel to deliver blocks to the client.
func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) (
*chainntnfs.BlockEpochEvent, error) {
return &chainntnfs.BlockEpochEvent{
Epochs: m.blockChan,
Cancel: func() {},
}, nil
}
// newInvoiceExpiryWatcherTest creates a new InvoiceExpiryWatcher test fixture
// and sets up the test environment.
func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time,
numExpiredInvoices, numPendingInvoices int) *invoiceExpiryWatcherTest {
mockNotifier := newMockNotifier()
test := &invoiceExpiryWatcherTest{
watcher: NewInvoiceExpiryWatcher(
clock.NewTestClock(testTime), 0,
uint32(testCurrentHeight), nil, mockNotifier,
),
testData: generateInvoiceExpiryTestData(
t, now, 0, numExpiredInvoices, numPendingInvoices,
),
}
test.wg.Add(numExpiredInvoices)
err := test.watcher.Start(func(paymentHash lntypes.Hash,
force bool) error {
test.canceledInvoices = append(
test.canceledInvoices, paymentHash,
)
test.wg.Done()
return nil
})
if err != nil {
t.Fatalf("cannot start InvoiceExpiryWatcher: %v", err)
}
return test
}
func (t *invoiceExpiryWatcherTest) waitForFinish(timeout time.Duration) {
done := make(chan struct{})
// Wait for all cancels.
go func() {
t.wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(timeout):
t.t.Fatalf("test timeout")
}
}
func (t *invoiceExpiryWatcherTest) checkExpectations() {
// Check that invoices that got canceled during the test are the ones
// that expired.
if len(t.canceledInvoices) != len(t.testData.expiredInvoices) {
t.t.Fatalf("expected %v cancellations, got %v",
len(t.testData.expiredInvoices),
len(t.canceledInvoices))
}
for i := range t.canceledInvoices {
if _, ok := t.testData.expiredInvoices[t.canceledInvoices[i]]; !ok {
t.t.Fatalf("wrong invoice canceled")
}
}
}
// Tests that InvoiceExpiryWatcher can be started and stopped.
func TestInvoiceExpiryWatcherStartStop(t *testing.T) {
watcher := NewInvoiceExpiryWatcher(
clock.NewTestClock(testTime), 0, uint32(testCurrentHeight), nil,
newMockNotifier(),
)
cancel := func(lntypes.Hash, bool) error {
t.Fatalf("unexpected call")
return nil
}
if err := watcher.Start(cancel); err != nil {
t.Fatalf("unexpected error upon start: %v", err)
}
if err := watcher.Start(cancel); err == nil {
t.Fatalf("expected error upon second start")
}
watcher.Stop()
if err := watcher.Start(cancel); err != nil {
t.Fatalf("unexpected error upon start: %v", err)
}
}
// Tests that no invoices will expire from an empty InvoiceExpiryWatcher.
func TestInvoiceExpiryWithNoInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 0, 0)
test.waitForFinish(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}
// Tests that if all add invoices are expired, then all invoices
// will be canceled.
func TestInvoiceExpiryWithOnlyExpiredInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 0, 5)
for paymentHash, invoice := range test.testData.pendingInvoices {
test.watcher.AddInvoices(makeInvoiceExpiry(paymentHash, invoice))
}
test.waitForFinish(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}
// Tests that if some invoices are expired, then those invoices
// will be canceled.
func TestInvoiceExpiryWithPendingAndExpiredInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5)
for paymentHash, invoice := range test.testData.expiredInvoices {
test.watcher.AddInvoices(makeInvoiceExpiry(paymentHash, invoice))
}
for paymentHash, invoice := range test.testData.pendingInvoices {
test.watcher.AddInvoices(makeInvoiceExpiry(paymentHash, invoice))
}
test.waitForFinish(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}
// Tests adding multiple invoices at once.
func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5)
var invoices []invoiceExpiry
for hash, invoice := range test.testData.expiredInvoices {
invoices = append(invoices, makeInvoiceExpiry(hash, invoice))
}
for hash, invoice := range test.testData.pendingInvoices {
invoices = append(invoices, makeInvoiceExpiry(hash, invoice))
}
test.watcher.AddInvoices(invoices...)
test.waitForFinish(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}
// TestExpiredHodlInv tests expiration of an already-expired hodl invoice
// which has no htlcs.
func TestExpiredHodlInv(t *testing.T) {
t.Parallel()
creationDate := testTime.Add(time.Hour * -24)
expiry := time.Hour
test := setupHodlExpiry(
t, creationDate, expiry, 0, channeldb.ContractOpen, nil,
)
test.assertCanceled(t, test.hash)
test.watcher.Stop()
}
// TestAcceptedHodlNotExpired tests that hodl invoices which are in an accepted
// state are not expired once their time-based expiry elapses, using a regular
// invoice that expires at the same time as a control to ensure that invoices
// with that timestamp would otherwise be expired.
func TestAcceptedHodlNotExpired(t *testing.T) {
t.Parallel()
creationDate := testTime
expiry := time.Hour
test := setupHodlExpiry(
t, creationDate, expiry, 0, channeldb.ContractAccepted, nil,
)
defer test.watcher.Stop()
// Add another invoice that will expire at our expiry time as a control
// value.
tsExpires := &invoiceExpiryTs{
PaymentHash: lntypes.Hash{1, 2, 3},
Expiry: creationDate.Add(expiry),
Keysend: true,
}
test.watcher.AddInvoices(tsExpires)
test.mockClock.SetTime(creationDate.Add(expiry + 1))
// Assert that only the ts expiry invoice is expired.
test.assertCanceled(t, tsExpires.PaymentHash)
}
// TestHeightAlreadyExpired tests the case where we add an invoice with htlcs
// that have already expired to the expiry watcher.
func TestHeightAlreadyExpired(t *testing.T) {
t.Parallel()
expiredHtlc := []*channeldb.InvoiceHTLC{
{
State: channeldb.HtlcStateAccepted,
Expiry: uint32(testCurrentHeight),
},
}
test := setupHodlExpiry(
t, testTime, time.Hour, 0, channeldb.ContractAccepted,
expiredHtlc,
)
defer test.watcher.Stop()
test.assertCanceled(t, test.hash)
}
// TestExpiryHeightArrives tests the case where we add a hodl invoice to the
// expiry watcher when it has no htlcs, htlcs are added and then they finally
// expire. We use a non-zero delta for this test to check that we expire with
// sufficient buffer.
func TestExpiryHeightArrives(t *testing.T) {
var (
creationDate = testTime
expiry = time.Hour * 2
delta uint32 = 1
)
// Start out with a hodl invoice that is open, and has no htlcs.
test := setupHodlExpiry(
t, creationDate, expiry, delta, channeldb.ContractOpen, nil,
)
defer test.watcher.Stop()
htlc1 := uint32(testCurrentHeight + 10)
expiry1 := makeHeightExpiry(test.hash, htlc1)
// Add htlcs to our invoice and progress its state to accepted.
test.watcher.AddInvoices(expiry1)
test.setState(channeldb.ContractAccepted)
// Progress time so that our expiry has elapsed. We no longer expect
// this invoice to be canceled because it has been accepted.
test.mockClock.SetTime(creationDate.Add(expiry))
// Tick our mock block subscription with the next block, we don't
// expect anything to happen.
currentHeight := uint32(testCurrentHeight + 1)
test.announceBlock(t, currentHeight)
// Now, we add another htlc to the invoice. This one has a lower expiry
// height than our current ones.
htlc2 := currentHeight + 5
expiry2 := makeHeightExpiry(test.hash, htlc2)
test.watcher.AddInvoices(expiry2)
// Announce our lowest htlc expiry block minus our delta, the invoice
// should be expired now.
test.announceBlock(t, htlc2-delta)
test.assertCanceled(t, test.hash)
}