diff --git a/invoices/invoice_expiry_watcher_test.go b/invoices/invoice_expiry_watcher_test.go index 8940063bc..e7e28b500 100644 --- a/invoices/invoice_expiry_watcher_test.go +++ b/invoices/invoice_expiry_watcher_test.go @@ -1,6 +1,7 @@ package invoices import ( + "sync" "testing" "time" @@ -13,6 +14,7 @@ import ( // for InvoiceExpiryWatcher tests. type invoiceExpiryWatcherTest struct { t *testing.T + wg sync.WaitGroup watcher *InvoiceExpiryWatcher testData invoiceExpiryTestData canceledInvoices []lntypes.Hash @@ -30,8 +32,11 @@ func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time, ), } + test.wg.Add(numExpiredInvoices) + err := test.watcher.Start(func(paymentHash lntypes.Hash) error { test.canceledInvoices = append(test.canceledInvoices, paymentHash) + test.wg.Done() return nil }) @@ -42,6 +47,22 @@ func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time, return test } +func (t *invoiceExpiryWatcherTest) waitForFinish(timeout time.Duration) { + done := make(chan struct{}) + + // Wait for all cancels. + go func() { + t.wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(timeout): + t.t.Fatalf("test timeout") + } +} + func (t *invoiceExpiryWatcherTest) checkExpectations() { // Check that invoices that got canceled during the test are the ones // that expired. @@ -83,9 +104,10 @@ func TestInvoiceExpiryWatcherStartStop(t *testing.T) { // Tests that no invoices will expire from an empty InvoiceExpiryWatcher. func TestInvoiceExpiryWithNoInvoices(t *testing.T) { t.Parallel() + test := newInvoiceExpiryWatcherTest(t, testTime, 0, 0) - time.Sleep(testTimeout) + test.waitForFinish(testTimeout) test.watcher.Stop() test.checkExpectations() } @@ -101,7 +123,7 @@ func TestInvoiceExpiryWithOnlyExpiredInvoices(t *testing.T) { test.watcher.AddInvoice(paymentHash, invoice) } - time.Sleep(testTimeout) + test.waitForFinish(testTimeout) test.watcher.Stop() test.checkExpectations() } @@ -110,6 +132,7 @@ func TestInvoiceExpiryWithOnlyExpiredInvoices(t *testing.T) { // will be canceled. func TestInvoiceExpiryWithPendingAndExpiredInvoices(t *testing.T) { t.Parallel() + test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) for paymentHash, invoice := range test.testData.expiredInvoices { @@ -120,7 +143,7 @@ func TestInvoiceExpiryWithPendingAndExpiredInvoices(t *testing.T) { test.watcher.AddInvoice(paymentHash, invoice) } - time.Sleep(testTimeout) + test.waitForFinish(testTimeout) test.watcher.Stop() test.checkExpectations() } @@ -128,8 +151,10 @@ func TestInvoiceExpiryWithPendingAndExpiredInvoices(t *testing.T) { // Tests adding multiple invoices at once. func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) { t.Parallel() + test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) var invoices []channeldb.InvoiceWithPaymentHash + for hash, invoice := range test.testData.expiredInvoices { invoices = append(invoices, channeldb.InvoiceWithPaymentHash{ @@ -138,6 +163,7 @@ func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) { }, ) } + for hash, invoice := range test.testData.pendingInvoices { invoices = append(invoices, channeldb.InvoiceWithPaymentHash{ @@ -148,7 +174,7 @@ func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) { } test.watcher.AddInvoices(invoices) - time.Sleep(testTimeout) + test.waitForFinish(testTimeout) test.watcher.Stop() test.checkExpectations() }