From f17c5f5a6b89a13e6feb8e48aa87b58951b1f21a Mon Sep 17 00:00:00 2001 From: Rusty Russell Date: Fri, 11 Oct 2024 21:30:39 +1030 Subject: [PATCH] askrene: don't use tmpctx in minflow() I tested with a really large gossmap (hacked to be 4GB), and when we keep retrying to minimize cost (calling minflow 11 times), and we don't free tmpctx. Due to an issue with how gossmap estimates the index sizes, we ended up running out of memory. This fixes it. Signed-off-by: Rusty Russell --- plugins/askrene/mcf.c | 63 ++++++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/plugins/askrene/mcf.c b/plugins/askrene/mcf.c index 7e803470d..e60d8316e 100644 --- a/plugins/askrene/mcf.c +++ b/plugins/askrene/mcf.c @@ -688,14 +688,17 @@ init_linear_network(const tal_t *ctx, const struct pay_parameters *params) * residual network with capacity greater than 0. * The path is encoded into prev, which contains the idx of the arcs that are * traversed. */ + +/* Note we eschew tmpctx here, as this can be called multiple times! */ static bool -find_admissible_path(const struct linear_network *linear_network, +find_admissible_path(const tal_t *working_ctx, + const struct linear_network *linear_network, const struct residual_network *residual_network, const u32 source, const u32 target, struct arc *prev) { bool target_found = false; /* Simple linear queue of node indexes */ - u32 *queue = tal_arr(tmpctx, u32, linear_network->max_num_arcs); + u32 *queue = tal_arr(working_ctx, u32, linear_network->max_num_arcs); size_t qstart, qend, prev_len = tal_count(prev); for(size_t i=0;imax_num_nodes); + struct arc *prev = tal_arr(working_ctx,struct arc,linear_network->max_num_nodes); while(amount>0) { // find a path from source to target - if (!find_admissible_path(linear_network, + if (!find_admissible_path(working_ctx, + linear_network, residual_network, source, target, prev)) { return false; @@ -846,7 +851,8 @@ static bool find_feasible_flow(const struct linear_network *linear_network, // TODO(eduardo): unit test this /* Similar to `find_admissible_path` but use Dijkstra to optimize the distance * label. Stops when the target is hit. */ -static bool find_optimal_path(struct dijkstra *dijkstra, +static bool find_optimal_path(const tal_t *working_ctx, + struct dijkstra *dijkstra, const struct linear_network *linear_network, const struct residual_network *residual_network, const u32 source, const u32 target, @@ -854,7 +860,7 @@ static bool find_optimal_path(struct dijkstra *dijkstra, { bool target_found = false; - bitmap *visited = tal_arrz(tmpctx, bitmap, + bitmap *visited = tal_arrz(working_ctx, bitmap, BITMAP_NWORDS(linear_network->max_num_nodes)); for(size_t i=0;i=0); zero_flow(linear_network,residual_network); - struct arc *prev = tal_arr(tmpctx,struct arc,linear_network->max_num_nodes); + struct arc *prev = tal_arr(working_ctx,struct arc,linear_network->max_num_nodes); const s64 *const distance = dijkstra_distance_data(dijkstra); @@ -955,7 +962,7 @@ static bool optimize_mcf(struct dijkstra *dijkstra, while(remaining_amount>0) { - if (!find_optimal_path(dijkstra, linear_network, + if (!find_optimal_path(working_ctx, dijkstra, linear_network, residual_network, source, target, prev)) { return false; } @@ -1002,6 +1009,7 @@ struct chan_flow * or we discover a cycle (returns a node idx with 0 balance). * */ static u32 find_path_or_cycle( + const tal_t *working_ctx, const struct gossmap *gossmap, const struct chan_flow *chan_flow, const u32 start_idx, @@ -1013,7 +1021,7 @@ static u32 find_path_or_cycle( { const size_t max_num_nodes = gossmap_max_node_idx(gossmap); bitmap *visited = - tal_arrz(tmpctx, bitmap, BITMAP_NWORDS(max_num_nodes)); + tal_arrz(working_ctx, bitmap, BITMAP_NWORDS(max_num_nodes)); u32 final_idx = start_idx; bitmap_set_bit(visited, start_idx); @@ -1176,6 +1184,7 @@ static void substract_cycle(const struct gossmap *gossmap, const u32 final_idx, * gossmap that corresponds to this flow. */ static struct flow ** get_flow_paths(const tal_t *ctx, + const tal_t *working_ctx, const struct route_query *rq, const struct linear_network *linear_network, const struct residual_network *residual_network) @@ -1183,17 +1192,17 @@ get_flow_paths(const tal_t *ctx, struct flow **flows = tal_arr(ctx,struct flow*,0); const size_t max_num_chans = gossmap_max_chan_idx(rq->gossmap); - struct chan_flow *chan_flow = tal_arrz(tmpctx,struct chan_flow,max_num_chans); + struct chan_flow *chan_flow = tal_arrz(working_ctx,struct chan_flow,max_num_chans); const size_t max_num_nodes = gossmap_max_node_idx(rq->gossmap); - s64 *balance = tal_arrz(tmpctx,s64,max_num_nodes); + s64 *balance = tal_arrz(working_ctx,s64,max_num_nodes); const struct gossmap_chan **prev_chan - = tal_arr(tmpctx,const struct gossmap_chan *,max_num_nodes); + = tal_arr(working_ctx,const struct gossmap_chan *,max_num_nodes); - int *prev_dir = tal_arr(tmpctx,int,max_num_nodes); - u32 *prev_idx = tal_arr(tmpctx,u32,max_num_nodes); + int *prev_dir = tal_arr(working_ctx,int,max_num_nodes); + u32 *prev_idx = tal_arr(working_ctx,u32,max_num_nodes); for (u32 node_idx = 0; node_idx < max_num_nodes; node_idx++) prev_idx[node_idx] = INVALID_INDEX; @@ -1232,6 +1241,7 @@ get_flow_paths(const tal_t *ctx, { prev_chan[node_idx] = NULL; u32 final_idx = find_path_or_cycle( + working_ctx, rq->gossmap, chan_flow, node_idx, balance, prev_chan, prev_dir, prev_idx); @@ -1278,8 +1288,10 @@ struct flow **minflow(const tal_t *ctx, u32 prob_cost_factor) { struct flow **flow_paths; - - struct pay_parameters *params = tal(tmpctx,struct pay_parameters); + /* We allocate everything off this, and free it at the end, + * as we can be called multiple times without cleaning tmpctx! */ + tal_t *working_ctx = tal(NULL, char); + struct pay_parameters *params = tal(working_ctx, struct pay_parameters); struct dijkstra *dijkstra; params->rq = rq; @@ -1306,11 +1318,11 @@ struct flow **minflow(const tal_t *ctx, params->prob_cost_factor = prob_cost_factor; // build the uncertainty network with linearization and residual arcs - struct linear_network *linear_network= init_linear_network(tmpctx, params); + struct linear_network *linear_network= init_linear_network(working_ctx, params); struct residual_network *residual_network = - alloc_residual_network(tmpctx, linear_network->max_num_nodes, + alloc_residual_network(working_ctx, linear_network->max_num_nodes, linear_network->max_num_arcs); - dijkstra = dijkstra_new(tmpctx, gossmap_max_node_idx(rq->gossmap)); + dijkstra = dijkstra_new(working_ctx, gossmap_max_node_idx(rq->gossmap)); const u32 target_idx = gossmap_node_idx(rq->gossmap,target); const u32 source_idx = gossmap_node_idx(rq->gossmap,source); @@ -1333,23 +1345,26 @@ struct flow **minflow(const tal_t *ctx, * flow units. */ const u64 pay_amount_sats = (params->amount.millisatoshis + 999)/1000; /* Raw: minflow */ - if (!find_feasible_flow(linear_network, residual_network, + if (!find_feasible_flow(working_ctx, linear_network, residual_network, source_idx, target_idx, pay_amount_sats)) { + tal_free(working_ctx); return NULL; } combine_cost_function(linear_network, residual_network, mu); /* We solve a linear MCF problem. */ - if(!optimize_mcf(dijkstra,linear_network,residual_network, + if(!optimize_mcf(working_ctx, dijkstra,linear_network,residual_network, source_idx,target_idx,pay_amount_sats)) { + tal_free(working_ctx); return NULL; } /* We dissect the solution of the MCF into payment routes. * Actual amounts considering fees are computed for every * channel in the routes. */ - flow_paths = get_flow_paths(tmpctx, rq, + flow_paths = get_flow_paths(ctx, working_ctx, rq, linear_network, residual_network); + tal_free(working_ctx); return flow_paths; }