lnwallet: uniformly use sighash default everywhere for taproot chans

We use a helper function to ensure that anytime we're about to make a
normal sighash, we consult the channel type to check if we should use
the default value or sighash all explicitly.
This commit is contained in:
Olaoluwa Osuntokun 2023-08-18 15:17:51 -07:00
parent dd05dd55d4
commit ff055ce0a4
No known key found for this signature in database
GPG key ID: 3BBD59E99B280306
2 changed files with 23 additions and 22 deletions

View file

@ -2530,7 +2530,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64,
PkScript: ourScript.PkScript(), PkScript: ourScript.PkScript(),
Value: ourAmt, Value: ourAmt,
}, },
HashType: txscript.SigHashAll, HashType: sweepSigHash(chanState.ChanType),
} }
// For taproot channels, we'll make sure to set the script path // For taproot channels, we'll make sure to set the script path
@ -2577,7 +2577,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64,
PkScript: theirScript.PkScript(), PkScript: theirScript.PkScript(),
Value: theirAmt, Value: theirAmt,
}, },
HashType: txscript.SigHashAll, HashType: sweepSigHash(chanState.ChanType),
} }
// For taproot channels, the remote output (the revoked output) // For taproot channels, the remote output (the revoked output)
@ -2659,7 +2659,7 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel,
PkScript: scriptInfo.PkScript(), PkScript: scriptInfo.PkScript(),
Value: int64(htlc.Amt), Value: int64(htlc.Amt),
}, },
HashType: txscript.SigHashAll, HashType: sweepSigHash(chanState.ChanType),
} }
// For taproot HTLC outputs, we need to set the sign method to key // For taproot HTLC outputs, we need to set the sign method to key
@ -6477,7 +6477,7 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Si
Value: localBalance, Value: localBalance,
PkScript: selfScript.PkScript(), PkScript: selfScript.PkScript(),
}, },
HashType: txscript.SigHashAll, HashType: sweepSigHash(chanState.ChanType),
}, },
MaturityDelay: maturityDelay, MaturityDelay: maturityDelay,
} }
@ -6501,8 +6501,6 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Si
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO(roasbeef): put ctrl block in resolution?
} }
} }
@ -6712,7 +6710,7 @@ func newOutgoingHtlcResolution(signer input.Signer,
PkScript: htlcPkScript, PkScript: htlcPkScript,
Value: int64(htlc.Amt.ToSatoshis()), Value: int64(htlc.Amt.ToSatoshis()),
}, },
HashType: txscript.SigHashAll, HashType: sweepSigHash(chanType),
} }
scriptTree, ok := htlcScriptInfo.(input.TapscriptDescriptor) scriptTree, ok := htlcScriptInfo.(input.TapscriptDescriptor)
@ -6772,7 +6770,7 @@ func newOutgoingHtlcResolution(signer input.Signer,
SingleTweak: keyRing.LocalHtlcKeyTweak, SingleTweak: keyRing.LocalHtlcKeyTweak,
WitnessScript: htlcWitnessScript, WitnessScript: htlcWitnessScript,
Output: txOut, Output: txOut,
HashType: txscript.SigHashAll, HashType: sweepSigHash(chanType),
PrevOutputFetcher: prevFetcher, PrevOutputFetcher: prevFetcher,
SigHashes: hashCache, SigHashes: hashCache,
InputIndex: 0, InputIndex: 0,
@ -6788,9 +6786,7 @@ func newOutgoingHtlcResolution(signer input.Signer,
sigHashType := HtlcSigHashType(chanType) sigHashType := HtlcSigHashType(chanType)
var timeoutWitness wire.TxWitness var timeoutWitness wire.TxWitness
if scriptTree, ok := htlcScriptInfo.(input.TapscriptDescriptor); ok { if scriptTree, ok := htlcScriptInfo.(input.TapscriptDescriptor); ok {
// TODO(roasbeef): make sure default elsewhere
timeoutSignDesc.SignMethod = input.TaprootScriptSpendSignMethod timeoutSignDesc.SignMethod = input.TaprootScriptSpendSignMethod
timeoutSignDesc.HashType = txscript.SigHashDefault
timeoutWitness, err = input.SenderHTLCScriptTaprootTimeout( timeoutWitness, err = input.SenderHTLCScriptTaprootTimeout(
htlcSig, sigHashType, signer, &timeoutSignDesc, htlcSig, sigHashType, signer, &timeoutSignDesc,
@ -6896,7 +6892,7 @@ func newOutgoingHtlcResolution(signer input.Signer,
PkScript: htlcSweepScript.PkScript(), PkScript: htlcSweepScript.PkScript(),
Value: int64(secondLevelOutputAmt), Value: int64(secondLevelOutputAmt),
}, },
HashType: txscript.SigHashAll, HashType: sweepSigHash(chanType),
SignMethod: signMethod, SignMethod: signMethod,
ControlBlock: ctrlBlock, ControlBlock: ctrlBlock,
}, },
@ -6957,7 +6953,7 @@ func newIncomingHtlcResolution(signer input.Signer,
PkScript: htlcPkScript, PkScript: htlcPkScript,
Value: int64(htlc.Amt.ToSatoshis()), Value: int64(htlc.Amt.ToSatoshis()),
}, },
HashType: txscript.SigHashAll, HashType: sweepSigHash(chanType),
} }
//nolint:lll //nolint:lll
@ -7009,7 +7005,7 @@ func newIncomingHtlcResolution(signer input.Signer,
SingleTweak: keyRing.LocalHtlcKeyTweak, SingleTweak: keyRing.LocalHtlcKeyTweak,
WitnessScript: htlcWitnessScript, WitnessScript: htlcWitnessScript,
Output: txOut, Output: txOut,
HashType: txscript.SigHashAll, HashType: sweepSigHash(chanType),
PrevOutputFetcher: prevFetcher, PrevOutputFetcher: prevFetcher,
SigHashes: hashCache, SigHashes: hashCache,
InputIndex: 0, InputIndex: 0,
@ -7027,7 +7023,6 @@ func newIncomingHtlcResolution(signer input.Signer,
var successWitness wire.TxWitness var successWitness wire.TxWitness
sigHashType := HtlcSigHashType(chanType) sigHashType := HtlcSigHashType(chanType)
if scriptTree, ok := scriptInfo.(input.TapscriptDescriptor); ok { if scriptTree, ok := scriptInfo.(input.TapscriptDescriptor); ok {
successSignDesc.HashType = txscript.SigHashDefault
successSignDesc.SignMethod = input.TaprootScriptSpendSignMethod successSignDesc.SignMethod = input.TaprootScriptSpendSignMethod
successWitness, err = input.ReceiverHTLCScriptTaprootRedeem( successWitness, err = input.ReceiverHTLCScriptTaprootRedeem(
@ -7101,8 +7096,6 @@ func newIncomingHtlcResolution(signer input.Signer,
return nil, err return nil, err
} }
// TODO(roasbeef): conslidate logic to reduce vertical noise
htlcSweepScript = secondLevelScriptTree htlcSweepScript = secondLevelScriptTree
} }
@ -7134,7 +7127,7 @@ func newIncomingHtlcResolution(signer input.Signer,
PkScript: htlcSweepScript.PkScript(), PkScript: htlcSweepScript.PkScript(),
Value: int64(secondLevelOutputAmt), Value: int64(secondLevelOutputAmt),
}, },
HashType: txscript.SigHashAll, HashType: sweepSigHash(chanType),
SignMethod: signMethod, SignMethod: signMethod,
ControlBlock: ctrlBlock, ControlBlock: ctrlBlock,
}, },
@ -7422,15 +7415,13 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel,
PkScript: delayOut.PkScript, PkScript: delayOut.PkScript,
Value: localBalance, Value: localBalance,
}, },
HashType: txscript.SigHashAll, HashType: sweepSigHash(chanState.ChanType),
}, },
MaturityDelay: csvTimeout, MaturityDelay: csvTimeout,
} }
// For taproot channels, we'll need to set some additional // For taproot channels, we'll need to set some additional
// fields to ensure the output can be swept. // fields to ensure the output can be swept.
//
// TODO(roasbef): abstract into new func
scriptTree, ok := toLocalScript.(input.TapscriptDescriptor) scriptTree, ok := toLocalScript.(input.TapscriptDescriptor)
if ok { if ok {
commitResolution.SelfOutputSignDesc.SignMethod = commitResolution.SelfOutputSignDesc.SignMethod =
@ -7874,14 +7865,13 @@ func NewAnchorResolution(chanState *channeldb.OpenChannel,
PkScript: localAnchor.PkScript(), PkScript: localAnchor.PkScript(),
Value: int64(anchorSize), Value: int64(anchorSize),
}, },
HashType: txscript.SigHashAll, HashType: sweepSigHash(chanState.ChanType),
} }
// For taproot outputs, we'll need to ensure that the proper sign // For taproot outputs, we'll need to ensure that the proper sign
// method is used, and the tweak as well. // method is used, and the tweak as well.
if scriptTree, ok := localAnchor.(input.TapscriptDescriptor); ok { if scriptTree, ok := localAnchor.(input.TapscriptDescriptor); ok {
signDesc.SignMethod = input.TaprootKeySpendSignMethod signDesc.SignMethod = input.TaprootKeySpendSignMethod
signDesc.HashType = txscript.SigHashDefault
//nolint:lll //nolint:lll
signDesc.PrevOutputFetcher = txscript.NewCannedPrevOutputFetcher( signDesc.PrevOutputFetcher = txscript.NewCannedPrevOutputFetcher(

View file

@ -400,6 +400,17 @@ func HtlcSecondLevelInputSequence(chanType channeldb.ChannelType) uint32 {
return 0 return 0
} }
// sweepSigHash returns the sign descriptor to use when signing a sweep
// transaction. For taproot channels, we'll use this to always sweep with
// sighash default.
func sweepSigHash(chanType channeldb.ChannelType) txscript.SigHashType {
if chanType.IsTaproot() {
return txscript.SigHashDefault
}
return txscript.SigHashAll
}
// SecondLevelHtlcScript derives the appropriate second level HTLC script based // SecondLevelHtlcScript derives the appropriate second level HTLC script based
// on the channel's commitment type. It is the uniform script that's used as the // on the channel's commitment type. It is the uniform script that's used as the
// output for the second-level HTLC transactions. The second level transaction // output for the second-level HTLC transactions. The second level transaction