Skip to content

Commit

Permalink
.Net: Refactor IAIServiceSelector to allow for additional parameters …
Browse files Browse the repository at this point in the history
…in future (microsoft#3447)

### Motivation and Context

If we need to pass additional information to `IAIServiceSelector` in
future we want to avoid this being a breaking change.

Changing the signature to pass:

1. `SKContext` - This contains all information about the Kernel
execution context
2. `ISKFunction` - This has been extended to provide access to the model
settings

### Description

Example `IAIServiceSelector` implementation.

```csharp
public class ByModelIdAIServiceSelector : IAIServiceSelector
{
    private readonly string _openAIModelId;

    public ByModelIdAIServiceSelector(string openAIModelId)
    {
        this._openAIModelId = openAIModelId;
    }

    public (T?, AIRequestSettings?) SelectAIService<T>(SKContext context, ISKFunction skfunction) where T : IAIService
    {
        foreach (var model in skfunction.ModelSettings)
        {
            if (model is OpenAIRequestSettings openAIModel)
            {
                if (openAIModel.ModelId == this._openAIModelId)
                {
                    var service = context.ServiceProvider.GetService<T>(openAIModel.ServiceId);
                    if (service is not null)
                    {
                        return (service, model);
                    }
                }
            }
        }

        throw new SKException("Unable to find AI service to handled request.");
    }
}
```

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
markwallace-microsoft authored Nov 10, 2023
1 parent a5b106e commit 446480b
Show file tree
Hide file tree
Showing 13 changed files with 143 additions and 204 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.ML.Tokenizers;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Services;
using Microsoft.SemanticKernel.TemplateEngine;
using RepoUtils;
Expand Down Expand Up @@ -52,17 +52,16 @@ public static async Task RunAsync()
modelId: openAIModelId,
serviceId: "OpenAIChat",
apiKey: openAIApiKey)
.WithAIServiceSelector(new MyAIServiceSelector())
.WithAIServiceSelector(new ByModelIdAIServiceSelector(openAIModelId))
.Build();

var modelSettings = new List<AIRequestSettings>
{
new OpenAIRequestSettings() { ServiceId = "AzureOpenAIChat", MaxTokens = 400 },
new OpenAIRequestSettings() { ServiceId = "OpenAIChat", MaxTokens = 200 }
new OpenAIRequestSettings() { ServiceId = "AzureOpenAIChat", ModelId = "" },
new OpenAIRequestSettings() { ServiceId = "OpenAIChat", ModelId = openAIModelId }
};

await RunSemanticFunctionAsync(kernel, "Hello AI, what can you do for me?", modelSettings);
await RunSemanticFunctionAsync(kernel, "Hello AI, provide an indepth description of what can you do for me as a bulleted list?", modelSettings);
}

public static async Task RunSemanticFunctionAsync(IKernel kernel, string prompt, List<AIRequestSettings> modelSettings)
Expand All @@ -81,83 +80,33 @@ public static async Task RunSemanticFunctionAsync(IKernel kernel, string prompt,
}
}

