From cab0560d5e51dd21855479b0064ec31625c269c9 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 22 Mar 2023 08:39:53 +0200 Subject: [PATCH 01/12] wtclient: cleanup the test file This commit just does some linting of the client_test.go file so that future commits are easier to parse. --- watchtower/wtclient/client_test.go | 78 +++++++++++++++--------------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 7be616d77..741443f49 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -36,8 +36,6 @@ import ( ) const ( - csvDelay uint32 = 144 - towerAddrStr = "18.28.243.2:9911" towerAddr2Str = "19.29.244.3:9912" ) @@ -73,6 +71,16 @@ var ( addrScript, _ = txscript.PayToAddrScript(addr) waitTime = 5 * time.Second + + defaultTxPolicy = wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + } + + highSweepRateTxPolicy = wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: 1000000, // The high sweep fee creates dust. + } ) // randPrivKey generates a new secp keypair, and returns the public key. @@ -823,7 +831,7 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint, require.Lenf(h.t, matches, len(hints), "expected: %d matches, got: %d", len(hints), len(matches)) - // Assert that all of the matches correspond to a session with the + // Assert that all the matches correspond to a session with the // expected policy. for _, match := range matches { matchPolicy := match.SessionInfo.Policy @@ -969,11 +977,6 @@ const ( remoteBalance = lnwire.MilliSatoshi(200000000) ) -var defaultTxPolicy = wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, -} - type clientTest struct { name string cfg harnessCfg @@ -1072,7 +1075,7 @@ var clientTests = []clientTest{ // pipeline is always flushed before it exits. go h.client.Stop() - // Wait for all of the updates to be populated in the + // Wait for all the updates to be populated in the // server's database. h.waitServerUpdates(hints, time.Second) }, @@ -1086,10 +1089,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: 1000000, // high sweep fee creates dust - }, + TxPolicy: highSweepRateTxPolicy, MaxUpdates: 20000, }, }, @@ -1177,7 +1177,7 @@ var clientTests = []clientTest{ // the tower to receive the remaining states. h.backupStates(chanID, numSent, numUpdates, nil) - // Wait for all of the updates to be populated in the + // Wait for all the updates to be populated in the // server's database. h.waitServerUpdates(hints, time.Second) @@ -1230,7 +1230,7 @@ var clientTests = []clientTest{ h.serverCfg.NoAckUpdates = false h.startServer() - // Wait for all of the updates to be populated in the + // Wait for all the updates to be populated in the // server's database. h.waitServerUpdates(hints, waitTime) }, @@ -1252,9 +1252,11 @@ var clientTests = []clientTest{ }, fn: func(h *testHarness) { var ( - capacity = h.cfg.localBalance + h.cfg.remoteBalance + capacity = h.cfg.localBalance + + h.cfg.remoteBalance paymentAmt = lnwire.MilliSatoshi(2000000) - numSends = uint64(h.cfg.localBalance / paymentAmt) + numSends = uint64(h.cfg.localBalance) / + uint64(paymentAmt) numRecvs = uint64(capacity / paymentAmt) numUpdates = numSends + numRecvs // 200 updates chanID = uint64(0) @@ -1262,11 +1264,15 @@ var clientTests = []clientTest{ // Send money to the remote party until all funds are // depleted. - sendHints := h.sendPayments(chanID, 0, numSends, paymentAmt) + sendHints := h.sendPayments( + chanID, 0, numSends, paymentAmt, + ) // Now, sequentially receive the entire channel balance // from the remote party. - recvHints := h.recvPayments(chanID, numSends, numUpdates, paymentAmt) + recvHints := h.recvPayments( + chanID, numSends, numUpdates, paymentAmt, + ) // Collect the hints generated by both sending and // receiving. @@ -1275,7 +1281,7 @@ var clientTests = []clientTest{ // Backup the channel's states the client. h.backupStates(chanID, 0, numUpdates, nil) - // Wait for all of the updates to be populated in the + // Wait for all the updates to be populated in the // server's database. h.waitServerUpdates(hints, 3*time.Second) }, @@ -1292,10 +1298,7 @@ var clientTests = []clientTest{ }, }, fn: func(h *testHarness) { - const ( - numUpdates = 5 - numChans = 10 - ) + const numUpdates = 5 // Initialize and register an additional 9 channels. for id := uint64(1); id < 10; id++ { @@ -1323,7 +1326,7 @@ var clientTests = []clientTest{ // Test reliable flush under multi-client scenario. go h.client.Stop() - // Wait for all of the updates to be populated in the + // Wait for all the updates to be populated in the // server's database. h.waitServerUpdates(hints, 10*time.Second) }, @@ -1372,7 +1375,7 @@ var clientTests = []clientTest{ // Now, queue the retributions for backup. h.backupStates(chanID, 0, numUpdates, nil) - // Wait for all of the updates to be populated in the + // Wait for all the updates to be populated in the // server's database. h.waitServerUpdates(hints, waitTime) @@ -1426,7 +1429,7 @@ var clientTests = []clientTest{ // Now, queue the retributions for backup. h.backupStates(chanID, 0, numUpdates, nil) - // Wait for all of the updates to be populated in the + // Wait for all the updates to be populated in the // server's database. h.waitServerUpdates(hints, waitTime) @@ -1469,11 +1472,11 @@ var clientTests = []clientTest{ require.NoError(h.t, h.client.Stop()) // Record the policy that the first half was stored - // under. We'll expect the second half to also be stored - // under the original policy, since we are only adjusting - // the MaxUpdates. The client should detect that the - // two policies have equivalent TxPolicies and continue - // using the first. + // under. We'll expect the second half to also be + // stored under the original policy, since we are only + // adjusting the MaxUpdates. The client should detect + // that the two policies have equivalent TxPolicies and + // continue using the first. expPolicy := h.clientCfg.Policy // Restart the client with a new policy. @@ -1483,7 +1486,7 @@ var clientTests = []clientTest{ // Now, queue the second half of the retributions. h.backupStates(chanID, numUpdates/2, numUpdates, nil) - // Wait for all of the updates to be populated in the + // Wait for all the updates to be populated in the // server's database. h.waitServerUpdates(hints, waitTime) @@ -1534,7 +1537,7 @@ var clientTests = []clientTest{ // the second half should actually be sent. h.backupStates(chanID, 0, numUpdates, nil) - // Wait for all of the updates to be populated in the + // Wait for all the updates to be populated in the // server's database. h.waitServerUpdates(hints, waitTime) }, @@ -1599,10 +1602,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -1629,7 +1629,7 @@ var clientTests = []clientTest{ // Back up the remaining two states. Once the first is // processed, the session will be exhausted but the - // client won't be able to regnegotiate a session for + // client won't be able to renegotiate a session for // the final state. We'll only wait for the first five // states to arrive at the tower. h.backupStates(chanID, maxUpdates-1, numUpdates, nil) From aebdd2375ca6e66037e121b68b2731449b13a419 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 2 Feb 2023 10:56:15 +0200 Subject: [PATCH 02/12] channeldb: lint FetchChannel method A few lint fixes of the FetchChannel method. --- channeldb/db.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/channeldb/db.go b/channeldb/db.go index 38493d240..7c204375b 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -654,8 +654,7 @@ func (c *ChannelStateDB) fetchNodeChannels(chainBucket kvdb.RBucket) ( // FetchChannel attempts to locate a channel specified by the passed channel // point. If the channel cannot be found, then an error will be returned. -// Optionally an existing db tx can be supplied. Optionally an existing db tx -// can be supplied. +// Optionally an existing db tx can be supplied. func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( *OpenChannel, error) { @@ -694,7 +693,9 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( return nil } - nodeChanBucket := openChanBucket.NestedReadBucket(nodePub) + nodeChanBucket := openChanBucket.NestedReadBucket( + nodePub, + ) if nodeChanBucket == nil { return nil } @@ -715,10 +716,11 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( ) if chainBucket == nil { return fmt.Errorf("unable to read "+ - "bucket for chain=%x", chainHash[:]) + "bucket for chain=%x", + chainHash) } - // Finally we reach the leaf bucket that stores + // Finally, we reach the leaf bucket that stores // all the chanPoints for this node. chanBucket := chainBucket.NestedReadBucket( targetChanPoint.Bytes(), @@ -757,7 +759,7 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( } // If we can't find the channel, then we return with an error, as we - // have nothing to backup. + // have nothing to back up. return nil, ErrChannelNotFound } From 908cb6060bc617a2b93f9e0416ea13d5eeb5d10a Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 2 Feb 2023 10:59:15 +0200 Subject: [PATCH 03/12] channeldb: optimise FetchChannel method This commit adds a small optimisation to the FetchChannel method. Instead of iterating over each channel bucket, an identifiable error is thrown once the wanted channel is found so that the iteration can stop early. --- channeldb/db.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/channeldb/db.go b/channeldb/db.go index 7c204375b..273189b47 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -661,6 +661,10 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( var ( targetChan *OpenChannel targetChanPoint bytes.Buffer + + // errChanFound is used to signal that the channel has been + // found so that iteration through the DB buckets can stop. + errChanFound = errors.New("channel found") ) if err := writeOutpoint(&targetChanPoint, &chanPoint); err != nil { @@ -739,7 +743,7 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( targetChan = channel targetChan.Db = c - return nil + return errChanFound }) }) } @@ -750,7 +754,7 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( } else { err = chanScan(tx) } - if err != nil { + if err != nil && !errors.Is(err, errChanFound) { return nil, err } From 63442cbe51a1f2e8f8a300266c681f9786310d0c Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 22 Mar 2023 09:29:28 +0200 Subject: [PATCH 04/12] channeldb: factor out generic FetchChannel logic This commit introduces a new `channelSelector` method and moves all generic logic from `FetchChannel` to it. This refactor will make it easier to add new methods that require the same open-channel db traversal with slightly different channel selection logic. --- channeldb/db.go | 53 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/channeldb/db.go b/channeldb/db.go index 273189b47..26268eabe 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcwallet/walletdb" "github.com/go-errors/errors" mig "github.com/lightningnetwork/lnd/channeldb/migration" "github.com/lightningnetwork/lnd/channeldb/migration12" @@ -658,19 +659,42 @@ func (c *ChannelStateDB) fetchNodeChannels(chainBucket kvdb.RBucket) ( func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( *OpenChannel, error) { + var targetChanPoint bytes.Buffer + if err := writeOutpoint(&targetChanPoint, &chanPoint); err != nil { + return nil, err + } + + targetChanPointBytes := targetChanPoint.Bytes() + selector := func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint, + error) { + + return targetChanPointBytes, &chanPoint, nil + } + + return c.channelScanner(tx, selector) +} + +// channelSelector describes a function that takes a chain-hash bucket from +// within the open-channel DB and returns the wanted channel point bytes, and +// channel point. It must return the ErrChannelNotFound error if the wanted +// channel is not in the given bucket. +type channelSelector func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint, + error) + +// channelScanner will traverse the DB to each chain-hash bucket of each node +// pub-key bucket in the open-channel-bucket. The chanSelector will then be used +// to fetch the wanted channel outpoint from the chain bucket. +func (c *ChannelStateDB) channelScanner(tx kvdb.RTx, + chanSelect channelSelector) (*OpenChannel, error) { + var ( - targetChan *OpenChannel - targetChanPoint bytes.Buffer + targetChan *OpenChannel // errChanFound is used to signal that the channel has been // found so that iteration through the DB buckets can stop. errChanFound = errors.New("channel found") ) - if err := writeOutpoint(&targetChanPoint, &chanPoint); err != nil { - return nil, err - } - // chanScan will traverse the following bucket structure: // * nodePub => chainHash => chanPoint // @@ -688,8 +712,8 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( } // Within the node channel bucket, are the set of node pubkeys - // we have channels with, we don't know the entire set, so - // we'll check them all. + // we have channels with, we don't know the entire set, so we'll + // check them all. return openChanBucket.ForEach(func(nodePub, v []byte) error { // Ensure that this is a key the same size as a pubkey, // and also that it leads directly to a bucket. @@ -726,15 +750,24 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( // Finally, we reach the leaf bucket that stores // all the chanPoints for this node. + targetChanBytes, chanPoint, err := chanSelect( + chainBucket, + ) + if errors.Is(err, ErrChannelNotFound) { + return nil + } else if err != nil { + return err + } + chanBucket := chainBucket.NestedReadBucket( - targetChanPoint.Bytes(), + targetChanBytes, ) if chanBucket == nil { return nil } channel, err := fetchOpenChannel( - chanBucket, &chanPoint, + chanBucket, chanPoint, ) if err != nil { return err From fe2304efad9a0ab9d1f02fdbb3343f27ea55bc66 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 22 Mar 2023 09:35:59 +0200 Subject: [PATCH 05/12] channeldb: add a FetchChannelByID method Add a FetchChannelByID method that allows a caller to fetch an OpenChannel using an lnwire.ChannelID. --- channeldb/db.go | 48 ++++++++++++++++++++++++++++++++++++++++++++ channeldb/db_test.go | 23 +++++++++++++-------- 2 files changed, 63 insertions(+), 8 deletions(-) diff --git a/channeldb/db.go b/channeldb/db.go index 26268eabe..146c07449 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -674,6 +674,54 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( return c.channelScanner(tx, selector) } +// FetchChannelByID attempts to locate a channel specified by the passed channel +// ID. If the channel cannot be found, then an error will be returned. +// Optionally an existing db tx can be supplied. +func (c *ChannelStateDB) FetchChannelByID(tx kvdb.RTx, id lnwire.ChannelID) ( + *OpenChannel, error) { + + selector := func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint, + error) { + + var ( + targetChanPointBytes []byte + targetChanPoint *wire.OutPoint + + // errChanFound is used to signal that the channel has + // been found so that iteration through the DB buckets + // can stop. + errChanFound = errors.New("channel found") + ) + err := chainBkt.ForEach(func(k, _ []byte) error { + var outPoint wire.OutPoint + err := readOutpoint(bytes.NewReader(k), &outPoint) + if err != nil { + return err + } + + chanID := lnwire.NewChanIDFromOutPoint(&outPoint) + if chanID != id { + return nil + } + + targetChanPoint = &outPoint + targetChanPointBytes = k + + return errChanFound + }) + if err != nil && !errors.Is(err, errChanFound) { + return nil, nil, err + } + if targetChanPoint == nil { + return nil, nil, ErrChannelNotFound + } + + return targetChanPointBytes, targetChanPoint, nil + } + + return c.channelScanner(tx, selector) +} + // channelSelector describes a function that takes a chain-hash bucket from // within the open-channel DB and returns the wanted channel point bytes, and // channel point. It must return the ErrChannelNotFound error if the wanted diff --git a/channeldb/db_test.go b/channeldb/db_test.go index bed9c6e32..62e720407 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -12,7 +12,6 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" @@ -238,10 +237,16 @@ func TestFetchChannel(t *testing.T) { // The decoded channel state should be identical to what we stored // above. - if !reflect.DeepEqual(channelState, dbChannel) { - t.Fatalf("channel state doesn't match:: %v vs %v", - spew.Sdump(channelState), spew.Sdump(dbChannel)) - } + require.Equal(t, channelState, dbChannel) + + // Next, attempt to fetch the channel by its channel ID. + chanID := lnwire.NewChanIDFromOutPoint(&channelState.FundingOutpoint) + dbChannel, err = cdb.FetchChannelByID(nil, chanID) + require.NoError(t, err, "unable to fetch channel") + + // The decoded channel state should be identical to what we stored + // above. + require.Equal(t, channelState, dbChannel) // If we attempt to query for a non-existent channel, then we should // get an error. @@ -252,9 +257,11 @@ func TestFetchChannel(t *testing.T) { channelState2.FundingOutpoint.Index = uniqueOutputIndex.Load() _, err = cdb.FetchChannel(nil, channelState2.FundingOutpoint) - if err == nil { - t.Fatalf("expected query to fail") - } + require.ErrorIs(t, err, ErrChannelNotFound) + + chanID2 := lnwire.NewChanIDFromOutPoint(&channelState2.FundingOutpoint) + _, err = cdb.FetchChannelByID(nil, chanID2) + require.ErrorIs(t, err, ErrChannelNotFound) } func genRandomChannelShell() (*ChannelShell, error) { From 85ec38f447bd425c93a726663cda7c0e749d7add Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 22 Mar 2023 11:13:51 +0200 Subject: [PATCH 06/12] multi: pass BuildBreachRetribution callback to tower client In this commit, a new BuildBreachRetribution callback is added to the tower client's Config struct. The main LND server provides the client with an implementation of the callback. --- server.go | 38 +++++++++++++++++++++++++----- watchtower/wtclient/client.go | 11 +++++++++ watchtower/wtclient/client_test.go | 24 +++++++++++++++++++ 3 files changed, 67 insertions(+), 6 deletions(-) diff --git a/server.go b/server.go index c38ebb69e..caefe4a50 100644 --- a/server.go +++ b/server.go @@ -1523,12 +1523,37 @@ func newServer(cfg *Config, listenAddrs []net.Addr, ) } + // buildBreachRetribution is a call-back that can be used to + // query the BreachRetribution info and channel type given a + // channel ID and commitment height. + buildBreachRetribution := func(chanID lnwire.ChannelID, + commitHeight uint64) (*lnwallet.BreachRetribution, + channeldb.ChannelType, error) { + + channel, err := s.chanStateDB.FetchChannelByID( + nil, chanID, + ) + if err != nil { + return nil, 0, err + } + + br, err := lnwallet.NewBreachRetribution( + channel, commitHeight, 0, nil, + ) + if err != nil { + return nil, 0, err + } + + return br, channel.ChanType, nil + } + fetchClosedChannel := s.chanStateDB.FetchClosedChannelForID s.towerClient, err = wtclient.New(&wtclient.Config{ - FetchClosedChannel: fetchClosedChannel, - SessionCloseRange: sessionCloseRange, - ChainNotifier: s.cc.ChainNotifier, + FetchClosedChannel: fetchClosedChannel, + BuildBreachRetribution: buildBreachRetribution, + SessionCloseRange: sessionCloseRange, + ChainNotifier: s.cc.ChainNotifier, SubscribeChannelEvents: func() (subscribe.Subscription, error) { @@ -1558,9 +1583,10 @@ func newServer(cfg *Config, listenAddrs []net.Addr, blob.Type(blob.FlagAnchorChannel) s.anchorTowerClient, err = wtclient.New(&wtclient.Config{ - FetchClosedChannel: fetchClosedChannel, - SessionCloseRange: sessionCloseRange, - ChainNotifier: s.cc.ChainNotifier, + FetchClosedChannel: fetchClosedChannel, + BuildBreachRetribution: buildBreachRetribution, + SessionCloseRange: sessionCloseRange, + ChainNotifier: s.cc.ChainNotifier, SubscribeChannelEvents: func() (subscribe.Subscription, error) { diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 96dcaa866..6a45a5c02 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -178,6 +178,11 @@ type Config struct { // ChainNotifier can be used to subscribe to block notifications. ChainNotifier chainntnfs.ChainNotifier + // BuildBreachRetribution is a function closure that allows the client + // fetch the breach retribution info for a certain channel at a certain + // revoked commitment height. + BuildBreachRetribution BreachRetributionBuilder + // NewAddress generates a new on-chain sweep pkscript. NewAddress func() ([]byte, error) @@ -240,6 +245,12 @@ type Config struct { SessionCloseRange uint32 } +// BreachRetributionBuilder is a function that can be used to construct a +// BreachRetribution from a channel ID and a commitment height. +type BreachRetributionBuilder func(id lnwire.ChannelID, + commitHeight uint64) (*lnwallet.BreachRetribution, + channeldb.ChannelType, error) + // newTowerMsg is an internal message we'll use within the TowerClient to signal // that a new tower can be considered. type newTowerMsg struct { diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 741443f49..37300f66a 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -517,6 +517,15 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { SessionCloseRange: 1, } + h.clientCfg.BuildBreachRetribution = func(id lnwire.ChannelID, + commitHeight uint64) (*lnwallet.BreachRetribution, + channeldb.ChannelType, error) { + + _, retribution := h.channelFromID(id).getState(commitHeight) + + return retribution, channeldb.SingleFunderBit, nil + } + if !cfg.noServerStart { h.startServer() t.Cleanup(h.stopServer) @@ -627,6 +636,21 @@ func (h *testHarness) channel(id uint64) *mockChannel { return c } +// channelFromID retrieves the channel corresponding to id. +// +// NOTE: The method fails if a channel for id does not exist. +func (h *testHarness) channelFromID(chanID lnwire.ChannelID) *mockChannel { + h.t.Helper() + + h.mu.Lock() + defer h.mu.Unlock() + + c, ok := h.channels[chanID] + require.Truef(h.t, ok, "unable to fetch channel %s", chanID) + + return c +} + // closeChannel marks a channel as closed. // // NOTE: The method fails if a channel for id does not exist. From 530a8cae5dfd8d998c98930265e71b52dbae0716 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 2 Feb 2023 11:31:24 +0200 Subject: [PATCH 07/12] wtclient: lint the package This commit fixes some lints in the wtclient package. This is done so that upcoming logic changes are easier to parse. --- watchtower/wtclient/backup_task.go | 7 +++++-- watchtower/wtclient/client.go | 30 ++++++++++++++++++---------- watchtower/wtclient/session_queue.go | 7 ++++--- watchtower/wtclient/task_pipeline.go | 19 +++++++++++------- 4 files changed, 40 insertions(+), 23 deletions(-) diff --git a/watchtower/wtclient/backup_task.go b/watchtower/wtclient/backup_task.go index a72689303..3f95ce096 100644 --- a/watchtower/wtclient/backup_task.go +++ b/watchtower/wtclient/backup_task.go @@ -223,7 +223,9 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error { // P2WKH output. Anchor channels spend a to-remote confirmed // P2WSH output. if t.chanType.HasAnchors() { - weightEstimate.AddWitnessInput(input.ToRemoteConfirmedWitnessSize) + weightEstimate.AddWitnessInput( + input.ToRemoteConfirmedWitnessSize, + ) } else { weightEstimate.AddWitnessInput(input.P2WKHWitnessSize) } @@ -231,7 +233,8 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error { // All justice transactions will either use segwit v0 (p2wkh + p2wsh) // or segwit v1 (p2tr). - if err := addScriptWeight(&weightEstimate, t.sweepPkScript); err != nil { + err := addScriptWeight(&weightEstimate, t.sweepPkScript) + if err != nil { return err } diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 6a45a5c02..315ac06b8 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -34,8 +34,8 @@ const ( // a read before breaking out of a blocking read. DefaultReadTimeout = 15 * time.Second - // DefaultWriteTimeout specifies the default duration we will wait during - // a write before breaking out of a blocking write. + // DefaultWriteTimeout specifies the default duration we will wait + // during a write before breaking out of a blocking write. DefaultWriteTimeout = 15 * time.Second // DefaultStatInterval specifies the default interval between logging @@ -569,8 +569,11 @@ func (c *TowerClient) Start() error { // committed but unacked state updates. This ensures that these // sessions will be able to flush the committed updates after a // restart. + fetchCommittedUpdates := c.cfg.DB.FetchSessionCommittedUpdates for _, session := range c.candidateSessions { - committedUpdates, err := c.cfg.DB.FetchSessionCommittedUpdates(&session.ID) + committedUpdates, err := fetchCommittedUpdates( + &session.ID, + ) if err != nil { returnErr = err return @@ -799,8 +802,8 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { // - client is force quit, // - justice transaction would create dust outputs when trying to abide by the // negotiated policy, or -// - breached outputs contain too little value to sweep at the target sweep fee -// rate. +// - breached outputs contain too little value to sweep at the target sweep +// fee rate. func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, breachInfo *lnwallet.BreachRetribution, chanType channeldb.ChannelType) error { @@ -817,8 +820,8 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, height, ok := c.chanCommitHeights[*chanID] if ok && breachInfo.RevokedStateNum <= height { c.backupMu.Unlock() - c.log.Debugf("Ignoring duplicate backup for chanid=%v at height=%d", - chanID, breachInfo.RevokedStateNum) + c.log.Debugf("Ignoring duplicate backup for chanid=%v at "+ + "height=%d", chanID, breachInfo.RevokedStateNum) return nil } @@ -1496,7 +1499,9 @@ func (c *TowerClient) readMessage(peer wtserver.Peer) (wtwire.Message, error) { } // sendMessage sends a watchtower wire message to the target peer. -func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error { +func (c *TowerClient) sendMessage(peer wtserver.Peer, + msg wtwire.Message) error { + // Encode the next wire message into the buffer. // TODO(conner): use buffer pool var b bytes.Buffer @@ -1664,7 +1669,9 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error { // negotiations and from being used for any subsequent backups until it's added // again. If an address is provided, then this call only serves as a way of // removing the address from the watchtower instead. -func (c *TowerClient) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { +func (c *TowerClient) RemoveTower(pubKey *btcec.PublicKey, + addr net.Addr) error { + errChan := make(chan error, 1) select { @@ -1745,8 +1752,9 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error { // If our active session queue corresponds to the stale tower, we'll // proceed to negotiate a new one. if c.sessionQueue != nil { - activeTower := c.sessionQueue.tower.IdentityKey.SerializeCompressed() - if bytes.Equal(pubKey, activeTower) { + towerKey := c.sessionQueue.tower.IdentityKey + + if bytes.Equal(pubKey, towerKey.SerializeCompressed()) { c.sessionQueue = nil } } diff --git a/watchtower/wtclient/session_queue.go b/watchtower/wtclient/session_queue.go index a5c570a71..7f29deb75 100644 --- a/watchtower/wtclient/session_queue.go +++ b/watchtower/wtclient/session_queue.go @@ -45,7 +45,8 @@ type sessionQueueConfig struct { Dial func(keychain.SingleKeyECDH, *lnwire.NetAddress) (wtserver.Peer, error) - // SendMessage encodes, encrypts, and writes a message to the given peer. + // SendMessage encodes, encrypts, and writes a message to the given + // peer. SendMessage func(wtserver.Peer, wtwire.Message) error // ReadMessage receives, decypts, and decodes a message from the given @@ -343,8 +344,8 @@ func (q *sessionQueue) drainBackups() { // before attempting to dequeue any pending updates. stateUpdate, isPending, backupID, err := q.nextStateUpdate() if err != nil { - q.log.Errorf("SessionQueue(%v) unable to get next state "+ - "update: %v", q.ID(), err) + q.log.Errorf("SessionQueue(%v) unable to get next "+ + "state update: %v", q.ID(), err) return } diff --git a/watchtower/wtclient/task_pipeline.go b/watchtower/wtclient/task_pipeline.go index d1dc62ff5..385f477af 100644 --- a/watchtower/wtclient/task_pipeline.go +++ b/watchtower/wtclient/task_pipeline.go @@ -116,8 +116,8 @@ func (q *taskPipeline) QueueBackupTask(task *backupTask) error { default: } - // Queue the new task and signal the queue's condition variable to wake up - // the queueManager for processing. + // Queue the new task and signal the queue's condition variable to wake + // up the queueManager for processing. q.queue.PushBack(task) q.queueCond.L.Unlock() @@ -141,16 +141,21 @@ func (q *taskPipeline) queueManager() { select { case <-q.quit: - // Exit only after the queue has been fully drained. + // Exit only after the queue has been fully + // drained. if q.queue.Len() == 0 { q.queueCond.L.Unlock() - q.log.Debugf("Revoked state pipeline flushed.") + q.log.Debugf("Revoked state pipeline " + + "flushed.") + return } case <-q.forceQuit: q.queueCond.L.Unlock() - q.log.Debugf("Revoked state pipeline force quit.") + q.log.Debugf("Revoked state pipeline force " + + "quit.") + return default: @@ -164,8 +169,8 @@ func (q *taskPipeline) queueManager() { select { - // Backup task submitted to dispatcher. We don't select on quit to - // ensure that we still drain tasks while shutting down. + // Backup task submitted to dispatcher. We don't select on quit + // to ensure that we still drain tasks while shutting down. case q.newBackupTasks <- task: // Force quit, return immediately to allow the client to exit. From 458ac32146b03a532c9998e28eaa83ebbcb7d13d Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 2 Feb 2023 11:40:33 +0200 Subject: [PATCH 08/12] multi: build retribution info in TowerClient Since the TowerClient now has a callback that it can use to retrieve the retribution for a certain channel and commit height, let it use this call back instead of requiring the info to be passed to it through BackupState. --- htlcswitch/interfaces.go | 8 ++------ htlcswitch/link.go | 22 ++++------------------ watchtower/wtclient/client.go | 26 ++++++++++++++++---------- watchtower/wtclient/client_test.go | 7 +++---- 4 files changed, 25 insertions(+), 38 deletions(-) diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index d24f48193..32d80ac11 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -7,7 +7,6 @@ import ( "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" @@ -255,11 +254,8 @@ type TowerClient interface { // state. If the method returns nil, the backup is guaranteed to be // successful unless the tower is unavailable and client is force quit, // or the justice transaction would create dust outputs when trying to - // abide by the negotiated policy. If the channel we're trying to back - // up doesn't have a tweak for the remote party's output, then - // isTweakless should be true. - BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution, - channeldb.ChannelType) error + // abide by the negotiated policy. + BackupState(chanID *lnwire.ChannelID, stateNum uint64) error } // InterceptableHtlcForwarder is the interface to set the interceptor diff --git a/htlcswitch/link.go b/htlcswitch/link.go index d2dfc20ab..4c1745f0c 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2022,11 +2022,6 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { // We've received a revocation from the remote chain, if valid, // this moves the remote chain forward, and expands our // revocation window. - // - // Before advancing our remote chain, we will record the - // current commit tx, which is used by the TowerClient to - // create backups. - oldCommitTx := l.channel.State().RemoteCommitment.CommitTx // We now process the message and advance our remote commit // chain. @@ -2063,24 +2058,15 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { // create a backup for the current state. if l.cfg.TowerClient != nil { state := l.channel.State() - breachInfo, err := lnwallet.NewBreachRetribution( - state, state.RemoteCommitment.CommitHeight-1, 0, - // OldCommitTx is the breaching tx at height-1. - oldCommitTx, - ) - if err != nil { - l.fail(LinkFailureError{code: ErrInternalError}, - "failed to load breach info: %v", err) - return - } - chanID := l.ChanID() + err = l.cfg.TowerClient.BackupState( - &chanID, breachInfo, state.ChanType, + &chanID, state.RemoteCommitment.CommitHeight-1, ) if err != nil { l.fail(LinkFailureError{code: ErrInternalError}, - "unable to queue breach backup: %v", err) + "unable to queue breach backup: %v", + err) return } } diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 315ac06b8..5373edb66 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -136,11 +136,8 @@ type Client interface { // state. If the method returns nil, the backup is guaranteed to be // successful unless the client is force quit, or the justice // transaction would create dust outputs when trying to abide by the - // negotiated policy. If the channel we're trying to back up doesn't - // have a tweak for the remote party's output, then isTweakless should - // be true. - BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution, - channeldb.ChannelType) error + // negotiated policy. + BackupState(chanID *lnwire.ChannelID, stateNum uint64) error // Start initializes the watchtower client, allowing it process requests // to backup revoked channel states. @@ -805,32 +802,41 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { // - breached outputs contain too little value to sweep at the target sweep // fee rate. func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, - breachInfo *lnwallet.BreachRetribution, - chanType channeldb.ChannelType) error { + stateNum uint64) error { // Retrieve the cached sweep pkscript used for this channel. c.backupMu.Lock() summary, ok := c.summaries[*chanID] if !ok { c.backupMu.Unlock() + return ErrUnregisteredChannel } // Ignore backups that have already been presented to the client. height, ok := c.chanCommitHeights[*chanID] - if ok && breachInfo.RevokedStateNum <= height { + if ok && stateNum <= height { c.backupMu.Unlock() c.log.Debugf("Ignoring duplicate backup for chanid=%v at "+ - "height=%d", chanID, breachInfo.RevokedStateNum) + "height=%d", chanID, stateNum) + return nil } // This backup has a higher commit height than any known backup for this // channel. We'll update our tip so that we won't accept it again if the // link flaps. - c.chanCommitHeights[*chanID] = breachInfo.RevokedStateNum + c.chanCommitHeights[*chanID] = stateNum c.backupMu.Unlock() + // Fetch the breach retribution info and channel type. + breachInfo, chanType, err := c.cfg.BuildBreachRetribution( + *chanID, stateNum, + ) + if err != nil { + return err + } + task := newBackupTask( chanID, breachInfo, summary.SweepPkScript, chanType, ) diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 37300f66a..718a902db 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -731,10 +731,9 @@ func (h *testHarness) backupState(id, i uint64, expErr error) { _, retribution := h.channel(id).getState(i) chanID := chanIDFromInt(id) - err := h.client.BackupState( - &chanID, retribution, channeldb.SingleFunderBit, - ) - require.ErrorIs(h.t, err, expErr) + + err := h.client.BackupState(&chanID, retribution.RevokedStateNum) + require.ErrorIs(h.t, expErr, err) } // sendPayments instructs the channel identified by id to send amt to the remote From 2371bbf09aab7783b094c968704f5bd8e6124846 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 2 Feb 2023 12:12:36 +0200 Subject: [PATCH 09/12] wtclient: only fetch retribution info when needed. Only construct the retribution info at the time that the backup task is being bound to a session. --- watchtower/wtclient/backup_task.go | 163 +++++++++--------- .../wtclient/backup_task_internal_test.go | 35 ++-- watchtower/wtclient/client.go | 34 ++-- watchtower/wtclient/session_queue.go | 10 +- 4 files changed, 128 insertions(+), 114 deletions(-) diff --git a/watchtower/wtclient/backup_task.go b/watchtower/wtclient/backup_task.go index 3f95ce096..2a02da479 100644 --- a/watchtower/wtclient/backup_task.go +++ b/watchtower/wtclient/backup_task.go @@ -10,7 +10,6 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -40,7 +39,6 @@ import ( type backupTask struct { id wtdb.BackupID breachInfo *lnwallet.BreachRetribution - chanType channeldb.ChannelType // state-dependent variables @@ -55,11 +53,79 @@ type backupTask struct { outputs []*wire.TxOut } -// newBackupTask initializes a new backupTask and populates all state-dependent -// variables. -func newBackupTask(chanID *lnwire.ChannelID, - breachInfo *lnwallet.BreachRetribution, - sweepPkScript []byte, chanType channeldb.ChannelType) *backupTask { +// newBackupTask initializes a new backupTask. +func newBackupTask(id wtdb.BackupID, sweepPkScript []byte) *backupTask { + return &backupTask{ + id: id, + sweepPkScript: sweepPkScript, + } +} + +// inputs returns all non-dust inputs that we will attempt to spend from. +// +// NOTE: Ordering of the inputs is not critical as we sort the transaction with +// BIP69 in a later stage. +func (t *backupTask) inputs() map[wire.OutPoint]input.Input { + inputs := make(map[wire.OutPoint]input.Input) + if t.toLocalInput != nil { + inputs[*t.toLocalInput.OutPoint()] = t.toLocalInput + } + if t.toRemoteInput != nil { + inputs[*t.toRemoteInput.OutPoint()] = t.toRemoteInput + } + + return inputs +} + +// addrType returns the type of an address after parsing it and matching it to +// the set of known script templates. +func addrType(pkScript []byte) txscript.ScriptClass { + // We pass in a set of dummy chain params here as they're only needed + // to make the address struct, which we're ignoring anyway (scripts are + // always the same, it's addresses that change across chains). + scriptClass, _, _, _ := txscript.ExtractPkScriptAddrs( + pkScript, &chaincfg.MainNetParams, + ) + + return scriptClass +} + +// addScriptWeight parses the passed pkScript and adds the computed weight cost +// were the script to be added to the justice transaction. +func addScriptWeight(weightEstimate *input.TxWeightEstimator, + pkScript []byte) error { + + switch addrType(pkScript) { + case txscript.WitnessV0PubKeyHashTy: + weightEstimate.AddP2WKHOutput() + + case txscript.WitnessV0ScriptHashTy: + weightEstimate.AddP2WSHOutput() + + case txscript.WitnessV1TaprootTy: + weightEstimate.AddP2TROutput() + + default: + return fmt.Errorf("invalid addr type: %v", addrType(pkScript)) + } + + return nil +} + +// bindSession first populates all state-dependent variables of the task. Then +// it determines if the backupTask is compatible with the passed SessionInfo's +// policy. If no error is returned, the task has been bound to the session and +// can be queued to upload to the tower. Otherwise, the bind failed and should +// be rescheduled with a different session. +func (t *backupTask) bindSession(session *wtdb.ClientSessionBody, + newBreachRetribution BreachRetributionBuilder) error { + + breachInfo, chanType, err := newBreachRetribution( + t.id.ChanID, t.id.CommitHeight, + ) + if err != nil { + return err + } // Parse the non-dust outputs from the breach transaction, // simultaneously computing the total amount contained in the inputs @@ -123,76 +189,11 @@ func newBackupTask(chanID *lnwire.ChannelID, totalAmt += breachInfo.LocalOutputSignDesc.Output.Value } - return &backupTask{ - id: wtdb.BackupID{ - ChanID: *chanID, - CommitHeight: breachInfo.RevokedStateNum, - }, - breachInfo: breachInfo, - chanType: chanType, - toLocalInput: toLocalInput, - toRemoteInput: toRemoteInput, - totalAmt: btcutil.Amount(totalAmt), - sweepPkScript: sweepPkScript, - } -} + t.breachInfo = breachInfo + t.toLocalInput = toLocalInput + t.toRemoteInput = toRemoteInput + t.totalAmt = btcutil.Amount(totalAmt) -// inputs returns all non-dust inputs that we will attempt to spend from. -// -// NOTE: Ordering of the inputs is not critical as we sort the transaction with -// BIP69. -func (t *backupTask) inputs() map[wire.OutPoint]input.Input { - inputs := make(map[wire.OutPoint]input.Input) - if t.toLocalInput != nil { - inputs[*t.toLocalInput.OutPoint()] = t.toLocalInput - } - if t.toRemoteInput != nil { - inputs[*t.toRemoteInput.OutPoint()] = t.toRemoteInput - } - return inputs -} - -// addrType returns the type of an address after parsing it and matching it to -// the set of known script templates. -func addrType(pkScript []byte) txscript.ScriptClass { - // We pass in a set of dummy chain params here as they're only needed - // to make the address struct, which we're ignoring anyway (scripts are - // always the same, it's addresses that change across chains). - scriptClass, _, _, _ := txscript.ExtractPkScriptAddrs( - pkScript, &chaincfg.MainNetParams, - ) - - return scriptClass -} - -// addScriptWeight parses the passed pkScript and adds the computed weight cost -// were the script to be added to the justice transaction. -func addScriptWeight(weightEstimate *input.TxWeightEstimator, - pkScript []byte) error { - - switch addrType(pkScript) { //nolint: whitespace - - case txscript.WitnessV0PubKeyHashTy: - weightEstimate.AddP2WKHOutput() - - case txscript.WitnessV0ScriptHashTy: - weightEstimate.AddP2WSHOutput() - - case txscript.WitnessV1TaprootTy: - weightEstimate.AddP2TROutput() - - default: - return fmt.Errorf("invalid addr type: %v", addrType(pkScript)) - } - - return nil -} - -// bindSession determines if the backupTask is compatible with the passed -// SessionInfo's policy. If no error is returned, the task has been bound to the -// session and can be queued to upload to the tower. Otherwise, the bind failed -// and should be rescheduled with a different session. -func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error { // First we'll begin by deriving a weight estimate for the justice // transaction. The final weight can be different depending on whether // the watchtower is taking a reward. @@ -208,7 +209,7 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error { // original weight estimate. For anchor channels we'll go ahead // an use the correct penalty witness when signing our justice // transactions. - if t.chanType.HasAnchors() { + if chanType.HasAnchors() { weightEstimate.AddWitnessInput( input.ToLocalPenaltyWitnessSize, ) @@ -222,7 +223,7 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error { // Legacy channels (both tweaked and non-tweaked) spend from // P2WKH output. Anchor channels spend a to-remote confirmed // P2WSH output. - if t.chanType.HasAnchors() { + if chanType.HasAnchors() { weightEstimate.AddWitnessInput( input.ToRemoteConfirmedWitnessSize, ) @@ -233,7 +234,7 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error { // All justice transactions will either use segwit v0 (p2wkh + p2wsh) // or segwit v1 (p2tr). - err := addScriptWeight(&weightEstimate, t.sweepPkScript) + err = addScriptWeight(&weightEstimate, t.sweepPkScript) if err != nil { return err } @@ -247,9 +248,9 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error { } } - if t.chanType.HasAnchors() != session.Policy.IsAnchorChannel() { + if chanType.HasAnchors() != session.Policy.IsAnchorChannel() { log.Criticalf("Invalid task (has_anchors=%t) for session "+ - "(has_anchors=%t)", t.chanType.HasAnchors(), + "(has_anchors=%t)", chanType.HasAnchors(), session.Policy.IsAnchorChannel()) } diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index c536c433b..43b818044 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -483,19 +483,19 @@ func TestBackupTask(t *testing.T) { func testBackupTask(t *testing.T, test backupTaskTest) { // Create a new backupTask from the channel id and breach info. - task := newBackupTask( - &test.chanID, test.breachInfo, test.expSweepScript, - test.chanType, - ) + id := wtdb.BackupID{ + ChanID: test.chanID, + CommitHeight: test.breachInfo.RevokedStateNum, + } + task := newBackupTask(id, test.expSweepScript) - // Assert that all parameters set during initialization are properly - // populated. - require.Equal(t, test.chanID, task.id.ChanID) - require.Equal(t, test.breachInfo.RevokedStateNum, task.id.CommitHeight) - require.Equal(t, test.expTotalAmt, task.totalAmt) - require.Equal(t, test.breachInfo, task.breachInfo) - require.Equal(t, test.expToLocalInput, task.toLocalInput) - require.Equal(t, test.expToRemoteInput, task.toRemoteInput) + // getBreachInfo is a helper closure that returns the breach retribution + // info and channel type for the given channel and commit height. + getBreachInfo := func(id lnwire.ChannelID, commitHeight uint64) ( + *lnwallet.BreachRetribution, channeldb.ChannelType, error) { + + return test.breachInfo, test.chanType, nil + } // Reconstruct the expected input.Inputs that will be returned by the // task's inputs() method. @@ -515,9 +515,18 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Now, bind the session to the task. If successful, this locks in the // session's negotiated parameters and allows the backup task to derive // the final free variables in the justice transaction. - err := task.bindSession(test.session) + err := task.bindSession(test.session, getBreachInfo) require.ErrorIs(t, err, test.bindErr) + // Assert that all parameters set during after binding the backup task + // are properly populated. + require.Equal(t, test.chanID, task.id.ChanID) + require.Equal(t, test.breachInfo.RevokedStateNum, task.id.CommitHeight) + require.Equal(t, test.expTotalAmt, task.totalAmt) + require.Equal(t, test.breachInfo, task.breachInfo) + require.Equal(t, test.expToLocalInput, task.toLocalInput) + require.Equal(t, test.expToRemoteInput, task.toRemoteInput) + // Exit early if the bind was supposed to fail. But first, we check that // all fields set during a bind are still unset. This ensure that a // failed bind doesn't have side-effects if the task is retried with a diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 5373edb66..5a59cbcd9 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -829,17 +829,12 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, c.chanCommitHeights[*chanID] = stateNum c.backupMu.Unlock() - // Fetch the breach retribution info and channel type. - breachInfo, chanType, err := c.cfg.BuildBreachRetribution( - *chanID, stateNum, - ) - if err != nil { - return err + id := wtdb.BackupID{ + ChanID: *chanID, + CommitHeight: stateNum, } - task := newBackupTask( - chanID, breachInfo, summary.SweepPkScript, chanType, - ) + task := newBackupTask(id, summary.SweepPkScript) return c.pipeline.QueueBackupTask(task) } @@ -1543,16 +1538,17 @@ func (c *TowerClient) newSessionQueue(s *ClientSession, updates []wtdb.CommittedUpdate) *sessionQueue { return newSessionQueue(&sessionQueueConfig{ - ClientSession: s, - ChainHash: c.cfg.ChainHash, - Dial: c.dial, - ReadMessage: c.readMessage, - SendMessage: c.sendMessage, - Signer: c.cfg.Signer, - DB: c.cfg.DB, - MinBackoff: c.cfg.MinBackoff, - MaxBackoff: c.cfg.MaxBackoff, - Log: c.log, + ClientSession: s, + ChainHash: c.cfg.ChainHash, + Dial: c.dial, + ReadMessage: c.readMessage, + SendMessage: c.sendMessage, + Signer: c.cfg.Signer, + DB: c.cfg.DB, + MinBackoff: c.cfg.MinBackoff, + MaxBackoff: c.cfg.MaxBackoff, + Log: c.log, + BuildBreachRetribution: c.cfg.BuildBreachRetribution, }, updates) } diff --git a/watchtower/wtclient/session_queue.go b/watchtower/wtclient/session_queue.go index 7f29deb75..aa06709ee 100644 --- a/watchtower/wtclient/session_queue.go +++ b/watchtower/wtclient/session_queue.go @@ -57,6 +57,11 @@ type sessionQueueConfig struct { // for justice transaction inputs. Signer input.Signer + // BuildBreachRetribution is a function closure that allows the client + // to fetch the breach retribution info for a certain channel at a + // certain revoked commitment height. + BuildBreachRetribution BreachRetributionBuilder + // DB provides access to the client's stable storage. DB DB @@ -220,7 +225,10 @@ func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) { // // TODO(conner): queue backups and retry with different session params. case reserveAvailable: - err := task.bindSession(&q.cfg.ClientSession.ClientSessionBody) + err := task.bindSession( + &q.cfg.ClientSession.ClientSessionBody, + q.cfg.BuildBreachRetribution, + ) if err != nil { q.queueCond.L.Unlock() q.log.Debugf("SessionQueue(%s) rejected %v: %v ", From 65dc20f2ccf6b6014a509d304b6ad55603b409d7 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 2 Feb 2023 12:26:27 +0200 Subject: [PATCH 10/12] wtclient: let task pipeline only carry wtdb.BackupID Since the retrubution info of a backup task is now only constructed at the time that the task is being bound to a session, the in-memory queue only needs to carry the BackupID of the task. --- watchtower/wtclient/client.go | 53 ++++++++++++++++++---------- watchtower/wtclient/task_pipeline.go | 13 ++++--- 2 files changed, 42 insertions(+), 24 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 5a59cbcd9..fe4b18648 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -301,7 +301,7 @@ type TowerClient struct { activeSessions sessionQueueSet sessionQueue *sessionQueue - prevTask *backupTask + prevTask *wtdb.BackupID closableSessionQueue *sessionCloseMinHeap @@ -804,10 +804,9 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, stateNum uint64) error { - // Retrieve the cached sweep pkscript used for this channel. + // Make sure that this channel is registered with the tower client. c.backupMu.Lock() - summary, ok := c.summaries[*chanID] - if !ok { + if _, ok := c.summaries[*chanID]; !ok { c.backupMu.Unlock() return ErrUnregisteredChannel @@ -829,14 +828,12 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, c.chanCommitHeights[*chanID] = stateNum c.backupMu.Unlock() - id := wtdb.BackupID{ + id := &wtdb.BackupID{ ChanID: *chanID, CommitHeight: stateNum, } - task := newBackupTask(id, summary.SweepPkScript) - - return c.pipeline.QueueBackupTask(task) + return c.pipeline.QueueBackupTask(id) } // nextSessionQueue attempts to fetch an active session from our set of @@ -1330,7 +1327,7 @@ func (c *TowerClient) backupDispatcher() { return } - c.log.Debugf("Processing %v", task.id) + c.log.Debugf("Processing %v", task) c.stats.taskReceived() c.processTask(task) @@ -1360,8 +1357,22 @@ func (c *TowerClient) backupDispatcher() { // sessionQueue hasn't been exhausted before proceeding to the next task. Tasks // that are rejected because the active sessionQueue is full will be cached as // the prevTask, and should be reprocessed after obtaining a new sessionQueue. -func (c *TowerClient) processTask(task *backupTask) { - status, accepted := c.sessionQueue.AcceptTask(task) +func (c *TowerClient) processTask(task *wtdb.BackupID) { + c.backupMu.Lock() + summary, ok := c.summaries[task.ChanID] + if !ok { + c.backupMu.Unlock() + + log.Infof("not processing task for unregistered channel: %s", + task.ChanID) + + return + } + c.backupMu.Unlock() + + backupTask := newBackupTask(*task, summary.SweepPkScript) + + status, accepted := c.sessionQueue.AcceptTask(backupTask) if accepted { c.taskAccepted(task, status) } else { @@ -1374,9 +1385,11 @@ func (c *TowerClient) processTask(task *backupTask) { // prevTask is always removed as a result of this call. The client's // sessionQueue will be removed if accepting the task left the sessionQueue in // an exhausted state. -func (c *TowerClient) taskAccepted(task *backupTask, newStatus reserveStatus) { - c.log.Infof("Queued %v successfully for session %v", - task.id, c.sessionQueue.ID()) +func (c *TowerClient) taskAccepted(task *wtdb.BackupID, + newStatus reserveStatus) { + + c.log.Infof("Queued %v successfully for session %v", task, + c.sessionQueue.ID()) c.stats.taskAccepted() @@ -1409,7 +1422,9 @@ func (c *TowerClient) taskAccepted(task *backupTask, newStatus reserveStatus) { // the sessionQueue to find a new session. If the sessionQueue was not // exhausted, the client marks the task as ineligible, as this implies we // couldn't construct a valid justice transaction given the session's policy. -func (c *TowerClient) taskRejected(task *backupTask, curStatus reserveStatus) { +func (c *TowerClient) taskRejected(task *wtdb.BackupID, + curStatus reserveStatus) { + switch curStatus { // The sessionQueue has available capacity but the task was rejected, @@ -1417,14 +1432,14 @@ func (c *TowerClient) taskRejected(task *backupTask, curStatus reserveStatus) { case reserveAvailable: c.stats.taskIneligible() - c.log.Infof("Ignoring ineligible %v", task.id) + c.log.Infof("Ignoring ineligible %v", task) err := c.cfg.DB.MarkBackupIneligible( - task.id.ChanID, task.id.CommitHeight, + task.ChanID, task.CommitHeight, ) if err != nil { c.log.Errorf("Unable to mark %v ineligible: %v", - task.id, err) + task, err) // It is safe to not handle this error, even if we could // not persist the result. At worst, this task may be @@ -1444,7 +1459,7 @@ func (c *TowerClient) taskRejected(task *backupTask, curStatus reserveStatus) { c.stats.sessionExhausted() c.log.Debugf("Session %v exhausted, %v queued for next session", - c.sessionQueue.ID(), task.id) + c.sessionQueue.ID(), task) // Cache the task that we pulled off, so that we can process it // once a new session queue is available. diff --git a/watchtower/wtclient/task_pipeline.go b/watchtower/wtclient/task_pipeline.go index 385f477af..9415e1d5b 100644 --- a/watchtower/wtclient/task_pipeline.go +++ b/watchtower/wtclient/task_pipeline.go @@ -6,6 +6,7 @@ import ( "time" "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/watchtower/wtdb" ) // taskPipeline implements a reliable, in-order queue that ensures its queue @@ -25,7 +26,7 @@ type taskPipeline struct { queueCond *sync.Cond queue *list.List - newBackupTasks chan *backupTask + newBackupTasks chan *wtdb.BackupID quit chan struct{} forceQuit chan struct{} @@ -37,7 +38,7 @@ func newTaskPipeline(log btclog.Logger) *taskPipeline { rq := &taskPipeline{ log: log, queue: list.New(), - newBackupTasks: make(chan *backupTask), + newBackupTasks: make(chan *wtdb.BackupID), quit: make(chan struct{}), forceQuit: make(chan struct{}), shutdown: make(chan struct{}), @@ -91,7 +92,7 @@ func (q *taskPipeline) ForceQuit() { // channel will be closed after a call to Stop and all pending tasks have been // delivered, or if a call to ForceQuit is called before the pending entries // have been drained. -func (q *taskPipeline) NewBackupTasks() <-chan *backupTask { +func (q *taskPipeline) NewBackupTasks() <-chan *wtdb.BackupID { return q.newBackupTasks } @@ -99,7 +100,7 @@ func (q *taskPipeline) NewBackupTasks() <-chan *backupTask { // of NewBackupTasks. If the taskPipeline is shutting down, ErrClientExiting is // returned. Otherwise, if QueueBackupTask returns nil it is guaranteed to be // delivered via NewBackupTasks unless ForceQuit is called before completion. -func (q *taskPipeline) QueueBackupTask(task *backupTask) error { +func (q *taskPipeline) QueueBackupTask(task *wtdb.BackupID) error { q.queueCond.L.Lock() select { @@ -164,7 +165,9 @@ func (q *taskPipeline) queueManager() { // Pop the first element from the queue. e := q.queue.Front() - task := q.queue.Remove(e).(*backupTask) + + //nolint:forcetypeassert + task := q.queue.Remove(e).(*wtdb.BackupID) q.queueCond.L.Unlock() select { From 08cde9886916bacd8ec9376677005abb027c33c3 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 22 Mar 2023 11:35:25 +0200 Subject: [PATCH 11/12] wtclient: add mutex locking in perUpdate Lock the `backupMu` when accessing `c.chanCommitHeights` in the `New` function. It is not strictly necessary right now but good to add it so that there is no accidental oversight if the `perUpdate` method is ever extracted and reused in future. --- watchtower/wtclient/client.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index fe4b18648..8a2597980 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -386,6 +386,9 @@ func New(config *Config) (*TowerClient, error) { return } + c.backupMu.Lock() + defer c.backupMu.Unlock() + // Take the highest commit height found in the session's acked // updates. height, ok := c.chanCommitHeights[chanID] From 9a1ed8c7b27cb230686fafa7d5e70029c16fc2a4 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 21 Apr 2023 10:40:00 +0200 Subject: [PATCH 12/12] docs: update release notes --- docs/release-notes/release-notes-0.17.0.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/release-notes/release-notes-0.17.0.md b/docs/release-notes/release-notes-0.17.0.md index e32b93157..43f9a4e90 100644 --- a/docs/release-notes/release-notes-0.17.0.md +++ b/docs/release-notes/release-notes-0.17.0.md @@ -6,7 +6,15 @@ implementation](https://github.com/lightningnetwork/lnd/pull/7377) logic in different update types. +## Watchtowers + +* Let the task pipeline [only carry + wtdb.BackupIDs](https://github.com/lightningnetwork/lnd/pull/7623) instead of + the entire retribution struct. This reduces the amount of data that needs to + be held in memory. + # Contributors (Alphabetical Order) +* Elle Mouton * Jordi Montes