@@ -9,7 +9,7 @@ use std::pin::Pin;
9
9
use std:: ptr;
10
10
use std:: sync:: atomic:: AtomicUsize ;
11
11
use std:: sync:: atomic:: Ordering :: { Acquire , SeqCst } ;
12
- use std:: sync:: { Arc , Mutex , Weak } ;
12
+ use std:: sync:: { Arc , Mutex , MutexGuard , Weak } ;
13
13
14
14
/// Future for the [`shared`](super::FutureExt::shared) method.
15
15
#[ must_use = "futures do nothing unless you `.await` or poll them" ]
@@ -81,6 +81,7 @@ const IDLE: usize = 0;
81
81
const POLLING : usize = 1 ;
82
82
const COMPLETE : usize = 2 ;
83
83
const POISONED : usize = 3 ;
84
+ const WOKEN_DURING_POLLING : usize = 4 ;
84
85
85
86
const NULL_WAKER_KEY : usize = usize:: MAX ;
86
87
@@ -197,35 +198,43 @@ where
197
198
}
198
199
}
199
200
201
+ /// Registers the current task to receive a wakeup when we are awoken.
202
+ fn record_waker ( wakers_guard : & mut MutexGuard < ' _ , Option < Slab < Option < Waker > > > > , waker_key : & mut usize , cx : & mut Context < ' _ > ) {
203
+ let wakers = match wakers_guard. as_mut ( ) {
204
+ Some ( wakers) => wakers,
205
+ None => return ,
206
+ } ;
207
+
208
+ let new_waker = cx. waker ( ) ;
209
+
210
+ if * waker_key == NULL_WAKER_KEY {
211
+ * waker_key = wakers. insert ( Some ( new_waker. clone ( ) ) ) ;
212
+ } else {
213
+ match wakers[ * waker_key] {
214
+ Some ( ref old_waker) if new_waker. will_wake ( old_waker) => { }
215
+ // Could use clone_from here, but Waker doesn't specialize it.
216
+ ref mut slot => * slot = Some ( new_waker. clone ( ) ) ,
217
+ }
218
+ }
219
+ debug_assert ! ( * waker_key != NULL_WAKER_KEY ) ;
220
+ }
221
+
222
+ /// Wakes all tasks that are registered to be woken.
223
+ fn wake_all ( waker_guard : & mut MutexGuard < ' _ , Option < Slab < Option < Waker > > > > ) {
224
+ if let Some ( wakers) = waker_guard. as_mut ( ) {
225
+ for ( _key, opt_waker) in wakers {
226
+ if let Some ( waker) = opt_waker. take ( ) {
227
+ waker. wake ( ) ;
228
+ }
229
+ }
230
+ }
231
+ }
232
+
200
233
impl < Fut > Inner < Fut >
201
234
where
202
235
Fut : Future ,
203
236
Fut :: Output : Clone ,
204
237
{
205
- /// Registers the current task to receive a wakeup when we are awoken.
206
- fn record_waker ( & self , waker_key : & mut usize , cx : & mut Context < ' _ > ) {
207
- let mut wakers_guard = self . notifier . wakers . lock ( ) . unwrap ( ) ;
208
-
209
- let wakers_mut = wakers_guard. as_mut ( ) ;
210
-
211
- let wakers = match wakers_mut {
212
- Some ( wakers) => wakers,
213
- None => return ,
214
- } ;
215
-
216
- let new_waker = cx. waker ( ) ;
217
-
218
- if * waker_key == NULL_WAKER_KEY {
219
- * waker_key = wakers. insert ( Some ( new_waker. clone ( ) ) ) ;
220
- } else {
221
- match wakers[ * waker_key] {
222
- Some ( ref old_waker) if new_waker. will_wake ( old_waker) => { }
223
- // Could use clone_from here, but Waker doesn't specialize it.
224
- ref mut slot => * slot = Some ( new_waker. clone ( ) ) ,
225
- }
226
- }
227
- debug_assert ! ( * waker_key != NULL_WAKER_KEY ) ;
228
- }
229
238
230
239
/// Safety: callers must first ensure that `inner.state`
231
240
/// is `COMPLETE`
@@ -268,18 +277,18 @@ where
268
277
return unsafe { Poll :: Ready ( inner. take_or_clone_output ( ) ) } ;
269
278
}
270
279
271
- inner. record_waker ( & mut this. waker_key , cx) ;
280
+ // Guard the state transition with mutex too
281
+ let mut wakers_guard = inner. notifier . wakers . lock ( ) . unwrap ( ) ;
282
+ record_waker ( & mut wakers_guard, & mut this. waker_key , cx) ;
272
283
273
- match inner
274
- . notifier
275
- . state
276
- . compare_exchange ( IDLE , POLLING , SeqCst , SeqCst )
277
- . unwrap_or_else ( |x| x)
278
- {
284
+ let prev = inner. notifier . state . compare_exchange ( IDLE , POLLING , SeqCst , SeqCst ) . unwrap_or_else ( |x| x) ;
285
+ drop ( wakers_guard) ;
286
+
287
+ match prev {
279
288
IDLE => {
280
289
// Lock acquired, fall through
281
290
}
282
- POLLING => {
291
+ POLLING | WOKEN_DURING_POLLING => {
283
292
// Another task is currently polling, at this point we just want
284
293
// to ensure that the waker for this task is registered
285
294
this. inner = Some ( inner) ;
@@ -324,15 +333,22 @@ where
324
333
325
334
match poll_result {
326
335
Poll :: Pending => {
327
- if inner. notifier . state . compare_exchange ( POLLING , IDLE , SeqCst , SeqCst ) . is_ok ( )
336
+ match inner. notifier . state . compare_exchange ( POLLING , IDLE , SeqCst , SeqCst )
328
337
{
329
- // Success
330
- drop ( reset) ;
331
- this. inner = Some ( inner) ;
332
- return Poll :: Pending ;
333
- } else {
334
- unreachable ! ( )
338
+ Ok ( POLLING ) => { } // success
339
+ Err ( WOKEN_DURING_POLLING ) => {
340
+ // waker has been called inside future.poll, need to wake any new wakers registered
341
+ let mut wakers = inner. notifier . wakers . lock ( ) . unwrap ( ) ;
342
+ wake_all ( & mut wakers) ;
343
+ let prev = inner. notifier . state . swap ( IDLE , SeqCst ) ;
344
+ assert_eq ! ( prev, WOKEN_DURING_POLLING ) ;
345
+ drop ( wakers) ;
346
+ }
347
+ _ => unreachable ! ( ) ,
335
348
}
349
+ drop ( reset) ;
350
+ this. inner = Some ( inner) ;
351
+ return Poll :: Pending ;
336
352
}
337
353
Poll :: Ready ( output) => output,
338
354
}
@@ -387,14 +403,9 @@ where
387
403
388
404
impl ArcWake for Notifier {
389
405
fn wake_by_ref ( arc_self : & Arc < Self > ) {
390
- let wakers = & mut * arc_self. wakers . lock ( ) . unwrap ( ) ;
391
- if let Some ( wakers) = wakers. as_mut ( ) {
392
- for ( _key, opt_waker) in wakers {
393
- if let Some ( waker) = opt_waker. take ( ) {
394
- waker. wake ( ) ;
395
- }
396
- }
397
- }
406
+ let mut wakers = arc_self. wakers . lock ( ) . unwrap ( ) ;
407
+ let _ = arc_self. state . compare_exchange ( POLLING , WOKEN_DURING_POLLING , SeqCst , SeqCst ) ;
408
+ wake_all ( & mut wakers) ;
398
409
}
399
410
}
400
411
0 commit comments