mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 01:43:16 +01:00
lnwire: refactor WriteMessage to use bytes.Buffer
This commit changes the method WriteMessage to use bytes.Buffer to save heap allocations. A unit test is added to check the method is implemented as expected.
This commit is contained in:
parent
77862e45af
commit
f212f1aa20
@ -52,6 +52,27 @@ const (
|
||||
MsgGossipTimestampRange = 265
|
||||
)
|
||||
|
||||
// ErrorEncodeMessage is used when failed to encode the message payload.
|
||||
func ErrorEncodeMessage(err error) error {
|
||||
return fmt.Errorf("failed to encode message to buffer, got %w", err)
|
||||
}
|
||||
|
||||
// ErrorWriteMessageType is used when failed to write the message type.
|
||||
func ErrorWriteMessageType(err error) error {
|
||||
return fmt.Errorf("failed to write message type, got %w", err)
|
||||
}
|
||||
|
||||
// ErrorPayloadTooLarge is used when the payload size exceeds the
|
||||
// MaxMsgBody.
|
||||
func ErrorPayloadTooLarge(size int) error {
|
||||
return fmt.Errorf(
|
||||
"message payload is too large - encoded %d bytes, "+
|
||||
"but maximum message payload is %d bytes",
|
||||
size, MaxMsgBody,
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
// String return the string representation of message type.
|
||||
func (t MessageType) String() string {
|
||||
switch t {
|
||||
@ -218,44 +239,49 @@ func makeEmptyMessage(msgType MessageType) (Message, error) {
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// WriteMessage writes a lightning Message to w including the necessary header
|
||||
// information and returns the number of bytes written.
|
||||
func WriteMessage(w io.Writer, msg Message, pver uint32) (int, error) {
|
||||
totalBytes := 0
|
||||
// WriteMessage writes a lightning Message to a buffer including the necessary
|
||||
// header information and returns the number of bytes written. If any error is
|
||||
// encountered, the buffer passed will be reset to its original state since we
|
||||
// don't want any broken bytes left. In other words, no bytes will be written
|
||||
// if there's an error. Either all or none of the message bytes will be written
|
||||
// to the buffer.
|
||||
//
|
||||
// NOTE: this method is not concurrent safe.
|
||||
func WriteMessage(buf *bytes.Buffer, msg Message, pver uint32) (int, error) {
|
||||
// Record the size of the bytes already written in buffer.
|
||||
oldByteSize := buf.Len()
|
||||
|
||||
// Encode the message payload itself into a temporary buffer.
|
||||
// TODO(roasbeef): create buffer pool
|
||||
var bw bytes.Buffer
|
||||
if err := msg.Encode(&bw, pver); err != nil {
|
||||
return totalBytes, err
|
||||
}
|
||||
payload := bw.Bytes()
|
||||
lenp := len(payload)
|
||||
|
||||
// Enforce maximum message payload, which means the body cannot be
|
||||
// greater than MaxMsgBody.
|
||||
if lenp > MaxMsgBody {
|
||||
return totalBytes, fmt.Errorf("message payload is too large - "+
|
||||
"encoded %d bytes, but maximum message body is %d bytes",
|
||||
lenp, MaxMsgBody)
|
||||
// cleanBrokenBytes is a helper closure that helps reset the buffer to
|
||||
// its original state. It truncates all the bytes written in current
|
||||
// scope.
|
||||
var cleanBrokenBytes = func(b *bytes.Buffer) int {
|
||||
b.Truncate(oldByteSize)
|
||||
return 0
|
||||
}
|
||||
|
||||
// With the initial sanity checks complete, we'll now write out the
|
||||
// message type itself.
|
||||
// Write the message type.
|
||||
var mType [2]byte
|
||||
binary.BigEndian.PutUint16(mType[:], uint16(msg.MsgType()))
|
||||
n, err := w.Write(mType[:])
|
||||
totalBytes += n
|
||||
msgTypeBytes, err := buf.Write(mType[:])
|
||||
if err != nil {
|
||||
return totalBytes, err
|
||||
return cleanBrokenBytes(buf), ErrorWriteMessageType(err)
|
||||
}
|
||||
|
||||
// With the message type written, we'll now write out the raw payload
|
||||
// itself.
|
||||
n, err = w.Write(payload)
|
||||
totalBytes += n
|
||||
// Use the write buffer to encode our message.
|
||||
if err := msg.Encode(buf, pver); err != nil {
|
||||
return cleanBrokenBytes(buf), ErrorEncodeMessage(err)
|
||||
}
|
||||
|
||||
return totalBytes, err
|
||||
// Enforce maximum overall message payload. The write buffer now has
|
||||
// the size of len(originalBytes) + len(payload) + len(type). We want
|
||||
// to enforce the payload here, so we subtract it by the length of the
|
||||
// type and old bytes.
|
||||
lenp := buf.Len() - oldByteSize - msgTypeBytes
|
||||
if lenp > MaxMsgBody {
|
||||
return cleanBrokenBytes(buf), ErrorPayloadTooLarge(lenp)
|
||||
}
|
||||
|
||||
return buf.Len() - oldByteSize, nil
|
||||
}
|
||||
|
||||
// ReadMessage reads, validates, and parses the next Lightning message from r
|
||||
|
@ -3,6 +3,7 @@ package lnwire_test
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"image/color"
|
||||
"io"
|
||||
"math"
|
||||
@ -16,6 +17,7 @@ import (
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/tor"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -41,6 +43,148 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
type mockMsg struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockMsg) Decode(r io.Reader, pver uint32) error {
|
||||
args := m.Called(r, pver)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockMsg) Encode(w io.Writer, pver uint32) error {
|
||||
args := m.Called(w, pver)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockMsg) MsgType() lnwire.MessageType {
|
||||
args := m.Called()
|
||||
return lnwire.MessageType(args.Int(0))
|
||||
}
|
||||
|
||||
// A compile time check to ensure mockMsg implements the lnwire.Message
|
||||
// interface.
|
||||
var _ lnwire.Message = (*mockMsg)(nil)
|
||||
|
||||
// TestWriteMessage tests the function lnwire.WriteMessage.
|
||||
func TestWriteMessage(t *testing.T) {
|
||||
var (
|
||||
buf = new(bytes.Buffer)
|
||||
|
||||
// encodeNormalSize specifies a message size that is normal.
|
||||
encodeNormalSize = 1000
|
||||
|
||||
// encodeOversize specifies a message size that's too big.
|
||||
encodeOversize = lnwire.MaxMsgBody + 1
|
||||
|
||||
// errDummy is returned by the msg.Encode when specified.
|
||||
errDummy = errors.New("test error")
|
||||
|
||||
// oneByte is a dummy byte used to fill up the buffer.
|
||||
oneByte = [1]byte{}
|
||||
)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
// encodeSize controls how many bytes are written to the buffer
|
||||
// by the method msg.Encode(buf, pver).
|
||||
encodeSize int
|
||||
|
||||
// encodeErr determines the return value of the method
|
||||
// msg.Encode(buf, pver).
|
||||
encodeErr error
|
||||
|
||||
errorExpected error
|
||||
}{
|
||||
|
||||
{
|
||||
name: "successful write",
|
||||
encodeSize: encodeNormalSize,
|
||||
encodeErr: nil,
|
||||
errorExpected: nil,
|
||||
},
|
||||
{
|
||||
name: "failed to encode payload",
|
||||
encodeSize: encodeNormalSize,
|
||||
encodeErr: errDummy,
|
||||
errorExpected: lnwire.ErrorEncodeMessage(errDummy),
|
||||
},
|
||||
{
|
||||
name: "exceeds MaxMsgBody",
|
||||
encodeSize: encodeOversize,
|
||||
encodeErr: nil,
|
||||
errorExpected: lnwire.ErrorPayloadTooLarge(
|
||||
encodeOversize,
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
tc := test
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Start the test by creating a mock message and patch
|
||||
// the relevant methods.
|
||||
msg := &mockMsg{}
|
||||
|
||||
// Use message type Ping here since all types are
|
||||
// encoded using 2 bytes, it won't affect anything
|
||||
// here.
|
||||
msg.On("MsgType").Return(lnwire.MsgPing)
|
||||
|
||||
// Encode will return the specified error (could be
|
||||
// nil) and has the side effect of filling up the
|
||||
// buffer by repeating the oneByte encodeSize times.
|
||||
msg.On("Encode", mock.Anything, mock.Anything).Return(
|
||||
tc.encodeErr,
|
||||
).Run(func(_ mock.Arguments) {
|
||||
for i := 0; i < tc.encodeSize; i++ {
|
||||
_, err := buf.Write(oneByte[:])
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
// Record the initial state of the buffer and write the
|
||||
// message.
|
||||
oldBytesSize := buf.Len()
|
||||
bytesWritten, err := lnwire.WriteMessage(
|
||||
buf, msg, 1,
|
||||
)
|
||||
|
||||
// Check that the returned error is expected.
|
||||
require.Equal(
|
||||
t, tc.errorExpected, err, "unexpected err",
|
||||
)
|
||||
|
||||
// If there's an error, no bytes should be written to
|
||||
// the buf.
|
||||
if tc.errorExpected != nil {
|
||||
require.Equal(
|
||||
t, 0, bytesWritten,
|
||||
"bytes written should be 0",
|
||||
)
|
||||
|
||||
// We also check that the old buf was not
|
||||
// affected.
|
||||
require.Equal(
|
||||
t, oldBytesSize, buf.Len(),
|
||||
"original buffer should not change",
|
||||
)
|
||||
} else {
|
||||
expected := buf.Len() - oldBytesSize
|
||||
require.Equal(
|
||||
t, expected, bytesWritten,
|
||||
"bytes written not matched",
|
||||
)
|
||||
}
|
||||
|
||||
// Finally, check the mocked methods are called as
|
||||
// expected.
|
||||
msg.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkWriteMessage benchmarks the performance of lnwire.WriteMessage. It
|
||||
// generates a test message for each of the lnwire.Message, calls the
|
||||
// WriteMessage method and benchmark it.
|
||||
|
Loading…
Reference in New Issue
Block a user