From ecca095a9b26080421cb76c507d2dbace660d468 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Tue, 12 Mar 2024 12:25:53 -0400 Subject: [PATCH 01/22] channeldb: consolidate root bucket TLVs into new struct In this commit, we consolidate the root bucket TLVs into a new struct. This makes it easier to see all the new TLV fields at a glance. We also convert TLV usage to use the new type param based APis. --- channeldb/channel.go | 248 ++++++++++++++++++++++++++----------------- 1 file changed, 149 insertions(+), 99 deletions(-) diff --git a/channeldb/channel.go b/channeldb/channel.go index 4e3db2fc1..ec449613e 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -226,28 +226,83 @@ const ( // A tlv type definition used to serialize an outpoint's indexStatus // for use in the outpoint index. indexStatusType tlv.Type = 0 - - // A tlv type definition used to serialize and deserialize a KeyLocator - // from the database. - keyLocType tlv.Type = 1 - - // A tlv type used to serialize and deserialize the - // `InitialLocalBalance` field. - initialLocalBalanceType tlv.Type = 2 - - // A tlv type used to serialize and deserialize the - // `InitialRemoteBalance` field. - initialRemoteBalanceType tlv.Type = 3 - - // A tlv type definition used to serialize and deserialize the - // confirmed ShortChannelID for a zero-conf channel. - realScidType tlv.Type = 4 - - // A tlv type definition used to serialize and deserialize the - // Memo for the channel channel. - channelMemoType tlv.Type = 5 ) +// openChannelTlvData houses the new data fields that are stored for each +// channel in a TLV stream within the root bucket. This is stored as a TLV +// stream appended to the existing hard-coded fields in the channel's root +// bucket. New fields being added to the channel state should be added here. +// +// NOTE: This struct is used for serialization purposes only and its fields +// should be accessed via the OpenChannel struct while in memory. +type openChannelTlvData struct { + // revokeKeyLoc is the key locator for the revocation key. + revokeKeyLoc tlv.RecordT[tlv.TlvType1, keyLocRecord] + + // initialLocalBalance is the initial local balance of the channel. + initialLocalBalance tlv.RecordT[tlv.TlvType2, uint64] + + // initialRemoteBalance is the initial remote balance of the channel. + initialRemoteBalance tlv.RecordT[tlv.TlvType3, uint64] + + // realScid is the real short channel ID of the channel corresponding to + // the on-chain outpoint. + realScid tlv.RecordT[tlv.TlvType4, lnwire.ShortChannelID] + + // memo is an optional text field that gives context to the user about + // the channel. + memo tlv.OptionalRecordT[tlv.TlvType5, []byte] +} + +// encode serializes the openChannelTlvData to the given io.Writer. +func (c *openChannelTlvData) encode(w io.Writer) error { + tlvRecords := []tlv.Record{ + c.revokeKeyLoc.Record(), + c.initialLocalBalance.Record(), + c.initialRemoteBalance.Record(), + c.realScid.Record(), + } + c.memo.WhenSome(func(memo tlv.RecordT[tlv.TlvType5, []byte]) { + tlvRecords = append(tlvRecords, memo.Record()) + }) + + // Create the tlv stream. + tlvStream, err := tlv.NewStream(tlvRecords...) + if err != nil { + return err + } + + return tlvStream.Encode(w) +} + +// decode deserializes the openChannelTlvData from the given io.Reader. +func (c *openChannelTlvData) decode(r io.Reader) error { + memo := c.memo.Zero() + + // Create the tlv stream. + tlvStream, err := tlv.NewStream( + c.revokeKeyLoc.Record(), + c.initialLocalBalance.Record(), + c.initialRemoteBalance.Record(), + c.realScid.Record(), + memo.Record(), + ) + if err != nil { + return err + } + + tlvs, err := tlvStream.DecodeWithParsedTypes(r) + if err != nil { + return err + } + + if _, ok := tlvs[memo.TlvType()]; ok { + c.memo = tlv.SomeRecordT(memo) + } + + return nil +} + // indexStatus is an enum-like type that describes what state the // outpoint is in. Currently only two possible values. type indexStatus uint8 @@ -867,6 +922,10 @@ type OpenChannel struct { // channel that will be useful to our future selves. Memo []byte + // TapscriptRoot is an optional tapscript root used to derive the MuSig2 + // funding output. + TapscriptRoot fn.Option[chainhash.Hash] + // TODO(roasbeef): eww Db *ChannelStateDB @@ -1025,6 +1084,48 @@ func (c *OpenChannel) SetBroadcastHeight(height uint32) { c.FundingBroadcastHeight = height } +// amendTlvData updates the channel with the given auxiliary TLV data. +func (c *OpenChannel) amendTlvData(auxData openChannelTlvData) { + c.RevocationKeyLocator = auxData.revokeKeyLoc.Val.KeyLocator + c.InitialLocalBalance = lnwire.MilliSatoshi( + auxData.initialLocalBalance.Val, + ) + c.InitialRemoteBalance = lnwire.MilliSatoshi( + auxData.initialRemoteBalance.Val, + ) + c.confirmedScid = auxData.realScid.Val + + auxData.memo.WhenSomeV(func(memo []byte) { + c.Memo = memo + }) +} + +// extractTlvData creates a new openChannelTlvData from the given channel. +func (c *OpenChannel) extractTlvData() openChannelTlvData { + auxData := openChannelTlvData{ + revokeKeyLoc: tlv.NewRecordT[tlv.TlvType1]( + keyLocRecord{c.RevocationKeyLocator}, + ), + initialLocalBalance: tlv.NewPrimitiveRecord[tlv.TlvType2]( + uint64(c.InitialLocalBalance), + ), + initialRemoteBalance: tlv.NewPrimitiveRecord[tlv.TlvType3]( + uint64(c.InitialRemoteBalance), + ), + realScid: tlv.NewRecordT[tlv.TlvType4]( + c.confirmedScid, + ), + } + + if len(c.Memo) != 0 { + auxData.memo = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5](c.Memo), + ) + } + + return auxData +} + // Refresh updates the in-memory channel state using the latest state observed // on disk. func (c *OpenChannel) Refresh() error { @@ -4030,32 +4131,9 @@ func putChanInfo(chanBucket kvdb.RwBucket, channel *OpenChannel) error { return err } - // Convert balance fields into uint64. - localBalance := uint64(channel.InitialLocalBalance) - remoteBalance := uint64(channel.InitialRemoteBalance) - - // Create the tlv stream. - tlvStream, err := tlv.NewStream( - // Write the RevocationKeyLocator as the first entry in a tlv - // stream. - MakeKeyLocRecord( - keyLocType, &channel.RevocationKeyLocator, - ), - tlv.MakePrimitiveRecord( - initialLocalBalanceType, &localBalance, - ), - tlv.MakePrimitiveRecord( - initialRemoteBalanceType, &remoteBalance, - ), - MakeScidRecord(realScidType, &channel.confirmedScid), - tlv.MakePrimitiveRecord(channelMemoType, &channel.Memo), - ) - if err != nil { - return err - } - - if err := tlvStream.Encode(&w); err != nil { - return err + auxData := channel.extractTlvData() + if err := auxData.encode(&w); err != nil { + return fmt.Errorf("unable to encode aux data: %w", err) } if err := chanBucket.Put(chanInfoKey, w.Bytes()); err != nil { @@ -4244,45 +4322,14 @@ func fetchChanInfo(chanBucket kvdb.RBucket, channel *OpenChannel) error { } } - // Create balance fields in uint64, and Memo field as byte slice. - var ( - localBalance uint64 - remoteBalance uint64 - memo []byte - ) - - // Create the tlv stream. - tlvStream, err := tlv.NewStream( - // Write the RevocationKeyLocator as the first entry in a tlv - // stream. - MakeKeyLocRecord( - keyLocType, &channel.RevocationKeyLocator, - ), - tlv.MakePrimitiveRecord( - initialLocalBalanceType, &localBalance, - ), - tlv.MakePrimitiveRecord( - initialRemoteBalanceType, &remoteBalance, - ), - MakeScidRecord(realScidType, &channel.confirmedScid), - tlv.MakePrimitiveRecord(channelMemoType, &memo), - ) - if err != nil { - return err + var auxData openChannelTlvData + if err := auxData.decode(r); err != nil { + return fmt.Errorf("unable to decode aux data: %w", err) } - if err := tlvStream.Decode(r); err != nil { - return err - } - - // Attach the balance fields. - channel.InitialLocalBalance = lnwire.MilliSatoshi(localBalance) - channel.InitialRemoteBalance = lnwire.MilliSatoshi(remoteBalance) - - // Attach the memo field if non-empty. - if len(memo) > 0 { - channel.Memo = memo - } + // Assign all the relevant fields from the aux data into the actual + // open channel. + channel.amendTlvData(auxData) channel.Packager = NewChannelPackager(channel.ShortChannelID) @@ -4440,6 +4487,25 @@ func deleteThawHeight(chanBucket kvdb.RwBucket) error { return chanBucket.Delete(frozenChanKey) } +// keyLocRecord is a wrapper struct around keychain.KeyLocator to implement the +// tlv.RecordProducer interface. +type keyLocRecord struct { + keychain.KeyLocator +} + +// Record creates a Record out of a KeyLocator using the passed Type and the +// EKeyLocator and DKeyLocator functions. The size will always be 8 as +// KeyFamily is uint32 and the Index is uint32. +// +// NOTE: This is part of the tlv.RecordProducer interface. +func (k *keyLocRecord) Record() tlv.Record { + // Note that we set the type here as zero, as when used with a + // tlv.RecordT, the type param will be used as the type. + return tlv.MakeStaticRecord( + 0, &k.KeyLocator, 8, EKeyLocator, DKeyLocator, + ) +} + // EKeyLocator is an encoder for keychain.KeyLocator. func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error { if v, ok := val.(*keychain.KeyLocator); ok { @@ -4468,22 +4534,6 @@ func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { return tlv.NewTypeForDecodingErr(val, "keychain.KeyLocator", l, 8) } -// MakeKeyLocRecord creates a Record out of a KeyLocator using the passed -// Type and the EKeyLocator and DKeyLocator functions. The size will always be -// 8 as KeyFamily is uint32 and the Index is uint32. -func MakeKeyLocRecord(typ tlv.Type, keyLoc *keychain.KeyLocator) tlv.Record { - return tlv.MakeStaticRecord(typ, keyLoc, 8, EKeyLocator, DKeyLocator) -} - -// MakeScidRecord creates a Record out of a ShortChannelID using the passed -// Type and the EShortChannelID and DShortChannelID functions. The size will -// always be 8 for the ShortChannelID. -func MakeScidRecord(typ tlv.Type, scid *lnwire.ShortChannelID) tlv.Record { - return tlv.MakeStaticRecord( - typ, scid, 8, lnwire.EShortChannelID, lnwire.DShortChannelID, - ) -} - // ShutdownInfo contains various info about the shutdown initiation of a // channel. type ShutdownInfo struct { From edf959d39f430858f974b5133805cb0d6e0bc533 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 13 Mar 2024 09:15:54 -0400 Subject: [PATCH 02/22] channeldb: add optional TapscriptRoot field + feature bit --- channeldb/channel.go | 33 +++++++++++++++++++++++++++++++++ channeldb/channel_test.go | 8 +++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/channeldb/channel.go b/channeldb/channel.go index ec449613e..f29a9ec19 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -252,6 +252,10 @@ type openChannelTlvData struct { // memo is an optional text field that gives context to the user about // the channel. memo tlv.OptionalRecordT[tlv.TlvType5, []byte] + + // tapscriptRoot is the optional Tapscript root the channel funding + // output commits to. + tapscriptRoot tlv.OptionalRecordT[tlv.TlvType6, [32]byte] } // encode serializes the openChannelTlvData to the given io.Writer. @@ -265,6 +269,11 @@ func (c *openChannelTlvData) encode(w io.Writer) error { c.memo.WhenSome(func(memo tlv.RecordT[tlv.TlvType5, []byte]) { tlvRecords = append(tlvRecords, memo.Record()) }) + c.tapscriptRoot.WhenSome( + func(root tlv.RecordT[tlv.TlvType6, [32]byte]) { + tlvRecords = append(tlvRecords, root.Record()) + }, + ) // Create the tlv stream. tlvStream, err := tlv.NewStream(tlvRecords...) @@ -278,6 +287,7 @@ func (c *openChannelTlvData) encode(w io.Writer) error { // decode deserializes the openChannelTlvData from the given io.Reader. func (c *openChannelTlvData) decode(r io.Reader) error { memo := c.memo.Zero() + tapscriptRoot := c.tapscriptRoot.Zero() // Create the tlv stream. tlvStream, err := tlv.NewStream( @@ -286,6 +296,7 @@ func (c *openChannelTlvData) decode(r io.Reader) error { c.initialRemoteBalance.Record(), c.realScid.Record(), memo.Record(), + tapscriptRoot.Record(), ) if err != nil { return err @@ -299,6 +310,9 @@ func (c *openChannelTlvData) decode(r io.Reader) error { if _, ok := tlvs[memo.TlvType()]; ok { c.memo = tlv.SomeRecordT(memo) } + if _, ok := tlvs[tapscriptRoot.TlvType()]; ok { + c.tapscriptRoot = tlv.SomeRecordT(tapscriptRoot) + } return nil } @@ -380,6 +394,11 @@ const ( // SimpleTaprootFeatureBit indicates that the simple-taproot-chans // feature bit was negotiated during the lifetime of the channel. SimpleTaprootFeatureBit ChannelType = 1 << 10 + + // TapscriptRootBit indicates that this is a MuSig2 channel with a top + // level tapscript commitment. This MUST be set along with the + // SimpleTaprootFeatureBit. + TapscriptRootBit ChannelType = 1 << 11 ) // IsSingleFunder returns true if the channel type if one of the known single @@ -450,6 +469,12 @@ func (c ChannelType) IsTaproot() bool { return c&SimpleTaprootFeatureBit == SimpleTaprootFeatureBit } +// HasTapscriptRoot returns true if the channel is using a top level tapscript +// root commitment. +func (c ChannelType) HasTapscriptRoot() bool { + return c&TapscriptRootBit == TapscriptRootBit +} + // ChannelStateBounds are the parameters from OpenChannel and AcceptChannel // that are responsible for providing bounds on the state space of the abstract // channel state. These values must be remembered for normal channel operation @@ -1098,6 +1123,9 @@ func (c *OpenChannel) amendTlvData(auxData openChannelTlvData) { auxData.memo.WhenSomeV(func(memo []byte) { c.Memo = memo }) + auxData.tapscriptRoot.WhenSomeV(func(h [32]byte) { + c.TapscriptRoot = fn.Some[chainhash.Hash](h) + }) } // extractTlvData creates a new openChannelTlvData from the given channel. @@ -1122,6 +1150,11 @@ func (c *OpenChannel) extractTlvData() openChannelTlvData { tlv.NewPrimitiveRecord[tlv.TlvType5](c.Memo), ) } + c.TapscriptRoot.WhenSome(func(h chainhash.Hash) { + auxData.tapscriptRoot = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, [32]byte](h), + ) + }) return auxData } diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index a7f3c1ebe..84aae9622 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -17,6 +17,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnmock" @@ -173,7 +174,7 @@ func fundingPointOption(chanPoint wire.OutPoint) testChannelOption { } // channelIDOption is an option which sets the short channel ID of the channel. -var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption { +func channelIDOption(chanID lnwire.ShortChannelID) testChannelOption { return func(params *testChannelParams) { params.channel.ShortChannelID = chanID } @@ -326,6 +327,9 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { uniqueOutputIndex.Add(1) op := wire.OutPoint{Hash: key, Index: uniqueOutputIndex.Load()} + var tapscriptRoot chainhash.Hash + copy(tapscriptRoot[:], bytes.Repeat([]byte{1}, 32)) + return &OpenChannel{ ChanType: SingleFunderBit | FrozenBit, ChainHash: key, @@ -368,6 +372,8 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { ThawHeight: uint32(defaultPendingHeight), InitialLocalBalance: lnwire.MilliSatoshi(9000), InitialRemoteBalance: lnwire.MilliSatoshi(3000), + Memo: []byte("test"), + TapscriptRoot: fn.Some(tapscriptRoot), } } From 2c56b3120a97324c138b844857436136d87f7eb5 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 13 Mar 2024 10:46:33 -0400 Subject: [PATCH 03/22] multi: add new tapscript root option to GenTaprootFundingScript This'll allow us to create a funding output that uses musig2, but uses a tapscript tweak rather than a normal BIP 86 tweak. --- contractcourt/chain_watcher.go | 2 +- funding/manager.go | 2 ++ graph/builder.go | 4 +++- input/script_utils.go | 24 +++++++++++++++--------- itest/lnd_funding_test.go | 2 ++ lnwallet/chanfunding/canned_assembler.go | 7 ++++--- lnwallet/channel.go | 1 + lnwallet/wallet.go | 10 ++++++++-- 8 files changed, 36 insertions(+), 16 deletions(-) diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 3cbc7422d..155fca2a9 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -308,7 +308,7 @@ func (c *chainWatcher) Start() error { ) if chanState.ChanType.IsTaproot() { c.fundingPkScript, _, err = input.GenTaprootFundingScript( - localKey, remoteKey, 0, + localKey, remoteKey, 0, fn.None[chainhash.Hash](), ) if err != nil { return err diff --git a/funding/manager.go b/funding/manager.go index 61b03200e..ad6286090 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -24,6 +24,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/discovery" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -2899,6 +2900,7 @@ func makeFundingScript(channel *channeldb.OpenChannel) ([]byte, error) { if channel.ChanType.IsTaproot() { pkScript, _, err := input.GenTaprootFundingScript( localKey, remoteKey, int64(channel.Capacity), + fn.None[chainhash.Hash](), ) if err != nil { return nil, err diff --git a/graph/builder.go b/graph/builder.go index 82a36eb36..7f7823ab4 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -11,12 +11,14 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnutils" @@ -1138,7 +1140,7 @@ func makeFundingScript(bitcoinKey1, bitcoinKey2 []byte, } fundingScript, _, err := input.GenTaprootFundingScript( - pubKey1, pubKey2, 0, + pubKey1, pubKey2, 0, fn.None[chainhash.Hash](), ) if err != nil { return nil, err diff --git a/input/script_utils.go b/input/script_utils.go index d801846f8..d3a6dc9d3 100644 --- a/input/script_utils.go +++ b/input/script_utils.go @@ -11,8 +11,10 @@ import ( "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "golang.org/x/crypto/ripemd160" @@ -199,26 +201,30 @@ func GenFundingPkScript(aPub, bPub []byte, amt int64) ([]byte, *wire.TxOut, erro } // GenTaprootFundingScript constructs the taproot-native funding output that -// uses musig2 to create a single aggregated key to anchor the channel. +// uses MuSig2 to create a single aggregated key to anchor the channel. func GenTaprootFundingScript(aPub, bPub *btcec.PublicKey, - amt int64) ([]byte, *wire.TxOut, error) { + amt int64, tapscriptRoot fn.Option[chainhash.Hash]) ([]byte, + *wire.TxOut, error) { + + muSig2Opt := musig2.WithBIP86KeyTweak() + tapscriptRoot.WhenSome(func(scriptRoot chainhash.Hash) { + muSig2Opt = musig2.WithTaprootKeyTweak(scriptRoot[:]) + }) // Similar to the existing p2wsh funding script, we'll always make sure // we sort the keys before any major operations. In order to ensure // that there's no other way this output can be spent, we'll use a BIP - // 86 tweak here during aggregation. - // - // TODO(roasbeef): revisit if BIP 86 is needed here? + // 86 tweak here during aggregation, unless the user has explicitly + // specified a tapscript root. combinedKey, _, _, err := musig2.AggregateKeys( - []*btcec.PublicKey{aPub, bPub}, true, - musig2.WithBIP86KeyTweak(), + []*btcec.PublicKey{aPub, bPub}, true, muSig2Opt, ) if err != nil { return nil, nil, fmt.Errorf("unable to combine keys: %w", err) } // Now that we have the combined key, we can create a taproot pkScript - // from this, and then make the txout given the amount. + // from this, and then make the txOut given the amount. pkScript, err := PayToTaprootScript(combinedKey.FinalKey) if err != nil { return nil, nil, fmt.Errorf("unable to make taproot "+ @@ -228,7 +234,7 @@ func GenTaprootFundingScript(aPub, bPub *btcec.PublicKey, txOut := wire.NewTxOut(amt, pkScript) // For the "witness program" we just return the raw pkScript since the - // output we create can _only_ be spent with a musig2 signature. + // output we create can _only_ be spent with a MuSig2 signature. return pkScript, txOut, nil } diff --git a/itest/lnd_funding_test.go b/itest/lnd_funding_test.go index 6e2f0070c..a1c2e292d 100644 --- a/itest/lnd_funding_test.go +++ b/itest/lnd_funding_test.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainreg" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" @@ -1192,6 +1193,7 @@ func deriveFundingShim(ht *lntest.HarnessTest, carol, dave *node.HarnessNode, _, fundingOutput, err = input.GenTaprootFundingScript( carolKey, daveKey, int64(chanSize), + fn.None[chainhash.Hash](), ) require.NoError(ht, err) diff --git a/lnwallet/chanfunding/canned_assembler.go b/lnwallet/chanfunding/canned_assembler.go index 21dd47339..02dd7169d 100644 --- a/lnwallet/chanfunding/canned_assembler.go +++ b/lnwallet/chanfunding/canned_assembler.go @@ -5,7 +5,9 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" ) @@ -76,9 +78,8 @@ func (s *ShimIntent) FundingOutput() ([]byte, *wire.TxOut, error) { // Similar to the existing p2wsh script, we'll always ensure // the keys are sorted before use. return input.GenTaprootFundingScript( - s.localKey.PubKey, - s.remoteKey, - int64(totalAmt), + s.localKey.PubKey, s.remoteKey, int64(totalAmt), + fn.None[chainhash.Hash](), ) } diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 7bd382d2c..cac399cf1 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -1013,6 +1013,7 @@ func (lc *LightningChannel) createSignDesc() error { if chanState.ChanType.IsTaproot() { fundingPkScript, _, err = input.GenTaprootFundingScript( localKey, remoteKey, int64(lc.channelState.Capacity), + fn.None[chainhash.Hash](), ) if err != nil { return err diff --git a/lnwallet/wallet.go b/lnwallet/wallet.go index 6925446a3..be9046397 100644 --- a/lnwallet/wallet.go +++ b/lnwallet/wallet.go @@ -23,6 +23,7 @@ import ( "github.com/btcsuite/btcwallet/wallet" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" @@ -2102,6 +2103,7 @@ func (l *LightningWallet) verifyCommitSig(res *ChannelReservation, if res.musigSessions == nil { _, fundingOutput, err := input.GenTaprootFundingScript( localKey, remoteKey, channelValue, + fn.None[chainhash.Hash](), ) if err != nil { return err @@ -2341,11 +2343,14 @@ func (l *LightningWallet) handleSingleFunderSigs(req *addSingleFunderSigsMsg) { fundingTxOut *wire.TxOut ) if chanType.IsTaproot() { - fundingWitnessScript, fundingTxOut, err = input.GenTaprootFundingScript( //nolint:lll + //nolint:lll + fundingWitnessScript, fundingTxOut, err = input.GenTaprootFundingScript( ourKey.PubKey, theirKey.PubKey, channelValue, + fn.None[chainhash.Hash](), ) } else { - fundingWitnessScript, fundingTxOut, err = input.GenFundingPkScript( //nolint:lll + //nolint:lll + fundingWitnessScript, fundingTxOut, err = input.GenFundingPkScript( ourKey.PubKey.SerializeCompressed(), theirKey.PubKey.SerializeCompressed(), channelValue, ) @@ -2482,6 +2487,7 @@ func (l *LightningWallet) ValidateChannel(channelState *channeldb.OpenChannel, if channelState.ChanType.IsTaproot() { fundingScript, _, err = input.GenTaprootFundingScript( localKey, remoteKey, int64(channel.Capacity), + fn.None[chainhash.Hash](), ) if err != nil { return err From eeee2979e59bbd1182529430fe81bec995730e4c Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 13 Mar 2024 10:47:12 -0400 Subject: [PATCH 04/22] lnwallet/chanfunding: add optional tapscript root --- lnwallet/chanfunding/canned_assembler.go | 10 +++++++++- lnwallet/chanfunding/psbt_assembler.go | 1 + lnwallet/chanfunding/wallet_assembler.go | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/lnwallet/chanfunding/canned_assembler.go b/lnwallet/chanfunding/canned_assembler.go index 02dd7169d..177cc35c3 100644 --- a/lnwallet/chanfunding/canned_assembler.go +++ b/lnwallet/chanfunding/canned_assembler.go @@ -58,6 +58,14 @@ type ShimIntent struct { // generate an aggregate key to use as the taproot-native multi-sig // output. musig2 bool + + // tapscriptRoot is the root of the tapscript tree that will be used to + // create the funding output. This field will only be utilized if the + // MuSig2 flag above is set to true. + // + // TODO(roasbeef): fold above into new chan type? sum type like thing, + // includes the tapscript root, etc + tapscriptRoot fn.Option[chainhash.Hash] } // FundingOutput returns the witness script, and the output that creates the @@ -79,7 +87,7 @@ func (s *ShimIntent) FundingOutput() ([]byte, *wire.TxOut, error) { // the keys are sorted before use. return input.GenTaprootFundingScript( s.localKey.PubKey, s.remoteKey, int64(totalAmt), - fn.None[chainhash.Hash](), + s.tapscriptRoot, ) } diff --git a/lnwallet/chanfunding/psbt_assembler.go b/lnwallet/chanfunding/psbt_assembler.go index 885fb7b46..10bcd7015 100644 --- a/lnwallet/chanfunding/psbt_assembler.go +++ b/lnwallet/chanfunding/psbt_assembler.go @@ -534,6 +534,7 @@ func (p *PsbtAssembler) ProvisionChannel(req *Request) (Intent, error) { ShimIntent: ShimIntent{ localFundingAmt: p.fundingAmt, musig2: req.Musig2, + tapscriptRoot: req.TapscriptRoot, }, State: PsbtShimRegistered, BasePsbt: p.basePsbt, diff --git a/lnwallet/chanfunding/wallet_assembler.go b/lnwallet/chanfunding/wallet_assembler.go index f824210e1..3d62649cb 100644 --- a/lnwallet/chanfunding/wallet_assembler.go +++ b/lnwallet/chanfunding/wallet_assembler.go @@ -394,7 +394,6 @@ func (w *WalletAssembler) ProvisionChannel(r *Request) (Intent, error) { // we will call the specialized coin selection function for // that. case r.FundUpToMaxAmt != 0 && r.MinFundAmt != 0: - // We need to ensure that manually selected coins, which // are spent entirely on the channel funding, leave // enough funds in the wallet to cover for a reserve. @@ -539,6 +538,7 @@ func (w *WalletAssembler) ProvisionChannel(r *Request) (Intent, error) { localFundingAmt: localContributionAmt, remoteFundingAmt: r.RemoteAmt, musig2: r.Musig2, + tapscriptRoot: r.TapscriptRoot, }, InputCoins: selectedCoins, coinLeaser: w.cfg.CoinLeaser, From 142b408be7baf13a25e612ed728d91e906bc4b27 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 13 Mar 2024 10:52:07 -0400 Subject: [PATCH 05/22] multi: update GenTaprootFundingScript to pass tapscript root In most cases, we won't yet be passing a root. The option usage helps us keep the control flow mostly unchanged. --- contractcourt/chain_watcher.go | 2 +- funding/manager.go | 3 +-- graph/builder.go | 2 ++ 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 155fca2a9..8b2f75816 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -308,7 +308,7 @@ func (c *chainWatcher) Start() error { ) if chanState.ChanType.IsTaproot() { c.fundingPkScript, _, err = input.GenTaprootFundingScript( - localKey, remoteKey, 0, fn.None[chainhash.Hash](), + localKey, remoteKey, 0, chanState.TapscriptRoot, ) if err != nil { return err diff --git a/funding/manager.go b/funding/manager.go index ad6286090..b22f2e691 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -24,7 +24,6 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/discovery" - "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -2900,7 +2899,7 @@ func makeFundingScript(channel *channeldb.OpenChannel) ([]byte, error) { if channel.ChanType.IsTaproot() { pkScript, _, err := input.GenTaprootFundingScript( localKey, remoteKey, int64(channel.Capacity), - fn.None[chainhash.Hash](), + channel.TapscriptRoot, ) if err != nil { return nil, err diff --git a/graph/builder.go b/graph/builder.go index 7f7823ab4..717d3e5ad 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -1146,6 +1146,8 @@ func makeFundingScript(bitcoinKey1, bitcoinKey2 []byte, return nil, err } + // TODO(roasbeef): add tapscript root to gossip v1.5 + return fundingScript, nil } From c8b7987a39149e41bca9ba1c742b4a06ab05b468 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 13 Mar 2024 10:53:41 -0400 Subject: [PATCH 06/22] lnwallet: update internal funding flow w/ tapscript root This isn't hooked up yet to the funding manager, but with this commit, we can now start to write internal unit tests that handle musig2 channels with a tapscript root. --- lnwallet/reservation.go | 5 +++++ lnwallet/wallet.go | 11 ++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/lnwallet/reservation.go b/lnwallet/reservation.go index df6c8fd94..1f0000e8e 100644 --- a/lnwallet/reservation.go +++ b/lnwallet/reservation.go @@ -415,6 +415,10 @@ func NewChannelReservation(capacity, localFundingAmt btcutil.Amount, chanType |= channeldb.ScidAliasFeatureBit } + if req.TapscriptRoot.IsSome() { + chanType |= channeldb.TapscriptRootBit + } + return &ChannelReservation{ ourContribution: &ChannelContribution{ FundingAmount: ourBalance.ToSatoshis(), @@ -448,6 +452,7 @@ func NewChannelReservation(capacity, localFundingAmt btcutil.Amount, InitialLocalBalance: ourBalance, InitialRemoteBalance: theirBalance, Memo: req.Memo, + TapscriptRoot: req.TapscriptRoot, }, pushMSat: req.PushMSat, pendingChanID: req.PendingChanID, diff --git a/lnwallet/wallet.go b/lnwallet/wallet.go index be9046397..4f51eaf99 100644 --- a/lnwallet/wallet.go +++ b/lnwallet/wallet.go @@ -211,6 +211,11 @@ type InitFundingReserveMsg struct { // channel that will be useful to our future selves. Memo []byte + // TapscriptRoot is the root of the tapscript tree that will be used to + // create the funding output. This is an optional field that should + // only be set for taproot channels. + TapscriptRoot fn.Option[chainhash.Hash] + // err is a channel in which all errors will be sent across. Will be // nil if this initial set is successful. // @@ -2103,7 +2108,7 @@ func (l *LightningWallet) verifyCommitSig(res *ChannelReservation, if res.musigSessions == nil { _, fundingOutput, err := input.GenTaprootFundingScript( localKey, remoteKey, channelValue, - fn.None[chainhash.Hash](), + res.partialState.TapscriptRoot, ) if err != nil { return err @@ -2346,7 +2351,7 @@ func (l *LightningWallet) handleSingleFunderSigs(req *addSingleFunderSigsMsg) { //nolint:lll fundingWitnessScript, fundingTxOut, err = input.GenTaprootFundingScript( ourKey.PubKey, theirKey.PubKey, channelValue, - fn.None[chainhash.Hash](), + pendingReservation.partialState.TapscriptRoot, ) } else { //nolint:lll @@ -2487,7 +2492,7 @@ func (l *LightningWallet) ValidateChannel(channelState *channeldb.OpenChannel, if channelState.ChanType.IsTaproot() { fundingScript, _, err = input.GenTaprootFundingScript( localKey, remoteKey, int64(channel.Capacity), - fn.None[chainhash.Hash](), + channelState.TapscriptRoot, ) if err != nil { return err From 82ba5bf0bf6c61544e1becaa8ccde08d66e87f8a Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 13 Mar 2024 10:54:49 -0400 Subject: [PATCH 07/22] lnwallet+peer: add tapscript root awareness to musig2 sessions With this commit, the channel is now aware of if it's a musig2 channel, that also has a tapscript root. We'll need to always pass in the tapscript root each time we: make the funding output, sign a new state, and also verify a new state. --- lnwallet/chancloser/chancloser_test.go | 7 +- lnwallet/channel.go | 21 +++-- lnwallet/musig_session.go | 107 ++++++++++++++++++------- peer/musig_chan_closer.go | 8 +- 4 files changed, 104 insertions(+), 39 deletions(-) diff --git a/lnwallet/chancloser/chancloser_test.go b/lnwallet/chancloser/chancloser_test.go index 9a90d0ab2..a6688ed39 100644 --- a/lnwallet/chancloser/chancloser_test.go +++ b/lnwallet/chancloser/chancloser_test.go @@ -14,6 +14,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" @@ -178,8 +179,9 @@ func (m *mockChannel) RemoteUpfrontShutdownScript() lnwire.DeliveryAddress { } func (m *mockChannel) CreateCloseProposal(fee btcutil.Amount, - localScript, remoteScript []byte, _ ...lnwallet.ChanCloseOpt, -) (input.Signature, *chainhash.Hash, btcutil.Amount, error) { + localScript, remoteScript []byte, + _ ...lnwallet.ChanCloseOpt) (input.Signature, *chainhash.Hash, + btcutil.Amount, error) { if m.chanType.IsTaproot() { return lnwallet.NewMusigPartialSig( @@ -188,6 +190,7 @@ func (m *mockChannel) CreateCloseProposal(fee btcutil.Amount, R: new(btcec.PublicKey), }, lnwire.Musig2Nonce{}, lnwire.Musig2Nonce{}, nil, + fn.None[chainhash.Hash](), ), nil, 0, nil } diff --git a/lnwallet/channel.go b/lnwallet/channel.go index cac399cf1..b807bbee3 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -1013,7 +1013,7 @@ func (lc *LightningChannel) createSignDesc() error { if chanState.ChanType.IsTaproot() { fundingPkScript, _, err = input.GenTaprootFundingScript( localKey, remoteKey, int64(lc.channelState.Capacity), - fn.None[chainhash.Hash](), + chanState.TapscriptRoot, ) if err != nil { return err @@ -6057,11 +6057,15 @@ func (lc *LightningChannel) getSignedCommitTx() (*wire.MsgTx, error) { "verification nonce: %w", err) } + tapscriptTweak := fn.MapOption(TapscriptRootToTweak)( + lc.channelState.TapscriptRoot, + ) + // Now that we have the local nonce, we'll re-create the musig // session we had for this height. musigSession := NewPartialMusigSession( *localNonce, ourKey, theirKey, lc.Signer, - &lc.fundingOutput, LocalMusigCommit, + &lc.fundingOutput, LocalMusigCommit, tapscriptTweak, ) var remoteSig lnwire.PartialSigWithNonce @@ -8672,12 +8676,13 @@ func (lc *LightningChannel) InitRemoteMusigNonces(remoteNonce *musig2.Nonces, // TODO(roasbeef): propagate rename of signing and verification nonces sessionCfg := &MusigSessionCfg{ - LocalKey: localChanCfg.MultiSigKey, - RemoteKey: remoteChanCfg.MultiSigKey, - LocalNonce: *localNonce, - RemoteNonce: *remoteNonce, - Signer: lc.Signer, - InputTxOut: &lc.fundingOutput, + LocalKey: localChanCfg.MultiSigKey, + RemoteKey: remoteChanCfg.MultiSigKey, + LocalNonce: *localNonce, + RemoteNonce: *remoteNonce, + Signer: lc.Signer, + InputTxOut: &lc.fundingOutput, + TapscriptTweak: lc.channelState.TapscriptRoot, } lc.musigSessions = NewMusigPairSession( sessionCfg, diff --git a/lnwallet/musig_session.go b/lnwallet/musig_session.go index ecc60d07f..c3214d3f2 100644 --- a/lnwallet/musig_session.go +++ b/lnwallet/musig_session.go @@ -8,8 +8,10 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" @@ -37,6 +39,20 @@ var ( ErrSessionNotFinalized = fmt.Errorf("musig2 session not finalized") ) +// tapscriptRootToSignOpt is a function that takes a tapscript root and returns +// a MuSig2 sign opt that'll apply the tweak when signing+verifying. +func tapscriptRootToSignOpt(root chainhash.Hash) musig2.SignOption { + return musig2.WithTaprootSignTweak(root[:]) +} + +// TapscriptRootToTweak is a helper function that converts a tapscript root +// into a tweak that can be used with the MuSig2 API. +func TapscriptRootToTweak(root chainhash.Hash) input.MuSig2Tweaks { + return input.MuSig2Tweaks{ + TaprootTweak: root[:], + } +} + // MusigPartialSig is a wrapper around the base musig2.PartialSignature type // that also includes information about the set of nonces used, and also the // signer. This allows us to implement the input.Signature interface, as that @@ -54,25 +70,30 @@ type MusigPartialSig struct { // signerKeys is the set of public keys of all signers. signerKeys []*btcec.PublicKey + + // tapscriptTweak is an optional tweak, that if specified, will be used + // instead of the normal BIP 86 tweak when validating the signature. + tapscriptTweak fn.Option[chainhash.Hash] } -// NewMusigPartialSig creates a new musig partial signature. -func NewMusigPartialSig(sig *musig2.PartialSignature, - signerNonce, combinedNonce lnwire.Musig2Nonce, - signerKeys []*btcec.PublicKey) *MusigPartialSig { +// NewMusigPartialSig creates a new MuSig2 partial signature. +func NewMusigPartialSig(sig *musig2.PartialSignature, signerNonce, + combinedNonce lnwire.Musig2Nonce, signerKeys []*btcec.PublicKey, + tapscriptTweak fn.Option[chainhash.Hash]) *MusigPartialSig { return &MusigPartialSig{ - sig: sig, - signerNonce: signerNonce, - combinedNonce: combinedNonce, - signerKeys: signerKeys, + sig: sig, + signerNonce: signerNonce, + combinedNonce: combinedNonce, + signerKeys: signerKeys, + tapscriptTweak: tapscriptTweak, } } // FromWireSig maps a wire partial sig to this internal type that we'll use to // perform signature validation. -func (p *MusigPartialSig) FromWireSig(sig *lnwire.PartialSigWithNonce, -) *MusigPartialSig { +func (p *MusigPartialSig) FromWireSig( + sig *lnwire.PartialSigWithNonce) *MusigPartialSig { p.sig = &musig2.PartialSignature{ S: &sig.Sig, @@ -135,9 +156,15 @@ func (p *MusigPartialSig) Verify(msg []byte, pub *btcec.PublicKey) bool { var m [32]byte copy(m[:], msg) + // If we have a tapscript tweak, then we'll use that as a tweak + // otherwise, we'll fall back to the normal BIP 86 sign tweak. + signOpts := fn.MapOption(tapscriptRootToSignOpt)( + p.tapscriptTweak, + ).UnwrapOr(musig2.WithBip86SignTweak()) + return p.sig.Verify( p.signerNonce, p.combinedNonce, p.signerKeys, pub, m, - musig2.WithSortedKeys(), musig2.WithBip86SignTweak(), + musig2.WithSortedKeys(), signOpts, ) } @@ -160,6 +187,14 @@ func (n *MusigNoncePair) String() string { n.SigningNonce.PubNonce[:]) } +// TapscriptRootToTweak is a function that takes a MuSig2 taproot tweak and +// returns the root hash of the tapscript tree. +func muSig2TweakToRoot(tweak input.MuSig2Tweaks) chainhash.Hash { + var root chainhash.Hash + copy(root[:], tweak.TaprootTweak) + return root +} + // MusigSession abstracts over the details of a logical musig session. A single // session is used for each commitment transactions. The sessions use a JIT // nonce style, wherein part of the session can be created using only the @@ -197,15 +232,20 @@ type MusigSession struct { // commitType tracks if this is the session for the local or remote // commitment. commitType MusigCommitType + + // tapscriptTweak is an optional tweak, that if specified, will be used + // instead of the normal BIP 86 tweak when creating the MuSig2 + // aggregate key and session. + tapscriptTweak fn.Option[input.MuSig2Tweaks] } // NewPartialMusigSession creates a new musig2 session given only the // verification nonce (local nonce), and the other information that has already // been bound to the session. func NewPartialMusigSession(verificationNonce musig2.Nonces, - localKey, remoteKey keychain.KeyDescriptor, - signer input.MuSig2Signer, inputTxOut *wire.TxOut, - commitType MusigCommitType) *MusigSession { + localKey, remoteKey keychain.KeyDescriptor, signer input.MuSig2Signer, + inputTxOut *wire.TxOut, commitType MusigCommitType, + tapscriptTweak fn.Option[input.MuSig2Tweaks]) *MusigSession { signerKeys := []*btcec.PublicKey{localKey.PubKey, remoteKey.PubKey} @@ -214,13 +254,14 @@ func NewPartialMusigSession(verificationNonce musig2.Nonces, } return &MusigSession{ - nonces: nonces, - remoteKey: remoteKey, - localKey: localKey, - inputTxOut: inputTxOut, - signerKeys: signerKeys, - signer: signer, - commitType: commitType, + nonces: nonces, + remoteKey: remoteKey, + localKey: localKey, + inputTxOut: inputTxOut, + signerKeys: signerKeys, + signer: signer, + commitType: commitType, + tapscriptTweak: tapscriptTweak, } } @@ -254,9 +295,9 @@ func (m *MusigSession) FinalizeSession(signingNonce musig2.Nonces) error { remoteNonce = m.nonces.SigningNonce } - tweakDesc := input.MuSig2Tweaks{ + tweakDesc := m.tapscriptTweak.UnwrapOr(input.MuSig2Tweaks{ TaprootBIP0086Tweak: true, - } + }) m.session, err = m.signer.MuSig2CreateSession( input.MuSig2Version100RC2, m.localKey.KeyLocator, m.signerKeys, &tweakDesc, [][musig2.PubNonceSize]byte{remoteNonce.PubNonce}, @@ -351,8 +392,11 @@ func (m *MusigSession) SignCommit(tx *wire.MsgTx) (*MusigPartialSig, error) { return nil, err } + tapscriptRoot := fn.MapOption(muSig2TweakToRoot)(m.tapscriptTweak) + return NewMusigPartialSig( sig, m.session.PublicNonce, m.combinedNonce, m.signerKeys, + tapscriptRoot, ), nil } @@ -364,7 +408,7 @@ func (m *MusigSession) Refresh(verificationNonce *musig2.Nonces, return NewPartialMusigSession( *verificationNonce, m.localKey, m.remoteKey, m.signer, - m.inputTxOut, m.commitType, + m.inputTxOut, m.commitType, m.tapscriptTweak, ), nil } @@ -451,9 +495,11 @@ func (m *MusigSession) VerifyCommitSig(commitTx *wire.MsgTx, // When we verify a commitment signature, we always assume that we're // verifying a signature on our local commitment. Therefore, we'll use: // their remote nonce, and also public key. + tapscriptRoot := fn.MapOption(muSig2TweakToRoot)(m.tapscriptTweak) partialSig := NewMusigPartialSig( &musig2.PartialSignature{S: &sig.Sig}, m.nonces.SigningNonce.PubNonce, m.combinedNonce, m.signerKeys, + tapscriptRoot, ) // With the partial sig loaded with the proper context, we'll now @@ -537,6 +583,10 @@ type MusigSessionCfg struct { // InputTxOut is the output that we're signing for. This will be the // funding input. InputTxOut *wire.TxOut + + // TapscriptRoot is an optional tweak that can be used to modify the + // MuSig2 public key used in the session. + TapscriptTweak fn.Option[chainhash.Hash] } // MusigPairSession houses the two musig2 sessions needed to do funding and @@ -561,13 +611,14 @@ func NewMusigPairSession(cfg *MusigSessionCfg) *MusigPairSession { // // Both sessions will be created using only the verification nonce for // the local+remote party. + tapscriptTweak := fn.MapOption(TapscriptRootToTweak)(cfg.TapscriptTweak) localSession := NewPartialMusigSession( - cfg.LocalNonce, cfg.LocalKey, cfg.RemoteKey, - cfg.Signer, cfg.InputTxOut, LocalMusigCommit, + cfg.LocalNonce, cfg.LocalKey, cfg.RemoteKey, cfg.Signer, + cfg.InputTxOut, LocalMusigCommit, tapscriptTweak, ) remoteSession := NewPartialMusigSession( - cfg.RemoteNonce, cfg.LocalKey, cfg.RemoteKey, - cfg.Signer, cfg.InputTxOut, RemoteMusigCommit, + cfg.RemoteNonce, cfg.LocalKey, cfg.RemoteKey, cfg.Signer, + cfg.InputTxOut, RemoteMusigCommit, tapscriptTweak, ) return &MusigPairSession{ diff --git a/peer/musig_chan_closer.go b/peer/musig_chan_closer.go index 6b05b1e62..6f69a8c5b 100644 --- a/peer/musig_chan_closer.go +++ b/peer/musig_chan_closer.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chancloser" @@ -43,10 +44,15 @@ func (m *MusigChanCloser) ProposalClosingOpts() ( } localKey, remoteKey := m.channel.MultiSigKeys() + + tapscriptTweak := fn.MapOption(lnwallet.TapscriptRootToTweak)( + m.channel.State().TapscriptRoot, + ) + m.musigSession = lnwallet.NewPartialMusigSession( *m.remoteNonce, localKey, remoteKey, m.channel.Signer, m.channel.FundingTxOut(), - lnwallet.RemoteMusigCommit, + lnwallet.RemoteMusigCommit, tapscriptTweak, ) err := m.musigSession.FinalizeSession(*m.localNonce) From 8588c9bfd7f521867b253568f33f281a9a2859f5 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 13 Mar 2024 10:55:27 -0400 Subject: [PATCH 08/22] lnwallet: add initial unit tests for musig2+tapscript root chans --- lnwallet/channel_test.go | 16 ++++++++++++++++ lnwallet/test_utils.go | 16 ++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 757a808fe..ac72a6ba6 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -386,6 +386,12 @@ func TestSimpleAddSettleWorkflow(t *testing.T) { ) }) + t.Run("taproot with tapscript root", func(t *testing.T) { + flags := channeldb.SimpleTaprootFeatureBit | + channeldb.TapscriptRootBit + testAddSettleWorkflow(t, true, flags, false) + }) + t.Run("storeFinalHtlcResolutions=true", func(t *testing.T) { testAddSettleWorkflow(t, false, 0, true) }) @@ -828,6 +834,16 @@ func TestForceClose(t *testing.T) { anchorAmt: AnchorSize * 2, }) }) + t.Run("taproot with tapscript root", func(t *testing.T) { + testForceClose(t, &forceCloseTestCase{ + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit | + channeldb.SimpleTaprootFeatureBit | + channeldb.TapscriptRootBit, + expectedCommitWeight: input.TaprootCommitWeight, + anchorAmt: AnchorSize * 2, + }) + }) } type forceCloseTestCase struct { diff --git a/lnwallet/test_utils.go b/lnwallet/test_utils.go index 126b76d6c..2c12c83ae 100644 --- a/lnwallet/test_utils.go +++ b/lnwallet/test_utils.go @@ -14,6 +14,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" @@ -348,6 +349,21 @@ func CreateTestChannels(t *testing.T, chanType channeldb.ChannelType, Packager: channeldb.NewChannelPackager(shortChanID), } + // If the channel type has a tapscript root, then we'll also specify + // one here to apply to both the channels. + if chanType.HasTapscriptRoot() { + var tapscriptRoot chainhash.Hash + _, err := io.ReadFull(rand.Reader, tapscriptRoot[:]) + if err != nil { + return nil, nil, err + } + + someRoot := fn.Some(tapscriptRoot) + + aliceChannelState.TapscriptRoot = someRoot + bobChannelState.TapscriptRoot = someRoot + } + aliceSigner := input.NewMockSigner(aliceKeys, nil) bobSigner := input.NewMockSigner(bobKeys, nil) From 1aae47fd71f5d038c69be6891b677c648f77193a Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Fri, 15 Mar 2024 11:43:21 -0400 Subject: [PATCH 09/22] input+lnwallet: update taproot scripts to accept optional aux leaf In this commit, we update all the taproot scripts to also accept an optional aux leaf. This aux leaf can be used to add more redemption paths for advanced channels, or just as an extra commitment space. --- input/script_utils.go | 113 +++++++---- input/size_test.go | 21 +- input/taproot.go | 10 + input/taproot_test.go | 189 +++++++++++++----- lnwallet/channel.go | 2 + lnwallet/commitment.go | 15 +- watchtower/blob/justice_kit.go | 9 +- watchtower/blob/justice_kit_test.go | 6 +- watchtower/lookout/justice_descriptor_test.go | 5 +- .../wtclient/backup_task_internal_test.go | 4 +- watchtower/wtclient/client_test.go | 7 +- 11 files changed, 273 insertions(+), 108 deletions(-) diff --git a/input/script_utils.go b/input/script_utils.go index d3a6dc9d3..91ca55292 100644 --- a/input/script_utils.go +++ b/input/script_utils.go @@ -646,6 +646,13 @@ type HtlcScriptTree struct { // TimeoutTapLeaf is the tapleaf for the timeout path. TimeoutTapLeaf txscript.TapLeaf + // AuxLeaf is an auxiliary leaf that can be used to extend the base + // HTLC script tree with new spend paths, or just as extra commitment + // space. When present, this leaf will always be in the right-most area + // of the tapscript tree. + AuxLeaf AuxTapLeaf + + // htlcType is the type of HTLC script this is. htlcType htlcType } @@ -726,8 +733,8 @@ var _ TapscriptDescriptor = (*HtlcScriptTree)(nil) // senderHtlcTapScriptTree builds the tapscript tree which is used to anchor // the HTLC key for HTLCs on the sender's commitment. func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, - revokeKey *btcec.PublicKey, payHash []byte, - hType htlcType) (*HtlcScriptTree, error) { + revokeKey *btcec.PublicKey, payHash []byte, hType htlcType, + auxLeaf AuxTapLeaf) (*HtlcScriptTree, error) { // First, we'll obtain the tap leaves for both the success and timeout // path. @@ -744,11 +751,14 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, return nil, err } + tapLeaves := []txscript.TapLeaf{successTapLeaf, timeoutTapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + // With the two leaves obtained, we'll now make the tapscript tree, // then obtain the root from that - tapscriptTree := txscript.AssembleTaprootScriptTree( - successTapLeaf, timeoutTapLeaf, - ) + tapscriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...) tapScriptRoot := tapscriptTree.RootNode.TapHash() @@ -767,6 +777,7 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, }, SuccessTapLeaf: successTapLeaf, TimeoutTapLeaf: timeoutTapLeaf, + AuxLeaf: auxLeaf, htlcType: hType, }, nil } @@ -801,7 +812,8 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, // unilaterally spend the created output. func SenderHTLCScriptTaproot(senderHtlcKey, receiverHtlcKey, revokeKey *btcec.PublicKey, payHash []byte, - whoseCommit lntypes.ChannelParty) (*HtlcScriptTree, error) { + whoseCommit lntypes.ChannelParty, auxLeaf AuxTapLeaf) (*HtlcScriptTree, + error) { var hType htlcType if whoseCommit.IsLocal() { @@ -814,8 +826,8 @@ func SenderHTLCScriptTaproot(senderHtlcKey, receiverHtlcKey, // tree that includes the top level output script, as well as the two // tap leaf paths. return senderHtlcTapScriptTree( - senderHtlcKey, receiverHtlcKey, revokeKey, payHash, - hType, + senderHtlcKey, receiverHtlcKey, revokeKey, payHash, hType, + auxLeaf, ) } @@ -1285,8 +1297,8 @@ func ReceiverHtlcTapLeafSuccess(receiverHtlcKey *btcec.PublicKey, // receiverHtlcTapScriptTree builds the tapscript tree which is used to anchor // the HTLC key for HTLCs on the receiver's commitment. func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, - revokeKey *btcec.PublicKey, payHash []byte, - cltvExpiry uint32, hType htlcType) (*HtlcScriptTree, error) { + revokeKey *btcec.PublicKey, payHash []byte, cltvExpiry uint32, + hType htlcType, auxLeaf AuxTapLeaf) (*HtlcScriptTree, error) { // First, we'll obtain the tap leaves for both the success and timeout // path. @@ -1303,11 +1315,14 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, return nil, err } + tapLeaves := []txscript.TapLeaf{timeoutTapLeaf, successTapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + // With the two leaves obtained, we'll now make the tapscript tree, // then obtain the root from that - tapscriptTree := txscript.AssembleTaprootScriptTree( - timeoutTapLeaf, successTapLeaf, - ) + tapscriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...) tapScriptRoot := tapscriptTree.RootNode.TapHash() @@ -1326,6 +1341,7 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, }, SuccessTapLeaf: successTapLeaf, TimeoutTapLeaf: timeoutTapLeaf, + AuxLeaf: auxLeaf, htlcType: hType, }, nil } @@ -1361,7 +1377,7 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, func ReceiverHTLCScriptTaproot(cltvExpiry uint32, senderHtlcKey, receiverHtlcKey, revocationKey *btcec.PublicKey, payHash []byte, whoseCommit lntypes.ChannelParty, -) (*HtlcScriptTree, error) { + auxLeaf AuxTapLeaf) (*HtlcScriptTree, error) { var hType htlcType if whoseCommit.IsLocal() { @@ -1375,7 +1391,7 @@ func ReceiverHTLCScriptTaproot(cltvExpiry uint32, // tap leaf paths. return receiverHtlcTapScriptTree( senderHtlcKey, receiverHtlcKey, revocationKey, payHash, - cltvExpiry, hType, + cltvExpiry, hType, auxLeaf, ) } @@ -1604,9 +1620,9 @@ func TaprootSecondLevelTapLeaf(delayKey *btcec.PublicKey, } // SecondLevelHtlcTapscriptTree construct the indexed tapscript tree needed to -// generate the taptweak to create the final output and also control block. -func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, - csvDelay uint32) (*txscript.IndexedTapScriptTree, error) { +// generate the tap tweak to create the final output and also control block. +func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, csvDelay uint32, + auxLeaf AuxTapLeaf) (*txscript.IndexedTapScriptTree, error) { // First grab the second level leaf script we need to create the top // level output. @@ -1615,9 +1631,14 @@ func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, return nil, err } + tapLeaves := []txscript.TapLeaf{secondLevelTapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + // Now that we have the sole second level script, we can create the // tapscript tree that commits to both the leaves. - return txscript.AssembleTaprootScriptTree(secondLevelTapLeaf), nil + return txscript.AssembleTaprootScriptTree(tapLeaves...), nil } // TaprootSecondLevelHtlcScript is the uniform script that's used as the output @@ -1637,12 +1658,12 @@ func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, // // The keyspend path require knowledge of the top level revocation private key. func TaprootSecondLevelHtlcScript(revokeKey, delayKey *btcec.PublicKey, - csvDelay uint32) (*btcec.PublicKey, error) { + csvDelay uint32, auxLeaf AuxTapLeaf) (*btcec.PublicKey, error) { // First, we'll make the tapscript tree that commits to the redemption // path. tapScriptTree, err := SecondLevelHtlcTapscriptTree( - delayKey, csvDelay, + delayKey, csvDelay, auxLeaf, ) if err != nil { return nil, err @@ -1667,17 +1688,21 @@ type SecondLevelScriptTree struct { // SuccessTapLeaf is the tapleaf for the redemption path. SuccessTapLeaf txscript.TapLeaf + + // AuxLeaf is an optional leaf that can be used to extend the script + // tree. + AuxLeaf AuxTapLeaf } // TaprootSecondLevelScriptTree constructs the tapscript tree used to spend the // second level HTLC output. func TaprootSecondLevelScriptTree(revokeKey, delayKey *btcec.PublicKey, - csvDelay uint32) (*SecondLevelScriptTree, error) { + csvDelay uint32, auxLeaf AuxTapLeaf) (*SecondLevelScriptTree, error) { // First, we'll make the tapscript tree that commits to the redemption // path. tapScriptTree, err := SecondLevelHtlcTapscriptTree( - delayKey, csvDelay, + delayKey, csvDelay, auxLeaf, ) if err != nil { return nil, err @@ -1698,6 +1723,7 @@ func TaprootSecondLevelScriptTree(revokeKey, delayKey *btcec.PublicKey, InternalKey: revokeKey, }, SuccessTapLeaf: tapScriptTree.LeafMerkleProofs[0].TapLeaf, + AuxLeaf: auxLeaf, }, nil } @@ -2079,6 +2105,12 @@ type CommitScriptTree struct { // RevocationLeaf is the leaf used to spend the output with the // revocation key signature. RevocationLeaf txscript.TapLeaf + + // AuxLeaf is an auxiliary leaf that can be used to extend the base + // commitment script tree with new spend paths, or just as extra + // commitment space. When present, this leaf will always be in the + // left-most or right-most area of the tapscript tree. + AuxLeaf AuxTapLeaf } // A compile time check to ensure CommitScriptTree implements the @@ -2143,8 +2175,9 @@ func (c *CommitScriptTree) Tree() ScriptTree { // NewLocalCommitScriptTree returns a new CommitScript tree that can be used to // create and spend the commitment output for the local party. -func NewLocalCommitScriptTree(csvTimeout uint32, - selfKey, revokeKey *btcec.PublicKey) (*CommitScriptTree, error) { +func NewLocalCommitScriptTree(csvTimeout uint32, selfKey, + revokeKey *btcec.PublicKey, auxLeaf AuxTapLeaf) (*CommitScriptTree, + error) { // First, we'll need to construct the tapLeaf that'll be our delay CSV // clause. @@ -2164,9 +2197,13 @@ func NewLocalCommitScriptTree(csvTimeout uint32, // the two leaves, and then obtain a root from that. delayTapLeaf := txscript.NewBaseTapLeaf(delayScript) revokeTapLeaf := txscript.NewBaseTapLeaf(revokeScript) - tapScriptTree := txscript.AssembleTaprootScriptTree( - delayTapLeaf, revokeTapLeaf, - ) + + tapLeaves := []txscript.TapLeaf{delayTapLeaf, revokeTapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + + tapScriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...) tapScriptRoot := tapScriptTree.RootNode.TapHash() // Now that we have our root, we can arrive at the final output script @@ -2184,6 +2221,7 @@ func NewLocalCommitScriptTree(csvTimeout uint32, }, SettleLeaf: delayTapLeaf, RevocationLeaf: revokeTapLeaf, + AuxLeaf: auxLeaf, }, nil } @@ -2253,7 +2291,7 @@ func TaprootCommitScriptToSelf(csvTimeout uint32, selfKey, revokeKey *btcec.PublicKey) (*btcec.PublicKey, error) { commitScriptTree, err := NewLocalCommitScriptTree( - csvTimeout, selfKey, revokeKey, + csvTimeout, selfKey, revokeKey, NoneTapLeaf(), ) if err != nil { return nil, err @@ -2579,7 +2617,7 @@ func CommitScriptToRemoteConfirmed(key *btcec.PublicKey) ([]byte, error) { // NewRemoteCommitScriptTree constructs a new script tree for the remote party // to sweep their funds after a hard coded 1 block delay. func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey, -) (*CommitScriptTree, error) { + auxLeaf AuxTapLeaf) (*CommitScriptTree, error) { // First, construct the remote party's tapscript they'll use to sweep // their outputs. @@ -2595,10 +2633,16 @@ func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey, return nil, err } + tapLeaf := txscript.NewBaseTapLeaf(remoteScript) + + tapLeaves := []txscript.TapLeaf{tapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + // With this script constructed, we'll map that into a tapLeaf, then // make a new tapscript root from that. - tapLeaf := txscript.NewBaseTapLeaf(remoteScript) - tapScriptTree := txscript.AssembleTaprootScriptTree(tapLeaf) + tapScriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...) tapScriptRoot := tapScriptTree.RootNode.TapHash() // Now that we have our root, we can arrive at the final output script @@ -2615,6 +2659,7 @@ func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey, InternalKey: &TaprootNUMSKey, }, SettleLeaf: tapLeaf, + AuxLeaf: auxLeaf, }, nil } @@ -2631,9 +2676,9 @@ func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey, // OP_CHECKSIG // 1 OP_CHECKSEQUENCEVERIFY OP_DROP func TaprootCommitScriptToRemote(remoteKey *btcec.PublicKey, -) (*btcec.PublicKey, error) { + auxLeaf AuxTapLeaf) (*btcec.PublicKey, error) { - commitScriptTree, err := NewRemoteCommitScriptTree(remoteKey) + commitScriptTree, err := NewRemoteCommitScriptTree(remoteKey, auxLeaf) if err != nil { return nil, err } diff --git a/input/size_test.go b/input/size_test.go index daa7053cc..21e2c06ae 100644 --- a/input/size_test.go +++ b/input/size_test.go @@ -853,7 +853,7 @@ var witnessSizeTests = []witnessSizeTest{ signer := &dummySigner{} commitScriptTree, err := input.NewLocalCommitScriptTree( testCSVDelay, testKey.PubKey(), - testKey.PubKey(), + testKey.PubKey(), input.NoneTapLeaf(), ) require.NoError(t, err) @@ -887,7 +887,7 @@ var witnessSizeTests = []witnessSizeTest{ signer := &dummySigner{} commitScriptTree, err := input.NewLocalCommitScriptTree( testCSVDelay, testKey.PubKey(), - testKey.PubKey(), + testKey.PubKey(), input.NoneTapLeaf(), ) require.NoError(t, err) @@ -921,7 +921,7 @@ var witnessSizeTests = []witnessSizeTest{ signer := &dummySigner{} //nolint:lll commitScriptTree, err := input.NewRemoteCommitScriptTree( - testKey.PubKey(), + testKey.PubKey(), input.NoneTapLeaf(), ) require.NoError(t, err) @@ -988,6 +988,7 @@ var witnessSizeTests = []witnessSizeTest{ scriptTree, err := input.SecondLevelHtlcTapscriptTree( testKey.PubKey(), testCSVDelay, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1027,6 +1028,7 @@ var witnessSizeTests = []witnessSizeTest{ scriptTree, err := input.SecondLevelHtlcTapscriptTree( testKey.PubKey(), testCSVDelay, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1075,6 +1077,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), payHash[:], lntypes.Remote, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1116,7 +1119,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], lntypes.Remote, + payHash[:], lntypes.Remote, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1158,7 +1161,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], lntypes.Remote, + payHash[:], lntypes.Remote, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1205,6 +1208,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), payHash[:], lntypes.Remote, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1265,6 +1269,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), payHash[:], lntypes.Remote, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1310,7 +1315,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], lntypes.Remote, + payHash[:], lntypes.Remote, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1396,7 +1401,7 @@ func genTimeoutTx(t *testing.T, if chanType.IsTaproot() { tapscriptTree, err = input.SenderHTLCScriptTaproot( testPubkey, testPubkey, testPubkey, testHash160, - lntypes.Remote, + lntypes.Remote, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1465,7 +1470,7 @@ func genSuccessTx(t *testing.T, chanType channeldb.ChannelType) *wire.MsgTx { if chanType.IsTaproot() { tapscriptTree, err = input.ReceiverHTLCScriptTaproot( testCLTVExpiry, testPubkey, testPubkey, testPubkey, - testHash160, lntypes.Remote, + testHash160, lntypes.Remote, input.NoneTapLeaf(), ) require.NoError(t, err) diff --git a/input/taproot.go b/input/taproot.go index 34cdb974d..0fcd1df4e 100644 --- a/input/taproot.go +++ b/input/taproot.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/waddrmgr" + "github.com/lightningnetwork/lnd/fn" ) const ( @@ -21,6 +22,15 @@ const ( PubKeyFormatCompressedOdd byte = 0x03 ) +// AuxTapLeaf is a type alias for an optional tapscript leaf that may be added +// to the tapscript tree of HTLC and commitment outputs. +type AuxTapLeaf = fn.Option[txscript.TapLeaf] + +// NoneTapLeaf returns an empty optional tapscript leaf. +func NoneTapLeaf() AuxTapLeaf { + return fn.None[txscript.TapLeaf]() +} + // NewTxSigHashesV0Only returns a new txscript.TxSigHashes instance that will // only calculate the sighash midstate values for segwit v0 inputs and can // therefore never be used for transactions that want to spend segwit v1 diff --git a/input/taproot_test.go b/input/taproot_test.go index 434be2dfd..a1259be19 100644 --- a/input/taproot_test.go +++ b/input/taproot_test.go @@ -1,13 +1,16 @@ package input import ( + "bytes" "crypto/rand" + "fmt" "testing" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" "github.com/stretchr/testify/require" @@ -31,7 +34,9 @@ type testSenderHtlcScriptTree struct { htlcAmt int64 } -func newTestSenderHtlcScriptTree(t *testing.T) *testSenderHtlcScriptTree { +func newTestSenderHtlcScriptTree(t *testing.T, + auxLeaf AuxTapLeaf) *testSenderHtlcScriptTree { + var preImage lntypes.Preimage _, err := rand.Read(preImage[:]) require.NoError(t, err) @@ -48,7 +53,7 @@ func newTestSenderHtlcScriptTree(t *testing.T) *testSenderHtlcScriptTree { payHash := preImage.Hash() htlcScriptTree, err := SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], lntypes.Remote, + payHash[:], lntypes.Remote, auxLeaf, ) require.NoError(t, err) @@ -207,13 +212,9 @@ func htlcSenderTimeoutWitnessGen(sigHash txscript.SigHashType, } } -// TestTaprootSenderHtlcSpend tests that all the positive and negative paths -// for the sender HTLC tapscript tree work as expected. -func TestTaprootSenderHtlcSpend(t *testing.T) { - t.Parallel() - +func testTaprootSenderHtlcSpend(t *testing.T, auxLeaf AuxTapLeaf) { // First, create a new test script tree. - htlcScriptTree := newTestSenderHtlcScriptTree(t) + htlcScriptTree := newTestSenderHtlcScriptTree(t, auxLeaf) spendTx := wire.NewMsgTx(2) spendTx.AddTxIn(&wire.TxIn{}) @@ -432,6 +433,26 @@ func TestTaprootSenderHtlcSpend(t *testing.T) { } } +// TestTaprootSenderHtlcSpend tests that all the positive and negative paths +// for the sender HTLC tapscript tree work as expected. +func TestTaprootSenderHtlcSpend(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some(txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + )) + } + + testTaprootSenderHtlcSpend(t, auxLeaf) + }) + } +} + type testReceiverHtlcScriptTree struct { preImage lntypes.Preimage @@ -452,7 +473,9 @@ type testReceiverHtlcScriptTree struct { lockTime int32 } -func newTestReceiverHtlcScriptTree(t *testing.T) *testReceiverHtlcScriptTree { +func newTestReceiverHtlcScriptTree(t *testing.T, + auxLeaf AuxTapLeaf) *testReceiverHtlcScriptTree { + var preImage lntypes.Preimage _, err := rand.Read(preImage[:]) require.NoError(t, err) @@ -471,7 +494,7 @@ func newTestReceiverHtlcScriptTree(t *testing.T) *testReceiverHtlcScriptTree { payHash := preImage.Hash() htlcScriptTree, err := ReceiverHTLCScriptTaproot( cltvExpiry, senderKey.PubKey(), receiverKey.PubKey(), - revokeKey.PubKey(), payHash[:], lntypes.Remote, + revokeKey.PubKey(), payHash[:], lntypes.Remote, auxLeaf, ) require.NoError(t, err) @@ -629,15 +652,11 @@ func htlcReceiverSuccessWitnessGen(sigHash txscript.SigHashType, } } -// TestTaprootReceiverHtlcSpend tests that all possible paths for redeeming an -// accepted HTLC (on the commitment transaction) of the receiver work properly. -func TestTaprootReceiverHtlcSpend(t *testing.T) { - t.Parallel() - +func testTaprootReceiverHtlcSpend(t *testing.T, auxLeaf AuxTapLeaf) { // We'll start by creating the HTLC script tree (contains all 3 valid // spend paths), and also a mock spend transaction that we'll be // signing below. - htlcScriptTree := newTestReceiverHtlcScriptTree(t) + htlcScriptTree := newTestReceiverHtlcScriptTree(t, auxLeaf) // TODO(roasbeef): issue with revoke key??? ctrl block even/odd @@ -891,6 +910,28 @@ func TestTaprootReceiverHtlcSpend(t *testing.T) { } } +// TestTaprootReceiverHtlcSpend tests that all possible paths for redeeming an +// accepted HTLC (on the commitment transaction) of the receiver work properly. +func TestTaprootReceiverHtlcSpend(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some( + txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + ), + ) + } + + testTaprootReceiverHtlcSpend(t, auxLeaf) + }) + } +} + type testCommitScriptTree struct { csvDelay uint32 @@ -905,7 +946,9 @@ type testCommitScriptTree struct { *CommitScriptTree } -func newTestCommitScriptTree(local bool) (*testCommitScriptTree, error) { +func newTestCommitScriptTree(local bool, + auxLeaf AuxTapLeaf) (*testCommitScriptTree, error) { + selfKey, err := btcec.NewPrivateKey() if err != nil { return nil, err @@ -925,10 +968,11 @@ func newTestCommitScriptTree(local bool) (*testCommitScriptTree, error) { if local { commitScriptTree, err = NewLocalCommitScriptTree( csvDelay, selfKey.PubKey(), revokeKey.PubKey(), + auxLeaf, ) } else { commitScriptTree, err = NewRemoteCommitScriptTree( - selfKey.PubKey(), + selfKey.PubKey(), auxLeaf, ) } if err != nil { @@ -1020,12 +1064,8 @@ func localCommitRevokeWitGen(sigHash txscript.SigHashType, } } -// TestTaprootCommitScriptToSelf tests that the taproot script for redeeming -// one's output after a force close behaves as expected. -func TestTaprootCommitScriptToSelf(t *testing.T) { - t.Parallel() - - commitScriptTree, err := newTestCommitScriptTree(true) +func testTaprootCommitScriptToSelf(t *testing.T, auxLeaf AuxTapLeaf) { + commitScriptTree, err := newTestCommitScriptTree(true, auxLeaf) require.NoError(t, err) spendTx := wire.NewMsgTx(2) @@ -1187,6 +1227,26 @@ func TestTaprootCommitScriptToSelf(t *testing.T) { } } +// TestTaprootCommitScriptToSelf tests that the taproot script for redeeming +// one's output after a force close behaves as expected. +func TestTaprootCommitScriptToSelf(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some(txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + )) + } + + testTaprootCommitScriptToSelf(t, auxLeaf) + }) + } +} + func remoteCommitSweepWitGen(sigHash txscript.SigHashType, commitScriptTree *testCommitScriptTree) witnessGen { @@ -1220,12 +1280,8 @@ func remoteCommitSweepWitGen(sigHash txscript.SigHashType, } } -// TestTaprootCommitScriptRemote tests that the remote party can properly sweep -// their output after force close. -func TestTaprootCommitScriptRemote(t *testing.T) { - t.Parallel() - - commitScriptTree, err := newTestCommitScriptTree(false) +func testTaprootCommitScriptRemote(t *testing.T, auxLeaf AuxTapLeaf) { + commitScriptTree, err := newTestCommitScriptTree(false, auxLeaf) require.NoError(t, err) spendTx := wire.NewMsgTx(2) @@ -1364,6 +1420,26 @@ func TestTaprootCommitScriptRemote(t *testing.T) { } } +// TestTaprootCommitScriptRemote tests that the remote party can properly sweep +// their output after force close. +func TestTaprootCommitScriptRemote(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some(txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + )) + } + + testTaprootCommitScriptRemote(t, auxLeaf) + }) + } +} + type testAnchorScriptTree struct { sweepKey *btcec.PrivateKey @@ -1599,25 +1675,21 @@ type testSecondLevelHtlcTree struct { tapScriptRoot []byte } -func newTestSecondLevelHtlcTree() (*testSecondLevelHtlcTree, error) { +func newTestSecondLevelHtlcTree(t *testing.T, + auxLeaf AuxTapLeaf) *testSecondLevelHtlcTree { + delayKey, err := btcec.NewPrivateKey() - if err != nil { - return nil, err - } + require.NoError(t, err) revokeKey, err := btcec.NewPrivateKey() - if err != nil { - return nil, err - } + require.NoError(t, err) const csvDelay = 6 scriptTree, err := SecondLevelHtlcTapscriptTree( - delayKey.PubKey(), csvDelay, + delayKey.PubKey(), csvDelay, auxLeaf, ) - if err != nil { - return nil, err - } + require.NoError(t, err) tapScriptRoot := scriptTree.RootNode.TapHash() @@ -1626,9 +1698,7 @@ func newTestSecondLevelHtlcTree() (*testSecondLevelHtlcTree, error) { ) pkScript, err := PayToTaprootScript(htlcKey) - if err != nil { - return nil, err - } + require.NoError(t, err) const amt = 100 @@ -1643,7 +1713,7 @@ func newTestSecondLevelHtlcTree() (*testSecondLevelHtlcTree, error) { amt: amt, scriptTree: scriptTree, tapScriptRoot: tapScriptRoot[:], - }, nil + } } func secondLevelHtlcSuccessWitGen(sigHash txscript.SigHashType, @@ -1713,13 +1783,8 @@ func secondLevelHtlcRevokeWitnessgen(sigHash txscript.SigHashType, } } -// TestTaprootSecondLevelHtlcScript tests that a channel peer can properly -// spend the second level HTLC script to resolve HTLCs. -func TestTaprootSecondLevelHtlcScript(t *testing.T) { - t.Parallel() - - htlcScriptTree, err := newTestSecondLevelHtlcTree() - require.NoError(t, err) +func testTaprootSecondLevelHtlcScript(t *testing.T, auxLeaf AuxTapLeaf) { + htlcScriptTree := newTestSecondLevelHtlcTree(t, auxLeaf) spendTx := wire.NewMsgTx(2) spendTx.AddTxIn(&wire.TxIn{}) @@ -1879,3 +1944,23 @@ func TestTaprootSecondLevelHtlcScript(t *testing.T) { }) } } + +// TestTaprootSecondLevelHtlcScript tests that a channel peer can properly +// spend the second level HTLC script to resolve HTLCs. +func TestTaprootSecondLevelHtlcScript(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some(txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + )) + } + + testTaprootSecondLevelHtlcScript(t, auxLeaf) + }) + } +} diff --git a/lnwallet/channel.go b/lnwallet/channel.go index b807bbee3..4e934b5d6 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -6669,6 +6669,7 @@ func newOutgoingHtlcResolution(signer input.Signer, //nolint:lll secondLevelScriptTree, err := input.TaprootSecondLevelScriptTree( keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, + fn.None[txscript.TapLeaf](), ) if err != nil { return nil, err @@ -6915,6 +6916,7 @@ func newIncomingHtlcResolution(signer input.Signer, //nolint:lll secondLevelScriptTree, err := input.TaprootSecondLevelScriptTree( keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, + fn.None[txscript.TapLeaf](), ) if err != nil { return nil, err diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index f32408814..5af48eefa 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -225,9 +225,8 @@ func (w *WitnessScriptDesc) WitnessScriptForPath( // party learns of the preimage to the revocation hash, then they can claim all // the settled funds in the channel, plus the unsettled funds. func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool, - selfKey, revokeKey *btcec.PublicKey, csvDelay, leaseExpiry uint32, -) ( - input.ScriptDescriptor, error) { + selfKey, revokeKey *btcec.PublicKey, csvDelay, + leaseExpiry uint32) (input.ScriptDescriptor, error) { switch { // For taproot scripts, we'll need to make a slightly modified script @@ -237,7 +236,7 @@ func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool, // Our "redeem" script here is just the taproot witness program. case chanType.IsTaproot(): return input.NewLocalCommitScriptTree( - csvDelay, selfKey, revokeKey, + csvDelay, selfKey, revokeKey, input.NoneTapLeaf(), ) // If we are the initiator of a leased channel, then we have an @@ -321,7 +320,7 @@ func CommitScriptToRemote(chanType channeldb.ChannelType, initiator bool, // with the sole tap leaf enforcing the 1 CSV delay. case chanType.IsTaproot(): toRemoteScriptTree, err := input.NewRemoteCommitScriptTree( - remoteKey, + remoteKey, input.NoneTapLeaf(), ) if err != nil { return nil, 0, err @@ -427,7 +426,7 @@ func SecondLevelHtlcScript(chanType channeldb.ChannelType, initiator bool, // For taproot channels, the pkScript is a segwit v1 p2tr output. case chanType.IsTaproot(): return input.TaprootSecondLevelScriptTree( - revocationKey, delayKey, csvDelay, + revocationKey, delayKey, csvDelay, input.NoneTapLeaf(), ) // If we are the initiator of a leased channel, then we have an @@ -1095,6 +1094,7 @@ func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, keyRing.RevocationKey, rHash[:], whoseCommit, + input.NoneTapLeaf(), ) // We're being paid via an HTLC by the remote party, and the HTLC is @@ -1104,6 +1104,7 @@ func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, keyRing.RevocationKey, rHash[:], whoseCommit, + input.NoneTapLeaf(), ) // We're sending an HTLC which is being added to our commitment @@ -1113,6 +1114,7 @@ func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, keyRing.RevocationKey, rHash[:], whoseCommit, + input.NoneTapLeaf(), ) // Finally, we're paying the remote party via an HTLC, which is being @@ -1122,6 +1124,7 @@ func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, keyRing.RevocationKey, rHash[:], whoseCommit, + input.NoneTapLeaf(), ) } diff --git a/watchtower/blob/justice_kit.go b/watchtower/blob/justice_kit.go index 8b6c20194..7780239f0 100644 --- a/watchtower/blob/justice_kit.go +++ b/watchtower/blob/justice_kit.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" secp "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -307,9 +308,11 @@ func newTaprootJusticeKit(sweepScript []byte, keyRing := breachInfo.KeyRing + // TODO(roasbeef): aux leaf tower updates needed + tree, err := input.NewLocalCommitScriptTree( breachInfo.RemoteDelay, keyRing.ToLocalKey, - keyRing.RevocationKey, + keyRing.RevocationKey, fn.None[txscript.TapLeaf](), ) if err != nil { return nil, err @@ -416,7 +419,9 @@ func (t *taprootJusticeKit) ToRemoteOutputSpendInfo() (*txscript.PkScript, return nil, nil, 0, err } - scriptTree, err := input.NewRemoteCommitScriptTree(toRemotePk) + scriptTree, err := input.NewRemoteCommitScriptTree( + toRemotePk, fn.None[txscript.TapLeaf](), + ) if err != nil { return nil, nil, 0, err } diff --git a/watchtower/blob/justice_kit_test.go b/watchtower/blob/justice_kit_test.go index fd12993a0..a1d6ec9f2 100644 --- a/watchtower/blob/justice_kit_test.go +++ b/watchtower/blob/justice_kit_test.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -304,7 +305,9 @@ func TestJusticeKitRemoteWitnessConstruction(t *testing.T) { name: "taproot commitment", blobType: TypeAltruistTaprootCommit, expWitnessScript: func(pk *btcec.PublicKey) []byte { - tree, _ := input.NewRemoteCommitScriptTree(pk) + tree, _ := input.NewRemoteCommitScriptTree( + pk, fn.None[txscript.TapLeaf](), + ) return tree.SettleLeaf.Script }, @@ -461,6 +464,7 @@ func TestJusticeKitToLocalWitnessConstruction(t *testing.T) { script, _ := input.NewLocalCommitScriptTree( csvDelay, delay, rev, + fn.None[txscript.TapLeaf](), ) return script.RevocationLeaf.Script diff --git a/watchtower/lookout/justice_descriptor_test.go b/watchtower/lookout/justice_descriptor_test.go index a07b440ad..5045b4a0f 100644 --- a/watchtower/lookout/justice_descriptor_test.go +++ b/watchtower/lookout/justice_descriptor_test.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" secp "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -123,7 +124,7 @@ func testJusticeDescriptor(t *testing.T, blobType blob.Type) { if isTaprootChannel { toLocalCommitTree, err = input.NewLocalCommitScriptTree( - csvDelay, toLocalPK, revPK, + csvDelay, toLocalPK, revPK, fn.None[txscript.TapLeaf](), ) require.NoError(t, err) @@ -174,7 +175,7 @@ func testJusticeDescriptor(t *testing.T, blobType blob.Type) { toRemoteSequence = 1 commitScriptTree, err := input.NewRemoteCommitScriptTree( - toRemotePK, + toRemotePK, fn.None[txscript.TapLeaf](), ) require.NoError(t, err) diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index 695c4f9ec..7894631b8 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -136,6 +137,7 @@ func genTaskTest( if chanType.IsTaproot() { scriptTree, _ := input.NewLocalCommitScriptTree( csvDelay, toLocalPK, revPK, + fn.None[txscript.TapLeaf](), ) pkScript, _ := input.PayToTaprootScript( @@ -189,7 +191,7 @@ func genTaskTest( if chanType.IsTaproot() { scriptTree, _ := input.NewRemoteCommitScriptTree( - toRemotePK, + toRemotePK, fn.None[txscript.TapLeaf](), ) pkScript, _ := input.PayToTaprootScript( diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 38d9acd9f..f3a4d5bf4 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -19,6 +19,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -230,12 +231,14 @@ func (c *mockChannel) createRemoteCommitTx(t *testing.T) { // Construct the to-local witness script. toLocalScriptTree, err := input.NewLocalCommitScriptTree( - c.csvDelay, c.toLocalPK, c.revPK, + c.csvDelay, c.toLocalPK, c.revPK, fn.None[txscript.TapLeaf](), ) require.NoError(t, err, "unable to create to-local script") // Construct the to-remote witness script. - toRemoteScriptTree, err := input.NewRemoteCommitScriptTree(c.toRemotePK) + toRemoteScriptTree, err := input.NewRemoteCommitScriptTree( + c.toRemotePK, fn.None[txscript.TapLeaf](), + ) require.NoError(t, err, "unable to create to-remote script") // Compute the to-local witness script hash. From 5c4428c3cf96f46212c4b853fe0d4f8de320a704 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Sun, 17 Mar 2024 15:51:50 -0400 Subject: [PATCH 10/22] channeldb: new custom blob nested TLV In this commit, for each channel, we'll now start to store an optional custom blob. This can be used to store extra information for custom channels in an opauqe manner. --- channeldb/channel.go | 153 +++++++++++++++++++++++++++++++++++++- channeldb/channel_test.go | 11 ++- 2 files changed, 159 insertions(+), 5 deletions(-) diff --git a/channeldb/channel.go b/channeldb/channel.go index f29a9ec19..9fed8e8ac 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -256,6 +256,10 @@ type openChannelTlvData struct { // tapscriptRoot is the optional Tapscript root the channel funding // output commits to. tapscriptRoot tlv.OptionalRecordT[tlv.TlvType6, [32]byte] + + // customBlob is an optional TLV encoded blob of data representing + // custom channel funding information. + customBlob tlv.OptionalRecordT[tlv.TlvType7, tlv.Blob] } // encode serializes the openChannelTlvData to the given io.Writer. @@ -274,6 +278,9 @@ func (c *openChannelTlvData) encode(w io.Writer) error { tlvRecords = append(tlvRecords, root.Record()) }, ) + c.customBlob.WhenSome(func(blob tlv.RecordT[tlv.TlvType7, tlv.Blob]) { + tlvRecords = append(tlvRecords, blob.Record()) + }) // Create the tlv stream. tlvStream, err := tlv.NewStream(tlvRecords...) @@ -288,6 +295,7 @@ func (c *openChannelTlvData) encode(w io.Writer) error { func (c *openChannelTlvData) decode(r io.Reader) error { memo := c.memo.Zero() tapscriptRoot := c.tapscriptRoot.Zero() + blob := c.customBlob.Zero() // Create the tlv stream. tlvStream, err := tlv.NewStream( @@ -297,6 +305,7 @@ func (c *openChannelTlvData) decode(r io.Reader) error { c.realScid.Record(), memo.Record(), tapscriptRoot.Record(), + blob.Record(), ) if err != nil { return err @@ -313,6 +322,9 @@ func (c *openChannelTlvData) decode(r io.Reader) error { if _, ok := tlvs[tapscriptRoot.TlvType()]; ok { c.tapscriptRoot = tlv.SomeRecordT(tapscriptRoot) } + if _, ok := tlvs[c.customBlob.TlvType()]; ok { + c.customBlob = tlv.SomeRecordT(blob) + } return nil } @@ -576,6 +588,53 @@ type ChannelConfig struct { HtlcBasePoint keychain.KeyDescriptor } +// commitTlvData stores all the optional data that may be stored as a TLV stream +// at the _end_ of the normal serialized commit on disk. +type commitTlvData struct { + // customBlob is a custom blob that may store extra data for custom + // channels. + customBlob tlv.OptionalRecordT[tlv.TlvType1, tlv.Blob] +} + +// encode encodes the aux data into the passed io.Writer. +func (c *commitTlvData) encode(w io.Writer) error { + var tlvRecords []tlv.Record + c.customBlob.WhenSome(func(blob tlv.RecordT[tlv.TlvType1, tlv.Blob]) { + tlvRecords = append(tlvRecords, blob.Record()) + }) + + // Create the tlv stream. + tlvStream, err := tlv.NewStream(tlvRecords...) + if err != nil { + return err + } + + return tlvStream.Encode(w) +} + +// decode attempts to decode the aux data from the passed io.Reader. +func (c *commitTlvData) decode(r io.Reader) error { + blob := c.customBlob.Zero() + + tlvStream, err := tlv.NewStream( + blob.Record(), + ) + if err != nil { + return err + } + + tlvs, err := tlvStream.DecodeWithParsedTypes(r) + if err != nil { + return err + } + + if _, ok := tlvs[c.customBlob.TlvType()]; ok { + c.customBlob = tlv.SomeRecordT(blob) + } + + return nil +} + // ChannelCommitment is a snapshot of the commitment state at a particular // point in the commitment chain. With each state transition, a snapshot of the // current state along with all non-settled HTLCs are recorded. These snapshots @@ -642,6 +701,11 @@ type ChannelCommitment struct { // able by us. CommitTx *wire.MsgTx + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This may track some custom + // specific state for this given commitment. + CustomBlob fn.Option[tlv.Blob] + // CommitSig is one half of the signature required to fully complete // the script for the commitment transaction above. This is the // signature signed by the remote party for our version of the @@ -651,9 +715,26 @@ type ChannelCommitment struct { // Htlcs is the set of HTLC's that are pending at this particular // commitment height. Htlcs []HTLC +} - // TODO(roasbeef): pending commit pointer? - // * lets just walk through +// amendTlvData updates the channel with the given auxiliary TLV data. +func (c *ChannelCommitment) amendTlvData(auxData commitTlvData) { + auxData.customBlob.WhenSomeV(func(blob tlv.Blob) { + c.CustomBlob = fn.Some(blob) + }) +} + +// extractTlvData creates a new commitTlvData from the given commitment. +func (c *ChannelCommitment) extractTlvData() commitTlvData { + var auxData commitTlvData + + c.CustomBlob.WhenSome(func(blob tlv.Blob) { + auxData.customBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType1](blob), + ) + }) + + return auxData } // ChannelStatus is a bit vector used to indicate whether an OpenChannel is in @@ -951,6 +1032,12 @@ type OpenChannel struct { // funding output. TapscriptRoot fn.Option[chainhash.Hash] + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This information is only created + // at channel funding time, and after wards is to be considered + // immutable. + CustomBlob fn.Option[tlv.Blob] + // TODO(roasbeef): eww Db *ChannelStateDB @@ -1126,6 +1213,9 @@ func (c *OpenChannel) amendTlvData(auxData openChannelTlvData) { auxData.tapscriptRoot.WhenSomeV(func(h [32]byte) { c.TapscriptRoot = fn.Some[chainhash.Hash](h) }) + auxData.customBlob.WhenSomeV(func(blob tlv.Blob) { + c.CustomBlob = fn.Some(blob) + }) } // extractTlvData creates a new openChannelTlvData from the given channel. @@ -1155,6 +1245,11 @@ func (c *OpenChannel) extractTlvData() openChannelTlvData { tlv.NewPrimitiveRecord[tlv.TlvType6, [32]byte](h), ) }) + c.CustomBlob.WhenSome(func(blob tlv.Blob) { + auxData.customBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType7](blob), + ) + }) return auxData } @@ -2824,6 +2919,14 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { // nolint: dupl } } + // We'll also encode the commit aux data stream here. We do this here + // rather than above (at the call to serializeChanCommit), to ensure + // backwards compat for reads to existing non-custom channels. + auxData := diff.Commitment.extractTlvData() + if err := auxData.encode(w); err != nil { + return fmt.Errorf("unable to write aux data: %w", err) + } + return nil } @@ -2884,6 +2987,17 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { } } + // As a final step, we'll read out any aux commit data that we have at + // the end of this byte stream. We do this here to ensure backward + // compatibility, as otherwise we risk erroneously reading into the + // wrong field. + var auxData commitTlvData + if err := auxData.decode(r); err != nil { + return nil, fmt.Errorf("unable to decode aux data: %w", err) + } + + d.Commitment.amendTlvData(auxData) + return &d, nil } @@ -3862,6 +3976,13 @@ func (c *OpenChannel) Snapshot() *ChannelSnapshot { }, } + localCommit.CustomBlob.WhenSome(func(blob tlv.Blob) { + blobCopy := make([]byte, len(blob)) + copy(blobCopy, blob) + + snapshot.ChannelCommitment.CustomBlob = fn.Some(blobCopy) + }) + // Copy over the current set of HTLCs to ensure the caller can't mutate // our internal state. snapshot.Htlcs = make([]HTLC, len(localCommit.Htlcs)) @@ -4253,6 +4374,12 @@ func putChanCommitment(chanBucket kvdb.RwBucket, c *ChannelCommitment, return err } + // Before we write to disk, we'll also write our aux data as well. + auxData := c.extractTlvData() + if err := auxData.encode(&b); err != nil { + return fmt.Errorf("unable to write aux data: %w", err) + } + return chanBucket.Put(commitKey, b.Bytes()) } @@ -4398,7 +4525,9 @@ func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) { return c, nil } -func fetchChanCommitment(chanBucket kvdb.RBucket, local bool) (ChannelCommitment, error) { +func fetchChanCommitment(chanBucket kvdb.RBucket, + local bool) (ChannelCommitment, error) { + var commitKey []byte if local { commitKey = append(chanCommitmentKey, byte(0x00)) @@ -4412,7 +4541,23 @@ func fetchChanCommitment(chanBucket kvdb.RBucket, local bool) (ChannelCommitment } r := bytes.NewReader(commitBytes) - return deserializeChanCommit(r) + chanCommit, err := deserializeChanCommit(r) + if err != nil { + return ChannelCommitment{}, fmt.Errorf("unable to decode "+ + "chan commit: %w", err) + } + + // We'll also check to see if we have any aux data stored as the end of + // the stream. + var auxData commitTlvData + if err := auxData.decode(r); err != nil { + return ChannelCommitment{}, fmt.Errorf("unable to decode "+ + "chan aux data: %w", err) + } + + chanCommit.amendTlvData(auxData) + + return chanCommit, nil } func fetchChanCommitments(chanBucket kvdb.RBucket, channel *OpenChannel) error { diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 84aae9622..13d77276c 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -351,6 +351,7 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { FeePerKw: btcutil.Amount(5000), CommitTx: channels.TestFundingTx, CommitSig: bytes.Repeat([]byte{1}, 71), + CustomBlob: fn.Some([]byte{1, 2, 3}), }, RemoteCommitment: ChannelCommitment{ CommitHeight: 0, @@ -360,6 +361,7 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { FeePerKw: btcutil.Amount(5000), CommitTx: channels.TestFundingTx, CommitSig: bytes.Repeat([]byte{1}, 71), + CustomBlob: fn.Some([]byte{4, 5, 6}), }, NumConfsRequired: 4, RemoteCurrentRevocation: privKey.PubKey(), @@ -374,6 +376,7 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { InitialRemoteBalance: lnwire.MilliSatoshi(3000), Memo: []byte("test"), TapscriptRoot: fn.Some(tapscriptRoot), + CustomBlob: fn.Some([]byte{1, 2, 3}), } } @@ -663,6 +666,7 @@ func TestChannelStateTransition(t *testing.T) { CommitTx: newTx, CommitSig: newSig, Htlcs: htlcs, + CustomBlob: fn.Some([]byte{4, 5, 6}), } // First update the local node's broadcastable state and also add a @@ -700,9 +704,14 @@ func TestChannelStateTransition(t *testing.T) { // have been updated. updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub) require.NoError(t, err, "unable to fetch updated channel") - assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment) + + assertCommitmentEqual( + t, &commitment, &updatedChannel[0].LocalCommitment, + ) + numDiskUpdates, err := updatedChannel[0].CommitmentHeight() require.NoError(t, err, "unable to read commitment height from disk") + if numDiskUpdates != uint64(commitment.CommitHeight) { t.Fatalf("num disk updates doesn't match: %v vs %v", numDiskUpdates, commitment.CommitHeight) From 12352d9e6eb4d8e76ba1a7f63f05a1a5559c8e17 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 1 Apr 2024 20:00:29 -0700 Subject: [PATCH 11/22] lnwallet: export the HtlcView struct We'll need this later on to ensure we can always interact with the new aux blobs at all stages of commitment transaction construction. --- lnwallet/channel.go | 139 ++++++++++++++++++++++++--------------- lnwallet/channel_test.go | 22 +++---- lnwallet/commitment.go | 10 +-- 3 files changed, 102 insertions(+), 69 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 4e934b5d6..073f57d3a 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2459,18 +2459,29 @@ func HtlcIsDust(chanType channeldb.ChannelType, return (htlcAmt - htlcFee) < dustLimit } -// htlcView represents the "active" HTLCs at a particular point within the +// HtlcView represents the "active" HTLCs at a particular point within the // history of the HTLC update log. -type htlcView struct { - ourUpdates []*PaymentDescriptor - theirUpdates []*PaymentDescriptor - feePerKw chainfee.SatPerKWeight +type HtlcView struct { + // NextHeight is the height of the commitment transaction that will be + // created using this view. + NextHeight uint64 + + // OurUpdates are our outgoing HTLCs. + OurUpdates []*PaymentDescriptor + + // TheirUpdates are their incoming HTLCs. + TheirUpdates []*PaymentDescriptor + + // FeePerKw is the fee rate in sat/kw of the commitment transaction. + FeePerKw chainfee.SatPerKWeight } // fetchHTLCView returns all the candidate HTLC updates which should be // considered for inclusion within a commitment based on the passed HTLC log // indexes. -func (lc *LightningChannel) fetchHTLCView(theirLogIndex, ourLogIndex uint64) *htlcView { +func (lc *LightningChannel) fetchHTLCView(theirLogIndex, + ourLogIndex uint64) *HtlcView { + var ourHTLCs []*PaymentDescriptor for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() { htlc := e.Value @@ -2495,9 +2506,9 @@ func (lc *LightningChannel) fetchHTLCView(theirLogIndex, ourLogIndex uint64) *ht } } - return &htlcView{ - ourUpdates: ourHTLCs, - theirUpdates: theirHTLCs, + return &HtlcView{ + OurUpdates: ourHTLCs, + TheirUpdates: theirHTLCs, } } @@ -2534,7 +2545,10 @@ func (lc *LightningChannel) fetchCommitmentView( if err != nil { return nil, err } - feePerKw := filteredHTLCView.feePerKw + feePerKw := filteredHTLCView.FeePerKw + + htlcView.NextHeight = nextHeight + filteredHTLCView.NextHeight = nextHeight // Actually generate unsigned commitment transaction for this view. commitTx, err := lc.commitBuilder.createUnsignedCommitmentTx( @@ -2594,12 +2608,16 @@ func (lc *LightningChannel) fetchCommitmentView( // In order to ensure _none_ of the HTLC's associated with this new // commitment are mutated, we'll manually copy over each HTLC to its // respective slice. - c.outgoingHTLCs = make([]PaymentDescriptor, len(filteredHTLCView.ourUpdates)) - for i, htlc := range filteredHTLCView.ourUpdates { + c.outgoingHTLCs = make( + []PaymentDescriptor, len(filteredHTLCView.OurUpdates), + ) + for i, htlc := range filteredHTLCView.OurUpdates { c.outgoingHTLCs[i] = *htlc } - c.incomingHTLCs = make([]PaymentDescriptor, len(filteredHTLCView.theirUpdates)) - for i, htlc := range filteredHTLCView.theirUpdates { + c.incomingHTLCs = make( + []PaymentDescriptor, len(filteredHTLCView.TheirUpdates), + ) + for i, htlc := range filteredHTLCView.TheirUpdates { c.incomingHTLCs[i] = *htlc } @@ -2634,16 +2652,17 @@ func fundingTxIn(chanState *channeldb.OpenChannel) wire.TxIn { // once for each height, and only in concert with signing a new commitment. // TODO(halseth): return htlcs to mutate instead of mutating inside // method. -func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, +func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, theirBalance *lnwire.MilliSatoshi, nextHeight uint64, - whoseCommitChain lntypes.ChannelParty, mutateState bool, -) (*htlcView, error) { + whoseCommitChain lntypes.ChannelParty, mutateState bool) (*HtlcView, + error) { // We initialize the view's fee rate to the fee rate of the unfiltered // view. If any fee updates are found when evaluating the view, it will // be updated. - newView := &htlcView{ - feePerKw: view.feePerKw, + newView := &HtlcView{ + FeePerKw: view.FeePerKw, + NextHeight: nextHeight, } // We use two maps, one for the local log and one for the remote log to @@ -2656,7 +2675,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, // First we run through non-add entries in both logs, populating the // skip sets and mutating the current chain state (crediting balances, // etc) to reflect the settle/timeout entry encountered. - for _, entry := range view.ourUpdates { + for _, entry := range view.OurUpdates { switch entry.EntryType { // Skip adds for now. They will be processed below. case Add: @@ -2688,10 +2707,13 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, } skipThem[addEntry.HtlcIndex] = struct{}{} - processRemoveEntry(entry, ourBalance, theirBalance, - nextHeight, whoseCommitChain, true, mutateState) + + processRemoveEntry( + entry, ourBalance, theirBalance, nextHeight, + whoseCommitChain, true, mutateState, + ) } - for _, entry := range view.theirUpdates { + for _, entry := range view.TheirUpdates { switch entry.EntryType { // Skip adds for now. They will be processed below. case Add: @@ -2725,32 +2747,41 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, } skipUs[addEntry.HtlcIndex] = struct{}{} - processRemoveEntry(entry, ourBalance, theirBalance, - nextHeight, whoseCommitChain, false, mutateState) + + processRemoveEntry( + entry, ourBalance, theirBalance, nextHeight, + whoseCommitChain, false, mutateState, + ) } // Next we take a second pass through all the log entries, skipping any // settled HTLCs, and debiting the chain state balance due to any newly // added HTLCs. - for _, entry := range view.ourUpdates { + for _, entry := range view.OurUpdates { isAdd := entry.EntryType == Add if _, ok := skipUs[entry.HtlcIndex]; !isAdd || ok { continue } - processAddEntry(entry, ourBalance, theirBalance, nextHeight, - whoseCommitChain, false, mutateState) - newView.ourUpdates = append(newView.ourUpdates, entry) + processAddEntry( + entry, ourBalance, theirBalance, nextHeight, + whoseCommitChain, false, mutateState, + ) + + newView.OurUpdates = append(newView.OurUpdates, entry) } - for _, entry := range view.theirUpdates { + for _, entry := range view.TheirUpdates { isAdd := entry.EntryType == Add if _, ok := skipThem[entry.HtlcIndex]; !isAdd || ok { continue } - processAddEntry(entry, ourBalance, theirBalance, nextHeight, - whoseCommitChain, true, mutateState) - newView.theirUpdates = append(newView.theirUpdates, entry) + processAddEntry( + entry, ourBalance, theirBalance, nextHeight, + whoseCommitChain, true, mutateState, + ) + + newView.TheirUpdates = append(newView.TheirUpdates, entry) } return newView, nil @@ -2901,8 +2932,8 @@ func processRemoveEntry(htlc *PaymentDescriptor, ourBalance, // processFeeUpdate processes a log update that updates the current commitment // fee. func processFeeUpdate(feeUpdate *PaymentDescriptor, nextHeight uint64, - whoseCommitChain lntypes.ChannelParty, mutateState bool, view *htlcView, -) { + whoseCommitChain lntypes.ChannelParty, mutateState bool, + view *HtlcView) { // Fee updates are applied for all commitments after they are // sent/received, so we consider them being added and removed at the @@ -2923,7 +2954,7 @@ func processFeeUpdate(feeUpdate *PaymentDescriptor, nextHeight uint64, // If the update wasn't already locked in, update the current fee rate // to reflect this update. - view.feePerKw = chainfee.SatPerKWeight(feeUpdate.Amount.ToSatoshis()) + view.FeePerKw = chainfee.SatPerKWeight(feeUpdate.Amount.ToSatoshis()) if mutateState { *addHeight = nextHeight @@ -3499,10 +3530,10 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // appropriate update log, in order to validate the sanity of the // commitment resulting from _actually adding_ this HTLC to the state. if predictOurAdd != nil { - view.ourUpdates = append(view.ourUpdates, predictOurAdd) + view.OurUpdates = append(view.OurUpdates, predictOurAdd) } if predictTheirAdd != nil { - view.theirUpdates = append(view.theirUpdates, predictTheirAdd) + view.TheirUpdates = append(view.TheirUpdates, predictTheirAdd) } ourBalance, theirBalance, commitWeight, filteredView, err := lc.computeView( @@ -3513,7 +3544,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, return err } - feePerKw := filteredView.feePerKw + feePerKw := filteredView.FeePerKw // Ensure that the fee being applied is enough to be relayed across the // network in a reasonable time frame. @@ -3657,7 +3688,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // First check that the remote updates won't violate it's channel // constraints. err = validateUpdates( - filteredView.theirUpdates, &lc.channelState.RemoteChanCfg, + filteredView.TheirUpdates, &lc.channelState.RemoteChanCfg, ) if err != nil { return err @@ -3666,7 +3697,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // Secondly check that our updates won't violate our channel // constraints. err = validateUpdates( - filteredView.ourUpdates, &lc.channelState.LocalChanCfg, + filteredView.OurUpdates, &lc.channelState.LocalChanCfg, ) if err != nil { return err @@ -4266,7 +4297,7 @@ func (lc *LightningChannel) ProcessChanSyncMsg( return updates, openedCircuits, closedCircuits, nil } -// computeView takes the given htlcView, and calculates the balances, filtered +// computeView takes the given HtlcView, and calculates the balances, filtered // view (settling unsettled HTLCs), commitment weight and feePerKw, after // applying the HTLCs to the latest commitment. The returned balances are the // balances *before* subtracting the commitment fee from the initiator's @@ -4275,10 +4306,10 @@ func (lc *LightningChannel) ProcessChanSyncMsg( // // If the updateState boolean is set true, the add and remove heights of the // HTLCs will be set to the next commitment height. -func (lc *LightningChannel) computeView(view *htlcView, +func (lc *LightningChannel) computeView(view *HtlcView, whoseCommitChain lntypes.ChannelParty, updateState bool, dryRunFee fn.Option[chainfee.SatPerKWeight]) (lnwire.MilliSatoshi, - lnwire.MilliSatoshi, lntypes.WeightUnit, *htlcView, error) { + lnwire.MilliSatoshi, lntypes.WeightUnit, *HtlcView, error) { commitChain := lc.localCommitChain dustLimit := lc.channelState.LocalChanCfg.DustLimit @@ -4309,7 +4340,7 @@ func (lc *LightningChannel) computeView(view *htlcView, // Initiate feePerKw to the last committed fee for this chain as we'll // need this to determine which HTLCs are dust, and also the final fee // rate. - view.feePerKw = commitChain.tip().feePerKw + view.FeePerKw = commitChain.tip().feePerKw // We evaluate the view at this stage, meaning settled and failed HTLCs // will remove their corresponding added HTLCs. The resulting filtered @@ -4317,12 +4348,14 @@ func (lc *LightningChannel) computeView(view *htlcView, // channel constraints to the final commitment state. If any fee // updates are found in the logs, the commitment fee rate should be // changed, so we'll also set the feePerKw to this new value. - filteredHTLCView, err := lc.evaluateHTLCView(view, &ourBalance, - &theirBalance, nextHeight, whoseCommitChain, updateState) + filteredHTLCView, err := lc.evaluateHTLCView( + view, &ourBalance, &theirBalance, nextHeight, whoseCommitChain, + updateState, + ) if err != nil { return 0, 0, 0, nil, err } - feePerKw := filteredHTLCView.feePerKw + feePerKw := filteredHTLCView.FeePerKw // Here we override the view's fee-rate if a dry-run fee-rate was // passed in. @@ -4346,7 +4379,7 @@ func (lc *LightningChannel) computeView(view *htlcView, // Now go through all HTLCs at this stage, to calculate the total // weight, needed to calculate the transaction fee. var totalHtlcWeight lntypes.WeightUnit - for _, htlc := range filteredHTLCView.ourUpdates { + for _, htlc := range filteredHTLCView.OurUpdates { if HtlcIsDust( lc.channelState.ChanType, false, whoseCommitChain, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -4357,7 +4390,7 @@ func (lc *LightningChannel) computeView(view *htlcView, totalHtlcWeight += input.HTLCWeight } - for _, htlc := range filteredHTLCView.theirUpdates { + for _, htlc := range filteredHTLCView.TheirUpdates { if HtlcIsDust( lc.channelState.ChanType, true, whoseCommitChain, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -7823,13 +7856,13 @@ func (lc *LightningChannel) availableBalance( } // availableCommitmentBalance attempts to calculate the balance we have -// available for HTLCs on the local/remote commitment given the htlcView. To +// available for HTLCs on the local/remote commitment given the HtlcView. To // account for sending HTLCs of different sizes, it will report the balance // available for sending non-dust HTLCs, which will be manifested on the // commitment, increasing the commitment fee we must pay as an initiator, // eating into our balance. It will make sure we won't violate the channel // reserve constraints for this amount. -func (lc *LightningChannel) availableCommitmentBalance(view *htlcView, +func (lc *LightningChannel) availableCommitmentBalance(view *HtlcView, whoseCommitChain lntypes.ChannelParty, buffer BufferType) ( lnwire.MilliSatoshi, lntypes.WeightUnit) { @@ -7859,7 +7892,7 @@ func (lc *LightningChannel) availableCommitmentBalance(view *htlcView, // Calculate the commitment fee in the case where we would add another // HTLC to the commitment, as only the balance remaining after this fee // has been paid is actually available for sending. - feePerKw := filteredView.feePerKw + feePerKw := filteredView.FeePerKw additionalHtlcFee := lnwire.NewMSatFromSatoshis( feePerKw.FeeForWeight(input.HTLCWeight), ) diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index ac72a6ba6..e0e402484 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -8541,10 +8541,10 @@ func TestEvaluateView(t *testing.T) { } } - view := &htlcView{ - ourUpdates: test.ourHtlcs, - theirUpdates: test.theirHtlcs, - feePerKw: feePerKw, + view := &HtlcView{ + OurUpdates: test.ourHtlcs, + TheirUpdates: test.theirHtlcs, + FeePerKw: feePerKw, } var ( @@ -8565,17 +8565,17 @@ func TestEvaluateView(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - if result.feePerKw != test.expectedFee { + if result.FeePerKw != test.expectedFee { t.Fatalf("expected fee: %v, got: %v", - test.expectedFee, result.feePerKw) + test.expectedFee, result.FeePerKw) } checkExpectedHtlcs( - t, result.ourUpdates, test.ourExpectedHtlcs, + t, result.OurUpdates, test.ourExpectedHtlcs, ) checkExpectedHtlcs( - t, result.theirUpdates, test.theirExpectedHtlcs, + t, result.TheirUpdates, test.theirExpectedHtlcs, ) if lc.channelState.TotalMSatSent != test.expectSent { @@ -8798,15 +8798,15 @@ func TestProcessFeeUpdate(t *testing.T) { EntryType: FeeUpdate, } - view := &htlcView{ - feePerKw: chainfee.SatPerKWeight(feePerKw), + view := &HtlcView{ + FeePerKw: chainfee.SatPerKWeight(feePerKw), } processFeeUpdate( update, nextHeight, test.whoseCommitChain, test.mutate, view, ) - if view.feePerKw != test.expectedFee { + if view.FeePerKw != test.expectedFee { t.Fatalf("expected fee: %v, got: %v", test.expectedFee, feePerKw) } diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index 5af48eefa..2a127e25c 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -685,7 +685,7 @@ type unsignedCommitmentTx struct { func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, theirBalance lnwire.MilliSatoshi, whoseCommit lntypes.ChannelParty, feePerKw chainfee.SatPerKWeight, height uint64, - filteredHTLCView *htlcView, + filteredHTLCView *HtlcView, keyRing *CommitmentKeyRing) (*unsignedCommitmentTx, error) { dustLimit := cb.chanState.LocalChanCfg.DustLimit @@ -694,7 +694,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } numHTLCs := int64(0) - for _, htlc := range filteredHTLCView.ourUpdates { + for _, htlc := range filteredHTLCView.OurUpdates { if HtlcIsDust( cb.chanState.ChanType, false, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -705,7 +705,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, numHTLCs++ } - for _, htlc := range filteredHTLCView.theirUpdates { + for _, htlc := range filteredHTLCView.TheirUpdates { if HtlcIsDust( cb.chanState.ChanType, true, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -789,7 +789,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, // commitment outputs and should correspond to zero values for the // purposes of sorting. cltvs := make([]uint32, len(commitTx.TxOut)) - for _, htlc := range filteredHTLCView.ourUpdates { + for _, htlc := range filteredHTLCView.OurUpdates { if HtlcIsDust( cb.chanState.ChanType, false, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -807,7 +807,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } cltvs = append(cltvs, htlc.Timeout) // nolint:makezero } - for _, htlc := range filteredHTLCView.theirUpdates { + for _, htlc := range filteredHTLCView.TheirUpdates { if HtlcIsDust( cb.chanState.ChanType, true, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, From d25f881433aa5cca512a31a81b98295325ec1c10 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 1 Apr 2024 20:07:15 -0700 Subject: [PATCH 12/22] lnwallet: add custom tlv blob to internal commitment struct In this commit, we also add the custom TLV blob to the internal commitment struct that we use within the in-memory commitment linked list. This'll be useful to ensure that we're tracking the current blob for our in memory commitment for when we need to write it to disk. --- lnwallet/channel.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 073f57d3a..60026f99c 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -33,6 +33,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -333,6 +334,10 @@ type commitment struct { // on this commitment transaction. incomingHTLCs []PaymentDescriptor + // customBlob stores opaque bytes that may be used by custom channels + // to store extra data for a given commitment state. + customBlob fn.Option[tlv.Blob] + // [outgoing|incoming]HTLCIndex is an index that maps an output index // on the commitment transaction to the payment descriptor that // represents the HTLC output. @@ -506,6 +511,7 @@ func (c *commitment) toDiskCommit( CommitTx: c.txn, CommitSig: c.sig, Htlcs: make([]channeldb.HTLC, 0, numHtlcs), + CustomBlob: c.customBlob, } for _, htlc := range c.outgoingHTLCs { @@ -753,6 +759,7 @@ func (lc *LightningChannel) diskCommitToMemCommit( feePerKw: chainfee.SatPerKWeight(diskCommit.FeePerKw), incomingHTLCs: incomingHtlcs, outgoingHTLCs: outgoingHtlcs, + customBlob: diskCommit.CustomBlob, } if whoseCommit.IsLocal() { commit.dustLimit = lc.channelState.LocalChanCfg.DustLimit From c36cd9298f7b80476d200375b0b3226fae7847c6 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 1 Apr 2024 19:52:21 -0700 Subject: [PATCH 13/22] input: add some utility type definitions for aux leaves In this commit, we add some useful type definitions for the aux leaf. --- input/taproot.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/input/taproot.go b/input/taproot.go index 0fcd1df4e..2ca6e9723 100644 --- a/input/taproot.go +++ b/input/taproot.go @@ -31,6 +31,24 @@ func NoneTapLeaf() AuxTapLeaf { return fn.None[txscript.TapLeaf]() } +// HtlcIndex represents the monotonically increasing counter that is used to +// identify HTLCs created by a peer. +type HtlcIndex = uint64 + +// HtlcAuxLeaf is a type that represents an auxiliary leaf for an HTLC output. +// An HTLC may have up to two aux leaves: one for the output on the commitment +// transaction, and one for the second level HTLC. +type HtlcAuxLeaf struct { + AuxTapLeaf + + // SecondLevelLeaf is the auxiliary leaf for the second level HTLC + // success or timeout transaction. + SecondLevelLeaf AuxTapLeaf +} + +// HtlcAuxLeaves is a type alias for a map of optional tapscript leaves. +type HtlcAuxLeaves = map[HtlcIndex]HtlcAuxLeaf + // NewTxSigHashesV0Only returns a new txscript.TxSigHashes instance that will // only calculate the sighash midstate values for segwit v0 inputs and can // therefore never be used for transactions that want to spend segwit v1 From d95c1f93f340ffd6435c07231d96d38256cbb129 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Sun, 17 Mar 2024 16:53:38 -0400 Subject: [PATCH 14/22] lnwallet+channeldb: add new AuxLeafStore for dynamic aux leaves In this commit, we add a new AuxLeafStore which can be used to dynamically fetch the latest aux leaves for a given state. This is useful for custom channel types that will store some extra information in the form of a custom blob, then will use that information to derive the new leaf tapscript leaves that may be attached to reach state. --- contractcourt/chain_watcher.go | 3 +- lnwallet/aux_leaf_store.go | 239 +++++++++++++++++++++++++++++++++ lnwallet/channel.go | 66 +++++++-- lnwallet/commitment.go | 138 ++++++++++++++----- lnwallet/transactions_test.go | 2 + lnwallet/wallet.go | 27 +++- 6 files changed, 430 insertions(+), 45 deletions(-) create mode 100644 lnwallet/aux_leaf_store.go diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 8b2f75816..66a61fe3b 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -431,7 +431,7 @@ func (c *chainWatcher) handleUnknownLocalState( } remoteScript, _, err := lnwallet.CommitScriptToRemote( c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator, - commitKeyRing.ToRemoteKey, leaseExpiry, + commitKeyRing.ToRemoteKey, leaseExpiry, input.NoneTapLeaf(), ) if err != nil { return false, err @@ -444,6 +444,7 @@ func (c *chainWatcher) handleUnknownLocalState( c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator, commitKeyRing.ToLocalKey, commitKeyRing.RevocationKey, uint32(c.cfg.chanState.LocalChanCfg.CsvDelay), leaseExpiry, + input.NoneTapLeaf(), ) if err != nil { return false, err diff --git a/lnwallet/aux_leaf_store.go b/lnwallet/aux_leaf_store.go new file mode 100644 index 000000000..4558c2f81 --- /dev/null +++ b/lnwallet/aux_leaf_store.go @@ -0,0 +1,239 @@ +package lnwallet + +import ( + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +// CommitSortFunc is a function type alias for a function that sorts the +// commitment transaction outputs. The second parameter is a list of CLTV +// timeouts that must correspond to the number of transaction outputs, with the +// value of 0 for non-HTLC outputs. The HTLC indexes are needed to have a +// deterministic sort value for HTLCs that have the identical amount, CLTV +// timeout and payment hash (e.g. multiple MPP shards of the same payment, where +// the on-chain script would be identical). +type CommitSortFunc func(tx *wire.MsgTx, cltvs []uint32, + indexes []input.HtlcIndex) error + +// DefaultCommitSort is the default commitment sort function that sorts the +// commitment transaction inputs and outputs according to BIP69. The second +// parameter is a list of CLTV timeouts that must correspond to the number of +// transaction outputs, with the value of 0 for non-HTLC outputs. The third +// parameter is unused for the default sort function. +func DefaultCommitSort(tx *wire.MsgTx, cltvs []uint32, + _ []input.HtlcIndex) error { + + InPlaceCommitSort(tx, cltvs) + return nil +} + +// CommitAuxLeaves stores two potential auxiliary leaves for the remote and +// local output that may be used to augment the final tapscript trees of the +// commitment transaction. +type CommitAuxLeaves struct { + // LocalAuxLeaf is the local party's auxiliary leaf. + LocalAuxLeaf input.AuxTapLeaf + + // RemoteAuxLeaf is the remote party's auxiliary leaf. + RemoteAuxLeaf input.AuxTapLeaf + + // OutgoingHTLCLeaves is the set of aux leaves for the outgoing HTLCs + // on this commitment transaction. + OutgoingHtlcLeaves input.HtlcAuxLeaves + + // IncomingHTLCLeaves is the set of aux leaves for the incoming HTLCs + // on this commitment transaction. + IncomingHtlcLeaves input.HtlcAuxLeaves +} + +// AuxChanState is a struct that holds certain fields of the +// channeldb.OpenChannel struct that are used by the aux components. The data +// is copied over to prevent accidental mutation of the original channel state. +type AuxChanState struct { + // ChanType denotes which type of channel this is. + ChanType channeldb.ChannelType + + // FundingOutpoint is the outpoint of the final funding transaction. + // This value uniquely and globally identifies the channel within the + // target blockchain as specified by the chain hash parameter. + FundingOutpoint wire.OutPoint + + // ShortChannelID encodes the exact location in the chain in which the + // channel was initially confirmed. This includes: the block height, + // transaction index, and the output within the target transaction. + // + // If IsZeroConf(), then this will the "base" (very first) ALIAS scid + // and the confirmed SCID will be stored in ConfirmedScid. + ShortChannelID lnwire.ShortChannelID + + // IsInitiator is a bool which indicates if we were the original + // initiator for the channel. This value may affect how higher levels + // negotiate fees, or close the channel. + IsInitiator bool + + // Capacity is the total capacity of this channel. + Capacity btcutil.Amount + + // LocalChanCfg is the channel configuration for the local node. + LocalChanCfg channeldb.ChannelConfig + + // RemoteChanCfg is the channel configuration for the remote node. + RemoteChanCfg channeldb.ChannelConfig + + // ThawHeight is the height when a frozen channel once again becomes a + // normal channel. If this is zero, then there're no restrictions on + // this channel. If the value is lower than 500,000, then it's + // interpreted as a relative height, or an absolute height otherwise. + ThawHeight uint32 + + // TapscriptRoot is an optional tapscript root used to derive the MuSig2 + // funding output. + TapscriptRoot fn.Option[chainhash.Hash] + + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This information is only created + // at channel funding time, and after wards is to be considered + // immutable. + CustomBlob fn.Option[tlv.Blob] +} + +// NewAuxChanState creates a new AuxChanState from the given channel state. +func NewAuxChanState(chanState *channeldb.OpenChannel) AuxChanState { + return AuxChanState{ + ChanType: chanState.ChanType, + FundingOutpoint: chanState.FundingOutpoint, + ShortChannelID: chanState.ShortChannelID, + IsInitiator: chanState.IsInitiator, + Capacity: chanState.Capacity, + LocalChanCfg: chanState.LocalChanCfg, + RemoteChanCfg: chanState.RemoteChanCfg, + ThawHeight: chanState.ThawHeight, + TapscriptRoot: chanState.TapscriptRoot, + CustomBlob: chanState.CustomBlob, + } +} + +// CommitDiffAuxInput is the input required to compute the diff of the auxiliary +// leaves for a commitment transaction. +type CommitDiffAuxInput struct { + // ChannelState is the static channel information of the channel this + // commitment transaction relates to. + ChannelState AuxChanState + + // PrevBlob is the blob of the previous commitment transaction. + PrevBlob tlv.Blob + + // UnfilteredView is the unfiltered, original HTLC view of the channel. + // Unfiltered in this context means that the view contains all HTLCs, + // including the canceled ones. + UnfilteredView *HtlcView + + // WhoseCommit denotes whose commitment transaction we are computing the + // diff for. + WhoseCommit lntypes.ChannelParty + + // OurBalance is the balance of the local party. + OurBalance lnwire.MilliSatoshi + + // TheirBalance is the balance of the remote party. + TheirBalance lnwire.MilliSatoshi + + // KeyRing is the key ring that can be used to derive keys for the + // commitment transaction. + KeyRing CommitmentKeyRing +} + +// CommitDiffAuxResult is the result of computing the diff of the auxiliary +// leaves for a commitment transaction. +type CommitDiffAuxResult struct { + // AuxLeaves are the auxiliary leaves for the new commitment + // transaction. + AuxLeaves fn.Option[CommitAuxLeaves] + + // CommitSortFunc is an optional function that sorts the commitment + // transaction inputs and outputs. + CommitSortFunc fn.Option[CommitSortFunc] +} + +// AuxLeafStore is used to optionally fetch auxiliary tapscript leaves for the +// commitment transaction given an opaque blob. This is also used to implement +// a state transition function for the blobs to allow them to be refreshed with +// each state. +type AuxLeafStore interface { + // FetchLeavesFromView attempts to fetch the auxiliary leaves that + // correspond to the passed aux blob, and pending original (unfiltered) + // HTLC view. + FetchLeavesFromView( + in CommitDiffAuxInput) fn.Result[CommitDiffAuxResult] + + // FetchLeavesFromCommit attempts to fetch the auxiliary leaves that + // correspond to the passed aux blob, and an existing channel + // commitment. + FetchLeavesFromCommit(chanState AuxChanState, + commit channeldb.ChannelCommitment, + keyRing CommitmentKeyRing) fn.Result[CommitDiffAuxResult] + + // FetchLeavesFromRevocation attempts to fetch the auxiliary leaves + // from a channel revocation that stores balance + blob information. + FetchLeavesFromRevocation( + r *channeldb.RevocationLog) fn.Result[CommitDiffAuxResult] + + // ApplyHtlcView serves as the state transition function for the custom + // channel's blob. Given the old blob, and an HTLC view, then a new + // blob should be returned that reflects the pending updates. + ApplyHtlcView(in CommitDiffAuxInput) fn.Result[fn.Option[tlv.Blob]] +} + +// auxLeavesFromView is used to derive the set of commit aux leaves (if any), +// that are needed to create a new commitment transaction using the original +// (unfiltered) htlc view. +func auxLeavesFromView(leafStore AuxLeafStore, chanState *channeldb.OpenChannel, + prevBlob fn.Option[tlv.Blob], originalView *HtlcView, + whoseCommit lntypes.ChannelParty, ourBalance, + theirBalance lnwire.MilliSatoshi, + keyRing CommitmentKeyRing) fn.Result[CommitDiffAuxResult] { + + return fn.MapOptionZ( + prevBlob, func(blob tlv.Blob) fn.Result[CommitDiffAuxResult] { + return leafStore.FetchLeavesFromView(CommitDiffAuxInput{ + ChannelState: NewAuxChanState(chanState), + PrevBlob: blob, + UnfilteredView: originalView, + WhoseCommit: whoseCommit, + OurBalance: ourBalance, + TheirBalance: theirBalance, + KeyRing: keyRing, + }) + }, + ) +} + +// updateAuxBlob is a helper function that attempts to update the aux blob +// given the prior and current state information. +func updateAuxBlob(leafStore AuxLeafStore, chanState *channeldb.OpenChannel, + prevBlob fn.Option[tlv.Blob], nextViewUnfiltered *HtlcView, + whoseCommit lntypes.ChannelParty, ourBalance, + theirBalance lnwire.MilliSatoshi, + keyRing CommitmentKeyRing) fn.Result[fn.Option[tlv.Blob]] { + + return fn.MapOptionZ( + prevBlob, func(blob tlv.Blob) fn.Result[fn.Option[tlv.Blob]] { + return leafStore.ApplyHtlcView(CommitDiffAuxInput{ + ChannelState: NewAuxChanState(chanState), + PrevBlob: blob, + UnfilteredView: nextViewUnfiltered, + WhoseCommit: whoseCommit, + OurBalance: ourBalance, + TheirBalance: theirBalance, + KeyRing: keyRing, + }) + }, + ) +} diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 60026f99c..036c96eda 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -804,6 +804,10 @@ type LightningChannel struct { // machine. Signer input.Signer + // leafStore is used to retrieve extra tapscript leaves for special + // custom channel types. + leafStore fn.Option[AuxLeafStore] + // signDesc is the primary sign descriptor that is capable of signing // the commitment transaction that spends the multi-sig output. signDesc *input.SignDescriptor @@ -879,6 +883,8 @@ type channelOpts struct { localNonce *musig2.Nonces remoteNonce *musig2.Nonces + leafStore fn.Option[AuxLeafStore] + skipNonceInit bool } @@ -909,6 +915,13 @@ func WithSkipNonceInit() ChannelOpt { } } +// WithLeafStore is used to specify a custom leaf store for the channel. +func WithLeafStore(store AuxLeafStore) ChannelOpt { + return func(o *channelOpts) { + o.leafStore = fn.Some[AuxLeafStore](store) + } +} + // defaultChannelOpts returns the set of default options for a new channel. func defaultChannelOpts() *channelOpts { return &channelOpts{} @@ -950,13 +963,16 @@ func NewLightningChannel(signer input.Signer, } lc := &LightningChannel{ - Signer: signer, - sigPool: sigPool, - currentHeight: localCommit.CommitHeight, - remoteCommitChain: newCommitmentChain(), - localCommitChain: newCommitmentChain(), - channelState: state, - commitBuilder: NewCommitmentBuilder(state), + Signer: signer, + leafStore: opts.leafStore, + sigPool: sigPool, + currentHeight: localCommit.CommitHeight, + remoteCommitChain: newCommitmentChain(), + localCommitChain: newCommitmentChain(), + channelState: state, + commitBuilder: NewCommitmentBuilder( + state, opts.leafStore, + ), localUpdateLog: localUpdateLog, remoteUpdateLog: remoteUpdateLog, Capacity: state.Capacity, @@ -1988,12 +2004,14 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, leaseExpiry = chanState.ThawHeight } + // TODO(roasbeef): Actually fetch aux leaves (later commits in this PR). + // Since it is the remote breach we are reconstructing, the output // going to us will be a to-remote script with our local params. isRemoteInitiator := !chanState.IsInitiator ourScript, ourDelay, err := CommitScriptToRemote( chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey, - leaseExpiry, + leaseExpiry, input.NoneTapLeaf(), ) if err != nil { return nil, err @@ -2003,6 +2021,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, theirScript, err := CommitScriptToSelf( chanState.ChanType, isRemoteInitiator, keyRing.ToLocalKey, keyRing.RevocationKey, theirDelay, leaseExpiry, + input.NoneTapLeaf(), ) if err != nil { return nil, err @@ -2560,7 +2579,8 @@ func (lc *LightningChannel) fetchCommitmentView( // Actually generate unsigned commitment transaction for this view. commitTx, err := lc.commitBuilder.createUnsignedCommitmentTx( ourBalance, theirBalance, whoseCommitChain, feePerKw, - nextHeight, filteredHTLCView, keyRing, + nextHeight, htlcView, filteredHTLCView, keyRing, + commitChain.tip(), ) if err != nil { return nil, err @@ -2595,6 +2615,23 @@ func (lc *LightningChannel) fetchCommitmentView( effFeeRate, spew.Sdump(commitTx)) } + // Given the custom blob of the past state, and this new HTLC view, + // we'll generate a new blob for the latest commitment. + newCommitBlob, err := fn.MapOptionZ( + lc.leafStore, + func(s AuxLeafStore) fn.Result[fn.Option[tlv.Blob]] { + return updateAuxBlob( + s, lc.channelState, + commitChain.tip().customBlob, htlcView, + whoseCommitChain, ourBalance, theirBalance, + *keyRing, + ) + }, + ).Unpack() + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // With the commitment view created, store the resulting balances and // transaction with the other parameters for this height. c := &commitment{ @@ -2610,6 +2647,7 @@ func (lc *LightningChannel) fetchCommitmentView( feePerKw: feePerKw, dustLimit: dustLimit, whoseCommit: whoseCommitChain, + customBlob: newCommitBlob, } // In order to ensure _none_ of the HTLC's associated with this new @@ -2703,6 +2741,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, if mutateState && entry.EntryType == Settle && whoseCommitChain.IsLocal() && entry.removeCommitHeightLocal == 0 { + lc.channelState.TotalMSatReceived += entry.Amount } @@ -5398,7 +5437,7 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) ( // before the change since the indexes are meant for the current, // revoked remote commitment. ourOutputIndex, theirOutputIndex, err := findOutputIndexesFromRemote( - revocation, lc.channelState, + revocation, lc.channelState, lc.leafStore, ) if err != nil { return nil, nil, nil, nil, err @@ -6286,12 +6325,14 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Si commitTxBroadcast := commitSpend.SpendingTx + // TODO(roasbeef): Actually fetch aux leaves (later commits in this PR). + // Before we can generate the proper sign descriptor, we'll need to // locate the output index of our non-delayed output on the commitment // transaction. selfScript, maturityDelay, err := CommitScriptToRemote( chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey, - leaseExpiry, + leaseExpiry, input.NoneTapLeaf(), ) if err != nil { return nil, fmt.Errorf("unable to create self commit "+ @@ -7236,6 +7277,8 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) + // TODO(roasbeef): Actually fetch aux leaves (later commits in this PR). + var leaseExpiry uint32 if chanState.ChanType.HasLeaseExpiration() { leaseExpiry = chanState.ThawHeight @@ -7243,6 +7286,7 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, toLocalScript, err := CommitScriptToSelf( chanState.ChanType, chanState.IsInitiator, keyRing.ToLocalKey, keyRing.RevocationKey, csvTimeout, leaseExpiry, + input.NoneTapLeaf(), ) if err != nil { return nil, err diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index 2a127e25c..7f2a9ce33 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -225,8 +226,8 @@ func (w *WitnessScriptDesc) WitnessScriptForPath( // party learns of the preimage to the revocation hash, then they can claim all // the settled funds in the channel, plus the unsettled funds. func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool, - selfKey, revokeKey *btcec.PublicKey, csvDelay, - leaseExpiry uint32) (input.ScriptDescriptor, error) { + selfKey, revokeKey *btcec.PublicKey, csvDelay, leaseExpiry uint32, + auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, error) { switch { // For taproot scripts, we'll need to make a slightly modified script @@ -236,7 +237,7 @@ func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool, // Our "redeem" script here is just the taproot witness program. case chanType.IsTaproot(): return input.NewLocalCommitScriptTree( - csvDelay, selfKey, revokeKey, input.NoneTapLeaf(), + csvDelay, selfKey, revokeKey, auxLeaf, ) // If we are the initiator of a leased channel, then we have an @@ -290,8 +291,8 @@ func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool, // script for. The second return value is the CSV delay of the output script, // what must be satisfied in order to spend the output. func CommitScriptToRemote(chanType channeldb.ChannelType, initiator bool, - remoteKey *btcec.PublicKey, - leaseExpiry uint32) (input.ScriptDescriptor, uint32, error) { + remoteKey *btcec.PublicKey, leaseExpiry uint32, + auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, uint32, error) { switch { // If we are not the initiator of a leased channel, then the remote @@ -320,7 +321,7 @@ func CommitScriptToRemote(chanType channeldb.ChannelType, initiator bool, // with the sole tap leaf enforcing the 1 CSV delay. case chanType.IsTaproot(): toRemoteScriptTree, err := input.NewRemoteCommitScriptTree( - remoteKey, input.NoneTapLeaf(), + remoteKey, auxLeaf, ) if err != nil { return nil, 0, err @@ -612,7 +613,7 @@ func CommitScriptAnchors(chanType channeldb.ChannelType, // with, and abstracts the various ways of constructing commitment // transactions. type CommitmentBuilder struct { - // chanState is the underlying channels's state struct, used to + // chanState is the underlying channel's state struct, used to // determine the type of channel we are dealing with, and relevant // parameters. chanState *channeldb.OpenChannel @@ -620,18 +621,25 @@ type CommitmentBuilder struct { // obfuscator is a 48-bit state hint that's used to obfuscate the // current state number on the commitment transactions. obfuscator [StateHintSize]byte + + // auxLeafStore is an interface that allows us to fetch auxiliary + // tapscript leaves for the commitment output. + auxLeafStore fn.Option[AuxLeafStore] } // NewCommitmentBuilder creates a new CommitmentBuilder from chanState. -func NewCommitmentBuilder(chanState *channeldb.OpenChannel) *CommitmentBuilder { +func NewCommitmentBuilder(chanState *channeldb.OpenChannel, + leafStore fn.Option[AuxLeafStore]) *CommitmentBuilder { + // The anchor channel type MUST be tweakless. if chanState.ChanType.HasAnchors() && !chanState.ChanType.IsTweakless() { panic("invalid channel type combination") } return &CommitmentBuilder{ - chanState: chanState, - obfuscator: createStateHintObfuscator(chanState), + chanState: chanState, + obfuscator: createStateHintObfuscator(chanState), + auxLeafStore: leafStore, } } @@ -684,9 +692,9 @@ type unsignedCommitmentTx struct { // fees, but after anchor outputs. func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, theirBalance lnwire.MilliSatoshi, whoseCommit lntypes.ChannelParty, - feePerKw chainfee.SatPerKWeight, height uint64, - filteredHTLCView *HtlcView, - keyRing *CommitmentKeyRing) (*unsignedCommitmentTx, error) { + feePerKw chainfee.SatPerKWeight, height uint64, originalHtlcView, + filteredHTLCView *HtlcView, keyRing *CommitmentKeyRing, + prevCommit *commitment) (*unsignedCommitmentTx, error) { dustLimit := cb.chanState.LocalChanCfg.DustLimit if whoseCommit.IsRemote() { @@ -752,6 +760,23 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, err error ) + // Before we create the commitment transaction below, we'll try to see + // if there're any aux leaves that need to be a part of the tapscript + // tree. We'll only do this if we have a custom blob defined though. + auxResult, err := fn.MapOptionZ( + cb.auxLeafStore, + func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return auxLeavesFromView( + s, cb.chanState, prevCommit.customBlob, + originalHtlcView, whoseCommit, ourBalance, + theirBalance, *keyRing, + ) + }, + ).Unpack() + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // Depending on whether the transaction is ours or not, we call // CreateCommitTx with parameters matching the perspective, to generate // a new commitment transaction with all the latest unsettled/un-timed @@ -766,6 +791,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, &cb.chanState.LocalChanCfg, &cb.chanState.RemoteChanCfg, ourBalance.ToSatoshis(), theirBalance.ToSatoshis(), numHTLCs, cb.chanState.IsInitiator, leaseExpiry, + auxResult.AuxLeaves, ) } else { commitTx, err = CreateCommitTx( @@ -773,6 +799,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, &cb.chanState.RemoteChanCfg, &cb.chanState.LocalChanCfg, theirBalance.ToSatoshis(), ourBalance.ToSatoshis(), numHTLCs, !cb.chanState.IsInitiator, leaseExpiry, + auxResult.AuxLeaves, ) } if err != nil { @@ -789,6 +816,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, // commitment outputs and should correspond to zero values for the // purposes of sorting. cltvs := make([]uint32, len(commitTx.TxOut)) + htlcIndexes := make([]input.HtlcIndex, len(commitTx.TxOut)) for _, htlc := range filteredHTLCView.OurUpdates { if HtlcIsDust( cb.chanState.ChanType, false, whoseCommit, feePerKw, @@ -805,7 +833,11 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, if err != nil { return nil, err } - cltvs = append(cltvs, htlc.Timeout) // nolint:makezero + + // We want to add the CLTV and HTLC index to their respective + // slices, even if we already pre-allocated them. + cltvs = append(cltvs, htlc.Timeout) //nolint + htlcIndexes = append(htlcIndexes, htlc.HtlcIndex) //nolint } for _, htlc := range filteredHTLCView.TheirUpdates { if HtlcIsDust( @@ -823,7 +855,11 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, if err != nil { return nil, err } - cltvs = append(cltvs, htlc.Timeout) // nolint:makezero + + // We want to add the CLTV and HTLC index to their respective + // slices, even if we already pre-allocated them. + cltvs = append(cltvs, htlc.Timeout) //nolint + htlcIndexes = append(htlcIndexes, htlc.HtlcIndex) //nolint } // Set the state hint of the commitment transaction to facilitate @@ -835,9 +871,16 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } // Sort the transactions according to the agreed upon canonical - // ordering. This lets us skip sending the entire transaction over, - // instead we'll just send signatures. - InPlaceCommitSort(commitTx, cltvs) + // ordering (which might be customized for custom channel types, but + // deterministic and both parties will arrive at the same result). This + // lets us skip sending the entire transaction over, instead we'll just + // send signatures. + commitSort := auxResult.CommitSortFunc.UnwrapOr(DefaultCommitSort) + err = commitSort(commitTx, cltvs, htlcIndexes) + if err != nil { + return nil, fmt.Errorf("unable to sort commitment "+ + "transaction: %w", err) + } // Next, we'll ensure that we don't accidentally create a commitment // transaction which would be invalid by consensus. @@ -879,24 +922,33 @@ func CreateCommitTx(chanType channeldb.ChannelType, fundingOutput wire.TxIn, keyRing *CommitmentKeyRing, localChanCfg, remoteChanCfg *channeldb.ChannelConfig, amountToLocal, amountToRemote btcutil.Amount, - numHTLCs int64, initiator bool, leaseExpiry uint32) (*wire.MsgTx, error) { + numHTLCs int64, initiator bool, leaseExpiry uint32, + auxLeaves fn.Option[CommitAuxLeaves]) (*wire.MsgTx, error) { // First, we create the script for the delayed "pay-to-self" output. // This output has 2 main redemption clauses: either we can redeem the // output after a relative block delay, or the remote node can claim // the funds with the revocation key if we broadcast a revoked // commitment transaction. + localAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + })(auxLeaves) toLocalScript, err := CommitScriptToSelf( chanType, initiator, keyRing.ToLocalKey, keyRing.RevocationKey, uint32(localChanCfg.CsvDelay), leaseExpiry, + fn.FlattenOption(localAuxLeaf), ) if err != nil { return nil, err } // Next, we create the script paying to the remote. + remoteAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + })(auxLeaves) toRemoteScript, _, err := CommitScriptToRemote( chanType, initiator, keyRing.ToRemoteKey, leaseExpiry, + fn.FlattenOption(remoteAuxLeaf), ) if err != nil { return nil, err @@ -1201,7 +1253,8 @@ func addHTLC(commitTx *wire.MsgTx, whoseCommit lntypes.ChannelParty, // output scripts and compares them against the outputs inside the commitment // to find the match. func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, - chanState *channeldb.OpenChannel) (uint32, uint32, error) { + chanState *channeldb.OpenChannel, + leafStore fn.Option[AuxLeafStore]) (uint32, uint32, error) { // Init the output indexes as empty. ourIndex := uint32(channeldb.OutputIndexEmpty) @@ -1231,26 +1284,51 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, leaseExpiry = chanState.ThawHeight } - // Map the scripts from our PoV. When facing a local commitment, the to - // local output belongs to us and the to remote output belongs to them. - // When facing a remote commitment, the to local output belongs to them - // and the to remote output belongs to us. + // If we have a custom blob, then we'll attempt to fetch the aux leaves + // for this state. + auxResult, err := fn.MapOptionZ( + leafStore, func(a AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return a.FetchLeavesFromCommit( + NewAuxChanState(chanState), chanCommit, + *keyRing, + ) + }, + ).Unpack() + if err != nil { + return ourIndex, theirIndex, fmt.Errorf("unable to fetch aux "+ + "leaves: %w", err) + } - // Compute the to local script. From our PoV, when facing a remote - // commitment, the to local output belongs to them. + // Map the scripts from our PoV. When facing a local commitment, the + // to_local output belongs to us and the to_remote output belongs to + // them. When facing a remote commitment, the to_local output belongs to + // them and the to_remote output belongs to us. + + // Compute the to_local script. From our PoV, when facing a remote + // commitment, the to_local output belongs to them. + localAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + }, + )(auxResult.AuxLeaves) theirScript, err := CommitScriptToSelf( chanState.ChanType, isRemoteInitiator, keyRing.ToLocalKey, - keyRing.RevocationKey, theirDelay, leaseExpiry, + keyRing.RevocationKey, theirDelay, leaseExpiry, localAuxLeaf, ) if err != nil { return ourIndex, theirIndex, err } - // Compute the to remote script. From our PoV, when facing a remote - // commitment, the to remote output belongs to us. + // Compute the to_remote script. From our PoV, when facing a remote + // commitment, the to_remote output belongs to us. + remoteAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + }, + )(auxResult.AuxLeaves) ourScript, _, err := CommitScriptToRemote( chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey, - leaseExpiry, + leaseExpiry, remoteAuxLeaf, ) if err != nil { return ourIndex, theirIndex, err diff --git a/lnwallet/transactions_test.go b/lnwallet/transactions_test.go index e9751c6b9..439e7ce95 100644 --- a/lnwallet/transactions_test.go +++ b/lnwallet/transactions_test.go @@ -21,6 +21,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" @@ -631,6 +632,7 @@ func testSpendValidation(t *testing.T, tweakless bool) { commitmentTx, err := CreateCommitTx( channelType, *fakeFundingTxIn, keyRing, aliceChanCfg, bobChanCfg, channelBalance, channelBalance, 0, true, 0, + fn.None[CommitAuxLeaves](), ) if err != nil { t.Fatalf("unable to create commitment transaction: %v", nil) diff --git a/lnwallet/wallet.go b/lnwallet/wallet.go index 4f51eaf99..077e17a3a 100644 --- a/lnwallet/wallet.go +++ b/lnwallet/wallet.go @@ -1470,6 +1470,21 @@ func (l *LightningWallet) handleFundingCancelRequest(req *fundingReserveCancelMs req.err <- nil } +// createCommitOpts is a struct that holds the options for creating a new +// commitment transaction. +type createCommitOpts struct { + auxLeaves fn.Option[CommitAuxLeaves] +} + +// defaultCommitOpts returns a new createCommitOpts with default values. +func defaultCommitOpts() createCommitOpts { + return createCommitOpts{} +} + +// CreateCommitOpt is a functional option that can be used to modify the way a +// new commitment transaction is created. +type CreateCommitOpt func(*createCommitOpts) + // CreateCommitmentTxns is a helper function that creates the initial // commitment transaction for both parties. This function is used during the // initial funding workflow as both sides must generate a signature for the @@ -1479,7 +1494,13 @@ func CreateCommitmentTxns(localBalance, remoteBalance btcutil.Amount, ourChanCfg, theirChanCfg *channeldb.ChannelConfig, localCommitPoint, remoteCommitPoint *btcec.PublicKey, fundingTxIn wire.TxIn, chanType channeldb.ChannelType, initiator bool, - leaseExpiry uint32) (*wire.MsgTx, *wire.MsgTx, error) { + leaseExpiry uint32, opts ...CreateCommitOpt) (*wire.MsgTx, *wire.MsgTx, + error) { + + options := defaultCommitOpts() + for _, optFunc := range opts { + optFunc(&options) + } localCommitmentKeys := DeriveCommitmentKeys( localCommitPoint, lntypes.Local, chanType, ourChanCfg, @@ -1493,7 +1514,7 @@ func CreateCommitmentTxns(localBalance, remoteBalance btcutil.Amount, ourCommitTx, err := CreateCommitTx( chanType, fundingTxIn, localCommitmentKeys, ourChanCfg, theirChanCfg, localBalance, remoteBalance, 0, initiator, - leaseExpiry, + leaseExpiry, options.auxLeaves, ) if err != nil { return nil, nil, err @@ -1507,7 +1528,7 @@ func CreateCommitmentTxns(localBalance, remoteBalance btcutil.Amount, theirCommitTx, err := CreateCommitTx( chanType, fundingTxIn, remoteCommitmentKeys, theirChanCfg, ourChanCfg, remoteBalance, localBalance, 0, !initiator, - leaseExpiry, + leaseExpiry, options.auxLeaves, ) if err != nil { return nil, nil, err From c1e641e9d9150f7494b0070cd1964f4b4267fd51 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 1 Apr 2024 19:56:50 -0700 Subject: [PATCH 15/22] lnwallet: add TLV blob to PaymentDescriptor + htlc add In this commit, we add a TLV blob to the PaymentDescriptor struct. We also now thread through this value from the UpdateAddHTLC message to the PaymentDescriptor mapping, and the other way around. --- channeldb/channel.go | 24 ++++++++++++++++++++++++ channeldb/channel_test.go | 19 +++++++++++++++++++ lnwallet/channel.go | 14 +++++++++++++- lnwallet/channel_test.go | 24 ++++++++++++++++++------ lnwallet/payment_descriptor.go | 4 ++++ 5 files changed, 78 insertions(+), 7 deletions(-) diff --git a/channeldb/channel.go b/channeldb/channel.go index 9fed8e8ac..c21716a45 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -2580,6 +2580,12 @@ type HTLC struct { // HTLC. It is stored in the ExtraData field, which is used to store // a TLV stream of additional information associated with the HTLC. BlindingPoint lnwire.BlindingPointRecord + + // CustomRecords is a set of custom TLV records that are associated with + // this HTLC. These records are used to store additional information + // about the HTLC that is not part of the standard HTLC fields. This + // field is encoded within the ExtraData field. + CustomRecords lnwire.CustomRecords } // serializeExtraData encodes a TLV stream of extra data to be stored with a @@ -2598,6 +2604,11 @@ func (h *HTLC) serializeExtraData() error { records = append(records, &b) }) + records, err := h.CustomRecords.ExtendRecordProducers(records) + if err != nil { + return err + } + return h.ExtraData.PackRecords(records...) } @@ -2619,8 +2630,19 @@ func (h *HTLC) deserializeExtraData() error { if val, ok := tlvMap[h.BlindingPoint.TlvType()]; ok && val == nil { h.BlindingPoint = tlv.SomeRecordT(blindingPoint) + + // Remove the entry from the TLV map. Anything left in the map + // will be included in the custom records field. + delete(tlvMap, h.BlindingPoint.TlvType()) } + // Set the custom records field to the remaining TLV records. + customRecords, err := lnwire.NewCustomRecords(tlvMap) + if err != nil { + return err + } + h.CustomRecords = customRecords + return nil } @@ -2758,6 +2780,8 @@ func (h *HTLC) Copy() HTLC { copy(clone.Signature[:], h.Signature) copy(clone.RHash[:], h.RHash[:]) copy(clone.ExtraData, h.ExtraData) + clone.BlindingPoint = h.BlindingPoint + clone.CustomRecords = h.CustomRecords.Copy() return clone } diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 13d77276c..fa647f158 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -1657,6 +1657,24 @@ func TestHTLCsExtraData(t *testing.T) { ), } + // Custom channel data htlc with a blinding point. + customDataHTLC := HTLC{ + Signature: testSig.Serialize(), + Incoming: false, + Amt: 10, + RHash: key, + RefundTimeout: 1, + OnionBlob: lnmock.MockOnion(), + BlindingPoint: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + pubKey, + ), + ), + CustomRecords: map[uint64][]byte{ + uint64(lnwire.MinCustomRecordsTlvType + 3): {1, 2, 3}, + }, + } + testCases := []struct { name string htlcs []HTLC @@ -1678,6 +1696,7 @@ func TestHTLCsExtraData(t *testing.T) { mockHtlc, blindingPointHTLC, mockHtlc, + customDataHTLC, }, }, } diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 036c96eda..3cd2e7f23 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -529,6 +529,7 @@ func (c *commitment) toDiskCommit( LogIndex: htlc.LogIndex, Incoming: false, BlindingPoint: htlc.BlindingPoint, + CustomRecords: htlc.CustomRecords.Copy(), } copy(h.OnionBlob[:], htlc.OnionBlob) @@ -554,8 +555,10 @@ func (c *commitment) toDiskCommit( LogIndex: htlc.LogIndex, Incoming: true, BlindingPoint: htlc.BlindingPoint, + CustomRecords: htlc.CustomRecords.Copy(), } copy(h.OnionBlob[:], htlc.OnionBlob) + if whoseCommit.IsLocal() && htlc.sig != nil { h.Signature = htlc.sig.Serialize() } @@ -653,6 +656,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, theirPkScript: theirP2WSH, theirWitnessScript: theirWitnessScript, BlindingPoint: htlc.BlindingPoint, + CustomRecords: htlc.CustomRecords.Copy(), }, nil } @@ -5727,6 +5731,9 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC, OnionBlob: htlc.OnionBlob[:], OpenCircuitKey: openKey, BlindingPoint: htlc.BlindingPoint, + // TODO(guggero): Add custom records from HTLC here once we have + // the custom records in the HTLC struct (later commits in this + // PR). } } @@ -5767,7 +5774,9 @@ func (lc *LightningChannel) validateAddHtlc(pd *PaymentDescriptor, // ReceiveHTLC adds an HTLC to the state machine's remote update log. This // method should be called in response to receiving a new HTLC from the remote // party. -func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, error) { +func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, + error) { + lc.Lock() defer lc.Unlock() @@ -5785,6 +5794,9 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, err HtlcIndex: lc.remoteUpdateLog.htlcCounter, OnionBlob: htlc.OnionBlob[:], BlindingPoint: htlc.BlindingPoint, + // TODO(guggero): Add custom records from HTLC here once we have + // the custom records in the HTLC struct (later commits in this + // PR). } localACKedIndex := lc.remoteCommitChain.tail().ourMessageIndex diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index e0e402484..61fa5ed09 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -2,6 +2,7 @@ package lnwallet import ( "bytes" + crand "crypto/rand" "crypto/sha256" "fmt" "math/rand" @@ -10431,19 +10432,26 @@ func assertPayDescMatchHTLC(t *testing.T, pd PaymentDescriptor, // the `Incoming`. func createRandomHTLC(t *testing.T, incoming bool) channeldb.HTLC { var onionBlob [lnwire.OnionPacketSize]byte - _, err := rand.Read(onionBlob[:]) + _, err := crand.Read(onionBlob[:]) require.NoError(t, err) var rHash [lntypes.HashSize]byte - _, err = rand.Read(rHash[:]) + _, err = crand.Read(rHash[:]) require.NoError(t, err) sig := make([]byte, 64) - _, err = rand.Read(sig) + _, err = crand.Read(sig) require.NoError(t, err) + randCustomData := make([]byte, 32) + _, err = crand.Read(randCustomData) + require.NoError(t, err) + + randCustomType := rand.Intn(255) + lnwire.MinCustomRecordsTlvType + blinding, err := pubkeyFromHex( - "0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d48236c39", //nolint:lll + "0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d48" + + "236c39", ) require.NoError(t, err) @@ -10458,9 +10466,13 @@ func createRandomHTLC(t *testing.T, incoming bool) channeldb.HTLC { HtlcIndex: rand.Uint64(), LogIndex: rand.Uint64(), BlindingPoint: tlv.SomeRecordT( - //nolint:lll - tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](blinding), + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + blinding, + ), ), + CustomRecords: map[uint64][]byte{ + uint64(randCustomType): randCustomData, + }, } } diff --git a/lnwallet/payment_descriptor.go b/lnwallet/payment_descriptor.go index f0ac8b9f7..6fe74bb6f 100644 --- a/lnwallet/payment_descriptor.go +++ b/lnwallet/payment_descriptor.go @@ -221,4 +221,8 @@ type PaymentDescriptor struct { // blinded route (ie, not the introduction node) from update_add_htlc's // TLVs. BlindingPoint lnwire.BlindingPointRecord + + // CustomRecords also stores the set of optional custom records that + // may have been attached to a sent HTLC. + CustomRecords lnwire.CustomRecords } From 1b1e7a61681dbc7f354013a3ab0e49e1eb1c501f Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Sat, 30 Mar 2024 17:28:35 -0700 Subject: [PATCH 16/22] channeldb: convert HTLCEntry to use tlv.RecordT --- channeldb/channel_test.go | 24 ++-- channeldb/revocation_log.go | 207 +++++++++++++++---------------- channeldb/revocation_log_test.go | 100 +++++++++++---- lnwallet/channel.go | 18 +-- lnwallet/channel_test.go | 29 +++-- 5 files changed, 212 insertions(+), 166 deletions(-) diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index fa647f158..4427579a7 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -593,15 +593,21 @@ func assertRevocationLogEntryEqual(t *testing.T, c *ChannelCommitment, require.Equal(t, len(r.HTLCEntries), len(c.Htlcs), "HTLCs len mismatch") for i, rHtlc := range r.HTLCEntries { cHtlc := c.Htlcs[i] - require.Equal(t, rHtlc.RHash, cHtlc.RHash, "RHash mismatch") - require.Equal(t, rHtlc.Amt, cHtlc.Amt.ToSatoshis(), - "Amt mismatch") - require.Equal(t, rHtlc.RefundTimeout, cHtlc.RefundTimeout, - "RefundTimeout mismatch") - require.EqualValues(t, rHtlc.OutputIndex, cHtlc.OutputIndex, - "OutputIndex mismatch") - require.Equal(t, rHtlc.Incoming, cHtlc.Incoming, - "Incoming mismatch") + require.Equal(t, rHtlc.RHash.Val[:], cHtlc.RHash[:], "RHash") + require.Equal( + t, rHtlc.Amt.Val.Int(), cHtlc.Amt.ToSatoshis(), "Amt", + ) + require.Equal( + t, rHtlc.RefundTimeout.Val, cHtlc.RefundTimeout, + "RefundTimeout", + ) + require.EqualValues( + t, rHtlc.OutputIndex.Val, cHtlc.OutputIndex, + "OutputIndex", + ) + require.Equal( + t, rHtlc.Incoming.Val, cHtlc.Incoming, "Incoming", + ) } } diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index f062ac086..26e4b9644 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -54,6 +54,74 @@ var ( ErrOutputIndexTooBig = errors.New("output index is over uint16") ) +// SparsePayHash is a type alias for a 32 byte array, which when serialized is +// able to save some space by not including an empty payment hash on disk. +type SparsePayHash [32]byte + +// NewSparsePayHash creates a new SparsePayHash from a 32 byte array. +func NewSparsePayHash(rHash [32]byte) SparsePayHash { + return SparsePayHash(rHash) +} + +// Record returns a tlv record for the SparsePayHash. +func (s *SparsePayHash) Record() tlv.Record { + // We use a zero for the type here, as this'll be used along with the + // RecordT type. + return tlv.MakeDynamicRecord( + 0, s, s.hashLen, + sparseHashEncoder, sparseHashDecoder, + ) +} + +// hashLen is used by MakeDynamicRecord to return the size of the RHash. +// +// NOTE: for zero hash, we return a length 0. +func (s *SparsePayHash) hashLen() uint64 { + if bytes.Equal(s[:], lntypes.ZeroHash[:]) { + return 0 + } + + return 32 +} + +// sparseHashEncoder is the customized encoder which skips encoding the empty +// hash. +func sparseHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error { + v, ok := val.(*SparsePayHash) + if !ok { + return tlv.NewTypeForEncodingErr(val, "SparsePayHash") + } + + // If the value is an empty hash, we will skip encoding it. + if bytes.Equal(v[:], lntypes.ZeroHash[:]) { + return nil + } + + vArray := (*[32]byte)(v) + + return tlv.EBytes32(w, vArray, buf) +} + +// sparseHashDecoder is the customized decoder which skips decoding the empty +// hash. +func sparseHashDecoder(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + v, ok := val.(*SparsePayHash) + if !ok { + return tlv.NewTypeForEncodingErr(val, "SparsePayHash") + } + + // If the length is zero, we will skip encoding the empty hash. + if l == 0 { + return nil + } + + vArray := (*[32]byte)(v) + + return tlv.DBytes32(r, vArray, buf, 32) +} + // HTLCEntry specifies the minimal info needed to be stored on disk for ALL the // historical HTLCs, which is useful for constructing RevocationLog when a // breach is detected. @@ -72,118 +140,62 @@ var ( // made into tlv records without further conversion. type HTLCEntry struct { // RHash is the payment hash of the HTLC. - RHash [32]byte + RHash tlv.RecordT[tlv.TlvType0, SparsePayHash] // RefundTimeout is the absolute timeout on the HTLC that the sender // must wait before reclaiming the funds in limbo. - RefundTimeout uint32 + RefundTimeout tlv.RecordT[tlv.TlvType1, uint32] // OutputIndex is the output index for this particular HTLC output // within the commitment transaction. // // NOTE: we use uint16 instead of int32 here to save us 2 bytes, which // gives us a max number of HTLCs of 65K. - OutputIndex uint16 + OutputIndex tlv.RecordT[tlv.TlvType2, uint16] // Incoming denotes whether we're the receiver or the sender of this // HTLC. // // NOTE: this field is the memory representation of the field // incomingUint. - Incoming bool + Incoming tlv.RecordT[tlv.TlvType3, bool] // Amt is the amount of satoshis this HTLC escrows. // // NOTE: this field is the memory representation of the field amtUint. - Amt btcutil.Amount - - // amtTlv is the uint64 format of Amt. This field is created so we can - // easily make it into a tlv record and save it to disk. - // - // NOTE: we keep this field for accounting purpose only. If the disk - // space becomes an issue, we could delete this field to save us extra - // 8 bytes. - amtTlv uint64 - - // incomingTlv is the uint8 format of Incoming. This field is created - // so we can easily make it into a tlv record and save it to disk. - incomingTlv uint8 -} - -// RHashLen is used by MakeDynamicRecord to return the size of the RHash. -// -// NOTE: for zero hash, we return a length 0. -func (h *HTLCEntry) RHashLen() uint64 { - if h.RHash == lntypes.ZeroHash { - return 0 - } - return 32 -} - -// RHashEncoder is the customized encoder which skips encoding the empty hash. -func RHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error { - v, ok := val.(*[32]byte) - if !ok { - return tlv.NewTypeForEncodingErr(val, "RHash") - } - - // If the value is an empty hash, we will skip encoding it. - if *v == lntypes.ZeroHash { - return nil - } - - return tlv.EBytes32(w, v, buf) -} - -// RHashDecoder is the customized decoder which skips decoding the empty hash. -func RHashDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { - v, ok := val.(*[32]byte) - if !ok { - return tlv.NewTypeForEncodingErr(val, "RHash") - } - - // If the length is zero, we will skip encoding the empty hash. - if l == 0 { - return nil - } - - return tlv.DBytes32(r, v, buf, 32) + Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]] } // toTlvStream converts an HTLCEntry record into a tlv representation. func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { - const ( - // A set of tlv type definitions used to serialize htlc entries - // to the database. We define it here instead of the head of - // the file to avoid naming conflicts. - // - // NOTE: A migration should be added whenever this list - // changes. - rHashType tlv.Type = 0 - refundTimeoutType tlv.Type = 1 - outputIndexType tlv.Type = 2 - incomingType tlv.Type = 3 - amtType tlv.Type = 4 - ) - return tlv.NewStream( - tlv.MakeDynamicRecord( - rHashType, &h.RHash, h.RHashLen, - RHashEncoder, RHashDecoder, - ), - tlv.MakePrimitiveRecord( - refundTimeoutType, &h.RefundTimeout, - ), - tlv.MakePrimitiveRecord( - outputIndexType, &h.OutputIndex, - ), - tlv.MakePrimitiveRecord(incomingType, &h.incomingTlv), - // We will save 3 bytes if the amount is less or equal to - // 4,294,967,295 msat, or roughly 0.043 bitcoin. - tlv.MakeBigSizeRecord(amtType, &h.amtTlv), + h.RHash.Record(), + h.RefundTimeout.Record(), + h.OutputIndex.Record(), + h.Incoming.Record(), + h.Amt.Record(), ) } +// NewHTLCEntryFromHTLC creates a new HTLCEntry from an HTLC. +func NewHTLCEntryFromHTLC(htlc HTLC) *HTLCEntry { + return &HTLCEntry{ + RHash: tlv.NewRecordT[tlv.TlvType0]( + NewSparsePayHash(htlc.RHash), + ), + RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1]( + htlc.RefundTimeout, + ), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2]( + uint16(htlc.OutputIndex), + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](htlc.Incoming), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(htlc.Amt.ToSatoshis()), + ), + } +} + // RevocationLog stores the info needed to construct a breach retribution. Its // fields can be viewed as a subset of a ChannelCommitment's. In the database, // all historical versions of the RevocationLog are saved using the @@ -265,13 +277,7 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment, return ErrOutputIndexTooBig } - entry := &HTLCEntry{ - RHash: htlc.RHash, - RefundTimeout: htlc.RefundTimeout, - Incoming: htlc.Incoming, - OutputIndex: uint16(htlc.OutputIndex), - Amt: htlc.Amt.ToSatoshis(), - } + entry := NewHTLCEntryFromHTLC(htlc) rl.HTLCEntries = append(rl.HTLCEntries, entry) } @@ -351,14 +357,6 @@ func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { // format. func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error { for _, htlc := range htlcs { - // Patch the incomingTlv field. - if htlc.Incoming { - htlc.incomingTlv = 1 - } - - // Patch the amtTlv field. - htlc.amtTlv = uint64(htlc.Amt) - // Create the tlv stream. tlvStream, err := htlc.toTlvStream() if err != nil { @@ -447,14 +445,6 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { return nil, err } - // Patch the Incoming field. - if htlc.incomingTlv == 1 { - htlc.Incoming = true - } - - // Patch the Amt field. - htlc.Amt = btcutil.Amount(htlc.amtTlv) - // Append the entry. htlcs = append(htlcs, &htlc) } @@ -469,6 +459,7 @@ func writeTlvStream(w io.Writer, s *tlv.Stream) error { if err := s.Encode(&b); err != nil { return err } + // Write the stream's length as a varint. err := tlv.WriteVarInt(w, uint64(b.Len()), &[8]byte{}) if err != nil { diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go index fc5303a48..c9d2a16f7 100644 --- a/channeldb/revocation_log_test.go +++ b/channeldb/revocation_log_test.go @@ -34,12 +34,16 @@ var ( } testHTLCEntry = HTLCEntry{ - RefundTimeout: 740_000, - OutputIndex: 10, - Incoming: true, - Amt: 1000_000, - amtTlv: 1000_000, - incomingTlv: 1, + RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1, uint32]( + 740_000, + ), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16]( + 10, + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(btcutil.Amount(1_000_000)), + ), } testHTLCEntryBytes = []byte{ // Body length 23. @@ -56,6 +60,40 @@ var ( 0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40, } + testHTLCEntryHash = HTLCEntry{ + RHash: tlv.NewPrimitiveRecord[tlv.TlvType0](NewSparsePayHash( + [32]byte{0x33, 0x44, 0x55}, + )), + RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1, uint32]( + 740_000, + ), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16]( + 10, + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(btcutil.Amount(1_000_000)), + ), + } + testHTLCEntryHashBytes = []byte{ + // Body length 54. + 0x36, + // Rhash tlv. + 0x0, 0x20, + 0x33, 0x44, 0x55, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // RefundTimeout tlv. + 0x1, 0x4, 0x0, 0xb, 0x4a, 0xa0, + // OutputIndex tlv. + 0x2, 0x2, 0x0, 0xa, + // Incoming tlv. + 0x3, 0x1, 0x1, + // Amt tlv. + 0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40, + } + localBalance = lnwire.MilliSatoshi(9000) remoteBalance = lnwire.MilliSatoshi(3000) @@ -68,11 +106,11 @@ var ( CommitTx: channels.TestFundingTx, CommitSig: bytes.Repeat([]byte{1}, 71), Htlcs: []HTLC{{ - RefundTimeout: testHTLCEntry.RefundTimeout, - OutputIndex: int32(testHTLCEntry.OutputIndex), - Incoming: testHTLCEntry.Incoming, + RefundTimeout: testHTLCEntry.RefundTimeout.Val, + OutputIndex: int32(testHTLCEntry.OutputIndex.Val), + Incoming: testHTLCEntry.Incoming.Val, Amt: lnwire.NewMSatFromSatoshis( - testHTLCEntry.Amt, + testHTLCEntry.Amt.Val.Int(), ), }}, } @@ -193,11 +231,6 @@ func TestSerializeHTLCEntriesEmptyRHash(t *testing.T) { // Copy the testHTLCEntry. entry := testHTLCEntry - // Set the internal fields to empty values so we can test the bytes are - // padded. - entry.incomingTlv = 0 - entry.amtTlv = 0 - // Write the tlv stream. buf := bytes.NewBuffer([]byte{}) err := serializeHTLCEntries(buf, []*HTLCEntry{&entry}) @@ -207,6 +240,21 @@ func TestSerializeHTLCEntriesEmptyRHash(t *testing.T) { require.Equal(t, testHTLCEntryBytes, buf.Bytes()) } +func TestSerializeHTLCEntriesWithRHash(t *testing.T) { + t.Parallel() + + // Copy the testHTLCEntry. + entry := testHTLCEntryHash + + // Write the tlv stream. + buf := bytes.NewBuffer([]byte{}) + err := serializeHTLCEntries(buf, []*HTLCEntry{&entry}) + require.NoError(t, err) + + // Check the bytes are read as expected. + require.Equal(t, testHTLCEntryHashBytes, buf.Bytes()) +} + func TestSerializeHTLCEntries(t *testing.T) { t.Parallel() @@ -215,7 +263,7 @@ func TestSerializeHTLCEntries(t *testing.T) { // Create a fake rHash. rHashBytes := bytes.Repeat([]byte{10}, 32) - copy(entry.RHash[:], rHashBytes) + copy(entry.RHash.Val[:], rHashBytes) // Construct the serialized bytes. // @@ -269,7 +317,7 @@ func TestSerializeAndDeserializeRevLog(t *testing.T) { t, &test.revLog, test.revLogBytes, ) - testDerializeRevocationLog( + testDeserializeRevocationLog( t, &test.revLog, test.revLogBytes, ) }) @@ -293,7 +341,7 @@ func testSerializeRevocationLog(t *testing.T, rl *RevocationLog, require.Equal(t, revLogBytes, buf.Bytes()[:bodyIndex]) } -func testDerializeRevocationLog(t *testing.T, revLog *RevocationLog, +func testDeserializeRevocationLog(t *testing.T, revLog *RevocationLog, revLogBytes []byte) { // Construct the full bytes. @@ -309,7 +357,7 @@ func testDerializeRevocationLog(t *testing.T, revLog *RevocationLog, require.Equal(t, *revLog, rl) } -func TestDerializeHTLCEntriesEmptyRHash(t *testing.T) { +func TestDeserializeHTLCEntriesEmptyRHash(t *testing.T) { t.Parallel() // Read the tlv stream. @@ -322,7 +370,7 @@ func TestDerializeHTLCEntriesEmptyRHash(t *testing.T) { require.Equal(t, &testHTLCEntry, htlcs[0]) } -func TestDerializeHTLCEntries(t *testing.T) { +func TestDeserializeHTLCEntries(t *testing.T) { t.Parallel() // Copy the testHTLCEntry. @@ -330,7 +378,7 @@ func TestDerializeHTLCEntries(t *testing.T) { // Create a fake rHash. rHashBytes := bytes.Repeat([]byte{10}, 32) - copy(entry.RHash[:], rHashBytes) + copy(entry.RHash.Val[:], rHashBytes) // Construct the serialized bytes. // @@ -398,11 +446,11 @@ func TestDeleteLogBucket(t *testing.T) { err = kvdb.Update(backend, func(tx kvdb.RwTx) error { // Create the buckets. - chanBucket, _, err := createTestRevocatoinLogBuckets(tx) + chanBucket, _, err := createTestRevocationLogBuckets(tx) require.NoError(t, err) // Create the buckets again should give us an error. - _, _, err = createTestRevocatoinLogBuckets(tx) + _, _, err = createTestRevocationLogBuckets(tx) require.ErrorIs(t, err, kvdb.ErrBucketExists) // Delete both buckets. @@ -410,7 +458,7 @@ func TestDeleteLogBucket(t *testing.T) { require.NoError(t, err) // Create the buckets again should give us NO error. - _, _, err = createTestRevocatoinLogBuckets(tx) + _, _, err = createTestRevocationLogBuckets(tx) return err }, func() {}) require.NoError(t, err) @@ -516,7 +564,7 @@ func TestPutRevocationLog(t *testing.T) { // Construct the testing db transaction. dbTx := func(tx kvdb.RwTx) (RevocationLog, error) { // Create the buckets. - _, bucket, err := createTestRevocatoinLogBuckets(tx) + _, bucket, err := createTestRevocationLogBuckets(tx) require.NoError(t, err) // Save the log. @@ -686,7 +734,7 @@ func TestFetchRevocationLogCompatible(t *testing.T) { } } -func createTestRevocatoinLogBuckets(tx kvdb.RwTx) (kvdb.RwBucket, +func createTestRevocationLogBuckets(tx kvdb.RwTx) (kvdb.RwBucket, kvdb.RwBucket, error) { chanBucket, err := tx.CreateTopLevelBucket(openChannelBucket) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 3cd2e7f23..3aa3a6bbd 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2202,8 +2202,8 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // then from the PoV of the remote commitment state, they're the // receiver of this HTLC. scriptInfo, err := genHtlcScript( - chanState.ChanType, htlc.Incoming, lntypes.Remote, - htlc.RefundTimeout, htlc.RHash, keyRing, + chanState.ChanType, htlc.Incoming.Val, lntypes.Remote, + htlc.RefundTimeout.Val, htlc.RHash.Val, keyRing, ) if err != nil { return emptyRetribution, err @@ -2216,7 +2216,7 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, WitnessScript: scriptInfo.WitnessScriptToSign(), Output: &wire.TxOut{ PkScript: scriptInfo.PkScript(), - Value: int64(htlc.Amt), + Value: int64(htlc.Amt.Val.Int()), }, HashType: sweepSigHash(chanState.ChanType), } @@ -2249,10 +2249,10 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, SignDesc: signDesc, OutPoint: wire.OutPoint{ Hash: commitHash, - Index: uint32(htlc.OutputIndex), + Index: uint32(htlc.OutputIndex.Val), }, SecondLevelWitnessScript: secondLevelWitnessScript, - IsIncoming: htlc.Incoming, + IsIncoming: htlc.Incoming.Val, SecondLevelTapTweak: secondLevelTapTweak, }, nil } @@ -2414,13 +2414,7 @@ func createBreachRetributionLegacy(revokedLog *channeldb.ChannelCommitment, continue } - entry := &channeldb.HTLCEntry{ - RHash: htlc.RHash, - RefundTimeout: htlc.RefundTimeout, - OutputIndex: uint16(htlc.OutputIndex), - Incoming: htlc.Incoming, - Amt: htlc.Amt.ToSatoshis(), - } + entry := channeldb.NewHTLCEntryFromHTLC(htlc) hr, err := createHtlcRetribution( chanState, keyRing, commitHash, commitmentSecret, leaseExpiry, entry, diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 61fa5ed09..7d5078783 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -9957,9 +9957,11 @@ func TestCreateHtlcRetribution(t *testing.T) { aliceChannel.channelState, ) htlc := &channeldb.HTLCEntry{ - Amt: testAmt, - Incoming: true, - OutputIndex: 1, + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(testAmt), + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16](1), } // Create the htlc retribution. @@ -9973,8 +9975,8 @@ func TestCreateHtlcRetribution(t *testing.T) { // Check the fields have expected values. require.EqualValues(t, testAmt, hr.SignDesc.Output.Value) require.Equal(t, commitHash, hr.OutPoint.Hash) - require.EqualValues(t, htlc.OutputIndex, hr.OutPoint.Index) - require.Equal(t, htlc.Incoming, hr.IsIncoming) + require.EqualValues(t, htlc.OutputIndex.Val, hr.OutPoint.Index) + require.Equal(t, htlc.Incoming.Val, hr.IsIncoming) } // TestCreateBreachRetribution checks that `createBreachRetribution` behaves as @@ -10014,9 +10016,13 @@ func TestCreateBreachRetribution(t *testing.T) { aliceChannel.channelState, ) htlc := &channeldb.HTLCEntry{ - Amt: btcutil.Amount(testAmt), - Incoming: true, - OutputIndex: uint16(htlcIndex), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(btcutil.Amount(testAmt)), + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2]( + uint16(htlcIndex), + ), } // Create a dummy revocation log. @@ -10143,11 +10149,12 @@ func TestCreateBreachRetribution(t *testing.T) { require.Equal(t, remote, br.RemoteOutpoint) for _, hr := range br.HtlcRetributions { - require.EqualValues(t, testAmt, - hr.SignDesc.Output.Value) + require.EqualValues( + t, testAmt, hr.SignDesc.Output.Value, + ) require.Equal(t, commitHash, hr.OutPoint.Hash) require.EqualValues(t, htlcIndex, hr.OutPoint.Index) - require.Equal(t, htlc.Incoming, hr.IsIncoming) + require.Equal(t, htlc.Incoming.Val, hr.IsIncoming) } } From 61f276856a23006466fd8bf5682709d075fa0ca5 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Sat, 30 Mar 2024 18:13:57 -0700 Subject: [PATCH 17/22] channeldb: convert RevocationLog to use RecordT --- channeldb/channel_test.go | 12 ++- channeldb/revocation_log.go | 158 ++++++++++++++++++------------- channeldb/revocation_log_test.go | 25 +++-- lnwallet/channel.go | 36 ++++--- lnwallet/channel_test.go | 27 +++--- 5 files changed, 145 insertions(+), 113 deletions(-) diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 4427579a7..e92692201 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -584,9 +584,11 @@ func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) { func assertRevocationLogEntryEqual(t *testing.T, c *ChannelCommitment, r *RevocationLog) { + t.Helper() + // Check the common fields. require.EqualValues( - t, r.CommitTxHash, c.CommitTx.TxHash(), "CommitTx mismatch", + t, r.CommitTxHash.Val, c.CommitTx.TxHash(), "CommitTx mismatch", ) // Now check the common fields from the HTLCs. @@ -820,10 +822,10 @@ func TestChannelStateTransition(t *testing.T) { // Check the output indexes are saved as expected. require.EqualValues( - t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex, + t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex.Val, ) require.EqualValues( - t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex, + t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex.Val, ) // The two deltas (the original vs the on-disk version) should @@ -865,10 +867,10 @@ func TestChannelStateTransition(t *testing.T) { // Check the output indexes are saved as expected. require.EqualValues( - t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex, + t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex.Val, ) require.EqualValues( - t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex, + t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex.Val, ) assertRevocationLogEntryEqual(t, &oldRemoteCommit, prevCommit) diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index 26e4b9644..3fc4163c1 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -7,6 +7,7 @@ import ( "math" "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" @@ -16,16 +17,15 @@ import ( const ( // OutputIndexEmpty is used when the output index doesn't exist. OutputIndexEmpty = math.MaxUint16 +) - // A set of tlv type definitions used to serialize the body of - // revocation logs to the database. - // - // NOTE: A migration should be added whenever this list changes. - revLogOurOutputIndexType tlv.Type = 0 - revLogTheirOutputIndexType tlv.Type = 1 - revLogCommitTxHashType tlv.Type = 2 - revLogOurBalanceType tlv.Type = 3 - revLogTheirBalanceType tlv.Type = 4 +type ( + // BigSizeAmount is a type alias for a TLV record of a btcutil.Amount. + BigSizeAmount = tlv.BigSizeT[btcutil.Amount] + + // BigSizeMilliSatoshi is a type alias for a TLV record of a + // lnwire.MilliSatoshi. + BigSizeMilliSatoshi = tlv.BigSizeT[lnwire.MilliSatoshi] ) var ( @@ -203,15 +203,15 @@ func NewHTLCEntryFromHTLC(htlc HTLC) *HTLCEntry { type RevocationLog struct { // OurOutputIndex specifies our output index in this commitment. In a // remote commitment transaction, this is the to remote output index. - OurOutputIndex uint16 + OurOutputIndex tlv.RecordT[tlv.TlvType0, uint16] // TheirOutputIndex specifies their output index in this commitment. In // a remote commitment transaction, this is the to local output index. - TheirOutputIndex uint16 + TheirOutputIndex tlv.RecordT[tlv.TlvType1, uint16] // CommitTxHash is the hash of the latest version of the commitment // state, broadcast able by us. - CommitTxHash [32]byte + CommitTxHash tlv.RecordT[tlv.TlvType2, [32]byte] // HTLCEntries is the set of HTLCEntry's that are pending at this // particular commitment height. @@ -221,21 +221,53 @@ type RevocationLog struct { // directly spendable by us. In other words, it is the value of the // to_remote output on the remote parties' commitment transaction. // - // NOTE: this is a pointer so that it is clear if the value is zero or + // NOTE: this is an option so that it is clear if the value is zero or // nil. Since migration 30 of the channeldb initially did not include // this field, it could be the case that the field is not present for // all revocation logs. - OurBalance *lnwire.MilliSatoshi + OurBalance tlv.OptionalRecordT[tlv.TlvType3, BigSizeMilliSatoshi] // TheirBalance is the current available balance within the channel // directly spendable by the remote node. In other words, it is the // value of the to_local output on the remote parties' commitment. // - // NOTE: this is a pointer so that it is clear if the value is zero or + // NOTE: this is an option so that it is clear if the value is zero or // nil. Since migration 30 of the channeldb initially did not include // this field, it could be the case that the field is not present for // all revocation logs. - TheirBalance *lnwire.MilliSatoshi + TheirBalance tlv.OptionalRecordT[tlv.TlvType4, BigSizeMilliSatoshi] +} + +// NewRevocationLog creates a new RevocationLog from the given parameters. +func NewRevocationLog(ourOutputIndex uint16, theirOutputIndex uint16, + commitHash [32]byte, ourBalance, + theirBalance fn.Option[lnwire.MilliSatoshi], + htlcs []*HTLCEntry) RevocationLog { + + rl := RevocationLog{ + OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( + ourOutputIndex, + ), + TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1]( + theirOutputIndex, + ), + CommitTxHash: tlv.NewPrimitiveRecord[tlv.TlvType2](commitHash), + HTLCEntries: htlcs, + } + + ourBalance.WhenSome(func(balance lnwire.MilliSatoshi) { + rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3]( + tlv.NewBigSizeT(balance), + )) + }) + + theirBalance.WhenSome(func(balance lnwire.MilliSatoshi) { + rl.TheirBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(balance), + )) + }) + + return rl } // putRevocationLog uses the fields `CommitTx` and `Htlcs` from a @@ -254,15 +286,26 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment, } rl := &RevocationLog{ - OurOutputIndex: uint16(ourOutputIndex), - TheirOutputIndex: uint16(theirOutputIndex), - CommitTxHash: commit.CommitTx.TxHash(), - HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)), + OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( + uint16(ourOutputIndex), + ), + TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1]( + uint16(theirOutputIndex), + ), + CommitTxHash: tlv.NewPrimitiveRecord[tlv.TlvType2, [32]byte]( + commit.CommitTx.TxHash(), + ), + HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)), } if !noAmtData { - rl.OurBalance = &commit.LocalBalance - rl.TheirBalance = &commit.RemoteBalance + rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3]( + tlv.NewBigSizeT(commit.LocalBalance), + )) + + rl.TheirBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(commit.RemoteBalance), + )) } for _, htlc := range commit.Htlcs { @@ -312,31 +355,23 @@ func fetchRevocationLog(log kvdb.RBucket, func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { // Add the tlv records for all non-optional fields. records := []tlv.Record{ - tlv.MakePrimitiveRecord( - revLogOurOutputIndexType, &rl.OurOutputIndex, - ), - tlv.MakePrimitiveRecord( - revLogTheirOutputIndexType, &rl.TheirOutputIndex, - ), - tlv.MakePrimitiveRecord( - revLogCommitTxHashType, &rl.CommitTxHash, - ), + rl.OurOutputIndex.Record(), + rl.TheirOutputIndex.Record(), + rl.CommitTxHash.Record(), } // Now we add any optional fields that are non-nil. - if rl.OurBalance != nil { - lb := uint64(*rl.OurBalance) - records = append(records, tlv.MakeBigSizeRecord( - revLogOurBalanceType, &lb, - )) - } + rl.OurBalance.WhenSome( + func(r tlv.RecordT[tlv.TlvType3, BigSizeMilliSatoshi]) { + records = append(records, r.Record()) + }, + ) - if rl.TheirBalance != nil { - rb := uint64(*rl.TheirBalance) - records = append(records, tlv.MakeBigSizeRecord( - revLogTheirBalanceType, &rb, - )) - } + rl.TheirBalance.WhenSome( + func(r tlv.RecordT[tlv.TlvType4, BigSizeMilliSatoshi]) { + records = append(records, r.Record()) + }, + ) // Create the tlv stream. tlvStream, err := tlv.NewStream(records...) @@ -374,27 +409,18 @@ func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error { // deserializeRevocationLog deserializes a RevocationLog based on tlv format. func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { - var ( - rl RevocationLog - ourBalance uint64 - theirBalance uint64 - ) + var rl RevocationLog + + ourBalance := rl.OurBalance.Zero() + theirBalance := rl.TheirBalance.Zero() // Create the tlv stream. tlvStream, err := tlv.NewStream( - tlv.MakePrimitiveRecord( - revLogOurOutputIndexType, &rl.OurOutputIndex, - ), - tlv.MakePrimitiveRecord( - revLogTheirOutputIndexType, &rl.TheirOutputIndex, - ), - tlv.MakePrimitiveRecord( - revLogCommitTxHashType, &rl.CommitTxHash, - ), - tlv.MakeBigSizeRecord(revLogOurBalanceType, &ourBalance), - tlv.MakeBigSizeRecord( - revLogTheirBalanceType, &theirBalance, - ), + rl.OurOutputIndex.Record(), + rl.TheirOutputIndex.Record(), + rl.CommitTxHash.Record(), + ourBalance.Record(), + theirBalance.Record(), ) if err != nil { return rl, err @@ -406,14 +432,12 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { return rl, err } - if t, ok := parsedTypes[revLogOurBalanceType]; ok && t == nil { - lb := lnwire.MilliSatoshi(ourBalance) - rl.OurBalance = &lb + if t, ok := parsedTypes[ourBalance.TlvType()]; ok && t == nil { + rl.OurBalance = tlv.SomeRecordT(ourBalance) } - if t, ok := parsedTypes[revLogTheirBalanceType]; ok && t == nil { - rb := lnwire.MilliSatoshi(theirBalance) - rl.TheirBalance = &rb + if t, ok := parsedTypes[theirBalance.TlvType()]; ok && t == nil { + rl.TheirBalance = tlv.SomeRecordT(theirBalance) } // Read the HTLC entries. diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go index c9d2a16f7..813f70ac9 100644 --- a/channeldb/revocation_log_test.go +++ b/channeldb/revocation_log_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/channels" "github.com/lightningnetwork/lnd/lnwire" @@ -115,12 +116,11 @@ var ( }}, } - testRevocationLogNoAmts = RevocationLog{ - OurOutputIndex: 0, - TheirOutputIndex: 1, - CommitTxHash: testChannelCommit.CommitTx.TxHash(), - HTLCEntries: []*HTLCEntry{&testHTLCEntry}, - } + testRevocationLogNoAmts = NewRevocationLog( + 0, 1, testChannelCommit.CommitTx.TxHash(), + fn.None[lnwire.MilliSatoshi](), fn.None[lnwire.MilliSatoshi](), + []*HTLCEntry{&testHTLCEntry}, + ) testRevocationLogNoAmtsBytes = []byte{ // Body length 42. 0x2a, @@ -136,14 +136,11 @@ var ( 0xc8, 0x22, 0x51, 0xb1, 0x5b, 0xa0, 0xbf, 0xd, } - testRevocationLogWithAmts = RevocationLog{ - OurOutputIndex: 0, - TheirOutputIndex: 1, - CommitTxHash: testChannelCommit.CommitTx.TxHash(), - HTLCEntries: []*HTLCEntry{&testHTLCEntry}, - OurBalance: &localBalance, - TheirBalance: &remoteBalance, - } + testRevocationLogWithAmts = NewRevocationLog( + 0, 1, testChannelCommit.CommitTx.TxHash(), + fn.Some(localBalance), fn.Some(remoteBalance), + []*HTLCEntry{&testHTLCEntry}, + ) testRevocationLogWithAmtsBytes = []byte{ // Body length 52. 0x34, diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 3aa3a6bbd..02b1ea459 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2275,7 +2275,7 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, htlcRetributions := make([]HtlcRetribution, len(revokedLog.HTLCEntries)) for i, htlc := range revokedLog.HTLCEntries { hr, err := createHtlcRetribution( - chanState, keyRing, commitHash, + chanState, keyRing, commitHash.Val, commitmentSecret, leaseExpiry, htlc, ) if err != nil { @@ -2288,10 +2288,10 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, // Construct the our outpoint. ourOutpoint := wire.OutPoint{ - Hash: commitHash, + Hash: commitHash.Val, } - if revokedLog.OurOutputIndex != channeldb.OutputIndexEmpty { - ourOutpoint.Index = uint32(revokedLog.OurOutputIndex) + if revokedLog.OurOutputIndex.Val != channeldb.OutputIndexEmpty { + ourOutpoint.Index = uint32(revokedLog.OurOutputIndex.Val) // If the spend transaction is provided, then we use it to get // the value of our output. @@ -2314,26 +2314,29 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, // contains our output amount. Due to a previous // migration, this field may be empty in which case an // error will be returned. - if revokedLog.OurBalance == nil { - return nil, 0, 0, ErrRevLogDataMissing + b, err := revokedLog.OurBalance.ValOpt().UnwrapOrErr( + ErrRevLogDataMissing, + ) + if err != nil { + return nil, 0, 0, err } - ourAmt = int64(revokedLog.OurBalance.ToSatoshis()) + ourAmt = int64(b.Int().ToSatoshis()) } } // Construct the their outpoint. theirOutpoint := wire.OutPoint{ - Hash: commitHash, + Hash: commitHash.Val, } - if revokedLog.TheirOutputIndex != channeldb.OutputIndexEmpty { - theirOutpoint.Index = uint32(revokedLog.TheirOutputIndex) + if revokedLog.TheirOutputIndex.Val != channeldb.OutputIndexEmpty { + theirOutpoint.Index = uint32(revokedLog.TheirOutputIndex.Val) // If the spend transaction is provided, then we use it to get // the value of the remote parties' output. if spendTx != nil { // Sanity check that TheirOutputIndex is within range. - if int(revokedLog.TheirOutputIndex) >= + if int(revokedLog.TheirOutputIndex.Val) >= len(spendTx.TxOut) { return nil, 0, 0, fmt.Errorf("%w: theirs=%v, "+ @@ -2351,16 +2354,19 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, // contains remote parties' output amount. Due to a // previous migration, this field may be empty in which // case an error will be returned. - if revokedLog.TheirBalance == nil { - return nil, 0, 0, ErrRevLogDataMissing + b, err := revokedLog.TheirBalance.ValOpt().UnwrapOrErr( + ErrRevLogDataMissing, + ) + if err != nil { + return nil, 0, 0, err } - theirAmt = int64(revokedLog.TheirBalance.ToSatoshis()) + theirAmt = int64(b.Int().ToSatoshis()) } } return &BreachRetribution{ - BreachTxHash: commitHash, + BreachTxHash: commitHash.Val, ChainHash: chanState.ChainHash, LocalOutpoint: ourOutpoint, RemoteOutpoint: theirOutpoint, diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 7d5078783..3034ccd5b 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -10028,22 +10028,19 @@ func TestCreateBreachRetribution(t *testing.T) { // Create a dummy revocation log. ourAmtMsat := lnwire.MilliSatoshi(ourAmt * 1000) theirAmtMsat := lnwire.MilliSatoshi(theirAmt * 1000) - revokedLog := channeldb.RevocationLog{ - CommitTxHash: commitHash, - OurOutputIndex: uint16(localIndex), - TheirOutputIndex: uint16(remoteIndex), - HTLCEntries: []*channeldb.HTLCEntry{htlc}, - TheirBalance: &theirAmtMsat, - OurBalance: &ourAmtMsat, - } + revokedLog := channeldb.NewRevocationLog( + uint16(localIndex), uint16(remoteIndex), commitHash, + fn.Some(ourAmtMsat), fn.Some(theirAmtMsat), + []*channeldb.HTLCEntry{htlc}, + ) // Create a log with an empty local output index. revokedLogNoLocal := revokedLog - revokedLogNoLocal.OurOutputIndex = channeldb.OutputIndexEmpty + revokedLogNoLocal.OurOutputIndex.Val = channeldb.OutputIndexEmpty // Create a log with an empty remote output index. revokedLogNoRemote := revokedLog - revokedLogNoRemote.TheirOutputIndex = channeldb.OutputIndexEmpty + revokedLogNoRemote.TheirOutputIndex.Val = channeldb.OutputIndexEmpty testCases := []struct { name string @@ -10073,14 +10070,20 @@ func TestCreateBreachRetribution(t *testing.T) { { name: "fail due to our index too big", revocationLog: &channeldb.RevocationLog{ - OurOutputIndex: uint16(htlcIndex + 1), + //nolint:lll + OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( + uint16(htlcIndex + 1), + ), }, expectedErr: ErrOutputIndexOutOfRange, }, { name: "fail due to their index too big", revocationLog: &channeldb.RevocationLog{ - TheirOutputIndex: uint16(htlcIndex + 1), + //nolint:lll + TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1]( + uint16(htlcIndex + 1), + ), }, expectedErr: ErrOutputIndexOutOfRange, }, From 669740c84e5b328485dffdb7e3040461c70c1ed5 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 1 Apr 2024 17:00:55 -0700 Subject: [PATCH 18/22] channeldb: add custom blobs to RevocationLog+HTLCEntry This'll be useful for custom channel types that want to store extra information that'll be useful to help handle channel revocation cases. --- channeldb/revocation_log.go | 88 ++++++++++++++++++++++++++++---- channeldb/revocation_log_test.go | 45 ++++++++++++---- lnwallet/channel.go | 6 ++- lnwallet/channel_test.go | 2 +- 4 files changed, 120 insertions(+), 21 deletions(-) diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index 3fc4163c1..b7b73a35a 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -164,22 +164,32 @@ type HTLCEntry struct { // // NOTE: this field is the memory representation of the field amtUint. Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]] + + // CustomBlob is an optional blob that can be used to store information + // specific to revocation handling for a custom channel type. + CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob] } // toTlvStream converts an HTLCEntry record into a tlv representation. func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { - return tlv.NewStream( + records := []tlv.Record{ h.RHash.Record(), h.RefundTimeout.Record(), h.OutputIndex.Record(), h.Incoming.Record(), h.Amt.Record(), - ) + } + + h.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) { + records = append(records, r.Record()) + }) + + return tlv.NewStream(records...) } // NewHTLCEntryFromHTLC creates a new HTLCEntry from an HTLC. -func NewHTLCEntryFromHTLC(htlc HTLC) *HTLCEntry { - return &HTLCEntry{ +func NewHTLCEntryFromHTLC(htlc HTLC) (*HTLCEntry, error) { + h := &HTLCEntry{ RHash: tlv.NewRecordT[tlv.TlvType0]( NewSparsePayHash(htlc.RHash), ), @@ -194,6 +204,19 @@ func NewHTLCEntryFromHTLC(htlc HTLC) *HTLCEntry { tlv.NewBigSizeT(htlc.Amt.ToSatoshis()), ), } + + if len(htlc.CustomRecords) != 0 { + blob, err := htlc.CustomRecords.Serialize() + if err != nil { + return nil, err + } + + h.CustomBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob), + ) + } + + return h, nil } // RevocationLog stores the info needed to construct a breach retribution. Its @@ -236,13 +259,19 @@ type RevocationLog struct { // this field, it could be the case that the field is not present for // all revocation logs. TheirBalance tlv.OptionalRecordT[tlv.TlvType4, BigSizeMilliSatoshi] + + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This information is only created + // at channel funding time, and after wards is to be considered + // immutable. + CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob] } // NewRevocationLog creates a new RevocationLog from the given parameters. func NewRevocationLog(ourOutputIndex uint16, theirOutputIndex uint16, commitHash [32]byte, ourBalance, - theirBalance fn.Option[lnwire.MilliSatoshi], - htlcs []*HTLCEntry) RevocationLog { + theirBalance fn.Option[lnwire.MilliSatoshi], htlcs []*HTLCEntry, + customBlob fn.Option[tlv.Blob]) RevocationLog { rl := RevocationLog{ OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( @@ -267,6 +296,12 @@ func NewRevocationLog(ourOutputIndex uint16, theirOutputIndex uint16, )) }) + customBlob.WhenSome(func(blob tlv.Blob) { + rl.CustomBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob), + ) + }) + return rl } @@ -298,6 +333,12 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment, HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)), } + commit.CustomBlob.WhenSome(func(blob tlv.Blob) { + rl.CustomBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob), + ) + }) + if !noAmtData { rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3]( tlv.NewBigSizeT(commit.LocalBalance), @@ -320,7 +361,10 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment, return ErrOutputIndexTooBig } - entry := NewHTLCEntryFromHTLC(htlc) + entry, err := NewHTLCEntryFromHTLC(htlc) + if err != nil { + return err + } rl.HTLCEntries = append(rl.HTLCEntries, entry) } @@ -373,6 +417,10 @@ func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { }, ) + rl.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) { + records = append(records, r.Record()) + }) + // Create the tlv stream. tlvStream, err := tlv.NewStream(records...) if err != nil { @@ -413,6 +461,7 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { ourBalance := rl.OurBalance.Zero() theirBalance := rl.TheirBalance.Zero() + customBlob := rl.CustomBlob.Zero() // Create the tlv stream. tlvStream, err := tlv.NewStream( @@ -421,6 +470,7 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { rl.CommitTxHash.Record(), ourBalance.Record(), theirBalance.Record(), + customBlob.Record(), ) if err != nil { return rl, err @@ -440,6 +490,10 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { rl.TheirBalance = tlv.SomeRecordT(theirBalance) } + if t, ok := parsedTypes[customBlob.TlvType()]; ok && t == nil { + rl.CustomBlob = tlv.SomeRecordT(customBlob) + } + // Read the HTLC entries. rl.HTLCEntries, err = deserializeHTLCEntries(r) @@ -454,14 +508,26 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { for { var htlc HTLCEntry + customBlob := htlc.CustomBlob.Zero() + // Create the tlv stream. - tlvStream, err := htlc.toTlvStream() + records := []tlv.Record{ + htlc.RHash.Record(), + htlc.RefundTimeout.Record(), + htlc.OutputIndex.Record(), + htlc.Incoming.Record(), + htlc.Amt.Record(), + customBlob.Record(), + } + + tlvStream, err := tlv.NewStream(records...) if err != nil { return nil, err } // Read the HTLC entry. - if _, err := readTlvStream(r, tlvStream); err != nil { + parsedTypes, err := readTlvStream(r, tlvStream) + if err != nil { // We've reached the end when hitting an EOF. if err == io.ErrUnexpectedEOF { break @@ -469,6 +535,10 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { return nil, err } + if t, ok := parsedTypes[customBlob.TlvType()]; ok && t == nil { + htlc.CustomBlob = tlv.SomeRecordT(customBlob) + } + // Append the entry. htlcs = append(htlcs, &htlc) } diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go index 813f70ac9..139a02d52 100644 --- a/channeldb/revocation_log_test.go +++ b/channeldb/revocation_log_test.go @@ -34,6 +34,17 @@ var ( 0xff, // value = 255 } + customRecords = lnwire.CustomRecords{ + lnwire.MinCustomRecordsTlvType + 1: []byte("custom data"), + } + + blobBytes = []byte{ + // Corresponds to the encoded version of the above custom + // records. + 0xfe, 0x00, 0x01, 0x00, 0x01, 0x0b, 0x63, 0x75, 0x73, 0x74, + 0x6f, 0x6d, 0x20, 0x64, 0x61, 0x74, 0x61, + } + testHTLCEntry = HTLCEntry{ RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1, uint32]( 740_000, @@ -45,10 +56,13 @@ var ( Amt: tlv.NewRecordT[tlv.TlvType4]( tlv.NewBigSizeT(btcutil.Amount(1_000_000)), ), + CustomBlob: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5](blobBytes), + ), } testHTLCEntryBytes = []byte{ - // Body length 23. - 0x16, + // Body length 41. + 0x29, // Rhash tlv. 0x0, 0x0, // RefundTimeout tlv. @@ -59,6 +73,9 @@ var ( 0x3, 0x1, 0x1, // Amt tlv. 0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40, + // Custom blob tlv. + 0x5, 0x11, 0xfe, 0x00, 0x01, 0x00, 0x01, 0x0b, 0x63, 0x75, 0x73, + 0x74, 0x6f, 0x6d, 0x20, 0x64, 0x61, 0x74, 0x61, } testHTLCEntryHash = HTLCEntry{ @@ -113,17 +130,19 @@ var ( Amt: lnwire.NewMSatFromSatoshis( testHTLCEntry.Amt.Val.Int(), ), + CustomRecords: customRecords, }}, + CustomBlob: fn.Some(blobBytes), } testRevocationLogNoAmts = NewRevocationLog( 0, 1, testChannelCommit.CommitTx.TxHash(), fn.None[lnwire.MilliSatoshi](), fn.None[lnwire.MilliSatoshi](), - []*HTLCEntry{&testHTLCEntry}, + []*HTLCEntry{&testHTLCEntry}, fn.Some(blobBytes), ) testRevocationLogNoAmtsBytes = []byte{ - // Body length 42. - 0x2a, + // Body length 61. + 0x3d, // OurOutputIndex tlv. 0x0, 0x2, 0x0, 0x0, // TheirOutputIndex tlv. @@ -134,16 +153,19 @@ var ( 0x6e, 0x60, 0x29, 0x23, 0x1d, 0x5e, 0xc5, 0xe6, 0xbd, 0xf7, 0xd3, 0x9b, 0x16, 0x7d, 0x0, 0xff, 0xc8, 0x22, 0x51, 0xb1, 0x5b, 0xa0, 0xbf, 0xd, + // Custom blob tlv. + 0x5, 0x11, 0xfe, 0x00, 0x01, 0x00, 0x01, 0x0b, 0x63, 0x75, 0x73, + 0x74, 0x6f, 0x6d, 0x20, 0x64, 0x61, 0x74, 0x61, } testRevocationLogWithAmts = NewRevocationLog( 0, 1, testChannelCommit.CommitTx.TxHash(), fn.Some(localBalance), fn.Some(remoteBalance), - []*HTLCEntry{&testHTLCEntry}, + []*HTLCEntry{&testHTLCEntry}, fn.Some(blobBytes), ) testRevocationLogWithAmtsBytes = []byte{ - // Body length 52. - 0x34, + // Body length 71. + 0x47, // OurOutputIndex tlv. 0x0, 0x2, 0x0, 0x0, // TheirOutputIndex tlv. @@ -158,6 +180,9 @@ var ( 0x3, 0x3, 0xfd, 0x23, 0x28, // Remote Balance. 0x4, 0x3, 0xfd, 0x0b, 0xb8, + // Custom blob tlv. + 0x5, 0x11, 0xfe, 0x00, 0x01, 0x00, 0x01, 0x0b, 0x63, 0x75, 0x73, + 0x74, 0x6f, 0x6d, 0x20, 0x64, 0x61, 0x74, 0x61, } ) @@ -269,7 +294,7 @@ func TestSerializeHTLCEntries(t *testing.T) { partialBytes := testHTLCEntryBytes[3:] // Write the total length and RHash tlv. - expectedBytes := []byte{0x36, 0x0, 0x20} + expectedBytes := []byte{0x49, 0x0, 0x20} expectedBytes = append(expectedBytes, rHashBytes...) // Append the rest. @@ -384,7 +409,7 @@ func TestDeserializeHTLCEntries(t *testing.T) { partialBytes := testHTLCEntryBytes[3:] // Write the total length and RHash tlv. - testBytes := append([]byte{0x36, 0x0, 0x20}, rHashBytes...) + testBytes := append([]byte{0x4d, 0x0, 0x20}, rHashBytes...) // Append the rest. testBytes = append(testBytes, partialBytes...) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 02b1ea459..7c0a367da 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2420,7 +2420,11 @@ func createBreachRetributionLegacy(revokedLog *channeldb.ChannelCommitment, continue } - entry := channeldb.NewHTLCEntryFromHTLC(htlc) + entry, err := channeldb.NewHTLCEntryFromHTLC(htlc) + if err != nil { + return nil, 0, 0, err + } + hr, err := createHtlcRetribution( chanState, keyRing, commitHash, commitmentSecret, leaseExpiry, entry, diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 3034ccd5b..1bfd4ad88 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -10031,7 +10031,7 @@ func TestCreateBreachRetribution(t *testing.T) { revokedLog := channeldb.NewRevocationLog( uint16(localIndex), uint16(remoteIndex), commitHash, fn.Some(ourAmtMsat), fn.Some(theirAmtMsat), - []*channeldb.HTLCEntry{htlc}, + []*channeldb.HTLCEntry{htlc}, fn.None[tlv.Blob](), ) // Create a log with an empty local output index. From 2510c190241df65d0efcbf52c783b3c7f2362d61 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 1 Apr 2024 17:44:14 -0700 Subject: [PATCH 19/22] channeldb: add HtlcIndex to HTLCEntry This may be useful for custom channel types that base everything off the index (a global value) rather than the output index (can change with each state). --- channeldb/revocation_log.go | 23 ++++++++++++++++++----- channeldb/revocation_log_test.go | 17 +++++++++++++---- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index b7b73a35a..3abc73f81 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -155,19 +155,17 @@ type HTLCEntry struct { // Incoming denotes whether we're the receiver or the sender of this // HTLC. - // - // NOTE: this field is the memory representation of the field - // incomingUint. Incoming tlv.RecordT[tlv.TlvType3, bool] // Amt is the amount of satoshis this HTLC escrows. - // - // NOTE: this field is the memory representation of the field amtUint. Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]] // CustomBlob is an optional blob that can be used to store information // specific to revocation handling for a custom channel type. CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob] + + // HtlcIndex is the index of the HTLC in the channel. + HtlcIndex tlv.OptionalRecordT[tlv.TlvType6, uint16] } // toTlvStream converts an HTLCEntry record into a tlv representation. @@ -184,6 +182,12 @@ func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { records = append(records, r.Record()) }) + h.HtlcIndex.WhenSome(func(r tlv.RecordT[tlv.TlvType6, uint16]) { + records = append(records, r.Record()) + }) + + tlv.SortRecords(records) + return tlv.NewStream(records...) } @@ -203,6 +207,9 @@ func NewHTLCEntryFromHTLC(htlc HTLC) (*HTLCEntry, error) { Amt: tlv.NewRecordT[tlv.TlvType4]( tlv.NewBigSizeT(htlc.Amt.ToSatoshis()), ), + HtlcIndex: tlv.SomeRecordT(tlv.NewPrimitiveRecord[tlv.TlvType6]( + uint16(htlc.HtlcIndex), + )), } if len(htlc.CustomRecords) != 0 { @@ -509,6 +516,7 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { var htlc HTLCEntry customBlob := htlc.CustomBlob.Zero() + htlcIndex := htlc.HtlcIndex.Zero() // Create the tlv stream. records := []tlv.Record{ @@ -518,6 +526,7 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { htlc.Incoming.Record(), htlc.Amt.Record(), customBlob.Record(), + htlcIndex.Record(), } tlvStream, err := tlv.NewStream(records...) @@ -539,6 +548,10 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { htlc.CustomBlob = tlv.SomeRecordT(customBlob) } + if t, ok := parsedTypes[htlcIndex.TlvType()]; ok && t == nil { + htlc.HtlcIndex = tlv.SomeRecordT(htlcIndex) + } + // Append the entry. htlcs = append(htlcs, &htlc) } diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go index 139a02d52..4290552ee 100644 --- a/channeldb/revocation_log_test.go +++ b/channeldb/revocation_log_test.go @@ -59,10 +59,13 @@ var ( CustomBlob: tlv.SomeRecordT( tlv.NewPrimitiveRecord[tlv.TlvType5](blobBytes), ), + HtlcIndex: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, uint16](0x33), + ), } testHTLCEntryBytes = []byte{ - // Body length 41. - 0x29, + // Body length 45. + 0x2d, // Rhash tlv. 0x0, 0x0, // RefundTimeout tlv. @@ -76,6 +79,8 @@ var ( // Custom blob tlv. 0x5, 0x11, 0xfe, 0x00, 0x01, 0x00, 0x01, 0x0b, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x20, 0x64, 0x61, 0x74, 0x61, + // HLTC index tlv. + 0x6, 0x2, 0x0, 0x33, } testHTLCEntryHash = HTLCEntry{ @@ -126,7 +131,11 @@ var ( Htlcs: []HTLC{{ RefundTimeout: testHTLCEntry.RefundTimeout.Val, OutputIndex: int32(testHTLCEntry.OutputIndex.Val), - Incoming: testHTLCEntry.Incoming.Val, + HtlcIndex: uint64( + testHTLCEntry.HtlcIndex.ValOpt(). + UnsafeFromSome(), + ), + Incoming: testHTLCEntry.Incoming.Val, Amt: lnwire.NewMSatFromSatoshis( testHTLCEntry.Amt.Val.Int(), ), @@ -294,7 +303,7 @@ func TestSerializeHTLCEntries(t *testing.T) { partialBytes := testHTLCEntryBytes[3:] // Write the total length and RHash tlv. - expectedBytes := []byte{0x49, 0x0, 0x20} + expectedBytes := []byte{0x4d, 0x0, 0x20} expectedBytes = append(expectedBytes, rHashBytes...) // Append the rest. From b45d72fe59077c0eeace630747ba1b518b2b0018 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Thu, 25 Apr 2024 19:00:42 +0200 Subject: [PATCH 20/22] multi: thread thru the AuxLeafStore everywhere --- chainreg/chainregistry.go | 5 ++ config_builder.go | 36 +++++++++----- contractcourt/breach_arbitrator_test.go | 3 ++ contractcourt/chain_arbitrator.go | 18 ++++++- contractcourt/chain_watcher.go | 11 +++-- funding/manager.go | 12 ++++- funding/manager_test.go | 5 ++ lnd.go | 3 +- lnwallet/channel.go | 63 +++++++++++++++---------- lnwallet/channel_test.go | 10 ++++ lnwallet/config.go | 5 ++ lnwallet/mock.go | 44 +++++++++++++++++ lnwallet/wallet.go | 9 +++- peer/brontide.go | 14 +++++- server.go | 3 ++ 15 files changed, 194 insertions(+), 47 deletions(-) diff --git a/chainreg/chainregistry.go b/chainreg/chainregistry.go index 37e72fdee..41a2fcbb7 100644 --- a/chainreg/chainregistry.go +++ b/chainreg/chainregistry.go @@ -24,6 +24,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs/neutrinonotify" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -63,6 +64,10 @@ type Config struct { // state. ChanStateDB *channeldb.ChannelStateDB + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] + // BlockCache is the main cache for storing block information. BlockCache *blockcache.BlockCache diff --git a/config_builder.go b/config_builder.go index bef59b9a0..7c399297e 100644 --- a/config_builder.go +++ b/config_builder.go @@ -105,7 +105,7 @@ type DatabaseBuilder interface { type WalletConfigBuilder interface { // BuildWalletConfig is responsible for creating or unlocking and then // fully initializing a wallet. - BuildWalletConfig(context.Context, *DatabaseInstances, + BuildWalletConfig(context.Context, *DatabaseInstances, *AuxComponents, *rpcperms.InterceptorChain, []*ListenerWithSignal) (*chainreg.PartialChainControl, *btcwallet.Config, func(), error) @@ -120,14 +120,6 @@ type ChainControlBuilder interface { *btcwallet.Config) (*chainreg.ChainControl, func(), error) } -// AuxComponents is a set of auxiliary components that can be used by lnd for -// certain custom channel types. -type AuxComponents struct { - // MsgRouter is an optional message router that if set will be used in - // place of a new blank default message router. - MsgRouter fn.Option[msgmux.Router] -} - // ImplementationCfg is a struct that holds all configuration items for // components that can be implemented outside lnd itself. type ImplementationCfg struct { @@ -160,6 +152,18 @@ type ImplementationCfg struct { AuxComponents } +// AuxComponents is a set of auxiliary components that can be used by lnd for +// certain custom channel types. +type AuxComponents struct { + // AuxLeafStore is an optional data source that can be used by custom + // channels to fetch+store various data. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] + + // MsgRouter is an optional message router that if set will be used in + // place of a new blank default message router. + MsgRouter fn.Option[msgmux.Router] +} + // DefaultWalletImpl is the default implementation of our normal, btcwallet // backed configuration. type DefaultWalletImpl struct { @@ -242,7 +246,8 @@ func (d *DefaultWalletImpl) Permissions() map[string][]bakery.Op { // // NOTE: This is part of the WalletConfigBuilder interface. func (d *DefaultWalletImpl) BuildWalletConfig(ctx context.Context, - dbs *DatabaseInstances, interceptorChain *rpcperms.InterceptorChain, + dbs *DatabaseInstances, aux *AuxComponents, + interceptorChain *rpcperms.InterceptorChain, grpcListeners []*ListenerWithSignal) (*chainreg.PartialChainControl, *btcwallet.Config, func(), error) { @@ -562,6 +567,7 @@ func (d *DefaultWalletImpl) BuildWalletConfig(ctx context.Context, HeightHintDB: dbs.HeightHintDB, ChanStateDB: dbs.ChanStateDB.ChannelStateDB(), NeutrinoCS: neutrinoCS, + AuxLeafStore: aux.AuxLeafStore, ActiveNetParams: d.cfg.ActiveNetParams, FeeURL: d.cfg.FeeURL, Fee: &lncfg.Fee{ @@ -625,8 +631,9 @@ func (d *DefaultWalletImpl) BuildWalletConfig(ctx context.Context, // proxyBlockEpoch proxies a block epoch subsections to the underlying neutrino // rebroadcaster client. -func proxyBlockEpoch(notifier chainntnfs.ChainNotifier, -) func() (*blockntfns.Subscription, error) { +func proxyBlockEpoch( + notifier chainntnfs.ChainNotifier) func() (*blockntfns.Subscription, + error) { return func() (*blockntfns.Subscription, error) { blockEpoch, err := notifier.RegisterBlockEpochNtfn( @@ -717,6 +724,7 @@ func (d *DefaultWalletImpl) BuildChainControl( ChainIO: walletController, NetParams: *walletConfig.NetParams, CoinSelectionStrategy: walletConfig.CoinSelectionStrategy, + AuxLeafStore: partialChainControl.Cfg.AuxLeafStore, } // The broadcast is already always active for neutrino nodes, so we @@ -899,6 +907,10 @@ type DatabaseInstances struct { // for native SQL queries for tables that already support it. This may // be nil if the use-native-sql flag was not set. NativeSQLStore *sqldb.BaseDB + + // AuxLeafStore is an optional data source that can be used by custom + // channels to fetch+store various data. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] } // DefaultDatabaseBuilder is a type that builds the default database backends diff --git a/contractcourt/breach_arbitrator_test.go b/contractcourt/breach_arbitrator_test.go index 6a1865444..5940ee25b 100644 --- a/contractcourt/breach_arbitrator_test.go +++ b/contractcourt/breach_arbitrator_test.go @@ -22,6 +22,7 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntest/channels" @@ -1590,6 +1591,7 @@ func testBreachSpends(t *testing.T, test breachTest) { // Notify the breach arbiter about the breach. retribution, err := lnwallet.NewBreachRetribution( alice.State(), height, 1, forceCloseTx, + fn.Some[lnwallet.AuxLeafStore](&lnwallet.MockAuxLeafStore{}), ) require.NoError(t, err, "unable to create breach retribution") @@ -1799,6 +1801,7 @@ func TestBreachDelayedJusticeConfirmation(t *testing.T) { // Notify the breach arbiter about the breach. retribution, err := lnwallet.NewBreachRetribution( alice.State(), height, uint32(blockHeight), forceCloseTx, + fn.Some[lnwallet.AuxLeafStore](&lnwallet.MockAuxLeafStore{}), ) require.NoError(t, err, "unable to create breach retribution") diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 0cc4b111a..dbc97939a 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -217,6 +217,10 @@ type ChainArbitratorConfig struct { // meanwhile, turn `PaymentCircuit` into an interface or bring it to a // lower package. QueryIncomingCircuit func(circuit models.CircuitKey) *models.CircuitKey + + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] } // ChainArbitrator is a sub-system that oversees the on-chain resolution of all @@ -299,8 +303,13 @@ func (a *arbChannel) NewAnchorResolutions() (*lnwallet.AnchorResolutions, return nil, err } + var chanOpts []lnwallet.ChannelOpt + a.c.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) + chanMachine, err := lnwallet.NewLightningChannel( - a.c.cfg.Signer, channel, nil, + a.c.cfg.Signer, channel, nil, chanOpts..., ) if err != nil { return nil, err @@ -344,10 +353,15 @@ func (a *arbChannel) ForceCloseChan() (*lnwallet.LocalForceCloseSummary, error) return nil, err } + var chanOpts []lnwallet.ChannelOpt + a.c.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) + // Finally, we'll force close the channel completing // the force close workflow. chanMachine, err := lnwallet.NewLightningChannel( - a.c.cfg.Signer, channel, nil, + a.c.cfg.Signer, channel, nil, chanOpts..., ) if err != nil { return nil, err diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 66a61fe3b..4012f031d 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -193,6 +193,9 @@ type chainWatcherConfig struct { // obfuscater. This is used by the chain watcher to identify which // state was broadcast and confirmed on-chain. extractStateNumHint func(*wire.MsgTx, [lnwallet.StateHintSize]byte) uint64 + + // auxLeafStore can be used to fetch information for custom channels. + auxLeafStore fn.Option[lnwallet.AuxLeafStore] } // chainWatcher is a system that's assigned to every active channel. The duty @@ -867,7 +870,7 @@ func (c *chainWatcher) handlePossibleBreach(commitSpend *chainntnfs.SpendDetail, spendHeight := uint32(commitSpend.SpendingHeight) retribution, err := lnwallet.NewBreachRetribution( c.cfg.chanState, broadcastStateNum, spendHeight, - commitSpend.SpendingTx, + commitSpend.SpendingTx, c.cfg.auxLeafStore, ) switch { @@ -1117,8 +1120,8 @@ func (c *chainWatcher) dispatchLocalForceClose( "detected", c.cfg.chanState.FundingOutpoint) forceClose, err := lnwallet.NewLocalForceCloseSummary( - c.cfg.chanState, c.cfg.signer, - commitSpend.SpendingTx, stateNum, + c.cfg.chanState, c.cfg.signer, commitSpend.SpendingTx, stateNum, + c.cfg.auxLeafStore, ) if err != nil { return err @@ -1211,7 +1214,7 @@ func (c *chainWatcher) dispatchRemoteForceClose( // channel on-chain. uniClose, err := lnwallet.NewUnilateralCloseSummary( c.cfg.chanState, c.cfg.signer, commitSpend, - remoteCommit, commitPoint, + remoteCommit, commitPoint, c.cfg.auxLeafStore, ) if err != nil { return err diff --git a/funding/manager.go b/funding/manager.go index b22f2e691..6bb027725 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -24,6 +24,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/discovery" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -544,6 +545,10 @@ type Config struct { // backed funding flow to not use utxos still being swept by the sweeper // subsystem. IsSweeperOutpoint func(wire.OutPoint) bool + + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] } // Manager acts as an orchestrator/bridge between the wallet's @@ -1069,9 +1074,14 @@ func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel, } } + var chanOpts []lnwallet.ChannelOpt + f.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) + // We create the state-machine object which wraps the database state. lnChannel, err := lnwallet.NewLightningChannel( - nil, channel, nil, + nil, channel, nil, chanOpts..., ) if err != nil { log.Errorf("Unable to create LightningChannel(%v): %v", diff --git a/funding/manager_test.go b/funding/manager_test.go index c4c8b4f36..471d0209f 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -28,6 +28,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/discovery" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lncfg" @@ -563,6 +564,9 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey, IsSweeperOutpoint: func(wire.OutPoint) bool { return false }, + AuxLeafStore: fn.Some[lnwallet.AuxLeafStore]( + &lnwallet.MockAuxLeafStore{}, + ), } for _, op := range options { @@ -672,6 +676,7 @@ func recreateAliceFundingManager(t *testing.T, alice *testNode) { OpenChannelPredicate: chainedAcceptor, DeleteAliasEdge: oldCfg.DeleteAliasEdge, AliasManager: oldCfg.AliasManager, + AuxLeafStore: oldCfg.AuxLeafStore, }) require.NoError(t, err, "failed recreating aliceFundingManager") diff --git a/lnd.go b/lnd.go index e483d5512..da7747e91 100644 --- a/lnd.go +++ b/lnd.go @@ -456,7 +456,8 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, defer cleanUp() partialChainControl, walletConfig, cleanUp, err := implCfg.BuildWalletConfig( - ctx, dbs, interceptorChain, grpcListeners, + ctx, dbs, &implCfg.AuxComponents, interceptorChain, + grpcListeners, ) if err != nil { return mkErr("error creating wallet config: %v", err) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 7c0a367da..e794da471 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -1965,7 +1965,8 @@ type BreachRetribution struct { // required to construct the BreachRetribution. If the revocation log is missing // the required fields then ErrRevLogDataMissing will be returned. func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, - breachHeight uint32, spendTx *wire.MsgTx) (*BreachRetribution, error) { + breachHeight uint32, spendTx *wire.MsgTx, + leafStore fn.Option[AuxLeafStore]) (*BreachRetribution, error) { // Query the on-disk revocation log for the snapshot which was recorded // at this particular state num. Based on whether a legacy revocation @@ -3023,9 +3024,16 @@ func processFeeUpdate(feeUpdate *PaymentDescriptor, nextHeight uint64, // signature can be submitted to the sigPool to generate all the signatures // asynchronously and in parallel. func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, - chanType channeldb.ChannelType, isRemoteInitiator bool, - leaseExpiry uint32, localChanCfg, remoteChanCfg *channeldb.ChannelConfig, - remoteCommitView *commitment) ([]SignJob, chan struct{}, error) { + chanState *channeldb.OpenChannel, leaseExpiry uint32, + remoteCommitView *commitment, + leafStore fn.Option[AuxLeafStore]) ([]SignJob, chan struct{}, error) { + + var ( + isRemoteInitiator = !chanState.IsInitiator + localChanCfg = chanState.LocalChanCfg + remoteChanCfg = chanState.RemoteChanCfg + chanType = chanState.ChanType + ) txHash := remoteCommitView.txn.TxHash() dustLimit := remoteChanCfg.DustLimit @@ -3191,9 +3199,9 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, // validate this new state. This function is called right before sending the // new commitment to the remote party. The commit diff returned contains all // information necessary for retransmission. -func (lc *LightningChannel) createCommitDiff( - newCommit *commitment, commitSig lnwire.Sig, - htlcSigs []lnwire.Sig) (*channeldb.CommitDiff, error) { +func (lc *LightningChannel) createCommitDiff(newCommit *commitment, + commitSig lnwire.Sig, htlcSigs []lnwire.Sig) (*channeldb.CommitDiff, + error) { // First, we need to convert the funding outpoint into the ID that's // used on the wire to identify this channel. We'll use this shortly @@ -3892,9 +3900,8 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { leaseExpiry = lc.channelState.ThawHeight } sigBatch, cancelChan, err := genRemoteHtlcSigJobs( - keyRing, lc.channelState.ChanType, !lc.channelState.IsInitiator, - leaseExpiry, &lc.channelState.LocalChanCfg, - &lc.channelState.RemoteChanCfg, newCommitView, + keyRing, lc.channelState, leaseExpiry, newCommitView, + lc.leafStore, ) if err != nil { return nil, err @@ -4465,10 +4472,18 @@ func (lc *LightningChannel) computeView(view *HtlcView, // meant to verify all the signatures for HTLC's attached to a newly created // commitment state. The jobs generated are fully populated, and can be sent // directly into the pool of workers. -func genHtlcSigValidationJobs(localCommitmentView *commitment, - keyRing *CommitmentKeyRing, htlcSigs []lnwire.Sig, - chanType channeldb.ChannelType, isLocalInitiator bool, leaseExpiry uint32, - localChanCfg, remoteChanCfg *channeldb.ChannelConfig) ([]VerifyJob, error) { +// +//nolint:funlen +func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, + localCommitmentView *commitment, keyRing *CommitmentKeyRing, + htlcSigs []lnwire.Sig, leaseExpiry uint32, + leafStore fn.Option[AuxLeafStore]) ([]VerifyJob, error) { + + var ( + isLocalInitiator = chanState.IsInitiator + localChanCfg = chanState.LocalChanCfg + chanType = chanState.ChanType + ) txHash := localCommitmentView.txn.TxHash() feePerKw := localCommitmentView.feePerKw @@ -4858,10 +4873,8 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error { leaseExpiry = lc.channelState.ThawHeight } verifyJobs, err := genHtlcSigValidationJobs( - localCommitmentView, keyRing, commitSigs.HtlcSigs, - lc.channelState.ChanType, lc.channelState.IsInitiator, - leaseExpiry, &lc.channelState.LocalChanCfg, - &lc.channelState.RemoteChanCfg, + lc.channelState, localCommitmentView, keyRing, + commitSigs.HtlcSigs, leaseExpiry, lc.leafStore, ) if err != nil { return err @@ -6308,10 +6321,10 @@ type UnilateralCloseSummary struct { // happen in case we have lost state) it should be set to an empty struct, in // which case we will attempt to sweep the non-HTLC output using the passed // commitPoint. -func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Signer, - commitSpend *chainntnfs.SpendDetail, - remoteCommit channeldb.ChannelCommitment, - commitPoint *btcec.PublicKey) (*UnilateralCloseSummary, error) { +func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, + signer input.Signer, commitSpend *chainntnfs.SpendDetail, + remoteCommit channeldb.ChannelCommitment, commitPoint *btcec.PublicKey, + leafStore fn.Option[AuxLeafStore]) (*UnilateralCloseSummary, error) { // First, we'll generate the commitment point and the revocation point // so we can re-construct the HTLC state and also our payment key. @@ -7254,7 +7267,7 @@ func (lc *LightningChannel) ForceClose() (*LocalForceCloseSummary, error) { localCommitment := lc.channelState.LocalCommitment summary, err := NewLocalForceCloseSummary( lc.channelState, lc.Signer, commitTx, - localCommitment.CommitHeight, + localCommitment.CommitHeight, lc.leafStore, ) if err != nil { return nil, fmt.Errorf("unable to gen force close "+ @@ -7271,8 +7284,8 @@ func (lc *LightningChannel) ForceClose() (*LocalForceCloseSummary, error) { // channel state. The passed commitTx must be a fully signed commitment // transaction corresponding to localCommit. func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, - signer input.Signer, commitTx *wire.MsgTx, stateNum uint64) ( - *LocalForceCloseSummary, error) { + signer input.Signer, commitTx *wire.MsgTx, stateNum uint64, + leafStore fn.Option[AuxLeafStore]) (*LocalForceCloseSummary, error) { // Re-derive the original pkScript for to-self output within the // commitment transaction. We'll need this to find the corresponding diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 1bfd4ad88..110d323f6 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -5695,6 +5695,7 @@ func TestChannelUnilateralCloseHtlcResolution(t *testing.T) { spendDetail, aliceChannel.channelState.RemoteCommitment, aliceChannel.channelState.RemoteCurrentRevocation, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.NoError(t, err, "unable to create alice close summary") @@ -5844,6 +5845,7 @@ func TestChannelUnilateralClosePendingCommit(t *testing.T) { spendDetail, aliceChannel.channelState.RemoteCommitment, aliceChannel.channelState.RemoteCurrentRevocation, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.NoError(t, err, "unable to create alice close summary") @@ -5861,6 +5863,7 @@ func TestChannelUnilateralClosePendingCommit(t *testing.T) { spendDetail, aliceRemoteChainTip.Commitment, aliceChannel.channelState.RemoteNextRevocation, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.NoError(t, err, "unable to create alice close summary") @@ -6741,6 +6744,7 @@ func TestNewBreachRetributionSkipsDustHtlcs(t *testing.T) { breachTx := aliceChannel.channelState.RemoteCommitment.CommitTx breachRet, err := NewBreachRetribution( aliceChannel.channelState, revokedStateNum, 100, breachTx, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.NoError(t, err, "unable to create breach retribution") @@ -10291,6 +10295,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // error as there are no past delta state saved as revocation logs yet. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, breachTx, + fn.None[AuxLeafStore](), ) require.ErrorIs(t, err, channeldb.ErrNoPastDeltas) @@ -10298,6 +10303,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // provided. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, nil, + fn.None[AuxLeafStore](), ) require.ErrorIs(t, err, channeldb.ErrNoPastDeltas) @@ -10343,6 +10349,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // successfully. br, err := NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, breachTx, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.NoError(t, err) @@ -10354,6 +10361,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // since the necessary info should now be found in the revocation log. br, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, nil, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.NoError(t, err) assertRetribution(br, 1, 0) @@ -10362,6 +10370,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // error. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum+1, breachHeight, breachTx, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.ErrorIs(t, err, channeldb.ErrLogEntryNotFound) @@ -10369,6 +10378,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // provided. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum+1, breachHeight, nil, + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.ErrorIs(t, err, channeldb.ErrLogEntryNotFound) } diff --git a/lnwallet/config.go b/lnwallet/config.go index 7eeacb6ea..24961f38e 100644 --- a/lnwallet/config.go +++ b/lnwallet/config.go @@ -5,6 +5,7 @@ import ( "github.com/btcsuite/btcwallet/wallet" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -62,4 +63,8 @@ type Config struct { // CoinSelectionStrategy is the strategy that is used for selecting // coins when funding a transaction. CoinSelectionStrategy wallet.CoinSelectionStrategy + + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[AuxLeafStore] } diff --git a/lnwallet/mock.go b/lnwallet/mock.go index faac5fa67..2afff4f21 100644 --- a/lnwallet/mock.go +++ b/lnwallet/mock.go @@ -17,8 +17,10 @@ import ( "github.com/btcsuite/btcwallet/wallet/txauthor" "github.com/btcsuite/btcwallet/wtxmgr" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -397,3 +399,45 @@ func (*mockChainIO) GetBlockHeader( return nil, nil } + +type MockAuxLeafStore struct{} + +// A compile time check to ensure that MockAuxLeafStore implements the +// AuxLeafStore interface. +var _ AuxLeafStore = (*MockAuxLeafStore)(nil) + +// FetchLeavesFromView attempts to fetch the auxiliary leaves that +// correspond to the passed aux blob, and pending original (unfiltered) +// HTLC view. +func (*MockAuxLeafStore) FetchLeavesFromView( + _ CommitDiffAuxInput) fn.Result[CommitDiffAuxResult] { + + return fn.Ok(CommitDiffAuxResult{}) +} + +// FetchLeavesFromCommit attempts to fetch the auxiliary leaves that +// correspond to the passed aux blob, and an existing channel +// commitment. +func (*MockAuxLeafStore) FetchLeavesFromCommit(_ AuxChanState, + _ channeldb.ChannelCommitment, + _ CommitmentKeyRing) fn.Result[CommitDiffAuxResult] { + + return fn.Ok(CommitDiffAuxResult{}) +} + +// FetchLeavesFromRevocation attempts to fetch the auxiliary leaves +// from a channel revocation that stores balance + blob information. +func (*MockAuxLeafStore) FetchLeavesFromRevocation( + _ *channeldb.RevocationLog) fn.Result[CommitDiffAuxResult] { + + return fn.Ok(CommitDiffAuxResult{}) +} + +// ApplyHtlcView serves as the state transition function for the custom +// channel's blob. Given the old blob, and an HTLC view, then a new +// blob should be returned that reflects the pending updates. +func (*MockAuxLeafStore) ApplyHtlcView( + _ CommitDiffAuxInput) fn.Result[fn.Option[tlv.Blob]] { + + return fn.Ok(fn.None[tlv.Blob]()) +} diff --git a/lnwallet/wallet.go b/lnwallet/wallet.go index 077e17a3a..7e455ab48 100644 --- a/lnwallet/wallet.go +++ b/lnwallet/wallet.go @@ -2496,9 +2496,16 @@ func initStateHints(commit1, commit2 *wire.MsgTx, func (l *LightningWallet) ValidateChannel(channelState *channeldb.OpenChannel, fundingTx *wire.MsgTx) error { + var chanOpts []ChannelOpt + l.Cfg.AuxLeafStore.WhenSome(func(s AuxLeafStore) { + chanOpts = append(chanOpts, WithLeafStore(s)) + }) + // First, we'll obtain a fully signed commitment transaction so we can // pass into it on the chanvalidate package for verification. - channel, err := NewLightningChannel(l.Cfg.Signer, channelState, nil) + channel, err := NewLightningChannel( + l.Cfg.Signer, channelState, nil, chanOpts..., + ) if err != nil { return err } diff --git a/peer/brontide.go b/peer/brontide.go index 25c7cea6f..1324044da 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -372,6 +372,10 @@ type Config struct { AddLocalAlias func(alias, base lnwire.ShortChannelID, gossip bool) error + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] + // PongBuf is a slice we'll reuse instead of allocating memory on the // heap. Since only reads will occur and no writes, there is no need // for any synchronization primitives. As a result, it's safe to share @@ -943,8 +947,12 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( } } + var chanOpts []lnwallet.ChannelOpt + p.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) lnChan, err := lnwallet.NewLightningChannel( - p.cfg.Signer, dbChan, p.cfg.SigPool, + p.cfg.Signer, dbChan, p.cfg.SigPool, chanOpts..., ) if err != nil { return nil, fmt.Errorf("unable to create channel "+ @@ -4151,6 +4159,10 @@ func (p *Brontide) addActiveChannel(c *lnpeer.NewChannel) error { chanOpts = append(chanOpts, lnwallet.WithSkipNonceInit()) } + p.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) + // If not already active, we'll add this channel to the set of active // channels, so we can look it up later easily according to its channel // ID. diff --git a/server.go b/server.go index 75fd02b97..1f8db90a6 100644 --- a/server.go +++ b/server.go @@ -1273,6 +1273,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return &pc.Incoming }, + AuxLeafStore: implCfg.AuxLeafStore, }, dbs.ChanStateDB) // Select the configuration and funding parameters for Bitcoin. @@ -1607,6 +1608,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, br, err := lnwallet.NewBreachRetribution( channel, commitHeight, 0, nil, + implCfg.AuxLeafStore, ) if err != nil { return nil, 0, err @@ -4073,6 +4075,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, DisallowRouteBlinding: s.cfg.ProtocolOptions.NoRouteBlinding(), MaxFeeExposure: thresholdMSats, Quit: s.quit, + AuxLeafStore: s.implCfg.AuxLeafStore, MsgRouter: s.implCfg.MsgRouter, } From 2ab22b0f0b06cabe54d3ce6a79c498b3753c06a2 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Fri, 23 Aug 2024 15:13:19 +0200 Subject: [PATCH 21/22] lnwallet: refactor commit keys to use lntypes.Dual --- lnwallet/channel.go | 30 ++++++++++++++---------------- lnwallet/channel_test.go | 2 +- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index e794da471..c382b9dc2 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -575,9 +575,8 @@ func (c *commitment) toDiskCommit( // restore commitment state written to disk back into memory once we need to // restart a channel session. func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, - htlc *channeldb.HTLC, localCommitKeys *CommitmentKeyRing, - remoteCommitKeys *CommitmentKeyRing, whoseCommit lntypes.ChannelParty, -) (PaymentDescriptor, error) { + htlc *channeldb.HTLC, commitKeys lntypes.Dual[*CommitmentKeyRing], + whoseCommit lntypes.ChannelParty) (PaymentDescriptor, error) { // The proper pkScripts for this PaymentDescriptor must be // generated so we can easily locate them within the commitment @@ -598,6 +597,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, chanType, htlc.Incoming, lntypes.Local, feeRate, htlc.Amt.ToSatoshis(), lc.channelState.LocalChanCfg.DustLimit, ) + localCommitKeys := commitKeys.GetForParty(lntypes.Local) if !isDustLocal && localCommitKeys != nil { scriptInfo, err := genHtlcScript( chanType, htlc.Incoming, lntypes.Local, @@ -613,6 +613,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, chanType, htlc.Incoming, lntypes.Remote, feeRate, htlc.Amt.ToSatoshis(), lc.channelState.RemoteChanCfg.DustLimit, ) + remoteCommitKeys := commitKeys.GetForParty(lntypes.Remote) if !isDustRemote && remoteCommitKeys != nil { scriptInfo, err := genHtlcScript( chanType, htlc.Incoming, lntypes.Remote, @@ -665,9 +666,9 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, // these payment descriptors can be re-inserted into the in-memory updateLog // for each side. func (lc *LightningChannel) extractPayDescs(feeRate chainfee.SatPerKWeight, - htlcs []channeldb.HTLC, localCommitKeys *CommitmentKeyRing, - remoteCommitKeys *CommitmentKeyRing, whoseCommit lntypes.ChannelParty, -) ([]PaymentDescriptor, []PaymentDescriptor, error) { + htlcs []channeldb.HTLC, commitKeys lntypes.Dual[*CommitmentKeyRing], + whoseCommit lntypes.ChannelParty) ([]PaymentDescriptor, + []PaymentDescriptor, error) { var ( incomingHtlcs []PaymentDescriptor @@ -685,9 +686,7 @@ func (lc *LightningChannel) extractPayDescs(feeRate chainfee.SatPerKWeight, htlc := htlc payDesc, err := lc.diskHtlcToPayDesc( - feeRate, &htlc, - localCommitKeys, remoteCommitKeys, - whoseCommit, + feeRate, &htlc, commitKeys, whoseCommit, ) if err != nil { return incomingHtlcs, outgoingHtlcs, err @@ -716,22 +715,22 @@ func (lc *LightningChannel) diskCommitToMemCommit( // (we extended but weren't able to complete the commitment dance // before shutdown), then the localCommitPoint won't be set as we // haven't yet received a responding commitment from the remote party. - var localCommitKeys, remoteCommitKeys *CommitmentKeyRing + var commitKeys lntypes.Dual[*CommitmentKeyRing] if localCommitPoint != nil { - localCommitKeys = DeriveCommitmentKeys( + commitKeys.SetForParty(lntypes.Local, DeriveCommitmentKeys( localCommitPoint, lntypes.Local, lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, - ) + )) } if remoteCommitPoint != nil { - remoteCommitKeys = DeriveCommitmentKeys( + commitKeys.SetForParty(lntypes.Remote, DeriveCommitmentKeys( remoteCommitPoint, lntypes.Remote, lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, - ) + )) } // With the key rings re-created, we'll now convert all the on-disk @@ -739,8 +738,7 @@ func (lc *LightningChannel) diskCommitToMemCommit( // update log. incomingHtlcs, outgoingHtlcs, err := lc.extractPayDescs( chainfee.SatPerKWeight(diskCommit.FeePerKw), - diskCommit.Htlcs, localCommitKeys, remoteCommitKeys, - whoseCommit, + diskCommit.Htlcs, commitKeys, whoseCommit, ) if err != nil { return nil, err diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 110d323f6..4a1cb0321 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -10416,7 +10416,7 @@ func TestExtractPayDescs(t *testing.T) { // NOTE: we use nil commitment key rings to avoid checking the htlc // scripts(`genHtlcScript`) as it should be tested independently. incomingPDs, outgoingPDs, err := lnChan.extractPayDescs( - 0, htlcs, nil, nil, lntypes.Local, + 0, htlcs, lntypes.Dual[*CommitmentKeyRing]{}, lntypes.Local, ) require.NoError(t, err) From 9dfbde701332e1220d4fb56ae102d110dc660b55 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Thu, 25 Apr 2024 19:01:37 +0200 Subject: [PATCH 22/22] lnwallet: thread thru input.AuxTapleaf to all relevant areas In this commit, we start to thread thru the new aux tap leaf structures to all relevant areas. This includes: commitment outputs, resolution creation, breach handling, and also HTLC scripts. --- contractcourt/chain_watcher.go | 30 +++- input/size_test.go | 4 +- lnwallet/channel.go | 304 ++++++++++++++++++++++++++++----- lnwallet/channel_test.go | 8 +- lnwallet/commitment.go | 63 ++++--- lnwallet/transactions.go | 12 +- 6 files changed, 345 insertions(+), 76 deletions(-) diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 4012f031d..b1e3fc1c2 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -426,15 +426,36 @@ func (c *chainWatcher) handleUnknownLocalState( &c.cfg.chanState.LocalChanCfg, &c.cfg.chanState.RemoteChanCfg, ) + auxResult, err := fn.MapOptionZ( + c.cfg.auxLeafStore, + //nolint:lll + func(s lnwallet.AuxLeafStore) fn.Result[lnwallet.CommitDiffAuxResult] { + return s.FetchLeavesFromCommit( + lnwallet.NewAuxChanState(c.cfg.chanState), + c.cfg.chanState.LocalCommitment, *commitKeyRing, + ) + }, + ).Unpack() + if err != nil { + return false, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // With the keys derived, we'll construct the remote script that'll be // present if they have a non-dust balance on the commitment. var leaseExpiry uint32 if c.cfg.chanState.ChanType.HasLeaseExpiration() { leaseExpiry = c.cfg.chanState.ThawHeight } + + remoteAuxLeaf := fn.ChainOption( + func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + }, + )(auxResult.AuxLeaves) remoteScript, _, err := lnwallet.CommitScriptToRemote( c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator, - commitKeyRing.ToRemoteKey, leaseExpiry, input.NoneTapLeaf(), + commitKeyRing.ToRemoteKey, leaseExpiry, + remoteAuxLeaf, ) if err != nil { return false, err @@ -443,11 +464,16 @@ func (c *chainWatcher) handleUnknownLocalState( // Next, we'll derive our script that includes the revocation base for // the remote party allowing them to claim this output before the CSV // delay if we breach. + localAuxLeaf := fn.ChainOption( + func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + }, + )(auxResult.AuxLeaves) localScript, err := lnwallet.CommitScriptToSelf( c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator, commitKeyRing.ToLocalKey, commitKeyRing.RevocationKey, uint32(c.cfg.chanState.LocalChanCfg.CsvDelay), leaseExpiry, - input.NoneTapLeaf(), + localAuxLeaf, ) if err != nil { return false, err diff --git a/input/size_test.go b/input/size_test.go index 21e2c06ae..33f9ff539 100644 --- a/input/size_test.go +++ b/input/size_test.go @@ -1388,7 +1388,7 @@ func genTimeoutTx(t *testing.T, // Create the unsigned timeout tx. timeoutTx, err := lnwallet.CreateHtlcTimeoutTx( chanType, false, testOutPoint, testAmt, testCLTVExpiry, - testCSVDelay, 0, testPubkey, testPubkey, + testCSVDelay, 0, testPubkey, testPubkey, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1457,7 +1457,7 @@ func genSuccessTx(t *testing.T, chanType channeldb.ChannelType) *wire.MsgTx { // Create the unsigned success tx. successTx, err := lnwallet.CreateHtlcSuccessTx( chanType, false, testOutPoint, testAmt, testCSVDelay, 0, - testPubkey, testPubkey, + testPubkey, testPubkey, input.NoneTapLeaf(), ) require.NoError(t, err) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index c382b9dc2..e5cbe6b6a 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -576,7 +576,8 @@ func (c *commitment) toDiskCommit( // restart a channel session. func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, htlc *channeldb.HTLC, commitKeys lntypes.Dual[*CommitmentKeyRing], - whoseCommit lntypes.ChannelParty) (PaymentDescriptor, error) { + whoseCommit lntypes.ChannelParty, + auxLeaf input.AuxTapLeaf) (PaymentDescriptor, error) { // The proper pkScripts for this PaymentDescriptor must be // generated so we can easily locate them within the commitment @@ -602,6 +603,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, scriptInfo, err := genHtlcScript( chanType, htlc.Incoming, lntypes.Local, htlc.RefundTimeout, htlc.RHash, localCommitKeys, + auxLeaf, ) if err != nil { return pd, err @@ -618,6 +620,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, scriptInfo, err := genHtlcScript( chanType, htlc.Incoming, lntypes.Remote, htlc.RefundTimeout, htlc.RHash, remoteCommitKeys, + auxLeaf, ) if err != nil { return pd, err @@ -667,7 +670,8 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, // for each side. func (lc *LightningChannel) extractPayDescs(feeRate chainfee.SatPerKWeight, htlcs []channeldb.HTLC, commitKeys lntypes.Dual[*CommitmentKeyRing], - whoseCommit lntypes.ChannelParty) ([]PaymentDescriptor, + whoseCommit lntypes.ChannelParty, + auxLeaves fn.Option[CommitAuxLeaves]) ([]PaymentDescriptor, []PaymentDescriptor, error) { var ( @@ -685,8 +689,19 @@ func (lc *LightningChannel) extractPayDescs(feeRate chainfee.SatPerKWeight, htlc := htlc + auxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.OutgoingHtlcLeaves + if htlc.Incoming { + leaves = l.IncomingHtlcLeaves + } + + return leaves[htlc.HtlcIndex].AuxTapLeaf + }, + )(auxLeaves) + payDesc, err := lc.diskHtlcToPayDesc( - feeRate, &htlc, commitKeys, whoseCommit, + feeRate, &htlc, commitKeys, whoseCommit, auxLeaf, ) if err != nil { return incomingHtlcs, outgoingHtlcs, err @@ -733,12 +748,25 @@ func (lc *LightningChannel) diskCommitToMemCommit( )) } + auxResult, err := fn.MapOptionZ( + lc.leafStore, + func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return s.FetchLeavesFromCommit( + NewAuxChanState(lc.channelState), *diskCommit, + *commitKeys.GetForParty(whoseCommit), + ) + }, + ).Unpack() + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // With the key rings re-created, we'll now convert all the on-disk // HTLC"s into PaymentDescriptor's so we can re-insert them into our // update log. incomingHtlcs, outgoingHtlcs, err := lc.extractPayDescs( chainfee.SatPerKWeight(diskCommit.FeePerKw), - diskCommit.Htlcs, commitKeys, whoseCommit, + diskCommit.Htlcs, commitKeys, whoseCommit, auxResult.AuxLeaves, ) if err != nil { return nil, err @@ -1091,7 +1119,8 @@ func (lc *LightningChannel) ResetState() { func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, remoteUpdateLog *updateLog, commitHeight uint64, feeRate chainfee.SatPerKWeight, remoteCommitKeys *CommitmentKeyRing, - remoteDustLimit btcutil.Amount) (*PaymentDescriptor, error) { + remoteDustLimit btcutil.Amount, + auxLeaves fn.Option[CommitAuxLeaves]) (*PaymentDescriptor, error) { // Depending on the type of update message we'll map that to a distinct // PaymentDescriptor instance. @@ -1127,10 +1156,17 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, feeRate, wireMsg.Amount.ToSatoshis(), remoteDustLimit, ) if !isDustRemote { + auxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.OutgoingHtlcLeaves + return leaves[pd.HtlcIndex].AuxTapLeaf + }, + )(auxLeaves) + scriptInfo, err := genHtlcScript( lc.channelState.ChanType, false, lntypes.Remote, wireMsg.Expiry, wireMsg.PaymentHash, - remoteCommitKeys, + remoteCommitKeys, auxLeaf, ) if err != nil { return nil, err @@ -1791,6 +1827,19 @@ func (lc *LightningChannel) restorePendingLocalUpdates( pendingCommit := pendingRemoteCommitDiff.Commitment pendingHeight := pendingCommit.CommitHeight + auxResult, err := fn.MapOptionZ( + lc.leafStore, + func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return s.FetchLeavesFromCommit( + NewAuxChanState(lc.channelState), pendingCommit, + *pendingRemoteKeys, + ) + }, + ).Unpack() + if err != nil { + return fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // If we did have a dangling commit, then we'll examine which updates // we included in that state and re-insert them into our update log. for _, logUpdate := range pendingRemoteCommitDiff.LogUpdates { @@ -1801,6 +1850,7 @@ func (lc *LightningChannel) restorePendingLocalUpdates( chainfee.SatPerKWeight(pendingCommit.FeePerKw), pendingRemoteKeys, lc.channelState.RemoteChanCfg.DustLimit, + auxResult.AuxLeaves, ) if err != nil { return err @@ -2007,24 +2057,40 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, leaseExpiry = chanState.ThawHeight } - // TODO(roasbeef): Actually fetch aux leaves (later commits in this PR). + auxResult, err := fn.MapOptionZ( + leafStore, func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return s.FetchLeavesFromRevocation(revokedLog) + }, + ).Unpack() + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } // Since it is the remote breach we are reconstructing, the output // going to us will be a to-remote script with our local params. + remoteAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + }, + )(auxResult.AuxLeaves) isRemoteInitiator := !chanState.IsInitiator ourScript, ourDelay, err := CommitScriptToRemote( chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey, - leaseExpiry, input.NoneTapLeaf(), + leaseExpiry, remoteAuxLeaf, ) if err != nil { return nil, err } + localAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + }, + )(auxResult.AuxLeaves) theirDelay := uint32(chanState.RemoteChanCfg.CsvDelay) theirScript, err := CommitScriptToSelf( chanState.ChanType, isRemoteInitiator, keyRing.ToLocalKey, - keyRing.RevocationKey, theirDelay, leaseExpiry, - input.NoneTapLeaf(), + keyRing.RevocationKey, theirDelay, leaseExpiry, localAuxLeaf, ) if err != nil { return nil, err @@ -2042,7 +2108,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, if revokedLog != nil { br, ourAmt, theirAmt, err = createBreachRetribution( revokedLog, spendTx, chanState, keyRing, - commitmentSecret, leaseExpiry, + commitmentSecret, leaseExpiry, auxResult.AuxLeaves, ) if err != nil { return nil, err @@ -2176,7 +2242,8 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, func createHtlcRetribution(chanState *channeldb.OpenChannel, keyRing *CommitmentKeyRing, commitHash chainhash.Hash, commitmentSecret *btcec.PrivateKey, leaseExpiry uint32, - htlc *channeldb.HTLCEntry) (HtlcRetribution, error) { + htlc *channeldb.HTLCEntry, + auxLeaves fn.Option[CommitAuxLeaves]) (HtlcRetribution, error) { var emptyRetribution HtlcRetribution @@ -2186,10 +2253,24 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // We'll generate the original second level witness script now, as // we'll need it if we're revoking an HTLC output on the remote // commitment transaction, and *they* go to the second level. + secondLevelAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) fn.Option[input.AuxTapLeaf] { + return fn.MapOption(func(val uint16) input.AuxTapLeaf { + idx := input.HtlcIndex(val) + + if htlc.Incoming.Val { + leaves := l.IncomingHtlcLeaves[idx] + return leaves.SecondLevelLeaf + } + + return l.OutgoingHtlcLeaves[idx].SecondLevelLeaf + })(htlc.HtlcIndex.ValOpt()) + }, + )(auxLeaves) secondLevelScript, err := SecondLevelHtlcScript( chanState.ChanType, isRemoteInitiator, keyRing.RevocationKey, keyRing.ToLocalKey, theirDelay, - leaseExpiry, + leaseExpiry, fn.FlattenOption(secondLevelAuxLeaf), ) if err != nil { return emptyRetribution, err @@ -2200,9 +2281,24 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // HTLC script. Otherwise, is this was an outgoing HTLC that we sent, // then from the PoV of the remote commitment state, they're the // receiver of this HTLC. + htlcLeaf := fn.ChainOption( + func(l CommitAuxLeaves) fn.Option[input.AuxTapLeaf] { + return fn.MapOption(func(val uint16) input.AuxTapLeaf { + idx := input.HtlcIndex(val) + + if htlc.Incoming.Val { + leaves := l.IncomingHtlcLeaves[idx] + return leaves.AuxTapLeaf + } + + return l.OutgoingHtlcLeaves[idx].AuxTapLeaf + })(htlc.HtlcIndex.ValOpt()) + }, + )(auxLeaves) scriptInfo, err := genHtlcScript( chanState.ChanType, htlc.Incoming.Val, lntypes.Remote, htlc.RefundTimeout.Val, htlc.RHash.Val, keyRing, + fn.FlattenOption(htlcLeaf), ) if err != nil { return emptyRetribution, err @@ -2266,7 +2362,9 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, func createBreachRetribution(revokedLog *channeldb.RevocationLog, spendTx *wire.MsgTx, chanState *channeldb.OpenChannel, keyRing *CommitmentKeyRing, commitmentSecret *btcec.PrivateKey, - leaseExpiry uint32) (*BreachRetribution, int64, int64, error) { + leaseExpiry uint32, + auxLeaves fn.Option[CommitAuxLeaves]) (*BreachRetribution, int64, int64, + error) { commitHash := revokedLog.CommitTxHash @@ -2275,7 +2373,7 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, for i, htlc := range revokedLog.HTLCEntries { hr, err := createHtlcRetribution( chanState, keyRing, commitHash.Val, - commitmentSecret, leaseExpiry, htlc, + commitmentSecret, leaseExpiry, htlc, auxLeaves, ) if err != nil { return nil, 0, 0, err @@ -2427,6 +2525,7 @@ func createBreachRetributionLegacy(revokedLog *channeldb.ChannelCommitment, hr, err := createHtlcRetribution( chanState, keyRing, commitHash, commitmentSecret, leaseExpiry, entry, + fn.None[CommitAuxLeaves](), ) if err != nil { return nil, 0, 0, err @@ -3041,13 +3140,27 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, // With the keys generated, we'll make a slice with enough capacity to // hold potentially all the HTLCs. The actual slice may be a bit // smaller (than its total capacity) and some HTLCs may be dust. - numSigs := (len(remoteCommitView.incomingHTLCs) + - len(remoteCommitView.outgoingHTLCs)) + numSigs := len(remoteCommitView.incomingHTLCs) + + len(remoteCommitView.outgoingHTLCs) sigBatch := make([]SignJob, 0, numSigs) var err error cancelChan := make(chan struct{}) + auxResult, err := fn.MapOptionZ( + leafStore, func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return s.FetchLeavesFromCommit( + NewAuxChanState(chanState), + *remoteCommitView.toDiskCommit(lntypes.Remote), + *keyRing, + ) + }, + ).Unpack() + if err != nil { + return nil, nil, fmt.Errorf("unable to fetch aux leaves: %w", + err) + } + // For each outgoing and incoming HTLC, if the HTLC isn't considered a // dust output after taking into account second-level HTLC fees, then a // sigJob will be generated and appended to the current batch. @@ -3074,6 +3187,13 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, htlcFee := HtlcTimeoutFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee + auxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.IncomingHtlcLeaves + return leaves[htlc.HtlcIndex].SecondLevelLeaf + }, + )(auxResult.AuxLeaves) + // With the fee calculate, we can properly create the HTLC // timeout transaction using the HTLC amount minus the fee. op := wire.OutPoint{ @@ -3084,11 +3204,15 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, chanType, isRemoteInitiator, op, outputAmt, htlc.Timeout, uint32(remoteChanCfg.CsvDelay), leaseExpiry, keyRing.RevocationKey, keyRing.ToLocalKey, + auxLeaf, ) if err != nil { return nil, nil, err } + // TODO(roasbeef): hook up signer interface here (later commit + // in this PR). + // Construct a full hash cache as we may be signing a segwit v1 // sighash. txOut := remoteCommitView.txn.TxOut[htlc.remoteOutputIndex] @@ -3115,7 +3239,8 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, // If this is a taproot channel, then we'll need to set the // method type to ensure we generate a valid signature. if chanType.IsTaproot() { - sigJob.SignDesc.SignMethod = input.TaprootScriptSpendSignMethod //nolint:lll + //nolint:lll + sigJob.SignDesc.SignMethod = input.TaprootScriptSpendSignMethod } sigBatch = append(sigBatch, sigJob) @@ -3141,6 +3266,13 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, htlcFee := HtlcSuccessFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee + auxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.OutgoingHtlcLeaves + return leaves[htlc.HtlcIndex].SecondLevelLeaf + }, + )(auxResult.AuxLeaves) + // With the proper output amount calculated, we can now // generate the success transaction using the remote party's // CSV delay. @@ -3152,6 +3284,7 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, chanType, isRemoteInitiator, op, outputAmt, uint32(remoteChanCfg.CsvDelay), leaseExpiry, keyRing.RevocationKey, keyRing.ToLocalKey, + auxLeaf, ) if err != nil { return nil, nil, err @@ -4491,10 +4624,24 @@ func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, // enough capacity to hold verification jobs for all HTLC's in this // view. In the case that we have some dust outputs, then the actual // length will be smaller than the total capacity. - numHtlcs := (len(localCommitmentView.incomingHTLCs) + - len(localCommitmentView.outgoingHTLCs)) + numHtlcs := len(localCommitmentView.incomingHTLCs) + + len(localCommitmentView.outgoingHTLCs) verifyJobs := make([]VerifyJob, 0, numHtlcs) + auxResult, err := fn.MapOptionZ( + leafStore, func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return s.FetchLeavesFromCommit( + NewAuxChanState(chanState), + *localCommitmentView.toDiskCommit( + lntypes.Local, + ), *keyRing, + ) + }, + ).Unpack() + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // We'll iterate through each output in the commitment transaction, // populating the sigHash closure function if it's detected to be an // HLTC output. Given the sighash, and the signing key, we'll be able @@ -4528,11 +4675,19 @@ func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, htlcFee := HtlcSuccessFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee + auxLeaf := fn.ChainOption(func( + l CommitAuxLeaves) input.AuxTapLeaf { + + leaves := l.IncomingHtlcLeaves + idx := htlc.HtlcIndex + return leaves[idx].SecondLevelLeaf + })(auxResult.AuxLeaves) + successTx, err := CreateHtlcSuccessTx( chanType, isLocalInitiator, op, outputAmt, uint32(localChanCfg.CsvDelay), leaseExpiry, keyRing.RevocationKey, - keyRing.ToLocalKey, + keyRing.ToLocalKey, auxLeaf, ) if err != nil { return nil, err @@ -4612,12 +4767,20 @@ func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, htlcFee := HtlcTimeoutFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee + auxLeaf := fn.ChainOption(func( + l CommitAuxLeaves) input.AuxTapLeaf { + + leaves := l.OutgoingHtlcLeaves + idx := htlc.HtlcIndex + return leaves[idx].SecondLevelLeaf + })(auxResult.AuxLeaves) + timeoutTx, err := CreateHtlcTimeoutTx( chanType, isLocalInitiator, op, outputAmt, htlc.Timeout, uint32(localChanCfg.CsvDelay), leaseExpiry, keyRing.RevocationKey, - keyRing.ToLocalKey, + keyRing.ToLocalKey, auxLeaf, ) if err != nil { return nil, err @@ -6332,6 +6495,18 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) + auxResult, err := fn.MapOptionZ( + leafStore, func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return s.FetchLeavesFromCommit( + NewAuxChanState(chanState), remoteCommit, + *keyRing, + ) + }, + ).Unpack() + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // Next, we'll obtain HTLC resolutions for all the outgoing HTLC's we // had on their commitment transaction. var leaseExpiry uint32 @@ -6344,6 +6519,7 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer, remoteCommit.Htlcs, keyRing, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, commitSpend.SpendingTx, chanState.ChanType, isRemoteInitiator, leaseExpiry, + auxResult.AuxLeaves, ) if err != nil { return nil, fmt.Errorf("unable to create htlc "+ @@ -6352,14 +6528,17 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, commitTxBroadcast := commitSpend.SpendingTx - // TODO(roasbeef): Actually fetch aux leaves (later commits in this PR). - // Before we can generate the proper sign descriptor, we'll need to // locate the output index of our non-delayed output on the commitment // transaction. + remoteAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + }, + )(auxResult.AuxLeaves) selfScript, maturityDelay, err := CommitScriptToRemote( chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey, - leaseExpiry, input.NoneTapLeaf(), + leaseExpiry, remoteAuxLeaf, ) if err != nil { return nil, fmt.Errorf("unable to create self commit "+ @@ -6599,7 +6778,8 @@ func newOutgoingHtlcResolution(signer input.Signer, htlc *channeldb.HTLC, keyRing *CommitmentKeyRing, feePerKw chainfee.SatPerKWeight, csvDelay, leaseExpiry uint32, whoseCommit lntypes.ChannelParty, isCommitFromInitiator bool, - chanType channeldb.ChannelType) (*OutgoingHtlcResolution, error) { + chanType channeldb.ChannelType, + auxLeaves fn.Option[CommitAuxLeaves]) (*OutgoingHtlcResolution, error) { op := wire.OutPoint{ Hash: commitTx.TxHash(), @@ -6608,9 +6788,12 @@ func newOutgoingHtlcResolution(signer input.Signer, // First, we'll re-generate the script used to send the HTLC to the // remote party within their commitment transaction. + auxLeaf := fn.ChainOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.OutgoingHtlcLeaves[htlc.HtlcIndex].AuxTapLeaf + })(auxLeaves) htlcScriptInfo, err := genHtlcScript( chanType, false, whoseCommit, htlc.RefundTimeout, htlc.RHash, - keyRing, + keyRing, auxLeaf, ) if err != nil { return nil, err @@ -6683,10 +6866,16 @@ func newOutgoingHtlcResolution(signer input.Signer, // With the fee calculated, re-construct the second level timeout // transaction. + secondLevelAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.OutgoingHtlcLeaves + return leaves[htlc.HtlcIndex].SecondLevelLeaf + }, + )(auxLeaves) timeoutTx, err := CreateHtlcTimeoutTx( chanType, isCommitFromInitiator, op, secondLevelOutputAmt, - htlc.RefundTimeout, csvDelay, leaseExpiry, keyRing.RevocationKey, - keyRing.ToLocalKey, + htlc.RefundTimeout, csvDelay, leaseExpiry, + keyRing.RevocationKey, keyRing.ToLocalKey, secondLevelAuxLeaf, ) if err != nil { return nil, err @@ -6769,6 +6958,7 @@ func newOutgoingHtlcResolution(signer input.Signer, htlcSweepScript, err = SecondLevelHtlcScript( chanType, isCommitFromInitiator, keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, leaseExpiry, + secondLevelAuxLeaf, ) if err != nil { return nil, err @@ -6777,7 +6967,7 @@ func newOutgoingHtlcResolution(signer input.Signer, //nolint:lll secondLevelScriptTree, err := input.TaprootSecondLevelScriptTree( keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, - fn.None[txscript.TapLeaf](), + secondLevelAuxLeaf, ) if err != nil { return nil, err @@ -6852,8 +7042,8 @@ func newIncomingHtlcResolution(signer input.Signer, htlc *channeldb.HTLC, keyRing *CommitmentKeyRing, feePerKw chainfee.SatPerKWeight, csvDelay, leaseExpiry uint32, whoseCommit lntypes.ChannelParty, isCommitFromInitiator bool, - chanType channeldb.ChannelType) ( - *IncomingHtlcResolution, error) { + chanType channeldb.ChannelType, + auxLeaves fn.Option[CommitAuxLeaves]) (*IncomingHtlcResolution, error) { op := wire.OutPoint{ Hash: commitTx.TxHash(), @@ -6862,9 +7052,12 @@ func newIncomingHtlcResolution(signer input.Signer, // First, we'll re-generate the script the remote party used to // send the HTLC to us in their commitment transaction. + auxLeaf := fn.ChainOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.IncomingHtlcLeaves[htlc.HtlcIndex].AuxTapLeaf + })(auxLeaves) scriptInfo, err := genHtlcScript( chanType, true, whoseCommit, htlc.RefundTimeout, htlc.RHash, - keyRing, + keyRing, auxLeaf, ) if err != nil { return nil, err @@ -6924,6 +7117,13 @@ func newIncomingHtlcResolution(signer input.Signer, }, nil } + secondLevelAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.IncomingHtlcLeaves + return leaves[htlc.HtlcIndex].SecondLevelLeaf + }, + )(auxLeaves) + // Otherwise, we'll need to go to the second level to sweep this HTLC. // // First, we'll reconstruct the original HTLC success transaction, @@ -6933,7 +7133,7 @@ func newIncomingHtlcResolution(signer input.Signer, successTx, err := CreateHtlcSuccessTx( chanType, isCommitFromInitiator, op, secondLevelOutputAmt, csvDelay, leaseExpiry, keyRing.RevocationKey, - keyRing.ToLocalKey, + keyRing.ToLocalKey, secondLevelAuxLeaf, ) if err != nil { return nil, err @@ -7016,6 +7216,7 @@ func newIncomingHtlcResolution(signer input.Signer, htlcSweepScript, err = SecondLevelHtlcScript( chanType, isCommitFromInitiator, keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, leaseExpiry, + secondLevelAuxLeaf, ) if err != nil { return nil, err @@ -7024,7 +7225,7 @@ func newIncomingHtlcResolution(signer input.Signer, //nolint:lll secondLevelScriptTree, err := input.TaprootSecondLevelScriptTree( keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, - fn.None[txscript.TapLeaf](), + secondLevelAuxLeaf, ) if err != nil { return nil, err @@ -7117,7 +7318,8 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, htlcs []channeldb.HTLC, keyRing *CommitmentKeyRing, localChanCfg, remoteChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx, chanType channeldb.ChannelType, - isCommitFromInitiator bool, leaseExpiry uint32) (*HtlcResolutions, error) { + isCommitFromInitiator bool, leaseExpiry uint32, + auxLeaves fn.Option[CommitAuxLeaves]) (*HtlcResolutions, error) { // TODO(roasbeef): don't need to swap csv delay? dustLimit := remoteChanCfg.DustLimit @@ -7150,8 +7352,9 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, // as we can satisfy the contract. ihr, err := newIncomingHtlcResolution( signer, localChanCfg, commitTx, &htlc, - keyRing, feePerKw, uint32(csvDelay), leaseExpiry, - whoseCommit, isCommitFromInitiator, chanType, + keyRing, feePerKw, uint32(csvDelay), + leaseExpiry, whoseCommit, isCommitFromInitiator, + chanType, auxLeaves, ) if err != nil { return nil, fmt.Errorf("incoming resolution "+ @@ -7165,7 +7368,7 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ohr, err := newOutgoingHtlcResolution( signer, localChanCfg, commitTx, &htlc, keyRing, feePerKw, uint32(csvDelay), leaseExpiry, whoseCommit, - isCommitFromInitiator, chanType, + isCommitFromInitiator, chanType, auxLeaves, ) if err != nil { return nil, fmt.Errorf("outgoing resolution "+ @@ -7304,16 +7507,31 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) - // TODO(roasbeef): Actually fetch aux leaves (later commits in this PR). + auxResult, err := fn.MapOptionZ( + leafStore, func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] { + return s.FetchLeavesFromCommit( + NewAuxChanState(chanState), + chanState.LocalCommitment, *keyRing, + ) + }, + ).Unpack() + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } var leaseExpiry uint32 if chanState.ChanType.HasLeaseExpiration() { leaseExpiry = chanState.ThawHeight } + + localAuxLeaf := fn.ChainOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + }, + )(auxResult.AuxLeaves) toLocalScript, err := CommitScriptToSelf( chanState.ChanType, chanState.IsInitiator, keyRing.ToLocalKey, - keyRing.RevocationKey, csvTimeout, leaseExpiry, - input.NoneTapLeaf(), + keyRing.RevocationKey, csvTimeout, leaseExpiry, localAuxLeaf, ) if err != nil { return nil, err @@ -7404,7 +7622,7 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, chainfee.SatPerKWeight(localCommit.FeePerKw), lntypes.Local, signer, localCommit.Htlcs, keyRing, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, commitTx, chanState.ChanType, - chanState.IsInitiator, leaseExpiry, + chanState.IsInitiator, leaseExpiry, auxResult.AuxLeaves, ) if err != nil { return nil, fmt.Errorf("unable to gen htlc resolution: %w", err) diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 4a1cb0321..de6bb219c 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -9971,7 +9971,7 @@ func TestCreateHtlcRetribution(t *testing.T) { // Create the htlc retribution. hr, err := createHtlcRetribution( aliceChannel.channelState, keyRing, commitHash, - dummyPrivate, leaseExpiry, htlc, + dummyPrivate, leaseExpiry, htlc, fn.None[CommitAuxLeaves](), ) // Expect no error. require.NoError(t, err) @@ -10177,6 +10177,7 @@ func TestCreateBreachRetribution(t *testing.T) { tc.revocationLog, tx, aliceChannel.channelState, keyRing, dummyPrivate, leaseExpiry, + fn.None[CommitAuxLeaves](), ) // Check the error if expected. @@ -10295,7 +10296,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // error as there are no past delta state saved as revocation logs yet. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, breachTx, - fn.None[AuxLeafStore](), + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.ErrorIs(t, err, channeldb.ErrNoPastDeltas) @@ -10303,7 +10304,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // provided. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, nil, - fn.None[AuxLeafStore](), + fn.Some[AuxLeafStore](&MockAuxLeafStore{}), ) require.ErrorIs(t, err, channeldb.ErrNoPastDeltas) @@ -10417,6 +10418,7 @@ func TestExtractPayDescs(t *testing.T) { // scripts(`genHtlcScript`) as it should be tested independently. incomingPDs, outgoingPDs, err := lnChan.extractPayDescs( 0, htlcs, lntypes.Dual[*CommitmentKeyRing]{}, lntypes.Local, + fn.None[CommitAuxLeaves](), ) require.NoError(t, err) diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index 7f2a9ce33..2ff23ab63 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -420,14 +420,14 @@ func sweepSigHash(chanType channeldb.ChannelType) txscript.SigHashType { // argument should correspond to the owner of the commitment transaction which // we are generating the to_local script for. func SecondLevelHtlcScript(chanType channeldb.ChannelType, initiator bool, - revocationKey, delayKey *btcec.PublicKey, - csvDelay, leaseExpiry uint32) (input.ScriptDescriptor, error) { + revocationKey, delayKey *btcec.PublicKey, csvDelay, leaseExpiry uint32, + auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, error) { switch { // For taproot channels, the pkScript is a segwit v1 p2tr output. case chanType.IsTaproot(): return input.TaprootSecondLevelScriptTree( - revocationKey, delayKey, csvDelay, input.NoneTapLeaf(), + revocationKey, delayKey, csvDelay, auxLeaf, ) // If we are the initiator of a leased channel, then we have an @@ -755,10 +755,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, theirBalance -= commitFeeMSat } - var ( - commitTx *wire.MsgTx - err error - ) + var commitTx *wire.MsgTx // Before we create the commitment transaction below, we'll try to see // if there're any aux leaves that need to be a part of the tapscript @@ -806,6 +803,19 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, return nil, err } + // Similarly, we'll now attempt to extract the set of aux leaves for + // the set of incoming and outgoing HTLCs. + incomingAuxLeaves := fn.MapOption( + func(leaves CommitAuxLeaves) input.HtlcAuxLeaves { + return leaves.IncomingHtlcLeaves + }, + )(auxResult.AuxLeaves) + outgoingAuxLeaves := fn.MapOption( + func(leaves CommitAuxLeaves) input.HtlcAuxLeaves { + return leaves.OutgoingHtlcLeaves + }, + )(auxResult.AuxLeaves) + // We'll now add all the HTLC outputs to the commitment transaction. // Each output includes an off-chain 2-of-2 covenant clause, so we'll // need the objective local/remote keys for this particular commitment @@ -826,9 +836,15 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, continue } + auxLeaf := fn.ChainOption( + func(leaves input.HtlcAuxLeaves) input.AuxTapLeaf { + return leaves[htlc.HtlcIndex].AuxTapLeaf + }, + )(outgoingAuxLeaves) + err := addHTLC( commitTx, whoseCommit, false, htlc, keyRing, - cb.chanState.ChanType, + cb.chanState.ChanType, auxLeaf, ) if err != nil { return nil, err @@ -848,9 +864,15 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, continue } + auxLeaf := fn.ChainOption( + func(leaves input.HtlcAuxLeaves) input.AuxTapLeaf { + return leaves[htlc.HtlcIndex].AuxTapLeaf + }, + )(incomingAuxLeaves) + err := addHTLC( commitTx, whoseCommit, true, htlc, keyRing, - cb.chanState.ChanType, + cb.chanState.ChanType, auxLeaf, ) if err != nil { return nil, err @@ -1128,7 +1150,7 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType, // channel. func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, timeout uint32, rHash [32]byte, keyRing *CommitmentKeyRing, -) (*input.HtlcScriptTree, error) { + auxLeaf input.AuxTapLeaf) (*input.HtlcScriptTree, error) { var ( htlcScriptTree *input.HtlcScriptTree @@ -1145,8 +1167,7 @@ func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, case isIncoming && whoseCommit.IsLocal(): htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, - keyRing.RevocationKey, rHash[:], whoseCommit, - input.NoneTapLeaf(), + keyRing.RevocationKey, rHash[:], whoseCommit, auxLeaf, ) // We're being paid via an HTLC by the remote party, and the HTLC is @@ -1155,8 +1176,7 @@ func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, case isIncoming && whoseCommit.IsRemote(): htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, - keyRing.RevocationKey, rHash[:], whoseCommit, - input.NoneTapLeaf(), + keyRing.RevocationKey, rHash[:], whoseCommit, auxLeaf, ) // We're sending an HTLC which is being added to our commitment @@ -1165,8 +1185,7 @@ func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, case !isIncoming && whoseCommit.IsLocal(): htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, - keyRing.RevocationKey, rHash[:], whoseCommit, - input.NoneTapLeaf(), + keyRing.RevocationKey, rHash[:], whoseCommit, auxLeaf, ) // Finally, we're paying the remote party via an HTLC, which is being @@ -1175,8 +1194,7 @@ func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, case !isIncoming && whoseCommit.IsRemote(): htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, - keyRing.RevocationKey, rHash[:], whoseCommit, - input.NoneTapLeaf(), + keyRing.RevocationKey, rHash[:], whoseCommit, auxLeaf, ) } @@ -1191,7 +1209,8 @@ func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, // along side the multiplexer. func genHtlcScript(chanType channeldb.ChannelType, isIncoming bool, whoseCommit lntypes.ChannelParty, timeout uint32, rHash [32]byte, - keyRing *CommitmentKeyRing) (input.ScriptDescriptor, error) { + keyRing *CommitmentKeyRing, + auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, error) { if !chanType.IsTaproot() { return genSegwitV0HtlcScript( @@ -1201,7 +1220,7 @@ func genHtlcScript(chanType channeldb.ChannelType, isIncoming bool, } return GenTaprootHtlcScript( - isIncoming, whoseCommit, timeout, rHash, keyRing, + isIncoming, whoseCommit, timeout, rHash, keyRing, auxLeaf, ) } @@ -1214,13 +1233,15 @@ func genHtlcScript(chanType channeldb.ChannelType, isIncoming bool, // the descriptor itself. func addHTLC(commitTx *wire.MsgTx, whoseCommit lntypes.ChannelParty, isIncoming bool, paymentDesc *PaymentDescriptor, - keyRing *CommitmentKeyRing, chanType channeldb.ChannelType) error { + keyRing *CommitmentKeyRing, chanType channeldb.ChannelType, + auxLeaf input.AuxTapLeaf) error { timeout := paymentDesc.Timeout rHash := paymentDesc.RHash scriptInfo, err := genHtlcScript( chanType, isIncoming, whoseCommit, timeout, rHash, keyRing, + auxLeaf, ) if err != nil { return err diff --git a/lnwallet/transactions.go b/lnwallet/transactions.go index 1cf954d3c..da86650bc 100644 --- a/lnwallet/transactions.go +++ b/lnwallet/transactions.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/input" ) const ( @@ -50,8 +51,8 @@ var ( // - func CreateHtlcSuccessTx(chanType channeldb.ChannelType, initiator bool, htlcOutput wire.OutPoint, htlcAmt btcutil.Amount, csvDelay, - leaseExpiry uint32, revocationKey, delayKey *btcec.PublicKey) ( - *wire.MsgTx, error) { + leaseExpiry uint32, revocationKey, delayKey *btcec.PublicKey, + auxLeaf input.AuxTapLeaf) (*wire.MsgTx, error) { // Create a version two transaction (as the success version of this // spends an output with a CSV timeout). @@ -71,7 +72,7 @@ func CreateHtlcSuccessTx(chanType channeldb.ChannelType, initiator bool, // HTLC outputs. scriptInfo, err := SecondLevelHtlcScript( chanType, initiator, revocationKey, delayKey, csvDelay, - leaseExpiry, + leaseExpiry, auxLeaf, ) if err != nil { return nil, err @@ -110,7 +111,8 @@ func CreateHtlcSuccessTx(chanType channeldb.ChannelType, initiator bool, func CreateHtlcTimeoutTx(chanType channeldb.ChannelType, initiator bool, htlcOutput wire.OutPoint, htlcAmt btcutil.Amount, cltvExpiry, csvDelay, leaseExpiry uint32, - revocationKey, delayKey *btcec.PublicKey) (*wire.MsgTx, error) { + revocationKey, delayKey *btcec.PublicKey, + auxLeaf input.AuxTapLeaf) (*wire.MsgTx, error) { // Create a version two transaction (as the success version of this // spends an output with a CSV timeout), and set the lock-time to the @@ -134,7 +136,7 @@ func CreateHtlcTimeoutTx(chanType channeldb.ChannelType, initiator bool, // HTLC outputs. scriptInfo, err := SecondLevelHtlcScript( chanType, initiator, revocationKey, delayKey, csvDelay, - leaseExpiry, + leaseExpiry, auxLeaf, ) if err != nil { return nil, err