package lnwire import ( "bytes" "encoding/hex" "testing" "github.com/stretchr/testify/require" ) // TestReplyChannelRangeUnsorted tests that decoding a ReplyChannelRange request // that contains duplicate or unsorted ids returns an ErrUnsortedSIDs failure. func TestReplyChannelRangeUnsorted(t *testing.T) { for _, test := range unsortedSidTests { test := test t.Run(test.name, func(t *testing.T) { req := &ReplyChannelRange{ EncodingType: test.encType, ShortChanIDs: test.sids, noSort: true, } var b bytes.Buffer err := req.Encode(&b, 0) if err != nil { t.Fatalf("unable to encode req: %v", err) } var req2 ReplyChannelRange err = req2.Decode(bytes.NewReader(b.Bytes()), 0) if _, ok := err.(ErrUnsortedSIDs); !ok { t.Fatalf("expected ErrUnsortedSIDs, got: %v", err) } }) } } // TestReplyChannelRangeEmpty tests encoding and decoding a ReplyChannelRange // that doesn't contain any channel results. func TestReplyChannelRangeEmpty(t *testing.T) { t.Parallel() emptyChannelsTests := []struct { name string encType QueryEncoding encodedHex string }{ { name: "empty plain encoding", encType: EncodingSortedPlain, encodedHex: "000000000000000000000000000000000000000" + "00000000000000000000000000000000100000002" + "01000100", }, { name: "empty zlib encoding", encType: EncodingSortedZlib, encodedHex: "00000000000000000000000000000000000000" + "0000000000000000000000000000000001000000" + "0201000101", }, } for _, test := range emptyChannelsTests { test := test t.Run(test.name, func(t *testing.T) { req := ReplyChannelRange{ FirstBlockHeight: 1, NumBlocks: 2, Complete: 1, EncodingType: test.encType, ShortChanIDs: nil, ExtraData: make([]byte, 0), } // First decode the hex string in the test case into a // new ReplyChannelRange message. It should be // identical to the one created above. req2 := NewReplyChannelRange() b, _ := hex.DecodeString(test.encodedHex) err := req2.Decode(bytes.NewReader(b), 0) require.NoError(t, err) require.Equal(t, req, *req2) // Next, we go in the reverse direction: encode the // request created above, and assert that it matches // the raw byte encoding. var b2 bytes.Buffer err = req.Encode(&b2, 0) require.NoError(t, err) require.Equal(t, b, b2.Bytes()) }) } } // TestReplyChannelRangeEncode tests that encoding a ReplyChannelRange message // results in the correct sorting of the SCIDs and Timestamps. func TestReplyChannelRangeEncode(t *testing.T) { t.Parallel() tests := []struct { name string scids []ShortChannelID timestamps Timestamps expError string expScids []ShortChannelID expTimestamps Timestamps }{ { name: "scids only, sorted", scids: []ShortChannelID{ {BlockHeight: 100}, {BlockHeight: 200}, {BlockHeight: 300}, }, expScids: []ShortChannelID{ {BlockHeight: 100}, {BlockHeight: 200}, {BlockHeight: 300}, }, }, { name: "scids only, unsorted", scids: []ShortChannelID{ {BlockHeight: 300}, {BlockHeight: 100}, {BlockHeight: 200}, }, expScids: []ShortChannelID{ {BlockHeight: 100}, {BlockHeight: 200}, {BlockHeight: 300}, }, }, { name: "scids and timestamps, sorted", scids: []ShortChannelID{ {BlockHeight: 100}, {BlockHeight: 200}, {BlockHeight: 300}, }, timestamps: Timestamps{ {Timestamp1: 1, Timestamp2: 2}, {Timestamp1: 3, Timestamp2: 4}, {Timestamp1: 5, Timestamp2: 6}, }, expScids: []ShortChannelID{ {BlockHeight: 100}, {BlockHeight: 200}, {BlockHeight: 300}, }, expTimestamps: Timestamps{ {Timestamp1: 1, Timestamp2: 2}, {Timestamp1: 3, Timestamp2: 4}, {Timestamp1: 5, Timestamp2: 6}, }, }, { name: "scids and timestamps, unsorted", scids: []ShortChannelID{ {BlockHeight: 300}, {BlockHeight: 100}, {BlockHeight: 200}, }, timestamps: Timestamps{ {Timestamp1: 5, Timestamp2: 6}, {Timestamp1: 1, Timestamp2: 2}, {Timestamp1: 3, Timestamp2: 4}, }, expScids: []ShortChannelID{ {BlockHeight: 100}, {BlockHeight: 200}, {BlockHeight: 300}, }, expTimestamps: Timestamps{ {Timestamp1: 1, Timestamp2: 2}, {Timestamp1: 3, Timestamp2: 4}, {Timestamp1: 5, Timestamp2: 6}, }, }, { name: "scid and timestamp count does not match", scids: []ShortChannelID{ {BlockHeight: 100}, {BlockHeight: 200}, {BlockHeight: 300}, }, timestamps: Timestamps{ {Timestamp1: 1, Timestamp2: 2}, {Timestamp1: 3, Timestamp2: 4}, }, expError: "got must provide a timestamp pair for " + "each of the given SCIDs", }, { name: "duplicate scids", scids: []ShortChannelID{ {BlockHeight: 100}, {BlockHeight: 200}, {BlockHeight: 200}, }, timestamps: Timestamps{ {Timestamp1: 1, Timestamp2: 2}, {Timestamp1: 3, Timestamp2: 4}, {Timestamp1: 5, Timestamp2: 6}, }, expError: "scid list should not contain duplicates", }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() replyMsg := &ReplyChannelRange{ FirstBlockHeight: 1, NumBlocks: 2, Complete: 1, EncodingType: EncodingSortedPlain, ShortChanIDs: test.scids, Timestamps: test.timestamps, ExtraData: make([]byte, 0), } var buf bytes.Buffer _, err := WriteMessage(&buf, replyMsg, 0) if len(test.expError) != 0 { require.ErrorContains(t, err, test.expError) return } require.NoError(t, err) r := bytes.NewBuffer(buf.Bytes()) msg, err := ReadMessage(r, 0) require.NoError(t, err) msg2, ok := msg.(*ReplyChannelRange) require.True(t, ok) require.Equal(t, test.expScids, msg2.ShortChanIDs) require.Equal(t, test.expTimestamps, msg2.Timestamps) }) } } // TestReplyChannelRangeDecode tests the decoding of some ReplyChannelRange // test vectors. func TestReplyChannelRangeDecode(t *testing.T) { t.Parallel() tests := []struct { name string hex string expEncoding QueryEncoding expSCIDs []string expTimestamps Timestamps expError string }{ { name: "plain encoding", hex: "01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb4360" + "12afca590b1a11466e2206000b8a06000005dc01001" + "900000000000000008e0000000000003c6900000000" + "0045a6c4", expEncoding: EncodingSortedPlain, expSCIDs: []string{ "0:0:142", "0:0:15465", "0:69:42692", }, }, { name: "zlib encoding", hex: "01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb4360" + "12afca590b1a11466e2206000006400000006e010016" + "01789c636000833e08659309a65878be010010a9023a", expEncoding: EncodingSortedZlib, expSCIDs: []string{ "0:0:142", "0:0:15465", "0:4:3318", }, }, { name: "plain encoding including timestamps", hex: "01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb43601" + "2afca590b1a11466e22060001ddde000005dc0100190" + "0000000000000304300000000000778d600000000004" + "6e1c1011900000282c1000e77c5000778ad00490ab00" + "000b57800955bff031800000457000008ae00000d050" + "000115c000015b300001a0a", expEncoding: EncodingSortedPlain, expSCIDs: []string{ "0:0:12355", "0:7:30934", "0:70:57793", }, expTimestamps: Timestamps{ { Timestamp1: 164545, Timestamp2: 948165, }, { Timestamp1: 489645, Timestamp2: 4786864, }, { Timestamp1: 46456, Timestamp2: 9788415, }, }, }, { name: "unsupported encoding", hex: "01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb" + "436012afca590b1a11466e22060001ddde000005dc01" + "001801789c63600001036730c55e710d4cbb3d3c0800" + "17c303b1012201789c63606a3ac8c0577e9481bd622d" + "8327d7060686ad150c53a3ff0300554707db03180000" + "0457000008ae00000d050000115c000015b300001a0a", expError: "unsupported encoding", }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() b, err := hex.DecodeString(test.hex) require.NoError(t, err) r := bytes.NewBuffer(b) msg, err := ReadMessage(r, 0) if len(test.expError) != 0 { require.ErrorContains(t, err, test.expError) return } require.NoError(t, err) replyMsg, ok := msg.(*ReplyChannelRange) require.True(t, ok) require.Equal( t, test.expEncoding, replyMsg.EncodingType, ) scids := make([]string, len(replyMsg.ShortChanIDs)) for i, id := range replyMsg.ShortChanIDs { scids[i] = id.String() } require.Equal(t, scids, test.expSCIDs) require.Equal( t, test.expTimestamps, replyMsg.Timestamps, ) }) } }