Merge pull request #8951 from ProofOfKeags/refactor/lnwallet-channel-channel-party

[MILLI]: Introduce and use ChannelParty
This commit is contained in:
Olaoluwa Osuntokun 2024-07-31 17:30:32 -07:00 committed by GitHub
commit 14ff12aa81
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 577 additions and 421 deletions

View file

@ -25,6 +25,7 @@ import (
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain" "github.com/lightningnetwork/lnd/shachain"
"github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tlv"
@ -1690,11 +1691,11 @@ func (c *OpenChannel) isBorked(chanBucket kvdb.RBucket) (bool, error) {
// republish this tx at startup to ensure propagation, and we should still // republish this tx at startup to ensure propagation, and we should still
// handle the case where a different tx actually hits the chain. // handle the case where a different tx actually hits the chain.
func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx, func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx,
locallyInitiated bool) error { closer lntypes.ChannelParty) error {
return c.markBroadcasted( return c.markBroadcasted(
ChanStatusCommitBroadcasted, forceCloseTxKey, closeTx, ChanStatusCommitBroadcasted, forceCloseTxKey, closeTx,
locallyInitiated, closer,
) )
} }
@ -1706,11 +1707,11 @@ func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx,
// ensure propagation, and we should still handle the case where a different tx // ensure propagation, and we should still handle the case where a different tx
// actually hits the chain. // actually hits the chain.
func (c *OpenChannel) MarkCoopBroadcasted(closeTx *wire.MsgTx, func (c *OpenChannel) MarkCoopBroadcasted(closeTx *wire.MsgTx,
locallyInitiated bool) error { closer lntypes.ChannelParty) error {
return c.markBroadcasted( return c.markBroadcasted(
ChanStatusCoopBroadcasted, coopCloseTxKey, closeTx, ChanStatusCoopBroadcasted, coopCloseTxKey, closeTx,
locallyInitiated, closer,
) )
} }
@ -1719,7 +1720,7 @@ func (c *OpenChannel) MarkCoopBroadcasted(closeTx *wire.MsgTx,
// which should specify either a coop or force close. It adds a status which // which should specify either a coop or force close. It adds a status which
// indicates the party that initiated the channel close. // indicates the party that initiated the channel close.
func (c *OpenChannel) markBroadcasted(status ChannelStatus, key []byte, func (c *OpenChannel) markBroadcasted(status ChannelStatus, key []byte,
closeTx *wire.MsgTx, locallyInitiated bool) error { closeTx *wire.MsgTx, closer lntypes.ChannelParty) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
@ -1741,7 +1742,7 @@ func (c *OpenChannel) markBroadcasted(status ChannelStatus, key []byte,
// Add the initiator status to the status provided. These statuses are // Add the initiator status to the status provided. These statuses are
// set in addition to the broadcast status so that we do not need to // set in addition to the broadcast status so that we do not need to
// migrate the original logic which does not store initiator. // migrate the original logic which does not store initiator.
if locallyInitiated { if closer.IsLocal() {
status |= ChanStatusLocalCloseInitiator status |= ChanStatusLocalCloseInitiator
} else { } else {
status |= ChanStatusRemoteCloseInitiator status |= ChanStatusRemoteCloseInitiator
@ -4486,6 +4487,15 @@ func NewShutdownInfo(deliveryScript lnwire.DeliveryAddress,
} }
} }
// Closer identifies the ChannelParty that initiated the coop-closure process.
func (s ShutdownInfo) Closer() lntypes.ChannelParty {
if s.LocalInitiator.Val {
return lntypes.Local
}
return lntypes.Remote
}
// encode serialises the ShutdownInfo to the given io.Writer. // encode serialises the ShutdownInfo to the given io.Writer.
func (s *ShutdownInfo) encode(w io.Writer) error { func (s *ShutdownInfo) encode(w io.Writer) error {
records := []tlv.Record{ records := []tlv.Record{

View file

@ -21,6 +21,7 @@ import (
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnmock" "github.com/lightningnetwork/lnd/lnmock"
"github.com/lightningnetwork/lnd/lntest/channels" "github.com/lightningnetwork/lnd/lntest/channels"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain" "github.com/lightningnetwork/lnd/shachain"
"github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tlv"
@ -1084,13 +1085,17 @@ func TestFetchWaitingCloseChannels(t *testing.T) {
}, },
) )
if err := channel.MarkCommitmentBroadcasted(closeTx, true); err != nil { if err := channel.MarkCommitmentBroadcasted(
closeTx, lntypes.Local,
); err != nil {
t.Fatalf("unable to mark commitment broadcast: %v", err) t.Fatalf("unable to mark commitment broadcast: %v", err)
} }
// Now try to marking a coop close with a nil tx. This should // Now try to marking a coop close with a nil tx. This should
// succeed, but it shouldn't exit when queried. // succeed, but it shouldn't exit when queried.
if err = channel.MarkCoopBroadcasted(nil, true); err != nil { if err = channel.MarkCoopBroadcasted(
nil, lntypes.Local,
); err != nil {
t.Fatalf("unable to mark nil coop broadcast: %v", err) t.Fatalf("unable to mark nil coop broadcast: %v", err)
} }
_, err := channel.BroadcastedCooperative() _, err := channel.BroadcastedCooperative()
@ -1102,7 +1107,9 @@ func TestFetchWaitingCloseChannels(t *testing.T) {
// it as coop closed. Later we will test that distinct // it as coop closed. Later we will test that distinct
// transactions are returned for both coop and force closes. // transactions are returned for both coop and force closes.
closeTx.TxIn[0].PreviousOutPoint.Index ^= 1 closeTx.TxIn[0].PreviousOutPoint.Index ^= 1
if err := channel.MarkCoopBroadcasted(closeTx, true); err != nil { if err := channel.MarkCoopBroadcasted(
closeTx, lntypes.Local,
); err != nil {
t.Fatalf("unable to mark coop broadcast: %v", err) t.Fatalf("unable to mark coop broadcast: %v", err)
} }
} }
@ -1324,7 +1331,7 @@ func TestCloseInitiator(t *testing.T) {
// by the local party. // by the local party.
updateChannel: func(c *OpenChannel) error { updateChannel: func(c *OpenChannel) error {
return c.MarkCoopBroadcasted( return c.MarkCoopBroadcasted(
&wire.MsgTx{}, true, &wire.MsgTx{}, lntypes.Local,
) )
}, },
expectedStatuses: []ChannelStatus{ expectedStatuses: []ChannelStatus{
@ -1338,7 +1345,7 @@ func TestCloseInitiator(t *testing.T) {
// by the remote party. // by the remote party.
updateChannel: func(c *OpenChannel) error { updateChannel: func(c *OpenChannel) error {
return c.MarkCoopBroadcasted( return c.MarkCoopBroadcasted(
&wire.MsgTx{}, false, &wire.MsgTx{}, lntypes.Remote,
) )
}, },
expectedStatuses: []ChannelStatus{ expectedStatuses: []ChannelStatus{
@ -1352,7 +1359,7 @@ func TestCloseInitiator(t *testing.T) {
// local initiator. // local initiator.
updateChannel: func(c *OpenChannel) error { updateChannel: func(c *OpenChannel) error {
return c.MarkCommitmentBroadcasted( return c.MarkCommitmentBroadcasted(
&wire.MsgTx{}, true, &wire.MsgTx{}, lntypes.Local,
) )
}, },
expectedStatuses: []ChannelStatus{ expectedStatuses: []ChannelStatus{

View file

@ -14,6 +14,7 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain" "github.com/lightningnetwork/lnd/shachain"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -606,7 +607,9 @@ func TestFetchChannels(t *testing.T) {
channelIDOption(pendingWaitingChan), channelIDOption(pendingWaitingChan),
) )
err = pendingClosing.MarkCoopBroadcasted(nil, true) err = pendingClosing.MarkCoopBroadcasted(
nil, lntypes.Local,
)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@ -626,7 +629,9 @@ func TestFetchChannels(t *testing.T) {
channelIDOption(openWaitingChan), channelIDOption(openWaitingChan),
openChannelOption(), openChannelOption(),
) )
err = openClosing.MarkCoopBroadcasted(nil, true) err = openClosing.MarkCoopBroadcasted(
nil, lntypes.Local,
)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }

View file

@ -11,6 +11,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lntest/mock"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -61,12 +62,14 @@ func TestChainArbitratorRepublishCloses(t *testing.T) {
for i := 0; i < numChans/2; i++ { for i := 0; i < numChans/2; i++ {
closeTx := channels[i].FundingTxn.Copy() closeTx := channels[i].FundingTxn.Copy()
closeTx.TxIn[0].PreviousOutPoint = channels[i].FundingOutpoint closeTx.TxIn[0].PreviousOutPoint = channels[i].FundingOutpoint
err := channels[i].MarkCommitmentBroadcasted(closeTx, true) err := channels[i].MarkCommitmentBroadcasted(
closeTx, lntypes.Local,
)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = channels[i].MarkCoopBroadcasted(closeTx, true) err = channels[i].MarkCoopBroadcasted(closeTx, lntypes.Local)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -20,6 +20,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -418,7 +419,7 @@ func (c *chainWatcher) handleUnknownLocalState(
// and remote keys for this state. We use our point as only we can // and remote keys for this state. We use our point as only we can
// revoke our own commitment. // revoke our own commitment.
commitKeyRing := lnwallet.DeriveCommitmentKeys( commitKeyRing := lnwallet.DeriveCommitmentKeys(
commitPoint, true, c.cfg.chanState.ChanType, commitPoint, lntypes.Local, c.cfg.chanState.ChanType,
&c.cfg.chanState.LocalChanCfg, &c.cfg.chanState.RemoteChanCfg, &c.cfg.chanState.LocalChanCfg, &c.cfg.chanState.RemoteChanCfg,
) )
@ -891,7 +892,7 @@ func (c *chainWatcher) handlePossibleBreach(commitSpend *chainntnfs.SpendDetail,
// Create an AnchorResolution for the breached state. // Create an AnchorResolution for the breached state.
anchorRes, err := lnwallet.NewAnchorResolution( anchorRes, err := lnwallet.NewAnchorResolution(
c.cfg.chanState, commitSpend.SpendingTx, retribution.KeyRing, c.cfg.chanState, commitSpend.SpendingTx, retribution.KeyRing,
false, lntypes.Remote,
) )
if err != nil { if err != nil {
return false, fmt.Errorf("unable to create anchor "+ return false, fmt.Errorf("unable to create anchor "+

View file

@ -129,7 +129,7 @@ type ChannelArbitratorConfig struct {
// MarkCommitmentBroadcasted should mark the channel as the commitment // MarkCommitmentBroadcasted should mark the channel as the commitment
// being broadcast, and we are waiting for the commitment to confirm. // being broadcast, and we are waiting for the commitment to confirm.
MarkCommitmentBroadcasted func(*wire.MsgTx, bool) error MarkCommitmentBroadcasted func(*wire.MsgTx, lntypes.ChannelParty) error
// MarkChannelClosed marks the channel closed in the database, with the // MarkChannelClosed marks the channel closed in the database, with the
// passed close summary. After this method successfully returns we can // passed close summary. After this method successfully returns we can
@ -1084,7 +1084,7 @@ func (c *ChannelArbitrator) stateStep(
// database, such that we can re-publish later in case it // database, such that we can re-publish later in case it
// didn't propagate. We initiated the force close, so we // didn't propagate. We initiated the force close, so we
// mark broadcast with local initiator set to true. // mark broadcast with local initiator set to true.
err = c.cfg.MarkCommitmentBroadcasted(closeTx, true) err = c.cfg.MarkCommitmentBroadcasted(closeTx, lntypes.Local)
if err != nil { if err != nil {
log.Errorf("ChannelArbitrator(%v): unable to "+ log.Errorf("ChannelArbitrator(%v): unable to "+
"mark commitment broadcasted: %v", "mark commitment broadcasted: %v",

View file

@ -416,7 +416,9 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog,
resolvedChan <- struct{}{} resolvedChan <- struct{}{}
return nil return nil
}, },
MarkCommitmentBroadcasted: func(_ *wire.MsgTx, _ bool) error { MarkCommitmentBroadcasted: func(_ *wire.MsgTx,
_ lntypes.ChannelParty) error {
return nil return nil
}, },
MarkChannelClosed: func(*channeldb.ChannelCloseSummary, MarkChannelClosed: func(*channeldb.ChannelCloseSummary,

View file

@ -63,7 +63,7 @@ type dustHandler interface {
// getDustSum returns the dust sum on either the local or remote // getDustSum returns the dust sum on either the local or remote
// commitment. An optional fee parameter can be passed in which is used // commitment. An optional fee parameter can be passed in which is used
// to calculate the dust sum. // to calculate the dust sum.
getDustSum(remote bool, getDustSum(whoseCommit lntypes.ChannelParty,
fee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi fee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi
// getFeeRate returns the current channel feerate. // getFeeRate returns the current channel feerate.

View file

@ -2727,10 +2727,10 @@ func (l *channelLink) MayAddOutgoingHtlc(amt lnwire.MilliSatoshi) error {
// method. // method.
// //
// NOTE: Part of the dustHandler interface. // NOTE: Part of the dustHandler interface.
func (l *channelLink) getDustSum(remote bool, func (l *channelLink) getDustSum(whoseCommit lntypes.ChannelParty,
dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi { dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi {
return l.channel.GetDustSum(remote, dryRunFee) return l.channel.GetDustSum(whoseCommit, dryRunFee)
} }
// getFeeRate is a wrapper method that retrieves the underlying channel's // getFeeRate is a wrapper method that retrieves the underlying channel's
@ -2784,8 +2784,8 @@ func (l *channelLink) exceedsFeeExposureLimit(
// Get the sum of dust for both the local and remote commitments using // Get the sum of dust for both the local and remote commitments using
// this "dry-run" fee. // this "dry-run" fee.
localDustSum := l.getDustSum(false, dryRunFee) localDustSum := l.getDustSum(lntypes.Local, dryRunFee)
remoteDustSum := l.getDustSum(true, dryRunFee) remoteDustSum := l.getDustSum(lntypes.Remote, dryRunFee)
// Calculate the local and remote commitment fees using this dry-run // Calculate the local and remote commitment fees using this dry-run
// fee. // fee.
@ -2826,12 +2826,16 @@ func (l *channelLink) isOverexposedWithHtlc(htlc *lnwire.UpdateAddHTLC,
amount := htlc.Amount.ToSatoshis() amount := htlc.Amount.ToSatoshis()
// See if this HTLC is dust on both the local and remote commitments. // See if this HTLC is dust on both the local and remote commitments.
isLocalDust := dustClosure(feeRate, incoming, true, amount) isLocalDust := dustClosure(feeRate, incoming, lntypes.Local, amount)
isRemoteDust := dustClosure(feeRate, incoming, false, amount) isRemoteDust := dustClosure(feeRate, incoming, lntypes.Remote, amount)
// Calculate the dust sum for the local and remote commitments. // Calculate the dust sum for the local and remote commitments.
localDustSum := l.getDustSum(false, fn.None[chainfee.SatPerKWeight]()) localDustSum := l.getDustSum(
remoteDustSum := l.getDustSum(true, fn.None[chainfee.SatPerKWeight]()) lntypes.Local, fn.None[chainfee.SatPerKWeight](),
)
remoteDustSum := l.getDustSum(
lntypes.Remote, fn.None[chainfee.SatPerKWeight](),
)
// Grab the larger of the local and remote commitment fees w/o dust. // Grab the larger of the local and remote commitment fees w/o dust.
commitFee := l.getCommitFee(false) commitFee := l.getCommitFee(false)
@ -2882,25 +2886,26 @@ func (l *channelLink) isOverexposedWithHtlc(htlc *lnwire.UpdateAddHTLC,
// the HTLC is incoming (i.e. one that the remote sent), a boolean denoting // the HTLC is incoming (i.e. one that the remote sent), a boolean denoting
// whether to evaluate on the local or remote commit, and finally an HTLC // whether to evaluate on the local or remote commit, and finally an HTLC
// amount to test. // amount to test.
type dustClosure func(chainfee.SatPerKWeight, bool, bool, btcutil.Amount) bool type dustClosure func(feerate chainfee.SatPerKWeight, incoming bool,
whoseCommit lntypes.ChannelParty, amt btcutil.Amount) bool
// dustHelper is used to construct the dustClosure. // dustHelper is used to construct the dustClosure.
func dustHelper(chantype channeldb.ChannelType, localDustLimit, func dustHelper(chantype channeldb.ChannelType, localDustLimit,
remoteDustLimit btcutil.Amount) dustClosure { remoteDustLimit btcutil.Amount) dustClosure {
isDust := func(feerate chainfee.SatPerKWeight, incoming, isDust := func(feerate chainfee.SatPerKWeight, incoming bool,
localCommit bool, amt btcutil.Amount) bool { whoseCommit lntypes.ChannelParty, amt btcutil.Amount) bool {
if localCommit { var dustLimit btcutil.Amount
return lnwallet.HtlcIsDust( if whoseCommit.IsLocal() {
chantype, incoming, true, feerate, amt, dustLimit = localDustLimit
localDustLimit, } else {
) dustLimit = remoteDustLimit
} }
return lnwallet.HtlcIsDust( return lnwallet.HtlcIsDust(
chantype, incoming, false, feerate, amt, chantype, incoming, whoseCommit, feerate, amt,
remoteDustLimit, dustLimit,
) )
} }

View file

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
@ -660,7 +661,8 @@ func (m *memoryMailBox) DustPackets() (lnwire.MilliSatoshi,
// Evaluate whether this HTLC is dust on the local commitment. // Evaluate whether this HTLC is dust on the local commitment.
if m.isDust( if m.isDust(
m.feeRate, false, true, addPkt.amount.ToSatoshis(), m.feeRate, false, lntypes.Local,
addPkt.amount.ToSatoshis(),
) { ) {
localDustSum += addPkt.amount localDustSum += addPkt.amount
@ -668,7 +670,8 @@ func (m *memoryMailBox) DustPackets() (lnwire.MilliSatoshi,
// Evaluate whether this HTLC is dust on the remote commitment. // Evaluate whether this HTLC is dust on the remote commitment.
if m.isDust( if m.isDust(
m.feeRate, false, false, addPkt.amount.ToSatoshis(), m.feeRate, false, lntypes.Remote,
addPkt.amount.ToSatoshis(),
) { ) {
remoteDustSum += addPkt.amount remoteDustSum += addPkt.amount

View file

@ -814,7 +814,7 @@ func (f *mockChannelLink) handleSwitchPacket(pkt *htlcPacket) error {
return nil return nil
} }
func (f *mockChannelLink) getDustSum(remote bool, func (f *mockChannelLink) getDustSum(whoseCommit lntypes.ChannelParty,
dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi { dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi {
return 0 return 0

View file

@ -2788,8 +2788,12 @@ func (s *Switch) dustExceedsFeeThreshold(link ChannelLink,
isDust := link.getDustClosure() isDust := link.getDustClosure()
// Evaluate if the HTLC is dust on either sides' commitment. // Evaluate if the HTLC is dust on either sides' commitment.
isLocalDust := isDust(feeRate, incoming, true, amount.ToSatoshis()) isLocalDust := isDust(
isRemoteDust := isDust(feeRate, incoming, false, amount.ToSatoshis()) feeRate, incoming, lntypes.Local, amount.ToSatoshis(),
)
isRemoteDust := isDust(
feeRate, incoming, lntypes.Remote, amount.ToSatoshis(),
)
if !(isLocalDust || isRemoteDust) { if !(isLocalDust || isRemoteDust) {
// If the HTLC is not dust on either commitment, it's fine to // If the HTLC is not dust on either commitment, it's fine to
@ -2807,7 +2811,7 @@ func (s *Switch) dustExceedsFeeThreshold(link ChannelLink,
// sum for it. // sum for it.
if isLocalDust { if isLocalDust {
localSum := link.getDustSum( localSum := link.getDustSum(
false, fn.None[chainfee.SatPerKWeight](), lntypes.Local, fn.None[chainfee.SatPerKWeight](),
) )
localSum += localMailDust localSum += localMailDust
@ -2827,7 +2831,7 @@ func (s *Switch) dustExceedsFeeThreshold(link ChannelLink,
// reached this point. // reached this point.
if isRemoteDust { if isRemoteDust {
remoteSum := link.getDustSum( remoteSum := link.getDustSum(
true, fn.None[chainfee.SatPerKWeight](), lntypes.Remote, fn.None[chainfee.SatPerKWeight](),
) )
remoteSum += remoteMailDust remoteSum += remoteMailDust

View file

@ -4319,7 +4319,7 @@ func TestSwitchDustForwarding(t *testing.T) {
} }
checkAlmostDust := func(link *channelLink, mbox MailBox, checkAlmostDust := func(link *channelLink, mbox MailBox,
remote bool) bool { whoseCommit lntypes.ChannelParty) bool {
timeout := time.After(15 * time.Second) timeout := time.After(15 * time.Second)
pollInterval := 300 * time.Millisecond pollInterval := 300 * time.Millisecond
@ -4335,12 +4335,12 @@ func TestSwitchDustForwarding(t *testing.T) {
} }
linkDust := link.getDustSum( linkDust := link.getDustSum(
remote, fn.None[chainfee.SatPerKWeight](), whoseCommit, fn.None[chainfee.SatPerKWeight](),
) )
localMailDust, remoteMailDust := mbox.DustPackets() localMailDust, remoteMailDust := mbox.DustPackets()
totalDust := linkDust totalDust := linkDust
if remote { if whoseCommit.IsRemote() {
totalDust += remoteMailDust totalDust += remoteMailDust
} else { } else {
totalDust += localMailDust totalDust += localMailDust
@ -4359,7 +4359,11 @@ func TestSwitchDustForwarding(t *testing.T) {
n.firstBobChannelLink.ChanID(), n.firstBobChannelLink.ChanID(),
n.firstBobChannelLink.ShortChanID(), n.firstBobChannelLink.ShortChanID(),
) )
require.True(t, checkAlmostDust(n.firstBobChannelLink, bobMbox, false)) require.True(
t, checkAlmostDust(
n.firstBobChannelLink, bobMbox, lntypes.Local,
),
)
// Sending one more HTLC should fail. SendHTLC won't error, but the // Sending one more HTLC should fail. SendHTLC won't error, but the
// HTLC should be failed backwards. // HTLC should be failed backwards.
@ -4408,7 +4412,9 @@ func TestSwitchDustForwarding(t *testing.T) {
aliceBobFirstHop, uint64(bobAttemptID), nondustHtlc, aliceBobFirstHop, uint64(bobAttemptID), nondustHtlc,
) )
require.NoError(t, err) require.NoError(t, err)
require.True(t, checkAlmostDust(n.firstBobChannelLink, bobMbox, false)) require.True(t, checkAlmostDust(
n.firstBobChannelLink, bobMbox, lntypes.Local,
))
// Check that the HTLC failed. // Check that the HTLC failed.
bobResultChan, err = n.bobServer.htlcSwitch.GetAttemptResult( bobResultChan, err = n.bobServer.htlcSwitch.GetAttemptResult(
@ -4486,7 +4492,11 @@ func TestSwitchDustForwarding(t *testing.T) {
aliceMbox := aliceOrch.GetOrCreateMailBox( aliceMbox := aliceOrch.GetOrCreateMailBox(
n.aliceChannelLink.ChanID(), n.aliceChannelLink.ShortChanID(), n.aliceChannelLink.ChanID(), n.aliceChannelLink.ShortChanID(),
) )
require.True(t, checkAlmostDust(n.aliceChannelLink, aliceMbox, true)) require.True(
t, checkAlmostDust(
n.aliceChannelLink, aliceMbox, lntypes.Remote,
),
)
err = n.aliceServer.htlcSwitch.SendHTLC( err = n.aliceServer.htlcSwitch.SendHTLC(
n.aliceChannelLink.ShortChanID(), uint64(aliceAttemptID), n.aliceChannelLink.ShortChanID(), uint64(aliceAttemptID),

View file

@ -13,6 +13,7 @@ import (
"github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnutils"
"golang.org/x/crypto/ripemd160" "golang.org/x/crypto/ripemd160"
) )
@ -789,10 +790,10 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey,
// unilaterally spend the created output. // unilaterally spend the created output.
func SenderHTLCScriptTaproot(senderHtlcKey, receiverHtlcKey, func SenderHTLCScriptTaproot(senderHtlcKey, receiverHtlcKey,
revokeKey *btcec.PublicKey, payHash []byte, revokeKey *btcec.PublicKey, payHash []byte,
localCommit bool) (*HtlcScriptTree, error) { whoseCommit lntypes.ChannelParty) (*HtlcScriptTree, error) {
var hType htlcType var hType htlcType
if localCommit { if whoseCommit.IsLocal() {
hType = htlcLocalOutgoing hType = htlcLocalOutgoing
} else { } else {
hType = htlcRemoteIncoming hType = htlcRemoteIncoming
@ -1348,10 +1349,11 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey,
// the tap leaf are returned. // the tap leaf are returned.
func ReceiverHTLCScriptTaproot(cltvExpiry uint32, func ReceiverHTLCScriptTaproot(cltvExpiry uint32,
senderHtlcKey, receiverHtlcKey, revocationKey *btcec.PublicKey, senderHtlcKey, receiverHtlcKey, revocationKey *btcec.PublicKey,
payHash []byte, ourCommit bool) (*HtlcScriptTree, error) { payHash []byte, whoseCommit lntypes.ChannelParty,
) (*HtlcScriptTree, error) {
var hType htlcType var hType htlcType
if ourCommit { if whoseCommit.IsLocal() {
hType = htlcLocalIncoming hType = htlcLocalIncoming
} else { } else {
hType = htlcRemoteOutgoing hType = htlcRemoteOutgoing

View file

@ -13,6 +13,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -1073,7 +1074,7 @@ var witnessSizeTests = []witnessSizeTest{
htlcScriptTree, err := input.SenderHTLCScriptTaproot( htlcScriptTree, err := input.SenderHTLCScriptTaproot(
senderKey.PubKey(), receiverKey.PubKey(), senderKey.PubKey(), receiverKey.PubKey(),
revokeKey.PubKey(), payHash[:], false, revokeKey.PubKey(), payHash[:], lntypes.Remote,
) )
require.NoError(t, err) require.NoError(t, err)
@ -1115,7 +1116,7 @@ var witnessSizeTests = []witnessSizeTest{
htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( htlcScriptTree, err := input.ReceiverHTLCScriptTaproot(
testCLTVExpiry, senderKey.PubKey(), testCLTVExpiry, senderKey.PubKey(),
receiverKey.PubKey(), revokeKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(),
payHash[:], false, payHash[:], lntypes.Remote,
) )
require.NoError(t, err) require.NoError(t, err)
@ -1157,7 +1158,7 @@ var witnessSizeTests = []witnessSizeTest{
htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( htlcScriptTree, err := input.ReceiverHTLCScriptTaproot(
testCLTVExpiry, senderKey.PubKey(), testCLTVExpiry, senderKey.PubKey(),
receiverKey.PubKey(), revokeKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(),
payHash[:], false, payHash[:], lntypes.Remote,
) )
require.NoError(t, err) require.NoError(t, err)
@ -1203,7 +1204,7 @@ var witnessSizeTests = []witnessSizeTest{
htlcScriptTree, err := input.SenderHTLCScriptTaproot( htlcScriptTree, err := input.SenderHTLCScriptTaproot(
senderKey.PubKey(), receiverKey.PubKey(), senderKey.PubKey(), receiverKey.PubKey(),
revokeKey.PubKey(), payHash[:], false, revokeKey.PubKey(), payHash[:], lntypes.Remote,
) )
require.NoError(t, err) require.NoError(t, err)
@ -1263,7 +1264,7 @@ var witnessSizeTests = []witnessSizeTest{
htlcScriptTree, err := input.SenderHTLCScriptTaproot( htlcScriptTree, err := input.SenderHTLCScriptTaproot(
senderKey.PubKey(), receiverKey.PubKey(), senderKey.PubKey(), receiverKey.PubKey(),
revokeKey.PubKey(), payHash[:], false, revokeKey.PubKey(), payHash[:], lntypes.Remote,
) )
require.NoError(t, err) require.NoError(t, err)
@ -1309,7 +1310,7 @@ var witnessSizeTests = []witnessSizeTest{
htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( htlcScriptTree, err := input.ReceiverHTLCScriptTaproot(
testCLTVExpiry, senderKey.PubKey(), testCLTVExpiry, senderKey.PubKey(),
receiverKey.PubKey(), revokeKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(),
payHash[:], false, payHash[:], lntypes.Remote,
) )
require.NoError(t, err) require.NoError(t, err)
@ -1394,7 +1395,8 @@ func genTimeoutTx(t *testing.T,
) )
if chanType.IsTaproot() { if chanType.IsTaproot() {
tapscriptTree, err = input.SenderHTLCScriptTaproot( tapscriptTree, err = input.SenderHTLCScriptTaproot(
testPubkey, testPubkey, testPubkey, testHash160, false, testPubkey, testPubkey, testPubkey, testHash160,
lntypes.Remote,
) )
require.NoError(t, err) require.NoError(t, err)
@ -1463,7 +1465,7 @@ func genSuccessTx(t *testing.T, chanType channeldb.ChannelType) *wire.MsgTx {
if chanType.IsTaproot() { if chanType.IsTaproot() {
tapscriptTree, err = input.ReceiverHTLCScriptTaproot( tapscriptTree, err = input.ReceiverHTLCScriptTaproot(
testCLTVExpiry, testPubkey, testPubkey, testPubkey, testCLTVExpiry, testPubkey, testPubkey, testPubkey,
testHash160, false, testHash160, lntypes.Remote,
) )
require.NoError(t, err) require.NoError(t, err)

View file

@ -48,7 +48,7 @@ func newTestSenderHtlcScriptTree(t *testing.T) *testSenderHtlcScriptTree {
payHash := preImage.Hash() payHash := preImage.Hash()
htlcScriptTree, err := SenderHTLCScriptTaproot( htlcScriptTree, err := SenderHTLCScriptTaproot(
senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(),
payHash[:], false, payHash[:], lntypes.Remote,
) )
require.NoError(t, err) require.NoError(t, err)
@ -471,7 +471,7 @@ func newTestReceiverHtlcScriptTree(t *testing.T) *testReceiverHtlcScriptTree {
payHash := preImage.Hash() payHash := preImage.Hash()
htlcScriptTree, err := ReceiverHTLCScriptTaproot( htlcScriptTree, err := ReceiverHTLCScriptTaproot(
cltvExpiry, senderKey.PubKey(), receiverKey.PubKey(), cltvExpiry, senderKey.PubKey(), receiverKey.PubKey(),
revokeKey.PubKey(), payHash[:], false, revokeKey.PubKey(), payHash[:], lntypes.Remote,
) )
require.NoError(t, err) require.NoError(t, err)

52
lntypes/channel_party.go Normal file
View file

@ -0,0 +1,52 @@
package lntypes
import "fmt"
// ChannelParty is a type used to have an unambiguous description of which node
// is being referred to. This eliminates the need to describe as "local" or
// "remote" using bool.
type ChannelParty uint8
const (
// Local is a ChannelParty constructor that is used to refer to the
// node that is running.
Local ChannelParty = iota
// Remote is a ChannelParty constructor that is used to refer to the
// node on the other end of the peer connection.
Remote
)
// String provides a string representation of ChannelParty (useful for logging).
func (p ChannelParty) String() string {
switch p {
case Local:
return "Local"
case Remote:
return "Remote"
default:
panic(fmt.Sprintf("invalid ChannelParty value: %d", p))
}
}
// CounterParty inverts the role of the ChannelParty.
func (p ChannelParty) CounterParty() ChannelParty {
switch p {
case Local:
return Remote
case Remote:
return Local
default:
panic(fmt.Sprintf("invalid ChannelParty value: %v", p))
}
}
// IsLocal returns true if the ChannelParty is Local.
func (p ChannelParty) IsLocal() bool {
return p == Local
}
// IsRemote returns true if the ChannelParty is Remote.
func (p ChannelParty) IsRemote() bool {
return p == Remote
}

View file

@ -15,6 +15,7 @@ import (
"github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/labels"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chainfee"
@ -207,8 +208,8 @@ type ChanCloser struct {
// settled channel funds to. // settled channel funds to.
remoteDeliveryScript []byte remoteDeliveryScript []byte
// locallyInitiated is true if we initiated the channel close. // closer is ChannelParty who initiated the coop close
locallyInitiated bool closer lntypes.ChannelParty
// cachedClosingSigned is a cached copy of a received ClosingSigned that // cachedClosingSigned is a cached copy of a received ClosingSigned that
// we use to handle a specific race condition caused by the independent // we use to handle a specific race condition caused by the independent
@ -267,7 +268,8 @@ func (d *SimpleCoopFeeEstimator) EstimateFee(chanType channeldb.ChannelType,
// be populated iff, we're the initiator of this closing request. // be populated iff, we're the initiator of this closing request.
func NewChanCloser(cfg ChanCloseCfg, deliveryScript []byte, func NewChanCloser(cfg ChanCloseCfg, deliveryScript []byte,
idealFeePerKw chainfee.SatPerKWeight, negotiationHeight uint32, idealFeePerKw chainfee.SatPerKWeight, negotiationHeight uint32,
closeReq *htlcswitch.ChanClose, locallyInitiated bool) *ChanCloser { closeReq *htlcswitch.ChanClose,
closer lntypes.ChannelParty) *ChanCloser {
chanPoint := cfg.Channel.ChannelPoint() chanPoint := cfg.Channel.ChannelPoint()
cid := lnwire.NewChanIDFromOutPoint(chanPoint) cid := lnwire.NewChanIDFromOutPoint(chanPoint)
@ -283,7 +285,7 @@ func NewChanCloser(cfg ChanCloseCfg, deliveryScript []byte,
priorFeeOffers: make( priorFeeOffers: make(
map[btcutil.Amount]*lnwire.ClosingSigned, map[btcutil.Amount]*lnwire.ClosingSigned,
), ),
locallyInitiated: locallyInitiated, closer: closer,
} }
} }
@ -366,7 +368,7 @@ func (c *ChanCloser) initChanShutdown() (*lnwire.Shutdown, error) {
// message we are about to send in order to ensure that if a // message we are about to send in order to ensure that if a
// re-establish occurs then we will re-send the same Shutdown message. // re-establish occurs then we will re-send the same Shutdown message.
shutdownInfo := channeldb.NewShutdownInfo( shutdownInfo := channeldb.NewShutdownInfo(
c.localDeliveryScript, c.locallyInitiated, c.localDeliveryScript, c.closer.IsLocal(),
) )
err := c.cfg.Channel.MarkShutdownSent(shutdownInfo) err := c.cfg.Channel.MarkShutdownSent(shutdownInfo)
if err != nil { if err != nil {
@ -650,7 +652,7 @@ func (c *ChanCloser) BeginNegotiation() (fn.Option[lnwire.ClosingSigned],
// externally consistent, and reflect that the channel is being // externally consistent, and reflect that the channel is being
// shutdown by the time the closing request returns. // shutdown by the time the closing request returns.
err := c.cfg.Channel.MarkCoopBroadcasted( err := c.cfg.Channel.MarkCoopBroadcasted(
nil, c.locallyInitiated, nil, c.closer,
) )
if err != nil { if err != nil {
return noClosingSigned, err return noClosingSigned, err
@ -861,7 +863,7 @@ func (c *ChanCloser) ReceiveClosingSigned( //nolint:funlen
// database, such that it can be republished if something goes // database, such that it can be republished if something goes
// wrong. // wrong.
err = c.cfg.Channel.MarkCoopBroadcasted( err = c.cfg.Channel.MarkCoopBroadcasted(
closeTx, c.locallyInitiated, closeTx, c.closer,
) )
if err != nil { if err != nil {
return noClosing, err return noClosing, err

View file

@ -16,6 +16,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chainfee"
@ -150,7 +151,9 @@ func (m *mockChannel) ChannelPoint() wire.OutPoint {
return m.chanPoint return m.chanPoint
} }
func (m *mockChannel) MarkCoopBroadcasted(*wire.MsgTx, bool) error { func (m *mockChannel) MarkCoopBroadcasted(*wire.MsgTx,
lntypes.ChannelParty) error {
return nil return nil
} }
@ -338,7 +341,7 @@ func TestMaxFeeClamp(t *testing.T) {
Channel: &channel, Channel: &channel,
MaxFee: test.inputMaxFee, MaxFee: test.inputMaxFee,
FeeEstimator: &SimpleCoopFeeEstimator{}, FeeEstimator: &SimpleCoopFeeEstimator{},
}, nil, test.idealFee, 0, nil, false, }, nil, test.idealFee, 0, nil, lntypes.Remote,
) )
// We'll call initFeeBaseline early here since we need // We'll call initFeeBaseline early here since we need
@ -379,7 +382,7 @@ func TestMaxFeeBailOut(t *testing.T) {
MaxFee: idealFee * 2, MaxFee: idealFee * 2,
} }
chanCloser := NewChanCloser( chanCloser := NewChanCloser(
closeCfg, nil, idealFee, 0, nil, false, closeCfg, nil, idealFee, 0, nil, lntypes.Remote,
) )
// We'll now force the channel state into the // We'll now force the channel state into the
@ -503,7 +506,7 @@ func TestTaprootFastClose(t *testing.T) {
DisableChannel: func(wire.OutPoint) error { DisableChannel: func(wire.OutPoint) error {
return nil return nil
}, },
}, nil, idealFee, 0, nil, true, }, nil, idealFee, 0, nil, lntypes.Local,
) )
aliceCloser.initFeeBaseline() aliceCloser.initFeeBaseline()
@ -520,7 +523,7 @@ func TestTaprootFastClose(t *testing.T) {
DisableChannel: func(wire.OutPoint) error { DisableChannel: func(wire.OutPoint) error {
return nil return nil
}, },
}, nil, idealFee, 0, nil, false, }, nil, idealFee, 0, nil, lntypes.Remote,
) )
bobCloser.initFeeBaseline() bobCloser.initFeeBaseline()

View file

@ -7,6 +7,7 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -33,7 +34,7 @@ type Channel interface { //nolint:interfacebloat
// MarkCoopBroadcasted persistently marks that the channel close // MarkCoopBroadcasted persistently marks that the channel close
// transaction has been broadcast. // transaction has been broadcast.
MarkCoopBroadcasted(*wire.MsgTx, bool) error MarkCoopBroadcasted(*wire.MsgTx, lntypes.ChannelParty) error
// MarkShutdownSent persists the given ShutdownInfo. The existence of // MarkShutdownSent persists the given ShutdownInfo. The existence of
// the ShutdownInfo represents the fact that the Shutdown message has // the ShutdownInfo represents the fact that the Shutdown message has

File diff suppressed because it is too large Load diff

View file

@ -5196,7 +5196,7 @@ func TestChanCommitWeightDustHtlcs(t *testing.T) {
lc.localUpdateLog.logIndex) lc.localUpdateLog.logIndex)
_, w := lc.availableCommitmentBalance( _, w := lc.availableCommitmentBalance(
htlcView, true, FeeBuffer, htlcView, lntypes.Remote, FeeBuffer,
) )
return w return w
@ -7985,11 +7985,11 @@ func TestChannelFeeRateFloor(t *testing.T) {
// TestFetchParent tests lookup of an entry's parent in the appropriate log. // TestFetchParent tests lookup of an entry's parent in the appropriate log.
func TestFetchParent(t *testing.T) { func TestFetchParent(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
remoteChain bool whoseCommitChain lntypes.ChannelParty
remoteLog bool whoseUpdateLog lntypes.ChannelParty
localEntries []*PaymentDescriptor localEntries []*PaymentDescriptor
remoteEntries []*PaymentDescriptor remoteEntries []*PaymentDescriptor
// parentIndex is the parent index of the entry that we will // parentIndex is the parent index of the entry that we will
// lookup with fetch parent. // lookup with fetch parent.
@ -8003,22 +8003,22 @@ func TestFetchParent(t *testing.T) {
expectedIndex uint64 expectedIndex uint64
}{ }{
{ {
name: "not found in remote log", name: "not found in remote log",
localEntries: nil, localEntries: nil,
remoteEntries: nil, remoteEntries: nil,
remoteChain: true, whoseCommitChain: lntypes.Remote,
remoteLog: true, whoseUpdateLog: lntypes.Remote,
parentIndex: 0, parentIndex: 0,
expectErr: true, expectErr: true,
}, },
{ {
name: "not found in local log", name: "not found in local log",
localEntries: nil, localEntries: nil,
remoteEntries: nil, remoteEntries: nil,
remoteChain: false, whoseCommitChain: lntypes.Local,
remoteLog: false, whoseUpdateLog: lntypes.Local,
parentIndex: 0, parentIndex: 0,
expectErr: true, expectErr: true,
}, },
{ {
name: "remote log + chain, remote add height 0", name: "remote log + chain, remote add height 0",
@ -8038,10 +8038,10 @@ func TestFetchParent(t *testing.T) {
addCommitHeightRemote: 0, addCommitHeightRemote: 0,
}, },
}, },
remoteChain: true, whoseCommitChain: lntypes.Remote,
remoteLog: true, whoseUpdateLog: lntypes.Remote,
parentIndex: 1, parentIndex: 1,
expectErr: true, expectErr: true,
}, },
{ {
name: "remote log, local chain, local add height 0", name: "remote log, local chain, local add height 0",
@ -8060,11 +8060,11 @@ func TestFetchParent(t *testing.T) {
addCommitHeightRemote: 100, addCommitHeightRemote: 100,
}, },
}, },
localEntries: nil, localEntries: nil,
remoteChain: false, whoseCommitChain: lntypes.Local,
remoteLog: true, whoseUpdateLog: lntypes.Remote,
parentIndex: 1, parentIndex: 1,
expectErr: true, expectErr: true,
}, },
{ {
name: "local log + chain, local add height 0", name: "local log + chain, local add height 0",
@ -8083,11 +8083,11 @@ func TestFetchParent(t *testing.T) {
addCommitHeightRemote: 100, addCommitHeightRemote: 100,
}, },
}, },
remoteEntries: nil, remoteEntries: nil,
remoteChain: false, whoseCommitChain: lntypes.Local,
remoteLog: false, whoseUpdateLog: lntypes.Local,
parentIndex: 1, parentIndex: 1,
expectErr: true, expectErr: true,
}, },
{ {
@ -8107,11 +8107,11 @@ func TestFetchParent(t *testing.T) {
addCommitHeightRemote: 0, addCommitHeightRemote: 0,
}, },
}, },
remoteEntries: nil, remoteEntries: nil,
remoteChain: true, whoseCommitChain: lntypes.Remote,
remoteLog: false, whoseUpdateLog: lntypes.Local,
parentIndex: 1, parentIndex: 1,
expectErr: true, expectErr: true,
}, },
{ {
name: "remote log found", name: "remote log found",
@ -8131,11 +8131,11 @@ func TestFetchParent(t *testing.T) {
addCommitHeightRemote: 100, addCommitHeightRemote: 100,
}, },
}, },
remoteChain: true, whoseCommitChain: lntypes.Remote,
remoteLog: true, whoseUpdateLog: lntypes.Remote,
parentIndex: 1, parentIndex: 1,
expectErr: false, expectErr: false,
expectedIndex: 2, expectedIndex: 2,
}, },
{ {
name: "local log found", name: "local log found",
@ -8154,12 +8154,12 @@ func TestFetchParent(t *testing.T) {
addCommitHeightRemote: 100, addCommitHeightRemote: 100,
}, },
}, },
remoteEntries: nil, remoteEntries: nil,
remoteChain: false, whoseCommitChain: lntypes.Local,
remoteLog: false, whoseUpdateLog: lntypes.Local,
parentIndex: 1, parentIndex: 1,
expectErr: false, expectErr: false,
expectedIndex: 2, expectedIndex: 2,
}, },
} }
@ -8186,8 +8186,8 @@ func TestFetchParent(t *testing.T) {
&PaymentDescriptor{ &PaymentDescriptor{
ParentIndex: test.parentIndex, ParentIndex: test.parentIndex,
}, },
test.remoteChain, test.whoseCommitChain,
test.remoteLog, test.whoseUpdateLog,
) )
gotErr := err != nil gotErr := err != nil
if test.expectErr != gotErr { if test.expectErr != gotErr {
@ -8245,11 +8245,11 @@ func TestEvaluateView(t *testing.T) {
) )
tests := []struct { tests := []struct {
name string name string
ourHtlcs []*PaymentDescriptor ourHtlcs []*PaymentDescriptor
theirHtlcs []*PaymentDescriptor theirHtlcs []*PaymentDescriptor
remoteChain bool whoseCommitChain lntypes.ChannelParty
mutateState bool mutateState bool
// ourExpectedHtlcs is the set of our htlcs that we expect in // ourExpectedHtlcs is the set of our htlcs that we expect in
// the htlc view once it has been evaluated. We just store // the htlc view once it has been evaluated. We just store
@ -8276,9 +8276,9 @@ func TestEvaluateView(t *testing.T) {
expectSent lnwire.MilliSatoshi expectSent lnwire.MilliSatoshi
}{ }{
{ {
name: "our fee update is applied", name: "our fee update is applied",
remoteChain: false, whoseCommitChain: lntypes.Local,
mutateState: false, mutateState: false,
ourHtlcs: []*PaymentDescriptor{ ourHtlcs: []*PaymentDescriptor{
{ {
Amount: ourFeeUpdateAmt, Amount: ourFeeUpdateAmt,
@ -8293,10 +8293,10 @@ func TestEvaluateView(t *testing.T) {
expectSent: 0, expectSent: 0,
}, },
{ {
name: "their fee update is applied", name: "their fee update is applied",
remoteChain: false, whoseCommitChain: lntypes.Local,
mutateState: false, mutateState: false,
ourHtlcs: []*PaymentDescriptor{}, ourHtlcs: []*PaymentDescriptor{},
theirHtlcs: []*PaymentDescriptor{ theirHtlcs: []*PaymentDescriptor{
{ {
Amount: theirFeeUpdateAmt, Amount: theirFeeUpdateAmt,
@ -8311,9 +8311,9 @@ func TestEvaluateView(t *testing.T) {
}, },
{ {
// We expect unresolved htlcs to to remain in the view. // We expect unresolved htlcs to to remain in the view.
name: "htlcs adds without settles", name: "htlcs adds without settles",
remoteChain: false, whoseCommitChain: lntypes.Local,
mutateState: false, mutateState: false,
ourHtlcs: []*PaymentDescriptor{ ourHtlcs: []*PaymentDescriptor{
{ {
HtlcIndex: 0, HtlcIndex: 0,
@ -8345,9 +8345,9 @@ func TestEvaluateView(t *testing.T) {
expectSent: 0, expectSent: 0,
}, },
{ {
name: "our htlc settled, state mutated", name: "our htlc settled, state mutated",
remoteChain: false, whoseCommitChain: lntypes.Local,
mutateState: true, mutateState: true,
ourHtlcs: []*PaymentDescriptor{ ourHtlcs: []*PaymentDescriptor{
{ {
HtlcIndex: 0, HtlcIndex: 0,
@ -8380,9 +8380,9 @@ func TestEvaluateView(t *testing.T) {
expectSent: htlcAddAmount, expectSent: htlcAddAmount,
}, },
{ {
name: "our htlc settled, state not mutated", name: "our htlc settled, state not mutated",
remoteChain: false, whoseCommitChain: lntypes.Local,
mutateState: false, mutateState: false,
ourHtlcs: []*PaymentDescriptor{ ourHtlcs: []*PaymentDescriptor{
{ {
HtlcIndex: 0, HtlcIndex: 0,
@ -8415,9 +8415,9 @@ func TestEvaluateView(t *testing.T) {
expectSent: 0, expectSent: 0,
}, },
{ {
name: "their htlc settled, state mutated", name: "their htlc settled, state mutated",
remoteChain: false, whoseCommitChain: lntypes.Local,
mutateState: true, mutateState: true,
ourHtlcs: []*PaymentDescriptor{ ourHtlcs: []*PaymentDescriptor{
{ {
HtlcIndex: 0, HtlcIndex: 0,
@ -8458,9 +8458,10 @@ func TestEvaluateView(t *testing.T) {
expectSent: 0, expectSent: 0,
}, },
{ {
name: "their htlc settled, state not mutated", name: "their htlc settled, state not mutated",
remoteChain: false,
mutateState: false, whoseCommitChain: lntypes.Local,
mutateState: false,
ourHtlcs: []*PaymentDescriptor{ ourHtlcs: []*PaymentDescriptor{
{ {
HtlcIndex: 0, HtlcIndex: 0,
@ -8543,7 +8544,7 @@ func TestEvaluateView(t *testing.T) {
// Evaluate the htlc view, mutate as test expects. // Evaluate the htlc view, mutate as test expects.
result, err := lc.evaluateHTLCView( result, err := lc.evaluateHTLCView(
view, &ourBalance, &theirBalance, nextHeight, view, &ourBalance, &theirBalance, nextHeight,
test.remoteChain, test.mutateState, test.whoseCommitChain, test.mutateState,
) )
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
@ -8631,12 +8632,12 @@ func TestProcessFeeUpdate(t *testing.T) {
) )
tests := []struct { tests := []struct {
name string name string
startHeights heights startHeights heights
expectedHeights heights expectedHeights heights
remoteChain bool whoseCommitChain lntypes.ChannelParty
mutate bool mutate bool
expectedFee chainfee.SatPerKWeight expectedFee chainfee.SatPerKWeight
}{ }{
{ {
// Looking at local chain, local add is non-zero so // Looking at local chain, local add is non-zero so
@ -8654,9 +8655,9 @@ func TestProcessFeeUpdate(t *testing.T) {
remoteAdd: 0, remoteAdd: 0,
remoteRemove: height, remoteRemove: height,
}, },
remoteChain: false, whoseCommitChain: lntypes.Local,
mutate: false, mutate: false,
expectedFee: feePerKw, expectedFee: feePerKw,
}, },
{ {
// Looking at local chain, local add is zero so the // Looking at local chain, local add is zero so the
@ -8675,9 +8676,9 @@ func TestProcessFeeUpdate(t *testing.T) {
remoteAdd: height, remoteAdd: height,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: false, whoseCommitChain: lntypes.Local,
mutate: false, mutate: false,
expectedFee: ourFeeUpdatePerSat, expectedFee: ourFeeUpdatePerSat,
}, },
{ {
// Looking at remote chain, the remote add height is // Looking at remote chain, the remote add height is
@ -8696,9 +8697,9 @@ func TestProcessFeeUpdate(t *testing.T) {
remoteAdd: 0, remoteAdd: 0,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: true, whoseCommitChain: lntypes.Remote,
mutate: false, mutate: false,
expectedFee: ourFeeUpdatePerSat, expectedFee: ourFeeUpdatePerSat,
}, },
{ {
// Looking at remote chain, the remote add height is // Looking at remote chain, the remote add height is
@ -8717,9 +8718,9 @@ func TestProcessFeeUpdate(t *testing.T) {
remoteAdd: height, remoteAdd: height,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: true, whoseCommitChain: lntypes.Remote,
mutate: false, mutate: false,
expectedFee: feePerKw, expectedFee: feePerKw,
}, },
{ {
// Local add height is non-zero, so the update has // Local add height is non-zero, so the update has
@ -8738,9 +8739,9 @@ func TestProcessFeeUpdate(t *testing.T) {
remoteAdd: 0, remoteAdd: 0,
remoteRemove: height, remoteRemove: height,
}, },
remoteChain: false, whoseCommitChain: lntypes.Local,
mutate: true, mutate: true,
expectedFee: feePerKw, expectedFee: feePerKw,
}, },
{ {
// Local add is zero and we are looking at our local // Local add is zero and we are looking at our local
@ -8760,9 +8761,9 @@ func TestProcessFeeUpdate(t *testing.T) {
remoteAdd: 0, remoteAdd: 0,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: false, whoseCommitChain: lntypes.Local,
mutate: true, mutate: true,
expectedFee: ourFeeUpdatePerSat, expectedFee: ourFeeUpdatePerSat,
}, },
} }
@ -8786,7 +8787,7 @@ func TestProcessFeeUpdate(t *testing.T) {
feePerKw: chainfee.SatPerKWeight(feePerKw), feePerKw: chainfee.SatPerKWeight(feePerKw),
} }
processFeeUpdate( processFeeUpdate(
update, nextHeight, test.remoteChain, update, nextHeight, test.whoseCommitChain,
test.mutate, view, test.mutate, view,
) )
@ -8841,7 +8842,7 @@ func TestProcessAddRemoveEntry(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
startHeights heights startHeights heights
remoteChain bool whoseCommitChain lntypes.ChannelParty
isIncoming bool isIncoming bool
mutateState bool mutateState bool
ourExpectedBalance lnwire.MilliSatoshi ourExpectedBalance lnwire.MilliSatoshi
@ -8857,7 +8858,7 @@ func TestProcessAddRemoveEntry(t *testing.T) {
localRemove: 0, localRemove: 0,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: true, whoseCommitChain: lntypes.Remote,
isIncoming: false, isIncoming: false,
mutateState: false, mutateState: false,
ourExpectedBalance: startBalance, ourExpectedBalance: startBalance,
@ -8878,7 +8879,7 @@ func TestProcessAddRemoveEntry(t *testing.T) {
localRemove: 0, localRemove: 0,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: false, whoseCommitChain: lntypes.Local,
isIncoming: false, isIncoming: false,
mutateState: false, mutateState: false,
ourExpectedBalance: startBalance, ourExpectedBalance: startBalance,
@ -8899,7 +8900,7 @@ func TestProcessAddRemoveEntry(t *testing.T) {
localRemove: 0, localRemove: 0,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: false, whoseCommitChain: lntypes.Local,
isIncoming: true, isIncoming: true,
mutateState: false, mutateState: false,
ourExpectedBalance: startBalance, ourExpectedBalance: startBalance,
@ -8920,7 +8921,7 @@ func TestProcessAddRemoveEntry(t *testing.T) {
localRemove: 0, localRemove: 0,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: false, whoseCommitChain: lntypes.Local,
isIncoming: true, isIncoming: true,
mutateState: true, mutateState: true,
ourExpectedBalance: startBalance, ourExpectedBalance: startBalance,
@ -8942,7 +8943,7 @@ func TestProcessAddRemoveEntry(t *testing.T) {
localRemove: 0, localRemove: 0,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: true, whoseCommitChain: lntypes.Remote,
isIncoming: false, isIncoming: false,
mutateState: false, mutateState: false,
ourExpectedBalance: startBalance - updateAmount, ourExpectedBalance: startBalance - updateAmount,
@ -8963,7 +8964,7 @@ func TestProcessAddRemoveEntry(t *testing.T) {
localRemove: 0, localRemove: 0,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: true, whoseCommitChain: lntypes.Remote,
isIncoming: false, isIncoming: false,
mutateState: true, mutateState: true,
ourExpectedBalance: startBalance - updateAmount, ourExpectedBalance: startBalance - updateAmount,
@ -8984,7 +8985,7 @@ func TestProcessAddRemoveEntry(t *testing.T) {
localRemove: 0, localRemove: 0,
remoteRemove: removeHeight, remoteRemove: removeHeight,
}, },
remoteChain: true, whoseCommitChain: lntypes.Remote,
isIncoming: false, isIncoming: false,
mutateState: false, mutateState: false,
ourExpectedBalance: startBalance, ourExpectedBalance: startBalance,
@ -9005,7 +9006,7 @@ func TestProcessAddRemoveEntry(t *testing.T) {
localRemove: removeHeight, localRemove: removeHeight,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: false, whoseCommitChain: lntypes.Local,
isIncoming: false, isIncoming: false,
mutateState: false, mutateState: false,
ourExpectedBalance: startBalance, ourExpectedBalance: startBalance,
@ -9028,7 +9029,7 @@ func TestProcessAddRemoveEntry(t *testing.T) {
localRemove: 0, localRemove: 0,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: true, whoseCommitChain: lntypes.Remote,
isIncoming: true, isIncoming: true,
mutateState: false, mutateState: false,
ourExpectedBalance: startBalance + updateAmount, ourExpectedBalance: startBalance + updateAmount,
@ -9051,7 +9052,7 @@ func TestProcessAddRemoveEntry(t *testing.T) {
localRemove: 0, localRemove: 0,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: true, whoseCommitChain: lntypes.Remote,
isIncoming: false, isIncoming: false,
mutateState: false, mutateState: false,
ourExpectedBalance: startBalance, ourExpectedBalance: startBalance,
@ -9074,7 +9075,7 @@ func TestProcessAddRemoveEntry(t *testing.T) {
localRemove: 0, localRemove: 0,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: true, whoseCommitChain: lntypes.Remote,
isIncoming: true, isIncoming: true,
mutateState: false, mutateState: false,
ourExpectedBalance: startBalance, ourExpectedBalance: startBalance,
@ -9097,7 +9098,7 @@ func TestProcessAddRemoveEntry(t *testing.T) {
localRemove: 0, localRemove: 0,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: true, whoseCommitChain: lntypes.Remote,
isIncoming: false, isIncoming: false,
mutateState: false, mutateState: false,
ourExpectedBalance: startBalance + updateAmount, ourExpectedBalance: startBalance + updateAmount,
@ -9122,7 +9123,7 @@ func TestProcessAddRemoveEntry(t *testing.T) {
localRemove: 0, localRemove: 0,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: false, whoseCommitChain: lntypes.Local,
isIncoming: true, isIncoming: true,
mutateState: true, mutateState: true,
ourExpectedBalance: startBalance + updateAmount, ourExpectedBalance: startBalance + updateAmount,
@ -9147,7 +9148,7 @@ func TestProcessAddRemoveEntry(t *testing.T) {
localRemove: 0, localRemove: 0,
remoteRemove: 0, remoteRemove: 0,
}, },
remoteChain: true, whoseCommitChain: lntypes.Remote,
isIncoming: true, isIncoming: true,
mutateState: true, mutateState: true,
ourExpectedBalance: startBalance + updateAmount, ourExpectedBalance: startBalance + updateAmount,
@ -9196,7 +9197,7 @@ func TestProcessAddRemoveEntry(t *testing.T) {
process( process(
update, &ourBalance, &theirBalance, nextHeight, update, &ourBalance, &theirBalance, nextHeight,
test.remoteChain, test.isIncoming, test.whoseCommitChain, test.isIncoming,
test.mutateState, test.mutateState,
) )
@ -9752,11 +9753,11 @@ func testGetDustSum(t *testing.T, chantype channeldb.ChannelType) {
expRemote lnwire.MilliSatoshi) { expRemote lnwire.MilliSatoshi) {
localDustSum := c.GetDustSum( localDustSum := c.GetDustSum(
false, fn.None[chainfee.SatPerKWeight](), lntypes.Local, fn.None[chainfee.SatPerKWeight](),
) )
require.Equal(t, expLocal, localDustSum) require.Equal(t, expLocal, localDustSum)
remoteDustSum := c.GetDustSum( remoteDustSum := c.GetDustSum(
true, fn.None[chainfee.SatPerKWeight](), lntypes.Remote, fn.None[chainfee.SatPerKWeight](),
) )
require.Equal(t, expRemote, remoteDustSum) require.Equal(t, expRemote, remoteDustSum)
} }
@ -9910,8 +9911,9 @@ func deriveDummyRetributionParams(chanState *channeldb.OpenChannel) (uint32,
config := chanState.RemoteChanCfg config := chanState.RemoteChanCfg
commitHash := chanState.RemoteCommitment.CommitTx.TxHash() commitHash := chanState.RemoteCommitment.CommitTx.TxHash()
keyRing := DeriveCommitmentKeys( keyRing := DeriveCommitmentKeys(
config.RevocationBasePoint.PubKey, false, chanState.ChanType, config.RevocationBasePoint.PubKey, lntypes.Remote,
&chanState.LocalChanCfg, &chanState.RemoteChanCfg, chanState.ChanType, &chanState.LocalChanCfg,
&chanState.RemoteChanCfg,
) )
leaseExpiry := chanState.ThawHeight leaseExpiry := chanState.ThawHeight
return leaseExpiry, keyRing, commitHash return leaseExpiry, keyRing, commitHash
@ -10378,7 +10380,7 @@ func TestExtractPayDescs(t *testing.T) {
// NOTE: we use nil commitment key rings to avoid checking the htlc // NOTE: we use nil commitment key rings to avoid checking the htlc
// scripts(`genHtlcScript`) as it should be tested independently. // scripts(`genHtlcScript`) as it should be tested independently.
incomingPDs, outgoingPDs, err := lnChan.extractPayDescs( incomingPDs, outgoingPDs, err := lnChan.extractPayDescs(
0, 0, htlcs, nil, nil, true, 0, 0, htlcs, nil, nil, lntypes.Local,
) )
require.NoError(t, err) require.NoError(t, err)

View file

@ -103,7 +103,7 @@ type CommitmentKeyRing struct {
// of channel, and whether the commitment transaction is ours or the remote // of channel, and whether the commitment transaction is ours or the remote
// peer's. // peer's.
func DeriveCommitmentKeys(commitPoint *btcec.PublicKey, func DeriveCommitmentKeys(commitPoint *btcec.PublicKey,
isOurCommit bool, chanType channeldb.ChannelType, whoseCommit lntypes.ChannelParty, chanType channeldb.ChannelType,
localChanCfg, remoteChanCfg *channeldb.ChannelConfig) *CommitmentKeyRing { localChanCfg, remoteChanCfg *channeldb.ChannelConfig) *CommitmentKeyRing {
tweaklessCommit := chanType.IsTweakless() tweaklessCommit := chanType.IsTweakless()
@ -111,7 +111,7 @@ func DeriveCommitmentKeys(commitPoint *btcec.PublicKey,
// Depending on if this is our commit or not, we'll choose the correct // Depending on if this is our commit or not, we'll choose the correct
// base point. // base point.
localBasePoint := localChanCfg.PaymentBasePoint localBasePoint := localChanCfg.PaymentBasePoint
if isOurCommit { if whoseCommit.IsLocal() {
localBasePoint = localChanCfg.DelayBasePoint localBasePoint = localChanCfg.DelayBasePoint
} }
@ -144,7 +144,7 @@ func DeriveCommitmentKeys(commitPoint *btcec.PublicKey,
toRemoteBasePoint *btcec.PublicKey toRemoteBasePoint *btcec.PublicKey
revocationBasePoint *btcec.PublicKey revocationBasePoint *btcec.PublicKey
) )
if isOurCommit { if whoseCommit.IsLocal() {
toLocalBasePoint = localChanCfg.DelayBasePoint.PubKey toLocalBasePoint = localChanCfg.DelayBasePoint.PubKey
toRemoteBasePoint = remoteChanCfg.PaymentBasePoint.PubKey toRemoteBasePoint = remoteChanCfg.PaymentBasePoint.PubKey
revocationBasePoint = remoteChanCfg.RevocationBasePoint.PubKey revocationBasePoint = remoteChanCfg.RevocationBasePoint.PubKey
@ -169,7 +169,7 @@ func DeriveCommitmentKeys(commitPoint *btcec.PublicKey,
// If this is not our commitment, the above ToRemoteKey will be // If this is not our commitment, the above ToRemoteKey will be
// ours, and we blank out the local commitment tweak to // ours, and we blank out the local commitment tweak to
// indicate that the key should not be tweaked when signing. // indicate that the key should not be tweaked when signing.
if !isOurCommit { if whoseCommit.IsRemote() {
keyRing.LocalCommitKeyTweak = nil keyRing.LocalCommitKeyTweak = nil
} }
} else { } else {
@ -686,20 +686,20 @@ type unsignedCommitmentTx struct {
// passed in balances should be balances *before* subtracting any commitment // passed in balances should be balances *before* subtracting any commitment
// fees, but after anchor outputs. // fees, but after anchor outputs.
func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance,
theirBalance lnwire.MilliSatoshi, isOurs bool, theirBalance lnwire.MilliSatoshi, whoseCommit lntypes.ChannelParty,
feePerKw chainfee.SatPerKWeight, height uint64, feePerKw chainfee.SatPerKWeight, height uint64,
filteredHTLCView *htlcView, filteredHTLCView *htlcView,
keyRing *CommitmentKeyRing) (*unsignedCommitmentTx, error) { keyRing *CommitmentKeyRing) (*unsignedCommitmentTx, error) {
dustLimit := cb.chanState.LocalChanCfg.DustLimit dustLimit := cb.chanState.LocalChanCfg.DustLimit
if !isOurs { if whoseCommit.IsRemote() {
dustLimit = cb.chanState.RemoteChanCfg.DustLimit dustLimit = cb.chanState.RemoteChanCfg.DustLimit
} }
numHTLCs := int64(0) numHTLCs := int64(0)
for _, htlc := range filteredHTLCView.ourUpdates { for _, htlc := range filteredHTLCView.ourUpdates {
if HtlcIsDust( if HtlcIsDust(
cb.chanState.ChanType, false, isOurs, feePerKw, cb.chanState.ChanType, false, whoseCommit, feePerKw,
htlc.Amount.ToSatoshis(), dustLimit, htlc.Amount.ToSatoshis(), dustLimit,
) { ) {
@ -710,7 +710,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance,
} }
for _, htlc := range filteredHTLCView.theirUpdates { for _, htlc := range filteredHTLCView.theirUpdates {
if HtlcIsDust( if HtlcIsDust(
cb.chanState.ChanType, true, isOurs, feePerKw, cb.chanState.ChanType, true, whoseCommit, feePerKw,
htlc.Amount.ToSatoshis(), dustLimit, htlc.Amount.ToSatoshis(), dustLimit,
) { ) {
@ -763,7 +763,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance,
if cb.chanState.ChanType.HasLeaseExpiration() { if cb.chanState.ChanType.HasLeaseExpiration() {
leaseExpiry = cb.chanState.ThawHeight leaseExpiry = cb.chanState.ThawHeight
} }
if isOurs { if whoseCommit.IsLocal() {
commitTx, err = CreateCommitTx( commitTx, err = CreateCommitTx(
cb.chanState.ChanType, fundingTxIn(cb.chanState), keyRing, cb.chanState.ChanType, fundingTxIn(cb.chanState), keyRing,
&cb.chanState.LocalChanCfg, &cb.chanState.RemoteChanCfg, &cb.chanState.LocalChanCfg, &cb.chanState.RemoteChanCfg,
@ -794,7 +794,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance,
cltvs := make([]uint32, len(commitTx.TxOut)) cltvs := make([]uint32, len(commitTx.TxOut))
for _, htlc := range filteredHTLCView.ourUpdates { for _, htlc := range filteredHTLCView.ourUpdates {
if HtlcIsDust( if HtlcIsDust(
cb.chanState.ChanType, false, isOurs, feePerKw, cb.chanState.ChanType, false, whoseCommit, feePerKw,
htlc.Amount.ToSatoshis(), dustLimit, htlc.Amount.ToSatoshis(), dustLimit,
) { ) {
@ -802,7 +802,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance,
} }
err := addHTLC( err := addHTLC(
commitTx, isOurs, false, htlc, keyRing, commitTx, whoseCommit, false, htlc, keyRing,
cb.chanState.ChanType, cb.chanState.ChanType,
) )
if err != nil { if err != nil {
@ -812,7 +812,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance,
} }
for _, htlc := range filteredHTLCView.theirUpdates { for _, htlc := range filteredHTLCView.theirUpdates {
if HtlcIsDust( if HtlcIsDust(
cb.chanState.ChanType, true, isOurs, feePerKw, cb.chanState.ChanType, true, whoseCommit, feePerKw,
htlc.Amount.ToSatoshis(), dustLimit, htlc.Amount.ToSatoshis(), dustLimit,
) { ) {
@ -820,7 +820,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance,
} }
err := addHTLC( err := addHTLC(
commitTx, isOurs, true, htlc, keyRing, commitTx, whoseCommit, true, htlc, keyRing,
cb.chanState.ChanType, cb.chanState.ChanType,
) )
if err != nil { if err != nil {
@ -1003,8 +1003,9 @@ func CoopCloseBalance(chanType channeldb.ChannelType, isInitiator bool,
// genSegwitV0HtlcScript generates the HTLC scripts for a normal segwit v0 // genSegwitV0HtlcScript generates the HTLC scripts for a normal segwit v0
// channel. // channel.
func genSegwitV0HtlcScript(chanType channeldb.ChannelType, func genSegwitV0HtlcScript(chanType channeldb.ChannelType,
isIncoming, ourCommit bool, timeout uint32, rHash [32]byte, isIncoming bool, whoseCommit lntypes.ChannelParty, timeout uint32,
keyRing *CommitmentKeyRing) (*WitnessScriptDesc, error) { rHash [32]byte, keyRing *CommitmentKeyRing,
) (*WitnessScriptDesc, error) {
var ( var (
witnessScript []byte witnessScript []byte
@ -1024,7 +1025,7 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType,
// The HTLC is paying to us, and being applied to our commitment // The HTLC is paying to us, and being applied to our commitment
// transaction. So we need to use the receiver's version of the HTLC // transaction. So we need to use the receiver's version of the HTLC
// script. // script.
case isIncoming && ourCommit: case isIncoming && whoseCommit.IsLocal():
witnessScript, err = input.ReceiverHTLCScript( witnessScript, err = input.ReceiverHTLCScript(
timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey,
keyRing.RevocationKey, rHash[:], confirmedHtlcSpends, keyRing.RevocationKey, rHash[:], confirmedHtlcSpends,
@ -1033,7 +1034,7 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType,
// We're being paid via an HTLC by the remote party, and the HTLC is // We're being paid via an HTLC by the remote party, and the HTLC is
// being added to their commitment transaction, so we use the sender's // being added to their commitment transaction, so we use the sender's
// version of the HTLC script. // version of the HTLC script.
case isIncoming && !ourCommit: case isIncoming && whoseCommit.IsRemote():
witnessScript, err = input.SenderHTLCScript( witnessScript, err = input.SenderHTLCScript(
keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey,
keyRing.RevocationKey, rHash[:], confirmedHtlcSpends, keyRing.RevocationKey, rHash[:], confirmedHtlcSpends,
@ -1042,7 +1043,7 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType,
// We're sending an HTLC which is being added to our commitment // We're sending an HTLC which is being added to our commitment
// transaction. Therefore, we need to use the sender's version of the // transaction. Therefore, we need to use the sender's version of the
// HTLC script. // HTLC script.
case !isIncoming && ourCommit: case !isIncoming && whoseCommit.IsLocal():
witnessScript, err = input.SenderHTLCScript( witnessScript, err = input.SenderHTLCScript(
keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey,
keyRing.RevocationKey, rHash[:], confirmedHtlcSpends, keyRing.RevocationKey, rHash[:], confirmedHtlcSpends,
@ -1051,7 +1052,7 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType,
// Finally, we're paying the remote party via an HTLC, which is being // Finally, we're paying the remote party via an HTLC, which is being
// added to their commitment transaction. Therefore, we use the // added to their commitment transaction. Therefore, we use the
// receiver's version of the HTLC script. // receiver's version of the HTLC script.
case !isIncoming && !ourCommit: case !isIncoming && whoseCommit.IsRemote():
witnessScript, err = input.ReceiverHTLCScript( witnessScript, err = input.ReceiverHTLCScript(
timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey,
keyRing.RevocationKey, rHash[:], confirmedHtlcSpends, keyRing.RevocationKey, rHash[:], confirmedHtlcSpends,
@ -1076,9 +1077,9 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType,
// genTaprootHtlcScript generates the HTLC scripts for a taproot+musig2 // genTaprootHtlcScript generates the HTLC scripts for a taproot+musig2
// channel. // channel.
func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, func genTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty,
rHash [32]byte, timeout uint32, rHash [32]byte, keyRing *CommitmentKeyRing,
keyRing *CommitmentKeyRing) (*input.HtlcScriptTree, error) { ) (*input.HtlcScriptTree, error) {
var ( var (
htlcScriptTree *input.HtlcScriptTree htlcScriptTree *input.HtlcScriptTree
@ -1092,37 +1093,37 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32,
// The HTLC is paying to us, and being applied to our commitment // The HTLC is paying to us, and being applied to our commitment
// transaction. So we need to use the receiver's version of HTLC the // transaction. So we need to use the receiver's version of HTLC the
// script. // script.
case isIncoming && ourCommit: case isIncoming && whoseCommit.IsLocal():
htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( htlcScriptTree, err = input.ReceiverHTLCScriptTaproot(
timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey,
keyRing.RevocationKey, rHash[:], ourCommit, keyRing.RevocationKey, rHash[:], whoseCommit,
) )
// We're being paid via an HTLC by the remote party, and the HTLC is // We're being paid via an HTLC by the remote party, and the HTLC is
// being added to their commitment transaction, so we use the sender's // being added to their commitment transaction, so we use the sender's
// version of the HTLC script. // version of the HTLC script.
case isIncoming && !ourCommit: case isIncoming && whoseCommit.IsRemote():
htlcScriptTree, err = input.SenderHTLCScriptTaproot( htlcScriptTree, err = input.SenderHTLCScriptTaproot(
keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey,
keyRing.RevocationKey, rHash[:], ourCommit, keyRing.RevocationKey, rHash[:], whoseCommit,
) )
// We're sending an HTLC which is being added to our commitment // We're sending an HTLC which is being added to our commitment
// transaction. Therefore, we need to use the sender's version of the // transaction. Therefore, we need to use the sender's version of the
// HTLC script. // HTLC script.
case !isIncoming && ourCommit: case !isIncoming && whoseCommit.IsLocal():
htlcScriptTree, err = input.SenderHTLCScriptTaproot( htlcScriptTree, err = input.SenderHTLCScriptTaproot(
keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey,
keyRing.RevocationKey, rHash[:], ourCommit, keyRing.RevocationKey, rHash[:], whoseCommit,
) )
// Finally, we're paying the remote party via an HTLC, which is being // Finally, we're paying the remote party via an HTLC, which is being
// added to their commitment transaction. Therefore, we use the // added to their commitment transaction. Therefore, we use the
// receiver's version of the HTLC script. // receiver's version of the HTLC script.
case !isIncoming && !ourCommit: case !isIncoming && whoseCommit.IsRemote():
htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( htlcScriptTree, err = input.ReceiverHTLCScriptTaproot(
timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey,
keyRing.RevocationKey, rHash[:], ourCommit, keyRing.RevocationKey, rHash[:], whoseCommit,
) )
} }
@ -1135,19 +1136,20 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32,
// multiplexer for the various spending paths is returned. The script path that // multiplexer for the various spending paths is returned. The script path that
// we need to sign for the remote party (2nd level HTLCs) is also returned // we need to sign for the remote party (2nd level HTLCs) is also returned
// along side the multiplexer. // along side the multiplexer.
func genHtlcScript(chanType channeldb.ChannelType, isIncoming, ourCommit bool, func genHtlcScript(chanType channeldb.ChannelType, isIncoming bool,
timeout uint32, rHash [32]byte, keyRing *CommitmentKeyRing, whoseCommit lntypes.ChannelParty, timeout uint32, rHash [32]byte,
keyRing *CommitmentKeyRing,
) (input.ScriptDescriptor, error) { ) (input.ScriptDescriptor, error) {
if !chanType.IsTaproot() { if !chanType.IsTaproot() {
return genSegwitV0HtlcScript( return genSegwitV0HtlcScript(
chanType, isIncoming, ourCommit, timeout, rHash, chanType, isIncoming, whoseCommit, timeout, rHash,
keyRing, keyRing,
) )
} }
return genTaprootHtlcScript( return genTaprootHtlcScript(
isIncoming, ourCommit, timeout, rHash, keyRing, isIncoming, whoseCommit, timeout, rHash, keyRing,
) )
} }
@ -1158,7 +1160,7 @@ func genHtlcScript(chanType channeldb.ChannelType, isIncoming, ourCommit bool,
// locate the added HTLC on the commitment transaction from the // locate the added HTLC on the commitment transaction from the
// PaymentDescriptor that generated it, the generated script is stored within // PaymentDescriptor that generated it, the generated script is stored within
// the descriptor itself. // the descriptor itself.
func addHTLC(commitTx *wire.MsgTx, ourCommit bool, func addHTLC(commitTx *wire.MsgTx, whoseCommit lntypes.ChannelParty,
isIncoming bool, paymentDesc *PaymentDescriptor, isIncoming bool, paymentDesc *PaymentDescriptor,
keyRing *CommitmentKeyRing, chanType channeldb.ChannelType) error { keyRing *CommitmentKeyRing, chanType channeldb.ChannelType) error {
@ -1166,7 +1168,7 @@ func addHTLC(commitTx *wire.MsgTx, ourCommit bool,
rHash := paymentDesc.RHash rHash := paymentDesc.RHash
scriptInfo, err := genHtlcScript( scriptInfo, err := genHtlcScript(
chanType, isIncoming, ourCommit, timeout, rHash, keyRing, chanType, isIncoming, whoseCommit, timeout, rHash, keyRing,
) )
if err != nil { if err != nil {
return err return err
@ -1180,7 +1182,7 @@ func addHTLC(commitTx *wire.MsgTx, ourCommit bool,
// Store the pkScript of this particular PaymentDescriptor so we can // Store the pkScript of this particular PaymentDescriptor so we can
// quickly locate it within the commitment transaction later. // quickly locate it within the commitment transaction later.
if ourCommit { if whoseCommit.IsLocal() {
paymentDesc.ourPkScript = pkScript paymentDesc.ourPkScript = pkScript
paymentDesc.ourWitnessScript = scriptInfo.WitnessScriptToSign() paymentDesc.ourWitnessScript = scriptInfo.WitnessScriptToSign()
@ -1211,7 +1213,7 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash,
// With the commitment point generated, we can now derive the king ring // With the commitment point generated, we can now derive the king ring
// which will be used to generate the output scripts. // which will be used to generate the output scripts.
keyRing := DeriveCommitmentKeys( keyRing := DeriveCommitmentKeys(
commitmentPoint, false, chanState.ChanType, commitmentPoint, lntypes.Remote, chanState.ChanType,
&chanState.LocalChanCfg, &chanState.RemoteChanCfg, &chanState.LocalChanCfg, &chanState.RemoteChanCfg,
) )

View file

@ -25,6 +25,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwallet/chanfunding" "github.com/lightningnetwork/lnd/lnwallet/chanfunding"
"github.com/lightningnetwork/lnd/lnwallet/chanvalidate" "github.com/lightningnetwork/lnd/lnwallet/chanvalidate"
@ -1475,10 +1476,12 @@ func CreateCommitmentTxns(localBalance, remoteBalance btcutil.Amount,
leaseExpiry uint32) (*wire.MsgTx, *wire.MsgTx, error) { leaseExpiry uint32) (*wire.MsgTx, *wire.MsgTx, error) {
localCommitmentKeys := DeriveCommitmentKeys( localCommitmentKeys := DeriveCommitmentKeys(
localCommitPoint, true, chanType, ourChanCfg, theirChanCfg, localCommitPoint, lntypes.Local, chanType, ourChanCfg,
theirChanCfg,
) )
remoteCommitmentKeys := DeriveCommitmentKeys( remoteCommitmentKeys := DeriveCommitmentKeys(
remoteCommitPoint, false, chanType, ourChanCfg, theirChanCfg, remoteCommitPoint, lntypes.Remote, chanType, ourChanCfg,
theirChanCfg,
) )
ourCommitTx, err := CreateCommitTx( ourCommitTx, err := CreateCommitTx(

View file

@ -36,6 +36,7 @@ import (
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnpeer"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chainfee"
@ -1069,7 +1070,7 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) (
chanCloser, err := p.createChanCloser( chanCloser, err := p.createChanCloser(
lnChan, info.DeliveryScript.Val, feePerKw, nil, lnChan, info.DeliveryScript.Val, feePerKw, nil,
info.LocalInitiator.Val, info.Closer(),
) )
if err != nil { if err != nil {
shutdownInfoErr = fmt.Errorf("unable to "+ shutdownInfoErr = fmt.Errorf("unable to "+
@ -2732,7 +2733,7 @@ func (p *Brontide) fetchActiveChanCloser(chanID lnwire.ChannelID) (
} }
chanCloser, err = p.createChanCloser( chanCloser, err = p.createChanCloser(
channel, deliveryScript, feePerKw, nil, false, channel, deliveryScript, feePerKw, nil, lntypes.Remote,
) )
if err != nil { if err != nil {
p.log.Errorf("unable to create chan closer: %v", err) p.log.Errorf("unable to create chan closer: %v", err)
@ -2969,12 +2970,13 @@ func (p *Brontide) restartCoopClose(lnChan *lnwallet.LightningChannel) (
// Determine whether we or the peer are the initiator of the coop // Determine whether we or the peer are the initiator of the coop
// close attempt by looking at the channel's status. // close attempt by looking at the channel's status.
locallyInitiated := c.HasChanStatus( closingParty := lntypes.Remote
channeldb.ChanStatusLocalCloseInitiator, if c.HasChanStatus(channeldb.ChanStatusLocalCloseInitiator) {
) closingParty = lntypes.Local
}
chanCloser, err := p.createChanCloser( chanCloser, err := p.createChanCloser(
lnChan, deliveryScript, feePerKw, nil, locallyInitiated, lnChan, deliveryScript, feePerKw, nil, closingParty,
) )
if err != nil { if err != nil {
p.log.Errorf("unable to create chan closer: %v", err) p.log.Errorf("unable to create chan closer: %v", err)
@ -3003,7 +3005,7 @@ func (p *Brontide) restartCoopClose(lnChan *lnwallet.LightningChannel) (
func (p *Brontide) createChanCloser(channel *lnwallet.LightningChannel, func (p *Brontide) createChanCloser(channel *lnwallet.LightningChannel,
deliveryScript lnwire.DeliveryAddress, fee chainfee.SatPerKWeight, deliveryScript lnwire.DeliveryAddress, fee chainfee.SatPerKWeight,
req *htlcswitch.ChanClose, req *htlcswitch.ChanClose,
locallyInitiated bool) (*chancloser.ChanCloser, error) { closer lntypes.ChannelParty) (*chancloser.ChanCloser, error) {
_, startingHeight, err := p.cfg.ChainIO.GetBestBlock() _, startingHeight, err := p.cfg.ChainIO.GetBestBlock()
if err != nil { if err != nil {
@ -3039,7 +3041,7 @@ func (p *Brontide) createChanCloser(channel *lnwallet.LightningChannel,
fee, fee,
uint32(startingHeight), uint32(startingHeight),
req, req,
locallyInitiated, closer,
) )
return chanCloser, nil return chanCloser, nil
@ -3096,7 +3098,8 @@ func (p *Brontide) handleLocalCloseReq(req *htlcswitch.ChanClose) {
} }
chanCloser, err := p.createChanCloser( chanCloser, err := p.createChanCloser(
channel, deliveryScript, req.TargetFeePerKw, req, true, channel, deliveryScript, req.TargetFeePerKw, req,
lntypes.Local,
) )
if err != nil { if err != nil {
p.log.Errorf(err.Error()) p.log.Errorf(err.Error())