diff --git a/src/System.Private.ServiceModel/src/Internals/System/Runtime/AsyncLock.cs b/src/System.Private.ServiceModel/src/Internals/System/Runtime/AsyncLock.cs index 5457d50d96c..063aa9cbf02 100644 --- a/src/System.Private.ServiceModel/src/Internals/System/Runtime/AsyncLock.cs +++ b/src/System.Private.ServiceModel/src/Internals/System/Runtime/AsyncLock.cs @@ -10,7 +10,7 @@ namespace System.Runtime { internal class AsyncLock : IAsyncDisposable { - private static ObjectPool s_semaphorePool = (new DefaultObjectPoolProvider { MaximumRetained = 100 }) + private static readonly ObjectPool s_semaphorePool = (new DefaultObjectPoolProvider { MaximumRetained = 100 }) .Create(new SemaphoreSlimPooledObjectPolicy()); private AsyncLocal _currentSemaphore; @@ -23,14 +23,12 @@ public AsyncLock() _currentSemaphore = new AsyncLocal(); } - public Task TakeLockAsync(CancellationToken cancellationToken = default) + public Task TakeLockAsync() { if (_isDisposed) throw new ObjectDisposedException(nameof(AsyncLock)); - if (_currentSemaphore.Value == null) - { - _currentSemaphore.Value = _topLevelSemaphore; - } + + _currentSemaphore.Value ??= _topLevelSemaphore; SemaphoreSlim currentSem = _currentSemaphore.Value; var nextSem = s_semaphorePool.Get(); _currentSemaphore.Value = nextSem; @@ -38,9 +36,9 @@ public Task TakeLockAsync(CancellationToken cancellationToken return TakeLockCoreAsync(currentSem, safeRelease); } - private async Task TakeLockCoreAsync(SemaphoreSlim currentSemaphore, SafeSemaphoreRelease safeSemaphoreRelease, CancellationToken cancellationToken = default) + private async Task TakeLockCoreAsync(SemaphoreSlim currentSemaphore, SafeSemaphoreRelease safeSemaphoreRelease) { - await currentSemaphore.WaitAsync(cancellationToken); + await currentSemaphore.WaitAsync(); return safeSemaphoreRelease; } @@ -48,12 +46,10 @@ public IDisposable TakeLock() { if (_isDisposed) throw new ObjectDisposedException(nameof(AsyncLock)); - if (_currentSemaphore.Value == null) - { - _currentSemaphore.Value = _topLevelSemaphore; - } + + _currentSemaphore.Value ??= _topLevelSemaphore; SemaphoreSlim currentSem = _currentSemaphore.Value; - currentSem.Wait(/*cancellationToken*/); + currentSem.Wait(); var nextSem = s_semaphorePool.Get(); _currentSemaphore.Value = nextSem; return new SafeSemaphoreRelease(currentSem, nextSem, this); @@ -63,8 +59,10 @@ public async ValueTask DisposeAsync() { if (_isDisposed) return; + _isDisposed = true; - // Ensure not in use + // Ensure the lock isn't held. If it is, wait for it to be released + // before completing the dispose. await _topLevelSemaphore.WaitAsync(); _topLevelSemaphore.Release(); s_semaphorePool.Return(_topLevelSemaphore); @@ -84,32 +82,49 @@ public SafeSemaphoreRelease(SemaphoreSlim currentSemaphore, SemaphoreSlim nextSe _asyncLock = asyncLock; } - public async ValueTask DisposeAsync() + public ValueTask DisposeAsync() { Fx.Assert(_nextSemaphore == _asyncLock._currentSemaphore.Value, "_nextSemaphore was expected to by the current semaphore"); + // Update _asyncLock._currentSemaphore in the calling ExecutionContext + // and defer any awaits to DisposeCoreAsync(). If this isn't done, the + // update will happen in a copy of the ExecutionContext and the caller + // won't see the changes. + if (_currentSemaphore == _asyncLock._topLevelSemaphore) + { + _asyncLock._currentSemaphore.Value = null; + } + else + { + _asyncLock._currentSemaphore.Value = _currentSemaphore; + } + + return DisposeCoreAsync(); + } + + private async ValueTask DisposeCoreAsync() + { await _nextSemaphore.WaitAsync(); - _asyncLock._currentSemaphore.Value = _currentSemaphore; _currentSemaphore.Release(); _nextSemaphore.Release(); s_semaphorePool.Return(_nextSemaphore); - if (_asyncLock._currentSemaphore.Value == _asyncLock._topLevelSemaphore) - { - _asyncLock._currentSemaphore.Value = null; - } } public void Dispose() { Fx.Assert(_nextSemaphore == _asyncLock._currentSemaphore.Value, "_nextSemaphore was expected to by the current semaphore"); + if (_currentSemaphore == _asyncLock._topLevelSemaphore) + { + _asyncLock._currentSemaphore.Value = null; + } + else + { + _asyncLock._currentSemaphore.Value = _currentSemaphore; + } + _nextSemaphore.Wait(); - _asyncLock._currentSemaphore.Value = _currentSemaphore; _currentSemaphore.Release(); _nextSemaphore.Release(); s_semaphorePool.Return(_nextSemaphore); - if (_asyncLock._currentSemaphore.Value == _asyncLock._topLevelSemaphore) - { - _asyncLock._currentSemaphore.Value = null; - } } }