lnwallet: remove mutateState from evaluateHTLCView

In line with previous commits we are progressively removing the
mutateState argument from this call stack for a more principled
software design approach.

NOTE FOR REVIEWERS:
We take a naive approach to updating the tests here and simply
take the functionality we are removing from evaluateHTLCView and
run it directly after the function in the test suite.

It's possible that we should instead remove this from the test
suite altogether but I opted to take a more conservative approach
with respect to reducing the scope of tests. If you have opinions
here, please make them known.
This commit is contained in:
Keagan McClelland 2024-07-19 16:53:58 -07:00
parent 819239c5c8
commit d82d02831d
No known key found for this signature in database
GPG Key ID: FA7E65C951F12439
4 changed files with 146 additions and 97 deletions

View File

@ -117,3 +117,5 @@ func MapDual[A, B any](d Dual[A], f func(A) B) Dual[B] {
Remote: f(d.Remote), Remote: f(d.Remote),
} }
} }
var BothParties []ChannelParty = []ChannelParty{Local, Remote}

View File

@ -2890,16 +2890,14 @@ func fundingTxIn(chanState *channeldb.OpenChannel) wire.TxIn {
// returned reflects the current state of HTLCs within the remote or local // returned reflects the current state of HTLCs within the remote or local
// commitment chain, and the current commitment fee rate. // commitment chain, and the current commitment fee rate.
// //
// If mutateState is set to true, then the add height of all added HTLCs // The return values of this function are as follows:
// will be set to nextHeight, and the remove height of all removed HTLCs // 1. The new htlcView reflecting the current channel state.
// will be set to nextHeight. This should therefore only be set to true // 2. A Dual of the updates which have not yet been committed in
// once for each height, and only in concert with signing a new commitment. // 'whoseCommitChain's commitment chain.
// TODO(halseth): return htlcs to mutate instead of mutating inside
// method.
func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance,
theirBalance *lnwire.MilliSatoshi, nextHeight uint64, theirBalance *lnwire.MilliSatoshi, nextHeight uint64,
whoseCommitChain lntypes.ChannelParty, mutateState bool) (*HtlcView, whoseCommitChain lntypes.ChannelParty) (*HtlcView,
error) { lntypes.Dual[[]*paymentDescriptor], error) {
// We initialize the view's fee rate to the fee rate of the unfiltered // We initialize the view's fee rate to the fee rate of the unfiltered
// view. If any fee updates are found when evaluating the view, it will // view. If any fee updates are found when evaluating the view, it will
@ -2917,8 +2915,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance,
skipThem := make(map[uint64]struct{}) skipThem := make(map[uint64]struct{})
// First we run through non-add entries in both logs, populating the // First we run through non-add entries in both logs, populating the
// skip sets and mutating the current chain state (crediting balances, // skip sets.
// etc) to reflect the settle/timeout entry encountered.
for _, entry := range view.OurUpdates { for _, entry := range view.OurUpdates {
switch entry.EntryType { switch entry.EntryType {
// Skip adds for now. They will be processed below. // Skip adds for now. They will be processed below.
@ -2938,53 +2935,31 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance,
newView.FeePerKw = chainfee.SatPerKWeight( newView.FeePerKw = chainfee.SatPerKWeight(
entry.Amount.ToSatoshis(), entry.Amount.ToSatoshis(),
) )
if mutateState {
entry.addCommitHeights.SetForParty(
whoseCommitChain, nextHeight,
)
entry.removeCommitHeights.SetForParty(
whoseCommitChain, nextHeight,
)
}
} }
continue continue
} }
// If we're settling an inbound HTLC, and it hasn't been
// processed yet, then increment our state tracking the total
// number of satoshis we've received within the channel.
if mutateState && entry.EntryType == Settle &&
whoseCommitChain.IsLocal() &&
entry.removeCommitHeights.Local == 0 {
lc.channelState.TotalMSatReceived += entry.Amount
}
addEntry, err := lc.fetchParent( addEntry, err := lc.fetchParent(
entry, whoseCommitChain, lntypes.Remote, entry, whoseCommitChain, lntypes.Remote,
) )
if err != nil { if err != nil {
return nil, err return nil, lntypes.Dual[[]*paymentDescriptor]{}, err
} }
skipThem[addEntry.HtlcIndex] = struct{}{} skipThem[addEntry.HtlcIndex] = struct{}{}
rmvHeights := &entry.removeCommitHeights rmvHeight := entry.removeCommitHeights.GetForParty(
rmvHeight := rmvHeights.GetForParty(whoseCommitChain) whoseCommitChain,
)
if rmvHeight == 0 { if rmvHeight == 0 {
processRemoveEntry( processRemoveEntry(
entry, ourBalance, theirBalance, true, entry, ourBalance, theirBalance, true,
) )
}
}
if mutateState { // Do the same for our peer's updates.
rmvHeights.SetForParty(
whoseCommitChain, nextHeight,
)
}
}
}
for _, entry := range view.TheirUpdates { for _, entry := range view.TheirUpdates {
switch entry.EntryType { switch entry.EntryType {
// Skip adds for now. They will be processed below. // Skip adds for now. They will be processed below.
@ -3004,53 +2979,27 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance,
newView.FeePerKw = chainfee.SatPerKWeight( newView.FeePerKw = chainfee.SatPerKWeight(
entry.Amount.ToSatoshis(), entry.Amount.ToSatoshis(),
) )
if mutateState {
entry.addCommitHeights.SetForParty(
whoseCommitChain, nextHeight,
)
entry.removeCommitHeights.SetForParty(
whoseCommitChain, nextHeight,
)
}
} }
continue continue
} }
// If the remote party is settling one of our outbound HTLC's,
// and it hasn't been processed, yet, the increment our state
// tracking the total number of satoshis we've sent within the
// channel.
if mutateState && entry.EntryType == Settle &&
whoseCommitChain.IsLocal() &&
entry.removeCommitHeights.Local == 0 {
lc.channelState.TotalMSatSent += entry.Amount
}
addEntry, err := lc.fetchParent( addEntry, err := lc.fetchParent(
entry, whoseCommitChain, lntypes.Local, entry, whoseCommitChain, lntypes.Local,
) )
if err != nil { if err != nil {
return nil, err return nil, lntypes.Dual[[]*paymentDescriptor]{}, err
} }
skipUs[addEntry.HtlcIndex] = struct{}{} skipUs[addEntry.HtlcIndex] = struct{}{}
rmvHeights := &entry.removeCommitHeights rmvHeight := entry.removeCommitHeights.GetForParty(
rmvHeight := rmvHeights.GetForParty(whoseCommitChain) whoseCommitChain,
)
if rmvHeight == 0 { if rmvHeight == 0 {
processRemoveEntry( processRemoveEntry(
entry, ourBalance, theirBalance, false, entry, ourBalance, theirBalance, false,
) )
if mutateState {
rmvHeights.SetForParty(
whoseCommitChain, nextHeight,
)
}
} }
} }
@ -3065,25 +3014,19 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance,
// Skip the entries that have already had their add commit // Skip the entries that have already had their add commit
// height set for this commit chain. // height set for this commit chain.
addHeights := &entry.addCommitHeights addHeight := entry.addCommitHeights.GetForParty(
addHeight := addHeights.GetForParty(whoseCommitChain) whoseCommitChain,
)
if addHeight == 0 { if addHeight == 0 {
processAddEntry( processAddEntry(
entry, ourBalance, theirBalance, false, entry, ourBalance, theirBalance, false,
) )
// If we are mutating the state, then set the add
// height for the appropriate commitment chain to the
// next height.
if mutateState {
addHeights.SetForParty(
whoseCommitChain, nextHeight,
)
}
} }
newView.OurUpdates = append(newView.OurUpdates, entry) newView.OurUpdates = append(newView.OurUpdates, entry)
} }
// Again, we do the same for our peer's updates.
for _, entry := range view.TheirUpdates { for _, entry := range view.TheirUpdates {
isAdd := entry.EntryType == Add isAdd := entry.EntryType == Add
if _, ok := skipThem[entry.HtlcIndex]; !isAdd || ok { if _, ok := skipThem[entry.HtlcIndex]; !isAdd || ok {
@ -3092,27 +3035,51 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance,
// Skip the entries that have already had their add commit // Skip the entries that have already had their add commit
// height set for this commit chain. // height set for this commit chain.
addHeights := &entry.addCommitHeights addHeight := entry.addCommitHeights.GetForParty(
addHeight := addHeights.GetForParty(whoseCommitChain) whoseCommitChain,
)
if addHeight == 0 { if addHeight == 0 {
processAddEntry( processAddEntry(
entry, ourBalance, theirBalance, true, entry, ourBalance, theirBalance, true,
) )
// If we are mutating the state, then set the add
// height for the appropriate commitment chain to the
// next height.
if mutateState {
addHeights.SetForParty(
whoseCommitChain, nextHeight,
)
}
} }
newView.TheirUpdates = append(newView.TheirUpdates, entry) newView.TheirUpdates = append(newView.TheirUpdates, entry)
} }
return newView, nil // Create a function that is capable of identifying whether or not the
// paymentDescriptor has been committed in the commitment chain
// corresponding to whoseCommitmentChain.
isUncommitted := func(update *paymentDescriptor) bool {
switch update.EntryType {
case Add:
return update.addCommitHeights.GetForParty(
whoseCommitChain,
) == 0
case FeeUpdate:
return update.addCommitHeights.GetForParty(
whoseCommitChain,
) == 0
case Settle, Fail, MalformedFail:
return update.removeCommitHeights.GetForParty(
whoseCommitChain,
) == 0
default:
panic("invalid paymentDescriptor EntryType")
}
}
// Collect all of the updates that haven't had their commit heights set
// for the commitment chain corresponding to whoseCommitmentChain.
uncommittedUpdates := lntypes.Dual[[]*paymentDescriptor]{
Local: fn.Filter(isUncommitted, view.OurUpdates),
Remote: fn.Filter(isUncommitted, view.TheirUpdates),
}
return newView, uncommittedUpdates, nil
} }
// fetchParent is a helper that looks up update log parent entries in the // fetchParent is a helper that looks up update log parent entries in the
@ -4683,13 +4650,27 @@ func (lc *LightningChannel) computeView(view *HtlcView,
// channel constraints to the final commitment state. If any fee // channel constraints to the final commitment state. If any fee
// updates are found in the logs, the commitment fee rate should be // updates are found in the logs, the commitment fee rate should be
// changed, so we'll also set the feePerKw to this new value. // changed, so we'll also set the feePerKw to this new value.
filteredHTLCView, err := lc.evaluateHTLCView( filteredHTLCView, uncommitted, err := lc.evaluateHTLCView(
view, &ourBalance, &theirBalance, nextHeight, whoseCommitChain, view, &ourBalance, &theirBalance, nextHeight, whoseCommitChain,
updateState,
) )
if err != nil { if err != nil {
return 0, 0, 0, nil, err return 0, 0, 0, nil, err
} }
if updateState {
for _, party := range lntypes.BothParties {
for _, u := range uncommitted.GetForParty(party) {
u.setCommitHeight(whoseCommitChain, nextHeight)
if whoseCommitChain == lntypes.Local &&
u.EntryType == Settle {
lc.recordSettlement(party, u.Amount)
}
}
}
}
feePerKw := filteredHTLCView.FeePerKw feePerKw := filteredHTLCView.FeePerKw
// Here we override the view's fee-rate if a dry-run fee-rate was // Here we override the view's fee-rate if a dry-run fee-rate was
@ -4742,6 +4723,18 @@ func (lc *LightningChannel) computeView(view *HtlcView,
return ourBalance, theirBalance, totalCommitWeight, filteredHTLCView, nil return ourBalance, theirBalance, totalCommitWeight, filteredHTLCView, nil
} }
// recordSettlement updates the lifetime payment flow values in persistent state
// of the LightningChannel, adding amt to the total received by the redeemer.
func (lc *LightningChannel) recordSettlement(
redeemer lntypes.ChannelParty, amt lnwire.MilliSatoshi) {
if redeemer == lntypes.Local {
lc.channelState.TotalMSatReceived += amt
} else {
lc.channelState.TotalMSatSent += amt
}
}
// genHtlcSigValidationJobs generates a series of signatures verification jobs // genHtlcSigValidationJobs generates a series of signatures verification jobs
// meant to verify all the signatures for HTLC's attached to a newly created // meant to verify all the signatures for HTLC's attached to a newly created
// commitment state. The jobs generated are fully populated, and can be sent // commitment state. The jobs generated are fully populated, and can be sent

