sweep: add method getSpentInputs

To track the input and its spending tx, which will be used later to
detect unknown spends.
This commit is contained in:
yyforyongyu 2025-01-24 08:35:40 +08:00
parent 0e87863481
commit 8c9ba327cc
No known key found for this signature in database
GPG key ID: 9BCD95C4FF296868
2 changed files with 175 additions and 22 deletions

View file

@ -908,7 +908,7 @@ func (t *TxPublisher) processRecords() {
// Check whether the inputs has been spent by a third party. // Check whether the inputs has been spent by a third party.
// //
// NOTE: this check is only done for neutrino backend. // NOTE: this check is only done for neutrino backend.
if t.isThirdPartySpent(r.tx.TxHash(), r.req.Inputs) { if t.isThirdPartySpent(r) {
failedRecords[requestID] = r failedRecords[requestID] = r
// Move to the next record. // Move to the next record.
@ -1253,26 +1253,59 @@ func (t *TxPublisher) isConfirmed(txid chainhash.Hash) bool {
// //
// NOTE: this check is only performed for neutrino backend as it has no // NOTE: this check is only performed for neutrino backend as it has no
// reliable way to tell a tx has been replaced. // reliable way to tell a tx has been replaced.
func (t *TxPublisher) isThirdPartySpent(txid chainhash.Hash, func (t *TxPublisher) isThirdPartySpent(r *monitorRecord) bool {
inputs []input.Input) bool {
// Skip this check for if this is not neutrino backend. // Skip this check for if this is not neutrino backend.
if !t.isNeutrinoBackend() { if !t.isNeutrinoBackend() {
return false return false
} }
txid := r.tx.TxHash()
spends := t.getSpentInputs(r)
// Iterate all the spending txns and check if they match the sweeping
// tx.
for op, spendingTx := range spends {
spendingTxID := spendingTx.TxHash()
// If the spending tx is the same as the sweeping tx
// then we are good.
if spendingTxID == txid {
continue
}
log.Warnf("Detected third party spent of output=%v "+
"in tx=%v", op, spendingTx.TxHash())
return true
}
return false
}
// getSpentInputs performs a non-blocking read on the spending subscriptions to
// see whether any of the monitored inputs has been spent. A map of inputs with
// their spending txns are returned if found.
func (t *TxPublisher) getSpentInputs(
r *monitorRecord) map[wire.OutPoint]*wire.MsgTx {
// Create a slice to record the inputs spent.
spentInputs := make(map[wire.OutPoint]*wire.MsgTx, len(r.req.Inputs))
// Iterate all the inputs and check if they have been spent already. // Iterate all the inputs and check if they have been spent already.
for _, inp := range inputs { for _, inp := range r.req.Inputs {
op := inp.OutPoint() op := inp.OutPoint()
// For wallet utxos, the height hint is not set - we don't need // For wallet utxos, the height hint is not set - we don't need
// to monitor them for third party spend. // to monitor them for third party spend.
//
// TODO(yy): We need to properly lock wallet utxos before
// skipping this check as the same wallet utxo can be used by
// different sweeping txns.
heightHint := inp.HeightHint() heightHint := inp.HeightHint()
if heightHint == 0 { if heightHint == 0 {
log.Debugf("Skipped third party check for wallet "+ heightHint = uint32(t.currentHeight.Load())
"input %v", op) log.Debugf("Checking wallet input %v using heightHint "+
"%v", op, heightHint)
continue
} }
// If the input has already been spent after the height hint, a // If the input has already been spent after the height hint, a
@ -1283,7 +1316,8 @@ func (t *TxPublisher) isThirdPartySpent(txid chainhash.Hash,
if err != nil { if err != nil {
log.Criticalf("Failed to register spend ntfn for "+ log.Criticalf("Failed to register spend ntfn for "+
"input=%v: %v", op, err) "input=%v: %v", op, err)
return false
return nil
} }
// Remove the subscription when exit. // Remove the subscription when exit.
@ -1294,28 +1328,24 @@ func (t *TxPublisher) isThirdPartySpent(txid chainhash.Hash,
case spend, ok := <-spendEvent.Spend: case spend, ok := <-spendEvent.Spend:
if !ok { if !ok {
log.Debugf("Spend ntfn for %v canceled", op) log.Debugf("Spend ntfn for %v canceled", op)
return false
}
spendingTxID := spend.SpendingTx.TxHash()
// If the spending tx is the same as the sweeping tx
// then we are good.
if spendingTxID == txid {
continue continue
} }
log.Warnf("Detected third party spent of output=%v "+ spendingTx := spend.SpendingTx
"in tx=%v", op, spend.SpendingTx.TxHash())
return true log.Debugf("Detected spent of input=%v in tx=%v", op,
spendingTx.TxHash())
spentInputs[op] = spendingTx
// Move to the next input. // Move to the next input.
default: default:
log.Tracef("Input %v not spent yet", op)
} }
} }
return false return spentInputs
} }
// calcCurrentConfTarget calculates the current confirmation target based on // calcCurrentConfTarget calculates the current confirmation target based on

View file

@ -55,7 +55,7 @@ func createTestInput(value int64,
PubKey: testPubKey, PubKey: testPubKey,
}, },
}, },
0, 1,
nil, nil,
) )
@ -1776,3 +1776,126 @@ func TestHandleInitialBroadcastFail(t *testing.T) {
require.Equal(t, 0, tp.records.Len()) require.Equal(t, 0, tp.records.Len())
require.Equal(t, 0, tp.subscriberChans.Len()) require.Equal(t, 0, tp.subscriberChans.Len())
} }
// TestHasInputsSpent checks the expected outpoint:tx map is returned.
func TestHasInputsSpent(t *testing.T) {
t.Parallel()
// Create a publisher using the mocks.
tp, m := createTestPublisher(t)
// Create mock inputs.
op1 := wire.OutPoint{
Hash: chainhash.Hash{1},
Index: 1,
}
inp1 := &input.MockInput{}
heightHint1 := uint32(1)
defer inp1.AssertExpectations(t)
op2 := wire.OutPoint{
Hash: chainhash.Hash{1},
Index: 2,
}
inp2 := &input.MockInput{}
heightHint2 := uint32(2)
defer inp2.AssertExpectations(t)
op3 := wire.OutPoint{
Hash: chainhash.Hash{1},
Index: 3,
}
walletInp := &input.MockInput{}
heightHint3 := uint32(0)
defer walletInp.AssertExpectations(t)
// We expect all the inputs to call OutPoint and HeightHint.
inp1.On("OutPoint").Return(op1).Once()
inp2.On("OutPoint").Return(op2).Once()
walletInp.On("OutPoint").Return(op3).Once()
inp1.On("HeightHint").Return(heightHint1).Once()
inp2.On("HeightHint").Return(heightHint2).Once()
walletInp.On("HeightHint").Return(heightHint3).Once()
// We expect the normal inputs to call SignDesc.
pkScript1 := []byte{1}
sd1 := &input.SignDescriptor{
Output: &wire.TxOut{
PkScript: pkScript1,
},
}
inp1.On("SignDesc").Return(sd1).Once()
pkScript2 := []byte{1}
sd2 := &input.SignDescriptor{
Output: &wire.TxOut{
PkScript: pkScript2,
},
}
inp2.On("SignDesc").Return(sd2).Once()
pkScript3 := []byte{3}
sd3 := &input.SignDescriptor{
Output: &wire.TxOut{
PkScript: pkScript3,
},
}
walletInp.On("SignDesc").Return(sd3).Once()
// Mock RegisterSpendNtfn.
//
// spendingTx1 is the tx spending op1.
spendingTx1 := &wire.MsgTx{}
se1 := createTestSpendEvent(spendingTx1)
m.notifier.On("RegisterSpendNtfn",
&op1, pkScript1, heightHint1).Return(se1, nil).Once()
// Create the spending event that doesn't send an event.
se2 := &chainntnfs.SpendEvent{
Cancel: func() {},
}
m.notifier.On("RegisterSpendNtfn",
&op2, pkScript2, heightHint2).Return(se2, nil).Once()
se3 := &chainntnfs.SpendEvent{
Cancel: func() {},
}
m.notifier.On("RegisterSpendNtfn",
&op3, pkScript3, heightHint3).Return(se3, nil).Once()
// Prepare the test inputs.
inputs := []input.Input{inp1, inp2, walletInp}
// Prepare the test record.
record := &monitorRecord{
req: &BumpRequest{
Inputs: inputs,
},
}
// Call the method under test.
result := tp.getSpentInputs(record)
// Assert the expected map is created.
expected := map[wire.OutPoint]*wire.MsgTx{
op1: spendingTx1,
}
require.Equal(t, expected, result)
}
// createTestSpendEvent creates a SpendEvent which places the specified tx in
// the channel, which can be read by a spending subscriber.
func createTestSpendEvent(tx *wire.MsgTx) *chainntnfs.SpendEvent {
// Create a monitor record that's confirmed.
spendDetails := chainntnfs.SpendDetail{
SpendingTx: tx,
}
spendChan1 := make(chan *chainntnfs.SpendDetail, 1)
spendChan1 <- &spendDetails
// Create the spend events.
return &chainntnfs.SpendEvent{
Spend: spendChan1,
Cancel: func() {},
}
}