peer: Extract protocol negotiation from main read and write code paths.

This allows cleaner separation of the half-duplex version negotiation from the fully duplex message passing between peers.
This commit is contained in:
Jonathan Gillham 2016-02-10 18:01:55 +00:00
parent 777ccdade3
commit f3d759d783
2 changed files with 179 additions and 168 deletions

View File

@ -188,7 +188,7 @@ type MessageListeners struct {
// not directly provide a callback. // not directly provide a callback.
OnRead func(p *Peer, bytesRead int, msg wire.Message, err error) OnRead func(p *Peer, bytesRead int, msg wire.Message, err error)
// OnWrite is invoked when a peer receives a bitcoin message. It // OnWrite is invoked when we write a bitcoin message to a peer. It
// consists of the number of bytes written, the message, and whether or // consists of the number of bytes written, the message, and whether or
// not an error in the write occurred. This can be useful for // not an error in the write occurred. This can be useful for
// circumstances such as keeping track of server-wide byte counts. // circumstances such as keeping track of server-wide byte counts.
@ -735,15 +735,15 @@ func (p *Peer) WantsHeaders() bool {
return p.sendHeadersPreferred return p.sendHeadersPreferred
} }
// pushVersionMsg sends a version message to the connected peer using the // localVersionMsg creates a version message that can be used to send to the
// current state. // remote peer.
func (p *Peer) pushVersionMsg() error { func (p *Peer) localVersionMsg() (*wire.MsgVersion, error) {
var blockNum int32 var blockNum int32
if p.cfg.NewestBlock != nil { if p.cfg.NewestBlock != nil {
var err error var err error
_, blockNum, err = p.cfg.NewestBlock() _, blockNum, err = p.cfg.NewestBlock()
if err != nil { if err != nil {
return err return nil, err
} }
} }
@ -775,7 +775,7 @@ func (p *Peer) pushVersionMsg() error {
// recently seen nonces. // recently seen nonces.
nonce, err := wire.RandomUint64() nonce, err := wire.RandomUint64()
if err != nil { if err != nil {
return err return nil, err
} }
sentNonces.Add(nonce) sentNonces.Add(nonce)
@ -810,8 +810,7 @@ func (p *Peer) pushVersionMsg() error {
// Advertise if inv messages for transactions are desired. // Advertise if inv messages for transactions are desired.
msg.DisableRelayTx = p.cfg.DisableRelayTx msg.DisableRelayTx = p.cfg.DisableRelayTx
p.QueueMessage(msg, nil) return msg, nil
return nil
} }
// PushAddrMsg sends an addr message to the connected peer using the provided // PushAddrMsg sends an addr message to the connected peer using the provided
@ -913,8 +912,8 @@ func (p *Peer) PushGetHeadersMsg(locator blockchain.BlockLocator, stopHash *wire
p.prevGetHdrsMtx.Unlock() p.prevGetHdrsMtx.Unlock()
if isDuplicate { if isDuplicate {
log.Tracef("Filtering duplicate [getheaders] with begin "+ log.Tracef("Filtering duplicate [getheaders] with begin hash %v",
"hash %v", beginHash) beginHash)
return nil return nil
} }
@ -974,10 +973,10 @@ func (p *Peer) PushRejectMsg(command string, code wire.RejectCode, reason string
<-doneChan <-doneChan
} }
// handleVersionMsg is invoked when a peer receives a version bitcoin message // handleRemoteVersionMsg is invoked when a version bitcoin message is received
// and is used to negotiate the protocol version details as well as kick start // from the remote peer. It will return an error if the remote peer's version
// the communications. // is not compatible with ours.
func (p *Peer) handleVersionMsg(msg *wire.MsgVersion) error { func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error {
// Detect self connections. // Detect self connections.
if !allowSelfConns && sentNonces.Exists(msg.Nonce) { if !allowSelfConns && sentNonces.Exists(msg.Nonce) {
return errors.New("disconnecting peer connected to self") return errors.New("disconnecting peer connected to self")
@ -991,21 +990,9 @@ func (p *Peer) handleVersionMsg(msg *wire.MsgVersion) error {
// disconnecting. // disconnecting.
reason := fmt.Sprintf("protocol version must be %d or greater", reason := fmt.Sprintf("protocol version must be %d or greater",
wire.MultipleAddressVersion) wire.MultipleAddressVersion)
p.PushRejectMsg(msg.Command(), wire.RejectObsolete, reason, rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectObsolete,
nil, true) reason)
return errors.New(reason) return p.writeMessage(rejectMsg)
}
// Limit to one version message per peer.
// No read lock is necessary because versionKnown is not written to in any
// other goroutine
if p.versionKnown {
// Send an reject message indicating the version message was
// incorrectly sent twice and wait for the message to be sent
// before disconnecting.
p.PushRejectMsg(msg.Command(), wire.RejectDuplicate,
"duplicate version message", nil, true)
return errors.New("only one version message per peer is allowed")
} }
// Updating a bunch of stats. // Updating a bunch of stats.
@ -1030,27 +1017,6 @@ func (p *Peer) handleVersionMsg(msg *wire.MsgVersion) error {
// Set the remote peer's user agent. // Set the remote peer's user agent.
p.userAgent = msg.UserAgent p.userAgent = msg.UserAgent
p.flagsMtx.Unlock() p.flagsMtx.Unlock()
// Inbound connections.
if p.inbound {
// Set up a NetAddress for the peer to be used with AddrManager.
// We only do this inbound because outbound set this up
// at connection time and no point recomputing.
na, err := newNetAddress(p.conn.RemoteAddr(), p.services)
if err != nil {
return err
}
p.na = na
// Send version.
err = p.pushVersionMsg()
if err != nil {
return err
}
}
// Send verack.
p.QueueMessage(wire.NewMsgVerAck(), nil)
return nil return nil
} }
@ -1147,18 +1113,6 @@ func (p *Peer) writeMessage(msg wire.Message) error {
if atomic.LoadInt32(&p.disconnect) != 0 { if atomic.LoadInt32(&p.disconnect) != 0 {
return nil return nil
} }
if !p.VersionKnown() {
switch msg.(type) {
case *wire.MsgVersion:
// This is OK.
case *wire.MsgReject:
// This is OK.
default:
// Drop all messages other than version and reject if
// the handshake has not already been done.
return nil
}
}
// Use closures to log expensive operations so they are only run when // Use closures to log expensive operations so they are only run when
// the logging level requires it. // the logging level requires it.
@ -1194,13 +1148,16 @@ func (p *Peer) writeMessage(msg wire.Message) error {
return err return err
} }
// isAllowedByRegression returns whether or not the passed error is allowed by // isAllowedReadError returns whether or not the passed error is allowed without
// regression tests without disconnecting the peer. In particular, regression // disconnecting the peer. In particular, regression tests need to be allowed
// tests need to be allowed to send malformed messages without the peer being // to send malformed messages without the peer being disconnected.
// disconnected. func (p *Peer) isAllowedReadError(err error) bool {
func (p *Peer) isAllowedByRegression(err error) bool { // Only allow read errors in regression test mode.
// Don't allow the error if it's not specifically a malformed message if p.cfg.ChainParams.Net != wire.TestNet {
// error. return false
}
// Don't allow the error if it's not specifically a malformed message error.
if _, ok := err.(*wire.MessageError); !ok { if _, ok := err.(*wire.MessageError); !ok {
return false return false
} }
@ -1220,12 +1177,6 @@ func (p *Peer) isAllowedByRegression(err error) bool {
return true return true
} }
// isRegTestNetwork returns whether or not the peer is running on the regression
// test network.
func (p *Peer) isRegTestNetwork() bool {
return p.cfg.ChainParams.Net == wire.TestNet
}
// shouldHandleReadError returns whether or not the passed error, which is // shouldHandleReadError returns whether or not the passed error, which is
// expected to have come from reading from the remote peer in the inHandler, // expected to have come from reading from the remote peer in the inHandler,
// should be logged and responded to with a reject message. // should be logged and responded to with a reject message.
@ -1437,14 +1388,8 @@ func (p *Peer) inHandler() {
// Peers must complete the initial version negotiation within a shorter // Peers must complete the initial version negotiation within a shorter
// timeframe than a general idle timeout. The timer is then reset below // timeframe than a general idle timeout. The timer is then reset below
// to idleTimeout for all future messages. // to idleTimeout for all future messages.
idleTimer := time.AfterFunc(negotiateTimeout, func() { idleTimer := time.AfterFunc(idleTimeout, func() {
if p.VersionKnown() { log.Warnf("Peer %s no answer for %s -- disconnecting", p, idleTimeout)
log.Warnf("Peer %s no answer for %s -- disconnecting",
p, idleTimeout)
} else {
log.Debugf("Peer %s no valid version message for %s -- "+
"disconnecting", p, negotiateTimeout)
}
p.Disconnect() p.Disconnect()
}) })
@ -1456,13 +1401,11 @@ out:
rmsg, buf, err := p.readMessage() rmsg, buf, err := p.readMessage()
idleTimer.Stop() idleTimer.Stop()
if err != nil { if err != nil {
// In order to allow regression tests with malformed // In order to allow regression tests with malformed messages, don't
// messages, don't disconnect the peer when we're in // disconnect the peer when we're in regression test mode and the
// regression test mode and the error is one of the // error is one of the allowed errors.
// allowed errors. if p.isAllowedReadError(err) {
if p.isRegTestNetwork() && p.isAllowedByRegression(err) { log.Errorf("Allowed test error from %s: %v", p, err)
log.Errorf("Allowed regression test error "+
"from %s: %v", p, err)
idleTimer.Reset(idleTimeout) idleTimer.Reset(idleTimeout)
continue continue
} }
@ -1471,70 +1414,40 @@ out:
// local peer is not forcibly disconnecting and the // local peer is not forcibly disconnecting and the
// remote peer has not disconnected. // remote peer has not disconnected.
if p.shouldHandleReadError(err) { if p.shouldHandleReadError(err) {
errMsg := fmt.Sprintf("Can't read message "+ errMsg := fmt.Sprintf("Can't read message from %s: %v", p, err)
"from %s: %v", p, err)
log.Errorf(errMsg) log.Errorf(errMsg)
// Push a reject message for the malformed // Push a reject message for the malformed message and wait for
// message and wait for the message to be sent // the message to be sent before disconnecting.
// before disconnecting.
// //
// NOTE: Ideally this would include the command // NOTE: Ideally this would include the command in the header if
// in the header if at least that much of the // at least that much of the message was valid, but that is not
// message was valid, but that is not currently // currently exposed by wire, so just used malformed for the
// exposed by wire, so just used malformed for // command.
// the command. p.PushRejectMsg("malformed", wire.RejectMalformed, errMsg, nil,
p.PushRejectMsg("malformed", true)
wire.RejectMalformed, errMsg, nil, true)
} }
break out break out
} }
atomic.StoreInt64(&p.lastRecv, time.Now().Unix()) atomic.StoreInt64(&p.lastRecv, time.Now().Unix())
p.stallControl <- stallControlMsg{sccReceiveMessage, rmsg} p.stallControl <- stallControlMsg{sccReceiveMessage, rmsg}
// Ensure version message comes first.
if vmsg, ok := rmsg.(*wire.MsgVersion); !ok && !p.VersionKnown() {
errStr := "A version message must precede all others"
log.Errorf(errStr)
// Push a reject message and wait for the message to be
// sent before disconnecting.
p.PushRejectMsg(vmsg.Command(), wire.RejectMalformed,
errStr, nil, true)
break out
}
// Handle each supported message type. // Handle each supported message type.
p.stallControl <- stallControlMsg{sccHandlerStart, rmsg} p.stallControl <- stallControlMsg{sccHandlerStart, rmsg}
switch msg := rmsg.(type) { switch msg := rmsg.(type) {
case *wire.MsgVersion: case *wire.MsgVersion:
err := p.handleVersionMsg(msg)
if err != nil { p.PushRejectMsg(msg.Command(), wire.RejectDuplicate,
log.Debugf("New peer %v - error negotiating protocol: %v", "duplicate version message", nil, true)
p, err) break out
p.Disconnect()
break out
}
if p.cfg.Listeners.OnVersion != nil {
p.cfg.Listeners.OnVersion(p, msg)
}
case *wire.MsgVerAck: case *wire.MsgVerAck:
p.flagsMtx.Lock()
versionSent := p.versionSent
p.flagsMtx.Unlock()
if !versionSent {
log.Infof("Received 'verack' from peer %v "+
"before version was sent -- "+
"disconnecting", p)
break out
}
// No read lock is necessary because verAckReceived is // No read lock is necessary because verAckReceived is not written
// not written to in any other goroutine. // to in any other goroutine.
if p.verAckReceived { if p.verAckReceived {
log.Infof("Already received 'verack' from "+ log.Infof("Already received 'verack' from peer %v -- "+
"peer %v -- disconnecting", p) "disconnecting", p)
break out break out
} }
p.flagsMtx.Lock() p.flagsMtx.Lock()
@ -1830,13 +1743,6 @@ out:
select { select {
case msg := <-p.sendQueue: case msg := <-p.sendQueue:
switch m := msg.msg.(type) { switch m := msg.msg.(type) {
case *wire.MsgVersion:
// Set the flag which indicates the version has
// been sent.
p.flagsMtx.Lock()
p.versionSent = true
p.flagsMtx.Unlock()
case *wire.MsgPing: case *wire.MsgPing:
// Only expects a pong message in later protocol // Only expects a pong message in later protocol
// versions. Also set up statistics. // versions. Also set up statistics.
@ -1849,8 +1755,7 @@ out:
} }
p.stallControl <- stallControlMsg{sccSendMessage, msg.msg} p.stallControl <- stallControlMsg{sccSendMessage, msg.msg}
err := p.writeMessage(msg.msg) if err := p.writeMessage(msg.msg); err != nil {
if err != nil {
p.Disconnect() p.Disconnect()
if p.shouldLogWriteError(err) { if p.shouldLogWriteError(err) {
log.Errorf("Failed to send message to "+ log.Errorf("Failed to send message to "+
@ -1956,15 +1861,28 @@ func (p *Peer) Connect(conn net.Conn) {
return return
} }
if p.inbound {
p.addr = conn.RemoteAddr().String()
}
p.conn = conn p.conn = conn
p.timeConnected = time.Now() p.timeConnected = time.Now()
if p.inbound {
p.addr = p.conn.RemoteAddr().String()
// Set up a NetAddress for the peer to be used with AddrManager. We
// only do this inbound because outbound set this up at connection time
// and no point recomputing.
na, err := newNetAddress(p.conn.RemoteAddr(), p.services)
if err != nil {
log.Errorf("Cannot create remote net address: %v", err)
p.Disconnect()
return
}
p.na = na
}
go func() { go func() {
if err := p.start(); err != nil { if err := p.start(); err != nil {
log.Errorf("Cannot start peer %v: %v", p, err) log.Warnf("Cannot start peer %v: %v", p, err)
p.Disconnect()
} }
}() }()
} }
@ -1992,26 +1910,38 @@ func (p *Peer) Disconnect() {
close(p.quit) close(p.quit)
} }
// Start begins processing input and output messages. It also sends the initial // start begins processing input and output messages.
// version message for outbound connections to start the negotiation process.
func (p *Peer) start() error { func (p *Peer) start() error {
log.Tracef("Starting peer %s", p) log.Tracef("Starting peer %s", p)
// Send an initial version message if this is an outbound connection. negotiateErr := make(chan error)
if !p.inbound { go func() {
if err := p.pushVersionMsg(); err != nil { if p.inbound {
log.Errorf("Can't send outbound version message %v", err) negotiateErr <- p.negotiateInboundProtocol()
p.Disconnect() } else {
negotiateErr <- p.negotiateOutboundProtocol()
}
}()
// Negotiate the protocol within the specified negotiateTimeout.
select {
case err := <-negotiateErr:
if err != nil {
return err return err
} }
case <-time.After(negotiateTimeout):
return errors.New("protocol negotiation timeout")
} }
// Start processing input and output. // The protocol has been negotiated successfully so start processing input
// and output messages.
go p.stallHandler() go p.stallHandler()
go p.inHandler() go p.inHandler()
go p.queueHandler() go p.queueHandler()
go p.outHandler() go p.outHandler()
// Send our verack message now that the IO processing machinery has started.
p.QueueMessage(wire.NewMsgVerAck(), nil)
return nil return nil
} }
@ -2023,6 +1953,79 @@ func (p *Peer) WaitForDisconnect() {
<-p.quit <-p.quit
} }
// readRemoteVersionMsg waits for the next message to arrive from the remote
// peer. If the next message is not a version message or the version is not
// acceptable then return an error.
func (p *Peer) readRemoteVersionMsg() error {
// Read their version message.
msg, _, err := p.readMessage()
if err != nil {
return err
}
remoteVerMsg, ok := msg.(*wire.MsgVersion)
if !ok {
errStr := "A version message must precede all others"
log.Errorf(errStr)
rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectMalformed,
errStr)
return p.writeMessage(rejectMsg)
}
if err := p.handleRemoteVersionMsg(remoteVerMsg); err != nil {
return err
}
if p.cfg.Listeners.OnVersion != nil {
p.cfg.Listeners.OnVersion(p, remoteVerMsg)
}
return nil
}
// writeLocalVersionMsg writes our version message to the remote peer.
func (p *Peer) writeLocalVersionMsg() error {
localVerMsg, err := p.localVersionMsg()
if err != nil {
return err
}
if err := p.writeMessage(localVerMsg); err != nil {
return err
}
p.flagsMtx.Lock()
p.versionSent = true
p.flagsMtx.Unlock()
return nil
}
// negotiateInboundProtocol waits to receive a version message from the peer
// then sends our version message. If the events do not occur in that order then
// it returns an error.
func (p *Peer) negotiateInboundProtocol() error {
if err := p.readRemoteVersionMsg(); err != nil {
return err
}
return p.writeLocalVersionMsg()
}
// negotiateOutboundProtocol sends our version message then waits to receive a
// version message from the peer. If the events do not occur in that order then
// it returns an error.
func (p *Peer) negotiateOutboundProtocol() error {
if err := p.writeLocalVersionMsg(); err != nil {
return err
}
return p.readRemoteVersionMsg()
}
// newPeerBase returns a new base bitcoin peer based on the inbound flag. This // newPeerBase returns a new base bitcoin peer based on the inbound flag. This
// is used by the NewInboundPeer and NewOutboundPeer functions to perform base // is used by the NewInboundPeer and NewOutboundPeer functions to perform base
// setup needed by both types of peers. // setup needed by both types of peers.

View File

@ -204,12 +204,15 @@ func testPeer(t *testing.T, p *peer.Peer, s peerStats) {
// TestPeerConnection tests connection between inbound and outbound peers. // TestPeerConnection tests connection between inbound and outbound peers.
func TestPeerConnection(t *testing.T) { func TestPeerConnection(t *testing.T) {
verack := make(chan struct{}, 1) verack := make(chan struct{})
peerCfg := &peer.Config{ peerCfg := &peer.Config{
Listeners: peer.MessageListeners{ Listeners: peer.MessageListeners{
OnWrite: func(p *peer.Peer, bytesWritten int, msg wire.Message, err error) { OnVerAck: func(p *peer.Peer, msg *wire.MsgVerAck) {
switch msg.(type) { verack <- struct{}{}
case *wire.MsgVerAck: },
OnWrite: func(p *peer.Peer, bytesWritten int, msg wire.Message,
err error) {
if _, ok := msg.(*wire.MsgVerAck); ok {
verack <- struct{}{} verack <- struct{}{}
} }
}, },
@ -253,10 +256,10 @@ func TestPeerConnection(t *testing.T) {
} }
outPeer.Connect(outConn) outPeer.Connect(outConn)
for i := 0; i < 2; i++ { for i := 0; i < 4; i++ {
select { select {
case <-verack: case <-verack:
case <-time.After(time.Second * 1): case <-time.After(time.Second):
return nil, nil, errors.New("verack timeout") return nil, nil, errors.New("verack timeout")
} }
} }
@ -279,10 +282,10 @@ func TestPeerConnection(t *testing.T) {
} }
outPeer.Connect(outConn) outPeer.Connect(outConn)
for i := 0; i < 2; i++ { for i := 0; i < 4; i++ {
select { select {
case <-verack: case <-verack:
case <-time.After(time.Second * 1): case <-time.After(time.Second):
return nil, nil, errors.New("verack timeout") return nil, nil, errors.New("verack timeout")
} }
} }
@ -294,7 +297,7 @@ func TestPeerConnection(t *testing.T) {
for i, test := range tests { for i, test := range tests {
inPeer, outPeer, err := test.setup() inPeer, outPeer, err := test.setup()
if err != nil { if err != nil {
t.Errorf("TestPeerConnection setup #%d: unexpected err %v\n", i, err) t.Errorf("TestPeerConnection setup #%d: unexpected err %v", i, err)
return return
} }
testPeer(t, inPeer, wantStats) testPeer(t, inPeer, wantStats)
@ -302,6 +305,8 @@ func TestPeerConnection(t *testing.T) {
inPeer.Disconnect() inPeer.Disconnect()
outPeer.Disconnect() outPeer.Disconnect()
inPeer.WaitForDisconnect()
outPeer.WaitForDisconnect()
} }
} }
@ -547,6 +552,7 @@ func TestOutboundPeer(t *testing.T) {
select { select {
case <-disconnected: case <-disconnected:
close(disconnected)
case <-time.After(time.Second): case <-time.After(time.Second):
t.Fatal("Peer did not automatically disconnect.") t.Fatal("Peer did not automatically disconnect.")
} }
@ -580,6 +586,7 @@ func TestOutboundPeer(t *testing.T) {
} }
return hash, 234439, nil return hash, 234439, nil
} }
peerCfg.NewestBlock = newestBlock peerCfg.NewestBlock = newestBlock
r1, w1 := io.Pipe() r1, w1 := io.Pipe()
c1 := &conn{raddr: "10.0.0.1:8333", Writer: w1, Reader: r1} c1 := &conn{raddr: "10.0.0.1:8333", Writer: w1, Reader: r1}
@ -638,7 +645,8 @@ func TestOutboundPeer(t *testing.T) {
t.Errorf("PushGetHeadersMsg: unexpected err %v\n", err) t.Errorf("PushGetHeadersMsg: unexpected err %v\n", err)
return return
} }
p2.PushRejectMsg("block", wire.RejectMalformed, "malformed", nil, true)
p2.PushRejectMsg("block", wire.RejectMalformed, "malformed", nil, false)
p2.PushRejectMsg("block", wire.RejectInvalid, "invalid", nil, false) p2.PushRejectMsg("block", wire.RejectInvalid, "invalid", nil, false)
// Test Queue Messages // Test Queue Messages