From 860cacb70a8faf0bee6e05f43937cee229c07653 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Fri, 23 Aug 2024 15:13:19 +0200 Subject: [PATCH] lnwallet: refactor commit keys to use lntypes.Dual --- lnwallet/channel.go | 30 ++++++++++++++---------------- lnwallet/channel_test.go | 2 +- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index b5ef71c15..875903e92 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -575,9 +575,8 @@ func (c *commitment) toDiskCommit( // restore commitment state written to disk back into memory once we need to // restart a channel session. func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, - htlc *channeldb.HTLC, localCommitKeys *CommitmentKeyRing, - remoteCommitKeys *CommitmentKeyRing, whoseCommit lntypes.ChannelParty, -) (PaymentDescriptor, error) { + htlc *channeldb.HTLC, commitKeys lntypes.Dual[*CommitmentKeyRing], + whoseCommit lntypes.ChannelParty) (PaymentDescriptor, error) { // The proper pkScripts for this PaymentDescriptor must be // generated so we can easily locate them within the commitment @@ -598,6 +597,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, chanType, htlc.Incoming, lntypes.Local, feeRate, htlc.Amt.ToSatoshis(), lc.channelState.LocalChanCfg.DustLimit, ) + localCommitKeys := commitKeys.GetForParty(lntypes.Local) if !isDustLocal && localCommitKeys != nil { scriptInfo, err := genHtlcScript( chanType, htlc.Incoming, lntypes.Local, @@ -613,6 +613,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, chanType, htlc.Incoming, lntypes.Remote, feeRate, htlc.Amt.ToSatoshis(), lc.channelState.RemoteChanCfg.DustLimit, ) + remoteCommitKeys := commitKeys.GetForParty(lntypes.Remote) if !isDustRemote && remoteCommitKeys != nil { scriptInfo, err := genHtlcScript( chanType, htlc.Incoming, lntypes.Remote, @@ -665,9 +666,9 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, // these payment descriptors can be re-inserted into the in-memory updateLog // for each side. func (lc *LightningChannel) extractPayDescs(feeRate chainfee.SatPerKWeight, - htlcs []channeldb.HTLC, localCommitKeys *CommitmentKeyRing, - remoteCommitKeys *CommitmentKeyRing, whoseCommit lntypes.ChannelParty, -) ([]PaymentDescriptor, []PaymentDescriptor, error) { + htlcs []channeldb.HTLC, commitKeys lntypes.Dual[*CommitmentKeyRing], + whoseCommit lntypes.ChannelParty) ([]PaymentDescriptor, + []PaymentDescriptor, error) { var ( incomingHtlcs []PaymentDescriptor @@ -685,9 +686,7 @@ func (lc *LightningChannel) extractPayDescs(feeRate chainfee.SatPerKWeight, htlc := htlc payDesc, err := lc.diskHtlcToPayDesc( - feeRate, &htlc, - localCommitKeys, remoteCommitKeys, - whoseCommit, + feeRate, &htlc, commitKeys, whoseCommit, ) if err != nil { return incomingHtlcs, outgoingHtlcs, err @@ -716,22 +715,22 @@ func (lc *LightningChannel) diskCommitToMemCommit( // (we extended but weren't able to complete the commitment dance // before shutdown), then the localCommitPoint won't be set as we // haven't yet received a responding commitment from the remote party. - var localCommitKeys, remoteCommitKeys *CommitmentKeyRing + var commitKeys lntypes.Dual[*CommitmentKeyRing] if localCommitPoint != nil { - localCommitKeys = DeriveCommitmentKeys( + commitKeys.SetForParty(lntypes.Local, DeriveCommitmentKeys( localCommitPoint, lntypes.Local, lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, - ) + )) } if remoteCommitPoint != nil { - remoteCommitKeys = DeriveCommitmentKeys( + commitKeys.SetForParty(lntypes.Remote, DeriveCommitmentKeys( remoteCommitPoint, lntypes.Remote, lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, - ) + )) } // With the key rings re-created, we'll now convert all the on-disk @@ -739,8 +738,7 @@ func (lc *LightningChannel) diskCommitToMemCommit( // update log. incomingHtlcs, outgoingHtlcs, err := lc.extractPayDescs( chainfee.SatPerKWeight(diskCommit.FeePerKw), - diskCommit.Htlcs, localCommitKeys, remoteCommitKeys, - whoseCommit, + diskCommit.Htlcs, commitKeys, whoseCommit, ) if err != nil { return nil, err diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 8bdf45aa8..a09f97bfe 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -10512,7 +10512,7 @@ func TestExtractPayDescs(t *testing.T) { // NOTE: we use nil commitment key rings to avoid checking the htlc // scripts(`genHtlcScript`) as it should be tested independently. incomingPDs, outgoingPDs, err := lnChan.extractPayDescs( - 0, htlcs, nil, nil, lntypes.Local, + 0, htlcs, lntypes.Dual[*CommitmentKeyRing]{}, lntypes.Local, ) require.NoError(t, err)