From 79245425005dd683848d3c6bc26a6b7b54ef536a Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 12 Oct 2022 09:21:38 +0200 Subject: [PATCH] watchtower: add AddressIterator and tests In this commit, a new AddressIterator type is added. It is a type that can be used to iterate over a list of addresses. It strictly disallows the list of addresses it holds to be empty. It also allows callers to place locks on certain addresses in order to prevent other callers from removing the addresses in question from the iterator. --- watchtower/wtclient/addr_iterator.go | 344 ++++++++++++++++++++++ watchtower/wtclient/addr_iterator_test.go | 188 ++++++++++++ 2 files changed, 532 insertions(+) create mode 100644 watchtower/wtclient/addr_iterator.go create mode 100644 watchtower/wtclient/addr_iterator_test.go diff --git a/watchtower/wtclient/addr_iterator.go b/watchtower/wtclient/addr_iterator.go new file mode 100644 index 000000000..87065c011 --- /dev/null +++ b/watchtower/wtclient/addr_iterator.go @@ -0,0 +1,344 @@ +package wtclient + +import ( + "container/list" + "errors" + "fmt" + "net" + "sync" + + "github.com/lightningnetwork/lnd/watchtower/wtdb" +) + +var ( + // ErrAddressesExhausted signals that a addressIterator has cycled + // through all available addresses. + ErrAddressesExhausted = errors.New("exhausted all addresses") + + // ErrAddrInUse indicates that an address is locked and cannot be + // removed from the addressIterator. + ErrAddrInUse = errors.New("address in use") +) + +// AddressIterator handles iteration over a list of addresses. It strictly +// disallows the list of addresses it holds to be empty. It also allows callers +// to place locks on certain addresses in order to prevent other callers from +// removing the addresses in question from the iterator. +type AddressIterator interface { + // Next returns the next candidate address. This iterator will always + // return candidates in the order given when the iterator was + // instantiated. If no more candidates are available, + // ErrAddressesExhausted is returned. + Next() (net.Addr, error) + + // NextAndLock does the same as described for Next, and it also places a + // lock on the returned address so that the address can not be removed + // until the lock on it has been released via ReleaseLock. + NextAndLock() (net.Addr, error) + + // Peek returns the currently selected address in the iterator. If the + // end of the iterator has been reached then it is reset and the first + // item in the iterator is returned. Since the AddressIterator will + // never have an empty address list, this function will never return a + // nil value. + Peek() net.Addr + + // PeekAndLock does the same as described for Peek, and it also places + // a lock on the returned address so that the address can not be removed + // until the lock on it has been released via ReleaseLock. + PeekAndLock() net.Addr + + // ReleaseLock releases the lock held on the given address. + ReleaseLock(addr net.Addr) + + // Add adds a new address to the iterator. + Add(addr net.Addr) + + // Remove removes an existing address from the iterator. It disallows + // the address from being removed if it is the last address in the + // iterator or if there is currently a lock on the address. + Remove(addr net.Addr) error + + // HasLocked returns true if the addressIterator has any locked + // addresses. + HasLocked() bool + + // GetAll returns a copy of all the addresses in the iterator. + GetAll() []net.Addr + + // Reset clears the iterators state, and makes the address at the front + // of the list the next item to be returned. + Reset() +} + +// A compile-time check to ensure that addressIterator implements the +// AddressIterator interface. +var _ AddressIterator = (*addressIterator)(nil) + +// addressIterator is a linked-list implementation of an AddressIterator. +type addressIterator struct { + mu sync.Mutex + addrList *list.List + currentTopAddr *list.Element + candidates map[string]*candidateAddr + totalLockCount int +} + +type candidateAddr struct { + addr net.Addr + numLocks int +} + +// newAddressIterator constructs a new addressIterator. +func newAddressIterator(addrs ...net.Addr) (*addressIterator, error) { + if len(addrs) == 0 { + return nil, fmt.Errorf("must have at least one address") + } + + iter := &addressIterator{ + addrList: list.New(), + candidates: make(map[string]*candidateAddr), + } + + for _, addr := range addrs { + addrID := addr.String() + iter.addrList.PushBack(addrID) + iter.candidates[addrID] = &candidateAddr{addr: addr} + } + iter.Reset() + + return iter, nil +} + +// Reset clears the iterators state, and makes the address at the front of the +// list the next item to be returned. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) Reset() { + a.mu.Lock() + defer a.mu.Unlock() + + a.unsafeReset() +} + +// unsafeReset clears the iterator state and makes the address at the front of +// the list the next item to be returned. +// +// NOTE: this method is not thread safe and so should only be called if the +// appropriate mutex is being held. +func (a *addressIterator) unsafeReset() { + // Reset the next candidate to the front of the linked-list. + a.currentTopAddr = a.addrList.Front() +} + +// Next returns the next candidate address. This iterator will always return +// candidates in the order given when the iterator was instantiated. If no more +// candidates are available, ErrAddressesExhausted is returned. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) Next() (net.Addr, error) { + return a.next(false) +} + +// NextAndLock does the same as described for Next, and it also places a lock on +// the returned address so that the address can not be removed until the lock on +// it has been released via ReleaseLock. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) NextAndLock() (net.Addr, error) { + return a.next(true) +} + +// next returns the next candidate address. This iterator will always return +// candidates in the order given when the iterator was instantiated. If no more +// candidates are available, ErrAddressesExhausted is returned. +func (a *addressIterator) next(lock bool) (net.Addr, error) { + a.mu.Lock() + defer a.mu.Unlock() + + // Set the next candidate to the subsequent element. + a.currentTopAddr = a.currentTopAddr.Next() + + for a.currentTopAddr != nil { + // Propose the address at the front of the list. + addrID := a.currentTopAddr.Value.(string) + + // Check whether this address is still considered a candidate. + // If it's not, we'll proceed to the next. + candidate, ok := a.candidates[addrID] + if !ok { + nextCandidate := a.currentTopAddr.Next() + a.addrList.Remove(a.currentTopAddr) + a.currentTopAddr = nextCandidate + continue + } + + if lock { + candidate.numLocks++ + a.totalLockCount++ + } + + return candidate.addr, nil + } + + return nil, ErrAddressesExhausted +} + +// Peek returns the currently selected address in the iterator. If the end of +// the list has been reached then the iterator is reset and the first item in +// the list is returned. Since the addressIterator will never have an empty +// address list, this function will never return a nil value. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) Peek() net.Addr { + return a.peek(false) +} + +// PeekAndLock does the same as described for Peek, and it also places a lock on +// the returned address so that the address can not be removed until the lock +// on it has been released via ReleaseLock. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) PeekAndLock() net.Addr { + return a.peek(true) +} + +// peek returns the currently selected address in the iterator. If the end of +// the list has been reached then the iterator is reset and the first item in +// the list is returned. Since the addressIterator will never have an empty +// address list, this function will never return a nil value. If lock is set to +// true, the address will be locked for removal until ReleaseLock has been +// called for the address. +func (a *addressIterator) peek(lock bool) net.Addr { + a.mu.Lock() + defer a.mu.Unlock() + + for { + // If currentTopAddr is nil, it means we have reached the end of + // the list, so we reset it here. The iterator always has at + // least one address, so we can be sure that currentTopAddr will + // be non-nil after calling reset here. + if a.currentTopAddr == nil { + a.unsafeReset() + } + + addrID := a.currentTopAddr.Value.(string) + candidate, ok := a.candidates[addrID] + if !ok { + nextCandidate := a.currentTopAddr.Next() + a.addrList.Remove(a.currentTopAddr) + a.currentTopAddr = nextCandidate + continue + } + + if lock { + candidate.numLocks++ + a.totalLockCount++ + } + + return candidate.addr + } +} + +// ReleaseLock releases the lock held on the given address. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) ReleaseLock(addr net.Addr) { + a.mu.Lock() + defer a.mu.Unlock() + + candidateAddr, ok := a.candidates[addr.String()] + if !ok { + return + } + + if candidateAddr.numLocks == 0 { + return + } + + candidateAddr.numLocks-- + a.totalLockCount-- +} + +// Add adds a new address to the iterator. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) Add(addr net.Addr) { + a.mu.Lock() + defer a.mu.Unlock() + + if _, ok := a.candidates[addr.String()]; ok { + return + } + + a.addrList.PushBack(addr.String()) + a.candidates[addr.String()] = &candidateAddr{addr: addr} + + // If we've reached the end of our queue, then this candidate + // will become the next. + if a.currentTopAddr == nil { + a.currentTopAddr = a.addrList.Back() + } +} + +// Remove removes an existing address from the iterator. It disallows the +// address from being removed if it is the last address in the iterator or if +// there is currently a lock on the address. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) Remove(addr net.Addr) error { + a.mu.Lock() + defer a.mu.Unlock() + + candidate, ok := a.candidates[addr.String()] + if !ok { + return nil + } + + if len(a.candidates) == 1 { + return wtdb.ErrLastTowerAddr + } + + if candidate.numLocks > 0 { + return ErrAddrInUse + } + + delete(a.candidates, addr.String()) + return nil +} + +// HasLocked returns true if the addressIterator has any locked addresses. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) HasLocked() bool { + a.mu.Lock() + defer a.mu.Unlock() + + return a.totalLockCount > 0 +} + +// GetAll returns a copy of all the addresses in the iterator. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) GetAll() []net.Addr { + a.mu.Lock() + defer a.mu.Unlock() + + var addrs []net.Addr + cursor := a.addrList.Front() + + for cursor != nil { + addrID := cursor.Value.(string) + + addr, ok := a.candidates[addrID] + if !ok { + cursor = cursor.Next() + continue + } + + addrs = append(addrs, addr.addr) + cursor = cursor.Next() + } + + return addrs +} diff --git a/watchtower/wtclient/addr_iterator_test.go b/watchtower/wtclient/addr_iterator_test.go new file mode 100644 index 000000000..d3674d985 --- /dev/null +++ b/watchtower/wtclient/addr_iterator_test.go @@ -0,0 +1,188 @@ +package wtclient + +import ( + "net" + "testing" + + "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/stretchr/testify/require" +) + +// TestAddrIterator tests the behaviour of the addressIterator. +func TestAddrIterator(t *testing.T) { + // Assert that an iterator can't be initialised with an empty address + // list. + _, err := newAddressIterator() + require.ErrorContains(t, err, "must have at least one address") + + addr1, err := net.ResolveTCPAddr("tcp", "1.2.3.4:8000") + require.NoError(t, err) + + // Initialise the iterator with addr1. + iter, err := newAddressIterator(addr1) + require.NoError(t, err) + + // Attempting to remove addr1 should fail now since it is the only + // address in the iterator. + iter.Add(addr1) + err = iter.Remove(addr1) + require.ErrorIs(t, err, wtdb.ErrLastTowerAddr) + + // Adding a duplicate of addr1 and then calling Remove should still + // return an error. + err = iter.Remove(addr1) + require.ErrorIs(t, err, wtdb.ErrLastTowerAddr) + + addr2, err := net.ResolveTCPAddr("tcp", "1.2.3.4:8001") + require.NoError(t, err) + + // Add addr2 to the iterator. + iter.Add(addr2) + + // Check that peek returns addr1. + a1 := iter.Peek() + require.NoError(t, err) + require.Equal(t, addr1, a1) + + // Calling peek multiple times should return the same result. + a1 = iter.Peek() + require.Equal(t, addr1, a1) + + // Calling Next should now return addr2. + a2, err := iter.Next() + require.NoError(t, err) + require.Equal(t, addr2, a2) + + // Assert that Peek now returns addr2. + a2 = iter.Peek() + require.NoError(t, err) + require.Equal(t, addr2, a2) + + // Calling Next should result in reaching the end of th list. + _, err = iter.Next() + require.ErrorIs(t, err, ErrAddressesExhausted) + + // Calling Peek now should reset the queue and return addr1. + a1 = iter.Peek() + require.Equal(t, addr1, a1) + + // Wind the list to the end again so that we can test the Reset func. + _, err = iter.Next() + require.NoError(t, err) + + _, err = iter.Next() + require.ErrorIs(t, err, ErrAddressesExhausted) + + iter.Reset() + + // Now Next should return addr 2. + a2, err = iter.Next() + require.NoError(t, err) + require.Equal(t, addr2, a2) + + addr3, err := net.ResolveTCPAddr("tcp", "1.2.3.4:8002") + require.NoError(t, err) + + // Add addr3 now to ensure that the iteration works even if we are + // midway through the queue. + iter.Add(addr3) + + // Now Next should return addr 3. + a3, err := iter.Next() + require.NoError(t, err) + require.Equal(t, addr3, a3) + + // Quickly test that GetAll correctly returns a copy of all the + // addresses in the iterator. + addrList := iter.GetAll() + require.ElementsMatch(t, addrList, []net.Addr{addr1, addr2, addr3}) + + // Let's now remove addr3. + err = iter.Remove(addr3) + require.NoError(t, err) + + // Since addr3 is gone, Peek should return addr1. + a1 = iter.Peek() + require.Equal(t, addr1, a1) + + // Lastly, we will test the "locking" of addresses. + + // First we test the locking of an address via the PeekAndLock function. + a1 = iter.PeekAndLock() + require.Equal(t, addr1, a1) + require.True(t, iter.HasLocked()) + + // Assert that we can't remove addr1 if there is a lock on it. + err = iter.Remove(addr1) + require.ErrorIs(t, err, ErrAddrInUse) + + // Now release the lock on addr1. + iter.ReleaseLock(addr1) + require.False(t, iter.HasLocked()) + + // Since the lock has been released, we should now be able to remove + // addr1. + err = iter.Remove(addr1) + require.NoError(t, err) + + // Now we test the locking of an address via the NextAndLock function. + // To do this, we first re-add addr3. + iter.Add(addr3) + + a2, err = iter.NextAndLock() + require.NoError(t, err) + require.Equal(t, addr2, a2) + require.True(t, iter.HasLocked()) + + // Assert that we can't remove addr2 if there is a lock on it. + err = iter.Remove(addr2) + require.ErrorIs(t, err, ErrAddrInUse) + + // Now release the lock on addr2. + iter.ReleaseLock(addr2) + require.False(t, iter.HasLocked()) + + // Since the lock has been released, we should now be able to remove + // addr1. + err = iter.Remove(addr2) + require.NoError(t, err) + + // Only addr3 should still be left in the iterator. + addrList = iter.GetAll() + require.Len(t, addrList, 1) + require.Contains(t, addrList, addr3) + + // Ensure that HasLocked acts correctly in the case where more than one + // address is being locked and unlock as well as the case where the same + // address is locked more than once. + + require.False(t, iter.HasLocked()) + + a3 = iter.PeekAndLock() + require.Equal(t, addr3, a3) + require.True(t, iter.HasLocked()) + + a3 = iter.PeekAndLock() + require.Equal(t, addr3, a3) + require.True(t, iter.HasLocked()) + + iter.Add(addr2) + a2, err = iter.NextAndLock() + require.NoError(t, err) + require.Equal(t, addr2, a2) + require.True(t, iter.HasLocked()) + + // Now release addr2 and asset that HasLock is still true. + iter.ReleaseLock(addr2) + require.True(t, iter.HasLocked()) + + // Releasing one of the locks on addr3 now should still result in + // HasLocked returning true. + iter.ReleaseLock(addr3) + require.True(t, iter.HasLocked()) + + // Releasing it again should now result in should still result in + // HasLocked returning false. + iter.ReleaseLock(addr3) + require.False(t, iter.HasLocked()) +}