Skip to content

Commit

Permalink
Prevent EC capture causing AsyncLocal leak
Browse files Browse the repository at this point in the history
  • Loading branch information
mconnew committed Sep 20, 2021
1 parent 8e18e27 commit 5c2f981
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 183 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,139 +4,131 @@

using System.Threading;
using System.Threading.Tasks;
#if DEBUG
using System.Diagnostics;
#endif
using Microsoft.Extensions.ObjectPool;

namespace System.Runtime
{
internal class AsyncLock
internal class AsyncLock : IAsyncDisposable
{
#if DEBUG
private StackTrace _lockTakenCallStack;
private string _lockTakenCallStackString;
#endif
private readonly SemaphoreSlim _semaphore;
private readonly SafeSemaphoreRelease _semaphoreRelease;
private AsyncLocal<bool> _lockTaken;
private static ObjectPool<SemaphoreSlim> s_semaphorePool = (new DefaultObjectPoolProvider { MaximumRetained = 100 })
.Create(new SemaphoreSlimPooledObjectPolicy());

private AsyncLocal<SemaphoreSlim> _currentSemaphore;
private SemaphoreSlim _topLevelSemaphore;
private bool _isDisposed;

public AsyncLock()
{
_semaphore = new SemaphoreSlim(1);
_semaphoreRelease = new SafeSemaphoreRelease(this);
_lockTaken = new AsyncLocal<bool>(LockTakenValueChanged);
_lockTaken.Value = false;
_topLevelSemaphore = s_semaphorePool.Get();
_currentSemaphore = new AsyncLocal<SemaphoreSlim>();
}

private void LockTakenValueChanged(AsyncLocalValueChangedArgs<bool> obj)
public Task<IAsyncDisposable> TakeLockAsync(CancellationToken cancellationToken = default)
{
// Without this fixup, when completing the call to await TakeLockAsync there is
// a switch of Context and _localTaken will be reset to false. This is because
// of leaving the task.

if (obj.ThreadContextChanged)
if (_isDisposed)
throw new ObjectDisposedException(nameof(AsyncLock));
if (_currentSemaphore.Value == null)
{
_lockTaken.Value = obj.PreviousValue;
_currentSemaphore.Value = _topLevelSemaphore;
}
SemaphoreSlim currentSem = _currentSemaphore.Value;
var nextSem = s_semaphorePool.Get();
_currentSemaphore.Value = nextSem;
var safeRelease = new SafeSemaphoreRelease(currentSem, nextSem, this);
return TakeLockCoreAsync(currentSem, safeRelease);
}

public async Task<IDisposable> TakeLockAsync()
private async Task<IAsyncDisposable> TakeLockCoreAsync(SemaphoreSlim currentSemaphore, SafeSemaphoreRelease safeSemaphoreRelease, CancellationToken cancellationToken = default)
{
if (_lockTaken.Value)
{
return null;
}

await _semaphore.WaitAsync();
_lockTaken.Value = true;
#if DEBUG
_lockTakenCallStack = new StackTrace();
_lockTakenCallStackString = _lockTakenCallStack.ToString();
#endif
return _semaphoreRelease;
await currentSemaphore.WaitAsync(cancellationToken);
return safeSemaphoreRelease;
}

public async Task<IDisposable> TakeLockAsync(CancellationToken token)
public IDisposable TakeLock()
{
if (_lockTaken.Value)
if (_isDisposed)
throw new ObjectDisposedException(nameof(AsyncLock));
if (_currentSemaphore.Value == null)
{
return null;
_currentSemaphore.Value = _topLevelSemaphore;
}

await _semaphore.WaitAsync(token);
_lockTaken.Value = true;
#if DEBUG
_lockTakenCallStack = new StackTrace();
_lockTakenCallStackString = _lockTakenCallStack.ToString();
#endif
return _semaphoreRelease;
SemaphoreSlim currentSem = _currentSemaphore.Value;
currentSem.Wait(/*cancellationToken*/);
var nextSem = s_semaphorePool.Get();
_currentSemaphore.Value = nextSem;
return new SafeSemaphoreRelease(currentSem, nextSem, this);
}

public IDisposable TakeLock()
public async ValueTask DisposeAsync()
{
if (_lockTaken.Value)
{
return null;
}

_semaphore.Wait();
_lockTaken.Value = true;
#if DEBUG
_lockTakenCallStack = new StackTrace();
_lockTakenCallStackString = _lockTakenCallStack.ToString();
#endif
return _semaphoreRelease;
if (_isDisposed)
return;
_isDisposed = true;
// Ensure not in use
await _topLevelSemaphore.WaitAsync();
_topLevelSemaphore.Release();
s_semaphorePool.Return(_topLevelSemaphore);
_topLevelSemaphore = null;
}

public IDisposable TakeLock(TimeSpan timeout)
private struct SafeSemaphoreRelease : IAsyncDisposable, IDisposable
{
if (_lockTaken.Value)
private SemaphoreSlim _currentSemaphore;
private SemaphoreSlim _nextSemaphore;
private AsyncLock _asyncLock;

public SafeSemaphoreRelease(SemaphoreSlim currentSemaphore, SemaphoreSlim nextSemaphore, AsyncLock asyncLock)
{
return null;
_currentSemaphore = currentSemaphore;
_nextSemaphore = nextSemaphore;
_asyncLock = asyncLock;
}

_semaphore.Wait(timeout);
_lockTaken.Value = true;
#if DEBUG
_lockTakenCallStack = new StackTrace();
_lockTakenCallStackString = _lockTakenCallStack.ToString();
#endif
return _semaphoreRelease;
}

public IDisposable TakeLock(int timeout)
{
if (_lockTaken.Value)
public async ValueTask DisposeAsync()
{
return null;
Fx.Assert(_nextSemaphore == _asyncLock._currentSemaphore.Value, "_nextSemaphore was expected to by the current semaphore");
await _nextSemaphore.WaitAsync();
_asyncLock._currentSemaphore.Value = _currentSemaphore;
_currentSemaphore.Release();
_nextSemaphore.Release();
s_semaphorePool.Return(_nextSemaphore);
if (_asyncLock._currentSemaphore.Value == _asyncLock._topLevelSemaphore)
{
_asyncLock._currentSemaphore.Value = null;
}
}

_semaphore.Wait(timeout);
_lockTaken.Value = true;
#if DEBUG
_lockTakenCallStack = new StackTrace();
_lockTakenCallStackString = _lockTakenCallStack.ToString();
#endif
return _semaphoreRelease;
public void Dispose()
{
Fx.Assert(_nextSemaphore == _asyncLock._currentSemaphore.Value, "_nextSemaphore was expected to by the current semaphore");
_nextSemaphore.Wait();
_asyncLock._currentSemaphore.Value = _currentSemaphore;
_currentSemaphore.Release();
_nextSemaphore.Release();
s_semaphorePool.Return(_nextSemaphore);
if (_asyncLock._currentSemaphore.Value == _asyncLock._topLevelSemaphore)
{
_asyncLock._currentSemaphore.Value = null;
}
}
}

public struct SafeSemaphoreRelease : IDisposable
private class SemaphoreSlimPooledObjectPolicy : PooledObjectPolicy<SemaphoreSlim>
{
private readonly AsyncLock _asyncLock;

public SafeSemaphoreRelease(AsyncLock asyncLock)
public override SemaphoreSlim Create()
{
_asyncLock = asyncLock;
return new SemaphoreSlim(1);
}

public void Dispose()
public override bool Return(SemaphoreSlim obj)
{
#if DEBUG
_asyncLock._lockTakenCallStack = null;
_asyncLock._lockTakenCallStackString = null;
#endif
_asyncLock._lockTaken.Value = false;
_asyncLock._semaphore.Release();
if (obj.CurrentCount != 1)
{
Fx.Assert("Shouldn't be returning semaphore with a count != 1");
return false;
}

return true;
}
}
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
<IsImplementationAssembly>true</IsImplementationAssembly>
<IsPackable>true</IsPackable>
<IsShipping>$(Ship_WcfPackages)</IsShipping>
<LangVersion>9.0</LangVersion>
</PropertyGroup>

<!-- [todo:arcade] Added this because our released S.P.SM package includes the "_BlockReflectionAttribute" but it is included not only in the UAP* assembly but also the other shipped assembly, should this only be in the binary built for uap? -->
Expand Down Expand Up @@ -54,6 +55,8 @@
<PackageReference Include="System.Security.Principal.Windows" Version="$(SystemSecurityPrincipalWindowsPackageVersion)" />
<PackageReference Include="System.Security.Cryptography.Xml" Version="$(SystemSecurityCryptographyXmlPackageVersion)" />
<PackageReference Include="System.Numerics.Vectors" Version="$(SystemNumericsVectorsPackageVersion)" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="5.0.0" />
<PackageReference Include="Microsoft.Extensions.ObjectPool" Version="5.0.9" />
</ItemGroup>

<ItemGroup Condition="$(TargetFramework.StartsWith('uap'))">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ protected override void OnAbort()

private async Task OnAcknowledgementTimeoutElapsedAsync(object state)
{
using (await ThisAsyncLock.TakeLockAsync())
await using (await ThisAsyncLock.TakeLockAsync())
{
_acknowledgementScheduled = false;
_pendingAcknowledgements = 0;
Expand Down Expand Up @@ -1125,7 +1125,7 @@ protected override void OnClose(TimeSpan timeout)

private async void OnConnectionLost(object sender, EventArgs args)
{
using (await ThisAsyncLock.TakeLockAsync())
await using (await ThisAsyncLock.TakeLockAsync())
{
if ((State == CommunicationState.Opened || State == CommunicationState.Closing) &&
!Binder.Connected && _clientSession.StopPolling())
Expand Down Expand Up @@ -1181,7 +1181,7 @@ protected override void OnOpened()
private static async Task OnReconnectTimerElapsed(object state)
{
ClientReliableDuplexSessionChannel channel = (ClientReliableDuplexSessionChannel)state;
using (await channel.ThisAsyncLock.TakeLockAsync())
await using (await channel.ThisAsyncLock.TakeLockAsync())
{
if ((channel.State == CommunicationState.Opened || channel.State == CommunicationState.Closing) &&
!channel.Binder.Connected)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ public async Task<bool> EnsureChannelAsync()
{
bool fault = false;

using (await ThisLock.TakeLockAsync())
await using (await ThisLock.TakeLockAsync())
{
if (ValidateOpened())
{
Expand Down Expand Up @@ -1448,7 +1448,7 @@ private void SetWaiters(Queue<IWaiter> waiters, TChannel channel)

public async Task StartSynchronizingAsync()
{
using (await ThisLock.TakeLockAsync())
await using (await ThisLock.TakeLockAsync())
{
if (_state == State.Created)
{
Expand Down
Loading

0 comments on commit 5c2f981

Please sign in to comment.