Split lists of Waker and directly-registered Future callbacks

In the next commit we'll fix a memory leak due to keeping too many
`std::task::Waker` callbacks in `FutureState` from redundant `poll`
calls, but first we need to split handling of `StdWaker`-based
future wake callbacks from normal ones, which we do here.
This commit is contained in:
Matt Corallo 2024-02-13 21:58:46 +00:00
parent 73da722d18
commit 2c987209f9

View file

@ -69,6 +69,7 @@ impl Notifier {
} else { } else {
let state = Arc::new(Mutex::new(FutureState { let state = Arc::new(Mutex::new(FutureState {
callbacks: Vec::new(), callbacks: Vec::new(),
std_future_callbacks: Vec::new(),
callbacks_with_state: Vec::new(), callbacks_with_state: Vec::new(),
complete: lock.0, complete: lock.0,
callbacks_made: false, callbacks_made: false,
@ -109,11 +110,13 @@ define_callback!(Send);
define_callback!(); define_callback!();
pub(crate) struct FutureState { pub(crate) struct FutureState {
// When we're tracking whether a callback counts as having woken the user's code, we check the // `callbacks` count as having woken the users' code (as they go direct to the user), but
// first bool - set to false if we're just calling a Waker, and true if we're calling an actual // `std_future_callbacks` and `callbacks_with_state` do not (as the first just wakes a future,
// user-provided function. // we only count it after another `poll()` and the second wakes a `Sleeper` which handles
callbacks: Vec<(bool, Box<dyn FutureCallback>)>, // setting `callbacks_made` itself).
callbacks_with_state: Vec<(bool, Box<dyn Fn(&Arc<Mutex<FutureState>>) -> () + Send>)>, callbacks: Vec<Box<dyn FutureCallback>>,
std_future_callbacks: Vec<StdWaker>,
callbacks_with_state: Vec<Box<dyn Fn(&Arc<Mutex<FutureState>>) -> () + Send>>,
complete: bool, complete: bool,
callbacks_made: bool, callbacks_made: bool,
} }
@ -121,13 +124,15 @@ pub(crate) struct FutureState {
fn complete_future(this: &Arc<Mutex<FutureState>>) -> bool { fn complete_future(this: &Arc<Mutex<FutureState>>) -> bool {
let mut state_lock = this.lock().unwrap(); let mut state_lock = this.lock().unwrap();
let state = &mut *state_lock; let state = &mut *state_lock;
for (counts_as_call, callback) in state.callbacks.drain(..) { for callback in state.callbacks.drain(..) {
callback.call(); callback.call();
state.callbacks_made |= counts_as_call; state.callbacks_made = true;
} }
for (counts_as_call, callback) in state.callbacks_with_state.drain(..) { for waker in state.std_future_callbacks.drain(..) {
waker.0.wake_by_ref();
}
for callback in state.callbacks_with_state.drain(..) {
(callback)(this); (callback)(this);
state.callbacks_made |= counts_as_call;
} }
state.complete = true; state.complete = true;
state.callbacks_made state.callbacks_made
@ -153,7 +158,7 @@ impl Future {
mem::drop(state); mem::drop(state);
callback.call(); callback.call();
} else { } else {
state.callbacks.push((true, callback)); state.callbacks.push(callback);
} }
} }
@ -193,9 +198,6 @@ impl Future {
use core::task::Waker; use core::task::Waker;
struct StdWaker(pub Waker); struct StdWaker(pub Waker);
impl FutureCallback for StdWaker {
fn call(&self) { self.0.wake_by_ref() }
}
/// This is not exported to bindings users as Rust Futures aren't usable in language bindings. /// This is not exported to bindings users as Rust Futures aren't usable in language bindings.
impl<'a> StdFuture for Future { impl<'a> StdFuture for Future {
@ -208,7 +210,7 @@ impl<'a> StdFuture for Future {
Poll::Ready(()) Poll::Ready(())
} else { } else {
let waker = cx.waker().clone(); let waker = cx.waker().clone();
state.callbacks.push((false, Box::new(StdWaker(waker)))); state.std_future_callbacks.push(StdWaker(waker));
Poll::Pending Poll::Pending
} }
} }
@ -251,10 +253,10 @@ impl Sleeper {
*notified_fut_mtx.lock().unwrap() = Some(Arc::clone(&notifier_mtx)); *notified_fut_mtx.lock().unwrap() = Some(Arc::clone(&notifier_mtx));
break; break;
} }
notifier.callbacks_with_state.push((false, Box::new(move |notifier_ref| { notifier.callbacks_with_state.push(Box::new(move |notifier_ref| {
*notified_fut_ref.lock().unwrap() = Some(Arc::clone(notifier_ref)); *notified_fut_ref.lock().unwrap() = Some(Arc::clone(notifier_ref));
cv_ref.notify_all(); cv_ref.notify_all();
}))); }));
} }
} }
(cv, notified_fut_mtx) (cv, notified_fut_mtx)
@ -455,6 +457,7 @@ mod tests {
let future = Future { let future = Future {
state: Arc::new(Mutex::new(FutureState { state: Arc::new(Mutex::new(FutureState {
callbacks: Vec::new(), callbacks: Vec::new(),
std_future_callbacks: Vec::new(),
callbacks_with_state: Vec::new(), callbacks_with_state: Vec::new(),
complete: false, complete: false,
callbacks_made: false, callbacks_made: false,
@ -514,6 +517,7 @@ mod tests {
let mut future = Future { let mut future = Future {
state: Arc::new(Mutex::new(FutureState { state: Arc::new(Mutex::new(FutureState {
callbacks: Vec::new(), callbacks: Vec::new(),
std_future_callbacks: Vec::new(),
callbacks_with_state: Vec::new(), callbacks_with_state: Vec::new(),
complete: false, complete: false,
callbacks_made: false, callbacks_made: false,