multi: use some record for payment descriptor blinding point

This commit is contained in:
Carla Kirk-Cohen 2024-04-02 08:46:14 -04:00
parent b1175514f9
commit 7fd9c2a7f8
No known key found for this signature in database
GPG key ID: 4CA7FE54A6213C91
4 changed files with 37 additions and 89 deletions

View file

@ -9,6 +9,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
)
// Iterator is an interface that abstracts away the routing information
@ -186,7 +187,7 @@ type DecodeHopIteratorRequest struct {
RHash []byte
IncomingCltv uint32
IncomingAmount lnwire.MilliSatoshi
BlindingPoint *btcec.PublicKey
BlindingPoint lnwire.BlindingPointRecord
}
// DecodeHopIteratorResponse encapsulates the outcome of a batched sphinx onion
@ -243,12 +244,14 @@ func (p *OnionProcessor) DecodeHopIterators(id []byte,
}
var opts []sphinx.ProcessOnionOpt
if req.BlindingPoint != nil {
opts = append(opts, sphinx.WithBlindingPoint(
req.BlindingPoint,
))
}
req.BlindingPoint.WhenSome(func(
b tlv.RecordT[lnwire.BlindingPointTlvType,
*btcec.PublicKey]) {
opts = append(opts, sphinx.WithBlindingPoint(
b.Val,
))
})
err = tx.ProcessOnionPacket(
seqNum, onionPkt, req.RHash, req.IncomingCltv, opts...,
)

View file

@ -31,7 +31,6 @@ import (
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain"
"github.com/lightningnetwork/lnd/tlv"
)
var (
@ -377,7 +376,7 @@ type PaymentDescriptor struct {
// This value is set for nodes that are relaying payments inside of a
// blinded route (ie, not the introduction node) from update_add_htlc's
// TLVs.
BlindingPoint *btcec.PublicKey
BlindingPoint lnwire.BlindingPointRecord
}
// PayDescsFromRemoteLogUpdates converts a slice of LogUpdates received from the
@ -418,7 +417,7 @@ func PayDescsFromRemoteLogUpdates(chanID lnwire.ShortChannelID, height uint64,
Height: height,
Index: uint16(i),
},
BlindingPoint: wireMsg.BlingingPointOrNil(),
BlindingPoint: pd.BlindingPoint,
}
pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob))
copy(pd.OnionBlob[:], wireMsg.OnionBlob[:])
@ -742,16 +741,9 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment {
HtlcIndex: htlc.HtlcIndex,
LogIndex: htlc.LogIndex,
Incoming: false,
BlindingPoint: htlc.BlindingPoint,
}
copy(h.OnionBlob[:], htlc.OnionBlob)
if htlc.BlindingPoint != nil {
h.BlindingPoint = tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
htlc.BlindingPoint,
),
)
}
if ourCommit && htlc.sig != nil {
h.Signature = htlc.sig.Serialize()
@ -774,16 +766,9 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment {
HtlcIndex: htlc.HtlcIndex,
LogIndex: htlc.LogIndex,
Incoming: true,
BlindingPoint: htlc.BlindingPoint,
}
copy(h.OnionBlob[:], htlc.OnionBlob)
if htlc.BlindingPoint != nil {
h.BlindingPoint = tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
htlc.BlindingPoint,
),
)
}
if ourCommit && htlc.sig != nil {
h.Signature = htlc.sig.Serialize()
}
@ -866,7 +851,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight,
// With the scripts reconstructed (depending on if this is our commit
// vs theirs or a pending commit for the remote party), we can now
// re-create the original payment descriptor.
pd = PaymentDescriptor{
return PaymentDescriptor{
RHash: htlc.RHash,
Timeout: htlc.RefundTimeout,
Amount: htlc.Amt,
@ -880,15 +865,8 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight,
ourWitnessScript: ourWitnessScript,
theirPkScript: theirP2WSH,
theirWitnessScript: theirWitnessScript,
}
htlc.BlindingPoint.WhenSome(func(b tlv.RecordT[
lnwire.BlindingPointTlvType, *btcec.PublicKey]) {
pd.BlindingPoint = b.Val
})
return pd, nil
BlindingPoint: htlc.BlindingPoint,
}, nil
}
// extractPayDescs will convert all HTLC's present within a disk commit state
@ -1577,7 +1555,7 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate,
HtlcIndex: wireMsg.ID,
LogIndex: logUpdate.LogIndex,
addCommitHeightRemote: commitHeight,
BlindingPoint: wireMsg.BlingingPointOrNil(),
BlindingPoint: wireMsg.BlindingPoint,
}
pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob))
copy(pd.OnionBlob[:], wireMsg.OnionBlob[:])
@ -1775,7 +1753,7 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd
HtlcIndex: wireMsg.ID,
LogIndex: logUpdate.LogIndex,
addCommitHeightLocal: commitHeight,
BlindingPoint: wireMsg.BlingingPointOrNil(),
BlindingPoint: wireMsg.BlindingPoint,
}
pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob))
copy(pd.OnionBlob, wireMsg.OnionBlob[:])
@ -3631,21 +3609,14 @@ func (lc *LightningChannel) createCommitDiff(
switch pd.EntryType {
case Add:
htlc := &lnwire.UpdateAddHTLC{
ChanID: chanID,
ID: pd.HtlcIndex,
Amount: pd.Amount,
Expiry: pd.Timeout,
PaymentHash: pd.RHash,
ChanID: chanID,
ID: pd.HtlcIndex,
Amount: pd.Amount,
Expiry: pd.Timeout,
PaymentHash: pd.RHash,
BlindingPoint: pd.BlindingPoint,
}
copy(htlc.OnionBlob[:], pd.OnionBlob)
if pd.BlindingPoint != nil {
htlc.BlindingPoint = tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
pd.BlindingPoint,
),
)
}
logUpdate.UpdateMsg = htlc
// Gather any references for circuits opened by this Add
@ -3775,21 +3746,13 @@ func (lc *LightningChannel) getUnsignedAckedUpdates() []channeldb.LogUpdate {
// four messages that it corresponds to.
switch pd.EntryType {
case Add:
var b lnwire.BlindingPointRecord
if pd.BlindingPoint != nil {
tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](pd.BlindingPoint),
)
}
htlc := &lnwire.UpdateAddHTLC{
ChanID: chanID,
ID: pd.HtlcIndex,
Amount: pd.Amount,
Expiry: pd.Timeout,
PaymentHash: pd.RHash,
BlindingPoint: b,
BlindingPoint: pd.BlindingPoint,
}
copy(htlc.OnionBlob[:], pd.OnionBlob)
logUpdate.UpdateMsg = htlc
@ -5784,19 +5747,12 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) (
switch pd.EntryType {
case Add:
htlc := &lnwire.UpdateAddHTLC{
ChanID: chanID,
ID: pd.HtlcIndex,
Amount: pd.Amount,
Expiry: pd.Timeout,
PaymentHash: pd.RHash,
}
if pd.BlindingPoint != nil {
htlc.BlindingPoint = tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
pd.BlindingPoint,
),
)
ChanID: chanID,
ID: pd.HtlcIndex,
Amount: pd.Amount,
Expiry: pd.Timeout,
PaymentHash: pd.RHash,
BlindingPoint: pd.BlindingPoint,
}
copy(htlc.OnionBlob[:], pd.OnionBlob)
logUpdate.UpdateMsg = htlc
@ -6135,7 +6091,7 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC,
HtlcIndex: lc.localUpdateLog.htlcCounter,
OnionBlob: htlc.OnionBlob[:],
OpenCircuitKey: openKey,
BlindingPoint: htlc.BlingingPointOrNil(),
BlindingPoint: htlc.BlindingPoint,
}
}
@ -6193,7 +6149,7 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, err
LogIndex: lc.remoteUpdateLog.logIndex,
HtlcIndex: lc.remoteUpdateLog.htlcCounter,
OnionBlob: htlc.OnionBlob[:],
BlindingPoint: htlc.BlingingPointOrNil(),
BlindingPoint: htlc.BlindingPoint,
}
localACKedIndex := lc.remoteCommitChain.tail().ourMessageIndex

