Update tokio to 1.0

This requires ensuring TcpStreams are set in nonblocking mode as
tokio doesn't handle this for us anymore, so we adapt the public
API to just accept std TcpStreams instead of an extra conversion
hop. Luckily converting them is cheap.
This commit is contained in:
Matt Corallo 2021-01-26 15:38:19 -05:00
parent f151c02975
commit 5a403bdb13
2 changed files with 23 additions and 14 deletions

View File

@ -12,7 +12,7 @@ For Rust-Lightning clients which wish to make direct connections to Lightning P2
[dependencies]
bitcoin = "0.24"
lightning = { version = "0.0.12", path = "../lightning" }
tokio = { version = ">=0.2.12", features = [ "io-util", "macros", "rt-core", "sync", "tcp", "time" ] }
tokio = { version = "1.0", features = [ "io-util", "macros", "rt", "sync", "net", "time" ] }
[dev-dependencies]
tokio = { version = ">=0.2.12", features = [ "io-util", "macros", "rt-core", "rt-threaded", "sync", "tcp", "time" ] }
tokio = { version = "1.0", features = [ "io-util", "macros", "rt", "rt-multi-thread", "sync", "net", "time" ] }

View File

@ -24,7 +24,7 @@
//! The call site should, thus, look something like this:
//! ```
//! use tokio::sync::mpsc;
//! use tokio::net::TcpStream;
//! use std::net::TcpStream;
//! use bitcoin::secp256k1::key::PublicKey;
//! use lightning::util::events::EventsProvider;
//! use std::net::SocketAddr;
@ -86,6 +86,7 @@ use lightning::util::logger::Logger;
use std::{task, thread};
use std::net::SocketAddr;
use std::net::TcpStream as StdTcpStream;
use std::sync::{Arc, Mutex, MutexGuard};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
@ -218,7 +219,7 @@ impl Connection {
}
}
fn new(event_notify: mpsc::Sender<()>, stream: TcpStream) -> (io::ReadHalf<TcpStream>, mpsc::Receiver<()>, mpsc::Receiver<()>, Arc<Mutex<Self>>) {
fn new(event_notify: mpsc::Sender<()>, stream: StdTcpStream) -> (io::ReadHalf<TcpStream>, mpsc::Receiver<()>, mpsc::Receiver<()>, Arc<Mutex<Self>>) {
// We only ever need a channel of depth 1 here: if we returned a non-full write to the
// PeerManager, we will eventually get notified that there is room in the socket to write
// new bytes, which will generate an event. That event will be popped off the queue before
@ -229,7 +230,8 @@ impl Connection {
// we shove a value into the channel which comes after we've reset the read_paused bool to
// false.
let (read_waker, read_receiver) = mpsc::channel(1);
let (reader, writer) = io::split(stream);
stream.set_nonblocking(true).unwrap();
let (reader, writer) = io::split(TcpStream::from_std(stream).unwrap());
(reader, write_receiver, read_receiver,
Arc::new(Mutex::new(Self {
@ -248,7 +250,7 @@ impl Connection {
/// not need to poll the provided future in order to make progress.
///
/// See the module-level documentation for how to handle the event_notify mpsc::Sender.
pub fn setup_inbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, stream: TcpStream) -> impl std::future::Future<Output=()> where
pub fn setup_inbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
CMH: ChannelMessageHandler + 'static,
RMH: RoutingMessageHandler + 'static,
L: Logger + 'static + ?Sized {
@ -290,7 +292,7 @@ pub fn setup_inbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<So
/// not need to poll the provided future in order to make progress.
///
/// See the module-level documentation for how to handle the event_notify mpsc::Sender.
pub fn setup_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: TcpStream) -> impl std::future::Future<Output=()> where
pub fn setup_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
CMH: ChannelMessageHandler + 'static,
RMH: RoutingMessageHandler + 'static,
L: Logger + 'static + ?Sized {
@ -366,7 +368,7 @@ pub async fn connect_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerM
CMH: ChannelMessageHandler + 'static,
RMH: RoutingMessageHandler + 'static,
L: Logger + 'static + ?Sized {
if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), TcpStream::connect(&addr)).await {
if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), async { TcpStream::connect(&addr).await.map(|s| s.into_std().unwrap()) }).await {
Some(setup_outbound(peer_manager, event_notify, their_node_id, stream))
} else { None }
}
@ -388,7 +390,7 @@ fn wake_socket_waker(orig_ptr: *const ()) {
}
fn wake_socket_waker_by_ref(orig_ptr: *const ()) {
let sender_ptr = orig_ptr as *const mpsc::Sender<()>;
let mut sender = unsafe { (*sender_ptr).clone() };
let sender = unsafe { (*sender_ptr).clone() };
let _ = sender.try_send(());
}
fn drop_socket_waker(orig_ptr: *const ()) {
@ -512,6 +514,7 @@ mod tests {
use tokio::sync::mpsc;
use std::mem;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
@ -526,6 +529,7 @@ mod tests {
expected_pubkey: PublicKey,
pubkey_connected: mpsc::Sender<()>,
pubkey_disconnected: mpsc::Sender<()>,
disconnected_flag: AtomicBool,
msg_events: Mutex<Vec<MessageSendEvent>>,
}
impl RoutingMessageHandler for MsgHandler {
@ -559,6 +563,7 @@ mod tests {
fn handle_announcement_signatures(&self, _their_node_id: &PublicKey, _msg: &AnnouncementSignatures) {}
fn peer_disconnected(&self, their_node_id: &PublicKey, _no_connection_possible: bool) {
if *their_node_id == self.expected_pubkey {
self.disconnected_flag.store(true, Ordering::SeqCst);
self.pubkey_disconnected.clone().try_send(()).unwrap();
}
}
@ -591,6 +596,7 @@ mod tests {
expected_pubkey: b_pub,
pubkey_connected: a_connected_sender,
pubkey_disconnected: a_disconnected_sender,
disconnected_flag: AtomicBool::new(false),
msg_events: Mutex::new(Vec::new()),
});
let a_manager = Arc::new(PeerManager::new(MessageHandler {
@ -604,6 +610,7 @@ mod tests {
expected_pubkey: a_pub,
pubkey_connected: b_connected_sender,
pubkey_disconnected: b_disconnected_sender,
disconnected_flag: AtomicBool::new(false),
msg_events: Mutex::new(Vec::new()),
});
let b_manager = Arc::new(PeerManager::new(MessageHandler {
@ -624,8 +631,8 @@ mod tests {
} else { panic!("Failed to bind to v4 localhost on common ports"); };
let (sender, _receiver) = mpsc::channel(2);
let fut_a = super::setup_outbound(Arc::clone(&a_manager), sender.clone(), b_pub, tokio::net::TcpStream::from_std(conn_a).unwrap());
let fut_b = super::setup_inbound(b_manager, sender, tokio::net::TcpStream::from_std(conn_b).unwrap());
let fut_a = super::setup_outbound(Arc::clone(&a_manager), sender.clone(), b_pub, conn_a);
let fut_b = super::setup_inbound(b_manager, sender, conn_b);
tokio::time::timeout(Duration::from_secs(10), a_connected.recv()).await.unwrap();
tokio::time::timeout(Duration::from_secs(1), b_connected.recv()).await.unwrap();
@ -633,18 +640,20 @@ mod tests {
a_handler.msg_events.lock().unwrap().push(MessageSendEvent::HandleError {
node_id: b_pub, action: ErrorAction::DisconnectPeer { msg: None }
});
assert!(a_disconnected.try_recv().is_err());
assert!(b_disconnected.try_recv().is_err());
assert!(!a_handler.disconnected_flag.load(Ordering::SeqCst));
assert!(!b_handler.disconnected_flag.load(Ordering::SeqCst));
a_manager.process_events();
tokio::time::timeout(Duration::from_secs(10), a_disconnected.recv()).await.unwrap();
tokio::time::timeout(Duration::from_secs(1), b_disconnected.recv()).await.unwrap();
assert!(a_handler.disconnected_flag.load(Ordering::SeqCst));
assert!(b_handler.disconnected_flag.load(Ordering::SeqCst));
fut_a.await;
fut_b.await;
}
#[tokio::test(threaded_scheduler)]
#[tokio::test(flavor = "multi_thread")]
async fn basic_threaded_connection_test() {
do_basic_connection_test().await;
}