lnwire: refactor Encode to use specific writers - III

This commit refactors the remaining usage of WriteElements. By
replacing the interface types with concrete types for the params used in
the methods, most of the encoding of the messages now takes zero heap
allocations.
This commit is contained in:
yyforyongyu 2021-06-18 15:15:44 +08:00
parent c1ad9cc60f
commit 2cf6969dbc
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
12 changed files with 289 additions and 133 deletions

View File

@ -88,20 +88,51 @@ func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error {
//
// This is part of the lnwire.Message interface.
func (a *ChannelAnnouncement) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w,
a.NodeSig1,
a.NodeSig2,
a.BitcoinSig1,
a.BitcoinSig2,
a.Features,
a.ChainHash[:],
a.ShortChannelID,
a.NodeID1,
a.NodeID2,
a.BitcoinKey1,
a.BitcoinKey2,
a.ExtraOpaqueData,
)
if err := WriteSig(w, a.NodeSig1); err != nil {
return err
}
if err := WriteSig(w, a.NodeSig2); err != nil {
return err
}
if err := WriteSig(w, a.BitcoinSig1); err != nil {
return err
}
if err := WriteSig(w, a.BitcoinSig2); err != nil {
return err
}
if err := WriteRawFeatureVector(w, a.Features); err != nil {
return err
}
if err := WriteBytes(w, a.ChainHash[:]); err != nil {
return err
}
if err := WriteShortChannelID(w, a.ShortChannelID); err != nil {
return err
}
if err := WriteBytes(w, a.NodeID1[:]); err != nil {
return err
}
if err := WriteBytes(w, a.NodeID2[:]); err != nil {
return err
}
if err := WriteBytes(w, a.BitcoinKey1[:]); err != nil {
return err
}
if err := WriteBytes(w, a.BitcoinKey2[:]); err != nil {
return err
}
return WriteBytes(w, a.ExtraOpaqueData)
}
// MsgType returns the integer uniquely identifying this message type on the
@ -116,20 +147,40 @@ func (a *ChannelAnnouncement) MsgType() MessageType {
// be signed.
func (a *ChannelAnnouncement) DataToSign() ([]byte, error) {
// We should not include the signatures itself.
var w bytes.Buffer
err := WriteElements(&w,
a.Features,
a.ChainHash[:],
a.ShortChannelID,
a.NodeID1,
a.NodeID2,
a.BitcoinKey1,
a.BitcoinKey2,
a.ExtraOpaqueData,
)
if err != nil {
b := make([]byte, 0, MaxMsgBody)
buf := bytes.NewBuffer(b)
if err := WriteRawFeatureVector(buf, a.Features); err != nil {
return nil, err
}
return w.Bytes(), nil
if err := WriteBytes(buf, a.ChainHash[:]); err != nil {
return nil, err
}
if err := WriteShortChannelID(buf, a.ShortChannelID); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.NodeID1[:]); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.NodeID2[:]); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.BitcoinKey1[:]); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.BitcoinKey2[:]); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.ExtraOpaqueData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

View File

