lnwire: add QueryOptions to QueryChannelRange

This commit is contained in:
Elle Mouton 2023-09-19 20:01:24 +02:00
parent 8efd141347
commit 341bae098c
No known key found for this signature in database
GPG Key ID: D7D916376026F177
3 changed files with 153 additions and 8 deletions

View File

@ -1167,6 +1167,23 @@ func TestLightningWireProtocol(t *testing.T) {
v[0] = reflect.ValueOf(req)
},
MsgQueryChannelRange: func(v []reflect.Value, r *rand.Rand) {
req := QueryChannelRange{
FirstBlockHeight: uint32(r.Int31()),
NumBlocks: uint32(r.Int31()),
ExtraData: make([]byte, 0),
}
_, err := rand.Read(req.ChainHash[:])
require.NoError(t, err)
// With a 50/50 change, we'll set a query option.
if r.Int31()%2 == 0 {
req.QueryOptions = NewTimestampQueryOption()
}
v[0] = reflect.ValueOf(req)
},
MsgPing: func(v []reflect.Value, r *rand.Rand) {
// We use a special message generator here to ensure we
// don't generate ping messages that are too large,

View File

@ -6,6 +6,7 @@ import (
"math"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/tlv"
)
// QueryChannelRange is a message sent by a node in order to query the
@ -27,6 +28,10 @@ type QueryChannelRange struct {
// channel ID's should be sent for.
NumBlocks uint32
// QueryOptions is an optional feature bit vector that can be used to
// specify additional query options.
QueryOptions *QueryOptions
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
@ -35,7 +40,9 @@ type QueryChannelRange struct {
// NewQueryChannelRange creates a new empty QueryChannelRange message.
func NewQueryChannelRange() *QueryChannelRange {
return &QueryChannelRange{}
return &QueryChannelRange{
ExtraData: make([]byte, 0),
}
}
// A compile time check to ensure QueryChannelRange implements the
@ -46,20 +53,42 @@ var _ Message = (*QueryChannelRange)(nil)
// passed io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (q *QueryChannelRange) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
q.ChainHash[:],
&q.FirstBlockHeight,
&q.NumBlocks,
&q.ExtraData,
func (q *QueryChannelRange) Decode(r io.Reader, _ uint32) error {
err := ReadElements(
r, q.ChainHash[:], &q.FirstBlockHeight, &q.NumBlocks,
)
if err != nil {
return err
}
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
var queryOptions QueryOptions
typeMap, err := tlvRecords.ExtractRecords(&queryOptions)
if err != nil {
return err
}
// Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[QueryOptionsRecordType]; ok && val == nil {
q.QueryOptions = &queryOptions
}
if len(tlvRecords) != 0 {
q.ExtraData = tlvRecords
}
return nil
}
// Encode serializes the target QueryChannelRange into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (q *QueryChannelRange) Encode(w *bytes.Buffer, pver uint32) error {
func (q *QueryChannelRange) Encode(w *bytes.Buffer, _ uint32) error {
if err := WriteBytes(w, q.ChainHash[:]); err != nil {
return err
}
@ -72,6 +101,15 @@ func (q *QueryChannelRange) Encode(w *bytes.Buffer, pver uint32) error {
return err
}
recordProducers := make([]tlv.RecordProducer, 0, 1)
if q.QueryOptions != nil {
recordProducers = append(recordProducers, q.QueryOptions)
}
err := EncodeMessageExtraData(&q.ExtraData, recordProducers...)
if err != nil {
return err
}
return WriteBytes(w, q.ExtraData)
}
@ -93,3 +131,14 @@ func (q *QueryChannelRange) LastBlockHeight() uint32 {
}
return uint32(lastBlockHeight)
}
// WithTimestamps returns true if the query has asked for timestamps too.
func (q *QueryChannelRange) WithTimestamps() bool {
if q.QueryOptions == nil {
return false
}
queryOpts := RawFeatureVector(*q.QueryOptions)
return queryOpts.IsSet(QueryOptionTimestampBit)
}

View File

@ -0,0 +1,79 @@
package lnwire
import (
"bytes"
"encoding/hex"
"testing"
"github.com/stretchr/testify/require"
)
// TestQueryChannelRange tests that a few query_channel_range test vectors can
// correctly be decoded and encoded.
func TestQueryChannelRange(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
expFirstBlockNum int
expNumOfBlocks int
expWantTimestamps bool
}{
{
name: "without timestamps query option",
input: "01070f9188f13cb7b2c71f2a335e3a4fc328bf5beb436" +
"012afca590b1a11466e2206000186a0000005dc",
expFirstBlockNum: 100000,
expNumOfBlocks: 1500,
expWantTimestamps: false,
},
{
name: "with timestamps query option",
input: "01070f9188f13cb7b2c71f2a335e3a4fc328bf5beb436" +
"012afca590b1a11466e2206000088b800000064010103",
expFirstBlockNum: 35000,
expNumOfBlocks: 100,
expWantTimestamps: true,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
b, err := hex.DecodeString(test.input)
require.NoError(t, err)
r := bytes.NewBuffer(b)
msg, err := ReadMessage(r, 0)
require.NoError(t, err)
queryMsg, ok := msg.(*QueryChannelRange)
require.True(t, ok)
require.EqualValues(
t, test.expFirstBlockNum,
queryMsg.FirstBlockHeight,
)
require.EqualValues(
t, test.expNumOfBlocks, queryMsg.NumBlocks,
)
require.Equal(
t, test.expWantTimestamps,
queryMsg.WithTimestamps(),
)
var buf bytes.Buffer
_, err = WriteMessage(&buf, queryMsg, 0)
require.NoError(t, err)
require.Equal(t, buf.Bytes(), b)
})
}
}