Give Futures for a FutureState an idx and track StdWaker idxn

When an `std::future::Future` is `poll()`ed, we're only supposed to
use the latest `Waker` provided. However, we currently push an
`StdWaker` onto our callback list every time `poll` is called,
waking every `Waker` but also using more and more memory until the
`Future` itself is woken.

Here we take a step towards fixing this by giving each `Future` a
unique index and storing which `Future` an `StdWaker` came from in
the callback list. This sets us up to deduplicate `StdWaker`s by
`Future`s in the next commit.
This commit is contained in:
Matt Corallo 2024-02-13 22:08:55 +00:00
parent 2c987209f9
commit 5f404b9d0a

View file

@ -56,16 +56,22 @@ impl Notifier {
/// Gets a [`Future`] that will get woken up with any waiters /// Gets a [`Future`] that will get woken up with any waiters
pub(crate) fn get_future(&self) -> Future { pub(crate) fn get_future(&self) -> Future {
let mut lock = self.notify_pending.lock().unwrap(); let mut lock = self.notify_pending.lock().unwrap();
let mut self_idx = 0;
if let Some(existing_state) = &lock.1 { if let Some(existing_state) = &lock.1 {
if existing_state.lock().unwrap().callbacks_made { let mut locked = existing_state.lock().unwrap();
if locked.callbacks_made {
// If the existing `FutureState` has completed and actually made callbacks, // If the existing `FutureState` has completed and actually made callbacks,
// consider the notification flag to have been cleared and reset the future state. // consider the notification flag to have been cleared and reset the future state.
mem::drop(locked);
lock.1.take(); lock.1.take();
lock.0 = false; lock.0 = false;
} else {
self_idx = locked.next_idx;
locked.next_idx += 1;
} }
} }
if let Some(existing_state) = &lock.1 { if let Some(existing_state) = &lock.1 {
Future { state: Arc::clone(&existing_state) } Future { state: Arc::clone(&existing_state), self_idx }
} else { } else {
let state = Arc::new(Mutex::new(FutureState { let state = Arc::new(Mutex::new(FutureState {
callbacks: Vec::new(), callbacks: Vec::new(),
@ -73,9 +79,10 @@ impl Notifier {
callbacks_with_state: Vec::new(), callbacks_with_state: Vec::new(),
complete: lock.0, complete: lock.0,
callbacks_made: false, callbacks_made: false,
next_idx: 1,
})); }));
lock.1 = Some(Arc::clone(&state)); lock.1 = Some(Arc::clone(&state));
Future { state } Future { state, self_idx: 0 }
} }
} }
@ -115,10 +122,11 @@ pub(crate) struct FutureState {
// we only count it after another `poll()` and the second wakes a `Sleeper` which handles // we only count it after another `poll()` and the second wakes a `Sleeper` which handles
// setting `callbacks_made` itself). // setting `callbacks_made` itself).
callbacks: Vec<Box<dyn FutureCallback>>, callbacks: Vec<Box<dyn FutureCallback>>,
std_future_callbacks: Vec<StdWaker>, std_future_callbacks: Vec<(usize, StdWaker)>,
callbacks_with_state: Vec<Box<dyn Fn(&Arc<Mutex<FutureState>>) -> () + Send>>, callbacks_with_state: Vec<Box<dyn Fn(&Arc<Mutex<FutureState>>) -> () + Send>>,
complete: bool, complete: bool,
callbacks_made: bool, callbacks_made: bool,
next_idx: usize,
} }
fn complete_future(this: &Arc<Mutex<FutureState>>) -> bool { fn complete_future(this: &Arc<Mutex<FutureState>>) -> bool {
@ -128,7 +136,7 @@ fn complete_future(this: &Arc<Mutex<FutureState>>) -> bool {
callback.call(); callback.call();
state.callbacks_made = true; state.callbacks_made = true;
} }
for waker in state.std_future_callbacks.drain(..) { for (_, waker) in state.std_future_callbacks.drain(..) {
waker.0.wake_by_ref(); waker.0.wake_by_ref();
} }
for callback in state.callbacks_with_state.drain(..) { for callback in state.callbacks_with_state.drain(..) {
@ -139,11 +147,9 @@ fn complete_future(this: &Arc<Mutex<FutureState>>) -> bool {
} }
/// A simple future which can complete once, and calls some callback(s) when it does so. /// A simple future which can complete once, and calls some callback(s) when it does so.
///
/// Clones can be made and all futures cloned from the same source will complete at the same time.
#[derive(Clone)]
pub struct Future { pub struct Future {
state: Arc<Mutex<FutureState>>, state: Arc<Mutex<FutureState>>,
self_idx: usize,
} }
impl Future { impl Future {
@ -210,7 +216,7 @@ impl<'a> StdFuture for Future {
Poll::Ready(()) Poll::Ready(())
} else { } else {
let waker = cx.waker().clone(); let waker = cx.waker().clone();
state.std_future_callbacks.push(StdWaker(waker)); state.std_future_callbacks.push((self.self_idx, StdWaker(waker)));
Poll::Pending Poll::Pending
} }
} }
@ -461,7 +467,9 @@ mod tests {
callbacks_with_state: Vec::new(), callbacks_with_state: Vec::new(),
complete: false, complete: false,
callbacks_made: false, callbacks_made: false,
})) next_idx: 1,
})),
self_idx: 0,
}; };
let callback = Arc::new(AtomicBool::new(false)); let callback = Arc::new(AtomicBool::new(false));
let callback_ref = Arc::clone(&callback); let callback_ref = Arc::clone(&callback);
@ -478,10 +486,13 @@ mod tests {
let future = Future { let future = Future {
state: Arc::new(Mutex::new(FutureState { state: Arc::new(Mutex::new(FutureState {
callbacks: Vec::new(), callbacks: Vec::new(),
std_future_callbacks: Vec::new(),
callbacks_with_state: Vec::new(), callbacks_with_state: Vec::new(),
complete: false, complete: false,
callbacks_made: false, callbacks_made: false,
})) next_idx: 1,
})),
self_idx: 0,
}; };
complete_future(&future.state); complete_future(&future.state);
@ -521,9 +532,11 @@ mod tests {
callbacks_with_state: Vec::new(), callbacks_with_state: Vec::new(),
complete: false, complete: false,
callbacks_made: false, callbacks_made: false,
})) next_idx: 2,
})),
self_idx: 0,
}; };
let mut second_future = Future { state: Arc::clone(&future.state) }; let mut second_future = Future { state: Arc::clone(&future.state), self_idx: 1 };
let (woken, waker) = create_waker(); let (woken, waker) = create_waker();
assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending); assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);