diff --git a/channeldb/channel.go b/channeldb/channel.go index 4e3db2fc1..ec449613e 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -226,28 +226,83 @@ const ( // A tlv type definition used to serialize an outpoint's indexStatus // for use in the outpoint index. indexStatusType tlv.Type = 0 - - // A tlv type definition used to serialize and deserialize a KeyLocator - // from the database. - keyLocType tlv.Type = 1 - - // A tlv type used to serialize and deserialize the - // `InitialLocalBalance` field. - initialLocalBalanceType tlv.Type = 2 - - // A tlv type used to serialize and deserialize the - // `InitialRemoteBalance` field. - initialRemoteBalanceType tlv.Type = 3 - - // A tlv type definition used to serialize and deserialize the - // confirmed ShortChannelID for a zero-conf channel. - realScidType tlv.Type = 4 - - // A tlv type definition used to serialize and deserialize the - // Memo for the channel channel. - channelMemoType tlv.Type = 5 ) +// openChannelTlvData houses the new data fields that are stored for each +// channel in a TLV stream within the root bucket. This is stored as a TLV +// stream appended to the existing hard-coded fields in the channel's root +// bucket. New fields being added to the channel state should be added here. +// +// NOTE: This struct is used for serialization purposes only and its fields +// should be accessed via the OpenChannel struct while in memory. +type openChannelTlvData struct { + // revokeKeyLoc is the key locator for the revocation key. + revokeKeyLoc tlv.RecordT[tlv.TlvType1, keyLocRecord] + + // initialLocalBalance is the initial local balance of the channel. + initialLocalBalance tlv.RecordT[tlv.TlvType2, uint64] + + // initialRemoteBalance is the initial remote balance of the channel. + initialRemoteBalance tlv.RecordT[tlv.TlvType3, uint64] + + // realScid is the real short channel ID of the channel corresponding to + // the on-chain outpoint. + realScid tlv.RecordT[tlv.TlvType4, lnwire.ShortChannelID] + + // memo is an optional text field that gives context to the user about + // the channel. + memo tlv.OptionalRecordT[tlv.TlvType5, []byte] +} + +// encode serializes the openChannelTlvData to the given io.Writer. +func (c *openChannelTlvData) encode(w io.Writer) error { + tlvRecords := []tlv.Record{ + c.revokeKeyLoc.Record(), + c.initialLocalBalance.Record(), + c.initialRemoteBalance.Record(), + c.realScid.Record(), + } + c.memo.WhenSome(func(memo tlv.RecordT[tlv.TlvType5, []byte]) { + tlvRecords = append(tlvRecords, memo.Record()) + }) + + // Create the tlv stream. + tlvStream, err := tlv.NewStream(tlvRecords...) + if err != nil { + return err + } + + return tlvStream.Encode(w) +} + +// decode deserializes the openChannelTlvData from the given io.Reader. +func (c *openChannelTlvData) decode(r io.Reader) error { + memo := c.memo.Zero() + + // Create the tlv stream. + tlvStream, err := tlv.NewStream( + c.revokeKeyLoc.Record(), + c.initialLocalBalance.Record(), + c.initialRemoteBalance.Record(), + c.realScid.Record(), + memo.Record(), + ) + if err != nil { + return err + } + + tlvs, err := tlvStream.DecodeWithParsedTypes(r) + if err != nil { + return err + } + + if _, ok := tlvs[memo.TlvType()]; ok { + c.memo = tlv.SomeRecordT(memo) + } + + return nil +} + // indexStatus is an enum-like type that describes what state the // outpoint is in. Currently only two possible values. type indexStatus uint8 @@ -867,6 +922,10 @@ type OpenChannel struct { // channel that will be useful to our future selves. Memo []byte + // TapscriptRoot is an optional tapscript root used to derive the MuSig2 + // funding output. + TapscriptRoot fn.Option[chainhash.Hash] + // TODO(roasbeef): eww Db *ChannelStateDB @@ -1025,6 +1084,48 @@ func (c *OpenChannel) SetBroadcastHeight(height uint32) { c.FundingBroadcastHeight = height } +// amendTlvData updates the channel with the given auxiliary TLV data. +func (c *OpenChannel) amendTlvData(auxData openChannelTlvData) { + c.RevocationKeyLocator = auxData.revokeKeyLoc.Val.KeyLocator + c.InitialLocalBalance = lnwire.MilliSatoshi( + auxData.initialLocalBalance.Val, + ) + c.InitialRemoteBalance = lnwire.MilliSatoshi( + auxData.initialRemoteBalance.Val, + ) + c.confirmedScid = auxData.realScid.Val + + auxData.memo.WhenSomeV(func(memo []byte) { + c.Memo = memo + }) +} + +// extractTlvData creates a new openChannelTlvData from the given channel. +func (c *OpenChannel) extractTlvData() openChannelTlvData { + auxData := openChannelTlvData{ + revokeKeyLoc: tlv.NewRecordT[tlv.TlvType1]( + keyLocRecord{c.RevocationKeyLocator}, + ), + initialLocalBalance: tlv.NewPrimitiveRecord[tlv.TlvType2]( + uint64(c.InitialLocalBalance), + ), + initialRemoteBalance: tlv.NewPrimitiveRecord[tlv.TlvType3]( + uint64(c.InitialRemoteBalance), + ), + realScid: tlv.NewRecordT[tlv.TlvType4]( + c.confirmedScid, + ), + } + + if len(c.Memo) != 0 { + auxData.memo = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5](c.Memo), + ) + } + + return auxData +} + // Refresh updates the in-memory channel state using the latest state observed // on disk. func (c *OpenChannel) Refresh() error { @@ -4030,32 +4131,9 @@ func putChanInfo(chanBucket kvdb.RwBucket, channel *OpenChannel) error { return err } - // Convert balance fields into uint64. - localBalance := uint64(channel.InitialLocalBalance) - remoteBalance := uint64(channel.InitialRemoteBalance) - - // Create the tlv stream. - tlvStream, err := tlv.NewStream( - // Write the RevocationKeyLocator as the first entry in a tlv - // stream. - MakeKeyLocRecord( - keyLocType, &channel.RevocationKeyLocator, - ), - tlv.MakePrimitiveRecord( - initialLocalBalanceType, &localBalance, - ), - tlv.MakePrimitiveRecord( - initialRemoteBalanceType, &remoteBalance, - ), - MakeScidRecord(realScidType, &channel.confirmedScid), - tlv.MakePrimitiveRecord(channelMemoType, &channel.Memo), - ) - if err != nil { - return err - } - - if err := tlvStream.Encode(&w); err != nil { - return err + auxData := channel.extractTlvData() + if err := auxData.encode(&w); err != nil { + return fmt.Errorf("unable to encode aux data: %w", err) } if err := chanBucket.Put(chanInfoKey, w.Bytes()); err != nil { @@ -4244,45 +4322,14 @@ func fetchChanInfo(chanBucket kvdb.RBucket, channel *OpenChannel) error { } } - // Create balance fields in uint64, and Memo field as byte slice. - var ( - localBalance uint64 - remoteBalance uint64 - memo []byte - ) - - // Create the tlv stream. - tlvStream, err := tlv.NewStream( - // Write the RevocationKeyLocator as the first entry in a tlv - // stream. - MakeKeyLocRecord( - keyLocType, &channel.RevocationKeyLocator, - ), - tlv.MakePrimitiveRecord( - initialLocalBalanceType, &localBalance, - ), - tlv.MakePrimitiveRecord( - initialRemoteBalanceType, &remoteBalance, - ), - MakeScidRecord(realScidType, &channel.confirmedScid), - tlv.MakePrimitiveRecord(channelMemoType, &memo), - ) - if err != nil { - return err + var auxData openChannelTlvData + if err := auxData.decode(r); err != nil { + return fmt.Errorf("unable to decode aux data: %w", err) } - if err := tlvStream.Decode(r); err != nil { - return err - } - - // Attach the balance fields. - channel.InitialLocalBalance = lnwire.MilliSatoshi(localBalance) - channel.InitialRemoteBalance = lnwire.MilliSatoshi(remoteBalance) - - // Attach the memo field if non-empty. - if len(memo) > 0 { - channel.Memo = memo - } + // Assign all the relevant fields from the aux data into the actual + // open channel. + channel.amendTlvData(auxData) channel.Packager = NewChannelPackager(channel.ShortChannelID) @@ -4440,6 +4487,25 @@ func deleteThawHeight(chanBucket kvdb.RwBucket) error { return chanBucket.Delete(frozenChanKey) } +// keyLocRecord is a wrapper struct around keychain.KeyLocator to implement the +// tlv.RecordProducer interface. +type keyLocRecord struct { + keychain.KeyLocator +} + +// Record creates a Record out of a KeyLocator using the passed Type and the +// EKeyLocator and DKeyLocator functions. The size will always be 8 as +// KeyFamily is uint32 and the Index is uint32. +// +// NOTE: This is part of the tlv.RecordProducer interface. +func (k *keyLocRecord) Record() tlv.Record { + // Note that we set the type here as zero, as when used with a + // tlv.RecordT, the type param will be used as the type. + return tlv.MakeStaticRecord( + 0, &k.KeyLocator, 8, EKeyLocator, DKeyLocator, + ) +} + // EKeyLocator is an encoder for keychain.KeyLocator. func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error { if v, ok := val.(*keychain.KeyLocator); ok { @@ -4468,22 +4534,6 @@ func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { return tlv.NewTypeForDecodingErr(val, "keychain.KeyLocator", l, 8) } -// MakeKeyLocRecord creates a Record out of a KeyLocator using the passed -// Type and the EKeyLocator and DKeyLocator functions. The size will always be -// 8 as KeyFamily is uint32 and the Index is uint32. -func MakeKeyLocRecord(typ tlv.Type, keyLoc *keychain.KeyLocator) tlv.Record { - return tlv.MakeStaticRecord(typ, keyLoc, 8, EKeyLocator, DKeyLocator) -} - -// MakeScidRecord creates a Record out of a ShortChannelID using the passed -// Type and the EShortChannelID and DShortChannelID functions. The size will -// always be 8 for the ShortChannelID. -func MakeScidRecord(typ tlv.Type, scid *lnwire.ShortChannelID) tlv.Record { - return tlv.MakeStaticRecord( - typ, scid, 8, lnwire.EShortChannelID, lnwire.DShortChannelID, - ) -} - // ShutdownInfo contains various info about the shutdown initiation of a // channel. type ShutdownInfo struct {