Skip to content

Commit

Permalink
Address comments 3rd batch
Browse files Browse the repository at this point in the history
  • Loading branch information
TaoChenOSU committed Feb 19, 2025
1 parent dd90d28 commit 4bc6141
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 160 deletions.
109 changes: 0 additions & 109 deletions dotnet/src/Agents/Bedrock/BedrockAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System.Threading;
using System.Threading.Tasks;
using Amazon.BedrockAgent;
using Amazon.BedrockAgent.Model;
using Amazon.BedrockAgentRuntime;
using Amazon.BedrockAgentRuntime.Model;
using Microsoft.SemanticKernel.Agents.Bedrock.Extensions;
Expand Down Expand Up @@ -231,114 +230,6 @@ async IAsyncEnumerable<StreamingChatMessageContent> InvokeInternal()
}
}

/// <summary>
/// Associate the agent with a knowledge base.
/// </summary>
/// <param name="knowledgeBaseId">The id of the knowledge base to associate with the agent.</param>
/// <param name="description">A description of what the agent should use the knowledge base for.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
public async Task AssociateAgentKnowledgeBaseAsync(string knowledgeBaseId, string description, CancellationToken cancellationToken = default)
{
await this.Client.AssociateAgentKnowledgeBaseAsync(new()
{
AgentId = this.Id,
AgentVersion = this.AgentModel.AgentVersion ?? "DRAFT",
KnowledgeBaseId = knowledgeBaseId,
Description = description,
}, cancellationToken).ConfigureAwait(false);

await this.Client.PrepareAgentAsync(new() { AgentId = this.Id }, cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// Disassociate the agent with a knowledge base.
/// </summary>
/// <param name="knowledgeBaseId">The id of the knowledge base to disassociate with the agent.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
public async Task DisassociateAgentKnowledgeBaseAsync(string knowledgeBaseId, CancellationToken cancellationToken = default)
{
await this.Client.DisassociateAgentKnowledgeBaseAsync(new()
{
AgentId = this.Id,
AgentVersion = this.AgentModel.AgentVersion ?? "DRAFT",
KnowledgeBaseId = knowledgeBaseId,
}, cancellationToken).ConfigureAwait(false);

await this.Client.PrepareAgentAsync(new() { AgentId = this.Id }, cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// List the knowledge bases associated with the agent.
/// </summary>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A <see cref="ListAgentKnowledgeBasesResponse"/> containing the knowledge bases associated with the agent.</returns>
public async Task<ListAgentKnowledgeBasesResponse> ListAssociatedKnowledgeBasesAsync(CancellationToken cancellationToken = default)
{
return await this.Client.ListAgentKnowledgeBasesAsync(new()
{
AgentId = this.Id,
AgentVersion = this.AgentModel.AgentVersion ?? "DRAFT",
}, cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// Create a code interpreter action group for the agent and prepare the agent.
/// </summary>
public async Task CreateCodeInterpreterActionGroupAsync(CancellationToken cancellationToken = default)
{
var createAgentActionGroupRequest = new CreateAgentActionGroupRequest
{
AgentId = this.Id,
AgentVersion = this.AgentModel.AgentVersion ?? "DRAFT",
ActionGroupName = this.CodeInterpreterActionGroupSignature,
ActionGroupState = ActionGroupState.ENABLED,
ParentActionGroupSignature = new(Amazon.BedrockAgent.ActionGroupSignature.AMAZONCodeInterpreter),
};

await this.Client.CreateAgentActionGroupAsync(createAgentActionGroupRequest, cancellationToken).ConfigureAwait(false);
await this.Client.PrepareAgentAsync(new() { AgentId = this.Id }, cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// Create a kernel function action group for the agent and prepare the agent.
/// </summary>
public async Task CreateKernelFunctionActionGroupAsync(CancellationToken cancellationToken = default)
{
var createAgentActionGroupRequest = new CreateAgentActionGroupRequest
{
AgentId = this.Id,
AgentVersion = this.AgentModel.AgentVersion ?? "DRAFT",
ActionGroupName = this.KernelFunctionActionGroupSignature,
ActionGroupState = ActionGroupState.ENABLED,
ActionGroupExecutor = new()
{
CustomControl = Amazon.BedrockAgent.CustomControlMethod.RETURN_CONTROL,
},
FunctionSchema = this.Kernel.ToFunctionSchema(),
};

await this.Client.CreateAgentActionGroupAsync(createAgentActionGroupRequest, cancellationToken).ConfigureAwait(false);
await this.Client.PrepareAgentAsync(new() { AgentId = this.Id }, cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// Enable user input for the agent and prepare the agent.
/// </summary>
public async Task EnableUserInputActionGroupAsync(CancellationToken cancellationToken = default)
{
var createAgentActionGroupRequest = new CreateAgentActionGroupRequest
{
AgentId = this.Id,
AgentVersion = this.AgentModel.AgentVersion ?? "DRAFT",
ActionGroupName = this.UseInputActionGroupSignature,
ActionGroupState = ActionGroupState.ENABLED,
ParentActionGroupSignature = new(Amazon.BedrockAgent.ActionGroupSignature.AMAZONUserInput),
};

await this.Client.CreateAgentActionGroupAsync(createAgentActionGroupRequest, cancellationToken).ConfigureAwait(false);
await this.Client.PrepareAgentAsync(new() { AgentId = this.Id }, cancellationToken).ConfigureAwait(false);
}

#endregion

/// <inheritdoc/>
Expand Down
112 changes: 64 additions & 48 deletions dotnet/src/Agents/Bedrock/BedrockAgentChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -63,71 +64,67 @@ protected override Task ReceiveAsync(IEnumerable<ChatMessageContent> history, Ca
}

/// <inheritdoc/>
protected override IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(
protected override async IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(
BedrockAgent agent,
CancellationToken cancellationToken)
[EnumeratorCancellation] CancellationToken cancellationToken)
{
return this._history.Count == 0 ? throw new InvalidOperationException("No messages to send.") : InvokeInternalAsync();
if (!this.PrepareAndValidateHistory())
{
yield break;
}

async IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeInternalAsync()
InvokeAgentRequest invokeAgentRequest = new()
{
this.EnsureHistoryAlternates();
this.EnsureLastMessageIsUser();
InvokeAgentRequest invokeAgentRequest = new()
{
AgentAliasId = BedrockAgent.WorkingDraftAgentAlias,
AgentId = agent.Id,
SessionId = BedrockAgent.CreateSessionId(),
InputText = this._history.Last().Content ?? throw new InvalidOperationException("Message content cannot be null."),
SessionState = this.ParseHistoryToSessionState(),
};
await foreach (var message in agent.InvokeAsync(invokeAgentRequest, null, cancellationToken).ConfigureAwait(false))
AgentAliasId = BedrockAgent.WorkingDraftAgentAlias,
AgentId = agent.Id,
SessionId = BedrockAgent.CreateSessionId(),
InputText = this._history.Last().Content,
SessionState = this.ParseHistoryToSessionState(),
};
await foreach (var message in agent.InvokeAsync(invokeAgentRequest, null, cancellationToken).ConfigureAwait(false))
{
if (message.Content is not null)
{
if (message.Content is not null)
{
this._history.Add(message);
// All messages from Bedrock agents are user facing, i.e., function calls are not returned as messages
yield return (true, message);
}
this._history.Add(message);
// All messages from Bedrock agents are user facing, i.e., function calls are not returned as messages
yield return (true, message);
}
}
}

/// <inheritdoc/>
protected override IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
protected override async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
BedrockAgent agent,
IList<ChatMessageContent> messages,
CancellationToken cancellationToken)
[EnumeratorCancellation] CancellationToken cancellationToken)
{
return this._history.Count == 0 ? throw new InvalidOperationException("No messages to send.") : InvokeInternalAsync();
if (!this.PrepareAndValidateHistory())
{
yield break;
}

async IAsyncEnumerable<StreamingChatMessageContent> InvokeInternalAsync()
InvokeAgentRequest invokeAgentRequest = new()
{
this.EnsureHistoryAlternates();
this.EnsureLastMessageIsUser();
InvokeAgentRequest invokeAgentRequest = new()
{
AgentAliasId = BedrockAgent.WorkingDraftAgentAlias,
AgentId = agent.Id,
SessionId = BedrockAgent.CreateSessionId(),
InputText = this._history.Last().Content ?? throw new InvalidOperationException("Message content cannot be null."),
SessionState = this.ParseHistoryToSessionState(),
};
await foreach (var message in agent.InvokeStreamingAsync(invokeAgentRequest, null, cancellationToken).ConfigureAwait(false))
AgentAliasId = BedrockAgent.WorkingDraftAgentAlias,
AgentId = agent.Id,
SessionId = BedrockAgent.CreateSessionId(),
InputText = this._history.Last().Content,
SessionState = this.ParseHistoryToSessionState(),
};
await foreach (var message in agent.InvokeStreamingAsync(invokeAgentRequest, null, cancellationToken).ConfigureAwait(false))
{
if (message.Content is not null)
{
if (message.Content is not null)
this._history.Add(new()
{
this._history.Add(new()
{
Role = AuthorRole.Assistant,
Content = message.Content,
AuthorName = message.AuthorName,
InnerContent = message.InnerContent,
ModelId = message.ModelId,
});
// All messages from Bedrock agents are user facing, i.e., function calls are not returned as messages
yield return message;
}
Role = AuthorRole.Assistant,
Content = message.Content,
AuthorName = message.AuthorName,
InnerContent = message.InnerContent,
ModelId = message.ModelId,
});
// All messages from Bedrock agents are user facing, i.e., function calls are not returned as messages
yield return message;
}
}
}
Expand All @@ -152,6 +149,25 @@ protected override string Serialize()

#region private methods

private bool PrepareAndValidateHistory()
{
if (this._history.Count == 0)
{
this.Logger.LogWarning("No messages to send. Bedrock requires at least one message to start a conversation.");
return false;
}

this.EnsureHistoryAlternates();
this.EnsureLastMessageIsUser();
if (string.IsNullOrEmpty(this._history.Last().Content))
{
this.Logger.LogWarning("Last message has no content. Bedrock doesn't support empty messages.");
return false;
}

return true;
}

private void EnsureHistoryAlternates()
{
if (this._history.Count <= 1)
Expand Down
Loading

0 comments on commit 4bc6141

Please sign in to comment.