watchtower: allow removal during session negotiation

In this commit, the bug demonstrated in the previous commit is fixed.
The locking capabilities of the AddressIterator are used to lock
addresses if they are being used for session negotiation. So now, when a
request comes through to remove a tower address then a check is first
done to ensure that the address is not currently in use. If it is not,
then the request can go through.
This commit is contained in:
Elle Mouton 2022-10-12 10:56:04 +02:00
parent b2039f245e
commit 3ff5abc9e3
No known key found for this signature in database
GPG key ID: D7D916376026F177
4 changed files with 125 additions and 30 deletions

View file

@ -155,6 +155,10 @@ func (t *towerListIterator) RemoveCandidate(candidate wtdb.TowerID,
return err
}
} else {
if tower.Addresses.HasLocked() {
return ErrAddrInUse
}
delete(t.candidates, candidate)
}

View file

@ -2,7 +2,6 @@ package wtclient
import (
"bytes"
"errors"
"fmt"
"net"
"sync"
@ -826,13 +825,10 @@ func (c *TowerClient) backupDispatcher() {
msg.errChan <- c.handleNewTower(msg)
// A tower has been requested to be removed. We'll
// immediately return an error as we want to avoid the
// possibility of a new session being negotiated with
// this request's tower.
// only allow removal of it if the address in question
// is not currently being used for session negotiation.
case msg := <-c.staleTowers:
msg.errChan <- errors.New("removing towers " +
"is disallowed while a new session " +
"negotiation is in progress")
msg.errChan <- c.handleStaleTower(msg)
case <-c.forceQuit:
return
@ -1254,18 +1250,31 @@ func (c *TowerClient) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error
func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
// We'll load the tower before potentially removing it in order to
// retrieve its ID within the database.
tower, err := c.cfg.DB.LoadTower(msg.pubKey)
dbTower, err := c.cfg.DB.LoadTower(msg.pubKey)
if err != nil {
return err
}
// We'll update our persisted state, followed by our in-memory state,
// with the stale tower.
if err := c.cfg.DB.RemoveTower(msg.pubKey, msg.addr); err != nil {
// We'll first update our in-memory state followed by our persisted
// state, with the stale tower. The removal of the tower address from
// the in-memory state will fail if the address is currently being used
// for a session negotiation.
err = c.candidateTowers.RemoveCandidate(dbTower.ID, msg.addr)
if err != nil {
return err
}
err = c.candidateTowers.RemoveCandidate(tower.ID, msg.addr)
if err != nil {
if err := c.cfg.DB.RemoveTower(msg.pubKey, msg.addr); err != nil {
// If the persisted state update fails, re-add the address to
// our in-memory state.
tower, newTowerErr := NewTowerFromDBTower(dbTower)
if newTowerErr != nil {
log.Errorf("could not create new in-memory tower: %v",
newTowerErr)
} else {
c.candidateTowers.AddCandidate(tower)
}
return err
}
@ -1278,7 +1287,7 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
// Otherwise, the tower should no longer be used for future session
// negotiations and backups.
pubKey := msg.pubKey.SerializeCompressed()
sessions, err := c.cfg.DB.ListClientSessions(&tower.ID)
sessions, err := c.cfg.DB.ListClientSessions(&dbTower.ID)
if err != nil {
return fmt.Errorf("unable to retrieve sessions for tower %x: "+
"%v", pubKey, err)

View file

@ -2,6 +2,7 @@ package wtclient_test
import (
"encoding/binary"
"errors"
"fmt"
"net"
"sync"
@ -394,6 +395,8 @@ type testHarness struct {
mu sync.Mutex
channels map[lnwire.ChannelID]*mockChannel
quit chan struct{}
}
type harnessCfg struct {
@ -467,7 +470,11 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
serverCfg: serverCfg,
net: mockNet,
channels: make(map[lnwire.ChannelID]*mockChannel),
quit: make(chan struct{}),
}
t.Cleanup(func() {
close(h.quit)
})
if !cfg.noServerStart {
h.startServer()
@ -1542,11 +1549,10 @@ var clientTests = []clientTest{
},
},
{
// Assert that an error is returned if a user tries to remove
// a tower from the client while a session negotiation is in
// progress. This is a bug that will be fixed in a future
// commit.
name: "cant remove tower while session negotiation in progress",
// Assert that a user is able to remove a tower address during
// session negotiation as long as the address in question is not
// currently being used.
name: "removing a tower during session negotiation",
cfg: harnessCfg{
localBalance: localBalance,
remoteBalance: remoteBalance,
@ -1560,18 +1566,93 @@ var clientTests = []clientTest{
noServerStart: true,
},
fn: func(h *testHarness) {
var err error
waitErr := wait.Predicate(func() bool {
err = h.client.RemoveTower(
// The server has not started yet and so no session
// negotiation with the server will be in progress, so
// the client should be able to remove the server.
err := wait.NoError(func() error {
return h.client.RemoveTower(
h.serverAddr.IdentityKey, nil,
)
return err != nil
}, time.Second*5)
require.NoError(h.t, waitErr)
}, waitTime)
require.NoError(h.t, err)
require.ErrorContains(h.t, err, "removing towers is "+
"disallowed while a new session negotiation "+
"is in progress")
// Set the server up so that its Dial function hangs
// when the client calls it. This will force the client
// to remain in the state where it has locked the
// address of the server.
h.server, err = wtserver.New(h.serverCfg)
require.NoError(h.t, err)
cancel := make(chan struct{})
h.net.registerConnCallback(
h.serverAddr, func(peer wtserver.Peer) {
select {
case <-h.quit:
case <-cancel:
}
},
)
// Also add a new tower address.
towerTCPAddr, err := net.ResolveTCPAddr(
"tcp", towerAddr2Str,
)
require.NoError(h.t, err)
towerAddr := &lnwire.NetAddress{
IdentityKey: h.serverAddr.IdentityKey,
Address: towerTCPAddr,
}
// Register the new address in the mock-net.
h.net.registerConnCallback(
towerAddr, h.server.InboundPeerConnected,
)
// Now start the server.
require.NoError(h.t, h.server.Start())
// Re-add the server to the client
err = h.client.AddTower(h.serverAddr)
require.NoError(h.t, err)
// Also add the new tower address.
err = h.client.AddTower(towerAddr)
require.NoError(h.t, err)
// Assert that if the client attempts to remove the
// tower's first address, then it will error due to
// address currently being locked for session
// negotiation.
err = wait.Predicate(func() bool {
err = h.client.RemoveTower(
h.serverAddr.IdentityKey,
h.serverAddr.Address,
)
return errors.Is(err, wtclient.ErrAddrInUse)
}, waitTime)
require.NoError(h.t, err)
// Assert that the second address can be removed since
// it is not being used for session negotiation.
err = wait.NoError(func() error {
return h.client.RemoveTower(
h.serverAddr.IdentityKey, towerTCPAddr,
)
}, waitTime)
require.NoError(h.t, err)
// Allow the dial to the first address to stop hanging.
close(cancel)
// Assert that the client can now remove the first
// address.
err = wait.NoError(func() error {
return h.client.RemoveTower(
h.serverAddr.IdentityKey, nil,
)
}, waitTime)
require.NoError(h.t, err)
},
},
}

View file

@ -350,7 +350,7 @@ func (n *sessionNegotiator) createSession(tower *Tower, keyIndex uint32) error {
sessionKeyDesc, n.cfg.SecretKeyRing,
)
addr := tower.Addresses.Peek()
addr := tower.Addresses.PeekAndLock()
for {
lnAddr := &lnwire.NetAddress{
IdentityKey: tower.IdentityKey,
@ -358,6 +358,7 @@ func (n *sessionNegotiator) createSession(tower *Tower, keyIndex uint32) error {
}
err = n.tryAddress(sessionKey, keyIndex, tower, lnAddr)
tower.Addresses.ReleaseLock(addr)
switch {
case err == ErrPermanentTowerFailure:
// TODO(conner): report to iterator? can then be reset
@ -370,7 +371,7 @@ func (n *sessionNegotiator) createSession(tower *Tower, keyIndex uint32) error {
"%v", lnAddr, err)
// Get the next tower address if there is one.
addr, err = tower.Addresses.Next()
addr, err = tower.Addresses.NextAndLock()
if err == ErrAddressesExhausted {
tower.Addresses.Reset()