peer: enable mockMessageConn to detect data races

We use unsynchronized counters to trigger a report under the race
detector if multiple reads or writes happen concurrently.
This commit is contained in:
Matt Morehouse 2023-11-16 17:32:19 -06:00
parent c4e0daa274
commit 08fff28504
No known key found for this signature in database
GPG key ID: CC8ECA224831C982

View file

@ -497,6 +497,16 @@ type mockMessageConn struct {
readMessages chan []byte readMessages chan []byte
curReadMessage []byte curReadMessage []byte
// writeRaceDetectingCounter is incremented on any function call
// associated with writing to the connection. The race detector will
// trigger on this counter if a data race exists.
writeRaceDetectingCounter int
// readRaceDetectingCounter is incremented on any function call
// associated with reading from the connection. The race detector will
// trigger on this counter if a data race exists.
readRaceDetectingCounter int
} }
func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn { func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn {
@ -509,17 +519,20 @@ func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn {
// SetWriteDeadline mocks setting write deadline for our conn. // SetWriteDeadline mocks setting write deadline for our conn.
func (m *mockMessageConn) SetWriteDeadline(time.Time) error { func (m *mockMessageConn) SetWriteDeadline(time.Time) error {
m.writeRaceDetectingCounter++
return nil return nil
} }
// Flush mocks a message conn flush. // Flush mocks a message conn flush.
func (m *mockMessageConn) Flush() (int, error) { func (m *mockMessageConn) Flush() (int, error) {
m.writeRaceDetectingCounter++
return 0, nil return 0, nil
} }
// WriteMessage mocks sending of a message on our connection. It will push // WriteMessage mocks sending of a message on our connection. It will push
// the bytes sent into the mock's writtenMessages channel. // the bytes sent into the mock's writtenMessages channel.
func (m *mockMessageConn) WriteMessage(msg []byte) error { func (m *mockMessageConn) WriteMessage(msg []byte) error {
m.writeRaceDetectingCounter++
select { select {
case m.writtenMessages <- msg: case m.writtenMessages <- msg:
case <-time.After(timeout): case <-time.After(timeout):
@ -542,15 +555,18 @@ func (m *mockMessageConn) assertWrite(expected []byte) {
} }
func (m *mockMessageConn) SetReadDeadline(t time.Time) error { func (m *mockMessageConn) SetReadDeadline(t time.Time) error {
m.readRaceDetectingCounter++
return nil return nil
} }
func (m *mockMessageConn) ReadNextHeader() (uint32, error) { func (m *mockMessageConn) ReadNextHeader() (uint32, error) {
m.readRaceDetectingCounter++
m.curReadMessage = <-m.readMessages m.curReadMessage = <-m.readMessages
return uint32(len(m.curReadMessage)), nil return uint32(len(m.curReadMessage)), nil
} }
func (m *mockMessageConn) ReadNextBody(buf []byte) ([]byte, error) { func (m *mockMessageConn) ReadNextBody(buf []byte) ([]byte, error) {
m.readRaceDetectingCounter++
return m.curReadMessage, nil return m.curReadMessage, nil
} }