View file

@ -11045,7 +11045,8 @@ func TestBlindingPointPersistence(t *testing.T) {
// Assert that the blinding point is restored from disk.
remoteCommit := aliceChannel.remoteCommitChain.tip()
require.Len(t, remoteCommit.outgoingHTLCs, 1)
require.Equal(t, blinding, remoteCommit.outgoingHTLCs[0].BlindingPoint)
require.Equal(t, blinding,
remoteCommit.outgoingHTLCs[0].BlindingPoint.UnwrapOrFailV(t))
// Next, update bob's commitment and assert that we can still retrieve
// his incoming blinding point after restart.
@ -11061,5 +11062,6 @@ func TestBlindingPointPersistence(t *testing.T) {
// Assert that Bob is able to recover the blinding point from disk.
bobCommit := bobChannel.localCommitChain.tip()
require.Len(t, bobCommit.incomingHTLCs, 1)
require.Equal(t, blinding, bobCommit.incomingHTLCs[0].BlindingPoint)
require.Equal(t, blinding,
bobCommit.incomingHTLCs[0].BlindingPoint.UnwrapOrFailV(t))
}

View file

@ -78,19 +78,6 @@ type UpdateAddHTLC struct {
ExtraData ExtraOpaqueData
}
// BlingingPointOrNil returns the blinding point associated with the update, or
// nil.
func (c *UpdateAddHTLC) BlingingPointOrNil() *btcec.PublicKey {
var blindingPoint *btcec.PublicKey
c.BlindingPoint.WhenSome(func(b tlv.RecordT[BlindingPointTlvType,
*btcec.PublicKey]) {
blindingPoint = b.Val
})
return blindingPoint
}
// NewUpdateAddHTLC returns a new empty UpdateAddHTLC message.
func NewUpdateAddHTLC() *UpdateAddHTLC {
return &UpdateAddHTLC{}