From 496c0dd1e674085d375df5dda8e48e8f1b9d874c Mon Sep 17 00:00:00 2001 From: Rusty Russell Date: Tue, 11 Aug 2020 14:35:56 +0930 Subject: [PATCH] common/random_select: central place for reservoir sampling. Turns out we can make quite a simple API out of it. Signed-off-by: Rusty Russell --- common/Makefile | 1 + common/random_select.c | 11 ++++++ common/random_select.h | 20 +++++++++++ gossipd/Makefile | 1 + gossipd/seeker.c | 8 ++--- gossipd/test/run-next_block_range.c | 3 ++ lightningd/Makefile | 1 + lightningd/invoice.c | 52 ++++++++++------------------- lightningd/test/Makefile | 1 + plugins/Makefile | 1 + plugins/libplugin-pay.c | 36 ++++++++++---------- 11 files changed, 76 insertions(+), 59 deletions(-) create mode 100644 common/random_select.c create mode 100644 common/random_select.h diff --git a/common/Makefile b/common/Makefile index 75fbb44b2..f3a729cc4 100644 --- a/common/Makefile +++ b/common/Makefile @@ -57,6 +57,7 @@ COMMON_SRC_NOGEN := \ common/ping.c \ common/psbt_open.c \ common/pseudorand.c \ + common/random_select.c \ common/read_peer_msg.c \ common/setup.c \ common/socket_close.c \ diff --git a/common/random_select.c b/common/random_select.c new file mode 100644 index 000000000..1acaeb005 --- /dev/null +++ b/common/random_select.c @@ -0,0 +1,11 @@ +#include +#include + +bool random_select(double weight, double *tot_weight) +{ + *tot_weight += weight; + if (weight == 0) + return false; + + return pseudorand_double() <= weight / *tot_weight; +} diff --git a/common/random_select.h b/common/random_select.h new file mode 100644 index 000000000..02bfbba2f --- /dev/null +++ b/common/random_select.h @@ -0,0 +1,20 @@ +#ifndef LIGHTNING_COMMON_RANDOM_SELECT_H +#define LIGHTNING_COMMON_RANDOM_SELECT_H +#include "config.h" +#include + +/* Use weighted reservoir sampling, see: + * https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Chao + * But (currently) the result will consist of only one sample (k=1) + */ + +/** + * random_select: return true if we should select this one. + * @weight: weight for this option (use 1.0 if all the same) + * @tot_wieght: returns with sum of weights (must be initialized to zero) + * + * This always returns true on the first non-zero weight, and weighted + * randomly from then on. + */ +bool random_select(double weight, double *tot_weight); +#endif /* LIGHTNING_COMMON_RANDOM_SELECT_H */ diff --git a/gossipd/Makefile b/gossipd/Makefile index 5a18f1d56..2b328b2e8 100644 --- a/gossipd/Makefile +++ b/gossipd/Makefile @@ -65,6 +65,7 @@ GOSSIPD_COMMON_OBJS := \ common/per_peer_state.o \ common/ping.o \ common/pseudorand.o \ + common/random_select.o \ common/setup.o \ common/status.o \ common/status_wire.o \ diff --git a/gossipd/seeker.c b/gossipd/seeker.c index bf9be2488..6428ad029 100644 --- a/gossipd/seeker.c +++ b/gossipd/seeker.c @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -454,7 +455,7 @@ static bool get_unannounced_nodes(const tal_t *ctx, { size_t num = 0; u64 offset; - u64 threshold = pseudorand_u64(); + double total_weight = 0.0; /* Pick an example short_channel_id at random to query. As a * side-effect this gets the node. */ @@ -475,11 +476,8 @@ static bool get_unannounced_nodes(const tal_t *ctx, (*scids)[num++] = c->scid; } else { /* Maybe replace one: approx. reservoir sampling */ - u64 p = pseudorand_u64(); - if (p > threshold) { + if (random_select(1.0, &total_weight)) (*scids)[pseudorand(max)] = c->scid; - threshold = p; - } } } diff --git a/gossipd/test/run-next_block_range.c b/gossipd/test/run-next_block_range.c index 0395c292c..fb7e46b9f 100644 --- a/gossipd/test/run-next_block_range.c +++ b/gossipd/test/run-next_block_range.c @@ -63,6 +63,9 @@ void queue_peer_msg(struct peer *peer UNNEEDED, const u8 *msg TAKES UNNEEDED) struct peer *random_peer(struct daemon *daemon UNNEEDED, bool (*check_peer)(const struct peer *peer)) { fprintf(stderr, "random_peer called!\n"); abort(); } +/* Generated stub for random_select */ +bool random_select(double weight UNNEEDED, double *tot_weight UNNEEDED) +{ fprintf(stderr, "random_select called!\n"); abort(); } /* Generated stub for status_failed */ void status_failed(enum status_failreason code UNNEEDED, const char *fmt UNNEEDED, ...) diff --git a/lightningd/Makefile b/lightningd/Makefile index 40aae47ee..644e34899 100644 --- a/lightningd/Makefile +++ b/lightningd/Makefile @@ -60,6 +60,7 @@ LIGHTNINGD_COMMON_OBJS := \ common/per_peer_state.o \ common/permute_tx.o \ common/pseudorand.o \ + common/random_select.o \ common/setup.o \ common/sphinx.o \ common/status_wire.o \ diff --git a/lightningd/invoice.c b/lightningd/invoice.c index 1528da75b..e3c884994 100644 --- a/lightningd/invoice.c +++ b/lightningd/invoice.c @@ -16,7 +16,7 @@ #include #include #include -#include +#include #include #include #include @@ -489,15 +489,8 @@ static struct route_info **select_inchan(const tal_t *ctx, bool *any_offline) { /* BOLT11 struct wants an array of arrays (can provide multiple routes) */ - struct route_info **R; - double wsum, p; - - struct sample { - const struct route_info *route; - double weight; - }; - - struct sample *S = tal_arr(tmpctx, struct sample, 0); + struct route_info **r = NULL; + double total_weight = 0.0; *any_offline = false; @@ -505,7 +498,6 @@ static struct route_info **select_inchan(const tal_t *ctx, for (size_t i = 0; i < tal_count(inchans); i++) { struct peer *peer; struct channel *c; - struct sample sample; struct amount_msat capacity_to_pay_us, excess, capacity; struct amount_sat cumulative_reserve; double excess_frac; @@ -564,33 +556,23 @@ static struct route_info **select_inchan(const tal_t *ctx, continue; } + /* We don't want a 0 probability if 0 excess; it might be the + * only one! So bump it by 1 msat */ + if (!amount_msat_add(&excess, excess, AMOUNT_MSAT(1))) { + log_broken(ld->log, "Channel %s excess overflow!", + type_to_string(tmpctx, struct short_channel_id, c->scid)); + continue; + } excess_frac = amount_msat_ratio(excess, capacity); - sample.route = &inchans[i]; - sample.weight = excess_frac; - tal_arr_expand(&S, sample); + if (random_select(excess_frac, &total_weight)) { + tal_free(r); + r = tal_arr(ctx, struct route_info *, 1); + r[0] = tal_dup(r, struct route_info, &inchans[i]); + } } - if (!tal_count(S)) - return NULL; - - /* Use weighted reservoir sampling, see: - * https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Chao - * But (currently) the result will consist of only one sample (k=1) */ - R = tal_arr(ctx, struct route_info *, 1); - R[0] = tal_dup(R, struct route_info, S[0].route); - wsum = S[0].weight; - - for (size_t i = 1; i < tal_count(S); i++) { - wsum += S[i].weight; - p = S[i].weight / wsum; - double random_1 = pseudorand_double(); /* range [0,1) */ - - if (random_1 <= p) - R[0] = tal_dup(R, struct route_info, S[i].route); - } - - return R; + return r; } /** select_inchan_mpp @@ -1414,6 +1396,7 @@ static struct command_result *json_waitanyinvoice(struct command *cmd, " is non-trivial."); } + static const struct json_command waitanyinvoice_command = { "waitanyinvoice", "payment", @@ -1423,7 +1406,6 @@ static const struct json_command waitanyinvoice_command = { }; AUTODATA(json_command, &waitanyinvoice_command); - /* Wait for an incoming payment matching the `label` in the JSON * command. This will either return immediately if the payment has * already been received or it may add the `cmd` to the list of diff --git a/lightningd/test/Makefile b/lightningd/test/Makefile index bc415583d..03fbd90e6 100644 --- a/lightningd/test/Makefile +++ b/lightningd/test/Makefile @@ -16,6 +16,7 @@ LIGHTNINGD_TEST_COMMON_OBJS := \ common/json.o \ common/key_derive.o \ common/pseudorand.o \ + common/random_select.o \ common/memleak.o \ common/msg_queue.o \ common/utils.o \ diff --git a/plugins/Makefile b/plugins/Makefile index 562e61eb8..e89d6ee02 100644 --- a/plugins/Makefile +++ b/plugins/Makefile @@ -70,6 +70,7 @@ PLUGIN_COMMON_OBJS := \ common/node_id.o \ common/param.o \ common/pseudorand.o \ + common/random_select.o \ common/setup.o \ common/type_to_string.o \ common/utils.o \ diff --git a/plugins/libplugin-pay.c b/plugins/libplugin-pay.c index efadc1a2f..30035c6f7 100644 --- a/plugins/libplugin-pay.c +++ b/plugins/libplugin-pay.c @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -2421,12 +2422,11 @@ static struct command_result *shadow_route_listchannels(struct command *cmd, const jsmntok_t *result, struct payment *p) { - /* Use reservoir sampling across the capable channels. */ struct shadow_route_data *d = payment_mod_shadowroute_get_data(p); struct payment_constraints *cons = &d->constraints; struct route_info *best = NULL; + double total_weight = 0.0; size_t i; - u64 sample = 0; struct amount_msat best_fee; const jsmntok_t *sattok, *delaytok, *basefeetok, *propfeetok, *desttok, *channelstok, *chan, *scidtok; @@ -2438,7 +2438,6 @@ static struct command_result *shadow_route_listchannels(struct command *cmd, channelstok = json_get_member(buf, result, "channels"); json_for_each_arr(i, chan, channelstok) { - u64 v = pseudorand(UINT64_MAX); struct route_info curr; struct amount_sat capacity; struct amount_msat fee; @@ -2465,28 +2464,27 @@ static struct command_result *shadow_route_listchannels(struct command *cmd, json_to_sat(buf, sattok, &capacity); json_to_node_id(buf, desttok, &curr.pubkey); - if (!best || v > sample) { - /* If the capacity is insufficient to pass the amount - * it's not a plausible extension. */ - if (amount_msat_greater_sat(p->amount, capacity)) - continue; + /* If the capacity is insufficient to pass the amount + * it's not a plausible extension. */ + if (amount_msat_greater_sat(p->amount, capacity)) + continue; - if (curr.cltv_expiry_delta > cons->cltv_budget) - continue; + if (curr.cltv_expiry_delta > cons->cltv_budget) + continue; - if (!amount_msat_fee( - &fee, p->amount, curr.fee_base_msat, - curr.fee_proportional_millionths)) { - /* Fee computation failed... */ - continue; - } + if (!amount_msat_fee( + &fee, p->amount, curr.fee_base_msat, + curr.fee_proportional_millionths)) { + /* Fee computation failed... */ + continue; + } - if (amount_msat_greater_eq(fee, cons->fee_budget)) - continue; + if (amount_msat_greater_eq(fee, cons->fee_budget)) + continue; + if (random_select(1.0, &total_weight)) { best = tal_dup(tmpctx, struct route_info, &curr); best_fee = fee; - sample = v; } }