diff --git a/chainreg/chainregistry.go b/chainreg/chainregistry.go index 37e72fdee..41a2fcbb7 100644 --- a/chainreg/chainregistry.go +++ b/chainreg/chainregistry.go @@ -24,6 +24,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs/neutrinonotify" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -63,6 +64,10 @@ type Config struct { // state. ChanStateDB *channeldb.ChannelStateDB + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] + // BlockCache is the main cache for storing block information. BlockCache *blockcache.BlockCache diff --git a/channeldb/channel.go b/channeldb/channel.go index 4e3db2fc1..c21716a45 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -226,28 +226,109 @@ const ( // A tlv type definition used to serialize an outpoint's indexStatus // for use in the outpoint index. indexStatusType tlv.Type = 0 - - // A tlv type definition used to serialize and deserialize a KeyLocator - // from the database. - keyLocType tlv.Type = 1 - - // A tlv type used to serialize and deserialize the - // `InitialLocalBalance` field. - initialLocalBalanceType tlv.Type = 2 - - // A tlv type used to serialize and deserialize the - // `InitialRemoteBalance` field. - initialRemoteBalanceType tlv.Type = 3 - - // A tlv type definition used to serialize and deserialize the - // confirmed ShortChannelID for a zero-conf channel. - realScidType tlv.Type = 4 - - // A tlv type definition used to serialize and deserialize the - // Memo for the channel channel. - channelMemoType tlv.Type = 5 ) +// openChannelTlvData houses the new data fields that are stored for each +// channel in a TLV stream within the root bucket. This is stored as a TLV +// stream appended to the existing hard-coded fields in the channel's root +// bucket. New fields being added to the channel state should be added here. +// +// NOTE: This struct is used for serialization purposes only and its fields +// should be accessed via the OpenChannel struct while in memory. +type openChannelTlvData struct { + // revokeKeyLoc is the key locator for the revocation key. + revokeKeyLoc tlv.RecordT[tlv.TlvType1, keyLocRecord] + + // initialLocalBalance is the initial local balance of the channel. + initialLocalBalance tlv.RecordT[tlv.TlvType2, uint64] + + // initialRemoteBalance is the initial remote balance of the channel. + initialRemoteBalance tlv.RecordT[tlv.TlvType3, uint64] + + // realScid is the real short channel ID of the channel corresponding to + // the on-chain outpoint. + realScid tlv.RecordT[tlv.TlvType4, lnwire.ShortChannelID] + + // memo is an optional text field that gives context to the user about + // the channel. + memo tlv.OptionalRecordT[tlv.TlvType5, []byte] + + // tapscriptRoot is the optional Tapscript root the channel funding + // output commits to. + tapscriptRoot tlv.OptionalRecordT[tlv.TlvType6, [32]byte] + + // customBlob is an optional TLV encoded blob of data representing + // custom channel funding information. + customBlob tlv.OptionalRecordT[tlv.TlvType7, tlv.Blob] +} + +// encode serializes the openChannelTlvData to the given io.Writer. +func (c *openChannelTlvData) encode(w io.Writer) error { + tlvRecords := []tlv.Record{ + c.revokeKeyLoc.Record(), + c.initialLocalBalance.Record(), + c.initialRemoteBalance.Record(), + c.realScid.Record(), + } + c.memo.WhenSome(func(memo tlv.RecordT[tlv.TlvType5, []byte]) { + tlvRecords = append(tlvRecords, memo.Record()) + }) + c.tapscriptRoot.WhenSome( + func(root tlv.RecordT[tlv.TlvType6, [32]byte]) { + tlvRecords = append(tlvRecords, root.Record()) + }, + ) + c.customBlob.WhenSome(func(blob tlv.RecordT[tlv.TlvType7, tlv.Blob]) { + tlvRecords = append(tlvRecords, blob.Record()) + }) + + // Create the tlv stream. + tlvStream, err := tlv.NewStream(tlvRecords...) + if err != nil { + return err + } + + return tlvStream.Encode(w) +} + +// decode deserializes the openChannelTlvData from the given io.Reader. +func (c *openChannelTlvData) decode(r io.Reader) error { + memo := c.memo.Zero() + tapscriptRoot := c.tapscriptRoot.Zero() + blob := c.customBlob.Zero() + + // Create the tlv stream. + tlvStream, err := tlv.NewStream( + c.revokeKeyLoc.Record(), + c.initialLocalBalance.Record(), + c.initialRemoteBalance.Record(), + c.realScid.Record(), + memo.Record(), + tapscriptRoot.Record(), + blob.Record(), + ) + if err != nil { + return err + } + + tlvs, err := tlvStream.DecodeWithParsedTypes(r) + if err != nil { + return err + } + + if _, ok := tlvs[memo.TlvType()]; ok { + c.memo = tlv.SomeRecordT(memo) + } + if _, ok := tlvs[tapscriptRoot.TlvType()]; ok { + c.tapscriptRoot = tlv.SomeRecordT(tapscriptRoot) + } + if _, ok := tlvs[c.customBlob.TlvType()]; ok { + c.customBlob = tlv.SomeRecordT(blob) + } + + return nil +} + // indexStatus is an enum-like type that describes what state the // outpoint is in. Currently only two possible values. type indexStatus uint8 @@ -325,6 +406,11 @@ const ( // SimpleTaprootFeatureBit indicates that the simple-taproot-chans // feature bit was negotiated during the lifetime of the channel. SimpleTaprootFeatureBit ChannelType = 1 << 10 + + // TapscriptRootBit indicates that this is a MuSig2 channel with a top + // level tapscript commitment. This MUST be set along with the + // SimpleTaprootFeatureBit. + TapscriptRootBit ChannelType = 1 << 11 ) // IsSingleFunder returns true if the channel type if one of the known single @@ -395,6 +481,12 @@ func (c ChannelType) IsTaproot() bool { return c&SimpleTaprootFeatureBit == SimpleTaprootFeatureBit } +// HasTapscriptRoot returns true if the channel is using a top level tapscript +// root commitment. +func (c ChannelType) HasTapscriptRoot() bool { + return c&TapscriptRootBit == TapscriptRootBit +} + // ChannelStateBounds are the parameters from OpenChannel and AcceptChannel // that are responsible for providing bounds on the state space of the abstract // channel state. These values must be remembered for normal channel operation @@ -496,6 +588,53 @@ type ChannelConfig struct { HtlcBasePoint keychain.KeyDescriptor } +// commitTlvData stores all the optional data that may be stored as a TLV stream +// at the _end_ of the normal serialized commit on disk. +type commitTlvData struct { + // customBlob is a custom blob that may store extra data for custom + // channels. + customBlob tlv.OptionalRecordT[tlv.TlvType1, tlv.Blob] +} + +// encode encodes the aux data into the passed io.Writer. +func (c *commitTlvData) encode(w io.Writer) error { + var tlvRecords []tlv.Record + c.customBlob.WhenSome(func(blob tlv.RecordT[tlv.TlvType1, tlv.Blob]) { + tlvRecords = append(tlvRecords, blob.Record()) + }) + + // Create the tlv stream. + tlvStream, err := tlv.NewStream(tlvRecords...) + if err != nil { + return err + } + + return tlvStream.Encode(w) +} + +// decode attempts to decode the aux data from the passed io.Reader. +func (c *commitTlvData) decode(r io.Reader) error { + blob := c.customBlob.Zero() + + tlvStream, err := tlv.NewStream( + blob.Record(), + ) + if err != nil { + return err + } + + tlvs, err := tlvStream.DecodeWithParsedTypes(r) + if err != nil { + return err + } + + if _, ok := tlvs[c.customBlob.TlvType()]; ok { + c.customBlob = tlv.SomeRecordT(blob) + } + + return nil +} + // ChannelCommitment is a snapshot of the commitment state at a particular // point in the commitment chain. With each state transition, a snapshot of the // current state along with all non-settled HTLCs are recorded. These snapshots @@ -562,6 +701,11 @@ type ChannelCommitment struct { // able by us. CommitTx *wire.MsgTx + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This may track some custom + // specific state for this given commitment. + CustomBlob fn.Option[tlv.Blob] + // CommitSig is one half of the signature required to fully complete // the script for the commitment transaction above. This is the // signature signed by the remote party for our version of the @@ -571,9 +715,26 @@ type ChannelCommitment struct { // Htlcs is the set of HTLC's that are pending at this particular // commitment height. Htlcs []HTLC +} - // TODO(roasbeef): pending commit pointer? - // * lets just walk through +// amendTlvData updates the channel with the given auxiliary TLV data. +func (c *ChannelCommitment) amendTlvData(auxData commitTlvData) { + auxData.customBlob.WhenSomeV(func(blob tlv.Blob) { + c.CustomBlob = fn.Some(blob) + }) +} + +// extractTlvData creates a new commitTlvData from the given commitment. +func (c *ChannelCommitment) extractTlvData() commitTlvData { + var auxData commitTlvData + + c.CustomBlob.WhenSome(func(blob tlv.Blob) { + auxData.customBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType1](blob), + ) + }) + + return auxData } // ChannelStatus is a bit vector used to indicate whether an OpenChannel is in @@ -867,6 +1028,16 @@ type OpenChannel struct { // channel that will be useful to our future selves. Memo []byte + // TapscriptRoot is an optional tapscript root used to derive the MuSig2 + // funding output. + TapscriptRoot fn.Option[chainhash.Hash] + + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This information is only created + // at channel funding time, and after wards is to be considered + // immutable. + CustomBlob fn.Option[tlv.Blob] + // TODO(roasbeef): eww Db *ChannelStateDB @@ -1025,6 +1196,64 @@ func (c *OpenChannel) SetBroadcastHeight(height uint32) { c.FundingBroadcastHeight = height } +// amendTlvData updates the channel with the given auxiliary TLV data. +func (c *OpenChannel) amendTlvData(auxData openChannelTlvData) { + c.RevocationKeyLocator = auxData.revokeKeyLoc.Val.KeyLocator + c.InitialLocalBalance = lnwire.MilliSatoshi( + auxData.initialLocalBalance.Val, + ) + c.InitialRemoteBalance = lnwire.MilliSatoshi( + auxData.initialRemoteBalance.Val, + ) + c.confirmedScid = auxData.realScid.Val + + auxData.memo.WhenSomeV(func(memo []byte) { + c.Memo = memo + }) + auxData.tapscriptRoot.WhenSomeV(func(h [32]byte) { + c.TapscriptRoot = fn.Some[chainhash.Hash](h) + }) + auxData.customBlob.WhenSomeV(func(blob tlv.Blob) { + c.CustomBlob = fn.Some(blob) + }) +} + +// extractTlvData creates a new openChannelTlvData from the given channel. +func (c *OpenChannel) extractTlvData() openChannelTlvData { + auxData := openChannelTlvData{ + revokeKeyLoc: tlv.NewRecordT[tlv.TlvType1]( + keyLocRecord{c.RevocationKeyLocator}, + ), + initialLocalBalance: tlv.NewPrimitiveRecord[tlv.TlvType2]( + uint64(c.InitialLocalBalance), + ), + initialRemoteBalance: tlv.NewPrimitiveRecord[tlv.TlvType3]( + uint64(c.InitialRemoteBalance), + ), + realScid: tlv.NewRecordT[tlv.TlvType4]( + c.confirmedScid, + ), + } + + if len(c.Memo) != 0 { + auxData.memo = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5](c.Memo), + ) + } + c.TapscriptRoot.WhenSome(func(h chainhash.Hash) { + auxData.tapscriptRoot = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, [32]byte](h), + ) + }) + c.CustomBlob.WhenSome(func(blob tlv.Blob) { + auxData.customBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType7](blob), + ) + }) + + return auxData +} + // Refresh updates the in-memory channel state using the latest state observed // on disk. func (c *OpenChannel) Refresh() error { @@ -2351,6 +2580,12 @@ type HTLC struct { // HTLC. It is stored in the ExtraData field, which is used to store // a TLV stream of additional information associated with the HTLC. BlindingPoint lnwire.BlindingPointRecord + + // CustomRecords is a set of custom TLV records that are associated with + // this HTLC. These records are used to store additional information + // about the HTLC that is not part of the standard HTLC fields. This + // field is encoded within the ExtraData field. + CustomRecords lnwire.CustomRecords } // serializeExtraData encodes a TLV stream of extra data to be stored with a @@ -2369,6 +2604,11 @@ func (h *HTLC) serializeExtraData() error { records = append(records, &b) }) + records, err := h.CustomRecords.ExtendRecordProducers(records) + if err != nil { + return err + } + return h.ExtraData.PackRecords(records...) } @@ -2390,8 +2630,19 @@ func (h *HTLC) deserializeExtraData() error { if val, ok := tlvMap[h.BlindingPoint.TlvType()]; ok && val == nil { h.BlindingPoint = tlv.SomeRecordT(blindingPoint) + + // Remove the entry from the TLV map. Anything left in the map + // will be included in the custom records field. + delete(tlvMap, h.BlindingPoint.TlvType()) } + // Set the custom records field to the remaining TLV records. + customRecords, err := lnwire.NewCustomRecords(tlvMap) + if err != nil { + return err + } + h.CustomRecords = customRecords + return nil } @@ -2529,6 +2780,8 @@ func (h *HTLC) Copy() HTLC { copy(clone.Signature[:], h.Signature) copy(clone.RHash[:], h.RHash[:]) copy(clone.ExtraData, h.ExtraData) + clone.BlindingPoint = h.BlindingPoint + clone.CustomRecords = h.CustomRecords.Copy() return clone } @@ -2690,6 +2943,14 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { // nolint: dupl } } + // We'll also encode the commit aux data stream here. We do this here + // rather than above (at the call to serializeChanCommit), to ensure + // backwards compat for reads to existing non-custom channels. + auxData := diff.Commitment.extractTlvData() + if err := auxData.encode(w); err != nil { + return fmt.Errorf("unable to write aux data: %w", err) + } + return nil } @@ -2750,6 +3011,17 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { } } + // As a final step, we'll read out any aux commit data that we have at + // the end of this byte stream. We do this here to ensure backward + // compatibility, as otherwise we risk erroneously reading into the + // wrong field. + var auxData commitTlvData + if err := auxData.decode(r); err != nil { + return nil, fmt.Errorf("unable to decode aux data: %w", err) + } + + d.Commitment.amendTlvData(auxData) + return &d, nil } @@ -3728,6 +4000,13 @@ func (c *OpenChannel) Snapshot() *ChannelSnapshot { }, } + localCommit.CustomBlob.WhenSome(func(blob tlv.Blob) { + blobCopy := make([]byte, len(blob)) + copy(blobCopy, blob) + + snapshot.ChannelCommitment.CustomBlob = fn.Some(blobCopy) + }) + // Copy over the current set of HTLCs to ensure the caller can't mutate // our internal state. snapshot.Htlcs = make([]HTLC, len(localCommit.Htlcs)) @@ -4030,32 +4309,9 @@ func putChanInfo(chanBucket kvdb.RwBucket, channel *OpenChannel) error { return err } - // Convert balance fields into uint64. - localBalance := uint64(channel.InitialLocalBalance) - remoteBalance := uint64(channel.InitialRemoteBalance) - - // Create the tlv stream. - tlvStream, err := tlv.NewStream( - // Write the RevocationKeyLocator as the first entry in a tlv - // stream. - MakeKeyLocRecord( - keyLocType, &channel.RevocationKeyLocator, - ), - tlv.MakePrimitiveRecord( - initialLocalBalanceType, &localBalance, - ), - tlv.MakePrimitiveRecord( - initialRemoteBalanceType, &remoteBalance, - ), - MakeScidRecord(realScidType, &channel.confirmedScid), - tlv.MakePrimitiveRecord(channelMemoType, &channel.Memo), - ) - if err != nil { - return err - } - - if err := tlvStream.Encode(&w); err != nil { - return err + auxData := channel.extractTlvData() + if err := auxData.encode(&w); err != nil { + return fmt.Errorf("unable to encode aux data: %w", err) } if err := chanBucket.Put(chanInfoKey, w.Bytes()); err != nil { @@ -4142,6 +4398,12 @@ func putChanCommitment(chanBucket kvdb.RwBucket, c *ChannelCommitment, return err } + // Before we write to disk, we'll also write our aux data as well. + auxData := c.extractTlvData() + if err := auxData.encode(&b); err != nil { + return fmt.Errorf("unable to write aux data: %w", err) + } + return chanBucket.Put(commitKey, b.Bytes()) } @@ -4244,45 +4506,14 @@ func fetchChanInfo(chanBucket kvdb.RBucket, channel *OpenChannel) error { } } - // Create balance fields in uint64, and Memo field as byte slice. - var ( - localBalance uint64 - remoteBalance uint64 - memo []byte - ) - - // Create the tlv stream. - tlvStream, err := tlv.NewStream( - // Write the RevocationKeyLocator as the first entry in a tlv - // stream. - MakeKeyLocRecord( - keyLocType, &channel.RevocationKeyLocator, - ), - tlv.MakePrimitiveRecord( - initialLocalBalanceType, &localBalance, - ), - tlv.MakePrimitiveRecord( - initialRemoteBalanceType, &remoteBalance, - ), - MakeScidRecord(realScidType, &channel.confirmedScid), - tlv.MakePrimitiveRecord(channelMemoType, &memo), - ) - if err != nil { - return err + var auxData openChannelTlvData + if err := auxData.decode(r); err != nil { + return fmt.Errorf("unable to decode aux data: %w", err) } - if err := tlvStream.Decode(r); err != nil { - return err - } - - // Attach the balance fields. - channel.InitialLocalBalance = lnwire.MilliSatoshi(localBalance) - channel.InitialRemoteBalance = lnwire.MilliSatoshi(remoteBalance) - - // Attach the memo field if non-empty. - if len(memo) > 0 { - channel.Memo = memo - } + // Assign all the relevant fields from the aux data into the actual + // open channel. + channel.amendTlvData(auxData) channel.Packager = NewChannelPackager(channel.ShortChannelID) @@ -4318,7 +4549,9 @@ func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) { return c, nil } -func fetchChanCommitment(chanBucket kvdb.RBucket, local bool) (ChannelCommitment, error) { +func fetchChanCommitment(chanBucket kvdb.RBucket, + local bool) (ChannelCommitment, error) { + var commitKey []byte if local { commitKey = append(chanCommitmentKey, byte(0x00)) @@ -4332,7 +4565,23 @@ func fetchChanCommitment(chanBucket kvdb.RBucket, local bool) (ChannelCommitment } r := bytes.NewReader(commitBytes) - return deserializeChanCommit(r) + chanCommit, err := deserializeChanCommit(r) + if err != nil { + return ChannelCommitment{}, fmt.Errorf("unable to decode "+ + "chan commit: %w", err) + } + + // We'll also check to see if we have any aux data stored as the end of + // the stream. + var auxData commitTlvData + if err := auxData.decode(r); err != nil { + return ChannelCommitment{}, fmt.Errorf("unable to decode "+ + "chan aux data: %w", err) + } + + chanCommit.amendTlvData(auxData) + + return chanCommit, nil } func fetchChanCommitments(chanBucket kvdb.RBucket, channel *OpenChannel) error { @@ -4440,6 +4689,25 @@ func deleteThawHeight(chanBucket kvdb.RwBucket) error { return chanBucket.Delete(frozenChanKey) } +// keyLocRecord is a wrapper struct around keychain.KeyLocator to implement the +// tlv.RecordProducer interface. +type keyLocRecord struct { + keychain.KeyLocator +} + +// Record creates a Record out of a KeyLocator using the passed Type and the +// EKeyLocator and DKeyLocator functions. The size will always be 8 as +// KeyFamily is uint32 and the Index is uint32. +// +// NOTE: This is part of the tlv.RecordProducer interface. +func (k *keyLocRecord) Record() tlv.Record { + // Note that we set the type here as zero, as when used with a + // tlv.RecordT, the type param will be used as the type. + return tlv.MakeStaticRecord( + 0, &k.KeyLocator, 8, EKeyLocator, DKeyLocator, + ) +} + // EKeyLocator is an encoder for keychain.KeyLocator. func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error { if v, ok := val.(*keychain.KeyLocator); ok { @@ -4468,22 +4736,6 @@ func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { return tlv.NewTypeForDecodingErr(val, "keychain.KeyLocator", l, 8) } -// MakeKeyLocRecord creates a Record out of a KeyLocator using the passed -// Type and the EKeyLocator and DKeyLocator functions. The size will always be -// 8 as KeyFamily is uint32 and the Index is uint32. -func MakeKeyLocRecord(typ tlv.Type, keyLoc *keychain.KeyLocator) tlv.Record { - return tlv.MakeStaticRecord(typ, keyLoc, 8, EKeyLocator, DKeyLocator) -} - -// MakeScidRecord creates a Record out of a ShortChannelID using the passed -// Type and the EShortChannelID and DShortChannelID functions. The size will -// always be 8 for the ShortChannelID. -func MakeScidRecord(typ tlv.Type, scid *lnwire.ShortChannelID) tlv.Record { - return tlv.MakeStaticRecord( - typ, scid, 8, lnwire.EShortChannelID, lnwire.DShortChannelID, - ) -} - // ShutdownInfo contains various info about the shutdown initiation of a // channel. type ShutdownInfo struct { diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index a7f3c1ebe..e92692201 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -17,6 +17,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnmock" @@ -173,7 +174,7 @@ func fundingPointOption(chanPoint wire.OutPoint) testChannelOption { } // channelIDOption is an option which sets the short channel ID of the channel. -var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption { +func channelIDOption(chanID lnwire.ShortChannelID) testChannelOption { return func(params *testChannelParams) { params.channel.ShortChannelID = chanID } @@ -326,6 +327,9 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { uniqueOutputIndex.Add(1) op := wire.OutPoint{Hash: key, Index: uniqueOutputIndex.Load()} + var tapscriptRoot chainhash.Hash + copy(tapscriptRoot[:], bytes.Repeat([]byte{1}, 32)) + return &OpenChannel{ ChanType: SingleFunderBit | FrozenBit, ChainHash: key, @@ -347,6 +351,7 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { FeePerKw: btcutil.Amount(5000), CommitTx: channels.TestFundingTx, CommitSig: bytes.Repeat([]byte{1}, 71), + CustomBlob: fn.Some([]byte{1, 2, 3}), }, RemoteCommitment: ChannelCommitment{ CommitHeight: 0, @@ -356,6 +361,7 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { FeePerKw: btcutil.Amount(5000), CommitTx: channels.TestFundingTx, CommitSig: bytes.Repeat([]byte{1}, 71), + CustomBlob: fn.Some([]byte{4, 5, 6}), }, NumConfsRequired: 4, RemoteCurrentRevocation: privKey.PubKey(), @@ -368,6 +374,9 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { ThawHeight: uint32(defaultPendingHeight), InitialLocalBalance: lnwire.MilliSatoshi(9000), InitialRemoteBalance: lnwire.MilliSatoshi(3000), + Memo: []byte("test"), + TapscriptRoot: fn.Some(tapscriptRoot), + CustomBlob: fn.Some([]byte{1, 2, 3}), } } @@ -575,24 +584,32 @@ func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) { func assertRevocationLogEntryEqual(t *testing.T, c *ChannelCommitment, r *RevocationLog) { + t.Helper() + // Check the common fields. require.EqualValues( - t, r.CommitTxHash, c.CommitTx.TxHash(), "CommitTx mismatch", + t, r.CommitTxHash.Val, c.CommitTx.TxHash(), "CommitTx mismatch", ) // Now check the common fields from the HTLCs. require.Equal(t, len(r.HTLCEntries), len(c.Htlcs), "HTLCs len mismatch") for i, rHtlc := range r.HTLCEntries { cHtlc := c.Htlcs[i] - require.Equal(t, rHtlc.RHash, cHtlc.RHash, "RHash mismatch") - require.Equal(t, rHtlc.Amt, cHtlc.Amt.ToSatoshis(), - "Amt mismatch") - require.Equal(t, rHtlc.RefundTimeout, cHtlc.RefundTimeout, - "RefundTimeout mismatch") - require.EqualValues(t, rHtlc.OutputIndex, cHtlc.OutputIndex, - "OutputIndex mismatch") - require.Equal(t, rHtlc.Incoming, cHtlc.Incoming, - "Incoming mismatch") + require.Equal(t, rHtlc.RHash.Val[:], cHtlc.RHash[:], "RHash") + require.Equal( + t, rHtlc.Amt.Val.Int(), cHtlc.Amt.ToSatoshis(), "Amt", + ) + require.Equal( + t, rHtlc.RefundTimeout.Val, cHtlc.RefundTimeout, + "RefundTimeout", + ) + require.EqualValues( + t, rHtlc.OutputIndex.Val, cHtlc.OutputIndex, + "OutputIndex", + ) + require.Equal( + t, rHtlc.Incoming.Val, cHtlc.Incoming, "Incoming", + ) } } @@ -657,6 +674,7 @@ func TestChannelStateTransition(t *testing.T) { CommitTx: newTx, CommitSig: newSig, Htlcs: htlcs, + CustomBlob: fn.Some([]byte{4, 5, 6}), } // First update the local node's broadcastable state and also add a @@ -694,9 +712,14 @@ func TestChannelStateTransition(t *testing.T) { // have been updated. updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub) require.NoError(t, err, "unable to fetch updated channel") - assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment) + + assertCommitmentEqual( + t, &commitment, &updatedChannel[0].LocalCommitment, + ) + numDiskUpdates, err := updatedChannel[0].CommitmentHeight() require.NoError(t, err, "unable to read commitment height from disk") + if numDiskUpdates != uint64(commitment.CommitHeight) { t.Fatalf("num disk updates doesn't match: %v vs %v", numDiskUpdates, commitment.CommitHeight) @@ -799,10 +822,10 @@ func TestChannelStateTransition(t *testing.T) { // Check the output indexes are saved as expected. require.EqualValues( - t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex, + t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex.Val, ) require.EqualValues( - t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex, + t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex.Val, ) // The two deltas (the original vs the on-disk version) should @@ -844,10 +867,10 @@ func TestChannelStateTransition(t *testing.T) { // Check the output indexes are saved as expected. require.EqualValues( - t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex, + t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex.Val, ) require.EqualValues( - t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex, + t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex.Val, ) assertRevocationLogEntryEqual(t, &oldRemoteCommit, prevCommit) @@ -1642,6 +1665,24 @@ func TestHTLCsExtraData(t *testing.T) { ), } + // Custom channel data htlc with a blinding point. + customDataHTLC := HTLC{ + Signature: testSig.Serialize(), + Incoming: false, + Amt: 10, + RHash: key, + RefundTimeout: 1, + OnionBlob: lnmock.MockOnion(), + BlindingPoint: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + pubKey, + ), + ), + CustomRecords: map[uint64][]byte{ + uint64(lnwire.MinCustomRecordsTlvType + 3): {1, 2, 3}, + }, + } + testCases := []struct { name string htlcs []HTLC @@ -1663,6 +1704,7 @@ func TestHTLCsExtraData(t *testing.T) { mockHtlc, blindingPointHTLC, mockHtlc, + customDataHTLC, }, }, } diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index f062ac086..3abc73f81 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -7,6 +7,7 @@ import ( "math" "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" @@ -16,16 +17,15 @@ import ( const ( // OutputIndexEmpty is used when the output index doesn't exist. OutputIndexEmpty = math.MaxUint16 +) - // A set of tlv type definitions used to serialize the body of - // revocation logs to the database. - // - // NOTE: A migration should be added whenever this list changes. - revLogOurOutputIndexType tlv.Type = 0 - revLogTheirOutputIndexType tlv.Type = 1 - revLogCommitTxHashType tlv.Type = 2 - revLogOurBalanceType tlv.Type = 3 - revLogTheirBalanceType tlv.Type = 4 +type ( + // BigSizeAmount is a type alias for a TLV record of a btcutil.Amount. + BigSizeAmount = tlv.BigSizeT[btcutil.Amount] + + // BigSizeMilliSatoshi is a type alias for a TLV record of a + // lnwire.MilliSatoshi. + BigSizeMilliSatoshi = tlv.BigSizeT[lnwire.MilliSatoshi] ) var ( @@ -54,6 +54,74 @@ var ( ErrOutputIndexTooBig = errors.New("output index is over uint16") ) +// SparsePayHash is a type alias for a 32 byte array, which when serialized is +// able to save some space by not including an empty payment hash on disk. +type SparsePayHash [32]byte + +// NewSparsePayHash creates a new SparsePayHash from a 32 byte array. +func NewSparsePayHash(rHash [32]byte) SparsePayHash { + return SparsePayHash(rHash) +} + +// Record returns a tlv record for the SparsePayHash. +func (s *SparsePayHash) Record() tlv.Record { + // We use a zero for the type here, as this'll be used along with the + // RecordT type. + return tlv.MakeDynamicRecord( + 0, s, s.hashLen, + sparseHashEncoder, sparseHashDecoder, + ) +} + +// hashLen is used by MakeDynamicRecord to return the size of the RHash. +// +// NOTE: for zero hash, we return a length 0. +func (s *SparsePayHash) hashLen() uint64 { + if bytes.Equal(s[:], lntypes.ZeroHash[:]) { + return 0 + } + + return 32 +} + +// sparseHashEncoder is the customized encoder which skips encoding the empty +// hash. +func sparseHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error { + v, ok := val.(*SparsePayHash) + if !ok { + return tlv.NewTypeForEncodingErr(val, "SparsePayHash") + } + + // If the value is an empty hash, we will skip encoding it. + if bytes.Equal(v[:], lntypes.ZeroHash[:]) { + return nil + } + + vArray := (*[32]byte)(v) + + return tlv.EBytes32(w, vArray, buf) +} + +// sparseHashDecoder is the customized decoder which skips decoding the empty +// hash. +func sparseHashDecoder(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + v, ok := val.(*SparsePayHash) + if !ok { + return tlv.NewTypeForEncodingErr(val, "SparsePayHash") + } + + // If the length is zero, we will skip encoding the empty hash. + if l == 0 { + return nil + } + + vArray := (*[32]byte)(v) + + return tlv.DBytes32(r, vArray, buf, 32) +} + // HTLCEntry specifies the minimal info needed to be stored on disk for ALL the // historical HTLCs, which is useful for constructing RevocationLog when a // breach is detected. @@ -72,116 +140,90 @@ var ( // made into tlv records without further conversion. type HTLCEntry struct { // RHash is the payment hash of the HTLC. - RHash [32]byte + RHash tlv.RecordT[tlv.TlvType0, SparsePayHash] // RefundTimeout is the absolute timeout on the HTLC that the sender // must wait before reclaiming the funds in limbo. - RefundTimeout uint32 + RefundTimeout tlv.RecordT[tlv.TlvType1, uint32] // OutputIndex is the output index for this particular HTLC output // within the commitment transaction. // // NOTE: we use uint16 instead of int32 here to save us 2 bytes, which // gives us a max number of HTLCs of 65K. - OutputIndex uint16 + OutputIndex tlv.RecordT[tlv.TlvType2, uint16] // Incoming denotes whether we're the receiver or the sender of this // HTLC. - // - // NOTE: this field is the memory representation of the field - // incomingUint. - Incoming bool + Incoming tlv.RecordT[tlv.TlvType3, bool] // Amt is the amount of satoshis this HTLC escrows. - // - // NOTE: this field is the memory representation of the field amtUint. - Amt btcutil.Amount + Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]] - // amtTlv is the uint64 format of Amt. This field is created so we can - // easily make it into a tlv record and save it to disk. - // - // NOTE: we keep this field for accounting purpose only. If the disk - // space becomes an issue, we could delete this field to save us extra - // 8 bytes. - amtTlv uint64 + // CustomBlob is an optional blob that can be used to store information + // specific to revocation handling for a custom channel type. + CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob] - // incomingTlv is the uint8 format of Incoming. This field is created - // so we can easily make it into a tlv record and save it to disk. - incomingTlv uint8 -} - -// RHashLen is used by MakeDynamicRecord to return the size of the RHash. -// -// NOTE: for zero hash, we return a length 0. -func (h *HTLCEntry) RHashLen() uint64 { - if h.RHash == lntypes.ZeroHash { - return 0 - } - return 32 -} - -// RHashEncoder is the customized encoder which skips encoding the empty hash. -func RHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error { - v, ok := val.(*[32]byte) - if !ok { - return tlv.NewTypeForEncodingErr(val, "RHash") - } - - // If the value is an empty hash, we will skip encoding it. - if *v == lntypes.ZeroHash { - return nil - } - - return tlv.EBytes32(w, v, buf) -} - -// RHashDecoder is the customized decoder which skips decoding the empty hash. -func RHashDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { - v, ok := val.(*[32]byte) - if !ok { - return tlv.NewTypeForEncodingErr(val, "RHash") - } - - // If the length is zero, we will skip encoding the empty hash. - if l == 0 { - return nil - } - - return tlv.DBytes32(r, v, buf, 32) + // HtlcIndex is the index of the HTLC in the channel. + HtlcIndex tlv.OptionalRecordT[tlv.TlvType6, uint16] } // toTlvStream converts an HTLCEntry record into a tlv representation. func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { - const ( - // A set of tlv type definitions used to serialize htlc entries - // to the database. We define it here instead of the head of - // the file to avoid naming conflicts. - // - // NOTE: A migration should be added whenever this list - // changes. - rHashType tlv.Type = 0 - refundTimeoutType tlv.Type = 1 - outputIndexType tlv.Type = 2 - incomingType tlv.Type = 3 - amtType tlv.Type = 4 - ) + records := []tlv.Record{ + h.RHash.Record(), + h.RefundTimeout.Record(), + h.OutputIndex.Record(), + h.Incoming.Record(), + h.Amt.Record(), + } - return tlv.NewStream( - tlv.MakeDynamicRecord( - rHashType, &h.RHash, h.RHashLen, - RHashEncoder, RHashDecoder, + h.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) { + records = append(records, r.Record()) + }) + + h.HtlcIndex.WhenSome(func(r tlv.RecordT[tlv.TlvType6, uint16]) { + records = append(records, r.Record()) + }) + + tlv.SortRecords(records) + + return tlv.NewStream(records...) +} + +// NewHTLCEntryFromHTLC creates a new HTLCEntry from an HTLC. +func NewHTLCEntryFromHTLC(htlc HTLC) (*HTLCEntry, error) { + h := &HTLCEntry{ + RHash: tlv.NewRecordT[tlv.TlvType0]( + NewSparsePayHash(htlc.RHash), ), - tlv.MakePrimitiveRecord( - refundTimeoutType, &h.RefundTimeout, + RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1]( + htlc.RefundTimeout, ), - tlv.MakePrimitiveRecord( - outputIndexType, &h.OutputIndex, + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2]( + uint16(htlc.OutputIndex), ), - tlv.MakePrimitiveRecord(incomingType, &h.incomingTlv), - // We will save 3 bytes if the amount is less or equal to - // 4,294,967,295 msat, or roughly 0.043 bitcoin. - tlv.MakeBigSizeRecord(amtType, &h.amtTlv), - ) + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](htlc.Incoming), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(htlc.Amt.ToSatoshis()), + ), + HtlcIndex: tlv.SomeRecordT(tlv.NewPrimitiveRecord[tlv.TlvType6]( + uint16(htlc.HtlcIndex), + )), + } + + if len(htlc.CustomRecords) != 0 { + blob, err := htlc.CustomRecords.Serialize() + if err != nil { + return nil, err + } + + h.CustomBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob), + ) + } + + return h, nil } // RevocationLog stores the info needed to construct a breach retribution. Its @@ -191,15 +233,15 @@ func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { type RevocationLog struct { // OurOutputIndex specifies our output index in this commitment. In a // remote commitment transaction, this is the to remote output index. - OurOutputIndex uint16 + OurOutputIndex tlv.RecordT[tlv.TlvType0, uint16] // TheirOutputIndex specifies their output index in this commitment. In // a remote commitment transaction, this is the to local output index. - TheirOutputIndex uint16 + TheirOutputIndex tlv.RecordT[tlv.TlvType1, uint16] // CommitTxHash is the hash of the latest version of the commitment // state, broadcast able by us. - CommitTxHash [32]byte + CommitTxHash tlv.RecordT[tlv.TlvType2, [32]byte] // HTLCEntries is the set of HTLCEntry's that are pending at this // particular commitment height. @@ -209,21 +251,65 @@ type RevocationLog struct { // directly spendable by us. In other words, it is the value of the // to_remote output on the remote parties' commitment transaction. // - // NOTE: this is a pointer so that it is clear if the value is zero or + // NOTE: this is an option so that it is clear if the value is zero or // nil. Since migration 30 of the channeldb initially did not include // this field, it could be the case that the field is not present for // all revocation logs. - OurBalance *lnwire.MilliSatoshi + OurBalance tlv.OptionalRecordT[tlv.TlvType3, BigSizeMilliSatoshi] // TheirBalance is the current available balance within the channel // directly spendable by the remote node. In other words, it is the // value of the to_local output on the remote parties' commitment. // - // NOTE: this is a pointer so that it is clear if the value is zero or + // NOTE: this is an option so that it is clear if the value is zero or // nil. Since migration 30 of the channeldb initially did not include // this field, it could be the case that the field is not present for // all revocation logs. - TheirBalance *lnwire.MilliSatoshi + TheirBalance tlv.OptionalRecordT[tlv.TlvType4, BigSizeMilliSatoshi] + + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This information is only created + // at channel funding time, and after wards is to be considered + // immutable. + CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob] +} + +// NewRevocationLog creates a new RevocationLog from the given parameters. +func NewRevocationLog(ourOutputIndex uint16, theirOutputIndex uint16, + commitHash [32]byte, ourBalance, + theirBalance fn.Option[lnwire.MilliSatoshi], htlcs []*HTLCEntry, + customBlob fn.Option[tlv.Blob]) RevocationLog { + + rl := RevocationLog{ + OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( + ourOutputIndex, + ), + TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1]( + theirOutputIndex, + ), + CommitTxHash: tlv.NewPrimitiveRecord[tlv.TlvType2](commitHash), + HTLCEntries: htlcs, + } + + ourBalance.WhenSome(func(balance lnwire.MilliSatoshi) { + rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3]( + tlv.NewBigSizeT(balance), + )) + }) + + theirBalance.WhenSome(func(balance lnwire.MilliSatoshi) { + rl.TheirBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(balance), + )) + }) + + customBlob.WhenSome(func(blob tlv.Blob) { + rl.CustomBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob), + ) + }) + + return rl } // putRevocationLog uses the fields `CommitTx` and `Htlcs` from a @@ -242,15 +328,32 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment, } rl := &RevocationLog{ - OurOutputIndex: uint16(ourOutputIndex), - TheirOutputIndex: uint16(theirOutputIndex), - CommitTxHash: commit.CommitTx.TxHash(), - HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)), + OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( + uint16(ourOutputIndex), + ), + TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1]( + uint16(theirOutputIndex), + ), + CommitTxHash: tlv.NewPrimitiveRecord[tlv.TlvType2, [32]byte]( + commit.CommitTx.TxHash(), + ), + HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)), } + commit.CustomBlob.WhenSome(func(blob tlv.Blob) { + rl.CustomBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob), + ) + }) + if !noAmtData { - rl.OurBalance = &commit.LocalBalance - rl.TheirBalance = &commit.RemoteBalance + rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3]( + tlv.NewBigSizeT(commit.LocalBalance), + )) + + rl.TheirBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(commit.RemoteBalance), + )) } for _, htlc := range commit.Htlcs { @@ -265,12 +368,9 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment, return ErrOutputIndexTooBig } - entry := &HTLCEntry{ - RHash: htlc.RHash, - RefundTimeout: htlc.RefundTimeout, - Incoming: htlc.Incoming, - OutputIndex: uint16(htlc.OutputIndex), - Amt: htlc.Amt.ToSatoshis(), + entry, err := NewHTLCEntryFromHTLC(htlc) + if err != nil { + return err } rl.HTLCEntries = append(rl.HTLCEntries, entry) } @@ -306,31 +406,27 @@ func fetchRevocationLog(log kvdb.RBucket, func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { // Add the tlv records for all non-optional fields. records := []tlv.Record{ - tlv.MakePrimitiveRecord( - revLogOurOutputIndexType, &rl.OurOutputIndex, - ), - tlv.MakePrimitiveRecord( - revLogTheirOutputIndexType, &rl.TheirOutputIndex, - ), - tlv.MakePrimitiveRecord( - revLogCommitTxHashType, &rl.CommitTxHash, - ), + rl.OurOutputIndex.Record(), + rl.TheirOutputIndex.Record(), + rl.CommitTxHash.Record(), } // Now we add any optional fields that are non-nil. - if rl.OurBalance != nil { - lb := uint64(*rl.OurBalance) - records = append(records, tlv.MakeBigSizeRecord( - revLogOurBalanceType, &lb, - )) - } + rl.OurBalance.WhenSome( + func(r tlv.RecordT[tlv.TlvType3, BigSizeMilliSatoshi]) { + records = append(records, r.Record()) + }, + ) - if rl.TheirBalance != nil { - rb := uint64(*rl.TheirBalance) - records = append(records, tlv.MakeBigSizeRecord( - revLogTheirBalanceType, &rb, - )) - } + rl.TheirBalance.WhenSome( + func(r tlv.RecordT[tlv.TlvType4, BigSizeMilliSatoshi]) { + records = append(records, r.Record()) + }, + ) + + rl.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) { + records = append(records, r.Record()) + }) // Create the tlv stream. tlvStream, err := tlv.NewStream(records...) @@ -351,14 +447,6 @@ func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { // format. func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error { for _, htlc := range htlcs { - // Patch the incomingTlv field. - if htlc.Incoming { - htlc.incomingTlv = 1 - } - - // Patch the amtTlv field. - htlc.amtTlv = uint64(htlc.Amt) - // Create the tlv stream. tlvStream, err := htlc.toTlvStream() if err != nil { @@ -376,27 +464,20 @@ func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error { // deserializeRevocationLog deserializes a RevocationLog based on tlv format. func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { - var ( - rl RevocationLog - ourBalance uint64 - theirBalance uint64 - ) + var rl RevocationLog + + ourBalance := rl.OurBalance.Zero() + theirBalance := rl.TheirBalance.Zero() + customBlob := rl.CustomBlob.Zero() // Create the tlv stream. tlvStream, err := tlv.NewStream( - tlv.MakePrimitiveRecord( - revLogOurOutputIndexType, &rl.OurOutputIndex, - ), - tlv.MakePrimitiveRecord( - revLogTheirOutputIndexType, &rl.TheirOutputIndex, - ), - tlv.MakePrimitiveRecord( - revLogCommitTxHashType, &rl.CommitTxHash, - ), - tlv.MakeBigSizeRecord(revLogOurBalanceType, &ourBalance), - tlv.MakeBigSizeRecord( - revLogTheirBalanceType, &theirBalance, - ), + rl.OurOutputIndex.Record(), + rl.TheirOutputIndex.Record(), + rl.CommitTxHash.Record(), + ourBalance.Record(), + theirBalance.Record(), + customBlob.Record(), ) if err != nil { return rl, err @@ -408,14 +489,16 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { return rl, err } - if t, ok := parsedTypes[revLogOurBalanceType]; ok && t == nil { - lb := lnwire.MilliSatoshi(ourBalance) - rl.OurBalance = &lb + if t, ok := parsedTypes[ourBalance.TlvType()]; ok && t == nil { + rl.OurBalance = tlv.SomeRecordT(ourBalance) } - if t, ok := parsedTypes[revLogTheirBalanceType]; ok && t == nil { - rb := lnwire.MilliSatoshi(theirBalance) - rl.TheirBalance = &rb + if t, ok := parsedTypes[theirBalance.TlvType()]; ok && t == nil { + rl.TheirBalance = tlv.SomeRecordT(theirBalance) + } + + if t, ok := parsedTypes[customBlob.TlvType()]; ok && t == nil { + rl.CustomBlob = tlv.SomeRecordT(customBlob) } // Read the HTLC entries. @@ -432,14 +515,28 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { for { var htlc HTLCEntry + customBlob := htlc.CustomBlob.Zero() + htlcIndex := htlc.HtlcIndex.Zero() + // Create the tlv stream. - tlvStream, err := htlc.toTlvStream() + records := []tlv.Record{ + htlc.RHash.Record(), + htlc.RefundTimeout.Record(), + htlc.OutputIndex.Record(), + htlc.Incoming.Record(), + htlc.Amt.Record(), + customBlob.Record(), + htlcIndex.Record(), + } + + tlvStream, err := tlv.NewStream(records...) if err != nil { return nil, err } // Read the HTLC entry. - if _, err := readTlvStream(r, tlvStream); err != nil { + parsedTypes, err := readTlvStream(r, tlvStream) + if err != nil { // We've reached the end when hitting an EOF. if err == io.ErrUnexpectedEOF { break @@ -447,13 +544,13 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { return nil, err } - // Patch the Incoming field. - if htlc.incomingTlv == 1 { - htlc.Incoming = true + if t, ok := parsedTypes[customBlob.TlvType()]; ok && t == nil { + htlc.CustomBlob = tlv.SomeRecordT(customBlob) } - // Patch the Amt field. - htlc.Amt = btcutil.Amount(htlc.amtTlv) + if t, ok := parsedTypes[htlcIndex.TlvType()]; ok && t == nil { + htlc.HtlcIndex = tlv.SomeRecordT(htlcIndex) + } // Append the entry. htlcs = append(htlcs, &htlc) @@ -469,6 +566,7 @@ func writeTlvStream(w io.Writer, s *tlv.Stream) error { if err := s.Encode(&b); err != nil { return err } + // Write the stream's length as a varint. err := tlv.WriteVarInt(w, uint64(b.Len()), &[8]byte{}) if err != nil { diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go index fc5303a48..4290552ee 100644 --- a/channeldb/revocation_log_test.go +++ b/channeldb/revocation_log_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/channels" "github.com/lightningnetwork/lnd/lnwire" @@ -33,17 +34,38 @@ var ( 0xff, // value = 255 } + customRecords = lnwire.CustomRecords{ + lnwire.MinCustomRecordsTlvType + 1: []byte("custom data"), + } + + blobBytes = []byte{ + // Corresponds to the encoded version of the above custom + // records. + 0xfe, 0x00, 0x01, 0x00, 0x01, 0x0b, 0x63, 0x75, 0x73, 0x74, + 0x6f, 0x6d, 0x20, 0x64, 0x61, 0x74, 0x61, + } + testHTLCEntry = HTLCEntry{ - RefundTimeout: 740_000, - OutputIndex: 10, - Incoming: true, - Amt: 1000_000, - amtTlv: 1000_000, - incomingTlv: 1, + RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1, uint32]( + 740_000, + ), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16]( + 10, + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(btcutil.Amount(1_000_000)), + ), + CustomBlob: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5](blobBytes), + ), + HtlcIndex: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, uint16](0x33), + ), } testHTLCEntryBytes = []byte{ - // Body length 23. - 0x16, + // Body length 45. + 0x2d, // Rhash tlv. 0x0, 0x0, // RefundTimeout tlv. @@ -54,6 +76,45 @@ var ( 0x3, 0x1, 0x1, // Amt tlv. 0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40, + // Custom blob tlv. + 0x5, 0x11, 0xfe, 0x00, 0x01, 0x00, 0x01, 0x0b, 0x63, 0x75, 0x73, + 0x74, 0x6f, 0x6d, 0x20, 0x64, 0x61, 0x74, 0x61, + // HLTC index tlv. + 0x6, 0x2, 0x0, 0x33, + } + + testHTLCEntryHash = HTLCEntry{ + RHash: tlv.NewPrimitiveRecord[tlv.TlvType0](NewSparsePayHash( + [32]byte{0x33, 0x44, 0x55}, + )), + RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1, uint32]( + 740_000, + ), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16]( + 10, + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(btcutil.Amount(1_000_000)), + ), + } + testHTLCEntryHashBytes = []byte{ + // Body length 54. + 0x36, + // Rhash tlv. + 0x0, 0x20, + 0x33, 0x44, 0x55, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // RefundTimeout tlv. + 0x1, 0x4, 0x0, 0xb, 0x4a, 0xa0, + // OutputIndex tlv. + 0x2, 0x2, 0x0, 0xa, + // Incoming tlv. + 0x3, 0x1, 0x1, + // Amt tlv. + 0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40, } localBalance = lnwire.MilliSatoshi(9000) @@ -68,24 +129,29 @@ var ( CommitTx: channels.TestFundingTx, CommitSig: bytes.Repeat([]byte{1}, 71), Htlcs: []HTLC{{ - RefundTimeout: testHTLCEntry.RefundTimeout, - OutputIndex: int32(testHTLCEntry.OutputIndex), - Incoming: testHTLCEntry.Incoming, - Amt: lnwire.NewMSatFromSatoshis( - testHTLCEntry.Amt, + RefundTimeout: testHTLCEntry.RefundTimeout.Val, + OutputIndex: int32(testHTLCEntry.OutputIndex.Val), + HtlcIndex: uint64( + testHTLCEntry.HtlcIndex.ValOpt(). + UnsafeFromSome(), ), + Incoming: testHTLCEntry.Incoming.Val, + Amt: lnwire.NewMSatFromSatoshis( + testHTLCEntry.Amt.Val.Int(), + ), + CustomRecords: customRecords, }}, + CustomBlob: fn.Some(blobBytes), } - testRevocationLogNoAmts = RevocationLog{ - OurOutputIndex: 0, - TheirOutputIndex: 1, - CommitTxHash: testChannelCommit.CommitTx.TxHash(), - HTLCEntries: []*HTLCEntry{&testHTLCEntry}, - } + testRevocationLogNoAmts = NewRevocationLog( + 0, 1, testChannelCommit.CommitTx.TxHash(), + fn.None[lnwire.MilliSatoshi](), fn.None[lnwire.MilliSatoshi](), + []*HTLCEntry{&testHTLCEntry}, fn.Some(blobBytes), + ) testRevocationLogNoAmtsBytes = []byte{ - // Body length 42. - 0x2a, + // Body length 61. + 0x3d, // OurOutputIndex tlv. 0x0, 0x2, 0x0, 0x0, // TheirOutputIndex tlv. @@ -96,19 +162,19 @@ var ( 0x6e, 0x60, 0x29, 0x23, 0x1d, 0x5e, 0xc5, 0xe6, 0xbd, 0xf7, 0xd3, 0x9b, 0x16, 0x7d, 0x0, 0xff, 0xc8, 0x22, 0x51, 0xb1, 0x5b, 0xa0, 0xbf, 0xd, + // Custom blob tlv. + 0x5, 0x11, 0xfe, 0x00, 0x01, 0x00, 0x01, 0x0b, 0x63, 0x75, 0x73, + 0x74, 0x6f, 0x6d, 0x20, 0x64, 0x61, 0x74, 0x61, } - testRevocationLogWithAmts = RevocationLog{ - OurOutputIndex: 0, - TheirOutputIndex: 1, - CommitTxHash: testChannelCommit.CommitTx.TxHash(), - HTLCEntries: []*HTLCEntry{&testHTLCEntry}, - OurBalance: &localBalance, - TheirBalance: &remoteBalance, - } + testRevocationLogWithAmts = NewRevocationLog( + 0, 1, testChannelCommit.CommitTx.TxHash(), + fn.Some(localBalance), fn.Some(remoteBalance), + []*HTLCEntry{&testHTLCEntry}, fn.Some(blobBytes), + ) testRevocationLogWithAmtsBytes = []byte{ - // Body length 52. - 0x34, + // Body length 71. + 0x47, // OurOutputIndex tlv. 0x0, 0x2, 0x0, 0x0, // TheirOutputIndex tlv. @@ -123,6 +189,9 @@ var ( 0x3, 0x3, 0xfd, 0x23, 0x28, // Remote Balance. 0x4, 0x3, 0xfd, 0x0b, 0xb8, + // Custom blob tlv. + 0x5, 0x11, 0xfe, 0x00, 0x01, 0x00, 0x01, 0x0b, 0x63, 0x75, 0x73, + 0x74, 0x6f, 0x6d, 0x20, 0x64, 0x61, 0x74, 0x61, } ) @@ -193,11 +262,6 @@ func TestSerializeHTLCEntriesEmptyRHash(t *testing.T) { // Copy the testHTLCEntry. entry := testHTLCEntry - // Set the internal fields to empty values so we can test the bytes are - // padded. - entry.incomingTlv = 0 - entry.amtTlv = 0 - // Write the tlv stream. buf := bytes.NewBuffer([]byte{}) err := serializeHTLCEntries(buf, []*HTLCEntry{&entry}) @@ -207,6 +271,21 @@ func TestSerializeHTLCEntriesEmptyRHash(t *testing.T) { require.Equal(t, testHTLCEntryBytes, buf.Bytes()) } +func TestSerializeHTLCEntriesWithRHash(t *testing.T) { + t.Parallel() + + // Copy the testHTLCEntry. + entry := testHTLCEntryHash + + // Write the tlv stream. + buf := bytes.NewBuffer([]byte{}) + err := serializeHTLCEntries(buf, []*HTLCEntry{&entry}) + require.NoError(t, err) + + // Check the bytes are read as expected. + require.Equal(t, testHTLCEntryHashBytes, buf.Bytes()) +} + func TestSerializeHTLCEntries(t *testing.T) { t.Parallel() @@ -215,7 +294,7 @@ func TestSerializeHTLCEntries(t *testing.T) { // Create a fake rHash. rHashBytes := bytes.Repeat([]byte{10}, 32) - copy(entry.RHash[:], rHashBytes) + copy(entry.RHash.Val[:], rHashBytes) // Construct the serialized bytes. // @@ -224,7 +303,7 @@ func TestSerializeHTLCEntries(t *testing.T) { partialBytes := testHTLCEntryBytes[3:] // Write the total length and RHash tlv. - expectedBytes := []byte{0x36, 0x0, 0x20} + expectedBytes := []byte{0x4d, 0x0, 0x20} expectedBytes = append(expectedBytes, rHashBytes...) // Append the rest. @@ -269,7 +348,7 @@ func TestSerializeAndDeserializeRevLog(t *testing.T) { t, &test.revLog, test.revLogBytes, ) - testDerializeRevocationLog( + testDeserializeRevocationLog( t, &test.revLog, test.revLogBytes, ) }) @@ -293,7 +372,7 @@ func testSerializeRevocationLog(t *testing.T, rl *RevocationLog, require.Equal(t, revLogBytes, buf.Bytes()[:bodyIndex]) } -func testDerializeRevocationLog(t *testing.T, revLog *RevocationLog, +func testDeserializeRevocationLog(t *testing.T, revLog *RevocationLog, revLogBytes []byte) { // Construct the full bytes. @@ -309,7 +388,7 @@ func testDerializeRevocationLog(t *testing.T, revLog *RevocationLog, require.Equal(t, *revLog, rl) } -func TestDerializeHTLCEntriesEmptyRHash(t *testing.T) { +func TestDeserializeHTLCEntriesEmptyRHash(t *testing.T) { t.Parallel() // Read the tlv stream. @@ -322,7 +401,7 @@ func TestDerializeHTLCEntriesEmptyRHash(t *testing.T) { require.Equal(t, &testHTLCEntry, htlcs[0]) } -func TestDerializeHTLCEntries(t *testing.T) { +func TestDeserializeHTLCEntries(t *testing.T) { t.Parallel() // Copy the testHTLCEntry. @@ -330,7 +409,7 @@ func TestDerializeHTLCEntries(t *testing.T) { // Create a fake rHash. rHashBytes := bytes.Repeat([]byte{10}, 32) - copy(entry.RHash[:], rHashBytes) + copy(entry.RHash.Val[:], rHashBytes) // Construct the serialized bytes. // @@ -339,7 +418,7 @@ func TestDerializeHTLCEntries(t *testing.T) { partialBytes := testHTLCEntryBytes[3:] // Write the total length and RHash tlv. - testBytes := append([]byte{0x36, 0x0, 0x20}, rHashBytes...) + testBytes := append([]byte{0x4d, 0x0, 0x20}, rHashBytes...) // Append the rest. testBytes = append(testBytes, partialBytes...) @@ -398,11 +477,11 @@ func TestDeleteLogBucket(t *testing.T) { err = kvdb.Update(backend, func(tx kvdb.RwTx) error { // Create the buckets. - chanBucket, _, err := createTestRevocatoinLogBuckets(tx) + chanBucket, _, err := createTestRevocationLogBuckets(tx) require.NoError(t, err) // Create the buckets again should give us an error. - _, _, err = createTestRevocatoinLogBuckets(tx) + _, _, err = createTestRevocationLogBuckets(tx) require.ErrorIs(t, err, kvdb.ErrBucketExists) // Delete both buckets. @@ -410,7 +489,7 @@ func TestDeleteLogBucket(t *testing.T) { require.NoError(t, err) // Create the buckets again should give us NO error. - _, _, err = createTestRevocatoinLogBuckets(tx) + _, _, err = createTestRevocationLogBuckets(tx) return err }, func() {}) require.NoError(t, err) @@ -516,7 +595,7 @@ func TestPutRevocationLog(t *testing.T) { // Construct the testing db transaction. dbTx := func(tx kvdb.RwTx) (RevocationLog, error) { // Create the buckets. - _, bucket, err := createTestRevocatoinLogBuckets(tx) + _, bucket, err := createTestRevocationLogBuckets(tx) require.NoError(t, err) // Save the log. @@ -686,7 +765,7 @@ func TestFetchRevocationLogCompatible(t *testing.T) { } } -func createTestRevocatoinLogBuckets(tx kvdb.RwTx) (kvdb.RwBucket, +func createTestRevocationLogBuckets(tx kvdb.RwTx) (kvdb.RwBucket, kvdb.RwBucket, error) { chanBucket, err := tx.CreateTopLevelBucket(openChannelBucket) diff --git a/config_builder.go b/config_builder.go index bef59b9a0..7c399297e 100644 --- a/config_builder.go +++ b/config_builder.go @@ -105,7 +105,7 @@ type DatabaseBuilder interface { type WalletConfigBuilder interface { // BuildWalletConfig is responsible for creating or unlocking and then // fully initializing a wallet. - BuildWalletConfig(context.Context, *DatabaseInstances, + BuildWalletConfig(context.Context, *DatabaseInstances, *AuxComponents, *rpcperms.InterceptorChain, []*ListenerWithSignal) (*chainreg.PartialChainControl, *btcwallet.Config, func(), error) @@ -120,14 +120,6 @@ type ChainControlBuilder interface { *btcwallet.Config) (*chainreg.ChainControl, func(), error) } -// AuxComponents is a set of auxiliary components that can be used by lnd for -// certain custom channel types. -type AuxComponents struct { - // MsgRouter is an optional message router that if set will be used in - // place of a new blank default message router. - MsgRouter fn.Option[msgmux.Router] -} - // ImplementationCfg is a struct that holds all configuration items for // components that can be implemented outside lnd itself. type ImplementationCfg struct { @@ -160,6 +152,18 @@ type ImplementationCfg struct { AuxComponents } +// AuxComponents is a set of auxiliary components that can be used by lnd for +// certain custom channel types. +type AuxComponents struct { + // AuxLeafStore is an optional data source that can be used by custom + // channels to fetch+store various data. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] + + // MsgRouter is an optional message router that if set will be used in + // place of a new blank default message router. + MsgRouter fn.Option[msgmux.Router] +} + // DefaultWalletImpl is the default implementation of our normal, btcwallet // backed configuration. type DefaultWalletImpl struct { @@ -242,7 +246,8 @@ func (d *DefaultWalletImpl) Permissions() map[string][]bakery.Op { // // NOTE: This is part of the WalletConfigBuilder interface. func (d *DefaultWalletImpl) BuildWalletConfig(ctx context.Context, - dbs *DatabaseInstances, interceptorChain *rpcperms.InterceptorChain, + dbs *DatabaseInstances, aux *AuxComponents, + interceptorChain *rpcperms.InterceptorChain, grpcListeners []*ListenerWithSignal) (*chainreg.PartialChainControl, *btcwallet.Config, func(), error) { @@ -562,6 +567,7 @@ func (d *DefaultWalletImpl) BuildWalletConfig(ctx context.Context, HeightHintDB: dbs.HeightHintDB, ChanStateDB: dbs.ChanStateDB.ChannelStateDB(), NeutrinoCS: neutrinoCS, + AuxLeafStore: aux.AuxLeafStore, ActiveNetParams: d.cfg.ActiveNetParams, FeeURL: d.cfg.FeeURL, Fee: &lncfg.Fee{ @@ -625,8 +631,9 @@ func (d *DefaultWalletImpl) BuildWalletConfig(ctx context.Context, // proxyBlockEpoch proxies a block epoch subsections to the underlying neutrino // rebroadcaster client. -func proxyBlockEpoch(notifier chainntnfs.ChainNotifier, -) func() (*blockntfns.Subscription, error) { +func proxyBlockEpoch( + notifier chainntnfs.ChainNotifier) func() (*blockntfns.Subscription, + error) { return func() (*blockntfns.Subscription, error) { blockEpoch, err := notifier.RegisterBlockEpochNtfn( @@ -717,6 +724,7 @@ func (d *DefaultWalletImpl) BuildChainControl( ChainIO: walletController, NetParams: *walletConfig.NetParams, CoinSelectionStrategy: walletConfig.CoinSelectionStrategy, + AuxLeafStore: partialChainControl.Cfg.AuxLeafStore, } // The broadcast is already always active for neutrino nodes, so we @@ -899,6 +907,10 @@ type DatabaseInstances struct { // for native SQL queries for tables that already support it. This may // be nil if the use-native-sql flag was not set. NativeSQLStore *sqldb.BaseDB + + // AuxLeafStore is an optional data source that can be used by custom + // channels to fetch+store various data. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] } // DefaultDatabaseBuilder is a type that builds the default database backends diff --git a/contractcourt/breach_arbitrator_test.go b/contractcourt/breach_arbitrator_test.go index 6a1865444..5940ee25b 100644 --- a/contractcourt/breach_arbitrator_test.go +++ b/contractcourt/breach_arbitrator_test.go @@ -22,6 +22,7 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntest/channels" @@ -1590,6 +1591,7 @@ func testBreachSpends(t *testing.T, test breachTest) { // Notify the breach arbiter about the breach. retribution, err := lnwallet.NewBreachRetribution( alice.State(), height, 1, forceCloseTx, + fn.Some[lnwallet.AuxLeafStore](&lnwallet.MockAuxLeafStore{}), ) require.NoError(t, err, "unable to create breach retribution") @@ -1799,6 +1801,7 @@ func TestBreachDelayedJusticeConfirmation(t *testing.T) { // Notify the breach arbiter about the breach. retribution, err := lnwallet.NewBreachRetribution( alice.State(), height, uint32(blockHeight), forceCloseTx, + fn.Some[lnwallet.AuxLeafStore](&lnwallet.MockAuxLeafStore{}), ) require.NoError(t, err, "unable to create breach retribution") diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 0cc4b111a..dbc97939a 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -217,6 +217,10 @@ type ChainArbitratorConfig struct { // meanwhile, turn `PaymentCircuit` into an interface or bring it to a // lower package. QueryIncomingCircuit func(circuit models.CircuitKey) *models.CircuitKey + + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] } // ChainArbitrator is a sub-system that oversees the on-chain resolution of all @@ -299,8 +303,13 @@ func (a *arbChannel) NewAnchorResolutions() (*lnwallet.AnchorResolutions, return nil, err } + var chanOpts []lnwallet.ChannelOpt + a.c.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) + chanMachine, err := lnwallet.NewLightningChannel( - a.c.cfg.Signer, channel, nil, + a.c.cfg.Signer, channel, nil, chanOpts..., ) if err != nil { return nil, err @@ -344,10 +353,15 @@ func (a *arbChannel) ForceCloseChan() (*lnwallet.LocalForceCloseSummary, error) return nil, err } + var chanOpts []lnwallet.ChannelOpt + a.c.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) + // Finally, we'll force close the channel completing // the force close workflow. chanMachine, err := lnwallet.NewLightningChannel( - a.c.cfg.Signer, channel, nil, + a.c.cfg.Signer, channel, nil, chanOpts..., ) if err != nil { return nil, err diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 3cbc7422d..b1e3fc1c2 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -193,6 +193,9 @@ type chainWatcherConfig struct { // obfuscater. This is used by the chain watcher to identify which // state was broadcast and confirmed on-chain. extractStateNumHint func(*wire.MsgTx, [lnwallet.StateHintSize]byte) uint64 + + // auxLeafStore can be used to fetch information for custom channels. + auxLeafStore fn.Option[lnwallet.AuxLeafStore] } // chainWatcher is a system that's assigned to every active channel. The duty @@ -308,7 +311,7 @@ func (c *chainWatcher) Start() error { ) if chanState.ChanType.IsTaproot() { c.fundingPkScript, _, err = input.GenTaprootFundingScript( - localKey, remoteKey, 0, + localKey, remoteKey, 0, chanState.TapscriptRoot, ) if err != nil { return err @@ -423,15 +426,36 @@ func (c *chainWatcher) handleUnknownLocalState( &c.cfg.chanState.LocalChanCfg, &c.cfg.chanState.RemoteChanCfg, ) + auxResult, err := fn.MapOptionZ( + c.cfg.auxLeafStore, + //nolint:lll + func(s lnwallet.AuxLeafStore) fn.Result[lnwallet.CommitDiffAuxResult] { + return s.FetchLeavesFromCommit( + lnwallet.NewAuxChanState(c.cfg.chanState), + c.cfg.chanState.LocalCommitment, *commitKeyRing, + ) + }, + ).Unpack() + if err != nil { + return false, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // With the keys derived, we'll construct the remote script that'll be // present if they have a non-dust balance on the commitment. var leaseExpiry uint32 if c.cfg.chanState.ChanType.HasLeaseExpiration() { leaseExpiry = c.cfg.chanState.ThawHeight } + + remoteAuxLeaf := fn.ChainOption( + func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + }, + )(auxResult.AuxLeaves) remoteScript, _, err := lnwallet.CommitScriptToRemote( c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator, commitKeyRing.ToRemoteKey, leaseExpiry, + remoteAuxLeaf, ) if err != nil { return false, err @@ -440,10 +464,16 @@ func (c *chainWatcher) handleUnknownLocalState( // Next, we'll derive our script that includes the revocation base for // the remote party allowing them to claim this output before the CSV // delay if we breach. + localAuxLeaf := fn.ChainOption( + func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + }, + )(auxResult.AuxLeaves) localScript, err := lnwallet.CommitScriptToSelf( c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator, commitKeyRing.ToLocalKey, commitKeyRing.RevocationKey, uint32(c.cfg.chanState.LocalChanCfg.CsvDelay), leaseExpiry, + localAuxLeaf, ) if err != nil { return false, err @@ -866,7 +896,7 @@ func (c *chainWatcher) handlePossibleBreach(commitSpend *chainntnfs.SpendDetail, spendHeight := uint32(commitSpend.SpendingHeight) retribution, err := lnwallet.NewBreachRetribution( c.cfg.chanState, broadcastStateNum, spendHeight, - commitSpend.SpendingTx, + commitSpend.SpendingTx, c.cfg.auxLeafStore, ) switch { @@ -1116,8 +1146,8 @@ func (c *chainWatcher) dispatchLocalForceClose( "detected", c.cfg.chanState.FundingOutpoint) forceClose, err := lnwallet.NewLocalForceCloseSummary( - c.cfg.chanState, c.cfg.signer, - commitSpend.SpendingTx, stateNum, + c.cfg.chanState, c.cfg.signer, commitSpend.SpendingTx, stateNum, + c.cfg.auxLeafStore, ) if err != nil { return err @@ -1210,7 +1240,7 @@ func (c *chainWatcher) dispatchRemoteForceClose( // channel on-chain. uniClose, err := lnwallet.NewUnilateralCloseSummary( c.cfg.chanState, c.cfg.signer, commitSpend, - remoteCommit, commitPoint, + remoteCommit, commitPoint, c.cfg.auxLeafStore, ) if err != nil { return err diff --git a/funding/manager.go b/funding/manager.go index 61b03200e..6bb027725 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -24,6 +24,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/discovery" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -544,6 +545,10 @@ type Config struct { // backed funding flow to not use utxos still being swept by the sweeper // subsystem. IsSweeperOutpoint func(wire.OutPoint) bool + + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] } // Manager acts as an orchestrator/bridge between the wallet's @@ -1069,9 +1074,14 @@ func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel, } } + var chanOpts []lnwallet.ChannelOpt + f.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) + // We create the state-machine object which wraps the database state. lnChannel, err := lnwallet.NewLightningChannel( - nil, channel, nil, + nil, channel, nil, chanOpts..., ) if err != nil { log.Errorf("Unable to create LightningChannel(%v): %v", @@ -2899,6 +2909,7 @@ func makeFundingScript(channel *channeldb.OpenChannel) ([]byte, error) { if channel.ChanType.IsTaproot() { pkScript, _, err := input.GenTaprootFundingScript( localKey, remoteKey, int64(channel.Capacity), + channel.TapscriptRoot, ) if err != nil { return nil, err diff --git a/funding/manager_test.go b/funding/manager_test.go index c4c8b4f36..471d0209f 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -28,6 +28,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/discovery" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lncfg" @@ -563,6 +564,9 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey, IsSweeperOutpoint: func(wire.OutPoint) bool { return false }, + AuxLeafStore: fn.Some[lnwallet.AuxLeafStore]( + &lnwallet.MockAuxLeafStore{}, + ), } for _, op := range options { @@ -672,6 +676,7 @@ func recreateAliceFundingManager(t *testing.T, alice *testNode) { OpenChannelPredicate: chainedAcceptor, DeleteAliasEdge: oldCfg.DeleteAliasEdge, AliasManager: oldCfg.AliasManager, + AuxLeafStore: oldCfg.AuxLeafStore, }) require.NoError(t, err, "failed recreating aliceFundingManager") diff --git a/graph/builder.go b/graph/builder.go index 82a36eb36..717d3e5ad 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -11,12 +11,14 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnutils" @@ -1138,12 +1140,14 @@ func makeFundingScript(bitcoinKey1, bitcoinKey2 []byte, } fundingScript, _, err := input.GenTaprootFundingScript( - pubKey1, pubKey2, 0, + pubKey1, pubKey2, 0, fn.None[chainhash.Hash](), ) if err != nil { return nil, err } + // TODO(roasbeef): add tapscript root to gossip v1.5 + return fundingScript, nil } diff --git a/input/script_utils.go b/input/script_utils.go index d801846f8..91ca55292 100644 --- a/input/script_utils.go +++ b/input/script_utils.go @@ -11,8 +11,10 @@ import ( "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "golang.org/x/crypto/ripemd160" @@ -199,26 +201,30 @@ func GenFundingPkScript(aPub, bPub []byte, amt int64) ([]byte, *wire.TxOut, erro } // GenTaprootFundingScript constructs the taproot-native funding output that -// uses musig2 to create a single aggregated key to anchor the channel. +// uses MuSig2 to create a single aggregated key to anchor the channel. func GenTaprootFundingScript(aPub, bPub *btcec.PublicKey, - amt int64) ([]byte, *wire.TxOut, error) { + amt int64, tapscriptRoot fn.Option[chainhash.Hash]) ([]byte, + *wire.TxOut, error) { + + muSig2Opt := musig2.WithBIP86KeyTweak() + tapscriptRoot.WhenSome(func(scriptRoot chainhash.Hash) { + muSig2Opt = musig2.WithTaprootKeyTweak(scriptRoot[:]) + }) // Similar to the existing p2wsh funding script, we'll always make sure // we sort the keys before any major operations. In order to ensure // that there's no other way this output can be spent, we'll use a BIP - // 86 tweak here during aggregation. - // - // TODO(roasbeef): revisit if BIP 86 is needed here? + // 86 tweak here during aggregation, unless the user has explicitly + // specified a tapscript root. combinedKey, _, _, err := musig2.AggregateKeys( - []*btcec.PublicKey{aPub, bPub}, true, - musig2.WithBIP86KeyTweak(), + []*btcec.PublicKey{aPub, bPub}, true, muSig2Opt, ) if err != nil { return nil, nil, fmt.Errorf("unable to combine keys: %w", err) } // Now that we have the combined key, we can create a taproot pkScript - // from this, and then make the txout given the amount. + // from this, and then make the txOut given the amount. pkScript, err := PayToTaprootScript(combinedKey.FinalKey) if err != nil { return nil, nil, fmt.Errorf("unable to make taproot "+ @@ -228,7 +234,7 @@ func GenTaprootFundingScript(aPub, bPub *btcec.PublicKey, txOut := wire.NewTxOut(amt, pkScript) // For the "witness program" we just return the raw pkScript since the - // output we create can _only_ be spent with a musig2 signature. + // output we create can _only_ be spent with a MuSig2 signature. return pkScript, txOut, nil } @@ -640,6 +646,13 @@ type HtlcScriptTree struct { // TimeoutTapLeaf is the tapleaf for the timeout path. TimeoutTapLeaf txscript.TapLeaf + // AuxLeaf is an auxiliary leaf that can be used to extend the base + // HTLC script tree with new spend paths, or just as extra commitment + // space. When present, this leaf will always be in the right-most area + // of the tapscript tree. + AuxLeaf AuxTapLeaf + + // htlcType is the type of HTLC script this is. htlcType htlcType } @@ -720,8 +733,8 @@ var _ TapscriptDescriptor = (*HtlcScriptTree)(nil) // senderHtlcTapScriptTree builds the tapscript tree which is used to anchor // the HTLC key for HTLCs on the sender's commitment. func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, - revokeKey *btcec.PublicKey, payHash []byte, - hType htlcType) (*HtlcScriptTree, error) { + revokeKey *btcec.PublicKey, payHash []byte, hType htlcType, + auxLeaf AuxTapLeaf) (*HtlcScriptTree, error) { // First, we'll obtain the tap leaves for both the success and timeout // path. @@ -738,11 +751,14 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, return nil, err } + tapLeaves := []txscript.TapLeaf{successTapLeaf, timeoutTapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + // With the two leaves obtained, we'll now make the tapscript tree, // then obtain the root from that - tapscriptTree := txscript.AssembleTaprootScriptTree( - successTapLeaf, timeoutTapLeaf, - ) + tapscriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...) tapScriptRoot := tapscriptTree.RootNode.TapHash() @@ -761,6 +777,7 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, }, SuccessTapLeaf: successTapLeaf, TimeoutTapLeaf: timeoutTapLeaf, + AuxLeaf: auxLeaf, htlcType: hType, }, nil } @@ -795,7 +812,8 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, // unilaterally spend the created output. func SenderHTLCScriptTaproot(senderHtlcKey, receiverHtlcKey, revokeKey *btcec.PublicKey, payHash []byte, - whoseCommit lntypes.ChannelParty) (*HtlcScriptTree, error) { + whoseCommit lntypes.ChannelParty, auxLeaf AuxTapLeaf) (*HtlcScriptTree, + error) { var hType htlcType if whoseCommit.IsLocal() { @@ -808,8 +826,8 @@ func SenderHTLCScriptTaproot(senderHtlcKey, receiverHtlcKey, // tree that includes the top level output script, as well as the two // tap leaf paths. return senderHtlcTapScriptTree( - senderHtlcKey, receiverHtlcKey, revokeKey, payHash, - hType, + senderHtlcKey, receiverHtlcKey, revokeKey, payHash, hType, + auxLeaf, ) } @@ -1279,8 +1297,8 @@ func ReceiverHtlcTapLeafSuccess(receiverHtlcKey *btcec.PublicKey, // receiverHtlcTapScriptTree builds the tapscript tree which is used to anchor // the HTLC key for HTLCs on the receiver's commitment. func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, - revokeKey *btcec.PublicKey, payHash []byte, - cltvExpiry uint32, hType htlcType) (*HtlcScriptTree, error) { + revokeKey *btcec.PublicKey, payHash []byte, cltvExpiry uint32, + hType htlcType, auxLeaf AuxTapLeaf) (*HtlcScriptTree, error) { // First, we'll obtain the tap leaves for both the success and timeout // path. @@ -1297,11 +1315,14 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, return nil, err } + tapLeaves := []txscript.TapLeaf{timeoutTapLeaf, successTapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + // With the two leaves obtained, we'll now make the tapscript tree, // then obtain the root from that - tapscriptTree := txscript.AssembleTaprootScriptTree( - timeoutTapLeaf, successTapLeaf, - ) + tapscriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...) tapScriptRoot := tapscriptTree.RootNode.TapHash() @@ -1320,6 +1341,7 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, }, SuccessTapLeaf: successTapLeaf, TimeoutTapLeaf: timeoutTapLeaf, + AuxLeaf: auxLeaf, htlcType: hType, }, nil } @@ -1355,7 +1377,7 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, func ReceiverHTLCScriptTaproot(cltvExpiry uint32, senderHtlcKey, receiverHtlcKey, revocationKey *btcec.PublicKey, payHash []byte, whoseCommit lntypes.ChannelParty, -) (*HtlcScriptTree, error) { + auxLeaf AuxTapLeaf) (*HtlcScriptTree, error) { var hType htlcType if whoseCommit.IsLocal() { @@ -1369,7 +1391,7 @@ func ReceiverHTLCScriptTaproot(cltvExpiry uint32, // tap leaf paths. return receiverHtlcTapScriptTree( senderHtlcKey, receiverHtlcKey, revocationKey, payHash, - cltvExpiry, hType, + cltvExpiry, hType, auxLeaf, ) } @@ -1598,9 +1620,9 @@ func TaprootSecondLevelTapLeaf(delayKey *btcec.PublicKey, } // SecondLevelHtlcTapscriptTree construct the indexed tapscript tree needed to -// generate the taptweak to create the final output and also control block. -func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, - csvDelay uint32) (*txscript.IndexedTapScriptTree, error) { +// generate the tap tweak to create the final output and also control block. +func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, csvDelay uint32, + auxLeaf AuxTapLeaf) (*txscript.IndexedTapScriptTree, error) { // First grab the second level leaf script we need to create the top // level output. @@ -1609,9 +1631,14 @@ func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, return nil, err } + tapLeaves := []txscript.TapLeaf{secondLevelTapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + // Now that we have the sole second level script, we can create the // tapscript tree that commits to both the leaves. - return txscript.AssembleTaprootScriptTree(secondLevelTapLeaf), nil + return txscript.AssembleTaprootScriptTree(tapLeaves...), nil } // TaprootSecondLevelHtlcScript is the uniform script that's used as the output @@ -1631,12 +1658,12 @@ func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, // // The keyspend path require knowledge of the top level revocation private key. func TaprootSecondLevelHtlcScript(revokeKey, delayKey *btcec.PublicKey, - csvDelay uint32) (*btcec.PublicKey, error) { + csvDelay uint32, auxLeaf AuxTapLeaf) (*btcec.PublicKey, error) { // First, we'll make the tapscript tree that commits to the redemption // path. tapScriptTree, err := SecondLevelHtlcTapscriptTree( - delayKey, csvDelay, + delayKey, csvDelay, auxLeaf, ) if err != nil { return nil, err @@ -1661,17 +1688,21 @@ type SecondLevelScriptTree struct { // SuccessTapLeaf is the tapleaf for the redemption path. SuccessTapLeaf txscript.TapLeaf + + // AuxLeaf is an optional leaf that can be used to extend the script + // tree. + AuxLeaf AuxTapLeaf } // TaprootSecondLevelScriptTree constructs the tapscript tree used to spend the // second level HTLC output. func TaprootSecondLevelScriptTree(revokeKey, delayKey *btcec.PublicKey, - csvDelay uint32) (*SecondLevelScriptTree, error) { + csvDelay uint32, auxLeaf AuxTapLeaf) (*SecondLevelScriptTree, error) { // First, we'll make the tapscript tree that commits to the redemption // path. tapScriptTree, err := SecondLevelHtlcTapscriptTree( - delayKey, csvDelay, + delayKey, csvDelay, auxLeaf, ) if err != nil { return nil, err @@ -1692,6 +1723,7 @@ func TaprootSecondLevelScriptTree(revokeKey, delayKey *btcec.PublicKey, InternalKey: revokeKey, }, SuccessTapLeaf: tapScriptTree.LeafMerkleProofs[0].TapLeaf, + AuxLeaf: auxLeaf, }, nil } @@ -2073,6 +2105,12 @@ type CommitScriptTree struct { // RevocationLeaf is the leaf used to spend the output with the // revocation key signature. RevocationLeaf txscript.TapLeaf + + // AuxLeaf is an auxiliary leaf that can be used to extend the base + // commitment script tree with new spend paths, or just as extra + // commitment space. When present, this leaf will always be in the + // left-most or right-most area of the tapscript tree. + AuxLeaf AuxTapLeaf } // A compile time check to ensure CommitScriptTree implements the @@ -2137,8 +2175,9 @@ func (c *CommitScriptTree) Tree() ScriptTree { // NewLocalCommitScriptTree returns a new CommitScript tree that can be used to // create and spend the commitment output for the local party. -func NewLocalCommitScriptTree(csvTimeout uint32, - selfKey, revokeKey *btcec.PublicKey) (*CommitScriptTree, error) { +func NewLocalCommitScriptTree(csvTimeout uint32, selfKey, + revokeKey *btcec.PublicKey, auxLeaf AuxTapLeaf) (*CommitScriptTree, + error) { // First, we'll need to construct the tapLeaf that'll be our delay CSV // clause. @@ -2158,9 +2197,13 @@ func NewLocalCommitScriptTree(csvTimeout uint32, // the two leaves, and then obtain a root from that. delayTapLeaf := txscript.NewBaseTapLeaf(delayScript) revokeTapLeaf := txscript.NewBaseTapLeaf(revokeScript) - tapScriptTree := txscript.AssembleTaprootScriptTree( - delayTapLeaf, revokeTapLeaf, - ) + + tapLeaves := []txscript.TapLeaf{delayTapLeaf, revokeTapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + + tapScriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...) tapScriptRoot := tapScriptTree.RootNode.TapHash() // Now that we have our root, we can arrive at the final output script @@ -2178,6 +2221,7 @@ func NewLocalCommitScriptTree(csvTimeout uint32, }, SettleLeaf: delayTapLeaf, RevocationLeaf: revokeTapLeaf, + AuxLeaf: auxLeaf, }, nil } @@ -2247,7 +2291,7 @@ func TaprootCommitScriptToSelf(csvTimeout uint32, selfKey, revokeKey *btcec.PublicKey) (*btcec.PublicKey, error) { commitScriptTree, err := NewLocalCommitScriptTree( - csvTimeout, selfKey, revokeKey, + csvTimeout, selfKey, revokeKey, NoneTapLeaf(), ) if err != nil { return nil, err @@ -2573,7 +2617,7 @@ func CommitScriptToRemoteConfirmed(key *btcec.PublicKey) ([]byte, error) { // NewRemoteCommitScriptTree constructs a new script tree for the remote party // to sweep their funds after a hard coded 1 block delay. func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey, -) (*CommitScriptTree, error) { + auxLeaf AuxTapLeaf) (*CommitScriptTree, error) { // First, construct the remote party's tapscript they'll use to sweep // their outputs. @@ -2589,10 +2633,16 @@ func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey, return nil, err } + tapLeaf := txscript.NewBaseTapLeaf(remoteScript) + + tapLeaves := []txscript.TapLeaf{tapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + // With this script constructed, we'll map that into a tapLeaf, then // make a new tapscript root from that. - tapLeaf := txscript.NewBaseTapLeaf(remoteScript) - tapScriptTree := txscript.AssembleTaprootScriptTree(tapLeaf) + tapScriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...) tapScriptRoot := tapScriptTree.RootNode.TapHash() // Now that we have our root, we can arrive at the final output script @@ -2609,6 +2659,7 @@ func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey, InternalKey: &TaprootNUMSKey, }, SettleLeaf: tapLeaf, + AuxLeaf: auxLeaf, }, nil } @@ -2625,9 +2676,9 @@ func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey, // OP_CHECKSIG // 1 OP_CHECKSEQUENCEVERIFY OP_DROP func TaprootCommitScriptToRemote(remoteKey *btcec.PublicKey, -) (*btcec.PublicKey, error) { + auxLeaf AuxTapLeaf) (*btcec.PublicKey, error) { - commitScriptTree, err := NewRemoteCommitScriptTree(remoteKey) + commitScriptTree, err := NewRemoteCommitScriptTree(remoteKey, auxLeaf) if err != nil { return nil, err } diff --git a/input/size_test.go b/input/size_test.go index daa7053cc..33f9ff539 100644 --- a/input/size_test.go +++ b/input/size_test.go @@ -853,7 +853,7 @@ var witnessSizeTests = []witnessSizeTest{ signer := &dummySigner{} commitScriptTree, err := input.NewLocalCommitScriptTree( testCSVDelay, testKey.PubKey(), - testKey.PubKey(), + testKey.PubKey(), input.NoneTapLeaf(), ) require.NoError(t, err) @@ -887,7 +887,7 @@ var witnessSizeTests = []witnessSizeTest{ signer := &dummySigner{} commitScriptTree, err := input.NewLocalCommitScriptTree( testCSVDelay, testKey.PubKey(), - testKey.PubKey(), + testKey.PubKey(), input.NoneTapLeaf(), ) require.NoError(t, err) @@ -921,7 +921,7 @@ var witnessSizeTests = []witnessSizeTest{ signer := &dummySigner{} //nolint:lll commitScriptTree, err := input.NewRemoteCommitScriptTree( - testKey.PubKey(), + testKey.PubKey(), input.NoneTapLeaf(), ) require.NoError(t, err) @@ -988,6 +988,7 @@ var witnessSizeTests = []witnessSizeTest{ scriptTree, err := input.SecondLevelHtlcTapscriptTree( testKey.PubKey(), testCSVDelay, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1027,6 +1028,7 @@ var witnessSizeTests = []witnessSizeTest{ scriptTree, err := input.SecondLevelHtlcTapscriptTree( testKey.PubKey(), testCSVDelay, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1075,6 +1077,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), payHash[:], lntypes.Remote, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1116,7 +1119,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], lntypes.Remote, + payHash[:], lntypes.Remote, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1158,7 +1161,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], lntypes.Remote, + payHash[:], lntypes.Remote, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1205,6 +1208,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), payHash[:], lntypes.Remote, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1265,6 +1269,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), payHash[:], lntypes.Remote, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1310,7 +1315,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], lntypes.Remote, + payHash[:], lntypes.Remote, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1383,7 +1388,7 @@ func genTimeoutTx(t *testing.T, // Create the unsigned timeout tx. timeoutTx, err := lnwallet.CreateHtlcTimeoutTx( chanType, false, testOutPoint, testAmt, testCLTVExpiry, - testCSVDelay, 0, testPubkey, testPubkey, + testCSVDelay, 0, testPubkey, testPubkey, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1396,7 +1401,7 @@ func genTimeoutTx(t *testing.T, if chanType.IsTaproot() { tapscriptTree, err = input.SenderHTLCScriptTaproot( testPubkey, testPubkey, testPubkey, testHash160, - lntypes.Remote, + lntypes.Remote, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1452,7 +1457,7 @@ func genSuccessTx(t *testing.T, chanType channeldb.ChannelType) *wire.MsgTx { // Create the unsigned success tx. successTx, err := lnwallet.CreateHtlcSuccessTx( chanType, false, testOutPoint, testAmt, testCSVDelay, 0, - testPubkey, testPubkey, + testPubkey, testPubkey, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1465,7 +1470,7 @@ func genSuccessTx(t *testing.T, chanType channeldb.ChannelType) *wire.MsgTx { if chanType.IsTaproot() { tapscriptTree, err = input.ReceiverHTLCScriptTaproot( testCLTVExpiry, testPubkey, testPubkey, testPubkey, - testHash160, lntypes.Remote, + testHash160, lntypes.Remote, input.NoneTapLeaf(), ) require.NoError(t, err) diff --git a/input/taproot.go b/input/taproot.go index 34cdb974d..2ca6e9723 100644 --- a/input/taproot.go +++ b/input/taproot.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/waddrmgr" + "github.com/lightningnetwork/lnd/fn" ) const ( @@ -21,6 +22,33 @@ const ( PubKeyFormatCompressedOdd byte = 0x03 ) +// AuxTapLeaf is a type alias for an optional tapscript leaf that may be added +// to the tapscript tree of HTLC and commitment outputs. +type AuxTapLeaf = fn.Option[txscript.TapLeaf] + +// NoneTapLeaf returns an empty optional tapscript leaf. +func NoneTapLeaf() AuxTapLeaf { + return fn.None[txscript.TapLeaf]() +} + +// HtlcIndex represents the monotonically increasing counter that is used to +// identify HTLCs created by a peer. +type HtlcIndex = uint64 + +// HtlcAuxLeaf is a type that represents an auxiliary leaf for an HTLC output. +// An HTLC may have up to two aux leaves: one for the output on the commitment +// transaction, and one for the second level HTLC. +type HtlcAuxLeaf struct { + AuxTapLeaf + + // SecondLevelLeaf is the auxiliary leaf for the second level HTLC + // success or timeout transaction. + SecondLevelLeaf AuxTapLeaf +} + +// HtlcAuxLeaves is a type alias for a map of optional tapscript leaves. +type HtlcAuxLeaves = map[HtlcIndex]HtlcAuxLeaf + // NewTxSigHashesV0Only returns a new txscript.TxSigHashes instance that will // only calculate the sighash midstate values for segwit v0 inputs and can // therefore never be used for transactions that want to spend segwit v1 diff --git a/input/taproot_test.go b/input/taproot_test.go index 434be2dfd..a1259be19 100644 --- a/input/taproot_test.go +++ b/input/taproot_test.go @@ -1,13 +1,16 @@ package input import ( + "bytes" "crypto/rand" + "fmt" "testing" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" "github.com/stretchr/testify/require" @@ -31,7 +34,9 @@ type testSenderHtlcScriptTree struct { htlcAmt int64 } -func newTestSenderHtlcScriptTree(t *testing.T) *testSenderHtlcScriptTree { +func newTestSenderHtlcScriptTree(t *testing.T, + auxLeaf AuxTapLeaf) *testSenderHtlcScriptTree { + var preImage lntypes.Preimage _, err := rand.Read(preImage[:]) require.NoError(t, err) @@ -48,7 +53,7 @@ func newTestSenderHtlcScriptTree(t *testing.T) *testSenderHtlcScriptTree { payHash := preImage.Hash() htlcScriptTree, err := SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], lntypes.Remote, + payHash[:], lntypes.Remote, auxLeaf, ) require.NoError(t, err) @@ -207,13 +212,9 @@ func htlcSenderTimeoutWitnessGen(sigHash txscript.SigHashType, } } -// TestTaprootSenderHtlcSpend tests that all the positive and negative paths -// for the sender HTLC tapscript tree work as expected. -func TestTaprootSenderHtlcSpend(t *testing.T) { - t.Parallel() - +func testTaprootSenderHtlcSpend(t *testing.T, auxLeaf AuxTapLeaf) { // First, create a new test script tree. - htlcScriptTree := newTestSenderHtlcScriptTree(t) + htlcScriptTree := newTestSenderHtlcScriptTree(t, auxLeaf) spendTx := wire.NewMsgTx(2) spendTx.AddTxIn(&wire.TxIn{}) @@ -432,6 +433,26 @@ func TestTaprootSenderHtlcSpend(t *testing.T) { } } +// TestTaprootSenderHtlcSpend tests that all the positive and negative paths +// for the sender HTLC tapscript tree work as expected. +func TestTaprootSenderHtlcSpend(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some(txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + )) + } + + testTaprootSenderHtlcSpend(t, auxLeaf) + }) + } +} + type testReceiverHtlcScriptTree struct { preImage lntypes.Preimage @@ -452,7 +473,9 @@ type testReceiverHtlcScriptTree struct { lockTime int32 } -func newTestReceiverHtlcScriptTree(t *testing.T) *testReceiverHtlcScriptTree { +func newTestReceiverHtlcScriptTree(t *testing.T, + auxLeaf AuxTapLeaf) *testReceiverHtlcScriptTree { + var preImage lntypes.Preimage _, err := rand.Read(preImage[:]) require.NoError(t, err) @@ -471,7 +494,7 @@ func newTestReceiverHtlcScriptTree(t *testing.T) *testReceiverHtlcScriptTree { payHash := preImage.Hash() htlcScriptTree, err := ReceiverHTLCScriptTaproot( cltvExpiry, senderKey.PubKey(), receiverKey.PubKey(), - revokeKey.PubKey(), payHash[:], lntypes.Remote, + revokeKey.PubKey(), payHash[:], lntypes.Remote, auxLeaf, ) require.NoError(t, err) @@ -629,15 +652,11 @@ func htlcReceiverSuccessWitnessGen(sigHash txscript.SigHashType, } } -// TestTaprootReceiverHtlcSpend tests that all possible paths for redeeming an -// accepted HTLC (on the commitment transaction) of the receiver work properly. -func TestTaprootReceiverHtlcSpend(t *testing.T) { - t.Parallel() - +func testTaprootReceiverHtlcSpend(t *testing.T, auxLeaf AuxTapLeaf) { // We'll start by creating the HTLC script tree (contains all 3 valid // spend paths), and also a mock spend transaction that we'll be // signing below. - htlcScriptTree := newTestReceiverHtlcScriptTree(t) + htlcScriptTree := newTestReceiverHtlcScriptTree(t, auxLeaf) // TODO(roasbeef): issue with revoke key??? ctrl block even/odd @@ -891,6 +910,28 @@ func TestTaprootReceiverHtlcSpend(t *testing.T) { } } +// TestTaprootReceiverHtlcSpend tests that all possible paths for redeeming an +// accepted HTLC (on the commitment transaction) of the receiver work properly. +func TestTaprootReceiverHtlcSpend(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some( + txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + ), + ) + } + + testTaprootReceiverHtlcSpend(t, auxLeaf) + }) + } +} + type testCommitScriptTree struct { csvDelay uint32 @@ -905,7 +946,9 @@ type testCommitScriptTree struct { *CommitScriptTree } -func newTestCommitScriptTree(local bool) (*testCommitScriptTree, error) { +func newTestCommitScriptTree(local bool, + auxLeaf AuxTapLeaf) (*testCommitScriptTree, error) { + selfKey, err := btcec.NewPrivateKey() if err != nil { return nil, err @@ -925,10 +968,11 @@ func newTestCommitScriptTree(local bool) (*testCommitScriptTree, error) { if local { commitScriptTree, err = NewLocalCommitScriptTree( csvDelay, selfKey.PubKey(), revokeKey.PubKey(), + auxLeaf, ) } else { commitScriptTree, err = NewRemoteCommitScriptTree( - selfKey.PubKey(), + selfKey.PubKey(), auxLeaf, ) } if err != nil { @@ -1020,12 +1064,8 @@ func localCommitRevokeWitGen(sigHash txscript.SigHashType, } } -// TestTaprootCommitScriptToSelf tests that the taproot script for redeeming -// one's output after a force close behaves as expected. -func TestTaprootCommitScriptToSelf(t *testing.T) { - t.Parallel() - - commitScriptTree, err := newTestCommitScriptTree(true) +func testTaprootCommitScriptToSelf(t *testing.T, auxLeaf AuxTapLeaf) { + commitScriptTree, err := newTestCommitScriptTree(true, auxLeaf) require.NoError(t, err) spendTx := wire.NewMsgTx(2) @@ -1187,6 +1227,26 @@ func TestTaprootCommitScriptToSelf(t *testing.T) { } } +// TestTaprootCommitScriptToSelf tests that the taproot script for redeeming +// one's output after a force close behaves as expected. +func TestTaprootCommitScriptToSelf(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some(txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + )) + } + + testTaprootCommitScriptToSelf(t, auxLeaf) + }) + } +} + func remoteCommitSweepWitGen(sigHash txscript.SigHashType, commitScriptTree *testCommitScriptTree) witnessGen { @@ -1220,12 +1280,8 @@ func remoteCommitSweepWitGen(sigHash txscript.SigHashType, } } -// TestTaprootCommitScriptRemote tests that the remote party can properly sweep -// their output after force close. -func TestTaprootCommitScriptRemote(t *testing.T) { - t.Parallel() - - commitScriptTree, err := newTestCommitScriptTree(false) +func testTaprootCommitScriptRemote(t *testing.T, auxLeaf AuxTapLeaf) { + commitScriptTree, err := newTestCommitScriptTree(false, auxLeaf) require.NoError(t, err) spendTx := wire.NewMsgTx(2) @@ -1364,6 +1420,26 @@ func TestTaprootCommitScriptRemote(t *testing.T) { } } +// TestTaprootCommitScriptRemote tests that the remote party can properly sweep +// their output after force close. +func TestTaprootCommitScriptRemote(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some(txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + )) + } + + testTaprootCommitScriptRemote(t, auxLeaf) + }) + } +} + type testAnchorScriptTree struct { sweepKey *btcec.PrivateKey @@ -1599,25 +1675,21 @@ type testSecondLevelHtlcTree struct { tapScriptRoot []byte } -func newTestSecondLevelHtlcTree() (*testSecondLevelHtlcTree, error) { +func newTestSecondLevelHtlcTree(t *testing.T, + auxLeaf AuxTapLeaf) *testSecondLevelHtlcTree { + delayKey, err := btcec.NewPrivateKey() - if err != nil { - return nil, err - } + require.NoError(t, err) revokeKey, err := btcec.NewPrivateKey() - if err != nil { - return nil, err - } + require.NoError(t, err) const csvDelay = 6 scriptTree, err := SecondLevelHtlcTapscriptTree( - delayKey.PubKey(), csvDelay, + delayKey.PubKey(), csvDelay, auxLeaf, ) - if err != nil { - return nil, err - } + require.NoError(t, err) tapScriptRoot := scriptTree.RootNode.TapHash() @@ -1626,9 +1698,7 @@ func newTestSecondLevelHtlcTree() (*testSecondLevelHtlcTree, error) { ) pkScript, err := PayToTaprootScript(htlcKey) - if err != nil { - return nil, err - } + require.NoError(t, err) const amt = 100 @@ -1643,7 +1713,7 @@ func newTestSecondLevelHtlcTree() (*testSecondLevelHtlcTree, error) { amt: amt, scriptTree: scriptTree, tapScriptRoot: tapScriptRoot[:], - }, nil + } } func secondLevelHtlcSuccessWitGen(sigHash txscript.SigHashType, @@ -1713,13 +1783,8 @@ func secondLevelHtlcRevokeWitnessgen(sigHash txscript.SigHashType, } } -// TestTaprootSecondLevelHtlcScript tests that a channel peer can properly -// spend the second level HTLC script to resolve HTLCs. -func TestTaprootSecondLevelHtlcScript(t *testing.T) { - t.Parallel() - - htlcScriptTree, err := newTestSecondLevelHtlcTree() - require.NoError(t, err) +func testTaprootSecondLevelHtlcScript(t *testing.T, auxLeaf AuxTapLeaf) { + htlcScriptTree := newTestSecondLevelHtlcTree(t, auxLeaf) spendTx := wire.NewMsgTx(2) spendTx.AddTxIn(&wire.TxIn{}) @@ -1879,3 +1944,23 @@ func TestTaprootSecondLevelHtlcScript(t *testing.T) { }) } } + +// TestTaprootSecondLevelHtlcScript tests that a channel peer can properly +// spend the second level HTLC script to resolve HTLCs. +func TestTaprootSecondLevelHtlcScript(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some(txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + )) + } + + testTaprootSecondLevelHtlcScript(t, auxLeaf) + }) + } +} diff --git a/itest/lnd_funding_test.go b/itest/lnd_funding_test.go index 6e2f0070c..a1c2e292d 100644 --- a/itest/lnd_funding_test.go +++ b/itest/lnd_funding_test.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainreg" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" @@ -1192,6 +1193,7 @@ func deriveFundingShim(ht *lntest.HarnessTest, carol, dave *node.HarnessNode, _, fundingOutput, err = input.GenTaprootFundingScript( carolKey, daveKey, int64(chanSize), + fn.None[chainhash.Hash](), ) require.NoError(ht, err) diff --git a/lnd.go b/lnd.go index e483d5512..da7747e91 100644 --- a/lnd.go +++ b/lnd.go @@ -456,7 +456,8 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, defer cleanUp() partialChainControl, walletConfig, cleanUp, err := implCfg.BuildWalletConfig( - ctx, dbs, interceptorChain, grpcListeners, + ctx, dbs, &implCfg.AuxComponents, interceptorChain, + grpcListeners, ) if err != nil { return mkErr("error creating wallet config: %v", err) diff --git a/lnwallet/aux_leaf_store.go b/lnwallet/aux_leaf_store.go new file mode 100644 index 000000000..4558c2f81 --- /dev/null +++ b/lnwallet/aux_leaf_store.go @@ -0,0 +1,239 @@ +package lnwallet + +import ( + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +// CommitSortFunc is a function type alias for a function that sorts the +// commitment transaction outputs. The second parameter is a list of CLTV +// timeouts that must correspond to the number of transaction outputs, with the +// value of 0 for non-HTLC outputs. The HTLC indexes are needed to have a +// deterministic sort value for HTLCs that have the identical amount, CLTV +// timeout and payment hash (e.g. multiple MPP shards of the same payment, where +// the on-chain script would be identical). +type CommitSortFunc func(tx *wire.MsgTx, cltvs []uint32, + indexes []input.HtlcIndex) error + +// DefaultCommitSort is the default commitment sort function that sorts the +// commitment transaction inputs and outputs according to BIP69. The second +// parameter is a list of CLTV timeouts that must correspond to the number of +// transaction outputs, with the value of 0 for non-HTLC outputs. The third +// parameter is unused for the default sort function. +func DefaultCommitSort(tx *wire.MsgTx, cltvs []uint32, + _ []input.HtlcIndex) error { + + InPlaceCommitSort(tx, cltvs) + return nil +} + +// CommitAuxLeaves stores two potential auxiliary leaves for the remote and +// local output that may be used to augment the final tapscript trees of the +// commitment transaction. +type CommitAuxLeaves struct { + // LocalAuxLeaf is the local party's auxiliary leaf. + LocalAuxLeaf input.AuxTapLeaf + + // RemoteAuxLeaf is the remote party's auxiliary leaf. + RemoteAuxLeaf input.AuxTapLeaf + + // OutgoingHTLCLeaves is the set of aux leaves for the outgoing HTLCs + // on this commitment transaction. + OutgoingHtlcLeaves input.HtlcAuxLeaves + + // IncomingHTLCLeaves is the set of aux leaves for the incoming HTLCs + // on this commitment transaction. + IncomingHtlcLeaves input.HtlcAuxLeaves +} + +// AuxChanState is a struct that holds certain fields of the +// channeldb.OpenChannel struct that are used by the aux components. The data +// is copied over to prevent accidental mutation of the original channel state. +type AuxChanState struct { + // ChanType denotes which type of channel this is. + ChanType channeldb.ChannelType + + // FundingOutpoint is the outpoint of the final funding transaction. + // This value uniquely and globally identifies the channel within the + // target blockchain as specified by the chain hash parameter. + FundingOutpoint wire.OutPoint + + // ShortChannelID encodes the exact location in the chain in which the + // channel was initially confirmed. This includes: the block height, + // transaction index, and the output within the target transaction. + // + // If IsZeroConf(), then this will the "base" (very first) ALIAS scid + // and the confirmed SCID will be stored in ConfirmedScid. + ShortChannelID lnwire.ShortChannelID + + // IsInitiator is a bool which indicates if we were the original + // initiator for the channel. This value may affect how higher levels + // negotiate fees, or close the channel. + IsInitiator bool + + // Capacity is the total capacity of this channel. + Capacity btcutil.Amount + + // LocalChanCfg is the channel configuration for the local node. + LocalChanCfg channeldb.ChannelConfig + + // RemoteChanCfg is the channel configuration for the remote node. + RemoteChanCfg channeldb.ChannelConfig + + // ThawHeight is the height when a frozen channel once again becomes a + // normal channel. If this is zero, then there're no restrictions on + // this channel. If the value is lower than 500,000, then it's + // interpreted as a relative height, or an absolute height otherwise. + ThawHeight uint32 + + // TapscriptRoot is an optional tapscript root used to derive the MuSig2 + // funding output. + TapscriptRoot fn.Option[chainhash.Hash] + + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This information is only created + // at channel funding time, and after wards is to be considered + // immutable. + CustomBlob fn.Option[tlv.Blob] +} + +// NewAuxChanState creates a new AuxChanState from the given channel state. +func NewAuxChanState(chanState *channeldb.OpenChannel) AuxChanState { + return AuxChanState{ + ChanType: chanState.ChanType, + FundingOutpoint: chanState.FundingOutpoint, + ShortChannelID: chanState.ShortChannelID, + IsInitiator: chanState.IsInitiator, + Capacity: chanState.Capacity, + LocalChanCfg: chanState.LocalChanCfg, + RemoteChanCfg: chanState.RemoteChanCfg, + ThawHeight: chanState.ThawHeight, + TapscriptRoot: chanState.TapscriptRoot, + CustomBlob: chanState.CustomBlob, + } +} + +// CommitDiffAuxInput is the input required to compute the diff of the auxiliary +// leaves for a commitment transaction. +type CommitDiffAuxInput struct { + // ChannelState is the static channel information of the channel this + // commitment transaction relates to. + ChannelState AuxChanState + + // PrevBlob is the blob of the previous commitment transaction. + PrevBlob tlv.Blob + + // UnfilteredView is the unfiltered, original HTLC view of the channel. + // Unfiltered in this context means that the view contains all HTLCs, + // including the canceled ones. + UnfilteredView *HtlcView + + // WhoseCommit denotes whose commitment transaction we are computing the + // diff for. + WhoseCommit lntypes.ChannelParty + + // OurBalance is the balance of the local party. + OurBalance lnwire.MilliSatoshi + + // TheirBalance is the balance of the remote party. + TheirBalance lnwire.MilliSatoshi + + // KeyRing is the key ring that can be used to derive keys for the + // commitment transaction. + KeyRing CommitmentKeyRing +} + +// CommitDiffAuxResult is the result of computing the diff of the auxiliary +// leaves for a commitment transaction. +type CommitDiffAuxResult struct { + // AuxLeaves are the auxiliary leaves for the new commitment + // transaction. + AuxLeaves fn.Option[CommitAuxLeaves] + + // CommitSortFunc is an optional function that sorts the commitment + // transaction inputs and outputs. + CommitSortFunc fn.Option[CommitSortFunc] +} + +// AuxLeafStore is used to optionally fetch auxiliary tapscript leaves for the +// commitment transaction given an opaque blob. This is also used to implement +// a state transition function for the blobs to allow them to be refreshed with +// each state. +type AuxLeafStore interface { + // FetchLeavesFromView attempts to fetch the auxiliary leaves that + // correspond to the passed aux blob, and pending original (unfiltered) + // HTLC view. + FetchLeavesFromView( + in CommitDiffAuxInput) fn.Result[CommitDiffAuxResult] + + // FetchLeavesFromCommit attempts to fetch the auxiliary leaves that + // correspond to the passed aux blob, and an existing channel + // commitment. + FetchLeavesFromCommit(chanState AuxChanState, + commit channeldb.ChannelCommitment, + keyRing CommitmentKeyRing) fn.Result[CommitDiffAuxResult] + + // FetchLeavesFromRevocation attempts to fetch the auxiliary leaves + // from a channel revocation that stores balance + blob information. + FetchLeavesFromRevocation( + r *channeldb.RevocationLog) fn.Result[CommitDiffAuxResult] + + // ApplyHtlcView serves as the state transition function for the custom + // channel's blob. Given the old blob, and an HTLC view, then a new + // blob should be returned that reflects the pending updates. + ApplyHtlcView(in CommitDiffAuxInput) fn.Result[fn.Option[tlv.Blob]] +} + +// auxLeavesFromView is used to derive the set of commit aux leaves (if any), +// that are needed to create a new commitment transaction using the original +// (unfiltered) htlc view. +func auxLeavesFromView(leafStore AuxLeafStore, chanState *channeldb.OpenChannel, + prevBlob fn.Option[tlv.Blob], originalView *HtlcView, + whoseCommit lntypes.ChannelParty, ourBalance, + theirBalance lnwire.MilliSatoshi, + keyRing CommitmentKeyRing) fn.Result[CommitDiffAuxResult] { + + return fn.MapOptionZ( + prevBlob, func(blob tlv.Blob) fn.Result[CommitDiffAuxResult] { + return leafStore.FetchLeavesFromView(CommitDiffAuxInput{ + ChannelState: NewAuxChanState(chanState), + PrevBlob: blob, + UnfilteredView: originalView, + WhoseCommit: whoseCommit, + OurBalance: ourBalance, + TheirBalance: theirBalance, + KeyRing: keyRing, + }) + }, + ) +} + +// updateAuxBlob is a helper function that attempts to update the aux blob +// given the prior and current state information. +func updateAuxBlob(leafStore AuxLeafStore, chanState *channeldb.OpenChannel, + prevBlob fn.Option[tlv.Blob], nextViewUnfiltered *HtlcView, + whoseCommit lntypes.ChannelParty, ourBalance, + theirBalance lnwire.MilliSatoshi, + keyRing CommitmentKeyRing) fn.Result[fn.Option[tlv.Blob]] { + + return fn.MapOptionZ( + prevBlob, func(blob tlv.Blob) fn.Result[fn.Option[tlv.Blob]] { + return leafStore.ApplyHtlcView(CommitDiffAuxInput{ + ChannelState: NewAuxChanState(chanState), + PrevBlob: blob, + UnfilteredView: nextViewUnfiltered, + WhoseCommit: whoseCommit, + OurBalance: ourBalance, + TheirBalance: theirBalance, + KeyRing: keyRing, + }) + }, + ) +} diff --git a/lnwallet/chancloser/chancloser_test.go b/lnwallet/chancloser/chancloser_test.go index 9a90d0ab2..a6688ed39 100644 --- a/lnwallet/chancloser/chancloser_test.go +++ b/lnwallet/chancloser/chancloser_test.go @@ -14,6 +14,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" @@ -178,8 +179,9 @@ func (m *mockChannel) RemoteUpfrontShutdownScript() lnwire.DeliveryAddress { } func (m *mockChannel) CreateCloseProposal(fee btcutil.Amount, - localScript, remoteScript []byte, _ ...lnwallet.ChanCloseOpt, -) (input.Signature, *chainhash.Hash, btcutil.Amount, error) { + localScript, remoteScript []byte, + _ ...lnwallet.ChanCloseOpt) (input.Signature, *chainhash.Hash, + btcutil.Amount, error) { if m.chanType.IsTaproot() { return lnwallet.NewMusigPartialSig( @@ -188,6 +190,7 @@ func (m *mockChannel) CreateCloseProposal(fee btcutil.Amount, R: new(btcec.PublicKey), }, lnwire.Musig2Nonce{}, lnwire.Musig2Nonce{}, nil, + fn.None[chainhash.Hash](), ), nil, 0, nil } diff --git a/lnwallet/chanfunding/canned_assembler.go b/lnwallet/chanfunding/canned_assembler.go index 21dd47339..177cc35c3 100644 --- a/lnwallet/chanfunding/canned_assembler.go +++ b/lnwallet/chanfunding/canned_assembler.go @@ -5,7 +5,9 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" ) @@ -56,6 +58,14 @@ type ShimIntent struct { // generate an aggregate key to use as the taproot-native multi-sig // output. musig2 bool + + // tapscriptRoot is the root of the tapscript tree that will be used to + // create the funding output. This field will only be utilized if the + // MuSig2 flag above is set to true. + // + // TODO(roasbeef): fold above into new chan type? sum type like thing, + // includes the tapscript root, etc + tapscriptRoot fn.Option[chainhash.Hash] } // FundingOutput returns the witness script, and the output that creates the @@ -76,9 +86,8 @@ func (s *ShimIntent) FundingOutput() ([]byte, *wire.TxOut, error) { // Similar to the existing p2wsh script, we'll always ensure // the keys are sorted before use. return input.GenTaprootFundingScript( - s.localKey.PubKey, - s.remoteKey, - int64(totalAmt), + s.localKey.PubKey, s.remoteKey, int64(totalAmt), + s.tapscriptRoot, ) } diff --git a/lnwallet/chanfunding/psbt_assembler.go b/lnwallet/chanfunding/psbt_assembler.go index 885fb7b46..10bcd7015 100644 --- a/lnwallet/chanfunding/psbt_assembler.go +++ b/lnwallet/chanfunding/psbt_assembler.go @@ -534,6 +534,7 @@ func (p *PsbtAssembler) ProvisionChannel(req *Request) (Intent, error) { ShimIntent: ShimIntent{ localFundingAmt: p.fundingAmt, musig2: req.Musig2, + tapscriptRoot: req.TapscriptRoot, }, State: PsbtShimRegistered, BasePsbt: p.basePsbt, diff --git a/lnwallet/chanfunding/wallet_assembler.go b/lnwallet/chanfunding/wallet_assembler.go index f824210e1..3d62649cb 100644 --- a/lnwallet/chanfunding/wallet_assembler.go +++ b/lnwallet/chanfunding/wallet_assembler.go @@ -394,7 +394,6 @@ func (w *WalletAssembler) ProvisionChannel(r *Request) (Intent, error) { // we will call the specialized coin selection function for // that. case r.FundUpToMaxAmt != 0 && r.MinFundAmt != 0: - // We need to ensure that manually selected coins, which // are spent entirely on the channel funding, leave // enough funds in the wallet to cover for a reserve. @@ -539,6 +538,7 @@ func (w *WalletAssembler) ProvisionChannel(r *Request) (Intent, error) { localFundingAmt: localContributionAmt, remoteFundingAmt: r.RemoteAmt, musig2: r.Musig2, + tapscriptRoot: r.TapscriptRoot, }, InputCoins: selectedCoins, coinLeaser: w.cfg.CoinLeaser, diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 7bd382d2c..e5cbe6b6a 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -33,6 +33,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -333,6 +334,10 @@ type commitment struct { // on this commitment transaction. incomingHTLCs []PaymentDescriptor + // customBlob stores opaque bytes that may be used by custom channels + // to store extra data for a given commitment state. + customBlob fn.Option[tlv.Blob] + // [outgoing|incoming]HTLCIndex is an index that maps an output index // on the commitment transaction to the payment descriptor that // represents the HTLC output. @@ -506,6 +511,7 @@ func (c *commitment) toDiskCommit( CommitTx: c.txn, CommitSig: c.sig, Htlcs: make([]channeldb.HTLC, 0, numHtlcs), + CustomBlob: c.customBlob, } for _, htlc := range c.outgoingHTLCs { @@ -523,6 +529,7 @@ func (c *commitment) toDiskCommit( LogIndex: htlc.LogIndex, Incoming: false, BlindingPoint: htlc.BlindingPoint, + CustomRecords: htlc.CustomRecords.Copy(), } copy(h.OnionBlob[:], htlc.OnionBlob) @@ -548,8 +555,10 @@ func (c *commitment) toDiskCommit( LogIndex: htlc.LogIndex, Incoming: true, BlindingPoint: htlc.BlindingPoint, + CustomRecords: htlc.CustomRecords.Copy(), } copy(h.OnionBlob[:], htlc.OnionBlob) + if whoseCommit.IsLocal() && htlc.sig != nil { h.Signature = htlc.sig.Serialize() } @@ -566,9 +575,9 @@ func (c *commitment) toDiskCommit( // restore commitment state written to disk back into memory once we need to // restart a channel session. func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, - htlc *channeldb.HTLC, localCommitKeys *CommitmentKeyRing, - remoteCommitKeys *CommitmentKeyRing, whoseCommit lntypes.ChannelParty, -) (PaymentDescriptor, error) { + htlc *channeldb.HTLC, commitKeys lntypes.Dual[*CommitmentKeyRing], + whoseCommit lntypes.ChannelParty, + auxLeaf input.AuxTapLeaf) (PaymentDescriptor, error) { // The proper pkScripts for this PaymentDescriptor must be // generated so we can easily locate them within the commitment @@ -589,10 +598,12 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, chanType, htlc.Incoming, lntypes.Local, feeRate, htlc.Amt.ToSatoshis(), lc.channelState.LocalChanCfg.DustLimit, ) + localCommitKeys := commitKeys.GetForParty(lntypes.Local) if !isDustLocal && localCommitKeys != nil { scriptInfo, err := genHtlcScript( chanType, htlc.Incoming, lntypes.Local, htlc.RefundTimeout, htlc.RHash, localCommitKeys, + auxLeaf, ) if err != nil { return pd, err @@ -604,10 +615,12 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, chanType, htlc.Incoming, lntypes.Remote, feeRate, htlc.Amt.ToSatoshis(), lc.channelState.RemoteChanCfg.DustLimit, ) + remoteCommitKeys := commitKeys.GetForParty(lntypes.Remote) if !isDustRemote && remoteCommitKeys != nil { scriptInfo, err := genHtlcScript( chanType, htlc.Incoming, lntypes.Remote, htlc.RefundTimeout, htlc.RHash, remoteCommitKeys, + auxLeaf, ) if err != nil { return pd, err @@ -647,6 +660,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, theirPkScript: theirP2WSH, theirWitnessScript: theirWitnessScript, BlindingPoint: htlc.BlindingPoint, + CustomRecords: htlc.CustomRecords.Copy(), }, nil } @@ -655,9 +669,10 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, // these payment descriptors can be re-inserted into the in-memory updateLog // for each side. func (lc *LightningChannel) extractPayDescs(feeRate chainfee.SatPerKWeight, - htlcs []channeldb.HTLC, localCommitKeys *CommitmentKeyRing, - remoteCommitKeys *CommitmentKeyRing, whoseCommit lntypes.ChannelParty, -) ([]PaymentDescriptor, []PaymentDescriptor, error) { + htlcs []channeldb.HTLC, commitKeys lntypes.Dual[*CommitmentKeyRing], + whoseCommit lntypes.ChannelParty, + auxLeaves fn.Option[CommitAuxLeaves]) ([]PaymentDescriptor, + []PaymentDescriptor, error) { var ( incomingHtlcs []PaymentDescriptor @@ -674,10 +689,19 @@ func (lc *LightningChannel) extractPayDescs(feeRate chainfee.SatPerKWeight, htlc := htlc + auxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.OutgoingHtlcLeaves + if htlc.Incoming { + leaves = l.IncomingHtlcLeaves + } + + return leaves[htlc.HtlcIndex].AuxTapLeaf + }, + )(auxLeaves) + payDesc, err := lc.diskHtlcToPayDesc( - feeRate, &htlc, - localCommitKeys, remoteCommitKeys, - whoseCommit, + feeRate, &htlc, commitKeys, whoseCommit, auxLeaf, ) if err != nil { return incomingHtlcs, outgoingHtlcs, err @@ -706,22 +730,35 @@ func (lc *LightningChannel) diskCommitToMemCommit( // (we extended but weren't able to complete the commitment dance // before shutdown), then the localCommitPoint won't be set as we // haven't yet received a responding commitment from the remote party. - var localCommitKeys, remoteCommitKeys *CommitmentKeyRing + var commitKeys lntypes.Dual[*CommitmentKeyRing] if localCommitPoint != nil { - localCommitKeys = DeriveCommitmentKeys( + commitKeys.SetForParty(lntypes.Local, DeriveCommitmentKeys( localCommitPoint, lntypes.Local, lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, - ) + )) } if remoteCommitPoint != nil { - remoteCommitKeys = DeriveCommitmentKeys( + commitKeys.SetForParty(lntypes.Remote, DeriveCommitmentKeys( remoteCommitPoint, lntypes.Remote, lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, - ) + )) + } + + auxResult, err := fn.MapOptionZ( + lc.leafStore, + func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return s.FetchLeavesFromCommit( + NewAuxChanState(lc.channelState), *diskCommit, + *commitKeys.GetForParty(whoseCommit), + ) + }, + ).Unpack() + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) } // With the key rings re-created, we'll now convert all the on-disk @@ -729,8 +766,7 @@ func (lc *LightningChannel) diskCommitToMemCommit( // update log. incomingHtlcs, outgoingHtlcs, err := lc.extractPayDescs( chainfee.SatPerKWeight(diskCommit.FeePerKw), - diskCommit.Htlcs, localCommitKeys, remoteCommitKeys, - whoseCommit, + diskCommit.Htlcs, commitKeys, whoseCommit, auxResult.AuxLeaves, ) if err != nil { return nil, err @@ -753,6 +789,7 @@ func (lc *LightningChannel) diskCommitToMemCommit( feePerKw: chainfee.SatPerKWeight(diskCommit.FeePerKw), incomingHTLCs: incomingHtlcs, outgoingHTLCs: outgoingHtlcs, + customBlob: diskCommit.CustomBlob, } if whoseCommit.IsLocal() { commit.dustLimit = lc.channelState.LocalChanCfg.DustLimit @@ -797,6 +834,10 @@ type LightningChannel struct { // machine. Signer input.Signer + // leafStore is used to retrieve extra tapscript leaves for special + // custom channel types. + leafStore fn.Option[AuxLeafStore] + // signDesc is the primary sign descriptor that is capable of signing // the commitment transaction that spends the multi-sig output. signDesc *input.SignDescriptor @@ -872,6 +913,8 @@ type channelOpts struct { localNonce *musig2.Nonces remoteNonce *musig2.Nonces + leafStore fn.Option[AuxLeafStore] + skipNonceInit bool } @@ -902,6 +945,13 @@ func WithSkipNonceInit() ChannelOpt { } } +// WithLeafStore is used to specify a custom leaf store for the channel. +func WithLeafStore(store AuxLeafStore) ChannelOpt { + return func(o *channelOpts) { + o.leafStore = fn.Some[AuxLeafStore](store) + } +} + // defaultChannelOpts returns the set of default options for a new channel. func defaultChannelOpts() *channelOpts { return &channelOpts{} @@ -943,13 +993,16 @@ func NewLightningChannel(signer input.Signer, } lc := &LightningChannel{ - Signer: signer, - sigPool: sigPool, - currentHeight: localCommit.CommitHeight, - remoteCommitChain: newCommitmentChain(), - localCommitChain: newCommitmentChain(), - channelState: state, - commitBuilder: NewCommitmentBuilder(state), + Signer: signer, + leafStore: opts.leafStore, + sigPool: sigPool, + currentHeight: localCommit.CommitHeight, + remoteCommitChain: newCommitmentChain(), + localCommitChain: newCommitmentChain(), + channelState: state, + commitBuilder: NewCommitmentBuilder( + state, opts.leafStore, + ), localUpdateLog: localUpdateLog, remoteUpdateLog: remoteUpdateLog, Capacity: state.Capacity, @@ -1013,6 +1066,7 @@ func (lc *LightningChannel) createSignDesc() error { if chanState.ChanType.IsTaproot() { fundingPkScript, _, err = input.GenTaprootFundingScript( localKey, remoteKey, int64(lc.channelState.Capacity), + chanState.TapscriptRoot, ) if err != nil { return err @@ -1065,7 +1119,8 @@ func (lc *LightningChannel) ResetState() { func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, remoteUpdateLog *updateLog, commitHeight uint64, feeRate chainfee.SatPerKWeight, remoteCommitKeys *CommitmentKeyRing, - remoteDustLimit btcutil.Amount) (*PaymentDescriptor, error) { + remoteDustLimit btcutil.Amount, + auxLeaves fn.Option[CommitAuxLeaves]) (*PaymentDescriptor, error) { // Depending on the type of update message we'll map that to a distinct // PaymentDescriptor instance. @@ -1101,10 +1156,17 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, feeRate, wireMsg.Amount.ToSatoshis(), remoteDustLimit, ) if !isDustRemote { + auxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.OutgoingHtlcLeaves + return leaves[pd.HtlcIndex].AuxTapLeaf + }, + )(auxLeaves) + scriptInfo, err := genHtlcScript( lc.channelState.ChanType, false, lntypes.Remote, wireMsg.Expiry, wireMsg.PaymentHash, - remoteCommitKeys, + remoteCommitKeys, auxLeaf, ) if err != nil { return nil, err @@ -1765,6 +1827,19 @@ func (lc *LightningChannel) restorePendingLocalUpdates( pendingCommit := pendingRemoteCommitDiff.Commitment pendingHeight := pendingCommit.CommitHeight + auxResult, err := fn.MapOptionZ( + lc.leafStore, + func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return s.FetchLeavesFromCommit( + NewAuxChanState(lc.channelState), pendingCommit, + *pendingRemoteKeys, + ) + }, + ).Unpack() + if err != nil { + return fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // If we did have a dangling commit, then we'll examine which updates // we included in that state and re-insert them into our update log. for _, logUpdate := range pendingRemoteCommitDiff.LogUpdates { @@ -1775,6 +1850,7 @@ func (lc *LightningChannel) restorePendingLocalUpdates( chainfee.SatPerKWeight(pendingCommit.FeePerKw), pendingRemoteKeys, lc.channelState.RemoteChanCfg.DustLimit, + auxResult.AuxLeaves, ) if err != nil { return err @@ -1937,7 +2013,8 @@ type BreachRetribution struct { // required to construct the BreachRetribution. If the revocation log is missing // the required fields then ErrRevLogDataMissing will be returned. func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, - breachHeight uint32, spendTx *wire.MsgTx) (*BreachRetribution, error) { + breachHeight uint32, spendTx *wire.MsgTx, + leafStore fn.Option[AuxLeafStore]) (*BreachRetribution, error) { // Query the on-disk revocation log for the snapshot which was recorded // at this particular state num. Based on whether a legacy revocation @@ -1980,21 +2057,40 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, leaseExpiry = chanState.ThawHeight } + auxResult, err := fn.MapOptionZ( + leafStore, func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return s.FetchLeavesFromRevocation(revokedLog) + }, + ).Unpack() + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // Since it is the remote breach we are reconstructing, the output // going to us will be a to-remote script with our local params. + remoteAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + }, + )(auxResult.AuxLeaves) isRemoteInitiator := !chanState.IsInitiator ourScript, ourDelay, err := CommitScriptToRemote( chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey, - leaseExpiry, + leaseExpiry, remoteAuxLeaf, ) if err != nil { return nil, err } + localAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + }, + )(auxResult.AuxLeaves) theirDelay := uint32(chanState.RemoteChanCfg.CsvDelay) theirScript, err := CommitScriptToSelf( chanState.ChanType, isRemoteInitiator, keyRing.ToLocalKey, - keyRing.RevocationKey, theirDelay, leaseExpiry, + keyRing.RevocationKey, theirDelay, leaseExpiry, localAuxLeaf, ) if err != nil { return nil, err @@ -2012,7 +2108,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, if revokedLog != nil { br, ourAmt, theirAmt, err = createBreachRetribution( revokedLog, spendTx, chanState, keyRing, - commitmentSecret, leaseExpiry, + commitmentSecret, leaseExpiry, auxResult.AuxLeaves, ) if err != nil { return nil, err @@ -2146,7 +2242,8 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, func createHtlcRetribution(chanState *channeldb.OpenChannel, keyRing *CommitmentKeyRing, commitHash chainhash.Hash, commitmentSecret *btcec.PrivateKey, leaseExpiry uint32, - htlc *channeldb.HTLCEntry) (HtlcRetribution, error) { + htlc *channeldb.HTLCEntry, + auxLeaves fn.Option[CommitAuxLeaves]) (HtlcRetribution, error) { var emptyRetribution HtlcRetribution @@ -2156,10 +2253,24 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // We'll generate the original second level witness script now, as // we'll need it if we're revoking an HTLC output on the remote // commitment transaction, and *they* go to the second level. + secondLevelAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) fn.Option[input.AuxTapLeaf] { + return fn.MapOption(func(val uint16) input.AuxTapLeaf { + idx := input.HtlcIndex(val) + + if htlc.Incoming.Val { + leaves := l.IncomingHtlcLeaves[idx] + return leaves.SecondLevelLeaf + } + + return l.OutgoingHtlcLeaves[idx].SecondLevelLeaf + })(htlc.HtlcIndex.ValOpt()) + }, + )(auxLeaves) secondLevelScript, err := SecondLevelHtlcScript( chanState.ChanType, isRemoteInitiator, keyRing.RevocationKey, keyRing.ToLocalKey, theirDelay, - leaseExpiry, + leaseExpiry, fn.FlattenOption(secondLevelAuxLeaf), ) if err != nil { return emptyRetribution, err @@ -2170,9 +2281,24 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // HTLC script. Otherwise, is this was an outgoing HTLC that we sent, // then from the PoV of the remote commitment state, they're the // receiver of this HTLC. + htlcLeaf := fn.ChainOption( + func(l CommitAuxLeaves) fn.Option[input.AuxTapLeaf] { + return fn.MapOption(func(val uint16) input.AuxTapLeaf { + idx := input.HtlcIndex(val) + + if htlc.Incoming.Val { + leaves := l.IncomingHtlcLeaves[idx] + return leaves.AuxTapLeaf + } + + return l.OutgoingHtlcLeaves[idx].AuxTapLeaf + })(htlc.HtlcIndex.ValOpt()) + }, + )(auxLeaves) scriptInfo, err := genHtlcScript( - chanState.ChanType, htlc.Incoming, lntypes.Remote, - htlc.RefundTimeout, htlc.RHash, keyRing, + chanState.ChanType, htlc.Incoming.Val, lntypes.Remote, + htlc.RefundTimeout.Val, htlc.RHash.Val, keyRing, + fn.FlattenOption(htlcLeaf), ) if err != nil { return emptyRetribution, err @@ -2185,7 +2311,7 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, WitnessScript: scriptInfo.WitnessScriptToSign(), Output: &wire.TxOut{ PkScript: scriptInfo.PkScript(), - Value: int64(htlc.Amt), + Value: int64(htlc.Amt.Val.Int()), }, HashType: sweepSigHash(chanState.ChanType), } @@ -2218,10 +2344,10 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, SignDesc: signDesc, OutPoint: wire.OutPoint{ Hash: commitHash, - Index: uint32(htlc.OutputIndex), + Index: uint32(htlc.OutputIndex.Val), }, SecondLevelWitnessScript: secondLevelWitnessScript, - IsIncoming: htlc.Incoming, + IsIncoming: htlc.Incoming.Val, SecondLevelTapTweak: secondLevelTapTweak, }, nil } @@ -2236,7 +2362,9 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, func createBreachRetribution(revokedLog *channeldb.RevocationLog, spendTx *wire.MsgTx, chanState *channeldb.OpenChannel, keyRing *CommitmentKeyRing, commitmentSecret *btcec.PrivateKey, - leaseExpiry uint32) (*BreachRetribution, int64, int64, error) { + leaseExpiry uint32, + auxLeaves fn.Option[CommitAuxLeaves]) (*BreachRetribution, int64, int64, + error) { commitHash := revokedLog.CommitTxHash @@ -2244,8 +2372,8 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, htlcRetributions := make([]HtlcRetribution, len(revokedLog.HTLCEntries)) for i, htlc := range revokedLog.HTLCEntries { hr, err := createHtlcRetribution( - chanState, keyRing, commitHash, - commitmentSecret, leaseExpiry, htlc, + chanState, keyRing, commitHash.Val, + commitmentSecret, leaseExpiry, htlc, auxLeaves, ) if err != nil { return nil, 0, 0, err @@ -2257,10 +2385,10 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, // Construct the our outpoint. ourOutpoint := wire.OutPoint{ - Hash: commitHash, + Hash: commitHash.Val, } - if revokedLog.OurOutputIndex != channeldb.OutputIndexEmpty { - ourOutpoint.Index = uint32(revokedLog.OurOutputIndex) + if revokedLog.OurOutputIndex.Val != channeldb.OutputIndexEmpty { + ourOutpoint.Index = uint32(revokedLog.OurOutputIndex.Val) // If the spend transaction is provided, then we use it to get // the value of our output. @@ -2283,26 +2411,29 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, // contains our output amount. Due to a previous // migration, this field may be empty in which case an // error will be returned. - if revokedLog.OurBalance == nil { - return nil, 0, 0, ErrRevLogDataMissing + b, err := revokedLog.OurBalance.ValOpt().UnwrapOrErr( + ErrRevLogDataMissing, + ) + if err != nil { + return nil, 0, 0, err } - ourAmt = int64(revokedLog.OurBalance.ToSatoshis()) + ourAmt = int64(b.Int().ToSatoshis()) } } // Construct the their outpoint. theirOutpoint := wire.OutPoint{ - Hash: commitHash, + Hash: commitHash.Val, } - if revokedLog.TheirOutputIndex != channeldb.OutputIndexEmpty { - theirOutpoint.Index = uint32(revokedLog.TheirOutputIndex) + if revokedLog.TheirOutputIndex.Val != channeldb.OutputIndexEmpty { + theirOutpoint.Index = uint32(revokedLog.TheirOutputIndex.Val) // If the spend transaction is provided, then we use it to get // the value of the remote parties' output. if spendTx != nil { // Sanity check that TheirOutputIndex is within range. - if int(revokedLog.TheirOutputIndex) >= + if int(revokedLog.TheirOutputIndex.Val) >= len(spendTx.TxOut) { return nil, 0, 0, fmt.Errorf("%w: theirs=%v, "+ @@ -2320,16 +2451,19 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, // contains remote parties' output amount. Due to a // previous migration, this field may be empty in which // case an error will be returned. - if revokedLog.TheirBalance == nil { - return nil, 0, 0, ErrRevLogDataMissing + b, err := revokedLog.TheirBalance.ValOpt().UnwrapOrErr( + ErrRevLogDataMissing, + ) + if err != nil { + return nil, 0, 0, err } - theirAmt = int64(revokedLog.TheirBalance.ToSatoshis()) + theirAmt = int64(b.Int().ToSatoshis()) } } return &BreachRetribution{ - BreachTxHash: commitHash, + BreachTxHash: commitHash.Val, ChainHash: chanState.ChainHash, LocalOutpoint: ourOutpoint, RemoteOutpoint: theirOutpoint, @@ -2383,16 +2517,15 @@ func createBreachRetributionLegacy(revokedLog *channeldb.ChannelCommitment, continue } - entry := &channeldb.HTLCEntry{ - RHash: htlc.RHash, - RefundTimeout: htlc.RefundTimeout, - OutputIndex: uint16(htlc.OutputIndex), - Incoming: htlc.Incoming, - Amt: htlc.Amt.ToSatoshis(), + entry, err := channeldb.NewHTLCEntryFromHTLC(htlc) + if err != nil { + return nil, 0, 0, err } + hr, err := createHtlcRetribution( chanState, keyRing, commitHash, commitmentSecret, leaseExpiry, entry, + fn.None[CommitAuxLeaves](), ) if err != nil { return nil, 0, 0, err @@ -2458,18 +2591,29 @@ func HtlcIsDust(chanType channeldb.ChannelType, return (htlcAmt - htlcFee) < dustLimit } -// htlcView represents the "active" HTLCs at a particular point within the +// HtlcView represents the "active" HTLCs at a particular point within the // history of the HTLC update log. -type htlcView struct { - ourUpdates []*PaymentDescriptor - theirUpdates []*PaymentDescriptor - feePerKw chainfee.SatPerKWeight +type HtlcView struct { + // NextHeight is the height of the commitment transaction that will be + // created using this view. + NextHeight uint64 + + // OurUpdates are our outgoing HTLCs. + OurUpdates []*PaymentDescriptor + + // TheirUpdates are their incoming HTLCs. + TheirUpdates []*PaymentDescriptor + + // FeePerKw is the fee rate in sat/kw of the commitment transaction. + FeePerKw chainfee.SatPerKWeight } // fetchHTLCView returns all the candidate HTLC updates which should be // considered for inclusion within a commitment based on the passed HTLC log // indexes. -func (lc *LightningChannel) fetchHTLCView(theirLogIndex, ourLogIndex uint64) *htlcView { +func (lc *LightningChannel) fetchHTLCView(theirLogIndex, + ourLogIndex uint64) *HtlcView { + var ourHTLCs []*PaymentDescriptor for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() { htlc := e.Value @@ -2494,9 +2638,9 @@ func (lc *LightningChannel) fetchHTLCView(theirLogIndex, ourLogIndex uint64) *ht } } - return &htlcView{ - ourUpdates: ourHTLCs, - theirUpdates: theirHTLCs, + return &HtlcView{ + OurUpdates: ourHTLCs, + TheirUpdates: theirHTLCs, } } @@ -2533,12 +2677,16 @@ func (lc *LightningChannel) fetchCommitmentView( if err != nil { return nil, err } - feePerKw := filteredHTLCView.feePerKw + feePerKw := filteredHTLCView.FeePerKw + + htlcView.NextHeight = nextHeight + filteredHTLCView.NextHeight = nextHeight // Actually generate unsigned commitment transaction for this view. commitTx, err := lc.commitBuilder.createUnsignedCommitmentTx( ourBalance, theirBalance, whoseCommitChain, feePerKw, - nextHeight, filteredHTLCView, keyRing, + nextHeight, htlcView, filteredHTLCView, keyRing, + commitChain.tip(), ) if err != nil { return nil, err @@ -2573,6 +2721,23 @@ func (lc *LightningChannel) fetchCommitmentView( effFeeRate, spew.Sdump(commitTx)) } + // Given the custom blob of the past state, and this new HTLC view, + // we'll generate a new blob for the latest commitment. + newCommitBlob, err := fn.MapOptionZ( + lc.leafStore, + func(s AuxLeafStore) fn.Result[fn.Option[tlv.Blob]] { + return updateAuxBlob( + s, lc.channelState, + commitChain.tip().customBlob, htlcView, + whoseCommitChain, ourBalance, theirBalance, + *keyRing, + ) + }, + ).Unpack() + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // With the commitment view created, store the resulting balances and // transaction with the other parameters for this height. c := &commitment{ @@ -2588,17 +2753,22 @@ func (lc *LightningChannel) fetchCommitmentView( feePerKw: feePerKw, dustLimit: dustLimit, whoseCommit: whoseCommitChain, + customBlob: newCommitBlob, } // In order to ensure _none_ of the HTLC's associated with this new // commitment are mutated, we'll manually copy over each HTLC to its // respective slice. - c.outgoingHTLCs = make([]PaymentDescriptor, len(filteredHTLCView.ourUpdates)) - for i, htlc := range filteredHTLCView.ourUpdates { + c.outgoingHTLCs = make( + []PaymentDescriptor, len(filteredHTLCView.OurUpdates), + ) + for i, htlc := range filteredHTLCView.OurUpdates { c.outgoingHTLCs[i] = *htlc } - c.incomingHTLCs = make([]PaymentDescriptor, len(filteredHTLCView.theirUpdates)) - for i, htlc := range filteredHTLCView.theirUpdates { + c.incomingHTLCs = make( + []PaymentDescriptor, len(filteredHTLCView.TheirUpdates), + ) + for i, htlc := range filteredHTLCView.TheirUpdates { c.incomingHTLCs[i] = *htlc } @@ -2633,16 +2803,17 @@ func fundingTxIn(chanState *channeldb.OpenChannel) wire.TxIn { // once for each height, and only in concert with signing a new commitment. // TODO(halseth): return htlcs to mutate instead of mutating inside // method. -func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, +func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, theirBalance *lnwire.MilliSatoshi, nextHeight uint64, - whoseCommitChain lntypes.ChannelParty, mutateState bool, -) (*htlcView, error) { + whoseCommitChain lntypes.ChannelParty, mutateState bool) (*HtlcView, + error) { // We initialize the view's fee rate to the fee rate of the unfiltered // view. If any fee updates are found when evaluating the view, it will // be updated. - newView := &htlcView{ - feePerKw: view.feePerKw, + newView := &HtlcView{ + FeePerKw: view.FeePerKw, + NextHeight: nextHeight, } // We use two maps, one for the local log and one for the remote log to @@ -2655,7 +2826,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, // First we run through non-add entries in both logs, populating the // skip sets and mutating the current chain state (crediting balances, // etc) to reflect the settle/timeout entry encountered. - for _, entry := range view.ourUpdates { + for _, entry := range view.OurUpdates { switch entry.EntryType { // Skip adds for now. They will be processed below. case Add: @@ -2676,6 +2847,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, if mutateState && entry.EntryType == Settle && whoseCommitChain.IsLocal() && entry.removeCommitHeightLocal == 0 { + lc.channelState.TotalMSatReceived += entry.Amount } @@ -2687,10 +2859,13 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, } skipThem[addEntry.HtlcIndex] = struct{}{} - processRemoveEntry(entry, ourBalance, theirBalance, - nextHeight, whoseCommitChain, true, mutateState) + + processRemoveEntry( + entry, ourBalance, theirBalance, nextHeight, + whoseCommitChain, true, mutateState, + ) } - for _, entry := range view.theirUpdates { + for _, entry := range view.TheirUpdates { switch entry.EntryType { // Skip adds for now. They will be processed below. case Add: @@ -2724,32 +2899,41 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, } skipUs[addEntry.HtlcIndex] = struct{}{} - processRemoveEntry(entry, ourBalance, theirBalance, - nextHeight, whoseCommitChain, false, mutateState) + + processRemoveEntry( + entry, ourBalance, theirBalance, nextHeight, + whoseCommitChain, false, mutateState, + ) } // Next we take a second pass through all the log entries, skipping any // settled HTLCs, and debiting the chain state balance due to any newly // added HTLCs. - for _, entry := range view.ourUpdates { + for _, entry := range view.OurUpdates { isAdd := entry.EntryType == Add if _, ok := skipUs[entry.HtlcIndex]; !isAdd || ok { continue } - processAddEntry(entry, ourBalance, theirBalance, nextHeight, - whoseCommitChain, false, mutateState) - newView.ourUpdates = append(newView.ourUpdates, entry) + processAddEntry( + entry, ourBalance, theirBalance, nextHeight, + whoseCommitChain, false, mutateState, + ) + + newView.OurUpdates = append(newView.OurUpdates, entry) } - for _, entry := range view.theirUpdates { + for _, entry := range view.TheirUpdates { isAdd := entry.EntryType == Add if _, ok := skipThem[entry.HtlcIndex]; !isAdd || ok { continue } - processAddEntry(entry, ourBalance, theirBalance, nextHeight, - whoseCommitChain, true, mutateState) - newView.theirUpdates = append(newView.theirUpdates, entry) + processAddEntry( + entry, ourBalance, theirBalance, nextHeight, + whoseCommitChain, true, mutateState, + ) + + newView.TheirUpdates = append(newView.TheirUpdates, entry) } return newView, nil @@ -2900,8 +3084,8 @@ func processRemoveEntry(htlc *PaymentDescriptor, ourBalance, // processFeeUpdate processes a log update that updates the current commitment // fee. func processFeeUpdate(feeUpdate *PaymentDescriptor, nextHeight uint64, - whoseCommitChain lntypes.ChannelParty, mutateState bool, view *htlcView, -) { + whoseCommitChain lntypes.ChannelParty, mutateState bool, + view *HtlcView) { // Fee updates are applied for all commitments after they are // sent/received, so we consider them being added and removed at the @@ -2922,7 +3106,7 @@ func processFeeUpdate(feeUpdate *PaymentDescriptor, nextHeight uint64, // If the update wasn't already locked in, update the current fee rate // to reflect this update. - view.feePerKw = chainfee.SatPerKWeight(feeUpdate.Amount.ToSatoshis()) + view.FeePerKw = chainfee.SatPerKWeight(feeUpdate.Amount.ToSatoshis()) if mutateState { *addHeight = nextHeight @@ -2937,9 +3121,16 @@ func processFeeUpdate(feeUpdate *PaymentDescriptor, nextHeight uint64, // signature can be submitted to the sigPool to generate all the signatures // asynchronously and in parallel. func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, - chanType channeldb.ChannelType, isRemoteInitiator bool, - leaseExpiry uint32, localChanCfg, remoteChanCfg *channeldb.ChannelConfig, - remoteCommitView *commitment) ([]SignJob, chan struct{}, error) { + chanState *channeldb.OpenChannel, leaseExpiry uint32, + remoteCommitView *commitment, + leafStore fn.Option[AuxLeafStore]) ([]SignJob, chan struct{}, error) { + + var ( + isRemoteInitiator = !chanState.IsInitiator + localChanCfg = chanState.LocalChanCfg + remoteChanCfg = chanState.RemoteChanCfg + chanType = chanState.ChanType + ) txHash := remoteCommitView.txn.TxHash() dustLimit := remoteChanCfg.DustLimit @@ -2949,13 +3140,27 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, // With the keys generated, we'll make a slice with enough capacity to // hold potentially all the HTLCs. The actual slice may be a bit // smaller (than its total capacity) and some HTLCs may be dust. - numSigs := (len(remoteCommitView.incomingHTLCs) + - len(remoteCommitView.outgoingHTLCs)) + numSigs := len(remoteCommitView.incomingHTLCs) + + len(remoteCommitView.outgoingHTLCs) sigBatch := make([]SignJob, 0, numSigs) var err error cancelChan := make(chan struct{}) + auxResult, err := fn.MapOptionZ( + leafStore, func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return s.FetchLeavesFromCommit( + NewAuxChanState(chanState), + *remoteCommitView.toDiskCommit(lntypes.Remote), + *keyRing, + ) + }, + ).Unpack() + if err != nil { + return nil, nil, fmt.Errorf("unable to fetch aux leaves: %w", + err) + } + // For each outgoing and incoming HTLC, if the HTLC isn't considered a // dust output after taking into account second-level HTLC fees, then a // sigJob will be generated and appended to the current batch. @@ -2982,6 +3187,13 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, htlcFee := HtlcTimeoutFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee + auxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.IncomingHtlcLeaves + return leaves[htlc.HtlcIndex].SecondLevelLeaf + }, + )(auxResult.AuxLeaves) + // With the fee calculate, we can properly create the HTLC // timeout transaction using the HTLC amount minus the fee. op := wire.OutPoint{ @@ -2992,11 +3204,15 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, chanType, isRemoteInitiator, op, outputAmt, htlc.Timeout, uint32(remoteChanCfg.CsvDelay), leaseExpiry, keyRing.RevocationKey, keyRing.ToLocalKey, + auxLeaf, ) if err != nil { return nil, nil, err } + // TODO(roasbeef): hook up signer interface here (later commit + // in this PR). + // Construct a full hash cache as we may be signing a segwit v1 // sighash. txOut := remoteCommitView.txn.TxOut[htlc.remoteOutputIndex] @@ -3023,7 +3239,8 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, // If this is a taproot channel, then we'll need to set the // method type to ensure we generate a valid signature. if chanType.IsTaproot() { - sigJob.SignDesc.SignMethod = input.TaprootScriptSpendSignMethod //nolint:lll + //nolint:lll + sigJob.SignDesc.SignMethod = input.TaprootScriptSpendSignMethod } sigBatch = append(sigBatch, sigJob) @@ -3049,6 +3266,13 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, htlcFee := HtlcSuccessFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee + auxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.OutgoingHtlcLeaves + return leaves[htlc.HtlcIndex].SecondLevelLeaf + }, + )(auxResult.AuxLeaves) + // With the proper output amount calculated, we can now // generate the success transaction using the remote party's // CSV delay. @@ -3060,6 +3284,7 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, chanType, isRemoteInitiator, op, outputAmt, uint32(remoteChanCfg.CsvDelay), leaseExpiry, keyRing.RevocationKey, keyRing.ToLocalKey, + auxLeaf, ) if err != nil { return nil, nil, err @@ -3105,9 +3330,9 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, // validate this new state. This function is called right before sending the // new commitment to the remote party. The commit diff returned contains all // information necessary for retransmission. -func (lc *LightningChannel) createCommitDiff( - newCommit *commitment, commitSig lnwire.Sig, - htlcSigs []lnwire.Sig) (*channeldb.CommitDiff, error) { +func (lc *LightningChannel) createCommitDiff(newCommit *commitment, + commitSig lnwire.Sig, htlcSigs []lnwire.Sig) (*channeldb.CommitDiff, + error) { // First, we need to convert the funding outpoint into the ID that's // used on the wire to identify this channel. We'll use this shortly @@ -3498,10 +3723,10 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // appropriate update log, in order to validate the sanity of the // commitment resulting from _actually adding_ this HTLC to the state. if predictOurAdd != nil { - view.ourUpdates = append(view.ourUpdates, predictOurAdd) + view.OurUpdates = append(view.OurUpdates, predictOurAdd) } if predictTheirAdd != nil { - view.theirUpdates = append(view.theirUpdates, predictTheirAdd) + view.TheirUpdates = append(view.TheirUpdates, predictTheirAdd) } ourBalance, theirBalance, commitWeight, filteredView, err := lc.computeView( @@ -3512,7 +3737,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, return err } - feePerKw := filteredView.feePerKw + feePerKw := filteredView.FeePerKw // Ensure that the fee being applied is enough to be relayed across the // network in a reasonable time frame. @@ -3656,7 +3881,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // First check that the remote updates won't violate it's channel // constraints. err = validateUpdates( - filteredView.theirUpdates, &lc.channelState.RemoteChanCfg, + filteredView.TheirUpdates, &lc.channelState.RemoteChanCfg, ) if err != nil { return err @@ -3665,7 +3890,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // Secondly check that our updates won't violate our channel // constraints. err = validateUpdates( - filteredView.ourUpdates, &lc.channelState.LocalChanCfg, + filteredView.OurUpdates, &lc.channelState.LocalChanCfg, ) if err != nil { return err @@ -3806,9 +4031,8 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { leaseExpiry = lc.channelState.ThawHeight } sigBatch, cancelChan, err := genRemoteHtlcSigJobs( - keyRing, lc.channelState.ChanType, !lc.channelState.IsInitiator, - leaseExpiry, &lc.channelState.LocalChanCfg, - &lc.channelState.RemoteChanCfg, newCommitView, + keyRing, lc.channelState, leaseExpiry, newCommitView, + lc.leafStore, ) if err != nil { return nil, err @@ -4265,7 +4489,7 @@ func (lc *LightningChannel) ProcessChanSyncMsg( return updates, openedCircuits, closedCircuits, nil } -// computeView takes the given htlcView, and calculates the balances, filtered +// computeView takes the given HtlcView, and calculates the balances, filtered // view (settling unsettled HTLCs), commitment weight and feePerKw, after // applying the HTLCs to the latest commitment. The returned balances are the // balances *before* subtracting the commitment fee from the initiator's @@ -4274,10 +4498,10 @@ func (lc *LightningChannel) ProcessChanSyncMsg( // // If the updateState boolean is set true, the add and remove heights of the // HTLCs will be set to the next commitment height. -func (lc *LightningChannel) computeView(view *htlcView, +func (lc *LightningChannel) computeView(view *HtlcView, whoseCommitChain lntypes.ChannelParty, updateState bool, dryRunFee fn.Option[chainfee.SatPerKWeight]) (lnwire.MilliSatoshi, - lnwire.MilliSatoshi, lntypes.WeightUnit, *htlcView, error) { + lnwire.MilliSatoshi, lntypes.WeightUnit, *HtlcView, error) { commitChain := lc.localCommitChain dustLimit := lc.channelState.LocalChanCfg.DustLimit @@ -4308,7 +4532,7 @@ func (lc *LightningChannel) computeView(view *htlcView, // Initiate feePerKw to the last committed fee for this chain as we'll // need this to determine which HTLCs are dust, and also the final fee // rate. - view.feePerKw = commitChain.tip().feePerKw + view.FeePerKw = commitChain.tip().feePerKw // We evaluate the view at this stage, meaning settled and failed HTLCs // will remove their corresponding added HTLCs. The resulting filtered @@ -4316,12 +4540,14 @@ func (lc *LightningChannel) computeView(view *htlcView, // channel constraints to the final commitment state. If any fee // updates are found in the logs, the commitment fee rate should be // changed, so we'll also set the feePerKw to this new value. - filteredHTLCView, err := lc.evaluateHTLCView(view, &ourBalance, - &theirBalance, nextHeight, whoseCommitChain, updateState) + filteredHTLCView, err := lc.evaluateHTLCView( + view, &ourBalance, &theirBalance, nextHeight, whoseCommitChain, + updateState, + ) if err != nil { return 0, 0, 0, nil, err } - feePerKw := filteredHTLCView.feePerKw + feePerKw := filteredHTLCView.FeePerKw // Here we override the view's fee-rate if a dry-run fee-rate was // passed in. @@ -4345,7 +4571,7 @@ func (lc *LightningChannel) computeView(view *htlcView, // Now go through all HTLCs at this stage, to calculate the total // weight, needed to calculate the transaction fee. var totalHtlcWeight lntypes.WeightUnit - for _, htlc := range filteredHTLCView.ourUpdates { + for _, htlc := range filteredHTLCView.OurUpdates { if HtlcIsDust( lc.channelState.ChanType, false, whoseCommitChain, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -4356,7 +4582,7 @@ func (lc *LightningChannel) computeView(view *htlcView, totalHtlcWeight += input.HTLCWeight } - for _, htlc := range filteredHTLCView.theirUpdates { + for _, htlc := range filteredHTLCView.TheirUpdates { if HtlcIsDust( lc.channelState.ChanType, true, whoseCommitChain, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -4377,10 +4603,18 @@ func (lc *LightningChannel) computeView(view *htlcView, // meant to verify all the signatures for HTLC's attached to a newly created // commitment state. The jobs generated are fully populated, and can be sent // directly into the pool of workers. -func genHtlcSigValidationJobs(localCommitmentView *commitment, - keyRing *CommitmentKeyRing, htlcSigs []lnwire.Sig, - chanType channeldb.ChannelType, isLocalInitiator bool, leaseExpiry uint32, - localChanCfg, remoteChanCfg *channeldb.ChannelConfig) ([]VerifyJob, error) { +// +//nolint:funlen +func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, + localCommitmentView *commitment, keyRing *CommitmentKeyRing, + htlcSigs []lnwire.Sig, leaseExpiry uint32, + leafStore fn.Option[AuxLeafStore]) ([]VerifyJob, error) { + + var ( + isLocalInitiator = chanState.IsInitiator + localChanCfg = chanState.LocalChanCfg + chanType = chanState.ChanType + ) txHash := localCommitmentView.txn.TxHash() feePerKw := localCommitmentView.feePerKw @@ -4390,10 +4624,24 @@ func genHtlcSigValidationJobs(localCommitmentView *commitment, // enough capacity to hold verification jobs for all HTLC's in this // view. In the case that we have some dust outputs, then the actual // length will be smaller than the total capacity. - numHtlcs := (len(localCommitmentView.incomingHTLCs) + - len(localCommitmentView.outgoingHTLCs)) + numHtlcs := len(localCommitmentView.incomingHTLCs) + + len(localCommitmentView.outgoingHTLCs) verifyJobs := make([]VerifyJob, 0, numHtlcs) + auxResult, err := fn.MapOptionZ( + leafStore, func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return s.FetchLeavesFromCommit( + NewAuxChanState(chanState), + *localCommitmentView.toDiskCommit( + lntypes.Local, + ), *keyRing, + ) + }, + ).Unpack() + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // We'll iterate through each output in the commitment transaction, // populating the sigHash closure function if it's detected to be an // HLTC output. Given the sighash, and the signing key, we'll be able @@ -4427,11 +4675,19 @@ func genHtlcSigValidationJobs(localCommitmentView *commitment, htlcFee := HtlcSuccessFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee + auxLeaf := fn.ChainOption(func( + l CommitAuxLeaves) input.AuxTapLeaf { + + leaves := l.IncomingHtlcLeaves + idx := htlc.HtlcIndex + return leaves[idx].SecondLevelLeaf + })(auxResult.AuxLeaves) + successTx, err := CreateHtlcSuccessTx( chanType, isLocalInitiator, op, outputAmt, uint32(localChanCfg.CsvDelay), leaseExpiry, keyRing.RevocationKey, - keyRing.ToLocalKey, + keyRing.ToLocalKey, auxLeaf, ) if err != nil { return nil, err @@ -4511,12 +4767,20 @@ func genHtlcSigValidationJobs(localCommitmentView *commitment, htlcFee := HtlcTimeoutFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee + auxLeaf := fn.ChainOption(func( + l CommitAuxLeaves) input.AuxTapLeaf { + + leaves := l.OutgoingHtlcLeaves + idx := htlc.HtlcIndex + return leaves[idx].SecondLevelLeaf + })(auxResult.AuxLeaves) + timeoutTx, err := CreateHtlcTimeoutTx( chanType, isLocalInitiator, op, outputAmt, htlc.Timeout, uint32(localChanCfg.CsvDelay), leaseExpiry, keyRing.RevocationKey, - keyRing.ToLocalKey, + keyRing.ToLocalKey, auxLeaf, ) if err != nil { return nil, err @@ -4770,10 +5034,8 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error { leaseExpiry = lc.channelState.ThawHeight } verifyJobs, err := genHtlcSigValidationJobs( - localCommitmentView, keyRing, commitSigs.HtlcSigs, - lc.channelState.ChanType, lc.channelState.IsInitiator, - leaseExpiry, &lc.channelState.LocalChanCfg, - &lc.channelState.RemoteChanCfg, + lc.channelState, localCommitmentView, keyRing, + commitSigs.HtlcSigs, leaseExpiry, lc.leafStore, ) if err != nil { return err @@ -5357,7 +5619,7 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) ( // before the change since the indexes are meant for the current, // revoked remote commitment. ourOutputIndex, theirOutputIndex, err := findOutputIndexesFromRemote( - revocation, lc.channelState, + revocation, lc.channelState, lc.leafStore, ) if err != nil { return nil, nil, nil, nil, err @@ -5647,6 +5909,9 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC, OnionBlob: htlc.OnionBlob[:], OpenCircuitKey: openKey, BlindingPoint: htlc.BlindingPoint, + // TODO(guggero): Add custom records from HTLC here once we have + // the custom records in the HTLC struct (later commits in this + // PR). } } @@ -5687,7 +5952,9 @@ func (lc *LightningChannel) validateAddHtlc(pd *PaymentDescriptor, // ReceiveHTLC adds an HTLC to the state machine's remote update log. This // method should be called in response to receiving a new HTLC from the remote // party. -func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, error) { +func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, + error) { + lc.Lock() defer lc.Unlock() @@ -5705,6 +5972,9 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, err HtlcIndex: lc.remoteUpdateLog.htlcCounter, OnionBlob: htlc.OnionBlob[:], BlindingPoint: htlc.BlindingPoint, + // TODO(guggero): Add custom records from HTLC here once we have + // the custom records in the HTLC struct (later commits in this + // PR). } localACKedIndex := lc.remoteCommitChain.tail().ourMessageIndex @@ -6056,11 +6326,15 @@ func (lc *LightningChannel) getSignedCommitTx() (*wire.MsgTx, error) { "verification nonce: %w", err) } + tapscriptTweak := fn.MapOption(TapscriptRootToTweak)( + lc.channelState.TapscriptRoot, + ) + // Now that we have the local nonce, we'll re-create the musig // session we had for this height. musigSession := NewPartialMusigSession( *localNonce, ourKey, theirKey, lc.Signer, - &lc.fundingOutput, LocalMusigCommit, + &lc.fundingOutput, LocalMusigCommit, tapscriptTweak, ) var remoteSig lnwire.PartialSigWithNonce @@ -6208,10 +6482,10 @@ type UnilateralCloseSummary struct { // happen in case we have lost state) it should be set to an empty struct, in // which case we will attempt to sweep the non-HTLC output using the passed // commitPoint. -func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Signer, - commitSpend *chainntnfs.SpendDetail, - remoteCommit channeldb.ChannelCommitment, - commitPoint *btcec.PublicKey) (*UnilateralCloseSummary, error) { +func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, + signer input.Signer, commitSpend *chainntnfs.SpendDetail, + remoteCommit channeldb.ChannelCommitment, commitPoint *btcec.PublicKey, + leafStore fn.Option[AuxLeafStore]) (*UnilateralCloseSummary, error) { // First, we'll generate the commitment point and the revocation point // so we can re-construct the HTLC state and also our payment key. @@ -6221,6 +6495,18 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Si &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) + auxResult, err := fn.MapOptionZ( + leafStore, func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return s.FetchLeavesFromCommit( + NewAuxChanState(chanState), remoteCommit, + *keyRing, + ) + }, + ).Unpack() + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // Next, we'll obtain HTLC resolutions for all the outgoing HTLC's we // had on their commitment transaction. var leaseExpiry uint32 @@ -6233,6 +6519,7 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Si signer, remoteCommit.Htlcs, keyRing, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, commitSpend.SpendingTx, chanState.ChanType, isRemoteInitiator, leaseExpiry, + auxResult.AuxLeaves, ) if err != nil { return nil, fmt.Errorf("unable to create htlc "+ @@ -6244,9 +6531,14 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Si // Before we can generate the proper sign descriptor, we'll need to // locate the output index of our non-delayed output on the commitment // transaction. + remoteAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + }, + )(auxResult.AuxLeaves) selfScript, maturityDelay, err := CommitScriptToRemote( chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey, - leaseExpiry, + leaseExpiry, remoteAuxLeaf, ) if err != nil { return nil, fmt.Errorf("unable to create self commit "+ @@ -6486,7 +6778,8 @@ func newOutgoingHtlcResolution(signer input.Signer, htlc *channeldb.HTLC, keyRing *CommitmentKeyRing, feePerKw chainfee.SatPerKWeight, csvDelay, leaseExpiry uint32, whoseCommit lntypes.ChannelParty, isCommitFromInitiator bool, - chanType channeldb.ChannelType) (*OutgoingHtlcResolution, error) { + chanType channeldb.ChannelType, + auxLeaves fn.Option[CommitAuxLeaves]) (*OutgoingHtlcResolution, error) { op := wire.OutPoint{ Hash: commitTx.TxHash(), @@ -6495,9 +6788,12 @@ func newOutgoingHtlcResolution(signer input.Signer, // First, we'll re-generate the script used to send the HTLC to the // remote party within their commitment transaction. + auxLeaf := fn.ChainOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.OutgoingHtlcLeaves[htlc.HtlcIndex].AuxTapLeaf + })(auxLeaves) htlcScriptInfo, err := genHtlcScript( chanType, false, whoseCommit, htlc.RefundTimeout, htlc.RHash, - keyRing, + keyRing, auxLeaf, ) if err != nil { return nil, err @@ -6570,10 +6866,16 @@ func newOutgoingHtlcResolution(signer input.Signer, // With the fee calculated, re-construct the second level timeout // transaction. + secondLevelAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.OutgoingHtlcLeaves + return leaves[htlc.HtlcIndex].SecondLevelLeaf + }, + )(auxLeaves) timeoutTx, err := CreateHtlcTimeoutTx( chanType, isCommitFromInitiator, op, secondLevelOutputAmt, - htlc.RefundTimeout, csvDelay, leaseExpiry, keyRing.RevocationKey, - keyRing.ToLocalKey, + htlc.RefundTimeout, csvDelay, leaseExpiry, + keyRing.RevocationKey, keyRing.ToLocalKey, secondLevelAuxLeaf, ) if err != nil { return nil, err @@ -6656,6 +6958,7 @@ func newOutgoingHtlcResolution(signer input.Signer, htlcSweepScript, err = SecondLevelHtlcScript( chanType, isCommitFromInitiator, keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, leaseExpiry, + secondLevelAuxLeaf, ) if err != nil { return nil, err @@ -6664,6 +6967,7 @@ func newOutgoingHtlcResolution(signer input.Signer, //nolint:lll secondLevelScriptTree, err := input.TaprootSecondLevelScriptTree( keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, + secondLevelAuxLeaf, ) if err != nil { return nil, err @@ -6738,8 +7042,8 @@ func newIncomingHtlcResolution(signer input.Signer, htlc *channeldb.HTLC, keyRing *CommitmentKeyRing, feePerKw chainfee.SatPerKWeight, csvDelay, leaseExpiry uint32, whoseCommit lntypes.ChannelParty, isCommitFromInitiator bool, - chanType channeldb.ChannelType) ( - *IncomingHtlcResolution, error) { + chanType channeldb.ChannelType, + auxLeaves fn.Option[CommitAuxLeaves]) (*IncomingHtlcResolution, error) { op := wire.OutPoint{ Hash: commitTx.TxHash(), @@ -6748,9 +7052,12 @@ func newIncomingHtlcResolution(signer input.Signer, // First, we'll re-generate the script the remote party used to // send the HTLC to us in their commitment transaction. + auxLeaf := fn.ChainOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.IncomingHtlcLeaves[htlc.HtlcIndex].AuxTapLeaf + })(auxLeaves) scriptInfo, err := genHtlcScript( chanType, true, whoseCommit, htlc.RefundTimeout, htlc.RHash, - keyRing, + keyRing, auxLeaf, ) if err != nil { return nil, err @@ -6810,6 +7117,13 @@ func newIncomingHtlcResolution(signer input.Signer, }, nil } + secondLevelAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.IncomingHtlcLeaves + return leaves[htlc.HtlcIndex].SecondLevelLeaf + }, + )(auxLeaves) + // Otherwise, we'll need to go to the second level to sweep this HTLC. // // First, we'll reconstruct the original HTLC success transaction, @@ -6819,7 +7133,7 @@ func newIncomingHtlcResolution(signer input.Signer, successTx, err := CreateHtlcSuccessTx( chanType, isCommitFromInitiator, op, secondLevelOutputAmt, csvDelay, leaseExpiry, keyRing.RevocationKey, - keyRing.ToLocalKey, + keyRing.ToLocalKey, secondLevelAuxLeaf, ) if err != nil { return nil, err @@ -6902,6 +7216,7 @@ func newIncomingHtlcResolution(signer input.Signer, htlcSweepScript, err = SecondLevelHtlcScript( chanType, isCommitFromInitiator, keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, leaseExpiry, + secondLevelAuxLeaf, ) if err != nil { return nil, err @@ -6910,6 +7225,7 @@ func newIncomingHtlcResolution(signer input.Signer, //nolint:lll secondLevelScriptTree, err := input.TaprootSecondLevelScriptTree( keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, + secondLevelAuxLeaf, ) if err != nil { return nil, err @@ -7002,7 +7318,8 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, htlcs []channeldb.HTLC, keyRing *CommitmentKeyRing, localChanCfg, remoteChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx, chanType channeldb.ChannelType, - isCommitFromInitiator bool, leaseExpiry uint32) (*HtlcResolutions, error) { + isCommitFromInitiator bool, leaseExpiry uint32, + auxLeaves fn.Option[CommitAuxLeaves]) (*HtlcResolutions, error) { // TODO(roasbeef): don't need to swap csv delay? dustLimit := remoteChanCfg.DustLimit @@ -7035,8 +7352,9 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, // as we can satisfy the contract. ihr, err := newIncomingHtlcResolution( signer, localChanCfg, commitTx, &htlc, - keyRing, feePerKw, uint32(csvDelay), leaseExpiry, - whoseCommit, isCommitFromInitiator, chanType, + keyRing, feePerKw, uint32(csvDelay), + leaseExpiry, whoseCommit, isCommitFromInitiator, + chanType, auxLeaves, ) if err != nil { return nil, fmt.Errorf("incoming resolution "+ @@ -7050,7 +7368,7 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ohr, err := newOutgoingHtlcResolution( signer, localChanCfg, commitTx, &htlc, keyRing, feePerKw, uint32(csvDelay), leaseExpiry, whoseCommit, - isCommitFromInitiator, chanType, + isCommitFromInitiator, chanType, auxLeaves, ) if err != nil { return nil, fmt.Errorf("outgoing resolution "+ @@ -7150,7 +7468,7 @@ func (lc *LightningChannel) ForceClose() (*LocalForceCloseSummary, error) { localCommitment := lc.channelState.LocalCommitment summary, err := NewLocalForceCloseSummary( lc.channelState, lc.Signer, commitTx, - localCommitment.CommitHeight, + localCommitment.CommitHeight, lc.leafStore, ) if err != nil { return nil, fmt.Errorf("unable to gen force close "+ @@ -7167,8 +7485,8 @@ func (lc *LightningChannel) ForceClose() (*LocalForceCloseSummary, error) { // channel state. The passed commitTx must be a fully signed commitment // transaction corresponding to localCommit. func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, - signer input.Signer, commitTx *wire.MsgTx, stateNum uint64) ( - *LocalForceCloseSummary, error) { + signer input.Signer, commitTx *wire.MsgTx, stateNum uint64, + leafStore fn.Option[AuxLeafStore]) (*LocalForceCloseSummary, error) { // Re-derive the original pkScript for to-self output within the // commitment transaction. We'll need this to find the corresponding @@ -7189,13 +7507,31 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) + auxResult, err := fn.MapOptionZ( + leafStore, func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return s.FetchLeavesFromCommit( + NewAuxChanState(chanState), + chanState.LocalCommitment, *keyRing, + ) + }, + ).Unpack() + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + var leaseExpiry uint32 if chanState.ChanType.HasLeaseExpiration() { leaseExpiry = chanState.ThawHeight } + + localAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + }, + )(auxResult.AuxLeaves) toLocalScript, err := CommitScriptToSelf( chanState.ChanType, chanState.IsInitiator, keyRing.ToLocalKey, - keyRing.RevocationKey, csvTimeout, leaseExpiry, + keyRing.RevocationKey, csvTimeout, leaseExpiry, localAuxLeaf, ) if err != nil { return nil, err @@ -7286,7 +7622,7 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, chainfee.SatPerKWeight(localCommit.FeePerKw), lntypes.Local, signer, localCommit.Htlcs, keyRing, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, commitTx, chanState.ChanType, - chanState.IsInitiator, leaseExpiry, + chanState.IsInitiator, leaseExpiry, auxResult.AuxLeaves, ) if err != nil { return nil, fmt.Errorf("unable to gen htlc resolution: %w", err) @@ -7816,13 +8152,13 @@ func (lc *LightningChannel) availableBalance( } // availableCommitmentBalance attempts to calculate the balance we have -// available for HTLCs on the local/remote commitment given the htlcView. To +// available for HTLCs on the local/remote commitment given the HtlcView. To // account for sending HTLCs of different sizes, it will report the balance // available for sending non-dust HTLCs, which will be manifested on the // commitment, increasing the commitment fee we must pay as an initiator, // eating into our balance. It will make sure we won't violate the channel // reserve constraints for this amount. -func (lc *LightningChannel) availableCommitmentBalance(view *htlcView, +func (lc *LightningChannel) availableCommitmentBalance(view *HtlcView, whoseCommitChain lntypes.ChannelParty, buffer BufferType) ( lnwire.MilliSatoshi, lntypes.WeightUnit) { @@ -7852,7 +8188,7 @@ func (lc *LightningChannel) availableCommitmentBalance(view *htlcView, // Calculate the commitment fee in the case where we would add another // HTLC to the commitment, as only the balance remaining after this fee // has been paid is actually available for sending. - feePerKw := filteredView.feePerKw + feePerKw := filteredView.FeePerKw additionalHtlcFee := lnwire.NewMSatFromSatoshis( feePerKw.FeeForWeight(input.HTLCWeight), ) @@ -8671,12 +9007,13 @@ func (lc *LightningChannel) InitRemoteMusigNonces(remoteNonce *musig2.Nonces, // TODO(roasbeef): propagate rename of signing and verification nonces sessionCfg := &MusigSessionCfg{ - LocalKey: localChanCfg.MultiSigKey, - RemoteKey: remoteChanCfg.MultiSigKey, - LocalNonce: *localNonce, - RemoteNonce: *remoteNonce, - Signer: lc.Signer, - InputTxOut: &lc.fundingOutput, + LocalKey: localChanCfg.MultiSigKey, + RemoteKey: remoteChanCfg.MultiSigKey, + LocalNonce: *localNonce, + RemoteNonce: *remoteNonce, + Signer: lc.Signer, + InputTxOut: &lc.fundingOutput, + TapscriptTweak: lc.channelState.TapscriptRoot, } lc.musigSessions = NewMusigPairSession( sessionCfg, diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 757a808fe..de6bb219c 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -2,6 +2,7 @@ package lnwallet import ( "bytes" + crand "crypto/rand" "crypto/sha256" "fmt" "math/rand" @@ -386,6 +387,12 @@ func TestSimpleAddSettleWorkflow(t *testing.T) { ) }) + t.Run("taproot with tapscript root", func(t *testing.T) { + flags := channeldb.SimpleTaprootFeatureBit | + channeldb.TapscriptRootBit + testAddSettleWorkflow(t, true, flags, false) + }) + t.Run("storeFinalHtlcResolutions=true", func(t *testing.T) { testAddSettleWorkflow(t, false, 0, true) }) @@ -828,6 +835,16 @@ func TestForceClose(t *testing.T) { anchorAmt: AnchorSize * 2, }) }) + t.Run("taproot with tapscript root", func(t *testing.T) { + testForceClose(t, &forceCloseTestCase{ + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit | + channeldb.SimpleTaprootFeatureBit | + channeldb.TapscriptRootBit, + expectedCommitWeight: input.TaprootCommitWeight, + anchorAmt: AnchorSize * 2, + }) + }) } type forceCloseTestCase struct { @@ -5678,6 +5695,7 @@ func TestChannelUnilateralCloseHtlcResolution(t *testing.T) { spendDetail, aliceChannel.channelState.RemoteCommitment, aliceChannel.channelState.RemoteCurrentRevocation, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.NoError(t, err, "unable to create alice close summary") @@ -5827,6 +5845,7 @@ func TestChannelUnilateralClosePendingCommit(t *testing.T) { spendDetail, aliceChannel.channelState.RemoteCommitment, aliceChannel.channelState.RemoteCurrentRevocation, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.NoError(t, err, "unable to create alice close summary") @@ -5844,6 +5863,7 @@ func TestChannelUnilateralClosePendingCommit(t *testing.T) { spendDetail, aliceRemoteChainTip.Commitment, aliceChannel.channelState.RemoteNextRevocation, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.NoError(t, err, "unable to create alice close summary") @@ -6724,6 +6744,7 @@ func TestNewBreachRetributionSkipsDustHtlcs(t *testing.T) { breachTx := aliceChannel.channelState.RemoteCommitment.CommitTx breachRet, err := NewBreachRetribution( aliceChannel.channelState, revokedStateNum, 100, breachTx, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.NoError(t, err, "unable to create breach retribution") @@ -8525,10 +8546,10 @@ func TestEvaluateView(t *testing.T) { } } - view := &htlcView{ - ourUpdates: test.ourHtlcs, - theirUpdates: test.theirHtlcs, - feePerKw: feePerKw, + view := &HtlcView{ + OurUpdates: test.ourHtlcs, + TheirUpdates: test.theirHtlcs, + FeePerKw: feePerKw, } var ( @@ -8549,17 +8570,17 @@ func TestEvaluateView(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - if result.feePerKw != test.expectedFee { + if result.FeePerKw != test.expectedFee { t.Fatalf("expected fee: %v, got: %v", - test.expectedFee, result.feePerKw) + test.expectedFee, result.FeePerKw) } checkExpectedHtlcs( - t, result.ourUpdates, test.ourExpectedHtlcs, + t, result.OurUpdates, test.ourExpectedHtlcs, ) checkExpectedHtlcs( - t, result.theirUpdates, test.theirExpectedHtlcs, + t, result.TheirUpdates, test.theirExpectedHtlcs, ) if lc.channelState.TotalMSatSent != test.expectSent { @@ -8782,15 +8803,15 @@ func TestProcessFeeUpdate(t *testing.T) { EntryType: FeeUpdate, } - view := &htlcView{ - feePerKw: chainfee.SatPerKWeight(feePerKw), + view := &HtlcView{ + FeePerKw: chainfee.SatPerKWeight(feePerKw), } processFeeUpdate( update, nextHeight, test.whoseCommitChain, test.mutate, view, ) - if view.feePerKw != test.expectedFee { + if view.FeePerKw != test.expectedFee { t.Fatalf("expected fee: %v, got: %v", test.expectedFee, feePerKw) } @@ -9940,15 +9961,17 @@ func TestCreateHtlcRetribution(t *testing.T) { aliceChannel.channelState, ) htlc := &channeldb.HTLCEntry{ - Amt: testAmt, - Incoming: true, - OutputIndex: 1, + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(testAmt), + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16](1), } // Create the htlc retribution. hr, err := createHtlcRetribution( aliceChannel.channelState, keyRing, commitHash, - dummyPrivate, leaseExpiry, htlc, + dummyPrivate, leaseExpiry, htlc, fn.None[CommitAuxLeaves](), ) // Expect no error. require.NoError(t, err) @@ -9956,8 +9979,8 @@ func TestCreateHtlcRetribution(t *testing.T) { // Check the fields have expected values. require.EqualValues(t, testAmt, hr.SignDesc.Output.Value) require.Equal(t, commitHash, hr.OutPoint.Hash) - require.EqualValues(t, htlc.OutputIndex, hr.OutPoint.Index) - require.Equal(t, htlc.Incoming, hr.IsIncoming) + require.EqualValues(t, htlc.OutputIndex.Val, hr.OutPoint.Index) + require.Equal(t, htlc.Incoming.Val, hr.IsIncoming) } // TestCreateBreachRetribution checks that `createBreachRetribution` behaves as @@ -9997,30 +10020,31 @@ func TestCreateBreachRetribution(t *testing.T) { aliceChannel.channelState, ) htlc := &channeldb.HTLCEntry{ - Amt: btcutil.Amount(testAmt), - Incoming: true, - OutputIndex: uint16(htlcIndex), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(btcutil.Amount(testAmt)), + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2]( + uint16(htlcIndex), + ), } // Create a dummy revocation log. ourAmtMsat := lnwire.MilliSatoshi(ourAmt * 1000) theirAmtMsat := lnwire.MilliSatoshi(theirAmt * 1000) - revokedLog := channeldb.RevocationLog{ - CommitTxHash: commitHash, - OurOutputIndex: uint16(localIndex), - TheirOutputIndex: uint16(remoteIndex), - HTLCEntries: []*channeldb.HTLCEntry{htlc}, - TheirBalance: &theirAmtMsat, - OurBalance: &ourAmtMsat, - } + revokedLog := channeldb.NewRevocationLog( + uint16(localIndex), uint16(remoteIndex), commitHash, + fn.Some(ourAmtMsat), fn.Some(theirAmtMsat), + []*channeldb.HTLCEntry{htlc}, fn.None[tlv.Blob](), + ) // Create a log with an empty local output index. revokedLogNoLocal := revokedLog - revokedLogNoLocal.OurOutputIndex = channeldb.OutputIndexEmpty + revokedLogNoLocal.OurOutputIndex.Val = channeldb.OutputIndexEmpty // Create a log with an empty remote output index. revokedLogNoRemote := revokedLog - revokedLogNoRemote.TheirOutputIndex = channeldb.OutputIndexEmpty + revokedLogNoRemote.TheirOutputIndex.Val = channeldb.OutputIndexEmpty testCases := []struct { name string @@ -10050,14 +10074,20 @@ func TestCreateBreachRetribution(t *testing.T) { { name: "fail due to our index too big", revocationLog: &channeldb.RevocationLog{ - OurOutputIndex: uint16(htlcIndex + 1), + //nolint:lll + OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( + uint16(htlcIndex + 1), + ), }, expectedErr: ErrOutputIndexOutOfRange, }, { name: "fail due to their index too big", revocationLog: &channeldb.RevocationLog{ - TheirOutputIndex: uint16(htlcIndex + 1), + //nolint:lll + TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1]( + uint16(htlcIndex + 1), + ), }, expectedErr: ErrOutputIndexOutOfRange, }, @@ -10126,11 +10156,12 @@ func TestCreateBreachRetribution(t *testing.T) { require.Equal(t, remote, br.RemoteOutpoint) for _, hr := range br.HtlcRetributions { - require.EqualValues(t, testAmt, - hr.SignDesc.Output.Value) + require.EqualValues( + t, testAmt, hr.SignDesc.Output.Value, + ) require.Equal(t, commitHash, hr.OutPoint.Hash) require.EqualValues(t, htlcIndex, hr.OutPoint.Index) - require.Equal(t, htlc.Incoming, hr.IsIncoming) + require.Equal(t, htlc.Incoming.Val, hr.IsIncoming) } } @@ -10146,6 +10177,7 @@ func TestCreateBreachRetribution(t *testing.T) { tc.revocationLog, tx, aliceChannel.channelState, keyRing, dummyPrivate, leaseExpiry, + fn.None[CommitAuxLeaves](), ) // Check the error if expected. @@ -10264,6 +10296,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // error as there are no past delta state saved as revocation logs yet. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, breachTx, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.ErrorIs(t, err, channeldb.ErrNoPastDeltas) @@ -10271,6 +10304,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // provided. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, nil, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.ErrorIs(t, err, channeldb.ErrNoPastDeltas) @@ -10316,6 +10350,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // successfully. br, err := NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, breachTx, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.NoError(t, err) @@ -10327,6 +10362,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // since the necessary info should now be found in the revocation log. br, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, nil, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.NoError(t, err) assertRetribution(br, 1, 0) @@ -10335,6 +10371,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // error. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum+1, breachHeight, breachTx, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.ErrorIs(t, err, channeldb.ErrLogEntryNotFound) @@ -10342,6 +10379,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // provided. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum+1, breachHeight, nil, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.ErrorIs(t, err, channeldb.ErrLogEntryNotFound) } @@ -10379,7 +10417,8 @@ func TestExtractPayDescs(t *testing.T) { // NOTE: we use nil commitment key rings to avoid checking the htlc // scripts(`genHtlcScript`) as it should be tested independently. incomingPDs, outgoingPDs, err := lnChan.extractPayDescs( - 0, htlcs, nil, nil, lntypes.Local, + 0, htlcs, lntypes.Dual[*CommitmentKeyRing]{}, lntypes.Local, + fn.None[CommitAuxLeaves](), ) require.NoError(t, err) @@ -10415,19 +10454,26 @@ func assertPayDescMatchHTLC(t *testing.T, pd PaymentDescriptor, // the `Incoming`. func createRandomHTLC(t *testing.T, incoming bool) channeldb.HTLC { var onionBlob [lnwire.OnionPacketSize]byte - _, err := rand.Read(onionBlob[:]) + _, err := crand.Read(onionBlob[:]) require.NoError(t, err) var rHash [lntypes.HashSize]byte - _, err = rand.Read(rHash[:]) + _, err = crand.Read(rHash[:]) require.NoError(t, err) sig := make([]byte, 64) - _, err = rand.Read(sig) + _, err = crand.Read(sig) require.NoError(t, err) + randCustomData := make([]byte, 32) + _, err = crand.Read(randCustomData) + require.NoError(t, err) + + randCustomType := rand.Intn(255) + lnwire.MinCustomRecordsTlvType + blinding, err := pubkeyFromHex( - "0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d48236c39", //nolint:lll + "0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d48" + + "236c39", ) require.NoError(t, err) @@ -10442,9 +10488,13 @@ func createRandomHTLC(t *testing.T, incoming bool) channeldb.HTLC { HtlcIndex: rand.Uint64(), LogIndex: rand.Uint64(), BlindingPoint: tlv.SomeRecordT( - //nolint:lll - tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](blinding), + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + blinding, + ), ), + CustomRecords: map[uint64][]byte{ + uint64(randCustomType): randCustomData, + }, } } diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index f32408814..2ff23ab63 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -226,8 +227,7 @@ func (w *WitnessScriptDesc) WitnessScriptForPath( // the settled funds in the channel, plus the unsettled funds. func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool, selfKey, revokeKey *btcec.PublicKey, csvDelay, leaseExpiry uint32, -) ( - input.ScriptDescriptor, error) { + auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, error) { switch { // For taproot scripts, we'll need to make a slightly modified script @@ -237,7 +237,7 @@ func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool, // Our "redeem" script here is just the taproot witness program. case chanType.IsTaproot(): return input.NewLocalCommitScriptTree( - csvDelay, selfKey, revokeKey, + csvDelay, selfKey, revokeKey, auxLeaf, ) // If we are the initiator of a leased channel, then we have an @@ -291,8 +291,8 @@ func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool, // script for. The second return value is the CSV delay of the output script, // what must be satisfied in order to spend the output. func CommitScriptToRemote(chanType channeldb.ChannelType, initiator bool, - remoteKey *btcec.PublicKey, - leaseExpiry uint32) (input.ScriptDescriptor, uint32, error) { + remoteKey *btcec.PublicKey, leaseExpiry uint32, + auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, uint32, error) { switch { // If we are not the initiator of a leased channel, then the remote @@ -321,7 +321,7 @@ func CommitScriptToRemote(chanType channeldb.ChannelType, initiator bool, // with the sole tap leaf enforcing the 1 CSV delay. case chanType.IsTaproot(): toRemoteScriptTree, err := input.NewRemoteCommitScriptTree( - remoteKey, + remoteKey, auxLeaf, ) if err != nil { return nil, 0, err @@ -420,14 +420,14 @@ func sweepSigHash(chanType channeldb.ChannelType) txscript.SigHashType { // argument should correspond to the owner of the commitment transaction which // we are generating the to_local script for. func SecondLevelHtlcScript(chanType channeldb.ChannelType, initiator bool, - revocationKey, delayKey *btcec.PublicKey, - csvDelay, leaseExpiry uint32) (input.ScriptDescriptor, error) { + revocationKey, delayKey *btcec.PublicKey, csvDelay, leaseExpiry uint32, + auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, error) { switch { // For taproot channels, the pkScript is a segwit v1 p2tr output. case chanType.IsTaproot(): return input.TaprootSecondLevelScriptTree( - revocationKey, delayKey, csvDelay, + revocationKey, delayKey, csvDelay, auxLeaf, ) // If we are the initiator of a leased channel, then we have an @@ -613,7 +613,7 @@ func CommitScriptAnchors(chanType channeldb.ChannelType, // with, and abstracts the various ways of constructing commitment // transactions. type CommitmentBuilder struct { - // chanState is the underlying channels's state struct, used to + // chanState is the underlying channel's state struct, used to // determine the type of channel we are dealing with, and relevant // parameters. chanState *channeldb.OpenChannel @@ -621,18 +621,25 @@ type CommitmentBuilder struct { // obfuscator is a 48-bit state hint that's used to obfuscate the // current state number on the commitment transactions. obfuscator [StateHintSize]byte + + // auxLeafStore is an interface that allows us to fetch auxiliary + // tapscript leaves for the commitment output. + auxLeafStore fn.Option[AuxLeafStore] } // NewCommitmentBuilder creates a new CommitmentBuilder from chanState. -func NewCommitmentBuilder(chanState *channeldb.OpenChannel) *CommitmentBuilder { +func NewCommitmentBuilder(chanState *channeldb.OpenChannel, + leafStore fn.Option[AuxLeafStore]) *CommitmentBuilder { + // The anchor channel type MUST be tweakless. if chanState.ChanType.HasAnchors() && !chanState.ChanType.IsTweakless() { panic("invalid channel type combination") } return &CommitmentBuilder{ - chanState: chanState, - obfuscator: createStateHintObfuscator(chanState), + chanState: chanState, + obfuscator: createStateHintObfuscator(chanState), + auxLeafStore: leafStore, } } @@ -685,9 +692,9 @@ type unsignedCommitmentTx struct { // fees, but after anchor outputs. func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, theirBalance lnwire.MilliSatoshi, whoseCommit lntypes.ChannelParty, - feePerKw chainfee.SatPerKWeight, height uint64, - filteredHTLCView *htlcView, - keyRing *CommitmentKeyRing) (*unsignedCommitmentTx, error) { + feePerKw chainfee.SatPerKWeight, height uint64, originalHtlcView, + filteredHTLCView *HtlcView, keyRing *CommitmentKeyRing, + prevCommit *commitment) (*unsignedCommitmentTx, error) { dustLimit := cb.chanState.LocalChanCfg.DustLimit if whoseCommit.IsRemote() { @@ -695,7 +702,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } numHTLCs := int64(0) - for _, htlc := range filteredHTLCView.ourUpdates { + for _, htlc := range filteredHTLCView.OurUpdates { if HtlcIsDust( cb.chanState.ChanType, false, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -706,7 +713,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, numHTLCs++ } - for _, htlc := range filteredHTLCView.theirUpdates { + for _, htlc := range filteredHTLCView.TheirUpdates { if HtlcIsDust( cb.chanState.ChanType, true, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -748,10 +755,24 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, theirBalance -= commitFeeMSat } - var ( - commitTx *wire.MsgTx - err error - ) + var commitTx *wire.MsgTx + + // Before we create the commitment transaction below, we'll try to see + // if there're any aux leaves that need to be a part of the tapscript + // tree. We'll only do this if we have a custom blob defined though. + auxResult, err := fn.MapOptionZ( + cb.auxLeafStore, + func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return auxLeavesFromView( + s, cb.chanState, prevCommit.customBlob, + originalHtlcView, whoseCommit, ourBalance, + theirBalance, *keyRing, + ) + }, + ).Unpack() + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } // Depending on whether the transaction is ours or not, we call // CreateCommitTx with parameters matching the perspective, to generate @@ -767,6 +788,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, &cb.chanState.LocalChanCfg, &cb.chanState.RemoteChanCfg, ourBalance.ToSatoshis(), theirBalance.ToSatoshis(), numHTLCs, cb.chanState.IsInitiator, leaseExpiry, + auxResult.AuxLeaves, ) } else { commitTx, err = CreateCommitTx( @@ -774,12 +796,26 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, &cb.chanState.RemoteChanCfg, &cb.chanState.LocalChanCfg, theirBalance.ToSatoshis(), ourBalance.ToSatoshis(), numHTLCs, !cb.chanState.IsInitiator, leaseExpiry, + auxResult.AuxLeaves, ) } if err != nil { return nil, err } + // Similarly, we'll now attempt to extract the set of aux leaves for + // the set of incoming and outgoing HTLCs. + incomingAuxLeaves := fn.MapOption( + func(leaves CommitAuxLeaves) input.HtlcAuxLeaves { + return leaves.IncomingHtlcLeaves + }, + )(auxResult.AuxLeaves) + outgoingAuxLeaves := fn.MapOption( + func(leaves CommitAuxLeaves) input.HtlcAuxLeaves { + return leaves.OutgoingHtlcLeaves + }, + )(auxResult.AuxLeaves) + // We'll now add all the HTLC outputs to the commitment transaction. // Each output includes an off-chain 2-of-2 covenant clause, so we'll // need the objective local/remote keys for this particular commitment @@ -790,7 +826,8 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, // commitment outputs and should correspond to zero values for the // purposes of sorting. cltvs := make([]uint32, len(commitTx.TxOut)) - for _, htlc := range filteredHTLCView.ourUpdates { + htlcIndexes := make([]input.HtlcIndex, len(commitTx.TxOut)) + for _, htlc := range filteredHTLCView.OurUpdates { if HtlcIsDust( cb.chanState.ChanType, false, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -799,16 +836,26 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, continue } + auxLeaf := fn.ChainOption( + func(leaves input.HtlcAuxLeaves) input.AuxTapLeaf { + return leaves[htlc.HtlcIndex].AuxTapLeaf + }, + )(outgoingAuxLeaves) + err := addHTLC( commitTx, whoseCommit, false, htlc, keyRing, - cb.chanState.ChanType, + cb.chanState.ChanType, auxLeaf, ) if err != nil { return nil, err } - cltvs = append(cltvs, htlc.Timeout) // nolint:makezero + + // We want to add the CLTV and HTLC index to their respective + // slices, even if we already pre-allocated them. + cltvs = append(cltvs, htlc.Timeout) //nolint + htlcIndexes = append(htlcIndexes, htlc.HtlcIndex) //nolint } - for _, htlc := range filteredHTLCView.theirUpdates { + for _, htlc := range filteredHTLCView.TheirUpdates { if HtlcIsDust( cb.chanState.ChanType, true, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -817,14 +864,24 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, continue } + auxLeaf := fn.ChainOption( + func(leaves input.HtlcAuxLeaves) input.AuxTapLeaf { + return leaves[htlc.HtlcIndex].AuxTapLeaf + }, + )(incomingAuxLeaves) + err := addHTLC( commitTx, whoseCommit, true, htlc, keyRing, - cb.chanState.ChanType, + cb.chanState.ChanType, auxLeaf, ) if err != nil { return nil, err } - cltvs = append(cltvs, htlc.Timeout) // nolint:makezero + + // We want to add the CLTV and HTLC index to their respective + // slices, even if we already pre-allocated them. + cltvs = append(cltvs, htlc.Timeout) //nolint + htlcIndexes = append(htlcIndexes, htlc.HtlcIndex) //nolint } // Set the state hint of the commitment transaction to facilitate @@ -836,9 +893,16 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } // Sort the transactions according to the agreed upon canonical - // ordering. This lets us skip sending the entire transaction over, - // instead we'll just send signatures. - InPlaceCommitSort(commitTx, cltvs) + // ordering (which might be customized for custom channel types, but + // deterministic and both parties will arrive at the same result). This + // lets us skip sending the entire transaction over, instead we'll just + // send signatures. + commitSort := auxResult.CommitSortFunc.UnwrapOr(DefaultCommitSort) + err = commitSort(commitTx, cltvs, htlcIndexes) + if err != nil { + return nil, fmt.Errorf("unable to sort commitment "+ + "transaction: %w", err) + } // Next, we'll ensure that we don't accidentally create a commitment // transaction which would be invalid by consensus. @@ -880,24 +944,33 @@ func CreateCommitTx(chanType channeldb.ChannelType, fundingOutput wire.TxIn, keyRing *CommitmentKeyRing, localChanCfg, remoteChanCfg *channeldb.ChannelConfig, amountToLocal, amountToRemote btcutil.Amount, - numHTLCs int64, initiator bool, leaseExpiry uint32) (*wire.MsgTx, error) { + numHTLCs int64, initiator bool, leaseExpiry uint32, + auxLeaves fn.Option[CommitAuxLeaves]) (*wire.MsgTx, error) { // First, we create the script for the delayed "pay-to-self" output. // This output has 2 main redemption clauses: either we can redeem the // output after a relative block delay, or the remote node can claim // the funds with the revocation key if we broadcast a revoked // commitment transaction. + localAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + })(auxLeaves) toLocalScript, err := CommitScriptToSelf( chanType, initiator, keyRing.ToLocalKey, keyRing.RevocationKey, uint32(localChanCfg.CsvDelay), leaseExpiry, + fn.FlattenOption(localAuxLeaf), ) if err != nil { return nil, err } // Next, we create the script paying to the remote. + remoteAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + })(auxLeaves) toRemoteScript, _, err := CommitScriptToRemote( chanType, initiator, keyRing.ToRemoteKey, leaseExpiry, + fn.FlattenOption(remoteAuxLeaf), ) if err != nil { return nil, err @@ -1077,7 +1150,7 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType, // channel. func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, timeout uint32, rHash [32]byte, keyRing *CommitmentKeyRing, -) (*input.HtlcScriptTree, error) { + auxLeaf input.AuxTapLeaf) (*input.HtlcScriptTree, error) { var ( htlcScriptTree *input.HtlcScriptTree @@ -1094,7 +1167,7 @@ func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, case isIncoming && whoseCommit.IsLocal(): htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, - keyRing.RevocationKey, rHash[:], whoseCommit, + keyRing.RevocationKey, rHash[:], whoseCommit, auxLeaf, ) // We're being paid via an HTLC by the remote party, and the HTLC is @@ -1103,7 +1176,7 @@ func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, case isIncoming && whoseCommit.IsRemote(): htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, - keyRing.RevocationKey, rHash[:], whoseCommit, + keyRing.RevocationKey, rHash[:], whoseCommit, auxLeaf, ) // We're sending an HTLC which is being added to our commitment @@ -1112,7 +1185,7 @@ func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, case !isIncoming && whoseCommit.IsLocal(): htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, - keyRing.RevocationKey, rHash[:], whoseCommit, + keyRing.RevocationKey, rHash[:], whoseCommit, auxLeaf, ) // Finally, we're paying the remote party via an HTLC, which is being @@ -1121,7 +1194,7 @@ func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, case !isIncoming && whoseCommit.IsRemote(): htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, - keyRing.RevocationKey, rHash[:], whoseCommit, + keyRing.RevocationKey, rHash[:], whoseCommit, auxLeaf, ) } @@ -1136,7 +1209,8 @@ func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, // along side the multiplexer. func genHtlcScript(chanType channeldb.ChannelType, isIncoming bool, whoseCommit lntypes.ChannelParty, timeout uint32, rHash [32]byte, - keyRing *CommitmentKeyRing) (input.ScriptDescriptor, error) { + keyRing *CommitmentKeyRing, + auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, error) { if !chanType.IsTaproot() { return genSegwitV0HtlcScript( @@ -1146,7 +1220,7 @@ func genHtlcScript(chanType channeldb.ChannelType, isIncoming bool, } return GenTaprootHtlcScript( - isIncoming, whoseCommit, timeout, rHash, keyRing, + isIncoming, whoseCommit, timeout, rHash, keyRing, auxLeaf, ) } @@ -1159,13 +1233,15 @@ func genHtlcScript(chanType channeldb.ChannelType, isIncoming bool, // the descriptor itself. func addHTLC(commitTx *wire.MsgTx, whoseCommit lntypes.ChannelParty, isIncoming bool, paymentDesc *PaymentDescriptor, - keyRing *CommitmentKeyRing, chanType channeldb.ChannelType) error { + keyRing *CommitmentKeyRing, chanType channeldb.ChannelType, + auxLeaf input.AuxTapLeaf) error { timeout := paymentDesc.Timeout rHash := paymentDesc.RHash scriptInfo, err := genHtlcScript( chanType, isIncoming, whoseCommit, timeout, rHash, keyRing, + auxLeaf, ) if err != nil { return err @@ -1198,7 +1274,8 @@ func addHTLC(commitTx *wire.MsgTx, whoseCommit lntypes.ChannelParty, // output scripts and compares them against the outputs inside the commitment // to find the match. func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, - chanState *channeldb.OpenChannel) (uint32, uint32, error) { + chanState *channeldb.OpenChannel, + leafStore fn.Option[AuxLeafStore]) (uint32, uint32, error) { // Init the output indexes as empty. ourIndex := uint32(channeldb.OutputIndexEmpty) @@ -1228,26 +1305,51 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, leaseExpiry = chanState.ThawHeight } - // Map the scripts from our PoV. When facing a local commitment, the to - // local output belongs to us and the to remote output belongs to them. - // When facing a remote commitment, the to local output belongs to them - // and the to remote output belongs to us. + // If we have a custom blob, then we'll attempt to fetch the aux leaves + // for this state. + auxResult, err := fn.MapOptionZ( + leafStore, func(a AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return a.FetchLeavesFromCommit( + NewAuxChanState(chanState), chanCommit, + *keyRing, + ) + }, + ).Unpack() + if err != nil { + return ourIndex, theirIndex, fmt.Errorf("unable to fetch aux "+ + "leaves: %w", err) + } - // Compute the to local script. From our PoV, when facing a remote - // commitment, the to local output belongs to them. + // Map the scripts from our PoV. When facing a local commitment, the + // to_local output belongs to us and the to_remote output belongs to + // them. When facing a remote commitment, the to_local output belongs to + // them and the to_remote output belongs to us. + + // Compute the to_local script. From our PoV, when facing a remote + // commitment, the to_local output belongs to them. + localAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + }, + )(auxResult.AuxLeaves) theirScript, err := CommitScriptToSelf( chanState.ChanType, isRemoteInitiator, keyRing.ToLocalKey, - keyRing.RevocationKey, theirDelay, leaseExpiry, + keyRing.RevocationKey, theirDelay, leaseExpiry, localAuxLeaf, ) if err != nil { return ourIndex, theirIndex, err } - // Compute the to remote script. From our PoV, when facing a remote - // commitment, the to remote output belongs to us. + // Compute the to_remote script. From our PoV, when facing a remote + // commitment, the to_remote output belongs to us. + remoteAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + }, + )(auxResult.AuxLeaves) ourScript, _, err := CommitScriptToRemote( chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey, - leaseExpiry, + leaseExpiry, remoteAuxLeaf, ) if err != nil { return ourIndex, theirIndex, err diff --git a/lnwallet/config.go b/lnwallet/config.go index 7eeacb6ea..24961f38e 100644 --- a/lnwallet/config.go +++ b/lnwallet/config.go @@ -5,6 +5,7 @@ import ( "github.com/btcsuite/btcwallet/wallet" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -62,4 +63,8 @@ type Config struct { // CoinSelectionStrategy is the strategy that is used for selecting // coins when funding a transaction. CoinSelectionStrategy wallet.CoinSelectionStrategy + + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[AuxLeafStore] } diff --git a/lnwallet/mock.go b/lnwallet/mock.go index faac5fa67..2afff4f21 100644 --- a/lnwallet/mock.go +++ b/lnwallet/mock.go @@ -17,8 +17,10 @@ import ( "github.com/btcsuite/btcwallet/wallet/txauthor" "github.com/btcsuite/btcwallet/wtxmgr" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -397,3 +399,45 @@ func (*mockChainIO) GetBlockHeader( return nil, nil } + +type MockAuxLeafStore struct{} + +// A compile time check to ensure that MockAuxLeafStore implements the +// AuxLeafStore interface. +var _ AuxLeafStore = (*MockAuxLeafStore)(nil) + +// FetchLeavesFromView attempts to fetch the auxiliary leaves that +// correspond to the passed aux blob, and pending original (unfiltered) +// HTLC view. +func (*MockAuxLeafStore) FetchLeavesFromView( + _ CommitDiffAuxInput) fn.Result[CommitDiffAuxResult] { + + return fn.Ok(CommitDiffAuxResult{}) +} + +// FetchLeavesFromCommit attempts to fetch the auxiliary leaves that +// correspond to the passed aux blob, and an existing channel +// commitment. +func (*MockAuxLeafStore) FetchLeavesFromCommit(_ AuxChanState, + _ channeldb.ChannelCommitment, + _ CommitmentKeyRing) fn.Result[CommitDiffAuxResult] { + + return fn.Ok(CommitDiffAuxResult{}) +} + +// FetchLeavesFromRevocation attempts to fetch the auxiliary leaves +// from a channel revocation that stores balance + blob information. +func (*MockAuxLeafStore) FetchLeavesFromRevocation( + _ *channeldb.RevocationLog) fn.Result[CommitDiffAuxResult] { + + return fn.Ok(CommitDiffAuxResult{}) +} + +// ApplyHtlcView serves as the state transition function for the custom +// channel's blob. Given the old blob, and an HTLC view, then a new +// blob should be returned that reflects the pending updates. +func (*MockAuxLeafStore) ApplyHtlcView( + _ CommitDiffAuxInput) fn.Result[fn.Option[tlv.Blob]] { + + return fn.Ok(fn.None[tlv.Blob]()) +} diff --git a/lnwallet/musig_session.go b/lnwallet/musig_session.go index ecc60d07f..c3214d3f2 100644 --- a/lnwallet/musig_session.go +++ b/lnwallet/musig_session.go @@ -8,8 +8,10 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" @@ -37,6 +39,20 @@ var ( ErrSessionNotFinalized = fmt.Errorf("musig2 session not finalized") ) +// tapscriptRootToSignOpt is a function that takes a tapscript root and returns +// a MuSig2 sign opt that'll apply the tweak when signing+verifying. +func tapscriptRootToSignOpt(root chainhash.Hash) musig2.SignOption { + return musig2.WithTaprootSignTweak(root[:]) +} + +// TapscriptRootToTweak is a helper function that converts a tapscript root +// into a tweak that can be used with the MuSig2 API. +func TapscriptRootToTweak(root chainhash.Hash) input.MuSig2Tweaks { + return input.MuSig2Tweaks{ + TaprootTweak: root[:], + } +} + // MusigPartialSig is a wrapper around the base musig2.PartialSignature type // that also includes information about the set of nonces used, and also the // signer. This allows us to implement the input.Signature interface, as that @@ -54,25 +70,30 @@ type MusigPartialSig struct { // signerKeys is the set of public keys of all signers. signerKeys []*btcec.PublicKey + + // tapscriptTweak is an optional tweak, that if specified, will be used + // instead of the normal BIP 86 tweak when validating the signature. + tapscriptTweak fn.Option[chainhash.Hash] } -// NewMusigPartialSig creates a new musig partial signature. -func NewMusigPartialSig(sig *musig2.PartialSignature, - signerNonce, combinedNonce lnwire.Musig2Nonce, - signerKeys []*btcec.PublicKey) *MusigPartialSig { +// NewMusigPartialSig creates a new MuSig2 partial signature. +func NewMusigPartialSig(sig *musig2.PartialSignature, signerNonce, + combinedNonce lnwire.Musig2Nonce, signerKeys []*btcec.PublicKey, + tapscriptTweak fn.Option[chainhash.Hash]) *MusigPartialSig { return &MusigPartialSig{ - sig: sig, - signerNonce: signerNonce, - combinedNonce: combinedNonce, - signerKeys: signerKeys, + sig: sig, + signerNonce: signerNonce, + combinedNonce: combinedNonce, + signerKeys: signerKeys, + tapscriptTweak: tapscriptTweak, } } // FromWireSig maps a wire partial sig to this internal type that we'll use to // perform signature validation. -func (p *MusigPartialSig) FromWireSig(sig *lnwire.PartialSigWithNonce, -) *MusigPartialSig { +func (p *MusigPartialSig) FromWireSig( + sig *lnwire.PartialSigWithNonce) *MusigPartialSig { p.sig = &musig2.PartialSignature{ S: &sig.Sig, @@ -135,9 +156,15 @@ func (p *MusigPartialSig) Verify(msg []byte, pub *btcec.PublicKey) bool { var m [32]byte copy(m[:], msg) + // If we have a tapscript tweak, then we'll use that as a tweak + // otherwise, we'll fall back to the normal BIP 86 sign tweak. + signOpts := fn.MapOption(tapscriptRootToSignOpt)( + p.tapscriptTweak, + ).UnwrapOr(musig2.WithBip86SignTweak()) + return p.sig.Verify( p.signerNonce, p.combinedNonce, p.signerKeys, pub, m, - musig2.WithSortedKeys(), musig2.WithBip86SignTweak(), + musig2.WithSortedKeys(), signOpts, ) } @@ -160,6 +187,14 @@ func (n *MusigNoncePair) String() string { n.SigningNonce.PubNonce[:]) } +// TapscriptRootToTweak is a function that takes a MuSig2 taproot tweak and +// returns the root hash of the tapscript tree. +func muSig2TweakToRoot(tweak input.MuSig2Tweaks) chainhash.Hash { + var root chainhash.Hash + copy(root[:], tweak.TaprootTweak) + return root +} + // MusigSession abstracts over the details of a logical musig session. A single // session is used for each commitment transactions. The sessions use a JIT // nonce style, wherein part of the session can be created using only the @@ -197,15 +232,20 @@ type MusigSession struct { // commitType tracks if this is the session for the local or remote // commitment. commitType MusigCommitType + + // tapscriptTweak is an optional tweak, that if specified, will be used + // instead of the normal BIP 86 tweak when creating the MuSig2 + // aggregate key and session. + tapscriptTweak fn.Option[input.MuSig2Tweaks] } // NewPartialMusigSession creates a new musig2 session given only the // verification nonce (local nonce), and the other information that has already // been bound to the session. func NewPartialMusigSession(verificationNonce musig2.Nonces, - localKey, remoteKey keychain.KeyDescriptor, - signer input.MuSig2Signer, inputTxOut *wire.TxOut, - commitType MusigCommitType) *MusigSession { + localKey, remoteKey keychain.KeyDescriptor, signer input.MuSig2Signer, + inputTxOut *wire.TxOut, commitType MusigCommitType, + tapscriptTweak fn.Option[input.MuSig2Tweaks]) *MusigSession { signerKeys := []*btcec.PublicKey{localKey.PubKey, remoteKey.PubKey} @@ -214,13 +254,14 @@ func NewPartialMusigSession(verificationNonce musig2.Nonces, } return &MusigSession{ - nonces: nonces, - remoteKey: remoteKey, - localKey: localKey, - inputTxOut: inputTxOut, - signerKeys: signerKeys, - signer: signer, - commitType: commitType, + nonces: nonces, + remoteKey: remoteKey, + localKey: localKey, + inputTxOut: inputTxOut, + signerKeys: signerKeys, + signer: signer, + commitType: commitType, + tapscriptTweak: tapscriptTweak, } } @@ -254,9 +295,9 @@ func (m *MusigSession) FinalizeSession(signingNonce musig2.Nonces) error { remoteNonce = m.nonces.SigningNonce } - tweakDesc := input.MuSig2Tweaks{ + tweakDesc := m.tapscriptTweak.UnwrapOr(input.MuSig2Tweaks{ TaprootBIP0086Tweak: true, - } + }) m.session, err = m.signer.MuSig2CreateSession( input.MuSig2Version100RC2, m.localKey.KeyLocator, m.signerKeys, &tweakDesc, [][musig2.PubNonceSize]byte{remoteNonce.PubNonce}, @@ -351,8 +392,11 @@ func (m *MusigSession) SignCommit(tx *wire.MsgTx) (*MusigPartialSig, error) { return nil, err } + tapscriptRoot := fn.MapOption(muSig2TweakToRoot)(m.tapscriptTweak) + return NewMusigPartialSig( sig, m.session.PublicNonce, m.combinedNonce, m.signerKeys, + tapscriptRoot, ), nil } @@ -364,7 +408,7 @@ func (m *MusigSession) Refresh(verificationNonce *musig2.Nonces, return NewPartialMusigSession( *verificationNonce, m.localKey, m.remoteKey, m.signer, - m.inputTxOut, m.commitType, + m.inputTxOut, m.commitType, m.tapscriptTweak, ), nil } @@ -451,9 +495,11 @@ func (m *MusigSession) VerifyCommitSig(commitTx *wire.MsgTx, // When we verify a commitment signature, we always assume that we're // verifying a signature on our local commitment. Therefore, we'll use: // their remote nonce, and also public key. + tapscriptRoot := fn.MapOption(muSig2TweakToRoot)(m.tapscriptTweak) partialSig := NewMusigPartialSig( &musig2.PartialSignature{S: &sig.Sig}, m.nonces.SigningNonce.PubNonce, m.combinedNonce, m.signerKeys, + tapscriptRoot, ) // With the partial sig loaded with the proper context, we'll now @@ -537,6 +583,10 @@ type MusigSessionCfg struct { // InputTxOut is the output that we're signing for. This will be the // funding input. InputTxOut *wire.TxOut + + // TapscriptRoot is an optional tweak that can be used to modify the + // MuSig2 public key used in the session. + TapscriptTweak fn.Option[chainhash.Hash] } // MusigPairSession houses the two musig2 sessions needed to do funding and @@ -561,13 +611,14 @@ func NewMusigPairSession(cfg *MusigSessionCfg) *MusigPairSession { // // Both sessions will be created using only the verification nonce for // the local+remote party. + tapscriptTweak := fn.MapOption(TapscriptRootToTweak)(cfg.TapscriptTweak) localSession := NewPartialMusigSession( - cfg.LocalNonce, cfg.LocalKey, cfg.RemoteKey, - cfg.Signer, cfg.InputTxOut, LocalMusigCommit, + cfg.LocalNonce, cfg.LocalKey, cfg.RemoteKey, cfg.Signer, + cfg.InputTxOut, LocalMusigCommit, tapscriptTweak, ) remoteSession := NewPartialMusigSession( - cfg.RemoteNonce, cfg.LocalKey, cfg.RemoteKey, - cfg.Signer, cfg.InputTxOut, RemoteMusigCommit, + cfg.RemoteNonce, cfg.LocalKey, cfg.RemoteKey, cfg.Signer, + cfg.InputTxOut, RemoteMusigCommit, tapscriptTweak, ) return &MusigPairSession{ diff --git a/lnwallet/payment_descriptor.go b/lnwallet/payment_descriptor.go index f0ac8b9f7..6fe74bb6f 100644 --- a/lnwallet/payment_descriptor.go +++ b/lnwallet/payment_descriptor.go @@ -221,4 +221,8 @@ type PaymentDescriptor struct { // blinded route (ie, not the introduction node) from update_add_htlc's // TLVs. BlindingPoint lnwire.BlindingPointRecord + + // CustomRecords also stores the set of optional custom records that + // may have been attached to a sent HTLC. + CustomRecords lnwire.CustomRecords } diff --git a/lnwallet/reservation.go b/lnwallet/reservation.go index df6c8fd94..1f0000e8e 100644 --- a/lnwallet/reservation.go +++ b/lnwallet/reservation.go @@ -415,6 +415,10 @@ func NewChannelReservation(capacity, localFundingAmt btcutil.Amount, chanType |= channeldb.ScidAliasFeatureBit } + if req.TapscriptRoot.IsSome() { + chanType |= channeldb.TapscriptRootBit + } + return &ChannelReservation{ ourContribution: &ChannelContribution{ FundingAmount: ourBalance.ToSatoshis(), @@ -448,6 +452,7 @@ func NewChannelReservation(capacity, localFundingAmt btcutil.Amount, InitialLocalBalance: ourBalance, InitialRemoteBalance: theirBalance, Memo: req.Memo, + TapscriptRoot: req.TapscriptRoot, }, pushMSat: req.PushMSat, pendingChanID: req.PendingChanID, diff --git a/lnwallet/test_utils.go b/lnwallet/test_utils.go index 126b76d6c..2c12c83ae 100644 --- a/lnwallet/test_utils.go +++ b/lnwallet/test_utils.go @@ -14,6 +14,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" @@ -348,6 +349,21 @@ func CreateTestChannels(t *testing.T, chanType channeldb.ChannelType, Packager: channeldb.NewChannelPackager(shortChanID), } + // If the channel type has a tapscript root, then we'll also specify + // one here to apply to both the channels. + if chanType.HasTapscriptRoot() { + var tapscriptRoot chainhash.Hash + _, err := io.ReadFull(rand.Reader, tapscriptRoot[:]) + if err != nil { + return nil, nil, err + } + + someRoot := fn.Some(tapscriptRoot) + + aliceChannelState.TapscriptRoot = someRoot + bobChannelState.TapscriptRoot = someRoot + } + aliceSigner := input.NewMockSigner(aliceKeys, nil) bobSigner := input.NewMockSigner(bobKeys, nil) diff --git a/lnwallet/transactions.go b/lnwallet/transactions.go index 1cf954d3c..da86650bc 100644 --- a/lnwallet/transactions.go +++ b/lnwallet/transactions.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/input" ) const ( @@ -50,8 +51,8 @@ var ( // - func CreateHtlcSuccessTx(chanType channeldb.ChannelType, initiator bool, htlcOutput wire.OutPoint, htlcAmt btcutil.Amount, csvDelay, - leaseExpiry uint32, revocationKey, delayKey *btcec.PublicKey) ( - *wire.MsgTx, error) { + leaseExpiry uint32, revocationKey, delayKey *btcec.PublicKey, + auxLeaf input.AuxTapLeaf) (*wire.MsgTx, error) { // Create a version two transaction (as the success version of this // spends an output with a CSV timeout). @@ -71,7 +72,7 @@ func CreateHtlcSuccessTx(chanType channeldb.ChannelType, initiator bool, // HTLC outputs. scriptInfo, err := SecondLevelHtlcScript( chanType, initiator, revocationKey, delayKey, csvDelay, - leaseExpiry, + leaseExpiry, auxLeaf, ) if err != nil { return nil, err @@ -110,7 +111,8 @@ func CreateHtlcSuccessTx(chanType channeldb.ChannelType, initiator bool, func CreateHtlcTimeoutTx(chanType channeldb.ChannelType, initiator bool, htlcOutput wire.OutPoint, htlcAmt btcutil.Amount, cltvExpiry, csvDelay, leaseExpiry uint32, - revocationKey, delayKey *btcec.PublicKey) (*wire.MsgTx, error) { + revocationKey, delayKey *btcec.PublicKey, + auxLeaf input.AuxTapLeaf) (*wire.MsgTx, error) { // Create a version two transaction (as the success version of this // spends an output with a CSV timeout), and set the lock-time to the @@ -134,7 +136,7 @@ func CreateHtlcTimeoutTx(chanType channeldb.ChannelType, initiator bool, // HTLC outputs. scriptInfo, err := SecondLevelHtlcScript( chanType, initiator, revocationKey, delayKey, csvDelay, - leaseExpiry, + leaseExpiry, auxLeaf, ) if err != nil { return nil, err diff --git a/lnwallet/transactions_test.go b/lnwallet/transactions_test.go index e9751c6b9..439e7ce95 100644 --- a/lnwallet/transactions_test.go +++ b/lnwallet/transactions_test.go @@ -21,6 +21,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" @@ -631,6 +632,7 @@ func testSpendValidation(t *testing.T, tweakless bool) { commitmentTx, err := CreateCommitTx( channelType, *fakeFundingTxIn, keyRing, aliceChanCfg, bobChanCfg, channelBalance, channelBalance, 0, true, 0, + fn.None[CommitAuxLeaves](), ) if err != nil { t.Fatalf("unable to create commitment transaction: %v", nil) diff --git a/lnwallet/wallet.go b/lnwallet/wallet.go index 6925446a3..7e455ab48 100644 --- a/lnwallet/wallet.go +++ b/lnwallet/wallet.go @@ -23,6 +23,7 @@ import ( "github.com/btcsuite/btcwallet/wallet" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" @@ -210,6 +211,11 @@ type InitFundingReserveMsg struct { // channel that will be useful to our future selves. Memo []byte + // TapscriptRoot is the root of the tapscript tree that will be used to + // create the funding output. This is an optional field that should + // only be set for taproot channels. + TapscriptRoot fn.Option[chainhash.Hash] + // err is a channel in which all errors will be sent across. Will be // nil if this initial set is successful. // @@ -1464,6 +1470,21 @@ func (l *LightningWallet) handleFundingCancelRequest(req *fundingReserveCancelMs req.err <- nil } +// createCommitOpts is a struct that holds the options for creating a new +// commitment transaction. +type createCommitOpts struct { + auxLeaves fn.Option[CommitAuxLeaves] +} + +// defaultCommitOpts returns a new createCommitOpts with default values. +func defaultCommitOpts() createCommitOpts { + return createCommitOpts{} +} + +// CreateCommitOpt is a functional option that can be used to modify the way a +// new commitment transaction is created. +type CreateCommitOpt func(*createCommitOpts) + // CreateCommitmentTxns is a helper function that creates the initial // commitment transaction for both parties. This function is used during the // initial funding workflow as both sides must generate a signature for the @@ -1473,7 +1494,13 @@ func CreateCommitmentTxns(localBalance, remoteBalance btcutil.Amount, ourChanCfg, theirChanCfg *channeldb.ChannelConfig, localCommitPoint, remoteCommitPoint *btcec.PublicKey, fundingTxIn wire.TxIn, chanType channeldb.ChannelType, initiator bool, - leaseExpiry uint32) (*wire.MsgTx, *wire.MsgTx, error) { + leaseExpiry uint32, opts ...CreateCommitOpt) (*wire.MsgTx, *wire.MsgTx, + error) { + + options := defaultCommitOpts() + for _, optFunc := range opts { + optFunc(&options) + } localCommitmentKeys := DeriveCommitmentKeys( localCommitPoint, lntypes.Local, chanType, ourChanCfg, @@ -1487,7 +1514,7 @@ func CreateCommitmentTxns(localBalance, remoteBalance btcutil.Amount, ourCommitTx, err := CreateCommitTx( chanType, fundingTxIn, localCommitmentKeys, ourChanCfg, theirChanCfg, localBalance, remoteBalance, 0, initiator, - leaseExpiry, + leaseExpiry, options.auxLeaves, ) if err != nil { return nil, nil, err @@ -1501,7 +1528,7 @@ func CreateCommitmentTxns(localBalance, remoteBalance btcutil.Amount, theirCommitTx, err := CreateCommitTx( chanType, fundingTxIn, remoteCommitmentKeys, theirChanCfg, ourChanCfg, remoteBalance, localBalance, 0, !initiator, - leaseExpiry, + leaseExpiry, options.auxLeaves, ) if err != nil { return nil, nil, err @@ -2102,6 +2129,7 @@ func (l *LightningWallet) verifyCommitSig(res *ChannelReservation, if res.musigSessions == nil { _, fundingOutput, err := input.GenTaprootFundingScript( localKey, remoteKey, channelValue, + res.partialState.TapscriptRoot, ) if err != nil { return err @@ -2341,11 +2369,14 @@ func (l *LightningWallet) handleSingleFunderSigs(req *addSingleFunderSigsMsg) { fundingTxOut *wire.TxOut ) if chanType.IsTaproot() { - fundingWitnessScript, fundingTxOut, err = input.GenTaprootFundingScript( //nolint:lll + //nolint:lll + fundingWitnessScript, fundingTxOut, err = input.GenTaprootFundingScript( ourKey.PubKey, theirKey.PubKey, channelValue, + pendingReservation.partialState.TapscriptRoot, ) } else { - fundingWitnessScript, fundingTxOut, err = input.GenFundingPkScript( //nolint:lll + //nolint:lll + fundingWitnessScript, fundingTxOut, err = input.GenFundingPkScript( ourKey.PubKey.SerializeCompressed(), theirKey.PubKey.SerializeCompressed(), channelValue, ) @@ -2465,9 +2496,16 @@ func initStateHints(commit1, commit2 *wire.MsgTx, func (l *LightningWallet) ValidateChannel(channelState *channeldb.OpenChannel, fundingTx *wire.MsgTx) error { + var chanOpts []ChannelOpt + l.Cfg.AuxLeafStore.WhenSome(func(s AuxLeafStore) { + chanOpts = append(chanOpts, WithLeafStore(s)) + }) + // First, we'll obtain a fully signed commitment transaction so we can // pass into it on the chanvalidate package for verification. - channel, err := NewLightningChannel(l.Cfg.Signer, channelState, nil) + channel, err := NewLightningChannel( + l.Cfg.Signer, channelState, nil, chanOpts..., + ) if err != nil { return err } @@ -2482,6 +2520,7 @@ func (l *LightningWallet) ValidateChannel(channelState *channeldb.OpenChannel, if channelState.ChanType.IsTaproot() { fundingScript, _, err = input.GenTaprootFundingScript( localKey, remoteKey, int64(channel.Capacity), + channelState.TapscriptRoot, ) if err != nil { return err diff --git a/peer/brontide.go b/peer/brontide.go index 25c7cea6f..1324044da 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -372,6 +372,10 @@ type Config struct { AddLocalAlias func(alias, base lnwire.ShortChannelID, gossip bool) error + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] + // PongBuf is a slice we'll reuse instead of allocating memory on the // heap. Since only reads will occur and no writes, there is no need // for any synchronization primitives. As a result, it's safe to share @@ -943,8 +947,12 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( } } + var chanOpts []lnwallet.ChannelOpt + p.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) lnChan, err := lnwallet.NewLightningChannel( - p.cfg.Signer, dbChan, p.cfg.SigPool, + p.cfg.Signer, dbChan, p.cfg.SigPool, chanOpts..., ) if err != nil { return nil, fmt.Errorf("unable to create channel "+ @@ -4151,6 +4159,10 @@ func (p *Brontide) addActiveChannel(c *lnpeer.NewChannel) error { chanOpts = append(chanOpts, lnwallet.WithSkipNonceInit()) } + p.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) + // If not already active, we'll add this channel to the set of active // channels, so we can look it up later easily according to its channel // ID. diff --git a/peer/musig_chan_closer.go b/peer/musig_chan_closer.go index 6b05b1e62..6f69a8c5b 100644 --- a/peer/musig_chan_closer.go +++ b/peer/musig_chan_closer.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chancloser" @@ -43,10 +44,15 @@ func (m *MusigChanCloser) ProposalClosingOpts() ( } localKey, remoteKey := m.channel.MultiSigKeys() + + tapscriptTweak := fn.MapOption(lnwallet.TapscriptRootToTweak)( + m.channel.State().TapscriptRoot, + ) + m.musigSession = lnwallet.NewPartialMusigSession( *m.remoteNonce, localKey, remoteKey, m.channel.Signer, m.channel.FundingTxOut(), - lnwallet.RemoteMusigCommit, + lnwallet.RemoteMusigCommit, tapscriptTweak, ) err := m.musigSession.FinalizeSession(*m.localNonce) diff --git a/server.go b/server.go index 75fd02b97..1f8db90a6 100644 --- a/server.go +++ b/server.go @@ -1273,6 +1273,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return &pc.Incoming }, + AuxLeafStore: implCfg.AuxLeafStore, }, dbs.ChanStateDB) // Select the configuration and funding parameters for Bitcoin. @@ -1607,6 +1608,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, br, err := lnwallet.NewBreachRetribution( channel, commitHeight, 0, nil, + implCfg.AuxLeafStore, ) if err != nil { return nil, 0, err @@ -4073,6 +4075,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, DisallowRouteBlinding: s.cfg.ProtocolOptions.NoRouteBlinding(), MaxFeeExposure: thresholdMSats, Quit: s.quit, + AuxLeafStore: s.implCfg.AuxLeafStore, MsgRouter: s.implCfg.MsgRouter, } diff --git a/watchtower/blob/justice_kit.go b/watchtower/blob/justice_kit.go index 8b6c20194..7780239f0 100644 --- a/watchtower/blob/justice_kit.go +++ b/watchtower/blob/justice_kit.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" secp "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -307,9 +308,11 @@ func newTaprootJusticeKit(sweepScript []byte, keyRing := breachInfo.KeyRing + // TODO(roasbeef): aux leaf tower updates needed + tree, err := input.NewLocalCommitScriptTree( breachInfo.RemoteDelay, keyRing.ToLocalKey, - keyRing.RevocationKey, + keyRing.RevocationKey, fn.None[txscript.TapLeaf](), ) if err != nil { return nil, err @@ -416,7 +419,9 @@ func (t *taprootJusticeKit) ToRemoteOutputSpendInfo() (*txscript.PkScript, return nil, nil, 0, err } - scriptTree, err := input.NewRemoteCommitScriptTree(toRemotePk) + scriptTree, err := input.NewRemoteCommitScriptTree( + toRemotePk, fn.None[txscript.TapLeaf](), + ) if err != nil { return nil, nil, 0, err } diff --git a/watchtower/blob/justice_kit_test.go b/watchtower/blob/justice_kit_test.go index fd12993a0..a1d6ec9f2 100644 --- a/watchtower/blob/justice_kit_test.go +++ b/watchtower/blob/justice_kit_test.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -304,7 +305,9 @@ func TestJusticeKitRemoteWitnessConstruction(t *testing.T) { name: "taproot commitment", blobType: TypeAltruistTaprootCommit, expWitnessScript: func(pk *btcec.PublicKey) []byte { - tree, _ := input.NewRemoteCommitScriptTree(pk) + tree, _ := input.NewRemoteCommitScriptTree( + pk, fn.None[txscript.TapLeaf](), + ) return tree.SettleLeaf.Script }, @@ -461,6 +464,7 @@ func TestJusticeKitToLocalWitnessConstruction(t *testing.T) { script, _ := input.NewLocalCommitScriptTree( csvDelay, delay, rev, + fn.None[txscript.TapLeaf](), ) return script.RevocationLeaf.Script diff --git a/watchtower/lookout/justice_descriptor_test.go b/watchtower/lookout/justice_descriptor_test.go index a07b440ad..5045b4a0f 100644 --- a/watchtower/lookout/justice_descriptor_test.go +++ b/watchtower/lookout/justice_descriptor_test.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" secp "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -123,7 +124,7 @@ func testJusticeDescriptor(t *testing.T, blobType blob.Type) { if isTaprootChannel { toLocalCommitTree, err = input.NewLocalCommitScriptTree( - csvDelay, toLocalPK, revPK, + csvDelay, toLocalPK, revPK, fn.None[txscript.TapLeaf](), ) require.NoError(t, err) @@ -174,7 +175,7 @@ func testJusticeDescriptor(t *testing.T, blobType blob.Type) { toRemoteSequence = 1 commitScriptTree, err := input.NewRemoteCommitScriptTree( - toRemotePK, + toRemotePK, fn.None[txscript.TapLeaf](), ) require.NoError(t, err) diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index 695c4f9ec..7894631b8 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -136,6 +137,7 @@ func genTaskTest( if chanType.IsTaproot() { scriptTree, _ := input.NewLocalCommitScriptTree( csvDelay, toLocalPK, revPK, + fn.None[txscript.TapLeaf](), ) pkScript, _ := input.PayToTaprootScript( @@ -189,7 +191,7 @@ func genTaskTest( if chanType.IsTaproot() { scriptTree, _ := input.NewRemoteCommitScriptTree( - toRemotePK, + toRemotePK, fn.None[txscript.TapLeaf](), ) pkScript, _ := input.PayToTaprootScript( diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 38d9acd9f..f3a4d5bf4 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -19,6 +19,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -230,12 +231,14 @@ func (c *mockChannel) createRemoteCommitTx(t *testing.T) { // Construct the to-local witness script. toLocalScriptTree, err := input.NewLocalCommitScriptTree( - c.csvDelay, c.toLocalPK, c.revPK, + c.csvDelay, c.toLocalPK, c.revPK, fn.None[txscript.TapLeaf](), ) require.NoError(t, err, "unable to create to-local script") // Construct the to-remote witness script. - toRemoteScriptTree, err := input.NewRemoteCommitScriptTree(c.toRemotePK) + toRemoteScriptTree, err := input.NewRemoteCommitScriptTree( + c.toRemotePK, fn.None[txscript.TapLeaf](), + ) require.NoError(t, err, "unable to create to-remote script") // Compute the to-local witness script hash.