record: stricter type for PaymentRelayInfo.BaseFee

In this commit, we update the PaymentRelayInfo struct's BaseFee member
to use a stricter type (lnwire.MilliSatoshi) instead of an ambigious
uint32.
This commit is contained in:
Elle Mouton 2024-07-01 15:27:41 +02:00
parent 62a97f86dd
commit 4457ca2e66
No known key found for this signature in database
GPG key ID: D7D916376026F177
5 changed files with 22 additions and 10 deletions

View file

@ -465,11 +465,11 @@ func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload,
// ceil(a/b) = (a + b - 1)/(b).
//
//nolint:lll,dupword
func calculateForwardingAmount(incomingAmount lnwire.MilliSatoshi, baseFee,
func calculateForwardingAmount(incomingAmount, baseFee lnwire.MilliSatoshi,
proportionalFee uint32) (lnwire.MilliSatoshi, error) {
// Sanity check to prevent overflow.
if incomingAmount < lnwire.MilliSatoshi(baseFee) {
if incomingAmount < baseFee {
return 0, fmt.Errorf("incoming amount: %v < base fee: %v",
incomingAmount, baseFee)
}

View file

@ -111,7 +111,7 @@ func TestForwardingAmountCalc(t *testing.T) {
tests := []struct {
name string
incomingAmount lnwire.MilliSatoshi
baseFee uint32
baseFee lnwire.MilliSatoshi
proportional uint32
forwardAmount lnwire.MilliSatoshi
expectErr bool

View file

@ -658,7 +658,9 @@ func (b *blindedForwardTest) createBlindedRoute(hops []*forwardingEdge,
// Set the relay information for this edge based on its policy.
delta := uint16(node.edge.TimeLockDelta)
relayInfo := &record.PaymentRelayInfo{
BaseFee: uint32(node.edge.FeeBaseMsat),
BaseFee: lnwire.MilliSatoshi(
node.edge.FeeBaseMsat,
),
FeeRate: uint32(node.edge.FeeRateMilliMsat),
CltvExpiryDelta: delta,
}

View file

@ -268,8 +268,8 @@ type PaymentRelayInfo struct {
// satoshi.
FeeRate uint32
// BaseFee is the per-htlc fee charged.
BaseFee uint32
// BaseFee is the per-htlc fee charged in milli-satoshis.
BaseFee lnwire.MilliSatoshi
}
// Record creates a tlv.Record that encodes the payment relay (type 10) type for
@ -278,7 +278,7 @@ func (i *PaymentRelayInfo) Record() tlv.Record {
return tlv.MakeDynamicRecord(
10, &i, func() uint64 {
// uint16 + uint32 + tuint32
return 2 + 4 + tlv.SizeTUint32(i.BaseFee)
return 2 + 4 + tlv.SizeTUint32(uint32(i.BaseFee))
}, encodePaymentRelay, decodePaymentRelay,
)
}
@ -294,9 +294,11 @@ func encodePaymentRelay(w io.Writer, val interface{}, buf *[8]byte) error {
return err
}
baseFee := uint32(relayInfo.BaseFee)
// We can safely reuse buf here because we overwrite its
// contents.
return tlv.ETUint32(w, &relayInfo.BaseFee, buf)
return tlv.ETUint32(w, &baseFee, buf)
}
return tlv.NewTypeForEncodingErr(val, "**hop.PaymentRelayInfo")
@ -333,7 +335,15 @@ func decodePaymentRelay(r io.Reader, val interface{}, buf *[8]byte,
// is okay.
b := bytes.NewBuffer(scratch[6:])
return tlv.DTUint32(b, &relayInfo.BaseFee, buf, l-6)
var baseFee uint32
err = tlv.DTUint32(b, &baseFee, buf, l-6)
if err != nil {
return err
}
relayInfo.BaseFee = lnwire.MilliSatoshi(baseFee)
return nil
}
return tlv.NewTypeForDecodingErr(val, "*hop.paymentRelayInfo", l, 10)

View file

@ -37,7 +37,7 @@ func TestBlindedDataEncoding(t *testing.T) {
tests := []struct {
name string
baseFee uint32
baseFee lnwire.MilliSatoshi
htlcMin lnwire.MilliSatoshi
features *lnwire.FeatureVector
constraints bool