diff --git a/chainntnfs/bitcoindnotify/bitcoind.go b/chainntnfs/bitcoindnotify/bitcoind.go index c4813cd2f..2bffefdbe 100644 --- a/chainntnfs/bitcoindnotify/bitcoind.go +++ b/chainntnfs/bitcoindnotify/bitcoind.go @@ -15,6 +15,7 @@ import ( "github.com/btcsuite/btcwallet/chain" "github.com/lightningnetwork/lnd/blockcache" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/queue" ) @@ -1070,3 +1071,26 @@ func (b *BitcoindNotifier) CancelMempoolSpendEvent( b.memNotifier.UnsubscribeEvent(sub) } + +// LookupInputMempoolSpend takes an outpoint and queries the mempool to find +// its spending tx. Returns the tx if found, otherwise fn.None. +// +// NOTE: part of the MempoolWatcher interface. +func (b *BitcoindNotifier) LookupInputMempoolSpend( + op wire.OutPoint) fn.Option[wire.MsgTx] { + + // Find the spending txid. + txid, found := b.chainConn.LookupInputMempoolSpend(op) + if !found { + return fn.None[wire.MsgTx]() + } + + // Query the spending tx using the id. + tx, err := b.chainConn.GetRawTransaction(&txid) + if err != nil { + // TODO(yy): enable logging errors in this package. + return fn.None[wire.MsgTx]() + } + + return fn.Some(*tx.MsgTx().Copy()) +} diff --git a/chainntnfs/btcdnotify/btcd.go b/chainntnfs/btcdnotify/btcd.go index d2e9c77bd..e865426e9 100644 --- a/chainntnfs/btcdnotify/btcd.go +++ b/chainntnfs/btcdnotify/btcd.go @@ -17,6 +17,7 @@ import ( "github.com/btcsuite/btcwallet/chain" "github.com/lightningnetwork/lnd/blockcache" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/queue" ) @@ -1137,3 +1138,26 @@ func (b *BtcdNotifier) CancelMempoolSpendEvent( b.memNotifier.UnsubscribeEvent(sub) } + +// LookupInputMempoolSpend takes an outpoint and queries the mempool to find +// its spending tx. Returns the tx if found, otherwise fn.None. +// +// NOTE: part of the MempoolWatcher interface. +func (b *BtcdNotifier) LookupInputMempoolSpend( + op wire.OutPoint) fn.Option[wire.MsgTx] { + + // Find the spending txid. + txid, found := b.chainConn.LookupInputMempoolSpend(op) + if !found { + return fn.None[wire.MsgTx]() + } + + // Query the spending tx using the id. + tx, err := b.chainConn.GetRawTransaction(&txid) + if err != nil { + // TODO(yy): enable logging errors in this package. + return fn.None[wire.MsgTx]() + } + + return fn.Some(*tx.MsgTx().Copy()) +} diff --git a/chainntnfs/interface.go b/chainntnfs/interface.go index e40c271b4..3337f1451 100644 --- a/chainntnfs/interface.go +++ b/chainntnfs/interface.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" ) var ( @@ -849,4 +850,9 @@ type MempoolWatcher interface { // CancelMempoolSpendEvent allows the caller to cancel a subscription to // watch for a spend of an outpoint in the mempool. CancelMempoolSpendEvent(sub *MempoolSpendEvent) + + // LookupInputMempoolSpend looks up the mempool to find a spending tx + // which spends the given outpoint. A fn.None is returned if it's not + // found. + LookupInputMempoolSpend(op wire.OutPoint) fn.Option[wire.MsgTx] } diff --git a/chainntnfs/mocks.go b/chainntnfs/mocks.go index 2db586d6c..31b75d46f 100644 --- a/chainntnfs/mocks.go +++ b/chainntnfs/mocks.go @@ -2,6 +2,7 @@ package chainntnfs import ( "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/stretchr/testify/mock" ) @@ -39,3 +40,13 @@ func (m *MockMempoolWatcher) CancelMempoolSpendEvent( m.Called(sub) } + +// LookupInputMempoolSpend looks up the mempool to find a spending tx which +// spends the given outpoint. +func (m *MockMempoolWatcher) LookupInputMempoolSpend( + op wire.OutPoint) fn.Option[wire.MsgTx] { + + args := m.Called(op) + + return args.Get(0).(fn.Option[wire.MsgTx]) +} diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 034dd46e8..4df08a8a0 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -1338,45 +1338,18 @@ func (s *UtxoSweeper) ListSweeps() ([]chainhash.Hash, error) { // mempoolLookup takes an input's outpoint and queries the mempool to see // whether it's already been spent in a transaction found in the mempool. // Returns the transaction if found. -func (s *UtxoSweeper) mempoolLookup(op wire.OutPoint) (*wire.MsgTx, bool) { +func (s *UtxoSweeper) mempoolLookup(op wire.OutPoint) fn.Option[wire.MsgTx] { // For neutrino backend, there's no mempool available, so we exit // early. if s.cfg.Mempool == nil { log.Debugf("Skipping mempool lookup for %v, no mempool ", op) - return nil, false + return fn.None[wire.MsgTx]() } - // Make a subscription to the mempool. If this outpoint is already - // spent in mempool, we should get a spending event back immediately. - mempoolSpent, err := s.cfg.Mempool.SubscribeMempoolSpent(op) - if err != nil { - log.Errorf("Unable to subscribe to mempool spend for input "+ - "%v: %v", op, err) - - return nil, false - } - - // We want to cancel this subscription in the end as we are only - // interested in a one-time query and this subscription won't be - // listened once this method returns. - defer s.cfg.Mempool.CancelMempoolSpendEvent(mempoolSpent) - - // Do a non-blocking read on the spent event channel. - select { - case details := <-mempoolSpent.Spend: - log.Debugf("Found mempool spend of input %s in tx=%s", - op, details.SpenderTxHash) - - // Found the spending transaction in mempool. This means we - // need to consider RBF constraints if we want to include this - // input in a new sweeping transaction. - return details.SpendingTx, true - - default: - } - - return nil, false + // Query this input in the mempool. If this outpoint is already spent + // in mempool, we should get a spending event back immediately. + return s.cfg.Mempool.LookupInputMempoolSpend(op) } // handleNewInput processes a new input by registering spend notification and @@ -1431,7 +1404,7 @@ func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) { // fee info of the spending transction, hence preparing for possible RBF. func (s *UtxoSweeper) attachAvailableRBFInfo(pi *pendingInput) *pendingInput { // Check if we can find the spending tx of this input in mempool. - tx, spent := s.mempoolLookup(*pi.OutPoint()) + txOption := s.mempoolLookup(*pi.OutPoint()) // Exit early if it's not found. // @@ -1439,10 +1412,14 @@ func (s *UtxoSweeper) attachAvailableRBFInfo(pi *pendingInput) *pendingInput { // lookup: // - for neutrino we don't have a mempool. // - for btcd below v0.24.1 we don't have `gettxspendingprevout`. - if !spent { + if txOption.IsNone() { return pi } + // NOTE: we use UnsafeFromSome for here because we are sure this option + // is NOT none. + tx := txOption.UnsafeFromSome() + // Otherwise the input is already spent in the mempool, update its // state to StatePublished. pi.state = StatePublished diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 48785b11e..a4ef79695 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -21,7 +21,6 @@ import ( lnmock "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -2309,69 +2308,38 @@ func TestMempoolLookup(t *testing.T) { // Create a mock mempool watcher. mockMempool := chainntnfs.NewMockMempoolWatcher() + defer mockMempool.AssertExpectations(t) // Create a test sweeper without a mempool. s := New(&UtxoSweeperConfig{}) - // Since we don't have a mempool, we expect the call to return an empty - // transaction plus a false value indicating it's not found. - tx, found := s.mempoolLookup(op) - require.Nil(tx) - require.False(found) + // Since we don't have a mempool, we expect the call to return a + // fn.None indicating it's not found. + tx := s.mempoolLookup(op) + require.True(tx.IsNone()) // Re-create the sweeper with the mocked mempool watcher. s = New(&UtxoSweeperConfig{ Mempool: mockMempool, }) - // Create a mempool spend event to be returned by the mempool watcher. - spendChan := make(chan *chainntnfs.SpendDetail, 1) - spendEvent := &chainntnfs.MempoolSpendEvent{ - Spend: spendChan, - } + // Mock the mempool watcher to return not found. + mockMempool.On("LookupInputMempoolSpend", op).Return( + fn.None[wire.MsgTx]()).Once() - // Mock the cancel subscription calls. - mockMempool.On("CancelMempoolSpendEvent", spendEvent) + // We expect a fn.None tx to be returned. + tx = s.mempoolLookup(op) + require.True(tx.IsNone()) - // Mock the mempool watcher to return an error. - dummyErr := errors.New("dummy err") - mockMempool.On("SubscribeMempoolSpent", op).Return(nil, dummyErr).Once() - - // We expect a nil tx and a false value to be returned. - // - // TODO(yy): this means the behavior of not having a mempool is the - // same as an erroneous mempool. The question is should we - // differentiate the two from their returned values? - tx, found = s.mempoolLookup(op) - require.Nil(tx) - require.False(found) - - // Mock the mempool to subscribe to the outpoint. - mockMempool.On("SubscribeMempoolSpent", op).Return( - spendEvent, nil).Once() - - // Without sending a spending details to the `spendChan`, we still - // expect a nil tx and a false value to be returned. - tx, found = s.mempoolLookup(op) - require.Nil(tx) - require.False(found) - - // Send a dummy spending details to the `spendChan`. - dummyTx := &wire.MsgTx{} - spendChan <- &chainntnfs.SpendDetail{ - SpendingTx: dummyTx, - } - - // Mock the mempool to subscribe to the outpoint. - mockMempool.On("SubscribeMempoolSpent", op).Return( - spendEvent, nil).Once() + // Mock the mempool to return a spending tx. + dummyTx := wire.MsgTx{} + mockMempool.On("LookupInputMempoolSpend", op).Return( + fn.Some(dummyTx)).Once() // Calling the loopup again, we expect the dummyTx to be returned. - tx, found = s.mempoolLookup(op) - require.Equal(dummyTx, tx) - require.True(found) - - mockMempool.AssertExpectations(t) + tx = s.mempoolLookup(op) + require.False(tx.IsNone()) + require.Equal(dummyTx, tx.UnsafeFromSome()) } // TestUpdateSweeperInputs checks that the method `updateSweeperInputs` will @@ -2444,6 +2412,8 @@ func TestAttachAvailableRBFInfo(t *testing.T) { // Create a mock input. testInput := &input.MockInput{} + defer testInput.AssertExpectations(t) + testInput.On("OutPoint").Return(&op) pi := &pendingInput{ Input: testInput, @@ -2452,16 +2422,9 @@ func TestAttachAvailableRBFInfo(t *testing.T) { // Create a mock mempool watcher and a mock sweeper store. mockMempool := chainntnfs.NewMockMempoolWatcher() + defer mockMempool.AssertExpectations(t) mockStore := NewMockSweeperStore() - - // Create a mempool spend event to be returned by the mempool watcher. - spendChan := make(chan *chainntnfs.SpendDetail, 1) - spendEvent := &chainntnfs.MempoolSpendEvent{ - Spend: spendChan, - } - - // Mock the cancel subscription calls. - mockMempool.On("CancelMempoolSpendEvent", spendEvent) + defer mockStore.AssertExpectations(t) // Create a test sweeper. s := New(&UtxoSweeperConfig{ @@ -2469,9 +2432,9 @@ func TestAttachAvailableRBFInfo(t *testing.T) { Mempool: mockMempool, }) - // First, mock the mempool to return an error. - dummyErr := errors.New("dummy err") - mockMempool.On("SubscribeMempoolSpent", op).Return(nil, dummyErr).Once() + // First, mock the mempool to return false. + mockMempool.On("LookupInputMempoolSpend", op).Return( + fn.None[wire.MsgTx]()).Once() // Since the mempool lookup failed, we exepect the original pending // input to stay unchanged. @@ -2479,16 +2442,11 @@ func TestAttachAvailableRBFInfo(t *testing.T) { require.True(result.rbf.IsNone()) require.Equal(StateInit, result.state) - // Mock the mempool lookup to return a tx three times. - tx := &wire.MsgTx{} - mockMempool.On("SubscribeMempoolSpent", op).Return( - spendEvent, nil).Times(3).Run(func(_ mock.Arguments) { - // Eeac time the method is called, we send a tx to the spend - // channel. - spendChan <- &chainntnfs.SpendDetail{ - SpendingTx: tx, - } - }) + // Mock the mempool lookup to return a tx three times as we are calling + // attachAvailableRBFInfo three times. + tx := wire.MsgTx{} + mockMempool.On("LookupInputMempoolSpend", op).Return( + fn.Some(tx)).Times(3) // Mock the store to return an error saying the tx cannot be found. mockStore.On("GetTx", tx.TxHash()).Return(nil, ErrTxNotFound).Once() @@ -2500,6 +2458,7 @@ func TestAttachAvailableRBFInfo(t *testing.T) { require.Equal(StatePublished, result.state) // Mock the store to return a db error. + dummyErr := errors.New("dummy error") mockStore.On("GetTx", tx.TxHash()).Return(nil, dummyErr).Once() // Although the db lookup failed, the pending input should have been @@ -2528,11 +2487,6 @@ func TestAttachAvailableRBFInfo(t *testing.T) { // Assert the state is updated. require.Equal(StatePublished, result.state) - - // Assert mocked statements. - testInput.AssertExpectations(t) - mockMempool.AssertExpectations(t) - mockStore.AssertExpectations(t) } // TestMarkInputFailed checks that the input is marked as failed as expected.