diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index deea56f4c..54d622d5e 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/davecgh/go-spew/spew" @@ -1617,6 +1618,60 @@ func (i *testInput) RequiredTxOut() *wire.TxOut { return i.reqTxOut } +// CraftInputScript is a custom sign method for the testInput type that will +// encode the spending outpoint and the tx input index as part of the returned +// witness. +func (i *testInput) CraftInputScript(_ input.Signer, txn *wire.MsgTx, + hashCache *txscript.TxSigHashes, txinIdx int) (*input.Script, error) { + + // We'll encode the outpoint in the witness, so we can assert that the + // expected input was signed at the correct index. + op := i.OutPoint() + return &input.Script{ + Witness: [][]byte{ + // We encode the hash of the outpoint... + op.Hash[:], + // ..the outpoint index... + {byte(op.Index)}, + // ..and finally the tx input index. + {byte(txinIdx)}, + }, + }, nil +} + +// assertSignedIndex goes through all inputs to the tx and checks that all +// testInputs have witnesses corresponding to the outpoints they are spending, +// and are signed at the correct tx input index. All found testInputs are +// returned such that we can sum up and sanity check that all testInputs were +// part of the sweep. +func assertSignedIndex(t *testing.T, tx *wire.MsgTx, + testInputs map[wire.OutPoint]*testInput) map[wire.OutPoint]struct{} { + + found := make(map[wire.OutPoint]struct{}) + for idx, txIn := range tx.TxIn { + op := txIn.PreviousOutPoint + + // Not a testInput, it won't have the test encoding we require + // to check outpoint and index. + if _, ok := testInputs[op]; !ok { + continue + } + + if _, ok := found[op]; ok { + t.Fatalf("input already used") + } + + // Check it was signes spending the correct outpoint, and at + // the expected tx input index. + require.Equal(t, txIn.Witness[0], op.Hash[:]) + require.Equal(t, txIn.Witness[1], []byte{byte(op.Index)}) + require.Equal(t, txIn.Witness[2], []byte{byte(idx)}) + found[op] = struct{}{} + } + + return found +} + // TestLockTimes checks that the sweeper properly groups inputs requiring the // same locktime together into sweep transactions. func TestLockTimes(t *testing.T) { @@ -1824,6 +1879,66 @@ func TestRequiredTxOuts(t *testing.T) { ) }, }, + { + // Two inputs, where the first one required no tx out. + name: "two inputs, one with required tx out", + inputs: []*testInput{ + { + + // We add a normal, non-requiredTxOut + // input. We use test input 10, to make + // sure this has a higher yield than + // the other input, and will be + // attempted added first to the sweep + // tx. + BaseInput: inputs[10], + }, + { + // The second input requires a TxOut. + BaseInput: inputs[0], + reqTxOut: &wire.TxOut{ + PkScript: []byte("aaa"), + Value: inputs[0].SignDesc().Output.Value, + }, + }, + }, + + // We expect the inputs to have been reordered. + assertSweeps: func(t *testing.T, + _ map[wire.OutPoint]*testInput, + txs []*wire.MsgTx) { + + require.Equal(t, 1, len(txs)) + + tx := txs[0] + require.Equal(t, 2, len(tx.TxIn)) + require.Equal(t, 2, len(tx.TxOut)) + + // The required TxOut should be the first one. + out := tx.TxOut[0] + require.Equal(t, []byte("aaa"), out.PkScript) + require.Equal( + t, inputs[0].SignDesc().Output.Value, + out.Value, + ) + + // The first input should be the one having the + // required TxOut. + require.Len(t, tx.TxIn, 2) + require.Equal( + t, inputs[0].OutPoint(), + &tx.TxIn[0].PreviousOutPoint, + ) + + // Second one is the one without a required tx + // out. + require.Equal( + t, inputs[10].OutPoint(), + &tx.TxIn[1].PreviousOutPoint, + ) + }, + }, + { // An input committing to an output of equal value, just // add input to pay fees. @@ -2076,6 +2191,30 @@ func TestRequiredTxOuts(t *testing.T) { // Assert the transactions are what we expect. testCase.assertSweeps(t, inputs, sweeps) + + // Finally we assert that all our test inputs were part + // of the sweeps, and that they were signed correctly. + sweptInputs := make(map[wire.OutPoint]struct{}) + for _, sweep := range sweeps { + swept := assertSignedIndex(t, sweep, inputs) + for op := range swept { + if _, ok := sweptInputs[op]; ok { + t.Fatalf("outpoint %v part of "+ + "previous sweep", op) + } + + sweptInputs[op] = struct{}{} + } + } + + require.Equal(t, len(inputs), len(sweptInputs)) + for op := range sweptInputs { + _, ok := inputs[op] + if !ok { + t.Fatalf("swept input %v not part of "+ + "test inputs", op) + } + } }) } } diff --git a/sweep/txgenerator.go b/sweep/txgenerator.go index 670a22551..ec58b4141 100644 --- a/sweep/txgenerator.go +++ b/sweep/txgenerator.go @@ -137,28 +137,35 @@ func createSweepTx(inputs []input.Input, outputPkScript []byte, dustLimit btcutil.Amount, signer input.Signer) (*wire.MsgTx, error) { inputs, estimator := getWeightEstimate(inputs, feePerKw) - txFee := estimator.fee() - // Create the sweep transaction that we will be building. We use - // version 2 as it is required for CSV. - sweepTx := wire.NewMsgTx(2) + var ( + // Create the sweep transaction that we will be building. We + // use version 2 as it is required for CSV. + sweepTx = wire.NewMsgTx(2) - // Track whether any of the inputs require a certain locktime. - locktime := int32(-1) + // Track whether any of the inputs require a certain locktime. + locktime = int32(-1) + + // We keep track of total input amount, and required output + // amount to use for calculating the change amount below. + totalInput btcutil.Amount + requiredOutput btcutil.Amount + + // We'll add the inputs as we go so we know the final ordering + // of inputs to sign. + idxs []input.Input + ) // We start by adding all inputs that commit to an output. We do this // since the input and output index must stay the same for the // signatures to be valid. - var ( - totalInput btcutil.Amount - requiredOutput btcutil.Amount - ) for _, o := range inputs { if o.RequiredTxOut() == nil { continue } + idxs = append(idxs, o) sweepTx.AddTxIn(&wire.TxIn{ PreviousOutPoint: *o.OutPoint(), Sequence: o.BlocksToMaturity(), @@ -186,6 +193,7 @@ func createSweepTx(inputs []input.Input, outputPkScript []byte, continue } + idxs = append(idxs, o) sweepTx.AddTxIn(&wire.TxIn{ PreviousOutPoint: *o.OutPoint(), Sequence: o.BlocksToMaturity(), @@ -255,10 +263,8 @@ func createSweepTx(inputs []input.Input, outputPkScript []byte, return nil } - // Finally we'll attach a valid input script to each csv and cltv input - // within the sweeping transaction. - for i, input := range inputs { - if err := addInputScript(i, input); err != nil { + for idx, inp := range idxs { + if err := addInputScript(idx, inp); err != nil { return nil, err } }