brontide: exclude MAC length from cipher text packet length prefix

Pervasively we would include the length of the MAC in the length prefix
for cipher text packets. As a result, the MAC would eat into the total
payload size. To remedy this, we now exclude the MAC from the length
prefix for cipher text packets, and instead account for the length of
the MAC on the packet when reading messages.
This commit is contained in:
Olaoluwa Osuntokun 2017-01-07 19:15:58 -08:00
parent 387d41e5df
commit d046efb502
No known key found for this signature in database
GPG Key ID: 9CC5B105D03521A2
3 changed files with 18 additions and 17 deletions

View File

@ -117,14 +117,13 @@ func (c *Conn) Read(b []byte) (n int, err error) {
func (c *Conn) Write(b []byte) (n int, err error) {
// If the message doesn't require any chunking, then we can go ahead
// with a single write.
if len(b)+macSize <= math.MaxUint16 {
if len(b) <= math.MaxUint16 {
return len(b), c.noise.WriteMessage(c.conn, b)
}
// If we need to split the message into fragments, then we'll write
// chunks which maximize usage of the available payload. To do so, we
// subtract the added overhead of the MAC at the end of the message.
chunkSize := math.MaxUint16 - macSize
// chunks which maximize usage of the available payload.
chunkSize := math.MaxUint16
bytesToWrite := len(b)
bytesWritten := 0

View File

@ -641,12 +641,13 @@ func (b *BrontideMachine) WriteMessage(w io.Writer, p []byte) error {
// The total length of each message payload including the MAC size
// payload exceed the largest number encodable within a 16-bit unsigned
// integer.
if len(p)+macSize > math.MaxUint16 {
if len(p) > math.MaxUint16 {
return ErrMaxMessageLengthExceeded
}
// The full length of the packet includes the 16 byte MAC.
fullLength := uint16(len(p) + macSize)
// The full length of the packet is only the packet length, and does
// NOT include the MAC.
fullLength := uint16(len(p))
var pktLen [2]byte
binary.BigEndian.PutUint16(pktLen[:], fullLength)
@ -684,11 +685,11 @@ func (b *BrontideMachine) ReadMessage(r io.Reader) ([]byte, error) {
// Next, using the length read from the packet header, read the
// encrypted packet itself.
pktLen := binary.BigEndian.Uint16(pktLenBytes)
ciperText := make([]byte, pktLen)
if _, err := io.ReadFull(r, ciperText[:]); err != nil {
pktLen := uint32(binary.BigEndian.Uint16(pktLenBytes)) + macSize
cipherText := make([]byte, pktLen)
if _, err := io.ReadFull(r, cipherText[:]); err != nil {
return nil, err
}
return b.recvCipher.Decrypt(nil, nil, ciperText)
return b.recvCipher.Decrypt(nil, nil, cipherText)
}

View File

@ -63,6 +63,7 @@ func establishTestConnection() (net.Conn, net.Conn, error) {
return localConn, remoteConn, nil
}
func TestConnectionCorrectness(t *testing.T) {
// Create a test connection, grabbing either side of the connection
// into local variables. If the initial crypto handshake fails, then
@ -130,9 +131,9 @@ func TestMaxPayloadLength(t *testing.T) {
"should have been rejected")
}
// Generate another payload which with the MAC acounted for, should be
// accepted as a valid payload.
payloadToAccept := make([]byte, math.MaxUint16-macSize)
// Generate another payload which should be accepted as a valid
// payload.
payloadToAccept := make([]byte, math.MaxUint16-1)
if err := b.WriteMessage(&buf, payloadToAccept); err != nil {
t.Fatalf("write for payload was rejected, should have been " +
"accepted")
@ -140,7 +141,7 @@ func TestMaxPayloadLength(t *testing.T) {
// Generate a final payload which is juuust over the max payload length
// when the MAC is accounted for.
payloadToReject = make([]byte, math.MaxUint16-macSize+1)
payloadToReject = make([]byte, math.MaxUint16+1)
// This payload should be rejected.
err = b.WriteMessage(&buf, payloadToReject)
@ -171,7 +172,7 @@ func TestWriteMessageChunking(t *testing.T) {
go func() {
bytesWritten, err := localConn.Write(largeMessage)
if err != nil {
t.Fatalf("unable to write message")
t.Fatalf("unable to write message: %v", err)
}
// The entire message should have been written out to the remote
@ -186,7 +187,7 @@ func TestWriteMessageChunking(t *testing.T) {
// Attempt to read the entirety of the message generated above.
buf := make([]byte, len(largeMessage))
if _, err := io.ReadFull(remoteConn, buf); err != nil {
t.Fatalf("unable to read message")
t.Fatalf("unable to read message: %v", err)
}
wg.Wait()