mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 01:43:16 +01:00
lnwire: add QueryOptions to QueryChannelRange
This commit is contained in:
parent
8efd141347
commit
341bae098c
@ -1167,6 +1167,23 @@ func TestLightningWireProtocol(t *testing.T) {
|
||||
|
||||
v[0] = reflect.ValueOf(req)
|
||||
},
|
||||
MsgQueryChannelRange: func(v []reflect.Value, r *rand.Rand) {
|
||||
req := QueryChannelRange{
|
||||
FirstBlockHeight: uint32(r.Int31()),
|
||||
NumBlocks: uint32(r.Int31()),
|
||||
ExtraData: make([]byte, 0),
|
||||
}
|
||||
|
||||
_, err := rand.Read(req.ChainHash[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
// With a 50/50 change, we'll set a query option.
|
||||
if r.Int31()%2 == 0 {
|
||||
req.QueryOptions = NewTimestampQueryOption()
|
||||
}
|
||||
|
||||
v[0] = reflect.ValueOf(req)
|
||||
},
|
||||
MsgPing: func(v []reflect.Value, r *rand.Rand) {
|
||||
// We use a special message generator here to ensure we
|
||||
// don't generate ping messages that are too large,
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"math"
|
||||
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
// QueryChannelRange is a message sent by a node in order to query the
|
||||
@ -27,6 +28,10 @@ type QueryChannelRange struct {
|
||||
// channel ID's should be sent for.
|
||||
NumBlocks uint32
|
||||
|
||||
// QueryOptions is an optional feature bit vector that can be used to
|
||||
// specify additional query options.
|
||||
QueryOptions *QueryOptions
|
||||
|
||||
// ExtraData is the set of data that was appended to this message to
|
||||
// fill out the full maximum transport message size. These fields can
|
||||
// be used to specify optional data such as custom TLV fields.
|
||||
@ -35,7 +40,9 @@ type QueryChannelRange struct {
|
||||
|
||||
// NewQueryChannelRange creates a new empty QueryChannelRange message.
|
||||
func NewQueryChannelRange() *QueryChannelRange {
|
||||
return &QueryChannelRange{}
|
||||
return &QueryChannelRange{
|
||||
ExtraData: make([]byte, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// A compile time check to ensure QueryChannelRange implements the
|
||||
@ -46,20 +53,42 @@ var _ Message = (*QueryChannelRange)(nil)
|
||||
// passed io.Reader observing the specified protocol version.
|
||||
//
|
||||
// This is part of the lnwire.Message interface.
|
||||
func (q *QueryChannelRange) Decode(r io.Reader, pver uint32) error {
|
||||
return ReadElements(r,
|
||||
q.ChainHash[:],
|
||||
&q.FirstBlockHeight,
|
||||
&q.NumBlocks,
|
||||
&q.ExtraData,
|
||||
func (q *QueryChannelRange) Decode(r io.Reader, _ uint32) error {
|
||||
err := ReadElements(
|
||||
r, q.ChainHash[:], &q.FirstBlockHeight, &q.NumBlocks,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var tlvRecords ExtraOpaqueData
|
||||
if err := ReadElements(r, &tlvRecords); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var queryOptions QueryOptions
|
||||
typeMap, err := tlvRecords.ExtractRecords(&queryOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the corresponding TLV types if they were included in the stream.
|
||||
if val, ok := typeMap[QueryOptionsRecordType]; ok && val == nil {
|
||||
q.QueryOptions = &queryOptions
|
||||
}
|
||||
|
||||
if len(tlvRecords) != 0 {
|
||||
q.ExtraData = tlvRecords
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode serializes the target QueryChannelRange into the passed io.Writer
|
||||
// observing the protocol version specified.
|
||||
//
|
||||
// This is part of the lnwire.Message interface.
|
||||
func (q *QueryChannelRange) Encode(w *bytes.Buffer, pver uint32) error {
|
||||
func (q *QueryChannelRange) Encode(w *bytes.Buffer, _ uint32) error {
|
||||
if err := WriteBytes(w, q.ChainHash[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -72,6 +101,15 @@ func (q *QueryChannelRange) Encode(w *bytes.Buffer, pver uint32) error {
|
||||
return err
|
||||
}
|
||||
|
||||
recordProducers := make([]tlv.RecordProducer, 0, 1)
|
||||
if q.QueryOptions != nil {
|
||||
recordProducers = append(recordProducers, q.QueryOptions)
|
||||
}
|
||||
err := EncodeMessageExtraData(&q.ExtraData, recordProducers...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return WriteBytes(w, q.ExtraData)
|
||||
}
|
||||
|
||||
@ -93,3 +131,14 @@ func (q *QueryChannelRange) LastBlockHeight() uint32 {
|
||||
}
|
||||
return uint32(lastBlockHeight)
|
||||
}
|
||||
|
||||
// WithTimestamps returns true if the query has asked for timestamps too.
|
||||
func (q *QueryChannelRange) WithTimestamps() bool {
|
||||
if q.QueryOptions == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
queryOpts := RawFeatureVector(*q.QueryOptions)
|
||||
|
||||
return queryOpts.IsSet(QueryOptionTimestampBit)
|
||||
}
|
||||
|
79
lnwire/query_channel_range_test.go
Normal file
79
lnwire/query_channel_range_test.go
Normal file
@ -0,0 +1,79 @@
|
||||
package lnwire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestQueryChannelRange tests that a few query_channel_range test vectors can
|
||||
// correctly be decoded and encoded.
|
||||
func TestQueryChannelRange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expFirstBlockNum int
|
||||
expNumOfBlocks int
|
||||
expWantTimestamps bool
|
||||
}{
|
||||
{
|
||||
name: "without timestamps query option",
|
||||
input: "01070f9188f13cb7b2c71f2a335e3a4fc328bf5beb436" +
|
||||
"012afca590b1a11466e2206000186a0000005dc",
|
||||
expFirstBlockNum: 100000,
|
||||
expNumOfBlocks: 1500,
|
||||
expWantTimestamps: false,
|
||||
},
|
||||
{
|
||||
name: "with timestamps query option",
|
||||
input: "01070f9188f13cb7b2c71f2a335e3a4fc328bf5beb436" +
|
||||
"012afca590b1a11466e2206000088b800000064010103",
|
||||
expFirstBlockNum: 35000,
|
||||
expNumOfBlocks: 100,
|
||||
expWantTimestamps: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b, err := hex.DecodeString(test.input)
|
||||
require.NoError(t, err)
|
||||
|
||||
r := bytes.NewBuffer(b)
|
||||
|
||||
msg, err := ReadMessage(r, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
queryMsg, ok := msg.(*QueryChannelRange)
|
||||
require.True(t, ok)
|
||||
|
||||
require.EqualValues(
|
||||
t, test.expFirstBlockNum,
|
||||
queryMsg.FirstBlockHeight,
|
||||
)
|
||||
|
||||
require.EqualValues(
|
||||
t, test.expNumOfBlocks, queryMsg.NumBlocks,
|
||||
)
|
||||
|
||||
require.Equal(
|
||||
t, test.expWantTimestamps,
|
||||
queryMsg.WithTimestamps(),
|
||||
)
|
||||
|
||||
var buf bytes.Buffer
|
||||
_, err = WriteMessage(&buf, queryMsg, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, buf.Bytes(), b)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user