public class MyAIServiceSelector : IAIServiceSelector
public class ByModelIdAIServiceSelector : IAIServiceSelector
{
private readonly int _defaultMaxTokens = 300;
private readonly int _minResponseTokens = 150;
private readonly string _openAIModelId;

public (T?, AIRequestSettings?) SelectAIService<T>(string renderedPrompt, IAIServiceProvider serviceProvider, IReadOnlyList<AIRequestSettings>? modelSettings) where T : IAIService
public ByModelIdAIServiceSelector(string openAIModelId)
{
if (modelSettings is null || modelSettings.Count == 0)
{
var service = serviceProvider.GetService<T>(null);
if (service is not null)
{
return (service, null);
}
}
else
{
var tokens = this.CountTokens(renderedPrompt);
this._openAIModelId = openAIModelId;
}

string? serviceId = null;
int fewestTokens = 0;
AIRequestSettings? requestSettings = null;
AIRequestSettings? defaultRequestSettings = null;
foreach (var model in modelSettings)
public (T?, AIRequestSettings?) SelectAIService<T>(SKContext context, ISKFunction skfunction) where T : IAIService
{
foreach (var model in skfunction.ModelSettings)
{
if (model is OpenAIRequestSettings openAIModel)
{
if (!string.IsNullOrEmpty(model.ServiceId))
if (openAIModel.ModelId == this._openAIModelId)
{
if (model is OpenAIRequestSettings openAIModel)
var service = context.ServiceProvider.GetService<T>(openAIModel.ServiceId);
if (service is not null)
{
var responseTokens = (openAIModel.MaxTokens ?? this._defaultMaxTokens) - tokens;
if (serviceId is null || (responseTokens > this._minResponseTokens && responseTokens < fewestTokens))
{
fewestTokens = responseTokens;
serviceId = model.ServiceId;
requestSettings = model;
}
Console.WriteLine($"======== Selected service: {openAIModel.ServiceId} {openAIModel.ModelId} ========");
return (service, model);
}
}
else
{
// First request settings with empty or null service id is the default
defaultRequestSettings ??= model;
}
}
Console.WriteLine($"Prompt tokens: {tokens}, Response tokens: {fewestTokens}");

if (serviceId is not null)
{
Console.WriteLine($"Selected service: {serviceId}");
var service = serviceProvider.GetService<T>(serviceId);
if (service is not null)
{
return (service, requestSettings);
}
}

if (defaultRequestSettings is not null)
{
var service = serviceProvider.GetService<T>(null);
if (service is not null)
{
return (service, defaultRequestSettings);
}
}
}

throw new SKException("Unable to find AI service to handled request.");
}

/// <summary>
/// MicrosoftML token counter implementation.
/// </summary>
private int CountTokens(string input)
{
Tokenizer tokenizer = new(new Bpe());
var tokens = tokenizer.Encode(input).Tokens;

return tokens.Count;
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -33,6 +34,11 @@ public interface ISKFunction
/// </summary>
string Description { get; }

/// <summary>
/// Model request settings.
/// </summary>
IEnumerable<AIRequestSettings> ModelSettings { get; }

/// <summary>
/// Returns a description of the function, including parameters.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public CultureInfo Culture
/// <summary>
/// AI service provider
/// </summary>
internal IAIServiceProvider ServiceProvider { get; }
public IAIServiceProvider ServiceProvider { get; }

/// <summary>
/// AIService selector implementation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Services;

#pragma warning disable IDE0130
Expand All @@ -18,9 +18,8 @@ public interface IAIServiceSelector
/// The returned value is a tuple containing instances of <see cref="IAIService"/> and <see cref="AIRequestSettings"/>
/// </summary>
/// <typeparam name="T">Type of AI service to return</typeparam>
/// <param name="renderedPrompt">Rendered prompt</param>
/// <param name="serviceProvider">AI service provider</param>
/// <param name="modelSettings">Collection of model settings</param>
/// <param name="context">Semantic Kernel context</param>
/// <param name="skfunction">Semantic Kernel callable function interface</param>
/// <returns></returns>
(T?, AIRequestSettings?) SelectAIService<T>(string renderedPrompt, IAIServiceProvider serviceProvider, IReadOnlyList<AIRequestSettings>? modelSettings) where T : IAIService;
(T?, AIRequestSettings?) SelectAIService<T>(SKContext context, ISKFunction skfunction) where T : IAIService;
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ public class InputConfig

/// <summary>
/// Model request settings.
/// Initially only a single model request settings is supported.
/// </summary>
[JsonPropertyName("models")]
[JsonPropertyOrder(4)]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.TextCompletion;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Services;

namespace Microsoft.SemanticKernel.Functions;
Expand All @@ -18,8 +18,8 @@ internal class DelegatingAIServiceSelector : IAIServiceSelector
internal AIRequestSettings? RequestSettings { get; set; }

/// <inheritdoc/>
public (T?, AIRequestSettings?) SelectAIService<T>(string renderedPrompt, IAIServiceProvider serviceProvider, IReadOnlyList<AIRequestSettings>? modelSettings) where T : IAIService
public (T?, AIRequestSettings?) SelectAIService<T>(SKContext context, ISKFunction skfunction) where T : IAIService
{
return ((T?)this.ServiceFactory?.Invoke() ?? serviceProvider.GetService<T>(null), this.RequestSettings ?? modelSettings?[0]);
return ((T?)this.ServiceFactory?.Invoke() ?? context.ServiceProvider.GetService<T>(null), this.RequestSettings ?? skfunction.RequestSettings);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Diagnostics.Metrics;
Expand Down Expand Up @@ -31,6 +32,9 @@ internal sealed class InstrumentedSKFunction : ISKFunction
/// <inheritdoc/>
public string Description => this._function.Description;

/// <inheritdoc/>
public IEnumerable<AIRequestSettings> ModelSettings => this._function.ModelSettings;

/// <summary>
/// Initialize a new instance of the <see cref="InstrumentedSKFunction"/> class.
/// </summary>
Expand Down
3 changes: 3 additions & 0 deletions dotnet/src/SemanticKernel.Core/Functions/NativeFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ internal sealed class NativeFunction : ISKFunction, IDisposable
/// <inheritdoc/>
public string Description { get; }

/// <inheritdoc/>
public IEnumerable<AIRequestSettings> ModelSettings => Enumerable.Empty<AIRequestSettings>();

/// <summary>
/// List of function parameters
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Linq;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Services;

namespace Microsoft.SemanticKernel.Functions;
Expand All @@ -15,9 +15,11 @@ namespace Microsoft.SemanticKernel.Functions;
internal class OrderedIAIServiceSelector : IAIServiceSelector
{
/// <inheritdoc/>
public (T?, AIRequestSettings?) SelectAIService<T>(string renderedPrompt, IAIServiceProvider serviceProvider, IReadOnlyList<AIRequestSettings>? modelSettings) where T : IAIService
public (T?, AIRequestSettings?) SelectAIService<T>(SKContext context, ISKFunction skfunction) where T : IAIService
{
if (modelSettings is null || modelSettings.Count == 0)
var serviceProvider = context.ServiceProvider;
var modelSettings = skfunction.ModelSettings;
if (modelSettings is null || !modelSettings.Any())
{
var service = serviceProvider.GetService<T>(null);
if (service is not null)
Expand Down
28 changes: 13 additions & 15 deletions dotnet/src/SemanticKernel.Core/Functions/SemanticFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ internal sealed class SemanticFunction : ISKFunction, IDisposable
public string PluginName { get; }

/// <inheritdoc/>
public string Description { get; }
public string Description => this._promptTemplateConfig.Description;

/// <inheritdoc/>
public IEnumerable<AIRequestSettings> ModelSettings => this._promptTemplateConfig.ModelSettings.AsReadOnly();

/// <summary>
/// List of function parameters
Expand Down Expand Up @@ -66,18 +69,13 @@ public static ISKFunction FromSemanticConfig(
Verify.NotNull(promptTemplateConfig);
Verify.NotNull(promptTemplate);

var func = new SemanticFunction(
return new SemanticFunction(
template: promptTemplate,
description: promptTemplateConfig.Description,
promptTemplateConfig: promptTemplateConfig,
pluginName: pluginName,
functionName: functionName,
loggerFactory: loggerFactory
)
{
_modelSettings = promptTemplateConfig.ModelSettings
};

return func;
);
}

/// <inheritdoc/>
Expand Down Expand Up @@ -118,9 +116,9 @@ public string ToString(bool writeIndented)

internal SemanticFunction(
IPromptTemplate template,
PromptTemplateConfig promptTemplateConfig,
string pluginName,
string functionName,
string description,
ILoggerFactory? loggerFactory = null)
{
Verify.NotNull(template);
Expand All @@ -130,13 +128,13 @@ internal SemanticFunction(
this._logger = loggerFactory is not null ? loggerFactory.CreateLogger(typeof(SemanticFunction)) : NullLogger.Instance;

this._promptTemplate = template;
this._promptTemplateConfig = promptTemplateConfig;
Verify.ParametersUniqueness(this.Parameters);

this.Name = functionName;
this.PluginName = pluginName;
this.Description = description;

this._view = new(() => new(functionName, pluginName, description, this.Parameters));
this._view = new(() => new(functionName, pluginName, promptTemplateConfig.Description, this.Parameters));
}

#region private
Expand All @@ -145,7 +143,7 @@ internal SemanticFunction(
private static readonly JsonSerializerOptions s_toStringIndentedSerialization = new() { WriteIndented = true };
private readonly ILogger _logger;
private IAIServiceSelector? _serviceSelector;
private List<AIRequestSettings>? _modelSettings;
private readonly PromptTemplateConfig _promptTemplateConfig;
private readonly Lazy<FunctionView> _view;
private readonly IPromptTemplate _promptTemplate;

Expand Down Expand Up @@ -182,7 +180,7 @@ private async Task<FunctionResult> RunPromptAsync(
string renderedPrompt = await this._promptTemplate.RenderAsync(context, cancellationToken).ConfigureAwait(false);

var serviceSelector = this._serviceSelector ?? context.ServiceSelector;
(var textCompletion, var defaultRequestSettings) = serviceSelector.SelectAIService<ITextCompletion>(renderedPrompt, context.ServiceProvider, this._modelSettings);
(var textCompletion, var defaultRequestSettings) = serviceSelector.SelectAIService<ITextCompletion>(context, this);
Verify.NotNull(textCompletion);

this.CallFunctionInvoking(context, renderedPrompt);
Expand Down Expand Up @@ -293,7 +291,7 @@ private string GetPromptFromEventArgsMetadataOrDefault(SKContext context, string

/// <inheritdoc/>
[Obsolete("Use ISKFunction.ModelSettings instead. This will be removed in a future release.")]
public AIRequestSettings? RequestSettings => this._modelSettings?.FirstOrDefault<AIRequestSettings>();
public AIRequestSettings? RequestSettings => this._promptTemplateConfig.ModelSettings?.FirstOrDefault<AIRequestSettings>();

/// <inheritdoc/>
[Obsolete("Use ISKFunction.SetAIServiceFactory instead. This will be removed in a future release.")]
Expand Down
4 changes: 4 additions & 0 deletions dotnet/src/SemanticKernel.Core/Planning/InstrumentedPlan.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Diagnostics.Metrics;
Expand Down Expand Up @@ -28,6 +29,9 @@ internal sealed class InstrumentedPlan : ISKFunction
/// <inheritdoc/>
public string Description => this._plan.Description;

/// <inheritdoc/>
public IEnumerable<AIRequestSettings> ModelSettings => this._plan.ModelSettings;

/// <summary>
/// Initialize a new instance of the <see cref="InstrumentedPlan"/> class.
/// </summary>
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/SemanticKernel.Core/Planning/Plan.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public sealed class Plan : ISKFunction

/// <inheritdoc/>
[JsonPropertyName("model_settings")]
public List<AIRequestSettings>? ModelSettings { get; private set; }
public IEnumerable<AIRequestSettings> ModelSettings => this.Function?.ModelSettings ?? Array.Empty<AIRequestSettings>();

#endregion ISKFunction implementation

Expand Down
Loading

0 comments on commit 446480b

Please sign in to comment.