Don't per_peer_state read locks recursively in monitor updating

When handling a `ChannelMonitor` update via the new
`handle_new_monitor_update` macro, we always call the macro with
the `per_peer_state` read lock held and have the macro drop the
per-peer state lock. Then, when handling the resulting updates, we
may take the `per_peer_state` read lock again in another function.

In a coming commit, recursive read locks will be disallowed, so we
have to drop the `per_peer_state` read lock before calling
additional functions in `handle_new_monitor_update`, which we do
here.
This commit is contained in:
Matt Corallo 2023-02-22 22:10:46 +00:00
parent 38374dde42
commit 2c8a26c6d2

View file

@ -1417,7 +1417,7 @@ macro_rules! emit_channel_ready_event {
}
macro_rules! handle_monitor_update_completion {
($self: ident, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $chan: expr) => { {
($self: ident, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $per_peer_state_lock: expr, $chan: expr) => { {
let mut updates = $chan.monitor_updating_restored(&$self.logger,
&$self.node_signer, $self.genesis_hash, &$self.default_configuration,
$self.best_block.read().unwrap().height());
@ -1450,6 +1450,7 @@ macro_rules! handle_monitor_update_completion {
let channel_id = $chan.channel_id();
core::mem::drop($peer_state_lock);
core::mem::drop($per_peer_state_lock);
$self.handle_monitor_update_completion_actions(update_actions);
@ -1465,7 +1466,7 @@ macro_rules! handle_monitor_update_completion {
}
macro_rules! handle_new_monitor_update {
($self: ident, $update_res: expr, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $chan: expr, MANUALLY_REMOVING, $remove: expr) => { {
($self: ident, $update_res: expr, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $per_peer_state_lock: expr, $chan: expr, MANUALLY_REMOVING, $remove: expr) => { {
// update_maps_on_chan_removal needs to be able to take id_to_peer, so make sure we can in
// any case so that it won't deadlock.
debug_assert!($self.id_to_peer.try_lock().is_ok());
@ -1492,14 +1493,14 @@ macro_rules! handle_new_monitor_update {
.update_id == $update_id) &&
$chan.get_latest_monitor_update_id() == $update_id
{
handle_monitor_update_completion!($self, $update_id, $peer_state_lock, $peer_state, $chan);
handle_monitor_update_completion!($self, $update_id, $peer_state_lock, $peer_state, $per_peer_state_lock, $chan);
}
Ok(())
},
}
} };
($self: ident, $update_res: expr, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $chan_entry: expr) => {
handle_new_monitor_update!($self, $update_res, $update_id, $peer_state_lock, $peer_state, $chan_entry.get_mut(), MANUALLY_REMOVING, $chan_entry.remove_entry())
($self: ident, $update_res: expr, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $per_peer_state_lock: expr, $chan_entry: expr) => {
handle_new_monitor_update!($self, $update_res, $update_id, $peer_state_lock, $peer_state, $per_peer_state_lock, $chan_entry.get_mut(), MANUALLY_REMOVING, $chan_entry.remove_entry())
}
}
@ -1835,7 +1836,7 @@ where
if let Some(monitor_update) = monitor_update_opt.take() {
let update_id = monitor_update.update_id;
let update_res = self.chain_monitor.update_channel(funding_txo_opt.unwrap(), monitor_update);
break handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, chan_entry);
break handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, per_peer_state, chan_entry);
}
if chan_entry.get().is_shutdown() {
@ -2464,7 +2465,7 @@ where
Some(monitor_update) => {
let update_id = monitor_update.update_id;
let update_res = self.chain_monitor.update_channel(funding_txo, monitor_update);
if let Err(e) = handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, chan) {
if let Err(e) = handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, per_peer_state, chan) {
break Err(e);
}
if update_res == ChannelMonitorUpdateStatus::InProgress {
@ -3991,7 +3992,8 @@ where
)
).unwrap_or(None);
if let Some(mut peer_state_lock) = peer_state_opt.take() {
if peer_state_opt.is_some() {
let mut peer_state_lock = peer_state_opt.unwrap();
let peer_state = &mut *peer_state_lock;
if let hash_map::Entry::Occupied(mut chan) = peer_state.channel_by_id.entry(chan_id) {
let counterparty_node_id = chan.get().get_counterparty_node_id();
@ -4006,7 +4008,7 @@ where
let update_id = monitor_update.update_id;
let update_res = self.chain_monitor.update_channel(prev_hop.outpoint, monitor_update);
let res = handle_new_monitor_update!(self, update_res, update_id, peer_state_lock,
peer_state, chan);
peer_state, per_peer_state, chan);
if let Err(e) = res {
// TODO: This is a *critical* error - we probably updated the outbound edge
// of the HTLC's monitor with a preimage. We should retry this monitor
@ -4207,7 +4209,7 @@ where
if !channel.get().is_awaiting_monitor_update() || channel.get().get_latest_monitor_update_id() != highest_applied_update_id {
return;
}
handle_monitor_update_completion!(self, highest_applied_update_id, peer_state_lock, peer_state, channel.get_mut());
handle_monitor_update_completion!(self, highest_applied_update_id, peer_state_lock, peer_state, per_peer_state, channel.get_mut());
}
/// Accepts a request to open a channel after a [`Event::OpenChannelRequest`].
@ -4513,7 +4515,8 @@ where
let monitor_res = self.chain_monitor.watch_channel(monitor.get_funding_txo().0, monitor);
let chan = e.insert(chan);
let mut res = handle_new_monitor_update!(self, monitor_res, 0, peer_state_lock, peer_state, chan, MANUALLY_REMOVING, { peer_state.channel_by_id.remove(&new_channel_id) });
let mut res = handle_new_monitor_update!(self, monitor_res, 0, peer_state_lock, peer_state,
per_peer_state, chan, MANUALLY_REMOVING, { peer_state.channel_by_id.remove(&new_channel_id) });
// Note that we reply with the new channel_id in error messages if we gave up on the
// channel, not the temporary_channel_id. This is compatible with ourselves, but the
@ -4546,7 +4549,7 @@ where
let monitor = try_chan_entry!(self,
chan.get_mut().funding_signed(&msg, best_block, &self.signer_provider, &self.logger), chan);
let update_res = self.chain_monitor.watch_channel(chan.get().get_funding_txo().unwrap(), monitor);
let mut res = handle_new_monitor_update!(self, update_res, 0, peer_state_lock, peer_state, chan);
let mut res = handle_new_monitor_update!(self, update_res, 0, peer_state_lock, peer_state, per_peer_state, chan);
if let Err(MsgHandleErrInternal { ref mut shutdown_finish, .. }) = res {
// We weren't able to watch the channel to begin with, so no updates should be made on
// it. Previously, full_stack_target found an (unreachable) panic when the
@ -4642,7 +4645,7 @@ where
if let Some(monitor_update) = monitor_update_opt {
let update_id = monitor_update.update_id;
let update_res = self.chain_monitor.update_channel(funding_txo_opt.unwrap(), monitor_update);
break handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, chan_entry);
break handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, per_peer_state, chan_entry);
}
break Ok(());
},
@ -4834,7 +4837,7 @@ where
let update_res = self.chain_monitor.update_channel(funding_txo.unwrap(), monitor_update);
let update_id = monitor_update.update_id;
handle_new_monitor_update!(self, update_res, update_id, peer_state_lock,
peer_state, chan)
peer_state, per_peer_state, chan)
},
hash_map::Entry::Vacant(_) => return Err(MsgHandleErrInternal::send_err_msg_no_close(format!("Got a message for a channel from the wrong node! No such channel for the passed counterparty_node_id {}", counterparty_node_id), msg.channel_id))
}
@ -4940,12 +4943,11 @@ where
fn internal_revoke_and_ack(&self, counterparty_node_id: &PublicKey, msg: &msgs::RevokeAndACK) -> Result<(), MsgHandleErrInternal> {
let (htlcs_to_fail, res) = {
let per_peer_state = self.per_peer_state.read().unwrap();
let peer_state_mutex = per_peer_state.get(counterparty_node_id)
let mut peer_state_lock = per_peer_state.get(counterparty_node_id)
.ok_or_else(|| {
debug_assert!(false);
MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.channel_id)
})?;
let mut peer_state_lock = peer_state_mutex.lock().unwrap();
}).map(|mtx| mtx.lock().unwrap())?;
let peer_state = &mut *peer_state_lock;
match peer_state.channel_by_id.entry(msg.channel_id) {
hash_map::Entry::Occupied(mut chan) => {
@ -4953,8 +4955,8 @@ where
let (htlcs_to_fail, monitor_update) = try_chan_entry!(self, chan.get_mut().revoke_and_ack(&msg, &self.logger), chan);
let update_res = self.chain_monitor.update_channel(funding_txo.unwrap(), monitor_update);
let update_id = monitor_update.update_id;
let res = handle_new_monitor_update!(self, update_res, update_id, peer_state_lock,
peer_state, chan);
let res = handle_new_monitor_update!(self, update_res, update_id,
peer_state_lock, peer_state, per_peer_state, chan);
(htlcs_to_fail, res)
},
hash_map::Entry::Vacant(_) => return Err(MsgHandleErrInternal::send_err_msg_no_close(format!("Got a message for a channel from the wrong node! No such channel for the passed counterparty_node_id {}", counterparty_node_id), msg.channel_id))
@ -5211,38 +5213,45 @@ where
let mut has_monitor_update = false;
let mut failed_htlcs = Vec::new();
let mut handle_errors = Vec::new();
let per_peer_state = self.per_peer_state.read().unwrap();
for (_cp_id, peer_state_mutex) in per_peer_state.iter() {
'chan_loop: loop {
let mut peer_state_lock = peer_state_mutex.lock().unwrap();
let peer_state: &mut PeerState<_> = &mut *peer_state_lock;
for (channel_id, chan) in peer_state.channel_by_id.iter_mut() {
let counterparty_node_id = chan.get_counterparty_node_id();
let funding_txo = chan.get_funding_txo();
let (monitor_opt, holding_cell_failed_htlcs) =
chan.maybe_free_holding_cell_htlcs(&self.logger);
if !holding_cell_failed_htlcs.is_empty() {
failed_htlcs.push((holding_cell_failed_htlcs, *channel_id, counterparty_node_id));
}
if let Some(monitor_update) = monitor_opt {
has_monitor_update = true;
let update_res = self.chain_monitor.update_channel(
funding_txo.expect("channel is live"), monitor_update);
let update_id = monitor_update.update_id;
let channel_id: [u8; 32] = *channel_id;
let res = handle_new_monitor_update!(self, update_res, update_id,
peer_state_lock, peer_state, chan, MANUALLY_REMOVING,
peer_state.channel_by_id.remove(&channel_id));
if res.is_err() {
handle_errors.push((counterparty_node_id, res));
// Walk our list of channels and find any that need to update. Note that when we do find an
// update, if it includes actions that must be taken afterwards, we have to drop the
// per-peer state lock as well as the top level per_peer_state lock. Thus, we loop until we
// manage to go through all our peers without finding a single channel to update.
'peer_loop: loop {
let per_peer_state = self.per_peer_state.read().unwrap();
for (_cp_id, peer_state_mutex) in per_peer_state.iter() {
'chan_loop: loop {
let mut peer_state_lock = peer_state_mutex.lock().unwrap();
let peer_state: &mut PeerState<_> = &mut *peer_state_lock;
for (channel_id, chan) in peer_state.channel_by_id.iter_mut() {
let counterparty_node_id = chan.get_counterparty_node_id();
let funding_txo = chan.get_funding_txo();
let (monitor_opt, holding_cell_failed_htlcs) =
chan.maybe_free_holding_cell_htlcs(&self.logger);
if !holding_cell_failed_htlcs.is_empty() {
failed_htlcs.push((holding_cell_failed_htlcs, *channel_id, counterparty_node_id));
}
if let Some(monitor_update) = monitor_opt {
has_monitor_update = true;
let update_res = self.chain_monitor.update_channel(
funding_txo.expect("channel is live"), monitor_update);
let update_id = monitor_update.update_id;
let channel_id: [u8; 32] = *channel_id;
let res = handle_new_monitor_update!(self, update_res, update_id,
peer_state_lock, peer_state, per_peer_state, chan, MANUALLY_REMOVING,
peer_state.channel_by_id.remove(&channel_id));
if res.is_err() {
handle_errors.push((counterparty_node_id, res));
}
continue 'peer_loop;
}
continue 'chan_loop;
}
break 'chan_loop;
}
break 'chan_loop;
}
break 'peer_loop;
}
let has_update = has_monitor_update || !failed_htlcs.is_empty() || !handle_errors.is_empty();