multi: add blinding point to payment descriptor and persist

This commit adds an optional blinding point to payment descriptors and
persists them in our HTLC's extra data. A get/set pattern is used to
populate the ExtraData on our disk representation of the HTLC so that
callers do not need to worry about the underlying storage detail.
This commit is contained in:
Carla Kirk-Cohen 2023-11-06 15:36:31 -05:00
parent 7596e764ac
commit f090a64142
No known key found for this signature in database
GPG Key ID: 4CA7FE54A6213C91
5 changed files with 246 additions and 40 deletions

View File

@ -35,6 +35,10 @@ const (
// begins to be interpreted as an absolute block height, rather than a
// relative one.
AbsoluteThawHeightThreshold uint32 = 500000
// HTLCBlindingPointTLV is the tlv type used for storing blinding
// points with HTLCs.
HTLCBlindingPointTLV tlv.Type = 0
)
var (
@ -2316,7 +2320,56 @@ type HTLC struct {
// Note that this extra data is stored inline with the OnionBlob for
// legacy reasons, see serialization/deserialization functions for
// detail.
ExtraData []byte
ExtraData lnwire.ExtraOpaqueData
// BlindingPoint is an optional blinding point included with the HTLC.
//
// Note: this field is not a part of on-disk representation of the
// HTLC. It is stored in the ExtraData field, which is used to store
// a TLV stream of additional information associated with the HTLC.
BlindingPoint lnwire.BlindingPointRecord
}
// serializeExtraData encodes a TLV stream of extra data to be stored with a
// HTLC. It uses the update_add_htlc TLV types, because this is where extra
// data is passed with a HTLC. At present blinding points are the only extra
// data that we will store, and the function is a no-op if a nil blinding
// point is provided.
//
// This function MUST be called to persist all HTLC values when they are
// serialized.
func (h *HTLC) serializeExtraData() error {
var records []tlv.RecordProducer
h.BlindingPoint.WhenSome(func(b tlv.RecordT[lnwire.BlindingPointTlvType,
*btcec.PublicKey]) {
records = append(records, &b)
})
return h.ExtraData.PackRecords(records...)
}
// deserializeExtraData extracts TLVs from the extra data persisted for the
// htlc and populates values in the struct accordingly.
//
// This function MUST be called to populate the struct properly when HTLCs
// are deserialized.
func (h *HTLC) deserializeExtraData() error {
if len(h.ExtraData) == 0 {
return nil
}
blindingPoint := h.BlindingPoint.Zero()
tlvMap, err := h.ExtraData.ExtractRecords(&blindingPoint)
if err != nil {
return err
}
if val, ok := tlvMap[h.BlindingPoint.TlvType()]; ok && val == nil {
h.BlindingPoint = tlv.SomeRecordT(blindingPoint)
}
return nil
}
// SerializeHtlcs writes out the passed set of HTLC's into the passed writer
@ -2340,6 +2393,12 @@ func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error {
}
for _, htlc := range htlcs {
// Populate TLV stream for any additional fields contained
// in the TLV.
if err := htlc.serializeExtraData(); err != nil {
return err
}
// The onion blob and hltc data are stored as a single var
// bytes blob.
onionAndExtraData := make(
@ -2425,6 +2484,12 @@ func DeserializeHtlcs(r io.Reader) ([]HTLC, error) {
onionAndExtraData[lnwire.OnionPacketSize:],
)
}
// Finally, deserialize any TLVs contained in that extra data
// if they are present.
if err := htlcs[i].deserializeExtraData(); err != nil {
return nil, err
}
}
return htlcs, nil

View File

@ -23,6 +23,7 @@ import (
"github.com/lightningnetwork/lnd/lntest/channels"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require"
)
@ -1606,9 +1607,25 @@ func TestHTLCsExtraData(t *testing.T) {
OnionBlob: lnmock.MockOnion(),
}
// Add a blinding point to a htlc.
blindingPointHTLC := HTLC{
Signature: testSig.Serialize(),
Incoming: false,
Amt: 10,
RHash: key,
RefundTimeout: 1,
OnionBlob: lnmock.MockOnion(),
BlindingPoint: tlv.SomeRecordT(
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
pubKey,
),
),
}
testCases := []struct {
name string
htlcs []HTLC
name string
htlcs []HTLC
blindingIdx int
}{
{
// Serialize multiple HLTCs with no extra data to
@ -1620,30 +1637,12 @@ func TestHTLCsExtraData(t *testing.T) {
},
},
{
// Some HTLCs with extra data, some without.
name: "mixed extra data",
htlcs: []HTLC{
mockHtlc,
{
Signature: testSig.Serialize(),
Incoming: false,
Amt: 10,
RHash: key,
RefundTimeout: 1,
OnionBlob: lnmock.MockOnion(),
ExtraData: []byte{1, 2, 3},
},
blindingPointHTLC,
mockHtlc,
{
Signature: testSig.Serialize(),
Incoming: false,
Amt: 10,
RHash: key,
RefundTimeout: 1,
OnionBlob: lnmock.MockOnion(),
ExtraData: bytes.Repeat(
[]byte{9}, 999,
),
},
},
},
}
@ -1661,7 +1660,15 @@ func TestHTLCsExtraData(t *testing.T) {
r := bytes.NewReader(b.Bytes())
htlcs, err := DeserializeHtlcs(r)
require.NoError(t, err)
require.Equal(t, testCase.htlcs, htlcs)
require.EqualValues(t, len(testCase.htlcs), len(htlcs))
for i, htlc := range htlcs {
// We use the extra data field when we
// serialize, so we set to nil to be able to
// assert on equal for the test.
htlc.ExtraData = nil
require.Equal(t, testCase.htlcs[i], htlc)
}
})
}
}

