diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index 620ad7554..f215df71e 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -908,7 +908,7 @@ func (t *TxPublisher) processRecords() { // Check whether the inputs has been spent by a third party. // // NOTE: this check is only done for neutrino backend. - if t.isThirdPartySpent(r.tx.TxHash(), r.req.Inputs) { + if t.isThirdPartySpent(r) { failedRecords[requestID] = r // Move to the next record. @@ -1253,26 +1253,59 @@ func (t *TxPublisher) isConfirmed(txid chainhash.Hash) bool { // // NOTE: this check is only performed for neutrino backend as it has no // reliable way to tell a tx has been replaced. -func (t *TxPublisher) isThirdPartySpent(txid chainhash.Hash, - inputs []input.Input) bool { - +func (t *TxPublisher) isThirdPartySpent(r *monitorRecord) bool { // Skip this check for if this is not neutrino backend. if !t.isNeutrinoBackend() { return false } + txid := r.tx.TxHash() + spends := t.getSpentInputs(r) + + // Iterate all the spending txns and check if they match the sweeping + // tx. + for op, spendingTx := range spends { + spendingTxID := spendingTx.TxHash() + + // If the spending tx is the same as the sweeping tx + // then we are good. + if spendingTxID == txid { + continue + } + + log.Warnf("Detected third party spent of output=%v "+ + "in tx=%v", op, spendingTx.TxHash()) + + return true + } + + return false +} + +// getSpentInputs performs a non-blocking read on the spending subscriptions to +// see whether any of the monitored inputs has been spent. A map of inputs with +// their spending txns are returned if found. +func (t *TxPublisher) getSpentInputs( + r *monitorRecord) map[wire.OutPoint]*wire.MsgTx { + + // Create a slice to record the inputs spent. + spentInputs := make(map[wire.OutPoint]*wire.MsgTx, len(r.req.Inputs)) + // Iterate all the inputs and check if they have been spent already. - for _, inp := range inputs { + for _, inp := range r.req.Inputs { op := inp.OutPoint() // For wallet utxos, the height hint is not set - we don't need // to monitor them for third party spend. + // + // TODO(yy): We need to properly lock wallet utxos before + // skipping this check as the same wallet utxo can be used by + // different sweeping txns. heightHint := inp.HeightHint() if heightHint == 0 { - log.Debugf("Skipped third party check for wallet "+ - "input %v", op) - - continue + heightHint = uint32(t.currentHeight.Load()) + log.Debugf("Checking wallet input %v using heightHint "+ + "%v", op, heightHint) } // If the input has already been spent after the height hint, a @@ -1283,7 +1316,8 @@ func (t *TxPublisher) isThirdPartySpent(txid chainhash.Hash, if err != nil { log.Criticalf("Failed to register spend ntfn for "+ "input=%v: %v", op, err) - return false + + return nil } // Remove the subscription when exit. @@ -1294,28 +1328,24 @@ func (t *TxPublisher) isThirdPartySpent(txid chainhash.Hash, case spend, ok := <-spendEvent.Spend: if !ok { log.Debugf("Spend ntfn for %v canceled", op) - return false - } - spendingTxID := spend.SpendingTx.TxHash() - - // If the spending tx is the same as the sweeping tx - // then we are good. - if spendingTxID == txid { continue } - log.Warnf("Detected third party spent of output=%v "+ - "in tx=%v", op, spend.SpendingTx.TxHash()) + spendingTx := spend.SpendingTx - return true + log.Debugf("Detected spent of input=%v in tx=%v", op, + spendingTx.TxHash()) + + spentInputs[op] = spendingTx // Move to the next input. default: + log.Tracef("Input %v not spent yet", op) } } - return false + return spentInputs } // calcCurrentConfTarget calculates the current confirmation target based on diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index 0531dec8d..1d53432d9 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -55,7 +55,7 @@ func createTestInput(value int64, PubKey: testPubKey, }, }, - 0, + 1, nil, ) @@ -1776,3 +1776,126 @@ func TestHandleInitialBroadcastFail(t *testing.T) { require.Equal(t, 0, tp.records.Len()) require.Equal(t, 0, tp.subscriberChans.Len()) } + +// TestHasInputsSpent checks the expected outpoint:tx map is returned. +func TestHasInputsSpent(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create mock inputs. + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1}, + Index: 1, + } + inp1 := &input.MockInput{} + heightHint1 := uint32(1) + defer inp1.AssertExpectations(t) + + op2 := wire.OutPoint{ + Hash: chainhash.Hash{1}, + Index: 2, + } + inp2 := &input.MockInput{} + heightHint2 := uint32(2) + defer inp2.AssertExpectations(t) + + op3 := wire.OutPoint{ + Hash: chainhash.Hash{1}, + Index: 3, + } + walletInp := &input.MockInput{} + heightHint3 := uint32(0) + defer walletInp.AssertExpectations(t) + + // We expect all the inputs to call OutPoint and HeightHint. + inp1.On("OutPoint").Return(op1).Once() + inp2.On("OutPoint").Return(op2).Once() + walletInp.On("OutPoint").Return(op3).Once() + inp1.On("HeightHint").Return(heightHint1).Once() + inp2.On("HeightHint").Return(heightHint2).Once() + walletInp.On("HeightHint").Return(heightHint3).Once() + + // We expect the normal inputs to call SignDesc. + pkScript1 := []byte{1} + sd1 := &input.SignDescriptor{ + Output: &wire.TxOut{ + PkScript: pkScript1, + }, + } + inp1.On("SignDesc").Return(sd1).Once() + + pkScript2 := []byte{1} + sd2 := &input.SignDescriptor{ + Output: &wire.TxOut{ + PkScript: pkScript2, + }, + } + inp2.On("SignDesc").Return(sd2).Once() + + pkScript3 := []byte{3} + sd3 := &input.SignDescriptor{ + Output: &wire.TxOut{ + PkScript: pkScript3, + }, + } + walletInp.On("SignDesc").Return(sd3).Once() + + // Mock RegisterSpendNtfn. + // + // spendingTx1 is the tx spending op1. + spendingTx1 := &wire.MsgTx{} + se1 := createTestSpendEvent(spendingTx1) + m.notifier.On("RegisterSpendNtfn", + &op1, pkScript1, heightHint1).Return(se1, nil).Once() + + // Create the spending event that doesn't send an event. + se2 := &chainntnfs.SpendEvent{ + Cancel: func() {}, + } + m.notifier.On("RegisterSpendNtfn", + &op2, pkScript2, heightHint2).Return(se2, nil).Once() + + se3 := &chainntnfs.SpendEvent{ + Cancel: func() {}, + } + m.notifier.On("RegisterSpendNtfn", + &op3, pkScript3, heightHint3).Return(se3, nil).Once() + + // Prepare the test inputs. + inputs := []input.Input{inp1, inp2, walletInp} + + // Prepare the test record. + record := &monitorRecord{ + req: &BumpRequest{ + Inputs: inputs, + }, + } + + // Call the method under test. + result := tp.getSpentInputs(record) + + // Assert the expected map is created. + expected := map[wire.OutPoint]*wire.MsgTx{ + op1: spendingTx1, + } + require.Equal(t, expected, result) +} + +// createTestSpendEvent creates a SpendEvent which places the specified tx in +// the channel, which can be read by a spending subscriber. +func createTestSpendEvent(tx *wire.MsgTx) *chainntnfs.SpendEvent { + // Create a monitor record that's confirmed. + spendDetails := chainntnfs.SpendDetail{ + SpendingTx: tx, + } + spendChan1 := make(chan *chainntnfs.SpendDetail, 1) + spendChan1 <- &spendDetails + + // Create the spend events. + return &chainntnfs.SpendEvent{ + Spend: spendChan1, + Cancel: func() {}, + } +}