input+lnwallet: refactor select methods in input to use ChannelParty

This commit is contained in:
Keagan McClelland 2024-07-30 16:18:09 -07:00
parent 1a5b5c5f62
commit 3a15085014
No known key found for this signature in database
GPG Key ID: FA7E65C951F12439
4 changed files with 22 additions and 18 deletions

View File

@ -13,6 +13,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
"golang.org/x/crypto/ripemd160"
)
@ -789,10 +790,10 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey,
// unilaterally spend the created output.
func SenderHTLCScriptTaproot(senderHtlcKey, receiverHtlcKey,
revokeKey *btcec.PublicKey, payHash []byte,
localCommit bool) (*HtlcScriptTree, error) {
whoseCommit lntypes.ChannelParty) (*HtlcScriptTree, error) {
var hType htlcType
if localCommit {
if whoseCommit.IsLocal() {
hType = htlcLocalOutgoing
} else {
hType = htlcRemoteIncoming
@ -1348,10 +1349,11 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey,
// the tap leaf are returned.
func ReceiverHTLCScriptTaproot(cltvExpiry uint32,
senderHtlcKey, receiverHtlcKey, revocationKey *btcec.PublicKey,
payHash []byte, ourCommit bool) (*HtlcScriptTree, error) {
payHash []byte, whoseCommit lntypes.ChannelParty,
) (*HtlcScriptTree, error) {
var hType htlcType
if ourCommit {
if whoseCommit.IsLocal() {
hType = htlcLocalIncoming
} else {
hType = htlcRemoteOutgoing

View File

@ -13,6 +13,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/stretchr/testify/require"
)
@ -1073,7 +1074,7 @@ var witnessSizeTests = []witnessSizeTest{
htlcScriptTree, err := input.SenderHTLCScriptTaproot(
senderKey.PubKey(), receiverKey.PubKey(),
revokeKey.PubKey(), payHash[:], false,
revokeKey.PubKey(), payHash[:], lntypes.Remote,
)
require.NoError(t, err)
@ -1115,7 +1116,7 @@ var witnessSizeTests = []witnessSizeTest{
htlcScriptTree, err := input.ReceiverHTLCScriptTaproot(
testCLTVExpiry, senderKey.PubKey(),
receiverKey.PubKey(), revokeKey.PubKey(),
payHash[:], false,
payHash[:], lntypes.Remote,
)
require.NoError(t, err)
@ -1157,7 +1158,7 @@ var witnessSizeTests = []witnessSizeTest{
htlcScriptTree, err := input.ReceiverHTLCScriptTaproot(
testCLTVExpiry, senderKey.PubKey(),
receiverKey.PubKey(), revokeKey.PubKey(),
payHash[:], false,
payHash[:], lntypes.Remote,
)
require.NoError(t, err)
@ -1203,7 +1204,7 @@ var witnessSizeTests = []witnessSizeTest{
htlcScriptTree, err := input.SenderHTLCScriptTaproot(
senderKey.PubKey(), receiverKey.PubKey(),
revokeKey.PubKey(), payHash[:], false,
revokeKey.PubKey(), payHash[:], lntypes.Remote,
)
require.NoError(t, err)
@ -1263,7 +1264,7 @@ var witnessSizeTests = []witnessSizeTest{
htlcScriptTree, err := input.SenderHTLCScriptTaproot(
senderKey.PubKey(), receiverKey.PubKey(),
revokeKey.PubKey(), payHash[:], false,
revokeKey.PubKey(), payHash[:], lntypes.Remote,
)
require.NoError(t, err)
@ -1309,7 +1310,7 @@ var witnessSizeTests = []witnessSizeTest{
htlcScriptTree, err := input.ReceiverHTLCScriptTaproot(
testCLTVExpiry, senderKey.PubKey(),
receiverKey.PubKey(), revokeKey.PubKey(),
payHash[:], false,
payHash[:], lntypes.Remote,
)
require.NoError(t, err)
@ -1394,7 +1395,8 @@ func genTimeoutTx(t *testing.T,
)
if chanType.IsTaproot() {
tapscriptTree, err = input.SenderHTLCScriptTaproot(
testPubkey, testPubkey, testPubkey, testHash160, false,
testPubkey, testPubkey, testPubkey, testHash160,
lntypes.Remote,
)
require.NoError(t, err)
@ -1463,7 +1465,7 @@ func genSuccessTx(t *testing.T, chanType channeldb.ChannelType) *wire.MsgTx {
if chanType.IsTaproot() {
tapscriptTree, err = input.ReceiverHTLCScriptTaproot(
testCLTVExpiry, testPubkey, testPubkey, testPubkey,
testHash160, false,
testHash160, lntypes.Remote,
)
require.NoError(t, err)

View File

@ -48,7 +48,7 @@ func newTestSenderHtlcScriptTree(t *testing.T) *testSenderHtlcScriptTree {
payHash := preImage.Hash()
htlcScriptTree, err := SenderHTLCScriptTaproot(
senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(),
payHash[:], false,
payHash[:], lntypes.Remote,
)
require.NoError(t, err)
@ -471,7 +471,7 @@ func newTestReceiverHtlcScriptTree(t *testing.T) *testReceiverHtlcScriptTree {
payHash := preImage.Hash()
htlcScriptTree, err := ReceiverHTLCScriptTaproot(
cltvExpiry, senderKey.PubKey(), receiverKey.PubKey(),
revokeKey.PubKey(), payHash[:], false,
revokeKey.PubKey(), payHash[:], lntypes.Remote,
)
require.NoError(t, err)

View File

@ -1095,7 +1095,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32,
case isIncoming && ourCommit:
htlcScriptTree, err = input.ReceiverHTLCScriptTaproot(
timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey,
keyRing.RevocationKey, rHash[:], ourCommit,
keyRing.RevocationKey, rHash[:], lntypes.Local,
)
// We're being paid via an HTLC by the remote party, and the HTLC is
@ -1104,7 +1104,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32,
case isIncoming && !ourCommit:
htlcScriptTree, err = input.SenderHTLCScriptTaproot(
keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey,
keyRing.RevocationKey, rHash[:], ourCommit,
keyRing.RevocationKey, rHash[:], lntypes.Remote,
)
// We're sending an HTLC which is being added to our commitment
@ -1113,7 +1113,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32,
case !isIncoming && ourCommit:
htlcScriptTree, err = input.SenderHTLCScriptTaproot(
keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey,
keyRing.RevocationKey, rHash[:], ourCommit,
keyRing.RevocationKey, rHash[:], lntypes.Local,
)
// Finally, we're paying the remote party via an HTLC, which is being
@ -1122,7 +1122,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32,
case !isIncoming && !ourCommit:
htlcScriptTree, err = input.ReceiverHTLCScriptTaproot(
timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey,
keyRing.RevocationKey, rHash[:], ourCommit,
keyRing.RevocationKey, rHash[:], lntypes.Remote,
)
}