multi: replace defer cleanup with t.Cleanup

Signed-off-by: Eng Zer Jun <engzerjun@gmail.com>
This commit is contained in:
Eng Zer Jun 2022-08-27 15:04:55 +08:00
parent 5c5997935d
commit c70e39cd21
No known key found for this signature in database
GPG Key ID: DAEBBD2E34C111E6
29 changed files with 393 additions and 606 deletions

View File

@ -346,9 +346,8 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel {
func TestOpenChannelPutGetDelete(t *testing.T) { func TestOpenChannelPutGetDelete(t *testing.T) {
t.Parallel() t.Parallel()
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()
@ -487,11 +486,10 @@ func TestOptionalShutdown(t *testing.T) {
test := test test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", err) t.Fatalf("unable to make test database: %v", err)
} }
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()
@ -572,9 +570,8 @@ func assertRevocationLogEntryEqual(t *testing.T, c *ChannelCommitment,
func TestChannelStateTransition(t *testing.T) { func TestChannelStateTransition(t *testing.T) {
t.Parallel() t.Parallel()
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()
@ -889,9 +886,8 @@ func TestChannelStateTransition(t *testing.T) {
func TestFetchPendingChannels(t *testing.T) { func TestFetchPendingChannels(t *testing.T) {
t.Parallel() t.Parallel()
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()
@ -960,9 +956,8 @@ func TestFetchPendingChannels(t *testing.T) {
func TestFetchClosedChannels(t *testing.T) { func TestFetchClosedChannels(t *testing.T) {
t.Parallel() t.Parallel()
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()
@ -1041,9 +1036,8 @@ func TestFetchWaitingCloseChannels(t *testing.T) {
// We'll start by creating two channels within our test database. One of // We'll start by creating two channels within our test database. One of
// them will have their funding transaction confirmed on-chain, while // them will have their funding transaction confirmed on-chain, while
// the other one will remain unconfirmed. // the other one will remain unconfirmed.
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()
@ -1154,9 +1148,8 @@ func TestFetchWaitingCloseChannels(t *testing.T) {
func TestRefresh(t *testing.T) { func TestRefresh(t *testing.T) {
t.Parallel() t.Parallel()
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()
@ -1298,12 +1291,11 @@ func TestCloseInitiator(t *testing.T) {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
t.Parallel() t.Parallel()
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", t.Fatalf("unable to make test database: %v",
err) err)
} }
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()
@ -1345,12 +1337,11 @@ func TestCloseInitiator(t *testing.T) {
// TestCloseChannelStatus tests setting of a channel status on the historical // TestCloseChannelStatus tests setting of a channel status on the historical
// channel on channel close. // channel on channel close.
func TestCloseChannelStatus(t *testing.T) { func TestCloseChannelStatus(t *testing.T) {
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", t.Fatalf("unable to make test database: %v",
err) err)
} }
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()

View File

@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os" "os"
"testing"
"github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
@ -1655,33 +1655,28 @@ func (c *ChannelStateDB) FetchHistoricalChannel(outPoint *wire.OutPoint) (
// MakeTestDB creates a new instance of the ChannelDB for testing purposes. // MakeTestDB creates a new instance of the ChannelDB for testing purposes.
// A callback which cleans up the created temporary directories is also // A callback which cleans up the created temporary directories is also
// returned and intended to be executed after the test completes. // returned and intended to be executed after the test completes.
func MakeTestDB(modifiers ...OptionModifier) (*DB, func(), error) { func MakeTestDB(t *testing.T, modifiers ...OptionModifier) (*DB, error) {
// First, create a temporary directory to be used for the duration of // First, create a temporary directory to be used for the duration of
// this test. // this test.
tempDirName, err := ioutil.TempDir("", "channeldb") tempDirName := t.TempDir()
if err != nil {
return nil, nil, err
}
// Next, create channeldb for the first time. // Next, create channeldb for the first time.
backend, backendCleanup, err := kvdb.GetTestBackend(tempDirName, "cdb") backend, backendCleanup, err := kvdb.GetTestBackend(tempDirName, "cdb")
if err != nil { if err != nil {
backendCleanup() backendCleanup()
return nil, nil, err return nil, err
} }
cdb, err := CreateWithBackend(backend, modifiers...) cdb, err := CreateWithBackend(backend, modifiers...)
if err != nil { if err != nil {
backendCleanup() backendCleanup()
os.RemoveAll(tempDirName) return nil, err
return nil, nil, err
} }
cleanUp := func() { t.Cleanup(func() {
cdb.Close() cdb.Close()
backendCleanup() backendCleanup()
os.RemoveAll(tempDirName) })
}
return cdb, cleanUp, nil return cdb, nil
} }

View File

@ -36,7 +36,7 @@ func TestOpenWithCreate(t *testing.T) {
dbPath := filepath.Join(tempDirName, "cdb") dbPath := filepath.Join(tempDirName, "cdb")
backend, cleanup, err := kvdb.GetTestBackend(dbPath, "cdb") backend, cleanup, err := kvdb.GetTestBackend(dbPath, "cdb")
require.NoError(t, err, "unable to get test db backend") require.NoError(t, err, "unable to get test db backend")
defer cleanup() t.Cleanup(cleanup)
cdb, err := CreateWithBackend(backend) cdb, err := CreateWithBackend(backend)
require.NoError(t, err, "unable to create channeldb") require.NoError(t, err, "unable to create channeldb")
@ -72,7 +72,7 @@ func TestWipe(t *testing.T) {
dbPath := filepath.Join(tempDirName, "cdb") dbPath := filepath.Join(tempDirName, "cdb")
backend, cleanup, err := kvdb.GetTestBackend(dbPath, "cdb") backend, cleanup, err := kvdb.GetTestBackend(dbPath, "cdb")
require.NoError(t, err, "unable to get test db backend") require.NoError(t, err, "unable to get test db backend")
defer cleanup() t.Cleanup(cleanup)
fullDB, err := CreateWithBackend(backend) fullDB, err := CreateWithBackend(backend)
require.NoError(t, err, "unable to create channeldb") require.NoError(t, err, "unable to create channeldb")
@ -101,9 +101,8 @@ func TestFetchClosedChannelForID(t *testing.T) {
const numChans = 101 const numChans = 101
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()
@ -172,9 +171,8 @@ func TestFetchClosedChannelForID(t *testing.T) {
func TestAddrsForNode(t *testing.T) { func TestAddrsForNode(t *testing.T) {
t.Parallel() t.Parallel()
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
graph := fullDB.ChannelGraph() graph := fullDB.ChannelGraph()
@ -226,9 +224,8 @@ func TestAddrsForNode(t *testing.T) {
func TestFetchChannel(t *testing.T) { func TestFetchChannel(t *testing.T) {
t.Parallel() t.Parallel()
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()
@ -324,9 +321,8 @@ func genRandomChannelShell() (*ChannelShell, error) {
func TestRestoreChannelShells(t *testing.T) { func TestRestoreChannelShells(t *testing.T) {
t.Parallel() t.Parallel()
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()
@ -414,9 +410,8 @@ func TestRestoreChannelShells(t *testing.T) {
func TestAbandonChannel(t *testing.T) { func TestAbandonChannel(t *testing.T) {
t.Parallel() t.Parallel()
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()
@ -581,12 +576,11 @@ func TestFetchChannels(t *testing.T) {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
t.Parallel() t.Parallel()
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
if err != nil { if err != nil {
t.Fatalf("unable to make test "+ t.Fatalf("unable to make test "+
"database: %v", err) "database: %v", err)
} }
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()
@ -652,9 +646,8 @@ func TestFetchChannels(t *testing.T) {
// TestFetchHistoricalChannel tests lookup of historical channels. // TestFetchHistoricalChannel tests lookup of historical channels.
func TestFetchHistoricalChannel(t *testing.T) { func TestFetchHistoricalChannel(t *testing.T) {
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()

View File

@ -20,9 +20,8 @@ func TestForwardingLogBasicStorageAndQuery(t *testing.T) {
// First, we'll set up a test database, and use that to instantiate the // First, we'll set up a test database, and use that to instantiate the
// forwarding event log that we'll be using for the duration of the // forwarding event log that we'll be using for the duration of the
// test. // test.
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test db") require.NoError(t, err, "unable to make test db")
defer cleanUp()
log := ForwardingLog{ log := ForwardingLog{
db: db, db: db,
@ -89,9 +88,8 @@ func TestForwardingLogQueryOptions(t *testing.T) {
// First, we'll set up a test database, and use that to instantiate the // First, we'll set up a test database, and use that to instantiate the
// forwarding event log that we'll be using for the duration of the // forwarding event log that we'll be using for the duration of the
// test. // test.
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test db") require.NoError(t, err, "unable to make test db")
defer cleanUp()
log := ForwardingLog{ log := ForwardingLog{
db: db, db: db,
@ -189,9 +187,8 @@ func TestForwardingLogQueryLimit(t *testing.T) {
// First, we'll set up a test database, and use that to instantiate the // First, we'll set up a test database, and use that to instantiate the
// forwarding event log that we'll be using for the duration of the // forwarding event log that we'll be using for the duration of the
// test. // test.
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test db") require.NoError(t, err, "unable to make test db")
defer cleanUp()
log := ForwardingLog{ log := ForwardingLog{
db: db, db: db,
@ -301,9 +298,8 @@ func TestForwardingLogStoreEvent(t *testing.T) {
// First, we'll set up a test database, and use that to instantiate the // First, we'll set up a test database, and use that to instantiate the
// forwarding event log that we'll be using for the duration of the // forwarding event log that we'll be using for the duration of the
// test. // test.
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test db") require.NoError(t, err, "unable to make test db")
defer cleanUp()
log := ForwardingLog{ log := ForwardingLog{
db: db, db: db,

View File

@ -149,8 +149,7 @@ func TestInvoiceWorkflow(t *testing.T) {
} }
func testInvoiceWorkflow(t *testing.T, test invWorkflowTest) { func testInvoiceWorkflow(t *testing.T, test invWorkflowTest) {
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanUp()
require.NoError(t, err, "unable to make test db") require.NoError(t, err, "unable to make test db")
// Create a fake invoice which we'll use several times in the tests // Create a fake invoice which we'll use several times in the tests
@ -293,8 +292,7 @@ func testInvoiceWorkflow(t *testing.T, test invWorkflowTest) {
// TestAddDuplicatePayAddr asserts that the payment addresses of inserted // TestAddDuplicatePayAddr asserts that the payment addresses of inserted
// invoices are unique. // invoices are unique.
func TestAddDuplicatePayAddr(t *testing.T) { func TestAddDuplicatePayAddr(t *testing.T) {
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanUp()
require.NoError(t, err) require.NoError(t, err)
// Create two invoices with the same payment addr. // Create two invoices with the same payment addr.
@ -320,8 +318,7 @@ func TestAddDuplicatePayAddr(t *testing.T) {
// addresses to be inserted if they are blank to support JIT legacy keysend // addresses to be inserted if they are blank to support JIT legacy keysend
// invoices. // invoices.
func TestAddDuplicateKeysendPayAddr(t *testing.T) { func TestAddDuplicateKeysendPayAddr(t *testing.T) {
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanUp()
require.NoError(t, err) require.NoError(t, err)
// Create two invoices with the same _blank_ payment addr. // Create two invoices with the same _blank_ payment addr.
@ -363,8 +360,7 @@ func TestAddDuplicateKeysendPayAddr(t *testing.T) {
// ensures that the HTLC's payment hash always matches the payment hash in the // ensures that the HTLC's payment hash always matches the payment hash in the
// returned invoice. // returned invoice.
func TestFailInvoiceLookupMPPPayAddrOnly(t *testing.T) { func TestFailInvoiceLookupMPPPayAddrOnly(t *testing.T) {
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanUp()
require.NoError(t, err) require.NoError(t, err)
// Create and insert a random invoice. // Create and insert a random invoice.
@ -391,8 +387,7 @@ func TestFailInvoiceLookupMPPPayAddrOnly(t *testing.T) {
// TestInvRefEquivocation asserts that retrieving or updating an invoice using // TestInvRefEquivocation asserts that retrieving or updating an invoice using
// an equivocating InvoiceRef results in ErrInvRefEquivocation. // an equivocating InvoiceRef results in ErrInvRefEquivocation.
func TestInvRefEquivocation(t *testing.T) { func TestInvRefEquivocation(t *testing.T) {
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanUp()
require.NoError(t, err) require.NoError(t, err)
// Add two random invoices. // Add two random invoices.
@ -431,8 +426,7 @@ func TestInvRefEquivocation(t *testing.T) {
func TestInvoiceCancelSingleHtlc(t *testing.T) { func TestInvoiceCancelSingleHtlc(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanUp()
require.NoError(t, err, "unable to make test db") require.NoError(t, err, "unable to make test db")
preimage := lntypes.Preimage{1} preimage := lntypes.Preimage{1}
@ -499,8 +493,7 @@ func TestInvoiceCancelSingleHtlc(t *testing.T) {
func TestInvoiceCancelSingleHtlcAMP(t *testing.T) { func TestInvoiceCancelSingleHtlcAMP(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanUp, err := MakeTestDB(OptionClock(testClock)) db, err := MakeTestDB(t, OptionClock(testClock))
defer cleanUp()
require.NoError(t, err, "unable to make test db: %v", err) require.NoError(t, err, "unable to make test db: %v", err)
// We'll start out by creating an invoice and writing it to the DB. // We'll start out by creating an invoice and writing it to the DB.
@ -656,8 +649,7 @@ func TestInvoiceCancelSingleHtlcAMP(t *testing.T) {
func TestInvoiceAddTimeSeries(t *testing.T) { func TestInvoiceAddTimeSeries(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanUp, err := MakeTestDB(OptionClock(testClock)) db, err := MakeTestDB(t, OptionClock(testClock))
defer cleanUp()
require.NoError(t, err, "unable to make test db") require.NoError(t, err, "unable to make test db")
_, err = db.InvoicesAddedSince(0) _, err = db.InvoicesAddedSince(0)
@ -812,8 +804,7 @@ func TestSettleIndexAmpPayments(t *testing.T) {
t.Parallel() t.Parallel()
testClock := clock.NewTestClock(testNow) testClock := clock.NewTestClock(testNow)
db, cleanUp, err := MakeTestDB(OptionClock(testClock)) db, err := MakeTestDB(t, OptionClock(testClock))
defer cleanUp()
require.Nil(t, err) require.Nil(t, err)
// First, we'll make a sample invoice that'll be paid to several times // First, we'll make a sample invoice that'll be paid to several times
@ -969,8 +960,7 @@ func TestSettleIndexAmpPayments(t *testing.T) {
func TestScanInvoices(t *testing.T) { func TestScanInvoices(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanup()
require.NoError(t, err, "unable to make test db") require.NoError(t, err, "unable to make test db")
var invoices map[lntypes.Hash]*Invoice var invoices map[lntypes.Hash]*Invoice
@ -1028,8 +1018,7 @@ func TestScanInvoices(t *testing.T) {
func TestDuplicateSettleInvoice(t *testing.T) { func TestDuplicateSettleInvoice(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanUp, err := MakeTestDB(OptionClock(testClock)) db, err := MakeTestDB(t, OptionClock(testClock))
defer cleanUp()
require.NoError(t, err, "unable to make test db") require.NoError(t, err, "unable to make test db")
// We'll start out by creating an invoice and writing it to the DB. // We'll start out by creating an invoice and writing it to the DB.
@ -1087,8 +1076,7 @@ func TestDuplicateSettleInvoice(t *testing.T) {
func TestQueryInvoices(t *testing.T) { func TestQueryInvoices(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanUp, err := MakeTestDB(OptionClock(testClock)) db, err := MakeTestDB(t, OptionClock(testClock))
defer cleanUp()
require.NoError(t, err, "unable to make test db") require.NoError(t, err, "unable to make test db")
// To begin the test, we'll add 50 invoices to the database. We'll // To begin the test, we'll add 50 invoices to the database. We'll
@ -1400,8 +1388,7 @@ func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback {
func TestCustomRecords(t *testing.T) { func TestCustomRecords(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanUp()
require.NoError(t, err, "unable to make test db") require.NoError(t, err, "unable to make test db")
preimage := lntypes.Preimage{1} preimage := lntypes.Preimage{1}
@ -1470,8 +1457,7 @@ func TestInvoiceHtlcAMPFields(t *testing.T) {
} }
func testInvoiceHtlcAMPFields(t *testing.T, isAMP bool) { func testInvoiceHtlcAMPFields(t *testing.T, isAMP bool) {
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanUp()
require.Nil(t, err) require.Nil(t, err)
testInvoice, err := randInvoice(1000) testInvoice, err := randInvoice(1000)
@ -1652,8 +1638,7 @@ func TestHTLCSet(t *testing.T) {
// TestAddInvoiceWithHTLCs asserts that you can't insert an invoice that already // TestAddInvoiceWithHTLCs asserts that you can't insert an invoice that already
// has HTLCs. // has HTLCs.
func TestAddInvoiceWithHTLCs(t *testing.T) { func TestAddInvoiceWithHTLCs(t *testing.T) {
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanUp()
require.Nil(t, err) require.Nil(t, err)
testInvoice, err := randInvoice(1000) testInvoice, err := randInvoice(1000)
@ -1672,8 +1657,7 @@ func TestAddInvoiceWithHTLCs(t *testing.T) {
// that invoices with duplicate set ids are disallowed. // that invoices with duplicate set ids are disallowed.
func TestSetIDIndex(t *testing.T) { func TestSetIDIndex(t *testing.T) {
testClock := clock.NewTestClock(testNow) testClock := clock.NewTestClock(testNow)
db, cleanUp, err := MakeTestDB(OptionClock(testClock)) db, err := MakeTestDB(t, OptionClock(testClock))
defer cleanUp()
require.Nil(t, err) require.Nil(t, err)
// We'll start out by creating an invoice and writing it to the DB. // We'll start out by creating an invoice and writing it to the DB.
@ -1983,8 +1967,7 @@ func getUpdateInvoiceAMPSettle(setID *[32]byte,
func TestUnexpectedInvoicePreimage(t *testing.T) { func TestUnexpectedInvoicePreimage(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanup()
require.NoError(t, err, "unable to make test db") require.NoError(t, err, "unable to make test db")
invoice, err := randInvoice(lnwire.MilliSatoshi(100)) invoice, err := randInvoice(lnwire.MilliSatoshi(100))
@ -2040,8 +2023,7 @@ func TestUpdateHTLCPreimages(t *testing.T) {
} }
func testUpdateHTLCPreimages(t *testing.T, test updateHTLCPreimageTestCase) { func testUpdateHTLCPreimages(t *testing.T, test updateHTLCPreimageTestCase) {
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanup()
require.NoError(t, err, "unable to make test db") require.NoError(t, err, "unable to make test db")
// We'll start out by creating an invoice and writing it to the DB. // We'll start out by creating an invoice and writing it to the DB.
@ -2772,8 +2754,7 @@ func testUpdateHTLC(t *testing.T, test updateHTLCTest) {
func TestDeleteInvoices(t *testing.T) { func TestDeleteInvoices(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanup()
require.NoError(t, err, "unable to make test db") require.NoError(t, err, "unable to make test db")
// Add some invoices to the test db. // Add some invoices to the test db.
@ -2856,9 +2837,8 @@ func TestDeleteInvoices(t *testing.T) {
func TestAddInvoiceInvalidFeatureDeps(t *testing.T) { func TestAddInvoiceInvalidFeatureDeps(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test db") require.NoError(t, err, "unable to make test db")
defer cleanup()
invoice, err := randInvoice(500) invoice, err := randInvoice(500)
require.NoError(t, err) require.NoError(t, err)

View File

@ -15,8 +15,7 @@ import (
func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB),
migrationFunc migration, shouldFail bool, dryRun bool) { migrationFunc migration, shouldFail bool, dryRun bool) {
cdb, cleanUp, err := MakeTestDB() cdb, err := MakeTestDB(t)
defer cleanUp()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -86,8 +85,7 @@ func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB),
func TestVersionFetchPut(t *testing.T) { func TestVersionFetchPut(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanUp()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -450,7 +448,7 @@ func TestMigrationReversion(t *testing.T) {
backend, cleanup, err = kvdb.GetTestBackend(tempDirName, "cdb") backend, cleanup, err = kvdb.GetTestBackend(tempDirName, "cdb")
require.NoError(t, err, "unable to get test db backend") require.NoError(t, err, "unable to get test db backend")
defer cleanup() t.Cleanup(cleanup)
_, err = CreateWithBackend(backend) _, err = CreateWithBackend(backend)
if err != ErrDBReversion { if err != ErrDBReversion {
@ -498,8 +496,7 @@ func TestMigrationDryRun(t *testing.T) {
func TestOptionalMeta(t *testing.T) { func TestOptionalMeta(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanUp()
require.NoError(t, err) require.NoError(t, err)
// Test read an empty optional meta. // Test read an empty optional meta.
@ -527,8 +524,7 @@ func TestOptionalMeta(t *testing.T) {
func TestApplyOptionalVersions(t *testing.T) { func TestApplyOptionalVersions(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanUp()
require.NoError(t, err) require.NoError(t, err)
// Overwrite the migration function so we can count how many times the // Overwrite the migration function so we can count how many times the
@ -581,8 +577,7 @@ func TestApplyOptionalVersions(t *testing.T) {
func TestFetchMeta(t *testing.T) { func TestFetchMeta(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanUp()
require.NoError(t, err) require.NoError(t, err)
meta := &Meta{} meta := &Meta{}
@ -601,8 +596,7 @@ func TestFetchMeta(t *testing.T) {
func TestMarkerAndTombstone(t *testing.T) { func TestMarkerAndTombstone(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanUp, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanUp()
require.NoError(t, err) require.NoError(t, err)
// Test that a generic marker is not present in a fresh DB. // Test that a generic marker is not present in a fresh DB.

View File

@ -14,9 +14,8 @@ import (
func TestLinkNodeEncodeDecode(t *testing.T) { func TestLinkNodeEncodeDecode(t *testing.T) {
t.Parallel() t.Parallel()
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()
@ -103,9 +102,8 @@ func TestLinkNodeEncodeDecode(t *testing.T) {
func TestDeleteLinkNode(t *testing.T) { func TestDeleteLinkNode(t *testing.T) {
t.Parallel() t.Parallel()
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
cdb := fullDB.ChannelStateDB() cdb := fullDB.ChannelStateDB()

View File

@ -54,8 +54,7 @@ func genInfo() (*PaymentCreationInfo, *HTLCAttemptInfo,
func TestPaymentControlSwitchFail(t *testing.T) { func TestPaymentControlSwitchFail(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanup()
require.NoError(t, err, "unable to init db") require.NoError(t, err, "unable to init db")
pControl := NewPaymentControl(db) pControl := NewPaymentControl(db)
@ -185,9 +184,7 @@ func TestPaymentControlSwitchFail(t *testing.T) {
func TestPaymentControlSwitchDoubleSend(t *testing.T) { func TestPaymentControlSwitchDoubleSend(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanup()
require.NoError(t, err, "unable to init db") require.NoError(t, err, "unable to init db")
pControl := NewPaymentControl(db) pControl := NewPaymentControl(db)
@ -258,9 +255,7 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) {
func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanup()
require.NoError(t, err, "unable to init db") require.NoError(t, err, "unable to init db")
pControl := NewPaymentControl(db) pControl := NewPaymentControl(db)
@ -287,9 +282,7 @@ func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) {
func TestPaymentControlFailsWithoutInFlight(t *testing.T) { func TestPaymentControlFailsWithoutInFlight(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanup()
require.NoError(t, err, "unable to init db") require.NoError(t, err, "unable to init db")
pControl := NewPaymentControl(db) pControl := NewPaymentControl(db)
@ -311,9 +304,7 @@ func TestPaymentControlFailsWithoutInFlight(t *testing.T) {
func TestPaymentControlDeleteNonInFlight(t *testing.T) { func TestPaymentControlDeleteNonInFlight(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanup()
require.NoError(t, err, "unable to init db") require.NoError(t, err, "unable to init db")
// Create a sequence number for duplicate payments that will not collide // Create a sequence number for duplicate payments that will not collide
@ -520,8 +511,7 @@ func TestPaymentControlDeleteNonInFlight(t *testing.T) {
func TestPaymentControlDeletePayments(t *testing.T) { func TestPaymentControlDeletePayments(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanup()
require.NoError(t, err, "unable to init db") require.NoError(t, err, "unable to init db")
pControl := NewPaymentControl(db) pControl := NewPaymentControl(db)
@ -574,8 +564,7 @@ func TestPaymentControlDeletePayments(t *testing.T) {
func TestPaymentControlDeleteSinglePayment(t *testing.T) { func TestPaymentControlDeleteSinglePayment(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanup()
require.NoError(t, err, "unable to init db") require.NoError(t, err, "unable to init db")
pControl := NewPaymentControl(db) pControl := NewPaymentControl(db)
@ -678,9 +667,7 @@ func TestPaymentControlMultiShard(t *testing.T) {
} }
runSubTest := func(t *testing.T, test testCase) { runSubTest := func(t *testing.T, test testCase) {
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanup()
if err != nil { if err != nil {
t.Fatalf("unable to init db: %v", err) t.Fatalf("unable to init db: %v", err)
} }
@ -924,9 +911,7 @@ func TestPaymentControlMultiShard(t *testing.T) {
func TestPaymentControlMPPRecordValidation(t *testing.T) { func TestPaymentControlMPPRecordValidation(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanup()
require.NoError(t, err, "unable to init db") require.NoError(t, err, "unable to init db")
pControl := NewPaymentControl(db) pControl := NewPaymentControl(db)
@ -1017,8 +1002,7 @@ func TestDeleteFailedAttempts(t *testing.T) {
} }
func testDeleteFailedAttempts(t *testing.T, keepFailedPaymentAttempts bool) { func testDeleteFailedAttempts(t *testing.T, keepFailedPaymentAttempts bool) {
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
defer cleanup()
require.NoError(t, err, "unable to init db") require.NoError(t, err, "unable to init db")
db.keepFailedPaymentAttempts = keepFailedPaymentAttempts db.keepFailedPaymentAttempts = keepFailedPaymentAttempts

View File

@ -398,11 +398,10 @@ func TestQueryPayments(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
if err != nil { if err != nil {
t.Fatalf("unable to init db: %v", err) t.Fatalf("unable to init db: %v", err)
} }
defer cleanup()
// Make a preliminary query to make sure it's ok to // Make a preliminary query to make sure it's ok to
// query when we have no payments. // query when we have no payments.
@ -514,11 +513,9 @@ func TestQueryPayments(t *testing.T) {
// case where a specific duplicate is not found and the duplicates bucket is not // case where a specific duplicate is not found and the duplicates bucket is not
// present when we expect it to be. // present when we expect it to be.
func TestFetchPaymentWithSequenceNumber(t *testing.T) { func TestFetchPaymentWithSequenceNumber(t *testing.T) {
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
require.NoError(t, err) require.NoError(t, err)
defer cleanup()
pControl := NewPaymentControl(db) pControl := NewPaymentControl(db)
// Generate a test payment which does not have duplicates. // Generate a test payment which does not have duplicates.

View File

@ -10,9 +10,8 @@ import (
// TestFlapCount tests lookup and writing of flap count to disk. // TestFlapCount tests lookup and writing of flap count to disk.
func TestFlapCount(t *testing.T) { func TestFlapCount(t *testing.T) {
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
require.NoError(t, err) require.NoError(t, err)
defer cleanup()
// Try to read flap count for a peer that we have no records for. // Try to read flap count for a peer that we have no records for.
_, err = db.ReadFlapCount(testPub) _, err = db.ReadFlapCount(testPub)

View File

@ -48,9 +48,8 @@ func TestPersistReport(t *testing.T) {
test := test test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
require.NoError(t, err) require.NoError(t, err)
defer cleanup()
channelOutpoint := testChanPoint1 channelOutpoint := testChanPoint1
@ -85,9 +84,8 @@ func TestPersistReport(t *testing.T) {
// channel, testing that the appropriate error is returned based on the state // channel, testing that the appropriate error is returned based on the state
// of the existing bucket. // of the existing bucket.
func TestFetchChannelReadBucket(t *testing.T) { func TestFetchChannelReadBucket(t *testing.T) {
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
require.NoError(t, err) require.NoError(t, err)
defer cleanup()
channelOutpoint := testChanPoint1 channelOutpoint := testChanPoint1
@ -197,9 +195,8 @@ func TestFetchChannelWriteBucket(t *testing.T) {
test := test test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
require.NoError(t, err) require.NoError(t, err)
defer cleanup()
// Update our db to the starting state we expect. // Update our db to the starting state we expect.
err = kvdb.Update(db, test.setup, func() {}) err = kvdb.Update(db, test.setup, func() {})

View File

@ -291,9 +291,8 @@ func TestDerializeRevocationLog(t *testing.T) {
func TestFetchLogBucket(t *testing.T) { func TestFetchLogBucket(t *testing.T) {
t.Parallel() t.Parallel()
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err) require.NoError(t, err)
defer cleanUp()
backend := fullDB.ChannelStateDB().backend backend := fullDB.ChannelStateDB().backend
@ -326,9 +325,8 @@ func TestFetchLogBucket(t *testing.T) {
func TestDeleteLogBucket(t *testing.T) { func TestDeleteLogBucket(t *testing.T) {
t.Parallel() t.Parallel()
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err) require.NoError(t, err)
defer cleanUp()
backend := fullDB.ChannelStateDB().backend backend := fullDB.ChannelStateDB().backend
@ -423,9 +421,8 @@ func TestPutRevocationLog(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
tc := tc tc := tc
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err) require.NoError(t, err)
defer cleanUp()
backend := fullDB.ChannelStateDB().backend backend := fullDB.ChannelStateDB().backend
@ -523,9 +520,8 @@ func TestFetchRevocationLogCompatible(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
tc := tc tc := tc
fullDB, cleanUp, err := MakeTestDB() fullDB, err := MakeTestDB(t)
require.NoError(t, err) require.NoError(t, err)
defer cleanUp()
backend := fullDB.ChannelStateDB().backend backend := fullDB.ChannelStateDB().backend

View File

@ -15,9 +15,8 @@ import (
func TestWaitingProofStore(t *testing.T) { func TestWaitingProofStore(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanup, err := MakeTestDB() db, err := MakeTestDB(t)
require.NoError(t, err, "failed to make test database") require.NoError(t, err, "failed to make test database")
defer cleanup()
proof1 := NewWaitingProof(true, &lnwire.AnnounceSignatures{ proof1 := NewWaitingProof(true, &lnwire.AnnounceSignatures{
NodeSignature: wireSig, NodeSignature: wireSig,

View File

@ -13,9 +13,8 @@ import (
func TestWitnessCacheSha256Retrieval(t *testing.T) { func TestWitnessCacheSha256Retrieval(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() cdb, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
wCache := cdb.NewWitnessCache() wCache := cdb.NewWitnessCache()
@ -54,9 +53,8 @@ func TestWitnessCacheSha256Retrieval(t *testing.T) {
func TestWitnessCacheSha256Deletion(t *testing.T) { func TestWitnessCacheSha256Deletion(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() cdb, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
wCache := cdb.NewWitnessCache() wCache := cdb.NewWitnessCache()
@ -101,9 +99,8 @@ func TestWitnessCacheSha256Deletion(t *testing.T) {
func TestWitnessCacheUnknownWitness(t *testing.T) { func TestWitnessCacheUnknownWitness(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() cdb, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
wCache := cdb.NewWitnessCache() wCache := cdb.NewWitnessCache()
@ -118,9 +115,8 @@ func TestWitnessCacheUnknownWitness(t *testing.T) {
// TestAddSha256Witnesses tests that insertion using AddSha256Witnesses behaves // TestAddSha256Witnesses tests that insertion using AddSha256Witnesses behaves
// identically to the insertion via the generalized interface. // identically to the insertion via the generalized interface.
func TestAddSha256Witnesses(t *testing.T) { func TestAddSha256Witnesses(t *testing.T) {
cdb, cleanUp, err := MakeTestDB() cdb, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database") require.NoError(t, err, "unable to make test database")
defer cleanUp()
wCache := cdb.NewWitnessCache() wCache := cdb.NewWitnessCache()

View File

@ -956,18 +956,18 @@ restartCheck:
func initBreachedState(t *testing.T) (*BreachArbiter, func initBreachedState(t *testing.T) (*BreachArbiter,
*lnwallet.LightningChannel, *lnwallet.LightningChannel, *lnwallet.LightningChannel, *lnwallet.LightningChannel,
*lnwallet.LocalForceCloseSummary, chan *ContractBreachEvent, *lnwallet.LocalForceCloseSummary, chan *ContractBreachEvent) {
func(), func()) {
// Create a pair of channels using a notifier that allows us to signal // Create a pair of channels using a notifier that allows us to signal
// a spend of the funding transaction. Alice's channel will be the on // a spend of the funding transaction. Alice's channel will be the on
// observing a breach. // observing a breach.
alice, bob, cleanUpChans, err := createInitChannels(t, 1) alice, bob, err := createInitChannels(t, 1)
require.NoError(t, err, "unable to create test channels") require.NoError(t, err, "unable to create test channels")
// Instantiate a breach arbiter to handle the breach of alice's channel. // Instantiate a breach arbiter to handle the breach of alice's channel.
contractBreaches := make(chan *ContractBreachEvent) contractBreaches := make(chan *ContractBreachEvent)
brar, cleanUpArb, err := createTestArbiter( brar, err := createTestArbiter(
t, contractBreaches, alice.State().Db.GetParentDB(), t, contractBreaches, alice.State().Db.GetParentDB(),
) )
require.NoError(t, err, "unable to initialize test breach arbiter") require.NoError(t, err, "unable to initialize test breach arbiter")
@ -1003,8 +1003,7 @@ func initBreachedState(t *testing.T) (*BreachArbiter,
t.Fatalf("Can't update the channel state: %v", err) t.Fatalf("Can't update the channel state: %v", err)
} }
return brar, alice, bob, bobClose, contractBreaches, cleanUpChans, return brar, alice, bob, bobClose, contractBreaches
cleanUpArb
} }
// TestBreachHandoffSuccess tests that a channel's close observer properly // TestBreachHandoffSuccess tests that a channel's close observer properly
@ -1012,10 +1011,7 @@ func initBreachedState(t *testing.T) (*BreachArbiter,
// breach close. This test verifies correctness in the event that the handoff // breach close. This test verifies correctness in the event that the handoff
// experiences no interruptions. // experiences no interruptions.
func TestBreachHandoffSuccess(t *testing.T) { func TestBreachHandoffSuccess(t *testing.T) {
brar, alice, _, bobClose, contractBreaches, brar, alice, _, bobClose, contractBreaches := initBreachedState(t)
cleanUpChans, cleanUpArb := initBreachedState(t)
defer cleanUpChans()
defer cleanUpArb()
chanPoint := alice.ChanPoint chanPoint := alice.ChanPoint
@ -1093,10 +1089,7 @@ func TestBreachHandoffSuccess(t *testing.T) {
// arbiter fails to write the information to disk, and that a subsequent attempt // arbiter fails to write the information to disk, and that a subsequent attempt
// at the handoff succeeds. // at the handoff succeeds.
func TestBreachHandoffFail(t *testing.T) { func TestBreachHandoffFail(t *testing.T) {
brar, alice, _, bobClose, contractBreaches, brar, alice, _, bobClose, contractBreaches := initBreachedState(t)
cleanUpChans, cleanUpArb := initBreachedState(t)
defer cleanUpChans()
defer cleanUpArb()
// Before alerting Alice of the breach, instruct our failing retribution // Before alerting Alice of the breach, instruct our failing retribution
// store to fail the next database operation, which we expect to write // store to fail the next database operation, which we expect to write
@ -1140,11 +1133,10 @@ func TestBreachHandoffFail(t *testing.T) {
assertNoArbiterBreach(t, brar, chanPoint) assertNoArbiterBreach(t, brar, chanPoint)
assertNotPendingClosed(t, alice) assertNotPendingClosed(t, alice)
brar, cleanUpArb, err := createTestArbiter( brar, err := createTestArbiter(
t, contractBreaches, alice.State().Db.GetParentDB(), t, contractBreaches, alice.State().Db.GetParentDB(),
) )
require.NoError(t, err, "unable to initialize test breach arbiter") require.NoError(t, err, "unable to initialize test breach arbiter")
defer cleanUpArb()
// Signal a spend of the funding transaction and wait for the close // Signal a spend of the funding transaction and wait for the close
// observer to exit. This time we are allowing the handoff to succeed. // observer to exit. This time we are allowing the handoff to succeed.
@ -1183,9 +1175,7 @@ func TestBreachHandoffFail(t *testing.T) {
// TestBreachCreateJusticeTx tests that we create three different variants of // TestBreachCreateJusticeTx tests that we create three different variants of
// the justice tx. // the justice tx.
func TestBreachCreateJusticeTx(t *testing.T) { func TestBreachCreateJusticeTx(t *testing.T) {
brar, _, _, _, _, cleanUpChans, cleanUpArb := initBreachedState(t) brar, _, _, _, _ := initBreachedState(t)
defer cleanUpChans()
defer cleanUpArb()
// In this test we just want to check that the correct inputs are added // In this test we just want to check that the correct inputs are added
// to the justice tx, not that we create a valid spend, so we just set // to the justice tx, not that we create a valid spend, so we just set
@ -1564,10 +1554,7 @@ func TestBreachSpends(t *testing.T) {
} }
func testBreachSpends(t *testing.T, test breachTest) { func testBreachSpends(t *testing.T, test breachTest) {
brar, alice, _, bobClose, contractBreaches, brar, alice, _, bobClose, contractBreaches := initBreachedState(t)
cleanUpChans, cleanUpArb := initBreachedState(t)
defer cleanUpChans()
defer cleanUpArb()
var ( var (
height = bobClose.ChanSnapshot.CommitHeight height = bobClose.ChanSnapshot.CommitHeight
@ -1783,10 +1770,7 @@ func testBreachSpends(t *testing.T, test breachTest) {
// "split" the justice tx in case the first justice tx doesn't confirm within // "split" the justice tx in case the first justice tx doesn't confirm within
// a reasonable time. // a reasonable time.
func TestBreachDelayedJusticeConfirmation(t *testing.T) { func TestBreachDelayedJusticeConfirmation(t *testing.T) {
brar, alice, _, bobClose, contractBreaches, brar, alice, _, bobClose, contractBreaches := initBreachedState(t)
cleanUpChans, cleanUpArb := initBreachedState(t)
defer cleanUpChans()
defer cleanUpArb()
var ( var (
height = bobClose.ChanSnapshot.CommitHeight height = bobClose.ChanSnapshot.CommitHeight
@ -2123,7 +2107,7 @@ func assertNotPendingClosed(t *testing.T, c *lnwallet.LightningChannel) {
// createTestArbiter instantiates a breach arbiter with a failing retribution // createTestArbiter instantiates a breach arbiter with a failing retribution
// store, so that controlled failures can be tested. // store, so that controlled failures can be tested.
func createTestArbiter(t *testing.T, contractBreaches chan *ContractBreachEvent, func createTestArbiter(t *testing.T, contractBreaches chan *ContractBreachEvent,
db *channeldb.DB) (*BreachArbiter, func(), error) { db *channeldb.DB) (*BreachArbiter, error) {
// Create a failing retribution store, that wraps a normal one. // Create a failing retribution store, that wraps a normal one.
store := newFailingRetributionStore(func() RetributionStorer { store := newFailingRetributionStore(func() RetributionStorer {
@ -2148,21 +2132,21 @@ func createTestArbiter(t *testing.T, contractBreaches chan *ContractBreachEvent,
}) })
if err := ba.Start(); err != nil { if err := ba.Start(); err != nil {
return nil, nil, err return nil, err
} }
t.Cleanup(func() {
require.NoError(t, ba.Stop())
})
// The caller is responsible for closing the database. return ba, nil
cleanUp := func() {
ba.Stop()
}
return ba, cleanUp, nil
} }
// createInitChannels creates two initialized test channels funded with 10 BTC, // createInitChannels creates two initialized test channels funded with 10 BTC,
// with 5 BTC allocated to each side. Within the channel, Alice is the // with 5 BTC allocated to each side. Within the channel, Alice is the
// initiator. // initiator.
func createInitChannels(t *testing.T, revocationWindow int) (*lnwallet.LightningChannel, *lnwallet.LightningChannel, func(), error) { func createInitChannels(t *testing.T, revocationWindow int) (
*lnwallet.LightningChannel, *lnwallet.LightningChannel, error) {
aliceKeyPriv, aliceKeyPub := btcec.PrivKeyFromBytes( aliceKeyPriv, aliceKeyPub := btcec.PrivKeyFromBytes(
channels.AlicesPrivKey, channels.AlicesPrivKey,
) )
@ -2172,7 +2156,7 @@ func createInitChannels(t *testing.T, revocationWindow int) (*lnwallet.Lightning
channelCapacity, err := btcutil.NewAmount(10) channelCapacity, err := btcutil.NewAmount(10)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
channelBal := channelCapacity / 2 channelBal := channelCapacity / 2
@ -2240,23 +2224,23 @@ func createInitChannels(t *testing.T, revocationWindow int) (*lnwallet.Lightning
bobRoot, err := chainhash.NewHash(bobKeyPriv.Serialize()) bobRoot, err := chainhash.NewHash(bobKeyPriv.Serialize())
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
bobPreimageProducer := shachain.NewRevocationProducer(*bobRoot) bobPreimageProducer := shachain.NewRevocationProducer(*bobRoot)
bobFirstRevoke, err := bobPreimageProducer.AtIndex(0) bobFirstRevoke, err := bobPreimageProducer.AtIndex(0)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
bobCommitPoint := input.ComputeCommitmentPoint(bobFirstRevoke[:]) bobCommitPoint := input.ComputeCommitmentPoint(bobFirstRevoke[:])
aliceRoot, err := chainhash.NewHash(aliceKeyPriv.Serialize()) aliceRoot, err := chainhash.NewHash(aliceKeyPriv.Serialize())
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
alicePreimageProducer := shachain.NewRevocationProducer(*aliceRoot) alicePreimageProducer := shachain.NewRevocationProducer(*aliceRoot)
aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0) aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
aliceCommitPoint := input.ComputeCommitmentPoint(aliceFirstRevoke[:]) aliceCommitPoint := input.ComputeCommitmentPoint(aliceFirstRevoke[:])
@ -2266,23 +2250,29 @@ func createInitChannels(t *testing.T, revocationWindow int) (*lnwallet.Lightning
false, 0, false, 0,
) )
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
dbAlice, err := channeldb.Open(t.TempDir()) dbAlice, err := channeldb.Open(t.TempDir())
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
t.Cleanup(func() {
require.NoError(t, dbAlice.Close())
})
dbBob, err := channeldb.Open(t.TempDir()) dbBob, err := channeldb.Open(t.TempDir())
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
t.Cleanup(func() {
require.NoError(t, dbBob.Close())
})
estimator := chainfee.NewStaticEstimator(12500, 0) estimator := chainfee.NewStaticEstimator(12500, 0)
feePerKw, err := estimator.EstimateFeePerKW(1) feePerKw, err := estimator.EstimateFeePerKW(1)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
commitFee := feePerKw.FeeForWeight(input.CommitWeight) commitFee := feePerKw.FeeForWeight(input.CommitWeight)
@ -2309,7 +2299,7 @@ func createInitChannels(t *testing.T, revocationWindow int) (*lnwallet.Lightning
var chanIDBytes [8]byte var chanIDBytes [8]byte
if _, err := io.ReadFull(crand.Reader, chanIDBytes[:]); err != nil { if _, err := io.ReadFull(crand.Reader, chanIDBytes[:]); err != nil {
return nil, nil, nil, err return nil, nil, err
} }
shortChanID := lnwire.NewShortChanIDFromInt( shortChanID := lnwire.NewShortChanIDFromInt(
@ -2360,25 +2350,31 @@ func createInitChannels(t *testing.T, revocationWindow int) (*lnwallet.Lightning
aliceSigner, aliceChannelState, alicePool, aliceSigner, aliceChannelState, alicePool,
) )
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
alicePool.Start() alicePool.Start()
t.Cleanup(func() {
require.NoError(t, alicePool.Stop())
})
bobPool := lnwallet.NewSigPool(1, bobSigner) bobPool := lnwallet.NewSigPool(1, bobSigner)
channelBob, err := lnwallet.NewLightningChannel( channelBob, err := lnwallet.NewLightningChannel(
bobSigner, bobChannelState, bobPool, bobSigner, bobChannelState, bobPool,
) )
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
bobPool.Start() bobPool.Start()
t.Cleanup(func() {
require.NoError(t, bobPool.Stop())
})
addr := &net.TCPAddr{ addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: 18556, Port: 18556,
} }
if err := channelAlice.State().SyncPending(addr, 101); err != nil { if err := channelAlice.State().SyncPending(addr, 101); err != nil {
return nil, nil, nil, err return nil, nil, err
} }
addr = &net.TCPAddr{ addr = &net.TCPAddr{
@ -2386,22 +2382,17 @@ func createInitChannels(t *testing.T, revocationWindow int) (*lnwallet.Lightning
Port: 18555, Port: 18555,
} }
if err := channelBob.State().SyncPending(addr, 101); err != nil { if err := channelBob.State().SyncPending(addr, 101); err != nil {
return nil, nil, nil, err return nil, nil, err
}
cleanUpFunc := func() {
dbBob.Close()
dbAlice.Close()
} }
// Now that the channel are open, simulate the start of a session by // Now that the channel are open, simulate the start of a session by
// having Alice and Bob extend their revocation windows to each other. // having Alice and Bob extend their revocation windows to each other.
err = initRevocationWindows(channelAlice, channelBob, revocationWindow) err = initRevocationWindows(channelAlice, channelBob, revocationWindow)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
return channelAlice, channelBob, cleanUpFunc, nil return channelAlice, channelBob, nil
} }
// initRevocationWindows simulates a new channel being opened within the p2p // initRevocationWindows simulates a new channel being opened within the p2p

View File

@ -24,19 +24,20 @@ func TestChainArbitratorRepublishCloses(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer db.Close() t.Cleanup(func() {
require.NoError(t, db.Close())
})
// Create 10 test channels and sync them to the database. // Create 10 test channels and sync them to the database.
const numChans = 10 const numChans = 10
var channels []*channeldb.OpenChannel var channels []*channeldb.OpenChannel
for i := 0; i < numChans; i++ { for i := 0; i < numChans; i++ {
lChannel, _, cleanup, err := lnwallet.CreateTestChannels( lChannel, _, err := lnwallet.CreateTestChannels(
channeldb.SingleFunderTweaklessBit, t, channeldb.SingleFunderTweaklessBit,
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer cleanup()
channel := lChannel.State() channel := lChannel.State()
@ -94,11 +95,9 @@ func TestChainArbitratorRepublishCloses(t *testing.T) {
if err := chainArb.Start(); err != nil { if err := chainArb.Start(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer func() { t.Cleanup(func() {
if err := chainArb.Stop(); err != nil { require.NoError(t, chainArb.Stop())
t.Fatal(err) })
}
}()
// Half of the channels should have had their closing tx re-published. // Half of the channels should have had their closing tx re-published.
if len(published) != numChans/2 { if len(published) != numChans/2 {
@ -137,15 +136,16 @@ func TestResolveContract(t *testing.T) {
db, err := channeldb.Open(t.TempDir()) db, err := channeldb.Open(t.TempDir())
require.NoError(t, err, "unable to open db") require.NoError(t, err, "unable to open db")
defer db.Close() t.Cleanup(func() {
require.NoError(t, db.Close())
})
// With the DB created, we'll make a new channel, and mark it as // With the DB created, we'll make a new channel, and mark it as
// pending open within the database. // pending open within the database.
newChannel, _, cleanup, err := lnwallet.CreateTestChannels( newChannel, _, err := lnwallet.CreateTestChannels(
channeldb.SingleFunderTweaklessBit, t, channeldb.SingleFunderTweaklessBit,
) )
require.NoError(t, err, "unable to make new test channel") require.NoError(t, err, "unable to make new test channel")
defer cleanup()
channel := newChannel.State() channel := newChannel.State()
channel.Db = db.ChannelStateDB() channel.Db = db.ChannelStateDB()
addr := &net.TCPAddr{ addr := &net.TCPAddr{
@ -177,11 +177,9 @@ func TestResolveContract(t *testing.T) {
if err := chainArb.Start(); err != nil { if err := chainArb.Start(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer func() { t.Cleanup(func() {
if err := chainArb.Stop(); err != nil { require.NoError(t, chainArb.Stop())
t.Fatal(err) })
}
}()
channelArb := chainArb.activeChannels[channel.FundingOutpoint] channelArb := chainArb.activeChannels[channel.FundingOutpoint]

View File

@ -25,11 +25,10 @@ func TestChainWatcherRemoteUnilateralClose(t *testing.T) {
// First, we'll create two channels which already have established a // First, we'll create two channels which already have established a
// commitment contract between themselves. // commitment contract between themselves.
aliceChannel, bobChannel, cleanUp, err := lnwallet.CreateTestChannels( aliceChannel, bobChannel, err := lnwallet.CreateTestChannels(
channeldb.SingleFunderTweaklessBit, t, channeldb.SingleFunderTweaklessBit,
) )
require.NoError(t, err, "unable to create test channels") require.NoError(t, err, "unable to create test channels")
defer cleanUp()
// With the channels created, we'll now create a chain watcher instance // With the channels created, we'll now create a chain watcher instance
// which will be watching for any closes of Alice's channel. // which will be watching for any closes of Alice's channel.
@ -110,11 +109,10 @@ func TestChainWatcherRemoteUnilateralClosePendingCommit(t *testing.T) {
// First, we'll create two channels which already have established a // First, we'll create two channels which already have established a
// commitment contract between themselves. // commitment contract between themselves.
aliceChannel, bobChannel, cleanUp, err := lnwallet.CreateTestChannels( aliceChannel, bobChannel, err := lnwallet.CreateTestChannels(
channeldb.SingleFunderTweaklessBit, t, channeldb.SingleFunderTweaklessBit,
) )
require.NoError(t, err, "unable to create test channels") require.NoError(t, err, "unable to create test channels")
defer cleanUp()
// With the channels created, we'll now create a chain watcher instance // With the channels created, we'll now create a chain watcher instance
// which will be watching for any closes of Alice's channel. // which will be watching for any closes of Alice's channel.
@ -255,13 +253,12 @@ func TestChainWatcherDataLossProtect(t *testing.T) {
dlpScenario := func(t *testing.T, testCase dlpTestCase) bool { dlpScenario := func(t *testing.T, testCase dlpTestCase) bool {
// First, we'll create two channels which already have // First, we'll create two channels which already have
// established a commitment contract between themselves. // established a commitment contract between themselves.
aliceChannel, bobChannel, cleanUp, err := lnwallet.CreateTestChannels( aliceChannel, bobChannel, err := lnwallet.CreateTestChannels(
channeldb.SingleFunderBit, t, channeldb.SingleFunderBit,
) )
if err != nil { if err != nil {
t.Fatalf("unable to create test channels: %v", err) t.Fatalf("unable to create test channels: %v", err)
} }
defer cleanUp()
// Based on the number of random updates for this state, make a // Based on the number of random updates for this state, make a
// new HTLC to add to the commitment, and then lock in a state // new HTLC to add to the commitment, and then lock in a state
@ -430,13 +427,12 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) {
// First, we'll create two channels which already have // First, we'll create two channels which already have
// established a commitment contract between themselves. // established a commitment contract between themselves.
aliceChannel, bobChannel, cleanUp, err := lnwallet.CreateTestChannels( aliceChannel, bobChannel, err := lnwallet.CreateTestChannels(
channeldb.SingleFunderBit, t, channeldb.SingleFunderBit,
) )
if err != nil { if err != nil {
t.Fatalf("unable to create test channels: %v", err) t.Fatalf("unable to create test channels: %v", err)
} }
defer cleanUp()
// We'll execute a number of state transitions based on the // We'll execute a number of state transitions based on the
// randomly selected number from testing/quick. We do this to // randomly selected number from testing/quick. We do this to

View File

@ -460,11 +460,9 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) {
if err := chanArbCtx.chanArb.Start(nil); err != nil { if err := chanArbCtx.chanArb.Start(nil); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
defer func() { t.Cleanup(func() {
if err := chanArbCtx.chanArb.Stop(); err != nil { require.NoError(t, chanArbCtx.chanArb.Stop())
t.Fatalf("unable to stop chan arb: %v", err) })
}
}()
// It should start out in the default state. // It should start out in the default state.
chanArbCtx.AssertState(StateDefault) chanArbCtx.AssertState(StateDefault)
@ -681,11 +679,9 @@ func TestChannelArbitratorBreachClose(t *testing.T) {
if err := chanArb.Start(nil); err != nil { if err := chanArb.Start(nil); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
defer func() { t.Cleanup(func() {
if err := chanArb.Stop(); err != nil { require.NoError(t, chanArb.Stop())
t.Fatal(err) })
}
}()
// It should start out in the default state. // It should start out in the default state.
chanArbCtx.AssertState(StateDefault) chanArbCtx.AssertState(StateDefault)
@ -1990,11 +1986,9 @@ func TestChannelArbitratorPendingExpiredHTLC(t *testing.T) {
if err := chanArb.Start(nil); err != nil { if err := chanArb.Start(nil); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
defer func() { t.Cleanup(func() {
if err := chanArb.Stop(); err != nil { require.NoError(t, chanArb.Stop())
t.Fatalf("unable to stop chan arb: %v", err) })
}
}()
// Now that our channel arb has started, we'll set up // Now that our channel arb has started, we'll set up
// its contract signals channel so we can send it // its contract signals channel so we can send it
@ -2098,14 +2092,13 @@ func TestRemoteCloseInitiator(t *testing.T) {
t.Parallel() t.Parallel()
// First, create alice's channel. // First, create alice's channel.
alice, _, cleanUp, err := lnwallet.CreateTestChannels( alice, _, err := lnwallet.CreateTestChannels(
channeldb.SingleFunderTweaklessBit, t, channeldb.SingleFunderTweaklessBit,
) )
if err != nil { if err != nil {
t.Fatalf("unable to create test channels: %v", t.Fatalf("unable to create test channels: %v",
err) err)
} }
defer cleanUp()
// Create a mock log which will not block the test's // Create a mock log which will not block the test's
// expected number of transitions transitions, and has // expected number of transitions transitions, and has
@ -2148,11 +2141,9 @@ func TestRemoteCloseInitiator(t *testing.T) {
t.Fatalf("unable to start "+ t.Fatalf("unable to start "+
"ChannelArbitrator: %v", err) "ChannelArbitrator: %v", err)
} }
defer func() { t.Cleanup(func() {
if err := chanArb.Stop(); err != nil { require.NoError(t, chanArb.Stop())
t.Fatal(err) })
}
}()
// It should start out in the default state. // It should start out in the default state.
chanArbCtx.AssertState(StateDefault) chanArbCtx.AssertState(StateDefault)
@ -2501,11 +2492,9 @@ func TestChannelArbitratorAnchors(t *testing.T) {
if err := chanArb.Start(nil); err != nil { if err := chanArb.Start(nil); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
defer func() { t.Cleanup(func() {
if err := chanArb.Stop(); err != nil { require.NoError(t, chanArb.Stop())
t.Fatal(err) })
}
}()
signals := &ContractSignals{ signals := &ContractSignals{
ShortChanID: lnwire.ShortChannelID{}, ShortChanID: lnwire.ShortChannelID{},

View File

@ -170,7 +170,7 @@ var _ UtxoSweeper = &mockSweeper{}
// unencumbered by a time lock. // unencumbered by a time lock.
func TestCommitSweepResolverNoDelay(t *testing.T) { func TestCommitSweepResolverNoDelay(t *testing.T) {
t.Parallel() t.Parallel()
defer timeout(t)() defer timeout()()
res := lnwallet.CommitOutputResolution{ res := lnwallet.CommitOutputResolution{
SelfOutputSignDesc: input.SignDescriptor{ SelfOutputSignDesc: input.SignDescriptor{
@ -227,7 +227,7 @@ func TestCommitSweepResolverNoDelay(t *testing.T) {
// that is encumbered by a time lock. sweepErr indicates whether the local node // that is encumbered by a time lock. sweepErr indicates whether the local node
// fails to sweep the output. // fails to sweep the output.
func testCommitSweepResolverDelay(t *testing.T, sweepErr error) { func testCommitSweepResolverDelay(t *testing.T, sweepErr error) {
defer timeout(t)() defer timeout()()
const sweepProcessInterval = 100 * time.Millisecond const sweepProcessInterval = 100 * time.Millisecond
amt := int64(100) amt := int64(100)

View File

@ -36,7 +36,7 @@ var (
// for which the preimage is already known initially. // for which the preimage is already known initially.
func TestHtlcIncomingResolverFwdPreimageKnown(t *testing.T) { func TestHtlcIncomingResolverFwdPreimageKnown(t *testing.T) {
t.Parallel() t.Parallel()
defer timeout(t)() defer timeout()()
ctx := newIncomingResolverTestContext(t, false) ctx := newIncomingResolverTestContext(t, false)
ctx.witnessBeacon.lookupPreimage[testResHash] = testResPreimage ctx.witnessBeacon.lookupPreimage[testResHash] = testResPreimage
@ -49,7 +49,7 @@ func TestHtlcIncomingResolverFwdPreimageKnown(t *testing.T) {
// started. // started.
func TestHtlcIncomingResolverFwdContestedSuccess(t *testing.T) { func TestHtlcIncomingResolverFwdContestedSuccess(t *testing.T) {
t.Parallel() t.Parallel()
defer timeout(t)() defer timeout()()
ctx := newIncomingResolverTestContext(t, false) ctx := newIncomingResolverTestContext(t, false)
ctx.resolve() ctx.resolve()
@ -65,7 +65,7 @@ func TestHtlcIncomingResolverFwdContestedSuccess(t *testing.T) {
// htlc that times out after the resolver has been started. // htlc that times out after the resolver has been started.
func TestHtlcIncomingResolverFwdContestedTimeout(t *testing.T) { func TestHtlcIncomingResolverFwdContestedTimeout(t *testing.T) {
t.Parallel() t.Parallel()
defer timeout(t)() defer timeout()()
ctx := newIncomingResolverTestContext(t, false) ctx := newIncomingResolverTestContext(t, false)
@ -104,7 +104,7 @@ func TestHtlcIncomingResolverFwdContestedTimeout(t *testing.T) {
// has already expired when the resolver starts. // has already expired when the resolver starts.
func TestHtlcIncomingResolverFwdTimeout(t *testing.T) { func TestHtlcIncomingResolverFwdTimeout(t *testing.T) {
t.Parallel() t.Parallel()
defer timeout(t)() defer timeout()()
ctx := newIncomingResolverTestContext(t, true) ctx := newIncomingResolverTestContext(t, true)
ctx.witnessBeacon.lookupPreimage[testResHash] = testResPreimage ctx.witnessBeacon.lookupPreimage[testResHash] = testResPreimage
@ -117,7 +117,7 @@ func TestHtlcIncomingResolverFwdTimeout(t *testing.T) {
// which the invoice has already been settled when the resolver starts. // which the invoice has already been settled when the resolver starts.
func TestHtlcIncomingResolverExitSettle(t *testing.T) { func TestHtlcIncomingResolverExitSettle(t *testing.T) {
t.Parallel() t.Parallel()
defer timeout(t)() defer timeout()()
ctx := newIncomingResolverTestContext(t, true) ctx := newIncomingResolverTestContext(t, true)
ctx.registry.notifyResolution = invoices.NewSettleResolution( ctx.registry.notifyResolution = invoices.NewSettleResolution(
@ -149,7 +149,7 @@ func TestHtlcIncomingResolverExitSettle(t *testing.T) {
// an invoice that is already canceled when the resolver starts. // an invoice that is already canceled when the resolver starts.
func TestHtlcIncomingResolverExitCancel(t *testing.T) { func TestHtlcIncomingResolverExitCancel(t *testing.T) {
t.Parallel() t.Parallel()
defer timeout(t)() defer timeout()()
ctx := newIncomingResolverTestContext(t, true) ctx := newIncomingResolverTestContext(t, true)
ctx.registry.notifyResolution = invoices.NewFailResolution( ctx.registry.notifyResolution = invoices.NewFailResolution(
@ -165,7 +165,7 @@ func TestHtlcIncomingResolverExitCancel(t *testing.T) {
// for a hodl invoice that is settled after the resolver has started. // for a hodl invoice that is settled after the resolver has started.
func TestHtlcIncomingResolverExitSettleHodl(t *testing.T) { func TestHtlcIncomingResolverExitSettleHodl(t *testing.T) {
t.Parallel() t.Parallel()
defer timeout(t)() defer timeout()()
ctx := newIncomingResolverTestContext(t, true) ctx := newIncomingResolverTestContext(t, true)
ctx.resolve() ctx.resolve()
@ -183,7 +183,7 @@ func TestHtlcIncomingResolverExitSettleHodl(t *testing.T) {
// for a hodl invoice that times out. // for a hodl invoice that times out.
func TestHtlcIncomingResolverExitTimeoutHodl(t *testing.T) { func TestHtlcIncomingResolverExitTimeoutHodl(t *testing.T) {
t.Parallel() t.Parallel()
defer timeout(t)() defer timeout()()
ctx := newIncomingResolverTestContext(t, true) ctx := newIncomingResolverTestContext(t, true)
@ -220,7 +220,7 @@ func TestHtlcIncomingResolverExitTimeoutHodl(t *testing.T) {
// for a hodl invoice that is canceled after the resolver has started. // for a hodl invoice that is canceled after the resolver has started.
func TestHtlcIncomingResolverExitCancelHodl(t *testing.T) { func TestHtlcIncomingResolverExitCancelHodl(t *testing.T) {
t.Parallel() t.Parallel()
defer timeout(t)() defer timeout()()
ctx := newIncomingResolverTestContext(t, true) ctx := newIncomingResolverTestContext(t, true)

View File

@ -23,7 +23,7 @@ const (
// timed out. // timed out.
func TestHtlcOutgoingResolverTimeout(t *testing.T) { func TestHtlcOutgoingResolverTimeout(t *testing.T) {
t.Parallel() t.Parallel()
defer timeout(t)() defer timeout()()
// Setup the resolver with our test resolution. // Setup the resolver with our test resolution.
ctx := newOutgoingResolverTestContext(t) ctx := newOutgoingResolverTestContext(t)
@ -44,7 +44,7 @@ func TestHtlcOutgoingResolverTimeout(t *testing.T) {
// is claimed by the remote party. // is claimed by the remote party.
func TestHtlcOutgoingResolverRemoteClaim(t *testing.T) { func TestHtlcOutgoingResolverRemoteClaim(t *testing.T) {
t.Parallel() t.Parallel()
defer timeout(t)() defer timeout()()
// Setup the resolver with our test resolution and start the resolution // Setup the resolver with our test resolution and start the resolution
// process. // process.

View File

@ -477,7 +477,7 @@ type checkpoint struct {
func testHtlcSuccess(t *testing.T, resolution lnwallet.IncomingHtlcResolution, func testHtlcSuccess(t *testing.T, resolution lnwallet.IncomingHtlcResolution,
checkpoints []checkpoint) { checkpoints []checkpoint) {
defer timeout(t)() defer timeout()()
// We first run the resolver from start to finish, ensuring it gets // We first run the resolver from start to finish, ensuring it gets
// checkpointed at every expected stage. We store the checkpointed data // checkpointed at every expected stage. We store the checkpointed data
@ -521,7 +521,7 @@ func testHtlcSuccess(t *testing.T, resolution lnwallet.IncomingHtlcResolution,
func runFromCheckpoint(t *testing.T, ctx *htlcResolverTestContext, func runFromCheckpoint(t *testing.T, ctx *htlcResolverTestContext,
expectedCheckpoints []checkpoint) [][]byte { expectedCheckpoints []checkpoint) [][]byte {
defer timeout(t)() defer timeout()()
var checkpointedState [][]byte var checkpointedState [][]byte

View File

@ -1286,7 +1286,7 @@ func TestHtlcTimeoutSecondStageSweeperRemoteSpend(t *testing.T) {
func testHtlcTimeout(t *testing.T, resolution lnwallet.OutgoingHtlcResolution, func testHtlcTimeout(t *testing.T, resolution lnwallet.OutgoingHtlcResolution,
checkpoints []checkpoint) { checkpoints []checkpoint) {
defer timeout(t)() defer timeout()()
// We first run the resolver from start to finish, ensuring it gets // We first run the resolver from start to finish, ensuring it gets
// checkpointed at every expected stage. We store the checkpointed data // checkpointed at every expected stage. We store the checkpointed data

View File

@ -53,9 +53,8 @@ func initIncubateTests() {
// TestNurseryStoreInit verifies basic properties of the nursery store before // TestNurseryStoreInit verifies basic properties of the nursery store before
// any modifying calls are made. // any modifying calls are made.
func TestNurseryStoreInit(t *testing.T) { func TestNurseryStoreInit(t *testing.T) {
cdb, cleanUp, err := channeldb.MakeTestDB() cdb, err := channeldb.MakeTestDB(t)
require.NoError(t, err, "unable to open channel db") require.NoError(t, err, "unable to open channel db")
defer cleanUp()
ns, err := NewNurseryStore(&chainHash, cdb) ns, err := NewNurseryStore(&chainHash, cdb)
require.NoError(t, err, "unable to open nursery store") require.NoError(t, err, "unable to open nursery store")
@ -69,9 +68,8 @@ func TestNurseryStoreInit(t *testing.T) {
// outputs through the nursery store, verifying the properties of the // outputs through the nursery store, verifying the properties of the
// intermediate states. // intermediate states.
func TestNurseryStoreIncubate(t *testing.T) { func TestNurseryStoreIncubate(t *testing.T) {
cdb, cleanUp, err := channeldb.MakeTestDB() cdb, err := channeldb.MakeTestDB(t)
require.NoError(t, err, "unable to open channel db") require.NoError(t, err, "unable to open channel db")
defer cleanUp()
ns, err := NewNurseryStore(&chainHash, cdb) ns, err := NewNurseryStore(&chainHash, cdb)
require.NoError(t, err, "unable to open nursery store") require.NoError(t, err, "unable to open nursery store")
@ -306,9 +304,8 @@ func TestNurseryStoreIncubate(t *testing.T) {
// populated entries from the height index as it is purged, and that the last // populated entries from the height index as it is purged, and that the last
// purged height is set appropriately. // purged height is set appropriately.
func TestNurseryStoreGraduate(t *testing.T) { func TestNurseryStoreGraduate(t *testing.T) {
cdb, cleanUp, err := channeldb.MakeTestDB() cdb, err := channeldb.MakeTestDB(t)
require.NoError(t, err, "unable to open channel db") require.NoError(t, err, "unable to open channel db")
defer cleanUp()
ns, err := NewNurseryStore(&chainHash, cdb) ns, err := NewNurseryStore(&chainHash, cdb)
require.NoError(t, err, "unable to open nursery store") require.NoError(t, err, "unable to open nursery store")

View File

@ -13,7 +13,7 @@ import (
) )
// timeout implements a test level timeout. // timeout implements a test level timeout.
func timeout(t *testing.T) func() { func timeout() func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
select { select {

View File

@ -409,7 +409,6 @@ type nurseryTestContext struct {
sweeper *mockSweeperFull sweeper *mockSweeperFull
timeoutChan chan chan time.Time timeoutChan chan chan time.Time
t *testing.T t *testing.T
dbCleanup func()
} }
func createNurseryTestContext(t *testing.T, func createNurseryTestContext(t *testing.T,
@ -419,7 +418,7 @@ func createNurseryTestContext(t *testing.T,
// alternative, mocking nurseryStore, is not chosen because there is // alternative, mocking nurseryStore, is not chosen because there is
// still considerable logic in the store. // still considerable logic in the store.
cdb, cleanup, err := channeldb.MakeTestDB() cdb, err := channeldb.MakeTestDB(t)
require.NoError(t, err, "unable to open channeldb") require.NoError(t, err, "unable to open channeldb")
store, err := NewNurseryStore(&chainhash.Hash{}, cdb) store, err := NewNurseryStore(&chainhash.Hash{}, cdb)
@ -480,7 +479,6 @@ func createNurseryTestContext(t *testing.T,
sweeper: sweeper, sweeper: sweeper,
timeoutChan: timeoutChan, timeoutChan: timeoutChan,
t: t, t: t,
dbCleanup: cleanup,
} }
ctx.receiveTx = func() wire.MsgTx { ctx.receiveTx = func() wire.MsgTx {
@ -528,8 +526,6 @@ func (ctx *nurseryTestContext) notifyEpoch(height int32) {
} }
func (ctx *nurseryTestContext) finish() { func (ctx *nurseryTestContext) finish() {
defer ctx.dbCleanup()
// Add a final restart point in this state // Add a final restart point in this state
ctx.restart() ctx.restart()

File diff suppressed because it is too large Load Diff

View File

@ -5,10 +5,9 @@ import (
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"io" "io"
"io/ioutil"
prand "math/rand" prand "math/rand"
"net" "net"
"os" "testing"
"github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil"
@ -20,6 +19,7 @@ import (
"github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain" "github.com/lightningnetwork/lnd/shachain"
"github.com/stretchr/testify/require"
) )
var ( var (
@ -103,17 +103,15 @@ var (
// CreateTestChannels creates to fully populated channels to be used within // CreateTestChannels creates to fully populated channels to be used within
// testing fixtures. The channels will be returned as if the funding process // testing fixtures. The channels will be returned as if the funding process
// has just completed. The channel itself is funded with 10 BTC, with 5 BTC // has just completed. The channel itself is funded with 10 BTC, with 5 BTC
// allocated to each side. Within the channel, Alice is the initiator. The // allocated to each side. Within the channel, Alice is the initiator. If
// function also returns a "cleanup" function that is meant to be called once // tweaklessCommits is true, then the commits within the channels will use the
// the test has been finalized. The clean up function will remote all temporary // new format, otherwise the legacy format.
// files created. If tweaklessCommits is true, then the commits within the func CreateTestChannels(t *testing.T, chanType channeldb.ChannelType) (
// channels will use the new format, otherwise the legacy format. *LightningChannel, *LightningChannel, error) {
func CreateTestChannels(chanType channeldb.ChannelType) (
*LightningChannel, *LightningChannel, func(), error) {
channelCapacity, err := btcutil.NewAmount(testChannelCapacity) channelCapacity, err := btcutil.NewAmount(testChannelCapacity)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
channelBal := channelCapacity / 2 channelBal := channelCapacity / 2
@ -202,23 +200,23 @@ func CreateTestChannels(chanType channeldb.ChannelType) (
bobRoot, err := chainhash.NewHash(bobKeys[0].Serialize()) bobRoot, err := chainhash.NewHash(bobKeys[0].Serialize())
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
bobPreimageProducer := shachain.NewRevocationProducer(*bobRoot) bobPreimageProducer := shachain.NewRevocationProducer(*bobRoot)
bobFirstRevoke, err := bobPreimageProducer.AtIndex(0) bobFirstRevoke, err := bobPreimageProducer.AtIndex(0)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
bobCommitPoint := input.ComputeCommitmentPoint(bobFirstRevoke[:]) bobCommitPoint := input.ComputeCommitmentPoint(bobFirstRevoke[:])
aliceRoot, err := chainhash.NewHash(aliceKeys[0].Serialize()) aliceRoot, err := chainhash.NewHash(aliceKeys[0].Serialize())
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
alicePreimageProducer := shachain.NewRevocationProducer(*aliceRoot) alicePreimageProducer := shachain.NewRevocationProducer(*aliceRoot)
aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0) aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
aliceCommitPoint := input.ComputeCommitmentPoint(aliceFirstRevoke[:]) aliceCommitPoint := input.ComputeCommitmentPoint(aliceFirstRevoke[:])
@ -227,33 +225,29 @@ func CreateTestChannels(chanType channeldb.ChannelType) (
bobCommitPoint, *fundingTxIn, chanType, isAliceInitiator, 0, bobCommitPoint, *fundingTxIn, chanType, isAliceInitiator, 0,
) )
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
alicePath, err := ioutil.TempDir("", "alicedb") dbAlice, err := channeldb.Open(t.TempDir())
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
t.Cleanup(func() {
require.NoError(t, dbAlice.Close())
})
dbAlice, err := channeldb.Open(alicePath) dbBob, err := channeldb.Open(t.TempDir())
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
}
bobPath, err := ioutil.TempDir("", "bobdb")
if err != nil {
return nil, nil, nil, err
}
dbBob, err := channeldb.Open(bobPath)
if err != nil {
return nil, nil, nil, err
} }
t.Cleanup(func() {
require.NoError(t, dbBob.Close())
})
estimator := chainfee.NewStaticEstimator(6000, 0) estimator := chainfee.NewStaticEstimator(6000, 0)
feePerKw, err := estimator.EstimateFeePerKW(1) feePerKw, err := estimator.EstimateFeePerKW(1)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
commitFee := calcStaticFee(chanType, 0) commitFee := calcStaticFee(chanType, 0)
var anchorAmt btcutil.Amount var anchorAmt btcutil.Amount
@ -305,7 +299,7 @@ func CreateTestChannels(chanType channeldb.ChannelType) (
var chanIDBytes [8]byte var chanIDBytes [8]byte
if _, err := io.ReadFull(rand.Reader, chanIDBytes[:]); err != nil { if _, err := io.ReadFull(rand.Reader, chanIDBytes[:]); err != nil {
return nil, nil, nil, err return nil, nil, err
} }
shortChanID := lnwire.NewShortChanIDFromInt( shortChanID := lnwire.NewShortChanIDFromInt(
@ -358,9 +352,12 @@ func CreateTestChannels(chanType channeldb.ChannelType) (
aliceSigner, aliceChannelState, alicePool, aliceSigner, aliceChannelState, alicePool,
) )
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
alicePool.Start() alicePool.Start()
t.Cleanup(func() {
require.NoError(t, alicePool.Stop())
})
obfuscator := createStateHintObfuscator(aliceChannelState) obfuscator := createStateHintObfuscator(aliceChannelState)
@ -369,21 +366,24 @@ func CreateTestChannels(chanType channeldb.ChannelType) (
bobSigner, bobChannelState, bobPool, bobSigner, bobChannelState, bobPool,
) )
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
bobPool.Start() bobPool.Start()
t.Cleanup(func() {
require.NoError(t, bobPool.Stop())
})
err = SetStateNumHint( err = SetStateNumHint(
aliceCommitTx, 0, obfuscator, aliceCommitTx, 0, obfuscator,
) )
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
err = SetStateNumHint( err = SetStateNumHint(
bobCommitTx, 0, obfuscator, bobCommitTx, 0, obfuscator,
) )
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
addr := &net.TCPAddr{ addr := &net.TCPAddr{
@ -391,7 +391,7 @@ func CreateTestChannels(chanType channeldb.ChannelType) (
Port: 18556, Port: 18556,
} }
if err := channelAlice.channelState.SyncPending(addr, 101); err != nil { if err := channelAlice.channelState.SyncPending(addr, 101); err != nil {
return nil, nil, nil, err return nil, nil, err
} }
addr = &net.TCPAddr{ addr = &net.TCPAddr{
@ -400,25 +400,17 @@ func CreateTestChannels(chanType channeldb.ChannelType) (
} }
if err := channelBob.channelState.SyncPending(addr, 101); err != nil { if err := channelBob.channelState.SyncPending(addr, 101); err != nil {
return nil, nil, nil, err return nil, nil, err
}
cleanUpFunc := func() {
os.RemoveAll(bobPath)
os.RemoveAll(alicePath)
alicePool.Stop()
bobPool.Stop()
} }
// Now that the channel are open, simulate the start of a session by // Now that the channel are open, simulate the start of a session by
// having Alice and Bob extend their revocation windows to each other. // having Alice and Bob extend their revocation windows to each other.
err = initRevocationWindows(channelAlice, channelBob) err = initRevocationWindows(channelAlice, channelBob)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
return channelAlice, channelBob, cleanUpFunc, nil return channelAlice, channelBob, nil
} }
// initRevocationWindows simulates a new channel being opened within the p2p // initRevocationWindows simulates a new channel being opened within the p2p

View File

@ -15,15 +15,10 @@ func TestStore(t *testing.T) {
t.Run("bolt", func(t *testing.T) { t.Run("bolt", func(t *testing.T) {
// Create new store. // Create new store.
cdb, cleanUp, err := channeldb.MakeTestDB() cdb, err := channeldb.MakeTestDB(t)
if err != nil { if err != nil {
t.Fatalf("unable to open channel db: %v", err) t.Fatalf("unable to open channel db: %v", err)
} }
defer cleanUp()
if err != nil {
t.Fatal(err)
}
testStore(t, func() (SweeperStore, error) { testStore(t, func() (SweeperStore, error) {
var chain chainhash.Hash var chain chainhash.Hash