Skip to content

Commit

Permalink
Migrate to LLMService Context Management
Browse files Browse the repository at this point in the history
Removed dependency on legacy state management and introduced individual context management for each LLMService. This change allows each service to manage its own context separately.

There is no fundamental change to the external interface. However, if your implementation relied on legacy state deletions to clear context, you will need to modify your code to use the `ClearContext()` method of LLMService.
  • Loading branch information
uezo committed Sep 29, 2024
1 parent 9cfb6d8 commit 9402684
Show file tree
Hide file tree
Showing 14 changed files with 217 additions and 336 deletions.
1 change: 1 addition & 0 deletions Plugins/ClaudeServiceWebGL.jslib
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mergeInto(LibraryManager.library, {
"anthropic-version": "2023-06-01",
"anthropic-beta": "messages-2023-12-15",
"Content-Type": "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"x-api-key": `${apiKey}`
},
method: "POST",
Expand Down
18 changes: 4 additions & 14 deletions Scripts/Dialog/DialogProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ public enum DialogStatus
public DialogStatus Status { get; private set; }
private string processingId { get; set; }
private ISkillRouter skillRouter { get; set; }
private IStateStore stateStore { get; set; }
private CancellationTokenSource dialogTokenSource { get; set; }

// Actions for each status
Expand All @@ -37,7 +36,6 @@ public enum DialogStatus
private void Awake()
{
// Get components
stateStore = gameObject.GetComponent<IStateStore>() ?? new MemoryStateStore();
skillRouter = gameObject.GetComponent<ISkillRouter>();
skillRouter.RegisterSkills();

Expand Down Expand Up @@ -84,14 +82,13 @@ public async UniTask StartDialogAsync(string text, Dictionary<string, object> pa
OnRequestRecievedTask = UniTask.Delay(1);
}

var state = await stateStore.GetStateAsync("_");
var request = new Request(RequestType.Voice, text);
request.Payloads = payloads ?? new Dictionary<string, object>();

Status = DialogStatus.Routing;

// Extract intent for routing
var intentExtractionResult = await skillRouter.ExtractIntentAsync(request, state, token);
var intentExtractionResult = await skillRouter.ExtractIntentAsync(request, null, token);
if (intentExtractionResult != null)
{
request.Intent = intentExtractionResult.Intent;
Expand All @@ -100,29 +97,27 @@ public async UniTask StartDialogAsync(string text, Dictionary<string, object> pa
if (token.IsCancellationRequested) { return; }

// Get skill to process intent / topic
var skill = skillRouter.Route(request, state, token);
var skill = skillRouter.Route(request, null, token);
if (token.IsCancellationRequested) { return; }

// Process skill
Status = DialogStatus.Processing;
var skillResponse = await skill.ProcessAsync(request, state, null, token);
var skillResponse = await skill.ProcessAsync(request, null, null, token);
if (token.IsCancellationRequested) { return; }

// Await before showing response
await OnRequestRecievedTask;

// Show response
Status = DialogStatus.Responding;
await skill.ShowResponseAsync(skillResponse, request, state, token);
await skill.ShowResponseAsync(skillResponse, request, null, token);
if (token.IsCancellationRequested) { return; }

if (OnResponseShownAsync != null)
{
await OnResponseShownAsync(skillResponse, token);
}

await stateStore.SaveStateAsync(state);

endConversation = skillResponse.EndConversation;
}
catch (Exception ex)
Expand Down Expand Up @@ -189,10 +184,5 @@ public CancellationToken GetDialogToken()
dialogTokenSource = new CancellationTokenSource();
return dialogTokenSource.Token;
}

public async UniTask ClearStateAsync(string userId = null)
{
await stateStore.DeleteStateAsync("_"); // "_" is the default user id of legacy ChatdollKit
}
}
}
86 changes: 27 additions & 59 deletions Scripts/LLM/ChatGPT/ChatGPTService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,11 @@
using UnityEngine.Networking;
using Cysharp.Threading.Tasks;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;

