sweep: make sure inputs with different locktime values are not grouped

This commit is contained in:
yyforyongyu 2024-04-11 22:01:18 +08:00
parent 49cfb91af1
commit 871cab4bc0
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
2 changed files with 141 additions and 8 deletions

View File

@ -551,10 +551,15 @@ func (b *BudgetAggregator) ClusterInputs(inputs InputsMap,
// Sort the inputs by their economical value. // Sort the inputs by their economical value.
sortedInputs := b.sortInputs(cluster) sortedInputs := b.sortInputs(cluster)
// Split on locktimes if they are different.
splitClusters := splitOnLocktime(sortedInputs)
// Create input sets from the cluster. // Create input sets from the cluster.
sets := b.createInputSets(sortedInputs, height) for _, cluster := range splitClusters {
sets := b.createInputSets(cluster, height)
inputSets = append(inputSets, sets...) inputSets = append(inputSets, sets...)
} }
}
// Create input sets from the exclusive inputs. // Create input sets from the exclusive inputs.
for _, cluster := range exclusiveInputs { for _, cluster := range exclusiveInputs {
@ -742,6 +747,62 @@ func (b *BudgetAggregator) sortInputs(inputs []SweeperInput) []SweeperInput {
return sortedInputs return sortedInputs
} }
// splitOnLocktime splits the list of inputs based on their locktime.
//
// TODO(yy): this is a temporary hack as the blocks are not synced among the
// contractcourt and the sweeper.
func splitOnLocktime(inputs []SweeperInput) map[uint32][]SweeperInput {
result := make(map[uint32][]SweeperInput)
noLocktimeInputs := make([]SweeperInput, 0, len(inputs))
// mergeLocktime is the locktime that we use to merge all the
// nolocktime inputs into.
var mergeLocktime uint32
// Iterate all inputs and split them based on their locktimes.
for _, inp := range inputs {
locktime, required := inp.RequiredLockTime()
if !required {
log.Tracef("No locktime required for input=%v",
inp.OutPoint())
noLocktimeInputs = append(noLocktimeInputs, inp)
continue
}
log.Tracef("Split input=%v on locktime=%v", inp.OutPoint(),
locktime)
// Get the slice - the slice will be initialized if not found.
inputList := result[locktime]
// Add the input to the list.
inputList = append(inputList, inp)
// Update the map.
result[locktime] = inputList
// Update the merge locktime.
mergeLocktime = locktime
}
// If there are locktime inputs, we will merge the no locktime inputs
// to the last locktime group found.
if len(result) > 0 {
log.Tracef("No locktime inputs has been merged to locktime=%v",
mergeLocktime)
result[mergeLocktime] = append(
result[mergeLocktime], noLocktimeInputs...,
)
} else {
// Otherwise just return the no locktime inputs.
result[mergeLocktime] = noLocktimeInputs
}
return result
}
// isDustOutput checks if the given output is considered as dust. // isDustOutput checks if the given output is considered as dust.
func isDustOutput(output *wire.TxOut) bool { func isDustOutput(output *wire.TxOut) bool {
// Fetch the dust limit for this output. // Fetch the dust limit for this output.

View File

@ -839,7 +839,7 @@ func TestBudgetInputSetClusterInputs(t *testing.T) {
// 3. when assigning the input to the exclusiveInputs. // 3. when assigning the input to the exclusiveInputs.
// 4. when iterating the exclusiveInputs. // 4. when iterating the exclusiveInputs.
opExclusive := wire.OutPoint{Hash: chainhash.Hash{1, 2, 3, 4, 5}} opExclusive := wire.OutPoint{Hash: chainhash.Hash{1, 2, 3, 4, 5}}
inpExclusive.On("OutPoint").Return(opExclusive).Times(4) inpExclusive.On("OutPoint").Return(opExclusive).Maybe()
// Mock the `WitnessType` method to return the witness type. // Mock the `WitnessType` method to return the witness type.
inpExclusive.On("WitnessType").Return(wt) inpExclusive.On("WitnessType").Return(wt)
@ -895,11 +895,10 @@ func TestBudgetInputSetClusterInputs(t *testing.T) {
// `filterInputs`. // `filterInputs`.
inpLow.On("OutPoint").Return(opLow).Once() inpLow.On("OutPoint").Return(opLow).Once()
// We expect the high budget input to call this method three // The number of times this method is called is dependent on
// times, one in `filterInputs` and one in `createInputSet`, // the log level.
// and one in `NewBudgetInputSet`. inpHigh1.On("OutPoint").Return(opHigh1).Maybe()
inpHigh1.On("OutPoint").Return(opHigh1).Times(3) inpHigh2.On("OutPoint").Return(opHigh2).Maybe()
inpHigh2.On("OutPoint").Return(opHigh2).Times(3)
// Mock the `WitnessType` method to return the witness type. // Mock the `WitnessType` method to return the witness type.
inpLow.On("WitnessType").Return(wt) inpLow.On("WitnessType").Return(wt)
@ -910,6 +909,10 @@ func TestBudgetInputSetClusterInputs(t *testing.T) {
inpHigh1.On("RequiredTxOut").Return(nil) inpHigh1.On("RequiredTxOut").Return(nil)
inpHigh2.On("RequiredTxOut").Return(nil) inpHigh2.On("RequiredTxOut").Return(nil)
// Mock the `RequiredLockTime` to return 0.
inpHigh1.On("RequiredLockTime").Return(uint32(0), false)
inpHigh2.On("RequiredLockTime").Return(uint32(0), false)
// Add the low input, which should be filtered out. // Add the low input, which should be filtered out.
inputs[opLow] = &SweeperInput{ inputs[opLow] = &SweeperInput{
Input: inpLow, Input: inpLow,
@ -969,3 +972,72 @@ func TestBudgetInputSetClusterInputs(t *testing.T) {
require.Contains(t, deadlines, deadline1.UnwrapOrFail(t)) require.Contains(t, deadlines, deadline1.UnwrapOrFail(t))
require.Contains(t, deadlines, deadline2.UnwrapOrFail(t)) require.Contains(t, deadlines, deadline2.UnwrapOrFail(t))
} }
// TestSplitOnLocktime asserts `splitOnLocktime` works as expected.
func TestSplitOnLocktime(t *testing.T) {
t.Parallel()
// Create two locktimes.
lockTime1 := uint32(1)
lockTime2 := uint32(2)
// Create cluster one, which has a locktime of 1.
input1LockTime1 := &input.MockInput{}
input2LockTime1 := &input.MockInput{}
input1LockTime1.On("RequiredLockTime").Return(lockTime1, true)
input2LockTime1.On("RequiredLockTime").Return(lockTime1, true)
// Create cluster two, which has a locktime of 2.
input3LockTime2 := &input.MockInput{}
input4LockTime2 := &input.MockInput{}
input3LockTime2.On("RequiredLockTime").Return(lockTime2, true)
input4LockTime2.On("RequiredLockTime").Return(lockTime2, true)
// Create cluster three, which has no locktime.
// Create cluster three, which has no locktime.
input5NoLockTime := &input.MockInput{}
input6NoLockTime := &input.MockInput{}
input5NoLockTime.On("RequiredLockTime").Return(uint32(0), false)
input6NoLockTime.On("RequiredLockTime").Return(uint32(0), false)
// Mock `OutPoint` - it may or may not be called due to log settings.
input1LockTime1.On("OutPoint").Return(wire.OutPoint{Index: 1}).Maybe()
input2LockTime1.On("OutPoint").Return(wire.OutPoint{Index: 2}).Maybe()
input3LockTime2.On("OutPoint").Return(wire.OutPoint{Index: 3}).Maybe()
input4LockTime2.On("OutPoint").Return(wire.OutPoint{Index: 4}).Maybe()
input5NoLockTime.On("OutPoint").Return(wire.OutPoint{Index: 5}).Maybe()
input6NoLockTime.On("OutPoint").Return(wire.OutPoint{Index: 6}).Maybe()
// With the inner Input being mocked, we can now create the pending
// inputs.
input1 := SweeperInput{Input: input1LockTime1}
input2 := SweeperInput{Input: input2LockTime1}
input3 := SweeperInput{Input: input3LockTime2}
input4 := SweeperInput{Input: input4LockTime2}
input5 := SweeperInput{Input: input5NoLockTime}
input6 := SweeperInput{Input: input6NoLockTime}
// Call the method under test.
inputs := []SweeperInput{input1, input2, input3, input4, input5, input6}
result := splitOnLocktime(inputs)
// We expect the no locktime inputs to be grouped with locktime2.
expectedResult := map[uint32][]SweeperInput{
lockTime1: {input1, input2},
lockTime2: {input3, input4, input5, input6},
}
require.Len(t, result[lockTime1], 2)
require.Len(t, result[lockTime2], 4)
require.Equal(t, expectedResult, result)
// Test the case where there are no locktime inputs.
inputs = []SweeperInput{input5, input6}
result = splitOnLocktime(inputs)
// We expect the no locktime inputs to be returned as is.
expectedResult = map[uint32][]SweeperInput{
uint32(0): {input5, input6},
}
require.Len(t, result[uint32(0)], 2)
require.Equal(t, expectedResult, result)
}