diff --git a/channeldb/channel.go b/channeldb/channel.go index 18db1d207..e7e850ae6 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -225,28 +225,117 @@ const ( // A tlv type definition used to serialize an outpoint's indexStatus // for use in the outpoint index. indexStatusType tlv.Type = 0 - - // A tlv type definition used to serialize and deserialize a KeyLocator - // from the database. - keyLocType tlv.Type = 1 - - // A tlv type used to serialize and deserialize the - // `InitialLocalBalance` field. - initialLocalBalanceType tlv.Type = 2 - - // A tlv type used to serialize and deserialize the - // `InitialRemoteBalance` field. - initialRemoteBalanceType tlv.Type = 3 - - // A tlv type definition used to serialize and deserialize the - // confirmed ShortChannelID for a zero-conf channel. - realScidType tlv.Type = 4 - - // A tlv type definition used to serialize and deserialize the - // Memo for the channel channel. - channelMemoType tlv.Type = 5 ) +// chanAuxData houses the auxiliary data that is stored for each channel in a +// TLV stream within the root bucket. This is stored as a TLV stream appended +// to the existing hard-coded fields in the channel's root bucket. +type chanAuxData struct { + // revokeKeyLoc is the key locator for the revocation key. + revokeKeyLoc tlv.RecordT[tlv.TlvType1, keyLocRecord] + + // initialLocalBalance is the initial local balance of the channel. + initialLocalBalance tlv.RecordT[tlv.TlvType2, uint64] + + // initialRemoteBalance is the initial remote balance of the channel. + initialRemoteBalance tlv.RecordT[tlv.TlvType3, uint64] + + // realScid is the real short channel ID of the channel corresponding to + // the on-chain outpoint. + realScid tlv.RecordT[tlv.TlvType4, lnwire.ShortChannelID] + + // memo is an optional text field that gives context to the user about + // the channel. + memo tlv.OptionalRecordT[tlv.TlvType5, []byte] +} + +// encode serializes the chanAuxData to the given io.Writer. +func (c *chanAuxData) encode(w io.Writer) error { + tlvRecords := []tlv.Record{ + c.revokeKeyLoc.Record(), + c.initialLocalBalance.Record(), + c.initialRemoteBalance.Record(), + c.realScid.Record(), + } + c.memo.WhenSome(func(memo tlv.RecordT[tlv.TlvType5, []byte]) { + tlvRecords = append(tlvRecords, memo.Record()) + }) + + // Create the tlv stream. + tlvStream, err := tlv.NewStream(tlvRecords...) + if err != nil { + return err + } + + return tlvStream.Encode(w) +} + +// decode deserializes the chanAuxData from the given io.Reader. +func (c *chanAuxData) decode(r io.Reader) error { + memo := c.memo.Zero() + + // Create the tlv stream. + tlvStream, err := tlv.NewStream( + c.revokeKeyLoc.Record(), + c.initialLocalBalance.Record(), + c.initialRemoteBalance.Record(), + c.realScid.Record(), + memo.Record(), + ) + if err != nil { + return err + } + + tlvs, err := tlvStream.DecodeWithParsedTypes(r) + if err != nil { + return err + } + + if _, ok := tlvs[memo.TlvType()]; ok { + c.memo = tlv.SomeRecordT(memo) + } + + return nil +} + +// toOpeChan converts the chanAuxData to an OpenChannel by setting the relevant +// fields in the OpenChannel struct. +func (c *chanAuxData) toOpenChan(o *OpenChannel) { + o.RevocationKeyLocator = c.revokeKeyLoc.Val.KeyLocator + o.InitialLocalBalance = lnwire.MilliSatoshi(c.initialLocalBalance.Val) + o.InitialRemoteBalance = lnwire.MilliSatoshi(c.initialRemoteBalance.Val) + o.confirmedScid = c.realScid.Val + c.memo.WhenSomeV(func(memo []byte) { + o.Memo = memo + }) +} + +// newChanAuxDataFromChan creates a new chanAuxData from the given channel. +func newChanAuxDataFromChan(openChan *OpenChannel) *chanAuxData { + c := &chanAuxData{ + revokeKeyLoc: tlv.NewRecordT[tlv.TlvType1]( + keyLocRecord{openChan.RevocationKeyLocator}, + ), + initialLocalBalance: tlv.NewPrimitiveRecord[tlv.TlvType2]( + uint64(openChan.InitialLocalBalance), + ), + initialRemoteBalance: tlv.NewPrimitiveRecord[tlv.TlvType3]( + uint64(openChan.InitialRemoteBalance), + ), + realScid: tlv.NewRecordT[tlv.TlvType4]( + openChan.confirmedScid, + ), + } + + if len(openChan.Memo) != 0 { + c.memo = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5](openChan.Memo), + ) + } + + return c +} + // indexStatus is an enum-like type that describes what state the // outpoint is in. Currently only two possible values. type indexStatus uint8 @@ -856,6 +945,10 @@ type OpenChannel struct { // channel that will be useful to our future selves. Memo []byte + // TapscriptRoot is an optional tapscript root used to derive the MuSig2 + // funding output. + TapscriptRoot fn.Option[chainhash.Hash] + // TODO(roasbeef): eww Db *ChannelStateDB @@ -4007,32 +4100,9 @@ func putChanInfo(chanBucket kvdb.RwBucket, channel *OpenChannel) error { return err } - // Convert balance fields into uint64. - localBalance := uint64(channel.InitialLocalBalance) - remoteBalance := uint64(channel.InitialRemoteBalance) - - // Create the tlv stream. - tlvStream, err := tlv.NewStream( - // Write the RevocationKeyLocator as the first entry in a tlv - // stream. - MakeKeyLocRecord( - keyLocType, &channel.RevocationKeyLocator, - ), - tlv.MakePrimitiveRecord( - initialLocalBalanceType, &localBalance, - ), - tlv.MakePrimitiveRecord( - initialRemoteBalanceType, &remoteBalance, - ), - MakeScidRecord(realScidType, &channel.confirmedScid), - tlv.MakePrimitiveRecord(channelMemoType, &channel.Memo), - ) - if err != nil { - return err - } - - if err := tlvStream.Encode(&w); err != nil { - return err + auxData := newChanAuxDataFromChan(channel) + if err := auxData.encode(&w); err != nil { + return fmt.Errorf("unable to encode aux data: %w", err) } if err := chanBucket.Put(chanInfoKey, w.Bytes()); err != nil { @@ -4221,45 +4291,14 @@ func fetchChanInfo(chanBucket kvdb.RBucket, channel *OpenChannel) error { } } - // Create balance fields in uint64, and Memo field as byte slice. - var ( - localBalance uint64 - remoteBalance uint64 - memo []byte - ) - - // Create the tlv stream. - tlvStream, err := tlv.NewStream( - // Write the RevocationKeyLocator as the first entry in a tlv - // stream. - MakeKeyLocRecord( - keyLocType, &channel.RevocationKeyLocator, - ), - tlv.MakePrimitiveRecord( - initialLocalBalanceType, &localBalance, - ), - tlv.MakePrimitiveRecord( - initialRemoteBalanceType, &remoteBalance, - ), - MakeScidRecord(realScidType, &channel.confirmedScid), - tlv.MakePrimitiveRecord(channelMemoType, &memo), - ) - if err != nil { - return err + var auxData chanAuxData + if err := auxData.decode(r); err != nil { + return fmt.Errorf("unable to decode aux data: %w", err) } - if err := tlvStream.Decode(r); err != nil { - return err - } - - // Attach the balance fields. - channel.InitialLocalBalance = lnwire.MilliSatoshi(localBalance) - channel.InitialRemoteBalance = lnwire.MilliSatoshi(remoteBalance) - - // Attach the memo field if non-empty. - if len(memo) > 0 { - channel.Memo = memo - } + // Assign all the relevant fields from the aux data into the actual + // open channel. + auxData.toOpenChan(channel) channel.Packager = NewChannelPackager(channel.ShortChannelID) @@ -4417,6 +4456,25 @@ func deleteThawHeight(chanBucket kvdb.RwBucket) error { return chanBucket.Delete(frozenChanKey) } +// keyLocRecord is a wrapper struct around keychain.KeyLocator to implement the +// tlv.RecordProducer interface. +type keyLocRecord struct { + keychain.KeyLocator +} + +// Record creates a Record out of a KeyLocator using the passed Type and the +// EKeyLocator and DKeyLocator functions. The size will always be 8 as +// KeyFamily is uint32 and the Index is uint32. +// +// NOTE: This is part of the tlv.RecordProducer interface. +func (k *keyLocRecord) Record() tlv.Record { + // Note that we set the type here as zero, as when used with a + // tlv.RecordT, the type param will be used as the type. + return tlv.MakeStaticRecord( + 0, &k.KeyLocator, 8, EKeyLocator, DKeyLocator, + ) +} + // EKeyLocator is an encoder for keychain.KeyLocator. func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error { if v, ok := val.(*keychain.KeyLocator); ok { @@ -4445,22 +4503,6 @@ func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { return tlv.NewTypeForDecodingErr(val, "keychain.KeyLocator", l, 8) } -// MakeKeyLocRecord creates a Record out of a KeyLocator using the passed -// Type and the EKeyLocator and DKeyLocator functions. The size will always be -// 8 as KeyFamily is uint32 and the Index is uint32. -func MakeKeyLocRecord(typ tlv.Type, keyLoc *keychain.KeyLocator) tlv.Record { - return tlv.MakeStaticRecord(typ, keyLoc, 8, EKeyLocator, DKeyLocator) -} - -// MakeScidRecord creates a Record out of a ShortChannelID using the passed -// Type and the EShortChannelID and DShortChannelID functions. The size will -// always be 8 for the ShortChannelID. -func MakeScidRecord(typ tlv.Type, scid *lnwire.ShortChannelID) tlv.Record { - return tlv.MakeStaticRecord( - typ, scid, 8, lnwire.EShortChannelID, lnwire.DShortChannelID, - ) -} - // ShutdownInfo contains various info about the shutdown initiation of a // channel. type ShutdownInfo struct {