namespace ChatdollKit.LLM.ChatGPT
{
public class ChatGPTService : LLMServiceBase
{
public string HistoryKey = "ChatGPTHistories";
public string CustomParameterKey = "ChatGPTParameters";
public string CustomHeaderKey = "ChatGPTHeaders";

[Header("API configuration")]
public string ApiKey;
public string Model = "gpt-4o-mini";
Expand Down Expand Up @@ -54,40 +49,18 @@ public override ILLMMessage CreateMessageAfterFunction(string role = null, strin
}
}

protected List<ILLMMessage> GetHistoriesFromStateData(Dictionary<string, object> stateData, int count)
protected override void UpdateContext(LLMSession llmSession)
{
var messages = new List<ILLMMessage>();

// Add histories to state if not exists
if (!stateData.ContainsKey(HistoryKey) || stateData[HistoryKey] == null)
{
stateData[HistoryKey] = new JArray();
return messages;
}

// Get JToken array from state
var serializedMessagesAll = (JArray)stateData[HistoryKey];
var serializedMessages = serializedMessagesAll.Skip(serializedMessagesAll.Count - count * 2).ToList();
for (var i = 0; i < serializedMessages.Count; i++)
// User message
var lastUserMessage = llmSession.Contexts.Last();
if (lastUserMessage is ChatGPTUserMessage chatGPTUserMessage)
{
// JToken -> string -> Restore object
messages.Add(JsonConvert.DeserializeObject<ILLMMessage>(serializedMessages[i].ToString(), messageSerializationSettings));
// Remove non-text content to keep context light
chatGPTUserMessage.content.RemoveAll(part => !(part is TextContentPart));
}
context.Add(lastUserMessage);

return messages;
}

#pragma warning disable CS1998
public async UniTask AddHistoriesAsync(ILLMSession llmSession, Dictionary<string, object> dataStore, CancellationToken token = default)
{
// Prepare state store
var serializedMessages = (JArray)dataStore[HistoryKey];

// Add user message
var serializedUserMessage = JsonConvert.SerializeObject(llmSession.Contexts.Last(), messageSerializationSettings);
serializedMessages.Add(serializedUserMessage);

// Add assistant message
// Assistant message
if (llmSession.ResponseType == ResponseType.FunctionCalling)
{
var functionCallMessage = new ChatGPTAssistantMessage(tool_calls: new List<Dictionary<string, object>>() {
Expand All @@ -101,18 +74,21 @@ public async UniTask AddHistoriesAsync(ILLMSession llmSession, Dictionary<string
}},
}
});
serializedMessages.Add(JsonConvert.SerializeObject(functionCallMessage, messageSerializationSettings));
context.Add(functionCallMessage);

// Add also to contexts for using this message in this turn
llmSession.Contexts.Add(functionCallMessage);
}
else
{
var assistantMessage = new ChatGPTAssistantMessage(llmSession.StreamBuffer);
serializedMessages.Add(JsonConvert.SerializeObject(assistantMessage, messageSerializationSettings));
context.Add(assistantMessage);
}

contextUpdatedAt = Time.time;
}

