diff --git a/common/amount.c b/common/amount.c index 965f23384..03da1de2d 100644 --- a/common/amount.c +++ b/common/amount.c @@ -524,7 +524,7 @@ struct amount_sat amount_sat_div(struct amount_sat sat, u64 div) bool amount_sat_mul(struct amount_sat *res, struct amount_sat sat, u64 mul) { - if ( mul_overflows_u64(sat.satoshis, mul)) + if (mul_overflows_u64(sat.satoshis, mul)) return false; res->satoshis = sat.satoshis * mul; return true; @@ -532,7 +532,7 @@ bool amount_sat_mul(struct amount_sat *res, struct amount_sat sat, u64 mul) bool amount_msat_mul(struct amount_msat *res, struct amount_msat msat, u64 mul) { - if ( mul_overflows_u64(msat.millisatoshis, mul)) + if (mul_overflows_u64(msat.millisatoshis, mul)) return false; res->millisatoshis = msat.millisatoshis * mul; return true; @@ -560,6 +560,50 @@ bool amount_msat_fee(struct amount_msat *fee, return amount_msat_add(fee, fee_base, fee_prop); } +/* Does this input give enough to provide fee for output? */ +static bool within_fee(struct amount_msat in, + struct amount_msat out, + u32 fee_base_msat, + u32 fee_proportional_millionths) +{ + struct amount_msat with_fee = out; + if (!amount_msat_add_fee(&with_fee, + fee_base_msat, + fee_proportional_millionths)) + return false; + return amount_msat_less_eq(with_fee, in); +} + +struct amount_msat amount_msat_sub_fee(struct amount_msat in, + u32 fee_base_msat, + u32 fee_proportional_millionths) +{ + struct amount_msat out, out_plus_one; + + /* out = in - base - (out * prop / 1000000) + * Thus: out * (1 + prop / 1000000) = in - base + * out = (in - base) / (1 + prop / 1000000) + * out = 1000000 * (in - base) / (1000000 + prop) + * + * Since we round the fee down, out can be a bit bigger than + * expected, so we iterate upwards. + */ + if (!amount_msat_sub(&out, in, amount_msat(fee_base_msat))) + return AMOUNT_MSAT(0); + if (!amount_msat_mul(&out, out, 1000000)) + return AMOUNT_MSAT(0); + out = amount_msat_div(out, 1000000ULL + fee_proportional_millionths); + + /* If we calc reverse, it must work! */ + assert(within_fee(in, out, fee_base_msat, fee_proportional_millionths)); + + /* We can be out-by-one */ + if (amount_msat_add(&out_plus_one, out, AMOUNT_MSAT(1)) + && within_fee(in, out_plus_one, fee_base_msat, fee_proportional_millionths)) + return out_plus_one; + return out; +} + bool amount_msat_add_fee(struct amount_msat *amt, u32 fee_base_msat, u32 fee_proportional_millionths) diff --git a/common/amount.h b/common/amount.h index e94e1d2d8..ded89bbcf 100644 --- a/common/amount.h +++ b/common/amount.h @@ -188,6 +188,11 @@ WARN_UNUSED_RESULT bool amount_msat_add_fee(struct amount_msat *amt, u32 fee_base_msat, u32 fee_proportional_millionths); +/* Reversed: what is the largest possible output for a given input and fee? */ +struct amount_msat amount_msat_sub_fee(struct amount_msat input, + u32 fee_base_msat, + u32 fee_proportional_millionths); + /* What is the fee for this tx weight? */ struct amount_sat amount_tx_fee(u32 fee_per_kw, size_t weight); diff --git a/common/test/run-amount.c b/common/test/run-amount.c index 923f0a9b7..7b408959d 100644 --- a/common/test/run-amount.c +++ b/common/test/run-amount.c @@ -63,7 +63,46 @@ void towire_u8_array(u8 **pptr UNNEEDED, const u8 *arr UNNEEDED, size_t num UNNE { fprintf(stderr, "towire_u8_array called!\n"); abort(); } /* AUTOGENERATED MOCKS END */ -#define FAIL_MSAT(msatp, str) \ +/* Note u32 truncation tests 0 values! */ +static void test_amount_sub_fee(struct amount_msat msat, + u32 base, u32 prop) +{ + struct amount_msat in, in2, out; + + /* If we get msat out, how much do we put in? */ + in = msat; + assert(amount_msat_add_fee(&in, base, prop)); + /* Fee only increases amount */ + assert(amount_msat_greater_eq(in, msat)); + + /* If we put that much in, how much do we get out? */ + out = amount_msat_sub_fee(in, base, prop); + assert(amount_msat_eq(out, msat)); + + /* If we asked for one more out, we'd have to put more in */ + assert(amount_msat_add(&in2, out, AMOUNT_MSAT(1))); + assert(amount_msat_add_fee(&in2, base, prop)); + assert(amount_msat_greater(in2, in)); +} + +static void test_amount_with_fee(void) +{ + for (int basebits = 0; basebits < 33; basebits++) { + u32 base = (1ULL << basebits); + + for (int propbits = 0; propbits < 20; propbits++) { + u32 prop = (1ULL << propbits); + + for (int amtbits1 = 0; amtbits1 < 42; amtbits1++) { + for (int amtbits2 = 0; amtbits2 < 42; amtbits2++) { + test_amount_sub_fee(amount_msat((1ULL << amtbits1) | (1ULL << amtbits2)), base, prop); + } + } + } + } +} + +#define FAIL_MSAT(msatp, str) \ assert(!parse_amount_msat((msatp), (str), strlen(str))) #define PASS_MSAT(msatp, str, val) \ do { \ @@ -229,5 +268,6 @@ int main(int argc, char *argv[]) assert(sat.satoshis == i); } + test_amount_with_fee(); common_shutdown(); }