diff --git a/src/thread.rs b/src/thread.rs index a0604e8..e84478f 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -205,7 +205,7 @@ impl Builder { f, self.name, self.stack_size, - Some(scope.data.clone()), + Some(&scope.data), location!(), ) }, @@ -309,7 +309,7 @@ impl fmt::Debug for LocalKey { /// See [`scope`] for more details. #[derive(Debug)] pub struct Scope<'scope, 'env: 'scope> { - data: Arc, + data: ScopeData, scope: PhantomData<&'scope mut &'scope ()>, env: PhantomData<&'env mut &'env ()>, } @@ -329,10 +329,10 @@ where F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T, { let scope = Scope { - data: Arc::new(ScopeData { + data: ScopeData { running_threads: Mutex::default(), main_thread: current(), - }), + }, env: PhantomData, scope: PhantomData, }; @@ -394,7 +394,6 @@ impl<'scope, T> ScopedJoinHandle<'scope, T> { #[derive(Debug)] struct JoinHandleInner<'scope, T> { data: Arc>, - notify: rt::Notify, thread: Thread, } @@ -423,7 +422,7 @@ unsafe fn spawn_internal<'scope, F, T>( f: F, name: Option, stack_size: Option, - scope: Option>, + scope: Option<&'scope ScopeData>, location: Location, ) -> JoinHandleInner<'scope, T> where @@ -435,18 +434,28 @@ where .clone() .map(|scope| (scope.add_running_thread(), scope)); let thread_data = Arc::new(ThreadData::new()); - let notify = rt::Notify::new(true, false); let id = { let name = name.clone(); - let thread_data = thread_data.clone(); + // Hold a weak reference so that if the thread handle gets dropped, we + // don't try to store the result or notify anybody unnecessarily. + let weak_data = Arc::downgrade(&thread_data); + let body: Box = Box::new(move || { rt::execution(|execution| { init_current(execution, name); }); - *thread_data.result.lock().unwrap() = Some(Ok(f())); - notify.notify(location); + // Ensure everything from the spawned thread's execution either gets + // stored in the thread handle or dropped before notifying that the + // thread has completed. + { + let result = f(); + if let Some(thread_data) = weak_data.upgrade() { + *thread_data.result.lock().unwrap() = Some(Ok(result)); + thread_data.notification.notify(location); + } + } if let Some((notifier, scope)) = scope_notify { notifier.notify(location!()); @@ -461,7 +470,6 @@ where JoinHandleInner { data: thread_data, - notify, thread: Thread { id: ThreadId { id }, name, @@ -473,6 +481,7 @@ where #[derive(Debug)] struct ThreadData<'scope, T> { result: Mutex>>, + notification: rt::Notify, _marker: PhantomData>, } @@ -480,6 +489,7 @@ impl<'scope, T> ThreadData<'scope, T> { fn new() -> Self { Self { result: Mutex::new(None), + notification: rt::Notify::new(true, false), _marker: PhantomData, } } @@ -487,7 +497,7 @@ impl<'scope, T> ThreadData<'scope, T> { impl<'scope, T> JoinHandleInner<'scope, T> { fn join(self) -> std::thread::Result { - self.notify.wait(location!()); + self.data.notification.wait(location!()); self.data.result.lock().unwrap().take().unwrap() } diff --git a/tests/thread_api.rs b/tests/thread_api.rs index 720c027..39a6c34 100644 --- a/tests/thread_api.rs +++ b/tests/thread_api.rs @@ -241,3 +241,45 @@ fn scoped_and_unscoped_threads() { assert_eq!(v, 2); }) } + +struct YieldAndIncrementOnDrop<'a>(&'a std::sync::atomic::AtomicUsize); + +impl Drop for YieldAndIncrementOnDrop<'_> { + fn drop(&mut self) { + thread::yield_now(); + self.0.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + } +} + +#[test] +fn scoped_thread_wait_until_finished() { + loom::model(|| { + let a = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)); + let r: &std::sync::atomic::AtomicUsize = &a; + thread::scope(|s| { + s.spawn(move || { + r.fetch_add(2, std::sync::atomic::Ordering::SeqCst); + YieldAndIncrementOnDrop(r) + }); + }); + assert_eq!(a.load(std::sync::atomic::Ordering::SeqCst), 3); + }); +} + +#[test] +fn scoped_thread_join_handle_forgotten() { + loom::model(|| { + let a = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)); + let r: &std::sync::atomic::AtomicUsize = &a; + thread::scope(|s| { + let handle = s.spawn(move || { + r.fetch_add(2, std::sync::atomic::Ordering::SeqCst); + YieldAndIncrementOnDrop(r) + }); + std::mem::forget(handle) + }); + // Expect only 2 since the spawned thread will complete but its result + // will be leaked and so never dropped. + assert_eq!(a.load(std::sync::atomic::Ordering::SeqCst), 2); + }); +}