@ -77,12 +77,15 @@ var _ Message = (*ChannelReestablish)(nil)
//
// This is part of the lnwire.Message interface.
func (a *ChannelReestablish) Encode(w *bytes.Buffer, pver uint32) error {
err := WriteElements(w,
a.ChanID,
a.NextLocalCommitHeight,
a.RemoteCommitTailHeight,
)
if err != nil {
if err := WriteChannelID(w, a.ChanID); err != nil {
return err
}
if err := WriteUint64(w, a.NextLocalCommitHeight); err != nil {
return err
}
if err := WriteUint64(w, a.RemoteCommitTailHeight); err != nil {
return err
}
@ -94,15 +97,18 @@ func (a *ChannelReestablish) Encode(w *bytes.Buffer, pver uint32) error {
//
// NOTE: This is here primarily for the quickcheck tests, in
// practice, we'll always populate this field.
return WriteElements(w, a.ExtraData)
return WriteBytes(w, a.ExtraData)
}
// Otherwise, we'll write out the remaining elements.
return WriteElements(w,
a.LastRemoteCommitSecret[:],
a.LocalUnrevokedCommitPoint,
a.ExtraData,
)
if err := WriteBytes(w, a.LastRemoteCommitSecret[:]); err != nil {
return err
}
if err := WritePublicKey(w, a.LocalUnrevokedCommitPoint); err != nil {
return err
}
return WriteBytes(w, a.ExtraData)
}
// Decode deserializes a serialized ChannelReestablish stored in the passed

View File

@ -160,32 +160,57 @@ func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error {
//
// This is part of the lnwire.Message interface.
func (a *ChannelUpdate) Encode(w *bytes.Buffer, pver uint32) error {
err := WriteElements(w,
a.Signature,
a.ChainHash[:],
a.ShortChannelID,
a.Timestamp,
a.MessageFlags,
a.ChannelFlags,
a.TimeLockDelta,
a.HtlcMinimumMsat,
a.BaseFee,
a.FeeRate,
)
if err != nil {
if err := WriteSig(w, a.Signature); err != nil {
return err
}
if err := WriteBytes(w, a.ChainHash[:]); err != nil {
return err
}
if err := WriteShortChannelID(w, a.ShortChannelID); err != nil {
return err
}
if err := WriteUint32(w, a.Timestamp); err != nil {
return err
}
if err := WriteChanUpdateMsgFlags(w, a.MessageFlags); err != nil {
return err
}
if err := WriteChanUpdateChanFlags(w, a.ChannelFlags); err != nil {
return err
}
if err := WriteUint16(w, a.TimeLockDelta); err != nil {
return err
}
if err := WriteMilliSatoshi(w, a.HtlcMinimumMsat); err != nil {
return err
}
if err := WriteUint32(w, a.BaseFee); err != nil {
return err
}
if err := WriteUint32(w, a.FeeRate); err != nil {
return err
}
// Now append optional fields if they are set. Currently, the only
// optional field is max HTLC.
if a.MessageFlags.HasMaxHtlc() {
if err := WriteElements(w, a.HtlcMaximumMsat); err != nil {
err := WriteMilliSatoshi(w, a.HtlcMaximumMsat)
if err != nil {
return err
}
}
// Finally, append any extra opaque data.
return a.ExtraOpaqueData.Encode(w)
return WriteBytes(w, a.ExtraOpaqueData)
}
// MsgType returns the integer uniquely identifying this message type on the
@ -199,36 +224,58 @@ func (a *ChannelUpdate) MsgType() MessageType {
// DataToSign is used to retrieve part of the announcement message which should
// be signed.
func (a *ChannelUpdate) DataToSign() ([]byte, error) {
// We should not include the signatures itself.
var w bytes.Buffer
err := WriteElements(&w,
a.ChainHash[:],
a.ShortChannelID,
a.Timestamp,
a.MessageFlags,
a.ChannelFlags,
a.TimeLockDelta,
a.HtlcMinimumMsat,
a.BaseFee,
a.FeeRate,
)
if err != nil {
b := make([]byte, 0, MaxMsgBody)
buf := bytes.NewBuffer(b)
if err := WriteBytes(buf, a.ChainHash[:]); err != nil {
return nil, err
}
if err := WriteShortChannelID(buf, a.ShortChannelID); err != nil {
return nil, err
}
if err := WriteUint32(buf, a.Timestamp); err != nil {
return nil, err
}
if err := WriteChanUpdateMsgFlags(buf, a.MessageFlags); err != nil {
return nil, err
}
if err := WriteChanUpdateChanFlags(buf, a.ChannelFlags); err != nil {
return nil, err
}
if err := WriteUint16(buf, a.TimeLockDelta); err != nil {
return nil, err
}
if err := WriteMilliSatoshi(buf, a.HtlcMinimumMsat); err != nil {
return nil, err
}
if err := WriteUint32(buf, a.BaseFee); err != nil {
return nil, err
}
if err := WriteUint32(buf, a.FeeRate); err != nil {
return nil, err
}
// Now append optional fields if they are set. Currently, the only
// optional field is max HTLC.
if a.MessageFlags.HasMaxHtlc() {
if err := WriteElements(&w, a.HtlcMaximumMsat); err != nil {
err := WriteMilliSatoshi(buf, a.HtlcMaximumMsat)
if err != nil {
return nil, err
}
}
// Finally, append any extra opaque data.
if err := a.ExtraOpaqueData.Encode(&w); err != nil {
if err := WriteBytes(buf, a.ExtraOpaqueData); err != nil {
return nil, err
}
return w.Bytes(), nil
return buf.Bytes(), nil
}

View File

@ -18,7 +18,7 @@ type ExtraOpaqueData []byte
// Encode attempts to encode the raw extra bytes into the passed io.Writer.
func (e *ExtraOpaqueData) Encode(w *bytes.Buffer) error {
eBytes := []byte((*e)[:])
if err := WriteElements(w, eBytes); err != nil {
if err := WriteBytes(w, eBytes); err != nil {
return err
}

View File

@ -232,7 +232,7 @@ func TestMaxOutPointIndex(t *testing.T) {
}
var b bytes.Buffer
if err := WriteElement(&b, op); err == nil {
if err := WriteOutPoint(&b, op); err == nil {
t.Fatalf("write of outPoint should fail, index exceeds 16-bits")
}
}

View File

@ -293,19 +293,18 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
// This is part of the lnwire.Message interface.
func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error {
// First, we'll write out the chain hash.
err := WriteElements(w, q.ChainHash[:])
if err != nil {
if err := WriteBytes(w, q.ChainHash[:]); err != nil {
return err
}
// Base on our encoding type, we'll write out the set of short channel
// ID's.
err = encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort)
err := encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort)
if err != nil {
return err
}
return q.ExtraData.Encode(w)
return WriteBytes(w, q.ExtraData)
}
// encodeShortChanIDs encodes the passed short channel ID's into the passed
@ -332,20 +331,21 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding,
// body. We add 1 as the response will have the encoding type
// prepended to it.
numBytesBody := uint16(len(shortChanIDs)*8) + 1
if err := WriteElements(w, numBytesBody); err != nil {
if err := WriteUint16(w, numBytesBody); err != nil {
return err
}
// We'll then write out the encoding that that follows the
// actual encoded short channel ID's.
if err := WriteElements(w, encodingType); err != nil {
err := WriteShortChanIDEncoding(w, encodingType)
if err != nil {
return err
}
// Now that we know they're sorted, we can write out each short
// channel ID to the buffer.
for _, chanID := range shortChanIDs {
if err := WriteElements(w, chanID); err != nil {
if err := WriteShortChannelID(w, chanID); err != nil {
return fmt.Errorf("unable to write short chan "+
"ID: %v", err)
}
@ -374,7 +374,7 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding,
// into the zlib writer, which will do compressing on
// the fly.
for _, chanID := range shortChanIDs {
err := WriteElements(&wb, chanID)
err := WriteShortChannelID(&wb, chanID)
if err != nil {
return fmt.Errorf(
"unable to write short chan "+
@ -418,15 +418,15 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding,
// Finally, we can write out the number of bytes, the
// compression type, and finally the buffer itself.
if err := WriteElements(w, uint16(numBytesBody)); err != nil {
if err := WriteUint16(w, uint16(numBytesBody)); err != nil {
return err
}
if err := WriteElements(w, encodingType); err != nil {
err := WriteShortChanIDEncoding(w, encodingType)
if err != nil {
return err
}
_, err := w.Write(compressedPayload)
return err
return WriteBytes(w, compressedPayload)
default:
// If we're trying to encode with an encoding type that we

View File

@ -87,22 +87,28 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error {
//
// This is part of the lnwire.Message interface.
func (c *ReplyChannelRange) Encode(w *bytes.Buffer, pver uint32) error {
err := WriteElements(w,
c.ChainHash[:],
c.FirstBlockHeight,
c.NumBlocks,
c.Complete,
)
if err := WriteBytes(w, c.ChainHash[:]); err != nil {
return err
}
if err := WriteUint32(w, c.FirstBlockHeight); err != nil {
return err
}
if err := WriteUint32(w, c.NumBlocks); err != nil {
return err
}
if err := WriteUint8(w, c.Complete); err != nil {
return err
}
err := encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort)
if err != nil {
return err
}
err = encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort)
if err != nil {
return err
}
return c.ExtraData.Encode(w)
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the

View File

@ -85,20 +85,36 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error {
)
}
// Encode serializes the target UpdateAddHTLC into the passed io.Writer observing
// the protocol version specified.
// Encode serializes the target UpdateAddHTLC into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *UpdateAddHTLC) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w,
c.ChanID,
c.ID,
c.Amount,
c.PaymentHash[:],
c.Expiry,
c.OnionBlob[:],
c.ExtraData,
)
if err := WriteChannelID(w, c.ChanID); err != nil {
return err
}
if err := WriteUint64(w, c.ID); err != nil {
return err
}
if err := WriteMilliSatoshi(w, c.Amount); err != nil {
return err
}
if err := WriteBytes(w, c.PaymentHash[:]); err != nil {
return err
}
if err := WriteUint32(w, c.Expiry); err != nil {
return err
}
if err := WriteBytes(w, c.OnionBlob[:]); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the

View File

@ -56,12 +56,19 @@ func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error {
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailHTLC) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w,
c.ChanID,
c.ID,
c.Reason,
c.ExtraData,
)
if err := WriteChannelID(w, c.ChanID); err != nil {
return err
}
if err := WriteUint64(w, c.ID); err != nil {
return err
}
if err := WriteOpaqueReason(w, c.Reason); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the

View File

@ -54,14 +54,26 @@ func (c *UpdateFailMalformedHTLC) Decode(r io.Reader, pver uint32) error {
// io.Writer observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailMalformedHTLC) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w,
c.ChanID,
c.ID,
c.ShaOnionBlob[:],
c.FailureCode,
c.ExtraData,
)
func (c *UpdateFailMalformedHTLC) Encode(w *bytes.Buffer,
pver uint32) error {
if err := WriteChannelID(w, c.ChanID); err != nil {
return err
}
if err := WriteUint64(w, c.ID); err != nil {
return err
}
if err := WriteBytes(w, c.ShaOnionBlob[:]); err != nil {
return err
}
if err := WriteFailCode(w, c.FailureCode); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the

View File

@ -53,11 +53,15 @@ func (c *UpdateFee) Decode(r io.Reader, pver uint32) error {
//
// This is part of the lnwire.Message interface.
func (c *UpdateFee) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w,
c.ChanID,
c.FeePerKw,
c.ExtraData,
)
if err := WriteChannelID(w, c.ChanID); err != nil {
return err
}
if err := WriteUint32(w, c.FeePerKw); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the

View File

@ -62,12 +62,19 @@ func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error {
//
// This is part of the lnwire.Message interface.
func (c *UpdateFulfillHTLC) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w,
c.ChanID,
c.ID,
c.PaymentPreimage[:],
c.ExtraData,
)
if err := WriteChannelID(w, c.ChanID); err != nil {
return err
}
if err := WriteUint64(w, c.ID); err != nil {
return err
}
if err := WriteBytes(w, c.PaymentPreimage[:]); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the