diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index c314b8b519913..f80d0def625c8 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -51,10 +51,10 @@ bool common_chat_msg_parser::add_tool_call(const std::string & name, const std:: result_.tool_calls.emplace_back(tool_call); return true; } -bool common_chat_msg_parser::add_tool_call(const json & tool_call) { +bool common_chat_msg_parser::add_tool_call(const json & tool_call, const char * arguments_name) { std::string name = tool_call.contains("name") ? tool_call.at("name") : ""; std::string id = tool_call.contains("id") ? tool_call.at("id") : ""; - std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : ""; + std::string arguments = tool_call.contains(arguments_name) ? tool_call.at(arguments_name) : ""; return add_tool_call(name, id, arguments); } diff --git a/common/chat-parser.h b/common/chat-parser.h index 7ee355056b30a..9c8efbba656bb 100644 --- a/common/chat-parser.h +++ b/common/chat-parser.h @@ -59,7 +59,7 @@ class common_chat_msg_parser { bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments); // Adds a tool call using the "name", "id" and "arguments" fields of the json object - bool add_tool_call(const nlohmann::ordered_json & tool_call); + bool add_tool_call(const nlohmann::ordered_json & tool_call, const char * arguments_name = "arguments"); // Adds an array of tool calls using their "name", "id" and "arguments" fields. bool add_tool_calls(const nlohmann::ordered_json & arr); diff --git a/common/chat.cpp b/common/chat.cpp index f1ab4c85a913e..dcc651623a838 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1096,6 +1096,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te tool_rules.push_back( builder.add_rule( name + "-call", + "\"<|python_tag|>\"? space " "\"{\" space " "( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? " " \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space " @@ -1105,12 +1106,12 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te // Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name. data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - "(\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\")[\\s\\S]*", // + name + "\"[\\s\\S]*", + "((?:<\\|python_tag\\|>\\s*)?\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\")[\\s\\S]*", // + name + "\"[\\s\\S]*", }); if (!builtin_tools.empty()) { data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); - data.preserved_tokens.push_back("<|python_tag|>"); } + data.preserved_tokens.push_back("<|python_tag|>"); // Allow a few empty lines on top of the usual constrained json schema space rule. builder.add_rule("root", string_join(tool_rules, " | ")); data.additional_stops.push_back("<|eom_id|>"); @@ -1134,16 +1135,18 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w return; } - static const common_regex function_regex( - "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); - static const common_regex close_regex("\\}\\s*"); + static const common_regex python_tag_regex("\\s*<\\|python_tag\\|>"); - static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); - static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); + auto initial_pos = builder.pos(); + if (auto res = builder.try_consume_regex(python_tag_regex)) { + if (auto tc = builder.try_consume_json_with_dumped_args({{"parameters"}})) { + if (!builder.add_tool_call(tc->value, "parameters")) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } else if (with_builtin_tools) { + static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); + static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); - if (with_builtin_tools) { - static const common_regex builtin_call_regex("<\\|python_tag\\|>"); - if (auto res = builder.try_find_regex(builtin_call_regex)) { auto fun_res = builder.consume_regex(function_name_regex); auto function_name = builder.str(fun_res.groups[1]); @@ -1171,17 +1174,24 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w if (!builder.add_tool_call(function_name, "", arguments)) { throw common_chat_msg_partial_exception("Incomplete tool call"); } - return; + } + } else if (auto tc = builder.try_consume_json_with_dumped_args({{"parameters"}})) { + if (!builder.add_tool_call(tc->value, "parameters")) { + auto has_unknown_keys = false; + for (const auto & [key, value] : tc->value.items()) { + if (key != "parameters" && key != "name" && key != "type") { + has_unknown_keys = true; + break; + } + } + if (has_unknown_keys) { + builder.move_to(initial_pos); + } else { + throw common_chat_msg_partial_exception("incomplete tool call"); + } } } - parse_json_tool_calls( - builder, - /* block_open= */ std::nullopt, - /* function_regex_start_only= */ function_regex, - /* function_regex= */ std::nullopt, - close_regex, - std::nullopt); - + builder.add_content(builder.consume_rest()); } static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { @@ -1453,8 +1463,8 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con if (has_raw_python) { tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); - data.preserved_tokens.push_back("<|python_tag|>"); } + data.preserved_tokens.push_back("<|python_tag|>"); auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space"; builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LLAMA_3_X})); + assert_equals( + message_assist_call, + common_chat_parse( + "<|python_tag|>{\"type\": \"function\", \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LLAMA_3_X})); + assert_equals( + simple_assist_msg("{\"something\": \"else\"}"), + common_chat_parse( + "{\"something\": \"else\"}", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LLAMA_3_X})); + assert_equals( + message_assist_empty, + common_chat_parse( + "{\"some", + /* is_partial= */ true, + {COMMON_CHAT_FORMAT_LLAMA_3_X})); + assert_equals( + message_assist_empty, + common_chat_parse( + "{\"parameters\": {\"arg1\": 1}", + /* is_partial= */ true, + {COMMON_CHAT_FORMAT_LLAMA_3_X})); // test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false); test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools,