diff --git a/watchtower/wtclient/candidate_iterator.go b/watchtower/wtclient/candidate_iterator.go index faf3169c6..10ef86465 100644 --- a/watchtower/wtclient/candidate_iterator.go +++ b/watchtower/wtclient/candidate_iterator.go @@ -29,6 +29,10 @@ type TowerCandidateIterator interface { // candidates available as long as they remain in the set. Reset() error + // GetTower gets the tower with the given ID from the iterator. If no + // such tower is found then ErrTowerNotInIterator is returned. + GetTower(id wtdb.TowerID) (*Tower, error) + // Next returns the next candidate tower. The iterator is not required // to return results in any particular order. If no more candidates are // available, ErrTowerCandidatesExhausted is returned. @@ -76,6 +80,20 @@ func (t *towerListIterator) Reset() error { return nil } +// GetTower gets the tower with the given ID from the iterator. If no such tower +// is found then ErrTowerNotInIterator is returned. +func (t *towerListIterator) GetTower(id wtdb.TowerID) (*Tower, error) { + t.mu.Lock() + defer t.mu.Unlock() + + tower, ok := t.candidates[id] + if !ok { + return nil, ErrTowerNotInIterator + } + + return tower, nil +} + // Next returns the next candidate tower. This iterator will always return // candidates in the order given when the iterator was instantiated. If no more // candidates are available, ErrTowerCandidatesExhausted is returned. diff --git a/watchtower/wtclient/candidate_iterator_test.go b/watchtower/wtclient/candidate_iterator_test.go index 7fe6ba723..70dfb7505 100644 --- a/watchtower/wtclient/candidate_iterator_test.go +++ b/watchtower/wtclient/candidate_iterator_test.go @@ -83,9 +83,15 @@ func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *Tower) { tower, err := i.Next() require.NoError(t, err) - require.True(t, tower.IdentityKey.IsEqual(c.IdentityKey)) - require.Equal(t, tower.ID, c.ID) - require.Equal(t, tower.Addresses.GetAll(), c.Addresses.GetAll()) + assertTowersEqual(t, c, tower) +} + +func assertTowersEqual(t *testing.T, expected, actual *Tower) { + t.Helper() + + require.True(t, expected.IdentityKey.IsEqual(actual.IdentityKey)) + require.Equal(t, expected.ID, actual.ID) + require.Equal(t, expected.Addresses.GetAll(), actual.Addresses.GetAll()) } // TestTowerCandidateIterator asserts the internal state of a @@ -155,4 +161,16 @@ func TestTowerCandidateIterator(t *testing.T) { towerIterator.AddCandidate(secondTower) assertActiveCandidate(t, towerIterator, secondTower, true) assertNextCandidate(t, towerIterator, secondTower) + + // Assert that the GetTower correctly returns the tower too. + tower, err := towerIterator.GetTower(secondTower.ID) + require.NoError(t, err) + assertTowersEqual(t, secondTower, tower) + + // Now remove the tower and assert that GetTower returns expected error. + err = towerIterator.RemoveCandidate(secondTower.ID, nil) + require.NoError(t, err) + + _, err = towerIterator.GetTower(secondTower.ID) + require.ErrorIs(t, err, ErrTowerNotInIterator) } diff --git a/watchtower/wtclient/errors.go b/watchtower/wtclient/errors.go index f496074bf..c6884bb35 100644 --- a/watchtower/wtclient/errors.go +++ b/watchtower/wtclient/errors.go @@ -1,6 +1,8 @@ package wtclient -import "errors" +import ( + "errors" +) var ( // ErrClientExiting signals that the watchtower client is shutting down. @@ -11,6 +13,10 @@ var ( ErrTowerCandidatesExhausted = errors.New("exhausted all tower " + "candidates") + // ErrTowerNotInIterator is returned when a requested tower was not + // found in the iterator. + ErrTowerNotInIterator = errors.New("tower not in iterator") + // ErrPermanentTowerFailure signals that the tower has reported that it // has permanently failed or the client believes this has happened based // on the tower's behavior.