Skip to content

Commit f124f18

Browse files
author
Cheuk Lun Ko
committed
Add math reasoning agent example
1 parent c30a588 commit f124f18

File tree

7 files changed

+291
-0
lines changed

7 files changed

+291
-0
lines changed
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
2+
# Solve Simple Math Problem Using AI Agent
3+
4+
To run the code shown on this page, open the MLX file in MATLAB®: [mlx-scripts/SolveSimpleMathProblemUsingAIAgent.mlx](mlx-scripts/SolveSimpleMathProblemUsingAIAgent.mlx)
5+
6+
This example shows how to build an AI agent to find the smallest root of a quadratic equation.
7+
8+
9+
AI agents are programs that autonomously plan and execute workflows. Typically, agents use large language models (LLMs) to process user queries and identify actions that need to be taken, also known as *tool calls*. The agent then executes the tool calls that the LLM has identified and returns the result to the LLM. Then, the LLM generates an answer or executes more tool calls.
10+
11+
12+
![image_0.png](SolveSimpleMathProblemUsingAIAgent_media/image_0.png)
13+
14+
# Specify OpenAI API Key
15+
16+
This example uses the OpenAI® API, which requires an OpenAI API key. For information on how to obtain an OpenAI API key, as well as pricing, terms and conditions of use, and information about available models, see the OpenAI documentation at [https://platform.openai.com/docs/overview](https://platform.openai.com/docs/overview).
17+
18+
19+
To connect to the OpenAI API from MATLAB® using LLMs with MATLAB, specify the OpenAI API key as an environment variable and save it to a file called ".env".
20+
21+
22+
![image_1.png](SolveSimpleMathProblemUsingAIAgent_media/image_1.png)
23+
24+
25+
To connect to OpenAI, the ".env" file must be on the search path. Load the environment file using the [`loadenv`](https://www.mathworks.com/help/matlab/ref/loadenv.html) function.
26+
27+
```matlab
28+
loadenv(".env")
29+
```
30+
# Define Tools
31+
32+
First, to enable the LLM to make tool calls, define an `openAIFunction` object for each tool. You can then pass the tools to the LLM by specifying the `Tools` name\-value argument of the [`openAIChat`](../doc/functions/openAIChat.md) function.
33+
34+
35+
For examples of how to use tool calling with the Large Language Models (LLMs) with MATLAB add\-on, see:
36+
37+
- [Analyze Scientific Papers Using ChatGPT Function Calls](./AnalyzeScientificPapersUsingFunctionCalls.md)
38+
- [Analyze Text Data Using Parallel Function Calls with ChatGPT](./AnalyzeTextDataUsingParallelFunctionCallwithChatGPT.md)
39+
40+
In this example, use an LLM to find the smallest root of a second\-order polynomial. To find the smallest root of any second\-order polynomial, give the agent access to two tools:
41+
42+
- Compute the roots of a second\-order polynomial.
43+
- Find the smallest real number from a list of two numbers.
44+
## Compute Roots of Second\-Order Polynomial
45+
46+
The `solveQuadraticEquation` function takes the three coefficients of a quadratic polynomial as input and uses the [`roots`](https://www.mathworks.com/help/matlab/ref/roots.html) function to return a column vector containing the two roots.
47+
48+
```matlab
49+
function r = solveQuadraticEquation(a,b,c)
50+
r = roots([a b c]);
51+
end
52+
```
53+
54+
Create an [`openAIFunction`](../doc/functions/openAIFunction.md) object that represents the `solveQuadraticEquation` function. Add the three coefficients as individual scalar input parameters.
55+
56+
```matlab
57+
toolSolveQuadraticEquation = openAIFunction("solveQuadraticEquation", ...
58+
"Compute the roots of a second-order polynomial of the form ax^2 + bx + c = 0.");
59+
toolSolveQuadraticEquation = addParameter(toolSolveQuadraticEquation,"a",type="number");
60+
toolSolveQuadraticEquation = addParameter(toolSolveQuadraticEquation,"b",type="number");
61+
toolSolveQuadraticEquation = addParameter(toolSolveQuadraticEquation,"c",type="number");
62+
```
63+
## Find Smallest Real Number
64+
65+
The `smallestRealNumber` function computes the smallest real number from a list of two numbers. The LLM used in this example expresses tool calls, including input parameters, in JSON format. Roots of polynomials can be complex, but JSON does not support complex numbers. Therefore, specify the roots as strings and convert them back into numbers using the [`str2double`](https://www.mathworks.com/help/matlab/ref/str2double.html) function.
66+
67+
```matlab
68+
function xMin = smallestRealNumber(strX1,strX2)
69+
allRoots = [str2double(strX1) str2double(strX2)];
70+
realRoots = allRoots(imag(allRoots)==0);
71+
if isempty(realRoots)
72+
xMin = "No real numbers.";
73+
else
74+
xMin = min(realRoots);
75+
end
76+
end
77+
```
78+
79+
Create an `openAIFunction` object that represents the `smallestRealNumber` function. Add the two numbers as individual string input parameters.
80+
81+
```matlab
82+
toolSmallestRealNumber = openAIFunction("smallestRealNumber", ...
83+
"Compute the smallest real number from a list of two numbers.");
84+
toolSmallestRealNumber = addParameter(toolSmallestRealNumber,"x1",type="string");
85+
toolSmallestRealNumber = addParameter(toolSmallestRealNumber,"x2",type="string");
86+
```
87+
# Evaluate Tool Calls
88+
89+
To enable structured and scalable usage of multiple tools, store both the `openAIFunction` objects and their corresponding MATLAB function handles in a dictionary `toolRegistry`.
90+
91+
```matlab
92+
toolRegistry = dictionary;
93+
toolRegistry("solveQuadraticEquation") = struct( ...
94+
"toolSpecification",toolSolveQuadraticEquation, ...
95+
"functionHandle",@solveQuadraticEquation);
96+
toolRegistry("smallestRealNumber") = struct( ...
97+
"toolSpecification",toolSmallestRealNumber, ...
98+
"functionHandle",@smallestRealNumber);
99+
```
100+
101+
Define a function to evaluate tool calls identified by the LLM. LLMs can hallucinate tool calls or make errors about the parameters that the tools need. Therefore, first validate the tool name and parameters by comparing them to the `toolRegistry` dictionary. Then, run the functions associated with the tools using the [`feval`](https://www.mathworks.com/help/matlab/ref/feval.html) function.
102+
103+
```matlab
104+
function result = evaluateToolCall(toolCall,toolRegistry)
105+
% Validate tool name
106+
toolName = toolCall.function.name;
107+
assert(isKey(toolRegistry,toolName),"Invalid tool name ''%s''.",toolName)
108+
109+
% Validate JSON syntax
110+
try
111+
args = jsondecode(toolCall.function.arguments);
112+
catch
113+
error("Model returned invalid JSON syntax for arguments of tool ''%s''.",toolName);
114+
end
115+
116+
% Validate tool parameters
117+
tool = toolRegistry(toolName);
118+
requiredArgs = string(fieldnames(tool.toolSpecification.Parameters));
119+
assert(all(isfield(args,requiredArgs)),"Invalid tool parameters: %s",strjoin(fieldnames(args),","))
120+
121+
extraArgs = setdiff(string(fieldnames(args)),requiredArgs);
122+
if ~isempty(extraArgs)
123+
warning("Ignoring extra tool parameters: %s",strjoin(extraArgs,","));
124+
end
125+
126+
% Execute tool
127+
argValues = arrayfun(@(fieldName) args.(fieldName),requiredArgs,UniformOutput=false);
128+
try
129+
result = feval(tool.functionHandle,argValues{:});
130+
catch ME
131+
error("Tool call '%s' failed with error: %s",toolName,ME.message)
132+
end
133+
end
134+
```
135+
# Set Up ReAct Agent
136+
137+
Next, define the function that sets up and runs the AI agent.
138+
139+
140+
The agentic architecture used in this example is based on a ReAct agent [\[1\]](#M_4d5d).
141+
142+
143+
![image_2.png](SolveSimpleMathProblemUsingAIAgent_media/image_2.png)
144+
145+
146+
This architecture is an iterative workflow. For each iteration, the agent performs three steps:
147+
148+
1. Thought — The agent plans the next action. To do this, first generate a thought in natural language, then generate a tool call based on that thought.
149+
2. Action — The agent executes the next action.
150+
3. Observation — The agent observes the tool output.
151+
152+
Define the function `runAgent` that answers a user query `userQuery` using the ReAct agent architecture and the tools provided in `toolRegistry`.
153+
154+
```matlab
155+
function agentResponse = runAgent(userQuery,toolRegistry)
156+
```
157+
158+
To ensure the agent stops after it answers the user query, create a tool `finalAnswer` and add it to the tool list.
159+
160+
```matlab
161+
toolFinalAnswer = openAIFunction("finalAnswer","Call this when you have reached the final answer.");
162+
tools = [toolRegistry.values.toolSpecification toolFinalAnswer];
163+
```
164+
165+
Create a system prompt. Instruct the agent to call the `finalAnswer` tool after it answers the user query.
166+
167+
```matlab
168+
systemPrompt = ...
169+
"You are a mathematical reasoning agent that can call math tools. " + ...
170+
"Solve the problem. When done, call the tool finalAnswer else you will get stuck in a loop.";
171+
```
172+
173+
Connect to the OpenAI Chat Completion API using the [`openAIChat`](../doc/functions/openAIChat.md) function. Use the OpenAI model `"gpt-4.1"`. Provide the LLM with tools using the `Tools` name\-value argument. Initialize the message history.
174+
175+
```matlab
176+
chat = openAIChat(systemPrompt,ModelName="gpt-4.1",Tools=tools);
177+
history = messageHistory;
178+
```
179+
180+
Add the user query to the message history. Display the user query.
181+
182+
```matlab
183+
history = addUserMessage(history,userQuery);
184+
disp("User: " + userQuery);
185+
```
186+
187+
Initialize the agentic loop. To ensure the program terminates, limit the number of iterations to `10`.
188+
189+
```matlab
190+
maxSteps = 10;
191+
stepCount = 0;
192+
problemSolved = false;
193+
while ~problemSolved
194+
if stepCount >= maxSteps
195+
error("Agent stopped after reaching maximum step limit (%d).",maxSteps);
196+
end
197+
stepCount = stepCount + 1;
198+
```
199+
200+
Instruct the agent to plan the next step. Generate a response from the message history. To ensure the agent outputs text, set the `ToolChoice` name\-value argument to `"none"`.
201+
202+
```matlab
203+
history = addUserMessage(history,"Plan your next step.");
204+
[thought,completeOutput] = generate(chat,history,ToolChoice="none");
205+
disp("[Thought] " + thought);
206+
history = addResponseMessage(history,completeOutput);
207+
```
208+
209+
Instruct the agent to call a tool. Instruct the agent to always call a tool in this step.
210+
211+
```matlab
212+
history = addUserMessage(history,"Call tools to solve the problem. Always call a tool.");
213+
[~,completeOutput] = generate(chat,history);
214+
history = addResponseMessage(history,completeOutput);
215+
actions = completeOutput.tool_calls;
216+
```
217+
218+
If the agent calls the `finalAnswer` tool, add the return the final agent response to the message history.
219+
220+
```matlab
221+
if isscalar(actions) && strcmp(actions(1).function.name,"finalAnswer")
222+
history = addToolMessage(history,actions.id,"finalAnswer","Final answer below");
223+
history = addUserMessage(history,"Return the final answer as a statement.");
224+
agentResponse = generate(chat,history,ToolChoice="none");
225+
problemSolved = true;
226+
```
227+
228+
Otherwise, log and evaluate each tool call in the agent output.
229+
230+
```matlab
231+
else
232+
for i = 1:numel(actions)
233+
action = actions(i);
234+
toolName = action.function.name;
235+
fprintf("[Action] Calling tool '%s' with args: %s\n",toolName,jsonencode(action.function.arguments));
236+
observation = evaluateToolCall(action,toolRegistry);
237+
```
238+
239+
To enable the agent to observe the output, add the tool call result to the message history.
240+
241+
```matlab
242+
fprintf("[Observation] Result from tool '%s': %s\n",toolName,jsonencode(string(observation)));
243+
history = addToolMessage(history,action.id,toolName,"Observation: " + jsonencode(string(observation)));
244+
end
245+
end
246+
end
247+
end
248+
```
249+
# Answer Query
250+
251+
Define the query. Answer the query using the agent.
252+
253+
```matlab
254+
userQuery = "What is the smallest root of x^2+2x-3=0?";
255+
agentResponse = runAgent(userQuery,toolRegistry);
256+
```
257+
258+
```matlabTextOutput
259+
User: What is the smallest root of x^2+2x-3=0?
260+
[Thought] To find the smallest root, I will:
261+
1. Solve the quadratic equation x^2 + 2x - 3 = 0 to find both roots.
262+
2. Compare the two roots and select the smallest one.
263+
[Action] Calling tool 'solveQuadraticEquation' with args: "{\"a\":1,\"b\":2,\"c\":-3}"
264+
[Observation] Result from tool 'solveQuadraticEquation': ["-3","1"]
265+
[Thought] Now that I have both roots (-3 and 1), I will compare them to determine which is the smallest root. Then I will provide the smallest root as the final answer.
266+
[Action] Calling tool 'smallestRealNumber' with args: "{\"x1\":\"-3\",\"x2\":\"1\"}"
267+
[Observation] Result from tool 'smallestRealNumber': "-3"
268+
[Thought] I have identified -3 as the smallest root. My next step is to provide -3 as the final answer.
269+
```
270+
271+
272+
Display the response.
273+
274+
```matlab
275+
disp(agentResponse);
276+
```
277+
278+
```matlabTextOutput
279+
The smallest root of the equation x^2 + 2x - 3 = 0 is -3.
280+
```
281+
282+
# References
283+
<a id="M_4d5d"></a>
284+
285+
\[1\] Shunyu Yao, Jeffrey Zhao, Dian Yu, Nan Du, Izhak Shafran, Karthik Narasimhan, and Yuan Cao. "ReAct: Synergizing Reasoning and Acting in Language Models". ArXiv, 10 March 2023. [https://doi.org/10.48550/arXiv.2210.03629](https://doi.org/10.48550/arXiv.2210.03629).
286+
119 KB
Loading
5.55 KB
Loading
130 KB
Loading
Binary file not shown.
Binary file not shown.

tests/texampleTests.m

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ function testRetrievalAugmentedGenerationUsingOllamaAndMATLAB(testCase)
152152
evalc("RetrievalAugmentedGenerationUsingOllamaAndMATLAB");
153153
end
154154

155+
function testSolveSimpleMathProblemUsingAIAgent(testCase)
156+
testCase.startCapture("SolveSimpleMathProblemUsingAIAgent");
157+
evalc("SolveSimpleMathProblemUsingAIAgent");
158+
end
159+
155160
function testSummarizeLargeDocumentsUsingChatGPTandMATLAB(testCase)
156161
testCase.startCapture("SummarizeLargeDocumentsUsingChatGPTandMATLAB");
157162
evalc("SummarizeLargeDocumentsUsingChatGPTandMATLAB");

0 commit comments

Comments
 (0)