From 0159481178d8e9090975d1351c3dce75e521885e Mon Sep 17 00:00:00 2001 From: David Justo Date: Fri, 2 Jun 2023 16:24:00 -0700 Subject: [PATCH] Validate that timeout is positive in WaitForOrchestrationAsync, and enable nullable checks (#910) --- .../AzureStorageOrchestrationService.cs | 60 ++++++++++++------- 1 file changed, 39 insertions(+), 21 deletions(-) diff --git a/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs b/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs index 38603c985..3bdb99d3a 100644 --- a/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs +++ b/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs @@ -10,7 +10,6 @@ // See the License for the specific language governing permissions and // limitations under the License. // ---------------------------------------------------------------------------------- - namespace DurableTask.AzureStorage { using System; @@ -1838,15 +1837,15 @@ async Task IOrchestrationServicePurgeClient.PurgeInstanceStateAsync purgeInstanceFilter.RuntimeStatus); return storagePurgeHistoryResult.ToCorePurgeHistoryResult(); } - +#nullable enable /// /// Wait for an orchestration to reach any terminal state within the given timeout /// /// The orchestration instance to wait for. /// The execution ID (generation) of the specified instance. - /// Max timeout to wait. + /// Max timeout to wait. Only positive values, , or are allowed. /// Task cancellation token. - public async Task WaitForOrchestrationAsync( + public async Task WaitForOrchestrationAsync( string instanceId, string executionId, TimeSpan timeout, @@ -1857,20 +1856,24 @@ public async Task WaitForOrchestrationAsync( throw new ArgumentException(nameof(instanceId)); } + bool isInfiniteTimeSpan = timeout == Timeout.InfiniteTimeSpan; + if (timeout < TimeSpan.Zero && !isInfiniteTimeSpan) + { + throw new ArgumentException($"The parameter {nameof(timeout)} cannot be negative." + + $" The value for {nameof(timeout)} was '{timeout}'." + + $" Please provide either a positive timeout value or Timeout.InfiniteTimeSpan."); + } + TimeSpan statusPollingInterval = TimeSpan.FromSeconds(2); - while (!cancellationToken.IsCancellationRequested && timeout > TimeSpan.Zero) - { - OrchestrationState state = await this.GetOrchestrationStateAsync(instanceId, executionId); - if (state == null || - state.OrchestrationStatus == OrchestrationStatus.Running || - state.OrchestrationStatus == OrchestrationStatus.Suspended || - state.OrchestrationStatus == OrchestrationStatus.Pending || - state.OrchestrationStatus == OrchestrationStatus.ContinuedAsNew) - { - await Task.Delay(statusPollingInterval, cancellationToken); - timeout -= statusPollingInterval; - } - else + while (!cancellationToken.IsCancellationRequested) + { + OrchestrationState? state = await this.GetOrchestrationStateAsync(instanceId, executionId); + + if (state != null && + state.OrchestrationStatus != OrchestrationStatus.Running && + state.OrchestrationStatus != OrchestrationStatus.Suspended && + state.OrchestrationStatus != OrchestrationStatus.Pending && + state.OrchestrationStatus != OrchestrationStatus.ContinuedAsNew) { if (this.settings.FetchLargeMessageDataEnabled) { @@ -1879,6 +1882,17 @@ public async Task WaitForOrchestrationAsync( } return state; } + + timeout -= statusPollingInterval; + + // For a user-provided timeout of `TimeSpan.Zero`, + // we want to check the status of the orchestration once and then return. + // Therefore, we check the timeout condition after the status check. + if (!isInfiniteTimeSpan && (timeout <= TimeSpan.Zero)) + { + break; + } + await Task.Delay(statusPollingInterval, cancellationToken); } return null; @@ -1909,7 +1923,7 @@ public Task DownloadBlobAsync(string blobUri) // TODO: Change this to a sticky assignment so that partition count changes can // be supported: https://github.com/Azure/azure-functions-durable-extension/issues/1 - async Task GetControlQueueAsync(string instanceId) + async Task GetControlQueueAsync(string instanceId) { uint partitionIndex = Fnv1aHashHelper.ComputeHash(instanceId) % (uint)this.settings.PartitionCount; string queueName = GetControlQueueName(this.settings.TaskHubName, (int)partitionIndex); @@ -1980,12 +1994,12 @@ private static OrchestrationQueryResult ConvertFrom(DurableStatusQueryResult sta class PendingMessageBatch { - public string OrchestrationInstanceId { get; set; } - public string OrchestrationExecutionId { get; set; } + public string? OrchestrationInstanceId { get; set; } + public string? OrchestrationExecutionId { get; set; } public List Messages { get; set; } = new List(); - public OrchestrationRuntimeState Orchestrationstate { get; set; } + public OrchestrationRuntimeState? Orchestrationstate { get; set; } } class ResettableLazy @@ -1995,7 +2009,10 @@ class ResettableLazy Lazy lazy; + // Supress warning because it's incorrect: the lazy variable is initialized in the constructor, in the `Reset()` method +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. public ResettableLazy(Func valueFactory, LazyThreadSafetyMode mode) +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. { this.valueFactory = valueFactory; this.threadSafetyMode = mode; @@ -2025,3 +2042,4 @@ public TaskHubQueueMessage(TaskHubQueue queue, TaskMessage message) } } } +#nullable disable