mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-02-23 14:40:30 +01:00
routing: improve lasthoppaylaod size calculation
Fixes a bug and makes the function more robust. Before we would always return the encrypted data size of last hop of the last path. Now we return the greatest last hop payload not always the one of the last path.
This commit is contained in:
parent
eb93eb7ee9
commit
c579a6bf2f
3 changed files with 81 additions and 33 deletions
|
@ -235,21 +235,33 @@ func (s *BlindedPaymentPathSet) FinalCLTVDelta() uint16 {
|
|||
// LargestLastHopPayloadPath returns the BlindedPayment in the set that has the
|
||||
// largest last-hop payload. This is to be used for onion size estimation in
|
||||
// path finding.
|
||||
func (s *BlindedPaymentPathSet) LargestLastHopPayloadPath() *BlindedPayment {
|
||||
func (s *BlindedPaymentPathSet) LargestLastHopPayloadPath() (*BlindedPayment,
|
||||
error) {
|
||||
|
||||
var (
|
||||
largestPath *BlindedPayment
|
||||
currentMax int
|
||||
)
|
||||
|
||||
if len(s.paths) == 0 {
|
||||
return nil, fmt.Errorf("no blinded paths in the set")
|
||||
}
|
||||
|
||||
// We set the largest path to make sure we always return a path even
|
||||
// if the cipher text is empty.
|
||||
largestPath = s.paths[0]
|
||||
|
||||
for _, path := range s.paths {
|
||||
numHops := len(path.BlindedPath.BlindedHops)
|
||||
lastHop := path.BlindedPath.BlindedHops[numHops-1]
|
||||
|
||||
if len(lastHop.CipherText) > currentMax {
|
||||
largestPath = path
|
||||
currentMax = len(lastHop.CipherText)
|
||||
}
|
||||
}
|
||||
|
||||
return largestPath
|
||||
return largestPath, nil
|
||||
}
|
||||
|
||||
// ToRouteHints converts the blinded path payment set into a RouteHints map so
|
||||
|
|
|
@ -700,7 +700,10 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
|
|||
|
||||
// The payload size of the final hop differ from intermediate hops
|
||||
// and depends on whether the destination is blinded or not.
|
||||
lastHopPayloadSize := lastHopPayloadSize(r, finalHtlcExpiry, amt)
|
||||
lastHopPayloadSize, err := lastHopPayloadSize(r, finalHtlcExpiry, amt)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// We can't always assume that the end destination is publicly
|
||||
// advertised to the network so we'll manually include the target node.
|
||||
|
@ -1433,11 +1436,15 @@ func getProbabilityBasedDist(weight int64, probability float64,
|
|||
// It depends on the tlv types which are present and also whether the hop is
|
||||
// part of a blinded route or not.
|
||||
func lastHopPayloadSize(r *RestrictParams, finalHtlcExpiry int32,
|
||||
amount lnwire.MilliSatoshi) uint64 {
|
||||
amount lnwire.MilliSatoshi) (uint64, error) {
|
||||
|
||||
if r.BlindedPaymentPathSet != nil {
|
||||
paymentPath := r.BlindedPaymentPathSet.
|
||||
paymentPath, err := r.BlindedPaymentPathSet.
|
||||
LargestLastHopPayloadPath()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
blindedPath := paymentPath.BlindedPath.BlindedHops
|
||||
blindedPoint := paymentPath.BlindedPath.BlindingPoint
|
||||
|
||||
|
@ -1452,7 +1459,7 @@ func lastHopPayloadSize(r *RestrictParams, finalHtlcExpiry int32,
|
|||
}
|
||||
|
||||
// The final hop does not have a short chanID set.
|
||||
return finalHop.PayloadSize(0)
|
||||
return finalHop.PayloadSize(0), nil
|
||||
}
|
||||
|
||||
var mpp *record.MPP
|
||||
|
@ -1478,7 +1485,7 @@ func lastHopPayloadSize(r *RestrictParams, finalHtlcExpiry int32,
|
|||
}
|
||||
|
||||
// The final hop does not have a short chanID set.
|
||||
return finalHop.PayloadSize(0)
|
||||
return finalHop.PayloadSize(0), nil
|
||||
}
|
||||
|
||||
// overflowSafeAdd adds two MilliSatoshi values and returns the result. If an
|
||||
|
|
|
@ -3416,32 +3416,48 @@ func TestLastHopPayloadSize(t *testing.T) {
|
|||
customRecords = map[uint64][]byte{
|
||||
record.CustomTypeStart: {1, 2, 3},
|
||||
}
|
||||
sizeEncryptedData = 100
|
||||
encrypedData = bytes.Repeat(
|
||||
[]byte{1}, sizeEncryptedData,
|
||||
|
||||
encrypedDataSmall = bytes.Repeat(
|
||||
[]byte{1}, 5,
|
||||
)
|
||||
_, blindedPoint = btcec.PrivKeyFromBytes([]byte{5})
|
||||
paymentAddr = &[32]byte{1}
|
||||
ampOptions = &Options{}
|
||||
amtToForward = lnwire.MilliSatoshi(10000)
|
||||
finalHopExpiry int32 = 144
|
||||
encrypedDataLarge = bytes.Repeat(
|
||||
[]byte{1}, 100,
|
||||
)
|
||||
_, blindedPoint = btcec.PrivKeyFromBytes([]byte{5})
|
||||
paymentAddr = &[32]byte{1}
|
||||
ampOptions = &Options{}
|
||||
amtToForward = lnwire.MilliSatoshi(10000)
|
||||
emptyEncryptedData = []byte{}
|
||||
finalHopExpiry int32 = 144
|
||||
|
||||
oneHopPath = &sphinx.BlindedPath{
|
||||
BlindedHops: []*sphinx.BlindedHopInfo{
|
||||
{
|
||||
CipherText: encrypedData,
|
||||
CipherText: emptyEncryptedData,
|
||||
},
|
||||
},
|
||||
BlindingPoint: blindedPoint,
|
||||
}
|
||||
|
||||
twoHopPath = &sphinx.BlindedPath{
|
||||
twoHopPathSmallHopSize = &sphinx.BlindedPath{
|
||||
BlindedHops: []*sphinx.BlindedHopInfo{
|
||||
{
|
||||
CipherText: encrypedData,
|
||||
CipherText: encrypedDataLarge,
|
||||
},
|
||||
{
|
||||
CipherText: encrypedData,
|
||||
CipherText: encrypedDataLarge,
|
||||
},
|
||||
},
|
||||
BlindingPoint: blindedPoint,
|
||||
}
|
||||
|
||||
twoHopPathLargeHopSize = &sphinx.BlindedPath{
|
||||
BlindedHops: []*sphinx.BlindedHopInfo{
|
||||
{
|
||||
CipherText: encrypedDataSmall,
|
||||
},
|
||||
{
|
||||
CipherText: encrypedDataSmall,
|
||||
},
|
||||
},
|
||||
BlindingPoint: blindedPoint,
|
||||
|
@ -3454,15 +3470,19 @@ func TestLastHopPayloadSize(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
twoHopBlindedPayment, err := NewBlindedPaymentPathSet(
|
||||
[]*BlindedPayment{{BlindedPath: twoHopPath}},
|
||||
[]*BlindedPayment{
|
||||
{BlindedPath: twoHopPathLargeHopSize},
|
||||
{BlindedPath: twoHopPathSmallHopSize},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
restrictions *RestrictParams
|
||||
finalHopExpiry int32
|
||||
amount lnwire.MilliSatoshi
|
||||
name string
|
||||
restrictions *RestrictParams
|
||||
finalHopExpiry int32
|
||||
amount lnwire.MilliSatoshi
|
||||
expectedEncryptedData []byte
|
||||
}{
|
||||
{
|
||||
name: "Non blinded final hop",
|
||||
|
@ -3480,16 +3500,18 @@ func TestLastHopPayloadSize(t *testing.T) {
|
|||
restrictions: &RestrictParams{
|
||||
BlindedPaymentPathSet: oneHopBlindedPayment,
|
||||
},
|
||||
amount: amtToForward,
|
||||
finalHopExpiry: finalHopExpiry,
|
||||
amount: amtToForward,
|
||||
finalHopExpiry: finalHopExpiry,
|
||||
expectedEncryptedData: emptyEncryptedData,
|
||||
},
|
||||
{
|
||||
name: "Blinded final hop of a two hop payment",
|
||||
restrictions: &RestrictParams{
|
||||
BlindedPaymentPathSet: twoHopBlindedPayment,
|
||||
},
|
||||
amount: amtToForward,
|
||||
finalHopExpiry: finalHopExpiry,
|
||||
amount: amtToForward,
|
||||
finalHopExpiry: finalHopExpiry,
|
||||
expectedEncryptedData: encrypedDataLarge,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -3513,16 +3535,23 @@ func TestLastHopPayloadSize(t *testing.T) {
|
|||
|
||||
var finalHop route.Hop
|
||||
if tc.restrictions.BlindedPaymentPathSet != nil {
|
||||
path := tc.restrictions.BlindedPaymentPathSet.
|
||||
LargestLastHopPayloadPath()
|
||||
bPSet := tc.restrictions.BlindedPaymentPathSet
|
||||
path, err := bPSet.LargestLastHopPayloadPath()
|
||||
require.NotNil(t, path)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
blindedPath := path.BlindedPath.BlindedHops
|
||||
blindedPoint := path.BlindedPath.BlindingPoint
|
||||
lastHop := blindedPath[len(blindedPath)-1]
|
||||
require.Equal(t, lastHop.CipherText,
|
||||
tc.expectedEncryptedData)
|
||||
|
||||
//nolint:lll
|
||||
finalHop = route.Hop{
|
||||
AmtToForward: tc.amount,
|
||||
OutgoingTimeLock: uint32(tc.finalHopExpiry),
|
||||
EncryptedData: blindedPath[len(blindedPath)-1].CipherText,
|
||||
EncryptedData: lastHop.CipherText,
|
||||
}
|
||||
if len(blindedPath) == 1 {
|
||||
finalHop.BlindingPoint = blindedPoint
|
||||
|
@ -3542,11 +3571,11 @@ func TestLastHopPayloadSize(t *testing.T) {
|
|||
payLoad, err := createHopPayload(finalHop, 0, true)
|
||||
require.NoErrorf(t, err, "failed to create hop payload")
|
||||
|
||||
expectedPayloadSize := lastHopPayloadSize(
|
||||
expectedPayloadSize, err := lastHopPayloadSize(
|
||||
tc.restrictions, tc.finalHopExpiry,
|
||||
tc.amount,
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(
|
||||
t, expectedPayloadSize,
|
||||
uint64(payLoad.NumBytes()),
|
||||
|
|
Loading…
Add table
Reference in a new issue