diff --git a/dotnet/samples/KernelSyntaxExamples/Example62_CustomAIServiceSelector.cs b/dotnet/samples/KernelSyntaxExamples/Example62_CustomAIServiceSelector.cs index 82ce2a8e6527..c4f620366c77 100644 --- a/dotnet/samples/KernelSyntaxExamples/Example62_CustomAIServiceSelector.cs +++ b/dotnet/samples/KernelSyntaxExamples/Example62_CustomAIServiceSelector.cs @@ -3,11 +3,11 @@ using System; using System.Collections.Generic; using System.Threading.Tasks; -using Microsoft.ML.Tokenizers; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.AI; using Microsoft.SemanticKernel.Connectors.AI.OpenAI; using Microsoft.SemanticKernel.Diagnostics; +using Microsoft.SemanticKernel.Orchestration; using Microsoft.SemanticKernel.Services; using Microsoft.SemanticKernel.TemplateEngine; using RepoUtils; @@ -52,17 +52,16 @@ public static async Task RunAsync() modelId: openAIModelId, serviceId: "OpenAIChat", apiKey: openAIApiKey) - .WithAIServiceSelector(new MyAIServiceSelector()) + .WithAIServiceSelector(new ByModelIdAIServiceSelector(openAIModelId)) .Build(); var modelSettings = new List { - new OpenAIRequestSettings() { ServiceId = "AzureOpenAIChat", MaxTokens = 400 }, - new OpenAIRequestSettings() { ServiceId = "OpenAIChat", MaxTokens = 200 } + new OpenAIRequestSettings() { ServiceId = "AzureOpenAIChat", ModelId = "" }, + new OpenAIRequestSettings() { ServiceId = "OpenAIChat", ModelId = openAIModelId } }; await RunSemanticFunctionAsync(kernel, "Hello AI, what can you do for me?", modelSettings); - await RunSemanticFunctionAsync(kernel, "Hello AI, provide an indepth description of what can you do for me as a bulleted list?", modelSettings); } public static async Task RunSemanticFunctionAsync(IKernel kernel, string prompt, List modelSettings) @@ -81,83 +80,33 @@ public static async Task RunSemanticFunctionAsync(IKernel kernel, string prompt, } } -public class MyAIServiceSelector : IAIServiceSelector +public class ByModelIdAIServiceSelector : IAIServiceSelector { - private readonly int _defaultMaxTokens = 300; - private readonly int _minResponseTokens = 150; + private readonly string _openAIModelId; - public (T?, AIRequestSettings?) SelectAIService(string renderedPrompt, IAIServiceProvider serviceProvider, IReadOnlyList? modelSettings) where T : IAIService + public ByModelIdAIServiceSelector(string openAIModelId) { - if (modelSettings is null || modelSettings.Count == 0) - { - var service = serviceProvider.GetService(null); - if (service is not null) - { - return (service, null); - } - } - else - { - var tokens = this.CountTokens(renderedPrompt); + this._openAIModelId = openAIModelId; + } - string? serviceId = null; - int fewestTokens = 0; - AIRequestSettings? requestSettings = null; - AIRequestSettings? defaultRequestSettings = null; - foreach (var model in modelSettings) + public (T?, AIRequestSettings?) SelectAIService(SKContext context, ISKFunction skfunction) where T : IAIService + { + foreach (var model in skfunction.ModelSettings) + { + if (model is OpenAIRequestSettings openAIModel) { - if (!string.IsNullOrEmpty(model.ServiceId)) + if (openAIModel.ModelId == this._openAIModelId) { - if (model is OpenAIRequestSettings openAIModel) + var service = context.ServiceProvider.GetService(openAIModel.ServiceId); + if (service is not null) { - var responseTokens = (openAIModel.MaxTokens ?? this._defaultMaxTokens) - tokens; - if (serviceId is null || (responseTokens > this._minResponseTokens && responseTokens < fewestTokens)) - { - fewestTokens = responseTokens; - serviceId = model.ServiceId; - requestSettings = model; - } + Console.WriteLine($"======== Selected service: {openAIModel.ServiceId} {openAIModel.ModelId} ========"); + return (service, model); } } - else - { - // First request settings with empty or null service id is the default - defaultRequestSettings ??= model; - } - } - Console.WriteLine($"Prompt tokens: {tokens}, Response tokens: {fewestTokens}"); - - if (serviceId is not null) - { - Console.WriteLine($"Selected service: {serviceId}"); - var service = serviceProvider.GetService(serviceId); - if (service is not null) - { - return (service, requestSettings); - } - } - - if (defaultRequestSettings is not null) - { - var service = serviceProvider.GetService(null); - if (service is not null) - { - return (service, defaultRequestSettings); - } } } throw new SKException("Unable to find AI service to handled request."); } - - /// - /// MicrosoftML token counter implementation. - /// - private int CountTokens(string input) - { - Tokenizer tokenizer = new(new Bpe()); - var tokens = tokenizer.Encode(input).Tokens; - - return tokens.Count; - } } diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/ISKFunction.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/ISKFunction.cs index 02565b767265..cb33723d9539 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/ISKFunction.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/ISKFunction.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.ComponentModel; using System.Threading; using System.Threading.Tasks; @@ -33,6 +34,11 @@ public interface ISKFunction /// string Description { get; } + /// + /// Model request settings. + /// + IEnumerable ModelSettings { get; } + /// /// Returns a description of the function, including parameters. /// diff --git a/dotnet/src/SemanticKernel.Abstractions/Orchestration/SKContext.cs b/dotnet/src/SemanticKernel.Abstractions/Orchestration/SKContext.cs index aa95511c3082..a1241b62b0ec 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Orchestration/SKContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Orchestration/SKContext.cs @@ -63,7 +63,7 @@ public CultureInfo Culture /// /// AI service provider /// - internal IAIServiceProvider ServiceProvider { get; } + public IAIServiceProvider ServiceProvider { get; } /// /// AIService selector implementation diff --git a/dotnet/src/SemanticKernel.Abstractions/Services/IAIServiceSelector.cs b/dotnet/src/SemanticKernel.Abstractions/Services/IAIServiceSelector.cs index f35a52ffaf34..d544843447b5 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Services/IAIServiceSelector.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Services/IAIServiceSelector.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Collections.Generic; using Microsoft.SemanticKernel.AI; +using Microsoft.SemanticKernel.Orchestration; using Microsoft.SemanticKernel.Services; #pragma warning disable IDE0130 @@ -18,9 +18,8 @@ public interface IAIServiceSelector /// The returned value is a tuple containing instances of and /// /// Type of AI service to return - /// Rendered prompt - /// AI service provider - /// Collection of model settings + /// Semantic Kernel context + /// Semantic Kernel callable function interface /// - (T?, AIRequestSettings?) SelectAIService(string renderedPrompt, IAIServiceProvider serviceProvider, IReadOnlyList? modelSettings) where T : IAIService; + (T?, AIRequestSettings?) SelectAIService(SKContext context, ISKFunction skfunction) where T : IAIService; } diff --git a/dotnet/src/SemanticKernel.Abstractions/TemplateEngine/PromptTemplateConfig.cs b/dotnet/src/SemanticKernel.Abstractions/TemplateEngine/PromptTemplateConfig.cs index af931093debb..95b54af48225 100644 --- a/dotnet/src/SemanticKernel.Abstractions/TemplateEngine/PromptTemplateConfig.cs +++ b/dotnet/src/SemanticKernel.Abstractions/TemplateEngine/PromptTemplateConfig.cs @@ -85,7 +85,6 @@ public class InputConfig /// /// Model request settings. - /// Initially only a single model request settings is supported. /// [JsonPropertyName("models")] [JsonPropertyOrder(4)] diff --git a/dotnet/src/SemanticKernel.Core/Functions/DelegatingAIServiceSelector.cs b/dotnet/src/SemanticKernel.Core/Functions/DelegatingAIServiceSelector.cs index 05a7a05e6411..52904833802a 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/DelegatingAIServiceSelector.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/DelegatingAIServiceSelector.cs @@ -1,9 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; using Microsoft.SemanticKernel.AI; using Microsoft.SemanticKernel.AI.TextCompletion; +using Microsoft.SemanticKernel.Orchestration; using Microsoft.SemanticKernel.Services; namespace Microsoft.SemanticKernel.Functions; @@ -18,8 +18,8 @@ internal class DelegatingAIServiceSelector : IAIServiceSelector internal AIRequestSettings? RequestSettings { get; set; } /// - public (T?, AIRequestSettings?) SelectAIService(string renderedPrompt, IAIServiceProvider serviceProvider, IReadOnlyList? modelSettings) where T : IAIService + public (T?, AIRequestSettings?) SelectAIService(SKContext context, ISKFunction skfunction) where T : IAIService { - return ((T?)this.ServiceFactory?.Invoke() ?? serviceProvider.GetService(null), this.RequestSettings ?? modelSettings?[0]); + return ((T?)this.ServiceFactory?.Invoke() ?? context.ServiceProvider.GetService(null), this.RequestSettings ?? skfunction.RequestSettings); } } diff --git a/dotnet/src/SemanticKernel.Core/Functions/InstrumentedSKFunction.cs b/dotnet/src/SemanticKernel.Core/Functions/InstrumentedSKFunction.cs index cef67d4718c5..84c670376a64 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/InstrumentedSKFunction.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/InstrumentedSKFunction.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; using System.Diagnostics.Metrics; @@ -31,6 +32,9 @@ internal sealed class InstrumentedSKFunction : ISKFunction /// public string Description => this._function.Description; + /// + public IEnumerable ModelSettings => this._function.ModelSettings; + /// /// Initialize a new instance of the class. /// diff --git a/dotnet/src/SemanticKernel.Core/Functions/NativeFunction.cs b/dotnet/src/SemanticKernel.Core/Functions/NativeFunction.cs index 4c526a3612e1..f1b8dc4fd59b 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/NativeFunction.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/NativeFunction.cs @@ -45,6 +45,9 @@ internal sealed class NativeFunction : ISKFunction, IDisposable /// public string Description { get; } + /// + public IEnumerable ModelSettings => Enumerable.Empty(); + /// /// List of function parameters /// diff --git a/dotnet/src/SemanticKernel.Core/Functions/OrderedIAIServiceSelector.cs b/dotnet/src/SemanticKernel.Core/Functions/OrderedIAIServiceSelector.cs index 0ce1c06f67cb..db739fbdb607 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/OrderedIAIServiceSelector.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/OrderedIAIServiceSelector.cs @@ -1,9 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Collections.Generic; using System.Linq; using Microsoft.SemanticKernel.AI; using Microsoft.SemanticKernel.Diagnostics; +using Microsoft.SemanticKernel.Orchestration; using Microsoft.SemanticKernel.Services; namespace Microsoft.SemanticKernel.Functions; @@ -15,9 +15,11 @@ namespace Microsoft.SemanticKernel.Functions; internal class OrderedIAIServiceSelector : IAIServiceSelector { /// - public (T?, AIRequestSettings?) SelectAIService(string renderedPrompt, IAIServiceProvider serviceProvider, IReadOnlyList? modelSettings) where T : IAIService + public (T?, AIRequestSettings?) SelectAIService(SKContext context, ISKFunction skfunction) where T : IAIService { - if (modelSettings is null || modelSettings.Count == 0) + var serviceProvider = context.ServiceProvider; + var modelSettings = skfunction.ModelSettings; + if (modelSettings is null || !modelSettings.Any()) { var service = serviceProvider.GetService(null); if (service is not null) diff --git a/dotnet/src/SemanticKernel.Core/Functions/SemanticFunction.cs b/dotnet/src/SemanticKernel.Core/Functions/SemanticFunction.cs index 7a2e09d65ed1..52c05cbb434a 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/SemanticFunction.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/SemanticFunction.cs @@ -38,7 +38,10 @@ internal sealed class SemanticFunction : ISKFunction, IDisposable public string PluginName { get; } /// - public string Description { get; } + public string Description => this._promptTemplateConfig.Description; + + /// + public IEnumerable ModelSettings => this._promptTemplateConfig.ModelSettings.AsReadOnly(); /// /// List of function parameters @@ -66,18 +69,13 @@ public static ISKFunction FromSemanticConfig( Verify.NotNull(promptTemplateConfig); Verify.NotNull(promptTemplate); - var func = new SemanticFunction( + return new SemanticFunction( template: promptTemplate, - description: promptTemplateConfig.Description, + promptTemplateConfig: promptTemplateConfig, pluginName: pluginName, functionName: functionName, loggerFactory: loggerFactory - ) - { - _modelSettings = promptTemplateConfig.ModelSettings - }; - - return func; + ); } /// @@ -118,9 +116,9 @@ public string ToString(bool writeIndented) internal SemanticFunction( IPromptTemplate template, + PromptTemplateConfig promptTemplateConfig, string pluginName, string functionName, - string description, ILoggerFactory? loggerFactory = null) { Verify.NotNull(template); @@ -130,13 +128,13 @@ internal SemanticFunction( this._logger = loggerFactory is not null ? loggerFactory.CreateLogger(typeof(SemanticFunction)) : NullLogger.Instance; this._promptTemplate = template; + this._promptTemplateConfig = promptTemplateConfig; Verify.ParametersUniqueness(this.Parameters); this.Name = functionName; this.PluginName = pluginName; - this.Description = description; - this._view = new(() => new(functionName, pluginName, description, this.Parameters)); + this._view = new(() => new(functionName, pluginName, promptTemplateConfig.Description, this.Parameters)); } #region private @@ -145,7 +143,7 @@ internal SemanticFunction( private static readonly JsonSerializerOptions s_toStringIndentedSerialization = new() { WriteIndented = true }; private readonly ILogger _logger; private IAIServiceSelector? _serviceSelector; - private List? _modelSettings; + private readonly PromptTemplateConfig _promptTemplateConfig; private readonly Lazy _view; private readonly IPromptTemplate _promptTemplate; @@ -182,7 +180,7 @@ private async Task RunPromptAsync( string renderedPrompt = await this._promptTemplate.RenderAsync(context, cancellationToken).ConfigureAwait(false); var serviceSelector = this._serviceSelector ?? context.ServiceSelector; - (var textCompletion, var defaultRequestSettings) = serviceSelector.SelectAIService(renderedPrompt, context.ServiceProvider, this._modelSettings); + (var textCompletion, var defaultRequestSettings) = serviceSelector.SelectAIService(context, this); Verify.NotNull(textCompletion); this.CallFunctionInvoking(context, renderedPrompt); @@ -293,7 +291,7 @@ private string GetPromptFromEventArgsMetadataOrDefault(SKContext context, string /// [Obsolete("Use ISKFunction.ModelSettings instead. This will be removed in a future release.")] - public AIRequestSettings? RequestSettings => this._modelSettings?.FirstOrDefault(); + public AIRequestSettings? RequestSettings => this._promptTemplateConfig.ModelSettings?.FirstOrDefault(); /// [Obsolete("Use ISKFunction.SetAIServiceFactory instead. This will be removed in a future release.")] diff --git a/dotnet/src/SemanticKernel.Core/Planning/InstrumentedPlan.cs b/dotnet/src/SemanticKernel.Core/Planning/InstrumentedPlan.cs index 83c087f6f7aa..c04d0bf7a171 100644 --- a/dotnet/src/SemanticKernel.Core/Planning/InstrumentedPlan.cs +++ b/dotnet/src/SemanticKernel.Core/Planning/InstrumentedPlan.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; using System.Diagnostics.Metrics; @@ -28,6 +29,9 @@ internal sealed class InstrumentedPlan : ISKFunction /// public string Description => this._plan.Description; + /// + public IEnumerable ModelSettings => this._plan.ModelSettings; + /// /// Initialize a new instance of the class. /// diff --git a/dotnet/src/SemanticKernel.Core/Planning/Plan.cs b/dotnet/src/SemanticKernel.Core/Planning/Plan.cs index 324b080bf04b..37b5d75a1a54 100644 --- a/dotnet/src/SemanticKernel.Core/Planning/Plan.cs +++ b/dotnet/src/SemanticKernel.Core/Planning/Plan.cs @@ -79,7 +79,7 @@ public sealed class Plan : ISKFunction /// [JsonPropertyName("model_settings")] - public List? ModelSettings { get; private set; } + public IEnumerable ModelSettings => this.Function?.ModelSettings ?? Array.Empty(); #endregion ISKFunction implementation diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedIAIServiceConfigurationProviderTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedIAIServiceConfigurationProviderTests.cs index c230a440e367..4796660c87ed 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedIAIServiceConfigurationProviderTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedIAIServiceConfigurationProviderTests.cs @@ -4,11 +4,13 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.AI; using Microsoft.SemanticKernel.AI.TextCompletion; using Microsoft.SemanticKernel.Diagnostics; using Microsoft.SemanticKernel.Functions; using Microsoft.SemanticKernel.Services; +using Microsoft.SemanticKernel.TemplateEngine; using Xunit; namespace SemanticKernel.UnitTests.Functions; @@ -18,30 +20,27 @@ public class OrderedIAIServiceConfigurationProviderTests public void ItThrowsAnSKExceptionForNoServices() { // Arrange - var renderedPrompt = "Hello AI, what can you do for me?"; - var serviceCollection = new AIServiceCollection(); - var serviceProvider = serviceCollection.Build(); - var modelSettings = new List(); - var configurationProvider = new OrderedIAIServiceSelector(); + var kernel = new KernelBuilder().Build(); + var context = kernel.CreateNewContext(); + var skfunction = kernel.CreateSemanticFunction("Hello AI"); + var serviceSelector = new OrderedIAIServiceSelector(); // Act // Assert - Assert.Throws(() => configurationProvider.SelectAIService(renderedPrompt, serviceProvider, modelSettings)); + Assert.Throws(() => serviceSelector.SelectAIService(context, skfunction)); } [Fact] public void ItGetsAIServiceConfigurationForSingleAIService() { // Arrange - var renderedPrompt = "Hello AI, what can you do for me?"; - var serviceCollection = new AIServiceCollection(); - serviceCollection.SetService(new AIService()); - var serviceProvider = serviceCollection.Build(); - var modelSettings = new List(); - var configurationProvider = new OrderedIAIServiceSelector(); + var kernel = new KernelBuilder().WithAIService("service1", new AIService()).Build(); + var context = kernel.CreateNewContext(); + var skfunction = kernel.CreateSemanticFunction("Hello AI"); + var serviceSelector = new OrderedIAIServiceSelector(); // Act - (var aiService, var defaultRequestSettings) = configurationProvider.SelectAIService(renderedPrompt, serviceProvider, modelSettings); + (var aiService, var defaultRequestSettings) = serviceSelector.SelectAIService(context, skfunction); // Assert Assert.NotNull(aiService); @@ -52,15 +51,13 @@ public void ItGetsAIServiceConfigurationForSingleAIService() public void ItGetsAIServiceConfigurationForSingleTextCompletion() { // Arrange - var renderedPrompt = "Hello AI, what can you do for me?"; - var serviceCollection = new AIServiceCollection(); - serviceCollection.SetService(new TextCompletion()); - var serviceProvider = serviceCollection.Build(); - var modelSettings = new List(); - var configurationProvider = new OrderedIAIServiceSelector(); + var kernel = new KernelBuilder().WithAIService("service1", new TextCompletion()).Build(); + var context = kernel.CreateNewContext(); + var skfunction = kernel.CreateSemanticFunction("Hello AI"); + var serviceSelector = new OrderedIAIServiceSelector(); // Act - (var aiService, var defaultRequestSettings) = configurationProvider.SelectAIService(renderedPrompt, serviceProvider, modelSettings); + (var aiService, var defaultRequestSettings) = serviceSelector.SelectAIService(context, skfunction); // Assert Assert.NotNull(aiService); @@ -68,81 +65,61 @@ public void ItGetsAIServiceConfigurationForSingleTextCompletion() } [Fact] - public void ItAIServiceConfigurationForTextCompletionByServiceId() + public void ItGetsAIServiceConfigurationForTextCompletionByServiceId() { // Arrange - var renderedPrompt = "Hello AI, what can you do for me?"; - var serviceCollection = new AIServiceCollection(); - serviceCollection.SetService("service1", new TextCompletion()); - serviceCollection.SetService("service2", new TextCompletion()); - var serviceProvider = serviceCollection.Build(); - var modelSettings = new List(); - var configurationProvider = new OrderedIAIServiceSelector(); + var kernel = new KernelBuilder() + .WithAIService("service1", new TextCompletion()) + .WithAIService("service2", new TextCompletion()) + .Build(); + var context = kernel.CreateNewContext(); + var requestSettings = new AIRequestSettings() { ServiceId = "service2" }; + var skfunction = kernel.CreateSemanticFunction("Hello AI", requestSettings: requestSettings); + var serviceSelector = new OrderedIAIServiceSelector(); // Act - (var aiService, var defaultRequestSettings) = configurationProvider.SelectAIService(renderedPrompt, serviceProvider, modelSettings); + (var aiService, var defaultRequestSettings) = serviceSelector.SelectAIService(context, skfunction); // Assert - Assert.NotNull(aiService); - Assert.Null(defaultRequestSettings); + Assert.Equal(context.ServiceProvider.GetService("service2"), aiService); + Assert.Equal(requestSettings, defaultRequestSettings); } [Fact] public void ItThrowsAnSKExceptionForNotFoundService() { // Arrange - var renderedPrompt = "Hello AI, what can you do for me?"; - var serviceCollection = new AIServiceCollection(); - serviceCollection.SetService("service1", new TextCompletion()); - serviceCollection.SetService("service2", new TextCompletion()); - var serviceProvider = serviceCollection.Build(); - var modelSettings = new List - { - new AIRequestSettings() { ServiceId = "service3" } - }; - var configurationProvider = new OrderedIAIServiceSelector(); + var kernel = new KernelBuilder() + .WithAIService("service1", new TextCompletion()) + .WithAIService("service2", new TextCompletion()) + .Build(); + var context = kernel.CreateNewContext(); + var requestSettings = new AIRequestSettings() { ServiceId = "service3" }; + var skfunction = kernel.CreateSemanticFunction("Hello AI", requestSettings: requestSettings); + var serviceSelector = new OrderedIAIServiceSelector(); // Act // Assert - Assert.Throws(() => configurationProvider.SelectAIService(renderedPrompt, serviceProvider, modelSettings)); - } - - [Fact] - public void ItUsesDefaultServiceForNullModelSettings() - { - // Arrange - var renderedPrompt = "Hello AI, what can you do for me?"; - var serviceCollection = new AIServiceCollection(); - serviceCollection.SetService("service1", new TextCompletion()); - serviceCollection.SetService("service2", new TextCompletion(), true); - var serviceProvider = serviceCollection.Build(); - var configurationProvider = new OrderedIAIServiceSelector(); - - // Act - (var aiService, var defaultRequestSettings) = configurationProvider.SelectAIService(renderedPrompt, serviceProvider, null); - - // Assert - Assert.Equal(serviceProvider.GetService("service2"), aiService); - Assert.Null(defaultRequestSettings); + Assert.Throws(() => serviceSelector.SelectAIService(context, skfunction)); } [Fact] public void ItUsesDefaultServiceForEmptyModelSettings() { // Arrange - var renderedPrompt = "Hello AI, what can you do for me?"; - var serviceCollection = new AIServiceCollection(); - serviceCollection.SetService("service1", new TextCompletion()); - serviceCollection.SetService("service2", new TextCompletion(), true); - var serviceProvider = serviceCollection.Build(); - var modelSettings = new List(); - var configurationProvider = new OrderedIAIServiceSelector(); + var kernel = new KernelBuilder() + .WithAIService("service1", new TextCompletion()) + .WithAIService("service2", new TextCompletion(), true) + .Build(); + var context = kernel.CreateNewContext(); + var skfunction = kernel.CreateSemanticFunction("Hello AI"); + var serviceSelector = new OrderedIAIServiceSelector(); // Act - (var aiService, var defaultRequestSettings) = configurationProvider.SelectAIService(renderedPrompt, serviceProvider, modelSettings); + (var aiService, var defaultRequestSettings) = serviceSelector.SelectAIService(context, skfunction); // Assert - Assert.Equal(serviceProvider.GetService("service2"), aiService); + Assert.Equal(context.ServiceProvider.GetService("service2"), aiService); Assert.Null(defaultRequestSettings); } @@ -150,46 +127,43 @@ public void ItUsesDefaultServiceForEmptyModelSettings() public void ItUsesDefaultServiceAndSettings() { // Arrange - var renderedPrompt = "Hello AI, what can you do for me?"; - var serviceCollection = new AIServiceCollection(); - serviceCollection.SetService("service1", new TextCompletion()); - serviceCollection.SetService("service2", new TextCompletion(), true); - var serviceProvider = serviceCollection.Build(); - var modelSettings = new List - { - new AIRequestSettings() - }; - var configurationProvider = new OrderedIAIServiceSelector(); + // Arrange + var kernel = new KernelBuilder() + .WithAIService("service1", new TextCompletion()) + .WithAIService("service2", new TextCompletion(), true) + .Build(); + var context = kernel.CreateNewContext(); + var requestSettings = new AIRequestSettings(); + var skfunction = kernel.CreateSemanticFunction("Hello AI", requestSettings: requestSettings); + var serviceSelector = new OrderedIAIServiceSelector(); // Act - (var aiService, var defaultRequestSettings) = configurationProvider.SelectAIService(renderedPrompt, serviceProvider, modelSettings); + (var aiService, var defaultRequestSettings) = serviceSelector.SelectAIService(context, skfunction); // Assert - Assert.Equal(serviceProvider.GetService("service2"), aiService); - Assert.Equal(modelSettings[0], defaultRequestSettings); + Assert.Equal(context.ServiceProvider.GetService("service2"), aiService); + Assert.Equal(requestSettings, defaultRequestSettings); } [Fact] public void ItUsesDefaultServiceAndSettingsEmptyServiceId() { // Arrange - var renderedPrompt = "Hello AI, what can you do for me?"; - var serviceCollection = new AIServiceCollection(); - serviceCollection.SetService("service1", new TextCompletion()); - serviceCollection.SetService("service2", new TextCompletion(), true); - var serviceProvider = serviceCollection.Build(); - var modelSettings = new List - { - new AIRequestSettings() { ServiceId = "" } - }; - var configurationProvider = new OrderedIAIServiceSelector(); + var kernel = new KernelBuilder() + .WithAIService("service1", new TextCompletion()) + .WithAIService("service2", new TextCompletion(), true) + .Build(); + var context = kernel.CreateNewContext(); + var requestSettings = new AIRequestSettings() { ServiceId = "" }; + var skfunction = kernel.CreateSemanticFunction("Hello AI", requestSettings: requestSettings); + var serviceSelector = new OrderedIAIServiceSelector(); // Act - (var aiService, var defaultRequestSettings) = configurationProvider.SelectAIService(renderedPrompt, serviceProvider, modelSettings); + (var aiService, var defaultRequestSettings) = serviceSelector.SelectAIService(context, skfunction); // Assert - Assert.Equal(serviceProvider.GetService("service2"), aiService); - Assert.Equal(modelSettings[0], defaultRequestSettings); + Assert.Equal(context.ServiceProvider.GetService("service2"), aiService); + Assert.Equal(requestSettings, defaultRequestSettings); } [Theory] @@ -200,24 +174,25 @@ public void ItUsesDefaultServiceAndSettingsEmptyServiceId() public void ItGetsAIServiceConfigurationByOrder(string[] serviceIds, string expectedServiceId) { // Arrange - var renderedPrompt = "Hello AI, what can you do for me?"; - var serviceCollection = new AIServiceCollection(); - serviceCollection.SetService("service1", new TextCompletion()); - serviceCollection.SetService("service2", new TextCompletion()); - serviceCollection.SetService("service3", new TextCompletion()); - var serviceProvider = serviceCollection.Build(); + var kernel = new KernelBuilder() + .WithAIService("service1", new TextCompletion()) + .WithAIService("service2", new TextCompletion()) + .WithAIService("service3", new TextCompletion()) + .Build(); + var context = kernel.CreateNewContext(); var modelSettings = new List(); foreach (var serviceId in serviceIds) { modelSettings.Add(new AIRequestSettings() { ServiceId = serviceId }); } - var configurationProvider = new OrderedIAIServiceSelector(); + var skfunction = kernel.CreateSemanticFunction("Hello AI", promptTemplateConfig: new PromptTemplateConfig() { ModelSettings = modelSettings }); + var serviceSelector = new OrderedIAIServiceSelector(); // Act - (var aiService, var defaultRequestSettings) = configurationProvider.SelectAIService(renderedPrompt, serviceProvider, modelSettings); + (var aiService, var defaultRequestSettings) = serviceSelector.SelectAIService(context, skfunction); // Assert - Assert.Equal(serviceProvider.GetService(expectedServiceId), aiService); + Assert.Equal(context.ServiceProvider.GetService(expectedServiceId), aiService); Assert.Equal(expectedServiceId, defaultRequestSettings!.ServiceId); }