Use CandidateRouteHop as input for channel_penalty_msat

We remove `source`, `target` and `scid` from
  `channel_penalty_msat` inputs to consume them from
  `candidate` of type `CandidateRouteHop`
This commit is contained in:
jbesraa 2023-09-06 16:04:08 +03:00
parent 04e93fc887
commit f0ecc3ec73
No known key found for this signature in database
GPG Key ID: 3297752B76B93547
4 changed files with 399 additions and 194 deletions

View File

@ -863,8 +863,8 @@ mod tests {
use lightning::ln::msgs::{ChannelMessageHandler, Init}; use lightning::ln::msgs::{ChannelMessageHandler, Init};
use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler}; use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler};
use lightning::routing::gossip::{NetworkGraph, NodeId, P2PGossipSync}; use lightning::routing::gossip::{NetworkGraph, NodeId, P2PGossipSync};
use lightning::routing::router::{DefaultRouter, Path, RouteHop};
use lightning::routing::scoring::{ChannelUsage, ScoreUpdate, ScoreLookUp, LockableScore}; use lightning::routing::scoring::{ChannelUsage, ScoreUpdate, ScoreLookUp, LockableScore};
use lightning::routing::router::{DefaultRouter, Path, RouteHop, CandidateRouteHop};
use lightning::util::config::UserConfig; use lightning::util::config::UserConfig;
use lightning::util::ser::Writeable; use lightning::util::ser::Writeable;
use lightning::util::test_utils; use lightning::util::test_utils;
@ -1071,7 +1071,7 @@ mod tests {
impl ScoreLookUp for TestScorer { impl ScoreLookUp for TestScorer {
type ScoreParams = (); type ScoreParams = ();
fn channel_penalty_msat( fn channel_penalty_msat(
&self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId, _usage: ChannelUsage, _score_params: &Self::ScoreParams &self, _candidate: &CandidateRouteHop, _usage: ChannelUsage, _score_params: &Self::ScoreParams
) -> u64 { unimplemented!(); } ) -> u64 { unimplemented!(); }
} }

View File

@ -130,18 +130,27 @@ impl<'a, S: Deref> ScorerAccountingForInFlightHtlcs<'a, S> where S::Target: Scor
impl<'a, S: Deref> ScoreLookUp for ScorerAccountingForInFlightHtlcs<'a, S> where S::Target: ScoreLookUp { impl<'a, S: Deref> ScoreLookUp for ScorerAccountingForInFlightHtlcs<'a, S> where S::Target: ScoreLookUp {
type ScoreParams = <S::Target as ScoreLookUp>::ScoreParams; type ScoreParams = <S::Target as ScoreLookUp>::ScoreParams;
fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 { fn channel_penalty_msat(&self, candidate: &CandidateRouteHop, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
let target = match candidate.target() {
Some(target) => target,
None => return self.scorer.channel_penalty_msat(candidate, usage, score_params),
};
let short_channel_id = match candidate.short_channel_id() {
Some(short_channel_id) => short_channel_id,
None => return self.scorer.channel_penalty_msat(candidate, usage, score_params),
};
let source = candidate.source();
if let Some(used_liquidity) = self.inflight_htlcs.used_liquidity_msat( if let Some(used_liquidity) = self.inflight_htlcs.used_liquidity_msat(
source, target, short_channel_id &source, &target, short_channel_id
) { ) {
let usage = ChannelUsage { let usage = ChannelUsage {
inflight_htlc_msat: usage.inflight_htlc_msat.saturating_add(used_liquidity), inflight_htlc_msat: usage.inflight_htlc_msat.saturating_add(used_liquidity),
..usage ..usage
}; };
self.scorer.channel_penalty_msat(short_channel_id, source, target, usage, score_params) self.scorer.channel_penalty_msat(candidate, usage, score_params)
} else { } else {
self.scorer.channel_penalty_msat(short_channel_id, source, target, usage, score_params) self.scorer.channel_penalty_msat(candidate, usage, score_params)
} }
} }
} }
@ -1068,7 +1077,7 @@ impl<'a> CandidateRouteHop<'a> {
/// For `Blinded` and `OneHopBlinded` we return `None` because next hop is not known. /// For `Blinded` and `OneHopBlinded` we return `None` because next hop is not known.
pub fn short_channel_id(&self) -> Option<u64> { pub fn short_channel_id(&self) -> Option<u64> {
match self { match self {
CandidateRouteHop::FirstHop { details, .. } => Some(details.get_outbound_payment_scid().unwrap()), CandidateRouteHop::FirstHop { details, .. } => details.get_outbound_payment_scid(),
CandidateRouteHop::PublicHop { short_channel_id, .. } => Some(*short_channel_id), CandidateRouteHop::PublicHop { short_channel_id, .. } => Some(*short_channel_id),
CandidateRouteHop::PrivateHop { hint, .. } => Some(hint.short_channel_id), CandidateRouteHop::PrivateHop { hint, .. } => Some(hint.short_channel_id),
CandidateRouteHop::Blinded { .. } => None, CandidateRouteHop::Blinded { .. } => None,
@ -1173,7 +1182,7 @@ impl<'a> CandidateRouteHop<'a> {
CandidateRouteHop::PublicHop { info, .. } => *info.source(), CandidateRouteHop::PublicHop { info, .. } => *info.source(),
CandidateRouteHop::PrivateHop { hint, .. } => hint.src_node_id.into(), CandidateRouteHop::PrivateHop { hint, .. } => hint.src_node_id.into(),
CandidateRouteHop::Blinded { hint, .. } => hint.1.introduction_node_id.into(), CandidateRouteHop::Blinded { hint, .. } => hint.1.introduction_node_id.into(),
CandidateRouteHop::OneHopBlinded { hint, .. } => hint.1.introduction_node_id.into() CandidateRouteHop::OneHopBlinded { hint, .. } => hint.1.introduction_node_id.into(),
} }
} }
/// Returns the target node id of this hop, if known. /// Returns the target node id of this hop, if known.
@ -2011,9 +2020,10 @@ where L::Target: Logger {
inflight_htlc_msat: used_liquidity_msat, inflight_htlc_msat: used_liquidity_msat,
effective_capacity, effective_capacity,
}; };
let channel_penalty_msat = scid_opt.map_or(0, let channel_penalty_msat =
|scid| scorer.channel_penalty_msat(scid, &src_node_id, &dest_node_id, scorer.channel_penalty_msat($candidate,
channel_usage, score_params)); channel_usage,
score_params);
let path_penalty_msat = $next_hops_path_penalty_msat let path_penalty_msat = $next_hops_path_penalty_msat
.saturating_add(channel_penalty_msat); .saturating_add(channel_penalty_msat);
let new_graph_node = RouteGraphNode { let new_graph_node = RouteGraphNode {
@ -2324,7 +2334,7 @@ where L::Target: Logger {
effective_capacity: candidate.effective_capacity(), effective_capacity: candidate.effective_capacity(),
}; };
let channel_penalty_msat = scorer.channel_penalty_msat( let channel_penalty_msat = scorer.channel_penalty_msat(
hop.short_channel_id, &source, &target, channel_usage, score_params &candidate, channel_usage, score_params
); );
aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
.saturating_add(channel_penalty_msat); .saturating_add(channel_penalty_msat);
@ -2879,13 +2889,13 @@ fn build_route_from_hops_internal<L: Deref>(
impl ScoreLookUp for HopScorer { impl ScoreLookUp for HopScorer {
type ScoreParams = (); type ScoreParams = ();
fn channel_penalty_msat(&self, _short_channel_id: u64, source: &NodeId, target: &NodeId, fn channel_penalty_msat(&self, candidate: &CandidateRouteHop,
_usage: ChannelUsage, _score_params: &Self::ScoreParams) -> u64 _usage: ChannelUsage, _score_params: &Self::ScoreParams) -> u64
{ {
let mut cur_id = self.our_node_id; let mut cur_id = self.our_node_id;
for i in 0..self.hop_ids.len() { for i in 0..self.hop_ids.len() {
if let Some(next_id) = self.hop_ids[i] { if let Some(next_id) = self.hop_ids[i] {
if cur_id == *source && next_id == *target { if cur_id == candidate.source() && Some(next_id) == candidate.target() {
return 0; return 0;
} }
cur_id = next_id; cur_id = next_id;
@ -2926,7 +2936,7 @@ mod tests {
use crate::routing::utxo::UtxoResult; use crate::routing::utxo::UtxoResult;
use crate::routing::router::{get_route, build_route_from_hops_internal, add_random_cltv_offset, default_node_features, use crate::routing::router::{get_route, build_route_from_hops_internal, add_random_cltv_offset, default_node_features,
BlindedTail, InFlightHtlcs, Path, PaymentParameters, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees, BlindedTail, InFlightHtlcs, Path, PaymentParameters, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees,
DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA, MAX_PATH_LENGTH_ESTIMATE, RouteParameters}; DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA, MAX_PATH_LENGTH_ESTIMATE, RouteParameters, CandidateRouteHop};
use crate::routing::scoring::{ChannelUsage, FixedPenaltyScorer, ScoreLookUp, ProbabilisticScorer, ProbabilisticScoringFeeParameters, ProbabilisticScoringDecayParameters}; use crate::routing::scoring::{ChannelUsage, FixedPenaltyScorer, ScoreLookUp, ProbabilisticScorer, ProbabilisticScoringFeeParameters, ProbabilisticScoringDecayParameters};
use crate::routing::test_utils::{add_channel, add_or_update_node, build_graph, build_line_graph, id_to_feature_flags, get_nodes, update_channel}; use crate::routing::test_utils::{add_channel, add_or_update_node, build_graph, build_line_graph, id_to_feature_flags, get_nodes, update_channel};
use crate::chain::transaction::OutPoint; use crate::chain::transaction::OutPoint;
@ -6231,8 +6241,8 @@ mod tests {
} }
impl ScoreLookUp for BadChannelScorer { impl ScoreLookUp for BadChannelScorer {
type ScoreParams = (); type ScoreParams = ();
fn channel_penalty_msat(&self, short_channel_id: u64, _: &NodeId, _: &NodeId, _: ChannelUsage, _score_params:&Self::ScoreParams) -> u64 { fn channel_penalty_msat(&self, candidate: &CandidateRouteHop, _: ChannelUsage, _score_params:&Self::ScoreParams) -> u64 {
if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 } if candidate.short_channel_id() == Some(self.short_channel_id) { u64::max_value() } else { 0 }
} }
} }
@ -6247,8 +6257,8 @@ mod tests {
impl ScoreLookUp for BadNodeScorer { impl ScoreLookUp for BadNodeScorer {
type ScoreParams = (); type ScoreParams = ();
fn channel_penalty_msat(&self, _: u64, _: &NodeId, target: &NodeId, _: ChannelUsage, _score_params:&Self::ScoreParams) -> u64 { fn channel_penalty_msat(&self, candidate: &CandidateRouteHop, _: ChannelUsage, _score_params:&Self::ScoreParams) -> u64 {
if *target == self.node_id { u64::max_value() } else { 0 } if candidate.target() == Some(self.node_id) { u64::max_value() } else { 0 }
} }
} }
@ -6736,26 +6746,32 @@ mod tests {
}; };
scorer_params.set_manual_penalty(&NodeId::from_pubkey(&nodes[3]), 123); scorer_params.set_manual_penalty(&NodeId::from_pubkey(&nodes[3]), 123);
scorer_params.set_manual_penalty(&NodeId::from_pubkey(&nodes[4]), 456); scorer_params.set_manual_penalty(&NodeId::from_pubkey(&nodes[4]), 456);
assert_eq!(scorer.channel_penalty_msat(42, &NodeId::from_pubkey(&nodes[3]), &NodeId::from_pubkey(&nodes[4]), usage, &scorer_params), 456); let network_graph = network_graph.read_only();
let channels = network_graph.channels();
let channel = channels.get(&5).unwrap();
let info = channel.as_directed_from(&NodeId::from_pubkey(&nodes[3])).unwrap();
let candidate: CandidateRouteHop = CandidateRouteHop::PublicHop {
info: info.0,
short_channel_id: 5,
};
assert_eq!(scorer.channel_penalty_msat(&candidate, usage, &scorer_params), 456);
// Then check we can get a normal route // Then check we can get a normal route
let payment_params = PaymentParameters::from_node_id(nodes[10], 42); let payment_params = PaymentParameters::from_node_id(nodes[10], 42);
let route_params = RouteParameters::from_payment_params_and_value( let route_params = RouteParameters::from_payment_params_and_value(
payment_params, 100); payment_params, 100);
let route = get_route(&our_id, &route_params, &network_graph.read_only(), None, let route = get_route(&our_id, &route_params, &network_graph, None,
Arc::clone(&logger), &scorer, &scorer_params, &random_seed_bytes); Arc::clone(&logger), &scorer, &scorer_params, &random_seed_bytes);
assert!(route.is_ok()); assert!(route.is_ok());
// Then check that we can't get a route if we ban an intermediate node. // Then check that we can't get a route if we ban an intermediate node.
scorer_params.add_banned(&NodeId::from_pubkey(&nodes[3])); scorer_params.add_banned(&NodeId::from_pubkey(&nodes[3]));
let route = get_route(&our_id, &route_params, &network_graph.read_only(), None, let route = get_route(&our_id, &route_params, &network_graph, None, Arc::clone(&logger), &scorer, &scorer_params,&random_seed_bytes);
Arc::clone(&logger), &scorer, &scorer_params, &random_seed_bytes);
assert!(route.is_err()); assert!(route.is_err());
// Finally make sure we can route again, when we remove the ban. // Finally make sure we can route again, when we remove the ban.
scorer_params.remove_banned(&NodeId::from_pubkey(&nodes[3])); scorer_params.remove_banned(&NodeId::from_pubkey(&nodes[3]));
let route = get_route(&our_id, &route_params, &network_graph.read_only(), None, let route = get_route(&our_id, &route_params, &network_graph, None, Arc::clone(&logger), &scorer, &scorer_params,&random_seed_bytes);
Arc::clone(&logger), &scorer, &scorer_params, &random_seed_bytes);
assert!(route.is_ok()); assert!(route.is_ok());
} }

File diff suppressed because it is too large Load Diff

View File

@ -17,6 +17,7 @@ use crate::chain::chainmonitor::{MonitorUpdateId, UpdateOrigin};
use crate::chain::channelmonitor; use crate::chain::channelmonitor;
use crate::chain::channelmonitor::MonitorEvent; use crate::chain::channelmonitor::MonitorEvent;
use crate::chain::transaction::OutPoint; use crate::chain::transaction::OutPoint;
use crate::routing::router::CandidateRouteHop;
use crate::sign; use crate::sign;
use crate::events; use crate::events;
use crate::events::bump_transaction::{WalletSource, Utxo}; use crate::events::bump_transaction::{WalletSource, Utxo};
@ -139,10 +140,34 @@ impl<'a> Router for TestRouter<'a> {
// Since the path is reversed, the last element in our iteration is the first // Since the path is reversed, the last element in our iteration is the first
// hop. // hop.
if idx == path.hops.len() - 1 { if idx == path.hops.len() - 1 {
scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(payer), &NodeId::from_pubkey(&hop.pubkey), usage, &Default::default()); let first_hops = match first_hops {
Some(hops) => hops,
None => continue,
};
if first_hops.len() == 0 {
continue;
}
let idx = if first_hops.len() > 1 { route.paths.iter().position(|p| p == path).unwrap_or(0) } else { 0 };
let candidate = CandidateRouteHop::FirstHop {
details: first_hops[idx],
node_id: NodeId::from_pubkey(payer)
};
scorer.channel_penalty_msat(&candidate, usage, &());
} else { } else {
let curr_hop_path_idx = path.hops.len() - 1 - idx; let network_graph = self.network_graph.read_only();
scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(&path.hops[curr_hop_path_idx - 1].pubkey), &NodeId::from_pubkey(&hop.pubkey), usage, &Default::default()); let channel = match network_graph.channel(hop.short_channel_id) {
Some(channel) => channel,
None => continue,
};
let channel = match channel.as_directed_to(&NodeId::from_pubkey(&hop.pubkey)) {
Some(channel) => channel,
None => panic!("Channel directed to {} was not found", hop.pubkey),
};
let candidate = CandidateRouteHop::PublicHop {
info: channel.0,
short_channel_id: hop.short_channel_id,
};
scorer.channel_penalty_msat(&candidate, usage, &());
} }
} }
} }
@ -1297,8 +1322,12 @@ impl crate::util::ser::Writeable for TestScorer {
impl ScoreLookUp for TestScorer { impl ScoreLookUp for TestScorer {
type ScoreParams = (); type ScoreParams = ();
fn channel_penalty_msat( fn channel_penalty_msat(
&self, short_channel_id: u64, _source: &NodeId, _target: &NodeId, usage: ChannelUsage, _score_params: &Self::ScoreParams &self, candidate: &CandidateRouteHop, usage: ChannelUsage, _score_params: &Self::ScoreParams
) -> u64 { ) -> u64 {
let short_channel_id = match candidate.short_channel_id() {
Some(scid) => scid,
None => return 0,
};
if let Some(scorer_expectations) = self.scorer_expectations.borrow_mut().as_mut() { if let Some(scorer_expectations) = self.scorer_expectations.borrow_mut().as_mut() {
match scorer_expectations.pop_front() { match scorer_expectations.pop_front() {
Some((scid, expectation)) => { Some((scid, expectation)) => {