mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-02-23 14:40:30 +01:00
Merge pull request #7025 from ellemouton/cantRemoveTowerDuringNegotiation
watchtower: introduce an AddressIterator
This commit is contained in:
commit
dfc526b8b4
17 changed files with 1201 additions and 493 deletions
|
@ -89,6 +89,9 @@ https://github.com/lightningnetwork/lnd/pull/6963/)
|
||||||
* [Fixed a flake in the TestBlockCacheMutexes unit
|
* [Fixed a flake in the TestBlockCacheMutexes unit
|
||||||
test](https://github.com/lightningnetwork/lnd/pull/7029).
|
test](https://github.com/lightningnetwork/lnd/pull/7029).
|
||||||
|
|
||||||
|
* [Create a helper function to wait for peer to come
|
||||||
|
online](https://github.com/lightningnetwork/lnd/pull/6931).
|
||||||
|
|
||||||
## `lncli`
|
## `lncli`
|
||||||
* [Add an `insecure` flag to skip tls auth as well as a `metadata` string slice
|
* [Add an `insecure` flag to skip tls auth as well as a `metadata` string slice
|
||||||
flag](https://github.com/lightningnetwork/lnd/pull/6818) that allows the
|
flag](https://github.com/lightningnetwork/lnd/pull/6818) that allows the
|
||||||
|
@ -119,6 +122,12 @@ https://github.com/lightningnetwork/lnd/pull/6963/)
|
||||||
caller is expected to know that doing so with untrusted input is
|
caller is expected to know that doing so with untrusted input is
|
||||||
unsafe.](https://github.com/lightningnetwork/lnd/pull/6779)
|
unsafe.](https://github.com/lightningnetwork/lnd/pull/6779)
|
||||||
|
|
||||||
|
* [test: replace defer cleanup with
|
||||||
|
`t.Cleanup`](https://github.com/lightningnetwork/lnd/pull/6864).
|
||||||
|
|
||||||
|
* [test: fix loop variables being accessed in
|
||||||
|
closures](https://github.com/lightningnetwork/lnd/pull/7032).
|
||||||
|
|
||||||
## Watchtowers
|
## Watchtowers
|
||||||
|
|
||||||
* [Create a towerID-to-sessionID index in the wtclient DB to improve the
|
* [Create a towerID-to-sessionID index in the wtclient DB to improve the
|
||||||
|
@ -131,14 +140,10 @@ https://github.com/lightningnetwork/lnd/pull/6963/)
|
||||||
struct](https://github.com/lightningnetwork/lnd/pull/6928) in order to
|
struct](https://github.com/lightningnetwork/lnd/pull/6928) in order to
|
||||||
improve the performance of fetching a `ClientSession` from the DB.
|
improve the performance of fetching a `ClientSession` from the DB.
|
||||||
|
|
||||||
* [Create a helper function to wait for peer to come
|
* [Allow user to update tower address without requiring a restart. Also allow
|
||||||
online](https://github.com/lightningnetwork/lnd/pull/6931).
|
the removal of a tower address if the current session negotiation is not
|
||||||
|
using the address in question](
|
||||||
* [test: replace defer cleanup with
|
https://github.com/lightningnetwork/lnd/pull/7025)
|
||||||
`t.Cleanup`](https://github.com/lightningnetwork/lnd/pull/6864).
|
|
||||||
|
|
||||||
* [test: fix loop variables being accessed in
|
|
||||||
closures](https://github.com/lightningnetwork/lnd/pull/7032).
|
|
||||||
|
|
||||||
### Tooling and documentation
|
### Tooling and documentation
|
||||||
|
|
||||||
|
|
344
watchtower/wtclient/addr_iterator.go
Normal file
344
watchtower/wtclient/addr_iterator.go
Normal file
|
@ -0,0 +1,344 @@
|
||||||
|
package wtclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/list"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrAddressesExhausted signals that a addressIterator has cycled
|
||||||
|
// through all available addresses.
|
||||||
|
ErrAddressesExhausted = errors.New("exhausted all addresses")
|
||||||
|
|
||||||
|
// ErrAddrInUse indicates that an address is locked and cannot be
|
||||||
|
// removed from the addressIterator.
|
||||||
|
ErrAddrInUse = errors.New("address in use")
|
||||||
|
)
|
||||||
|
|
||||||
|
// AddressIterator handles iteration over a list of addresses. It strictly
|
||||||
|
// disallows the list of addresses it holds to be empty. It also allows callers
|
||||||
|
// to place locks on certain addresses in order to prevent other callers from
|
||||||
|
// removing the addresses in question from the iterator.
|
||||||
|
type AddressIterator interface {
|
||||||
|
// Next returns the next candidate address. This iterator will always
|
||||||
|
// return candidates in the order given when the iterator was
|
||||||
|
// instantiated. If no more candidates are available,
|
||||||
|
// ErrAddressesExhausted is returned.
|
||||||
|
Next() (net.Addr, error)
|
||||||
|
|
||||||
|
// NextAndLock does the same as described for Next, and it also places a
|
||||||
|
// lock on the returned address so that the address can not be removed
|
||||||
|
// until the lock on it has been released via ReleaseLock.
|
||||||
|
NextAndLock() (net.Addr, error)
|
||||||
|
|
||||||
|
// Peek returns the currently selected address in the iterator. If the
|
||||||
|
// end of the iterator has been reached then it is reset and the first
|
||||||
|
// item in the iterator is returned. Since the AddressIterator will
|
||||||
|
// never have an empty address list, this function will never return a
|
||||||
|
// nil value.
|
||||||
|
Peek() net.Addr
|
||||||
|
|
||||||
|
// PeekAndLock does the same as described for Peek, and it also places
|
||||||
|
// a lock on the returned address so that the address can not be removed
|
||||||
|
// until the lock on it has been released via ReleaseLock.
|
||||||
|
PeekAndLock() net.Addr
|
||||||
|
|
||||||
|
// ReleaseLock releases the lock held on the given address.
|
||||||
|
ReleaseLock(addr net.Addr)
|
||||||
|
|
||||||
|
// Add adds a new address to the iterator.
|
||||||
|
Add(addr net.Addr)
|
||||||
|
|
||||||
|
// Remove removes an existing address from the iterator. It disallows
|
||||||
|
// the address from being removed if it is the last address in the
|
||||||
|
// iterator or if there is currently a lock on the address.
|
||||||
|
Remove(addr net.Addr) error
|
||||||
|
|
||||||
|
// HasLocked returns true if the addressIterator has any locked
|
||||||
|
// addresses.
|
||||||
|
HasLocked() bool
|
||||||
|
|
||||||
|
// GetAll returns a copy of all the addresses in the iterator.
|
||||||
|
GetAll() []net.Addr
|
||||||
|
|
||||||
|
// Reset clears the iterators state, and makes the address at the front
|
||||||
|
// of the list the next item to be returned.
|
||||||
|
Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
// A compile-time check to ensure that addressIterator implements the
|
||||||
|
// AddressIterator interface.
|
||||||
|
var _ AddressIterator = (*addressIterator)(nil)
|
||||||
|
|
||||||
|
// addressIterator is a linked-list implementation of an AddressIterator.
|
||||||
|
type addressIterator struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
addrList *list.List
|
||||||
|
currentTopAddr *list.Element
|
||||||
|
candidates map[string]*candidateAddr
|
||||||
|
totalLockCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
type candidateAddr struct {
|
||||||
|
addr net.Addr
|
||||||
|
numLocks int
|
||||||
|
}
|
||||||
|
|
||||||
|
// newAddressIterator constructs a new addressIterator.
|
||||||
|
func newAddressIterator(addrs ...net.Addr) (*addressIterator, error) {
|
||||||
|
if len(addrs) == 0 {
|
||||||
|
return nil, fmt.Errorf("must have at least one address")
|
||||||
|
}
|
||||||
|
|
||||||
|
iter := &addressIterator{
|
||||||
|
addrList: list.New(),
|
||||||
|
candidates: make(map[string]*candidateAddr),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
addrID := addr.String()
|
||||||
|
iter.addrList.PushBack(addrID)
|
||||||
|
iter.candidates[addrID] = &candidateAddr{addr: addr}
|
||||||
|
}
|
||||||
|
iter.Reset()
|
||||||
|
|
||||||
|
return iter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset clears the iterators state, and makes the address at the front of the
|
||||||
|
// list the next item to be returned.
|
||||||
|
//
|
||||||
|
// NOTE: This is part of the AddressIterator interface.
|
||||||
|
func (a *addressIterator) Reset() {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
|
a.unsafeReset()
|
||||||
|
}
|
||||||
|
|
||||||
|
// unsafeReset clears the iterator state and makes the address at the front of
|
||||||
|
// the list the next item to be returned.
|
||||||
|
//
|
||||||
|
// NOTE: this method is not thread safe and so should only be called if the
|
||||||
|
// appropriate mutex is being held.
|
||||||
|
func (a *addressIterator) unsafeReset() {
|
||||||
|
// Reset the next candidate to the front of the linked-list.
|
||||||
|
a.currentTopAddr = a.addrList.Front()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next returns the next candidate address. This iterator will always return
|
||||||
|
// candidates in the order given when the iterator was instantiated. If no more
|
||||||
|
// candidates are available, ErrAddressesExhausted is returned.
|
||||||
|
//
|
||||||
|
// NOTE: This is part of the AddressIterator interface.
|
||||||
|
func (a *addressIterator) Next() (net.Addr, error) {
|
||||||
|
return a.next(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextAndLock does the same as described for Next, and it also places a lock on
|
||||||
|
// the returned address so that the address can not be removed until the lock on
|
||||||
|
// it has been released via ReleaseLock.
|
||||||
|
//
|
||||||
|
// NOTE: This is part of the AddressIterator interface.
|
||||||
|
func (a *addressIterator) NextAndLock() (net.Addr, error) {
|
||||||
|
return a.next(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// next returns the next candidate address. This iterator will always return
|
||||||
|
// candidates in the order given when the iterator was instantiated. If no more
|
||||||
|
// candidates are available, ErrAddressesExhausted is returned.
|
||||||
|
func (a *addressIterator) next(lock bool) (net.Addr, error) {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
|
// Set the next candidate to the subsequent element.
|
||||||
|
a.currentTopAddr = a.currentTopAddr.Next()
|
||||||
|
|
||||||
|
for a.currentTopAddr != nil {
|
||||||
|
// Propose the address at the front of the list.
|
||||||
|
addrID := a.currentTopAddr.Value.(string)
|
||||||
|
|
||||||
|
// Check whether this address is still considered a candidate.
|
||||||
|
// If it's not, we'll proceed to the next.
|
||||||
|
candidate, ok := a.candidates[addrID]
|
||||||
|
if !ok {
|
||||||
|
nextCandidate := a.currentTopAddr.Next()
|
||||||
|
a.addrList.Remove(a.currentTopAddr)
|
||||||
|
a.currentTopAddr = nextCandidate
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if lock {
|
||||||
|
candidate.numLocks++
|
||||||
|
a.totalLockCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
return candidate.addr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, ErrAddressesExhausted
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peek returns the currently selected address in the iterator. If the end of
|
||||||
|
// the list has been reached then the iterator is reset and the first item in
|
||||||
|
// the list is returned. Since the addressIterator will never have an empty
|
||||||
|
// address list, this function will never return a nil value.
|
||||||
|
//
|
||||||
|
// NOTE: This is part of the AddressIterator interface.
|
||||||
|
func (a *addressIterator) Peek() net.Addr {
|
||||||
|
return a.peek(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeekAndLock does the same as described for Peek, and it also places a lock on
|
||||||
|
// the returned address so that the address can not be removed until the lock
|
||||||
|
// on it has been released via ReleaseLock.
|
||||||
|
//
|
||||||
|
// NOTE: This is part of the AddressIterator interface.
|
||||||
|
func (a *addressIterator) PeekAndLock() net.Addr {
|
||||||
|
return a.peek(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// peek returns the currently selected address in the iterator. If the end of
|
||||||
|
// the list has been reached then the iterator is reset and the first item in
|
||||||
|
// the list is returned. Since the addressIterator will never have an empty
|
||||||
|
// address list, this function will never return a nil value. If lock is set to
|
||||||
|
// true, the address will be locked for removal until ReleaseLock has been
|
||||||
|
// called for the address.
|
||||||
|
func (a *addressIterator) peek(lock bool) net.Addr {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
|
for {
|
||||||
|
// If currentTopAddr is nil, it means we have reached the end of
|
||||||
|
// the list, so we reset it here. The iterator always has at
|
||||||
|
// least one address, so we can be sure that currentTopAddr will
|
||||||
|
// be non-nil after calling reset here.
|
||||||
|
if a.currentTopAddr == nil {
|
||||||
|
a.unsafeReset()
|
||||||
|
}
|
||||||
|
|
||||||
|
addrID := a.currentTopAddr.Value.(string)
|
||||||
|
candidate, ok := a.candidates[addrID]
|
||||||
|
if !ok {
|
||||||
|
nextCandidate := a.currentTopAddr.Next()
|
||||||
|
a.addrList.Remove(a.currentTopAddr)
|
||||||
|
a.currentTopAddr = nextCandidate
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if lock {
|
||||||
|
candidate.numLocks++
|
||||||
|
a.totalLockCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
return candidate.addr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReleaseLock releases the lock held on the given address.
|
||||||
|
//
|
||||||
|
// NOTE: This is part of the AddressIterator interface.
|
||||||
|
func (a *addressIterator) ReleaseLock(addr net.Addr) {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
|
candidateAddr, ok := a.candidates[addr.String()]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if candidateAddr.numLocks == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
candidateAddr.numLocks--
|
||||||
|
a.totalLockCount--
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds a new address to the iterator.
|
||||||
|
//
|
||||||
|
// NOTE: This is part of the AddressIterator interface.
|
||||||
|
func (a *addressIterator) Add(addr net.Addr) {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
|
if _, ok := a.candidates[addr.String()]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
a.addrList.PushBack(addr.String())
|
||||||
|
a.candidates[addr.String()] = &candidateAddr{addr: addr}
|
||||||
|
|
||||||
|
// If we've reached the end of our queue, then this candidate
|
||||||
|
// will become the next.
|
||||||
|
if a.currentTopAddr == nil {
|
||||||
|
a.currentTopAddr = a.addrList.Back()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove removes an existing address from the iterator. It disallows the
|
||||||
|
// address from being removed if it is the last address in the iterator or if
|
||||||
|
// there is currently a lock on the address.
|
||||||
|
//
|
||||||
|
// NOTE: This is part of the AddressIterator interface.
|
||||||
|
func (a *addressIterator) Remove(addr net.Addr) error {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
|
candidate, ok := a.candidates[addr.String()]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(a.candidates) == 1 {
|
||||||
|
return wtdb.ErrLastTowerAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
if candidate.numLocks > 0 {
|
||||||
|
return ErrAddrInUse
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(a.candidates, addr.String())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasLocked returns true if the addressIterator has any locked addresses.
|
||||||
|
//
|
||||||
|
// NOTE: This is part of the AddressIterator interface.
|
||||||
|
func (a *addressIterator) HasLocked() bool {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
|
return a.totalLockCount > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAll returns a copy of all the addresses in the iterator.
|
||||||
|
//
|
||||||
|
// NOTE: This is part of the AddressIterator interface.
|
||||||
|
func (a *addressIterator) GetAll() []net.Addr {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
|
var addrs []net.Addr
|
||||||
|
cursor := a.addrList.Front()
|
||||||
|
|
||||||
|
for cursor != nil {
|
||||||
|
addrID := cursor.Value.(string)
|
||||||
|
|
||||||
|
addr, ok := a.candidates[addrID]
|
||||||
|
if !ok {
|
||||||
|
cursor = cursor.Next()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
addrs = append(addrs, addr.addr)
|
||||||
|
cursor = cursor.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
return addrs
|
||||||
|
}
|
188
watchtower/wtclient/addr_iterator_test.go
Normal file
188
watchtower/wtclient/addr_iterator_test.go
Normal file
|
@ -0,0 +1,188 @@
|
||||||
|
package wtclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestAddrIterator tests the behaviour of the addressIterator.
|
||||||
|
func TestAddrIterator(t *testing.T) {
|
||||||
|
// Assert that an iterator can't be initialised with an empty address
|
||||||
|
// list.
|
||||||
|
_, err := newAddressIterator()
|
||||||
|
require.ErrorContains(t, err, "must have at least one address")
|
||||||
|
|
||||||
|
addr1, err := net.ResolveTCPAddr("tcp", "1.2.3.4:8000")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Initialise the iterator with addr1.
|
||||||
|
iter, err := newAddressIterator(addr1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Attempting to remove addr1 should fail now since it is the only
|
||||||
|
// address in the iterator.
|
||||||
|
iter.Add(addr1)
|
||||||
|
err = iter.Remove(addr1)
|
||||||
|
require.ErrorIs(t, err, wtdb.ErrLastTowerAddr)
|
||||||
|
|
||||||
|
// Adding a duplicate of addr1 and then calling Remove should still
|
||||||
|
// return an error.
|
||||||
|
err = iter.Remove(addr1)
|
||||||
|
require.ErrorIs(t, err, wtdb.ErrLastTowerAddr)
|
||||||
|
|
||||||
|
addr2, err := net.ResolveTCPAddr("tcp", "1.2.3.4:8001")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add addr2 to the iterator.
|
||||||
|
iter.Add(addr2)
|
||||||
|
|
||||||
|
// Check that peek returns addr1.
|
||||||
|
a1 := iter.Peek()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, addr1, a1)
|
||||||
|
|
||||||
|
// Calling peek multiple times should return the same result.
|
||||||
|
a1 = iter.Peek()
|
||||||
|
require.Equal(t, addr1, a1)
|
||||||
|
|
||||||
|
// Calling Next should now return addr2.
|
||||||
|
a2, err := iter.Next()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, addr2, a2)
|
||||||
|
|
||||||
|
// Assert that Peek now returns addr2.
|
||||||
|
a2 = iter.Peek()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, addr2, a2)
|
||||||
|
|
||||||
|
// Calling Next should result in reaching the end of th list.
|
||||||
|
_, err = iter.Next()
|
||||||
|
require.ErrorIs(t, err, ErrAddressesExhausted)
|
||||||
|
|
||||||
|
// Calling Peek now should reset the queue and return addr1.
|
||||||
|
a1 = iter.Peek()
|
||||||
|
require.Equal(t, addr1, a1)
|
||||||
|
|
||||||
|
// Wind the list to the end again so that we can test the Reset func.
|
||||||
|
_, err = iter.Next()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = iter.Next()
|
||||||
|
require.ErrorIs(t, err, ErrAddressesExhausted)
|
||||||
|
|
||||||
|
iter.Reset()
|
||||||
|
|
||||||
|
// Now Next should return addr 2.
|
||||||
|
a2, err = iter.Next()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, addr2, a2)
|
||||||
|
|
||||||
|
addr3, err := net.ResolveTCPAddr("tcp", "1.2.3.4:8002")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add addr3 now to ensure that the iteration works even if we are
|
||||||
|
// midway through the queue.
|
||||||
|
iter.Add(addr3)
|
||||||
|
|
||||||
|
// Now Next should return addr 3.
|
||||||
|
a3, err := iter.Next()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, addr3, a3)
|
||||||
|
|
||||||
|
// Quickly test that GetAll correctly returns a copy of all the
|
||||||
|
// addresses in the iterator.
|
||||||
|
addrList := iter.GetAll()
|
||||||
|
require.ElementsMatch(t, addrList, []net.Addr{addr1, addr2, addr3})
|
||||||
|
|
||||||
|
// Let's now remove addr3.
|
||||||
|
err = iter.Remove(addr3)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Since addr3 is gone, Peek should return addr1.
|
||||||
|
a1 = iter.Peek()
|
||||||
|
require.Equal(t, addr1, a1)
|
||||||
|
|
||||||
|
// Lastly, we will test the "locking" of addresses.
|
||||||
|
|
||||||
|
// First we test the locking of an address via the PeekAndLock function.
|
||||||
|
a1 = iter.PeekAndLock()
|
||||||
|
require.Equal(t, addr1, a1)
|
||||||
|
require.True(t, iter.HasLocked())
|
||||||
|
|
||||||
|
// Assert that we can't remove addr1 if there is a lock on it.
|
||||||
|
err = iter.Remove(addr1)
|
||||||
|
require.ErrorIs(t, err, ErrAddrInUse)
|
||||||
|
|
||||||
|
// Now release the lock on addr1.
|
||||||
|
iter.ReleaseLock(addr1)
|
||||||
|
require.False(t, iter.HasLocked())
|
||||||
|
|
||||||
|
// Since the lock has been released, we should now be able to remove
|
||||||
|
// addr1.
|
||||||
|
err = iter.Remove(addr1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Now we test the locking of an address via the NextAndLock function.
|
||||||
|
// To do this, we first re-add addr3.
|
||||||
|
iter.Add(addr3)
|
||||||
|
|
||||||
|
a2, err = iter.NextAndLock()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, addr2, a2)
|
||||||
|
require.True(t, iter.HasLocked())
|
||||||
|
|
||||||
|
// Assert that we can't remove addr2 if there is a lock on it.
|
||||||
|
err = iter.Remove(addr2)
|
||||||
|
require.ErrorIs(t, err, ErrAddrInUse)
|
||||||
|
|
||||||
|
// Now release the lock on addr2.
|
||||||
|
iter.ReleaseLock(addr2)
|
||||||
|
require.False(t, iter.HasLocked())
|
||||||
|
|
||||||
|
// Since the lock has been released, we should now be able to remove
|
||||||
|
// addr1.
|
||||||
|
err = iter.Remove(addr2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Only addr3 should still be left in the iterator.
|
||||||
|
addrList = iter.GetAll()
|
||||||
|
require.Len(t, addrList, 1)
|
||||||
|
require.Contains(t, addrList, addr3)
|
||||||
|
|
||||||
|
// Ensure that HasLocked acts correctly in the case where more than one
|
||||||
|
// address is being locked and unlock as well as the case where the same
|
||||||
|
// address is locked more than once.
|
||||||
|
|
||||||
|
require.False(t, iter.HasLocked())
|
||||||
|
|
||||||
|
a3 = iter.PeekAndLock()
|
||||||
|
require.Equal(t, addr3, a3)
|
||||||
|
require.True(t, iter.HasLocked())
|
||||||
|
|
||||||
|
a3 = iter.PeekAndLock()
|
||||||
|
require.Equal(t, addr3, a3)
|
||||||
|
require.True(t, iter.HasLocked())
|
||||||
|
|
||||||
|
iter.Add(addr2)
|
||||||
|
a2, err = iter.NextAndLock()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, addr2, a2)
|
||||||
|
require.True(t, iter.HasLocked())
|
||||||
|
|
||||||
|
// Now release addr2 and asset that HasLock is still true.
|
||||||
|
iter.ReleaseLock(addr2)
|
||||||
|
require.True(t, iter.HasLocked())
|
||||||
|
|
||||||
|
// Releasing one of the locks on addr3 now should still result in
|
||||||
|
// HasLocked returning true.
|
||||||
|
iter.ReleaseLock(addr3)
|
||||||
|
require.True(t, iter.HasLocked())
|
||||||
|
|
||||||
|
// Releasing it again should now result in should still result in
|
||||||
|
// HasLocked returning false.
|
||||||
|
iter.ReleaseLock(addr3)
|
||||||
|
require.False(t, iter.HasLocked())
|
||||||
|
}
|
|
@ -2,9 +2,6 @@ package wtclient
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
|
||||||
"io"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec/v2"
|
"github.com/btcsuite/btcd/btcec/v2"
|
||||||
|
@ -12,7 +9,6 @@ import (
|
||||||
"github.com/btcsuite/btcd/chaincfg"
|
"github.com/btcsuite/btcd/chaincfg"
|
||||||
"github.com/btcsuite/btcd/txscript"
|
"github.com/btcsuite/btcd/txscript"
|
||||||
"github.com/btcsuite/btcd/wire"
|
"github.com/btcsuite/btcd/wire"
|
||||||
"github.com/davecgh/go-spew/spew"
|
|
||||||
"github.com/lightningnetwork/lnd/channeldb"
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
"github.com/lightningnetwork/lnd/input"
|
"github.com/lightningnetwork/lnd/input"
|
||||||
"github.com/lightningnetwork/lnd/keychain"
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
|
@ -54,14 +50,6 @@ var (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func makeAddrSlice(size int) []byte {
|
|
||||||
addr := make([]byte, size)
|
|
||||||
if _, err := io.ReadFull(rand.Reader, addr); err != nil {
|
|
||||||
panic("cannot make addr")
|
|
||||||
}
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|
||||||
type backupTaskTest struct {
|
type backupTaskTest struct {
|
||||||
name string
|
name string
|
||||||
chanID lnwire.ChannelID
|
chanID lnwire.ChannelID
|
||||||
|
@ -502,35 +490,12 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
|
||||||
|
|
||||||
// Assert that all parameters set during initialization are properly
|
// Assert that all parameters set during initialization are properly
|
||||||
// populated.
|
// populated.
|
||||||
if task.id.ChanID != test.chanID {
|
require.Equal(t, test.chanID, task.id.ChanID)
|
||||||
t.Fatalf("channel id mismatch, want: %s, got: %s",
|
require.Equal(t, test.breachInfo.RevokedStateNum, task.id.CommitHeight)
|
||||||
test.chanID, task.id.ChanID)
|
require.Equal(t, test.expTotalAmt, task.totalAmt)
|
||||||
}
|
require.Equal(t, test.breachInfo, task.breachInfo)
|
||||||
|
require.Equal(t, test.expToLocalInput, task.toLocalInput)
|
||||||
if task.id.CommitHeight != test.breachInfo.RevokedStateNum {
|
require.Equal(t, test.expToRemoteInput, task.toRemoteInput)
|
||||||
t.Fatalf("commit height mismatch, want: %d, got: %d",
|
|
||||||
test.breachInfo.RevokedStateNum, task.id.CommitHeight)
|
|
||||||
}
|
|
||||||
|
|
||||||
if task.totalAmt != test.expTotalAmt {
|
|
||||||
t.Fatalf("total amount mismatch, want: %d, got: %v",
|
|
||||||
test.expTotalAmt, task.totalAmt)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(task.breachInfo, test.breachInfo) {
|
|
||||||
t.Fatalf("breach info mismatch, want: %v, got: %v",
|
|
||||||
test.breachInfo, task.breachInfo)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(task.toLocalInput, test.expToLocalInput) {
|
|
||||||
t.Fatalf("to-local input mismatch, want: %v, got: %v",
|
|
||||||
test.expToLocalInput, task.toLocalInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(task.toRemoteInput, test.expToRemoteInput) {
|
|
||||||
t.Fatalf("to-local input mismatch, want: %v, got: %v",
|
|
||||||
test.expToRemoteInput, task.toRemoteInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reconstruct the expected input.Inputs that will be returned by the
|
// Reconstruct the expected input.Inputs that will be returned by the
|
||||||
// task's inputs() method.
|
// task's inputs() method.
|
||||||
|
@ -545,34 +510,24 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
|
||||||
// Assert that the inputs method returns the correct slice of
|
// Assert that the inputs method returns the correct slice of
|
||||||
// input.Inputs.
|
// input.Inputs.
|
||||||
inputs := task.inputs()
|
inputs := task.inputs()
|
||||||
if !reflect.DeepEqual(expInputs, inputs) {
|
require.Equal(t, expInputs, inputs)
|
||||||
t.Fatalf("inputs mismatch, want: %v, got: %v",
|
|
||||||
expInputs, inputs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now, bind the session to the task. If successful, this locks in the
|
// Now, bind the session to the task. If successful, this locks in the
|
||||||
// session's negotiated parameters and allows the backup task to derive
|
// session's negotiated parameters and allows the backup task to derive
|
||||||
// the final free variables in the justice transaction.
|
// the final free variables in the justice transaction.
|
||||||
err := task.bindSession(test.session)
|
err := task.bindSession(test.session)
|
||||||
if err != test.bindErr {
|
require.ErrorIs(t, err, test.bindErr)
|
||||||
t.Fatalf("expected: %v when binding session, got: %v",
|
|
||||||
test.bindErr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exit early if the bind was supposed to fail. But first, we check that
|
// Exit early if the bind was supposed to fail. But first, we check that
|
||||||
// all fields set during a bind are still unset. This ensure that a
|
// all fields set during a bind are still unset. This ensure that a
|
||||||
// failed bind doesn't have side-effects if the task is retried with a
|
// failed bind doesn't have side-effects if the task is retried with a
|
||||||
// different session.
|
// different session.
|
||||||
if test.bindErr != nil {
|
if test.bindErr != nil {
|
||||||
if task.blobType != 0 {
|
require.Zerof(t, task.blobType, "blob type should not be set "+
|
||||||
t.Fatalf("blob type should not be set on failed bind, "+
|
"on failed bind, found: %s", task.blobType)
|
||||||
"found: %s", task.blobType)
|
|
||||||
}
|
|
||||||
|
|
||||||
if task.outputs != nil {
|
require.Nilf(t, task.outputs, "justice outputs should not be "+
|
||||||
t.Fatalf("justice outputs should not be set on failed bind, "+
|
" set on failed bind, found: %v", task.outputs)
|
||||||
"found: %v", task.outputs)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -580,10 +535,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
|
||||||
// Otherwise, the binding succeeded. Assert that all values set during
|
// Otherwise, the binding succeeded. Assert that all values set during
|
||||||
// the bind are properly populated.
|
// the bind are properly populated.
|
||||||
policy := test.session.Policy
|
policy := test.session.Policy
|
||||||
if task.blobType != policy.BlobType {
|
require.Equal(t, policy.BlobType, task.blobType)
|
||||||
t.Fatalf("blob type mismatch, want: %s, got %s",
|
|
||||||
policy.BlobType, task.blobType)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the expected outputs on the justice transaction.
|
// Compute the expected outputs on the justice transaction.
|
||||||
var expOutputs = []*wire.TxOut{
|
var expOutputs = []*wire.TxOut{
|
||||||
|
@ -603,10 +555,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assert that the computed outputs match our expected outputs.
|
// Assert that the computed outputs match our expected outputs.
|
||||||
if !reflect.DeepEqual(expOutputs, task.outputs) {
|
require.Equal(t, expOutputs, task.outputs)
|
||||||
t.Fatalf("justice txn output mismatch, want: %v,\ngot: %v",
|
|
||||||
spew.Sdump(expOutputs), spew.Sdump(task.outputs))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now, we'll construct, sign, and encrypt the blob containing the parts
|
// Now, we'll construct, sign, and encrypt the blob containing the parts
|
||||||
// needed to reconstruct the justice transaction.
|
// needed to reconstruct the justice transaction.
|
||||||
|
@ -616,10 +565,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
|
||||||
// Verify that the breach hint matches the breach txid's prefix.
|
// Verify that the breach hint matches the breach txid's prefix.
|
||||||
breachTxID := test.breachInfo.BreachTxHash
|
breachTxID := test.breachInfo.BreachTxHash
|
||||||
expHint := blob.NewBreachHintFromHash(&breachTxID)
|
expHint := blob.NewBreachHintFromHash(&breachTxID)
|
||||||
if hint != expHint {
|
require.Equal(t, expHint, hint)
|
||||||
t.Fatalf("breach hint mismatch, want: %x, got: %v",
|
|
||||||
expHint, hint)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decrypt the return blob to obtain the JusticeKit containing its
|
// Decrypt the return blob to obtain the JusticeKit containing its
|
||||||
// contents.
|
// contents.
|
||||||
|
@ -634,14 +580,8 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
|
||||||
|
|
||||||
// Assert that the blob contained the serialized revocation and to-local
|
// Assert that the blob contained the serialized revocation and to-local
|
||||||
// pubkeys.
|
// pubkeys.
|
||||||
if !bytes.Equal(jKit.RevocationPubKey[:], expRevPK) {
|
require.Equal(t, expRevPK, jKit.RevocationPubKey[:])
|
||||||
t.Fatalf("revocation pk mismatch, want: %x, got: %x",
|
require.Equal(t, expToLocalPK, jKit.LocalDelayPubKey[:])
|
||||||
expRevPK, jKit.RevocationPubKey[:])
|
|
||||||
}
|
|
||||||
if !bytes.Equal(jKit.LocalDelayPubKey[:], expToLocalPK) {
|
|
||||||
t.Fatalf("revocation pk mismatch, want: %x, got: %x",
|
|
||||||
expToLocalPK, jKit.LocalDelayPubKey[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine if the breach transaction has a to-remote output and/or
|
// Determine if the breach transaction has a to-remote output and/or
|
||||||
// to-local output to spend from. Note the seemingly-reversed
|
// to-local output to spend from. Note the seemingly-reversed
|
||||||
|
@ -650,32 +590,19 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
|
||||||
hasToLocal := test.breachInfo.RemoteOutputSignDesc != nil
|
hasToLocal := test.breachInfo.RemoteOutputSignDesc != nil
|
||||||
|
|
||||||
// If the to-remote output is present, assert that the to-remote public
|
// If the to-remote output is present, assert that the to-remote public
|
||||||
// key was included in the blob.
|
// key was included in the blob. Otherwise assert that a blank public
|
||||||
if hasToRemote &&
|
// key was inserted.
|
||||||
!bytes.Equal(jKit.CommitToRemotePubKey[:], expToRemotePK) {
|
if hasToRemote {
|
||||||
t.Fatalf("mismatch to-remote pubkey, want: %x, got: %x",
|
require.Equal(t, expToRemotePK, jKit.CommitToRemotePubKey[:])
|
||||||
expToRemotePK, jKit.CommitToRemotePubKey)
|
} else {
|
||||||
}
|
require.Equal(t, zeroPK[:], jKit.CommitToRemotePubKey[:])
|
||||||
|
|
||||||
// Otherwise if the to-local output is not present, assert that a blank
|
|
||||||
// public key was inserted.
|
|
||||||
if !hasToRemote &&
|
|
||||||
!bytes.Equal(jKit.CommitToRemotePubKey[:], zeroPK[:]) {
|
|
||||||
t.Fatalf("mismatch to-remote pubkey, want: %x, got: %x",
|
|
||||||
zeroPK, jKit.CommitToRemotePubKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assert that the CSV is encoded in the blob.
|
// Assert that the CSV is encoded in the blob.
|
||||||
if jKit.CSVDelay != test.breachInfo.RemoteDelay {
|
require.Equal(t, test.breachInfo.RemoteDelay, jKit.CSVDelay)
|
||||||
t.Fatalf("mismatch remote delay, want: %d, got: %v",
|
|
||||||
test.breachInfo.RemoteDelay, jKit.CSVDelay)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert that the sweep pkscript is included.
|
// Assert that the sweep pkscript is included.
|
||||||
if !bytes.Equal(jKit.SweepAddress, test.expSweepScript) {
|
require.Equal(t, test.expSweepScript, jKit.SweepAddress)
|
||||||
t.Fatalf("sweep pkscript mismatch, want: %x, got: %x",
|
|
||||||
test.expSweepScript, jKit.SweepAddress)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, verify that the signatures are encoded in the justice kit.
|
// Finally, verify that the signatures are encoded in the justice kit.
|
||||||
// We don't validate the actual signatures produced here, since at the
|
// We don't validate the actual signatures produced here, since at the
|
||||||
|
@ -684,18 +611,20 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
|
||||||
// TODO(conner): include signature validation checks
|
// TODO(conner): include signature validation checks
|
||||||
|
|
||||||
emptyToLocalSig := bytes.Equal(jKit.CommitToLocalSig[:], zeroSig[:])
|
emptyToLocalSig := bytes.Equal(jKit.CommitToLocalSig[:], zeroSig[:])
|
||||||
switch {
|
if hasToLocal {
|
||||||
case hasToLocal && emptyToLocalSig:
|
require.False(t, emptyToLocalSig, "to-local signature should "+
|
||||||
t.Fatalf("to-local signature should not be empty")
|
"not be empty")
|
||||||
case !hasToLocal && !emptyToLocalSig:
|
} else {
|
||||||
t.Fatalf("to-local signature should be empty")
|
require.True(t, emptyToLocalSig, "to-local signature should "+
|
||||||
|
"be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
emptyToRemoteSig := bytes.Equal(jKit.CommitToRemoteSig[:], zeroSig[:])
|
emptyToRemoteSig := bytes.Equal(jKit.CommitToRemoteSig[:], zeroSig[:])
|
||||||
switch {
|
if hasToRemote {
|
||||||
case hasToRemote && emptyToRemoteSig:
|
require.False(t, emptyToRemoteSig, "to-remote signature "+
|
||||||
t.Fatalf("to-remote signature should not be empty")
|
"should not be empty")
|
||||||
case !hasToRemote && !emptyToRemoteSig:
|
} else {
|
||||||
t.Fatalf("to-remote signature should be empty")
|
require.True(t, emptyToRemoteSig, "to-remote signature "+
|
||||||
|
"should be empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ import (
|
||||||
type TowerCandidateIterator interface {
|
type TowerCandidateIterator interface {
|
||||||
// AddCandidate adds a new candidate tower to the iterator. If the
|
// AddCandidate adds a new candidate tower to the iterator. If the
|
||||||
// candidate already exists, then any new addresses are added to it.
|
// candidate already exists, then any new addresses are added to it.
|
||||||
AddCandidate(*wtdb.Tower)
|
AddCandidate(*Tower)
|
||||||
|
|
||||||
// RemoveCandidate removes an existing candidate tower from the
|
// RemoveCandidate removes an existing candidate tower from the
|
||||||
// iterator. An optional address can be provided to indicate a stale
|
// iterator. An optional address can be provided to indicate a stale
|
||||||
|
@ -32,7 +32,7 @@ type TowerCandidateIterator interface {
|
||||||
// Next returns the next candidate tower. The iterator is not required
|
// Next returns the next candidate tower. The iterator is not required
|
||||||
// to return results in any particular order. If no more candidates are
|
// to return results in any particular order. If no more candidates are
|
||||||
// available, ErrTowerCandidatesExhausted is returned.
|
// available, ErrTowerCandidatesExhausted is returned.
|
||||||
Next() (*wtdb.Tower, error)
|
Next() (*Tower, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// towerListIterator is a linked-list backed TowerCandidateIterator.
|
// towerListIterator is a linked-list backed TowerCandidateIterator.
|
||||||
|
@ -40,7 +40,7 @@ type towerListIterator struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
queue *list.List
|
queue *list.List
|
||||||
nextCandidate *list.Element
|
nextCandidate *list.Element
|
||||||
candidates map[wtdb.TowerID]*wtdb.Tower
|
candidates map[wtdb.TowerID]*Tower
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compile-time constraint to ensure *towerListIterator implements the
|
// Compile-time constraint to ensure *towerListIterator implements the
|
||||||
|
@ -49,10 +49,10 @@ var _ TowerCandidateIterator = (*towerListIterator)(nil)
|
||||||
|
|
||||||
// newTowerListIterator initializes a new towerListIterator from a variadic list
|
// newTowerListIterator initializes a new towerListIterator from a variadic list
|
||||||
// of lnwire.NetAddresses.
|
// of lnwire.NetAddresses.
|
||||||
func newTowerListIterator(candidates ...*wtdb.Tower) *towerListIterator {
|
func newTowerListIterator(candidates ...*Tower) *towerListIterator {
|
||||||
iter := &towerListIterator{
|
iter := &towerListIterator{
|
||||||
queue: list.New(),
|
queue: list.New(),
|
||||||
candidates: make(map[wtdb.TowerID]*wtdb.Tower),
|
candidates: make(map[wtdb.TowerID]*Tower),
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, candidate := range candidates {
|
for _, candidate := range candidates {
|
||||||
|
@ -79,7 +79,7 @@ func (t *towerListIterator) Reset() error {
|
||||||
// Next returns the next candidate tower. This iterator will always return
|
// Next returns the next candidate tower. This iterator will always return
|
||||||
// candidates in the order given when the iterator was instantiated. If no more
|
// candidates in the order given when the iterator was instantiated. If no more
|
||||||
// candidates are available, ErrTowerCandidatesExhausted is returned.
|
// candidates are available, ErrTowerCandidatesExhausted is returned.
|
||||||
func (t *towerListIterator) Next() (*wtdb.Tower, error) {
|
func (t *towerListIterator) Next() (*Tower, error) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ func (t *towerListIterator) Next() (*wtdb.Tower, error) {
|
||||||
|
|
||||||
// AddCandidate adds a new candidate tower to the iterator. If the candidate
|
// AddCandidate adds a new candidate tower to the iterator. If the candidate
|
||||||
// already exists, then any new addresses are added to it.
|
// already exists, then any new addresses are added to it.
|
||||||
func (t *towerListIterator) AddCandidate(candidate *wtdb.Tower) {
|
func (t *towerListIterator) AddCandidate(candidate *Tower) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
@ -121,8 +121,16 @@ func (t *towerListIterator) AddCandidate(candidate *wtdb.Tower) {
|
||||||
t.nextCandidate = t.queue.Back()
|
t.nextCandidate = t.queue.Back()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, addr := range candidate.Addresses {
|
candidate.Addresses.Reset()
|
||||||
tower.AddAddress(addr)
|
firstAddr := candidate.Addresses.Peek()
|
||||||
|
tower.Addresses.Add(firstAddr)
|
||||||
|
for {
|
||||||
|
next, err := candidate.Addresses.Next()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
tower.Addresses.Add(next)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -142,11 +150,15 @@ func (t *towerListIterator) RemoveCandidate(candidate wtdb.TowerID,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if addr != nil {
|
if addr != nil {
|
||||||
tower.RemoveAddress(addr)
|
err := tower.Addresses.Remove(addr)
|
||||||
if len(tower.Addresses) == 0 {
|
if err != nil {
|
||||||
return wtdb.ErrLastTowerAddr
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
if tower.Addresses.HasLocked() {
|
||||||
|
return ErrAddrInUse
|
||||||
|
}
|
||||||
|
|
||||||
delete(t.candidates, candidate)
|
delete(t.candidates, candidate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,10 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec/v2"
|
"github.com/btcsuite/btcd/btcec/v2"
|
||||||
"github.com/davecgh/go-spew/spew"
|
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
@ -19,66 +17,75 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func randAddr(t *testing.T) net.Addr {
|
func randAddr(t *testing.T) net.Addr {
|
||||||
var ip [4]byte
|
t.Helper()
|
||||||
if _, err := rand.Read(ip[:]); err != nil {
|
|
||||||
t.Fatal(err)
|
var ip [4]byte
|
||||||
}
|
_, err := rand.Read(ip[:])
|
||||||
var port [2]byte
|
require.NoError(t, err)
|
||||||
if _, err := rand.Read(port[:]); err != nil {
|
|
||||||
t.Fatal(err)
|
var port [2]byte
|
||||||
|
_, err = rand.Read(port[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
}
|
|
||||||
return &net.TCPAddr{
|
return &net.TCPAddr{
|
||||||
IP: net.IP(ip[:]),
|
IP: net.IP(ip[:]),
|
||||||
Port: int(binary.BigEndian.Uint16(port[:])),
|
Port: int(binary.BigEndian.Uint16(port[:])),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func randTower(t *testing.T) *wtdb.Tower {
|
func randTower(t *testing.T) *Tower {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
priv, err := btcec.NewPrivateKey()
|
priv, err := btcec.NewPrivateKey()
|
||||||
require.NoError(t, err, "unable to create private key")
|
require.NoError(t, err, "unable to create private key")
|
||||||
pubKey := priv.PubKey()
|
pubKey := priv.PubKey()
|
||||||
return &wtdb.Tower{
|
addrs, err := newAddressIterator(randAddr(t))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return &Tower{
|
||||||
ID: wtdb.TowerID(rand.Uint64()),
|
ID: wtdb.TowerID(rand.Uint64()),
|
||||||
IdentityKey: pubKey,
|
IdentityKey: pubKey,
|
||||||
Addresses: []net.Addr{randAddr(t)},
|
Addresses: addrs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func copyTower(tower *wtdb.Tower) *wtdb.Tower {
|
func copyTower(t *testing.T, tower *Tower) *Tower {
|
||||||
t := &wtdb.Tower{
|
t.Helper()
|
||||||
|
|
||||||
|
addrs := tower.Addresses.GetAll()
|
||||||
|
addrIterator, err := newAddressIterator(addrs...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return &Tower{
|
||||||
ID: tower.ID,
|
ID: tower.ID,
|
||||||
IdentityKey: tower.IdentityKey,
|
IdentityKey: tower.IdentityKey,
|
||||||
Addresses: make([]net.Addr, len(tower.Addresses)),
|
Addresses: addrIterator,
|
||||||
}
|
}
|
||||||
copy(t.Addresses, tower.Addresses)
|
|
||||||
return t
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertActiveCandidate(t *testing.T, i TowerCandidateIterator,
|
func assertActiveCandidate(t *testing.T, i TowerCandidateIterator, c *Tower,
|
||||||
c *wtdb.Tower, active bool) {
|
active bool) {
|
||||||
|
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
isCandidate := i.IsActive(c.ID)
|
isCandidate := i.IsActive(c.ID)
|
||||||
if isCandidate && !active {
|
if isCandidate {
|
||||||
t.Fatalf("expected tower %v to no longer be an active candidate",
|
require.Truef(t, active, "expected tower %v to no longer be "+
|
||||||
c.ID)
|
"an active candidate", c.ID)
|
||||||
}
|
return
|
||||||
if !isCandidate && active {
|
|
||||||
t.Fatalf("expected tower %v to be an active candidate", c.ID)
|
|
||||||
}
|
}
|
||||||
|
require.Falsef(t, active, "expected tower %v to be an active candidate",
|
||||||
|
c.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *wtdb.Tower) {
|
func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *Tower) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tower, err := i.Next()
|
tower, err := i.Next()
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
require.True(t, tower.IdentityKey.IsEqual(c.IdentityKey))
|
||||||
}
|
require.Equal(t, tower.ID, c.ID)
|
||||||
if !reflect.DeepEqual(tower, c) {
|
require.Equal(t, tower.Addresses.GetAll(), c.Addresses.GetAll())
|
||||||
t.Fatalf("expected tower: %v\ngot: %v", spew.Sdump(c),
|
|
||||||
spew.Sdump(tower))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestTowerCandidateIterator asserts the internal state of a
|
// TestTowerCandidateIterator asserts the internal state of a
|
||||||
|
@ -90,13 +97,13 @@ func TestTowerCandidateIterator(t *testing.T) {
|
||||||
// towers. We'll use copies of these towers within the iterator to
|
// towers. We'll use copies of these towers within the iterator to
|
||||||
// ensure the iterator properly updates the state of its candidates.
|
// ensure the iterator properly updates the state of its candidates.
|
||||||
const numTowers = 4
|
const numTowers = 4
|
||||||
towers := make([]*wtdb.Tower, 0, numTowers)
|
towers := make([]*Tower, 0, numTowers)
|
||||||
for i := 0; i < numTowers; i++ {
|
for i := 0; i < numTowers; i++ {
|
||||||
towers = append(towers, randTower(t))
|
towers = append(towers, randTower(t))
|
||||||
}
|
}
|
||||||
towerCopies := make([]*wtdb.Tower, 0, numTowers)
|
towerCopies := make([]*Tower, 0, numTowers)
|
||||||
for _, tower := range towers {
|
for _, tower := range towers {
|
||||||
towerCopies = append(towerCopies, copyTower(tower))
|
towerCopies = append(towerCopies, copyTower(t, tower))
|
||||||
}
|
}
|
||||||
towerIterator := newTowerListIterator(towerCopies...)
|
towerIterator := newTowerListIterator(towerCopies...)
|
||||||
|
|
||||||
|
@ -104,28 +111,23 @@ func TestTowerCandidateIterator(t *testing.T) {
|
||||||
// were added.
|
// were added.
|
||||||
for _, expTower := range towers {
|
for _, expTower := range towers {
|
||||||
tower, err := towerIterator.Next()
|
tower, err := towerIterator.Next()
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
require.Equal(t, expTower, tower)
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(tower, expTower) {
|
|
||||||
t.Fatalf("expected tower: %v\ngot: %v",
|
|
||||||
spew.Sdump(expTower), spew.Sdump(tower))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := towerIterator.Next(); err != ErrTowerCandidatesExhausted {
|
_, err := towerIterator.Next()
|
||||||
t.Fatalf("expected ErrTowerCandidatesExhausted, got %v", err)
|
require.ErrorIs(t, err, ErrTowerCandidatesExhausted)
|
||||||
}
|
|
||||||
towerIterator.Reset()
|
towerIterator.Reset()
|
||||||
|
|
||||||
// We'll then attempt to test the RemoveCandidate behavior of the
|
// We'll then attempt to test the RemoveCandidate behavior of the
|
||||||
// iterator. We'll remove the address of the first tower, which should
|
// iterator. We'll attempt to remove the address of the first tower,
|
||||||
// result in it not having any addresses left, but still being an active
|
// which should result in an error due to it being the last address of
|
||||||
// candidate.
|
// the tower.
|
||||||
firstTower := towers[0]
|
firstTower := towers[0]
|
||||||
firstTowerAddr := firstTower.Addresses[0]
|
firstTowerAddr := firstTower.Addresses.Peek()
|
||||||
firstTower.RemoveAddress(firstTowerAddr)
|
err = towerIterator.RemoveCandidate(firstTower.ID, firstTowerAddr)
|
||||||
towerIterator.RemoveCandidate(firstTower.ID, firstTowerAddr)
|
require.ErrorIs(t, err, wtdb.ErrLastTowerAddr)
|
||||||
assertActiveCandidate(t, towerIterator, firstTower, true)
|
assertActiveCandidate(t, towerIterator, firstTower, true)
|
||||||
assertNextCandidate(t, towerIterator, firstTower)
|
assertNextCandidate(t, towerIterator, firstTower)
|
||||||
|
|
||||||
|
@ -133,7 +135,8 @@ func TestTowerCandidateIterator(t *testing.T) {
|
||||||
// not providing the optional address. Since it's been removed, we
|
// not providing the optional address. Since it's been removed, we
|
||||||
// should expect to see the third tower next.
|
// should expect to see the third tower next.
|
||||||
secondTower, thirdTower := towers[1], towers[2]
|
secondTower, thirdTower := towers[1], towers[2]
|
||||||
towerIterator.RemoveCandidate(secondTower.ID, nil)
|
err = towerIterator.RemoveCandidate(secondTower.ID, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
assertActiveCandidate(t, towerIterator, secondTower, false)
|
assertActiveCandidate(t, towerIterator, secondTower, false)
|
||||||
assertNextCandidate(t, towerIterator, thirdTower)
|
assertNextCandidate(t, towerIterator, thirdTower)
|
||||||
|
|
||||||
|
@ -142,7 +145,7 @@ func TestTowerCandidateIterator(t *testing.T) {
|
||||||
// iterator, but the new address should be.
|
// iterator, but the new address should be.
|
||||||
fourthTower := towers[3]
|
fourthTower := towers[3]
|
||||||
assertActiveCandidate(t, towerIterator, fourthTower, true)
|
assertActiveCandidate(t, towerIterator, fourthTower, true)
|
||||||
fourthTower.AddAddress(randAddr(t))
|
fourthTower.Addresses.Add(randAddr(t))
|
||||||
towerIterator.AddCandidate(fourthTower)
|
towerIterator.AddCandidate(fourthTower)
|
||||||
assertNextCandidate(t, towerIterator, fourthTower)
|
assertNextCandidate(t, towerIterator, fourthTower)
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,6 @@ package wtclient
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -45,8 +44,8 @@ const (
|
||||||
|
|
||||||
// genActiveSessionFilter generates a filter that selects active sessions that
|
// genActiveSessionFilter generates a filter that selects active sessions that
|
||||||
// also match the desired channel type, either legacy or anchor.
|
// also match the desired channel type, either legacy or anchor.
|
||||||
func genActiveSessionFilter(anchor bool) func(*wtdb.ClientSession) bool {
|
func genActiveSessionFilter(anchor bool) func(*ClientSession) bool {
|
||||||
return func(s *wtdb.ClientSession) bool {
|
return func(s *ClientSession) bool {
|
||||||
return s.Status == wtdb.CSessionActive &&
|
return s.Status == wtdb.CSessionActive &&
|
||||||
anchor == s.Policy.IsAnchorChannel()
|
anchor == s.Policy.IsAnchorChannel()
|
||||||
}
|
}
|
||||||
|
@ -241,7 +240,7 @@ type TowerClient struct {
|
||||||
|
|
||||||
negotiator SessionNegotiator
|
negotiator SessionNegotiator
|
||||||
candidateTowers TowerCandidateIterator
|
candidateTowers TowerCandidateIterator
|
||||||
candidateSessions map[wtdb.SessionID]*wtdb.ClientSession
|
candidateSessions map[wtdb.SessionID]*ClientSession
|
||||||
activeSessions sessionQueueSet
|
activeSessions sessionQueueSet
|
||||||
|
|
||||||
sessionQueue *sessionQueue
|
sessionQueue *sessionQueue
|
||||||
|
@ -351,7 +350,7 @@ func New(config *Config) (*TowerClient, error) {
|
||||||
activeSessionFilter := genActiveSessionFilter(isAnchorClient)
|
activeSessionFilter := genActiveSessionFilter(isAnchorClient)
|
||||||
|
|
||||||
candidateTowers := newTowerListIterator()
|
candidateTowers := newTowerListIterator()
|
||||||
perActiveTower := func(tower *wtdb.Tower) {
|
perActiveTower := func(tower *Tower) {
|
||||||
// If the tower has already been marked as active, then there is
|
// If the tower has already been marked as active, then there is
|
||||||
// no need to add it to the iterator again.
|
// no need to add it to the iterator again.
|
||||||
if candidateTowers.IsActive(tower.ID) {
|
if candidateTowers.IsActive(tower.ID) {
|
||||||
|
@ -400,18 +399,23 @@ func New(config *Config) (*TowerClient, error) {
|
||||||
// sessionFilter check then the perActiveTower call-back will be called on that
|
// sessionFilter check then the perActiveTower call-back will be called on that
|
||||||
// tower.
|
// tower.
|
||||||
func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
||||||
sessionFilter func(*wtdb.ClientSession) bool,
|
sessionFilter func(*ClientSession) bool,
|
||||||
perActiveTower func(tower *wtdb.Tower),
|
perActiveTower func(tower *Tower),
|
||||||
opts ...wtdb.ClientSessionListOption) (
|
opts ...wtdb.ClientSessionListOption) (
|
||||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
map[wtdb.SessionID]*ClientSession, error) {
|
||||||
|
|
||||||
towers, err := db.ListTowers()
|
towers, err := db.ListTowers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
candidateSessions := make(map[wtdb.SessionID]*ClientSession)
|
||||||
for _, tower := range towers {
|
for _, dbTower := range towers {
|
||||||
|
tower, err := NewTowerFromDBTower(dbTower)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
sessions, err := db.ListClientSessions(&tower.ID, opts...)
|
sessions, err := db.ListClientSessions(&tower.ID, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -427,16 +431,24 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
s.SessionKeyECDH = keychain.NewPubKeyECDH(
|
|
||||||
|
sessionKeyECDH := keychain.NewPubKeyECDH(
|
||||||
towerKeyDesc, keyRing,
|
towerKeyDesc, keyRing,
|
||||||
)
|
)
|
||||||
|
|
||||||
if !sessionFilter(s) {
|
cs := &ClientSession{
|
||||||
|
ID: s.ID,
|
||||||
|
ClientSessionBody: s.ClientSessionBody,
|
||||||
|
Tower: tower,
|
||||||
|
SessionKeyECDH: sessionKeyECDH,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !sessionFilter(cs) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the session to the set of candidate sessions.
|
// Add the session to the set of candidate sessions.
|
||||||
candidateSessions[s.ID] = s
|
candidateSessions[s.ID] = cs
|
||||||
perActiveTower(tower)
|
perActiveTower(tower)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -452,11 +464,11 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
||||||
// ClientSession's SessionPrivKey field is desired, otherwise, the existing
|
// ClientSession's SessionPrivKey field is desired, otherwise, the existing
|
||||||
// ListClientSessions method should be used.
|
// ListClientSessions method should be used.
|
||||||
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
||||||
passesFilter func(*wtdb.ClientSession) bool,
|
passesFilter func(*ClientSession) bool,
|
||||||
opts ...wtdb.ClientSessionListOption) (
|
opts ...wtdb.ClientSessionListOption) (
|
||||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
map[wtdb.SessionID]*ClientSession, error) {
|
||||||
|
|
||||||
sessions, err := db.ListClientSessions(forTower, opts...)
|
dbSessions, err := db.ListClientSessions(forTower, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -466,7 +478,13 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
||||||
// be able to communicate with the towers and authenticate session
|
// be able to communicate with the towers and authenticate session
|
||||||
// requests. This prevents us from having to store the private keys on
|
// requests. This prevents us from having to store the private keys on
|
||||||
// disk.
|
// disk.
|
||||||
for _, s := range sessions {
|
sessions := make(map[wtdb.SessionID]*ClientSession)
|
||||||
|
for _, s := range dbSessions {
|
||||||
|
dbTower, err := db.LoadTowerByID(s.TowerID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{
|
towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{
|
||||||
Family: keychain.KeyFamilyTowerSession,
|
Family: keychain.KeyFamilyTowerSession,
|
||||||
Index: s.KeyIndex,
|
Index: s.KeyIndex,
|
||||||
|
@ -474,13 +492,27 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
s.SessionKeyECDH = keychain.NewPubKeyECDH(towerKeyDesc, keyRing)
|
sessionKeyECDH := keychain.NewPubKeyECDH(towerKeyDesc, keyRing)
|
||||||
|
|
||||||
|
tower, err := NewTowerFromDBTower(dbTower)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cs := &ClientSession{
|
||||||
|
ID: s.ID,
|
||||||
|
ClientSessionBody: s.ClientSessionBody,
|
||||||
|
Tower: tower,
|
||||||
|
SessionKeyECDH: sessionKeyECDH,
|
||||||
|
}
|
||||||
|
|
||||||
// If an optional filter was provided, use it to filter out any
|
// If an optional filter was provided, use it to filter out any
|
||||||
// undesired sessions.
|
// undesired sessions.
|
||||||
if passesFilter != nil && !passesFilter(s) {
|
if passesFilter != nil && !passesFilter(cs) {
|
||||||
delete(sessions, s.ID)
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sessions[s.ID] = cs
|
||||||
}
|
}
|
||||||
|
|
||||||
return sessions, nil
|
return sessions, nil
|
||||||
|
@ -710,7 +742,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
|
||||||
func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) {
|
func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) {
|
||||||
// Select any candidate session at random, and remove it from the set of
|
// Select any candidate session at random, and remove it from the set of
|
||||||
// candidate sessions.
|
// candidate sessions.
|
||||||
var candidateSession *wtdb.ClientSession
|
var candidateSession *ClientSession
|
||||||
for id, sessionInfo := range c.candidateSessions {
|
for id, sessionInfo := range c.candidateSessions {
|
||||||
delete(c.candidateSessions, id)
|
delete(c.candidateSessions, id)
|
||||||
|
|
||||||
|
@ -793,13 +825,10 @@ func (c *TowerClient) backupDispatcher() {
|
||||||
msg.errChan <- c.handleNewTower(msg)
|
msg.errChan <- c.handleNewTower(msg)
|
||||||
|
|
||||||
// A tower has been requested to be removed. We'll
|
// A tower has been requested to be removed. We'll
|
||||||
// immediately return an error as we want to avoid the
|
// only allow removal of it if the address in question
|
||||||
// possibility of a new session being negotiated with
|
// is not currently being used for session negotiation.
|
||||||
// this request's tower.
|
|
||||||
case msg := <-c.staleTowers:
|
case msg := <-c.staleTowers:
|
||||||
msg.errChan <- errors.New("removing towers " +
|
msg.errChan <- c.handleStaleTower(msg)
|
||||||
"is disallowed while a new session " +
|
|
||||||
"negotiation is in progress")
|
|
||||||
|
|
||||||
case <-c.forceQuit:
|
case <-c.forceQuit:
|
||||||
return
|
return
|
||||||
|
@ -1069,7 +1098,7 @@ func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error
|
||||||
|
|
||||||
// newSessionQueue creates a sessionQueue from a ClientSession loaded from the
|
// newSessionQueue creates a sessionQueue from a ClientSession loaded from the
|
||||||
// database and supplying it with the resources needed by the client.
|
// database and supplying it with the resources needed by the client.
|
||||||
func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession,
|
func (c *TowerClient) newSessionQueue(s *ClientSession,
|
||||||
updates []wtdb.CommittedUpdate) *sessionQueue {
|
updates []wtdb.CommittedUpdate) *sessionQueue {
|
||||||
|
|
||||||
return newSessionQueue(&sessionQueueConfig{
|
return newSessionQueue(&sessionQueueConfig{
|
||||||
|
@ -1089,7 +1118,7 @@ func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession,
|
||||||
// getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the
|
// getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the
|
||||||
// passed ClientSession. If it exists, the active sessionQueue is returned.
|
// passed ClientSession. If it exists, the active sessionQueue is returned.
|
||||||
// Otherwise, a new sessionQueue is initialized and added to the set.
|
// Otherwise, a new sessionQueue is initialized and added to the set.
|
||||||
func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession,
|
func (c *TowerClient) getOrInitActiveQueue(s *ClientSession,
|
||||||
updates []wtdb.CommittedUpdate) *sessionQueue {
|
updates []wtdb.CommittedUpdate) *sessionQueue {
|
||||||
|
|
||||||
if sq, ok := c.activeSessions[s.ID]; ok {
|
if sq, ok := c.activeSessions[s.ID]; ok {
|
||||||
|
@ -1103,7 +1132,7 @@ func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession,
|
||||||
// adds the sessionQueue to the activeSessions set, and starts the sessionQueue
|
// adds the sessionQueue to the activeSessions set, and starts the sessionQueue
|
||||||
// so that it can deliver any committed updates or begin accepting newly
|
// so that it can deliver any committed updates or begin accepting newly
|
||||||
// assigned tasks.
|
// assigned tasks.
|
||||||
func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession,
|
func (c *TowerClient) initActiveQueue(s *ClientSession,
|
||||||
updates []wtdb.CommittedUpdate) *sessionQueue {
|
updates []wtdb.CommittedUpdate) *sessionQueue {
|
||||||
|
|
||||||
// Initialize the session queue, providing it with all the resources it
|
// Initialize the session queue, providing it with all the resources it
|
||||||
|
@ -1156,10 +1185,16 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error {
|
||||||
// We'll start by updating our persisted state, followed by our
|
// We'll start by updating our persisted state, followed by our
|
||||||
// in-memory state, with the new tower. This might not actually be a new
|
// in-memory state, with the new tower. This might not actually be a new
|
||||||
// tower, but it might include a new address at which it can be reached.
|
// tower, but it might include a new address at which it can be reached.
|
||||||
tower, err := c.cfg.DB.CreateTower(msg.addr)
|
dbTower, err := c.cfg.DB.CreateTower(msg.addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tower, err := NewTowerFromDBTower(dbTower)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
c.candidateTowers.AddCandidate(tower)
|
c.candidateTowers.AddCandidate(tower)
|
||||||
|
|
||||||
// Include all of its corresponding sessions to our set of candidates.
|
// Include all of its corresponding sessions to our set of candidates.
|
||||||
|
@ -1215,18 +1250,31 @@ func (c *TowerClient) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error
|
||||||
func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
|
func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
|
||||||
// We'll load the tower before potentially removing it in order to
|
// We'll load the tower before potentially removing it in order to
|
||||||
// retrieve its ID within the database.
|
// retrieve its ID within the database.
|
||||||
tower, err := c.cfg.DB.LoadTower(msg.pubKey)
|
dbTower, err := c.cfg.DB.LoadTower(msg.pubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// We'll update our persisted state, followed by our in-memory state,
|
// We'll first update our in-memory state followed by our persisted
|
||||||
// with the stale tower.
|
// state, with the stale tower. The removal of the tower address from
|
||||||
if err := c.cfg.DB.RemoveTower(msg.pubKey, msg.addr); err != nil {
|
// the in-memory state will fail if the address is currently being used
|
||||||
|
// for a session negotiation.
|
||||||
|
err = c.candidateTowers.RemoveCandidate(dbTower.ID, msg.addr)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = c.candidateTowers.RemoveCandidate(tower.ID, msg.addr)
|
|
||||||
if err != nil {
|
if err := c.cfg.DB.RemoveTower(msg.pubKey, msg.addr); err != nil {
|
||||||
|
// If the persisted state update fails, re-add the address to
|
||||||
|
// our in-memory state.
|
||||||
|
tower, newTowerErr := NewTowerFromDBTower(dbTower)
|
||||||
|
if newTowerErr != nil {
|
||||||
|
log.Errorf("could not create new in-memory tower: %v",
|
||||||
|
newTowerErr)
|
||||||
|
} else {
|
||||||
|
c.candidateTowers.AddCandidate(tower)
|
||||||
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1239,7 +1287,7 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
|
||||||
// Otherwise, the tower should no longer be used for future session
|
// Otherwise, the tower should no longer be used for future session
|
||||||
// negotiations and backups.
|
// negotiations and backups.
|
||||||
pubKey := msg.pubKey.SerializeCompressed()
|
pubKey := msg.pubKey.SerializeCompressed()
|
||||||
sessions, err := c.cfg.DB.ListClientSessions(&tower.ID)
|
sessions, err := c.cfg.DB.ListClientSessions(&dbTower.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to retrieve sessions for tower %x: "+
|
return fmt.Errorf("unable to retrieve sessions for tower %x: "+
|
||||||
"%v", pubKey, err)
|
"%v", pubKey, err)
|
||||||
|
@ -1251,7 +1299,7 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
|
||||||
// If our active session queue corresponds to the stale tower, we'll
|
// If our active session queue corresponds to the stale tower, we'll
|
||||||
// proceed to negotiate a new one.
|
// proceed to negotiate a new one.
|
||||||
if c.sessionQueue != nil {
|
if c.sessionQueue != nil {
|
||||||
activeTower := c.sessionQueue.towerAddr.IdentityKey.SerializeCompressed()
|
activeTower := c.sessionQueue.tower.IdentityKey.SerializeCompressed()
|
||||||
if bytes.Equal(pubKey, activeTower) {
|
if bytes.Equal(pubKey, activeTower) {
|
||||||
c.sessionQueue = nil
|
c.sessionQueue = nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,8 @@ package wtclient_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -16,6 +18,7 @@ import (
|
||||||
"github.com/lightningnetwork/lnd/channeldb"
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
"github.com/lightningnetwork/lnd/input"
|
"github.com/lightningnetwork/lnd/input"
|
||||||
"github.com/lightningnetwork/lnd/keychain"
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
|
"github.com/lightningnetwork/lnd/lntest/wait"
|
||||||
"github.com/lightningnetwork/lnd/lnwallet"
|
"github.com/lightningnetwork/lnd/lnwallet"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/tor"
|
"github.com/lightningnetwork/lnd/tor"
|
||||||
|
@ -31,7 +34,8 @@ import (
|
||||||
const (
|
const (
|
||||||
csvDelay uint32 = 144
|
csvDelay uint32 = 144
|
||||||
|
|
||||||
towerAddrStr = "18.28.243.2:9911"
|
towerAddrStr = "18.28.243.2:9911"
|
||||||
|
towerAddr2Str = "19.29.244.3:9912"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -63,6 +67,8 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
addrScript, _ = txscript.PayToAddrScript(addr)
|
addrScript, _ = txscript.PayToAddrScript(addr)
|
||||||
|
|
||||||
|
waitTime = 5 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// randPrivKey generates a new secp keypair, and returns the public key.
|
// randPrivKey generates a new secp keypair, and returns the public key.
|
||||||
|
@ -76,37 +82,34 @@ func randPrivKey(t *testing.T) *btcec.PrivateKey {
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockNet struct {
|
type mockNet struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
connCallback func(wtserver.Peer)
|
connCallbacks map[string]func(wtserver.Peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMockNet(cb func(wtserver.Peer)) *mockNet {
|
func newMockNet() *mockNet {
|
||||||
return &mockNet{
|
return &mockNet{
|
||||||
connCallback: cb,
|
connCallbacks: make(map[string]func(peer wtserver.Peer)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockNet) Dial(network string, address string,
|
func (m *mockNet) Dial(_, _ string, _ time.Duration) (net.Conn, error) {
|
||||||
timeout time.Duration) (net.Conn, error) {
|
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockNet) LookupHost(host string) ([]string, error) {
|
func (m *mockNet) LookupHost(_ string) ([]string, error) {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockNet) LookupSRV(service string, proto string, name string) (string, []*net.SRV, error) {
|
func (m *mockNet) LookupSRV(_, _, _ string) (string, []*net.SRV, error) {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockNet) ResolveTCPAddr(network string, address string) (*net.TCPAddr, error) {
|
func (m *mockNet) ResolveTCPAddr(_, _ string) (*net.TCPAddr, error) {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockNet) AuthDial(local keychain.SingleKeyECDH,
|
func (m *mockNet) AuthDial(local keychain.SingleKeyECDH,
|
||||||
netAddr *lnwire.NetAddress,
|
netAddr *lnwire.NetAddress, _ tor.DialFunc) (wtserver.Peer, error) {
|
||||||
dialer tor.DialFunc) (wtserver.Peer, error) {
|
|
||||||
|
|
||||||
localPk := local.PubKey()
|
localPk := local.PubKey()
|
||||||
localAddr := &net.TCPAddr{
|
localAddr := &net.TCPAddr{
|
||||||
|
@ -119,16 +122,31 @@ func (m *mockNet) AuthDial(local keychain.SingleKeyECDH,
|
||||||
)
|
)
|
||||||
|
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
m.connCallback(remotePeer)
|
defer m.mu.RUnlock()
|
||||||
m.mu.RUnlock()
|
cb, ok := m.connCallbacks[netAddr.String()]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no callback registered for this peer")
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(remotePeer)
|
||||||
|
|
||||||
return localPeer, nil
|
return localPeer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockNet) setConnCallback(cb func(wtserver.Peer)) {
|
func (m *mockNet) registerConnCallback(netAddr *lnwire.NetAddress,
|
||||||
|
cb func(wtserver.Peer)) {
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
m.connCallback = cb
|
|
||||||
|
m.connCallbacks[netAddr.String()] = cb
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNet) removeConnCallback(netAddr *lnwire.NetAddress) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
delete(m.connCallbacks, netAddr.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockChannel struct {
|
type mockChannel struct {
|
||||||
|
@ -325,10 +343,8 @@ func (c *mockChannel) sendPayment(t *testing.T, amt lnwire.MilliSatoshi) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
if c.localBalance < amt {
|
require.GreaterOrEqualf(t, c.localBalance, amt, "insufficient funds "+
|
||||||
t.Fatalf("insufficient funds to send, need: %v, have: %v",
|
"to send, need: %v, have: %v", amt, c.localBalance)
|
||||||
amt, c.localBalance)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.localBalance -= amt
|
c.localBalance -= amt
|
||||||
c.remoteBalance += amt
|
c.remoteBalance += amt
|
||||||
|
@ -343,10 +359,8 @@ func (c *mockChannel) receivePayment(t *testing.T, amt lnwire.MilliSatoshi) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
if c.remoteBalance < amt {
|
require.GreaterOrEqualf(t, c.remoteBalance, amt, "insufficient funds "+
|
||||||
t.Fatalf("insufficient funds to recv, need: %v, have: %v",
|
"to recv, need: %v, have: %v", amt, c.remoteBalance)
|
||||||
amt, c.remoteBalance)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.localBalance += amt
|
c.localBalance += amt
|
||||||
c.remoteBalance -= amt
|
c.remoteBalance -= amt
|
||||||
|
@ -381,6 +395,8 @@ type testHarness struct {
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
channels map[lnwire.ChannelID]*mockChannel
|
channels map[lnwire.ChannelID]*mockChannel
|
||||||
|
|
||||||
|
quit chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type harnessCfg struct {
|
type harnessCfg struct {
|
||||||
|
@ -389,6 +405,7 @@ type harnessCfg struct {
|
||||||
policy wtpolicy.Policy
|
policy wtpolicy.Policy
|
||||||
noRegisterChan0 bool
|
noRegisterChan0 bool
|
||||||
noAckCreateSession bool
|
noAckCreateSession bool
|
||||||
|
noServerStart bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
||||||
|
@ -420,11 +437,8 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
||||||
NoAckCreateSession: cfg.noAckCreateSession,
|
NoAckCreateSession: cfg.noAckCreateSession,
|
||||||
}
|
}
|
||||||
|
|
||||||
server, err := wtserver.New(serverCfg)
|
|
||||||
require.NoError(t, err, "unable to create wtserver")
|
|
||||||
|
|
||||||
signer := wtmock.NewMockSigner()
|
signer := wtmock.NewMockSigner()
|
||||||
mockNet := newMockNet(server.InboundPeerConnected)
|
mockNet := newMockNet()
|
||||||
clientDB := wtmock.NewClientDB()
|
clientDB := wtmock.NewClientDB()
|
||||||
|
|
||||||
clientCfg := &wtclient.Config{
|
clientCfg := &wtclient.Config{
|
||||||
|
@ -443,21 +457,6 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
||||||
MaxBackoff: time.Second,
|
MaxBackoff: time.Second,
|
||||||
ForceQuitDelay: 10 * time.Second,
|
ForceQuitDelay: 10 * time.Second,
|
||||||
}
|
}
|
||||||
client, err := wtclient.New(clientCfg)
|
|
||||||
require.NoError(t, err, "Unable to create wtclient")
|
|
||||||
|
|
||||||
if err := server.Start(); err != nil {
|
|
||||||
t.Fatalf("Unable to start wtserver: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = client.Start(); err != nil {
|
|
||||||
server.Stop()
|
|
||||||
t.Fatalf("Unable to start wtclient: %v", err)
|
|
||||||
}
|
|
||||||
if err := client.AddTower(towerAddr); err != nil {
|
|
||||||
server.Stop()
|
|
||||||
t.Fatalf("Unable to add tower to wtclient: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
h := &testHarness{
|
h := &testHarness{
|
||||||
t: t,
|
t: t,
|
||||||
|
@ -466,14 +465,24 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
||||||
capacity: cfg.localBalance + cfg.remoteBalance,
|
capacity: cfg.localBalance + cfg.remoteBalance,
|
||||||
clientDB: clientDB,
|
clientDB: clientDB,
|
||||||
clientCfg: clientCfg,
|
clientCfg: clientCfg,
|
||||||
client: client,
|
|
||||||
serverAddr: towerAddr,
|
serverAddr: towerAddr,
|
||||||
serverDB: serverDB,
|
serverDB: serverDB,
|
||||||
serverCfg: serverCfg,
|
serverCfg: serverCfg,
|
||||||
server: server,
|
|
||||||
net: mockNet,
|
net: mockNet,
|
||||||
channels: make(map[lnwire.ChannelID]*mockChannel),
|
channels: make(map[lnwire.ChannelID]*mockChannel),
|
||||||
|
quit: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(h.quit)
|
||||||
|
})
|
||||||
|
|
||||||
|
if !cfg.noServerStart {
|
||||||
|
h.startServer()
|
||||||
|
t.Cleanup(h.stopServer)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.startClient()
|
||||||
|
t.Cleanup(h.client.ForceQuit)
|
||||||
|
|
||||||
h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance)
|
h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance)
|
||||||
if !cfg.noRegisterChan0 {
|
if !cfg.noRegisterChan0 {
|
||||||
|
@ -490,15 +499,20 @@ func (h *testHarness) startServer() {
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
h.server, err = wtserver.New(h.serverCfg)
|
h.server, err = wtserver.New(h.serverCfg)
|
||||||
if err != nil {
|
require.NoError(h.t, err)
|
||||||
h.t.Fatalf("unable to create wtserver: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
h.net.setConnCallback(h.server.InboundPeerConnected)
|
h.net.registerConnCallback(h.serverAddr, h.server.InboundPeerConnected)
|
||||||
|
|
||||||
if err := h.server.Start(); err != nil {
|
require.NoError(h.t, h.server.Start())
|
||||||
h.t.Fatalf("unable to start wtserver: %v", err)
|
}
|
||||||
}
|
|
||||||
|
// stopServer stops the main harness server.
|
||||||
|
func (h *testHarness) stopServer() {
|
||||||
|
h.t.Helper()
|
||||||
|
|
||||||
|
h.net.removeConnCallback(h.serverAddr)
|
||||||
|
|
||||||
|
require.NoError(h.t, h.server.Stop())
|
||||||
}
|
}
|
||||||
|
|
||||||
// startClient creates a new server using the harness's current clientCf and
|
// startClient creates a new server using the harness's current clientCf and
|
||||||
|
@ -507,24 +521,16 @@ func (h *testHarness) startClient() {
|
||||||
h.t.Helper()
|
h.t.Helper()
|
||||||
|
|
||||||
towerTCPAddr, err := net.ResolveTCPAddr("tcp", towerAddrStr)
|
towerTCPAddr, err := net.ResolveTCPAddr("tcp", towerAddrStr)
|
||||||
if err != nil {
|
require.NoError(h.t, err)
|
||||||
h.t.Fatalf("Unable to resolve tower TCP addr: %v", err)
|
|
||||||
}
|
|
||||||
towerAddr := &lnwire.NetAddress{
|
towerAddr := &lnwire.NetAddress{
|
||||||
IdentityKey: h.serverCfg.NodeKeyECDH.PubKey(),
|
IdentityKey: h.serverCfg.NodeKeyECDH.PubKey(),
|
||||||
Address: towerTCPAddr,
|
Address: towerTCPAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
h.client, err = wtclient.New(h.clientCfg)
|
h.client, err = wtclient.New(h.clientCfg)
|
||||||
if err != nil {
|
require.NoError(h.t, err)
|
||||||
h.t.Fatalf("unable to create wtclient: %v", err)
|
require.NoError(h.t, h.client.Start())
|
||||||
}
|
require.NoError(h.t, h.client.AddTower(towerAddr))
|
||||||
if err := h.client.Start(); err != nil {
|
|
||||||
h.t.Fatalf("unable to start wtclient: %v", err)
|
|
||||||
}
|
|
||||||
if err := h.client.AddTower(towerAddr); err != nil {
|
|
||||||
h.t.Fatalf("unable to add tower to wtclient: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// chanIDFromInt creates a unique channel id given a unique integral id.
|
// chanIDFromInt creates a unique channel id given a unique integral id.
|
||||||
|
@ -553,9 +559,7 @@ func (h *testHarness) makeChannel(id uint64,
|
||||||
}
|
}
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
if ok {
|
require.Falsef(h.t, ok, "channel %d already created", id)
|
||||||
h.t.Fatalf("channel %d already created", id)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// channel retrieves the channel corresponding to id.
|
// channel retrieves the channel corresponding to id.
|
||||||
|
@ -567,9 +571,7 @@ func (h *testHarness) channel(id uint64) *mockChannel {
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
c, ok := h.channels[chanIDFromInt(id)]
|
c, ok := h.channels[chanIDFromInt(id)]
|
||||||
h.mu.Unlock()
|
h.mu.Unlock()
|
||||||
if !ok {
|
require.Truef(h.t, ok, "unable to fetch channel %d", id)
|
||||||
h.t.Fatalf("unable to fetch channel %d", id)
|
|
||||||
}
|
|
||||||
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
@ -580,9 +582,7 @@ func (h *testHarness) registerChannel(id uint64) {
|
||||||
|
|
||||||
chanID := chanIDFromInt(id)
|
chanID := chanIDFromInt(id)
|
||||||
err := h.client.RegisterChannel(chanID)
|
err := h.client.RegisterChannel(chanID)
|
||||||
if err != nil {
|
require.NoError(h.t, err)
|
||||||
h.t.Fatalf("unable to register channel %d: %v", id, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// advanceChannelN calls advanceState on the channel identified by id the number
|
// advanceChannelN calls advanceState on the channel identified by id the number
|
||||||
|
@ -621,11 +621,10 @@ func (h *testHarness) backupState(id, i uint64, expErr error) {
|
||||||
_, retribution := h.channel(id).getState(i)
|
_, retribution := h.channel(id).getState(i)
|
||||||
|
|
||||||
chanID := chanIDFromInt(id)
|
chanID := chanIDFromInt(id)
|
||||||
err := h.client.BackupState(&chanID, retribution, channeldb.SingleFunderBit)
|
err := h.client.BackupState(
|
||||||
if err != expErr {
|
&chanID, retribution, channeldb.SingleFunderBit,
|
||||||
h.t.Fatalf("back error mismatch, want: %v, got: %v",
|
)
|
||||||
expErr, err)
|
require.ErrorIs(h.t, expErr, err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendPayments instructs the channel identified by id to send amt to the remote
|
// sendPayments instructs the channel identified by id to send amt to the remote
|
||||||
|
@ -685,10 +684,8 @@ func (h *testHarness) waitServerUpdates(hints []blob.BreachHint,
|
||||||
hintSet[hint] = struct{}{}
|
hintSet[hint] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(hints) != len(hintSet) {
|
require.Lenf(h.t, hints, len(hintSet), "breach hints are not unique, "+
|
||||||
h.t.Fatalf("breach hints are not unique, list-len: %d "+
|
"list-len: %d set-len: %d", len(hints), len(hintSet))
|
||||||
"set-len: %d", len(hints), len(hintSet))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Closure to assert the server's matches are consistent with the hint
|
// Closure to assert the server's matches are consistent with the hint
|
||||||
// set.
|
// set.
|
||||||
|
@ -698,12 +695,9 @@ func (h *testHarness) waitServerUpdates(hints []blob.BreachHint,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, match := range matches {
|
for _, match := range matches {
|
||||||
if _, ok := hintSet[match.Hint]; ok {
|
_, ok := hintSet[match.Hint]
|
||||||
continue
|
require.Truef(h.t, ok, "match %v in db is not in "+
|
||||||
}
|
"hint set", match.Hint)
|
||||||
|
|
||||||
h.t.Fatalf("match %v in db is not in hint set",
|
|
||||||
match.Hint)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
|
@ -714,31 +708,24 @@ func (h *testHarness) waitServerUpdates(hints []blob.BreachHint,
|
||||||
select {
|
select {
|
||||||
case <-time.After(time.Second):
|
case <-time.After(time.Second):
|
||||||
matches, err := h.serverDB.QueryMatches(hints)
|
matches, err := h.serverDB.QueryMatches(hints)
|
||||||
switch {
|
require.NoError(h.t, err, "unable to query for hints")
|
||||||
case err != nil:
|
|
||||||
h.t.Fatalf("unable to query for hints: %v", err)
|
|
||||||
|
|
||||||
case wantUpdates && serverHasHints(matches):
|
if wantUpdates && serverHasHints(matches) {
|
||||||
return
|
return
|
||||||
|
}
|
||||||
|
|
||||||
case wantUpdates:
|
if wantUpdates {
|
||||||
h.t.Logf("Received %d/%d\n", len(matches),
|
h.t.Logf("Received %d/%d\n", len(matches),
|
||||||
len(hints))
|
len(hints))
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-failTimeout:
|
case <-failTimeout:
|
||||||
matches, err := h.serverDB.QueryMatches(hints)
|
matches, err := h.serverDB.QueryMatches(hints)
|
||||||
switch {
|
require.NoError(h.t, err, "unable to query for hints")
|
||||||
case err != nil:
|
require.Truef(h.t, serverHasHints(matches), "breach "+
|
||||||
h.t.Fatalf("unable to query for hints: %v", err)
|
"hints not received, only got %d/%d",
|
||||||
|
len(matches), len(hints))
|
||||||
case serverHasHints(matches):
|
return
|
||||||
return
|
|
||||||
|
|
||||||
default:
|
|
||||||
h.t.Fatalf("breach hints not received, only "+
|
|
||||||
"got %d/%d", len(matches), len(hints))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -751,25 +738,18 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint,
|
||||||
|
|
||||||
// Query for matches on the provided hints.
|
// Query for matches on the provided hints.
|
||||||
matches, err := h.serverDB.QueryMatches(hints)
|
matches, err := h.serverDB.QueryMatches(hints)
|
||||||
if err != nil {
|
require.NoError(h.t, err)
|
||||||
h.t.Fatalf("unable to query for matches: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert that the number of matches is exactly the number of provided
|
// Assert that the number of matches is exactly the number of provided
|
||||||
// hints.
|
// hints.
|
||||||
if len(matches) != len(hints) {
|
require.Lenf(h.t, matches, len(hints), "expected: %d matches, got: %d",
|
||||||
h.t.Fatalf("expected: %d matches, got: %d", len(hints),
|
len(hints), len(matches))
|
||||||
len(matches))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert that all of the matches correspond to a session with the
|
// Assert that all of the matches correspond to a session with the
|
||||||
// expected policy.
|
// expected policy.
|
||||||
for _, match := range matches {
|
for _, match := range matches {
|
||||||
matchPolicy := match.SessionInfo.Policy
|
matchPolicy := match.SessionInfo.Policy
|
||||||
if expPolicy != matchPolicy {
|
require.Equal(h.t, expPolicy, matchPolicy)
|
||||||
h.t.Fatalf("expected session to have policy: %v, "+
|
|
||||||
"got: %v", expPolicy, matchPolicy)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -777,9 +757,8 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint,
|
||||||
func (h *testHarness) addTower(addr *lnwire.NetAddress) {
|
func (h *testHarness) addTower(addr *lnwire.NetAddress) {
|
||||||
h.t.Helper()
|
h.t.Helper()
|
||||||
|
|
||||||
if err := h.client.AddTower(addr); err != nil {
|
err := h.client.AddTower(addr)
|
||||||
h.t.Fatalf("unable to add tower: %v", err)
|
require.NoError(h.t, err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeTower removes a tower from the client. If `addr` is specified, then the
|
// removeTower removes a tower from the client. If `addr` is specified, then the
|
||||||
|
@ -787,9 +766,8 @@ func (h *testHarness) addTower(addr *lnwire.NetAddress) {
|
||||||
func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) {
|
func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) {
|
||||||
h.t.Helper()
|
h.t.Helper()
|
||||||
|
|
||||||
if err := h.client.RemoveTower(pubKey, addr); err != nil {
|
err := h.client.RemoveTower(pubKey, addr)
|
||||||
h.t.Fatalf("unable to remove tower: %v", err)
|
require.NoError(h.t, err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -976,10 +954,9 @@ var clientTests = []clientTest{
|
||||||
|
|
||||||
// Now, restart the server and prevent it from acking
|
// Now, restart the server and prevent it from acking
|
||||||
// state updates.
|
// state updates.
|
||||||
h.server.Stop()
|
h.stopServer()
|
||||||
h.serverCfg.NoAckUpdates = true
|
h.serverCfg.NoAckUpdates = true
|
||||||
h.startServer()
|
h.startServer()
|
||||||
defer h.server.Stop()
|
|
||||||
|
|
||||||
// Send the next state update to the tower. Since the
|
// Send the next state update to the tower. Since the
|
||||||
// tower isn't acking state updates, we expect this
|
// tower isn't acking state updates, we expect this
|
||||||
|
@ -997,15 +974,13 @@ var clientTests = []clientTest{
|
||||||
|
|
||||||
// Restart the server and allow it to ack the updates
|
// Restart the server and allow it to ack the updates
|
||||||
// after the client retransmits the unacked update.
|
// after the client retransmits the unacked update.
|
||||||
h.server.Stop()
|
h.stopServer()
|
||||||
h.serverCfg.NoAckUpdates = false
|
h.serverCfg.NoAckUpdates = false
|
||||||
h.startServer()
|
h.startServer()
|
||||||
defer h.server.Stop()
|
|
||||||
|
|
||||||
// Restart the client and allow it to process the
|
// Restart the client and allow it to process the
|
||||||
// committed update.
|
// committed update.
|
||||||
h.startClient()
|
h.startClient()
|
||||||
defer h.client.ForceQuit()
|
|
||||||
|
|
||||||
// Wait for the committed update to be accepted by the
|
// Wait for the committed update to be accepted by the
|
||||||
// tower.
|
// tower.
|
||||||
|
@ -1049,10 +1024,9 @@ var clientTests = []clientTest{
|
||||||
|
|
||||||
// Restart the server and prevent it from acking state
|
// Restart the server and prevent it from acking state
|
||||||
// updates.
|
// updates.
|
||||||
h.server.Stop()
|
h.stopServer()
|
||||||
h.serverCfg.NoAckUpdates = true
|
h.serverCfg.NoAckUpdates = true
|
||||||
h.startServer()
|
h.startServer()
|
||||||
defer h.server.Stop()
|
|
||||||
|
|
||||||
// Now, queue the retributions for backup.
|
// Now, queue the retributions for backup.
|
||||||
h.backupStates(chanID, 0, numUpdates, nil)
|
h.backupStates(chanID, 0, numUpdates, nil)
|
||||||
|
@ -1068,14 +1042,13 @@ var clientTests = []clientTest{
|
||||||
|
|
||||||
// Restart the server and allow it to ack the updates
|
// Restart the server and allow it to ack the updates
|
||||||
// after the client retransmits the unacked updates.
|
// after the client retransmits the unacked updates.
|
||||||
h.server.Stop()
|
h.stopServer()
|
||||||
h.serverCfg.NoAckUpdates = false
|
h.serverCfg.NoAckUpdates = false
|
||||||
h.startServer()
|
h.startServer()
|
||||||
defer h.server.Stop()
|
|
||||||
|
|
||||||
// Wait for all of the updates to be populated in the
|
// Wait for all of the updates to be populated in the
|
||||||
// server's database.
|
// server's database.
|
||||||
h.waitServerUpdates(hints, 5*time.Second)
|
h.waitServerUpdates(hints, waitTime)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1212,23 +1185,21 @@ var clientTests = []clientTest{
|
||||||
|
|
||||||
// Restart the server and allow it to ack session
|
// Restart the server and allow it to ack session
|
||||||
// creation.
|
// creation.
|
||||||
h.server.Stop()
|
h.stopServer()
|
||||||
h.serverCfg.NoAckCreateSession = false
|
h.serverCfg.NoAckCreateSession = false
|
||||||
h.startServer()
|
h.startServer()
|
||||||
defer h.server.Stop()
|
|
||||||
|
|
||||||
// Restart the client with the same policy, which will
|
// Restart the client with the same policy, which will
|
||||||
// immediately try to overwrite the old session with an
|
// immediately try to overwrite the old session with an
|
||||||
// identical one.
|
// identical one.
|
||||||
h.startClient()
|
h.startClient()
|
||||||
defer h.client.ForceQuit()
|
|
||||||
|
|
||||||
// Now, queue the retributions for backup.
|
// Now, queue the retributions for backup.
|
||||||
h.backupStates(chanID, 0, numUpdates, nil)
|
h.backupStates(chanID, 0, numUpdates, nil)
|
||||||
|
|
||||||
// Wait for all of the updates to be populated in the
|
// Wait for all of the updates to be populated in the
|
||||||
// server's database.
|
// server's database.
|
||||||
h.waitServerUpdates(hints, 5*time.Second)
|
h.waitServerUpdates(hints, waitTime)
|
||||||
|
|
||||||
// Assert that the server has updates for the clients
|
// Assert that the server has updates for the clients
|
||||||
// most recent policy.
|
// most recent policy.
|
||||||
|
@ -1270,24 +1241,22 @@ var clientTests = []clientTest{
|
||||||
|
|
||||||
// Restart the server and allow it to ack session
|
// Restart the server and allow it to ack session
|
||||||
// creation.
|
// creation.
|
||||||
h.server.Stop()
|
h.stopServer()
|
||||||
h.serverCfg.NoAckCreateSession = false
|
h.serverCfg.NoAckCreateSession = false
|
||||||
h.startServer()
|
h.startServer()
|
||||||
defer h.server.Stop()
|
|
||||||
|
|
||||||
// Restart the client with a new policy, which will
|
// Restart the client with a new policy, which will
|
||||||
// immediately try to overwrite the prior session with
|
// immediately try to overwrite the prior session with
|
||||||
// the old policy.
|
// the old policy.
|
||||||
h.clientCfg.Policy.SweepFeeRate *= 2
|
h.clientCfg.Policy.SweepFeeRate *= 2
|
||||||
h.startClient()
|
h.startClient()
|
||||||
defer h.client.ForceQuit()
|
|
||||||
|
|
||||||
// Now, queue the retributions for backup.
|
// Now, queue the retributions for backup.
|
||||||
h.backupStates(chanID, 0, numUpdates, nil)
|
h.backupStates(chanID, 0, numUpdates, nil)
|
||||||
|
|
||||||
// Wait for all of the updates to be populated in the
|
// Wait for all of the updates to be populated in the
|
||||||
// server's database.
|
// server's database.
|
||||||
h.waitServerUpdates(hints, 5*time.Second)
|
h.waitServerUpdates(hints, waitTime)
|
||||||
|
|
||||||
// Assert that the server has updates for the clients
|
// Assert that the server has updates for the clients
|
||||||
// most recent policy.
|
// most recent policy.
|
||||||
|
@ -1341,14 +1310,13 @@ var clientTests = []clientTest{
|
||||||
// Restart the client with a new policy.
|
// Restart the client with a new policy.
|
||||||
h.clientCfg.Policy.MaxUpdates = 20
|
h.clientCfg.Policy.MaxUpdates = 20
|
||||||
h.startClient()
|
h.startClient()
|
||||||
defer h.client.ForceQuit()
|
|
||||||
|
|
||||||
// Now, queue the second half of the retributions.
|
// Now, queue the second half of the retributions.
|
||||||
h.backupStates(chanID, numUpdates/2, numUpdates, nil)
|
h.backupStates(chanID, numUpdates/2, numUpdates, nil)
|
||||||
|
|
||||||
// Wait for all of the updates to be populated in the
|
// Wait for all of the updates to be populated in the
|
||||||
// server's database.
|
// server's database.
|
||||||
h.waitServerUpdates(hints, 5*time.Second)
|
h.waitServerUpdates(hints, waitTime)
|
||||||
|
|
||||||
// Assert that the server has updates for the client's
|
// Assert that the server has updates for the client's
|
||||||
// original policy.
|
// original policy.
|
||||||
|
@ -1389,13 +1357,12 @@ var clientTests = []clientTest{
|
||||||
|
|
||||||
// Wait for the first half of the updates to be
|
// Wait for the first half of the updates to be
|
||||||
// populated in the server's database.
|
// populated in the server's database.
|
||||||
h.waitServerUpdates(hints[:len(hints)/2], 5*time.Second)
|
h.waitServerUpdates(hints[:len(hints)/2], waitTime)
|
||||||
|
|
||||||
// Restart the client, so we can ensure the deduping is
|
// Restart the client, so we can ensure the deduping is
|
||||||
// maintained across restarts.
|
// maintained across restarts.
|
||||||
h.client.Stop()
|
h.client.Stop()
|
||||||
h.startClient()
|
h.startClient()
|
||||||
defer h.client.ForceQuit()
|
|
||||||
|
|
||||||
// Try to back up the full range of retributions. Only
|
// Try to back up the full range of retributions. Only
|
||||||
// the second half should actually be sent.
|
// the second half should actually be sent.
|
||||||
|
@ -1403,7 +1370,7 @@ var clientTests = []clientTest{
|
||||||
|
|
||||||
// Wait for all of the updates to be populated in the
|
// Wait for all of the updates to be populated in the
|
||||||
// server's database.
|
// server's database.
|
||||||
h.waitServerUpdates(hints, 5*time.Second)
|
h.waitServerUpdates(hints, waitTime)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1431,7 +1398,7 @@ var clientTests = []clientTest{
|
||||||
// first two.
|
// first two.
|
||||||
hints := h.advanceChannelN(chanID, numUpdates)
|
hints := h.advanceChannelN(chanID, numUpdates)
|
||||||
h.backupStates(chanID, 0, numUpdates/2, nil)
|
h.backupStates(chanID, 0, numUpdates/2, nil)
|
||||||
h.waitServerUpdates(hints[:numUpdates/2], 5*time.Second)
|
h.waitServerUpdates(hints[:numUpdates/2], waitTime)
|
||||||
|
|
||||||
// Fully remove the tower, causing its existing sessions
|
// Fully remove the tower, causing its existing sessions
|
||||||
// to be marked inactive.
|
// to be marked inactive.
|
||||||
|
@ -1445,8 +1412,7 @@ var clientTests = []clientTest{
|
||||||
// Re-add the tower. We prevent the tower from acking
|
// Re-add the tower. We prevent the tower from acking
|
||||||
// session creation to ensure the inactive sessions are
|
// session creation to ensure the inactive sessions are
|
||||||
// not used.
|
// not used.
|
||||||
err := h.server.Stop()
|
h.stopServer()
|
||||||
require.Nil(h.t, err)
|
|
||||||
h.serverCfg.NoAckCreateSession = true
|
h.serverCfg.NoAckCreateSession = true
|
||||||
h.startServer()
|
h.startServer()
|
||||||
h.addTower(h.serverAddr)
|
h.addTower(h.serverAddr)
|
||||||
|
@ -1455,11 +1421,10 @@ var clientTests = []clientTest{
|
||||||
// Finally, allow the tower to ack session creation,
|
// Finally, allow the tower to ack session creation,
|
||||||
// allowing the state updates to be sent through the new
|
// allowing the state updates to be sent through the new
|
||||||
// session.
|
// session.
|
||||||
err = h.server.Stop()
|
h.stopServer()
|
||||||
require.Nil(h.t, err)
|
|
||||||
h.serverCfg.NoAckCreateSession = false
|
h.serverCfg.NoAckCreateSession = false
|
||||||
h.startServer()
|
h.startServer()
|
||||||
h.waitServerUpdates(hints[numUpdates/2:], 5*time.Second)
|
h.waitServerUpdates(hints[numUpdates/2:], waitTime)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1490,13 +1455,12 @@ var clientTests = []clientTest{
|
||||||
|
|
||||||
// Back up 4 of the 5 states for the negotiated session.
|
// Back up 4 of the 5 states for the negotiated session.
|
||||||
h.backupStates(chanID, 0, maxUpdates-1, nil)
|
h.backupStates(chanID, 0, maxUpdates-1, nil)
|
||||||
h.waitServerUpdates(hints[:maxUpdates-1], 5*time.Second)
|
h.waitServerUpdates(hints[:maxUpdates-1], waitTime)
|
||||||
|
|
||||||
// Now, restart the tower and prevent it from acking any
|
// Now, restart the tower and prevent it from acking any
|
||||||
// new sessions. We do this here as once the last slot
|
// new sessions. We do this here as once the last slot
|
||||||
// is exhausted the client will attempt to renegotiate.
|
// is exhausted the client will attempt to renegotiate.
|
||||||
err := h.server.Stop()
|
h.stopServer()
|
||||||
require.Nil(h.t, err)
|
|
||||||
h.serverCfg.NoAckCreateSession = true
|
h.serverCfg.NoAckCreateSession = true
|
||||||
h.startServer()
|
h.startServer()
|
||||||
|
|
||||||
|
@ -1506,15 +1470,189 @@ var clientTests = []clientTest{
|
||||||
// the final state. We'll only wait for the first five
|
// the final state. We'll only wait for the first five
|
||||||
// states to arrive at the tower.
|
// states to arrive at the tower.
|
||||||
h.backupStates(chanID, maxUpdates-1, numUpdates, nil)
|
h.backupStates(chanID, maxUpdates-1, numUpdates, nil)
|
||||||
h.waitServerUpdates(hints[:maxUpdates], 5*time.Second)
|
h.waitServerUpdates(hints[:maxUpdates], waitTime)
|
||||||
|
|
||||||
// Finally, stop the client which will continue to
|
// Finally, stop the client which will continue to
|
||||||
// attempt session negotiation since it has one more
|
// attempt session negotiation since it has one more
|
||||||
// state to process. After the force quite delay
|
// state to process. After the force quite delay
|
||||||
// expires, the client should force quite itself and
|
// expires, the client should force quite itself and
|
||||||
// allow the test to complete.
|
// allow the test to complete.
|
||||||
err = h.client.Stop()
|
h.stopServer()
|
||||||
require.Nil(h.t, err)
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Assert that if a client changes the address for a server and
|
||||||
|
// then tries to back up updates then the client will switch to
|
||||||
|
// the new address.
|
||||||
|
name: "change address of existing session",
|
||||||
|
cfg: harnessCfg{
|
||||||
|
localBalance: localBalance,
|
||||||
|
remoteBalance: remoteBalance,
|
||||||
|
policy: wtpolicy.Policy{
|
||||||
|
TxPolicy: wtpolicy.TxPolicy{
|
||||||
|
BlobType: blob.TypeAltruistCommit,
|
||||||
|
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||||
|
},
|
||||||
|
MaxUpdates: 5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
fn: func(h *testHarness) {
|
||||||
|
const (
|
||||||
|
chanID = 0
|
||||||
|
numUpdates = 6
|
||||||
|
maxUpdates = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
// Advance the channel to create all states.
|
||||||
|
hints := h.advanceChannelN(chanID, numUpdates)
|
||||||
|
|
||||||
|
h.backupStates(chanID, 0, numUpdates/2, nil)
|
||||||
|
|
||||||
|
// Wait for the first half of the updates to be
|
||||||
|
// populated in the server's database.
|
||||||
|
h.waitServerUpdates(hints[:len(hints)/2], waitTime)
|
||||||
|
|
||||||
|
// Stop the server.
|
||||||
|
h.stopServer()
|
||||||
|
|
||||||
|
// Change the address of the server.
|
||||||
|
towerTCPAddr, err := net.ResolveTCPAddr(
|
||||||
|
"tcp", towerAddr2Str,
|
||||||
|
)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
oldAddr := h.serverAddr.Address
|
||||||
|
towerAddr := &lnwire.NetAddress{
|
||||||
|
IdentityKey: h.serverAddr.IdentityKey,
|
||||||
|
Address: towerTCPAddr,
|
||||||
|
}
|
||||||
|
h.serverAddr = towerAddr
|
||||||
|
|
||||||
|
// Add the new tower address to the client.
|
||||||
|
err = h.client.AddTower(towerAddr)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
// Remove the old tower address from the client.
|
||||||
|
err = h.client.RemoveTower(
|
||||||
|
towerAddr.IdentityKey, oldAddr,
|
||||||
|
)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
// Restart the server.
|
||||||
|
h.startServer()
|
||||||
|
|
||||||
|
// Now attempt to back up the rest of the updates.
|
||||||
|
h.backupStates(chanID, numUpdates/2, maxUpdates, nil)
|
||||||
|
|
||||||
|
// Assert that the server does receive the updates.
|
||||||
|
h.waitServerUpdates(hints[:maxUpdates], waitTime)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Assert that a user is able to remove a tower address during
|
||||||
|
// session negotiation as long as the address in question is not
|
||||||
|
// currently being used.
|
||||||
|
name: "removing a tower during session negotiation",
|
||||||
|
cfg: harnessCfg{
|
||||||
|
localBalance: localBalance,
|
||||||
|
remoteBalance: remoteBalance,
|
||||||
|
policy: wtpolicy.Policy{
|
||||||
|
TxPolicy: wtpolicy.TxPolicy{
|
||||||
|
BlobType: blob.TypeAltruistCommit,
|
||||||
|
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||||
|
},
|
||||||
|
MaxUpdates: 5,
|
||||||
|
},
|
||||||
|
noServerStart: true,
|
||||||
|
},
|
||||||
|
fn: func(h *testHarness) {
|
||||||
|
// The server has not started yet and so no session
|
||||||
|
// negotiation with the server will be in progress, so
|
||||||
|
// the client should be able to remove the server.
|
||||||
|
err := wait.NoError(func() error {
|
||||||
|
return h.client.RemoveTower(
|
||||||
|
h.serverAddr.IdentityKey, nil,
|
||||||
|
)
|
||||||
|
}, waitTime)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
// Set the server up so that its Dial function hangs
|
||||||
|
// when the client calls it. This will force the client
|
||||||
|
// to remain in the state where it has locked the
|
||||||
|
// address of the server.
|
||||||
|
h.server, err = wtserver.New(h.serverCfg)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
cancel := make(chan struct{})
|
||||||
|
h.net.registerConnCallback(
|
||||||
|
h.serverAddr, func(peer wtserver.Peer) {
|
||||||
|
select {
|
||||||
|
case <-h.quit:
|
||||||
|
case <-cancel:
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
// Also add a new tower address.
|
||||||
|
towerTCPAddr, err := net.ResolveTCPAddr(
|
||||||
|
"tcp", towerAddr2Str,
|
||||||
|
)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
towerAddr := &lnwire.NetAddress{
|
||||||
|
IdentityKey: h.serverAddr.IdentityKey,
|
||||||
|
Address: towerTCPAddr,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register the new address in the mock-net.
|
||||||
|
h.net.registerConnCallback(
|
||||||
|
towerAddr, h.server.InboundPeerConnected,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Now start the server.
|
||||||
|
require.NoError(h.t, h.server.Start())
|
||||||
|
|
||||||
|
// Re-add the server to the client
|
||||||
|
err = h.client.AddTower(h.serverAddr)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
// Also add the new tower address.
|
||||||
|
err = h.client.AddTower(towerAddr)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
// Assert that if the client attempts to remove the
|
||||||
|
// tower's first address, then it will error due to
|
||||||
|
// address currently being locked for session
|
||||||
|
// negotiation.
|
||||||
|
err = wait.Predicate(func() bool {
|
||||||
|
err = h.client.RemoveTower(
|
||||||
|
h.serverAddr.IdentityKey,
|
||||||
|
h.serverAddr.Address,
|
||||||
|
)
|
||||||
|
return errors.Is(err, wtclient.ErrAddrInUse)
|
||||||
|
}, waitTime)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
// Assert that the second address can be removed since
|
||||||
|
// it is not being used for session negotiation.
|
||||||
|
err = wait.NoError(func() error {
|
||||||
|
return h.client.RemoveTower(
|
||||||
|
h.serverAddr.IdentityKey, towerTCPAddr,
|
||||||
|
)
|
||||||
|
}, waitTime)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
// Allow the dial to the first address to stop hanging.
|
||||||
|
close(cancel)
|
||||||
|
|
||||||
|
// Assert that the client can now remove the first
|
||||||
|
// address.
|
||||||
|
err = wait.NoError(func() error {
|
||||||
|
return h.client.RemoveTower(
|
||||||
|
h.serverAddr.IdentityKey, nil,
|
||||||
|
)
|
||||||
|
}, waitTime)
|
||||||
|
require.NoError(h.t, err)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -1528,10 +1666,6 @@ func TestClient(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
h := newHarness(t, tc.cfg)
|
h := newHarness(t, tc.cfg)
|
||||||
t.Cleanup(func() {
|
|
||||||
require.NoError(t, h.server.Stop())
|
|
||||||
h.client.ForceQuit()
|
|
||||||
})
|
|
||||||
|
|
||||||
tc.fn(h)
|
tc.fn(h)
|
||||||
})
|
})
|
||||||
|
|
|
@ -20,10 +20,6 @@ var (
|
||||||
// down.
|
// down.
|
||||||
ErrNegotiatorExiting = errors.New("negotiator exiting")
|
ErrNegotiatorExiting = errors.New("negotiator exiting")
|
||||||
|
|
||||||
// ErrNoTowerAddrs signals that the client could not be created because
|
|
||||||
// we have no addresses with which we can reach a tower.
|
|
||||||
ErrNoTowerAddrs = errors.New("no tower addresses")
|
|
||||||
|
|
||||||
// ErrFailedNegotiation signals that the session negotiator could not
|
// ErrFailedNegotiation signals that the session negotiator could not
|
||||||
// acquire a new session as requested.
|
// acquire a new session as requested.
|
||||||
ErrFailedNegotiation = errors.New("session negotiation unsuccessful")
|
ErrFailedNegotiation = errors.New("session negotiation unsuccessful")
|
||||||
|
|
|
@ -118,3 +118,50 @@ type ECDHKeyRing interface {
|
||||||
// key.
|
// key.
|
||||||
DeriveKey(keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error)
|
DeriveKey(keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tower represents the info about a watchtower server that a watchtower client
|
||||||
|
// needs in order to connect to it.
|
||||||
|
type Tower struct {
|
||||||
|
// ID is the unique, db-assigned, identifier for this tower.
|
||||||
|
ID wtdb.TowerID
|
||||||
|
|
||||||
|
// IdentityKey is the public key of the remote node, used to
|
||||||
|
// authenticate the brontide transport.
|
||||||
|
IdentityKey *btcec.PublicKey
|
||||||
|
|
||||||
|
// Addresses is an AddressIterator that can be used to manage the
|
||||||
|
// addresses for this tower.
|
||||||
|
Addresses AddressIterator
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTowerFromDBTower converts a wtdb.Tower, which uses a static address list,
|
||||||
|
// into a Tower which uses an address iterator.
|
||||||
|
func NewTowerFromDBTower(t *wtdb.Tower) (*Tower, error) {
|
||||||
|
addrs, err := newAddressIterator(t.Addresses...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Tower{
|
||||||
|
ID: t.ID,
|
||||||
|
IdentityKey: t.IdentityKey,
|
||||||
|
Addresses: addrs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientSession represents the session that a tower client has with a server.
|
||||||
|
type ClientSession struct {
|
||||||
|
// ID is the client's public key used when authenticating with the
|
||||||
|
// tower.
|
||||||
|
ID wtdb.SessionID
|
||||||
|
|
||||||
|
wtdb.ClientSessionBody
|
||||||
|
|
||||||
|
// Tower represents the tower that the client session has been made
|
||||||
|
// with.
|
||||||
|
Tower *Tower
|
||||||
|
|
||||||
|
// SessionKeyECDH is the ECDH capable wrapper of the ephemeral secret
|
||||||
|
// key used to connect to the watchtower.
|
||||||
|
SessionKeyECDH keychain.SingleKeyECDH
|
||||||
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ type SessionNegotiator interface {
|
||||||
|
|
||||||
// NewSessions is a read-only channel where newly negotiated sessions
|
// NewSessions is a read-only channel where newly negotiated sessions
|
||||||
// will be delivered.
|
// will be delivered.
|
||||||
NewSessions() <-chan *wtdb.ClientSession
|
NewSessions() <-chan *ClientSession
|
||||||
|
|
||||||
// Start safely initializes the session negotiator.
|
// Start safely initializes the session negotiator.
|
||||||
Start() error
|
Start() error
|
||||||
|
@ -105,8 +105,8 @@ type sessionNegotiator struct {
|
||||||
log btclog.Logger
|
log btclog.Logger
|
||||||
|
|
||||||
dispatcher chan struct{}
|
dispatcher chan struct{}
|
||||||
newSessions chan *wtdb.ClientSession
|
newSessions chan *ClientSession
|
||||||
successfulNegotiations chan *wtdb.ClientSession
|
successfulNegotiations chan *ClientSession
|
||||||
|
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
quit chan struct{}
|
quit chan struct{}
|
||||||
|
@ -139,8 +139,8 @@ func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator {
|
||||||
log: cfg.Log,
|
log: cfg.Log,
|
||||||
localInit: localInit,
|
localInit: localInit,
|
||||||
dispatcher: make(chan struct{}, 1),
|
dispatcher: make(chan struct{}, 1),
|
||||||
newSessions: make(chan *wtdb.ClientSession),
|
newSessions: make(chan *ClientSession),
|
||||||
successfulNegotiations: make(chan *wtdb.ClientSession),
|
successfulNegotiations: make(chan *ClientSession),
|
||||||
quit: make(chan struct{}),
|
quit: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -171,7 +171,7 @@ func (n *sessionNegotiator) Stop() error {
|
||||||
|
|
||||||
// NewSessions returns a receive-only channel from which newly negotiated
|
// NewSessions returns a receive-only channel from which newly negotiated
|
||||||
// sessions will be returned.
|
// sessions will be returned.
|
||||||
func (n *sessionNegotiator) NewSessions() <-chan *wtdb.ClientSession {
|
func (n *sessionNegotiator) NewSessions() <-chan *ClientSession {
|
||||||
return n.newSessions
|
return n.newSessions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -333,18 +333,10 @@ retryWithBackoff:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// createSession takes a tower an attempts to negotiate a session using any of
|
// createSession takes a tower and attempts to negotiate a session using any of
|
||||||
// its stored addresses. This method returns after the first successful
|
// its stored addresses. This method returns after the first successful
|
||||||
// negotiation, or after all addresses have failed with ErrFailedNegotiation. If
|
// negotiation, or after all addresses have failed with ErrFailedNegotiation.
|
||||||
// the tower has no addresses, ErrNoTowerAddrs is returned.
|
func (n *sessionNegotiator) createSession(tower *Tower, keyIndex uint32) error {
|
||||||
func (n *sessionNegotiator) createSession(tower *wtdb.Tower,
|
|
||||||
keyIndex uint32) error {
|
|
||||||
|
|
||||||
// If the tower has no addresses, there's nothing we can do.
|
|
||||||
if len(tower.Addresses) == 0 {
|
|
||||||
return ErrNoTowerAddrs
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionKeyDesc, err := n.cfg.SecretKeyRing.DeriveKey(
|
sessionKeyDesc, err := n.cfg.SecretKeyRing.DeriveKey(
|
||||||
keychain.KeyLocator{
|
keychain.KeyLocator{
|
||||||
Family: keychain.KeyFamilyTowerSession,
|
Family: keychain.KeyFamilyTowerSession,
|
||||||
|
@ -358,8 +350,15 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower,
|
||||||
sessionKeyDesc, n.cfg.SecretKeyRing,
|
sessionKeyDesc, n.cfg.SecretKeyRing,
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, lnAddr := range tower.LNAddrs() {
|
addr := tower.Addresses.PeekAndLock()
|
||||||
err := n.tryAddress(sessionKey, keyIndex, tower, lnAddr)
|
for {
|
||||||
|
lnAddr := &lnwire.NetAddress{
|
||||||
|
IdentityKey: tower.IdentityKey,
|
||||||
|
Address: addr,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = n.tryAddress(sessionKey, keyIndex, tower, lnAddr)
|
||||||
|
tower.Addresses.ReleaseLock(addr)
|
||||||
switch {
|
switch {
|
||||||
case err == ErrPermanentTowerFailure:
|
case err == ErrPermanentTowerFailure:
|
||||||
// TODO(conner): report to iterator? can then be reset
|
// TODO(conner): report to iterator? can then be reset
|
||||||
|
@ -370,6 +369,15 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower,
|
||||||
n.log.Debugf("Request for session negotiation with "+
|
n.log.Debugf("Request for session negotiation with "+
|
||||||
"tower=%s failed, trying again -- reason: "+
|
"tower=%s failed, trying again -- reason: "+
|
||||||
"%v", lnAddr, err)
|
"%v", lnAddr, err)
|
||||||
|
|
||||||
|
// Get the next tower address if there is one.
|
||||||
|
addr, err = tower.Addresses.NextAndLock()
|
||||||
|
if err == ErrAddressesExhausted {
|
||||||
|
tower.Addresses.Reset()
|
||||||
|
|
||||||
|
return ErrFailedNegotiation
|
||||||
|
}
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
@ -385,7 +393,7 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower,
|
||||||
// returns true if all steps succeed and the new session has been persisted, and
|
// returns true if all steps succeed and the new session has been persisted, and
|
||||||
// fails otherwise.
|
// fails otherwise.
|
||||||
func (n *sessionNegotiator) tryAddress(sessionKey keychain.SingleKeyECDH,
|
func (n *sessionNegotiator) tryAddress(sessionKey keychain.SingleKeyECDH,
|
||||||
keyIndex uint32, tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error {
|
keyIndex uint32, tower *Tower, lnAddr *lnwire.NetAddress) error {
|
||||||
|
|
||||||
// Connect to the tower address using our generated session key.
|
// Connect to the tower address using our generated session key.
|
||||||
conn, err := n.cfg.Dial(sessionKey, lnAddr)
|
conn, err := n.cfg.Dial(sessionKey, lnAddr)
|
||||||
|
@ -456,26 +464,31 @@ func (n *sessionNegotiator) tryAddress(sessionKey keychain.SingleKeyECDH,
|
||||||
rewardPkScript := createSessionReply.Data
|
rewardPkScript := createSessionReply.Data
|
||||||
|
|
||||||
sessionID := wtdb.NewSessionIDFromPubKey(sessionKey.PubKey())
|
sessionID := wtdb.NewSessionIDFromPubKey(sessionKey.PubKey())
|
||||||
clientSession := &wtdb.ClientSession{
|
dbClientSession := &wtdb.ClientSession{
|
||||||
ClientSessionBody: wtdb.ClientSessionBody{
|
ClientSessionBody: wtdb.ClientSessionBody{
|
||||||
TowerID: tower.ID,
|
TowerID: tower.ID,
|
||||||
KeyIndex: keyIndex,
|
KeyIndex: keyIndex,
|
||||||
Policy: n.cfg.Policy,
|
Policy: n.cfg.Policy,
|
||||||
RewardPkScript: rewardPkScript,
|
RewardPkScript: rewardPkScript,
|
||||||
},
|
},
|
||||||
Tower: tower,
|
ID: sessionID,
|
||||||
SessionKeyECDH: sessionKey,
|
|
||||||
ID: sessionID,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = n.cfg.DB.CreateClientSession(clientSession)
|
err = n.cfg.DB.CreateClientSession(dbClientSession)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to persist ClientSession: %v",
|
return fmt.Errorf("unable to persist ClientSession: %v",
|
||||||
err)
|
err)
|
||||||
}
|
}
|
||||||
|
|
||||||
n.log.Debugf("New session negotiated with %s, policy: %s",
|
n.log.Debugf("New session negotiated with %s, policy: %s",
|
||||||
lnAddr, clientSession.Policy)
|
lnAddr, dbClientSession.Policy)
|
||||||
|
|
||||||
|
clientSession := &ClientSession{
|
||||||
|
ID: sessionID,
|
||||||
|
ClientSessionBody: dbClientSession.ClientSessionBody,
|
||||||
|
Tower: tower,
|
||||||
|
SessionKeyECDH: sessionKey,
|
||||||
|
}
|
||||||
|
|
||||||
// We have a newly negotiated session, return it to the
|
// We have a newly negotiated session, return it to the
|
||||||
// dispatcher so that it can update how many outstanding
|
// dispatcher so that it can update how many outstanding
|
||||||
|
|
|
@ -34,7 +34,7 @@ const (
|
||||||
type sessionQueueConfig struct {
|
type sessionQueueConfig struct {
|
||||||
// ClientSession provides access to the negotiated session parameters
|
// ClientSession provides access to the negotiated session parameters
|
||||||
// and updating its persistent storage.
|
// and updating its persistent storage.
|
||||||
ClientSession *wtdb.ClientSession
|
ClientSession *ClientSession
|
||||||
|
|
||||||
// ChainHash identifies the chain for which the session's justice
|
// ChainHash identifies the chain for which the session's justice
|
||||||
// transactions are targeted.
|
// transactions are targeted.
|
||||||
|
@ -97,7 +97,7 @@ type sessionQueue struct {
|
||||||
queueCond *sync.Cond
|
queueCond *sync.Cond
|
||||||
|
|
||||||
localInit *wtwire.Init
|
localInit *wtwire.Init
|
||||||
towerAddr *lnwire.NetAddress
|
tower *Tower
|
||||||
|
|
||||||
seqNum uint16
|
seqNum uint16
|
||||||
|
|
||||||
|
@ -117,18 +117,13 @@ func newSessionQueue(cfg *sessionQueueConfig,
|
||||||
cfg.ChainHash,
|
cfg.ChainHash,
|
||||||
)
|
)
|
||||||
|
|
||||||
towerAddr := &lnwire.NetAddress{
|
|
||||||
IdentityKey: cfg.ClientSession.Tower.IdentityKey,
|
|
||||||
Address: cfg.ClientSession.Tower.Addresses[0],
|
|
||||||
}
|
|
||||||
|
|
||||||
sq := &sessionQueue{
|
sq := &sessionQueue{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
log: cfg.Log,
|
log: cfg.Log,
|
||||||
commitQueue: list.New(),
|
commitQueue: list.New(),
|
||||||
pendingQueue: list.New(),
|
pendingQueue: list.New(),
|
||||||
localInit: localInit,
|
localInit: localInit,
|
||||||
towerAddr: towerAddr,
|
tower: cfg.ClientSession.Tower,
|
||||||
seqNum: cfg.ClientSession.SeqNum,
|
seqNum: cfg.ClientSession.SeqNum,
|
||||||
retryBackoff: cfg.MinBackoff,
|
retryBackoff: cfg.MinBackoff,
|
||||||
quit: make(chan struct{}),
|
quit: make(chan struct{}),
|
||||||
|
@ -293,18 +288,48 @@ func (q *sessionQueue) sessionManager() {
|
||||||
|
|
||||||
// drainBackups attempts to send all pending updates in the queue to the tower.
|
// drainBackups attempts to send all pending updates in the queue to the tower.
|
||||||
func (q *sessionQueue) drainBackups() {
|
func (q *sessionQueue) drainBackups() {
|
||||||
// First, check that we are able to dial this session's tower.
|
var (
|
||||||
conn, err := q.cfg.Dial(q.cfg.ClientSession.SessionKeyECDH, q.towerAddr)
|
conn wtserver.Peer
|
||||||
if err != nil {
|
err error
|
||||||
q.log.Errorf("SessionQueue(%s) unable to dial tower at %v: %v",
|
towerAddr = q.tower.Addresses.Peek()
|
||||||
q.ID(), q.towerAddr, err)
|
)
|
||||||
|
|
||||||
q.increaseBackoff()
|
for {
|
||||||
select {
|
q.log.Infof("SessionQueue(%s) attempting to dial tower at %v",
|
||||||
case <-time.After(q.retryBackoff):
|
q.ID(), towerAddr)
|
||||||
case <-q.forceQuit:
|
|
||||||
|
// First, check that we are able to dial this session's tower.
|
||||||
|
conn, err = q.cfg.Dial(
|
||||||
|
q.cfg.ClientSession.SessionKeyECDH, &lnwire.NetAddress{
|
||||||
|
IdentityKey: q.tower.IdentityKey,
|
||||||
|
Address: towerAddr,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
// If there are more addrs available, immediately try
|
||||||
|
// those.
|
||||||
|
nextAddr, iteratorErr := q.tower.Addresses.Next()
|
||||||
|
if iteratorErr == nil {
|
||||||
|
towerAddr = nextAddr
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, if we have exhausted the address list,
|
||||||
|
// back off and try again later.
|
||||||
|
q.tower.Addresses.Reset()
|
||||||
|
|
||||||
|
q.log.Errorf("SessionQueue(%s) unable to dial tower "+
|
||||||
|
"at any available Addresses: %v", q.ID(), err)
|
||||||
|
|
||||||
|
q.increaseBackoff()
|
||||||
|
select {
|
||||||
|
case <-time.After(q.retryBackoff):
|
||||||
|
case <-q.forceQuit:
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
return
|
|
||||||
|
break
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
|
@ -324,9 +349,7 @@ func (q *sessionQueue) drainBackups() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now, send the state update to the tower and wait for a reply.
|
// Now, send the state update to the tower and wait for a reply.
|
||||||
err = q.sendStateUpdate(
|
err = q.sendStateUpdate(conn, stateUpdate, sendInit, isPending)
|
||||||
conn, stateUpdate, q.localInit, sendInit, isPending,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
q.log.Errorf("SessionQueue(%s) unable to send state "+
|
q.log.Errorf("SessionQueue(%s) unable to send state "+
|
||||||
"update: %v", q.ID(), err)
|
"update: %v", q.ID(), err)
|
||||||
|
@ -483,8 +506,12 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool,
|
||||||
// variable indicates whether we should back off before attempting to send the
|
// variable indicates whether we should back off before attempting to send the
|
||||||
// next state update.
|
// next state update.
|
||||||
func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
|
func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
|
||||||
stateUpdate *wtwire.StateUpdate, localInit *wtwire.Init,
|
stateUpdate *wtwire.StateUpdate, sendInit, isPending bool) error {
|
||||||
sendInit, isPending bool) error {
|
|
||||||
|
towerAddr := &lnwire.NetAddress{
|
||||||
|
IdentityKey: conn.RemotePub(),
|
||||||
|
Address: conn.RemoteAddr(),
|
||||||
|
}
|
||||||
|
|
||||||
// If this is the first message being sent to the tower, we must send an
|
// If this is the first message being sent to the tower, we must send an
|
||||||
// Init message to establish that server supports the features we
|
// Init message to establish that server supports the features we
|
||||||
|
@ -505,7 +532,7 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
|
||||||
remoteInit, ok := remoteMsg.(*wtwire.Init)
|
remoteInit, ok := remoteMsg.(*wtwire.Init)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("watchtower %s responded with %T "+
|
return fmt.Errorf("watchtower %s responded with %T "+
|
||||||
"to Init", q.towerAddr, remoteMsg)
|
"to Init", towerAddr, remoteMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate Init.
|
// Validate Init.
|
||||||
|
@ -532,7 +559,7 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
|
||||||
stateUpdateReply, ok := remoteMsg.(*wtwire.StateUpdateReply)
|
stateUpdateReply, ok := remoteMsg.(*wtwire.StateUpdateReply)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("watchtower %s responded with %T to "+
|
return fmt.Errorf("watchtower %s responded with %T to "+
|
||||||
"StateUpdate", q.towerAddr, remoteMsg)
|
"StateUpdate", towerAddr, remoteMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the reply from the tower.
|
// Process the reply from the tower.
|
||||||
|
@ -547,8 +574,8 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
|
||||||
err := fmt.Errorf("received error code %v in "+
|
err := fmt.Errorf("received error code %v in "+
|
||||||
"StateUpdateReply for seqnum=%d",
|
"StateUpdateReply for seqnum=%d",
|
||||||
stateUpdateReply.Code, stateUpdate.SeqNum)
|
stateUpdateReply.Code, stateUpdate.SeqNum)
|
||||||
q.log.Warnf("SessionQueue(%s) unable to upload state update to "+
|
q.log.Warnf("SessionQueue(%s) unable to upload state update "+
|
||||||
"tower=%s: %v", q.ID(), q.towerAddr, err)
|
"to tower=%s: %v", q.ID(), towerAddr, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -559,7 +586,8 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
|
||||||
// TODO(conner): borked watchtower
|
// TODO(conner): borked watchtower
|
||||||
err = fmt.Errorf("unable to ack seqnum=%d: %v",
|
err = fmt.Errorf("unable to ack seqnum=%d: %v",
|
||||||
stateUpdate.SeqNum, err)
|
stateUpdate.SeqNum, err)
|
||||||
q.log.Errorf("SessionQueue(%v) failed to ack update: %v", q.ID(), err)
|
q.log.Errorf("SessionQueue(%v) failed to ack update: %v",
|
||||||
|
q.ID(), err)
|
||||||
return err
|
return err
|
||||||
|
|
||||||
case err == wtdb.ErrLastAppliedReversion:
|
case err == wtdb.ErrLastAppliedReversion:
|
||||||
|
|
|
@ -429,7 +429,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
towerSessions, err := listTowerSessions(
|
towerSessions, err := listTowerSessions(
|
||||||
towerID, sessions, towers, towersToSessionsIndex,
|
towerID, sessions, towersToSessionsIndex,
|
||||||
WithPerCommittedUpdate(perCommittedUpdate),
|
WithPerCommittedUpdate(perCommittedUpdate),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -766,7 +766,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||||
// known to the db.
|
// known to the db.
|
||||||
if id == nil {
|
if id == nil {
|
||||||
clientSessions, err = listClientAllSessions(
|
clientSessions, err = listClientAllSessions(
|
||||||
sessions, towers, opts...,
|
sessions, opts...,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -778,7 +778,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||||
}
|
}
|
||||||
|
|
||||||
clientSessions, err = listTowerSessions(
|
clientSessions, err = listTowerSessions(
|
||||||
*id, sessions, towers, towerToSessionIndex, opts...,
|
*id, sessions, towerToSessionIndex, opts...,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}, func() {
|
}, func() {
|
||||||
|
@ -792,7 +792,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||||
}
|
}
|
||||||
|
|
||||||
// listClientAllSessions returns the set of all client sessions known to the db.
|
// listClientAllSessions returns the set of all client sessions known to the db.
|
||||||
func listClientAllSessions(sessions, towers kvdb.RBucket,
|
func listClientAllSessions(sessions kvdb.RBucket,
|
||||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||||
|
|
||||||
clientSessions := make(map[SessionID]*ClientSession)
|
clientSessions := make(map[SessionID]*ClientSession)
|
||||||
|
@ -801,7 +801,7 @@ func listClientAllSessions(sessions, towers kvdb.RBucket,
|
||||||
// the CommittedUpdates and AckedUpdates on startup to resume
|
// the CommittedUpdates and AckedUpdates on startup to resume
|
||||||
// committed updates and compute the highest known commit height
|
// committed updates and compute the highest known commit height
|
||||||
// for each channel.
|
// for each channel.
|
||||||
session, err := getClientSession(sessions, towers, k, opts...)
|
session, err := getClientSession(sessions, k, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -819,7 +819,7 @@ func listClientAllSessions(sessions, towers kvdb.RBucket,
|
||||||
|
|
||||||
// listTowerSessions returns the set of all client sessions known to the db
|
// listTowerSessions returns the set of all client sessions known to the db
|
||||||
// that are associated with the given tower id.
|
// that are associated with the given tower id.
|
||||||
func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
|
func listTowerSessions(id TowerID, sessionsBkt,
|
||||||
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
|
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
|
||||||
map[SessionID]*ClientSession, error) {
|
map[SessionID]*ClientSession, error) {
|
||||||
|
|
||||||
|
@ -834,9 +834,7 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
|
||||||
// the CommittedUpdates and AckedUpdates on startup to resume
|
// the CommittedUpdates and AckedUpdates on startup to resume
|
||||||
// committed updates and compute the highest known commit height
|
// committed updates and compute the highest known commit height
|
||||||
// for each channel.
|
// for each channel.
|
||||||
session, err := getClientSession(
|
session, err := getClientSession(sessionsBkt, k, opts...)
|
||||||
sessionsBkt, towersBkt, k, opts...,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -1248,7 +1246,7 @@ func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
|
||||||
// getClientSession loads the full ClientSession associated with the serialized
|
// getClientSession loads the full ClientSession associated with the serialized
|
||||||
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
|
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
|
||||||
// in addition to the ClientSession's body.
|
// in addition to the ClientSession's body.
|
||||||
func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
|
func getClientSession(sessions kvdb.RBucket, idBytes []byte,
|
||||||
opts ...ClientSessionListOption) (*ClientSession, error) {
|
opts ...ClientSessionListOption) (*ClientSession, error) {
|
||||||
|
|
||||||
cfg := NewClientSessionCfg()
|
cfg := NewClientSessionCfg()
|
||||||
|
@ -1261,13 +1259,6 @@ func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch the tower associated with this session.
|
|
||||||
tower, err := getTower(towers, session.TowerID.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
session.Tower = tower
|
|
||||||
|
|
||||||
// Can't fail because client session body has already been read.
|
// Can't fail because client session body has already been read.
|
||||||
sessionBkt := sessions.NestedReadBucket(idBytes)
|
sessionBkt := sessions.NestedReadBucket(idBytes)
|
||||||
|
|
||||||
|
|
|
@ -343,8 +343,11 @@ func testCreateTower(h *clientDBHarness) {
|
||||||
h.loadTowerByID(20, wtdb.ErrTowerNotFound)
|
h.loadTowerByID(20, wtdb.ErrTowerNotFound)
|
||||||
|
|
||||||
tower := h.newTower()
|
tower := h.newTower()
|
||||||
require.Len(h.t, tower.LNAddrs(), 1)
|
require.Len(h.t, tower.Addresses, 1)
|
||||||
towerAddr := tower.LNAddrs()[0]
|
towerAddr := &lnwire.NetAddress{
|
||||||
|
IdentityKey: tower.IdentityKey,
|
||||||
|
Address: tower.Addresses[0],
|
||||||
|
}
|
||||||
|
|
||||||
// Load the tower from the database and assert that it matches the tower
|
// Load the tower from the database and assert that it matches the tower
|
||||||
// we created.
|
// we created.
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/lightningnetwork/lnd/keychain"
|
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||||
|
@ -36,19 +35,6 @@ type ClientSession struct {
|
||||||
ID SessionID
|
ID SessionID
|
||||||
|
|
||||||
ClientSessionBody
|
ClientSessionBody
|
||||||
|
|
||||||
// Tower holds the pubkey and address of the watchtower.
|
|
||||||
//
|
|
||||||
// NOTE: This value is not serialized. It is recovered by looking up the
|
|
||||||
// tower with TowerID.
|
|
||||||
Tower *Tower
|
|
||||||
|
|
||||||
// SessionKeyECDH is the ECDH capable wrapper of the ephemeral secret
|
|
||||||
// key used to connect to the watchtower.
|
|
||||||
//
|
|
||||||
// NOTE: This value is not serialized. It is derived using the KeyIndex
|
|
||||||
// on startup to avoid storing private keys on disk.
|
|
||||||
SessionKeyECDH keychain.SingleKeyECDH
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientSessionBody represents the primary components of a ClientSession that
|
// ClientSessionBody represents the primary components of a ClientSession that
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec/v2"
|
"github.com/btcsuite/btcd/btcec/v2"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TowerID is a unique 64-bit identifier allocated to each unique watchtower.
|
// TowerID is a unique 64-bit identifier allocated to each unique watchtower.
|
||||||
|
@ -77,23 +76,6 @@ func (t *Tower) RemoveAddress(addr net.Addr) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LNAddrs generates a list of lnwire.NetAddress from a Tower instance's
|
|
||||||
// addresses. This can be used to have a client try multiple addresses for the
|
|
||||||
// same Tower.
|
|
||||||
//
|
|
||||||
// NOTE: This method is NOT safe for concurrent use.
|
|
||||||
func (t *Tower) LNAddrs() []*lnwire.NetAddress {
|
|
||||||
addrs := make([]*lnwire.NetAddress, 0, len(t.Addresses))
|
|
||||||
for _, addr := range t.Addresses {
|
|
||||||
addrs = append(addrs, &lnwire.NetAddress{
|
|
||||||
IdentityKey: t.IdentityKey,
|
|
||||||
Address: addr,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return addrs
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns a user-friendly identifier of the tower.
|
// String returns a user-friendly identifier of the tower.
|
||||||
func (t *Tower) String() string {
|
func (t *Tower) String() string {
|
||||||
pubKey := hex.EncodeToString(t.IdentityKey.SerializeCompressed())
|
pubKey := hex.EncodeToString(t.IdentityKey.SerializeCompressed())
|
||||||
|
|
|
@ -231,7 +231,6 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
|
||||||
if tower != nil && *tower != session.TowerID {
|
if tower != nil && *tower != session.TowerID {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
session.Tower = m.towers[session.TowerID]
|
|
||||||
sessions[session.ID] = &session
|
sessions[session.ID] = &session
|
||||||
|
|
||||||
if cfg.PerAckedUpdate != nil {
|
if cfg.PerAckedUpdate != nil {
|
||||||
|
|
Loading…
Add table
Reference in a new issue