From 3ff8eb899c0a72fdb1a7144c5b248c89d6b1a592 Mon Sep 17 00:00:00 2001 From: eugene Date: Mon, 4 Apr 2022 15:25:16 -0400 Subject: [PATCH] lnwire: add alias to FundingLocked in TLV This adds an optional short channel id field to the FundingLocked message that is sent/received as a TLV segment inside the ExtraOpaqueData field. --- lnwire/funding_locked.go | 37 ++++++++++++++++++++++++++- lnwire/short_channel_id.go | 44 +++++++++++++++++++++++++++++++++ lnwire/short_channel_id_test.go | 23 +++++++++++++++++ 3 files changed, 103 insertions(+), 1 deletion(-) diff --git a/lnwire/funding_locked.go b/lnwire/funding_locked.go index 02ee4d40b..fb47356ba 100644 --- a/lnwire/funding_locked.go +++ b/lnwire/funding_locked.go @@ -5,6 +5,7 @@ import ( "io" "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/tlv" ) // FundingLocked is the message that both parties to a new channel creation @@ -21,6 +22,11 @@ type FundingLocked struct { // next commitment transaction for the channel. NextPerCommitmentPoint *btcec.PublicKey + // AliasScid is an alias ShortChannelID used to refer to the underlying + // channel. It can be used instead of the confirmed on-chain + // ShortChannelID for forwarding. + AliasScid *ShortChannelID + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -47,11 +53,31 @@ var _ Message = (*FundingLocked)(nil) // // This is part of the lnwire.Message interface. func (c *FundingLocked) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, + // Read all the mandatory fields in the message. + err := ReadElements(r, &c.ChanID, &c.NextPerCommitmentPoint, &c.ExtraData, ) + if err != nil { + return err + } + + // Next we'll parse out the set of known records. For now, this is just + // the AliasScidRecordType. + var aliasScid ShortChannelID + typeMap, err := c.ExtraData.ExtractRecords(&aliasScid) + if err != nil { + return err + } + + // We'll only set AliasScid if the corresponding TLV type was included + // in the stream. + if val, ok := typeMap[AliasScidRecordType]; ok && val == nil { + c.AliasScid = &aliasScid + } + + return nil } // Encode serializes the target FundingLocked message into the passed io.Writer @@ -68,6 +94,15 @@ func (c *FundingLocked) Encode(w *bytes.Buffer, pver uint32) error { return err } + // We'll only encode the AliasScid in a TLV segment if it exists. + if c.AliasScid != nil { + recordProducers := []tlv.RecordProducer{c.AliasScid} + err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) + if err != nil { + return err + } + } + return WriteBytes(w, c.ExtraData) } diff --git a/lnwire/short_channel_id.go b/lnwire/short_channel_id.go index b2b980aa2..f07de709f 100644 --- a/lnwire/short_channel_id.go +++ b/lnwire/short_channel_id.go @@ -2,6 +2,15 @@ package lnwire import ( "fmt" + "io" + + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // AliasScidRecordType is the type of the experimental record to denote + // the alias being used in an option_scid_alias channel. + AliasScidRecordType tlv.Type = 1 ) // ShortChannelID represents the set of data which is needed to retrieve all @@ -46,3 +55,38 @@ func (c ShortChannelID) ToUint64() uint64 { func (c ShortChannelID) String() string { return fmt.Sprintf("%d:%d:%d", c.BlockHeight, c.TxIndex, c.TxPosition) } + +// Record returns a TLV record that can be used to encode/decode a +// ShortChannelID to/from a TLV stream. +func (c *ShortChannelID) Record() tlv.Record { + return tlv.MakeStaticRecord( + AliasScidRecordType, c, 8, EShortChannelID, DShortChannelID, + ) +} + +// EShortChannelID is an encoder for ShortChannelID. It is exported so other +// packages can use the encoding scheme. +func EShortChannelID(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*ShortChannelID); ok { + return tlv.EUint64T(w, v.ToUint64(), buf) + } + return tlv.NewTypeForEncodingErr(val, "lnwire.ShortChannelID") +} + +// DShortChannelID is a decoder for ShortChannelID. It is exported so other +// packages can use the decoding scheme. +func DShortChannelID(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if v, ok := val.(*ShortChannelID); ok { + var scid uint64 + err := tlv.DUint64(r, &scid, buf, 8) + if err != nil { + return err + } + + *v = NewShortChanIDFromInt(scid) + return nil + } + return tlv.NewTypeForDecodingErr(val, "lnwire.ShortChannelID", l, 8) +} diff --git a/lnwire/short_channel_id_test.go b/lnwire/short_channel_id_test.go index 3bab49834..2916f20d1 100644 --- a/lnwire/short_channel_id_test.go +++ b/lnwire/short_channel_id_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/davecgh/go-spew/spew" + "github.com/stretchr/testify/require" ) func TestShortChannelIDEncoding(t *testing.T) { @@ -39,3 +40,25 @@ func TestShortChannelIDEncoding(t *testing.T) { } } } + +// TestScidTypeEncodeDecode tests that we're able to properly encode and decode +// ShortChannelID within TLV streams. +func TestScidTypeEncodeDecode(t *testing.T) { + t.Parallel() + + aliasScid := ShortChannelID{ + BlockHeight: (1 << 24) - 1, + TxIndex: (1 << 24) - 1, + TxPosition: (1 << 16) - 1, + } + + var extraData ExtraOpaqueData + require.NoError(t, extraData.PackRecords(&aliasScid)) + + var aliasScid2 ShortChannelID + tlvs, err := extraData.ExtractRecords(&aliasScid2) + require.NoError(t, err) + + require.Contains(t, tlvs, AliasScidRecordType) + require.Equal(t, aliasScid, aliasScid2) +}