#pragma warning disable CS1998
public override async UniTask<List<ILLMMessage>> MakePromptAsync(string userId, string inputText, Dictionary<string, object> payloads, CancellationToken token = default)
{
var messages = new List<ILLMMessage>();
Expand All @@ -124,8 +100,7 @@ public override async UniTask<List<ILLMMessage>> MakePromptAsync(string userId,
}

// Histories
var histories = GetHistoriesFromStateData((Dictionary<string, object>)payloads["StateData"], HistoryTurns);
messages.AddRange(histories);
messages.AddRange(GetContext(historyTurns));

// User (current input)
if (((Dictionary<string, object>)payloads["RequestPayloads"]).ContainsKey("imageBytes"))
Expand Down Expand Up @@ -153,14 +128,14 @@ public override async UniTask<List<ILLMMessage>> MakePromptAsync(string userId,
public override async UniTask<ILLMSession> GenerateContentAsync(List<ILLMMessage> messages, Dictionary<string, object> payloads, bool useFunctions = true, int retryCounter = 1, CancellationToken token = default)
{
// Custom parameters and headers
var stateData = (Dictionary<string, object>)payloads["StateData"];
var customParameters = stateData.ContainsKey(CustomParameterKey) ? (Dictionary<string, string>)stateData[CustomParameterKey] : new Dictionary<string, string>();
var customHeaders = stateData.ContainsKey(CustomHeaderKey) ? (Dictionary<string, string>)stateData[CustomHeaderKey] : new Dictionary<string, string>();
var requestPayloads = (Dictionary<string, object>)payloads["RequestPayloads"];
var customParameters = requestPayloads.ContainsKey(CustomParameterKey) ? (Dictionary<string, string>)requestPayloads[CustomParameterKey] : new Dictionary<string, string>();
var customHeaders = requestPayloads.ContainsKey(CustomHeaderKey) ? (Dictionary<string, string>)requestPayloads[CustomHeaderKey] : new Dictionary<string, string>();

// Start streaming session
var chatGPTSession = new ChatGPTSession();
chatGPTSession.Contexts = messages;
chatGPTSession.StreamingTask = StartStreamingAsync(chatGPTSession, stateData, customParameters, customHeaders, useFunctions, token);
chatGPTSession.StreamingTask = StartStreamingAsync(chatGPTSession, customParameters, customHeaders, useFunctions, token);
await WaitForFunctionInfo(chatGPTSession, token);

// Retry
Expand All @@ -182,7 +157,7 @@ public override async UniTask<ILLMSession> GenerateContentAsync(List<ILLMMessage
return chatGPTSession;
}

public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, Dictionary<string, object> stateData, Dictionary<string, string> customParameters, Dictionary<string, string> customHeaders, bool useFunctions = true, CancellationToken token = default)
public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, Dictionary<string, string> customParameters, Dictionary<string, string> customHeaders, bool useFunctions = true, CancellationToken token = default)
{
chatGPTSession.CurrentStreamBuffer = string.Empty;

Expand Down Expand Up @@ -323,21 +298,14 @@ public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
await UniTask.Delay(10);
}

// Remove non-text content to keep context light
var lastUserMessage = chatGPTSession.Contexts.Last() as ChatGPTUserMessage;
if (lastUserMessage != null)
{
lastUserMessage.content.RemoveAll(part => !(part is TextContentPart));
}

// Update histories
// Update context
if (chatGPTSession.ResponseType != ResponseType.Error && chatGPTSession.ResponseType != ResponseType.Timeout)
{
await AddHistoriesAsync(chatGPTSession, stateData, token);
UpdateContext(chatGPTSession);
}
else
{
Debug.LogWarning($"Messages are not added to histories for response type is not success: {chatGPTSession.ResponseType}");
Debug.LogWarning($"Messages are not added to context for response type is not success: {chatGPTSession.ResponseType}");
}

// Ends with error
Expand All @@ -346,8 +314,8 @@ public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
throw new Exception($"ChatGPT ends with error ({streamRequest.result}): {streamRequest.error}");
}

// Process tags
var extractedTags = ExtractTags(chatGPTSession.CurrentStreamBuffer);

if (extractedTags.Count > 0 && HandleExtractedTags != null)
{
HandleExtractedTags(extractedTags, chatGPTSession);
Expand All @@ -359,17 +327,17 @@ public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
chatGPTSession.IsVisionAvailable = false;

// Get image
var imageBytes = await CaptureImage(extractedTags["vision"]);
var imageSource = extractedTags["vision"];
var imageBytes = await CaptureImage(imageSource);

// Make contexts
var lastUserContentText = ((TextContentPart)lastUserMessage.content[0]).text;
if (imageBytes != null)
{
chatGPTSession.Contexts.Add(new ChatGPTAssistantMessage(chatGPTSession.StreamBuffer));
// Image -> Text to get the better accuracy
chatGPTSession.Contexts.Add(new ChatGPTUserMessage(new List<IContentPart>() {
new ImageUrlContentPart("data:image/jpeg;base64," + Convert.ToBase64String(imageBytes)),
new TextContentPart(lastUserContentText)
new TextContentPart($"This is the image you captured. (source: {imageSource})")
}));
}
else
Expand All @@ -378,7 +346,7 @@ public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
}

