diff --git a/watchtower/blob/type.go b/watchtower/blob/type.go index 83dfcf184..aee163ec0 100644 --- a/watchtower/blob/type.go +++ b/watchtower/blob/type.go @@ -80,6 +80,19 @@ const ( TypeAltruistTaprootCommit = Type(FlagCommitOutputs | FlagTaprootChannel) ) +// TypeFromChannel returns the appropriate blob Type for the given channel +// type. +func TypeFromChannel(chanType channeldb.ChannelType) Type { + switch { + case chanType.IsTaproot(): + return TypeAltruistTaprootCommit + case chanType.HasAnchors(): + return TypeAltruistAnchorCommit + default: + return TypeAltruistCommit + } +} + // Identifier returns a unique, stable string identifier for the blob Type. func (t Type) Identifier() (string, error) { switch t { diff --git a/watchtower/wtclient/backup_task.go b/watchtower/wtclient/backup_task.go index a458afecb..44a82aacb 100644 --- a/watchtower/wtclient/backup_task.go +++ b/watchtower/wtclient/backup_task.go @@ -213,12 +213,6 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody, } } - if chanType.HasAnchors() != session.Policy.IsAnchorChannel() { - log.Criticalf("Invalid task (has_anchors=%t) for session "+ - "(has_anchors=%t)", chanType.HasAnchors(), - session.Policy.IsAnchorChannel()) - } - // Now, compute the output values depending on whether FlagReward is set // in the current session's policy. outputs, err := session.Policy.ComputeJusticeTxOuts( @@ -334,6 +328,7 @@ func (t *backupTask) craftSessionPayload( switch inp.WitnessType() { case toLocalWitnessType: justiceKit.AddToLocalSig(signature) + case toRemoteWitnessType: justiceKit.AddToRemoteSig(signature) default: diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index 6604935ea..d8c207c0f 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -85,9 +85,11 @@ func genTaskTest( bindErr error, chanType channeldb.ChannelType) backupTaskTest { - // Set the anchor flag in the blob type if the session needs to support - // anchor channels. - if chanType.HasAnchors() { + // Set the anchor or taproot flag in the blob type if the session needs + // to support anchor or taproot channels. + if chanType.IsTaproot() { + blobType |= blob.Type(blob.FlagTaprootChannel) + } else if chanType.HasAnchors() { blobType |= blob.Type(blob.FlagAnchorChannel) } @@ -129,30 +131,112 @@ func genTaskTest( // to that output as local, though relative to their commitment, it is // paying to-the-remote party (which is us). if toLocalAmt > 0 { - toLocalSignDesc := &input.SignDescriptor{ - KeyDesc: keychain.KeyDescriptor{ - KeyLocator: revKeyLoc, - PubKey: revPK, - }, - Output: &wire.TxOut{ - Value: toLocalAmt, - }, - HashType: txscript.SigHashAll, + var toLocalSignDesc *input.SignDescriptor + + if chanType.IsTaproot() { + scriptTree, _ := input.NewLocalCommitScriptTree( + csvDelay, toLocalPK, revPK, + ) + + pkScript, _ := input.PayToTaprootScript( + scriptTree.TaprootKey, + ) + + revokeTapleafHash := txscript.NewBaseTapLeaf( + scriptTree.RevocationLeaf.Script, + ).TapHash() + + tapTree := scriptTree.TapscriptTree + revokeIdx := tapTree.LeafProofIndex[revokeTapleafHash] + revokeMerkleProof := tapTree.LeafMerkleProofs[revokeIdx] + revokeControlBlock := revokeMerkleProof.ToControlBlock( + &input.TaprootNUMSKey, + ) + ctrlBytes, _ := revokeControlBlock.ToBytes() + + toLocalSignDesc = &input.SignDescriptor{ + KeyDesc: keychain.KeyDescriptor{ + KeyLocator: revKeyLoc, + PubKey: revPK, + }, + Output: &wire.TxOut{ + Value: toLocalAmt, + PkScript: pkScript, + }, + WitnessScript: scriptTree.RevocationLeaf.Script, + SignMethod: input.TaprootScriptSpendSignMethod, //nolint:lll + HashType: txscript.SigHashDefault, + ControlBlock: ctrlBytes, + } + } else { + toLocalSignDesc = &input.SignDescriptor{ + KeyDesc: keychain.KeyDescriptor{ + KeyLocator: revKeyLoc, + PubKey: revPK, + }, + Output: &wire.TxOut{ + Value: toLocalAmt, + }, + HashType: txscript.SigHashAll, + } } + breachInfo.RemoteOutputSignDesc = toLocalSignDesc breachTxn.AddTxOut(toLocalSignDesc.Output) } if toRemoteAmt > 0 { - toRemoteSignDesc := &input.SignDescriptor{ - KeyDesc: keychain.KeyDescriptor{ - KeyLocator: toRemoteKeyLoc, - PubKey: toRemotePK, - }, - Output: &wire.TxOut{ - Value: toRemoteAmt, - }, - HashType: txscript.SigHashAll, + var toRemoteSignDesc *input.SignDescriptor + + if chanType.IsTaproot() { + scriptTree, _ := input.NewRemoteCommitScriptTree( + toRemotePK, + ) + + pkScript, _ := input.PayToTaprootScript( + scriptTree.TaprootKey, + ) + + revokeTapleafHash := txscript.NewBaseTapLeaf( + scriptTree.SettleLeaf.Script, + ).TapHash() + + tapTree := scriptTree.TapscriptTree + revokeIdx := tapTree.LeafProofIndex[revokeTapleafHash] + revokeMerkleProof := tapTree.LeafMerkleProofs[revokeIdx] + revokeControlBlock := revokeMerkleProof.ToControlBlock( + &input.TaprootNUMSKey, + ) + + ctrlBytes, _ := revokeControlBlock.ToBytes() + + toRemoteSignDesc = &input.SignDescriptor{ + KeyDesc: keychain.KeyDescriptor{ + KeyLocator: toRemoteKeyLoc, + PubKey: toRemotePK, + }, + WitnessScript: scriptTree.SettleLeaf.Script, + SignMethod: input.TaprootScriptSpendSignMethod, //nolint:lll + Output: &wire.TxOut{ + Value: toRemoteAmt, + PkScript: pkScript, + }, + HashType: txscript.SigHashDefault, + InputIndex: 1, + ControlBlock: ctrlBytes, + } + } else { + toRemoteSignDesc = &input.SignDescriptor{ + KeyDesc: keychain.KeyDescriptor{ + KeyLocator: toRemoteKeyLoc, + PubKey: toRemotePK, + }, + Output: &wire.TxOut{ + Value: toRemoteAmt, + }, + HashType: txscript.SigHashAll, + } } + breachInfo.LocalOutputSignDesc = toRemoteSignDesc breachTxn.AddTxOut(toRemoteSignDesc.Output) } @@ -248,6 +332,7 @@ func TestBackupTask(t *testing.T) { channeldb.SingleFunderBit, channeldb.SingleFunderTweaklessBit, channeldb.AnchorOutputsBit, + channeldb.SimpleTaprootFeatureBit, } var backupTaskTests []backupTaskTest @@ -272,7 +357,16 @@ func TestBackupTask(t *testing.T) { sweepFeeRateNoRewardRemoteDust chainfee.SatPerKWeight = 227500 sweepFeeRateRewardRemoteDust chainfee.SatPerKWeight = 175350 ) - if chanType.HasAnchors() { + if chanType.IsTaproot() { + expSweepCommitNoRewardBoth = 299165 + expSweepCommitNoRewardLocal = 199468 + expSweepCommitNoRewardRemote = 99531 + sweepFeeRateNoRewardRemoteDust = 213200 + expSweepCommitRewardBoth = 295993 + expSweepCommitRewardLocal = 197296 + expSweepCommitRewardRemote = 98359 + sweepFeeRateRewardRemoteDust = 167000 + } else if chanType.HasAnchors() { expSweepCommitNoRewardBoth = 299236 expSweepCommitNoRewardLocal = 199513 expSweepCommitNoRewardRemote = 99557 diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index c2380dbff..81674f0c0 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -76,12 +76,12 @@ var ( waitTime = 15 * time.Second defaultTxPolicy = wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, + BlobType: blob.TypeAltruistTaprootCommit, SweepFeeRate: wtpolicy.DefaultSweepFeeRate, } highSweepRateTxPolicy = wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, + BlobType: blob.TypeAltruistTaprootCommit, SweepFeeRate: 1000000, // The high sweep fee creates dust. } ) @@ -229,17 +229,25 @@ func (c *mockChannel) createRemoteCommitTx(t *testing.T) { t.Helper() // Construct the to-local witness script. - toLocalScript, err := input.CommitScriptToSelf( + toLocalScriptTree, err := input.NewLocalCommitScriptTree( c.csvDelay, c.toLocalPK, c.revPK, ) require.NoError(t, err, "unable to create to-local script") + // Construct the to-remote witness script. + toRemoteScriptTree, err := input.NewRemoteCommitScriptTree(c.toRemotePK) + require.NoError(t, err, "unable to create to-remote script") + // Compute the to-local witness script hash. - toLocalScriptHash, err := input.WitnessScriptHash(toLocalScript) + toLocalScriptHash, err := input.PayToTaprootScript( + toLocalScriptTree.TaprootKey, + ) require.NoError(t, err, "unable to create to-local witness script hash") // Compute the to-remote witness script hash. - toRemoteScriptHash, err := input.CommitScriptUnencumbered(c.toRemotePK) + toRemoteScriptHash, err := input.PayToTaprootScript( + toRemoteScriptTree.TaprootKey, + ) require.NoError(t, err, "unable to create to-remote script") // Construct the remote commitment txn, containing the to-local and @@ -264,6 +272,19 @@ func (c *mockChannel) createRemoteCommitTx(t *testing.T) { PkScript: toLocalScriptHash, }) + revokeTapleafHash := txscript.NewBaseTapLeaf( + toLocalScriptTree.RevocationLeaf.Script, + ).TapHash() + tapTree := toLocalScriptTree.TapscriptTree + revokeIdx := tapTree.LeafProofIndex[revokeTapleafHash] + revokeMerkleProof := tapTree.LeafMerkleProofs[revokeIdx] + revokeControlBlock := revokeMerkleProof.ToControlBlock( + &input.TaprootNUMSKey, + ) + + ctrlBytes, err := revokeControlBlock.ToBytes() + require.NoError(t, err) + // Create the sign descriptor used to sign for the to-local // input. toLocalSignDesc = &input.SignDescriptor{ @@ -271,9 +292,11 @@ func (c *mockChannel) createRemoteCommitTx(t *testing.T) { KeyLocator: c.revKeyLoc, PubKey: c.revPK, }, - WitnessScript: toLocalScript, + WitnessScript: toLocalScriptTree.RevocationLeaf.Script, Output: commitTxn.TxOut[outputIndex], - HashType: txscript.SigHashAll, + HashType: txscript.SigHashDefault, + SignMethod: input.TaprootScriptSpendSignMethod, + ControlBlock: ctrlBytes, } outputIndex++ } @@ -283,6 +306,18 @@ func (c *mockChannel) createRemoteCommitTx(t *testing.T) { PkScript: toRemoteScriptHash, }) + toRemoteTapleafHash := txscript.NewBaseTapLeaf( + toRemoteScriptTree.SettleLeaf.Script, + ).TapHash() + tapTree := toRemoteScriptTree.TapscriptTree + remoteIdx := tapTree.LeafProofIndex[toRemoteTapleafHash] + remoteMerkleProof := tapTree.LeafMerkleProofs[remoteIdx] + remoteControlBlock := remoteMerkleProof.ToControlBlock( + &input.TaprootNUMSKey, + ) + + ctrlBytes, _ := remoteControlBlock.ToBytes() + // Create the sign descriptor used to sign for the to-remote // input. toRemoteSignDesc = &input.SignDescriptor{ @@ -290,9 +325,11 @@ func (c *mockChannel) createRemoteCommitTx(t *testing.T) { KeyLocator: c.toRemoteKeyLoc, PubKey: c.toRemotePK, }, - WitnessScript: toRemoteScriptHash, + WitnessScript: toRemoteScriptTree.SettleLeaf.Script, Output: commitTxn.TxOut[outputIndex], - HashType: txscript.SigHashAll, + HashType: txscript.SigHashDefault, + SignMethod: input.TaprootScriptSpendSignMethod, + ControlBlock: ctrlBytes, } outputIndex++ } @@ -516,7 +553,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { _, retribution := h.channelFromID(id).getState(commitHeight) - return retribution, channeldb.SingleFunderBit, nil + return retribution, channeldb.SimpleTaprootFeatureBit, nil } if !cfg.noServerStart { @@ -664,7 +701,9 @@ func (h *testHarness) registerChannel(id uint64) { h.t.Helper() chanID := chanIDFromInt(id) - err := h.clientMgr.RegisterChannel(chanID, channeldb.SingleFunderBit) + err := h.clientMgr.RegisterChannel( + chanID, channeldb.SimpleTaprootFeatureBit, + ) require.NoError(h.t, err) } @@ -1404,7 +1443,7 @@ var clientTests = []clientTest{ // Wait for all the updates to be populated in the // server's database. - h.server.waitForUpdates(hints, 10*time.Second) + h.server.waitForUpdates(hints, waitTime) }, }, { diff --git a/watchtower/wtclient/manager.go b/watchtower/wtclient/manager.go index 73f259085..70f2a0542 100644 --- a/watchtower/wtclient/manager.go +++ b/watchtower/wtclient/manager.go @@ -523,10 +523,7 @@ func (m *Manager) Policy(blobType blob.Type) (wtpolicy.Policy, error) { func (m *Manager) RegisterChannel(id lnwire.ChannelID, chanType channeldb.ChannelType) error { - blobType := blob.TypeAltruistCommit - if chanType.HasAnchors() { - blobType = blob.TypeAltruistAnchorCommit - } + blobType := blob.TypeFromChannel(chanType) m.clientsMu.Lock() if _, ok := m.clients[blobType]; !ok { diff --git a/watchtower/wtclient/session_negotiator.go b/watchtower/wtclient/session_negotiator.go index 2b1e988a1..ad359233e 100644 --- a/watchtower/wtclient/session_negotiator.go +++ b/watchtower/wtclient/session_negotiator.go @@ -120,15 +120,8 @@ var _ SessionNegotiator = (*sessionNegotiator)(nil) // newSessionNegotiator initializes a fresh sessionNegotiator instance. func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator { // Generate the set of features the negotiator will present to the tower - // upon connection. For anchor channels, we'll conditionally signal that - // we require support for anchor channels depending on the requested - // policy. - features := []lnwire.FeatureBit{ - wtwire.AltruistSessionsRequired, - } - if cfg.Policy.IsAnchorChannel() { - features = append(features, wtwire.AnchorCommitRequired) - } + // upon connection. + features := cfg.Policy.FeatureBits() localInit := wtwire.NewInitMessage( lnwire.NewRawFeatureVector(features...), diff --git a/watchtower/wtpolicy/policy.go b/watchtower/wtpolicy/policy.go index c6a2467c6..9d5763fb2 100644 --- a/watchtower/wtpolicy/policy.go +++ b/watchtower/wtpolicy/policy.go @@ -8,7 +8,9 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/watchtower/blob" + "github.com/lightningnetwork/lnd/watchtower/wtwire" ) const ( @@ -120,20 +122,38 @@ func (p Policy) String() string { p.SweepFeeRate) } +// FeatureBits returns the watchtower feature bits required for the given +// policy. +func (p *Policy) FeatureBits() []lnwire.FeatureBit { + features := []lnwire.FeatureBit{ + wtwire.AltruistSessionsRequired, + } + + t := p.TxPolicy.BlobType + switch { + case t.IsTaprootChannel(): + features = append(features, wtwire.TaprootCommitRequired) + case t.IsAnchorChannel(): + features = append(features, wtwire.AnchorCommitRequired) + } + + return features +} + // IsAnchorChannel returns true if the session policy requires anchor channels. -func (p Policy) IsAnchorChannel() bool { +func (p *Policy) IsAnchorChannel() bool { return p.TxPolicy.BlobType.IsAnchorChannel() } // IsTaprootChannel returns true if the session policy requires taproot // channels. -func (p Policy) IsTaprootChannel() bool { +func (p *Policy) IsTaprootChannel() bool { return p.TxPolicy.BlobType.IsTaprootChannel() } // Validate ensures that the policy satisfies some minimal correctness // constraints. -func (p Policy) Validate() error { +func (p *Policy) Validate() error { // RewardBase and RewardRate should not be set if the policy doesn't // have a reward. if !p.BlobType.Has(blob.FlagReward) &&