Merge pull request #2090 from TheBlueMatt/2023-03-mon-wake-bp

Wake background-processor on async monitor update completion
This commit is contained in:
Matt Corallo 2023-04-03 20:12:46 +00:00 committed by GitHub
commit 3b8bf93251
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 292 additions and 212 deletions

View file

@ -38,6 +38,8 @@ use lightning::routing::router::Router;
use lightning::routing::scoring::{Score, WriteableScore};
use lightning::util::logger::Logger;
use lightning::util::persist::Persister;
#[cfg(feature = "std")]
use lightning::util::wakers::Sleeper;
use lightning_rapid_gossip_sync::RapidGossipSync;
use core::ops::Deref;
@ -114,6 +116,13 @@ const FIRST_NETWORK_PRUNE_TIMER: u64 = 60;
#[cfg(test)]
const FIRST_NETWORK_PRUNE_TIMER: u64 = 1;
#[cfg(feature = "futures")]
/// core::cmp::min is not currently const, so we define a trivial (and equivalent) replacement
const fn min_u64(a: u64, b: u64) -> u64 { if a < b { a } else { b } }
#[cfg(feature = "futures")]
const FASTEST_TIMER: u64 = min_u64(min_u64(FRESHNESS_TIMER, PING_TIMER),
min_u64(SCORER_PERSIST_TIMER, FIRST_NETWORK_PRUNE_TIMER));
/// Either [`P2PGossipSync`] or [`RapidGossipSync`].
pub enum GossipSync<
P: Deref<Target = P2PGossipSync<G, U, L>>,
@ -256,7 +265,8 @@ macro_rules! define_run_body {
($persister: ident, $chain_monitor: ident, $process_chain_monitor_events: expr,
$channel_manager: ident, $process_channel_manager_events: expr,
$gossip_sync: ident, $peer_manager: ident, $logger: ident, $scorer: ident,
$loop_exit_check: expr, $await: expr, $get_timer: expr, $timer_elapsed: expr)
$loop_exit_check: expr, $await: expr, $get_timer: expr, $timer_elapsed: expr,
$check_slow_await: expr)
=> { {
log_trace!($logger, "Calling ChannelManager's timer_tick_occurred on startup");
$channel_manager.timer_tick_occurred();
@ -286,9 +296,10 @@ macro_rules! define_run_body {
// We wait up to 100ms, but track how long it takes to detect being put to sleep,
// see `await_start`'s use below.
let mut await_start = $get_timer(1);
let mut await_start = None;
if $check_slow_await { await_start = Some($get_timer(1)); }
let updates_available = $await;
let await_slow = $timer_elapsed(&mut await_start, 1);
let await_slow = if $check_slow_await { $timer_elapsed(&mut await_start.unwrap(), 1) } else { false };
if updates_available {
log_trace!($logger, "Persisting ChannelManager...");
@ -388,15 +399,20 @@ pub(crate) mod futures_util {
use core::task::{Poll, Waker, RawWaker, RawWakerVTable};
use core::pin::Pin;
use core::marker::Unpin;
pub(crate) struct Selector<A: Future<Output=()> + Unpin, B: Future<Output=bool> + Unpin> {
pub(crate) struct Selector<
A: Future<Output=()> + Unpin, B: Future<Output=()> + Unpin, C: Future<Output=bool> + Unpin
> {
pub a: A,
pub b: B,
pub c: C,
}
pub(crate) enum SelectorOutput {
A, B(bool),
A, B, C(bool),
}
impl<A: Future<Output=()> + Unpin, B: Future<Output=bool> + Unpin> Future for Selector<A, B> {
impl<
A: Future<Output=()> + Unpin, B: Future<Output=()> + Unpin, C: Future<Output=bool> + Unpin
> Future for Selector<A, B, C> {
type Output = SelectorOutput;
fn poll(mut self: Pin<&mut Self>, ctx: &mut core::task::Context<'_>) -> Poll<SelectorOutput> {
match Pin::new(&mut self.a).poll(ctx) {
@ -404,7 +420,11 @@ pub(crate) mod futures_util {
Poll::Pending => {},
}
match Pin::new(&mut self.b).poll(ctx) {
Poll::Ready(res) => { return Poll::Ready(SelectorOutput::B(res)); },
Poll::Ready(()) => { return Poll::Ready(SelectorOutput::B); },
Poll::Pending => {},
}
match Pin::new(&mut self.c).poll(ctx) {
Poll::Ready(res) => { return Poll::Ready(SelectorOutput::C(res)); },
Poll::Pending => {},
}
Poll::Pending
@ -438,6 +458,11 @@ use core::task;
/// feature, doing so will skip calling [`NetworkGraph::remove_stale_channels_and_tracking`],
/// you should call [`NetworkGraph::remove_stale_channels_and_tracking_with_time`] regularly
/// manually instead.
///
/// The `mobile_interruptable_platform` flag should be set if we're currently running on a
/// mobile device, where we may need to check for interruption of the application regularly. If you
/// are unsure, you should set the flag, as the performance impact of it is minimal unless there
/// are hundreds or thousands of simultaneous process calls running.
#[cfg(feature = "futures")]
pub async fn process_events_async<
'a,
@ -473,7 +498,7 @@ pub async fn process_events_async<
>(
persister: PS, event_handler: EventHandler, chain_monitor: M, channel_manager: CM,
gossip_sync: GossipSync<PGS, RGS, G, UL, L>, peer_manager: PM, logger: L, scorer: Option<S>,
sleeper: Sleeper,
sleeper: Sleeper, mobile_interruptable_platform: bool,
) -> Result<(), lightning::io::Error>
where
UL::Target: 'static + UtxoLookup,
@ -514,11 +539,13 @@ where
gossip_sync, peer_manager, logger, scorer, should_break, {
let fut = Selector {
a: channel_manager.get_persistable_update_future(),
b: sleeper(Duration::from_millis(100)),
b: chain_monitor.get_update_future(),
c: sleeper(if mobile_interruptable_platform { Duration::from_millis(100) } else { Duration::from_secs(FASTEST_TIMER) }),
};
match fut.await {
SelectorOutput::A => true,
SelectorOutput::B(exit) => {
SelectorOutput::B => false,
SelectorOutput::C(exit) => {
should_break = exit;
false
}
@ -528,7 +555,7 @@ where
let mut waker = dummy_waker();
let mut ctx = task::Context::from_waker(&mut waker);
core::pin::Pin::new(fut).poll(&mut ctx).is_ready()
})
}, mobile_interruptable_platform)
}
#[cfg(feature = "std")]
@ -643,8 +670,11 @@ impl BackgroundProcessor {
define_run_body!(persister, chain_monitor, chain_monitor.process_pending_events(&event_handler),
channel_manager, channel_manager.process_pending_events(&event_handler),
gossip_sync, peer_manager, logger, scorer, stop_thread.load(Ordering::Acquire),
channel_manager.await_persistable_update_timeout(Duration::from_millis(100)),
|_| Instant::now(), |time: &Instant, dur| time.elapsed().as_secs() > dur)
Sleeper::from_two_futures(
channel_manager.get_persistable_update_future(),
chain_monitor.get_update_future()
).wait_timeout(Duration::from_millis(100)),
|_| Instant::now(), |time: &Instant, dur| time.elapsed().as_secs() > dur, false)
});
Self { stop_thread: stop_thread_clone, thread_handle: Some(handle) }
}

View file

@ -8,64 +8,19 @@
// licenses.
//! A socket handling library for those running in Tokio environments who wish to use
//! rust-lightning with native TcpStreams.
//! rust-lightning with native [`TcpStream`]s.
//!
//! Designed to be as simple as possible, the high-level usage is almost as simple as "hand over a
//! TcpStream and a reference to a PeerManager and the rest is handled", except for the
//! [Event](../lightning/util/events/enum.Event.html) handling mechanism; see example below.
//! [`TcpStream`] and a reference to a [`PeerManager`] and the rest is handled".
//!
//! The PeerHandler, due to the fire-and-forget nature of this logic, must be an Arc, and must use
//! the SocketDescriptor provided here as the PeerHandler's SocketDescriptor.
//! The [`PeerManager`], due to the fire-and-forget nature of this logic, must be a reference,
//! (e.g. an [`Arc`]) and must use the [`SocketDescriptor`] provided here as the [`PeerManager`]'s
//! `SocketDescriptor` implementation.
//!
//! Three methods are exposed to register a new connection for handling in tokio::spawn calls; see
//! their individual docs for details.
//! Three methods are exposed to register a new connection for handling in [`tokio::spawn`] calls;
//! see their individual docs for details.
//!
//! # Example
//! ```
//! use std::net::TcpStream;
//! use bitcoin::secp256k1::PublicKey;
//! use lightning::events::{Event, EventHandler, EventsProvider};
//! use std::net::SocketAddr;
//! use std::sync::Arc;
//!
//! // Define concrete types for our high-level objects:
//! type TxBroadcaster = dyn lightning::chain::chaininterface::BroadcasterInterface + Send + Sync;
//! type FeeEstimator = dyn lightning::chain::chaininterface::FeeEstimator + Send + Sync;
//! type Logger = dyn lightning::util::logger::Logger + Send + Sync;
//! type NodeSigner = dyn lightning::chain::keysinterface::NodeSigner + Send + Sync;
//! type UtxoLookup = dyn lightning::routing::utxo::UtxoLookup + Send + Sync;
//! type ChainFilter = dyn lightning::chain::Filter + Send + Sync;
//! type DataPersister = dyn lightning::chain::chainmonitor::Persist<lightning::chain::keysinterface::InMemorySigner> + Send + Sync;
//! type ChainMonitor = lightning::chain::chainmonitor::ChainMonitor<lightning::chain::keysinterface::InMemorySigner, Arc<ChainFilter>, Arc<TxBroadcaster>, Arc<FeeEstimator>, Arc<Logger>, Arc<DataPersister>>;
//! type ChannelManager = Arc<lightning::ln::channelmanager::SimpleArcChannelManager<ChainMonitor, TxBroadcaster, FeeEstimator, Logger>>;
//! type PeerManager = Arc<lightning::ln::peer_handler::SimpleArcPeerManager<lightning_net_tokio::SocketDescriptor, ChainMonitor, TxBroadcaster, FeeEstimator, UtxoLookup, Logger>>;
//!
//! // Connect to node with pubkey their_node_id at addr:
//! async fn connect_to_node(peer_manager: PeerManager, chain_monitor: Arc<ChainMonitor>, channel_manager: ChannelManager, their_node_id: PublicKey, addr: SocketAddr) {
//! lightning_net_tokio::connect_outbound(peer_manager, their_node_id, addr).await;
//! loop {
//! let event_handler = |event: Event| {
//! // Handle the event!
//! };
//! channel_manager.await_persistable_update();
//! channel_manager.process_pending_events(&event_handler);
//! chain_monitor.process_pending_events(&event_handler);
//! }
//! }
//!
//! // Begin reading from a newly accepted socket and talk to the peer:
//! async fn accept_socket(peer_manager: PeerManager, chain_monitor: Arc<ChainMonitor>, channel_manager: ChannelManager, socket: TcpStream) {
//! lightning_net_tokio::setup_inbound(peer_manager, socket);
//! loop {
//! let event_handler = |event: Event| {
//! // Handle the event!
//! };
//! channel_manager.await_persistable_update();
//! channel_manager.process_pending_events(&event_handler);
//! chain_monitor.process_pending_events(&event_handler);
//! }
//! }
//! ```
//! [`PeerManager`]: lightning::ln::peer_handler::PeerManager
// Prefix these with `rustdoc::` when we update our MSRV to be >= 1.52 to remove warnings.
#![deny(broken_intra_doc_links)]

View file

@ -37,6 +37,7 @@ use crate::events::{Event, EventHandler};
use crate::util::atomic_counter::AtomicCounter;
use crate::util::logger::Logger;
use crate::util::errors::APIError;
use crate::util::wakers::{Future, Notifier};
use crate::ln::channelmanager::ChannelDetails;
use crate::prelude::*;
@ -240,6 +241,8 @@ pub struct ChainMonitor<ChannelSigner: WriteableEcdsaChannelSigner, C: Deref, T:
pending_monitor_events: Mutex<Vec<(OutPoint, Vec<MonitorEvent>, Option<PublicKey>)>>,
/// The best block height seen, used as a proxy for the passage of time.
highest_chain_height: AtomicUsize,
event_notifier: Notifier,
}
impl<ChannelSigner: WriteableEcdsaChannelSigner, C: Deref, T: Deref, F: Deref, L: Deref, P: Deref> ChainMonitor<ChannelSigner, C, T, F, L, P>
@ -300,6 +303,7 @@ where C::Target: chain::Filter,
ChannelMonitorUpdateStatus::PermanentFailure => {
monitor_state.channel_perm_failed.store(true, Ordering::Release);
self.pending_monitor_events.lock().unwrap().push((*funding_outpoint, vec![MonitorEvent::UpdateFailed(*funding_outpoint)], monitor.get_counterparty_node_id()));
self.event_notifier.notify();
},
ChannelMonitorUpdateStatus::InProgress => {
log_debug!(self.logger, "Channel Monitor sync for channel {} in progress, holding events until completion!", log_funding_info!(monitor));
@ -345,6 +349,7 @@ where C::Target: chain::Filter,
persister,
pending_monitor_events: Mutex::new(Vec::new()),
highest_chain_height: AtomicUsize::new(0),
event_notifier: Notifier::new(),
}
}
@ -472,6 +477,7 @@ where C::Target: chain::Filter,
}
},
}
self.event_notifier.notify();
Ok(())
}
@ -486,6 +492,7 @@ where C::Target: chain::Filter,
funding_txo,
monitor_update_id,
}], counterparty_node_id));
self.event_notifier.notify();
}
#[cfg(any(test, fuzzing, feature = "_test_utils"))]
@ -514,6 +521,18 @@ where C::Target: chain::Filter,
handler(event).await;
}
}
/// Gets a [`Future`] that completes when an event is available either via
/// [`chain::Watch::release_pending_monitor_events`] or
/// [`EventsProvider::process_pending_events`].
///
/// Note that callbacks registered on the [`Future`] MUST NOT call back into this
/// [`ChainMonitor`] and should instead register actions to be taken later.
///
/// [`EventsProvider::process_pending_events`]: crate::events::EventsProvider::process_pending_events
pub fn get_update_future(&self) -> Future {
self.event_notifier.get_future()
}
}
impl<ChannelSigner: WriteableEcdsaChannelSigner, C: Deref, T: Deref, F: Deref, L: Deref, P: Deref>

