Individually lock NetworkGraph fields

In preparation for giving NetworkGraph shared ownership, wrap individual
fields in RwLock. This allows removing the outer RwLock used in
NetGraphMsgHandler.
This commit is contained in:
Jeffrey Czyz 2021-08-09 22:24:41 -05:00
parent a6e650630d
commit 777661ae52
No known key found for this signature in database
GPG key ID: 3A4E08275D5E96D2
2 changed files with 84 additions and 55 deletions

View file

@ -51,11 +51,11 @@ const MAX_EXCESS_BYTES_FOR_RELAY: usize = 1024;
const MAX_SCIDS_PER_REPLY: usize = 8000;
/// Represents the network as nodes and channels between them
#[derive(Clone, PartialEq)]
pub struct NetworkGraph {
genesis_hash: BlockHash,
channels: BTreeMap<u64, ChannelInfo>,
nodes: BTreeMap<PublicKey, NodeInfo>,
// Lock order: channels -> nodes
channels: RwLock<BTreeMap<u64, ChannelInfo>>,
nodes: RwLock<BTreeMap<PublicKey, NodeInfo>>,
}
/// A simple newtype for RwLockReadGuard<'a, NetworkGraph>.
@ -193,7 +193,8 @@ impl<C: Deref , L: Deref > RoutingMessageHandler for NetGraphMsgHandler<C, L> wh
fn get_next_channel_announcements(&self, starting_point: u64, batch_amount: u8) -> Vec<(ChannelAnnouncement, Option<ChannelUpdate>, Option<ChannelUpdate>)> {
let network_graph = self.network_graph.read().unwrap();
let mut result = Vec::with_capacity(batch_amount as usize);
let mut iter = network_graph.get_channels().range(starting_point..);
let channels = network_graph.get_channels();
let mut iter = channels.range(starting_point..);
while result.len() < batch_amount as usize {
if let Some((_, ref chan)) = iter.next() {
if chan.announcement_message.is_some() {
@ -221,12 +222,13 @@ impl<C: Deref , L: Deref > RoutingMessageHandler for NetGraphMsgHandler<C, L> wh
fn get_next_node_announcements(&self, starting_point: Option<&PublicKey>, batch_amount: u8) -> Vec<NodeAnnouncement> {
let network_graph = self.network_graph.read().unwrap();
let mut result = Vec::with_capacity(batch_amount as usize);
let nodes = network_graph.get_nodes();
let mut iter = if let Some(pubkey) = starting_point {
let mut iter = network_graph.get_nodes().range((*pubkey)..);
let mut iter = nodes.range((*pubkey)..);
iter.next();
iter
} else {
network_graph.get_nodes().range(..)
nodes.range(..)
};
while result.len() < batch_amount as usize {
if let Some((_, ref node)) = iter.next() {
@ -616,13 +618,15 @@ impl Writeable for NetworkGraph {
write_ver_prefix!(writer, SERIALIZATION_VERSION, MIN_SERIALIZATION_VERSION);
self.genesis_hash.write(writer)?;
(self.channels.len() as u64).write(writer)?;
for (ref chan_id, ref chan_info) in self.channels.iter() {
let channels = self.channels.read().unwrap();
(channels.len() as u64).write(writer)?;
for (ref chan_id, ref chan_info) in channels.iter() {
(*chan_id).write(writer)?;
chan_info.write(writer)?;
}
(self.nodes.len() as u64).write(writer)?;
for (ref node_id, ref node_info) in self.nodes.iter() {
let nodes = self.nodes.read().unwrap();
(nodes.len() as u64).write(writer)?;
for (ref node_id, ref node_info) in nodes.iter() {
node_id.write(writer)?;
node_info.write(writer)?;
}
@ -655,8 +659,8 @@ impl Readable for NetworkGraph {
Ok(NetworkGraph {
genesis_hash,
channels,
nodes,
channels: RwLock::new(channels),
nodes: RwLock::new(nodes),
})
}
}
@ -664,36 +668,49 @@ impl Readable for NetworkGraph {
impl fmt::Display for NetworkGraph {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
writeln!(f, "Network map\n[Channels]")?;
for (key, val) in self.channels.iter() {
for (key, val) in self.channels.read().unwrap().iter() {
writeln!(f, " {}: {}", key, val)?;
}
writeln!(f, "[Nodes]")?;
for (key, val) in self.nodes.iter() {
for (key, val) in self.nodes.read().unwrap().iter() {
writeln!(f, " {}: {}", log_pubkey!(key), val)?;
}
Ok(())
}
}
impl PartialEq for NetworkGraph {
fn eq(&self, other: &Self) -> bool {
self.genesis_hash == other.genesis_hash &&
*self.channels.read().unwrap() == *other.channels.read().unwrap() &&
*self.nodes.read().unwrap() == *other.nodes.read().unwrap()
}
}
impl NetworkGraph {
/// Returns all known valid channels' short ids along with announced channel info.
///
/// (C-not exported) because we have no mapping for `BTreeMap`s
pub fn get_channels<'a>(&'a self) -> &'a BTreeMap<u64, ChannelInfo> { &self.channels }
pub fn get_channels(&self) -> RwLockReadGuard<'_, BTreeMap<u64, ChannelInfo>> {
self.channels.read().unwrap()
}
/// Returns all known nodes' public keys along with announced node info.
///
/// (C-not exported) because we have no mapping for `BTreeMap`s
pub fn get_nodes<'a>(&'a self) -> &'a BTreeMap<PublicKey, NodeInfo> { &self.nodes }
pub fn get_nodes(&self) -> RwLockReadGuard<'_, BTreeMap<PublicKey, NodeInfo>> {
self.nodes.read().unwrap()
}
/// Get network addresses by node id.
/// Returns None if the requested node is completely unknown,
/// or if node announcement for the node was never received.
///
/// (C-not exported) as there is no practical way to track lifetimes of returned values.
pub fn get_addresses<'a>(&'a self, pubkey: &PublicKey) -> Option<&'a Vec<NetAddress>> {
if let Some(node) = self.nodes.get(pubkey) {
pub fn get_addresses(&self, pubkey: &PublicKey) -> Option<Vec<NetAddress>> {
if let Some(node) = self.nodes.read().unwrap().get(pubkey) {
if let Some(node_info) = node.announcement_info.as_ref() {
return Some(&node_info.addresses)
return Some(node_info.addresses.clone())
}
}
None
@ -703,8 +720,8 @@ impl NetworkGraph {
pub fn new(genesis_hash: BlockHash) -> NetworkGraph {
Self {
genesis_hash,
channels: BTreeMap::new(),
nodes: BTreeMap::new(),
channels: RwLock::new(BTreeMap::new()),
nodes: RwLock::new(BTreeMap::new()),
}
}
@ -729,7 +746,7 @@ impl NetworkGraph {
}
fn update_node_from_announcement_intern(&mut self, msg: &msgs::UnsignedNodeAnnouncement, full_msg: Option<&msgs::NodeAnnouncement>) -> Result<(), LightningError> {
match self.nodes.get_mut(&msg.node_id) {
match self.nodes.write().unwrap().get_mut(&msg.node_id) {
None => Err(LightningError{err: "No existing channels for node_announcement".to_owned(), action: ErrorAction::IgnoreError}),
Some(node) => {
if let Some(node_info) = node.announcement_info.as_ref() {
@ -838,7 +855,9 @@ impl NetworkGraph {
{ full_msg.cloned() } else { None },
};
match self.channels.entry(msg.short_channel_id) {
let mut channels = self.channels.write().unwrap();
let mut nodes = self.nodes.write().unwrap();
match channels.entry(msg.short_channel_id) {
BtreeEntry::Occupied(mut entry) => {
//TODO: because asking the blockchain if short_channel_id is valid is only optional
//in the blockchain API, we need to handle it smartly here, though it's unclear
@ -852,7 +871,7 @@ impl NetworkGraph {
// b) we don't track UTXOs of channels we know about and remove them if they
// get reorg'd out.
// c) it's unclear how to do so without exposing ourselves to massive DoS risk.
Self::remove_channel_in_nodes(&mut self.nodes, &entry.get(), msg.short_channel_id);
Self::remove_channel_in_nodes(&mut nodes, &entry.get(), msg.short_channel_id);
*entry.get_mut() = chan_info;
} else {
return Err(LightningError{err: "Already have knowledge of channel".to_owned(), action: ErrorAction::IgnoreAndLog(Level::Trace)})
@ -865,7 +884,7 @@ impl NetworkGraph {
macro_rules! add_channel_to_node {
( $node_id: expr ) => {
match self.nodes.entry($node_id) {
match nodes.entry($node_id) {
BtreeEntry::Occupied(node_entry) => {
node_entry.into_mut().channels.push(msg.short_channel_id);
},
@ -891,12 +910,14 @@ impl NetworkGraph {
/// May cause the removal of nodes too, if this was their last channel.
/// If not permanent, makes channels unavailable for routing.
pub fn close_channel_from_update(&mut self, short_channel_id: u64, is_permanent: bool) {
let mut channels = self.channels.write().unwrap();
if is_permanent {
if let Some(chan) = self.channels.remove(&short_channel_id) {
Self::remove_channel_in_nodes(&mut self.nodes, &chan, short_channel_id);
if let Some(chan) = channels.remove(&short_channel_id) {
let mut nodes = self.nodes.write().unwrap();
Self::remove_channel_in_nodes(&mut nodes, &chan, short_channel_id);
}
} else {
if let Some(chan) = self.channels.get_mut(&short_channel_id) {
if let Some(chan) = channels.get_mut(&short_channel_id) {
if let Some(one_to_two) = chan.one_to_two.as_mut() {
one_to_two.enabled = false;
}
@ -937,7 +958,8 @@ impl NetworkGraph {
let chan_enabled = msg.flags & (1 << 1) != (1 << 1);
let chan_was_enabled;
match self.channels.get_mut(&msg.short_channel_id) {
let mut channels = self.channels.write().unwrap();
match channels.get_mut(&msg.short_channel_id) {
None => return Err(LightningError{err: "Couldn't find channel for update".to_owned(), action: ErrorAction::IgnoreError}),
Some(channel) => {
if let OptionalField::Present(htlc_maximum_msat) = msg.htlc_maximum_msat {
@ -1000,8 +1022,9 @@ impl NetworkGraph {
}
}
let mut nodes = self.nodes.write().unwrap();
if chan_enabled {
let node = self.nodes.get_mut(&dest_node_id).unwrap();
let node = nodes.get_mut(&dest_node_id).unwrap();
let mut base_msat = msg.fee_base_msat;
let mut proportional_millionths = msg.fee_proportional_millionths;
if let Some(fees) = node.lowest_inbound_channel_fees {
@ -1013,11 +1036,11 @@ impl NetworkGraph {
proportional_millionths
});
} else if chan_was_enabled {
let node = self.nodes.get_mut(&dest_node_id).unwrap();
let node = nodes.get_mut(&dest_node_id).unwrap();
let mut lowest_inbound_channel_fees = None;
for chan_id in node.channels.iter() {
let chan = self.channels.get(chan_id).unwrap();
let chan = channels.get(chan_id).unwrap();
let chan_info_opt;
if chan.node_one == dest_node_id {
chan_info_opt = chan.two_to_one.as_ref();
@ -1268,7 +1291,7 @@ mod tests {
match network.get_channels().get(&unsigned_announcement.short_channel_id) {
None => panic!(),
Some(_) => ()
}
};
}
// If we receive announcement for the same channel (with UTXO lookups disabled),
@ -1320,7 +1343,7 @@ mod tests {
match network.get_channels().get(&unsigned_announcement.short_channel_id) {
None => panic!(),
Some(_) => ()
}
};
}
// If we receive announcement for the same channel (but TX is not confirmed),
@ -1353,7 +1376,7 @@ mod tests {
assert_eq!(channel_entry.features, ChannelFeatures::empty());
},
_ => panic!()
}
};
}
// Don't relay valid channels with excess data
@ -1484,7 +1507,7 @@ mod tests {
assert_eq!(channel_info.one_to_two.as_ref().unwrap().cltv_expiry_delta, 144);
assert!(channel_info.two_to_one.is_none());
}
}
};
}
unsigned_channel_update.timestamp += 100;
@ -1645,7 +1668,7 @@ mod tests {
Some(channel_info) => {
assert!(channel_info.one_to_two.is_some());
}
}
};
}
let channel_close_msg = HTLCFailChannelUpdate::ChannelClosed {
@ -1663,7 +1686,7 @@ mod tests {
Some(channel_info) => {
assert!(!channel_info.one_to_two.as_ref().unwrap().enabled);
}
}
};
}
let channel_close_msg = HTLCFailChannelUpdate::ChannelClosed {

View file

@ -443,6 +443,8 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
// to use as the A* heuristic beyond just the cost to get one node further than the current
// one.
let network_channels = network.get_channels();
let network_nodes = network.get_nodes();
let dummy_directional_info = DummyDirectionalChannelInfo { // used for first_hops routes
cltv_expiry_delta: 0,
htlc_minimum_msat: 0,
@ -458,7 +460,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
// work reliably.
let allow_mpp = if let Some(features) = &payee_features {
features.supports_basic_mpp()
} else if let Some(node) = network.get_nodes().get(&payee) {
} else if let Some(node) = network_nodes.get(&payee) {
if let Some(node_info) = node.announcement_info.as_ref() {
node_info.features.supports_basic_mpp()
} else { false }
@ -492,7 +494,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
// Map from node_id to information about the best current path to that node, including feerate
// information.
let mut dist = HashMap::with_capacity(network.get_nodes().len());
let mut dist = HashMap::with_capacity(network_nodes.len());
// During routing, if we ignore a path due to an htlc_minimum_msat limit, we set this,
// indicating that we may wish to try again with a higher value, potentially paying to meet an
@ -511,7 +513,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
// This map allows paths to be aware of the channel use by other paths in the same call.
// This would help to make a better path finding decisions and not "overbook" channels.
// It is unaware of the directions (except for `outbound_capacity_msat` in `first_hops`).
let mut bookkeeped_channels_liquidity_available_msat = HashMap::with_capacity(network.get_nodes().len());
let mut bookkeeped_channels_liquidity_available_msat = HashMap::with_capacity(network_nodes.len());
// Keeping track of how much value we already collected across other paths. Helps to decide:
// - how much a new path should be transferring (upper bound);
@ -629,7 +631,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
// as a way to reach the $dest_node_id.
let mut fee_base_msat = u32::max_value();
let mut fee_proportional_millionths = u32::max_value();
if let Some(Some(fees)) = network.get_nodes().get(&$src_node_id).map(|node| node.lowest_inbound_channel_fees) {
if let Some(Some(fees)) = network_nodes.get(&$src_node_id).map(|node| node.lowest_inbound_channel_fees) {
fee_base_msat = fees.base_msat;
fee_proportional_millionths = fees.proportional_millionths;
}
@ -814,7 +816,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
if !features.requires_unknown_bits() {
for chan_id in $node.channels.iter() {
let chan = network.get_channels().get(chan_id).unwrap();
let chan = network_channels.get(chan_id).unwrap();
if !chan.features.requires_unknown_bits() {
if chan.node_one == *$node_id {
// ie $node is one, ie next hop in A* is two, via the two_to_one channel
@ -862,7 +864,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
// Add the payee as a target, so that the payee-to-payer
// search algorithm knows what to start with.
match network.get_nodes().get(payee) {
match network_nodes.get(payee) {
// The payee is not in our network graph, so nothing to add here.
// There is still a chance of reaching them via last_hops though,
// so don't yet fail the payment here.
@ -884,7 +886,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
// we have a direct channel to the first hop or the first hop is
// in the regular network graph.
first_hop_targets.get(&first_hop_in_route.src_node_id).is_some() ||
network.get_nodes().get(&first_hop_in_route.src_node_id).is_some();
network_nodes.get(&first_hop_in_route.src_node_id).is_some();
if have_hop_src_in_graph {
// We start building the path from reverse, i.e., from payee
// to the first RouteHintHop in the path.
@ -991,7 +993,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
'path_walk: loop {
if let Some(&(_, _, _, ref features)) = first_hop_targets.get(&ordered_hops.last().unwrap().0.pubkey) {
ordered_hops.last_mut().unwrap().1 = features.clone();
} else if let Some(node) = network.get_nodes().get(&ordered_hops.last().unwrap().0.pubkey) {
} else if let Some(node) = network_nodes.get(&ordered_hops.last().unwrap().0.pubkey) {
if let Some(node_info) = node.announcement_info.as_ref() {
ordered_hops.last_mut().unwrap().1 = node_info.features.clone();
} else {
@ -1093,7 +1095,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
// Otherwise, since the current target node is not us,
// keep "unrolling" the payment graph from payee to payer by
// finding a way to reach the current target from the payer side.
match network.get_nodes().get(&pubkey) {
match network_nodes.get(&pubkey) {
None => {},
Some(node) => {
add_entries_to_cheapest_to_target_node!(node, &pubkey, lowest_fee_to_node, value_contribution_msat, path_htlc_minimum_msat);
@ -4211,12 +4213,13 @@ mod tests {
// First, get 100 (source, destination) pairs for which route-getting actually succeeds...
let mut seed = random_init_seed() as usize;
let nodes = graph.get_nodes();
'load_endpoints: for _ in 0..10 {
loop {
seed = seed.overflowing_mul(0xdeadbeef).0;
let src = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
let src = nodes.keys().skip(seed % nodes.len()).next().unwrap();
seed = seed.overflowing_mul(0xdeadbeef).0;
let dst = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
let dst = nodes.keys().skip(seed % nodes.len()).next().unwrap();
let amt = seed as u64 % 200_000_000;
if get_route(src, &graph, dst, None, None, &[], amt, 42, &test_utils::TestLogger::new()).is_ok() {
continue 'load_endpoints;
@ -4239,12 +4242,13 @@ mod tests {
// First, get 100 (source, destination) pairs for which route-getting actually succeeds...
let mut seed = random_init_seed() as usize;
let nodes = graph.get_nodes();
'load_endpoints: for _ in 0..10 {
loop {
seed = seed.overflowing_mul(0xdeadbeef).0;
let src = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
let src = nodes.keys().skip(seed % nodes.len()).next().unwrap();
seed = seed.overflowing_mul(0xdeadbeef).0;
let dst = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
let dst = nodes.keys().skip(seed % nodes.len()).next().unwrap();
let amt = seed as u64 % 200_000_000;
if get_route(src, &graph, dst, Some(InvoiceFeatures::known()), None, &[], amt, 42, &test_utils::TestLogger::new()).is_ok() {
continue 'load_endpoints;
@ -4297,6 +4301,7 @@ mod benches {
fn generate_routes(bench: &mut Bencher) {
let mut d = test_utils::get_route_file().unwrap();
let graph = NetworkGraph::read(&mut d).unwrap();
let nodes = graph.get_nodes();
// First, get 100 (source, destination) pairs for which route-getting actually succeeds...
let mut path_endpoints = Vec::new();
@ -4304,9 +4309,9 @@ mod benches {
'load_endpoints: for _ in 0..100 {
loop {
seed *= 0xdeadbeef;
let src = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
let src = nodes.keys().skip(seed % nodes.len()).next().unwrap();
seed *= 0xdeadbeef;
let dst = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
let dst = nodes.keys().skip(seed % nodes.len()).next().unwrap();
let amt = seed as u64 % 1_000_000;
if get_route(src, &graph, dst, None, None, &[], amt, 42, &DummyLogger{}).is_ok() {
path_endpoints.push((src, dst, amt));
@ -4328,6 +4333,7 @@ mod benches {
fn generate_mpp_routes(bench: &mut Bencher) {
let mut d = test_utils::get_route_file().unwrap();
let graph = NetworkGraph::read(&mut d).unwrap();
let nodes = graph.get_nodes();
// First, get 100 (source, destination) pairs for which route-getting actually succeeds...
let mut path_endpoints = Vec::new();
@ -4335,9 +4341,9 @@ mod benches {
'load_endpoints: for _ in 0..100 {
loop {
seed *= 0xdeadbeef;
let src = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
let src = nodes.keys().skip(seed % nodes.len()).next().unwrap();
seed *= 0xdeadbeef;
let dst = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
let dst = nodes.keys().skip(seed % nodes.len()).next().unwrap();
let amt = seed as u64 % 1_000_000;
if get_route(src, &graph, dst, Some(InvoiceFeatures::known()), None, &[], amt, 42, &DummyLogger{}).is_ok() {
path_endpoints.push((src, dst, amt));