diff --git a/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb b/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb index eb4cfdbc..fd8fc1d0 100644 --- a/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb +++ b/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb @@ -41,7 +41,9 @@ "id": "e13eb331", "metadata": {}, "outputs": [], - "source": ["%pip install --upgrade --quiet langchain-nvidia-ai-endpoints"] + "source": [ + "%pip install --upgrade --quiet langchain-nvidia-ai-endpoints" + ] }, { "cell_type": "markdown", @@ -65,11 +67,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "686c4d2f", "metadata": {}, "outputs": [], - "source": ["import getpass\nimport os\n\n# del os.environ['NVIDIA_API_KEY'] ## delete key and reset\nif os.environ.get(\"NVIDIA_API_KEY\", \"\").startswith(\"nvapi-\"):\n print(\"Valid NVIDIA_API_KEY already in environment. Delete to reset\")\nelse:\n nvapi_key = getpass.getpass(\"NVAPI Key (starts with nvapi-): \")\n assert nvapi_key.startswith(\"nvapi-\"), f\"{nvapi_key[:5]}... is not a valid key\"\n os.environ[\"NVIDIA_API_KEY\"] = nvapi_key"] + "source": [ + "import getpass\n", + "import os\n", + "\n", + "# del os.environ['NVIDIA_API_KEY'] ## delete key and reset\n", + "if os.environ.get(\"NVIDIA_API_KEY\", \"\").startswith(\"nvapi-\"):\n", + " print(\"Valid NVIDIA_API_KEY already in environment. Delete to reset\")\n", + "else:\n", + " nvapi_key = getpass.getpass(\"NVAPI Key (starts with nvapi-): \")\n", + " assert nvapi_key.startswith(\"nvapi-\"), f\"{nvapi_key[:5]}... is not a valid key\"\n", + " os.environ[\"NVIDIA_API_KEY\"] = nvapi_key" + ] }, { "cell_type": "markdown", @@ -91,7 +104,14 @@ "outputId": "e9c4cc72-8db6-414b-d8e9-95de93fc5db4" }, "outputs": [], - "source": ["## Core LC Chat Interface\nfrom langchain_nvidia_ai_endpoints import ChatNVIDIA\n\nllm = ChatNVIDIA(model=\"mistralai/mixtral-8x7b-instruct-v0.1\")\nresult = llm.invoke(\"Write a ballad about LangChain.\")\nprint(result.content)"] + "source": [ + "## Core LC Chat Interface\n", + "from langchain_nvidia_ai_endpoints import ChatNVIDIA\n", + "\n", + "llm = ChatNVIDIA(model=\"mistralai/mixtral-8x7b-instruct-v0.1\")\n", + "result = llm.invoke(\"Write a ballad about LangChain.\")\n", + "print(result.content)" + ] }, { "cell_type": "markdown", @@ -110,7 +130,12 @@ "id": "49838930", "metadata": {}, "outputs": [], - "source": ["from langchain_nvidia_ai_endpoints import ChatNVIDIA\n\n# connect to an embedding NIM running at localhost:8000, specifying a specific model\nllm = ChatNVIDIA(base_url=\"http://localhost:8000/v1\", model=\"meta/llama3-8b-instruct\")"] + "source": [ + "from langchain_nvidia_ai_endpoints import ChatNVIDIA\n", + "\n", + "# connect to an embedding NIM running at localhost:8000, specifying a specific model\n", + "llm = ChatNVIDIA(base_url=\"http://localhost:8000/v1\", model=\"meta/llama3-8b-instruct\")" + ] }, { "cell_type": "markdown", @@ -128,7 +153,11 @@ "id": "01fa5095-be72-47b0-8247-e9fac799435d", "metadata": {}, "outputs": [], - "source": ["print(llm.batch([\"What's 2*3?\", \"What's 2*6?\"]))\n# Or via the async API\n# await llm.abatch([\"What's 2*3?\", \"What's 2*6?\"])"] + "source": [ + "print(llm.batch([\"What's 2*3?\", \"What's 2*6?\"]))\n", + "# Or via the async API\n", + "# await llm.abatch([\"What's 2*3?\", \"What's 2*6?\"])" + ] }, { "cell_type": "code", @@ -136,7 +165,11 @@ "id": "75189ac6-e13f-414f-9064-075c77d6e754", "metadata": {}, "outputs": [], - "source": ["for chunk in llm.stream(\"How far can a seagull fly in one day?\"):\n # Show the token separations\n print(chunk.content, end=\"|\")"] + "source": [ + "for chunk in llm.stream(\"How far can a seagull fly in one day?\"):\n", + " # Show the token separations\n", + " print(chunk.content, end=\"|\")" + ] }, { "cell_type": "code", @@ -144,7 +177,12 @@ "id": "8a9a4122-7a10-40c0-a979-82a769ce7f6a", "metadata": {}, "outputs": [], - "source": ["async for chunk in llm.astream(\n \"How long does it take for monarch butterflies to migrate?\"\n):\n print(chunk.content, end=\"|\")"] + "source": [ + "async for chunk in llm.astream(\n", + " \"How long does it take for monarch butterflies to migrate?\"\n", + "):\n", + " print(chunk.content, end=\"|\")" + ] }, { "cell_type": "markdown", @@ -166,7 +204,10 @@ "id": "5b8a312d-38e9-4528-843e-59451bdadbac", "metadata": {}, "outputs": [], - "source": ["ChatNVIDIA.get_available_models()\n# llm.get_available_models()"] + "source": [ + "ChatNVIDIA.get_available_models()\n", + "# llm.get_available_models()" + ] }, { "cell_type": "markdown", @@ -206,7 +247,19 @@ "id": "f5f7aee8-e90c-4d5a-ac97-0dd3d45c3f4c", "metadata": {}, "outputs": [], - "source": ["from langchain_core.output_parsers import StrOutputParser\nfrom langchain_core.prompts import ChatPromptTemplate\nfrom langchain_nvidia_ai_endpoints import ChatNVIDIA\n\nprompt = ChatPromptTemplate.from_messages(\n [(\"system\", \"You are a helpful AI assistant named Fred.\"), (\"user\", \"{input}\")]\n)\nchain = prompt | ChatNVIDIA(model=\"meta/llama3-8b-instruct\") | StrOutputParser()\n\nfor txt in chain.stream({\"input\": \"What's your name?\"}):\n print(txt, end=\"\")"] + "source": [ + "from langchain_core.output_parsers import StrOutputParser\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "from langchain_nvidia_ai_endpoints import ChatNVIDIA\n", + "\n", + "prompt = ChatPromptTemplate.from_messages(\n", + " [(\"system\", \"You are a helpful AI assistant named Fred.\"), (\"user\", \"{input}\")]\n", + ")\n", + "chain = prompt | ChatNVIDIA(model=\"meta/llama3-8b-instruct\") | StrOutputParser()\n", + "\n", + "for txt in chain.stream({\"input\": \"What's your name?\"}):\n", + " print(txt, end=\"\")" + ] }, { "cell_type": "markdown", @@ -224,7 +277,21 @@ "id": "49aa569b-5f33-47b3-9edc-df58313eb038", "metadata": {}, "outputs": [], - "source": ["prompt = ChatPromptTemplate.from_messages(\n [\n (\n \"system\",\n \"You are an expert coding AI. Respond only in valid python; no narration whatsoever.\",\n ),\n (\"user\", \"{input}\"),\n ]\n)\nchain = prompt | ChatNVIDIA(model=\"meta/codellama-70b\") | StrOutputParser()\n\nfor txt in chain.stream({\"input\": \"How do I solve this fizz buzz problem?\"}):\n print(txt, end=\"\")"] + "source": [ + "prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " (\n", + " \"system\",\n", + " \"You are an expert coding AI. Respond only in valid python; no narration whatsoever.\",\n", + " ),\n", + " (\"user\", \"{input}\"),\n", + " ]\n", + ")\n", + "chain = prompt | ChatNVIDIA(model=\"meta/codellama-70b\") | StrOutputParser()\n", + "\n", + "for txt in chain.stream({\"input\": \"How do I solve this fizz buzz problem?\"}):\n", + " print(txt, end=\"\")" + ] }, { "cell_type": "markdown", @@ -244,7 +311,15 @@ "id": "26625437-1695-440f-b792-b85e6add9a90", "metadata": {}, "outputs": [], - "source": ["import IPython\nimport requests\n\nimage_url = \"https://www.nvidia.com/content/dam/en-zz/Solutions/research/ai-playground/nvidia-picasso-3c33-p@2x.jpg\" ## Large Image\nimage_content = requests.get(image_url).content\n\nIPython.display.Image(image_content)"] + "source": [ + "import IPython\n", + "import requests\n", + "\n", + "image_url = \"https://www.nvidia.com/content/dam/en-zz/Solutions/research/ai-playground/nvidia-picasso-3c33-p@2x.jpg\" ## Large Image\n", + "image_content = requests.get(image_url).content\n", + "\n", + "IPython.display.Image(image_content)" + ] }, { "cell_type": "code", @@ -252,7 +327,11 @@ "id": "dfbbe57c-27a5-4cbb-b967-19c4e7d29fd0", "metadata": {}, "outputs": [], - "source": ["from langchain_nvidia_ai_endpoints import ChatNVIDIA\n\nllm = ChatNVIDIA(model=\"nvidia/neva-22b\")"] + "source": [ + "from langchain_nvidia_ai_endpoints import ChatNVIDIA\n", + "\n", + "llm = ChatNVIDIA(model=\"nvidia/neva-22b\")" + ] }, { "cell_type": "markdown", @@ -268,7 +347,20 @@ "id": "432ea2a2-4d39-43f8-a236-041294171f14", "metadata": {}, "outputs": [], - "source": ["from langchain_core.messages import HumanMessage\n\nllm.invoke(\n [\n HumanMessage(\n content=[\n {\"type\": \"text\", \"text\": \"Describe this image:\"},\n {\"type\": \"image_url\", \"image_url\": {\"url\": image_url}},\n ]\n )\n ]\n)"] + "source": [ + "from langchain_core.messages import HumanMessage\n", + "\n", + "llm.invoke(\n", + " [\n", + " HumanMessage(\n", + " content=[\n", + " {\"type\": \"text\", \"text\": \"Describe this image:\"},\n", + " {\"type\": \"image_url\", \"image_url\": {\"url\": image_url}},\n", + " ]\n", + " )\n", + " ]\n", + ")" + ] }, { "cell_type": "markdown", @@ -366,7 +458,15 @@ "id": "c58f1dd0", "metadata": {}, "outputs": [], - "source": ["import IPython\nimport requests\n\nimage_url = \"https://picsum.photos/seed/kitten/300/200\"\nimage_content = requests.get(image_url).content\n\nIPython.display.Image(image_content)"] + "source": [ + "import IPython\n", + "import requests\n", + "\n", + "image_url = \"https://picsum.photos/seed/kitten/300/200\"\n", + "image_content = requests.get(image_url).content\n", + "\n", + "IPython.display.Image(image_content)" + ] }, { "cell_type": "code", @@ -374,7 +474,28 @@ "id": "8c721629-42eb-4006-bf68-0296f7925ebc", "metadata": {}, "outputs": [], - "source": ["import base64\n\nfrom langchain_core.messages import HumanMessage\n\n## Works for simpler images. For larger images, see actual implementation\nb64_string = base64.b64encode(image_content).decode(\"utf-8\")\n\nllm.invoke(\n [\n HumanMessage(\n content=[\n {\"type\": \"text\", \"text\": \"Describe this image:\"},\n {\n \"type\": \"image_url\",\n \"image_url\": {\"url\": f\"data:image/png;base64,{b64_string}\"},\n },\n ]\n )\n ]\n)"] + "source": [ + "import base64\n", + "\n", + "from langchain_core.messages import HumanMessage\n", + "\n", + "## Works for simpler images. For larger images, see actual implementation\n", + "b64_string = base64.b64encode(image_content).decode(\"utf-8\")\n", + "\n", + "llm.invoke(\n", + " [\n", + " HumanMessage(\n", + " content=[\n", + " {\"type\": \"text\", \"text\": \"Describe this image:\"},\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\"url\": f\"data:image/png;base64,{b64_string}\"},\n", + " },\n", + " ]\n", + " )\n", + " ]\n", + ")" + ] }, { "cell_type": "markdown", @@ -392,7 +513,10 @@ "id": "00c06a9a-497b-4192-a842-b075e27401aa", "metadata": {}, "outputs": [], - "source": ["base64_with_mime_type = f\"data:image/png;base64,{b64_string}\"\nllm.invoke(f'What\\'s in this image?\\n')"] + "source": [ + "base64_with_mime_type = f\"data:image/png;base64,{b64_string}\"\n", + "llm.invoke(f'What\\'s in this image?\\n')" + ] }, { "cell_type": "markdown", @@ -420,7 +544,9 @@ "id": "082ccb21-91e1-4e71-a9ba-4bff1e89f105", "metadata": {}, "outputs": [], - "source": ["%pip install --upgrade --quiet langchain"] + "source": [ + "%pip install --upgrade --quiet langchain" + ] }, { "cell_type": "code", @@ -430,7 +556,41 @@ "id": "fd2c6bc1" }, "outputs": [], - "source": ["from langchain_core.chat_history import InMemoryChatMessageHistory\nfrom langchain_core.runnables.history import RunnableWithMessageHistory\n\n# store is a dictionary that maps session IDs to their corresponding chat histories.\nstore = {} # memory is maintained outside the chain\n\n\n# A function that returns the chat history for a given session ID.\ndef get_session_history(session_id: str) -> InMemoryChatMessageHistory:\n if session_id not in store:\n store[session_id] = InMemoryChatMessageHistory()\n return store[session_id]\n\n\nchat = ChatNVIDIA(\n model=\"mistralai/mixtral-8x22b-instruct-v0.1\",\n temperature=0.1,\n max_tokens=100,\n top_p=1.0,\n)\n\n# Define a RunnableConfig object, with a `configurable` key. session_id determines thread\nconfig = {\"configurable\": {\"session_id\": \"1\"}}\n\nconversation = RunnableWithMessageHistory(\n chat,\n get_session_history,\n)\n\nconversation.invoke(\n \"Hi I'm Srijan Dubey.\", # input or query\n config=config,\n)"] + "source": [ + "from langchain_core.chat_history import InMemoryChatMessageHistory\n", + "from langchain_core.runnables.history import RunnableWithMessageHistory\n", + "\n", + "# store is a dictionary that maps session IDs to their corresponding chat histories.\n", + "store = {} # memory is maintained outside the chain\n", + "\n", + "\n", + "# A function that returns the chat history for a given session ID.\n", + "def get_session_history(session_id: str) -> InMemoryChatMessageHistory:\n", + " if session_id not in store:\n", + " store[session_id] = InMemoryChatMessageHistory()\n", + " return store[session_id]\n", + "\n", + "\n", + "chat = ChatNVIDIA(\n", + " model=\"mistralai/mixtral-8x22b-instruct-v0.1\",\n", + " temperature=0.1,\n", + " max_tokens=100,\n", + " top_p=1.0,\n", + ")\n", + "\n", + "# Define a RunnableConfig object, with a `configurable` key. session_id determines thread\n", + "config = {\"configurable\": {\"session_id\": \"1\"}}\n", + "\n", + "conversation = RunnableWithMessageHistory(\n", + " chat,\n", + " get_session_history,\n", + ")\n", + "\n", + "conversation.invoke(\n", + " \"Hi I'm Srijan Dubey.\", # input or query\n", + " config=config,\n", + ")" + ] }, { "cell_type": "code", @@ -445,7 +605,12 @@ "outputId": "79acc89d-a820-4f2c-bac2-afe99da95580" }, "outputs": [], - "source": ["conversation.invoke(\n \"I'm doing well! Just having a conversation with an AI.\",\n config=config,\n)"] + "source": [ + "conversation.invoke(\n", + " \"I'm doing well! Just having a conversation with an AI.\",\n", + " config=config,\n", + ")" + ] }, { "cell_type": "code", @@ -460,7 +625,12 @@ "outputId": "a1714513-a8fd-4d14-f974-233e39d5c4f5" }, "outputs": [], - "source": ["conversation.invoke(\n \"Tell me about yourself.\",\n config=config,\n)"] + "source": [ + "conversation.invoke(\n", + " \"Tell me about yourself.\",\n", + " config=config,\n", + ")" + ] }, { "cell_type": "markdown", @@ -488,7 +658,10 @@ "id": "e36c8911", "metadata": {}, "outputs": [], - "source": ["tool_models = [model for model in ChatNVIDIA.get_available_models() if model.supports_tools]\ntool_models"] + "source": [ + "tool_models = [model for model in ChatNVIDIA.get_available_models() if model.supports_tools]\n", + "tool_models" + ] }, { "cell_type": "markdown", @@ -504,7 +677,21 @@ "id": "bd54f174", "metadata": {}, "outputs": [], - "source": ["from pydantic import Field\nfrom langchain_core.tools import tool\n\n@tool\ndef get_current_weather(\n location: str = Field(..., description=\"The location to get the weather for.\")\n):\n \"\"\"Get the current weather for a location.\"\"\"\n ...\n\nllm = ChatNVIDIA(model=tool_models[0].id).bind_tools(tools=[get_current_weather])\nresponse = llm.invoke(\"What is the weather in Boston?\")\nresponse.tool_calls"] + "source": [ + "from pydantic import Field\n", + "from langchain_core.tools import tool\n", + "\n", + "@tool\n", + "def get_current_weather(\n", + " location: str = Field(..., description=\"The location to get the weather for.\")\n", + "):\n", + " \"\"\"Get the current weather for a location.\"\"\"\n", + " ...\n", + "\n", + "llm = ChatNVIDIA(model=tool_models[0].id).bind_tools(tools=[get_current_weather])\n", + "response = llm.invoke(\"What is the weather in Boston?\")\n", + "response.tool_calls" + ] }, { "cell_type": "markdown", @@ -542,7 +729,11 @@ "id": "0515f558", "metadata": {}, "outputs": [], - "source": ["from langchain_nvidia_ai_endpoints import ChatNVIDIA\nstructured_models = [model for model in ChatNVIDIA.get_available_models() if model.supports_structured_output]\nstructured_models"] + "source": [ + "from langchain_nvidia_ai_endpoints import ChatNVIDIA\n", + "structured_models = [model for model in ChatNVIDIA.get_available_models() if model.supports_structured_output]\n", + "structured_models" + ] }, { "cell_type": "markdown", @@ -558,7 +749,17 @@ "id": "482c37e8", "metadata": {}, "outputs": [], - "source": ["from pydantic import BaseModel, Field\n\nclass Person(BaseModel):\n first_name: str = Field(..., description=\"The person's first name.\")\n last_name: str = Field(..., description=\"The person's last name.\")\n\nllm = ChatNVIDIA(model=structured_models[0].id).with_structured_output(Person)\nresponse = llm.invoke(\"Who is Michael Jeffrey Jordon?\")\nresponse"] + "source": [ + "from pydantic import BaseModel, Field\n", + "\n", + "class Person(BaseModel):\n", + " first_name: str = Field(..., description=\"The person's first name.\")\n", + " last_name: str = Field(..., description=\"The person's last name.\")\n", + "\n", + "llm = ChatNVIDIA(model=structured_models[0].id).with_structured_output(Person)\n", + "response = llm.invoke(\"Who is Michael Jeffrey Jordon?\")\n", + "response" + ] }, { "cell_type": "markdown", @@ -574,7 +775,24 @@ "id": "7f802912", "metadata": {}, "outputs": [], - "source": ["from enum import Enum\n\nclass Choices(Enum):\n A = \"A\"\n B = \"B\"\n C = \"C\"\n\nllm = ChatNVIDIA(model=structured_models[2].id).with_structured_output(Choices)\nresponse = llm.invoke(\"\"\"\n What does 1+1 equal?\n A. -100\n B. 2\n C. doorstop\n \"\"\"\n)\nresponse"] + "source": [ + "from enum import Enum\n", + "\n", + "class Choices(Enum):\n", + " A = \"A\"\n", + " B = \"B\"\n", + " C = \"C\"\n", + "\n", + "llm = ChatNVIDIA(model=structured_models[2].id).with_structured_output(Choices)\n", + "response = llm.invoke(\"\"\"\n", + " What does 1+1 equal?\n", + " A. -100\n", + " B. 2\n", + " C. doorstop\n", + " \"\"\"\n", + ")\n", + "response" + ] }, { "cell_type": "code", @@ -582,7 +800,111 @@ "id": "02b7ef29", "metadata": {}, "outputs": [], - "source": ["model = structured_models[3].id\nllm = ChatNVIDIA(model=model).with_structured_output(Choices)\nprint(model)\nresponse = llm.invoke(\"\"\"\n What does 1+1 equal?\n A. -100\n B. 2\n C. doorstop\n \"\"\"\n)\nresponse"] + "source": [ + "model = structured_models[3].id\n", + "llm = ChatNVIDIA(model=model).with_structured_output(Choices)\n", + "print(model)\n", + "response = llm.invoke(\"\"\"\n", + " What does 1+1 equal?\n", + " A. -100\n", + " B. 2\n", + " C. doorstop\n", + " \"\"\"\n", + ")\n", + "response" + ] + }, + { + "cell_type": "markdown", + "id": "b5476b50", + "metadata": {}, + "source": [ + "### JSON Schema" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6fb3e5cb", + "metadata": {}, + "outputs": [], + "source": [ + "json_schema = {\n", + " \"title\": \"joke\",\n", + " \"description\": \"Joke to tell user.\",\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"setup\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The setup of the joke\",\n", + " },\n", + " \"punchline\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The punchline to the joke\",\n", + " },\n", + " \"rating\": {\n", + " \"type\": \"integer\",\n", + " \"description\": \"How funny the joke is, from 1 to 10\",\n", + " \"default\": None,\n", + " },\n", + " },\n", + " \"required\": [\"setup\", \"punchline\"],\n", + "}\n", + "llm = ChatNVIDIA(model=structured_models[0].id)\n", + "structured_llm = llm.with_structured_output(json_schema)\n", + "\n", + "structured_llm.invoke(\"Tell me a joke about cats\")" + ] + }, + { + "cell_type": "markdown", + "id": "c8893046", + "metadata": {}, + "source": [ + "## [JSON mode](https://platform.openai.com/docs/guides/text-generation/json-mode)\n", + "\n", + "Constrain the model to only generate valid JSON. Note that you must include a system message with instructions to use JSON for this mode to work.\n", + "\n", + "Only works with certain models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1c955cc", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.messages import HumanMessage, SystemMessage\n", + "\n", + "chat = ChatNVIDIA(model=tool_models[0].id).bind(\n", + " response_format={\"type\": \"json_object\"}\n", + ")\n", + "\n", + "output = chat.invoke(\n", + " [\n", + " SystemMessage(\n", + " content=\"Extract the 'name' and 'origin' of any companies mentioned in the following statement. Return a JSON list.\"\n", + " ),\n", + " HumanMessage(\n", + " content=\"Google was founded in the USA, while Deepmind was founded in the UK\"\n", + " ),\n", + " ]\n", + ")\n", + "print(output.content)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a884df2", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "json.loads(output.content)" + ] } ], "metadata": { @@ -590,7 +912,7 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -604,7 +926,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index 63ee21c5..6741a1a4 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -48,8 +48,11 @@ ) from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool -from langchain_core.utils.function_calling import convert_to_openai_tool -from langchain_core.utils.pydantic import is_basemodel_subclass +from langchain_core.utils.function_calling import ( + convert_to_openai_function, + convert_to_openai_tool, +) +from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass from pydantic import BaseModel, Field, PrivateAttr from langchain_nvidia_ai_endpoints._common import _NVIDIAClient @@ -237,6 +240,28 @@ def _process_for_vlm( return inputs, extra_headers +def _convert_to_openai_response_format( + schema: Union[Dict[str, Any], Type], +) -> Union[Dict, TypeBaseModel]: + if isinstance(schema, type) and is_basemodel_subclass(schema): + return schema + + if ( + isinstance(schema, dict) + and "json_schema" in schema + and schema.get("type") == "json_schema" + ): + response_format = schema + elif isinstance(schema, dict) and "name" in schema and "schema" in schema: + response_format = {"type": "json_schema", "json_schema": schema} + else: + function = convert_to_openai_function(schema) + function["schema"] = function.pop("parameters") + response_format = {"type": "json_schema", "json_schema": function} + + return response_format + + _DEFAULT_MODEL_NAME: str = "meta/llama3-8b-instruct" @@ -649,6 +674,7 @@ def with_structured_output( # type: ignore self, schema: Union[Dict, Type], *, + method: Literal["json_mode", "json_schema"] = "json_schema", include_raw: bool = False, **kwargs: Any, ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: @@ -658,6 +684,13 @@ def with_structured_output( # type: ignore Args: schema (Union[Dict, Type]): The schema to bind to the model. include_raw (bool): Always False. Passing True raises an error. + method: The method for steering model generation, one of: + - "json_schema": + Uses Structured Output API for supported models. + - "json_mode": + Uses JSON mode. Note that if using JSON mode then you + must include instructions for formatting the output into the + desired schema into the model call: **kwargs: Additional keyword arguments. Notes: @@ -773,13 +806,6 @@ class Choices(enum.Enum): For more, see https://python.langchain.com/docs/how_to/structured_output/ """ # noqa: E501 - if "method" in kwargs: - warnings.warn( - "The 'method' parameter is unnecessary and is ignored. " - "The appropriate method will be chosen automatically depending " - "on the type of schema provided." - ) - if kwargs.get("strict", True) is not True: warnings.warn( "Structured output always follows strict validation. " @@ -796,6 +822,8 @@ class Choices(enum.Enum): "being None when the LLM produces an incomplete response." ) + is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema) + # check if the model supports structured output, warn if it does not known_good = False # todo: we need to store model: Model in this class @@ -814,11 +842,32 @@ class Choices(enum.Enum): f"Model '{self.model}' is not known to support structured output. " "Your output may fail at inference time." ) + output_parser: BaseOutputParser # create a common type + + if method == "json_mode": + llm = self.bind(response_format={"type": "json_object"}) + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] + if is_pydantic_schema + else JsonOutputParser() + ) + elif method == "json_schema": + response_format = _convert_to_openai_response_format(schema) + llm = self.bind(response_format=response_format) + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] + if is_pydantic_schema + else JsonOutputParser() + ) + else: + raise ValueError( + f"Unrecognized method argument. Expected one of 'json_scheme' or " + f"'json_mode'. Received: '{method}'" + ) if isinstance(schema, dict): - output_parser: BaseOutputParser = JsonOutputParser() - nvext_param: Dict[str, Any] = {"guided_json": schema} - + output_parser = JsonOutputParser() + llm = self.bind(nvext={"guided_json": schema}) elif issubclass(schema, enum.Enum): # langchain's EnumOutputParser is not in langchain_core # and doesn't support streaming. this is a simple implementation @@ -845,7 +894,7 @@ def parse(self, response: str) -> Any: "Use StrEnum or ensure all member values are strings." ) output_parser = EnumOutputParser(enum=schema) - nvext_param = {"guided_choice": choices} + llm = self.bind(nvext={"guided_choice": choices}) elif is_basemodel_subclass(schema): # PydanticOutputParser does not support streaming. what we do @@ -867,7 +916,7 @@ def parse_result( json_schema = schema.model_json_schema() else: json_schema = schema.schema() - nvext_param = {"guided_json": json_schema} + llm = self.bind(nvext={"guided_json": json_schema}) else: raise ValueError( @@ -875,4 +924,4 @@ def parse_result( "representing a JSON schema, or an Enum." ) - return super().bind(nvext=nvext_param) | output_parser + return llm | output_parser diff --git a/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py b/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py index aaee110a..1f69a7cd 100644 --- a/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py +++ b/libs/ai-endpoints/tests/integration_tests/test_bind_tools.py @@ -11,6 +11,7 @@ BaseMessage, BaseMessageChunk, ) +from langchain_core.prompts import ChatPromptTemplate from langchain_core.tools import tool from pydantic import Field @@ -787,3 +788,35 @@ def magic( assert response_in < baseline_in * tolerance assert response_out < baseline_out * tolerance assert response_total < baseline_total * tolerance + + +def test_json_mode(tool_model: str) -> None: + llm = ChatNVIDIA(model=tool_model).bind(response_format={"type": "json_object"}) + response = llm.invoke( + "Return this as json: {'a': 1}", + ) + assert isinstance(response.content, str) + assert json.loads(response.content) == {"a": 1} + + # Test streaming + full: Optional[Union[BaseMessage, ChatPromptTemplate]] = None + for chunk in llm.stream("Return this as json: {'a': 1}"): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert isinstance(full.content, str) + assert json.loads(full.content) == {"a": 1} + + +async def test_json_mode_async(tool_model: str) -> None: + llm = ChatNVIDIA(model=tool_model).bind(response_format={"type": "json_object"}) + response = await llm.ainvoke("Return this as json: {'a': 1}") + assert isinstance(response.content, str) + assert json.loads(response.content) == {"a": 1} + + # Test streaming + full: Optional[Union[BaseMessage, ChatPromptTemplate]] = None + async for chunk in llm.astream("Return this as json: {'a': 1}"): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert isinstance(full.content, str) + assert json.loads(full.content) == {"a": 1} diff --git a/libs/ai-endpoints/tests/integration_tests/test_structured_output.py b/libs/ai-endpoints/tests/integration_tests/test_structured_output.py index 1f059e75..66bd8dbb 100644 --- a/libs/ai-endpoints/tests/integration_tests/test_structured_output.py +++ b/libs/ai-endpoints/tests/integration_tests/test_structured_output.py @@ -1,9 +1,10 @@ import enum -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Literal, Optional, Union import pytest from langchain_core.messages import HumanMessage from pydantic import BaseModel, Field +from pydantic import BaseModel as BaseModelProper from langchain_nvidia_ai_endpoints import ChatNVIDIA @@ -20,6 +21,26 @@ def do_stream(llm: ChatNVIDIA, message: str) -> Any: return result[-1] if result else None +class Joke(BaseModelProper): + """Joke to tell user.""" + + setup: str = Field(description="question to set up a joke") + punchline: str = Field(description="answer to resolve the joke") + + +class SelfEvaluation(BaseModelProper): + score: int + text: str + + +class JokeWithEvaluation(BaseModelProper): + """Joke to tell user.""" + + setup: str + punchline: str + self_evaluation: SelfEvaluation + + @pytest.mark.xfail(reason="Accuracy is not guaranteed") def test_accuracy(structured_model: str, mode: dict) -> None: class Person(BaseModel): @@ -66,14 +87,6 @@ class Person(BaseModel): assert person.birthplace == "Tainan, Taiwan" -class Joke(BaseModel): - """Joke to tell user.""" - - setup: str = Field(description="The setup of the joke") - punchline: str = Field(description="The punchline to the joke") - rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10") - - @pytest.mark.parametrize("func", [do_invoke, do_stream], ids=["invoke", "stream"]) def test_pydantic(structured_model: str, mode: dict, func: Callable) -> None: llm = ChatNVIDIA(model=structured_model, temperature=0, **mode) @@ -184,3 +197,142 @@ def test_pydantic_incomplete(structured_model: str, mode: dict, func: Callable) structured_llm = llm.with_structured_output(Joke) result = func(structured_llm, "Tell me a joke about cats") assert result is None + + +def joke(result: Any) -> None: + assert isinstance(result, dict) + assert all(key in set(result.keys()) for key in {"setup", "punchline"}) + + +def nested_json(result: Any) -> None: + assert isinstance(result, dict) # for mypy + assert set(result.keys()) == {"setup", "punchline", "self_evaluation"} + assert set(result["self_evaluation"].keys()) == {"score", "text"} + + +@pytest.mark.parametrize( + ("method", "strict"), + [("json_schema", None), ("json_mode", None)], +) +def test_structured_output_json_strict( + structured_model: str, + mode: dict, + method: Literal["json_mode", "json_schema"], + strict: Optional[bool], +) -> None: + """Test to verify structured output with strict=True.""" + + llm = ChatNVIDIA(model=structured_model, temperature=0, **mode) + + # Test structured output with a Pydantic class + chat = llm.with_structured_output(Joke, method=method, strict=strict) + result = chat.invoke("Tell me a joke about cats.") + + assert isinstance(result, Joke) + + for chunk in chat.stream("Tell me a joke about cats."): + assert isinstance(chunk, Joke) + + # Test structured output with JSON schema + chat = llm.with_structured_output( + Joke.model_json_schema(), method=method, strict=strict + ) + result = chat.invoke("Tell me a joke about cats.") + joke(result) + + for chunk in chat.stream("Tell me a joke about cats."): + assert isinstance(chunk, dict) + joke(chunk) + + +@pytest.mark.parametrize( + ("method", "strict"), [("json_schema", None), ("json_mode", None)] +) +def test_nested_structured_output_json_strict( + structured_model: str, + mode: dict, + method: Literal["json_schema", "json_mode"], + strict: Optional[bool], +) -> None: + """Test to verify structured output with strict=True for nested object.""" + + llm = ChatNVIDIA(model=structured_model, temperature=0, **mode) + + # Schema + chat = llm.with_structured_output( + JokeWithEvaluation.model_json_schema(), method=method, strict=strict + ) + result = chat.invoke("Tell me a joke about cats.") + nested_json(result) + + for chunk in chat.stream("Tell me a joke about cats."): + assert isinstance(chunk, dict) + nested_json(chunk) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("method", "strict"), + [("json_schema", None), ("json_mode", None)], +) +async def test_structured_output_json_strict_async( + structured_model: str, + method: Literal["json_schema", "json_mode"], + strict: Optional[bool], +) -> None: + """Test to verify structured output with strict=True (async).""" + + llm = ChatNVIDIA(model=structured_model, temperature=0) + + # Pydantic class + chat = llm.with_structured_output(Joke, method=method, strict=strict) + result = await chat.ainvoke("Tell me a joke about cats.") + assert isinstance(result, Joke) + + async for chunk in chat.astream("Tell me a joke about cats."): + assert isinstance(chunk, Joke) + + # Schema + chat = llm.with_structured_output( + Joke.model_json_schema(), method=method, strict=strict + ) + result = await chat.ainvoke("Tell me a joke about cats.") + joke(result) + + async for chunk in chat.astream("Tell me a joke about cats."): + assert isinstance(chunk, dict) + joke(chunk) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("method", "strict"), [("json_schema", None), ("json_mode", None)] +) +async def test_nested_structured_output_json_strict_async( + structured_model: str, method: Literal["json_schema"], strict: Optional[bool] +) -> None: + """Test to verify structured output with strict=True for nested object (async).""" + + llm = ChatNVIDIA(model=structured_model, temperature=0) + + # Schema + chat = llm.with_structured_output( + JokeWithEvaluation.model_json_schema(), method=method, strict=strict + ) + result = await chat.ainvoke("Tell me a joke about cats.") + nested_json(result) + + async for chunk in chat.astream("Tell me a joke about cats."): + assert isinstance(chunk, dict) + nested_json(chunk) + + +def test_json_mode_with_dict(structured_model: str) -> None: + """Test json_mode with a dictionary schema.""" + schema = { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + } + + llm = ChatNVIDIA(model=structured_model) + llm.with_structured_output(schema, method="json_mode") diff --git a/libs/ai-endpoints/tests/unit_tests/test_bind_tools.py b/libs/ai-endpoints/tests/unit_tests/test_bind_tools.py index a2a26436..920e32e0 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_bind_tools.py +++ b/libs/ai-endpoints/tests/unit_tests/test_bind_tools.py @@ -342,3 +342,39 @@ def test_strict_no_warns(strict: Optional[bool]) -> None: tools=[xxyyzz_tool_annotated], **({"strict": strict} if strict is not None else {}), ) + + +def test_json_mode( + requests_mock: requests_mock.Mocker, + mock_v1_models: None, +) -> None: + requests_mock.post( + "https://integrate.api.nvidia.com/v1/chat/completions", + json={ + "id": "chatcmpl-ID", + "object": "chat.completion", + "created": 1234567890, + "model": "BOGUS", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": '{"a": 1}', + }, + "logprobs": None, + } + ], + "usage": { + "prompt_tokens": 22, + "completion_tokens": 20, + "total_tokens": 42, + }, + "system_fingerprint": None, + }, + ) + + llm = ChatNVIDIA(api_key="BOGUS").bind(response_format={"type": "json_object"}) + response = llm.invoke("Return this as json: {'a': 1}") + assert isinstance(response, AIMessage) + assert json.loads(str(response.content)) == {"a": 1} diff --git a/libs/ai-endpoints/tests/unit_tests/test_structured_output.py b/libs/ai-endpoints/tests/unit_tests/test_structured_output.py index c8dfd00b..744a6f82 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_structured_output.py +++ b/libs/ai-endpoints/tests/unit_tests/test_structured_output.py @@ -19,19 +19,6 @@ class Joke(pydanticV2BaseModel): rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10") -def test_method() -> None: - with pytest.warns(UserWarning) as record: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - category=UserWarning, - message=".*not known to support structured output.*", - ) - ChatNVIDIA(api_key="BOGUS").with_structured_output(Joke, method="json_mode") - assert len(record) == 1 - assert "unnecessary" in str(record[0].message) - - def test_include_raw() -> None: with pytest.raises(NotImplementedError): ChatNVIDIA(api_key="BOGUS").with_structured_output(Joke, include_raw=True) @@ -226,7 +213,5 @@ def test_strict_no_warns(strict: Optional[bool]) -> None: "ignore", category=UserWarning, message=".*not known to support.*" ) - ChatNVIDIA(api_key="BOGUS").with_structured_output( - Joke, - **({"strict": strict} if strict is not None else {}), - ) + if strict: + ChatNVIDIA(api_key="BOGUS").with_structured_output(Joke, strict=strict)