common/amount: add routine to calculate fees backwards.

If I put in X, how much can I get out after fees are subtracted?

This was inspired by Eduardo's channel_maximum_forward in renepay, which
is basically the same thing.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell 2024-09-18 17:01:27 +09:30
parent db74ca7cbe
commit 379a4ee16a
3 changed files with 92 additions and 3 deletions

View file

@ -524,7 +524,7 @@ struct amount_sat amount_sat_div(struct amount_sat sat, u64 div)
bool amount_sat_mul(struct amount_sat *res, struct amount_sat sat, u64 mul)
{
if ( mul_overflows_u64(sat.satoshis, mul))
if (mul_overflows_u64(sat.satoshis, mul))
return false;
res->satoshis = sat.satoshis * mul;
return true;
@ -532,7 +532,7 @@ bool amount_sat_mul(struct amount_sat *res, struct amount_sat sat, u64 mul)
bool amount_msat_mul(struct amount_msat *res, struct amount_msat msat, u64 mul)
{
if ( mul_overflows_u64(msat.millisatoshis, mul))
if (mul_overflows_u64(msat.millisatoshis, mul))
return false;
res->millisatoshis = msat.millisatoshis * mul;
return true;
@ -560,6 +560,50 @@ bool amount_msat_fee(struct amount_msat *fee,
return amount_msat_add(fee, fee_base, fee_prop);
}
/* Does this input give enough to provide fee for output? */
static bool within_fee(struct amount_msat in,
struct amount_msat out,
u32 fee_base_msat,
u32 fee_proportional_millionths)
{
struct amount_msat with_fee = out;
if (!amount_msat_add_fee(&with_fee,
fee_base_msat,
fee_proportional_millionths))
return false;
return amount_msat_less_eq(with_fee, in);
}
struct amount_msat amount_msat_sub_fee(struct amount_msat in,
u32 fee_base_msat,
u32 fee_proportional_millionths)
{
struct amount_msat out, out_plus_one;
/* out = in - base - (out * prop / 1000000)
* Thus: out * (1 + prop / 1000000) = in - base
* out = (in - base) / (1 + prop / 1000000)
* out = 1000000 * (in - base) / (1000000 + prop)
*
* Since we round the fee down, out can be a bit bigger than
* expected, so we iterate upwards.
*/
if (!amount_msat_sub(&out, in, amount_msat(fee_base_msat)))
return AMOUNT_MSAT(0);
if (!amount_msat_mul(&out, out, 1000000))
return AMOUNT_MSAT(0);
out = amount_msat_div(out, 1000000ULL + fee_proportional_millionths);
/* If we calc reverse, it must work! */
assert(within_fee(in, out, fee_base_msat, fee_proportional_millionths));
/* We can be out-by-one */
if (amount_msat_add(&out_plus_one, out, AMOUNT_MSAT(1))
&& within_fee(in, out_plus_one, fee_base_msat, fee_proportional_millionths))
return out_plus_one;
return out;
}
bool amount_msat_add_fee(struct amount_msat *amt,
u32 fee_base_msat,
u32 fee_proportional_millionths)

View file

@ -188,6 +188,11 @@ WARN_UNUSED_RESULT bool amount_msat_add_fee(struct amount_msat *amt,
u32 fee_base_msat,
u32 fee_proportional_millionths);
/* Reversed: what is the largest possible output for a given input and fee? */
struct amount_msat amount_msat_sub_fee(struct amount_msat input,
u32 fee_base_msat,
u32 fee_proportional_millionths);
/* What is the fee for this tx weight? */
struct amount_sat amount_tx_fee(u32 fee_per_kw, size_t weight);

View file

@ -63,7 +63,46 @@ void towire_u8_array(u8 **pptr UNNEEDED, const u8 *arr UNNEEDED, size_t num UNNE
{ fprintf(stderr, "towire_u8_array called!\n"); abort(); }
/* AUTOGENERATED MOCKS END */
#define FAIL_MSAT(msatp, str) \
/* Note u32 truncation tests 0 values! */
static void test_amount_sub_fee(struct amount_msat msat,
u32 base, u32 prop)
{
struct amount_msat in, in2, out;
/* If we get msat out, how much do we put in? */
in = msat;
assert(amount_msat_add_fee(&in, base, prop));
/* Fee only increases amount */
assert(amount_msat_greater_eq(in, msat));
/* If we put that much in, how much do we get out? */
out = amount_msat_sub_fee(in, base, prop);
assert(amount_msat_eq(out, msat));
/* If we asked for one more out, we'd have to put more in */
assert(amount_msat_add(&in2, out, AMOUNT_MSAT(1)));
assert(amount_msat_add_fee(&in2, base, prop));
assert(amount_msat_greater(in2, in));
}
static void test_amount_with_fee(void)
{
for (int basebits = 0; basebits < 33; basebits++) {
u32 base = (1ULL << basebits);
for (int propbits = 0; propbits < 20; propbits++) {
u32 prop = (1ULL << propbits);
for (int amtbits1 = 0; amtbits1 < 42; amtbits1++) {
for (int amtbits2 = 0; amtbits2 < 42; amtbits2++) {
test_amount_sub_fee(amount_msat((1ULL << amtbits1) | (1ULL << amtbits2)), base, prop);
}
}
}
}
}
#define FAIL_MSAT(msatp, str) \
assert(!parse_amount_msat((msatp), (str), strlen(str)))
#define PASS_MSAT(msatp, str, val) \
do { \
@ -229,5 +268,6 @@ int main(int argc, char *argv[])
assert(sat.satoshis == i);
}
test_amount_with_fee();
common_shutdown();
}