lnwire: refactor WriteMessage to use bytes.Buffer

This commit changes the method WriteMessage to use bytes.Buffer to save
heap allocations. A unit test is added to check the method is
implemented as expected.
This commit is contained in:
yyforyongyu 2021-06-17 11:29:40 +08:00
parent 77862e45af
commit f212f1aa20
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
2 changed files with 199 additions and 29 deletions

View File

@ -52,6 +52,27 @@ const (
MsgGossipTimestampRange = 265
)
// ErrorEncodeMessage is used when failed to encode the message payload.
func ErrorEncodeMessage(err error) error {
return fmt.Errorf("failed to encode message to buffer, got %w", err)
}
// ErrorWriteMessageType is used when failed to write the message type.
func ErrorWriteMessageType(err error) error {
return fmt.Errorf("failed to write message type, got %w", err)
}
// ErrorPayloadTooLarge is used when the payload size exceeds the
// MaxMsgBody.
func ErrorPayloadTooLarge(size int) error {
return fmt.Errorf(
"message payload is too large - encoded %d bytes, "+
"but maximum message payload is %d bytes",
size, MaxMsgBody,
)
}
// String return the string representation of message type.
func (t MessageType) String() string {
switch t {
@ -218,44 +239,49 @@ func makeEmptyMessage(msgType MessageType) (Message, error) {
return msg, nil
}
// WriteMessage writes a lightning Message to w including the necessary header
// information and returns the number of bytes written.
func WriteMessage(w io.Writer, msg Message, pver uint32) (int, error) {
totalBytes := 0
// WriteMessage writes a lightning Message to a buffer including the necessary
// header information and returns the number of bytes written. If any error is
// encountered, the buffer passed will be reset to its original state since we
// don't want any broken bytes left. In other words, no bytes will be written
// if there's an error. Either all or none of the message bytes will be written
// to the buffer.
//
// NOTE: this method is not concurrent safe.
func WriteMessage(buf *bytes.Buffer, msg Message, pver uint32) (int, error) {
// Record the size of the bytes already written in buffer.
oldByteSize := buf.Len()
// Encode the message payload itself into a temporary buffer.
// TODO(roasbeef): create buffer pool
var bw bytes.Buffer
if err := msg.Encode(&bw, pver); err != nil {
return totalBytes, err
}
payload := bw.Bytes()
lenp := len(payload)
// Enforce maximum message payload, which means the body cannot be
// greater than MaxMsgBody.
if lenp > MaxMsgBody {
return totalBytes, fmt.Errorf("message payload is too large - "+
"encoded %d bytes, but maximum message body is %d bytes",
lenp, MaxMsgBody)
// cleanBrokenBytes is a helper closure that helps reset the buffer to
// its original state. It truncates all the bytes written in current
// scope.
var cleanBrokenBytes = func(b *bytes.Buffer) int {
b.Truncate(oldByteSize)
return 0
}
// With the initial sanity checks complete, we'll now write out the
// message type itself.
// Write the message type.
var mType [2]byte
binary.BigEndian.PutUint16(mType[:], uint16(msg.MsgType()))
n, err := w.Write(mType[:])
totalBytes += n
msgTypeBytes, err := buf.Write(mType[:])
if err != nil {
return totalBytes, err
return cleanBrokenBytes(buf), ErrorWriteMessageType(err)
}
// With the message type written, we'll now write out the raw payload
// itself.
n, err = w.Write(payload)
totalBytes += n
// Use the write buffer to encode our message.
if err := msg.Encode(buf, pver); err != nil {
return cleanBrokenBytes(buf), ErrorEncodeMessage(err)
}
return totalBytes, err
// Enforce maximum overall message payload. The write buffer now has
// the size of len(originalBytes) + len(payload) + len(type). We want
// to enforce the payload here, so we subtract it by the length of the
// type and old bytes.
lenp := buf.Len() - oldByteSize - msgTypeBytes
if lenp > MaxMsgBody {
return cleanBrokenBytes(buf), ErrorPayloadTooLarge(lenp)
}
return buf.Len() - oldByteSize, nil
}
// ReadMessage reads, validates, and parses the next Lightning message from r

View File

@ -3,6 +3,7 @@ package lnwire_test
import (
"bytes"
"encoding/binary"
"errors"
"image/color"
"io"
"math"
@ -16,6 +17,7 @@ import (
"github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tor"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
@ -41,6 +43,148 @@ var (
}
)
type mockMsg struct {
mock.Mock
}
func (m *mockMsg) Decode(r io.Reader, pver uint32) error {
args := m.Called(r, pver)
return args.Error(0)
}
func (m *mockMsg) Encode(w io.Writer, pver uint32) error {
args := m.Called(w, pver)
return args.Error(0)
}
func (m *mockMsg) MsgType() lnwire.MessageType {
args := m.Called()
return lnwire.MessageType(args.Int(0))
}
// A compile time check to ensure mockMsg implements the lnwire.Message
// interface.
var _ lnwire.Message = (*mockMsg)(nil)
// TestWriteMessage tests the function lnwire.WriteMessage.
func TestWriteMessage(t *testing.T) {
var (
buf = new(bytes.Buffer)
// encodeNormalSize specifies a message size that is normal.
encodeNormalSize = 1000
// encodeOversize specifies a message size that's too big.
encodeOversize = lnwire.MaxMsgBody + 1
// errDummy is returned by the msg.Encode when specified.
errDummy = errors.New("test error")
// oneByte is a dummy byte used to fill up the buffer.
oneByte = [1]byte{}
)
testCases := []struct {
name string
// encodeSize controls how many bytes are written to the buffer
// by the method msg.Encode(buf, pver).
encodeSize int
// encodeErr determines the return value of the method
// msg.Encode(buf, pver).
encodeErr error
errorExpected error
}{
{
name: "successful write",
encodeSize: encodeNormalSize,
encodeErr: nil,
errorExpected: nil,
},
{
name: "failed to encode payload",
encodeSize: encodeNormalSize,
encodeErr: errDummy,
errorExpected: lnwire.ErrorEncodeMessage(errDummy),
},
{
name: "exceeds MaxMsgBody",
encodeSize: encodeOversize,
encodeErr: nil,
errorExpected: lnwire.ErrorPayloadTooLarge(
encodeOversize,
),
},
}
for _, test := range testCases {
tc := test
t.Run(tc.name, func(t *testing.T) {
// Start the test by creating a mock message and patch
// the relevant methods.
msg := &mockMsg{}
// Use message type Ping here since all types are
// encoded using 2 bytes, it won't affect anything
// here.
msg.On("MsgType").Return(lnwire.MsgPing)
// Encode will return the specified error (could be
// nil) and has the side effect of filling up the
// buffer by repeating the oneByte encodeSize times.
msg.On("Encode", mock.Anything, mock.Anything).Return(
tc.encodeErr,
).Run(func(_ mock.Arguments) {
for i := 0; i < tc.encodeSize; i++ {
_, err := buf.Write(oneByte[:])
require.NoError(t, err)
}
})
// Record the initial state of the buffer and write the
// message.
oldBytesSize := buf.Len()
bytesWritten, err := lnwire.WriteMessage(
buf, msg, 1,
)
// Check that the returned error is expected.
require.Equal(
t, tc.errorExpected, err, "unexpected err",
)
// If there's an error, no bytes should be written to
// the buf.
if tc.errorExpected != nil {
require.Equal(
t, 0, bytesWritten,
"bytes written should be 0",
)
// We also check that the old buf was not
// affected.
require.Equal(
t, oldBytesSize, buf.Len(),
"original buffer should not change",
)
} else {
expected := buf.Len() - oldBytesSize
require.Equal(
t, expected, bytesWritten,
"bytes written not matched",
)
}
// Finally, check the mocked methods are called as
// expected.
msg.AssertExpectations(t)
})
}
}
// BenchmarkWriteMessage benchmarks the performance of lnwire.WriteMessage. It
// generates a test message for each of the lnwire.Message, calls the
// WriteMessage method and benchmark it.