diff --git a/input/script_utils.go b/input/script_utils.go index 80997eed4..104c24251 100644 --- a/input/script_utils.go +++ b/input/script_utils.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "golang.org/x/crypto/ripemd160" ) @@ -789,10 +790,10 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, // unilaterally spend the created output. func SenderHTLCScriptTaproot(senderHtlcKey, receiverHtlcKey, revokeKey *btcec.PublicKey, payHash []byte, - localCommit bool) (*HtlcScriptTree, error) { + whoseCommit lntypes.ChannelParty) (*HtlcScriptTree, error) { var hType htlcType - if localCommit { + if whoseCommit.IsLocal() { hType = htlcLocalOutgoing } else { hType = htlcRemoteIncoming @@ -1348,10 +1349,11 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, // the tap leaf are returned. func ReceiverHTLCScriptTaproot(cltvExpiry uint32, senderHtlcKey, receiverHtlcKey, revocationKey *btcec.PublicKey, - payHash []byte, ourCommit bool) (*HtlcScriptTree, error) { + payHash []byte, whoseCommit lntypes.ChannelParty, +) (*HtlcScriptTree, error) { var hType htlcType - if ourCommit { + if whoseCommit.IsLocal() { hType = htlcLocalIncoming } else { hType = htlcRemoteOutgoing diff --git a/input/size_test.go b/input/size_test.go index 9c3446afb..daa7053cc 100644 --- a/input/size_test.go +++ b/input/size_test.go @@ -13,6 +13,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/stretchr/testify/require" ) @@ -1073,7 +1074,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), - revokeKey.PubKey(), payHash[:], false, + revokeKey.PubKey(), payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1115,7 +1116,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1157,7 +1158,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1203,7 +1204,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), - revokeKey.PubKey(), payHash[:], false, + revokeKey.PubKey(), payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1263,7 +1264,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), - revokeKey.PubKey(), payHash[:], false, + revokeKey.PubKey(), payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1309,7 +1310,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1394,7 +1395,8 @@ func genTimeoutTx(t *testing.T, ) if chanType.IsTaproot() { tapscriptTree, err = input.SenderHTLCScriptTaproot( - testPubkey, testPubkey, testPubkey, testHash160, false, + testPubkey, testPubkey, testPubkey, testHash160, + lntypes.Remote, ) require.NoError(t, err) @@ -1463,7 +1465,7 @@ func genSuccessTx(t *testing.T, chanType channeldb.ChannelType) *wire.MsgTx { if chanType.IsTaproot() { tapscriptTree, err = input.ReceiverHTLCScriptTaproot( testCLTVExpiry, testPubkey, testPubkey, testPubkey, - testHash160, false, + testHash160, lntypes.Remote, ) require.NoError(t, err) diff --git a/input/taproot_test.go b/input/taproot_test.go index 801b0fef4..434be2dfd 100644 --- a/input/taproot_test.go +++ b/input/taproot_test.go @@ -48,7 +48,7 @@ func newTestSenderHtlcScriptTree(t *testing.T) *testSenderHtlcScriptTree { payHash := preImage.Hash() htlcScriptTree, err := SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -471,7 +471,7 @@ func newTestReceiverHtlcScriptTree(t *testing.T) *testReceiverHtlcScriptTree { payHash := preImage.Hash() htlcScriptTree, err := ReceiverHTLCScriptTaproot( cltvExpiry, senderKey.PubKey(), receiverKey.PubKey(), - revokeKey.PubKey(), payHash[:], false, + revokeKey.PubKey(), payHash[:], lntypes.Remote, ) require.NoError(t, err) diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index 96af8d7cf..1e1140fbc 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -1095,7 +1095,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, case isIncoming && ourCommit: htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, + keyRing.RevocationKey, rHash[:], lntypes.Local, ) // We're being paid via an HTLC by the remote party, and the HTLC is @@ -1104,7 +1104,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, case isIncoming && !ourCommit: htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, + keyRing.RevocationKey, rHash[:], lntypes.Remote, ) // We're sending an HTLC which is being added to our commitment @@ -1113,7 +1113,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, case !isIncoming && ourCommit: htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, + keyRing.RevocationKey, rHash[:], lntypes.Local, ) // Finally, we're paying the remote party via an HTLC, which is being @@ -1122,7 +1122,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, case !isIncoming && !ourCommit: htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, + keyRing.RevocationKey, rHash[:], lntypes.Remote, ) }