From 04cf1fe2589bb1457945e394ee8cdfcfa9316ccd Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Wed, 27 Nov 2024 15:44:38 +0000 Subject: [PATCH] Integrate OpenAI/AzureOpenAI support with Microsoft.Extensions.AI (#6225) --- Directory.Packages.props | 6 +- eng/Versions.props | 1 + .../Components/App.razor | 8 +- .../Components/Pages/Home.razor | 1 - .../Components/Pages/UseIChatClient.razor | 37 +++ .../OpenAIEndToEnd.WebStory.csproj | 2 + .../OpenAIEndToEnd.WebStory/Program.cs | 5 +- .../Aspire.Azure.AI.OpenAI.csproj | 2 + .../AspireAzureOpenAIClientBuilder.cs | 29 +++ .../AspireAzureOpenAIExtensions.cs | 16 +- .../PublicAPI.Unshipped.txt | 7 +- .../Aspire.OpenAI/Aspire.OpenAI.csproj | 5 +- .../AspireOpenAIClientBuilder.cs | 98 ++++++++ ...OpenAIClientBuilderChatClientExtensions.cs | 61 +++++ ...ientBuilderEmbeddingGeneratorExtensions.cs | 63 +++++ .../Aspire.OpenAI/AspireOpenAIExtensions.cs | 16 +- .../MEAIPackageOverrides.targets | 13 + .../Aspire.OpenAI/PublicAPI.Unshipped.txt | 17 +- src/Components/Common/AzureComponent.cs | 4 +- .../Aspire.Azure.AI.OpenAI.Tests.csproj | 2 + ...IClientBuilderChatClientExtensionsTests.cs | 223 +++++++++++++++++ ...uilderEmbeddingGeneratorExtensionsTests.cs | 226 +++++++++++++++++ .../Aspire.OpenAI.Tests.csproj | 2 + ...IClientBuilderChatClientExtensionsTests.cs | 224 +++++++++++++++++ ...uilderEmbeddingGeneratorExtensionsTests.cs | 227 ++++++++++++++++++ 25 files changed, 1272 insertions(+), 23 deletions(-) create mode 100644 playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/Components/Pages/UseIChatClient.razor create mode 100644 src/Components/Aspire.Azure.AI.OpenAI/AspireAzureOpenAIClientBuilder.cs create mode 100644 src/Components/Aspire.OpenAI/AspireOpenAIClientBuilder.cs create mode 100644 src/Components/Aspire.OpenAI/AspireOpenAIClientBuilderChatClientExtensions.cs create mode 100644 src/Components/Aspire.OpenAI/AspireOpenAIClientBuilderEmbeddingGeneratorExtensions.cs create mode 100644 src/Components/Aspire.OpenAI/MEAIPackageOverrides.targets create mode 100644 tests/Aspire.Azure.AI.OpenAI.Tests/AspireAzureOpenAIClientBuilderChatClientExtensionsTests.cs create mode 100644 tests/Aspire.Azure.AI.OpenAI.Tests/AspireAzureOpenAIClientBuilderEmbeddingGeneratorExtensionsTests.cs create mode 100644 tests/Aspire.OpenAI.Tests/AspireOpenAIClientBuilderChatClientExtensionsTests.cs create mode 100644 tests/Aspire.OpenAI.Tests/AspireOpenAIClientBuilderEmbeddingGeneratorExtensionsTests.cs diff --git a/Directory.Packages.props b/Directory.Packages.props index 88a7cfa94d..9f1570a44c 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -11,7 +11,7 @@ - + @@ -80,6 +80,8 @@ + + @@ -116,7 +118,7 @@ - + diff --git a/eng/Versions.props b/eng/Versions.props index 62ef9cb728..c528688df5 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -33,6 +33,7 @@ 9.0.0-beta.24572.2 9.0.0-beta.24516.2 9.0.0-beta.24516.2 + 9.0.1-preview.1.24570.5 9.0.0 9.0.0 8.0.0 diff --git a/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/Components/App.razor b/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/Components/App.razor index 0f9b81fecd..7ff7bd32e3 100644 --- a/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/Components/App.razor +++ b/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/Components/App.razor @@ -7,12 +7,16 @@ - + - + + +@code { + IComponentRenderMode renderMode = new InteractiveServerRenderMode(prerender: false); +} diff --git a/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/Components/Pages/Home.razor b/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/Components/Pages/Home.razor index f989e760ce..543b9804ae 100644 --- a/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/Components/Pages/Home.razor +++ b/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/Components/Pages/Home.razor @@ -1,5 +1,4 @@ @page "/" -@rendermode @(new InteractiveServerRenderMode(prerender: false)) @using OpenAI @using OpenAI.Chat @inject OpenAIClient aiClient diff --git a/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/Components/Pages/UseIChatClient.razor b/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/Components/Pages/UseIChatClient.razor new file mode 100644 index 0000000000..69e3e68db2 --- /dev/null +++ b/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/Components/Pages/UseIChatClient.razor @@ -0,0 +1,37 @@ +@page "/useichatclient" +@using Microsoft.Extensions.AI +@inject IChatClient aiClient +@inject ILogger logger +@inject IConfiguration configuration + +
+ @foreach (var message in chatMessages.Where(m => m.Role == ChatRole.Assistant)) + { +

@message.Text

+ } + + +
+ +@code { + private List chatMessages = new List + { + new(ChatRole.System, "Pick a random topic and write a sentence of a fictional story about it.") + }; + + private async Task GenerateNextParagraph() + { + if (chatMessages.Count > 1) + { + chatMessages.Add(new (ChatRole.User, "Write the next sentence in the story.")); + } + + var response = await aiClient.CompleteAsync(chatMessages); + chatMessages.Add(response.Message); + } + + protected override async Task OnInitializedAsync() + { + await GenerateNextParagraph(); + } +} diff --git a/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/OpenAIEndToEnd.WebStory.csproj b/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/OpenAIEndToEnd.WebStory.csproj index ede52a8521..e9ab77317f 100644 --- a/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/OpenAIEndToEnd.WebStory.csproj +++ b/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/OpenAIEndToEnd.WebStory.csproj @@ -11,4 +11,6 @@
+ + diff --git a/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/Program.cs b/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/Program.cs index 5b98bcffa3..2d2e2953f3 100644 --- a/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/Program.cs +++ b/playground/OpenAIEndToEnd/OpenAIEndToEnd.WebStory/Program.cs @@ -7,7 +7,10 @@ builder.AddServiceDefaults(); -builder.AddAzureOpenAIClient("openai"); +// Instead of passing this manually, it can also be read from the connection string +var openAiDeploymentName = builder.Configuration["OpenAI:DeploymentName"]; + +builder.AddAzureOpenAIClient("openai").AddChatClient(openAiDeploymentName); // Add services to the container. builder.Services.AddRazorComponents() diff --git a/src/Components/Aspire.Azure.AI.OpenAI/Aspire.Azure.AI.OpenAI.csproj b/src/Components/Aspire.Azure.AI.OpenAI/Aspire.Azure.AI.OpenAI.csproj index 5af5dcb3e7..6a22fde1c9 100644 --- a/src/Components/Aspire.Azure.AI.OpenAI/Aspire.Azure.AI.OpenAI.csproj +++ b/src/Components/Aspire.Azure.AI.OpenAI/Aspire.Azure.AI.OpenAI.csproj @@ -34,4 +34,6 @@ + + diff --git a/src/Components/Aspire.Azure.AI.OpenAI/AspireAzureOpenAIClientBuilder.cs b/src/Components/Aspire.Azure.AI.OpenAI/AspireAzureOpenAIClientBuilder.cs new file mode 100644 index 0000000000..a26cdea37e --- /dev/null +++ b/src/Components/Aspire.Azure.AI.OpenAI/AspireAzureOpenAIClientBuilder.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Aspire.OpenAI; +using Azure.AI.OpenAI; +using Microsoft.Extensions.Hosting; + +namespace Aspire.Azure.AI.OpenAI; + +/// +/// A builder for configuring an service registration. +/// +public class AspireAzureOpenAIClientBuilder : AspireOpenAIClientBuilder +{ + /// + /// Constructs a new instance of . + /// + /// The with which services are being registered. + /// The name used to retrieve the connection string from the ConnectionStrings configuration section. + /// The service key used to register the service, if any. + /// A flag to indicate whether tracing should be disabled. + public AspireAzureOpenAIClientBuilder(IHostApplicationBuilder hostBuilder, string connectionName, string? serviceKey, bool disableTracing) + : base(hostBuilder, connectionName, serviceKey, disableTracing) + { + } + + /// + public override string ConfigurationSectionName => AspireAzureOpenAIExtensions.DefaultConfigSectionName; +} diff --git a/src/Components/Aspire.Azure.AI.OpenAI/AspireAzureOpenAIExtensions.cs b/src/Components/Aspire.Azure.AI.OpenAI/AspireAzureOpenAIExtensions.cs index 32e05e062a..f7683acd7f 100644 --- a/src/Components/Aspire.Azure.AI.OpenAI/AspireAzureOpenAIExtensions.cs +++ b/src/Components/Aspire.Azure.AI.OpenAI/AspireAzureOpenAIExtensions.cs @@ -22,7 +22,7 @@ namespace Microsoft.Extensions.Hosting; /// public static class AspireAzureOpenAIExtensions { - private const string DefaultConfigSectionName = "Aspire:Azure:AI:OpenAI"; + internal const string DefaultConfigSectionName = "Aspire:Azure:AI:OpenAI"; /// /// Registers as a singleton in the services provided by the . @@ -33,17 +33,20 @@ public static class AspireAzureOpenAIExtensions /// A name used to retrieve the connection string from the ConnectionStrings configuration section. /// An optional method that can be used for customizing the . It's invoked after the settings are read from the configuration. /// An optional method that can be used for customizing the . + /// An that can be used to register additional services. /// Reads the configuration from "Aspire.Azure.AI.OpenAI" section. - public static void AddAzureOpenAIClient( + public static AspireAzureOpenAIClientBuilder AddAzureOpenAIClient( this IHostApplicationBuilder builder, string connectionName, Action? configureSettings = null, Action>? configureClientBuilder = null) { - new OpenAIComponent().AddClient(builder, DefaultConfigSectionName, configureSettings, configureClientBuilder, connectionName, serviceKey: null); + var settings = new OpenAIComponent().AddClient(builder, DefaultConfigSectionName, configureSettings, configureClientBuilder, connectionName, serviceKey: null); // Add the AzureOpenAIClient service as OpenAIClient. That way the service can be resolved by both service Types. builder.Services.TryAddSingleton(typeof(OpenAIClient), static provider => provider.GetRequiredService()); + + return new AspireAzureOpenAIClientBuilder(builder, connectionName, serviceKey: null, disableTracing: settings.DisableTracing); } /// @@ -55,8 +58,9 @@ public static void AddAzureOpenAIClient( /// The name of the component, which is used as the of the service and also to retrieve the connection string from the ConnectionStrings configuration section. /// An optional method that can be used for customizing the . It's invoked after the settings are read from the configuration. /// An optional method that can be used for customizing the . + /// An that can be used to register additional services. /// Reads the configuration from "Aspire.Azure.AI.OpenAI:{name}" section. - public static void AddKeyedAzureOpenAIClient( + public static AspireAzureOpenAIClientBuilder AddKeyedAzureOpenAIClient( this IHostApplicationBuilder builder, string name, Action? configureSettings = null, @@ -64,10 +68,12 @@ public static void AddKeyedAzureOpenAIClient( { ArgumentException.ThrowIfNullOrEmpty(name); - new OpenAIComponent().AddClient(builder, DefaultConfigSectionName, configureSettings, configureClientBuilder, connectionName: name, serviceKey: name); + var settings = new OpenAIComponent().AddClient(builder, DefaultConfigSectionName, configureSettings, configureClientBuilder, connectionName: name, serviceKey: name); // Add the AzureOpenAIClient service as OpenAIClient. That way the service can be resolved by both service Types. builder.Services.TryAddKeyedSingleton(typeof(OpenAIClient), serviceKey: name, static (provider, key) => provider.GetRequiredKeyedService(key)); + + return new AspireAzureOpenAIClientBuilder(builder, name, name, settings.DisableTracing); } private sealed class OpenAIComponent : AzureComponent diff --git a/src/Components/Aspire.Azure.AI.OpenAI/PublicAPI.Unshipped.txt b/src/Components/Aspire.Azure.AI.OpenAI/PublicAPI.Unshipped.txt index 1f2eb92b6e..a205352267 100644 --- a/src/Components/Aspire.Azure.AI.OpenAI/PublicAPI.Unshipped.txt +++ b/src/Components/Aspire.Azure.AI.OpenAI/PublicAPI.Unshipped.txt @@ -1,4 +1,6 @@ #nullable enable +Aspire.Azure.AI.OpenAI.AspireAzureOpenAIClientBuilder +Aspire.Azure.AI.OpenAI.AspireAzureOpenAIClientBuilder.AspireAzureOpenAIClientBuilder(Microsoft.Extensions.Hosting.IHostApplicationBuilder! hostBuilder, string! connectionName, string? serviceKey, bool disableTracing) -> void Aspire.Azure.AI.OpenAI.AzureOpenAISettings Aspire.Azure.AI.OpenAI.AzureOpenAISettings.AzureOpenAISettings() -> void Aspire.Azure.AI.OpenAI.AzureOpenAISettings.Credential.get -> Azure.Core.TokenCredential? @@ -13,7 +15,8 @@ Aspire.Azure.AI.OpenAI.AzureOpenAISettings.Key.get -> string? Aspire.Azure.AI.OpenAI.AzureOpenAISettings.Key.set -> void Microsoft.Extensions.Hosting.AspireAzureOpenAIExtensions Microsoft.Extensions.Hosting.AspireConfigurableOpenAIExtensions -static Microsoft.Extensions.Hosting.AspireAzureOpenAIExtensions.AddAzureOpenAIClient(this Microsoft.Extensions.Hosting.IHostApplicationBuilder! builder, string! connectionName, System.Action? configureSettings = null, System.Action!>? configureClientBuilder = null) -> void -static Microsoft.Extensions.Hosting.AspireAzureOpenAIExtensions.AddKeyedAzureOpenAIClient(this Microsoft.Extensions.Hosting.IHostApplicationBuilder! builder, string! name, System.Action? configureSettings = null, System.Action!>? configureClientBuilder = null) -> void +override Aspire.Azure.AI.OpenAI.AspireAzureOpenAIClientBuilder.ConfigurationSectionName.get -> string! +static Microsoft.Extensions.Hosting.AspireAzureOpenAIExtensions.AddAzureOpenAIClient(this Microsoft.Extensions.Hosting.IHostApplicationBuilder! builder, string! connectionName, System.Action? configureSettings = null, System.Action!>? configureClientBuilder = null) -> Aspire.Azure.AI.OpenAI.AspireAzureOpenAIClientBuilder! +static Microsoft.Extensions.Hosting.AspireAzureOpenAIExtensions.AddKeyedAzureOpenAIClient(this Microsoft.Extensions.Hosting.IHostApplicationBuilder! builder, string! name, System.Action? configureSettings = null, System.Action!>? configureClientBuilder = null) -> Aspire.Azure.AI.OpenAI.AspireAzureOpenAIClientBuilder! static Microsoft.Extensions.Hosting.AspireConfigurableOpenAIExtensions.AddKeyedOpenAIClientFromConfiguration(this Microsoft.Extensions.Hosting.IHostApplicationBuilder! builder, string! name) -> void static Microsoft.Extensions.Hosting.AspireConfigurableOpenAIExtensions.AddOpenAIClientFromConfiguration(this Microsoft.Extensions.Hosting.IHostApplicationBuilder! builder, string! connectionName) -> void diff --git a/src/Components/Aspire.OpenAI/Aspire.OpenAI.csproj b/src/Components/Aspire.OpenAI/Aspire.OpenAI.csproj index f4a0c637d3..288b16552a 100644 --- a/src/Components/Aspire.OpenAI/Aspire.OpenAI.csproj +++ b/src/Components/Aspire.OpenAI/Aspire.OpenAI.csproj @@ -19,9 +19,12 @@ - + + + + diff --git a/src/Components/Aspire.OpenAI/AspireOpenAIClientBuilder.cs b/src/Components/Aspire.OpenAI/AspireOpenAIClientBuilder.cs new file mode 100644 index 0000000000..25b10c84ee --- /dev/null +++ b/src/Components/Aspire.OpenAI/AspireOpenAIClientBuilder.cs @@ -0,0 +1,98 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Configuration; +using System.Data.Common; +using Microsoft.Extensions.Hosting; +using OpenAI; + +namespace Aspire.OpenAI; + +/// +/// A builder for configuring an service registration. +/// +public class AspireOpenAIClientBuilder +{ + private const string DeploymentKey = "Deployment"; + private const string ModelKey = "Model"; + + /// + /// Constructs a new instance of . + /// + /// The with which services are being registered. + /// The name used to retrieve the connection string from the ConnectionStrings configuration section. + /// The service key used to register the service, if any. + /// A flag to indicate whether tracing should be disabled. + public AspireOpenAIClientBuilder(IHostApplicationBuilder hostBuilder, string connectionName, string? serviceKey, bool disableTracing) + { + HostBuilder = hostBuilder; + ConnectionName = connectionName; + ServiceKey = serviceKey; + DisableTracing = disableTracing; + } + + /// + /// Gets the with which services are being registered. + /// + public IHostApplicationBuilder HostBuilder { get; } + + /// + /// Gets the name used to retrieve the connection string from the ConnectionStrings configuration section. + /// + public string ConnectionName { get; } + + /// + /// Gets the service key used to register the service, if any. + /// + public string? ServiceKey { get; } + + /// + /// Gets a flag indicating whether tracing should be disabled. + /// + public bool DisableTracing { get; } + + /// + /// Gets the name of the configuration section for this component type. + /// + public virtual string ConfigurationSectionName => AspireOpenAIExtensions.DefaultConfigSectionName; + + internal string GetRequiredDeploymentName() + { + string? deploymentName = null; + + var configuration = HostBuilder.Configuration; + if (configuration.GetConnectionString(ConnectionName) is string connectionString) + { + // The reason we accept either 'Deployment' or 'Model' as the key is because OpenAI's terminology + // is 'Model' and Azure OpenAI's terminology is 'Deployment'. It may seem awkward if we picked just + // one of these, as it might not match the usage scenario. We could restrict it based on which backend + // you're using, but that adds an unnecessary failure case for no clear benefit. + var connectionBuilder = new DbConnectionStringBuilder { ConnectionString = connectionString }; + var deploymentValue = ConnectionStringValue(connectionBuilder, DeploymentKey); + var modelValue = ConnectionStringValue(connectionBuilder, ModelKey); + if (deploymentValue is not null && modelValue is not null) + { + throw new InvalidOperationException( + $"The connection string '{ConnectionName}' contains both '{DeploymentKey}' and '{ModelKey}' keys. Either of these may be specified, but not both."); + } + + deploymentName = deploymentValue ?? modelValue; + } + + if (string.IsNullOrEmpty(deploymentName)) + { + var configSection = configuration.GetSection(ConfigurationSectionName); + deploymentName = configSection[DeploymentKey]; + } + + if (string.IsNullOrEmpty(deploymentName)) + { + throw new InvalidOperationException($"The deployment could not be determined. Ensure a '{DeploymentKey}' or '{ModelKey}' value is provided in 'ConnectionStrings:{ConnectionName}', or specify a '{DeploymentKey}' in the '{ConfigurationSectionName}' configuration section, or specify a '{nameof(deploymentName)}' in the call."); + } + + return deploymentName; + } + + private static string? ConnectionStringValue(DbConnectionStringBuilder connectionString, string key) + => connectionString.TryGetValue(key, out var value) ? value as string : null; +} diff --git a/src/Components/Aspire.OpenAI/AspireOpenAIClientBuilderChatClientExtensions.cs b/src/Components/Aspire.OpenAI/AspireOpenAIClientBuilderChatClientExtensions.cs new file mode 100644 index 0000000000..3fe8426127 --- /dev/null +++ b/src/Components/Aspire.OpenAI/AspireOpenAIClientBuilderChatClientExtensions.cs @@ -0,0 +1,61 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Aspire.OpenAI; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using OpenAI; + +namespace Microsoft.Extensions.Hosting; + +/// +/// Provides extension methods for registering as a singleton in the services provided by the . +/// +public static class AspireOpenAIClientBuilderChatClientExtensions +{ + /// + /// Registers a singleton in the services provided by the . + /// + /// An . + /// Optionally specifies which model deployment to use. If not specified, a value will be taken from the connection string. + /// A that can be used to build a pipeline around the inner . + public static ChatClientBuilder AddChatClient( + this AspireOpenAIClientBuilder builder, + string? deploymentName = null) + { + return builder.HostBuilder.Services.AddChatClient( + services => CreateInnerChatClient(services, builder, deploymentName)); + } + + /// + /// Registers a keyed singleton in the services provided by the . + /// + /// An . + /// The service key with which the will be registered. + /// Optionally specifies which model deployment to use. If not specified, a value will be taken from the connection string. + /// A that can be used to build a pipeline around the inner . + public static ChatClientBuilder AddKeyedChatClient( + this AspireOpenAIClientBuilder builder, + string serviceKey, + string? deploymentName = null) + { + return builder.HostBuilder.Services.AddKeyedChatClient( + serviceKey, + services => CreateInnerChatClient(services, builder, deploymentName)); + } + + private static IChatClient CreateInnerChatClient( + IServiceProvider services, + AspireOpenAIClientBuilder builder, + string? deploymentName) + { + var openAiClient = builder.ServiceKey is null + ? services.GetRequiredService() + : services.GetRequiredKeyedService(builder.ServiceKey); + + deploymentName ??= builder.GetRequiredDeploymentName(); + var result = openAiClient.AsChatClient(deploymentName); + + return builder.DisableTracing ? result : new OpenTelemetryChatClient(result); + } +} diff --git a/src/Components/Aspire.OpenAI/AspireOpenAIClientBuilderEmbeddingGeneratorExtensions.cs b/src/Components/Aspire.OpenAI/AspireOpenAIClientBuilderEmbeddingGeneratorExtensions.cs new file mode 100644 index 0000000000..12d06e7a44 --- /dev/null +++ b/src/Components/Aspire.OpenAI/AspireOpenAIClientBuilderEmbeddingGeneratorExtensions.cs @@ -0,0 +1,63 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Aspire.OpenAI; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using OpenAI; + +namespace Microsoft.Extensions.Hosting; + +/// +/// Provides extension methods for registering as a singleton in the services provided by the . +/// +public static class AspireOpenAIClientBuilderEmbeddingGeneratorExtensions +{ + /// + /// Registers a singleton in the services provided by the . + /// + /// An . + /// Optionally specifies which model deployment to use. If not specified, a value will be taken from the connection string. + /// A that can be used to build a pipeline around the inner . + public static EmbeddingGeneratorBuilder> AddEmbeddingGenerator( + this AspireOpenAIClientBuilder builder, + string? deploymentName = null) + { + return builder.HostBuilder.Services.AddEmbeddingGenerator( + services => CreateInnerEmbeddingGenerator(services, builder, deploymentName)); + } + + /// + /// Registers a keyed singleton in the services provided by the . + /// + /// An . + /// The service key with which the will be registered. + /// Optionally specifies which model deployment to use. If not specified, a value will be taken from the connection string. + /// A that can be used to build a pipeline around the inner . + public static EmbeddingGeneratorBuilder> AddKeyedEmbeddingGenerator( + this AspireOpenAIClientBuilder builder, + string serviceKey, + string? deploymentName = null) + { + return builder.HostBuilder.Services.AddKeyedEmbeddingGenerator( + serviceKey, + services => CreateInnerEmbeddingGenerator(services, builder, deploymentName)); + } + + private static IEmbeddingGenerator> CreateInnerEmbeddingGenerator( + IServiceProvider services, + AspireOpenAIClientBuilder builder, + string? deploymentName) + { + var openAiClient = builder.ServiceKey is null + ? services.GetRequiredService() + : services.GetRequiredKeyedService(builder.ServiceKey); + + deploymentName ??= builder.GetRequiredDeploymentName(); + var result = openAiClient.AsEmbeddingGenerator(deploymentName); + + return builder.DisableTracing + ? result + : new OpenTelemetryEmbeddingGenerator>(result); + } +} diff --git a/src/Components/Aspire.OpenAI/AspireOpenAIExtensions.cs b/src/Components/Aspire.OpenAI/AspireOpenAIExtensions.cs index 991c470dbf..edd5f2bb54 100644 --- a/src/Components/Aspire.OpenAI/AspireOpenAIExtensions.cs +++ b/src/Components/Aspire.OpenAI/AspireOpenAIExtensions.cs @@ -16,7 +16,7 @@ namespace Microsoft.Extensions.Hosting; /// public static class AspireOpenAIExtensions { - private const string DefaultConfigSectionName = "Aspire:OpenAI"; + internal const string DefaultConfigSectionName = "Aspire:OpenAI"; /// /// Registers as a singleton in the services provided by the . @@ -25,8 +25,9 @@ public static class AspireOpenAIExtensions /// A name used to retrieve the connection string from the ConnectionStrings configuration section. /// An optional method that can be used for customizing the . It's invoked after the settings are read from the configuration. /// An optional method that can be used for customizing the . + /// An that can be used to register additional services. /// Reads the configuration from "Aspire.OpenAI" section. - public static void AddOpenAIClient( + public static AspireOpenAIClientBuilder AddOpenAIClient( this IHostApplicationBuilder builder, string connectionName, Action? configureSettings = null, @@ -35,7 +36,7 @@ public static void AddOpenAIClient( ArgumentNullException.ThrowIfNull(builder); ArgumentNullException.ThrowIfNull(connectionName); - AddOpenAIClient(builder, DefaultConfigSectionName, configureSettings, configureOptions, connectionName, serviceKey: null); + return AddOpenAIClient(builder, DefaultConfigSectionName, configureSettings, configureOptions, connectionName, serviceKey: null); } /// @@ -45,8 +46,9 @@ public static void AddOpenAIClient( /// The name of the component, which is used as the of the service and also to retrieve the connection string from the ConnectionStrings configuration section. /// An optional method that can be used for customizing the . It's invoked after the settings are read from the configuration. /// An optional method that can be used for customizing the . + /// An that can be used to register additional services. /// Reads the configuration from "Aspire.OpenAI:{name}" section. - public static void AddKeyedOpenAIClient( + public static AspireOpenAIClientBuilder AddKeyedOpenAIClient( this IHostApplicationBuilder builder, string name, Action? configureSettings = null, @@ -55,10 +57,10 @@ public static void AddKeyedOpenAIClient( ArgumentNullException.ThrowIfNull(builder); ArgumentException.ThrowIfNullOrEmpty(name); - AddOpenAIClient(builder, DefaultConfigSectionName, configureSettings, configureOptions, connectionName: name, serviceKey: name); + return AddOpenAIClient(builder, DefaultConfigSectionName, configureSettings, configureOptions, connectionName: name, serviceKey: name); } - private static void AddOpenAIClient( + private static AspireOpenAIClientBuilder AddOpenAIClient( this IHostApplicationBuilder builder, string configurationSectionName, Action? configureSettings, @@ -120,6 +122,8 @@ private static void AddOpenAIClient( .WithMetrics(b => b.AddMeter("OpenAI.*")); } + return new AspireOpenAIClientBuilder(builder, connectionName, serviceKey, settings.DisableTracing); + OpenAIClient ConfigureOpenAI(IServiceProvider serviceProvider) { if (settings.Key is not null) diff --git a/src/Components/Aspire.OpenAI/MEAIPackageOverrides.targets b/src/Components/Aspire.OpenAI/MEAIPackageOverrides.targets new file mode 100644 index 0000000000..259a41acd1 --- /dev/null +++ b/src/Components/Aspire.OpenAI/MEAIPackageOverrides.targets @@ -0,0 +1,13 @@ + + + + + + + + + diff --git a/src/Components/Aspire.OpenAI/PublicAPI.Unshipped.txt b/src/Components/Aspire.OpenAI/PublicAPI.Unshipped.txt index 12d79016b9..678b950614 100644 --- a/src/Components/Aspire.OpenAI/PublicAPI.Unshipped.txt +++ b/src/Components/Aspire.OpenAI/PublicAPI.Unshipped.txt @@ -1,4 +1,10 @@ #nullable enable +Aspire.OpenAI.AspireOpenAIClientBuilder +Aspire.OpenAI.AspireOpenAIClientBuilder.AspireOpenAIClientBuilder(Microsoft.Extensions.Hosting.IHostApplicationBuilder! hostBuilder, string! connectionName, string? serviceKey, bool disableTracing) -> void +Aspire.OpenAI.AspireOpenAIClientBuilder.ConnectionName.get -> string! +Aspire.OpenAI.AspireOpenAIClientBuilder.DisableTracing.get -> bool +Aspire.OpenAI.AspireOpenAIClientBuilder.HostBuilder.get -> Microsoft.Extensions.Hosting.IHostApplicationBuilder! +Aspire.OpenAI.AspireOpenAIClientBuilder.ServiceKey.get -> string? Aspire.OpenAI.OpenAISettings Aspire.OpenAI.OpenAISettings.DisableMetrics.get -> bool Aspire.OpenAI.OpenAISettings.DisableMetrics.set -> void @@ -9,6 +15,13 @@ Aspire.OpenAI.OpenAISettings.Endpoint.set -> void Aspire.OpenAI.OpenAISettings.Key.get -> string? Aspire.OpenAI.OpenAISettings.Key.set -> void Aspire.OpenAI.OpenAISettings.OpenAISettings() -> void +Microsoft.Extensions.Hosting.AspireOpenAIClientBuilderChatClientExtensions +Microsoft.Extensions.Hosting.AspireOpenAIClientBuilderEmbeddingGeneratorExtensions Microsoft.Extensions.Hosting.AspireOpenAIExtensions -static Microsoft.Extensions.Hosting.AspireOpenAIExtensions.AddKeyedOpenAIClient(this Microsoft.Extensions.Hosting.IHostApplicationBuilder! builder, string! name, System.Action? configureSettings = null, System.Action? configureOptions = null) -> void -static Microsoft.Extensions.Hosting.AspireOpenAIExtensions.AddOpenAIClient(this Microsoft.Extensions.Hosting.IHostApplicationBuilder! builder, string! connectionName, System.Action? configureSettings = null, System.Action? configureOptions = null) -> void +static Microsoft.Extensions.Hosting.AspireOpenAIClientBuilderChatClientExtensions.AddChatClient(this Aspire.OpenAI.AspireOpenAIClientBuilder! builder, string? deploymentName = null) -> Microsoft.Extensions.AI.ChatClientBuilder! +static Microsoft.Extensions.Hosting.AspireOpenAIClientBuilderChatClientExtensions.AddKeyedChatClient(this Aspire.OpenAI.AspireOpenAIClientBuilder! builder, string! serviceKey, string? deploymentName = null) -> Microsoft.Extensions.AI.ChatClientBuilder! +static Microsoft.Extensions.Hosting.AspireOpenAIClientBuilderEmbeddingGeneratorExtensions.AddEmbeddingGenerator(this Aspire.OpenAI.AspireOpenAIClientBuilder! builder, string? deploymentName = null) -> Microsoft.Extensions.AI.EmbeddingGeneratorBuilder!>! +static Microsoft.Extensions.Hosting.AspireOpenAIClientBuilderEmbeddingGeneratorExtensions.AddKeyedEmbeddingGenerator(this Aspire.OpenAI.AspireOpenAIClientBuilder! builder, string! serviceKey, string? deploymentName = null) -> Microsoft.Extensions.AI.EmbeddingGeneratorBuilder!>! +static Microsoft.Extensions.Hosting.AspireOpenAIExtensions.AddKeyedOpenAIClient(this Microsoft.Extensions.Hosting.IHostApplicationBuilder! builder, string! name, System.Action? configureSettings = null, System.Action? configureOptions = null) -> Aspire.OpenAI.AspireOpenAIClientBuilder! +static Microsoft.Extensions.Hosting.AspireOpenAIExtensions.AddOpenAIClient(this Microsoft.Extensions.Hosting.IHostApplicationBuilder! builder, string! connectionName, System.Action? configureSettings = null, System.Action? configureOptions = null) -> Aspire.OpenAI.AspireOpenAIClientBuilder! +virtual Aspire.OpenAI.AspireOpenAIClientBuilder.ConfigurationSectionName.get -> string! diff --git a/src/Components/Common/AzureComponent.cs b/src/Components/Common/AzureComponent.cs index fb0f88f874..699e2cffc3 100644 --- a/src/Components/Common/AzureComponent.cs +++ b/src/Components/Common/AzureComponent.cs @@ -41,7 +41,7 @@ protected abstract IAzureClientBuilder AddClient( protected abstract IHealthCheck CreateHealthCheck(TClient client, TSettings settings); - internal void AddClient( + internal TSettings AddClient( IHostApplicationBuilder builder, string configurationSectionName, Action? configureSettings, @@ -144,6 +144,8 @@ internal void AddClient( builder.Services.AddOpenTelemetry() .WithTracing(traceBuilder => traceBuilder.AddSource(ActivitySourceNames)); } + + return settings; } } diff --git a/tests/Aspire.Azure.AI.OpenAI.Tests/Aspire.Azure.AI.OpenAI.Tests.csproj b/tests/Aspire.Azure.AI.OpenAI.Tests/Aspire.Azure.AI.OpenAI.Tests.csproj index 608087628b..71240d0e63 100644 --- a/tests/Aspire.Azure.AI.OpenAI.Tests/Aspire.Azure.AI.OpenAI.Tests.csproj +++ b/tests/Aspire.Azure.AI.OpenAI.Tests/Aspire.Azure.AI.OpenAI.Tests.csproj @@ -11,4 +11,6 @@ + + diff --git a/tests/Aspire.Azure.AI.OpenAI.Tests/AspireAzureOpenAIClientBuilderChatClientExtensionsTests.cs b/tests/Aspire.Azure.AI.OpenAI.Tests/AspireAzureOpenAIClientBuilderChatClientExtensionsTests.cs new file mode 100644 index 0000000000..26705f4015 --- /dev/null +++ b/tests/Aspire.Azure.AI.OpenAI.Tests/AspireAzureOpenAIClientBuilderChatClientExtensionsTests.cs @@ -0,0 +1,223 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Xunit; + +namespace Aspire.Azure.AI.OpenAI.Tests; + +public class AspireAzureOpenAIClientBuilderChatClientExtensionsTests +{ + [Theory] + [InlineData(true)] + [InlineData(false)] + public void CanReadDeploymentNameFromConfig(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("Aspire:Azure:AI:OpenAI:Endpoint", "https://aspireopenaitests.openai.azure.com/"), + new("Aspire:Azure:AI:OpenAI:Deployment", "testdeployment1") + ]); + + if (useKeyed) + { + builder.AddAzureOpenAIClient("openai").AddKeyedChatClient("openai_chatclient"); + } + else + { + builder.AddAzureOpenAIClient("openai").AddChatClient(); + } + + using var host = builder.Build(); + var client = useKeyed ? + host.Services.GetRequiredKeyedService("openai_chatclient") : + host.Services.GetRequiredService(); + + Assert.NotNull(client); + Assert.Equal("testdeployment1", client.Metadata.ModelId); + } + + [Theory] + [InlineData(true, "Model")] + [InlineData(false, "Model")] + [InlineData(true, "Deployment")] + [InlineData(false, "Deployment")] + public void CanReadDeploymentNameFromConnectionString(bool useKeyed, string connectionStringKey) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake;{connectionStringKey}=testdeployment1") + ]); + + if (useKeyed) + { + builder.AddAzureOpenAIClient("openai").AddKeyedChatClient("openai_chatclient"); + } + else + { + builder.AddAzureOpenAIClient("openai").AddChatClient(); + } + + using var host = builder.Build(); + var client = useKeyed ? + host.Services.GetRequiredKeyedService("openai_chatclient") : + host.Services.GetRequiredService(); + + Assert.NotNull(client); + Assert.Equal("testdeployment1", client.Metadata.ModelId); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void CanAcceptDeploymentNameAsArgument(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake") + ]); + + if (useKeyed) + { + builder.AddAzureOpenAIClient("openai").AddKeyedChatClient("openai_chatclient", "testdeployment1"); + } + else + { + builder.AddAzureOpenAIClient("openai").AddChatClient("testdeployment1"); + } + + using var host = builder.Build(); + var client = useKeyed ? + host.Services.GetRequiredKeyedService("openai_chatclient") : + host.Services.GetRequiredService(); + + Assert.NotNull(client); + Assert.Equal("testdeployment1", client.Metadata.ModelId); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void RejectsConnectionStringWithBothModelAndDeployment(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake;Deployment=testdeployment1;Model=something") + ]); + + if (useKeyed) + { + builder.AddAzureOpenAIClient("openai").AddKeyedChatClient("openai_chatclient"); + } + else + { + builder.AddAzureOpenAIClient("openai").AddChatClient(); + } + + using var host = builder.Build(); + + var ex = Assert.Throws(() => + { + _ = useKeyed ? + host.Services.GetRequiredKeyedService("openai_chatclient") : + host.Services.GetRequiredService(); + }); + + Assert.StartsWith("The connection string 'openai' contains both 'Deployment' and 'Model' keys.", ex.Message); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void RejectsDeploymentNameNotSpecified(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake") + ]); + + if (useKeyed) + { + builder.AddAzureOpenAIClient("openai").AddKeyedChatClient("openai_chatclient"); + } + else + { + builder.AddAzureOpenAIClient("openai").AddChatClient(); + } + + using var host = builder.Build(); + + var ex = Assert.Throws(() => + { + _ = useKeyed ? + host.Services.GetRequiredKeyedService("openai_chatclient") : + host.Services.GetRequiredService(); + }); + + Assert.StartsWith("The deployment could not be determined", ex.Message); + } + + [Theory] + [InlineData(true, false)] + [InlineData(false, false)] + [InlineData(true, true)] + [InlineData(false, true)] + public void AddsOpenTelemetry(bool useKeyed, bool disableOpenTelemetry) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake"), + new("Aspire:Azure:AI:OpenAI:DisableTracing", disableOpenTelemetry.ToString()), + ]); + + if (useKeyed) + { + builder.AddAzureOpenAIClient("openai").AddKeyedChatClient("openai_chatclient", "testdeployment1"); + } + else + { + builder.AddAzureOpenAIClient("openai").AddChatClient("testdeployment1"); + } + + using var host = builder.Build(); + var client = useKeyed ? + host.Services.GetRequiredKeyedService("openai_chatclient") : + host.Services.GetRequiredService(); + + Assert.Equal(disableOpenTelemetry, client.GetService() is null); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanConfigurePipelineAsync(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake") + ]); + + if (useKeyed) + { + builder.AddAzureOpenAIClient("openai").AddKeyedChatClient("openai_chatclient", "testdeployment1").Use(TestMiddleware, null); + } + else + { + builder.AddAzureOpenAIClient("openai").AddChatClient("testdeployment1").Use(TestMiddleware, null); + } + + using var host = builder.Build(); + var client = useKeyed ? + host.Services.GetRequiredKeyedService("openai_chatclient") : + host.Services.GetRequiredService(); + + var completion = await client.CompleteAsync("Whatever"); + Assert.Equal("Hello from middleware", completion.Message.Text); + + static Task TestMiddleware(IList list, ChatOptions? options, IChatClient client, CancellationToken token) + => Task.FromResult(new ChatCompletion(new ChatMessage(ChatRole.Assistant, "Hello from middleware"))); + } +} diff --git a/tests/Aspire.Azure.AI.OpenAI.Tests/AspireAzureOpenAIClientBuilderEmbeddingGeneratorExtensionsTests.cs b/tests/Aspire.Azure.AI.OpenAI.Tests/AspireAzureOpenAIClientBuilderEmbeddingGeneratorExtensionsTests.cs new file mode 100644 index 0000000000..1f600a76e6 --- /dev/null +++ b/tests/Aspire.Azure.AI.OpenAI.Tests/AspireAzureOpenAIClientBuilderEmbeddingGeneratorExtensionsTests.cs @@ -0,0 +1,226 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Xunit; + +namespace Aspire.Azure.AI.OpenAI.Tests; + +public class AspireAzureOpenAIClientBuilderEmbeddingGeneratorExtensionsTests +{ + [Theory] + [InlineData(true)] + [InlineData(false)] + public void CanReadDeploymentNameFromConfig(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("Aspire:Azure:AI:OpenAI:Endpoint", "https://aspireopenaitests.openai.azure.com/"), + new("Aspire:Azure:AI:OpenAI:Deployment", "testdeployment1") + ]); + + if (useKeyed) + { + builder.AddAzureOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator"); + } + else + { + builder.AddAzureOpenAIClient("openai").AddEmbeddingGenerator(); + } + + using var host = builder.Build(); + var generator = useKeyed ? + host.Services.GetRequiredKeyedService>>("openai_embeddinggenerator") : + host.Services.GetRequiredService>>(); + + Assert.NotNull(generator); + Assert.Equal("testdeployment1", generator.Metadata.ModelId); + } + + [Theory] + [InlineData(true, "Model")] + [InlineData(false, "Model")] + [InlineData(true, "Deployment")] + [InlineData(false, "Deployment")] + public void CanReadDeploymentNameFromConnectionString(bool useKeyed, string connectionStringKey) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake;{connectionStringKey}=testdeployment1") + ]); + + if (useKeyed) + { + builder.AddAzureOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator"); + } + else + { + builder.AddAzureOpenAIClient("openai").AddEmbeddingGenerator(); + } + + using var host = builder.Build(); + var generator = useKeyed ? + host.Services.GetRequiredKeyedService>>("openai_embeddinggenerator") : + host.Services.GetRequiredService>>(); + + Assert.NotNull(generator); + Assert.Equal("testdeployment1", generator.Metadata.ModelId); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void CanAcceptDeploymentNameAsArgument(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake") + ]); + + if (useKeyed) + { + builder.AddAzureOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator", "testdeployment1"); + } + else + { + builder.AddAzureOpenAIClient("openai").AddEmbeddingGenerator("testdeployment1"); + } + + using var host = builder.Build(); + var generator = useKeyed ? + host.Services.GetRequiredKeyedService>>("openai_embeddinggenerator") : + host.Services.GetRequiredService>>(); + + Assert.NotNull(generator); + Assert.Equal("testdeployment1", generator.Metadata.ModelId); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void RejectsConnectionStringWithBothModelAndDeployment(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake;Deployment=testdeployment1;Model=something") + ]); + + if (useKeyed) + { + builder.AddAzureOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator"); + } + else + { + builder.AddAzureOpenAIClient("openai").AddEmbeddingGenerator(); + } + + using var host = builder.Build(); + + var ex = Assert.Throws(() => + { + _ = useKeyed ? + host.Services.GetRequiredKeyedService>>("openai_embeddinggenerator") : + host.Services.GetRequiredService>>(); + }); + + Assert.StartsWith("The connection string 'openai' contains both 'Deployment' and 'Model' keys.", ex.Message); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void RejectsDeploymentNameNotSpecified(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake") + ]); + + if (useKeyed) + { + builder.AddAzureOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator"); + } + else + { + builder.AddAzureOpenAIClient("openai").AddEmbeddingGenerator(); + } + + using var host = builder.Build(); + + var ex = Assert.Throws(() => + { + _ = useKeyed ? + host.Services.GetRequiredKeyedService>>("openai_embeddinggenerator") : + host.Services.GetRequiredService>>(); + }); + + Assert.StartsWith("The deployment could not be determined", ex.Message); + } + + [Theory] + [InlineData(true, false)] + [InlineData(false, false)] + [InlineData(true, true)] + [InlineData(false, true)] + public void AddsOpenTelemetry(bool useKeyed, bool disableOpenTelemetry) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake"), + new("Aspire:Azure:AI:OpenAI:DisableTracing", disableOpenTelemetry.ToString()), + ]); + + if (useKeyed) + { + builder.AddAzureOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator", "testdeployment1"); + } + else + { + builder.AddAzureOpenAIClient("openai").AddEmbeddingGenerator("testdeployment1"); + } + + using var host = builder.Build(); + var generator = useKeyed ? + host.Services.GetRequiredKeyedService>>("openai_embeddinggenerator") : + host.Services.GetRequiredService>>(); + + Assert.Equal(disableOpenTelemetry, generator.GetService>>() is null); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanConfigurePipelineAsync(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake") + ]); + + if (useKeyed) + { + builder.AddAzureOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator", "testdeployment1").Use(TestMiddleware); + } + else + { + builder.AddAzureOpenAIClient("openai").AddEmbeddingGenerator("testdeployment1").Use(TestMiddleware); + } + + using var host = builder.Build(); + var generator = useKeyed ? + host.Services.GetRequiredKeyedService>>("openai_embeddinggenerator") : + host.Services.GetRequiredService>>(); + + var vector = await generator.GenerateEmbeddingVectorAsync("Hello"); + Assert.Equal(1.23f, vector.ToArray().Single()); + } + + private Task>> TestMiddleware(IEnumerable inputs, EmbeddingGenerationOptions? options, IEmbeddingGenerator> nextAsync, CancellationToken cancellationToken) + { + float[] floats = [1.23f]; + return Task.FromResult(new GeneratedEmbeddings>(inputs.Select(i => new Embedding(floats)))); + } +} diff --git a/tests/Aspire.OpenAI.Tests/Aspire.OpenAI.Tests.csproj b/tests/Aspire.OpenAI.Tests/Aspire.OpenAI.Tests.csproj index f364f3bf98..a8578688eb 100644 --- a/tests/Aspire.OpenAI.Tests/Aspire.OpenAI.Tests.csproj +++ b/tests/Aspire.OpenAI.Tests/Aspire.OpenAI.Tests.csproj @@ -11,4 +11,6 @@ + + diff --git a/tests/Aspire.OpenAI.Tests/AspireOpenAIClientBuilderChatClientExtensionsTests.cs b/tests/Aspire.OpenAI.Tests/AspireOpenAIClientBuilderChatClientExtensionsTests.cs new file mode 100644 index 0000000000..53bc8e94c2 --- /dev/null +++ b/tests/Aspire.OpenAI.Tests/AspireOpenAIClientBuilderChatClientExtensionsTests.cs @@ -0,0 +1,224 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Xunit; + +namespace Aspire.OpenAI.Tests; + +public class AspireOpenAIClientBuilderChatClientExtensionsTests +{ + [Theory] + [InlineData(true)] + [InlineData(false)] + public void CanReadDeploymentNameFromConfig(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new ("Aspire:OpenAI:Endpoint", "https://aspireopenaitests.openai.azure.com/"), + new ("Aspire:OpenAI:Deployment", "testdeployment1"), + new ("Aspire:OpenAI:Key", "fake"), + ]); + + if (useKeyed) + { + builder.AddOpenAIClient("openai").AddKeyedChatClient("openai_chatclient"); + } + else + { + builder.AddOpenAIClient("openai").AddChatClient(); + } + + using var host = builder.Build(); + var client = useKeyed ? + host.Services.GetRequiredKeyedService("openai_chatclient") : + host.Services.GetRequiredService(); + + Assert.NotNull(client); + Assert.Equal("testdeployment1", client.Metadata.ModelId); + } + + [Theory] + [InlineData(true, "Model")] + [InlineData(false, "Model")] + [InlineData(true, "Deployment")] + [InlineData(false, "Deployment")] + public void CanReadDeploymentNameFromConnectionString(bool useKeyed, string connectionStringKey) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake;{connectionStringKey}=testdeployment1") + ]); + + if (useKeyed) + { + builder.AddOpenAIClient("openai").AddKeyedChatClient("openai_chatclient"); + } + else + { + builder.AddOpenAIClient("openai").AddChatClient(); + } + + using var host = builder.Build(); + var client = useKeyed ? + host.Services.GetRequiredKeyedService("openai_chatclient") : + host.Services.GetRequiredService(); + + Assert.NotNull(client); + Assert.Equal("testdeployment1", client.Metadata.ModelId); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void CanAcceptDeploymentNameAsArgument(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake") + ]); + + if (useKeyed) + { + builder.AddOpenAIClient("openai").AddKeyedChatClient("openai_chatclient", "testdeployment1"); + } + else + { + builder.AddOpenAIClient("openai").AddChatClient("testdeployment1"); + } + + using var host = builder.Build(); + var client = useKeyed ? + host.Services.GetRequiredKeyedService("openai_chatclient") : + host.Services.GetRequiredService(); + + Assert.NotNull(client); + Assert.Equal("testdeployment1", client.Metadata.ModelId); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void RejectsConnectionStringWithBothModelAndDeployment(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake;Deployment=testdeployment1;Model=something") + ]); + + if (useKeyed) + { + builder.AddOpenAIClient("openai").AddKeyedChatClient("openai_chatclient"); + } + else + { + builder.AddOpenAIClient("openai").AddChatClient(); + } + + using var host = builder.Build(); + + var ex = Assert.Throws(() => + { + _ = useKeyed ? + host.Services.GetRequiredKeyedService("openai_chatclient") : + host.Services.GetRequiredService(); + }); + + Assert.StartsWith("The connection string 'openai' contains both 'Deployment' and 'Model' keys.", ex.Message); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void RejectsDeploymentNameNotSpecified(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake") + ]); + + if (useKeyed) + { + builder.AddOpenAIClient("openai").AddKeyedChatClient("openai_chatclient"); + } + else + { + builder.AddOpenAIClient("openai").AddChatClient(); + } + + using var host = builder.Build(); + + var ex = Assert.Throws(() => + { + _ = useKeyed ? + host.Services.GetRequiredKeyedService("openai_chatclient") : + host.Services.GetRequiredService(); + }); + + Assert.StartsWith("The deployment could not be determined", ex.Message); + } + + [Theory] + [InlineData(true, false)] + [InlineData(false, false)] + [InlineData(true, true)] + [InlineData(false, true)] + public void AddsOpenTelemetry(bool useKeyed, bool disableOpenTelemetry) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake"), + new("Aspire:OpenAI:DisableTracing", disableOpenTelemetry.ToString()), + ]); + + if (useKeyed) + { + builder.AddOpenAIClient("openai").AddKeyedChatClient("openai_chatclient", "testdeployment1"); + } + else + { + builder.AddOpenAIClient("openai").AddChatClient("testdeployment1"); + } + + using var host = builder.Build(); + var client = useKeyed ? + host.Services.GetRequiredKeyedService("openai_chatclient") : + host.Services.GetRequiredService(); + + Assert.Equal(disableOpenTelemetry, client.GetService() is null); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanConfigurePipelineAsync(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake") + ]); + + if (useKeyed) + { + builder.AddOpenAIClient("openai").AddKeyedChatClient("openai_chatclient", "testdeployment1").Use(TestMiddleware, null); + } + else + { + builder.AddOpenAIClient("openai").AddChatClient("testdeployment1").Use(TestMiddleware, null); + } + + using var host = builder.Build(); + var client = useKeyed ? + host.Services.GetRequiredKeyedService("openai_chatclient") : + host.Services.GetRequiredService(); + + var completion = await client.CompleteAsync("Whatever"); + Assert.Equal("Hello from middleware", completion.Message.Text); + + static Task TestMiddleware(IList list, ChatOptions? options, IChatClient client, CancellationToken token) + => Task.FromResult(new ChatCompletion(new ChatMessage(ChatRole.Assistant, "Hello from middleware"))); + } +} diff --git a/tests/Aspire.OpenAI.Tests/AspireOpenAIClientBuilderEmbeddingGeneratorExtensionsTests.cs b/tests/Aspire.OpenAI.Tests/AspireOpenAIClientBuilderEmbeddingGeneratorExtensionsTests.cs new file mode 100644 index 0000000000..5529008876 --- /dev/null +++ b/tests/Aspire.OpenAI.Tests/AspireOpenAIClientBuilderEmbeddingGeneratorExtensionsTests.cs @@ -0,0 +1,227 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Xunit; + +namespace Aspire.OpenAI.Tests; + +public class AspireOpenAIClientBuilderEmbeddingGeneratorExtensionsTests +{ + [Theory] + [InlineData(true)] + [InlineData(false)] + public void CanReadDeploymentNameFromConfig(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("Aspire:OpenAI:Endpoint", "https://aspireopenaitests.openai.azure.com/"), + new("Aspire:OpenAI:Deployment", "testdeployment1"), + new("Aspire:OpenAI:Key", "fake"), + ]); + + if (useKeyed) + { + builder.AddOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator"); + } + else + { + builder.AddOpenAIClient("openai").AddEmbeddingGenerator(); + } + + using var host = builder.Build(); + var generator = useKeyed ? + host.Services.GetRequiredKeyedService>>("openai_embeddinggenerator") : + host.Services.GetRequiredService>>(); + + Assert.NotNull(generator); + Assert.Equal("testdeployment1", generator.Metadata.ModelId); + } + + [Theory] + [InlineData(true, "Model")] + [InlineData(false, "Model")] + [InlineData(true, "Deployment")] + [InlineData(false, "Deployment")] + public void CanReadDeploymentNameFromConnectionString(bool useKeyed, string connectionStringKey) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake;{connectionStringKey}=testdeployment1") + ]); + + if (useKeyed) + { + builder.AddOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator"); + } + else + { + builder.AddOpenAIClient("openai").AddEmbeddingGenerator(); + } + + using var host = builder.Build(); + var generator = useKeyed ? + host.Services.GetRequiredKeyedService>>("openai_embeddinggenerator") : + host.Services.GetRequiredService>>(); + + Assert.NotNull(generator); + Assert.Equal("testdeployment1", generator.Metadata.ModelId); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void CanAcceptDeploymentNameAsArgument(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake") + ]); + + if (useKeyed) + { + builder.AddOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator", "testdeployment1"); + } + else + { + builder.AddOpenAIClient("openai").AddEmbeddingGenerator("testdeployment1"); + } + + using var host = builder.Build(); + var generator = useKeyed ? + host.Services.GetRequiredKeyedService>>("openai_embeddinggenerator") : + host.Services.GetRequiredService>>(); + + Assert.NotNull(generator); + Assert.Equal("testdeployment1", generator.Metadata.ModelId); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void RejectsConnectionStringWithBothModelAndDeployment(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake;Deployment=testdeployment1;Model=something") + ]); + + if (useKeyed) + { + builder.AddOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator"); + } + else + { + builder.AddOpenAIClient("openai").AddEmbeddingGenerator(); + } + + using var host = builder.Build(); + + var ex = Assert.Throws(() => + { + _ = useKeyed ? + host.Services.GetRequiredKeyedService>>("openai_embeddinggenerator") : + host.Services.GetRequiredService>>(); + }); + + Assert.StartsWith("The connection string 'openai' contains both 'Deployment' and 'Model' keys.", ex.Message); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void RejectsDeploymentNameNotSpecified(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake") + ]); + + if (useKeyed) + { + builder.AddOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator"); + } + else + { + builder.AddOpenAIClient("openai").AddEmbeddingGenerator(); + } + + using var host = builder.Build(); + + var ex = Assert.Throws(() => + { + _ = useKeyed ? + host.Services.GetRequiredKeyedService>>("openai_embeddinggenerator") : + host.Services.GetRequiredService>>(); + }); + + Assert.StartsWith("The deployment could not be determined", ex.Message); + } + + [Theory] + [InlineData(true, false)] + [InlineData(false, false)] + [InlineData(true, true)] + [InlineData(false, true)] + public void AddsOpenTelemetry(bool useKeyed, bool disableOpenTelemetry) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake"), + new("Aspire:OpenAI:DisableTracing", disableOpenTelemetry.ToString()), + ]); + + if (useKeyed) + { + builder.AddOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator", "testdeployment1"); + } + else + { + builder.AddOpenAIClient("openai").AddEmbeddingGenerator("testdeployment1"); + } + + using var host = builder.Build(); + var generator = useKeyed ? + host.Services.GetRequiredKeyedService>>("openai_embeddinggenerator") : + host.Services.GetRequiredService>>(); + + Assert.Equal(disableOpenTelemetry, generator.GetService>>() is null); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanConfigurePipelineAsync(bool useKeyed) + { + var builder = Host.CreateEmptyApplicationBuilder(null); + builder.Configuration.AddInMemoryCollection([ + new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake") + ]); + + if (useKeyed) + { + builder.AddOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator", "testdeployment1").Use(TestMiddleware); + } + else + { + builder.AddOpenAIClient("openai").AddEmbeddingGenerator("testdeployment1").Use(TestMiddleware); + } + + using var host = builder.Build(); + var generator = useKeyed ? + host.Services.GetRequiredKeyedService>>("openai_embeddinggenerator") : + host.Services.GetRequiredService>>(); + + var vector = await generator.GenerateEmbeddingVectorAsync("Hello"); + Assert.Equal(1.23f, vector.ToArray().Single()); + } + + private Task>> TestMiddleware(IEnumerable inputs, EmbeddingGenerationOptions? options, IEmbeddingGenerator> nextAsync, CancellationToken cancellationToken) + { + float[] floats = [1.23f]; + return Task.FromResult(new GeneratedEmbeddings>(inputs.Select(i => new Embedding(floats)))); + } +}