// Call recursively with image
await StartStreamingAsync(chatGPTSession, stateData, customParameters, customHeaders, useFunctions, token);
await StartStreamingAsync(chatGPTSession, customParameters, customHeaders, useFunctions, token);
}
else
{
Expand Down
38 changes: 11 additions & 27 deletions Scripts/LLM/ChatGPT/ChatGPTServiceWebGL.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,13 @@ public override bool IsEnabled
protected bool isChatCompletionJSDone { get; set; } = false;
protected Dictionary<string, ChatGPTSession> sessions { get; set; } = new Dictionary<string, ChatGPTSession>();

public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, Dictionary<string, object> stateData, Dictionary<string, string> customParameters, Dictionary<string, string> customHeaders, bool useFunctions = true, CancellationToken token = default)
public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, Dictionary<string, string> customParameters, Dictionary<string, string> customHeaders, bool useFunctions = true, CancellationToken token = default)
{
chatGPTSession.CurrentStreamBuffer = string.Empty;

string sessionId;
if (stateData.ContainsKey("chatGPTSessionId"))
{
// Use existing session id for callback
sessionId = (string)stateData["chatGPTSessionId"];
}
else
{
// Add session for callback
sessionId = Guid.NewGuid().ToString();
sessions.Add(sessionId, chatGPTSession);
}
// Store session with id to receive streaming data from JavaScript
var sessionId = Guid.NewGuid().ToString();
sessions.Add(sessionId, chatGPTSession);

// Make request data
var data = new Dictionary<string, object>()
Expand Down Expand Up @@ -91,7 +82,7 @@ public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
}

// TODO: Support custom headers later...
if (customHeaders.Count >= 0)
if (customHeaders.Count > 0)
{
Debug.LogWarning("Custom headers for ChatGPT on WebGL is not supported for now.");
}
Expand Down Expand Up @@ -153,17 +144,10 @@ public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
await UniTask.Delay(10);
}

// Remove non-text content to keep context light
var lastUserMessage = chatGPTSession.Contexts.Last() as ChatGPTUserMessage;
if (lastUserMessage != null)
{
lastUserMessage.content.RemoveAll(part => !(part is TextContentPart));
}

// Update histories
if (chatGPTSession.ResponseType != ResponseType.Error && chatGPTSession.ResponseType != ResponseType.Timeout)
{
await AddHistoriesAsync(chatGPTSession, stateData, token);
UpdateContext(chatGPTSession);
}
else
{
Expand All @@ -176,8 +160,8 @@ public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
throw new Exception($"ChatGPT ends with error");
}

// Process tags
var extractedTags = ExtractTags(chatGPTSession.CurrentStreamBuffer);

if (extractedTags.Count > 0 && HandleExtractedTags != null)
{
HandleExtractedTags(extractedTags, chatGPTSession);
Expand All @@ -189,17 +173,17 @@ public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
chatGPTSession.IsVisionAvailable = false;

// Get image
var imageBytes = await CaptureImage(extractedTags["vision"]);
var imageSource = extractedTags["vision"];
var imageBytes = await CaptureImage(imageSource);

// Make contexts
var lastUserContentText = ((TextContentPart)lastUserMessage.content[0]).text;
if (imageBytes != null)
{
chatGPTSession.Contexts.Add(new ChatGPTAssistantMessage(chatGPTSession.StreamBuffer));
// Image -> Text to get the better accuracy
chatGPTSession.Contexts.Add(new ChatGPTUserMessage(new List<IContentPart>() {
new ImageUrlContentPart("data:image/jpeg;base64," + Convert.ToBase64String(imageBytes)),
new TextContentPart(lastUserContentText)
new TextContentPart($"This is the image you captured. (source: {imageSource})")
}));
}
else
Expand All @@ -208,7 +192,7 @@ public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
}

// Call recursively with image
await StartStreamingAsync(chatGPTSession, stateData, customParameters, customHeaders, useFunctions, token);
await StartStreamingAsync(chatGPTSession, customParameters, customHeaders, useFunctions, token);
}
else
{
Expand Down
Loading

0 comments on commit 9402684

Please sign in to comment.