diff --git a/tools/server/webui/src/Config.ts b/tools/server/webui/src/Config.ts index c03ac287f3484..d557eaeb1bfab 100644 --- a/tools/server/webui/src/Config.ts +++ b/tools/server/webui/src/Config.ts @@ -1,5 +1,7 @@ import daisyuiThemes from 'daisyui/theme/object'; import { isNumeric } from './utils/misc'; +import { AVAILABLE_TOOLS } from './utils/tool_calling/register_tools'; +import { AgentTool } from './utils/tool_calling/agent_tool'; export const isDev = import.meta.env.MODE === 'development'; @@ -41,6 +43,14 @@ export const CONFIG_DEFAULT = { custom: '', // custom json-stringified object // experimental features pyIntepreterEnabled: false, + // Fields for tool calling + streamResponse: true, + ...Object.fromEntries( + Array.from(AVAILABLE_TOOLS.values()).map((tool: AgentTool) => [ + `tool_${tool.id}_enabled`, + false, + ]) + ), }; export const CONFIG_INFO: Record = { apiKey: 'Set the API Key if you are using --api-key option for the server.', diff --git a/tools/server/webui/src/assets/iframe_sandbox.html b/tools/server/webui/src/assets/iframe_sandbox.html new file mode 100644 index 0000000000000..20867132b5785 --- /dev/null +++ b/tools/server/webui/src/assets/iframe_sandbox.html @@ -0,0 +1,77 @@ + + + + JS Sandbox + + + +

JavaScript Execution Sandbox

+ + diff --git a/tools/server/webui/src/components/ChatMessage.tsx b/tools/server/webui/src/components/ChatMessage.tsx index ee59de450d1ff..f6705c15bedd9 100644 --- a/tools/server/webui/src/components/ChatMessage.tsx +++ b/tools/server/webui/src/components/ChatMessage.tsx @@ -1,8 +1,10 @@ -import { useMemo, useState } from 'react'; +import { useMemo, useState, Fragment } from 'react'; import { useAppContext } from '../utils/app.context'; import { Message, PendingMessage } from '../utils/types'; import { classNames } from '../utils/misc'; import MarkdownDisplay, { CopyButton } from './MarkdownDisplay'; +import { ToolCallArgsDisplay } from './tool_calling/ToolCallArgsDisplay'; +import { ToolCallResultDisplay } from './tool_calling/ToolCallResultDisplay'; import { ArrowPathIcon, ChevronLeftIcon, @@ -20,6 +22,7 @@ interface SplitMessage { export default function ChatMessage({ msg, + chainedParts, siblingLeafNodeIds, siblingCurrIdx, id, @@ -29,6 +32,7 @@ export default function ChatMessage({ isPending, }: { msg: Message | PendingMessage; + chainedParts?: (Message | PendingMessage)[]; siblingLeafNodeIds: Message['id'][]; siblingCurrIdx: number; id?: string; @@ -57,8 +61,15 @@ export default function ChatMessage({ // for reasoning model, we split the message into content and thought // TODO: implement this as remark/rehype plugin in the future - const { content, thought, isThinking }: SplitMessage = useMemo(() => { - if (msg.content === null || msg.role !== 'assistant') { + const { + content: mainDisplayableContent, + thought, + isThinking, + }: SplitMessage = useMemo(() => { + if ( + msg.content === null || + (msg.role !== 'assistant' && msg.role !== 'tool') + ) { return { content: msg.content }; } let actualContent = ''; @@ -78,11 +89,21 @@ export default function ChatMessage({ actualContent += thinkSplit[0]; } } + return { content: actualContent, thought, isThinking }; }, [msg]); if (!viewingChat) return null; + const toolCalls = msg.tool_calls ?? null; + + const hasContentInMainMsg = + mainDisplayableContent && mainDisplayableContent.trim() !== ''; + const hasContentInChainedParts = chainedParts?.some( + (part) => part.content && part.content.trim() !== '' + ); + const entireTurnHasSomeDisplayableContent = + hasContentInMainMsg || hasContentInChainedParts; const isUser = msg.role === 'user'; return ( @@ -141,7 +162,9 @@ export default function ChatMessage({ {/* not editing content, render message */} {editingContent === null && ( <> - {content === null ? ( + {mainDisplayableContent === null && + !toolCalls && + !chainedParts?.length ? ( <> {/* show loading dots for pending message */} @@ -158,13 +181,53 @@ export default function ChatMessage({ /> )} - + {msg.role === 'tool' && mainDisplayableContent ? ( + + ) : ( + mainDisplayableContent && + mainDisplayableContent.trim() !== '' && ( + + ) + )} )} + {toolCalls && + toolCalls.map((toolCall) => ( + + ))} + + {chainedParts?.map((part) => ( + + {part.role === 'tool' && part.content && ( + + )} + + {part.role === 'assistant' && part.content && ( +
+ +
+ )} + + {part.tool_calls && + part.tool_calls.map((toolCall) => ( + + ))} +
+ ))} {/* render timings if enabled */} {timings && config.showTokensPerSecond && (
@@ -195,7 +258,7 @@ export default function ChatMessage({
{/* actions for each message */} - {msg.content !== null && ( + {(entireTurnHasSomeDisplayableContent || msg.role === 'user') && (
{ - if (msg.content !== null) { + if (entireTurnHasSomeDisplayableContent) { onRegenerateMessage(msg as Message); } }} - disabled={msg.content === null} + disabled={ + !entireTurnHasSomeDisplayableContent || msg.content === null + } tooltipsContent="Regenerate response" > @@ -263,7 +328,17 @@ export default function ChatMessage({ )} )} - + {entireTurnHasSomeDisplayableContent && ( + p.role === 'assistant' && p.content) + ?.content ?? + '' + } + /> + )}
)} diff --git a/tools/server/webui/src/components/ChatScreen.tsx b/tools/server/webui/src/components/ChatScreen.tsx index 09c601ef2366a..fc339ed8d5aff 100644 --- a/tools/server/webui/src/components/ChatScreen.tsx +++ b/tools/server/webui/src/components/ChatScreen.tsx @@ -27,6 +27,7 @@ import { scrollToBottom, useChatScroll } from './useChatScroll.tsx'; */ export interface MessageDisplay { msg: Message | PendingMessage; + chainedParts?: (Message | PendingMessage)[]; // For merging consecutive assistant/tool messages siblingLeafNodeIds: Message['id'][]; siblingCurrIdx: number; isPending?: boolean; @@ -69,18 +70,72 @@ function getListMessageDisplay( } return currNode?.id ?? -1; }; + const processedIds = new Set(); // traverse the current nodes - for (const msg of currNodes) { - const parentNode = nodeMap.get(msg.parent ?? -1); - if (!parentNode) continue; - const siblings = parentNode.children; - if (msg.type !== 'root') { - res.push({ - msg, - siblingLeafNodeIds: siblings.map(findLeafNode), - siblingCurrIdx: siblings.indexOf(msg.id), - }); + for (const currentMessage of currNodes) { + if (processedIds.has(currentMessage.id) || currentMessage.type === 'root') { + continue; } + + const displayMsg = currentMessage; + const chainedParts: (Message | PendingMessage)[] = []; + processedIds.add(displayMsg.id); + + if (displayMsg.role === 'assistant') { + let currentLinkInChain = displayMsg; // Start with the initial assistant message + + // Loop to chain subsequent tool calls and their assistant responses + while (true) { + if (currentLinkInChain.children.length !== 1) { + // Stop if there isn't a single, clear next step in the chain + // or if the current link has no children. + break; + } + + const childId = currentLinkInChain.children[0]; + const childNode = nodeMap.get(childId); + + if (!childNode || processedIds.has(childNode.id)) { + // Child not found or already processed, end of chain + break; + } + + // Scenario 1: Current is Assistant, next is Tool + if ( + currentLinkInChain.role === 'assistant' && + childNode.role === 'tool' + ) { + chainedParts.push(childNode); + processedIds.add(childNode.id); + currentLinkInChain = childNode; // Continue chain from the tool message + } + // Scenario 2: Current is Tool, next is Assistant + else if ( + currentLinkInChain.role === 'tool' && + childNode.role === 'assistant' + ) { + chainedParts.push(childNode); + processedIds.add(childNode.id); + currentLinkInChain = childNode; // Continue chain from the assistant message + // This assistant message might make further tool calls + } + // Scenario 3: Pattern broken (e.g., Assistant -> Assistant, or Tool -> Tool) + else { + break; // Pattern broken, end of this specific tool-use chain + } + } + } + + const parentNode = nodeMap.get(displayMsg.parent ?? -1); + if (!parentNode && displayMsg.type !== 'root') continue; // Skip if parent not found for non-root + + const siblings = parentNode ? parentNode.children : []; + res.push({ + msg: displayMsg, + chainedParts: chainedParts.length > 0 ? chainedParts : undefined, + siblingLeafNodeIds: siblings.map(findLeafNode), + siblingCurrIdx: siblings.indexOf(displayMsg.id), + }); } return res; } @@ -136,13 +191,33 @@ export default function ChatScreen() { } textarea.setValue(''); scrollToBottom(false); + + // Determine the ID of the actual last message to use as parent + let parentMessageId: Message['id'] | null = null; + const lastMessageDisplayItem = messages.at(-1); + + if (lastMessageDisplayItem) { + if ( + lastMessageDisplayItem.chainedParts && + lastMessageDisplayItem.chainedParts.length > 0 + ) { + // If the last display item has chained parts, the true last message is the last part of that chain + parentMessageId = + lastMessageDisplayItem.chainedParts.at(-1)?.id ?? + lastMessageDisplayItem.msg.id; + } else { + // Otherwise, it's the main message of the last display item + parentMessageId = lastMessageDisplayItem.msg.id; + } + } + // If messages is empty (e.g., new chat), parentMessageId will remain null. + // sendMessage handles parentId = null correctly for starting new conversations. setCurrNodeId(-1); - // get the last message node - const lastMsgNodeId = messages.at(-1)?.msg.id ?? null; + if ( !(await sendMessage( currConvId, - lastMsgNodeId, + parentMessageId, lastInpMsg, extraContext.items, onChunk @@ -248,6 +323,7 @@ export default function ChatScreen() { ; - label: string | React.ReactElement; - help?: string | React.ReactElement; + label: string | ReactElement; + help?: string | ReactElement; key: SettKey; } @@ -100,6 +103,11 @@ const SETTING_SECTIONS: SettingSection[] = [ key, }) as SettingFieldInput ), + { + type: SettingInputType.CHECKBOX, + label: 'Enable response streaming', + key: 'streamResponse', + }, { type: SettingInputType.SHORT_INPUT, label: 'Paste length to file', @@ -169,6 +177,48 @@ const SETTING_SECTIONS: SettingSection[] = [ }, ], }, + { + title: ( + <> + + Tool Calling + + ), + fields: [ + { + type: SettingInputType.CUSTOM, + key: 'custom', + component: () => ( +
+

Important Note:

+

+ Response streaming must be disabled to use tool + calling. Individual tools (listed below) will be automatically + disabled if streaming is enabled. +

+
+ ), + }, + ...Array.from(AVAILABLE_TOOLS.values()).map( + (tool: AgentTool) => + ({ + type: SettingInputType.CHECKBOX, + label: ( + <> + {tool.name || tool.id} + {tool.toolDescription && ( + + Agent tool description: + {tool.toolDescription} + + )} + + ), + key: `tool_${tool.id}_enabled` as SettKey, + }) as SettingFieldInput + ), + ], + }, { title: ( <> @@ -417,6 +467,11 @@ export default function SettingDialog({ /> ); } else if (field.type === SettingInputType.CHECKBOX) { + const isToolToggle = + typeof field.key === 'string' && + field.key.startsWith('tool_') && + field.key.endsWith('_enabled'); + const isDisabled = isToolToggle && localConfig.streamResponse; return ( ); } else if (field.type === SettingInputType.CUSTOM) { @@ -531,11 +587,13 @@ function SettingsModalCheckbox({ value, onChange, label, + disabled, }: { configKey: SettKey; value: boolean; onChange: (value: boolean) => void; - label: string; + label: React.ReactElement | string; + disabled?: boolean; }) { return (
@@ -544,6 +602,7 @@ function SettingsModalCheckbox({ className="toggle" checked={value} onChange={(e) => onChange(e.target.checked)} + disabled={disabled} /> {label || configKey}
diff --git a/tools/server/webui/src/components/tool_calling/ToolCallArgsDisplay.tsx b/tools/server/webui/src/components/tool_calling/ToolCallArgsDisplay.tsx new file mode 100644 index 0000000000000..c76401d18da69 --- /dev/null +++ b/tools/server/webui/src/components/tool_calling/ToolCallArgsDisplay.tsx @@ -0,0 +1,23 @@ +import { ToolCallRequest } from '../../utils/types'; + +export const ToolCallArgsDisplay = ({ + toolCall, + baseClassName = 'collapse bg-base-200 collapse-arrow mb-4', +}: { + toolCall: ToolCallRequest; + baseClassName?: string; +}) => { + return ( +
+ + Tool call: {toolCall.function.name} + +
+
Arguments:
+
+          {JSON.stringify(JSON.parse(toolCall.function.arguments), null, 2)}
+        
+
+
+ ); +}; diff --git a/tools/server/webui/src/components/tool_calling/ToolCallResultDisplay.tsx b/tools/server/webui/src/components/tool_calling/ToolCallResultDisplay.tsx new file mode 100644 index 0000000000000..88b7a68cd7c1b --- /dev/null +++ b/tools/server/webui/src/components/tool_calling/ToolCallResultDisplay.tsx @@ -0,0 +1,20 @@ +export const ToolCallResultDisplay = ({ + content, + baseClassName = 'collapse bg-base-200 collapse-arrow mb-4', +}: { + content: string; + baseClassName?: string; +}) => { + return ( +
+ + Tool call result + +
+
+          {content}
+        
+
+
+ ); +}; diff --git a/tools/server/webui/src/utils/app.context.tsx b/tools/server/webui/src/utils/app.context.tsx index 96cffd95aba7c..3f653b4cceb42 100644 --- a/tools/server/webui/src/utils/app.context.tsx +++ b/tools/server/webui/src/utils/app.context.tsx @@ -6,6 +6,7 @@ import { LlamaCppServerProps, Message, PendingMessage, + ToolCallRequest, ViewingChat, } from './types'; import StorageUtils from './storage'; @@ -17,6 +18,7 @@ import { } from './misc'; import { BASE_URL, CONFIG_DEFAULT, isDev } from '../Config'; import { matchPath, useLocation, useNavigate } from 'react-router'; +import { AVAILABLE_TOOLS } from './tool_calling/register_tools'; import toast from 'react-hot-toast'; interface AppContextValue { @@ -157,8 +159,8 @@ export const AppContextProvider = ({ convId: string, leafNodeId: Message['id'], onChunk: CallbackGeneratedChunk - ) => { - if (isGenerating(convId)) return; + ): Promise => { + if (isGenerating(convId)) return leafNodeId; const config = StorageUtils.getConfig(); const currConversation = await StorageUtils.getOneConversation(convId); @@ -204,10 +206,18 @@ export const AppContextProvider = ({ } if (isDev) console.log({ messages }); + // tool calling from clientside + const enabledTools = Array.from( + AVAILABLE_TOOLS, + ([_name, tool], _index) => tool + ) + .filter((tool) => tool.enabled) + .map((tool) => tool.specs); + // prepare params const params = { messages, - stream: true, + stream: config.streamResponse, cache_prompt: true, samplers: config.samplers, temperature: config.temperature, @@ -229,6 +239,7 @@ export const AppContextProvider = ({ dry_penalty_last_n: config.dry_penalty_last_n, max_tokens: config.max_tokens, timings_per_token: !!config.showTokensPerSecond, + tools: enabledTools.length > 0 ? enabledTools : undefined, ...(config.custom.length ? JSON.parse(config.custom) : {}), }; @@ -244,37 +255,147 @@ export const AppContextProvider = ({ body: JSON.stringify(params), signal: abortController.signal, }); + if (fetchResponse.status !== 200) { const body = await fetchResponse.json(); throw new Error(body?.error?.message || 'Unknown error'); } - const chunks = getSSEStreamAsync(fetchResponse); - for await (const chunk of chunks) { - // const stop = chunk.stop; - if (chunk.error) { - throw new Error(chunk.error?.message || 'Unknown error'); + + // Tool calls results we will process later + const pendingMessages: PendingMessage[] = []; + let lastMsgId = pendingMsg.id; + let shouldContinueChain = false; + + if (params.stream) { + const chunks = getSSEStreamAsync(fetchResponse); + for await (const chunk of chunks) { + // const stop = chunk.stop; + if (chunk.error) { + throw new Error(chunk.error?.message || 'Unknown error'); + } + const addedContent = chunk.choices[0].delta.content; + const lastContent = pendingMsg.content || ''; + if (addedContent) { + pendingMsg = { + ...pendingMsg, + content: lastContent + addedContent, + }; + } + const timings = chunk.timings; + if (timings && config.showTokensPerSecond) { + // only extract what's really needed, to save some space + pendingMsg.timings = { + prompt_n: timings.prompt_n, + prompt_ms: timings.prompt_ms, + predicted_n: timings.predicted_n, + predicted_ms: timings.predicted_ms, + }; + } + setPending(convId, pendingMsg); + onChunk(); // don't need to switch node for pending message + } + } else { + const responseData = await fetchResponse.json(); + if (responseData.error) { + throw new Error(responseData.error?.message || 'Unknown error'); + } + + const choice = responseData.choices[0]; + const messageFromAPI = choice.message; + let newContent = ''; + + if (messageFromAPI.content) { + newContent = messageFromAPI.content; } - const addedContent = chunk.choices[0].delta.content; - const lastContent = pendingMsg.content || ''; - if (addedContent) { + + // Process tool calls + if (messageFromAPI.tool_calls && messageFromAPI.tool_calls.length > 0) { + // Store the raw tool calls in the pendingMsg pendingMsg = { ...pendingMsg, - content: lastContent + addedContent, + tool_calls: messageFromAPI.tool_calls as ToolCallRequest[], }; + + for (let i = 0; i < messageFromAPI.tool_calls.length; i++) { + const toolCall = messageFromAPI.tool_calls[i] as ToolCallRequest; + if (toolCall) { + // Set up call id + toolCall.call_id ??= `call_${i}`; + + if (isDev) console.log({ tc: toolCall }); + + // Process tool call + const toolResult = await AVAILABLE_TOOLS.get( + toolCall.function.name + )?.processCall(toolCall); + + const toolMsg: PendingMessage = { + id: lastMsgId + 1, + type: 'text', + convId: convId, + content: toolResult?.output ?? 'Error: invalid tool call!', + timestamp: Date.now(), + role: 'tool', + parent: lastMsgId, + children: [], + }; + pendingMessages.push(toolMsg); + lastMsgId += 1; + } + } } - const timings = chunk.timings; - if (timings && config.showTokensPerSecond) { - // only extract what's really needed, to save some space + + if (newContent !== '') { + pendingMsg = { + ...pendingMsg, + content: newContent, + }; + } + + // Handle timings from the non-streaming response + const apiTimings = responseData.timings; + if (apiTimings && config.showTokensPerSecond) { pendingMsg.timings = { - prompt_n: timings.prompt_n, - prompt_ms: timings.prompt_ms, - predicted_n: timings.predicted_n, - predicted_ms: timings.predicted_ms, + prompt_n: apiTimings.prompt_n, + prompt_ms: apiTimings.prompt_ms, + predicted_n: apiTimings.predicted_n, + predicted_ms: apiTimings.predicted_ms, }; } - setPending(convId, pendingMsg); - onChunk(); // don't need to switch node for pending message + + for (const pendMsg of pendingMessages) { + setPending(convId, pendMsg); + onChunk(pendMsg.id); // Update UI to show the processed message + } + + shouldContinueChain = choice.finish_reason === 'tool_calls'; + } + + pendingMessages.unshift(pendingMsg); + if ( + pendingMsg.content !== null || + (pendingMsg.tool_calls?.length ?? 0) > 0 + ) { + await StorageUtils.appendMsgChain( + pendingMessages as Message[], + leafNodeId + ); } + + // if message ended due to "finish_reason": "tool_calls" + // resend it to assistant to process the result. + if (shouldContinueChain) { + lastMsgId = await generateMessage(convId, lastMsgId, onChunk); + } + + setPending(convId, null); + onChunk(lastMsgId); // trigger scroll to bottom and switch to the last node + + // Fetch messages from DB + const savedMsgs = await StorageUtils.getMessages(convId); + console.log({ savedMsgs }); + + return lastMsgId; } catch (err) { setPending(convId, null); if ((err as Error).name === 'AbortError') { @@ -288,11 +409,7 @@ export const AppContextProvider = ({ } } - if (pendingMsg.content !== null) { - await StorageUtils.appendMsg(pendingMsg as Message, leafNodeId); - } - setPending(convId, null); - onChunk(pendingId); // trigger scroll to bottom and switch to the last node + return pendingId; }; const sendMessage = async ( @@ -316,7 +433,7 @@ export const AppContextProvider = ({ const now = Date.now(); const currMsgId = now; - StorageUtils.appendMsg( + await StorageUtils.appendMsg( { id: currMsgId, timestamp: now, diff --git a/tools/server/webui/src/utils/misc.ts b/tools/server/webui/src/utils/misc.ts index ba760e83bb282..e91de6703d453 100644 --- a/tools/server/webui/src/utils/misc.ts +++ b/tools/server/webui/src/utils/misc.ts @@ -27,7 +27,6 @@ export async function* getSSEStreamAsync(fetchResponse: Response) { .pipeThrough(new TextLineStream()); // @ts-expect-error asyncIterator complains about type, but it should work for await (const line of asyncIterator(lines)) { - //if (isDev) console.log({ line }); if (line.startsWith('data:') && !line.endsWith('[DONE]')) { const data = JSON.parse(line.slice(5)); yield data; @@ -63,10 +62,15 @@ export const copyStr = (textToCopy: string) => { export function normalizeMsgsForAPI(messages: Readonly) { return messages.map((msg) => { if (msg.role !== 'user' || !msg.extra) { - return { + const apiMessage = { role: msg.role, content: msg.content, } as APIMessage; + + if (msg.tool_calls && msg.tool_calls.length > 0) { + apiMessage.tool_calls = msg.tool_calls; + } + return apiMessage; } // extra content first, then user text message in the end @@ -117,14 +121,26 @@ export function filterThoughtFromMsgs(messages: APIMessage[]) { return msg; } // assistant message is always a string - const contentStr = msg.content as string; - return { + // except when tool_calls is present - it can be null then + const contentStr = msg.content as string | null; + let content; + if (msg.role === 'assistant' && contentStr !== null) { + content = contentStr?.split('').at(-1)!.trim(); + } else { + content = contentStr; + } + + const filteredMessage = { role: msg.role, - content: - msg.role === 'assistant' - ? contentStr.split('').at(-1)!.trim() - : contentStr, + content: content, + tool_calls: msg.tool_calls, } as APIMessage; + + if (msg.tool_calls && msg.tool_calls.length > 0) { + filteredMessage.tool_calls = msg.tool_calls; + } + + return filteredMessage; }); } diff --git a/tools/server/webui/src/utils/storage.ts b/tools/server/webui/src/utils/storage.ts index 505693e9272ac..46aa0665fe7bb 100644 --- a/tools/server/webui/src/utils/storage.ts +++ b/tools/server/webui/src/utils/storage.ts @@ -133,39 +133,85 @@ const StorageUtils = { msg: Exclude, parentNodeId: Message['id'] ): Promise { - if (msg.content === null) return; - const { convId } = msg; - await db.transaction('rw', db.conversations, db.messages, async () => { - const conv = await StorageUtils.getOneConversation(convId); - const parentMsg = await db.messages - .where({ convId, id: parentNodeId }) - .first(); - // update the currNode of conversation - if (!conv) { - throw new Error(`Conversation ${convId} does not exist`); - } - if (!parentMsg) { - throw new Error( - `Parent message ID ${parentNodeId} does not exist in conversation ${convId}` - ); - } - await db.conversations.update(convId, { - lastModified: Date.now(), - currNode: msg.id, - }); - // update parent - await db.messages.update(parentNodeId, { - children: [...parentMsg.children, msg.id], - }); - // create message - await db.messages.add({ - ...msg, - parent: parentNodeId, - children: [], + await this.appendMsgChain([msg], parentNodeId); + }, + + /** + * Adds chain of messages to the DB, usually + * produced by tool calling. + */ + async appendMsgChain( + messages: Exclude[], + parentNodeId: Message['id'] + ): Promise { + if (messages.length === 0) return; + + const { convId } = messages[0]; + + // Verify conversation exists + const conv = await this.getOneConversation(convId); + if (!conv) { + throw new Error(`Conversation ${convId} does not exist`); + } + + // Verify starting parent exists + const startParent = await db.messages + .where({ convId, id: parentNodeId }) + .first(); + if (!startParent) { + throw new Error( + `Starting parent message ${parentNodeId} does not exist in conversation ${convId}` + ); + } + + // Get the last message ID for updating the conversation + const lastMsgId = messages[messages.length - 1].id; + + try { + // Process all messages in a single transaction + await db.transaction('rw', db.messages, db.conversations, () => { + // First message connects to startParentId + let parentId = parentNodeId; + const parentChildren = [...startParent.children]; + + for (let i = 0; i < messages.length; i++) { + const msg = messages[i]; + + // Add this message to its parent's children + if (i === 0) { + // First message - update the starting parent + parentChildren.push(msg.id); + db.messages.update(parentId, { children: parentChildren }); + } else { + // Other messages - previous message is the parent + db.messages.update(parentId, { children: [msg.id] }); + } + + // Add the message + db.messages.add({ + ...msg, + parent: parentId, + children: [], // Will be updated if this message has children + }); + + // Next message's parent is this message + parentId = msg.id; + } + + // Update the conversation + db.conversations.update(convId, { + lastModified: Date.now(), + currNode: lastMsgId, + }); }); - }); + } catch (error) { + console.error('Error saving message chain:', error); + throw error; + } + dispatchConversationChange(convId); }, + /** * remove conversation by id */ diff --git a/tools/server/webui/src/utils/tool_calling/agent_tool.ts b/tools/server/webui/src/utils/tool_calling/agent_tool.ts new file mode 100644 index 0000000000000..5b60543215cef --- /dev/null +++ b/tools/server/webui/src/utils/tool_calling/agent_tool.ts @@ -0,0 +1,70 @@ +import StorageUtils from '../storage'; +import { + ToolCallRequest, + ToolCallOutput, + ToolCallParameters, + ToolCallSpec, +} from '../types'; + +export abstract class AgentTool { + constructor( + public readonly id: string, + public readonly name: string, + public readonly toolDescription: string, + public readonly parameters: ToolCallParameters + ) {} + + /** + * "Public" wrapper for the tool call processing logic. + * @param call The tool call object from the API response. + * @returns The tool call output or undefined if the tool is not enabled. + */ + public async processCall( + call: ToolCallRequest + ): Promise { + if (this.enabled) { + try { + return await this._process(call); + } catch (error) { + console.error(`Error processing tool call for ${this.id}:`, error); + return { + type: 'function_call_output', + call_id: call.call_id, + output: `Error during tool execution: ${(error as Error).message}`, + } as ToolCallOutput; + } + } + return undefined; + } + + /** + * Whether calling this tool is enabled. + * User can toggle the status from the settings panel. + * @returns enabled status. + */ + public get enabled(): boolean { + return StorageUtils.getConfig()[`tool_${this.id}_enabled`] ?? false; + } + + /** + * Specifications for the tool call. + * https://github.com/ggml-org/llama.cpp/blob/master/docs/function-calling.md + * https://platform.openai.com/docs/guides/function-calling?api-mode=responses#defining-functions + */ + public get specs(): ToolCallSpec { + return { + type: 'function', + function: { + name: this.id, + description: this.toolDescription, + parameters: this.parameters, + }, + }; + } + + /** + * The actual tool call processing logic. + * @param call: The tool call object from the API response. + */ + protected abstract _process(call: ToolCallRequest): Promise; +} diff --git a/tools/server/webui/src/utils/tool_calling/js_repl_tool.ts b/tools/server/webui/src/utils/tool_calling/js_repl_tool.ts new file mode 100644 index 0000000000000..7f718e6661557 --- /dev/null +++ b/tools/server/webui/src/utils/tool_calling/js_repl_tool.ts @@ -0,0 +1,159 @@ +import { ToolCallRequest, ToolCallOutput, ToolCallParameters } from '../types'; +import { AgentTool } from './agent_tool'; + +// Import the HTML content as a raw string +import iframeHTMLContent from '../../assets/iframe_sandbox.html?raw'; + +interface IframeMessage { + call_id: string; + output?: string; + error?: string; + command?: 'executeCode' | 'iframeReady'; + code?: string; +} + +export class JSReplAgentTool extends AgentTool { + private static readonly ID = 'javascript_interpreter'; + private iframe: HTMLIFrameElement | null = null; + private iframeReadyPromise: Promise | null = null; + private resolveIframeReady: (() => void) | null = null; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + private rejectIframeReady: ((reason?: any) => void) | null = null; + private pendingCalls = new Map void>(); + private messageHandler: + | ((event: MessageEvent) => void) + | null = null; + + constructor() { + super( + JSReplAgentTool.ID, + 'Javascript interpreter', + 'Executes JavaScript code in a sandboxed iframe. The code should be self-contained valid javascript. You can use console.log(variable) to print out intermediate values, which will be captured.', + { + type: 'object', + properties: { + code: { + type: 'string', + description: 'Valid JavaScript code to execute.', + }, + }, + required: ['code'], + } as ToolCallParameters + ); + this.initIframe(); + } + + private initIframe(): void { + if (typeof window === 'undefined' || typeof document === 'undefined') { + console.warn( + 'JSReplAgentTool: Not in a browser environment, iframe will not be created.' + ); + return; + } + + this.iframeReadyPromise = new Promise((resolve, reject) => { + this.resolveIframeReady = resolve; + this.rejectIframeReady = reject; + }); + + this.messageHandler = (event: MessageEvent) => { + if ( + !event.data || + !this.iframe || + !this.iframe.contentWindow || + event.source !== this.iframe.contentWindow + ) { + return; + } + + const { command, call_id, output, error } = event.data; + if (command === 'iframeReady' && call_id === 'initial_ready') { + if (this.resolveIframeReady) { + this.resolveIframeReady(); + this.resolveIframeReady = null; + this.rejectIframeReady = null; + } + return; + } + if (typeof call_id !== 'string') { + return; + } + if (this.pendingCalls.has(call_id)) { + const callback = this.pendingCalls.get(call_id)!; + callback({ + type: 'function_call_output', + call_id: call_id, + output: error ? `Error: ${error}` : (output ?? ''), + } as ToolCallOutput); + this.pendingCalls.delete(call_id); + } + }; + window.addEventListener('message', this.messageHandler); + + this.iframe = document.createElement('iframe'); + this.iframe.style.display = 'none'; + this.iframe.sandbox.add('allow-scripts'); + + // Use srcdoc with the imported HTML content + this.iframe.srcdoc = iframeHTMLContent; + + document.body.appendChild(this.iframe); + + setTimeout(() => { + if (this.rejectIframeReady) { + this.rejectIframeReady(new Error('Iframe readiness timeout')); + this.resolveIframeReady = null; + this.rejectIframeReady = null; + } + }, 5000); + } + + async _process(tc: ToolCallRequest): Promise { + let error = null; + if ( + typeof window === 'undefined' || + !this.iframe || + !this.iframe.contentWindow || + !this.iframeReadyPromise + ) { + error = + 'Error: JavaScript interpreter is not available or iframe not ready.'; + } + + try { + await this.iframeReadyPromise; + } catch (e) { + error = `Error: Iframe for JavaScript interpreter failed to initialize. ${(e as Error).message}`; + } + + let args; + try { + args = JSON.parse(tc.function.arguments); + } catch (e) { + error = `Error: Could not parse arguments for tool call. ${(e as Error).message}`; + } + + const codeToExecute = args.code; + if (typeof codeToExecute !== 'string') { + error = 'Error: "code" argument must be a string.'; + } + + if (error) { + return { + type: 'function_call_output', + call_id: tc.call_id, + output: error, + } as ToolCallOutput; + } + + return new Promise((resolve) => { + this.pendingCalls.set(tc.call_id, resolve); + const message: IframeMessage = { + command: 'executeCode', + code: codeToExecute, + call_id: tc.call_id, + }; + this.iframe!.contentWindow!.postMessage(message, '*'); + }); + } +} diff --git a/tools/server/webui/src/utils/tool_calling/register_tools.ts b/tools/server/webui/src/utils/tool_calling/register_tools.ts new file mode 100644 index 0000000000000..dc0c2862cc41c --- /dev/null +++ b/tools/server/webui/src/utils/tool_calling/register_tools.ts @@ -0,0 +1,16 @@ +import { AgentTool } from './agent_tool'; +import { JSReplAgentTool } from './js_repl_tool'; + +/** + * Map of available tools for function calling. + * Note that these tools are not necessarily enabled by the user. + */ +export const AVAILABLE_TOOLS = new Map(); + +function registerTool(tool: T): T { + AVAILABLE_TOOLS.set(tool.id, tool); + return tool; +} + +// Available agent tools +export const jsReplTool = registerTool(new JSReplAgentTool()); diff --git a/tools/server/webui/src/utils/types.ts b/tools/server/webui/src/utils/types.ts index ba673dd9432da..b4171ee631e31 100644 --- a/tools/server/webui/src/utils/types.ts +++ b/tools/server/webui/src/utils/types.ts @@ -39,10 +39,11 @@ export interface Message { convId: string; type: 'text' | 'root'; timestamp: number; // timestamp from Date.now() - role: 'user' | 'assistant' | 'system'; + role: 'user' | 'assistant' | 'system' | 'tool'; content: string; timings?: TimingReport; extra?: MessageExtra[]; + tool_calls?: ToolCallRequest[]; // node based system for branching parent: Message['id']; children: Message['id'][]; @@ -83,6 +84,7 @@ export type APIMessageContentPart = export type APIMessage = { role: Message['role']; + tool_calls?: ToolCallRequest[]; content: string | APIMessageContentPart[]; }; @@ -113,6 +115,37 @@ export interface CanvasPyInterpreter { export type CanvasData = CanvasPyInterpreter; +export interface ToolCallRequest { + id: string; + type: 'function'; + call_id: string; + function: { + name: string; + arguments: string; // JSON string of arguments + }; +} + +export interface ToolCallSpec { + type: 'function'; + function: { + name: string; + description: string; + parameters: ToolCallParameters; + }; +} + +export interface ToolCallParameters { + type: 'object'; + properties: object; + required: string[]; +} + +export interface ToolCallOutput { + type: 'function_call_output'; + call_id: string; + output: string; +} + // a non-complete list of props, only contains the ones we need export interface LlamaCppServerProps { build_info: string;