View File

@ -31,6 +31,7 @@ import (
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain"
"github.com/lightningnetwork/lnd/tlv"
)
var (
@ -371,6 +372,12 @@ type PaymentDescriptor struct {
// isForwarded denotes if an incoming HTLC has been forwarded to any
// possible upstream peers in the route.
isForwarded bool
// BlindingPoint is an optional ephemeral key used in route blinding.
// This value is set for nodes that are relaying payments inside of a
// blinded route (ie, not the introduction node) from update_add_htlc's
// TLVs.
BlindingPoint *btcec.PublicKey
}
// PayDescsFromRemoteLogUpdates converts a slice of LogUpdates received from the
@ -411,6 +418,7 @@ func PayDescsFromRemoteLogUpdates(chanID lnwire.ShortChannelID, height uint64,
Height: height,
Index: uint16(i),
},
BlindingPoint: wireMsg.BlingingPointOrNil(),
}
pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob))
copy(pd.OnionBlob[:], wireMsg.OnionBlob[:])
@ -736,6 +744,14 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment {
Incoming: false,
}
copy(h.OnionBlob[:], htlc.OnionBlob)
if htlc.BlindingPoint != nil {
h.BlindingPoint = tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
htlc.BlindingPoint,
),
)
}
if ourCommit && htlc.sig != nil {
h.Signature = htlc.sig.Serialize()
@ -760,7 +776,14 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment {
Incoming: true,
}
copy(h.OnionBlob[:], htlc.OnionBlob)
if htlc.BlindingPoint != nil {
h.BlindingPoint = tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
htlc.BlindingPoint,
),
)
}
if ourCommit && htlc.sig != nil {
h.Signature = htlc.sig.Serialize()
}
@ -859,6 +882,12 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight,
theirWitnessScript: theirWitnessScript,
}
htlc.BlindingPoint.WhenSome(func(b tlv.RecordT[
lnwire.BlindingPointTlvType, *btcec.PublicKey]) {
pd.BlindingPoint = b.Val
})
return pd, nil
}
@ -1548,6 +1577,7 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate,
HtlcIndex: wireMsg.ID,
LogIndex: logUpdate.LogIndex,
addCommitHeightRemote: commitHeight,
BlindingPoint: wireMsg.BlingingPointOrNil(),
}
pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob))
copy(pd.OnionBlob[:], wireMsg.OnionBlob[:])
@ -1745,6 +1775,7 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd
HtlcIndex: wireMsg.ID,
LogIndex: logUpdate.LogIndex,
addCommitHeightLocal: commitHeight,
BlindingPoint: wireMsg.BlingingPointOrNil(),
}
pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob))
copy(pd.OnionBlob, wireMsg.OnionBlob[:])
@ -3607,6 +3638,14 @@ func (lc *LightningChannel) createCommitDiff(
PaymentHash: pd.RHash,
}
copy(htlc.OnionBlob[:], pd.OnionBlob)
if pd.BlindingPoint != nil {
htlc.BlindingPoint = tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
pd.BlindingPoint,
),
)
}
logUpdate.UpdateMsg = htlc
// Gather any references for circuits opened by this Add
@ -3736,12 +3775,21 @@ func (lc *LightningChannel) getUnsignedAckedUpdates() []channeldb.LogUpdate {
// four messages that it corresponds to.
switch pd.EntryType {
case Add:
var b lnwire.BlindingPointRecord
if pd.BlindingPoint != nil {
tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](pd.BlindingPoint),
)
}
htlc := &lnwire.UpdateAddHTLC{
ChanID: chanID,
ID: pd.HtlcIndex,
Amount: pd.Amount,
Expiry: pd.Timeout,
PaymentHash: pd.RHash,
ChanID: chanID,
ID: pd.HtlcIndex,
Amount: pd.Amount,
Expiry: pd.Timeout,
PaymentHash: pd.RHash,
BlindingPoint: b,
}
copy(htlc.OnionBlob[:], pd.OnionBlob)
logUpdate.UpdateMsg = htlc
@ -5742,6 +5790,14 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) (
Expiry: pd.Timeout,
PaymentHash: pd.RHash,
}
if pd.BlindingPoint != nil {
htlc.BlindingPoint = tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
pd.BlindingPoint,
),
)
}
copy(htlc.OnionBlob[:], pd.OnionBlob)
logUpdate.UpdateMsg = htlc
addUpdates = append(addUpdates, logUpdate)
@ -6079,6 +6135,7 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC,
HtlcIndex: lc.localUpdateLog.htlcCounter,
OnionBlob: htlc.OnionBlob[:],
OpenCircuitKey: openKey,
BlindingPoint: htlc.BlingingPointOrNil(),
}
}
@ -6129,13 +6186,14 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, err
}
pd := &PaymentDescriptor{
EntryType: Add,
RHash: PaymentHash(htlc.PaymentHash),
Timeout: htlc.Expiry,
Amount: htlc.Amount,
LogIndex: lc.remoteUpdateLog.logIndex,
HtlcIndex: lc.remoteUpdateLog.htlcCounter,
OnionBlob: htlc.OnionBlob[:],
EntryType: Add,
RHash: PaymentHash(htlc.PaymentHash),
Timeout: htlc.Expiry,
Amount: htlc.Amount,
LogIndex: lc.remoteUpdateLog.logIndex,
HtlcIndex: lc.remoteUpdateLog.htlcCounter,
OnionBlob: htlc.OnionBlob[:],
BlindingPoint: htlc.BlingingPointOrNil(),
}
localACKedIndex := lc.remoteCommitChain.tail().ourMessageIndex

