sweep: make sure inputs with different locktime values are not grouped

This commit is contained in:
yyforyongyu 2024-04-11 22:01:18 +08:00
parent 49cfb91af1
commit 871cab4bc0
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
2 changed files with 141 additions and 8 deletions

View File

@ -551,10 +551,15 @@ func (b *BudgetAggregator) ClusterInputs(inputs InputsMap,
// Sort the inputs by their economical value.
sortedInputs := b.sortInputs(cluster)
// Split on locktimes if they are different.
splitClusters := splitOnLocktime(sortedInputs)
// Create input sets from the cluster.
sets := b.createInputSets(sortedInputs, height)
for _, cluster := range splitClusters {
sets := b.createInputSets(cluster, height)
inputSets = append(inputSets, sets...)
}
}
// Create input sets from the exclusive inputs.
for _, cluster := range exclusiveInputs {
@ -742,6 +747,62 @@ func (b *BudgetAggregator) sortInputs(inputs []SweeperInput) []SweeperInput {
return sortedInputs
}
// splitOnLocktime splits the list of inputs based on their locktime.
//
// TODO(yy): this is a temporary hack as the blocks are not synced among the
// contractcourt and the sweeper.
func splitOnLocktime(inputs []SweeperInput) map[uint32][]SweeperInput {
result := make(map[uint32][]SweeperInput)
noLocktimeInputs := make([]SweeperInput, 0, len(inputs))
// mergeLocktime is the locktime that we use to merge all the
// nolocktime inputs into.
var mergeLocktime uint32
// Iterate all inputs and split them based on their locktimes.
for _, inp := range inputs {
locktime, required := inp.RequiredLockTime()
if !required {
log.Tracef("No locktime required for input=%v",
inp.OutPoint())
noLocktimeInputs = append(noLocktimeInputs, inp)
continue
}
log.Tracef("Split input=%v on locktime=%v", inp.OutPoint(),
locktime)
// Get the slice - the slice will be initialized if not found.
inputList := result[locktime]
// Add the input to the list.
inputList = append(inputList, inp)
// Update the map.
result[locktime] = inputList
// Update the merge locktime.
mergeLocktime = locktime
}
// If there are locktime inputs, we will merge the no locktime inputs
// to the last locktime group found.
if len(result) > 0 {
log.Tracef("No locktime inputs has been merged to locktime=%v",
mergeLocktime)
result[mergeLocktime] = append(
result[mergeLocktime], noLocktimeInputs...,
)
} else {
// Otherwise just return the no locktime inputs.
result[mergeLocktime] = noLocktimeInputs
}
return result
}
// isDustOutput checks if the given output is considered as dust.
func isDustOutput(output *wire.TxOut) bool {
// Fetch the dust limit for this output.

View File

@ -839,7 +839,7 @@ func TestBudgetInputSetClusterInputs(t *testing.T) {
// 3. when assigning the input to the exclusiveInputs.
// 4. when iterating the exclusiveInputs.
opExclusive := wire.OutPoint{Hash: chainhash.Hash{1, 2, 3, 4, 5}}
inpExclusive.On("OutPoint").Return(opExclusive).Times(4)
inpExclusive.On("OutPoint").Return(opExclusive).Maybe()
// Mock the `WitnessType` method to return the witness type.
inpExclusive.On("WitnessType").Return(wt)
@ -895,11 +895,10 @@ func TestBudgetInputSetClusterInputs(t *testing.T) {
// `filterInputs`.
inpLow.On("OutPoint").Return(opLow).Once()
// We expect the high budget input to call this method three
// times, one in `filterInputs` and one in `createInputSet`,
// and one in `NewBudgetInputSet`.
inpHigh1.On("OutPoint").Return(opHigh1).Times(3)
inpHigh2.On("OutPoint").Return(opHigh2).Times(3)
// The number of times this method is called is dependent on
// the log level.
inpHigh1.On("OutPoint").Return(opHigh1).Maybe()
inpHigh2.On("OutPoint").Return(opHigh2).Maybe()
// Mock the `WitnessType` method to return the witness type.
inpLow.On("WitnessType").Return(wt)
@ -910,6 +909,10 @@ func TestBudgetInputSetClusterInputs(t *testing.T) {
inpHigh1.On("RequiredTxOut").Return(nil)
inpHigh2.On("RequiredTxOut").Return(nil)
// Mock the `RequiredLockTime` to return 0.
inpHigh1.On("RequiredLockTime").Return(uint32(0), false)
inpHigh2.On("RequiredLockTime").Return(uint32(0), false)
// Add the low input, which should be filtered out.
inputs[opLow] = &SweeperInput{
Input: inpLow,
@ -969,3 +972,72 @@ func TestBudgetInputSetClusterInputs(t *testing.T) {
require.Contains(t, deadlines, deadline1.UnwrapOrFail(t))
require.Contains(t, deadlines, deadline2.UnwrapOrFail(t))
}
// TestSplitOnLocktime asserts `splitOnLocktime` works as expected.
func TestSplitOnLocktime(t *testing.T) {
t.Parallel()
// Create two locktimes.
lockTime1 := uint32(1)
lockTime2 := uint32(2)
// Create cluster one, which has a locktime of 1.
input1LockTime1 := &input.MockInput{}
input2LockTime1 := &input.MockInput{}
input1LockTime1.On("RequiredLockTime").Return(lockTime1, true)
input2LockTime1.On("RequiredLockTime").Return(lockTime1, true)
// Create cluster two, which has a locktime of 2.
input3LockTime2 := &input.MockInput{}
input4LockTime2 := &input.MockInput{}
input3LockTime2.On("RequiredLockTime").Return(lockTime2, true)
input4LockTime2.On("RequiredLockTime").Return(lockTime2, true)
// Create cluster three, which has no locktime.
// Create cluster three, which has no locktime.
input5NoLockTime := &input.MockInput{}
input6NoLockTime := &input.MockInput{}
input5NoLockTime.On("RequiredLockTime").Return(uint32(0), false)
input6NoLockTime.On("RequiredLockTime").Return(uint32(0), false)
// Mock `OutPoint` - it may or may not be called due to log settings.
input1LockTime1.On("OutPoint").Return(wire.OutPoint{Index: 1}).Maybe()
input2LockTime1.On("OutPoint").Return(wire.OutPoint{Index: 2}).Maybe()
input3LockTime2.On("OutPoint").Return(wire.OutPoint{Index: 3}).Maybe()
input4LockTime2.On("OutPoint").Return(wire.OutPoint{Index: 4}).Maybe()
input5NoLockTime.On("OutPoint").Return(wire.OutPoint{Index: 5}).Maybe()
input6NoLockTime.On("OutPoint").Return(wire.OutPoint{Index: 6}).Maybe()
// With the inner Input being mocked, we can now create the pending
// inputs.
input1 := SweeperInput{Input: input1LockTime1}
input2 := SweeperInput{Input: input2LockTime1}
input3 := SweeperInput{Input: input3LockTime2}
input4 := SweeperInput{Input: input4LockTime2}
input5 := SweeperInput{Input: input5NoLockTime}
input6 := SweeperInput{Input: input6NoLockTime}
// Call the method under test.
inputs := []SweeperInput{input1, input2, input3, input4, input5, input6}
result := splitOnLocktime(inputs)
// We expect the no locktime inputs to be grouped with locktime2.
expectedResult := map[uint32][]SweeperInput{
lockTime1: {input1, input2},
lockTime2: {input3, input4, input5, input6},
}
require.Len(t, result[lockTime1], 2)
require.Len(t, result[lockTime2], 4)
require.Equal(t, expectedResult, result)
// Test the case where there are no locktime inputs.
inputs = []SweeperInput{input5, input6}
result = splitOnLocktime(inputs)
// We expect the no locktime inputs to be returned as is.
expectedResult = map[uint32][]SweeperInput{
uint32(0): {input5, input6},
}
require.Len(t, result[uint32(0)], 2)
require.Equal(t, expectedResult, result)
}