diff --git a/watchtower/wtclient/backup_task.go b/watchtower/wtclient/backup_task.go index 3f95ce096..2a02da479 100644 --- a/watchtower/wtclient/backup_task.go +++ b/watchtower/wtclient/backup_task.go @@ -10,7 +10,6 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -40,7 +39,6 @@ import ( type backupTask struct { id wtdb.BackupID breachInfo *lnwallet.BreachRetribution - chanType channeldb.ChannelType // state-dependent variables @@ -55,11 +53,79 @@ type backupTask struct { outputs []*wire.TxOut } -// newBackupTask initializes a new backupTask and populates all state-dependent -// variables. -func newBackupTask(chanID *lnwire.ChannelID, - breachInfo *lnwallet.BreachRetribution, - sweepPkScript []byte, chanType channeldb.ChannelType) *backupTask { +// newBackupTask initializes a new backupTask. +func newBackupTask(id wtdb.BackupID, sweepPkScript []byte) *backupTask { + return &backupTask{ + id: id, + sweepPkScript: sweepPkScript, + } +} + +// inputs returns all non-dust inputs that we will attempt to spend from. +// +// NOTE: Ordering of the inputs is not critical as we sort the transaction with +// BIP69 in a later stage. +func (t *backupTask) inputs() map[wire.OutPoint]input.Input { + inputs := make(map[wire.OutPoint]input.Input) + if t.toLocalInput != nil { + inputs[*t.toLocalInput.OutPoint()] = t.toLocalInput + } + if t.toRemoteInput != nil { + inputs[*t.toRemoteInput.OutPoint()] = t.toRemoteInput + } + + return inputs +} + +// addrType returns the type of an address after parsing it and matching it to +// the set of known script templates. +func addrType(pkScript []byte) txscript.ScriptClass { + // We pass in a set of dummy chain params here as they're only needed + // to make the address struct, which we're ignoring anyway (scripts are + // always the same, it's addresses that change across chains). + scriptClass, _, _, _ := txscript.ExtractPkScriptAddrs( + pkScript, &chaincfg.MainNetParams, + ) + + return scriptClass +} + +// addScriptWeight parses the passed pkScript and adds the computed weight cost +// were the script to be added to the justice transaction. +func addScriptWeight(weightEstimate *input.TxWeightEstimator, + pkScript []byte) error { + + switch addrType(pkScript) { + case txscript.WitnessV0PubKeyHashTy: + weightEstimate.AddP2WKHOutput() + + case txscript.WitnessV0ScriptHashTy: + weightEstimate.AddP2WSHOutput() + + case txscript.WitnessV1TaprootTy: + weightEstimate.AddP2TROutput() + + default: + return fmt.Errorf("invalid addr type: %v", addrType(pkScript)) + } + + return nil +} + +// bindSession first populates all state-dependent variables of the task. Then +// it determines if the backupTask is compatible with the passed SessionInfo's +// policy. If no error is returned, the task has been bound to the session and +// can be queued to upload to the tower. Otherwise, the bind failed and should +// be rescheduled with a different session. +func (t *backupTask) bindSession(session *wtdb.ClientSessionBody, + newBreachRetribution BreachRetributionBuilder) error { + + breachInfo, chanType, err := newBreachRetribution( + t.id.ChanID, t.id.CommitHeight, + ) + if err != nil { + return err + } // Parse the non-dust outputs from the breach transaction, // simultaneously computing the total amount contained in the inputs @@ -123,76 +189,11 @@ func newBackupTask(chanID *lnwire.ChannelID, totalAmt += breachInfo.LocalOutputSignDesc.Output.Value } - return &backupTask{ - id: wtdb.BackupID{ - ChanID: *chanID, - CommitHeight: breachInfo.RevokedStateNum, - }, - breachInfo: breachInfo, - chanType: chanType, - toLocalInput: toLocalInput, - toRemoteInput: toRemoteInput, - totalAmt: btcutil.Amount(totalAmt), - sweepPkScript: sweepPkScript, - } -} + t.breachInfo = breachInfo + t.toLocalInput = toLocalInput + t.toRemoteInput = toRemoteInput + t.totalAmt = btcutil.Amount(totalAmt) -// inputs returns all non-dust inputs that we will attempt to spend from. -// -// NOTE: Ordering of the inputs is not critical as we sort the transaction with -// BIP69. -func (t *backupTask) inputs() map[wire.OutPoint]input.Input { - inputs := make(map[wire.OutPoint]input.Input) - if t.toLocalInput != nil { - inputs[*t.toLocalInput.OutPoint()] = t.toLocalInput - } - if t.toRemoteInput != nil { - inputs[*t.toRemoteInput.OutPoint()] = t.toRemoteInput - } - return inputs -} - -// addrType returns the type of an address after parsing it and matching it to -// the set of known script templates. -func addrType(pkScript []byte) txscript.ScriptClass { - // We pass in a set of dummy chain params here as they're only needed - // to make the address struct, which we're ignoring anyway (scripts are - // always the same, it's addresses that change across chains). - scriptClass, _, _, _ := txscript.ExtractPkScriptAddrs( - pkScript, &chaincfg.MainNetParams, - ) - - return scriptClass -} - -// addScriptWeight parses the passed pkScript and adds the computed weight cost -// were the script to be added to the justice transaction. -func addScriptWeight(weightEstimate *input.TxWeightEstimator, - pkScript []byte) error { - - switch addrType(pkScript) { //nolint: whitespace - - case txscript.WitnessV0PubKeyHashTy: - weightEstimate.AddP2WKHOutput() - - case txscript.WitnessV0ScriptHashTy: - weightEstimate.AddP2WSHOutput() - - case txscript.WitnessV1TaprootTy: - weightEstimate.AddP2TROutput() - - default: - return fmt.Errorf("invalid addr type: %v", addrType(pkScript)) - } - - return nil -} - -// bindSession determines if the backupTask is compatible with the passed -// SessionInfo's policy. If no error is returned, the task has been bound to the -// session and can be queued to upload to the tower. Otherwise, the bind failed -// and should be rescheduled with a different session. -func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error { // First we'll begin by deriving a weight estimate for the justice // transaction. The final weight can be different depending on whether // the watchtower is taking a reward. @@ -208,7 +209,7 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error { // original weight estimate. For anchor channels we'll go ahead // an use the correct penalty witness when signing our justice // transactions. - if t.chanType.HasAnchors() { + if chanType.HasAnchors() { weightEstimate.AddWitnessInput( input.ToLocalPenaltyWitnessSize, ) @@ -222,7 +223,7 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error { // Legacy channels (both tweaked and non-tweaked) spend from // P2WKH output. Anchor channels spend a to-remote confirmed // P2WSH output. - if t.chanType.HasAnchors() { + if chanType.HasAnchors() { weightEstimate.AddWitnessInput( input.ToRemoteConfirmedWitnessSize, ) @@ -233,7 +234,7 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error { // All justice transactions will either use segwit v0 (p2wkh + p2wsh) // or segwit v1 (p2tr). - err := addScriptWeight(&weightEstimate, t.sweepPkScript) + err = addScriptWeight(&weightEstimate, t.sweepPkScript) if err != nil { return err } @@ -247,9 +248,9 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error { } } - if t.chanType.HasAnchors() != session.Policy.IsAnchorChannel() { + if chanType.HasAnchors() != session.Policy.IsAnchorChannel() { log.Criticalf("Invalid task (has_anchors=%t) for session "+ - "(has_anchors=%t)", t.chanType.HasAnchors(), + "(has_anchors=%t)", chanType.HasAnchors(), session.Policy.IsAnchorChannel()) } diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index c536c433b..43b818044 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -483,19 +483,19 @@ func TestBackupTask(t *testing.T) { func testBackupTask(t *testing.T, test backupTaskTest) { // Create a new backupTask from the channel id and breach info. - task := newBackupTask( - &test.chanID, test.breachInfo, test.expSweepScript, - test.chanType, - ) + id := wtdb.BackupID{ + ChanID: test.chanID, + CommitHeight: test.breachInfo.RevokedStateNum, + } + task := newBackupTask(id, test.expSweepScript) - // Assert that all parameters set during initialization are properly - // populated. - require.Equal(t, test.chanID, task.id.ChanID) - require.Equal(t, test.breachInfo.RevokedStateNum, task.id.CommitHeight) - require.Equal(t, test.expTotalAmt, task.totalAmt) - require.Equal(t, test.breachInfo, task.breachInfo) - require.Equal(t, test.expToLocalInput, task.toLocalInput) - require.Equal(t, test.expToRemoteInput, task.toRemoteInput) + // getBreachInfo is a helper closure that returns the breach retribution + // info and channel type for the given channel and commit height. + getBreachInfo := func(id lnwire.ChannelID, commitHeight uint64) ( + *lnwallet.BreachRetribution, channeldb.ChannelType, error) { + + return test.breachInfo, test.chanType, nil + } // Reconstruct the expected input.Inputs that will be returned by the // task's inputs() method. @@ -515,9 +515,18 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Now, bind the session to the task. If successful, this locks in the // session's negotiated parameters and allows the backup task to derive // the final free variables in the justice transaction. - err := task.bindSession(test.session) + err := task.bindSession(test.session, getBreachInfo) require.ErrorIs(t, err, test.bindErr) + // Assert that all parameters set during after binding the backup task + // are properly populated. + require.Equal(t, test.chanID, task.id.ChanID) + require.Equal(t, test.breachInfo.RevokedStateNum, task.id.CommitHeight) + require.Equal(t, test.expTotalAmt, task.totalAmt) + require.Equal(t, test.breachInfo, task.breachInfo) + require.Equal(t, test.expToLocalInput, task.toLocalInput) + require.Equal(t, test.expToRemoteInput, task.toRemoteInput) + // Exit early if the bind was supposed to fail. But first, we check that // all fields set during a bind are still unset. This ensure that a // failed bind doesn't have side-effects if the task is retried with a diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 5373edb66..5a59cbcd9 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -829,17 +829,12 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, c.chanCommitHeights[*chanID] = stateNum c.backupMu.Unlock() - // Fetch the breach retribution info and channel type. - breachInfo, chanType, err := c.cfg.BuildBreachRetribution( - *chanID, stateNum, - ) - if err != nil { - return err + id := wtdb.BackupID{ + ChanID: *chanID, + CommitHeight: stateNum, } - task := newBackupTask( - chanID, breachInfo, summary.SweepPkScript, chanType, - ) + task := newBackupTask(id, summary.SweepPkScript) return c.pipeline.QueueBackupTask(task) } @@ -1543,16 +1538,17 @@ func (c *TowerClient) newSessionQueue(s *ClientSession, updates []wtdb.CommittedUpdate) *sessionQueue { return newSessionQueue(&sessionQueueConfig{ - ClientSession: s, - ChainHash: c.cfg.ChainHash, - Dial: c.dial, - ReadMessage: c.readMessage, - SendMessage: c.sendMessage, - Signer: c.cfg.Signer, - DB: c.cfg.DB, - MinBackoff: c.cfg.MinBackoff, - MaxBackoff: c.cfg.MaxBackoff, - Log: c.log, + ClientSession: s, + ChainHash: c.cfg.ChainHash, + Dial: c.dial, + ReadMessage: c.readMessage, + SendMessage: c.sendMessage, + Signer: c.cfg.Signer, + DB: c.cfg.DB, + MinBackoff: c.cfg.MinBackoff, + MaxBackoff: c.cfg.MaxBackoff, + Log: c.log, + BuildBreachRetribution: c.cfg.BuildBreachRetribution, }, updates) } diff --git a/watchtower/wtclient/session_queue.go b/watchtower/wtclient/session_queue.go index 7f29deb75..aa06709ee 100644 --- a/watchtower/wtclient/session_queue.go +++ b/watchtower/wtclient/session_queue.go @@ -57,6 +57,11 @@ type sessionQueueConfig struct { // for justice transaction inputs. Signer input.Signer + // BuildBreachRetribution is a function closure that allows the client + // to fetch the breach retribution info for a certain channel at a + // certain revoked commitment height. + BuildBreachRetribution BreachRetributionBuilder + // DB provides access to the client's stable storage. DB DB @@ -220,7 +225,10 @@ func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) { // // TODO(conner): queue backups and retry with different session params. case reserveAvailable: - err := task.bindSession(&q.cfg.ClientSession.ClientSessionBody) + err := task.bindSession( + &q.cfg.ClientSession.ClientSessionBody, + q.cfg.BuildBreachRetribution, + ) if err != nil { q.queueCond.L.Unlock() q.log.Debugf("SessionQueue(%s) rejected %v: %v ",