From 4457ca2e66ec0f246fb5e75632cbf3292aa099ab Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 1 Jul 2024 15:27:41 +0200 Subject: [PATCH] record: stricter type for PaymentRelayInfo.BaseFee In this commit, we update the PaymentRelayInfo struct's BaseFee member to use a stricter type (lnwire.MilliSatoshi) instead of an ambigious uint32. --- htlcswitch/hop/iterator.go | 4 ++-- htlcswitch/hop/iterator_test.go | 2 +- itest/lnd_route_blinding_test.go | 4 +++- record/blinded_data.go | 20 +++++++++++++++----- record/blinded_data_test.go | 2 +- 5 files changed, 22 insertions(+), 10 deletions(-) diff --git a/htlcswitch/hop/iterator.go b/htlcswitch/hop/iterator.go index ddfbe5934..5ef870874 100644 --- a/htlcswitch/hop/iterator.go +++ b/htlcswitch/hop/iterator.go @@ -465,11 +465,11 @@ func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload, // ceil(a/b) = (a + b - 1)/(b). // //nolint:lll,dupword -func calculateForwardingAmount(incomingAmount lnwire.MilliSatoshi, baseFee, +func calculateForwardingAmount(incomingAmount, baseFee lnwire.MilliSatoshi, proportionalFee uint32) (lnwire.MilliSatoshi, error) { // Sanity check to prevent overflow. - if incomingAmount < lnwire.MilliSatoshi(baseFee) { + if incomingAmount < baseFee { return 0, fmt.Errorf("incoming amount: %v < base fee: %v", incomingAmount, baseFee) } diff --git a/htlcswitch/hop/iterator_test.go b/htlcswitch/hop/iterator_test.go index 9995c71ba..d99f1de24 100644 --- a/htlcswitch/hop/iterator_test.go +++ b/htlcswitch/hop/iterator_test.go @@ -111,7 +111,7 @@ func TestForwardingAmountCalc(t *testing.T) { tests := []struct { name string incomingAmount lnwire.MilliSatoshi - baseFee uint32 + baseFee lnwire.MilliSatoshi proportional uint32 forwardAmount lnwire.MilliSatoshi expectErr bool diff --git a/itest/lnd_route_blinding_test.go b/itest/lnd_route_blinding_test.go index a2899657b..7cddff70b 100644 --- a/itest/lnd_route_blinding_test.go +++ b/itest/lnd_route_blinding_test.go @@ -658,7 +658,9 @@ func (b *blindedForwardTest) createBlindedRoute(hops []*forwardingEdge, // Set the relay information for this edge based on its policy. delta := uint16(node.edge.TimeLockDelta) relayInfo := &record.PaymentRelayInfo{ - BaseFee: uint32(node.edge.FeeBaseMsat), + BaseFee: lnwire.MilliSatoshi( + node.edge.FeeBaseMsat, + ), FeeRate: uint32(node.edge.FeeRateMilliMsat), CltvExpiryDelta: delta, } diff --git a/record/blinded_data.go b/record/blinded_data.go index 2c3204d20..a62db00a7 100644 --- a/record/blinded_data.go +++ b/record/blinded_data.go @@ -268,8 +268,8 @@ type PaymentRelayInfo struct { // satoshi. FeeRate uint32 - // BaseFee is the per-htlc fee charged. - BaseFee uint32 + // BaseFee is the per-htlc fee charged in milli-satoshis. + BaseFee lnwire.MilliSatoshi } // Record creates a tlv.Record that encodes the payment relay (type 10) type for @@ -278,7 +278,7 @@ func (i *PaymentRelayInfo) Record() tlv.Record { return tlv.MakeDynamicRecord( 10, &i, func() uint64 { // uint16 + uint32 + tuint32 - return 2 + 4 + tlv.SizeTUint32(i.BaseFee) + return 2 + 4 + tlv.SizeTUint32(uint32(i.BaseFee)) }, encodePaymentRelay, decodePaymentRelay, ) } @@ -294,9 +294,11 @@ func encodePaymentRelay(w io.Writer, val interface{}, buf *[8]byte) error { return err } + baseFee := uint32(relayInfo.BaseFee) + // We can safely reuse buf here because we overwrite its // contents. - return tlv.ETUint32(w, &relayInfo.BaseFee, buf) + return tlv.ETUint32(w, &baseFee, buf) } return tlv.NewTypeForEncodingErr(val, "**hop.PaymentRelayInfo") @@ -333,7 +335,15 @@ func decodePaymentRelay(r io.Reader, val interface{}, buf *[8]byte, // is okay. b := bytes.NewBuffer(scratch[6:]) - return tlv.DTUint32(b, &relayInfo.BaseFee, buf, l-6) + var baseFee uint32 + err = tlv.DTUint32(b, &baseFee, buf, l-6) + if err != nil { + return err + } + + relayInfo.BaseFee = lnwire.MilliSatoshi(baseFee) + + return nil } return tlv.NewTypeForDecodingErr(val, "*hop.paymentRelayInfo", l, 10) diff --git a/record/blinded_data_test.go b/record/blinded_data_test.go index 6de97e497..604d5a7fb 100644 --- a/record/blinded_data_test.go +++ b/record/blinded_data_test.go @@ -37,7 +37,7 @@ func TestBlindedDataEncoding(t *testing.T) { tests := []struct { name string - baseFee uint32 + baseFee lnwire.MilliSatoshi htlcMin lnwire.MilliSatoshi features *lnwire.FeatureVector constraints bool