diff --git a/lnwire/message.go b/lnwire/message.go index fbb809248..59e230edf 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -9,13 +9,6 @@ import ( "io" ) -// MessageHeaderSize is the number of bytes in a lightning message header. -// The bytes are allocated as follows: network magic 4 bytes + command 4 -// bytes + payload length 4 bytes. Note that a checksum is omitted as lightning -// messages are assumed to be transmitted over an AEAD secured connection which -// provides integrity over the entire message. -const MessageHeaderSize = 12 - // MaxMessagePayload is the maximum bytes a message can be regardless of other // individual limits imposed by messages themselves. const MaxMessagePayload = 65535 // 65KB @@ -127,182 +120,82 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { return msg, nil } -// messageHeader represents the header structure for all lightning protocol -// messages. -type messageHeader struct { - // magic represents Which Blockchain Technology(TM) to use. - // NOTE(j): We don't need to worry about the magic overlapping with - // bitcoin since this is inside encrypted comms anyway, but maybe we - // should use the XOR (^wire.TestNet3) just in case??? - magic wire.BitcoinNet // 4 bytes - command uint32 // 4 bytes - length uint32 // 4 bytes -} - -// readMessageHeader reads a lightning protocol message header from r. -func readMessageHeader(r io.Reader) (int, *messageHeader, error) { - // As the message header is a fixed size structure, read bytes for the - // entire header at once. - var headerBytes [MessageHeaderSize]byte - n, err := io.ReadFull(r, headerBytes[:]) - if err != nil { - return n, nil, err - } - hr := bytes.NewReader(headerBytes[:]) - - // Create and populate the message header from the raw header bytes. - hdr := messageHeader{} - err = readElements(hr, - &hdr.magic, - &hdr.command, - &hdr.length) - if err != nil { - return n, nil, err - } - - return n, &hdr, nil -} - -// discardInput reads n bytes from reader r in chunks and discards the read -// bytes. This is used to skip payloads when various errors occur and helps -// prevent rogue nodes from causing massive memory allocation through forging -// header length. -func discardInput(r io.Reader, n uint32) { - maxSize := uint32(10 * 1024) // 10k at a time - numReads := n / maxSize - bytesRemaining := n % maxSize - if n > 0 { - buf := make([]byte, maxSize) - for i := uint32(0); i < numReads; i++ { - io.ReadFull(r, buf) - } - } - if bytesRemaining > 0 { - buf := make([]byte, bytesRemaining) - io.ReadFull(r, buf) - } -} - // WriteMessage writes a lightning Message to w including the necessary header // information and returns the number of bytes written. -func WriteMessage(w io.Writer, msg Message, pver uint32, btcnet wire.BitcoinNet) (int, error) { +func WriteMessage(w io.Writer, msg Message, pver uint32) (int, error) { totalBytes := 0 - cmd := msg.Command() - - // Encode the message payload + // Encode the message payload itself into a temporary buffer. + // TODO(roasbeef): create buffer pool var bw bytes.Buffer - err := msg.Encode(&bw, pver) - if err != nil { + if err := msg.Encode(&bw, pver); err != nil { return totalBytes, err } payload := bw.Bytes() lenp := len(payload) - // Enforce maximum overall message payload + // Enforce maximum overall message payload. if lenp > MaxMessagePayload { return totalBytes, fmt.Errorf("message payload is too large - "+ "encoded %d bytes, but maximum message payload is %d bytes", lenp, MaxMessagePayload) } - // Enforce maximum message payload on the message type + // Enforce maximum message payload on the message type. mpl := msg.MaxPayloadLength(pver) if uint32(lenp) > mpl { return totalBytes, fmt.Errorf("message payload is too large - "+ "encoded %d bytes, but maximum message payload of "+ - "type %x is %d bytes", lenp, cmd, mpl) + "type %x is %d bytes", lenp, msg.MsgType(), mpl) } - // Create header for the message. - hdr := messageHeader{magic: btcnet, command: cmd, length: uint32(lenp)} - - // Encode the header for the message. This is done to a buffer - // rather than directly to the writer since writeElements doesn't - // return the number of bytes written. - hw := bytes.NewBuffer(make([]byte, 0, MessageHeaderSize)) - if err := writeElements(hw, hdr.magic, hdr.command, hdr.length); err != nil { - return 0, nil - } - - // Write the header first. - n, err := w.Write(hw.Bytes()) + // With the initial sanity checks complete, we'll now write out the + // message type itself. + var mType [2]byte + binary.BigEndian.PutUint16(mType[:], uint16(msg.MsgType())) + n, err := w.Write(mType[:]) totalBytes += n if err != nil { return totalBytes, err } - // Write payload the payload itself after the header. + // With the message type written, we'll now write out the raw payload + // itself. n, err = w.Write(payload) totalBytes += n + return totalBytes, err } // ReadMessage reads, validates, and parses the next bitcoin Message from r for -// the provided protocol version and bitcoin network. It returns the number of -// bytes read in addition to the parsed Message and raw bytes which comprise the -// message. This function is the same as ReadMessage except it also returns the -// number of bytes read. -func ReadMessage(r io.Reader, pver uint32, btcnet wire.BitcoinNet) (int, Message, []byte, error) { +// the provided protocol version. It returns the number of bytes read in +// addition to the parsed Message and raw bytes which comprise the message. +func ReadMessage(r io.Reader, pver uint32) (int, Message, error) { + // TODO(roasbeef): need to explicitly enforce max message payload, or + // just allow it to be done by the MaxPayloadLength? totalBytes := 0 - n, hdr, err := readMessageHeader(r) + + // First, we'll read out the first two bytes of the message so we can + // create the proper empty message. + var mType [2]byte + n, err := io.ReadFull(r, mType[:]) totalBytes += n if err != nil { - return totalBytes, nil, nil, err + return totalBytes, nil, err } - // Enforce maximum message payload - if hdr.length > MaxMessagePayload { - return totalBytes, nil, nil, fmt.Errorf("message payload is "+ - "too large - header indicates %d bytes, but max "+ - "message payload is %d bytes.", hdr.length, - MaxMessagePayload) - } + msgType := MessageType(binary.BigEndian.Uint16(mType[:])) - // Check for messages in the wrong network. - if hdr.magic != btcnet { - discardInput(r, hdr.length) - return totalBytes, nil, nil, fmt.Errorf("message from other "+ - "network [%v]", hdr.magic) - } - - // Create struct of appropriate message type based on the command. - command := hdr.command - msg, err := makeEmptyMessage(command) + // Now that we know the target message type, we can create the proper + // empty message type and decode the message into it. + msg, err := makeEmptyMessage(msgType) if err != nil { - discardInput(r, hdr.length) - return totalBytes, nil, nil, &UnknownMessage{ - messageType: command, - } + return totalBytes, nil, err } - - // Check for maximum length based on the message type. - mpl := msg.MaxPayloadLength(pver) - if hdr.length > mpl { - discardInput(r, hdr.length) - return totalBytes, nil, nil, fmt.Errorf("payload exceeds max "+ - "length. indicates %v bytes, but max of message type %v is %v.", - hdr.length, command, mpl) + if err := msg.Decode(r, pver); err != nil { + return totalBytes, nil, err } + totalBytes += int(msg.MaxPayloadLength(pver)) - // Read payload. - payload := make([]byte, hdr.length) - n, err = io.ReadFull(r, payload) - totalBytes += n - if err != nil { - return totalBytes, nil, nil, err - } - - // Unmarshal message. - pr := bytes.NewBuffer(payload) - if err = msg.Decode(pr, pver); err != nil { - return totalBytes, nil, nil, err - } - - // Validate the data. - if err = msg.Validate(); err != nil { - return totalBytes, nil, nil, err - } - - return totalBytes, msg, payload, nil + return totalBytes, msg, nil }