mirror of
https://github.com/ElementsProject/lightning.git
synced 2025-02-20 13:54:36 +01:00
common/random_select: central place for reservoir sampling.
Turns out we can make quite a simple API out of it. Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
parent
12d0d5c185
commit
496c0dd1e6
11 changed files with 76 additions and 59 deletions
|
@ -57,6 +57,7 @@ COMMON_SRC_NOGEN := \
|
|||
common/ping.c \
|
||||
common/psbt_open.c \
|
||||
common/pseudorand.c \
|
||||
common/random_select.c \
|
||||
common/read_peer_msg.c \
|
||||
common/setup.c \
|
||||
common/socket_close.c \
|
||||
|
|
11
common/random_select.c
Normal file
11
common/random_select.c
Normal file
|
@ -0,0 +1,11 @@
|
|||
#include <common/pseudorand.h>
|
||||
#include <common/random_select.h>
|
||||
|
||||
bool random_select(double weight, double *tot_weight)
|
||||
{
|
||||
*tot_weight += weight;
|
||||
if (weight == 0)
|
||||
return false;
|
||||
|
||||
return pseudorand_double() <= weight / *tot_weight;
|
||||
}
|
20
common/random_select.h
Normal file
20
common/random_select.h
Normal file
|
@ -0,0 +1,20 @@
|
|||
#ifndef LIGHTNING_COMMON_RANDOM_SELECT_H
|
||||
#define LIGHTNING_COMMON_RANDOM_SELECT_H
|
||||
#include "config.h"
|
||||
#include <stdbool.h>
|
||||
|
||||
/* Use weighted reservoir sampling, see:
|
||||
* https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Chao
|
||||
* But (currently) the result will consist of only one sample (k=1)
|
||||
*/
|
||||
|
||||
/**
|
||||
* random_select: return true if we should select this one.
|
||||
* @weight: weight for this option (use 1.0 if all the same)
|
||||
* @tot_wieght: returns with sum of weights (must be initialized to zero)
|
||||
*
|
||||
* This always returns true on the first non-zero weight, and weighted
|
||||
* randomly from then on.
|
||||
*/
|
||||
bool random_select(double weight, double *tot_weight);
|
||||
#endif /* LIGHTNING_COMMON_RANDOM_SELECT_H */
|
|
@ -65,6 +65,7 @@ GOSSIPD_COMMON_OBJS := \
|
|||
common/per_peer_state.o \
|
||||
common/ping.o \
|
||||
common/pseudorand.o \
|
||||
common/random_select.o \
|
||||
common/setup.o \
|
||||
common/status.o \
|
||||
common/status_wire.o \
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include <ccan/tal/tal.h>
|
||||
#include <common/decode_array.h>
|
||||
#include <common/pseudorand.h>
|
||||
#include <common/random_select.h>
|
||||
#include <common/status.h>
|
||||
#include <common/timeout.h>
|
||||
#include <common/type_to_string.h>
|
||||
|
@ -454,7 +455,7 @@ static bool get_unannounced_nodes(const tal_t *ctx,
|
|||
{
|
||||
size_t num = 0;
|
||||
u64 offset;
|
||||
u64 threshold = pseudorand_u64();
|
||||
double total_weight = 0.0;
|
||||
|
||||
/* Pick an example short_channel_id at random to query. As a
|
||||
* side-effect this gets the node. */
|
||||
|
@ -475,11 +476,8 @@ static bool get_unannounced_nodes(const tal_t *ctx,
|
|||
(*scids)[num++] = c->scid;
|
||||
} else {
|
||||
/* Maybe replace one: approx. reservoir sampling */
|
||||
u64 p = pseudorand_u64();
|
||||
if (p > threshold) {
|
||||
if (random_select(1.0, &total_weight))
|
||||
(*scids)[pseudorand(max)] = c->scid;
|
||||
threshold = p;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -63,6 +63,9 @@ void queue_peer_msg(struct peer *peer UNNEEDED, const u8 *msg TAKES UNNEEDED)
|
|||
struct peer *random_peer(struct daemon *daemon UNNEEDED,
|
||||
bool (*check_peer)(const struct peer *peer))
|
||||
{ fprintf(stderr, "random_peer called!\n"); abort(); }
|
||||
/* Generated stub for random_select */
|
||||
bool random_select(double weight UNNEEDED, double *tot_weight UNNEEDED)
|
||||
{ fprintf(stderr, "random_select called!\n"); abort(); }
|
||||
/* Generated stub for status_failed */
|
||||
void status_failed(enum status_failreason code UNNEEDED,
|
||||
const char *fmt UNNEEDED, ...)
|
||||
|
|
|
@ -60,6 +60,7 @@ LIGHTNINGD_COMMON_OBJS := \
|
|||
common/per_peer_state.o \
|
||||
common/permute_tx.o \
|
||||
common/pseudorand.o \
|
||||
common/random_select.o \
|
||||
common/setup.o \
|
||||
common/sphinx.o \
|
||||
common/status_wire.o \
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
#include <common/jsonrpc_errors.h>
|
||||
#include <common/overflows.h>
|
||||
#include <common/param.h>
|
||||
#include <common/pseudorand.h>
|
||||
#include <common/random_select.h>
|
||||
#include <common/timeout.h>
|
||||
#include <common/utils.h>
|
||||
#include <errno.h>
|
||||
|
@ -489,15 +489,8 @@ static struct route_info **select_inchan(const tal_t *ctx,
|
|||
bool *any_offline)
|
||||
{
|
||||
/* BOLT11 struct wants an array of arrays (can provide multiple routes) */
|
||||
struct route_info **R;
|
||||
double wsum, p;
|
||||
|
||||
struct sample {
|
||||
const struct route_info *route;
|
||||
double weight;
|
||||
};
|
||||
|
||||
struct sample *S = tal_arr(tmpctx, struct sample, 0);
|
||||
struct route_info **r = NULL;
|
||||
double total_weight = 0.0;
|
||||
|
||||
*any_offline = false;
|
||||
|
||||
|
@ -505,7 +498,6 @@ static struct route_info **select_inchan(const tal_t *ctx,
|
|||
for (size_t i = 0; i < tal_count(inchans); i++) {
|
||||
struct peer *peer;
|
||||
struct channel *c;
|
||||
struct sample sample;
|
||||
struct amount_msat capacity_to_pay_us, excess, capacity;
|
||||
struct amount_sat cumulative_reserve;
|
||||
double excess_frac;
|
||||
|
@ -564,33 +556,23 @@ static struct route_info **select_inchan(const tal_t *ctx,
|
|||
continue;
|
||||
}
|
||||
|
||||
/* We don't want a 0 probability if 0 excess; it might be the
|
||||
* only one! So bump it by 1 msat */
|
||||
if (!amount_msat_add(&excess, excess, AMOUNT_MSAT(1))) {
|
||||
log_broken(ld->log, "Channel %s excess overflow!",
|
||||
type_to_string(tmpctx, struct short_channel_id, c->scid));
|
||||
continue;
|
||||
}
|
||||
excess_frac = amount_msat_ratio(excess, capacity);
|
||||
|
||||
sample.route = &inchans[i];
|
||||
sample.weight = excess_frac;
|
||||
tal_arr_expand(&S, sample);
|
||||
if (random_select(excess_frac, &total_weight)) {
|
||||
tal_free(r);
|
||||
r = tal_arr(ctx, struct route_info *, 1);
|
||||
r[0] = tal_dup(r, struct route_info, &inchans[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (!tal_count(S))
|
||||
return NULL;
|
||||
|
||||
/* Use weighted reservoir sampling, see:
|
||||
* https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Chao
|
||||
* But (currently) the result will consist of only one sample (k=1) */
|
||||
R = tal_arr(ctx, struct route_info *, 1);
|
||||
R[0] = tal_dup(R, struct route_info, S[0].route);
|
||||
wsum = S[0].weight;
|
||||
|
||||
for (size_t i = 1; i < tal_count(S); i++) {
|
||||
wsum += S[i].weight;
|
||||
p = S[i].weight / wsum;
|
||||
double random_1 = pseudorand_double(); /* range [0,1) */
|
||||
|
||||
if (random_1 <= p)
|
||||
R[0] = tal_dup(R, struct route_info, S[i].route);
|
||||
}
|
||||
|
||||
return R;
|
||||
return r;
|
||||
}
|
||||
|
||||
/** select_inchan_mpp
|
||||
|
@ -1414,6 +1396,7 @@ static struct command_result *json_waitanyinvoice(struct command *cmd,
|
|||
" is non-trivial.");
|
||||
}
|
||||
|
||||
|
||||
static const struct json_command waitanyinvoice_command = {
|
||||
"waitanyinvoice",
|
||||
"payment",
|
||||
|
@ -1423,7 +1406,6 @@ static const struct json_command waitanyinvoice_command = {
|
|||
};
|
||||
AUTODATA(json_command, &waitanyinvoice_command);
|
||||
|
||||
|
||||
/* Wait for an incoming payment matching the `label` in the JSON
|
||||
* command. This will either return immediately if the payment has
|
||||
* already been received or it may add the `cmd` to the list of
|
||||
|
|
|
@ -16,6 +16,7 @@ LIGHTNINGD_TEST_COMMON_OBJS := \
|
|||
common/json.o \
|
||||
common/key_derive.o \
|
||||
common/pseudorand.o \
|
||||
common/random_select.o \
|
||||
common/memleak.o \
|
||||
common/msg_queue.o \
|
||||
common/utils.o \
|
||||
|
|
|
@ -70,6 +70,7 @@ PLUGIN_COMMON_OBJS := \
|
|||
common/node_id.o \
|
||||
common/param.o \
|
||||
common/pseudorand.o \
|
||||
common/random_select.o \
|
||||
common/setup.o \
|
||||
common/type_to_string.o \
|
||||
common/utils.o \
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include <ccan/tal/str/str.h>
|
||||
#include <common/json_stream.h>
|
||||
#include <common/pseudorand.h>
|
||||
#include <common/random_select.h>
|
||||
#include <common/type_to_string.h>
|
||||
#include <plugins/libplugin-pay.h>
|
||||
|
||||
|
@ -2421,12 +2422,11 @@ static struct command_result *shadow_route_listchannels(struct command *cmd,
|
|||
const jsmntok_t *result,
|
||||
struct payment *p)
|
||||
{
|
||||
/* Use reservoir sampling across the capable channels. */
|
||||
struct shadow_route_data *d = payment_mod_shadowroute_get_data(p);
|
||||
struct payment_constraints *cons = &d->constraints;
|
||||
struct route_info *best = NULL;
|
||||
double total_weight = 0.0;
|
||||
size_t i;
|
||||
u64 sample = 0;
|
||||
struct amount_msat best_fee;
|
||||
const jsmntok_t *sattok, *delaytok, *basefeetok, *propfeetok, *desttok,
|
||||
*channelstok, *chan, *scidtok;
|
||||
|
@ -2438,7 +2438,6 @@ static struct command_result *shadow_route_listchannels(struct command *cmd,
|
|||
|
||||
channelstok = json_get_member(buf, result, "channels");
|
||||
json_for_each_arr(i, chan, channelstok) {
|
||||
u64 v = pseudorand(UINT64_MAX);
|
||||
struct route_info curr;
|
||||
struct amount_sat capacity;
|
||||
struct amount_msat fee;
|
||||
|
@ -2465,28 +2464,27 @@ static struct command_result *shadow_route_listchannels(struct command *cmd,
|
|||
json_to_sat(buf, sattok, &capacity);
|
||||
json_to_node_id(buf, desttok, &curr.pubkey);
|
||||
|
||||
if (!best || v > sample) {
|
||||
/* If the capacity is insufficient to pass the amount
|
||||
* it's not a plausible extension. */
|
||||
if (amount_msat_greater_sat(p->amount, capacity))
|
||||
continue;
|
||||
/* If the capacity is insufficient to pass the amount
|
||||
* it's not a plausible extension. */
|
||||
if (amount_msat_greater_sat(p->amount, capacity))
|
||||
continue;
|
||||
|
||||
if (curr.cltv_expiry_delta > cons->cltv_budget)
|
||||
continue;
|
||||
if (curr.cltv_expiry_delta > cons->cltv_budget)
|
||||
continue;
|
||||
|
||||
if (!amount_msat_fee(
|
||||
&fee, p->amount, curr.fee_base_msat,
|
||||
curr.fee_proportional_millionths)) {
|
||||
/* Fee computation failed... */
|
||||
continue;
|
||||
}
|
||||
if (!amount_msat_fee(
|
||||
&fee, p->amount, curr.fee_base_msat,
|
||||
curr.fee_proportional_millionths)) {
|
||||
/* Fee computation failed... */
|
||||
continue;
|
||||
}
|
||||
|
||||
if (amount_msat_greater_eq(fee, cons->fee_budget))
|
||||
continue;
|
||||
if (amount_msat_greater_eq(fee, cons->fee_budget))
|
||||
continue;
|
||||
|
||||
if (random_select(1.0, &total_weight)) {
|
||||
best = tal_dup(tmpctx, struct route_info, &curr);
|
||||
best_fee = fee;
|
||||
sample = v;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue