mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 01:43:16 +01:00
sweep: use testify/mock
for MockSweeperStore
This commit is contained in:
parent
8b9d5e0548
commit
f13a3a8053
@ -2,54 +2,58 @@ package sweep
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockSweeperStore is a mock implementation of sweeper store. This type is
|
||||
// exported, because it is currently used in nursery tests too.
|
||||
type MockSweeperStore struct {
|
||||
ourTxes map[chainhash.Hash]struct{}
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// NewMockSweeperStore returns a new instance.
|
||||
func NewMockSweeperStore() *MockSweeperStore {
|
||||
return &MockSweeperStore{
|
||||
ourTxes: make(map[chainhash.Hash]struct{}),
|
||||
}
|
||||
return &MockSweeperStore{}
|
||||
}
|
||||
|
||||
// IsOurTx determines whether a tx is published by us, based on its
|
||||
// hash.
|
||||
// IsOurTx determines whether a tx is published by us, based on its hash.
|
||||
func (s *MockSweeperStore) IsOurTx(hash chainhash.Hash) (bool, error) {
|
||||
_, ok := s.ourTxes[hash]
|
||||
return ok, nil
|
||||
args := s.Called(hash)
|
||||
|
||||
return args.Bool(0), args.Error(1)
|
||||
}
|
||||
|
||||
// StoreTx stores a tx we are about to publish.
|
||||
func (s *MockSweeperStore) StoreTx(tr *TxRecord) error {
|
||||
s.ourTxes[tr.Txid] = struct{}{}
|
||||
|
||||
return nil
|
||||
args := s.Called(tr)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
// ListSweeps lists all the sweeps we have successfully published.
|
||||
func (s *MockSweeperStore) ListSweeps() ([]chainhash.Hash, error) {
|
||||
var txns []chainhash.Hash
|
||||
for tx := range s.ourTxes {
|
||||
txns = append(txns, tx)
|
||||
}
|
||||
args := s.Called()
|
||||
|
||||
return txns, nil
|
||||
return args.Get(0).([]chainhash.Hash), args.Error(1)
|
||||
}
|
||||
|
||||
// GetTx queries the database to find the tx that matches the given txid.
|
||||
// Returns ErrTxNotFound if it cannot be found.
|
||||
func (s *MockSweeperStore) GetTx(hash chainhash.Hash) (*TxRecord, error) {
|
||||
return nil, ErrTxNotFound
|
||||
args := s.Called(hash)
|
||||
|
||||
tr := args.Get(0)
|
||||
if tr != nil {
|
||||
return args.Get(0).(*TxRecord), args.Error(1)
|
||||
}
|
||||
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
// DeleteTx removes the given tx from db.
|
||||
func (s *MockSweeperStore) DeleteTx(txid chainhash.Hash) error {
|
||||
return nil
|
||||
args := s.Called(txid)
|
||||
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
// Compile-time constraint to ensure MockSweeperStore implements SweeperStore.
|
||||
|
@ -14,35 +14,13 @@ import (
|
||||
// TestStore asserts that the store persists the presented data to disk and is
|
||||
// able to retrieve it again.
|
||||
func TestStore(t *testing.T) {
|
||||
t.Run("bolt", func(t *testing.T) {
|
||||
// Create new store.
|
||||
cdb, err := channeldb.MakeTestDB(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create new store.
|
||||
cdb, err := channeldb.MakeTestDB(t)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open channel db: %v", err)
|
||||
}
|
||||
|
||||
testStore(t, func() (SweeperStore, error) {
|
||||
var chain chainhash.Hash
|
||||
return NewSweeperStore(cdb, &chain)
|
||||
})
|
||||
})
|
||||
t.Run("mock", func(t *testing.T) {
|
||||
store := NewMockSweeperStore()
|
||||
|
||||
testStore(t, func() (SweeperStore, error) {
|
||||
// Return same store, because the mock has no real
|
||||
// persistence.
|
||||
return store, nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func testStore(t *testing.T, createStore func() (SweeperStore, error)) {
|
||||
store, err := createStore()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var chain chainhash.Hash
|
||||
store, err := NewSweeperStore(cdb, &chain)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Notify publication of tx1
|
||||
tx1 := wire.MsgTx{}
|
||||
@ -75,10 +53,8 @@ func testStore(t *testing.T, createStore func() (SweeperStore, error)) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Recreate the sweeper store
|
||||
store, err = createStore()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
store, err = NewSweeperStore(cdb, &chain)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert that both txes are recognized as our own.
|
||||
ours, err := store.IsOurTx(tx1.TxHash())
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/build"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lntest/mock"
|
||||
@ -41,7 +42,7 @@ type sweeperTestContext struct {
|
||||
notifier *MockNotifier
|
||||
estimator *mockFeeEstimator
|
||||
backend *mockBackend
|
||||
store *MockSweeperStore
|
||||
store SweeperStore
|
||||
|
||||
publishChan chan wire.MsgTx
|
||||
}
|
||||
@ -102,7 +103,13 @@ func init() {
|
||||
func createSweeperTestContext(t *testing.T) *sweeperTestContext {
|
||||
notifier := NewMockNotifier(t)
|
||||
|
||||
store := NewMockSweeperStore()
|
||||
// Create new store.
|
||||
cdb, err := channeldb.MakeTestDB(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
var chain chainhash.Hash
|
||||
store, err := NewSweeperStore(cdb, &chain)
|
||||
require.NoError(t, err)
|
||||
|
||||
backend := newMockBackend(t, notifier)
|
||||
backend.walletUtxos = []*lnwallet.Utxo{
|
||||
@ -682,7 +689,6 @@ func TestIdempotency(t *testing.T) {
|
||||
|
||||
// Timer is still running, but spend notification was delivered before
|
||||
// it expired.
|
||||
|
||||
ctx.finish(1)
|
||||
}
|
||||
|
||||
@ -701,9 +707,8 @@ func TestRestart(t *testing.T) {
|
||||
|
||||
// Sweep input and expect sweep tx.
|
||||
input1 := spendableInputs[0]
|
||||
if _, err := ctx.sweeper.SweepInput(input1, defaultFeePref); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err := ctx.sweeper.SweepInput(input1, defaultFeePref)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx.receiveTx()
|
||||
|
||||
@ -758,23 +763,20 @@ func TestRestart(t *testing.T) {
|
||||
ctx.finish(1)
|
||||
}
|
||||
|
||||
// TestRestartRemoteSpend asserts that the sweeper picks up sweeping properly after
|
||||
// a restart with remote spend.
|
||||
// TestRestartRemoteSpend asserts that the sweeper picks up sweeping properly
|
||||
// after a restart with remote spend.
|
||||
func TestRestartRemoteSpend(t *testing.T) {
|
||||
|
||||
ctx := createSweeperTestContext(t)
|
||||
|
||||
// Sweep input.
|
||||
input1 := spendableInputs[0]
|
||||
if _, err := ctx.sweeper.SweepInput(input1, defaultFeePref); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err := ctx.sweeper.SweepInput(input1, defaultFeePref)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Sweep another input.
|
||||
input2 := spendableInputs[1]
|
||||
if _, err := ctx.sweeper.SweepInput(input2, defaultFeePref); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = ctx.sweeper.SweepInput(input2, defaultFeePref)
|
||||
require.NoError(t, err)
|
||||
|
||||
sweepTx := ctx.receiveTx()
|
||||
|
||||
@ -798,7 +800,8 @@ func TestRestartRemoteSpend(t *testing.T) {
|
||||
// Mine remote spending tx.
|
||||
ctx.backend.mine()
|
||||
|
||||
// Simulate other subsystem (e.g. contract resolver) re-offering input 0.
|
||||
// Simulate other subsystem (e.g. contract resolver) re-offering input
|
||||
// 0.
|
||||
spendChan, err := ctx.sweeper.SweepInput(input1, defaultFeePref)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -815,8 +818,8 @@ func TestRestartRemoteSpend(t *testing.T) {
|
||||
ctx.finish(1)
|
||||
}
|
||||
|
||||
// TestRestartConfirmed asserts that the sweeper picks up sweeping properly after
|
||||
// a restart with a confirm of our own sweep tx.
|
||||
// TestRestartConfirmed asserts that the sweeper picks up sweeping properly
|
||||
// after a restart with a confirm of our own sweep tx.
|
||||
func TestRestartConfirmed(t *testing.T) {
|
||||
ctx := createSweeperTestContext(t)
|
||||
|
||||
@ -834,7 +837,8 @@ func TestRestartConfirmed(t *testing.T) {
|
||||
// Mine the sweep tx.
|
||||
ctx.backend.mine()
|
||||
|
||||
// Simulate other subsystem (e.g. contract resolver) re-offering input 0.
|
||||
// Simulate other subsystem (e.g. contract resolver) re-offering input
|
||||
// 0.
|
||||
spendChan, err := ctx.sweeper.SweepInput(input, defaultFeePref)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
Loading…
Reference in New Issue
Block a user