Fix (and DRY) the conditionals before calling peer_disconnected

If we have a peer that sends a non-`Init` first message, we'll call
`peer_disconnected` without ever having called `peer_connected`
(which has to wait until we have an `Init` message). This is a
violation of our API guarantees, though should generally not be an
issue.

Because this bug was repeated in a few places, we also take this
opportunity to DRY up the logic which checks the peer state before
calling `peer_disconnected`.

Found by the new `ChannelManager` assertions and the
`full_stack_target` fuzzer.
This commit is contained in:
Matt Corallo 2023-02-15 01:13:57 +00:00
parent 2edb3f1983
commit 3554678e9c

View file

@ -393,6 +393,12 @@ struct Peer {
/// We cache a `NodeId` here to avoid serializing peers' keys every time we forward gossip
/// messages in `PeerManager`. Use `Peer::set_their_node_id` to modify this field.
their_node_id: Option<(PublicKey, NodeId)>,
/// The features provided in the peer's [`msgs::Init`] message.
///
/// This is set only after we've processed the [`msgs::Init`] message and called relevant
/// `peer_connected` handler methods. Thus, this field is set *iff* we've finished our
/// handshake and can talk to this peer normally (though use [`Peer::handshake_complete`] to
/// check this.
their_features: Option<InitFeatures>,
their_net_address: Option<NetAddress>,
@ -424,6 +430,13 @@ struct Peer {
}
impl Peer {
/// True after we've processed the [`msgs::Init`] message and called relevant `peer_connected`
/// handler methods. Thus, this implies we've finished our handshake and can talk to this peer
/// normally.
fn handshake_complete(&self) -> bool {
self.their_features.is_some()
}
/// Returns true if the channel announcements/updates for the given channel should be
/// forwarded to this peer.
/// If we are sending our routing table to this peer and we have not yet sent channel
@ -1877,24 +1890,21 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
// thread can be holding the peer lock if we have the global write
// lock).
if let Some(mut descriptor) = self.node_id_to_descriptor.lock().unwrap().remove(&node_id) {
let descriptor_opt = self.node_id_to_descriptor.lock().unwrap().remove(&node_id);
if let Some(mut descriptor) = descriptor_opt {
if let Some(peer_mutex) = peers.remove(&descriptor) {
let mut peer = peer_mutex.lock().unwrap();
if let Some(msg) = msg {
log_trace!(self.logger, "Handling DisconnectPeer HandleError event in peer_handler for node {} with message {}",
log_pubkey!(node_id),
msg.data);
let mut peer = peer_mutex.lock().unwrap();
self.enqueue_message(&mut *peer, &msg);
// This isn't guaranteed to work, but if there is enough free
// room in the send buffer, put the error message there...
self.do_attempt_write_data(&mut descriptor, &mut *peer, false);
} else {
log_trace!(self.logger, "Handling DisconnectPeer HandleError event in peer_handler for node {} with no message", log_pubkey!(node_id));
}
self.do_disconnect(descriptor, &*peer, "DisconnectPeer HandleError");
}
descriptor.disconnect_socket();
self.message_handler.chan_handler.peer_disconnected(&node_id, false);
self.message_handler.onion_message_handler.peer_disconnected(&node_id, false);
}
}
}
@ -1905,6 +1915,22 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
self.disconnect_event_internal(descriptor, false);
}
fn do_disconnect(&self, mut descriptor: Descriptor, peer: &Peer, reason: &'static str) {
if !peer.handshake_complete() {
log_trace!(self.logger, "Disconnecting peer which hasn't completed handshake due to {}", reason);
descriptor.disconnect_socket();
return;
}
debug_assert!(peer.their_node_id.is_some());
if let Some((node_id, _)) = peer.their_node_id {
log_trace!(self.logger, "Disconnecting peer with id {} due to {}", node_id, reason);
self.message_handler.chan_handler.peer_disconnected(&node_id, false);
self.message_handler.onion_message_handler.peer_disconnected(&node_id, false);
}
descriptor.disconnect_socket();
}
fn disconnect_event_internal(&self, descriptor: &Descriptor, no_connection_possible: bool) {
let mut peers = self.peers.write().unwrap();
let peer_option = peers.remove(descriptor);
@ -1916,6 +1942,8 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
},
Some(peer_lock) => {
let peer = peer_lock.lock().unwrap();
if !peer.handshake_complete() { return; }
debug_assert!(peer.their_node_id.is_some());
if let Some((node_id, _)) = peer.their_node_id {
log_trace!(self.logger,
"Handling disconnection of peer {}, with {}future connection to the peer possible.",
@ -1937,14 +1965,13 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
/// peer. Thus, be very careful about reentrancy issues.
///
/// [`disconnect_socket`]: SocketDescriptor::disconnect_socket
pub fn disconnect_by_node_id(&self, node_id: PublicKey, no_connection_possible: bool) {
pub fn disconnect_by_node_id(&self, node_id: PublicKey, _no_connection_possible: bool) {
let mut peers_lock = self.peers.write().unwrap();
if let Some(mut descriptor) = self.node_id_to_descriptor.lock().unwrap().remove(&node_id) {
log_trace!(self.logger, "Disconnecting peer with id {} due to client request", node_id);
peers_lock.remove(&descriptor);
self.message_handler.chan_handler.peer_disconnected(&node_id, no_connection_possible);
self.message_handler.onion_message_handler.peer_disconnected(&node_id, no_connection_possible);
descriptor.disconnect_socket();
if let Some(descriptor) = self.node_id_to_descriptor.lock().unwrap().remove(&node_id) {
let peer_opt = peers_lock.remove(&descriptor);
if let Some(peer_mutex) = peer_opt {
self.do_disconnect(descriptor, &*peer_mutex.lock().unwrap(), "client request");
} else { debug_assert!(false, "node_id_to_descriptor thought we had a peer"); }
}
}
@ -1955,13 +1982,8 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
let mut peers_lock = self.peers.write().unwrap();
self.node_id_to_descriptor.lock().unwrap().clear();
let peers = &mut *peers_lock;
for (mut descriptor, peer) in peers.drain() {
if let Some((node_id, _)) = peer.lock().unwrap().their_node_id {
log_trace!(self.logger, "Disconnecting peer with id {} due to client request to disconnect all peers", node_id);
self.message_handler.chan_handler.peer_disconnected(&node_id, false);
self.message_handler.onion_message_handler.peer_disconnected(&node_id, false);
}
descriptor.disconnect_socket();
for (descriptor, peer_mutex) in peers.drain() {
self.do_disconnect(descriptor, &*peer_mutex.lock().unwrap(), "client request to disconnect all peers");
}
}
@ -2052,21 +2074,16 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
if !descriptors_needing_disconnect.is_empty() {
{
let mut peers_lock = self.peers.write().unwrap();
for descriptor in descriptors_needing_disconnect.iter() {
if let Some(peer) = peers_lock.remove(descriptor) {
if let Some((node_id, _)) = peer.lock().unwrap().their_node_id {
log_trace!(self.logger, "Disconnecting peer with id {} due to ping timeout", node_id);
for descriptor in descriptors_needing_disconnect {
if let Some(peer_mutex) = peers_lock.remove(&descriptor) {
let peer = peer_mutex.lock().unwrap();
if let Some((node_id, _)) = peer.their_node_id {
self.node_id_to_descriptor.lock().unwrap().remove(&node_id);
self.message_handler.chan_handler.peer_disconnected(&node_id, false);
self.message_handler.onion_message_handler.peer_disconnected(&node_id, false);
}
self.do_disconnect(descriptor, &*peer, "ping timeout");
}
}
}
for mut descriptor in descriptors_needing_disconnect.drain(..) {
descriptor.disconnect_socket();
}
}
}