lnwallet: refactor channel to use new typed List

This commit is contained in:
Keagan McClelland 2024-04-30 15:23:50 -07:00
parent 16d80f5b5b
commit 04c37344ae
No known key found for this signature in database
GPG Key ID: FA7E65C951F12439
4 changed files with 37 additions and 38 deletions

View File

@ -2474,7 +2474,7 @@ type htlcView struct {
func (lc *LightningChannel) fetchHTLCView(theirLogIndex, ourLogIndex uint64) *htlcView {
var ourHTLCs []*PaymentDescriptor
for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() {
htlc := e.Value.(*PaymentDescriptor)
htlc := e.Value
// This HTLC is active from this point-of-view iff the log
// index of the state update is below the specified index in
@ -2486,7 +2486,7 @@ func (lc *LightningChannel) fetchHTLCView(theirLogIndex, ourLogIndex uint64) *ht
var theirHTLCs []*PaymentDescriptor
for e := lc.remoteUpdateLog.Front(); e != nil; e = e.Next() {
htlc := e.Value.(*PaymentDescriptor)
htlc := e.Value
// If this is an incoming HTLC, then it is only active from
// this point-of-view if the index of the HTLC addition in
@ -3112,7 +3112,7 @@ func (lc *LightningChannel) createCommitDiff(
// set of items we need to retransmit if we reconnect and find that
// they didn't process this new state fully.
for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() {
pd := e.Value.(*PaymentDescriptor)
pd := e.Value
// If this entry wasn't committed at the exact height of this
// remote commitment, then we'll skip it as it was already
@ -3250,7 +3250,7 @@ func (lc *LightningChannel) getUnsignedAckedUpdates() []channeldb.LogUpdate {
// remote party expects.
var logUpdates []channeldb.LogUpdate
for e := lc.remoteUpdateLog.Front(); e != nil; e = e.Next() {
pd := e.Value.(*PaymentDescriptor)
pd := e.Value
// Skip all remote updates that we have already included in our
// commit chain.
@ -5195,7 +5195,7 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) (
var addIndex, settleFailIndex uint16
for e := lc.remoteUpdateLog.Front(); e != nil; e = e.Next() {
pd := e.Value.(*PaymentDescriptor)
pd := e.Value
// Fee updates are local to this particular channel, and should
// never be forwarded.
@ -5525,7 +5525,7 @@ func (lc *LightningChannel) GetDustSum(remote bool,
// Grab all of our HTLCs and evaluate against the dust limit.
for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() {
pd := e.Value.(*PaymentDescriptor)
pd := e.Value
if pd.EntryType != Add {
continue
}
@ -5544,7 +5544,7 @@ func (lc *LightningChannel) GetDustSum(remote bool,
// Grab all of their HTLCs and evaluate against the dust limit.
for e := lc.remoteUpdateLog.Front(); e != nil; e = e.Next() {
pd := e.Value.(*PaymentDescriptor)
pd := e.Value
if pd.EntryType != Add {
continue
}
@ -8545,7 +8545,7 @@ func (lc *LightningChannel) unsignedLocalUpdates(remoteMessageIndex,
var localPeerUpdates []channeldb.LogUpdate
for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() {
pd := e.Value.(*PaymentDescriptor)
pd := e.Value
// We don't save add updates as they are restored from the
// remote commitment in restoreStateLogs.

View File

@ -2,7 +2,6 @@ package lnwallet
import (
"bytes"
"container/list"
"crypto/sha256"
"fmt"
"math/rand"
@ -1906,7 +1905,7 @@ func TestStateUpdatePersistence(t *testing.T) {
// Newly generated pkScripts for HTLCs should be the same as in the old channel.
for _, entry := range aliceChannel.localUpdateLog.htlcIndex {
htlc := entry.Value.(*PaymentDescriptor)
htlc := entry.Value
restoredHtlc := aliceChannelNew.localUpdateLog.lookupHtlc(htlc.HtlcIndex)
if !bytes.Equal(htlc.ourPkScript, restoredHtlc.ourPkScript) {
t.Fatalf("alice ourPkScript in ourLog: expected %X, got %X",
@ -1918,7 +1917,7 @@ func TestStateUpdatePersistence(t *testing.T) {
}
}
for _, entry := range aliceChannel.remoteUpdateLog.htlcIndex {
htlc := entry.Value.(*PaymentDescriptor)
htlc := entry.Value
restoredHtlc := aliceChannelNew.remoteUpdateLog.lookupHtlc(htlc.HtlcIndex)
if !bytes.Equal(htlc.ourPkScript, restoredHtlc.ourPkScript) {
t.Fatalf("alice ourPkScript in theirLog: expected %X, got %X",
@ -1930,7 +1929,7 @@ func TestStateUpdatePersistence(t *testing.T) {
}
}
for _, entry := range bobChannel.localUpdateLog.htlcIndex {
htlc := entry.Value.(*PaymentDescriptor)
htlc := entry.Value
restoredHtlc := bobChannelNew.localUpdateLog.lookupHtlc(htlc.HtlcIndex)
if !bytes.Equal(htlc.ourPkScript, restoredHtlc.ourPkScript) {
t.Fatalf("bob ourPkScript in ourLog: expected %X, got %X",
@ -1942,7 +1941,7 @@ func TestStateUpdatePersistence(t *testing.T) {
}
}
for _, entry := range bobChannel.remoteUpdateLog.htlcIndex {
htlc := entry.Value.(*PaymentDescriptor)
htlc := entry.Value
restoredHtlc := bobChannelNew.remoteUpdateLog.lookupHtlc(htlc.HtlcIndex)
if !bytes.Equal(htlc.ourPkScript, restoredHtlc.ourPkScript) {
t.Fatalf("bob ourPkScript in theirLog: expected %X, got %X",
@ -4472,7 +4471,7 @@ func TestFeeUpdateOldDiskFormat(t *testing.T) {
countLog := func(log *updateLog) (int, int) {
var numUpdates, numFee int
for e := log.Front(); e != nil; e = e.Next() {
htlc := e.Value.(*PaymentDescriptor)
htlc := e.Value
if htlc.EntryType == FeeUpdate {
numFee++
}
@ -6755,14 +6754,14 @@ func compareHtlcs(htlc1, htlc2 *PaymentDescriptor) error {
}
// compareIndexes is a helper method to compare two index maps.
func compareIndexes(a, b map[uint64]*list.Element) error {
func compareIndexes(a, b map[uint64]*fn.Node[*PaymentDescriptor]) error {
for k1, e1 := range a {
e2, ok := b[k1]
if !ok {
return fmt.Errorf("element with key %d "+
"not found in b", k1)
}
htlc1, htlc2 := e1.Value.(*PaymentDescriptor), e2.Value.(*PaymentDescriptor)
htlc1, htlc2 := e1.Value, e2.Value
if err := compareHtlcs(htlc1, htlc2); err != nil {
return err
}
@ -6774,7 +6773,7 @@ func compareIndexes(a, b map[uint64]*list.Element) error {
return fmt.Errorf("element with key %d not "+
"found in a", k1)
}
htlc1, htlc2 := e1.Value.(*PaymentDescriptor), e2.Value.(*PaymentDescriptor)
htlc1, htlc2 := e1.Value, e2.Value
if err := compareHtlcs(htlc1, htlc2); err != nil {
return err
}
@ -6809,7 +6808,7 @@ func compareLogs(a, b *updateLog) error {
e1, e2 := a.Front(), b.Front()
for ; e1 != nil; e1, e2 = e1.Next(), e2.Next() {
htlc1, htlc2 := e1.Value.(*PaymentDescriptor), e2.Value.(*PaymentDescriptor)
htlc1, htlc2 := e1.Value, e2.Value
if err := compareHtlcs(htlc1, htlc2); err != nil {
return err
}
@ -6917,7 +6916,7 @@ func TestChannelRestoreUpdateLogs(t *testing.T) {
func fetchNumUpdates(t updateType, log *updateLog) int {
num := 0
for e := log.Front(); e != nil; e = e.Next() {
htlc := e.Value.(*PaymentDescriptor)
htlc := e.Value
if htlc.EntryType == t {
num++
}

View File

@ -1,6 +1,8 @@
package lnwallet
import "container/list"
import (
"github.com/lightningnetwork/lnd/fn"
)
// commitmentChain represents a chain of unrevoked commitments. The tail of the
// chain is the latest fully signed, yet unrevoked commitment. Two chains are
@ -15,13 +17,13 @@ type commitmentChain struct {
// commitments are added to the end of the chain with increase height.
// Once a commitment transaction is revoked, the tail is incremented,
// freeing up the revocation window for new commitments.
commitments *list.List
commitments *fn.List[*commitment]
}
// newCommitmentChain creates a new commitment chain.
func newCommitmentChain() *commitmentChain {
return &commitmentChain{
commitments: list.New(),
commitments: fn.NewList[*commitment](),
}
}
@ -42,14 +44,12 @@ func (s *commitmentChain) advanceTail() {
// tip returns the latest commitment added to the chain.
func (s *commitmentChain) tip() *commitment {
//nolint:forcetypeassert
return s.commitments.Back().Value.(*commitment)
return s.commitments.Back().Value
}
// tail returns the lowest unrevoked commitment transaction in the chain.
func (s *commitmentChain) tail() *commitment {
//nolint:forcetypeassert
return s.commitments.Front().Value.(*commitment)
return s.commitments.Front().Value
}
// hasUnackedCommitment returns true if the commitment chain has more than one

View File

@ -1,6 +1,8 @@
package lnwallet
import "container/list"
import (
"github.com/lightningnetwork/lnd/fn"
)
// updateLog is an append-only log that stores updates to a node's commitment
// chain. This structure can be seen as the "mempool" within Lightning where
@ -27,16 +29,16 @@ type updateLog struct {
// List is the updatelog itself, we embed this value so updateLog has
// access to all the method of a list.List.
*list.List
*fn.List[*PaymentDescriptor]
// updateIndex maps a `logIndex` to a particular update entry. It
// deals with the four update types:
// `Fail|MalformedFail|Settle|FeeUpdate`
updateIndex map[uint64]*list.Element
updateIndex map[uint64]*fn.Node[*PaymentDescriptor]
// htlcIndex maps a `htlcCounter` to an offered HTLC entry, hence the
// `Add` update.
htlcIndex map[uint64]*list.Element
htlcIndex map[uint64]*fn.Node[*PaymentDescriptor]
// modifiedHtlcs is a set that keeps track of all the current modified
// htlcs, hence update types `Fail|MalformedFail|Settle`. A modified
@ -48,9 +50,9 @@ type updateLog struct {
// newUpdateLog creates a new updateLog instance.
func newUpdateLog(logIndex, htlcCounter uint64) *updateLog {
return &updateLog{
List: list.New(),
updateIndex: make(map[uint64]*list.Element),
htlcIndex: make(map[uint64]*list.Element),
List: fn.NewList[*PaymentDescriptor](),
updateIndex: make(map[uint64]*fn.Node[*PaymentDescriptor]),
htlcIndex: make(map[uint64]*fn.Node[*PaymentDescriptor]),
logIndex: logIndex,
htlcCounter: htlcCounter,
modifiedHtlcs: make(map[uint64]struct{}),
@ -101,8 +103,7 @@ func (u *updateLog) lookupHtlc(i uint64) *PaymentDescriptor {
return nil
}
//nolint:forcetypeassert
return htlc.Value.(*PaymentDescriptor)
return htlc.Value
}
// remove attempts to remove an entry from the update log. If the entry is
@ -145,15 +146,14 @@ func compactLogs(ourLog, theirLog *updateLog,
localChainTail, remoteChainTail uint64) {
compactLog := func(logA, logB *updateLog) {
var nextA *list.Element
var nextA *fn.Node[*PaymentDescriptor]
for e := logA.Front(); e != nil; e = nextA {
// Assign next iteration element at top of loop because
// we may remove the current element from the list,
// which can change the iterated sequence.
nextA = e.Next()
//nolint:forcetypeassert
htlc := e.Value.(*PaymentDescriptor)
htlc := e.Value
// We skip Adds, as they will be removed along with the
// fail/settles below.