From d29a55bbb5d01b580486cfd4f57c2cd2d2c989ed Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 12 Oct 2022 08:59:20 +0200 Subject: [PATCH 01/10] watchtower: make use of t.Cleanup funcs in tests Make use of the t.Cleanup helper function to clean up watchtower client tests instead of relying on defer calls. --- watchtower/wtclient/client_test.go | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index fcaad1588..53ef15d15 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -449,13 +449,16 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { if err := server.Start(); err != nil { t.Fatalf("Unable to start wtserver: %v", err) } + t.Cleanup(func() { + _ = server.Stop() + }) if err = client.Start(); err != nil { - server.Stop() t.Fatalf("Unable to start wtclient: %v", err) } + t.Cleanup(client.ForceQuit) + if err := client.AddTower(towerAddr); err != nil { - server.Stop() t.Fatalf("Unable to add tower to wtclient: %v", err) } @@ -979,7 +982,6 @@ var clientTests = []clientTest{ h.server.Stop() h.serverCfg.NoAckUpdates = true h.startServer() - defer h.server.Stop() // Send the next state update to the tower. Since the // tower isn't acking state updates, we expect this @@ -1000,12 +1002,10 @@ var clientTests = []clientTest{ h.server.Stop() h.serverCfg.NoAckUpdates = false h.startServer() - defer h.server.Stop() // Restart the client and allow it to process the // committed update. h.startClient() - defer h.client.ForceQuit() // Wait for the committed update to be accepted by the // tower. @@ -1052,7 +1052,6 @@ var clientTests = []clientTest{ h.server.Stop() h.serverCfg.NoAckUpdates = true h.startServer() - defer h.server.Stop() // Now, queue the retributions for backup. h.backupStates(chanID, 0, numUpdates, nil) @@ -1071,7 +1070,6 @@ var clientTests = []clientTest{ h.server.Stop() h.serverCfg.NoAckUpdates = false h.startServer() - defer h.server.Stop() // Wait for all of the updates to be populated in the // server's database. @@ -1215,13 +1213,11 @@ var clientTests = []clientTest{ h.server.Stop() h.serverCfg.NoAckCreateSession = false h.startServer() - defer h.server.Stop() // Restart the client with the same policy, which will // immediately try to overwrite the old session with an // identical one. h.startClient() - defer h.client.ForceQuit() // Now, queue the retributions for backup. h.backupStates(chanID, 0, numUpdates, nil) @@ -1273,14 +1269,12 @@ var clientTests = []clientTest{ h.server.Stop() h.serverCfg.NoAckCreateSession = false h.startServer() - defer h.server.Stop() // Restart the client with a new policy, which will // immediately try to overwrite the prior session with // the old policy. h.clientCfg.Policy.SweepFeeRate *= 2 h.startClient() - defer h.client.ForceQuit() // Now, queue the retributions for backup. h.backupStates(chanID, 0, numUpdates, nil) @@ -1341,7 +1335,6 @@ var clientTests = []clientTest{ // Restart the client with a new policy. h.clientCfg.Policy.MaxUpdates = 20 h.startClient() - defer h.client.ForceQuit() // Now, queue the second half of the retributions. h.backupStates(chanID, numUpdates/2, numUpdates, nil) @@ -1395,7 +1388,6 @@ var clientTests = []clientTest{ // maintained across restarts. h.client.Stop() h.startClient() - defer h.client.ForceQuit() // Try to back up the full range of retributions. Only // the second half should actually be sent. @@ -1528,10 +1520,6 @@ func TestClient(t *testing.T) { t.Parallel() h := newHarness(t, tc.cfg) - t.Cleanup(func() { - require.NoError(t, h.server.Stop()) - h.client.ForceQuit() - }) tc.fn(h) }) From ab4d4a19be51ccbef19f83c294fb737b5bbd4bab Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 11 Oct 2022 17:31:12 +0200 Subject: [PATCH 02/10] watchtower/wtclient: upgrade pkg to use require Upgrade all the tests in the wtclient package to make use of the `require` package. --- .../wtclient/backup_task_internal_test.go | 145 +++++------------- .../wtclient/candidate_iterator_test.go | 59 ++++--- watchtower/wtclient/client_test.go | 137 ++++++----------- 3 files changed, 108 insertions(+), 233 deletions(-) diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index 7d3178f3e..c536c433b 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -2,9 +2,6 @@ package wtclient import ( "bytes" - "crypto/rand" - "io" - "reflect" "testing" "github.com/btcsuite/btcd/btcec/v2" @@ -12,7 +9,6 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -54,14 +50,6 @@ var ( } ) -func makeAddrSlice(size int) []byte { - addr := make([]byte, size) - if _, err := io.ReadFull(rand.Reader, addr); err != nil { - panic("cannot make addr") - } - return addr -} - type backupTaskTest struct { name string chanID lnwire.ChannelID @@ -502,35 +490,12 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Assert that all parameters set during initialization are properly // populated. - if task.id.ChanID != test.chanID { - t.Fatalf("channel id mismatch, want: %s, got: %s", - test.chanID, task.id.ChanID) - } - - if task.id.CommitHeight != test.breachInfo.RevokedStateNum { - t.Fatalf("commit height mismatch, want: %d, got: %d", - test.breachInfo.RevokedStateNum, task.id.CommitHeight) - } - - if task.totalAmt != test.expTotalAmt { - t.Fatalf("total amount mismatch, want: %d, got: %v", - test.expTotalAmt, task.totalAmt) - } - - if !reflect.DeepEqual(task.breachInfo, test.breachInfo) { - t.Fatalf("breach info mismatch, want: %v, got: %v", - test.breachInfo, task.breachInfo) - } - - if !reflect.DeepEqual(task.toLocalInput, test.expToLocalInput) { - t.Fatalf("to-local input mismatch, want: %v, got: %v", - test.expToLocalInput, task.toLocalInput) - } - - if !reflect.DeepEqual(task.toRemoteInput, test.expToRemoteInput) { - t.Fatalf("to-local input mismatch, want: %v, got: %v", - test.expToRemoteInput, task.toRemoteInput) - } + require.Equal(t, test.chanID, task.id.ChanID) + require.Equal(t, test.breachInfo.RevokedStateNum, task.id.CommitHeight) + require.Equal(t, test.expTotalAmt, task.totalAmt) + require.Equal(t, test.breachInfo, task.breachInfo) + require.Equal(t, test.expToLocalInput, task.toLocalInput) + require.Equal(t, test.expToRemoteInput, task.toRemoteInput) // Reconstruct the expected input.Inputs that will be returned by the // task's inputs() method. @@ -545,34 +510,24 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Assert that the inputs method returns the correct slice of // input.Inputs. inputs := task.inputs() - if !reflect.DeepEqual(expInputs, inputs) { - t.Fatalf("inputs mismatch, want: %v, got: %v", - expInputs, inputs) - } + require.Equal(t, expInputs, inputs) // Now, bind the session to the task. If successful, this locks in the // session's negotiated parameters and allows the backup task to derive // the final free variables in the justice transaction. err := task.bindSession(test.session) - if err != test.bindErr { - t.Fatalf("expected: %v when binding session, got: %v", - test.bindErr, err) - } + require.ErrorIs(t, err, test.bindErr) // Exit early if the bind was supposed to fail. But first, we check that // all fields set during a bind are still unset. This ensure that a // failed bind doesn't have side-effects if the task is retried with a // different session. if test.bindErr != nil { - if task.blobType != 0 { - t.Fatalf("blob type should not be set on failed bind, "+ - "found: %s", task.blobType) - } + require.Zerof(t, task.blobType, "blob type should not be set "+ + "on failed bind, found: %s", task.blobType) - if task.outputs != nil { - t.Fatalf("justice outputs should not be set on failed bind, "+ - "found: %v", task.outputs) - } + require.Nilf(t, task.outputs, "justice outputs should not be "+ + " set on failed bind, found: %v", task.outputs) return } @@ -580,10 +535,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Otherwise, the binding succeeded. Assert that all values set during // the bind are properly populated. policy := test.session.Policy - if task.blobType != policy.BlobType { - t.Fatalf("blob type mismatch, want: %s, got %s", - policy.BlobType, task.blobType) - } + require.Equal(t, policy.BlobType, task.blobType) // Compute the expected outputs on the justice transaction. var expOutputs = []*wire.TxOut{ @@ -603,10 +555,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) { } // Assert that the computed outputs match our expected outputs. - if !reflect.DeepEqual(expOutputs, task.outputs) { - t.Fatalf("justice txn output mismatch, want: %v,\ngot: %v", - spew.Sdump(expOutputs), spew.Sdump(task.outputs)) - } + require.Equal(t, expOutputs, task.outputs) // Now, we'll construct, sign, and encrypt the blob containing the parts // needed to reconstruct the justice transaction. @@ -616,10 +565,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Verify that the breach hint matches the breach txid's prefix. breachTxID := test.breachInfo.BreachTxHash expHint := blob.NewBreachHintFromHash(&breachTxID) - if hint != expHint { - t.Fatalf("breach hint mismatch, want: %x, got: %v", - expHint, hint) - } + require.Equal(t, expHint, hint) // Decrypt the return blob to obtain the JusticeKit containing its // contents. @@ -634,14 +580,8 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Assert that the blob contained the serialized revocation and to-local // pubkeys. - if !bytes.Equal(jKit.RevocationPubKey[:], expRevPK) { - t.Fatalf("revocation pk mismatch, want: %x, got: %x", - expRevPK, jKit.RevocationPubKey[:]) - } - if !bytes.Equal(jKit.LocalDelayPubKey[:], expToLocalPK) { - t.Fatalf("revocation pk mismatch, want: %x, got: %x", - expToLocalPK, jKit.LocalDelayPubKey[:]) - } + require.Equal(t, expRevPK, jKit.RevocationPubKey[:]) + require.Equal(t, expToLocalPK, jKit.LocalDelayPubKey[:]) // Determine if the breach transaction has a to-remote output and/or // to-local output to spend from. Note the seemingly-reversed @@ -650,32 +590,19 @@ func testBackupTask(t *testing.T, test backupTaskTest) { hasToLocal := test.breachInfo.RemoteOutputSignDesc != nil // If the to-remote output is present, assert that the to-remote public - // key was included in the blob. - if hasToRemote && - !bytes.Equal(jKit.CommitToRemotePubKey[:], expToRemotePK) { - t.Fatalf("mismatch to-remote pubkey, want: %x, got: %x", - expToRemotePK, jKit.CommitToRemotePubKey) - } - - // Otherwise if the to-local output is not present, assert that a blank - // public key was inserted. - if !hasToRemote && - !bytes.Equal(jKit.CommitToRemotePubKey[:], zeroPK[:]) { - t.Fatalf("mismatch to-remote pubkey, want: %x, got: %x", - zeroPK, jKit.CommitToRemotePubKey) + // key was included in the blob. Otherwise assert that a blank public + // key was inserted. + if hasToRemote { + require.Equal(t, expToRemotePK, jKit.CommitToRemotePubKey[:]) + } else { + require.Equal(t, zeroPK[:], jKit.CommitToRemotePubKey[:]) } // Assert that the CSV is encoded in the blob. - if jKit.CSVDelay != test.breachInfo.RemoteDelay { - t.Fatalf("mismatch remote delay, want: %d, got: %v", - test.breachInfo.RemoteDelay, jKit.CSVDelay) - } + require.Equal(t, test.breachInfo.RemoteDelay, jKit.CSVDelay) // Assert that the sweep pkscript is included. - if !bytes.Equal(jKit.SweepAddress, test.expSweepScript) { - t.Fatalf("sweep pkscript mismatch, want: %x, got: %x", - test.expSweepScript, jKit.SweepAddress) - } + require.Equal(t, test.expSweepScript, jKit.SweepAddress) // Finally, verify that the signatures are encoded in the justice kit. // We don't validate the actual signatures produced here, since at the @@ -684,18 +611,20 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // TODO(conner): include signature validation checks emptyToLocalSig := bytes.Equal(jKit.CommitToLocalSig[:], zeroSig[:]) - switch { - case hasToLocal && emptyToLocalSig: - t.Fatalf("to-local signature should not be empty") - case !hasToLocal && !emptyToLocalSig: - t.Fatalf("to-local signature should be empty") + if hasToLocal { + require.False(t, emptyToLocalSig, "to-local signature should "+ + "not be empty") + } else { + require.True(t, emptyToLocalSig, "to-local signature should "+ + "be empty") } emptyToRemoteSig := bytes.Equal(jKit.CommitToRemoteSig[:], zeroSig[:]) - switch { - case hasToRemote && emptyToRemoteSig: - t.Fatalf("to-remote signature should not be empty") - case !hasToRemote && !emptyToRemoteSig: - t.Fatalf("to-remote signature should be empty") + if hasToRemote { + require.False(t, emptyToRemoteSig, "to-remote signature "+ + "should not be empty") + } else { + require.True(t, emptyToRemoteSig, "to-remote signature "+ + "should be empty") } } diff --git a/watchtower/wtclient/candidate_iterator_test.go b/watchtower/wtclient/candidate_iterator_test.go index 99547d794..9a919e103 100644 --- a/watchtower/wtclient/candidate_iterator_test.go +++ b/watchtower/wtclient/candidate_iterator_test.go @@ -4,12 +4,10 @@ import ( "encoding/binary" "math/rand" "net" - "reflect" "testing" "time" "github.com/btcsuite/btcd/btcec/v2" - "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/stretchr/testify/require" ) @@ -19,15 +17,16 @@ func init() { } func randAddr(t *testing.T) net.Addr { - var ip [4]byte - if _, err := rand.Read(ip[:]); err != nil { - t.Fatal(err) - } - var port [2]byte - if _, err := rand.Read(port[:]); err != nil { - t.Fatal(err) + t.Helper() + + var ip [4]byte + _, err := rand.Read(ip[:]) + require.NoError(t, err) + + var port [2]byte + _, err = rand.Read(port[:]) + require.NoError(t, err) - } return &net.TCPAddr{ IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(port[:])), @@ -35,6 +34,8 @@ func randAddr(t *testing.T) net.Addr { } func randTower(t *testing.T) *wtdb.Tower { + t.Helper() + priv, err := btcec.NewPrivateKey() require.NoError(t, err, "unable to create private key") pubKey := priv.PubKey() @@ -58,27 +59,24 @@ func copyTower(tower *wtdb.Tower) *wtdb.Tower { func assertActiveCandidate(t *testing.T, i TowerCandidateIterator, c *wtdb.Tower, active bool) { + t.Helper() + isCandidate := i.IsActive(c.ID) - if isCandidate && !active { - t.Fatalf("expected tower %v to no longer be an active candidate", - c.ID) - } - if !isCandidate && active { - t.Fatalf("expected tower %v to be an active candidate", c.ID) + if isCandidate { + require.Truef(t, active, "expected tower %v to no longer be "+ + "an active candidate", c.ID) + return } + require.Falsef(t, active, "expected tower %v to be an active candidate", + c.ID) } func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *wtdb.Tower) { t.Helper() tower, err := i.Next() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(tower, c) { - t.Fatalf("expected tower: %v\ngot: %v", spew.Sdump(c), - spew.Sdump(tower)) - } + require.NoError(t, err) + require.Equal(t, c, tower) } // TestTowerCandidateIterator asserts the internal state of a @@ -104,18 +102,13 @@ func TestTowerCandidateIterator(t *testing.T) { // were added. for _, expTower := range towers { tower, err := towerIterator.Next() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(tower, expTower) { - t.Fatalf("expected tower: %v\ngot: %v", - spew.Sdump(expTower), spew.Sdump(tower)) - } + require.NoError(t, err) + require.Equal(t, expTower, tower) } - if _, err := towerIterator.Next(); err != ErrTowerCandidatesExhausted { - t.Fatalf("expected ErrTowerCandidatesExhausted, got %v", err) - } + _, err := towerIterator.Next() + require.ErrorIs(t, err, ErrTowerCandidatesExhausted) + towerIterator.Reset() // We'll then attempt to test the RemoveCandidate behavior of the diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 53ef15d15..a9c4a330b 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -325,10 +325,8 @@ func (c *mockChannel) sendPayment(t *testing.T, amt lnwire.MilliSatoshi) { c.mu.Lock() defer c.mu.Unlock() - if c.localBalance < amt { - t.Fatalf("insufficient funds to send, need: %v, have: %v", - amt, c.localBalance) - } + require.GreaterOrEqualf(t, c.localBalance, amt, "insufficient funds "+ + "to send, need: %v, have: %v", amt, c.localBalance) c.localBalance -= amt c.remoteBalance += amt @@ -343,10 +341,8 @@ func (c *mockChannel) receivePayment(t *testing.T, amt lnwire.MilliSatoshi) { c.mu.Lock() defer c.mu.Unlock() - if c.remoteBalance < amt { - t.Fatalf("insufficient funds to recv, need: %v, have: %v", - amt, c.remoteBalance) - } + require.GreaterOrEqualf(t, c.remoteBalance, amt, "insufficient funds "+ + "to recv, need: %v, have: %v", amt, c.remoteBalance) c.localBalance += amt c.remoteBalance -= amt @@ -446,21 +442,18 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { client, err := wtclient.New(clientCfg) require.NoError(t, err, "Unable to create wtclient") - if err := server.Start(); err != nil { - t.Fatalf("Unable to start wtserver: %v", err) - } + err = server.Start() + require.NoError(t, err) t.Cleanup(func() { _ = server.Stop() }) - if err = client.Start(); err != nil { - t.Fatalf("Unable to start wtclient: %v", err) - } + err = client.Start() + require.NoError(t, err) t.Cleanup(client.ForceQuit) - if err := client.AddTower(towerAddr); err != nil { - t.Fatalf("Unable to add tower to wtclient: %v", err) - } + err = client.AddTower(towerAddr) + require.NoError(t, err) h := &testHarness{ t: t, @@ -493,15 +486,11 @@ func (h *testHarness) startServer() { var err error h.server, err = wtserver.New(h.serverCfg) - if err != nil { - h.t.Fatalf("unable to create wtserver: %v", err) - } + require.NoError(h.t, err) h.net.setConnCallback(h.server.InboundPeerConnected) - if err := h.server.Start(); err != nil { - h.t.Fatalf("unable to start wtserver: %v", err) - } + require.NoError(h.t, h.server.Start()) } // startClient creates a new server using the harness's current clientCf and @@ -510,24 +499,16 @@ func (h *testHarness) startClient() { h.t.Helper() towerTCPAddr, err := net.ResolveTCPAddr("tcp", towerAddrStr) - if err != nil { - h.t.Fatalf("Unable to resolve tower TCP addr: %v", err) - } + require.NoError(h.t, err) towerAddr := &lnwire.NetAddress{ IdentityKey: h.serverCfg.NodeKeyECDH.PubKey(), Address: towerTCPAddr, } h.client, err = wtclient.New(h.clientCfg) - if err != nil { - h.t.Fatalf("unable to create wtclient: %v", err) - } - if err := h.client.Start(); err != nil { - h.t.Fatalf("unable to start wtclient: %v", err) - } - if err := h.client.AddTower(towerAddr); err != nil { - h.t.Fatalf("unable to add tower to wtclient: %v", err) - } + require.NoError(h.t, err) + require.NoError(h.t, h.client.Start()) + require.NoError(h.t, h.client.AddTower(towerAddr)) } // chanIDFromInt creates a unique channel id given a unique integral id. @@ -556,9 +537,7 @@ func (h *testHarness) makeChannel(id uint64, } c.mu.Unlock() - if ok { - h.t.Fatalf("channel %d already created", id) - } + require.Falsef(h.t, ok, "channel %d already created", id) } // channel retrieves the channel corresponding to id. @@ -570,9 +549,7 @@ func (h *testHarness) channel(id uint64) *mockChannel { h.mu.Lock() c, ok := h.channels[chanIDFromInt(id)] h.mu.Unlock() - if !ok { - h.t.Fatalf("unable to fetch channel %d", id) - } + require.Truef(h.t, ok, "unable to fetch channel %d", id) return c } @@ -583,9 +560,7 @@ func (h *testHarness) registerChannel(id uint64) { chanID := chanIDFromInt(id) err := h.client.RegisterChannel(chanID) - if err != nil { - h.t.Fatalf("unable to register channel %d: %v", id, err) - } + require.NoError(h.t, err) } // advanceChannelN calls advanceState on the channel identified by id the number @@ -624,11 +599,10 @@ func (h *testHarness) backupState(id, i uint64, expErr error) { _, retribution := h.channel(id).getState(i) chanID := chanIDFromInt(id) - err := h.client.BackupState(&chanID, retribution, channeldb.SingleFunderBit) - if err != expErr { - h.t.Fatalf("back error mismatch, want: %v, got: %v", - expErr, err) - } + err := h.client.BackupState( + &chanID, retribution, channeldb.SingleFunderBit, + ) + require.ErrorIs(h.t, expErr, err) } // sendPayments instructs the channel identified by id to send amt to the remote @@ -688,10 +662,8 @@ func (h *testHarness) waitServerUpdates(hints []blob.BreachHint, hintSet[hint] = struct{}{} } - if len(hints) != len(hintSet) { - h.t.Fatalf("breach hints are not unique, list-len: %d "+ - "set-len: %d", len(hints), len(hintSet)) - } + require.Lenf(h.t, hints, len(hintSet), "breach hints are not unique, "+ + "list-len: %d set-len: %d", len(hints), len(hintSet)) // Closure to assert the server's matches are consistent with the hint // set. @@ -701,12 +673,9 @@ func (h *testHarness) waitServerUpdates(hints []blob.BreachHint, } for _, match := range matches { - if _, ok := hintSet[match.Hint]; ok { - continue - } - - h.t.Fatalf("match %v in db is not in hint set", - match.Hint) + _, ok := hintSet[match.Hint] + require.Truef(h.t, ok, "match %v in db is not in "+ + "hint set", match.Hint) } return true @@ -717,31 +686,24 @@ func (h *testHarness) waitServerUpdates(hints []blob.BreachHint, select { case <-time.After(time.Second): matches, err := h.serverDB.QueryMatches(hints) - switch { - case err != nil: - h.t.Fatalf("unable to query for hints: %v", err) + require.NoError(h.t, err, "unable to query for hints") - case wantUpdates && serverHasHints(matches): + if wantUpdates && serverHasHints(matches) { return + } - case wantUpdates: + if wantUpdates { h.t.Logf("Received %d/%d\n", len(matches), len(hints)) } case <-failTimeout: matches, err := h.serverDB.QueryMatches(hints) - switch { - case err != nil: - h.t.Fatalf("unable to query for hints: %v", err) - - case serverHasHints(matches): - return - - default: - h.t.Fatalf("breach hints not received, only "+ - "got %d/%d", len(matches), len(hints)) - } + require.NoError(h.t, err, "unable to query for hints") + require.Truef(h.t, serverHasHints(matches), "breach "+ + "hints not received, only got %d/%d", + len(matches), len(hints)) + return } } } @@ -754,25 +716,18 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint, // Query for matches on the provided hints. matches, err := h.serverDB.QueryMatches(hints) - if err != nil { - h.t.Fatalf("unable to query for matches: %v", err) - } + require.NoError(h.t, err) // Assert that the number of matches is exactly the number of provided // hints. - if len(matches) != len(hints) { - h.t.Fatalf("expected: %d matches, got: %d", len(hints), - len(matches)) - } + require.Lenf(h.t, matches, len(hints), "expected: %d matches, got: %d", + len(hints), len(matches)) // Assert that all of the matches correspond to a session with the // expected policy. for _, match := range matches { matchPolicy := match.SessionInfo.Policy - if expPolicy != matchPolicy { - h.t.Fatalf("expected session to have policy: %v, "+ - "got: %v", expPolicy, matchPolicy) - } + require.Equal(h.t, expPolicy, matchPolicy) } } @@ -780,9 +735,8 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint, func (h *testHarness) addTower(addr *lnwire.NetAddress) { h.t.Helper() - if err := h.client.AddTower(addr); err != nil { - h.t.Fatalf("unable to add tower: %v", err) - } + err := h.client.AddTower(addr) + require.NoError(h.t, err) } // removeTower removes a tower from the client. If `addr` is specified, then the @@ -790,9 +744,8 @@ func (h *testHarness) addTower(addr *lnwire.NetAddress) { func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) { h.t.Helper() - if err := h.client.RemoveTower(pubKey, addr); err != nil { - h.t.Fatalf("unable to remove tower: %v", err) - } + err := h.client.RemoveTower(pubKey, addr) + require.NoError(h.t, err) } const ( From 4828fd902dc8acc281a51134934d344466972962 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 7 Sep 2022 11:47:54 +0200 Subject: [PATCH 03/10] wtclient: allow multiplie callback dial functions This commit is a step towards prepping the watchtower client test harness to be able to handle the case where the client connects to multiple mock servers. --- watchtower/wtclient/client_test.go | 109 ++++++++++++++++------------- 1 file changed, 60 insertions(+), 49 deletions(-) diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index a9c4a330b..312a7bc1b 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -2,6 +2,7 @@ package wtclient_test import ( "encoding/binary" + "fmt" "net" "sync" "testing" @@ -76,37 +77,34 @@ func randPrivKey(t *testing.T) *btcec.PrivateKey { } type mockNet struct { - mu sync.RWMutex - connCallback func(wtserver.Peer) + mu sync.RWMutex + connCallbacks map[string]func(wtserver.Peer) } -func newMockNet(cb func(wtserver.Peer)) *mockNet { +func newMockNet() *mockNet { return &mockNet{ - connCallback: cb, + connCallbacks: make(map[string]func(peer wtserver.Peer)), } } -func (m *mockNet) Dial(network string, address string, - timeout time.Duration) (net.Conn, error) { - +func (m *mockNet) Dial(_, _ string, _ time.Duration) (net.Conn, error) { return nil, nil } -func (m *mockNet) LookupHost(host string) ([]string, error) { +func (m *mockNet) LookupHost(_ string) ([]string, error) { panic("not implemented") } -func (m *mockNet) LookupSRV(service string, proto string, name string) (string, []*net.SRV, error) { +func (m *mockNet) LookupSRV(_, _, _ string) (string, []*net.SRV, error) { panic("not implemented") } -func (m *mockNet) ResolveTCPAddr(network string, address string) (*net.TCPAddr, error) { +func (m *mockNet) ResolveTCPAddr(_, _ string) (*net.TCPAddr, error) { panic("not implemented") } func (m *mockNet) AuthDial(local keychain.SingleKeyECDH, - netAddr *lnwire.NetAddress, - dialer tor.DialFunc) (wtserver.Peer, error) { + netAddr *lnwire.NetAddress, _ tor.DialFunc) (wtserver.Peer, error) { localPk := local.PubKey() localAddr := &net.TCPAddr{ @@ -119,16 +117,31 @@ func (m *mockNet) AuthDial(local keychain.SingleKeyECDH, ) m.mu.RLock() - m.connCallback(remotePeer) - m.mu.RUnlock() + defer m.mu.RUnlock() + cb, ok := m.connCallbacks[netAddr.String()] + if !ok { + return nil, fmt.Errorf("no callback registered for this peer") + } + + cb(remotePeer) return localPeer, nil } -func (m *mockNet) setConnCallback(cb func(wtserver.Peer)) { +func (m *mockNet) registerConnCallback(netAddr *lnwire.NetAddress, + cb func(wtserver.Peer)) { + m.mu.Lock() defer m.mu.Unlock() - m.connCallback = cb + + m.connCallbacks[netAddr.String()] = cb +} + +func (m *mockNet) removeConnCallback(netAddr *lnwire.NetAddress) { + m.mu.Lock() + defer m.mu.Unlock() + + delete(m.connCallbacks, netAddr.String()) } type mockChannel struct { @@ -416,11 +429,8 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { NoAckCreateSession: cfg.noAckCreateSession, } - server, err := wtserver.New(serverCfg) - require.NoError(t, err, "unable to create wtserver") - signer := wtmock.NewMockSigner() - mockNet := newMockNet(server.InboundPeerConnected) + mockNet := newMockNet() clientDB := wtmock.NewClientDB() clientCfg := &wtclient.Config{ @@ -442,19 +452,6 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { client, err := wtclient.New(clientCfg) require.NoError(t, err, "Unable to create wtclient") - err = server.Start() - require.NoError(t, err) - t.Cleanup(func() { - _ = server.Stop() - }) - - err = client.Start() - require.NoError(t, err) - t.Cleanup(client.ForceQuit) - - err = client.AddTower(towerAddr) - require.NoError(t, err) - h := &testHarness{ t: t, cfg: cfg, @@ -466,11 +463,20 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { serverAddr: towerAddr, serverDB: serverDB, serverCfg: serverCfg, - server: server, net: mockNet, channels: make(map[lnwire.ChannelID]*mockChannel), } + h.startServer() + t.Cleanup(h.stopServer) + + err = client.Start() + require.NoError(t, err) + t.Cleanup(client.ForceQuit) + + err = client.AddTower(towerAddr) + require.NoError(t, err) + h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance) if !cfg.noRegisterChan0 { h.registerChannel(0) @@ -488,11 +494,20 @@ func (h *testHarness) startServer() { h.server, err = wtserver.New(h.serverCfg) require.NoError(h.t, err) - h.net.setConnCallback(h.server.InboundPeerConnected) + h.net.registerConnCallback(h.serverAddr, h.server.InboundPeerConnected) require.NoError(h.t, h.server.Start()) } +// stopServer stops the main harness server. +func (h *testHarness) stopServer() { + h.t.Helper() + + h.net.removeConnCallback(h.serverAddr) + + require.NoError(h.t, h.server.Stop()) +} + // startClient creates a new server using the harness's current clientCf and // starts it. func (h *testHarness) startClient() { @@ -932,7 +947,7 @@ var clientTests = []clientTest{ // Now, restart the server and prevent it from acking // state updates. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckUpdates = true h.startServer() @@ -952,7 +967,7 @@ var clientTests = []clientTest{ // Restart the server and allow it to ack the updates // after the client retransmits the unacked update. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckUpdates = false h.startServer() @@ -1002,7 +1017,7 @@ var clientTests = []clientTest{ // Restart the server and prevent it from acking state // updates. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckUpdates = true h.startServer() @@ -1020,7 +1035,7 @@ var clientTests = []clientTest{ // Restart the server and allow it to ack the updates // after the client retransmits the unacked updates. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckUpdates = false h.startServer() @@ -1163,7 +1178,7 @@ var clientTests = []clientTest{ // Restart the server and allow it to ack session // creation. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckCreateSession = false h.startServer() @@ -1219,7 +1234,7 @@ var clientTests = []clientTest{ // Restart the server and allow it to ack session // creation. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckCreateSession = false h.startServer() @@ -1390,8 +1405,7 @@ var clientTests = []clientTest{ // Re-add the tower. We prevent the tower from acking // session creation to ensure the inactive sessions are // not used. - err := h.server.Stop() - require.Nil(h.t, err) + h.stopServer() h.serverCfg.NoAckCreateSession = true h.startServer() h.addTower(h.serverAddr) @@ -1400,8 +1414,7 @@ var clientTests = []clientTest{ // Finally, allow the tower to ack session creation, // allowing the state updates to be sent through the new // session. - err = h.server.Stop() - require.Nil(h.t, err) + h.stopServer() h.serverCfg.NoAckCreateSession = false h.startServer() h.waitServerUpdates(hints[numUpdates/2:], 5*time.Second) @@ -1440,8 +1453,7 @@ var clientTests = []clientTest{ // Now, restart the tower and prevent it from acking any // new sessions. We do this here as once the last slot // is exhausted the client will attempt to renegotiate. - err := h.server.Stop() - require.Nil(h.t, err) + h.stopServer() h.serverCfg.NoAckCreateSession = true h.startServer() @@ -1458,8 +1470,7 @@ var clientTests = []clientTest{ // state to process. After the force quite delay // expires, the client should force quite itself and // allow the test to complete. - err = h.client.Stop() - require.Nil(h.t, err) + h.stopServer() }, }, } From 60f58b7812be3b313043a6f74bb9d5f9eb514eb0 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 12 Oct 2022 09:15:13 +0200 Subject: [PATCH 04/10] watchtower: simplify the newHarness test function --- watchtower/wtclient/client_test.go | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 312a7bc1b..a680f0a6b 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -449,8 +449,6 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { MaxBackoff: time.Second, ForceQuitDelay: 10 * time.Second, } - client, err := wtclient.New(clientCfg) - require.NoError(t, err, "Unable to create wtclient") h := &testHarness{ t: t, @@ -459,7 +457,6 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { capacity: cfg.localBalance + cfg.remoteBalance, clientDB: clientDB, clientCfg: clientCfg, - client: client, serverAddr: towerAddr, serverDB: serverDB, serverCfg: serverCfg, @@ -470,12 +467,8 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { h.startServer() t.Cleanup(h.stopServer) - err = client.Start() - require.NoError(t, err) - t.Cleanup(client.ForceQuit) - - err = client.AddTower(towerAddr) - require.NoError(t, err) + h.startClient() + t.Cleanup(h.client.ForceQuit) h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance) if !cfg.noRegisterChan0 { From 5bc8ee48fc14ff1c1cad53705572f78d6a108951 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 12 Oct 2022 09:16:45 +0200 Subject: [PATCH 05/10] watchtower: demo that client cant update tower address In this commit, a test is added to demonstrate that if a client tries to update the address of a tower for an active tower session, then this new address will not be used until the client restarts. This is a bug that will be fixed in a future commit. --- watchtower/wtclient/client_test.go | 104 ++++++++++++++++++++++++++--- 1 file changed, 93 insertions(+), 11 deletions(-) diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index a680f0a6b..6312bbbab 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -32,7 +32,8 @@ import ( const ( csvDelay uint32 = 144 - towerAddrStr = "18.28.243.2:9911" + towerAddrStr = "18.28.243.2:9911" + towerAddr2Str = "19.29.244.3:9912" ) var ( @@ -64,6 +65,8 @@ var ( ) addrScript, _ = txscript.PayToAddrScript(addr) + + waitTime = 5 * time.Second ) // randPrivKey generates a new secp keypair, and returns the public key. @@ -1034,7 +1037,7 @@ var clientTests = []clientTest{ // Wait for all of the updates to be populated in the // server's database. - h.waitServerUpdates(hints, 5*time.Second) + h.waitServerUpdates(hints, waitTime) }, }, { @@ -1185,7 +1188,7 @@ var clientTests = []clientTest{ // Wait for all of the updates to be populated in the // server's database. - h.waitServerUpdates(hints, 5*time.Second) + h.waitServerUpdates(hints, waitTime) // Assert that the server has updates for the clients // most recent policy. @@ -1242,7 +1245,7 @@ var clientTests = []clientTest{ // Wait for all of the updates to be populated in the // server's database. - h.waitServerUpdates(hints, 5*time.Second) + h.waitServerUpdates(hints, waitTime) // Assert that the server has updates for the clients // most recent policy. @@ -1302,7 +1305,7 @@ var clientTests = []clientTest{ // Wait for all of the updates to be populated in the // server's database. - h.waitServerUpdates(hints, 5*time.Second) + h.waitServerUpdates(hints, waitTime) // Assert that the server has updates for the client's // original policy. @@ -1343,7 +1346,7 @@ var clientTests = []clientTest{ // Wait for the first half of the updates to be // populated in the server's database. - h.waitServerUpdates(hints[:len(hints)/2], 5*time.Second) + h.waitServerUpdates(hints[:len(hints)/2], waitTime) // Restart the client, so we can ensure the deduping is // maintained across restarts. @@ -1356,7 +1359,7 @@ var clientTests = []clientTest{ // Wait for all of the updates to be populated in the // server's database. - h.waitServerUpdates(hints, 5*time.Second) + h.waitServerUpdates(hints, waitTime) }, }, { @@ -1384,7 +1387,7 @@ var clientTests = []clientTest{ // first two. hints := h.advanceChannelN(chanID, numUpdates) h.backupStates(chanID, 0, numUpdates/2, nil) - h.waitServerUpdates(hints[:numUpdates/2], 5*time.Second) + h.waitServerUpdates(hints[:numUpdates/2], waitTime) // Fully remove the tower, causing its existing sessions // to be marked inactive. @@ -1410,7 +1413,7 @@ var clientTests = []clientTest{ h.stopServer() h.serverCfg.NoAckCreateSession = false h.startServer() - h.waitServerUpdates(hints[numUpdates/2:], 5*time.Second) + h.waitServerUpdates(hints[numUpdates/2:], waitTime) }, }, { @@ -1441,7 +1444,7 @@ var clientTests = []clientTest{ // Back up 4 of the 5 states for the negotiated session. h.backupStates(chanID, 0, maxUpdates-1, nil) - h.waitServerUpdates(hints[:maxUpdates-1], 5*time.Second) + h.waitServerUpdates(hints[:maxUpdates-1], waitTime) // Now, restart the tower and prevent it from acking any // new sessions. We do this here as once the last slot @@ -1456,7 +1459,7 @@ var clientTests = []clientTest{ // the final state. We'll only wait for the first five // states to arrive at the tower. h.backupStates(chanID, maxUpdates-1, numUpdates, nil) - h.waitServerUpdates(hints[:maxUpdates], 5*time.Second) + h.waitServerUpdates(hints[:maxUpdates], waitTime) // Finally, stop the client which will continue to // attempt session negotiation since it has one more @@ -1466,6 +1469,85 @@ var clientTests = []clientTest{ h.stopServer() }, }, + { + // Assert that if a client changes the address for a server and + // then tries to back up updates then the client will not switch + // to the new address. The client will only use the server's new + // address after a restart. This is a bug that will be fixed in + // a future commit. + name: "change address of existing session", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 5, + }, + }, + fn: func(h *testHarness) { + const ( + chanID = 0 + numUpdates = 6 + maxUpdates = 5 + ) + + // Advance the channel to create all states. + hints := h.advanceChannelN(chanID, numUpdates) + + h.backupStates(chanID, 0, numUpdates/2, nil) + + // Wait for the first half of the updates to be + // populated in the server's database. + h.waitServerUpdates(hints[:len(hints)/2], waitTime) + + // Stop the server. + h.stopServer() + + // Change the address of the server. + towerTCPAddr, err := net.ResolveTCPAddr( + "tcp", towerAddr2Str, + ) + require.NoError(h.t, err) + + oldAddr := h.serverAddr.Address + towerAddr := &lnwire.NetAddress{ + IdentityKey: h.serverAddr.IdentityKey, + Address: towerTCPAddr, + } + h.serverAddr = towerAddr + + // Add the new tower address to the client. + err = h.client.AddTower(towerAddr) + require.NoError(h.t, err) + + // Remove the old tower address from the client. + err = h.client.RemoveTower( + towerAddr.IdentityKey, oldAddr, + ) + require.NoError(h.t, err) + + // Restart the server. + h.startServer() + + // Now attempt to back up the rest of the updates. + h.backupStates(chanID, numUpdates/2, maxUpdates, nil) + + // Assert that the server does not receive the updates. + h.waitServerUpdates(nil, waitTime) + + // Restart the client and attempt to back up the updates + // again. + h.client.Stop() + h.startClient() + h.backupStates(chanID, numUpdates/2, maxUpdates, nil) + + // The server should now receive the updates. + h.waitServerUpdates(hints[:maxUpdates], waitTime) + }, + }, } // TestClient executes the client test suite, asserting the ability to backup From 79245425005dd683848d3c6bc26a6b7b54ef536a Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 12 Oct 2022 09:21:38 +0200 Subject: [PATCH 06/10] watchtower: add AddressIterator and tests In this commit, a new AddressIterator type is added. It is a type that can be used to iterate over a list of addresses. It strictly disallows the list of addresses it holds to be empty. It also allows callers to place locks on certain addresses in order to prevent other callers from removing the addresses in question from the iterator. --- watchtower/wtclient/addr_iterator.go | 344 ++++++++++++++++++++++ watchtower/wtclient/addr_iterator_test.go | 188 ++++++++++++ 2 files changed, 532 insertions(+) create mode 100644 watchtower/wtclient/addr_iterator.go create mode 100644 watchtower/wtclient/addr_iterator_test.go diff --git a/watchtower/wtclient/addr_iterator.go b/watchtower/wtclient/addr_iterator.go new file mode 100644 index 000000000..87065c011 --- /dev/null +++ b/watchtower/wtclient/addr_iterator.go @@ -0,0 +1,344 @@ +package wtclient + +import ( + "container/list" + "errors" + "fmt" + "net" + "sync" + + "github.com/lightningnetwork/lnd/watchtower/wtdb" +) + +var ( + // ErrAddressesExhausted signals that a addressIterator has cycled + // through all available addresses. + ErrAddressesExhausted = errors.New("exhausted all addresses") + + // ErrAddrInUse indicates that an address is locked and cannot be + // removed from the addressIterator. + ErrAddrInUse = errors.New("address in use") +) + +// AddressIterator handles iteration over a list of addresses. It strictly +// disallows the list of addresses it holds to be empty. It also allows callers +// to place locks on certain addresses in order to prevent other callers from +// removing the addresses in question from the iterator. +type AddressIterator interface { + // Next returns the next candidate address. This iterator will always + // return candidates in the order given when the iterator was + // instantiated. If no more candidates are available, + // ErrAddressesExhausted is returned. + Next() (net.Addr, error) + + // NextAndLock does the same as described for Next, and it also places a + // lock on the returned address so that the address can not be removed + // until the lock on it has been released via ReleaseLock. + NextAndLock() (net.Addr, error) + + // Peek returns the currently selected address in the iterator. If the + // end of the iterator has been reached then it is reset and the first + // item in the iterator is returned. Since the AddressIterator will + // never have an empty address list, this function will never return a + // nil value. + Peek() net.Addr + + // PeekAndLock does the same as described for Peek, and it also places + // a lock on the returned address so that the address can not be removed + // until the lock on it has been released via ReleaseLock. + PeekAndLock() net.Addr + + // ReleaseLock releases the lock held on the given address. + ReleaseLock(addr net.Addr) + + // Add adds a new address to the iterator. + Add(addr net.Addr) + + // Remove removes an existing address from the iterator. It disallows + // the address from being removed if it is the last address in the + // iterator or if there is currently a lock on the address. + Remove(addr net.Addr) error + + // HasLocked returns true if the addressIterator has any locked + // addresses. + HasLocked() bool + + // GetAll returns a copy of all the addresses in the iterator. + GetAll() []net.Addr + + // Reset clears the iterators state, and makes the address at the front + // of the list the next item to be returned. + Reset() +} + +// A compile-time check to ensure that addressIterator implements the +// AddressIterator interface. +var _ AddressIterator = (*addressIterator)(nil) + +// addressIterator is a linked-list implementation of an AddressIterator. +type addressIterator struct { + mu sync.Mutex + addrList *list.List + currentTopAddr *list.Element + candidates map[string]*candidateAddr + totalLockCount int +} + +type candidateAddr struct { + addr net.Addr + numLocks int +} + +// newAddressIterator constructs a new addressIterator. +func newAddressIterator(addrs ...net.Addr) (*addressIterator, error) { + if len(addrs) == 0 { + return nil, fmt.Errorf("must have at least one address") + } + + iter := &addressIterator{ + addrList: list.New(), + candidates: make(map[string]*candidateAddr), + } + + for _, addr := range addrs { + addrID := addr.String() + iter.addrList.PushBack(addrID) + iter.candidates[addrID] = &candidateAddr{addr: addr} + } + iter.Reset() + + return iter, nil +} + +// Reset clears the iterators state, and makes the address at the front of the +// list the next item to be returned. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) Reset() { + a.mu.Lock() + defer a.mu.Unlock() + + a.unsafeReset() +} + +// unsafeReset clears the iterator state and makes the address at the front of +// the list the next item to be returned. +// +// NOTE: this method is not thread safe and so should only be called if the +// appropriate mutex is being held. +func (a *addressIterator) unsafeReset() { + // Reset the next candidate to the front of the linked-list. + a.currentTopAddr = a.addrList.Front() +} + +// Next returns the next candidate address. This iterator will always return +// candidates in the order given when the iterator was instantiated. If no more +// candidates are available, ErrAddressesExhausted is returned. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) Next() (net.Addr, error) { + return a.next(false) +} + +// NextAndLock does the same as described for Next, and it also places a lock on +// the returned address so that the address can not be removed until the lock on +// it has been released via ReleaseLock. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) NextAndLock() (net.Addr, error) { + return a.next(true) +} + +// next returns the next candidate address. This iterator will always return +// candidates in the order given when the iterator was instantiated. If no more +// candidates are available, ErrAddressesExhausted is returned. +func (a *addressIterator) next(lock bool) (net.Addr, error) { + a.mu.Lock() + defer a.mu.Unlock() + + // Set the next candidate to the subsequent element. + a.currentTopAddr = a.currentTopAddr.Next() + + for a.currentTopAddr != nil { + // Propose the address at the front of the list. + addrID := a.currentTopAddr.Value.(string) + + // Check whether this address is still considered a candidate. + // If it's not, we'll proceed to the next. + candidate, ok := a.candidates[addrID] + if !ok { + nextCandidate := a.currentTopAddr.Next() + a.addrList.Remove(a.currentTopAddr) + a.currentTopAddr = nextCandidate + continue + } + + if lock { + candidate.numLocks++ + a.totalLockCount++ + } + + return candidate.addr, nil + } + + return nil, ErrAddressesExhausted +} + +// Peek returns the currently selected address in the iterator. If the end of +// the list has been reached then the iterator is reset and the first item in +// the list is returned. Since the addressIterator will never have an empty +// address list, this function will never return a nil value. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) Peek() net.Addr { + return a.peek(false) +} + +// PeekAndLock does the same as described for Peek, and it also places a lock on +// the returned address so that the address can not be removed until the lock +// on it has been released via ReleaseLock. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) PeekAndLock() net.Addr { + return a.peek(true) +} + +// peek returns the currently selected address in the iterator. If the end of +// the list has been reached then the iterator is reset and the first item in +// the list is returned. Since the addressIterator will never have an empty +// address list, this function will never return a nil value. If lock is set to +// true, the address will be locked for removal until ReleaseLock has been +// called for the address. +func (a *addressIterator) peek(lock bool) net.Addr { + a.mu.Lock() + defer a.mu.Unlock() + + for { + // If currentTopAddr is nil, it means we have reached the end of + // the list, so we reset it here. The iterator always has at + // least one address, so we can be sure that currentTopAddr will + // be non-nil after calling reset here. + if a.currentTopAddr == nil { + a.unsafeReset() + } + + addrID := a.currentTopAddr.Value.(string) + candidate, ok := a.candidates[addrID] + if !ok { + nextCandidate := a.currentTopAddr.Next() + a.addrList.Remove(a.currentTopAddr) + a.currentTopAddr = nextCandidate + continue + } + + if lock { + candidate.numLocks++ + a.totalLockCount++ + } + + return candidate.addr + } +} + +// ReleaseLock releases the lock held on the given address. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) ReleaseLock(addr net.Addr) { + a.mu.Lock() + defer a.mu.Unlock() + + candidateAddr, ok := a.candidates[addr.String()] + if !ok { + return + } + + if candidateAddr.numLocks == 0 { + return + } + + candidateAddr.numLocks-- + a.totalLockCount-- +} + +// Add adds a new address to the iterator. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) Add(addr net.Addr) { + a.mu.Lock() + defer a.mu.Unlock() + + if _, ok := a.candidates[addr.String()]; ok { + return + } + + a.addrList.PushBack(addr.String()) + a.candidates[addr.String()] = &candidateAddr{addr: addr} + + // If we've reached the end of our queue, then this candidate + // will become the next. + if a.currentTopAddr == nil { + a.currentTopAddr = a.addrList.Back() + } +} + +// Remove removes an existing address from the iterator. It disallows the +// address from being removed if it is the last address in the iterator or if +// there is currently a lock on the address. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) Remove(addr net.Addr) error { + a.mu.Lock() + defer a.mu.Unlock() + + candidate, ok := a.candidates[addr.String()] + if !ok { + return nil + } + + if len(a.candidates) == 1 { + return wtdb.ErrLastTowerAddr + } + + if candidate.numLocks > 0 { + return ErrAddrInUse + } + + delete(a.candidates, addr.String()) + return nil +} + +// HasLocked returns true if the addressIterator has any locked addresses. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) HasLocked() bool { + a.mu.Lock() + defer a.mu.Unlock() + + return a.totalLockCount > 0 +} + +// GetAll returns a copy of all the addresses in the iterator. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) GetAll() []net.Addr { + a.mu.Lock() + defer a.mu.Unlock() + + var addrs []net.Addr + cursor := a.addrList.Front() + + for cursor != nil { + addrID := cursor.Value.(string) + + addr, ok := a.candidates[addrID] + if !ok { + cursor = cursor.Next() + continue + } + + addrs = append(addrs, addr.addr) + cursor = cursor.Next() + } + + return addrs +} diff --git a/watchtower/wtclient/addr_iterator_test.go b/watchtower/wtclient/addr_iterator_test.go new file mode 100644 index 000000000..d3674d985 --- /dev/null +++ b/watchtower/wtclient/addr_iterator_test.go @@ -0,0 +1,188 @@ +package wtclient + +import ( + "net" + "testing" + + "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/stretchr/testify/require" +) + +// TestAddrIterator tests the behaviour of the addressIterator. +func TestAddrIterator(t *testing.T) { + // Assert that an iterator can't be initialised with an empty address + // list. + _, err := newAddressIterator() + require.ErrorContains(t, err, "must have at least one address") + + addr1, err := net.ResolveTCPAddr("tcp", "1.2.3.4:8000") + require.NoError(t, err) + + // Initialise the iterator with addr1. + iter, err := newAddressIterator(addr1) + require.NoError(t, err) + + // Attempting to remove addr1 should fail now since it is the only + // address in the iterator. + iter.Add(addr1) + err = iter.Remove(addr1) + require.ErrorIs(t, err, wtdb.ErrLastTowerAddr) + + // Adding a duplicate of addr1 and then calling Remove should still + // return an error. + err = iter.Remove(addr1) + require.ErrorIs(t, err, wtdb.ErrLastTowerAddr) + + addr2, err := net.ResolveTCPAddr("tcp", "1.2.3.4:8001") + require.NoError(t, err) + + // Add addr2 to the iterator. + iter.Add(addr2) + + // Check that peek returns addr1. + a1 := iter.Peek() + require.NoError(t, err) + require.Equal(t, addr1, a1) + + // Calling peek multiple times should return the same result. + a1 = iter.Peek() + require.Equal(t, addr1, a1) + + // Calling Next should now return addr2. + a2, err := iter.Next() + require.NoError(t, err) + require.Equal(t, addr2, a2) + + // Assert that Peek now returns addr2. + a2 = iter.Peek() + require.NoError(t, err) + require.Equal(t, addr2, a2) + + // Calling Next should result in reaching the end of th list. + _, err = iter.Next() + require.ErrorIs(t, err, ErrAddressesExhausted) + + // Calling Peek now should reset the queue and return addr1. + a1 = iter.Peek() + require.Equal(t, addr1, a1) + + // Wind the list to the end again so that we can test the Reset func. + _, err = iter.Next() + require.NoError(t, err) + + _, err = iter.Next() + require.ErrorIs(t, err, ErrAddressesExhausted) + + iter.Reset() + + // Now Next should return addr 2. + a2, err = iter.Next() + require.NoError(t, err) + require.Equal(t, addr2, a2) + + addr3, err := net.ResolveTCPAddr("tcp", "1.2.3.4:8002") + require.NoError(t, err) + + // Add addr3 now to ensure that the iteration works even if we are + // midway through the queue. + iter.Add(addr3) + + // Now Next should return addr 3. + a3, err := iter.Next() + require.NoError(t, err) + require.Equal(t, addr3, a3) + + // Quickly test that GetAll correctly returns a copy of all the + // addresses in the iterator. + addrList := iter.GetAll() + require.ElementsMatch(t, addrList, []net.Addr{addr1, addr2, addr3}) + + // Let's now remove addr3. + err = iter.Remove(addr3) + require.NoError(t, err) + + // Since addr3 is gone, Peek should return addr1. + a1 = iter.Peek() + require.Equal(t, addr1, a1) + + // Lastly, we will test the "locking" of addresses. + + // First we test the locking of an address via the PeekAndLock function. + a1 = iter.PeekAndLock() + require.Equal(t, addr1, a1) + require.True(t, iter.HasLocked()) + + // Assert that we can't remove addr1 if there is a lock on it. + err = iter.Remove(addr1) + require.ErrorIs(t, err, ErrAddrInUse) + + // Now release the lock on addr1. + iter.ReleaseLock(addr1) + require.False(t, iter.HasLocked()) + + // Since the lock has been released, we should now be able to remove + // addr1. + err = iter.Remove(addr1) + require.NoError(t, err) + + // Now we test the locking of an address via the NextAndLock function. + // To do this, we first re-add addr3. + iter.Add(addr3) + + a2, err = iter.NextAndLock() + require.NoError(t, err) + require.Equal(t, addr2, a2) + require.True(t, iter.HasLocked()) + + // Assert that we can't remove addr2 if there is a lock on it. + err = iter.Remove(addr2) + require.ErrorIs(t, err, ErrAddrInUse) + + // Now release the lock on addr2. + iter.ReleaseLock(addr2) + require.False(t, iter.HasLocked()) + + // Since the lock has been released, we should now be able to remove + // addr1. + err = iter.Remove(addr2) + require.NoError(t, err) + + // Only addr3 should still be left in the iterator. + addrList = iter.GetAll() + require.Len(t, addrList, 1) + require.Contains(t, addrList, addr3) + + // Ensure that HasLocked acts correctly in the case where more than one + // address is being locked and unlock as well as the case where the same + // address is locked more than once. + + require.False(t, iter.HasLocked()) + + a3 = iter.PeekAndLock() + require.Equal(t, addr3, a3) + require.True(t, iter.HasLocked()) + + a3 = iter.PeekAndLock() + require.Equal(t, addr3, a3) + require.True(t, iter.HasLocked()) + + iter.Add(addr2) + a2, err = iter.NextAndLock() + require.NoError(t, err) + require.Equal(t, addr2, a2) + require.True(t, iter.HasLocked()) + + // Now release addr2 and asset that HasLock is still true. + iter.ReleaseLock(addr2) + require.True(t, iter.HasLocked()) + + // Releasing one of the locks on addr3 now should still result in + // HasLocked returning true. + iter.ReleaseLock(addr3) + require.True(t, iter.HasLocked()) + + // Releasing it again should now result in should still result in + // HasLocked returning false. + iter.ReleaseLock(addr3) + require.False(t, iter.HasLocked()) +} From 8a7329b988811b354b5344e38a3bd19b2feec6dc Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 12 Oct 2022 09:47:38 +0200 Subject: [PATCH 07/10] watchtower: make use of the new AddressIterator This commit upgrades the wtclient package to make use of the new `AddressIterator`. It does so by first creating new `Tower` and `ClientSession` types. The new `Tower` type has an `AddressIterator` instead of a list of addresses. The `ClientSession` type contains a `Tower`. --- watchtower/wtclient/candidate_iterator.go | 32 ++++--- .../wtclient/candidate_iterator_test.go | 56 +++++++----- watchtower/wtclient/client.go | 89 +++++++++++++------ watchtower/wtclient/client_test.go | 17 +--- watchtower/wtclient/errors.go | 4 - watchtower/wtclient/interface.go | 47 ++++++++++ watchtower/wtclient/session_negotiator.go | 64 +++++++------ watchtower/wtclient/session_queue.go | 84 +++++++++++------ watchtower/wtdb/client_db.go | 25 ++---- watchtower/wtdb/client_db_test.go | 7 +- watchtower/wtdb/client_session.go | 14 --- watchtower/wtdb/tower.go | 18 ---- watchtower/wtmock/client_db.go | 1 - 13 files changed, 274 insertions(+), 184 deletions(-) diff --git a/watchtower/wtclient/candidate_iterator.go b/watchtower/wtclient/candidate_iterator.go index 5b48a68ef..f11ad2e35 100644 --- a/watchtower/wtclient/candidate_iterator.go +++ b/watchtower/wtclient/candidate_iterator.go @@ -13,7 +13,7 @@ import ( type TowerCandidateIterator interface { // AddCandidate adds a new candidate tower to the iterator. If the // candidate already exists, then any new addresses are added to it. - AddCandidate(*wtdb.Tower) + AddCandidate(*Tower) // RemoveCandidate removes an existing candidate tower from the // iterator. An optional address can be provided to indicate a stale @@ -32,7 +32,7 @@ type TowerCandidateIterator interface { // Next returns the next candidate tower. The iterator is not required // to return results in any particular order. If no more candidates are // available, ErrTowerCandidatesExhausted is returned. - Next() (*wtdb.Tower, error) + Next() (*Tower, error) } // towerListIterator is a linked-list backed TowerCandidateIterator. @@ -40,7 +40,7 @@ type towerListIterator struct { mu sync.Mutex queue *list.List nextCandidate *list.Element - candidates map[wtdb.TowerID]*wtdb.Tower + candidates map[wtdb.TowerID]*Tower } // Compile-time constraint to ensure *towerListIterator implements the @@ -49,10 +49,10 @@ var _ TowerCandidateIterator = (*towerListIterator)(nil) // newTowerListIterator initializes a new towerListIterator from a variadic list // of lnwire.NetAddresses. -func newTowerListIterator(candidates ...*wtdb.Tower) *towerListIterator { +func newTowerListIterator(candidates ...*Tower) *towerListIterator { iter := &towerListIterator{ queue: list.New(), - candidates: make(map[wtdb.TowerID]*wtdb.Tower), + candidates: make(map[wtdb.TowerID]*Tower), } for _, candidate := range candidates { @@ -79,7 +79,7 @@ func (t *towerListIterator) Reset() error { // Next returns the next candidate tower. This iterator will always return // candidates in the order given when the iterator was instantiated. If no more // candidates are available, ErrTowerCandidatesExhausted is returned. -func (t *towerListIterator) Next() (*wtdb.Tower, error) { +func (t *towerListIterator) Next() (*Tower, error) { t.mu.Lock() defer t.mu.Unlock() @@ -107,7 +107,7 @@ func (t *towerListIterator) Next() (*wtdb.Tower, error) { // AddCandidate adds a new candidate tower to the iterator. If the candidate // already exists, then any new addresses are added to it. -func (t *towerListIterator) AddCandidate(candidate *wtdb.Tower) { +func (t *towerListIterator) AddCandidate(candidate *Tower) { t.mu.Lock() defer t.mu.Unlock() @@ -121,8 +121,16 @@ func (t *towerListIterator) AddCandidate(candidate *wtdb.Tower) { t.nextCandidate = t.queue.Back() } } else { - for _, addr := range candidate.Addresses { - tower.AddAddress(addr) + candidate.Addresses.Reset() + firstAddr := candidate.Addresses.Peek() + tower.Addresses.Add(firstAddr) + for { + next, err := candidate.Addresses.Next() + if err != nil { + break + } + + tower.Addresses.Add(next) } } } @@ -142,9 +150,9 @@ func (t *towerListIterator) RemoveCandidate(candidate wtdb.TowerID, return nil } if addr != nil { - tower.RemoveAddress(addr) - if len(tower.Addresses) == 0 { - return wtdb.ErrLastTowerAddr + err := tower.Addresses.Remove(addr) + if err != nil { + return err } } else { delete(t.candidates, candidate) diff --git a/watchtower/wtclient/candidate_iterator_test.go b/watchtower/wtclient/candidate_iterator_test.go index 9a919e103..7fe6ba723 100644 --- a/watchtower/wtclient/candidate_iterator_test.go +++ b/watchtower/wtclient/candidate_iterator_test.go @@ -33,31 +33,38 @@ func randAddr(t *testing.T) net.Addr { } } -func randTower(t *testing.T) *wtdb.Tower { +func randTower(t *testing.T) *Tower { t.Helper() priv, err := btcec.NewPrivateKey() require.NoError(t, err, "unable to create private key") pubKey := priv.PubKey() - return &wtdb.Tower{ + addrs, err := newAddressIterator(randAddr(t)) + require.NoError(t, err) + + return &Tower{ ID: wtdb.TowerID(rand.Uint64()), IdentityKey: pubKey, - Addresses: []net.Addr{randAddr(t)}, + Addresses: addrs, } } -func copyTower(tower *wtdb.Tower) *wtdb.Tower { - t := &wtdb.Tower{ +func copyTower(t *testing.T, tower *Tower) *Tower { + t.Helper() + + addrs := tower.Addresses.GetAll() + addrIterator, err := newAddressIterator(addrs...) + require.NoError(t, err) + + return &Tower{ ID: tower.ID, IdentityKey: tower.IdentityKey, - Addresses: make([]net.Addr, len(tower.Addresses)), + Addresses: addrIterator, } - copy(t.Addresses, tower.Addresses) - return t } -func assertActiveCandidate(t *testing.T, i TowerCandidateIterator, - c *wtdb.Tower, active bool) { +func assertActiveCandidate(t *testing.T, i TowerCandidateIterator, c *Tower, + active bool) { t.Helper() @@ -71,12 +78,14 @@ func assertActiveCandidate(t *testing.T, i TowerCandidateIterator, c.ID) } -func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *wtdb.Tower) { +func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *Tower) { t.Helper() tower, err := i.Next() require.NoError(t, err) - require.Equal(t, c, tower) + require.True(t, tower.IdentityKey.IsEqual(c.IdentityKey)) + require.Equal(t, tower.ID, c.ID) + require.Equal(t, tower.Addresses.GetAll(), c.Addresses.GetAll()) } // TestTowerCandidateIterator asserts the internal state of a @@ -88,13 +97,13 @@ func TestTowerCandidateIterator(t *testing.T) { // towers. We'll use copies of these towers within the iterator to // ensure the iterator properly updates the state of its candidates. const numTowers = 4 - towers := make([]*wtdb.Tower, 0, numTowers) + towers := make([]*Tower, 0, numTowers) for i := 0; i < numTowers; i++ { towers = append(towers, randTower(t)) } - towerCopies := make([]*wtdb.Tower, 0, numTowers) + towerCopies := make([]*Tower, 0, numTowers) for _, tower := range towers { - towerCopies = append(towerCopies, copyTower(tower)) + towerCopies = append(towerCopies, copyTower(t, tower)) } towerIterator := newTowerListIterator(towerCopies...) @@ -112,13 +121,13 @@ func TestTowerCandidateIterator(t *testing.T) { towerIterator.Reset() // We'll then attempt to test the RemoveCandidate behavior of the - // iterator. We'll remove the address of the first tower, which should - // result in it not having any addresses left, but still being an active - // candidate. + // iterator. We'll attempt to remove the address of the first tower, + // which should result in an error due to it being the last address of + // the tower. firstTower := towers[0] - firstTowerAddr := firstTower.Addresses[0] - firstTower.RemoveAddress(firstTowerAddr) - towerIterator.RemoveCandidate(firstTower.ID, firstTowerAddr) + firstTowerAddr := firstTower.Addresses.Peek() + err = towerIterator.RemoveCandidate(firstTower.ID, firstTowerAddr) + require.ErrorIs(t, err, wtdb.ErrLastTowerAddr) assertActiveCandidate(t, towerIterator, firstTower, true) assertNextCandidate(t, towerIterator, firstTower) @@ -126,7 +135,8 @@ func TestTowerCandidateIterator(t *testing.T) { // not providing the optional address. Since it's been removed, we // should expect to see the third tower next. secondTower, thirdTower := towers[1], towers[2] - towerIterator.RemoveCandidate(secondTower.ID, nil) + err = towerIterator.RemoveCandidate(secondTower.ID, nil) + require.NoError(t, err) assertActiveCandidate(t, towerIterator, secondTower, false) assertNextCandidate(t, towerIterator, thirdTower) @@ -135,7 +145,7 @@ func TestTowerCandidateIterator(t *testing.T) { // iterator, but the new address should be. fourthTower := towers[3] assertActiveCandidate(t, towerIterator, fourthTower, true) - fourthTower.AddAddress(randAddr(t)) + fourthTower.Addresses.Add(randAddr(t)) towerIterator.AddCandidate(fourthTower) assertNextCandidate(t, towerIterator, fourthTower) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index f9514f8f1..1f6641e28 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -45,8 +45,8 @@ const ( // genActiveSessionFilter generates a filter that selects active sessions that // also match the desired channel type, either legacy or anchor. -func genActiveSessionFilter(anchor bool) func(*wtdb.ClientSession) bool { - return func(s *wtdb.ClientSession) bool { +func genActiveSessionFilter(anchor bool) func(*ClientSession) bool { + return func(s *ClientSession) bool { return s.Status == wtdb.CSessionActive && anchor == s.Policy.IsAnchorChannel() } @@ -241,7 +241,7 @@ type TowerClient struct { negotiator SessionNegotiator candidateTowers TowerCandidateIterator - candidateSessions map[wtdb.SessionID]*wtdb.ClientSession + candidateSessions map[wtdb.SessionID]*ClientSession activeSessions sessionQueueSet sessionQueue *sessionQueue @@ -351,7 +351,7 @@ func New(config *Config) (*TowerClient, error) { activeSessionFilter := genActiveSessionFilter(isAnchorClient) candidateTowers := newTowerListIterator() - perActiveTower := func(tower *wtdb.Tower) { + perActiveTower := func(tower *Tower) { // If the tower has already been marked as active, then there is // no need to add it to the iterator again. if candidateTowers.IsActive(tower.ID) { @@ -400,18 +400,23 @@ func New(config *Config) (*TowerClient, error) { // sessionFilter check then the perActiveTower call-back will be called on that // tower. func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, - sessionFilter func(*wtdb.ClientSession) bool, - perActiveTower func(tower *wtdb.Tower), + sessionFilter func(*ClientSession) bool, + perActiveTower func(tower *Tower), opts ...wtdb.ClientSessionListOption) ( - map[wtdb.SessionID]*wtdb.ClientSession, error) { + map[wtdb.SessionID]*ClientSession, error) { towers, err := db.ListTowers() if err != nil { return nil, err } - candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession) - for _, tower := range towers { + candidateSessions := make(map[wtdb.SessionID]*ClientSession) + for _, dbTower := range towers { + tower, err := NewTowerFromDBTower(dbTower) + if err != nil { + return nil, err + } + sessions, err := db.ListClientSessions(&tower.ID, opts...) if err != nil { return nil, err @@ -427,16 +432,24 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, if err != nil { return nil, err } - s.SessionKeyECDH = keychain.NewPubKeyECDH( + + sessionKeyECDH := keychain.NewPubKeyECDH( towerKeyDesc, keyRing, ) - if !sessionFilter(s) { + cs := &ClientSession{ + ID: s.ID, + ClientSessionBody: s.ClientSessionBody, + Tower: tower, + SessionKeyECDH: sessionKeyECDH, + } + + if !sessionFilter(cs) { continue } // Add the session to the set of candidate sessions. - candidateSessions[s.ID] = s + candidateSessions[s.ID] = cs perActiveTower(tower) } } @@ -452,11 +465,11 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, // ClientSession's SessionPrivKey field is desired, otherwise, the existing // ListClientSessions method should be used. func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, - passesFilter func(*wtdb.ClientSession) bool, + passesFilter func(*ClientSession) bool, opts ...wtdb.ClientSessionListOption) ( - map[wtdb.SessionID]*wtdb.ClientSession, error) { + map[wtdb.SessionID]*ClientSession, error) { - sessions, err := db.ListClientSessions(forTower, opts...) + dbSessions, err := db.ListClientSessions(forTower, opts...) if err != nil { return nil, err } @@ -466,7 +479,13 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, // be able to communicate with the towers and authenticate session // requests. This prevents us from having to store the private keys on // disk. - for _, s := range sessions { + sessions := make(map[wtdb.SessionID]*ClientSession) + for _, s := range dbSessions { + dbTower, err := db.LoadTowerByID(s.TowerID) + if err != nil { + return nil, err + } + towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{ Family: keychain.KeyFamilyTowerSession, Index: s.KeyIndex, @@ -474,13 +493,27 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, if err != nil { return nil, err } - s.SessionKeyECDH = keychain.NewPubKeyECDH(towerKeyDesc, keyRing) + sessionKeyECDH := keychain.NewPubKeyECDH(towerKeyDesc, keyRing) + + tower, err := NewTowerFromDBTower(dbTower) + if err != nil { + return nil, err + } + + cs := &ClientSession{ + ID: s.ID, + ClientSessionBody: s.ClientSessionBody, + Tower: tower, + SessionKeyECDH: sessionKeyECDH, + } // If an optional filter was provided, use it to filter out any // undesired sessions. - if passesFilter != nil && !passesFilter(s) { - delete(sessions, s.ID) + if passesFilter != nil && !passesFilter(cs) { + continue } + + sessions[s.ID] = cs } return sessions, nil @@ -710,7 +743,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) { // Select any candidate session at random, and remove it from the set of // candidate sessions. - var candidateSession *wtdb.ClientSession + var candidateSession *ClientSession for id, sessionInfo := range c.candidateSessions { delete(c.candidateSessions, id) @@ -1069,7 +1102,7 @@ func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error // newSessionQueue creates a sessionQueue from a ClientSession loaded from the // database and supplying it with the resources needed by the client. -func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession, +func (c *TowerClient) newSessionQueue(s *ClientSession, updates []wtdb.CommittedUpdate) *sessionQueue { return newSessionQueue(&sessionQueueConfig{ @@ -1089,7 +1122,7 @@ func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession, // getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the // passed ClientSession. If it exists, the active sessionQueue is returned. // Otherwise, a new sessionQueue is initialized and added to the set. -func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession, +func (c *TowerClient) getOrInitActiveQueue(s *ClientSession, updates []wtdb.CommittedUpdate) *sessionQueue { if sq, ok := c.activeSessions[s.ID]; ok { @@ -1103,7 +1136,7 @@ func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession, // adds the sessionQueue to the activeSessions set, and starts the sessionQueue // so that it can deliver any committed updates or begin accepting newly // assigned tasks. -func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession, +func (c *TowerClient) initActiveQueue(s *ClientSession, updates []wtdb.CommittedUpdate) *sessionQueue { // Initialize the session queue, providing it with all the resources it @@ -1156,10 +1189,16 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error { // We'll start by updating our persisted state, followed by our // in-memory state, with the new tower. This might not actually be a new // tower, but it might include a new address at which it can be reached. - tower, err := c.cfg.DB.CreateTower(msg.addr) + dbTower, err := c.cfg.DB.CreateTower(msg.addr) if err != nil { return err } + + tower, err := NewTowerFromDBTower(dbTower) + if err != nil { + return err + } + c.candidateTowers.AddCandidate(tower) // Include all of its corresponding sessions to our set of candidates. @@ -1251,7 +1290,7 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error { // If our active session queue corresponds to the stale tower, we'll // proceed to negotiate a new one. if c.sessionQueue != nil { - activeTower := c.sessionQueue.towerAddr.IdentityKey.SerializeCompressed() + activeTower := c.sessionQueue.tower.IdentityKey.SerializeCompressed() if bytes.Equal(pubKey, activeTower) { c.sessionQueue = nil } diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 6312bbbab..738c8cf02 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -1471,10 +1471,8 @@ var clientTests = []clientTest{ }, { // Assert that if a client changes the address for a server and - // then tries to back up updates then the client will not switch - // to the new address. The client will only use the server's new - // address after a restart. This is a bug that will be fixed in - // a future commit. + // then tries to back up updates then the client will switch to + // the new address. name: "change address of existing session", cfg: harnessCfg{ localBalance: localBalance, @@ -1535,16 +1533,7 @@ var clientTests = []clientTest{ // Now attempt to back up the rest of the updates. h.backupStates(chanID, numUpdates/2, maxUpdates, nil) - // Assert that the server does not receive the updates. - h.waitServerUpdates(nil, waitTime) - - // Restart the client and attempt to back up the updates - // again. - h.client.Stop() - h.startClient() - h.backupStates(chanID, numUpdates/2, maxUpdates, nil) - - // The server should now receive the updates. + // Assert that the server does receive the updates. h.waitServerUpdates(hints[:maxUpdates], waitTime) }, }, diff --git a/watchtower/wtclient/errors.go b/watchtower/wtclient/errors.go index 857af3087..f496074bf 100644 --- a/watchtower/wtclient/errors.go +++ b/watchtower/wtclient/errors.go @@ -20,10 +20,6 @@ var ( // down. ErrNegotiatorExiting = errors.New("negotiator exiting") - // ErrNoTowerAddrs signals that the client could not be created because - // we have no addresses with which we can reach a tower. - ErrNoTowerAddrs = errors.New("no tower addresses") - // ErrFailedNegotiation signals that the session negotiator could not // acquire a new session as requested. ErrFailedNegotiation = errors.New("session negotiation unsuccessful") diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 5f2357950..ba6546328 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -118,3 +118,50 @@ type ECDHKeyRing interface { // key. DeriveKey(keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error) } + +// Tower represents the info about a watchtower server that a watchtower client +// needs in order to connect to it. +type Tower struct { + // ID is the unique, db-assigned, identifier for this tower. + ID wtdb.TowerID + + // IdentityKey is the public key of the remote node, used to + // authenticate the brontide transport. + IdentityKey *btcec.PublicKey + + // Addresses is an AddressIterator that can be used to manage the + // addresses for this tower. + Addresses AddressIterator +} + +// NewTowerFromDBTower converts a wtdb.Tower, which uses a static address list, +// into a Tower which uses an address iterator. +func NewTowerFromDBTower(t *wtdb.Tower) (*Tower, error) { + addrs, err := newAddressIterator(t.Addresses...) + if err != nil { + return nil, err + } + + return &Tower{ + ID: t.ID, + IdentityKey: t.IdentityKey, + Addresses: addrs, + }, nil +} + +// ClientSession represents the session that a tower client has with a server. +type ClientSession struct { + // ID is the client's public key used when authenticating with the + // tower. + ID wtdb.SessionID + + wtdb.ClientSessionBody + + // Tower represents the tower that the client session has been made + // with. + Tower *Tower + + // SessionKeyECDH is the ECDH capable wrapper of the ephemeral secret + // key used to connect to the watchtower. + SessionKeyECDH keychain.SingleKeyECDH +} diff --git a/watchtower/wtclient/session_negotiator.go b/watchtower/wtclient/session_negotiator.go index 9ccaf5b79..91b568158 100644 --- a/watchtower/wtclient/session_negotiator.go +++ b/watchtower/wtclient/session_negotiator.go @@ -25,7 +25,7 @@ type SessionNegotiator interface { // NewSessions is a read-only channel where newly negotiated sessions // will be delivered. - NewSessions() <-chan *wtdb.ClientSession + NewSessions() <-chan *ClientSession // Start safely initializes the session negotiator. Start() error @@ -105,8 +105,8 @@ type sessionNegotiator struct { log btclog.Logger dispatcher chan struct{} - newSessions chan *wtdb.ClientSession - successfulNegotiations chan *wtdb.ClientSession + newSessions chan *ClientSession + successfulNegotiations chan *ClientSession wg sync.WaitGroup quit chan struct{} @@ -139,8 +139,8 @@ func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator { log: cfg.Log, localInit: localInit, dispatcher: make(chan struct{}, 1), - newSessions: make(chan *wtdb.ClientSession), - successfulNegotiations: make(chan *wtdb.ClientSession), + newSessions: make(chan *ClientSession), + successfulNegotiations: make(chan *ClientSession), quit: make(chan struct{}), } } @@ -171,7 +171,7 @@ func (n *sessionNegotiator) Stop() error { // NewSessions returns a receive-only channel from which newly negotiated // sessions will be returned. -func (n *sessionNegotiator) NewSessions() <-chan *wtdb.ClientSession { +func (n *sessionNegotiator) NewSessions() <-chan *ClientSession { return n.newSessions } @@ -333,18 +333,10 @@ retryWithBackoff: } } -// createSession takes a tower an attempts to negotiate a session using any of +// createSession takes a tower and attempts to negotiate a session using any of // its stored addresses. This method returns after the first successful -// negotiation, or after all addresses have failed with ErrFailedNegotiation. If -// the tower has no addresses, ErrNoTowerAddrs is returned. -func (n *sessionNegotiator) createSession(tower *wtdb.Tower, - keyIndex uint32) error { - - // If the tower has no addresses, there's nothing we can do. - if len(tower.Addresses) == 0 { - return ErrNoTowerAddrs - } - +// negotiation, or after all addresses have failed with ErrFailedNegotiation. +func (n *sessionNegotiator) createSession(tower *Tower, keyIndex uint32) error { sessionKeyDesc, err := n.cfg.SecretKeyRing.DeriveKey( keychain.KeyLocator{ Family: keychain.KeyFamilyTowerSession, @@ -358,8 +350,14 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower, sessionKeyDesc, n.cfg.SecretKeyRing, ) - for _, lnAddr := range tower.LNAddrs() { - err := n.tryAddress(sessionKey, keyIndex, tower, lnAddr) + addr := tower.Addresses.Peek() + for { + lnAddr := &lnwire.NetAddress{ + IdentityKey: tower.IdentityKey, + Address: addr, + } + + err = n.tryAddress(sessionKey, keyIndex, tower, lnAddr) switch { case err == ErrPermanentTowerFailure: // TODO(conner): report to iterator? can then be reset @@ -370,6 +368,15 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower, n.log.Debugf("Request for session negotiation with "+ "tower=%s failed, trying again -- reason: "+ "%v", lnAddr, err) + + // Get the next tower address if there is one. + addr, err = tower.Addresses.Next() + if err == ErrAddressesExhausted { + tower.Addresses.Reset() + + return ErrFailedNegotiation + } + continue default: @@ -385,7 +392,7 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower, // returns true if all steps succeed and the new session has been persisted, and // fails otherwise. func (n *sessionNegotiator) tryAddress(sessionKey keychain.SingleKeyECDH, - keyIndex uint32, tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error { + keyIndex uint32, tower *Tower, lnAddr *lnwire.NetAddress) error { // Connect to the tower address using our generated session key. conn, err := n.cfg.Dial(sessionKey, lnAddr) @@ -456,26 +463,31 @@ func (n *sessionNegotiator) tryAddress(sessionKey keychain.SingleKeyECDH, rewardPkScript := createSessionReply.Data sessionID := wtdb.NewSessionIDFromPubKey(sessionKey.PubKey()) - clientSession := &wtdb.ClientSession{ + dbClientSession := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ TowerID: tower.ID, KeyIndex: keyIndex, Policy: n.cfg.Policy, RewardPkScript: rewardPkScript, }, - Tower: tower, - SessionKeyECDH: sessionKey, - ID: sessionID, + ID: sessionID, } - err = n.cfg.DB.CreateClientSession(clientSession) + err = n.cfg.DB.CreateClientSession(dbClientSession) if err != nil { return fmt.Errorf("unable to persist ClientSession: %v", err) } n.log.Debugf("New session negotiated with %s, policy: %s", - lnAddr, clientSession.Policy) + lnAddr, dbClientSession.Policy) + + clientSession := &ClientSession{ + ID: sessionID, + ClientSessionBody: dbClientSession.ClientSessionBody, + Tower: tower, + SessionKeyECDH: sessionKey, + } // We have a newly negotiated session, return it to the // dispatcher so that it can update how many outstanding diff --git a/watchtower/wtclient/session_queue.go b/watchtower/wtclient/session_queue.go index d149d09b6..7d98ec86f 100644 --- a/watchtower/wtclient/session_queue.go +++ b/watchtower/wtclient/session_queue.go @@ -34,7 +34,7 @@ const ( type sessionQueueConfig struct { // ClientSession provides access to the negotiated session parameters // and updating its persistent storage. - ClientSession *wtdb.ClientSession + ClientSession *ClientSession // ChainHash identifies the chain for which the session's justice // transactions are targeted. @@ -97,7 +97,7 @@ type sessionQueue struct { queueCond *sync.Cond localInit *wtwire.Init - towerAddr *lnwire.NetAddress + tower *Tower seqNum uint16 @@ -117,18 +117,13 @@ func newSessionQueue(cfg *sessionQueueConfig, cfg.ChainHash, ) - towerAddr := &lnwire.NetAddress{ - IdentityKey: cfg.ClientSession.Tower.IdentityKey, - Address: cfg.ClientSession.Tower.Addresses[0], - } - sq := &sessionQueue{ cfg: cfg, log: cfg.Log, commitQueue: list.New(), pendingQueue: list.New(), localInit: localInit, - towerAddr: towerAddr, + tower: cfg.ClientSession.Tower, seqNum: cfg.ClientSession.SeqNum, retryBackoff: cfg.MinBackoff, quit: make(chan struct{}), @@ -293,18 +288,48 @@ func (q *sessionQueue) sessionManager() { // drainBackups attempts to send all pending updates in the queue to the tower. func (q *sessionQueue) drainBackups() { - // First, check that we are able to dial this session's tower. - conn, err := q.cfg.Dial(q.cfg.ClientSession.SessionKeyECDH, q.towerAddr) - if err != nil { - q.log.Errorf("SessionQueue(%s) unable to dial tower at %v: %v", - q.ID(), q.towerAddr, err) + var ( + conn wtserver.Peer + err error + towerAddr = q.tower.Addresses.Peek() + ) - q.increaseBackoff() - select { - case <-time.After(q.retryBackoff): - case <-q.forceQuit: + for { + q.log.Infof("SessionQueue(%s) attempting to dial tower at %v", + q.ID(), towerAddr) + + // First, check that we are able to dial this session's tower. + conn, err = q.cfg.Dial( + q.cfg.ClientSession.SessionKeyECDH, &lnwire.NetAddress{ + IdentityKey: q.tower.IdentityKey, + Address: towerAddr, + }, + ) + if err != nil { + // If there are more addrs available, immediately try + // those. + nextAddr, iteratorErr := q.tower.Addresses.Next() + if iteratorErr == nil { + towerAddr = nextAddr + continue + } + + // Otherwise, if we have exhausted the address list, + // back off and try again later. + q.tower.Addresses.Reset() + + q.log.Errorf("SessionQueue(%s) unable to dial tower "+ + "at any available Addresses: %v", q.ID(), err) + + q.increaseBackoff() + select { + case <-time.After(q.retryBackoff): + case <-q.forceQuit: + } + return } - return + + break } defer conn.Close() @@ -324,9 +349,7 @@ func (q *sessionQueue) drainBackups() { } // Now, send the state update to the tower and wait for a reply. - err = q.sendStateUpdate( - conn, stateUpdate, q.localInit, sendInit, isPending, - ) + err = q.sendStateUpdate(conn, stateUpdate, sendInit, isPending) if err != nil { q.log.Errorf("SessionQueue(%s) unable to send state "+ "update: %v", q.ID(), err) @@ -483,8 +506,12 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, // variable indicates whether we should back off before attempting to send the // next state update. func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, - stateUpdate *wtwire.StateUpdate, localInit *wtwire.Init, - sendInit, isPending bool) error { + stateUpdate *wtwire.StateUpdate, sendInit, isPending bool) error { + + towerAddr := &lnwire.NetAddress{ + IdentityKey: conn.RemotePub(), + Address: conn.RemoteAddr(), + } // If this is the first message being sent to the tower, we must send an // Init message to establish that server supports the features we @@ -505,7 +532,7 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, remoteInit, ok := remoteMsg.(*wtwire.Init) if !ok { return fmt.Errorf("watchtower %s responded with %T "+ - "to Init", q.towerAddr, remoteMsg) + "to Init", towerAddr, remoteMsg) } // Validate Init. @@ -532,7 +559,7 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, stateUpdateReply, ok := remoteMsg.(*wtwire.StateUpdateReply) if !ok { return fmt.Errorf("watchtower %s responded with %T to "+ - "StateUpdate", q.towerAddr, remoteMsg) + "StateUpdate", towerAddr, remoteMsg) } // Process the reply from the tower. @@ -547,8 +574,8 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, err := fmt.Errorf("received error code %v in "+ "StateUpdateReply for seqnum=%d", stateUpdateReply.Code, stateUpdate.SeqNum) - q.log.Warnf("SessionQueue(%s) unable to upload state update to "+ - "tower=%s: %v", q.ID(), q.towerAddr, err) + q.log.Warnf("SessionQueue(%s) unable to upload state update "+ + "to tower=%s: %v", q.ID(), towerAddr, err) return err } @@ -559,7 +586,8 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, // TODO(conner): borked watchtower err = fmt.Errorf("unable to ack seqnum=%d: %v", stateUpdate.SeqNum, err) - q.log.Errorf("SessionQueue(%v) failed to ack update: %v", q.ID(), err) + q.log.Errorf("SessionQueue(%v) failed to ack update: %v", + q.ID(), err) return err case err == wtdb.ErrLastAppliedReversion: diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 94a9c2c74..26d4704d4 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -429,7 +429,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { } towerSessions, err := listTowerSessions( - towerID, sessions, towers, towersToSessionsIndex, + towerID, sessions, towersToSessionsIndex, WithPerCommittedUpdate(perCommittedUpdate), ) if err != nil { @@ -766,7 +766,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID, // known to the db. if id == nil { clientSessions, err = listClientAllSessions( - sessions, towers, opts..., + sessions, opts..., ) return err } @@ -778,7 +778,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID, } clientSessions, err = listTowerSessions( - *id, sessions, towers, towerToSessionIndex, opts..., + *id, sessions, towerToSessionIndex, opts..., ) return err }, func() { @@ -792,7 +792,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID, } // listClientAllSessions returns the set of all client sessions known to the db. -func listClientAllSessions(sessions, towers kvdb.RBucket, +func listClientAllSessions(sessions kvdb.RBucket, opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) { clientSessions := make(map[SessionID]*ClientSession) @@ -801,7 +801,7 @@ func listClientAllSessions(sessions, towers kvdb.RBucket, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := getClientSession(sessions, towers, k, opts...) + session, err := getClientSession(sessions, k, opts...) if err != nil { return err } @@ -819,7 +819,7 @@ func listClientAllSessions(sessions, towers kvdb.RBucket, // listTowerSessions returns the set of all client sessions known to the db // that are associated with the given tower id. -func listTowerSessions(id TowerID, sessionsBkt, towersBkt, +func listTowerSessions(id TowerID, sessionsBkt, towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) ( map[SessionID]*ClientSession, error) { @@ -834,9 +834,7 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := getClientSession( - sessionsBkt, towersBkt, k, opts..., - ) + session, err := getClientSession(sessionsBkt, k, opts...) if err != nil { return err } @@ -1248,7 +1246,7 @@ func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption { // getClientSession loads the full ClientSession associated with the serialized // session id. This method populates the CommittedUpdates, AckUpdates and Tower // in addition to the ClientSession's body. -func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte, +func getClientSession(sessions kvdb.RBucket, idBytes []byte, opts ...ClientSessionListOption) (*ClientSession, error) { cfg := NewClientSessionCfg() @@ -1261,13 +1259,6 @@ func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte, return nil, err } - // Fetch the tower associated with this session. - tower, err := getTower(towers, session.TowerID.Bytes()) - if err != nil { - return nil, err - } - session.Tower = tower - // Can't fail because client session body has already been read. sessionBkt := sessions.NestedReadBucket(idBytes) diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index aa30cc713..f75a0c2bc 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -343,8 +343,11 @@ func testCreateTower(h *clientDBHarness) { h.loadTowerByID(20, wtdb.ErrTowerNotFound) tower := h.newTower() - require.Len(h.t, tower.LNAddrs(), 1) - towerAddr := tower.LNAddrs()[0] + require.Len(h.t, tower.Addresses, 1) + towerAddr := &lnwire.NetAddress{ + IdentityKey: tower.IdentityKey, + Address: tower.Addresses[0], + } // Load the tower from the database and assert that it matches the tower // we created. diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index a4d5c5ecc..e44331094 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -4,7 +4,6 @@ import ( "fmt" "io" - "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" @@ -36,19 +35,6 @@ type ClientSession struct { ID SessionID ClientSessionBody - - // Tower holds the pubkey and address of the watchtower. - // - // NOTE: This value is not serialized. It is recovered by looking up the - // tower with TowerID. - Tower *Tower - - // SessionKeyECDH is the ECDH capable wrapper of the ephemeral secret - // key used to connect to the watchtower. - // - // NOTE: This value is not serialized. It is derived using the KeyIndex - // on startup to avoid storing private keys on disk. - SessionKeyECDH keychain.SingleKeyECDH } // ClientSessionBody represents the primary components of a ClientSession that diff --git a/watchtower/wtdb/tower.go b/watchtower/wtdb/tower.go index 77f452fb5..ca9dbeb28 100644 --- a/watchtower/wtdb/tower.go +++ b/watchtower/wtdb/tower.go @@ -7,7 +7,6 @@ import ( "net" "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/lnwire" ) // TowerID is a unique 64-bit identifier allocated to each unique watchtower. @@ -77,23 +76,6 @@ func (t *Tower) RemoveAddress(addr net.Addr) { } } -// LNAddrs generates a list of lnwire.NetAddress from a Tower instance's -// addresses. This can be used to have a client try multiple addresses for the -// same Tower. -// -// NOTE: This method is NOT safe for concurrent use. -func (t *Tower) LNAddrs() []*lnwire.NetAddress { - addrs := make([]*lnwire.NetAddress, 0, len(t.Addresses)) - for _, addr := range t.Addresses { - addrs = append(addrs, &lnwire.NetAddress{ - IdentityKey: t.IdentityKey, - Address: addr, - }) - } - - return addrs -} - // String returns a user-friendly identifier of the tower. func (t *Tower) String() string { pubKey := hex.EncodeToString(t.IdentityKey.SerializeCompressed()) diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 8a47bdf7f..b12fe2780 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -231,7 +231,6 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, if tower != nil && *tower != session.TowerID { continue } - session.Tower = m.towers[session.TowerID] sessions[session.ID] = &session if cfg.PerAckedUpdate != nil { From b2039f245e97f343a112580061f1c530a9ee64c5 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 12 Oct 2022 10:19:16 +0200 Subject: [PATCH 08/10] watchtower: demo cant remove tower during negotiation bug In this commit, a new test is added to demonstrate that an error is thrown if a user attempts to remove a tower during session negotiation even if no current negotiation is taking place with the tower. --- watchtower/wtclient/client_test.go | 41 ++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 738c8cf02..1d5c5b76a 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -17,6 +17,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tor" @@ -401,6 +402,7 @@ type harnessCfg struct { policy wtpolicy.Policy noRegisterChan0 bool noAckCreateSession bool + noServerStart bool } func newHarness(t *testing.T, cfg harnessCfg) *testHarness { @@ -467,8 +469,10 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { channels: make(map[lnwire.ChannelID]*mockChannel), } - h.startServer() - t.Cleanup(h.stopServer) + if !cfg.noServerStart { + h.startServer() + t.Cleanup(h.stopServer) + } h.startClient() t.Cleanup(h.client.ForceQuit) @@ -1537,6 +1541,39 @@ var clientTests = []clientTest{ h.waitServerUpdates(hints[:maxUpdates], waitTime) }, }, + { + // Assert that an error is returned if a user tries to remove + // a tower from the client while a session negotiation is in + // progress. This is a bug that will be fixed in a future + // commit. + name: "cant remove tower while session negotiation in progress", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 5, + }, + noServerStart: true, + }, + fn: func(h *testHarness) { + var err error + waitErr := wait.Predicate(func() bool { + err = h.client.RemoveTower( + h.serverAddr.IdentityKey, nil, + ) + return err != nil + }, time.Second*5) + require.NoError(h.t, waitErr) + + require.ErrorContains(h.t, err, "removing towers is "+ + "disallowed while a new session negotiation "+ + "is in progress") + }, + }, } // TestClient executes the client test suite, asserting the ability to backup From 3ff5abc9e30876d7be7a67276179922b81bc9c3f Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 12 Oct 2022 10:56:04 +0200 Subject: [PATCH 09/10] watchtower: allow removal during session negotiation In this commit, the bug demonstrated in the previous commit is fixed. The locking capabilities of the AddressIterator are used to lock addresses if they are being used for session negotiation. So now, when a request comes through to remove a tower address then a check is first done to ensure that the address is not currently in use. If it is not, then the request can go through. --- watchtower/wtclient/candidate_iterator.go | 4 + watchtower/wtclient/client.go | 37 +++++--- watchtower/wtclient/client_test.go | 109 +++++++++++++++++++--- watchtower/wtclient/session_negotiator.go | 5 +- 4 files changed, 125 insertions(+), 30 deletions(-) diff --git a/watchtower/wtclient/candidate_iterator.go b/watchtower/wtclient/candidate_iterator.go index f11ad2e35..faf3169c6 100644 --- a/watchtower/wtclient/candidate_iterator.go +++ b/watchtower/wtclient/candidate_iterator.go @@ -155,6 +155,10 @@ func (t *towerListIterator) RemoveCandidate(candidate wtdb.TowerID, return err } } else { + if tower.Addresses.HasLocked() { + return ErrAddrInUse + } + delete(t.candidates, candidate) } diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 1f6641e28..3aa84f28c 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -2,7 +2,6 @@ package wtclient import ( "bytes" - "errors" "fmt" "net" "sync" @@ -826,13 +825,10 @@ func (c *TowerClient) backupDispatcher() { msg.errChan <- c.handleNewTower(msg) // A tower has been requested to be removed. We'll - // immediately return an error as we want to avoid the - // possibility of a new session being negotiated with - // this request's tower. + // only allow removal of it if the address in question + // is not currently being used for session negotiation. case msg := <-c.staleTowers: - msg.errChan <- errors.New("removing towers " + - "is disallowed while a new session " + - "negotiation is in progress") + msg.errChan <- c.handleStaleTower(msg) case <-c.forceQuit: return @@ -1254,18 +1250,31 @@ func (c *TowerClient) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error { // We'll load the tower before potentially removing it in order to // retrieve its ID within the database. - tower, err := c.cfg.DB.LoadTower(msg.pubKey) + dbTower, err := c.cfg.DB.LoadTower(msg.pubKey) if err != nil { return err } - // We'll update our persisted state, followed by our in-memory state, - // with the stale tower. - if err := c.cfg.DB.RemoveTower(msg.pubKey, msg.addr); err != nil { + // We'll first update our in-memory state followed by our persisted + // state, with the stale tower. The removal of the tower address from + // the in-memory state will fail if the address is currently being used + // for a session negotiation. + err = c.candidateTowers.RemoveCandidate(dbTower.ID, msg.addr) + if err != nil { return err } - err = c.candidateTowers.RemoveCandidate(tower.ID, msg.addr) - if err != nil { + + if err := c.cfg.DB.RemoveTower(msg.pubKey, msg.addr); err != nil { + // If the persisted state update fails, re-add the address to + // our in-memory state. + tower, newTowerErr := NewTowerFromDBTower(dbTower) + if newTowerErr != nil { + log.Errorf("could not create new in-memory tower: %v", + newTowerErr) + } else { + c.candidateTowers.AddCandidate(tower) + } + return err } @@ -1278,7 +1287,7 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error { // Otherwise, the tower should no longer be used for future session // negotiations and backups. pubKey := msg.pubKey.SerializeCompressed() - sessions, err := c.cfg.DB.ListClientSessions(&tower.ID) + sessions, err := c.cfg.DB.ListClientSessions(&dbTower.ID) if err != nil { return fmt.Errorf("unable to retrieve sessions for tower %x: "+ "%v", pubKey, err) diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 1d5c5b76a..1490c6d10 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -2,6 +2,7 @@ package wtclient_test import ( "encoding/binary" + "errors" "fmt" "net" "sync" @@ -394,6 +395,8 @@ type testHarness struct { mu sync.Mutex channels map[lnwire.ChannelID]*mockChannel + + quit chan struct{} } type harnessCfg struct { @@ -467,7 +470,11 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { serverCfg: serverCfg, net: mockNet, channels: make(map[lnwire.ChannelID]*mockChannel), + quit: make(chan struct{}), } + t.Cleanup(func() { + close(h.quit) + }) if !cfg.noServerStart { h.startServer() @@ -1542,11 +1549,10 @@ var clientTests = []clientTest{ }, }, { - // Assert that an error is returned if a user tries to remove - // a tower from the client while a session negotiation is in - // progress. This is a bug that will be fixed in a future - // commit. - name: "cant remove tower while session negotiation in progress", + // Assert that a user is able to remove a tower address during + // session negotiation as long as the address in question is not + // currently being used. + name: "removing a tower during session negotiation", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, @@ -1560,18 +1566,93 @@ var clientTests = []clientTest{ noServerStart: true, }, fn: func(h *testHarness) { - var err error - waitErr := wait.Predicate(func() bool { - err = h.client.RemoveTower( + // The server has not started yet and so no session + // negotiation with the server will be in progress, so + // the client should be able to remove the server. + err := wait.NoError(func() error { + return h.client.RemoveTower( h.serverAddr.IdentityKey, nil, ) - return err != nil - }, time.Second*5) - require.NoError(h.t, waitErr) + }, waitTime) + require.NoError(h.t, err) - require.ErrorContains(h.t, err, "removing towers is "+ - "disallowed while a new session negotiation "+ - "is in progress") + // Set the server up so that its Dial function hangs + // when the client calls it. This will force the client + // to remain in the state where it has locked the + // address of the server. + h.server, err = wtserver.New(h.serverCfg) + require.NoError(h.t, err) + + cancel := make(chan struct{}) + h.net.registerConnCallback( + h.serverAddr, func(peer wtserver.Peer) { + select { + case <-h.quit: + case <-cancel: + } + }, + ) + + // Also add a new tower address. + towerTCPAddr, err := net.ResolveTCPAddr( + "tcp", towerAddr2Str, + ) + require.NoError(h.t, err) + + towerAddr := &lnwire.NetAddress{ + IdentityKey: h.serverAddr.IdentityKey, + Address: towerTCPAddr, + } + + // Register the new address in the mock-net. + h.net.registerConnCallback( + towerAddr, h.server.InboundPeerConnected, + ) + + // Now start the server. + require.NoError(h.t, h.server.Start()) + + // Re-add the server to the client + err = h.client.AddTower(h.serverAddr) + require.NoError(h.t, err) + + // Also add the new tower address. + err = h.client.AddTower(towerAddr) + require.NoError(h.t, err) + + // Assert that if the client attempts to remove the + // tower's first address, then it will error due to + // address currently being locked for session + // negotiation. + err = wait.Predicate(func() bool { + err = h.client.RemoveTower( + h.serverAddr.IdentityKey, + h.serverAddr.Address, + ) + return errors.Is(err, wtclient.ErrAddrInUse) + }, waitTime) + require.NoError(h.t, err) + + // Assert that the second address can be removed since + // it is not being used for session negotiation. + err = wait.NoError(func() error { + return h.client.RemoveTower( + h.serverAddr.IdentityKey, towerTCPAddr, + ) + }, waitTime) + require.NoError(h.t, err) + + // Allow the dial to the first address to stop hanging. + close(cancel) + + // Assert that the client can now remove the first + // address. + err = wait.NoError(func() error { + return h.client.RemoveTower( + h.serverAddr.IdentityKey, nil, + ) + }, waitTime) + require.NoError(h.t, err) }, }, } diff --git a/watchtower/wtclient/session_negotiator.go b/watchtower/wtclient/session_negotiator.go index 91b568158..db0f543a1 100644 --- a/watchtower/wtclient/session_negotiator.go +++ b/watchtower/wtclient/session_negotiator.go @@ -350,7 +350,7 @@ func (n *sessionNegotiator) createSession(tower *Tower, keyIndex uint32) error { sessionKeyDesc, n.cfg.SecretKeyRing, ) - addr := tower.Addresses.Peek() + addr := tower.Addresses.PeekAndLock() for { lnAddr := &lnwire.NetAddress{ IdentityKey: tower.IdentityKey, @@ -358,6 +358,7 @@ func (n *sessionNegotiator) createSession(tower *Tower, keyIndex uint32) error { } err = n.tryAddress(sessionKey, keyIndex, tower, lnAddr) + tower.Addresses.ReleaseLock(addr) switch { case err == ErrPermanentTowerFailure: // TODO(conner): report to iterator? can then be reset @@ -370,7 +371,7 @@ func (n *sessionNegotiator) createSession(tower *Tower, keyIndex uint32) error { "%v", lnAddr, err) // Get the next tower address if there is one. - addr, err = tower.Addresses.Next() + addr, err = tower.Addresses.NextAndLock() if err == ErrAddressesExhausted { tower.Addresses.Reset() From ca05335083425dec43ef6d5f926a07ae503611a0 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 12 Oct 2022 11:50:13 +0200 Subject: [PATCH 10/10] docs: update release notes with #7025 --- docs/release-notes/release-notes-0.16.0.md | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/docs/release-notes/release-notes-0.16.0.md b/docs/release-notes/release-notes-0.16.0.md index 48ec8b807..ca19670c6 100644 --- a/docs/release-notes/release-notes-0.16.0.md +++ b/docs/release-notes/release-notes-0.16.0.md @@ -89,6 +89,9 @@ https://github.com/lightningnetwork/lnd/pull/6963/) * [Fixed a flake in the TestBlockCacheMutexes unit test](https://github.com/lightningnetwork/lnd/pull/7029). +* [Create a helper function to wait for peer to come + online](https://github.com/lightningnetwork/lnd/pull/6931). + ## `lncli` * [Add an `insecure` flag to skip tls auth as well as a `metadata` string slice flag](https://github.com/lightningnetwork/lnd/pull/6818) that allows the @@ -119,6 +122,12 @@ https://github.com/lightningnetwork/lnd/pull/6963/) caller is expected to know that doing so with untrusted input is unsafe.](https://github.com/lightningnetwork/lnd/pull/6779) +* [test: replace defer cleanup with + `t.Cleanup`](https://github.com/lightningnetwork/lnd/pull/6864). + +* [test: fix loop variables being accessed in + closures](https://github.com/lightningnetwork/lnd/pull/7032). + ## Watchtowers * [Create a towerID-to-sessionID index in the wtclient DB to improve the @@ -131,14 +140,10 @@ https://github.com/lightningnetwork/lnd/pull/6963/) struct](https://github.com/lightningnetwork/lnd/pull/6928) in order to improve the performance of fetching a `ClientSession` from the DB. -* [Create a helper function to wait for peer to come - online](https://github.com/lightningnetwork/lnd/pull/6931). - -* [test: replace defer cleanup with - `t.Cleanup`](https://github.com/lightningnetwork/lnd/pull/6864). - -* [test: fix loop variables being accessed in - closures](https://github.com/lightningnetwork/lnd/pull/7032). +* [Allow user to update tower address without requiring a restart. Also allow + the removal of a tower address if the current session negotiation is not + using the address in question]( + https://github.com/lightningnetwork/lnd/pull/7025) ### Tooling and documentation