From d6f7ee360876bb0fd52c51f88c699f35704dabc8 Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Mon, 17 Jul 2023 17:40:38 +0100 Subject: [PATCH] .Net: Merge feature/oobabooga branch to main (#2016) ### Motivation and Context Merge [feature/oobabooga](https://github.com/microsoft/semantic-kernel/tree/feature/oobabooga) branch to `main` with [Oobabooga](https://github.com/oobabooga/text-generation-webui) AI Connector functionality. Functionality verified with unit and integration testing. ### Description From original PR (https://github.com/microsoft/semantic-kernel/pull/1357): > This PR adds to the solution a project similar to HuggingFace connectors project, and an additional integration test also similar to HuggingFace connector's The code for the connector was based on the existing HuggingFace's, with a couple improvements (e.g. using web sockets for streaming API) ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#dev-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: Co-authored-by: Jean-Sylvain Boige --------- Signed-off-by: dependabot[bot] Co-authored-by: Jean-Sylvain Boige Co-authored-by: Shawn Callegari <36091529+shawncal@users.noreply.github.com> Co-authored-by: Gina Triolo <51341242+gitri-ms@users.noreply.github.com> Co-authored-by: Devis Lucato Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Craig Presti <146438+craigomatic@users.noreply.github.com> Co-authored-by: Craig Presti Co-authored-by: Mark Wallace <127216156+markwallace-microsoft@users.noreply.github.com> Co-authored-by: Teresa Hoang <125500434+teresaqhoang@users.noreply.github.com> Co-authored-by: Abby Harrison <54643756+awharrison-28@users.noreply.github.com> Co-authored-by: Tao Chen Co-authored-by: Aman Sachan <51973971+amsacha@users.noreply.github.com> Co-authored-by: cschadewitz Co-authored-by: Abby Harrison --- dotnet/Directory.Packages.props | 2 + dotnet/SK-dotnet.sln | 9 + dotnet/SK-dotnet.sln.DotSettings | 2 + .../Connectors.AI.Oobabooga.csproj | 28 ++ .../OobaboogaInvalidResponseException.cs | 16 + .../TextCompletion/OobaboogaTextCompletion.cs | 475 ++++++++++++++++++ .../TextCompletion/TextCompletionRequest.cs | 177 +++++++ .../TextCompletion/TextCompletionResponse.cs | 30 ++ .../TextCompletion/TextCompletionResult.cs | 28 ++ .../TextCompletionStreamingResponse.cs | 32 ++ .../TextCompletionStreamingResult.cs | 66 +++ .../Connectors.UnitTests/ConnectedClient.cs | 25 + .../Connectors.UnitTests.csproj | 7 + .../Oobabooga/OobaboogaTestHelper.cs | 44 ++ .../Oobabooga/OobaboogaWebSocketTestServer.cs | 62 +++ .../TestData/completion_test_response.json | 9 + .../completion_test_streaming_response.json | 5 + .../OobaboogaTextCompletionTests.cs | 405 +++++++++++++++ .../WebSocketTestServer.cs | 223 ++++++++ .../Connectors.UnitTests/XunitLogger.cs | 40 ++ .../Oobabooga/OobaboogaTextCompletionTests.cs | 110 ++++ .../IntegrationTests/IntegrationTests.csproj | 1 + dotnet/src/IntegrationTests/README.md | 1 + .../TextCompletionExtensions.cs | 1 + 24 files changed, 1798 insertions(+) create mode 100644 dotnet/src/Connectors/Connectors.AI.Oobabooga/Connectors.AI.Oobabooga.csproj create mode 100644 dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/OobaboogaInvalidResponseException.cs create mode 100644 dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/OobaboogaTextCompletion.cs create mode 100644 dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionRequest.cs create mode 100644 dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionResponse.cs create mode 100644 dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionResult.cs create mode 100644 dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionStreamingResponse.cs create mode 100644 dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionStreamingResult.cs create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/ConnectedClient.cs create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/OobaboogaTestHelper.cs create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/OobaboogaWebSocketTestServer.cs create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/TestData/completion_test_response.json create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/TestData/completion_test_streaming_response.json create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/TextCompletion/OobaboogaTextCompletionTests.cs create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/WebSocketTestServer.cs create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/XunitLogger.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Oobabooga/OobaboogaTextCompletionTests.cs diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index e42dcfee2e1c..f45e3291ae70 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -29,6 +29,8 @@ + + diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 453e32f33399..7207d47c875a 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -141,6 +141,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Skills.Core", "src\Skills\S EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NCalcSkills", "samples\NCalcSkills\NCalcSkills.csproj", "{E6EDAB8F-3406-4DBF-9AAB-DF40DC2CA0FA}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.AI.Oobabooga", "src\Connectors\Connectors.AI.Oobabooga\Connectors.AI.Oobabooga.csproj", "{677F1381-7830-4115-9C1A-58B282629DC6}" +EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Planning.StepwisePlanner", "src\Extensions\Planning.StepwisePlanner\Planning.StepwisePlanner.csproj", "{4762BCAF-E1C5-4714-B88D-E50FA333C50E}" EndProject Global @@ -342,6 +344,12 @@ Global {E6EDAB8F-3406-4DBF-9AAB-DF40DC2CA0FA}.Publish|Any CPU.ActiveCfg = Release|Any CPU {E6EDAB8F-3406-4DBF-9AAB-DF40DC2CA0FA}.Release|Any CPU.ActiveCfg = Release|Any CPU {E6EDAB8F-3406-4DBF-9AAB-DF40DC2CA0FA}.Release|Any CPU.Build.0 = Release|Any CPU + {677F1381-7830-4115-9C1A-58B282629DC6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {677F1381-7830-4115-9C1A-58B282629DC6}.Debug|Any CPU.Build.0 = Debug|Any CPU + {677F1381-7830-4115-9C1A-58B282629DC6}.Publish|Any CPU.ActiveCfg = Publish|Any CPU + {677F1381-7830-4115-9C1A-58B282629DC6}.Publish|Any CPU.Build.0 = Publish|Any CPU + {677F1381-7830-4115-9C1A-58B282629DC6}.Release|Any CPU.ActiveCfg = Release|Any CPU + {677F1381-7830-4115-9C1A-58B282629DC6}.Release|Any CPU.Build.0 = Release|Any CPU {4762BCAF-E1C5-4714-B88D-E50FA333C50E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {4762BCAF-E1C5-4714-B88D-E50FA333C50E}.Debug|Any CPU.Build.0 = Debug|Any CPU {4762BCAF-E1C5-4714-B88D-E50FA333C50E}.Publish|Any CPU.ActiveCfg = Publish|Any CPU @@ -397,6 +405,7 @@ Global {1C19D805-3573-4477-BF07-40180FCDE1BD} = {958AD708-F048-4FAF-94ED-D2F2B92748B9} {0D0C4DAD-E6BC-4504-AE3A-EEA4E35920C1} = {9ECD1AA0-75B3-4E25-B0B5-9F0945B64974} {E6EDAB8F-3406-4DBF-9AAB-DF40DC2CA0FA} = {FA3720F1-C99A-49B2-9577-A940257098BF} + {677F1381-7830-4115-9C1A-58B282629DC6} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} {4762BCAF-E1C5-4714-B88D-E50FA333C50E} = {078F96B4-09E1-4E0E-B214-F71A4F4BF633} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution diff --git a/dotnet/SK-dotnet.sln.DotSettings b/dotnet/SK-dotnet.sln.DotSettings index 94c269cd2a4a..4d5e6137e95a 100644 --- a/dotnet/SK-dotnet.sln.DotSettings +++ b/dotnet/SK-dotnet.sln.DotSettings @@ -202,8 +202,10 @@ public void It$SOMENAME$() True True True + True True True + True True True True diff --git a/dotnet/src/Connectors/Connectors.AI.Oobabooga/Connectors.AI.Oobabooga.csproj b/dotnet/src/Connectors/Connectors.AI.Oobabooga/Connectors.AI.Oobabooga.csproj new file mode 100644 index 000000000000..6daa5aaab4c1 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AI.Oobabooga/Connectors.AI.Oobabooga.csproj @@ -0,0 +1,28 @@ + + + + + Microsoft.SemanticKernel.Connectors.AI.Oobabooga + $(AssemblyName) + netstandard2.0 + + + + + + + + + Semantic Kernel - Oobabooga Connector + Semantic Kernel connector for the oobabooga text-generation-webui open source project. Contains a client for text completion. + + + + + + + + + + + diff --git a/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/OobaboogaInvalidResponseException.cs b/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/OobaboogaInvalidResponseException.cs new file mode 100644 index 000000000000..a2e8e51d2a57 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/OobaboogaInvalidResponseException.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.AI; + +namespace Microsoft.SemanticKernel.Connectors.AI.Oobabooga.TextCompletion; + +#pragma warning disable RCS1194 // Implement exception constructors. +internal sealed class OobaboogaInvalidResponseException : AIException +{ + public T? ResponseData { get; } + + public OobaboogaInvalidResponseException(T? responseData, string? message = null) : base(ErrorCodes.InvalidResponseContent, message) + { + this.ResponseData = responseData; + } +} diff --git a/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/OobaboogaTextCompletion.cs b/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/OobaboogaTextCompletion.cs new file mode 100644 index 000000000000..e8d41d7b9411 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/OobaboogaTextCompletion.cs @@ -0,0 +1,475 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Net.WebSockets; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.AI; +using Microsoft.SemanticKernel.AI.TextCompletion; +using Microsoft.SemanticKernel.Diagnostics; + +namespace Microsoft.SemanticKernel.Connectors.AI.Oobabooga.TextCompletion; + +/// +/// Oobabooga text completion service API. +/// Adapted from +/// +public sealed class OobaboogaTextCompletion : ITextCompletion +{ + public const string HttpUserAgent = "Microsoft-Semantic-Kernel"; + public const string BlockingUriPath = "/api/v1/generate"; + private const string StreamingUriPath = "/api/v1/stream"; + + private readonly UriBuilder _blockingUri; + private readonly UriBuilder _streamingUri; + private readonly HttpClient _httpClient; + private readonly Func _webSocketFactory; + private readonly bool _useWebSocketsPooling; + private readonly int _maxNbConcurrentWebSockets; + private readonly SemaphoreSlim? _concurrentSemaphore; + private readonly ConcurrentBag? _activeConnections; + private readonly ConcurrentBag _webSocketPool = new(); + private readonly int _keepAliveWebSocketsDuration; + private readonly ILogger? _logger; + private long _lastCallTicks = long.MaxValue; + + /// + /// Controls the size of the buffer used to received websocket packets + /// + public int WebSocketBufferSize { get; set; } = 2048; + + /// + /// Initializes a new instance of the class. + /// + /// The service API endpoint to which requests should be sent. + /// The port used for handling blocking requests. Default value is 5000 + /// The port used for handling streaming requests. Default value is 5005 + /// You can optionally set a hard limit on the max number of concurrent calls to the either of the completion methods by providing a . Calls in excess will wait for existing consumers to release the semaphore + /// Optional. The HTTP client used for making blocking API requests. If not specified, a default client will be used. + /// If true, websocket clients will be recycled in a reusable pool as long as concurrent calls are detected + /// if websocket pooling is enabled, you can provide an optional CancellationToken to properly dispose of the clean up tasks when disposing of the connector + /// When pooling is enabled, pooled websockets are flushed on a regular basis when no more connections are made. This is the time to keep them in pool before flushing + /// The WebSocket factory used for making streaming API requests. Note that only when pooling is enabled will websocket be recycled and reused for the specified duration. Otherwise, a new websocket is created for each call and closed and disposed afterwards, to prevent data corruption from concurrent calls. + /// Application logger + public OobaboogaTextCompletion(Uri endpoint, + int blockingPort = 5000, + int streamingPort = 5005, + SemaphoreSlim? concurrentSemaphore = null, + HttpClient? httpClient = null, + bool useWebSocketsPooling = true, + CancellationToken? webSocketsCleanUpCancellationToken = default, + int keepAliveWebSocketsDuration = 100, + Func? webSocketFactory = null, + ILogger? logger = null) + { + Verify.NotNull(endpoint); + this._blockingUri = new UriBuilder(endpoint) + { + Port = blockingPort, + Path = BlockingUriPath + }; + this._streamingUri = new(endpoint) + { + Port = streamingPort, + Path = StreamingUriPath + }; + if (this._streamingUri.Uri.Scheme.StartsWith("http", StringComparison.OrdinalIgnoreCase)) + { + this._streamingUri.Scheme = (this._streamingUri.Scheme == "https") ? "wss" : "ws"; + } + + this._httpClient = httpClient ?? new HttpClient(NonDisposableHttpClientHandler.Instance, disposeHandler: false); + this._useWebSocketsPooling = useWebSocketsPooling; + this._keepAliveWebSocketsDuration = keepAliveWebSocketsDuration; + this._logger = logger; + if (webSocketFactory != null) + { + this._webSocketFactory = () => + { + var webSocket = webSocketFactory(); + this.SetWebSocketOptions(webSocket); + return webSocket; + }; + } + else + { + this._webSocketFactory = () => + { + ClientWebSocket webSocket = new(); + this.SetWebSocketOptions(webSocket); + return webSocket; + }; + } + + // if a hard limit is defined, we use a semaphore to limit the number of concurrent calls, otherwise, we use a stack to track active connections + if (concurrentSemaphore != null) + { + this._concurrentSemaphore = concurrentSemaphore; + this._maxNbConcurrentWebSockets = concurrentSemaphore.CurrentCount; + } + else + { + this._activeConnections = new(); + this._maxNbConcurrentWebSockets = 0; + } + + if (this._useWebSocketsPooling) + { + this.StartCleanupTask(webSocketsCleanUpCancellationToken ?? CancellationToken.None); + } + } + + /// + public async IAsyncEnumerable GetStreamingCompletionsAsync( + string text, + CompleteRequestSettings requestSettings, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await this.StartConcurrentCallAsync(cancellationToken).ConfigureAwait(false); + + var completionRequest = this.CreateOobaboogaRequest(text, requestSettings); + + var requestJson = JsonSerializer.Serialize(completionRequest); + + var requestBytes = Encoding.UTF8.GetBytes(requestJson); + + ClientWebSocket? clientWebSocket = null; + try + { + // if pooling is enabled, web socket is going to be recycled for reuse, if not it will be properly disposed of after the call +#pragma warning disable CA2000 // Dispose objects before losing scope + if (!this._useWebSocketsPooling || !this._webSocketPool.TryTake(out clientWebSocket)) + { + clientWebSocket = this._webSocketFactory(); + } +#pragma warning restore CA2000 // Dispose objects before losing scope + if (clientWebSocket.State == WebSocketState.None) + { + await clientWebSocket.ConnectAsync(this._streamingUri.Uri, cancellationToken).ConfigureAwait(false); + } + + var sendSegment = new ArraySegment(requestBytes); + await clientWebSocket.SendAsync(sendSegment, WebSocketMessageType.Text, true, cancellationToken).ConfigureAwait(false); + + TextCompletionStreamingResult streamingResult = new(); + + var processingTask = this.ProcessWebSocketMessagesAsync(clientWebSocket, streamingResult, cancellationToken); + + yield return streamingResult; + + // Await the processing task to make sure it's finished before continuing + await processingTask.ConfigureAwait(false); + } + finally + { + if (clientWebSocket != null) + { + if (this._useWebSocketsPooling && clientWebSocket.State == WebSocketState.Open) + { + this._webSocketPool.Add(clientWebSocket); + } + else + { + await this.DisposeClientGracefullyAsync(clientWebSocket).ConfigureAwait(false); + } + } + + this.FinishConcurrentCall(); + } + } + + /// + public async Task> GetCompletionsAsync( + string text, + CompleteRequestSettings requestSettings, + CancellationToken cancellationToken = default) + { + try + { + await this.StartConcurrentCallAsync(cancellationToken).ConfigureAwait(false); + + var completionRequest = this.CreateOobaboogaRequest(text, requestSettings); + + using var stringContent = new StringContent( + JsonSerializer.Serialize(completionRequest), + Encoding.UTF8, + "application/json"); + + using var httpRequestMessage = new HttpRequestMessage() + { + Method = HttpMethod.Post, + RequestUri = this._blockingUri.Uri, + Content = stringContent + }; + httpRequestMessage.Headers.Add("User-Agent", HttpUserAgent); + + using var response = await this._httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); + response.EnsureSuccessStatusCode(); + + var body = await response.Content.ReadAsStringAsync().ConfigureAwait(false); + + TextCompletionResponse? completionResponse = JsonSerializer.Deserialize(body); + + if (completionResponse is null) + { + throw new OobaboogaInvalidResponseException(body, "Unexpected response from Oobabooga API"); + } + + return completionResponse.Results.Select(completionText => new TextCompletionResult(completionText)).ToList(); + } + catch (Exception e) when (e is not AIException && !e.IsCriticalException()) + { + throw new AIException( + AIException.ErrorCodes.UnknownError, + $"Something went wrong: {e.Message}", e); + } + finally + { + this.FinishConcurrentCall(); + } + } + + #region private ================================================================================ + + /// + /// Creates an Oobabooga request, mapping CompleteRequestSettings fields to their Oobabooga API counter parts + /// + /// The text to complete. + /// The request settings. + /// An Oobabooga TextCompletionRequest object with the text and completion parameters. + private TextCompletionRequest CreateOobaboogaRequest(string text, CompleteRequestSettings requestSettings) + { + if (string.IsNullOrWhiteSpace(text)) + { + throw new ArgumentNullException(nameof(text)); + } + + // Prepare the request using the provided parameters. + return new TextCompletionRequest() + { + Prompt = text, + MaxNewTokens = requestSettings.MaxTokens, + Temperature = requestSettings.Temperature, + TopP = requestSettings.TopP, + RepetitionPenalty = GetRepetitionPenalty(requestSettings), + StoppingStrings = requestSettings.StopSequences.ToList() + }; + } + + /// + /// Sets the options for the , either persistent and provided by the ctor, or transient if none provided. + /// + private void SetWebSocketOptions(ClientWebSocket clientWebSocket) + { + clientWebSocket.Options.SetRequestHeader("User-Agent", HttpUserAgent); + } + + /// + /// Converts the semantic-kernel presence penalty, scaled -2:+2 with default 0 for no penalty to the Oobabooga repetition penalty, strictly positive with default 1 for no penalty. See and subsequent links for more details. + /// + private static double GetRepetitionPenalty(CompleteRequestSettings requestSettings) + { + return 1 + requestSettings.PresencePenalty / 2; + } + + /// + /// That method is responsible for processing the websocket messages that build a streaming response object. It is crucial that it is run asynchronously to prevent a deadlock with results iteration + /// + private async Task ProcessWebSocketMessagesAsync(ClientWebSocket clientWebSocket, TextCompletionStreamingResult streamingResult, CancellationToken cancellationToken) + { + var buffer = new byte[this.WebSocketBufferSize]; + var finishedProcessing = false; + while (!finishedProcessing && !cancellationToken.IsCancellationRequested) + { + MemoryStream messageStream = new(); + WebSocketReceiveResult result; + do + { + var segment = new ArraySegment(buffer); + result = await clientWebSocket.ReceiveAsync(segment, cancellationToken).ConfigureAwait(false); + await messageStream.WriteAsync(buffer, 0, result.Count, cancellationToken).ConfigureAwait(false); + } while (!result.EndOfMessage); + + messageStream.Seek(0, SeekOrigin.Begin); + + if (result.MessageType == WebSocketMessageType.Text) + { + string messageText; + using (var reader = new StreamReader(messageStream, Encoding.UTF8)) + { + messageText = await reader.ReadToEndAsync().ConfigureAwait(false); + } + + var responseObject = JsonSerializer.Deserialize(messageText); + + if (responseObject is null) + { + throw new OobaboogaInvalidResponseException(messageText, "Unexpected response from Oobabooga API"); + } + + switch (responseObject.Event) + { + case TextCompletionStreamingResponse.ResponseObjectTextStreamEvent: + streamingResult.AppendResponse(responseObject); + break; + case TextCompletionStreamingResponse.ResponseObjectStreamEndEvent: + streamingResult.SignalStreamEnd(); + if (!this._useWebSocketsPooling) + { + await clientWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Acknowledge stream-end oobabooga message", CancellationToken.None).ConfigureAwait(false); + } + + finishedProcessing = true; + break; + default: + break; + } + } + else if (result.MessageType == WebSocketMessageType.Close) + { + await clientWebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Acknowledge Close frame", CancellationToken.None).ConfigureAwait(false); + finishedProcessing = true; + } + + if (clientWebSocket.State != WebSocketState.Open) + { + finishedProcessing = true; + } + } + } + + /// + /// Starts a concurrent call, either by taking a semaphore slot or by pushing a value on the active connections stack + /// + /// + private async Task StartConcurrentCallAsync(CancellationToken cancellationToken) + { + if (this._concurrentSemaphore != null) + { + await this._concurrentSemaphore!.WaitAsync(cancellationToken).ConfigureAwait(false); + } + else + { + this._activeConnections!.Add(true); + } + } + + /// + /// Gets the number of concurrent calls, either by reading the semaphore count or by reading the active connections stack count + /// + /// + private int GetCurrentConcurrentCallsNb() + { + if (this._concurrentSemaphore != null) + { + return this._maxNbConcurrentWebSockets - this._concurrentSemaphore!.CurrentCount; + } + + return this._activeConnections!.Count; + } + + /// + /// Ends a concurrent call, either by releasing a semaphore slot or by popping a value from the active connections stack + /// + private void FinishConcurrentCall() + { + if (this._concurrentSemaphore != null) + { + this._concurrentSemaphore!.Release(); + } + else + { + this._activeConnections!.TryTake(out _); + } + + Interlocked.Exchange(ref this._lastCallTicks, DateTime.UtcNow.Ticks); + } + + private void StartCleanupTask(CancellationToken cancellationToken) + { + Task.Factory.StartNew( + async () => + { + while (!cancellationToken.IsCancellationRequested) + { + await this.FlushWebSocketClientsAsync(cancellationToken).ConfigureAwait(false); + } + }, + cancellationToken, + TaskCreationOptions.LongRunning, + TaskScheduler.Default); + } + + /// + /// Flushes the web socket clients that have been idle for too long + /// + /// + private async Task FlushWebSocketClientsAsync(CancellationToken cancellationToken) + { + // In the cleanup task, make sure you handle OperationCanceledException appropriately + // and make frequent checks on whether cancellation is requested. + try + { + if (!cancellationToken.IsCancellationRequested) + { + await Task.Delay(this._keepAliveWebSocketsDuration, cancellationToken).ConfigureAwait(false); + + // If another call was made during the delay, do not proceed with flushing + if (DateTime.UtcNow.Ticks - Interlocked.Read(ref this._lastCallTicks) < TimeSpan.FromMilliseconds(this._keepAliveWebSocketsDuration).Ticks) + { + return; + } + + while (this.GetCurrentConcurrentCallsNb() == 0 && this._webSocketPool.TryTake(out ClientWebSocket clientToDispose)) + { + await this.DisposeClientGracefullyAsync(clientToDispose).ConfigureAwait(false); + } + } + } + catch (OperationCanceledException exception) + { + this._logger?.LogTrace(message: "FlushWebSocketClientsAsync cleaning task was cancelled", exception: exception); + while (this._webSocketPool.TryTake(out ClientWebSocket clientToDispose)) + { + await this.DisposeClientGracefullyAsync(clientToDispose).ConfigureAwait(false); + } + } + } + + /// + /// Closes and disposes of a client web socket after use + /// + private async Task DisposeClientGracefullyAsync(ClientWebSocket clientWebSocket) + { + try + { + if (clientWebSocket.State == WebSocketState.Open) + { + await clientWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing client before disposal", CancellationToken.None).ConfigureAwait(false); + } + } + catch (OperationCanceledException exception) + { + this._logger?.LogTrace(message: "Closing client web socket before disposal was cancelled", exception: exception); + } + catch (WebSocketException exception) + { + this._logger?.LogTrace(message: "Closing client web socket before disposal raised web socket exception", exception: exception); + } + finally + { + clientWebSocket.Dispose(); + } + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionRequest.cs b/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionRequest.cs new file mode 100644 index 000000000000..8adcc088187a --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionRequest.cs @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace Microsoft.SemanticKernel.Connectors.AI.Oobabooga.TextCompletion; + +/// +/// HTTP schema to perform oobabooga completion request. Contains many parameters, some of which are specific to certain kinds of models. +/// See and subsequent links for additional information. +/// +[Serializable] +public sealed class TextCompletionRequest +{ + /// + /// The prompt text to complete. + /// + [JsonPropertyName("prompt")] + public string Prompt { get; set; } = string.Empty; + + /// + /// The maximum number of tokens to generate, ignoring the number of tokens in the prompt. + /// + [JsonPropertyName("max_new_tokens")] + public int? MaxNewTokens { get; set; } + + /// + /// Determines whether or not to use sampling; use greedy decoding if false. + /// + [JsonPropertyName("do_sample")] + public bool DoSample { get; set; } = true; + + /// + /// Modulates the next token probabilities. A value of 0 implies deterministic output (only the most likely token is used). Higher values increase randomness. + /// + [JsonPropertyName("temperature")] + public double Temperature { get; set; } + + /// + /// If set to a value less than 1, only the most probable tokens with cumulative probability less than this value are kept for generation. + /// + [JsonPropertyName("top_p")] + public double TopP { get; set; } + + /// + /// Measures how similar the conditional probability of predicting a target token is to the expected conditional probability of predicting a random token, given the generated text. + /// + [JsonPropertyName("typical_p")] + public double TypicalP { get; set; } = 1; + + /// + /// Sets a probability floor below which tokens are excluded from being sampled. + /// + [JsonPropertyName("epsilon_cutoff")] + public double EpsilonCutoff { get; set; } + + /// + /// Used with top_p, top_k, and epsilon_cutoff set to 0. This parameter hybridizes locally typical sampling and epsilon sampling. + /// + [JsonPropertyName("eta_cutoff")] + public double EtaCutoff { get; set; } + + /// + /// Controls Tail Free Sampling (value between 0 and 1) + /// + [JsonPropertyName("tfs")] + public double Tfs { get; set; } = 1; + + /// + /// Top A Sampling is a way to pick the next word in a sentence based on how important it is in the context. Top-A considers the probability of the most likely token, and sets a limit based on its percentage. After this, remaining tokens are compared to this limit. If their probability is too low, they are removed from the pool​. + /// + [JsonPropertyName("top_a")] + public double TopA { get; set; } + + /// + /// Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition. + /// + [JsonPropertyName("repetition_penalty")] + public double RepetitionPenalty { get; set; } = 1.18; + + /// + ///When using "top k", you select the top k most likely words to come next based on their probability of occurring, where k is a fixed number that you specify. You can use Top_K to control the amount of diversity in the model output​ + /// + [JsonPropertyName("top_k")] + public int TopK { get; set; } + + /// + /// Minimum length of the sequence to be generated. + /// + [JsonPropertyName("min_length")] + public int MinLength { get; set; } + + /// + /// If set to a value greater than 0, all ngrams of that size can only occur once. + /// + [JsonPropertyName("no_repeat_ngram_size")] + public int NoRepeatNgramSize { get; set; } + + /// + /// Number of beams for beam search. 1 means no beam search. + /// + [JsonPropertyName("num_beams")] + public int NumBeams { get; set; } = 1; + + /// + /// The values balance the model confidence and the degeneration penalty in contrastive search decoding. + /// + [JsonPropertyName("penalty_alpha")] + public int PenaltyAlpha { get; set; } + + /// + /// Exponential penalty to the length that is used with beam-based generation + /// + [JsonPropertyName("length_penalty")] + public double LengthPenalty { get; set; } = 1; + + /// + /// Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: True, where the generation stops as soon as there are num_beams complete candidates; False, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates. + /// + [JsonPropertyName("early_stopping")] + public bool EarlyStopping { get; set; } + + /// + /// Parameter used for mirostat sampling in Llama.cpp, controlling perplexity during text (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) + /// + [JsonPropertyName("mirostat_mode")] + public int MirostatMode { get; set; } + + /// + /// Set the Mirostat target entropy, parameter tau (default: 5.0) + /// + [JsonPropertyName("mirostat_tau")] + public int MirostatTau { get; set; } = 5; + + /// + /// Set the Mirostat learning rate, parameter eta (default: 0.1) + /// + [JsonPropertyName("mirostat_eta")] + public double MirostatEta { get; set; } = 0.1; + + /// + /// Random seed to control sampling, used when DoSample is True. + /// + [JsonPropertyName("seed")] + public int Seed { get; set; } = -1; + + /// + /// Controls whether to add beginning of a sentence token + /// + [JsonPropertyName("add_bos_token")] + public bool AddBosToken { get; set; } = true; + + /// + /// The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048. + /// + [JsonPropertyName("truncation_length")] + public int TruncationLength { get; set; } = 2048; + + /// + /// Forces the model to never end the generation prematurely. + /// + [JsonPropertyName("ban_eos_token")] + public bool BanEosToken { get; set; } = true; + + /// + /// Some specific models need this unset. + /// + [JsonPropertyName("skip_special_tokens")] + public bool SkipSpecialTokens { get; set; } = true; + + /// + /// In addition to the defaults. Written between "" and separated by commas. For instance: "\nYour Assistant:", "\nThe assistant:" + /// + [JsonPropertyName("stopping_strings")] + public List StoppingStrings { get; set; } = new List(); +} diff --git a/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionResponse.cs b/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionResponse.cs new file mode 100644 index 000000000000..e5058fe77cb2 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionResponse.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace Microsoft.SemanticKernel.Connectors.AI.Oobabooga.TextCompletion; + +/// +/// HTTP Schema for Oobabooga completion response. Contains a list of results. Adapted from +/// +public sealed class TextCompletionResponse +{ + /// + /// A field used by Oobabooga to return results from the blocking API. + /// + [JsonPropertyName("results")] + public List Results { get; set; } = new(); +} + +/// +/// HTTP Schema for an single Oobabooga result as part of a completion response. +/// +public sealed class TextCompletionResponseText +{ + /// + /// Completed text. + /// + [JsonPropertyName("text")] + public string? Text { get; set; } = string.Empty; +} diff --git a/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionResult.cs b/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionResult.cs new file mode 100644 index 000000000000..95097f9736ec --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionResult.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.AI.TextCompletion; +using Microsoft.SemanticKernel.Orchestration; + +namespace Microsoft.SemanticKernel.Connectors.AI.Oobabooga.TextCompletion; + +/// +/// Oobabooga implementation of . Actual response object is stored in a ModelResult instance, and completion text is simply passed forward. +/// +internal sealed class TextCompletionResult : ITextResult +{ + private readonly ModelResult _responseData; + + public TextCompletionResult(TextCompletionResponseText responseData) + { + this._responseData = new ModelResult(responseData); + } + + public ModelResult ModelResult => this._responseData; + + public Task GetCompletionAsync(CancellationToken cancellationToken = default) + { + return Task.FromResult(this._responseData.GetResult().Text ?? string.Empty); + } +} diff --git a/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionStreamingResponse.cs b/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionStreamingResponse.cs new file mode 100644 index 000000000000..33d9abf68401 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionStreamingResponse.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; + +namespace Microsoft.SemanticKernel.Connectors.AI.Oobabooga.TextCompletion; + +/// +/// HTTP Schema for streaming completion response. Adapted from +/// +public sealed class TextCompletionStreamingResponse +{ + public const string ResponseObjectTextStreamEvent = "text_stream"; + public const string ResponseObjectStreamEndEvent = "stream_end"; + + /// + /// A field used by Oobabooga to signal the type of websocket message sent, e.g. "text_stream" or "stream_end". + /// + [JsonPropertyName("event")] + public string Event { get; set; } = string.Empty; + + /// + /// A field used by Oobabooga to signal the number of messages sent, starting with 0 and incremented on each message. + /// + [JsonPropertyName("message_num")] + public int MessageNum { get; set; } + + /// + /// A field used by Oobabooga with the text chunk sent in the websocket message. + /// + [JsonPropertyName("text")] + public string Text { get; set; } = string.Empty; +} diff --git a/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionStreamingResult.cs b/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionStreamingResult.cs new file mode 100644 index 000000000000..0575e6434cc2 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AI.Oobabooga/TextCompletion/TextCompletionStreamingResult.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.AI.TextCompletion; +using Microsoft.SemanticKernel.Orchestration; + +namespace Microsoft.SemanticKernel.Connectors.AI.Oobabooga.TextCompletion; + +internal sealed class TextCompletionStreamingResult : ITextStreamingResult +{ + private readonly List _modelResponses; + private readonly Channel _responseChannel; + + public ModelResult ModelResult { get; } + + public TextCompletionStreamingResult() + { + this._modelResponses = new(); + this.ModelResult = new ModelResult(this._modelResponses); + this._responseChannel = Channel.CreateUnbounded(new UnboundedChannelOptions() + { + SingleReader = true, + SingleWriter = true, + AllowSynchronousContinuations = false + }); + } + + public void AppendResponse(TextCompletionStreamingResponse response) + { + this._modelResponses.Add(response); + this._responseChannel.Writer.TryWrite(response.Text); + } + + public void SignalStreamEnd() + { + this._responseChannel.Writer.Complete(); + } + + public async Task GetCompletionAsync(CancellationToken cancellationToken = default) + { + StringBuilder resultBuilder = new(); + + await foreach (var chunk in this.GetCompletionStreamingAsync(cancellationToken)) + { + resultBuilder.Append(chunk); + } + + return resultBuilder.ToString(); + } + + public async IAsyncEnumerable GetCompletionStreamingAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + while (await this._responseChannel.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + { + while (this._responseChannel.Reader.TryRead(out string? chunk)) + { + yield return chunk; + } + } + } +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/ConnectedClient.cs b/dotnet/src/Connectors/Connectors.UnitTests/ConnectedClient.cs new file mode 100644 index 000000000000..b47c192dbd61 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/ConnectedClient.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net; +using System.Net.WebSockets; + +namespace SemanticKernel.Connectors.UnitTests; + +internal sealed class ConnectedClient +{ + public Guid Id { get; } + public HttpListenerContext Context { get; } + public WebSocket? Socket { get; private set; } + + public ConnectedClient(Guid id, HttpListenerContext context) + { + this.Id = id; + this.Context = context; + } + + public void SetSocket(WebSocket socket) + { + this.Socket = socket; + } +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj b/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj index 0fc43760fd5b..eeeedeee5625 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj @@ -31,6 +31,7 @@ + @@ -56,6 +57,12 @@ Always + + Always + + + Always + diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/OobaboogaTestHelper.cs b/dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/OobaboogaTestHelper.cs new file mode 100644 index 000000000000..0df5eda9dd19 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/OobaboogaTestHelper.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.IO; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Moq; +using Moq.Protected; + +namespace SemanticKernel.Connectors.UnitTests.Oobabooga; + +/// +/// Helper for Oobabooga test purposes. +/// +internal static class OobaboogaTestHelper +{ + /// + /// Reads test response from file for mocking purposes. + /// + /// Name of the file with test response. + internal static string GetTestResponse(string fileName) + { + return File.ReadAllText($"./Oobabooga/TestData/{fileName}"); + } + + /// + /// Returns mocked instance of . + /// + /// Message to return for mocked . + internal static HttpClientHandler GetHttpClientHandlerMock(HttpResponseMessage httpResponseMessage) + { + var httpClientHandler = new Mock(); + + httpClientHandler + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(httpResponseMessage); + + return httpClientHandler.Object; + } +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/OobaboogaWebSocketTestServer.cs b/dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/OobaboogaWebSocketTestServer.cs new file mode 100644 index 000000000000..d9210603a8fd --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/OobaboogaWebSocketTestServer.cs @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Text; +using System.Text.Json; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.Connectors.AI.Oobabooga.TextCompletion; + +namespace SemanticKernel.Connectors.UnitTests.Oobabooga; + +/// +/// Represents a WebSocket test server specifically designed for the Oobabooga text completion service. +/// It inherits from the base WebSocketTestServer class and handles Oobabooga-specific request and response classes. +/// The server accepts WebSocket connections, receives requests, and generates responses based on the Oobabooga text completion logic. +/// The OobaboogaWebSocketTestServer class uses a delegate to handle the request and response logic, allowing customization of the behavior. +/// +internal sealed class OobaboogaWebSocketTestServer : WebSocketTestServer +{ + public OobaboogaWebSocketTestServer(string url, Func> stringHandler, ILogger? logger = null) + : base(url, bytes => HandleRequest(bytes, stringHandler), logger: logger) + { + } + + private static List> HandleRequest(ArraySegment request, Func> stringHandler) + { + var requestString = Encoding.UTF8.GetString(request.ToArray()); + var requestObj = JsonSerializer.Deserialize(requestString); + + var responseList = stringHandler(requestObj?.Prompt ?? string.Empty); + + var responseSegments = new List>(); + int messageNum = 0; + foreach (var responseChunk in responseList) + { + var responseObj = new TextCompletionStreamingResponse + { + Event = "text_stream", + MessageNum = messageNum, + Text = responseChunk + }; + + var responseJson = JsonSerializer.Serialize(responseObj); + var responseBytes = Encoding.UTF8.GetBytes(responseJson); + responseSegments.Add(new ArraySegment(responseBytes)); + + messageNum++; + } + + var streamEndObj = new TextCompletionStreamingResponse + { + Event = "stream_end", + MessageNum = messageNum + }; + + var streamEndJson = JsonSerializer.Serialize(streamEndObj); + var streamEndBytes = Encoding.UTF8.GetBytes(streamEndJson); + responseSegments.Add(new ArraySegment(streamEndBytes)); + + return responseSegments; + } +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/TestData/completion_test_response.json b/dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/TestData/completion_test_response.json new file mode 100644 index 000000000000..397ee62436d5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/TestData/completion_test_response.json @@ -0,0 +1,9 @@ +{ + "results": [ + { + "text": "This is test completion response" + + } + ] + +} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/TestData/completion_test_streaming_response.json b/dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/TestData/completion_test_streaming_response.json new file mode 100644 index 000000000000..bf731d314094 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/TestData/completion_test_streaming_response.json @@ -0,0 +1,5 @@ +{ + "event": "text_stream", + "message_num": 0, + "text": "This is test completion response" +} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/TextCompletion/OobaboogaTextCompletionTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/TextCompletion/OobaboogaTextCompletionTests.cs new file mode 100644 index 000000000000..65810789802d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/Oobabooga/TextCompletion/OobaboogaTextCompletionTests.cs @@ -0,0 +1,405 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.Linq; +using System.Net.Http; +using System.Net.WebSockets; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.AI.TextCompletion; +using Microsoft.SemanticKernel.Connectors.AI.Oobabooga.TextCompletion; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.Connectors.UnitTests.Oobabooga.TextCompletion; + +/// +/// Unit tests for class. +/// +public sealed class OobaboogaTextCompletionTests : IDisposable +{ + private readonly XunitLogger _logger; + private const string EndPoint = "https://fake-random-test-host"; + private const int BlockingPort = 1234; + private const int StreamingPort = 2345; + private const string CompletionText = "fake-test"; + private const string CompletionMultiText = "Hello, my name is"; + + private HttpMessageHandlerStub _messageHandlerStub; + private HttpClient _httpClient; + private Uri _endPointUri; + private string _streamCompletionResponseStub; + + public OobaboogaTextCompletionTests(ITestOutputHelper output) + { + this._logger = new XunitLogger(output); + this._messageHandlerStub = new HttpMessageHandlerStub(); + this._messageHandlerStub.ResponseToReturn.Content = new StringContent(OobaboogaTestHelper.GetTestResponse("completion_test_response.json")); + this._streamCompletionResponseStub = OobaboogaTestHelper.GetTestResponse("completion_test_streaming_response.json"); + + this._httpClient = new HttpClient(this._messageHandlerStub, false); + this._endPointUri = new Uri(EndPoint); + } + + [Fact] + public async Task UserAgentHeaderShouldBeUsedAsync() + { + //Arrange + var sut = new OobaboogaTextCompletion(endpoint: this._endPointUri, + blockingPort: BlockingPort, + httpClient: this._httpClient, + logger: this._logger); + + //Act + await sut.GetCompletionsAsync(CompletionText, new CompleteRequestSettings()); + + //Assert + Assert.True(this._messageHandlerStub.RequestHeaders?.Contains("User-Agent")); + + var values = this._messageHandlerStub.RequestHeaders!.GetValues("User-Agent"); + + var value = values.SingleOrDefault(); + Assert.Equal(OobaboogaTextCompletion.HttpUserAgent, value); + } + + [Fact] + public async Task ProvidedEndpointShouldBeUsedAsync() + { + //Arrange + var sut = new OobaboogaTextCompletion(endpoint: this._endPointUri, + blockingPort: BlockingPort, + httpClient: this._httpClient, + logger: this._logger); + + //Act + await sut.GetCompletionsAsync(CompletionText, new CompleteRequestSettings()); + + //Assert + Assert.StartsWith(EndPoint, this._messageHandlerStub.RequestUri?.AbsoluteUri, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task BlockingUrlShouldBeBuiltSuccessfullyAsync() + { + //Arrange + var sut = new OobaboogaTextCompletion(endpoint: this._endPointUri, + blockingPort: BlockingPort, + httpClient: this._httpClient, + logger: this._logger); + + //Act + await sut.GetCompletionsAsync(CompletionText, new CompleteRequestSettings()); + var expectedUri = new UriBuilder(this._endPointUri) + { + Path = OobaboogaTextCompletion.BlockingUriPath, + Port = BlockingPort + }; + + //Assert + Assert.Equal(expectedUri.Uri, this._messageHandlerStub.RequestUri); + } + + [Fact] + public async Task ShouldSendPromptToServiceAsync() + { + //Arrange + var sut = new OobaboogaTextCompletion(endpoint: this._endPointUri, + blockingPort: BlockingPort, + httpClient: this._httpClient, + logger: this._logger); + + //Act + await sut.GetCompletionsAsync(CompletionText, new CompleteRequestSettings()); + + //Assert + var requestPayload = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(requestPayload); + + Assert.Equal(CompletionText, requestPayload.Prompt); + } + + [Fact] + public async Task ShouldHandleServiceResponseAsync() + { + //Arrange + var sut = new OobaboogaTextCompletion(endpoint: this._endPointUri, + blockingPort: BlockingPort, + httpClient: this._httpClient, + logger: this._logger); + + //Act + var result = await sut.GetCompletionsAsync(CompletionText, new CompleteRequestSettings()); + + //Assert + Assert.NotNull(result); + + var completions = result.SingleOrDefault(); + Assert.NotNull(completions); + + var completion = await completions.GetCompletionAsync(); + Assert.Equal("This is test completion response", completion); + } + + [Fact] + public async Task ShouldHandleStreamingServicePersistentWebSocketResponseAsync() + { + var requestMessage = CompletionText; + var expectedResponse = new List { this._streamCompletionResponseStub }; + await this.RunWebSocketMultiPacketStreamingTestAsync( + requestMessage: requestMessage, + expectedResponse: expectedResponse, + isPersistent: true).ConfigureAwait(false); + } + + [Fact] + public async Task ShouldHandleStreamingServiceTransientWebSocketResponseAsync() + { + var requestMessage = CompletionText; + var expectedResponse = new List { this._streamCompletionResponseStub }; + await this.RunWebSocketMultiPacketStreamingTestAsync( + requestMessage: requestMessage, + expectedResponse: expectedResponse).ConfigureAwait(false); + } + + [Fact] + public async Task ShouldHandleConcurrentWebSocketConnectionsAsync() + { + var serverUrl = $"http://localhost:{StreamingPort}/"; + var clientUrl = $"ws://localhost:{StreamingPort}/"; + var expectedResponses = new List + { + "Response 1", + "Response 2", + "Response 3", + "Response 4", + "Response 5" + }; + + await using var server = new WebSocketTestServer(serverUrl, request => + { + // Simulate different responses for each request + var responseIndex = int.Parse(Encoding.UTF8.GetString(request.ToArray()), CultureInfo.InvariantCulture); + byte[] bytes = Encoding.UTF8.GetBytes(expectedResponses[responseIndex]); + var toReturn = new List> { new ArraySegment(bytes) }; + return toReturn; + }); + + var tasks = new List>(); + + // Simulate multiple concurrent WebSocket connections + for (int i = 0; i < expectedResponses.Count; i++) + { + var currentIndex = i; + tasks.Add(Task.Run(async () => + { + using var client = new ClientWebSocket(); + await client.ConnectAsync(new Uri(clientUrl), CancellationToken.None); + + // Send a request to the server + var requestBytes = Encoding.UTF8.GetBytes(currentIndex.ToString(CultureInfo.InvariantCulture)); + await client.SendAsync(new ArraySegment(requestBytes), WebSocketMessageType.Text, true, CancellationToken.None); + + // Receive the response from the server + var responseBytes = new byte[1024]; + var responseResult = await client.ReceiveAsync(new ArraySegment(responseBytes), CancellationToken.None); + await client.CloseAsync(WebSocketCloseStatus.NormalClosure, "Close connection after message received", CancellationToken.None).ConfigureAwait(false); + + var response = Encoding.UTF8.GetString(responseBytes, 0, responseResult.Count); + + return response; + })); + } + + // Assert + for (int i = 0; i < expectedResponses.Count; i++) + { + var response = await tasks[i].ConfigureAwait(false); + Assert.Equal(expectedResponses[i], response); + } + } + + [Fact] + public async Task ShouldHandleMultiPacketStreamingServiceTransientWebSocketResponseAsync() + { + await this.RunWebSocketMultiPacketStreamingTestAsync().ConfigureAwait(false); + } + + [Fact] + public async Task ShouldHandleMultiPacketStreamingServicePersistentWebSocketResponseBroadcastBlockAsync() + { + await this.RunWebSocketMultiPacketStreamingTestAsync(isPersistent: true).ConfigureAwait(false); + } + + [Fact] + public async Task ShouldHandleConcurrentMultiPacketStreamingServiceTransientWebSocketResponseAsync() + { + await this.RunWebSocketMultiPacketStreamingTestAsync(nbConcurrentCalls: 10).ConfigureAwait(false); + } + + [Fact] + public async Task ShouldHandleConcurrentMultiPacketStreamingServicePersistentWebSocketResponseAsync() + { + await this.RunWebSocketMultiPacketStreamingTestAsync(nbConcurrentCalls: 10, isPersistent: true).ConfigureAwait(false); + } + + /// + /// This test will assess concurrent enumeration of the same long multi message (500 websocket messages) streaming result. + /// + [Fact] + public async Task ShouldHandleConcurrentEnumerationOfLongStreamingServiceResponseAsync() + { + var expectedResponse = Enumerable.Range(0, 500).Select(i => i.ToString(CultureInfo.InvariantCulture)).ToList(); + using SemaphoreSlim enforcedConcurrentCallSemaphore = new(20); + await this.RunWebSocketMultiPacketStreamingTestAsync( + expectedResponse: expectedResponse, + nbConcurrentCalls: 1, + nbConcurrentEnumeration: 100, + isPersistent: true, + keepAliveWebSocketsDuration: 100, + concurrentCallsTicksDelay: 0, + enforcedConcurrentCallSemaphore: enforcedConcurrentCallSemaphore, + maxExpectedNbClients: 20).ConfigureAwait(false); + } + + private async Task RunWebSocketMultiPacketStreamingTestAsync( + string requestMessage = CompletionMultiText, + List? expectedResponse = null, + int nbConcurrentCalls = 1, + int nbConcurrentEnumeration = 1, + bool isPersistent = false, + int requestProcessingDuration = 0, + int segmentMessageDelay = 0, + int keepAliveWebSocketsDuration = 100, + int concurrentCallsTicksDelay = 0, + SemaphoreSlim? enforcedConcurrentCallSemaphore = null, + int maxExpectedNbClients = 0, + int maxTestDuration = 0) + { + if (expectedResponse == null) + { + expectedResponse = new List { " John", ". I", "'m a", " writer" }; + } + + Func? webSocketFactory = null; + // Counter to track the number of WebSocket clients created + int clientCount = 0; + var delayTimeSpan = new TimeSpan(concurrentCallsTicksDelay); + if (isPersistent) + { + ClientWebSocket ExternalWebSocketFactory() + { + this._logger?.LogInformation(message: "Creating new client web socket"); + var toReturn = new ClientWebSocket(); + return toReturn; + } + + if (maxExpectedNbClients > 0) + { + ClientWebSocket IncrementFactory() + { + var toReturn = ExternalWebSocketFactory(); + Interlocked.Increment(ref clientCount); + return toReturn; + } + + webSocketFactory = IncrementFactory; + } + else + { + webSocketFactory = ExternalWebSocketFactory; + } + } + + using var cleanupToken = new CancellationTokenSource(); + + var sut = new OobaboogaTextCompletion( + endpoint: new Uri("http://localhost/"), + streamingPort: StreamingPort, + httpClient: this._httpClient, + webSocketsCleanUpCancellationToken: cleanupToken.Token, + webSocketFactory: webSocketFactory, + keepAliveWebSocketsDuration: keepAliveWebSocketsDuration, + concurrentSemaphore: enforcedConcurrentCallSemaphore, + logger: this._logger); + + await using var server = new OobaboogaWebSocketTestServer($"http://localhost:{StreamingPort}/", request => expectedResponse, logger: this._logger) + { + RequestProcessingDelay = TimeSpan.FromMilliseconds(requestProcessingDuration), + SegmentMessageDelay = TimeSpan.FromMilliseconds(segmentMessageDelay) + }; + + var sw = Stopwatch.StartNew(); + var tasks = new List>>(); + + for (int i = 0; i < nbConcurrentCalls; i++) + { + tasks.Add(Task.Run(() => + { + var localResponse = sut.CompleteStreamAsync(requestMessage, new CompleteRequestSettings() + { + Temperature = 0.01, + MaxTokens = 7, + TopP = 0.1, + }, cancellationToken: cleanupToken.Token); + return localResponse; + })); + } + + var callEnumerationTasks = new List>>(); + await Task.WhenAll(tasks).ConfigureAwait(false); + + foreach (var callTask in tasks) + { + callEnumerationTasks.AddRange(Enumerable.Range(0, nbConcurrentEnumeration).Select(_ => Task.Run(async () => + { + var completion = await callTask.ConfigureAwait(false); + var result = new List(); + await foreach (var chunk in completion) + { + result.Add(chunk); + } + + return result; + }))); + + // Introduce a delay between creating each WebSocket client + await Task.Delay(delayTimeSpan).ConfigureAwait(false); + } + + var allResults = await Task.WhenAll(callEnumerationTasks).ConfigureAwait(false); + + var elapsed = sw.ElapsedMilliseconds; + if (maxExpectedNbClients > 0) + { + Assert.InRange(clientCount, 1, maxExpectedNbClients); + } + + // Validate all results + foreach (var result in allResults) + { + Assert.Equal(expectedResponse.Count, result.Count); + for (int i = 0; i < expectedResponse.Count; i++) + { + Assert.Equal(expectedResponse[i], result[i]); + } + } + + if (maxTestDuration > 0) + { + Assert.InRange(elapsed, 0, maxTestDuration); + } + } + + public void Dispose() + { + this._httpClient.Dispose(); + this._messageHandlerStub.Dispose(); + this._logger.Dispose(); + } +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/WebSocketTestServer.cs b/dotnet/src/Connectors/Connectors.UnitTests/WebSocketTestServer.cs new file mode 100644 index 000000000000..11eafcb24ef2 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/WebSocketTestServer.cs @@ -0,0 +1,223 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace SemanticKernel.Connectors.UnitTests; + +internal class WebSocketTestServer : IDisposable +{ + private readonly ILogger? _logger; + + private readonly HttpListener _httpListener; + private readonly CancellationTokenSource _mainCancellationTokenSource; + private readonly CancellationTokenSource _socketCancellationTokenSource; + private bool _serverIsRunning; + + private Func, List>> _arraySegmentHandler; + private readonly ConcurrentDictionary> _requestContentQueues; + private readonly ConcurrentBag _runningTasks = new(); + + private readonly ConcurrentDictionary _clients = new(); + + public TimeSpan RequestProcessingDelay { get; set; } = TimeSpan.Zero; + public TimeSpan SegmentMessageDelay { get; set; } = TimeSpan.Zero; + + public ConcurrentDictionary RequestContents + { + get + { + return new ConcurrentDictionary( + this._requestContentQueues + .ToDictionary(kvp => kvp.Key, kvp => kvp.Value.ToList().SelectMany(bytes => bytes).ToArray())); + } + } + + public WebSocketTestServer(string url, Func, List>> arraySegmentHandler, ILogger? logger = null) + { + this._logger = logger; + + this._arraySegmentHandler = arraySegmentHandler; + this._requestContentQueues = new ConcurrentDictionary>(); + + this._mainCancellationTokenSource = new(); + this._socketCancellationTokenSource = new(); + + this._httpListener = new HttpListener(); + this._httpListener.Prefixes.Add(url); + this._httpListener.Start(); + this._serverIsRunning = true; + + Task.Run((Func)this.HandleRequestsAsync, this._mainCancellationTokenSource.Token); + } + + private async Task HandleRequestsAsync() + { + while (!this._mainCancellationTokenSource.IsCancellationRequested) + { + var context = await this._httpListener.GetContextAsync().ConfigureAwait(false); + + if (this._serverIsRunning) + { + if (context.Request.IsWebSocketRequest) + { + var connectedClient = new ConnectedClient(Guid.NewGuid(), context); + this._clients[connectedClient.Id] = connectedClient; + try + { + var socketContext = await context.AcceptWebSocketAsync(subProtocol: null); + connectedClient.SetSocket(socketContext.WebSocket); + this._runningTasks.Add(this.HandleSingleWebSocketRequestAsync(connectedClient)); + } + catch + { + // server error if upgrade from HTTP to WebSocket fails + context.Response.StatusCode = 500; + context.Response.StatusDescription = "WebSocket upgrade failed"; + context.Response.Close(); + throw; + } + } + } + else + { + // HTTP 409 Conflict (with server's current state) + context.Response.StatusCode = 409; + context.Response.StatusDescription = "Server is shutting down"; + context.Response.Close(); + return; + } + } + + await Task.WhenAll(this._runningTasks).ConfigureAwait(false); + } + + private async Task HandleSingleWebSocketRequestAsync(ConnectedClient connectedClient) + { + var buffer = WebSocket.CreateServerBuffer(4096); + + Guid requestId = connectedClient.Id; + this._requestContentQueues[requestId] = new ConcurrentQueue(); + + try + { + while (!this._socketCancellationTokenSource.IsCancellationRequested && connectedClient.Socket != null && connectedClient.Socket.State != WebSocketState.Closed && connectedClient.Socket.State != WebSocketState.Aborted) + { + WebSocketReceiveResult result = await connectedClient.Socket.ReceiveAsync(buffer, this._socketCancellationTokenSource.Token).ConfigureAwait(false); + if (!this._socketCancellationTokenSource.IsCancellationRequested && connectedClient.Socket.State != WebSocketState.Closed && connectedClient.Socket.State != WebSocketState.Aborted) + { + if (connectedClient.Socket.State == WebSocketState.CloseReceived && result.MessageType == WebSocketMessageType.Close) + { + await connectedClient.Socket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Acknowledge Close frame", CancellationToken.None); + + break; + } + + var receivedBytes = buffer.Slice(0, result.Count); + this._requestContentQueues[requestId].Enqueue(receivedBytes.ToArray()); + + if (result.EndOfMessage) + { + var responseSegments = this._arraySegmentHandler(receivedBytes); + + if (this.RequestProcessingDelay.Ticks > 0) + { + await Task.Delay(this.RequestProcessingDelay).ConfigureAwait(false); + } + + foreach (var responseSegment in responseSegments) + { + if (connectedClient.Socket.State != WebSocketState.Open) + { + break; + } + + if (this.SegmentMessageDelay.Ticks > 0) + { + await Task.Delay(this.SegmentMessageDelay).ConfigureAwait(false); + } + + await connectedClient.Socket.SendAsync(responseSegment, WebSocketMessageType.Text, true, this._socketCancellationTokenSource.Token).ConfigureAwait(false); + } + } + } + } + + if (connectedClient.Socket?.State == WebSocketState.Open) + { + await connectedClient.Socket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing waiting for acknowledgement", CancellationToken.None).ConfigureAwait(false); + } + else if (connectedClient.Socket?.State == WebSocketState.CloseReceived) + { + await connectedClient.Socket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Closing without waiting for acknowledgment", CancellationToken.None).ConfigureAwait(false); + } + } + catch (OperationCanceledException exception) + { + this._logger?.LogTrace(message: "Closing server web socket before disposal was cancelled", exception: exception); + } + catch (WebSocketException exception) + { + this._logger?.LogTrace(message: "Closing server web socket before disposal raised web socket exception", exception: exception); + } + finally + { + if (connectedClient.Socket?.State != WebSocketState.Closed) + { + connectedClient.Socket?.Abort(); + } + + connectedClient.Socket?.Dispose(); + + // Remove client from dictionary when done + this._clients.TryRemove(requestId, out _); + } + } + + private async Task CloseAllSocketsAsync() + { + // Close all active sockets before disposing + foreach (var client in this._clients.Values) + { + if (client.Socket?.State == WebSocketState.Open) + { + await client.Socket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing", this._mainCancellationTokenSource.Token); + } + } + } + + public async ValueTask DisposeAsync() + { + try + { + this._serverIsRunning = false; + await this.CloseAllSocketsAsync(); // Close all sockets before finishing the tasks + await Task.WhenAll(this._runningTasks).ConfigureAwait(false); + this._socketCancellationTokenSource.Cancel(); + this._mainCancellationTokenSource.Cancel(); + } + catch (OperationCanceledException exception) + { + this._logger?.LogTrace(message: "\"Disposing web socket test server raised operation cancel exception", exception: exception); + } + finally + { + this._httpListener.Stop(); + this._httpListener.Close(); + this._socketCancellationTokenSource.Dispose(); + this._mainCancellationTokenSource.Dispose(); + } + } + + public void Dispose() + { + this.DisposeAsync().AsTask().GetAwaiter().GetResult(); + } +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/XunitLogger.cs b/dotnet/src/Connectors/Connectors.UnitTests/XunitLogger.cs new file mode 100644 index 000000000000..1521dac75bed --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/XunitLogger.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.Extensions.Logging; +using Xunit.Abstractions; + +namespace SemanticKernel.Connectors.UnitTests; + +/// +/// A logger that writes to the Xunit test output +/// +internal sealed class XunitLogger : ILogger, IDisposable +{ + private readonly ITestOutputHelper _output; + + public XunitLogger(ITestOutputHelper output) + { + this._output = output; + } + + /// + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) + { + this._output.WriteLine(state?.ToString()); + } + + /// + public bool IsEnabled(LogLevel logLevel) => true; + + /// + public IDisposable BeginScope(TState state) + => this; + + /// + public void Dispose() + { + // This class is marked as disposable to support the BeginScope method. + // However, there is no need to dispose anything. + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Oobabooga/OobaboogaTextCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/Oobabooga/OobaboogaTextCompletionTests.cs new file mode 100644 index 000000000000..78d98dafc1ba --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Oobabooga/OobaboogaTextCompletionTests.cs @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Net.WebSockets; +using System.Text; +using System.Text.RegularExpressions; +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Microsoft.SemanticKernel.AI.TextCompletion; +using Microsoft.SemanticKernel.Connectors.AI.Oobabooga.TextCompletion; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Oobabooga; + +/// +/// Integration tests for . +/// +public sealed class OobaboogaTextCompletionTests : IDisposable +{ + private const string Endpoint = "http://localhost"; + private const int BlockingPort = 5000; + private const int StreamingPort = 5005; + + private readonly IConfigurationRoot _configuration; + private List _webSockets = new(); + private Func _webSocketFactory; + + public OobaboogaTextCompletionTests() + { + // Load configuration + this._configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .Build(); + this._webSocketFactory = () => + { + var toReturn = new ClientWebSocket(); + this._webSockets.Add(toReturn); + return toReturn; + }; + } + + private const string Input = " My name is"; + + [Fact(Skip = "This test is for manual verification.")] + public async Task OobaboogaLocalTextCompletionAsync() + { + var oobaboogaLocal = new OobaboogaTextCompletion( + endpoint: new Uri(Endpoint), + blockingPort: BlockingPort); + + // Act + var localResponse = await oobaboogaLocal.CompleteAsync(Input, new CompleteRequestSettings() + { + Temperature = 0.01, + MaxTokens = 7, + TopP = 0.1, + }); + + AssertAcceptableResponse(localResponse); + } + + [Fact(Skip = "This test is for manual verification.")] + public async Task OobaboogaLocalTextCompletionStreamingAsync() + { + var oobaboogaLocal = new OobaboogaTextCompletion( + endpoint: new Uri(Endpoint), + streamingPort: StreamingPort, + webSocketFactory: this._webSocketFactory); + + // Act + var localResponse = oobaboogaLocal.CompleteStreamAsync(Input, new CompleteRequestSettings() + { + Temperature = 0.01, + MaxTokens = 7, + TopP = 0.1, + }); + + StringBuilder stringBuilder = new(); + await foreach (var result in localResponse) + { + stringBuilder.Append(result); + } + + var resultsMerged = stringBuilder.ToString(); + AssertAcceptableResponse(resultsMerged); + } + + private static void AssertAcceptableResponse(string localResponse) + { + // Assert + Assert.NotNull(localResponse); + // Depends on the target LLM obviously, but most LLMs should propose an arbitrary surname preceded by a white space, including the start prompt or not + // ie " My name is" => " John (...)" or " My name is" => " My name is John (...)". + // Here are a couple LLMs that were tested successfully: gpt2, aisquared_dlite-v1-355m, bigscience_bloomz-560m, eachadea_vicuna-7b-1.1, TheBloke_WizardLM-30B-GPTQ etc. + // A few will return an empty string, but well those shouldn't be used for integration tests. + var expectedRegex = new Regex(@"\s\w+.*"); + Assert.Matches(expectedRegex, localResponse); + } + + public void Dispose() + { + foreach (ClientWebSocket clientWebSocket in this._webSockets) + { + clientWebSocket.Dispose(); + } + } +} diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index 28efab76da42..7443e4100df9 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -34,6 +34,7 @@ + diff --git a/dotnet/src/IntegrationTests/README.md b/dotnet/src/IntegrationTests/README.md index 00186f6309f6..9edb16e85896 100644 --- a/dotnet/src/IntegrationTests/README.md +++ b/dotnet/src/IntegrationTests/README.md @@ -8,6 +8,7 @@ 3. **HuggingFace API key**: see https://huggingface.co/docs/huggingface_hub/guides/inference for details. 4. **Azure Bing Web Search API**: go to [Bing Web Search API](https://www.microsoft.com/en-us/bing/apis/bing-web-search-api) and select `Try Now` to get started. +5. **Oobabooga Text generation web UI**: Follow the [installation instructions](https://github.com/oobabooga/text-generation-webui#installation) to get a local Oobabooga instance running. Follow the [download instructions](https://github.com/oobabooga/text-generation-webui#downloading-models) to install a test model e.g. `python download-model.py gpt2`. Follow the [starting instructions](https://github.com/oobabooga/text-generation-webui#starting-the-web-ui) to start your local instance, enabling API, e.g. `python server.py --model gpt2 --listen --api --api-blocking-port "5000" --api-streaming-port "5005"`. Note that `--model` parameter is optional and models can be downloaded and hot swapped using exclusively the web UI, making it easy to test various models. 5. **Postgres**: start a postgres with the [pgvector](https://github.com/pgvector/pgvector) extension installed. You can easily do it using the docker image [ankane/pgvector](https://hub.docker.com/r/ankane/pgvector). ## Setup diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/TextCompletion/TextCompletionExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/TextCompletion/TextCompletionExtensions.cs index 3172ee86fd38..31d468bfe647 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/TextCompletion/TextCompletionExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/TextCompletion/TextCompletionExtensions.cs @@ -54,6 +54,7 @@ public static async IAsyncEnumerable CompleteStreamAsync(this ITextCompl { yield return word; } + yield break; } }