From 341bae098c43ed4b122ba5385a5c0622e7922676 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 19 Sep 2023 20:01:24 +0200 Subject: [PATCH] lnwire: add QueryOptions to QueryChannelRange --- lnwire/lnwire_test.go | 17 +++++++ lnwire/query_channel_range.go | 65 +++++++++++++++++++++--- lnwire/query_channel_range_test.go | 79 ++++++++++++++++++++++++++++++ 3 files changed, 153 insertions(+), 8 deletions(-) create mode 100644 lnwire/query_channel_range_test.go diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 9d92b970b..9b119c851 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -1167,6 +1167,23 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(req) }, + MsgQueryChannelRange: func(v []reflect.Value, r *rand.Rand) { + req := QueryChannelRange{ + FirstBlockHeight: uint32(r.Int31()), + NumBlocks: uint32(r.Int31()), + ExtraData: make([]byte, 0), + } + + _, err := rand.Read(req.ChainHash[:]) + require.NoError(t, err) + + // With a 50/50 change, we'll set a query option. + if r.Int31()%2 == 0 { + req.QueryOptions = NewTimestampQueryOption() + } + + v[0] = reflect.ValueOf(req) + }, MsgPing: func(v []reflect.Value, r *rand.Rand) { // We use a special message generator here to ensure we // don't generate ping messages that are too large, diff --git a/lnwire/query_channel_range.go b/lnwire/query_channel_range.go index cfe88e9bf..1e0dcb0fa 100644 --- a/lnwire/query_channel_range.go +++ b/lnwire/query_channel_range.go @@ -6,6 +6,7 @@ import ( "math" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/tlv" ) // QueryChannelRange is a message sent by a node in order to query the @@ -27,6 +28,10 @@ type QueryChannelRange struct { // channel ID's should be sent for. NumBlocks uint32 + // QueryOptions is an optional feature bit vector that can be used to + // specify additional query options. + QueryOptions *QueryOptions + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -35,7 +40,9 @@ type QueryChannelRange struct { // NewQueryChannelRange creates a new empty QueryChannelRange message. func NewQueryChannelRange() *QueryChannelRange { - return &QueryChannelRange{} + return &QueryChannelRange{ + ExtraData: make([]byte, 0), + } } // A compile time check to ensure QueryChannelRange implements the @@ -46,20 +53,42 @@ var _ Message = (*QueryChannelRange)(nil) // passed io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. -func (q *QueryChannelRange) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, - q.ChainHash[:], - &q.FirstBlockHeight, - &q.NumBlocks, - &q.ExtraData, +func (q *QueryChannelRange) Decode(r io.Reader, _ uint32) error { + err := ReadElements( + r, q.ChainHash[:], &q.FirstBlockHeight, &q.NumBlocks, ) + if err != nil { + return err + } + + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + var queryOptions QueryOptions + typeMap, err := tlvRecords.ExtractRecords(&queryOptions) + if err != nil { + return err + } + + // Set the corresponding TLV types if they were included in the stream. + if val, ok := typeMap[QueryOptionsRecordType]; ok && val == nil { + q.QueryOptions = &queryOptions + } + + if len(tlvRecords) != 0 { + q.ExtraData = tlvRecords + } + + return nil } // Encode serializes the target QueryChannelRange into the passed io.Writer // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (q *QueryChannelRange) Encode(w *bytes.Buffer, pver uint32) error { +func (q *QueryChannelRange) Encode(w *bytes.Buffer, _ uint32) error { if err := WriteBytes(w, q.ChainHash[:]); err != nil { return err } @@ -72,6 +101,15 @@ func (q *QueryChannelRange) Encode(w *bytes.Buffer, pver uint32) error { return err } + recordProducers := make([]tlv.RecordProducer, 0, 1) + if q.QueryOptions != nil { + recordProducers = append(recordProducers, q.QueryOptions) + } + err := EncodeMessageExtraData(&q.ExtraData, recordProducers...) + if err != nil { + return err + } + return WriteBytes(w, q.ExtraData) } @@ -93,3 +131,14 @@ func (q *QueryChannelRange) LastBlockHeight() uint32 { } return uint32(lastBlockHeight) } + +// WithTimestamps returns true if the query has asked for timestamps too. +func (q *QueryChannelRange) WithTimestamps() bool { + if q.QueryOptions == nil { + return false + } + + queryOpts := RawFeatureVector(*q.QueryOptions) + + return queryOpts.IsSet(QueryOptionTimestampBit) +} diff --git a/lnwire/query_channel_range_test.go b/lnwire/query_channel_range_test.go new file mode 100644 index 000000000..5d690f38d --- /dev/null +++ b/lnwire/query_channel_range_test.go @@ -0,0 +1,79 @@ +package lnwire + +import ( + "bytes" + "encoding/hex" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestQueryChannelRange tests that a few query_channel_range test vectors can +// correctly be decoded and encoded. +func TestQueryChannelRange(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expFirstBlockNum int + expNumOfBlocks int + expWantTimestamps bool + }{ + { + name: "without timestamps query option", + input: "01070f9188f13cb7b2c71f2a335e3a4fc328bf5beb436" + + "012afca590b1a11466e2206000186a0000005dc", + expFirstBlockNum: 100000, + expNumOfBlocks: 1500, + expWantTimestamps: false, + }, + { + name: "with timestamps query option", + input: "01070f9188f13cb7b2c71f2a335e3a4fc328bf5beb436" + + "012afca590b1a11466e2206000088b800000064010103", + expFirstBlockNum: 35000, + expNumOfBlocks: 100, + expWantTimestamps: true, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + b, err := hex.DecodeString(test.input) + require.NoError(t, err) + + r := bytes.NewBuffer(b) + + msg, err := ReadMessage(r, 0) + require.NoError(t, err) + + queryMsg, ok := msg.(*QueryChannelRange) + require.True(t, ok) + + require.EqualValues( + t, test.expFirstBlockNum, + queryMsg.FirstBlockHeight, + ) + + require.EqualValues( + t, test.expNumOfBlocks, queryMsg.NumBlocks, + ) + + require.Equal( + t, test.expWantTimestamps, + queryMsg.WithTimestamps(), + ) + + var buf bytes.Buffer + _, err = WriteMessage(&buf, queryMsg, 0) + require.NoError(t, err) + + require.Equal(t, buf.Bytes(), b) + }) + } +}