Keep track of peers with maps instead of lists.

This commit is contained in:
David Hill 2015-05-14 18:23:23 -04:00
parent 58f29ad939
commit 9d6d0e4006

View file

@ -5,7 +5,6 @@
package main package main
import ( import (
"container/list"
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
@ -117,9 +116,9 @@ type server struct {
} }
type peerState struct { type peerState struct {
peers *list.List peers map[*peer]struct{}
outboundPeers *list.List outboundPeers map[*peer]struct{}
persistentPeers *list.List persistentPeers map[*peer]struct{}
banned map[string]time.Time banned map[string]time.Time
outboundGroups map[string]int outboundGroups map[string]int
maxOutboundPeers int maxOutboundPeers int
@ -166,11 +165,11 @@ func (s *server) RemoveRebroadcastInventory(iv *wire.InvVect) {
} }
func (p *peerState) Count() int { func (p *peerState) Count() int {
return p.peers.Len() + p.outboundPeers.Len() + p.persistentPeers.Len() return len(p.peers) + len(p.outboundPeers) + len(p.persistentPeers)
} }
func (p *peerState) OutboundCount() int { func (p *peerState) OutboundCount() int {
return p.outboundPeers.Len() + p.persistentPeers.Len() return len(p.outboundPeers) + len(p.persistentPeers)
} }
func (p *peerState) NeedMoreOutbound() bool { func (p *peerState) NeedMoreOutbound() bool {
@ -181,19 +180,19 @@ func (p *peerState) NeedMoreOutbound() bool {
// forAllOutboundPeers is a helper function that runs closure on all outbound // forAllOutboundPeers is a helper function that runs closure on all outbound
// peers known to peerState. // peers known to peerState.
func (p *peerState) forAllOutboundPeers(closure func(p *peer)) { func (p *peerState) forAllOutboundPeers(closure func(p *peer)) {
for e := p.outboundPeers.Front(); e != nil; e = e.Next() { for e := range p.outboundPeers {
closure(e.Value.(*peer)) closure(e)
} }
for e := p.persistentPeers.Front(); e != nil; e = e.Next() { for e := range p.persistentPeers {
closure(e.Value.(*peer)) closure(e)
} }
} }
// forAllPeers is a helper function that runs closure on all peers known to // forAllPeers is a helper function that runs closure on all peers known to
// peerState. // peerState.
func (p *peerState) forAllPeers(closure func(p *peer)) { func (p *peerState) forAllPeers(closure func(p *peer)) {
for e := p.peers.Front(); e != nil; e = e.Next() { for e := range p.peers {
closure(e.Value.(*peer)) closure(e)
} }
p.forAllOutboundPeers(closure) p.forAllOutboundPeers(closure)
} }
@ -278,14 +277,14 @@ func (s *server) handleAddPeerMsg(state *peerState, p *peer) bool {
// Add the new peer and start it. // Add the new peer and start it.
srvrLog.Debugf("New peer %s", p) srvrLog.Debugf("New peer %s", p)
if p.inbound { if p.inbound {
state.peers.PushBack(p) state.peers[p] = struct{}{}
p.Start() p.Start()
} else { } else {
state.outboundGroups[addrmgr.GroupKey(p.na)]++ state.outboundGroups[addrmgr.GroupKey(p.na)]++
if p.persistent { if p.persistent {
state.persistentPeers.PushBack(p) state.persistentPeers[p] = struct{}{}
} else { } else {
state.outboundPeers.PushBack(p) state.outboundPeers[p] = struct{}{}
} }
} }
@ -295,7 +294,7 @@ func (s *server) handleAddPeerMsg(state *peerState, p *peer) bool {
// handleDonePeerMsg deals with peers that have signalled they are done. It is // handleDonePeerMsg deals with peers that have signalled they are done. It is
// invoked from the peerHandler goroutine. // invoked from the peerHandler goroutine.
func (s *server) handleDonePeerMsg(state *peerState, p *peer) { func (s *server) handleDonePeerMsg(state *peerState, p *peer) {
var list *list.List var list map[*peer]struct{}
if p.persistent { if p.persistent {
list = state.persistentPeers list = state.persistentPeers
} else if p.inbound { } else if p.inbound {
@ -303,18 +302,18 @@ func (s *server) handleDonePeerMsg(state *peerState, p *peer) {
} else { } else {
list = state.outboundPeers list = state.outboundPeers
} }
for e := list.Front(); e != nil; e = e.Next() { for e := range list {
if e.Value == p { if e == p {
// Issue an asynchronous reconnect if the peer was a // Issue an asynchronous reconnect if the peer was a
// persistent outbound connection. // persistent outbound connection.
if !p.inbound && p.persistent && atomic.LoadInt32(&s.shutdown) == 0 { if !p.inbound && p.persistent && atomic.LoadInt32(&s.shutdown) == 0 {
e.Value = newOutboundPeer(s, p.addr, true, p.retryCount+1) e = newOutboundPeer(s, p.addr, true, p.retryCount+1)
return return
} }
if !p.inbound { if !p.inbound {
state.outboundGroups[addrmgr.GroupKey(p.na)]-- state.outboundGroups[addrmgr.GroupKey(p.na)]--
} }
list.Remove(e) delete(list, e)
srvrLog.Debugf("Removed peer %s", p) srvrLog.Debugf("Removed peer %s", p)
return return
} }
@ -439,7 +438,7 @@ func (s *server) handleQuery(querymsg interface{}, state *peerState) {
case getPeerInfoMsg: case getPeerInfoMsg:
syncPeer := s.blockManager.SyncPeer() syncPeer := s.blockManager.SyncPeer()
infos := make([]*btcjson.GetPeerInfoResult, 0, state.peers.Len()) infos := make([]*btcjson.GetPeerInfoResult, 0, len(state.peers))
state.forAllPeers(func(p *peer) { state.forAllPeers(func(p *peer) {
if !p.Connected() { if !p.Connected() {
return return
@ -481,8 +480,7 @@ func (s *server) handleQuery(querymsg interface{}, state *peerState) {
case connectNodeMsg: case connectNodeMsg:
// XXX(oga) duplicate oneshots? // XXX(oga) duplicate oneshots?
for e := state.persistentPeers.Front(); e != nil; e = e.Next() { for peer := range state.persistentPeers {
peer := e.Value.(*peer)
if peer.addr == msg.addr { if peer.addr == msg.addr {
if msg.permanent { if msg.permanent {
msg.reply <- errors.New("peer already connected") msg.reply <- errors.New("peer already connected")
@ -515,9 +513,8 @@ func (s *server) handleQuery(querymsg interface{}, state *peerState) {
// Request a list of the persistent (added) peers. // Request a list of the persistent (added) peers.
case getAddedNodesMsg: case getAddedNodesMsg:
// Respond with a slice of the relavent peers. // Respond with a slice of the relavent peers.
peers := make([]*peer, 0, state.persistentPeers.Len()) peers := make([]*peer, 0, len(state.persistentPeers))
for e := state.persistentPeers.Front(); e != nil; e = e.Next() { for peer := range state.persistentPeers {
peer := e.Value.(*peer)
peers = append(peers, peer) peers = append(peers, peer)
} }
msg.reply <- peers msg.reply <- peers
@ -560,9 +557,8 @@ func (s *server) handleQuery(querymsg interface{}, state *peerState) {
// to be located. If the peer is found, and the passed callback: `whenFound' // to be located. If the peer is found, and the passed callback: `whenFound'
// isn't nil, we call it with the peer as the argument before it is removed // isn't nil, we call it with the peer as the argument before it is removed
// from the peerList, and is disconnected from the server. // from the peerList, and is disconnected from the server.
func disconnectPeer(peerList *list.List, compareFunc func(*peer) bool, whenFound func(*peer)) bool { func disconnectPeer(peerList map[*peer]struct{}, compareFunc func(*peer) bool, whenFound func(*peer)) bool {
for e := peerList.Front(); e != nil; e = e.Next() { for peer := range peerList {
peer := e.Value.(*peer)
if compareFunc(peer) { if compareFunc(peer) {
if whenFound != nil { if whenFound != nil {
whenFound(peer) whenFound(peer)
@ -570,7 +566,7 @@ func disconnectPeer(peerList *list.List, compareFunc func(*peer) bool, whenFound
// This is ok because we are not continuing // This is ok because we are not continuing
// to iterate so won't corrupt the loop. // to iterate so won't corrupt the loop.
peerList.Remove(e) delete(peerList, peer)
peer.Disconnect() peer.Disconnect()
return true return true
} }
@ -659,9 +655,9 @@ func (s *server) peerHandler() {
srvrLog.Tracef("Starting peer handler") srvrLog.Tracef("Starting peer handler")
state := &peerState{ state := &peerState{
peers: list.New(), peers: make(map[*peer]struct{}),
persistentPeers: list.New(), persistentPeers: make(map[*peer]struct{}),
outboundPeers: list.New(), outboundPeers: make(map[*peer]struct{}),
banned: make(map[string]time.Time), banned: make(map[string]time.Time),
maxOutboundPeers: defaultMaxOutbound, maxOutboundPeers: defaultMaxOutbound,
outboundGroups: make(map[string]int), outboundGroups: make(map[string]int),