watchtower/wtclient: prep client for taproot towers

This commit is contained in:
Elle Mouton 2023-05-31 09:44:28 +02:00
parent c50aa10194
commit 660f1f361e
No known key found for this signature in database
GPG key ID: D7D916376026F177
7 changed files with 207 additions and 56 deletions

View file

@ -80,6 +80,19 @@ const (
TypeAltruistTaprootCommit = Type(FlagCommitOutputs | FlagTaprootChannel) TypeAltruistTaprootCommit = Type(FlagCommitOutputs | FlagTaprootChannel)
) )
// TypeFromChannel returns the appropriate blob Type for the given channel
// type.
func TypeFromChannel(chanType channeldb.ChannelType) Type {
switch {
case chanType.IsTaproot():
return TypeAltruistTaprootCommit
case chanType.HasAnchors():
return TypeAltruistAnchorCommit
default:
return TypeAltruistCommit
}
}
// Identifier returns a unique, stable string identifier for the blob Type. // Identifier returns a unique, stable string identifier for the blob Type.
func (t Type) Identifier() (string, error) { func (t Type) Identifier() (string, error) {
switch t { switch t {

View file

@ -213,12 +213,6 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody,
} }
} }
if chanType.HasAnchors() != session.Policy.IsAnchorChannel() {
log.Criticalf("Invalid task (has_anchors=%t) for session "+
"(has_anchors=%t)", chanType.HasAnchors(),
session.Policy.IsAnchorChannel())
}
// Now, compute the output values depending on whether FlagReward is set // Now, compute the output values depending on whether FlagReward is set
// in the current session's policy. // in the current session's policy.
outputs, err := session.Policy.ComputeJusticeTxOuts( outputs, err := session.Policy.ComputeJusticeTxOuts(
@ -334,6 +328,7 @@ func (t *backupTask) craftSessionPayload(
switch inp.WitnessType() { switch inp.WitnessType() {
case toLocalWitnessType: case toLocalWitnessType:
justiceKit.AddToLocalSig(signature) justiceKit.AddToLocalSig(signature)
case toRemoteWitnessType: case toRemoteWitnessType:
justiceKit.AddToRemoteSig(signature) justiceKit.AddToRemoteSig(signature)
default: default:

View file

@ -85,9 +85,11 @@ func genTaskTest(
bindErr error, bindErr error,
chanType channeldb.ChannelType) backupTaskTest { chanType channeldb.ChannelType) backupTaskTest {
// Set the anchor flag in the blob type if the session needs to support // Set the anchor or taproot flag in the blob type if the session needs
// anchor channels. // to support anchor or taproot channels.
if chanType.HasAnchors() { if chanType.IsTaproot() {
blobType |= blob.Type(blob.FlagTaprootChannel)
} else if chanType.HasAnchors() {
blobType |= blob.Type(blob.FlagAnchorChannel) blobType |= blob.Type(blob.FlagAnchorChannel)
} }
@ -129,30 +131,112 @@ func genTaskTest(
// to that output as local, though relative to their commitment, it is // to that output as local, though relative to their commitment, it is
// paying to-the-remote party (which is us). // paying to-the-remote party (which is us).
if toLocalAmt > 0 { if toLocalAmt > 0 {
toLocalSignDesc := &input.SignDescriptor{ var toLocalSignDesc *input.SignDescriptor
KeyDesc: keychain.KeyDescriptor{
KeyLocator: revKeyLoc, if chanType.IsTaproot() {
PubKey: revPK, scriptTree, _ := input.NewLocalCommitScriptTree(
}, csvDelay, toLocalPK, revPK,
Output: &wire.TxOut{ )
Value: toLocalAmt,
}, pkScript, _ := input.PayToTaprootScript(
HashType: txscript.SigHashAll, scriptTree.TaprootKey,
)
revokeTapleafHash := txscript.NewBaseTapLeaf(
scriptTree.RevocationLeaf.Script,
).TapHash()
tapTree := scriptTree.TapscriptTree
revokeIdx := tapTree.LeafProofIndex[revokeTapleafHash]
revokeMerkleProof := tapTree.LeafMerkleProofs[revokeIdx]
revokeControlBlock := revokeMerkleProof.ToControlBlock(
&input.TaprootNUMSKey,
)
ctrlBytes, _ := revokeControlBlock.ToBytes()
toLocalSignDesc = &input.SignDescriptor{
KeyDesc: keychain.KeyDescriptor{
KeyLocator: revKeyLoc,
PubKey: revPK,
},
Output: &wire.TxOut{
Value: toLocalAmt,
PkScript: pkScript,
},
WitnessScript: scriptTree.RevocationLeaf.Script,
SignMethod: input.TaprootScriptSpendSignMethod, //nolint:lll
HashType: txscript.SigHashDefault,
ControlBlock: ctrlBytes,
}
} else {
toLocalSignDesc = &input.SignDescriptor{
KeyDesc: keychain.KeyDescriptor{
KeyLocator: revKeyLoc,
PubKey: revPK,
},
Output: &wire.TxOut{
Value: toLocalAmt,
},
HashType: txscript.SigHashAll,
}
} }
breachInfo.RemoteOutputSignDesc = toLocalSignDesc breachInfo.RemoteOutputSignDesc = toLocalSignDesc
breachTxn.AddTxOut(toLocalSignDesc.Output) breachTxn.AddTxOut(toLocalSignDesc.Output)
} }
if toRemoteAmt > 0 { if toRemoteAmt > 0 {
toRemoteSignDesc := &input.SignDescriptor{ var toRemoteSignDesc *input.SignDescriptor
KeyDesc: keychain.KeyDescriptor{
KeyLocator: toRemoteKeyLoc, if chanType.IsTaproot() {
PubKey: toRemotePK, scriptTree, _ := input.NewRemoteCommitScriptTree(
}, toRemotePK,
Output: &wire.TxOut{ )
Value: toRemoteAmt,
}, pkScript, _ := input.PayToTaprootScript(
HashType: txscript.SigHashAll, scriptTree.TaprootKey,
)
revokeTapleafHash := txscript.NewBaseTapLeaf(
scriptTree.SettleLeaf.Script,
).TapHash()
tapTree := scriptTree.TapscriptTree
revokeIdx := tapTree.LeafProofIndex[revokeTapleafHash]
revokeMerkleProof := tapTree.LeafMerkleProofs[revokeIdx]
revokeControlBlock := revokeMerkleProof.ToControlBlock(
&input.TaprootNUMSKey,
)
ctrlBytes, _ := revokeControlBlock.ToBytes()
toRemoteSignDesc = &input.SignDescriptor{
KeyDesc: keychain.KeyDescriptor{
KeyLocator: toRemoteKeyLoc,
PubKey: toRemotePK,
},
WitnessScript: scriptTree.SettleLeaf.Script,
SignMethod: input.TaprootScriptSpendSignMethod, //nolint:lll
Output: &wire.TxOut{
Value: toRemoteAmt,
PkScript: pkScript,
},
HashType: txscript.SigHashDefault,
InputIndex: 1,
ControlBlock: ctrlBytes,
}
} else {
toRemoteSignDesc = &input.SignDescriptor{
KeyDesc: keychain.KeyDescriptor{
KeyLocator: toRemoteKeyLoc,
PubKey: toRemotePK,
},
Output: &wire.TxOut{
Value: toRemoteAmt,
},
HashType: txscript.SigHashAll,
}
} }
breachInfo.LocalOutputSignDesc = toRemoteSignDesc breachInfo.LocalOutputSignDesc = toRemoteSignDesc
breachTxn.AddTxOut(toRemoteSignDesc.Output) breachTxn.AddTxOut(toRemoteSignDesc.Output)
} }
@ -248,6 +332,7 @@ func TestBackupTask(t *testing.T) {
channeldb.SingleFunderBit, channeldb.SingleFunderBit,
channeldb.SingleFunderTweaklessBit, channeldb.SingleFunderTweaklessBit,
channeldb.AnchorOutputsBit, channeldb.AnchorOutputsBit,
channeldb.SimpleTaprootFeatureBit,
} }
var backupTaskTests []backupTaskTest var backupTaskTests []backupTaskTest
@ -272,7 +357,16 @@ func TestBackupTask(t *testing.T) {
sweepFeeRateNoRewardRemoteDust chainfee.SatPerKWeight = 227500 sweepFeeRateNoRewardRemoteDust chainfee.SatPerKWeight = 227500
sweepFeeRateRewardRemoteDust chainfee.SatPerKWeight = 175350 sweepFeeRateRewardRemoteDust chainfee.SatPerKWeight = 175350
) )
if chanType.HasAnchors() { if chanType.IsTaproot() {
expSweepCommitNoRewardBoth = 299165
expSweepCommitNoRewardLocal = 199468
expSweepCommitNoRewardRemote = 99531
sweepFeeRateNoRewardRemoteDust = 213200
expSweepCommitRewardBoth = 295993
expSweepCommitRewardLocal = 197296
expSweepCommitRewardRemote = 98359
sweepFeeRateRewardRemoteDust = 167000
} else if chanType.HasAnchors() {
expSweepCommitNoRewardBoth = 299236 expSweepCommitNoRewardBoth = 299236
expSweepCommitNoRewardLocal = 199513 expSweepCommitNoRewardLocal = 199513
expSweepCommitNoRewardRemote = 99557 expSweepCommitNoRewardRemote = 99557

View file

@ -76,12 +76,12 @@ var (
waitTime = 15 * time.Second waitTime = 15 * time.Second
defaultTxPolicy = wtpolicy.TxPolicy{ defaultTxPolicy = wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit, BlobType: blob.TypeAltruistTaprootCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate, SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
} }
highSweepRateTxPolicy = wtpolicy.TxPolicy{ highSweepRateTxPolicy = wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit, BlobType: blob.TypeAltruistTaprootCommit,
SweepFeeRate: 1000000, // The high sweep fee creates dust. SweepFeeRate: 1000000, // The high sweep fee creates dust.
} }
) )
@ -229,17 +229,25 @@ func (c *mockChannel) createRemoteCommitTx(t *testing.T) {
t.Helper() t.Helper()
// Construct the to-local witness script. // Construct the to-local witness script.
toLocalScript, err := input.CommitScriptToSelf( toLocalScriptTree, err := input.NewLocalCommitScriptTree(
c.csvDelay, c.toLocalPK, c.revPK, c.csvDelay, c.toLocalPK, c.revPK,
) )
require.NoError(t, err, "unable to create to-local script") require.NoError(t, err, "unable to create to-local script")
// Construct the to-remote witness script.
toRemoteScriptTree, err := input.NewRemoteCommitScriptTree(c.toRemotePK)
require.NoError(t, err, "unable to create to-remote script")
// Compute the to-local witness script hash. // Compute the to-local witness script hash.
toLocalScriptHash, err := input.WitnessScriptHash(toLocalScript) toLocalScriptHash, err := input.PayToTaprootScript(
toLocalScriptTree.TaprootKey,
)
require.NoError(t, err, "unable to create to-local witness script hash") require.NoError(t, err, "unable to create to-local witness script hash")
// Compute the to-remote witness script hash. // Compute the to-remote witness script hash.
toRemoteScriptHash, err := input.CommitScriptUnencumbered(c.toRemotePK) toRemoteScriptHash, err := input.PayToTaprootScript(
toRemoteScriptTree.TaprootKey,
)
require.NoError(t, err, "unable to create to-remote script") require.NoError(t, err, "unable to create to-remote script")
// Construct the remote commitment txn, containing the to-local and // Construct the remote commitment txn, containing the to-local and
@ -264,6 +272,19 @@ func (c *mockChannel) createRemoteCommitTx(t *testing.T) {
PkScript: toLocalScriptHash, PkScript: toLocalScriptHash,
}) })
revokeTapleafHash := txscript.NewBaseTapLeaf(
toLocalScriptTree.RevocationLeaf.Script,
).TapHash()
tapTree := toLocalScriptTree.TapscriptTree
revokeIdx := tapTree.LeafProofIndex[revokeTapleafHash]
revokeMerkleProof := tapTree.LeafMerkleProofs[revokeIdx]
revokeControlBlock := revokeMerkleProof.ToControlBlock(
&input.TaprootNUMSKey,
)
ctrlBytes, err := revokeControlBlock.ToBytes()
require.NoError(t, err)
// Create the sign descriptor used to sign for the to-local // Create the sign descriptor used to sign for the to-local
// input. // input.
toLocalSignDesc = &input.SignDescriptor{ toLocalSignDesc = &input.SignDescriptor{
@ -271,9 +292,11 @@ func (c *mockChannel) createRemoteCommitTx(t *testing.T) {
KeyLocator: c.revKeyLoc, KeyLocator: c.revKeyLoc,
PubKey: c.revPK, PubKey: c.revPK,
}, },
WitnessScript: toLocalScript, WitnessScript: toLocalScriptTree.RevocationLeaf.Script,
Output: commitTxn.TxOut[outputIndex], Output: commitTxn.TxOut[outputIndex],
HashType: txscript.SigHashAll, HashType: txscript.SigHashDefault,
SignMethod: input.TaprootScriptSpendSignMethod,
ControlBlock: ctrlBytes,
} }
outputIndex++ outputIndex++
} }
@ -283,6 +306,18 @@ func (c *mockChannel) createRemoteCommitTx(t *testing.T) {
PkScript: toRemoteScriptHash, PkScript: toRemoteScriptHash,
}) })
toRemoteTapleafHash := txscript.NewBaseTapLeaf(
toRemoteScriptTree.SettleLeaf.Script,
).TapHash()
tapTree := toRemoteScriptTree.TapscriptTree
remoteIdx := tapTree.LeafProofIndex[toRemoteTapleafHash]
remoteMerkleProof := tapTree.LeafMerkleProofs[remoteIdx]
remoteControlBlock := remoteMerkleProof.ToControlBlock(
&input.TaprootNUMSKey,
)
ctrlBytes, _ := remoteControlBlock.ToBytes()
// Create the sign descriptor used to sign for the to-remote // Create the sign descriptor used to sign for the to-remote
// input. // input.
toRemoteSignDesc = &input.SignDescriptor{ toRemoteSignDesc = &input.SignDescriptor{
@ -290,9 +325,11 @@ func (c *mockChannel) createRemoteCommitTx(t *testing.T) {
KeyLocator: c.toRemoteKeyLoc, KeyLocator: c.toRemoteKeyLoc,
PubKey: c.toRemotePK, PubKey: c.toRemotePK,
}, },
WitnessScript: toRemoteScriptHash, WitnessScript: toRemoteScriptTree.SettleLeaf.Script,
Output: commitTxn.TxOut[outputIndex], Output: commitTxn.TxOut[outputIndex],
HashType: txscript.SigHashAll, HashType: txscript.SigHashDefault,
SignMethod: input.TaprootScriptSpendSignMethod,
ControlBlock: ctrlBytes,
} }
outputIndex++ outputIndex++
} }
@ -516,7 +553,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
_, retribution := h.channelFromID(id).getState(commitHeight) _, retribution := h.channelFromID(id).getState(commitHeight)
return retribution, channeldb.SingleFunderBit, nil return retribution, channeldb.SimpleTaprootFeatureBit, nil
} }
if !cfg.noServerStart { if !cfg.noServerStart {
@ -664,7 +701,9 @@ func (h *testHarness) registerChannel(id uint64) {
h.t.Helper() h.t.Helper()
chanID := chanIDFromInt(id) chanID := chanIDFromInt(id)
err := h.clientMgr.RegisterChannel(chanID, channeldb.SingleFunderBit) err := h.clientMgr.RegisterChannel(
chanID, channeldb.SimpleTaprootFeatureBit,
)
require.NoError(h.t, err) require.NoError(h.t, err)
} }
@ -1404,7 +1443,7 @@ var clientTests = []clientTest{
// Wait for all the updates to be populated in the // Wait for all the updates to be populated in the
// server's database. // server's database.
h.server.waitForUpdates(hints, 10*time.Second) h.server.waitForUpdates(hints, waitTime)
}, },
}, },
{ {

View file

@ -523,10 +523,7 @@ func (m *Manager) Policy(blobType blob.Type) (wtpolicy.Policy, error) {
func (m *Manager) RegisterChannel(id lnwire.ChannelID, func (m *Manager) RegisterChannel(id lnwire.ChannelID,
chanType channeldb.ChannelType) error { chanType channeldb.ChannelType) error {
blobType := blob.TypeAltruistCommit blobType := blob.TypeFromChannel(chanType)
if chanType.HasAnchors() {
blobType = blob.TypeAltruistAnchorCommit
}
m.clientsMu.Lock() m.clientsMu.Lock()
if _, ok := m.clients[blobType]; !ok { if _, ok := m.clients[blobType]; !ok {

View file

@ -120,15 +120,8 @@ var _ SessionNegotiator = (*sessionNegotiator)(nil)
// newSessionNegotiator initializes a fresh sessionNegotiator instance. // newSessionNegotiator initializes a fresh sessionNegotiator instance.
func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator { func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator {
// Generate the set of features the negotiator will present to the tower // Generate the set of features the negotiator will present to the tower
// upon connection. For anchor channels, we'll conditionally signal that // upon connection.
// we require support for anchor channels depending on the requested features := cfg.Policy.FeatureBits()
// policy.
features := []lnwire.FeatureBit{
wtwire.AltruistSessionsRequired,
}
if cfg.Policy.IsAnchorChannel() {
features = append(features, wtwire.AnchorCommitRequired)
}
localInit := wtwire.NewInitMessage( localInit := wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(features...), lnwire.NewRawFeatureVector(features...),

View file

@ -8,7 +8,9 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
) )
const ( const (
@ -120,20 +122,38 @@ func (p Policy) String() string {
p.SweepFeeRate) p.SweepFeeRate)
} }
// FeatureBits returns the watchtower feature bits required for the given
// policy.
func (p *Policy) FeatureBits() []lnwire.FeatureBit {
features := []lnwire.FeatureBit{
wtwire.AltruistSessionsRequired,
}
t := p.TxPolicy.BlobType
switch {
case t.IsTaprootChannel():
features = append(features, wtwire.TaprootCommitRequired)
case t.IsAnchorChannel():
features = append(features, wtwire.AnchorCommitRequired)
}
return features
}
// IsAnchorChannel returns true if the session policy requires anchor channels. // IsAnchorChannel returns true if the session policy requires anchor channels.
func (p Policy) IsAnchorChannel() bool { func (p *Policy) IsAnchorChannel() bool {
return p.TxPolicy.BlobType.IsAnchorChannel() return p.TxPolicy.BlobType.IsAnchorChannel()
} }
// IsTaprootChannel returns true if the session policy requires taproot // IsTaprootChannel returns true if the session policy requires taproot
// channels. // channels.
func (p Policy) IsTaprootChannel() bool { func (p *Policy) IsTaprootChannel() bool {
return p.TxPolicy.BlobType.IsTaprootChannel() return p.TxPolicy.BlobType.IsTaprootChannel()
} }
// Validate ensures that the policy satisfies some minimal correctness // Validate ensures that the policy satisfies some minimal correctness
// constraints. // constraints.
func (p Policy) Validate() error { func (p *Policy) Validate() error {
// RewardBase and RewardRate should not be set if the policy doesn't // RewardBase and RewardRate should not be set if the policy doesn't
// have a reward. // have a reward.
if !p.BlobType.Has(blob.FlagReward) && if !p.BlobType.Has(blob.FlagReward) &&