diff --git a/funding/batch.go b/funding/batch.go index 2adb4e3d8..3b18a7b5d 100644 --- a/funding/batch.go +++ b/funding/batch.go @@ -320,6 +320,7 @@ func (b *Batcher) BatchFund(ctx context.Context, // anyway. firstReq := b.channels[0].fundingReq feeRateSatPerKVByte := firstReq.FundingFeePerKw.FeePerKVByte() + changeType := walletrpc.ChangeAddressType_CHANGE_ADDRESS_TYPE_P2TR fundPsbtReq := &walletrpc.FundPsbtRequest{ Template: &walletrpc.FundPsbtRequest_Raw{ Raw: txTemplate, @@ -329,6 +330,7 @@ func (b *Batcher) BatchFund(ctx context.Context, }, MinConfs: firstReq.MinConfs, SpendUnconfirmed: firstReq.MinConfs == 0, + ChangeType: changeType, } fundPsbtResp, err := b.cfg.WalletKitServer.FundPsbt(ctx, fundPsbtReq) if err != nil { diff --git a/itest/lnd_funding_test.go b/itest/lnd_funding_test.go index 82ac7f153..1162142b9 100644 --- a/itest/lnd_funding_test.go +++ b/itest/lnd_funding_test.go @@ -7,6 +7,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/input" @@ -777,6 +778,21 @@ func testBatchChanFunding(ht *lntest.HarnessTest) { ht.AssertTopologyChannelOpen(alice, chanPoint2) ht.AssertTopologyChannelOpen(alice, chanPoint3) + // Check if the change type from the batch_open_channel funding is P2TR. + rawTx := ht.Miner.GetRawTransaction(txHash) + require.Len(ht, rawTx.MsgTx().TxOut, 4) + + // For calculating the change output index we use the formula for the + // sum of consecutive of integers (n(n+1)/2). All the channel point + // indexes are known, so we just calculate the difference to get the + // change output index. + changeIndex := uint32(6) - (chanPoint1.OutputIndex + + chanPoint2.OutputIndex + chanPoint3.OutputIndex) + + ht.AssertOutputScriptClass( + rawTx, changeIndex, txscript.WitnessV1TaprootTy, + ) + // With the channel open, ensure that it is counted towards Alice's // total channel balance. balRes := alice.RPC.ChannelBalance() diff --git a/lntest/harness_assertion.go b/lntest/harness_assertion.go index bf1bb4ba8..47a1fa38b 100644 --- a/lntest/harness_assertion.go +++ b/lntest/harness_assertion.go @@ -377,6 +377,20 @@ func (h *HarnessTest) AssertChannelExists(hn *node.HarnessNode, return channel } +// AssertOutputScriptClass checks that the specified transaction output has the +// expected script class. +func (h *HarnessTest) AssertOutputScriptClass(tx *btcutil.Tx, + outputIndex uint32, scriptClass txscript.ScriptClass) { + + require.Greater(h, len(tx.MsgTx().TxOut), int(outputIndex)) + + txOut := tx.MsgTx().TxOut[outputIndex] + + pkScript, err := txscript.ParsePkScript(txOut.PkScript) + require.NoError(h, err) + require.Equal(h, pkScript.Class(), scriptClass) +} + // findChannel tries to find a target channel in the node using the given // channel point. func (h *HarnessTest) findChannel(hn *node.HarnessNode,