View file

@ -6170,34 +6170,11 @@ where
}
}
/// Blocks until ChannelManager needs to be persisted or a timeout is reached. It returns a bool
/// indicating whether persistence is necessary. Only one listener on
/// [`await_persistable_update`], [`await_persistable_update_timeout`], or a future returned by
/// [`get_persistable_update_future`] is guaranteed to be woken up.
/// Gets a [`Future`] that completes when this [`ChannelManager`] needs to be persisted.
///
/// Note that this method is not available with the `no-std` feature.
/// Note that callbacks registered on the [`Future`] MUST NOT call back into this
/// [`ChannelManager`] and should instead register actions to be taken later.
///
/// [`await_persistable_update`]: Self::await_persistable_update
/// [`await_persistable_update_timeout`]: Self::await_persistable_update_timeout
/// [`get_persistable_update_future`]: Self::get_persistable_update_future
#[cfg(any(test, feature = "std"))]
pub fn await_persistable_update_timeout(&self, max_wait: Duration) -> bool {
self.persistence_notifier.wait_timeout(max_wait)
}
/// Blocks until ChannelManager needs to be persisted. Only one listener on
/// [`await_persistable_update`], `await_persistable_update_timeout`, or a future returned by
/// [`get_persistable_update_future`] is guaranteed to be woken up.
///
/// [`await_persistable_update`]: Self::await_persistable_update
/// [`get_persistable_update_future`]: Self::get_persistable_update_future
pub fn await_persistable_update(&self) {
self.persistence_notifier.wait()
}
/// Gets a [`Future`] that completes when a persistable update is available. Note that
/// callbacks registered on the [`Future`] MUST NOT call back into this [`ChannelManager`] and
/// should instead register actions to be taken later.
pub fn get_persistable_update_future(&self) -> Future {
self.persistence_notifier.get_future()
}
@ -7952,6 +7929,7 @@ mod tests {
use bitcoin::hashes::Hash;
use bitcoin::hashes::sha256::Hash as Sha256;
use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
#[cfg(feature = "std")]
use core::time::Duration;
use core::sync::atomic::Ordering;
use crate::events::{Event, HTLCDestination, MessageSendEvent, MessageSendEventsProvider, ClosureReason};
@ -7977,9 +7955,9 @@ mod tests {
// All nodes start with a persistable update pending as `create_network` connects each node
// with all other nodes to make most tests simpler.
assert!(nodes[0].node.await_persistable_update_timeout(Duration::from_millis(1)));
assert!(nodes[1].node.await_persistable_update_timeout(Duration::from_millis(1)));
assert!(nodes[2].node.await_persistable_update_timeout(Duration::from_millis(1)));
assert!(nodes[0].node.get_persistable_update_future().poll_is_complete());
assert!(nodes[1].node.get_persistable_update_future().poll_is_complete());
assert!(nodes[2].node.get_persistable_update_future().poll_is_complete());
let mut chan = create_announced_chan_between_nodes(&nodes, 0, 1);
@ -7993,19 +7971,19 @@ mod tests {
&nodes[0].node.get_our_node_id()).pop().unwrap();
// The first two nodes (which opened a channel) should now require fresh persistence
assert!(nodes[0].node.await_persistable_update_timeout(Duration::from_millis(1)));
assert!(nodes[1].node.await_persistable_update_timeout(Duration::from_millis(1)));
assert!(nodes[0].node.get_persistable_update_future().poll_is_complete());
assert!(nodes[1].node.get_persistable_update_future().poll_is_complete());
// ... but the last node should not.
assert!(!nodes[2].node.await_persistable_update_timeout(Duration::from_millis(1)));
assert!(!nodes[2].node.get_persistable_update_future().poll_is_complete());
// After persisting the first two nodes they should no longer need fresh persistence.
assert!(!nodes[0].node.await_persistable_update_timeout(Duration::from_millis(1)));
assert!(!nodes[1].node.await_persistable_update_timeout(Duration::from_millis(1)));
assert!(!nodes[0].node.get_persistable_update_future().poll_is_complete());
assert!(!nodes[1].node.get_persistable_update_future().poll_is_complete());
// Node 3, unrelated to the only channel, shouldn't care if it receives a channel_update
// about the channel.
nodes[2].node.handle_channel_update(&nodes[1].node.get_our_node_id(), &chan.0);
nodes[2].node.handle_channel_update(&nodes[1].node.get_our_node_id(), &chan.1);
assert!(!nodes[2].node.await_persistable_update_timeout(Duration::from_millis(1)));
assert!(!nodes[2].node.get_persistable_update_future().poll_is_complete());
// The nodes which are a party to the channel should also ignore messages from unrelated
// parties.
@ -8013,8 +7991,8 @@ mod tests {
nodes[0].node.handle_channel_update(&nodes[2].node.get_our_node_id(), &chan.1);
nodes[1].node.handle_channel_update(&nodes[2].node.get_our_node_id(), &chan.0);
nodes[1].node.handle_channel_update(&nodes[2].node.get_our_node_id(), &chan.1);
assert!(!nodes[0].node.await_persistable_update_timeout(Duration::from_millis(1)));
assert!(!nodes[1].node.await_persistable_update_timeout(Duration::from_millis(1)));
assert!(!nodes[0].node.get_persistable_update_future().poll_is_complete());
assert!(!nodes[1].node.get_persistable_update_future().poll_is_complete());
// At this point the channel info given by peers should still be the same.
assert_eq!(nodes[0].node.list_channels()[0], node_a_chan_info);
@ -8031,8 +8009,8 @@ mod tests {
// persisted and that its channel info remains the same.
nodes[0].node.handle_channel_update(&nodes[1].node.get_our_node_id(), &as_update);
nodes[1].node.handle_channel_update(&nodes[0].node.get_our_node_id(), &bs_update);
assert!(!nodes[0].node.await_persistable_update_timeout(Duration::from_millis(1)));
assert!(!nodes[1].node.await_persistable_update_timeout(Duration::from_millis(1)));
assert!(!nodes[0].node.get_persistable_update_future().poll_is_complete());
assert!(!nodes[1].node.get_persistable_update_future().poll_is_complete());
assert_eq!(nodes[0].node.list_channels()[0], node_a_chan_info);
assert_eq!(nodes[1].node.list_channels()[0], node_b_chan_info);
@ -8040,8 +8018,8 @@ mod tests {
// the channel info has updated.
nodes[0].node.handle_channel_update(&nodes[1].node.get_our_node_id(), &bs_update);
nodes[1].node.handle_channel_update(&nodes[0].node.get_our_node_id(), &as_update);
assert!(nodes[0].node.await_persistable_update_timeout(Duration::from_millis(1)));
assert!(nodes[1].node.await_persistable_update_timeout(Duration::from_millis(1)));
assert!(nodes[0].node.get_persistable_update_future().poll_is_complete());
assert!(nodes[1].node.get_persistable_update_future().poll_is_complete());
assert_ne!(nodes[0].node.list_channels()[0], node_a_chan_info);
assert_ne!(nodes[1].node.list_channels()[0], node_b_chan_info);
}

View file

@ -12,6 +12,8 @@ use std::sync::RwLockReadGuard as StdRwLockReadGuard;
use std::sync::RwLockWriteGuard as StdRwLockWriteGuard;
use std::sync::Condvar as StdCondvar;
pub use std::sync::WaitTimeoutResult;
use crate::prelude::HashMap;
use super::{LockTestExt, LockHeldState};
@ -35,15 +37,19 @@ impl Condvar {
Condvar { inner: StdCondvar::new() }
}
pub fn wait<'a, T>(&'a self, guard: MutexGuard<'a, T>) -> LockResult<MutexGuard<'a, T>> {
pub fn wait_while<'a, T, F: FnMut(&mut T) -> bool>(&'a self, guard: MutexGuard<'a, T>, condition: F)
-> LockResult<MutexGuard<'a, T>> {
let mutex: &'a Mutex<T> = guard.mutex;
self.inner.wait(guard.into_inner()).map(|lock| MutexGuard { mutex, lock }).map_err(|_| ())
self.inner.wait_while(guard.into_inner(), condition).map(|lock| MutexGuard { mutex, lock })
.map_err(|_| ())
}
#[allow(unused)]
pub fn wait_timeout<'a, T>(&'a self, guard: MutexGuard<'a, T>, dur: Duration) -> LockResult<(MutexGuard<'a, T>, ())> {
pub fn wait_timeout_while<'a, T, F: FnMut(&mut T) -> bool>(&'a self, guard: MutexGuard<'a, T>, dur: Duration, condition: F)
-> LockResult<(MutexGuard<'a, T>, WaitTimeoutResult)> {
let mutex = guard.mutex;
self.inner.wait_timeout(guard.into_inner(), dur).map(|(lock, _)| (MutexGuard { mutex, lock }, ())).map_err(|_| ())
self.inner.wait_timeout_while(guard.into_inner(), dur, condition).map_err(|_| ())
.map(|(lock, e)| (MutexGuard { mutex, lock }, e))
}
pub fn notify_all(&self) { self.inner.notify_all(); }

View file

@ -1,30 +1,10 @@
pub use ::alloc::sync::Arc;
use core::ops::{Deref, DerefMut};
use core::time::Duration;
use core::cell::{RefCell, Ref, RefMut};
use super::{LockTestExt, LockHeldState};
pub type LockResult<Guard> = Result<Guard, ()>;
pub struct Condvar {}
impl Condvar {
pub fn new() -> Condvar {
Condvar { }
}
pub fn wait<'a, T>(&'a self, guard: MutexGuard<'a, T>) -> LockResult<MutexGuard<'a, T>> {
Ok(guard)
}
#[allow(unused)]
pub fn wait_timeout<'a, T>(&'a self, guard: MutexGuard<'a, T>, _dur: Duration) -> LockResult<(MutexGuard<'a, T>, ())> {
Ok((guard, ()))
}
pub fn notify_all(&self) {}
}
pub struct Mutex<T: ?Sized> {
inner: RefCell<T>
}

View file

@ -15,12 +15,14 @@
use alloc::sync::Arc;
use core::mem;
use crate::sync::{Condvar, Mutex, MutexGuard};
use crate::sync::Mutex;
use crate::prelude::*;
#[cfg(any(test, feature = "std"))]
use std::time::{Duration, Instant};
#[cfg(feature = "std")]
use crate::sync::Condvar;
#[cfg(feature = "std")]
use std::time::Duration;
use core::future::Future as StdFuture;
use core::task::{Context, Poll};
@ -30,74 +32,12 @@ use core::pin::Pin;
/// Used to signal to one of many waiters that the condition they're waiting on has happened.
pub(crate) struct Notifier {
notify_pending: Mutex<(bool, Option<Arc<Mutex<FutureState>>>)>,
condvar: Condvar,
}
macro_rules! check_woken {
($guard: expr, $retval: expr) => { {
if $guard.0 {
$guard.0 = false;
if $guard.1.as_ref().map(|l| l.lock().unwrap().complete).unwrap_or(false) {
// If we're about to return as woken, and the future state is marked complete, wipe
// the future state and let the next future wait until we get a new notify.
$guard.1.take();
}
return $retval;
}
} }
}
impl Notifier {
pub(crate) fn new() -> Self {
Self {
notify_pending: Mutex::new((false, None)),
condvar: Condvar::new(),
}
}
fn propagate_future_state_to_notify_flag(&self) -> MutexGuard<(bool, Option<Arc<Mutex<FutureState>>>)> {
let mut lock = self.notify_pending.lock().unwrap();
if let Some(existing_state) = &lock.1 {
if existing_state.lock().unwrap().callbacks_made {
// If the existing `FutureState` has completed and actually made callbacks,
// consider the notification flag to have been cleared and reset the future state.
lock.1.take();
lock.0 = false;
}
}
lock
}
pub(crate) fn wait(&self) {
loop {
let mut guard = self.propagate_future_state_to_notify_flag();
check_woken!(guard, ());
guard = self.condvar.wait(guard).unwrap();
check_woken!(guard, ());
}
}
#[cfg(any(test, feature = "std"))]
pub(crate) fn wait_timeout(&self, max_wait: Duration) -> bool {
let current_time = Instant::now();
loop {
let mut guard = self.propagate_future_state_to_notify_flag();
check_woken!(guard, true);
guard = self.condvar.wait_timeout(guard, max_wait).unwrap().0;
check_woken!(guard, true);
// Due to spurious wakeups that can happen on `wait_timeout`, here we need to check if the
// desired wait time has actually passed, and if not then restart the loop with a reduced wait
// time. Note that this logic can be highly simplified through the use of
// `Condvar::wait_while` and `Condvar::wait_timeout_while`, if and when our MSRV is raised to
// 1.42.0.
let elapsed = current_time.elapsed();
if elapsed >= max_wait {
return false;
}
match max_wait.checked_sub(elapsed) {
None => return false,
Some(_) => continue
}
}
}
@ -111,13 +51,19 @@ impl Notifier {
}
}
lock.0 = true;
mem::drop(lock);
self.condvar.notify_all();
}
/// Gets a [`Future`] that will get woken up with any waiters
pub(crate) fn get_future(&self) -> Future {
let mut lock = self.propagate_future_state_to_notify_flag();
let mut lock = self.notify_pending.lock().unwrap();
if let Some(existing_state) = &lock.1 {
if existing_state.lock().unwrap().callbacks_made {
// If the existing `FutureState` has completed and actually made callbacks,
// consider the notification flag to have been cleared and reset the future state.
lock.1.take();
lock.0 = false;
}
}
if let Some(existing_state) = &lock.1 {
Future { state: Arc::clone(&existing_state) }
} else {
@ -137,6 +83,7 @@ impl Notifier {
}
}
macro_rules! define_callback { ($($bounds: path),*) => {
/// A callback which is called when a [`Future`] completes.
///
/// Note that this MUST NOT call back into LDK directly, it must instead schedule actions to be
@ -145,14 +92,20 @@ impl Notifier {
///
/// Note that the [`std::future::Future`] implementation may only work for runtimes which schedule
/// futures when they receive a wake, rather than immediately executing them.
pub trait FutureCallback : Send {
pub trait FutureCallback : $($bounds +)* {
/// The method which is called.
fn call(&self);
}
impl<F: Fn() + Send> FutureCallback for F {
impl<F: Fn() $(+ $bounds)*> FutureCallback for F {
fn call(&self) { (self)(); }
}
} }
#[cfg(feature = "std")]
define_callback!(Send);
#[cfg(not(feature = "std"))]
define_callback!();
pub(crate) struct FutureState {
// When we're tracking whether a callback counts as having woken the user's code, we check the
@ -175,6 +128,9 @@ impl FutureState {
}
/// A simple future which can complete once, and calls some callback(s) when it does so.
///
/// Clones can be made and all futures cloned from the same source will complete at the same time.
#[derive(Clone)]
pub struct Future {
state: Arc<Mutex<FutureState>>,
}
@ -204,6 +160,29 @@ impl Future {
pub fn register_callback_fn<F: 'static + FutureCallback>(&self, callback: F) {
self.register_callback(Box::new(callback));
}
/// Waits until this [`Future`] completes.
#[cfg(feature = "std")]
pub fn wait(self) {
Sleeper::from_single_future(self).wait();
}
/// Waits until this [`Future`] completes or the given amount of time has elapsed.
///
/// Returns true if the [`Future`] completed, false if the time elapsed.
#[cfg(feature = "std")]
pub fn wait_timeout(self, max_wait: Duration) -> bool {
Sleeper::from_single_future(self).wait_timeout(max_wait)
}
#[cfg(test)]
pub fn poll_is_complete(&self) -> bool {
let mut state = self.state.lock().unwrap();
if state.complete {
state.callbacks_made = true;
true
} else { false }
}
}
use core::task::Waker;
@ -229,6 +208,78 @@ impl<'a> StdFuture for Future {
}
}
/// A struct which can be used to select across many [`Future`]s at once without relying on a full
/// async context.
#[cfg(feature = "std")]
pub struct Sleeper {
notifiers: Vec<Arc<Mutex<FutureState>>>,
}
#[cfg(feature = "std")]
impl Sleeper {
/// Constructs a new sleeper from one future, allowing blocking on it.
pub fn from_single_future(future: Future) -> Self {
Self { notifiers: vec![future.state] }
}
/// Constructs a new sleeper from two futures, allowing blocking on both at once.
// Note that this is the common case - a ChannelManager and ChainMonitor.
pub fn from_two_futures(fut_a: Future, fut_b: Future) -> Self {
Self { notifiers: vec![fut_a.state, fut_b.state] }
}
/// Constructs a new sleeper on many futures, allowing blocking on all at once.
pub fn new(futures: Vec<Future>) -> Self {
Self { notifiers: futures.into_iter().map(|f| f.state).collect() }
}
/// Prepares to go into a wait loop body, creating a condition variable which we can block on
/// and an `Arc<Mutex<Option<_>>>` which gets set to the waking `Future`'s state prior to the
/// condition variable being woken.
fn setup_wait(&self) -> (Arc<Condvar>, Arc<Mutex<Option<Arc<Mutex<FutureState>>>>>) {
let cv = Arc::new(Condvar::new());
let notified_fut_mtx = Arc::new(Mutex::new(None));
{
for notifier_mtx in self.notifiers.iter() {
let cv_ref = Arc::clone(&cv);
let notified_fut_ref = Arc::clone(&notified_fut_mtx);
let notifier_ref = Arc::clone(&notifier_mtx);
let mut notifier = notifier_mtx.lock().unwrap();
if notifier.complete {
*notified_fut_mtx.lock().unwrap() = Some(notifier_ref);
break;
}
notifier.callbacks.push((false, Box::new(move || {
*notified_fut_ref.lock().unwrap() = Some(Arc::clone(&notifier_ref));
cv_ref.notify_all();
})));
}
}
(cv, notified_fut_mtx)
}
/// Wait until one of the [`Future`]s registered with this [`Sleeper`] has completed.
pub fn wait(&self) {
let (cv, notified_fut_mtx) = self.setup_wait();
let notified_fut = cv.wait_while(notified_fut_mtx.lock().unwrap(), |fut_opt| fut_opt.is_none())
.unwrap().take().expect("CV wait shouldn't have returned until the notifying future was set");
notified_fut.lock().unwrap().callbacks_made = true;
}
/// Wait until one of the [`Future`]s registered with this [`Sleeper`] has completed or the
/// given amount of time has elapsed. Returns true if a [`Future`] completed, false if the time
/// elapsed.
pub fn wait_timeout(&self, max_wait: Duration) -> bool {
let (cv, notified_fut_mtx) = self.setup_wait();
let notified_fut =
match cv.wait_timeout_while(notified_fut_mtx.lock().unwrap(), max_wait, |fut_opt| fut_opt.is_none()) {
Ok((_, e)) if e.timed_out() => return false,
Ok((mut notified_fut, _)) =>
notified_fut.take().expect("CV wait shouldn't have returned until the notifying future was set"),
Err(_) => panic!("Previous panic while a lock was held led to a lock panic"),
};
notified_fut.lock().unwrap().callbacks_made = true;
true
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -327,10 +378,7 @@ mod tests {
let exit_thread_clone = exit_thread.clone();
thread::spawn(move || {
loop {
let mut lock = thread_notifier.notify_pending.lock().unwrap();
lock.0 = true;
thread_notifier.condvar.notify_all();
thread_notifier.notify();
if exit_thread_clone.load(Ordering::SeqCst) {
break
}
@ -338,12 +386,12 @@ mod tests {
});
// Check that we can block indefinitely until updates are available.
let _ = persistence_notifier.wait();
let _ = persistence_notifier.get_future().wait();
// Check that the Notifier will return after the given duration if updates are
// available.
loop {
if persistence_notifier.wait_timeout(Duration::from_millis(100)) {
if persistence_notifier.get_future().wait_timeout(Duration::from_millis(100)) {
break
}
}
@ -353,7 +401,7 @@ mod tests {
// Check that the Notifier will return after the given duration even if no updates
// are available.
loop {
if !persistence_notifier.wait_timeout(Duration::from_millis(100)) {
if !persistence_notifier.get_future().wait_timeout(Duration::from_millis(100)) {
break
}
}
@ -443,6 +491,7 @@ mod tests {
}
#[test]
#[cfg(feature = "std")]
fn test_dropped_future_doesnt_count() {
// Tests that if a Future gets drop'd before it is poll()ed `Ready` it doesn't count as
// having been woken, leaving the notify-required flag set.
@ -451,8 +500,8 @@ mod tests {
// If we get a future and don't touch it we're definitely still notify-required.
notifier.get_future();
assert!(notifier.wait_timeout(Duration::from_millis(1)));
assert!(!notifier.wait_timeout(Duration::from_millis(1)));
assert!(notifier.get_future().wait_timeout(Duration::from_millis(1)));
assert!(!notifier.get_future().wait_timeout(Duration::from_millis(1)));
// Even if we poll'd once but didn't observe a `Ready`, we should be notify-required.
let mut future = notifier.get_future();
@ -461,7 +510,7 @@ mod tests {
notifier.notify();
assert!(woken.load(Ordering::SeqCst));
assert!(notifier.wait_timeout(Duration::from_millis(1)));
assert!(notifier.get_future().wait_timeout(Duration::from_millis(1)));
// However, once we do poll `Ready` it should wipe the notify-required flag.
let mut future = notifier.get_future();
@ -471,7 +520,7 @@ mod tests {
notifier.notify();
assert!(woken.load(Ordering::SeqCst));
assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
assert!(!notifier.wait_timeout(Duration::from_millis(1)));
assert!(!notifier.get_future().wait_timeout(Duration::from_millis(1)));
}
#[test]
@ -532,4 +581,67 @@ mod tests {
assert!(woken.load(Ordering::SeqCst));
assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
}
#[test]
#[cfg(feature = "std")]
fn test_multi_future_sleep() {
// Tests the `Sleeper` with multiple futures.
let notifier_a = Notifier::new();
let notifier_b = Notifier::new();
// Set both notifiers as woken without sleeping yet.
notifier_a.notify();
notifier_b.notify();
Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait();
// One future has woken us up, but the other should still have a pending notification.
Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait();
// However once we've slept twice, we should no longer have any pending notifications
assert!(!Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future())
.wait_timeout(Duration::from_millis(10)));
// Test ordering somewhat more.
notifier_a.notify();
Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait();
}
#[test]
#[cfg(feature = "std")]
fn sleeper_with_pending_callbacks() {
// This is similar to the above `test_multi_future_sleep` test, but in addition registers
// "normal" callbacks which will cause the futures to assume notification has occurred,
// rather than waiting for a woken sleeper.
let notifier_a = Notifier::new();
let notifier_b = Notifier::new();
// Set both notifiers as woken without sleeping yet.
notifier_a.notify();
notifier_b.notify();
// After sleeping one future (not guaranteed which one, however) will have its notification
// bit cleared.
Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait();
// By registering a callback on the futures for both notifiers, one will complete
// immediately, but one will remain tied to the notifier, and will complete once the
// notifier is next woken, which will be considered the completion of the notification.
let callback_a = Arc::new(AtomicBool::new(false));
let callback_b = Arc::new(AtomicBool::new(false));
let callback_a_ref = Arc::clone(&callback_a);
let callback_b_ref = Arc::clone(&callback_b);
notifier_a.get_future().register_callback(Box::new(move || assert!(!callback_a_ref.fetch_or(true, Ordering::SeqCst))));
notifier_b.get_future().register_callback(Box::new(move || assert!(!callback_b_ref.fetch_or(true, Ordering::SeqCst))));
assert!(callback_a.load(Ordering::SeqCst) ^ callback_b.load(Ordering::SeqCst));
// If we now notify both notifiers again, the other callback will fire, completing the
// notification, and we'll be back to one pending notification.
notifier_a.notify();
notifier_b.notify();
assert!(callback_a.load(Ordering::SeqCst) && callback_b.load(Ordering::SeqCst));
Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait();
assert!(!Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future())
.wait_timeout(Duration::from_millis(10)));
}
}