watchtower: add GetTower to tower iterator

Add a GetTower method to the tower iterator.
This commit is contained in:
Elle Mouton 2022-10-21 11:24:02 +02:00
parent e432261dab
commit 0ed5c750c8
No known key found for this signature in database
GPG key ID: D7D916376026F177
3 changed files with 46 additions and 4 deletions

View file

@ -29,6 +29,10 @@ type TowerCandidateIterator interface {
// candidates available as long as they remain in the set.
Reset() error
// GetTower gets the tower with the given ID from the iterator. If no
// such tower is found then ErrTowerNotInIterator is returned.
GetTower(id wtdb.TowerID) (*Tower, error)
// Next returns the next candidate tower. The iterator is not required
// to return results in any particular order. If no more candidates are
// available, ErrTowerCandidatesExhausted is returned.
@ -76,6 +80,20 @@ func (t *towerListIterator) Reset() error {
return nil
}
// GetTower gets the tower with the given ID from the iterator. If no such tower
// is found then ErrTowerNotInIterator is returned.
func (t *towerListIterator) GetTower(id wtdb.TowerID) (*Tower, error) {
t.mu.Lock()
defer t.mu.Unlock()
tower, ok := t.candidates[id]
if !ok {
return nil, ErrTowerNotInIterator
}
return tower, nil
}
// Next returns the next candidate tower. This iterator will always return
// candidates in the order given when the iterator was instantiated. If no more
// candidates are available, ErrTowerCandidatesExhausted is returned.

View file

@ -83,9 +83,15 @@ func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *Tower) {
tower, err := i.Next()
require.NoError(t, err)
require.True(t, tower.IdentityKey.IsEqual(c.IdentityKey))
require.Equal(t, tower.ID, c.ID)
require.Equal(t, tower.Addresses.GetAll(), c.Addresses.GetAll())
assertTowersEqual(t, c, tower)
}
func assertTowersEqual(t *testing.T, expected, actual *Tower) {
t.Helper()
require.True(t, expected.IdentityKey.IsEqual(actual.IdentityKey))
require.Equal(t, expected.ID, actual.ID)
require.Equal(t, expected.Addresses.GetAll(), actual.Addresses.GetAll())
}
// TestTowerCandidateIterator asserts the internal state of a
@ -155,4 +161,16 @@ func TestTowerCandidateIterator(t *testing.T) {
towerIterator.AddCandidate(secondTower)
assertActiveCandidate(t, towerIterator, secondTower, true)
assertNextCandidate(t, towerIterator, secondTower)
// Assert that the GetTower correctly returns the tower too.
tower, err := towerIterator.GetTower(secondTower.ID)
require.NoError(t, err)
assertTowersEqual(t, secondTower, tower)
// Now remove the tower and assert that GetTower returns expected error.
err = towerIterator.RemoveCandidate(secondTower.ID, nil)
require.NoError(t, err)
_, err = towerIterator.GetTower(secondTower.ID)
require.ErrorIs(t, err, ErrTowerNotInIterator)
}

View file

@ -1,6 +1,8 @@
package wtclient
import "errors"
import (
"errors"
)
var (
// ErrClientExiting signals that the watchtower client is shutting down.
@ -11,6 +13,10 @@ var (
ErrTowerCandidatesExhausted = errors.New("exhausted all tower " +
"candidates")
// ErrTowerNotInIterator is returned when a requested tower was not
// found in the iterator.
ErrTowerNotInIterator = errors.New("tower not in iterator")
// ErrPermanentTowerFailure signals that the tower has reported that it
// has permanently failed or the client believes this has happened based
// on the tower's behavior.