From 9e4c8dd5090666f008361390d19a9c2bc8ecba17 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 30 Mar 2023 11:49:58 +0200 Subject: [PATCH] wtclient: make addr iterator panic safe Ensure that calling Next twice in a row without first calling Reset is safe when the iterator is at the end of its list. Also alter the towerListIterator to call Reset after hitting an error on Next. --- watchtower/wtclient/addr_iterator.go | 4 ++++ watchtower/wtclient/addr_iterator_test.go | 8 +------- watchtower/wtclient/candidate_iterator.go | 1 + 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/watchtower/wtclient/addr_iterator.go b/watchtower/wtclient/addr_iterator.go index cb16d335a..8abcb3891 100644 --- a/watchtower/wtclient/addr_iterator.go +++ b/watchtower/wtclient/addr_iterator.go @@ -162,6 +162,10 @@ func (a *addressIterator) next(lock bool) (net.Addr, error) { a.mu.Lock() defer a.mu.Unlock() + if a.currentTopAddr == nil { + return nil, ErrAddressesExhausted + } + // Set the next candidate to the subsequent element. a.currentTopAddr = a.currentTopAddr.Next() diff --git a/watchtower/wtclient/addr_iterator_test.go b/watchtower/wtclient/addr_iterator_test.go index 436be2a7e..a2e49ff90 100644 --- a/watchtower/wtclient/addr_iterator_test.go +++ b/watchtower/wtclient/addr_iterator_test.go @@ -217,15 +217,9 @@ func TestAddrIterator(t *testing.T) { require.False(t, iter.HasLocked()) }) - t.Run("calling Next twice without Reset panics", func(t *testing.T) { + t.Run("calling Next twice without Reset is safe", func(t *testing.T) { t.Parallel() - // This defer-function asserts that a panic does occur. - defer func() { - r := recover() - require.NotNilf(t, r, "the code did not panic") - }() - // Initialise the iterator with addr1. iter, err := newAddressIterator(addr1) require.NoError(t, err) diff --git a/watchtower/wtclient/candidate_iterator.go b/watchtower/wtclient/candidate_iterator.go index 10ef86465..29c4719f0 100644 --- a/watchtower/wtclient/candidate_iterator.go +++ b/watchtower/wtclient/candidate_iterator.go @@ -145,6 +145,7 @@ func (t *towerListIterator) AddCandidate(candidate *Tower) { for { next, err := candidate.Addresses.Next() if err != nil { + candidate.Addresses.Reset() break }