diff --git a/docs/release-notes/release-notes-0.16.0.md b/docs/release-notes/release-notes-0.16.0.md index 48ec8b807..ca19670c6 100644 --- a/docs/release-notes/release-notes-0.16.0.md +++ b/docs/release-notes/release-notes-0.16.0.md @@ -89,6 +89,9 @@ https://github.com/lightningnetwork/lnd/pull/6963/) * [Fixed a flake in the TestBlockCacheMutexes unit test](https://github.com/lightningnetwork/lnd/pull/7029). +* [Create a helper function to wait for peer to come + online](https://github.com/lightningnetwork/lnd/pull/6931). + ## `lncli` * [Add an `insecure` flag to skip tls auth as well as a `metadata` string slice flag](https://github.com/lightningnetwork/lnd/pull/6818) that allows the @@ -119,6 +122,12 @@ https://github.com/lightningnetwork/lnd/pull/6963/) caller is expected to know that doing so with untrusted input is unsafe.](https://github.com/lightningnetwork/lnd/pull/6779) +* [test: replace defer cleanup with + `t.Cleanup`](https://github.com/lightningnetwork/lnd/pull/6864). + +* [test: fix loop variables being accessed in + closures](https://github.com/lightningnetwork/lnd/pull/7032). + ## Watchtowers * [Create a towerID-to-sessionID index in the wtclient DB to improve the @@ -131,14 +140,10 @@ https://github.com/lightningnetwork/lnd/pull/6963/) struct](https://github.com/lightningnetwork/lnd/pull/6928) in order to improve the performance of fetching a `ClientSession` from the DB. -* [Create a helper function to wait for peer to come - online](https://github.com/lightningnetwork/lnd/pull/6931). - -* [test: replace defer cleanup with - `t.Cleanup`](https://github.com/lightningnetwork/lnd/pull/6864). - -* [test: fix loop variables being accessed in - closures](https://github.com/lightningnetwork/lnd/pull/7032). +* [Allow user to update tower address without requiring a restart. Also allow + the removal of a tower address if the current session negotiation is not + using the address in question]( + https://github.com/lightningnetwork/lnd/pull/7025) ### Tooling and documentation diff --git a/watchtower/wtclient/addr_iterator.go b/watchtower/wtclient/addr_iterator.go new file mode 100644 index 000000000..87065c011 --- /dev/null +++ b/watchtower/wtclient/addr_iterator.go @@ -0,0 +1,344 @@ +package wtclient + +import ( + "container/list" + "errors" + "fmt" + "net" + "sync" + + "github.com/lightningnetwork/lnd/watchtower/wtdb" +) + +var ( + // ErrAddressesExhausted signals that a addressIterator has cycled + // through all available addresses. + ErrAddressesExhausted = errors.New("exhausted all addresses") + + // ErrAddrInUse indicates that an address is locked and cannot be + // removed from the addressIterator. + ErrAddrInUse = errors.New("address in use") +) + +// AddressIterator handles iteration over a list of addresses. It strictly +// disallows the list of addresses it holds to be empty. It also allows callers +// to place locks on certain addresses in order to prevent other callers from +// removing the addresses in question from the iterator. +type AddressIterator interface { + // Next returns the next candidate address. This iterator will always + // return candidates in the order given when the iterator was + // instantiated. If no more candidates are available, + // ErrAddressesExhausted is returned. + Next() (net.Addr, error) + + // NextAndLock does the same as described for Next, and it also places a + // lock on the returned address so that the address can not be removed + // until the lock on it has been released via ReleaseLock. + NextAndLock() (net.Addr, error) + + // Peek returns the currently selected address in the iterator. If the + // end of the iterator has been reached then it is reset and the first + // item in the iterator is returned. Since the AddressIterator will + // never have an empty address list, this function will never return a + // nil value. + Peek() net.Addr + + // PeekAndLock does the same as described for Peek, and it also places + // a lock on the returned address so that the address can not be removed + // until the lock on it has been released via ReleaseLock. + PeekAndLock() net.Addr + + // ReleaseLock releases the lock held on the given address. + ReleaseLock(addr net.Addr) + + // Add adds a new address to the iterator. + Add(addr net.Addr) + + // Remove removes an existing address from the iterator. It disallows + // the address from being removed if it is the last address in the + // iterator or if there is currently a lock on the address. + Remove(addr net.Addr) error + + // HasLocked returns true if the addressIterator has any locked + // addresses. + HasLocked() bool + + // GetAll returns a copy of all the addresses in the iterator. + GetAll() []net.Addr + + // Reset clears the iterators state, and makes the address at the front + // of the list the next item to be returned. + Reset() +} + +// A compile-time check to ensure that addressIterator implements the +// AddressIterator interface. +var _ AddressIterator = (*addressIterator)(nil) + +// addressIterator is a linked-list implementation of an AddressIterator. +type addressIterator struct { + mu sync.Mutex + addrList *list.List + currentTopAddr *list.Element + candidates map[string]*candidateAddr + totalLockCount int +} + +type candidateAddr struct { + addr net.Addr + numLocks int +} + +// newAddressIterator constructs a new addressIterator. +func newAddressIterator(addrs ...net.Addr) (*addressIterator, error) { + if len(addrs) == 0 { + return nil, fmt.Errorf("must have at least one address") + } + + iter := &addressIterator{ + addrList: list.New(), + candidates: make(map[string]*candidateAddr), + } + + for _, addr := range addrs { + addrID := addr.String() + iter.addrList.PushBack(addrID) + iter.candidates[addrID] = &candidateAddr{addr: addr} + } + iter.Reset() + + return iter, nil +} + +// Reset clears the iterators state, and makes the address at the front of the +// list the next item to be returned. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) Reset() { + a.mu.Lock() + defer a.mu.Unlock() + + a.unsafeReset() +} + +// unsafeReset clears the iterator state and makes the address at the front of +// the list the next item to be returned. +// +// NOTE: this method is not thread safe and so should only be called if the +// appropriate mutex is being held. +func (a *addressIterator) unsafeReset() { + // Reset the next candidate to the front of the linked-list. + a.currentTopAddr = a.addrList.Front() +} + +// Next returns the next candidate address. This iterator will always return +// candidates in the order given when the iterator was instantiated. If no more +// candidates are available, ErrAddressesExhausted is returned. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) Next() (net.Addr, error) { + return a.next(false) +} + +// NextAndLock does the same as described for Next, and it also places a lock on +// the returned address so that the address can not be removed until the lock on +// it has been released via ReleaseLock. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) NextAndLock() (net.Addr, error) { + return a.next(true) +} + +// next returns the next candidate address. This iterator will always return +// candidates in the order given when the iterator was instantiated. If no more +// candidates are available, ErrAddressesExhausted is returned. +func (a *addressIterator) next(lock bool) (net.Addr, error) { + a.mu.Lock() + defer a.mu.Unlock() + + // Set the next candidate to the subsequent element. + a.currentTopAddr = a.currentTopAddr.Next() + + for a.currentTopAddr != nil { + // Propose the address at the front of the list. + addrID := a.currentTopAddr.Value.(string) + + // Check whether this address is still considered a candidate. + // If it's not, we'll proceed to the next. + candidate, ok := a.candidates[addrID] + if !ok { + nextCandidate := a.currentTopAddr.Next() + a.addrList.Remove(a.currentTopAddr) + a.currentTopAddr = nextCandidate + continue + } + + if lock { + candidate.numLocks++ + a.totalLockCount++ + } + + return candidate.addr, nil + } + + return nil, ErrAddressesExhausted +} + +// Peek returns the currently selected address in the iterator. If the end of +// the list has been reached then the iterator is reset and the first item in +// the list is returned. Since the addressIterator will never have an empty +// address list, this function will never return a nil value. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) Peek() net.Addr { + return a.peek(false) +} + +// PeekAndLock does the same as described for Peek, and it also places a lock on +// the returned address so that the address can not be removed until the lock +// on it has been released via ReleaseLock. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) PeekAndLock() net.Addr { + return a.peek(true) +} + +// peek returns the currently selected address in the iterator. If the end of +// the list has been reached then the iterator is reset and the first item in +// the list is returned. Since the addressIterator will never have an empty +// address list, this function will never return a nil value. If lock is set to +// true, the address will be locked for removal until ReleaseLock has been +// called for the address. +func (a *addressIterator) peek(lock bool) net.Addr { + a.mu.Lock() + defer a.mu.Unlock() + + for { + // If currentTopAddr is nil, it means we have reached the end of + // the list, so we reset it here. The iterator always has at + // least one address, so we can be sure that currentTopAddr will + // be non-nil after calling reset here. + if a.currentTopAddr == nil { + a.unsafeReset() + } + + addrID := a.currentTopAddr.Value.(string) + candidate, ok := a.candidates[addrID] + if !ok { + nextCandidate := a.currentTopAddr.Next() + a.addrList.Remove(a.currentTopAddr) + a.currentTopAddr = nextCandidate + continue + } + + if lock { + candidate.numLocks++ + a.totalLockCount++ + } + + return candidate.addr + } +} + +// ReleaseLock releases the lock held on the given address. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) ReleaseLock(addr net.Addr) { + a.mu.Lock() + defer a.mu.Unlock() + + candidateAddr, ok := a.candidates[addr.String()] + if !ok { + return + } + + if candidateAddr.numLocks == 0 { + return + } + + candidateAddr.numLocks-- + a.totalLockCount-- +} + +// Add adds a new address to the iterator. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) Add(addr net.Addr) { + a.mu.Lock() + defer a.mu.Unlock() + + if _, ok := a.candidates[addr.String()]; ok { + return + } + + a.addrList.PushBack(addr.String()) + a.candidates[addr.String()] = &candidateAddr{addr: addr} + + // If we've reached the end of our queue, then this candidate + // will become the next. + if a.currentTopAddr == nil { + a.currentTopAddr = a.addrList.Back() + } +} + +// Remove removes an existing address from the iterator. It disallows the +// address from being removed if it is the last address in the iterator or if +// there is currently a lock on the address. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) Remove(addr net.Addr) error { + a.mu.Lock() + defer a.mu.Unlock() + + candidate, ok := a.candidates[addr.String()] + if !ok { + return nil + } + + if len(a.candidates) == 1 { + return wtdb.ErrLastTowerAddr + } + + if candidate.numLocks > 0 { + return ErrAddrInUse + } + + delete(a.candidates, addr.String()) + return nil +} + +// HasLocked returns true if the addressIterator has any locked addresses. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) HasLocked() bool { + a.mu.Lock() + defer a.mu.Unlock() + + return a.totalLockCount > 0 +} + +// GetAll returns a copy of all the addresses in the iterator. +// +// NOTE: This is part of the AddressIterator interface. +func (a *addressIterator) GetAll() []net.Addr { + a.mu.Lock() + defer a.mu.Unlock() + + var addrs []net.Addr + cursor := a.addrList.Front() + + for cursor != nil { + addrID := cursor.Value.(string) + + addr, ok := a.candidates[addrID] + if !ok { + cursor = cursor.Next() + continue + } + + addrs = append(addrs, addr.addr) + cursor = cursor.Next() + } + + return addrs +} diff --git a/watchtower/wtclient/addr_iterator_test.go b/watchtower/wtclient/addr_iterator_test.go new file mode 100644 index 000000000..d3674d985 --- /dev/null +++ b/watchtower/wtclient/addr_iterator_test.go @@ -0,0 +1,188 @@ +package wtclient + +import ( + "net" + "testing" + + "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/stretchr/testify/require" +) + +// TestAddrIterator tests the behaviour of the addressIterator. +func TestAddrIterator(t *testing.T) { + // Assert that an iterator can't be initialised with an empty address + // list. + _, err := newAddressIterator() + require.ErrorContains(t, err, "must have at least one address") + + addr1, err := net.ResolveTCPAddr("tcp", "1.2.3.4:8000") + require.NoError(t, err) + + // Initialise the iterator with addr1. + iter, err := newAddressIterator(addr1) + require.NoError(t, err) + + // Attempting to remove addr1 should fail now since it is the only + // address in the iterator. + iter.Add(addr1) + err = iter.Remove(addr1) + require.ErrorIs(t, err, wtdb.ErrLastTowerAddr) + + // Adding a duplicate of addr1 and then calling Remove should still + // return an error. + err = iter.Remove(addr1) + require.ErrorIs(t, err, wtdb.ErrLastTowerAddr) + + addr2, err := net.ResolveTCPAddr("tcp", "1.2.3.4:8001") + require.NoError(t, err) + + // Add addr2 to the iterator. + iter.Add(addr2) + + // Check that peek returns addr1. + a1 := iter.Peek() + require.NoError(t, err) + require.Equal(t, addr1, a1) + + // Calling peek multiple times should return the same result. + a1 = iter.Peek() + require.Equal(t, addr1, a1) + + // Calling Next should now return addr2. + a2, err := iter.Next() + require.NoError(t, err) + require.Equal(t, addr2, a2) + + // Assert that Peek now returns addr2. + a2 = iter.Peek() + require.NoError(t, err) + require.Equal(t, addr2, a2) + + // Calling Next should result in reaching the end of th list. + _, err = iter.Next() + require.ErrorIs(t, err, ErrAddressesExhausted) + + // Calling Peek now should reset the queue and return addr1. + a1 = iter.Peek() + require.Equal(t, addr1, a1) + + // Wind the list to the end again so that we can test the Reset func. + _, err = iter.Next() + require.NoError(t, err) + + _, err = iter.Next() + require.ErrorIs(t, err, ErrAddressesExhausted) + + iter.Reset() + + // Now Next should return addr 2. + a2, err = iter.Next() + require.NoError(t, err) + require.Equal(t, addr2, a2) + + addr3, err := net.ResolveTCPAddr("tcp", "1.2.3.4:8002") + require.NoError(t, err) + + // Add addr3 now to ensure that the iteration works even if we are + // midway through the queue. + iter.Add(addr3) + + // Now Next should return addr 3. + a3, err := iter.Next() + require.NoError(t, err) + require.Equal(t, addr3, a3) + + // Quickly test that GetAll correctly returns a copy of all the + // addresses in the iterator. + addrList := iter.GetAll() + require.ElementsMatch(t, addrList, []net.Addr{addr1, addr2, addr3}) + + // Let's now remove addr3. + err = iter.Remove(addr3) + require.NoError(t, err) + + // Since addr3 is gone, Peek should return addr1. + a1 = iter.Peek() + require.Equal(t, addr1, a1) + + // Lastly, we will test the "locking" of addresses. + + // First we test the locking of an address via the PeekAndLock function. + a1 = iter.PeekAndLock() + require.Equal(t, addr1, a1) + require.True(t, iter.HasLocked()) + + // Assert that we can't remove addr1 if there is a lock on it. + err = iter.Remove(addr1) + require.ErrorIs(t, err, ErrAddrInUse) + + // Now release the lock on addr1. + iter.ReleaseLock(addr1) + require.False(t, iter.HasLocked()) + + // Since the lock has been released, we should now be able to remove + // addr1. + err = iter.Remove(addr1) + require.NoError(t, err) + + // Now we test the locking of an address via the NextAndLock function. + // To do this, we first re-add addr3. + iter.Add(addr3) + + a2, err = iter.NextAndLock() + require.NoError(t, err) + require.Equal(t, addr2, a2) + require.True(t, iter.HasLocked()) + + // Assert that we can't remove addr2 if there is a lock on it. + err = iter.Remove(addr2) + require.ErrorIs(t, err, ErrAddrInUse) + + // Now release the lock on addr2. + iter.ReleaseLock(addr2) + require.False(t, iter.HasLocked()) + + // Since the lock has been released, we should now be able to remove + // addr1. + err = iter.Remove(addr2) + require.NoError(t, err) + + // Only addr3 should still be left in the iterator. + addrList = iter.GetAll() + require.Len(t, addrList, 1) + require.Contains(t, addrList, addr3) + + // Ensure that HasLocked acts correctly in the case where more than one + // address is being locked and unlock as well as the case where the same + // address is locked more than once. + + require.False(t, iter.HasLocked()) + + a3 = iter.PeekAndLock() + require.Equal(t, addr3, a3) + require.True(t, iter.HasLocked()) + + a3 = iter.PeekAndLock() + require.Equal(t, addr3, a3) + require.True(t, iter.HasLocked()) + + iter.Add(addr2) + a2, err = iter.NextAndLock() + require.NoError(t, err) + require.Equal(t, addr2, a2) + require.True(t, iter.HasLocked()) + + // Now release addr2 and asset that HasLock is still true. + iter.ReleaseLock(addr2) + require.True(t, iter.HasLocked()) + + // Releasing one of the locks on addr3 now should still result in + // HasLocked returning true. + iter.ReleaseLock(addr3) + require.True(t, iter.HasLocked()) + + // Releasing it again should now result in should still result in + // HasLocked returning false. + iter.ReleaseLock(addr3) + require.False(t, iter.HasLocked()) +} diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index 7d3178f3e..c536c433b 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -2,9 +2,6 @@ package wtclient import ( "bytes" - "crypto/rand" - "io" - "reflect" "testing" "github.com/btcsuite/btcd/btcec/v2" @@ -12,7 +9,6 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -54,14 +50,6 @@ var ( } ) -func makeAddrSlice(size int) []byte { - addr := make([]byte, size) - if _, err := io.ReadFull(rand.Reader, addr); err != nil { - panic("cannot make addr") - } - return addr -} - type backupTaskTest struct { name string chanID lnwire.ChannelID @@ -502,35 +490,12 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Assert that all parameters set during initialization are properly // populated. - if task.id.ChanID != test.chanID { - t.Fatalf("channel id mismatch, want: %s, got: %s", - test.chanID, task.id.ChanID) - } - - if task.id.CommitHeight != test.breachInfo.RevokedStateNum { - t.Fatalf("commit height mismatch, want: %d, got: %d", - test.breachInfo.RevokedStateNum, task.id.CommitHeight) - } - - if task.totalAmt != test.expTotalAmt { - t.Fatalf("total amount mismatch, want: %d, got: %v", - test.expTotalAmt, task.totalAmt) - } - - if !reflect.DeepEqual(task.breachInfo, test.breachInfo) { - t.Fatalf("breach info mismatch, want: %v, got: %v", - test.breachInfo, task.breachInfo) - } - - if !reflect.DeepEqual(task.toLocalInput, test.expToLocalInput) { - t.Fatalf("to-local input mismatch, want: %v, got: %v", - test.expToLocalInput, task.toLocalInput) - } - - if !reflect.DeepEqual(task.toRemoteInput, test.expToRemoteInput) { - t.Fatalf("to-local input mismatch, want: %v, got: %v", - test.expToRemoteInput, task.toRemoteInput) - } + require.Equal(t, test.chanID, task.id.ChanID) + require.Equal(t, test.breachInfo.RevokedStateNum, task.id.CommitHeight) + require.Equal(t, test.expTotalAmt, task.totalAmt) + require.Equal(t, test.breachInfo, task.breachInfo) + require.Equal(t, test.expToLocalInput, task.toLocalInput) + require.Equal(t, test.expToRemoteInput, task.toRemoteInput) // Reconstruct the expected input.Inputs that will be returned by the // task's inputs() method. @@ -545,34 +510,24 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Assert that the inputs method returns the correct slice of // input.Inputs. inputs := task.inputs() - if !reflect.DeepEqual(expInputs, inputs) { - t.Fatalf("inputs mismatch, want: %v, got: %v", - expInputs, inputs) - } + require.Equal(t, expInputs, inputs) // Now, bind the session to the task. If successful, this locks in the // session's negotiated parameters and allows the backup task to derive // the final free variables in the justice transaction. err := task.bindSession(test.session) - if err != test.bindErr { - t.Fatalf("expected: %v when binding session, got: %v", - test.bindErr, err) - } + require.ErrorIs(t, err, test.bindErr) // Exit early if the bind was supposed to fail. But first, we check that // all fields set during a bind are still unset. This ensure that a // failed bind doesn't have side-effects if the task is retried with a // different session. if test.bindErr != nil { - if task.blobType != 0 { - t.Fatalf("blob type should not be set on failed bind, "+ - "found: %s", task.blobType) - } + require.Zerof(t, task.blobType, "blob type should not be set "+ + "on failed bind, found: %s", task.blobType) - if task.outputs != nil { - t.Fatalf("justice outputs should not be set on failed bind, "+ - "found: %v", task.outputs) - } + require.Nilf(t, task.outputs, "justice outputs should not be "+ + " set on failed bind, found: %v", task.outputs) return } @@ -580,10 +535,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Otherwise, the binding succeeded. Assert that all values set during // the bind are properly populated. policy := test.session.Policy - if task.blobType != policy.BlobType { - t.Fatalf("blob type mismatch, want: %s, got %s", - policy.BlobType, task.blobType) - } + require.Equal(t, policy.BlobType, task.blobType) // Compute the expected outputs on the justice transaction. var expOutputs = []*wire.TxOut{ @@ -603,10 +555,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) { } // Assert that the computed outputs match our expected outputs. - if !reflect.DeepEqual(expOutputs, task.outputs) { - t.Fatalf("justice txn output mismatch, want: %v,\ngot: %v", - spew.Sdump(expOutputs), spew.Sdump(task.outputs)) - } + require.Equal(t, expOutputs, task.outputs) // Now, we'll construct, sign, and encrypt the blob containing the parts // needed to reconstruct the justice transaction. @@ -616,10 +565,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Verify that the breach hint matches the breach txid's prefix. breachTxID := test.breachInfo.BreachTxHash expHint := blob.NewBreachHintFromHash(&breachTxID) - if hint != expHint { - t.Fatalf("breach hint mismatch, want: %x, got: %v", - expHint, hint) - } + require.Equal(t, expHint, hint) // Decrypt the return blob to obtain the JusticeKit containing its // contents. @@ -634,14 +580,8 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Assert that the blob contained the serialized revocation and to-local // pubkeys. - if !bytes.Equal(jKit.RevocationPubKey[:], expRevPK) { - t.Fatalf("revocation pk mismatch, want: %x, got: %x", - expRevPK, jKit.RevocationPubKey[:]) - } - if !bytes.Equal(jKit.LocalDelayPubKey[:], expToLocalPK) { - t.Fatalf("revocation pk mismatch, want: %x, got: %x", - expToLocalPK, jKit.LocalDelayPubKey[:]) - } + require.Equal(t, expRevPK, jKit.RevocationPubKey[:]) + require.Equal(t, expToLocalPK, jKit.LocalDelayPubKey[:]) // Determine if the breach transaction has a to-remote output and/or // to-local output to spend from. Note the seemingly-reversed @@ -650,32 +590,19 @@ func testBackupTask(t *testing.T, test backupTaskTest) { hasToLocal := test.breachInfo.RemoteOutputSignDesc != nil // If the to-remote output is present, assert that the to-remote public - // key was included in the blob. - if hasToRemote && - !bytes.Equal(jKit.CommitToRemotePubKey[:], expToRemotePK) { - t.Fatalf("mismatch to-remote pubkey, want: %x, got: %x", - expToRemotePK, jKit.CommitToRemotePubKey) - } - - // Otherwise if the to-local output is not present, assert that a blank - // public key was inserted. - if !hasToRemote && - !bytes.Equal(jKit.CommitToRemotePubKey[:], zeroPK[:]) { - t.Fatalf("mismatch to-remote pubkey, want: %x, got: %x", - zeroPK, jKit.CommitToRemotePubKey) + // key was included in the blob. Otherwise assert that a blank public + // key was inserted. + if hasToRemote { + require.Equal(t, expToRemotePK, jKit.CommitToRemotePubKey[:]) + } else { + require.Equal(t, zeroPK[:], jKit.CommitToRemotePubKey[:]) } // Assert that the CSV is encoded in the blob. - if jKit.CSVDelay != test.breachInfo.RemoteDelay { - t.Fatalf("mismatch remote delay, want: %d, got: %v", - test.breachInfo.RemoteDelay, jKit.CSVDelay) - } + require.Equal(t, test.breachInfo.RemoteDelay, jKit.CSVDelay) // Assert that the sweep pkscript is included. - if !bytes.Equal(jKit.SweepAddress, test.expSweepScript) { - t.Fatalf("sweep pkscript mismatch, want: %x, got: %x", - test.expSweepScript, jKit.SweepAddress) - } + require.Equal(t, test.expSweepScript, jKit.SweepAddress) // Finally, verify that the signatures are encoded in the justice kit. // We don't validate the actual signatures produced here, since at the @@ -684,18 +611,20 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // TODO(conner): include signature validation checks emptyToLocalSig := bytes.Equal(jKit.CommitToLocalSig[:], zeroSig[:]) - switch { - case hasToLocal && emptyToLocalSig: - t.Fatalf("to-local signature should not be empty") - case !hasToLocal && !emptyToLocalSig: - t.Fatalf("to-local signature should be empty") + if hasToLocal { + require.False(t, emptyToLocalSig, "to-local signature should "+ + "not be empty") + } else { + require.True(t, emptyToLocalSig, "to-local signature should "+ + "be empty") } emptyToRemoteSig := bytes.Equal(jKit.CommitToRemoteSig[:], zeroSig[:]) - switch { - case hasToRemote && emptyToRemoteSig: - t.Fatalf("to-remote signature should not be empty") - case !hasToRemote && !emptyToRemoteSig: - t.Fatalf("to-remote signature should be empty") + if hasToRemote { + require.False(t, emptyToRemoteSig, "to-remote signature "+ + "should not be empty") + } else { + require.True(t, emptyToRemoteSig, "to-remote signature "+ + "should be empty") } } diff --git a/watchtower/wtclient/candidate_iterator.go b/watchtower/wtclient/candidate_iterator.go index 5b48a68ef..faf3169c6 100644 --- a/watchtower/wtclient/candidate_iterator.go +++ b/watchtower/wtclient/candidate_iterator.go @@ -13,7 +13,7 @@ import ( type TowerCandidateIterator interface { // AddCandidate adds a new candidate tower to the iterator. If the // candidate already exists, then any new addresses are added to it. - AddCandidate(*wtdb.Tower) + AddCandidate(*Tower) // RemoveCandidate removes an existing candidate tower from the // iterator. An optional address can be provided to indicate a stale @@ -32,7 +32,7 @@ type TowerCandidateIterator interface { // Next returns the next candidate tower. The iterator is not required // to return results in any particular order. If no more candidates are // available, ErrTowerCandidatesExhausted is returned. - Next() (*wtdb.Tower, error) + Next() (*Tower, error) } // towerListIterator is a linked-list backed TowerCandidateIterator. @@ -40,7 +40,7 @@ type towerListIterator struct { mu sync.Mutex queue *list.List nextCandidate *list.Element - candidates map[wtdb.TowerID]*wtdb.Tower + candidates map[wtdb.TowerID]*Tower } // Compile-time constraint to ensure *towerListIterator implements the @@ -49,10 +49,10 @@ var _ TowerCandidateIterator = (*towerListIterator)(nil) // newTowerListIterator initializes a new towerListIterator from a variadic list // of lnwire.NetAddresses. -func newTowerListIterator(candidates ...*wtdb.Tower) *towerListIterator { +func newTowerListIterator(candidates ...*Tower) *towerListIterator { iter := &towerListIterator{ queue: list.New(), - candidates: make(map[wtdb.TowerID]*wtdb.Tower), + candidates: make(map[wtdb.TowerID]*Tower), } for _, candidate := range candidates { @@ -79,7 +79,7 @@ func (t *towerListIterator) Reset() error { // Next returns the next candidate tower. This iterator will always return // candidates in the order given when the iterator was instantiated. If no more // candidates are available, ErrTowerCandidatesExhausted is returned. -func (t *towerListIterator) Next() (*wtdb.Tower, error) { +func (t *towerListIterator) Next() (*Tower, error) { t.mu.Lock() defer t.mu.Unlock() @@ -107,7 +107,7 @@ func (t *towerListIterator) Next() (*wtdb.Tower, error) { // AddCandidate adds a new candidate tower to the iterator. If the candidate // already exists, then any new addresses are added to it. -func (t *towerListIterator) AddCandidate(candidate *wtdb.Tower) { +func (t *towerListIterator) AddCandidate(candidate *Tower) { t.mu.Lock() defer t.mu.Unlock() @@ -121,8 +121,16 @@ func (t *towerListIterator) AddCandidate(candidate *wtdb.Tower) { t.nextCandidate = t.queue.Back() } } else { - for _, addr := range candidate.Addresses { - tower.AddAddress(addr) + candidate.Addresses.Reset() + firstAddr := candidate.Addresses.Peek() + tower.Addresses.Add(firstAddr) + for { + next, err := candidate.Addresses.Next() + if err != nil { + break + } + + tower.Addresses.Add(next) } } } @@ -142,11 +150,15 @@ func (t *towerListIterator) RemoveCandidate(candidate wtdb.TowerID, return nil } if addr != nil { - tower.RemoveAddress(addr) - if len(tower.Addresses) == 0 { - return wtdb.ErrLastTowerAddr + err := tower.Addresses.Remove(addr) + if err != nil { + return err } } else { + if tower.Addresses.HasLocked() { + return ErrAddrInUse + } + delete(t.candidates, candidate) } diff --git a/watchtower/wtclient/candidate_iterator_test.go b/watchtower/wtclient/candidate_iterator_test.go index 99547d794..7fe6ba723 100644 --- a/watchtower/wtclient/candidate_iterator_test.go +++ b/watchtower/wtclient/candidate_iterator_test.go @@ -4,12 +4,10 @@ import ( "encoding/binary" "math/rand" "net" - "reflect" "testing" "time" "github.com/btcsuite/btcd/btcec/v2" - "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/stretchr/testify/require" ) @@ -19,66 +17,75 @@ func init() { } func randAddr(t *testing.T) net.Addr { - var ip [4]byte - if _, err := rand.Read(ip[:]); err != nil { - t.Fatal(err) - } - var port [2]byte - if _, err := rand.Read(port[:]); err != nil { - t.Fatal(err) + t.Helper() + + var ip [4]byte + _, err := rand.Read(ip[:]) + require.NoError(t, err) + + var port [2]byte + _, err = rand.Read(port[:]) + require.NoError(t, err) - } return &net.TCPAddr{ IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(port[:])), } } -func randTower(t *testing.T) *wtdb.Tower { +func randTower(t *testing.T) *Tower { + t.Helper() + priv, err := btcec.NewPrivateKey() require.NoError(t, err, "unable to create private key") pubKey := priv.PubKey() - return &wtdb.Tower{ + addrs, err := newAddressIterator(randAddr(t)) + require.NoError(t, err) + + return &Tower{ ID: wtdb.TowerID(rand.Uint64()), IdentityKey: pubKey, - Addresses: []net.Addr{randAddr(t)}, + Addresses: addrs, } } -func copyTower(tower *wtdb.Tower) *wtdb.Tower { - t := &wtdb.Tower{ +func copyTower(t *testing.T, tower *Tower) *Tower { + t.Helper() + + addrs := tower.Addresses.GetAll() + addrIterator, err := newAddressIterator(addrs...) + require.NoError(t, err) + + return &Tower{ ID: tower.ID, IdentityKey: tower.IdentityKey, - Addresses: make([]net.Addr, len(tower.Addresses)), + Addresses: addrIterator, } - copy(t.Addresses, tower.Addresses) - return t } -func assertActiveCandidate(t *testing.T, i TowerCandidateIterator, - c *wtdb.Tower, active bool) { +func assertActiveCandidate(t *testing.T, i TowerCandidateIterator, c *Tower, + active bool) { + + t.Helper() isCandidate := i.IsActive(c.ID) - if isCandidate && !active { - t.Fatalf("expected tower %v to no longer be an active candidate", - c.ID) - } - if !isCandidate && active { - t.Fatalf("expected tower %v to be an active candidate", c.ID) + if isCandidate { + require.Truef(t, active, "expected tower %v to no longer be "+ + "an active candidate", c.ID) + return } + require.Falsef(t, active, "expected tower %v to be an active candidate", + c.ID) } -func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *wtdb.Tower) { +func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *Tower) { t.Helper() tower, err := i.Next() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(tower, c) { - t.Fatalf("expected tower: %v\ngot: %v", spew.Sdump(c), - spew.Sdump(tower)) - } + require.NoError(t, err) + require.True(t, tower.IdentityKey.IsEqual(c.IdentityKey)) + require.Equal(t, tower.ID, c.ID) + require.Equal(t, tower.Addresses.GetAll(), c.Addresses.GetAll()) } // TestTowerCandidateIterator asserts the internal state of a @@ -90,13 +97,13 @@ func TestTowerCandidateIterator(t *testing.T) { // towers. We'll use copies of these towers within the iterator to // ensure the iterator properly updates the state of its candidates. const numTowers = 4 - towers := make([]*wtdb.Tower, 0, numTowers) + towers := make([]*Tower, 0, numTowers) for i := 0; i < numTowers; i++ { towers = append(towers, randTower(t)) } - towerCopies := make([]*wtdb.Tower, 0, numTowers) + towerCopies := make([]*Tower, 0, numTowers) for _, tower := range towers { - towerCopies = append(towerCopies, copyTower(tower)) + towerCopies = append(towerCopies, copyTower(t, tower)) } towerIterator := newTowerListIterator(towerCopies...) @@ -104,28 +111,23 @@ func TestTowerCandidateIterator(t *testing.T) { // were added. for _, expTower := range towers { tower, err := towerIterator.Next() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(tower, expTower) { - t.Fatalf("expected tower: %v\ngot: %v", - spew.Sdump(expTower), spew.Sdump(tower)) - } + require.NoError(t, err) + require.Equal(t, expTower, tower) } - if _, err := towerIterator.Next(); err != ErrTowerCandidatesExhausted { - t.Fatalf("expected ErrTowerCandidatesExhausted, got %v", err) - } + _, err := towerIterator.Next() + require.ErrorIs(t, err, ErrTowerCandidatesExhausted) + towerIterator.Reset() // We'll then attempt to test the RemoveCandidate behavior of the - // iterator. We'll remove the address of the first tower, which should - // result in it not having any addresses left, but still being an active - // candidate. + // iterator. We'll attempt to remove the address of the first tower, + // which should result in an error due to it being the last address of + // the tower. firstTower := towers[0] - firstTowerAddr := firstTower.Addresses[0] - firstTower.RemoveAddress(firstTowerAddr) - towerIterator.RemoveCandidate(firstTower.ID, firstTowerAddr) + firstTowerAddr := firstTower.Addresses.Peek() + err = towerIterator.RemoveCandidate(firstTower.ID, firstTowerAddr) + require.ErrorIs(t, err, wtdb.ErrLastTowerAddr) assertActiveCandidate(t, towerIterator, firstTower, true) assertNextCandidate(t, towerIterator, firstTower) @@ -133,7 +135,8 @@ func TestTowerCandidateIterator(t *testing.T) { // not providing the optional address. Since it's been removed, we // should expect to see the third tower next. secondTower, thirdTower := towers[1], towers[2] - towerIterator.RemoveCandidate(secondTower.ID, nil) + err = towerIterator.RemoveCandidate(secondTower.ID, nil) + require.NoError(t, err) assertActiveCandidate(t, towerIterator, secondTower, false) assertNextCandidate(t, towerIterator, thirdTower) @@ -142,7 +145,7 @@ func TestTowerCandidateIterator(t *testing.T) { // iterator, but the new address should be. fourthTower := towers[3] assertActiveCandidate(t, towerIterator, fourthTower, true) - fourthTower.AddAddress(randAddr(t)) + fourthTower.Addresses.Add(randAddr(t)) towerIterator.AddCandidate(fourthTower) assertNextCandidate(t, towerIterator, fourthTower) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index f9514f8f1..3aa84f28c 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -2,7 +2,6 @@ package wtclient import ( "bytes" - "errors" "fmt" "net" "sync" @@ -45,8 +44,8 @@ const ( // genActiveSessionFilter generates a filter that selects active sessions that // also match the desired channel type, either legacy or anchor. -func genActiveSessionFilter(anchor bool) func(*wtdb.ClientSession) bool { - return func(s *wtdb.ClientSession) bool { +func genActiveSessionFilter(anchor bool) func(*ClientSession) bool { + return func(s *ClientSession) bool { return s.Status == wtdb.CSessionActive && anchor == s.Policy.IsAnchorChannel() } @@ -241,7 +240,7 @@ type TowerClient struct { negotiator SessionNegotiator candidateTowers TowerCandidateIterator - candidateSessions map[wtdb.SessionID]*wtdb.ClientSession + candidateSessions map[wtdb.SessionID]*ClientSession activeSessions sessionQueueSet sessionQueue *sessionQueue @@ -351,7 +350,7 @@ func New(config *Config) (*TowerClient, error) { activeSessionFilter := genActiveSessionFilter(isAnchorClient) candidateTowers := newTowerListIterator() - perActiveTower := func(tower *wtdb.Tower) { + perActiveTower := func(tower *Tower) { // If the tower has already been marked as active, then there is // no need to add it to the iterator again. if candidateTowers.IsActive(tower.ID) { @@ -400,18 +399,23 @@ func New(config *Config) (*TowerClient, error) { // sessionFilter check then the perActiveTower call-back will be called on that // tower. func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, - sessionFilter func(*wtdb.ClientSession) bool, - perActiveTower func(tower *wtdb.Tower), + sessionFilter func(*ClientSession) bool, + perActiveTower func(tower *Tower), opts ...wtdb.ClientSessionListOption) ( - map[wtdb.SessionID]*wtdb.ClientSession, error) { + map[wtdb.SessionID]*ClientSession, error) { towers, err := db.ListTowers() if err != nil { return nil, err } - candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession) - for _, tower := range towers { + candidateSessions := make(map[wtdb.SessionID]*ClientSession) + for _, dbTower := range towers { + tower, err := NewTowerFromDBTower(dbTower) + if err != nil { + return nil, err + } + sessions, err := db.ListClientSessions(&tower.ID, opts...) if err != nil { return nil, err @@ -427,16 +431,24 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, if err != nil { return nil, err } - s.SessionKeyECDH = keychain.NewPubKeyECDH( + + sessionKeyECDH := keychain.NewPubKeyECDH( towerKeyDesc, keyRing, ) - if !sessionFilter(s) { + cs := &ClientSession{ + ID: s.ID, + ClientSessionBody: s.ClientSessionBody, + Tower: tower, + SessionKeyECDH: sessionKeyECDH, + } + + if !sessionFilter(cs) { continue } // Add the session to the set of candidate sessions. - candidateSessions[s.ID] = s + candidateSessions[s.ID] = cs perActiveTower(tower) } } @@ -452,11 +464,11 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, // ClientSession's SessionPrivKey field is desired, otherwise, the existing // ListClientSessions method should be used. func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, - passesFilter func(*wtdb.ClientSession) bool, + passesFilter func(*ClientSession) bool, opts ...wtdb.ClientSessionListOption) ( - map[wtdb.SessionID]*wtdb.ClientSession, error) { + map[wtdb.SessionID]*ClientSession, error) { - sessions, err := db.ListClientSessions(forTower, opts...) + dbSessions, err := db.ListClientSessions(forTower, opts...) if err != nil { return nil, err } @@ -466,7 +478,13 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, // be able to communicate with the towers and authenticate session // requests. This prevents us from having to store the private keys on // disk. - for _, s := range sessions { + sessions := make(map[wtdb.SessionID]*ClientSession) + for _, s := range dbSessions { + dbTower, err := db.LoadTowerByID(s.TowerID) + if err != nil { + return nil, err + } + towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{ Family: keychain.KeyFamilyTowerSession, Index: s.KeyIndex, @@ -474,13 +492,27 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, if err != nil { return nil, err } - s.SessionKeyECDH = keychain.NewPubKeyECDH(towerKeyDesc, keyRing) + sessionKeyECDH := keychain.NewPubKeyECDH(towerKeyDesc, keyRing) + + tower, err := NewTowerFromDBTower(dbTower) + if err != nil { + return nil, err + } + + cs := &ClientSession{ + ID: s.ID, + ClientSessionBody: s.ClientSessionBody, + Tower: tower, + SessionKeyECDH: sessionKeyECDH, + } // If an optional filter was provided, use it to filter out any // undesired sessions. - if passesFilter != nil && !passesFilter(s) { - delete(sessions, s.ID) + if passesFilter != nil && !passesFilter(cs) { + continue } + + sessions[s.ID] = cs } return sessions, nil @@ -710,7 +742,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) { // Select any candidate session at random, and remove it from the set of // candidate sessions. - var candidateSession *wtdb.ClientSession + var candidateSession *ClientSession for id, sessionInfo := range c.candidateSessions { delete(c.candidateSessions, id) @@ -793,13 +825,10 @@ func (c *TowerClient) backupDispatcher() { msg.errChan <- c.handleNewTower(msg) // A tower has been requested to be removed. We'll - // immediately return an error as we want to avoid the - // possibility of a new session being negotiated with - // this request's tower. + // only allow removal of it if the address in question + // is not currently being used for session negotiation. case msg := <-c.staleTowers: - msg.errChan <- errors.New("removing towers " + - "is disallowed while a new session " + - "negotiation is in progress") + msg.errChan <- c.handleStaleTower(msg) case <-c.forceQuit: return @@ -1069,7 +1098,7 @@ func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error // newSessionQueue creates a sessionQueue from a ClientSession loaded from the // database and supplying it with the resources needed by the client. -func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession, +func (c *TowerClient) newSessionQueue(s *ClientSession, updates []wtdb.CommittedUpdate) *sessionQueue { return newSessionQueue(&sessionQueueConfig{ @@ -1089,7 +1118,7 @@ func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession, // getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the // passed ClientSession. If it exists, the active sessionQueue is returned. // Otherwise, a new sessionQueue is initialized and added to the set. -func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession, +func (c *TowerClient) getOrInitActiveQueue(s *ClientSession, updates []wtdb.CommittedUpdate) *sessionQueue { if sq, ok := c.activeSessions[s.ID]; ok { @@ -1103,7 +1132,7 @@ func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession, // adds the sessionQueue to the activeSessions set, and starts the sessionQueue // so that it can deliver any committed updates or begin accepting newly // assigned tasks. -func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession, +func (c *TowerClient) initActiveQueue(s *ClientSession, updates []wtdb.CommittedUpdate) *sessionQueue { // Initialize the session queue, providing it with all the resources it @@ -1156,10 +1185,16 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error { // We'll start by updating our persisted state, followed by our // in-memory state, with the new tower. This might not actually be a new // tower, but it might include a new address at which it can be reached. - tower, err := c.cfg.DB.CreateTower(msg.addr) + dbTower, err := c.cfg.DB.CreateTower(msg.addr) if err != nil { return err } + + tower, err := NewTowerFromDBTower(dbTower) + if err != nil { + return err + } + c.candidateTowers.AddCandidate(tower) // Include all of its corresponding sessions to our set of candidates. @@ -1215,18 +1250,31 @@ func (c *TowerClient) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error { // We'll load the tower before potentially removing it in order to // retrieve its ID within the database. - tower, err := c.cfg.DB.LoadTower(msg.pubKey) + dbTower, err := c.cfg.DB.LoadTower(msg.pubKey) if err != nil { return err } - // We'll update our persisted state, followed by our in-memory state, - // with the stale tower. - if err := c.cfg.DB.RemoveTower(msg.pubKey, msg.addr); err != nil { + // We'll first update our in-memory state followed by our persisted + // state, with the stale tower. The removal of the tower address from + // the in-memory state will fail if the address is currently being used + // for a session negotiation. + err = c.candidateTowers.RemoveCandidate(dbTower.ID, msg.addr) + if err != nil { return err } - err = c.candidateTowers.RemoveCandidate(tower.ID, msg.addr) - if err != nil { + + if err := c.cfg.DB.RemoveTower(msg.pubKey, msg.addr); err != nil { + // If the persisted state update fails, re-add the address to + // our in-memory state. + tower, newTowerErr := NewTowerFromDBTower(dbTower) + if newTowerErr != nil { + log.Errorf("could not create new in-memory tower: %v", + newTowerErr) + } else { + c.candidateTowers.AddCandidate(tower) + } + return err } @@ -1239,7 +1287,7 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error { // Otherwise, the tower should no longer be used for future session // negotiations and backups. pubKey := msg.pubKey.SerializeCompressed() - sessions, err := c.cfg.DB.ListClientSessions(&tower.ID) + sessions, err := c.cfg.DB.ListClientSessions(&dbTower.ID) if err != nil { return fmt.Errorf("unable to retrieve sessions for tower %x: "+ "%v", pubKey, err) @@ -1251,7 +1299,7 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error { // If our active session queue corresponds to the stale tower, we'll // proceed to negotiate a new one. if c.sessionQueue != nil { - activeTower := c.sessionQueue.towerAddr.IdentityKey.SerializeCompressed() + activeTower := c.sessionQueue.tower.IdentityKey.SerializeCompressed() if bytes.Equal(pubKey, activeTower) { c.sessionQueue = nil } diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index fcaad1588..1490c6d10 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -2,6 +2,8 @@ package wtclient_test import ( "encoding/binary" + "errors" + "fmt" "net" "sync" "testing" @@ -16,6 +18,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tor" @@ -31,7 +34,8 @@ import ( const ( csvDelay uint32 = 144 - towerAddrStr = "18.28.243.2:9911" + towerAddrStr = "18.28.243.2:9911" + towerAddr2Str = "19.29.244.3:9912" ) var ( @@ -63,6 +67,8 @@ var ( ) addrScript, _ = txscript.PayToAddrScript(addr) + + waitTime = 5 * time.Second ) // randPrivKey generates a new secp keypair, and returns the public key. @@ -76,37 +82,34 @@ func randPrivKey(t *testing.T) *btcec.PrivateKey { } type mockNet struct { - mu sync.RWMutex - connCallback func(wtserver.Peer) + mu sync.RWMutex + connCallbacks map[string]func(wtserver.Peer) } -func newMockNet(cb func(wtserver.Peer)) *mockNet { +func newMockNet() *mockNet { return &mockNet{ - connCallback: cb, + connCallbacks: make(map[string]func(peer wtserver.Peer)), } } -func (m *mockNet) Dial(network string, address string, - timeout time.Duration) (net.Conn, error) { - +func (m *mockNet) Dial(_, _ string, _ time.Duration) (net.Conn, error) { return nil, nil } -func (m *mockNet) LookupHost(host string) ([]string, error) { +func (m *mockNet) LookupHost(_ string) ([]string, error) { panic("not implemented") } -func (m *mockNet) LookupSRV(service string, proto string, name string) (string, []*net.SRV, error) { +func (m *mockNet) LookupSRV(_, _, _ string) (string, []*net.SRV, error) { panic("not implemented") } -func (m *mockNet) ResolveTCPAddr(network string, address string) (*net.TCPAddr, error) { +func (m *mockNet) ResolveTCPAddr(_, _ string) (*net.TCPAddr, error) { panic("not implemented") } func (m *mockNet) AuthDial(local keychain.SingleKeyECDH, - netAddr *lnwire.NetAddress, - dialer tor.DialFunc) (wtserver.Peer, error) { + netAddr *lnwire.NetAddress, _ tor.DialFunc) (wtserver.Peer, error) { localPk := local.PubKey() localAddr := &net.TCPAddr{ @@ -119,16 +122,31 @@ func (m *mockNet) AuthDial(local keychain.SingleKeyECDH, ) m.mu.RLock() - m.connCallback(remotePeer) - m.mu.RUnlock() + defer m.mu.RUnlock() + cb, ok := m.connCallbacks[netAddr.String()] + if !ok { + return nil, fmt.Errorf("no callback registered for this peer") + } + + cb(remotePeer) return localPeer, nil } -func (m *mockNet) setConnCallback(cb func(wtserver.Peer)) { +func (m *mockNet) registerConnCallback(netAddr *lnwire.NetAddress, + cb func(wtserver.Peer)) { + m.mu.Lock() defer m.mu.Unlock() - m.connCallback = cb + + m.connCallbacks[netAddr.String()] = cb +} + +func (m *mockNet) removeConnCallback(netAddr *lnwire.NetAddress) { + m.mu.Lock() + defer m.mu.Unlock() + + delete(m.connCallbacks, netAddr.String()) } type mockChannel struct { @@ -325,10 +343,8 @@ func (c *mockChannel) sendPayment(t *testing.T, amt lnwire.MilliSatoshi) { c.mu.Lock() defer c.mu.Unlock() - if c.localBalance < amt { - t.Fatalf("insufficient funds to send, need: %v, have: %v", - amt, c.localBalance) - } + require.GreaterOrEqualf(t, c.localBalance, amt, "insufficient funds "+ + "to send, need: %v, have: %v", amt, c.localBalance) c.localBalance -= amt c.remoteBalance += amt @@ -343,10 +359,8 @@ func (c *mockChannel) receivePayment(t *testing.T, amt lnwire.MilliSatoshi) { c.mu.Lock() defer c.mu.Unlock() - if c.remoteBalance < amt { - t.Fatalf("insufficient funds to recv, need: %v, have: %v", - amt, c.remoteBalance) - } + require.GreaterOrEqualf(t, c.remoteBalance, amt, "insufficient funds "+ + "to recv, need: %v, have: %v", amt, c.remoteBalance) c.localBalance += amt c.remoteBalance -= amt @@ -381,6 +395,8 @@ type testHarness struct { mu sync.Mutex channels map[lnwire.ChannelID]*mockChannel + + quit chan struct{} } type harnessCfg struct { @@ -389,6 +405,7 @@ type harnessCfg struct { policy wtpolicy.Policy noRegisterChan0 bool noAckCreateSession bool + noServerStart bool } func newHarness(t *testing.T, cfg harnessCfg) *testHarness { @@ -420,11 +437,8 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { NoAckCreateSession: cfg.noAckCreateSession, } - server, err := wtserver.New(serverCfg) - require.NoError(t, err, "unable to create wtserver") - signer := wtmock.NewMockSigner() - mockNet := newMockNet(server.InboundPeerConnected) + mockNet := newMockNet() clientDB := wtmock.NewClientDB() clientCfg := &wtclient.Config{ @@ -443,21 +457,6 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { MaxBackoff: time.Second, ForceQuitDelay: 10 * time.Second, } - client, err := wtclient.New(clientCfg) - require.NoError(t, err, "Unable to create wtclient") - - if err := server.Start(); err != nil { - t.Fatalf("Unable to start wtserver: %v", err) - } - - if err = client.Start(); err != nil { - server.Stop() - t.Fatalf("Unable to start wtclient: %v", err) - } - if err := client.AddTower(towerAddr); err != nil { - server.Stop() - t.Fatalf("Unable to add tower to wtclient: %v", err) - } h := &testHarness{ t: t, @@ -466,14 +465,24 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { capacity: cfg.localBalance + cfg.remoteBalance, clientDB: clientDB, clientCfg: clientCfg, - client: client, serverAddr: towerAddr, serverDB: serverDB, serverCfg: serverCfg, - server: server, net: mockNet, channels: make(map[lnwire.ChannelID]*mockChannel), + quit: make(chan struct{}), } + t.Cleanup(func() { + close(h.quit) + }) + + if !cfg.noServerStart { + h.startServer() + t.Cleanup(h.stopServer) + } + + h.startClient() + t.Cleanup(h.client.ForceQuit) h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance) if !cfg.noRegisterChan0 { @@ -490,15 +499,20 @@ func (h *testHarness) startServer() { var err error h.server, err = wtserver.New(h.serverCfg) - if err != nil { - h.t.Fatalf("unable to create wtserver: %v", err) - } + require.NoError(h.t, err) - h.net.setConnCallback(h.server.InboundPeerConnected) + h.net.registerConnCallback(h.serverAddr, h.server.InboundPeerConnected) - if err := h.server.Start(); err != nil { - h.t.Fatalf("unable to start wtserver: %v", err) - } + require.NoError(h.t, h.server.Start()) +} + +// stopServer stops the main harness server. +func (h *testHarness) stopServer() { + h.t.Helper() + + h.net.removeConnCallback(h.serverAddr) + + require.NoError(h.t, h.server.Stop()) } // startClient creates a new server using the harness's current clientCf and @@ -507,24 +521,16 @@ func (h *testHarness) startClient() { h.t.Helper() towerTCPAddr, err := net.ResolveTCPAddr("tcp", towerAddrStr) - if err != nil { - h.t.Fatalf("Unable to resolve tower TCP addr: %v", err) - } + require.NoError(h.t, err) towerAddr := &lnwire.NetAddress{ IdentityKey: h.serverCfg.NodeKeyECDH.PubKey(), Address: towerTCPAddr, } h.client, err = wtclient.New(h.clientCfg) - if err != nil { - h.t.Fatalf("unable to create wtclient: %v", err) - } - if err := h.client.Start(); err != nil { - h.t.Fatalf("unable to start wtclient: %v", err) - } - if err := h.client.AddTower(towerAddr); err != nil { - h.t.Fatalf("unable to add tower to wtclient: %v", err) - } + require.NoError(h.t, err) + require.NoError(h.t, h.client.Start()) + require.NoError(h.t, h.client.AddTower(towerAddr)) } // chanIDFromInt creates a unique channel id given a unique integral id. @@ -553,9 +559,7 @@ func (h *testHarness) makeChannel(id uint64, } c.mu.Unlock() - if ok { - h.t.Fatalf("channel %d already created", id) - } + require.Falsef(h.t, ok, "channel %d already created", id) } // channel retrieves the channel corresponding to id. @@ -567,9 +571,7 @@ func (h *testHarness) channel(id uint64) *mockChannel { h.mu.Lock() c, ok := h.channels[chanIDFromInt(id)] h.mu.Unlock() - if !ok { - h.t.Fatalf("unable to fetch channel %d", id) - } + require.Truef(h.t, ok, "unable to fetch channel %d", id) return c } @@ -580,9 +582,7 @@ func (h *testHarness) registerChannel(id uint64) { chanID := chanIDFromInt(id) err := h.client.RegisterChannel(chanID) - if err != nil { - h.t.Fatalf("unable to register channel %d: %v", id, err) - } + require.NoError(h.t, err) } // advanceChannelN calls advanceState on the channel identified by id the number @@ -621,11 +621,10 @@ func (h *testHarness) backupState(id, i uint64, expErr error) { _, retribution := h.channel(id).getState(i) chanID := chanIDFromInt(id) - err := h.client.BackupState(&chanID, retribution, channeldb.SingleFunderBit) - if err != expErr { - h.t.Fatalf("back error mismatch, want: %v, got: %v", - expErr, err) - } + err := h.client.BackupState( + &chanID, retribution, channeldb.SingleFunderBit, + ) + require.ErrorIs(h.t, expErr, err) } // sendPayments instructs the channel identified by id to send amt to the remote @@ -685,10 +684,8 @@ func (h *testHarness) waitServerUpdates(hints []blob.BreachHint, hintSet[hint] = struct{}{} } - if len(hints) != len(hintSet) { - h.t.Fatalf("breach hints are not unique, list-len: %d "+ - "set-len: %d", len(hints), len(hintSet)) - } + require.Lenf(h.t, hints, len(hintSet), "breach hints are not unique, "+ + "list-len: %d set-len: %d", len(hints), len(hintSet)) // Closure to assert the server's matches are consistent with the hint // set. @@ -698,12 +695,9 @@ func (h *testHarness) waitServerUpdates(hints []blob.BreachHint, } for _, match := range matches { - if _, ok := hintSet[match.Hint]; ok { - continue - } - - h.t.Fatalf("match %v in db is not in hint set", - match.Hint) + _, ok := hintSet[match.Hint] + require.Truef(h.t, ok, "match %v in db is not in "+ + "hint set", match.Hint) } return true @@ -714,31 +708,24 @@ func (h *testHarness) waitServerUpdates(hints []blob.BreachHint, select { case <-time.After(time.Second): matches, err := h.serverDB.QueryMatches(hints) - switch { - case err != nil: - h.t.Fatalf("unable to query for hints: %v", err) + require.NoError(h.t, err, "unable to query for hints") - case wantUpdates && serverHasHints(matches): + if wantUpdates && serverHasHints(matches) { return + } - case wantUpdates: + if wantUpdates { h.t.Logf("Received %d/%d\n", len(matches), len(hints)) } case <-failTimeout: matches, err := h.serverDB.QueryMatches(hints) - switch { - case err != nil: - h.t.Fatalf("unable to query for hints: %v", err) - - case serverHasHints(matches): - return - - default: - h.t.Fatalf("breach hints not received, only "+ - "got %d/%d", len(matches), len(hints)) - } + require.NoError(h.t, err, "unable to query for hints") + require.Truef(h.t, serverHasHints(matches), "breach "+ + "hints not received, only got %d/%d", + len(matches), len(hints)) + return } } } @@ -751,25 +738,18 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint, // Query for matches on the provided hints. matches, err := h.serverDB.QueryMatches(hints) - if err != nil { - h.t.Fatalf("unable to query for matches: %v", err) - } + require.NoError(h.t, err) // Assert that the number of matches is exactly the number of provided // hints. - if len(matches) != len(hints) { - h.t.Fatalf("expected: %d matches, got: %d", len(hints), - len(matches)) - } + require.Lenf(h.t, matches, len(hints), "expected: %d matches, got: %d", + len(hints), len(matches)) // Assert that all of the matches correspond to a session with the // expected policy. for _, match := range matches { matchPolicy := match.SessionInfo.Policy - if expPolicy != matchPolicy { - h.t.Fatalf("expected session to have policy: %v, "+ - "got: %v", expPolicy, matchPolicy) - } + require.Equal(h.t, expPolicy, matchPolicy) } } @@ -777,9 +757,8 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint, func (h *testHarness) addTower(addr *lnwire.NetAddress) { h.t.Helper() - if err := h.client.AddTower(addr); err != nil { - h.t.Fatalf("unable to add tower: %v", err) - } + err := h.client.AddTower(addr) + require.NoError(h.t, err) } // removeTower removes a tower from the client. If `addr` is specified, then the @@ -787,9 +766,8 @@ func (h *testHarness) addTower(addr *lnwire.NetAddress) { func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) { h.t.Helper() - if err := h.client.RemoveTower(pubKey, addr); err != nil { - h.t.Fatalf("unable to remove tower: %v", err) - } + err := h.client.RemoveTower(pubKey, addr) + require.NoError(h.t, err) } const ( @@ -976,10 +954,9 @@ var clientTests = []clientTest{ // Now, restart the server and prevent it from acking // state updates. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckUpdates = true h.startServer() - defer h.server.Stop() // Send the next state update to the tower. Since the // tower isn't acking state updates, we expect this @@ -997,15 +974,13 @@ var clientTests = []clientTest{ // Restart the server and allow it to ack the updates // after the client retransmits the unacked update. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckUpdates = false h.startServer() - defer h.server.Stop() // Restart the client and allow it to process the // committed update. h.startClient() - defer h.client.ForceQuit() // Wait for the committed update to be accepted by the // tower. @@ -1049,10 +1024,9 @@ var clientTests = []clientTest{ // Restart the server and prevent it from acking state // updates. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckUpdates = true h.startServer() - defer h.server.Stop() // Now, queue the retributions for backup. h.backupStates(chanID, 0, numUpdates, nil) @@ -1068,14 +1042,13 @@ var clientTests = []clientTest{ // Restart the server and allow it to ack the updates // after the client retransmits the unacked updates. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckUpdates = false h.startServer() - defer h.server.Stop() // Wait for all of the updates to be populated in the // server's database. - h.waitServerUpdates(hints, 5*time.Second) + h.waitServerUpdates(hints, waitTime) }, }, { @@ -1212,23 +1185,21 @@ var clientTests = []clientTest{ // Restart the server and allow it to ack session // creation. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckCreateSession = false h.startServer() - defer h.server.Stop() // Restart the client with the same policy, which will // immediately try to overwrite the old session with an // identical one. h.startClient() - defer h.client.ForceQuit() // Now, queue the retributions for backup. h.backupStates(chanID, 0, numUpdates, nil) // Wait for all of the updates to be populated in the // server's database. - h.waitServerUpdates(hints, 5*time.Second) + h.waitServerUpdates(hints, waitTime) // Assert that the server has updates for the clients // most recent policy. @@ -1270,24 +1241,22 @@ var clientTests = []clientTest{ // Restart the server and allow it to ack session // creation. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckCreateSession = false h.startServer() - defer h.server.Stop() // Restart the client with a new policy, which will // immediately try to overwrite the prior session with // the old policy. h.clientCfg.Policy.SweepFeeRate *= 2 h.startClient() - defer h.client.ForceQuit() // Now, queue the retributions for backup. h.backupStates(chanID, 0, numUpdates, nil) // Wait for all of the updates to be populated in the // server's database. - h.waitServerUpdates(hints, 5*time.Second) + h.waitServerUpdates(hints, waitTime) // Assert that the server has updates for the clients // most recent policy. @@ -1341,14 +1310,13 @@ var clientTests = []clientTest{ // Restart the client with a new policy. h.clientCfg.Policy.MaxUpdates = 20 h.startClient() - defer h.client.ForceQuit() // Now, queue the second half of the retributions. h.backupStates(chanID, numUpdates/2, numUpdates, nil) // Wait for all of the updates to be populated in the // server's database. - h.waitServerUpdates(hints, 5*time.Second) + h.waitServerUpdates(hints, waitTime) // Assert that the server has updates for the client's // original policy. @@ -1389,13 +1357,12 @@ var clientTests = []clientTest{ // Wait for the first half of the updates to be // populated in the server's database. - h.waitServerUpdates(hints[:len(hints)/2], 5*time.Second) + h.waitServerUpdates(hints[:len(hints)/2], waitTime) // Restart the client, so we can ensure the deduping is // maintained across restarts. h.client.Stop() h.startClient() - defer h.client.ForceQuit() // Try to back up the full range of retributions. Only // the second half should actually be sent. @@ -1403,7 +1370,7 @@ var clientTests = []clientTest{ // Wait for all of the updates to be populated in the // server's database. - h.waitServerUpdates(hints, 5*time.Second) + h.waitServerUpdates(hints, waitTime) }, }, { @@ -1431,7 +1398,7 @@ var clientTests = []clientTest{ // first two. hints := h.advanceChannelN(chanID, numUpdates) h.backupStates(chanID, 0, numUpdates/2, nil) - h.waitServerUpdates(hints[:numUpdates/2], 5*time.Second) + h.waitServerUpdates(hints[:numUpdates/2], waitTime) // Fully remove the tower, causing its existing sessions // to be marked inactive. @@ -1445,8 +1412,7 @@ var clientTests = []clientTest{ // Re-add the tower. We prevent the tower from acking // session creation to ensure the inactive sessions are // not used. - err := h.server.Stop() - require.Nil(h.t, err) + h.stopServer() h.serverCfg.NoAckCreateSession = true h.startServer() h.addTower(h.serverAddr) @@ -1455,11 +1421,10 @@ var clientTests = []clientTest{ // Finally, allow the tower to ack session creation, // allowing the state updates to be sent through the new // session. - err = h.server.Stop() - require.Nil(h.t, err) + h.stopServer() h.serverCfg.NoAckCreateSession = false h.startServer() - h.waitServerUpdates(hints[numUpdates/2:], 5*time.Second) + h.waitServerUpdates(hints[numUpdates/2:], waitTime) }, }, { @@ -1490,13 +1455,12 @@ var clientTests = []clientTest{ // Back up 4 of the 5 states for the negotiated session. h.backupStates(chanID, 0, maxUpdates-1, nil) - h.waitServerUpdates(hints[:maxUpdates-1], 5*time.Second) + h.waitServerUpdates(hints[:maxUpdates-1], waitTime) // Now, restart the tower and prevent it from acking any // new sessions. We do this here as once the last slot // is exhausted the client will attempt to renegotiate. - err := h.server.Stop() - require.Nil(h.t, err) + h.stopServer() h.serverCfg.NoAckCreateSession = true h.startServer() @@ -1506,15 +1470,189 @@ var clientTests = []clientTest{ // the final state. We'll only wait for the first five // states to arrive at the tower. h.backupStates(chanID, maxUpdates-1, numUpdates, nil) - h.waitServerUpdates(hints[:maxUpdates], 5*time.Second) + h.waitServerUpdates(hints[:maxUpdates], waitTime) // Finally, stop the client which will continue to // attempt session negotiation since it has one more // state to process. After the force quite delay // expires, the client should force quite itself and // allow the test to complete. - err = h.client.Stop() - require.Nil(h.t, err) + h.stopServer() + }, + }, + { + // Assert that if a client changes the address for a server and + // then tries to back up updates then the client will switch to + // the new address. + name: "change address of existing session", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 5, + }, + }, + fn: func(h *testHarness) { + const ( + chanID = 0 + numUpdates = 6 + maxUpdates = 5 + ) + + // Advance the channel to create all states. + hints := h.advanceChannelN(chanID, numUpdates) + + h.backupStates(chanID, 0, numUpdates/2, nil) + + // Wait for the first half of the updates to be + // populated in the server's database. + h.waitServerUpdates(hints[:len(hints)/2], waitTime) + + // Stop the server. + h.stopServer() + + // Change the address of the server. + towerTCPAddr, err := net.ResolveTCPAddr( + "tcp", towerAddr2Str, + ) + require.NoError(h.t, err) + + oldAddr := h.serverAddr.Address + towerAddr := &lnwire.NetAddress{ + IdentityKey: h.serverAddr.IdentityKey, + Address: towerTCPAddr, + } + h.serverAddr = towerAddr + + // Add the new tower address to the client. + err = h.client.AddTower(towerAddr) + require.NoError(h.t, err) + + // Remove the old tower address from the client. + err = h.client.RemoveTower( + towerAddr.IdentityKey, oldAddr, + ) + require.NoError(h.t, err) + + // Restart the server. + h.startServer() + + // Now attempt to back up the rest of the updates. + h.backupStates(chanID, numUpdates/2, maxUpdates, nil) + + // Assert that the server does receive the updates. + h.waitServerUpdates(hints[:maxUpdates], waitTime) + }, + }, + { + // Assert that a user is able to remove a tower address during + // session negotiation as long as the address in question is not + // currently being used. + name: "removing a tower during session negotiation", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 5, + }, + noServerStart: true, + }, + fn: func(h *testHarness) { + // The server has not started yet and so no session + // negotiation with the server will be in progress, so + // the client should be able to remove the server. + err := wait.NoError(func() error { + return h.client.RemoveTower( + h.serverAddr.IdentityKey, nil, + ) + }, waitTime) + require.NoError(h.t, err) + + // Set the server up so that its Dial function hangs + // when the client calls it. This will force the client + // to remain in the state where it has locked the + // address of the server. + h.server, err = wtserver.New(h.serverCfg) + require.NoError(h.t, err) + + cancel := make(chan struct{}) + h.net.registerConnCallback( + h.serverAddr, func(peer wtserver.Peer) { + select { + case <-h.quit: + case <-cancel: + } + }, + ) + + // Also add a new tower address. + towerTCPAddr, err := net.ResolveTCPAddr( + "tcp", towerAddr2Str, + ) + require.NoError(h.t, err) + + towerAddr := &lnwire.NetAddress{ + IdentityKey: h.serverAddr.IdentityKey, + Address: towerTCPAddr, + } + + // Register the new address in the mock-net. + h.net.registerConnCallback( + towerAddr, h.server.InboundPeerConnected, + ) + + // Now start the server. + require.NoError(h.t, h.server.Start()) + + // Re-add the server to the client + err = h.client.AddTower(h.serverAddr) + require.NoError(h.t, err) + + // Also add the new tower address. + err = h.client.AddTower(towerAddr) + require.NoError(h.t, err) + + // Assert that if the client attempts to remove the + // tower's first address, then it will error due to + // address currently being locked for session + // negotiation. + err = wait.Predicate(func() bool { + err = h.client.RemoveTower( + h.serverAddr.IdentityKey, + h.serverAddr.Address, + ) + return errors.Is(err, wtclient.ErrAddrInUse) + }, waitTime) + require.NoError(h.t, err) + + // Assert that the second address can be removed since + // it is not being used for session negotiation. + err = wait.NoError(func() error { + return h.client.RemoveTower( + h.serverAddr.IdentityKey, towerTCPAddr, + ) + }, waitTime) + require.NoError(h.t, err) + + // Allow the dial to the first address to stop hanging. + close(cancel) + + // Assert that the client can now remove the first + // address. + err = wait.NoError(func() error { + return h.client.RemoveTower( + h.serverAddr.IdentityKey, nil, + ) + }, waitTime) + require.NoError(h.t, err) }, }, } @@ -1528,10 +1666,6 @@ func TestClient(t *testing.T) { t.Parallel() h := newHarness(t, tc.cfg) - t.Cleanup(func() { - require.NoError(t, h.server.Stop()) - h.client.ForceQuit() - }) tc.fn(h) }) diff --git a/watchtower/wtclient/errors.go b/watchtower/wtclient/errors.go index 857af3087..f496074bf 100644 --- a/watchtower/wtclient/errors.go +++ b/watchtower/wtclient/errors.go @@ -20,10 +20,6 @@ var ( // down. ErrNegotiatorExiting = errors.New("negotiator exiting") - // ErrNoTowerAddrs signals that the client could not be created because - // we have no addresses with which we can reach a tower. - ErrNoTowerAddrs = errors.New("no tower addresses") - // ErrFailedNegotiation signals that the session negotiator could not // acquire a new session as requested. ErrFailedNegotiation = errors.New("session negotiation unsuccessful") diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 5f2357950..ba6546328 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -118,3 +118,50 @@ type ECDHKeyRing interface { // key. DeriveKey(keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error) } + +// Tower represents the info about a watchtower server that a watchtower client +// needs in order to connect to it. +type Tower struct { + // ID is the unique, db-assigned, identifier for this tower. + ID wtdb.TowerID + + // IdentityKey is the public key of the remote node, used to + // authenticate the brontide transport. + IdentityKey *btcec.PublicKey + + // Addresses is an AddressIterator that can be used to manage the + // addresses for this tower. + Addresses AddressIterator +} + +// NewTowerFromDBTower converts a wtdb.Tower, which uses a static address list, +// into a Tower which uses an address iterator. +func NewTowerFromDBTower(t *wtdb.Tower) (*Tower, error) { + addrs, err := newAddressIterator(t.Addresses...) + if err != nil { + return nil, err + } + + return &Tower{ + ID: t.ID, + IdentityKey: t.IdentityKey, + Addresses: addrs, + }, nil +} + +// ClientSession represents the session that a tower client has with a server. +type ClientSession struct { + // ID is the client's public key used when authenticating with the + // tower. + ID wtdb.SessionID + + wtdb.ClientSessionBody + + // Tower represents the tower that the client session has been made + // with. + Tower *Tower + + // SessionKeyECDH is the ECDH capable wrapper of the ephemeral secret + // key used to connect to the watchtower. + SessionKeyECDH keychain.SingleKeyECDH +} diff --git a/watchtower/wtclient/session_negotiator.go b/watchtower/wtclient/session_negotiator.go index 9ccaf5b79..db0f543a1 100644 --- a/watchtower/wtclient/session_negotiator.go +++ b/watchtower/wtclient/session_negotiator.go @@ -25,7 +25,7 @@ type SessionNegotiator interface { // NewSessions is a read-only channel where newly negotiated sessions // will be delivered. - NewSessions() <-chan *wtdb.ClientSession + NewSessions() <-chan *ClientSession // Start safely initializes the session negotiator. Start() error @@ -105,8 +105,8 @@ type sessionNegotiator struct { log btclog.Logger dispatcher chan struct{} - newSessions chan *wtdb.ClientSession - successfulNegotiations chan *wtdb.ClientSession + newSessions chan *ClientSession + successfulNegotiations chan *ClientSession wg sync.WaitGroup quit chan struct{} @@ -139,8 +139,8 @@ func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator { log: cfg.Log, localInit: localInit, dispatcher: make(chan struct{}, 1), - newSessions: make(chan *wtdb.ClientSession), - successfulNegotiations: make(chan *wtdb.ClientSession), + newSessions: make(chan *ClientSession), + successfulNegotiations: make(chan *ClientSession), quit: make(chan struct{}), } } @@ -171,7 +171,7 @@ func (n *sessionNegotiator) Stop() error { // NewSessions returns a receive-only channel from which newly negotiated // sessions will be returned. -func (n *sessionNegotiator) NewSessions() <-chan *wtdb.ClientSession { +func (n *sessionNegotiator) NewSessions() <-chan *ClientSession { return n.newSessions } @@ -333,18 +333,10 @@ retryWithBackoff: } } -// createSession takes a tower an attempts to negotiate a session using any of +// createSession takes a tower and attempts to negotiate a session using any of // its stored addresses. This method returns after the first successful -// negotiation, or after all addresses have failed with ErrFailedNegotiation. If -// the tower has no addresses, ErrNoTowerAddrs is returned. -func (n *sessionNegotiator) createSession(tower *wtdb.Tower, - keyIndex uint32) error { - - // If the tower has no addresses, there's nothing we can do. - if len(tower.Addresses) == 0 { - return ErrNoTowerAddrs - } - +// negotiation, or after all addresses have failed with ErrFailedNegotiation. +func (n *sessionNegotiator) createSession(tower *Tower, keyIndex uint32) error { sessionKeyDesc, err := n.cfg.SecretKeyRing.DeriveKey( keychain.KeyLocator{ Family: keychain.KeyFamilyTowerSession, @@ -358,8 +350,15 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower, sessionKeyDesc, n.cfg.SecretKeyRing, ) - for _, lnAddr := range tower.LNAddrs() { - err := n.tryAddress(sessionKey, keyIndex, tower, lnAddr) + addr := tower.Addresses.PeekAndLock() + for { + lnAddr := &lnwire.NetAddress{ + IdentityKey: tower.IdentityKey, + Address: addr, + } + + err = n.tryAddress(sessionKey, keyIndex, tower, lnAddr) + tower.Addresses.ReleaseLock(addr) switch { case err == ErrPermanentTowerFailure: // TODO(conner): report to iterator? can then be reset @@ -370,6 +369,15 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower, n.log.Debugf("Request for session negotiation with "+ "tower=%s failed, trying again -- reason: "+ "%v", lnAddr, err) + + // Get the next tower address if there is one. + addr, err = tower.Addresses.NextAndLock() + if err == ErrAddressesExhausted { + tower.Addresses.Reset() + + return ErrFailedNegotiation + } + continue default: @@ -385,7 +393,7 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower, // returns true if all steps succeed and the new session has been persisted, and // fails otherwise. func (n *sessionNegotiator) tryAddress(sessionKey keychain.SingleKeyECDH, - keyIndex uint32, tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error { + keyIndex uint32, tower *Tower, lnAddr *lnwire.NetAddress) error { // Connect to the tower address using our generated session key. conn, err := n.cfg.Dial(sessionKey, lnAddr) @@ -456,26 +464,31 @@ func (n *sessionNegotiator) tryAddress(sessionKey keychain.SingleKeyECDH, rewardPkScript := createSessionReply.Data sessionID := wtdb.NewSessionIDFromPubKey(sessionKey.PubKey()) - clientSession := &wtdb.ClientSession{ + dbClientSession := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ TowerID: tower.ID, KeyIndex: keyIndex, Policy: n.cfg.Policy, RewardPkScript: rewardPkScript, }, - Tower: tower, - SessionKeyECDH: sessionKey, - ID: sessionID, + ID: sessionID, } - err = n.cfg.DB.CreateClientSession(clientSession) + err = n.cfg.DB.CreateClientSession(dbClientSession) if err != nil { return fmt.Errorf("unable to persist ClientSession: %v", err) } n.log.Debugf("New session negotiated with %s, policy: %s", - lnAddr, clientSession.Policy) + lnAddr, dbClientSession.Policy) + + clientSession := &ClientSession{ + ID: sessionID, + ClientSessionBody: dbClientSession.ClientSessionBody, + Tower: tower, + SessionKeyECDH: sessionKey, + } // We have a newly negotiated session, return it to the // dispatcher so that it can update how many outstanding diff --git a/watchtower/wtclient/session_queue.go b/watchtower/wtclient/session_queue.go index d149d09b6..7d98ec86f 100644 --- a/watchtower/wtclient/session_queue.go +++ b/watchtower/wtclient/session_queue.go @@ -34,7 +34,7 @@ const ( type sessionQueueConfig struct { // ClientSession provides access to the negotiated session parameters // and updating its persistent storage. - ClientSession *wtdb.ClientSession + ClientSession *ClientSession // ChainHash identifies the chain for which the session's justice // transactions are targeted. @@ -97,7 +97,7 @@ type sessionQueue struct { queueCond *sync.Cond localInit *wtwire.Init - towerAddr *lnwire.NetAddress + tower *Tower seqNum uint16 @@ -117,18 +117,13 @@ func newSessionQueue(cfg *sessionQueueConfig, cfg.ChainHash, ) - towerAddr := &lnwire.NetAddress{ - IdentityKey: cfg.ClientSession.Tower.IdentityKey, - Address: cfg.ClientSession.Tower.Addresses[0], - } - sq := &sessionQueue{ cfg: cfg, log: cfg.Log, commitQueue: list.New(), pendingQueue: list.New(), localInit: localInit, - towerAddr: towerAddr, + tower: cfg.ClientSession.Tower, seqNum: cfg.ClientSession.SeqNum, retryBackoff: cfg.MinBackoff, quit: make(chan struct{}), @@ -293,18 +288,48 @@ func (q *sessionQueue) sessionManager() { // drainBackups attempts to send all pending updates in the queue to the tower. func (q *sessionQueue) drainBackups() { - // First, check that we are able to dial this session's tower. - conn, err := q.cfg.Dial(q.cfg.ClientSession.SessionKeyECDH, q.towerAddr) - if err != nil { - q.log.Errorf("SessionQueue(%s) unable to dial tower at %v: %v", - q.ID(), q.towerAddr, err) + var ( + conn wtserver.Peer + err error + towerAddr = q.tower.Addresses.Peek() + ) - q.increaseBackoff() - select { - case <-time.After(q.retryBackoff): - case <-q.forceQuit: + for { + q.log.Infof("SessionQueue(%s) attempting to dial tower at %v", + q.ID(), towerAddr) + + // First, check that we are able to dial this session's tower. + conn, err = q.cfg.Dial( + q.cfg.ClientSession.SessionKeyECDH, &lnwire.NetAddress{ + IdentityKey: q.tower.IdentityKey, + Address: towerAddr, + }, + ) + if err != nil { + // If there are more addrs available, immediately try + // those. + nextAddr, iteratorErr := q.tower.Addresses.Next() + if iteratorErr == nil { + towerAddr = nextAddr + continue + } + + // Otherwise, if we have exhausted the address list, + // back off and try again later. + q.tower.Addresses.Reset() + + q.log.Errorf("SessionQueue(%s) unable to dial tower "+ + "at any available Addresses: %v", q.ID(), err) + + q.increaseBackoff() + select { + case <-time.After(q.retryBackoff): + case <-q.forceQuit: + } + return } - return + + break } defer conn.Close() @@ -324,9 +349,7 @@ func (q *sessionQueue) drainBackups() { } // Now, send the state update to the tower and wait for a reply. - err = q.sendStateUpdate( - conn, stateUpdate, q.localInit, sendInit, isPending, - ) + err = q.sendStateUpdate(conn, stateUpdate, sendInit, isPending) if err != nil { q.log.Errorf("SessionQueue(%s) unable to send state "+ "update: %v", q.ID(), err) @@ -483,8 +506,12 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, // variable indicates whether we should back off before attempting to send the // next state update. func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, - stateUpdate *wtwire.StateUpdate, localInit *wtwire.Init, - sendInit, isPending bool) error { + stateUpdate *wtwire.StateUpdate, sendInit, isPending bool) error { + + towerAddr := &lnwire.NetAddress{ + IdentityKey: conn.RemotePub(), + Address: conn.RemoteAddr(), + } // If this is the first message being sent to the tower, we must send an // Init message to establish that server supports the features we @@ -505,7 +532,7 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, remoteInit, ok := remoteMsg.(*wtwire.Init) if !ok { return fmt.Errorf("watchtower %s responded with %T "+ - "to Init", q.towerAddr, remoteMsg) + "to Init", towerAddr, remoteMsg) } // Validate Init. @@ -532,7 +559,7 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, stateUpdateReply, ok := remoteMsg.(*wtwire.StateUpdateReply) if !ok { return fmt.Errorf("watchtower %s responded with %T to "+ - "StateUpdate", q.towerAddr, remoteMsg) + "StateUpdate", towerAddr, remoteMsg) } // Process the reply from the tower. @@ -547,8 +574,8 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, err := fmt.Errorf("received error code %v in "+ "StateUpdateReply for seqnum=%d", stateUpdateReply.Code, stateUpdate.SeqNum) - q.log.Warnf("SessionQueue(%s) unable to upload state update to "+ - "tower=%s: %v", q.ID(), q.towerAddr, err) + q.log.Warnf("SessionQueue(%s) unable to upload state update "+ + "to tower=%s: %v", q.ID(), towerAddr, err) return err } @@ -559,7 +586,8 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, // TODO(conner): borked watchtower err = fmt.Errorf("unable to ack seqnum=%d: %v", stateUpdate.SeqNum, err) - q.log.Errorf("SessionQueue(%v) failed to ack update: %v", q.ID(), err) + q.log.Errorf("SessionQueue(%v) failed to ack update: %v", + q.ID(), err) return err case err == wtdb.ErrLastAppliedReversion: diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 94a9c2c74..26d4704d4 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -429,7 +429,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { } towerSessions, err := listTowerSessions( - towerID, sessions, towers, towersToSessionsIndex, + towerID, sessions, towersToSessionsIndex, WithPerCommittedUpdate(perCommittedUpdate), ) if err != nil { @@ -766,7 +766,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID, // known to the db. if id == nil { clientSessions, err = listClientAllSessions( - sessions, towers, opts..., + sessions, opts..., ) return err } @@ -778,7 +778,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID, } clientSessions, err = listTowerSessions( - *id, sessions, towers, towerToSessionIndex, opts..., + *id, sessions, towerToSessionIndex, opts..., ) return err }, func() { @@ -792,7 +792,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID, } // listClientAllSessions returns the set of all client sessions known to the db. -func listClientAllSessions(sessions, towers kvdb.RBucket, +func listClientAllSessions(sessions kvdb.RBucket, opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) { clientSessions := make(map[SessionID]*ClientSession) @@ -801,7 +801,7 @@ func listClientAllSessions(sessions, towers kvdb.RBucket, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := getClientSession(sessions, towers, k, opts...) + session, err := getClientSession(sessions, k, opts...) if err != nil { return err } @@ -819,7 +819,7 @@ func listClientAllSessions(sessions, towers kvdb.RBucket, // listTowerSessions returns the set of all client sessions known to the db // that are associated with the given tower id. -func listTowerSessions(id TowerID, sessionsBkt, towersBkt, +func listTowerSessions(id TowerID, sessionsBkt, towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) ( map[SessionID]*ClientSession, error) { @@ -834,9 +834,7 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := getClientSession( - sessionsBkt, towersBkt, k, opts..., - ) + session, err := getClientSession(sessionsBkt, k, opts...) if err != nil { return err } @@ -1248,7 +1246,7 @@ func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption { // getClientSession loads the full ClientSession associated with the serialized // session id. This method populates the CommittedUpdates, AckUpdates and Tower // in addition to the ClientSession's body. -func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte, +func getClientSession(sessions kvdb.RBucket, idBytes []byte, opts ...ClientSessionListOption) (*ClientSession, error) { cfg := NewClientSessionCfg() @@ -1261,13 +1259,6 @@ func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte, return nil, err } - // Fetch the tower associated with this session. - tower, err := getTower(towers, session.TowerID.Bytes()) - if err != nil { - return nil, err - } - session.Tower = tower - // Can't fail because client session body has already been read. sessionBkt := sessions.NestedReadBucket(idBytes) diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index aa30cc713..f75a0c2bc 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -343,8 +343,11 @@ func testCreateTower(h *clientDBHarness) { h.loadTowerByID(20, wtdb.ErrTowerNotFound) tower := h.newTower() - require.Len(h.t, tower.LNAddrs(), 1) - towerAddr := tower.LNAddrs()[0] + require.Len(h.t, tower.Addresses, 1) + towerAddr := &lnwire.NetAddress{ + IdentityKey: tower.IdentityKey, + Address: tower.Addresses[0], + } // Load the tower from the database and assert that it matches the tower // we created. diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index a4d5c5ecc..e44331094 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -4,7 +4,6 @@ import ( "fmt" "io" - "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" @@ -36,19 +35,6 @@ type ClientSession struct { ID SessionID ClientSessionBody - - // Tower holds the pubkey and address of the watchtower. - // - // NOTE: This value is not serialized. It is recovered by looking up the - // tower with TowerID. - Tower *Tower - - // SessionKeyECDH is the ECDH capable wrapper of the ephemeral secret - // key used to connect to the watchtower. - // - // NOTE: This value is not serialized. It is derived using the KeyIndex - // on startup to avoid storing private keys on disk. - SessionKeyECDH keychain.SingleKeyECDH } // ClientSessionBody represents the primary components of a ClientSession that diff --git a/watchtower/wtdb/tower.go b/watchtower/wtdb/tower.go index 77f452fb5..ca9dbeb28 100644 --- a/watchtower/wtdb/tower.go +++ b/watchtower/wtdb/tower.go @@ -7,7 +7,6 @@ import ( "net" "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/lnwire" ) // TowerID is a unique 64-bit identifier allocated to each unique watchtower. @@ -77,23 +76,6 @@ func (t *Tower) RemoveAddress(addr net.Addr) { } } -// LNAddrs generates a list of lnwire.NetAddress from a Tower instance's -// addresses. This can be used to have a client try multiple addresses for the -// same Tower. -// -// NOTE: This method is NOT safe for concurrent use. -func (t *Tower) LNAddrs() []*lnwire.NetAddress { - addrs := make([]*lnwire.NetAddress, 0, len(t.Addresses)) - for _, addr := range t.Addresses { - addrs = append(addrs, &lnwire.NetAddress{ - IdentityKey: t.IdentityKey, - Address: addr, - }) - } - - return addrs -} - // String returns a user-friendly identifier of the tower. func (t *Tower) String() string { pubKey := hex.EncodeToString(t.IdentityKey.SerializeCompressed()) diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 8a47bdf7f..b12fe2780 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -231,7 +231,6 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, if tower != nil && *tower != session.TowerID { continue } - session.Tower = m.towers[session.TowerID] sessions[session.ID] = &session if cfg.PerAckedUpdate != nil {