2018-04-17 03:46:15 +02:00
|
|
|
package lnwire
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
2018-06-16 03:31:23 +02:00
|
|
|
"compress/zlib"
|
2018-04-17 03:46:15 +02:00
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"sort"
|
2018-06-16 03:31:23 +02:00
|
|
|
"sync"
|
2018-04-17 03:46:15 +02:00
|
|
|
|
2018-06-05 03:34:16 +02:00
|
|
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
2018-04-17 03:46:15 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
// ShortChanIDEncoding is an enum-like type that represents exactly how a set
|
|
|
|
// of short channel ID's is encoded on the wire. The set of encodings allows us
|
|
|
|
// to take advantage of the structure of a list of short channel ID's to
|
|
|
|
// achieving a high degree of compression.
|
|
|
|
type ShortChanIDEncoding uint8
|
|
|
|
|
|
|
|
const (
|
|
|
|
// EncodingSortedPlain signals that the set of short channel ID's is
|
|
|
|
// encoded using the regular encoding, in a sorted order.
|
|
|
|
EncodingSortedPlain ShortChanIDEncoding = 0
|
|
|
|
|
2018-06-16 03:31:23 +02:00
|
|
|
// EncodingSortedZlib signals that the set of short channel ID's is
|
|
|
|
// encoded by first sorting the set of channel ID's, as then
|
|
|
|
// compressing them using zlib.
|
|
|
|
EncodingSortedZlib ShortChanIDEncoding = 1
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
// maxZlibBufSize is the max number of bytes that we'll accept from a
|
|
|
|
// zlib decoding instance. We do this in order to limit the total
|
|
|
|
// amount of memory allocated during a decoding instance.
|
|
|
|
maxZlibBufSize = 67413630
|
2018-04-17 03:46:15 +02:00
|
|
|
)
|
|
|
|
|
2018-06-16 03:33:04 +02:00
|
|
|
// zlibDecodeMtx is a package level mutex that we'll use in order to ensure
|
|
|
|
// that we'll only attempt a single zlib decoding instance at a time. This
|
|
|
|
// allows us to also further bound our memory usage.
|
|
|
|
var zlibDecodeMtx sync.Mutex
|
|
|
|
|
2018-04-17 03:46:15 +02:00
|
|
|
// ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we
|
|
|
|
// came across an unknown short channel ID encoding, and therefore were unable
|
|
|
|
// to continue parsing.
|
|
|
|
func ErrUnknownShortChanIDEncoding(encoding ShortChanIDEncoding) error {
|
|
|
|
return fmt.Errorf("unknown short chan id encoding: %v", encoding)
|
|
|
|
}
|
|
|
|
|
|
|
|
// QueryShortChanIDs is a message that allows the sender to query a set of
|
|
|
|
// channel announcement and channel update messages that correspond to the set
|
|
|
|
// of encoded short channel ID's. The encoding of the short channel ID's is
|
|
|
|
// detailed in the query message ensuring that the receiver knows how to
|
|
|
|
// properly decode each encode short channel ID which may be encoded using a
|
|
|
|
// compression format. The receiver should respond with a series of channel
|
|
|
|
// announcement and channel updates, finally sending a ReplyShortChanIDsEnd
|
|
|
|
// message.
|
|
|
|
type QueryShortChanIDs struct {
|
|
|
|
// ChainHash denotes the target chain that we're querying for the
|
|
|
|
// channel channel ID's of.
|
|
|
|
ChainHash chainhash.Hash
|
|
|
|
|
|
|
|
// EncodingType is a signal to the receiver of the message that
|
|
|
|
// indicates exactly how the set of short channel ID's that follow have
|
|
|
|
// been encoded.
|
|
|
|
EncodingType ShortChanIDEncoding
|
|
|
|
|
|
|
|
// ShortChanIDs is a slice of decoded short channel ID's.
|
|
|
|
ShortChanIDs []ShortChannelID
|
|
|
|
}
|
|
|
|
|
|
|
|
// NewQueryShortChanIDs creates a new QueryShortChanIDs message.
|
|
|
|
func NewQueryShortChanIDs(h chainhash.Hash, e ShortChanIDEncoding,
|
|
|
|
s []ShortChannelID) *QueryShortChanIDs {
|
|
|
|
|
|
|
|
return &QueryShortChanIDs{
|
|
|
|
ChainHash: h,
|
|
|
|
EncodingType: e,
|
|
|
|
ShortChanIDs: s,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// A compile time check to ensure QueryShortChanIDs implements the
|
|
|
|
// lnwire.Message interface.
|
|
|
|
var _ Message = (*QueryShortChanIDs)(nil)
|
|
|
|
|
|
|
|
// Decode deserializes a serialized QueryShortChanIDs message stored in the
|
|
|
|
// passed io.Reader observing the specified protocol version.
|
|
|
|
//
|
|
|
|
// This is part of the lnwire.Message interface.
|
|
|
|
func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error {
|
2018-12-10 03:27:41 +01:00
|
|
|
err := ReadElements(r, q.ChainHash[:])
|
2018-04-17 03:46:15 +02:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r)
|
|
|
|
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// decodeShortChanIDs decodes a set of short channel ID's that have been
|
|
|
|
// encoded. The first byte of the body details how the short chan ID's were
|
|
|
|
// encoded. We'll use this type to govern exactly how we go about encoding the
|
|
|
|
// set of short channel ID's.
|
|
|
|
func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, error) {
|
|
|
|
// First, we'll attempt to read the number of bytes in the body of the
|
|
|
|
// set of encoded short channel ID's.
|
|
|
|
var numBytesResp uint16
|
2018-12-10 03:27:41 +01:00
|
|
|
err := ReadElements(r, &numBytesResp)
|
2018-04-17 03:46:15 +02:00
|
|
|
if err != nil {
|
|
|
|
return 0, nil, err
|
|
|
|
}
|
|
|
|
|
2018-09-14 00:56:21 +02:00
|
|
|
if numBytesResp == 0 {
|
|
|
|
return 0, nil, fmt.Errorf("No encoding type specified")
|
|
|
|
}
|
|
|
|
|
2018-04-17 03:46:15 +02:00
|
|
|
queryBody := make([]byte, numBytesResp)
|
|
|
|
if _, err := io.ReadFull(r, queryBody); err != nil {
|
|
|
|
return 0, nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// The first byte is the encoding type, so we'll extract that so we can
|
|
|
|
// continue our parsing.
|
|
|
|
encodingType := ShortChanIDEncoding(queryBody[0])
|
|
|
|
|
|
|
|
// Before continuing, we'll snip off the first byte of the query body
|
|
|
|
// as that was just the encoding type.
|
|
|
|
queryBody = queryBody[1:]
|
|
|
|
|
|
|
|
// Otherwise, depending on the encoding type, we'll decode the encode
|
|
|
|
// short channel ID's in a different manner.
|
|
|
|
switch encodingType {
|
|
|
|
|
|
|
|
// In this encoding, we'll simply read a sort array of encoded short
|
|
|
|
// channel ID's from the buffer.
|
|
|
|
case EncodingSortedPlain:
|
2018-06-05 01:31:18 +02:00
|
|
|
// If after extracting the encoding type, then number of
|
|
|
|
// remaining bytes instead a whole multiple of the size of an
|
|
|
|
// encoded short channel ID (8 bytes), then we'll return a
|
|
|
|
// parsing error.
|
|
|
|
if len(queryBody)%8 != 0 {
|
|
|
|
return 0, nil, fmt.Errorf("whole number of short "+
|
|
|
|
"chan ID's cannot be encoded in len=%v",
|
|
|
|
len(queryBody))
|
|
|
|
}
|
|
|
|
|
2018-04-17 03:46:15 +02:00
|
|
|
// As each short channel ID is encoded as 8 bytes, we can
|
|
|
|
// compute the number of bytes encoded based on the size of the
|
|
|
|
// query body.
|
|
|
|
numShortChanIDs := len(queryBody) / 8
|
2018-06-29 04:04:55 +02:00
|
|
|
if numShortChanIDs == 0 {
|
|
|
|
return encodingType, nil, nil
|
|
|
|
}
|
2018-04-17 03:46:15 +02:00
|
|
|
|
|
|
|
// Finally, we'll read out the exact number of short channel
|
|
|
|
// ID's to conclude our parsing.
|
2018-06-29 04:04:55 +02:00
|
|
|
shortChanIDs := make([]ShortChannelID, numShortChanIDs)
|
2018-04-17 03:46:15 +02:00
|
|
|
bodyReader := bytes.NewReader(queryBody)
|
|
|
|
for i := 0; i < numShortChanIDs; i++ {
|
2018-12-10 03:27:41 +01:00
|
|
|
if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil {
|
2018-04-17 03:46:15 +02:00
|
|
|
return 0, nil, fmt.Errorf("unable to parse "+
|
|
|
|
"short chan ID: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return encodingType, shortChanIDs, nil
|
|
|
|
|
2018-06-16 03:31:23 +02:00
|
|
|
// In this encoding, we'll use zlib to decode the compressed payload.
|
|
|
|
// However, we'll pay attention to ensure that we don't open our selves
|
|
|
|
// up to a memory exhaustion attack.
|
|
|
|
case EncodingSortedZlib:
|
2018-06-16 03:33:04 +02:00
|
|
|
// We'll obtain an ultimately release the zlib decode mutex.
|
|
|
|
// This guards us against allocating too much memory to decode
|
|
|
|
// each instance from concurrent peers.
|
|
|
|
zlibDecodeMtx.Lock()
|
|
|
|
defer zlibDecodeMtx.Unlock()
|
|
|
|
|
2018-06-16 03:31:23 +02:00
|
|
|
// Before we start to decode, we'll create a limit reader over
|
|
|
|
// the current reader. This will ensure that we can control how
|
|
|
|
// much memory we're allocating during the decoding process.
|
|
|
|
limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{
|
|
|
|
R: bytes.NewReader(queryBody),
|
|
|
|
N: maxZlibBufSize,
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
return 0, nil, fmt.Errorf("unable to create zlib reader: %v", err)
|
|
|
|
}
|
|
|
|
|
2018-06-26 01:15:30 +02:00
|
|
|
var (
|
|
|
|
shortChanIDs []ShortChannelID
|
|
|
|
lastChanID ShortChannelID
|
|
|
|
)
|
2018-06-16 03:31:23 +02:00
|
|
|
for {
|
|
|
|
// We'll now attempt to read the next short channel ID
|
|
|
|
// encoded in the payload.
|
|
|
|
var cid ShortChannelID
|
2018-12-10 03:27:41 +01:00
|
|
|
err := ReadElements(limitedDecompressor, &cid)
|
2018-06-16 03:31:23 +02:00
|
|
|
|
|
|
|
switch {
|
|
|
|
// If we get an EOF error, then that either means we've
|
|
|
|
// read all that's contained in the buffer, or have hit
|
|
|
|
// our limit on the number of bytes we'll read. In
|
|
|
|
// either case, we'll return what we have so far.
|
|
|
|
case err == io.ErrUnexpectedEOF || err == io.EOF:
|
|
|
|
return encodingType, shortChanIDs, nil
|
|
|
|
|
|
|
|
// Otherwise, we hit some other sort of error, possibly
|
|
|
|
// an invalid payload, so we'll exit early with the
|
|
|
|
// error.
|
|
|
|
case err != nil:
|
|
|
|
return 0, nil, fmt.Errorf("unable to "+
|
|
|
|
"deflate next short chan "+
|
|
|
|
"ID: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// We successfully read the next ID, so well collect
|
|
|
|
// that in the set of final ID's to return.
|
|
|
|
shortChanIDs = append(shortChanIDs, cid)
|
2018-06-26 01:15:30 +02:00
|
|
|
|
|
|
|
// Finally, we'll ensure that this short chan ID is
|
|
|
|
// greater than the last one. This is a requirement
|
|
|
|
// within the encoding, and if violated can aide us in
|
|
|
|
// detecting malicious payloads.
|
|
|
|
if cid.ToUint64() <= lastChanID.ToUint64() {
|
|
|
|
return 0, nil, fmt.Errorf("current sid of %v "+
|
|
|
|
"isn't greater than last sid of %v", cid,
|
|
|
|
lastChanID)
|
|
|
|
}
|
|
|
|
|
|
|
|
lastChanID = cid
|
2018-06-16 03:31:23 +02:00
|
|
|
}
|
|
|
|
|
2018-04-17 03:46:15 +02:00
|
|
|
default:
|
|
|
|
// If we've been sent an encoding type that we don't know of,
|
|
|
|
// then we'll return a parsing error as we can't continue if
|
|
|
|
// we're unable to encode them.
|
|
|
|
return 0, nil, ErrUnknownShortChanIDEncoding(encodingType)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Encode serializes the target QueryShortChanIDs into the passed io.Writer
|
|
|
|
// observing the protocol version specified.
|
|
|
|
//
|
|
|
|
// This is part of the lnwire.Message interface.
|
|
|
|
func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error {
|
|
|
|
// First, we'll write out the chain hash.
|
2018-12-10 03:27:41 +01:00
|
|
|
err := WriteElements(w, q.ChainHash[:])
|
2018-04-17 03:46:15 +02:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Base on our encoding type, we'll write out the set of short channel
|
|
|
|
// ID's.
|
|
|
|
return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs)
|
|
|
|
}
|
|
|
|
|
|
|
|
// encodeShortChanIDs encodes the passed short channel ID's into the passed
|
|
|
|
// io.Writer, respecting the specified encoding type.
|
|
|
|
func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding,
|
|
|
|
shortChanIDs []ShortChannelID) error {
|
|
|
|
|
2018-06-16 03:31:23 +02:00
|
|
|
// For both of the current encoding types, the channel ID's are to be
|
|
|
|
// sorted in place, so we'll do that now.
|
|
|
|
sort.Slice(shortChanIDs, func(i, j int) bool {
|
|
|
|
return shortChanIDs[i].ToUint64() <
|
|
|
|
shortChanIDs[j].ToUint64()
|
|
|
|
})
|
|
|
|
|
2018-04-17 03:46:15 +02:00
|
|
|
switch encodingType {
|
|
|
|
|
|
|
|
// In this encoding, we'll simply write a sorted array of encoded short
|
|
|
|
// channel ID's from the buffer.
|
|
|
|
case EncodingSortedPlain:
|
|
|
|
// First, we'll write out the number of bytes of the query
|
|
|
|
// body. We add 1 as the response will have the encoding type
|
|
|
|
// prepended to it.
|
|
|
|
numBytesBody := uint16(len(shortChanIDs)*8) + 1
|
2018-12-10 03:27:41 +01:00
|
|
|
if err := WriteElements(w, numBytesBody); err != nil {
|
2018-04-17 03:46:15 +02:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// We'll then write out the encoding that that follows the
|
|
|
|
// actual encoded short channel ID's.
|
2018-12-10 03:27:41 +01:00
|
|
|
if err := WriteElements(w, encodingType); err != nil {
|
2018-04-17 03:46:15 +02:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Now that we know they're sorted, we can write out each short
|
|
|
|
// channel ID to the buffer.
|
|
|
|
for _, chanID := range shortChanIDs {
|
2018-12-10 03:27:41 +01:00
|
|
|
if err := WriteElements(w, chanID); err != nil {
|
2018-04-17 03:46:15 +02:00
|
|
|
return fmt.Errorf("unable to write short chan "+
|
|
|
|
"ID: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
2018-06-16 03:31:23 +02:00
|
|
|
// For this encoding we'll first write out a serialized version of all
|
|
|
|
// the channel ID's into a buffer, then zlib encode that. The final
|
|
|
|
// payload is what we'll write out to the passed io.Writer.
|
|
|
|
//
|
|
|
|
// TODO(roasbeef): assumes the caller knows the proper chunk size to
|
|
|
|
// pass to avoid bin-packing here
|
|
|
|
case EncodingSortedZlib:
|
|
|
|
// We'll make a new buffer, then wrap that with a zlib writer
|
|
|
|
// so we can write directly to the buffer and encode in a
|
|
|
|
// streaming manner.
|
|
|
|
var buf bytes.Buffer
|
|
|
|
zlibWriter := zlib.NewWriter(&buf)
|
|
|
|
|
|
|
|
// Next, we'll write out all the channel ID's directly into the
|
|
|
|
// zlib writer, which will do compressing on the fly.
|
|
|
|
for _, chanID := range shortChanIDs {
|
2018-12-10 03:27:41 +01:00
|
|
|
err := WriteElements(zlibWriter, chanID)
|
2018-06-16 03:31:23 +02:00
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("unable to write short chan "+
|
|
|
|
"ID: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Now that we've written all the elements, we'll ensure the
|
|
|
|
// compressed stream is written to the underlying buffer.
|
|
|
|
if err := zlibWriter.Close(); err != nil {
|
|
|
|
return fmt.Errorf("unable to finalize "+
|
|
|
|
"compression: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Now that we have all the items compressed, we can compute
|
|
|
|
// what the total payload size will be. We add one to account
|
|
|
|
// for the byte to encode the type.
|
|
|
|
compressedPayload := buf.Bytes()
|
|
|
|
numBytesBody := len(compressedPayload) + 1
|
|
|
|
|
|
|
|
// Finally, we can write out the number of bytes, the
|
|
|
|
// compression type, and finally the buffer itself.
|
2018-12-10 03:27:41 +01:00
|
|
|
if err := WriteElements(w, uint16(numBytesBody)); err != nil {
|
2018-06-16 03:31:23 +02:00
|
|
|
return err
|
|
|
|
}
|
2018-12-10 03:27:41 +01:00
|
|
|
if err := WriteElements(w, encodingType); err != nil {
|
2018-06-16 03:31:23 +02:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
_, err := w.Write(compressedPayload)
|
|
|
|
return err
|
|
|
|
|
2018-04-17 03:46:15 +02:00
|
|
|
default:
|
|
|
|
// If we're trying to encode with an encoding type that we
|
|
|
|
// don't know of, then we'll return a parsing error as we can't
|
|
|
|
// continue if we're unable to encode them.
|
|
|
|
return ErrUnknownShortChanIDEncoding(encodingType)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// MsgType returns the integer uniquely identifying this message type on the
|
|
|
|
// wire.
|
|
|
|
//
|
|
|
|
// This is part of the lnwire.Message interface.
|
|
|
|
func (q *QueryShortChanIDs) MsgType() MessageType {
|
|
|
|
return MsgQueryShortChanIDs
|
|
|
|
}
|
|
|
|
|
|
|
|
// MaxPayloadLength returns the maximum allowed payload size for a
|
|
|
|
// QueryShortChanIDs complete message observing the specified protocol version.
|
|
|
|
//
|
|
|
|
// This is part of the lnwire.Message interface.
|
|
|
|
func (q *QueryShortChanIDs) MaxPayloadLength(uint32) uint32 {
|
|
|
|
return MaxMessagePayload
|
|
|
|
}
|