renepay: switch from arc_t to struct arc.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell 2023-08-02 07:00:09 +09:30
parent b793dc9224
commit 15c8a6f6fe
2 changed files with 47 additions and 76 deletions

View file

@ -305,17 +305,6 @@ static inline struct arc arc_from_parts(u32 chanidx, int chandir, u32 part, bool
return arc;
}
typedef union
{
struct{
u32 dual: 1;
u32 part: PARTS_BITS;
u32 chandir: 1;
u32 chanidx: (32-1-PARTS_BITS-1);
};
u32 idx;
} arc_t;
#define MAX(x, y) (((x) > (y)) ? (x) : (y))
#define MIN(x, y) (((x) < (y)) ? (x) : (y))
@ -356,8 +345,8 @@ struct linear_network
// notice that a tail node is not needed,
// because the tail of arc is the head of dual(arc)
arc_t *node_adjacency_next_arc;
arc_t *node_adjacency_first_arc;
struct arc *node_adjacency_next_arc;
struct arc *node_adjacency_first_arc;
// probability and fee cost associated to an arc
s64 *arc_prob_cost, *arc_fee_cost;
@ -381,22 +370,24 @@ struct residual_network {
/* Helper function.
* Given an arc idx, return the dual's idx in the residual network. */
static arc_t arc_dual(arc_t arc)
static struct arc arc_dual(struct arc arc)
{
arc.dual ^= 1;
arc.idx ^= (1U << ARC_DUAL_BITOFF);
return arc;
}
/* Helper function. */
static bool arc_is_dual(const arc_t arc)
static bool arc_is_dual(struct arc arc)
{
return arc.dual == 1;
bool dual;
arc_to_parts(arc, NULL, NULL, NULL, &dual);
return dual;
}
/* Helper function.
* Given an arc of the network (not residual) give me the flow. */
static s64 get_arc_flow(
const struct residual_network *network,
const arc_t arc)
const struct arc arc)
{
assert(!arc_is_dual(arc));
assert(arc_dual(arc).idx < tal_count(network->cap));
@ -406,7 +397,7 @@ static s64 get_arc_flow(
/* Helper function.
* Given an arc idx, return the node from which this arc emanates in the residual network. */
static u32 arc_tail(const struct linear_network *linear_network,
const arc_t arc)
const struct arc arc)
{
assert(arc.idx < tal_count(linear_network->arc_tail_node));
return linear_network->arc_tail_node[ arc.idx ];
@ -414,9 +405,9 @@ static u32 arc_tail(const struct linear_network *linear_network,
/* Helper function.
* Given an arc idx, return the node that this arc is pointing to in the residual network. */
static u32 arc_head(const struct linear_network *linear_network,
const arc_t arc)
const struct arc arc)
{
const arc_t dual = arc_dual(arc);
const struct arc dual = arc_dual(arc);
assert(dual.idx < tal_count(linear_network->arc_tail_node));
return linear_network->arc_tail_node[dual.idx];
}
@ -424,7 +415,7 @@ static u32 arc_head(const struct linear_network *linear_network,
/* Helper function.
* Given node idx `node`, return the idx of the first arc whose tail is `node`.
* */
static arc_t node_adjacency_begin(
static struct arc node_adjacency_begin(
const struct linear_network * linear_network,
const u32 node)
{
@ -434,39 +425,21 @@ static arc_t node_adjacency_begin(
/* Helper function.
* Is this the end of the adjacency list. */
static bool node_adjacency_end(const arc_t arc)
static bool node_adjacency_end(const struct arc arc)
{
return arc.idx == INVALID_INDEX;
}
/* Helper function.
* Given node idx `node` and `arc`, returns the idx of the next arc whose tail is `node`. */
static arc_t node_adjacency_next(
static struct arc node_adjacency_next(
const struct linear_network *linear_network,
const arc_t arc)
const struct arc arc)
{
assert(arc.idx < tal_count(linear_network->node_adjacency_next_arc));
return linear_network->node_adjacency_next_arc[arc.idx];
}
/* Helper function.
* Given a channel index, we should be able to deduce the arc id. */
static arc_t channel_idx_to_arc(
const u32 chan_idx,
int half,
int part,
int dual)
{
arc_t arc;
arc.dual=dual;
arc.part=part;
arc.chandir=half;
arc.chanidx = chan_idx;
/* check that it doesn't overflow */
assert(arc.chanidx == chan_idx);
return arc;
}
// TODO(eduardo): unit test this
/* Split a directed channel into parts with linear cost function. */
static void linearize_channel(
@ -538,14 +511,13 @@ static void init_residual_network(
{
const size_t max_num_arcs = linear_network->max_num_arcs;
const size_t max_num_nodes = linear_network->max_num_nodes;
for(u32 idx=0;idx<max_num_arcs;++idx)
{
arc_t arc = (arc_t){.idx=idx};
for(struct arc arc = {0};arc.idx < max_num_arcs; ++arc.idx)
{
if(arc_is_dual(arc))
continue;
arc_t dual = arc_dual(arc);
struct arc dual = arc_dual(arc);
residual_network->cap[arc.idx]=linear_network->capacity[arc.idx];
residual_network->cap[dual.idx]=0;
@ -562,19 +534,18 @@ static void combine_cost_function(
struct residual_network *residual_network,
s64 mu)
{
for(u32 arc_idx=0;arc_idx<linear_network->max_num_arcs;++arc_idx)
for(struct arc arc = {0};arc.idx < linear_network->max_num_arcs; ++arc.idx)
{
arc_t arc = (arc_t){.idx=arc_idx};
if(arc_tail(linear_network,arc)==INVALID_INDEX)
continue;
const s64 pcost = linear_network->arc_prob_cost[arc_idx],
fcost = linear_network->arc_fee_cost[arc_idx];
const s64 pcost = linear_network->arc_prob_cost[arc.idx],
fcost = linear_network->arc_fee_cost[arc.idx];
const s64 combined = pcost==INFINITE || fcost==INFINITE ? INFINITE :
mu*fcost + (MU_MAX-1-mu)*pcost;
residual_network->cost[arc_idx]
residual_network->cost[arc.idx]
= mu==0 ? pcost :
(mu==(MU_MAX-1) ? fcost : combined);
}
@ -583,13 +554,13 @@ static void combine_cost_function(
static void linear_network_add_adjacenct_arc(
struct linear_network *linear_network,
const u32 node_idx,
const arc_t arc)
const struct arc arc)
{
assert(arc.idx < tal_count(linear_network->arc_tail_node));
linear_network->arc_tail_node[arc.idx] = node_idx;
assert(node_idx < tal_count(linear_network->node_adjacency_first_arc));
const arc_t first_arc = linear_network->node_adjacency_first_arc[node_idx];
const struct arc first_arc = linear_network->node_adjacency_first_arc[node_idx];
assert(arc.idx < tal_count(linear_network->node_adjacency_next_arc));
linear_network->node_adjacency_next_arc[arc.idx]=first_arc;
@ -614,11 +585,11 @@ static void init_linear_network(
for(size_t i=0;i<tal_count(linear_network->arc_tail_node);++i)
linear_network->arc_tail_node[i]=INVALID_INDEX;
linear_network->node_adjacency_next_arc = tal_arr(linear_network,arc_t,max_num_arcs);
linear_network->node_adjacency_next_arc = tal_arr(linear_network,struct arc,max_num_arcs);
for(size_t i=0;i<tal_count(linear_network->node_adjacency_next_arc);++i)
linear_network->node_adjacency_next_arc[i].idx=INVALID_INDEX;
linear_network->node_adjacency_first_arc = tal_arr(linear_network,arc_t,max_num_nodes);
linear_network->node_adjacency_first_arc = tal_arr(linear_network,struct arc,max_num_nodes);
for(size_t i=0;i<tal_count(linear_network->node_adjacency_first_arc);++i)
linear_network->node_adjacency_first_arc[i].idx=INVALID_INDEX;
@ -682,7 +653,7 @@ static void init_linear_network(
{
// if(capacity[k]==0)continue;
arc_t arc = channel_idx_to_arc(chan_id,half,k,0);
struct arc arc = arc_from_parts(chan_id, half, k, false);
linear_network_add_adjacenct_arc(linear_network,node_id,arc);
@ -692,7 +663,7 @@ static void init_linear_network(
linear_network->arc_fee_cost[arc.idx] = fee_cost;
// + the respective dual
arc_t dual = arc_dual(arc);
struct arc dual = arc_dual(arc);
linear_network_add_adjacenct_arc(linear_network,next_id,dual);
@ -723,7 +694,7 @@ static int find_admissible_path(
const struct residual_network *residual_network,
const u32 source,
const u32 target,
arc_t *prev)
struct arc *prev)
{
tal_t *this_ctx = tal(tmpctx,tal_t);
@ -753,7 +724,7 @@ static int find_admissible_path(
break;
}
for(arc_t arc = node_adjacency_begin(linear_network,cur);
for(struct arc arc = node_adjacency_begin(linear_network,cur);
!node_adjacency_end(arc);
arc = node_adjacency_next(linear_network,arc))
{
@ -787,7 +758,7 @@ static s64 get_augmenting_flow(
const struct residual_network *residual_network,
const u32 source,
const u32 target,
const arc_t *prev)
const struct arc *prev)
{
s64 flow = INFINITE;
@ -795,7 +766,7 @@ static s64 get_augmenting_flow(
while(cur!=source)
{
assert(cur<tal_count(prev));
const arc_t arc = prev[cur];
const struct arc arc = prev[cur];
flow = MIN(flow , residual_network->cap[arc.idx]);
// we are traversing in the opposite direction to the flow,
@ -813,7 +784,7 @@ static void augment_flow(
struct residual_network *residual_network,
const u32 source,
const u32 target,
const arc_t *prev,
const struct arc *prev,
s64 flow)
{
u32 cur = target;
@ -821,8 +792,8 @@ static void augment_flow(
while(cur!=source)
{
assert(cur < tal_count(prev));
const arc_t arc = prev[cur];
const arc_t dual = arc_dual(arc);
const struct arc arc = prev[cur];
const struct arc dual = arc_dual(arc);
assert(arc.idx < tal_count(residual_network->cap));
assert(dual.idx < tal_count(residual_network->cap));
@ -862,7 +833,7 @@ static int find_feasible_flow(
/* path information
* prev: is the id of the arc that lead to the node. */
arc_t *prev = tal_arr(this_ctx,arc_t,linear_network->max_num_nodes);
struct arc *prev = tal_arr(this_ctx,struct arc,linear_network->max_num_nodes);
while(amount>0)
{
@ -903,7 +874,7 @@ static int find_optimal_path(
const struct residual_network* residual_network,
const u32 source,
const u32 target,
arc_t *prev)
struct arc *prev)
{
tal_t *this_ctx = tal(tmpctx,tal_t);
int ret = RENEPAY_ERR_NOFEASIBLEFLOW;
@ -935,7 +906,7 @@ static int find_optimal_path(
break;
}
for(arc_t arc = node_adjacency_begin(linear_network,cur);
for(struct arc arc = node_adjacency_begin(linear_network,cur);
!node_adjacency_end(arc);
arc = node_adjacency_next(linear_network,arc))
{
@ -971,13 +942,13 @@ static void zero_flow(
for(u32 node=0;node<linear_network->max_num_nodes;++node)
{
residual_network->potential[node]=0;
for(arc_t arc=node_adjacency_begin(linear_network,node);
for(struct arc arc=node_adjacency_begin(linear_network,node);
!node_adjacency_end(arc);
arc = node_adjacency_next(linear_network,arc))
{
if(arc_is_dual(arc))continue;
arc_t dual = arc_dual(arc);
struct arc dual = arc_dual(arc);
residual_network->cap[arc.idx] = linear_network->capacity[arc.idx];
residual_network->cap[dual.idx] = 0;
@ -1008,7 +979,7 @@ static int optimize_mcf(
int ret = RENEPAY_ERR_OK;
zero_flow(linear_network,residual_network);
arc_t *prev = tal_arr(this_ctx,arc_t,linear_network->max_num_nodes);
struct arc *prev = tal_arr(this_ctx,struct arc,linear_network->max_num_nodes);
const s64 *const distance = dijkstra_distance_data(dijkstra);
@ -1172,7 +1143,7 @@ static struct flow **
// Compute balance on the nodes.
for(u32 n = 0;n<max_num_nodes;++n)
{
for(arc_t arc = node_adjacency_begin(linear_network,n);
for(struct arc arc = node_adjacency_begin(linear_network,n);
!node_adjacency_end(arc);
arc = node_adjacency_next(linear_network,arc))
{
@ -1180,11 +1151,14 @@ static struct flow **
continue;
u32 m = arc_head(linear_network,arc);
s64 flow = get_arc_flow(residual_network,arc);
u32 chanidx;
int chandir;
balance[n] -= flow;
balance[m] += flow;
chan_flow[arc.chanidx].half[arc.chandir] +=flow;
arc_to_parts(arc, &chanidx, &chandir, NULL, NULL);
chan_flow[chanidx].half[chandir] +=flow;
}
}

View file

@ -98,15 +98,12 @@ int main(int argc, char *argv[])
arc_to_parts(a, NULL, NULL, NULL, &dual);
assert(dual == i);
/* This code not converted yet! */
#if 0
assert(arc_is_dual(a) == dual);
a = arc_dual(a);
arc_to_parts(a, NULL, NULL, NULL, &dual);
assert(dual == !i);
assert(arc_is_dual(a) == dual);
#endif
}
common_shutdown();