diff --git a/connmgr/connmanager.go b/connmgr/connmanager.go index 26c8e59c..4acc0b2d 100644 --- a/connmgr/connmanager.go +++ b/connmgr/connmanager.go @@ -46,9 +46,10 @@ type ConnState uint8 // connection which was disconnected is categorized as disconnected. const ( ConnPending ConnState = iota + ConnFailing + ConnCanceled ConnEstablished ConnDisconnected - ConnFailed ) // ConnReq is the connection request to a network address. If permanent, the @@ -225,8 +226,16 @@ func (cm *ConnManager) handleFailedConn(c *ConnReq) { // connections so that we remain connected to the network. Connection requests // are processed and mapped by their assigned ids. func (cm *ConnManager) connHandler() { - conns := make(map[uint64]*ConnReq, cm.cfg.TargetOutbound) - pendingConns := make(map[uint64]*ConnReq) + + var ( + // pending holds all registered conn requests that have yet to + // succeed. + pending = make(map[uint64]*ConnReq) + + // conns represents the set of all actively connected peers. + conns = make(map[uint64]*ConnReq, cm.cfg.TargetOutbound) + ) + out: for { select { @@ -234,11 +243,23 @@ out: switch msg := req.(type) { case registerPending: - pendingConns[msg.c.id] = msg.c + connReq := msg.c + connReq.updateState(ConnPending) + pending[msg.c.id] = connReq close(msg.done) case handleConnected: connReq := msg.c + + if _, ok := pending[connReq.id]; !ok { + if msg.conn != nil { + msg.conn.Close() + } + log.Debugf("Ignoring connection for "+ + "canceled connreq=%v", connReq) + continue + } + connReq.updateState(ConnEstablished) connReq.conn = msg.conn conns[connReq.id] = connReq @@ -246,7 +267,7 @@ out: connReq.retryCount = 0 cm.failedAttempts = 0 - delete(pendingConns, connReq.id) + delete(pending, connReq.id) if cm.cfg.OnConnection != nil { go cm.cfg.OnConnection(connReq, msg.conn) @@ -255,39 +276,74 @@ out: case handleDisconnected: connReq, ok := conns[msg.id] if !ok { - connReq, ok = pendingConns[msg.id] - if ok && !msg.retry { - connReq.updateState(ConnFailed) - - log.Debugf("Cancelling: %v", connReq) - delete(pendingConns, msg.id) - return + connReq, ok = pending[msg.id] + if !ok { + log.Errorf("Unknown connid=%d", + msg.id) + continue } + + // Pending connection was found, remove + // it from pending map if we should + // ignore a later, successful + // connection. + connReq.updateState(ConnCanceled) + log.Debugf("Canceling: %v", connReq) + delete(pending, msg.id) + continue + } - if connReq != nil { + // An existing connection was located, mark as + // disconnected and execute disconnection + // callback. + log.Debugf("Disconnected from %v", connReq) + delete(conns, msg.id) + + if connReq.conn != nil { + connReq.conn.Close() + } + + if cm.cfg.OnDisconnection != nil { + go cm.cfg.OnDisconnection(connReq) + } + + // All internal state has been cleaned up, if + // this connection is being removed, we will + // make no further attempts with this request. + if !msg.retry { connReq.updateState(ConnDisconnected) - if connReq.conn != nil { - connReq.conn.Close() - } - log.Debugf("Disconnected from %v", connReq) - delete(conns, msg.id) + continue + } - if cm.cfg.OnDisconnection != nil { - go cm.cfg.OnDisconnection(connReq) - } + // Otherwise, we will attempt a reconnection if + // we do not have enough peers, or if this is a + // persistent peer. The connection request is + // re added to the pending map, so that + // subsequent processing of connections and + // failures do not ignore the request. + if uint32(len(conns)) < cm.cfg.TargetOutbound || + connReq.Permanent { - if uint32(len(conns)) < cm.cfg.TargetOutbound && msg.retry { - cm.handleFailedConn(connReq) - } - } else { - log.Errorf("Unknown connection: %d", msg.id) + connReq.updateState(ConnPending) + log.Debugf("Reconnecting to %v", + connReq) + pending[msg.id] = connReq + cm.handleFailedConn(connReq) } case handleFailed: connReq := msg.c - connReq.updateState(ConnFailed) - log.Debugf("Failed to connect to %v: %v", connReq, msg.err) + + if _, ok := pending[connReq.id]; !ok { + log.Debugf("Ignoring connection for "+ + "canceled conn req: %v", connReq) + continue + } + + connReq.updateState(ConnFailing) + log.Debugf("Failed to connect to %v: %v", + connReq, msg.err) cm.handleFailedConn(connReq) } @@ -313,9 +369,31 @@ func (cm *ConnManager) NewConnReq() { c := &ConnReq{} atomic.StoreUint64(&c.id, atomic.AddUint64(&cm.connReqCount, 1)) + // Submit a request of a pending connection attempt to the connection + // manager. By registering the id before the connection is even + // established, we'll be able to later cancel the connection via the + // Remove method. + done := make(chan struct{}) + select { + case cm.requests <- registerPending{c, done}: + case <-cm.quit: + return + } + + // Wait for the registration to successfully add the pending conn req to + // the conn manager's internal state. + select { + case <-done: + case <-cm.quit: + return + } + addr, err := cm.cfg.GetNewAddress() if err != nil { - cm.requests <- handleFailed{c, err} + select { + case cm.requests <- handleFailed{c, err}: + case <-cm.quit: + } return } @@ -338,17 +416,35 @@ func (cm *ConnManager) Connect(c *ConnReq) { // connection is even established, we'll be able to later // cancel the connection via the Remove method. done := make(chan struct{}) - cm.requests <- registerPending{c, done} - <-done + select { + case cm.requests <- registerPending{c, done}: + case <-cm.quit: + return + } + + // Wait for the registration to successfully add the pending + // conn req to the conn manager's internal state. + select { + case <-done: + case <-cm.quit: + return + } } log.Debugf("Attempting to connect to %v", c) conn, err := cm.cfg.Dial(c.Addr) if err != nil { - cm.requests <- handleFailed{c, err} - } else { - cm.requests <- handleConnected{c, conn} + select { + case cm.requests <- handleFailed{c, err}: + case <-cm.quit: + } + return + } + + select { + case cm.requests <- handleConnected{c, conn}: + case <-cm.quit: } } @@ -359,7 +455,11 @@ func (cm *ConnManager) Disconnect(id uint64) { if atomic.LoadInt32(&cm.stop) != 0 { return } - cm.requests <- handleDisconnected{id, true} + + select { + case cm.requests <- handleDisconnected{id, true}: + case <-cm.quit: + } } // Remove removes the connection corresponding to the given connection id from @@ -371,7 +471,11 @@ func (cm *ConnManager) Remove(id uint64) { if atomic.LoadInt32(&cm.stop) != 0 { return } - cm.requests <- handleDisconnected{id, false} + + select { + case cm.requests <- handleDisconnected{id, false}: + case <-cm.quit: + } } // listenHandler accepts incoming connections on a given listener. It must be diff --git a/connmgr/connmanager_test.go b/connmgr/connmanager_test.go index 99928931..67769deb 100644 --- a/connmgr/connmanager_test.go +++ b/connmgr/connmanager_test.go @@ -9,7 +9,6 @@ import ( "fmt" "io" "net" - "runtime" "sync/atomic" "testing" "time" @@ -268,7 +267,7 @@ func TestRetryPermanent(t *testing.T) { t.Fatalf("retry: %v - want ID %v, got ID %v", cr.Addr, wantID, gotID) } gotState = cr.State() - wantState = ConnDisconnected + wantState = ConnPending if gotState != wantState { t.Fatalf("retry: %v - want state %v, got state %v", cr.Addr, wantState, gotState) } @@ -451,24 +450,108 @@ func TestRemovePendingConnection(t *testing.T) { } go cmgr.Connect(cr) - runtime.Gosched() + time.Sleep(10 * time.Millisecond) + + if cr.State() != ConnPending { + t.Fatalf("pending request hasn't been registered, status: %v", + cr.State()) + } // The request launched above will actually never be able to establish // a connection. So we'll cancel it _before_ it's able to be completed. cmgr.Remove(cr.ID()) - runtime.Gosched() + time.Sleep(10 * time.Millisecond) // Now examine the status of the connection request, it should read a // status of failed. - if cr.State() != ConnFailed { - t.Fatalf("request wasn't cancelled, status is: %v", cr.State()) + if cr.State() != ConnCanceled { + t.Fatalf("request wasn't canceled, status is: %v", cr.State()) } close(wait) cmgr.Stop() } +// TestCancelIgnoreDelayedConnection tests that a canceled connection request will +// not execute the on connection callback, even if an outstanding retry +// succeeds. +func TestCancelIgnoreDelayedConnection(t *testing.T) { + retryTimeout := 10 * time.Millisecond + + // Setup a dialer that will continue to return an error until the + // connect chan is signaled, the dial attempt immediately after will + // succeed in returning a connection. + connect := make(chan struct{}) + failingDialer := func(addr net.Addr) (net.Conn, error) { + select { + case <-connect: + return mockDialer(addr) + default: + } + + return nil, fmt.Errorf("error") + } + + connected := make(chan *ConnReq) + cmgr, err := New(&Config{ + Dial: failingDialer, + RetryDuration: retryTimeout, + OnConnection: func(c *ConnReq, conn net.Conn) { + connected <- c + }, + }) + if err != nil { + t.Fatalf("New error: %v", err) + } + cmgr.Start() + defer cmgr.Stop() + + // Establish a connection request to a random IP we've chosen. + cr := &ConnReq{ + Addr: &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + }, + } + cmgr.Connect(cr) + + // Allow for the first retry timeout to elapse. + time.Sleep(2 * retryTimeout) + + // Connection be marked as failed, even after reattempting to + // connect. + if cr.State() != ConnFailing { + t.Fatalf("failing request should have status failed, status: %v", + cr.State()) + } + + // Remove the connection, and then immediately allow the next connection + // to succeed. + cmgr.Remove(cr.ID()) + close(connect) + + // Allow the connection manager to process the removal. + time.Sleep(5 * time.Millisecond) + + // Now examine the status of the connection request, it should read a + // status of canceled. + if cr.State() != ConnCanceled { + t.Fatalf("request wasn't canceled, status is: %v", cr.State()) + } + + // Finally, the connection manager should not signal the on-connection + // callback, since we explicitly canceled this request. We give a + // generous window to ensure the connection manager's lienar backoff is + // allowed to properly elapse. + select { + case <-connected: + t.Fatalf("on-connect should not be called for canceled req") + case <-time.After(5 * retryTimeout): + } + +} + // mockListener implements the net.Listener interface and is used to test // code that deals with net.Listeners without having to actually make any real // connections.