View File

@ -8956,14 +8956,40 @@ func TestEvaluateView(t *testing.T) {
) )
// Evaluate the htlc view, mutate as test expects. // Evaluate the htlc view, mutate as test expects.
result, err := lc.evaluateHTLCView( result, uncommitted, err := lc.evaluateHTLCView(
view, &ourBalance, &theirBalance, nextHeight, view, &ourBalance, &theirBalance, nextHeight,
test.whoseCommitChain, test.mutateState, test.whoseCommitChain,
) )
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
// TODO(proofofkeags): This block is here because we
// extracted this code from a previous implementation
// of evaluateHTLCView, due to a reduced scope of
// responsibility of that function. Consider removing
// it from the test altogether.
if test.mutateState {
for _, party := range lntypes.BothParties {
us := uncommitted.GetForParty(party)
for _, u := range us {
u.setCommitHeight(
test.whoseCommitChain,
nextHeight,
)
if test.whoseCommitChain ==
lntypes.Local &&
u.EntryType == Settle {
lc.recordSettlement(
party, u.Amount,
)
}
}
}
}
if result.FeePerKw != test.expectedFee { if result.FeePerKw != test.expectedFee {
t.Fatalf("expected fee: %v, got: %v", t.Fatalf("expected fee: %v, got: %v",
test.expectedFee, result.FeePerKw) test.expectedFee, result.FeePerKw)

View File

@ -283,3 +283,31 @@ func (pd *paymentDescriptor) toLogUpdate() channeldb.LogUpdate {
UpdateMsg: msg, UpdateMsg: msg,
} }
} }
// setCommitHeight updates the appropriate addCommitHeight and/or
// removeCommitHeight for whoseCommitChain and locks it in at nextHeight.
func (pd *paymentDescriptor) setCommitHeight(
whoseCommitChain lntypes.ChannelParty, nextHeight uint64) {
switch pd.EntryType {
case Add:
pd.addCommitHeights.SetForParty(
whoseCommitChain, nextHeight,
)
case Settle, Fail, MalformedFail:
pd.removeCommitHeights.SetForParty(
whoseCommitChain, nextHeight,
)
case FeeUpdate:
// Fee updates are applied for all commitments
// after they are sent/received, so we consider
// them being added and removed at the same
// height.
pd.addCommitHeights.SetForParty(
whoseCommitChain, nextHeight,
)
pd.removeCommitHeights.SetForParty(
whoseCommitChain, nextHeight,
)
}
}