lnwallet: refactor commit keys to use lntypes.Dual

This commit is contained in:
Oliver Gugger 2024-08-23 15:13:19 +02:00
parent b1c8a836e3
commit 860cacb70a
No known key found for this signature in database
GPG key ID: 8E4256593F177720
2 changed files with 15 additions and 17 deletions

View file

@ -575,9 +575,8 @@ func (c *commitment) toDiskCommit(
// restore commitment state written to disk back into memory once we need to
// restart a channel session.
func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight,
htlc *channeldb.HTLC, localCommitKeys *CommitmentKeyRing,
remoteCommitKeys *CommitmentKeyRing, whoseCommit lntypes.ChannelParty,
) (PaymentDescriptor, error) {
htlc *channeldb.HTLC, commitKeys lntypes.Dual[*CommitmentKeyRing],
whoseCommit lntypes.ChannelParty) (PaymentDescriptor, error) {
// The proper pkScripts for this PaymentDescriptor must be
// generated so we can easily locate them within the commitment
@ -598,6 +597,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight,
chanType, htlc.Incoming, lntypes.Local, feeRate,
htlc.Amt.ToSatoshis(), lc.channelState.LocalChanCfg.DustLimit,
)
localCommitKeys := commitKeys.GetForParty(lntypes.Local)
if !isDustLocal && localCommitKeys != nil {
scriptInfo, err := genHtlcScript(
chanType, htlc.Incoming, lntypes.Local,
@ -613,6 +613,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight,
chanType, htlc.Incoming, lntypes.Remote, feeRate,
htlc.Amt.ToSatoshis(), lc.channelState.RemoteChanCfg.DustLimit,
)
remoteCommitKeys := commitKeys.GetForParty(lntypes.Remote)
if !isDustRemote && remoteCommitKeys != nil {
scriptInfo, err := genHtlcScript(
chanType, htlc.Incoming, lntypes.Remote,
@ -665,9 +666,9 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight,
// these payment descriptors can be re-inserted into the in-memory updateLog
// for each side.
func (lc *LightningChannel) extractPayDescs(feeRate chainfee.SatPerKWeight,
htlcs []channeldb.HTLC, localCommitKeys *CommitmentKeyRing,
remoteCommitKeys *CommitmentKeyRing, whoseCommit lntypes.ChannelParty,
) ([]PaymentDescriptor, []PaymentDescriptor, error) {
htlcs []channeldb.HTLC, commitKeys lntypes.Dual[*CommitmentKeyRing],
whoseCommit lntypes.ChannelParty) ([]PaymentDescriptor,
[]PaymentDescriptor, error) {
var (
incomingHtlcs []PaymentDescriptor
@ -685,9 +686,7 @@ func (lc *LightningChannel) extractPayDescs(feeRate chainfee.SatPerKWeight,
htlc := htlc
payDesc, err := lc.diskHtlcToPayDesc(
feeRate, &htlc,
localCommitKeys, remoteCommitKeys,
whoseCommit,
feeRate, &htlc, commitKeys, whoseCommit,
)
if err != nil {
return incomingHtlcs, outgoingHtlcs, err
@ -716,22 +715,22 @@ func (lc *LightningChannel) diskCommitToMemCommit(
// (we extended but weren't able to complete the commitment dance
// before shutdown), then the localCommitPoint won't be set as we
// haven't yet received a responding commitment from the remote party.
var localCommitKeys, remoteCommitKeys *CommitmentKeyRing
var commitKeys lntypes.Dual[*CommitmentKeyRing]
if localCommitPoint != nil {
localCommitKeys = DeriveCommitmentKeys(
commitKeys.SetForParty(lntypes.Local, DeriveCommitmentKeys(
localCommitPoint, lntypes.Local,
lc.channelState.ChanType,
&lc.channelState.LocalChanCfg,
&lc.channelState.RemoteChanCfg,
)
))
}
if remoteCommitPoint != nil {
remoteCommitKeys = DeriveCommitmentKeys(
commitKeys.SetForParty(lntypes.Remote, DeriveCommitmentKeys(
remoteCommitPoint, lntypes.Remote,
lc.channelState.ChanType,
&lc.channelState.LocalChanCfg,
&lc.channelState.RemoteChanCfg,
)
))
}
// With the key rings re-created, we'll now convert all the on-disk
@ -739,8 +738,7 @@ func (lc *LightningChannel) diskCommitToMemCommit(
// update log.
incomingHtlcs, outgoingHtlcs, err := lc.extractPayDescs(
chainfee.SatPerKWeight(diskCommit.FeePerKw),
diskCommit.Htlcs, localCommitKeys, remoteCommitKeys,
whoseCommit,
diskCommit.Htlcs, commitKeys, whoseCommit,
)
if err != nil {
return nil, err

View file

@ -10512,7 +10512,7 @@ func TestExtractPayDescs(t *testing.T) {
// NOTE: we use nil commitment key rings to avoid checking the htlc
// scripts(`genHtlcScript`) as it should be tested independently.
incomingPDs, outgoingPDs, err := lnChan.extractPayDescs(
0, htlcs, nil, nil, lntypes.Local,
0, htlcs, lntypes.Dual[*CommitmentKeyRing]{}, lntypes.Local,
)
require.NoError(t, err)