record: stricter type for PaymentRelayInfo.BaseFee

In this commit, we update the PaymentRelayInfo struct's BaseFee member
to use a stricter type (lnwire.MilliSatoshi) instead of an ambigious
uint32.
This commit is contained in:
Elle Mouton 2024-07-01 15:27:41 +02:00
parent 62a97f86dd
commit 4457ca2e66
No known key found for this signature in database
GPG key ID: D7D916376026F177
5 changed files with 22 additions and 10 deletions

View file

@ -465,11 +465,11 @@ func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload,
// ceil(a/b) = (a + b - 1)/(b). // ceil(a/b) = (a + b - 1)/(b).
// //
//nolint:lll,dupword //nolint:lll,dupword
func calculateForwardingAmount(incomingAmount lnwire.MilliSatoshi, baseFee, func calculateForwardingAmount(incomingAmount, baseFee lnwire.MilliSatoshi,
proportionalFee uint32) (lnwire.MilliSatoshi, error) { proportionalFee uint32) (lnwire.MilliSatoshi, error) {
// Sanity check to prevent overflow. // Sanity check to prevent overflow.
if incomingAmount < lnwire.MilliSatoshi(baseFee) { if incomingAmount < baseFee {
return 0, fmt.Errorf("incoming amount: %v < base fee: %v", return 0, fmt.Errorf("incoming amount: %v < base fee: %v",
incomingAmount, baseFee) incomingAmount, baseFee)
} }

View file

@ -111,7 +111,7 @@ func TestForwardingAmountCalc(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
incomingAmount lnwire.MilliSatoshi incomingAmount lnwire.MilliSatoshi
baseFee uint32 baseFee lnwire.MilliSatoshi
proportional uint32 proportional uint32
forwardAmount lnwire.MilliSatoshi forwardAmount lnwire.MilliSatoshi
expectErr bool expectErr bool

View file

@ -658,7 +658,9 @@ func (b *blindedForwardTest) createBlindedRoute(hops []*forwardingEdge,
// Set the relay information for this edge based on its policy. // Set the relay information for this edge based on its policy.
delta := uint16(node.edge.TimeLockDelta) delta := uint16(node.edge.TimeLockDelta)
relayInfo := &record.PaymentRelayInfo{ relayInfo := &record.PaymentRelayInfo{
BaseFee: uint32(node.edge.FeeBaseMsat), BaseFee: lnwire.MilliSatoshi(
node.edge.FeeBaseMsat,
),
FeeRate: uint32(node.edge.FeeRateMilliMsat), FeeRate: uint32(node.edge.FeeRateMilliMsat),
CltvExpiryDelta: delta, CltvExpiryDelta: delta,
} }

View file

@ -268,8 +268,8 @@ type PaymentRelayInfo struct {
// satoshi. // satoshi.
FeeRate uint32 FeeRate uint32
// BaseFee is the per-htlc fee charged. // BaseFee is the per-htlc fee charged in milli-satoshis.
BaseFee uint32 BaseFee lnwire.MilliSatoshi
} }
// Record creates a tlv.Record that encodes the payment relay (type 10) type for // Record creates a tlv.Record that encodes the payment relay (type 10) type for
@ -278,7 +278,7 @@ func (i *PaymentRelayInfo) Record() tlv.Record {
return tlv.MakeDynamicRecord( return tlv.MakeDynamicRecord(
10, &i, func() uint64 { 10, &i, func() uint64 {
// uint16 + uint32 + tuint32 // uint16 + uint32 + tuint32
return 2 + 4 + tlv.SizeTUint32(i.BaseFee) return 2 + 4 + tlv.SizeTUint32(uint32(i.BaseFee))
}, encodePaymentRelay, decodePaymentRelay, }, encodePaymentRelay, decodePaymentRelay,
) )
} }
@ -294,9 +294,11 @@ func encodePaymentRelay(w io.Writer, val interface{}, buf *[8]byte) error {
return err return err
} }
baseFee := uint32(relayInfo.BaseFee)
// We can safely reuse buf here because we overwrite its // We can safely reuse buf here because we overwrite its
// contents. // contents.
return tlv.ETUint32(w, &relayInfo.BaseFee, buf) return tlv.ETUint32(w, &baseFee, buf)
} }
return tlv.NewTypeForEncodingErr(val, "**hop.PaymentRelayInfo") return tlv.NewTypeForEncodingErr(val, "**hop.PaymentRelayInfo")
@ -333,7 +335,15 @@ func decodePaymentRelay(r io.Reader, val interface{}, buf *[8]byte,
// is okay. // is okay.
b := bytes.NewBuffer(scratch[6:]) b := bytes.NewBuffer(scratch[6:])
return tlv.DTUint32(b, &relayInfo.BaseFee, buf, l-6) var baseFee uint32
err = tlv.DTUint32(b, &baseFee, buf, l-6)
if err != nil {
return err
}
relayInfo.BaseFee = lnwire.MilliSatoshi(baseFee)
return nil
} }
return tlv.NewTypeForDecodingErr(val, "*hop.paymentRelayInfo", l, 10) return tlv.NewTypeForDecodingErr(val, "*hop.paymentRelayInfo", l, 10)

View file

@ -37,7 +37,7 @@ func TestBlindedDataEncoding(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
baseFee uint32 baseFee lnwire.MilliSatoshi
htlcMin lnwire.MilliSatoshi htlcMin lnwire.MilliSatoshi
features *lnwire.FeatureVector features *lnwire.FeatureVector
constraints bool constraints bool