View File

@ -25,6 +25,7 @@ import (
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require"
)
@ -10419,8 +10420,9 @@ func createRandomHTLC(t *testing.T, incoming bool) channeldb.HTLC {
_, err = rand.Read(sig)
require.NoError(t, err)
extra := make([]byte, 1000)
_, err = rand.Read(extra)
blinding, err := pubkeyFromHex(
"0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d48236c39", //nolint:lll
)
require.NoError(t, err)
return channeldb.HTLC{
@ -10433,7 +10435,10 @@ func createRandomHTLC(t *testing.T, incoming bool) channeldb.HTLC {
OnionBlob: onionBlob,
HtlcIndex: rand.Uint64(),
LogIndex: rand.Uint64(),
ExtraData: extra,
BlindingPoint: tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](blinding),
),
}
}
@ -11000,3 +11005,61 @@ func TestEnforceFeeBuffer(t *testing.T) {
require.Equal(t, aliceBalance, expectedAmt)
}
// TestBlindingPointPersistence tests persistence of blinding points attached
// to htlcs across restarts.
func TestBlindingPointPersistence(t *testing.T) {
// Create a test channel which will be used for the duration of this
// test. The channel will be funded evenly with Alice having 5 BTC,
// and Bob having 5 BTC.
aliceChannel, bobChannel, err := CreateTestChannels(
t, channeldb.SingleFunderTweaklessBit,
)
require.NoError(t, err, "unable to create test channels")
// Send a HTLC from Alice to Bob that has a blinding point populated.
htlc, _ := createHTLC(0, 100_000_000)
blinding, err := pubkeyFromHex(
"0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d48236c39", //nolint:lll
)
require.NoError(t, err)
htlc.BlindingPoint = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](blinding),
)
_, err = aliceChannel.AddHTLC(htlc, nil)
require.NoError(t, err)
_, err = bobChannel.ReceiveHTLC(htlc)
require.NoError(t, err)
// Now, Alice will send a new commitment to Bob, which will persist our
// pending HTLC to disk.
aliceCommit, err := aliceChannel.SignNextCommitment()
require.NoError(t, err, "unable to sign commitment")
// Restart alice to force fetching state from disk.
aliceChannel, err = restartChannel(aliceChannel)
require.NoError(t, err, "unable to restart alice")
// Assert that the blinding point is restored from disk.
remoteCommit := aliceChannel.remoteCommitChain.tip()
require.Len(t, remoteCommit.outgoingHTLCs, 1)
require.Equal(t, blinding, remoteCommit.outgoingHTLCs[0].BlindingPoint)
// Next, update bob's commitment and assert that we can still retrieve
// his incoming blinding point after restart.
err = bobChannel.ReceiveNewCommitment(aliceCommit.CommitSigs)
require.NoError(t, err, "bob unable to receive new commitment")
_, _, _, err = bobChannel.RevokeCurrentCommitment()
require.NoError(t, err, "bob unable to revoke current commitment")
bobChannel, err = restartChannel(bobChannel)
require.NoError(t, err, "unable to restart bob's channel")
// Assert that Bob is able to recover the blinding point from disk.
bobCommit := bobChannel.localCommitChain.tip()
require.Len(t, bobCommit.incomingHTLCs, 1)
require.Equal(t, blinding, bobCommit.incomingHTLCs[0].BlindingPoint)
}

View File

@ -78,6 +78,19 @@ type UpdateAddHTLC struct {
ExtraData ExtraOpaqueData
}
// BlingingPointOrNil returns the blinding point associated with the update, or
// nil.
func (c *UpdateAddHTLC) BlingingPointOrNil() *btcec.PublicKey {
var blindingPoint *btcec.PublicKey
c.BlindingPoint.WhenSome(func(b tlv.RecordT[BlindingPointTlvType,
*btcec.PublicKey]) {
blindingPoint = b.Val
})
return blindingPoint
}
// NewUpdateAddHTLC returns a new empty UpdateAddHTLC message.
func NewUpdateAddHTLC() *UpdateAddHTLC {
return &UpdateAddHTLC{}