diff --git a/peer.go b/peer.go index f6aa3a44..c8fe87ba 100644 --- a/peer.go +++ b/peer.go @@ -18,6 +18,7 @@ import ( "net" "strconv" "sync" + "sync/atomic" "time" ) @@ -101,7 +102,7 @@ type peer struct { na *btcwire.NetAddress timeConnected time.Time inbound bool - disconnect bool + disconnect int32 // only to be used atomically persistent bool versionKnown bool knownAddresses map[string]bool @@ -116,7 +117,6 @@ type peer struct { requestQueue *list.List invSendQueue *list.List continueHash *btcwire.ShaHash - wg sync.WaitGroup outputQueue chan btcwire.Message outputInvChan chan *btcwire.InvVect blockProcessed chan bool @@ -692,7 +692,7 @@ func (p *peer) handleAddrMsg(msg *btcwire.MsgAddr) { for _, na := range msg.AddrList { // Don't add more address if we're disconnecting. - if p.disconnect { + if atomic.LoadInt32(&p.disconnect) != 0 { return } @@ -752,7 +752,7 @@ func (p *peer) readMessage() (msg btcwire.Message, buf []byte, err error) { // writeMessage sends a bitcoin Message to the peer with logging. func (p *peer) writeMessage(msg btcwire.Message) { // Don't do anything if we're disconnecting. - if p.disconnect == true { + if atomic.LoadInt32(&p.disconnect) != 0 { return } @@ -812,7 +812,7 @@ func (p *peer) isAllowedByRegression(err error) bool { // goroutine. func (p *peer) inHandler() { out: - for !p.disconnect { + for atomic.LoadInt32(&p.disconnect) == 0 { rmsg, buf, err := p.readMessage() if err != nil { // In order to allow regression tests with malformed @@ -826,7 +826,7 @@ out: } // Only log the error if we're not forcibly disconnecting. - if !p.disconnect { + if atomic.LoadInt32(&p.disconnect) == 0 { log.Errorf("[PEER] Can't read message: %v", err) } break out @@ -890,7 +890,7 @@ out: // Mark the address as currently connected and working as of // now if one of the messages that trigger - if markConnected && !p.disconnect { + if markConnected && atomic.LoadInt32(&p.disconnect) == 0 { if p.na == nil { log.Warnf("we're getting stuff before we " + "got a version message. that's bad") @@ -905,9 +905,7 @@ out: p.Disconnect() p.server.donePeers <- p p.server.blockManager.DonePeer(p) - p.quit <- true - p.wg.Done() log.Tracef("[PEER] Peer input handler done for %s", p.conn.RemoteAddr()) } @@ -928,7 +926,8 @@ out: case <-trickleTicker.C: // Don't send anything if we're disconnecting or there // is no queued inventory. - if p.disconnect || p.invSendQueue.Len() == 0 { + if atomic.LoadInt32(&p.disconnect) != 0 || + p.invSendQueue.Len() == 0 { continue } @@ -962,7 +961,6 @@ out: break out } } - p.wg.Done() log.Tracef("[PEER] Peer output handler done for %s", p.conn.RemoteAddr()) } @@ -1011,7 +1009,6 @@ func (p *peer) Start() error { // Start processing input and output. go p.inHandler() go p.outHandler() - p.wg.Add(2) p.started = true return nil @@ -1020,7 +1017,11 @@ func (p *peer) Start() error { // Disconnect disconnects the peer by closing the connection. It also sets // a flag so the impending shutdown can be detected. func (p *peer) Disconnect() { - p.disconnect = true + // did we win the race? + if atomic.AddInt32(&p.disconnect, 1) != 1 { + return + } + close(p.quit) if p.conn != nil { p.conn.Close() } @@ -1031,7 +1032,6 @@ func (p *peer) Disconnect() { func (p *peer) Shutdown() { log.Tracef("[PEER] Shutdown peer %s", p.addr) p.Disconnect() - p.wg.Wait() } // newPeerBase returns a new base bitcoin peer for the provided server and @@ -1098,7 +1098,6 @@ func newOutboundPeer(s *server, addr string, persistent bool) *peer { } p.na = btcwire.NewNetAddressIPPort(net.ParseIP(ip), uint16(port), 0) - p.wg.Add(1) go func() { // Select which dial method to call depending on whether or // not a proxy is configured. Also, add proxy information to @@ -1118,7 +1117,7 @@ func newOutboundPeer(s *server, addr string, persistent bool) *peer { // Attempt to connect to the peer. If the connection fails and // this is a persistent connection, retry after the retry // interval. - for !s.shutdown { + for atomic.LoadInt32(&p.disconnect) == 0 { log.Debugf("[SRVR] Attempting to connect to %s", faddr) conn, err := dial("tcp", addr) if err != nil { @@ -1127,7 +1126,6 @@ func newOutboundPeer(s *server, addr string, persistent bool) *peer { faddr, err) if !persistent { p.server.donePeers <- p - p.wg.Done() return } scaledInterval := connectionRetryInterval.Nanoseconds() * p.retrycount / 2 @@ -1141,7 +1139,7 @@ func newOutboundPeer(s *server, addr string, persistent bool) *peer { // While we were sleeping trying to connect, the server // may have scheduled a shutdown. In that case ditch // the peer immediately. - if !s.shutdown { + if atomic.LoadInt32(&p.disconnect) == 0 { p.server.addrManager.Attempt(p.na) // Connection was successful so log it and start peer. @@ -1149,14 +1147,8 @@ func newOutboundPeer(s *server, addr string, persistent bool) *peer { p.conn = conn p.retrycount = 0 p.Start() - } else { - p.server.donePeers <- p } - // We are done here, Start() will have grabbed - // additional waitgroup entries if we are not shutting - // down. - p.wg.Done() return } }() diff --git a/server.go b/server.go index 4d57ba75..334ac5e9 100644 --- a/server.go +++ b/server.go @@ -12,6 +12,7 @@ import ( "net" "strconv" "sync" + "sync/atomic" "time" ) @@ -48,8 +49,8 @@ type server struct { listeners []net.Listener btcnet btcwire.BitcoinNet started bool - shutdown bool - shutdownSched bool + shutdown int32 // atomic + shutdownSched int32 // atomic addrManager *AddrManager rpcServer *rpcServer blockManager *blockManager @@ -73,7 +74,7 @@ func (s *server) handleAddPeerMsg(peers *list.List, banned map[string]time.Time, // Ignore new peers if we're shutting down. direction := directionString(p.inbound) - if s.shutdown { + if atomic.LoadInt32(&s.shutdown) != 0 { log.Infof("[SRVR] New peer %s (%s) ignored - server is "+ "shutting down", p.addr, direction) p.Shutdown() @@ -130,7 +131,8 @@ func (s *server) handleDonePeerMsg(peers *list.List, p *peer) bool { // Issue an asynchronous reconnect if the peer was a // persistent outbound connection. - if !p.inbound && p.persistent && !s.shutdown { + if !p.inbound && p.persistent && + atomic.LoadInt32(&s.shutdown) == 0 { // attempt reconnect. addr := p.addr e.Value = newOutboundPeer(s, addr, true) @@ -207,11 +209,11 @@ func (s *server) handleBroadcastMsg(peers *list.List, bmsg *broadcastMsg) { // server. It must be run as a goroutine. func (s *server) listenHandler(listener net.Listener) { log.Infof("[SRVR] Server listening on %s", listener.Addr()) - for !s.shutdown { + for atomic.LoadInt32(&s.shutdown) == 0 { conn, err := listener.Accept() if err != nil { // Only log the error if we're not forcibly shutting down. - if !s.shutdown { + if atomic.LoadInt32(&s.shutdown) == 0 { log.Errorf("[SRVR] %v", err) } continue @@ -295,8 +297,8 @@ func (s *server) peerHandler() { // if nothing else happens, wake us up soon. time.AfterFunc(10*time.Second, func() { s.wakeup <- true }) - // Live while we're not shutting down or there are still connected peers. - for !s.shutdown || peers.Len() != 0 { +out: + for { select { // New peers connected to the server. case p := <-s.newPeers: @@ -336,6 +338,7 @@ func (s *server) peerHandler() { p := e.Value.(*peer) p.Shutdown() } + break out } // Timer was just to make sure we woke up again soon. so cancel @@ -346,7 +349,8 @@ func (s *server) peerHandler() { } // Only try connect to more peers if we actually need more - if outboundPeers >= maxOutbound || s.shutdown { + if outboundPeers >= maxOutbound || + atomic.LoadInt32(&s.shutdown) != 0 { continue } groups := make(map[string]int) @@ -359,7 +363,8 @@ func (s *server) peerHandler() { tries := 0 for outboundPeers < maxOutbound && - peers.Len() < cfg.MaxPeers && !s.shutdown { + peers.Len() < cfg.MaxPeers && + atomic.LoadInt32(&s.shutdown) == 0 { // We bias like bitcoind does, 10 for no outgoing // up to 90 (8) for the selection of new vs tried //addresses. @@ -494,16 +499,16 @@ func (s *server) Start() { // Stop gracefully shuts down the server by stopping and disconnecting all // peers and the main listener. func (s *server) Stop() error { - if s.shutdown { + // Make sure this only happens once. + if atomic.AddInt32(&s.shutdown, 1) != 1 { log.Infof("[SRVR] Server is already in the process of shutting down") return nil } log.Warnf("[SRVR] Server shutting down") - // Set the shutdown flag and stop all the listeners. There will not be - // any listeners if listening is disabled. - s.shutdown = true + // Stop all the listeners. There will not be any listeners if + // listening is disabled. for _, listener := range s.listeners { err := listener.Close() if err != nil { @@ -532,7 +537,7 @@ func (s *server) WaitForShutdown() { // on remaining duration. func (s *server) ScheduleShutdown(duration time.Duration) { // Don't schedule shutdown more than once. - if s.shutdownSched { + if atomic.AddInt32(&s.shutdownSched, 1) != 1 { return } log.Warnf("[SRVR] Server shutdown in %v", duration) @@ -565,7 +570,6 @@ func (s *server) ScheduleShutdown(duration time.Duration) { } } }() - s.shutdownSched = true } // newServer returns a new btcd server configured to listen on addr for the