mirror of
https://github.com/btcsuite/btcd.git
synced 2024-11-19 09:50:08 +01:00
peer: Extract protocol negotiation from main read and write code paths.
This allows cleaner separation of the half-duplex version negotiation from the fully duplex message passing between peers.
This commit is contained in:
parent
777ccdade3
commit
f3d759d783
319
peer/peer.go
319
peer/peer.go
@ -188,7 +188,7 @@ type MessageListeners struct {
|
|||||||
// not directly provide a callback.
|
// not directly provide a callback.
|
||||||
OnRead func(p *Peer, bytesRead int, msg wire.Message, err error)
|
OnRead func(p *Peer, bytesRead int, msg wire.Message, err error)
|
||||||
|
|
||||||
// OnWrite is invoked when a peer receives a bitcoin message. It
|
// OnWrite is invoked when we write a bitcoin message to a peer. It
|
||||||
// consists of the number of bytes written, the message, and whether or
|
// consists of the number of bytes written, the message, and whether or
|
||||||
// not an error in the write occurred. This can be useful for
|
// not an error in the write occurred. This can be useful for
|
||||||
// circumstances such as keeping track of server-wide byte counts.
|
// circumstances such as keeping track of server-wide byte counts.
|
||||||
@ -735,15 +735,15 @@ func (p *Peer) WantsHeaders() bool {
|
|||||||
return p.sendHeadersPreferred
|
return p.sendHeadersPreferred
|
||||||
}
|
}
|
||||||
|
|
||||||
// pushVersionMsg sends a version message to the connected peer using the
|
// localVersionMsg creates a version message that can be used to send to the
|
||||||
// current state.
|
// remote peer.
|
||||||
func (p *Peer) pushVersionMsg() error {
|
func (p *Peer) localVersionMsg() (*wire.MsgVersion, error) {
|
||||||
var blockNum int32
|
var blockNum int32
|
||||||
if p.cfg.NewestBlock != nil {
|
if p.cfg.NewestBlock != nil {
|
||||||
var err error
|
var err error
|
||||||
_, blockNum, err = p.cfg.NewestBlock()
|
_, blockNum, err = p.cfg.NewestBlock()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -775,7 +775,7 @@ func (p *Peer) pushVersionMsg() error {
|
|||||||
// recently seen nonces.
|
// recently seen nonces.
|
||||||
nonce, err := wire.RandomUint64()
|
nonce, err := wire.RandomUint64()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
sentNonces.Add(nonce)
|
sentNonces.Add(nonce)
|
||||||
|
|
||||||
@ -810,8 +810,7 @@ func (p *Peer) pushVersionMsg() error {
|
|||||||
// Advertise if inv messages for transactions are desired.
|
// Advertise if inv messages for transactions are desired.
|
||||||
msg.DisableRelayTx = p.cfg.DisableRelayTx
|
msg.DisableRelayTx = p.cfg.DisableRelayTx
|
||||||
|
|
||||||
p.QueueMessage(msg, nil)
|
return msg, nil
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// PushAddrMsg sends an addr message to the connected peer using the provided
|
// PushAddrMsg sends an addr message to the connected peer using the provided
|
||||||
@ -913,8 +912,8 @@ func (p *Peer) PushGetHeadersMsg(locator blockchain.BlockLocator, stopHash *wire
|
|||||||
p.prevGetHdrsMtx.Unlock()
|
p.prevGetHdrsMtx.Unlock()
|
||||||
|
|
||||||
if isDuplicate {
|
if isDuplicate {
|
||||||
log.Tracef("Filtering duplicate [getheaders] with begin "+
|
log.Tracef("Filtering duplicate [getheaders] with begin hash %v",
|
||||||
"hash %v", beginHash)
|
beginHash)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -974,10 +973,10 @@ func (p *Peer) PushRejectMsg(command string, code wire.RejectCode, reason string
|
|||||||
<-doneChan
|
<-doneChan
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleVersionMsg is invoked when a peer receives a version bitcoin message
|
// handleRemoteVersionMsg is invoked when a version bitcoin message is received
|
||||||
// and is used to negotiate the protocol version details as well as kick start
|
// from the remote peer. It will return an error if the remote peer's version
|
||||||
// the communications.
|
// is not compatible with ours.
|
||||||
func (p *Peer) handleVersionMsg(msg *wire.MsgVersion) error {
|
func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error {
|
||||||
// Detect self connections.
|
// Detect self connections.
|
||||||
if !allowSelfConns && sentNonces.Exists(msg.Nonce) {
|
if !allowSelfConns && sentNonces.Exists(msg.Nonce) {
|
||||||
return errors.New("disconnecting peer connected to self")
|
return errors.New("disconnecting peer connected to self")
|
||||||
@ -991,21 +990,9 @@ func (p *Peer) handleVersionMsg(msg *wire.MsgVersion) error {
|
|||||||
// disconnecting.
|
// disconnecting.
|
||||||
reason := fmt.Sprintf("protocol version must be %d or greater",
|
reason := fmt.Sprintf("protocol version must be %d or greater",
|
||||||
wire.MultipleAddressVersion)
|
wire.MultipleAddressVersion)
|
||||||
p.PushRejectMsg(msg.Command(), wire.RejectObsolete, reason,
|
rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectObsolete,
|
||||||
nil, true)
|
reason)
|
||||||
return errors.New(reason)
|
return p.writeMessage(rejectMsg)
|
||||||
}
|
|
||||||
|
|
||||||
// Limit to one version message per peer.
|
|
||||||
// No read lock is necessary because versionKnown is not written to in any
|
|
||||||
// other goroutine
|
|
||||||
if p.versionKnown {
|
|
||||||
// Send an reject message indicating the version message was
|
|
||||||
// incorrectly sent twice and wait for the message to be sent
|
|
||||||
// before disconnecting.
|
|
||||||
p.PushRejectMsg(msg.Command(), wire.RejectDuplicate,
|
|
||||||
"duplicate version message", nil, true)
|
|
||||||
return errors.New("only one version message per peer is allowed")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Updating a bunch of stats.
|
// Updating a bunch of stats.
|
||||||
@ -1030,27 +1017,6 @@ func (p *Peer) handleVersionMsg(msg *wire.MsgVersion) error {
|
|||||||
// Set the remote peer's user agent.
|
// Set the remote peer's user agent.
|
||||||
p.userAgent = msg.UserAgent
|
p.userAgent = msg.UserAgent
|
||||||
p.flagsMtx.Unlock()
|
p.flagsMtx.Unlock()
|
||||||
|
|
||||||
// Inbound connections.
|
|
||||||
if p.inbound {
|
|
||||||
// Set up a NetAddress for the peer to be used with AddrManager.
|
|
||||||
// We only do this inbound because outbound set this up
|
|
||||||
// at connection time and no point recomputing.
|
|
||||||
na, err := newNetAddress(p.conn.RemoteAddr(), p.services)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
p.na = na
|
|
||||||
|
|
||||||
// Send version.
|
|
||||||
err = p.pushVersionMsg()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send verack.
|
|
||||||
p.QueueMessage(wire.NewMsgVerAck(), nil)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1147,18 +1113,6 @@ func (p *Peer) writeMessage(msg wire.Message) error {
|
|||||||
if atomic.LoadInt32(&p.disconnect) != 0 {
|
if atomic.LoadInt32(&p.disconnect) != 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if !p.VersionKnown() {
|
|
||||||
switch msg.(type) {
|
|
||||||
case *wire.MsgVersion:
|
|
||||||
// This is OK.
|
|
||||||
case *wire.MsgReject:
|
|
||||||
// This is OK.
|
|
||||||
default:
|
|
||||||
// Drop all messages other than version and reject if
|
|
||||||
// the handshake has not already been done.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use closures to log expensive operations so they are only run when
|
// Use closures to log expensive operations so they are only run when
|
||||||
// the logging level requires it.
|
// the logging level requires it.
|
||||||
@ -1194,13 +1148,16 @@ func (p *Peer) writeMessage(msg wire.Message) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// isAllowedByRegression returns whether or not the passed error is allowed by
|
// isAllowedReadError returns whether or not the passed error is allowed without
|
||||||
// regression tests without disconnecting the peer. In particular, regression
|
// disconnecting the peer. In particular, regression tests need to be allowed
|
||||||
// tests need to be allowed to send malformed messages without the peer being
|
// to send malformed messages without the peer being disconnected.
|
||||||
// disconnected.
|
func (p *Peer) isAllowedReadError(err error) bool {
|
||||||
func (p *Peer) isAllowedByRegression(err error) bool {
|
// Only allow read errors in regression test mode.
|
||||||
// Don't allow the error if it's not specifically a malformed message
|
if p.cfg.ChainParams.Net != wire.TestNet {
|
||||||
// error.
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't allow the error if it's not specifically a malformed message error.
|
||||||
if _, ok := err.(*wire.MessageError); !ok {
|
if _, ok := err.(*wire.MessageError); !ok {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -1220,12 +1177,6 @@ func (p *Peer) isAllowedByRegression(err error) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// isRegTestNetwork returns whether or not the peer is running on the regression
|
|
||||||
// test network.
|
|
||||||
func (p *Peer) isRegTestNetwork() bool {
|
|
||||||
return p.cfg.ChainParams.Net == wire.TestNet
|
|
||||||
}
|
|
||||||
|
|
||||||
// shouldHandleReadError returns whether or not the passed error, which is
|
// shouldHandleReadError returns whether or not the passed error, which is
|
||||||
// expected to have come from reading from the remote peer in the inHandler,
|
// expected to have come from reading from the remote peer in the inHandler,
|
||||||
// should be logged and responded to with a reject message.
|
// should be logged and responded to with a reject message.
|
||||||
@ -1437,14 +1388,8 @@ func (p *Peer) inHandler() {
|
|||||||
// Peers must complete the initial version negotiation within a shorter
|
// Peers must complete the initial version negotiation within a shorter
|
||||||
// timeframe than a general idle timeout. The timer is then reset below
|
// timeframe than a general idle timeout. The timer is then reset below
|
||||||
// to idleTimeout for all future messages.
|
// to idleTimeout for all future messages.
|
||||||
idleTimer := time.AfterFunc(negotiateTimeout, func() {
|
idleTimer := time.AfterFunc(idleTimeout, func() {
|
||||||
if p.VersionKnown() {
|
log.Warnf("Peer %s no answer for %s -- disconnecting", p, idleTimeout)
|
||||||
log.Warnf("Peer %s no answer for %s -- disconnecting",
|
|
||||||
p, idleTimeout)
|
|
||||||
} else {
|
|
||||||
log.Debugf("Peer %s no valid version message for %s -- "+
|
|
||||||
"disconnecting", p, negotiateTimeout)
|
|
||||||
}
|
|
||||||
p.Disconnect()
|
p.Disconnect()
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -1456,13 +1401,11 @@ out:
|
|||||||
rmsg, buf, err := p.readMessage()
|
rmsg, buf, err := p.readMessage()
|
||||||
idleTimer.Stop()
|
idleTimer.Stop()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// In order to allow regression tests with malformed
|
// In order to allow regression tests with malformed messages, don't
|
||||||
// messages, don't disconnect the peer when we're in
|
// disconnect the peer when we're in regression test mode and the
|
||||||
// regression test mode and the error is one of the
|
// error is one of the allowed errors.
|
||||||
// allowed errors.
|
if p.isAllowedReadError(err) {
|
||||||
if p.isRegTestNetwork() && p.isAllowedByRegression(err) {
|
log.Errorf("Allowed test error from %s: %v", p, err)
|
||||||
log.Errorf("Allowed regression test error "+
|
|
||||||
"from %s: %v", p, err)
|
|
||||||
idleTimer.Reset(idleTimeout)
|
idleTimer.Reset(idleTimeout)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -1471,70 +1414,40 @@ out:
|
|||||||
// local peer is not forcibly disconnecting and the
|
// local peer is not forcibly disconnecting and the
|
||||||
// remote peer has not disconnected.
|
// remote peer has not disconnected.
|
||||||
if p.shouldHandleReadError(err) {
|
if p.shouldHandleReadError(err) {
|
||||||
errMsg := fmt.Sprintf("Can't read message "+
|
errMsg := fmt.Sprintf("Can't read message from %s: %v", p, err)
|
||||||
"from %s: %v", p, err)
|
|
||||||
log.Errorf(errMsg)
|
log.Errorf(errMsg)
|
||||||
|
|
||||||
// Push a reject message for the malformed
|
// Push a reject message for the malformed message and wait for
|
||||||
// message and wait for the message to be sent
|
// the message to be sent before disconnecting.
|
||||||
// before disconnecting.
|
|
||||||
//
|
//
|
||||||
// NOTE: Ideally this would include the command
|
// NOTE: Ideally this would include the command in the header if
|
||||||
// in the header if at least that much of the
|
// at least that much of the message was valid, but that is not
|
||||||
// message was valid, but that is not currently
|
// currently exposed by wire, so just used malformed for the
|
||||||
// exposed by wire, so just used malformed for
|
// command.
|
||||||
// the command.
|
p.PushRejectMsg("malformed", wire.RejectMalformed, errMsg, nil,
|
||||||
p.PushRejectMsg("malformed",
|
true)
|
||||||
wire.RejectMalformed, errMsg, nil, true)
|
|
||||||
}
|
}
|
||||||
break out
|
break out
|
||||||
}
|
}
|
||||||
atomic.StoreInt64(&p.lastRecv, time.Now().Unix())
|
atomic.StoreInt64(&p.lastRecv, time.Now().Unix())
|
||||||
p.stallControl <- stallControlMsg{sccReceiveMessage, rmsg}
|
p.stallControl <- stallControlMsg{sccReceiveMessage, rmsg}
|
||||||
|
|
||||||
// Ensure version message comes first.
|
|
||||||
if vmsg, ok := rmsg.(*wire.MsgVersion); !ok && !p.VersionKnown() {
|
|
||||||
errStr := "A version message must precede all others"
|
|
||||||
log.Errorf(errStr)
|
|
||||||
|
|
||||||
// Push a reject message and wait for the message to be
|
|
||||||
// sent before disconnecting.
|
|
||||||
p.PushRejectMsg(vmsg.Command(), wire.RejectMalformed,
|
|
||||||
errStr, nil, true)
|
|
||||||
break out
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle each supported message type.
|
// Handle each supported message type.
|
||||||
p.stallControl <- stallControlMsg{sccHandlerStart, rmsg}
|
p.stallControl <- stallControlMsg{sccHandlerStart, rmsg}
|
||||||
switch msg := rmsg.(type) {
|
switch msg := rmsg.(type) {
|
||||||
case *wire.MsgVersion:
|
case *wire.MsgVersion:
|
||||||
err := p.handleVersionMsg(msg)
|
|
||||||
if err != nil {
|
p.PushRejectMsg(msg.Command(), wire.RejectDuplicate,
|
||||||
log.Debugf("New peer %v - error negotiating protocol: %v",
|
"duplicate version message", nil, true)
|
||||||
p, err)
|
break out
|
||||||
p.Disconnect()
|
|
||||||
break out
|
|
||||||
}
|
|
||||||
if p.cfg.Listeners.OnVersion != nil {
|
|
||||||
p.cfg.Listeners.OnVersion(p, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
case *wire.MsgVerAck:
|
case *wire.MsgVerAck:
|
||||||
p.flagsMtx.Lock()
|
|
||||||
versionSent := p.versionSent
|
|
||||||
p.flagsMtx.Unlock()
|
|
||||||
if !versionSent {
|
|
||||||
log.Infof("Received 'verack' from peer %v "+
|
|
||||||
"before version was sent -- "+
|
|
||||||
"disconnecting", p)
|
|
||||||
break out
|
|
||||||
}
|
|
||||||
|
|
||||||
// No read lock is necessary because verAckReceived is
|
// No read lock is necessary because verAckReceived is not written
|
||||||
// not written to in any other goroutine.
|
// to in any other goroutine.
|
||||||
if p.verAckReceived {
|
if p.verAckReceived {
|
||||||
log.Infof("Already received 'verack' from "+
|
log.Infof("Already received 'verack' from peer %v -- "+
|
||||||
"peer %v -- disconnecting", p)
|
"disconnecting", p)
|
||||||
break out
|
break out
|
||||||
}
|
}
|
||||||
p.flagsMtx.Lock()
|
p.flagsMtx.Lock()
|
||||||
@ -1830,13 +1743,6 @@ out:
|
|||||||
select {
|
select {
|
||||||
case msg := <-p.sendQueue:
|
case msg := <-p.sendQueue:
|
||||||
switch m := msg.msg.(type) {
|
switch m := msg.msg.(type) {
|
||||||
case *wire.MsgVersion:
|
|
||||||
// Set the flag which indicates the version has
|
|
||||||
// been sent.
|
|
||||||
p.flagsMtx.Lock()
|
|
||||||
p.versionSent = true
|
|
||||||
p.flagsMtx.Unlock()
|
|
||||||
|
|
||||||
case *wire.MsgPing:
|
case *wire.MsgPing:
|
||||||
// Only expects a pong message in later protocol
|
// Only expects a pong message in later protocol
|
||||||
// versions. Also set up statistics.
|
// versions. Also set up statistics.
|
||||||
@ -1849,8 +1755,7 @@ out:
|
|||||||
}
|
}
|
||||||
|
|
||||||
p.stallControl <- stallControlMsg{sccSendMessage, msg.msg}
|
p.stallControl <- stallControlMsg{sccSendMessage, msg.msg}
|
||||||
err := p.writeMessage(msg.msg)
|
if err := p.writeMessage(msg.msg); err != nil {
|
||||||
if err != nil {
|
|
||||||
p.Disconnect()
|
p.Disconnect()
|
||||||
if p.shouldLogWriteError(err) {
|
if p.shouldLogWriteError(err) {
|
||||||
log.Errorf("Failed to send message to "+
|
log.Errorf("Failed to send message to "+
|
||||||
@ -1956,15 +1861,28 @@ func (p *Peer) Connect(conn net.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.inbound {
|
|
||||||
p.addr = conn.RemoteAddr().String()
|
|
||||||
}
|
|
||||||
p.conn = conn
|
p.conn = conn
|
||||||
p.timeConnected = time.Now()
|
p.timeConnected = time.Now()
|
||||||
|
|
||||||
|
if p.inbound {
|
||||||
|
p.addr = p.conn.RemoteAddr().String()
|
||||||
|
|
||||||
|
// Set up a NetAddress for the peer to be used with AddrManager. We
|
||||||
|
// only do this inbound because outbound set this up at connection time
|
||||||
|
// and no point recomputing.
|
||||||
|
na, err := newNetAddress(p.conn.RemoteAddr(), p.services)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Cannot create remote net address: %v", err)
|
||||||
|
p.Disconnect()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.na = na
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := p.start(); err != nil {
|
if err := p.start(); err != nil {
|
||||||
log.Errorf("Cannot start peer %v: %v", p, err)
|
log.Warnf("Cannot start peer %v: %v", p, err)
|
||||||
|
p.Disconnect()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
@ -1992,26 +1910,38 @@ func (p *Peer) Disconnect() {
|
|||||||
close(p.quit)
|
close(p.quit)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start begins processing input and output messages. It also sends the initial
|
// start begins processing input and output messages.
|
||||||
// version message for outbound connections to start the negotiation process.
|
|
||||||
func (p *Peer) start() error {
|
func (p *Peer) start() error {
|
||||||
log.Tracef("Starting peer %s", p)
|
log.Tracef("Starting peer %s", p)
|
||||||
|
|
||||||
// Send an initial version message if this is an outbound connection.
|
negotiateErr := make(chan error)
|
||||||
if !p.inbound {
|
go func() {
|
||||||
if err := p.pushVersionMsg(); err != nil {
|
if p.inbound {
|
||||||
log.Errorf("Can't send outbound version message %v", err)
|
negotiateErr <- p.negotiateInboundProtocol()
|
||||||
p.Disconnect()
|
} else {
|
||||||
|
negotiateErr <- p.negotiateOutboundProtocol()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Negotiate the protocol within the specified negotiateTimeout.
|
||||||
|
select {
|
||||||
|
case err := <-negotiateErr:
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
case <-time.After(negotiateTimeout):
|
||||||
|
return errors.New("protocol negotiation timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start processing input and output.
|
// The protocol has been negotiated successfully so start processing input
|
||||||
|
// and output messages.
|
||||||
go p.stallHandler()
|
go p.stallHandler()
|
||||||
go p.inHandler()
|
go p.inHandler()
|
||||||
go p.queueHandler()
|
go p.queueHandler()
|
||||||
go p.outHandler()
|
go p.outHandler()
|
||||||
|
|
||||||
|
// Send our verack message now that the IO processing machinery has started.
|
||||||
|
p.QueueMessage(wire.NewMsgVerAck(), nil)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2023,6 +1953,79 @@ func (p *Peer) WaitForDisconnect() {
|
|||||||
<-p.quit
|
<-p.quit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// readRemoteVersionMsg waits for the next message to arrive from the remote
|
||||||
|
// peer. If the next message is not a version message or the version is not
|
||||||
|
// acceptable then return an error.
|
||||||
|
func (p *Peer) readRemoteVersionMsg() error {
|
||||||
|
|
||||||
|
// Read their version message.
|
||||||
|
msg, _, err := p.readMessage()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteVerMsg, ok := msg.(*wire.MsgVersion)
|
||||||
|
if !ok {
|
||||||
|
errStr := "A version message must precede all others"
|
||||||
|
log.Errorf(errStr)
|
||||||
|
|
||||||
|
rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectMalformed,
|
||||||
|
errStr)
|
||||||
|
return p.writeMessage(rejectMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.handleRemoteVersionMsg(remoteVerMsg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.cfg.Listeners.OnVersion != nil {
|
||||||
|
p.cfg.Listeners.OnVersion(p, remoteVerMsg)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeLocalVersionMsg writes our version message to the remote peer.
|
||||||
|
func (p *Peer) writeLocalVersionMsg() error {
|
||||||
|
|
||||||
|
localVerMsg, err := p.localVersionMsg()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.writeMessage(localVerMsg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.flagsMtx.Lock()
|
||||||
|
p.versionSent = true
|
||||||
|
p.flagsMtx.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// negotiateInboundProtocol waits to receive a version message from the peer
|
||||||
|
// then sends our version message. If the events do not occur in that order then
|
||||||
|
// it returns an error.
|
||||||
|
func (p *Peer) negotiateInboundProtocol() error {
|
||||||
|
|
||||||
|
if err := p.readRemoteVersionMsg(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.writeLocalVersionMsg()
|
||||||
|
}
|
||||||
|
|
||||||
|
// negotiateOutboundProtocol sends our version message then waits to receive a
|
||||||
|
// version message from the peer. If the events do not occur in that order then
|
||||||
|
// it returns an error.
|
||||||
|
func (p *Peer) negotiateOutboundProtocol() error {
|
||||||
|
|
||||||
|
if err := p.writeLocalVersionMsg(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.readRemoteVersionMsg()
|
||||||
|
}
|
||||||
|
|
||||||
// newPeerBase returns a new base bitcoin peer based on the inbound flag. This
|
// newPeerBase returns a new base bitcoin peer based on the inbound flag. This
|
||||||
// is used by the NewInboundPeer and NewOutboundPeer functions to perform base
|
// is used by the NewInboundPeer and NewOutboundPeer functions to perform base
|
||||||
// setup needed by both types of peers.
|
// setup needed by both types of peers.
|
||||||
|
@ -204,12 +204,15 @@ func testPeer(t *testing.T, p *peer.Peer, s peerStats) {
|
|||||||
|
|
||||||
// TestPeerConnection tests connection between inbound and outbound peers.
|
// TestPeerConnection tests connection between inbound and outbound peers.
|
||||||
func TestPeerConnection(t *testing.T) {
|
func TestPeerConnection(t *testing.T) {
|
||||||
verack := make(chan struct{}, 1)
|
verack := make(chan struct{})
|
||||||
peerCfg := &peer.Config{
|
peerCfg := &peer.Config{
|
||||||
Listeners: peer.MessageListeners{
|
Listeners: peer.MessageListeners{
|
||||||
OnWrite: func(p *peer.Peer, bytesWritten int, msg wire.Message, err error) {
|
OnVerAck: func(p *peer.Peer, msg *wire.MsgVerAck) {
|
||||||
switch msg.(type) {
|
verack <- struct{}{}
|
||||||
case *wire.MsgVerAck:
|
},
|
||||||
|
OnWrite: func(p *peer.Peer, bytesWritten int, msg wire.Message,
|
||||||
|
err error) {
|
||||||
|
if _, ok := msg.(*wire.MsgVerAck); ok {
|
||||||
verack <- struct{}{}
|
verack <- struct{}{}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -253,10 +256,10 @@ func TestPeerConnection(t *testing.T) {
|
|||||||
}
|
}
|
||||||
outPeer.Connect(outConn)
|
outPeer.Connect(outConn)
|
||||||
|
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 4; i++ {
|
||||||
select {
|
select {
|
||||||
case <-verack:
|
case <-verack:
|
||||||
case <-time.After(time.Second * 1):
|
case <-time.After(time.Second):
|
||||||
return nil, nil, errors.New("verack timeout")
|
return nil, nil, errors.New("verack timeout")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -279,10 +282,10 @@ func TestPeerConnection(t *testing.T) {
|
|||||||
}
|
}
|
||||||
outPeer.Connect(outConn)
|
outPeer.Connect(outConn)
|
||||||
|
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 4; i++ {
|
||||||
select {
|
select {
|
||||||
case <-verack:
|
case <-verack:
|
||||||
case <-time.After(time.Second * 1):
|
case <-time.After(time.Second):
|
||||||
return nil, nil, errors.New("verack timeout")
|
return nil, nil, errors.New("verack timeout")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -294,7 +297,7 @@ func TestPeerConnection(t *testing.T) {
|
|||||||
for i, test := range tests {
|
for i, test := range tests {
|
||||||
inPeer, outPeer, err := test.setup()
|
inPeer, outPeer, err := test.setup()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("TestPeerConnection setup #%d: unexpected err %v\n", i, err)
|
t.Errorf("TestPeerConnection setup #%d: unexpected err %v", i, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
testPeer(t, inPeer, wantStats)
|
testPeer(t, inPeer, wantStats)
|
||||||
@ -302,6 +305,8 @@ func TestPeerConnection(t *testing.T) {
|
|||||||
|
|
||||||
inPeer.Disconnect()
|
inPeer.Disconnect()
|
||||||
outPeer.Disconnect()
|
outPeer.Disconnect()
|
||||||
|
inPeer.WaitForDisconnect()
|
||||||
|
outPeer.WaitForDisconnect()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -547,6 +552,7 @@ func TestOutboundPeer(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-disconnected:
|
case <-disconnected:
|
||||||
|
close(disconnected)
|
||||||
case <-time.After(time.Second):
|
case <-time.After(time.Second):
|
||||||
t.Fatal("Peer did not automatically disconnect.")
|
t.Fatal("Peer did not automatically disconnect.")
|
||||||
}
|
}
|
||||||
@ -580,6 +586,7 @@ func TestOutboundPeer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
return hash, 234439, nil
|
return hash, 234439, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
peerCfg.NewestBlock = newestBlock
|
peerCfg.NewestBlock = newestBlock
|
||||||
r1, w1 := io.Pipe()
|
r1, w1 := io.Pipe()
|
||||||
c1 := &conn{raddr: "10.0.0.1:8333", Writer: w1, Reader: r1}
|
c1 := &conn{raddr: "10.0.0.1:8333", Writer: w1, Reader: r1}
|
||||||
@ -638,7 +645,8 @@ func TestOutboundPeer(t *testing.T) {
|
|||||||
t.Errorf("PushGetHeadersMsg: unexpected err %v\n", err)
|
t.Errorf("PushGetHeadersMsg: unexpected err %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p2.PushRejectMsg("block", wire.RejectMalformed, "malformed", nil, true)
|
|
||||||
|
p2.PushRejectMsg("block", wire.RejectMalformed, "malformed", nil, false)
|
||||||
p2.PushRejectMsg("block", wire.RejectInvalid, "invalid", nil, false)
|
p2.PushRejectMsg("block", wire.RejectInvalid, "invalid", nil, false)
|
||||||
|
|
||||||
// Test Queue Messages
|
// Test Queue Messages
|
||||||
|
Loading…
Reference in New Issue
Block a user